piiswrong closed pull request #10979: Fix bugs in MKLDNN.
URL: https://github.com/apache/incubator-mxnet/pull/10979
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index d87e8bc95ea..94d3d90413a 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -200,6 +200,7 @@ NDArray NDArray::MKLDNNDataReshape(const TShape &shape) 
const {
     ret.ptr_->delay_alloc = false;
     ret.ptr_->static_data = true;
     ret.byte_offset_ = byte_offset_;
+    ret.reuse_ = false;
     return ret;
   }
 }
@@ -217,6 +218,7 @@ NDArray NDArray::Reshape(const TShape &shape) const {
   // Otherwise, reshape only works on the default layout.
   CHECK_EQ(storage_type(), kDefaultStorage);
   ret.shape_ = shape;
+  ret.reuse_ = false;
   return ret;
 }
 
@@ -249,6 +251,7 @@ NDArray NDArray::Slice(index_t begin, index_t end) const {
   MSHADOW_TYPE_SWITCH(ret.dtype(), DType, {
     ret.byte_offset_ += begin * length * sizeof(DType);
   });
+  ret.reuse_ = false;
   ret.shape_[0] = end - begin;
   return ret;
 }
@@ -555,6 +558,7 @@ NDArray NDArray::Reorder2Default() const {
   // reshape as needed
   ret.shape_ = shape_;
   ret.byte_offset_ = byte_offset_;
+  ret.reuse_ = false;
   return ret;
 }
 
@@ -584,39 +588,39 @@ void NDArray::MKLDNNDataReorderAsync(const 
mkldnn::memory::primitive_desc &desc)
 
 const mkldnn::memory *NDArray::GetMKLDNNData() const {
   CHECK(storage_type() == kDefaultStorage);
+  bool is_view = IsView();
   if (IsMKLDNNData()) {
     // If this array uses MKLDNN layout, we have to make sure it's not a view.
     // Otherwise, we'll have to change the layout inside the array.
-    CHECK(!IsView());
+    CHECK(!is_view);
     MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
     // If this array uses MKLDNN format, we should return now. Otherwise,
     // SetMKLMem may mess up mkl_mem_.
     return ptr_->mkl_mem_->GetRaw();
-  }
-  ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, dtype_);
-  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
-  if (IsView()) {
-    mkldnn::memory::primitive_desc pd = ptr_->mkl_mem_->GetPrimitiveDesc();
-    // Sliced array must use the default layout.
-    CHECK_EQ(GetDefaultFormat(pd.desc()), pd.desc().data.format);
-    void *off_addr = static_cast<char *>(ptr_->mkl_mem_->GetDataHandle())
-        + byte_offset_;
-
+  } else if (is_view) {
+    // If this is a view, we can't create a MKLDNN memory for the chunk
+    // because we don't have the complete data type and shape information for
+    // the chunk.
+    void *off_addr = static_cast<char *>(ptr_->shandle.dptr) + byte_offset_;
     // Create the primitive desc for the new mkldnn memory.
     mkldnn::memory::dims dims(shape().ndim());
     for (size_t i = 0; i < dims.size(); i++)
       dims[i] = shape()[i];
     mkldnn::memory::format cpp_format = static_cast<mkldnn::memory::format>(
         GetDefaultFormat(shape().ndim()));
-    mkldnn::memory::data_type cpp_type = 
static_cast<mkldnn::memory::data_type>(
-        pd.desc().data.data_type);
+    mkldnn::memory::data_type cpp_type = get_mkldnn_type(dtype_);
     mkldnn::memory::desc data_md(dims, cpp_type, cpp_format);
-    mkldnn::memory::primitive_desc new_pd(data_md, pd.get_engine());
+    mkldnn::memory::primitive_desc new_pd(data_md,
+                                          CpuEngine::Get()->get_engine());
 
     std::shared_ptr<mkldnn::memory> ret(new mkldnn::memory(new_pd, off_addr));
     MKLDNNStream::Get()->RegisterMem(ret);
     return ret.get();
   } else {
+    // If this isn't a view, we can create a MKLDNN memory and store it in the
+    // chunk.
+    ptr_->SetMKLMem(shape_, dtype_);
+    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
     return ptr_->mkl_mem_->GetRaw();
   }
 }
@@ -637,20 +641,23 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
   MKLDNNStream *stream = MKLDNNStream::Get();
   // If this array uses MKLDNN layout, we have to make sure it's not a view.
   // Otherwise, we'll have to change the layout inside the array.
-  if (IsMKLDNNData())
-    CHECK(!IsView());
-  ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_,
-                  dtype_);
-  stream->RegisterMem(ptr_->mkl_mem_->GetMem());
-  mkldnn::memory::desc from_desc = mem.get_primitive_desc().desc();
-  mkldnn::memory::desc this_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc();
+
+  if (IsMKLDNNData() && IsView())
+    ptr_->Reorder2Default();
+
+  const mkldnn::memory *this_mem = GetMKLDNNData();
+  mkldnn::memory::primitive_desc from_pd = mem.get_primitive_desc();
+  mkldnn::memory::desc from_desc = from_pd.desc();
+  mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc();
+  mkldnn::memory::desc this_desc = this_pd.desc();
   mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc);
+  mkldnn_memory_format_t this_def_format = GetDefaultFormat(this_desc);
   if (IsView()) {
     // Sliced array must use the default layout.
     CHECK_EQ(GetDefaultFormat(this_desc), this_desc.data.format);
   }
   // It's possible that the memory and the NDArray don't have the same shape.
-  if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims)
+  if (!same_shape(this_desc, from_desc)
       // If the source memory uses the default layout, we can reshape directly.
       && from_def_format == from_desc.data.format) {
     // In this case, we can simply create a new MKLDNN memory for the required
@@ -660,15 +667,14 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
     auto this_dtype = 
static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
     auto this_format = 
static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
     mkldnn::memory::desc data_md(dims, this_dtype, this_format);
-    mkldnn::memory::primitive_desc pd(data_md, 
mem.get_primitive_desc().get_engine());
+    mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
     mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
     stream->RegisterMem(tmp_mem);
-    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw()));
-  } else if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims)) {
+    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
+  } else if (!same_shape(this_desc, from_desc)) {
     // In this case, the source memory stores data in a customized layout. We
     // need to reorganize the data in memory before we can reshape.
-    mkldnn::memory::primitive_desc def_pd = 
GetPrimitiveDesc(mem.get_primitive_desc(),
-                                                             from_def_format);
+    mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(from_pd, 
from_def_format);
     mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd);
     stream->RegisterPrim(mkldnn::reorder(mem, *def_mem));
     // Now we can reshape it
@@ -677,45 +683,40 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
     auto this_dtype = 
static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
     auto this_format = 
static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
     mkldnn::memory::desc data_md(dims, this_dtype, this_format);
-    mkldnn::memory::primitive_desc pd(data_md, 
mem.get_primitive_desc().get_engine());
+    mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
     mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle()));
     stream->RegisterMem(tmp_mem);
-    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw()));
-  } else if (mem.get_primitive_desc() == ptr_->mkl_mem_->GetPrimitiveDesc()) {
+    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
+  } else if (from_pd == this_pd) {
     // If the layout is the same, we can just copy data.
-    stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_->GetRaw()));
+    stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
   } else {
-    mkldnn_memory_format_t src_def = 
GetDefaultFormat(mem.get_primitive_desc().desc());
-    mkldnn_memory_format_t dst_def = ptr_->mkl_mem_->GetDefaultFormat();
     // If both are not using the default layouts. There isn't much we can do,
     // other than reorder data layout directly.
-    if (dst_def != ptr_->mkl_mem_->GetFormat()
-        && src_def != mem.get_primitive_desc().desc().data.format) {
-      stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_->GetRaw()));
-    } else if (dst_def == ptr_->mkl_mem_->GetFormat()) {
+    if (this_def_format != this_desc.data.format
+        && from_def_format != from_desc.data.format) {
+      stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
+    } else if (this_def_format == this_desc.data.format) {
       // If the dest mem uses the default memory layout, we can simply use
       // the default format of the source memory to improve perf of reorder.
-      mkldnn::memory::primitive_desc pd = 
ptr_->mkl_mem_->GetPrimitiveDesc(src_def);
-      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, 
ptr_->mkl_mem_->GetDataHandle()));
+      mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(from_pd,
+                                                           from_def_format);
+      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, 
this_mem->get_data_handle()));
       stream->RegisterMem(tmp_mem);
       stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem));
     } else {
       // If the src mem uses the default memory layout, we can use
       // the default format of the source memory to improve perf.
-      mkldnn::memory::primitive_desc pd = 
GetPrimitiveDesc(mem.get_primitive_desc(), dst_def);
+      mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(this_pd,
+                                                           this_def_format);
       mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
       stream->RegisterMem(tmp_mem);
-      stream->RegisterPrim(mkldnn::reorder(*tmp_mem, 
*ptr_->mkl_mem_->GetRaw()));
+      stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
     }
   }
 }
-mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc 
pd,
-                                                mkldnn_memory_format_t format);
 
 mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc 
&desc) {
-  // This array shouldn't be a view.
-  CHECK(!IsView());
-
   if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
     LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN 
memory desc";
     return nullptr;
@@ -726,10 +727,26 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const 
mkldnn::memory::primitive_desc &
   mkldnn_memory_format_t def_format = GetDefaultFormat(_desc.desc());
   // If the required format is a default format, we don't need to worry about 
the shape.
   // If the shape isn't the same, it actually implicitly reshapes data.
-  if (required_format == def_format) {
+  if (required_format == def_format && !IsView()) {
     ptr_->SetMKLMem(shape_, dtype_);
     MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
     return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc);
+  } else if (required_format == def_format) {
+    ptr_->CheckAndAlloc();
+    CHECK(ptr_->shandle.dptr);
+    // When this is a view and a user wants the default layout, we can simply
+    // create a new mkldnn memory that points to the right memory.
+    std::shared_ptr<mkldnn::memory> mem(new mkldnn::memory(
+            desc, static_cast<char *>(ptr_->shandle.dptr) + byte_offset_));
+    MKLDNNStream::Get()->RegisterMem(mem);
+    return mem.get();
+  } else if (IsView()) {
+    // If this is a view and a user wants to write data to it with special
+    // a MKLDNN format, we should reorder the data in the array and return 
NULL.
+    // In this way, the user will create a new NDArray for the special format
+    // and copy data back.
+    ptr_->Reorder2Default();
+    return nullptr;
   }
 
   if (ptr_->mkl_mem_)
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc 
b/src/operator/nn/mkldnn/mkldnn_base.cc
index 56abc6b79ff..1bd1581dbc2 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -297,10 +297,7 @@ void FallBackCompute(FCompute fn, const nnvm::NodeAttrs 
&attrs,
     } else {
       if (in_bufs.empty())
         in_bufs.reserve(inputs.size());
-      in_bufs.emplace_back(inputs[i].shape(), inputs[i].ctx(),
-                           false, inputs[i].dtype());
-      const mkldnn::memory *mem = inputs[i].GetMKLDNNData();
-      in_bufs.back().CopyFrom(*mem);
+      in_bufs.push_back(inputs[i].Reorder2Default());
       in_blobs[i] = in_bufs.back().data();
     }
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc 
b/src/operator/nn/mkldnn/mkldnn_sum.cc
index 361ce58af07..fdbfb1558f6 100644
--- a/src/operator/nn/mkldnn/mkldnn_sum.cc
+++ b/src/operator/nn/mkldnn/mkldnn_sum.cc
@@ -59,8 +59,15 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const 
OpContext &ctx,
   std::vector<float> scales(inputs.size(), 1);
   in_prims.reserve(inputs.size());
   bool pd_same = true;
+  std::vector<NDArray> in_bufs(inputs.size());
   for (size_t i = 0; i < inputs.size(); i++) {
-    auto in_mem = inputs[i].GetMKLDNNData();
+    const mkldnn::memory *in_mem;
+    if (inputs[i].IsMKLDNNData() && inputs[i].IsView()) {
+      in_bufs[i] = inputs[i].Reorder2Default();
+      in_mem = in_bufs[i].GetMKLDNNData();
+    } else {
+      in_mem = inputs[i].GetMKLDNNData();
+    }
     in_prims.push_back(*in_mem);
     in_pds[i] = in_mem->get_primitive_desc();
   }
@@ -68,9 +75,16 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const 
OpContext &ctx,
   mkldnn::sum::primitive_desc pdesc(scales, in_pds);
   pd_same = pd_same && (pdesc.dst_primitive_desc() == in_pds[0]);
   auto out_mem = 
const_cast<NDArray&>(out_data).CreateMKLDNNData(pdesc.dst_primitive_desc());
-  bool addr_same = out_mem->get_data_handle() == 
inputs[0].GetMKLDNNData()->get_data_handle();
-  if ((req == kWriteTo) ||
-      (req == kWriteInplace && pd_same && addr_same)) {
+  bool addr_same = false;
+  const void *first_data_handle;
+  if (in_bufs[0].is_none())
+    first_data_handle = inputs[0].GetMKLDNNData()->get_data_handle();
+  else
+    first_data_handle = in_bufs[0].GetMKLDNNData()->get_data_handle();
+  if (out_mem)
+    addr_same = out_mem->get_data_handle() == first_data_handle;
+  if (((req == kWriteTo) || (req == kWriteInplace && pd_same && addr_same))
+      && out_mem) {
     // do sum computation directly on output NDArray
     MKLDNNStream *stream = MKLDNNStream::Get();
     stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *out_mem));
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc 
b/src/operator/tensor/elemwise_binary_op_basic.cc
index dbb26ead149..9b5b9d39248 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_op_basic.cc
@@ -43,16 +43,8 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
     return;
   } else if (inputs[0].storage_type() == kDefaultStorage
              && inputs[1].storage_type() == kDefaultStorage) {
-    // This happens if inputs are supposed to be in MKLDNN format
-    // but MKLDNN doesn't support the data type or the shape. We're
-    // forced to convert it to the default format.
-    std::vector<TBlob> in_blobs(2);
-    std::vector<TBlob> out_blobs(1);
-    in_blobs[0] = inputs[0].data();
-    in_blobs[1] = inputs[1].data();
-    out_blobs[0] = outputs[0].data();
-    ElemwiseBinaryOp::Compute<cpu, op::mshadow_op::plus>(attrs, ctx, in_blobs,
-                                                         req, out_blobs);
+    FallBackCompute(ElemwiseBinaryOp::Compute<cpu, op::mshadow_op::plus>,
+                    attrs, ctx, inputs, req, outputs);
     return;
   }
 #endif
diff --git a/tests/cpp/operator/mkldnn.cc b/tests/cpp/operator/mkldnn.cc
index bbff53152ec..1d29219105d 100644
--- a/tests/cpp/operator/mkldnn.cc
+++ b/tests/cpp/operator/mkldnn.cc
@@ -89,12 +89,17 @@ TEST(MKLDNN_UTIL_FUNC, MemFormat) {
 }
 
 // Init arrays with the default layout.
-static void InitArray(NDArray *arr) {
+static void InitArray(NDArray *arr, bool is_rand = false) {
   const TBlob &blob = arr->data();
   mshadow::default_real_t *data = blob.dptr<mshadow::default_real_t>();
   size_t size = blob.Size();
-  for (size_t i = 0; i < size; i++)
-    data[i] = i;
+  if (is_rand) {
+    for (size_t i = 0; i < size; i++)
+      data[i] = std::rand();
+  } else {
+    for (size_t i = 0; i < size; i++)
+      data[i] = i;
+  }
 }
 
 // Init arrays with the specified layout.
@@ -361,6 +366,15 @@ OpAttrs GetLeakyReluOp() {
   return attrs;
 }
 
+OpAttrs GetSumOp() {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("elemwise_add");
+  attrs.dispatches.resize(2);
+  attrs.dispatches[0] = DispatchMode::kFCompute;
+  attrs.dispatches[1] = DispatchMode::kFComputeEx;
+  return attrs;
+}
+
 /*
  * We want to get a few types of NDArrays for testing:
  * 1. Normal NDArray
@@ -418,34 +432,65 @@ std::vector<NDArray> GetTestInputArrays() {
  *    pass them to all operators.
  *    In the inference mode, the MKLDNN memory in the weight array will be
  *    reordered to 5 dimensions.
- * 4. Reused NDArray (this is created by the MXNet executor). This type of
+ * 4. Reshaped/sliced NDArray
+ * 5. Reused NDArray (this is created by the MXNet executor). This type of
  *    NDArrays can only be used as output arrays.
+ * 6. Reused NDArray converted from an array with a different data type.
+ * 7. Reused reshaped/sliced NDArray.
+ * 8. Reused NDArray with MKLDNN layout.
+ * 9. Reused NDArray with MKLDNN layout of different dimensions.
  */
 std::vector<NDArray> GetTestOutputArrays(const TShape &shape,
                                          const 
std::vector<mkldnn::memory::primitive_desc> &pds) {
   std::vector<NDArray> in_arrs;
+  // Type 1.
   in_arrs.emplace_back(shape, Context());
-  InitArray(&in_arrs.back());
+  InitArray(&in_arrs.back(), true);
+
+  // Type 4.
+  TShape tmp_shape = shape;
+  tmp_shape[0] = shape[0] * 2;
+  NDArray arr0(tmp_shape, Context());
+  InitArray(&arr0, true);
+  in_arrs.emplace_back(arr0.Slice(1, shape[0] + 1));
 
+  // Type 5.
   // Get a reused version.
   nnvm::TShape s(1);
   s[0] = shape.Size();
-  NDArray arr(s, Context());
-  arr = arr.AsArray(shape, arr.dtype());
-  InitArray(&arr);
-  in_arrs.emplace_back(arr);
+  NDArray arr1(s, Context());
+  arr1 = arr1.AsArray(shape, arr1.dtype());
+  InitArray(&arr1, true);
+  in_arrs.emplace_back(arr1);
+
+  // Type 6.
+  s[0] = shape.Size() * GetTypeSize(mshadow::default_type_flag);
+  NDArray arr2(s, Context(), true, mshadow::kUint8);
+  arr2 = arr2.AsArray(shape, mshadow::default_type_flag);
+  InitArray(&arr2, true);
+  in_arrs.emplace_back(arr2);
+
+  // Type 7
+  s[0] = shape.Size() * GetTypeSize(mshadow::default_type_flag) * 2;
+  NDArray arr3(s, Context(), true, mshadow::kUint8);
+  tmp_shape[0] = shape[0] * 2;
+  arr3 = arr3.AsArray(tmp_shape, mshadow::default_type_flag);
+  InitArray(&arr3, true);
+  in_arrs.emplace_back(arr3.Slice(1, shape[0] + 1));
 
   for (auto pd : pds) {
     if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t))
       continue;
 
+    // Type 2, 3.
     in_arrs.emplace_back(shape, Context());
     InitMKLDNNArray(&in_arrs.back(), pd, true);
 
+    // Type 8, 9.
     // Get a reused version.
     nnvm::TShape s(1);
     s[0] = shape.Size();
-    arr = NDArray(s, Context());
+    NDArray arr = NDArray(s, Context());
     arr = arr.AsArray(shape, arr.dtype());
     InitMKLDNNArray(&arr, pd, true);
     in_arrs.emplace_back(arr);
@@ -453,10 +498,10 @@ std::vector<NDArray> GetTestOutputArrays(const TShape 
&shape,
   return in_arrs;
 }
 
-using VerifyFunc = std::function<void (const NDArray &in_arr, const NDArray 
&arr)>;
+using VerifyFunc = std::function<void (const std::vector<NDArray *> &in_arrs, 
const NDArray &arr)>;
 
-void VerifyCopyResult(const NDArray &in_arr, const NDArray &arr) {
-  NDArray tmp1 = in_arr.Reorder2Default();
+void VerifyCopyResult(const std::vector<NDArray *> &in_arrs, const NDArray 
&arr) {
+  NDArray tmp1 = in_arrs[0]->Reorder2Default();
   NDArray tmp2 = arr.Reorder2Default();
   EXPECT_EQ(tmp1.shape().Size(), tmp2.shape().Size());
   TBlob d1 = tmp1.data();
@@ -465,6 +510,40 @@ void VerifyCopyResult(const NDArray &in_arr, const NDArray 
&arr) {
                    tmp1.shape().Size() * sizeof(mshadow::default_real_t)), 0);
 }
 
+void VerifySumResult(const std::vector<NDArray *> &in_arrs, const NDArray 
&arr) {
+  NDArray in1 = in_arrs[0]->Reorder2Default();
+  NDArray in2 = in_arrs[1]->Reorder2Default();
+  NDArray out = arr.Reorder2Default();
+  EXPECT_EQ(in1.shape().Size(), in2.shape().Size());
+  EXPECT_EQ(in1.shape().Size(), out.shape().Size());
+
+  mshadow::default_real_t *d1 = in1.data().dptr<mshadow::default_real_t>();
+  mshadow::default_real_t *d2 = in2.data().dptr<mshadow::default_real_t>();
+  mshadow::default_real_t *o = out.data().dptr<mshadow::default_real_t>();
+  for (size_t i = 0; i < in1.shape().Size(); i++)
+    EXPECT_EQ(d1[i] + d2[i], o[i]);
+}
+
+TEST(MKLDNN_NDArray, CopyFrom) {
+  TestArrayShapes tas = GetTestArrayShapes();
+  std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
+
+  std::vector<NDArray> in_arrs = GetTestInputArrays();
+  for (auto in_arr : in_arrs) {
+    std::vector<NDArray> out_arrs = GetTestOutputArrays(in_arr.shape(), pds);
+    for (auto out_arr : out_arrs) {
+      if (in_arr.IsMKLDNNData() && in_arr.IsView())
+        in_arr = in_arr.Reorder2Default();
+      const mkldnn::memory *mem = in_arr.GetMKLDNNData();
+      out_arr.CopyFrom(*mem);
+      MKLDNNStream::Get()->Submit();
+      std::vector<NDArray *> inputs(1);
+      inputs[0] = &in_arr;
+      VerifyCopyResult(inputs, out_arr);
+    }
+  }
+}
+
 void TestUnaryOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
   std::vector<NDArray*> inputs(1);
   std::vector<NDArray*> outputs(1);
@@ -485,7 +564,53 @@ void TestUnaryOp(const OpAttrs &attrs, VerifyFunc 
verify_fn) {
         Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
                                     outputs, req, dispatch, 
mxnet::OpStatePtr());
         out_arr.WaitToRead();
-        verify_fn(in_arr, out_arr);
+        verify_fn(inputs, out_arr);
+      }
+    }
+  }
+
+  for (auto dispatch : dispatches) {
+    in_arrs = GetTestInputArrays();
+    for (auto arr : in_arrs) {
+      // If the array is a view, we shouldn't write data to it.
+      if (arr.IsView())
+        continue;
+
+      NDArray orig = arr.Copy(arr.ctx());
+      req[0] = kWriteInplace;
+      inputs[0] = &arr;
+      outputs[0] = &arr;
+      Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req,
+                                  dispatch, mxnet::OpStatePtr());
+      arr.WaitToRead();
+      inputs[0] = &orig;
+      verify_fn(inputs, arr);
+    }
+  }
+}
+
+void TestBinaryOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
+  std::vector<NDArray*> inputs(2);
+  std::vector<NDArray*> outputs(1);
+  std::vector<OpReqType> req(1);
+  std::vector<DispatchMode> dispatches = attrs.dispatches;
+
+  TestArrayShapes tas = GetTestArrayShapes();
+  std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
+
+  std::vector<NDArray> in_arrs = GetTestInputArrays();
+  for (auto in_arr1 : in_arrs) {
+    for (auto dispatch : dispatches) {
+      std::vector<NDArray> out_arrs = GetTestOutputArrays(in_arr1.shape(), 
pds);
+      for (auto out_arr : out_arrs) {
+        req[0] = kWriteTo;
+        inputs[0] = &in_arr1;
+        inputs[1] = &in_arr1;
+        outputs[0] = &out_arr;
+        Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
+                                    outputs, req, dispatch, 
mxnet::OpStatePtr());
+        out_arr.WaitToRead();
+        verify_fn(inputs, out_arr);
       }
     }
   }
@@ -500,11 +625,15 @@ void TestUnaryOp(const OpAttrs &attrs, VerifyFunc 
verify_fn) {
       NDArray orig = arr.Copy(arr.ctx());
       req[0] = kWriteInplace;
       inputs[0] = &arr;
+      inputs[1] = &arr;
       outputs[0] = &arr;
       Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req,
                                   dispatch, mxnet::OpStatePtr());
       arr.WaitToRead();
-      verify_fn(orig, arr);
+      std::vector<NDArray *> orig_inputs(2);
+      orig_inputs[0] = &orig;
+      orig_inputs[1] = &orig;
+      verify_fn(orig_inputs, arr);
     }
   }
 }
@@ -514,4 +643,10 @@ TEST(IMPERATIVE, UnaryOp) {
   TestUnaryOp(attrs, VerifyCopyResult);
 }
 
+
+TEST(IMPERATIVE, BinaryOp) {
+  OpAttrs attrs = GetSumOp();
+  TestBinaryOp(attrs, VerifySumResult);
+}
+
 #endif


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to