This is an automated email from the ASF dual-hosted git repository.

hongyij pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dafd053be2 [Dlight] Fix general reduction rule to support non-last 
reduction axis (#17754)
dafd053be2 is described below

commit dafd053be246718c5fb15b049e309688959165ab
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Mar 17 17:12:10 2025 -0400

    [Dlight] Fix general reduction rule to support non-last reduction axis 
(#17754)
    
    This PR fixes a bug in the general reduction dlight rule, which happens
    when there is a trailing spatial block, and for the previous reduction
    blocks, the reduction axes are not on the back.
    
    In the case above, the loop orders of the reduction blocks and the
    trailing spatial block are inconsistent, while the dlight rule before
    this fix always treat the loop orders as consistent.
    
    As a result, though the function after applying the rule is numerically
    correct, it may require much extra shared memory use (in proportion to
    the size of spatial loops). And when the spatial dimensions are large,
    the required share memory size may exceed the device limit.
    
    This PR fixes this bug and adds a unit test.
---
 python/tvm/dlight/gpu/general_reduction.py        | 40 ++++++++++
 tests/python/dlight/test_gpu_general_reduction.py | 91 +++++++++++++++++++++++
 2 files changed, 131 insertions(+)

diff --git a/python/tvm/dlight/gpu/general_reduction.py 
b/python/tvm/dlight/gpu/general_reduction.py
index a068e732b9..d3979ce0e4 100644
--- a/python/tvm/dlight/gpu/general_reduction.py
+++ b/python/tvm/dlight/gpu/general_reduction.py
@@ -99,6 +99,46 @@ class GeneralReduction(GPUScheduleRule):
         except AssertionError:
             return None
 
+        if "R" not in block_infos[-1].dom_kind():
+            # The final block is a spatial block.
+            # It is possible that the loop order of the last block is not the 
same as
+            # previous blocks.
+            # Thus we reorder spatial loops to align with reduction loops for 
followup schedule.
+            # We first collect all the buffers written by reduction blocks,
+            # then in the final block, any index of those buffers are spatial.
+            reduced_buffers = []
+            for block_info in block_infos[:-1]:
+                for buffer_write in sch.get(block_info.block_rv).writes:
+                    reduced_buffers.append(buffer_write.buffer)
+
+            spatial_block = sch.get(block_infos[-1].block_rv)
+            spatial_loops = set()
+            block_var_to_loop_var = {}
+            loops = sch.get_loops(block_infos[-1].block_rv)
+            for block_iter, loop_rv in zip(spatial_block.iter_vars, loops):
+                block_var_to_loop_var[block_iter.var] = 
sch.get(loop_rv).loop_var
+
+            def _visit_expr(e: tir.PrimExpr):
+                if isinstance(e, tir.Var) and e in block_var_to_loop_var:
+                    spatial_loops.add(block_var_to_loop_var[e])
+
+            for buffer_read in spatial_block.reads:
+                buffer = buffer_read.buffer
+                if buffer in reduced_buffers:
+                    for read_range in buffer_read.region:
+                        tir.stmt_functor.post_order_visit(read_range.min, 
_visit_expr)
+                        tir.stmt_functor.post_order_visit(read_range.extent, 
_visit_expr)
+
+            s_loops = []
+            other_loops = []
+            for loop_rv in loops:
+                loop = sch.get(loop_rv)
+                if loop.loop_var in spatial_loops or loop.extent == 1:
+                    s_loops.append(loop_rv)
+                else:
+                    other_loops.append(loop_rv)
+            sch.reorder(*s_loops, *other_loops)
+
         loops = sch.get_loops(block_infos[-1].block_rv)
         bx = sch.fuse(*loops[:num_leading_s])
         r_loop, tx = sch.split(loops[-1], [None, len_tx])
diff --git a/tests/python/dlight/test_gpu_general_reduction.py 
b/tests/python/dlight/test_gpu_general_reduction.py
index e1a9a8e018..9549441a11 100644
--- a/tests/python/dlight/test_gpu_general_reduction.py
+++ b/tests/python/dlight/test_gpu_general_reduction.py
@@ -222,6 +222,97 @@ def test_softmax_2():
     _check(Before, After)
 
 
+def test_softmax_3():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), 
T.int64(8192)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(4), 
T.int64(32), T.int64(8192)), "float32")):
+            # with T.block("root"):
+            T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(4), 
T.int64(8192)))
+            T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(4), 
T.int64(32), T.int64(8192)))
+            T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(4), 
T.int64(8192)))
+            for i0, i1, i2, k in T.grid(T.int64(1), T.int64(4), T.int64(8192), 
T.int64(32)):
+                with T.block("T_softmax_maxelem"):
+                    v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, 
k])
+                    T.reads(input[v_i0, v_i1, v_k, v_i2])
+                    T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])
+                    with T.init():
+                        T_softmax_maxelem[v_i0, v_i1, v_i2] = 
T.float32(-340282346638528859811704183484516925440.0)
+                    T_softmax_maxelem[v_i0, v_i1, v_i2] = 
T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], input[v_i0, v_i1, v_k, v_i2])
+            for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(4), T.int64(32), 
T.int64(8192)):
+                with T.block("T_softmax_exp"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(input[v_i0, v_i1, v_i2, v_i3], 
T_softmax_maxelem[v_i0, v_i1, v_i3])
+                    T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])
+                    T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(input[v_i0, 
v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i3])
+            for i0, i1, i2, k in T.grid(T.int64(1), T.int64(4), T.int64(8192), 
T.int64(32)):
+                with T.block("T_softmax_expsum"):
+                    v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, 
k])
+                    T.reads(T_softmax_exp[v_i0, v_i1, v_k, v_i2])
+                    T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])
+                    with T.init():
+                        T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0.0)
+                    T_softmax_expsum[v_i0, v_i1, v_i2] = 
T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_k, v_i2]
+            for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(4), T.int64(32), 
T.int64(8192)):
+                with T.block("T_softmax_norm"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], 
T_softmax_expsum[v_i0, v_i1, v_i3])
+                    T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3])
+                    T.block_attr({"axis": 2})
+                    T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = 
T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i3]
+
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), 
T.int64(8192)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(4), 
T.int64(32), T.int64(8192)), "float32")):
+            T.func_attr({"tir.is_scheduled": 1})
+            # with T.block("root"):
+            T_softmax_maxelem_shared = T.alloc_buffer((T.int64(1), T.int64(4), 
T.int64(8192)), scope="shared")
+            T_softmax_expsum_shared = T.alloc_buffer((T.int64(1), T.int64(4), 
T.int64(8192)), scope="shared")
+            for ax0_ax2_fused in T.thread_binding(T.int64(32768), 
thread="blockIdx.x"):
+                for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
+                    for ax2_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                        for ax2_fused_0 in T.serial(T.int64(1), 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                            with T.block("T_softmax_maxelem"):
+                                v0 = T.axis.spatial(T.int64(4), ax0_ax2_fused 
// T.int64(8192) + ax0)
+                                v1 = T.axis.spatial(T.int64(8192), 
ax0_ax2_fused % T.int64(8192) + ax1)
+                                v2 = T.axis.reduce(T.int64(32), ax2_fused_0 * 
T.int64(256) + ax2_fused_1)
+                                T.where(ax2_fused_0 * T.int64(256) + 
ax2_fused_1 < T.int64(32))
+                                T.reads(input[T.int64(0), v0, v2, v1])
+                                T.writes(T_softmax_maxelem_shared[T.int64(0), 
v0, v1])
+                                with T.init():
+                                    T_softmax_maxelem_shared[T.int64(0), v0, 
v1] = T.float32(-340282346638528859811704183484516925440.0)
+                                T_softmax_maxelem_shared[T.int64(0), v0, v1] = 
T.max(T_softmax_maxelem_shared[T.int64(0), v0, v1], input[T.int64(0), v0, v2, 
v1])
+                for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
+                    for ax2_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                        for ax2_fused_0 in T.serial(T.int64(1), 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                            with T.block("T_softmax_expsum"):
+                                v0 = T.axis.spatial(T.int64(4), ax0_ax2_fused 
// T.int64(8192) + ax0)
+                                v1 = T.axis.spatial(T.int64(8192), 
ax0_ax2_fused % T.int64(8192) + ax1)
+                                v2 = T.axis.reduce(T.int64(32), ax2_fused_0 * 
T.int64(256) + ax2_fused_1)
+                                T.where(ax2_fused_0 * T.int64(256) + 
ax2_fused_1 < T.int64(32))
+                                T.reads(input[T.int64(0), v0, v2, v1], 
T_softmax_maxelem_shared[T.int64(0), v0, v1])
+                                T.writes(T_softmax_expsum_shared[T.int64(0), 
v0, v1])
+                                with T.init():
+                                    T_softmax_expsum_shared[T.int64(0), v0, 
v1] = T.float32(0.0)
+                                T_softmax_expsum_shared[T.int64(0), v0, v1] = 
T_softmax_expsum_shared[T.int64(0), v0, v1] + T.exp(input[T.int64(0), v0, v2, 
v1] - T_softmax_maxelem_shared[T.int64(0), v0, v1])
+                for ax1_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                    for ax1_0 in T.serial(T.int64(1), 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                        with T.block("T_softmax_norm"):
+                            v0 = T.axis.spatial(T.int64(4), ax0_ax2_fused // 
T.int64(8192))
+                            v1 = T.axis.spatial(T.int64(32), ax1_0 * 
T.int64(256) + ax1_1)
+                            v2 = T.axis.spatial(T.int64(8192), ax0_ax2_fused % 
T.int64(8192))
+                            T.where(ax1_0 * T.int64(256) + ax1_1 < T.int64(32))
+                            T.reads(input[T.int64(0), v0, v1, v2], 
T_softmax_maxelem_shared[T.int64(0), v0, v2], 
T_softmax_expsum_shared[T.int64(0), v0, v2])
+                            T.writes(T_softmax_norm[T.int64(0), v0, v1, v2])
+                            T.block_attr({"axis": 2})
+                            T_softmax_norm[T.int64(0), v0, v1, v2] = 
T.exp(input[T.int64(0), v0, v1, v2] - T_softmax_maxelem_shared[T.int64(0), v0, 
v2]) / T_softmax_expsum_shared[T.int64(0), v0, v2]
+    # fmt: on
+    _check(Before, After)
+
+
 def test_layer_norm():
     # fmt: off
     @I.ir_module

Reply via email to