You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In #3595 we are seeing that large matmul problems use int64_t indexing even if all the global memory transfers are done using TMA instead of vectorized accesses. Since TMA can use 2D indexing in these cases, most of the time it is safe to actually use int32_t indexing in these cases.
Currently we compute the index type by looking at all the input tensors in a KernelArgumentHolder and finding the largest index that could be used to index an element of any of those tensors. Instead, what we would ideally like is to bound each expression in our lowered kernel and if all of those bounds is within the range of an int32_t, use Int32 as the index type. To do this we could implement some limited interval arithmetic on Val* and evaluate bounds for all scalars in the kernel, stopping when an upper bound indicates overflow.
The text was updated successfully, but these errors were encountered:
This type of analysis could also allow us to mix index types within the kernel by setting the dtype to Int32 for tensors that we bound below the overflow threshold.
In #3595 we are seeing that large matmul problems use
int64_t
indexing even if all the global memory transfers are done using TMA instead of vectorized accesses. Since TMA can use 2D indexing in these cases, most of the time it is safe to actually useint32_t
indexing in these cases.Currently we compute the index type by looking at all the input tensors in a
KernelArgumentHolder
and finding the largest index that could be used to index an element of any of those tensors. Instead, what we would ideally like is to bound each expression in our lowered kernel and if all of those bounds is within the range of anint32_t
, useInt32
as the index type. To do this we could implement some limited interval arithmetic onVal*
and evaluate bounds for all scalars in the kernel, stopping when an upper bound indicates overflow.The text was updated successfully, but these errors were encountered: