This is an automated email from the ASF dual-hosted git repository. haibin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new df9f79a standard update for sparse sgd_mom_update (#9189) df9f79a is described below commit df9f79ae5e265e28ceecab3c58828f3a84769eb4 Author: Ziyue Huang <zyhuan...@gmail.com> AuthorDate: Fri Jan 5 13:36:15 2018 +0800 standard update for sparse sgd_mom_update (#9189) * standard sparse sgd mom update * update * update comments * address comments * revise * more general infer stype * fix * fix * add comments for stype inference func * update --- python/mxnet/optimizer.py | 25 ++++--- src/operator/optimizer_op-inl.h | 112 ++++++++++++++++++++++++++++++-- src/operator/optimizer_op.cc | 62 +++++++++++++++++- src/operator/optimizer_op.cu | 66 +++++++++++++++++++ tests/python/unittest/test_optimizer.py | 24 ++++++- 5 files changed, 272 insertions(+), 17 deletions(-) diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 59898c9..feff87e 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -433,14 +433,8 @@ register = Optimizer.register # pylint: disable=invalid-name class SGD(Optimizer): """The SGD optimizer with momentum and weight decay. - The optimizer updates the weight by:: - - rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd * weight - state = momentum * state + rescaled_grad - weight = weight - state - - If the storage types of weight, state and grad are all ``row_sparse``, \ - **sparse updates** are applied by:: + If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \ + **lazy updates** are applied by:: for row in grad.indices: rescaled_grad[row] = lr * rescale_grad * clip(grad[row], clip_gradient) + wd * weight[row] @@ -454,6 +448,12 @@ class SGD(Optimizer): provides slightly different semantics than the original update, and may lead to different empirical results. + Otherwise, **standard updates** are applied by:: + + rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd * weight + state = momentum * state + rescaled_grad + weight = weight - state + For details of the update algorithm see :class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`. @@ -464,6 +464,9 @@ class SGD(Optimizer): ---------- momentum : float, optional The momentum value. + lazy_update : bool, optional + Default is True. If True, lazy updates are applied \ + if the storage types of weight and grad are both ``row_sparse``. multi_precision: bool, optional Flag to control the internal precision of the optimizer. ``False`` results in using the same precision as the weights (default), @@ -471,9 +474,10 @@ class SGD(Optimizer): in 32-bit precision even if actual weights used in the model have lower precision.\ Turning this on can improve convergence and accuracy when training with float16. """ - def __init__(self, momentum=0.0, **kwargs): + def __init__(self, momentum=0.0, lazy_update=True, **kwargs): super(SGD, self).__init__(**kwargs) self.momentum = momentum + self.lazy_update = lazy_update def create_state_multi_precision(self, index, weight): weight_master_copy = None @@ -489,8 +493,9 @@ class SGD(Optimizer): def create_state(self, index, weight): momentum = None + stype = weight.stype if self.lazy_update else 'default' if self.momentum != 0.0: - momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype) + momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype) return momentum def _update_impl(self, index, weight, grad, state, multi_precision=False): diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index a6b32b1..33b7dd5 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -38,6 +38,7 @@ #include "./elemwise_op_common.h" #include "mxnet_op.h" #include "./tensor/init_op.h" +#include "./tensor/util/tensor_util-inl.h" namespace mxnet { namespace op { @@ -460,6 +461,106 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param, mom.data(), req, &out_blob); } +/*! + * \brief Storge type inference function in optimizer. + * \param n_rsp The number of inputs that should be of row_sparse storage type + * if kFComputeEx is dispatched + * \param n_rsp_dns The number of inputs that should be of row_sparse or default storage type + * if kFComputeEx is dispatched + */ +template<int n_rsp, int n_rsp_dns> +inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector<int>* in_attrs, + std::vector<int>* out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_rsp + n_rsp_dns)); + CHECK_EQ(out_attrs->size(), 1U); + bool dispatched = false; + + if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { + // dns, ... -> dns + dispatched = storage_type_assign(out_attrs, kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); + } + const std::vector<int> rsp_stypes(in_attrs->begin(), in_attrs->begin() + n_rsp); + const std::vector<int> rsp_dns_stypes(in_attrs->begin() + n_rsp, in_attrs->end()); + if (!dispatched && common::ContainsOnlyStorage(rsp_stypes, kRowSparseStorage) && + (common::ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage) || + common::ContainsOnlyStorage(rsp_dns_stypes, kDefaultStorage))) { + // rsp, ..., rsp/dns, ... -> rsp + dispatched = storage_type_assign(out_attrs, kRowSparseStorage, + dispatch_mode, DispatchMode::kFComputeEx); + } + + if (!dispatched) { + dispatch_fallback(out_attrs, dispatch_mode); + LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs); + } + return true; +} + +template<int req> +struct SGDMomStdDnsRspDnsKernel { + template<typename DType, typename IType, typename RType> + MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data, + DType* mom_data, const DType* weight_data, const IType* grad_idx, + const DType* grad_data, const RType* prefix_sum, const DType clip_gradient, + const DType momentum, const DType lr, const DType wd, const DType rescale_grad) { + const DType rate = lr * wd; + const bool non_zero = (i == 0) ? prefix_sum[0] > 0 + : prefix_sum[i] > prefix_sum[i-1]; + + const index_t row_i = i * row_length; + const RType grad_i = (prefix_sum[i]-1) * row_length; + for (index_t j = 0; j < row_length; j++) { + const index_t data_i = row_i + j; + const DType grad = non_zero ? grad_data[grad_i + j] + : static_cast<DType>(0); + if (clip_gradient >= 0.0f) { + mom_data[data_i] = momentum * mom_data[data_i] + - rate * weight_data[data_i] + - lr * + mshadow_op::clip::Map(rescale_grad * grad, + clip_gradient); + } else { + mom_data[data_i] = momentum * mom_data[data_i] + - rate * weight_data[data_i] + - lr * rescale_grad * grad; + } + KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]); + } + } +}; + +template<typename xpu> +void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param, + const OpContext& ctx, + const TBlob& weight, + const NDArray& grad, + const TBlob& mom, + const OpReqType& req, + TBlob *out); + +template<typename xpu> +inline void SGDMomStdUpdateRspRspDnsImpl(const SGDMomParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const NDArray& mom, + const OpReqType& req, + NDArray *out) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + using namespace rowsparse; + CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights"); + Stream<xpu>* s = ctx.get_stream<xpu>(); + TBlob out_blob = out->data(); + SGDMomStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, + mom.data(), req, &out_blob); +} + template<typename xpu> inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, const OpContext &ctx, @@ -474,12 +575,15 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, const auto weight_stype = weight.storage_type(); const auto mom_stype = mom.storage_type(); const auto out_stype = outputs[0].storage_type(); - CHECK_EQ(weight_stype, mom_stype) << "Inconsistent storage type detected between mom.stype = " - << mom_stype << " and weight.stype = " << weight_stype; + NDArray out = outputs[0]; if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) && out_stype == kRowSparseStorage) { - NDArray out = outputs[0]; - SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out); + SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out); + } else if (weight.storage_type() == kRowSparseStorage && + grad.storage_type() == kRowSparseStorage && + mom.storage_type() == kDefaultStorage && + out_stype == kRowSparseStorage) { + SGDMomStdUpdateRspRspDnsImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out); } else { LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs); } diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 4de94e5..dda8092 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -37,6 +37,57 @@ DMLC_REGISTER_PARAMETER(RMSPropParam); DMLC_REGISTER_PARAMETER(RMSPropAlexParam); DMLC_REGISTER_PARAMETER(FtrlParam); +template<> +void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param, + const OpContext& ctx, + const TBlob& weight, + const NDArray& grad, + const TBlob& mom, + const OpReqType& req, + TBlob *out) { + using namespace mxnet_op; + using namespace rowsparse; + using namespace mshadow; + Stream<cpu>* s = ctx.get_stream<cpu>(); + if (req == kNullOp) return; + CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update"; + CHECK_GT(weight.shape_.Size(), 0); + CHECK_GT(mom.shape_.Size(), 0); + MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { + MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + DType* weight_data = weight.dptr<DType>(); + IType* grad_idx = grad.aux_data(kIdx).dptr<IType>(); + DType* grad_val = grad.data().dptr<DType>(); + DType* mom_data = mom.dptr<DType>(); + DType* out_data = out->dptr<DType>(); + nnvm::dim_t num_rows = weight.shape_[0]; + auto row_length = weight.shape_.ProdShape(1, weight.ndim()); + Tensor<cpu, 1, char> workspace = ctx.requested[0] + .get_space_typed<cpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)), s); + + nnvm::dim_t* prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_); + // mark row flags + Kernel<set_zero, cpu>::Launch(s, num_rows, prefix_sum); + if (grad.storage_initialized()) { + Kernel<MarkRowFlgKernel, cpu>::Launch(s, grad.aux_shape(kIdx)[0], + prefix_sum, grad_idx); + // calculate inclusive prefix sum + for (nnvm::dim_t i = 1; i < num_rows; i++) { + prefix_sum[i] += prefix_sum[i - 1]; + } + } + Kernel<SGDMomStdDnsRspDnsKernel<req_type>, cpu>::Launch(s, num_rows, row_length, + out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum, + static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum), + static_cast<DType>(param.lr), static_cast<DType>(param.wd), + static_cast<DType>(param.rescale_grad)); + }); + }); + }); +} + + NNVM_REGISTER_OP(sgd_update) MXNET_ADD_SPARSE_OP_ALIAS(sgd_update) .describe(R"code(Update function for Stochastic Gradient Descent (SDG) optimizer. @@ -84,7 +135,10 @@ It updates the weights using:: Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. -If weight and momentum are both of ``row_sparse`` storage type, +If weight and grad are both of ``row_sparse`` storage type and momentum is of ``default`` storage type, +standard update is applied. + +If weight, grad and momentum are all of ``row_sparse`` storage type, only the row slices whose indices appear in grad.indices are updated (for both weight and momentum):: for row in gradient.indices: @@ -97,11 +151,15 @@ only the row slices whose indices appear in grad.indices are updated (for both w .set_attr_parser(ParamParser<SGDMomParam>) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>) -.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<3, 1, false, true, false>) +.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 1>) .set_attr<nnvm::FMutateInputs>("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { return std::vector<uint32_t>{2}; }) +.set_attr<FResourceRequest>("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; + }) .set_attr<FCompute>("FCompute<cpu>", SGDMomUpdate<cpu>) .set_attr<FComputeEx>("FComputeEx<cpu>", SGDMomUpdateEx<cpu>) .add_argument("weight", "NDArray-or-Symbol", "Weight") diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index 4306b32..9512e92 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -24,10 +24,76 @@ * \author Junyuan Xie */ #include "./optimizer_op-inl.h" +#include <cub/cub.cuh> namespace mxnet { namespace op { +template<> +void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& param, + const OpContext& ctx, + const TBlob& weight, + const NDArray& grad, + const TBlob& mom, + const OpReqType& req, + TBlob *out) { + using namespace mxnet_op; + using namespace rowsparse; + using namespace mshadow; + Stream<gpu>* s = ctx.get_stream<gpu>(); + if (req == kNullOp) return; + CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update"; + CHECK_GT(weight.shape_.Size(), 0); + CHECK_GT(mom.shape_.Size(), 0); + + MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { + MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + DType* weight_data = weight.dptr<DType>(); + IType* grad_idx = grad.aux_data(kIdx).dptr<IType>(); + DType* grad_val = grad.data().dptr<DType>(); + DType* mom_data = mom.dptr<DType>(); + DType* out_data = out->dptr<DType>(); + nnvm::dim_t num_rows = weight.shape_[0]; + nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim()); + + nnvm::dim_t* prefix_sum = NULL; + void* d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSum(d_temp_storage, + temp_storage_bytes, + prefix_sum, + prefix_sum, + num_rows, + Stream<gpu>::GetStream(s)); + Tensor<gpu, 1, char> workspace = ctx.requested[0] + .get_space_typed<gpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t) + + temp_storage_bytes), s); + prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_); + d_temp_storage = workspace.dptr_ + num_rows*sizeof(nnvm::dim_t); + // mark row flags + Fill<false>(s, TBlob(prefix_sum, Shape1(num_rows), gpu::kDevMask), kWriteTo, 0); + if (grad.storage_initialized()) { + Kernel<MarkRowFlgKernel, gpu>::Launch(s, grad.aux_shape(kIdx)[0], + prefix_sum, grad_idx); + // calculate inclusive prefix sum + cub::DeviceScan::InclusiveSum(d_temp_storage, + temp_storage_bytes, + prefix_sum, + prefix_sum, + num_rows, + mshadow::Stream<gpu>::GetStream(s)); + } + Kernel<SGDMomStdDnsRspDnsKernel<req_type>, gpu>::Launch(s, num_rows, row_length, + out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum, + static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum), + static_cast<DType>(param.lr), static_cast<DType>(param.wd), + static_cast<DType>(param.rescale_grad)); + }); + }); + }); +} + NNVM_REGISTER_OP(sgd_update) .set_attr<FCompute>("FCompute<gpu>", SGDUpdate<gpu>) .set_attr<FComputeEx>("FComputeEx<gpu>", SGDUpdateEx<gpu>); diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 655e157..ae248b0 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -334,6 +334,29 @@ def test_sparse_sgd(): w_stype='row_sparse', g_stype='row_sparse') +def test_std_sparse_sgd(): + mx.random.seed(0) + opt1 = PySGD + opt2 = mx.optimizer.SGD + shape = (3, 4, 5) + mom_options = [{'momentum': 0.9}] + cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] + rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] + wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] + for dtype in [np.float32]: + for mom_option in mom_options: + for cg_option in cg_options: + for rg_option in rg_options: + for wd_option in wd_options: + kwarg = {} + kwarg.update(mom_option) + kwarg.update(cg_option) + kwarg.update(rg_option) + kwarg.update(wd_option) + compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape, dtype, + w_stype='row_sparse', g_stype='row_sparse') + + # FTML class PyFTML(mx.optimizer.Optimizer): @@ -400,7 +423,6 @@ def test_ftml(): compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) - # ADAM class PyAdam(mx.optimizer.Optimizer): -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].