This is an automated email from the ASF dual-hosted git repository.

expye pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 468bf2da79 [TIR][Transform] Introduce new `InjectPermutedLayout` pass 
(#16070)
468bf2da79 is described below

commit 468bf2da7902974047977a7f31ffbfcfaa422eac
Author: Yixin Dong <[email protected]>
AuthorDate: Sun Nov 12 01:25:40 2023 -0800

    [TIR][Transform] Introduce new `InjectPermutedLayout` pass (#16070)
    
    * 1104
    
    1104
    
    1106
    
    * 1106
    
    * try fix ci
---
 src/tir/transforms/inject_permuted_layout.cc       | 401 +++++++++++----------
 .../test_tir_transform_inject_permuted_layout.py   | 351 ++++++++++++++++++
 2 files changed, 569 insertions(+), 183 deletions(-)

diff --git a/src/tir/transforms/inject_permuted_layout.cc 
b/src/tir/transforms/inject_permuted_layout.cc
index a1afbeae6f..cccf2c505a 100644
--- a/src/tir/transforms/inject_permuted_layout.cc
+++ b/src/tir/transforms/inject_permuted_layout.cc
@@ -19,44 +19,55 @@
 
 /*!
  * \file inject_permuted_layout.cc
- * \brief The pass for inject permuted layout.
+ * \brief The pass injects permuted layout for shared memory buffers to avoid 
bank conflicts.
  */
-
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/function.h>
 #include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../runtime/thread_storage_scope.h"
 #include "../../support/utils.h"
-#include "../ir/functor_common.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
-using tir::Block;
-using tir::BlockRealize;
-using tir::Call;
-using tir::For;
+using namespace arith;
+using namespace runtime;
 
-class PermutedLayoutInjector : public StmtExprMutator {
+class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
  public:
-  PermutedLayoutInjector() {}
+  static PrimFunc Transform(PrimFunc func) {
+    Analyzer analyzer;
+
+    auto new_body = PermutedLayoutInjector(func, &analyzer)(func->body);
+    auto func_node = func.CopyOnWrite();
+    func_node->body = new_body;
+    return func;
+  }
 
  private:
-  Array<PrimExpr> GetNewIndices(PrimExpr s0, PrimExpr s1, int smem_width) {
-    // index after vectorize(8)
-    PrimExpr i = s0, j = floordiv(s1, 8), v = floormod(s1, 8);
-    PrimExpr permuted_j;
-    // In the following comments, each number represent a 8 * fp16 load
-    // which is correspond to a index (i, j) in line 50's PrimExpr
-    // Each 8 number correspond to 32 memory bank (every bank has 32 bit):
-    //   8 * 8 * 16bit = 32 * 32bit
-    // And we have 32 banks in total, so all loads in one column share
-    // same memory bank
-    if (smem_width % 64 == 0) {
-      // use 8 * 8 permuted
+  explicit PermutedLayoutInjector(PrimFunc func, Analyzer* analyzer)
+      : IRMutatorWithAnalyzer(analyzer) {
+    buffer_map_.insert(func->buffer_map.begin(), func->buffer_map.end());
+  }
+
+  using IRMutatorWithAnalyzer::VisitExpr_;
+  using IRMutatorWithAnalyzer::VisitStmt_;
+
+  Array<PrimExpr> PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int 
row_size) {
+    ICHECK(permute_);
+    // Index after vectorizing by 8
+    PrimExpr col_idx_outer = floordiv(col_idx, VECTORIZE_FACTOR),
+             col_idx_inner = floormod(col_idx, VECTORIZE_FACTOR);
+    PrimExpr new_col_idx_outer;
+    if (row_size % 64 == 0) {
+      // Use 8 * 8 permuted layout
+      // Every number below corresponds to 8 consecutive fp16 number in shared 
mem, i.e. one read
+      // Every row below corresponds to 32 banks
       // 0  1  2  3  4  5  6  7    ==>    0  1  2  3  4  5  6  7
       // 0  1  2  3  4  5  6  7    ==>    1  0  3  2  5  4  7  6
       // 0  1  2  3  4  5  6  7    ==>    2  3  0  1  6  7  4  5
@@ -65,10 +76,13 @@ class PermutedLayoutInjector : public StmtExprMutator {
       // 0  1  2  3  4  5  6  7    ==>    5  4  7  6  1  0  3  2
       // 0  1  2  3  4  5  6  7    ==>    6  7  4  5  2  3  0  1
       // 0  1  2  3  4  5  6  7    ==>    7  6  5  4  3  2  1  0
-      PrimExpr permuted_j_mod_8 = (floormod(j, 8) ^ floormod(i, 8));
-      permuted_j = floordiv(j, 8) * 8 + permuted_j_mod_8;
+      auto row_idx_sub = floormod(row_idx, 8);
+      new_col_idx_outer = col_idx_outer ^ row_idx_sub;
     } else {
-      // use 8 * 4 permuted
+      ICHECK(row_size % 32 == 0);
+      // Use 8 * 4 permuted layout
+      // Every number below corresponds to 8 consecutive fp16 number in shared 
mem, i.e. one read
+      // Every row below corresponds to 16 banks
       // 0  1  2  3    ==>    0  1  2  3
       // 0  1  2  3    ==>    0  1  2  3
       // 0  1  2  3    ==>    1  0  3  2
@@ -77,183 +91,204 @@ class PermutedLayoutInjector : public StmtExprMutator {
       // 0  1  2  3    ==>    2  3  0  1
       // 0  1  2  3    ==>    3  2  1  0
       // 0  1  2  3    ==>    3  2  1  0
-      // in 8 number each line view:
+      // View with 8 elements per row:
       // 0  1  2  3  4  0  1  2  3    ==>    0  1  2  3  0  1  2  3
       // 0  1  2  3  4  0  1  2  3    ==>    1  0  3  2  1  0  3  2
       // 0  1  2  3  4  0  1  2  3    ==>    2  3  0  1  2  3  0  1
       // 0  1  2  3  4  0  1  2  3    ==>    3  2  1  0  3  2  1  0
-      permuted_j = floormod(j, 4) ^ floordiv(floormod(i, 8), 2);
+      auto row_idx_sub = floormod(row_idx, 8);
+      new_col_idx_outer = col_idx_outer ^ floordiv(row_idx_sub, 2);
     }
-    return {s0, permuted_j * 8 + v};
+    return {row_idx, analyzer_->Simplify(new_col_idx_outer * 8 + 
col_idx_inner)};
   }
 
-  Stmt VisitStmt_(const BlockRealizeNode* _op) final {
-    BlockRealize br = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(_op));
-    BlockRealizeNode* op = br.CopyOnWrite();
-    if (op->block->annotations.count("permuted_layout") == 0) {
-      return br;
+  static bool CheckAnnotation(ObjectRef annotation) {
+    if (auto* node = annotation.as<StringObj>()) {
+      // Support string annotation for backward compatibility
+      return GetRef<String>(node) != "";
+    } else if (auto* node = annotation.as<IntImmNode>()) {
+      return node->value != 0;
+    } else {
+      LOG(FATAL) << "Invalid permuted layout annotation: " << annotation;
     }
-    String val = 
Downcast<String>(op->block->annotations.at("permuted_layout"));
-    if (val.empty()) return br;
-    Block blk = op->block;
-    Stmt body = blk->body;
-    if (support::StartsWith(val, "g2s")) {
-      // Case 1. Rewrite global to share.dyn
-
-      // Step 1.1. Handle case when have local stage
-      // Block with local stage is like
-      // body {
-      //   SeqStmt {
-      //     seq[0]: local <- global
-      //     seq[1]: shared.dyn <- local
-      //   }
-      // }
-      // We only need to rewrite seq[1]
-      bool have_local_stage = (body.as<SeqStmtNode>() != nullptr);
-      Stmt upper_loop;
-      if (have_local_stage) {
-        SeqStmt seq = Downcast<SeqStmt>(body);
-        ICHECK(seq->size() == 2);
-        upper_loop = seq->seq[0];
-        body = seq->seq[1];
-      }
-
-      // Step 1.2. get inner loop body
-      std::vector<const ForNode*> loops;
-      while (const ForNode* loop = body.as<ForNode>()) {
-        loops.push_back(loop);
-        body = loop->body;
-      }
-      Optional<PrimExpr> if_then_else_condition = NullOpt;
-      const BufferStoreNode* store = body.as<BufferStoreNode>();
-      if (!store) {
-        // Case 1.2.1. IfThenElse generated by reverse_compute_inline
-        // It is always like
-        // if condition:
-        //   loop_body
-        // We just extract the inner loop body inside IfThenElseNode
-        const IfThenElseNode* if_then_else = body.as<IfThenElseNode>();
-        store = if_then_else->then_case.as<BufferStoreNode>();
-        ICHECK(!if_then_else->else_case);
-        if_then_else_condition = if_then_else->condition;
-      }
-      ICHECK(store) << body;
-
-      // Step 1.3. Get smem width and refuse to make any difference if invalid
-      auto smem_width = store->buffer->shape[1].as<IntImmNode>()->value;
-      if (smem_width % 32 != 0) {
-        LOG(WARNING) << "Permuted Layout for " << op->block->name_hint
-                     << " is not supported since its second dimension is not 
divisible by 32";
-        return br;
-      }
-      if (smem_width % 64 == 32) {
-        if (store->buffer->shape[0].as<IntImmNode>()->value % 2 != 0) {
-          LOG(WARNING) << "Permuted Layout for " << op->block->name_hint
-                       << " is not supported since its first dimension is not 
divisible by 2"
-                       << " and second dimension is not divisible by 64";
-          return br;
-        }
-      }
-
-      // Step 1.4. Set corresponding member variable
-      if (val.at(4) == 'A') {
-        smem_width_A_ = smem_width;
-      } else {
-        smem_width_B_ = smem_width;
-      }
-
-      // Step 1.5. Rewrite index
-      PrimExpr s0 = store->indices[0];
-      PrimExpr s1 = store->indices[1];
-      Array<PrimExpr> new_indices = GetNewIndices(s0, s1, smem_width);
-      // Step 1.6. Create new BlockRealize
-      Stmt new_body = BufferStore(store->buffer, store->value, new_indices);
-      if (if_then_else_condition) {
-        // Case 1.6.1. Add back IfThenElse
-        new_body = IfThenElse(if_then_else_condition.value(), new_body);
-      }
-      for (int i = loops.size() - 1; i >= 0; i--) {
-        const ForNode* loop = loops[i];
-        new_body = For(loop->loop_var, loop->min, loop->extent, loop->kind, 
new_body,
-                       loop->thread_binding, loop->annotations);
-      }
-      if (have_local_stage) {
-        // Case 1.6.1. Add back local stage
-        new_body = SeqStmt({upper_loop, new_body});
-      }
-      Block new_blk = Block(blk->iter_vars, blk->reads, blk->writes, 
blk->name_hint, new_body,
-                            blk->init, blk->alloc_buffers, blk->match_buffers, 
blk->annotations);
-      BlockRealize new_br = BlockRealize(op->iter_values, op->predicate, 
new_blk);
-      return new_br;
-    } else if (support::StartsWith(val, "s2l")) {
-      // Case 2. rewrite share.dyn to local
-      // Step 2.1. Retrieve previous set member variable
-      int smem_width = val.at(4) == 'A' ? smem_width_A_ : smem_width_B_;
-      if (smem_width == -1) {
-        return br;
-      }
-
-      // Step 2.2. Rewrite index
-      // Body of shared.dyn to local is always 
T.evaluate(T.ptx_ldmatrix(args...))
-      // Please refer to the load tensor intrinsic
-      Evaluate eval = Downcast<Evaluate>(body);
-      Call ldmat_call = Downcast<Call>(eval->value);
-      ICHECK(ldmat_call->args.size() == 7);
-      Array<PrimExpr> new_ldmat_args;
-      // Step 2.2.1. Add unchanged args
-      for (int i = 0; i < 5; i++) {
-        new_ldmat_args.push_back(ldmat_call->args[i]);
-      }
-      // 5th argument is always a T.tvm_access_ptr call
-      // Please refer to the load tensor intrinsic
-      Call accptr_call = Downcast<Call>(ldmat_call->args[5]);
-      PrimExpr smem_offset = ldmat_call->args[6];
-
-      // Step 2.2.2. Create new access ptr call
-      Array<PrimExpr> new_accptr_args;
-      for (int i = 0; i < 5; i++) {
-        // 2th args of T.tvm_access_ptr call is offset, we set it to 0 and 
calculate
-        // total offset in ldmatrix call
-        new_accptr_args.push_back(i == 2 ? 0 : accptr_call->args[i]);
-      }
-      Call new_accptr_call = Call(accptr_call->dtype, accptr_call->op, 
new_accptr_args);
-      new_ldmat_args.push_back(new_accptr_call);
-
-      // Step 2.2.3. Calculate new offset
-      // We convert offset to 2-dimension, reindex it and convert it back
-      PrimExpr accptr_offset = accptr_call->args[2];
-      PrimExpr offset = smem_offset + accptr_offset;
-      PrimExpr s0 = floordiv(offset, smem_width), s1 = floormod(offset, 
smem_width);
-      Array<PrimExpr> new_indices = GetNewIndices(s0, s1, smem_width);
-      PrimExpr new_offset = new_indices[0] * smem_width + new_indices[1];
-      new_ldmat_args.push_back(new_offset);
-      // Step 2.2.4. Rewrite the rest part
-      Call new_ldmat_call = Call(ldmat_call->dtype, ldmat_call->op, 
new_ldmat_args);
-      Stmt new_body = Evaluate(new_ldmat_call);
-      Block new_blk = Block(blk->iter_vars, blk->reads, blk->writes, 
blk->name_hint, new_body,
-                            blk->init, blk->alloc_buffers, blk->match_buffers, 
blk->annotations);
-      BlockRealize new_br = BlockRealize(op->iter_values, op->predicate, 
new_blk);
-      return new_br;
+  }
+
+  Stmt VisitStmt_(const BlockNode* op) final {
+    // Record the mapping from buffer data var to buffer for later lookup
+    for (auto buffer : op->alloc_buffers) {
+      buffer_map_.insert({buffer->data, buffer});
+    }
+    for (auto match_buffer : op->match_buffers) {
+      buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer});
+    }
+
+    if (op->annotations.count("permuted_layout") == 0 ||
+        !CheckAnnotation(op->annotations.at("permuted_layout"))) {
+      return IRMutatorWithAnalyzer::VisitStmt_(op);
     }
 
-    return StmtExprMutator::VisitStmt_(op);
+    auto prev_permute = permute_;
+    permute_ = true;
+
+    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
+
+    permute_ = prev_permute;
+
+    // Erase the permuted_layout annotation after the pass
+    auto block_node = block.CopyOnWrite();
+    block_node->annotations.erase("permuted_layout");
+    return block;
   }
 
-  int smem_width_A_ = -1;
-  int smem_width_B_ = -1;
-};
+  int CheckAndGetBufferRowSize(Buffer buffer) {
+    CHECK(buffer->shape.size() >= 2)
+        << "The dimension of Buffer \"" << buffer->name << "\" with shape " << 
buffer->shape
+        << " should be at least 2";
 
-PrimFunc InjectPermutedLayout(PrimFunc func) {
-  auto fptr = func.CopyOnWrite();
-  fptr->body = PermutedLayoutInjector()(std::move(fptr->body));
-  return func;
-}
+    auto dim = buffer->shape.size();
+    auto buffer_row_size = buffer->shape[dim - 1].as<IntImmNode>()->value;
+    auto buffer_col_size = buffer->shape[dim - 2].as<IntImmNode>()->value;
+
+    if (buffer_row_size % 64 != 0) {
+      CHECK(buffer_row_size % 32 == 0)
+          << "Permuted Layout for Buffer \"" << buffer->name << "\" with shape 
" << buffer->shape
+          << " is not supported since its second dimension is not divisible by 
32";
+      CHECK(buffer_col_size % 2 == 0)
+          << "Permuted Layout for Buffer \"" << buffer->name << "\" with shape 
" << buffer->shape
+          << " is not supported since its first dimension is not divisible by 
2 and second "
+             "dimension is not divisible by 64";
+    }
+
+    return buffer_row_size;
+  }
+
+  Array<PrimExpr> HandleBufferIndices(Buffer buffer, Array<PrimExpr> indices) {
+    auto buffer_row_size = CheckAndGetBufferRowSize(buffer);
+
+    // Mutate the last two indices
+    auto indices_size = indices.size();
+    PrimExpr row_idx = indices[indices_size - 2];
+    PrimExpr col_idx = indices[indices_size - 1];
+    auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size);
+    indices.Set(indices_size - 2, new_indices[0]);
+    indices.Set(indices_size - 1, new_indices[1]);
+    return indices;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    // Rewrite write from global to shared.dyn or shared
+    // We assume the shape of the shared memory is [..., row_size, col_size],
+    // where row_size is divisible by 64, or divisible by 32 and col_size is 
divisible by 2.
+    auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
+
+    if (!permute_ || store->buffer->shape.size() < 2) {
+      return store;
+    }
+
+    auto scope = StorageScope::Create(GetPtrStorageScope(store->buffer->data));
+    if (scope.rank != StorageRank::kShared) {
+      return store;
+    }
+
+    auto store_node = store.CopyOnWrite();
+    store_node->indices = HandleBufferIndices(store_node->buffer, 
store_node->indices);
+    return store;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    // Rewrite load from shared or shared.dyn to global
+    auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
+
+    if (!permute_ || load->buffer->shape.size() < 2) {
+      return load;
+    }
+
+    auto scope = StorageScope::Create(GetPtrStorageScope(load->buffer->data));
+    if (scope.rank != StorageRank::kShared) {
+      return load;
+    }
+
+    auto load_node = load.CopyOnWrite();
+    load_node->indices = HandleBufferIndices(load_node->buffer, 
load_node->indices);
+    return load;
+  }
+
+  PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, Optional<PrimExpr> 
offset = NullOpt) {
+    // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and 
accumulate it to
+    // smem_offset
+    CHECK(access_ptr->IsInstance<CallNode>())
+        << "Invalid access ptr for permuted layout: " << access_ptr;
+    auto access_ptr_call = Downcast<Call>(access_ptr);
+    CHECK(access_ptr_call->op.same_as(builtin::tvm_access_ptr()))
+        << "Invalid access ptr for permuted layout: " << access_ptr;
+
+    auto buffer_map_iter = 
buffer_map_.find(Downcast<Var>(access_ptr_call->args[1]));
+    CHECK(buffer_map_iter != buffer_map_.end())
+        << "The buffer corresponding to data Var " << access_ptr_call->args[1] 
<< " is not found";
+    int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second);
+
+    PrimExpr smem_offset = access_ptr_call->args[2] + (offset.defined() ? 
offset.value() : 0);
+
+    // Convert offset to 2-dimension, reindex it and convert it back
+    PrimExpr row_idx = floordiv(smem_offset, buffer_row_size);
+    PrimExpr col_idx = floormod(smem_offset, buffer_row_size);
+
+    auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size);
+    auto new_offset = analyzer_->Simplify(new_indices[0] * buffer_row_size + 
new_indices[1]);
+
+    auto new_access_ptr = access_ptr_call.CopyOnWrite();
+    new_access_ptr->args.Set(2, new_offset);
+    return access_ptr_call;
+  }
+
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    // Rewrite from/to shared or shared.dyn to/from local
+    auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
+
+    if (!permute_) {
+      return call;
+    }
+
+    if (!call->op.same_as(builtin::ptx_ldmatrix()) && 
!call->op.same_as(builtin::mma_store())) {
+      return call;
+    }
+
+    if (call->op.same_as(builtin::ptx_ldmatrix())) {
+      // form: T.ptx_ldmatrix(..., smem_ptr, smem_offset)
+      // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
+      auto access_ptr = call->args[5];
+      PrimExpr smem_offset = call->args[6];
+      auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, smem_offset);
+      auto new_call = call.CopyOnWrite();
+      new_call->args.Set(5, new_access_ptr);
+      new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
+      return call;
+    } else if (call->op.same_as(builtin::mma_store())) {
+      // TODO(yixin): mma_store is not fully tested yet
+      // because we will directly store result to Buffer instead of calling 
mma_store now
+      auto access_ptr = call->args[2];
+      auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr);
+      auto new_call = call.CopyOnWrite();
+      new_call->args.Set(2, new_access_ptr);
+      return call;
+    } else {
+      LOG(FATAL) << "Invalid call node: " << call;
+    }
+  }
+
+  static constexpr size_t VECTORIZE_FACTOR = 8;
+  static constexpr size_t BANK_SIZE_BYTES = 128;
+
+  // Mapping from data Var of a Buffer to Buffer, for lookup
+  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
+  bool permute_ = false;
+};
 
 namespace transform {
 
 Pass InjectPermutedLayout() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
-    return InjectPermutedLayout(std::move(f));
+    return PermutedLayoutInjector::Transform(std::move(f));
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.InjectPermutedLayout", {});
 }
diff --git a/tests/python/unittest/test_tir_transform_inject_permuted_layout.py 
b/tests/python/unittest/test_tir_transform_inject_permuted_layout.py
new file mode 100644
index 0000000000..6495cdb2bd
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_inject_permuted_layout.py
@@ -0,0 +1,351 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import IRModule
+from tvm.script import tir as T
+from tvm.tir import PrimFunc
+
+
+def _check_primfunc_transform(before: PrimFunc, expected: PrimFunc):
+    before_module = IRModule.from_expr(before)
+    after_module = tvm.tir.transform.InjectPermutedLayout()(before_module)
+
+    after = after_module["before"].without_attr("global_symbol")
+    expected = expected.without_attr("global_symbol")
+
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+# This pass is adapted from another previous pass, so we need to ensure 
backward compatibility here
+def test_backward_compatibility_shared_a():
+    # fmt: off
+    @T.prim_func
+    def before(X: T.Buffer((4096, 4096), "float16")):
+        # with T.block("root"):
+        for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"):
+            for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"):
+                for threadIdx_x in T.thread_binding(32, thread="threadIdx.x"):
+                    with T.block(""):
+                        T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + 
threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 
97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 4072])
+                        T.writes()
+                        for ax2_0_0 in range(128):
+                            with T.block(""):
+                                T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y 
* 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 
4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 + 
8])
+                                T.writes()
+                                X_reindex_shared_dyn = T.alloc_buffer((128, 
32), "float16", strides=(32, 1), scope="shared.dyn")
+                                with T.block("X_reindex_shared.dyn"):
+                                    T.reads(X[blockIdx_y // 8 * 128 + 
threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + 
threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + 
threadIdx_x % 4 * 8 + 8])
+                                    T.writes(X_reindex_shared_dyn[threadIdx_y 
* 8 + threadIdx_x // 4:threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 
* 8:threadIdx_x % 4 * 8 + 8])
+                                    T.block_attr({"permuted_layout": "g2s_A"})
+                                    for ax0_ax1_fused_0 in range(4):
+                                        for ax0_ax1_fused_3 in T.vectorized(8):
+                                            
X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, 
threadIdx_x % 4 * 8 + ax0_ax1_fused_3] = X[blockIdx_y // 8 * 128 + 
ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, ax2_0_0 * 32 + 
threadIdx_x % 4 * 8 + ax0_ax1_fused_3]
+                                for ax2_0_1 in range(4):
+                                    with T.block(""):
+                                        
T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64:threadIdx_y // 2 * 64 + 64, 
ax2_0_1 * 8:ax2_0_1 * 8 + 8])
+                                        T.writes()
+                                        X_reindex_shared_dyn_m16n8k8_matrixA = 
T.alloc_buffer((64, 8), "float16", scope="m16n8k8.matrixA")
+                                        for ax0_0, ax1_0 in T.grid(2, 1):
+                                            with 
T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"):
+                                                
T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y // 
2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8])
+                                                
T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8])
+                                                
T.block_attr({"permuted_layout": "s2l_A"})
+                                                T.ptx_ldmatrix("float16", 
T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8, 
T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data, 
threadIdx_y // 2 * 2048 + ax0_0 * 1024 + ax2_0_1 * 8, 1024, 1), threadIdx_x * 
32)
+
+    @T.prim_func
+    def expected(X: T.Buffer((4096, 4096), "float16")):
+        for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"):
+            for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"):
+                for threadIdx_x in T.thread_binding(32, thread="threadIdx.x"):
+                    with T.block(""):
+                        for ax2_0_0 in T.serial(128):
+                            with T.block(""):
+                                X_reindex_shared_dyn = T.alloc_buffer((128, 
32), "float16", strides=(32, 1), scope="shared.dyn")
+                                with T.block("X_reindex_shared.dyn"):
+                                    # annotate the reads and writes because 
they cannot be inferred from tir.bitwise_xor
+                                    T.reads(X[blockIdx_y // 8 * 128 + 
threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + 
threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + 
threadIdx_x % 4 * 8 + 8])
+                                    T.writes(X_reindex_shared_dyn[threadIdx_y 
* 8 + threadIdx_x // 4:threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 
* 8:threadIdx_x % 4 * 8 + 8])
+                                    for ax0_ax1_fused_0 in range(4):
+                                        for ax0_ax1_fused_3 in T.vectorized(8):
+                                            
X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, 
T.bitwise_xor(threadIdx_x % 4, threadIdx_x // 8) * 8 + ax0_ax1_fused_3] = 
X[blockIdx_y // 8 * 128 + ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x 
// 4, ax2_0_0 * 32 + threadIdx_x % 4 * 8 + ax0_ax1_fused_3]
+                                for ax2_0_1 in T.serial(4):
+                                    with T.block(""):
+                                        X_reindex_shared_dyn_m16n8k8_matrixA = 
T.alloc_buffer((64, 8), "float16", scope="m16n8k8.matrixA")
+                                        for ax0_0, ax1_0 in T.grid(2, 1):
+                                            with 
T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"):
+                                                
T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y // 
2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8])
+                                                
T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8])
+                                                T.ptx_ldmatrix("float16", 
T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8, 
T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data, 
threadIdx_y // 2 * 2048 + ax0_0 * 1024 + threadIdx_x * 32 + 
T.bitwise_xor(ax2_0_1, threadIdx_x % 8 // 2) * 8, 1024, 1), 0)
+    # fmt: on
+    _check_primfunc_transform(before, expected)
+
+
+def test_backward_compatibility_shared_a_and_b():
+    # fmt: off
+    @T.prim_func
+    def before(X: T.Buffer((4096, 4096), "float16"), Y: T.Buffer((4096, 4096), 
"float16")):
+        for blockIdx_x in T.thread_binding(4, thread="blockIdx.x"):
+            for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"):
+                for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"):
+                    for threadIdx_x in T.thread_binding(32, 
thread="threadIdx.x"):
+                        with T.block(""):
+                            for ax2_0_0 in T.serial(128):
+                                with T.block(""):
+                                    X_reindex_shared_dyn = 
T.alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn")
+                                    Y_reindex_shared_dyn = T.alloc_buffer((32, 
128), "float16", strides=(128, 1), scope="shared.dyn")
+                                    with T.block("X_reindex_shared.dyn"):
+                                        T.block_attr({"permuted_layout": 
"g2s_A"})
+                                        for ax0_ax1_fused_0 in range(4):
+                                            for ax0_ax1_fused_3 in 
T.vectorized(8):
+                                                
X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, 
threadIdx_x % 4 * 8 + ax0_ax1_fused_3] = X[blockIdx_y // 8 * 128 + 
ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, ax2_0_0 * 32 + 
threadIdx_x % 4 * 8 + ax0_ax1_fused_3]
+                                    with T.block("Y_reindex_shared.dyn"):
+                                        T.block_attr({"permuted_layout": 
"g2s_B"})
+                                        for ax0_ax1_fused_0 in range(4):
+                                            for ax0_ax1_fused_3 in 
T.vectorized(8):
+                                                
Y_reindex_shared_dyn[ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + threadIdx_x // 16, 
threadIdx_x % 16 * 8 + ax0_ax1_fused_3] = Y[ax2_0_0 * 32 + ax0_ax1_fused_0 * 8 
+ threadIdx_y * 2 + threadIdx_x // 16, blockIdx_x * 1024 + blockIdx_y % 8 * 128 
+ threadIdx_x % 16 * 8 + ax0_ax1_fused_3]
+                                    for ax2_0_1 in T.serial(4):
+                                        with T.block(""):
+                                            
X_reindex_shared_dyn_m16n8k8_matrixA = T.alloc_buffer((64, 8), "float16", 
scope="m16n8k8.matrixA")
+                                            
Y_reindex_shared_dyn_m16n8k8_matrixB = T.alloc_buffer((8, 64), "float16", 
scope="m16n8k8.matrixB")
+                                            for ax0_0, ax1_0 in T.grid(2, 1):
+                                                with 
T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"):
+                                                    
T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y // 
2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8])
+                                                    
T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8])
+                                                    
T.block_attr({"permuted_layout": "s2l_A"})
+                                                    T.ptx_ldmatrix("float16", 
T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8, 
T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data, 
threadIdx_y // 2 * 2048 + ax0_0 * 1024 + ax2_0_1 * 8, 1024, 1), threadIdx_x * 
32)
+                                            for ax0_0, ax1_0 in T.grid(1, 2):
+                                                with 
T.block("Y_reindex_shared.dyn_m16n8k8.matrixB_o"):
+                                                    
T.reads(Y_reindex_shared_dyn[ax2_0_1 * 8:ax2_0_1 * 8 + 8, threadIdx_y % 2 * 64 
+ ax1_0 * 32:threadIdx_y % 2 * 64 + ax1_0 * 32 + 32])
+                                                    
T.writes(Y_reindex_shared_dyn_m16n8k8_matrixB[0:8, ax1_0 * 32:ax1_0 * 32 + 32])
+                                                    
T.block_attr({"permuted_layout": "s2l_B"})
+                                                    T.ptx_ldmatrix("float16", 
T.bool(True), 4, ".b16", Y_reindex_shared_dyn_m16n8k8_matrixB.data, ax1_0 * 8, 
T.tvm_access_ptr(T.type_annotation("float16"), Y_reindex_shared_dyn.data, 
ax2_0_1 * 1024 + threadIdx_y % 2 * 64 + ax1_0 * 32, 1024, 1), threadIdx_x % 8 * 
128 + threadIdx_x // 8 * 8)
+
+    @T.prim_func
+    def expected(X: T.Buffer((4096, 4096), "float16"), Y: T.Buffer((4096, 
4096), "float16")):
+        for blockIdx_x in T.thread_binding(4, thread="blockIdx.x"):
+            for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"):
+                for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"):
+                    for threadIdx_x in T.thread_binding(32, 
thread="threadIdx.x"):
+                        with T.block(""):
+                            T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 
+ threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 
97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 4072], Y[threadIdx_y * 2 + 
threadIdx_x // 16:threadIdx_y * 2 + threadIdx_x // 16 + 4089, blockIdx_x * 1024 
+ blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8:blockIdx_x * 1024 + blockIdx_y % 
8 * 128 + threadIdx_x % 16 * 8 + 8])
+                            T.writes()
+                            for ax2_0_0 in T.serial(128):
+                                with T.block(""):
+                                    T.reads(X[blockIdx_y // 8 * 128 + 
threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + 
threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + 
threadIdx_x % 4 * 8 + 8], Y[ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 
16:ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 16 + 25, blockIdx_x * 1024 + 
blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8:blockIdx_x * 1024 + blockIdx_y % 8 
* 128 + threadIdx_x % 16 * 8 + 8])
+                                    T.writes()
+                                    X_reindex_shared_dyn = 
T.alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn")
+                                    Y_reindex_shared_dyn = T.alloc_buffer((32, 
128), "float16", strides=(128, 1), scope="shared.dyn")
+                                    with T.block("X_reindex_shared.dyn"):
+                                        T.reads(X[blockIdx_y // 8 * 128 + 
threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + 
threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + 
threadIdx_x % 4 * 8 + 8])
+                                        
T.writes(X_reindex_shared_dyn[threadIdx_y * 8 + threadIdx_x // 4:threadIdx_y * 
8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 8])
+                                        for ax0_ax1_fused_0 in range(4):
+                                            for ax0_ax1_fused_3 in 
T.vectorized(8):
+                                                
X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, 
T.bitwise_xor(threadIdx_x % 4, threadIdx_x // 8) * 8 + ax0_ax1_fused_3] = 
X[blockIdx_y // 8 * 128 + ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x 
// 4, ax2_0_0 * 32 + threadIdx_x % 4 * 8 + ax0_ax1_fused_3]
+                                    with T.block("Y_reindex_shared.dyn"):
+                                        T.reads(Y[ax2_0_0 * 32 + threadIdx_y * 
2 + threadIdx_x // 16:ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 16 + 25, 
blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8:blockIdx_x * 
1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8 + 8])
+                                        
T.writes(Y_reindex_shared_dyn[threadIdx_y * 2 + threadIdx_x // 16:threadIdx_y * 
2 + threadIdx_x // 16 + 25, threadIdx_x % 16 * 8:threadIdx_x % 16 * 8 + 8])
+                                        for ax0_ax1_fused_0 in range(4):
+                                            for ax0_ax1_fused_3 in 
T.vectorized(8):
+                                                
Y_reindex_shared_dyn[ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + threadIdx_x // 16, 
T.bitwise_xor(threadIdx_x % 16, threadIdx_y * 2 + threadIdx_x // 16) * 8 + 
ax0_ax1_fused_3]   = Y[ax2_0_0 * 32 + ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + 
threadIdx_x // 16, blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 
* 8 + ax0_ax1_fused_3]
+                                    for ax2_0_1 in T.serial(4):
+                                        with T.block(""):
+                                            
X_reindex_shared_dyn_m16n8k8_matrixA = T.alloc_buffer((64, 8), "float16", 
scope="m16n8k8.matrixA")
+                                            
Y_reindex_shared_dyn_m16n8k8_matrixB = T.alloc_buffer((8, 64), "float16", 
scope="m16n8k8.matrixB")
+                                            for ax0_0, ax1_0 in T.grid(2, 1):
+                                                with 
T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"):
+                                                    
T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y // 
2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8])
+                                                    
T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8])
+                                                    T.ptx_ldmatrix("float16", 
T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8, 
T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data, 
threadIdx_y // 2 * 2048 + ax0_0 * 1024 + threadIdx_x * 32 + 
T.bitwise_xor(ax2_0_1, threadIdx_x % 8 // 2) * 8, 1024, 1), 0)
+                                            for ax0_0, ax1_0 in T.grid(1, 2):
+                                                with 
T.block("Y_reindex_shared.dyn_m16n8k8.matrixB_o"):
+                                                    
T.reads(Y_reindex_shared_dyn[ax2_0_1 * 8:ax2_0_1 * 8 + 8, threadIdx_y % 2 * 64 
+ ax1_0 * 32:threadIdx_y % 2 * 64 + ax1_0 * 32 + 32])
+                                                    
T.writes(Y_reindex_shared_dyn_m16n8k8_matrixB[0:8, ax1_0 * 32:ax1_0 * 32 + 32])
+                                                    T.ptx_ldmatrix("float16", 
T.bool(True), 4, ".b16", Y_reindex_shared_dyn_m16n8k8_matrixB.data, ax1_0 * 8, 
T.tvm_access_ptr(T.type_annotation("float16"), Y_reindex_shared_dyn.data, 
ax2_0_1 * 1024 + threadIdx_x % 8 * 128 + T.bitwise_xor(threadIdx_y % 2 * 8 + 
ax1_0 * 4 + threadIdx_x // 8, threadIdx_x % 8) * 8, 1024, 1), 0)
+    # fmt: on
+    _check_primfunc_transform(before, expected)
+
+
+def test_buffer_a():
+    # fmt: off
+    @T.prim_func
+    def before(p_A: T.handle):
+        A = T.match_buffer(p_A, (T.int64(128), T.int64(32)), "float16")
+        A_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(32)), "float16", 
scope="shared.dyn")
+        A_warp = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(32), 
T.int64(8)), "float16", scope="warp")
+        for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"):
+            for threadIdx_y in T.thread_binding(T.int64(2), 
thread="threadIdx.y"):
+                for threadIdx_x in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
+                    for v0 in range(T.int64(4)):
+                        for v1 in T.vectorized(T.int64(8)):
+                            with T.block("A_reindex_shared.dyn"):
+                                T.block_attr({"permuted_layout": 1})
+                                A_shared_dyn[
+                                    v0 * T.int64(32) + threadIdx_z * 
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4),
+                                    threadIdx_x % T.int64(4) * T.int64(8) + v1
+                                ] = A[
+                                    (v0 * T.int64(32) + threadIdx_z * 
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4)) % 
T.int64(32),
+                                    threadIdx_x % T.int64(4) * T.int64(8) + v1
+                                ]
+                    for v0, v1 in T.grid(T.int64(2), T.int64(4)):
+                        with T.block("A_reindex_shared.dyn_warp_o"):
+                            T.block_attr({"permuted_layout": 1})
+                            with T.block("A_reindex_shared.dyn_warp_o"):
+                                T.reads(A_shared_dyn[threadIdx_z * T.int64(64) 
+ v1 * T.int64(16):threadIdx_z * T.int64(64) + v1 * T.int64(16) + T.int64(16), 
v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)])
+                                T.writes(A_warp[v1, T.int64(0), 
T.int64(0):T.int64(32), T.int64(0):T.int64(8)])
+                                T.ptx_ldmatrix("float16", T.bool(False), 4, 
".b16",
+                                    A_warp.data,
+                                    v1 * T.int64(256) + threadIdx_x * 
T.int64(8),
+                                    
T.tvm_access_ptr(T.type_annotation("float16"),
+                                        A_shared_dyn.data,
+                                        threadIdx_z * T.int64(2048) + v1 * 
T.int64(512) + v0 * T.int64(16), T.int64(512),
+                                        1
+                                    ),
+                                    threadIdx_x % T.int64(16) * T.int64(32) + 
threadIdx_x // T.int64(16) * T.int64(8)
+                                )
+
+    @T.prim_func
+    def expected(A: T.Buffer((T.int64(128), T.int64(32)), "float16")):
+        A_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(32)), "float16", 
scope="shared.dyn")
+        A_warp = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(32), 
T.int64(8)), "float16", scope="warp")
+        for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"):
+            for threadIdx_y in T.thread_binding(T.int64(2), 
thread="threadIdx.y"):
+                for threadIdx_x in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
+                    for v0 in range(T.int64(4)):
+                        for v1 in T.vectorized(T.int64(8)):
+                            with T.block("A_reindex_shared.dyn"):
+                                T.reads(A[(v0 * T.int64(32) + threadIdx_z * 
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4)) % 
T.int64(32), threadIdx_x % T.int64(4) * T.int64(8) + v1])
+                                T.writes(A_shared_dyn[v0 * T.int64(32) + 
threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // 
T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1])
+                                A_shared_dyn[v0 * T.int64(32) + threadIdx_z * 
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), 
T.bitwise_xor(threadIdx_x % T.int64(4), threadIdx_x // T.int64(8)) * T.int64(8) 
+ v1] = A[(v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * 
T.int64(8) + threadIdx_x // T.int64(4)) % T.int64(32), threadIdx_x % T.int64(4) 
* T.int64(8) + v1]
+                    for v0, v1 in T.grid(T.int64(2), T.int64(4)):
+                        with T.block("A_reindex_shared.dyn_warp_o"):
+                            T.reads(A_shared_dyn[threadIdx_z * T.int64(64) + 
v1 * T.int64(16):threadIdx_z * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 
* T.int64(16):v0 * T.int64(16) + T.int64(16)])
+                            T.writes(A_warp[v1, T.int64(0), 
T.int64(0):T.int64(32), T.int64(0):T.int64(8)])
+                            with T.block("A_reindex_shared.dyn_warp_o"):
+                                T.reads(A_shared_dyn[threadIdx_z * T.int64(64) 
+ v1 * T.int64(16):threadIdx_z * T.int64(64) + v1 * T.int64(16) + T.int64(16), 
v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)])
+                                T.writes(A_warp[v1, T.int64(0), 
T.int64(0):T.int64(32), T.int64(0):T.int64(8)])
+                                T.ptx_ldmatrix("float16", T.bool(False), 4, 
".b16", A_warp.data, v1 * T.int64(256) + threadIdx_x * T.int64(8), 
T.tvm_access_ptr(T.type_annotation("float16"), A_shared_dyn.data, threadIdx_z * 
T.int64(2048) + v1 * T.int64(512) + threadIdx_x % T.int64(16) * T.int64(32) + 
T.bitwise_xor(v0 * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % 
T.int64(8) // T.int64(2)) * T.int64(8), T.int64(512), 1), T.int64(0))
+
+    # fmt: on
+    _check_primfunc_transform(before, expected)
+
+
+def test_buffer_b():
+    # fmt: off
+    @T.prim_func
+    def before(B: T.Buffer((T.int64(128), T.int64(32)), "float16")):
+        B_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(32)), "float16", 
scope="shared.dyn")
+        for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"):
+            for threadIdx_y in T.thread_binding(T.int64(2), 
thread="threadIdx.y"):
+                for threadIdx_x in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
+                    for v0 in range(T.int64(4)):
+                        for v1 in T.vectorized(T.int64(8)):
+                            with T.block("B_reindex_shared.dyn"):
+                                T.block_attr({"permuted_layout": 1})
+                                B_shared_dyn[v0 * T.int64(32) + threadIdx_z * 
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x 
% T.int64(4) * T.int64(8) + v1] = B[v0 * T.int64(32) + threadIdx_z * 
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x 
% T.int64(4) * T.int64(8) + v1]
+                    for v0 in range(T.int64(2)):
+                        with T.block(""):
+                            B_warp = T.alloc_buffer((T.int64(4), T.int64(1), 
T.int64(32), T.int64(8)), "float16", scope="warp")
+                            for v1 in range(T.int64(4)):
+                                with T.block("B_reindex_shared.dyn_warp_o"):
+                                    T.block_attr({"permuted_layout": 1})
+                                    with 
T.block("B_reindex_shared.dyn_warp_o"):
+                                        T.reads(B_shared_dyn[threadIdx_y * 
T.int64(64) + v1 * T.int64(16):threadIdx_y * T.int64(64) + v1 * T.int64(16) + 
T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)])
+                                        T.writes(B_warp[v1, T.int64(0), 
T.int64(0):T.int64(32), T.int64(0):T.int64(8)])
+                                        T.ptx_ldmatrix("float16", 
T.bool(False), 4, ".b16", B_warp.data, v1 * T.int64(256) + threadIdx_x * 
T.int64(8), T.tvm_access_ptr(T.type_annotation("float16"), B_shared_dyn.data, 
threadIdx_y * T.int64(2048) + v1 * T.int64(512) + v0 * T.int64(16), 
T.int64(512), 1), threadIdx_x // T.int64(16) * T.int64(256) + threadIdx_x % 
T.int64(8) * T.int64(32) + threadIdx_x % T.int64(16) // T.int64(8) * T.int64(8))
+
+    @T.prim_func
+    def expected(B: T.Buffer((T.int64(128), T.int64(32)), "float16")):
+        B_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(32)), "float16", 
scope="shared.dyn")
+        for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"):
+            for threadIdx_y in T.thread_binding(T.int64(2), 
thread="threadIdx.y"):
+                for threadIdx_x in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
+                    for v0 in range(T.int64(4)):
+                        for v1 in T.vectorized(T.int64(8)):
+                            with T.block("B_reindex_shared.dyn"):
+                                T.reads(B[v0 * T.int64(32) + threadIdx_z * 
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x 
% T.int64(4) * T.int64(8) + v1])
+                                T.writes(B_shared_dyn[v0 * T.int64(32) + 
threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // 
T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1])
+                                B_shared_dyn[v0 * T.int64(32) + threadIdx_z * 
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), 
T.bitwise_xor(threadIdx_x % T.int64(4), threadIdx_x // T.int64(8)) * T.int64(8) 
+ v1] = B[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * 
T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + 
v1]
+                    for v0 in range(T.int64(2)):
+                        with T.block(""):
+                            B_warp = T.alloc_buffer((T.int64(4), T.int64(1), 
T.int64(32), T.int64(8)), "float16", scope="warp")
+                            for v1 in range(T.int64(4)):
+                                with T.block("B_reindex_shared.dyn_warp_o"):
+                                    T.reads(B_shared_dyn[threadIdx_y * 
T.int64(64) + v1 * T.int64(16):threadIdx_y * T.int64(64) + v1 * T.int64(16) + 
T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)])
+                                    T.writes(B_warp[v1, T.int64(0), 
T.int64(0):T.int64(32), T.int64(0):T.int64(8)])
+                                    with 
T.block("B_reindex_shared.dyn_warp_o"):
+                                        T.reads(B_shared_dyn[threadIdx_y * 
T.int64(64) + v1 * T.int64(16):threadIdx_y * T.int64(64) + v1 * T.int64(16) + 
T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)])
+                                        T.writes(B_warp[v1, T.int64(0), 
T.int64(0):T.int64(32), T.int64(0):T.int64(8)])
+                                        T.ptx_ldmatrix("float16", 
T.bool(False), 4, ".b16", B_warp.data, v1 * T.int64(256) + threadIdx_x * 
T.int64(8), T.tvm_access_ptr(T.type_annotation("float16"), B_shared_dyn.data, 
threadIdx_y * T.int64(2048) + v1 * T.int64(512) + threadIdx_x // T.int64(16) * 
T.int64(256) + threadIdx_x % T.int64(8) * T.int64(32) + T.bitwise_xor(v0 * 
T.int64(2) + threadIdx_x % T.int64(16) // T.int64(8), threadIdx_x % T.int64(8) 
// T.int64(2)) * T.int64(8), T.int64(512), [...]
+
+    # fmt: on
+    _check_primfunc_transform(before, expected)
+
+
+def test_buffer_c_fp32():
+    # fmt: off
+    @T.prim_func
+    def before(p_O: T.handle):
+        O = T.match_buffer(p_O, (T.int64(128), T.int64(128)), "float16")
+        O_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(128)), 
scope="shared.dyn")
+        O_warp = T.alloc_buffer((T.int64(4), T.int64(4), T.int64(32), 
T.int64(8)), scope="warp")
+        for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"):
+            for threadIdx_y in T.thread_binding(T.int64(2), 
thread="threadIdx.y"):
+                for threadIdx_x in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
+                    for v0, v1 in T.grid(T.int64(4), T.int64(4)):
+                        with T.block("O.dyn_warp_o"):
+                            T.block_attr({"permuted_layout": 1})
+                            with T.block("O.dyn_warp_o"):
+                                for local_id in range(T.int64(8)):
+                                    O_shared_dyn[threadIdx_z * T.int64(64) + 
v0 * T.int64(16) + local_id % T.int64(4) // T.int64(2) * T.int64(8) + 
threadIdx_x // T.int64(4), threadIdx_y * T.int64(64) + v1 * T.int64(16) + 
local_id // T.int64(4) * T.int64(8) + threadIdx_x % T.int64(4) * T.int64(2) + 
local_id % T.int64(2)] = O_warp[v0, v1, threadIdx_x, local_id]
+                    for v0 in range(T.int64(16)):
+                        for v1 in T.vectorized(T.int64(8)):
+                            with T.block("O.dyn"):
+                                T.block_attr({"permuted_layout": 1})
+                                O[v0 * T.int64(8) + threadIdx_z * T.int64(4) + 
threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % 
T.int64(16) * T.int64(8) + v1] = T.Cast("float16", O_shared_dyn[v0 * T.int64(8) 
+ threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // 
T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1])
+
+
+    @T.prim_func
+    def expected(O: T.Buffer((T.int64(128), T.int64(128)), "float16")):
+        # with T.block("root"):
+        O_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(128)), 
scope="shared.dyn")
+        O_warp = T.alloc_buffer((T.int64(4), T.int64(4), T.int64(32), 
T.int64(8)), scope="warp")
+        for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"):
+            for threadIdx_y in T.thread_binding(T.int64(2), 
thread="threadIdx.y"):
+                for threadIdx_x in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
+                    for v0, v1 in T.grid(T.int64(4), T.int64(4)):
+                        with T.block("O.dyn_warp_o"):
+                            T.reads(O_warp[v0, v1, threadIdx_x, 
T.int64(0):T.int64(8)])
+                            T.writes(O_shared_dyn[threadIdx_z * T.int64(64) + 
v0 * T.int64(16) + threadIdx_x // T.int64(4):threadIdx_z * T.int64(64) + v0 * 
T.int64(16) + threadIdx_x // T.int64(4) + T.int64(9), threadIdx_y * T.int64(64) 
+ v1 * T.int64(16) + threadIdx_x % T.int64(4) * T.int64(2):threadIdx_y * 
T.int64(64) + v1 * T.int64(16) + threadIdx_x % T.int64(4) * T.int64(2) + 
T.int64(10)])
+                            with T.block("O.dyn_warp_o"):
+                                T.reads(O_warp[v0, v1, threadIdx_x, 
T.int64(0):T.int64(8)])
+                                T.writes(O_shared_dyn[threadIdx_z * 
T.int64(64) + v0 * T.int64(16) + threadIdx_x // T.int64(4):threadIdx_z * 
T.int64(64) + v0 * T.int64(16) + threadIdx_x // T.int64(4) + T.int64(9), 
threadIdx_y * T.int64(64) + v1 * T.int64(16) + threadIdx_x % T.int64(4) * 
T.int64(2):threadIdx_y * T.int64(64) + v1 * T.int64(16) + threadIdx_x % 
T.int64(4) * T.int64(2) + T.int64(10)])
+                                for local_id in range(T.int64(8)):
+                                    O_shared_dyn[threadIdx_z * T.int64(64) + 
v0 * T.int64(16) + local_id % T.int64(4) // T.int64(2) * T.int64(8) + 
threadIdx_x // T.int64(4), T.bitwise_xor(threadIdx_y * T.int64(8) + v1 * 
T.int64(2) + local_id // T.int64(4), threadIdx_x // T.int64(4)) * T.int64(8) + 
threadIdx_x % T.int64(4) * T.int64(2) + local_id % T.int64(2)] = O_warp[v0, v1, 
threadIdx_x, local_id]
+                    for v0 in range(T.int64(16)):
+                        for v1 in T.vectorized(T.int64(8)):
+                            with T.block("O.dyn"):
+                                T.reads(O_shared_dyn[v0 * T.int64(8) + 
threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // 
T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1])
+                                T.writes(O[v0 * T.int64(8) + threadIdx_z * 
T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x 
% T.int64(16) * T.int64(8) + v1])
+                                O[v0 * T.int64(8) + threadIdx_z * T.int64(4) + 
threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % 
T.int64(16) * T.int64(8) + v1] = T.Cast("float16", O_shared_dyn[v0 * T.int64(8) 
+ threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // 
T.int64(16), T.bitwise_xor(threadIdx_x % T.int64(16), threadIdx_z * T.int64(4) 
+ threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16)) * T.int64(8) + v1])
+
+    # fmt: on
+    _check_primfunc_transform(before, expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to