This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 424c749a3d [MetaSchedule] Tile and pack intermediate output for CUDA
TensorCore (#14108)
424c749a3d is described below
commit 424c749a3dac0ba42e89d3cbd04b024658d7d104
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Mar 6 03:24:33 2023 -0800
[MetaSchedule] Tile and pack intermediate output for CUDA TensorCore
(#14108)
* [MetaSchedule] Tile and pack intermediate output for CUDA TensorCore
* clean up schedule rule mltc
* add lhs analyzer
* prevent simplifying single point
* clean up
* lint
* fix rewrite_tensorize test
* fix software pipeline test
* fix compile on mac
* fix test cases
* remove unused
* rebase
* only use json format for roundtrip
* lint
* Update src/tir/schedule/ir_comparator.h
Co-authored-by: Siyuan Feng <[email protected]>
---------
Co-authored-by: Tianqi Chen <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
---
.../tvm/meta_schedule/testing/space_generation.py | 2 +-
src/meta_schedule/postproc/verify_gpu_code.cc | 1 +
.../schedule_rule/multi_level_tiling.cc | 13 +-
.../schedule_rule/multi_level_tiling.h | 8 +-
.../multi_level_tiling_tensor_core.cc | 176 ++++-
.../multi_level_tiling_wide_vector.cc | 15 +-
src/tir/analysis/block_access_region_detector.cc | 9 +-
src/tir/schedule/ir_comparator.cc | 10 +-
src/tir/schedule/ir_comparator.h | 7 +-
.../test_meta_schedule_schedule_rule_mlt_tc.py | 783 +++++++++------------
.../test_tir_transform_inject_software_pipeline.py | 16 +-
11 files changed, 567 insertions(+), 473 deletions(-)
diff --git a/python/tvm/meta_schedule/testing/space_generation.py
b/python/tvm/meta_schedule/testing/space_generation.py
index 0b7072b65a..45cd6659b6 100644
--- a/python/tvm/meta_schedule/testing/space_generation.py
+++ b/python/tvm/meta_schedule/testing/space_generation.py
@@ -88,7 +88,7 @@ def _find_match_sketch_id(
decisions=new_decisions,
).apply_to_schedule(sch, remove_postproc=True)
if structural_equal(sch.mod, expected_mod):
- verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask)
+ verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask,
text_format="json")
return sketch_id
return None
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc
b/src/meta_schedule/postproc/verify_gpu_code.cc
index 99ffc1bfcd..6f9b46a0f7 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -162,6 +162,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::UnifyThreadBinding());
+
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index 779114e9cf..0312c100b5 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -186,15 +186,15 @@ std::vector<State>
MultiLevelTilingNode::AddWriteReuse(State state) const {
return results;
}
-Array<tir::LoopRV> MultiLevelTilingNode::SplitLoop(const Schedule& sch,
BlockRV block, LoopRV loop,
- int n_tiles) const {
+std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>>
MultiLevelTilingNode::SplitLoop(
+ const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const {
Array<tir::ExprRV> factors = sch->SamplePerfectTile(
/*loop=*/loop,
/*n=*/n_tiles,
/*max_innermost_factor=*/max_innermost_factor);
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
/*factors=*/{factors.begin(),
factors.end()});
- return splits;
+ return {factors, splits};
}
std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
@@ -207,6 +207,9 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State
state) const {
// Step 2. For each loop axis, tile it
int64_t spatial_loop_product = 1;
std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
+ state->tile_factors.resize(tiles.size());
+ std::vector<Array<tir::ExprRV>> tile_factors;
+ tile_factors.resize(tiles.size());
for (int i = 0, n = loops.size(); i < n; ++i) {
LoopRV loop = loops[i];
const std::vector<int>* idx = nullptr;
@@ -231,14 +234,16 @@ std::vector<State>
MultiLevelTilingNode::TileLoopNest(State state) const {
if (n_tiles == 1) {
tiles[idx->at(0)].push_back(loop);
} else {
- auto splits = SplitLoop(sch, block_rv, loop, n_tiles);
+ auto [factors, splits] = SplitLoop(sch, block_rv, loop, n_tiles);
// Put every tile to its slot
for (int j = 0; j < n_tiles; ++j) {
tiles[idx->at(j)].push_back(splits[j]);
+ tile_factors[idx->at(j)].push_back(factors[j]);
}
}
}
+ state->tile_factors = std::move(tile_factors);
// Step 3. Reorder to organize the tiles
sch->Reorder(support::ConcatArrayList<LoopRV>(tiles.begin(), tiles.end()));
// Step 4. Bind the tiles to threads
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h
b/src/meta_schedule/schedule_rule/multi_level_tiling.h
index ff38756ff0..41b3ca9f26 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.h
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h
@@ -94,6 +94,8 @@ class StateNode : public Object {
tir::BlockRV block_rv;
/*! \brief The loop tiles */
Array<Array<tir::LoopRV>> tiles;
+ /*! \brief The factors of the loop tiles. */
+ Array<Array<tir::ExprRV>> tile_factors;
/*! \brief The mapping from buffer index to read cache block. */
std::unordered_map<int, tir::BlockRV> read_reuse;
/*! \brief The mapping from buffer index to write cache block. */
@@ -163,8 +165,10 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
protected:
virtual std::vector<State> ApplySubRules(std::vector<State> states);
- virtual Array<tir::LoopRV> SplitLoop(const tir::Schedule& sch, tir::BlockRV
block,
- tir::LoopRV loop, int n_tiles) const;
+ virtual std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> SplitLoop(const
tir::Schedule& sch,
+
tir::BlockRV block,
+
tir::LoopRV loop,
+ int
n_tiles) const;
// Annotate a block to use cooperative fetching
void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV&
block) const;
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index d5cca52d41..1f9945022b 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/meta_schedule/schedule_rule.h>
+#include <tvm/tir/op.h>
#include <algorithm>
#include <utility>
@@ -124,6 +125,9 @@ class MultiLevelTilingTensorCoreNode : public
MultiLevelTilingNode {
private:
// SubRule: Add tensorization-related transformations
inline std::vector<State> TransformForTensorization(TensorCoreState state)
const;
+ // Subrule: Transform the layout of the output. This is necessary for
efficient cache write the
+ // output in the shared memory.
+ std::vector<State> TransformIntermediateOutputLayout(TensorCoreState state);
// Subrule: Add tensorized load
inline std::vector<State> AddReadReuseTensorCore(TensorCoreState state)
const;
// Subrule: Add tensorized store
@@ -225,6 +229,9 @@ std::vector<State>
MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<Sta
return TransformForTensorization(Downcast<TensorCoreState>(state));
});
states = SubRule(std::move(states), [&](State state) { return
TileLoopNest(state); });
+ states = SubRule(std::move(states), [&](State state) {
+ return TransformIntermediateOutputLayout(Downcast<TensorCoreState>(state));
+ });
states = SubRule(std::move(states), [&](State state) { return
AddWriteReuse(state); });
states = SubRule(std::move(states), [&](State state) {
return AddWriteReuseTensorCore(Downcast<TensorCoreState>(state));
@@ -248,25 +255,162 @@ void
MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(Schedule* sch,
(*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize,
intrin_name);
}
+std::vector<State>
MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLayout(
+ TensorCoreState state) {
+ // Transform the intermediate output to packed layout
+ // [..., warp_m, warp_n, accum_frag_m, accum_frag_n, accum_elem_m,
accum_elem_n]
+ // where warp_m, warp_n are thread indices bound to the warp id,
accum_frag_m, accum_frag_n are
+ // the index of the fragments in each warp, accum_elem_m, accum_elem_n are
the index of the
+ // elements in each accumulator fragment.
+
+ // Get the shape of the wmma accumulator
+ auto [frag_shape_m, frag_shape_n] = [&]() {
+ tir::Block intrin_block =
+ Downcast<tir::BlockRealize>(
+
tir::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body)
+ ->block;
+ tir::For loop_m = Downcast<tir::For>(intrin_block->body);
+ tir::For loop_n = Downcast<tir::For>(loop_m->body);
+ return std::make_tuple(loop_m->extent, loop_n->extent);
+ }();
+
+ // Get the tile index of the warp id (i.e. threadIdx.y)
+ auto it = std::find(tile_binds.begin(), tile_binds.end(), "threadIdx.y");
+ ICHECK(it != tile_binds.end());
+ auto tile_index_warp_id = std::distance(tile_binds.begin(), it);
+
+ // Get the extent of loop indicated by `loop_idx` inside the warp scope.
+ // For example, after spatial loops i, j are tiled, we will have
+ // tile_factors = ((i0, j0), (i1, j1), ..., (in, jn))
+ // This function computes the product of tile_factors[i][loop_idx] for i >
tile_index_warp_id.
+ // `loop_idx` can be negative, in which case it is counted from the end.
+ auto f_get_inner_tile_product = [&](int loop_idx) {
+ Array<tir::ExprRV> factors;
+ for (int i = tile_index_warp_id + 1; i <
static_cast<int>(s_indices_.size()); ++i) {
+ auto s_factors = state->tile_factors[s_indices_[i]];
+ if (loop_idx < 0) {
+ loop_idx += s_factors.size();
+ }
+ factors.push_back(s_factors[loop_idx]);
+ }
+ ICHECK(!factors.empty());
+ if (factors.size() == 1) {
+ return factors[0];
+ }
+ auto result = factors[0];
+ for (int i = 1; i < static_cast<int>(factors.size()); ++i) {
+ result = result * factors[i];
+ }
+ return result;
+ };
+
+ // Compute the number of output fragment of each warp
+ auto warp_num_frag_m = f_get_inner_tile_product(-2);
+ auto warp_num_frag_n = f_get_inner_tile_product(-1);
+
+ Schedule& sch = state->sch;
+ int buffer_ndim =
static_cast<int>(sch->Get(state->block_rv)->writes[0]->buffer->shape.size());
+ // The dimension of the buffer should be larger or same as that of the
tensor intrin.
+ ICHECK_GE(buffer_ndim, 2);
+ int num_higher_dims = buffer_ndim - 2;
+
+ auto index_map =
+ tir::IndexMap::FromFunc(buffer_ndim,
+ // frag_shape_m and frag_shape_n are structural
bindings that cannot
+ // not be automatically captured until c++20
+ [&, frag_shape_m = frag_shape_m,
+ frag_shape_n = frag_shape_n](const
Array<tir::Var>& indices) {
+ Array<PrimExpr> result;
+ result.reserve(indices.size() + 4);
+ for (int i = 0; i < num_higher_dims; ++i) {
+ result.push_back(indices[i]);
+ }
+ const auto& m = indices[num_higher_dims];
+ const auto& n = indices[num_higher_dims + 1];
+ auto accum_m = floormod(m, frag_shape_m);
+ auto accum_n = floormod(n, frag_shape_n);
+ auto outer_m = floordiv(m, frag_shape_m);
+ auto outer_n = floordiv(n, frag_shape_n);
+
+ result.push_back(floordiv(outer_m,
warp_num_frag_m));
+ result.push_back(floordiv(outer_n,
warp_num_frag_n));
+ result.push_back(floormod(outer_m,
warp_num_frag_m));
+ result.push_back(floormod(outer_n,
warp_num_frag_n));
+ result.push_back(accum_m);
+ result.push_back(accum_n);
+ return result;
+ });
+ sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite,
index_map,
+ /*pad_value=*/NullOpt,
/*assume_injective_transform=*/true);
+
+ return {state};
+}
+
std::vector<State> MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore(
TensorCoreState state) const {
// Add the cache write stage for Tensor Core
- int level = r_indices_.front() - 1;
- const LoopRV& loop = state->tiles[level].back();
Schedule& sch = state->sch;
auto cache_write = sch->CacheWrite(state->block_rv, 0, "wmma.accumulator");
- sch->ReverseComputeAt(cache_write, loop, true);
-
- if (state->write_reuse.count(0)) {
- // Fuse the iterators of the cache_write
- Array<LoopRV> buffer_loops = sch->GetLoops(state->write_reuse[0]);
- ICHECK_GT(buffer_loops.size(), 2);
- sch->Fuse(Array<LoopRV>{buffer_loops.end() - 2, // The src shmem is
always 2D
- buffer_loops.end()});
- AnnotateCooperativeFetching(&sch, state->write_reuse[0]);
+
+ // The compute block has been tiled by the warp shape and the fragment shape.
+ // We need to bind the cache write block (from the accumulator to the shared
memory) to the warp
+ // id. The schedule is as follows:
+ //
+ // After adding cache write for wmma.accumulator, we will have
+ // for i0, j0, i1, j1, accum_m, accum_n:
+ // shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1,
accum_m, accum_n]
+ // for i0', j0', i1', j1', accum_m', accum_n':
+ // global_mem[i0', j0', i1', j1', accum_m', accum_n'] =
+ // shared_mem[i0', j0', i1', j1', accum_m', accum_n']
+ // where i0' and j0' are already bound to the block id and warp id.
+ //
+ // To reduce the shared memory usage and allow efficient data movement, we
will apply
+ // transformations to generate the following schedule:
+ //
+ // for i1':
+ // for i0_j0 (fused and bound to threadIdx.y):
+ // for j1, accum_m, accum_n:
+ // shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1,
j1, accum_m, accum_n]
+ // for i0', j0', j1', accum_m', accum_n':
+ // global_mem[i0', j0', i1', j1', accum_m', accum_n'] =
+ // shared_mem[i0', j0', i1', j1', accum_m', accum_n']
+ //
+ // i1' is reordered to the outermost. This effectively allows only a row
(i.e. loop i1') of the
+ // fragments are moved to the shared memory and then to the global memory
each time.
+ // As a result, shared memory for the output will only have shape of [j1,
accum_m, accum_n]
+ // instead of [i0 * i1 * accum_m, j0 * j1 * accum_n].
+
+ // Get the loops other than the innermost two loops (accum_m and accum_n).
+ auto f_get_loops = [&](const BlockRV& block_rv) -> std::array<LoopRV, 4> {
+ Array<LoopRV> buffer_loops = sch->GetLoops(block_rv);
+ ICHECK_GT(buffer_loops.size(), 6);
+ return {buffer_loops[buffer_loops.size() - 6],
buffer_loops[buffer_loops.size() - 5],
+ buffer_loops[buffer_loops.size() - 4],
buffer_loops[buffer_loops.size() - 3]};
+ };
+ {
+ const auto& [i0, j0, i1, j1] = f_get_loops(state->write_reuse[0]);
+ sch->Reorder({i1, i0, j0, j1});
+ sch->ComputeAt(cache_write, i1, true);
+ }
+ {
+ auto loops = f_get_loops(cache_write);
+ const auto& i0 = loops[0];
+ const auto& j0 = loops[1];
+ auto fused = sch->Fuse({i0, j0});
+ sch->Bind(fused, "threadIdx.y");
}
+
sch->ReverseComputeInline(state->tensor_core_reindex_store);
- TileAndAnnotateTensorize(&sch, cache_write,
state->intrin_group.store_intrin);
+ auto loops = sch->GetLoops(cache_write);
+ auto blockized_store = sch->Blockize(loops[loops.size() - 2]);
+ sch->Annotate(blockized_store, tir::attr::meta_schedule_auto_tensorize,
+ state->intrin_group.store_intrin);
+
+ Array<LoopRV> buffer_loops = sch->GetLoops(state->write_reuse[0]);
+ ICHECK_GT(buffer_loops.size(), 5);
+ sch->Fuse(Array<LoopRV>{buffer_loops.end() - 5, // The src shmem is always
2D
+ buffer_loops.end()});
+ AnnotateCooperativeFetching(&sch, state->write_reuse[0]);
return {state};
}
@@ -508,7 +652,8 @@ Optional<LoopRV>
MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
state->sch->state(), GetRef<tir::Block>(block), buffer_index,
index_type);
auto sub_index_map = f_get_sub_index_map(lhs_buffer,
reindexed_buffer_region->region);
buffer_sub_index_map.Set(lhs_buffer, sub_index_map);
- state->sch->TransformLayout(state->block_rv, buffer_index, index_type,
sub_index_map, NullOpt);
+ state->sch->TransformLayout(state->block_rv, buffer_index, index_type,
sub_index_map,
+ /*pad_value=*/NullOpt,
/*assume_injective_transform=*/true);
};
for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) {
@@ -569,6 +714,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
structure, tile_binds, max_innermost_factor, vector_load_lens,
reuse_read, reuse_write);
+ CHECK(node->reuse_write_.req == ReuseType::kMustReuse &&
+ runtime::StorageScope::Create(node->reuse_write_.scope).rank ==
+ runtime::StorageRank::kShared)
+ << "ValueError: Shared memory write reuse must be enabled for
MultiLevelTilingTensorCore.";
+
node->intrin_groups.reserve(intrin_groups.size());
for (const auto& intrin_group_config : intrin_groups) {
node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config));
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc
b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc
index d4c4a10fdd..e68b64ea2d 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc
@@ -48,11 +48,12 @@ class MultiLevelTilingWideVectorNode : public
MultiLevelTilingNode {
return ScheduleRule(n);
}
- Array<tir::LoopRV> SplitLoop(const Schedule& sch, BlockRV block, LoopRV
loop, int n_tiles) const;
+ std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> SplitLoop(const Schedule&
sch, BlockRV block,
+ LoopRV loop, int
n_tiles) const;
};
-Array<tir::LoopRV> MultiLevelTilingWideVectorNode::SplitLoop(const Schedule&
sch, BlockRV block_rv,
- LoopRV loop_rv,
int n_tiles) const {
+std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>>
MultiLevelTilingWideVectorNode::SplitLoop(
+ const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, int n_tiles) const {
const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv));
const tir::StmtSRef block_sref = sch->GetSRef(block_rv);
const tir::BlockNode* block_node = block_sref->StmtAs<tir::BlockNode>();
@@ -99,12 +100,14 @@ Array<tir::LoopRV>
MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch
Array<tir::LoopRV> outer_splits = sch->Split(
/*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(),
outer_factors.end()});
outer_splits.push_back(inner_splits[1]);
- return outer_splits;
+ outer_factors.push_back(PrimExpr(vec_len));
+ return {outer_factors, outer_splits};
} else {
Array<tir::ExprRV> factors(n_tiles - 1, PrimExpr(1));
factors.push_back(loop->extent);
- return sch->Split(/*loop=*/loop_rv,
- /*factors=*/{factors.begin(), factors.end()});
+ Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop_rv,
+ /*factors=*/{factors.begin(),
factors.end()});
+ return {factors, splits};
}
}
}
diff --git a/src/tir/analysis/block_access_region_detector.cc
b/src/tir/analysis/block_access_region_detector.cc
index e9bff1b6fd..ab328efaa6 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -76,8 +76,6 @@ class BlockReadWriteDetector : public StmtExprVisitor {
Map<Var, Buffer> buffer_var_map_;
/*! \brief The target buffer var mapping to its matching */
std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
- /*! \brief The analyzer for simplifying*/
- arith::Analyzer analyzer_;
/*!
* \brief Update read/write buffers and regions with provided buffer and
region
@@ -330,7 +328,12 @@ Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
ICHECK_EQ(buffers[i]->shape.size(), regions[i].size());
for (size_t j = 0; j < regions[i].size(); j++) {
const tvm::arith::IntSet& range = regions[i][j];
- region.push_back(range.CoverRange(Range::FromMinExtent(0,
buffers[i]->shape[j])));
+ if (range.IsSinglePoint()) {
+ PrimExpr min = range.min();
+ region.push_back(Range::FromMinExtent(min, make_const(min.dtype(),
1)));
+ } else {
+ region.push_back(range.CoverRange(Range::FromMinExtent(0,
buffers[i]->shape[j])));
+ }
}
res.push_back(BufferRegion(buffers[i], region));
}
diff --git a/src/tir/schedule/ir_comparator.cc
b/src/tir/schedule/ir_comparator.cc
index 9d89c64163..5353a051a6 100644
--- a/src/tir/schedule/ir_comparator.cc
+++ b/src/tir/schedule/ir_comparator.cc
@@ -43,7 +43,7 @@ class TensorIntrinMismatchError : public ScheduleError {
std::ostringstream os;
os << "The stmt {0} doesn't match the tensor intrin\nThe pattern
attempting to be matched:\n"
<< lhs_stmt_ << "\nDoes not match the tensorize description:\n"
- << rhs_stmt_;
+ << rhs_stmt_ << '\n';
for (const auto& msg : error_messages_) {
os << msg << std::endl;
}
@@ -173,6 +173,9 @@ bool TensorizeComparator::VisitStmt_(const
BlockRealizeNode* op, const Stmt& oth
bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) {
const auto* rhs = other.as<BlockNode>();
+ for (const IterVar& iter : op->iter_vars) {
+ lhs_analyzer_.Bind(iter->var, iter->dom);
+ }
// Check block equality.
// All iter vars and buffer regions including the order should match.
// When checking iter vars, DefEqual is used to remap variables.
@@ -465,7 +468,7 @@ bool TensorizeComparator::CompareBufferRegion(const
BufferRegion& lhs, const Buf
}
return false;
}
- if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) {
+ if (!lhs_analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) {
if (assert_mode_) {
std::ostringstream os;
os << "Buffer base index consistency check failed due to unequal
index base: "
@@ -487,7 +490,8 @@ bool TensorizeComparator::CompareBufferRegion(const
BufferRegion& lhs, const Buf
}
return false;
}
- PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min -
indices_base[i + offset]);
+ PrimExpr normalized_lhs_min =
+ lhs_analyzer_.Simplify((lhs->region[i + offset]->min -
indices_base[i + offset]));
if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) {
if (assert_mode_) {
std::ostringstream os;
diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h
index 394d828673..debf0f946e 100644
--- a/src/tir/schedule/ir_comparator.h
+++ b/src/tir/schedule/ir_comparator.h
@@ -102,8 +102,13 @@ class TensorizeComparator : public ExprComparator, public
StmtComparator {
bool assert_mode_;
/*! \brief Whether it is visiting the scope block (the outermost block). */
bool is_scope_block = true;
- /*! \brief The arithmetic analyzer. */
+ /*! \brief The arithmetic analyzer for comparing LHS and RHS */
arith::Analyzer analyzer_;
+ /*!
+ * \brief The arithmetic analyzer for simplifying expressions on LHS.
+ * This analyzer only contains the domains of the iterators on LHS.
+ */
+ arith::Analyzer lhs_analyzer_;
/*! \brief Additional error messages. Only used when assert_mode is true. */
std::vector<std::string> error_messages_;
// variable remap if any
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
index 9b869b4436..1cab2554e8 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -83,39 +83,39 @@ def test_matmul_relu(shared_scope):
@T.prim_func
def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128,
128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32",
scope=shared_scope)
- C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128],
dtype="float32", scope="wmma.accumulator")
- A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16",
scope=shared_scope)
- B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16",
scope=shared_scope)
- A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128],
dtype="float16", scope="wmma.matrix_a")
- B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128],
dtype="float16", scope="wmma.matrix_b")
+ C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16),
scope=shared_scope)
+ C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16,
16), scope="wmma.accumulator")
+ A_reindex_shared = T.alloc_buffer((128, 128), "float16",
scope=shared_scope)
+ B_reindex_shared = T.alloc_buffer((128, 128), "float16",
scope=shared_scope)
+ A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16",
scope="wmma.matrix_a")
+ B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16",
scope="wmma.matrix_b")
for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"):
for ax0_0_1_ax1_0_1_fused in T.thread_binding(2,
thread="blockIdx.x"):
for ax0_0_2_ax1_0_2_fused in T.thread_binding(2,
thread="threadIdx.y"):
- for ax2_0_0 in T.serial(1):
- for ax0_ax1_fused in T.serial(4096):
+ for ax2_0_0 in range(1):
+ for ax0_ax1_fused in range(4096):
with T.block("A_reindex_shared"):
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused
// 2 * 32 + ax0_ax1_fused // 128)
v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
T.reads(A[v0, v1])
T.writes(A_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "meta_schedule.cooperative_fetch":8})
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 8})
A_reindex_shared[v0, v1] = A[v0, v1]
- for ax0_ax1_fused in T.serial(4096):
+ for ax0_ax1_fused in range(4096):
with T.block("B_reindex_shared"):
v0 = T.axis.spatial(128, ax0_ax1_fused // 32)
v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused
% 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
T.reads(B[v0, v1])
T.writes(B_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "meta_schedule.cooperative_fetch":1})
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 1})
B_reindex_shared[v0, v1] = B[v0, v1]
- for ax2_0_1 in T.serial(4):
+ for ax2_0_1 in range(4):
for ax0_0, ax1_0 in T.grid(2, 2):
with
T.block("A_reindex_shared_wmma.matrix_a_o"):
v0_o = T.axis.spatial(8,
ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0)
v1_o = T.axis.spatial(8, ax2_0_1 * 2 +
ax1_0)
- T.reads(A_reindex_shared[v0_o * 16 : v0_o
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 :
v1_o * 16 + 16])
+ T.reads(A_reindex_shared[v0_o * 16:v0_o *
16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize":
f"wmma_load_16x16x16_f16_a_{intrin_suffix}"})
for ax0_1, ax1_1 in T.grid(16, 16):
with
T.block("A_reindex_shared_wmma.matrix_a"):
@@ -127,8 +127,8 @@ def test_matmul_relu(shared_scope):
with
T.block("B_reindex_shared_wmma.matrix_b_o"):
v0_o = T.axis.spatial(8, ax2_0_1 * 2 +
ax0_0)
v1_o = T.axis.spatial(8,
ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 +
ax0_0_2_ax1_0_2_fused + ax1_0)
- T.reads(B_reindex_shared[v0_o * 16 : v0_o
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 :
v1_o * 16 + 16])
+ T.reads(B_reindex_shared[v0_o * 16:v0_o *
16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize":
f"wmma_load_16x16x16_f16_b_{intrin_suffix}"})
for ax0_1, ax1_1 in T.grid(16, 16):
with
T.block("B_reindex_shared_wmma.matrix_b"):
@@ -141,44 +141,54 @@ def test_matmul_relu(shared_scope):
v0_o = T.axis.spatial(8,
ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4)
v1_o = T.axis.spatial(8, ax1_0_4 +
ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 +
ax0_0_2_ax1_0_2_fused + ax1_0_3)
v2_o = T.axis.reduce(8, ax2_0_0 * 8 +
ax2_0_1 * 2 + ax2_0_2)
-
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 :
v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16,
v1_o * 16 : v1_o * 16 + 16])
-
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o *
16 : v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32",
"warp_execution":1})
+
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o
* 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16,
0:16])
+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32",
"warp_execution": 1})
with T.init():
for ax0_1, ax1_1 in T.grid(16, 16):
with T.block("C_init"):
v0_i_init, v1_i_init =
T.axis.remap("SS", [ax0_1, ax1_1])
T.reads()
-
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 +
v1_i_init])
-
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]
= T.float32(0)
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0,
v0_i_init, v1_i_init])
+
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init,
v1_i_init] = T.float32(0)
for ax0_1, ax1_1, ax2_1 in T.grid(16, 16,
16):
with T.block("C"):
v0_i, v1_i, v2_i =
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
-
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i],
A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i],
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
-
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] +
T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i],
"float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16
+ v1_i], "float32")
- for ax0_0, ax1_0 in T.grid(2, 1):
- with T.block("C_reindex_shared_wmma.accumulator_o"):
- v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused //
2 * 2 + ax0_0)
- v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2
* 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0)
- T.reads(C_reindex_shared_wmma_accumulator[v0_o *
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
- T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 +
16, v1_o * 16 : v1_o * 16 + 16])
- T.block_attr({"meta_schedule.auto_tensorize":
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
- for ax0_1, ax1_1 in T.grid(16, 16):
- with
T.block("C_reindex_shared_wmma.accumulator"):
- v0_i, v1_i = T.axis.remap("SS", [ax0_1,
ax1_1])
-
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
- T.writes(C_reindex_shared[v0_o * 16 +
v0_i, v1_o * 16 + v1_i])
- C_reindex_shared[v0_o * 16 + v0_i, v1_o *
16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 +
v1_i]
- for ax0_ax1_fused in T.serial(1024):
- with T.block("C_reindex_shared"):
- v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 *
32 + ax0_ax1_fused // 32)
- v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 *
64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
- T.reads(C_reindex_shared[v0, v1])
- T.writes(compute[v0, v1])
- T.block_attr({"meta_schedule.cooperative_fetch":4})
- compute[v0, v1] = T.max(C_reindex_shared[v0, v1],
T.float32(0))
+
T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i,
v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i],
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i,
v1_i])
+
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] =
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] +
T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 +
v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i,
v1_o * 16 + v1_i])
+ for ax2 in range(2):
+ for ax0_ax1_fused in T.thread_binding(2,
thread="threadIdx.y"):
+ for ax2_1, ax3 in T.grid(1, 1):
+ with
T.block("C_reindex_shared_wmma.accumulator_o"):
+ v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused
// 2)
+ v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused %
2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused)
+ v2 = T.axis.spatial(2, ax2 + ax2_1)
+ v3 = T.axis.spatial(1, ax3)
+ v4_o = T.axis.spatial(1, 0)
+ v5_o = T.axis.spatial(1, 0)
+ T.reads(C_reindex_shared_wmma_accumulator[v0,
v1, v2, v3, 0:16, 0:16])
+ T.writes(C_reindex_shared[v0, v1, v2, v3,
0:16, 0:16])
+ T.block_attr({"meta_schedule.auto_tensorize":
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
+ for ax4, ax5 in T.grid(16, 16):
+ with
T.block("C_reindex_shared_wmma.accumulator"):
+ v4_i, v5_i = T.axis.remap("SS", [ax4,
ax5])
+
T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i])
+ T.writes(C_reindex_shared[v0, v1, v2,
v3, v4_i, v5_i])
+ C_reindex_shared[v0, v1, v2, v3, v4_i,
v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]
+ for ax0_ax1_ax3_ax4_ax5_fused in range(512):
+ with T.block("C_reindex_shared"):
+ v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2)
+ v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 *
4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256)
+ v2 = T.axis.spatial(2, ax2)
+ v3 = T.axis.spatial(1, 0)
+ v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 256 // 16)
+ v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 16)
+ T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
+ T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 *
16])
+ T.block_attr({"meta_schedule.cooperative_fetch":
4})
+ compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] =
T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
# fmt: on
decision_0 = [
@@ -223,44 +233,42 @@ def test_matmul_relu_with_fallback():
# fmt: off
@T.prim_func
def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B:
T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) ->
None:
- # function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- # body
- # with T.block("root")
- C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32",
scope="shared")
- C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128],
dtype="float32", scope="wmma.accumulator")
- A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16",
scope="shared")
- B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16",
scope="shared")
- A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128],
dtype="float16", scope="wmma.matrix_a")
- B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128],
dtype="float16", scope="wmma.matrix_b")
+ # with T.block("root"):
+ C_reindex_shared = T.alloc_buffer((4, 2, 2, 4, 16, 16), scope="shared")
+ C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 2, 2, 4, 16,
16), scope="wmma.accumulator")
+ A_reindex_shared = T.alloc_buffer((128, 128), "float16",
scope="shared")
+ B_reindex_shared = T.alloc_buffer((128, 128), "float16",
scope="shared")
+ A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16",
scope="wmma.matrix_a")
+ B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16",
scope="wmma.matrix_b")
for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"):
for ax0_0_1_ax1_0_1_fused in T.thread_binding(2,
thread="blockIdx.x"):
for ax0_0_2_ax1_0_2_fused in T.thread_binding(2,
thread="threadIdx.y"):
- for ax2_0_0 in T.serial(2):
- for ax0_ax1_fused in T.serial(2048):
+ for ax2_0_0 in range(2):
+ for ax0_ax1_fused in range(2048):
with T.block("A_reindex_shared"):
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused
* 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 64)
v1 = T.axis.spatial(128, ax2_0_0 * 64 +
ax0_ax1_fused % 64)
T.reads(A[v0, v1])
T.writes(A_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "meta_schedule.cooperative_fetch":4})
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 4})
A_reindex_shared[v0, v1] = A[v0, v1]
- for ax0_ax1_fused in T.serial(8192):
+ for ax0_ax1_fused in range(8192):
with T.block("B_reindex_shared"):
v0 = T.axis.spatial(128, ax2_0_0 * 64 +
ax0_ax1_fused // 128)
v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
T.reads(B[v0, v1])
T.writes(B_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "meta_schedule.cooperative_fetch":2})
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 2})
B_reindex_shared[v0, v1] = B[v0, v1]
- for ax2_0_1 in T.serial(1):
+ for ax2_0_1 in range(1):
for ax0_0, ax1_0 in T.grid(2, 4):
with
T.block("A_reindex_shared_wmma.matrix_a_o"):
v0_o = T.axis.spatial(8,
ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0)
v1_o = T.axis.spatial(8, ax2_0_0 * 4 +
ax1_0)
- T.reads(A_reindex_shared[v0_o * 16 : v0_o
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 :
v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"})
+ T.reads(A_reindex_shared[v0_o * 16:v0_o *
16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_load_16x16x16_f16_a_shared"})
for ax0_1, ax1_1 in T.grid(16, 16):
with
T.block("A_reindex_shared_wmma.matrix_a"):
v0_i, v1_i = T.axis.remap("SS",
[ax0_1, ax1_1])
@@ -271,9 +279,9 @@ def test_matmul_relu_with_fallback():
with
T.block("B_reindex_shared_wmma.matrix_b_o"):
v0_o = T.axis.spatial(8, ax2_0_0 * 4 +
ax0_0)
v1_o = T.axis.spatial(8,
ax0_0_2_ax1_0_2_fused * 4 + ax1_0)
- T.reads(B_reindex_shared[v0_o * 16 : v0_o
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 :
v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"})
+ T.reads(B_reindex_shared[v0_o * 16:v0_o *
16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_load_16x16x16_f16_b_shared"})
for ax0_1, ax1_1 in T.grid(16, 16):
with
T.block("B_reindex_shared_wmma.matrix_b"):
v0_i, v1_i = T.axis.remap("SS",
[ax0_1, ax1_1])
@@ -285,44 +293,54 @@ def test_matmul_relu_with_fallback():
v0_o = T.axis.spatial(8,
ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_3 * 2 + ax0_0_4)
v1_o = T.axis.spatial(8,
ax0_0_2_ax1_0_2_fused * 4 + ax1_0_3 * 4 + ax1_0_4)
v2_o = T.axis.reduce(8, ax2_0_0 * 4 +
ax2_0_1 * 4 + ax2_0_2)
-
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 :
v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16,
v1_o * 16 : v1_o * 16 + 16])
-
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o *
16 : v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32",
"warp_execution":1})
+
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o
* 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o
% 4, 0:16, 0:16])
+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32",
"warp_execution": 1})
with T.init():
for ax0_1, ax1_1 in T.grid(16, 16):
with T.block("C_init"):
v0_i_init, v1_i_init =
T.axis.remap("SS", [ax0_1, ax1_1])
T.reads()
-
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 +
v1_i_init])
-
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]
= T.float32(0)
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o
% 4, v0_i_init, v1_i_init])
+
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4,
v0_i_init, v1_i_init] = T.float32(0)
for ax0_1, ax1_1, ax2_1 in T.grid(16, 16,
16):
with T.block("C"):
v0_i, v1_i, v2_i =
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
-
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i],
A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i],
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
-
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] +
T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i],
"float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16
+ v1_i], "float32")
- for ax0_0, ax1_0 in T.grid(2, 4):
- with T.block("C_reindex_shared_wmma.accumulator_o"):
- v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4
+ ax0_0_1_ax1_0_1_fused * 2 + ax0_0)
- v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4
+ ax1_0)
- T.reads(C_reindex_shared_wmma_accumulator[v0_o *
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
- T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 +
16, v1_o * 16 : v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
- for ax0_1, ax1_1 in T.grid(16, 16):
- with
T.block("C_reindex_shared_wmma.accumulator"):
- v0_i, v1_i = T.axis.remap("SS", [ax0_1,
ax1_1])
-
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
- T.writes(C_reindex_shared[v0_o * 16 +
v0_i, v1_o * 16 + v1_i])
- C_reindex_shared[v0_o * 16 + v0_i, v1_o *
16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 +
v1_i]
- for ax0_ax1_fused in T.serial(4096):
- with T.block("C_reindex_shared"):
- v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 +
ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 128)
- v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
- T.reads(C_reindex_shared[v0, v1])
- T.writes(compute[v0, v1])
- T.block_attr({"meta_schedule.cooperative_fetch":4})
- compute[v0, v1] = T.max(C_reindex_shared[v0, v1],
T.float32(0))
+
T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o
% 4, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 +
v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o
% 4, v0_i, v1_i])
+
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4,
v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2,
v1_o % 4, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o *
16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32",
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+ for ax2 in range(2):
+ for ax0_ax1_fused in T.thread_binding(2,
thread="threadIdx.y"):
+ for ax2_1, ax3 in T.grid(1, 4):
+ with
T.block("C_reindex_shared_wmma.accumulator_o"):
+ v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused *
2 + ax0_0_1_ax1_0_1_fused)
+ v1 = T.axis.spatial(2, ax0_ax1_fused)
+ v2 = T.axis.spatial(2, ax2 + ax2_1)
+ v3 = T.axis.spatial(4, ax3)
+ v4_o = T.axis.spatial(1, 0)
+ v5_o = T.axis.spatial(1, 0)
+ T.reads(C_reindex_shared_wmma_accumulator[v0,
v1, v2, v3, 0:16, 0:16])
+ T.writes(C_reindex_shared[v0, v1, v2, v3,
0:16, 0:16])
+ T.block_attr({"meta_schedule.auto_tensorize":
"wmma_store_16x16x16_f32_shared"})
+ for ax4, ax5 in T.grid(16, 16):
+ with
T.block("C_reindex_shared_wmma.accumulator"):
+ v4_i, v5_i = T.axis.remap("SS", [ax4,
ax5])
+
T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i])
+ T.writes(C_reindex_shared[v0, v1, v2,
v3, v4_i, v5_i])
+ C_reindex_shared[v0, v1, v2, v3, v4_i,
v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]
+ for ax0_ax1_ax3_ax4_ax5_fused in range(2048):
+ with T.block("C_reindex_shared"):
+ v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 2 +
ax0_0_1_ax1_0_1_fused)
+ v1 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused
// 1024)
+ v2 = T.axis.spatial(2, ax2)
+ v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused %
1024 // 256)
+ v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 256 // 16)
+ v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 16)
+ T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
+ T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v3 *
16 + v1 * 64])
+ T.block_attr({"meta_schedule.cooperative_fetch":
4})
+ compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1
* 64] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
# fmt: on
decision_0 = [
("SamplePerfectTile", [2, 2, 1, 1, 2]),
@@ -373,46 +391,46 @@ def test_conv2d(shared_scope):
@T.prim_func
def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight:
T.Buffer((3, 3, 32, 32), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 32),
"float32")) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="float16")
- conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32],
dtype="float32", scope=shared_scope)
- conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256,
32], dtype="float32", scope="wmma.accumulator")
- PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16",
scope=shared_scope)
- weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16",
scope=shared_scope)
- PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 288],
dtype="float16", scope="wmma.matrix_a")
- weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([288, 32],
dtype="float16", scope="wmma.matrix_b")
+ PadInput = T.alloc_buffer((1, 18, 18, 32), "float16")
+ conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 2, 1, 1, 16, 16),
scope=shared_scope)
+ conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 2,
1, 1, 16, 16), scope="wmma.accumulator")
+ PadInput_reindex_shared = T.alloc_buffer((256, 288), "float16",
scope=shared_scope)
+ weight_reindex_shared = T.alloc_buffer((288, 32), "float16",
scope=shared_scope)
+ PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 288),
"float16", scope="wmma.matrix_a")
+ weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((288, 32),
"float16", scope="wmma.matrix_b")
for i0, i1, i2, i3 in T.grid(1, 18, 18, 32):
with T.block("PadInput"):
- i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
- T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1])
- T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
- PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1
and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1,
i3_1], T.float16(0), dtype="float16")
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3])
+ T.writes(PadInput[v_i0, v_i1, v_i2, v_i3])
+ PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1
and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1,
v_i3], T.float16(0))
for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"):
for ax0_0_1_ax1_0_1_fused in T.thread_binding(16,
thread="blockIdx.x"):
for ax0_0_2_ax1_0_2_fused in T.thread_binding(1,
thread="threadIdx.y"):
- for ax2_0_0 in T.serial(1):
- for ax0_ax1_fused in T.serial(4608):
+ for ax2_0_0 in range(1):
+ for ax0_ax1_fused in range(4608):
with T.block("PadInput_reindex_shared"):
v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused
* 16 + ax0_ax1_fused // 288)
v1 = T.axis.spatial(288, ax0_ax1_fused % 288)
T.reads(PadInput[v0 // 256, v1 // 96 + v0 //
16, v1 % 96 // 32 + v0 % 16, v1 % 32])
T.writes(PadInput_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "meta_schedule.cooperative_fetch":2})
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 2})
PadInput_reindex_shared[v0, v1] = PadInput[v0
// 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]
- for ax0_ax1_fused in T.serial(4608):
+ for ax0_ax1_fused in range(4608):
with T.block("weight_reindex_shared"):
v0 = T.axis.spatial(288, ax0_ax1_fused // 16)
v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused
* 16 + ax0_ax1_fused % 16)
T.reads(weight[v0 // 96, v0 % 96 // 32, v0 %
32, v1])
T.writes(weight_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "meta_schedule.cooperative_fetch":8})
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 8})
weight_reindex_shared[v0, v1] = weight[v0 //
96, v0 % 96 // 32, v0 % 32, v1]
- for ax2_0_1 in T.serial(18):
+ for ax2_0_1 in range(18):
for ax0_0, ax1_0 in T.grid(1, 1):
with
T.block("PadInput_reindex_shared_wmma.matrix_a_o"):
v0_o = T.axis.spatial(16,
ax0_0_1_ax1_0_1_fused + ax0_0)
v1_o = T.axis.spatial(18, ax2_0_1 + ax1_0)
- T.reads(PadInput_reindex_shared[v0_o * 16
: v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o
* 16 : v1_o * 16 + 16])
+ T.reads(PadInput_reindex_shared[v0_o *
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize":
f"wmma_load_16x16x16_f16_a_{intrin_suffix}"})
for ax0_1, ax1_1 in T.grid(16, 16):
with
T.block("PadInput_reindex_shared_wmma.matrix_a"):
@@ -424,8 +442,8 @@ def test_conv2d(shared_scope):
with
T.block("weight_reindex_shared_wmma.matrix_b_o"):
v0_o = T.axis.spatial(18, ax2_0_1 + ax0_0)
v1_o = T.axis.spatial(2,
ax0_0_0_ax1_0_0_fused + ax1_0)
- T.reads(weight_reindex_shared[v0_o * 16 :
v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o *
16 : v1_o * 16 + 16])
+ T.reads(weight_reindex_shared[v0_o *
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize":
f"wmma_load_16x16x16_f16_b_{intrin_suffix}"})
for ax0_1, ax1_1 in T.grid(16, 16):
with
T.block("weight_reindex_shared_wmma.matrix_b"):
@@ -438,44 +456,49 @@ def test_conv2d(shared_scope):
v0_o = T.axis.spatial(16, ax0_0_4 +
ax0_0_1_ax1_0_1_fused + ax0_0_3)
v1_o = T.axis.spatial(2,
ax0_0_0_ax1_0_0_fused + ax1_0_3 + ax1_0_4)
v2_o = T.axis.reduce(18, ax2_0_0 * 18 +
ax2_0_1 + ax2_0_2)
-
T.reads(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o
* 16 : v2_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o *
16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 +
16, v1_o * 16 : v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32",
"warp_execution":1})
+
T.reads(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o *
16:v2_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 +
16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, 0:16,
0:16])
+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32",
"warp_execution": 1})
with T.init():
for ax0_1, ax1_1 in T.grid(16, 16):
with T.block("conv2d_nhwc_init"):
v0_i_init, v1_i_init =
T.axis.remap("SS", [ax0_1, ax1_1])
T.reads()
-
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init,
v1_o * 16 + v1_i_init])
-
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 +
v1_i_init] = T.float32(0)
+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0,
v0_i_init, v1_i_init])
+
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init,
v1_i_init] = T.float32(0)
for ax0_1, ax1_1, ax2_1 in T.grid(16, 16,
16):
with T.block("conv2d_nhwc"):
v0_i, v1_i, v2_i =
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
-
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16
+ v1_i], PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 +
v2_i], weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
-
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o *
16 + v1_i])
-
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
= conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 +
v1_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o *
16 + v2_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v2_o * 16 +
v2_i, v1_o * 16 + v1_i], "float32")
- for ax0_0, ax1_0 in T.grid(1, 1):
- with
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
- v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused +
ax0_0)
- v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused +
ax1_0)
-
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16,
v1_o * 16 : v1_o * 16 + 16])
- T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 :
v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
- T.block_attr({"meta_schedule.auto_tensorize":
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
- for ax0_1, ax1_1 in T.grid(16, 16):
- with
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
- v0_i, v1_i = T.axis.remap("SS", [ax0_1,
ax1_1])
-
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16
+ v1_i])
- T.writes(conv2d_nhwc_reindex_shared[v0_o *
16 + v0_i, v1_o * 16 + v1_i])
- conv2d_nhwc_reindex_shared[v0_o * 16 +
v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16
+ v0_i, v1_o * 16 + v1_i]
- for ax0_ax1_fused in T.serial(256):
- with T.block("conv2d_nhwc_reindex_shared"):
- v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 +
ax0_ax1_fused // 16)
- v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 +
ax0_ax1_fused % 16)
- T.reads(conv2d_nhwc_reindex_shared[v0, v1])
- T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1])
- T.block_attr({"meta_schedule.cooperative_fetch":3})
- conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1] =
conv2d_nhwc_reindex_shared[v0, v1]
+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i,
v1_i], PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 +
v2_i], weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i,
v1_i])
+
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] =
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] +
T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o
* 16 + v2_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v2_o * 16
+ v2_i, v1_o * 16 + v1_i])
+ for ax2 in range(1):
+ for ax0_ax1_fused in T.thread_binding(1,
thread="threadIdx.y"):
+ for ax2_1, ax3 in T.grid(1, 1):
+ with
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS",
[ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2_1, ax3])
+ v4_o = T.axis.spatial(1, 0)
+ v5_o = T.axis.spatial(1, 0)
+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16])
+ T.writes(conv2d_nhwc_reindex_shared[v0, v1,
v2, v3, 0:16, 0:16])
+ T.block_attr({"meta_schedule.auto_tensorize":
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
+ for ax4, ax5 in T.grid(16, 16):
+ with
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
+ v4_i, v5_i = T.axis.remap("SS", [ax4,
ax5])
+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i])
+
T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i])
+ conv2d_nhwc_reindex_shared[v0, v1, v2,
v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3,
v4_i, v5_i]
+ for ax0_ax1_ax3_ax4_ax5_fused in range(256):
+ with T.block("conv2d_nhwc_reindex_shared"):
+ v0, v1, v2 = T.axis.remap("SSS",
[ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2])
+ v3 = T.axis.spatial(1, 0)
+ v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
// 16)
+ v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 16)
+ T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3,
v4, v5])
+ T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 +
v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16])
+ T.block_attr({"meta_schedule.cooperative_fetch":
3})
+ conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16)
// 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1,
v2, v3, v4, v5]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 16, 1, 1, 1]),
@@ -551,40 +574,40 @@ def test_matmul_relu_pipeline(shared_scope):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
- C = T.alloc_buffer([128, 128], dtype="float32")
- C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32",
scope=shared_scope)
- C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128],
dtype="float32", scope="wmma.accumulator")
- A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16",
scope=shared_scope)
- B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16",
scope=shared_scope)
- A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128],
dtype="float16", scope="wmma.matrix_a")
- B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128],
dtype="float16", scope="wmma.matrix_b")
+ C = T.alloc_buffer((128, 128))
+ C_reindex_shared = T.alloc_buffer((4, 4, 2, 2, 16, 16),
scope=shared_scope)
+ C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 4, 2, 2, 16,
16), scope="wmma.accumulator")
+ A_reindex_shared = T.alloc_buffer((128, 128), "float16",
scope=shared_scope)
+ B_reindex_shared = T.alloc_buffer((128, 128), "float16",
scope=shared_scope)
+ A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16",
scope="wmma.matrix_a")
+ B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16",
scope="wmma.matrix_b")
for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"):
for ax0_0_1_ax1_0_1_fused in T.thread_binding(16,
thread="blockIdx.x"):
for ax0_0_2_ax1_0_2_fused in T.thread_binding(1,
thread="threadIdx.y"):
- for ax2_0_0 in T.serial(4,
annotations={"software_pipeline_order":[0, 3, 1, 4, 5, 2, 6],
"software_pipeline_stage":[0, 0, 0, 0, 0, 1, 1]}):
- for ax0_ax1_fused in T.serial(1024):
+ for ax2_0_0 in T.serial(4,
annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6],
"software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}):
+ for ax0_ax1_fused in range(1024):
with T.block("A_reindex_shared"):
v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused
// 4 * 32 + ax0_ax1_fused // 32)
v1 = T.axis.spatial(128, ax2_0_0 * 32 +
ax0_ax1_fused % 32)
T.reads(A[v0, v1])
T.writes(A_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":4,
"tir.manifest_shared_memory_local_stage":1})
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 4,
"tir.manifest_shared_memory_local_stage": 1})
A_reindex_shared[v0, v1] = A[v0, v1]
- for ax0_ax1_fused in T.serial(1024):
+ for ax0_ax1_fused in range(1024):
with T.block("B_reindex_shared"):
v0 = T.axis.spatial(128, ax2_0_0 * 32 +
ax0_ax1_fused // 32)
v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused
% 4 * 32 + ax0_ax1_fused % 32)
T.reads(B[v0, v1])
T.writes(B_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":2,
"tir.manifest_shared_memory_local_stage":1})
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 2,
"tir.manifest_shared_memory_local_stage": 1})
B_reindex_shared[v0, v1] = B[v0, v1]
- for ax2_0_1 in T.serial(2,
annotations={"software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0,
0, 1]}):
+ for ax2_0_1 in T.serial(2,
annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage":
[0, 0, 1]}):
for ax0_0, ax1_0 in T.grid(2, 1):
with
T.block("A_reindex_shared_wmma.matrix_a_o"):
v0_o = T.axis.spatial(8,
ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0)
v1_o = T.axis.spatial(8, ax2_0_0 * 2 +
ax2_0_1 + ax1_0)
- T.reads(A_reindex_shared[v0_o * 16 : v0_o
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 :
v1_o * 16 + 16])
+ T.reads(A_reindex_shared[v0_o * 16:v0_o *
16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize":
f"wmma_load_16x16x16_f16_a_{intrin_suffix}"})
for ax0_1, ax1_1 in T.grid(16, 16):
with
T.block("A_reindex_shared_wmma.matrix_a"):
@@ -596,8 +619,8 @@ def test_matmul_relu_pipeline(shared_scope):
with
T.block("B_reindex_shared_wmma.matrix_b_o"):
v0_o = T.axis.spatial(8, ax2_0_0 * 2 +
ax2_0_1 + ax0_0)
v1_o = T.axis.spatial(8,
ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0)
- T.reads(B_reindex_shared[v0_o * 16 : v0_o
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 :
v1_o * 16 + 16])
+ T.reads(B_reindex_shared[v0_o * 16:v0_o *
16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize":
f"wmma_load_16x16x16_f16_b_{intrin_suffix}"})
for ax0_1, ax1_1 in T.grid(16, 16):
with
T.block("B_reindex_shared_wmma.matrix_b"):
@@ -610,50 +633,61 @@ def test_matmul_relu_pipeline(shared_scope):
v0_o = T.axis.spatial(8,
ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0_3 * 2 + ax0_0_4)
v1_o = T.axis.spatial(8,
ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0_3 * 2 + ax1_0_4)
v2_o = T.axis.reduce(8, ax2_0_0 * 2 +
ax2_0_1 + ax2_0_2)
-
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 :
v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16,
v1_o * 16 : v1_o * 16 + 16])
-
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o *
16 : v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32",
"warp_execution":1})
+
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o
* 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o
% 2, 0:16, 0:16])
+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32",
"warp_execution": 1})
with T.init():
for ax0_1, ax1_1 in T.grid(16, 16):
with T.block("C_init"):
v0_i_init, v1_i_init =
T.axis.remap("SS", [ax0_1, ax1_1])
T.reads()
-
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 +
v1_i_init])
-
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]
= T.float32(0)
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o
% 2, v0_i_init, v1_i_init])
+
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2,
v0_i_init, v1_i_init] = T.float32(0)
for ax0_1, ax1_1, ax2_1 in T.grid(16, 16,
16):
with T.block("C"):
v0_i, v1_i, v2_i =
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
-
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i],
A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i],
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
-
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] +
T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i],
"float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16
+ v1_i], "float32")
- for ax0_0, ax1_0 in T.grid(2, 2):
- with T.block("C_reindex_shared_wmma.accumulator_o"):
- v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused //
4 * 2 + ax0_0)
- v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4
* 2 + ax1_0)
- T.reads(C_reindex_shared_wmma_accumulator[v0_o *
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
- T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 +
16, v1_o * 16 : v1_o * 16 + 16])
- T.block_attr({"meta_schedule.auto_tensorize":
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
- for ax0_1, ax1_1 in T.grid(16, 16):
- with
T.block("C_reindex_shared_wmma.accumulator"):
- v0_i, v1_i = T.axis.remap("SS", [ax0_1,
ax1_1])
-
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
- T.writes(C_reindex_shared[v0_o * 16 +
v0_i, v1_o * 16 + v1_i])
- C_reindex_shared[v0_o * 16 + v0_i, v1_o *
16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 +
v1_i]
- for ax0_ax1_fused in T.grid(1024):
- with T.block("C_reindex_shared"):
- v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 *
32 + ax0_ax1_fused // 32)
- v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 *
32 + ax0_ax1_fused % 32)
- T.reads(C_reindex_shared[v0, v1])
- T.writes(C[v0, v1])
- T.block_attr({"meta_schedule.cooperative_fetch":3})
- C[v0, v1] = C_reindex_shared[v0, v1]
+
T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o
% 2, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 +
v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o
% 2, v0_i, v1_i])
+
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2,
v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2,
v1_o % 2, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o *
16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32",
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+ for ax2 in range(2):
+ for ax0_ax1_fused in T.thread_binding(1,
thread="threadIdx.y"):
+ for ax2_1, ax3 in T.grid(1, 2):
+ with
T.block("C_reindex_shared_wmma.accumulator_o"):
+ v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused
// 4)
+ v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused %
4)
+ v2 = T.axis.spatial(2, ax2 + ax2_1)
+ v3 = T.axis.spatial(2, ax3)
+ v4_o = T.axis.spatial(1, 0)
+ v5_o = T.axis.spatial(1, 0)
+ T.reads(C_reindex_shared_wmma_accumulator[v0,
v1, v2, v3, 0:16, 0:16])
+ T.writes(C_reindex_shared[v0, v1, v2, v3,
0:16, 0:16])
+ T.block_attr({"meta_schedule.auto_tensorize":
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
+ for ax4, ax5 in T.grid(16, 16):
+ with
T.block("C_reindex_shared_wmma.accumulator"):
+ v4_i, v5_i = T.axis.remap("SS", [ax4,
ax5])
+
T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i])
+ T.writes(C_reindex_shared[v0, v1, v2,
v3, v4_i, v5_i])
+ C_reindex_shared[v0, v1, v2, v3, v4_i,
v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]
+ for ax0_ax1_ax3_ax4_ax5_fused in range(512):
+ with T.block("C_reindex_shared"):
+ v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused // 4)
+ v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 4)
+ v2 = T.axis.spatial(2, ax2)
+ v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused
// 256)
+ v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 256 // 16)
+ v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 16)
+ T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
+ T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 +
v1 * 32])
+ T.block_attr({"meta_schedule.cooperative_fetch":
3})
+ C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 32]
= C_reindex_shared[v0, v1, v2, v3, v4, v5]
for i0, i1 in T.grid(128, 128):
with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(C[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(C[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.max(C[v_i0, v_i1], T.float32(0))
+
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 4, 1, 1, 2]),
@@ -693,141 +727,6 @@ def test_matmul_relu_pipeline(shared_scope):
)
-def test_matmul_relu_global():
- # fmt: off
- @T.prim_func
- def matmul_relu_global_0(A: T.Buffer((128, 128), "float16"), B:
T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) ->
None:
- # function attr dict
- T.func_attr({"global_symbol": "main", "tir.noalias": True})
- # body
- # with T.block("root")
- C = T.alloc_buffer([128, 128], dtype="float32")
- C_reindex_wmma_accumulator = T.alloc_buffer([128, 128],
dtype="float32", scope="wmma.accumulator")
- A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16",
scope="shared")
- B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16",
scope="shared")
- A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128],
dtype="float16", scope="wmma.matrix_a")
- B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128],
dtype="float16", scope="wmma.matrix_b")
- for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"):
- for ax0_0_1_ax1_0_1_fused in T.thread_binding(1,
thread="blockIdx.x"):
- for ax0_0_2_ax1_0_2_fused in T.thread_binding(16,
thread="threadIdx.y"):
- for ax2_0_0 in T.serial(2):
- for ax0_ax1_fused in T.serial(8192):
- with T.block("A_reindex_shared"):
- v0 = T.axis.spatial(128, ax0_ax1_fused // 64)
- v1 = T.axis.spatial(128, ax2_0_0 * 64 +
ax0_ax1_fused % 64)
- T.reads(A[v0, v1])
- T.writes(A_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "meta_schedule.cooperative_fetch":1})
- A_reindex_shared[v0, v1] = A[v0, v1]
- for ax0_ax1_fused in T.serial(8192):
- with T.block("B_reindex_shared"):
- v0 = T.axis.spatial(128, ax2_0_0 * 64 +
ax0_ax1_fused // 128)
- v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
- T.reads(B[v0, v1])
- T.writes(B_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "meta_schedule.cooperative_fetch":1})
- B_reindex_shared[v0, v1] = B[v0, v1]
- for ax2_0_1 in T.serial(2):
- for ax0_0, ax1_0 in T.grid(1, 2):
- with
T.block("A_reindex_shared_wmma.matrix_a_o"):
- v0_o = T.axis.spatial(8,
ax0_0_2_ax1_0_2_fused // 2 + ax0_0)
- v1_o = T.axis.spatial(8, ax2_0_0 * 4 +
ax2_0_1 * 2 + ax1_0)
- T.reads(A_reindex_shared[v0_o * 16 : v0_o
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 :
v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"})
- for ax0_1, ax1_1 in T.grid(16, 16):
- with
T.block("A_reindex_shared_wmma.matrix_a"):
- v0_i, v1_i = T.axis.remap("SS",
[ax0_1, ax1_1])
- T.reads(A_reindex_shared[v0_o * 16
+ v0_i, v1_o * 16 + v1_i])
-
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-
A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
- for ax0_0, ax1_0 in T.grid(2, 4):
- with
T.block("B_reindex_shared_wmma.matrix_b_o"):
- v0_o = T.axis.spatial(8, ax2_0_0 * 4 +
ax2_0_1 * 2 + ax0_0)
- v1_o = T.axis.spatial(8,
ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0)
- T.reads(B_reindex_shared[v0_o * 16 : v0_o
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 :
v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"})
- for ax0_1, ax1_1 in T.grid(16, 16):
- with
T.block("B_reindex_shared_wmma.matrix_b"):
- v0_i, v1_i = T.axis.remap("SS",
[ax0_1, ax1_1])
- T.reads(B_reindex_shared[v0_o * 16
+ v0_i, v1_o * 16 + v1_i])
-
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-
B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
- for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in
T.grid(1, 4, 2, 1, 1):
- with T.block("C_o"):
- v0_o = T.axis.spatial(8,
ax0_0_2_ax1_0_2_fused // 2 + ax0_0_3 + ax0_0_4)
- v1_o = T.axis.spatial(8, ax1_0_4 +
ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0_3)
- v2_o = T.axis.reduce(8, ax2_0_0 * 4 +
ax2_0_1 * 2 + ax2_0_2)
-
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 :
v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16,
v1_o * 16 : v1_o * 16 + 16])
- T.writes(C_reindex_wmma_accumulator[v0_o *
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32",
"warp_execution":1})
- with T.init():
- for ax0_1, ax1_1 in T.grid(16, 16):
- with T.block("C_init"):
- v0_i_init, v1_i_init =
T.axis.remap("SS", [ax0_1, ax1_1])
- T.reads()
-
T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 +
v1_i_init])
-
C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] =
T.float32(0)
- for ax0_1, ax1_1, ax2_1 in T.grid(16, 16,
16):
- with T.block("C"):
- v0_i, v1_i, v2_i =
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
-
T.reads(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i],
A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i],
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
-
T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
- C_reindex_wmma_accumulator[v0_o *
16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i,
v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i,
v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16
+ v2_i, v1_o * 16 + v1_i], "float32")
- for ax0_0, ax1_0 in T.grid(1, 4):
- with T.block("C_reindex_wmma.accumulator_o"):
- v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused //
2 + ax0_0)
- v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2
* 4 + ax1_0)
- T.reads(C_reindex_wmma_accumulator[v0_o * 16 :
v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
- T.writes(C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 :
v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_global"})
- for ax0_1, ax1_1 in T.grid(16, 16):
- with T.block("C_reindex_wmma.accumulator"):
- v0_i, v1_i = T.axis.remap("SS", [ax0_1,
ax1_1])
- T.reads(C_reindex_wmma_accumulator[v0_o *
16 + v0_i, v1_o * 16 + v1_i])
- T.writes(C[v0_o * 16 + v0_i, v1_o * 16 +
v1_i])
- C[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
- for i0, i1 in T.grid(128, 128):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(C[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
- # fmt: on
- decision_0 = [
- ("SamplePerfectTile", [1, 1, 8, 1, 1]),
- ("SamplePerfectTile", [1, 1, 2, 4, 1]),
- ("SamplePerfectTile", [2, 2, 2]),
- ("SampleCategorical", 0),
- ("SampleCategorical", 0),
- ]
- mod = te.create_prim_func(
- te_workload.matmul_relu(
- n=128,
- m=128,
- k=128,
- in_dtype="float16",
- out_dtype="float32",
- )
- )
- actual = generate_design_space(
- kind="cuda",
- mod=mod,
- target=tvm.target.Target("cuda"),
- types=None,
- sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")]
- + get_rules("cuda", ms.schedule_rule.AutoInline),
- )
- check_sketches(
- mod,
- sketches=actual,
- expected_mods=[matmul_relu_global_0],
- expected_decisions=[decision_0],
- )
-
-
def test_matmul_relu_non_tensorizable():
# expected to do nothing on non-tensorizable workloads
mod = te.create_prim_func(
@@ -842,7 +741,7 @@ def test_matmul_relu_non_tensorizable():
mod=mod,
target=tvm.target.Target("cuda"),
types=None,
- sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")]
+ sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")]
+ get_rules("cuda", ms.schedule_rule.AutoInline),
)
tvm.ir.assert_structural_equal(mod, sch.mod["main"])
@@ -856,40 +755,40 @@ def test_padded_matmul_relu():
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
- C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32",
scope="shared")
- C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128],
dtype="float32", scope="wmma.accumulator")
- A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16",
scope="shared")
- B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16",
scope="shared")
- A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128],
dtype="float16", scope="wmma.matrix_a")
- B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128],
dtype="float16", scope="wmma.matrix_b")
+ C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="shared")
+ C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16,
16), scope="wmma.accumulator")
+ A_reindex_shared = T.alloc_buffer((128, 128), "float16",
scope="shared")
+ B_reindex_shared = T.alloc_buffer((128, 128), "float16",
scope="shared")
+ A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16",
scope="wmma.matrix_a")
+ B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16",
scope="wmma.matrix_b")
for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"):
for ax0_0_1_ax1_0_1_fused in T.thread_binding(2,
thread="blockIdx.x"):
for ax0_0_2_ax1_0_2_fused in T.thread_binding(2,
thread="threadIdx.y"):
- for ax2_0_0 in T.serial(1):
- for ax0_ax1_fused in T.serial(4096):
+ for ax2_0_0 in range(1):
+ for ax0_ax1_fused in range(4096):
with T.block("A_reindex_shared"):
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused
// 2 * 32 + ax0_ax1_fused // 128)
v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
T.reads(A[v0, v1])
T.writes(A_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "meta_schedule.cooperative_fetch":8})
- A_reindex_shared[v0, v1] = T.if_then_else(v0 <
127 and v1 < 127, A[v0, v1], T.float16(0), dtype="float16")
- for ax0_ax1_fused in T.serial(4096):
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 8})
+ A_reindex_shared[v0, v1] = T.if_then_else(v0 <
127 and v1 < 127, A[v0, v1], T.float16(0))
+ for ax0_ax1_fused in range(4096):
with T.block("B_reindex_shared"):
v0 = T.axis.spatial(128, ax0_ax1_fused // 32)
v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused
% 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
T.reads(B[v0, v1])
T.writes(B_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "meta_schedule.cooperative_fetch":1})
- B_reindex_shared[v0, v1] = T.if_then_else(v0 <
127 and v1 < 127, B[v0, v1], T.float16(0), dtype="float16")
- for ax2_0_1 in T.serial(4):
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 1})
+ B_reindex_shared[v0, v1] = T.if_then_else(v0 <
127 and v1 < 127, B[v0, v1], T.float16(0))
+ for ax2_0_1 in range(4):
for ax0_0, ax1_0 in T.grid(2, 2):
with
T.block("A_reindex_shared_wmma.matrix_a_o"):
v0_o = T.axis.spatial(8,
ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0)
v1_o = T.axis.spatial(8, ax2_0_1 * 2 +
ax1_0)
- T.reads(A_reindex_shared[v0_o * 16 : v0_o
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 :
v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"})
+ T.reads(A_reindex_shared[v0_o * 16:v0_o *
16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_load_16x16x16_f16_a_shared"})
for ax0_1, ax1_1 in T.grid(16, 16):
with
T.block("A_reindex_shared_wmma.matrix_a"):
v0_i, v1_i = T.axis.remap("SS",
[ax0_1, ax1_1])
@@ -900,9 +799,9 @@ def test_padded_matmul_relu():
with
T.block("B_reindex_shared_wmma.matrix_b_o"):
v0_o = T.axis.spatial(8, ax2_0_1 * 2 +
ax0_0)
v1_o = T.axis.spatial(8,
ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 +
ax0_0_2_ax1_0_2_fused + ax1_0)
- T.reads(B_reindex_shared[v0_o * 16 : v0_o
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 :
v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"})
+ T.reads(B_reindex_shared[v0_o * 16:v0_o *
16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_load_16x16x16_f16_b_shared"})
for ax0_1, ax1_1 in T.grid(16, 16):
with
T.block("B_reindex_shared_wmma.matrix_b"):
v0_i, v1_i = T.axis.remap("SS",
[ax0_1, ax1_1])
@@ -914,45 +813,56 @@ def test_padded_matmul_relu():
v0_o = T.axis.spatial(8,
ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4)
v1_o = T.axis.spatial(8, ax1_0_4 +
ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 +
ax0_0_2_ax1_0_2_fused + ax1_0_3)
v2_o = T.axis.reduce(8, ax2_0_0 * 8 +
ax2_0_1 * 2 + ax2_0_2)
-
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 :
v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16,
v1_o * 16 : v1_o * 16 + 16])
-
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o *
16 : v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32",
"warp_execution":1})
+
T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o
* 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16,
0:16])
+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32",
"warp_execution": 1})
with T.init():
for ax0_1, ax1_1 in T.grid(16, 16):
with T.block("C_init"):
v0_i_init, v1_i_init =
T.axis.remap("SS", [ax0_1, ax1_1])
T.reads()
-
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 +
v1_i_init])
-
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]
= T.float32(0)
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0,
v0_i_init, v1_i_init])
+
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init,
v1_i_init] = T.float32(0)
for ax0_1, ax1_1, ax2_1 in T.grid(16, 16,
16):
with T.block("C"):
v0_i, v1_i, v2_i =
T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
-
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i],
A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i],
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
-
T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
-
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] =
C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] +
T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i],
"float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16
+ v1_i], "float32")
- for ax0_0, ax1_0 in T.grid(2, 1):
- with T.block("C_reindex_shared_wmma.accumulator_o"):
- v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused //
2 * 2 + ax0_0)
- v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2
* 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0)
- T.reads(C_reindex_shared_wmma_accumulator[v0_o *
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
- T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 +
16, v1_o * 16 : v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
- for ax0_1, ax1_1 in T.grid(16, 16):
- with
T.block("C_reindex_shared_wmma.accumulator"):
- v0_i, v1_i = T.axis.remap("SS", [ax0_1,
ax1_1])
-
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
- T.writes(C_reindex_shared[v0_o * 16 +
v0_i, v1_o * 16 + v1_i])
- C_reindex_shared[v0_o * 16 + v0_i, v1_o *
16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 +
v1_i]
- for ax0_ax1_fused in T.serial(1024):
- with T.block("C_reindex_shared"):
- T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 +
ax0_ax1_fused // 32 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 +
ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32 < 127)
- v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 *
32 + ax0_ax1_fused // 32)
- v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 *
64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
- T.reads(C_reindex_shared[v0, v1])
- T.writes(compute[v0, v1])
- T.block_attr({"meta_schedule.cooperative_fetch":4})
- compute[v0, v1] = T.max(C_reindex_shared[v0, v1],
T.float32(0))
+
T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i,
v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i],
B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
+
T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i,
v1_i])
+
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] =
C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] +
T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 +
v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i,
v1_o * 16 + v1_i])
+ for ax2 in range(2):
+ for ax0_ax1_fused in T.thread_binding(2,
thread="threadIdx.y"):
+ for ax2_1, ax3 in T.grid(1, 1):
+ with
T.block("C_reindex_shared_wmma.accumulator_o"):
+ v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused
// 2)
+ v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused %
2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused)
+ v2 = T.axis.spatial(2, ax2 + ax2_1)
+ v3 = T.axis.spatial(1, ax3)
+ v4_o = T.axis.spatial(1, 0)
+ v5_o = T.axis.spatial(1, 0)
+ T.reads(C_reindex_shared_wmma_accumulator[v0,
v1, v2, v3, 0:16, 0:16])
+ T.writes(C_reindex_shared[v0, v1, v2, v3,
0:16, 0:16])
+ T.block_attr({"meta_schedule.auto_tensorize":
"wmma_store_16x16x16_f32_shared"})
+ for ax4, ax5 in T.grid(16, 16):
+ with
T.block("C_reindex_shared_wmma.accumulator"):
+ v4_i, v5_i = T.axis.remap("SS", [ax4,
ax5])
+
T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i])
+ T.writes(C_reindex_shared[v0, v1, v2,
v3, v4_i, v5_i])
+ C_reindex_shared[v0, v1, v2, v3, v4_i,
v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]
+ for ax0_ax1_ax3_ax4_ax5_fused in range(512):
+ with T.block("C_reindex_shared"):
+ v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2
+ 0)
+ v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 *
4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256)
+ v2 = T.axis.spatial(2, ax2)
+ v3 = T.axis.spatial(1, 0)
+ v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 256 // 16)
+ v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 16)
+ T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16
+ ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 *
64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256 * 16 +
ax0_ax1_ax3_ax4_ax5_fused % 16 < 127)
+ T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
+ T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 *
16])
+ T.block_attr({"meta_schedule.cooperative_fetch":
4})
+ compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] =
T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
+
# fmt: on
decision_0 = [
@@ -994,25 +904,25 @@ def test_conv_1x1():
@T.prim_func
def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight:
T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64),
"float32")) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 64],
dtype="float32", scope="shared")
- conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256,
64], dtype="float32", scope="wmma.accumulator")
- PadInput_reindex_shared = T.alloc_buffer([256, 64], dtype="float16",
scope="shared")
- weight_reindex_shared = T.alloc_buffer([1, 1, 64, 64],
dtype="float16", scope="shared")
- PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 64],
dtype="float16", scope="wmma.matrix_a")
- weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 64, 64],
dtype="float16", scope="wmma.matrix_b")
+ conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 4, 1, 1, 16, 16),
scope="shared")
+ conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 4,
1, 1, 16, 16), scope="wmma.accumulator")
+ PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16",
scope="shared")
+ weight_reindex_shared = T.alloc_buffer((1, 1, 64, 64), "float16",
scope="shared")
+ PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 64),
"float16", scope="wmma.matrix_a")
+ weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 64, 64),
"float16", scope="wmma.matrix_b")
for ax2_0_0_ax3_0_0_fused in T.thread_binding(16, thread="blockIdx.y"):
for ax2_0_1_ax3_0_1_fused in T.thread_binding(2,
thread="blockIdx.x"):
for ax2_0_2_ax3_0_2_fused in T.thread_binding(2,
thread="threadIdx.y"):
for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 1):
- for ax0_ax1_fused in T.serial(1024):
+ for ax0_ax1_fused in range(1024):
with T.block("PadInput_reindex_shared"):
v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused
// 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 64)
v1 = T.axis.spatial(64, ax0_ax1_fused % 64)
T.reads(inputs[v0 // 256, v0 // 16, v0 % 16,
v1])
T.writes(PadInput_reindex_shared[v0, v1])
- T.block_attr({"buffer_dim_align":[[0, 0, 32,
8]], "meta_schedule.cooperative_fetch":1})
+ T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 1})
PadInput_reindex_shared[v0, v1] = inputs[v0 //
256, v0 // 16, v0 % 16, v1]
- for ax0_ax1_ax2_ax3_fused in T.serial(2048):
+ for ax0_ax1_ax2_ax3_fused in range(2048):
with T.block("weight_reindex_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(1, 0)
@@ -1020,16 +930,16 @@ def test_conv_1x1():
v3 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused
% 2 * 32 + ax0_ax1_ax2_ax3_fused % 32)
T.reads(weight[v0, v1, v2, v3])
T.writes(weight_reindex_shared[v0, v1, v2, v3])
- T.block_attr({"buffer_dim_align":[[0, 2, 32,
8]], "meta_schedule.cooperative_fetch":4})
+ T.block_attr({"buffer_dim_align": [[0, 2, 32,
8]], "meta_schedule.cooperative_fetch": 4})
weight_reindex_shared[v0, v1, v2, v3] =
weight[v0, v1, v2, v3]
for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1):
for ax0_0_1, ax1_0_1 in T.grid(1, 4):
with
T.block("PadInput_reindex_shared_wmma.matrix_a_o"):
v0_o = T.axis.spatial(16,
ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0_1)
v1_o = T.axis.spatial(4, ax1_0_1)
- T.reads(PadInput_reindex_shared[v0_o * 16
: v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o
* 16 : v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"})
+ T.reads(PadInput_reindex_shared[v0_o *
16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
+
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o *
16:v1_o * 16 + 16])
+
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_load_16x16x16_f16_a_shared"})
for ax0_1_1, ax1_1_1 in T.grid(16, 16):
with
T.block("PadInput_reindex_shared_wmma.matrix_a"):
v0_i, v1_i = T.axis.remap("SS",
[ax0_1_1, ax1_1_1])
@@ -1040,9 +950,9 @@ def test_conv_1x1():
with
T.block("weight_reindex_shared_wmma.matrix_b_o"):
v0, v1, v2_o = T.axis.remap("SSS", [ax0,
ax1, ax2_0])
v3_o = T.axis.spatial(4,
ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0)
- T.reads(weight_reindex_shared[v0, v1, v2_o
* 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
-
T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 +
16, v3_o * 16 : v3_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"})
+ T.reads(weight_reindex_shared[v0, v1, v2_o
* 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
+
T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16:v2_o * 16 + 16,
v3_o * 16:v3_o * 16 + 16])
+
T.block_attr({"meta_schedule.auto_tensorize":
"wmma_load_16x16x16_f16_b_shared"})
for ax2_1, ax3_1 in T.grid(16, 16):
with
T.block("weight_reindex_shared_wmma.matrix_b"):
v2_i, v3_i = T.axis.remap("SS",
[ax2_1, ax3_1])
@@ -1056,44 +966,53 @@ def test_conv_1x1():
v2_o = T.axis.spatial(16, ax2_0_4 +
ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3)
v3_o = T.axis.spatial(4, ax3_0_4 +
ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3)
v4_o = T.axis.reduce(4, ax4_0_0 * 4 +
ax4_0_1 * 4 + ax4_0_2)
-
T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o
* 16 : v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 :
v4_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
-
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 +
16, v3_o * 16 : v3_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32",
"warp_execution":1})
+
T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o *
16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16:v4_o
* 16 + 16, v3_o * 16:v3_o * 16 + 16])
+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, 0:16,
0:16])
+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32",
"meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32",
"warp_execution": 1})
with T.init():
for ax2_1, ax3_1 in T.grid(16, 16):
with T.block("conv2d_nhwc_init"):
v2_i_init, v3_i_init =
T.axis.remap("SS", [ax2_1, ax3_1])
T.reads()
-
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init,
v3_o * 16 + v3_i_init])
-
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 +
v3_i_init] = T.float32(0)
+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0,
v2_i_init, v3_i_init])
+
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i_init,
v3_i_init] = T.float32(0)
for ax2_1, ax3_1, ax4_1 in T.grid(16, 16,
16):
with T.block("conv2d_nhwc"):
v2_i, v3_i, v4_i =
T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1])
-
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16
+ v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 +
v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16
+ v3_i])
-
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o *
16 + v3_i])
-
T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]
= conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 +
v3_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o *
16 + v4_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v0, v1,
v4_o * 16 + v4_i, v3_o * 16 + v3_i], "float32")
- for ax0_0, ax1_0 in T.grid(1, 1):
- with
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
- v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused //
2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0)
- v1_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2
* 2 + ax2_0_2_ax3_0_2_fused + ax1_0)
-
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16,
v1_o * 16 : v1_o * 16 + 16])
- T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 :
v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
- for ax0_1, ax1_1 in T.grid(16, 16):
- with
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
- v0_i, v1_i = T.axis.remap("SS", [ax0_1,
ax1_1])
-
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16
+ v1_i])
- T.writes(conv2d_nhwc_reindex_shared[v0_o *
16 + v0_i, v1_o * 16 + v1_i])
- conv2d_nhwc_reindex_shared[v0_o * 16 +
v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16
+ v0_i, v1_o * 16 + v1_i]
- for ax0_ax1_fused in T.serial(512):
- with T.block("conv2d_nhwc_reindex_shared"):
- v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 *
32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 32)
- v1 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32
+ ax0_ax1_fused % 32)
- T.reads(conv2d_nhwc_reindex_shared[v0, v1])
- T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1])
- T.block_attr({"meta_schedule.cooperative_fetch":2})
- conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1] =
conv2d_nhwc_reindex_shared[v0, v1]
+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i,
v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 +
v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16
+ v3_i])
+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i,
v3_i])
+
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
+
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i] =
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i] +
T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o
* 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0, v1,
v4_o * 16 + v4_i, v3_o * 16 + v3_i])
+ for ax2 in range(1):
+ for ax0_ax1_fused in T.thread_binding(2,
thread="threadIdx.y"):
+ for ax2_1, ax3 in T.grid(1, 1):
+ with
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
+ v0 = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused
// 2 * 2 + ax2_0_1_ax3_0_1_fused)
+ v1 = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused %
2 * 2 + ax0_ax1_fused)
+ v2, v3 = T.axis.remap("SS", [ax2_1, ax3])
+ v4_o = T.axis.spatial(1, 0)
+ v5_o = T.axis.spatial(1, 0)
+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16])
+ T.writes(conv2d_nhwc_reindex_shared[v0, v1,
v2, v3, 0:16, 0:16])
+ T.block_attr({"meta_schedule.auto_tensorize":
"wmma_store_16x16x16_f32_shared"})
+ for ax4, ax5 in T.grid(16, 16):
+ with
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
+ v4_i, v5_i = T.axis.remap("SS", [ax4,
ax5])
+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i])
+
T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i])
+ conv2d_nhwc_reindex_shared[v0, v1, v2,
v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3,
v4_i, v5_i]
+ for ax0_ax1_ax3_ax4_ax5_fused in range(512):
+ with T.block("conv2d_nhwc_reindex_shared"):
+ v0 = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2
* 2 + ax2_0_1_ax3_0_1_fused)
+ v1 = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 *
2 + ax0_ax1_ax3_ax4_ax5_fused // 256)
+ v2 = T.axis.spatial(1, ax2)
+ v3 = T.axis.spatial(1, 0)
+ v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 256 // 16)
+ v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 16)
+ T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3,
v4, v5])
+ T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 +
v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16])
+ T.block_attr({"meta_schedule.cooperative_fetch":
2})
+ conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16)
// 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1,
v2, v3, v4, v5]
# fmt: on
decision_0 = [
diff --git
a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index 1e5fd8843b..b9f35ed553 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -1124,7 +1124,7 @@ def test_simple_compute_async():
B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
with T.block():
T.reads(A[tx, 0])
- T.writes(B[0, tx, 0])
+ T.writes(B[T.FloorMod(0, 2), tx, 0])
with T.attr(0, "async_commit_queue_scope", 0):
with T.attr(0, "async_scope", 1):
B[T.FloorMod(0, 2), tx, 0] = A[tx, 0] *
T.float32(2)
@@ -1350,8 +1350,8 @@ def test_three_stage_compute_two_stage_async():
B[i % 2, tx, 0] = A[tx, i] * T.float32(2)
with T.block():
T.where(i == 1 and i - 1 < 16)
- T.reads(B[(i + 1) % 2, tx, 0])
- T.writes(C[(i + 1) % 2, tx, 0])
+ T.reads(B[(i - 1) % 2, tx, 0])
+ T.writes(C[(i - 1) % 2, tx, 0])
with T.attr(0, "async_commit_queue_scope", 1):
with T.attr(0, "async_wait_queue_scope", 0):
with T.attr(0,
"async_wait_inflight_count", 1):
@@ -1366,14 +1366,14 @@ def test_three_stage_compute_two_stage_async():
with T.block():
T.where(i + 2 < 16)
T.reads(A[tx, i + 2])
- T.writes(B[i % 2, tx, 0])
+ T.writes(B[(i + 2) % 2, tx, 0])
with T.attr(0, "async_commit_queue_scope", 0):
with T.attr(0, "async_scope", 1):
B[(i + 2) % 2, tx, 0] = A[tx, i + 2] *
T.float32(2)
with T.block():
T.where(i + 2 - 1 < 16)
- T.reads(B[(i + 1) % 2, tx, 0])
- T.writes(C[(i + 1) % 2, tx, 0])
+ T.reads(B[(i - 1 + 2) % 2, tx, 0])
+ T.writes(C[(i - 1 + 2) % 2, tx, 0])
with T.attr(0, "async_commit_queue_scope", 1):
with T.attr(0, "async_wait_queue_scope", 0):
with T.attr(0,
"async_wait_inflight_count", 1):
@@ -1394,8 +1394,8 @@ def test_three_stage_compute_two_stage_async():
for i in T.unroll(2):
with T.block():
T.where(i + 16 - 1 < 16)
- T.reads(B[(i + 1) % 2, tx, 0])
- T.writes(C[(i + 1) % 2, tx, 0])
+ T.reads(B[(i - 1 + 16) % 2, tx, 0])
+ T.writes(C[(i - 1 + 16) % 2, tx, 0])
with T.attr(0, "async_commit_queue_scope", 1):
with T.attr(0, "async_wait_queue_scope", 0):
with T.attr(0,
"async_wait_inflight_count", 0 - i):