This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new faffe56  move parameters to backend for hybridblock (#10632)
faffe56 is described below

commit faffe56b450cfc908bf9d29e810e00c0de8aca7b
Author: Eric Junyuan Xie <piiswr...@users.noreply.github.com>
AuthorDate: Tue Apr 24 15:04:31 2018 -0700

    move parameters to backend for hybridblock (#10632)
---
 include/mxnet/base.h                 |  8 ++++
 include/mxnet/c_api.h                |  7 ++-
 include/mxnet/imperative.h           | 18 +++++---
 python/mxnet/_ctypes/ndarray.py      | 17 ++++++-
 python/mxnet/gluon/block.py          | 90 +++++++++++++++++-------------------
 src/c_api/c_api_ndarray.cc           | 31 +++++++++++--
 src/executor/attach_op_execs_pass.cc | 11 +----
 src/executor/graph_executor.cc       |  1 -
 src/executor/graph_executor.h        |  2 -
 src/imperative/cached_op.cc          | 67 ++++++++++++++++++++++-----
 10 files changed, 168 insertions(+), 84 deletions(-)

diff --git a/include/mxnet/base.h b/include/mxnet/base.h
index 783002e..38fd7ed 100644
--- a/include/mxnet/base.h
+++ b/include/mxnet/base.h
@@ -362,6 +362,14 @@ constexpr size_t kMKLDNNAlign = 64;
 
 }  // namespace mxnet
 
+namespace std {
+template<> struct hash<mxnet::Context> {
+  size_t operator()(const mxnet::Context& ctx) const {
+    return (static_cast<size_t>(ctx.dev_type) << 32) | ctx.dev_id;
+  }
+};
+}
+
 #include "./tensor_blob.h"
 //! \endcond
 #endif  // MXNET_BASE_H_
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index cbc83b2..3f04051 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -976,9 +976,14 @@ MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, 
CachedOpHandle *out);
  * \brief create cached operator
  */
 MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
-                                 int num_params,
+                                 int num_flags,
                                  const char** keys,
                                  const char** vals,
+                                 int num_inputs,
+                                 const char** input_names,
+                                 int num_params,
+                                 const char** param_names,
+                                 NDArrayHandle* params,
                                  CachedOpHandle *out);
 /*!
  * \brief free cached operator
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index d605e9d..758ce85 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -36,11 +36,11 @@
 
 namespace mxnet {
 /*! \brief CachedOp Parameters */
-struct CachedOpParam : public dmlc::Parameter<CachedOpParam> {
+struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
   uint32_t inline_limit;
   uint32_t forward_bulk_size;
   uint32_t backward_bulk_size;
-  DMLC_DECLARE_PARAMETER(CachedOpParam) {
+  DMLC_DECLARE_PARAMETER(CachedOpConfig) {
     DMLC_DECLARE_FIELD(inline_limit)
     .set_default(2)
     .describe("Maximum number of operators that can be inlined.");
@@ -96,8 +96,11 @@ class Imperative {
   };
   class CachedOp {
    public:
-    CachedOp(const nnvm::Symbol& sym,
-             const std::vector<std::pair<std::string, std::string> >& kwargs);
+    CachedOp(
+        const nnvm::Symbol& sym,
+        const std::vector<std::pair<std::string, std::string> >& flags,
+        const std::vector<std::string> arg_names,
+        const std::unordered_map<std::string, std::vector<NDArray> >& params);
     uint32_t num_inputs() {
       return fwd_graph_.indexed_graph().input_nodes().size();
     }
@@ -124,7 +127,7 @@ class Imperative {
     std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
                                           const std::vector<nnvm::NodeEntry>& 
ograds);
     void Forward(const std::shared_ptr<CachedOp>& op_ptr,
-                 const std::vector<NDArray*>& inputs,
+                 const std::vector<NDArray*>& args,
                  const std::vector<NDArray*>& outputs);
     void Backward(const bool retain_graph,
                   const OpStatePtr& state,
@@ -138,14 +141,17 @@ class Imperative {
       std::vector<OpStatePtr> states;
     };
     std::mutex mutex_;
-    CachedOpParam param_;
+    CachedOpConfig config_;
     nnvm::Graph fwd_graph_;
     nnvm::Graph grad_graph_;
     nnvm::Graph full_graph_;
+    std::unordered_map<Context, std::vector<NDArray> > params_;
     bool inlining_;
     std::vector<nnvm::NodeEntry> ograd_entries_;
     std::vector<bool> curr_grad_req_;
     std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
+    std::vector<uint32_t> fwd_args_idx_;
+    std::vector<uint32_t> fwd_params_idx_;
     std::vector<uint32_t> bwd_input_eid_;
     std::vector<bool> save_inputs_, save_outputs_;
   };
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index 20ad2bf..191985e 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -105,13 +105,28 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
 class CachedOp(object):
     """Cached operator handle."""
     __slots__ = ["handle"]
-    def __init__(self, sym, flags=()):
+    def __init__(self, sym, flags=(), inputs=None, params=None):
         self.handle = CachedOpHandle()
+        param_names = []
+        param_arrays = []
+        if inputs is None:
+            assert params is None, "When inputs is None params must also be 
None."
+            inputs = sym.list_inputs()
+        elif params is not None:
+            for name, arrs in params.items():
+                param_arrays.extend(arrs)
+                param_names.extend([name] * len(arrs))
+
         check_call(_LIB.MXCreateCachedOpEx(
             sym.handle,
             len(flags),
             c_str_array([key for key, _ in flags]),
             c_str_array([str(val) for _, val in flags]),
+            len(inputs),
+            c_str_array(inputs),
+            len(params),
+            c_str_array(param_names),
+            c_handle_array(param_arrays),
             ctypes.byref(self.handle)))
 
     def __del__(self):
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 0f41543..a737817 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -447,7 +447,6 @@ class HybridBlock(Block):
         super(HybridBlock, self).__init__(prefix=prefix, params=params)
         self._cached_graph = ()
         self._cached_op = None
-        self._cached_op_args = None
         self._out_format = None
         self._in_format = None
         self._active = False
@@ -479,53 +478,54 @@ class HybridBlock(Block):
 
     def _build_cache(self, *args):
         inputs, out = self._get_graph(*args)
-        input_idx = {var.name: i for i, var in enumerate(inputs)}
-        self._cached_op = ndarray.CachedOp(out, self._flags)
-        params = dict(self.collect_params().items())
-
-        # verify graph inputs
-        expected_inputs = set(out.list_inputs())
-        for name in expected_inputs:
-            assert name in params or name in input_idx, \
+        input_names = [i.name for i in inputs]
+
+        params = self.collect_params()
+        param_names = set(params.keys())
+        expected_names = set(out.list_inputs())
+        for name in expected_names:
+            assert name in param_names or name in input_names, \
                 "Unknown input to HybridBlock: %s"%name
-        for name, i in input_idx.items():
-            if name not in expected_inputs:
-                warnings.warn("The %d-th input to HybridBlock is not used by 
any "
-                              "computation. Is this intended?"%i, stacklevel=4)
-        for name in params:
-            if name not in expected_inputs:
-                warnings.warn("Parameter %s is not used by any computation. "
-                              "Is this intended?"%name, stacklevel=4)
-
-        self._cached_op_args = [(False, params[name]) if name in params
-                                else (True, input_idx[name])
-                                for name in out.list_inputs()]
-
-    def _finish_deferred_init(self, hybrid, *args):
+
+        used_input_names = [i for i in input_names if i in expected_names]
+        if len(used_input_names) != len(input_names):
+            unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names)
+                                if name not in expected_names])
+            warnings.warn("The %s input to HybridBlock is not used by any "
+                          "computation. Is this intended?"%unused, 
stacklevel=4)
+
+        used_param_names = set(i for i in param_names if i in expected_names)
+        if len(used_param_names) != len(param_names):
+            unused = ', '.join(list(param_names - used_param_names))
+            warnings.warn("Parameter %s is not used by any computation. "
+                          "Is this intended?"%unused, stacklevel=4)
+
+        used_params = {k: params[k] for k in used_param_names}
+        try:
+            param_dict = {k: v.list_data() for k, v in used_params.items()}
+        except DeferredInitializationError:
+            self._deferred_infer_shape(*args)
+            for i in used_params.values():
+                i._finish_deferred_init()
+            param_dict = {k: v.list_data() for k, v in used_params.items()}
+
+        self._cached_op = ndarray.CachedOp(out, self._flags, input_names, 
param_dict)
+
+    def _deferred_infer_shape(self, *args):
         try:
             self.infer_shape(*args)
         except Exception as e:
             error_msg = "Deferred initialization failed because shape"\
-                        " cannot be inferred \n {}".format(e)
+                        " cannot be inferred. {}".format(e)
             raise ValueError(error_msg)
 
-        if hybrid:
-            for is_arg, i in self._cached_op_args:
-                if not is_arg:
-                    i._finish_deferred_init()
-        else:
-            for _, i in self.params.items():
-                i._finish_deferred_init()
-
     def _call_cached_op(self, *args):
         if self._cached_op is None:
             self._build_cache(*args)
 
         args, fmt = _flatten(args, "input")
         assert fmt == self._in_format, "Invalid input format"
-        cargs = [args[i] if is_arg else i.data()
-                 for is_arg, i in self._cached_op_args]
-        out = self._cached_op(*cargs)
+        out = self._cached_op(*args)
         if isinstance(out, NDArray):
             out = [out]
         return _regroup(out, self._out_format)[0]
@@ -533,7 +533,6 @@ class HybridBlock(Block):
     def _clear_cached_op(self):
         self._cached_graph = ()
         self._cached_op = None
-        self._cached_op_args = None
 
     def register_child(self, block, name=None):
         if not isinstance(block, HybridBlock):
@@ -616,16 +615,17 @@ class HybridBlock(Block):
         :py:class:`NDArray` or :py:class:`Symbol`."""
         if isinstance(x, NDArray):
             with x.context as ctx:
+                if self._active:
+                    return self._call_cached_op(x, *args)
+
                 try:
-                    if self._active:
-                        return self._call_cached_op(x, *args)
                     params = {i: j.data(ctx) for i, j in 
self._reg_params.items()}
                 except DeferredInitializationError:
-                    self._finish_deferred_init(self._active, x, *args)
+                    self._deferred_infer_shape(x, *args)
+                    for _, i in self.params.items():
+                        i._finish_deferred_init()
+                    params = {i: j.data(ctx) for i, j in 
self._reg_params.items()}
 
-                if self._active:
-                    return self._call_cached_op(x, *args)
-                params = {i: j.data(ctx) for i, j in self._reg_params.items()}
                 return self.hybrid_forward(ndarray, x, *args, **params)
 
         assert isinstance(x, Symbol), \
@@ -709,16 +709,10 @@ class SymbolBlock(HybridBlock):
                 self.params.get(i, grad_req='null', allow_deferred_init=True)
 
         self._cached_graph = syms, out
-        self._build_cache()
 
     def forward(self, x, *args):
         if isinstance(x, NDArray):
             with x.context:
-                try:
-                    return self._call_cached_op(x, *args)
-                except DeferredInitializationError:
-                    self._finish_deferred_init(True, x, *args)
-
                 return self._call_cached_op(x, *args)
 
         assert isinstance(x, Symbol), \
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index d67d52c..9aabe04 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -156,26 +156,47 @@ int MXCreateCachedOp(SymbolHandle handle,
   nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);
 
   API_BEGIN();
+  auto inputs = sym->ListInputs(nnvm::Symbol::kAll);
+  std::vector<std::string> input_names;
+  input_names.reserve(inputs.size());
+  for (const auto& i : inputs) input_names.push_back(i->attrs.name);
   *out = new std::shared_ptr<Imperative::CachedOp>(
       new Imperative::CachedOp(
-        *sym, std::vector<std::pair<std::string, std::string> >()));
+        *sym,
+        std::vector<std::pair<std::string, std::string> >(),
+        input_names,
+        std::unordered_map<std::string, std::vector<NDArray> >()));
   API_END();
 }
 
 int MXCreateCachedOpEx(SymbolHandle handle,
-                       int num_params,
+                       int num_flags,
                        const char** keys,
                        const char** vals,
+                       int num_args,
+                       const char** arg_names,
+                       int num_params,
+                       const char** param_names,
+                       NDArrayHandle* params,
                        CachedOpHandle *out) {
   nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);
 
   API_BEGIN();
-  std::vector<std::pair<std::string, std::string> > kwargs;
+  std::vector<std::pair<std::string, std::string> > flags;
+  for (int i = 0; i < num_flags; ++i) {
+    flags.push_back({keys[i], vals[i]});
+  }
+  std::vector<std::string> args;
+  for (int i = 0; i < num_args; ++i) {
+    args.push_back(arg_names[i]);
+  }
+  std::unordered_map<std::string, std::vector<NDArray> > param_dict;
   for (int i = 0; i < num_params; ++i) {
-    kwargs.push_back({keys[i], vals[i]});
+    param_dict[param_names[i]].emplace_back(
+        *reinterpret_cast<NDArray*>(params[i]));
   }
   *out = new std::shared_ptr<Imperative::CachedOp>(
-      new Imperative::CachedOp(*sym, kwargs));
+      new Imperative::CachedOp(*sym, flags, args, param_dict));
   API_END();
 }
 
diff --git a/src/executor/attach_op_execs_pass.cc 
b/src/executor/attach_op_execs_pass.cc
index e4d4955..b174ea0 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -241,8 +241,6 @@ Graph AttachOpExecs(Graph g) {
   const auto& vdtype = g.GetAttr<DTypeVector>("dtype");
   const auto& vshape = g.GetAttr<ShapeVector>("shape");
   const auto& vctx = g.GetAttr<ContextVector>("context");
-  const auto& saved_states = g.GetAttr<
-    std::unordered_map<const nnvm::Node*, OpStatePtr> >("saved_states");
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
   // get the graph
@@ -271,13 +269,8 @@ Graph AttachOpExecs(Graph g) {
         itype.emplace_back(vdtype[idx.entry_id(e)]);
       }
 
-      OpStatePtr state;
-      if (saved_states.count(inode.source)) {
-        state = saved_states.at(inode.source);
-      } else {
-        state = fcreate_op_state[op](
-            inode.source->attrs, vctx[i], ishape, itype);
-      }
+      OpStatePtr state = fcreate_op_state[op](
+          inode.source->attrs, vctx[i], ishape, itype);
       FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
           op, "FStatefulComputeEx", vctx[i]);
       // FStatefulComputeEx is dispatched only when dispatch_mode is 
DispatchMode::kFComputeEx
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 4d24f55..d5dacf7 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -904,7 +904,6 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
   }
   g = DetectInplaceAddTo(g);
 
-  g.attrs["saved_states"] = 
std::make_shared<nnvm::any>(std::move(saved_states_));
   g = AttachOpExecs(g);
   g = AttachOpResources(g);
   graph_ = std::move(g);
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index 3f1ebe5..bcde41d 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -234,8 +234,6 @@ class GraphExecutor : public Executor {
   size_t num_forward_inputs_{0};
   // number of forward nodes
   size_t num_forward_nodes_{0};
-  // saved operator for autograd
-  std::unordered_map<const nnvm::Node*, OpStatePtr> saved_states_;
   // monitor call back
   std::function<void(const char*, void*)> monitor_callback_{nullptr};
   // whether to enable bulk execution
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index ce23735..10a8fc0 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -22,17 +22,19 @@
 
 namespace mxnet {
 
-DMLC_REGISTER_PARAMETER(CachedOpParam);
+DMLC_REGISTER_PARAMETER(CachedOpConfig);
 
 Imperative::CachedOp::CachedOp(
     const nnvm::Symbol& sym,
-    const std::vector<std::pair<std::string, std::string> >& kwargs) {
+    const std::vector<std::pair<std::string, std::string> >& flags,
+    const std::vector<std::string> arg_names,
+    const std::unordered_map<std::string, std::vector<NDArray> >& params) {
   using namespace nnvm;
   using namespace imperative;
   static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), 
Op::Get("_zeros")};
   static const auto _copy = Op::Get("_copy");
 
-  param_.Init(kwargs);
+  config_.Init(flags);
 
   // construct forward graph
   {
@@ -66,7 +68,34 @@ Imperative::CachedOp::CachedOp(
     fwd_graph_.attrs["forward_ref_count"] =
         std::make_shared<dmlc::any>(std::move(ref_count));
 
-    inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= 
param_.inline_limit;
+    inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= 
config_.inline_limit;
+  }
+
+  // Set params
+  {
+    const auto& idx = fwd_graph_.indexed_graph();
+    std::unordered_map<std::string, size_t> arg_name_to_id;
+    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
+      const auto& name = idx[idx.input_nodes()[i]].source->attrs.name;
+      auto iter = params.find(name);
+      if (iter == params.end()) {
+        arg_name_to_id[name] = i;
+        continue;
+      }
+      fwd_params_idx_.push_back(i);
+      for (const auto& param : iter->second) {
+        params_[param.ctx()].emplace_back(param);
+      }
+    }
+
+    CHECK_EQ(arg_name_to_id.size(), arg_names.size())
+        << "Expecting " << arg_name_to_id.size() << "inputs, given " << 
arg_names.size();
+
+    for (const auto& name : arg_names) {
+      auto iter = arg_name_to_id.find(name);
+      CHECK(iter != arg_name_to_id.end()) << "Unexpected input name " << name;
+      fwd_args_idx_.push_back(iter->second);
+    }
   }
 
   // construct backward graph
@@ -341,22 +370,38 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
 
 void Imperative::CachedOp::Forward(
     const std::shared_ptr<CachedOp>& op_ptr,
-    const std::vector<NDArray*>& inputs,
+    const std::vector<NDArray*>& args,
     const std::vector<NDArray*>& outputs) {
   using namespace nnvm;
   using namespace imperative;
   static const auto cached_op = nnvm::Op::Get("_CachedOp");
 
+  CHECK_EQ(args.size(), fwd_args_idx_.size())
+      << "CachedOp requires " << fwd_args_idx_.size()
+      << " inputs but got " << args.size();
+
+  Context default_ctx = args[0]->ctx();
+
+
+  std::vector<NDArray*> inputs(num_inputs());
+  for (index_t i = 0; i < fwd_args_idx_.size(); ++i) {
+    inputs[fwd_args_idx_[i]] = args[i];
+  }
+  if (fwd_params_idx_.size()) {
+    CHECK(params_.find(default_ctx) != params_.end())
+        << "CachedOp is not initialized on context " << default_ctx;
+
+    for (size_t i = 0; i < fwd_params_idx_.size(); ++i) {
+      inputs[fwd_params_idx_[i]] = &params_[default_ctx][i];
+    }
+  }
+
   // Initialize
   bool recording = Imperative::Get()->is_recording();
   nnvm::Graph g = GetForwardGraph(recording, inputs);
   const auto& idx = g.indexed_graph();
   size_t num_inputs = idx.input_nodes().size();
 
-  CHECK_EQ(num_inputs, inputs.size())
-      << "CachedOp requires " << num_inputs << " but got " << inputs.size();
-
-  Context default_ctx = inputs[0]->ctx();
   for (size_t i = 0; i < inputs.size(); ++i) {
     CHECK_EQ(inputs[i]->ctx(), default_ctx)
         << "CachedOp requires all inputs to live on the same context. But "
@@ -403,7 +448,7 @@ void Imperative::CachedOp::Forward(
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
   if (recording && !inlining_) Imperative::Get()->set_is_recording(false);
-  int prev_bulk_size = Engine::Get()->set_bulk_size(param_.forward_bulk_size);
+  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
 
   Imperative::Get()->RunGraph(
       false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
@@ -491,7 +536,7 @@ void Imperative::CachedOp::Backward(
 
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
-  int prev_bulk_size = Engine::Get()->set_bulk_size(param_.backward_bulk_size);
+  int prev_bulk_size = 
Engine::Get()->set_bulk_size(config_.backward_bulk_size);
 
   Imperative::Get()->RunGraph(
       retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to