zheng-da commented on a change in pull request #8302: Refactor operators & 
MKLDNN
URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r162502678
 
 

 ##########
 File path: src/ndarray/ndarray.cc
 ##########
 @@ -181,6 +322,408 @@ void NDArray::set_fresh_out_grad(bool state) const {
   info.fresh_out_grad = state;
 }
 
+#if MXNET_USE_MKLDNN == 1
+static inline bool same_shape(const TShape &shape, mkldnn_dims_t dims, int 
ndims) {
+  if (shape.ndim() != (size_t)ndims)
+    return false;
+  for (int i = 0; i < ndims; i++)
+    if (shape[i] != dims[i])
+      return false;
+  return true;
+}
+
+static inline bool same_shape(const TShape &shape, int dtype, 
mkldnn::memory::desc desc) {
+  return same_shape(shape, desc.data.dims, desc.data.ndims)
+      && get_mkldnn_type(dtype) == desc.data.data_type;
+}
+
+bool NDArray::Chunk::IsMKLDNN() const {
+  if (storage_type != kDefaultStorage)
+    return false;
+  if (mkl_mem_ == nullptr)
+    return false;
+  auto desc = mkl_mem_->get_primitive_desc().desc();
+  return desc.data.format != GetDefaultFormat(desc);
+}
+
+bool NDArray::Chunk::IsDefault() const {
+  if (storage_type != kDefaultStorage)
+    return false;
+  // If we don't have mkldnn memory yet, we just assume it's not the default
+  // format.
+  if (mkl_mem_ == nullptr)
+    return true;
+  auto desc = mkl_mem_->get_primitive_desc().desc();
+  return desc.data.format == GetDefaultFormat(desc);
+}
+
+void NDArray::Chunk::Reorder2Default() {
+  if (mkl_mem_ == nullptr)
+    return;
+
+  auto format = GetDefaultFormat(mkl_mem_->get_primitive_desc().desc());
+  CHECK(format != mkl_mem_->get_primitive_desc().desc().data.format);
+
+  auto def_pd = GetPrimitiveDesc(mkl_mem_->get_primitive_desc(), format);
+  mkldnn_mem_ptr def_mem(new mkldnn::memory(def_pd));
+  // This may be called in MKLDNN operators. We can't use MKLDNNStream here.
+  std::vector<mkldnn::primitive> net;
+  net.push_back(mkldnn::reorder(*mkl_mem_, *def_mem));
+  mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+
+  CHECK(shandle.size >= def_pd.get_size());
+  CheckAndAlloc(def_pd.get_size());
+  // TODO(zhengda) We need to avoid memory copy here.
+  memcpy(shandle.dptr, def_mem->get_data_handle(), def_pd.get_size());
+  mkl_mem_.reset(new mkldnn::memory(def_pd, shandle.dptr));
+}
+
+void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {
+  // The shape of the array and the one of the MKL memory may mismatch.
+  // For example, if the array stores parameters, the MKL memory may store data
+  // in 5 dimensions while the NDArray stores data in 4 dimensions.
+  if (mkl_mem_ && mkl_mem_->get_data_handle() == shandle.dptr
+      && same_shape(shape, dtype, mkl_mem_->get_primitive_desc().desc())) {
+    return;
+  }
+
+  mkldnn::memory::dims dims;
+  // These are shapes supprted by MKLDNN.
+  if (shape.ndim() == 1 || shape.ndim() == 2 || shape.ndim() == 4
+      || shape.ndim() == 5) {
+    dims.resize(shape.ndim());
+    for (size_t i = 0; i < dims.size(); i++)
+      dims[i] = shape[i];
+  } else if (shape.ndim() == 3) {
+    // If there are 3 dimensions, we'll force it to 4 dimensions.
+    dims.resize(shape.ndim() + 1);
+    dims[0] = 1;
+    for (size_t i = 0; i < shape.ndim(); i++)
+      dims[i + 1] = shape[i];
+  } else {
+    LOG(FATAL) << "MKLDNN doesn't support " << shape.ndim() << " dimensions";
+  }
+  mkldnn::memory::format layout = mkldnn::memory::format::format_undef;
+  switch (dims.size()) {
+    case 1: layout = mkldnn::memory::format::x; break;
+    case 2: layout = mkldnn::memory::format::nc; break;
+    case 4: layout = mkldnn::memory::format::nchw; break;
+    // This isn't the right layout when the data has 5 dimensions in MXNet.
+    // MXNet interprets 5 dimensions as ncdhw, but MKLDNN doesn't have
+    // a corresponding format.
+    case 5: layout = mkldnn::memory::format::goihw; break;
+  }
+  mkldnn::memory::desc data_md{dims, get_mkldnn_type(dtype), layout};
+  auto cpu_engine = CpuEngine::Get()->get_engine();
+  if (shandle.dptr == nullptr) {
+    CHECK(delay_alloc);
+    CheckAndAlloc();
+  }
+  mkldnn::memory::primitive_desc pd(data_md, cpu_engine);
+  CHECK(shandle.size >= pd.get_size());
+  mkl_mem_.reset(new mkldnn::memory(pd, shandle.dptr));
+}
+
+/*
+ * Here we want to get MKLDNN memory whose primitive desc is exactly the same 
as
+ * the given one. operator== can't guarantee that. == can return true even if
+ * the formats are different. I need to double check its format.
+ */
+static inline mkldnn::memory *GetMKLDNNExact(
+    const mkldnn::memory *mem, mkldnn::memory::primitive_desc desc) {
+  auto src_desc = mem->get_primitive_desc();
+  if (desc == src_desc && desc.desc().data.format == 
src_desc.desc().data.format) {
+    return const_cast<mkldnn::memory *>(mem);
+  } else {
+    std::shared_ptr<mkldnn::memory> ret(new mkldnn::memory(
+            desc, mem->get_data_handle()));
+    MKLDNNStream::Get()->RegisterMem(ret);
+    return ret.get();
+  }
+}
+
+const mkldnn::memory *NDArray::GetMKLDNNData(
+    const mkldnn::memory::primitive_desc &desc) const {
+  if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
+    LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN 
memory desc";
+    return nullptr;
+  }
+  auto mem = GetMKLDNNData();
+  mkldnn::memory::primitive_desc _desc = desc;
+  auto desc1 = mem->get_primitive_desc().desc();
 
 Review comment:
   they are just some mkldnn::memory::desc.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to