This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch v1.7.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.7.x by this push:
     new 75ab155  more support for boolean indexing and assign (#18352)
75ab155 is described below

commit 75ab15569bd0f20a90806ce2fc38df08be208ed7
Author: alicia <32725332+alicia1...@users.noreply.github.com>
AuthorDate: Thu May 28 11:05:13 2020 +0800

    more support for boolean indexing and assign (#18352)
---
 python/mxnet/ndarray/ndarray.py             | 100 ++++++++----
 python/mxnet/numpy/multiarray.py            | 230 ++++++++++++----------------
 src/operator/numpy/np_nonzero_op.cc         |   2 +-
 src/operator/tensor/indexing_op.cc          |   2 +-
 src/operator/tensor/indexing_op.cu          |   2 +-
 src/operator/tensor/indexing_op.h           |   4 +-
 tests/python/unittest/test_numpy_ndarray.py |  36 ++++-
 7 files changed, 206 insertions(+), 170 deletions(-)

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index cda3166..7ac666e 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -39,6 +39,7 @@ from ..base import mx_uint, NDArrayHandle, check_call, 
DLPackHandle, mx_int, mx_
 from ..base import ctypes2buffer
 from ..runtime import Features
 from ..context import Context, current_context
+from ..util import is_np_array
 from . import _internal
 from . import op
 from ._internal import NDArrayBase
@@ -111,7 +112,11 @@ _NDARRAY_UNSUPPORTED_INDEXING = -1
 _NDARRAY_BASIC_INDEXING = 0
 _NDARRAY_ADVANCED_INDEXING = 1
 _NDARRAY_EMPTY_TUPLE_INDEXING = 2
-_NDARRAY_BOOLEAN_INDEXING = 3
+
+# Return code for 0-d boolean array handler
+_NDARRAY_NO_ZERO_DIM_BOOL_ARRAY = -1
+_NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE = 0
+_NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE = 1
 
 # Caching whether MXNet was built with INT64 support or not
 _INT64_TENSOR_SIZE_ENABLED = None
@@ -521,7 +526,7 @@ fixed-size items.
             return
 
         else:
-            key = indexing_key_expand_implicit_axes(key, self.shape)
+            key, _ = indexing_key_expand_implicit_axes(key, self.shape)
             slc_key = tuple(idx for idx in key if idx is not None)
 
             if len(slc_key) < self.ndim:
@@ -714,7 +719,7 @@ fixed-size items.
             elif key.step == 0:
                 raise ValueError("slice step cannot be zero")
 
-        key = indexing_key_expand_implicit_axes(key, self.shape)
+        key, _ = indexing_key_expand_implicit_axes(key, self.shape)
         if len(key) == 0:
             raise ValueError('indexing key cannot be an empty tuple')
 
@@ -2574,9 +2579,12 @@ fixed-size items.
         >>> type(x.asscalar())
         <type 'numpy.int32'>
         """
-        if self.shape != (1,):
+        if self.size != 1:
             raise ValueError("The current array is not a scalar")
-        return self.asnumpy()[0]
+        if self.ndim == 1:
+            return self.asnumpy()[0]
+        else:
+            return self.asnumpy()[()]
 
     def astype(self, dtype, copy=True):
         """Returns a copy of the array after casting to a specified type.
@@ -2943,9 +2951,23 @@ fixed-size items.
             lhs=self, rhs=value_nd, indices=indices, shape=self.shape, out=self
         )
 
+def check_boolean_array_dimension(array_shape, axis, bool_shape):
+    """
+    Advanced boolean indexing is implemented through the use of `nonzero`.
+    Size check is necessary to make sure that the boolean array
+    has exactly as many dimensions as it is supposed to work with before the 
conversion
+    """
+    for i, val in enumerate(bool_shape):
+        if array_shape[axis + i] != val:
+            raise IndexError('boolean index did not match indexed array along 
axis {};'
+                             ' size is {} but corresponding boolean size is {}'
+                             .format(axis + i, array_shape[axis + i], val))
 
 def indexing_key_expand_implicit_axes(key, shape):
-    """Make implicit axes explicit by adding ``slice(None)``.
+    """
+    Make implicit axes explicit by adding ``slice(None)``
+    and convert boolean array to integer array through `nonzero`.
+
     Examples
     --------
     >>> shape = (3, 4, 5)
@@ -2957,6 +2979,11 @@ def indexing_key_expand_implicit_axes(key, shape):
     (0, slice(None, None, None), slice(None, None, None))
     >>> indexing_key_expand_implicit_axes(np.s_[:2, None, 0, ...], shape)
     (slice(None, 2, None), None, 0, slice(None, None, None))
+    >>> bool_array = np.array([[True, False, True, False],
+                               [False, True, False, True],
+                               [True, False, True, False]], dtype=np.bool)
+    >>> indexing_key_expand_implicit_axes(np.s_[bool_array, None, 0:2], shape)
+    (array([0, 0, 1, 1, 2, 2], dtype=int64), array([0, 2, 1, 3, 0, 2], 
dtype=int64), None, slice(None, 2, None))
     """
     if not isinstance(key, tuple):
         key = (key,)
@@ -2966,6 +2993,17 @@ def indexing_key_expand_implicit_axes(key, shape):
     ell_idx = None
     num_none = 0
     nonell_key = []
+
+    # For 0-d boolean indices: A new axis is added,
+    # but at the same time no axis is "used". So if we have True,
+    # we add a new axis (a bit like with np.newaxis). If it is
+    # False, we add a new axis, but this axis has 0 entries.
+    # prepend is defined to handle this case.
+    # prepend = _NDARRAY_NO_ZERO_DIM_BOOL_ARRAY/-1 means there is no 0-d 
boolean scalar
+    # prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE/0 means an zero dim must be 
expanded
+    # prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE/1 means a new axis must be 
expanded
+    prepend = _NDARRAY_NO_ZERO_DIM_BOOL_ARRAY
+    axis = 0
     for i, idx in enumerate(key):
         if idx is Ellipsis:
             if ell_idx is not None:
@@ -2974,14 +3012,38 @@ def indexing_key_expand_implicit_axes(key, shape):
                 )
             ell_idx = i
         else:
+            # convert primitive type boolean value to mx.np.bool type
+            # otherwise will be treated as 1/0
+            if isinstance(idx, bool):
+                idx = array(idx, dtype=np.bool_)
             if idx is None:
                 num_none += 1
-            if isinstance(idx, NDArrayBase) and idx.ndim == 0 and idx.dtype != 
np.bool_:
+            if isinstance(idx, NDArrayBase) and idx.ndim == 0 and idx.dtype == 
np.bool_:
+                if not idx: # array(False) has priority
+                    prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE
+                else:
+                    prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE
+            elif isinstance(idx, NDArrayBase) and idx.ndim == 0 and idx.dtype 
!= np.bool_:
                 # This handles ndarray of zero dim. e.g array(1)
                 # while advoid converting zero dim boolean array
-                nonell_key.append(idx.item())
+                # float type will be converted to int
+                nonell_key.append(int(idx.item()))
+                axis += 1
+            elif isinstance(idx, NDArrayBase) and idx.dtype == np.bool_:
+                # Necessary size check before using `nonzero`
+                check_boolean_array_dimension(shape, axis, idx.shape)
+                # If the whole array is false and npx.set_np() is not set_up
+                # the program will throw infer shape error
+                if not is_np_array():
+                    raise ValueError('Cannot perform boolean indexing in 
legacy mode. Please activate'
+                                     ' numpy semantics by calling 
`npx.set_np()` in the global scope'
+                                     ' before calling this function.')
+                # Add the arrays from the nonzero result to the index
+                nonell_key.extend(idx.nonzero())
+                axis += idx.ndim
             else:
                 nonell_key.append(idx)
+                axis += 1
 
     nonell_key = tuple(nonell_key)
 
@@ -2995,7 +3057,7 @@ def indexing_key_expand_implicit_axes(key, shape):
                     (slice(None),) * ell_ndim +
                     nonell_key[ell_idx:])
 
-    return expanded_key
+    return expanded_key, prepend
 
 
 def _int_to_slice(idx):
@@ -3053,32 +3115,18 @@ def _is_advanced_index(idx):
 def get_indexing_dispatch_code(key):
     """Returns a dispatch code for calling basic or advanced indexing 
functions."""
     assert isinstance(key, tuple)
-    num_bools = 0
-    basic_indexing = True
 
     for idx in key:
-        if isinstance(idx, (NDArray, np.ndarray, list, tuple)):
+        if isinstance(idx, (NDArray, np.ndarray, list, tuple, range)):
             if isinstance(idx, tuple) and len(idx) == 0:
                 return _NDARRAY_EMPTY_TUPLE_INDEXING
-            if getattr(idx, 'dtype', None) == np.bool_:
-                num_bools += 1
-            basic_indexing = False
-        elif isinstance(idx, range):
-            basic_indexing = False
+            return _NDARRAY_ADVANCED_INDEXING
         elif not (isinstance(idx, (py_slice, integer_types)) or idx is None):
             raise ValueError(
                 'NDArray does not support slicing with key {} of type {}.'
                 ''.format(idx, type(idx))
             )
-    if basic_indexing and num_bools == 0:
-        return _NDARRAY_BASIC_INDEXING
-    elif not basic_indexing and num_bools == 0:
-        return _NDARRAY_ADVANCED_INDEXING
-    elif num_bools == 1:
-        return _NDARRAY_BOOLEAN_INDEXING
-    else:
-        raise TypeError('ndarray indexing does not more than one boolean 
ndarray'
-                        ' in a tuple of complex indices.')
+    return _NDARRAY_BASIC_INDEXING
 
 
 def _get_index_range(start, stop, length, step=1):
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 6516994..fceaaf3 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -76,10 +76,11 @@ _NDARRAY_UNSUPPORTED_INDEXING = -1
 _NDARRAY_BASIC_INDEXING = 0
 _NDARRAY_ADVANCED_INDEXING = 1
 _NDARRAY_EMPTY_TUPLE_INDEXING = 2
-_NDARRAY_BOOLEAN_INDEXING = 3
-_NDARRAY_INT_BOOLEAN_INDEXING = 4
-_NDARRAY_SLICE_BOOLEAN_INDEXING = 5
 
+# Return code for 0-d boolean array handler
+_NDARRAY_NO_ZERO_DIM_BOOL_ARRAY = -1
+_NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE = 0
+_NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE = 1
 
 # This function is copied from ndarray.py since pylint
 # keeps giving false alarm error of undefined-all-variable
@@ -405,121 +406,43 @@ class ndarray(NDArray):
         value_nd = self._prepare_value_nd(value, bcast_shape=vshape, 
squeeze_axes=new_axes)
         self._scatter_set_nd(value_nd, idcs)
 
-    def _check_boolean_indexing_type(self, key):
-        """Check boolean indexing type arr[bool, :, :], arr[1, bool, 4], or 
arr[:, bool, :]
-           return bool_type, bool_position"""
-
-        dim = len(key)
-        rest_int = True
-        rest_full_slice = True
-        pos = None
-        for idx in range(dim):
-            if isinstance(key[idx], _np.ndarray) and key[idx].dtype == 
_np.bool_:
-                key[idx] = array(key[idx], dtype='bool', ctx=self.ctx)
-            if isinstance(key[idx], ndarray) and key[idx].dtype == _np.bool_:
-                pos = idx
-            elif isinstance(key[idx], integer_types):
-                rest_full_slice = False
-            elif isinstance(key[idx], py_slice) and key[idx] == slice(None, 
None, None):
-                rest_int = False
-            # not arr[:, bool, :] format slicing or not arr[3,bool,4]
-            else:
-                raise TypeError('ndarray boolean indexing does not support 
slicing '
-                                'with key {} of type {}'.format(idx, type(idx))
-                                )
-
-        if rest_int:
-            return _NDARRAY_INT_BOOLEAN_INDEXING, pos
-        elif rest_full_slice:
-            return _NDARRAY_SLICE_BOOLEAN_INDEXING, pos
-        raise NotImplementedError("Do not support {} as key for boolean 
indexing".format(key))
-
-    @staticmethod
-    def _calculate_new_idx(key, shape, mask_pos, mask_ndim): # pylint: 
disable=redefined-outer-name
-        new_idx = 0
-        step = 1
-        for idx in range(len(key)-1, mask_pos, -1):
-            new_idx += key[idx]*step
-            step *= shape[idx+mask_ndim-1]
-        return new_idx
-
-    def _get_np_boolean_indexing(self, key):
-        if not isinstance(key, tuple):
-            key = (key,)
-        bool_type, pos = self._check_boolean_indexing_type(key)
-
-        from functools import reduce
-        mask_shape = key[pos].shape
-        mask_ndim = len(mask_shape)
-        ndim = len(self.shape)  # pylint: disable=redefined-outer-name, 
unused-variable
-        for i in range(mask_ndim):
-            if key[pos].shape[i] != self.shape[pos + i]:
-                raise IndexError('boolean index did not match indexed array 
along axis {};'
-                                 ' size is {} but corresponding boolean size 
is {}'
-                                 .format(pos + i, self.shape[pos + i], 
key[pos].shape[i]))
-        remaining_idces = pos + mask_ndim
-        remaining_shapes = self.shape[remaining_idces:]
-        mask = _reshape_view(key[pos], -1)
-
-        if bool_type == _NDARRAY_SLICE_BOOLEAN_INDEXING:
-            data = _reshape_view(self, -1, *remaining_shapes)
-            # if mask is at the begining, then the scale is one
-            scale = reduce(lambda x, y: x * y, self.shape[:pos], 1)
-            keys = mask if scale == 1 else _reshape_view(_npi.stack(*[mask for 
i in range(scale)]), -1)
-            all_shapes = self.shape[:pos] + remaining_shapes
-            return _reshape_view(_npi.boolean_mask(data, keys), -1, 
*all_shapes)
-
-        elif bool_type == _NDARRAY_INT_BOOLEAN_INDEXING:
-            out = self
-            for idx in range(pos):
-                out = out[key[idx]]
-            data = _reshape_view(out, -1, *remaining_shapes)
-            after_mask = _reshape_view(_npi.boolean_mask(data, mask), -1, 
*remaining_shapes)
-            if pos == len(key) - 1:
-                return after_mask
-            # check boundary
-            for idx in range(pos+1, len(key)):
-                if key[idx] >= self.shape[idx+mask_ndim-1]:
-                    raise IndexError('index {} on a dimension of {}'
-                                     .format(key[idx], 
self.shape[idx+mask_ndim-1]))
-            implicit_idces = len(key)+mask_ndim-1 # idces not explictly shown 
in the key
-            implicit_shape = self.shape[implicit_idces:]
-            new_dim = reduce(lambda x, y: x * y, 
self.shape[pos+mask_ndim:implicit_idces], 1)
-            new_idx = self._calculate_new_idx(key, self.shape, pos, mask_ndim)
-            after_reshape = _reshape_view(after_mask, -1, new_dim, 
*implicit_shape)
-            return _reshape_view(_npi.take(after_reshape, array([new_idx]), 
axis=1), -1, *implicit_shape)
-
-        raise NotImplementedError("This boolean indexing type is not 
supported.")
+    # pylint: disable=redefined-outer-name
+    def _get_np_boolean_indexing(self, key, ndim, shape):
+        """
+        There are two types of boolean indices (which are equivalent,
+        for the most part though). This function will handle single
+        boolean indexing for higher speed.
+        If this is not the case, it is instead expanded into (multiple)
+        integer array indices and will be handled by advanced indexing.
+        """
+        key_shape = key.shape
+        key_ndim = len(key_shape)
+        if ndim < key_ndim:
+            raise IndexError('too many indices, whose ndim = {}, for array 
with ndim = {}'
+                             .format(key_ndim, ndim))
+        for i in range(key_ndim):
+            if key_shape[i] != shape[i]:
+                raise IndexError('boolean index did not match indexed array 
along dimension {};'
+                                 ' dimension is {} but corresponding boolean 
dimension is {}'
+                                 .format(i, shape[i], key_shape[i]))
+        remaining_dims = shape[key_ndim:]
+        data = _reshape_view(self, -1, *remaining_dims)
+        key = _reshape_view(key, -1)
+        return _reshape_view(_npi.boolean_mask(data, key), -1, *remaining_dims)
 
     def _set_np_boolean_indexing(self, key, value):
-        if not isinstance(key, tuple):
-            key = (key,)
-        bool_type, pos = self._check_boolean_indexing_type(key)
-
-        mask = key[pos]
-        mask_shape = mask.shape
-        mask_ndim = len(mask_shape)
-        for i in range(mask_ndim):
-            if mask_shape[i] != self.shape[pos + i]:
-                raise IndexError('boolean index did not match indexed array 
along axis {};'
-                                 ' size is {} but corresponding boolean size 
is {}'
-                                 .format(pos + i, self.shape[pos + i], 
mask_shape[i]))
-
-        data = self # when bool_type == _NDARRAY_SLICE_BOOLEAN_INDEXING
-        if bool_type == _NDARRAY_INT_BOOLEAN_INDEXING:
-            if pos != len(key) - 1:
-                raise NotImplementedError('only support boolean array at the 
end of the idces '
-                                          'when it is mixed with integers')
-            for idx in range(pos):
-                data = data[key[idx]]
-                pos -= 1
-
+        """
+        There are two types of boolean indices (which are equivalent,
+        for the most part though). This function will handle single boolean 
assign for higher speed.
+        If this is not the case, it is instead expanded into (multiple)
+        integer array indices and will be handled by advanced assign.
+        """
         if isinstance(value, numeric_types):
-            _npi.boolean_mask_assign_scalar(data=data, mask=mask,
+            _npi.boolean_mask_assign_scalar(data=self, mask=key,
                                             value=int(value) if 
isinstance(value, bool) else value,
-                                            start_axis=pos, out=data)
+                                            start_axis=0, out=self)
         elif isinstance(value, ndarray):
-            _npi.boolean_mask_assign_tensor(data=data, mask=mask, value=value, 
start_axis=pos, out=data)
+            _npi.boolean_mask_assign_tensor(data=self, mask=key, value=value, 
start_axis=0, out=self)
         else:
             raise NotImplementedError('type %s is not 
supported.'%(type(value)))
 
@@ -658,12 +581,14 @@ class ndarray(NDArray):
         >>> x = np.array([1., -1., -2., 3])
         >>> x[x < 0]
         array([-1., -2.])
+
+        For more imformation related to boolean indexing, please refer to
+        https://docs.scipy.org/doc/numpy-1.17.0/reference/arrays.indexing.html.
         """
-        # handling possible boolean indexing first
         ndim = self.ndim  # pylint: disable=redefined-outer-name
         shape = self.shape  # pylint: disable=redefined-outer-name
         if isinstance(key, bool): # otherwise will be treated as 0 and 1
-            key = array(key, dtype=_np.bool)
+            key = array(key, dtype=_np.bool, ctx=self.ctx)
         if isinstance(key, list):
             try:
                 new_key = _np.array(key)
@@ -674,11 +599,15 @@ class ndarray(NDArray):
         if isinstance(key, _np.ndarray) and key.dtype == _np.bool_:
             key = array(key, dtype='bool', ctx=self.ctx)
 
-        if ndim == 0:
-            if isinstance(key, ndarray) and key.dtype == _np.bool:
-                pass # will handle by function for boolean indexing
-            elif key != ():
-                raise IndexError('scalar tensor can only accept `()` as index')
+        # Handle single boolean index of matching dimensionality and size 
first for higher speed
+        # If the boolean array is mixed with other idices, it is instead 
expanded into (multiple)
+        # integer array indices and will be handled by advanced indexing.
+        # Come before the check self.dim == 0 as it also handle the 0-dim case.
+        if isinstance(key, ndarray) and key.dtype == _np.bool_:
+            return self._get_np_boolean_indexing(key, ndim, shape)
+
+        if ndim == 0 and key != ():
+            raise IndexError('scalar tensor can only accept `()` as index')
         # Handle simple cases for higher speed
         if isinstance(key, tuple) and len(key) == 0:
             return self
@@ -703,17 +632,33 @@ class ndarray(NDArray):
             elif key.step == 0:
                 raise ValueError("slice step cannot be zero")
 
-        key_before_expaned = key
-        key = indexing_key_expand_implicit_axes(key, self.shape)
+        # For 0-d boolean indices: A new axis is added,
+        # but at the same time no axis is "used". So if we have True,
+        # we add a new axis (a bit like with np.newaxis). If it is
+        # False, we add a new axis, but this axis has 0 entries.
+        # prepend is defined to handle this case.
+        # prepend = _NDARRAY_NO_ZERO_DIM_BOOL_ARRAY/-1 means there is no 0-d 
boolean scalar
+        # prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE/0 means an zero dim 
must be expanded
+        # prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE/1 means a new axis must 
be prepended
+        key, prepend = indexing_key_expand_implicit_axes(key, self.shape)
         indexing_dispatch_code = get_indexing_dispatch_code(key)
-        if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING:
-            return self._get_np_basic_indexing(key)
-        elif indexing_dispatch_code == _NDARRAY_EMPTY_TUPLE_INDEXING:
+        if indexing_dispatch_code == _NDARRAY_EMPTY_TUPLE_INDEXING:
+            # won't be affected by zero-dim boolean indices
             return self._get_np_empty_tuple_indexing(key)
+        elif indexing_dispatch_code == _NDARRAY_BASIC_INDEXING:
+            if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE:
+                return empty((0,) + self._get_np_basic_indexing(key).shape,
+                             dtype=self.dtype, ctx=self.ctx)
+            if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE:
+                key = (_np.newaxis,) + key
+            return self._get_np_basic_indexing(key)
         elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
+            if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE:
+                return empty((0,) + self._get_np_adanced_indexing(key).shape,
+                             dtype=self.dtype, ctx=self.ctx)
+            if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE:
+                key = (_np.newaxis,) + key
             return self._get_np_advanced_indexing(key)
-        elif indexing_dispatch_code == _NDARRAY_BOOLEAN_INDEXING:
-            return self._get_np_boolean_indexing(key_before_expaned)
         else:
             raise RuntimeError
 
@@ -764,16 +709,25 @@ class ndarray(NDArray):
         >>> x
         array([[ 6.,  5.,  5.],
                [ 6.,  0.,  4.]])
+
+        For imformation related to boolean indexing, please refer to
+        https://docs.scipy.org/doc/numpy-1.17.0/reference/arrays.indexing.html.
         """
         if isinstance(value, NDArray) and not isinstance(value, ndarray):
             raise TypeError('Cannot assign mx.nd.NDArray to 
mxnet.numpy.ndarray')
         if isinstance(key, bool): # otherwise will be treated as 0 and 1
             key = array(key, dtype=_np.bool)
+
+        # Handle single boolean assign of matching dimensionality and size 
first for higher speed
+        # If the boolean array is mixed with other idices, it is instead 
expanded into (multiple)
+        # integer array indices and will be handled by advanced assign.
+        # Come before the check self.dim == 0 as it also handle the 0-dim case.
+        if isinstance(key, ndarray) and key.dtype == _np.bool:
+            return self._set_np_boolean_indexing(key, value)
+
         # handle basic and advanced indexing
         if self.ndim == 0:
-            if isinstance(key, ndarray) and key.dtype == _np.bool:
-                pass # will be handled by boolean indexing
-            elif not isinstance(key, tuple) or len(key) != 0:
+            if not isinstance(key, tuple) or len(key) != 0:
                 raise IndexError('scalar tensor can only accept `()` as index')
             if isinstance(value, numeric_types):
                 self._full(value)
@@ -788,8 +742,18 @@ class ndarray(NDArray):
             else:
                 raise ValueError('setting an array element with a sequence.')
         else:
-            key_before_expaned = key
-            key = indexing_key_expand_implicit_axes(key, self.shape)
+            # For 0-d boolean indices: A new axis is added,
+            # but at the same time no axis is "used". So if we have True,
+            # we add a new axis (a bit like with np.newaxis). If it is
+            # False, we add a new axis, but this axis has 0 entries.
+            # prepend is defined to handle this case.
+            # prepend == _NDARRAY_NO_ZERO_DIM_BOOL_ARRAY/-1 means there is no 
0-d boolean scalar
+            # prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE/0 means an zero 
dim must be expanded
+            # prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE/1 means a new axis 
must be expanded
+            # prepend actually has no influence on __setitem__
+            key, prepend = indexing_key_expand_implicit_axes(key, self.shape)
+            if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE:
+                return # no action is needed
             slc_key = tuple(idx for idx in key if idx is not None)
             if len(slc_key) < self.ndim:
                 raise RuntimeError(
@@ -809,8 +773,6 @@ class ndarray(NDArray):
                 pass # no action needed
             elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
                 self._set_np_advanced_indexing(key, value)
-            elif indexing_dispatch_code == _NDARRAY_BOOLEAN_INDEXING:
-                return self._set_np_boolean_indexing(key_before_expaned, value)
             else:
                 raise ValueError(
                     'Indexing NDArray with index {} of type {} is not 
supported'
diff --git a/src/operator/numpy/np_nonzero_op.cc 
b/src/operator/numpy/np_nonzero_op.cc
index 0eaf087..2b9de76 100644
--- a/src/operator/numpy/np_nonzero_op.cc
+++ b/src/operator/numpy/np_nonzero_op.cc
@@ -66,7 +66,7 @@ void NonzeroForwardCPU(const nnvm::NodeAttrs& attrs,
   CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " 
<< MAXDIM;
   // 0-dim
   if (0 == in.shape().ndim()) {
-    MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
+    MSHADOW_TYPE_SWITCH_WITH_BOOL(in.dtype(), DType, {
       DType* in_dptr = in.data().dptr<DType>();
       if (*in_dptr) {
         mxnet::TShape s(2, 1);
diff --git a/src/operator/tensor/indexing_op.cc 
b/src/operator/tensor/indexing_op.cc
index 2a2dd76..1303cb9 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -480,7 +480,7 @@ void GatherNDForwardCPU(const nnvm::NodeAttrs& attrs,
     strides[i] = stride;
     mshape[i] = dshape[i];
   }
-  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {  // output data type 
switch
+  MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {  // output data 
type switch
     MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // indices data type 
switch
       // check whether indices are out of bound
       IType* idx_ptr = inputs[1].dptr<IType>();
diff --git a/src/operator/tensor/indexing_op.cu 
b/src/operator/tensor/indexing_op.cu
index 2d8789d..6904656 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -486,7 +486,7 @@ void GatherNDForwardGPU(const nnvm::NodeAttrs& attrs,
     strides[i] = stride;
     mshape[i] = dshape[i];
   }
-  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {  // output data type 
switch
+  MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {  // output data 
type switch
     MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // indices data type 
switch
       // check whether indices are out of bound
       IType* idx_ptr = inputs[1].dptr<IType>();
diff --git a/src/operator/tensor/indexing_op.h 
b/src/operator/tensor/indexing_op.h
index 5449fbe..2b04881 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -1447,8 +1447,8 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
   if (kWriteTo == req[0]) {
     Fill<true>(s, outputs[0], req[0], 0);
   }
-  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {  // output data type 
switch
-    MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // indices data type 
switch
+  MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {  // output data 
type switch
+    MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[1].type_flag_, IType, {  // indices 
data type switch
       mxnet_op::Kernel<scatter_nd, xpu>::Launch(
         s, N, req[0], N, M, K, strides, outputs[0].dptr<DType>(),
         inputs[0].dptr<DType>(), inputs[1].dptr<IType>());
diff --git a/tests/python/unittest/test_numpy_ndarray.py 
b/tests/python/unittest/test_numpy_ndarray.py
index 9ef5d79..3ce53c6 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -1194,6 +1194,24 @@ def test_np_ndarray_boolean_indexing():
         assert same(a[0, b].asnumpy(), _np_a[0, _np_b])
         assert same(a[b, 1].asnumpy(), _np_a[_np_b, 1])
 
+        a = np.arange(12).reshape(4,3)
+        b = np.array([1.,2.,3.])
+        _np_a = a.asnumpy()
+        _np_b = b.asnumpy()
+        assert same(a[:, b > 2].shape, _np_a[:, _np_b > 2].shape)
+        assert same(a[:, b > 2].asnumpy(), _np_a[:, _np_b > 2])
+
+        a = np.array([[1,2,3],[3,4,5]])
+        _np_a = a.asnumpy()
+        assert same(a[:,a[1,:] > 0].shape, _np_a[:,_np_a[1,: ] > 0].shape)
+        assert same(a[:,a[1,:] > 0].asnumpy(), _np_a[:,_np_a[1,: ] > 0])
+
+        a = np.ones((3,2), dtype='bool')
+        b = np.array([1,2,3])
+        _np_a = a.asnumpy()
+        _np_b = b.asnumpy()
+        assert same(a[b > 1].asnumpy(), _np_a[_np_b > 1])
+
     def test_boolean_indexing_assign():
         # test boolean indexing assign
         shape = (3, 2, 3)
@@ -1208,11 +1226,11 @@ def test_np_ndarray_boolean_indexing():
         np_data[np_mask] = 1
         mx_data[mx_mask] = 1
         assert_almost_equal(mx_data.asnumpy(), np_data, rtol=1e-3, atol=1e-5, 
use_broadcast=False)
-        # not supported at this moment
-        # only support boolean array at the end of the idces when it is mixed 
with integers
-        # np_data[np_mask, 1] = 2
-        # mx_data[mx_mask, 1] = 2
-        # assert_almost_equal(mx_data.asnumpy(), np_data, rtol=1e-3, 
atol=1e-5, use_broadcast=False)
+
+        np_data[np_mask, 1] = 2
+        mx_data[mx_mask, 1] = 2
+        assert_almost_equal(mx_data.asnumpy(), np_data, rtol=1e-3, atol=1e-5, 
use_broadcast=False)
+
         np_data[np_mask, :] = 3
         mx_data[mx_mask, :] = 3
         assert_almost_equal(mx_data.asnumpy(), np_data, rtol=1e-3, atol=1e-5, 
use_broadcast=False)
@@ -1227,6 +1245,14 @@ def test_np_ndarray_boolean_indexing():
         mx_data[:, mx_mask] = 6
         assert_almost_equal(mx_data.asnumpy(), np_data, rtol=1e-3, atol=1e-5, 
use_broadcast=False)
 
+        np_data[0, True, True, np_mask] = 7
+        mx_data[0, True, True, mx_mask] = 7
+        assert_almost_equal(mx_data.asnumpy(), np_data, rtol=1e-3, atol=1e-5, 
use_broadcast=False)
+
+        np_data[False, 1] = 8
+        mx_data[False, 1] = 8
+        assert_almost_equal(mx_data.asnumpy(), np_data, rtol=1e-3, atol=1e-5, 
use_broadcast=False)
+
     def test_boolean_indexing_autograd():
         a = np.random.uniform(size=(3, 4, 5))
         a.attach_grad()

Reply via email to