Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for PyTorch operators #17

Open
ScottTodd opened this issue Aug 21, 2024 · 1 comment
Open

Add tests for PyTorch operators #17

ScottTodd opened this issue Aug 21, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@ScottTodd
Copy link
Member

Could look at the Python test cases in https://github.com/pytorch/pytorch/tree/main/test, or lean more on tests in torch-mlir: https://github.com/llvm/torch-mlir/tree/main/test. There are also some tests in https://github.com/nod-ai/SHARK-TestSuite/tree/main/e2eshark/pytorch/operators.

Generally looking for proof that the work tracked in nod-ai/SHARK-ModelDev#347 has fully landed in IREE and is supported across the full range of targets.

@ScottTodd
Copy link
Member Author

From @Groverkss here on Discord:

I've started using iree-turbine to generate tests. Writing IR for something like gather + attention fusion is somewhat hard for me, it's much easier to just write it in iree-turbine
This is the template i used:

import torch
import torch.nn as nn

import iree.turbine.aot as aot

class GatherAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, q: torch.Tensor, k_mem: torch.Tensor, k_indices:
                torch.Tensor, v: torch.Tensor):
        k = torch.gather(k_mem, dim=2, index=k_indices)
        out = nn.functional.scaled_dot_product_attention(q, k, v)
        return out

model = GatherAttention()

q = torch.randn(2, 10, 4096, 64, dtype=torch.float16)
k_mem = torch.randn(2, 10, 4096, 64, dtype=torch.float16)
v = torch.randn(2, 10, 4096, 64, dtype=torch.float16)
k_indices = torch.arange(2 * 10 * 4096 * 64).reshape((2, 10, 4096, 64))

exported = aot.export(model, q, k_mem, k_indices, v)
exported.print_readable()

Would recommend generating tests like this because you can also check correctness with torch this way

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant