junrushao1994 commented on a change in pull request #8767:
URL: https://github.com/apache/tvm/pull/8767#discussion_r693400173



##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& 
loop_srefs) {
   return self->stmt2ref.at(new_stmt.get());
 }
 
+void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
+  std::unordered_set<const StmtSRefNode*> loop_srefs;
+  loop_srefs.reserve(ordered_loop_srefs.size());
+  if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) {
+    return;
+  }
+  // Step 1. check uniqueness
+  for (const StmtSRef& loop_sref : ordered_loop_srefs) {
+    const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+    // uniqueness check
+    auto inserted = loop_srefs.insert(loop_sref.get());
+    if (!inserted.second) {
+      throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop));
+    }
+  }
+  // Step 2. gather loops to be reordered
+  // For each loop, traverse upwards along the parent pointer, and stop on 
either a block, or a
+  // previously-visited loop
+  // - the top of the reorder range is the last loop visited in the first 
traverse which exists in
+  //   the input array
+  // - the bottom of the reorder range is the last loop in the input array 
which is not visited in
+  // the previous traverses
+  const StmtSRefNode* top = nullptr;
+  const StmtSRefNode* bottom = nullptr;
+  // Maps a parent sref to its child sref
+  std::unordered_map<const StmtSRefNode*, const StmtSRefNode*> successor;
+  for (size_t i = 0; i < ordered_loop_srefs.size(); i++) {
+    const StmtSRefNode* sref = ordered_loop_srefs[i].get();
+    // if sref is not visited before, update `bottom`
+    if (!successor.count(sref->parent)) {
+      bottom = sref;
+    }
+    while (true) {
+      // stop at blocknode
+      if (sref->stmt->IsInstance<BlockNode>()) {
+        if (i != 0) {
+          throw LoopsNotAChainError(self->mod, NullOpt,
+                                    
LoopsNotAChainError::ProblemKind::kNotUnderAScope);
+        } else {
+          break;
+        }
+      }
+      const StmtSRefNode* parent_sref = sref->parent;
+      // stop at previously-visited loop
+      if (successor.count(parent_sref)) {
+        if (successor[parent_sref] == sref) {
+          break;
+        } else {
+          throw LoopsNotAChainError(self->mod, GetRef<Stmt>(parent_sref->stmt),
+                                    
LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
+        }
+      } else {
+        successor[parent_sref] = sref;
+      }
+      // if it's the first traverse and the loop is in the input array, update 
`top`
+      if (loop_srefs.count(sref) && i == 0) {
+        top = sref;
+      }
+      sref = parent_sref;
+    }
+  }
+  // Step 3. Check that loops are single-branch
+  const ForNode* outer_loop = TVM_SREF_TO_FOR(outer_loop, 
GetRef<StmtSRef>(top));
+  for (const StmtSRefNode* loop_sref = top; loop_sref != bottom;) {
+    loop_sref = successor[loop_sref];
+    const ForNode* inner_loop = TVM_SREF_TO_FOR(inner_loop, 
GetRef<StmtSRef>(loop_sref));
+    if (outer_loop->body.get() != inner_loop) {
+      throw LoopsNotAChainError(self->mod, GetRef<For>(outer_loop),
+                                
LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
+    }

Review comment:
       Okay I thought a bit and agree you are right in this particular case. We 
dont need to add such a method for now, given the logic here is only a single 
line, but in general we need to add `HasOnlyChild` some day




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to