This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch tvm-ffi-bool in repository https://gitbox.apache.org/repos/asf/tvm.git
commit ee4270e2095abfb69ec5f6ea586000d5d7bbdee6 Author: tqchen <[email protected]> AuthorDate: Wed Nov 12 12:26:59 2025 -0500 [DataType] Update to use explicit Bool Type Aligning with DLPack This PR updates the project to use explicit bool type which helps us to align with dlpack. It will also streamline explicit use of bool types. --- .gitmodules | 2 +- 3rdparty/tvm-ffi | 2 +- include/tvm/runtime/data_type.h | 7 ++-- include/tvm/tir/op.h | 6 ++-- src/arith/const_fold.h | 26 +++++++------- src/ir/expr.cc | 7 ++-- src/relax/transform/utils.h | 2 +- src/target/llvm/codegen_llvm.cc | 8 +++-- src/target/llvm/codegen_llvm.h | 1 + src/target/source/codegen_source_base.cc | 5 +++ src/target/spirv/codegen_spirv.cc | 4 +-- src/target/spirv/ir_builder.cc | 59 ++++++++++++++++---------------- 12 files changed, 70 insertions(+), 59 deletions(-) diff --git a/.gitmodules b/.gitmodules index 0513981e58..7acfdb6a38 100644 --- a/.gitmodules +++ b/.gitmodules @@ -24,4 +24,4 @@ url = https://github.com/madler/zlib.git [submodule "3rdparty/tvm-ffi"] path = 3rdparty/tvm-ffi - url = https://github.com/apache/tvm-ffi + url = https://github.com/tqchen/tvm-ffi diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi index f703a0cf93..60f45ac017 160000 --- a/3rdparty/tvm-ffi +++ b/3rdparty/tvm-ffi @@ -1 +1 @@ -Subproject commit f703a0cf9358fa30d8faee719f905c58d8ca6ee3 +Subproject commit 60f45ac017964caf2252b3c74a6e10a4422a1835 diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 0af3022bbd..da355bd7ce 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -60,6 +60,7 @@ class DataType { kFloat = kDLFloat, kHandle = kDLOpaqueHandle, kBFloat = kDLBfloat, + kBool = kDLBool, kFloat8_e3m4 = kDLFloat8_e3m4, kFloat8_e4m3 = kDLFloat8_e4m3, kFloat8_e4m3b11fnuz = kDLFloat8_e4m3b11fnuz, @@ -138,7 +139,7 @@ class DataType { /*! \return whether type is a scalar type. */ bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; } /*! \return whether type is a scalar type. */ - bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } + bool is_bool() const { return code() == DataType::kBool; } /*! \return whether type is a float type. */ bool is_float() const { return code() == DataType::kFloat; } /*! \return whether type is a bfloat type. */ @@ -204,7 +205,7 @@ class DataType { /*! \return whether type is a vector type. */ bool is_vector() const { return lanes() > 1; } /*! \return whether type is a bool vector type. */ - bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; } + bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && is_bool(); } /*! \return whether type is a Void type. */ bool is_void() const { return code() == DataType::kHandle && bits() == 0 && static_cast<int16_t>(data_.lanes) == 0; @@ -381,7 +382,7 @@ class DataType { * \return The constructed data type. */ static DataType Bool(int lanes = 1, bool is_scalable = false) { - return DataType::UInt(1, lanes, is_scalable); + return DataType(kDLBool, 8, lanes, is_scalable); } /*! * \brief Construct a handle type. diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 6a0f427b80..57f8681514 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -816,7 +816,7 @@ inline PrimExpr make_zero(DataType t, Span span = Span()); * \return The result expression. */ inline PrimExpr const_true(int lanes = 1, Span span = Span()) { - return make_const(DataType::UInt(1, lanes), 1); + return make_const(DataType::Bool(lanes), 1); } /*! * \brief Make a constant false expression. @@ -825,7 +825,7 @@ inline PrimExpr const_true(int lanes = 1, Span span = Span()) { * \return The result expression. */ inline PrimExpr const_false(int lanes = 1, Span span = Span()) { - return make_const(DataType::UInt(1, lanes), 0); + return make_const(DataType::Bool(lanes), 0); } /*! * \brief Get x as constant int expression. @@ -957,7 +957,7 @@ inline bool is_no_op(const tir::Stmt& stmt) { template <typename ValueType> inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) { - if (t.is_int()) return IntImm(t, static_cast<int64_t>(value), span); + if (t.is_int() || t.is_bool()) return IntImm(t, static_cast<int64_t>(value), span); if (t.is_uint()) { // Use IntImm if it is a small integer uint64_t uval = static_cast<uint64_t>(value); diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index dda7f67465..1851142b02 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -349,8 +349,8 @@ inline ffi::Optional<PrimExpr> TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); + if (pa && pb) return IntImm(DataType::Bool(), pa->value > pb->value); + if (fa && fb) return IntImm(DataType::Bool(), fa->value > fb->value); }); return std::nullopt; } @@ -358,8 +358,8 @@ inline ffi::Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); + if (pa && pb) return IntImm(DataType::Bool(), pa->value >= pb->value); + if (fa && fb) return IntImm(DataType::Bool(), fa->value >= fb->value); }); return std::nullopt; } @@ -367,8 +367,8 @@ inline ffi::Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); + if (pa && pb) return IntImm(DataType::Bool(), pa->value < pb->value); + if (fa && fb) return IntImm(DataType::Bool(), fa->value < fb->value); }); return std::nullopt; } @@ -376,8 +376,8 @@ inline ffi::Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); + if (pa && pb) return IntImm(DataType::Bool(1), pa->value <= pb->value); + if (fa && fb) return IntImm(DataType::Bool(1), fa->value <= fb->value); }); return std::nullopt; } @@ -385,8 +385,8 @@ inline ffi::Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); + if (pa && pb) return IntImm(DataType::Bool(1), pa->value == pb->value); + if (fa && fb) return IntImm(DataType::Bool(1), fa->value == fb->value); }); return std::nullopt; } @@ -394,8 +394,8 @@ inline ffi::Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); + if (pa && pb) return IntImm(DataType::Bool(1), pa->value != pb->value); + if (fa && fb) return IntImm(DataType::Bool(1), fa->value != fb->value); }); return std::nullopt; } @@ -426,7 +426,7 @@ template <> inline ffi::Optional<PrimExpr> TryConstFold<tir::Not>(PrimExpr a) { const IntImmNode* pa = a.as<IntImmNode>(); if (pa) { - return IntImm(DataType::UInt(1), !(pa->value)); + return IntImm(DataType::Bool(), !(pa->value)); } return std::nullopt; } diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 6c0065c29c..b856854a5d 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -53,8 +53,9 @@ PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { return tir::StringI IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " << dtype << " was supplied."; - ICHECK(dtype.is_int() || dtype.is_uint()) - << "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied."; + ICHECK(dtype.is_int() || dtype.is_uint() || dtype.is_bool()) + << "ValueError: IntImm supports only int or uint or bool type, but " << dtype + << " was supplied."; if (dtype.is_uint()) { ICHECK_GE(value, 0U) << "ValueError: Literal value " << value << " is negative for unsigned integer type " << dtype; @@ -62,7 +63,7 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK_LT(value, 1LL << dtype.bits()) << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; } - } else if (dtype.bits() == 1) { + } else if (dtype.bits() == 1 || dtype.is_bool()) { // int(1) ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype; } else if (dtype.bits() < 64) { diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index ff8596cd79..5bcb5f2199 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -328,7 +328,7 @@ inline Constant MakeConstantScalar(T value, DataType dtype) { *static_cast<int32_t*>(arr->data) = static_cast<int32_t>(value); } else if (dtype == DataType::Int(64)) { *static_cast<int64_t*>(arr->data) = static_cast<int64_t>(value); - } else if (dtype == DataType::UInt(1)) { + } else if (dtype == DataType::Bool()) { *static_cast<bool*>(arr->data) = static_cast<bool>(value); } else if (dtype == DataType::UInt(8)) { *static_cast<uint8_t*>(arr->data) = static_cast<uint8_t>(value); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index bdb0c6b738..39bb221cf2 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -148,6 +148,7 @@ void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, // types t_void_ = llvm::Type::getVoidTy(*ctx); t_void_p_ = llvmGetPointerTo(llvm::Type::getInt8Ty(*ctx), GetGlobalAddressSpace()); + t_int1_ = llvm::Type::getInt1Ty(*ctx); t_int_ = llvm::Type::getInt32Ty(*ctx); t_char_ = llvm::Type::getInt8Ty(*ctx); t_int8_ = llvm::Type::getInt8Ty(*ctx); @@ -572,6 +573,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { if (dtype.is_void()) { return t_void_; } + if (dtype.is_bool()) { + return t_int1_; + } llvm::Type* etype = nullptr; llvm::LLVMContext* ctx = llvm_target_->GetContext(); if (dtype.is_int() || dtype.is_uint()) { @@ -922,7 +926,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va if (to.is_handle()) { return builder_->CreateBitCast(value, target); - } else if (to.is_uint() && to.bits() == 1) { + } else if (to.is_bool()) { if (from.is_float()) { llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.); return builder_->CreateFCmpONE(value, zero); @@ -943,7 +947,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va } } else if (from.is_int() && to.is_float()) { return builder_->CreateSIToFP(value, target); - } else if (from.is_uint() && to.is_float()) { + } else if ((from.is_uint() || from.is_bool()) && to.is_float()) { return builder_->CreateUIToFP(value, target); } else { ICHECK(from.is_float() && to.is_float()); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 5cf053cf71..efec7ad6ad 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -536,6 +536,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>, llvm::Type* t_void_{nullptr}; llvm::PointerType* t_void_p_{nullptr}; llvm::Type* t_int_{nullptr}; + llvm::Type* t_int1_{nullptr}; llvm::Type* t_char_{nullptr}; llvm::Type* t_int8_{nullptr}; llvm::Type* t_int16_{nullptr}; diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 60fa786d52..917036b8e2 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -109,6 +109,11 @@ void CodeGenSourceBase::PrintType(DataType type, std::ostream& os) { // NOLINT( os << "void"; return; } + // default c may be have bool type, can be handled in subclass + if (type.is_bool()) { + os << "int"; + return; + } if (type.is_float()) { if (type.bits() == 32) { os << "float"; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index ddbc22d88a..c062926cc2 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -430,7 +430,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Value dst_ptr = builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], MakeValue(dst_index)); spirv::Value src_ptr = VisitExpr(op->args[5]); - spirv::SType type_bool = builder_->GetSType(DataType::UInt(1)); + spirv::SType type_bool = builder_->GetSType(DataType::Bool()); spirv::Value t_val = builder_->UIntImm(type_bool, 1); spirv::Value f_val = builder_->UIntImm(type_bool, 0); spirv::Value loaded = @@ -492,7 +492,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index)); uint32_t mask = spv::MemoryAccessMaskNone; spirv::Value loaded = builder_->MakeValue(spv::OpLoad, fragment_type, ptr, mask); - spirv::SType type_bool = builder_->GetSType(DataType::UInt(1)); + spirv::SType type_bool = builder_->GetSType(DataType::Bool()); spirv::Value t_val = builder_->UIntImm(type_bool, 1); spirv::Value f_val = builder_->UIntImm(type_bool, 0); builder_->MakeInst(spv::OpCooperativeMatrixStoreNV, dst_ptr, loaded, stride_val, diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 545e677af9..bbe616a5d5 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -76,7 +76,7 @@ void IRBuilder::InitPreDefs() { ext_glsl450_ = ExtInstImport("GLSL.std.450"); t_int32_ = DeclareType(DataType::Int(32)); t_uint32_ = DeclareType(DataType::UInt(32)); - t_bool_ = DeclareType(DataType::UInt(1)); + t_bool_ = DeclareType(DataType::Bool()); t_fp32_ = DeclareType(DataType::Float(32)); const_i32_zero_ = IntImm(t_int32_, 0); @@ -115,7 +115,7 @@ std::vector<uint32_t> IRBuilder::Finalize() { SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) { if (dtype == DataType::Int(32)) { return t_int32_; - } else if (dtype == DataType::UInt(1)) { + } else if (dtype == DataType::Bool()) { return t_bool_; } else if (dtype == DataType::Float(32)) { return t_fp32_; @@ -467,7 +467,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { } ICHECK_LE(dtype.type.bits(), 64); Value ret = NewValue(dtype, kConstant); - if (dtype.type == DataType::UInt(1)) { + if (dtype.type == DataType::Bool()) { // bool types. if (*pvalue) { ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret); @@ -501,8 +501,7 @@ SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) SType t; t.id = id_counter_++; t.type = dtype; - if (dtype.bits() == 1) { - ICHECK(dtype.is_uint()); + if (dtype.is_bool()) { ib_.Begin(spv::OpTypeBool).Add(t).Commit(&global_); } else if (dtype.is_int()) { ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 1).Commit(&global_); @@ -822,19 +821,19 @@ Value IRBuilder::Mod(Value a, Value b) { } } -#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - ICHECK_EQ(a.stype.id, b.stype.id); \ - ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ - const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int()) { \ - return MakeValue(spv::OpS##_Op, bool_type, a, b); \ - } else if (a.stype.type.is_uint()) { \ - return MakeValue(spv::OpU##_Op, bool_type, a, b); \ - } else { \ - ICHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ +#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + ICHECK_EQ(a.stype.id, b.stype.id); \ + ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \ + if (a.stype.type.is_int()) { \ + return MakeValue(spv::OpS##_Op, bool_type, a, b); \ + } else if (a.stype.type.is_uint()) { \ + return MakeValue(spv::OpU##_Op, bool_type, a, b); \ + } else { \ + ICHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_OP(LT, LessThan); @@ -842,17 +841,17 @@ DEFINE_BUILDER_CMP_OP(LE, LessThanEqual); DEFINE_BUILDER_CMP_OP(GT, GreaterThan); DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual); -#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - ICHECK_EQ(a.stype.id, b.stype.id); \ - ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ - const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI##_Op, bool_type, a, b); \ - } else { \ - ICHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ +#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + ICHECK_EQ(a.stype.id, b.stype.id); \ + ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \ + if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ + return MakeValue(spv::OpI##_Op, bool_type, a, b); \ + } else { \ + ICHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_UOP(EQ, Equal); @@ -860,7 +859,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual); Value IRBuilder::Select(Value cond, Value a, Value b) { ICHECK_EQ(a.stype.id, b.stype.id); - ICHECK_EQ(cond.stype.type.element_of(), DataType::UInt(1)); + ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool()); return MakeValue(spv::OpSelect, a.stype, cond, a, b); }
