This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new beafba7  [Improvement] Invoke mkldnn and cudnn BatchNorm when axis != 
1 (#18504)
beafba7 is described below

commit beafba76395e75c093f99d20ac62e38f48e91012
Author: JackieWu <w...@live.cn>
AuthorDate: Thu Jul 9 08:01:35 2020 +0800

    [Improvement] Invoke mkldnn and cudnn BatchNorm when axis != 1 (#18504)
    
    * fix batch norm when fix_gamma is True
    
    * support gradient accumulation for batch norm
    
    * mkldnn batchnorm support grad add
    
    * unittest for bn
    
    * fix bn arg
    
    * fix lint
    
    * fix mkldnn
    
    * fix mkldnn bn
    
    * fix grad when fixing gamma
    
    * fix naive gpu bn
    
    * fix lint
    
    * invoke mkldnn and cudnn batchnorm when axis != 1
    
    * backport 18500
    
    * change condition
    
    * fix
    
    * fix
    
    * add mkldnn_off for bn
    
    * remove mkldnn_off
    
    * recover save_000800.json
    
    * cast
---
 src/operator/nn/batch_norm.cc                  | 12 ++++---
 src/operator/nn/batch_norm.cu                  |  6 ++--
 src/operator/nn/cudnn/cudnn_batch_norm-inl.h   | 26 +++++++++++----
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 44 +++++++++++++++++++++++---
 4 files changed, 68 insertions(+), 20 deletions(-)

diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 7e540ca..2fdd31e 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -435,10 +435,14 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
 
 #if MXNET_USE_MKLDNN == 1
 static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam 
&param) {
-  mxnet::TShape shape = input.shape();
-  return SupportMKLDNN(input) && shape.ndim() == 4
-      && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
-      && !mxnet::op::batchnorm::disable_mkl;
+  if (mxnet::op::batchnorm::disable_mkl) return false;
+  const mxnet::TShape shape = input.shape();
+  const int ndim = shape.ndim();
+  if (ndim == 0 || shape.Size() == 0) return false;
+  const int dtype = input.dtype();
+  return (dtype == mshadow::kFloat32 ||
+          dtype == mshadow::kBfloat16) &&
+          SupportStorageMKLDNN(input.storage_type());
 }
 
 void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index 0875f05..c7e991f 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -698,8 +698,7 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
 
   param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
 #if MXNET_USE_CUDNN == 1
-  if (!param.use_global_stats && !param.cudnn_off
-      && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
+  if (!param.use_global_stats && !param.cudnn_off) {
     MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
       GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states);
     })
@@ -727,8 +726,7 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
 
   param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
 #if MXNET_USE_CUDNN == 1
-  if (!param.use_global_stats && !param.cudnn_off
-      && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
+  if (!param.use_global_stats && !param.cudnn_off) {
     MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
       GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs);
     })
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h 
b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
index 13db44d..340c2f3 100644
--- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
+++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
@@ -262,15 +262,27 @@ class CuDNNBatchNormOp {
 
  private:
   void Init(const TBlob &in_data) {
-    if (in_data.ndim() == 4) {
-      for (int i = 0; i < 4; ++i)
-        shape_[i] = in_data.shape_[i];
+    CHECK_GE(param_.axis, 0);
+    CHECK_LT(param_.axis, in_data.ndim());
+    if (param_.axis == 1) {
+      if (in_data.ndim() == 4) {
+        for (int i = 0; i < 4; ++i)
+          shape_[i] = in_data.shape_[i];
+      } else {
+        // when in_data.ndim() != 4
+        shape_[0] = in_data.shape_[0];
+        shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
+        shape_[2] = 1;
+        shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(2,
+              in_data.ndim()));
+      }
     } else {
-      // when in_data.ndim() != 4
-      shape_[0] = in_data.shape_[0];
-      shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
+      // reshape to (N, C, 1, D), C is the `param_.axis` dimension
+      shape_[0] = static_cast<dim_t>(in_data.shape_.ProdShape(0, param_.axis));
+      shape_[1] = in_data.shape_[param_.axis];
       shape_[2] = 1;
-      shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim());
+      shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(param_.axis + 1,
+            in_data.ndim()));
     }
 
     CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_,
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h 
b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index da4fd97..0a29a6d 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -157,7 +157,25 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, 
const OpContext &ctx,
                             const std::vector<NDArray> &inputs, const 
std::vector<OpReqType> &req,
                             const std::vector<NDArray> &outputs, bool 
fuse_relu) {
   const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
-  const std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + 
batchnorm::kInMovingMean);
+  std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + 
batchnorm::kInMovingMean);
+
+  mxnet::TShape shape = inputs[batchnorm::kData].shape();
+  const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
+  CHECK_LT(real_axis, shape.ndim());
+  NDArray out = outputs[batchnorm::kOut];
+  if (param.axis != 1 || shape.ndim() != 4) {
+    // reshape to (N, C, 1, D)
+    mxnet::TShape new_shape{
+      static_cast<dim_t>(shape.ProdShape(0, real_axis)),
+      shape[real_axis],
+      1,
+      static_cast<dim_t>(shape.ProdShape(real_axis + 1,
+            static_cast<int>(shape.ndim())))
+    };
+    in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape);
+    out = out.Reshape(new_shape);
+  }
+
   const std::vector<NDArray> aux_states(inputs.begin() + 
batchnorm::kInMovingMean, inputs.end());
   TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
   mkldnn::normalization_flags flags = _GetFlags(in_data,
@@ -166,7 +184,6 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, 
const OpContext &ctx,
                                                 fuse_relu);
   const NDArray &data = in_data[batchnorm::kData];
   auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
-  const NDArray &out = outputs[batchnorm::kOut];
 
   // for output memory
   auto out_mem = const_cast<NDArray 
&>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
@@ -325,9 +342,9 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, 
const OpContext &ctx,
                                                 ctx.is_train && 
!param.use_global_stats,
                                                 fuse_relu);
 
-  const NDArray &data         = in_data[batchnorm::kData];
-  const NDArray &diff         = out_grad[batchnorm::kOut];
-  const NDArray &gradIn       = in_grad[batchnorm::kData];
+  NDArray data         = in_data[batchnorm::kData];
+  NDArray diff         = out_grad[batchnorm::kOut];
+  NDArray gradIn       = in_grad[batchnorm::kData];
   const NDArray &moving_mean  = aux_states[batchnorm::kMovingMean];
   const NDArray &moving_var   = aux_states[batchnorm::kMovingVar];
   const NDArray &out_mean     = out_data[batchnorm::kMean];
@@ -338,6 +355,23 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, 
const OpContext &ctx,
   CHECK(moving_mean.IsDefaultData());
   CHECK(moving_var.IsDefaultData());
 
+  mxnet::TShape shape = data.shape();
+  const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
+  CHECK_LT(real_axis, shape.ndim());
+  if (param.axis != 1 || shape.ndim() != 4) {
+    // reshape to (N, C, 1, D)
+    mxnet::TShape new_shape{
+      static_cast<dim_t>(shape.ProdShape(0, real_axis)),
+      shape[real_axis],
+      1,
+      static_cast<dim_t>(shape.ProdShape(real_axis + 1,
+            static_cast<int>(shape.ndim())))
+    };
+    data = data.Reshape(new_shape);
+    diff = diff.Reshape(new_shape);
+    gradIn = gradIn.Reshape(new_shape);
+  }
+
   auto data_mem  = data.GetMKLDNNData();
   auto diff_mem  = diff.GetMKLDNNData();
   // MKLDNN batchnorm should run on special layouts. If one of them isn't, we

Reply via email to