This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 82b471cee4 [TIR] Do not drop 4th argument to tir.max (#15763)
82b471cee4 is described below
commit 82b471cee4f28b691e921373694bfcb0257ec3e1
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Sep 17 15:10:08 2023 -0500
[TIR] Do not drop 4th argument to tir.max (#15763)
The `tir.op.comm_reducer` utility provides two distinct APIs, either
reducing along a tensor axis or reducing along a list of arguments.
Prior to this commit, when reducing along a list of arguments, the 4th
argument was silently dropped. For example,
`tvm.tir.max(1,2,3,4,3,2,1)` would return `3`.
---
python/tvm/tir/op.py | 8 +++++++-
tests/python/unittest/test_tir_ops.py | 28 ++++++++++++++++++++++------
2 files changed, 29 insertions(+), 7 deletions(-)
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 30e2a29487..905d14296d 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -3150,10 +3150,16 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
if isinstance(axis, (tvm.tir.IterVar, list, tuple)):
assert not args
return _make_reduce(expr, axis, where, init)
+
if where is None:
assert not args
+ assert init is None
return _reduce_directly(expr, axis)
- return _reduce_directly(expr, axis, where, *args)
+ elif init is None:
+ assert not args
+ return _reduce_directly(expr, axis, where)
+ else:
+ return _reduce_directly(expr, axis, where, init, *args)
doc_str = """Create a {0} expression over axis.
diff --git a/tests/python/unittest/test_tir_ops.py
b/tests/python/unittest/test_tir_ops.py
index 9725650ead..21981d1f0b 100644
--- a/tests/python/unittest/test_tir_ops.py
+++ b/tests/python/unittest/test_tir_ops.py
@@ -15,8 +15,11 @@
# specific language governing permissions and limitations
# under the License.
import tvm
+import tvm.testing
from tvm import te
+import pytest
+
def check_throws(f):
try:
@@ -213,10 +216,23 @@ def test_if_then_else():
raise ValueError("Unknown combinations")
[email protected]("num_args", list(range(2, 10)))
+def test_comm_reducer(num_args):
+ """Handle all arguments in tir comm_reducer
+
+ The `tir.comm_reducer` API has two distinct usages. It can reduce
+ a tensor along a specified axis, similar to numpy.max, or it can
+ reduce several arguments together, simililar to Python's built-in
+ max(). This choice is based on the type of the second argument.
+
+ If the `tir.comm_reducer` is reducing all arguments, then all
+ arguments should be used. In the past, the introduction of new
+ arguments intended for use when reducing along a tensor axis has
+ failed to forward these arguments when reducing along a list of
+ items.
+ """
+ assert tvm.tir.max(*range(num_args)) == num_args - 1
+
+
if __name__ == "__main__":
- test_const_fold()
- test_const_fold2()
- test_const_fold3()
- test_const_fold4()
- test_binary_dtype_match()
- test_if_then_else()
+ tvm.testing.main()