New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC][FSDP2] Added register_fsdp_forward_method
for user fwd methods
#125394
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125394
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit 4428b8c with merge base b03fb49 (): FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 01bc2fab00edeb16ce0d2d06d0a784fe3619911f Pull Request resolved: #125394
… fwd methods" FSDP only runs its pre/post-forward hooks on `nn.Module.forward`. This means that if the user runs a custom method meant as a forward pass, then FSDP will not all-gather the parameters. Examples include HuggingFace models' `generate()` (#123962, #100069) or others (#109385). This PR adds a monkey patching API to allow FSDP pre/post-forward hooks to run on the method. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
ghstack-source-id: ba8e1d1f417cfc622b7098808a508f99968f61be Pull Request resolved: #125394
From
We should be able to replace this with:
|
@@ -314,3 +315,35 @@ def wait(self): | |||
self._fsdp_param_group.wait_for_unshard() | |||
# Avoid keeping a reference | |||
self._fsdp_param_group = None | |||
|
|||
|
|||
def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @Skylion007 if you have any opinions on this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! This is fantastic, let me ping some folks composer folks and see if they have any more detailed feedback on this PR. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice. I did not know it only requires a few lines of code to support user-defined fwd.
is this making an assumption that user-defined fwd eg forward_features
won't call hooks from nn.Module ? otherwise we will have fsdp_hook(forward_features(fsdp_hook))
Since the user-defined forward method (e.g. Note that this is only adding FSDP hooks to the user-defined method for that one particular module. Any nested submodules will run |
|
That is a good point. We are assuming that the user is not calling the hooks themselves in This is not so much of a concern to me because (1) calling the hooks themselves is not too likely to me and (2) FSDP wants to prepend the pre-forward hook anyway. (Post-forward being prepended might be an issue though.) At least in the use cases we have seen in practice, I think this is okay, but your point is definitely valid. |
is the particular module mostly root module? like |
Yes. I mainly have seen it for the root module's |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sgtm!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hi @awgu, thanks for the patch! I wonder how this can be used together with |
@gaotianyu1350 Sorry, this does not apply to |
Stack from ghstack (oldest at bottom):
register_fsdp_forward_method
for user fwd methods #125394FSDP only runs its pre/post-forward hooks on
nn.Module.forward
. This means that if the user runs a custom method meant as a forward pass, then FSDP will not all-gather the parameters. Examples include HuggingFace models'generate()
(#123962, #100069) or others (#109385).This PR adds a monkey patching API
register_fsdp_forward_method(module: nn.Module, method_name: str)
to allow FSDP pre/post-forward hooks to run on the method. The function is a no-op if the passed-inmodule
is not an FSDP module so that the register function can be called even if the FSDP wrapping changes.cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k