Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weird AST constructor issue with mode="max-autotune" with python 3.11 #125374

Closed
Chillee opened this issue May 2, 2024 · 8 comments
Closed

Weird AST constructor issue with mode="max-autotune" with python 3.11 #125374

Chillee opened this issue May 2, 2024 · 8 comments
Assignees

Comments

@Chillee
Copy link
Contributor

Chillee commented May 2, 2024

🐛 Describe the bug

Traceback (most recent call last):
  File "/data/users/chilli/pytorch/torch/_inductor/triton_heuristics.py", line 383, in _precompile_config
    binary = triton.compile(*compile_args, **compile_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/chilli/triton_source/python/triton/compiler/compiler.py", line 268, in compile
    module = src.make_ir(options, codegen_fns, context)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/chilli/triton_source/python/triton/compiler/compiler.py", line 112, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/chilli/triton_source/python/triton/runtime/jit.py", line 553, in parse
    raise e
  File "/data/users/chilli/triton_source/python/triton/runtime/jit.py", line 551, in parse
    tree = ast.parse(self.src)
           ^^^^^^^^^^^^^^^^^^^
SystemError: AST constructor recursion depth mismatch (before=45, after=42)
backend='inductor' raised:

Repro:

import torch
import os
import sys
import shutil

directory_path = '/tmp/torchinductor_chilli/'
if os.path.exists(directory_path) and os.path.isdir(directory_path):
    shutil.rmtree(directory_path)
import torch._inductor.config
torch.set_default_device('cuda')

torch._dynamo.config.automatic_dynamic_shapes = False
# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000

@torch.compile(mode="max-autotune-no-cudagraphs")
def f(a, b):
    return torch.mm(a, b)

try:
    for N in range(512, 1024, 16):
        print(N)
        a = torch.randn(N, N, dtype=torch.bfloat16)
        b = torch.randn(N, N, dtype=torch.bfloat16)
        f(a, b)
except Exception as e:
    print(e)
    sys.exit(1)
sys.exit(0)

I bisected back to this commit: #124030

cc: @eellison

cc: @cpuhrsch

Setting TORCHINDUCTOR_COMPILE_THREADS=1 also fixes it.

Versions

python=3.11

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

@eellison
Copy link
Contributor

eellison commented May 2, 2024

I just setup a pytorch-311 build and can't repro this. Python 3.11.9. Do you have to run multiple times to repro or does it happen consistently?

@Chillee
Copy link
Contributor Author

Chillee commented May 2, 2024

@eellison hmm.. this script repros it consistently for me. Like, when you run it, it goes through compilation for all the shapes without erroring?

@eellison
Copy link
Contributor

eellison commented May 2, 2024

That's correct

@cpuhrsch
Copy link
Contributor

cpuhrsch commented May 2, 2024

I tried to get it to repo on CI in #125331 but it seems to have green / with an unrelated failure. The error is definitely real though, because we see it in AO here: pytorch/ao#197

@eellison
Copy link
Contributor

eellison commented May 2, 2024

What python version are you specifically ?

@eellison
Copy link
Contributor

eellison commented May 2, 2024

There is an AST Parsing bug that was fixed in 3.11.8: pytest-dev/pytest#11724 (comment).

I'm using 3.11.9 here so maybe that explains it.

Can @Chillee or @cpuhrsch give your 3.11 versions ?

@Chillee
Copy link
Contributor Author

Chillee commented May 2, 2024

Yeah i'm on 3.11.7

@eellison
Copy link
Contributor

eellison commented May 2, 2024

I guess we can disable precompilation from 3.11.0 to 3.11.7 and print a warning.

which has been around in that form in python3.11 releases since 3.11.0b4 up to and including 3.11.7.

eellison added a commit that referenced this issue May 2, 2024
…karound 311 cpython bug"


Fix for #125374. We dont have CI for this specific versions, but I verified locally. THere is a cpython bug from 3.11->3.17 where the ast parsing state is global, and errors with multiple threads. when dust settles a little around the new process based compilation we can look into migrating. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
eellison added a commit that referenced this issue May 3, 2024
…karound 311 cpython bug"


Fix for #125374. We dont have CI for this specific versions, but I verified locally. THere is a cpython bug from 3.11.0->3.11.7 where the ast parsing state is global, and errors with multiple threads. when dust settles a little around the new process based compilation we can look into migrating.  


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue May 3, 2024
… cpython bug (#125446)

Fix for #125374. We dont have CI for this specific versions, but I verified locally. THere is a cpython bug from 3.11.0->3.11.7 where the ast parsing state is global, and errors with multiple threads. when dust settles a little around the new process based compilation we can look into migrating.

Pull Request resolved: #125446
Approved by: https://github.com/Chillee
ghstack dependencies: #125289
@eellison eellison closed this as completed May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants