jinhongyii commented on a change in pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#discussion_r735568966



##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = 
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = 
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && 
storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index 
+ offset
+    // if offset < 0, means this is the end, the begin entry is current_index 
+ offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // scope level
+    size_t level{0};
+    // allocation stmt
+    const AllocateNode* alloc{nullptr};
+  };
 
-  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+  void VisitStmt_(const AllocateNode* op) final {
+    size_t level = scope_.size();
+    const VarNode* buf = op->buffer_var.get();
+    alloc_info_[buf].alloc = op;
+    alloc_info_[buf].level = level;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    // Add write access.
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitStmt_(const EvaluateNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitExpr_(const LoadNode* op) final {
+    // Add write access.
+    StmtExprVisitor::VisitExpr_(op);
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places 
other than store.";
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::address_of())) {
+      const LoadNode* l = op->args[0].as<LoadNode>();
+      this->VisitExpr(l->index);
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  void VisitExpr_(const VarNode* buf) final {
+    // Directly reference to the variable count as a read.
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  template <typename T>
+  void VisitNewScope(const T* op) {
+    scope_.push_back(StmtEntry());
+    StmtEntry e;
+    e.stmt = op;
+    int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
+    // before scope.
+    linear_seq_.push_back(e);
+    StmtExprVisitor::VisitStmt_(op);
+    // after scope.
+    e.touched = std::move(scope_.back().touched);
+    scope_.pop_back();
+    int64_t end_index = static_cast<int64_t>(linear_seq_.size());
+    ICHECK_GT(end_index, begin_index);
+    e.scope_pair_offset = begin_index - end_index;
+    linear_seq_.push_back(e);
+    // record the pointer to end index.
+    ICHECK_NE(end_index, 0U);
+    linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
+  }
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Only record the outer most thread extent.
+    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
+      in_thread_env_ = true;
+      VisitNewScope(op);
+      in_thread_env_ = false;
+    } else if (op->attr_key == attr::extern_scope) {
+      VisitNewScope(op);
+    } else if (op->attr_key == attr::virtual_thread) {
+      VisitNewScope(op);
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+  void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
+
+  // linearized access sequence.
+  std::vector<StmtEntry> linear_seq_;
+  // The storage scope of each buffer
+  std::unordered_map<const VarNode*, AllocEntry> alloc_info_;

Review comment:
       it is accessed in the rewriter




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to