Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aot] Proper type checking in C-API header generator #7979

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions misc/generate_c_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@


def get_type_name(x: EntryBase):
ty = type(x)
if ty in [BuiltInType]:
if isinstance(x, (BuiltInType)):
return x.type_name
elif ty in [Alias, Handle, Enumeration, Structure, Union, Callback]:
elif isinstance(x, (Alias, Handle, Enumeration, Structure, Union, Callback)):
return x.name.upper_camel_case
elif ty in [BitField]:
elif isinstance(x, BitField):
return x.name.extend("flags").upper_camel_case
else:
raise RuntimeError(f"'{x.id}' is not a type")
Expand Down Expand Up @@ -67,57 +66,56 @@ def get_declr(module: Module, x: EntryBase, with_docs=False):
if with_docs:
out += get_api_ref(module, x)

ty = type(x)
if ty is BuiltInType:
if isinstance(x, BuiltInType):
out += [""]

elif ty is Alias:
elif isinstance(x, Alias):
out += [f"typedef {get_type_name(x.alias_of)} {get_type_name(x)};"]

elif ty is Definition:
elif isinstance(x, Definition):
out += [f"#define {x.name.screaming_snake_case} {x.value}"]

elif ty is Handle:
elif isinstance(x, Handle):
out += [f"typedef struct {get_type_name(x)}_t* {get_type_name(x)};"]

elif ty is Enumeration:
elif isinstance(x, Enumeration):
out += ["typedef enum " + get_type_name(x) + " {"]
for name, value in x.cases.items():
if with_docs:
out += get_api_field_ref(module, x, name)
out += get_api_field_ref(module, x, name.snake_case)
name = x.name.extend(name).screaming_snake_case
out += [f" {name} = {value},"]
out += [f" {x.name.extend('max_enum').screaming_snake_case} = 0xffffffff,"]
out += ["} " + get_type_name(x) + ";"]

elif ty is BitField:
elif isinstance(x, BitField):
bit_type_name = x.name.extend("flag_bits").upper_camel_case
out += ["typedef enum " + bit_type_name + " {"]
for name, value in x.bits.items():
if with_docs:
out += get_api_field_ref(module, x, name)
out += get_api_field_ref(module, x, name.snake_case)
name = x.name.extend(name).extend("bit").screaming_snake_case
out += [f" {name} = 1 << {value},"]
out += ["} " + bit_type_name + ";"]
out += [f"typedef TiFlags {get_type_name(x)};"]

elif ty is Structure:
elif isinstance(x, Structure):
out += ["typedef struct " + get_type_name(x) + " {"]
for field in x.fields:
if with_docs:
out += get_api_field_ref(module, x, field.name)
out += [f" {get_field(field)};"]
out += ["} " + get_type_name(x) + ";"]

elif ty is Union:
elif isinstance(x, Union):
out += ["typedef union " + get_type_name(x) + " {"]
for variant in x.variants:
if with_docs:
out += get_api_field_ref(module, x, variant.name)
out += [f" {get_field(variant)};"]
out += ["} " + get_type_name(x) + ";"]

elif ty is Callback:
elif isinstance(x, Callback):
return_value_type = "void" if x.return_value_type == None else get_type_name(x.return_value_type)
out += [f"typedef {return_value_type} (TI_API_CALL *{get_type_name(x)})("]
if x.params:
Expand All @@ -129,7 +127,7 @@ def get_declr(module: Module, x: EntryBase, with_docs=False):
out += [f" {get_field(param)}"]
out += [");"]

elif ty is Function:
elif isinstance(x, Function):
return_value_type = "void" if x.return_value_type == None else get_type_name(x.return_value_type)
out += ["TI_DLL_EXPORT " + return_value_type + " TI_API_CALL " + x.name.snake_case + "("]
if x.params:
Expand Down Expand Up @@ -211,6 +209,7 @@ def resolve_symbol_to_name(module: Module, id: str):
pass

out = module.declr_reg.resolve(id)
assert out is not None, f"Unable to resolve symbol {id}"
href = None

try:
Expand Down Expand Up @@ -300,6 +299,7 @@ def print_module_header(module: Module):

for x in module.declr_reg:
declr = module.declr_reg.resolve(x)
assert declr is not None, f"Unable to resolve {x}"
out += ["", get_declr(module, declr, True)]

out += [
Expand Down
34 changes: 16 additions & 18 deletions misc/generate_unity_language_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@


def get_type_name(x: EntryBase):
ty = type(x)
if ty in [BuiltInType]:
if isinstance(x, BuiltInType):
return x.type_name
elif ty in [Alias, Handle, Enumeration, Structure, Union, Callback]:
elif isinstance(x, (Alias, Handle, Enumeration, Structure, Union, Callback)):
return x.name.upper_camel_case
elif ty in [BitField]:
elif isinstance(x, BitField):
return x.name.extend("flag_bits").upper_camel_case
else:
raise RuntimeError(f"'{x.id}' is not a type")
Expand Down Expand Up @@ -74,7 +73,7 @@ def get_struct_field(x: Field):

out = ""
if is_ptr:
out += f"IntPtr {name}"
out += f"public IntPtr {name}"
elif x.count:
out += f"[MarshalAs(UnmanagedType.ByValArray, SizeConst={x.count})] "
out += f"public {get_type_name(x.type)}[] {name}"
Expand All @@ -93,7 +92,7 @@ def get_union_variant(x: Field):

out = "[FieldOffset(0)] "
if is_ptr:
out += f"IntPtr {name}"
out += f"public IntPtr {name}"
elif x.count:
out += f"[MarshalAs(UnmanagedType.ByValArray, SizeConst={x.count})] "
out += f"public {get_type_name(x.type)}[] {name}"
Expand All @@ -103,22 +102,21 @@ def get_union_variant(x: Field):


def get_declr(x: EntryBase):
ty = type(x)
if ty is BuiltInType:
if isinstance(x, BuiltInType):
return ""

elif ty is Alias:
elif isinstance(x, Alias):
return f"// using {get_type_name(x)} = {get_type_name(x.alias_of)};"

elif ty is Definition:
elif isinstance(x, Definition):
out = [
"static partial class Def {",
f"public const uint {x.name.screaming_snake_case} = {x.value};",
"}",
]
return "\n".join(out)

elif ty is Handle:
elif isinstance(x, Handle):
out = [
"[StructLayout(LayoutKind.Sequential)]",
"public struct " + get_type_name(x) + " {",
Expand All @@ -127,7 +125,7 @@ def get_declr(x: EntryBase):
]
return "\n".join(out)

elif ty is Enumeration:
elif isinstance(x, Enumeration):
out = ["public enum " + get_type_name(x) + " {"]
for name, value in x.cases.items():
name = x.name.extend(name).screaming_snake_case
Expand All @@ -136,15 +134,15 @@ def get_declr(x: EntryBase):
out += ["}"]
return "\n".join(out)

elif ty is BitField:
elif isinstance(x, BitField):
out = ["[Flags]", "public enum " + get_type_name(x) + " {"]
for name, value in x.bits.items():
name = x.name.extend(name).extend("bit").screaming_snake_case
out += [f" {name} = 1 << {value},"]
out += ["};"]
return "\n".join(out)

elif ty is Structure:
elif isinstance(x, Structure):
out = [
"[StructLayout(LayoutKind.Sequential)]",
"public struct " + get_type_name(x) + " {",
Expand All @@ -154,7 +152,7 @@ def get_declr(x: EntryBase):
out += ["}"]
return "\n".join(out)

elif ty is Union:
elif isinstance(x, Union):
out = [
"[StructLayout(LayoutKind.Explicit)]",
"public struct " + get_type_name(x) + " {",
Expand All @@ -164,7 +162,7 @@ def get_declr(x: EntryBase):
out += ["}"]
return "\n".join(out)

elif ty is Callback:
elif isinstance(x, Callback):
out = [
"[StructLayout(LayoutKind.Sequential)]",
"public struct " + get_type_name(x) + " {",
Expand All @@ -173,7 +171,7 @@ def get_declr(x: EntryBase):
]
return "\n".join(out)

elif ty is Function:
elif isinstance(x, Function):
out = []

return_value_type = "void" if x.return_value_type == None else get_type_name(x.return_value_type)
Expand Down Expand Up @@ -298,7 +296,7 @@ def generate_module_header(module):
return

print(f"processing module '{module.name}'")
assert re.match("taichi/\w+.h", module.name)
assert re.match(r"taichi/\w+.h", module.name)
module_name = module.name[len("taichi/") : -len(".h")]
path = f"c_api/unity/{module_name}.cs"
with open(path, "w") as f:
Expand Down
26 changes: 17 additions & 9 deletions misc/taichi_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from collections import defaultdict
from pathlib import Path
from typing import List
from typing import DefaultDict, Dict, List, Optional


class Version:
Expand All @@ -12,7 +12,7 @@ def __init__(self, ver: str) -> None:
ver = ver[1:]
xs = [int(x) for x in ver.split(".")]
assert len(xs) <= 3
xs += ["0"] * (3 - len(xs))
xs += [0] * (3 - len(xs))

self.major = xs[0]
self.minor = xs[1]
Expand Down Expand Up @@ -66,15 +66,15 @@ def __repr__(self) -> str:


class DeclarationRegistry:
current = None
current: "Optional[DeclarationRegistry]" = None

def __init__(self, builtin_tys={}):
# "xxx.yyy" -> Xxx(yyy) Look-up table.
self._inner = {}
self._imported = {}
self._builtin_tys = dict((x.id, x) for x in builtin_tys)

def resolve(self, id: str) -> "EntryBase":
def resolve(self, id: str) -> Optional["EntryBase"]:
if id in self._builtin_tys:
return self._builtin_tys[id]
elif id in self._inner:
Expand All @@ -99,7 +99,7 @@ def set_current(declr_reg):
DeclarationRegistry.current = declr_reg


def load_inc_enums():
def load_inc_enums() -> DefaultDict[str, Dict[Name, int]]:
paths = glob.glob("taichi/inc/*.inc.h")
cases = defaultdict(dict)
for path in paths:
Expand Down Expand Up @@ -161,7 +161,10 @@ def __init__(self, id, type_name):
class Alias(EntryBase):
def __init__(self, j):
super().__init__(j, "alias")
self.alias_of = DeclarationRegistry.current.resolve(j["alias_of"])
assert DeclarationRegistry.current is not None
alias_of = DeclarationRegistry.current.resolve(j["alias_of"])
assert alias_of is not None
self.alias_of = alias_of


class Definition(EntryBase):
Expand All @@ -179,10 +182,12 @@ def __init__(self, j):
class Enumeration(EntryBase):
def __init__(self, j):
super().__init__(j, "enumeration")
cases: Dict[Name, int]
if "inc_cases" in j:
self.cases = load_inc_enums()[j["inc_cases"]]
cases = load_inc_enums()[j["inc_cases"]]
else:
self.cases = dict((Name(name), value) for name, value in j["cases"].items())
cases = dict((Name(name), value) for name, value in j["cases"].items())
self.cases = cases


class BitField(EntryBase):
Expand All @@ -196,6 +201,7 @@ def __init__(self, j):

class Field:
def __init__(self, j):
assert DeclarationRegistry.current is not None
ty = DeclarationRegistry.current.resolve(j["type"])
assert ty != None, f"unknown type '{j['type']}'"
# The type has been registered.
Expand Down Expand Up @@ -268,10 +274,12 @@ def __init__(self, name: str):
with path.open() as f:
templ = f.readlines()

i = 1

# Ignore markdown headers
markdown_metadata = []
if len(templ) > 0 and templ[0].startswith("---"):
for i in range(1, len(templ)):
for i in range(i, len(templ)):
if templ[i].startswith("---"):
i += 1
break
Expand Down