Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use jnp.nanmin and jnp.nanmax to compute new stepsize factor #235

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

virajpandya
Copy link

…instead of of jnp.clip in diffrax.step_size_controller.adaptive.adapt_step_size()

Care was taken to make sure that the order of jnp.nanmin and jnp.nanmax reflects the actual behavior of jnp.clip. According to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.clip.html, using jnp.(nan)min and jnp.(nan)max may be a bit slower than jnp.clip but at least it will be robust against NaN's in y_error.

This solves #223 and is similar to my proposed bugfix / pull request for jax.experimental.ode.optimal_step_size() here: google/jax#14612 and google/jax#14624

I confirmed that this works with different explicit/implicit diffrax solvers and I get the expected correct solution to my non-autonomous ODE system vs. scipy.integrate.solve_ivp, manual non-adaptive Euler integration with extremely small timesteps, and manual adaptive RK23 (Bogacki-Shampine) solver in both pure python and JAX.

…of jnp.clip

Care was taken to make sure that the order of jnp.nanmin and jnp.nanmax reflects the actual behavior of jnp.clip.
According to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.clip.html, using jnp.(nan)min and jnp.(nan)max may be a bit slower than 
jnp.clip but at least it will be robust against NaN's in y_error.
@patrick-kidger
Copy link
Owner

Okay, sorry for taking so long to get back around to this!

One thing I would like to understand is why this approach failed. It looks like both possible fixes involve detecting NaNs.

In particular you mentioned wanting to autodiff through these solvers. If one generates a NaN and then removes it later, then this can sometimes still re-appear on the backward pass. Being robust to such issues usually means catching the NaN as soon as possible.

Are you able to track down how the NaN still sneaks by in the previous approach? I'd definitely like to get some version of this fix in.

@virajpandya
Copy link
Author

virajpandya commented Mar 23, 2023

Sorry for the delay -- I'm finally getting back to this. Thanks for all the help! So the good news is that with my fix, the system can both be successfully solved and forward-mode autodiff'd (jax.jacfwd and jax.jvp). I verified the resulting Jacobian and JVP with finite-difference (using atol and rtol << parameter perturbations). That was with your old v0.2.2 using NoAdjoint(). Switching to your latest v0.3.1 without either my fix or yours, the system and forward-mode autodiff also are both successful (using DirectAdjoint()). I don't yet know exactly what changed in v0.3.1 to have this work out of the box now but it's great!

What doesn't work is reverse-mode autodiff and I would like to know why. RecursiveCheckpointAdjoint() gives NaN gradients and BacksolveAdjoint() leads to a 'max_steps reached' error, even though the system is solved in 290 steps.

Where in diffrax do you recommend putting jax.debug.print statements so I can see if something wonky is happening in solving my ODE system backwards in time? I checked ad.py, adjoint.py and integrate.py and I see a lot of equinox calls, but are there specific places where, e.g., I can print the value of the 8 state variable values and their time derivatives along the backward pass?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants