diff --git a/tests/python/unittest/test_optimizer.py 
index eb33f9b5217..935bd9ab182 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -435,6 +435,90 @@ def test_nag():
                             compare_optimizer(opt1(**kwarg), opt2(**kwarg), 
shape, dtype)
+class PySGLD(mx.optimizer.Optimizer):
+    """python reference implementation of SGLD"""
+    def __init__(self, **kwargs):
+        super(PySGLD, self).__init__(**kwargs)
+    def create_state(self, index, weight):
+        return None
+    def update(self, index, weight, grad, state):
+        assert(isinstance(weight, mx.nd.NDArray))
+        assert(isinstance(grad, mx.nd.NDArray))
+        self._update_count(index)
+        lr = self._get_lr(index)
+        wd = self._get_wd(index)
+        grad = grad * self.rescale_grad
+        if self.clip_gradient is not None:
+            grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
+        weight[:] += - lr/2 * (grad + wd * weight) + mx.random.normal(0, 
math.sqrt(lr), shape=weight.shape,
dtype=weight.dtype, ctx=weight.context)
+def test_sgld():
+    opt1 = PySGLD
+    opt2 = mx.optimizer.SGLD
+    shape = (3, 4, 5)
+    ns_options = [1234, 42]
+    cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+    wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+    mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}]
+    def compare_optimizer_noise_seeded(opt1, opt2, shape, dtype, noise_seed,
+                                       w_stype='default', g_stype='default',
+                                       rtol=1e-4, atol=1e-5, 
+        """Compare opt1 and opt2 with the added functionality that the seed 
for generating random noise
+        in the SGLD optimizer update is set so that the same noise is used in 
opt1 and opt2.
+        """
+        if w_stype == 'default':
+            w2 = mx.random.uniform(shape=shape, ctx=default_context(), 
+            w1 = w2.copyto(default_context())
+        elif w_stype == 'row_sparse' or w_stype == 'csr':
+            w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
+            w1 = w2.copyto(default_context()).tostype('default')
+        else:
+            raise Exception("type not supported yet")
+        if g_stype == 'default':
+            g2 = mx.random.uniform(shape=shape, ctx=default_context(), 
+            g1 = g2.copyto(default_context())
+        elif g_stype == 'row_sparse' or g_stype == 'csr':
+            g2 = rand_ndarray(shape, g_stype, dtype=dtype)
+            g1 = g2.copyto(default_context()).tostype('default')
+        else:
+            raise Exception("type not supported yet")
+        state1 = opt1.create_state_multi_precision(0, w1)
+        state2 = opt2.create_state_multi_precision(0, w2)
+        if compare_states:
+            compare_ndarray_tuple(state1, state2)
+        # set seed for Gaussian noise replication
+        mx.random.seed(noise_seed)
+        opt1.update_multi_precision(0, w1, g1, state1)
+        mx.random.seed(noise_seed)
+        opt2.update_multi_precision(0, w2, g2, state2)
+        if compare_states:
+            compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
+        assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
+    for seed in ns_options:
+        for dtype in [np.float16, np.float32, np.float64]:
+            for params in itertools.product(cg_options, wd_options, 
+                kwarg = {k: v for param in params for k, v in param.items()}
+                if (dtype == np.float16 and ('multi_precision' not in kwarg or
+                    not kwarg['multi_precision'])):
+                    continue
+                compare_optimizer_noise_seeded(opt1(**kwarg), opt2(**kwarg), 
shape, dtype, seed)

With regards,
Apache Git Services

Reply via email to