haojin2 commented on a change in pull request #17009: [Numpy] support boolean indexing URL: https://github.com/apache/incubator-mxnet/pull/17009#discussion_r358077224
########## File path: python/mxnet/numpy/multiarray.py ########## @@ -324,6 +327,134 @@ def _set_np_advanced_indexing(self, key, value): value_nd = self._prepare_value_nd(value, bcast_shape=vshape, squeeze_axes=new_axes) self._scatter_set_nd(value_nd, idcs) + def _is_basic_boolean_indexing(self, key): + """Check boolean indexing type arr[bool], arr[1, bool, 4], or arr[:, bool, :] + return bool, type, position""" + if isinstance(key, ndarray) and key.dtype == _np.bool_: # boolean indexing + return True, _NDARRAY_SINGLE_BOOLEAN_INDEXING, 0 + elif not isinstance(key, tuple): + return False, _NDARRAY_UNSUPPORTED_INDEXING, 0 + num_bool = 0 + dim = len(key) + rest_int = True + rest_full_slice = True + pos = None + for idx in range(dim): + if isinstance(key[idx], _np.ndarray) and key[idx].dtype == _np.bool_: + key[idx] = array(key[idx], dtype='bool', ctx=self.ctx) + if isinstance(key[idx], ndarray) and key[idx].dtype == _np.bool_: + num_bool += 1 + pos = idx + elif isinstance(key[idx], integer_types): + rest_full_slice = False + elif isinstance(key[idx], py_slice) and key[idx] == slice(None, None, None): + rest_int = False + # not arr[:, bool, :] format slicing or not arr[3,bool,4] + else: + return False, _NDARRAY_UNSUPPORTED_INDEXING, pos + + if num_bool == 1 and rest_int: + return True, _NDARRAY_INT_BOOLEAN_INDEXING, pos + elif num_bool == 1 and rest_full_slice: + return True, _NDARRAY_SLICE_BOOLEAN_INDEXING, pos + elif num_bool > 2: + raise NotImplementedError("Do not support more than two boolean arrays as part of indexing") + return False, _NDARRAY_UNSUPPORTED_INDEXING, pos + + @staticmethod + def _calculate_new_idx(key, shape, mask_pos, mask_ndim): # pylint: disable=redefined-outer-name + new_idx = 0 + step = 1 + for idx in range(len(key)-1, mask_pos, -1): + new_idx += key[idx]*step + step *= shape[idx+mask_ndim-1] + return new_idx + + def _get_boolean_indexing(self, key, pos, bool_type): + if bool_type == _NDARRAY_SINGLE_BOOLEAN_INDEXING: + key = (key,) + bool_type = _NDARRAY_SLICE_BOOLEAN_INDEXING + + from functools import reduce + mask_shape = key[pos].shape + mask_ndim = len(mask_shape) + ndim = len(self.shape) + if len(key) + mask_ndim - 1 > ndim: + raise IndexError('too many indices, whose ndim = {}, for array with ndim = {}' + .format(len(key) + mask_ndim - 1, ndim)) + for i in range(mask_ndim): + if key[pos].shape[i] != self.shape[pos + i]: + raise IndexError('boolean index did not match indexed array along axis {};' + ' size is {} but corresponding boolean size is {}' + .format(pos + i, self.shape[pos + i], key[pos].shape[i])) + remaining_idces = pos + mask_ndim + remaining_shapes = self.shape[remaining_idces:] + mask = _reshape_view(key[pos], -1) + + if bool_type == _NDARRAY_SLICE_BOOLEAN_INDEXING: + data = _reshape_view(self, -1, *remaining_shapes) + # if mask is at the begining, then the scale is one + scale = reduce(lambda x, y: x * y, self.shape[:pos], 1) + keys = mask if scale == 1 else _reshape_view(_npi.stack(*[mask for i in range(scale)]), -1) + all_shapes = self.shape[:pos] + remaining_shapes + return _reshape_view(_npi.boolean_mask(data, keys), -1, *all_shapes) + + elif bool_type == _NDARRAY_INT_BOOLEAN_INDEXING: + out = self + for idx in range(pos): + out = out[key[idx]] + data = _reshape_view(out, -1, *remaining_shapes) + after_mask = _reshape_view(_npi.boolean_mask(data, mask), -1, *remaining_shapes) + if pos == len(key) - 1: + return after_mask + # check boundary + for idx in range(pos+1, len(key)): + if key[idx] >= self.shape[idx+mask_ndim-1]: + raise IndexError('index {} on a dimension of' # pylint: disable=too-many-format-args + .format(key[idx], self.shape[idx+mask_ndim-1])) + implicit_idces = len(key)+mask_ndim-1 # idces not explictly shown in the key + implicit_shape = self.shape[implicit_idces:] + new_dim = reduce(lambda x, y: x * y, self.shape[pos+mask_ndim:implicit_idces], 1) + new_idx = self._calculate_new_idx(key, self.shape, pos, mask_ndim) + after_reshape = _reshape_view(after_mask, -1, new_dim, *implicit_shape) + return _reshape_view(_npi.take(after_reshape, array([new_idx]), axis=1), -1, *implicit_shape) + + raise NotImplementedError("This boolean indexing type is not supported.") + + def _set_boolean_indexing(self, key, pos, bool_type, value): + if bool_type == _NDARRAY_SINGLE_BOOLEAN_INDEXING: + key = (key,) + bool_type = _NDARRAY_SLICE_BOOLEAN_INDEXING + + mask = key[pos] + mask_shape = mask.shape + mask_ndim = len(mask_shape) + if len(key) + mask_ndim - 1 > self.ndim: + raise IndexError('too many indices, whose ndim = {}, for array with ndim = {}' + .format(len(key) + mask_ndim - 1, self.ndim)) + for i in range(mask_ndim): + if mask_shape[i] != self.shape[pos + i]: + raise IndexError('boolean index did not match indexed array along axis {};' + ' size is {} but corresponding boolean size is {}' + .format(pos + i, self.shape[pos + i], mask_shape[i])) + + data = self # when bool_type == _NDARRAY_SLICE_BOOLEAN_INDEXING + if bool_type == _NDARRAY_INT_BOOLEAN_INDEXING: + if pos != len(key) - 1: + raise NotImplementedError('only support boolean array at the end of the idces ' + 'when it is mixed with integers') + for idx in range(pos): + data = data[key[idx]] + pos -= 1 + + if isinstance(value, numeric_types): + _npi.boolean_mask_assign_scalar(data=data, mask=mask, value=value, start_axis=pos, out=data) + elif isinstance(value, NDArray): Review comment: should check ```python elif isinstance(value, ndarray): ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services