Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

internal assert failure I saw while working on Thunder #3498

Open
crcrpar opened this issue Nov 29, 2024 · 4 comments
Open

internal assert failure I saw while working on Thunder #3498

crcrpar opened this issue Nov 29, 2024 · 4 comments
Assignees
Labels

Comments

@crcrpar
Copy link
Collaborator

crcrpar commented Nov 29, 2024

# CUDA devices:
#  0: NVIDIA RTX 6000 Ada Generation
# torch version: 2.6.0a0+git62eea62
# cuda version: 12.6
# nvfuser version: 0.2.23+git7b92716
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id4(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[64], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[16, 64], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T5 = fd.ops.broadcast_in_dim(T0, shape=[16, 64], broadcast_dims=[1])
    T6 = fd.ops.cast(T1, dtype=DataType.Float)
    T7 = fd.ops.cast(T5, dtype=DataType.Float)
    T8 = fd.ops.add(T6, T7)
    T9 = fd.ops.mul(T8, T8)
    T10 = fd.ops.mul(T9, T8)
    S11 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T12 = fd.ops.mul(S11, T8)
    S13 = fd.define_scalar(0.0447150, dtype=DataType.Double)
    T14 = fd.ops.mul(S13, T10)
    T15 = fd.ops.add(T8, T14)
    S16 = fd.define_scalar(0.797885, dtype=DataType.Double)
    T17 = fd.ops.mul(S16, T15)
    T18 = fd.ops.tanh(T17)
    S19 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T20 = fd.ops.add(S19, T18)
    T21 = fd.ops.mul(T12, T20)
    T22 = fd.ops.abs(T21)
    T23 = fd.ops.max(T22, dims=[0, 1], keepdim=False, dtype=DataType.Null)
    T24 = fd.ops.cast(T23, dtype=DataType.Double)
    T25 = fd.ops.ne(T24, T24)
    S26 = fd.define_scalar(1.00000e-12, dtype=DataType.Double)
    T27 = fd.ops.gt(T24, S26)
    S28 = fd.define_scalar(1.00000e-12, dtype=DataType.Double)
    T29 = fd.ops.where(T27, T24, S28)
    T30 = fd.ops.where(T25, T24, T29)
    S31 = fd.define_scalar(448.000, dtype=DataType.Double)
    T32 = fd.ops.reciprocal(T30)
    T33 = fd.ops.mul(S31, T32)
    T34 = fd.ops.cast(T33, dtype=DataType.Float)
    T38 = fd.ops.broadcast_in_dim(T34, shape=[16, 64], broadcast_dims=[])
    T39 = fd.ops.mul(T21, T38)
    T40 = fd.ops.ne(T39, T39)
    S41 = fd.define_scalar(-448.000, dtype=DataType.Double)
    T42 = fd.ops.gt(T39, S41)
    S43 = fd.define_scalar(-448.000, dtype=DataType.Double)
    T44 = fd.ops.where(T42, T39, S43)
    T45 = fd.ops.where(T40, T39, T44)
    T46 = fd.ops.ne(T45, T45)
    S47 = fd.define_scalar(448.000, dtype=DataType.Double)
    T48 = fd.ops.lt(T45, S47)
    S49 = fd.define_scalar(448.000, dtype=DataType.Double)
    T50 = fd.ops.where(T48, T45, S49)
    T51 = fd.ops.where(T46, T45, T50)
    fd.add_output(T34)
    fd.add_output(T51)

with FusionDefinition() as fd:
    nvfuser_fusion_id4(fd)

inputs = [
    torch.testing.make_tensor((64,), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((16, 64), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs)
Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/__init__.py", line 317, in execute
    results = self._execute(
RuntimeError:  INTERNAL ASSERT FAILED at "/home/mkozuki/ghq/github.com/crcrpar/Fuser/csrc/device_lower/analysis/sync_information.cpp":827, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV52 (T52_l___bfloat[ iS422{1}, iUS424{1}, ithreadIdx.x425{16}_p, iV421{4} ]) and TV2(T2_l___bfloat[ iS199{1}, iUS201{1}, ithreadIdx.x202{16}_p, iS198{4} ] ca_pos( 4 )). Producer is required to be in Global or Shared Memory based on parallelization strategy. RAW flags: (threadIdx.x)
Exception raised from SyncMap at /home/mkozuki/ghq/github.com/crcrpar/Fuser/csrc/device_lower/analysis/sync_information.cpp:827 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xcb (0x72ef20c4811b in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x3b (0x72ef20c4837b in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: nvfuser::SyncMap::SyncMap(nvfuser::Fusion*) + 0x35bd (0x72ef20b0873d in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x33b20f (0x72ef20b3b20f in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: nvfuser::GpuLower::GpuLower(nvfuser::Fusion*, nvfuser::CompileParams const&) + 0x1579 (0x72ef20b39039 in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: nvfuser::KernelExecutor::compile(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::SchedulerType) + 0x7bf (0x72ef20f1abbf in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x72fc2f (0x72ef20f2fc2f in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x75fad2 (0x72ef20f5fad2 in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x9ee (0x72ef20f5ea6e in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x179 (0x72ef20f51cb9 in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >) const + 0xacc (0x72ef210cc5ec in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x13c688 (0x72ef2093c688 in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x13ba47 (0x72ef2093ba47 in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #13: <unknown function> + 0x1cb301 (0x72ef209cb301 in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)

Steps to reproduce

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
import thunder
from thunder.tests.make_tensor import make_tensor


def main():
    batch_size, in_features, out_features = 16, 32, 64

    device = torch.device("cuda")
    dtype = torch.bfloat16
    bias = True

    model = nn.Sequential(
        nn.Linear(in_features, out_features, bias=bias),
        nn.GELU(approximate="tanh"),
        nn.Linear(out_features, out_features, bias=bias),
    ).to(device=device, dtype=dtype)
    fp8_model = convert_to_float8_training(model)
    x = make_tensor((batch_size, in_features), device=device, dtype=dtype)

    jitted = thunder.jit(fp8_model, executors=[thunder.get_executor("torch"), thunder.get_executor("nvfuser")])
    actual = jitted(x)


if __name__ == "__main__":
    main()

Just FYI, this script works if

  • a model is just one nn.Linear(in_features, out_features, bias=bias)
  • dtype is float32
@kevinstephano
Copy link
Collaborator

@crcrpar could we get the Thunder command as well?

@crcrpar
Copy link
Collaborator Author

crcrpar commented Nov 30, 2024

@kevinstephano I added the section of repro steps to the description

@naoyam
Copy link
Collaborator

naoyam commented Dec 17, 2024

The fusion is scheduled as an inner normalization kernel with no segmentation. Looks like there's some issue with the persistent buffer logic. T52 is picked as the persistent buffer tensor, which makes its inlining position 0. T2 is an immediate consumer of T52 and has a broadcast that doesn't exist in T52. Since T2 is inlined, its broadcast ID seems to be promoted.

Here's what the fusion looks like. It's fairly complex.

Image

Fusion IR math:

Inputs:
  T0_g___bfloat[iS428{1}, iS430{1}, iS431{16}, iS427{4}]
  T1_g___bfloat[iS220{1}, iS222{1}, iS223{256}, iS219{4}]
Outputs:
  T54_g_float[]
  T36_g_float[iS339{1}, iUS341{1}, ithreadIdx.x342{256}_p, iV338{4}] ca_pos( 3 ) produce_pos( 3 )

%kernel_math {
T53_l___bfloat[iS213{1}, iUS215{1}, ithreadIdx.x216{256}_p, iV212{4}]
   = Set( T1_g___bfloat[iS220{1}, iS222{1}, iS223{256}, iS219{4}], cache_op=Streaming )
T4_l_float[iS206{1}, iUS208{1}, ithreadIdx.x209{256}_p, iS205{4}] ca_pos( 4 )
   = __bfloat2float(T53_l___bfloat[iS213{1}, iUS215{1}, ithreadIdx.x216{256}_p, iV212{4}]);
T52_l___bfloat[iS422{1}, iUS424{1}, ithreadIdx.x425{16}_p, iV421{4}]
   = Set( T0_g___bfloat[iS428{1}, iS430{1}, iS431{16}, iS427{4}], cache_op=AllLevels )
T2_l___bfloat[iS199{1}, iUS201{1}, ithreadIdx.x202{16}_p, iS198{4}] ca_pos( 4 )
   = broadcast( T52_l___bfloat[iS422{1}, iUS424{1}, ithreadIdx.x425{16}_p, iV421{4}] )
T3_l___bfloat[iS192{1}, iUS194{1}, ithreadIdx.x195{16 ex ( ceilDiv(( 16 * 64 ), 4) )}_p, iS191{4}] ca_pos( 4 ) produce_pos( 4 ) = expand( T2_l___bfloat[iS199{1}, iUS201{1}, ithreadIdx.x202{16}_p, iS198{4}] ca_pos( 4 ), {16, 64} )
T5_l_float[iS185{1}, iUS187{1}, ithreadIdx.x188{16 ex ( ceilDiv(( 16 * 64 ), 4) )}_p, iS184{4}] ca_pos( 4 ) produce_pos( 4 )
   = __bfloat2float(T3_l___bfloat[iS192{1}, iUS194{1}, ithreadIdx.x195{16 ex ( ceilDiv(( 16 * 64 ), 4) )}_p, iS191{4}] ca_pos( 4 ) produce_pos( 4 ));
T6_l_float[iS178{1}, iUS180{1}, ithreadIdx.x181{256}_p, iS177{4}] ca_pos( 4 ) produce_pos( 4 )
   = T4_l_float[iS206{1}, iUS208{1}, ithreadIdx.x209{256}_p, iS205{4}] ca_pos( 4 )
   + T5_l_float[iS185{1}, iUS187{1}, ithreadIdx.x188{16 ex ( ceilDiv(( 16 * 64 ), 4) )}_p, iS184{4}] ca_pos( 4 ) produce_pos( 4 );
T9_l_float[iS409{1}, iUS411{1}, ithreadIdx.x412{256}_p, iS408{4}] ca_pos( 4 ) produce_pos( 4 )
   = double(0.5)
   * T6_l_float[iS178{1}, iUS180{1}, ithreadIdx.x181{256}_p, iS177{4}] ca_pos( 4 ) produce_pos( 4 );
T7_l_float[iS171{1}, iUS173{1}, ithreadIdx.x174{256}_p, iS170{4}] ca_pos( 4 ) produce_pos( 4 )
   = T6_l_float[iS178{1}, iUS180{1}, ithreadIdx.x181{256}_p, iS177{4}] ca_pos( 4 ) produce_pos( 4 )
   * T6_l_float[iS178{1}, iUS180{1}, ithreadIdx.x181{256}_p, iS177{4}] ca_pos( 4 ) produce_pos( 4 );
T8_l_float[iS164{1}, iUS166{1}, ithreadIdx.x167{256}_p, iS163{4}] ca_pos( 4 ) produce_pos( 4 )
   = T7_l_float[iS171{1}, iUS173{1}, ithreadIdx.x174{256}_p, iS170{4}] ca_pos( 4 ) produce_pos( 4 )
   * T6_l_float[iS178{1}, iUS180{1}, ithreadIdx.x181{256}_p, iS177{4}] ca_pos( 4 ) produce_pos( 4 );
T10_l_float[iS157{1}, iUS159{1}, ithreadIdx.x160{256}_p, iS156{4}] ca_pos( 4 ) produce_pos( 4 )
   = double(0.044714999999999998)
   * T8_l_float[iS164{1}, iUS166{1}, ithreadIdx.x167{256}_p, iS163{4}] ca_pos( 4 ) produce_pos( 4 );
T11_l_float[iS150{1}, iUS152{1}, ithreadIdx.x153{256}_p, iS149{4}] ca_pos( 4 ) produce_pos( 4 )
   = T6_l_float[iS178{1}, iUS180{1}, ithreadIdx.x181{256}_p, iS177{4}] ca_pos( 4 ) produce_pos( 4 )
   + T10_l_float[iS157{1}, iUS159{1}, ithreadIdx.x160{256}_p, iS156{4}] ca_pos( 4 ) produce_pos( 4 );
T12_l_float[iS143{1}, iUS145{1}, ithreadIdx.x146{256}_p, iS142{4}] ca_pos( 4 ) produce_pos( 4 )
   = double(0.79788499999999996)
   * T11_l_float[iS150{1}, iUS152{1}, ithreadIdx.x153{256}_p, iS149{4}] ca_pos( 4 ) produce_pos( 4 );
T13_l_float[iS136{1}, iUS138{1}, ithreadIdx.x139{256}_p, iS135{4}] ca_pos( 4 ) produce_pos( 4 )
   = tanhf(T12_l_float[iS143{1}, iUS145{1}, ithreadIdx.x146{256}_p, iS142{4}] ca_pos( 4 ) produce_pos( 4 ));
T14_l_float[iS129{1}, iUS131{1}, ithreadIdx.x132{256}_p, iS128{4}] ca_pos( 4 ) produce_pos( 4 )
   = double(1)
   + T13_l_float[iS136{1}, iUS138{1}, ithreadIdx.x139{256}_p, iS135{4}] ca_pos( 4 ) produce_pos( 4 );
T15_l_float[iS122{1}, iUS124{1}, ithreadIdx.x125{256}_p, iS121{4}] ca_pos( 4 ) produce_pos( 4 )
   = T9_l_float[iS409{1}, iUS411{1}, ithreadIdx.x412{256}_p, iS408{4}] ca_pos( 4 ) produce_pos( 4 )
   * T14_l_float[iS129{1}, iUS131{1}, ithreadIdx.x132{256}_p, iS128{4}] ca_pos( 4 ) produce_pos( 4 );
T16_l_float[iS115{1}, iUS117{1}, ithreadIdx.x118{256}_p, iS114{4}] ca_pos( 4 ) produce_pos( 4 )
   = abs(T15_l_float[iS122{1}, iUS124{1}, ithreadIdx.x125{256}_p, iS121{4}] ca_pos( 4 ) produce_pos( 4 ));
T57_l_float[rS107{1}rf, rUS109{1}rf, ithreadIdx.x110{256}rf_p, rS106{4}rf] produce_pos( 4 )
   = reduction( T16_l_float[iS115{1}, iUS117{1}, ithreadIdx.x118{256}_p, iS114{4}] ca_pos( 4 ) produce_pos( 4 ), op = fmax, initial value = double(-inf), allreduce = false )
T17_l_float[rthreadIdx.x111{256}_p]
   = reduction( T57_l_float[rS107{1}rf, rUS109{1}rf, ithreadIdx.x110{256}rf_p, rS106{4}rf] produce_pos( 4 ), op = fmax, initial value = double(-inf), allreduce = false )
T18_l_double[]
   = (double)(T17_l_float[rthreadIdx.x111{256}_p]);
T19_l_bool[]
   = T18_l_double[]
   != T18_l_double[];
T20_l_bool[]
   = T18_l_double[]
   > double(9.9999999999999998e-13);
T21_l_double[]
   = where(T20_l_bool[]
  , T18_l_double[]
  , T21_l_double[]);
T23_l_double[]
   = reciprocal(T22_l_double[]);
T24_l_double[]
   = double(448)
   * T23_l_double[];
T25_l_float[]
   = (float)(T24_l_double[]);
T55_l_float[]
   = Set( T25_l_float[], cache_op=Streaming )
T54_g_float[]
   = Set( T55_l_float[], cache_op=Streaming )
T37_l_float[iS227{1}, iUS229{1}, ithreadIdx.x230{256}_p, iS226{4}] ca_pos( 4 )
   = __bfloat2float(T53_l___bfloat[iS213{1}, iUS215{1}, ithreadIdx.x216{256}_p, iV212{4}]);
T38_l___bfloat[iS255{1}, iUS257{1}, ithreadIdx.x258{16}_p, iS254{4}] ca_pos( 4 )
   = broadcast( T52_l___bfloat[iS422{1}, iUS424{1}, ithreadIdx.x425{16}_p, iV421{4}] )
T39_l___bfloat[iS248{1}, iUS250{1}, ithreadIdx.x251{16 ex ( ceilDiv(( 16 * 64 ), 4) )}_p, iS247{4}] ca_pos( 4 ) produce_pos( 4 ) = expand( T38_l___bfloat[iS255{1}, iUS257{1}, ithreadIdx.x258{16}_p, iS254{4}] ca_pos( 4 ), {16, 64} )
T40_l_float[iS241{1}, iUS243{1}, ithreadIdx.x244{16 ex ( ceilDiv(( 16 * 64 ), 4) )}_p, iS240{4}] ca_pos( 4 ) produce_pos( 4 )
   = __bfloat2float(T39_l___bfloat[iS248{1}, iUS250{1}, ithreadIdx.x251{16 ex ( ceilDiv(( 16 * 64 ), 4) )}_p, iS247{4}] ca_pos( 4 ) produce_pos( 4 ));
T41_l_float[iS234{1}, iUS236{1}, ithreadIdx.x237{256}_p, iS233{4}] ca_pos( 4 ) produce_pos( 4 )
   = T37_l_float[iS227{1}, iUS229{1}, ithreadIdx.x230{256}_p, iS226{4}] ca_pos( 4 )
   + T40_l_float[iS241{1}, iUS243{1}, ithreadIdx.x244{16 ex ( ceilDiv(( 16 * 64 ), 4) )}_p, iS240{4}] ca_pos( 4 ) produce_pos( 4 );
T42_l_float[iS402{1}, iUS404{1}, ithreadIdx.x405{256}_p, iS401{4}] ca_pos( 4 ) produce_pos( 4 )
   = double(0.5)
   * T41_l_float[iS234{1}, iUS236{1}, ithreadIdx.x237{256}_p, iS233{4}] ca_pos( 4 ) produce_pos( 4 );
T43_l_float[iS395{1}, iUS397{1}, ithreadIdx.x398{256}_p, iS394{4}] ca_pos( 4 ) produce_pos( 4 )
   = T41_l_float[iS234{1}, iUS236{1}, ithreadIdx.x237{256}_p, iS233{4}] ca_pos( 4 ) produce_pos( 4 )
   * T41_l_float[iS234{1}, iUS236{1}, ithreadIdx.x237{256}_p, iS233{4}] ca_pos( 4 ) produce_pos( 4 );
T44_l_float[iS388{1}, iUS390{1}, ithreadIdx.x391{256}_p, iS387{4}] ca_pos( 4 ) produce_pos( 4 )
   = T43_l_float[iS395{1}, iUS397{1}, ithreadIdx.x398{256}_p, iS394{4}] ca_pos( 4 ) produce_pos( 4 )
   * T41_l_float[iS234{1}, iUS236{1}, ithreadIdx.x237{256}_p, iS233{4}] ca_pos( 4 ) produce_pos( 4 );
T45_l_float[iS269{1}, iUS271{1}, ithreadIdx.x272{256}_p, iS268{4}] ca_pos( 4 ) produce_pos( 4 )
   = double(0.044714999999999998)
   * T44_l_float[iS388{1}, iUS390{1}, ithreadIdx.x391{256}_p, iS387{4}] ca_pos( 4 ) produce_pos( 4 );
T46_l_float[iS262{1}, iUS264{1}, ithreadIdx.x265{256}_p, iS261{4}] ca_pos( 4 ) produce_pos( 4 )
   = T41_l_float[iS234{1}, iUS236{1}, ithreadIdx.x237{256}_p, iS233{4}] ca_pos( 4 ) produce_pos( 4 )
   + T45_l_float[iS269{1}, iUS271{1}, ithreadIdx.x272{256}_p, iS268{4}] ca_pos( 4 ) produce_pos( 4 );
T47_l_float[iS276{1}, iUS278{1}, ithreadIdx.x279{256}_p, iS275{4}] ca_pos( 4 ) produce_pos( 4 )
   = double(0.79788499999999996)
   * T46_l_float[iS262{1}, iUS264{1}, ithreadIdx.x265{256}_p, iS261{4}] ca_pos( 4 ) produce_pos( 4 );
T48_l_float[iS283{1}, iUS285{1}, ithreadIdx.x286{256}_p, iS282{4}] ca_pos( 4 ) produce_pos( 4 )
   = tanhf(T47_l_float[iS276{1}, iUS278{1}, ithreadIdx.x279{256}_p, iS275{4}] ca_pos( 4 ) produce_pos( 4 ));
T49_l_float[iS290{1}, iUS292{1}, ithreadIdx.x293{256}_p, iS289{4}] ca_pos( 4 ) produce_pos( 4 )
   = double(1)
   + T48_l_float[iS283{1}, iUS285{1}, ithreadIdx.x286{256}_p, iS282{4}] ca_pos( 4 ) produce_pos( 4 );
T50_l_float[iS297{1}, iUS299{1}, ithreadIdx.x300{256}_p, iS296{4}] ca_pos( 4 ) produce_pos( 4 )
   = T42_l_float[iS402{1}, iUS404{1}, ithreadIdx.x405{256}_p, iS401{4}] ca_pos( 4 ) produce_pos( 4 )
   * T49_l_float[iS290{1}, iUS292{1}, ithreadIdx.x293{256}_p, iS289{4}] ca_pos( 4 ) produce_pos( 4 );
T26_l_float[bS318{1}, bUS320{1}, bthreadIdx.x321{1}_p, bS317{4}]
   = broadcast( T25_l_float[] )
T27_l_float[bS311{1}, bUS313{1}, bthreadIdx.x314{1 ex ( ceilDiv(( 16 * 64 ), 4) )}_p, bS310{4}] = expand( T26_l_float[bS318{1}, bUS320{1}, bthreadIdx.x321{1}_p, bS317{4}], {16, 64} )
T28_l_float[iS304{1}, iUS306{1}, ithreadIdx.x307{256}_p, iS303{4}] ca_pos( 4 ) produce_pos( 4 )
   = T50_l_float[iS297{1}, iUS299{1}, ithreadIdx.x300{256}_p, iS296{4}] ca_pos( 4 ) produce_pos( 4 )
   * T27_l_float[bS311{1}, bUS313{1}, bthreadIdx.x314{1 ex ( ceilDiv(( 16 * 64 ), 4) )}_p, bS310{4}];
T29_l_bool[iS381{1}, iUS383{1}, ithreadIdx.x384{256}_p, iS380{4}] ca_pos( 4 ) produce_pos( 4 )
   = T28_l_float[iS304{1}, iUS306{1}, ithreadIdx.x307{256}_p, iS303{4}] ca_pos( 4 ) produce_pos( 4 )
   != T28_l_float[iS304{1}, iUS306{1}, ithreadIdx.x307{256}_p, iS303{4}] ca_pos( 4 ) produce_pos( 4 );
T30_l_bool[iS374{1}, iUS376{1}, ithreadIdx.x377{256}_p, iS373{4}] ca_pos( 4 ) produce_pos( 4 )
   = T28_l_float[iS304{1}, iUS306{1}, ithreadIdx.x307{256}_p, iS303{4}] ca_pos( 4 ) produce_pos( 4 )
   > double(-448);
T31_l_float[iS367{1}, iUS369{1}, ithreadIdx.x370{256}_p, iS366{4}] ca_pos( 4 ) produce_pos( 4 )
   = where(T30_l_bool[iS374{1}, iUS376{1}, ithreadIdx.x377{256}_p, iS373{4}] ca_pos( 4 ) produce_pos( 4 )
  , T28_l_float[iS304{1}, iUS306{1}, ithreadIdx.x307{256}_p, iS303{4}] ca_pos( 4 ) produce_pos( 4 )
  , double(-448));
T32_l_float[iS325{1}, iUS327{1}, ithreadIdx.x328{256}_p, iS324{4}] ca_pos( 4 ) produce_pos( 4 )
   = where(T29_l_bool[iS381{1}, iUS383{1}, ithreadIdx.x384{256}_p, iS380{4}] ca_pos( 4 ) produce_pos( 4 )
  , T28_l_float[iS304{1}, iUS306{1}, ithreadIdx.x307{256}_p, iS303{4}] ca_pos( 4 ) produce_pos( 4 )
  , T31_l_float[iS367{1}, iUS369{1}, ithreadIdx.x370{256}_p, iS366{4}] ca_pos( 4 ) produce_pos( 4 ));
T33_l_bool[iS360{1}, iUS362{1}, ithreadIdx.x363{256}_p, iS359{4}] ca_pos( 4 ) produce_pos( 4 )
   = T32_l_float[iS325{1}, iUS327{1}, ithreadIdx.x328{256}_p, iS324{4}] ca_pos( 4 ) produce_pos( 4 )
   != T32_l_float[iS325{1}, iUS327{1}, ithreadIdx.x328{256}_p, iS324{4}] ca_pos( 4 ) produce_pos( 4 );
T34_l_bool[iS353{1}, iUS355{1}, ithreadIdx.x356{256}_p, iS352{4}] ca_pos( 4 ) produce_pos( 4 )
   = T32_l_float[iS325{1}, iUS327{1}, ithreadIdx.x328{256}_p, iS324{4}] ca_pos( 4 ) produce_pos( 4 )
   < double(448);
T35_l_float[iS346{1}, iUS348{1}, ithreadIdx.x349{256}_p, iS345{4}] ca_pos( 4 ) produce_pos( 4 )
   = where(T34_l_bool[iS353{1}, iUS355{1}, ithreadIdx.x356{256}_p, iS352{4}] ca_pos( 4 ) produce_pos( 4 )
  , T32_l_float[iS325{1}, iUS327{1}, ithreadIdx.x328{256}_p, iS324{4}] ca_pos( 4 ) produce_pos( 4 )
  , double(448));
T56_l_float[iS332{1}, iUS334{1}, ithreadIdx.x335{256}_p, iS331{4}] ca_pos( 3 ) produce_pos( 4 )
   = where(T33_l_bool[iS360{1}, iUS362{1}, ithreadIdx.x363{256}_p, iS359{4}] ca_pos( 4 ) produce_pos( 4 )
  , T32_l_float[iS325{1}, iUS327{1}, ithreadIdx.x328{256}_p, iS324{4}] ca_pos( 4 ) produce_pos( 4 )
  , T35_l_float[iS346{1}, iUS348{1}, ithreadIdx.x349{256}_p, iS345{4}] ca_pos( 4 ) produce_pos( 4 ));
T36_g_float[iS339{1}, iUS341{1}, ithreadIdx.x342{256}_p, iV338{4}] ca_pos( 3 ) produce_pos( 3 )
   = Set( T56_l_float[iS332{1}, iUS334{1}, ithreadIdx.x335{256}_p, iS331{4}] ca_pos( 3 ) produce_pos( 4 ), cache_op=Streaming )
} // %kernel_math

I'm not sure what to blame. It seems the sync analysis is doing the right analysis since if the fusion is scheduled this way, T52 needs to be in shmem. I also suspect the inlining position of T52 is correct. I'm not entirely sure if this fusion can really be scheduled as a inner-normalization kernel with T52 as the persistent buffer.

There're some 0-dim tensors after the reduction. Not sure if they are related to the error (I suspect not).

@liqiangxl Could you please take a look when you have time?
@crcrpar Please let me know if this is urgent.

naoyam added a commit that referenced this issue Dec 17, 2024
`NVFUSER_DUMP=fusion_ir_graph` saves the dot representation of a fusion
before lowering to a file named like
`__tmp_fusion_ir_graph_inner_persistent_f0_c1_r0_g0.dot`.

Example visualization:
#3498 (comment)
@crcrpar
Copy link
Collaborator Author

crcrpar commented Dec 18, 2024

I think this isn't urgent because I'm seeing some working cases under slightly different settings and my recent update leads to another issue which looks orthogonal to nvfuser.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants