-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] Relax the restriction of dot shape >= 16 #3908
base: main
Are you sure you want to change the base?
Conversation
Do you mean |
Nice! Do you have plans for support 888? |
I am not sure what is a |
Hi @YixinSong-e , I think we should support any size in the front-end and let the backend decide how to lower it down. But we need other people to agree with this :) |
I believe slices for wmma layouts are completely supported |
Ah, I see, thanks. FYI: I've experimented with 3 types of layouts that use mfma4x4 in ROCm fork.
So far, the most promising layout is the third one, but it's use is limited, because of large difference in size between first and second operand |
So I guess my point is that there are two natural next steps to this PR:
I think the second point is not super important, because many frameworks simply use reduction mfma. It might be that we also want to skip the first point if there is higher priority work to do |
Can you provide a bit more info on the motivation? It sounds like this breaks portability |
Hi @ThomasRaoux , the point is that the AMD backend can accelerate smaller sizes than 16x16x16, that's why we are trying to add this relaxation in the frontend. |
Do we require all implementation to support same set of shapes? I think that'd be hard right? Various ways to accelerate different dot variants are very important "innovations" these days. And we have different supporting levels for various element types anyway. I feel it might make sense to be less restrictive here and let backend to decide how to best lower and/or reject if cannot support? |
else: | ||
assert lhs.shape[-2].value >= 16 and lhs.shape[-1].value >= 16 \ | ||
and rhs.shape[-1].value >= 16, \ | ||
f"All non-batch values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!" | ||
if lhs.type.scalar.is_int(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the int type requirement is specific for cuda, so we should add remove it for amd backend.
@@ -1319,6 +1319,9 @@ def _str_to_dot_input_precision(input_precision, builder): | |||
def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, | |||
out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: | |||
|
|||
def support_m16n16k8(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For fp8 and int8 on MI300, the mfma instructions are 32X32X16
and 16X16X32
which is not applicable here
This is my first PR in Triton, and it is trying to fix the limitation on
dot
to support sizes bigger than(M,N,K)==(16,16,16)
.I modified
semantic.py
to relax thett.dot
limitations on the size of the matrices (forgfx9
architectures). Please note that there is a supportMFMA function that only acceptsMN
sizes multiple of 16 andK
sizes multiple of 8.Based on that, I relaxed the restriction in
semantics.py
to support(M,N,K)>=(16,16,8)
. This is the minimal change.If we want to push further, we would need change
supportMFMA
and add tests for smaller layouts (many of those smaller layouts are broadcast layouts. Do we support this in the AMD backend?)Please note: for now, if I try to feed a smaller layout (e.g., 8x8x8) the test fails by mismatches.