You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Description:
I am training a simple VAE with Jax 0.4.29 (cuda enabled) and Optax 0.2.4 and was using the sgd optax optimiser to update my parameters. However, when I switch to the adam optimiser, the tree signature of the parameters and its gradients are no longer accepted in the solver.update function, even without any changes to the parameters, and throws the following error:
When I switch to the SGD optimiser, I can get the updates successfully and move on to optax.apply. When I switch to Optax 0.2.3, the error goes away for the Adam optimiser as well and the weights in the jax tree can be unwrapped successfully without throwing the error that the jax arrays are not dicts.
The code to reproduce this is very long, so I will not include it here, but if it's helpful my training functions are in this github repository file.
The text was updated successfully, but these errors were encountered:
Hey, I took a look, but can't see any obvious changes between versions 0.2.3 and 0.2.4 that could cause something like this. Without a repro, this error might indicate that one of your jitted functions is leaking tracer arrays during tracing.
A repro would be really useful for debugging this, is there maybe a shell command I can use to run a repro using the git repo you linked?
Description:
I am training a simple VAE with Jax 0.4.29 (cuda enabled) and Optax 0.2.4 and was using the sgd optax optimiser to update my parameters. However, when I switch to the adam optimiser, the tree signature of the parameters and its gradients are no longer accepted in the
solver.update
function, even without any changes to the parameters, and throws the following error:optax valueError: Expected dict, got Traced<ShapedArray(float32[128,6])>with<DynamicJaxprTrace(level=1/0)>.
When I switch to the SGD optimiser, I can get the updates successfully and move on to
optax.apply
. When I switch to Optax 0.2.3, the error goes away for the Adam optimiser as well and the weights in the jax tree can be unwrapped successfully without throwing the error that the jax arrays are not dicts.The code to reproduce this is very long, so I will not include it here, but if it's helpful my training functions are in this github repository file.
The text was updated successfully, but these errors were encountered: