This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4a5e22e869 [BugFix][MetaSchedule] MultiLevelTilingTensorCore generates 
inconsistent thread-binding sketch for batched matmul (#17012)
4a5e22e869 is described below

commit 4a5e22e869e92b9c12b3bda8b88a0ce8c69b8d30
Author: tsu-bin <81693503+tsu-...@users.noreply.github.com>
AuthorDate: Fri Jun 28 12:55:06 2024 +0800

    [BugFix][MetaSchedule] MultiLevelTilingTensorCore generates inconsistent 
thread-binding sketch for batched matmul (#17012)
    
    * [BugFix][MetaSchedule] MultiLevelTilingTensorCore generates inconsistent 
thread-binding sketch for batched matmul
    
    * Update testcase test_meta_schedule_schedule_rule_mlt_tc.py::test_conv_1x1
    
    ---------
    
    Co-authored-by: tsu-bin <tsu...@gmail.com>
---
 .../schedule_rule/multi_level_tiling.cc            | 23 +++++-
 .../schedule_rule/multi_level_tiling.h             |  2 +-
 .../multi_level_tiling_tensor_core.cc              |  2 +-
 .../test_meta_schedule_schedule_rule_mlt_tc.py     | 93 +++++++++++-----------
 4 files changed, 70 insertions(+), 50 deletions(-)

diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc 
b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index 702947ebc0..bcaf4343e2 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -190,7 +190,8 @@ std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> 
MultiLevelTilingNode::SplitLoo
   return {factors, splits};
 }
 
-std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
+std::vector<State> MultiLevelTilingNode::TileLoopNest(State state,
+                                                      int 
tile_inner_most_space_loop_num) const {
   Schedule& sch = state->sch;
   const BlockRV& block_rv = state->block_rv;
   // Step 1. Assuming trivial binding, pair the loops and their iter-var-types
@@ -199,6 +200,16 @@ std::vector<State> 
MultiLevelTilingNode::TileLoopNest(State state) const {
   ICHECK_EQ(loops.size(), iter_types.size());
   // Step 2. For each loop axis, tile it
   int64_t spatial_loop_product = 1;
+
+  int total_spatial_loop_num = 0;
+  std::for_each(iter_types.begin(), iter_types.end(), [&](const auto& 
iter_type) {
+    if (iter_type == IterVarType::kDataPar) total_spatial_loop_num++;
+  });
+  CHECK_GE(total_spatial_loop_num, tile_inner_most_space_loop_num);
+  if (tile_inner_most_space_loop_num < 0) tile_inner_most_space_loop_num = 
total_spatial_loop_num;
+  int outer_most_spatial_loop_skipped_num = total_spatial_loop_num - 
tile_inner_most_space_loop_num;
+
+  Array<LoopRV> skipped_outer_spatial_loops;
   std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
   state->tile_factors.resize(tiles.size());
   std::vector<Array<tir::ExprRV>> tile_factors;
@@ -208,6 +219,11 @@ std::vector<State> 
MultiLevelTilingNode::TileLoopNest(State state) const {
     const std::vector<int>* idx = nullptr;
 
     if (iter_types[i] == IterVarType::kDataPar) {
+      if (outer_most_spatial_loop_skipped_num > 0) {
+        skipped_outer_spatial_loops.push_back(loop);
+        outer_most_spatial_loop_skipped_num--;
+        continue;
+      }
       idx = &s_indices_;
       if (spatial_loop_product != -1) {
         if (const int64_t* extent = 
tir::GetLoopIntExtent(sch->Get(loop).get())) {
@@ -241,6 +257,11 @@ std::vector<State> 
MultiLevelTilingNode::TileLoopNest(State state) const {
   sch->Reorder(support::ConcatArrayList<LoopRV>(tiles.begin(), tiles.end()));
   // Step 4. Bind the tiles to threads
   int n_binds = std::min(tile_binds.size(), tiles.size());
+  if (skipped_outer_spatial_loops.size() && n_binds) {
+    auto& the_first_tile = tiles[0];
+    the_first_tile.insert(the_first_tile.begin(), 
skipped_outer_spatial_loops.begin(),
+                          skipped_outer_spatial_loops.end());
+  }
   for (int i = 0; i < n_binds; ++i) {
     LoopRV fused = sch->Fuse(tiles[i]);
     sch->Bind(fused, tile_binds[i]);
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h 
b/src/meta_schedule/schedule_rule/multi_level_tiling.h
index 2b06aba9c1..23d6599a25 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.h
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h
@@ -162,7 +162,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
   // SubRule 1. add write cache
   std::vector<State> AddWriteReuse(State state) const;
   // SubRule 2. tile the loop nest
-  std::vector<State> TileLoopNest(State state) const;
+  std::vector<State> TileLoopNest(State state, int 
tile_inner_most_space_loop_num = -1) const;
   // SubRule 3. add read cache
   std::vector<State> AddReadReuse(State state) const;
   // SubRule 4. add async pipeline
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc 
b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index e3b51dda15..e038ab908d 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -251,7 +251,7 @@ std::vector<State> 
MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<Sta
   });
   states = SubRule(std::move(states), [&](State state) {
     TensorCoreState tc_state = Downcast<TensorCoreState>(state);
-    return tc_state->is_mma ? MMATileLoopNest(tc_state) : TileLoopNest(state);
+    return tc_state->is_mma ? MMATileLoopNest(tc_state) : TileLoopNest(state, 
2);
   });
   states = SubRule(std::move(states), [&](State state) {
     return TransformIntermediateOutputLayout(Downcast<TensorCoreState>(state));
diff --git 
a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py 
b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
index 034bddd971..da00f294ba 100644
--- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -903,39 +903,39 @@ def test_conv_1x1():
     def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: 
T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), 
"float32")):
         T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
         # with T.block("root"):
-        conv2d_nhwc_reindex_shared = T.alloc_buffer((2, 2, 8, 2, 16, 16), 
scope="shared")
-        conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((2, 2, 8, 
2, 16, 16), scope="wmma.accumulator")
+        conv2d_nhwc_reindex_shared = T.alloc_buffer((2, 1, 8, 4, 16, 16), 
scope="shared")
+        conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((2, 1, 8, 
4, 16, 16), scope="wmma.accumulator")
         PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16", 
scope="shared")
         weight_reindex_shared = T.alloc_buffer((1, 1, 64, 64), "float16", 
scope="shared")
         PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 64), 
"float16", scope="wmma.matrix_a")
         weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 64, 64), 
"float16", scope="wmma.matrix_b")
-        for ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused in T.thread_binding(4, 
thread="blockIdx.y"):
-            for ax0_1_ax1_1_ax2_0_1_ax3_0_1_fused in T.thread_binding(1, 
thread="blockIdx.x"):
-                for ax0_2_ax1_2_ax2_0_2_ax3_0_2_fused in T.thread_binding(1, 
thread="threadIdx.y"):
-                    for ax4_0_0 in range(1):
+        for ax0_ax1_ax2_0_0_ax3_0_0_fused in T.thread_binding(1, 
thread="blockIdx.y"):
+            for ax2_0_1_ax3_0_1_fused in T.thread_binding(1, 
thread="blockIdx.x"):
+                for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, 
thread="threadIdx.y"):
+                    for ax4_0_0 in range(2):
                         for ax0_ax1_fused in range(8192):
                             with T.block("PadInput_reindex_shared"):
-                                v0 = T.axis.spatial(256, 
ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 128 + ax0_ax1_fused // 64)
-                                v1 = T.axis.spatial(64, ax0_ax1_fused % 64)
+                                v0 = T.axis.spatial(256, ax0_ax1_fused // 32)
+                                v1 = T.axis.spatial(64, ax4_0_0 * 32 + 
ax0_ax1_fused % 32)
                                 T.reads(inputs[0, v0 // 16, v0 % 16, v1])
                                 T.writes(PadInput_reindex_shared[v0, v1])
-                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 2})
+                                T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 8})
                                 PadInput_reindex_shared[v0, v1] = inputs[0, v0 
// 16, v0 % 16, v1]
                         for ax0_ax1_ax2_ax3_fused in range(2048):
                             with T.block("weight_reindex_shared"):
                                 v0 = T.axis.spatial(1, 0)
                                 v1 = T.axis.spatial(1, 0)
-                                v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused 
// 32)
-                                v3 = T.axis.spatial(64, 
ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32)
+                                v2 = T.axis.spatial(64, ax4_0_0 * 32 + 
ax0_ax1_ax2_ax3_fused // 64)
+                                v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused 
% 64)
                                 T.reads(weight[v0, v1, v2, v3])
                                 T.writes(weight_reindex_shared[v0, v1, v2, v3])
-                                T.block_attr({"buffer_dim_align": [[0, 2, 32, 
8]], "meta_schedule.cooperative_fetch": 8})
+                                T.block_attr({"buffer_dim_align": [[0, 2, 32, 
8]], "meta_schedule.cooperative_fetch": 4})
                                 weight_reindex_shared[v0, v1, v2, v3] = 
weight[v0, v1, v2, v3]
                         for ax4_0_1 in range(1):
-                            for ax0_0, ax1_0 in T.grid(8, 4):
+                            for ax0_0, ax1_0 in T.grid(8, 2):
                                 with 
T.block("PadInput_reindex_shared_wmma.matrix_a_o"):
-                                    v0_o = T.axis.spatial(16, 
ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 8 + ax0_0)
-                                    v1_o = T.axis.spatial(4, ax1_0)
+                                    v0_o = T.axis.spatial(16, 
ax2_0_2_ax3_0_2_fused * 8 + ax0_0)
+                                    v1_o = T.axis.spatial(4, ax4_0_0 * 2 + 
ax1_0)
                                     T.reads(PadInput_reindex_shared[v0_o * 
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
                                     
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 
16:v1_o * 16 + 16])
                                     
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_a_shared"})
@@ -945,10 +945,11 @@ def test_conv_1x1():
                                             
T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
                                             
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + 
v1_i])
                                             
PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = 
PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
-                            for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 4, 2):
+                            for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 4):
                                 with 
T.block("weight_reindex_shared_wmma.matrix_b_o"):
-                                    v0_o, v1_o, v2_o = T.axis.remap("SSS", 
[ax0, ax1, ax2_0])
-                                    v3_o = T.axis.spatial(4, 
ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 2 + ax3_0)
+                                    v0_o, v1_o = T.axis.remap("SS", [ax0, ax1])
+                                    v2_o = T.axis.spatial(4, ax4_0_0 * 2 + 
ax2_0)
+                                    v3_o = T.axis.spatial(4, ax3_0)
                                     T.reads(weight_reindex_shared[v0_o, v1_o, 
v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
                                     
T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16:v2_o * 16 + 
16, v3_o * 16:v3_o * 16 + 16])
                                     
T.block_attr({"meta_schedule.auto_tensorize": 
"wmma_load_16x16x16_f16_b_shared"})
@@ -958,38 +959,38 @@ def test_conv_1x1():
                                             
T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i])
                                             
T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o 
* 16 + v3_i])
                                             
weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + 
v3_i] = weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]
-                            for ax0_3, ax1_3, ax2_0_3, ax3_0_3, ax4_0_2, 
ax0_4, ax1_4, ax2_0_4, ax3_0_4 in T.grid(1, 1, 8, 2, 4, 1, 1, 1, 1):
+                            for ax2_0_3, ax3_0_3, ax4_0_2, ax2_0_4, ax3_0_4 in 
T.grid(8, 1, 2, 1, 4):
                                 with T.block("conv2d_nhwc_o"):
-                                    v0_o = T.axis.spatial(1, ax0_3 + ax0_4)
-                                    v1_o = T.axis.spatial(1, ax1_3 + ax1_4)
-                                    v2_o = T.axis.spatial(16, 
ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 8 + ax2_0_3 + ax2_0_4)
-                                    v3_o = T.axis.spatial(4, 
ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 2 + ax3_0_3 + ax3_0_4)
-                                    v4_o = T.axis.reduce(4, ax4_0_0 * 4 + 
ax4_0_1 * 4 + ax4_0_2)
+                                    v0_o = T.axis.spatial(1, 0)
+                                    v1_o = T.axis.spatial(1, 0)
+                                    v2_o = T.axis.spatial(16, 
ax2_0_2_ax3_0_2_fused * 8 + ax2_0_3 + ax2_0_4)
+                                    v3_o = T.axis.spatial(4, ax3_0_3 * 4 + 
ax3_0_4)
+                                    v4_o = T.axis.reduce(4, ax4_0_0 * 2 + 
ax4_0_1 * 2 + ax4_0_2)
                                     
T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 
16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 
16:v4_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
-                                    
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o 
% 8, v3_o % 2, 0:16, 0:16])
+                                    
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, 
v3_o, 0:16, 0:16])
                                     
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", 
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", 
"warp_execution": 1})
                                     with T.init():
                                         for ax2_1, ax3_1 in T.grid(16, 16):
                                             with T.block("conv2d_nhwc_init"):
                                                 v2_i_init, v3_i_init = 
T.axis.remap("SS", [ax2_1, ax3_1])
                                                 T.reads()
-                                                
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o 
% 8, v3_o % 2, v2_i_init, v3_i_init])
-                                                
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, 
v3_o % 2, v2_i_init, v3_i_init] = T.float32(0)
+                                                
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, 
v3_o, v2_i_init, v3_i_init])
+                                                
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, 
v2_i_init, v3_i_init] = T.float32(0)
                                     for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 
16):
                                         with T.block("conv2d_nhwc"):
                                             v2_i, v3_i, v4_i = 
T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1])
-                                            
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o 
% 8, v3_o % 2, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + 
v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 
16 + v4_i, v3_o * 16 + v3_i])
-                                            
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o 
% 8, v3_o % 2, v2_i, v3_i])
+                                            
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, 
v3_o, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o 
* 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, 
v3_o * 16 + v3_i])
+                                            
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, 
v3_o, v2_i, v3_i])
                                             
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
-                                            
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, 
v3_o % 2, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 
v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i] + T.Cast("float32", 
PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * 
T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + 
v4_i, v3_o * 16 + v3_i])
+                                            
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, 
v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, 
v3_o, v2_i, v3_i] + T.Cast("float32", 
PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * 
T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + 
v4_i, v3_o * 16 + v3_i])
                 for ax2 in range(8):
-                    for ax0_ax1_fused in T.thread_binding(1, 
thread="threadIdx.y"):
-                        for ax2_1, ax3 in T.grid(1, 2):
+                    for ax0_ax1_fused in T.thread_binding(2, 
thread="threadIdx.y"):
+                        for ax2_1, ax3 in T.grid(1, 4):
                             with 
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
-                                v0_o = T.axis.spatial(2, 
ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2)
-                                v1_o = T.axis.spatial(2, 
ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2)
+                                v0_o = T.axis.spatial(2, ax0_ax1_fused)
+                                v1_o = T.axis.spatial(1, 0)
                                 v2_o = T.axis.spatial(8, ax2 + ax2_1)
-                                v3_o = T.axis.spatial(2, ax3)
+                                v3_o = T.axis.spatial(4, ax3)
                                 v4_o = T.axis.spatial(1, 0)
                                 v5_o = T.axis.spatial(1, 0)
                                 
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 
0:16, 0:16])
@@ -1001,29 +1002,27 @@ def test_conv_1x1():
                                         
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 
v4_i, v5_i])
                                         
T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
                                         conv2d_nhwc_reindex_shared[v0_o, v1_o, 
v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, 
v1_o, v2_o, v3_o, v4_i, v5_i]
-                    for ax0_ax1_ax3_ax4_ax5_fused in range(512):
+                    for ax0_ax1_ax3_ax4_ax5_fused in range(2048):
                         with T.block("conv2d_nhwc_reindex_shared"):
-                            v0 = T.axis.spatial(2, 
ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2)
-                            v1 = T.axis.spatial(2, 
ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2)
+                            v0 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused 
// 1024)
+                            v1 = T.axis.spatial(1, 0)
                             v2 = T.axis.spatial(8, ax2)
-                            v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused 
// 256)
+                            v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 
1024 // 256)
                             v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 256 // 16)
                             v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
                             T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 
v4, v5])
-                            T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) 
// 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16 + v1 * 32])
+                            T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) 
// 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16])
                             T.block_attr({"meta_schedule.cooperative_fetch": 
1})
-                            conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, 
(v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16 + v1 * 32] = 
conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]
+                            conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, 
(v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16] = conv2d_nhwc_reindex_shared[v0, 
v1, v2, v3, v4, v5]
     # fmt: on
 
     decision_0 = [
-        ("SamplePerfectTile", [1, 1, 1, 1, 1]),
-        ("SamplePerfectTile", [1, 1, 1, 1, 1]),
-        ("SamplePerfectTile", [2, 1, 1, 8, 1]),
-        ("SamplePerfectTile", [2, 1, 1, 2, 1]),
-        ("SamplePerfectTile", [1, 1, 4]),
+        ("SamplePerfectTile", [1, 1, 2, 8, 1]),
+        ("SamplePerfectTile", [1, 1, 1, 1, 4]),
+        ("SamplePerfectTile", [2, 1, 2]),
         ("SampleCategorical", 0),
-        ("SampleCategorical", 1),
         ("SampleCategorical", 3),
+        ("SampleCategorical", 2),
     ]
 
     mod = te.create_prim_func(

Reply via email to