zheng-da commented on a change in pull request #8302: Refactor operators & 
MKLDNN
URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r156519807
 
 

 ##########
 File path: src/ndarray/ndarray.cc
 ##########
 @@ -64,17 +158,147 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {
   return ret;
 }
 
+#if MXNET_USE_MKLDNN == 1
+
+static inline mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc 
desc) {
+  if (desc.data.ndims == 1) {
+    return desc.data.format;
+  } else if (desc.data.ndims == 2) {
+    if (desc.data.format == mkldnn_io)
+      return mkldnn_oi;
+    else
+      return desc.data.format;
+  } else if (desc.data.ndims == 4) {
+    switch (desc.data.format) {
+      case mkldnn_nchw:
+      case mkldnn_nhwc:
+      case mkldnn_chwn:
+      case mkldnn_nChw8c:
+      case mkldnn_nChw16c:
+        return mkldnn_nchw;
+      case mkldnn_oihw:
+      case mkldnn_ihwo:
+      case mkldnn_hwio:
+      case mkldnn_OIhw8i8o:
+      case mkldnn_OIhw16i16o:
+      case mkldnn_OIhw8i16o2i:
+      case mkldnn_OIhw8o16i2o:
+      case mkldnn_OIhw8o8i:
+      case mkldnn_OIhw16o16i:
+      case mkldnn_IOhw16o16i:
+      case mkldnn_Oihw8o:
+      case mkldnn_Oihw16o:
+      case mkldnn_Ohwi8o:
+      case mkldnn_Ohwi16o:
+      case mkldnn_OhIw16o4i:
+        return mkldnn_oihw;
+      default:
+        LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << 
desc.data.format;
+        return mkldnn_format_undef;
+    }
+  } else if (desc.data.ndims == 5) {
+    switch (desc.data.format) {
+      case mkldnn_goihw:
+      case mkldnn_gOIhw8i8o:
+      case mkldnn_gOIhw16i16o:
+      case mkldnn_gOIhw8i16o2i:
+      case mkldnn_gOIhw8o16i2o:
+      case mkldnn_gOIhw8o8i:
+      case mkldnn_gOIhw16o16i:
+      case mkldnn_gIOhw16o16i:
+      case mkldnn_gOihw8o:
+      case mkldnn_gOihw16o:
+      case mkldnn_gOhwi8o:
+      case mkldnn_gOhwi16o:
+      case mkldnn_gOhIw16o4i:
+        return mkldnn_goihw;
+      default:
+        LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << 
desc.data.format;
+        return mkldnn_format_undef;
+    }
+  } else {
+    LOG(FATAL) << "Unsupported dimensions: " << desc.data.ndims;
+    return mkldnn_format_undef;
+  }
+}
+
+static inline mkldnn_mem_ptr Reorder2Default(mkldnn_mem_ptr mem,
+                                             bool submit_now = true) {
+  auto format = GetDefaultFormat(mem->get_primitive_desc().desc());
+  if (format == mem->get_primitive_desc().desc().data.format)
+    return mem;
+
+  auto pd = mem->get_primitive_desc();
+  mkldnn::memory::dims dims(pd.desc().data.ndims);
+  for (size_t i = 0; i < dims.size(); i++)
+    dims[i] = pd.desc().data.dims[i];
+  mkldnn::memory::format cpp_format = 
static_cast<mkldnn::memory::format>(format);
+  mkldnn::memory::data_type cpp_type = static_cast<mkldnn::memory::data_type>(
+      pd.desc().data.data_type);
+  mkldnn::memory::desc data_md(dims, cpp_type, cpp_format);
+  mkldnn_mem_ptr def_mem(new 
mkldnn::memory(mkldnn::memory::primitive_desc(data_md,
+          pd.get_engine())));
+
+  MKLDNNStream &stream = MKLDNNStream::Instance();
+  stream.RegisterMem(mem);
+  stream.RegisterMem(def_mem);
+  stream.RegisterPrim(mkldnn::reorder(*mem, *def_mem));
+  if (submit_now)
+    stream.Submit();
+  return def_mem;
+}
+
+NDArray NDArray::ReshapeMKLDNN(const TShape &shape) const {
+  CHECK(!is_none()) << "NDArray is not initialized";
+  CHECK_GE(shape_.Size(), shape.Size())
+    << "NDArray.Reshape: target shape size is larger current shape";
+  if (storage_type() == kDefaultStorage) {
+    NDArray ret = this->Detach();
+    ret.shape_ = shape;
+    return ret;
+  } else if (storage_type() == kMKLDNNStorage) {
+    NDArray ret(kMKLDNNStorage, shape, ctx(), ptr_->delay_alloc, dtype());
+    CHECK(ptr_->Mkl_mem_ != nullptr);
+    // We shouldn't submit the reorder primitive here because submit will
+    // be called in operators.
+    ret.ptr_->Mkl_mem_ = Reorder2Default(ptr_->Mkl_mem_, false);
+    return ret;
+  }
+  LOG(FATAL) << "Reshape for storage type " << storage_type() << " is not 
implemented yet";
+  return NDArray();
+}
+
+#endif
+
 NDArray NDArray::Reshape(const TShape &shape) const {
   CHECK(!is_none()) << "NDArray is not initialized";
-  auto stype = storage_type();
-  // reshape is not supported for non-default ndarray with dismatching shapes
-  CHECK((shape_ == shape) || stype == kDefaultStorage)
-    << "Reshape for storage type " << stype << " is not implemented yet";
   CHECK_GE(shape_.Size(), shape.Size())
     << "NDArray.Reshape: target shape size is larger current shape";
-  NDArray ret = this->Detach();
-  ret.shape_ = shape;
-  return ret;
+  if (storage_type() == kDefaultStorage) {
+    NDArray ret = this->Detach();
+    ret.shape_ = shape;
+    return ret;
+#if MXNET_USE_MKLDNN == 1
+  } else if (storage_type() == kMKLDNNStorage) {
+    NDArray ret(kMKLDNNStorage, shape, ctx(), ptr_->delay_alloc, dtype());
+    // We need to convert the MKL memory to the default layout.
+    Engine::Get()->PushSync([&](RunContext ctx) {
+        if (this->ptr_->Mkl_mem_) {
+          auto def_format = 
GetDefaultFormat(this->ptr_->Mkl_mem_->get_primitive_desc().desc());
+          if (this->ptr_->Mkl_mem_->get_primitive_desc().desc().data.format != 
def_format) {
+            ret.ptr_->Mkl_mem_ = Reorder2Default(this->ptr_->Mkl_mem_);
+          } else {
+            ret.ptr_->Mkl_mem_ = this->ptr_->Mkl_mem_;
+          }
+        }
+    }, ctx(), {this->var()}, {ret.var()},
+    FnProperty::kNormal, 0, PROFILER_MESSAGE("SyncMKLDNN2Default"));
+    ret.WaitToRead();
 
 Review comment:
   I tried not using WaitToRead and saw some non-deterministic memory error. 
I'm not entirely sure how the async exec engine works. I wanted to ask you what 
happens if WaitToRead() isn't called here.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to