junrushao1994 commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r458573185



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& 
tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {

Review comment:
       The extractor basically does two things:
   1) Check if there are branches
   2) For each op, figure out where it is read and save as a list of 
multi-dimensional indices
   
   So I think the class name might be misleading, because writing a tensor is 
not counted. Let's find a better name.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& 
tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      
op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) 
||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) 
||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, 
bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const 
VarNode*>* vars) {

Review comment:
       It doesn't have to be static (it might interfere with backtrace in error 
reporting)
   
   ```suggestion
   std::unordered_set<const VarNode*> ExprGatherVars(const PrimExpr& expr) {
   ```

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& 
tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      
op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) 
||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) 
||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, 
bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const 
VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {

Review comment:
       ditto

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& 
tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      
op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) 
||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) 
||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, 
bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const 
VarNode*>* vars) {

Review comment:
       consider moving to utils.h?

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& 
tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      
op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) 
||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) 
||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));

Review comment:
       It is a bit counter-intuitive. Let's use pattern matching in tvm/arith 
instead.

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -126,6 +555,7 @@ class FlopEstimator : public ExprFunctor<double(const 
PrimExpr& n)> {
           fail_ = true;
           break;
         }
+        cur_type_code_ = pop->output_dtype(0).code();

Review comment:
       could you elaborate why we need `cur_type_code_`? how do we deal with 
the case that computation is mixed with int8 and fp32?

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& 
tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      
op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) 
||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) 
||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, 
bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const 
VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.buf_accesses) {
+        std::vector<std::vector<PrimExpr>>& accesses = 
node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), 
iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.buf_accesses);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+        int n_common;
+        for (n_common = 0;
+             n_common < static_cast<int>(std::min(output_shape.size(), 
producer_shape.size()));
+             n_common++) {
+          if (!is_zero(analyzer.Simplify(output_shape[n_common] - 
producer_shape[n_common]))) {
+            break;
+          }
+
+          bool direct_access = true;
+          for (const auto& access : access_list) {
+            if (!IsConstShiftEqual(cop->axis[n_common]->var, 
access[n_common])) {
+              direct_access = false;
+              break;
+            }
+          }
+
+          if (!direct_access) {
+            break;
+          }
+        }
+
+        node->num_common_outer_iterators[op][producer] = n_common;
+        node->num_common_outer_iterators[producer][op] = n_common;
+      }
+    } else {
+      LOG(FATAL) << "Invalid op: " << op;
+    }
+  }
+
+  // do some static analysis
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->is_injective[op] = true;
+      node->needs_multi_level_tiling[op] = false;
+      node->is_strict_inlineable[op] = false;
+      node->is_output[op] = false;
+    } else if (auto pop = op.as<te::ComputeOpNode>()) {
+      // check whether this op is element-wise and strict-inlineable
+      bool is_injective = true;
+      bool is_strict_inlineable = true;
+
+      bool axis_missing, axis_duplicated, same_order;
+      for (const auto& pair : node->read_from[op]) {
+        const std::vector<std::vector<PrimExpr>>& access = pair.second;
+        for (const auto& index : access) {
+          if (!auto_scheduler::IsInjective(op, index, &axis_missing, 
&axis_duplicated,
+                                           &same_order)) {
+            is_injective = false;
+            is_strict_inlineable = false;
+            break;
+          }
+          if (!same_order || axis_duplicated) {
+            // do not strictly inline transpose
+            is_strict_inlineable = false;
+          }
+        }
+        if (!is_injective) {
+          break;
+        }
+      }
+      if (has_branch[op]) {
+        is_strict_inlineable = false;
+      }
+
+      // don't strictly inline expensive op (e.g. exp)
+      bool has_expensive_op = false;
+      for (const auto& expr : pop->body) {
+        has_expensive_op |= HasExpensiveOp(expr);
+      }
+
+      node->is_injective[op] = is_injective;
+      node->is_strict_inlineable[op] = is_strict_inlineable && 
!has_expensive_op;
+
+      // check whether the op needs multi-level tiling
+      bool needs_multi_level_tiling = false;
+      int n_missing = 0;
+
+      for (const auto& pair : node->read_from[op]) {
+        const std::vector<std::vector<PrimExpr>>& access = pair.second;
+        std::unordered_set<const VarNode*> vars;
+        for (const std::vector<PrimExpr>& indices : access) {
+          for (const PrimExpr& expr : indices) {
+            GatherVars(expr, &vars);
+          }
+        }
+        bool missing = false;
+        for (const auto& axis : pop->axis) {
+          if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) 
== 0) {
+            missing = true;
+          }
+        }
+        if (missing) {
+          n_missing++;
+        }
+
+        if (n_missing >= 2 || (n_missing >= 1 && !pop->reduce_axis.empty())) {
+          needs_multi_level_tiling = true;
+          break;
+        }
+      }
+
+      node->needs_multi_level_tiling[op] = needs_multi_level_tiling;
+
+      // check whether is output
+      node->is_output[op] = node->read_by[op].empty();
+    } else {
+      LOG(FATAL) << "Invalid op" << op;
+    }
+  }
+
+  data_ = std::move(node);
+}
+
+bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation& op) const {
+  return operator->()->needs_multi_level_tiling.at(op);
+}
+
+bool AccessAnalyzer::IsOutput(const te::Operation& op) const {
+  return operator->()->is_output.at(op);
+}
+
+bool AccessAnalyzer::IsInjective(const te::Operation& op) const {
+  return operator->()->is_injective.at(op);
+}
+
+bool AccessAnalyzer::IsStrictInlineable(const te::Operation& op) const {
+  return operator->()->is_strict_inlineable.at(op);
+}
+
+void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op,
+                                  OperationSet* consumers) const {
+  OperationSet inlined_ops;
+  for (const auto& stage : state->stages) {
+    if (stage->compute_at == ComputeAtKind::kInlined) {
+      inlined_ops.insert(stage->op);
+    }
+  }
+
+  std::function<void(const te::Operation&)> collect;
+  collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& 
op) {
+    for (const auto& iter : operator->()->read_by.at(op)) {
+      if (inlined_ops.count(iter.first)) {
+        collect(iter.first);
+      } else {
+        consumers->insert(iter.first);
+      }
+    }
+  };
+
+  consumers->clear();
+  collect(op);
+}
+
+void AccessAnalyzer::GetDirectProducers(const te::Operation& op, OperationSet* 
producers) const {
+  producers->clear();
+  for (const auto& iter : operator->()->read_from.at(op)) {
+    producers->insert(iter.first);
+  }
+}
+
+void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op,
+                                  OperationSet* producers) const {
+  OperationSet inlined_ops;
+  for (const auto& stage : state->stages) {
+    if (stage->compute_at == ComputeAtKind::kInlined) {
+      inlined_ops.insert(stage->op);
+    }
+  }
+
+  std::function<void(const te::Operation&)> collect;
+  collect = [this, &collect, &inlined_ops, &producers](const te::Operation& 
op) {
+    for (const auto& iter : operator->()->read_from.at(op)) {
+      if (inlined_ops.count(iter.first)) {
+        collect(iter.first);
+      } else {
+        producers->insert(iter.first);
+      }
+    }
+  };
+
+  producers->clear();
+  collect(op);
+}
+
+int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op,
+                                              const te::Operation& target_op) 
const {
+  int ret = INT32_MAX;
+  bool meet = false;
+
+  std::function<void(const te::Operation&, int)> traverse;
+  traverse = [this, &traverse, &target_op, &ret, &meet](const te::Operation& 
cur_op, int cur_num) {
+    if (cur_op == target_op) {
+      ret = std::min(ret, cur_num);
+      meet = true;
+      return;
+    }
+
+    for (const auto& iter : operator->()->read_by.at(cur_op)) {
+      traverse(
+          iter.first,
+          std::min(cur_num, 
operator->()->num_common_outer_iterators.at(cur_op).at(iter.first)));
+    }
+  };
+
+  traverse(op, op->output_shape(0).size());
+  return meet ? ret : 0;
+}
+
+// Return whether two int arrays are elementwise-equal
+bool IntArrayEqual(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {

Review comment:
       Moved to utils.h?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to