sxjscience commented on a change in pull request #16715: Lamb optimizer update
URL: https://github.com/apache/incubator-mxnet/pull/16715#discussion_r349947004
 
 

 ##########
 File path: src/operator/optimizer_op-inl.h
 ##########
 @@ -1563,6 +1563,192 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
   }
 }
 
+struct LambUpdatePhaseOneParam : public 
dmlc::Parameter<LambUpdatePhaseOneParam> {
+    float beta1;
+    float beta2;
+    float epsilon;
+    float t;
+    bool bias_correction;
+    float wd;
+    float rescale_grad;
+    float clip_gradient;
+    DMLC_DECLARE_PARAMETER(LambUpdatePhaseOneParam) {
+      DMLC_DECLARE_FIELD(beta1)
+      .set_default(0.9f)
+      .describe("The decay rate for the 1st moment estimates.");
+      DMLC_DECLARE_FIELD(beta2)
+      .set_default(0.999f)
+      .describe("The decay rate for the 2nd moment estimates.");
+      DMLC_DECLARE_FIELD(epsilon)
+      .set_default(1e-6f)
+      .describe("A small constant for numerical stability.");
+      DMLC_DECLARE_FIELD(t)
+      .describe("Index update count.");
+      DMLC_DECLARE_FIELD(bias_correction)
+      .set_default(true)
+      .describe("Whether to use bias correction.");
+      DMLC_DECLARE_FIELD(wd)
+      .describe("Weight decay augments the objective function with a "
+                "regularization term that penalizes large weights. "
+                "The penalty scales with the square of the magnitude of each 
weight.");
+      DMLC_DECLARE_FIELD(rescale_grad)
+      .set_default(1.0f)
+      .describe("Rescale gradient to grad = rescale_grad*grad.");
+      DMLC_DECLARE_FIELD(clip_gradient)
+      .set_default(-1.0f)
+      .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] 
"
+                "If clip_gradient <= 0, gradient clipping is turned off. "
+                "grad = max(min(grad, clip_gradient), -clip_gradient).");
+    }
+};
+
+struct LambUpdatePhaseTwoParam : public 
dmlc::Parameter<LambUpdatePhaseTwoParam> {
+    float lr;
+    float lower_bound;
+    float upper_bound;
+    DMLC_DECLARE_PARAMETER(LambUpdatePhaseTwoParam) {
+      DMLC_DECLARE_FIELD(lr)
+      .describe("Learning rate");
+      DMLC_DECLARE_FIELD(lower_bound)
+      .set_default(-1.0f)
+      .describe("Lower limit of norm of weight. If lower_bound <= 0, Lower 
limit is not set");
+      DMLC_DECLARE_FIELD(upper_bound)
+      .set_default(-1.0f)
+      .describe("Upper limit of norm of weight. If upper_bound <= 0, Upper 
limit is not set");
+    }
+};
+
+struct LambUpdatePhaseOneKernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data,
+    DType* mean_data, DType* var_data, const DType* weight_data, const DType* 
grad_data,
+    const DType clip_gradient, const DType rescale_grad,
+    const DType beta1, const DType beta2, const DType wd,
+    const DType epsilon, const DType t,
+    bool bias_correction, const OpReqType req) {
+    using namespace mshadow_op;
+
+    DType grad_rescaled = grad_data[i] * rescale_grad;
+    if (clip_gradient >= 0.f) {
+      grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
+    }
+
+    mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
+    var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * 
grad_rescaled;
+
+    DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * 
weight_data[i];
+
+    if (bias_correction) {
+      DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
+      DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
 
 Review comment:
   Actually, in apex, it uses a float32 to calculate the power and then switch 
to float16:
   
https://github.com/NVIDIA/apex/blob/325f5a0bec542701edba1628ad34f3b2ea47c556/csrc/multi_tensor_lamb.cu#L231-L249

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to