Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why not allow JITFunction as parameter to another JITFunction(high-order jit function)? #3918

Open
iclementine opened this issue May 15, 2024 · 0 comments

Comments

@iclementine
Copy link

iclementine commented May 15, 2024

I notice that it is not possible to pass a JITFunction as the parameter to another JITFunction(just call it higher order JITFunction for now).

The code below is an example of a pointwise function _jit_function, whose operations to map to inputs elements are defined in scalar_fn.

# test_pointwise.py
import triton
import triton.language as tl
import torch

@triton.jit
def scalar_fn(x):
    return tl.log(1 + tl.exp(x))

@triton.jit
def _jit_function(
    in_ptr, o_ptr,
    size,
    tile_size: tl.constexpr,
):
    pid = tl.program_id(0)
    tid = pid * tile_size + tl.arange(0, tile_size)
    mask = tid < size

    input_ = tl.load(in_ptr + tid, mask=mask)
    out = scalar_fn(input_)
    tl.store(o_ptr + tid, out, mask=mask)

def _wrapper(x: torch.Tensor):
    out = torch.empty_like(x)
    size = out.numel()
    tile_size = 512
    grid = triton.cdiv(size, tile_size), 1, 1
    _jit_function[grid](
        x, out, size,
        tile_size=tile_size,
        num_warps=4,
    )
    return out

x = torch.randn((3, 4), device="cuda")
print(_wrapper(x))

The JITFunction scalar_fn can be called at the function body of another JITFunction(_jit_function). However, if I want to add the scalar function as a parameter to _jit_function.

# test_pointwise_high_order.py
import triton
import triton.language as tl
import torch

@triton.jit
def scalar_fn(x):
    return tl.log(1 + tl.exp(x))

@triton.jit
def _jit_function(
    in_ptr, o_ptr,
    size,
    f: tl.constexpr,
    tile_size: tl.constexpr,
):
    pid = tl.program_id(0)
    tid = pid * tile_size + tl.arange(0, tile_size)
    mask = tid < size

    input_ = tl.load(in_ptr + tid, mask=mask)
    out = f(input_)
    tl.store(o_ptr + tid, out, mask=mask)

def _wrapper(x: torch.Tensor):
    out = torch.empty_like(x)
    size = out.numel()
    tile_size = 512
    grid = triton.cdiv(size, tile_size), 1, 1
    _jit_function[grid](
        x, out, size,
        f=scalar_fn,
        tile_size=tile_size,
        num_warps=4,
    )
    return out

x = torch.randn((3, 4), device="cuda")
print(_wrapper(x))

I would get an error.

Traceback (most recent call last):
  File "test_pointwise_high_order.py", line 38, in <module>
    print(_wrapper(x))
  File "test_pointwise_high_order.py", line 29, in _wrapper
    _jit_function[grid](
  File "<path_to_site_packages>/triton/runtime/jit.py", line 508, in run
    raise TypeError(f"Callable constexpr at index {i} is not supported")
TypeError: Callable constexpr at index 3 is not supported

I notice that the restriction that parameters with typehint tl.constexpr (constant args) cannot be callable was first introduced in
#644.

https://github.com/triton-lang/triton/blob/cfa8d18b835b10fc48449924aadf5982ac10d87c/python/triton/runtime/jit.py#L268C1-L270C79

      # build stub signature -- includes arguments that are specialized
      for i, arg in constants.items():
        if callable(arg):
          raise TypeError(f"Callable constexpr at index {i} is not supported")

However, commenting relevant lines make it just work as expected.

I think JITFunctions passed to another JITFunction is always a constexpr(in the sense that it is not a parameter to the CompiledKernel). And recent PRs has include the JITFunction arguments passed to tl.reduce or tl.associative_scan in the function body into the cache_key(#3137). Then here is my question:

Why only allow triton builtins like tl.reduce and tl.associative_scan to have JITFunction as arguments while disallowing JITFunctions to have JITFunctions as constant arguments? It would make high order jit functions easier with triton.

Assume in a case that I want to make a utility that generate a _warpper function & a _jit_function (like the example above) given a _scalar_fn. Due to the fact that triton's JITFunction relies on inspect to get the source of the function that triton.jit decorates, I have to write the code into a file and dynamically imports it(via importlib or some kind of exec).

However, since JITFunction cannot take another JITFunction as argument, there is no high order JITFunction. So if I wan to generate a JITFunction that calls the user-provided scalar function, I have to make the scalar function available in the module where the generated _jit_function is defined. If I try to find where _scalar_fn is defined and try to generate code to import it into the generated code, it would be difficult since _scalar_fn may be defined in a module whose path is known, or just in a script(__name__ ==__main__) or interactive repl.

The problem for allowing JITFunction as parameters to another JITFunction I could think of is that the cache key of a JITFunction would have to be complete at run, instead of just inspecting the body of the JITFunction itself, which can be done at __init__.

The relation of high order JITFunction and non-high-order JITFunction resembles that of function templates and functions in cpp. When a high order JITFunction has all the JITFunction parameters specifiecd, it is complete(instantiated) in some sense.

So maybe it would be convenient to add an interface to instantiate a high order JITFunction into a non-high-order JITFunction to avoid analysing the cache key of a high order JITFunction at each run? @ptillet

Thank you

@iclementine iclementine changed the title Why not allow JITFunction as parameter to another JITFunction? Why not allow JITFunction as parameter to another JITFunction(high-order jit function)? May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant