eric-haibin-lin commented on a change in pull request #6152:
URL: https://github.com/apache/incubator-tvm/pull/6152#discussion_r464588608



##########
File path: src/relay/transforms/to_a_normal_form.cc
##########
@@ -81,171 +70,187 @@ std::unordered_map<DependencyGraph::Node*, Scope> 
CalcScope(const DependencyGrap
       global_scope_used = true;
     } else {
       s = expr_scope.at(iit->value);
+      const auto original_s = s;
       iit = iit->next;
       for (; iit != nullptr; iit = iit->next) {
         s = LCA(s, expr_scope.at(iit->value));
       }
+      if (s != original_s && node_to_expr.find(n) != node_to_expr.end()) {
+        // filter out exprs whose scope do not matter
+        Expr expr = node_to_expr[n];
+        if (!expr.as<OpNode>()) {
+          lifted_exprs.insert(expr);
+        }
+      }
+    }
+    if (n->new_scope) {
+      auto child_scope = std::make_shared<ScopeNode>(s);
+      expr_scope.insert({n, child_scope});
+    } else {
+      expr_scope.insert({n, s});
     }
-    expr_scope.insert({n, n->new_scope ? ChildScope(s) : s});
   }
   CHECK(global_scope_used);
-  return expr_scope;
+  return std::make_pair(expr_scope, lifted_exprs);
 }
 
-/* Special care is needed to handle local recursion.
- * Fill additionally take a (possibly null) Var argument,
- * If it is not null, Fill is required to bind the transformed result to that 
var.
- */
-class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
- public:
-  static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg,
-                            std::unordered_map<DependencyGraph::Node*, Scope>* 
node_scope) {
-    Fill fi(dg, node_scope);
-    return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
-  }
-
- private:
-  const DependencyGraph& dg_;
-  std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
-  std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo;
+Expr Fill::ToANormalForm(const Expr& e, const DependencyGraph& dg, 
NodeScopeMap* node_scope) {
+  Fill fi(dg, node_scope, nullptr);
+  return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
+}
 
-  Fill(const DependencyGraph& dg, std::unordered_map<DependencyGraph::Node*, 
Scope>* node_scope)
-      : dg_(dg), node_scope_(node_scope) {}
+// For basic block normal form, bind expressions only if the original 
expression's scope
+// should be lifted
+Expr Fill::ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg,
+                                  NodeScopeMap* node_scope, ExprSet* lifted) {
+  Fill fi(dg, node_scope, lifted);
+  auto var = fi.VisitExpr(e);
+  return fi.GetScope(e)->ll->Get(var);
+}
 
-  Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); 
}
+Scope Fill::GetScope(const Expr& e) { return 
node_scope_->at(dg_.expr_node.at(e)); }
 
-  Scope GetSubScope(const Expr& e, size_t i) {
-    DependencyGraph::Node* n = dg_.expr_node.at(e);
-    auto h = n->children.head;
-    while (i != 0) {
-      CHECK(h);
-      --i;
-      h = h->next;
-    }
+Scope Fill::GetSubScope(const Expr& e, size_t i) {
+  DependencyGraph::Node* n = dg_.expr_node.at(e);
+  auto h = n->children.head;
+  while (i != 0) {
     CHECK(h);
-    return node_scope_->at(h->value);
+    --i;
+    h = h->next;
   }
+  CHECK(h);
+  return node_scope_->at(h->value);
+}
 
-  Expr VisitExpr(const Expr& e, const Var& v) final {
-    if (memo.count(e) == 0) {
-      memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, 
v)});
-    } else if (v.defined()) {
-      GetScope(e)->ll->Push(v, memo.at(e));
-    }
-    auto ret = memo.at(e);
-    CHECK(IsAtomic(ret));
-    return ret;
+Expr Fill::VisitExpr(const Expr& e, const Var& v) {
+  if (memo.count(e) == 0) {
+    memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, 
v)});
+  } else if (v.defined()) {
+    GetScope(e)->ll->Push(v, memo.at(e));
   }
+  auto ret = memo.at(e);
+  // if no include_set is specified, every expression should be atomic.
+  if (include_set_ == nullptr) CHECK(IsAtomic(ret));
+  return ret;
+}
 
-  Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); }
+Expr Fill::VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); }
 
-  Expr Atomic(const Expr& e, const Var& v) { return v.defined() ? 
GetScope(e)->ll->Push(v, e) : e; }
+Expr Fill::Atomic(const Expr& e, const Var& v) {
+  return v.defined() ? GetScope(e)->ll->Push(v, e) : e;
+}
 
-  Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
-    Var var = v.defined() ? v : Var(String("x"), Type());
+// Bind expression `now` to var `v` if the original expression is in the 
include set, or if
+// v is already defined (e.g. coming from a Let expression). Otherwise return 
`now` directly
+Expr Fill::Compound(const Expr& orig, const Expr& now, const Var& v) {
+  Var var = v.defined() ? v : Var(String("x"), Type());
+  bool not_included = include_set_ && include_set_->find(orig) == 
include_set_->end();
+  if (!v.defined() && not_included) {
+    return now;
+  } else {
     return GetScope(orig)->ll->Push(var, now);
   }
+}
 
-  Expr VisitExpr_(const CallNode* c, const Var& v) final {
-    Expr e = GetRef<Expr>(c);
-    std::vector<Expr> args;
-    for (const auto& a : c->args) {
-      args.push_back(VisitExpr(a));
-    }
-    return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), 
v);
+Expr Fill::VisitExpr_(const CallNode* c, const Var& v) {

Review comment:
       `final` is still used in the header at declaration 
https://en.cppreference.com/w/cpp/language/final#Syntax 

##########
File path: src/relay/transforms/pass_util.h
##########
@@ -184,6 +189,89 @@ struct TreeBranchNode : TreeNode<ConditionObjectPtr> {
   ~TreeBranchNode() {}
 };
 
+struct ScopeNode;
+using Scope = std::shared_ptr<ScopeNode>;
+using NodeScopeMap = std::unordered_map<DependencyGraph::Node*, Scope>;
+using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
+
+/* Invariant: when parent is null level is 0
+ * Invariant: when parent is not null level is 1 + parent->level
+ */
+struct ScopeNode {
+  // the level of the scope
+  size_t level;
+  // the parent scope
+  Scope parent;
+  // the corresponding let list which holds all let bindings in the scope
+  std::shared_ptr<LetList> ll = std::make_shared<LetList>();
+  explicit ScopeNode(const Scope& parent) : level(1 + parent->level), 
parent(parent) {}
+  ScopeNode() : level(0) {}
+};
+
+/*! \brief Calculate the scope of nodes in the dependency graph by least 
common ancestor.
+ *
+ *  \param dg the input dependency graph
+ *  \param expr_scope the output node -> scope mapping for all nodes.
+ *  \param lifted_exprs the output set of expressions whose scope is lifted 
due to dependency
+ */
+std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg);
+
+/*! \brief find the least common ancestor of lhs scope and rhs scope.
+ */
+Scope LCA(Scope lhs, Scope rhs);
+
+/* Special care is needed to handle local recursion.
+ * Fill additionally take a (possibly null) Var argument,
+ * If it is not null, Fill is required to bind the transformed result to that 
var.
+ */
+class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
+ public:
+  static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, 
NodeScopeMap* node_scope);
+
+  // For basic block normal form, bind expressions only if the original 
expression's
+  // scope should be lifted
+  static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg,
+                                     NodeScopeMap* node_scope, ExprSet* 
lifted);

Review comment:
       thanks for bringing this up. I was also thinking about how to best 
leverage the existing code without too much ad-hoc code or duplication. I think 
adding an inclusion/exclusion argument to `Fill` is still acceptable, and the 
Fill constructor is private anyway

##########
File path: src/relay/transforms/to_basic_block_normal_form.cc
##########
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_basic_block_normal_form.cc
+ *
+ * \brief Turn an expression to the basic normal form.
+ */
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/support/logging.h>
+
+#include "../../support/arena.h"
+#include "../analysis/dependency_graph.h"
+#include "let_list.h"
+#include "pass_util.h"
+
+namespace tvm {
+namespace relay {
+
+Expr ToBasicBlockNormalFormAux(const Expr& e) {

Review comment:
       I guess it's because of https://en.wikipedia.org/wiki/Basic_block  ? 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to