vmap
across arbitrary leading dimensions.
#14034
-
I wanted to The first type of solution I came up with was something like this (illustrated for the case of matrix-vector multiplication)
However this only supports up the three batch dimensions, and so is pretty tedious and limited, and is really ugly (although it is mainly for illustrative purposes). But we can express this as a recursion and get something more general like this:
which we can call as follows:
The There are some deviations here from the JAX-ey way of doing things; namely that In some ways, though, I think this is more JAX-ey that my naive alternatives: in the same way that you should write I'd love to know if there is a better way of going about this, either that is already in JAX or otherwise, and whether this would be a good thing to add. Thanks, |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
It sounds like what you want is similar to the semantics of from functools import partial
@partial(jnp.vectorize, signature='(m),()->(n)')
def fn(x, c):
return A @ x + c
out = fn(x, c) This is implemented via nested calls to |
Beta Was this translation helpful? Give feedback.
-
oh. sometimes it really is that simple eh? 😅😂 Thank you for posting this! This is exactly what I was looking for. I was sort-of disbelieving that it wouldn't already be in JAX somewhere, but I didn't realise it would be that simple. Thanks! |
Beta Was this translation helpful? Give feedback.
-
Also linking this Stack Overflow answer, hilariously by jakevdp as well, giving some more context. https://stackoverflow.com/questions/69099847/jax-vectorization-vmap-and-or-numpy-vectorize |
Beta Was this translation helpful? Give feedback.
It sounds like what you want is similar to the semantics of
jnp.vectorize
? It might look something like this, using the variables you defined:This is implemented via nested calls to
vmap
, as in your approach. But if that's not the API you want, I think your wrapper function looks like a great solution.