zhiics commented on a change in pull request #4927: [Relay][Pass] Add inline 
pass
URL: https://github.com/apache/incubator-tvm/pull/4927#discussion_r386043406
 
 

 ##########
 File path: src/relay/pass/inline.cc
 ##########
 @@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/pass/inline.cc
+ * \brief Global function inliner. It contains the following steps:
+ *
+ *  - Preprocessing: eligibility checking. Only inline the functions that can
+ *  be inlined. We currently only use simple rules to make the decision. No
+ *  profitibility analysis is available for now.
+ *
+ *  - Inline: replace the call with a function or the function body depending 
on
+ *  the attribute of the callee function. For example, we return the function
+ *  node when it doesn't use default compiler, i.e. llvm. This is because these
+ *  functions are packed to be offloaded to external codegen.
+ *
+ *  - Postprocessing: remove the replaced functions that have no reference.
+ */
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/support/logging.h>
+#include <tvm/relay/transform.h>
+#include <string>
+#include <unordered_set>
+
+#include "call_graph.h"
+
+using namespace tvm::runtime;
+
+namespace tvm {
+namespace relay {
+
+class Inliner : ExprMutator {
+ public:
+  explicit Inliner(CallGraphEntry* cur_node, CallGraphNode* call_graph)
+      : cur_node_(cur_node), call_graph_(call_graph) {}
+
+  Expr VisitExpr_(const CallNode* call_node) final {
+    Expr op = call_node->op;
+    const auto* gvn = op.as<GlobalVarNode>();
+
+    if (gvn) {
+      GlobalVar gv = GetRef<GlobalVar>(gvn);
+      auto* cg_node = (*call_graph_)[gv->name_hint];
+      if (CanInline(cg_node)) {
+        tvm::Array<Expr> call_args;
+        for (auto arg : call_node->args) {
+          auto new_arg = VisitExpr(arg);
+          call_args.push_back(new_arg);
+        }
+        cur_node_->RemoveCallTo(gv);
+        return MakeNewExpr(gv, call_args, GetRef<Call>(call_node));
+      }
+    }
+    return ExprMutator::VisitExpr_(call_node);
+  }
+
+  Expr VisitExpr_(const GlobalVarNode* gvn) final {
+    GlobalVar gv = GetRef<GlobalVar>(gvn);
+    auto* cg_node = (*call_graph_)[gv->name_hint];
+    if (CanInline(cg_node)) {
+      cur_node_->RemoveCallTo(gv);
+      return MakeNewExpr(gv, {}, GetRef<GlobalVar>(gvn));
+    }
+    return ExprMutator::VisitExpr_(gvn);
+  }
+
+  Function Inline(const Function& func) {
+    return FunctionNode::make(func->params,
+                              VisitExpr(func->body),
+                              func->ret_type,
+                              func->type_params,
+                              func->attrs);
+  }
+
+ private:
+  bool CanInline(const CallGraphEntry* cg_node) {
+    // The node must be a leaf node and it cannot be recursive.
+    if (!cg_node->empty() || cg_node->IsRecursive()) return false;
+
+    auto base_func = call_graph_->GetGlobalFunction(cg_node->GetGlobalVar());
+    auto func = Downcast<Function>(base_func);
+    // The body of a global functions must be defined.
+    if (!func->body.defined()) return false;
+
+    // The function must be annotated with the inline attribute.
+    if (!func->IsMarkedInline()) return false;
+
+    // The function is not abled to be inlined if any callee under the 
CallGraph
+    // of this function cannot be inlined.
+    for (const auto& it : *cg_node) {
+      if (!CanInline(it.second)) {
+        return false;
+      }
+    }
+
+    return true;
+  }
+
+  // Make a new Relay expression to replace the caller.
 
 Review comment:
   Sorry. I missed read it, you are right.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to