junrushao1994 commented on a change in pull request #10066:
URL: https://github.com/apache/tvm/pull/10066#discussion_r797198386



##########
File path: src/tir/transforms/inject_software_pipeline.cc
##########
@@ -0,0 +1,785 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file inject_software_pipeline.cc
+ * \brief Transform annotated loops into pipelined one that parallelize 
producers and consumers
+ */
+#include <tvm/target/target.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/transform.h>
+
+#include "../../support/utils.h"
+#include "../schedule/utils.h"
+#include "./ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+namespace software_pipeline {
+
+/*!
+ * \brief Create a block and infer the access region with the given body.
+ *
+ * The result is a opaque block that doesn't contain any block iter vars. In 
case the body is a
+ * block realize without predicate, it is unnecessary to create a new block, 
the block of the block
+ * realize will be returned.
+ *
+ * \param body The body of the block.
+ * \param buffer_data_to_buffer The map from buffer data to buffer.
+ * \return The result block.
+ */
+Block MakeBlock(const Stmt& body, const Map<Var, Buffer>& 
buffer_data_to_buffer) {
+  if (const BlockRealizeNode* block_realize = body.as<BlockRealizeNode>()) {
+    if (is_one(block_realize->predicate)) {
+      // no need to create a new block
+      return block_realize->block;
+    }
+  }
+  Block block = Block({}, {}, {}, "", body);
+  auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer);
+  auto* n = block.CopyOnWrite();
+  n->reads = access[0];
+  n->writes = access[1];
+  return block;
+}
+
+/*! Structure that represents the stage and order of the software pipeline 
component. */
+struct PipelineStageOrder {
+  int stage;
+  int order;
+  explicit PipelineStageOrder(int stage, int order) : stage(stage), 
order(order) {}
+};
+
+using PipelineInfo = std::unordered_map<Block, PipelineStageOrder, 
ObjectPtrHash, ObjectPtrEqual>;
+
+struct BufferAccessInfo {
+  int def;  // the defining stage of the buffer
+  int use;  // the last using stage of the buffer
+  explicit BufferAccessInfo(int def = -1, int use = -1) : def(def), use(use) {}
+};
+
+/*!
+ * \brief Rewriter for the body of the software pipeline. This pass inserts 
`floormod` to indices
+ * of the remapped buffer to select the version corresponding to the pipeline 
stage.
+ */
+class PipelineBodyRewriter : public StmtExprMutator {
+ public:
+  /*!
+   * \brief Constructor of PipelineBodyRewriter.
+   * \param buffer_data_to_buffer The map from buffer data to buffer.
+   * \param buffer_remap The map from original buffer to the buffer with 
updated shape for
+   *        multi-versioning in the sofeware pipeline.
+   * \param pipeline_loop The original loop to be software pipelined.
+   * \param access_all_versions Whether all versions the the buffers in the 
software pipeline are
+   *        accessed. This will be used to update block access region. In the 
prologue and epilogue
+   *        of a two-stage software pipeline, only one version of these 
buffers are accessed.
+   * \param fragment_info Information about tensor core fragment
+   */
+  PipelineBodyRewriter(const Map<Var, Buffer>& buffer_data_to_buffer,
+                       const Map<Buffer, Buffer>& buffer_remap, For 
pipeline_loop,
+                       bool access_all_versions,
+                       const std::unordered_map<const VarNode*, FragmentInfo>& 
fragment_info)
+      : buffer_data_to_buffer_(buffer_data_to_buffer),
+        buffer_remap_(buffer_remap),
+        pipeline_loop_(pipeline_loop),
+        access_all_versions_(access_all_versions),
+        fragment_info_(fragment_info) {}
+
+ private:
+  BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) 
const {
+    auto it = buffer_remap_.find(buffer_region->buffer);
+    if (it != buffer_remap_.end()) {
+      Region new_region = buffer_region->region;
+      const Buffer& new_buffer = (*it).second;
+      // For pipeline buffers, relax the access region of the first dimension 
to full extent
+      // if access_all_versions == true
+      Range accessed_version =
+          access_all_versions_
+              ? Range::FromMinExtent(0, new_buffer->shape[0])
+              : Range::FromMinExtent(floormod((pipeline_loop_->loop_var - 
pipeline_loop_->min),
+                                              new_buffer->shape[0]),
+                                     Integer(1));
+      new_region.insert(new_region.begin(), accessed_version);
+      return BufferRegion(new_buffer, new_region);
+    }
+    return buffer_region;
+  }
+
+  Stmt VisitStmt_(const BlockNode* op) final {
+    for (const Buffer& alloc_buffer : op->alloc_buffers) {
+      buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
+    }
+    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+    BlockNode* n = block.CopyOnWrite();
+    n->reads.MutateByApply(
+        std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, 
std::placeholders::_1));
+    n->writes.MutateByApply(
+        std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, 
std::placeholders::_1));
+    for (const Buffer& alloc_buffer : op->alloc_buffers) {
+      buffer_data_to_buffer_.erase(alloc_buffer->data);
+    }
+    return std::move(block);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    auto it = buffer_remap_.find(store->buffer);
+    if (it == buffer_remap_.end()) {
+      return std::move(store);
+    }
+    const Buffer& new_buffer = (*it).second;
+    auto* n = store.CopyOnWrite();
+    n->buffer = new_buffer;
+    PrimExpr version =
+        floormod((pipeline_loop_->loop_var - pipeline_loop_->min), 
new_buffer->shape[0]);
+    n->indices.insert(n->indices.begin(), version);
+    return std::move(store);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    auto it = buffer_remap_.find(load->buffer);
+    if (it == buffer_remap_.end()) {
+      return std::move(load);
+    }
+    const Buffer& new_buffer = (*it).second;
+    auto* n = load.CopyOnWrite();
+    n->buffer = new_buffer;
+    PrimExpr version =
+        floormod((pipeline_loop_->loop_var - pipeline_loop_->min), 
new_buffer->shape[0]);
+    n->indices.insert(n->indices.begin(), version);
+    return std::move(load);
+  }
+
+  int GetWmmaFragmentSize(const Buffer& buffer) {
+    auto it = fragment_info_.find(buffer->data.get());
+    ICHECK(it != fragment_info_.end());
+    const FragmentInfo& info = (*it).second;
+    String scope = buffer.scope();
+    if (scope == "wmma.matrix_a") {
+      return info.m * info.k;
+    } else if (scope == "wmma.matrix_b") {
+      return info.n * info.k;
+    } else if (scope == "wmma.accumulator") {
+      return info.m * info.n;
+    } else {
+      ICHECK(0);
+      throw;
+    }
+  }
+
+  PrimExpr RewriteWmmaFragmentIndex(const Buffer& old_buffer, const Buffer& 
new_buffer,
+                                    const PrimExpr& old_index) {
+    PrimExpr new_buffer_offset = old_index;
+
+    int fragment_size = GetWmmaFragmentSize(old_buffer);
+    PrimExpr offset =
+        floordiv(foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, 
b, span); },
+                       make_const(DataType::Int(32), 1), old_buffer->shape),
+                 fragment_size);
+    new_buffer_offset +=
+        floormod(pipeline_loop_->loop_var - pipeline_loop_->min, 
new_buffer->shape[0]) * offset;
+    return new_buffer_offset;
+  }
+
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    // Intrinsic calls should be handled explicitly here as they are opaque 
accesses to
+    // buffer.
+    static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync();
+    static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync();
+    static const auto& mma_sync = builtin::tvm_mma_sync();
+    static const auto& access_ptr = builtin::tvm_access_ptr();
+    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
+    if (call->op.same_as(load_matrix_sync) || 
call->op.same_as(store_matrix_sync)) {
+      const Buffer& buffer = 
buffer_data_to_buffer_.at(Downcast<Var>(call->args[0]));
+      auto it = buffer_remap_.find(buffer);
+      if (it != buffer_remap_.end()) {
+        Array<PrimExpr> new_args = call->args;
+        const Buffer& new_buffer = (*it).second;
+        new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, 
call->args[4]));
+        return Call(call->dtype, call->op, new_args, call->span);
+      }
+    } else if (call->op.same_as(mma_sync)) {
+      Array<PrimExpr> new_args = call->args;
+      for (int i = 0; i < 4; i++) {
+        const Var& buffer_var = Downcast<Var>(call->args[i * 2]);
+        const PrimExpr& index = call->args[i * 2 + 1];
+        const Buffer& buffer = buffer_data_to_buffer_.at(buffer_var);
+        auto it = buffer_remap_.find(buffer);
+        if (it != buffer_remap_.end()) {
+          PrimExpr new_index = RewriteWmmaFragmentIndex(buffer, (*it).second, 
index);
+          new_args.Set(i * 2 + 1, new_index);
+        }
+      }
+      return Call(call->dtype, call->op, new_args, call->span);
+    } else if (call->op.same_as(access_ptr)) {
+      const Buffer& buffer = 
buffer_data_to_buffer_.at(Downcast<Var>(call->args[1]));
+      auto it = buffer_remap_.find(buffer);
+      if (it != buffer_remap_.end()) {
+        Array<PrimExpr> new_args = call->args;
+        const Buffer& new_buffer = (*it).second;
+        const PrimExpr& old_index = call->args[2];
+        PrimExpr offset;
+        if (new_buffer->strides.empty()) {
+          offset = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, 
b, span); },
+                         make_const(DataType::Int(32), 1), buffer->shape);
+        } else {
+          offset = new_buffer->strides[0];
+        }
+        PrimExpr new_index = old_index + floormod(pipeline_loop_->loop_var, 2) 
* offset;
+        new_args.Set(2, new_index);
+        return Call(call->dtype, call->op, new_args, call->span);
+      }
+    }
+    return std::move(call);
+  }
+
+  Map<Var, Buffer> buffer_data_to_buffer_;
+  Map<Buffer, Buffer> buffer_remap_;
+  For pipeline_loop_;
+  bool access_all_versions_;
+  const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info_;
+};
+
+/*!
+ * \brief Rewriter for the software pipeline that rewrite a loop into a 
pipelined one.
+ */
+class PipelineRewriter : public StmtExprMutator {
+ public:
+  static Stmt Rewrite(
+      Map<Var, Buffer> buffer_data_to_buffer,
+      const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& 
double_buffers,
+      const Array<Buffer> pipeline_allocs, const For& pipeline_loop,
+      const PipelineInfo& pipeline_info,
+      const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info) {
+    PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, 
pipeline_allocs, pipeline_loop,
+                              pipeline_info, fragment_info);
+    return rewriter.BuildPipeline();
+  }
+
+ private:
+  PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
+                   const std::unordered_set<Buffer, ObjectPtrHash, 
ObjectPtrEqual>& double_buffers,
+                   const Array<Buffer>& pipeline_allocs, const For& 
pipeline_loop,
+                   const PipelineInfo& pipeline_info,
+                   const std::unordered_map<const VarNode*, FragmentInfo>& 
fragment_info)
+
+      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
+        double_buffers_(double_buffers),
+        pipeline_allocs_(pipeline_allocs),
+        pipeline_loop_(pipeline_loop),
+        pipeline_info_(pipeline_info),
+        fragment_info_(fragment_info) {}
+
+  Stmt BuildPipeline() {
+    // Step 1: Analyze accesses to the buffers in the pipeline and compute the 
number of versions
+    // need to maintain for each buffer.
+    RemapPipelineBuffers(pipeline_allocs_);
+
+    ordered_stmts_.resize(pipeline_info_.size());
+    for (const auto& pair : pipeline_info_) {
+      const Block& block = pair.first;
+      int order = pair.second.order;
+      ordered_stmts_.Set(order, block);
+    }
+
+    // Step 2: Emit the pipeline prologue, body and epilogue.
+    Stmt prologue = EmitImpl(pipeline_loop_->min, pipeline_loop_->min + 
max_stage_, true);
+    Stmt body = EmitImpl(pipeline_loop_->min + max_stage_,
+                         pipeline_loop_->min + pipeline_loop_->extent, false);
+    Stmt epilogue = EmitImpl(pipeline_loop_->min + pipeline_loop_->extent,
+                             pipeline_loop_->min + pipeline_loop_->extent + 
max_stage_, true);
+
+    SeqStmt stmt = SeqStmt({prologue, body, epilogue});
+
+    // Step 3: Make a new block that contains new buffer allocations after 
pipeline rewriting.
+    Array<Buffer> alloc_buffers;
+    for (const auto& alloc : pipeline_allocs_) {
+      auto it = buffer_remap_.find(alloc);
+      if (it != buffer_remap_.end()) {
+        alloc_buffers.push_back((*it).second);
+      } else {
+        alloc_buffers.push_back(alloc);
+      }
+      buffer_data_to_buffer_.erase(alloc->data);
+    }
+    Block block = MakeBlock(stmt, buffer_data_to_buffer_);
+    auto* n = block.CopyOnWrite();
+    n->alloc_buffers = std::move(alloc_buffers);
+    return BlockRealize({}, Bool(true), block);
+  }
+
+ private:
+  /*!
+   * \brief Analyze accesses to the buffers in the software pipeline.
+   *
+   * This method check the 'define' and 'use' stage of the buffers in the 
software pipeline, which
+   * can be used to compute the number of versions needed to maintain after 
rewriting.
+   */
+  std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
+  GetBufferAccessInfo() {
+    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, 
ObjectPtrEqual> infos;
+    for (const auto& pair : pipeline_info_) {
+      const Block& block = pair.first;
+      int stage = pair.second.stage;
+      max_stage_ = std::max(max_stage_, stage);
+
+      for (const BufferRegion& write : block->writes) {
+        if (!infos.count(write->buffer)) {
+          infos.emplace(write->buffer, BufferAccessInfo{});
+        }
+        auto& info = infos.at(write->buffer);
+        if (info.def == -1) {
+          info.def = stage;
+        } else {
+          info.def = std::min(info.def, stage);
+        }
+      }
+
+      for (const BufferRegion& read : block->reads) {
+        if (!infos.count(read->buffer)) {
+          infos.emplace(read->buffer, BufferAccessInfo{});
+        }
+        auto& info = infos.at(read->buffer);
+        info.use = std::max(info.use, stage);
+      }
+    }
+    return infos;
+  }
+
+  /*!
+   * \brief Check whether two regions have intersections.
+   * \param region1 The first region.
+   * \param region2 The second region.
+   * \return Whether region1 and region2 have intersections.
+   */
+  bool MayConflict(Region region1, Region region2) {
+    ICHECK(region1.size() == region2.size());
+    for (size_t i = 0; i < region1.size(); i++) {
+      Range dim1 = region1[i];
+      Range dim2 = region2[i];
+      auto int_set1 = arith::IntSet::FromRange(dim1);
+      auto int_set2 = arith::IntSet::FromRange(dim2);
+      if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Compute the number of versions need to maintain for buffer 
accessed in the software
+   * pipeline.
+   *
+   * This method applies liveness analysis to the target buffer to compute the 
number of versions
+   * need to maintain during the software pipeline.
+   * Annotation `attr::double_buffer_scope` is handled here which provides a 
way to override the
+   * result of the analysis. Additional double buffering in the software 
pipeline can be useful
+   * to eliminate synchonizations in GPU devices.
+   *
+   * \param buffer The target buffer
+   * \param buffer_info The access information of the target buffer.
+   * \return The number of versions required for the target buffer.
+   */
+  int ComputeBufferVersions(const Buffer& buffer, const BufferAccessInfo& 
buffer_info) {
+    if (buffer_info.def == -1) {
+      // Keep the original number of versions as buffers defined outside the 
software pipeline
+      // should not be mutated.
+      return 1;
+    }
+
+    // `use - def + 1` is a upper bound of the needed versions
+    // We optimize a few case where the number of versions can be smaller than 
the upper bound
+    int num_versions = buffer_info.use - buffer_info.def + 1;
+    if (num_versions == 2) {
+      // A special case when `use - def + 1 == 2`. Double buffering is only 
needed in this case when
+      // these exists a reader block_i and a writer block_j such that
+      // order(block_i) < order(block_j) and stage(block_i) < stage(block_j) 
and the access regions
+      // of block_i and block_j overlap.
+      bool need_multi_version = false;
+      for (const auto& pair1 : pipeline_info_) {
+        const Block& writer_block = pair1.first;
+        const auto& writer_info = pair1.second;
+
+        auto it1 = std::find_if(writer_block->writes.begin(), 
writer_block->writes.end(),
+                                [&](const BufferRegion& buffer_region) {
+                                  return buffer_region->buffer.same_as(buffer);
+                                });
+        if (it1 == writer_block->writes.end()) {
+          continue;
+        }
+
+        for (const auto& pair2 : pipeline_info_) {
+          const Block& reader_block = pair2.first;
+          const auto& reader_info = pair2.second;
+          auto it2 = std::find_if(reader_block->reads.begin(), 
reader_block->reads.end(),
+                                  [&](const BufferRegion& buffer_region) {
+                                    return 
buffer_region->buffer.same_as(buffer);
+                                  });
+          if (it2 == reader_block->reads.end()) {
+            continue;
+          }
+          if (writer_info.order < reader_info.order && writer_info.stage < 
reader_info.stage &&
+              MayConflict((*it1)->region, (*it2)->region)) {
+            need_multi_version = true;
+            break;
+          }
+        }
+      }
+      if (!need_multi_version) {
+        num_versions = 1;
+      }
+    }
+    if (num_versions == 1 && double_buffers_.count(buffer)) {
+      num_versions = 2;
+    }
+    return num_versions;
+  }
+
+  /*!
+   * \brief Rewrite buffer allocations to create new buffers with new shapes 
according to
+   * the software pipeline.
+   * \param pipeline_allocs The buffer allocations inside the software 
pipeline scope.
+   */
+  void RemapPipelineBuffers(Array<Buffer> pipeline_allocs) {
+    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, 
ObjectPtrEqual> infos =
+        GetBufferAccessInfo();
+    for (const Buffer& buffer : pipeline_allocs) {
+      const BufferAccessInfo access_info = infos.at(buffer);
+      int num_versions = ComputeBufferVersions(buffer, access_info);
+      if (num_versions > 1) {
+        Buffer new_buffer = RewriteAllocBuffer(buffer, num_versions);
+        buffer_remap_.Set(buffer, new_buffer);
+      }
+    }
+  }
+
+  /*!
+   * \brief Rewrite buffer allocation to keep multiple versions of original 
buffer for pipelined
+   * accesses.
+   * \param buffer The buffer to be resized.
+   * \param num_versions The number of versions to keep.
+   * \return The resized buffer.
+   */
+  Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) {
+    ObjectPtr<BufferNode> new_buffer = 
make_object<BufferNode>(*(buffer.get()));
+    new_buffer->shape.insert(new_buffer->shape.begin(), num_versions);
+    if (new_buffer->strides.size()) {
+      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
+      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
+      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
+    }
+    return Buffer(new_buffer);
+  }
+
+  /*!
+   * \brief Emit the pipeline loop in the given range.
+   * \param start The start of the range
+   * \param end The end of the range
+   * \param unroll_loop Whether the loop should be unrolled.
+   * \return The result loop.
+   */
+  Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop) {
+    Array<Stmt> stmts;
+    PrimExpr new_loop_var;
+    PrimExpr extent = end - start;
+
+    auto make_nop = []() { return BlockRealize({}, Bool(true), 
MakeBlock(Evaluate(0), {})); };
+
+    if (!analyzer_.CanProve(extent > 0)) {
+      return make_nop();
+    }
+    bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
+    if (is_unit_loop) {
+      new_loop_var = start;  // use constants as the loop var for unit loops
+    } else {
+      new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
+      analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
+    }
+
+    for (const Block block : ordered_stmts_) {
+      int stage = pipeline_info_.at(block).stage;
+      PrimExpr skewed_loop_var = new_loop_var - stage;
+      PrimExpr inbound = (skewed_loop_var >= pipeline_loop_->min) &&
+                         (skewed_loop_var < pipeline_loop_->min + 
pipeline_loop_->extent);
+      inbound = analyzer_.Simplify(inbound);
+      if (analyzer_.CanProve(!inbound)) {
+        continue;
+      }
+      Block new_block = 
Downcast<Block>(PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
+                                                             pipeline_loop_, 
max_stage_ != 1,
+                                                             
fragment_info_)(block));
+      Map<Var, PrimExpr> subst_map;
+      if (is_unit_loop) {
+        subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var);
+      } else {
+        // normalize loop range
+        subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + (start - 
pipeline_loop_->min));
+      }
+      new_block = Downcast<Block>(Substitute(new_block, subst_map));
+      stmts.push_back(BlockRealize({}, inbound, new_block));
+    }
+
+    Stmt new_loop{nullptr};
+
+    if (stmts.empty()) {
+      return make_nop();
+    }
+    if (stmts.size() == 1) {
+      new_loop = stmts[0];
+    } else {
+      new_loop = SeqStmt(stmts);
+    }
+
+    if (!is_unit_loop) {
+      new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
+                     unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, 
std::move(new_loop));
+    }
+    return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), 
buffer_data_to_buffer_));
+  }
+
+  arith::Analyzer analyzer_;
+  Map<Var, Buffer> buffer_data_to_buffer_;
+  const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& 
double_buffers_;
+  Array<Buffer> pipeline_allocs_;
+  For pipeline_loop_;
+  PipelineInfo pipeline_info_;
+  const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info_;
+  int max_stage_ = -1;
+  Map<Buffer, Buffer> buffer_remap_;
+  Array<Block> ordered_stmts_;
+};
+
+class PipelineInjector : private StmtExprMutator {
+ public:
+  static Stmt Inject(const PrimFunc& func) {
+    PipelineInjector injector;
+    for (const auto& kv : func->buffer_map) {
+      const Buffer& buffer = kv.second;
+      injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
+    }
+    injector.fragment_info_ = GetTensorCoreFragmentInfo(func->body);
+    return injector(func->body);
+  }
+
+ private:
+  PipelineInjector() = default;
+
+  /*!
+   * \brief Check the pipeline satisfies the following conditions:
+   * 1) No conflicting order: The order of each statement should be unique.
+   * 2) No reordering with the same stage: Statements in the same stage are 
not allowed to be
+   * reordered.
+   */
+  void ValidatePipelineBody(const PipelineInfo& pipeline_info, const 
Array<Block>& original_order) {
+    std::unordered_set<int> used_orders;
+    std::unordered_map<int, int> stage_max_order;
+    for (const Block& block : original_order) {
+      const auto& stmt_info = pipeline_info.at(block);
+      int stage = stmt_info.stage;
+      int order = stmt_info.order;
+      CHECK(!used_orders.count(order))
+          << "ValueError: Two statements in the software pipeline cannot have 
the same order";
+      used_orders.insert(order);
+      CHECK(!stage_max_order.count(stage) || stage_max_order[stage] < order)
+          << "ValueError: Statements in the same stage of the software 
pipeline must have "
+             "increasing order.";
+      stage_max_order[stage] = order;
+    }
+  }
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    // Step 1: Recursively rewrite the children first.
+    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
+    bool is_pipeline = HasPipelineAnnotation(op);
+    if (!is_pipeline) {
+      return std::move(for_node);
+    }

Review comment:
       nit:
   
   ```suggestion
       if (!HasPipelineAnnotation(op)) {
         return std::move(for_node);
       }
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to