Skip to content

vmap across arbitrary leading dimensions. #14034

Answered by jakevdp
andrewwarrington asked this question in Ideas
Discussion options

You must be logged in to vote

It sounds like what you want is similar to the semantics of jnp.vectorize? It might look something like this, using the variables you defined:

from functools import partial

@partial(jnp.vectorize, signature='(m),()->(n)')
def fn(x, c):
  return A @ x + c

out = fn(x, c)

This is implemented via nested calls to vmap, as in your approach. But if that's not the API you want, I think your wrapper function looks like a great solution.

Replies: 3 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by andrewwarrington
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Ideas
Labels
None yet
2 participants