This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e997185  [Relay] Change some passes to mix mode (#6695)
e997185 is described below

commit e997185795480d24075a2e7d3fa42ccec425b5f6
Author: lixiaoquan <radiohe...@163.com>
AuthorDate: Fri Oct 16 23:47:27 2020 +0800

    [Relay] Change some passes to mix mode (#6695)
---
 src/relay/analysis/util.cc            |  8 ++++++--
 src/relay/analysis/well_formed.cc     | 16 +++++++---------
 src/relay/ir/expr_functor.cc          |  4 +++-
 src/relay/transforms/de_duplicate.cc  |  6 ++++--
 src/relay/transforms/fold_constant.cc | 32 ++++++++++++++++----------------
 5 files changed, 36 insertions(+), 30 deletions(-)

diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc
index 59ce01c..edf8fb6 100644
--- a/src/relay/analysis/util.cc
+++ b/src/relay/analysis/util.cc
@@ -71,7 +71,7 @@ class TypeVarTVisitor : public TypeVisitor {
   InsertionSet<TypeVar>* bound_type_vars_;
 };
 
-class TypeVarEVisitor : private ExprVisitor {
+class TypeVarEVisitor : private MixedModeVisitor {
  public:
   explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {}
 
@@ -131,6 +131,8 @@ class TypeVarEVisitor : private ExprVisitor {
     return CollectAll();
   }
 
+  using MixedModeVisitor::VisitExpr_;
+
   void VisitExpr_(const FunctionNode* f) final {
     for (const auto& tp : f->type_params) {
       type_vars_.Insert(tp);
@@ -159,7 +161,7 @@ class TypeVarEVisitor : private ExprVisitor {
   const IRModule& mod_;
 };
 
-class VarVisitor : protected ExprVisitor, protected PatternVisitor {
+class VarVisitor : protected MixedModeVisitor, protected PatternVisitor {
  public:
   Array<Var> Free(const Expr& expr) {
     this->VisitExpr(expr);
@@ -204,6 +206,8 @@ class VarVisitor : protected ExprVisitor, protected 
PatternVisitor {
     vars_.Insert(v);
   }
 
+  using MixedModeVisitor::VisitExpr_;
+
   void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
 
   void VisitExpr_(const FunctionNode* op) final {
diff --git a/src/relay/analysis/well_formed.cc 
b/src/relay/analysis/well_formed.cc
index 3e409d1..5abbbc9 100644
--- a/src/relay/analysis/well_formed.cc
+++ b/src/relay/analysis/well_formed.cc
@@ -32,7 +32,7 @@ namespace tvm {
 namespace relay {
 
 //! brief make sure each Var is bound at most once in a scope.
-class WellFormedChecker : private ExprVisitor, PatternVisitor {
+class WellFormedChecker : private MixedModeVisitor, PatternVisitor {
  public:
   Optional<DiagnosticContext> diag_ctx;
   Span occurs_in;
@@ -79,6 +79,8 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor 
{
     total_bound.insert(v);
   }
 
+  using MixedModeVisitor::VisitExpr_;
+
   void VisitExpr_(const VarNode* op) final {
     Var v = GetRef<Var>(op);
     if (current_bound.count(v) == 0) {
@@ -126,7 +128,7 @@ class WellFormedChecker : private ExprVisitor, 
PatternVisitor {
 
     // CHECK(call->attrs.defined());
     CHECK(call->type_args.defined());
-    ExprVisitor::VisitExpr_(call);
+    MixedModeVisitor::VisitExpr_(call);
   }
 
   void VisitClause(const Clause& c) final {
@@ -139,18 +141,14 @@ class WellFormedChecker : private ExprVisitor, 
PatternVisitor {
 
   void VisitVar(const Var& v) final { Bound(v); }
 
-  void VisitExpr(const Expr& e) final {
+ public:
+  bool CheckWellFormed(const Expr& e) {
     if (auto v = e.as<VarNode>()) {
       VisitExpr_(v);
     } else {
       // this->occurs_in = e->span;
-      ExprVisitor::VisitExpr(e);
+      VisitExpr(e);
     }
-  }
-
- public:
-  bool CheckWellFormed(const Expr& e) {
-    this->VisitExpr(e);
     return well_formed;
   }
 };
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index cbc41d2..a09179b 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -517,10 +517,12 @@ 
TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit").set_body_typed([](Expr ex
 });
 
 // Implement bind.
-class ExprBinder : public ExprMutator, PatternMutator {
+class ExprBinder : public MixedModeMutator, PatternMutator {
  public:
   explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) : 
args_map_(args_map) {}
 
+  using MixedModeMutator::VisitExpr_;
+
   Expr VisitExpr_(const LetNode* op) final {
     CHECK(!args_map_.count(op->var)) << "Cannot bind an internel variable in 
let";
     return ExprMutator::VisitExpr_(op);
diff --git a/src/relay/transforms/de_duplicate.cc 
b/src/relay/transforms/de_duplicate.cc
index d90e5c5..8c62fe6 100644
--- a/src/relay/transforms/de_duplicate.cc
+++ b/src/relay/transforms/de_duplicate.cc
@@ -31,7 +31,7 @@ namespace tvm {
 namespace relay {
 
 Expr DeDup(const Expr& e) {
-  class DeDupMutator : public TypeMutator, public ExprMutator, public 
PatternMutator {
+  class DeDupMutator : public TypeMutator, public MixedModeMutator, public 
PatternMutator {
    public:
     TypeVar Fresh(const TypeVar& tv) {
       TypeVar ret = TypeVar(tv->name_hint, tv->kind);
@@ -47,12 +47,14 @@ Expr DeDup(const Expr& e) {
       return ret;
     }
 
-    Expr VisitExpr(const Expr& e) final {
+    Expr DispatchVisitExpr(const Expr& e) final {
       auto ret = ExprMutator::VisitExpr(e);
       ret->checked_type_ = e->checked_type_;
       return ret;
     }
 
+    using MixedModeMutator::VisitExpr_;
+
     Expr VisitExpr_(const VarNode* op) final {
       Var v = GetRef<Var>(op);
       return rename_.count(v) != 0 ? rename_.at(v) : v;
diff --git a/src/relay/transforms/fold_constant.cc 
b/src/relay/transforms/fold_constant.cc
index 660aff2..8d2cba0 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -75,7 +75,7 @@ 
TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantChec
 
 // TODO(tvm-team) consider combine dead-code with constant folder.
 // or make a more powerful partial evaluator.
-class ConstantFolder : public ExprMutator {
+class ConstantFolder : public MixedModeMutator {
  public:
   explicit ConstantFolder(IRModule module)
       : module_(module),
@@ -89,6 +89,8 @@ class ConstantFolder : public ExprMutator {
         cast_op_(Op::Get("cast")),
         ndarray_size_op_(Op::Get("ndarray_size")) {}
 
+  using MixedModeMutator::VisitExpr_;
+
   Expr VisitExpr_(const LetNode* op) final {
     Expr value = this->Mutate(op->value);
     if (value.as<ConstantNode>()) {
@@ -118,7 +120,7 @@ class ConstantFolder : public ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const CallNode* call) final {
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
     if (inside_primitive) {
       return GetRef<Expr>(call);
     }
@@ -127,26 +129,25 @@ class ConstantFolder : public ExprMutator {
     std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", 
"full_like", "full"};
 
     auto origin_args = call->args;
-    Expr res = ExprMutator::VisitExpr_(call);
-    call = res.as<CallNode>();
+    call = post.as<CallNode>();
     // We don't constant fold function with zero arguments.
     // This is a heuristic that is useful.
     // For example it is harmful to fold ones(shape=(4, 5)).
-    if (call->args.size() == 0) return res;
+    if (call->args.size() == 0) return post;
     const OpNode* op = call->op.as<OpNode>();
-    if (op == nullptr) return res;
+    if (op == nullptr) return post;
     if (skip_list.count(op->name)) {
-      return res;
+      return post;
     }
     // skip stateful ops.
-    if (op_stateful.get(GetRef<Op>(op), false)) return res;
+    if (op_stateful.get(GetRef<Op>(op), false)) return post;
     // Try to evaluate shape_of op
     if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) {
-      return EvaluateShapeOf(res, origin_args, call->attrs);
+      return EvaluateShapeOf(post, origin_args, call->attrs);
     }
 
     if (call->op == ndarray_size_op_) {
-      return EvaluateNdarraySize(res, origin_args, call->attrs);
+      return EvaluateNdarraySize(post, origin_args, call->attrs);
     }
 
     // We should think about potentially constant evaluation over these ops 
too.
@@ -162,19 +163,18 @@ class ConstantFolder : public ExprMutator {
       }
     }
     if (all_const_args) {
-      return ConstEvaluate(res);
+      return ConstEvaluate(post);
     } else {
-      return res;
+      return post;
     }
   }
 
-  Expr VisitExpr_(const TupleGetItemNode* op) final {
-    Expr res = ExprMutator::VisitExpr_(op);
-    op = res.as<TupleGetItemNode>();
+  Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
+    op = post.as<TupleGetItemNode>();
     if (const auto* tuple = op->tuple.as<TupleNode>()) {
       return tuple->fields[op->index];
     } else {
-      return res;
+      return post;
     }
   }
 

Reply via email to