[GitHub] [incubator-mxnet] reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy
reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy URL: https://github.com/apache/incubator-mxnet/pull/16621#discussion_r350463623 ## File path: python/mxnet/symbol/numpy/_symbol.py ## @@ -43,31 +50,131 @@ 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] -def _num_outputs(sym): -return len(sym.as_nd_ndarray()) - @set_module('mxnet.symbol.numpy') class _Symbol(Symbol): -def __getitem__(self, key): -num_outputs = _num_outputs(self) -if num_outputs == 1: -raise NotImplementedError -if not isinstance(key, int): -raise NotImplementedError -if key >= num_outputs: -# Important, python determines the end by this exception -raise IndexError -handle = SymbolHandle() -check_call(_LIB.MXSymbolGetOutput( -self.handle, mx_uint(key), ctypes.byref(handle))) -return _Symbol(handle=handle) +def __init__(self, handle): +super(_Symbol, self).__init__(handle) + +def __getitem__(self, key): # pylint: disable = too-many-return-statements, inconsistent-return-statements +"""Return self[key]. + +If the symbol is a symbol list, it returns the i-th symbol or a list of symbols +selected by key. + +Otherwise, it outputs a symbol that slice the input by the given key. Currently, this +function supports the following types of key: + +- integer types, e.g., int, long, np.int32, np.int64 +- slice containing integer constants, e.g., slice(0, None, None) +- tuple contaning the above elements, which is used for multidimensional indexing + +Parameters +-- +key : int, slice, or tuple of all previous types +Indexing key. + +""" +num_outputs = self.num_outputs +if num_outputs > 1: +num_outputs = self.num_outputs +if isinstance(key, integer_types): +key = int(key) +if key < -num_outputs or key >= num_outputs: +raise IndexError('list index out of range') +if key < 0: +key += num_outputs +ret_handle = SymbolHandle() +check_call(_LIB.MXSymbolGetOutput(self.handle, mx_uint(key), + ctypes.byref(ret_handle))) +return _Symbol(handle=ret_handle) +elif isinstance(key, py_slice): +start, stop, step = key.indices(num_outputs) +return Group([self[i] for i in range(start, stop, step)], _Symbol) +else: +raise TypeError('indices of symbol group must be integers or slices, not {}' +.format(type(key))) +else: +if isinstance(key, integer_types): +sliced = _npi.slice(self, [key], [key+1]) +return _npi.reshape(sliced, (-3, -4)) +elif isinstance(key, py_slice): +if key.step is None or key.step != 0: +start = [None] if key.start is None else key.start +stop = [None] if key.stop is None else key.stop +return _npi.slice(self, start, stop, key.step) +else: +raise ValueError("slice step cannot be zero") +elif isinstance(key, tuple): +begin = [] +end = [] +step = [] +new_shape = () +if len(key) == 0: +return self +for index in key: +if isinstance(index, py_slice): +if index.step is not None and index.step == 0: +raise ValueError("slice step cannot be zero") +begin.append(index.start) +end.append(index.stop) +step.append(index.step) +new_shape += (-2,) +elif isinstance(index, integer_types): +if index >= 0: +begin.append(index) +end.append(index+1) +step.append(1) +else: +begin.append(index) +end.append(index - 1) +step.append(-1) +new_shape += (-3,) +else: +raise IndexError('Only integer, slice, or tuple of these types' + ' are supported! Received key={}'.format(key)) +new_shape += (-4,) +sliced = _npi.slice(self, begin, end, step) +return
[GitHub] [incubator-mxnet] reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy
reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy URL: https://github.com/apache/incubator-mxnet/pull/16621#discussion_r350470836 ## File path: python/mxnet/symbol/numpy/_symbol.py ## @@ -4957,4 +5119,62 @@ def where(condition, x, y): return _npi.where(condition, x, y, out=None) +@set_module('mxnet.symbol.numpy') +def load_json_string(json_str): +""" +Loads symbol from json string. + +Parameters +-- +json_str : str +A JSON string. + +Returns +--- +sym : Symbol +The loaded symbol. + +See Also + +_Symbol.tojson : Used to save symbol into json string. +""" +if not isinstance(json_str, string_types): +raise TypeError('fname required to be string') +handle = SymbolHandle() +json_data = json.loads(json_str) +check_call(_LIB.MXSymbolCreateFromJSON(c_str(json.dumps(json_data)), ctypes.byref(handle))) +s = _Symbol(handle) +return s + + +@set_module('mxnet.symbol.numpy') +def load(fname): +"""Loads symbol from a JSON file. +You can also use pickle to do the job if you only work on python. +The advantage of load/save is the file is language agnostic. +This means the file saved using save can be loaded by other language binding of mxnet. +You also get the benefit being able to directly load/save from cloud storage(S3, HDFS). +Parameters +-- +fname : str +The name of the file, examples: +- `s3://my-bucket/path/my-s3-symbol` +- `hdfs://my-bucket/path/my-hdfs-symbol` +- `/path-to/my-local-symbol` +Returns +--- +sym : Symbol +The loaded symbol. +See Also + +Symbol.save : Used to save symbol into file. +""" +if not isinstance(fname, string_types): Review comment: Reuse `mx.sym.load`? This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy
reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy URL: https://github.com/apache/incubator-mxnet/pull/16621#discussion_r350469321 ## File path: python/mxnet/symbol/numpy/_symbol.py ## @@ -855,6 +972,51 @@ def broadcast_to(self, *args, **kwargs): def broadcast_like(self, *args, **kwargs): raise AttributeError('_Symbol object has no attribute broadcast_like') +def save(self, fname, remove_amp_cast=True): +"""Saves symbol to a file. +You can also use pickle to do the job if you only work on python. +The advantage of `load`/`save` functions is that the file contents are language agnostic. +This means the model saved by one language binding can be loaded by a different +language binding of `MXNet`. +You also get the benefit of being able to directly load/save from cloud storage(S3, HDFS). +Parameters +-- +fname : str +The name of the file. +- "s3://my-bucket/path/my-s3-symbol" +- "hdfs://my-bucket/path/my-hdfs-symbol" +- "/path-to/my-local-symbol" +remove_amp_cast : bool, optional +Whether to remove the amp_cast and amp_multicast operators, before saving the model. +See Also + +symbol.load : Used to load symbol from file. +""" +if not isinstance(fname, string_types): +raise TypeError('fname need to be string') + +handle = self.handle +if remove_amp_cast: +handle = SymbolHandle() +check_call(_LIB.MXSymbolRemoveAmpCast(self.handle, ctypes.byref(handle))) + +processed_symbol = _Symbol(handle) +json_str = processed_symbol.save_json_string() +json_data = json.loads(json_str) +with open(fname, 'w') as file_out: +json.dump(json_data, file_out, indent=2, sort_keys=True) Review comment: If the attribute `is_output_list` is removed from the symbol, I think we can simply reuse the base class method? This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy
reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy URL: https://github.com/apache/incubator-mxnet/pull/16621#discussion_r350468055 ## File path: python/mxnet/ndarray/ndarray.py ## @@ -3125,6 +3128,26 @@ def _get_dim_size(start, stop, step): return dim_size +def _get_slice_len_for(slc, seq_length): Review comment: nit: change `_get_slice_len_for` to `_get_slice_len`. This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy
reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy URL: https://github.com/apache/incubator-mxnet/pull/16621#discussion_r350470620 ## File path: python/mxnet/symbol/numpy/_symbol.py ## @@ -4957,4 +5119,62 @@ def where(condition, x, y): return _npi.where(condition, x, y, out=None) +@set_module('mxnet.symbol.numpy') +def load_json_string(json_str): +""" +Loads symbol from json string. + +Parameters +-- +json_str : str +A JSON string. + +Returns +--- +sym : Symbol +The loaded symbol. + +See Also + +_Symbol.tojson : Used to save symbol into json string. +""" +if not isinstance(json_str, string_types): Review comment: Reuse `mx.sym.load_json`? This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy
reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy URL: https://github.com/apache/incubator-mxnet/pull/16621#discussion_r350468294 ## File path: python/mxnet/ndarray/numpy/_op.py ## @@ -2774,10 +2773,10 @@ def split(ary, indices_or_sections, axis=0): elif isinstance(indices_or_sections, (list, set, tuple)): indices = [0] + list(indices_or_sections) else: -raise ValueError('indices_or_sections must either int, or tuple / list / set of ints') +raise ValueError('indices_or_sections must be either int, or tuple / list / set of ints') ret = _npi.split(ary, indices, axis, False) -if not isinstance(ret, list): -return [ret] +assert isinstance(ret, list), 'Output of split should be list,' \ + ' get a return type {}'.format(type(ret)) Review comment: nit: `get` -> `got`. This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy
reminisce commented on a change in pull request #16621: [Numpy] Basic indexing in symbolic interface of DeepNumpy URL: https://github.com/apache/incubator-mxnet/pull/16621#discussion_r350469601 ## File path: python/mxnet/symbol/numpy/_symbol.py ## @@ -855,6 +972,51 @@ def broadcast_to(self, *args, **kwargs): def broadcast_like(self, *args, **kwargs): raise AttributeError('_Symbol object has no attribute broadcast_like') +def save(self, fname, remove_amp_cast=True): +"""Saves symbol to a file. +You can also use pickle to do the job if you only work on python. +The advantage of `load`/`save` functions is that the file contents are language agnostic. +This means the model saved by one language binding can be loaded by a different +language binding of `MXNet`. +You also get the benefit of being able to directly load/save from cloud storage(S3, HDFS). +Parameters +-- +fname : str +The name of the file. +- "s3://my-bucket/path/my-s3-symbol" +- "hdfs://my-bucket/path/my-hdfs-symbol" +- "/path-to/my-local-symbol" +remove_amp_cast : bool, optional +Whether to remove the amp_cast and amp_multicast operators, before saving the model. +See Also + +symbol.load : Used to load symbol from file. +""" +if not isinstance(fname, string_types): +raise TypeError('fname need to be string') + +handle = self.handle +if remove_amp_cast: +handle = SymbolHandle() +check_call(_LIB.MXSymbolRemoveAmpCast(self.handle, ctypes.byref(handle))) + +processed_symbol = _Symbol(handle) +json_str = processed_symbol.save_json_string() +json_data = json.loads(json_str) +with open(fname, 'w') as file_out: +json.dump(json_data, file_out, indent=2, sort_keys=True) + +def save_json_string(self): Review comment: Same here. Seems this is not needed anymore. This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services