You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The JITFunction scalar_fn can be called at the function body of another JITFunction(_jit_function). However, if I want to add the scalar function as a parameter to _jit_function.
Traceback (most recent call last):
File "test_pointwise_high_order.py", line 38, in <module>
print(_wrapper(x))
File "test_pointwise_high_order.py", line 29, in _wrapper
_jit_function[grid](
File "<path_to_site_packages>/triton/runtime/jit.py", line 508, in run
raise TypeError(f"Callable constexpr at index {i} is not supported")
TypeError: Callable constexpr at index 3 is not supported
I notice that the restriction that parameters with typehint tl.constexpr (constant args) cannot be callable was first introduced in #644.
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
However, commenting relevant lines make it just work as expected.
I think JITFunctions passed to another JITFunction is always a constexpr(in the sense that it is not a parameter to the CompiledKernel). And recent PRs has include the JITFunction arguments passed to tl.reduce or tl.associative_scan in the function body into the cache_key(#3137). Then here is my question:
Why only allow triton builtins like tl.reduce and tl.associative_scan to have JITFunction as arguments while disallowing JITFunctions to have JITFunctions as constant arguments? It would make high order jit functions easier with triton.
Assume in a case that I want to make a utility that generate a _warpper function & a _jit_function (like the example above) given a _scalar_fn. Due to the fact that triton's JITFunction relies on inspect to get the source of the function that triton.jit decorates, I have to write the code into a file and dynamically imports it(via importlib or some kind of exec).
However, since JITFunction cannot take another JITFunction as argument, there is no high order JITFunction. So if I wan to generate a JITFunction that calls the user-provided scalar function, I have to make the scalar function available in the module where the generated _jit_function is defined. If I try to find where _scalar_fn is defined and try to generate code to import it into the generated code, it would be difficult since _scalar_fn may be defined in a module whose path is known, or just in a script(__name__ ==__main__) or interactive repl.
The problem for allowing JITFunction as parameters to another JITFunction I could think of is that the cache key of a JITFunction would have to be complete at run, instead of just inspecting the body of the JITFunction itself, which can be done at __init__.
The relation of high order JITFunction and non-high-order JITFunction resembles that of function templates and functions in cpp. When a high order JITFunction has all the JITFunction parameters specifiecd, it is complete(instantiated) in some sense.
So maybe it would be convenient to add an interface to instantiate a high order JITFunction into a non-high-order JITFunction to avoid analysing the cache key of a high order JITFunction at each run? @ptillet
Thank you
The text was updated successfully, but these errors were encountered:
iclementine
changed the title
Why not allow JITFunction as parameter to another JITFunction?
Why not allow JITFunction as parameter to another JITFunction(high-order jit function)?
May 15, 2024
I notice that it is not possible to pass a JITFunction as the parameter to another JITFunction(just call it higher order JITFunction for now).
The code below is an example of a pointwise function
_jit_function
, whose operations to map to inputs elements are defined inscalar_fn
.The JITFunction
scalar_fn
can be called at the function body of another JITFunction(_jit_function
). However, if I want to add the scalar function as a parameter to_jit_function
.I would get an error.
I notice that the restriction that parameters with typehint
tl.constexpr
(constant args) cannot be callable was first introduced in#644.
https://github.com/triton-lang/triton/blob/cfa8d18b835b10fc48449924aadf5982ac10d87c/python/triton/runtime/jit.py#L268C1-L270C79
However, commenting relevant lines make it just work as expected.
I think JITFunctions passed to another JITFunction is always a constexpr(in the sense that it is not a parameter to the CompiledKernel). And recent PRs has include the JITFunction arguments passed to
tl.reduce
ortl.associative_scan
in the function body into the cache_key(#3137). Then here is my question:Why only allow triton builtins like
tl.reduce
andtl.associative_scan
to haveJITFunction
as arguments while disallowing JITFunctions to have JITFunctions as constant arguments? It would makehigh order jit functions
easier with triton.Assume in a case that I want to make a utility that generate a
_warpper
function & a_jit_function
(like the example above) given a_scalar_fn
. Due to the fact that triton's JITFunction relies oninspect
to get the source of the function thattriton.jit
decorates, I have to write the code into a file and dynamically imports it(via importlib or some kind ofexec
).However, since JITFunction cannot take another JITFunction as argument, there is no high order JITFunction. So if I wan to generate a JITFunction that calls the user-provided scalar function, I have to make the scalar function available in the module where the generated
_jit_function
is defined. If I try to find where_scalar_fn
is defined and try to generate code to import it into the generated code, it would be difficult since_scalar_fn
may be defined in a module whose path is known, or just in a script(__name__ ==__main__
) or interactive repl.The problem for allowing JITFunction as parameters to another JITFunction I could think of is that the cache key of a JITFunction would have to be complete at
run
, instead of just inspecting the body of the JITFunction itself, which can be done at__init__
.The relation of high order JITFunction and non-high-order JITFunction resembles that of function templates and functions in cpp. When a high order JITFunction has all the JITFunction parameters specifiecd, it is complete(instantiated) in some sense.
So maybe it would be convenient to add an interface to
instantiate
a high order JITFunction into a non-high-order JITFunction to avoid analysing the cache key of a high order JITFunction at eachrun
? @ptilletThank you
The text was updated successfully, but these errors were encountered: