Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

_allow_non_fake_inputs parameter of make_fx has no effect #125347

Open
ttrouwen-dmatrix opened this issue May 1, 2024 · 1 comment
Open

_allow_non_fake_inputs parameter of make_fx has no effect #125347

ttrouwen-dmatrix opened this issue May 1, 2024 · 1 comment
Labels
module: fakeTensor module: fx oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ttrouwen-dmatrix
Copy link

ttrouwen-dmatrix commented May 1, 2024

馃悰 Describe the bug

Running make_fx with _allow_non_fake_inputs=True will throw Exception that the inputs are non fake:

Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.t.default(Parameter containing:
tensor([...], size=(64, 64), requires_grad=True))

Notice that this issue does not reproduce on Torch 2.0 but does occur on Torch 2.1-2.3
This is a full reproducer:

import torch
import torch.nn as nn
from torch.fx.experimental.proxy_tensor import make_fx

class M(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.l = nn.Linear(64, 64)
    
    def forward(self, inp):
        return self.l(inp)

m = M()
inp = torch.zeros(64, 64)

def backend(gm, inps):
    input_cloned = [i.clone().detach() for i in inps][0]
    # inputs_cloned = [i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]
    gm = make_fx(
        gm,
        tracing_mode="fake",
        _allow_non_fake_inputs=True,
        decomposition_table=None,
    )(input_cloned)

    return gm

fn = torch.compile(m, backend=backend)
fn(inp)

Versions

PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.29.2
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.2.0-1018-azure-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             8
On-line CPU(s) list:                0-7
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
CPU family:                         6
Model:                              106
Thread(s) per core:                 2
Core(s) per socket:                 4
Socket(s):                          1
Stepping:                           6
CPU max MHz:                        2800.0000
CPU min MHz:                        800.0000
BogoMIPS:                           5586.87
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid aperfmperf pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm arch_capabilities
Virtualization:                     VT-x
Hypervisor vendor:                  Microsoft
Virtualization type:                full
L1d cache:                          192 KiB (4 instances)
L1i cache:                          128 KiB (4 instances)
L2 cache:                           5 MiB (4 instances)
L3 cache:                           48 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-7
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.3.0
[pip3] triton==2.3.0
[conda] Could not collect

cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @msaroufim @bdhirsh @anijain2305 @chauhang @eellison

@ttrouwen-dmatrix
Copy link
Author

This occurs because a call to detect_fake_mode was added to make_fx. If a fake mode is detected, the _allow_non_fake_inputs argument is ignored.

A hack to circumvent the issue is by changing some lines in make_fx to the following:

elif tracing_mode == "fake":
    import torch._dynamo
    fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
    if fake_tensor_mode is None:
        # ...
    else:  # This line and the next line should be inserted after line 1193
        fake_tensor_mode.allow_non_fake_inputs = _allow_non_fake_inputs

@yanboliang yanboliang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: fakeTensor module: fx oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants