reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy URL: https://github.com/apache/incubator-mxnet/pull/16621#discussion_r350463623
########## File path: python/mxnet/symbol/numpy/_symbol.py ########## @@ -43,31 +50,131 @@ 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] -def _num_outputs(sym): - return len(sym.as_nd_ndarray()) - @set_module('mxnet.symbol.numpy') class _Symbol(Symbol): - def __getitem__(self, key): - num_outputs = _num_outputs(self) - if num_outputs == 1: - raise NotImplementedError - if not isinstance(key, int): - raise NotImplementedError - if key >= num_outputs: - # Important, python determines the end by this exception - raise IndexError - handle = SymbolHandle() - check_call(_LIB.MXSymbolGetOutput( - self.handle, mx_uint(key), ctypes.byref(handle))) - return _Symbol(handle=handle) + def __init__(self, handle): + super(_Symbol, self).__init__(handle) + + def __getitem__(self, key): # pylint: disable = too-many-return-statements, inconsistent-return-statements + """Return self[key]. + + If the symbol is a symbol list, it returns the i-th symbol or a list of symbols + selected by key. + + Otherwise, it outputs a symbol that slice the input by the given key. Currently, this + function supports the following types of key: + + - integer types, e.g., int, long, np.int32, np.int64 + - slice containing integer constants, e.g., slice(0, None, None) + - tuple contaning the above elements, which is used for multidimensional indexing + + Parameters + ---------- + key : int, slice, or tuple of all previous types + Indexing key. + + """ + num_outputs = self.num_outputs + if num_outputs > 1: + num_outputs = self.num_outputs + if isinstance(key, integer_types): + key = int(key) + if key < -num_outputs or key >= num_outputs: + raise IndexError('list index out of range') + if key < 0: + key += num_outputs + ret_handle = SymbolHandle() + check_call(_LIB.MXSymbolGetOutput(self.handle, mx_uint(key), + ctypes.byref(ret_handle))) + return _Symbol(handle=ret_handle) + elif isinstance(key, py_slice): + start, stop, step = key.indices(num_outputs) + return Group([self[i] for i in range(start, stop, step)], _Symbol) + else: + raise TypeError('indices of symbol group must be integers or slices, not {}' + .format(type(key))) + else: + if isinstance(key, integer_types): + sliced = _npi.slice(self, [key], [key+1]) + return _npi.reshape(sliced, (-3, -4)) + elif isinstance(key, py_slice): + if key.step is None or key.step != 0: + start = [None] if key.start is None else key.start + stop = [None] if key.stop is None else key.stop + return _npi.slice(self, start, stop, key.step) + else: + raise ValueError("slice step cannot be zero") + elif isinstance(key, tuple): + begin = [] + end = [] + step = [] + new_shape = () + if len(key) == 0: + return self + for index in key: + if isinstance(index, py_slice): + if index.step is not None and index.step == 0: + raise ValueError("slice step cannot be zero") + begin.append(index.start) + end.append(index.stop) + step.append(index.step) + new_shape += (-2,) + elif isinstance(index, integer_types): + if index >= 0: + begin.append(index) + end.append(index+1) + step.append(1) + else: + begin.append(index) + end.append(index - 1) + step.append(-1) + new_shape += (-3,) + else: + raise IndexError('Only integer, slice, or tuple of these types' + ' are supported! Received key={}'.format(key)) + new_shape += (-4,) + sliced = _npi.slice(self, begin, end, step) + return _npi.reshape(sliced, new_shape) + else: + raise IndexError('Only integer, slice, or tuple of these types are supported! ' + 'Received key={}'.format(key)) def __setitem__(self, key, value): raise NotImplementedError + def __repr__(self): + """Gets a string representation of the symbol.""" + if self.num_outputs > 1: + name = ', '.join([str(ele_sym) for ele_sym in self]) + return '<%s group [%s]>' % (self.__class__.__name__, name) + else: + return '<%s %s>' % (self.__class__.__name__, self.name) + + @property + def name(self): + """Gets name string from the symbol, this function only works for symbols + that are not a list (grouped symbols). + + Returns + ------- + value : str + The name of this symbol, returns ``None`` for list symbol. + """ + if self.num_outputs > 1: + return None Review comment: `raise` instead of `return`? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services