abstracted_axes
and eval_jaxpr
#18567
-
Hi, I am looking now into
I am trying to 436 def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True):
437 def read(v: Atom) -> Any:
438 return v.val if isinstance(v, Literal) else env[v]
439
440 def write(v: Var, val: Any) -> None:
441 if config.jax_enable_checks and not config.jax_dynamic_shapes:
442 assert typecheck(v.aval, val), (v.aval, val)
443 env[v] = val
444
445 env: dict[Var, Any] = {}
446 map(write, jaxpr.constvars, consts)
447 # Here, can I add the implicit input variables from args and perhaps some other sources of information?
# E.g., args.insert(0, args[0].shape[0])
448 map(write, jaxpr.invars, args) Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 7 replies
-
Thanks for the question!
Actually, the parameter isn't implicit in the jaxpr; it appears as just an ordinary parameter, and correspondingly the caller can just pass it as an ordinary argument: import jax
import jax.numpy as jnp
jax.config.update('jax_dynamic_shapes', True)
def f(x):
return jnp.sin(x) + jnp.cos(x)
jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(jnp.arange(3.)).jaxpr
print(jaxpr)
# { lambda ; a:i32[] b:f32[a]. let
# c:f32[a] = sin b
# d:f32[a] = cos b
# e:f32[a] = add c d
# in (e,) }
from jax._src.core import eval_jaxpr
ans, = eval_jaxpr(jaxpr, (), 3, jnp.arange(3.))
print(ans) # [1. 1.3817732 0.49315056] The axis size arguments (and outputs) are only implicit at the Python tracing level, but everything is made explicit at the jaxpr level. That means things like the dispatch path for jitted computations needs to infer axis sizes from arguments, as you describe, but not That said, we're not actively working on this dynamic shape stuff right now, so there's probably lots of other stuff that doesn't work. WDYT? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
Actually, the parameter isn't implicit in the jaxpr; it appears as just an ordinary parameter, and correspondingly the caller can just pass it as an ordinary argument: