Skip to content

etasnadi/VulkanCooperativeMatrixAttention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Memory efficient scaled dot product attention (SDPA) implementation using Vulkan and plain GLSL
-----------------------------------------------------------------------------------------------

This code is an independent implementation of the forward pass of the FlashAttention-2 algorithm. It uses cooperative matrices (VK_KHR_cooperative_matrix) to utilize Tensor Cores on NVIDIA hardware (or equivalent implementation on other GPUs) to accelerate matrix-matrix multiplications. It is intended to be an open-source, cross-gpu replacement of cuDNN's proprietary flash SDPA implementation. The algorithm (FlashAttention) is memory efficient, that means that it does not materialize the full softmax score matrix in the GPU memory to compute the attention output. For the details, consult the FlashAttention-2 paper [1].

It currently supports float16 and float32 datatypes with mixed precision: input type is always float16 while intermediate computation and matrix  accumulation precision can be either float16 or float32 (and any combination of them).

Python ctypes interface is also provided (src/frontend.py). Accuracy should be tested using the python interface as it directly compares the results to PyTorch's SDPA implementation (that can utilize the GPU). 

CPU implementation is also provided (it uses std::float16_t that is introduced in C++23 hence you need a compiler that supports it).

To test the performance, launch the main application test_sdpa_flash. It can reach 11 TFLOPS on an RTX 2060 SUPER GPU.

Requirements:
------------
Vulkan 1.3 capable device and SDK
VK_KHR_cooperative_matrix support
A compiler that supports std::float16_t (introduced in C++23).

Building:
---------
mkdir build
cmake ..
make

Launch:
------
python3 ../src/frontend.py
or
./test_sdpa_flash

Set the DEVICE_NAME environment variable if you want to manually select the device. Otherwise it will use the device with minimum ID that supports the cooperative matrix extension.

Debugging:
----------
uncomment target_compile_definitions(test_sdpa_flash PRIVATE DEBUG_MODE) in CMakeLists.txt. To copy the intermediate computation result to the host, define DEBUG_SAVE_S or DEBUG_SAVE_PIJT in the shader code.

When debug mode is enabled (thus validation layers are enabled), an error might be triggered for in non-recent Vulkan SDKs. In that case, a recent validation layer build is needed from https://github.com/KhronosGroup/Vulkan-ValidationLayers.git that can be made available for the application by  the VK_LAYER_PATH environment variable.

Copyright notice: 
-----------------
(1) Some (non-significant) part of the code is built on code from jbolznv's cooperative matrix sample (https://github.com/jeffbolznv/vk_cooperative_matrix_perf), mainly (vulkan initialization and global memory load in glsl), therefore I decided to keep the license of their code.
(2) Vulkan is a registered trademark of the Khronos Group, and this project is not officially endorsed by them.

[1] https://arxiv.org/abs/2307.08691