This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch tvm-ffi-bool
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/tvm-ffi-bool by this push:
new 11f324d9ea Update
11f324d9ea is described below
commit 11f324d9ea19dbe63c9aedacc1eec66f58fd3d28
Author: tqchen <[email protected]>
AuthorDate: Wed Nov 12 20:04:28 2025 -0500
Update
---
3rdparty/tvm-ffi | 2 +-
include/tvm/runtime/data_type.h | 2 ++
python/tvm/tir/ir_builder.py | 2 +-
src/runtime/vm/builtin.cc | 2 +-
src/tir/ir/expr.cc | 2 +-
src/tir/ir/stmt.cc | 4 ++--
src/tir/op/op.cc | 17 ++++++++++++++++-
tests/python/tir-base/test_tir_constructor.py | 8 ++++----
tests/python/tir-base/test_tir_ops.py | 14 +++++++-------
tests/python/tvmscript/test_tvmscript_ir_builder_tir.py | 2 +-
tests/python/tvmscript/test_tvmscript_printer_tir.py | 4 ++--
11 files changed, 38 insertions(+), 21 deletions(-)
diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi
index 60f45ac017..5fcdf8597f 160000
--- a/3rdparty/tvm-ffi
+++ b/3rdparty/tvm-ffi
@@ -1 +1 @@
-Subproject commit 60f45ac017964caf2252b3c74a6e10a4422a1835
+Subproject commit 5fcdf8597f1ecb1a76e2eb0578bded73de91ace0
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index da355bd7ce..3a91d4777b 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -140,6 +140,8 @@ class DataType {
bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
/*! \return whether type is a scalar type. */
bool is_bool() const { return code() == DataType::kBool; }
+ /*! \return whether type can be used in a predicate expression. */
+ bool is_predicate_dtype() const { return is_bool() || (is_uint() && bits()
== 1); }
/*! \return whether type is a float type. */
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a bfloat type. */
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index d6466b0922..a6313ae3bc 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -448,7 +448,7 @@ class IRBuilder(object):
)
buffer_var = buffer.data
- self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1,
dtype="uint1"), x))
+ self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1,
dtype="bool"), x))
return BufferVar(self, buffer, dtype)
def pointer(self, content_type, name="ptr", scope=""):
diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc
index 13446a158f..1bd3084c21 100644
--- a/src/runtime/vm/builtin.cc
+++ b/src/runtime/vm/builtin.cc
@@ -535,7 +535,7 @@ bool ReadIfCond(ffi::AnyView cond) {
if (arr->device.device_type != kDLCPU) {
arr = arr.CopyTo(DLDevice{kDLCPU, 0});
}
- ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt);
+ ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt ||
arr->dtype.code == kDLBool);
int64_t result;
switch (arr->dtype.bits) {
case 1: {
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 252b8693a7..5eee4ffd8b 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -840,7 +840,7 @@ BufferLoad::BufferLoad(Buffer buffer, ffi::Array<PrimExpr>
indices,
<< " lanes. The number of lanes must match.";
DataType predicate_element_dtype = predicate_dtype.element_of();
- ICHECK(predicate_element_dtype.is_bool())
+ ICHECK(predicate_element_dtype.is_predicate_dtype())
<< "Predicate mask elements must be boolean values, but got " <<
predicate_element_dtype
<< ".";
}
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index d33a01340b..781fb887ff 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -485,7 +485,7 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value,
ffi::Array<PrimExpr> ind
<< " lanes. The number of lanes must match.";
DataType predicate_element_dtype = predicate_dtype.element_of();
- ICHECK(predicate_element_dtype.is_bool())
+ ICHECK(predicate_element_dtype.is_predicate_dtype())
<< "Predicate mask elements must be boolean values, but got " <<
predicate_element_dtype
<< ".";
}
@@ -687,7 +687,7 @@ BlockRealize::BlockRealize(ffi::Array<PrimExpr> values,
PrimExpr predicate, Bloc
Span span) {
CHECK_EQ(block->iter_vars.size(), values.size())
<< "ValueError: BlockRealize needs to have the same number of iter_vars
and binding values";
- CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to
be a bool expression";
+ CHECK(predicate.dtype().is_bool() || predicate.dtype() == DataType::UInt(1))
<< "TypeError: Expect Block.predicate to be a bool expression";
ObjectPtr<BlockRealizeNode> node = ffi::make_object<BlockRealizeNode>();
node->iter_values = std::move(values);
node->predicate = std::move(predicate);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 935f9928a5..d6d68e5410 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -214,6 +214,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span
span) { // NOLINT(*)
} else if (ltype.is_float4() && !rtype.is_float4()) {
// Cast int->float4 for rhs when lhs is a float4
rhs = cast(ltype, rhs);
+ } else if (ltype.is_bool() && (rtype.is_int() || rtype.is_uint())) {
+ // Cast bool to int for lhs when rhs is a int or uint
+ lhs = cast(rtype, lhs);
+ } else if ((ltype.is_int() || ltype.is_uint()) && rtype.is_bool()) {
+ // Cast bool to int for rhs when lhs is a int or uint
+ rhs = cast(ltype, rhs);
} else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() &&
rtype.is_uint())) {
// Promote int to higher bits e.g. int8 + int16 --> int16 + int16
if (ltype.bits() < rtype.bits()) {
@@ -712,6 +718,15 @@ void type_check_integer_args(const PrimExpr& lhs, const
PrimExpr& rhs, const cha
<< "Expected integer argument as RHS of " << op << ", but received " <<
rhs << " of type "
<< rhs.dtype();
}
+
+void type_check_int_or_bool_args(const PrimExpr& lhs, const PrimExpr& rhs,
const char* op) {
+ ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint() ||
lhs.dtype().is_bool())
+ << "Expected integer argument as LHS of " << op << ", but received " <<
lhs << " of type "
+ << lhs.dtype();
+ ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint() ||
rhs.dtype().is_bool())
+ << "Expected integer argument as RHS of " << op << ", but received " <<
rhs << " of type "
+ << rhs.dtype();
+}
} // namespace
PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); }
@@ -805,7 +820,7 @@ PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
// bitwise_xor
PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); }
PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
- type_check_integer_args(a, b, "^ operator (bitwise XOR)");
+ type_check_int_or_bool_args(a, b, "^ operator (bitwise XOR)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
diff --git a/tests/python/tir-base/test_tir_constructor.py
b/tests/python/tir-base/test_tir_constructor.py
index 42c2998e27..fe64efa39b 100644
--- a/tests/python/tir-base/test_tir_constructor.py
+++ b/tests/python/tir-base/test_tir_constructor.py
@@ -140,7 +140,7 @@ def test_stmt_constructor():
assert isinstance(x, tvm.tir.AttrStmt)
assert x.value.value == 1
- x = tvm.tir.AssertStmt(tvm.tir.const(1, "uint1"),
tvm.runtime.convert("hellow"), nop)
+ x = tvm.tir.AssertStmt(tvm.tir.const(1, "bool"),
tvm.runtime.convert("hellow"), nop)
assert isinstance(x, tvm.tir.AssertStmt)
assert x.body == nop
@@ -160,7 +160,7 @@ def test_stmt_constructor():
assert x.value.value == 1
buffer_var = tvm.tir.Var("buf",
tvm.ir.PointerType(tvm.ir.PrimType("float32")))
- x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1,
"uint1"), nop)
+ x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1,
"bool"), nop)
assert isinstance(x, tvm.tir.Allocate)
assert x.dtype == "float32"
assert x.buffer_var == buffer_var
@@ -168,7 +168,7 @@ def test_stmt_constructor():
storage_scope = "global.texture"
buffer_var = tvm.tir.Var("buf",
tvm.ir.PointerType(tvm.ir.PrimType("float32"), storage_scope))
- x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1,
"uint1"), nop)
+ x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1,
"bool"), nop)
assert isinstance(x, tvm.tir.Allocate)
assert x.dtype == "float32"
assert x.buffer_var == buffer_var
@@ -181,7 +181,7 @@ def test_stmt_constructor():
assert x.attr_key == "xyz"
assert x.body == nop
- x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"), tvm.tir.Evaluate(11),
nop)
+ x = tvm.tir.IfThenElse(tvm.tir.const(1, "bool"), tvm.tir.Evaluate(11), nop)
assert isinstance(x, tvm.tir.IfThenElse)
assert x.then_case.value.value == 11
assert x.else_case == nop
diff --git a/tests/python/tir-base/test_tir_ops.py
b/tests/python/tir-base/test_tir_ops.py
index dfa5cbab80..cb7d8c597a 100644
--- a/tests/python/tir-base/test_tir_ops.py
+++ b/tests/python/tir-base/test_tir_ops.py
@@ -69,8 +69,8 @@ def test_const_fold3():
x = te.var("x")
for val in [0, 1]:
for func in [tvm.tir.all, tvm.tir.any]:
- check_throws(lambda: func(tvm.tir.const(val, "uint1"), x))
- check_throws(lambda: func(x, tvm.tir.const(val, "uint1")))
+ check_throws(lambda: func(tvm.tir.const(val, "bool"), x))
+ check_throws(lambda: func(x, tvm.tir.const(val, "bool")))
# Test const folding when both arguments are const
for tvm_func, py_func in [
@@ -80,13 +80,13 @@ def test_const_fold3():
for v1 in [0, 1]:
for v2 in [0, 1]:
tvm.ir.assert_structural_equal(
- tvm_func(tvm.tir.const(v1, "uint1"), tvm.tir.const(v2,
"uint1")),
- tvm.tir.const(py_func(v1, v2), "uint1"),
+ tvm_func(tvm.tir.const(v1, "bool"), tvm.tir.const(v2,
"bool")),
+ tvm.tir.const(py_func(v1, v2), "bool"),
)
- x = te.var("x", "uint1")
- true = tvm.tir.const(1, "uint1")
- false = tvm.tir.const(0, "uint1")
+ x = te.var("x", "bool")
+ true = tvm.tir.const(1, "bool")
+ false = tvm.tir.const(0, "bool")
assert tvm.tir.all(x, true).same_as(x)
assert tvm.tir.all(true, x).same_as(x)
diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
index db6f4ba47f..8352b11644 100644
--- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
@@ -366,7 +366,7 @@ def test_ir_builder_tir_allocate():
# the expected allocate
buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("float32"),
"local"))
ir_expected = tir.Allocate(
- buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), tir.Evaluate(1)
+ buffer_var, "float32", [10], tvm.tir.const(1, "bool"), tir.Evaluate(1)
)
# Check if the generated ir is expected
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index fc7deacd98..e4af158074 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -961,13 +961,13 @@ def test_predicated_buffer_load_store():
buffer_load = tir.BufferLoad(
buffer=buffer_map[b],
indices=[0, tir.Ramp(0, 4, 4)],
- predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4),
+ predicate=tir.Broadcast(tir.IntImm("bool", 0), 4),
)
body = tir.BufferStore(
buffer=buffer_map[a],
value=buffer_load,
indices=[0, tir.Ramp(0, 2, 4)],
- predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4),
+ predicate=tir.Broadcast(tir.IntImm("bool", 0), 4),
)
func = tir.PrimFunc(
params=[a, b],