Skip to content

Commit

Permalink
[RFC][FSDP2] Added register_fsdp_forward_method for user fwd methods
Browse files Browse the repository at this point in the history
ghstack-source-id: ba8e1d1f417cfc622b7098808a508f99968f61be
Pull Request resolved: #125394
  • Loading branch information
awgu committed May 2, 2024
1 parent b03fb49 commit 1f3372a
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 1 deletion.
53 changes: 53 additions & 0 deletions test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
FSDPModule,
fully_shard,
OffloadPolicy,
register_fsdp_forward_method,
)
from torch.distributed._tensor import DTensor, init_device_mesh
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand Down Expand Up @@ -1139,5 +1140,57 @@ def _test_train_parity_hsdp(
check_sharded_parity(self, ref_model, model)


class TestFullyShardCustomForwardMethod(FSDPTestMultiThread):
@property
def world_size(self) -> int:
return 2

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_register_fsdp_forward_method(self):
"""Based on https://github.com/pytorch/pytorch/issues/109385"""

class VisionTransformer(nn.Module):
def __init__(self):
super().__init__()
self.patch_proj = nn.Conv2d(3, 1024, kernel_size=14, stride=14)

def forward_features(self, imgs: torch.Tensor) -> torch.Tensor:
return self.patch_proj(imgs).flatten(2).transpose(1, 2)

def forward(self, imgs: torch.Tensor) -> torch.Tensor:
return self.forward_features(imgs).sum(dim=1)

class Model(nn.Module):
def __init__(self):
super().__init__()
self.vit, self.projector = VisionTransformer(), nn.Linear(1024, 256)

def forward(self, imgs: torch.Tensor) -> torch.Tensor:
# Run `vit.forward_features`, which is not `forward`!
patch_embeddings = self.vit.forward_features(imgs)
return self.projector(patch_embeddings)

torch.manual_seed(42)
model = Model()
for param in model.parameters():
dist.broadcast(param.detach(), src=0)
ref_model = copy.deepcopy(model).cuda()
fully_shard(model.vit)
fully_shard(model.projector)
fully_shard(model)
register_fsdp_forward_method(model.vit, "forward_features")

torch.manual_seed(42 + self.rank + 1)
inp = torch.randn(4, 3, 224, 224, device="cuda")
ref_loss = ref_model(inp).sum()
loss = model(inp).sum()
self.assertEqual(ref_loss, loss)
ref_loss.backward()
loss.backward()
for param in ref_model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
check_sharded_parity(self, ref_model, model)


if __name__ == "__main__":
run_tests()
2 changes: 1 addition & 1 deletion torch/distributed/_composable/fsdp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
from .fully_shard import FSDPModule, fully_shard
from .fully_shard import FSDPModule, fully_shard, register_fsdp_forward_method
33 changes: 33 additions & 0 deletions torch/distributed/_composable/fsdp/fully_shard.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from typing import Any, cast, Optional, Union

import typing_extensions
Expand Down Expand Up @@ -314,3 +315,35 @@ def wait(self):
self._fsdp_param_group.wait_for_unshard()
# Avoid keeping a reference
self._fsdp_param_group = None


def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
"""
Registers a method on ``module`` to be a forward method for FSDP.
FSDP only knows to run its pre-forward and post-forward hooks on the
default :meth:`nn.Module.forward` method. This function patches a user
specified method to run the pre/post-forward hooks before/after the method,
respectively. If ``module`` is not an :class:`FSDPModule`, then this is a
no-op.
Args:
module (nn.Module): Module to register the forward method on.
method_name (str): Name of the forward method.
"""
if not isinstance(module, FSDPModule):
# Make no-op to allow including both when using/not using FSDP
return
if not hasattr(module, method_name):
raise ValueError(f"{type(module)} does not have a method {method_name}")
orig_method = getattr(module, method_name)

@functools.wraps(orig_method)
def wrapped_method(self, *args, **kwargs):
fsdp_state = self._get_fsdp_state()
args, kwargs = fsdp_state._pre_forward(self, args, kwargs)
out = orig_method(*args, **kwargs)
return fsdp_state._post_forward(self, args, out)

# Use `__get__` to make `wrapped_method` an instance method
setattr(module, method_name, wrapped_method.__get__(module, type(module)))

0 comments on commit 1f3372a

Please sign in to comment.