piiswrong commented on a change in pull request #8302: Refactor operators & MKLDNN URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r156265180
########## File path: src/operator/nn/activation-inl.h ########## @@ -61,158 +62,127 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> { } }; -/** - * \brief This is the implementation of activation operator. - * \tparam xpu The device that the op will be executed on. - */ template<typename xpu, typename ForwardOp, typename BackwardOp, typename DType> -class ActivationOp : public Operator { - public: - virtual void Forward(const OpContext &ctx, - const std::vector<TBlob> &in_data, - const std::vector<OpReqType> &req, - const std::vector<TBlob> &out_data, - const std::vector<TBlob> &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 1U); - CHECK_EQ(out_data.size(), 1U); - Stream<xpu> *s = ctx.get_stream<xpu>(); - const TBlob& input = in_data[activation::kData]; - const size_t sz = input.shape_.Size(); - if (sz) { - MXNET_ASSIGN_REQ_SWITCH(req[activation::kOut], Req, { - mxnet_op::Kernel<mxnet_op::op_with_req<ForwardOp, Req>, xpu>::Launch( - s, sz, - out_data[activation::kOut].dptr<DType>(), - input.dptr<DType>()); - }); - } +void ActivationForward(const OpContext &ctx, const TBlob &in_data, + const OpReqType &req, const TBlob &out_data) { + using namespace mshadow; + using namespace mshadow::expr; + Stream<xpu> *s = ctx.get_stream<xpu>(); + const size_t sz = in_data.shape_.Size(); + if (sz) { + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + mxnet_op::Kernel<mxnet_op::op_with_req<ForwardOp, Req>, xpu>::Launch( + s, sz, + out_data.dptr<DType>(), + in_data.dptr<DType>()); + }); } +} - virtual void Backward(const OpContext &ctx, - const std::vector<TBlob> &out_grad, - const std::vector<TBlob> &in_data, - const std::vector<TBlob> &out_data, - const std::vector<OpReqType> &req, - const std::vector<TBlob> &in_grad, - const std::vector<TBlob> &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(out_grad.size(), 1U); - CHECK(in_data.size() == 1 && in_grad.size() == 1); - CHECK_EQ(req.size(), 1U); - Stream<xpu> *s = ctx.get_stream<xpu>(); - const TBlob& m_out_grad = out_grad[activation::kOut]; - const TBlob& m_out_data = out_data[activation::kOut]; - const TBlob& m_in_grad = in_grad[activation::kData]; - const size_t sz = m_out_data.shape_.Size(); - if (sz) { - MXNET_ASSIGN_REQ_SWITCH(req[activation::kData], Req, { - mxnet_op::Kernel<mxnet_op::op_with_req< - mxnet::op::mxnet_op::backward_grad_tuned<BackwardOp>, Req>, xpu>::Launch( - s, sz, - m_in_grad.dptr<DType>(), - m_out_grad.dptr<DType>(), - m_out_data.dptr<DType>()); - }); - } +template<typename xpu, typename ForwardOp, typename BackwardOp, typename DType> +void ActivationBackward(const OpContext &ctx, const TBlob &out_grad, + const TBlob &out_data, const OpReqType &req, + const TBlob &in_grad) { + using namespace mshadow; + using namespace mshadow::expr; + Stream<xpu> *s = ctx.get_stream<xpu>(); + const size_t sz = out_data.shape_.Size(); + if (sz) { + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + mxnet_op::Kernel<mxnet_op::op_with_req< + mxnet::op::mxnet_op::backward_grad_tuned<BackwardOp>, Req>, xpu>::Launch( + s, sz, + in_grad.dptr<DType>(), + out_grad.dptr<DType>(), + out_data.dptr<DType>()); + }); } -}; // class ActivationOp +} -// Declare Factory function, used for dispatch specialization template<typename xpu> -Operator* CreateOp(ActivationParam type, int dtype, const TShape& dshape); - -#if DMLC_USE_CXX11 -class ActivationProp : public OperatorProperty { - public: - void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { - param_.Init(kwargs); - } - - std::map<std::string, std::string> GetParams() const override { - return param_.__DICT__(); - } - - bool InferShape(std::vector<TShape> *in_shape, - std::vector<TShape> *out_shape, - std::vector<TShape> *aux_shape) const override { - using namespace mshadow; - CHECK_EQ(in_shape->size(), 1U) << "Input:[data]"; - const TShape &dshape = in_shape->at(activation::kData); - if (dshape.ndim() == 0) return false; - out_shape->clear(); - out_shape->push_back(dshape); - return true; - } - - bool InferType(std::vector<int> *in_type, - std::vector<int> *out_type, - std::vector<int> *aux_type) const override { - CHECK_GE(in_type->size(), 1U); - int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; - for (index_t i = 0; i < in_type->size(); ++i) { - if ((*in_type)[i] == -1) { - (*in_type)[i] = dtype; - } else { - UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); - } +void _ActivationCompute(const ActivationParam ¶m, const OpContext &ctx, + const TBlob &input, OpReqType req, const TBlob &output) { + MSHADOW_REAL_TYPE_SWITCH(input.type_flag_, DType, { + switch (param.act_type) { + case activation::kReLU: + ActivationForward<xpu, mshadow_op::relu, mshadow_op::relu_grad, DType>( + ctx, input, req, output); + break; + case activation::kSigmoid: + ActivationForward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad, DType>( + ctx, input, req, output); + break; + case activation::kTanh: + ActivationForward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad, DType>( + ctx, input, req, output); + break; + case activation::kSoftReLU: + ActivationForward<xpu, mshadow_op::softrelu, mshadow_op::softrelu_grad, DType>( + ctx, input, req, output); + break; + default: + LOG(FATAL) << "unknown activation type"; } - out_type->clear(); - out_type->push_back(dtype); - return true; - } + }); +} - OperatorProperty* Copy() const override { - auto ptr = new ActivationProp(); - ptr->param_ = param_; - return ptr; - } +template<typename xpu> +void _ActivationGradCompute(const ActivationParam ¶m, const OpContext &ctx, Review comment: BTW why separate _ActivationGradCompute from ActivationGradCompute? ActivationGradCompute is almost trivial now. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services