This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 37bdf0b  [MXNET-1453] Support the intput whose dimension is greater 
than 6 for Transpose and Rollaxis (#18707)
37bdf0b is described below

commit 37bdf0bf981d11a89bd248b02f473211d57bc9c6
Author: JackieWu <w...@live.cn>
AuthorDate: Fri Jul 17 01:25:01 2020 +0800

    [MXNET-1453] Support the intput whose dimension is greater than 6 for 
Transpose and Rollaxis (#18707)
    
    * support 6+ dims for transpose
    
    * test over
    
    * reorder code
    
    * fix transposeex
---
 src/operator/numpy/np_matrix_op-inl.h  |  51 ++++++++----
 src/operator/numpy/np_matrix_op.cc     |  17 +++-
 src/operator/tensor/matrix_op-inl.h    | 138 +++++++++++++++++++++++++++++++--
 src/operator/tensor/matrix_op.cc       |   4 +
 tests/python/unittest/test_numpy_op.py |   8 +-
 tests/python/unittest/test_operator.py |   4 +-
 6 files changed, 191 insertions(+), 31 deletions(-)

diff --git a/src/operator/numpy/np_matrix_op-inl.h 
b/src/operator/numpy/np_matrix_op-inl.h
index 0125feb..0fea76b 100644
--- a/src/operator/numpy/np_matrix_op-inl.h
+++ b/src/operator/numpy/np_matrix_op-inl.h
@@ -134,10 +134,10 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
                     const std::vector<TBlob>& inputs,
                     const std::vector<OpReqType>& req,
                     const std::vector<TBlob>& outputs) {
-  const NumpyTransposeParam& param = 
nnvm::get<NumpyTransposeParam>(attrs.parsed);
   if (req[0] == kNullOp) return;
   CHECK(req[0] == kWriteTo || req[0] == kAddTo)
-      << "Transpose only supports kWriteTo, kNullOp and kAddTo";
+      << "Transpose does not support inplace";
+  const NumpyTransposeParam& param = 
nnvm::get<NumpyTransposeParam>(attrs.parsed);
   mxnet::TShape axes;
   if (ndim_is_known(param.axes)) {
     axes = common::CanonicalizeAxes(param.axes);
@@ -147,10 +147,14 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
       axes[i] = axes.ndim() - 1 - i;
     }
   }
+  mshadow::Tensor<xpu, 1, dim_t> workspace =
+    GetTransposeExWorkspace<xpu>(ctx, axes);
   if (req[0] == kAddTo) {
-    TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes);
+    TransposeExImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0],
+        axes, workspace);
   } else {
-    TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes);
+    TransposeExImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0],
+        axes, workspace);
   }
 }
 
@@ -779,13 +783,21 @@ void NumpyRollaxisCompute(const nnvm::NodeAttrs& attrs,
   using namespace mshadow::expr;
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
-  CHECK_EQ(req[0], kWriteTo) << "Rollaxis does not support inplace";
-  mxnet::TShape axes;
+  if (req[0] == kNullOp) return;
+  CHECK(req[0] == kWriteTo || req[0] == kAddTo)
+      << "Rollaxis does not support inplace";
   const NumpyRollaxisParam& param = 
nnvm::get<NumpyRollaxisParam>(attrs.parsed);
-  axes = NumpyRollaxisShapeImpl(param.axis, param.start, inputs[0].ndim());
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
-    TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
-  })
+  mxnet::TShape axes = NumpyRollaxisShapeImpl(param.axis, param.start, 
inputs[0].ndim());
+
+  mshadow::Tensor<xpu, 1, dim_t> workspace =
+    GetTransposeExWorkspace<xpu>(ctx, axes);
+  if (req[0] == kAddTo) {
+    TransposeExImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0],
+        axes, workspace);
+  } else {
+    TransposeExImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0],
+        axes, workspace);
+  }
 }
 
 template<typename xpu>
@@ -796,6 +808,9 @@ void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs,
                             const std::vector<TBlob> &outputs) {
   using namespace mshadow;
   using namespace mshadow::expr;
+  if (req[0] == kNullOp) return;
+  CHECK(req[0] == kWriteTo || req[0] == kAddTo)
+      << "Rollaxis Backward does not support inplace";
   const NumpyRollaxisParam& param = 
nnvm::get<NumpyRollaxisParam>(attrs.parsed);
   int axis_origin = param.axis;
   int start_origin = param.start;
@@ -819,11 +834,17 @@ void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs,
     axis = start_origin;
     start = axis_origin + 1;
   }
-  mxnet::TShape axes;
-  axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim());
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
-    TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
-  })
+  mxnet::TShape axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim());
+
+  mshadow::Tensor<xpu, 1, dim_t> workspace =
+    GetTransposeExWorkspace<xpu>(ctx, axes);
+  if (req[0] == kAddTo) {
+    TransposeExImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0],
+        axes, workspace);
+  } else {
+    TransposeExImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0],
+        axes, workspace);
+  }
 }
 
 struct NumpyRot90Param : public dmlc::Parameter<NumpyRot90Param> {
diff --git a/src/operator/numpy/np_matrix_op.cc 
b/src/operator/numpy/np_matrix_op.cc
index da9839f..2bb2fe3 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -51,7 +51,6 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(out_attrs->size(), 1U);
   mxnet::TShape& shp = (*in_attrs)[0];
   mxnet::TShape& out_shp = (*out_attrs)[0];
-  CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
 
   int ndim = -1;
   if (ndim_is_known(shp)) {
@@ -133,6 +132,10 @@ NNVM_REGISTER_OP(_npi_transpose)
     }
   })
 .set_attr<FCompute>("FCompute<cpu>", NumpyTranspose<cpu>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
 .set_attr<nnvm::FListInputNames>("FListInputNames",
   [](const NodeAttrs& attrs) {
     return std::vector<std::string>{"a"};
@@ -1261,7 +1264,6 @@ bool NumpyRollaxisShape(const nnvm::NodeAttrs& attrs,
 
   // check transpose dimentions no more than 6
   mxnet::TShape& shp = (*in_attrs)[0];
-  CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
 
   // check axis and start range
   CHECK_GE(param.axis, -shp.ndim())
@@ -1304,6 +1306,10 @@ until it lies in a given position.)code" ADD_FILELINE)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisCompute<cpu>)
 .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseNone{"_npi_rollaxis_backward"})
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
 .add_argument("data", "NDArray-or-Symbol", "Input ndarray")
 .add_arguments(NumpyRollaxisParam::__FIELDS__());
 
@@ -1312,7 +1318,11 @@ NNVM_REGISTER_OP(_npi_rollaxis_backward)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<NumpyRollaxisParam>)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisBackward<cpu>);
+.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisBackward<cpu>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  });
 
 template<>
 void NumpyFlipForwardImpl<cpu>(const OpContext& ctx,
@@ -1368,7 +1378,6 @@ bool NumpyMoveaxisShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 1U);
   mxnet::TShape& shp = (*in_attrs)[0];
-  CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
   CHECK_EQ(param.source.ndim(), param.destination.ndim())
     << "source and destination not equal.";
   mxnet::TShape ret(shp.ndim(), -1);
diff --git a/src/operator/tensor/matrix_op-inl.h 
b/src/operator/tensor/matrix_op-inl.h
index 79c2a6d..6c125ed 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -321,15 +321,16 @@ inline bool IsIdentityTranspose(const TShape& axes) {
 }
 
 template<typename xpu, bool is_addto = false>
-void TransposeImpl(RunContext ctx,
+bool TransposeCommonImpl(RunContext ctx,
                    const TBlob& src,
                    const TBlob& ret,
                    const mxnet::TShape& axes) {
+  // return true when running successfully, otherwise false
   using namespace mshadow;
   using namespace mshadow::expr;
   CHECK_EQ(src.type_flag_, ret.type_flag_);
   // zero-size tensor, no need to compute
-  if (src.shape_.Size() == 0U) return;
+  if (src.shape_.Size() == 0U) return true;
   Stream<xpu> *s = ctx.get_stream<xpu>();
 #ifdef __CUDACC__
   // This transpose can be used only if there exist n and m such that:
@@ -339,7 +340,7 @@ void TransposeImpl(RunContext ctx,
     MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
       transpose_pseudo2D<DType, is_addto>(ret, src, axes, s);
     });
-    return;
+    return true;
   }
 #endif
   // Special handle the identity case
@@ -355,7 +356,7 @@ void TransposeImpl(RunContext ctx,
             s, ret.Size(), out.dptr_, in.dptr_);
       }
     });
-    return;
+    return true;
   }
   // Handle the general transpose case
   MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
@@ -413,10 +414,127 @@ void TransposeImpl(RunContext ctx,
       break;
      }
      default:
-      LOG(FATAL) << "Transpose support at most 6 dimensions";
+      // return false when dimensions > 6
+      return false;
       break;
     }
   });
+  return true;
+}
+
+template<typename xpu, bool is_addto = false>
+void TransposeImpl(RunContext ctx,
+                   const TBlob& src,
+                   const TBlob& ret,
+                   const mxnet::TShape& axes) {
+  CHECK_LE(axes.ndim(), 6) << "TransposeImpl supports at most 6 dimensions";
+  CHECK((TransposeCommonImpl<xpu, is_addto>(ctx, src, ret, axes))) <<
+    "Failed to execute TransposeImpl Operator";
+}
+
+template <bool is_addto>
+struct TransposeExKernel {
+  /*!
+   * \brief
+   * \param tid      global thread id
+   * \param out_data output data
+   * \param in_data  input data
+   * \param strides  input strides and output strides
+   * \param ndim     the number of dimension
+   */
+  template <typename DType>
+  MSHADOW_XINLINE static void Map(int tid,
+      DType *out_data,
+      const DType *in_data,
+      const dim_t *strides,
+      const int ndim
+      ) {
+    // tid is the index of input data
+    const dim_t* const out_strides = strides + ndim;
+    int k = tid;
+    int out_id = 0;
+    for (int i = 0; i < ndim; ++i) {
+      out_id += (k / strides[i]) * out_strides[i];
+      k %= strides[i];
+    }
+    if (is_addto)
+      out_data[out_id] += in_data[tid];
+    else
+      out_data[out_id] = in_data[tid];
+  }
+};
+
+template<typename xpu, bool is_addto = false>
+void TransposeExImpl(RunContext ctx,
+                   const TBlob& src,
+                   const TBlob& ret,
+                   const mxnet::TShape& axes,
+                   mshadow::Tensor<xpu, 1, dim_t>& strides_xpu
+                   ) {
+  /*
+   * If ndim <= 6, it is not necessary to allocate any space for `strides_xpu`
+   * If ndim > 6, `strides_xpu` should be allocated `ndim * 2` elements
+   */
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  if (TransposeCommonImpl<xpu, is_addto>(ctx, src, ret, axes)) return;
+  CHECK_GT(axes.ndim(), 6) <<
+    "Failed to execute TransposeExImpl when axes.ndim() <= 6";
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
+      CHECK_EQ(strides_xpu.MSize(), axes.ndim() * 2) << \
+        "If ndim > 6, `strides_xpu` should be allocated `ndim * 2` elements";
+
+      const mxnet::TShape &in_shape = src.shape_;
+      // strides: in_strides and out_strides
+      const int ndim = axes.ndim();
+      std::vector<dim_t> strides(ndim * 2);
+      // compute in_strides
+      strides[ndim - 1] = 1;
+      for (int i = ndim - 2; i >= 0; --i) {
+        strides[i] = strides[i + 1] * in_shape[i + 1];
+      }
+      // compute out_strides
+      std::vector<dim_t> tmp_strides(ndim);
+      tmp_strides[ndim - 1] = 1;
+      for (int i = ndim - 2; i >= 0; --i) {
+        tmp_strides[i] = tmp_strides[i + 1] * in_shape[axes[i + 1]];
+      }
+      // reorder tmp_strides to out_strides
+      dim_t * const out_strides = &strides[ndim];
+      for (int i = 0; i < ndim; ++i) {
+        out_strides[axes[i]] = tmp_strides[i];
+      }
+      Shape<1> strides_shape;
+      strides_shape[0] = ndim * 2;
+      Tensor<cpu, 1, dim_t> strides_cpu(strides.data(), strides_shape);
+      // copy arguments into xpu context
+      Copy(strides_xpu, strides_cpu, s);
+      const DType *in = src.dptr<DType>();
+      DType *out = ret.dptr<DType>();
+      if (is_addto) {
+        mxnet_op::Kernel<TransposeExKernel<true>, xpu>::Launch(s,
+            in_shape.Size(), out, in, strides_xpu.dptr_, ndim);
+      } else {
+        mxnet_op::Kernel<TransposeExKernel<false>, xpu>::Launch(s,
+            in_shape.Size(), out, in, strides_xpu.dptr_, ndim);
+      }
+  });
+}
+
+template<typename xpu>
+mshadow::Tensor<xpu, 1, dim_t> GetTransposeExWorkspace(
+      const OpContext& ctx,
+      const mxnet::TShape& axes
+    ) {
+  if (axes.ndim() > 6) {
+    // allocate workspace when axes.ndim() > 6
+    mshadow::Shape<1> strides_shape;
+    strides_shape[0] = axes.ndim() * 2;
+    return ctx.requested[0].get_space_typed<xpu, 1, dim_t>(
+        strides_shape, ctx.get_stream<xpu>());
+  }
+  return {};
 }
 
 // matrix transpose
@@ -441,10 +559,15 @@ void Transpose(const nnvm::NodeAttrs& attrs,
   } else {
     axes = common::CanonicalizeAxes(param.axes);
   }
+
+  mshadow::Tensor<xpu, 1, dim_t> workspace =
+    GetTransposeExWorkspace<xpu>(ctx, axes);
   if (req[0] == kAddTo) {
-    TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes);
+    TransposeExImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0],
+        axes, workspace);
   } else {
-    TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes);
+    TransposeExImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0],
+        axes, workspace);
   }
 }
 
@@ -458,7 +581,6 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
   mxnet::TShape& out_shp = (*out_attrs)[0];
   if (!mxnet::ndim_is_known(shp) && !mxnet::ndim_is_known(out_shp))
     return false;  // none of the shapes is known
-  CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
   if (out_shp.ndim() >= 0 && shp.ndim() >= 0)
     CHECK_EQ(out_shp.ndim(), shp.ndim());
   mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1);
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 8a86f7a..8c3d14f 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -351,6 +351,10 @@ Examples::
     }
   })
 .set_attr<FCompute>("FCompute<cpu>", Transpose<cpu>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& n) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
 #if MXNET_USE_MKLDNN == 1
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", TransposeComputeExCPU)
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index 2599db0..88ad77f 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -2440,7 +2440,11 @@ def test_np_broadcast_to_npx(src_shape, npx_dst_shape, 
np_dst_shape, hybridize):
     [(8, 2, 16), [(0, 2, 1), (2, 0, 1), (0, 1, 2), (2, 1, 0), (-1, -2, -3)]],
     [(8, 3, 4, 8), [(0, 2, 3, 1), (1, 2, 3, 0), (0, 3, 2, 1)]],
     [(8, 3, 2, 3, 8), [(0, 1, 3, 2, 4), (0, 1, 2, 3, 4), (4, 0, 1, 2, 3)]],
-    [(3, 4, 3, 4, 3, 2), [(0, 1, 3, 2, 4, 5), (2, 3, 4, 1, 0, 5), None]]
+    [(3, 4, 3, 4, 3, 2), [(0, 1, 3, 2, 4, 5), (2, 3, 4, 1, 0, 5), None]],
+    [(3, 4, 3, 4, 3, 2, 2), [(0, 1, 3, 2, 4, 5, 6),
+     (2, 3, 4, 1, 0, 5, 6), None]],
+    [(3, 4, 3, 4, 3, 2, 3, 2), [(0, 1, 3, 2, 4, 5, 7, 6),
+     (2, 3, 4, 1, 0, 5, 7, 6), None]],
 ])
 @pytest.mark.parametrize('grad_req', ['write', 'add'])
 def test_np_transpose(data_shape, axes_workload, hybridize, dtype, grad_req):
@@ -10117,7 +10121,7 @@ def test_np_rollaxis():
     dtypes = ['int32', 'int64', 'float16', 'float32', 'float64']
     for hybridize in [False, True]:
         for dtype in dtypes:
-            for ndim in [0, 1, 2, 3, 4, 5, 6]:
+            for ndim in [0, 1, 2, 3, 4, 5, 6, 7, 8]:
                 shape = rand_shape_nd(ndim, dim=5, allow_zero_size=True)
                 np_data = _np.random.uniform(low=-100, high=100, 
size=shape).astype(dtype)
                 mx_data = np.array(np_data, dtype=dtype)
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 6546751..1578e14 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -2573,9 +2573,9 @@ def test_broadcast():
 
 @with_seed()
 def test_transpose():
-    for ndim in range(1, 7):
+    for ndim in range(1, 10):
         for t in range(5):
-            dims = list(np.random.randint(1, 10, size=ndim))
+            dims = list(np.random.randint(1, 5, size=ndim))
             axes = list(range(ndim))
             random.shuffle(axes)
             axes = tuple(axes)

Reply via email to