reminisce commented on a change in pull request #16621: [Numpy] Basic indexing 
in symbolic interface of DeepNumpy
URL: https://github.com/apache/incubator-mxnet/pull/16621#discussion_r350463623
 
 

 ##########
 File path: python/mxnet/symbol/numpy/_symbol.py
 ##########
 @@ -43,31 +50,131 @@
            'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 
'shares_memory', 'may_share_memory', 'diff',
            'resize', 'nan_to_num', 'where']
 
-def _num_outputs(sym):
-    return len(sym.as_nd_ndarray())
-
 
 @set_module('mxnet.symbol.numpy')
 class _Symbol(Symbol):
-    def __getitem__(self, key):
-        num_outputs = _num_outputs(self)
-        if num_outputs == 1:
-            raise NotImplementedError
-        if not isinstance(key, int):
-            raise NotImplementedError
-        if key >= num_outputs:
-            # Important, python determines the end by this exception
-            raise IndexError
-        handle = SymbolHandle()
-        check_call(_LIB.MXSymbolGetOutput(
-            self.handle, mx_uint(key), ctypes.byref(handle)))
-        return _Symbol(handle=handle)
+    def __init__(self, handle):
+        super(_Symbol, self).__init__(handle)
+
+    def __getitem__(self, key): # pylint: disable = 
too-many-return-statements, inconsistent-return-statements
+        """Return self[key].
+
+        If the symbol is a symbol list, it returns the i-th symbol or a list 
of symbols
+        selected by key.
+
+        Otherwise, it outputs a symbol that slice the input by the given key. 
Currently, this
+        function supports the following types of key:
+
+        - integer types, e.g., int, long, np.int32, np.int64
+        - slice containing integer constants, e.g., slice(0, None, None)
+        - tuple contaning the above elements, which is used for 
multidimensional indexing
+
+        Parameters
+        ----------
+        key : int, slice, or tuple of all previous types
+            Indexing key.
+
+        """
+        num_outputs = self.num_outputs
+        if num_outputs > 1:
+            num_outputs = self.num_outputs
+            if isinstance(key, integer_types):
+                key = int(key)
+                if key < -num_outputs or key >= num_outputs:
+                    raise IndexError('list index out of range')
+                if key < 0:
+                    key += num_outputs
+                ret_handle = SymbolHandle()
+                check_call(_LIB.MXSymbolGetOutput(self.handle, mx_uint(key),
+                                                  ctypes.byref(ret_handle)))
+                return _Symbol(handle=ret_handle)
+            elif isinstance(key, py_slice):
+                start, stop, step = key.indices(num_outputs)
+                return Group([self[i] for i in range(start, stop, step)], 
_Symbol)
+            else:
+                raise TypeError('indices of symbol group must be integers or 
slices, not {}'
+                                .format(type(key)))
+        else:
+            if isinstance(key, integer_types):
+                sliced = _npi.slice(self, [key], [key+1])
+                return _npi.reshape(sliced, (-3, -4))
+            elif isinstance(key, py_slice):
+                if key.step is None or key.step != 0:
+                    start = [None] if key.start is None else key.start
+                    stop = [None] if key.stop is None else key.stop
+                    return _npi.slice(self, start, stop, key.step)
+                else:
+                    raise ValueError("slice step cannot be zero")
+            elif isinstance(key, tuple):
+                begin = []
+                end = []
+                step = []
+                new_shape = ()
+                if len(key) == 0:
+                    return self
+                for index in key:
+                    if isinstance(index, py_slice):
+                        if index.step is not None and index.step == 0:
+                            raise ValueError("slice step cannot be zero")
+                        begin.append(index.start)
+                        end.append(index.stop)
+                        step.append(index.step)
+                        new_shape += (-2,)
+                    elif isinstance(index, integer_types):
+                        if index >= 0:
+                            begin.append(index)
+                            end.append(index+1)
+                            step.append(1)
+                        else:
+                            begin.append(index)
+                            end.append(index - 1)
+                            step.append(-1)
+                        new_shape += (-3,)
+                    else:
+                        raise IndexError('Only integer, slice, or tuple of 
these types'
+                                         ' are supported! Received 
key={}'.format(key))
+                new_shape += (-4,)
+                sliced = _npi.slice(self, begin, end, step)
+                return _npi.reshape(sliced, new_shape)
+            else:
+                raise IndexError('Only integer, slice, or tuple of these types 
are supported! '
+                                 'Received key={}'.format(key))
 
     def __setitem__(self, key, value):
         raise NotImplementedError
 
+    def __repr__(self):
+        """Gets a string representation of the symbol."""
+        if self.num_outputs > 1:
+            name = ', '.join([str(ele_sym) for ele_sym in self])
+            return '<%s group [%s]>' % (self.__class__.__name__, name)
+        else:
+            return '<%s %s>' % (self.__class__.__name__, self.name)
+
+    @property
+    def name(self):
+        """Gets name string from the symbol, this function only works for 
symbols
+         that are not a list (grouped symbols).
+
+        Returns
+        -------
+        value : str
+            The name of this symbol, returns ``None`` for list symbol.
+        """
+        if self.num_outputs > 1:
+            return None
 
 Review comment:
   `raise` instead of `return`?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to