Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Relax the restriction of dot shape >= 16 #3908

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

giuseros
Copy link
Contributor

This is my first PR in Triton, and it is trying to fix the limitation on dot to support sizes bigger than (M,N,K)==(16,16,16).

I modified semantic.py to relax the tt.dot limitations on the size of the matrices (for gfx9 architectures). Please note that there is a supportMFMA function that only accepts MN sizes multiple of 16 and K sizes multiple of 8.

Based on that, I relaxed the restriction in semantics.py to support (M,N,K)>=(16,16,8). This is the minimal change.

If we want to push further, we would need change supportMFMA and add tests for smaller layouts (many of those smaller layouts are broadcast layouts. Do we support this in the AMD backend?)

Please note: for now, if I try to feed a smaller layout (e.g., 8x8x8) the test fails by mismatches.

@giuseros giuseros requested a review from ptillet as a code owner May 14, 2024 16:57
@giuseros
Copy link
Contributor Author

cc @zhanglx13 @binarman

@jlebar jlebar requested a review from antiagainst May 14, 2024 16:58
@zhanglx13 zhanglx13 marked this pull request as draft May 14, 2024 17:01
@binarman
Copy link
Contributor

@giuseros

many of those smaller layouts are broadcast layouts. Do we support this in the AMD backend

Do you mean slice layout?
If so, answer is yes, we do. There are some issues with WMMA at this point, but I think @joviliast is working on this issue at the moment.

@YixinSong-e
Copy link

Nice! Do you have plans for support 888?

@giuseros
Copy link
Contributor Author

@giuseros

many of those smaller layouts are broadcast layouts. Do we support this in the AMD backend

Do you mean slice layout? If so, answer is yes, we do. There are some issues with WMMA at this point, but I think @joviliast is working on this issue at the moment.

I am not sure what is a slice layout, but instructions like mfma_4x4x1_16B work on 16 Blocks. You can broadcast rows of A (or columns of B) to make this work as a single GEMM. Is this what @joviliast is working on?

@giuseros
Copy link
Contributor Author

Nice! Do you have plans for support 8_8_8?

Hi @YixinSong-e , I think we should support any size in the front-end and let the backend decide how to lower it down. But we need other people to agree with this :)

@joviliast
Copy link
Contributor

@giuseros

many of those smaller layouts are broadcast layouts. Do we support this in the AMD backend

Do you mean slice layout? If so, answer is yes, we do. There are some issues with WMMA at this point, but I think @joviliast is working on this issue at the moment.

I am not sure what is a slice layout, but instructions like mfma_4x4x1_16B work on 16 Blocks. You can broadcast rows of A (or columns of B) to make this work as a single GEMM. Is this what @joviliast is working on?

I believe slices for wmma layouts are completely supported

@binarman
Copy link
Contributor

@giuseros

I am not sure what is a slice layout, but instructions like mfma_4x4x1_16B work on 16 Blocks. You can broadcast rows of A (or columns of B) to make this work as a single GEMM. Is this what @joviliast is working on?

Ah, I see, thanks.

FYI: I've experimented with 3 types of layouts that use mfma4x4 in ROCm fork.
They had following tile sizes (A(MxK)*B(KxN)):

  1. 4(M) x 4(N) x 64(K)
  2. 4(M) x 64(N) x 4(K)
  3. 4(M) x 64(N) x 64(K)

So far, the most promising layout is the third one, but it's use is limited, because of large difference in size between first and second operand

@giuseros
Copy link
Contributor Author

@giuseros

I am not sure what is a slice layout, but instructions like mfma_4x4x1_16B work on 16 Blocks. You can broadcast rows of A (or columns of B) to make this work as a single GEMM. Is this what @joviliast is working on?

Ah, I see, thanks.

FYI: I've experimented with 3 types of layouts that use mfma4x4 in ROCm fork. They had following tile sizes (A(MxK)*B(KxN)):

  1. 4(M) x 4(N) x 64(K)
  2. 4(M) x 64(N) x 4(K)
  3. 4(M) x 64(N) x 64(K)

So far, the most promising layout is the third one, but it's use is limited, because of large difference in size between first and second operand

So I guess my point is that there are two natural next steps to this PR:

  • First, support every size for the non-accelerated layout
  • Second, introduce broadcasts layouts and supports those at least for the cases when the size is too small to support any other mfma

I think the second point is not super important, because many frameworks simply use reduction mfma. It might be that we also want to skip the first point if there is higher priority work to do

@ThomasRaoux
Copy link
Contributor

Can you provide a bit more info on the motivation? It sounds like this breaks portability

@giuseros
Copy link
Contributor Author

Can you provide a bit more info on the motivation? It sounds like this breaks portability

Hi @ThomasRaoux , the point is that the AMD backend can accelerate smaller sizes than 16x16x16, that's why we are trying to add this relaxation in the frontend.

@antiagainst
Copy link
Collaborator

Can you provide a bit more info on the motivation? It sounds like this breaks portability

Do we require all implementation to support same set of shapes? I think that'd be hard right? Various ways to accelerate different dot variants are very important "innovations" these days. And we have different supporting levels for various element types anyway.

I feel it might make sense to be less restrictive here and let backend to decide how to best lower and/or reject if cannot support?

else:
assert lhs.shape[-2].value >= 16 and lhs.shape[-1].value >= 16 \
and rhs.shape[-1].value >= 16, \
f"All non-batch values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
if lhs.type.scalar.is_int():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the int type requirement is specific for cuda, so we should add remove it for amd backend.

@@ -1319,6 +1319,9 @@ def _str_to_dot_input_precision(input_precision, builder):
def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int,
out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:

def support_m16n16k8():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fp8 and int8 on MI300, the mfma instructions are 32X32X16 and 16X16X32 which is not applicable here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants