This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new a40055e [MXNET-546] Add unit test for MKLDNNSum (#11272) a40055e is described below commit a40055eeb98326120642039e29732d2ca39a7610 Author: Alexander Zai <aza...@gmail.com> AuthorDate: Thu Jun 14 10:26:46 2018 -0700 [MXNET-546] Add unit test for MKLDNNSum (#11272) * add test for mkldnnsum * add extra comment * fix verify * only run if mkldnn supported * divide size by size of type * filter out regular arrays * fix cond * add verify mes * filter views * different outputs arrays * move print message * add in place test * update copy fn * refactor copyfrom * use arr.copy instead of tmpmemmg * use InitMKLDNNArray helper * fix params * pass correct type to copyfrom * add print message for inplace sum * remove copyfrom refactor * remove redundant header * remove extra line * fix lint * retrigger --- tests/cpp/operator/mkldnn.cc | 46 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/cpp/operator/mkldnn.cc b/tests/cpp/operator/mkldnn.cc index 76872d5..a7a1187 100644 --- a/tests/cpp/operator/mkldnn.cc +++ b/tests/cpp/operator/mkldnn.cc @@ -30,6 +30,7 @@ #include "gtest/gtest.h" #include "mxnet/imperative.h" #include "../../src/operator/nn/mkldnn/mkldnn_base-inl.h" +#include "../../src/operator/nn/mkldnn/mkldnn_ops-inl.h" using namespace mxnet; @@ -736,4 +737,49 @@ TEST(IMPERATIVE, BinaryOp) { TestBinaryOp(attrs, VerifySumResult); } +void VerifySumMemory(mkldnn::memory in_mem1, mkldnn::memory in_mem2, mkldnn::memory out_mem) { + float *in1 = static_cast<float*>(in_mem1.get_data_handle()); + float *in2 = static_cast<float*>(in_mem2.get_data_handle()); + float *out = static_cast<float*>(out_mem.get_data_handle()); + for (size_t i = 0; i < in_mem1.get_primitive_desc().get_size() / sizeof(float); i++) { + ASSERT_EQ(in1[i] + in2[i], out[i]); + } +} + +TEST(MKLDNN_BASE, MKLDNNSum) { + std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(InitDefaultArray); + TestArrayShapes tas = GetTestArrayShapes(); + std::vector<mkldnn::memory::primitive_desc> pds = tas.pds; + + for (auto in_arr : in_arrs) { + std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds, + InitDefaultArray); + if (!SupportMKLDNN(in_arr.arr) || !in_arr.arr.IsMKLDNNData() || in_arr.arr.IsView()) + continue; + + for (auto out_arr : out_arrs) { + auto in_mem1 = in_arr.arr.GetMKLDNNData(); + auto in_mem2 = in_arr.arr.GetMKLDNNData(); + auto out_mem = out_arr.arr.GetMKLDNNData(in_mem1->get_primitive_desc()); + if (out_mem == nullptr) + continue; + PrintVerifyMsg(in_arr, in_arr); + op::MKLDNNSum(*in_mem1, *in_mem2, *out_mem); + MKLDNNStream::Get()->Submit(); + VerifySumMemory(*in_mem1, *in_mem2, *out_mem); + } + + // in place + auto input_mem = in_arr.arr.GetMKLDNNData(); + NDArrayAttrs orig_arr(in_arr.arr.Copy(in_arr.arr.ctx()), "In Place Copy"); + PrintVerifyMsg(orig_arr, in_arr); + InitMKLDNNArray(&orig_arr.arr, input_mem->get_primitive_desc(), InitDefaultArray); + orig_arr.arr.CopyFrom(*input_mem); + auto old_mem = orig_arr.arr.GetMKLDNNData(); + op::MKLDNNSum(*input_mem, *input_mem, *input_mem); + MKLDNNStream::Get()->Submit(); + VerifySumMemory(*old_mem, *old_mem, *input_mem); + } +} + #endif -- To stop receiving notification emails like this one, please contact j...@apache.org.