use jnp.nanmin and jnp.nanmax to compute new stepsize factor #235
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
…instead of of jnp.clip in diffrax.step_size_controller.adaptive.adapt_step_size()
Care was taken to make sure that the order of jnp.nanmin and jnp.nanmax reflects the actual behavior of jnp.clip. According to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.clip.html, using jnp.(nan)min and jnp.(nan)max may be a bit slower than jnp.clip but at least it will be robust against NaN's in y_error.
This solves #223 and is similar to my proposed bugfix / pull request for jax.experimental.ode.optimal_step_size() here: google/jax#14612 and google/jax#14624
I confirmed that this works with different explicit/implicit diffrax solvers and I get the expected correct solution to my non-autonomous ODE system vs. scipy.integrate.solve_ivp, manual non-adaptive Euler integration with extremely small timesteps, and manual adaptive RK23 (Bogacki-Shampine) solver in both pure python and JAX.