Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD][WMMA] Support dot3d #3674

Merged
merged 4 commits into from
May 28, 2024
Merged

Conversation

binarman
Copy link
Contributor

@binarman binarman commented Apr 16, 2024

This PR enables support of 3d dot for RDNA GPUs.

@binarman
Copy link
Contributor Author

+cc @joviliast

@@ -1649,7 +1676,7 @@ AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
auto rep = getWMMARepForOperands(shape, eltTy, kWidth, opIdx);
return rep[0] * rep[1] * kWidth;
return rep[0] * rep[1] * rep[2] * kWidth;
Copy link
Contributor

@joviliast joviliast Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use something like
return product(rep) * kWidth; ?

Comment on lines 100 to 101
unsigned mfmaInstrNonK = elemsPerInstr[opIdx == 0 ? 0 : 1];
unsigned mfmaInstrK = elemsPerInstr[opIdx == 0 ? 1 : 0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as it became a common logic, could you please rename it?

@joviliast
Copy link
Contributor

LGTM
Thank you for this PR.

Have you run test_dot locally on navi ?

@binarman
Copy link
Contributor Author

Have you run test_dot locally on navi ?

yes

  • fp16->fp32 tests pass
  • fp16->fp16 some tests pass, some fail due to mismatches (few values exceed tolerance threshold)
  • fp32->fp32 are not supported in WMMA, they go through FMA pipeline and fail
  • int8 tests do not pass, but this is expected, test_dot do not work at the moment as well

Copy link
Contributor

@joviliast joviliast left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. LGTM

Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some lit tests? At the moment we don't have CI for RDNA GPUs so test_core.py is effectively not checked. It may regress at any time. lit tests is checking the compiler transformation and can make sure we have some guarantee. It's also easier to read and fix lit tests than full blown integration runtime tests. So lit tests are typically the first line of defense for quality.

@binarman
Copy link
Contributor Author

binarman commented May 1, 2024

@antiagainst
hi!
From my experience, most of dot issues are related to wrong indexing/address computations, which requires large and complex lit test to check. Such test will be extremely fragile, very complex and hard to read.

@joviliast is working on same code, so even if I add lit test, it will probably break in near future adding more redundant work to him(or me, it depends who will merge changes first).

I can implement some basic test, which will check that there are no crashes, but in my opinion this test does not guarantee much.

P.s. We have some basic llir interpreter which can help checking changes from this PR, but at this point it requires some massive work. I prefer to invest time in this task, if correctness on Navi aligns with our team priorities.

@binarman
Copy link
Contributor Author

@antiagainst PTAL

python/test/unit/language/test_core.py Outdated Show resolved Hide resolved
if triton.runtime.driver.active.get_current_target().arch == "gfx1100":
if in_dtype_str == "int8" or in_dtype_str == "float32":
pytest.skip(f"{in_dtype_str} is not supported in WMMA dot")
if out_dtype_str == "float16":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are float16 accumulate wmma ops? Are they not matching the precision w.r.t. reference pytorch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to check this. At some point they did not match, but maybe this is not the case anymore, since a lot of time passed since I've implemented this.

Copy link
Contributor Author

@binarman binarman May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, precision issue is still there.
I suspect this is a hardware problem, though this requires more investigation of wmma behavior.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay thanks. worth understanding more. I think we can also prmote to f32 and then downcast if necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@binarman Are we currently using V_WMMA_F16_16X16X16_F16 and see accuracy mismatch with pytorch? If so, can we use V_WMMA_F32_16X16X16_F16 and then cast to fp16 as @antiagainst mentioned?

assert(shape[0] % mnkDim[0] == 0);
multiDimWarpId[0] =
urem(multiDimWarpId[0], i32_val(ceil<unsigned>(shape[0], mnkDim[0])));
if (shape[rank - 2] >= mnkDim[0]) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have quite some duplicated shape[rank - N] references. What about using some self-documenting local variables for them? Then we have less chance to be inconsistent.

Copy link
Contributor Author

@binarman binarman May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is intended to support dot3d, I suggest to refactor this code as a separate task(FYI we discussed this some time ago: #3600 (comment)),
A lot of this code is same on MFMA side and it will be better to refactor both MFMA and WMMA at the same time.

We have two ideas how to refactor this code:

  • always assume we have batch dimension in dot
  • use structure with named fields, i.e. M/N/K/B instead of indexes

Choosing one of this paths is a separate task, which will be next step after test bringup.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG to follow up on this later.

for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
offsets.push_back(
{ctaOffsetX * shapePerCta[0] + 2 * elem, ctaOffsetY * shapePerCta[1]});
elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minus is less mentally straightforward than plus. I'd suggest doing bool hasBatch = rank == 3; and then use [0 + hasBath] for M index and [1 + hasBatch] for N index.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to not change this now: all other places like this use minus style.
My suggestion is to make this refactoring a separate task,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure works for me.

lib/Dialect/TritonGPU/IR/Dialect.cpp Show resolved Hide resolved
lib/Dialect/TritonGPU/IR/Dialect.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(sorry clicked the wrong buttion before)

@antiagainst
Copy link
Collaborator

antiagainst commented May 23, 2024

Regarding tests, I treat it as what we want to invest to guard against future breakages. We don't have RDNA CI; so this can easily regress. Compared to the efforts spent on writing some tests now (which is mostly one time), I'm more concerned about the potential time lost on debugging all these complex logic in the future only via integration python tests in a sense. And we don't know how many regressions we will see throughout the journey.

Also btw lit tests don't need to be super detailed and cover all the lines; we can just cover important parts so it's not a change detector. I don't think it's a lot of effort to update them, given that the index caculation doesn't change frequently I believe. And whatever we change there it's delibrate--it can help for folks touching the code to verify their changes too. (Keep in mind that there are contributors that only do MFMA parts--they will not run their changes on some RDNA cards to verify things pass. let alone folks only touching nvidia support. But a lit tests runs everywhere and can provide us guarantees.)

@antiagainst
Copy link
Collaborator

This PR has extensive indexing calculation. So + @zhanglx13 to double check too.

This PR enables support of 3d dot and fixes tests in test_core.py
if triton.runtime.driver.active.get_current_target().arch == "gfx1100":
if in_dtype_str == "int8" or in_dtype_str == "float32":
pytest.skip(f"{in_dtype_str} is not supported in WMMA dot")
if out_dtype_str == "float16":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay thanks. worth understanding more. I think we can also prmote to f32 and then downcast if necessary.

for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
offsets.push_back(
{ctaOffsetX * shapePerCta[0] + 2 * elem, ctaOffsetY * shapePerCta[1]});
elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure works for me.

assert(shape[0] % mnkDim[0] == 0);
multiDimWarpId[0] =
urem(multiDimWarpId[0], i32_val(ceil<unsigned>(shape[0], mnkDim[0])));
if (shape[rank - 2] >= mnkDim[0]) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG to follow up on this later.

@antiagainst antiagainst marked this pull request as ready for review May 28, 2024 20:27
@antiagainst antiagainst removed the request for review from Jokeren May 28, 2024 20:31
@antiagainst antiagainst merged commit 100e2aa into triton-lang:main May 28, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants