Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash paged decoding #325

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions iree/turbine/kernel/compiler/kernel_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,27 @@ def sym_to_dim_asm(s: IndexSymbol) -> str:
else:
# Unranked. Not well supported, but for completeness.
spec_asm = element_type_asm
strides = strides_from_symbolic_shape(
idx_context, kb_t.symbolic_shape, allow_mixed_shapes=True
)
# If strides have been specified in the type, that implies that they are
# not consistent with the dimensions of the tensor, so we default to
# dynamic dims for all shapes.
ref_type = self.reference[1].type
if ref_type.physical_layout:
# Strides are always present in the physical layout.
strides = [
idx_context.get_static_value(s)
for s in ref_type.physical_layout["stride"]
]
# Shapes are not always present in the physical layout.
if ref_type.physical_layout.get("shape", None):
shape_asm = "x".join(
sym_to_dim_asm(s) for s in ref_type.physical_layout["shape"]
)
spec_asm = f"{shape_asm}x{element_type_asm}"
else:
strides = strides_from_symbolic_shape(
idx_context, kb_t.symbolic_shape, allow_mixed_shapes=True
)

if strides is None:
memref_asm = f"memref<{spec_asm}>"
elif _is_symbolic(strides):
Expand Down
4 changes: 4 additions & 0 deletions iree/turbine/kernel/lang/kernel_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,23 @@ def new_subtype(
address_space: AddressSpace | NotSetType = NotSet,
symbolic_shape: tuple[IndexExpr, ...] | NotSetType = NotSet,
dtype: DataType | NotSetType = NotSet,
physical_layout: dict[str, IndexExpr] | NotSetType = NotSet,
usage: KernelBufferUsage | NotSetType = NotSet,
) -> Type[SubtypeT]:
init_address_space = (
address_space if address_space else AddressSpace.GLOBAL_MEMORY
)
init_symbolic_shape = symbolic_shape if symbolic_shape is not NotSet else cls.symbolic_shape # type: ignore
init_dtype = dtype if dtype is not NotSet else cls.dtype # type: ignore
init_physical_layout = physical_layout if physical_layout else None # type: ignore
init_usage = usage if usage is not NotSet else cls.usage # type: ignore

class SubType(cls):
address_space = init_address_space
symbolic_shape = init_symbolic_shape
rank = len(init_symbolic_shape) # type: ignore
dtype = init_dtype
physical_layout = init_physical_layout
usage = init_usage

if name is not NotSet:
Expand All @@ -104,6 +107,7 @@ class KernelBuffer(metaclass=KernelBufferMeta):
symbolic_shape: ClassVar[tuple[IndexExpr, ...]]
rank: ClassVar[int]
dtype: ClassVar[DataType]
stride: ClassVar[tuple[IndexExpr, ...]]

def __init__(self, tensor: torch.Tensor):
assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}"
Expand Down
16 changes: 12 additions & 4 deletions iree/turbine/kernel/lang/wave_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .._support.dtype import DataType
from .._support.indexing import IndexExpr, IndexSymbol, index_symbol

from sympy import Symbol
from sympy import Symbol, Integer
from sympy.core.expr import Expr
from typing_extensions import Self

Expand All @@ -41,6 +41,7 @@ class Memory(metaclass=KernelBufferMeta):
symbolic_shape: ClassVar[tuple[IndexExpr, ...]]
rank: ClassVar[int]
dtype: ClassVar[DataType]
physical_layout: ClassVar[Optional[dict[str, IndexExpr]]]
usage: ClassVar[Optional[KernelBufferUsage]]

def __init__(self) -> None:
Expand All @@ -55,9 +56,15 @@ def __class_getitem__(

shift = 0
usage = KernelBufferUsage.NONE
if isinstance(shape_and_dtype[-1], KernelBufferUsage):
shift = 1
usage = shape_and_dtype[-1]
last_dim = -1
if isinstance(shape_and_dtype[last_dim], KernelBufferUsage):
shift += 1
usage = shape_and_dtype[last_dim]
last_dim -= 1
physical_layout = None
if isinstance(shape_and_dtype[last_dim], dict):
shift += 1
physical_layout = shape_and_dtype[last_dim]
shape = shape_and_dtype[: -2 - shift]
addressSpace = shape_and_dtype[-2 - shift]
dtype = shape_and_dtype[-1 - shift]
Expand Down Expand Up @@ -85,6 +92,7 @@ def __class_getitem__(
address_space=addressSpace,
symbolic_shape=shape,
dtype=dtype,
physical_layout=physical_layout,
usage=usage,
)

Expand Down
29 changes: 26 additions & 3 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def read(
elements_per_thread: Optional[IndexExpr | int] = None,
mapping: Optional[IndexMapping] = None,
mapping_dynamic_vals: "Register" | tuple["Register", ...] = (),
drop_dims: Optional[IndexExpr] = (),
) -> "Register":
...

Expand Down Expand Up @@ -461,7 +462,7 @@ def node_args(self) -> dict[int, Any]:
for i, arg in enumerate(self.fx_node.args):
if isinstance(arg, fx.Node):
custom_args[i] = get_custom(arg)
if isinstance(arg, list) and all(isinstance(x, fx.Node) for x in arg):
if isinstance(arg, Sequence) and all(isinstance(x, fx.Node) for x in arg):
custom_args[i] = [get_custom(x) for x in arg]
return custom_args

Expand Down Expand Up @@ -965,6 +966,22 @@ class Read(CustomOp):
mapping_dynamic_vals: tuple[fx.Node, ...] = ()
_write_dependency: Optional[list[fx.Node]] = None

"""
Note on drop_dims.
Consider the following loop:

for b in range(B):
for k1 in range(K1):
for k2 in range(K2):
out[b, k1, k2] = in[b, 0, k1, k2]

This is a slice where the output is a 3D tensor and the input is a 4D tensor.
The index mapping does not allow rank-reducing operations, since every symbol in the output must be
bound to an index variable. So we introduce a drop_dims field to specify which dimensions are dropped
after the mapping.

"""

@property
def indexing_dims(self) -> list[IndexSymbol]:
if self.mapping is not None:
Expand Down Expand Up @@ -1013,7 +1030,10 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]
subs = {v: k for k, v in zip(iters, mapping.keys())}
return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()}
return {
k: v.apply_expr(subs[k], mapping[k]) if k in mapping else v
for k, v in index.items()
}

return index

Expand Down Expand Up @@ -1253,7 +1273,10 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]
subs = {v: k for k, v in zip(iters, mapping.keys())}
return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()}
return {
k: v.apply_expr(subs[k], mapping[k]) if k in mapping else v
for k, v in index.items()
}

return index

Expand Down
Loading
Loading