This is an automated email from the ASF dual-hosted git repository. wkcn pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 6692d2c [Bug Fix] support multiple-dim input for unravel_index (#17748) 6692d2c is described below commit 6692d2cc76c4bb841d43abbe53f4d4aff059ba77 Author: JackieWu <w...@live.cn> AuthorDate: Sat Apr 11 21:43:12 2020 +0800 [Bug Fix] support multiple-dim input for unravel_index (#17748) * support multiple-dim input for unravel_index * sanity --- src/operator/tensor/ravel.cc | 12 ++++++++++-- src/operator/tensor/ravel.h | 19 ++++++++++++++----- tests/python/unittest/test_operator.py | 13 +++++++++++++ 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/operator/tensor/ravel.cc b/src/operator/tensor/ravel.cc index e04628e..e7cd303 100644 --- a/src/operator/tensor/ravel.cc +++ b/src/operator/tensor/ravel.cc @@ -62,8 +62,16 @@ NNVM_REGISTER_OP(_unravel_index) Examples:: A = [22,41,37] - unravel(A, shape=(7,6)) = [[3,6,6],[4,5,1]] - unravel(A, shape=(-1,6)) = [[3,6,6],[4,5,1]] + unravel_index(A, shape=(7,6)) = [[3,6,6], + [4,5,1]] + unravel_index(A, shape=(-1,6)) = [[3,6,6], + [4,5,1]] + + B = [[22,41,37],[10,11,15]] + unravel_index(B, shape=(7,6)) = [[[3,6,6],[1,1,2]], + [[4,5,1],[4,5,3]]] + unravel_index(B, shape=(-1,6)) = [[[3,6,6],[1,1,2]], + [[4,5,1],[4,5,3]]] )code" ADD_FILELINE) .set_num_inputs(1) diff --git a/src/operator/tensor/ravel.h b/src/operator/tensor/ravel.h index d96b9cf..abf9383 100644 --- a/src/operator/tensor/ravel.h +++ b/src/operator/tensor/ravel.h @@ -76,16 +76,24 @@ inline bool UnravelOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1); CHECK_EQ(out_attrs->size(), 1); CHECK_GT(shape.ndim(), 0) << "Empty shape parameter for unravel operator."; - if ((*in_attrs)[0].ndim() > 0) { - SHAPE_ASSIGN_CHECK(*out_attrs, 0, Shape2(shape.ndim(), (*in_attrs)[0][0])); + const mxnet::TShape &in_shape = (*in_attrs)[0]; + if (in_shape.ndim() > 0) { + mxnet::TShape out_shape(in_shape.ndim() + 1, -1); + out_shape[0] = shape.ndim(); + for (int i = 0; i < in_shape.ndim(); ++i) { + out_shape[i+1] = in_shape[i]; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape); return true; } if ((*out_attrs)[0].ndim() > 0) { + const mxnet::TShape &out_shape = (*out_attrs)[0]; CHECK_EQ((*out_attrs)[0].ndim(), 2) << "Output of unravel operator must be two-dimensional."; CHECK_EQ((*out_attrs)[0][0], shape.ndim()) << "First dimension of output of ravel operator does not match shape parameter dimension."; - SHAPE_ASSIGN_CHECK(*in_attrs, 0, Shape1((*out_attrs)[0][1])); + SHAPE_ASSIGN_CHECK(*in_attrs, 0, mxnet::TShape( + out_shape.data() + 1, out_shape.data() + out_shape.ndim())); return true; } return false; @@ -156,8 +164,9 @@ void UnravelForward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { Tensor<xpu, 1, OType> in = inputs[0].FlatTo1D<xpu, OType>(s); Tensor<xpu, 1, OType> out = outputs[0].FlatTo1D<xpu, OType>(s); - mxnet_op::Kernel<unravel_index, xpu>::Launch(s, in.size(0), in.size(0), out.size(0)/in.size(0), - work.dptr_, out.dptr_, in.dptr_); + mxnet_op::Kernel<unravel_index, xpu>::Launch( + s, in.shape_.Size(), in.shape_.Size(), shape.ndim(), + work.dptr_, out.dptr_, in.dptr_); }); } diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 0c795db..230073a 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8427,6 +8427,19 @@ def test_ravel(): c = mx.sym.unravel_index(a, shape=shape2) check_symbolic_forward(c, location={'a': ravel_npy}, expected=[data]) + +@with_seed() +def test_unravel_index(): + unravel_shape = (2, 10) + unravel_size = np.prod(unravel_shape) + for shape in [(10,), (2, 10), (3, 4, 5)]: + a = np.random.randint(0, unravel_size, size=shape) + b = np.stack(np.unravel_index(a, shape=unravel_shape), 0) + a_mx = mx.nd.array(a) + b_mx = mx.nd.unravel_index(a_mx, shape=unravel_shape) + assert_array_equal(b, b_mx.asnumpy()) + + def test_context_num_gpus(): try: # Note: the test is run both on GPU and CPU hosts, so that we can not assert