rongzha1 commented on a change in pull request #16141: [mkldnn-v1.0] Add MKL-DNN Convolution URL: https://github.com/apache/incubator-mxnet/pull/16141#discussion_r325194667
########## File path: src/operator/nn/mkldnn/mkldnn_convolution-inl.h ########## @@ -79,54 +79,63 @@ struct MKLDNNConvFullParam { MKLDNNPostEltwiseParam postsum_act_param; }; -mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam ¶m, - const bool is_train, - const NDArray &data, - const NDArray &weights, - const NDArray *bias, - const NDArray &output); +std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl( + const ConvolutionParam ¶m, const bool is_train, const NDArray &data, const NDArray &weight, + const NDArray *bias, const NDArray &output); class MKLDNNConvForward { public: - mkldnn::convolution_forward::primitive_desc fwd_pd; - MKLDNNConvForward(const MKLDNNConvFullParam ¶m, const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output); + const NDArray &weight, const NDArray *bias, const NDArray &output); - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, - const mkldnn::memory *bias, const mkldnn::memory &output); + const mkldnn::convolution_forward &GetFwd() const { return *fwd_; } - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { - this->data_->set_data_handle(data.get_data_handle()); - this->out_->set_data_handle(output.get_data_handle()); - } - - const mkldnn::convolution_forward &GetFwd() const { - return *fwd_; - } + const mkldnn::convolution_forward::primitive_desc &GetPd() const { return *pd_; } private: std::shared_ptr<mkldnn::convolution_forward> fwd_; - std::shared_ptr<mkldnn::memory> data_; - std::shared_ptr<mkldnn::memory> weight_; - std::shared_ptr<mkldnn::memory> bias_; - std::shared_ptr<mkldnn::memory> out_; + std::shared_ptr<mkldnn::convolution_forward::primitive_desc> pd_; }; typedef ParamOpSign<ConvolutionParam> MKLDNNConvSignature; -MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, - const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, - const NDArray &output); - void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, const OpContext &ctx, MKLDNNConvForward *fwd, const std::vector<NDArray> &in_data, const std::vector<OpReqType> &req, const std::vector<NDArray> &out_data); +void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector<NDArray> &in_data, + const std::vector<OpReqType> &req, + const std::vector<NDArray> &out_data); + +class MKLDNNConvBackward { + public: + MKLDNNConvBackward(const MKLDNNConvFullParam ¶m, const NDArray &data, const NDArray &weight, + const NDArray *bias, const NDArray &output); + + const mkldnn::convolution_backward_data &GetBwdData() const { return *bwd_data_; } + + const mkldnn::convolution_backward_weights &GetBwdWeights() const { return *bwd_weight_; } + + const mkldnn::convolution_backward_data::primitive_desc &GetDataPd() const { + return *bwd_data_pd_; + } + + const mkldnn::convolution_backward_weights::primitive_desc &GetWeightsPd() const { + return *bwd_weights_pd_; + } + + private: + std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> bwd_data_pd_; + std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> bwd_weights_pd_; + std::shared_ptr<mkldnn::convolution_backward_data> bwd_data_; + std::shared_ptr<mkldnn::convolution_backward_weights> bwd_weight_; +}; + Review comment: let's all use weight as mkldnn_fully_connect.cc did. So change bwd_weights_pd_ to bwd_weight_pd_ ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services