This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new faffe56 move parameters to backend for hybridblock (#10632) faffe56 is described below commit faffe56b450cfc908bf9d29e810e00c0de8aca7b Author: Eric Junyuan Xie <piiswr...@users.noreply.github.com> AuthorDate: Tue Apr 24 15:04:31 2018 -0700 move parameters to backend for hybridblock (#10632) --- include/mxnet/base.h | 8 ++++ include/mxnet/c_api.h | 7 ++- include/mxnet/imperative.h | 18 +++++--- python/mxnet/_ctypes/ndarray.py | 17 ++++++- python/mxnet/gluon/block.py | 90 +++++++++++++++++------------------- src/c_api/c_api_ndarray.cc | 31 +++++++++++-- src/executor/attach_op_execs_pass.cc | 11 +---- src/executor/graph_executor.cc | 1 - src/executor/graph_executor.h | 2 - src/imperative/cached_op.cc | 67 ++++++++++++++++++++++----- 10 files changed, 168 insertions(+), 84 deletions(-) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 783002e..38fd7ed 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -362,6 +362,14 @@ constexpr size_t kMKLDNNAlign = 64; } // namespace mxnet +namespace std { +template<> struct hash<mxnet::Context> { + size_t operator()(const mxnet::Context& ctx) const { + return (static_cast<size_t>(ctx.dev_type) << 32) | ctx.dev_id; + } +}; +} + #include "./tensor_blob.h" //! \endcond #endif // MXNET_BASE_H_ diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index cbc83b2..3f04051 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -976,9 +976,14 @@ MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out); * \brief create cached operator */ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle, - int num_params, + int num_flags, const char** keys, const char** vals, + int num_inputs, + const char** input_names, + int num_params, + const char** param_names, + NDArrayHandle* params, CachedOpHandle *out); /*! * \brief free cached operator diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index d605e9d..758ce85 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -36,11 +36,11 @@ namespace mxnet { /*! \brief CachedOp Parameters */ -struct CachedOpParam : public dmlc::Parameter<CachedOpParam> { +struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> { uint32_t inline_limit; uint32_t forward_bulk_size; uint32_t backward_bulk_size; - DMLC_DECLARE_PARAMETER(CachedOpParam) { + DMLC_DECLARE_PARAMETER(CachedOpConfig) { DMLC_DECLARE_FIELD(inline_limit) .set_default(2) .describe("Maximum number of operators that can be inlined."); @@ -96,8 +96,11 @@ class Imperative { }; class CachedOp { public: - CachedOp(const nnvm::Symbol& sym, - const std::vector<std::pair<std::string, std::string> >& kwargs); + CachedOp( + const nnvm::Symbol& sym, + const std::vector<std::pair<std::string, std::string> >& flags, + const std::vector<std::string> arg_names, + const std::unordered_map<std::string, std::vector<NDArray> >& params); uint32_t num_inputs() { return fwd_graph_.indexed_graph().input_nodes().size(); } @@ -124,7 +127,7 @@ class Imperative { std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node, const std::vector<nnvm::NodeEntry>& ograds); void Forward(const std::shared_ptr<CachedOp>& op_ptr, - const std::vector<NDArray*>& inputs, + const std::vector<NDArray*>& args, const std::vector<NDArray*>& outputs); void Backward(const bool retain_graph, const OpStatePtr& state, @@ -138,14 +141,17 @@ class Imperative { std::vector<OpStatePtr> states; }; std::mutex mutex_; - CachedOpParam param_; + CachedOpConfig config_; nnvm::Graph fwd_graph_; nnvm::Graph grad_graph_; nnvm::Graph full_graph_; + std::unordered_map<Context, std::vector<NDArray> > params_; bool inlining_; std::vector<nnvm::NodeEntry> ograd_entries_; std::vector<bool> curr_grad_req_; std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_; + std::vector<uint32_t> fwd_args_idx_; + std::vector<uint32_t> fwd_params_idx_; std::vector<uint32_t> bwd_input_eid_; std::vector<bool> save_inputs_, save_outputs_; }; diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index 20ad2bf..191985e 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -105,13 +105,28 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): class CachedOp(object): """Cached operator handle.""" __slots__ = ["handle"] - def __init__(self, sym, flags=()): + def __init__(self, sym, flags=(), inputs=None, params=None): self.handle = CachedOpHandle() + param_names = [] + param_arrays = [] + if inputs is None: + assert params is None, "When inputs is None params must also be None." + inputs = sym.list_inputs() + elif params is not None: + for name, arrs in params.items(): + param_arrays.extend(arrs) + param_names.extend([name] * len(arrs)) + check_call(_LIB.MXCreateCachedOpEx( sym.handle, len(flags), c_str_array([key for key, _ in flags]), c_str_array([str(val) for _, val in flags]), + len(inputs), + c_str_array(inputs), + len(params), + c_str_array(param_names), + c_handle_array(param_arrays), ctypes.byref(self.handle))) def __del__(self): diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 0f41543..a737817 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -447,7 +447,6 @@ class HybridBlock(Block): super(HybridBlock, self).__init__(prefix=prefix, params=params) self._cached_graph = () self._cached_op = None - self._cached_op_args = None self._out_format = None self._in_format = None self._active = False @@ -479,53 +478,54 @@ class HybridBlock(Block): def _build_cache(self, *args): inputs, out = self._get_graph(*args) - input_idx = {var.name: i for i, var in enumerate(inputs)} - self._cached_op = ndarray.CachedOp(out, self._flags) - params = dict(self.collect_params().items()) - - # verify graph inputs - expected_inputs = set(out.list_inputs()) - for name in expected_inputs: - assert name in params or name in input_idx, \ + input_names = [i.name for i in inputs] + + params = self.collect_params() + param_names = set(params.keys()) + expected_names = set(out.list_inputs()) + for name in expected_names: + assert name in param_names or name in input_names, \ "Unknown input to HybridBlock: %s"%name - for name, i in input_idx.items(): - if name not in expected_inputs: - warnings.warn("The %d-th input to HybridBlock is not used by any " - "computation. Is this intended?"%i, stacklevel=4) - for name in params: - if name not in expected_inputs: - warnings.warn("Parameter %s is not used by any computation. " - "Is this intended?"%name, stacklevel=4) - - self._cached_op_args = [(False, params[name]) if name in params - else (True, input_idx[name]) - for name in out.list_inputs()] - - def _finish_deferred_init(self, hybrid, *args): + + used_input_names = [i for i in input_names if i in expected_names] + if len(used_input_names) != len(input_names): + unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names) + if name not in expected_names]) + warnings.warn("The %s input to HybridBlock is not used by any " + "computation. Is this intended?"%unused, stacklevel=4) + + used_param_names = set(i for i in param_names if i in expected_names) + if len(used_param_names) != len(param_names): + unused = ', '.join(list(param_names - used_param_names)) + warnings.warn("Parameter %s is not used by any computation. " + "Is this intended?"%unused, stacklevel=4) + + used_params = {k: params[k] for k in used_param_names} + try: + param_dict = {k: v.list_data() for k, v in used_params.items()} + except DeferredInitializationError: + self._deferred_infer_shape(*args) + for i in used_params.values(): + i._finish_deferred_init() + param_dict = {k: v.list_data() for k, v in used_params.items()} + + self._cached_op = ndarray.CachedOp(out, self._flags, input_names, param_dict) + + def _deferred_infer_shape(self, *args): try: self.infer_shape(*args) except Exception as e: error_msg = "Deferred initialization failed because shape"\ - " cannot be inferred \n {}".format(e) + " cannot be inferred. {}".format(e) raise ValueError(error_msg) - if hybrid: - for is_arg, i in self._cached_op_args: - if not is_arg: - i._finish_deferred_init() - else: - for _, i in self.params.items(): - i._finish_deferred_init() - def _call_cached_op(self, *args): if self._cached_op is None: self._build_cache(*args) args, fmt = _flatten(args, "input") assert fmt == self._in_format, "Invalid input format" - cargs = [args[i] if is_arg else i.data() - for is_arg, i in self._cached_op_args] - out = self._cached_op(*cargs) + out = self._cached_op(*args) if isinstance(out, NDArray): out = [out] return _regroup(out, self._out_format)[0] @@ -533,7 +533,6 @@ class HybridBlock(Block): def _clear_cached_op(self): self._cached_graph = () self._cached_op = None - self._cached_op_args = None def register_child(self, block, name=None): if not isinstance(block, HybridBlock): @@ -616,16 +615,17 @@ class HybridBlock(Block): :py:class:`NDArray` or :py:class:`Symbol`.""" if isinstance(x, NDArray): with x.context as ctx: + if self._active: + return self._call_cached_op(x, *args) + try: - if self._active: - return self._call_cached_op(x, *args) params = {i: j.data(ctx) for i, j in self._reg_params.items()} except DeferredInitializationError: - self._finish_deferred_init(self._active, x, *args) + self._deferred_infer_shape(x, *args) + for _, i in self.params.items(): + i._finish_deferred_init() + params = {i: j.data(ctx) for i, j in self._reg_params.items()} - if self._active: - return self._call_cached_op(x, *args) - params = {i: j.data(ctx) for i, j in self._reg_params.items()} return self.hybrid_forward(ndarray, x, *args, **params) assert isinstance(x, Symbol), \ @@ -709,16 +709,10 @@ class SymbolBlock(HybridBlock): self.params.get(i, grad_req='null', allow_deferred_init=True) self._cached_graph = syms, out - self._build_cache() def forward(self, x, *args): if isinstance(x, NDArray): with x.context: - try: - return self._call_cached_op(x, *args) - except DeferredInitializationError: - self._finish_deferred_init(True, x, *args) - return self._call_cached_op(x, *args) assert isinstance(x, Symbol), \ diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index d67d52c..9aabe04 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -156,26 +156,47 @@ int MXCreateCachedOp(SymbolHandle handle, nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle); API_BEGIN(); + auto inputs = sym->ListInputs(nnvm::Symbol::kAll); + std::vector<std::string> input_names; + input_names.reserve(inputs.size()); + for (const auto& i : inputs) input_names.push_back(i->attrs.name); *out = new std::shared_ptr<Imperative::CachedOp>( new Imperative::CachedOp( - *sym, std::vector<std::pair<std::string, std::string> >())); + *sym, + std::vector<std::pair<std::string, std::string> >(), + input_names, + std::unordered_map<std::string, std::vector<NDArray> >())); API_END(); } int MXCreateCachedOpEx(SymbolHandle handle, - int num_params, + int num_flags, const char** keys, const char** vals, + int num_args, + const char** arg_names, + int num_params, + const char** param_names, + NDArrayHandle* params, CachedOpHandle *out) { nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle); API_BEGIN(); - std::vector<std::pair<std::string, std::string> > kwargs; + std::vector<std::pair<std::string, std::string> > flags; + for (int i = 0; i < num_flags; ++i) { + flags.push_back({keys[i], vals[i]}); + } + std::vector<std::string> args; + for (int i = 0; i < num_args; ++i) { + args.push_back(arg_names[i]); + } + std::unordered_map<std::string, std::vector<NDArray> > param_dict; for (int i = 0; i < num_params; ++i) { - kwargs.push_back({keys[i], vals[i]}); + param_dict[param_names[i]].emplace_back( + *reinterpret_cast<NDArray*>(params[i])); } *out = new std::shared_ptr<Imperative::CachedOp>( - new Imperative::CachedOp(*sym, kwargs)); + new Imperative::CachedOp(*sym, flags, args, param_dict)); API_END(); } diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index e4d4955..b174ea0 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -241,8 +241,6 @@ Graph AttachOpExecs(Graph g) { const auto& vdtype = g.GetAttr<DTypeVector>("dtype"); const auto& vshape = g.GetAttr<ShapeVector>("shape"); const auto& vctx = g.GetAttr<ContextVector>("context"); - const auto& saved_states = g.GetAttr< - std::unordered_map<const nnvm::Node*, OpStatePtr> >("saved_states"); const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode"); // get the graph @@ -271,13 +269,8 @@ Graph AttachOpExecs(Graph g) { itype.emplace_back(vdtype[idx.entry_id(e)]); } - OpStatePtr state; - if (saved_states.count(inode.source)) { - state = saved_states.at(inode.source); - } else { - state = fcreate_op_state[op]( - inode.source->attrs, vctx[i], ishape, itype); - } + OpStatePtr state = fcreate_op_state[op]( + inode.source->attrs, vctx[i], ishape, itype); FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>( op, "FStatefulComputeEx", vctx[i]); // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 4d24f55..d5dacf7 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -904,7 +904,6 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, } g = DetectInplaceAddTo(g); - g.attrs["saved_states"] = std::make_shared<nnvm::any>(std::move(saved_states_)); g = AttachOpExecs(g); g = AttachOpResources(g); graph_ = std::move(g); diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 3f1ebe5..bcde41d 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -234,8 +234,6 @@ class GraphExecutor : public Executor { size_t num_forward_inputs_{0}; // number of forward nodes size_t num_forward_nodes_{0}; - // saved operator for autograd - std::unordered_map<const nnvm::Node*, OpStatePtr> saved_states_; // monitor call back std::function<void(const char*, void*)> monitor_callback_{nullptr}; // whether to enable bulk execution diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index ce23735..10a8fc0 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -22,17 +22,19 @@ namespace mxnet { -DMLC_REGISTER_PARAMETER(CachedOpParam); +DMLC_REGISTER_PARAMETER(CachedOpConfig); Imperative::CachedOp::CachedOp( const nnvm::Symbol& sym, - const std::vector<std::pair<std::string, std::string> >& kwargs) { + const std::vector<std::pair<std::string, std::string> >& flags, + const std::vector<std::string> arg_names, + const std::unordered_map<std::string, std::vector<NDArray> >& params) { using namespace nnvm; using namespace imperative; static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; static const auto _copy = Op::Get("_copy"); - param_.Init(kwargs); + config_.Init(flags); // construct forward graph { @@ -66,7 +68,34 @@ Imperative::CachedOp::CachedOp( fwd_graph_.attrs["forward_ref_count"] = std::make_shared<dmlc::any>(std::move(ref_count)); - inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= param_.inline_limit; + inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit; + } + + // Set params + { + const auto& idx = fwd_graph_.indexed_graph(); + std::unordered_map<std::string, size_t> arg_name_to_id; + for (size_t i = 0; i < idx.input_nodes().size(); ++i) { + const auto& name = idx[idx.input_nodes()[i]].source->attrs.name; + auto iter = params.find(name); + if (iter == params.end()) { + arg_name_to_id[name] = i; + continue; + } + fwd_params_idx_.push_back(i); + for (const auto& param : iter->second) { + params_[param.ctx()].emplace_back(param); + } + } + + CHECK_EQ(arg_name_to_id.size(), arg_names.size()) + << "Expecting " << arg_name_to_id.size() << "inputs, given " << arg_names.size(); + + for (const auto& name : arg_names) { + auto iter = arg_name_to_id.find(name); + CHECK(iter != arg_name_to_id.end()) << "Unexpected input name " << name; + fwd_args_idx_.push_back(iter->second); + } } // construct backward graph @@ -341,22 +370,38 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph( void Imperative::CachedOp::Forward( const std::shared_ptr<CachedOp>& op_ptr, - const std::vector<NDArray*>& inputs, + const std::vector<NDArray*>& args, const std::vector<NDArray*>& outputs) { using namespace nnvm; using namespace imperative; static const auto cached_op = nnvm::Op::Get("_CachedOp"); + CHECK_EQ(args.size(), fwd_args_idx_.size()) + << "CachedOp requires " << fwd_args_idx_.size() + << " inputs but got " << args.size(); + + Context default_ctx = args[0]->ctx(); + + + std::vector<NDArray*> inputs(num_inputs()); + for (index_t i = 0; i < fwd_args_idx_.size(); ++i) { + inputs[fwd_args_idx_[i]] = args[i]; + } + if (fwd_params_idx_.size()) { + CHECK(params_.find(default_ctx) != params_.end()) + << "CachedOp is not initialized on context " << default_ctx; + + for (size_t i = 0; i < fwd_params_idx_.size(); ++i) { + inputs[fwd_params_idx_[i]] = ¶ms_[default_ctx][i]; + } + } + // Initialize bool recording = Imperative::Get()->is_recording(); nnvm::Graph g = GetForwardGraph(recording, inputs); const auto& idx = g.indexed_graph(); size_t num_inputs = idx.input_nodes().size(); - CHECK_EQ(num_inputs, inputs.size()) - << "CachedOp requires " << num_inputs << " but got " << inputs.size(); - - Context default_ctx = inputs[0]->ctx(); for (size_t i = 0; i < inputs.size(); ++i) { CHECK_EQ(inputs[i]->ctx(), default_ctx) << "CachedOp requires all inputs to live on the same context. But " @@ -403,7 +448,7 @@ void Imperative::CachedOp::Forward( const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode"); if (recording && !inlining_) Imperative::Get()->set_is_recording(false); - int prev_bulk_size = Engine::Get()->set_bulk_size(param_.forward_bulk_size); + int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size); Imperative::Get()->RunGraph( false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), @@ -491,7 +536,7 @@ void Imperative::CachedOp::Backward( const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode"); - int prev_bulk_size = Engine::Get()->set_bulk_size(param_.backward_bulk_size); + int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size); Imperative::Get()->RunGraph( retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), -- To stop receiving notification emails like this one, please contact j...@apache.org.