This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit b26abdc7783aeae118a3e8c96d75b14e7d2b2f51
Author: Hao Jin <h...@amazon.com>
AuthorDate: Wed Jun 12 16:17:47 2019 -0700

    fix for chapter6 conv nn (#15224)
---
 python/mxnet/gluon/data/dataloader.py        | 21 +++++++++++++++++----
 python/mxnet/gluon/data/vision/transforms.py |  1 +
 python/mxnet/gluon/nn/conv_layers.py         | 11 ++++++++++-
 python/mxnet/gluon/utils.py                  |  2 +-
 python/mxnet/numpy/multiarray.py             |  4 ++++
 5 files changed, 33 insertions(+), 6 deletions(-)

diff --git a/python/mxnet/gluon/data/dataloader.py 
b/python/mxnet/gluon/data/dataloader.py
index 1923f65..59b1582 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -38,7 +38,7 @@ except ImportError:
 
 from . import sampler as _sampler
 from ... import nd, context
-from ...util import is_np_array
+from ...util import is_np_shape, is_np_array, set_np
 from ... import numpy as _mx_np  # pylint: disable=reimported
 
 if sys.platform == 'darwin' or sys.platform == 'win32':
@@ -392,14 +392,21 @@ class DataLoaderV1(object):
     def __len__(self):
         return len(self._batch_sampler)
 
+
+def _thread_worker_initializer(active_shape, active_array):
+    """Initializer for ThreadPool."""
+    set_np(shape=active_shape, array=active_array)
+
+
 _worker_dataset = None
-def _worker_initializer(dataset):
+def _worker_initializer(dataset, active_shape, active_array):
     """Initialier for processing pool."""
     # global dataset is per-process based and only available in worker 
processes
     # this is only necessary to handle MXIndexedRecordIO because otherwise 
dataset
     # can be passed as argument
     global _worker_dataset
     _worker_dataset = dataset
+    set_np(shape=active_shape, array=active_array)
 
 def _worker_fn(samples, batchify_fn, dataset=None):
     """Function for processing data in worker process."""
@@ -463,6 +470,9 @@ class _MultiWorkerIter(object):
             batch = _as_in_context(batch, 
context.cpu_pinned(self._pin_device_id))
         batch = batch[0] if len(batch) == 1 else batch
         self._rcvd_idx += 1
+        if is_np_array():
+            new_batch = [member.as_np_ndarray() for member in batch]
+            batch = new_batch
         return batch
 
     def next(self):
@@ -566,10 +576,13 @@ class DataLoader(object):
         self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * 
self._num_workers)
         if self._num_workers > 0:
             if self._thread_pool:
-                self._worker_pool = ThreadPool(self._num_workers)
+                self._worker_pool = ThreadPool(self._num_workers,
+                                               
initializer=_thread_worker_initializer,
+                                               initargs=(is_np_shape(), 
is_np_array()))
             else:
                 self._worker_pool = multiprocessing.Pool(
-                    self._num_workers, initializer=_worker_initializer, 
initargs=[self._dataset])
+                    self._num_workers, initializer=_worker_initializer,
+                    initargs=[self._dataset, is_np_shape(), is_np_array()])
         if batchify_fn is None:
             if num_workers > 0:
                 self._batchify_fn = default_mp_batchify_fn
diff --git a/python/mxnet/gluon/data/vision/transforms.py 
b/python/mxnet/gluon/data/vision/transforms.py
index 2648997..54af87e 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -370,6 +370,7 @@ class Resize(HybridBlock):
         self._size = size
         self._interpolation = interpolation
 
+    @_adapt_np_array
     def hybrid_forward(self, F, x):
         return F.image.resize(x, self._size, self._keep, self._interpolation)
 
diff --git a/python/mxnet/gluon/nn/conv_layers.py 
b/python/mxnet/gluon/nn/conv_layers.py
index 4122a08..3e8516b 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -30,6 +30,7 @@ from ..block import HybridBlock
 from ... import symbol
 from ...base import numeric_types
 from .activations import Activation
+from ...util import is_np_array
 
 
 def _infer_weight_shape(op_name, data_shape, kwargs):
@@ -109,7 +110,11 @@ class _Conv(HybridBlock):
             if adj is not None:
                 self._kwargs['adj'] = adj
 
-            dshape = [0]*(len(kernel_size) + 2)
+            if is_np_array():
+                dshape = [-1]*(len(kernel_size) + 2)
+            else:
+                dshape = [0]*(len(kernel_size) + 2)
+
             dshape[layout.find('N')] = 1
             dshape[layout.find('C')] = in_channels
             wshapes = _infer_weight_shape(op_name, dshape, self._kwargs)
@@ -129,6 +134,8 @@ class _Conv(HybridBlock):
                 self.act = None
 
     def hybrid_forward(self, F, x, weight, bias=None):
+        if is_np_array():
+            F = F.npx
         if bias is None:
             act = getattr(F, self._op_name)(x, weight, name='fwd', 
**self._kwargs)
         else:
@@ -693,6 +700,8 @@ class _Pooling(HybridBlock):
         return 'pool'
 
     def hybrid_forward(self, F, x):
+        if is_np_array():
+            F = F.npx
         return F.Pooling(x, name='fwd', **self._kwargs)
 
     def __repr__(self):
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 63dc1b2..bd69503 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -516,7 +516,7 @@ def _adapt_np_array(func):
         assert len(args) > 2, "expect at least three arguments in args"
         if is_np_array():
             input_args, kwargs = _to_classic_arrays(*args[2:], **kwargs)
-            input_args = list(args[0:2]) + input_args
+            input_args = list(args[0:2]) + list(input_args)
             out = func(*input_args, **kwargs)
             return _to_np_arrays(out)
         else:
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index a4a05af..409cbf4 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -111,6 +111,10 @@ class ndarray(NDArray):
                 out = out[idx]
             return out.reshape(()).as_np_ndarray()
         if isinstance(key, integer_types):
+            if key > self.shape[0] - 1:
+                raise IndexError(
+                    'index {} is out of bounds for axis 0 with size {}'.format(
+                        key, self.shape[0]))
             return self._at(key)
         if isinstance(key, ndarray):
             key = key._as_nd_ndarray()

Reply via email to