This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 349803c  Multi-precision AdamW update op (#14171)
349803c is described below

commit 349803ce9e737248ef8eb97914fcd87d9a5d75d8
Author: Haibin Lin <linhaibin.e...@gmail.com>
AuthorDate: Tue Feb 19 16:02:00 2019 -0800

    Multi-precision AdamW update op (#14171)
    
    * mp adamw update
    
    * Softmax fp16 (#201)
    
    * softmax for fp16 with fp32 accumulator
    
    * return AType in kernel
    
    * add dtype
    
    * kernel
    
    * adamw with nan check
    
    * add doc
    
    * Revert "Softmax fp16 (#201)"
    
    This reverts commit 5869e0ae832437c839bb4ccbcc434971bf5c3486.
    
    * add test
    
    * more test for fp16
    
    * skip update for rescale = 0
---
 src/operator/contrib/adamw-inl.h                | 165 ++++++++++++++++++------
 src/operator/contrib/adamw.cc                   |  76 ++++++++++-
 src/operator/contrib/adamw.cu                   |  27 +++-
 tests/python/unittest/test_contrib_optimizer.py |  84 ++++++++++++
 4 files changed, 310 insertions(+), 42 deletions(-)

diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h
index 3d76b33..66bd4f3 100644
--- a/src/operator/contrib/adamw-inl.h
+++ b/src/operator/contrib/adamw-inl.h
@@ -33,6 +33,7 @@
 #include <nnvm/op.h>
 #include <nnvm/op_attr_types.h>
 #include <vector>
+#include <cmath>
 #include "../operator_common.h"
 #include "../mshadow_op.h"
 #include "../elemwise_op_common.h"
@@ -48,7 +49,6 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
   float epsilon;
   float wd;
   float eta;
-  float rescale_grad;
   float clip_gradient;
   DMLC_DECLARE_PARAMETER(AdamWParam) {
     DMLC_DECLARE_FIELD(lr)
@@ -69,9 +69,6 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
               "The penalty scales with the square of the magnitude of each 
weight.");
     DMLC_DECLARE_FIELD(eta)
     .describe("Learning rate schedule multiplier");
-    DMLC_DECLARE_FIELD(rescale_grad)
-    .set_default(1.0f)
-    .describe("Rescale gradient to grad = rescale_grad*grad.");
     DMLC_DECLARE_FIELD(clip_gradient)
     .set_default(-1.0f)
     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
@@ -80,44 +77,138 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
   }
 };
 
+// rescale_grad is a reserved argument at position -1. Example:
+// n_in = 2: weight, grad (fp16)
+// n_out = 1: weight (fp16)
+// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32)
+template<int n_in, int n_out, int total_in>
+inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs,
+                               std::vector<TShape> *in_attrs,
+                               std::vector<TShape> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " 
<< attrs.name;
+  CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " 
<< attrs.name;
+  // rescale_grad.shape = (1,)
+  SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mshadow::Shape1(1));
+  return ElemwiseAttr<TShape, shape_is_none, shape_assign, true, shape_string, 
n_in, n_out>(
+      attrs, in_attrs, out_attrs, TShape());
+}
+
+// rescale_grad is a reserved argument at position -1. Example:
+// n_in = 2: weight, grad (fp16)
+// n_out = 1: weight (fp16)
+// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32)
+template<int n_in, int n_out, int total_in>
+inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs,
+                              std::vector<int> *in_attrs,
+                              std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " 
<< attrs.name;
+  CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " 
<< attrs.name;
+  for (int i = n_in; i < total_in; ++i) {
+    TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
+  }
+  return ElemwiseAttr<int, type_is_none, type_assign, true, type_string, n_in, 
n_out>(
+      attrs, in_attrs, out_attrs, -1);
+}
+
+template<int req>
+struct MPAdamWKernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data,
+    float* var_data, const DType* weight_data, const DType* grad_data, float* 
weight32,
+    const float param_clip_gradient, const float param_beta1, const float 
param_beta2,
+    const float param_eta, const float param_lr, const float param_wd,
+    const float param_rescale_grad, const float param_epsilon) {
+    float w = weight32[i];
+    float mean = mean_data[i];
+    float var = var_data[i];
+    float scaled_grad = param_rescale_grad*static_cast<float>(grad_data[i]);
+    if (param_clip_gradient >= 0.0f) {
+      mean = param_beta1 * mean +
+             (1 - param_beta1) * mshadow_op::clip::Map(scaled_grad, 
param_clip_gradient);
+      var = param_beta2 * var + (1 - param_beta2) *
+            mshadow_op::square::Map(mshadow_op::clip::Map(scaled_grad, 
param_clip_gradient));
+    } else {
+      mean = param_beta1 * mean + (1 - param_beta1) * scaled_grad;
+      var = param_beta2 * var + (1 - param_beta2) * 
mshadow_op::square::Map(scaled_grad);
+    }
+    mean_data[i] = mean;
+    var_data[i] = var;
+    w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) 
+ param_epsilon)
+                         + param_wd * w);
+    weight32[i] = w;
+    KERNEL_ASSIGN(out_data[i], req, w);
+  }
+};
+
+
+template<typename xpu>
+struct MPAdamWUpdate {
+  static inline void Forward(const nnvm::NodeAttrs& attrs,
+               const OpContext &ctx,
+               const std::vector<TBlob> &inputs,
+               const std::vector<OpReqType> &req,
+               const std::vector<TBlob> &outputs,
+               const float rescale_grad) {
+    using namespace mxnet_op;
+    AdamWParam param = nnvm::get<AdamWParam>(attrs.parsed);
+    Stream<xpu>* s = ctx.get_stream<xpu>();
+    MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+      Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+      Tensor<xpu, 2, float> mean = inputs[2].FlatTo2D<xpu, float>(s);
+      Tensor<xpu, 2, float> var = inputs[3].FlatTo2D<xpu, float>(s);
+      Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
+      Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Kernel<MPAdamWKernel<req_type>, xpu>::Launch(s, weight.shape_.Size(), 
out.dptr_, mean.dptr_,
+          var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, 
param.clip_gradient, param.beta1,
+          param.beta2, param.eta, param.lr, param.wd, rescale_grad, 
param.epsilon);
+      });
+    });
+  }
+};
+
 /*
  * \brief adam_w update.
  */
 template<typename xpu>
-inline void AdamWUpdate(const nnvm::NodeAttrs& attrs,
-                        const OpContext &ctx,
-                        const std::vector<TBlob> &inputs,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &outputs) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  using namespace mshadow_op;
-  const AdamWParam& param = nnvm::get<AdamWParam>(attrs.parsed);
-  Stream<xpu>* s = ctx.get_stream<xpu>();
-  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-    Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
-    Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
-    Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
-    Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
-    Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+struct AdamWUpdate {
+  static inline void Forward(const nnvm::NodeAttrs& attrs,
+                             const OpContext &ctx,
+                             const std::vector<TBlob> &inputs,
+                             const std::vector<OpReqType> &req,
+                             const std::vector<TBlob> &outputs,
+                             const float rescale_grad) {
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    using namespace mshadow_op;
+    const AdamWParam& param = nnvm::get<AdamWParam>(attrs.parsed);
+    Stream<xpu>* s = ctx.get_stream<xpu>();
+    MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+      Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+      Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
+      Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
+      Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
 
-    grad = scalar<DType>(param.rescale_grad) * grad;
-    if (param.clip_gradient >= 0.0f) {
-      mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) *
-          F<clip>(grad, DType(param.clip_gradient));
-      var = scalar<DType>(param.beta2)*var + 
scalar<DType>(1.f-param.beta2)*F<square>(
-          F<clip>(grad, DType(param.clip_gradient)));
-    } else {
-      mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) 
* grad;
-      var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * 
F<square>(grad);
-    }
-    Assign(out, req[0],
-           weight -
-           scalar<DType>(param.eta) * (scalar<DType>(param.lr) *
-           mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) +
-           (scalar<DType>(param.wd) * weight)));
-  });
-}
+      grad = scalar<DType>(rescale_grad) * grad;
+      if (param.clip_gradient >= 0.0f) {
+        mean = scalar<DType>(param.beta1)*mean + 
scalar<DType>(1.f-param.beta1) *
+            F<clip>(grad, DType(param.clip_gradient));
+        var = scalar<DType>(param.beta2)*var + 
scalar<DType>(1.f-param.beta2)*F<square>(
+            F<clip>(grad, DType(param.clip_gradient)));
+      } else {
+        mean = scalar<DType>(param.beta1)*mean + 
scalar<DType>(1.f-param.beta1) * grad;
+        var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) 
* F<square>(grad);
+      }
+      Assign(out, req[0],
+             weight -
+             scalar<DType>(param.eta) * (scalar<DType>(param.lr) *
+             mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) +
+             (scalar<DType>(param.wd) * weight)));
+    });
+  }
+};
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc
index 94623fe..2fbc397 100644
--- a/src/operator/contrib/adamw.cc
+++ b/src/operator/contrib/adamw.cc
@@ -24,12 +24,76 @@
  * \author Haibin Lin
  */
 #include "./adamw-inl.h"
+#include "../optimizer_op-inl.h"
 
 namespace mxnet {
 namespace op {
 
 DMLC_REGISTER_PARAMETER(AdamWParam);
 
+template<template <typename xpu> class F>
+inline void MPUpdateCPU(const nnvm::NodeAttrs& attrs,
+                        const OpContext &ctx,
+                        const std::vector<TBlob> &inputs,
+                        const std::vector<OpReqType> &req,
+                        const std::vector<TBlob> &outputs) {
+  // copy to cpu and check NaN value
+  TBlob scale_blob = inputs[inputs.size() - 1];
+  MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, {
+    float scalef = static_cast<float>(*scale_blob.dptr<DType>());
+    if (!std::isfinite(scalef) || scalef == 0) return;
+    std::vector<TBlob> inputs_wo_scale;
+    size_t num_in = inputs.size();
+    inputs_wo_scale.reserve(num_in - 1);
+    for (size_t i = 0; i < num_in - 1; i++) 
inputs_wo_scale.emplace_back(inputs[i]);
+    F<cpu>::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
+  });
+}
+
+NNVM_REGISTER_OP(_contrib_mp_adamw_update)
+.describe(R"code(Update function for multi-precision AdamW optimizer.
+
+AdamW is seen as a modification of Adam by decoupling the weight decay from the
+optimization steps taken w.r.t. the loss function.
+
+Adam update consists of the following steps, where g represents gradient and 
m, v
+are 1st and 2nd order moment estimates (mean and variance).
+
+.. math::
+
+ g_t = \nabla J(W_{t-1})\\
+ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
+ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
+ W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd 
W_{t-1})
+
+It updates the weights using::
+
+ m = beta1*m + (1-beta1)*grad
+ v = beta2*v + (1-beta2)*(grad**2)
+ w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
+
+Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad 
is NaN, Inf, or 0,
+the update is skipped.
+)code" ADD_FILELINE)
+.set_num_inputs(6)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<AdamWParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", MPUpdateInferShape<2, 1, 6>)
+.set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<2, 1, 6>)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+  [](const nnvm::NodeAttrs& attrs) {
+    return std::vector<uint32_t>{2, 3, 4};
+  })
+.set_attr<FCompute>("FCompute<cpu>", MPUpdateCPU<MPAdamWUpdate>)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
+.add_argument("var", "NDArray-or-Symbol", "Moving variance")
+.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
+.add_argument("rescale_grad", "NDArray-or-Symbol",
+              "Rescale gradient to rescale_grad * grad. If NaN, the update is 
skipped.")
+.add_arguments(AdamWParam::__FIELDS__());
+
 NNVM_REGISTER_OP(_contrib_adamw_update)
 .describe(R"code(Update function for AdamW optimizer. AdamW is seen as a 
modification of
 Adam by decoupling the weight decay from the optimization steps taken w.r.t. 
the loss function.
@@ -50,21 +114,25 @@ It updates the weights using::
  v = beta2*v + (1-beta2)*(grad**2)
  w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
 
+Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad 
is NaN, Inf, or 0,
+the update is skipped.
 )code" ADD_FILELINE)
-.set_num_inputs(4)
+.set_num_inputs(5)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<AdamWParam>)
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", MPUpdateInferShape<4, 1, 5>)
+.set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<4, 1, 5>)
 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
   [](const nnvm::NodeAttrs& attrs) {
     return std::vector<uint32_t>{2, 3};
   })
-.set_attr<FCompute>("FCompute<cpu>", AdamWUpdate<cpu>)
+.set_attr<FCompute>("FCompute<cpu>", MPUpdateCPU<AdamWUpdate>)
 .add_argument("weight", "NDArray-or-Symbol", "Weight")
 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
 .add_argument("mean", "NDArray-or-Symbol", "Moving mean")
 .add_argument("var", "NDArray-or-Symbol", "Moving variance")
+.add_argument("rescale_grad", "NDArray-or-Symbol",
+              "Rescale gradient to rescale_grad * grad. If NaN, the update is 
skipped.")
 .add_arguments(AdamWParam::__FIELDS__());
 
 }  // namespace op
diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu
index b7452f8..e21b83b 100644
--- a/src/operator/contrib/adamw.cu
+++ b/src/operator/contrib/adamw.cu
@@ -28,8 +28,33 @@
 namespace mxnet {
 namespace op {
 
+template<template <typename xpu> class F>
+inline void MPUpdateGPU(const nnvm::NodeAttrs& attrs,
+                        const OpContext &ctx,
+                        const std::vector<TBlob> &inputs,
+                        const std::vector<OpReqType> &req,
+                        const std::vector<TBlob> &outputs) {
+  // copy to cpu and check NaN value
+  TBlob scale_blob = inputs[inputs.size() - 1];
+  MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, {
+    DType scale = 0;
+    CUDA_CALL(cudaMemcpy(&scale, scale_blob.dptr<DType>(), sizeof(DType),
+       cudaMemcpyDeviceToHost));
+    float scalef = static_cast<float>(scale);
+    if (!std::isfinite(scalef) || scalef == 0) return;
+    std::vector<TBlob> inputs_wo_scale;
+    size_t num_in = inputs.size();
+    inputs_wo_scale.reserve(num_in - 1);
+    for (size_t i = 0; i < num_in - 1; i++) 
inputs_wo_scale.emplace_back(inputs[i]);
+    F<gpu>::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
+  });
+}
+
 NNVM_REGISTER_OP(_contrib_adamw_update)
-.set_attr<FCompute>("FCompute<gpu>", AdamWUpdate<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<AdamWUpdate>);
+
+NNVM_REGISTER_OP(_contrib_mp_adamw_update)
+.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<MPAdamWUpdate>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_contrib_optimizer.py 
b/tests/python/unittest/test_contrib_optimizer.py
index 8ff8a7e..dad7bed 100644
--- a/tests/python/unittest/test_contrib_optimizer.py
+++ b/tests/python/unittest/test_contrib_optimizer.py
@@ -94,6 +94,90 @@ def test_group_adagrad():
                 g_stype='row_sparse',
                 compare_states=False)
 
+def test_adamw():
+    shape = (3, 4)
+    weight = mx.nd.random.uniform(shape=shape)
+    weight_ref = weight.copy()
+    grad = mx.nd.random.uniform(shape=shape)
+    m = mx.nd.random.uniform(shape=shape)
+    v = mx.nd.random.uniform(shape=shape)
+    rescale_grad = mx.nd.array([10])
+    eta, lr, wd, epsilon = 1, 1, 0, 1e-8
+    beta1, beta2 = 0.9, 0.999
+    kwargs = {'eta': eta, 'lr': lr, 'wd': wd, 'epsilon': epsilon,
+              'beta1': beta1, 'beta2': beta2}
+
+    # update is skipped for rescale = 0
+    mx.nd.contrib.adamw_update(weight, grad, m, v,
+                               rescale_grad * 0, out=weight, **kwargs)
+    # weight remains unchanged
+    mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+
+    # update is skipped for rescale = nan
+    mx.nd.contrib.adamw_update(weight, grad, m, v,
+                               rescale_grad * np.nan, out=weight, **kwargs)
+    # weight remains unchanged
+    mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+
+    # update is skipped for rescale = inf
+    mx.nd.contrib.adamw_update(weight, grad, m, v,
+                               rescale_grad * np.inf, out=weight, **kwargs)
+    # weight remains unchanged
+    mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+
+    # multi-precision update is skipped for rescale = nan
+    weight_fp16 = weight.astype('float16')
+    grad_fp16 = grad.astype('float16')
+    weight_fp16_ref = weight_fp16.copy()
+    mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
+                                  rescale_grad * np.nan, out=weight_fp16, 
**kwargs)
+    mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+    mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), 
weight_fp16.asnumpy())
+
+    # multi-precision update is skipped for rescale = inf
+    mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
+                                  rescale_grad * np.inf, out=weight_fp16, 
**kwargs)
+    mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+    mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), 
weight_fp16.asnumpy())
+
+    # multi-precision update is skipped for rescale = 0
+    mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
+                                  rescale_grad * 0, out=weight_fp16, **kwargs)
+    mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
+    mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), 
weight_fp16.asnumpy())
+
+    # reference normal update
+    grad_rescale = rescale_grad * grad
+    m_ref = beta1*m + (1-beta1)*grad_rescale
+    v_ref = beta2*v + (1-beta2)*(grad_rescale**2)
+    weight_ref = weight - eta * (1 * m_ref / (v_ref.sqrt() + epsilon) + weight 
* wd)
+    m_test = m.copy()
+    v_test = v.copy()
+    weight_test = weight.copy()
+    # op normal update
+    mx.nd.contrib.adamw_update(weight_test, grad, m_test, v_test,
+                               rescale_grad, out=weight_test, **kwargs)
+    mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), 
weight_test.asnumpy())
+    mx.test_utils.assert_almost_equal(m_ref.asnumpy(), m_test.asnumpy())
+    mx.test_utils.assert_almost_equal(v_ref.asnumpy(), v_test.asnumpy())
+
+    # reference normal multi-precision update
+    m_fp32 = m.copy()
+    v_fp32 = v.copy()
+    weight_fp32 = weight.copy()
+    grad_rescale = rescale_grad * grad_fp16.astype('float32')
+    m_ref = beta1*m_fp32 + (1-beta1)*grad_rescale
+    v_ref = beta2*v_fp32 + (1-beta2)*(grad_rescale**2)
+    weight_ref = weight - eta * (1 * m_ref / (v_ref.sqrt() + epsilon) + weight 
* wd)
+    weight_fp16_ref = weight_ref.astype('float16')
+    # op normal multi-precision update
+    mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m_fp32, v_fp32, 
weight_fp32,
+                                  rescale_grad, out=weight_fp16, **kwargs)
+    mx.test_utils.assert_almost_equal(m_ref.asnumpy(), m_fp32.asnumpy())
+    mx.test_utils.assert_almost_equal(v_ref.asnumpy(), v_fp32.asnumpy())
+    mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), 
weight_fp32.asnumpy())
+    mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), 
weight_fp16.asnumpy())
+
 
 if __name__ == '__main__':
     import nose

Reply via email to