eric-haibin-lin commented on a change in pull request #14568: NAG Optimizer with multi-precision support URL: https://github.com/apache/incubator-mxnet/pull/14568#discussion_r274754817
########## File path: python/mxnet/optimizer/optimizer.py ########## @@ -1051,33 +1051,61 @@ def __init__(self, momentum=0.0, **kwargs): super(NAG, self).__init__(**kwargs) self.momentum = momentum + def create_state_multi_precision(self, index, weight): + weight_master_copy = None + if self.multi_precision and weight.dtype == numpy.float16: + weight_master_copy = weight.astype(numpy.float32) + return (self.create_state(index, weight_master_copy), weight_master_copy) + if weight.dtype == numpy.float16 and not self.multi_precision: + warnings.warn("Accumulating with float16 in optimizer can lead to " + "poor accuracy or slow convergence. " + "Consider using multi_precision=True option of the " + "NAG optimizer") + return self.create_state(index, weight) + def create_state(self, index, weight): momentum = None if self.momentum != 0.0: momentum = zeros(weight.shape, weight.context, dtype=weight.dtype) return momentum - def update(self, index, weight, grad, state): + def _update_impl(self, index, weight, grad, state, multi_precision=False): assert(isinstance(weight, NDArray)) assert(isinstance(grad, NDArray)) self._update_count(index) lr = self._get_lr(index) wd = self._get_wd(index) - grad = grad * self.rescale_grad - if self.clip_gradient is not None: - grad = clip(grad, -self.clip_gradient, self.clip_gradient) + kwargs = {'rescale_grad': self.rescale_grad} + if self.momentum > 0: + kwargs['momentum'] = self.momentum + if self.clip_gradient: + kwargs['clip_gradient'] = self.clip_gradient - if state is not None: - mom = state - mom[:] *= self.momentum - mom[:] += grad - mom[:] += wd * weight - grad[:] += self.momentum * mom - weight[:] -= lr * grad + if not multi_precision: + if state is not None: + nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs) + else: + nag_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) + else: + if state[0] is not None: + mp_nag_mom_update(weight, grad, state[0], state[1], out=weight, + lr=lr, wd=wd, **kwargs) + else: + mp_nag_update(weight, grad, state[1], out=weight, + lr=lr, wd=wd, **kwargs) + + def update(self, index, weight, grad, state): + self._update_impl(index, weight, grad, state, multi_precision=False) + + def update_multi_precision(self, index, weight, grad, state): + if not isinstance(index, (tuple, list)): Review comment: Are you supporting list/tuple of indices/weights for this optimizer? I see why this makes sense for SGD because it implemented an op for list of weights to update, but I don't see that with the NAG op here ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services