- general reduction type
- autograd
- CPU version
- PR into torch_scatter
We have identified that although the scatter and segment operations are cornerstone in the construction of geometric deep learning systems, current deep learning frameworks have not optimized it effectively. What's more, current framework solution does not support fusion for segment reduction. In torch.compile, the scatter and message will be fused into a fully-atomic SpMM, which largely affects the performance. (See benchmark detail here)
To address these challenges, we propose new operators and paradigms as shown in the following sections:
dst = index_scatter(dim, index, src, reduce, sorted=True) → Tensor
Contrary to the conventional scatter_reduce operation which allows for flexibility in the dimensionality of dst and src, our approach necessitates that both dst and src tensors share an identical number of dimensions. This constraint aligns our method more closely with the index_reduce operation.
E.g. For a 3-D tensor with reduce="sum"
, the output is calculated as follows:
dst[index[i]][j][k] += src[i][j][k] # if dim == 0
dst[i][index[j]][k] += src[i][j][k] # if dim == 1
dst[i][j][index[k]] += src[i][j][k] # if dim == 2
Additionally, we have integrated a sorted
flag to optimize the index_scatter_reduce kernel's performance. A sorted index enhances processing locality and minimizes atomic operations, thereby substantially improving the parallel processing efficiency for both CPUs and GPUs. This is formulated as implicit segment reduction (or segment coo).
In Graph Neural Networks (GNNs), fusing the message and aggregation steps is a prevalent strategy. PyTorch Geometric (PyG) utilizes the torch_sparse library, which offers optimized Sparse Matrix-Matrix Multiplication (SpMM) to facilitate this process. Beyond traditional sparse format utilization, we have developed a method for efficient SpMM built upon segment reduction. Within GeoT, SpMM can be effortlessly executed using the gather_scatter function, where it's ensured that the adjacency matrix is sorted by column.
# consider adj as edge_index
# no weight
dst = gather_scatter(edge_index[0], edge_index[1], src, reduce) → Tensor
# with weight
dst = gather_weight_scatter(edge_index[0], edge_index[1], weight, src, reduce) → Tensor
This format-agnostic approach ensures compatibility with mainstream frameworks and compilers by achieving segmentation implicitly, without requiring explicit sparse format specification. We plan to integrate this feature into the Triton language in the future.
Setup
python setup.py build install
Run benchmark
cd benchmark
python bench_index_scatter.py
python bench_spmm.py
For more detail, please check our paper. If you find this repo useful, please cite the following bib
@article{yu2024geot,
title={GeoT: Tensor Centric Library for Graph Neural Network via Efficient Segment Reduction on GPU},
author={Yu, Zhongming and Zhang, Genghan and Huang, Hanxian and Chen, Xin and Zhao, Jishen},
journal={arXiv preprint arXiv:2404.03019},
year={2024}
}