zheng-da commented on a change in pull request #8302: Refactor operators & 
MKLDNN
URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r156520457
 
 

 ##########
 File path: src/ndarray/ndarray.cc
 ##########
 @@ -64,17 +158,147 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {
   return ret;
 }
 
+#if MXNET_USE_MKLDNN == 1
+
+static inline mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc 
desc) {
+  if (desc.data.ndims == 1) {
+    return desc.data.format;
+  } else if (desc.data.ndims == 2) {
+    if (desc.data.format == mkldnn_io)
+      return mkldnn_oi;
+    else
+      return desc.data.format;
+  } else if (desc.data.ndims == 4) {
+    switch (desc.data.format) {
+      case mkldnn_nchw:
+      case mkldnn_nhwc:
+      case mkldnn_chwn:
+      case mkldnn_nChw8c:
+      case mkldnn_nChw16c:
+        return mkldnn_nchw;
+      case mkldnn_oihw:
+      case mkldnn_ihwo:
+      case mkldnn_hwio:
+      case mkldnn_OIhw8i8o:
+      case mkldnn_OIhw16i16o:
+      case mkldnn_OIhw8i16o2i:
+      case mkldnn_OIhw8o16i2o:
+      case mkldnn_OIhw8o8i:
+      case mkldnn_OIhw16o16i:
+      case mkldnn_IOhw16o16i:
+      case mkldnn_Oihw8o:
+      case mkldnn_Oihw16o:
+      case mkldnn_Ohwi8o:
+      case mkldnn_Ohwi16o:
+      case mkldnn_OhIw16o4i:
+        return mkldnn_oihw;
+      default:
+        LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << 
desc.data.format;
+        return mkldnn_format_undef;
+    }
+  } else if (desc.data.ndims == 5) {
+    switch (desc.data.format) {
+      case mkldnn_goihw:
+      case mkldnn_gOIhw8i8o:
+      case mkldnn_gOIhw16i16o:
+      case mkldnn_gOIhw8i16o2i:
+      case mkldnn_gOIhw8o16i2o:
+      case mkldnn_gOIhw8o8i:
+      case mkldnn_gOIhw16o16i:
+      case mkldnn_gIOhw16o16i:
+      case mkldnn_gOihw8o:
+      case mkldnn_gOihw16o:
+      case mkldnn_gOhwi8o:
+      case mkldnn_gOhwi16o:
+      case mkldnn_gOhIw16o4i:
+        return mkldnn_goihw;
+      default:
+        LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << 
desc.data.format;
+        return mkldnn_format_undef;
+    }
+  } else {
+    LOG(FATAL) << "Unsupported dimensions: " << desc.data.ndims;
+    return mkldnn_format_undef;
+  }
+}
+
+static inline mkldnn_mem_ptr Reorder2Default(mkldnn_mem_ptr mem,
+                                             bool submit_now = true) {
+  auto format = GetDefaultFormat(mem->get_primitive_desc().desc());
+  if (format == mem->get_primitive_desc().desc().data.format)
+    return mem;
+
+  auto pd = mem->get_primitive_desc();
+  mkldnn::memory::dims dims(pd.desc().data.ndims);
+  for (size_t i = 0; i < dims.size(); i++)
+    dims[i] = pd.desc().data.dims[i];
+  mkldnn::memory::format cpp_format = 
static_cast<mkldnn::memory::format>(format);
+  mkldnn::memory::data_type cpp_type = static_cast<mkldnn::memory::data_type>(
+      pd.desc().data.data_type);
+  mkldnn::memory::desc data_md(dims, cpp_type, cpp_format);
+  mkldnn_mem_ptr def_mem(new 
mkldnn::memory(mkldnn::memory::primitive_desc(data_md,
+          pd.get_engine())));
+
+  MKLDNNStream &stream = MKLDNNStream::Instance();
+  stream.RegisterMem(mem);
+  stream.RegisterMem(def_mem);
+  stream.RegisterPrim(mkldnn::reorder(*mem, *def_mem));
+  if (submit_now)
+    stream.Submit();
+  return def_mem;
+}
+
+NDArray NDArray::ReshapeMKLDNN(const TShape &shape) const {
+  CHECK(!is_none()) << "NDArray is not initialized";
+  CHECK_GE(shape_.Size(), shape.Size())
+    << "NDArray.Reshape: target shape size is larger current shape";
+  if (storage_type() == kDefaultStorage) {
+    NDArray ret = this->Detach();
+    ret.shape_ = shape;
+    return ret;
+  } else if (storage_type() == kMKLDNNStorage) {
+    NDArray ret(kMKLDNNStorage, shape, ctx(), ptr_->delay_alloc, dtype());
 
 Review comment:
   Actually, we should avoid allocating memory here. Should I always use true 
here?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to