This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2c5d7f7  Optimize transpose operator with MKL-DNN (#14545)
2c5d7f7 is described below

commit 2c5d7f768bdd1599c35f1a3cd1266efd051a9986
Author: Tao Lv <tao.a...@intel.com>
AuthorDate: Thu Apr 11 06:41:32 2019 +0800

    Optimize transpose operator with MKL-DNN (#14545)
    
    * add mkldnn transpose
    
    * general transpose
    
    * support mkldnn format
    
    * fix lint
    
    * address comments
    
    * add unit test
    
    * add comments
    
    * retrigger CI
---
 src/operator/nn/mkldnn/mkldnn_base-inl.h   |   2 +
 src/operator/nn/mkldnn/mkldnn_ops-inl.h    |   6 ++
 src/operator/nn/mkldnn/mkldnn_transpose.cc | 161 +++++++++++++++++++++++++++++
 src/operator/tensor/matrix_op-inl.h        |  15 +++
 src/operator/tensor/matrix_op.cc           |  34 ++++++
 tests/python/mkl/test_mkldnn.py            |  15 +++
 6 files changed, 233 insertions(+)

diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h 
b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 0a89c0f..a460e33 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -175,12 +175,14 @@ struct ConvolutionParam;
 struct DeconvolutionParam;
 struct SoftmaxParam;
 struct SoftmaxOutputParam;
+struct TransposeParam;
 bool SupportMKLDNNAct(const ActivationParam& param);
 bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
 bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
 bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray 
&input);
 bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
 bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
+bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
 }  // namespace op
 
 static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h 
b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 39f2632..f3f61b4 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -113,6 +113,12 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& 
attrs, const OpContext &ctx
 void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
          const mkldnn::memory &out);
 
+void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
+                            const OpContext &ctx,
+                            const NDArray &data,
+                            const OpReqType &req,
+                            const NDArray &output);
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/nn/mkldnn/mkldnn_transpose.cc 
b/src/operator/nn/mkldnn/mkldnn_transpose.cc
new file mode 100644
index 0000000..0986d06
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_transpose.cc
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_transpose.cc
+ * \brief Implement transpose operator via MKL-DNN reorder primitive
+ * \author Tao Lv
+*/
+
+#if MXNET_USE_MKLDNN == 1
+
+#include <mkldnn.hpp>
+#include "../../tensor/matrix_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+bool SupportMKLDNNTranspose(const TransposeParam& param,
+                            const NDArray &data) {
+  auto data_ndim = data.shape().ndim();
+
+  if (data_ndim > 4 || data.dtype() != mshadow::kFloat32)
+    return false;
+
+  return true;
+}
+
+typedef ParamOpSign<TransposeParam> MKLDNNTransposeSignature;
+
+class MKLDNNTransposeForward {
+  std::shared_ptr<mkldnn::memory> data_;
+  std::shared_ptr<mkldnn::memory> out_;
+  std::shared_ptr<mkldnn::memory::primitive_desc> dst_pd_;
+  std::shared_ptr<mkldnn::reorder> transpose_;
+
+ public:
+  MKLDNNTransposeForward(const TransposeParam& param,
+                         const NDArray &data) {
+    auto shape = data.shape();
+    auto data_ndim = shape.ndim();
+    auto axes_ndim = param.axes.ndim();
+    auto axes = mxnet::TShape(data_ndim);
+    if (axes_ndim == 0) {
+      for (size_t i = 0; i < data_ndim; i++) {
+        axes[i] = data_ndim - i - 1;
+      }
+    } else {
+      axes = param.axes;
+    }
+
+    auto engine = CpuEngine::Get()->get_engine();
+    auto in_mem = data.GetMKLDNNData();
+    auto src_pd = in_mem->get_primitive_desc();
+    data_ = std::make_shared<mkldnn::memory>(src_pd, nullptr);
+
+    // destination
+    // Not all formats are well defined with a certain name in MKL-DNN.
+    // For example, transpose(NCHW, (0, 2, 1, 3)) -> NHCW, which is not 
explicitly defined in
+    // MKL-DNN. To support general transposing, we need create destination 
format from scratch.
+    mkldnn_memory_desc_t dst_fmt;
+    dst_fmt.primitive_kind = mkldnn_memory;
+    dst_fmt.ndims = data_ndim;
+    dst_fmt.data_type = mkldnn_f32;
+    dst_fmt.format = mkldnn_blocked;
+
+    for (size_t i = 0; i < data_ndim; i++)
+      dst_fmt.dims[i] = shape[i];
+
+    unsigned int total_stride = 1;
+    for (int i = data_ndim - 1; i >= 0; i--) {
+      dst_fmt.layout_desc.blocking.padding_dims[i] = shape[i];
+      dst_fmt.layout_desc.blocking.block_dims[i] = 1;
+      dst_fmt.layout_desc.blocking.offset_padding_to_data[i]= 0;
+      // strides[0]: stride between the first elements of adjacent blocks.
+      dst_fmt.layout_desc.blocking.strides[0][axes[i]] = total_stride;
+      // strides[1]: strides between elements in the same block.
+      dst_fmt.layout_desc.blocking.strides[1][axes[i]] = 1;
+
+      total_stride *= shape[axes[i]];
+    }
+
+    dst_fmt.layout_desc.blocking.offset_padding = 0;
+    dst_pd_ = std::make_shared<mkldnn::memory::primitive_desc>(dst_fmt, 
engine);
+    out_ = std::make_shared<mkldnn::memory>(*dst_pd_, nullptr);
+
+    transpose_ = std::make_shared<mkldnn::reorder>(*data_, *out_);
+  }
+
+  void SetNewMem(const NDArray &data, const NDArray &output) {
+    if (data.IsMKLDNNData()) {
+      this->data_->set_data_handle(data.GetMKLDNNData()->get_data_handle());
+    } else {
+      MSHADOW_TYPE_SWITCH(data.dtype(), DTYPE, {
+        this->data_->set_data_handle(data.data().dptr<DTYPE>());
+      });
+    }
+
+    CHECK(!output.IsMKLDNNData());
+    MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, {
+      this->out_->set_data_handle(output.data().dptr<DTYPE>());
+    });
+  }
+
+  const mkldnn::reorder &GetFwd() const {
+    return *transpose_;
+  }
+};
+
+static MKLDNNTransposeForward &GetTransposeForward(const TransposeParam& param,
+                                                   const NDArray &data) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNTransposeSignature,
+                                         MKLDNNTransposeForward, OpHash> fwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNTransposeSignature,
+                                            MKLDNNTransposeForward, OpHash> 
fwds;
+#endif
+  MKLDNNTransposeSignature key(param);
+  key.AddSign(data);
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    MKLDNNTransposeForward fwd(param, data);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
+void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
+                            const OpContext &ctx,
+                            const NDArray &data,
+                            const OpReqType &req,
+                            const NDArray &output) {
+  const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
+
+  auto stream = MKLDNNStream::Get();
+  auto fwd = GetTransposeForward(param, data);
+
+  fwd.SetNewMem(data, output);
+  stream->RegisterPrim(fwd.GetFwd());
+  stream->Submit();
+}
+}  // namespace op
+}  // namespace mxnet
+#endif
diff --git a/src/operator/tensor/matrix_op-inl.h 
b/src/operator/tensor/matrix_op-inl.h
index 5eecda6..fa10815 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -238,6 +238,10 @@ struct TransposeParam : public 
dmlc::Parameter<TransposeParam> {
     DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape())
     .describe("Target axis order. By default the axes will be inverted.");
   }
+
+  bool operator==(const TransposeParam &other) const {
+    return this->axes == other.axes;
+  }
 };
 
 template<typename xpu>
@@ -2841,4 +2845,15 @@ inline uint32_t SplitNumOutputs(const NodeAttrs& attrs) {
 }  // namespace op
 }  // namespace mxnet
 
+namespace std {
+template<>
+struct hash<mxnet::op::TransposeParam> {
+  size_t operator()(const mxnet::op::TransposeParam& val) {
+    size_t ret = 0;
+    ret = dmlc::HashCombine(ret, val.axes);
+    return ret;
+  }
+};
+}  // namespace std
+
 #endif  // MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 3bca330..1431fef 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -339,6 +339,35 @@ Example::
   })
 .add_argument("data", "NDArray-or-Symbol", "Input array.");
 
+#if MXNET_USE_MKLDNN == 1
+static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                  const OpContext& ctx,
+                                  const std::vector<NDArray>& inputs,
+                                  const std::vector<OpReqType>& req,
+                                  const std::vector<NDArray>& outputs) {
+  const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
+  CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+
+  if (SupportMKLDNNTranspose(param, inputs[0])) {
+    MKLDNNTransposeForward(attrs, ctx, inputs[0], req[0], outputs[0]);
+    return;
+  }
+  FallBackCompute(Transpose<cpu>, attrs, ctx, inputs, req, outputs);
+}
+
+inline static bool TransposeStorageType(const nnvm::NodeAttrs& attrs,
+                                        const int dev_mask,
+                                        DispatchMode* dispatch_mode,
+                                        std::vector<int>* in_attrs,
+                                        std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, 
out_attrs);
+}
+#endif
+
 NNVM_REGISTER_OP(transpose)
 .describe(R"code(Permutes the dimensions of an array.
 
@@ -393,6 +422,11 @@ Examples::
     }
   })
 .set_attr<FCompute>("FCompute<cpu>", Transpose<cpu>)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FComputeEx>("FComputeEx<cpu>", TransposeComputeExCPU)
+.set_attr<FInferStorageType>("FInferStorageType", TransposeStorageType)
+#endif
 .add_argument("data", "NDArray-or-Symbol", "Source input")
 .add_arguments(TransposeParam::__FIELDS__());
 
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index 01ba03c..0610b60 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -473,6 +473,21 @@ def test_non_mkldnn_fcomputeex():
     exec1 = custom.bind(mx.cpu(), args={'data': mx.nd.ones([10,3,96,96]), 
'conv_weight': mx.nd.ones([8,3,5,5])})
     exec1.forward()[0].wait_to_read()
 
+@with_seed()
+def test_conv_transpose():
+    axes = [(0,2,1,3), (0,2,3,1), (1,2,3,0), (3,2,1,0)]
+    a = np.random.rand(10, 16, 50, 50)
+    b = np.random.rand(32, 16, 3, 3)
+    x = mx.nd.array(a)
+    w = mx.nd.array(b)
+    y = mx.nd.Convolution(data=x, weight=w, kernel=(3, 3), num_group=1, 
num_filter=32, no_bias=True)
+    for axis in axes:
+        t = mx.nd.transpose(y, axis)
+        t.wait_to_read()
+        s = y.asnumpy()
+        n = np.transpose(s, axis)
+        np.allclose(t.asnumpy(), n)
+
 
 if __name__ == '__main__':
     install.test_mkldnn_install()

Reply via email to