Skip to content

Commit

Permalink
[TKW] Add support for tkw.round_even
Browse files Browse the repository at this point in the history
Signed-off-by: Ege Beysel <[email protected]>
  • Loading branch information
egebeysel committed Jan 14, 2025
1 parent d759cb5 commit 6b02ab5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
5 changes: 5 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def reshape(inputs: Sequence["Register"]) -> "Register":
...


def round_even(src: "Register") -> "Register":
...


def define_op(op_name: str) -> Callable[[T], T]:
def decorator(cls: T) -> T:
cls.tkw_op_name = op_name
Expand Down Expand Up @@ -704,6 +708,7 @@ def infer_type(self):
@define_interface_op("exp2")
@define_interface_op("reciprocal")
@define_interface_op("abs")
@define_interface_op("round_even")
@define_py_op(operator.neg)
@dataclass
class UnaryPyOp(CustomOp, ABC):
Expand Down
11 changes: 11 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
cast,
permute,
reshape,
round_even,
)
from ..lang.wave_types import IndexMapping, IndexSymbol
from ..compiler.base import CodegenError, ValidationError, NDEBUG
Expand Down Expand Up @@ -1197,6 +1198,16 @@ def handle_abs(source: Value) -> OpResult:
return abs


@handle_unary_op(round_even)
def handle_round_even(source: Value) -> OpResult:
element_type = get_type_or_element_type(source.type)
if _is_float_type(element_type):
round_even = math_d.roundeven(source)
else:
raise ValidationError(f"Found unhandled operand type for abs: {element_type}")
return round_even


###############################################################################
# Control Flow ops
###############################################################################
Expand Down
12 changes: 8 additions & 4 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ def test(
res = tkw.reciprocal(res)
res = tkw.abs(res)
res_b = tkw.abs(b_reg)
res = tkw.round_even(res)
tkw.write(res, a, elements_per_thread=4)
tkw.write(res_b, b, elements_per_thread=4)

Expand All @@ -740,12 +741,15 @@ def test(
# CHECK: %[[EXP2:.+]] = math.exp2 %[[NEG]]

# Testing reciprocal
# %[[ONES:.+]] = arith.constant dense<1.000000e+00> : vector<4xf16>
# %[[RECIPROCAL:.+]] = arith.divf %[[ONES]], %[[EXP2]] : vector<4xf16>
# CHECK: %[[ONES:.+]] = arith.constant dense<1.000000e+00> : vector<4xf16>
# CHECK: %[[RECIPROCAL:.+]] = arith.divf %[[ONES]], %[[EXP2]] : vector<4xf16>

# Testing abs
# %[[ABSF:.+]] = math.absf %[[RECIPROCAL]]
# %[[ABSI:.+]] = math.absi
# CHECK: %[[ABSF:.+]] = math.absf %[[RECIPROCAL]]
# CHECK: %[[ABSI:.+]] = math.absi

# Testing round_even
# CHECK: %[[ROUNDEVEN:.+]] = math.roundeven %[[ABSF]]


@run_test
Expand Down

0 comments on commit 6b02ab5

Please sign in to comment.