cjolivier01 commented on a change in pull request #8302: Refactor operators
URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r153356747
 
 

 ##########
 File path: src/operator/nn/batch_norm.cc
 ##########
 @@ -313,45 +314,76 @@ void BatchNormOp<xpu, DType, 
AccReal>::DoBackward(mshadow::Stream<cpu> *,
   }
 }
 
-template<>
-Operator *CreateOp<cpu>(BatchNormParam param, const int dtype, const TShape& 
shape) {
-  param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
-  Operator *op = nullptr;
-#if MXNET_USE_MKL2017 == 1
-  if (shape.ndim() == 4
-      && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
-      && !mxnet::op::batchnorm::disable_mkl) {
-    switch (dtype) {
-      case mshadow::kFloat32:
-        op = new MKLBatchNormOp<cpu, float>(param);
-        break;
-      case mshadow::kFloat64:
-        op = new MKLBatchNormOp<cpu, double>(param);
-        break;
-      default:
-        // MKL operator doesn't support half_t, so fall through
-        break;
-    }
-  }
-#endif
-  if (!op) {
-    MSHADOW_REAL_TYPE_SWITCH_EX(dtype,
-                                DType,
-                                AccReal, {
-                                  op = new BatchNormOp<cpu, DType, 
AccReal>(param); });
+DMLC_REGISTER_PARAMETER(BatchNormParam);
+
+static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
+    std::vector<TShape> *in_shape, std::vector<TShape> *out_shape) {
+  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
+  using namespace mshadow;
+  CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, 
MovingVar]";
+  const TShape &dshape = in_shape->at(0);
+
+  const size_t channelAxis = static_cast<size_t>(param.axis < 0
+      ? static_cast<int>(dshape.ndim()) + param.axis
+      : param.axis);
+  CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << 
param.axis;
+
+  const int channelCount = dshape[channelAxis];
+
+  if (dshape.ndim() == 0) {
+    return false;
   }
-  return op;
+
+  in_shape->at(1) = TShape(Shape1(channelCount));
+  in_shape->at(2) = TShape(Shape1(channelCount));
+  in_shape->at(3) = TShape(Shape1(channelCount));  // kMovingMean
+  in_shape->at(4) = TShape(Shape1(channelCount));  // kMovingVar
+
+  out_shape->clear();
+  out_shape->push_back(dshape);                // kOut
+  out_shape->push_back(Shape1(channelCount));  // kMean
+  out_shape->push_back(Shape1(channelCount));  // kVar
+
+  return true;
 }
 
-// DO_BIND_DISPATCH comes from operator_common.h
-Operator *BatchNormProp::CreateOperatorEx(Context ctx, std::vector<TShape> 
*in_shape,
-                                          std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_shape)[0]);
+static inline std::vector<std::string> ListArguments() {
 
 Review comment:
   In the gtests, in test_legacy_op.h:
   template<typename OperatorProp, typename DType, typename AccReal>
   using LegacyOpRunner =
   mxnet::test::OperatorRunner<OperatorProp, LegacyOperatorExecutor<DType, 
AccReal>>;
   
   
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to