This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 649a0a1  CUDNN not training on backward pass similar to pytorch 
(#10470)
649a0a1 is described below

commit 649a0a16b91e4c81ec911ecc3be0439d9ebc2a52
Author: Chris Olivier <cjolivie...@gmail.com>
AuthorDate: Mon Apr 9 14:43:53 2018 -0700

    CUDNN not training on backward pass similar to pytorch (#10470)
    
    * CUDNN not training on backward pass similar to pytorch 
https://github.com/pytorch/pytorch/issues/4284
    
    * add test
    
    * add seed decorator
    
    * Trigger
---
 src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 12 ++++++------
 tests/python/gpu/test_operator_gpu.py        | 20 +++++++++++++++++++-
 2 files changed, 25 insertions(+), 7 deletions(-)

diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h 
b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
index e3d5dd9..d4b9f84 100644
--- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
+++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
@@ -165,8 +165,6 @@ class CuDNNBatchNormOp {
     using namespace mshadow::expr;
     CHECK_EQ(inputs.size(), 8U);
     CHECK_EQ(outputs.size(), 3U);
-    CHECK(ctx.is_train && !param_.use_global_stats)
-        << "use global statistics is not yet supported in CuDNNBatchNorm";
 
     // Rename the inputs and outputs.
     const TBlob &out_grad = inputs[0];
@@ -183,6 +181,8 @@ class CuDNNBatchNormOp {
       in_grad[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s);
     Tensor<gpu, 4, DType> dy = out_grad.get_with_shape<gpu, 4, DType>(shape_, 
s);
 
+    const bool global_stats = !ctx.is_train || param_.use_global_stats;
+
 #if CUDNN_VERSION >= 4007
 #if CUDNN_VERSION >= 7002
     auto mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
@@ -226,8 +226,8 @@ class CuDNNBatchNormOp {
         dgamma.dptr_,
         dbeta.dptr_,
         param_.eps,
-        save_mean.dptr_,
-        save_inv_var.dptr_));
+        global_stats ? nullptr : save_mean.dptr_,
+        global_stats ? nullptr : save_inv_var.dptr_));
       if (param_.fix_gamma) dgamma = 0.f;
     })
 #else  // CUDNN_VERSION < 4007
@@ -264,8 +264,8 @@ class CuDNNBatchNormOp {
                                                  dgamma.dptr_,
                                                  dbeta.dptr_,
                                                  param_.eps,
-                                                 save_mean.dptr_,
-                                                 save_inv_var.dptr_));
+                                                 global_stats ? nullptr : 
save_mean.dptr_,
+                                                 global_stats ? nullptr : 
save_inv_var.dptr_));
       if (param_.fix_gamma) dgamma = 0.f;
     })
 #endif
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index 85b3e26..08c749e 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -27,6 +27,7 @@ import unittest
 from nose.tools import assert_raises
 from mxnet.test_utils import check_consistency, set_default_context, 
assert_almost_equal
 from mxnet.base import MXNetError
+from mxnet import autograd
 from numpy.testing import assert_allclose
 
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
@@ -1815,7 +1816,24 @@ def test_incorrect_gpu():
     # Try setting dev_id to a really big number
     assert_raises(MXNetError, mx.nd.ones, (2,2), ctx=mx.gpu(100001))
 
+@with_seed()
+def test_batchnorm_backwards_notrain():
+    for ctx in [mx.cpu(0), mx.gpu(0)]:
+        for cudnn_o in [False, True]:
+            B,C,H,W = 4,3,2,2
+            x = mx.nd.random.poisson(1,shape=(B,C,H,W)).as_in_context(ctx)
+            gamma = mx.nd.random.normal(shape=(C)).as_in_context(ctx)
+            beta = mx.nd.random.normal(shape=(C)).as_in_context(ctx)
+            mean = mx.nd.random.normal(shape=(C)).as_in_context(ctx)
+            std = mx.nd.random.normal(shape=(C)).as_in_context(ctx)
+            x.attach_grad()
+
+            with autograd.record(False):
+                y = mx.ndarray.BatchNorm(x, gamma, beta, mean, std.square(),
+                                         fix_gamma=False, cudnn_off=cudnn_o)
+                loss=y.square().sum()
+            loss.backward(train_mode=False)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
-

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to