-
Notifications
You must be signed in to change notification settings - Fork 0
Vulkan & GLSL implementation of FlashAttention-2
License
etasnadi/VulkanCooperativeMatrixAttention
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Memory efficient scaled dot product attention (SDPA) implementation using Vulkan and plain GLSL ----------------------------------------------------------------------------------------------- This code is an independent implementation of the forward pass of the FlashAttention-2 algorithm. It uses cooperative matrices (VK_KHR_cooperative_matrix) to utilize Tensor Cores on NVIDIA hardware (or equivalent implementation on other GPUs) to accelerate matrix-matrix multiplications. It is intended to be an open-source, cross-gpu replacement of cuDNN's proprietary flash SDPA implementation. The algorithm (FlashAttention) is memory efficient, that means that it does not materialize the full softmax score matrix in the GPU memory to compute the attention output. For the details, consult the FlashAttention-2 paper [1]. It currently supports float16 and float32 datatypes with mixed precision: input type is always float16 while intermediate computation and matrix accumulation precision can be either float16 or float32 (and any combination of them). Python ctypes interface is also provided (src/frontend.py). Accuracy should be tested using the python interface as it directly compares the results to PyTorch's SDPA implementation (that can utilize the GPU). CPU implementation is also provided (it uses std::float16_t that is introduced in C++23 hence you need a compiler that supports it). To test the performance, launch the main application test_sdpa_flash. It can reach 11 TFLOPS on an RTX 2060 SUPER GPU. Requirements: ------------ Vulkan 1.3 capable device and SDK VK_KHR_cooperative_matrix support A compiler that supports std::float16_t (introduced in C++23). Building: --------- mkdir build cmake .. make Launch: ------ python3 ../src/frontend.py or ./test_sdpa_flash Set the DEVICE_NAME environment variable if you want to manually select the device. Otherwise it will use the device with minimum ID that supports the cooperative matrix extension. Debugging: ---------- uncomment target_compile_definitions(test_sdpa_flash PRIVATE DEBUG_MODE) in CMakeLists.txt. To copy the intermediate computation result to the host, define DEBUG_SAVE_S or DEBUG_SAVE_PIJT in the shader code. When debug mode is enabled (thus validation layers are enabled), an error might be triggered for in non-recent Vulkan SDKs. In that case, a recent validation layer build is needed from https://github.com/KhronosGroup/Vulkan-ValidationLayers.git that can be made available for the application by the VK_LAYER_PATH environment variable. Copyright notice: ----------------- (1) Some (non-significant) part of the code is built on code from jbolznv's cooperative matrix sample (https://github.com/jeffbolznv/vk_cooperative_matrix_perf), mainly (vulkan initialization and global memory load in glsl), therefore I decided to keep the license of their code. (2) Vulkan is a registered trademark of the Khronos Group, and this project is not officially endorsed by them. [1] https://arxiv.org/abs/2307.08691
About
Vulkan & GLSL implementation of FlashAttention-2
Topics
Resources
License
Stars
Watchers
Forks
Releases
No releases published
Packages 0
No packages published