-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
error when call tl.sort for data of torch.bfloat16 #3873
Comments
In another env I get a different error: Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/triton/language/core.py", line 33, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/language/core.py", line 912, in to
return semantic.bitcast(self, dtype, _builder)
File "/usr/local/lib/python3.10/dist-packages/triton/language/semantic.py", line 737, in bitcast
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
ValueError: Cannot bitcast data-type of size 32 to data-type of size 16
The above exception was the direct cause of the following exception:
triton.compiler.errors.CompilationError: at 13:12:
n_outer: core.constexpr = x.numel >> n_dims
shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
y = core.reshape(x, shape)
# slice left/right with 'stride' 2**(n_dims - i - 1)
mask = core.arange(0, 2)[None, :, None]
left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape)
right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape)
left = core.reshape(left, x.shape)
right = core.reshape(right, x.shape)
# actual compare-and-swap
idtype = core.dtype(f'int{core.constexpr(x.dtype.primitive_bitwidth)}')
ileft = left.to(idtype, bitcast=True) I think this error comes from below code in mask = core.arange(0, 2)[None, :, None]
left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape)
right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape) So I try to cast mask = core.arange(0, 2)[None, :, None]
left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype)
right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape).to(y.dtype) Then I get the same error as the beginning: loc(callsite(callsite(callsite("/usr/local/lib/python3.10/dist-packages/triton/language/standard.py":338:34 at "/usr/local/lib/python3.10/dist-packages/triton/language/standard.py":363:61) at "/usr/local/lib/python3.10/dist-packages/triton/language/standard.py":376:66) at "/workspace2/triton-test/test.py":18:18)): error: 'llvm.fcmp' op operand #0 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type, but got 'i16'
Traceback (most recent call last):
File "/workspace2/triton-test/test.py", line 49, in <module>
test_kernel[(8, )](a, b, 16, 16)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 180, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 401, in run
self.cache[device][key] = compile(
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 268, in compile
next_module = compile_ir(module, metadata)
File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/compiler.py", line 265, in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/compiler.py", line 178, in make_llir
pm.run(mod)
RuntimeError: PassManager::run failed I think the error comes from line 338 where ret = ix ^ core.where((left > right) ^ flip, ileft ^ iright, zeros_like(ix)) FYI, thx. |
ThomasRaoux
pushed a commit
that referenced
this issue
May 29, 2024
… bug (#3975) Since LLVM now support `bf16`, it is not necessary that [represent `bf16` as `i16`](#1245 (comment)) in TritonGPUtoLLVM conversion, in which case `bf16` compare makes mistake as compare is converted to `arith.cmpf` while `i16` is not compatible with `arith.cmpf`, thus `bf16` compare and `tl.sort` both report [bug](#3873). Meanwhile, use of `core.arange` in `_compare_and_swap` causes the unaligned data type when call `tl.sort` for `bf16`. Data type of `left` and `right` needs to be casted to `y.dtype` to fix `tl.sort`. The revision have passed the python tests as below in docker on H100: ```sh $ sudo pip uninstall pytorch-triton $ cd triton $ pip install -e python $ python -m pytest python/test/unit # ... 11309 passed, 1219 skipped, 156 warnings in 3222.26s (0:53:42) ``` However, I cannot build cpp test with the errors: ```sh $ cd python/build/cmake.linux-x86_64-cpython-3.10/ $ ninja test [0/1] Re-running CMake... /bin/bash: line 1: /tmp/pip-build-env-phcu6k1b/overlay/local/lib/python3.10/dist-packages/cmake/data/bin/cmake: No such file or directory FAILED: build.ninja /tmp/pip-build-env-phcu6k1b/overlay/local/lib/python3.10/dist-packages/cmake/data/bin/cmake --regenerate-during-build -S/home/scratch.haoruoc_gpu/repos/triton -B/home/scratch.haoruoc_gpu/repos/triton/python/build/cmake.linux-x86_64-cpython-3.10 ninja: error: rebuilding 'build.ninja': subcommand failed ``` The given path in `build.ninja` does not exist. Besides, I do not revise AMD backend as I have no access to corresponding hardware. --------- Co-authored-by: haoruoc <[email protected]>
Fixed with #3975. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When I call
tl.sort
with input of dtype bf16 I get the error.I use triton 3.0.0. To reproduce
It is confused that fp16 and fp32 work.
FYI, I think the error code is located in function
_compare_and_swap
of standard.py.The text was updated successfully, but these errors were encountered: