From 310a4e3e4723bc3d1b7ea447550652d10b0ad7f0 Mon Sep 17 00:00:00 2001 From: Cheng Cao Date: Tue, 25 Apr 2023 18:21:54 -0700 Subject: [PATCH 1/2] Make statements final, and change `IRNode::as` to use `static_cast` --- taichi/ir/ir.h | 6 +- taichi/ir/statements.h | 128 ++++++++++++++++++++--------------------- 2 files changed, 66 insertions(+), 68 deletions(-) diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index fc15d05be1f73..0bdfb8b729095 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -244,14 +244,12 @@ class IRNode { template T *as() { - TI_ASSERT(is()); - return dynamic_cast(this); + return static_cast(this); } template const T *as() const { - TI_ASSERT(is()); - return dynamic_cast(this); + return static_cast(this); } template diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 68996e70e6fe4..df107db55f370 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -16,7 +16,7 @@ class Function; /** * Allocate a local variable with initial value 0. */ -class AllocaStmt : public Stmt, public ir_traits::Store { +class AllocaStmt final : public Stmt, public ir_traits::Store { public: explicit AllocaStmt(DataType type) : is_shared(false) { ret_type = type; @@ -59,7 +59,7 @@ class AllocaStmt : public Stmt, public ir_traits::Store { /** * Updates mask, break if all bits of the mask are 0. */ -class WhileControlStmt : public Stmt { +class WhileControlStmt final : public Stmt { public: Stmt *mask; Stmt *cond; @@ -74,7 +74,7 @@ class WhileControlStmt : public Stmt { /** * Jump to the next loop iteration, i.e., `continue` in C++. */ -class ContinueStmt : public Stmt { +class ContinueStmt final : public Stmt { public: // This is the loop on which this continue stmt has effects. It can be either // an offloaded task, or a for/while loop inside the kernel. @@ -113,7 +113,7 @@ class ContinueStmt : public Stmt { /** * A decoration statement. The decorated "operands" will keep this decoration. */ -class DecorationStmt : public Stmt { +class DecorationStmt final : public Stmt { public: enum class Decoration : uint32_t { kUnknown, kLoopUnique }; @@ -145,7 +145,7 @@ class DecorationStmt : public Stmt { /** * A unary operation. The field |cast_type| is used only when is_cast() is true. */ -class UnaryOpStmt : public Stmt { +class UnaryOpStmt final : public Stmt { public: UnaryOpType op_type; Stmt *operand; @@ -169,7 +169,7 @@ class UnaryOpStmt : public Stmt { * statement. |is_ptr| should be true iff the result can be used as a base * pointer of an ExternalPtrStmt. */ -class ArgLoadStmt : public Stmt { +class ArgLoadStmt final : public Stmt { public: int arg_id; @@ -210,7 +210,7 @@ class ArgLoadStmt : public Stmt { * random seed. Each invocation of a RandStmt compiles to a call of a * deterministic PRNG to generate a random value in the backend. */ -class RandStmt : public Stmt { +class RandStmt final : public Stmt { public: explicit RandStmt(const DataType &dt) { ret_type = dt; @@ -232,7 +232,7 @@ class RandStmt : public Stmt { /** * A binary operation. */ -class BinaryOpStmt : public Stmt { +class BinaryOpStmt final : public Stmt { public: BinaryOpType op_type; Stmt *lhs, *rhs; @@ -263,7 +263,7 @@ class BinaryOpStmt : public Stmt { * A ternary operation. Currently "select" (the ternary conditional operator, * "?:" in C++) is the only supported ternary operation. */ -class TernaryOpStmt : public Stmt { +class TernaryOpStmt final : public Stmt { public: TernaryOpType op_type; Stmt *op1, *op2, *op3; @@ -287,7 +287,7 @@ class TernaryOpStmt : public Stmt { /** * An atomic operation. */ -class AtomicOpStmt : public Stmt, +class AtomicOpStmt final : public Stmt, public ir_traits::Store, public ir_traits::Load { public: @@ -330,7 +330,7 @@ class AtomicOpStmt : public Stmt, * An external pointer. |base_ptr| should be ArgLoadStmt with * |is_ptr| == true. */ -class ExternalPtrStmt : public Stmt { +class ExternalPtrStmt final : public Stmt { public: Stmt *base_ptr; std::vector indices; @@ -372,7 +372,7 @@ class ExternalPtrStmt : public Stmt { * SNodeLookupStmts and GetChStmts, and should not appear in the final lowered * IR. */ -class GlobalPtrStmt : public Stmt { +class GlobalPtrStmt final : public Stmt { public: SNode *snode; std::vector indices; @@ -404,7 +404,7 @@ class GlobalPtrStmt : public Stmt { * the lower_matrix_ptr pass, this stmt will either be eliminated (constant * index) or have ptr_base initialized (dynamic index or whole-matrix access). */ -class MatrixOfGlobalPtrStmt : public Stmt { +class MatrixOfGlobalPtrStmt final : public Stmt { public: std::vector snodes; std::vector indices; @@ -446,7 +446,7 @@ class MatrixOfGlobalPtrStmt : public Stmt { * TODO(yi/zhanlue): Keep scalarization pass alive for MatrixOfMatrixPtrStmt * operations even with real_matrix_scalarize=False */ -class MatrixOfMatrixPtrStmt : public Stmt { +class MatrixOfMatrixPtrStmt final : public Stmt { public: std::vector stmts; @@ -459,7 +459,7 @@ class MatrixOfMatrixPtrStmt : public Stmt { /** * A pointer to an element of a matrix. */ -class MatrixPtrStmt : public Stmt { +class MatrixPtrStmt final : public Stmt { public: Stmt *origin{nullptr}; Stmt *offset{nullptr}; @@ -503,7 +503,7 @@ class MatrixPtrStmt : public Stmt { /** * An operation to a SNode (not necessarily a leaf SNode). */ -class SNodeOpStmt : public Stmt, public ir_traits::Store { +class SNodeOpStmt final : public Stmt, public ir_traits::Store { public: SNodeOpType op_type; SNode *snode; @@ -539,7 +539,7 @@ class SNodeOpStmt : public Stmt, public ir_traits::Store { // TODO: remove this // (penguinliong) This Stmt is used for both ND-arrays and textures. This is // subject to change in the future. -class ExternalTensorShapeAlongAxisStmt : public Stmt { +class ExternalTensorShapeAlongAxisStmt final : public Stmt { public: int axis; int arg_id; @@ -559,7 +559,7 @@ class ExternalTensorShapeAlongAxisStmt : public Stmt { * If |cond| is false, print the formatted |text| with |args|, and terminate * the program. */ -class AssertStmt : public Stmt { +class AssertStmt final : public Stmt { public: Stmt *cond; std::string text; @@ -580,7 +580,7 @@ class AssertStmt : public Stmt { /** * Call an external (C++) function. */ -class ExternalFuncCallStmt : public Stmt, +class ExternalFuncCallStmt final : public Stmt, public ir_traits::Store, public ir_traits::Load { public: @@ -645,7 +645,7 @@ class ExternalFuncCallStmt : public Stmt, * This statement simply returns the input statement at the backend, and hints * the Taichi compiler that |base| + |low| <= |input| < |base| + |high|. */ -class RangeAssumptionStmt : public Stmt { +class RangeAssumptionStmt final : public Stmt { public: Stmt *input; Stmt *base; @@ -674,7 +674,7 @@ class RangeAssumptionStmt : public Stmt { * of this statement. Since this statement can only evaluate to one value, * the SNodes with id in the |covers| field should have only one dimension. */ -class LoopUniqueStmt : public Stmt { +class LoopUniqueStmt final : public Stmt { public: Stmt *input; std::unordered_set covers; // Stores SNode id @@ -695,7 +695,7 @@ class LoopUniqueStmt : public Stmt { * A load from a global address, including SNodes, external arrays, TLS, BLS, * and global temporary variables. */ -class GlobalLoadStmt : public Stmt, public ir_traits::Load { +class GlobalLoadStmt final : public Stmt, public ir_traits::Load { public: Stmt *src; @@ -724,7 +724,7 @@ class GlobalLoadStmt : public Stmt, public ir_traits::Load { * A store to a global address, including SNodes, external arrays, TLS, BLS, * and global temporary variables. */ -class GlobalStoreStmt : public Stmt, public ir_traits::Store { +class GlobalStoreStmt final : public Stmt, public ir_traits::Store { public: Stmt *dest; Stmt *val; @@ -753,7 +753,7 @@ class GlobalStoreStmt : public Stmt, public ir_traits::Store { /** * A load from a local variable, i.e., an "alloca". */ -class LocalLoadStmt : public Stmt, public ir_traits::Load { +class LocalLoadStmt final : public Stmt, public ir_traits::Load { public: Stmt *src; @@ -781,7 +781,7 @@ class LocalLoadStmt : public Stmt, public ir_traits::Load { /** * A store to a local variable, i.e., an "alloca". */ -class LocalStoreStmt : public Stmt, public ir_traits::Store { +class LocalStoreStmt final : public Stmt, public ir_traits::Store { public: Stmt *dest; Stmt *val; @@ -821,7 +821,7 @@ class LocalStoreStmt : public Stmt, public ir_traits::Store { * Same as "if (cond) true_statements; else false_statements;" in C++. * |true_mask| and |false_mask| are used to support vectorization. */ -class IfStmt : public Stmt { +class IfStmt final : public Stmt { public: Stmt *cond; std::unique_ptr true_statements, false_statements; @@ -847,7 +847,7 @@ class IfStmt : public Stmt { * either a statement or a string, and they are printed one by one, separated * by a comma and a space. */ -class PrintStmt : public Stmt { +class PrintStmt final : public Stmt { public: using EntryType = std::variant; using FormatType = std::optional; @@ -898,7 +898,7 @@ class PrintStmt : public Stmt { /** * A constant value. */ -class ConstStmt : public Stmt { +class ConstStmt final : public Stmt { public: TypedConstant val; @@ -923,7 +923,7 @@ class ConstStmt : public Stmt { * offloaded to a parallel for loop. Otherwise, it will be offloaded to a * serial for loop. */ -class RangeForStmt : public Stmt { +class RangeForStmt final : public Stmt { public: Stmt *begin, *end; std::unique_ptr body; @@ -967,7 +967,7 @@ class RangeForStmt : public Stmt { * A parallel for loop over a SNode, similar to "for i in snode: body" * in Python. This statement must be at the top level before offloading. */ -class StructForStmt : public Stmt { +class StructForStmt final : public Stmt { public: SNode *snode; std::unique_ptr body; @@ -1003,7 +1003,7 @@ class StructForStmt : public Stmt { /** * meshfor */ -class MeshForStmt : public Stmt { +class MeshForStmt final : public Stmt { public: mesh::Mesh *mesh; std::unique_ptr body; @@ -1042,7 +1042,7 @@ class MeshForStmt : public Stmt { /** * Call an inline Taichi function. */ -class FuncCallStmt : public Stmt { +class FuncCallStmt final : public Stmt { public: Function *func; std::vector args; @@ -1061,7 +1061,7 @@ class FuncCallStmt : public Stmt { /** * A reference to a variable. */ -class ReferenceStmt : public Stmt, public ir_traits::Load { +class ReferenceStmt final : public Stmt, public ir_traits::Load { public: Stmt *var; bool global_side_effect{false}; @@ -1086,7 +1086,7 @@ class ReferenceStmt : public Stmt, public ir_traits::Load { /** * Gets an element from a struct */ -class GetElementStmt : public Stmt { +class GetElementStmt final : public Stmt { public: Stmt *src; std::vector index; @@ -1102,7 +1102,7 @@ class GetElementStmt : public Stmt { /** * Exit the kernel or function with a return value. */ -class ReturnStmt : public Stmt { +class ReturnStmt final : public Stmt { public: std::vector values; @@ -1139,7 +1139,7 @@ class ReturnStmt : public Stmt { /** * A serial while-true loop. |mask| is to support vectorization. */ -class WhileStmt : public Stmt { +class WhileStmt final : public Stmt { public: Stmt *mask; std::unique_ptr body; @@ -1157,7 +1157,7 @@ class WhileStmt : public Stmt { }; // TODO: remove this (replace with input + ConstStmt(offset)) -class IntegerOffsetStmt : public Stmt { +class IntegerOffsetStmt final : public Stmt { public: Stmt *input; int64 offset; @@ -1177,7 +1177,7 @@ class IntegerOffsetStmt : public Stmt { /** * All indices of an address fused together. */ -class LinearizeStmt : public Stmt { +class LinearizeStmt final : public Stmt { public: std::vector inputs; std::vector strides; @@ -1200,7 +1200,7 @@ class LinearizeStmt : public Stmt { /** * The SNode root. */ -class GetRootStmt : public Stmt { +class GetRootStmt final : public Stmt { public: explicit GetRootStmt(SNode *root = nullptr) : root_(root) { if (this->root_ != nullptr) { @@ -1233,7 +1233,7 @@ class GetRootStmt : public Stmt { /** * Lookup a component of a SNode. */ -class SNodeLookupStmt : public Stmt { +class SNodeLookupStmt final : public Stmt { public: SNode *snode; Stmt *input_snode; @@ -1266,7 +1266,7 @@ class SNodeLookupStmt : public Stmt { /** * Get a child of a SNode on the hierarchical SNode tree. */ -class GetChStmt : public Stmt { +class GetChStmt final : public Stmt { public: Stmt *input_ptr; SNode *input_snode, *output_snode; @@ -1299,7 +1299,7 @@ class GetChStmt : public Stmt { /** * The statement corresponding to an offloaded task. */ -class OffloadedStmt : public Stmt { +class OffloadedStmt final : public Stmt { public: using TaskType = OffloadedTaskType; @@ -1387,7 +1387,7 @@ class OffloadedStmt : public Stmt { /** * The |index|-th index of the |loop|. */ -class LoopIndexStmt : public Stmt { +class LoopIndexStmt final : public Stmt { public: Stmt *loop; int index; @@ -1429,7 +1429,7 @@ class LoopIndexStmt : public Stmt { * thread index within a CUDA block * TODO: Remove this. Have a better way for retrieving thread index. */ -class LoopLinearIndexStmt : public Stmt { +class LoopLinearIndexStmt final : public Stmt { public: Stmt *loop; @@ -1448,7 +1448,7 @@ class LoopLinearIndexStmt : public Stmt { /** * global thread index, i.e. thread_idx() + block_idx() * block_dim() */ -class GlobalThreadIndexStmt : public Stmt { +class GlobalThreadIndexStmt final : public Stmt { public: explicit GlobalThreadIndexStmt() { TI_STMT_REG_FIELDS; @@ -1466,7 +1466,7 @@ class GlobalThreadIndexStmt : public Stmt { * The lowest |index|-th index of the |loop| among the iterations iterated by * the block. */ -class BlockCornerIndexStmt : public Stmt { +class BlockCornerIndexStmt final : public Stmt { public: Stmt *loop; int index; @@ -1487,7 +1487,7 @@ class BlockCornerIndexStmt : public Stmt { * A global temporary variable, located at |offset| in the global temporary * buffer. */ -class GlobalTemporaryStmt : public Stmt { +class GlobalTemporaryStmt final : public Stmt { public: std::size_t offset; @@ -1508,7 +1508,7 @@ class GlobalTemporaryStmt : public Stmt { /** * A thread-local pointer, located at |offset| in the thread-local storage. */ -class ThreadLocalPtrStmt : public Stmt { +class ThreadLocalPtrStmt final : public Stmt { public: std::size_t offset; @@ -1529,7 +1529,7 @@ class ThreadLocalPtrStmt : public Stmt { /** * A block-local pointer, located at |offset| in the block-local storage. */ -class BlockLocalPtrStmt : public Stmt { +class BlockLocalPtrStmt final : public Stmt { public: Stmt *offset; @@ -1549,7 +1549,7 @@ class BlockLocalPtrStmt : public Stmt { /** * The statement corresponding to a clear-list task. */ -class ClearListStmt : public Stmt { +class ClearListStmt final : public Stmt { public: explicit ClearListStmt(SNode *snode); @@ -1562,7 +1562,7 @@ class ClearListStmt : public Stmt { // Checks if the task represented by |stmt| contains a single ClearListStmt. bool is_clear_list_task(const OffloadedStmt *stmt); -class InternalFuncStmt : public Stmt { +class InternalFuncStmt final : public Stmt { public: std::string func_name; std::vector args; @@ -1589,7 +1589,7 @@ class InternalFuncStmt : public Stmt { class Texture; -class TexturePtrStmt : public Stmt { +class TexturePtrStmt final : public Stmt { public: Stmt *arg_load_stmt{nullptr}; int dimensions{2}; @@ -1621,7 +1621,7 @@ class TexturePtrStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; -class TextureOpStmt : public Stmt { +class TextureOpStmt final : public Stmt { public: TextureOpType op; Stmt *texture_ptr; @@ -1651,7 +1651,7 @@ class TextureOpStmt : public Stmt { /** * A local AD-stack. */ -class AdStackAllocaStmt : public Stmt { +class AdStackAllocaStmt final : public Stmt { public: DataType dt; std::size_t max_size{0}; // 0 = adaptive @@ -1688,7 +1688,7 @@ class AdStackAllocaStmt : public Stmt { /** * Load the top primal value of an AD-stack. */ -class AdStackLoadTopStmt : public Stmt, public ir_traits::Load { +class AdStackLoadTopStmt final : public Stmt, public ir_traits::Load { public: Stmt *stack; @@ -1718,7 +1718,7 @@ class AdStackLoadTopStmt : public Stmt, public ir_traits::Load { /** * Load the top adjoint value of an AD-stack. */ -class AdStackLoadTopAdjStmt : public Stmt, public ir_traits::Load { +class AdStackLoadTopAdjStmt final : public Stmt, public ir_traits::Load { public: Stmt *stack; @@ -1748,7 +1748,7 @@ class AdStackLoadTopAdjStmt : public Stmt, public ir_traits::Load { /** * Pop the top primal and adjoint values in the AD-stack. */ -class AdStackPopStmt : public Stmt, public ir_traits::Load { +class AdStackPopStmt final : public Stmt, public ir_traits::Load { public: Stmt *stack; @@ -1775,7 +1775,7 @@ class AdStackPopStmt : public Stmt, public ir_traits::Load { * Push a primal value to the AD-stack, and set the corresponding adjoint * value to 0. */ -class AdStackPushStmt : public Stmt, public ir_traits::Load { +class AdStackPushStmt final : public Stmt, public ir_traits::Load { public: Stmt *stack; Stmt *v; @@ -1804,7 +1804,7 @@ class AdStackPushStmt : public Stmt, public ir_traits::Load { * Accumulate |v| to the top adjoint value of the AD-stack. * This statement loads and stores the adjoint data. */ -class AdStackAccAdjointStmt : public Stmt, public ir_traits::Load { +class AdStackAccAdjointStmt final : public Stmt, public ir_traits::Load { public: Stmt *stack; Stmt *v; @@ -1831,7 +1831,7 @@ class AdStackAccAdjointStmt : public Stmt, public ir_traits::Load { /** * A global store to one or more children of a bit struct. */ -class BitStructStoreStmt : public Stmt { +class BitStructStoreStmt final : public Stmt { public: Stmt *ptr; std::vector ch_ids; @@ -1863,7 +1863,7 @@ class BitStructStoreStmt : public Stmt { * If neibhor_idex has no value, it returns the number of neighbors (length of * relation) of a mesh idx */ -class MeshRelationAccessStmt : public Stmt { +class MeshRelationAccessStmt final : public Stmt { public: mesh::Mesh *mesh; Stmt *mesh_idx; @@ -1920,7 +1920,7 @@ class MeshRelationAccessStmt : public Stmt { /** * Convert a mesh index to another index space */ -class MeshIndexConversionStmt : public Stmt { +class MeshIndexConversionStmt final : public Stmt { public: mesh::Mesh *mesh; mesh::MeshElementType idx_type; @@ -1948,7 +1948,7 @@ class MeshIndexConversionStmt : public Stmt { /** * The patch index of the |mesh_loop|. */ -class MeshPatchIndexStmt : public Stmt { +class MeshPatchIndexStmt final : public Stmt { public: MeshPatchIndexStmt() { this->ret_type = PrimitiveType::i32; @@ -1966,7 +1966,7 @@ class MeshPatchIndexStmt : public Stmt { /** * Initialization of a local matrix */ -class MatrixInitStmt : public Stmt { +class MatrixInitStmt final : public Stmt { public: std::vector values; From 5e0a34e4222e6da8abc0b07afa51aa8e4e13f61f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Apr 2023 01:26:52 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/ir/statements.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index df107db55f370..4e95b97157116 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -288,8 +288,8 @@ class TernaryOpStmt final : public Stmt { * An atomic operation. */ class AtomicOpStmt final : public Stmt, - public ir_traits::Store, - public ir_traits::Load { + public ir_traits::Store, + public ir_traits::Load { public: AtomicOpType op_type; Stmt *dest, *val; @@ -581,8 +581,8 @@ class AssertStmt final : public Stmt { * Call an external (C++) function. */ class ExternalFuncCallStmt final : public Stmt, - public ir_traits::Store, - public ir_traits::Load { + public ir_traits::Store, + public ir_traits::Load { public: enum Type { SHARED_OBJECT = 0, ASSEMBLY = 1, BITCODE = 2 };