eric-haibin-lin closed pull request #9468: standard adam update for sparse 
tensor
URL: https://github.com/apache/incubator-mxnet/pull/9468
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 4285aecef1..57340be280 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -780,15 +780,8 @@ class Adam(Optimizer):
     This class implements the optimizer described in *Adam: A Method for
     Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
 
-    The optimizer updates the weight by::
-
-        rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
-        m = beta1 * m + (1 - beta1) * rescaled_grad
-        v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
-        w = w - learning_rate * m / (sqrt(v) + epsilon)
-
-    If the storage types of weight, state and grad are all ``row_sparse``, \
-    **sparse updates** are applied by::
+    If the storage types of weight and grad are both ``row_sparse``, and 
``lazy_update`` is True, \
+    **lazy updates** are applied by::
 
         for row in grad.indices:
             rescaled_grad[row] = clip(grad[row] * rescale_grad + wd * 
weight[row], clip_gradient)
@@ -796,12 +789,19 @@ class Adam(Optimizer):
             v[row] = beta2 * v[row] + (1 - beta2) * (rescaled_grad[row]**2)
             w[row] = w[row] - learning_rate * m[row] / (sqrt(v[row]) + epsilon)
 
-    The sparse update only updates the mean and var for the weights whose 
row_sparse
+    The lazy update only updates the mean and var for the weights whose 
row_sparse
     gradient indices appear in the current batch, rather than updating it for 
all indices.
     Compared with the original update, it can provide large improvements in 
model training
     throughput for some applications. However, it provides slightly different 
semantics than
     the original update, and may lead to different empirical results.
 
+    Otherwise, **standard updates** are applied by::
+
+        rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
+        m = beta1 * m + (1 - beta1) * rescaled_grad
+        v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
+        w = w - learning_rate * m / (sqrt(v) + epsilon)
+
     This optimizer accepts the following parameters in addition to those 
accepted
     by :class:`.Optimizer`.
 
@@ -815,19 +815,24 @@ class Adam(Optimizer):
         Exponential decay rate for the second moment estimates.
     epsilon : float, optional
         Small value to avoid division by 0.
+    lazy_update : bool, optional
+       Default is True. If True, lazy updates are applied \
+       if the storage types of weight and grad are both ``row_sparse``.
     """
     def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, 
epsilon=1e-8,
-                 **kwargs):
+                 lazy_update=True, **kwargs):
         super(Adam, self).__init__(learning_rate=learning_rate, **kwargs)
         self.beta1 = beta1
         self.beta2 = beta2
         self.epsilon = epsilon
+        self.lazy_update = lazy_update
 
     def create_state(self, index, weight):
+        stype = weight.stype if self.lazy_update else 'default'
         return (zeros(weight.shape, weight.context, dtype=weight.dtype,
-                      stype=weight.stype),  # mean
+                      stype=stype),  # mean
                 zeros(weight.shape, weight.context, dtype=weight.dtype,
-                      stype=weight.stype))  # variance
+                      stype=stype))  # variance
 
     def update(self, index, weight, grad, state):
         assert(isinstance(weight, NDArray))
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 42721a9146..0e0dd7f152 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -859,6 +859,71 @@ inline void AdamUpdateRspRspRspImpl(const AdamParam& param,
                                var.data(), req, &out_blob);
 }
 
+template<int req>
+struct AdamStdDnsRspDnsKernel {
+  template<typename DType, typename IType, typename RType>
+  MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* 
out_data,
+    DType* mean_data, DType* var_data, const DType* weight_data, const IType* 
grad_idx,
+    const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
+    const DType beta1, const DType beta2, const DType lr, const DType wd,
+    const DType epsilon, const DType rescale_grad) {
+    using namespace mshadow_op;
+    const bool non_zero = (i == 0) ? prefix_sum[0] > 0
+                                   : prefix_sum[i] > prefix_sum[i-1];
+
+    const index_t row_i = i * row_length;
+    const RType grad_i = (prefix_sum[i]-1) * row_length;
+    for (index_t j = 0; j < row_length; j++) {
+      const index_t data_i = row_i + j;
+      const DType grad_rescaled = non_zero ? static_cast<DType>(
+                                               grad_data[grad_i + j] * 
rescale_grad +
+                                               weight_data[data_i] * wd)
+                                           : 
static_cast<DType>(weight_data[data_i] * wd);
+      if (clip_gradient >= 0.0f) {
+        mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
+                            clip::Map(grad_rescaled, clip_gradient);
+        var_data[data_i] =  beta2 * var_data[data_i] + (1.f - beta2) * 
square::Map(
+                            clip::Map(grad_rescaled, clip_gradient));
+      } else {
+        mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * 
grad_rescaled;
+        var_data[data_i] = beta2 * var_data[data_i] +
+                           (1.f - beta2) * square::Map(grad_rescaled);
+      }
+      KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * 
mean_data[data_i] /
+                    (square_root::Map(var_data[data_i]) + epsilon));
+    }
+  }
+};
+
+
+template<typename xpu>
+void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
+                                const OpContext& ctx,
+                                const TBlob& weight,
+                                const NDArray& grad,
+                                const TBlob& mean,
+                                const TBlob& var,
+                                const OpReqType& req,
+                                TBlob *out);
+
+template<typename xpu>
+inline void AdamStdUpdateRspRspRspImpl(const AdamParam& param,
+                                       const OpContext& ctx,
+                                       const NDArray& weight,
+                                       const NDArray& grad,
+                                       const NDArray& mean,
+                                       const NDArray& var,
+                                       const OpReqType& req,
+                                       NDArray *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamStdUpdate", "weights");
+  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
+  TBlob out_blob = out->data();
+  // reuse dns rsp implementation when storage_shape == shape
+  AdamStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
+                                  var.data(), req, &out_blob);
+}
 
 template<typename xpu>
 inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
@@ -868,18 +933,20 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
                          const std::vector<NDArray> &outputs) {
   const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
   const auto weight_stype = inputs[0].storage_type();
+  const auto grad_stype = inputs[1].storage_type();
   const auto mean_stype = inputs[2].storage_type();
   const auto var_stype = inputs[3].storage_type();
   const auto out_stype = outputs[0].storage_type();
-  CHECK_EQ(mean_stype, weight_stype) << "Inconsistent storage type detected 
between "
-           << " mean.stype = " << mean_stype << " and weight.stype = " << 
weight_stype;
-  CHECK_EQ(var_stype, weight_stype) << "Inconsistent storage type detected 
between "
-           << " var.stype = " << var_stype << " and weight.stype = " << 
weight_stype;
+  NDArray out = outputs[0];
   if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
       out_stype == kRowSparseStorage) {
-     NDArray out = outputs[0];
      AdamUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
                                   inputs[3], req[0], &out);
+  } else if (weight_stype == kRowSparseStorage && grad_stype == 
kRowSparseStorage &&
+             mean_stype == kDefaultStorage && var_stype == kDefaultStorage &&
+             out_stype == kRowSparseStorage) {
+     AdamStdUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], 
inputs[2],
+                                     inputs[3], req[0], &out);
   } else {
     LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, 
req, outputs);
   }
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 8760fe94a5..136769a1bf 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -148,6 +148,62 @@ void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& 
param,
   });
 }
 
+template<>
+void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& param,
+                                     const OpContext& ctx,
+                                     const TBlob& weight,
+                                     const NDArray& grad,
+                                     const TBlob& mean,
+                                     const TBlob& var,
+                                     const OpReqType& req,
+                                     TBlob *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  using namespace mshadow;
+  Stream<cpu>* s = ctx.get_stream<cpu>();
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse 
adam_update";
+  CHECK_GT(weight.shape_.Size(), 0);
+  CHECK_GT(mean.shape_.Size(), 0);
+  CHECK_GT(var.shape_.Size(), 0);
+
+  MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+        const DType* weight_data = weight.dptr<DType>();
+        const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+        const DType* grad_val = grad.data().dptr<DType>();
+        DType* mean_data = mean.dptr<DType>();
+        DType* var_data = var.dptr<DType>();
+        DType* out_data = out->dptr<DType>();
+        nnvm::dim_t num_rows = weight.shape_[0];
+        nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
+        Tensor<cpu, 1, char> workspace = ctx.requested[0]
+          .get_space_typed<cpu, 1, char>(Shape1(num_rows * 
sizeof(nnvm::dim_t)), s);
+
+        nnvm::dim_t* prefix_sum = 
reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
+        // mark row flags
+        Kernel<set_zero, cpu>::Launch(s, num_rows, prefix_sum);
+        if (grad.storage_initialized()) {
+          Kernel<MarkRowFlgKernel, cpu>::Launch(s, grad.aux_shape(kIdx)[0],
+            prefix_sum, grad_idx);
+          // calculate inclusive prefix sum
+          for (nnvm::dim_t i = 1; i < num_rows; i++) {
+            prefix_sum[i] += prefix_sum[i - 1];
+          }
+        }
+
+        Kernel<AdamStdDnsRspDnsKernel<req_type>, cpu>::Launch(s, num_rows, 
row_length,
+          out_data, mean_data, var_data, weight_data, grad_idx, grad_val, 
prefix_sum,
+          static_cast<DType>(param.clip_gradient), 
static_cast<DType>(param.beta1),
+          static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
+          static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
+          static_cast<DType>(param.rescale_grad));
+      });
+    });
+  });
+}
+
 
 NNVM_REGISTER_OP(sgd_update)
 MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
@@ -329,8 +385,12 @@ only the row slices whose indices appear in grad.indices 
are updated (for w, m a
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<AdamParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<4, 1, 
false, true, false>)
+.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 2>)
 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2, 3};
diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu
index 891f24fe79..1bd6117432 100644
--- a/src/operator/optimizer_op.cu
+++ b/src/operator/optimizer_op.cu
@@ -94,6 +94,74 @@ void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& 
param,
   });
 }
 
+template<>
+void AdamStdUpdateDnsRspDnsImpl<gpu>(const AdamParam& param,
+                                     const OpContext& ctx,
+                                     const TBlob& weight,
+                                     const NDArray& grad,
+                                     const TBlob& mean,
+                                     const TBlob& var,
+                                     const OpReqType& req,
+                                     TBlob *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  using namespace mshadow;
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse 
adam_update";
+  CHECK_GT(weight.shape_.Size(), 0);
+  CHECK_GT(mean.shape_.Size(), 0);
+  CHECK_GT(var.shape_.Size(), 0);
+
+  MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+        const DType* weight_data = weight.dptr<DType>();
+        const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+        const DType* grad_val = grad.data().dptr<DType>();
+        DType* mean_data = mean.dptr<DType>();
+        DType* var_data = var.dptr<DType>();
+        DType* out_data = out->dptr<DType>();
+        nnvm::dim_t num_rows = weight.shape_[0];
+        nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
+        nnvm::dim_t* prefix_sum = NULL;
+        void* d_temp_storage = NULL;
+        size_t temp_storage_bytes = 0;
+        cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                      temp_storage_bytes,
+                                      prefix_sum,
+                                      prefix_sum,
+                                      num_rows,
+                                      Stream<gpu>::GetStream(s));
+        Tensor<gpu, 1, char> workspace = ctx.requested[0]
+          .get_space_typed<gpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t) 
+
+                                         temp_storage_bytes), s);
+        prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
+        d_temp_storage = workspace.dptr_ + num_rows*sizeof(nnvm::dim_t);
+        // mark row flags
+        Fill<false>(s, TBlob(prefix_sum, Shape1(num_rows), gpu::kDevMask), 
kWriteTo, 0);
+        if (grad.storage_initialized()) {
+          Kernel<MarkRowFlgKernel, gpu>::Launch(s, grad.aux_shape(kIdx)[0],
+            prefix_sum, grad_idx);
+          // calculate inclusive prefix sum
+          cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                        temp_storage_bytes,
+                                        prefix_sum,
+                                        prefix_sum,
+                                        num_rows,
+                                        Stream<gpu>::GetStream(s));
+        }
+
+        Kernel<AdamStdDnsRspDnsKernel<req_type>, gpu>::Launch(s, num_rows, 
row_length,
+          out_data, mean_data, var_data, weight_data, grad_idx, grad_val, 
prefix_sum,
+          static_cast<DType>(param.clip_gradient), 
static_cast<DType>(param.beta1),
+          static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
+          static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
+          static_cast<DType>(param.rescale_grad));
+      });
+    });
+  });
+}
 
 NNVM_REGISTER_OP(signsgd_update)
 .set_attr<FCompute>("FCompute<gpu>", SignSGDUpdate<gpu>);
diff --git a/tests/python/unittest/test_optimizer.py 
b/tests/python/unittest/test_optimizer.py
index 2d22391879..26ff48babc 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -520,10 +520,10 @@ def test_adam():
                                     not kwarg['multi_precision'])):
                             continue
                         compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, 
dtype)
-                        if (default_context() == mx.cpu()):
-                            compare_optimizer(opt1(sparse_update=True, 
**kwarg), opt2(**kwarg), shape,
+                        compare_optimizer(opt1(sparse_update=True, **kwarg), 
opt2(**kwarg), shape,
+                                          dtype, w_stype='row_sparse', 
g_stype='row_sparse')
+                        compare_optimizer(opt1(**kwarg), 
opt2(lazy_update=False, **kwarg), shape,
                                           dtype, w_stype='row_sparse', 
g_stype='row_sparse')
-
 
 # Signum
 class PySignum(mx.optimizer.Optimizer):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to