Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace DataType with const Type* throughout the codebase #7850

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ bool QuantFloatType::get_is_signed() const {

BitStructType::BitStructType(
PrimitiveType *physical_type,
const std::vector<Type *> &member_types,
const std::vector<const Type *> &member_types,
const std::vector<int> &member_bit_offsets,
const std::vector<int> &member_exponents,
const std::vector<std::vector<int>> &member_exponent_users)
Expand All @@ -282,7 +282,7 @@ BitStructType::BitStructType(
int physical_type_bits = data_type_bits(physical_type_);
int member_total_bits = 0;
for (auto i = 0; i < member_types_.size(); ++i) {
QuantIntType *component_qit = nullptr;
const QuantIntType *component_qit = nullptr;
if (auto qit = member_types_[i]->cast<QuantIntType>()) {
component_qit = qit;
} else if (auto qfxt = member_types_[i]->cast<QuantFixedType>()) {
Expand Down
12 changes: 6 additions & 6 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class TI_DLL_EXPORT DataType {
TI_IO_DEF(ptr_);

private:
Type *ptr_;
const Type *ptr_;
};

// Note that all types are immutable once created.
Expand Down Expand Up @@ -335,7 +335,7 @@ class TI_DLL_EXPORT QuantIntType : public Type {
private:
// TODO(type): for now we can uniformly use i32 as the "compute_type". It may
// be a good idea to make "compute_type" also customizable.
Type *compute_type_{nullptr};
const Type *compute_type_{nullptr};
int num_bits_{32};
bool is_signed_{true};
};
Expand All @@ -349,7 +349,7 @@ class TI_DLL_EXPORT QuantFixedType : public Type {

bool get_is_signed() const;

Type *get_digits_type() {
const Type *get_digits_type() {
return digits_type_;
}

Expand Down Expand Up @@ -379,7 +379,7 @@ class TI_DLL_EXPORT QuantFloatType : public Type {

std::string to_string() const override;

Type *get_digits_type() {
const Type *get_digits_type() {
return digits_type_;
}

Expand Down Expand Up @@ -411,7 +411,7 @@ class TI_DLL_EXPORT BitStructType : public Type {
public:
BitStructType() : Type(TypeKind::BitStruct){};
BitStructType(PrimitiveType *physical_type,
const std::vector<Type *> &member_types,
const std::vector<const Type *> &member_types,
const std::vector<int> &member_bit_offsets,
const std::vector<int> &member_exponents,
const std::vector<std::vector<int>> &member_exponent_users);
Expand Down Expand Up @@ -457,7 +457,7 @@ class TI_DLL_EXPORT BitStructType : public Type {

private:
PrimitiveType *physical_type_;
std::vector<Type *> member_types_;
std::vector<const Type *> member_types_;
std::vector<int> member_bit_offsets_;
std::vector<int> member_exponents_;
std::vector<std::vector<int>> member_exponent_users_;
Expand Down
46 changes: 25 additions & 21 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TypeFactory &TypeFactory::get_instance() {
TypeFactory::TypeFactory() {
}

Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
const Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
std::lock_guard<std::mutex> _(primitive_mut_);

if (primitive_types_.find(id) == primitive_types_.end()) {
Expand All @@ -22,7 +22,8 @@ Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
return primitive_types_[id].get();
}

Type *TypeFactory::get_tensor_type(std::vector<int> shape, Type *element) {
const Type *TypeFactory::get_tensor_type(std::vector<int> shape,
Type *element) {
std::lock_guard<std::mutex> _(tensor_mut_);

auto encode = [](const std::vector<int> &shape) -> std::string {
Expand Down Expand Up @@ -57,7 +58,8 @@ const Type *TypeFactory::get_struct_type(
return struct_types_[key].get();
}

Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
const Type *TypeFactory::get_pointer_type(const Type *element,
bool is_bit_pointer) {
std::lock_guard<std::mutex> _(pointer_mut_);

auto key = std::make_pair(element, is_bit_pointer);
Expand All @@ -68,9 +70,9 @@ Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
return pointer_types_[key].get();
}

Type *TypeFactory::get_quant_int_type(int num_bits,
bool is_signed,
Type *compute_type) {
const Type *TypeFactory::get_quant_int_type(int num_bits,
bool is_signed,
const Type *compute_type) {
std::lock_guard<std::mutex> _(quant_int_mut_);

auto key = std::make_tuple(num_bits, is_signed, compute_type);
Expand All @@ -81,9 +83,9 @@ Type *TypeFactory::get_quant_int_type(int num_bits,
return quant_int_types_[key].get();
}

Type *TypeFactory::get_quant_fixed_type(Type *digits_type,
Type *compute_type,
float64 scale) {
const Type *TypeFactory::get_quant_fixed_type(const Type *digits_type,
const Type *compute_type,
float64 scale) {
std::lock_guard<std::mutex> _(quant_fixed_mut_);

auto key = std::make_tuple(digits_type, compute_type, scale);
Expand All @@ -94,9 +96,9 @@ Type *TypeFactory::get_quant_fixed_type(Type *digits_type,
return quant_fixed_types_[key].get();
}

Type *TypeFactory::get_quant_float_type(Type *digits_type,
Type *exponent_type,
Type *compute_type) {
const Type *TypeFactory::get_quant_float_type(const Type *digits_type,
const Type *exponent_type,
const Type *compute_type) {
std::lock_guard<std::mutex> _(quant_float_mut_);

auto key = std::make_tuple(digits_type, exponent_type, compute_type);
Expand All @@ -108,8 +110,8 @@ Type *TypeFactory::get_quant_float_type(Type *digits_type,
}

BitStructType *TypeFactory::get_bit_struct_type(
PrimitiveType *physical_type,
const std::vector<Type *> &member_types,
const PrimitiveType *physical_type,
const std::vector<const Type *> &member_types,
const std::vector<int> &member_bit_offsets,
const std::vector<int> &member_exponents,
const std::vector<std::vector<int>> &member_exponent_users) {
Expand All @@ -121,18 +123,20 @@ BitStructType *TypeFactory::get_bit_struct_type(
return bit_struct_types_.back().get();
}

Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type,
Type *element_type,
int num_elements) {
const Type *TypeFactory::get_quant_array_type(
const PrimitiveType *physical_type,
const Type *element_type,
int num_elements) {
std::lock_guard<std::mutex> _(quant_array_mut_);

quant_array_types_.push_back(std::make_unique<QuantArrayType>(
physical_type, element_type, num_elements));
return quant_array_types_.back().get();
}

PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) {
Type *int_type;
const PrimitiveType *TypeFactory::get_primitive_int_type(int bits,
bool is_signed) {
const Type *int_type;
if (bits == 8) {
int_type = get_primitive_type(PrimitiveTypeID::i8);
} else if (bits == 16) {
Expand All @@ -150,8 +154,8 @@ PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) {
return int_type->cast<PrimitiveType>();
}

PrimitiveType *TypeFactory::get_primitive_real_type(int bits) {
Type *real_type;
const PrimitiveType *TypeFactory::get_primitive_real_type(int bits) {
const Type *real_type;
if (bits == 16) {
real_type = get_primitive_type(PrimitiveTypeID::f16);
} else if (bits == 32) {
Expand Down
63 changes: 34 additions & 29 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,42 @@ class TypeFactory {
// TODO(type): maybe it makes sense to let each get_X function return X*
// instead of generic Type*

Type *get_primitive_type(PrimitiveTypeID id);
const Type *get_primitive_type(PrimitiveTypeID id);

PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true);
const PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true);

PrimitiveType *get_primitive_real_type(int bits);
const PrimitiveType *get_primitive_real_type(int bits);

Type *get_tensor_type(std::vector<int> shape, Type *element);
const Type *get_tensor_type(std::vector<int> shape, Type *element);

const Type *get_struct_type(const std::vector<StructMember> &elements,
const std::string &layout = "none");

Type *get_pointer_type(Type *element, bool is_bit_pointer = false);
const Type *get_pointer_type(const Type *element,
bool is_bit_pointer = false);

Type *get_quant_int_type(int num_bits, bool is_signed, Type *compute_type);
const Type *get_quant_int_type(int num_bits,
bool is_signed,
const Type *compute_type);

Type *get_quant_fixed_type(Type *digits_type,
Type *compute_type,
float64 scale);
const Type *get_quant_fixed_type(const Type *digits_type,
const Type *compute_type,
float64 scale);

Type *get_quant_float_type(Type *digits_type,
Type *exponent_type,
Type *compute_type);
const Type *get_quant_float_type(const Type *digits_type,
const Type *exponent_type,
const Type *compute_type);

BitStructType *get_bit_struct_type(
PrimitiveType *physical_type,
const std::vector<Type *> &member_types,
const PrimitiveType *physical_type,
const std::vector<const Type *> &member_types,
const std::vector<int> &member_bit_offsets,
const std::vector<int> &member_exponents,
const std::vector<std::vector<int>> &member_exponent_users);

Type *get_quant_array_type(PrimitiveType *physical_type,
Type *element_type,
int num_elements);
const Type *get_quant_array_type(const PrimitiveType *physical_type,
const Type *element_type,
int num_elements);

static DataType create_tensor_type(std::vector<int> shape, DataType element);

Expand All @@ -56,9 +59,9 @@ class TypeFactory {
std::unordered_map<PrimitiveTypeID, std::unique_ptr<Type>> primitive_types_;
std::mutex primitive_mut_;

std::unordered_map<std::pair<std::string, Type *>,
std::unordered_map<std::pair<std::string, const Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::pair<std::string, Type *>>>
hashing::Hasher<std::pair<std::string, const Type *>>>
tensor_types_;
std::mutex tensor_mut_;

Expand All @@ -70,27 +73,29 @@ class TypeFactory {
std::mutex struct_mut_;

// TODO: is_bit_ptr?
std::unordered_map<std::pair<Type *, bool>,
std::unordered_map<std::pair<const Type *, bool>,
std::unique_ptr<Type>,
hashing::Hasher<std::pair<Type *, bool>>>
hashing::Hasher<std::pair<const Type *, bool>>>
pointer_types_;
std::mutex pointer_mut_;

std::unordered_map<std::tuple<int, bool, Type *>,
std::unordered_map<std::tuple<int, bool, const Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<int, bool, Type *>>>
hashing::Hasher<std::tuple<int, bool, const Type *>>>
quant_int_types_;
std::mutex quant_int_mut_;

std::unordered_map<std::tuple<Type *, Type *, float64>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<Type *, Type *, float64>>>
std::unordered_map<
std::tuple<const Type *, const Type *, float64>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<const Type *, const Type *, float64>>>
quant_fixed_types_;
std::mutex quant_fixed_mut_;

std::unordered_map<std::tuple<Type *, Type *, Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<Type *, Type *, Type *>>>
std::unordered_map<
std::tuple<const Type *, const Type *, const Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<const Type *, const Type *, const Type *>>>
quant_float_types_;
std::mutex quant_float_mut_;

Expand Down
6 changes: 3 additions & 3 deletions taichi/ir/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class BitStructTypeBuilder {
member_bit_offsets_.push_back(member_total_bits_);
member_exponents_.push_back(-1);
member_exponent_users_.push_back({});
QuantIntType *member_qit = nullptr;
const QuantIntType *member_qit = nullptr;
if (auto qit = member_type->cast<QuantIntType>()) {
member_qit = qit;
} else if (auto qfxt = member_type->cast<QuantFixedType>()) {
Expand All @@ -260,8 +260,8 @@ class BitStructTypeBuilder {
return old_num_members;
}

PrimitiveType *physical_type_{nullptr};
std::vector<Type *> member_types_;
const PrimitiveType *physical_type_{nullptr};
std::vector<const Type *> member_types_;
std::vector<int> member_bit_offsets_;
int member_total_bits_{0};
std::vector<int> member_exponents_;
Expand Down