-
Here is the MWE i have: import jax
import numpy as np
gpus = jax.devices("gpu")
cpus = jax.devices("cpu")
import jax.numpy as jnp
from jax import jit
N = 3000
mat_list = [ np.random.randn(N, N).astype(np.float64) for i in range(5)]
b = np.random.randn(N, 1).astype(np.float64)
mat_jnp_list = jnp.array([jnp.array(x) for x in mat_list])
b_jnp = jnp.array(b)
seq = np.random.randint(0, 5, size=8).tolist()
@jax.partial(jit, static_argnums=0)
def evolve(seq, b_jnp, mat_jnp_list):
v_jnp = b_jnp
for i in seq:
v_jnp = np.dot(mat_jnp_list[i], v_jnp)
return v_jnp
seq_jnp = jnp.array(seq)
v_jnp = evolve(tuple(seq), b_jnp, mat_jnp_list) I have the error of the following: ---------------------------------------------------------------------------
Exception Traceback (most recent call last)
<ipython-input-22-d8bc7e0a30d8> in <module>()
15 tic = time.time()
16 seq_jnp = jnp.array(seq)
---> 17 v_jnp = evolve(seq, b_jnp, mat_jnp_list)
18
19 toc = time.time()
~/Downloads/jax/jax/api.py in f_jitted(*args, **kwargs)
215 backend=backend,
216 name=flat_fun.__name__,
--> 217 donated_invars=donated_invars)
218 return tree_unflatten(out_tree(), out)
219
~/Downloads/jax/jax/core.py in bind(self, fun, *args, **params)
1160
1161 def bind(self, fun, *args, **params):
-> 1162 return call_bind(self, fun, *args, **params)
1163
1164 def process(self, trace, fun, tracers, params):
~/Downloads/jax/jax/core.py in call_bind(primitive, fun, *args, **params)
1151 tracers = map(top_trace.full_raise, args)
1152 with maybe_new_sublevel(top_trace):
-> 1153 outs = primitive.process(top_trace, fun, tracers, params)
1154 return map(full_lower, apply_todos(env_trace_todo(), outs))
1155
~/Downloads/jax/jax/core.py in process(self, trace, fun, tracers, params)
1163
1164 def process(self, trace, fun, tracers, params):
-> 1165 return trace.process_call(self, fun, tracers, params)
1166
1167 def post_process(self, trace, out_tracers, params):
~/Downloads/jax/jax/core.py in process_call(self, primitive, f, tracers, params)
573
574 def process_call(self, primitive, f, tracers, params):
--> 575 return primitive.impl(f, *tracers, **params)
576 process_map = process_call
577
~/Downloads/jax/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
555 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
556 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 557 *unsafe_map(arg_spec, args))
558 try:
559 return compiled_fun(*args)
~/Downloads/jax/jax/linear_util.py in memoized_fun(fun, *args)
245 fun.populate_stores(stores)
246 else:
--> 247 ans = call(fun, *args)
248 cache[key] = (ans, fun.stores)
249
~/Downloads/jax/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
630 abstract_args, arg_devices = unzip2(arg_specs)
631 if config.omnistaging_enabled:
--> 632 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
633 if any(isinstance(c, core.Tracer) for c in consts):
634 raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
~/Downloads/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
1036 main.source_info = fun_sourceinfo(fun.f) # type: ignore
1037 main.jaxpr_stack = () # type: ignore
-> 1038 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1039 del main
1040 return jaxpr, out_avals, consts
~/Downloads/jax/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1017 trace = DynamicJaxprTrace(main, core.cur_sublevel())
1018 in_tracers = map(trace.new_arg, in_avals)
-> 1019 ans = fun.call_wrapped(*in_tracers)
1020 out_tracers = map(trace.full_raise, ans)
1021 jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)
~/Downloads/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
154
155 try:
--> 156 ans = self.f(*args, **dict(self.params, **kwargs))
157 except:
158 # Some transformations yield from inside context managers, so we have to
<ipython-input-22-d8bc7e0a30d8> in evolve(seq, b_jnp, mat_jnp_list)
8 # for i in seq:
9 for i in range(N):
---> 10 v_jnp = np.dot(mat_jnp_list[i], v_jnp)
11
12 return v_jnp
<__array_function__ internals> in dot(*args, **kwargs)
~/Downloads/jax/jax/core.py in __array__(self, *args, **kw)
439 "JAX Tracer instance; in that case, you can instead write "
440 "`jax.device_put(x)[idx]`.")
--> 441 raise Exception(msg)
442
443 def __init__(self, trace: Trace):
Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3000,1])>with<DynamicJaxprTrace(level=0/1)>.
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`. Does anyone know what I can do to help the jit to support indexing? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Feb 1, 2021
Replies: 2 comments
-
In this line: v_jnp = np.dot(mat_jnp_list[i], v_jnp) You're using |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
JiahaoYao
-
Oh, i see. The issue is resoved. Thanks, appreciated! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In this line:
You're using
np.dot
within a jitted function, which is not supported. Try changingnp.dot
tojnp.dot
.