This is an automated email from the ASF dual-hosted git repository. haibin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 349803c Multi-precision AdamW update op (#14171) 349803c is described below commit 349803ce9e737248ef8eb97914fcd87d9a5d75d8 Author: Haibin Lin <linhaibin.e...@gmail.com> AuthorDate: Tue Feb 19 16:02:00 2019 -0800 Multi-precision AdamW update op (#14171) * mp adamw update * Softmax fp16 (#201) * softmax for fp16 with fp32 accumulator * return AType in kernel * add dtype * kernel * adamw with nan check * add doc * Revert "Softmax fp16 (#201)" This reverts commit 5869e0ae832437c839bb4ccbcc434971bf5c3486. * add test * more test for fp16 * skip update for rescale = 0 --- src/operator/contrib/adamw-inl.h | 165 ++++++++++++++++++------ src/operator/contrib/adamw.cc | 76 ++++++++++- src/operator/contrib/adamw.cu | 27 +++- tests/python/unittest/test_contrib_optimizer.py | 84 ++++++++++++ 4 files changed, 310 insertions(+), 42 deletions(-) diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h index 3d76b33..66bd4f3 100644 --- a/src/operator/contrib/adamw-inl.h +++ b/src/operator/contrib/adamw-inl.h @@ -33,6 +33,7 @@ #include <nnvm/op.h> #include <nnvm/op_attr_types.h> #include <vector> +#include <cmath> #include "../operator_common.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" @@ -48,7 +49,6 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> { float epsilon; float wd; float eta; - float rescale_grad; float clip_gradient; DMLC_DECLARE_PARAMETER(AdamWParam) { DMLC_DECLARE_FIELD(lr) @@ -69,9 +69,6 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> { "The penalty scales with the square of the magnitude of each weight."); DMLC_DECLARE_FIELD(eta) .describe("Learning rate schedule multiplier"); - DMLC_DECLARE_FIELD(rescale_grad) - .set_default(1.0f) - .describe("Rescale gradient to grad = rescale_grad*grad."); DMLC_DECLARE_FIELD(clip_gradient) .set_default(-1.0f) .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " @@ -80,44 +77,138 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> { } }; +// rescale_grad is a reserved argument at position -1. Example: +// n_in = 2: weight, grad (fp16) +// n_out = 1: weight (fp16) +// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32) +template<int n_in, int n_out, int total_in> +inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs, + std::vector<TShape> *in_attrs, + std::vector<TShape> *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name; + // rescale_grad.shape = (1,) + SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mshadow::Shape1(1)); + return ElemwiseAttr<TShape, shape_is_none, shape_assign, true, shape_string, n_in, n_out>( + attrs, in_attrs, out_attrs, TShape()); +} + +// rescale_grad is a reserved argument at position -1. Example: +// n_in = 2: weight, grad (fp16) +// n_out = 1: weight (fp16) +// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32) +template<int n_in, int n_out, int total_in> +inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs, + std::vector<int> *in_attrs, + std::vector<int> *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name; + for (int i = n_in; i < total_in; ++i) { + TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32); + } + return ElemwiseAttr<int, type_is_none, type_assign, true, type_string, n_in, n_out>( + attrs, in_attrs, out_attrs, -1); +} + +template<int req> +struct MPAdamWKernel { + template<typename DType> + MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data, + float* var_data, const DType* weight_data, const DType* grad_data, float* weight32, + const float param_clip_gradient, const float param_beta1, const float param_beta2, + const float param_eta, const float param_lr, const float param_wd, + const float param_rescale_grad, const float param_epsilon) { + float w = weight32[i]; + float mean = mean_data[i]; + float var = var_data[i]; + float scaled_grad = param_rescale_grad*static_cast<float>(grad_data[i]); + if (param_clip_gradient >= 0.0f) { + mean = param_beta1 * mean + + (1 - param_beta1) * mshadow_op::clip::Map(scaled_grad, param_clip_gradient); + var = param_beta2 * var + (1 - param_beta2) * + mshadow_op::square::Map(mshadow_op::clip::Map(scaled_grad, param_clip_gradient)); + } else { + mean = param_beta1 * mean + (1 - param_beta1) * scaled_grad; + var = param_beta2 * var + (1 - param_beta2) * mshadow_op::square::Map(scaled_grad); + } + mean_data[i] = mean; + var_data[i] = var; + w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon) + + param_wd * w); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, w); + } +}; + + +template<typename xpu> +struct MPAdamWUpdate { + static inline void Forward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector<TBlob> &inputs, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &outputs, + const float rescale_grad) { + using namespace mxnet_op; + AdamWParam param = nnvm::get<AdamWParam>(attrs.parsed); + Stream<xpu>* s = ctx.get_stream<xpu>(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); + Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); + Tensor<xpu, 2, float> mean = inputs[2].FlatTo2D<xpu, float>(s); + Tensor<xpu, 2, float> var = inputs[3].FlatTo2D<xpu, float>(s); + Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s); + Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel<MPAdamWKernel<req_type>, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mean.dptr_, + var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.beta1, + param.beta2, param.eta, param.lr, param.wd, rescale_grad, param.epsilon); + }); + }); + } +}; + /* * \brief adam_w update. */ template<typename xpu> -inline void AdamWUpdate(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector<TBlob> &inputs, - const std::vector<OpReqType> &req, - const std::vector<TBlob> &outputs) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace mshadow_op; - const AdamWParam& param = nnvm::get<AdamWParam>(attrs.parsed); - Stream<xpu>* s = ctx.get_stream<xpu>(); - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); - Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); - Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s); - Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s); - Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); +struct AdamWUpdate { + static inline void Forward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector<TBlob> &inputs, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &outputs, + const float rescale_grad) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + const AdamWParam& param = nnvm::get<AdamWParam>(attrs.parsed); + Stream<xpu>* s = ctx.get_stream<xpu>(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); + Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); + Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s); + Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s); + Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); - grad = scalar<DType>(param.rescale_grad) * grad; - if (param.clip_gradient >= 0.0f) { - mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * - F<clip>(grad, DType(param.clip_gradient)); - var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2)*F<square>( - F<clip>(grad, DType(param.clip_gradient))); - } else { - mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * grad; - var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * F<square>(grad); - } - Assign(out, req[0], - weight - - scalar<DType>(param.eta) * (scalar<DType>(param.lr) * - mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) + - (scalar<DType>(param.wd) * weight))); - }); -} + grad = scalar<DType>(rescale_grad) * grad; + if (param.clip_gradient >= 0.0f) { + mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * + F<clip>(grad, DType(param.clip_gradient)); + var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2)*F<square>( + F<clip>(grad, DType(param.clip_gradient))); + } else { + mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * grad; + var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * F<square>(grad); + } + Assign(out, req[0], + weight - + scalar<DType>(param.eta) * (scalar<DType>(param.lr) * + mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) + + (scalar<DType>(param.wd) * weight))); + }); + } +}; } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc index 94623fe..2fbc397 100644 --- a/src/operator/contrib/adamw.cc +++ b/src/operator/contrib/adamw.cc @@ -24,12 +24,76 @@ * \author Haibin Lin */ #include "./adamw-inl.h" +#include "../optimizer_op-inl.h" namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(AdamWParam); +template<template <typename xpu> class F> +inline void MPUpdateCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector<TBlob> &inputs, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &outputs) { + // copy to cpu and check NaN value + TBlob scale_blob = inputs[inputs.size() - 1]; + MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, { + float scalef = static_cast<float>(*scale_blob.dptr<DType>()); + if (!std::isfinite(scalef) || scalef == 0) return; + std::vector<TBlob> inputs_wo_scale; + size_t num_in = inputs.size(); + inputs_wo_scale.reserve(num_in - 1); + for (size_t i = 0; i < num_in - 1; i++) inputs_wo_scale.emplace_back(inputs[i]); + F<cpu>::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef); + }); +} + +NNVM_REGISTER_OP(_contrib_mp_adamw_update) +.describe(R"code(Update function for multi-precision AdamW optimizer. + +AdamW is seen as a modification of Adam by decoupling the weight decay from the +optimization steps taken w.r.t. the loss function. + +Adam update consists of the following steps, where g represents gradient and m, v +are 1st and 2nd order moment estimates (mean and variance). + +.. math:: + + g_t = \nabla J(W_{t-1})\\ + m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ + v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ + W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) + +It updates the weights using:: + + m = beta1*m + (1-beta1)*grad + v = beta2*v + (1-beta2)*(grad**2) + w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) + +Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, +the update is skipped. +)code" ADD_FILELINE) +.set_num_inputs(6) +.set_num_outputs(1) +.set_attr_parser(ParamParser<AdamWParam>) +.set_attr<nnvm::FInferShape>("FInferShape", MPUpdateInferShape<2, 1, 6>) +.set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<2, 1, 6>) +.set_attr<nnvm::FMutateInputs>("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector<uint32_t>{2, 3, 4}; + }) +.set_attr<FCompute>("FCompute<cpu>", MPUpdateCPU<MPAdamWUpdate>) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mean", "NDArray-or-Symbol", "Moving mean") +.add_argument("var", "NDArray-or-Symbol", "Moving variance") +.add_argument("weight32", "NDArray-or-Symbol", "Weight32") +.add_argument("rescale_grad", "NDArray-or-Symbol", + "Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.") +.add_arguments(AdamWParam::__FIELDS__()); + NNVM_REGISTER_OP(_contrib_adamw_update) .describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function. @@ -50,21 +114,25 @@ It updates the weights using:: v = beta2*v + (1-beta2)*(grad**2) w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) +Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, +the update is skipped. )code" ADD_FILELINE) -.set_num_inputs(4) +.set_num_inputs(5) .set_num_outputs(1) .set_attr_parser(ParamParser<AdamWParam>) -.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<4, 1>) -.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>) +.set_attr<nnvm::FInferShape>("FInferShape", MPUpdateInferShape<4, 1, 5>) +.set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<4, 1, 5>) .set_attr<nnvm::FMutateInputs>("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { return std::vector<uint32_t>{2, 3}; }) -.set_attr<FCompute>("FCompute<cpu>", AdamWUpdate<cpu>) +.set_attr<FCompute>("FCompute<cpu>", MPUpdateCPU<AdamWUpdate>) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("mean", "NDArray-or-Symbol", "Moving mean") .add_argument("var", "NDArray-or-Symbol", "Moving variance") +.add_argument("rescale_grad", "NDArray-or-Symbol", + "Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.") .add_arguments(AdamWParam::__FIELDS__()); } // namespace op diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu index b7452f8..e21b83b 100644 --- a/src/operator/contrib/adamw.cu +++ b/src/operator/contrib/adamw.cu @@ -28,8 +28,33 @@ namespace mxnet { namespace op { +template<template <typename xpu> class F> +inline void MPUpdateGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector<TBlob> &inputs, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &outputs) { + // copy to cpu and check NaN value + TBlob scale_blob = inputs[inputs.size() - 1]; + MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, { + DType scale = 0; + CUDA_CALL(cudaMemcpy(&scale, scale_blob.dptr<DType>(), sizeof(DType), + cudaMemcpyDeviceToHost)); + float scalef = static_cast<float>(scale); + if (!std::isfinite(scalef) || scalef == 0) return; + std::vector<TBlob> inputs_wo_scale; + size_t num_in = inputs.size(); + inputs_wo_scale.reserve(num_in - 1); + for (size_t i = 0; i < num_in - 1; i++) inputs_wo_scale.emplace_back(inputs[i]); + F<gpu>::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef); + }); +} + NNVM_REGISTER_OP(_contrib_adamw_update) -.set_attr<FCompute>("FCompute<gpu>", AdamWUpdate<gpu>); +.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<AdamWUpdate>); + +NNVM_REGISTER_OP(_contrib_mp_adamw_update) +.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<MPAdamWUpdate>); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_contrib_optimizer.py b/tests/python/unittest/test_contrib_optimizer.py index 8ff8a7e..dad7bed 100644 --- a/tests/python/unittest/test_contrib_optimizer.py +++ b/tests/python/unittest/test_contrib_optimizer.py @@ -94,6 +94,90 @@ def test_group_adagrad(): g_stype='row_sparse', compare_states=False) +def test_adamw(): + shape = (3, 4) + weight = mx.nd.random.uniform(shape=shape) + weight_ref = weight.copy() + grad = mx.nd.random.uniform(shape=shape) + m = mx.nd.random.uniform(shape=shape) + v = mx.nd.random.uniform(shape=shape) + rescale_grad = mx.nd.array([10]) + eta, lr, wd, epsilon = 1, 1, 0, 1e-8 + beta1, beta2 = 0.9, 0.999 + kwargs = {'eta': eta, 'lr': lr, 'wd': wd, 'epsilon': epsilon, + 'beta1': beta1, 'beta2': beta2} + + # update is skipped for rescale = 0 + mx.nd.contrib.adamw_update(weight, grad, m, v, + rescale_grad * 0, out=weight, **kwargs) + # weight remains unchanged + mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy()) + + # update is skipped for rescale = nan + mx.nd.contrib.adamw_update(weight, grad, m, v, + rescale_grad * np.nan, out=weight, **kwargs) + # weight remains unchanged + mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy()) + + # update is skipped for rescale = inf + mx.nd.contrib.adamw_update(weight, grad, m, v, + rescale_grad * np.inf, out=weight, **kwargs) + # weight remains unchanged + mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy()) + + # multi-precision update is skipped for rescale = nan + weight_fp16 = weight.astype('float16') + grad_fp16 = grad.astype('float16') + weight_fp16_ref = weight_fp16.copy() + mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight, + rescale_grad * np.nan, out=weight_fp16, **kwargs) + mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy()) + mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy()) + + # multi-precision update is skipped for rescale = inf + mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight, + rescale_grad * np.inf, out=weight_fp16, **kwargs) + mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy()) + mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy()) + + # multi-precision update is skipped for rescale = 0 + mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight, + rescale_grad * 0, out=weight_fp16, **kwargs) + mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy()) + mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy()) + + # reference normal update + grad_rescale = rescale_grad * grad + m_ref = beta1*m + (1-beta1)*grad_rescale + v_ref = beta2*v + (1-beta2)*(grad_rescale**2) + weight_ref = weight - eta * (1 * m_ref / (v_ref.sqrt() + epsilon) + weight * wd) + m_test = m.copy() + v_test = v.copy() + weight_test = weight.copy() + # op normal update + mx.nd.contrib.adamw_update(weight_test, grad, m_test, v_test, + rescale_grad, out=weight_test, **kwargs) + mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight_test.asnumpy()) + mx.test_utils.assert_almost_equal(m_ref.asnumpy(), m_test.asnumpy()) + mx.test_utils.assert_almost_equal(v_ref.asnumpy(), v_test.asnumpy()) + + # reference normal multi-precision update + m_fp32 = m.copy() + v_fp32 = v.copy() + weight_fp32 = weight.copy() + grad_rescale = rescale_grad * grad_fp16.astype('float32') + m_ref = beta1*m_fp32 + (1-beta1)*grad_rescale + v_ref = beta2*v_fp32 + (1-beta2)*(grad_rescale**2) + weight_ref = weight - eta * (1 * m_ref / (v_ref.sqrt() + epsilon) + weight * wd) + weight_fp16_ref = weight_ref.astype('float16') + # op normal multi-precision update + mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m_fp32, v_fp32, weight_fp32, + rescale_grad, out=weight_fp16, **kwargs) + mx.test_utils.assert_almost_equal(m_ref.asnumpy(), m_fp32.asnumpy()) + mx.test_utils.assert_almost_equal(v_ref.asnumpy(), v_fp32.asnumpy()) + mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight_fp32.asnumpy()) + mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy()) + if __name__ == '__main__': import nose