Contributing a differentiable and jittable sparse linear solver to JAX core? #18452
-
Hello all! @denizokt and I have worked recently to implement a direct sparse linear solver that is jittable and differentiable via a custom VJP that leverages the implicit function theorem. The solver works on both CPU and GPU. You can check it out here: https://github.com/arpastrana/jax_fdm/blob/main/src/jax_fdm/equilibrium/sparse.py We put the solver together in the context of a research project on differentiable mechanics for inverse design. As often happens, many mechanical simulations rely on sparse matrices to calculate the structural response of a system efficiently. We identified that there was limited support for direct differentiable sparse solvers (especially for CPU) in JAX at the time, so we decided to take a stab at it. We think this solver could be useful to other JAX + mechanics folks like us! Are there any chances to contribute our differentiable sparse solver to JAX core? Happy to hear your thoughts and suggestions. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
Have you come across JAXopt (https://github.com/google/jaxopt) or lineax? These might be the best places to look for overlap and discuss. You can also read our recent thinking on JAX numpy/scipy scope here. |
Beta Was this translation helpful? Give feedback.
-
Hello @arpastrana I am keenly interested in trying out the sparse solver library you have developed. May I ask what are the requirements for the format of matrix A? Would it be possible to use jax.sparse.BCOO when defining A? It would be highly appreciated if you could provide a simple example to illustrate its usage. Thank you in advance for your help. |
Beta Was this translation helpful? Give feedback.
Have you come across JAXopt (https://github.com/google/jaxopt) or lineax? These might be the best places to look for overlap and discuss. You can also read our recent thinking on JAX numpy/scipy scope here.
cc @mblondel @fabianp @patrick-kidger @jakevdp