How to code "if" statement when using "vmap" #4951
-
I encounter a problem when I use vmap for a simple function containing if statement. Here is the code: import jax
import jax.numpy as jnp
from jax import vmap
def k(u):
if u ==1.:
return u
elif u==-1.:
return 2*u
else:
return 3*u
k_map = vmap(k,(0))
u = jnp.linspace(-1.,1.,5)
k_map(u) Then it shows error:
Thanks a lot for your help! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Check out the control flow section of "The Sharp Bits" in the documentation; it discusses this! If you want to use if statements in Jax, use lax.cond or lax.switch. |
Beta Was this translation helpful? Give feedback.
Check out the control flow section of "The Sharp Bits" in the documentation; it discusses this!
Issue #196 also explains this (in @mattjj's main reply).
If you want to use if statements in Jax, use lax.cond or lax.switch.
Hope that helps!