This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new df9f79a  standard update for sparse sgd_mom_update (#9189)
df9f79a is described below

commit df9f79ae5e265e28ceecab3c58828f3a84769eb4
Author: Ziyue Huang <zyhuan...@gmail.com>
AuthorDate: Fri Jan 5 13:36:15 2018 +0800

    standard update for sparse sgd_mom_update (#9189)
    
    * standard sparse sgd mom update
    
    * update
    
    * update comments
    
    * address comments
    
    * revise
    
    * more general infer stype
    
    * fix
    
    * fix
    
    * add comments for stype inference func
    
    * update
---
 python/mxnet/optimizer.py               |  25 ++++---
 src/operator/optimizer_op-inl.h         | 112 ++++++++++++++++++++++++++++++--
 src/operator/optimizer_op.cc            |  62 +++++++++++++++++-
 src/operator/optimizer_op.cu            |  66 +++++++++++++++++++
 tests/python/unittest/test_optimizer.py |  24 ++++++-
 5 files changed, 272 insertions(+), 17 deletions(-)

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 59898c9..feff87e 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -433,14 +433,8 @@ register = Optimizer.register   # pylint: 
disable=invalid-name
 class SGD(Optimizer):
     """The SGD optimizer with momentum and weight decay.
 
-    The optimizer updates the weight by::
-
-        rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd * 
weight
-        state = momentum * state + rescaled_grad
-        weight = weight - state
-
-    If the storage types of weight, state and grad are all ``row_sparse``, \
-    **sparse updates** are applied by::
+    If the storage types of weight and grad are both ``row_sparse``, and 
``lazy_update`` is True, \
+    **lazy updates** are applied by::
 
         for row in grad.indices:
             rescaled_grad[row] = lr * rescale_grad * clip(grad[row], 
clip_gradient) + wd * weight[row]
@@ -454,6 +448,12 @@ class SGD(Optimizer):
     provides slightly different semantics than the original update, and
     may lead to different empirical results.
 
+    Otherwise, **standard updates** are applied by::
+
+        rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd * 
weight
+        state = momentum * state + rescaled_grad
+        weight = weight - state
+
     For details of the update algorithm see
     :class:`~mxnet.ndarray.sgd_update` and 
:class:`~mxnet.ndarray.sgd_mom_update`.
 
@@ -464,6 +464,9 @@ class SGD(Optimizer):
     ----------
     momentum : float, optional
        The momentum value.
+    lazy_update : bool, optional
+       Default is True. If True, lazy updates are applied \
+       if the storage types of weight and grad are both ``row_sparse``.
     multi_precision: bool, optional
        Flag to control the internal precision of the optimizer.
        ``False`` results in using the same precision as the weights (default),
@@ -471,9 +474,10 @@ class SGD(Optimizer):
                 in 32-bit precision even if actual weights used in the model 
have lower precision.\
                 Turning this on can improve convergence and accuracy when 
training with float16.
     """
-    def __init__(self, momentum=0.0, **kwargs):
+    def __init__(self, momentum=0.0, lazy_update=True, **kwargs):
         super(SGD, self).__init__(**kwargs)
         self.momentum = momentum
+        self.lazy_update = lazy_update
 
     def create_state_multi_precision(self, index, weight):
         weight_master_copy = None
@@ -489,8 +493,9 @@ class SGD(Optimizer):
 
     def create_state(self, index, weight):
         momentum = None
+        stype = weight.stype if self.lazy_update else 'default'
         if self.momentum != 0.0:
-            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, 
stype=weight.stype)
+            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, 
stype=stype)
         return momentum
 
     def _update_impl(self, index, weight, grad, state, multi_precision=False):
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index a6b32b1..33b7dd5 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -38,6 +38,7 @@
 #include "./elemwise_op_common.h"
 #include "mxnet_op.h"
 #include "./tensor/init_op.h"
+#include "./tensor/util/tensor_util-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -460,6 +461,106 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& 
param,
                                  mom.data(), req, &out_blob);
 }
 
+/*! 
+ * \brief Storge type inference function in optimizer.
+ * \param n_rsp     The number of inputs that should be of row_sparse storage 
type
+ *                  if kFComputeEx is dispatched
+ * \param n_rsp_dns The number of inputs that should be of row_sparse or 
default storage type
+ *                  if kFComputeEx is dispatched
+ */
+template<int n_rsp, int n_rsp_dns>
+inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs,
+                              const int dev_mask,
+                              DispatchMode* dispatch_mode,
+                              std::vector<int>* in_attrs,
+                              std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_rsp + n_rsp_dns));
+  CHECK_EQ(out_attrs->size(), 1U);
+  bool dispatched = false;
+
+  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+    // dns, ... -> dns
+    dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  const std::vector<int> rsp_stypes(in_attrs->begin(), in_attrs->begin() + 
n_rsp);
+  const std::vector<int> rsp_dns_stypes(in_attrs->begin() + n_rsp, 
in_attrs->end());
+  if (!dispatched && common::ContainsOnlyStorage(rsp_stypes, 
kRowSparseStorage) &&
+      (common::ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage) ||
+       common::ContainsOnlyStorage(rsp_dns_stypes, kDefaultStorage))) {
+    // rsp, ..., rsp/dns, ... -> rsp
+    dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+
+  if (!dispatched) {
+    dispatch_fallback(out_attrs, dispatch_mode);
+    LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs);
+  }
+  return true;
+}
+
+template<int req>
+struct SGDMomStdDnsRspDnsKernel {
+  template<typename DType, typename IType, typename RType>
+  MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
+    DType* mom_data, const DType* weight_data, const IType* grad_idx,
+    const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
+    const DType momentum, const DType lr, const DType wd, const DType 
rescale_grad) {
+    const DType rate = lr * wd;
+    const bool non_zero = (i == 0) ? prefix_sum[0] > 0
+                                   : prefix_sum[i] > prefix_sum[i-1];
+
+    const index_t row_i = i * row_length;
+    const RType grad_i = (prefix_sum[i]-1) * row_length;
+    for (index_t j = 0; j < row_length; j++) {
+      const index_t data_i = row_i + j;
+      const DType grad = non_zero ? grad_data[grad_i + j]
+                                  : static_cast<DType>(0);
+      if (clip_gradient >= 0.0f) {
+        mom_data[data_i] = momentum * mom_data[data_i]
+                - rate * weight_data[data_i]
+                - lr *
+                mshadow_op::clip::Map(rescale_grad * grad,
+                                      clip_gradient);
+      } else {
+        mom_data[data_i] = momentum * mom_data[data_i]
+                  - rate * weight_data[data_i]
+                  - lr * rescale_grad * grad;
+      }
+      KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + 
mom_data[data_i]);
+    }
+  }
+};
+
+template<typename xpu>
+void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param,
+                                  const OpContext& ctx,
+                                  const TBlob& weight,
+                                  const NDArray& grad,
+                                  const TBlob& mom,
+                                  const OpReqType& req,
+                                  TBlob *out);
+
+template<typename xpu>
+inline void SGDMomStdUpdateRspRspDnsImpl(const SGDMomParam& param,
+                                         const OpContext& ctx,
+                                         const NDArray& weight,
+                                         const NDArray& grad,
+                                         const NDArray& mom,
+                                         const OpReqType& req,
+                                         NDArray *out) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  TBlob out_blob = out->data();
+  SGDMomStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
+                                    mom.data(), req, &out_blob);
+}
+
 template<typename xpu>
 inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
                            const OpContext &ctx,
@@ -474,12 +575,15 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
   const auto weight_stype = weight.storage_type();
   const auto mom_stype = mom.storage_type();
   const auto out_stype = outputs[0].storage_type();
-  CHECK_EQ(weight_stype, mom_stype) << "Inconsistent storage type detected 
between mom.stype = "
-           << mom_stype << " and weight.stype = " << weight_stype;
+  NDArray out = outputs[0];
   if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
       out_stype == kRowSparseStorage) {
-     NDArray out = outputs[0];
-     SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], 
&out);
+    SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], 
&out);
+  } else if (weight.storage_type() == kRowSparseStorage &&
+             grad.storage_type() == kRowSparseStorage &&
+             mom.storage_type() == kDefaultStorage &&
+             out_stype == kRowSparseStorage) {
+    SGDMomStdUpdateRspRspDnsImpl<xpu>(param, ctx, weight, grad, mom, req[0], 
&out);
   } else {
     LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, 
req, outputs);
   }
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 4de94e5..dda8092 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -37,6 +37,57 @@ DMLC_REGISTER_PARAMETER(RMSPropParam);
 DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
 DMLC_REGISTER_PARAMETER(FtrlParam);
 
+template<>
+void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
+                                       const OpContext& ctx,
+                                       const TBlob& weight,
+                                       const NDArray& grad,
+                                       const TBlob& mom,
+                                       const OpReqType& req,
+                                       TBlob *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  using namespace mshadow;
+  Stream<cpu>* s = ctx.get_stream<cpu>();
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse 
sgd_mom_update";
+  CHECK_GT(weight.shape_.Size(), 0);
+  CHECK_GT(mom.shape_.Size(), 0);
+  MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+        DType* weight_data = weight.dptr<DType>();
+        IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+        DType* grad_val = grad.data().dptr<DType>();
+        DType* mom_data = mom.dptr<DType>();
+        DType* out_data = out->dptr<DType>();
+        nnvm::dim_t num_rows = weight.shape_[0];
+        auto row_length = weight.shape_.ProdShape(1, weight.ndim());
+        Tensor<cpu, 1, char> workspace = ctx.requested[0]
+          .get_space_typed<cpu, 1, char>(Shape1(num_rows * 
sizeof(nnvm::dim_t)), s);
+
+        nnvm::dim_t* prefix_sum = 
reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
+        // mark row flags
+        Kernel<set_zero, cpu>::Launch(s, num_rows, prefix_sum);
+        if (grad.storage_initialized()) {
+          Kernel<MarkRowFlgKernel, cpu>::Launch(s, grad.aux_shape(kIdx)[0],
+            prefix_sum, grad_idx);
+          // calculate inclusive prefix sum
+          for (nnvm::dim_t i = 1; i < num_rows; i++) {
+            prefix_sum[i] += prefix_sum[i - 1];
+          }
+        }
+        Kernel<SGDMomStdDnsRspDnsKernel<req_type>, cpu>::Launch(s, num_rows, 
row_length,
+          out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum,
+          static_cast<DType>(param.clip_gradient), 
static_cast<DType>(param.momentum),
+          static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+          static_cast<DType>(param.rescale_grad));
+      });
+    });
+  });
+}
+
+
 NNVM_REGISTER_OP(sgd_update)
 MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
 .describe(R"code(Update function for Stochastic Gradient Descent (SDG) 
optimizer.
@@ -84,7 +135,10 @@ It updates the weights using::
 
 Where the parameter ``momentum`` is the decay rate of momentum estimates at 
each epoch.
 
-If weight and momentum are both of ``row_sparse`` storage type,
+If weight and grad are both of ``row_sparse`` storage type and momentum is of 
``default`` storage type,
+standard update is applied.
+
+If weight, grad and momentum are all of ``row_sparse`` storage type,
 only the row slices whose indices appear in grad.indices are updated (for both 
weight and momentum)::
 
   for row in gradient.indices:
@@ -97,11 +151,15 @@ only the row slices whose indices appear in grad.indices 
are updated (for both w
 .set_attr_parser(ParamParser<SGDMomParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<3, 1, 
false, true, false>)
+.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 1>)
 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2};
   })
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
 .set_attr<FCompute>("FCompute<cpu>", SGDMomUpdate<cpu>)
 .set_attr<FComputeEx>("FComputeEx<cpu>", SGDMomUpdateEx<cpu>)
 .add_argument("weight", "NDArray-or-Symbol", "Weight")
diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu
index 4306b32..9512e92 100644
--- a/src/operator/optimizer_op.cu
+++ b/src/operator/optimizer_op.cu
@@ -24,10 +24,76 @@
  * \author Junyuan Xie
  */
 #include "./optimizer_op-inl.h"
+#include <cub/cub.cuh>
 
 namespace mxnet {
 namespace op {
 
+template<>
+void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& param,
+                                       const OpContext& ctx,
+                                       const TBlob& weight,
+                                       const NDArray& grad,
+                                       const TBlob& mom,
+                                       const OpReqType& req,
+                                       TBlob *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  using namespace mshadow;
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse 
sgd_mom_update";
+  CHECK_GT(weight.shape_.Size(), 0);
+  CHECK_GT(mom.shape_.Size(), 0);
+
+  MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+        DType* weight_data = weight.dptr<DType>();
+        IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+        DType* grad_val = grad.data().dptr<DType>();
+        DType* mom_data = mom.dptr<DType>();
+        DType* out_data = out->dptr<DType>();
+        nnvm::dim_t num_rows = weight.shape_[0];
+        nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
+
+        nnvm::dim_t* prefix_sum = NULL;
+        void* d_temp_storage = NULL;
+        size_t temp_storage_bytes = 0;
+        cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                      temp_storage_bytes,
+                                      prefix_sum,
+                                      prefix_sum,
+                                      num_rows,
+                                      Stream<gpu>::GetStream(s));
+        Tensor<gpu, 1, char> workspace = ctx.requested[0]
+          .get_space_typed<gpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t) 
+
+                                         temp_storage_bytes), s);
+        prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
+        d_temp_storage = workspace.dptr_ + num_rows*sizeof(nnvm::dim_t);
+        // mark row flags
+        Fill<false>(s, TBlob(prefix_sum, Shape1(num_rows), gpu::kDevMask), 
kWriteTo, 0);
+        if (grad.storage_initialized()) {
+          Kernel<MarkRowFlgKernel, gpu>::Launch(s, grad.aux_shape(kIdx)[0],
+            prefix_sum, grad_idx);
+          // calculate inclusive prefix sum
+          cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                        temp_storage_bytes,
+                                        prefix_sum,
+                                        prefix_sum,
+                                        num_rows,
+                                        mshadow::Stream<gpu>::GetStream(s));
+        }
+        Kernel<SGDMomStdDnsRspDnsKernel<req_type>, gpu>::Launch(s, num_rows, 
row_length,
+          out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum,
+          static_cast<DType>(param.clip_gradient), 
static_cast<DType>(param.momentum),
+          static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+          static_cast<DType>(param.rescale_grad));
+      });
+    });
+  });
+}
+
 NNVM_REGISTER_OP(sgd_update)
 .set_attr<FCompute>("FCompute<gpu>", SGDUpdate<gpu>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", SGDUpdateEx<gpu>);
diff --git a/tests/python/unittest/test_optimizer.py 
b/tests/python/unittest/test_optimizer.py
index 655e157..ae248b0 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -334,6 +334,29 @@ def test_sparse_sgd():
                                               w_stype='row_sparse', 
g_stype='row_sparse')
 
 
+def test_std_sparse_sgd():
+    mx.random.seed(0)
+    opt1 = PySGD
+    opt2 = mx.optimizer.SGD
+    shape = (3, 4, 5)
+    mom_options = [{'momentum': 0.9}]
+    cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+    rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+    wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+    for dtype in [np.float32]:
+        for mom_option in mom_options:
+            for cg_option in cg_options:
+                for rg_option in rg_options:
+                    for wd_option in wd_options:
+                        kwarg = {}
+                        kwarg.update(mom_option)
+                        kwarg.update(cg_option)
+                        kwarg.update(rg_option)
+                        kwarg.update(wd_option)
+                        compare_optimizer(opt1(**kwarg), 
opt2(lazy_update=False, **kwarg), shape, dtype,
+                                          w_stype='row_sparse', 
g_stype='row_sparse')
+
+
 # FTML
 
 class PyFTML(mx.optimizer.Optimizer):
@@ -400,7 +423,6 @@ def test_ftml():
                             compare_optimizer(opt1(**kwarg), opt2(**kwarg), 
shape, dtype)
 
 
-
 # ADAM
 
 class PyAdam(mx.optimizer.Optimizer):

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to