This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 38b85c9f61 [Fix][dlight] add an explicit reduction loop check in 
Reduce (#17711)
38b85c9f61 is described below

commit 38b85c9f612e77e2c0e25eb3872b89fe254b12a8
Author: PatrikPerssonInceptron 
<[email protected]>
AuthorDate: Sun Mar 9 03:37:09 2025 +0100

    [Fix][dlight] add an explicit reduction loop check in Reduce (#17711)
    
    * added an explicit check to verify that the block has a reduction loop 
since this is assumed in later stages
    
    * added unit test to verify that the Reduction schedule is not applied to
    prim funcs without a reduction loop
---
 python/tvm/dlight/gpu/reduction.py        |  7 ++++++-
 tests/python/dlight/test_gpu_reduction.py | 26 ++++++++++++++++++++++++++
 2 files changed, 32 insertions(+), 1 deletion(-)

diff --git a/python/tvm/dlight/gpu/reduction.py 
b/python/tvm/dlight/gpu/reduction.py
index 9851bb9800..4faaa1cab9 100644
--- a/python/tvm/dlight/gpu/reduction.py
+++ b/python/tvm/dlight/gpu/reduction.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""A rule for reduction. """
+"""A rule for reduction."""
 # TODO: combine reduction rule and general reduction rule into one file.
 from typing import List, Mapping, Optional, Tuple, Union
 
@@ -47,6 +47,10 @@ def _get_reduction_expr(block: tir.Block) -> 
Optional[tir.PrimExpr]:
     return buffer_store.value.b
 
 
+def _has_reduction_loop(block_info):
+    return any([info.kind == "R" for info in block_info.iters])
+
+
 class Reduction(GPUScheduleRule):
     """A rule for Reduction."""
 
@@ -79,6 +83,7 @@ class Reduction(GPUScheduleRule):
         # Step 1. Check reduction block
         if (
             (not block_info.is_reduction())
+            or (not _has_reduction_loop(block_info))
             or len(block_stmt.writes) != 1
             or _get_reduction_expr(block_stmt) is None
         ):
diff --git a/tests/python/dlight/test_gpu_reduction.py 
b/tests/python/dlight/test_gpu_reduction.py
index 1ce57eb53d..0a74df70c0 100644
--- a/tests/python/dlight/test_gpu_reduction.py
+++ b/tests/python/dlight/test_gpu_reduction.py
@@ -1152,5 +1152,31 @@ def test_gemv_output_one_element():
     assert_structural_equal(mod, Expected)
 
 
+def test_no_reduction_loop_check():
+    # The normalized prime func will not contain a reduction loop since its 
extent is one.
+    # This checks that the Reduction schedule is correctly not applied in this 
case
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def matmul(lv43: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), 
"float16"), lv44: T.Buffer((T.int64(1), T.int64(1), T.int64(1)), "float16"), 
matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")):
+            T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), 
T.int64(1)):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, 
k])
+                    T.reads(lv43[v_i0, v_i1, v_k], lv44[v_i0, v_k, v_i2])
+                    T.writes(matmul[v_i0, v_i1, v_i2])
+                    with T.init():
+                        matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
+                    matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + 
lv43[v_i0, v_i1, v_k] * lv44[v_i0, v_k, v_i2]
+    # fmt: on
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, Before)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to