This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 67b0c6cc5f [Tests][S-TIR] Fix stale MetaSchedule sketch expectations 
and migrate let binds to T.let (#19729)
67b0c6cc5f is described below

commit 67b0c6cc5f6f2a987637c2474c74c3489008f3be
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Jun 11 07:19:00 2026 -0400

    [Tests][S-TIR] Fix stale MetaSchedule sketch expectations and migrate let 
binds to T.let (#19729)
    
    Fix the s_tir tests broken or left stale by two upstream changes.
    
    * test_meta_schedule_space_cuda.py (cap, dil, gmm, t2d, nrm, sfm, cbr,
    tbg) and test_meta_schedule_space_cuda_async.py (c2d): #18927 expanded
    DefaultCUDA unroll_max_steps from {0, 16, 64, 512, 1024} to {0, 16, 32,
    64, 128, 256, 512, 1024} without updating the recorded SampleCategorical
    decisions. Remap the indices (2->3, 3->6, 4->7) so each test keeps
    sampling the same unroll value; every sketch was re-verified by
    replaying the trace and structurally comparing against the expected
    module.
    
    * T.let migration: since #19581 the TIRx parser treats `v: T.int32 =
    expr` as a mutable local-scalar buffer instead of an immutable bind,
    which is now spelled `v: T.let[T.int32] = expr` (a Bind node, the same
    form te.create_prim_func emits). Tests whose intent is a bind are
    migrated to the new spelling: reduction combiner temporaries
    (add_rfactor, lower_cross_thread_reduction) and let-dependent passes
    (compact_buffer_region, hoist_expression, remove_undef).
    
    * Also convert reduction temporaries in still-green tests
    (cross_thread_reduction rule, compute_inline, schedule utilities,
    parallel_vectorize_unroll postproc, dlight general reduction, relax
    cuda_graph) so the hand-written workloads match the canonical Bind form
    instead of feeding rules a mutable-scalar body.
---
 .../relax/test_transform_rewrite_cuda_graph.py     |  4 +--
 .../s_tir/dlight/test_gpu_general_reduction.py     | 16 +++++-----
 ...e_postproc_rewrite_parallel_vectorize_unroll.py |  8 ++---
 ...test_meta_schedule_schedule_rule_add_rfactor.py | 28 +++++++++--------
 ...chedule_schedule_rule_cross_thread_reduction.py | 36 ++++++++++++++--------
 .../meta_schedule/test_meta_schedule_space_cuda.py | 18 +++++------
 .../test_meta_schedule_space_cuda_async.py         |  2 +-
 .../schedule/test_tir_schedule_compute_inline.py   |  8 ++---
 .../s_tir/schedule/test_tir_schedule_utilities.py  |  4 +--
 .../test_s_tir_transform_compact_buffer_region.py  | 18 +++++------
 .../test_s_tir_transform_hoist_expression.py       | 10 +++---
 ...s_tir_transform_lower_cross_thread_reduction.py | 30 ++++++++++--------
 .../transform/test_s_tir_transform_remove_undef.py |  4 +--
 13 files changed, 104 insertions(+), 82 deletions(-)

diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py 
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 80637edcc0..3897e444bc 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -475,10 +475,10 @@ def test_capture_fixed_inputs():
                     with T.init():
                         A_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.float32(0)
                         A_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.float32(0)
-                    v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1, 
v_ax2] + T.Cast(
+                    v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[v_ax0, 
v_ax1, v_ax2] + T.Cast(
                         "float32", A[v_ax0, v_ax1, v_ax2, v_k3]
                     )
-                    v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1, 
v_ax2] + T.Cast(
+                    v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[v_ax0, 
v_ax1, v_ax2] + T.Cast(
                         "float32", A[v_ax0, v_ax1, v_ax2, v_k3]
                     ) * T.Cast("float32", A[v_ax0, v_ax1, v_ax2, v_k3])
                     A_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_A_red_temp_v0
diff --git a/tests/python/s_tir/dlight/test_gpu_general_reduction.py 
b/tests/python/s_tir/dlight/test_gpu_general_reduction.py
index 7022cef9f2..6ddc4671b9 100644
--- a/tests/python/s_tir/dlight/test_gpu_general_reduction.py
+++ b/tests/python/s_tir/dlight/test_gpu_general_reduction.py
@@ -337,8 +337,8 @@ def test_layer_norm():
                     with T.init():
                         A_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
                         A_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
-                    v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + 
lv6[v_ax0, v_ax1, v_k2]
-                    v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + 
lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2]
+                    v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[v_ax0, 
v_ax1] + lv6[v_ax0, v_ax1, v_k2]
+                    v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[v_ax0, 
v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2]
                     A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0
                     A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1
             for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)):
@@ -377,8 +377,8 @@ def test_layer_norm():
                                 with T.init():
                                     A_red_temp_v0_shared[T.int64(0), v0] = 
T.float32(0)
                                     A_red_temp_v1_shared[T.int64(0), v0] = 
T.float32(0)
-                                v_A_red_temp_v0: T.float32 = 
A_red_temp_v0_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1]
-                                v_A_red_temp_v1: T.float32 = 
A_red_temp_v1_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1] * 
lv6[T.int64(0), v0, v1]
+                                v_A_red_temp_v0: T.let[T.float32] = 
A_red_temp_v0_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1]
+                                v_A_red_temp_v1: T.let[T.float32] = 
A_red_temp_v1_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1] * 
lv6[T.int64(0), v0, v1]
                                 A_red_temp_v0_shared[T.int64(0), v0] = 
v_A_red_temp_v0
                                 A_red_temp_v1_shared[T.int64(0), v0] = 
v_A_red_temp_v1
                 for ax1_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
@@ -481,8 +481,8 @@ def test_group_norm():
                     with T.init():
                         A_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
                         A_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
-                    v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + 
T_reshape_1[v_ax0, v_ax1, v_k2]
-                    v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + 
T_reshape_1[v_ax0, v_ax1, v_k2] * T_reshape_1[v_ax0, v_ax1, v_k2]
+                    v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[v_ax0, 
v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2]
+                    v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[v_ax0, 
v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2] * T_reshape_1[v_ax0, v_ax1, v_k2]
                     A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0
                     A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1
             for ax0, ax1 in T.grid(32, 64):
@@ -531,8 +531,8 @@ def test_group_norm():
                                 with T.init():
                                     A_red_temp_v0_shared[0, v0] = T.float32(0)
                                     A_red_temp_v1_shared[0, v0] = T.float32(0)
-                                v_A_red_temp_v0: T.float32 = 
A_red_temp_v0_shared[0, v0] + A[0, v0 * 64 + v1]
-                                v_A_red_temp_v1: T.float32 = 
A_red_temp_v1_shared[0, v0] + A[0, v0 * 64 + v1] * A[0, v0 * 64 + v1]
+                                v_A_red_temp_v0: T.let[T.float32] = 
A_red_temp_v0_shared[0, v0] + A[0, v0 * 64 + v1]
+                                v_A_red_temp_v1: T.let[T.float32] = 
A_red_temp_v1_shared[0, v0] + A[0, v0 * 64 + v1] * A[0, v0 * 64 + v1]
                                 A_red_temp_v0_shared[0, v0] = v_A_red_temp_v0
                                 A_red_temp_v1_shared[0, v0] = v_A_red_temp_v1
                 for ax1_1 in T.thread_binding(256, thread="threadIdx.x"):
diff --git 
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
index 84d667f47d..5152e2fa15 100644
--- 
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
+++ 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
@@ -241,8 +241,8 @@ def test_no_unroll_for_spatial_block():
                     with T.init():
                         A_red_temp_v0[v_ax0] = T.float32(0)
                         A_red_temp_v1[v_ax0] = T.float32(0)
-                    v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0] + 
A[v_ax0, v_k1, v_k2, v_k3]
-                    v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0] + 
A[v_ax0, v_k1, v_k2, v_k3] * A[v_ax0, v_k1, v_k2, v_k3]
+                    v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[v_ax0] + 
A[v_ax0, v_k1, v_k2, v_k3]
+                    v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[v_ax0] + 
A[v_ax0, v_k1, v_k2, v_k3] * A[v_ax0, v_k1, v_k2, v_k3]
                     A_red_temp_v0[v_ax0] = v_A_red_temp_v0
                     A_red_temp_v1[v_ax0] = v_A_red_temp_v1
             for ax0, ax1, ax2, ax3 in T.grid(1, 4, 4, 32):
@@ -267,8 +267,8 @@ def test_no_unroll_for_spatial_block():
                         with T.init():
                             A_red_temp_v0[0] = T.float32(0)
                             A_red_temp_v1[0] = T.float32(0)
-                        v_A_red_temp_v0: T.float32 = A_red_temp_v0[0] + A[0, 
v_k1, v_k2, v_k3]
-                        v_A_red_temp_v1: T.float32 = A_red_temp_v1[0] + A[0, 
v_k1, v_k2, v_k3] * A[0, v_k1, v_k2, v_k3]
+                        v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[0] + 
A[0, v_k1, v_k2, v_k3]
+                        v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[0] + 
A[0, v_k1, v_k2, v_k3] * A[0, v_k1, v_k2, v_k3]
                         A_red_temp_v0[0] = v_A_red_temp_v0
                         A_red_temp_v1[0] = v_A_red_temp_v1
             for ax0, ax1, ax2, ax3 in T.grid(1, 4, 4, 32):
diff --git 
a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py
 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py
index fc6043526d..550cc35e94 100644
--- 
a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py
+++ 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py
@@ -138,8 +138,10 @@ def test_cpu_argmax():
                 with T.init():
                     argmax_v0[i] = -1
                     argmax_v1[i] = T.min_value("float32")
-                v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], 
argmax_v0[i], idx[i, k])
-                v_argmax_v1: T.float32 = T.Select(
+                v_argmax_v0: T.let[T.int32] = T.Select(
+                    argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+                )
+                v_argmax_v1: T.let[T.float32] = T.Select(
                     argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
                 )
                 argmax_v0[i] = v_argmax_v0
@@ -160,8 +162,10 @@ def test_cpu_argmax():
                 with T.init():
                     argmax_v0[i] = -1
                     argmax_v1[i] = T.float32(-3.4028234663852886e38)
-                v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], 
argmax_v0[i], idx[i, k])
-                v_argmax_v1: T.float32 = T.Select(
+                v_argmax_v0: T.let[T.int32] = T.Select(
+                    argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+                )
+                v_argmax_v1: T.let[T.float32] = T.Select(
                     argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
                 )
                 argmax_v0[i] = v_argmax_v0
@@ -184,12 +188,12 @@ def test_cpu_argmax():
                 with T.init():
                     argmax_v0_rf[i, vi1_1] = -1
                     argmax_v1_rf[i, vi1_1] = T.float32(-3.4028234663852886e38)
-                v_argmax_v0_rf: T.int32 = T.Select(
+                v_argmax_v0_rf: T.let[T.int32] = T.Select(
                     argmax_v1_rf[i, vi1_1] >= val[i, vi1_0 * 16 + vi1_1],
                     argmax_v0_rf[i, vi1_1],
                     idx[i, vi1_0 * 16 + vi1_1],
                 )
-                v_argmax_v1_rf: T.float32 = T.Select(
+                v_argmax_v1_rf: T.let[T.float32] = T.Select(
                     argmax_v1_rf[i, vi1_1] >= val[i, vi1_0 * 16 + vi1_1],
                     argmax_v1_rf[i, vi1_1],
                     val[i, vi1_0 * 16 + vi1_1],
@@ -205,10 +209,10 @@ def test_cpu_argmax():
                 with T.init():
                     argmax_v0[i] = -1
                     argmax_v1[i] = T.float32(-3.4028234663852886e38)
-                v_argmax_v0: T.int32 = T.Select(
+                v_argmax_v0: T.let[T.int32] = T.Select(
                     argmax_v1[i] >= argmax_v1_rf[i, vi1_1], argmax_v0[i], 
argmax_v0_rf[i, vi1_1]
                 )
-                v_argmax_v1: T.float32 = T.Select(
+                v_argmax_v1: T.let[T.float32] = T.Select(
                     argmax_v1[i] >= argmax_v1_rf[i, vi1_1], argmax_v1[i], 
argmax_v1_rf[i, vi1_1]
                 )
                 argmax_v0[i] = v_argmax_v0
@@ -233,12 +237,12 @@ def test_cpu_argmax():
                 with T.init():
                     argmax_v0_rf[i, vi1_0] = -1
                     argmax_v1_rf[i, vi1_0] = T.float32(-3.4028234663852886e38)
-                v_argmax_v0_rf: T.int32 = T.Select(
+                v_argmax_v0_rf: T.let[T.int32] = T.Select(
                     argmax_v1_rf[i, vi1_0] >= val[i, vi1_0 * 16 + vi1_1],
                     argmax_v0_rf[i, vi1_0],
                     idx[i, vi1_0 * 16 + vi1_1],
                 )
-                v_argmax_v1_rf: T.float32 = T.Select(
+                v_argmax_v1_rf: T.let[T.float32] = T.Select(
                     argmax_v1_rf[i, vi1_0] >= val[i, vi1_0 * 16 + vi1_1],
                     argmax_v1_rf[i, vi1_0],
                     val[i, vi1_0 * 16 + vi1_1],
@@ -254,10 +258,10 @@ def test_cpu_argmax():
                 with T.init():
                     argmax_v0[i] = -1
                     argmax_v1[i] = T.float32(-3.4028234663852886e38)
-                v_argmax_v0: T.int32 = T.Select(
+                v_argmax_v0: T.let[T.int32] = T.Select(
                     argmax_v1[i] >= argmax_v1_rf[i, vi1_0], argmax_v0[i], 
argmax_v0_rf[i, vi1_0]
                 )
-                v_argmax_v1: T.float32 = T.Select(
+                v_argmax_v1: T.let[T.float32] = T.Select(
                     argmax_v1[i] >= argmax_v1_rf[i, vi1_0], argmax_v1[i], 
argmax_v1_rf[i, vi1_0]
                 )
                 argmax_v0[i] = v_argmax_v0
diff --git 
a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py
 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py
index eaecaa0fb5..0c6cc9bb24 100644
--- 
a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py
+++ 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py
@@ -582,8 +582,12 @@ def argmax(
             with T.init():
                 argmax_v0[i] = -1
                 argmax_v1[i] = T.min_value("float32")
-            v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], 
argmax_v0[i], idx[i, k])
-            v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], 
argmax_v1[i], val[i, k])
+            v_argmax_v0: T.let[T.int32] = T.Select(
+                argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+            )
+            v_argmax_v1: T.let[T.float32] = T.Select(
+                argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
+            )
             argmax_v0[i] = v_argmax_v0
             argmax_v1[i] = v_argmax_v1
 
@@ -604,8 +608,12 @@ def argmax_32(
             with T.init():
                 argmax_v0[i] = -1
                 argmax_v1[i] = T.min_value("float32")
-            v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], 
argmax_v0[i], idx[i, k])
-            v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], 
argmax_v1[i], val[i, k])
+            v_argmax_v0: T.let[T.int32] = T.Select(
+                argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+            )
+            v_argmax_v1: T.let[T.float32] = T.Select(
+                argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
+            )
             argmax_v0[i] = v_argmax_v0
             argmax_v1[i] = v_argmax_v1
 
@@ -628,8 +636,10 @@ def test_gpu_argmax():
                 with T.init():
                     argmax_v0[i] = -1
                     argmax_v1[i] = T.float32(-3.4028234663852886e38)
-                v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], 
argmax_v0[i], idx[i, k])
-                v_argmax_v1: T.float32 = T.Select(
+                v_argmax_v0: T.let[T.int32] = T.Select(
+                    argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+                )
+                v_argmax_v1: T.let[T.float32] = T.Select(
                     argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
                 )
                 argmax_v0[i] = v_argmax_v0
@@ -654,10 +664,10 @@ def test_gpu_argmax():
                     with T.init():
                         argmax_v0[i] = -1
                         argmax_v1[i] = T.float32(-3.4028234663852886e38)
-                    v_argmax_v0: T.int32 = T.Select(
+                    v_argmax_v0: T.let[T.int32] = T.Select(
                         argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
                     )
-                    v_argmax_v1: T.float32 = T.Select(
+                    v_argmax_v1: T.let[T.float32] = T.Select(
                         argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
                     )
                     argmax_v0[i] = v_argmax_v0
@@ -701,8 +711,10 @@ def test_gpu_argmax_32():
                 with T.init():
                     argmax_v0[i] = -1
                     argmax_v1[i] = T.float32(-3.4028234663852886e38)
-                v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], 
argmax_v0[i], idx[i, k])
-                v_argmax_v1: T.float32 = T.Select(
+                v_argmax_v0: T.let[T.int32] = T.Select(
+                    argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+                )
+                v_argmax_v1: T.let[T.float32] = T.Select(
                     argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
                 )
                 argmax_v0[i] = v_argmax_v0
@@ -728,10 +740,10 @@ def test_gpu_argmax_32():
                     with T.init():
                         argmax_v0[i] = -1
                         argmax_v1[i] = T.float32(-3.4028234663852886e38)
-                    v_argmax_v0: T.int32 = T.Select(
+                    v_argmax_v0: T.let[T.int32] = T.Select(
                         argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
                     )
-                    v_argmax_v1: T.float32 = T.Select(
+                    v_argmax_v1: T.let[T.float32] = T.Select(
                         argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
                     )
                     argmax_v0[i] = v_argmax_v0
diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
index ba9ac778a5..d748d074ed 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
@@ -375,7 +375,7 @@ def test_cuda_cap():
         ("SamplePerfectTile", [8, 4, 1]),
         ("SampleCategorical", 1),
         ("SampleCategorical", 3),
-        ("SampleCategorical", 2),
+        ("SampleCategorical", 3),
     ]
     mod = create_te_workload("CAP", 0)
     actual = _design_space(mod)
@@ -537,7 +537,7 @@ def test_cuda_dil():
         ("SamplePerfectTile", [3, 1, 1]),
         ("SampleCategorical", 1),
         ("SampleCategorical", 3),
-        ("SampleCategorical", 3),
+        ("SampleCategorical", 6),
     ]
     mod = create_te_workload("DIL", 0)
     actual = _design_space(mod)
@@ -611,7 +611,7 @@ def test_cuda_gmm():
         ("SamplePerfectTile", [1, 32, 4]),
         ("SampleCategorical", 1),
         ("SampleCategorical", 0),
-        ("SampleCategorical", 4),
+        ("SampleCategorical", 7),
     ]
     mod = create_te_workload("GMM", 0)
     actual = _design_space(mod)
@@ -776,7 +776,7 @@ def test_cuda_t2d():
         ("SamplePerfectTile", [16, 4, 8]),
         ("SampleCategorical", 1),
         ("SampleCategorical", 3),
-        ("SampleCategorical", 2),
+        ("SampleCategorical", 3),
     ]
     mod = create_te_workload("T2D", 0)
     actual = _design_space(mod)
@@ -846,11 +846,11 @@ def test_cuda_nrm():
                         D[v_b] = T.sqrt(C_shared[v_b])
     # fmt: on
     decision_0 = [
-        ("SampleCategorical", 3),
+        ("SampleCategorical", 6),
     ]
     decision_1 = [
         ("SampleCategorical", 5),
-        ("SampleCategorical", 4),
+        ("SampleCategorical", 7),
     ]
     mod = create_te_workload("NRM", 0)
     actual = _design_space(mod)
@@ -1043,7 +1043,7 @@ def test_cuda_sfm():
     ]
     decision_2 = [
         ("SampleCategorical", 7),
-        ("SampleCategorical", 3),
+        ("SampleCategorical", 6),
         ("SampleCategorical", 0),
     ]
     decision_3 = [
@@ -1132,7 +1132,7 @@ def test_cuda_cbr():
         ("SamplePerfectTile", [3, 1, 1]),
         ("SampleCategorical", 0),
         ("SampleCategorical", 0),
-        ("SampleCategorical", 3),
+        ("SampleCategorical", 6),
     ]
     mod = create_te_workload("CBR", 0)
     actual = _design_space(mod)
@@ -1211,7 +1211,7 @@ def test_cuda_tbg():
         ("SamplePerfectTile", [8, 4, 2]),
         ("SampleCategorical", 2),
         ("SampleCategorical", 3),
-        ("SampleCategorical", 4),
+        ("SampleCategorical", 7),
     ]
     mod = create_te_workload("TBG", 0)
     actual = _design_space(mod)
diff --git 
a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py
index 4c44feae91..1907502b43 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py
@@ -178,7 +178,7 @@ def test_cuda_c2d():
         ("SamplePerfectTile", [3, 1, 1]),
         ("SampleCategorical", 3),
         ("SampleCategorical", 2),
-        ("SampleCategorical", 4),
+        ("SampleCategorical", 7),
     ]
 
     mod = create_te_workload("C2D", 0)
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py 
b/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py
index df0e4963b9..d05d67b4a8 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py
@@ -1424,8 +1424,8 @@ def test_reverse_compute_inline_layer_norm():
                         with T.init():
                             A_red_temp_v0_shared[v_ax0, v_ax1] = T.float32(0)
                             A_red_temp_v1_shared[v_ax0, v_ax1] = T.float32(0)
-                        v_A_red_temp_v0: T.float32 = 
A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2]
-                        v_A_red_temp_v1: T.float32 = 
A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, 
v_ax1, v_k2]
+                        v_A_red_temp_v0: T.let[T.float32] = 
A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2]
+                        v_A_red_temp_v1: T.let[T.float32] = 
A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, 
v_ax1, v_k2]
                         A_red_temp_v0_shared[v_ax0, v_ax1] = v_A_red_temp_v0
                         A_red_temp_v1_shared[v_ax0, v_ax1] = v_A_red_temp_v1
             for ax2_0 in range(T.int64(10)):
@@ -1465,8 +1465,8 @@ def test_reverse_compute_inline_layer_norm():
                         with T.init():
                             A_red_temp_v0_shared[v_ax0, v_ax1] = T.float32(0)
                             A_red_temp_v1_shared[v_ax0, v_ax1] = T.float32(0)
-                        v_A_red_temp_v0: T.float32 = 
A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2]
-                        v_A_red_temp_v1: T.float32 = 
A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, 
v_ax1, v_k2]
+                        v_A_red_temp_v0: T.let[T.float32] = 
A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2]
+                        v_A_red_temp_v1: T.let[T.float32] = 
A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, 
v_ax1, v_k2]
                         A_red_temp_v0_shared[v_ax0, v_ax1] = v_A_red_temp_v0
                         A_red_temp_v1_shared[v_ax0, v_ax1] = v_A_red_temp_v1
             for ax2_0 in range(T.int64(10)):
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_utilities.py 
b/tests/python/s_tir/schedule/test_tir_schedule_utilities.py
index dcd3b7b5a2..a2aa9c699d 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_utilities.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_utilities.py
@@ -147,8 +147,8 @@ def tuple_reduction(data: T.Buffer((4, 32), "float32"), 
T_add: T.Buffer((4,), "f
                 with T.init():
                     data_red_temp_v0[ax0] = T.float32(0)
                     data_red_temp_v1[ax0] = T.float32(0)
-                v_data_red_temp_v0: T.float32 = data_red_temp_v0[ax0] + 
data[ax0, k1]
-                v_data_red_temp_v1: T.float32 = (
+                v_data_red_temp_v0: T.let[T.float32] = data_red_temp_v0[ax0] + 
data[ax0, k1]
+                v_data_red_temp_v1: T.let[T.float32] = (
                     data_red_temp_v1[ax0] + data[ax0, k1] * data[ax0, k1]
                 )
                 data_red_temp_v0[ax0] = v_data_red_temp_v0
diff --git 
a/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py 
b/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py
index e398876f35..3185147f8e 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py
@@ -755,8 +755,8 @@ class TestLetBinding(BaseCompactTest):
             for rii, rjj in T.grid(8, 8):
                 C[rii, rjj] = T.float32(0)
             for riijj in T.serial(8 * 8):
-                rii: T.int32 = riijj // 8
-                rjj: T.int32 = riijj % 8
+                rii: T.let[T.int32] = riijj // 8
+                rjj: T.let[T.int32] = riijj % 8
                 C[rii, rjj] += A[rk, rii] * B[rk, rjj]
 
     expected = before
@@ -766,13 +766,13 @@ class TestNonIndexLetBinding(BaseCompactTest):
     @T.prim_func(s_tir=True)
     def before():
         A = T.sblock_alloc_buffer((64), "float32")
-        x1 = T.call_extern("get", dtype="float16")
-        x2 = T.call_extern("get", dtype="float32")
-        x3 = T.call_extern("get", dtype="float64")
-        x4 = T.call_extern("get", dtype="uint8")
-        x5 = T.call_extern("get", dtype="int32x16")
-        x6 = T.call_extern("get", dtype="handle")
-        x7 = T.call_extern("get", dtype="")
+        x1: T.let[T.float16] = T.call_extern("get", dtype="float16")
+        x2: T.let[T.float32] = T.call_extern("get", dtype="float32")
+        x3: T.let[T.float64] = T.call_extern("get", dtype="float64")
+        x4: T.let[T.uint8] = T.call_extern("get", dtype="uint8")
+        x5: T.let[T.int32x16] = T.call_extern("get", dtype="int32x16")
+        x6: T.let[T.handle] = T.call_extern("get", dtype="handle")
+        x7: T.let = T.call_extern("get", dtype="")
         for rk in range(64):
             A[rk] = T.call_extern("load_ptr", x1, x2, x3, x4, x5, x6, x7, 
dtype="float32")
 
diff --git 
a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py 
b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
index b4c52d2831..da86edb310 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
@@ -221,14 +221,14 @@ def test_hoist_with_let():
     def before(A: T.Buffer((4, 4), "float32")):
         for i in T.serial(4):
             for j in T.serial(4):
-                condition = i < 3
+                condition: T.let[T.bool] = i < 3
                 if condition:
                     A[i, j] = 0.0
 
     @T.prim_func(private=True, s_tir=True)
     def expected(A: T.Buffer((4, 4), "float32")):
         for i in T.serial(4):
-            condition: T.bool = i < 3  # noqa: F841
+            condition: T.let[T.bool] = i < 3  # noqa: F841
             if i < 3:
                 for j in T.serial(4):
                     A[i, j] = T.float32(0.0)
@@ -250,14 +250,14 @@ def test_hoist_disable_let():
     def before(A: T.Buffer((4, 4), "float32")):
         for i in T.serial(4):
             for j in T.serial(4):
-                condition = i < 3
+                condition: T.let[T.bool] = i < 3
                 if condition:
                     A[i, j] = 0.0
 
     @T.prim_func(private=True, s_tir=True)
     def expected(A: T.Buffer((4, 4), "float32")):
         for i, j in T.grid(4, 4):
-            condition: T.bool = i < 3  # noqa: F841
+            condition: T.let[T.bool] = i < 3  # noqa: F841
             if i < 3:
                 A[i, j] = T.float32(0.0)
 
@@ -519,7 +519,7 @@ def test_hoist_let_expr():
     @T.prim_func(private=True, s_tir=True)
     def expected(A: T.Buffer((4, 4), "float32")):
         for i in T.serial(4):
-            x: T.float32 = T.cast(i + 1, "float32")  # noqa: F841
+            x: T.let[T.float32] = T.cast(i + 1, "float32")  # noqa: F841
             for j in T.serial(4):
                 A[i, j] = T.float32(5.0) * T.cast(i + 1, "float32") + 
T.cast(j, "float32")
 
diff --git 
a/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py
 
b/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py
index 34e08718f5..5e83bd8f7f 100644
--- 
a/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py
+++ 
b/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py
@@ -1246,8 +1246,10 @@ def argmax_split(
                 with T.init():
                     argmax_v0[i] = -1
                     argmax_v1[i] = T.float32(-3.4028234663852886e38)
-                v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], 
argmax_v0[i], idx[i, k])
-                v_argmax_v1: T.float32 = T.Select(
+                v_argmax_v0: T.let[T.int32] = T.Select(
+                    argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
+                )
+                v_argmax_v1: T.let[T.float32] = T.Select(
                     argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
                 )
                 argmax_v0[i] = v_argmax_v0
@@ -1278,10 +1280,10 @@ def lowered_argmax_split(
                     k = T.axis.reduce(128, i1_0 * 32 + i1_1)
                     T.reads(idx[i, k], val[i, k])
                     T.writes(in_thread_argmax_v0[0], in_thread_argmax_v1[0])
-                    v_argmax_v0: T.int32 = T.Select(
+                    v_argmax_v0: T.let[T.int32] = T.Select(
                         in_thread_argmax_v1[0] >= val[i, k], 
in_thread_argmax_v0[0], idx[i, k]
                     )
-                    v_argmax_v1: T.float32 = T.Select(
+                    v_argmax_v1: T.let[T.float32] = T.Select(
                         in_thread_argmax_v1[0] >= val[i, k], 
in_thread_argmax_v1[0], val[i, k]
                     )
                     in_thread_argmax_v0[0] = v_argmax_v0
@@ -1338,8 +1340,10 @@ def argmin_split_init_update_reordered(
                 with T.init():
                     argmin_v1[i] = T.float32(3.4028234663852886e38)
                     argmin_v0[i] = -1
-                v_argmin_v0: T.int32 = T.Select(argmin_v1[i] <= val[i, k], 
argmin_v0[i], idx[i, k])
-                v_argmin_v1: T.float32 = T.Select(
+                v_argmin_v0: T.let[T.int32] = T.Select(
+                    argmin_v1[i] <= val[i, k], argmin_v0[i], idx[i, k]
+                )
+                v_argmin_v1: T.let[T.float32] = T.Select(
                     argmin_v1[i] <= val[i, k], argmin_v1[i], val[i, k]
                 )
                 argmin_v1[i] = v_argmin_v1
@@ -1370,10 +1374,10 @@ def lowered_argmin_split_init_update_reordered(
                     k = T.axis.reduce(128, i1_0 * 32 + i1_1)
                     T.reads(idx[i, k], val[i, k])
                     T.writes(in_thread_argmin_v0[0], in_thread_argmin_v1[0])
-                    v_argmin_v0: T.int32 = T.Select(
+                    v_argmin_v0: T.let[T.int32] = T.Select(
                         in_thread_argmin_v1[0] <= val[i, k], 
in_thread_argmin_v0[0], idx[i, k]
                     )
-                    v_argmin_v1: T.float32 = T.Select(
+                    v_argmin_v1: T.let[T.float32] = T.Select(
                         in_thread_argmin_v1[0] <= val[i, k], 
in_thread_argmin_v1[0], val[i, k]
                     )
                     in_thread_argmin_v1[0] = v_argmin_v1
@@ -1433,8 +1437,8 @@ def layer_norm_tuple_sum(
                     with T.init():
                         data_red_temp_v0[ax0] = T.float32(0)
                         data_red_temp_v1[ax0] = T.float32(0)
-                    v_data_red_temp_v0: T.float32 = data_red_temp_v0[ax0] + 
data[ax0, k1]
-                    v_data_red_temp_v1: T.float32 = (
+                    v_data_red_temp_v0: T.let[T.float32] = 
data_red_temp_v0[ax0] + data[ax0, k1]
+                    v_data_red_temp_v1: T.let[T.float32] = (
                         data_red_temp_v1[ax0] + data[ax0, k1] * data[ax0, k1]
                     )
                     data_red_temp_v0[ax0] = v_data_red_temp_v0
@@ -1499,8 +1503,10 @@ def lowered_layer_norm_tuple_sum(
                     k1 = T.axis.reduce(768, i1_0 * 32 + i1_1)
                     T.reads(data[ax0, k1])
                     T.writes(in_thread_data_red_temp_v0[0], 
in_thread_data_red_temp_v1[0])
-                    v_data_red_temp_v0: T.float32 = 
in_thread_data_red_temp_v0[0] + data[ax0, k1]
-                    v_data_red_temp_v1: T.float32 = (
+                    v_data_red_temp_v0: T.let[T.float32] = (
+                        in_thread_data_red_temp_v0[0] + data[ax0, k1]
+                    )
+                    v_data_red_temp_v1: T.let[T.float32] = (
                         in_thread_data_red_temp_v1[0] + data[ax0, k1] * 
data[ax0, k1]
                     )
                     in_thread_data_red_temp_v0[0] = v_data_red_temp_v0
diff --git a/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py 
b/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py
index cdc39c443a..a1dc101c74 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py
@@ -84,7 +84,7 @@ def test_remove_let_undef():
     class Before:
         @T.prim_func(s_tir=True)
         def main(A: T.Buffer(1, "int32")):
-            val = T.undef(dtype="int32")
+            val: T.let[T.int32] = T.undef(dtype="int32")
             A[0] = val
 
     @I.ir_module
@@ -104,7 +104,7 @@ def test_raise_error_for_undef_as_store_indices():
     class Before:
         @T.prim_func(s_tir=True)
         def main(A: T.Buffer(1, "int32")):
-            val = T.undef(dtype="int32")
+            val: T.let[T.int32] = T.undef(dtype="int32")
             A[val] = 5
 
     with pytest.raises(TVMError):


Reply via email to