Skip to content

Commit

Permalink
Add paged flash decoding kernel
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Dec 12, 2024
1 parent 71eb1c8 commit cba32c1
Show file tree
Hide file tree
Showing 10 changed files with 1,220 additions and 622 deletions.
24 changes: 21 additions & 3 deletions iree/turbine/kernel/compiler/kernel_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,27 @@ def sym_to_dim_asm(s: IndexSymbol) -> str:
else:
# Unranked. Not well supported, but for completeness.
spec_asm = element_type_asm
strides = strides_from_symbolic_shape(
idx_context, kb_t.symbolic_shape, allow_mixed_shapes=True
)
# If strides have been specified in the type, that implies that they are
# not consistent with the dimensions of the tensor, so we default to
# dynamic dims for all shapes.
ref_type = self.reference[1].type
if ref_type.physical_layout:
# Strides are always present in the physical layout.
strides = [
idx_context.get_static_value(s)
for s in ref_type.physical_layout["stride"]
]
# Shapes are not always present in the physical layout.
if ref_type.physical_layout.get("shape", None):
shape_asm = "x".join(
sym_to_dim_asm(s) for s in ref_type.physical_layout["shape"]
)
spec_asm = f"{shape_asm}x{element_type_asm}"
else:
strides = strides_from_symbolic_shape(
idx_context, kb_t.symbolic_shape, allow_mixed_shapes=True
)

if strides is None:
memref_asm = f"memref<{spec_asm}>"
elif _is_symbolic(strides):
Expand Down
4 changes: 4 additions & 0 deletions iree/turbine/kernel/lang/kernel_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,23 @@ def new_subtype(
address_space: AddressSpace | NotSetType = NotSet,
symbolic_shape: tuple[IndexExpr, ...] | NotSetType = NotSet,
dtype: DataType | NotSetType = NotSet,
physical_layout: dict[str, IndexExpr] | NotSetType = NotSet,
usage: KernelBufferUsage | NotSetType = NotSet,
) -> Type[SubtypeT]:
init_address_space = (
address_space if address_space else AddressSpace.GLOBAL_MEMORY
)
init_symbolic_shape = symbolic_shape if symbolic_shape is not NotSet else cls.symbolic_shape # type: ignore
init_dtype = dtype if dtype is not NotSet else cls.dtype # type: ignore
init_physical_layout = physical_layout if physical_layout else None # type: ignore
init_usage = usage if usage is not NotSet else cls.usage # type: ignore

class SubType(cls):
address_space = init_address_space
symbolic_shape = init_symbolic_shape
rank = len(init_symbolic_shape) # type: ignore
dtype = init_dtype
physical_layout = init_physical_layout
usage = init_usage

if name is not NotSet:
Expand All @@ -104,6 +107,7 @@ class KernelBuffer(metaclass=KernelBufferMeta):
symbolic_shape: ClassVar[tuple[IndexExpr, ...]]
rank: ClassVar[int]
dtype: ClassVar[DataType]
stride: ClassVar[tuple[IndexExpr, ...]]

def __init__(self, tensor: torch.Tensor):
assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}"
Expand Down
16 changes: 12 additions & 4 deletions iree/turbine/kernel/lang/wave_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .._support.dtype import DataType
from .._support.indexing import IndexExpr, IndexSymbol, index_symbol

from sympy import Symbol
from sympy import Symbol, Integer
from sympy.core.expr import Expr
from typing_extensions import Self

Expand All @@ -41,6 +41,7 @@ class Memory(metaclass=KernelBufferMeta):
symbolic_shape: ClassVar[tuple[IndexExpr, ...]]
rank: ClassVar[int]
dtype: ClassVar[DataType]
physical_layout: ClassVar[Optional[dict[str, IndexExpr]]]
usage: ClassVar[Optional[KernelBufferUsage]]

def __init__(self) -> None:
Expand All @@ -55,9 +56,15 @@ def __class_getitem__(

shift = 0
usage = KernelBufferUsage.NONE
if isinstance(shape_and_dtype[-1], KernelBufferUsage):
shift = 1
usage = shape_and_dtype[-1]
last_dim = -1
if isinstance(shape_and_dtype[last_dim], KernelBufferUsage):
shift += 1
usage = shape_and_dtype[last_dim]
last_dim -= 1
physical_layout = None
if isinstance(shape_and_dtype[last_dim], dict):
shift += 1
physical_layout = shape_and_dtype[last_dim]
shape = shape_and_dtype[: -2 - shift]
addressSpace = shape_and_dtype[-2 - shift]
dtype = shape_and_dtype[-1 - shift]
Expand Down Expand Up @@ -85,6 +92,7 @@ def __class_getitem__(
address_space=addressSpace,
symbolic_shape=shape,
dtype=dtype,
physical_layout=physical_layout,
usage=usage,
)

Expand Down
29 changes: 26 additions & 3 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def read(
elements_per_thread: Optional[IndexExpr | int] = None,
mapping: Optional[IndexMapping] = None,
mapping_dynamic_vals: "Register" | tuple["Register", ...] = (),
drop_dims: Optional[IndexExpr] = (),
) -> "Register":
...

Expand Down Expand Up @@ -461,7 +462,7 @@ def node_args(self) -> dict[int, Any]:
for i, arg in enumerate(self.fx_node.args):
if isinstance(arg, fx.Node):
custom_args[i] = get_custom(arg)
if isinstance(arg, list) and all(isinstance(x, fx.Node) for x in arg):
if isinstance(arg, Sequence) and all(isinstance(x, fx.Node) for x in arg):
custom_args[i] = [get_custom(x) for x in arg]
return custom_args

Expand Down Expand Up @@ -965,6 +966,22 @@ class Read(CustomOp):
mapping_dynamic_vals: tuple[fx.Node, ...] = ()
_write_dependency: Optional[list[fx.Node]] = None

"""
Note on drop_dims.
Consider the following loop:
for b in range(B):
for k1 in range(K1):
for k2 in range(K2):
out[b, k1, k2] = in[b, 0, k1, k2]
This is a slice where the output is a 3D tensor and the input is a 4D tensor.
The index mapping does not allow rank-reducing operations, since every symbol in the output must be
bound to an index variable. So we introduce a drop_dims field to specify which dimensions are dropped
after the mapping.
"""

@property
def indexing_dims(self) -> list[IndexSymbol]:
if self.mapping is not None:
Expand Down Expand Up @@ -1013,7 +1030,10 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]
subs = {v: k for k, v in zip(iters, mapping.keys())}
return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()}
return {
k: v.apply_expr(subs[k], mapping[k]) if k in mapping else v
for k, v in index.items()
}

return index

Expand Down Expand Up @@ -1253,7 +1273,10 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]
subs = {v: k for k, v in zip(iters, mapping.keys())}
return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()}
return {
k: v.apply_expr(subs[k], mapping[k]) if k in mapping else v
for k, v in index.items()
}

return index

Expand Down
Loading

0 comments on commit cba32c1

Please sign in to comment.