mbs-octoml commented on a change in pull request #9735:
URL: https://github.com/apache/tvm/pull/9735#discussion_r771001558



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -381,6 +381,18 @@ class MixedPrecisionPass : public MixedModeMutator {
     return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, 
pre_call_node->span);
   }
 
+  Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) {
+    // The old checked type in the expression may not be valid so clear it
+    post->checked_type_ = Type(nullptr);

Review comment:
       am I missing something or will checked_type_ = null iff some 
sub-expression of post has been rewritten and thus it's type has changed?
   ie checked_type_ is non-null only if pre == post.get() ??
   

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -176,13 +176,13 @@ class MixedPrecisionPass : public MixedModeMutator {
   }
 
   Type GetType(const Expr& expr) const {
-    auto mod = IRModule::FromExpr(expr);
-    mod = transform::InferType()(mod);
-    if (expr.as<FunctionNode>()) {
-      return mod->Lookup("main")->checked_type();
-    } else {
-      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    Type checked_type = expr->checked_type_;
+    if (checked_type.defined()) {
+      return checked_type;

Review comment:
       // The expression has not been changed AND it's existing type
   // is known to still be valid. (See special handling for tuples etc
   // below for where we null out checked_type_ when we can not
   // sure it is still valid.
   
   (though see my comment below)

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
   }
 }
 
+class SameTypedSubgraphExtractor : public ExprMutator {
+  /*

Review comment:
       nit: Returns  the largest sub-graph who's inner nodes need types and 
leaves are vars standing in
   for already typed sub-expressions.

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
   }
 }
 
+class SameTypedSubgraphExtractor : public ExprMutator {
+  /*

Review comment:
       micro nit: move to before class, used /*! etc.

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
   }
 }
 
+class SameTypedSubgraphExtractor : public ExprMutator {
+  /*
+  Creates a small subgraph with the same type as the input expression. We 
attempt to do
+  by depending on existing type information being populated in expressions the 
target
+  node depends on. If a node with populated type information is found we simply
+  replace it with a variable of that type. In this way, we can avoid copying 
and
+  recursing through most of the expression graph. Note, this assumes that 
current
+  populated type information is correct!
+
+  ExprMutator is sufficient over MixedModemutator since we will not recurse 
much.
+  */
+
+  Expr VisitExpr_(const VarNode* op) { return Var(op->vid, 
op->type_annotation, op->span); }
+  Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data, 
op->span); }
+  Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); }
+  Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); }
+  Expr VisitExpr_(const TupleNode* op) {
+    return Tuple(get_analogous_expression(op->fields), op->span);
+  }
+  Expr VisitExpr_(const FunctionNode* op) {
+    // Here will be the only VisitExpr
+    return Function(op->params, get_analogous_expression(op->body), 
op->ret_type, op->type_params,
+                    op->attrs, op->span);
+  }
+  Expr VisitExpr_(const CallNode* op) {
+    return Call(op->op, get_analogous_expression(op->args), op->attrs, 
op->type_args, op->span);
+  }
+  Expr VisitExpr_(const LetNode* op) {
+    return Let(op->var, get_analogous_expression(op->value), 
get_analogous_expression(op->body),
+               op->span);
+  }
+  Expr VisitExpr_(const IfNode* op) {
+    return If(get_analogous_expression(op->cond), 
get_analogous_expression(op->true_branch),
+              get_analogous_expression(op->false_branch), op->span);
+  }
+  Expr VisitExpr_(const TupleGetItemNode* op) {
+    return TupleGetItem(get_analogous_expression(op->tuple), op->index, 
op->span);
+  }
+  Expr VisitExpr_(const RefCreateNode* op) {
+    return RefCreate(get_analogous_expression(op->value), op->span);
+  }
+  Expr VisitExpr_(const RefReadNode* op) {
+    return RefRead(get_analogous_expression(op->ref), op->span);
+  }
+  Expr VisitExpr_(const RefWriteNode* op) {
+    return RefWrite(get_analogous_expression(op->ref), 
get_analogous_expression(op->value),
+                    op->span);
+  }
+  Expr VisitExpr_(const ConstructorNode* op) {
+    return Constructor(op->name_hint, op->inputs, op->belong_to);
+  }
+  Expr VisitExpr_(const MatchNode* op) {
+    return Match(get_analogous_expression(op->data), op->clauses, 
op->complete, op->span);
+  }
+
+ private:
+  Expr get_analogous_expression(const Expr& expr) {
+    // Replace the expression with a potentially simpler expression of the 
same type
+    if (!expr->checked_type_.defined()) {
+      return VisitExpr(expr);
+    }
+
+    return Var("dummy_var", expr->checked_type(), expr->span);

Review comment:
       // Since the expression already has a checked_type which we trust we 
don't need
   // full type inference to enter it. So stub it out with a dummy var of the 
same type. 

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
   }
 }
 
+class SameTypedSubgraphExtractor : public ExprMutator {
+  /*
+  Creates a small subgraph with the same type as the input expression. We 
attempt to do
+  by depending on existing type information being populated in expressions the 
target
+  node depends on. If a node with populated type information is found we simply
+  replace it with a variable of that type. In this way, we can avoid copying 
and
+  recursing through most of the expression graph. Note, this assumes that 
current
+  populated type information is correct!
+
+  ExprMutator is sufficient over MixedModemutator since we will not recurse 
much.
+  */
+
+  Expr VisitExpr_(const VarNode* op) { return Var(op->vid, 
op->type_annotation, op->span); }
+  Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data, 
op->span); }
+  Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); }
+  Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); }
+  Expr VisitExpr_(const TupleNode* op) {
+    return Tuple(get_analogous_expression(op->fields), op->span);
+  }
+  Expr VisitExpr_(const FunctionNode* op) {
+    // Here will be the only VisitExpr
+    return Function(op->params, get_analogous_expression(op->body), 
op->ret_type, op->type_params,
+                    op->attrs, op->span);
+  }
+  Expr VisitExpr_(const CallNode* op) {
+    return Call(op->op, get_analogous_expression(op->args), op->attrs, 
op->type_args, op->span);
+  }
+  Expr VisitExpr_(const LetNode* op) {
+    return Let(op->var, get_analogous_expression(op->value), 
get_analogous_expression(op->body),
+               op->span);
+  }
+  Expr VisitExpr_(const IfNode* op) {
+    return If(get_analogous_expression(op->cond), 
get_analogous_expression(op->true_branch),
+              get_analogous_expression(op->false_branch), op->span);
+  }
+  Expr VisitExpr_(const TupleGetItemNode* op) {
+    return TupleGetItem(get_analogous_expression(op->tuple), op->index, 
op->span);
+  }
+  Expr VisitExpr_(const RefCreateNode* op) {
+    return RefCreate(get_analogous_expression(op->value), op->span);
+  }
+  Expr VisitExpr_(const RefReadNode* op) {
+    return RefRead(get_analogous_expression(op->ref), op->span);
+  }
+  Expr VisitExpr_(const RefWriteNode* op) {
+    return RefWrite(get_analogous_expression(op->ref), 
get_analogous_expression(op->value),
+                    op->span);
+  }
+  Expr VisitExpr_(const ConstructorNode* op) {
+    return Constructor(op->name_hint, op->inputs, op->belong_to);
+  }
+  Expr VisitExpr_(const MatchNode* op) {
+    return Match(get_analogous_expression(op->data), op->clauses, 
op->complete, op->span);
+  }
+
+ private:
+  Expr get_analogous_expression(const Expr& expr) {

Review comment:
       nit: GetAnalogousExpression




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to