Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crash with jit of ordered io_callback under shard_map and then no shard_map on CPU #25671

Open
jburnim opened this issue Dec 23, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@jburnim
Copy link
Collaborator

jburnim commented Dec 23, 2024

Description

On my M1 Mac, with jax==0.4.39.dev20241223 and jaxlib==0.4.39.dev20241223 and Python 3.12.6, the following code:

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'

import jax
import jax.numpy as jnp

from jax.sharding import PartitionSpec as P
from jax.experimental.shard_map import shard_map

mesh = jax.make_mesh((4,), 'i')

def empty_callback(x):
  return

def _f(x, y):
  jax.experimental.io_callback(empty_callback, (), x, ordered=True)
  return x + y[..., jnp.newaxis]

f = jax.jit(shard_map(
    _f, mesh, in_specs=(P(None, 'i'), P(None)), out_specs=P(None, 'i')))

print(f(jnp.zeros((2, 16)), jnp.ones(2)))

print(jax.jit(_f)(jnp.zeros((2, 16)), jnp.ones(2)))

crashes when running the last line, with:

F external/xla/xla/pjrt/pjrt_client.h:1128] Check failed: on_device_shape().has_layout()
Abort trap: 6

When I update test_empty_io_callback_under_shard_map in tests/pjit_test.py at HEAD to match this code -- see #25670 -- the test fails in GitHub's CI and on my Mac. For example, see https://github.com/jax-ml/jax/actions/runs/12470134880/job/34804668225?pr=25670 :

Fatal Python error: Aborted

Current thread 0x00007a4bd1b93000 (most recent call first):
  File "/__w/jax/jax/jax/_src/array.py", line 644 in _value
  File "/__w/jax/jax/jax/_src/profiler.py", line 333 in wrapper
  File "/__w/jax/jax/jax/_src/array.py", line 1113 in shard_sharded_device_array_slow_path
  File "/__w/jax/jax/jax/_src/array.py", line 1170 in _array_shard_arg
  File "/__w/jax/jax/jax/_src/array.py", line 1209 in _token_shard_arg
  File "/__w/jax/jax/jax/_src/interpreters/pxla.py", line 135 in shard_args
  File "/__w/jax/jax/jax/_src/profiler.py", line 333 in wrapper
  File "/__w/jax/jax/jax/_src/dispatch.py", line 510 in _batched_device_put_impl
  File "/__w/jax/jax/jax/_src/core.py", line 941 in process_primitive
  File "/__w/jax/jax/jax/_src/core.py", line 468 in bind_with_trace
  File "/__w/jax/jax/jax/_src/core.py", line 463 in bind
  File "/__w/jax/jax/jax/_src/api.py", line 2294 in device_put
  File "/__w/jax/jax/jax/_src/dispatch.py", line 136 in get_token_input
  File "/__w/jax/jax/jax/_src/interpreters/pxla.py", line 1249 in <listcomp>
  File "/__w/jax/jax/jax/_src/interpreters/pxla.py", line 1248 in _add_tokens_to_inputs
  File "/__w/jax/jax/jax/_src/interpreters/pxla.py", line 1288 in __call__
  File "/__w/jax/jax/jax/_src/profiler.py", line 333 in wrapper
  File "/__w/jax/jax/jax/_src/pjit.py", line 1688 in _pjit_call_impl_python
  File "/__w/jax/jax/jax/_src/pjit.py", line 196 in _python_pjit_helper

The code fails with the same error if the last two lines are switched -- i.e., if the not-shard-mapped _f is run before the shard-mapped f.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.39.dev20241223
jaxlib: 0.4.39.dev20241223
numpy:  2.2.1
python: 3.12.6 (main, Sep  9 2024, 21:36:32) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='jburnim-macbookpro2.roam.internal', release='23.6.0', version='Darwin Kernel Version 23.6.0: Thu Sep 12 23:35:29 PDT 2024; root:xnu-10063.141.1.701.1~1/RELEASE_ARM64_T6000', machine='arm64')
@jburnim jburnim added the bug Something isn't working label Dec 23, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 23, 2024

Assigning @yashk2810, who may have ideas!

Note however that due to the holiday it may take a while to look into this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants