Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add gradient accumulation #292

Open
XinDongol opened this issue May 1, 2024 · 5 comments
Open

[Feature] Add gradient accumulation #292

XinDongol opened this issue May 1, 2024 · 5 comments
Labels
enhancement New feature or request

Comments

@XinDongol
Copy link

XinDongol commented May 1, 2024

Gradient accumulation (micro step) could be very useful when we want to have large batch size but with limited number of gpus.

@wanchaol
Copy link
Contributor

wanchaol commented May 1, 2024

@XinDongol do you mean microbatching or pipeline parallel?

@lessw2020
Copy link
Contributor

@awgu - is there a context manager or similar option in fsdp2 that would support gradient accumulation and thus enable this in titan? I know we talked about this for HSDP but not sure about generic FSDP2.

@awgu
Copy link
Contributor

awgu commented May 1, 2024

I am guessing this is asking for normal microbatching. There are similar APIs for FSDP2 that can control communication during gradient accumulation.

We migrated the no_sync() context to directly just module.set_requires_gradient_sync(bool) so that it can be just placed at the top of the training loop as module.set_requires_gradient_sync(is_last_microbatch). Note however though, that typically for memory constrained cases, we prefer to just proceed as normal and reduce-scatter every microbatch.

@XinDongol
Copy link
Author

XinDongol commented May 1, 2024

Thanks for updating.
@wanchaol Yes, I am talking about microbatching.

torchtitan/train.py

Lines 291 to 294 in 58b1169

with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()

@awgu is it sufficient to change ? Thanks
from (current)

with loss_parallel_ctx():
    pred = model(input_ids)
    loss = loss_fn(pred, labels)
    loss.backward()

to

for microbatch_idx in range(microbatch):
	batch = next(data_iterator)
    input_ids, labels = batch
	model.set_requires_gradient_sync(microbatch_idx==(microbatch-1))
	with loss_parallel_ctx():
	    pred = model(input_ids)
	    loss = loss_fn(pred, labels) / microbatch
	    loss.backward()

@awgu
Copy link
Contributor

awgu commented May 1, 2024

@XinDongol I think that is sufficient.

If you want to avoid reduce-scatter in backward, then what you have is right. Note however that this will mean that gradients are left as unsharded through backward, which may use too much memory depending on the workload.

If you want to still reduce-scatter in backward, you can simply remove that model.set_requires_gradient_sync line (effectively leaving it as the default of True).

@tianyu-l tianyu-l added the enhancement New feature or request label May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants