rongzha1 commented on a change in pull request #16141: [mkldnn-v1.0] Add 
MKL-DNN Convolution
URL: https://github.com/apache/incubator-mxnet/pull/16141#discussion_r325194667
 
 

 ##########
 File path: src/operator/nn/mkldnn/mkldnn_convolution-inl.h
 ##########
 @@ -79,54 +79,63 @@ struct MKLDNNConvFullParam {
   MKLDNNPostEltwiseParam postsum_act_param;
 };
 
-mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const 
MKLDNNConvFullParam &param,
-                                                           const bool is_train,
-                                                           const NDArray &data,
-                                                           const NDArray 
&weights,
-                                                           const NDArray *bias,
-                                                           const NDArray 
&output);
+std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
+    const ConvolutionParam &param, const bool is_train, const NDArray &data, 
const NDArray &weight,
+    const NDArray *bias, const NDArray &output);
 
 class MKLDNNConvForward {
  public:
-  mkldnn::convolution_forward::primitive_desc fwd_pd;
-
   MKLDNNConvForward(const MKLDNNConvFullParam &param, const bool is_train, 
const NDArray &data,
-                    const NDArray &weights, const NDArray *bias, const NDArray 
&output);
+                    const NDArray &weight, const NDArray *bias, const NDArray 
&output);
 
-  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
-                 const mkldnn::memory *bias, const mkldnn::memory &output);
+  const mkldnn::convolution_forward &GetFwd() const { return *fwd_; }
 
-  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
-    this->data_->set_data_handle(data.get_data_handle());
-    this->out_->set_data_handle(output.get_data_handle());
-  }
-
-  const mkldnn::convolution_forward &GetFwd() const {
-    return *fwd_;
-  }
+  const mkldnn::convolution_forward::primitive_desc &GetPd() const { return 
*pd_; }
 
  private:
   std::shared_ptr<mkldnn::convolution_forward> fwd_;
-  std::shared_ptr<mkldnn::memory> data_;
-  std::shared_ptr<mkldnn::memory> weight_;
-  std::shared_ptr<mkldnn::memory> bias_;
-  std::shared_ptr<mkldnn::memory> out_;
+  std::shared_ptr<mkldnn::convolution_forward::primitive_desc> pd_;
 };
 
 typedef ParamOpSign<ConvolutionParam> MKLDNNConvSignature;
 
-MKLDNNConvForward &GetConvFwd(const ConvolutionParam &param,
-                              const bool is_train, const NDArray &data,
-                              const NDArray &weights, const NDArray *bias,
-                              const NDArray &output);
-
 void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam &param,
                                          const OpContext &ctx,
                                          MKLDNNConvForward *fwd,
                                          const std::vector<NDArray> &in_data,
                                          const std::vector<OpReqType> &req,
                                          const std::vector<NDArray> &out_data);
 
+void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs,
+                              const OpContext &ctx,
+                              const std::vector<NDArray> &in_data,
+                              const std::vector<OpReqType> &req,
+                              const std::vector<NDArray> &out_data);
+
+class MKLDNNConvBackward {
+ public:
+  MKLDNNConvBackward(const MKLDNNConvFullParam &param, const NDArray &data, 
const NDArray &weight,
+                     const NDArray *bias, const NDArray &output);
+
+  const mkldnn::convolution_backward_data &GetBwdData() const { return 
*bwd_data_; }
+
+  const mkldnn::convolution_backward_weights &GetBwdWeights() const { return 
*bwd_weight_; }
+
+  const mkldnn::convolution_backward_data::primitive_desc &GetDataPd() const {
+    return *bwd_data_pd_;
+  }
+
+  const mkldnn::convolution_backward_weights::primitive_desc &GetWeightsPd() 
const {
+    return *bwd_weights_pd_;
+  }
+
+ private:
+  std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> 
bwd_data_pd_;
+  std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> 
bwd_weights_pd_;
+  std::shared_ptr<mkldnn::convolution_backward_data> bwd_data_;
+  std::shared_ptr<mkldnn::convolution_backward_weights> bwd_weight_;
+};
+
 
 Review comment:
   let's all use weight as mkldnn_fully_connect.cc did. So change 
bwd_weights_pd_ to bwd_weight_pd_

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to