This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 67b0c6cc5f [Tests][S-TIR] Fix stale MetaSchedule sketch expectations
and migrate let binds to T.let (#19729)
67b0c6cc5f is described below
commit 67b0c6cc5f6f2a987637c2474c74c3489008f3be
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Jun 11 07:19:00 2026 -0400
[Tests][S-TIR] Fix stale MetaSchedule sketch expectations and migrate let
binds to T.let (#19729)
Fix the s_tir tests broken or left stale by two upstream changes.
* test_meta_schedule_space_cuda.py (cap, dil, gmm, t2d, nrm, sfm, cbr,
tbg) and test_meta_schedule_space_cuda_async.py (c2d): #18927 expanded
DefaultCUDA unroll_max_steps from {0, 16, 64, 512, 1024} to {0, 16, 32,
64, 128, 256, 512, 1024} without updating the recorded SampleCategorical
decisions. Remap the indices (2->3, 3->6, 4->7) so each test keeps
sampling the same unroll value; every sketch was re-verified by
replaying the trace and structurally comparing against the expected
module.
* T.let migration: since #19581 the TIRx parser treats `v: T.int32 =
expr` as a mutable local-scalar buffer instead of an immutable bind,
which is now spelled `v: T.let[T.int32] = expr` (a Bind node, the same
form te.create_prim_func emits). Tests whose intent is a bind are
migrated to the new spelling: reduction combiner temporaries
(add_rfactor, lower_cross_thread_reduction) and let-dependent passes
(compact_buffer_region, hoist_expression, remove_undef).
* Also convert reduction temporaries in still-green tests
(cross_thread_reduction rule, compute_inline, schedule utilities,
parallel_vectorize_unroll postproc, dlight general reduction, relax
cuda_graph) so the hand-written workloads match the canonical Bind form
instead of feeding rules a mutable-scalar body.
---
.../relax/test_transform_rewrite_cuda_graph.py | 4 +--
.../s_tir/dlight/test_gpu_general_reduction.py | 16 +++++-----
...e_postproc_rewrite_parallel_vectorize_unroll.py | 8 ++---
...test_meta_schedule_schedule_rule_add_rfactor.py | 28 +++++++++--------
...chedule_schedule_rule_cross_thread_reduction.py | 36 ++++++++++++++--------
.../meta_schedule/test_meta_schedule_space_cuda.py | 18 +++++------
.../test_meta_schedule_space_cuda_async.py | 2 +-
.../schedule/test_tir_schedule_compute_inline.py | 8 ++---
.../s_tir/schedule/test_tir_schedule_utilities.py | 4 +--
.../test_s_tir_transform_compact_buffer_region.py | 18 +++++------
.../test_s_tir_transform_hoist_expression.py | 10 +++---
...s_tir_transform_lower_cross_thread_reduction.py | 30 ++++++++++--------
.../transform/test_s_tir_transform_remove_undef.py | 4 +--
13 files changed, 104 insertions(+), 82 deletions(-)
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 80637edcc0..3897e444bc 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -475,10 +475,10 @@ def test_capture_fixed_inputs():
with T.init():
A_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.float32(0)
A_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.float32(0)
- v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1,
v_ax2] + T.Cast(
+ v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[v_ax0,
v_ax1, v_ax2] + T.Cast(
"float32", A[v_ax0, v_ax1, v_ax2, v_k3]
)
- v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1,
v_ax2] + T.Cast(
+ v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[v_ax0,
v_ax1, v_ax2] + T.Cast(
"float32", A[v_ax0, v_ax1, v_ax2, v_k3]
) * T.Cast("float32", A[v_ax0, v_ax1, v_ax2, v_k3])
A_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_A_red_temp_v0
diff --git a/tests/python/s_tir/dlight/test_gpu_general_reduction.py
b/tests/python/s_tir/dlight/test_gpu_general_reduction.py
index 7022cef9f2..6ddc4671b9 100644
--- a/tests/python/s_tir/dlight/test_gpu_general_reduction.py
+++ b/tests/python/s_tir/dlight/test_gpu_general_reduction.py
@@ -337,8 +337,8 @@ def test_layer_norm():
with T.init():
A_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
A_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
- v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] +
lv6[v_ax0, v_ax1, v_k2]
- v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] +
lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2]
+ v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[v_ax0,
v_ax1] + lv6[v_ax0, v_ax1, v_k2]
+ v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[v_ax0,
v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2]
A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0
A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)):
@@ -377,8 +377,8 @@ def test_layer_norm():
with T.init():
A_red_temp_v0_shared[T.int64(0), v0] =
T.float32(0)
A_red_temp_v1_shared[T.int64(0), v0] =
T.float32(0)
- v_A_red_temp_v0: T.float32 =
A_red_temp_v0_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1]
- v_A_red_temp_v1: T.float32 =
A_red_temp_v1_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1] *
lv6[T.int64(0), v0, v1]
+ v_A_red_temp_v0: T.let[T.float32] =
A_red_temp_v0_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1]
+ v_A_red_temp_v1: T.let[T.float32] =
A_red_temp_v1_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1] *
lv6[T.int64(0), v0, v1]
A_red_temp_v0_shared[T.int64(0), v0] =
v_A_red_temp_v0
A_red_temp_v1_shared[T.int64(0), v0] =
v_A_red_temp_v1
for ax1_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
@@ -481,8 +481,8 @@ def test_group_norm():
with T.init():
A_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
A_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
- v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] +
T_reshape_1[v_ax0, v_ax1, v_k2]
- v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] +
T_reshape_1[v_ax0, v_ax1, v_k2] * T_reshape_1[v_ax0, v_ax1, v_k2]
+ v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[v_ax0,
v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2]
+ v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[v_ax0,
v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2] * T_reshape_1[v_ax0, v_ax1, v_k2]
A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0
A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1
for ax0, ax1 in T.grid(32, 64):
@@ -531,8 +531,8 @@ def test_group_norm():
with T.init():
A_red_temp_v0_shared[0, v0] = T.float32(0)
A_red_temp_v1_shared[0, v0] = T.float32(0)
- v_A_red_temp_v0: T.float32 =
A_red_temp_v0_shared[0, v0] + A[0, v0 * 64 + v1]
- v_A_red_temp_v1: T.float32 =
A_red_temp_v1_shared[0, v0] + A[0, v0 * 64 + v1] * A[0, v0 * 64 + v1]
+ v_A_red_temp_v0: T.let[T.float32] =
A_red_temp_v0_shared[0, v0] + A[0, v0 * 64 + v1]
+ v_A_red_temp_v1: T.let[T.float32] =
A_red_temp_v1_shared[0, v0] + A[0, v0 * 64 + v1] * A[0, v0 * 64 + v1]
A_red_temp_v0_shared[0, v0] = v_A_red_temp_v0
A_red_temp_v1_shared[0, v0] = v_A_red_temp_v1
for ax1_1 in T.thread_binding(256, thread="threadIdx.x"):
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
index 84d667f47d..5152e2fa15 100644
---
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
+++
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
@@ -241,8 +241,8 @@ def test_no_unroll_for_spatial_block():
with T.init():
A_red_temp_v0[v_ax0] = T.float32(0)
A_red_temp_v1[v_ax0] = T.float32(0)
- v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0] +
A[v_ax0, v_k1, v_k2, v_k3]
- v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0] +
A[v_ax0, v_k1, v_k2, v_k3] * A[v_ax0, v_k1, v_k2, v_k3]
+ v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[v_ax0] +
A[v_ax0, v_k1, v_k2, v_k3]
+ v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[v_ax0] +
A[v_ax0, v_k1, v_k2, v_k3] * A[v_ax0, v_k1, v_k2, v_k3]
A_red_temp_v0[v_ax0] = v_A_red_temp_v0
A_red_temp_v1[v_ax0] = v_A_red_temp_v1
for ax0, ax1, ax2, ax3 in T.grid(1, 4, 4, 32):
@@ -267,8 +267,8 @@ def test_no_unroll_for_spatial_block():
with T.init():
A_red_temp_v0[0] = T.float32(0)
A_red_temp_v1[0] = T.float32(0)
- v_A_red_temp_v0: T.float32 = A_red_temp_v0[0] + A[0,
v_k1, v_k2, v_k3]
- v_A_red_temp_v1: T.float32 = A_red_temp_v1[0] + A[0,
v_k1, v_k2, v_k3] * A[0, v_k1, v_k2, v_k3]
+ v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[0] +
A[0, v_k1, v_k2, v_k3]
+ v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[0] +
A[0, v_k1, v_k2, v_k3] * A[0, v_k1, v_k2, v_k3]
A_red_temp_v0[0] = v_A_red_temp_v0
A_red_temp_v1[0] = v_A_red_temp_v1
for ax0, ax1, ax2, ax3 in T.grid(1, 4, 4, 32):
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py
index fc6043526d..550cc35e94 100644
---
a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py
+++
b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py
@@ -138,8 +138,10 @@ def test_cpu_argmax():
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.min_value("float32")
- v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k],
argmax_v0[i], idx[i, k])
- v_argmax_v1: T.float32 = T.Select(
+ v_argmax_v0: T.let[T.int32] = T.Select(
+ argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+ )
+ v_argmax_v1: T.let[T.float32] = T.Select(
argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
)
argmax_v0[i] = v_argmax_v0
@@ -160,8 +162,10 @@ def test_cpu_argmax():
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.float32(-3.4028234663852886e38)
- v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k],
argmax_v0[i], idx[i, k])
- v_argmax_v1: T.float32 = T.Select(
+ v_argmax_v0: T.let[T.int32] = T.Select(
+ argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+ )
+ v_argmax_v1: T.let[T.float32] = T.Select(
argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
)
argmax_v0[i] = v_argmax_v0
@@ -184,12 +188,12 @@ def test_cpu_argmax():
with T.init():
argmax_v0_rf[i, vi1_1] = -1
argmax_v1_rf[i, vi1_1] = T.float32(-3.4028234663852886e38)
- v_argmax_v0_rf: T.int32 = T.Select(
+ v_argmax_v0_rf: T.let[T.int32] = T.Select(
argmax_v1_rf[i, vi1_1] >= val[i, vi1_0 * 16 + vi1_1],
argmax_v0_rf[i, vi1_1],
idx[i, vi1_0 * 16 + vi1_1],
)
- v_argmax_v1_rf: T.float32 = T.Select(
+ v_argmax_v1_rf: T.let[T.float32] = T.Select(
argmax_v1_rf[i, vi1_1] >= val[i, vi1_0 * 16 + vi1_1],
argmax_v1_rf[i, vi1_1],
val[i, vi1_0 * 16 + vi1_1],
@@ -205,10 +209,10 @@ def test_cpu_argmax():
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.float32(-3.4028234663852886e38)
- v_argmax_v0: T.int32 = T.Select(
+ v_argmax_v0: T.let[T.int32] = T.Select(
argmax_v1[i] >= argmax_v1_rf[i, vi1_1], argmax_v0[i],
argmax_v0_rf[i, vi1_1]
)
- v_argmax_v1: T.float32 = T.Select(
+ v_argmax_v1: T.let[T.float32] = T.Select(
argmax_v1[i] >= argmax_v1_rf[i, vi1_1], argmax_v1[i],
argmax_v1_rf[i, vi1_1]
)
argmax_v0[i] = v_argmax_v0
@@ -233,12 +237,12 @@ def test_cpu_argmax():
with T.init():
argmax_v0_rf[i, vi1_0] = -1
argmax_v1_rf[i, vi1_0] = T.float32(-3.4028234663852886e38)
- v_argmax_v0_rf: T.int32 = T.Select(
+ v_argmax_v0_rf: T.let[T.int32] = T.Select(
argmax_v1_rf[i, vi1_0] >= val[i, vi1_0 * 16 + vi1_1],
argmax_v0_rf[i, vi1_0],
idx[i, vi1_0 * 16 + vi1_1],
)
- v_argmax_v1_rf: T.float32 = T.Select(
+ v_argmax_v1_rf: T.let[T.float32] = T.Select(
argmax_v1_rf[i, vi1_0] >= val[i, vi1_0 * 16 + vi1_1],
argmax_v1_rf[i, vi1_0],
val[i, vi1_0 * 16 + vi1_1],
@@ -254,10 +258,10 @@ def test_cpu_argmax():
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.float32(-3.4028234663852886e38)
- v_argmax_v0: T.int32 = T.Select(
+ v_argmax_v0: T.let[T.int32] = T.Select(
argmax_v1[i] >= argmax_v1_rf[i, vi1_0], argmax_v0[i],
argmax_v0_rf[i, vi1_0]
)
- v_argmax_v1: T.float32 = T.Select(
+ v_argmax_v1: T.let[T.float32] = T.Select(
argmax_v1[i] >= argmax_v1_rf[i, vi1_0], argmax_v1[i],
argmax_v1_rf[i, vi1_0]
)
argmax_v0[i] = v_argmax_v0
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py
index eaecaa0fb5..0c6cc9bb24 100644
---
a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py
+++
b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py
@@ -582,8 +582,12 @@ def argmax(
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.min_value("float32")
- v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k],
argmax_v0[i], idx[i, k])
- v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k],
argmax_v1[i], val[i, k])
+ v_argmax_v0: T.let[T.int32] = T.Select(
+ argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+ )
+ v_argmax_v1: T.let[T.float32] = T.Select(
+ argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
+ )
argmax_v0[i] = v_argmax_v0
argmax_v1[i] = v_argmax_v1
@@ -604,8 +608,12 @@ def argmax_32(
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.min_value("float32")
- v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k],
argmax_v0[i], idx[i, k])
- v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k],
argmax_v1[i], val[i, k])
+ v_argmax_v0: T.let[T.int32] = T.Select(
+ argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+ )
+ v_argmax_v1: T.let[T.float32] = T.Select(
+ argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
+ )
argmax_v0[i] = v_argmax_v0
argmax_v1[i] = v_argmax_v1
@@ -628,8 +636,10 @@ def test_gpu_argmax():
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.float32(-3.4028234663852886e38)
- v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k],
argmax_v0[i], idx[i, k])
- v_argmax_v1: T.float32 = T.Select(
+ v_argmax_v0: T.let[T.int32] = T.Select(
+ argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+ )
+ v_argmax_v1: T.let[T.float32] = T.Select(
argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
)
argmax_v0[i] = v_argmax_v0
@@ -654,10 +664,10 @@ def test_gpu_argmax():
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.float32(-3.4028234663852886e38)
- v_argmax_v0: T.int32 = T.Select(
+ v_argmax_v0: T.let[T.int32] = T.Select(
argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
)
- v_argmax_v1: T.float32 = T.Select(
+ v_argmax_v1: T.let[T.float32] = T.Select(
argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
)
argmax_v0[i] = v_argmax_v0
@@ -701,8 +711,10 @@ def test_gpu_argmax_32():
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.float32(-3.4028234663852886e38)
- v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k],
argmax_v0[i], idx[i, k])
- v_argmax_v1: T.float32 = T.Select(
+ v_argmax_v0: T.let[T.int32] = T.Select(
+ argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+ )
+ v_argmax_v1: T.let[T.float32] = T.Select(
argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
)
argmax_v0[i] = v_argmax_v0
@@ -728,10 +740,10 @@ def test_gpu_argmax_32():
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.float32(-3.4028234663852886e38)
- v_argmax_v0: T.int32 = T.Select(
+ v_argmax_v0: T.let[T.int32] = T.Select(
argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
)
- v_argmax_v1: T.float32 = T.Select(
+ v_argmax_v1: T.let[T.float32] = T.Select(
argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
)
argmax_v0[i] = v_argmax_v0
diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
index ba9ac778a5..d748d074ed 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
@@ -375,7 +375,7 @@ def test_cuda_cap():
("SamplePerfectTile", [8, 4, 1]),
("SampleCategorical", 1),
("SampleCategorical", 3),
- ("SampleCategorical", 2),
+ ("SampleCategorical", 3),
]
mod = create_te_workload("CAP", 0)
actual = _design_space(mod)
@@ -537,7 +537,7 @@ def test_cuda_dil():
("SamplePerfectTile", [3, 1, 1]),
("SampleCategorical", 1),
("SampleCategorical", 3),
- ("SampleCategorical", 3),
+ ("SampleCategorical", 6),
]
mod = create_te_workload("DIL", 0)
actual = _design_space(mod)
@@ -611,7 +611,7 @@ def test_cuda_gmm():
("SamplePerfectTile", [1, 32, 4]),
("SampleCategorical", 1),
("SampleCategorical", 0),
- ("SampleCategorical", 4),
+ ("SampleCategorical", 7),
]
mod = create_te_workload("GMM", 0)
actual = _design_space(mod)
@@ -776,7 +776,7 @@ def test_cuda_t2d():
("SamplePerfectTile", [16, 4, 8]),
("SampleCategorical", 1),
("SampleCategorical", 3),
- ("SampleCategorical", 2),
+ ("SampleCategorical", 3),
]
mod = create_te_workload("T2D", 0)
actual = _design_space(mod)
@@ -846,11 +846,11 @@ def test_cuda_nrm():
D[v_b] = T.sqrt(C_shared[v_b])
# fmt: on
decision_0 = [
- ("SampleCategorical", 3),
+ ("SampleCategorical", 6),
]
decision_1 = [
("SampleCategorical", 5),
- ("SampleCategorical", 4),
+ ("SampleCategorical", 7),
]
mod = create_te_workload("NRM", 0)
actual = _design_space(mod)
@@ -1043,7 +1043,7 @@ def test_cuda_sfm():
]
decision_2 = [
("SampleCategorical", 7),
- ("SampleCategorical", 3),
+ ("SampleCategorical", 6),
("SampleCategorical", 0),
]
decision_3 = [
@@ -1132,7 +1132,7 @@ def test_cuda_cbr():
("SamplePerfectTile", [3, 1, 1]),
("SampleCategorical", 0),
("SampleCategorical", 0),
- ("SampleCategorical", 3),
+ ("SampleCategorical", 6),
]
mod = create_te_workload("CBR", 0)
actual = _design_space(mod)
@@ -1211,7 +1211,7 @@ def test_cuda_tbg():
("SamplePerfectTile", [8, 4, 2]),
("SampleCategorical", 2),
("SampleCategorical", 3),
- ("SampleCategorical", 4),
+ ("SampleCategorical", 7),
]
mod = create_te_workload("TBG", 0)
actual = _design_space(mod)
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py
index 4c44feae91..1907502b43 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py
@@ -178,7 +178,7 @@ def test_cuda_c2d():
("SamplePerfectTile", [3, 1, 1]),
("SampleCategorical", 3),
("SampleCategorical", 2),
- ("SampleCategorical", 4),
+ ("SampleCategorical", 7),
]
mod = create_te_workload("C2D", 0)
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py
b/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py
index df0e4963b9..d05d67b4a8 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py
@@ -1424,8 +1424,8 @@ def test_reverse_compute_inline_layer_norm():
with T.init():
A_red_temp_v0_shared[v_ax0, v_ax1] = T.float32(0)
A_red_temp_v1_shared[v_ax0, v_ax1] = T.float32(0)
- v_A_red_temp_v0: T.float32 =
A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2]
- v_A_red_temp_v1: T.float32 =
A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0,
v_ax1, v_k2]
+ v_A_red_temp_v0: T.let[T.float32] =
A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2]
+ v_A_red_temp_v1: T.let[T.float32] =
A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0,
v_ax1, v_k2]
A_red_temp_v0_shared[v_ax0, v_ax1] = v_A_red_temp_v0
A_red_temp_v1_shared[v_ax0, v_ax1] = v_A_red_temp_v1
for ax2_0 in range(T.int64(10)):
@@ -1465,8 +1465,8 @@ def test_reverse_compute_inline_layer_norm():
with T.init():
A_red_temp_v0_shared[v_ax0, v_ax1] = T.float32(0)
A_red_temp_v1_shared[v_ax0, v_ax1] = T.float32(0)
- v_A_red_temp_v0: T.float32 =
A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2]
- v_A_red_temp_v1: T.float32 =
A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0,
v_ax1, v_k2]
+ v_A_red_temp_v0: T.let[T.float32] =
A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2]
+ v_A_red_temp_v1: T.let[T.float32] =
A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0,
v_ax1, v_k2]
A_red_temp_v0_shared[v_ax0, v_ax1] = v_A_red_temp_v0
A_red_temp_v1_shared[v_ax0, v_ax1] = v_A_red_temp_v1
for ax2_0 in range(T.int64(10)):
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_utilities.py
b/tests/python/s_tir/schedule/test_tir_schedule_utilities.py
index dcd3b7b5a2..a2aa9c699d 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_utilities.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_utilities.py
@@ -147,8 +147,8 @@ def tuple_reduction(data: T.Buffer((4, 32), "float32"),
T_add: T.Buffer((4,), "f
with T.init():
data_red_temp_v0[ax0] = T.float32(0)
data_red_temp_v1[ax0] = T.float32(0)
- v_data_red_temp_v0: T.float32 = data_red_temp_v0[ax0] +
data[ax0, k1]
- v_data_red_temp_v1: T.float32 = (
+ v_data_red_temp_v0: T.let[T.float32] = data_red_temp_v0[ax0] +
data[ax0, k1]
+ v_data_red_temp_v1: T.let[T.float32] = (
data_red_temp_v1[ax0] + data[ax0, k1] * data[ax0, k1]
)
data_red_temp_v0[ax0] = v_data_red_temp_v0
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py
b/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py
index e398876f35..3185147f8e 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py
@@ -755,8 +755,8 @@ class TestLetBinding(BaseCompactTest):
for rii, rjj in T.grid(8, 8):
C[rii, rjj] = T.float32(0)
for riijj in T.serial(8 * 8):
- rii: T.int32 = riijj // 8
- rjj: T.int32 = riijj % 8
+ rii: T.let[T.int32] = riijj // 8
+ rjj: T.let[T.int32] = riijj % 8
C[rii, rjj] += A[rk, rii] * B[rk, rjj]
expected = before
@@ -766,13 +766,13 @@ class TestNonIndexLetBinding(BaseCompactTest):
@T.prim_func(s_tir=True)
def before():
A = T.sblock_alloc_buffer((64), "float32")
- x1 = T.call_extern("get", dtype="float16")
- x2 = T.call_extern("get", dtype="float32")
- x3 = T.call_extern("get", dtype="float64")
- x4 = T.call_extern("get", dtype="uint8")
- x5 = T.call_extern("get", dtype="int32x16")
- x6 = T.call_extern("get", dtype="handle")
- x7 = T.call_extern("get", dtype="")
+ x1: T.let[T.float16] = T.call_extern("get", dtype="float16")
+ x2: T.let[T.float32] = T.call_extern("get", dtype="float32")
+ x3: T.let[T.float64] = T.call_extern("get", dtype="float64")
+ x4: T.let[T.uint8] = T.call_extern("get", dtype="uint8")
+ x5: T.let[T.int32x16] = T.call_extern("get", dtype="int32x16")
+ x6: T.let[T.handle] = T.call_extern("get", dtype="handle")
+ x7: T.let = T.call_extern("get", dtype="")
for rk in range(64):
A[rk] = T.call_extern("load_ptr", x1, x2, x3, x4, x5, x6, x7,
dtype="float32")
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
index b4c52d2831..da86edb310 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
@@ -221,14 +221,14 @@ def test_hoist_with_let():
def before(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
for j in T.serial(4):
- condition = i < 3
+ condition: T.let[T.bool] = i < 3
if condition:
A[i, j] = 0.0
@T.prim_func(private=True, s_tir=True)
def expected(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
- condition: T.bool = i < 3 # noqa: F841
+ condition: T.let[T.bool] = i < 3 # noqa: F841
if i < 3:
for j in T.serial(4):
A[i, j] = T.float32(0.0)
@@ -250,14 +250,14 @@ def test_hoist_disable_let():
def before(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
for j in T.serial(4):
- condition = i < 3
+ condition: T.let[T.bool] = i < 3
if condition:
A[i, j] = 0.0
@T.prim_func(private=True, s_tir=True)
def expected(A: T.Buffer((4, 4), "float32")):
for i, j in T.grid(4, 4):
- condition: T.bool = i < 3 # noqa: F841
+ condition: T.let[T.bool] = i < 3 # noqa: F841
if i < 3:
A[i, j] = T.float32(0.0)
@@ -519,7 +519,7 @@ def test_hoist_let_expr():
@T.prim_func(private=True, s_tir=True)
def expected(A: T.Buffer((4, 4), "float32")):
for i in T.serial(4):
- x: T.float32 = T.cast(i + 1, "float32") # noqa: F841
+ x: T.let[T.float32] = T.cast(i + 1, "float32") # noqa: F841
for j in T.serial(4):
A[i, j] = T.float32(5.0) * T.cast(i + 1, "float32") +
T.cast(j, "float32")
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py
b/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py
index 34e08718f5..5e83bd8f7f 100644
---
a/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py
+++
b/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py
@@ -1246,8 +1246,10 @@ def argmax_split(
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.float32(-3.4028234663852886e38)
- v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k],
argmax_v0[i], idx[i, k])
- v_argmax_v1: T.float32 = T.Select(
+ v_argmax_v0: T.let[T.int32] = T.Select(
+ argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+ )
+ v_argmax_v1: T.let[T.float32] = T.Select(
argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
)
argmax_v0[i] = v_argmax_v0
@@ -1278,10 +1280,10 @@ def lowered_argmax_split(
k = T.axis.reduce(128, i1_0 * 32 + i1_1)
T.reads(idx[i, k], val[i, k])
T.writes(in_thread_argmax_v0[0], in_thread_argmax_v1[0])
- v_argmax_v0: T.int32 = T.Select(
+ v_argmax_v0: T.let[T.int32] = T.Select(
in_thread_argmax_v1[0] >= val[i, k],
in_thread_argmax_v0[0], idx[i, k]
)
- v_argmax_v1: T.float32 = T.Select(
+ v_argmax_v1: T.let[T.float32] = T.Select(
in_thread_argmax_v1[0] >= val[i, k],
in_thread_argmax_v1[0], val[i, k]
)
in_thread_argmax_v0[0] = v_argmax_v0
@@ -1338,8 +1340,10 @@ def argmin_split_init_update_reordered(
with T.init():
argmin_v1[i] = T.float32(3.4028234663852886e38)
argmin_v0[i] = -1
- v_argmin_v0: T.int32 = T.Select(argmin_v1[i] <= val[i, k],
argmin_v0[i], idx[i, k])
- v_argmin_v1: T.float32 = T.Select(
+ v_argmin_v0: T.let[T.int32] = T.Select(
+ argmin_v1[i] <= val[i, k], argmin_v0[i], idx[i, k]
+ )
+ v_argmin_v1: T.let[T.float32] = T.Select(
argmin_v1[i] <= val[i, k], argmin_v1[i], val[i, k]
)
argmin_v1[i] = v_argmin_v1
@@ -1370,10 +1374,10 @@ def lowered_argmin_split_init_update_reordered(
k = T.axis.reduce(128, i1_0 * 32 + i1_1)
T.reads(idx[i, k], val[i, k])
T.writes(in_thread_argmin_v0[0], in_thread_argmin_v1[0])
- v_argmin_v0: T.int32 = T.Select(
+ v_argmin_v0: T.let[T.int32] = T.Select(
in_thread_argmin_v1[0] <= val[i, k],
in_thread_argmin_v0[0], idx[i, k]
)
- v_argmin_v1: T.float32 = T.Select(
+ v_argmin_v1: T.let[T.float32] = T.Select(
in_thread_argmin_v1[0] <= val[i, k],
in_thread_argmin_v1[0], val[i, k]
)
in_thread_argmin_v1[0] = v_argmin_v1
@@ -1433,8 +1437,8 @@ def layer_norm_tuple_sum(
with T.init():
data_red_temp_v0[ax0] = T.float32(0)
data_red_temp_v1[ax0] = T.float32(0)
- v_data_red_temp_v0: T.float32 = data_red_temp_v0[ax0] +
data[ax0, k1]
- v_data_red_temp_v1: T.float32 = (
+ v_data_red_temp_v0: T.let[T.float32] =
data_red_temp_v0[ax0] + data[ax0, k1]
+ v_data_red_temp_v1: T.let[T.float32] = (
data_red_temp_v1[ax0] + data[ax0, k1] * data[ax0, k1]
)
data_red_temp_v0[ax0] = v_data_red_temp_v0
@@ -1499,8 +1503,10 @@ def lowered_layer_norm_tuple_sum(
k1 = T.axis.reduce(768, i1_0 * 32 + i1_1)
T.reads(data[ax0, k1])
T.writes(in_thread_data_red_temp_v0[0],
in_thread_data_red_temp_v1[0])
- v_data_red_temp_v0: T.float32 =
in_thread_data_red_temp_v0[0] + data[ax0, k1]
- v_data_red_temp_v1: T.float32 = (
+ v_data_red_temp_v0: T.let[T.float32] = (
+ in_thread_data_red_temp_v0[0] + data[ax0, k1]
+ )
+ v_data_red_temp_v1: T.let[T.float32] = (
in_thread_data_red_temp_v1[0] + data[ax0, k1] *
data[ax0, k1]
)
in_thread_data_red_temp_v0[0] = v_data_red_temp_v0
diff --git a/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py
b/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py
index cdc39c443a..a1dc101c74 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py
@@ -84,7 +84,7 @@ def test_remove_let_undef():
class Before:
@T.prim_func(s_tir=True)
def main(A: T.Buffer(1, "int32")):
- val = T.undef(dtype="int32")
+ val: T.let[T.int32] = T.undef(dtype="int32")
A[0] = val
@I.ir_module
@@ -104,7 +104,7 @@ def test_raise_error_for_undef_as_store_indices():
class Before:
@T.prim_func(s_tir=True)
def main(A: T.Buffer(1, "int32")):
- val = T.undef(dtype="int32")
+ val: T.let[T.int32] = T.undef(dtype="int32")
A[val] = 5
with pytest.raises(TVMError):