Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeviceMesh] Add support for group: Tuple[ProcessGroup, ...] in from_group() #125358

Open
awgu opened this issue May 1, 2024 · 2 comments
Open
Labels
module: DeviceMesh triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@awgu
Copy link
Contributor

awgu commented May 1, 2024

We recently added a DeviceMesh.from_group() API to support constructing a DeviceMesh from an existing ProcessGroup to help interoperate with training code that uses ProcessGroup for some parallelisms and DeviceMesh for others.

@staticmethod
def from_group(group: ProcessGroup, device_type: str) -> "DeviceMesh":

We want to expand the API to support HSDP.

  • Generalize the group argument to support Union[ProcessGroup, Tuple[ProcessGroup, ...]] so that the user can pass in a tuple of the inter-node and intra-node PGs
  • [Optional] Add mesh_dim_names: Optional[Tuple[str, ...]] = None kwarg so that the user can still give named mesh dims
@awgu awgu added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: DeviceMesh labels May 1, 2024
@awgu awgu assigned awgu and unassigned awgu May 1, 2024
@awgu
Copy link
Contributor Author

awgu commented May 2, 2024

I am not clear how to recover the mesh tensor from the process groups in the general case. Each rank can get the ranks of the process groups passed to it, but to recover the mesh, we need to do some math.

For example, if we have mesh = torch.arange(32).view(4, 8), then rank 0 sees inter-node PG with ranks (0, 8, 16, 24) and intra-node PG with ranks (0, 1, 2, 3, 4, 5, 6, 7). We can see that the intra-node (0, 1, 2, 3, 4, 5, 6, 7) increments by 1 each time and use that to fill out the ranks along each element in (0, 8, 16, 24) to get back mesh.

Now, if we have say mesh = torch.arange(128).view(4, 4, 8), where the rightmost dim is excluded (e.g. 8-way TP with (4, 4)-way HSDP)), then rank 0 sees "intra-node" ranks (0, 8, 16, 24) and "inter-node" ranks (0, 32, 64, 96). We can similarly see that (0, 8, 16, 24) increments by 8 each time and use that to fill out the ranks along each element in (0, 32, 64, 96).

Now, how does this generalize to the user passing in N process groups?

@wanchaol
Copy link
Contributor

wanchaol commented May 2, 2024

@awgu I thought a bit about this, the recovery math can be quite complicated in N-D scenarios, even for 2-D/3-D it seems non-trival amount of code. I'm wondering if we should do the way we do for things like device_type, when user want to construct a device mesh from a pg, they also need to tell us a bit more about their subpg structures by passing in a mesh_tensor in addition to device_type, then under the hood we just do some simple validations to make sure the pg_ranks are the same as the mesh_tensor dimension values.

The mesh tensor dim values can be easily derived similar to this https://github.com/pytorch/pytorch/blob/main/torch/distributed/device_mesh.py#L289

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: DeviceMesh triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants