reminisce commented on a change in pull request #12157: Subgraph API for 
integrating accelerators with MXNet
URL: https://github.com/apache/incubator-mxnet/pull/12157#discussion_r212713464
 
 

 ##########
 File path: src/executor/graph_executor.cc
 ##########
 @@ -1428,6 +1430,146 @@ GraphExecutor::CachedSegOpr 
GraphExecutor::CreateCachedSegOpr(size_t topo_start,
     iter->c_str());
   return ret;
 }
+
+// Infer shapes, dtypes, stypes, contexts for the forward graph
+static nnvm::Graph InferForwardAttrs(nnvm::Graph g,
+                                     nnvm::ShapeVector arg_shapes,
+                                     nnvm::DTypeVector arg_dtypes,
+                                     StorageTypeVector arg_stypes,
+                                     const Context& default_ctx,
+                                     const std::map<std::string, Context>& 
ctx_map,
+                                     const std::vector<Context>& in_arg_ctxes,
+                                     const std::vector<Context>& 
aux_state_ctxes) {
+  const auto& indexed_graph = g.indexed_graph();
+  const auto num_forward_inputs = indexed_graph.input_nodes().size();
+  g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
+                   aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs, indexed_graph,
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs, indexed_graph,
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+  g = InferStorageType(std::move(g), std::move(arg_stypes), 
"__storage_type__");
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+  return g;
+}
+
+// Given input attr arrays, partition the graph using the backend name equal 
to prop_name.
+// This is a common function for bind and simple_bind flows.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const nnvm::ShapeVector& arg_shapes,
+                                   const nnvm::DTypeVector& arg_dtypes,
+                                   const StorageTypeVector& arg_stypes,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& 
ctx_map,
+                                   const std::vector<Context>& in_arg_ctxes,
+                                   const std::vector<Context>& 
aux_state_ctxes) {
+  auto subgraph_prop = 
op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
 
 Review comment:
   There is a check inside `CreateSubgraphProperty`.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to