This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 38b85c9f61 [Fix][dlight] add an explicit reduction loop check in
Reduce (#17711)
38b85c9f61 is described below
commit 38b85c9f612e77e2c0e25eb3872b89fe254b12a8
Author: PatrikPerssonInceptron
<[email protected]>
AuthorDate: Sun Mar 9 03:37:09 2025 +0100
[Fix][dlight] add an explicit reduction loop check in Reduce (#17711)
* added an explicit check to verify that the block has a reduction loop
since this is assumed in later stages
* added unit test to verify that the Reduction schedule is not applied to
prim funcs without a reduction loop
---
python/tvm/dlight/gpu/reduction.py | 7 ++++++-
tests/python/dlight/test_gpu_reduction.py | 26 ++++++++++++++++++++++++++
2 files changed, 32 insertions(+), 1 deletion(-)
diff --git a/python/tvm/dlight/gpu/reduction.py
b/python/tvm/dlight/gpu/reduction.py
index 9851bb9800..4faaa1cab9 100644
--- a/python/tvm/dlight/gpu/reduction.py
+++ b/python/tvm/dlight/gpu/reduction.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""A rule for reduction. """
+"""A rule for reduction."""
# TODO: combine reduction rule and general reduction rule into one file.
from typing import List, Mapping, Optional, Tuple, Union
@@ -47,6 +47,10 @@ def _get_reduction_expr(block: tir.Block) ->
Optional[tir.PrimExpr]:
return buffer_store.value.b
+def _has_reduction_loop(block_info):
+ return any([info.kind == "R" for info in block_info.iters])
+
+
class Reduction(GPUScheduleRule):
"""A rule for Reduction."""
@@ -79,6 +83,7 @@ class Reduction(GPUScheduleRule):
# Step 1. Check reduction block
if (
(not block_info.is_reduction())
+ or (not _has_reduction_loop(block_info))
or len(block_stmt.writes) != 1
or _get_reduction_expr(block_stmt) is None
):
diff --git a/tests/python/dlight/test_gpu_reduction.py
b/tests/python/dlight/test_gpu_reduction.py
index 1ce57eb53d..0a74df70c0 100644
--- a/tests/python/dlight/test_gpu_reduction.py
+++ b/tests/python/dlight/test_gpu_reduction.py
@@ -1152,5 +1152,31 @@ def test_gemv_output_one_element():
assert_structural_equal(mod, Expected)
+def test_no_reduction_loop_check():
+ # The normalized prime func will not contain a reduction loop since its
extent is one.
+ # This checks that the Reduction schedule is correctly not applied in this
case
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def matmul(lv43: T.Buffer((T.int64(1), T.int64(32), T.int64(1)),
"float16"), lv44: T.Buffer((T.int64(1), T.int64(1), T.int64(1)), "float16"),
matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")):
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1),
T.int64(1)):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2,
k])
+ T.reads(lv43[v_i0, v_i1, v_k], lv44[v_i0, v_k, v_i2])
+ T.writes(matmul[v_i0, v_i1, v_i2])
+ with T.init():
+ matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
+ matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] +
lv43[v_i0, v_i1, v_k] * lv44[v_i0, v_k, v_i2]
+ # fmt: on
+
+ target = Target("nvidia/geforce-rtx-3090-ti")
+ with target:
+ mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before) # pylint:
disable=not-callable
+ assert_structural_equal(mod, Before)
+
+
if __name__ == "__main__":
tvm.testing.main()