-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add gradient accumulation #292
Comments
@XinDongol do you mean microbatching or pipeline parallel? |
@awgu - is there a context manager or similar option in fsdp2 that would support gradient accumulation and thus enable this in titan? I know we talked about this for HSDP but not sure about generic FSDP2. |
I am guessing this is asking for normal microbatching. There are similar APIs for FSDP2 that can control communication during gradient accumulation. We migrated the |
Thanks for updating. Lines 291 to 294 in 58b1169
@awgu is it sufficient to change ? Thanks with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward() to for microbatch_idx in range(microbatch):
batch = next(data_iterator)
input_ids, labels = batch
model.set_requires_gradient_sync(microbatch_idx==(microbatch-1))
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels) / microbatch
loss.backward() |
@XinDongol I think that is sufficient. If you want to avoid reduce-scatter in backward, then what you have is right. Note however that this will mean that gradients are left as unsharded through backward, which may use too much memory depending on the workload. If you want to still reduce-scatter in backward, you can simply remove that |
Gradient accumulation (micro step) could be very useful when we want to have large batch size but with limited number of gpus.
The text was updated successfully, but these errors were encountered: