This is an automated email from the ASF dual-hosted git repository. haibin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 9ccd647 standard sgd_update (#10614) 9ccd647 is described below commit 9ccd64787a05c6c04466ecfb0763b70ee8fbc988 Author: Ziyue Huang <zyhuan...@gmail.com> AuthorDate: Fri Apr 20 08:03:02 2018 +0800 standard sgd_update (#10614) --- python/mxnet/optimizer.py | 2 +- src/operator/optimizer_op-inl.h | 15 +++++++++++++-- tests/python/unittest/test_optimizer.py | 2 +- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 6589e77..2f7c51b 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -516,7 +516,7 @@ class SGD(Optimizer): sgd_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs) else: - sgd_update(weight, grad, out=weight, + sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) else: if state[0] is not None: diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 3b6bd57..dfc7bef 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -47,6 +47,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> { float wd; float rescale_grad; float clip_gradient; + bool lazy_update; DMLC_DECLARE_PARAMETER(SGDParam) { DMLC_DECLARE_FIELD(lr) .describe("Learning rate"); @@ -63,6 +64,9 @@ struct SGDParam : public dmlc::Parameter<SGDParam> { .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " "If clip_gradient <= 0, gradient clipping is turned off. " "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(lazy_update) + .set_default(true) + .describe("If true, lazy updates are applied."); } }; @@ -177,7 +181,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param, Stream<xpu>* s = ctx.get_stream<xpu>(); CHECK_EQ(grad.storage_type(), kRowSparseStorage); // if gradients are zeros, no weights are updated - if (!grad.storage_initialized() || req == kNullOp) return; + if (req == kNullOp) return; CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update"; CHECK_GT(weight.shape_.Size(), 0); @@ -185,6 +189,13 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param, MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { MXNET_ASSIGN_REQ_SWITCH(req, req_type, { DType* weight_data = weight.dptr<DType>(); + float wd = param.wd; + if (!param.lazy_update) { + Kernel<op_with_req<mshadow_op::mul, req_type>, xpu>::Launch(s, weight.Size(), + weight_data, weight_data, static_cast<DType>(1 - param.lr * param.wd)); + wd = 0; + } + if (!grad.storage_initialized()) return; const IType* grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>(); const DType* grad_val = grad.data().dptr<DType>(); const nnvm::dim_t num_rows = grad.aux_shape(rowsparse::kIdx)[0]; @@ -196,7 +207,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param, Kernel<SGDDnsRspKernel<req_type, xpu>, xpu>::Launch(s, num_threads, row_length, out->dptr<DType>(), weight_data, grad_idx, grad_val, static_cast<DType>(param.clip_gradient), - static_cast<DType>(param.lr), static_cast<DType>(param.wd), + static_cast<DType>(param.lr), static_cast<DType>(wd), static_cast<DType>(param.rescale_grad)); }); }); diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index bbd7845..d1dc31a 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -344,7 +344,7 @@ def test_std_sparse_sgd(): opt1 = PySGD opt2 = mx.optimizer.SGD shape = (3, 4, 5) - mom_options = [{'momentum': 0.9}] + mom_options = [{'momentum': 0.0}, {'momentum': 0.9}] cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] -- To stop receiving notification emails like this one, please contact hai...@apache.org.