zheng-da commented on a change in pull request #8302: Refactor operators URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r153368427
########## File path: src/operator/nn/batch_norm.cc ########## @@ -313,45 +314,76 @@ void BatchNormOp<xpu, DType, AccReal>::DoBackward(mshadow::Stream<cpu> *, } } -template<> -Operator *CreateOp<cpu>(BatchNormParam param, const int dtype, const TShape& shape) { - param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); - Operator *op = nullptr; -#if MXNET_USE_MKL2017 == 1 - if (shape.ndim() == 4 - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS - && !mxnet::op::batchnorm::disable_mkl) { - switch (dtype) { - case mshadow::kFloat32: - op = new MKLBatchNormOp<cpu, float>(param); - break; - case mshadow::kFloat64: - op = new MKLBatchNormOp<cpu, double>(param); - break; - default: - // MKL operator doesn't support half_t, so fall through - break; - } - } -#endif - if (!op) { - MSHADOW_REAL_TYPE_SWITCH_EX(dtype, - DType, - AccReal, { - op = new BatchNormOp<cpu, DType, AccReal>(param); }); +DMLC_REGISTER_PARAMETER(BatchNormParam); + +static bool BatchNormShape(const nnvm::NodeAttrs& attrs, + std::vector<TShape> *in_shape, std::vector<TShape> *out_shape) { + const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed); + using namespace mshadow; + CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]"; + const TShape &dshape = in_shape->at(0); + + const size_t channelAxis = static_cast<size_t>(param.axis < 0 + ? static_cast<int>(dshape.ndim()) + param.axis + : param.axis); + CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << param.axis; + + const int channelCount = dshape[channelAxis]; + + if (dshape.ndim() == 0) { + return false; } - return op; + + in_shape->at(1) = TShape(Shape1(channelCount)); + in_shape->at(2) = TShape(Shape1(channelCount)); + in_shape->at(3) = TShape(Shape1(channelCount)); // kMovingMean + in_shape->at(4) = TShape(Shape1(channelCount)); // kMovingVar + + out_shape->clear(); + out_shape->push_back(dshape); // kOut + out_shape->push_back(Shape1(channelCount)); // kMean + out_shape->push_back(Shape1(channelCount)); // kVar + + return true; } -// DO_BIND_DISPATCH comes from operator_common.h -Operator *BatchNormProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, - std::vector<int> *in_type) const { - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_shape)[0]); +static inline std::vector<std::string> ListArguments() { Review comment: this is where LegacyOpRunner is defined, but it's not used for batch norm. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services