This is an automated email from the ASF dual-hosted git repository. zhasheng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 37bdf0b [MXNET-1453] Support the intput whose dimension is greater than 6 for Transpose and Rollaxis (#18707) 37bdf0b is described below commit 37bdf0bf981d11a89bd248b02f473211d57bc9c6 Author: JackieWu <w...@live.cn> AuthorDate: Fri Jul 17 01:25:01 2020 +0800 [MXNET-1453] Support the intput whose dimension is greater than 6 for Transpose and Rollaxis (#18707) * support 6+ dims for transpose * test over * reorder code * fix transposeex --- src/operator/numpy/np_matrix_op-inl.h | 51 ++++++++---- src/operator/numpy/np_matrix_op.cc | 17 +++- src/operator/tensor/matrix_op-inl.h | 138 +++++++++++++++++++++++++++++++-- src/operator/tensor/matrix_op.cc | 4 + tests/python/unittest/test_numpy_op.py | 8 +- tests/python/unittest/test_operator.py | 4 +- 6 files changed, 191 insertions(+), 31 deletions(-) diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 0125feb..0fea76b 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -134,10 +134,10 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req, const std::vector<TBlob>& outputs) { - const NumpyTransposeParam& param = nnvm::get<NumpyTransposeParam>(attrs.parsed); if (req[0] == kNullOp) return; CHECK(req[0] == kWriteTo || req[0] == kAddTo) - << "Transpose only supports kWriteTo, kNullOp and kAddTo"; + << "Transpose does not support inplace"; + const NumpyTransposeParam& param = nnvm::get<NumpyTransposeParam>(attrs.parsed); mxnet::TShape axes; if (ndim_is_known(param.axes)) { axes = common::CanonicalizeAxes(param.axes); @@ -147,10 +147,14 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, axes[i] = axes.ndim() - 1 - i; } } + mshadow::Tensor<xpu, 1, dim_t> workspace = + GetTransposeExWorkspace<xpu>(ctx, axes); if (req[0] == kAddTo) { - TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes); + TransposeExImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], + axes, workspace); } else { - TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes); + TransposeExImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], + axes, workspace); } } @@ -779,13 +783,21 @@ void NumpyRollaxisCompute(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); - CHECK_EQ(req[0], kWriteTo) << "Rollaxis does not support inplace"; - mxnet::TShape axes; + if (req[0] == kNullOp) return; + CHECK(req[0] == kWriteTo || req[0] == kAddTo) + << "Rollaxis does not support inplace"; const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed); - axes = NumpyRollaxisShapeImpl(param.axis, param.start, inputs[0].ndim()); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, { - TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes); - }) + mxnet::TShape axes = NumpyRollaxisShapeImpl(param.axis, param.start, inputs[0].ndim()); + + mshadow::Tensor<xpu, 1, dim_t> workspace = + GetTransposeExWorkspace<xpu>(ctx, axes); + if (req[0] == kAddTo) { + TransposeExImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], + axes, workspace); + } else { + TransposeExImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], + axes, workspace); + } } template<typename xpu> @@ -796,6 +808,9 @@ void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs, const std::vector<TBlob> &outputs) { using namespace mshadow; using namespace mshadow::expr; + if (req[0] == kNullOp) return; + CHECK(req[0] == kWriteTo || req[0] == kAddTo) + << "Rollaxis Backward does not support inplace"; const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed); int axis_origin = param.axis; int start_origin = param.start; @@ -819,11 +834,17 @@ void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs, axis = start_origin; start = axis_origin + 1; } - mxnet::TShape axes; - axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim()); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, { - TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes); - }) + mxnet::TShape axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim()); + + mshadow::Tensor<xpu, 1, dim_t> workspace = + GetTransposeExWorkspace<xpu>(ctx, axes); + if (req[0] == kAddTo) { + TransposeExImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], + axes, workspace); + } else { + TransposeExImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], + axes, workspace); + } } struct NumpyRot90Param : public dmlc::Parameter<NumpyRot90Param> { diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index da9839f..2bb2fe3 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -51,7 +51,6 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& shp = (*in_attrs)[0]; mxnet::TShape& out_shp = (*out_attrs)[0]; - CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; int ndim = -1; if (ndim_is_known(shp)) { @@ -133,6 +132,10 @@ NNVM_REGISTER_OP(_npi_transpose) } }) .set_attr<FCompute>("FCompute<cpu>", NumpyTranspose<cpu>) +.set_attr<FResourceRequest>("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; + }) .set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { return std::vector<std::string>{"a"}; @@ -1261,7 +1264,6 @@ bool NumpyRollaxisShape(const nnvm::NodeAttrs& attrs, // check transpose dimentions no more than 6 mxnet::TShape& shp = (*in_attrs)[0]; - CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; // check axis and start range CHECK_GE(param.axis, -shp.ndim()) @@ -1304,6 +1306,10 @@ until it lies in a given position.)code" ADD_FILELINE) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisCompute<cpu>) .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_npi_rollaxis_backward"}) +.set_attr<FResourceRequest>("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; + }) .add_argument("data", "NDArray-or-Symbol", "Input ndarray") .add_arguments(NumpyRollaxisParam::__FIELDS__()); @@ -1312,7 +1318,11 @@ NNVM_REGISTER_OP(_npi_rollaxis_backward) .set_num_outputs(1) .set_attr_parser(ParamParser<NumpyRollaxisParam>) .set_attr<nnvm::TIsBackward>("TIsBackward", true) -.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisBackward<cpu>); +.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisBackward<cpu>) +.set_attr<FResourceRequest>("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; + }); template<> void NumpyFlipForwardImpl<cpu>(const OpContext& ctx, @@ -1368,7 +1378,6 @@ bool NumpyMoveaxisShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& shp = (*in_attrs)[0]; - CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; CHECK_EQ(param.source.ndim(), param.destination.ndim()) << "source and destination not equal."; mxnet::TShape ret(shp.ndim(), -1); diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 79c2a6d..6c125ed 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -321,15 +321,16 @@ inline bool IsIdentityTranspose(const TShape& axes) { } template<typename xpu, bool is_addto = false> -void TransposeImpl(RunContext ctx, +bool TransposeCommonImpl(RunContext ctx, const TBlob& src, const TBlob& ret, const mxnet::TShape& axes) { + // return true when running successfully, otherwise false using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(src.type_flag_, ret.type_flag_); // zero-size tensor, no need to compute - if (src.shape_.Size() == 0U) return; + if (src.shape_.Size() == 0U) return true; Stream<xpu> *s = ctx.get_stream<xpu>(); #ifdef __CUDACC__ // This transpose can be used only if there exist n and m such that: @@ -339,7 +340,7 @@ void TransposeImpl(RunContext ctx, MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, { transpose_pseudo2D<DType, is_addto>(ret, src, axes, s); }); - return; + return true; } #endif // Special handle the identity case @@ -355,7 +356,7 @@ void TransposeImpl(RunContext ctx, s, ret.Size(), out.dptr_, in.dptr_); } }); - return; + return true; } // Handle the general transpose case MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, { @@ -413,10 +414,127 @@ void TransposeImpl(RunContext ctx, break; } default: - LOG(FATAL) << "Transpose support at most 6 dimensions"; + // return false when dimensions > 6 + return false; break; } }); + return true; +} + +template<typename xpu, bool is_addto = false> +void TransposeImpl(RunContext ctx, + const TBlob& src, + const TBlob& ret, + const mxnet::TShape& axes) { + CHECK_LE(axes.ndim(), 6) << "TransposeImpl supports at most 6 dimensions"; + CHECK((TransposeCommonImpl<xpu, is_addto>(ctx, src, ret, axes))) << + "Failed to execute TransposeImpl Operator"; +} + +template <bool is_addto> +struct TransposeExKernel { + /*! + * \brief + * \param tid global thread id + * \param out_data output data + * \param in_data input data + * \param strides input strides and output strides + * \param ndim the number of dimension + */ + template <typename DType> + MSHADOW_XINLINE static void Map(int tid, + DType *out_data, + const DType *in_data, + const dim_t *strides, + const int ndim + ) { + // tid is the index of input data + const dim_t* const out_strides = strides + ndim; + int k = tid; + int out_id = 0; + for (int i = 0; i < ndim; ++i) { + out_id += (k / strides[i]) * out_strides[i]; + k %= strides[i]; + } + if (is_addto) + out_data[out_id] += in_data[tid]; + else + out_data[out_id] = in_data[tid]; + } +}; + +template<typename xpu, bool is_addto = false> +void TransposeExImpl(RunContext ctx, + const TBlob& src, + const TBlob& ret, + const mxnet::TShape& axes, + mshadow::Tensor<xpu, 1, dim_t>& strides_xpu + ) { + /* + * If ndim <= 6, it is not necessary to allocate any space for `strides_xpu` + * If ndim > 6, `strides_xpu` should be allocated `ndim * 2` elements + */ + using namespace mshadow; + using namespace mshadow::expr; + if (TransposeCommonImpl<xpu, is_addto>(ctx, src, ret, axes)) return; + CHECK_GT(axes.ndim(), 6) << + "Failed to execute TransposeExImpl when axes.ndim() <= 6"; + Stream<xpu> *s = ctx.get_stream<xpu>(); + MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, { + CHECK_EQ(strides_xpu.MSize(), axes.ndim() * 2) << \ + "If ndim > 6, `strides_xpu` should be allocated `ndim * 2` elements"; + + const mxnet::TShape &in_shape = src.shape_; + // strides: in_strides and out_strides + const int ndim = axes.ndim(); + std::vector<dim_t> strides(ndim * 2); + // compute in_strides + strides[ndim - 1] = 1; + for (int i = ndim - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * in_shape[i + 1]; + } + // compute out_strides + std::vector<dim_t> tmp_strides(ndim); + tmp_strides[ndim - 1] = 1; + for (int i = ndim - 2; i >= 0; --i) { + tmp_strides[i] = tmp_strides[i + 1] * in_shape[axes[i + 1]]; + } + // reorder tmp_strides to out_strides + dim_t * const out_strides = &strides[ndim]; + for (int i = 0; i < ndim; ++i) { + out_strides[axes[i]] = tmp_strides[i]; + } + Shape<1> strides_shape; + strides_shape[0] = ndim * 2; + Tensor<cpu, 1, dim_t> strides_cpu(strides.data(), strides_shape); + // copy arguments into xpu context + Copy(strides_xpu, strides_cpu, s); + const DType *in = src.dptr<DType>(); + DType *out = ret.dptr<DType>(); + if (is_addto) { + mxnet_op::Kernel<TransposeExKernel<true>, xpu>::Launch(s, + in_shape.Size(), out, in, strides_xpu.dptr_, ndim); + } else { + mxnet_op::Kernel<TransposeExKernel<false>, xpu>::Launch(s, + in_shape.Size(), out, in, strides_xpu.dptr_, ndim); + } + }); +} + +template<typename xpu> +mshadow::Tensor<xpu, 1, dim_t> GetTransposeExWorkspace( + const OpContext& ctx, + const mxnet::TShape& axes + ) { + if (axes.ndim() > 6) { + // allocate workspace when axes.ndim() > 6 + mshadow::Shape<1> strides_shape; + strides_shape[0] = axes.ndim() * 2; + return ctx.requested[0].get_space_typed<xpu, 1, dim_t>( + strides_shape, ctx.get_stream<xpu>()); + } + return {}; } // matrix transpose @@ -441,10 +559,15 @@ void Transpose(const nnvm::NodeAttrs& attrs, } else { axes = common::CanonicalizeAxes(param.axes); } + + mshadow::Tensor<xpu, 1, dim_t> workspace = + GetTransposeExWorkspace<xpu>(ctx, axes); if (req[0] == kAddTo) { - TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes); + TransposeExImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], + axes, workspace); } else { - TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes); + TransposeExImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], + axes, workspace); } } @@ -458,7 +581,6 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, mxnet::TShape& out_shp = (*out_attrs)[0]; if (!mxnet::ndim_is_known(shp) && !mxnet::ndim_is_known(out_shp)) return false; // none of the shapes is known - CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; if (out_shp.ndim() >= 0 && shp.ndim() >= 0) CHECK_EQ(out_shp.ndim(), shp.ndim()); mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1); diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 8a86f7a..8c3d14f 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -351,6 +351,10 @@ Examples:: } }) .set_attr<FCompute>("FCompute<cpu>", Transpose<cpu>) +.set_attr<FResourceRequest>("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; +}) #if MXNET_USE_MKLDNN == 1 .set_attr<bool>("TIsMKLDNN", true) .set_attr<FComputeEx>("FComputeEx<cpu>", TransposeComputeExCPU) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 2599db0..88ad77f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2440,7 +2440,11 @@ def test_np_broadcast_to_npx(src_shape, npx_dst_shape, np_dst_shape, hybridize): [(8, 2, 16), [(0, 2, 1), (2, 0, 1), (0, 1, 2), (2, 1, 0), (-1, -2, -3)]], [(8, 3, 4, 8), [(0, 2, 3, 1), (1, 2, 3, 0), (0, 3, 2, 1)]], [(8, 3, 2, 3, 8), [(0, 1, 3, 2, 4), (0, 1, 2, 3, 4), (4, 0, 1, 2, 3)]], - [(3, 4, 3, 4, 3, 2), [(0, 1, 3, 2, 4, 5), (2, 3, 4, 1, 0, 5), None]] + [(3, 4, 3, 4, 3, 2), [(0, 1, 3, 2, 4, 5), (2, 3, 4, 1, 0, 5), None]], + [(3, 4, 3, 4, 3, 2, 2), [(0, 1, 3, 2, 4, 5, 6), + (2, 3, 4, 1, 0, 5, 6), None]], + [(3, 4, 3, 4, 3, 2, 3, 2), [(0, 1, 3, 2, 4, 5, 7, 6), + (2, 3, 4, 1, 0, 5, 7, 6), None]], ]) @pytest.mark.parametrize('grad_req', ['write', 'add']) def test_np_transpose(data_shape, axes_workload, hybridize, dtype, grad_req): @@ -10117,7 +10121,7 @@ def test_np_rollaxis(): dtypes = ['int32', 'int64', 'float16', 'float32', 'float64'] for hybridize in [False, True]: for dtype in dtypes: - for ndim in [0, 1, 2, 3, 4, 5, 6]: + for ndim in [0, 1, 2, 3, 4, 5, 6, 7, 8]: shape = rand_shape_nd(ndim, dim=5, allow_zero_size=True) np_data = _np.random.uniform(low=-100, high=100, size=shape).astype(dtype) mx_data = np.array(np_data, dtype=dtype) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 6546751..1578e14 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2573,9 +2573,9 @@ def test_broadcast(): @with_seed() def test_transpose(): - for ndim in range(1, 7): + for ndim in range(1, 10): for t in range(5): - dims = list(np.random.randint(1, 10, size=ndim)) + dims = list(np.random.randint(1, 5, size=ndim)) axes = list(range(ndim)) random.shuffle(axes) axes = tuple(axes)