Replies: 1 comment
-
To add on, I think it would be ergonomic to include theta = jnp.ones((3, 3))
x = jnp.ones((10, 3))
h = jnp.zeros((3,))
def f(carry, input):
theta, x = input
_, h = carry
return h + theta @ x, None
scanf = jax.scan(f, in_axes=(0, None)) # sequence-map (smap?) over x but not params theta
scanf(init=h, xs=(theta, x)) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Is there a reason why
jax.lax.scan
does not transform functions? Havingjit
,vmap
operate on functions is super useful and very readable. It is a shame that it doesn't work forscan
.For example, ideally one could do something like:
Or even use a decorator
I suppose you can use partial, but it's less clean
This is somewhat related to #23487
Beta Was this translation helpful? Give feedback.
All reactions