reminisce commented on a change in pull request #12157: Subgraph API for 
integrating accelerators with MXNet
URL: https://github.com/apache/incubator-mxnet/pull/12157#discussion_r213503015
 
 

 ##########
 File path: src/operator/subgraph/partition_graph.cc
 ##########
 @@ -0,0 +1,774 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file partition_graph.cc
+ * \brief
+ */
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+#include <mxnet/op_attr_types.h>
+#include <unordered_set>
+#include <stack>
+#include <queue>
+
+#include "./subgraph_property.h"
+
+namespace nnvm {
+NodePtr CreateVariableNode(const std::string& name);
+}
+
+namespace mxnet {
+
+namespace op {
+
+using nnvm::Symbol;
+using nnvm::Node;
+using nnvm::NodePtr;
+using nnvm::NodeEntry;
+using nnvm::Graph;
+
+#define DEBUG_SUBGRAPH 0
+
+namespace sg {  // sg stands for subgraph
+
+struct SimpleNode;
+using SimpleNodePtr = std::shared_ptr<SimpleNode>;
+
+/*!
+ * \brief Node of the undirected graph which replicates the network structures
+ * of the computational graph. It is used to ease the graph traversal for 
finding
+ * subgraphs.
+ */
+struct SimpleNode {
+  static SimpleNodePtr Create() {
+    return std::make_shared<SimpleNode>();
+  }
+  SimpleNode() : label(-1), node(nullptr) {}
+  /*! subgraph label */
+  int label;
+  /*! the original node in the computational graph it references*/
+  nnvm::Node* node;
+  /*!
+   * \brief output nodes of the current node
+   * key is node ptr and value is an array of indices standing for the entry 
indices
+   * in key->inputs whose source is the current node.
+   */
+  std::unordered_map<nnvm::Node*, std::vector<size_t>> outputs;
+};  // struct SimpleNode
+
+#if DEBUG_SUBGRAPH
+void PrintSubgraph(const std::vector<SimpleNode*>& simple_nodes) {
+  std::string op_names = "";
+  for (size_t i = 0; i < simple_nodes.size(); ++i) {
+    op_names += simple_nodes[i]->node->attrs.name + ' ';
+  }
+  LOG(INFO) << "Subgraph node names: " << op_names;
+}
+
+void PrintNodeEntry(const nnvm::NodeEntry& entry) {
+  std::string ret = "NodeEntry: node_name=" + entry.node->attrs.name
+    + ", index=" + std::to_string(entry.index) + ", version=" + 
std::to_string(entry.version);
+  LOG(INFO) << ret;
+}
+
+void PrintNodeEntries(const std::vector<nnvm::NodeEntry*>& entries) {
+  for (size_t i = 0; i < entries.size(); ++i) {
+    PrintNodeEntry(*entries[i]);
+  }
+}
+#endif
+
+/*!
+ * \brief Given a MXNet computational graph, create an undirected graph from 
it.
+ * \param g the MXNet computational graph
+ * \param simple_nodes the nodes of undirected graph in top sorted order
+ */
+void CreateSimpleGraph(const Graph& g,
+                       std::vector<SimpleNodePtr>* simple_nodes) {
+  const auto& indexed_graph = g.indexed_graph();
+  simple_nodes->reserve(indexed_graph.num_nodes());
+  DFSVisit(g.outputs, [&](const NodePtr& node) {
+    SimpleNodePtr sn = SimpleNode::Create();
+    sn->node = node.get();
+    for (size_t i = 0; i < sn->node->inputs.size(); ++i) {
+      const auto& e = sn->node->inputs[i];
+      const auto input_nid = indexed_graph.node_id(e.node.get());
+      CHECK_LT(input_nid, simple_nodes->size());
+      auto& input_node_outputs = (*simple_nodes)[input_nid]->outputs;
+      auto it = input_node_outputs.find(sn->node);
+      if (it == input_node_outputs.end()) {
+        input_node_outputs.emplace(sn->node, std::vector<size_t>{i});
+      } else {
+        it->second.push_back(i);
+      }
+    }
+    simple_nodes->emplace_back(std::move(sn));
+  });
+}
+
+/*!
+ * \brief Reset labels of the subgraph nodes to the original state
+ * and clear the vector of subgraph nodes.
+ */
+void ResetNodeLabels(const nnvm::Graph& g,
+                     const std::vector<SimpleNodePtr>& simple_nodes,
+                     std::vector<nnvm::Node*>* subgraph_nodes) {
+  for (auto n : *subgraph_nodes) {
+    const auto nid = g.indexed_graph().node_id(n);
+    simple_nodes[nid]->label = -1;
+  }
+  subgraph_nodes->clear();
+}
+
+/*!
+ * \brief This function traverses the nodes in a computation graph from a 
starting
+ * node following the input edges and output edges, and marks all nodes that
+ * can be accessed from the starting node. Before the function returns,
+ * it will conduct checking whether there is a loop between the potential 
subgraph
+ * and the outside nodes. If so, add the node that should break the loop
+ * in excluded_nodes and return false. Otherwise, return true.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or 
not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \excluded_nodes set of nodes that should be excluded from the current 
subgraph
+ */
+bool LabelSubgraph(const Graph& g,
+                   SubgraphSelectorPtr subgraph_selector,
+                   const int label,
+                   const size_t snid,  // simple node id, this is a seed
+                   const std::vector<SimpleNodePtr>& simple_nodes,
+                   std::vector<nnvm::Node*>* subgraph_nodes,
+                   std::unordered_set<const nnvm::Node*>* excluded_nodes = 
nullptr) {
+  const auto& indexed_graph = g.indexed_graph();
+  std::queue<SimpleNode*> node_queue;
+  if (!excluded_nodes || !excluded_nodes->count(simple_nodes[snid]->node)) {
+    CHECK_EQ(simple_nodes[snid]->label, -1);
+    simple_nodes[snid]->label = label;
+    node_queue.push(simple_nodes[snid].get());
+  }
+  // key: nodes that serve as input/output nodes to the subgraph
+  // value: pair of vectors of nodes in the subgraph. The first vector 
contains the
+  // output nodes of the key in the subgraph, and the second vector contains 
the
+  // input nodes of the key in the subgraph.
+  // If a non-subgraph node has inputs from the subgraph and the other 
non-subgraph node
+  // has outputs to the subgraph, and the first non-subgraph node is an 
ancestor
+  // of the second non-subgraph node, there exits a cycle.
+  // When breaking the cycle, we want to start from removing the node with the 
largest node id
+  // in the subgraph.
+  std::unordered_map<const nnvm::Node*,
+    std::pair<std::vector<const nnvm::Node*>,
+              std::vector<const nnvm::Node*>>> non_subgraph_node_map;
+  while (!node_queue.empty()) {
+    SimpleNode* cur_node = node_queue.front();
+    node_queue.pop();
+    subgraph_nodes->push_back(cur_node->node);
+    // get qualified adjacent input nodes
+    for (auto& e : cur_node->node->inputs) {
+      const bool select_input = (!excluded_nodes || 
!excluded_nodes->count(e.node.get()))
+        && subgraph_selector->SelectInput(*cur_node->node, *e.node);
+      if (select_input) {
+        // e.node is a subgraph node
+        const auto nid = indexed_graph.node_id(e.node.get());
+        CHECK_LT(nid, simple_nodes.size());
+        // this node has not been visited yet
+        if (simple_nodes[nid]->label == -1) {
+          simple_nodes[nid]->label = label;
+          node_queue.push(simple_nodes[nid].get());
+        }
+      } else {
+        // e.node is an input node of the subgraph
+        non_subgraph_node_map[e.node.get()].first.push_back(cur_node->node);
+      }
+    }
+    // get qualified output nodes
+    for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); 
++it) {
+      const bool select_output = (!excluded_nodes || 
!excluded_nodes->count(it->first))
+          && subgraph_selector->SelectOutput(*cur_node->node, *it->first);
+      if (select_output) {
+        // it->first is a subgraph node
+        const auto nid = indexed_graph.node_id(it->first);
+        CHECK_LT(nid, simple_nodes.size());
+        // this node has not been visited yet
+        if (simple_nodes[nid]->label == -1) {
+          simple_nodes[nid]->label = label;
+          node_queue.push(simple_nodes[nid].get());
+        }
+      } else {
+        // it->first is an output node of the subgraph
+        non_subgraph_node_map[it->first].second.push_back(cur_node->node);
+      }
+    }
+  }
+  // prepare to check if there is a cycle
+  auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+    return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+  };
+  std::vector<const nnvm::Node*> non_subgraph_nodes;
+  non_subgraph_nodes.reserve(non_subgraph_node_map.size());
+  for (auto& kv : non_subgraph_node_map) {
+    auto& output_nodes = kv.second.first;
+    std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
+    auto& input_nodes = kv.second.second;
+    std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+    non_subgraph_nodes.push_back(kv.first);
+  }
+  // check whether there is a cycle between the subgraph and its input/output 
nodes
+  auto is_ancestor = [&](const nnvm::Node* ancestor, const nnvm::Node* 
descendant,
+                         const std::vector<nnvm::Node*>& snodes) {
+    if (ancestor == descendant) return true;
+    std::stack<const nnvm::Node*> s;
+    s.push(descendant);
+    size_t count = 0;
+    while (!s.empty()) {
+      CHECK_LT(count, indexed_graph.num_nodes()) << "Finding ancestor failed. 
There is probably"
+                                                    " a loop in the graph";
+      ++count;
+      const nnvm::Node* top = s.top();
+      s.pop();
+      if (top == ancestor) {
+        return true;
+      }
+      for (const auto& entry : top->inputs) {
+        // when searching for the ancestor, the path cannot cross any subgraph 
node
+        auto it = std::find(snodes.begin(), snodes.end(), entry.node.get());
+        if (it == snodes.end()) {
+          s.push(entry.node.get());
+        }
+      }
+    }
+    return false;
+  };
+  std::sort(non_subgraph_nodes.begin(), non_subgraph_nodes.end(), node_cmp);
+  int excluded_node_id = -1;
+  for (size_t i = 0; i < non_subgraph_nodes.size(); ++i) {
+    auto it1 = non_subgraph_node_map.find(non_subgraph_nodes[i]);
+    CHECK(it1 != non_subgraph_node_map.end());
+    auto& output_nodes = it1->second.first;  // has been top sorted
+    auto& input_nodes = it1->second.second;  // has been top sorted
+    if (!output_nodes.empty() && !input_nodes.empty()) {
+      // there is a loop between node i and the subgraph
+      const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
+                                    indexed_graph.node_id(input_nodes.back()));
+      excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+    } else if (!input_nodes.empty()) {
+      // node i is an input to the subgraph, find out if there is a node j
+      // which is an output of the subgraph and also a child of node i.
+      for (size_t j = i + 1; j < non_subgraph_nodes.size(); ++j) {
+        auto it2 = non_subgraph_node_map.find(non_subgraph_nodes[j]);
+        CHECK(it2 != non_subgraph_node_map.end());
+        // i is topologically before j, j might be a direct/indirect output 
node of i
+        CHECK_LT(indexed_graph.node_id(it1->first), 
indexed_graph.node_id(it2->first));
+        if (!it2->second.first.empty() && is_ancestor(it1->first, it2->first, 
*subgraph_nodes)) {
+          // found a loop
+          const auto node_id = 
std::max(indexed_graph.node_id(input_nodes.back()),
+                                        
indexed_graph.node_id(it2->second.first.back()));
+          excluded_node_id = std::max(excluded_node_id, 
static_cast<int>(node_id));
+        }
+      }
+    }
+  }
+
+  if (excluded_node_id != -1) {
+    CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
+    CHECK_NE(excluded_node_id, static_cast<int>(snid))
+      << "A cycle is found in the computational graph between nodes "
+      << simple_nodes[excluded_node_id]->node->attrs.name << " and "
+      << simple_nodes[snid]->node->attrs.name;
+    excluded_nodes->insert(simple_nodes[excluded_node_id]->node);
+    ResetNodeLabels(g, simple_nodes, subgraph_nodes);
+    return false;
+  }
+  std::sort(subgraph_nodes->begin(), subgraph_nodes->end(), node_cmp);
+  return true;
+}
+
+/*!
+ * \brief Finds all the nodes belonging to the same subgraph given a seed node.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or 
not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \return Subgraph node candidates sorted in the topological order
+ */
+void PreSelectSubgraphNodes(const Graph& g,
+                            SubgraphSelectorPtr subgraph_selector,
+                            const int label,
+                            const size_t snid,
+                            const std::vector<SimpleNodePtr>& simple_nodes,
+                            std::vector<nnvm::Node*>* subgraph_nodes) {
+  std::unordered_set<const nnvm::Node*> excluded_nodes;
+  const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
+  size_t count = 0;
+  bool success = false;
+  while (!success && count < max_num_retry) {
 
 Review comment:
   Theoretically, yes. But just want to be safe here.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to