Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][FSDP2] Added register_fsdp_forward_method for user fwd methods #125394

Closed
wants to merge 2 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented May 2, 2024

Stack from ghstack (oldest at bottom):

FSDP only runs its pre/post-forward hooks on nn.Module.forward. This means that if the user runs a custom method meant as a forward pass, then FSDP will not all-gather the parameters. Examples include HuggingFace models' generate() (#123962, #100069) or others (#109385).

This PR adds a monkey patching API register_fsdp_forward_method(module: nn.Module, method_name: str) to allow FSDP pre/post-forward hooks to run on the method. The function is a no-op if the passed-in module is not an FSDP module so that the register function can be called even if the FSDP wrapping changes.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Copy link

pytorch-bot bot commented May 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125394

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 4428b8c with merge base b03fb49 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ci-td-distributed oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels May 2, 2024
awgu added a commit that referenced this pull request May 2, 2024
ghstack-source-id: 01bc2fab00edeb16ce0d2d06d0a784fe3619911f
Pull Request resolved: #125394
@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels May 2, 2024
… fwd methods"


FSDP only runs its pre/post-forward hooks on `nn.Module.forward`. This means that if the user runs a custom method meant as a forward pass, then FSDP will not all-gather the parameters. Examples include HuggingFace models' `generate()` (#123962, #100069) or others (#109385).

This PR adds a monkey patching API to allow FSDP pre/post-forward hooks to run on the method.

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit that referenced this pull request May 2, 2024
ghstack-source-id: ba8e1d1f417cfc622b7098808a508f99968f61be
Pull Request resolved: #125394
@awgu awgu marked this pull request as ready for review May 2, 2024 14:39
@awgu
Copy link
Contributor Author

awgu commented May 2, 2024

From mosaicml/composer:

# Note: We need to use the FSDP.summon_full_params context manager here because the generate function
# does not seem to gather the weights for the LM head. This solution works because the tied weights of the LM head
# are in the root FSDP module, and are summoned by the below context manager. See https://github.com/pytorch/pytorch/issues/100069
# for more info.
# Note: We use recurse=False here so that we only summon full params for the LM head, not the entire model.
with FSDP.summon_full_params(self.model, writeback=False, recurse=False):
    return self.model.generate(input_ids=input_ids, pad_token_id=pad_token_id, **kwargs)

We should be able to replace this with:

register_fsdp_forward_method(self.model, "generate")  # call once at init time
...
return self.model.generate(input_ids=input_ids, pad_token_id=pad_token_id, **kwargs)

@@ -314,3 +315,35 @@ def wait(self):
self._fsdp_param_group.wait_for_unshard()
# Avoid keeping a reference
self._fsdp_param_group = None


def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @Skylion007 if you have any opinions on this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is fantastic, let me ping some folks composer folks and see if they have any more detailed feedback on this PR. :)

Copy link
Contributor

@weifengpy weifengpy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice. I did not know it only requires a few lines of code to support user-defined fwd.

is this making an assumption that user-defined fwd eg forward_features won't call hooks from nn.Module ? otherwise we will have fsdp_hook(forward_features(fsdp_hook))

@awgu
Copy link
Contributor Author

awgu commented May 2, 2024

is this making an assumption that user-defined fwd eg forward_features won't call hooks from nn.Module ? otherwise we will have fsdp_hook(forward_features(fsdp_hook))

Since the user-defined forward method (e.g. forward_features) is not nn.Module.forward, anyway the registered forward hooks on the module would not run for that user-defined method (e.g. forward_features), so I think this is not a concern.

Note that this is only adding FSDP hooks to the user-defined method for that one particular module. Any nested submodules will run forward normally, so if there is a nested FSDP submodule that will just work per normal.

@weifengpy
Copy link
Contributor

is this making an assumption that user-defined fwd eg forward_features won't call hooks from nn.Module ? otherwise we will have fsdp_hook(forward_features(fsdp_hook))

Since the user-defined forward method (e.g. forward_features) is not nn.Module.forward, anyway the registered forward hooks on the module would not run for that user-defined method (e.g. forward_features), so I think this is not a concern.

Note that this is only adding FSDP hooks to the user-defined method for that one particular module. Any nested submodules will run forward normally, so if there is a nested FSDP submodule that will just work per normal.

forward_features is under user's control? I guess we are ignoring the chance that user call nn.Module.forward_hooks explicitly in forward_features ?

@awgu
Copy link
Contributor Author

awgu commented May 2, 2024

forward_features is under user's control? I guess we are ignoring the chance that user call nn.Module.forward_hooks explicitly in forward_features ?

That is a good point. We are assuming that the user is not calling the hooks themselves in forward_features.

This is not so much of a concern to me because (1) calling the hooks themselves is not too likely to me and (2) FSDP wants to prepend the pre-forward hook anyway. (Post-forward being prepended might be an issue though.)

At least in the use cases we have seen in practice, I think this is okay, but your point is definitely valid.

@weifengpy
Copy link
Contributor

user-defined method for that one particular module. Any nested submodules will run forward normally

is the particular module mostly root module? like model.generate() ?

@awgu
Copy link
Contributor Author

awgu commented May 2, 2024

user-defined method for that one particular module. Any nested submodules will run forward normally

is the particular module mostly root module? like model.generate() ?

Yes. I mainly have seen it for the root module's .generate(). The vision transformer example was not root though 🤔 .

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label May 3, 2024
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sgtm!

@awgu
Copy link
Contributor Author

awgu commented May 3, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@gaotianyu1350
Copy link

Hi @awgu, thanks for the patch! I wonder how this can be used together with torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel? It seems that this is only compatible with the torch.distributed._composable stuff, which I don't quite understand... Thanks!

@awgu
Copy link
Contributor Author

awgu commented May 13, 2024

@gaotianyu1350 Sorry, this does not apply to torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel. The current workaround for that is to use summon_full_params(recurse=False).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-td-distributed ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants