ciyongch commented on a change in pull request #14128: MKLDNN based Quantized 
FullyConnected Operator and its fusion
URL: https://github.com/apache/incubator-mxnet/pull/14128#discussion_r261482177
 
 

 ##########
 File path: src/operator/nn/mkldnn/mkldnn_fully_connected.cc
 ##########
 @@ -23,215 +23,289 @@
  * \author Da Zheng
 */
 
-#include "../fully_connected-inl.h"
-#include "./mkldnn_base-inl.h"
-
 #if MXNET_USE_MKLDNN == 1
+#include "mkldnn_fully_connected-inl.h"
+
 namespace mxnet {
 namespace op {
 
-inline static mkldnn::inner_product_forward::primitive_desc GetIPFwd(
+DMLC_REGISTER_PARAMETER(MKLDNNFCParam);
+
+mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
+    const MKLDNNFCFullParam &full_param, const bool is_train,
     const NDArray &data, const NDArray &weight, const NDArray *bias,
-    const mkldnn::memory::desc &out_md, const bool is_train) {
+    const mkldnn::memory::desc &out_md) {
   auto data_md = GetMemDesc(data);
   auto weight_md = GetMemDesc(weight);
   auto engine = CpuEngine::Get()->get_engine();
   auto propagation =
     is_train ? mkldnn::prop_kind::forward_training : 
mkldnn::prop_kind::forward_scoring;
+
+  mkldnn::primitive_attr attr;
+  mkldnn::post_ops ops;
+  if (full_param.mkldnn_param.with_relu) {
+    float scale = 1.0f;
+    float alpha = 0.0f;
+    float beta = 1.0f;
+    ops.append_eltwise(scale, eltwise_relu, alpha, beta);
+  }
+  attr.set_post_ops(ops);
+
+  if (full_param.mkldnn_param.quantized) {
+    if (full_param.mkldnn_param.fuse_requantize ||
+        full_param.mkldnn_param.fuse_dequantize) {
+      int mask = 0;
+      std::vector<float> scales = {0.0};
+      if (full_param.requantize_scales.size()) {
+        scales[0] = full_param.requantize_scales[0];
+      } else if (full_param.output_scales.size()) {
+        scales[0] = full_param.output_scales[0];
+      } else {
+        LOG(FATAL) << "Must specified either output_scales or 
requantize_scales!";
+      }
+
+      attr.set_output_scales(mask, scales);
+      attr.set_int_output_round_mode(round_nearest);
+    }
+  }
+
+  auto GetFCFwdPd = [&full_param, &attr,
+                     &engine](const mkldnn::inner_product_forward::desc &desc) 
{
+    try {
+      return mkldnn::inner_product_forward::primitive_desc(desc, attr, engine);
+    } catch (mkldnn::error &e) {
+      if (e.status == mkldnn_unimplemented &&
+          full_param.mkldnn_param.quantized) {
+        LOG(ERROR) << "AVX512-BW support or MKLDNN v0.18 is required for INT8 
fully_connected.";
+      } else {
+        LOG(ERROR) << e.message;
+      }
+      throw;
+    }
+  };
+
   if (bias) {
     auto bias_md = GetMemDesc(*bias);
-    mkldnn::inner_product_forward::desc ipFwd_desc(propagation,
+    mkldnn::inner_product_forward::desc desc(propagation,
         data_md, weight_md, bias_md, out_md);
-    return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine);
+    return GetFCFwdPd(desc);
   } else {
-    mkldnn::inner_product_forward::desc ipFwd_desc(propagation,
+    mkldnn::inner_product_forward::desc desc(propagation,
         data_md, weight_md, out_md);
-    return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine);
+    return GetFCFwdPd(desc);
   }
 }
 
-inline static mkldnn::inner_product_backward_data::primitive_desc GetIpBwdData(
+inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData(
     const NDArray &data, const NDArray &weight, const NDArray &output,
-    mkldnn::inner_product_forward::primitive_desc ipFwd_pd) {
+    mkldnn::inner_product_forward::primitive_desc fwd_pd) {
   auto data_md = GetMemDesc(data);
   auto weight_md = GetMemDesc(weight);
   auto out_md = GetMemDesc(output);
   auto engine = CpuEngine::Get()->get_engine();
   mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md);
-  return mkldnn::inner_product_backward_data::primitive_desc(desc, engine, 
ipFwd_pd);
+  return mkldnn::inner_product_backward_data::primitive_desc(desc, engine, 
fwd_pd);
 }
 
-inline static mkldnn::inner_product_backward_weights::primitive_desc 
GetIPBwdWeights(
+inline static mkldnn::inner_product_backward_weights::primitive_desc 
GetFCBwdWeights(
     const NDArray &data, const NDArray &weight, const NDArray *bias,
-    const NDArray &output, mkldnn::inner_product_forward::primitive_desc 
ipFwd_pd) {
+    const NDArray &output, mkldnn::inner_product_forward::primitive_desc 
fwd_pd) {
   auto data_md = GetMemDesc(data);
   auto weight_md = GetMemDesc(weight);
   auto out_md = GetMemDesc(output);
   auto engine = CpuEngine::Get()->get_engine();
   if (bias) {
     auto bias_md = GetMemDesc(*bias);
-    mkldnn::inner_product_backward_weights::desc ipBwdWeights_desc(data_md,
+    mkldnn::inner_product_backward_weights::desc desc(data_md,
         weight_md, bias_md, out_md);
     return mkldnn::inner_product_backward_weights::primitive_desc(
-        ipBwdWeights_desc, engine, ipFwd_pd);
+        desc, engine, fwd_pd);
   } else {
-    mkldnn::inner_product_backward_weights::desc ipBwdWeights_desc(data_md,
+    mkldnn::inner_product_backward_weights::desc desc(data_md,
         weight_md, out_md);
     return mkldnn::inner_product_backward_weights::primitive_desc(
-        ipBwdWeights_desc, engine, ipFwd_pd);
+        desc, engine, fwd_pd);
   }
 }
 
-class MKLDNNFullyConnectForward {
-  std::shared_ptr<mkldnn::memory> data;
-  std::shared_ptr<mkldnn::memory> weight;
-  std::shared_ptr<mkldnn::memory> out;
-  std::shared_ptr<mkldnn::memory> bias;
-  std::shared_ptr<mkldnn::inner_product_forward> ipFwd;
-
- public:
-  mkldnn::inner_product_forward::primitive_desc ipFwd_pd;
-
-  MKLDNNFullyConnectForward(const FullyConnectedParam &param, bool is_train,
-                            const NDArray &data, const NDArray &weight,
-                            const NDArray *bias,
-                            const mkldnn::memory::desc &output)
-      : ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {}
-
-  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
-                 const mkldnn::memory *bias, const mkldnn::memory &output) {
-    if (this->data == nullptr)
-      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              ipFwd_pd.src_primitive_desc(), data.get_data_handle()));
-    else
-      this->data->set_data_handle(data.get_data_handle());
+void MKLDNNFullyConnectedForward::SetNewMem(const mkldnn::memory &data,
+                                            const mkldnn::memory &weight,
+                                            const mkldnn::memory *bias,
+                                            const mkldnn::memory &output) {
+  if (this->data_ == nullptr)
+    this->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+            fwd_pd.src_primitive_desc(), data.get_data_handle()));
+  else
+    this->data_->set_data_handle(data.get_data_handle());
 
-    if (this->weight == nullptr)
-      this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              ipFwd_pd.weights_primitive_desc(), weight.get_data_handle()));
-    else
-      this->weight->set_data_handle(weight.get_data_handle());
+  if (this->weight_ == nullptr)
+    this->weight_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+            fwd_pd.weights_primitive_desc(), weight.get_data_handle()));
+  else
+    this->weight_->set_data_handle(weight.get_data_handle());
+
+  if (this->out_ == nullptr)
+    this->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+            fwd_pd.dst_primitive_desc(), output.get_data_handle()));
+  else
+    this->out_->set_data_handle(output.get_data_handle());
 
-    if (this->out == nullptr)
-      this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              ipFwd_pd.dst_primitive_desc(), output.get_data_handle()));
+  if (bias != nullptr) {
+    if (this->bias_ == nullptr)
+      this->bias_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+      fwd_pd.bias_primitive_desc(), bias->get_data_handle()));
     else
-      this->out->set_data_handle(output.get_data_handle());
-
-    if (bias != nullptr) {
-      if (this->bias == nullptr)
-        this->bias = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-        ipFwd_pd.bias_primitive_desc(), bias->get_data_handle()));
-      else
-        this->bias->set_data_handle(bias->get_data_handle());
-      if (this->ipFwd == nullptr)
-        this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
-            new mkldnn::inner_product_forward(
-                ipFwd_pd, mkldnn::primitive::at(*this->data),
-                mkldnn::primitive::at(*this->weight),
-                mkldnn::primitive::at(*this->bias), *this->out));
-    } else if (this->ipFwd == nullptr) {
-      this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
+      this->bias_->set_data_handle(bias->get_data_handle());
+
+    if (this->fwd_ == nullptr)
+      this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
           new mkldnn::inner_product_forward(
-              ipFwd_pd, mkldnn::primitive::at(*this->data),
-              mkldnn::primitive::at(*this->weight), *this->out));
+              fwd_pd, mkldnn::primitive::at(*this->data_),
+              mkldnn::primitive::at(*this->weight_),
+              mkldnn::primitive::at(*this->bias_), *this->out_));
+  } else {
+     if (this->fwd_ == nullptr) {
+      this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
+          new mkldnn::inner_product_forward(
+              fwd_pd, mkldnn::primitive::at(*this->data_),
+              mkldnn::primitive::at(*this->weight_), *this->out_));
     }
   }
-  const mkldnn::inner_product_forward &GetIpFwd() const {
-    return *ipFwd;
-  }
-};
-
-typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;
+}
 
-static inline MKLDNNFullyConnectForward &GetFCFwd(
-    const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight,
-    const NDArray *bias, const mkldnn::memory::desc &output,
-    const bool is_train) {
+MKLDNNFullyConnectedForward &GetFCFwd(
+    const FullyConnectedParam &param, const bool is_train,
+    const NDArray &data, const NDArray &weight,
+    const NDArray *bias, const mkldnn::memory::desc &out_md) {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local std::unordered_map<MKLDNNFullyconSignature,
-              MKLDNNFullyConnectForward, OpHash> fcFwds;
+              MKLDNNFullyConnectedForward, OpHash> fcFwds;
 #else
   static MX_THREAD_LOCAL std::unordered_map<MKLDNNFullyconSignature,
-              MKLDNNFullyConnectForward, OpHash> fcFwds;
+              MKLDNNFullyConnectedForward, OpHash> fcFwds;
 #endif
-  const FullyConnectedParam& param = 
nnvm::get<FullyConnectedParam>(attrs.parsed);
   MKLDNNFullyconSignature key(param);
+  key.AddSign(is_train);
   key.AddSign(data);
   key.AddSign(weight);
-  key.AddSign(is_train);
-
   if (bias)
     key.AddSign(*bias);
 
   auto it = fcFwds.find(key);
   if (it == fcFwds.end()) {
-    MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias,
-                                    output);
-    auto ins_ret = fcFwds.insert(
-        std::pair<MKLDNNFullyconSignature, MKLDNNFullyConnectForward>(key, 
fcFwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    MKLDNNFCFullParam full_param;
+    full_param.default_param = param;
+    full_param.mkldnn_param.Init(std::unordered_map<std::string, 
std::string>());
 
 Review comment:
   Since this function is only called by normal FullyConnected and Quantized 
FC, while  `mkldnn_param` was not used by these two Ops, so only 
`default_param` passed down from caller is enough.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to