azai91 commented on a change in pull request #12985: adding unit test for 
MKLDNN FullyConnected operator
URL: https://github.com/apache/incubator-mxnet/pull/12985#discussion_r229139235
 
 

 ##########
 File path: tests/cpp/operator/mkldnn_operator_test.cc
 ##########
 @@ -554,6 +586,129 @@ void TestOpEx(const OpAttrs &forward_attrs, const 
OpAttrs &backwards_attrs) {
   }
 }
 
+uint32_t weight_dim2(const nnvm::TShape arr) {
+  uint32_t dim = 1;
+  for (int i = 1; i < arr.ndim(); i++) {
+    dim *= arr[i];
+  }
+  return dim;
+}
+
+void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs 
&backwards_attrs) {
+  std::vector<NDArray*> inputs(forward_attrs.num_inputs);
+  std::vector<NDArray*> outputs(forward_attrs.num_outputs);
+  std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
+
+  std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);
+  std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
+  std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
+
+  std::vector<OpReqType> req(forward_attrs.num_outputs);
+  std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
+
+  TestArrayShapes tas = GetTestArrayShapes();
+  std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
+
+  std::vector<NDArrayAttrs> in_arrs = 
GetTestInputArrays(forward_attrs.input_types, true);
+  std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
+  std::vector<std::vector<NDArrayAttrs>> 
ex_out_arrs(forward_attrs.num_outputs);
+
+  std::string str_hid = 
const_cast<OpAttrs&>(forward_attrs).attrs.dict["num_hidden"];
+  int num_hid = std::stoi(str_hid);
+
+  if (forward_attrs.requests.find(OpReqType::kWriteTo) != 
forward_attrs.requests.end()) {
+    for (int i1 = 0; i1 < in_arrs.size(); i1++) {
+      auto in_arr = in_arrs[i1];
+
+      if (in_arr.arr.shape().ndim() < 2)
+        continue;
+
+      auto in_shape = in_arr.arr.shape();
+
+      nnvm::TShape wt_shape(2);
+      wt_shape[0] = num_hid;
+      wt_shape[1] = weight_dim2(in_shape);
+      NDArray weights(wt_shape, Context());
+      InitDefaultArray(&weights, false);
+
+      nnvm::TShape bias_shape(1);
+      bias_shape[0] = num_hid;
+      NDArray bias(bias_shape, Context());
+      InitDefaultArray(&bias, false);
+
+      inputs[0] = &in_arr.arr;
+      inputs[1] = &weights;
+      inputs[2] = &bias;
+
+      nnvm::TShape out_shape(2);
+      out_shape[0] = in_shape[0];
+      out_shape[1] = num_hid;
+
+      for (int i = 0; i < forward_attrs.num_outputs; i++) {
+        out_arrs[i] =
+            GetTestOutputArrays(out_shape, pds, {1}, 
forward_attrs.output_types);
+        ex_out_arrs[i] =
+            GetTestOutputArrays(out_shape, pds, {1}, 
forward_attrs.output_types);
+      }
+
+      for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
+        if (out_arrs[0][output_i].arr.IsMKLDNNData())
 
 Review comment:
   why we filtering out mkldnn here?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to