eric-haibin-lin commented on a change in pull request #14568: NAG Optimizer 
with multi-precision support
URL: https://github.com/apache/incubator-mxnet/pull/14568#discussion_r274754817
 
 

 ##########
 File path: python/mxnet/optimizer/optimizer.py
 ##########
 @@ -1051,33 +1051,61 @@ def __init__(self, momentum=0.0, **kwargs):
         super(NAG, self).__init__(**kwargs)
         self.momentum = momentum
 
+    def create_state_multi_precision(self, index, weight):
+        weight_master_copy = None
+        if self.multi_precision and weight.dtype == numpy.float16:
+            weight_master_copy = weight.astype(numpy.float32)
+            return (self.create_state(index, weight_master_copy), 
weight_master_copy)
+        if weight.dtype == numpy.float16 and not self.multi_precision:
+            warnings.warn("Accumulating with float16 in optimizer can lead to "
+                          "poor accuracy or slow convergence. "
+                          "Consider using multi_precision=True option of the "
+                          "NAG optimizer")
+        return self.create_state(index, weight)
+
     def create_state(self, index, weight):
         momentum = None
         if self.momentum != 0.0:
             momentum = zeros(weight.shape, weight.context, dtype=weight.dtype)
         return momentum
 
-    def update(self, index, weight, grad, state):
+    def _update_impl(self, index, weight, grad, state, multi_precision=False):
         assert(isinstance(weight, NDArray))
         assert(isinstance(grad, NDArray))
         self._update_count(index)
         lr = self._get_lr(index)
         wd = self._get_wd(index)
 
-        grad = grad * self.rescale_grad
-        if self.clip_gradient is not None:
-            grad = clip(grad, -self.clip_gradient, self.clip_gradient)
+        kwargs = {'rescale_grad': self.rescale_grad}
+        if self.momentum > 0:
+            kwargs['momentum'] = self.momentum
+        if self.clip_gradient:
+            kwargs['clip_gradient'] = self.clip_gradient
 
-        if state is not None:
-            mom = state
-            mom[:] *= self.momentum
-            mom[:] += grad
-            mom[:] += wd * weight
-            grad[:] += self.momentum * mom
-            weight[:] -= lr * grad
+        if not multi_precision:
+            if state is not None:
+                nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, 
**kwargs)
+            else:
+                nag_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
+        else:
+            if state[0] is not None:
+                mp_nag_mom_update(weight, grad, state[0], state[1], out=weight,
+                                  lr=lr, wd=wd, **kwargs)
+            else:
+                mp_nag_update(weight, grad, state[1], out=weight,
+                              lr=lr, wd=wd, **kwargs)
+
+    def update(self, index, weight, grad, state):
+        self._update_impl(index, weight, grad, state, multi_precision=False)
+
+    def update_multi_precision(self, index, weight, grad, state):
+        if not isinstance(index, (tuple, list)):
 
 Review comment:
   Are you supporting list/tuple of indices/weights for this optimizer? I see 
why this makes sense for SGD because it implemented an op for list of weights 
to update, but I don't see that with the NAG op here

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to