ZhennanQin commented on a change in pull request #16141: [mkldnn-v1.0] Add 
MKL-DNN Convolution
URL: https://github.com/apache/incubator-mxnet/pull/16141#discussion_r324945129
 
 

 ##########
 File path: src/operator/nn/mkldnn/mkldnn_convolution-inl.h
 ##########
 @@ -79,54 +79,63 @@ struct MKLDNNConvFullParam {
   MKLDNNPostEltwiseParam postsum_act_param;
 };
 
-mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const 
MKLDNNConvFullParam &param,
-                                                           const bool is_train,
-                                                           const NDArray &data,
-                                                           const NDArray 
&weights,
-                                                           const NDArray *bias,
-                                                           const NDArray 
&output);
+std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
+    const ConvolutionParam &param, const bool is_train, const NDArray &data, 
const NDArray &weight,
+    const NDArray *bias, const NDArray &output);
 
 class MKLDNNConvForward {
  public:
-  mkldnn::convolution_forward::primitive_desc fwd_pd;
-
   MKLDNNConvForward(const MKLDNNConvFullParam &param, const bool is_train, 
const NDArray &data,
-                    const NDArray &weights, const NDArray *bias, const NDArray 
&output);
+                    const NDArray &weight, const NDArray *bias, const NDArray 
&output);
 
-  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
-                 const mkldnn::memory *bias, const mkldnn::memory &output);
+  const mkldnn::convolution_forward &GetFwd() const { return *fwd_; }
 
-  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
-    this->data_->set_data_handle(data.get_data_handle());
-    this->out_->set_data_handle(output.get_data_handle());
-  }
-
-  const mkldnn::convolution_forward &GetFwd() const {
-    return *fwd_;
-  }
+  const mkldnn::convolution_forward::primitive_desc &GetPd() const { return 
*pd_; }
 
  private:
   std::shared_ptr<mkldnn::convolution_forward> fwd_;
-  std::shared_ptr<mkldnn::memory> data_;
-  std::shared_ptr<mkldnn::memory> weight_;
-  std::shared_ptr<mkldnn::memory> bias_;
-  std::shared_ptr<mkldnn::memory> out_;
+  std::shared_ptr<mkldnn::convolution_forward::primitive_desc> pd_;
 };
 
 typedef ParamOpSign<ConvolutionParam> MKLDNNConvSignature;
 
-MKLDNNConvForward &GetConvFwd(const ConvolutionParam &param,
-                              const bool is_train, const NDArray &data,
-                              const NDArray &weights, const NDArray *bias,
-                              const NDArray &output);
-
 void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam &param,
                                          const OpContext &ctx,
                                          MKLDNNConvForward *fwd,
                                          const std::vector<NDArray> &in_data,
                                          const std::vector<OpReqType> &req,
                                          const std::vector<NDArray> &out_data);
 
+void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs,
+                              const OpContext &ctx,
+                              const std::vector<NDArray> &in_data,
+                              const std::vector<OpReqType> &req,
+                              const std::vector<NDArray> &out_data);
+
+class MKLDNNConvBackward {
+ public:
+  MKLDNNConvBackward(const MKLDNNConvFullParam &param, const NDArray &data, 
const NDArray &weight,
+                     const NDArray *bias, const NDArray &output);
+
+  const mkldnn::convolution_backward_data &GetBwdData() const { return 
*bwd_data_; }
+
+  const mkldnn::convolution_backward_weights &GetBwdWeights() const { return 
*bwd_weight_; }
+
+  const mkldnn::convolution_backward_data::primitive_desc &GetDataPd() const {
+    return *bwd_data_pd_;
+  }
+
+  const mkldnn::convolution_backward_weights::primitive_desc &GetWeightsPd() 
const {
+    return *bwd_weights_pd_;
+  }
+
+ private:
+  std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> 
bwd_data_pd_;
+  std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> 
bwd_weights_pd_;
+  std::shared_ptr<mkldnn::convolution_backward_data> bwd_data_;
+  std::shared_ptr<mkldnn::convolution_backward_weights> bwd_weight_;
+};
+
 
 Review comment:
   This is a difference between mkldnn and mxnet. mxnet prefer to use `weight` 
while mkldnn prefer `weights`. So I don't know which should be the boundary of 
mkldnn and mxnet. @rongzha1 You can make the decision.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to