Replies: 1 comment
-
Have you tried inspecting annotations for grads and putting sharding constraints around that so it doesn't call all_reduce before the accumulation is finished also it seems you'd need either shard_map which would be better suited for such non-trivial setup |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I am in the process of moving over some data parallel training code written with
xmap
to the newjit
API with the intent to extend it to allow model-parallel training. The training step function which was previously xmapped, takes a batch of data, splits the data into microbatches and then performs gradient accumulation on the microbatches, performing a singlepmean
operation at the end of the accumulation loop to synchronize gradients across devices.I've been trying to replicate this behavior with
jit
unsuccessfully (minimal code reproduction attached below). Every timejnp.mean(loss)
is called in the accumulation loop, an all-reduce across all devices is performed, which I have been able to confirm with the jax profiler. I've tried sharding the batch and then re-sharding every microbatch within the accumulation loop but the compiled code seems to want to perform this all-reduce no matter the sharding annotations. Is there something I am missing with respect to the sharding annotations or is this a bug?Thank you!
Code to reproduce:
What jax/jaxlib version are you using?
Which accelerator(s) are you using?
Beta Was this translation helpful? Give feedback.
All reactions