Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile error with fp8 block pointer usage #3857

Open
cyang49 opened this issue May 8, 2024 · 0 comments
Open

Compile error with fp8 block pointer usage #3857

cyang49 opened this issue May 8, 2024 · 0 comments

Comments

@cyang49
Copy link

cyang49 commented May 8, 2024

Hello,

I'm testing a fused attention kernel which potentially supports fp8. The code works in fp16. However, when I enable fp8 and try it on NVIDIA H100 with latest triton built from source main branch, I got an obscure compile error

python: /home/ccyang/.triton/llvm/llvm-ed4e505c-ubuntu-x64/include/llvm/Support/Casting.h:572: decltype(auto) llvm::cast(From&) [with To = mlir::IntegerAttr; From = mlir::Attribute]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
Aborted (core dumped)

With some testing I found that, if I remove the boundary checking and padding usage defined in load_fn, the compile error can go away. But I think those are needed for correctness. I tried to narrow it down to a minimally reproducible example (see below), where I removed most logic and keep only block pointers, load, tl.dot and store back.

#!/usr/bin/env python

import argparse
import pytest
import random
import sys
import torch

import triton
import triton.language as tl

torch_dtype:tl.constexpr = torch.float16

TORCH_HAS_FP8E5 = True
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2')
if TORCH_HAS_FP8E5:
    torch_dtype:tl.constexpr = torch.float8_e5m2

class MetaData():
    cu_seqlens_q = None
    cu_seqlens_k = None
    max_seqlens_q = 0
    max_seqlens_k = 0

    def __init__(self, sm_scale=1.0):
        self.sm_scale = sm_scale

@triton.autotune(
   configs=[
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
   ],
   key=['hq', 'hk', 'BLOCK_DMODEL'],
   use_cuda_graph=True,
)
@triton.jit
def attn_fwd(
    Q, K, Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_oz, stride_oh, stride_om, stride_on,
    hq, hk,
    ACTUAL_BLOCK_DMODEL:tl.constexpr,
    MAX_SEQLENS_Q:tl.constexpr, MAX_SEQLENS_K:tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
    PRE_LOAD_V: tl.constexpr,
    BATCH_SIZE: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_h_q = tl.program_id(1)
    off_z = tl.program_id(2)
    
    cu_seqlens_q_start = 0
    cu_seqlens_k_start = 0
    seqlen_q = MAX_SEQLENS_Q
    seqlen_k = MAX_SEQLENS_K

    is_mqa = hq != hk
    off_h_k = off_h_q % hk if is_mqa else off_h_q

    # Compute pointers for all the tensors used in this kernel.
    q_offset = off_z * stride_qz +  off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
    Q_block_ptr = tl.make_block_ptr(
        base=Q + q_offset,
        shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    )
    k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
    K_block_ptr = tl.make_block_ptr(
        base=K + k_offset,
        shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1)
    )

    acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
    # q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="")
    q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero")
    
    k = tl.load(K_block_ptr)
    acc += tl.dot(q, k)
    acc = acc.to(Out.type.element_ty)
    
    # write back O
    o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
    O_block_ptr = tl.make_block_ptr(
        base=Out + o_offset,
        shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
        strides=(stride_om, stride_on),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_N),
        order=(1, 0)
    )
    tl.store(O_block_ptr, acc, boundary_check=(0,1))


class _attention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, o, metadata):
        if o is None:
            o = torch.empty_like(q, dtype=v.dtype)
        # metadata.check_args(q, k, v, o)

        batch, nheads_q, seqlen_q, head_size = q.shape
        _, nheads_k, seqlen_k, _ = k.shape
        q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))
        k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))
        v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))
        o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))

        padded_d_model = head_size

        grid = lambda META: (
            triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']),
            nheads_q,
            batch
        )

        encoded_softmax = None

        M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32)

        attn_fwd[grid](
            q, k, o,
            *q_strides, *k_strides, *o_strides,
            hq=nheads_q, hk=nheads_k,
            ACTUAL_BLOCK_DMODEL=head_size,
            MAX_SEQLENS_Q=metadata.max_seqlens_q,
            MAX_SEQLENS_K=metadata.max_seqlens_k,
            BLOCK_DMODEL=padded_d_model,
            BATCH_SIZE= q.shape[0]
        )
        return o, encoded_softmax


attention = _attention.apply

@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD',
                         [(4, 32, 1024, 1024, 64)])
def test_op_fwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal=False, use_alibi=False, dtype=torch.float16):
    torch.manual_seed(20)
    sm_scale = D_HEAD ** -0.5
    input_metadata = MetaData(sm_scale=sm_scale)
    input_metadata.max_seqlens_q = N_CTX_Q
    input_metadata.max_seqlens_k = N_CTX_K

    q = torch.randn((Z, H, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
    k = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
    v = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
    o = torch.empty_like(q)

    scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale
    if causal:
        mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), 
                          diagonal=N_CTX_K-N_CTX_Q)
        scores[:, :, mask==0] = float("-inf")

    p = torch.softmax(scores, dim=-1)
    ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v)
        
    # triton implementation
    if TORCH_HAS_FP8E5:
        q = q.to(torch_dtype)
        k = k.to(torch_dtype)
    tri_out, _ = attention(q, k, v, o, input_metadata)
    # reference implementation:171

    # compare
    torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)

if __name__ == "__main__":
    test_op_fwd(4, 32, 1024, 1024, 128, False, False)

If you toggle between these two lines, you can either reproduce the compile error or get the code to compile. It suggests that there is a bug or some limitation in fp8 block pointer boundary checks and padding option. If this is a limitation, could you suggest a workaround? Thanks.

    # q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="")
    q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero")

Running with pytest a call stack is dumped after the abortion

Current thread 0x00007f4c8a158740 (most recent call first):
  File "/home/ccyang/github.com/triton/python/triton/backends/nvidia/compiler.py", line 212 in make_llir
  File "/home/ccyang/github.com/triton/python/triton/backends/nvidia/compiler.py", line 302 in <lambda>
  File "/home/ccyang/github.com/triton/python/triton/compiler/compiler.py", line 282 in compile
  File "/home/ccyang/github.com/triton/python/triton/runtime/jit.py", line 662 in run
  File "/home/ccyang/github.com/triton/python/triton/runtime/autotuner.py", line 174 in run
  File "/home/ccyang/github.com/triton/python/triton/runtime/jit.py", line 345 in <lambda>
  File "/home/ccyang/github.com/cyang49/foundation-model-stack/fms/triton/bug_reproducer.py", line 251 in forward
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant