Skip to content

Commit

Permalink
[AOT][ExportProc] Use annotation for extra signature details.
Browse files Browse the repository at this point in the history
  • Loading branch information
raikonenfnu committed Sep 23, 2023
1 parent d80899a commit 8667f6a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
7 changes: 5 additions & 2 deletions python/shark_turbine/aot/compiled_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,11 @@ def def_export_proc(self, name, f) -> ExportProcDef:
raise TypeError(
f"exported functions only support positional parameters"
)
param_desc = param.default
if param_desc is inspect.Parameter.empty:
if param.default is not inspect.Parameter.empty:
param_desc = param.default
elif param.annotation is not inspect.Parameter.empty:
param_desc = param.annotation
else:
# TODO: Merge from a decorator?
raise TypeError(
f"export function {name} missing required default value annotation "
Expand Down
44 changes: 22 additions & 22 deletions tests/aot/iree_procedural_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class CompiledModuleAPI(unittest.TestCase):
def testTensorDim(self):
class BasicModule(CompiledModule):
def foobar(self, a=AbstractTensor(None, 3)):
def foobar(self, a: AbstractTensor(None, 3)):
return IREE.tensor_dim(a, 0)

inst = BasicModule(context=Context())
Expand All @@ -31,7 +31,7 @@ def foobar(self, a=AbstractTensor(None, 3)):

def testTensorEmpty(self):
class BasicModule(CompiledModule):
def foobar(self, x=AbstractIndex):
def foobar(self, x: AbstractIndex):
empty = IREE.tensor_empty(x, 16)
dim0 = IREE.tensor_dim(empty, 0)
return empty, dim0
Expand All @@ -46,7 +46,7 @@ def foobar(self, x=AbstractIndex):

def testTensorSplat(self):
class BasicModule(CompiledModule):
def foobar(self, x=AbstractIndex, y=AbstractF32):
def foobar(self, x: AbstractIndex, y: AbstractF32):
empty = IREE.tensor_splat(x, 34, value=y, dtype=torch.float32)
dim0 = IREE.tensor_dim(empty, 0)
return empty, dim0
Expand All @@ -63,7 +63,7 @@ def foobar(self, x=AbstractIndex, y=AbstractF32):

def testTensorTrace(self):
class BasicModule(CompiledModule):
def foobar(self, x=AbstractTensor(None), y=AbstractTensor(3)):
def foobar(self, x: AbstractTensor(None), y: AbstractTensor(3)):
IREE.tensor_trace("DEBUG", x, y)

inst = BasicModule(context=Context())
Expand All @@ -75,7 +75,7 @@ def testStoreDynamic(self):
class BasicModule(CompiledModule):
x = export_global(AbstractTensor(None, 34), mutable=True)

def foobar(self, x=AbstractIndex, y=AbstractF32):
def foobar(self, x: AbstractIndex, y: AbstractF32):
splat = IREE.tensor_splat(x, 34, value=y, dtype=torch.float32)
self.x = splat

Expand All @@ -91,7 +91,7 @@ def foobar(self, x=AbstractIndex, y=AbstractF32):

def testTensorSliceStatic(self):
class BasicModule(CompiledModule):
def foobar(self, x=AbstractTensor(3, 4)):
def foobar(self, x: AbstractTensor(3, 4)):
return IREE.tensor_slice(x, 0, (1, 3))

inst = BasicModule(context=Context())
Expand All @@ -104,7 +104,7 @@ def foobar(self, x=AbstractTensor(3, 4)):

def testTensorSliceDynamicIndex(self):
class SliceDynamicIndex(CompiledModule):
def foobar(self, x=AbstractIndex):
def foobar(self, x: AbstractIndex):
empty = IREE.tensor_empty(x, 16)
return IREE.tensor_slice(empty, x, 4)

Expand All @@ -118,7 +118,7 @@ def foobar(self, x=AbstractIndex):

def testTensorSliceDynamicLength(self):
class SliceDynamicIndex(CompiledModule):
def foobar(self, x=AbstractIndex, y=AbstractIndex):
def foobar(self, x: AbstractIndex, y: AbstractIndex):
empty = IREE.tensor_empty(x, 16)
return IREE.tensor_slice(empty, (x, y), 4)

Expand All @@ -134,10 +134,10 @@ def testTensorUpdateStatic(self):
class UpdateStatic(CompiledModule):
def foobar(
self,
target=AbstractTensor(4, 4),
update=AbstractTensor(2, 2),
i=AbstractIndex,
j=AbstractIndex,
target: AbstractTensor(4, 4),
update: AbstractTensor(2, 2),
i: AbstractIndex,
j: AbstractIndex,
):
return IREE.tensor_update(target, update, i, j)

Expand All @@ -153,11 +153,11 @@ def testTensorUpdateDynamic(self):
class UpdateDynamic(CompiledModule):
def foobar(
self,
x=AbstractIndex,
y=AbstractIndex,
i=AbstractIndex,
j=AbstractIndex,
value=AbstractF32,
x: AbstractIndex,
y: AbstractIndex,
i: AbstractIndex,
j: AbstractIndex,
value: AbstractF32,
):
target = IREE.tensor_empty(x, y)
update = IREE.tensor_splat(i, j, value=value, dtype=torch.float32)
Expand All @@ -173,7 +173,7 @@ def foobar(

def testTensorReshape(self):
class ReshapeModule(CompiledModule):
def foobar(self, x=AbstractIndex, y=AbstractIndex):
def foobar(self, x: AbstractIndex, y: AbstractIndex):
empty = IREE.tensor_empty(x, 16)
reshaped = IREE.tensor_reshape(empty, 1, y, y)
return reshaped
Expand All @@ -188,7 +188,7 @@ def foobar(self, x=AbstractIndex, y=AbstractIndex):

def testScalarAddInt(self):
class ArithModule(CompiledModule):
def foobar(self, a=AbstractI32, b=AbstractI32):
def foobar(self, a: AbstractI32, b: AbstractI32):
return a + b

inst = ArithModule(context=Context())
Expand All @@ -197,7 +197,7 @@ def foobar(self, a=AbstractI32, b=AbstractI32):

def testScalarAddFloat(self):
class ArithModule(CompiledModule):
def foobar(self, a=AbstractF32, b=AbstractF32):
def foobar(self, a: AbstractF32, b: AbstractF32):
return a + b

inst = ArithModule(context=Context())
Expand All @@ -206,7 +206,7 @@ def foobar(self, a=AbstractF32, b=AbstractF32):

def testScalarAddLiteral(self):
class ArithModule(CompiledModule):
def foobar(self, a=AbstractI32):
def foobar(self, a: AbstractI32):
return a + 1

inst = ArithModule(context=Context())
Expand All @@ -216,7 +216,7 @@ def foobar(self, a=AbstractI32):

def testScalarAddLiteralMixedType(self):
class ArithModule(CompiledModule):
def foobar(self, a=AbstractI32):
def foobar(self, a: AbstractI32):
return a + 3.23

inst = ArithModule(context=Context())
Expand Down

0 comments on commit 8667f6a

Please sign in to comment.