-
-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
What is the best way to integrate a function? #361
Comments
This would work, but might not be that efficient, algorithmically speaking. To explain this, I like to distinguish between ODEs (when the vector field depends on ODE solvers are based on the premise that small errors based early on in the integration will result in downstream errors later in the integration -- as the evolving But for an integral like this, that premise is no longer the case. That means that solver can be less conservative, and e.g. allow itself to take larger timesteps whilst still getting a sufficiently accurate solution. That means integral-specific solvers (i.e. "quadrature methods") can accomplish the same goal with less computational work. Sadly we don't really have a great quadrature library in JAX yet. (Quadrature and interpolation are the two pieces we still have missing, really.) So if you particularly care about efficiency, I'd suggest coding up your own quadrature rule. But if you just care about getting a result and aren't too fussed about speed, then go ahead and keep using Diffrax! |
If you don't mind a little self promotion, I've been working on a library for quadrature in Jax that should do what you need: https://github.com/f0uriest/quadax And in response to @patrick-kidger comment, I also have one for interpolation and splines: https://github.com/f0uriest/interpax |
Haha, self promotion is absolutely encouraged! I think these are both really cool (and I think I've seen them before, and have been meaning to check them out). Poking through Quadax a little bit, two quick questions/comments:
(I should emphasise that I really like the look of both these libraries :) ) |
Thanks! I'm a big fan of equinox so I had considered using your version of wrt |
So when it comes to "running all loops until the final step", note that this will be until the final step of the overall scan -- to be precise, the (presumably-expensive) body function will be evaluated on every step, for the fixed length of the entire scan. This is including once you're past the point when all batch elements are done -- it's just that in this case the body function will be evaluated and then discarded. If that's what you mean and are okay with that, then sure. In practice for difffeqsolves it's fairly common to have a maximum number of steps set at e.g. 10^4, but to only make e.g. 10^1 steps most of the time. So there it's important that we exhibit early-exit behaviour. The status of Equinox's while loop: it's a stable API, the only reason is that it isn't documented is because it's easily footgunnable (see the warnings in its docstring). I don't think that's safe enough for an average user -- the rule I've gone with for the Eqx ecosystem has been that I'd rather not do something than do it badly. Too easy to break user trust otherwise. (=exactly the reason I set up my own thing rather than use the Julia ecosystem!) So I think you should be good to use it, just scrutinise what you do carefully :) |
Let say I have a function$h(t)$ that I can evaluate in any $t$ and I want to calculate $$\int_0^1 h(t) dt.$$
Would
diffrax
used like that be a good idea?The text was updated successfully, but these errors were encountered: