This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 684a8ca6a4 [Unity][DLight] Enhance the inline consumer rule (#16124)
684a8ca6a4 is described below

commit 684a8ca6a41c984a2431405e29b82ba862f70f82
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Nov 15 14:38:33 2023 +0800

    [Unity][DLight] Enhance the inline consumer rule (#16124)
    
    The current inline consumer rule failed on the following case, because
    of the missing inline of the producers of the output stage
    
    ```
     A   B   D
      \ /    |
     matmul  C
        \   /
         out
    ```
---
 python/tvm/dlight/gpu/matmul.py        |   1 +
 tests/python/dlight/test_gpu_matmul.py | 130 ++++++++++++++++++++++++++++++++-
 2 files changed, 130 insertions(+), 1 deletion(-)

diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index 703f9c151f..7d5d6489cb 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -99,6 +99,7 @@ def auto_inline_consumer_chain(
         for c in remaining_consumers:
             for p in sch.get_producers(c):
                 if sch.get(p) != sch.get(block):
+                    auto_inline_producers(sch, p)
                     sch.compute_inline(p)
 
         # Try inlining into the cache-write stage again, this time it should 
succeed.
diff --git a/tests/python/dlight/test_gpu_matmul.py 
b/tests/python/dlight/test_gpu_matmul.py
index 550e30e6e7..82f481da46 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -19,7 +19,6 @@ import pytest
 
 import tvm.testing
 from tvm import dlight as dl
-from tvm.script import ir as I
 from tvm.script import tir as T
 from tvm.target import Target
 
@@ -476,6 +475,135 @@ class TestOutputFP32(BaseBeforeAfter):
     # fmt: on
 
 
+class TestInlineConsumerChain(BaseBeforeAfter):
+    # fmt: off
+    @T.prim_func(private=True)
+    def before(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), 
"float16"), p_lv52: T.handle, p_output0: T.handle):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        n = T.int64()
+        lv26 = T.match_buffer(p_lv26, (n, T.int64(2048)), "float16")
+        lv52 = T.match_buffer(p_lv52, (T.int64(1), n, T.int64(2048)))
+        var_T_multiply_intermediate = T.match_buffer(p_output0, (n, 
T.int64(2048)), "float16")
+        # with T.block("root"):
+        var_NT_matmul_intermediate = T.alloc_buffer((n, T.int64(2048)), 
"float16")
+        compute = T.alloc_buffer((n, T.int64(2048)), "float16")
+        var_T_multiply_intermediate_1 = T.alloc_buffer((n, T.int64(2048)), 
"float16")
+        var_T_squeeze_intermediate = T.alloc_buffer((n, T.int64(2048)))
+        var_compute_intermediate = T.alloc_buffer((n, T.int64(2048)), 
"float16")
+        for i0, i1, k in T.grid(n, T.int64(2048), T.int64(2048)):
+            with T.block("NT_matmul"):
+                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+                T.reads(lv26[v_i0, v_k], lv9[v_i1, v_k])
+                T.writes(var_NT_matmul_intermediate[v_i0, v_i1])
+                with T.init():
+                    var_NT_matmul_intermediate[v_i0, v_i1] = T.float16(0)
+                var_NT_matmul_intermediate[v_i0, v_i1] = 
var_NT_matmul_intermediate[v_i0, v_i1] + lv26[v_i0, v_k] * lv9[v_i1, v_k]
+        for i0, i1 in T.grid(n, T.int64(2048)):
+            with T.block("compute"):
+                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                T.reads(var_NT_matmul_intermediate[v_i0, v_i1])
+                T.writes(compute[v_i0, v_i1])
+                compute[v_i0, v_i1] = 
T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1])
+        for ax0, ax1 in T.grid(n, T.int64(2048)):
+            with T.block("T_multiply"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1], 
compute[v_ax0, v_ax1])
+                T.writes(var_T_multiply_intermediate_1[v_ax0, v_ax1])
+                var_T_multiply_intermediate_1[v_ax0, v_ax1] = 
var_NT_matmul_intermediate[v_ax0, v_ax1] * compute[v_ax0, v_ax1]
+        for ax0, ax1 in T.grid(n, T.int64(2048)):
+            with T.block("T_squeeze"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(lv52[T.int64(0), v_ax0, v_ax1])
+                T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1])
+                var_T_squeeze_intermediate[v_ax0, v_ax1] = lv52[T.int64(0), 
v_ax0, v_ax1]
+        for i0, i1 in T.grid(n, T.int64(2048)):
+            with T.block("compute_1"):
+                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                T.reads(var_T_squeeze_intermediate[v_i0, v_i1])
+                T.writes(var_compute_intermediate[v_i0, v_i1])
+                var_compute_intermediate[v_i0, v_i1] = T.Cast("float16", 
var_T_squeeze_intermediate[v_i0, v_i1])
+        for ax0, ax1 in T.grid(n, T.int64(2048)):
+            with T.block("T_multiply_1"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(var_compute_intermediate[v_ax0, v_ax1], 
var_T_multiply_intermediate_1[v_ax0, v_ax1])
+                T.writes(var_T_multiply_intermediate[v_ax0, v_ax1])
+                var_T_multiply_intermediate[v_ax0, v_ax1] = 
var_compute_intermediate[v_ax0, v_ax1] * var_T_multiply_intermediate_1[v_ax0, 
v_ax1]
+
+    @T.prim_func
+    def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), 
T.int64(2048)), "float16"), p_lv52: T.handle, p_output0: T.handle):
+        T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+        n = T.int64()
+        lv26 = T.match_buffer(p_lv26, (n, T.int64(2048)), "float16")
+        lv52 = T.match_buffer(p_lv52, (T.int64(1), n, T.int64(2048)))
+        var_T_multiply_intermediate = T.match_buffer(p_output0, (n, 
T.int64(2048)), "float16")
+        # with T.block("root"):
+        var_NT_matmul_intermediate_reindex_pad_local = 
T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), 
T.int64(2048)), "float16", scope="local")
+        lv26_reindex_pad_shared = T.alloc_buffer((T.int64(1), (n + 
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", 
scope="shared")
+        lv9_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(2048), 
T.int64(2048)), "float16", scope="shared")
+        for ax0_ax2_0_fused in T.thread_binding(T.int64(32), 
thread="blockIdx.y"):
+            for ax1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), 
thread="blockIdx.x"):
+                for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+                    for ax1_1 in T.thread_binding(T.int64(1), 
thread="vthread.x"):
+                        for ax2_2 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
+                            for ax1_2 in T.thread_binding(T.int64(8), 
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, 
"pragma_unroll_explicit": 1}):
+                                for ax2_3_init, ax1_3_0_init in 
T.grid(T.int64(4), T.int64(2)):
+                                    for ax1_3_1_init in 
T.vectorized(T.int64(2)):
+                                        with T.block("NT_matmul_init"):
+                                            v0 = T.axis.spatial(T.int64(1), 
T.int64(0))
+                                            v1 = T.axis.spatial((n + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * 
T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0_init * T.int64(2) + ax1_3_1_init)
+                                            v2 = T.axis.spatial(T.int64(2048), 
ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + 
ax2_3_init)
+                                            T.reads()
+                                            
T.writes(var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
+                                            
var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float16(0)
+                                for ax3_0 in range(T.int64(128)):
+                                    for ax0_ax1_ax2_fused_0 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in 
range(T.int64(2)):
+                                                for ax0_ax1_ax2_fused_3 in 
T.vectorized(T.int64(2)):
+                                                    with 
T.block("lv26_reindex_pad_shared"):
+                                                        v0 = 
T.axis.spatial(T.int64(1), T.int64(0))
+                                                        v1 = T.axis.spatial((n 
+ T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + 
(ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + 
ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                        v2 = 
T.axis.spatial(T.int64(2048), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                        T.reads(lv26[v1, v2])
+                                                        
T.writes(lv26_reindex_pad_shared[v0, v1, v2])
+                                                        
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        
lv26_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv26[v1, v2], 
T.float16(0))
+                                    for ax0_ax1_ax2_fused_0 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in 
range(T.int64(4)):
+                                                for ax0_ax1_ax2_fused_3 in 
T.vectorized(T.int64(2)):
+                                                    with 
T.block("lv9_reindex_shared"):
+                                                        v0 = 
T.axis.spatial(T.int64(1), T.int64(0))
+                                                        v1 = 
T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + 
(ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + 
ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                        v2 = 
T.axis.spatial(T.int64(2048), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                        T.reads(lv9[v1, v2])
+                                                        
T.writes(lv9_reindex_shared[v0, v1, v2])
+                                                        
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        lv9_reindex_shared[v0, 
v1, v2] = lv9[v1, v2]
+                                    for ax3_1, ax2_3, ax1_3_0 in 
T.grid(T.int64(16), T.int64(4), T.int64(2)):
+                                        for ax1_3_1 in 
T.vectorized(T.int64(2)):
+                                            with T.block("NT_matmul_update"):
+                                                v0 = 
T.axis.spatial(T.int64(1), T.int64(0))
+                                                v1 = T.axis.spatial((n + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * 
T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0 * T.int64(2) + ax1_3_1)
+                                                v2 = 
T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_1 * 
T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
+                                                v3 = 
T.axis.reduce(T.int64(2048), ax3_0 * T.int64(16) + ax3_1)
+                                                
T.reads(var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], 
lv26_reindex_pad_shared[T.int64(0), v1, v3], lv9_reindex_shared[T.int64(0), v2, 
v3])
+                                                
T.writes(var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
+                                                
var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = 
var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + 
lv26_reindex_pad_shared[T.int64(0), v1, v3] * lv9_reindex_shared[T.int64(0), 
v2, v3]
+                                for ax0, ax1, ax2_0 in T.grid(T.int64(1), 
T.int64(4), T.int64(2)):
+                                    for ax2_1_1 in T.vectorized(T.int64(2)):
+                                        with 
T.block("var_NT_matmul_intermediate_reindex_pad_local"):
+                                            v0 = T.axis.spatial(T.int64(1), 
ax0)
+                                            v1 = T.axis.spatial((n + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * 
T.int64(4) + ax1)
+                                            v2 = T.axis.spatial(T.int64(2048), 
ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + 
ax2_1_1)
+                                            T.reads(lv52[T.int64(0), v1, v2], 
var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
+                                            
T.writes(var_T_multiply_intermediate[v1, v2])
+                                            if v1 < n:
+                                                
var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1, 
v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * 
T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))
+
+    # fmt: on
+
+
 class AndroidBeforeAfter(tvm.testing.CompareBeforeAfter):
     @pytest.fixture
     def transform(self):

Reply via email to