From d25708aaa9145f04d2bf12f2c310c1f71d1ad6f2 Mon Sep 17 00:00:00 2001 From: Junya Morioka Date: Sun, 27 Oct 2024 19:54:25 +0900 Subject: [PATCH] Update build.yml --- .github/workflows/build.yml | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index cb754b6..bf0d95c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -96,6 +96,7 @@ jobs: run: | echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV + echo "CACHE_KEY=cuda-ext-${{ matrix.flash-attn-version }}-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-cuda${{ matrix.cuda-version }}" >> $GITHUB_ENV - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} run: | @@ -128,6 +129,18 @@ jobs: cd flash-attention git checkout v${{ matrix.flash-attn-version }} + # Add cache steps for CUDA extension build + - name: Cache CUDA extension build + uses: actions/cache@v3 + with: + path: | + flash-attention/build + flash-attention/flash_attn.egg-info + flash-attention/**/*.so + key: ${{ env.CACHE_KEY }}-${{ hashFiles('flash-attention/csrc/**') }} + restore-keys: | + ${{ env.CACHE_KEY }}- + - name: Build wheels run: | pip install setuptools==68.0.0 ninja packaging wheel @@ -137,7 +150,8 @@ jobs: cd flash-attention FLASH_ATTENTION_FORCE_BUILD="TRUE" python setup.py bdist_wheel --dist-dir=dist base_wheel_name=$(basename $(ls dist/*.whl | head -n 1)) - wheel_name=$(echo $base_wheel_name | sed 's/${{ matrix.flash-attn-version }}/${{ matrix.flash-attn-version }}+cu${{ env.MATRIX_CUDA_VERSION }}torch${{ env.MATRIX_TORCH_VERSION }}') + wheel_name=$(echo $base_wheel_name | sed "s/${{ matrix.flash-attn-version }}/${{ matrix.flash-attn-version }}+cu${{ env.MATRIX_CUDA_VERSION }}torch${{ env.MATRIX_TORCH_VERSION }}/") + mv dist/$base_wheel_name dist/$wheel_name echo "wheel_name=$wheel_name" >> $GITHUB_ENV - name: Install Test