This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c9a77d6712 [S-TIR][Tests] Fix transform test failures after TIRx
bringup (#19735)
c9a77d6712 is described below
commit c9a77d671232a192df5241a8ffd5f15be8274224
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Jun 11 17:34:05 2026 -0400
[S-TIR][Tests] Fix transform test failures after TIRx bringup (#19735)
This PR fixes 11 test failures in `tests/python/s_tir/transform/`
introduced as side effects of the TIRx bringup (#19581 / 859498dc01), in
three independent commits.
### 1. LowerOpaqueBlock: update expected IR for buffer metadata
annotations
`LowerOpaqueBlock` now emits `buffer_allocated_addr` and
`buffer_data_alignment` annotations on lowered allocations (intentional
in #19581: the annotations are consumed downstream by `codegen_cuda.cc`
/ `codegen_trn.cc`; the alignment value 64 comes from
`kAllocAlignment`). The tests' expected IR predates this, so
`assert_structural_equal` failed on the missing annotations.
Fix: update the expected IR in
`test_s_tir_transform_lower_opaque_block.py` to carry the annotations
(`T.decl_buffer(...)` → `T.alloc_buffer(..., annotations={...})`). Fixes
6 tests.
### 2. DefaultGPUSchedule: parse scalar-block test in s_tir mode
#19581 added a well-formedness rule rejecting `SBlockRealize` in
`tirx=True` mode, which is correct — sblocks are s_tir-mode constructs.
The hand-written `Before`/`Expected` modules in
`test_scalar_block_no_loops` were the only ones in the file still using
plain `T.prim_func`, so they failed at parse time before the pass under
test even ran.
Fix: parse both modules with `T.prim_func(s_tir=True)`, consistent with
every other test in the file. Fixes 1 test.
### 3. InjectPermutedLayout: match legacy PTX intrinsics by canonical
name
#19581 registers device intrinsics under two Op identities: a flat
builtin name (returned by `builtin::xxx()` in C++) and a canonical
dotted name (e.g. `tirx.ptx.ldmatrix_legacy`, produced when TVMScript /
tensor intrinsics are parsed). `InjectPermutedLayout` only compared with
`same_as(builtin::...)`, so it silently skipped rewriting the swizzled
shared-memory offsets of parsed legacy-form calls, leaving the expected
swizzle index expressions unmatched.
Fix: match `ptx_ldmatrix_legacy` / `mma_store_legacy` by both the
builtin Op and the canonical name via an `IsOp` helper, following the
existing pattern in `lower_warp_memory.cc` and `codegen_cuda.cc`. Only
the legacy intrinsic forms fold shared-memory access into
`tvm_access_ptr` + offset; non-legacy forms address shared memory
through `BufferLoad` and are already handled by the BufferLoad visitor,
so the unreachable `InternalError` throw is replaced by a pass-through.
(`mma_store_legacy` has no dotted alias, hence the asymmetric name
strings.) Fixes 4 tests.
---
src/s_tir/transform/inject_permuted_layout.cc | 28 ++++++++++-----
.../test_s_tir_transform_default_gpu_schedule.py | 4 +--
.../test_s_tir_transform_lower_opaque_block.py | 41 ++++++++++++++++++----
3 files changed, 55 insertions(+), 18 deletions(-)
diff --git a/src/s_tir/transform/inject_permuted_layout.cc
b/src/s_tir/transform/inject_permuted_layout.cc
index fe90f38cec..74e843a6e5 100644
--- a/src/s_tir/transform/inject_permuted_layout.cc
+++ b/src/s_tir/transform/inject_permuted_layout.cc
@@ -246,6 +246,17 @@ class PermutedLayoutInjector : private
IRMutatorWithAnalyzer {
return access_ptr_call;
}
+ // Device intrinsics are registered under both a flat name (the builtin Op)
+ // and a canonical dotted name (emitted by TVMScript and the tensor
+ // intrinsics), so compare against both.
+ static bool IsOp(const Call& call, const Op& compat_op, const char*
canonical_name) {
+ if (call->op.same_as(compat_op)) {
+ return true;
+ }
+ const auto* op_node = call->op.as<OpNode>();
+ return op_node != nullptr && op_node->name == canonical_name;
+ }
+
PrimExpr VisitExpr_(const CallNode* op) final {
// Rewrite from/to shared or shared.dyn to/from local
auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
@@ -254,12 +265,12 @@ class PermutedLayoutInjector : private
IRMutatorWithAnalyzer {
return call;
}
- if (!call->op.same_as(builtin::ptx_ldmatrix()) &&
!call->op.same_as(builtin::mma_store())) {
- return call;
- }
-
- if (call->op.same_as(builtin::ptx_ldmatrix())) {
- // form: T.ptx_ldmatrix(..., smem_ptr, smem_offset)
+ // Only the legacy intrinsic forms fold the shared memory access into a
+ // tvm_access_ptr + offset, which must be rewritten here. The non-legacy
+ // forms address shared memory through BufferLoad (e.g. via address_of),
+ // which is already handled by the BufferLoad visitor above.
+ if (IsOp(call, builtin::ptx_ldmatrix_legacy(),
"tirx.ptx.ldmatrix_legacy")) {
+ // form: T.ptx.ldmatrix_legacy(..., smem_ptr, smem_offset)
// smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
auto access_ptr = call->args[5];
PrimExpr smem_offset = call->args[6];
@@ -268,7 +279,7 @@ class PermutedLayoutInjector : private
IRMutatorWithAnalyzer {
new_call->args.Set(5, new_access_ptr);
new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
return call;
- } else if (call->op.same_as(builtin::mma_store())) {
+ } else if (IsOp(call, builtin::mma_store_legacy(),
"tirx.mma_store_legacy")) {
// TODO(yixin): mma_store is not fully tested yet
// because we will directly store result to Buffer instead of calling
mma_store now
auto access_ptr = call->args[2];
@@ -276,9 +287,8 @@ class PermutedLayoutInjector : private
IRMutatorWithAnalyzer {
auto new_call = call.CopyOnWrite();
new_call->args.Set(2, new_access_ptr);
return call;
- } else {
- TVM_FFI_THROW(InternalError) << "Invalid call node: " << call;
}
+ return call;
}
static constexpr size_t VECTORIZE_FACTOR = 8;
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
index 891ba3f208..875fe18182 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
@@ -575,14 +575,14 @@ def test_scalar_block_no_loops():
# fmt: off
@tvm.script.ir_module
class Before:
- @T.prim_func
+ @T.prim_func(s_tir=True)
def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"),
c: T.Buffer((), "float32")):
with T.sblock("scalar_add"):
c[()] = a[()] + b[()]
@tvm.script.ir_module
class Expected:
- @T.prim_func
+ @T.prim_func(s_tir=True)
def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"),
c: T.Buffer((), "float32")):
T.func_attr({"tirx.is_scheduled": True})
# with T.sblock("root"):
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py
b/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py
index 62ad915a57..441074128e 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py
@@ -56,7 +56,11 @@ def transformed_elementwise_func(a: T.handle, c: T.handle)
-> None:
A = T.match_buffer(a, (16, 16), "float32")
C = T.match_buffer(c, (16, 16), "float32")
for i in T.serial(0, 16):
- B_new = T.decl_buffer(shape=[1, 16], dtype="float32")
+ B_new = T.alloc_buffer(
+ [1, 16],
+ "float32",
+ annotations={"buffer_allocated_addr": [], "buffer_data_alignment":
64},
+ )
for j in T.serial(0, 16):
B_new[0, j] = A[i, j] + 1.0
for j in T.serial(0, 16):
@@ -98,7 +102,12 @@ def transformed_gpu_func(a: T.handle, c: T.handle) -> None:
T.launch_thread(i0, 4)
T.launch_thread(i1, 2)
T.launch_thread(i2, 2)
- B = T.decl_buffer(shape=[1, 16], dtype="float32", scope="local")
+ B = T.alloc_buffer(
+ [1, 16],
+ "float32",
+ scope="local",
+ annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64},
+ )
for j in range(0, 16):
B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
for j in range(0, 16):
@@ -133,7 +142,11 @@ def transformed_symbolic_func(a: T.handle, c: T.handle, n:
T.int32, m: T.int32)
C = T.match_buffer(c, (n, m), "float32")
for i in range(0, n):
- B = T.decl_buffer(shape=[m], dtype="float32")
+ B = T.alloc_buffer(
+ [m],
+ "float32",
+ annotations={"buffer_allocated_addr": [], "buffer_data_alignment":
64},
+ )
for j in range(0, m):
B[j] = A[i, j] + 1.0
for j in range(0, m):
@@ -206,8 +219,16 @@ def transformed_multi_alloc_func(a: T.handle, d: T.handle)
-> None:
D = T.match_buffer(d, (32), "float32")
for i in range(0, 32):
- B = T.decl_buffer(shape=(32,), dtype="float32")
- C = T.decl_buffer(shape=(32,), dtype="float32")
+ B = T.alloc_buffer(
+ (32,),
+ "float32",
+ annotations={"buffer_allocated_addr": [], "buffer_data_alignment":
64},
+ )
+ C = T.alloc_buffer(
+ (32,),
+ "float32",
+ annotations={"buffer_allocated_addr": [], "buffer_data_alignment":
64},
+ )
B[i] = A[i] + 1.0
C[i] = A[i] + B[i]
D[i] = C[i] * 2.0
@@ -242,7 +263,12 @@ def transformed_strided_buffer_func(
) -> None:
# body
for i0 in T.serial(4):
- B = T.decl_buffer(shape=[4, 16], dtype="float32", strides=[17, 1])
+ B = T.alloc_buffer(
+ [4, 16],
+ "float32",
+ strides=[17, 1],
+ annotations={"buffer_allocated_addr": [], "buffer_data_alignment":
64},
+ )
for i1, j in T.grid(4, 16):
B[i1, j] = A[i0 * 4 + i1, j] + T.float32(1)
for i1, j in T.grid(4, 16):
@@ -275,10 +301,11 @@ def transformed_symbolic_strided_buffer_func(a: T.handle):
n = T.int32()
A = T.match_buffer(a, (1, n, 10240))
for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
- A_pad_shared_dyn = T.decl_buffer(
+ A_pad_shared_dyn = T.alloc_buffer(
(1, T.min((n + 63) // 64 * 64, 96), 64),
strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1),
scope="shared.dyn",
+ annotations={"buffer_allocated_addr": [], "buffer_data_alignment":
64},
)
for ax0, ax1 in T.grid(96, 64):
if i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64: