Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][FSDP2] Added register_fsdp_forward_method for user fwd methods #125394

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
53 changes: 53 additions & 0 deletions test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
FSDPModule,
fully_shard,
OffloadPolicy,
register_fsdp_forward_method,
)
from torch.distributed._tensor import DTensor, init_device_mesh
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand Down Expand Up @@ -1139,5 +1140,57 @@ def _test_train_parity_hsdp(
check_sharded_parity(self, ref_model, model)


class TestFullyShardCustomForwardMethod(FSDPTestMultiThread):
@property
def world_size(self) -> int:
return 2

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_register_fsdp_forward_method(self):
"""Based on https://github.com/pytorch/pytorch/issues/109385"""

class VisionTransformer(nn.Module):
def __init__(self):
super().__init__()
self.patch_proj = nn.Conv2d(3, 1024, kernel_size=14, stride=14)

def forward_features(self, imgs: torch.Tensor) -> torch.Tensor:
return self.patch_proj(imgs).flatten(2).transpose(1, 2)

def forward(self, imgs: torch.Tensor) -> torch.Tensor:
return self.forward_features(imgs).sum(dim=1)

class Model(nn.Module):
def __init__(self):
super().__init__()
self.vit, self.projector = VisionTransformer(), nn.Linear(1024, 256)

def forward(self, imgs: torch.Tensor) -> torch.Tensor:
# Run `vit.forward_features`, which is not `forward`!
patch_embeddings = self.vit.forward_features(imgs)
return self.projector(patch_embeddings)

torch.manual_seed(42)
model = Model()
for param in model.parameters():
dist.broadcast(param.detach(), src=0)
ref_model = copy.deepcopy(model).cuda()
fully_shard(model.vit)
fully_shard(model.projector)
fully_shard(model)
register_fsdp_forward_method(model.vit, "forward_features")

torch.manual_seed(42 + self.rank + 1)
inp = torch.randn(4, 3, 224, 224, device="cuda")
ref_loss = ref_model(inp).sum()
loss = model(inp).sum()
self.assertEqual(ref_loss, loss)
ref_loss.backward()
loss.backward()
for param in ref_model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
check_sharded_parity(self, ref_model, model)


if __name__ == "__main__":
run_tests()
2 changes: 1 addition & 1 deletion torch/distributed/_composable/fsdp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
from .fully_shard import FSDPModule, fully_shard
from .fully_shard import FSDPModule, fully_shard, register_fsdp_forward_method
33 changes: 33 additions & 0 deletions torch/distributed/_composable/fsdp/fully_shard.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from typing import Any, cast, Optional, Union

import typing_extensions
Expand Down Expand Up @@ -314,3 +315,35 @@ def wait(self):
self._fsdp_param_group.wait_for_unshard()
# Avoid keeping a reference
self._fsdp_param_group = None


def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @Skylion007 if you have any opinions on this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is fantastic, let me ping some folks composer folks and see if they have any more detailed feedback on this PR. :)

"""
Registers a method on ``module`` to be a forward method for FSDP.

FSDP only knows to run its pre-forward and post-forward hooks on the
default :meth:`nn.Module.forward` method. This function patches a user
specified method to run the pre/post-forward hooks before/after the method,
respectively. If ``module`` is not an :class:`FSDPModule`, then this is a
no-op.

Args:
module (nn.Module): Module to register the forward method on.
method_name (str): Name of the forward method.
"""
if not isinstance(module, FSDPModule):
# Make no-op to allow including both when using/not using FSDP
return
if not hasattr(module, method_name):
raise ValueError(f"{type(module)} does not have a method {method_name}")
orig_method = getattr(module, method_name)

@functools.wraps(orig_method)
def wrapped_method(self, *args, **kwargs):
fsdp_state = self._get_fsdp_state()
args, kwargs = fsdp_state._pre_forward(self, args, kwargs)
out = orig_method(*args, **kwargs)
return fsdp_state._post_forward(self, args, out)

# Use `__get__` to make `wrapped_method` an instance method
setattr(module, method_name, wrapped_method.__get__(module, type(module)))