You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fatal Python error: Aborted
Current thread 0x00007a4bd1b93000 (most recent call first):
File "/__w/jax/jax/jax/_src/array.py", line 644 in _value
File "/__w/jax/jax/jax/_src/profiler.py", line 333 in wrapper
File "/__w/jax/jax/jax/_src/array.py", line 1113 in shard_sharded_device_array_slow_path
File "/__w/jax/jax/jax/_src/array.py", line 1170 in _array_shard_arg
File "/__w/jax/jax/jax/_src/array.py", line 1209 in _token_shard_arg
File "/__w/jax/jax/jax/_src/interpreters/pxla.py", line 135 in shard_args
File "/__w/jax/jax/jax/_src/profiler.py", line 333 in wrapper
File "/__w/jax/jax/jax/_src/dispatch.py", line 510 in _batched_device_put_impl
File "/__w/jax/jax/jax/_src/core.py", line 941 in process_primitive
File "/__w/jax/jax/jax/_src/core.py", line 468 in bind_with_trace
File "/__w/jax/jax/jax/_src/core.py", line 463 in bind
File "/__w/jax/jax/jax/_src/api.py", line 2294 in device_put
File "/__w/jax/jax/jax/_src/dispatch.py", line 136 in get_token_input
File "/__w/jax/jax/jax/_src/interpreters/pxla.py", line 1249 in <listcomp>
File "/__w/jax/jax/jax/_src/interpreters/pxla.py", line 1248 in _add_tokens_to_inputs
File "/__w/jax/jax/jax/_src/interpreters/pxla.py", line 1288 in __call__
File "/__w/jax/jax/jax/_src/profiler.py", line 333 in wrapper
File "/__w/jax/jax/jax/_src/pjit.py", line 1688 in _pjit_call_impl_python
File "/__w/jax/jax/jax/_src/pjit.py", line 196 in _python_pjit_helper
The code fails with the same error if the last two lines are switched -- i.e., if the not-shard-mapped _f is run before the shard-mapped f.
System info (python version, jaxlib version, accelerator, etc.)
Description
On my M1 Mac, with jax==0.4.39.dev20241223 and jaxlib==0.4.39.dev20241223 and Python 3.12.6, the following code:
crashes when running the last line, with:
When I update
test_empty_io_callback_under_shard_map
intests/pjit_test.py
at HEAD to match this code -- see #25670 -- the test fails in GitHub's CI and on my Mac. For example, see https://github.com/jax-ml/jax/actions/runs/12470134880/job/34804668225?pr=25670 :The code fails with the same error if the last two lines are switched -- i.e., if the not-shard-mapped
_f
is run before the shard-mappedf
.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: