Modular functions in JAX #25260
-
Hi JAX community! I'm fairly new to JAX and really enjoying it so far (it's very accessible to someone from a non-maths/CS background). I have a software design related question to JAX. To give a bit of context, I am trying to write a physical simulation package, with modularity as one of the features. For example, I want to have solvers that can be easily swapped out, and I am not sure what is the best way to go about this. I know of one approach to do this: from jax.tree_util import Partial
@Partial
def solver_a(...):
pass
@Partial
def solver_b(...):
pass This works, but if the solver function needs any inputs unique to it compared to other solvers, then it quickly becomes unwieldy (closures, using I've been considering this approach recently: from typing import NamedTuple
def solver_a(NamedTuple):
unique_params: ArrayLike
# any others unique params can go here
# one implementation
def __call__(self, ...)->float:
pass
def solver_b(NamedTuple):
unique_params: float | int
# another implementation
def __call__(self, ...)->float:
pass The benefit of this approach is that named tuples are inherently immutable and transparent to JAX by default. Are there any downsides that I am not seeing? Looking at similar packages (e.g. fluid simulations) written in JAX, I see a variety of different approaches. Some use |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
If you want to pass a function as a parameter, you can mark it static and that should be enough. If you want to pass a more general structure, then you could use |
Beta Was this translation helpful? Give feedback.
-
Unsurprisingly I'll advocate for Of the options you list:
Regardless of what you do, if you want interchangeable solvers then you might like to take a look at the solvers of libraries like Lineax, Optimistix, or Diffrax for inspiration, e.g. how |
Beta Was this translation helpful? Give feedback.
Unsurprisingly I'll advocate for
equinox.Module
as the best way to do this. :)Of the options you list:
dict
: unstructured, probably not recommended in the same way that we prefer classes over dicts in regular Python.equinox.Module
: this is just a dataclass-as-pytree, which handles all the edge cases for you (bound methods, inheritance, etc.) If you have multiple solvers that you want to treat interchangeably then I think one of the big advantages here is its strong support for abstract classes. (AbstractVar
,__check_init__
, see also this advanced guide.)