Replies: 1 comment 2 replies
-
Your main issue is that you used ex: @partial(pmap, axis_name='devices')
def train_step(params, inputs, bboxes, labels, opt_state, clip_state, step):
(loss_val, params), grads = train_forward(params, inputs, bboxes, labels, step)
# ..... For pjit, you need to specifying how the data and parameters are sharded. from jax.experimental.pjit import pjit
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
@pjit(
in_shardings=(replicated_sharding, data_sharding, bboxes_sharding, labels_sharding, replicated_sharding, replicated_sharding, replicated_sharding),
out_shardings=(replicated_sharding, replicated_sharding, replicated_sharding, replicated_sharding)
)
def train_step(params, inputs, bboxes, labels, opt_state, clip_state, step):
(loss_val, params), grads = train_forward(params, inputs, bboxes, labels, step)
|
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi Guys, I attempted to train a model using data parallelism with JAX. However, the speedup was a meager 1.14. Below is a portion of my code. Could you help me identify where I went wrong?
Beta Was this translation helpful? Give feedback.
All reactions