Build wheels and upload to GitHub Releases #14
Workflow file for this run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: Build wheels and upload to GitHub Releases | |
on: | |
create: | |
tags: | |
- "v*" | |
jobs: | |
create_releases: | |
name: Create Releases | |
runs-on: ubuntu-latest | |
steps: | |
- name: Get the tag version | |
id: extract_branch | |
run: echo "branch=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT | |
shell: bash | |
- name: Create Release | |
id: create_release | |
uses: actions/create-release@v1 | |
env: | |
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | |
with: | |
tag_name: ${{ steps.extract_branch.outputs.branch }} | |
release_name: ${{ steps.extract_branch.outputs.branch }} | |
body: | | |
| Flash-Attention | Python | PyTorch | CUDA | | |
|-----------------|--------|---------|------| | |
| 2.4.3, 2.5.6, 2.6.3, 2.7.0.post2 | 3.10, 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.1 | 11.8.0, 12.1.1, 12.4.1 | | |
build_wheels: | |
name: Build wheels and Upload | |
needs: create_releases | |
runs-on: ubuntu-latest | |
strategy: | |
fail-fast: false | |
matrix: | |
flash-attn-version: ["2.4.3", "2.5.6", "2.6.3", "2.7.0.post2"] | |
python-version: ["3.10", "3.11", "3.12"] | |
torch-version: ["2.0.1", "2.1.2", "2.2.2", "2.3.1", "2.4.1", "2.5.1"] | |
cuda-version: ["11.8.0", "12.1.1", "12.4.1"] | |
exclude: | |
# torch < 2.2 does not support Python 3.12 | |
- python-version: "3.12" | |
torch-version: "2.0.1" | |
- python-version: "3.12" | |
torch-version: "2.1.2" | |
# torch 2.0.1 does not support CUDA 12.x | |
- torch-version: "2.0.1" | |
cuda-version: "12.1.1" | |
- torch-version: "2.0.1" | |
cuda-version: "12.4.1" | |
steps: | |
- uses: actions/checkout@v4 | |
- name: Maximize build space | |
run: | | |
df -h | |
echo "-----------------------------" | |
sudo rm -rf /usr/share/dotnet | |
sudo rm -rf /usr/local/lib/android | |
sudo rm -rf /opt/ghc | |
sudo rm -rf /opt/hostedtoolcache/CodeQL | |
df -h | |
- name: Set Swap Space | |
uses: pierotofy/set-swap-space@master | |
with: | |
swap-size-gb: 48 | |
- uses: actions/setup-python@v5 | |
with: | |
python-version: ${{ matrix.python-version }} | |
- uses: Jimver/cuda-toolkit@master | |
with: | |
cuda: ${{ matrix.cuda-version }} | |
linux-local-args: '["--toolkit"]' | |
method: "network" | |
- run: sudo apt install -y ninja-build | |
- name: Set CUDA and PyTorch versions | |
run: | | |
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV | |
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV | |
echo "CACHE_KEY=cuda-ext-${{ matrix.flash-attn-version }}-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-cuda${{ matrix.cuda-version }}" >> $GITHUB_ENV | |
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} | |
run: | | |
pip install -U pip | |
pip install wheel setuptools packaging | |
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ | |
support_cuda_versions = { \ | |
'2.0': [117, 118], \ | |
'2.1': [118, 121], \ | |
'2.2': [118, 121], \ | |
'2.3': [118, 121], \ | |
'2.4': [118, 121, 124], \ | |
'2.5': [118, 121, 124], \ | |
}; \ | |
target_cuda_versions = support_cuda_versions[env['MATRIX_TORCH_VERSION']]; \ | |
cuda_version = int(env['MATRIX_CUDA_VERSION']); \ | |
closest_version = min(target_cuda_versions, key=lambda x: abs(x - cuda_version)); \ | |
print(closest_version) \ | |
") | |
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} | |
nvcc --version | |
python -V | |
python -c "import torch; print('PyTorch:', torch.__version__)" | |
python -c "import torch; print('CUDA:', torch.version.cuda)" | |
python -c "from torch.utils import cpp_extension; print(cpp_extension.CUDA_HOME)" | |
- name: Checkout flash-attn | |
run: | | |
git clone https://github.com/Dao-AILab/flash-attention.git | |
cd flash-attention | |
git checkout v${{ matrix.flash-attn-version }} | |
# Add cache steps for CUDA extension build | |
- name: Cache CUDA extension build | |
uses: actions/cache@v3 | |
with: | |
path: | | |
flash-attention/build | |
flash-attention/flash_attn.egg-info | |
flash-attention/**/*.so | |
key: ${{ env.CACHE_KEY }}-${{ hashFiles('flash-attention/csrc/**') }} | |
restore-keys: | | |
${{ env.CACHE_KEY }}- | |
- name: Build wheels | |
run: | | |
pip install setuptools==68.0.0 ninja packaging wheel | |
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH | |
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH | |
export MAX_JOBS=$(($(nproc) - 1)) | |
cd flash-attention | |
FLASH_ATTENTION_FORCE_BUILD="TRUE" python setup.py bdist_wheel --dist-dir=dist | |
base_wheel_name=$(basename $(ls dist/*.whl | head -n 1)) | |
wheel_name=$(echo $base_wheel_name | sed "s/${{ matrix.flash-attn-version }}/${{ matrix.flash-attn-version }}+cu${{ env.MATRIX_CUDA_VERSION }}torch${{ env.MATRIX_TORCH_VERSION }}/") | |
mv dist/$base_wheel_name dist/$wheel_name | |
echo "wheel_name=$wheel_name" >> $GITHUB_ENV | |
- name: Install Test | |
run: | | |
pip install flash-attention/dist/${{ env.wheel_name }} | |
python -c "import flash_attn; print(flash_attn.__version__)" | |
- name: Get the tag version | |
id: extract_branch | |
run: echo "branch=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT | |
- name: Get Release with Tag | |
id: get_release | |
uses: joutvhu/get-release@v1 | |
env: | |
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | |
with: | |
tag_name: ${{ steps.extract_branch.outputs.branch }} | |
- name: Upload Release Asset | |
uses: actions/upload-release-asset@v1 | |
env: | |
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | |
with: | |
upload_url: ${{ steps.get_release.outputs.upload_url }} | |
asset_path: flash-attention/dist/${{ env.wheel_name }} | |
asset_name: ${{ env.wheel_name }} | |
asset_content_type: application/* |