This is an automated email from the ASF dual-hosted git repository. zhasheng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 39c0fd8 Fix ndarray assignment issue with basic indexing (#10022) 39c0fd8 is described below commit 39c0fd82312e138ef6b7f6531adb1f2fe423cb07 Author: reminisce <wujun....@gmail.com> AuthorDate: Wed Mar 7 22:40:04 2018 -0800 Fix ndarray assignment issue with basic indexing (#10022) * Fix ndarray assignment issue with basic index * Uncomment useful code --- python/mxnet/ndarray/ndarray.py | 2 ++ tests/python/unittest/test_ndarray.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 5ac2796..5367845 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -695,6 +695,8 @@ fixed-size items. # may need to broadcast first if isinstance(value, NDArray): if value.handle is not self.handle: + if value.shape != shape: + value = value.broadcast_to(shape) value.copyto(self) elif isinstance(value, numeric_types): _internal._full(shape=shape, ctx=self.context, diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index e96fb2f..16f08b0 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -992,6 +992,8 @@ def test_ndarray_indexing(): def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None): if np_value is not None: np_array[np_index] = np_value + elif isinstance(mx_value, mx.nd.NDArray): + np_array[np_index] = mx_value.asnumpy() else: np_array[np_index] = mx_value mx_array[mx_index] = mx_value @@ -1024,6 +1026,9 @@ def test_ndarray_indexing(): # test value is an numeric_type assert_same(np_array, np_index, mx_array, index, np.random.randint(low=-10000, high=0)) if len(indexed_array_shape) > 1: + # test NDArray with broadcast + assert_same(np_array, np_index, mx_array, index, + mx.nd.random.uniform(low=-10000, high=0, shape=(indexed_array_shape[-1],))) # test numpy array with broadcast assert_same(np_array, np_index, mx_array, index, np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],))) -- To stop receiving notification emails like this one, please contact zhash...@apache.org.