This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b93f2060f [Fix][Relax] Lower bool prod as logical all (#19557)
4b93f2060f is described below

commit 4b93f2060f1c11786ed3e257bbccc617b8a20b3b
Author: ConvolutedDog <[email protected]>
AuthorDate: Fri May 15 01:47:21 2026 +0800

    [Fix][Relax] Lower bool prod as logical all (#19557)
    
    This PR fixed https://github.com/apache/tvm/issues/19551.
    
    Bool product has logical-AND semantics and cannot be lowered through TIR
    Mul for LLVM codegen. Route bool prod through all() and add frontend and
    legalization coverage for bool R.prod.
---
 src/tirx/op/op.cc                                  | 18 ++++++++----
 .../relax/test_frontend_from_exported_program.py   | 18 +++++++-----
 tests/python/relax/test_frontend_from_fx.py        | 18 +++++++-----
 tests/python/relax/test_frontend_onnx.py           |  2 +-
 ...st_transform_legalize_ops_search_statistical.py | 33 ++++++++++++++++++++++
 5 files changed, 69 insertions(+), 20 deletions(-)

diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc
index 91539c9e7c..59b2c750d3 100644
--- a/src/tirx/op/op.cc
+++ b/src/tirx/op/op.cc
@@ -994,11 +994,19 @@ PrimExpr min(PrimExpr source, ffi::Array<IterVar> rdom, 
ffi::Array<PrimExpr> ini
 }
 
 PrimExpr prod(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> 
init, Span span) {
-  Var x("x", source.dtype(), span), y("y", source.dtype(), span);
-  PrimExpr result = tirx::Mul(x, y, span);
-  PrimExpr identity_element = make_const(source.dtype(), 1, span);
-  tirx::CommReducer combiner = tirx::CommReducer({x}, {y}, {result}, 
{identity_element}, span);
-  return tirx::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), 
true), 0, init, span);
+  if (source.dtype().is_bool()) {
+    // Bool product (prod) has the same truth table as logical AND.  Reuse 
all() to
+    // avoid lowering bool prod through Mul, which LLVM codegen does not 
support.
+    return all(source, rdom, init, span);
+  } else {
+    // For non-bool types, we lower prod through Mul.
+    Var x("x", source.dtype(), span), y("y", source.dtype(), span);
+    PrimExpr result = tirx::Mul(x, y, span);
+    PrimExpr identity_element = make_const(source.dtype(), 1, span);
+    tirx::CommReducer combiner = tirx::CommReducer({x}, {y}, {result}, 
{identity_element}, span);
+    return tirx::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), 
true), 0, init,
+                        span);
+  }
 }
 
 // fmod
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 5d032ba5c7..1f3848ff64 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7724,24 +7724,28 @@ def test_var_correction():
     verify_model(VarCorrection0(), example_args, {}, Expected0)
 
 
-def test_prod():
[email protected](
+    "torch_dtype,relax_dtype",
+    [(torch.float32, "float32"), (torch.bool, "bool")],
+)
+def test_prod(torch_dtype, relax_dtype):
     class Prod(Module):
         def forward(self, x):
-            return torch.prod(x)
+            return torch.prod(x, dtype=torch_dtype)
 
     @tvm.script.ir_module
     class Expected:
         @R.function
         def main(
-            x: R.Tensor((5, 3), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((), dtype="float32")):
+            x: R.Tensor((5, 3), dtype=relax_dtype),
+        ) -> R.Tuple(R.Tensor((), dtype=relax_dtype)):
             with R.dataflow():
-                lv: R.Tensor((), dtype="float32") = R.prod(x, axis=None, 
keepdims=False)
-                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+                lv: R.Tensor((), dtype=relax_dtype) = R.prod(x, axis=None, 
keepdims=False)
+                gv: R.Tuple(R.Tensor((), dtype=relax_dtype)) = (lv,)
                 R.output(gv)
             return gv
 
-    example_args = (torch.randn(5, 3, dtype=torch.float32),)
+    example_args = (torch.ones(5, 3, dtype=torch_dtype),)
     verify_model(Prod(), example_args, {}, Expected)
 
 
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index b2fe59b507..410875985e 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -6231,24 +6231,28 @@ def test_var():
     verify_model(Var(), [([5, 3], "float32")], {}, Expected)
 
 
-def test_prod():
[email protected](
+    "torch_dtype,relax_dtype",
+    [(torch.float32, "float32"), (torch.bool, "bool")],
+)
+def test_prod(torch_dtype, relax_dtype):
     class Prod(Module):
         def forward(self, x):
-            return torch.prod(x)
+            return torch.prod(x, dtype=torch_dtype)
 
     @tvm.script.ir_module
     class Expected:
         @R.function
         def main(
-            inp_0: R.Tensor((5, 3), dtype="float32"),
-        ) -> R.Tensor((), dtype="float32"):
+            inp_0: R.Tensor((5, 3), dtype=relax_dtype),
+        ) -> R.Tensor((), dtype=relax_dtype):
             with R.dataflow():
-                lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None, 
keepdims=False)
-                gv: R.Tensor((), dtype="float32") = lv
+                lv: R.Tensor((), dtype=relax_dtype) = R.prod(inp_0, axis=None, 
keepdims=False)
+                gv: R.Tensor((), dtype=relax_dtype) = lv
                 R.output(gv)
             return gv
 
-    verify_model(Prod(), [([5, 3], "float32")], {}, Expected)
+    verify_model(Prod(), [([5, 3], relax_dtype)], {}, Expected)
 
 
 def test_cumprod():
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 0d1d9f2d7c..151ec35e89 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4924,7 +4924,7 @@ def 
test_nms_max_output_boxes_per_class_zero(with_explicit_max: bool):
     check_correctness(model, inputs=inputs, opset=11)
 
     tvm_out = run_in_tvm(model, inputs=inputs, opset=11)
-    tvm_selected = tvm_out[0].numpy() if isinstance(tvm_out, (list, tuple)) 
else tvm_out.numpy()
+    tvm_selected = tvm_out[0].numpy() if isinstance(tvm_out, list | tuple) 
else tvm_out.numpy()
     assert tvm_selected.shape == (0, 3)
 
 
diff --git 
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py 
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 304227d30d..c607a784f5 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -557,6 +557,39 @@ def test_prod():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_prod_bool():
+    # fmt: off
+    @tvm.script.ir_module
+    class Prod:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 5), "bool")) -> R.Tensor((1, 1, 1, 1), 
"bool"):
+            gv: R.Tensor((1, 1, 1, 1), "bool") = R.prod(x, keepdims=True)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 5), "bool")) -> R.Tensor((1, 1, 1, 1), 
"bool"):
+            gv = R.call_tir(Expected.prod, (x,), R.Tensor((1, 1, 1, 1), 
dtype="bool"))
+            return gv
+
+        @T.prim_func(private=True)
+        def prod(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)), "bool"), rxplaceholder_red: T.Buffer((T.int64(1), T.int64(1), 
T.int64(1), T.int64(1)), "bool")):
+            T.func_attr({"tirx.noalias": True})
+            for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1), 
T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
+                with T.sblock("rxplaceholder_red"):
+                    ax0, ax1, ax2, ax3, k0, k1, k2, k3 = 
T.axis.remap("SSSSRRRR", [i0, i1, i2, i3, i4, i5, i6, i7])
+                    T.reads(rxplaceholder[k0, k1, k2, k3])
+                    T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3])
+                    with T.init():
+                        rxplaceholder_red[ax0, ax1, ax2, ax3] = T.bool(1)
+                    rxplaceholder_red[ax0, ax1, ax2, ax3] = 
rxplaceholder_red[ax0, ax1, ax2, ax3] and rxplaceholder[k0, k1, k2, k3]
+    # fmt: on
+
+    mod = LegalizeOps()(Prod)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_prod_symbolic():
     # fmt: off
     @tvm.script.ir_module

Reply via email to