zheng-da commented on a change in pull request #8302: Refactor operators
URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r150124107
 
 

 ##########
 File path: src/operator/nn/batch_norm-inl.h
 ##########
 @@ -212,150 +212,43 @@ class BatchNormOp : public Operator {
 };  // class BatchNormOp
 
 template<typename xpu>
-Operator *CreateOp(BatchNormParam param, const int dtype, const TShape& shape);
-
-#if DMLC_USE_CXX11
-class BatchNormProp : public OperatorProperty {
- public:
-  void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) 
override {
-    param_.Init(kwargs);
-  }
-
-  std::map<std::string, std::string> GetParams() const override {
-    return param_.__DICT__();
-  }
-
-  bool InferShape(std::vector<TShape> *in_shape,
-                  std::vector<TShape> *out_shape,
-                  std::vector<TShape> *aux_shape) const override {
-    using namespace mshadow;
-    CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
-    const TShape &dshape = in_shape->at(0);
-
-    const size_t channelAxis = static_cast<size_t>(param_.axis < 0
-                            ? static_cast<int>(dshape.ndim()) + param_.axis
-                            : param_.axis);
-    CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << 
param_.axis;
-
-    const int channelCount = dshape[channelAxis];
-
-    if (dshape.ndim() == 0) {
-      return false;
-    }
-
-    in_shape->at(1) = TShape(Shape1(channelCount));
-    in_shape->at(2) = TShape(Shape1(channelCount));
-
-    out_shape->clear();
-    out_shape->push_back(dshape);                // kOut
-    out_shape->push_back(Shape1(channelCount));  // kMean
-    out_shape->push_back(Shape1(channelCount));  // kVar
-
-    aux_shape->clear();
-    aux_shape->push_back(Shape1(channelCount));  // kMovingMean
-    aux_shape->push_back(Shape1(channelCount));  // kMovingVar
-    return true;
-  }
-
-  bool InferType(std::vector<int> *in_type,
-                 std::vector<int> *out_type,
-                 std::vector<int> *aux_type) const override {
-    using namespace mshadow;
-    CHECK_GE(in_type->size(), 1U);
-    const int dtype = (*in_type)[0];
-    CHECK_NE(dtype, -1) << "First input must have specified type";
-    // For float16 input type beta, gamma, mean, and average are stored in 
float32.
-    // For other input types, these parameters have the same type as input
-    // NOTE: This requirement is from cuDNN (v. 4 and 5)
-    int dtype_param;
-    MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
-         dtype_param = mshadow::DataType<AccRealX>::kFlag; });
-    for (index_t i = 1; i < in_type->size(); ++i) {
-      if ((*in_type)[i] == -1) {
-        (*in_type)[i] = dtype_param;
-      } else {
-        UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, ListArguments()[i]);
-      }
-    }
-    for (index_t i = 0; i < aux_type->size(); ++i) {
-      if ((*aux_type)[i] != -1) {
-        UNIFORM_TYPE_CHECK((*aux_type)[i], dtype_param, ListArguments()[i]);
-      }
-    }
-    const size_t n_aux = this->ListAuxiliaryStates().size();
-    aux_type->clear();
-    for (size_t i = 0; i < n_aux; ++i) {
-      aux_type->push_back(dtype_param);
-    }
-    const size_t n_out = this->ListOutputs().size();
-    out_type->clear();
-    out_type->push_back(dtype);
-    for (size_t i = 1; i < n_out; ++i) {
-      out_type->push_back(dtype_param);
-    }
-    return true;
-  }
-
-  OperatorProperty* Copy() const override {
-    auto ptr = new BatchNormProp();
-    ptr->param_ = param_;
-    return ptr;
-  }
-
-  std::string TypeString() const override {
-    return "BatchNorm";
-  }
-
-  std::vector<int> DeclareBackwardDependency(
-    const std::vector<int> &out_grad,
-    const std::vector<int> &in_data,
-    const std::vector<int> &out_data) const override {
-    return {out_grad[batchnorm::kOut],
-            out_data[batchnorm::kMean],
-            out_data[batchnorm::kVar],
-            in_data[batchnorm::kData],
-            in_data[batchnorm::kGamma]
-           };
-  }
-
-  int NumVisibleOutputs() const override {
-    if (param_.output_mean_var) {
-      return 3;
-    }
-    return 1;
-  }
-
-  int NumOutputs() const override {
-    return 3;
-  }
-
-  std::vector<std::string> ListArguments() const override {
-    return {"data", "gamma", "beta"};
-  }
-
-  std::vector<std::string> ListOutputs() const override {
-    return {"output", "mean", "var"};
-  }
-
-  std::vector<std::string> ListAuxiliaryStates() const override {
-    return {"moving_mean", "moving_var"};
-  }
-
-  Operator* CreateOperator(Context ctx) const override {
-      LOG(FATAL) << "Not Implemented.";
-      return NULL;
-  }
-
-  Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-      std::vector<int> *in_type) const override;
+void BatchNormCompute(const nnvm::NodeAttrs& attrs,
+    const OpContext& ctx, const std::vector<TBlob>& inputs,
+    const std::vector<OpReqType>& req,
+    const std::vector<TBlob>& outputs) {
+  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), 5U);
+  std::vector<TBlob> in_data(inputs.begin(), inputs.begin() + 3);
+  std::vector<TBlob> aux_states(inputs.begin() + 3, inputs.end());
+  MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
+    static thread_local BatchNormOp<xpu, DType, AccReal> op;
+    op.Init(param);
+    op.Forward(ctx, in_data, req, outputs, aux_states);
+  });
+}
 
-  inline const BatchNormParam& getParam() const {
-    return param_;
-  }
+template<typename xpu>
+void BatchNormGradCompute(const nnvm::NodeAttrs& attrs,
+    const OpContext& ctx, const std::vector<TBlob>& inputs,
+    const std::vector<OpReqType>& req,
+    const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 11U);
+  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
+  std::vector<TBlob> out_grad(inputs.begin(),
+      inputs.begin() + (param.output_mean_var ? 3U : 1U));
+  std::vector<TBlob> in_data(inputs.begin() + 3, inputs.begin() + 6);
+  std::vector<TBlob> aux_states(inputs.begin() + 6, inputs.begin() + 8);
+  std::vector<TBlob> out_data(inputs.begin() + 8, inputs.end());
+  std::vector<TBlob> in_grad(outputs.begin(), outputs.begin() + 3);
+
 
 Review comment:
   You are right. I was trying to avoid modifying the original code as much as 
possible.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to