merrymercy commented on a change in pull request #6103:
URL: https://github.com/apache/incubator-tvm/pull/6103#discussion_r459749490



##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -114,7 +118,432 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& 
tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class TensorAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    
buf_accesses[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                      
op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  OperationMap<std::vector<std::vector<PrimExpr>>> buf_accesses;
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with a const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  if (auto pv = expr.as<VarNode>()) {
+    return pv == var.get();
+  } else if (auto padd = expr.as<AddNode>()) {
+    return ((padd->a.get() == var.get() && padd->b->IsInstance<IntImmNode>()) 
||
+            (padd->b.get() == var.get() && padd->a->IsInstance<IntImmNode>()));
+  } else if (auto psub = expr.as<SubNode>()) {
+    return ((psub->a.get() == var.get() && psub->b->IsInstance<IntImmNode>()) 
||
+            (psub->b.get() == var.get() && psub->a->IsInstance<IntImmNode>()));
+  } else {
+    return false;
+  }
+}
+
+// Return whether the access is injective
+bool IsInjective(const te::Operation& op, const std::vector<PrimExpr>& index, 
bool* axis_missing,
+                 bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : index) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+static void GatherVars(const PrimExpr& expr, std::unordered_set<const 
VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+static bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // get all ops
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      TensorAccessExtractor extractor;
+      for (const auto& exp : cop->body) {

Review comment:
       No. It will be addressed later.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to