This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 024b5a9  [MXNET-11241] Avoid use of troublesome cudnnFind() results 
when grad_req='add' (#11338)
024b5a9 is described below

commit 024b5a916dd3a39a39031ce5e6565cd7d9d60fe2
Author: Dick Carter <dick.car...@comcast.net>
AuthorDate: Mon Jul 30 13:34:34 2018 -0700

    [MXNET-11241] Avoid use of troublesome cudnnFind() results when 
grad_req='add' (#11338)
    
    * Add tests that fail due to issue 11241
    
    * Fix #11241 Conv1D throws CUDNN_STATUS_EXECUTION_FAILED
    
    * Force algo 1 when grad_req==add with large c.  Expand tests.
    
    * Shorten test runtimes.
---
 src/operator/nn/convolution.cu                  | 20 ++++++---
 src/operator/nn/cudnn/cudnn_algoreg-inl.h       | 11 +++--
 src/operator/nn/cudnn/cudnn_convolution-inl.h   | 36 +++++++++++++--
 src/operator/nn/cudnn/cudnn_deconvolution-inl.h | 38 ++++++++++++++--
 src/operator/nn/deconvolution.cu                | 20 ++++++---
 src/operator/operator_common.h                  |  2 +-
 tests/python/gpu/test_operator_gpu.py           | 59 +++++++++++++++++++++++++
 7 files changed, 162 insertions(+), 24 deletions(-)

diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu
index 797557e..daccc55 100644
--- a/src/operator/nn/convolution.cu
+++ b/src/operator/nn/convolution.cu
@@ -41,7 +41,8 @@ static CuDNNConvolutionOp<DType>& GetCuDNNConvOp(const 
ConvolutionParam& param,
                                                  int backward_compute_type,
                                                  const std::vector<TShape>& 
in_shape,
                                                  const std::vector<TShape>& 
out_shape,
-                                                 const RunContext& rctx) {
+                                                 const RunContext& rctx,
+                                                 bool add_to_weight) {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local std::unordered_map<ConvSignature,
                                          
std::shared_ptr<CuDNNConvolutionOp<DType> >,
@@ -57,14 +58,18 @@ static CuDNNConvolutionOp<DType>& GetCuDNNConvOp(const 
ConvolutionParam& param,
     ndim += s.ndim();
   for (auto &s : out_shape)
     ndim += s.ndim();
-  key.Reserve(1 /* for forward_compute_type */ + 1 /* for 
backward_compute_type */
-              + ndim + 1 /* for dev_id */);
+  key.Reserve(1 /* for forward_compute_type */ +
+              1 /* for backward_compute_type */ +
+              ndim /* for in and out shapes */ +
+              1 /* for dev_id */ +
+              1 /* for add_to_weight */);
 
   key.AddSign(forward_compute_type);
   key.AddSign(backward_compute_type);
   key.AddSign(in_shape);
   key.AddSign(out_shape);
   key.AddSign(rctx.ctx.dev_id);
+  key.AddSign(add_to_weight ? 1 : 0);
 
   auto it = ops.find(key);
   if (it == ops.end()) {
@@ -74,7 +79,7 @@ static CuDNNConvolutionOp<DType>& GetCuDNNConvOp(const 
ConvolutionParam& param,
     CHECK(ins_ret.second);
     it = ins_ret.first;
     it->second->Init(param, forward_compute_type, backward_compute_type, 
in_shape,
-                     out_shape, rctx);
+                     out_shape, rctx, add_to_weight);
   }
   return *it->second;
 }
@@ -141,8 +146,10 @@ void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs,
       std::vector<TShape> out_shape(1, outputs[0].shape_);
       for (size_t i = 0; i < in_shape.size(); i++)
         in_shape[i] = inputs[i].shape_;
+      // req[conv::kWeight] is only set for backward, so assume the typical 
'write' for now.
+      auto add_to_weight = false;
       CuDNNConvolutionOp<DType> &op = GetCuDNNConvOp<DType>(param,
-          compute_type, compute_type, in_shape, out_shape, ctx.run_ctx);
+          compute_type, compute_type, in_shape, out_shape, ctx.run_ctx, 
add_to_weight);
       op.Forward(ctx, inputs, req, outputs);
     }
   })
@@ -220,8 +227,9 @@ void ConvolutionGradCompute<gpu>(const nnvm::NodeAttrs& 
attrs,
       std::vector<TShape> out_shape(1, out_grad.shape_);
       for (size_t i = 0; i < in_shape.size(); i++)
         in_shape[i] = in_data[i].shape_;
+      auto add_to_weight = req[conv::kWeight] == kAddTo;
       CuDNNConvolutionOp<DType> &op = GetCuDNNConvOp<DType>(param,
-          compute_type, compute_type, in_shape, out_shape, ctx.run_ctx);
+          compute_type, compute_type, in_shape, out_shape, ctx.run_ctx, 
add_to_weight);
       op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad);
     }
   })
diff --git a/src/operator/nn/cudnn/cudnn_algoreg-inl.h 
b/src/operator/nn/cudnn/cudnn_algoreg-inl.h
index e029c83..3b59fd1 100644
--- a/src/operator/nn/cudnn/cudnn_algoreg-inl.h
+++ b/src/operator/nn/cudnn/cudnn_algoreg-inl.h
@@ -72,12 +72,13 @@ class CuDNNAlgoReg {
             cudnnDataType_t cudnn_forward_compute_type,
             cudnnDataType_t cudnn_backward_compute_type,
             int sm_arch,
+            bool add_to_weight,
             CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
             CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
             CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
     CHECK(in_shape.size() == 2 || in_shape.size() == 3);
     ParamKey key{param, in_shape[0], in_shape[1], out_shape[0], 
cudnn_data_type,
-                 cudnn_forward_compute_type, cudnn_backward_compute_type, 
sm_arch};
+                 cudnn_forward_compute_type, cudnn_backward_compute_type, 
sm_arch, add_to_weight};
     std::lock_guard<std::mutex> guard(lock_);
     auto i = reg_.find(key);
     if (i != reg_.end()) {
@@ -96,12 +97,13 @@ class CuDNNAlgoReg {
                 cudnnDataType_t cudnn_forward_compute_type,
                 cudnnDataType_t cudnn_backward_compute_type,
                 int sm_arch,
+                bool add_to_weight,
                 const CuDNNAlgo<cudnnConvolutionFwdAlgo_t> &fwd,
                 const CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> &bwd,
                 const CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> &flt) {
     CHECK(in_shape.size() == 2 || in_shape.size() == 3);
     ParamKey key{param, in_shape[0], in_shape[1], out_shape[0], 
cudnn_data_type,
-                 cudnn_forward_compute_type, cudnn_backward_compute_type, 
sm_arch};
+                 cudnn_forward_compute_type, cudnn_backward_compute_type, 
sm_arch, add_to_weight};
     std::lock_guard<std::mutex> guard(lock_);
     if (param.cudnn_tune.value() && reg_.size() % 50 == 0) {
       LOG(INFO) << "Running performance tests to find the best convolution "
@@ -140,6 +142,7 @@ class CuDNNAlgoReg {
     cudnnDataType_t cudnn_forward_compute_type;
     cudnnDataType_t cudnn_backward_compute_type;
     int sm_arch;
+    bool add_to_weight;
 
     bool operator==(const ParamKey& other) const {
       return this->param == other.param &&
@@ -149,7 +152,8 @@ class CuDNNAlgoReg {
              this->cudnn_data_type == other.cudnn_data_type &&
              this->cudnn_forward_compute_type == 
other.cudnn_forward_compute_type &&
              this->cudnn_backward_compute_type == 
other.cudnn_backward_compute_type &&
-             this->sm_arch == other.sm_arch;
+             this->sm_arch == other.sm_arch &&
+             this->add_to_weight == other.add_to_weight;
     }
   };
 
@@ -164,6 +168,7 @@ class CuDNNAlgoReg {
       ret = dmlc::HashCombine(ret, 
static_cast<int>(key.cudnn_forward_compute_type));
       ret = dmlc::HashCombine(ret, 
static_cast<int>(key.cudnn_backward_compute_type));
       ret = dmlc::HashCombine(ret, key.sm_arch);
+      ret = dmlc::HashCombine(ret, key.add_to_weight);
       return ret;
     }
   };
diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h 
b/src/operator/nn/cudnn/cudnn_convolution-inl.h
index 4b1cbbe..827c89f 100644
--- a/src/operator/nn/cudnn/cudnn_convolution-inl.h
+++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h
@@ -59,9 +59,11 @@ class CuDNNConvolutionOp {
             int backward_compute_type,
             const std::vector<TShape>& in_shape,
             const std::vector<TShape>& out_shape,
-            const RunContext& rctx) {
+            const RunContext& rctx,
+            bool add_to_weight) {
     using namespace mshadow;
     this->param_ = param;
+    this->add_to_weight_ = add_to_weight;
     InitBufferForParam();
     auto cudnn_forward_compute_type = 
convertToCuDNNDataType(forward_compute_type);
     auto cudnn_backward_compute_type = 
convertToCuDNNDataType(backward_compute_type);
@@ -247,6 +249,7 @@ class CuDNNConvolutionOp {
                                             gbias.dptr_));
     }
     if (req[conv::kWeight] != kNullOp) {
+        CHECK_EQ(add_to_weight_, req[conv::kWeight] == kAddTo);
         CUDNN_CALL(cudnnConvolutionBackwardFilter(s->dnn_handle_,
             &alpha,
             in_desc_,
@@ -610,8 +613,8 @@ class CuDNNConvolutionOp {
                   cudnnDataType_t cudnn_backward_compute_type) {
     if (!CuDNNConvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_,
                                        cudnn_forward_compute_type, 
cudnn_backward_compute_type,
-                                       SMArch(rctx.ctx.dev_id), 
&forward_algo_, &back_algo_,
-                                       &back_algo_w_)) {
+                                       SMArch(rctx.ctx.dev_id), add_to_weight_,
+                                       &forward_algo_, &back_algo_, 
&back_algo_w_)) {
       mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
       CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
       size_t workspace_byte = static_cast<size_t>(param_.workspace * 
sizeof(DType));
@@ -645,6 +648,8 @@ class CuDNNConvolutionOp {
       auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
       std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> 
bwd_filt_results(max_bwd_filt_algos);
       int actual_bwd_filter_algos = 0;
+      // In cudnn v7.1.4, find() returned wgrad algos that could fail for 
large c if we
+      // were summing into the output (i.e. beta != 0).  Get() returned OK 
algos though.
       auto bwd_filter_algo_discoverer =
         param_.cudnn_tune.value() == conv::kOff ? 
cudnnGetConvolutionBackwardFilterAlgorithm_v7
                                                 : 
cudnnFindConvolutionBackwardFilterAlgorithm;
@@ -792,6 +797,13 @@ class CuDNNConvolutionOp {
         }
       }
       #endif  // CUDNN_MAJOR < 7
+
+      // Fix for issue #11241
+      int cudnn_find_issue_max_features = 64 * 1024;
+      if (add_to_weight_ && Features(in_shape[conv::kData]) >= 
cudnn_find_issue_max_features) {
+        this->back_algo_w_.Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
+      }
+
       // An algo specification by the user may be cached here, but another
       // convolution will match only if identically specified.
       // We're caching results of *Get* as well as *Find*, but these records
@@ -799,7 +811,8 @@ class CuDNNConvolutionOp {
       CuDNNConvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_,
                                         cudnn_forward_compute_type,
                                         cudnn_backward_compute_type,
-                                        SMArch(rctx.ctx.dev_id), 
this->forward_algo_,
+                                        SMArch(rctx.ctx.dev_id), 
this->add_to_weight_,
+                                        this->forward_algo_,
                                         this->back_algo_, this->back_algo_w_);
     }
     // If we're allowing Tensor Core variants of the algos to be considered in
@@ -921,6 +934,19 @@ class CuDNNConvolutionOp {
     return tensor.MSize() * sizeof(DType);
   }
 
+  // Given a tensor shape of this operation, return the number of features 'c'
+  int64_t Features(const TShape &dshape) {
+    int c = 0;
+    switch (dshape.ndim()) {
+      case 3: c = ConvertLayout(dshape.get<3>(), param_.layout.value(), 
kNCW)[1]; break;
+      case 4: c = ConvertLayout(dshape.get<4>(), param_.layout.value(), 
kNCHW)[1]; break;
+      case 5: c = ConvertLayout(dshape.get<5>(), param_.layout.value(), 
kNCDHW)[1]; break;
+      default:
+        LOG(FATAL) << "Unexpected convolution data dimension " << 
dshape.ndim();
+    }
+    return c;
+  }
+
   std::vector<int> param_stride_;
   std::vector<int> param_dilate_;
   std::vector<int> param_pad_;
@@ -953,6 +979,8 @@ class CuDNNConvolutionOp {
   cudnnTensorFormat_t format_;
   // Allow TensorCore algo policy
   bool cudnn_tensor_core_;
+  // Is req[kWeight] == conv::kAddTo ?
+  bool add_to_weight_;
   ConvolutionParam param_;
 };
 #endif  // __CUDACC__ && CUDNN
diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h 
b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
index cb0de4c..f1b40cc 100644
--- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
+++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
@@ -56,9 +56,11 @@ class CuDNNDeconvolutionOp {
             int backward_compute_type,
             const std::vector<TShape>& in_shape,
             const std::vector<TShape>& out_shape,
-            const RunContext& rctx) {
+            const RunContext& rctx,
+            bool add_to_weight) {
     using namespace mshadow;
     this->param_ = param;
+    this->add_to_weight_ = add_to_weight;
     InitBufferForParam();
     auto cudnn_forward_compute_type = 
convertToCuDNNDataType(forward_compute_type);
     auto cudnn_backward_compute_type = 
convertToCuDNNDataType(backward_compute_type);
@@ -257,6 +259,7 @@ class CuDNNDeconvolutionOp {
           filter_desc_,
           gwmat.dptr_ + weight_offset_ * g));
         #elif CUDNN_MAJOR >= 5
+        CHECK_EQ(add_to_weight_, req[deconv::kWeight] == kAddTo);
         CUDNN_CALL(cudnnConvolutionBackwardFilter(
           s->dnn_handle_,
           &alpha,
@@ -543,8 +546,8 @@ class CuDNNDeconvolutionOp {
     if (!CuDNNDeconvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_,
                                          cudnn_forward_compute_type,
                                          cudnn_backward_compute_type,
-                                         SMArch(rctx.ctx.dev_id), 
&forward_algo_,
-                                         &back_algo_, &back_algo_w_)) {
+                                         SMArch(rctx.ctx.dev_id), 
add_to_weight_,
+                                         &forward_algo_, &back_algo_, 
&back_algo_w_)) {
       mshadow::Stream <gpu> *s = rctx.get_stream<gpu>();
       CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
       size_t workspace_byte = static_cast<size_t>(param_.workspace * 
sizeof(DType));
@@ -578,6 +581,8 @@ class CuDNNDeconvolutionOp {
       auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
       std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> 
bwd_filt_results(max_bwd_filt_algos);
       int actual_bwd_filter_algos = 0;
+      // In cudnn v7.1.4, find() returned wgrad algos that could fail for 
large c if we
+      // were summing into the output (i.e. beta != 0).  Get() returned OK 
algos though.
       auto bwd_filter_algo_discoverer =
         param_.cudnn_tune.value() == conv::kOff ? 
cudnnGetConvolutionBackwardFilterAlgorithm_v7
                                                 : 
cudnnFindConvolutionBackwardFilterAlgorithm;
@@ -728,6 +733,14 @@ class CuDNNDeconvolutionOp {
         }
       }
       #endif  // CUDNN_MAJOR < 7
+
+      // Fix for issue #11241
+      int cudnn_find_issue_max_features = 64 * 1024;
+      // With deconvolution, the algo sensitivity is to a large number of 
output features
+      if (add_to_weight_ && Features(out_shape[deconv::kOut]) >= 
cudnn_find_issue_max_features) {
+        this->back_algo_w_.Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
+      }
+
       // An algo specification by the user may be cached here, but another
       // convolution will match only if identically specified.
       // We're caching results of *Get* as well as *Find*, but these records
@@ -735,7 +748,8 @@ class CuDNNDeconvolutionOp {
       CuDNNDeconvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_,
                                           cudnn_forward_compute_type,
                                           cudnn_backward_compute_type,
-                                          SMArch(rctx.ctx.dev_id), 
this->forward_algo_,
+                                          SMArch(rctx.ctx.dev_id), 
this->add_to_weight_,
+                                          this->forward_algo_,
                                           this->back_algo_, 
this->back_algo_w_);
     }
     // If we're allowing Tensor Core variants of the algos to be considered in
@@ -866,6 +880,20 @@ class CuDNNDeconvolutionOp {
     return tensor.MSize() * sizeof(DType);
   }
 
+
+  // Given a tensor shape of this operation, return the number of features 'c'
+  int64_t Features(const TShape &dshape) {
+    int c = 0;
+    switch (dshape.ndim()) {
+      case 3: c = ConvertLayout(dshape.get<3>(), param_.layout.value(), 
kNCW)[1]; break;
+      case 4: c = ConvertLayout(dshape.get<4>(), param_.layout.value(), 
kNCHW)[1]; break;
+      case 5: c = ConvertLayout(dshape.get<5>(), param_.layout.value(), 
kNCDHW)[1]; break;
+      default:
+        LOG(FATAL) << "Unexpected deconvolution data dimension " << 
dshape.ndim();
+    }
+    return c;
+  }
+
   std::vector<int> param_stride_;
   std::vector<int> param_dilate_;
 
@@ -912,6 +940,8 @@ class CuDNNDeconvolutionOp {
   cudnnTensorFormat_t format_;
   // Allow TensorCore algo policy
   bool cudnn_tensor_core_;
+  // Is req[kWeight] == deconv::kAddTo ?
+  bool add_to_weight_;
   DeconvolutionParam param_;
 };
 #endif  // CUDNN
diff --git a/src/operator/nn/deconvolution.cu b/src/operator/nn/deconvolution.cu
index cdfb606..1c3970b 100644
--- a/src/operator/nn/deconvolution.cu
+++ b/src/operator/nn/deconvolution.cu
@@ -39,7 +39,8 @@ static CuDNNDeconvolutionOp<DType> &GetCuDNNDeconvOp(const 
DeconvolutionParam& p
                                                      int backward_compute_type,
                                                      const 
std::vector<TShape>& in_shape,
                                                      const 
std::vector<TShape>& out_shape,
-                                                     const RunContext& rctx) {
+                                                     const RunContext& rctx,
+                                                     bool add_to_weight) {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local std::unordered_map<DeconvSignature,
                                          
std::shared_ptr<CuDNNDeconvolutionOp<DType> >,
@@ -55,14 +56,18 @@ static CuDNNDeconvolutionOp<DType> &GetCuDNNDeconvOp(const 
DeconvolutionParam& p
     ndim += s.ndim();
   for (auto &s : out_shape)
     ndim += s.ndim();
-  key.Reserve(1 /* for forward_compute_type */ + 1 /* for 
backward_compute_type */
-              + ndim + 1 /* for dev_id */);
+  key.Reserve(1 /* for forward_compute_type */ +
+              1 /* for backward_compute_type */ +
+              ndim /* for in and out shapes */ +
+              1 /* for dev_id */ +
+              1 /* for add_to_weight */);
 
   key.AddSign(forward_compute_type);
   key.AddSign(backward_compute_type);
   key.AddSign(in_shape);
   key.AddSign(out_shape);
   key.AddSign(rctx.ctx.dev_id);
+  key.AddSign(add_to_weight ? 1 : 0);
 
   auto it = ops.find(key);
   if (it == ops.end()) {
@@ -72,7 +77,7 @@ static CuDNNDeconvolutionOp<DType> &GetCuDNNDeconvOp(const 
DeconvolutionParam& p
     CHECK(ins_ret.second);
     it = ins_ret.first;
     it->second->Init(param, forward_compute_type, backward_compute_type, 
in_shape,
-                     out_shape, rctx);
+                     out_shape, rctx, add_to_weight);
   }
   return *it->second;
 }
@@ -109,8 +114,10 @@ void DeconvolutionCompute<gpu>(const nnvm::NodeAttrs& 
attrs,
       for (size_t i = 0; i < in_shape.size(); i++) {
         in_shape[i] = inputs[i].shape_;
       }
+      // req[deconv::kWeight] is only set for backward, so assume the typical 
'write' for now.
+      auto add_to_weight = false;
       GetCuDNNDeconvOp<DType>(param, compute_type, compute_type,
-          in_shape, out_shape, ctx.run_ctx).Forward(ctx, inputs, req, outputs);
+          in_shape, out_shape, ctx.run_ctx, add_to_weight).Forward(ctx, 
inputs, req, outputs);
     }
   })
 #else
@@ -156,8 +163,9 @@ void DeconvolutionGradCompute<gpu>(const nnvm::NodeAttrs& 
attrs,
       for (size_t i = 0; i < in_shape.size(); i++) {
         in_shape[i] = in_data[i].shape_;
       }
+      auto add_to_weight = req[deconv::kWeight] == kAddTo;
       GetCuDNNDeconvOp<DType>(param, compute_type, compute_type,
-          in_shape, out_shape, ctx.run_ctx).Backward(ctx,
+          in_shape, out_shape, ctx.run_ctx, add_to_weight).Backward(ctx,
             std::vector<TBlob>{out_grad}, in_data, req, in_grad);
     }
   })
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 02130eb..2911293 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -494,7 +494,7 @@ inline void LogUnimplementedOp(const nnvm::NodeAttrs& attrs,
 }
 
 class OpSignature {
-  std::vector<int> eles;
+  std::vector<int64_t> eles;
   uint64_t hash;
 
  public:
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index 8877b57..a3e663a 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -522,6 +522,65 @@ def test_convolution_options():
     sym_no_cudnn = mx.sym.Convolution(num_filter=3, kernel=(1,1,1), 
pad=(0,0,0), cudnn_off=True, name='conv')
     check_consistency_NxM([sym, sym_no_cudnn], ctx_list)
 
+# This test is designed to expose an issue with cudnn v7.1.4 algo find() when 
invoked with large c.
+# Algos returned by find() can fail to run with grad_req='add' (wgrad kernel 
beta parameter == 1.0f).
+@with_seed()
+def test_convolution_large_c():
+    problematic_c = 64 * 1024
+    # The convolution accumulates many values, so set large tolerances.
+    tol = {np.dtype(np.float32): 1,
+           np.dtype(np.float64): 1}
+    def test_1D_with_width(width, grad_req):
+        ctx_list = [{'ctx': mx.gpu(0), 'conv_data': (1, problematic_c, width), 
'type_dict': {'conv_data': np.float32}},
+                    {'ctx': mx.gpu(0), 'conv_data': (1, problematic_c, width), 
'type_dict': {'conv_data': np.float64}}]
+        sym = mx.sym.Convolution(layout='NCW', num_filter=8, kernel=(2,), 
name='conv')
+        check_consistency([sym, sym], ctx_list, tol=tol, grad_req=grad_req)
+
+    def test_2D_with_width(width, grad_req):
+        ctx_list = [{'ctx': mx.gpu(0), 'conv_data': (1, problematic_c, 2, 
width), 'type_dict': {'conv_data': np.float32}},
+                    {'ctx': mx.gpu(0), 'conv_data': (1, problematic_c, 2, 
width), 'type_dict': {'conv_data': np.float64}}]
+        sym = mx.sym.Convolution(layout='NCHW', num_filter=4, kernel=(2,2), 
name='conv')
+        check_consistency([sym, sym], ctx_list, tol=tol, grad_req=grad_req)
+
+    # Run with different data tensor shapes to run cudnnFind() multiple times.
+    # First, populate algo and op caches with models that always use 
cudnnFind() (req == 'write').
+    # Then run models that must avoid cached cudnnFind() results in some cases 
(req == 'add').
+    widths = [4, 16, 64]
+    for req in ['write', 'add']:
+        for width in widths:
+            test_1D_with_width(width, req)
+            test_2D_with_width(width, req)
+
+
+# This test is designed to expose an issue with cudnn v7.1.4 algo find() when 
invoked with large c.
+# Algos returned by find() can fail to run with grad_req='add' (wgrad kernel 
beta parameter == 1.0f).
+@with_seed()
+def test_deconvolution_large_c():
+    problematic_c = 64 * 1024
+    # The deconvolution accumulates many values, so set large tolerances.
+    tol = {np.dtype(np.float32): 1,
+           np.dtype(np.float64): 1}
+    def test_1D_with_width(width, grad_req):
+        ctx_list = [{'ctx': mx.gpu(0), 'deconv_data': (1, 8, width), 
'type_dict': {'deconv_data': np.float32}},
+                    {'ctx': mx.gpu(0), 'deconv_data': (1, 8, width), 
'type_dict': {'deconv_data': np.float64}}]
+        sym = mx.sym.Deconvolution(layout='NCW', num_filter=problematic_c, 
kernel=(2,), name='deconv')
+        check_consistency([sym, sym], ctx_list, tol=tol, grad_req=grad_req)
+
+    def test_2D_with_width(width, grad_req):
+        ctx_list = [{'ctx': mx.gpu(0), 'deconv_data': (1, 8, 2, width), 
'type_dict': {'deconv_data': np.float32}},
+                    {'ctx': mx.gpu(0), 'deconv_data': (1, 8, 2, width), 
'type_dict': {'deconv_data': np.float64}}]
+        sym = mx.sym.Deconvolution(layout='NCHW', num_filter=problematic_c, 
kernel=(2,2), name='deconv')
+        check_consistency([sym, sym], ctx_list, tol=tol, grad_req=grad_req)
+
+    # Run with different data tensor shapes to run cudnnFind() multiple times.
+    # First, populate algo and op caches with models that always use 
cudnnFind() (req == 'write').
+    # Then run models that must avoid cached cudnnFind() results in some cases 
(req == 'add').
+    widths = [4, 16, 64]
+    for req in ['write', 'add']:
+        for width in widths:
+            test_1D_with_width(width, req)
+            test_2D_with_width(width, req)
+
 
 @with_seed()
 def test_convolution_versions():

Reply via email to