This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 684a8ca6a4 [Unity][DLight] Enhance the inline consumer rule (#16124)
684a8ca6a4 is described below
commit 684a8ca6a41c984a2431405e29b82ba862f70f82
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Nov 15 14:38:33 2023 +0800
[Unity][DLight] Enhance the inline consumer rule (#16124)
The current inline consumer rule failed on the following case, because
of the missing inline of the producers of the output stage
```
A B D
\ / |
matmul C
\ /
out
```
---
python/tvm/dlight/gpu/matmul.py | 1 +
tests/python/dlight/test_gpu_matmul.py | 130 ++++++++++++++++++++++++++++++++-
2 files changed, 130 insertions(+), 1 deletion(-)
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index 703f9c151f..7d5d6489cb 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -99,6 +99,7 @@ def auto_inline_consumer_chain(
for c in remaining_consumers:
for p in sch.get_producers(c):
if sch.get(p) != sch.get(block):
+ auto_inline_producers(sch, p)
sch.compute_inline(p)
# Try inlining into the cache-write stage again, this time it should
succeed.
diff --git a/tests/python/dlight/test_gpu_matmul.py
b/tests/python/dlight/test_gpu_matmul.py
index 550e30e6e7..82f481da46 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -19,7 +19,6 @@ import pytest
import tvm.testing
from tvm import dlight as dl
-from tvm.script import ir as I
from tvm.script import tir as T
from tvm.target import Target
@@ -476,6 +475,135 @@ class TestOutputFP32(BaseBeforeAfter):
# fmt: on
+class TestInlineConsumerChain(BaseBeforeAfter):
+ # fmt: off
+ @T.prim_func(private=True)
+ def before(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)),
"float16"), p_lv52: T.handle, p_output0: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n = T.int64()
+ lv26 = T.match_buffer(p_lv26, (n, T.int64(2048)), "float16")
+ lv52 = T.match_buffer(p_lv52, (T.int64(1), n, T.int64(2048)))
+ var_T_multiply_intermediate = T.match_buffer(p_output0, (n,
T.int64(2048)), "float16")
+ # with T.block("root"):
+ var_NT_matmul_intermediate = T.alloc_buffer((n, T.int64(2048)),
"float16")
+ compute = T.alloc_buffer((n, T.int64(2048)), "float16")
+ var_T_multiply_intermediate_1 = T.alloc_buffer((n, T.int64(2048)),
"float16")
+ var_T_squeeze_intermediate = T.alloc_buffer((n, T.int64(2048)))
+ var_compute_intermediate = T.alloc_buffer((n, T.int64(2048)),
"float16")
+ for i0, i1, k in T.grid(n, T.int64(2048), T.int64(2048)):
+ with T.block("NT_matmul"):
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+ T.reads(lv26[v_i0, v_k], lv9[v_i1, v_k])
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1])
+ with T.init():
+ var_NT_matmul_intermediate[v_i0, v_i1] = T.float16(0)
+ var_NT_matmul_intermediate[v_i0, v_i1] =
var_NT_matmul_intermediate[v_i0, v_i1] + lv26[v_i0, v_k] * lv9[v_i1, v_k]
+ for i0, i1 in T.grid(n, T.int64(2048)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(var_NT_matmul_intermediate[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] =
T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1])
+ for ax0, ax1 in T.grid(n, T.int64(2048)):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1],
compute[v_ax0, v_ax1])
+ T.writes(var_T_multiply_intermediate_1[v_ax0, v_ax1])
+ var_T_multiply_intermediate_1[v_ax0, v_ax1] =
var_NT_matmul_intermediate[v_ax0, v_ax1] * compute[v_ax0, v_ax1]
+ for ax0, ax1 in T.grid(n, T.int64(2048)):
+ with T.block("T_squeeze"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(lv52[T.int64(0), v_ax0, v_ax1])
+ T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1])
+ var_T_squeeze_intermediate[v_ax0, v_ax1] = lv52[T.int64(0),
v_ax0, v_ax1]
+ for i0, i1 in T.grid(n, T.int64(2048)):
+ with T.block("compute_1"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(var_T_squeeze_intermediate[v_i0, v_i1])
+ T.writes(var_compute_intermediate[v_i0, v_i1])
+ var_compute_intermediate[v_i0, v_i1] = T.Cast("float16",
var_T_squeeze_intermediate[v_i0, v_i1])
+ for ax0, ax1 in T.grid(n, T.int64(2048)):
+ with T.block("T_multiply_1"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(var_compute_intermediate[v_ax0, v_ax1],
var_T_multiply_intermediate_1[v_ax0, v_ax1])
+ T.writes(var_T_multiply_intermediate[v_ax0, v_ax1])
+ var_T_multiply_intermediate[v_ax0, v_ax1] =
var_compute_intermediate[v_ax0, v_ax1] * var_T_multiply_intermediate_1[v_ax0,
v_ax1]
+
+ @T.prim_func
+ def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048),
T.int64(2048)), "float16"), p_lv52: T.handle, p_output0: T.handle):
+ T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
+ n = T.int64()
+ lv26 = T.match_buffer(p_lv26, (n, T.int64(2048)), "float16")
+ lv52 = T.match_buffer(p_lv52, (T.int64(1), n, T.int64(2048)))
+ var_T_multiply_intermediate = T.match_buffer(p_output0, (n,
T.int64(2048)), "float16")
+ # with T.block("root"):
+ var_NT_matmul_intermediate_reindex_pad_local =
T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32),
T.int64(2048)), "float16", scope="local")
+ lv26_reindex_pad_shared = T.alloc_buffer((T.int64(1), (n +
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16",
scope="shared")
+ lv9_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(2048),
T.int64(2048)), "float16", scope="shared")
+ for ax0_ax2_0_fused in T.thread_binding(T.int64(32),
thread="blockIdx.y"):
+ for ax1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32),
thread="blockIdx.x"):
+ for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+ for ax1_1 in T.thread_binding(T.int64(1),
thread="vthread.x"):
+ for ax2_2 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
+ for ax1_2 in T.thread_binding(T.int64(8),
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
+ for ax2_3_init, ax1_3_0_init in
T.grid(T.int64(4), T.int64(2)):
+ for ax1_3_1_init in
T.vectorized(T.int64(2)):
+ with T.block("NT_matmul_init"):
+ v0 = T.axis.spatial(T.int64(1),
T.int64(0))
+ v1 = T.axis.spatial((n +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 *
T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0_init * T.int64(2) + ax1_3_1_init)
+ v2 = T.axis.spatial(T.int64(2048),
ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) +
ax2_3_init)
+ T.reads()
+
T.writes(var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
+
var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float16(0)
+ for ax3_0 in range(T.int64(128)):
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(T.int64(2)):
+ for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
+ with
T.block("lv26_reindex_pad_shared"):
+ v0 =
T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial((n
+ T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) +
(ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) +
ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 =
T.axis.spatial(T.int64(2048), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(lv26[v1, v2])
+
T.writes(lv26_reindex_pad_shared[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+
lv26_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv26[v1, v2],
T.float16(0))
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(T.int64(4)):
+ for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
+ with
T.block("lv9_reindex_shared"):
+ v0 =
T.axis.spatial(T.int64(1), T.int64(0))
+ v1 =
T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) +
(ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) +
ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 =
T.axis.spatial(T.int64(2048), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(lv9[v1, v2])
+
T.writes(lv9_reindex_shared[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+ lv9_reindex_shared[v0,
v1, v2] = lv9[v1, v2]
+ for ax3_1, ax2_3, ax1_3_0 in
T.grid(T.int64(16), T.int64(4), T.int64(2)):
+ for ax1_3_1 in
T.vectorized(T.int64(2)):
+ with T.block("NT_matmul_update"):
+ v0 =
T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial((n +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 *
T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0 * T.int64(2) + ax1_3_1)
+ v2 =
T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_1 *
T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
+ v3 =
T.axis.reduce(T.int64(2048), ax3_0 * T.int64(16) + ax3_1)
+
T.reads(var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2],
lv26_reindex_pad_shared[T.int64(0), v1, v3], lv9_reindex_shared[T.int64(0), v2,
v3])
+
T.writes(var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
+
var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] =
var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] +
lv26_reindex_pad_shared[T.int64(0), v1, v3] * lv9_reindex_shared[T.int64(0),
v2, v3]
+ for ax0, ax1, ax2_0 in T.grid(T.int64(1),
T.int64(4), T.int64(2)):
+ for ax2_1_1 in T.vectorized(T.int64(2)):
+ with
T.block("var_NT_matmul_intermediate_reindex_pad_local"):
+ v0 = T.axis.spatial(T.int64(1),
ax0)
+ v1 = T.axis.spatial((n +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 *
T.int64(4) + ax1)
+ v2 = T.axis.spatial(T.int64(2048),
ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) +
ax2_1_1)
+ T.reads(lv52[T.int64(0), v1, v2],
var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
+
T.writes(var_T_multiply_intermediate[v1, v2])
+ if v1 < n:
+
var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1,
v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] *
T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))
+
+ # fmt: on
+
+
class AndroidBeforeAfter(tvm.testing.CompareBeforeAfter):
@pytest.fixture
def transform(self):