This is an automated email from the ASF dual-hosted git repository. zhasheng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new beafba7 [Improvement] Invoke mkldnn and cudnn BatchNorm when axis != 1 (#18504) beafba7 is described below commit beafba76395e75c093f99d20ac62e38f48e91012 Author: JackieWu <w...@live.cn> AuthorDate: Thu Jul 9 08:01:35 2020 +0800 [Improvement] Invoke mkldnn and cudnn BatchNorm when axis != 1 (#18504) * fix batch norm when fix_gamma is True * support gradient accumulation for batch norm * mkldnn batchnorm support grad add * unittest for bn * fix bn arg * fix lint * fix mkldnn * fix mkldnn bn * fix grad when fixing gamma * fix naive gpu bn * fix lint * invoke mkldnn and cudnn batchnorm when axis != 1 * backport 18500 * change condition * fix * fix * add mkldnn_off for bn * remove mkldnn_off * recover save_000800.json * cast --- src/operator/nn/batch_norm.cc | 12 ++++--- src/operator/nn/batch_norm.cu | 6 ++-- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 26 +++++++++++---- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 44 +++++++++++++++++++++++--- 4 files changed, 68 insertions(+), 20 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 7e540ca..2fdd31e 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -435,10 +435,14 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) { - mxnet::TShape shape = input.shape(); - return SupportMKLDNN(input) && shape.ndim() == 4 - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS - && !mxnet::op::batchnorm::disable_mkl; + if (mxnet::op::batchnorm::disable_mkl) return false; + const mxnet::TShape shape = input.shape(); + const int ndim = shape.ndim(); + if (ndim == 0 || shape.Size() == 0) return false; + const int dtype = input.dtype(); + return (dtype == mshadow::kFloat32 || + dtype == mshadow::kBfloat16) && + SupportStorageMKLDNN(input.storage_type()); } void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs, diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 0875f05..c7e991f 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -698,8 +698,7 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 - if (!param.use_global_stats && !param.cudnn_off - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { + if (!param.use_global_stats && !param.cudnn_off) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states); }) @@ -727,8 +726,7 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 - if (!param.use_global_stats && !param.cudnn_off - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { + if (!param.use_global_stats && !param.cudnn_off) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs); }) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 13db44d..340c2f3 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -262,15 +262,27 @@ class CuDNNBatchNormOp { private: void Init(const TBlob &in_data) { - if (in_data.ndim() == 4) { - for (int i = 0; i < 4; ++i) - shape_[i] = in_data.shape_[i]; + CHECK_GE(param_.axis, 0); + CHECK_LT(param_.axis, in_data.ndim()); + if (param_.axis == 1) { + if (in_data.ndim() == 4) { + for (int i = 0; i < 4; ++i) + shape_[i] = in_data.shape_[i]; + } else { + // when in_data.ndim() != 4 + shape_[0] = in_data.shape_[0]; + shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; + shape_[2] = 1; + shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(2, + in_data.ndim())); + } } else { - // when in_data.ndim() != 4 - shape_[0] = in_data.shape_[0]; - shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; + // reshape to (N, C, 1, D), C is the `param_.axis` dimension + shape_[0] = static_cast<dim_t>(in_data.shape_.ProdShape(0, param_.axis)); + shape_[1] = in_data.shape_[param_.axis]; shape_[2] = 1; - shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); + shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(param_.axis + 1, + in_data.ndim())); } CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_, diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index da4fd97..0a29a6d 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -157,7 +157,25 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req, const std::vector<NDArray> &outputs, bool fuse_relu) { const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed); - const std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); + std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); + + mxnet::TShape shape = inputs[batchnorm::kData].shape(); + const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); + CHECK_LT(real_axis, shape.ndim()); + NDArray out = outputs[batchnorm::kOut]; + if (param.axis != 1 || shape.ndim() != 4) { + // reshape to (N, C, 1, D) + mxnet::TShape new_shape{ + static_cast<dim_t>(shape.ProdShape(0, real_axis)), + shape[real_axis], + 1, + static_cast<dim_t>(shape.ProdShape(real_axis + 1, + static_cast<int>(shape.ndim()))) + }; + in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape); + out = out.Reshape(new_shape); + } + const std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end()); TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); mkldnn::normalization_flags flags = _GetFlags(in_data, @@ -166,7 +184,6 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, fuse_relu); const NDArray &data = in_data[batchnorm::kData]; auto &fwd = GetBNForward<DType>(param, ctx, data, flags); - const NDArray &out = outputs[batchnorm::kOut]; // for output memory auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); @@ -325,9 +342,9 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, ctx.is_train && !param.use_global_stats, fuse_relu); - const NDArray &data = in_data[batchnorm::kData]; - const NDArray &diff = out_grad[batchnorm::kOut]; - const NDArray &gradIn = in_grad[batchnorm::kData]; + NDArray data = in_data[batchnorm::kData]; + NDArray diff = out_grad[batchnorm::kOut]; + NDArray gradIn = in_grad[batchnorm::kData]; const NDArray &moving_mean = aux_states[batchnorm::kMovingMean]; const NDArray &moving_var = aux_states[batchnorm::kMovingVar]; const NDArray &out_mean = out_data[batchnorm::kMean]; @@ -338,6 +355,23 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, CHECK(moving_mean.IsDefaultData()); CHECK(moving_var.IsDefaultData()); + mxnet::TShape shape = data.shape(); + const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); + CHECK_LT(real_axis, shape.ndim()); + if (param.axis != 1 || shape.ndim() != 4) { + // reshape to (N, C, 1, D) + mxnet::TShape new_shape{ + static_cast<dim_t>(shape.ProdShape(0, real_axis)), + shape[real_axis], + 1, + static_cast<dim_t>(shape.ProdShape(real_axis + 1, + static_cast<int>(shape.ndim()))) + }; + data = data.Reshape(new_shape); + diff = diff.Reshape(new_shape); + gradIn = gradIn.Reshape(new_shape); + } + auto data_mem = data.GetMKLDNNData(); auto diff_mem = diff.GetMKLDNNData(); // MKLDNN batchnorm should run on special layouts. If one of them isn't, we