-
Notifications
You must be signed in to change notification settings - Fork 250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add row-decomposition of adj. matrix to reduce graph partitioning overhead #720
Open
chang-l
wants to merge
5
commits into
NVIDIA:main
Choose a base branch
from
chang-l:dist_mpnn_sharding
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
f5132d9
Add matrix partition to replace current graph partition
chang-l 2ccaab9
Refactor and renaming
chang-l a0f287b
Add back orig impl and test
chang-l fec3dec
Minor update to remove notimpl blocks
chang-l 665acd6
Merge branch 'main' into dist_mpnn_sharding
chang-l File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,6 +67,8 @@ class GraphPartition: | |
partition_size: int | ||
partition_rank: int | ||
device: torch.device | ||
# flag to indicate using adj matrix 1-D row-decomp | ||
matrix_decomp: bool = False | ||
|
||
# data structures defining partition | ||
# set in after initialization or during execution | ||
|
@@ -394,12 +396,133 @@ def partition_graph_with_id_mapping( | |
return graph_partition | ||
|
||
|
||
def partition_graph_with_matrix_decomposition( | ||
global_offsets: torch.Tensor, | ||
global_indices: torch.Tensor, | ||
num_nodes: int, | ||
partition_book: torch.Tensor, | ||
partition_size: int, | ||
partition_rank: int, | ||
device: torch.device, | ||
) -> GraphPartition: | ||
""" | ||
Utility function which partitions a global graph given as CSC structure based on its adjacency | ||
matirx using 1-D row-wise decomposition. This approach ensures a 1D uniform distribution of nodes | ||
and their associated 1-hop incoming edges. By treating source and destination nodes equivalently | ||
during partitioning, this approach assumes the graph is not bipartite. | ||
This decomposition also ensures that the graph convolution (spMM) remains local by maintaining a copy of | ||
the local incoming edge features and the local node outputs from the graph convolution. | ||
The memory complexity of this approach is O[(N/P + E/P)*hid_dim*L], where N/E are the number of nodes/edges. | ||
The transformation from local node storage to local edge storage is achieved using nccl `alltoall`. | ||
|
||
Key differences from the existing graph partition scheme (partition_graph_with_id_mapping): | ||
(1) This function partitions the global node ID space uniformly, without distinguishing | ||
between source and destination nodes (i.e., matrix row ordering or column ordering). Both | ||
src/dst or row/col nodes are indexed consistently within the adjacency matrix. | ||
(2) Each local graph (sub-matrix) can be defined/constructed by just node/edge offsets from | ||
global graph. | ||
(3) The partitioning is performed on a global graph stored in CPU memory, and then each device | ||
(rank) constructs its local graph independently from the global csc matrix. | ||
|
||
Parameters | ||
---------- | ||
global_offsets : torch.Tensor | ||
CSC offsets, can live on the CPU | ||
global_indices : torch.Tensor | ||
CSC indices, can live on the CPU | ||
num_nodes : int | ||
number of nodes in the global graph | ||
partition_book : torch.Tensor | ||
the boundaries of 1-D row-decomp of adj. matrix for all ranks | ||
partition_size : int | ||
number of process groups across which graph is partitioned, | ||
i.e. the number of graph partitions | ||
partition_rank : int | ||
rank within process group managing the distributed graph, i.e. | ||
the rank determining which partition the corresponding local rank | ||
will manage | ||
device : torch.device | ||
device connected to the passed partition rank, i.e. the device | ||
on which the local graph and related buffers will live on | ||
""" | ||
|
||
# initialize graph partition | ||
graph_partition = GraphPartition( | ||
partition_size=partition_size, partition_rank=partition_rank, device=device | ||
) | ||
dtype = global_indices.dtype | ||
# -------------------------------------------------------------- | ||
# First partition the global row ptrs (dst nodes) to local row ptrs | ||
num_edges = global_indices.size(0) | ||
node_offset = partition_book[partition_rank] | ||
num_local_nodes = partition_book[partition_rank + 1] - partition_book[partition_rank] | ||
edge_partition_offset = global_offsets[node_offset] | ||
assert node_offset + num_local_nodes <= num_nodes, "Invalid node offset and number of local nodes" | ||
local_offsets = global_offsets[node_offset:node_offset + num_local_nodes + 1].to(device=device, non_blocking=True) | ||
graph_partition.local_offsets = local_offsets - edge_partition_offset | ||
graph_partition.num_local_dst_nodes = num_local_nodes | ||
|
||
# Scan through all partitions and compress the source nodes (edges) for each partition | ||
# to fill the local send/recv buffers for all-to-all communications | ||
partition_book = partition_book.to(device=device) | ||
for to_partition in range(partition_size): | ||
local_indices = global_indices[ | ||
global_offsets[partition_book[to_partition]]:global_offsets[partition_book[to_partition+1]] | ||
].to(device=device, non_blocking=True) | ||
# compress the columns (src nodes or local_indices) for each partition and record mapping (inverse_indices) | ||
global_src_node_at_partition, inverse_indices = local_indices.unique(sorted=True, return_inverse=True) | ||
global_src_node_at_partition_rank = torch.bucketize(global_src_node_at_partition, partition_book, right=True) - 1 | ||
src_node_indices = torch.nonzero(global_src_node_at_partition_rank == partition_rank, as_tuple=False).squeeze(1) | ||
# fill local send buffer for alltoalls (scatter selected nodes to_partition rank) | ||
graph_partition.scatter_indices[to_partition] = global_src_node_at_partition[src_node_indices] - node_offset | ||
# fill the numbers of indices (edges), dst nodes and src nodes for each partition | ||
graph_partition.num_indices_in_each_partition[to_partition] = local_indices.size(0) | ||
graph_partition.num_dst_nodes_in_each_partition[to_partition] = partition_book[to_partition+1] - partition_book[to_partition] | ||
graph_partition.num_src_nodes_in_each_partition[to_partition] = global_src_node_at_partition.size(0) | ||
|
||
if to_partition == partition_rank: | ||
graph_partition.local_indices = inverse_indices | ||
graph_partition.num_local_indices = graph_partition.local_indices.size(0) | ||
graph_partition.num_local_src_nodes = global_src_node_at_partition.size(0) | ||
# map from local (compressed) column indices [0, ..., num_local_src_nodes] to their global node IDs | ||
graph_partition.map_partitioned_src_ids_to_global = global_src_node_at_partition | ||
|
||
for from_partition in range(partition_size): | ||
# fill all recv buffer sizes for alltoalls | ||
graph_partition.sizes[from_partition][to_partition] = torch.count_nonzero(global_src_node_at_partition_rank == from_partition) | ||
|
||
# trivial mappings due to 1D row-wise decomposition | ||
graph_partition.map_partitioned_dst_ids_to_global = torch.arange(node_offset, node_offset + num_local_nodes, dtype=dtype, device=device) | ||
graph_partition.map_partitioned_edge_ids_to_global = torch.arange(edge_partition_offset, | ||
edge_partition_offset + graph_partition.num_local_indices, | ||
dtype=dtype, device=device) | ||
# trivial mappings due to 1D row-wise decomposition, with mem. cost O(E, N) at each dev; need to optimize | ||
graph_partition.map_concatenated_local_src_ids_to_global = torch.arange(num_nodes, dtype=dtype, device=device) | ||
graph_partition.map_concatenated_local_edge_ids_to_global = torch.arange(num_edges, dtype=dtype, device=device) | ||
graph_partition.map_concatenated_local_dst_ids_to_global = graph_partition.map_concatenated_local_src_ids_to_global | ||
graph_partition.map_global_src_ids_to_concatenated_local = graph_partition.map_concatenated_local_src_ids_to_global | ||
graph_partition.map_global_dst_ids_to_concatenated_local = graph_partition.map_concatenated_local_src_ids_to_global | ||
graph_partition.map_global_edge_ids_to_concatenated_local = graph_partition.map_concatenated_local_edge_ids_to_global | ||
graph_partition.matrix_decomp = True | ||
|
||
for r in range(graph_partition.partition_size): | ||
err_msg = "error in graph partition: list containing sizes of exchanged indices does not match the tensor of indices to be exchanged" | ||
if ( | ||
graph_partition.sizes[graph_partition.partition_rank][r] | ||
!= graph_partition.scatter_indices[r].numel() | ||
): | ||
raise AssertionError(err_msg) | ||
graph_partition = graph_partition.to(device=device) | ||
return graph_partition | ||
|
||
|
||
def partition_graph_nodewise( | ||
global_offsets: torch.Tensor, | ||
global_indices: torch.Tensor, | ||
partition_size: int, | ||
partition_rank: int, | ||
device: torch.device, | ||
matrix_decomp: bool = False, | ||
) -> GraphPartition: | ||
""" | ||
Utility function which partitions a global graph given as CSC structure naively | ||
|
@@ -429,13 +552,29 @@ def partition_graph_nodewise( | |
device : torch.device | ||
device connected to the passed partition rank, i.e. the device | ||
on which the local graph and related buffers will live on | ||
matrix_decomp : bool | ||
flag to enable matrix decomposition for partitioning | ||
""" | ||
|
||
num_global_src_nodes = global_indices.max().item() + 1 | ||
num_global_dst_nodes = global_offsets.size(0) - 1 | ||
num_dst_nodes_per_partition = ( | ||
num_global_dst_nodes + partition_size - 1 | ||
) // partition_size | ||
|
||
if matrix_decomp: | ||
assert num_global_src_nodes == num_global_dst_nodes, "Assuming square adjacency matrix (num_src=num_dst) for matrix decomposition" | ||
partition_book = torch.arange(0, num_global_dst_nodes, num_dst_nodes_per_partition, dtype=global_indices.dtype) | ||
partition_book = torch.cat([partition_book, torch.tensor([num_global_dst_nodes], dtype=global_indices.dtype)]) | ||
return partition_graph_with_matrix_decomposition( | ||
global_offsets, | ||
global_indices, | ||
num_global_dst_nodes, | ||
partition_book, | ||
partition_size, | ||
partition_rank, | ||
device, | ||
) | ||
|
||
num_src_nodes_per_partition = ( | ||
num_global_src_nodes + partition_size - 1 | ||
) // partition_size | ||
|
@@ -769,6 +908,12 @@ def get_src_node_features_in_partition( | |
) -> torch.Tensor: # pragma: no cover | ||
# if global features only on local rank 0 also scatter, split them | ||
# according to the partition and scatter them to other ranks | ||
|
||
if self.graph_partition.matrix_decomp: | ||
raise NotImplementedError( | ||
"Use get_dst_node_features_in_partition instead of collecting source node features, " | ||
"as there is only one node feature partition in matrix decomposition." | ||
) | ||
if scatter_features: | ||
global_node_features = global_node_features[ | ||
self.graph_partition.map_concatenated_local_src_ids_to_global | ||
|
@@ -872,6 +1017,12 @@ def get_global_src_node_features( | |
if partitioned_node_features.device != self.device: | ||
raise AssertionError(error_msg) | ||
|
||
if self.graph_partition.matrix_decomp: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as above |
||
raise NotImplementedError( | ||
"Use get_global_dst_node_features instead of collecting source node features, " | ||
"as there is only one node feature partition in matrix decomposition." | ||
) | ||
|
||
if not get_on_all_ranks: | ||
global_node_feat = gather_v( | ||
partitioned_node_features, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, one shouldn't make the code using these utilities dependent on how a graph is partitioned. Couldn't one instead of throwing this error just use
get_dst_node_features
underneath?