masahi commented on code in PR #11050:
URL: https://github.com/apache/tvm/pull/11050#discussion_r852501884


##########
src/tir/schedule/analysis/analysis.cc:
##########
@@ -2028,5 +2034,107 @@ bool NeedsRFactorOrCrossThreadReduction(const 
tir::ScheduleState& self,   //
   }
 }
 
+TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);
+
+Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
+                                                const tir::StmtSRef& 
block_sref,
+                                                const tir::PrimFunc& 
desc_func) {
+  arith::Analyzer analyzer;
+  const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
+  // Step 1. Analyze desc_func, extract its block, loops and loop vars
+  const tir::BlockRealizeNode* desc_block = nullptr;
+  std::vector<const tir::ForNode*> desc_loops;
+  std::unordered_set<const tir::VarNode*> desc_loop_vars;
+  const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
+  ICHECK(desc_scope_realize);
+  {
+    auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
+                    &analyzer](const ObjectRef& obj) -> bool {
+      // Extract the block
+      if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
+        desc_block = block;
+        return false;
+      }
+      // Extract loops
+      if (const auto* loop = obj.as<tir::ForNode>()) {
+        desc_loops.push_back(loop);
+        desc_loop_vars.insert(loop->loop_var.get());
+        if (!analyzer.CanProve(loop->min == 0)) {
+          return false;
+        }
+      }
+      return true;
+    };
+    tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
+    std::reverse(desc_loops.begin(), desc_loops.end());
+    ICHECK(desc_block);
+  }
+  // Step 2. Collect loops from block_sref
+  const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
+  const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, 
scope_sref);
+  std::vector<const tir::ForNode*> block_loops;
+  std::unordered_set<const tir::VarNode*> block_loop_vars;
+  {
+    for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = 
loop_sref->parent) {
+      const auto* loop = loop_sref->StmtAs<tir::ForNode>();
+      if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
+        break;
+      }
+      block_loops.push_back(loop);
+      block_loop_vars.insert(loop->loop_var.get());
+      if (!analyzer.CanProve(loop->min == 0)) {
+        return NullOpt;
+      }
+    }
+    std::reverse(block_loops.begin(), block_loops.end());
+  }
+  // Step 3. Map from block loops to desc block loops
+  ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
+  const int n_block_vars = block->iter_values.size();
+  const int n_desc_vars = desc_block->iter_values.size();
+  const int offset = n_block_vars - n_desc_vars;
+
+  if (offset < 0) {
+    return NullOpt;
+  }
+
+  const std::vector<IterVarType> iter_types_block = 
GetBlockVarTypes(block_sref);
+  const std::vector<IterVarType> iter_types_desc = 
GetBlockVarTypes(desc_block->block.get());
+
+  ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars));
+  ICHECK(block_loops.size() == iter_types_block.size());
+
+  int next_block_ind = block_loops.size() - 1;
+  for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) {
+    const tir::ForNode* desc_loop = desc_loops[i_desc];
+    const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>();
+    if (!int_desc_extent) continue;
+
+    for (int i_block = next_block_ind; i_block >= 0; --i_block) {
+      const tir::ForNode* block_loop = block_loops[i_block];
+      const IntImmNode* int_block_extent = block_loop->extent.as<IntImmNode>();
+
+      if (!int_block_extent) continue;
+      if (int_block_extent->value % int_desc_extent->value != 0) continue;
+      if (iter_types_block[i_block] != iter_types_desc[i_desc]) continue;
+
+      const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
+      ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
+      next_block_ind = i_block - 1;
+      break;
+    }
+  }

Review Comment:
   The logic here is very different from the one in the original code 
https://github.com/spectrometerHBH/tvm/blob/auto-tensorization/src/tir/schedule/analysis/analysis.cc#L1246.
 I was not able to understand why the original code has been written that way 
and it didn't work for the case where matching loops in the target block are 
not in the innermost positions (conv2d NCHWc on CPU, a test in 
https://github.com/apache/tvm/blob/d6ae84879d4eb7befc3fc07e0f967973f50ece16/tests/python/unittest/test_tir_schedule_analysis.py#L199).
 
   
   I think my change is simple and obvious. The condition for a match is (1) 
divisibility of loop extent and (2) matching iterator types (reduction vs 
spatial). Mapping is determined starting from the innermost axis.
   
   Please have a look at this change carefully, and let me know if I need to 
bring back some logic in the original code @spectrometerHBH @vinx13 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to