zhiics commented on a change in pull request #4922: [Relay][pass] call graph 
for relay
URL: https://github.com/apache/incubator-tvm/pull/4922#discussion_r383653603
 
 

 ##########
 File path: src/relay/pass/call_graph.h
 ##########
 @@ -0,0 +1,509 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/pass/call_graph.h
+ * \brief Define data structures for the call graph of a IRModule. It borrows
+ * the idea how LLVM constructs CallGraph.
+ *
+ * https://llvm.org/doxygen/CallGraph_8h_source.html
+ */
+
+#ifndef TVM_RELAY_PASS_CALL_GRAPH_H_
+#define TVM_RELAY_PASS_CALL_GRAPH_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/expr.h>
+#include <tvm/runtime/object.h>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+class CallGraphEntryNode;
+class CallGraph;
+
+class CallGraphNode : public Object {
+  using CallGraphMap =
+      std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntryNode>, 
ObjectHash,
+                         ObjectEqual>;
+  // Create iterator alias for a CallGraphNode object.
+  using iterator = CallGraphMap::iterator;
+  using const_iterator = CallGraphMap::const_iterator;
+
+ public:
+  /*! \brief The IR module for creating a CallGraphNode. */
+  IRModule module;
+
+  /*! \brief Default constructor. */
+  CallGraphNode() {}
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("module", &module);
+  }
+
+  /*!
+   * \brief Print the call graph.
+   *
+   * \param os The stream for printing.
+   */
+  void Print(std::ostream& os) const;
+
+  /*! \return The begin iterator. */
+  iterator begin() {
+    return call_graph_.begin();
+  }
+  /*! \return The end iterator. */
+  iterator end() {
+    return call_graph_.end();
+  }
+  /*! \return The begin iterator. */
+  const_iterator begin() const {
+    return call_graph_.begin();
+  }
+  /*! \return The end iterator. */
+  const_iterator end() const {
+    return call_graph_.end();
+  }
+
+  /*!
+   * \brief Get an element from the CallGraphNode using a GlobalVar.
+   *
+   * \param gv The GlobalVar used for indexing.
+   *
+   * \return The fetched element.
+   */
+  const CallGraphEntryNode* operator[](const GlobalVar& gv) const;
+  /*!
+   * \brief Get an element from the CallGraphNode using a GlobalVar.
+   *
+   * \param gv The GlobalVar used for indexing.
+   *
+   * \return The fetched element.
+   */
+  CallGraphEntryNode* operator[](const GlobalVar& gv);
+  /*!
+   * \brief Get an element from the CallGraphNode using the global function 
name.
+   *
+   * \param gvar_name The global function name used for indexing.
+   *
+   * \return The fetched element.
+   */
+  const CallGraphEntryNode* operator[](const std::string& gvar_name) const {
+    return (*this)[module->GetGlobalVar(gvar_name)];
+  }
+  /*!
+   * \brief Get an element from the CallGraphNode using the global function 
name.
+   *
+   * \param gvar_name The global function name used for indexing.
+   *
+   * \return The fetched element.
+   */
+  CallGraphEntryNode* operator[](const std::string& gvar_name) {
+    return (*this)[module->GetGlobalVar(gvar_name)];
+  }
+
+  /*! \brief Return the IR module. */
+  IRModule GetModule() const {
+    return module;
+  }
+
+  /*!
+   * \brief Get the entries/root nodes of CallGraphNode.
+   *
+   *  Entry functions are never referenced by other functions.
+   *  Note these functions can be recursive as well.
+   *
+   * \return The list of CallGraphEntryNode that represent entry nodes.
+   */
+  std::vector<CallGraphEntryNode*> GetEntryGlobals() const;
+
+  /*!
+   * \brief Remove a GlobalVar in a given CallGraphEntryNode from the current
+   *        IR module.
+   *
+   * \param cg_node The CallGraphEntryNode that contains a global function to 
be
+   *        removed.
+   * \param update_call_graph Indicate if we will update the CallGraph as well
+   *        since updating is costly. We are only able to remove a leaf 
function
+   *        when update_call_graph is disabled because the edges pointing to
+   *        functions being removed are not updated.
+   *
+   * \return The GlobalVar removed from the current module.
+   */
+  GlobalVar RemoveGlobalVarFromModule(CallGraphEntryNode* cg_node,
+                                      bool update_call_graph = false);
+
+  /*!
+   * \brief Lookup a GlobalVar for the CallGraphNode. It creates an entry for
+   *        the GlobalVar if it doesn't exist.
+   *
+   * \param gv The GlobalVar for query.
+   *
+   * \return The queried entry.
+   */
+  CallGraphEntryNode* LookupGlobalVar(const GlobalVar& gv);
+
+  /*!
+   * \brief Get the entries from the CallGraphNode in the topological order.
+   *
+   *  This is useful for various module-level optimizations/analysis. For 
example,
+   *  inlining requires the correct order of the functions being processed, 
i.e.
+   *  callee should be always handled before callers.
+   *
+   * \return The list of collected entries that are sorted in the topological 
order.
+   */
+  std::vector<CallGraphEntryNode*> TopologicalOrder() const;
+
+  static constexpr const char* _type_key = "relay.CallGraph";
+  TVM_DECLARE_FINAL_OBJECT_INFO(CallGraphNode, Object);
+
+ private:
+  /*!
+   * \brief Create a CallGraphEntryNode for a global function and add it to the
+   *        CallGraphNode.
+   *
+   * \param gv The global var.
+   * \param func The global function corresponding to `gv`.
+   */
+  void AddToCallGraph(const GlobalVar& gv, const Function& func);
+
+  /*! \brief A record contains GlobalVar to CallGraphEntryNode mapping. */
+  CallGraphMap call_graph_;
+
+  friend CallGraph;
+};
+
+/*!
+ * \brief The class that represents the call graph of a Relay IR module. It 
also
+ * provides a variety of utility functions for users to query, view, and update
+ * a call graph.
+ */
+class CallGraph : public ObjectRef {
+  using CallGraphMap =
+      std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntryNode>, 
ObjectHash,
+                         ObjectEqual>;
+  // Create iterator alias for a CallGraph object.
+  using iterator = CallGraphMap::iterator;
+  using const_iterator = CallGraphMap::const_iterator;
+
+ public:
+  /*!
+   * \brief Construct a CallGraph from a IR module.
+   *
+   * \param module The IR module
+   */
+  explicit CallGraph(IRModule module);
+
+  /*!
+   * \brief Construct from an object pointer.
+   * \param n The object pointer.
+   */
+  explicit CallGraph(ObjectPtr<Object> n) : ObjectRef(n) {}
+
+  /*! \return The begin iterator. */
+  iterator begin() {
+    auto* n = operator->();
+    CHECK(n);
+    return n->begin();
+  }
+  /*! \return The end iterator. */
+  iterator end() {
+    auto* n = operator->();
+    CHECK(n);
+    return n->end();
+  }
+  /*! \return The begin iterator. */
+  const_iterator begin() const {
+    const auto* n = operator->();
+    CHECK(n);
+    return n->begin();
+  }
+  /*! \return The end iterator. */
+  const_iterator end() const {
+    const auto* n = operator->();
+    CHECK(n);
+    return n->end();
+  }
+
+  /*!
+   * \brief Get an element from the CallGraph using a GlobalVar.
+   *
+   * \param gv The GlobalVar used for indexing.
+   *
+   * \return The fetched element.
+   */
+  const CallGraphEntryNode* operator[](const GlobalVar& gv) const {
+    const auto* n = operator->();
+    CHECK(n);
+    return (*n)[gv];
+  }
+  /*!
+   * \brief Get an element from the CallGraph using a GlobalVar.
+   *
+   * \param gv The GlobalVar used for indexing.
+   *
+   * \return The fetched element.
+   */
+  CallGraphEntryNode* operator[](const GlobalVar& gv) {
+    auto* n = operator->();
+    CHECK(n);
+    return (*n)[gv];
+  }
+  /*!
+   * \brief Get an element from the CallGraph using the global function name.
+   *
+   * \param gvar_name The global function name used for indexing.
+   *
+   * \return The fetched element.
+   */
+  const CallGraphEntryNode* operator[](const std::string& gvar_name) const {
+    const auto* n = operator->();
+    CHECK(n);
+    return (*n)[gvar_name];
+  }
+  /*!
+   * \brief Get an element from the CallGraph using the global function name.
+   *
+   * \param gvar_name The global function name used for indexing.
+   *
+   * \return The fetched element.
+   */
+  CallGraphEntryNode* operator[](const std::string& gvar_name) {
+    auto* n = operator->();
+    CHECK(n);
+    return (*n)[gvar_name];
+  }
+
+  /*! \return mutable pointers to the node. */
+  CallGraphNode* operator->() const {
+    auto* ptr = get_mutable();
+    CHECK(ptr != nullptr);
+    return static_cast<CallGraphNode*>(ptr);
+  }
+
+ private:
+  /*! \brief Overload the << operator to print a call graph. */
+  friend std::ostream& operator<<(std::ostream& os, const CallGraph&);
+};
+
+/*!
+ * \brief A node in the call graph. It maintains the edges from a caller to
+ * all callees.
+ */
+class CallGraphEntryNode {
 
 Review comment:
   Thanks for pointing out. Only `CallGraphNode` is in the node system. Let's 
use `CallGraphEntry` instead of `CallGraphEntryNode` to reduce the confusion.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to