diff --git a/misc/generate_c_api.py b/misc/generate_c_api.py index c9b0b6a311a18..a378712d656f7 100644 --- a/misc/generate_c_api.py +++ b/misc/generate_c_api.py @@ -19,12 +19,11 @@ def get_type_name(x: EntryBase): - ty = type(x) - if ty in [BuiltInType]: + if isinstance(x, (BuiltInType)): return x.type_name - elif ty in [Alias, Handle, Enumeration, Structure, Union, Callback]: + elif isinstance(x, (Alias, Handle, Enumeration, Structure, Union, Callback)): return x.name.upper_camel_case - elif ty in [BitField]: + elif isinstance(x, BitField): return x.name.extend("flags").upper_camel_case else: raise RuntimeError(f"'{x.id}' is not a type") @@ -67,41 +66,40 @@ def get_declr(module: Module, x: EntryBase, with_docs=False): if with_docs: out += get_api_ref(module, x) - ty = type(x) - if ty is BuiltInType: + if isinstance(x, BuiltInType): out += [""] - elif ty is Alias: + elif isinstance(x, Alias): out += [f"typedef {get_type_name(x.alias_of)} {get_type_name(x)};"] - elif ty is Definition: + elif isinstance(x, Definition): out += [f"#define {x.name.screaming_snake_case} {x.value}"] - elif ty is Handle: + elif isinstance(x, Handle): out += [f"typedef struct {get_type_name(x)}_t* {get_type_name(x)};"] - elif ty is Enumeration: + elif isinstance(x, Enumeration): out += ["typedef enum " + get_type_name(x) + " {"] for name, value in x.cases.items(): if with_docs: - out += get_api_field_ref(module, x, name) + out += get_api_field_ref(module, x, name.snake_case) name = x.name.extend(name).screaming_snake_case out += [f" {name} = {value},"] out += [f" {x.name.extend('max_enum').screaming_snake_case} = 0xffffffff,"] out += ["} " + get_type_name(x) + ";"] - elif ty is BitField: + elif isinstance(x, BitField): bit_type_name = x.name.extend("flag_bits").upper_camel_case out += ["typedef enum " + bit_type_name + " {"] for name, value in x.bits.items(): if with_docs: - out += get_api_field_ref(module, x, name) + out += get_api_field_ref(module, x, name.snake_case) name = x.name.extend(name).extend("bit").screaming_snake_case out += [f" {name} = 1 << {value},"] out += ["} " + bit_type_name + ";"] out += [f"typedef TiFlags {get_type_name(x)};"] - elif ty is Structure: + elif isinstance(x, Structure): out += ["typedef struct " + get_type_name(x) + " {"] for field in x.fields: if with_docs: @@ -109,7 +107,7 @@ def get_declr(module: Module, x: EntryBase, with_docs=False): out += [f" {get_field(field)};"] out += ["} " + get_type_name(x) + ";"] - elif ty is Union: + elif isinstance(x, Union): out += ["typedef union " + get_type_name(x) + " {"] for variant in x.variants: if with_docs: @@ -117,7 +115,7 @@ def get_declr(module: Module, x: EntryBase, with_docs=False): out += [f" {get_field(variant)};"] out += ["} " + get_type_name(x) + ";"] - elif ty is Callback: + elif isinstance(x, Callback): return_value_type = "void" if x.return_value_type == None else get_type_name(x.return_value_type) out += [f"typedef {return_value_type} (TI_API_CALL *{get_type_name(x)})("] if x.params: @@ -129,7 +127,7 @@ def get_declr(module: Module, x: EntryBase, with_docs=False): out += [f" {get_field(param)}"] out += [");"] - elif ty is Function: + elif isinstance(x, Function): return_value_type = "void" if x.return_value_type == None else get_type_name(x.return_value_type) out += ["TI_DLL_EXPORT " + return_value_type + " TI_API_CALL " + x.name.snake_case + "("] if x.params: @@ -211,6 +209,7 @@ def resolve_symbol_to_name(module: Module, id: str): pass out = module.declr_reg.resolve(id) + assert out is not None, f"Unable to resolve symbol {id}" href = None try: @@ -300,6 +299,7 @@ def print_module_header(module: Module): for x in module.declr_reg: declr = module.declr_reg.resolve(x) + assert declr is not None, f"Unable to resolve {x}" out += ["", get_declr(module, declr, True)] out += [ diff --git a/misc/generate_unity_language_binding.py b/misc/generate_unity_language_binding.py index fe39e58f63551..d9c89d77aca97 100644 --- a/misc/generate_unity_language_binding.py +++ b/misc/generate_unity_language_binding.py @@ -23,12 +23,11 @@ def get_type_name(x: EntryBase): - ty = type(x) - if ty in [BuiltInType]: + if isinstance(x, BuiltInType): return x.type_name - elif ty in [Alias, Handle, Enumeration, Structure, Union, Callback]: + elif isinstance(x, (Alias, Handle, Enumeration, Structure, Union, Callback)): return x.name.upper_camel_case - elif ty in [BitField]: + elif isinstance(x, BitField): return x.name.extend("flag_bits").upper_camel_case else: raise RuntimeError(f"'{x.id}' is not a type") @@ -74,7 +73,7 @@ def get_struct_field(x: Field): out = "" if is_ptr: - out += f"IntPtr {name}" + out += f"public IntPtr {name}" elif x.count: out += f"[MarshalAs(UnmanagedType.ByValArray, SizeConst={x.count})] " out += f"public {get_type_name(x.type)}[] {name}" @@ -93,7 +92,7 @@ def get_union_variant(x: Field): out = "[FieldOffset(0)] " if is_ptr: - out += f"IntPtr {name}" + out += f"public IntPtr {name}" elif x.count: out += f"[MarshalAs(UnmanagedType.ByValArray, SizeConst={x.count})] " out += f"public {get_type_name(x.type)}[] {name}" @@ -103,14 +102,13 @@ def get_union_variant(x: Field): def get_declr(x: EntryBase): - ty = type(x) - if ty is BuiltInType: + if isinstance(x, BuiltInType): return "" - elif ty is Alias: + elif isinstance(x, Alias): return f"// using {get_type_name(x)} = {get_type_name(x.alias_of)};" - elif ty is Definition: + elif isinstance(x, Definition): out = [ "static partial class Def {", f"public const uint {x.name.screaming_snake_case} = {x.value};", @@ -118,7 +116,7 @@ def get_declr(x: EntryBase): ] return "\n".join(out) - elif ty is Handle: + elif isinstance(x, Handle): out = [ "[StructLayout(LayoutKind.Sequential)]", "public struct " + get_type_name(x) + " {", @@ -127,7 +125,7 @@ def get_declr(x: EntryBase): ] return "\n".join(out) - elif ty is Enumeration: + elif isinstance(x, Enumeration): out = ["public enum " + get_type_name(x) + " {"] for name, value in x.cases.items(): name = x.name.extend(name).screaming_snake_case @@ -136,7 +134,7 @@ def get_declr(x: EntryBase): out += ["}"] return "\n".join(out) - elif ty is BitField: + elif isinstance(x, BitField): out = ["[Flags]", "public enum " + get_type_name(x) + " {"] for name, value in x.bits.items(): name = x.name.extend(name).extend("bit").screaming_snake_case @@ -144,7 +142,7 @@ def get_declr(x: EntryBase): out += ["};"] return "\n".join(out) - elif ty is Structure: + elif isinstance(x, Structure): out = [ "[StructLayout(LayoutKind.Sequential)]", "public struct " + get_type_name(x) + " {", @@ -154,7 +152,7 @@ def get_declr(x: EntryBase): out += ["}"] return "\n".join(out) - elif ty is Union: + elif isinstance(x, Union): out = [ "[StructLayout(LayoutKind.Explicit)]", "public struct " + get_type_name(x) + " {", @@ -164,7 +162,7 @@ def get_declr(x: EntryBase): out += ["}"] return "\n".join(out) - elif ty is Callback: + elif isinstance(x, Callback): out = [ "[StructLayout(LayoutKind.Sequential)]", "public struct " + get_type_name(x) + " {", @@ -173,7 +171,7 @@ def get_declr(x: EntryBase): ] return "\n".join(out) - elif ty is Function: + elif isinstance(x, Function): out = [] return_value_type = "void" if x.return_value_type == None else get_type_name(x.return_value_type) @@ -298,7 +296,7 @@ def generate_module_header(module): return print(f"processing module '{module.name}'") - assert re.match("taichi/\w+.h", module.name) + assert re.match(r"taichi/\w+.h", module.name) module_name = module.name[len("taichi/") : -len(".h")] path = f"c_api/unity/{module_name}.cs" with open(path, "w") as f: diff --git a/misc/taichi_json.py b/misc/taichi_json.py index 7df697ff0baa0..9d1fa4be7b751 100644 --- a/misc/taichi_json.py +++ b/misc/taichi_json.py @@ -3,7 +3,7 @@ import re from collections import defaultdict from pathlib import Path -from typing import List +from typing import DefaultDict, Dict, List, Optional class Version: @@ -12,7 +12,7 @@ def __init__(self, ver: str) -> None: ver = ver[1:] xs = [int(x) for x in ver.split(".")] assert len(xs) <= 3 - xs += ["0"] * (3 - len(xs)) + xs += [0] * (3 - len(xs)) self.major = xs[0] self.minor = xs[1] @@ -66,7 +66,7 @@ def __repr__(self) -> str: class DeclarationRegistry: - current = None + current: "Optional[DeclarationRegistry]" = None def __init__(self, builtin_tys={}): # "xxx.yyy" -> Xxx(yyy) Look-up table. @@ -74,7 +74,7 @@ def __init__(self, builtin_tys={}): self._imported = {} self._builtin_tys = dict((x.id, x) for x in builtin_tys) - def resolve(self, id: str) -> "EntryBase": + def resolve(self, id: str) -> Optional["EntryBase"]: if id in self._builtin_tys: return self._builtin_tys[id] elif id in self._inner: @@ -99,7 +99,7 @@ def set_current(declr_reg): DeclarationRegistry.current = declr_reg -def load_inc_enums(): +def load_inc_enums() -> DefaultDict[str, Dict[Name, int]]: paths = glob.glob("taichi/inc/*.inc.h") cases = defaultdict(dict) for path in paths: @@ -161,7 +161,10 @@ def __init__(self, id, type_name): class Alias(EntryBase): def __init__(self, j): super().__init__(j, "alias") - self.alias_of = DeclarationRegistry.current.resolve(j["alias_of"]) + assert DeclarationRegistry.current is not None + alias_of = DeclarationRegistry.current.resolve(j["alias_of"]) + assert alias_of is not None + self.alias_of = alias_of class Definition(EntryBase): @@ -179,10 +182,12 @@ def __init__(self, j): class Enumeration(EntryBase): def __init__(self, j): super().__init__(j, "enumeration") + cases: Dict[Name, int] if "inc_cases" in j: - self.cases = load_inc_enums()[j["inc_cases"]] + cases = load_inc_enums()[j["inc_cases"]] else: - self.cases = dict((Name(name), value) for name, value in j["cases"].items()) + cases = dict((Name(name), value) for name, value in j["cases"].items()) + self.cases = cases class BitField(EntryBase): @@ -196,6 +201,7 @@ def __init__(self, j): class Field: def __init__(self, j): + assert DeclarationRegistry.current is not None ty = DeclarationRegistry.current.resolve(j["type"]) assert ty != None, f"unknown type '{j['type']}'" # The type has been registered. @@ -268,10 +274,12 @@ def __init__(self, name: str): with path.open() as f: templ = f.readlines() + i = 1 + # Ignore markdown headers markdown_metadata = [] if len(templ) > 0 and templ[0].startswith("---"): - for i in range(1, len(templ)): + for i in range(i, len(templ)): if templ[i].startswith("---"): i += 1 break