This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2a16d98b9e [S-TIR] Fix software pipeline offsets for legacy MMA 
intrinsics (#19742)
2a16d98b9e is described below

commit 2a16d98b9ec1899f11c35a70c779285812d360c2
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Jun 12 00:38:43 2026 -0400

    [S-TIR] Fix software pipeline offsets for legacy MMA intrinsics (#19742)
    
    This pr fixes `InjectSoftwarePipeline` to rewrite opaque buffer offsets
    for legacy PTX MMA intrinsics, including `ptx_ldmatrix_legacy`,
    `ptx_mma_legacy`, `mma_store_legacy`, and `mma_fill_legacy`.
    
    ### Failure
    
    `test_async_nested_pipeline_mma_gemm_ideal_annotation` failed with a
    real numerical mismatch during the final GEMM result check:
    
    ```text
    Mismatched elements: 6744298 / 16777216 (40.2%)
    Max absolute difference: 6.706421
    Max relative difference: 0.00653978
    ```
    
    A diagnostic run with a fixed seed reproduced the same class of error:
    
    ```text
    nested_full bad 6826830 / 16777216
    max_abs 7.0405273
    mean_abs 0.9874781
    ```
    
    The adjacent simple pipeline test passed, and disabling async copy did
    not change the mismatch, so this was not a cp.async lowering issue,
    tolerance noise, or GPU flakiness.
    
    ### Root Cause
    
    `InjectSoftwarePipeline` may add a leading version dimension to pipeline
    buffers. Normal buffer loads/stores and newer opaque PTX intrinsics
    already had their offsets rewritten to include the pipeline version
    slot.
    
    However, the legacy MMA intrinsics were not covered. In this test, the
    warp buffers became multi-versioned, but the legacy `ldmatrix`/`mma`
    offsets still pointed to the original slot. As a result, the second
    `ldmatrix` stage overwrote the first warp fragment, and both MMA stages
    read the same fragment. That skipped one K fragment and duplicated
    another, producing the numerical mismatch above.
---
 src/s_tir/transform/inject_software_pipeline.cc | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/src/s_tir/transform/inject_software_pipeline.cc 
b/src/s_tir/transform/inject_software_pipeline.cc
index d9da151f39..f5190bddbc 100644
--- a/src/s_tir/transform/inject_software_pipeline.cc
+++ b/src/s_tir/transform/inject_software_pipeline.cc
@@ -28,6 +28,7 @@
 #include <tvm/s_tir/transform.h>
 #include <tvm/target/target.h>
 #include <tvm/tirx/builtin.h>
+#include <tvm/tirx/op.h>
 
 #include <map>
 #include <unordered_set>
@@ -42,6 +43,14 @@ using namespace tvm::tirx;
 
 namespace software_pipeline {
 
+static bool IsOp(const Call& call, const Op& compat_op, const char* 
canonical_name) {
+  if (call->op.same_as(compat_op)) {
+    return true;
+  }
+  const auto* op_node = call->op.as<OpNode>();
+  return op_node != nullptr && op_node->name == canonical_name;
+}
+
 /*!
  * \brief Create a block and infer the access region with the given body.
  *
@@ -110,8 +119,8 @@ class PipelineOpaqueAccessRewriter {
     static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync();
     static const auto& mma_sync = builtin::tvm_mma_sync();
     static const auto& access_ptr = builtin::tvm_access_ptr();
-    static const auto& ptx_ldmatrix = builtin::ptx_ldmatrix();
-    static const auto& ptx_mma = builtin::ptx_mma();
+    static const auto& ptx_ldmatrix_legacy = builtin::ptx_ldmatrix_legacy();
+    static const auto& ptx_mma_legacy = builtin::ptx_mma_legacy();
     if (call->op.same_as(load_matrix_sync) || 
call->op.same_as(store_matrix_sync)) {
       const Buffer& buffer = 
buffer_data_to_buffer_.at(Downcast<Var>(call->args[0]));
       auto it = buffer_remap_.find(buffer);
@@ -136,9 +145,9 @@ class PipelineOpaqueAccessRewriter {
       return Call(call->dtype, call->op, new_args, call->attrs, call->span);
     } else if (call->op.same_as(access_ptr)) {
       return RewriteBufferAccess(call, {1});
-    } else if (call->op.same_as(ptx_mma)) {
+    } else if (IsOp(call, ptx_mma_legacy, "tirx.ptx.mma_legacy")) {
       return RewriteBufferAccess(call, {6, 8, 10});
-    } else if (call->op.same_as(ptx_ldmatrix)) {
+    } else if (IsOp(call, ptx_ldmatrix_legacy, "tirx.ptx.ldmatrix_legacy")) {
       return RewriteBufferAccess(call, {3});
     }
     return call;

Reply via email to