Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add row-decomposition of adj. matrix to reduce graph partitioning overhead #720

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 152 additions & 1 deletion modulus/models/gnn_layers/distributed_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class GraphPartition:
partition_size: int
partition_rank: int
device: torch.device
# flag to indicate using adj matrix 1-D row-decomp
matrix_decomp: bool = False

# data structures defining partition
# set in after initialization or during execution
Expand Down Expand Up @@ -394,12 +396,133 @@ def partition_graph_with_id_mapping(
return graph_partition


def partition_graph_with_matrix_decomposition(
global_offsets: torch.Tensor,
global_indices: torch.Tensor,
num_nodes: int,
partition_book: torch.Tensor,
partition_size: int,
partition_rank: int,
device: torch.device,
) -> GraphPartition:
"""
Utility function which partitions a global graph given as CSC structure based on its adjacency
matirx using 1-D row-wise decomposition. This approach ensures a 1D uniform distribution of nodes
and their associated 1-hop incoming edges. By treating source and destination nodes equivalently
during partitioning, this approach assumes the graph is not bipartite.
This decomposition also ensures that the graph convolution (spMM) remains local by maintaining a copy of
the local incoming edge features and the local node outputs from the graph convolution.
The memory complexity of this approach is O[(N/P + E/P)*hid_dim*L], where N/E are the number of nodes/edges.
The transformation from local node storage to local edge storage is achieved using nccl `alltoall`.

Key differences from the existing graph partition scheme (partition_graph_with_id_mapping):
(1) This function partitions the global node ID space uniformly, without distinguishing
between source and destination nodes (i.e., matrix row ordering or column ordering). Both
src/dst or row/col nodes are indexed consistently within the adjacency matrix.
(2) Each local graph (sub-matrix) can be defined/constructed by just node/edge offsets from
global graph.
(3) The partitioning is performed on a global graph stored in CPU memory, and then each device
(rank) constructs its local graph independently from the global csc matrix.

Parameters
----------
global_offsets : torch.Tensor
CSC offsets, can live on the CPU
global_indices : torch.Tensor
CSC indices, can live on the CPU
num_nodes : int
number of nodes in the global graph
partition_book : torch.Tensor
the boundaries of 1-D row-decomp of adj. matrix for all ranks
partition_size : int
number of process groups across which graph is partitioned,
i.e. the number of graph partitions
partition_rank : int
rank within process group managing the distributed graph, i.e.
the rank determining which partition the corresponding local rank
will manage
device : torch.device
device connected to the passed partition rank, i.e. the device
on which the local graph and related buffers will live on
"""

# initialize graph partition
graph_partition = GraphPartition(
partition_size=partition_size, partition_rank=partition_rank, device=device
)
dtype = global_indices.dtype
# --------------------------------------------------------------
# First partition the global row ptrs (dst nodes) to local row ptrs
num_edges = global_indices.size(0)
node_offset = partition_book[partition_rank]
num_local_nodes = partition_book[partition_rank + 1] - partition_book[partition_rank]
edge_partition_offset = global_offsets[node_offset]
assert node_offset + num_local_nodes <= num_nodes, "Invalid node offset and number of local nodes"
local_offsets = global_offsets[node_offset:node_offset + num_local_nodes + 1].to(device=device, non_blocking=True)
graph_partition.local_offsets = local_offsets - edge_partition_offset
graph_partition.num_local_dst_nodes = num_local_nodes

# Scan through all partitions and compress the source nodes (edges) for each partition
# to fill the local send/recv buffers for all-to-all communications
partition_book = partition_book.to(device=device)
for to_partition in range(partition_size):
local_indices = global_indices[
global_offsets[partition_book[to_partition]]:global_offsets[partition_book[to_partition+1]]
].to(device=device, non_blocking=True)
# compress the columns (src nodes or local_indices) for each partition and record mapping (inverse_indices)
global_src_node_at_partition, inverse_indices = local_indices.unique(sorted=True, return_inverse=True)
global_src_node_at_partition_rank = torch.bucketize(global_src_node_at_partition, partition_book, right=True) - 1
src_node_indices = torch.nonzero(global_src_node_at_partition_rank == partition_rank, as_tuple=False).squeeze(1)
# fill local send buffer for alltoalls (scatter selected nodes to_partition rank)
graph_partition.scatter_indices[to_partition] = global_src_node_at_partition[src_node_indices] - node_offset
# fill the numbers of indices (edges), dst nodes and src nodes for each partition
graph_partition.num_indices_in_each_partition[to_partition] = local_indices.size(0)
graph_partition.num_dst_nodes_in_each_partition[to_partition] = partition_book[to_partition+1] - partition_book[to_partition]
graph_partition.num_src_nodes_in_each_partition[to_partition] = global_src_node_at_partition.size(0)

if to_partition == partition_rank:
graph_partition.local_indices = inverse_indices
graph_partition.num_local_indices = graph_partition.local_indices.size(0)
graph_partition.num_local_src_nodes = global_src_node_at_partition.size(0)
# map from local (compressed) column indices [0, ..., num_local_src_nodes] to their global node IDs
graph_partition.map_partitioned_src_ids_to_global = global_src_node_at_partition

for from_partition in range(partition_size):
# fill all recv buffer sizes for alltoalls
graph_partition.sizes[from_partition][to_partition] = torch.count_nonzero(global_src_node_at_partition_rank == from_partition)

# trivial mappings due to 1D row-wise decomposition
graph_partition.map_partitioned_dst_ids_to_global = torch.arange(node_offset, node_offset + num_local_nodes, dtype=dtype, device=device)
graph_partition.map_partitioned_edge_ids_to_global = torch.arange(edge_partition_offset,
edge_partition_offset + graph_partition.num_local_indices,
dtype=dtype, device=device)
# trivial mappings due to 1D row-wise decomposition, with mem. cost O(E, N) at each dev; need to optimize
graph_partition.map_concatenated_local_src_ids_to_global = torch.arange(num_nodes, dtype=dtype, device=device)
graph_partition.map_concatenated_local_edge_ids_to_global = torch.arange(num_edges, dtype=dtype, device=device)
graph_partition.map_concatenated_local_dst_ids_to_global = graph_partition.map_concatenated_local_src_ids_to_global
graph_partition.map_global_src_ids_to_concatenated_local = graph_partition.map_concatenated_local_src_ids_to_global
graph_partition.map_global_dst_ids_to_concatenated_local = graph_partition.map_concatenated_local_src_ids_to_global
graph_partition.map_global_edge_ids_to_concatenated_local = graph_partition.map_concatenated_local_edge_ids_to_global
graph_partition.matrix_decomp = True

for r in range(graph_partition.partition_size):
err_msg = "error in graph partition: list containing sizes of exchanged indices does not match the tensor of indices to be exchanged"
if (
graph_partition.sizes[graph_partition.partition_rank][r]
!= graph_partition.scatter_indices[r].numel()
):
raise AssertionError(err_msg)
graph_partition = graph_partition.to(device=device)
return graph_partition


def partition_graph_nodewise(
global_offsets: torch.Tensor,
global_indices: torch.Tensor,
partition_size: int,
partition_rank: int,
device: torch.device,
matrix_decomp: bool = False,
) -> GraphPartition:
"""
Utility function which partitions a global graph given as CSC structure naively
Expand Down Expand Up @@ -429,13 +552,29 @@ def partition_graph_nodewise(
device : torch.device
device connected to the passed partition rank, i.e. the device
on which the local graph and related buffers will live on
matrix_decomp : bool
flag to enable matrix decomposition for partitioning
"""

num_global_src_nodes = global_indices.max().item() + 1
num_global_dst_nodes = global_offsets.size(0) - 1
num_dst_nodes_per_partition = (
num_global_dst_nodes + partition_size - 1
) // partition_size

if matrix_decomp:
assert num_global_src_nodes == num_global_dst_nodes, "Assuming square adjacency matrix (num_src=num_dst) for matrix decomposition"
partition_book = torch.arange(0, num_global_dst_nodes, num_dst_nodes_per_partition, dtype=global_indices.dtype)
partition_book = torch.cat([partition_book, torch.tensor([num_global_dst_nodes], dtype=global_indices.dtype)])
return partition_graph_with_matrix_decomposition(
global_offsets,
global_indices,
num_global_dst_nodes,
partition_book,
partition_size,
partition_rank,
device,
)

num_src_nodes_per_partition = (
num_global_src_nodes + partition_size - 1
) // partition_size
Expand Down Expand Up @@ -769,6 +908,12 @@ def get_src_node_features_in_partition(
) -> torch.Tensor: # pragma: no cover
# if global features only on local rank 0 also scatter, split them
# according to the partition and scatter them to other ranks

if self.graph_partition.matrix_decomp:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, one shouldn't make the code using these utilities dependent on how a graph is partitioned. Couldn't one instead of throwing this error just use get_dst_node_features underneath?

raise NotImplementedError(
"Use get_dst_node_features_in_partition instead of collecting source node features, "
"as there is only one node feature partition in matrix decomposition."
)
if scatter_features:
global_node_features = global_node_features[
self.graph_partition.map_concatenated_local_src_ids_to_global
Expand Down Expand Up @@ -872,6 +1017,12 @@ def get_global_src_node_features(
if partitioned_node_features.device != self.device:
raise AssertionError(error_msg)

if self.graph_partition.matrix_decomp:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

raise NotImplementedError(
"Use get_global_dst_node_features instead of collecting source node features, "
"as there is only one node feature partition in matrix decomposition."
)

if not get_on_all_ranks:
global_node_feat = gather_v(
partitioned_node_features,
Expand Down
52 changes: 52 additions & 0 deletions test/models/test_graph_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ def global_graph():

return (offsets, indices, num_src_nodes, num_dst_nodes)

@pytest.fixture
def global_graph_square():
"""test fixture: simple non-bipartie graph with a degree of 2 per node"""
#num_src_nodes = 4
#num_dst_nodes = 4
#num_edges = 8
offsets = torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64)
indices = torch.tensor([0, 3, 2, 1, 1, 0, 1, 2], dtype=torch.int64)

return (offsets, indices, 4, 4)


def assert_partitions_are_equal(a, b):
"""test utility: check if a matches b"""
Expand Down Expand Up @@ -162,6 +173,47 @@ def test_gp_nodewise(global_graph, device):

assert_partitions_are_equal(pg, pg_expected)

@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_gp_matrixdecomp(global_graph_square, device):
offsets, indices, num_src_nodes, num_dst_nodes = global_graph_square
partition_size = 4
partition_rank = 0

pg = partition_graph_nodewise(
offsets,
indices,
partition_size,
partition_rank,
device,
matrix_decomp = True
)

pg_expected = GraphPartition(
partition_size=4,
partition_rank=0,
device=device,
local_offsets=torch.tensor([0, 2]),
local_indices=torch.tensor([0, 1]),
num_local_src_nodes=2,
num_local_dst_nodes=1,
num_local_indices=2,
map_partitioned_src_ids_to_global=torch.tensor([0, 3]),
map_partitioned_dst_ids_to_global=torch.tensor([0]),
map_partitioned_edge_ids_to_global=torch.tensor([0, 1]),
sizes=[[1, 0, 1, 0], [0, 1, 1, 1], [0, 1, 0, 1], [1, 0, 0, 0]],
scatter_indices=[
torch.tensor([0]),
torch.tensor([], dtype=torch.int64),
torch.tensor([0]),
torch.tensor([], dtype=torch.int64),
],
num_src_nodes_in_each_partition=[2, 2, 2, 2],
num_dst_nodes_in_each_partition=[1, 1, 1, 1],
num_indices_in_each_partition=[2, 2, 2, 2],
).to(device=device)

assert_partitions_are_equal(pg, pg_expected)


@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_gp_coordinate_bbox(global_graph, device):
Expand Down