azai91 commented on a change in pull request #12985: adding unit test for MKLDNN FullyConnected operator URL: https://github.com/apache/incubator-mxnet/pull/12985#discussion_r229139235
########## File path: tests/cpp/operator/mkldnn_operator_test.cc ########## @@ -554,6 +586,129 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { } } +uint32_t weight_dim2(const nnvm::TShape arr) { + uint32_t dim = 1; + for (int i = 1; i < arr.ndim(); i++) { + dim *= arr[i]; + } + return dim; +} + +void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { + std::vector<NDArray*> inputs(forward_attrs.num_inputs); + std::vector<NDArray*> outputs(forward_attrs.num_outputs); + std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs); + + std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs); + std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs); + std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs); + + std::vector<OpReqType> req(forward_attrs.num_outputs); + std::vector<OpReqType> back_req(backwards_attrs.num_outputs); + + TestArrayShapes tas = GetTestArrayShapes(); + std::vector<mkldnn::memory::primitive_desc> pds = tas.pds; + + std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(forward_attrs.input_types, true); + std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs); + std::vector<std::vector<NDArrayAttrs>> ex_out_arrs(forward_attrs.num_outputs); + + std::string str_hid = const_cast<OpAttrs&>(forward_attrs).attrs.dict["num_hidden"]; + int num_hid = std::stoi(str_hid); + + if (forward_attrs.requests.find(OpReqType::kWriteTo) != forward_attrs.requests.end()) { + for (int i1 = 0; i1 < in_arrs.size(); i1++) { + auto in_arr = in_arrs[i1]; + + if (in_arr.arr.shape().ndim() < 2) + continue; + + auto in_shape = in_arr.arr.shape(); + + nnvm::TShape wt_shape(2); + wt_shape[0] = num_hid; + wt_shape[1] = weight_dim2(in_shape); + NDArray weights(wt_shape, Context()); + InitDefaultArray(&weights, false); + + nnvm::TShape bias_shape(1); + bias_shape[0] = num_hid; + NDArray bias(bias_shape, Context()); + InitDefaultArray(&bias, false); + + inputs[0] = &in_arr.arr; + inputs[1] = &weights; + inputs[2] = &bias; + + nnvm::TShape out_shape(2); + out_shape[0] = in_shape[0]; + out_shape[1] = num_hid; + + for (int i = 0; i < forward_attrs.num_outputs; i++) { + out_arrs[i] = + GetTestOutputArrays(out_shape, pds, {1}, forward_attrs.output_types); + ex_out_arrs[i] = + GetTestOutputArrays(out_shape, pds, {1}, forward_attrs.output_types); + } + + for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { + if (out_arrs[0][output_i].arr.IsMKLDNNData()) Review comment: why we filtering out mkldnn here? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services