Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in optax 0.2.4: Adam optimiser does not work in a jax tracer function, but Optax 0.2.3 does #1159

Open
olive004 opened this issue Dec 19, 2024 · 1 comment
Assignees

Comments

@olive004
Copy link

olive004 commented Dec 19, 2024

Description:
I am training a simple VAE with Jax 0.4.29 (cuda enabled) and Optax 0.2.4 and was using the sgd optax optimiser to update my parameters. However, when I switch to the adam optimiser, the tree signature of the parameters and its gradients are no longer accepted in the solver.update function, even without any changes to the parameters, and throws the following error:

optax valueError: Expected dict, got Traced<ShapedArray(float32[128,6])>with<DynamicJaxprTrace(level=1/0)>.

When I switch to the SGD optimiser, I can get the updates successfully and move on to optax.apply. When I switch to Optax 0.2.3, the error goes away for the Adam optimiser as well and the weights in the jax tree can be unwrapped successfully without throwing the error that the jax arrays are not dicts.

The code to reproduce this is very long, so I will not include it here, but if it's helpful my training functions are in this github repository file.

@rdyro rdyro self-assigned this Dec 19, 2024
@rdyro
Copy link
Collaborator

rdyro commented Dec 23, 2024

Hey, I took a look, but can't see any obvious changes between versions 0.2.3 and 0.2.4 that could cause something like this. Without a repro, this error might indicate that one of your jitted functions is leaking tracer arrays during tracing.

A repro would be really useful for debugging this, is there maybe a shell command I can use to run a repro using the git repo you linked?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants