This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e0678699f9 [Fix] CommReduce could handle 0-dim data (#19683)
e0678699f9 is described below

commit e0678699f9eb525e6dfb8acf817aa4ddcae1a080
Author: flashmouse <[email protected]>
AuthorDate: Wed Jun 10 02:41:41 2026 +0800

    [Fix] CommReduce could handle 0-dim data (#19683)
    
    This PR try to fix #19676 , allow ``CommReduce`` handle 0-dim data
    correctly.
    
    ---------
    
    Co-authored-by: flashmouse <[email protected]>
---
 include/tvm/topi/reduction.h                       |  5 +++-
 ...st_transform_legalize_ops_search_statistical.py | 32 ++++++++++++++++++++++
 2 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h
index e3f5444efe..d8889000f2 100644
--- a/include/tvm/topi/reduction.h
+++ b/include/tvm/topi/reduction.h
@@ -184,7 +184,10 @@ inline Tensor DoCommReduce(const Tensor& data, FReduce 
func,
 inline Tensor CommReduce(const Tensor& data, const 
ffi::Optional<ffi::Array<int64_t>>& axis,
                          FReduce func, bool keepdims, bool atleast1d) {
   auto ndim = data->shape.size();
-  TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
+  if (ndim == 0) {
+    auto identity = topi::identity(data, data->op->name + "_red", kCommReduce);
+    return atleast1d ? topi::expand_dims(identity, 0, 1) : identity;
+  }
   auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
   auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, 
atleast1d);
   return DoCommReduce(data, func, target_shape, real_axis,
diff --git 
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py 
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 1a0b71690d..82c478bd51 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -1179,5 +1179,37 @@ def test_variance_no_keepdims():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_max_zero_dim():
+    # Reducing a 0-D (scalar) tensor is the identity; it must legalize, not 
crash.
+    # Regression test for https://github.com/apache/tvm/issues/19676
+    # fmt: off
+    @tvm.script.ir_module
+    class Max:
+        @R.function
+        def main(x: R.Tensor((), "float32")) -> R.Tensor((), "float32"):
+            gv: R.Tensor((), "float32") = R.max(x)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((), 
dtype="float32"):
+            gv = R.call_tir(Expected.max, (x,), R.Tensor((), dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True, s_tir=True)
+        def max(x: T.Buffer((), "float32"), x_red: T.Buffer((), "float32")):
+            T.func_attr({"tirx.noalias": True})
+            with T.sblock("x_red"):
+                vi = T.axis.spatial(1, T.int64(0))
+                T.reads(x[()])
+                T.writes(x_red[()])
+                x_red[()] = x[()]
+    # fmt: on
+
+    mod = LegalizeOps()(Max)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to