This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 82b471cee4 [TIR] Do not drop 4th argument to tir.max (#15763)
82b471cee4 is described below

commit 82b471cee4f28b691e921373694bfcb0257ec3e1
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Sep 17 15:10:08 2023 -0500

    [TIR] Do not drop 4th argument to tir.max (#15763)
    
    The `tir.op.comm_reducer` utility provides two distinct APIs, either
    reducing along a tensor axis or reducing along a list of arguments.
    
    Prior to this commit, when reducing along a list of arguments, the 4th
    argument was silently dropped.  For example,
    `tvm.tir.max(1,2,3,4,3,2,1)` would return `3`.
---
 python/tvm/tir/op.py                  |  8 +++++++-
 tests/python/unittest/test_tir_ops.py | 28 ++++++++++++++++++++++------
 2 files changed, 29 insertions(+), 7 deletions(-)

diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 30e2a29487..905d14296d 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -3150,10 +3150,16 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
         if isinstance(axis, (tvm.tir.IterVar, list, tuple)):
             assert not args
             return _make_reduce(expr, axis, where, init)
+
         if where is None:
             assert not args
+            assert init is None
             return _reduce_directly(expr, axis)
-        return _reduce_directly(expr, axis, where, *args)
+        elif init is None:
+            assert not args
+            return _reduce_directly(expr, axis, where)
+        else:
+            return _reduce_directly(expr, axis, where, init, *args)
 
     doc_str = """Create a {0} expression over axis.
 
diff --git a/tests/python/unittest/test_tir_ops.py 
b/tests/python/unittest/test_tir_ops.py
index 9725650ead..21981d1f0b 100644
--- a/tests/python/unittest/test_tir_ops.py
+++ b/tests/python/unittest/test_tir_ops.py
@@ -15,8 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
+import tvm.testing
 from tvm import te
 
+import pytest
+
 
 def check_throws(f):
     try:
@@ -213,10 +216,23 @@ def test_if_then_else():
             raise ValueError("Unknown combinations")
 
 
[email protected]("num_args", list(range(2, 10)))
+def test_comm_reducer(num_args):
+    """Handle all arguments in tir comm_reducer
+
+    The `tir.comm_reducer` API has two distinct usages.  It can reduce
+    a tensor along a specified axis, similar to numpy.max, or it can
+    reduce several arguments together, simililar to Python's built-in
+    max().  This choice is based on the type of the second argument.
+
+    If the `tir.comm_reducer` is reducing all arguments, then all
+    arguments should be used.  In the past, the introduction of new
+    arguments intended for use when reducing along a tensor axis has
+    failed to forward these arguments when reducing along a list of
+    items.
+    """
+    assert tvm.tir.max(*range(num_args)) == num_args - 1
+
+
 if __name__ == "__main__":
-    test_const_fold()
-    test_const_fold2()
-    test_const_fold3()
-    test_const_fold4()
-    test_binary_dtype_match()
-    test_if_then_else()
+    tvm.testing.main()

Reply via email to