AndrewZhaoLuo commented on a change in pull request #9735:
URL: https://github.com/apache/tvm/pull/9735#discussion_r771638737



##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
   }
 }
 
+class SameTypedSubgraphExtractor : public ExprMutator {
+  /*
+  Creates a small subgraph with the same type as the input expression. We 
attempt to do
+  by depending on existing type information being populated in expressions the 
target
+  node depends on. If a node with populated type information is found we simply
+  replace it with a variable of that type. In this way, we can avoid copying 
and
+  recursing through most of the expression graph. Note, this assumes that 
current
+  populated type information is correct!
+
+  ExprMutator is sufficient over MixedModemutator since we will not recurse 
much.
+  */
+
+  Expr VisitExpr_(const VarNode* op) { return Var(op->vid, 
op->type_annotation, op->span); }
+  Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data, 
op->span); }
+  Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); }
+  Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); }
+  Expr VisitExpr_(const TupleNode* op) {
+    return Tuple(get_analogous_expression(op->fields), op->span);
+  }
+  Expr VisitExpr_(const FunctionNode* op) {
+    // Here will be the only VisitExpr
+    return Function(op->params, get_analogous_expression(op->body), 
op->ret_type, op->type_params,
+                    op->attrs, op->span);
+  }
+  Expr VisitExpr_(const CallNode* op) {
+    return Call(op->op, get_analogous_expression(op->args), op->attrs, 
op->type_args, op->span);
+  }
+  Expr VisitExpr_(const LetNode* op) {
+    return Let(op->var, get_analogous_expression(op->value), 
get_analogous_expression(op->body),
+               op->span);
+  }
+  Expr VisitExpr_(const IfNode* op) {
+    return If(get_analogous_expression(op->cond), 
get_analogous_expression(op->true_branch),
+              get_analogous_expression(op->false_branch), op->span);
+  }
+  Expr VisitExpr_(const TupleGetItemNode* op) {
+    return TupleGetItem(get_analogous_expression(op->tuple), op->index, 
op->span);
+  }
+  Expr VisitExpr_(const RefCreateNode* op) {
+    return RefCreate(get_analogous_expression(op->value), op->span);
+  }
+  Expr VisitExpr_(const RefReadNode* op) {
+    return RefRead(get_analogous_expression(op->ref), op->span);
+  }
+  Expr VisitExpr_(const RefWriteNode* op) {
+    return RefWrite(get_analogous_expression(op->ref), 
get_analogous_expression(op->value),
+                    op->span);
+  }
+  Expr VisitExpr_(const ConstructorNode* op) {
+    return Constructor(op->name_hint, op->inputs, op->belong_to);
+  }
+  Expr VisitExpr_(const MatchNode* op) {
+    return Match(get_analogous_expression(op->data), op->clauses, 
op->complete, op->span);
+  }
+
+ private:
+  Expr get_analogous_expression(const Expr& expr) {
+    // Replace the expression with a potentially simpler expression of the 
same type
+    if (!expr->checked_type_.defined()) {
+      return VisitExpr(expr);
+    }
+
+    return Var("dummy_var", expr->checked_type(), expr->span);

Review comment:
       Done

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
   }
 }
 
+class SameTypedSubgraphExtractor : public ExprMutator {
+  /*
+  Creates a small subgraph with the same type as the input expression. We 
attempt to do
+  by depending on existing type information being populated in expressions the 
target
+  node depends on. If a node with populated type information is found we simply
+  replace it with a variable of that type. In this way, we can avoid copying 
and
+  recursing through most of the expression graph. Note, this assumes that 
current
+  populated type information is correct!
+
+  ExprMutator is sufficient over MixedModemutator since we will not recurse 
much.
+  */
+
+  Expr VisitExpr_(const VarNode* op) { return Var(op->vid, 
op->type_annotation, op->span); }
+  Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data, 
op->span); }
+  Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); }
+  Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); }
+  Expr VisitExpr_(const TupleNode* op) {
+    return Tuple(get_analogous_expression(op->fields), op->span);
+  }
+  Expr VisitExpr_(const FunctionNode* op) {
+    // Here will be the only VisitExpr
+    return Function(op->params, get_analogous_expression(op->body), 
op->ret_type, op->type_params,
+                    op->attrs, op->span);
+  }
+  Expr VisitExpr_(const CallNode* op) {
+    return Call(op->op, get_analogous_expression(op->args), op->attrs, 
op->type_args, op->span);
+  }
+  Expr VisitExpr_(const LetNode* op) {
+    return Let(op->var, get_analogous_expression(op->value), 
get_analogous_expression(op->body),
+               op->span);
+  }
+  Expr VisitExpr_(const IfNode* op) {
+    return If(get_analogous_expression(op->cond), 
get_analogous_expression(op->true_branch),
+              get_analogous_expression(op->false_branch), op->span);
+  }
+  Expr VisitExpr_(const TupleGetItemNode* op) {
+    return TupleGetItem(get_analogous_expression(op->tuple), op->index, 
op->span);
+  }
+  Expr VisitExpr_(const RefCreateNode* op) {
+    return RefCreate(get_analogous_expression(op->value), op->span);
+  }
+  Expr VisitExpr_(const RefReadNode* op) {
+    return RefRead(get_analogous_expression(op->ref), op->span);
+  }
+  Expr VisitExpr_(const RefWriteNode* op) {
+    return RefWrite(get_analogous_expression(op->ref), 
get_analogous_expression(op->value),
+                    op->span);
+  }
+  Expr VisitExpr_(const ConstructorNode* op) {
+    return Constructor(op->name_hint, op->inputs, op->belong_to);
+  }
+  Expr VisitExpr_(const MatchNode* op) {
+    return Match(get_analogous_expression(op->data), op->clauses, 
op->complete, op->span);
+  }
+
+ private:
+  Expr get_analogous_expression(const Expr& expr) {

Review comment:
       Done

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
   }
 }
 
+class SameTypedSubgraphExtractor : public ExprMutator {
+  /*

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to