This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 27d5c69600c1ab90021bcc1e72a6c2ac51bed1cf
Author: reminisce <wujun....@gmail.com>
AuthorDate: Tue Jun 18 11:34:46 2019 -0700

    [numpy] Fix d2l chapter 5 (#15264)
    
    * Fix parameter initializer
    
    * Add np.save and np.load
    
    * Fix read-write
    
    * Fix lint
---
 python/mxnet/gluon/block.py                 |  11 ++-
 python/mxnet/gluon/parameter.py             |  44 ++++++----
 python/mxnet/initializer.py                 |  14 +++-
 python/mxnet/ndarray/utils.py               |   7 ++
 python/mxnet/numpy/__init__.py              |   1 +
 python/mxnet/numpy/multiarray.py            |   3 +-
 python/mxnet/numpy/utils.py                 | 122 ++++++++++++++++++++++++++++
 tests/python/unittest/test_numpy_ndarray.py |  46 ++++++++++-
 8 files changed, 224 insertions(+), 24 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 588d12c..7866cfb 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -37,6 +37,7 @@ from .utils import _indent, _brief_print_list, HookHandle
 from .utils import _check_same_symbol_type, _check_all_np_ndarrays
 from .. import numpy_extension as _mx_npx
 from .. import numpy as _mx_np
+from .. util import is_np_array
 
 
 class _BlockScope(object):
@@ -335,7 +336,10 @@ class Block(object):
         """
         params = self._collect_params_with_prefix()
         arg_dict = {key : val._reduce() for key, val in params.items()}
-        ndarray.save(filename, arg_dict)
+        if is_np_array():
+            _mx_np.save(filename, arg_dict)
+        else:
+            ndarray.save(filename, arg_dict)
 
     def save_params(self, filename):
         """[Deprecated] Please use save_parameters. Note that if you want load
@@ -384,7 +388,10 @@ class Block(object):
         `Saving and Loading Gluon Models \
         
<https://mxnet.incubator.apache.org/tutorials/gluon/save_load_params.html>`_
         """
-        loaded = ndarray.load(filename)
+        if is_np_array():
+            loaded = _mx_np.load(filename)
+        else:
+            loaded = ndarray.load(filename)
         params = self._collect_params_with_prefix()
         if not loaded and not params:
             return
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 6d8e5c0..89a3c33 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -18,6 +18,8 @@
 # coding: utf-8
 # pylint: disable=unnecessary-pass, too-many-lines
 """Neural network parameter."""
+from __future__ import absolute_import
+
 __all__ = ['DeferredInitializationError', 'Parameter', 'Constant',
            'ParameterDict', 'tensor_types']
 
@@ -32,6 +34,7 @@ from ..context import Context, cpu
 from .. import autograd
 from .utils import _indent, _brief_print_list, shape_is_known
 from ..util import is_np_shape, is_np_array
+from .. import numpy as _mx_np  # pylint: disable=reimported
 
 # pylint: disable= invalid-name
 tensor_types = (symbol.Symbol, ndarray.NDArray)
@@ -190,9 +193,9 @@ class Parameter(object):
             return
 
         assert len(self._shape) == len(new_shape) and \
-            all(j in (0, i) for i, j in zip(new_shape, self._shape)), \
+            all(j in (-1, 0, i) for i, j in zip(new_shape, self._shape)), \
             "Expected shape %s is incompatible with given shape %s."%(
-                str(new_shape), str(self._shape))
+                str(new_shape), str(self._shape))  # -1 means unknown dim size 
in np_shape mode
 
         self._shape = new_shape
 
@@ -271,12 +274,14 @@ class Parameter(object):
         if cast_dtype:
             assert dtype_source in ['current', 'saved']
         if self.shape:
+            unknown_dim_size = -1 if is_np_shape() else 0
             for self_dim, data_dim in zip(self.shape, data.shape):
-                assert self_dim in (0, data_dim), \
+                assert self_dim in (unknown_dim_size, data_dim), \
                     "Failed loading Parameter '%s' from saved params: " \
                     "shape incompatible expected %s vs saved %s"%(
                         self.name, str(self.shape), str(data.shape))
-            self.shape = tuple(i if i != 0 else j for i, j in zip(self.shape, 
data.shape))
+            self.shape = tuple(i if i != unknown_dim_size else j
+                               for i, j in zip(self.shape, data.shape))
         if self.dtype:
             if cast_dtype and np.dtype(self.dtype).type != data.dtype:
                 if dtype_source == 'current':
@@ -326,13 +331,18 @@ class Parameter(object):
 
         with autograd.pause():
             if data is None:
-                data = ndarray.zeros(shape=self.shape, dtype=self.dtype,
-                                     ctx=context.cpu(), stype=self._stype)
+                kwargs = {'shape': self.shape, 'dtype': self.dtype, 'ctx': 
context.cpu()}
+                if is_np_array():
+                    if self._stype != 'default':
+                        raise ValueError("mxnet.numpy.zeros does not support 
stype = {}"
+                                         .format(self._stype))
+                    zeros_fn = _mx_np.zeros
+                else:
+                    kwargs['stype'] = self._stype
+                    zeros_fn = ndarray.zeros
+                data = zeros_fn(**kwargs)
                 initializer.create(default_init)(
                     initializer.InitDesc(self.name, {'__init__': init}), data)
-                # TODO(junwu): use np random operators when available
-                if is_np_array():
-                    data = data.as_np_ndarray()  # convert to np.ndarray
 
             self._init_impl(data, ctx)
 
@@ -355,11 +365,15 @@ class Parameter(object):
             self._grad = None
             return
 
-        self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, 
ctx=i.context,
-                                    stype=self._grad_stype) for i in 
self._data]
-        # TODO(junwu): use np.zeros
         if is_np_array():
-            self._grad = [arr.as_np_ndarray() for arr in self._grad]
+            if self._grad_stype != 'default':
+                raise ValueError("mxnet.numpy.zeros does not support stype = 
{}"
+                                 .format(self._grad_stype))
+            self._grad = [_mx_np.zeros(shape=i.shape, dtype=i.dtype, 
ctx=i.context)
+                          for i in self._data]
+        else:
+            self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, 
ctx=i.context,
+                                        stype=self._grad_stype) for i in 
self._data]
 
         autograd.mark_variables(self._check_and_get(self._data, list),
                                 self._grad, self.grad_req)
@@ -773,12 +787,12 @@ class ParameterDict(object):
                         inferred_shape = []
                         matched = True
                         for dim1, dim2 in zip(v, existing):
-                            if dim1 != dim2 and dim1 * dim2 != 0:
+                            if dim1 != dim2 and dim1 > 0 and dim2 > 0:
                                 matched = False
                                 break
                             elif dim1 == dim2:
                                 inferred_shape.append(dim1)
-                            elif dim1 == 0:
+                            elif dim1 in (0, -1):  # -1 means unknown dim size 
in np_shape mode
                                 inferred_shape.append(dim2)
                             else:
                                 inferred_shape.append(dim1)
diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py
index d028247..c5eef63 100755
--- a/python/mxnet/initializer.py
+++ b/python/mxnet/initializer.py
@@ -29,6 +29,8 @@ from .ndarray import NDArray, load
 from . import random
 from . import registry
 from . import ndarray
+from . util import is_np_array
+from . import numpy as _mx_np  # pylint: disable=reimported
 
 # inherit str for backward compatibility
 class InitDesc(str):
@@ -501,7 +503,8 @@ class Uniform(Initializer):
         self.scale = scale
 
     def _init_weight(self, _, arr):
-        random.uniform(-self.scale, self.scale, out=arr)
+        uniform_fn = _mx_np.random.uniform if is_np_array() else random.uniform
+        uniform_fn(-self.scale, self.scale, out=arr)
 
 @register
 class Normal(Initializer):
@@ -534,7 +537,8 @@ class Normal(Initializer):
         self.sigma = sigma
 
     def _init_weight(self, _, arr):
-        random.normal(0, self.sigma, out=arr)
+        normal_fn = _mx_np.random.normal if is_np_array() else random.normal
+        normal_fn(0, self.sigma, out=arr)
 
 @register
 class Orthogonal(Initializer):
@@ -633,9 +637,11 @@ class Xavier(Initializer):
             raise ValueError("Incorrect factor type")
         scale = np.sqrt(self.magnitude / factor)
         if self.rnd_type == "uniform":
-            random.uniform(-scale, scale, out=arr)
+            uniform_fn = _mx_np.random.uniform if is_np_array() else 
random.uniform
+            uniform_fn(-scale, scale, out=arr)
         elif self.rnd_type == "gaussian":
-            random.normal(0, scale, out=arr)
+            normal_fn = _mx_np.random.normal if is_np_array() else 
random.normal
+            normal_fn(0, scale, out=arr)
         else:
             raise ValueError("Unknown random type")
 
diff --git a/python/mxnet/ndarray/utils.py b/python/mxnet/ndarray/utils.py
index ff93d0b..730f217 100644
--- a/python/mxnet/ndarray/utils.py
+++ b/python/mxnet/ndarray/utils.py
@@ -248,6 +248,7 @@ def save(fname, data):
     >>> mx.nd.load('my_dict')
     {'y': <NDArray 1x4 @cpu(0)>, 'x': <NDArray 2x3 @cpu(0)>}
     """
+    from ..numpy import ndarray as np_ndarray
     if isinstance(data, NDArray):
         data = [data]
         handles = c_array(NDArrayHandle, [])
@@ -257,11 +258,17 @@ def save(fname, data):
         if any(not isinstance(k, string_types) for k in str_keys) or \
            any(not isinstance(v, NDArray) for v in nd_vals):
             raise TypeError('save only accept dict str->NDArray or list of 
NDArray')
+        if any(isinstance(v, np_ndarray) for v in nd_vals):
+            raise TypeError('cannot save mxnet.numpy.ndarray using 
mxnet.ndarray.save;'
+                            ' use mxnet.numpy.save instead.')
         keys = c_str_array(str_keys)
         handles = c_handle_array(nd_vals)
     elif isinstance(data, list):
         if any(not isinstance(v, NDArray) for v in data):
             raise TypeError('save only accept dict str->NDArray or list of 
NDArray')
+        if any(isinstance(v, np_ndarray) for v in data):
+            raise TypeError('cannot save mxnet.numpy.ndarray using 
mxnet.ndarray.save;'
+                            ' use mxnet.numpy.save instead.')
         keys = None
         handles = c_handle_array(data)
     else:
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py
index e1c9d90..266c2fa 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/__init__.py
@@ -24,5 +24,6 @@ from .multiarray import *  # pylint: disable=wildcard-import
 from . import _op
 from . import _register
 from ._op import *  # pylint: disable=wildcard-import
+from .utils import *  # pylint: disable=wildcard-import
 
 __all__ = []
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 3c981d1..52a2cf4 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -1285,8 +1285,7 @@ def array(object, dtype=None, ctx=None):
             try:
                 object = _np.array(object, dtype=dtype)
             except Exception as e:
-                print(e)
-                raise TypeError('source array must be an array like object')
+                raise TypeError('{}'.format(str(e)))
     ret = empty(object.shape, dtype=dtype, ctx=ctx)
     if len(object.shape) == 0:
         ret[()] = object
diff --git a/python/mxnet/numpy/utils.py b/python/mxnet/numpy/utils.py
new file mode 100644
index 0000000..48a47a3
--- /dev/null
+++ b/python/mxnet/numpy/utils.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Util functions for the numpy module."""
+
+
+from __future__ import absolute_import
+
+import ctypes
+from .. util import is_np_array, is_np_shape
+from .. base import _LIB, check_call, string_types, c_str_array
+from .. base import c_handle_array, c_str, mx_uint, NDArrayHandle, py_str
+from . import ndarray
+
+__all__ = ['save', 'load']
+
+
+def save(file, arr):
+    """Saves a list of `ndarray`s or a dict of `str`->`ndarray` to file.
+
+    Examples of filenames:
+
+    - ``/path/to/file``
+    - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports)
+    - ``hdfs://path/to/file`` (if compiled with HDFS supports)
+
+    Parameters
+    ----------
+    file : str
+        Filename to which the data is saved.
+    arr : `ndarray` or list of `ndarray`s or dict of `str` to `ndarray`
+        The data to be saved.
+
+    Notes
+    -----
+    This function can only be called within numpy semantics, i.e., 
`npx.is_np_shape()`
+    and `npx.is_np_array()` must both return true.
+    """
+    if not (is_np_shape() and is_np_array()):
+        raise ValueError('Cannot save `mxnet.numpy.ndarray` in legacy mode. 
Please activate'
+                         ' numpy semantics by calling `npx.set_np()` in the 
global scope'
+                         ' before calling this function.')
+    if isinstance(arr, ndarray):
+        arr = [arr]
+    if isinstance(arr, dict):
+        str_keys = arr.keys()
+        nd_vals = arr.values()
+        if any(not isinstance(k, string_types) for k in str_keys) or \
+                any(not isinstance(v, ndarray) for v in nd_vals):
+            raise TypeError('Only accepts dict str->ndarray or list of 
ndarrays')
+        keys = c_str_array(str_keys)
+        handles = c_handle_array(nd_vals)
+    elif isinstance(arr, list):
+        if any(not isinstance(v, ndarray) for v in arr):
+            raise TypeError('Only accepts dict str->ndarray or list of 
ndarrays')
+        keys = None
+        handles = c_handle_array(arr)
+    else:
+        raise ValueError("data needs to either be a ndarray, dict of (str, 
ndarray) pairs "
+                         "or a list of ndarrays.")
+    check_call(_LIB.MXNDArraySave(c_str(file),
+                                  mx_uint(len(handles)),
+                                  handles,
+                                  keys))
+
+
+def load(file):
+    """Loads an array from file.
+
+    See more details in ``save``.
+
+    Parameters
+    ----------
+    file : str
+        The filename.
+
+    Returns
+    -------
+    result : list of ndarrays or dict of str -> ndarray
+        Data stored in the file.
+
+    Notes
+    -----
+    This function can only be called within numpy semantics, i.e., 
`npx.is_np_shape()`
+    and `npx.is_np_array()` must both return true.
+    """
+    if not (is_np_shape() and is_np_array()):
+        raise ValueError('Cannot load `mxnet.numpy.ndarray` in legacy mode. 
Please activate'
+                         ' numpy semantics by calling `npx.set_np()` in the 
global scope'
+                         ' before calling this function.')
+    if not isinstance(file, string_types):
+        raise TypeError('file required to be a string')
+    out_size = mx_uint()
+    out_name_size = mx_uint()
+    handles = ctypes.POINTER(NDArrayHandle)()
+    names = ctypes.POINTER(ctypes.c_char_p)()
+    check_call(_LIB.MXNDArrayLoad(c_str(file),
+                                  ctypes.byref(out_size),
+                                  ctypes.byref(handles),
+                                  ctypes.byref(out_name_size),
+                                  ctypes.byref(names)))
+    if out_name_size.value == 0:
+        return [ndarray(NDArrayHandle(handles[i])) for i in 
range(out_size.value)]
+    else:
+        assert out_name_size.value == out_size.value
+        return dict(
+            (py_str(names[i]), ndarray(NDArrayHandle(handles[i])))
+            for i in range(out_size.value))
diff --git a/tests/python/unittest/test_numpy_ndarray.py 
b/tests/python/unittest/test_numpy_ndarray.py
index 74b3d4d..0d8eacf 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -18,12 +18,13 @@
 # pylint: skip-file
 from __future__ import absolute_import
 from __future__ import division
+import os
 import numpy as _np
 import mxnet as mx
 from mxnet import np, npx, autograd
 from mxnet.gluon import HybridBlock
 from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, 
rand_ndarray, assert_exception
-from common import with_seed
+from common import with_seed, TemporaryDirectory
 
 
 @with_seed()
@@ -625,6 +626,49 @@ def test_np_ndarray_indexing():
             test_setitem_autograd(np_array, index)
 
 
+@with_seed()
+@npx.use_np
+def test_np_save_load_ndarrays():
+    shapes = [(2, 0, 1), (0,), (), (), (0, 4), (), (3, 0, 0, 0), (2, 1), (0, 
5, 0), (4, 5, 6), (0, 0, 0)]
+    array_list = [_np.random.randint(0, 10, size=shape) for shape in shapes]
+    array_list = [np.array(arr, dtype=arr.dtype) for arr in array_list]
+    # test save/load single ndarray
+    for i, arr in enumerate(array_list):
+        with TemporaryDirectory() as work_dir:
+            fname = os.path.join(work_dir, 'dataset.npy')
+            np.save(fname, arr)
+            arr_loaded = np.load(fname)
+            assert isinstance(arr_loaded, list)
+            assert len(arr_loaded) == 1
+            assert _np.array_equal(arr_loaded[0].asnumpy(), 
array_list[i].asnumpy())
+
+    # test save/load a list of ndarrays
+    with TemporaryDirectory() as work_dir:
+        fname = os.path.join(work_dir, 'dataset.npy')
+        np.save(fname, array_list)
+        array_list_loaded = mx.nd.load(fname)
+        assert isinstance(arr_loaded, list)
+        assert len(array_list) == len(array_list_loaded)
+        assert all(isinstance(arr, np.ndarray) for arr in arr_loaded)
+        for a1, a2 in zip(array_list, array_list_loaded):
+            assert _np.array_equal(a1.asnumpy(), a2.asnumpy())
+
+    # test save/load a dict of str->ndarray
+    arr_dict = {}
+    keys = [str(i) for i in range(len(array_list))]
+    for k, v in zip(keys, array_list):
+        arr_dict[k] = v
+    with TemporaryDirectory() as work_dir:
+        fname = os.path.join(work_dir, 'dataset.npy')
+        np.save(fname, arr_dict)
+        arr_dict_loaded = np.load(fname)
+        assert isinstance(arr_dict_loaded, dict)
+        assert len(arr_dict_loaded) == len(arr_dict)
+        for k, v in arr_dict_loaded.items():
+            assert k in arr_dict
+            assert _np.array_equal(v.asnumpy(), arr_dict[k].asnumpy())
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to