This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch tvm-ffi-bool
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit a975381e85b7c52feaa108d3a628bb800dccb0f4
Author: tqchen <[email protected]>
AuthorDate: Wed Nov 12 12:26:59 2025 -0500

    [DataType] Update to use explicit Bool Type Aligning with DLPack
    
    This PR updates the project to use explicit bool type which helps
    us to align with dlpack. It will also streamline explicit use of bool types.
---
 .gitmodules                                        |  2 +-
 3rdparty/tvm-ffi                                   |  2 +-
 include/tvm/runtime/data_type.h                    |  9 ++--
 include/tvm/tir/op.h                               |  6 +--
 python/tvm/script/parser/tir/operation.py          |  2 +
 python/tvm/tir/ir_builder.py                       |  2 +-
 src/arith/const_fold.h                             | 26 +++++-----
 src/arith/const_int_bound.cc                       |  5 +-
 src/ir/expr.cc                                     |  7 +--
 src/relax/transform/utils.h                        |  2 +-
 src/runtime/vm/builtin.cc                          |  2 +-
 src/target/llvm/codegen_llvm.cc                    |  8 ++-
 src/target/llvm/codegen_llvm.h                     |  1 +
 src/target/source/codegen_opencl.cc                |  6 +++
 src/target/source/codegen_source_base.cc           |  5 ++
 src/target/spirv/codegen_spirv.cc                  |  4 +-
 src/target/spirv/ir_builder.cc                     | 59 +++++++++++-----------
 src/tir/ir/expr.cc                                 |  2 +-
 src/tir/ir/stmt.cc                                 |  5 +-
 src/tir/op/op.cc                                   | 41 ++++++++++-----
 tests/cpp/tir_scalable_datatype.cc                 |  4 +-
 tests/python/arith/test_arith_rewrite_simplify.py  | 22 ++++----
 tests/python/relax/test_op_nn.py                   |  2 -
 tests/python/tir-base/test_tir_constructor.py      | 12 ++---
 tests/python/tir-base/test_tir_nodes.py            |  2 +-
 tests/python/tir-base/test_tir_ops.py              | 14 ++---
 .../tvmscript/test_tvmscript_ir_builder_tir.py     |  2 +-
 .../python/tvmscript/test_tvmscript_printer_tir.py |  4 +-
 28 files changed, 148 insertions(+), 110 deletions(-)

diff --git a/.gitmodules b/.gitmodules
index 0513981e58..7acfdb6a38 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -24,4 +24,4 @@
        url = https://github.com/madler/zlib.git
 [submodule "3rdparty/tvm-ffi"]
        path = 3rdparty/tvm-ffi
-       url = https://github.com/apache/tvm-ffi
+       url = https://github.com/tqchen/tvm-ffi
diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi
index f703a0cf93..646d3bab6b 160000
--- a/3rdparty/tvm-ffi
+++ b/3rdparty/tvm-ffi
@@ -1 +1 @@
-Subproject commit f703a0cf9358fa30d8faee719f905c58d8ca6ee3
+Subproject commit 646d3bab6bfe381213b25659f395c2ea1df5b91f
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 0af3022bbd..3a91d4777b 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -60,6 +60,7 @@ class DataType {
     kFloat = kDLFloat,
     kHandle = kDLOpaqueHandle,
     kBFloat = kDLBfloat,
+    kBool = kDLBool,
     kFloat8_e3m4 = kDLFloat8_e3m4,
     kFloat8_e4m3 = kDLFloat8_e4m3,
     kFloat8_e4m3b11fnuz = kDLFloat8_e4m3b11fnuz,
@@ -138,7 +139,9 @@ class DataType {
   /*! \return whether type is a scalar type. */
   bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
   /*! \return whether type is a scalar type. */
-  bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
+  bool is_bool() const { return code() == DataType::kBool; }
+  /*! \return whether type can be used in a predicate expression. */
+  bool is_predicate_dtype() const { return is_bool() || (is_uint() && bits() 
== 1); }
   /*! \return whether type is a float type. */
   bool is_float() const { return code() == DataType::kFloat; }
   /*! \return whether type is a bfloat type. */
@@ -204,7 +207,7 @@ class DataType {
   /*! \return whether type is a vector type. */
   bool is_vector() const { return lanes() > 1; }
   /*! \return whether type is a bool vector type. */
-  bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && 
bits() == 1; }
+  bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && 
is_bool(); }
   /*! \return whether type is a Void type. */
   bool is_void() const {
     return code() == DataType::kHandle && bits() == 0 && 
static_cast<int16_t>(data_.lanes) == 0;
@@ -381,7 +384,7 @@ class DataType {
    * \return The constructed data type.
    */
   static DataType Bool(int lanes = 1, bool is_scalable = false) {
-    return DataType::UInt(1, lanes, is_scalable);
+    return DataType(kDLBool, 8, lanes, is_scalable);
   }
   /*!
    * \brief Construct a handle type.
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 6a0f427b80..57f8681514 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -816,7 +816,7 @@ inline PrimExpr make_zero(DataType t, Span span = Span());
  * \return The result expression.
  */
 inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
-  return make_const(DataType::UInt(1, lanes), 1);
+  return make_const(DataType::Bool(lanes), 1);
 }
 /*!
  * \brief Make a constant false expression.
@@ -825,7 +825,7 @@ inline PrimExpr const_true(int lanes = 1, Span span = 
Span()) {
  * \return The result expression.
  */
 inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
-  return make_const(DataType::UInt(1, lanes), 0);
+  return make_const(DataType::Bool(lanes), 0);
 }
 /*!
  * \brief Get x as constant int expression.
@@ -957,7 +957,7 @@ inline bool is_no_op(const tir::Stmt& stmt) {
 
 template <typename ValueType>
 inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = 
Span()) {
-  if (t.is_int()) return IntImm(t, static_cast<int64_t>(value), span);
+  if (t.is_int() || t.is_bool()) return IntImm(t, static_cast<int64_t>(value), 
span);
   if (t.is_uint()) {
     // Use IntImm if it is a small integer
     uint64_t uval = static_cast<uint64_t>(value);
diff --git a/python/tvm/script/parser/tir/operation.py 
b/python/tvm/script/parser/tir/operation.py
index 22f996a456..b22b0a7335 100644
--- a/python/tvm/script/parser/tir/operation.py
+++ b/python/tvm/script/parser/tir/operation.py
@@ -61,6 +61,7 @@ def _register_expr_op(ty: Type):  # pylint: 
disable=invalid-name
                 if (
                     DataType(b.dtype).type_code == DataTypeCode.INT
                     or DataType(b.dtype).type_code == DataTypeCode.UINT
+                    or DataType(b.dtype).type_code == DataTypeCode.BOOL
                 ):
                     a = IntImm(_get_type_str(b.dtype), a)
                 elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
@@ -80,6 +81,7 @@ def _register_expr_op(ty: Type):  # pylint: 
disable=invalid-name
             if (
                 DataType(a.dtype).type_code == DataTypeCode.INT
                 or DataType(a.dtype).type_code == DataTypeCode.UINT
+                or DataType(a.dtype).type_code == DataTypeCode.BOOL
             ):
                 b = IntImm(_get_type_str(a.dtype), b)
             elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index d6466b0922..a6313ae3bc 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -448,7 +448,7 @@ class IRBuilder(object):
         )
 
         buffer_var = buffer.data
-        self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, 
dtype="uint1"), x))
+        self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, 
dtype="bool"), x))
         return BufferVar(self, buffer, dtype)
 
     def pointer(self, content_type, name="ptr", scope=""):
diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h
index dda7f67465..1851142b02 100644
--- a/src/arith/const_fold.h
+++ b/src/arith/const_fold.h
@@ -349,8 +349,8 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value > pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value > fb->value);
   });
   return std::nullopt;
 }
@@ -358,8 +358,8 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value >= pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value >= fb->value);
   });
   return std::nullopt;
 }
@@ -367,8 +367,8 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value < pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value < fb->value);
   });
   return std::nullopt;
 }
@@ -376,8 +376,8 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(1), pa->value <= pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(1), fa->value <= fb->value);
   });
   return std::nullopt;
 }
@@ -385,8 +385,8 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(1), pa->value == pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(1), fa->value == fb->value);
   });
   return std::nullopt;
 }
@@ -394,8 +394,8 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(1), pa->value != pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(1), fa->value != fb->value);
   });
   return std::nullopt;
 }
@@ -426,7 +426,7 @@ template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::Not>(PrimExpr a) {
   const IntImmNode* pa = a.as<IntImmNode>();
   if (pa) {
-    return IntImm(DataType::UInt(1), !(pa->value));
+    return IntImm(DataType::Bool(), !(pa->value));
   }
   return std::nullopt;
 }
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 7e1d8fb3fb..d8296bafd9 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -735,9 +735,12 @@ class ConstIntBoundAnalyzer::Impl
    * \return Bound that represent everything dtype can represent.
    */
   static Entry Everything(DataType dtype) {
-    if (!dtype.is_int() && !dtype.is_uint()) {
+    if (!dtype.is_int() && !dtype.is_uint() && !dtype.is_bool()) {
       return MakeBound(kNegInf, kPosInf);
     }
+    if (dtype.is_bool()) {
+      return MakeBound(0, 1);
+    }
     Entry ret;
     int64_t vbits = dtype.bits() - static_cast<int>(dtype.is_int());
     if (dtype.is_uint()) {
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 6c0065c29c..b856854a5d 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -53,8 +53,9 @@ PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { 
return tir::StringI
 IntImm::IntImm(DataType dtype, int64_t value, Span span) {
   ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " 
<< dtype
                             << " was supplied.";
-  ICHECK(dtype.is_int() || dtype.is_uint())
-      << "ValueError: IntImm supports only int or uint type, but " << dtype << 
" was supplied.";
+  ICHECK(dtype.is_int() || dtype.is_uint() || dtype.is_bool())
+      << "ValueError: IntImm supports only int or uint or bool type, but " << 
dtype
+      << " was supplied.";
   if (dtype.is_uint()) {
     ICHECK_GE(value, 0U) << "ValueError: Literal value " << value
                          << " is negative for unsigned integer type " << dtype;
@@ -62,7 +63,7 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) {
       ICHECK_LT(value, 1LL << dtype.bits())
           << "ValueError: Literal value " << value << " exceeds maximum of " 
<< dtype;
     }
-  } else if (dtype.bits() == 1) {
+  } else if (dtype.bits() == 1 || dtype.is_bool()) {
     // int(1)
     ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds 
range of " << dtype;
   } else if (dtype.bits() < 64) {
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index ff8596cd79..5bcb5f2199 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -328,7 +328,7 @@ inline Constant MakeConstantScalar(T value, DataType dtype) 
{
     *static_cast<int32_t*>(arr->data) = static_cast<int32_t>(value);
   } else if (dtype == DataType::Int(64)) {
     *static_cast<int64_t*>(arr->data) = static_cast<int64_t>(value);
-  } else if (dtype == DataType::UInt(1)) {
+  } else if (dtype == DataType::Bool()) {
     *static_cast<bool*>(arr->data) = static_cast<bool>(value);
   } else if (dtype == DataType::UInt(8)) {
     *static_cast<uint8_t*>(arr->data) = static_cast<uint8_t>(value);
diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc
index 13446a158f..1bd3084c21 100644
--- a/src/runtime/vm/builtin.cc
+++ b/src/runtime/vm/builtin.cc
@@ -535,7 +535,7 @@ bool ReadIfCond(ffi::AnyView cond) {
   if (arr->device.device_type != kDLCPU) {
     arr = arr.CopyTo(DLDevice{kDLCPU, 0});
   }
-  ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt);
+  ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt || 
arr->dtype.code == kDLBool);
   int64_t result;
   switch (arr->dtype.bits) {
     case 1: {
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index bdb0c6b738..39bb221cf2 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -148,6 +148,7 @@ void CodeGenLLVM::Init(const std::string& module_name, 
LLVMTarget* llvm_target,
   // types
   t_void_ = llvm::Type::getVoidTy(*ctx);
   t_void_p_ = llvmGetPointerTo(llvm::Type::getInt8Ty(*ctx), 
GetGlobalAddressSpace());
+  t_int1_ = llvm::Type::getInt1Ty(*ctx);
   t_int_ = llvm::Type::getInt32Ty(*ctx);
   t_char_ = llvm::Type::getInt8Ty(*ctx);
   t_int8_ = llvm::Type::getInt8Ty(*ctx);
@@ -572,6 +573,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& 
dtype) const {
   if (dtype.is_void()) {
     return t_void_;
   }
+  if (dtype.is_bool()) {
+    return t_int1_;
+  }
   llvm::Type* etype = nullptr;
   llvm::LLVMContext* ctx = llvm_target_->GetContext();
   if (dtype.is_int() || dtype.is_uint()) {
@@ -922,7 +926,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, 
DataType to, llvm::Value* va
 
   if (to.is_handle()) {
     return builder_->CreateBitCast(value, target);
-  } else if (to.is_uint() && to.bits() == 1) {
+  } else if (to.is_bool()) {
     if (from.is_float()) {
       llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.);
       return builder_->CreateFCmpONE(value, zero);
@@ -943,7 +947,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, 
DataType to, llvm::Value* va
     }
   } else if (from.is_int() && to.is_float()) {
     return builder_->CreateSIToFP(value, target);
-  } else if (from.is_uint() && to.is_float()) {
+  } else if ((from.is_uint() || from.is_bool()) && to.is_float()) {
     return builder_->CreateUIToFP(value, target);
   } else {
     ICHECK(from.is_float() && to.is_float());
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 5cf053cf71..efec7ad6ad 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -536,6 +536,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const 
PrimExpr&)>,
   llvm::Type* t_void_{nullptr};
   llvm::PointerType* t_void_p_{nullptr};
   llvm::Type* t_int_{nullptr};
+  llvm::Type* t_int1_{nullptr};
   llvm::Type* t_char_{nullptr};
   llvm::Type* t_int8_{nullptr};
   llvm::Type* t_int16_{nullptr};
diff --git a/src/target/source/codegen_opencl.cc 
b/src/target/source/codegen_opencl.cc
index 769401c4bc..8ea55b8ff5 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -230,6 +230,12 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& 
os) {  // NOLINT(*)
       os << lanes;
       return;
     }
+  } else if (t.is_bool()) {
+    os << "uint";
+    if (!fail && ((lanes >= 2 && lanes <= 4) || lanes == 8 || lanes == 16)) {
+      os << lanes;
+      return;
+    }
   } else if (t.is_uint() || t.is_int()) {
     if (t.is_uint()) {
       os << 'u';
diff --git a/src/target/source/codegen_source_base.cc 
b/src/target/source/codegen_source_base.cc
index 60fa786d52..917036b8e2 100644
--- a/src/target/source/codegen_source_base.cc
+++ b/src/target/source/codegen_source_base.cc
@@ -109,6 +109,11 @@ void CodeGenSourceBase::PrintType(DataType type, 
std::ostream& os) {  // NOLINT(
     os << "void";
     return;
   }
+  // default c may be have bool type, can be handled in subclass
+  if (type.is_bool()) {
+    os << "int";
+    return;
+  }
   if (type.is_float()) {
     if (type.bits() == 32) {
       os << "float";
diff --git a/src/target/spirv/codegen_spirv.cc 
b/src/target/spirv/codegen_spirv.cc
index ddbc22d88a..c062926cc2 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -430,7 +430,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
     spirv::Value dst_ptr =
         builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], 
MakeValue(dst_index));
     spirv::Value src_ptr = VisitExpr(op->args[5]);
-    spirv::SType type_bool = builder_->GetSType(DataType::UInt(1));
+    spirv::SType type_bool = builder_->GetSType(DataType::Bool());
     spirv::Value t_val = builder_->UIntImm(type_bool, 1);
     spirv::Value f_val = builder_->UIntImm(type_bool, 0);
     spirv::Value loaded =
@@ -492,7 +492,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
         builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], 
MakeValue(index));
     uint32_t mask = spv::MemoryAccessMaskNone;
     spirv::Value loaded = builder_->MakeValue(spv::OpLoad, fragment_type, ptr, 
mask);
-    spirv::SType type_bool = builder_->GetSType(DataType::UInt(1));
+    spirv::SType type_bool = builder_->GetSType(DataType::Bool());
     spirv::Value t_val = builder_->UIntImm(type_bool, 1);
     spirv::Value f_val = builder_->UIntImm(type_bool, 0);
     builder_->MakeInst(spv::OpCooperativeMatrixStoreNV, dst_ptr, loaded, 
stride_val,
diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc
index 545e677af9..bbe616a5d5 100644
--- a/src/target/spirv/ir_builder.cc
+++ b/src/target/spirv/ir_builder.cc
@@ -76,7 +76,7 @@ void IRBuilder::InitPreDefs() {
   ext_glsl450_ = ExtInstImport("GLSL.std.450");
   t_int32_ = DeclareType(DataType::Int(32));
   t_uint32_ = DeclareType(DataType::UInt(32));
-  t_bool_ = DeclareType(DataType::UInt(1));
+  t_bool_ = DeclareType(DataType::Bool());
   t_fp32_ = DeclareType(DataType::Float(32));
   const_i32_zero_ = IntImm(t_int32_, 0);
 
@@ -115,7 +115,7 @@ std::vector<uint32_t> IRBuilder::Finalize() {
 SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) {
   if (dtype == DataType::Int(32)) {
     return t_int32_;
-  } else if (dtype == DataType::UInt(1)) {
+  } else if (dtype == DataType::Bool()) {
     return t_bool_;
   } else if (dtype == DataType::Float(32)) {
     return t_fp32_;
@@ -467,7 +467,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const 
uint64_t* pvalue) {
   }
   ICHECK_LE(dtype.type.bits(), 64);
   Value ret = NewValue(dtype, kConstant);
-  if (dtype.type == DataType::UInt(1)) {
+  if (dtype.type == DataType::Bool()) {
     // bool types.
     if (*pvalue) {
       ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret);
@@ -501,8 +501,7 @@ SType IRBuilder::DeclareType(const DataType& dtype, 
uint32_t row, uint32_t col)
     SType t;
     t.id = id_counter_++;
     t.type = dtype;
-    if (dtype.bits() == 1) {
-      ICHECK(dtype.is_uint());
+    if (dtype.is_bool()) {
       ib_.Begin(spv::OpTypeBool).Add(t).Commit(&global_);
     } else if (dtype.is_int()) {
       ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 1).Commit(&global_);
@@ -822,19 +821,19 @@ Value IRBuilder::Mod(Value a, Value b) {
   }
 }
 
-#define DEFINE_BUILDER_CMP_OP(_OpName, _Op)                                    
                 \
-  Value IRBuilder::_OpName(Value a, Value b) {                                 
                 \
-    ICHECK_EQ(a.stype.id, b.stype.id);                                         
                 \
-    ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                     
                 \
-    const auto& bool_type = 
this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \
-    if (a.stype.type.is_int()) {                                               
                 \
-      return MakeValue(spv::OpS##_Op, bool_type, a, b);                        
                 \
-    } else if (a.stype.type.is_uint()) {                                       
                 \
-      return MakeValue(spv::OpU##_Op, bool_type, a, b);                        
                 \
-    } else {                                                                   
                 \
-      ICHECK(a.stype.type.is_float());                                         
                 \
-      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                     
                 \
-    }                                                                          
                 \
+#define DEFINE_BUILDER_CMP_OP(_OpName, _Op)                                    
                \
+  Value IRBuilder::_OpName(Value a, Value b) {                                 
                \
+    ICHECK_EQ(a.stype.id, b.stype.id);                                         
                \
+    ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                     
                \
+    const auto& bool_type = 
this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \
+    if (a.stype.type.is_int()) {                                               
                \
+      return MakeValue(spv::OpS##_Op, bool_type, a, b);                        
                \
+    } else if (a.stype.type.is_uint()) {                                       
                \
+      return MakeValue(spv::OpU##_Op, bool_type, a, b);                        
                \
+    } else {                                                                   
                \
+      ICHECK(a.stype.type.is_float());                                         
                \
+      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                     
                \
+    }                                                                          
                \
   }
 
 DEFINE_BUILDER_CMP_OP(LT, LessThan);
@@ -842,17 +841,17 @@ DEFINE_BUILDER_CMP_OP(LE, LessThanEqual);
 DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
 DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
 
-#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op)                                   
                 \
-  Value IRBuilder::_OpName(Value a, Value b) {                                 
                 \
-    ICHECK_EQ(a.stype.id, b.stype.id);                                         
                 \
-    ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                     
                 \
-    const auto& bool_type = 
this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \
-    if (a.stype.type.is_int() || a.stype.type.is_uint()) {                     
                 \
-      return MakeValue(spv::OpI##_Op, bool_type, a, b);                        
                 \
-    } else {                                                                   
                 \
-      ICHECK(a.stype.type.is_float());                                         
                 \
-      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                     
                 \
-    }                                                                          
                 \
+#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op)                                   
                \
+  Value IRBuilder::_OpName(Value a, Value b) {                                 
                \
+    ICHECK_EQ(a.stype.id, b.stype.id);                                         
                \
+    ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                     
                \
+    const auto& bool_type = 
this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \
+    if (a.stype.type.is_int() || a.stype.type.is_uint()) {                     
                \
+      return MakeValue(spv::OpI##_Op, bool_type, a, b);                        
                \
+    } else {                                                                   
                \
+      ICHECK(a.stype.type.is_float());                                         
                \
+      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                     
                \
+    }                                                                          
                \
   }
 
 DEFINE_BUILDER_CMP_UOP(EQ, Equal);
@@ -860,7 +859,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual);
 
 Value IRBuilder::Select(Value cond, Value a, Value b) {
   ICHECK_EQ(a.stype.id, b.stype.id);
-  ICHECK_EQ(cond.stype.type.element_of(), DataType::UInt(1));
+  ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool());
   return MakeValue(spv::OpSelect, a.stype, cond, a, b);
 }
 
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 252b8693a7..5eee4ffd8b 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -840,7 +840,7 @@ BufferLoad::BufferLoad(Buffer buffer, ffi::Array<PrimExpr> 
indices,
         << " lanes. The number of lanes must match.";
 
     DataType predicate_element_dtype = predicate_dtype.element_of();
-    ICHECK(predicate_element_dtype.is_bool())
+    ICHECK(predicate_element_dtype.is_predicate_dtype())
         << "Predicate mask elements must be boolean values, but got " << 
predicate_element_dtype
         << ".";
   }
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index d33a01340b..47622757e5 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -485,7 +485,7 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, 
ffi::Array<PrimExpr> ind
         << " lanes. The number of lanes must match.";
 
     DataType predicate_element_dtype = predicate_dtype.element_of();
-    ICHECK(predicate_element_dtype.is_bool())
+    ICHECK(predicate_element_dtype.is_predicate_dtype())
         << "Predicate mask elements must be boolean values, but got " << 
predicate_element_dtype
         << ".";
   }
@@ -687,7 +687,8 @@ BlockRealize::BlockRealize(ffi::Array<PrimExpr> values, 
PrimExpr predicate, Bloc
                            Span span) {
   CHECK_EQ(block->iter_vars.size(), values.size())
       << "ValueError: BlockRealize needs to have the same number of iter_vars 
and binding values";
-  CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to 
be a bool expression";
+  CHECK(predicate.dtype().is_bool() || predicate.dtype() == DataType::UInt(1))
+      << "TypeError: Expect Block.predicate to be a bool expression";
   ObjectPtr<BlockRealizeNode> node = ffi::make_object<BlockRealizeNode>();
   node->iter_values = std::move(values);
   node->predicate = std::move(predicate);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 935f9928a5..808456a80c 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -214,6 +214,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span 
span) {  // NOLINT(*)
   } else if (ltype.is_float4() && !rtype.is_float4()) {
     // Cast int->float4 for rhs when lhs is a float4
     rhs = cast(ltype, rhs);
+  } else if (ltype.is_bool() && (rtype.is_int() || rtype.is_uint())) {
+    // Cast bool to int for lhs when rhs is a int or uint
+    lhs = cast(rtype, lhs);
+  } else if ((ltype.is_int() || ltype.is_uint()) && rtype.is_bool()) {
+    // Cast bool to int for rhs when lhs is a int or uint
+    rhs = cast(ltype, rhs);
   } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && 
rtype.is_uint())) {
     // Promote int to higher bits e.g. int8 + int16 --> int16 + int16
     if (ltype.bits() < rtype.bits()) {
@@ -698,10 +704,10 @@ void type_check_boolean_args(const PrimExpr& lhs, const 
PrimExpr& rhs, const cha
                                 << rhs << " of type " << rhs.dtype();
 }
 
-void type_check_integer_args(const PrimExpr& arg, const char* op) {
-  ICHECK(arg.dtype().is_int() || arg.dtype().is_uint())
-      << "Expected integer argument for " << op << ", but received " << arg << 
" of type "
-      << arg.dtype();
+void type_check_int_or_bool_args(const PrimExpr& arg, const char* op) {
+  ICHECK(arg.dtype().is_int() || arg.dtype().is_uint() || 
arg.dtype().is_bool())
+      << "Expected integer or boolean argument for " << op << ", but received 
" << arg
+      << " of type " << arg.dtype();
 }
 
 void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const 
char* op) {
@@ -712,6 +718,15 @@ void type_check_integer_args(const PrimExpr& lhs, const 
PrimExpr& rhs, const cha
       << "Expected integer argument as RHS of " << op << ", but received " << 
rhs << " of type "
       << rhs.dtype();
 }
+
+void type_check_int_or_bool_args(const PrimExpr& lhs, const PrimExpr& rhs, 
const char* op) {
+  ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint() || 
lhs.dtype().is_bool())
+      << "Expected integer argument as LHS of " << op << ", but received " << 
lhs << " of type "
+      << lhs.dtype();
+  ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint() || 
rhs.dtype().is_bool())
+      << "Expected integer argument as RHS of " << op << ", but received " << 
rhs << " of type "
+      << rhs.dtype();
+}
 }  // namespace
 
 PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); }
@@ -781,7 +796,7 @@ PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) {
 // bitwise and
 PrimExpr operator&(PrimExpr a, PrimExpr b) { return bitwise_and(a, b); }
 PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
-  type_check_integer_args(a, b, "& operator (bitwise AND)");
+  type_check_int_or_bool_args(a, b, "& operator (bitwise AND)");
   BinaryOpMatchTypes(a, b, span);
   TVM_INDEX_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
@@ -793,7 +808,7 @@ PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
 // bitwise_or
 PrimExpr operator|(PrimExpr a, PrimExpr b) { return bitwise_or(a, b); }
 PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
-  type_check_integer_args(a, b, "| operator (bitwise OR)");
+  type_check_int_or_bool_args(a, b, "| operator (bitwise OR)");
   BinaryOpMatchTypes(a, b, span);
   TVM_INDEX_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
@@ -805,7 +820,7 @@ PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
 // bitwise_xor
 PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); }
 PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
-  type_check_integer_args(a, b, "^ operator (bitwise XOR)");
+  type_check_int_or_bool_args(a, b, "^ operator (bitwise XOR)");
   BinaryOpMatchTypes(a, b, span);
   TVM_INDEX_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
@@ -818,7 +833,7 @@ PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
 PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); }
 
 PrimExpr bitwise_neg(PrimExpr a, Span span) {
-  type_check_integer_args(a, "~ operator (bitwise NOT)");
+  type_check_int_or_bool_args(a, "~ operator (bitwise NOT)");
   return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span);
 }
 
@@ -992,7 +1007,7 @@ TVM_TIR_REGISTER_PURE_UNARY_OP("fmod");
 
 // floor
 PrimExpr floor(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
@@ -1006,7 +1021,7 @@ 
TVM_TIR_REGISTER_PURE_UNARY_OP("floor").set_attr<TVectorizable>("TVectorizable",
 
 // ceil
 PrimExpr ceil(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
@@ -1020,7 +1035,7 @@ 
TVM_TIR_REGISTER_PURE_UNARY_OP("ceil").set_attr<TVectorizable>("TVectorizable",
 
 // round
 PrimExpr round(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
@@ -1034,7 +1049,7 @@ 
TVM_TIR_REGISTER_PURE_UNARY_OP("round").set_attr<TVectorizable>("TVectorizable",
 
 // nearbyint
 PrimExpr nearbyint(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
@@ -1048,7 +1063,7 @@ TVM_TIR_REGISTER_PURE_UNARY_OP("nearbyint");
 
 // trunc
 PrimExpr trunc(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
diff --git a/tests/cpp/tir_scalable_datatype.cc 
b/tests/cpp/tir_scalable_datatype.cc
index 6c42972d94..6ae6deb50d 100644
--- a/tests/cpp/tir_scalable_datatype.cc
+++ b/tests/cpp/tir_scalable_datatype.cc
@@ -167,8 +167,8 @@ TEST(ScalableDataType, 
TestScalableDataTypeInvalidLanesAccess) {
 
 TEST(ScalableDataType, TestScalableBool) {
   tvm::DataType scalable_type = tvm::DataType::Bool(4, true);
-  ASSERT_EQ(scalable_type.code(), kDLUInt);
-  ASSERT_EQ(scalable_type.bits(), 1);
+  ASSERT_EQ(scalable_type.code(), kDLBool);
+  ASSERT_EQ(scalable_type.bits(), 8);
   ASSERT_EQ(scalable_type.vscale_factor(), 4);
   ASSERT_TRUE(scalable_type.is_scalable_vector());
 }
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py 
b/tests/python/arith/test_arith_rewrite_simplify.py
index 6954cf4e1d..5eaaac68f0 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -93,7 +93,7 @@ class TestVector(BaseCompare):
     x, y, z = te.var("x"), te.var("y"), te.var("z")
     x64 = te.var("x", dtype="int64")
     vx = te.var("vx", dtype="int32x2")
-    vc = te.var("vc", dtype="uint1")
+    vc = te.var("vc", dtype="bool")
     test_case = tvm.testing.parameter(
         # Add rules
         TestCase(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x 
+ y, 3, 4)),
@@ -285,22 +285,22 @@ class TestVector(BaseCompare):
             tvm.te.max(vx, tvm.te.max(y, x).astype("int32x2")),
         ),
         ## Logical rules
-        TestCase(y.astype("int32x2").equal(x.astype("int32x2")), 
(y.equal(x)).astype("uint1x2")),
+        TestCase(y.astype("int32x2").equal(x.astype("int32x2")), 
(y.equal(x)).astype("boolx2")),
         TestCase(
             tvm.tir.NE(y.astype("int32x2"), (x.astype("int32x2"))),
-            (tvm.tir.NE(y, x)).astype("uint1x2"),
+            (tvm.tir.NE(y, x)).astype("boolx2"),
         ),
-        TestCase(y.astype("int32x2") > x.astype("int32x2"), (x < 
y).astype("uint1x2")),
-        TestCase(y.astype("int32x2") >= x.astype("int32x2"), (x <= 
y).astype("uint1x2")),
-        TestCase(y.astype("int32x2") < x.astype("int32x2"), (y < 
x).astype("uint1x2")),
-        TestCase(y.astype("int32x2") <= x.astype("int32x2"), (y <= 
x).astype("uint1x2")),
+        TestCase(y.astype("int32x2") > x.astype("int32x2"), (x < 
y).astype("boolx2")),
+        TestCase(y.astype("int32x2") >= x.astype("int32x2"), (x <= 
y).astype("boolx2")),
+        TestCase(y.astype("int32x2") < x.astype("int32x2"), (y < 
x).astype("boolx2")),
+        TestCase(y.astype("int32x2") <= x.astype("int32x2"), (y <= 
x).astype("boolx2")),
         TestCase(
-            tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), 
vc.astype("uint1x2")),
-            (tvm.tir.And(y <= x, vc)).astype("uint1x2"),
+            tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), 
vc.astype("boolx2")),
+            (tvm.tir.And(y <= x, vc)).astype("boolx2"),
         ),
         TestCase(
-            tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), 
vc.astype("uint1x2")),
-            (tvm.tir.Or(y <= x, vc)).astype("uint1x2"),
+            tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), 
vc.astype("boolx2")),
+            (tvm.tir.Or(y <= x, vc)).astype("boolx2"),
         ),
     )
 
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index a0ff507ef8..b076827dc4 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -1721,7 +1721,6 @@ def test_nll_loss_infer_struct_info_targets_dtype():
     w = relax.Var("w", R.Tensor((5,), "float32"))
     targets0 = relax.Var("targets", R.Tensor((3, 10, 10), "float32"))
     targets1 = relax.Var("targets", R.Tensor((3, 10, 10), "float64"))
-    targets2 = relax.Var("targets", R.Tensor((3, 10, 10), "bool"))
     targets3 = relax.Var("targets", R.Tensor((3, 10, 10), "int32"))
     targets4 = relax.Var("targets", R.Tensor((3, 10, 10), "int64"))
     targets5 = relax.Var("targets", R.Tensor((3, 10, 10), "uint32"))
@@ -1733,7 +1732,6 @@ def test_nll_loss_infer_struct_info_targets_dtype():
         bb.normalize(relax.op.nn.nll_loss(x, targets1, w))
 
     # correct cases
-    bb.normalize(relax.op.nn.nll_loss(x, targets2, w))  # bool is uint1
     bb.normalize(relax.op.nn.nll_loss(x, targets3, w))
     bb.normalize(relax.op.nn.nll_loss(x, targets4, w))
     bb.normalize(relax.op.nn.nll_loss(x, targets5, w))
diff --git a/tests/python/tir-base/test_tir_constructor.py 
b/tests/python/tir-base/test_tir_constructor.py
index 42c2998e27..4076070557 100644
--- a/tests/python/tir-base/test_tir_constructor.py
+++ b/tests/python/tir-base/test_tir_constructor.py
@@ -140,7 +140,7 @@ def test_stmt_constructor():
     assert isinstance(x, tvm.tir.AttrStmt)
     assert x.value.value == 1
 
-    x = tvm.tir.AssertStmt(tvm.tir.const(1, "uint1"), 
tvm.runtime.convert("hellow"), nop)
+    x = tvm.tir.AssertStmt(tvm.tir.const(1, "bool"), 
tvm.runtime.convert("hellow"), nop)
     assert isinstance(x, tvm.tir.AssertStmt)
     assert x.body == nop
 
@@ -150,8 +150,8 @@ def test_stmt_constructor():
     assert x.extent.value == 10
     assert x.body == nop
 
-    buffer_var = tvm.tir.Var("buf", 
tvm.ir.PointerType(tvm.ir.PrimType("uint1")))
-    buffer = tvm.tir.decl_buffer([16], "uint1", data=buffer_var)
+    buffer_var = tvm.tir.Var("buf", 
tvm.ir.PointerType(tvm.ir.PrimType("bool")))
+    buffer = tvm.tir.decl_buffer([16], "bool", data=buffer_var)
     x = tvm.tir.BufferStore(buffer, tvm.tir.IntImm("bool", 1), [10])
     assert isinstance(x, tvm.tir.BufferStore)
     assert x.buffer == buffer
@@ -160,7 +160,7 @@ def test_stmt_constructor():
     assert x.value.value == 1
 
     buffer_var = tvm.tir.Var("buf", 
tvm.ir.PointerType(tvm.ir.PrimType("float32")))
-    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, 
"uint1"), nop)
+    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, 
"bool"), nop)
     assert isinstance(x, tvm.tir.Allocate)
     assert x.dtype == "float32"
     assert x.buffer_var == buffer_var
@@ -168,7 +168,7 @@ def test_stmt_constructor():
 
     storage_scope = "global.texture"
     buffer_var = tvm.tir.Var("buf", 
tvm.ir.PointerType(tvm.ir.PrimType("float32"), storage_scope))
-    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, 
"uint1"), nop)
+    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, 
"bool"), nop)
     assert isinstance(x, tvm.tir.Allocate)
     assert x.dtype == "float32"
     assert x.buffer_var == buffer_var
@@ -181,7 +181,7 @@ def test_stmt_constructor():
     assert x.attr_key == "xyz"
     assert x.body == nop
 
-    x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"), tvm.tir.Evaluate(11), 
nop)
+    x = tvm.tir.IfThenElse(tvm.tir.const(1, "bool"), tvm.tir.Evaluate(11), nop)
     assert isinstance(x, tvm.tir.IfThenElse)
     assert x.then_case.value.value == 11
     assert x.else_case == nop
diff --git a/tests/python/tir-base/test_tir_nodes.py 
b/tests/python/tir-base/test_tir_nodes.py
index 5e1d25e48b..bc7cfeae17 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -302,7 +302,7 @@ def test_isnan():
     z = te.var("z", "int32")
     assert str(tvm.tir.isnan(z)) == "T.bool(False)"
     k = te.var("k", "int8x2")
-    assert str(tvm.tir.isnan(k).dtype) == "uint1x2"
+    assert str(tvm.tir.isnan(k).dtype) == "boolx2"
 
 
 def test_equality():
diff --git a/tests/python/tir-base/test_tir_ops.py 
b/tests/python/tir-base/test_tir_ops.py
index dfa5cbab80..cb7d8c597a 100644
--- a/tests/python/tir-base/test_tir_ops.py
+++ b/tests/python/tir-base/test_tir_ops.py
@@ -69,8 +69,8 @@ def test_const_fold3():
     x = te.var("x")
     for val in [0, 1]:
         for func in [tvm.tir.all, tvm.tir.any]:
-            check_throws(lambda: func(tvm.tir.const(val, "uint1"), x))
-            check_throws(lambda: func(x, tvm.tir.const(val, "uint1")))
+            check_throws(lambda: func(tvm.tir.const(val, "bool"), x))
+            check_throws(lambda: func(x, tvm.tir.const(val, "bool")))
 
     # Test const folding when both arguments are const
     for tvm_func, py_func in [
@@ -80,13 +80,13 @@ def test_const_fold3():
         for v1 in [0, 1]:
             for v2 in [0, 1]:
                 tvm.ir.assert_structural_equal(
-                    tvm_func(tvm.tir.const(v1, "uint1"), tvm.tir.const(v2, 
"uint1")),
-                    tvm.tir.const(py_func(v1, v2), "uint1"),
+                    tvm_func(tvm.tir.const(v1, "bool"), tvm.tir.const(v2, 
"bool")),
+                    tvm.tir.const(py_func(v1, v2), "bool"),
                 )
 
-    x = te.var("x", "uint1")
-    true = tvm.tir.const(1, "uint1")
-    false = tvm.tir.const(0, "uint1")
+    x = te.var("x", "bool")
+    true = tvm.tir.const(1, "bool")
+    false = tvm.tir.const(0, "bool")
 
     assert tvm.tir.all(x, true).same_as(x)
     assert tvm.tir.all(true, x).same_as(x)
diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py 
b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
index db6f4ba47f..8352b11644 100644
--- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
@@ -366,7 +366,7 @@ def test_ir_builder_tir_allocate():
     # the expected allocate
     buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("float32"), 
"local"))
     ir_expected = tir.Allocate(
-        buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), tir.Evaluate(1)
+        buffer_var, "float32", [10], tvm.tir.const(1, "bool"), tir.Evaluate(1)
     )
 
     # Check if the generated ir is expected
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py 
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index fc7deacd98..e4af158074 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -961,13 +961,13 @@ def test_predicated_buffer_load_store():
     buffer_load = tir.BufferLoad(
         buffer=buffer_map[b],
         indices=[0, tir.Ramp(0, 4, 4)],
-        predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4),
+        predicate=tir.Broadcast(tir.IntImm("bool", 0), 4),
     )
     body = tir.BufferStore(
         buffer=buffer_map[a],
         value=buffer_load,
         indices=[0, tir.Ramp(0, 2, 4)],
-        predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4),
+        predicate=tir.Broadcast(tir.IntImm("bool", 0), 4),
     )
     func = tir.PrimFunc(
         params=[a, b],


Reply via email to