This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4b93f2060f [Fix][Relax] Lower bool prod as logical all (#19557)
4b93f2060f is described below
commit 4b93f2060f1c11786ed3e257bbccc617b8a20b3b
Author: ConvolutedDog <[email protected]>
AuthorDate: Fri May 15 01:47:21 2026 +0800
[Fix][Relax] Lower bool prod as logical all (#19557)
This PR fixed https://github.com/apache/tvm/issues/19551.
Bool product has logical-AND semantics and cannot be lowered through TIR
Mul for LLVM codegen. Route bool prod through all() and add frontend and
legalization coverage for bool R.prod.
---
src/tirx/op/op.cc | 18 ++++++++----
.../relax/test_frontend_from_exported_program.py | 18 +++++++-----
tests/python/relax/test_frontend_from_fx.py | 18 +++++++-----
tests/python/relax/test_frontend_onnx.py | 2 +-
...st_transform_legalize_ops_search_statistical.py | 33 ++++++++++++++++++++++
5 files changed, 69 insertions(+), 20 deletions(-)
diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc
index 91539c9e7c..59b2c750d3 100644
--- a/src/tirx/op/op.cc
+++ b/src/tirx/op/op.cc
@@ -994,11 +994,19 @@ PrimExpr min(PrimExpr source, ffi::Array<IterVar> rdom,
ffi::Array<PrimExpr> ini
}
PrimExpr prod(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr>
init, Span span) {
- Var x("x", source.dtype(), span), y("y", source.dtype(), span);
- PrimExpr result = tirx::Mul(x, y, span);
- PrimExpr identity_element = make_const(source.dtype(), 1, span);
- tirx::CommReducer combiner = tirx::CommReducer({x}, {y}, {result},
{identity_element}, span);
- return tirx::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(),
true), 0, init, span);
+ if (source.dtype().is_bool()) {
+ // Bool product (prod) has the same truth table as logical AND. Reuse
all() to
+ // avoid lowering bool prod through Mul, which LLVM codegen does not
support.
+ return all(source, rdom, init, span);
+ } else {
+ // For non-bool types, we lower prod through Mul.
+ Var x("x", source.dtype(), span), y("y", source.dtype(), span);
+ PrimExpr result = tirx::Mul(x, y, span);
+ PrimExpr identity_element = make_const(source.dtype(), 1, span);
+ tirx::CommReducer combiner = tirx::CommReducer({x}, {y}, {result},
{identity_element}, span);
+ return tirx::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(),
true), 0, init,
+ span);
+ }
}
// fmod
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 5d032ba5c7..1f3848ff64 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7724,24 +7724,28 @@ def test_var_correction():
verify_model(VarCorrection0(), example_args, {}, Expected0)
-def test_prod():
[email protected](
+ "torch_dtype,relax_dtype",
+ [(torch.float32, "float32"), (torch.bool, "bool")],
+)
+def test_prod(torch_dtype, relax_dtype):
class Prod(Module):
def forward(self, x):
- return torch.prod(x)
+ return torch.prod(x, dtype=torch_dtype)
@tvm.script.ir_module
class Expected:
@R.function
def main(
- x: R.Tensor((5, 3), dtype="float32"),
- ) -> R.Tuple(R.Tensor((), dtype="float32")):
+ x: R.Tensor((5, 3), dtype=relax_dtype),
+ ) -> R.Tuple(R.Tensor((), dtype=relax_dtype)):
with R.dataflow():
- lv: R.Tensor((), dtype="float32") = R.prod(x, axis=None,
keepdims=False)
- gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+ lv: R.Tensor((), dtype=relax_dtype) = R.prod(x, axis=None,
keepdims=False)
+ gv: R.Tuple(R.Tensor((), dtype=relax_dtype)) = (lv,)
R.output(gv)
return gv
- example_args = (torch.randn(5, 3, dtype=torch.float32),)
+ example_args = (torch.ones(5, 3, dtype=torch_dtype),)
verify_model(Prod(), example_args, {}, Expected)
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index b2fe59b507..410875985e 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -6231,24 +6231,28 @@ def test_var():
verify_model(Var(), [([5, 3], "float32")], {}, Expected)
-def test_prod():
[email protected](
+ "torch_dtype,relax_dtype",
+ [(torch.float32, "float32"), (torch.bool, "bool")],
+)
+def test_prod(torch_dtype, relax_dtype):
class Prod(Module):
def forward(self, x):
- return torch.prod(x)
+ return torch.prod(x, dtype=torch_dtype)
@tvm.script.ir_module
class Expected:
@R.function
def main(
- inp_0: R.Tensor((5, 3), dtype="float32"),
- ) -> R.Tensor((), dtype="float32"):
+ inp_0: R.Tensor((5, 3), dtype=relax_dtype),
+ ) -> R.Tensor((), dtype=relax_dtype):
with R.dataflow():
- lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None,
keepdims=False)
- gv: R.Tensor((), dtype="float32") = lv
+ lv: R.Tensor((), dtype=relax_dtype) = R.prod(inp_0, axis=None,
keepdims=False)
+ gv: R.Tensor((), dtype=relax_dtype) = lv
R.output(gv)
return gv
- verify_model(Prod(), [([5, 3], "float32")], {}, Expected)
+ verify_model(Prod(), [([5, 3], relax_dtype)], {}, Expected)
def test_cumprod():
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 0d1d9f2d7c..151ec35e89 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4924,7 +4924,7 @@ def
test_nms_max_output_boxes_per_class_zero(with_explicit_max: bool):
check_correctness(model, inputs=inputs, opset=11)
tvm_out = run_in_tvm(model, inputs=inputs, opset=11)
- tvm_selected = tvm_out[0].numpy() if isinstance(tvm_out, (list, tuple))
else tvm_out.numpy()
+ tvm_selected = tvm_out[0].numpy() if isinstance(tvm_out, list | tuple)
else tvm_out.numpy()
assert tvm_selected.shape == (0, 3)
diff --git
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 304227d30d..c607a784f5 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -557,6 +557,39 @@ def test_prod():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_prod_bool():
+ # fmt: off
+ @tvm.script.ir_module
+ class Prod:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), "bool")) -> R.Tensor((1, 1, 1, 1),
"bool"):
+ gv: R.Tensor((1, 1, 1, 1), "bool") = R.prod(x, keepdims=True)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), "bool")) -> R.Tensor((1, 1, 1, 1),
"bool"):
+ gv = R.call_tir(Expected.prod, (x,), R.Tensor((1, 1, 1, 1),
dtype="bool"))
+ return gv
+
+ @T.prim_func(private=True)
+ def prod(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "bool"), rxplaceholder_red: T.Buffer((T.int64(1), T.int64(1),
T.int64(1), T.int64(1)), "bool")):
+ T.func_attr({"tirx.noalias": True})
+ for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1),
T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ with T.sblock("rxplaceholder_red"):
+ ax0, ax1, ax2, ax3, k0, k1, k2, k3 =
T.axis.remap("SSSSRRRR", [i0, i1, i2, i3, i4, i5, i6, i7])
+ T.reads(rxplaceholder[k0, k1, k2, k3])
+ T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3])
+ with T.init():
+ rxplaceholder_red[ax0, ax1, ax2, ax3] = T.bool(1)
+ rxplaceholder_red[ax0, ax1, ax2, ax3] =
rxplaceholder_red[ax0, ax1, ax2, ax3] and rxplaceholder[k0, k1, k2, k3]
+ # fmt: on
+
+ mod = LegalizeOps()(Prod)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_prod_symbolic():
# fmt: off
@tvm.script.ir_module