This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 860adfd84e0523ce8149f18298725189ba1ee162
Author: reminisce <wujun....@gmail.com>
AuthorDate: Mon Jun 17 00:24:58 2019 -0700

    [numpy] [DO NOT MERGE] Fix d2l chapters 9 and 13 (#15246)
    
    * Add npx batch_dot and topk
    
    * Text embedding uses numpy
    
    * Fix SoftmaxCrossEntropyLoss with np
    
    * Fix sentiment cnn
    
    * Fix pylint
    
    * Fix dot attention
    
    * Fix seq2seq attention
    
    * Add np.tile
    
    * Fix transformer
    
    * Fix ci
    
    * Fix ci and rebase
---
 python/mxnet/_numpy_op_doc.py                    | 23 +++++++++++
 python/mxnet/contrib/text/embedding.py           | 29 ++++++++++----
 python/mxnet/gluon/block.py                      |  2 +-
 python/mxnet/gluon/loss.py                       | 21 +++++++---
 python/mxnet/gluon/nn/basic_layers.py            | 16 ++++----
 python/mxnet/gluon/parameter.py                  |  5 ++-
 python/mxnet/gluon/utils.py                      | 35 +++++++++++------
 python/mxnet/ndarray/numpy/_op.py                | 38 +++++++++++++++++-
 python/mxnet/numpy/multiarray.py                 | 50 ++++++++++++++++++++----
 python/mxnet/symbol/numpy/_symbol.py             | 48 +++++++++++++++++++----
 src/operator/nn/dropout.cc                       |  1 +
 src/operator/nn/layer_norm.cc                    |  1 +
 src/operator/nn/softmax.cc                       |  2 +
 src/operator/tensor/broadcast_reduce_op_index.cc |  1 +
 src/operator/tensor/dot.cc                       |  1 +
 src/operator/tensor/matrix_op-inl.h              | 12 +++---
 src/operator/tensor/matrix_op.cc                 |  2 +
 src/operator/tensor/ordering_op.cc               |  1 +
 tests/python/unittest/test_numpy_op.py           | 41 +++++++++++++++++++
 19 files changed, 273 insertions(+), 56 deletions(-)

diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py
index 17f92ce..9265a98 100644
--- a/python/mxnet/_numpy_op_doc.py
+++ b/python/mxnet/_numpy_op_doc.py
@@ -86,3 +86,26 @@ def _np_zeros_like(a):
         Array of zeros with the same shape and type as `a`.
     """
     pass
+
+
+def _np_repeat(a, repeats, axis=None):
+    """Repeat elements of an array.
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array.
+    repeats : int or array of ints
+        The number of repetitions for each element.  `repeats` is broadcasted
+        to fit the shape of the given axis.
+    axis : int, optional
+        The axis along which to repeat values.  By default, use the
+        flattened input array, and return a flat output array.
+
+    Returns
+    -------
+    repeated_array : ndarray
+        Output array which has the same shape as `a`, except along
+        the given axis.
+    """
+    pass
diff --git a/python/mxnet/contrib/text/embedding.py 
b/python/mxnet/contrib/text/embedding.py
index 9d529db..da20fbe 100644
--- a/python/mxnet/contrib/text/embedding.py
+++ b/python/mxnet/contrib/text/embedding.py
@@ -35,6 +35,9 @@ from . import vocab
 from ... import ndarray as nd
 from ... import registry
 from ... import base
+from ...util import is_np_array
+from ... import numpy as _mx_np
+from ... import numpy_extension as _mx_npx
 
 
 def register(embedding_cls):
@@ -295,12 +298,15 @@ class _TokenEmbedding(vocab.Vocabulary):
                     tokens.add(token)
 
         self._vec_len = vec_len
-        self._idx_to_vec = nd.array(all_elems).reshape((-1, self.vec_len))
+        array_fn = _mx_np.array if is_np_array() else nd.array
+        self._idx_to_vec = array_fn(all_elems).reshape((-1, self.vec_len))
 
         if loaded_unknown_vec is None:
-            self._idx_to_vec[C.UNKNOWN_IDX] = 
init_unknown_vec(shape=self.vec_len)
+            init_val = init_unknown_vec(shape=self.vec_len)
+            self._idx_to_vec[C.UNKNOWN_IDX] =\
+                init_val.as_np_ndarray() if is_np_array() else init_val
         else:
-            self._idx_to_vec[C.UNKNOWN_IDX] = nd.array(loaded_unknown_vec)
+            self._idx_to_vec[C.UNKNOWN_IDX] = array_fn(loaded_unknown_vec)
 
     def _index_tokens_from_vocabulary(self, vocabulary):
         self._token_to_idx = vocabulary.token_to_idx.copy() \
@@ -328,7 +334,8 @@ class _TokenEmbedding(vocab.Vocabulary):
         """
 
         new_vec_len = sum(embed.vec_len for embed in token_embeddings)
-        new_idx_to_vec = nd.zeros(shape=(vocab_len, new_vec_len))
+        zeros_fn = _mx_np.zeros if is_np_array() else nd.zeros
+        new_idx_to_vec = zeros_fn(shape=(vocab_len, new_vec_len))
 
         col_start = 0
         # Concatenate all the embedding vectors in token_embeddings.
@@ -397,7 +404,13 @@ class _TokenEmbedding(vocab.Vocabulary):
                        else self.token_to_idx.get(token.lower(), C.UNKNOWN_IDX)
                        for token in tokens]
 
-        vecs = nd.Embedding(nd.array(indices), self.idx_to_vec, 
self.idx_to_vec.shape[0],
+        if is_np_array():
+            embedding_fn = _mx_npx.Embedding
+            array_fn = _mx_np.array
+        else:
+            embedding_fn = nd.Embedding
+            array_fn = nd.array
+        vecs = embedding_fn(array_fn(indices), self.idx_to_vec, 
self.idx_to_vec.shape[0],
                             self.idx_to_vec.shape[1])
 
         return vecs[0] if to_reduce else vecs
@@ -425,7 +438,8 @@ class _TokenEmbedding(vocab.Vocabulary):
             if not isinstance(tokens, list):
                 tokens = [tokens]
             if len(new_vectors.shape) == 1:
-                new_vectors = new_vectors.expand_dims(0)
+                expand_dims_fn = _mx_np.expand_dims if is_np_array() else 
nd.expand_dims
+                new_vectors = expand_dims_fn(new_vectors, axis=0)
 
         else:
             assert isinstance(new_vectors, nd.NDArray) and 
len(new_vectors.shape) == 2, \
@@ -444,7 +458,8 @@ class _TokenEmbedding(vocab.Vocabulary):
                                  '`unknown_token` %s in `tokens`. This is to 
avoid unintended '
                                  'updates.' % (token, 
self.idx_to_token[C.UNKNOWN_IDX]))
 
-        self._idx_to_vec[nd.array(indices)] = new_vectors
+        array_fn = _mx_np.array if is_np_array() else nd.array
+        self._idx_to_vec[array_fn(indices)] = new_vectors
 
     @classmethod
     def _check_pretrained_file_names(cls, pretrained_file_name):
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 6e92469..4da9e90 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -553,7 +553,7 @@ class Block(object):
         for hook in self._forward_hooks.values():
             hook(self, args, out)
         if _mx_npx.is_np_array():
-            _check_all_np_ndarrays(_flatten(out, "output")[0])
+            _check_all_np_ndarrays(out)
         return out
 
     def forward(self, *args):
diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py
index 79a5981..6c66d4c 100644
--- a/python/mxnet/gluon/loss.py
+++ b/python/mxnet/gluon/loss.py
@@ -357,17 +357,28 @@ class SoftmaxCrossEntropyLoss(Loss):
         self._sparse_label = sparse_label
         self._from_logits = from_logits
 
-    @_adapt_np_array
     def hybrid_forward(self, F, pred, label, sample_weight=None):
+        if is_np_array():
+            log_softmax = F.npx.log_softmax
+            pick = F.npx.pick
+        else:
+            log_softmax = F.log_softmax
+            pick = F.pick
         if not self._from_logits:
-            pred = F.log_softmax(pred, self._axis)
+            pred = log_softmax(pred, self._axis)
         if self._sparse_label:
-            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
+            loss = -pick(pred, label, axis=self._axis, keepdims=True)
         else:
             label = _reshape_like(F, label, pred)
-            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
+            loss = -(pred * label).sum(axis=self._axis, keepdims=True)
         loss = _apply_weighting(F, loss, self._weight, sample_weight)
-        return F.mean(loss, axis=self._batch_axis, exclude=True)
+        if is_np_array():
+            if F is ndarray:
+                return loss.mean(axis=tuple(range(1, loss.ndim)))
+            else:
+                return F.npx.batch_flatten(loss).mean(axis=1)
+        else:
+            return loss.mean(axis=self._batch_axis, exclude=True)
 
 
 SoftmaxCELoss = SoftmaxCrossEntropyLoss
diff --git a/python/mxnet/gluon/nn/basic_layers.py 
b/python/mxnet/gluon/nn/basic_layers.py
index eb4d55b..3c43ac3 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -264,12 +264,13 @@ class Dropout(HybridBlock):
         self._rate = rate
         self._axes = axes
 
-    @_adapt_np_array
     def hybrid_forward(self, F, x):
         if self._rate > 0:
-            return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', 
cudnn_off=False)
+            dropout = F.npx.Dropout if is_np_array() else F.Dropout
+            return dropout(x, p=self._rate, axes=self._axes, name='fwd', 
cudnn_off=False)
         else:
-            return F.identity(x)
+            copy = F.np.copy if is_np_array() else F.identity
+            return copy(x)
 
     def __repr__(self):
         s = '{name}(p = {_rate}, axes={_axes})'
@@ -359,8 +360,9 @@ class BatchNorm(HybridBlock):
             dtype = 'float32'
         super(BatchNorm, self).cast(dtype)
 
-    @_adapt_np_array
     def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
+        if is_np_array():
+            F = F.npx
         return F.BatchNorm(x, gamma, beta, running_mean, running_var,
                            name='fwd', **self._kwargs)
 
@@ -611,10 +613,10 @@ class LayerNorm(HybridBlock):
                                     shape=(in_channels,), 
init=beta_initializer,
                                     allow_deferred_init=True)
 
-    @_adapt_np_array
     def hybrid_forward(self, F, data, gamma, beta):
-        norm_data = F.LayerNorm(data, gamma=gamma, beta=beta, axis=self._axis, 
eps=self._epsilon)
-        return norm_data
+        if is_np_array():
+            F = F.npx
+        return F.LayerNorm(data, gamma=gamma, beta=beta, axis=self._axis, 
eps=self._epsilon)
 
     def __repr__(self):
         s = '{name}({content}'
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 0797b4c..6d8e5c0 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -369,7 +369,10 @@ class Parameter(object):
         ctx = context.cpu()
         if self._stype == 'default':
             block = self.list_data()
-            data = ndarray.add_n(*(w.copyto(ctx).as_nd_ndarray() for w in 
block)) / len(block)
+            if is_np_array():
+                data = sum([w.copyto(ctx) for w in block]) / len(block)
+            else:
+                data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / 
len(block)
         else:
             # fetch all rows for 'row_sparse' param
             all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', 
ctx=ctx)
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index ef26989..542a3c6 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -18,6 +18,8 @@
 # coding: utf-8
 # pylint: disable=
 """Parallelization utility optimizer."""
+from __future__ import absolute_import
+
 __all__ = ['split_data', 'split_and_load', 'clip_global_norm',
            'check_sha1', 'download']
 
@@ -39,6 +41,7 @@ import numpy as np
 
 from .. import ndarray
 from ..util import is_np_shape, is_np_array, wraps_safely
+from .. import numpy as _mx_np  # pylint: disable=reimported
 
 
 def split_data(data, num_slice, batch_axis=0, even_split=True):
@@ -112,15 +115,14 @@ def split_and_load(data, ctx_list, batch_axis=0, 
even_split=True):
     list of NDArray
         Each corresponds to a context in `ctx_list`.
     """
-    # TODO(junwu): temp solution for supporting np.ndarray
-    # rewrite this using np ops
+    array_fn = _mx_np.array if is_np_array() else ndarray.array
     if not isinstance(data, ndarray.NDArray):
-        data = ndarray.array(data, ctx=ctx_list[0])
+        data = array_fn(data, ctx=ctx_list[0])
     if len(ctx_list) == 1:
-        if is_np_array():
-            data = data.as_np_ndarray()
         return [data.as_in_context(ctx_list[0])]
 
+    # TODO(junwu): temp solution for supporting np.ndarray
+    # rewrite this using np ops
     slices = split_data(data, len(ctx_list), batch_axis, even_split)
     if is_np_array():
         slices = [i.as_np_ndarray() for i in slices]
@@ -445,7 +447,7 @@ def _check_same_symbol_type(symbols):
     Raise type error if the types are different. Return the class of
     the symbols."""
     from ..symbol.numpy import _Symbol as np_symbol
-    from ..symbol import Symbol as classic_symbol
+    from ..symbol import Symbol as nd_symbol
     is_np_sym = bool(isinstance(symbols[0], np_symbol))
     for s in symbols[1:]:
         if is_np_sym != isinstance(s, np_symbol):
@@ -460,18 +462,25 @@ def _check_same_symbol_type(symbols):
                             'on each of them; if you want classic ndarray 
output(s) from the '
                             'computation graph, please convert all the numpy 
symbols in the list '
                             'to classic symbols by calling `as_nd_ndarray()` 
on each of them.')
-    return np_symbol if is_np_sym else classic_symbol
+    return np_symbol if is_np_sym else nd_symbol
 
 
 def _check_all_np_ndarrays(out):
-    """Check if ndarrays in out are all np.ndarray"""
+    """Check if ndarrays/symbols in out are all np.ndarray/np._Symbol."""
     from ..numpy import ndarray as np_ndarray
     from ..symbol.numpy import _Symbol as np_symbol
-    assert isinstance(out, (list, tuple))
-    for array in out:
-        if not isinstance(array, (np_ndarray, np_symbol)):
-            raise TypeError('Expected np.ndarray or np._Symbol type in output, 
while received type '
-                            '{}'.format(str(type(array))))
+    from ..symbol import Symbol as nd_symbol
+    from ..ndarray import NDArray as nd_ndarray
+
+    # pylint: disable=no-else-raise
+    if isinstance(out, (nd_ndarray, nd_symbol)) and not isinstance(out, 
(np_ndarray, np_symbol)):
+        raise TypeError("Block's output ndarrays/symbols must be of type 
`mxnet.numpy.ndarray`"
+                        " or `mxnet.symbol.numpy._Symbol`, while got output 
type {}"
+                        .format(str(type(out))))
+    elif isinstance(out, (list, tuple)):
+        for i in out:
+            _check_all_np_ndarrays(i)
+    # pylint: enable=no-else-raise
 
 
 def _to_classic_arrays(*args, **kwargs):
diff --git a/python/mxnet/ndarray/numpy/_op.py 
b/python/mxnet/ndarray/numpy/_op.py
index 087b99e..04de2cd 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -26,7 +26,7 @@ from . import _internal as _npi
 
 __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
            'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 
'concatenate',
-           'clip', 'split', 'swapaxes', 'expand_dims']
+           'clip', 'split', 'swapaxes', 'expand_dims', 'tile']
 
 
 @set_module('mxnet.ndarray.numpy')
@@ -593,3 +593,39 @@ def split(ary, indices_or_sections, axis=0):
     if not isinstance(ret, list):
         raise NotImplementedError('single output from split is not supported 
yet...')
     return ret
+
+
+@set_module('mxnet.ndarray.numpy')
+def tile(A, reps):
+    """
+    Construct an array by repeating A the number of times given by reps.
+
+    If `reps` has length ``d``, the result will have dimension of
+    ``max(d, A.ndim)``.
+
+    If ``A.ndim < d``, `A` is promoted to be d-dimensional by prepending new
+    axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication,
+    or shape (1, 1, 3) for 3-D replication. If this is not the desired
+    behavior, promote `A` to d-dimensions manually before calling this
+    function.
+
+    If ``A.ndim > d``, `reps` is promoted to `A`.ndim by pre-pending 1's to it.
+    Thus for an `A` of shape (2, 3, 4, 5), a `reps` of (2, 2) is treated as
+    (1, 1, 2, 2).
+
+    Note : Although tile may be used for broadcasting, it is strongly
+    recommended to use numpy's broadcasting operations and functions.
+
+    Parameters
+    ----------
+    A : ndarray
+        The input array.
+    reps : tuple of integers
+        The number of repetitions of `A` along each axis.
+
+    Returns
+    -------
+    c : ndarray
+        The tiled output array.
+    """
+    return _npi.tile(A, reps)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 3cf3a44..3c981d1 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -45,7 +45,7 @@ from ..ndarray.numpy import _internal as _npi
 
 __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 
'stack', 'arange',
            'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 
'concatenate',
-           'clip', 'split', 'swapaxes', 'expand_dims']
+           'clip', 'split', 'swapaxes', 'expand_dims', 'tile']
 
 
 # This function is copied from ndarray.py since pylint
@@ -340,6 +340,8 @@ class ndarray(NDArray):
         else:
             raise ValueError("The truth value of an ndarray with multiple 
elements is ambiguous.")
 
+    __nonzero__ = __bool__
+
     def __float__(self):
         num_elements = self.size
         if num_elements != 1:
@@ -607,13 +609,9 @@ class ndarray(NDArray):
         """
         raise AttributeError('mxnet.numpy.ndarray object has no attribute 
broadcast_like')
 
-    def repeat(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`repeat`.
-
-        The arguments are the same as for :py:func:`repeat`, with
-        this array as data.
-        """
-        raise NotImplementedError
+    def repeat(self, repeats, axis=None):  # pylint: disable=arguments-differ
+        """Repeat elements of an array."""
+        return _mx_np_op.repeat(self, repeats=repeats, axis=axis)
 
     def pad(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`pad`.
@@ -1757,3 +1755,39 @@ def split(ary, indices_or_sections, axis=0):
         If `indices_or_sections` is given as an integer, but
         a split does not result in equal division."""
     return _mx_nd_np.split(ary, indices_or_sections, axis=axis)
+
+
+@set_module('mxnet.numpy')
+def tile(A, reps):
+    """
+    Construct an array by repeating A the number of times given by reps.
+
+    If `reps` has length ``d``, the result will have dimension of
+    ``max(d, A.ndim)``.
+
+    If ``A.ndim < d``, `A` is promoted to be d-dimensional by prepending new
+    axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication,
+    or shape (1, 1, 3) for 3-D replication. If this is not the desired
+    behavior, promote `A` to d-dimensions manually before calling this
+    function.
+
+    If ``A.ndim > d``, `reps` is promoted to `A`.ndim by pre-pending 1's to it.
+    Thus for an `A` of shape (2, 3, 4, 5), a `reps` of (2, 2) is treated as
+    (1, 1, 2, 2).
+
+    Note : Although tile may be used for broadcasting, it is strongly
+    recommended to use numpy's broadcasting operations and functions.
+
+    Parameters
+    ----------
+    A : ndarray
+        The input array.
+    reps : tuple of integers
+        The number of repetitions of `A` along each axis.
+
+    Returns
+    -------
+    c : ndarray
+        The tiled output array.
+    """
+    return _npi.tile(A, reps)
diff --git a/python/mxnet/symbol/numpy/_symbol.py 
b/python/mxnet/symbol/numpy/_symbol.py
index a3b9038..11a1da8 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -31,7 +31,7 @@ from . import _internal as _npi
 
 __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 
'arange', 'argmax',
            'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 
'split', 'swapaxes',
-           'expand_dims']
+           'expand_dims', 'tile']
 
 
 def _num_outputs(sym):
@@ -257,13 +257,9 @@ class _Symbol(Symbol):
         """
         raise AttributeError('_Symbol object has no attribute broadcast_like')
 
-    def repeat(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`repeat`.
-
-        The arguments are the same as for :py:func:`repeat`, with
-        this array as data.
-        """
-        raise NotImplementedError
+    def repeat(self, repeats, axis=None):  # pylint: disable=arguments-differ
+        """Repeat elements of an array."""
+        return _mx_np_op.repeat(self, repeats=repeats, axis=axis)
 
     def pad(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`pad`.
@@ -1275,4 +1271,40 @@ def split(ary, indices_or_sections, axis=0):
     return ret
 
 
+@set_module('mxnet.symbol.numpy')
+def tile(A, reps):
+    """
+    Construct an array by repeating A the number of times given by reps.
+
+    If `reps` has length ``d``, the result will have dimension of
+    ``max(d, A.ndim)``.
+
+    If ``A.ndim < d``, `A` is promoted to be d-dimensional by prepending new
+    axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication,
+    or shape (1, 1, 3) for 3-D replication. If this is not the desired
+    behavior, promote `A` to d-dimensions manually before calling this
+    function.
+
+    If ``A.ndim > d``, `reps` is promoted to `A`.ndim by pre-pending 1's to it.
+    Thus for an `A` of shape (2, 3, 4, 5), a `reps` of (2, 2) is treated as
+    (1, 1, 2, 2).
+
+    Note : Although tile may be used for broadcasting, it is strongly
+    recommended to use numpy's broadcasting operations and functions.
+
+    Parameters
+    ----------
+    A : _Symbol
+        The input array.
+    reps : tuple of integers
+        The number of repetitions of `A` along each axis.
+
+    Returns
+    -------
+    c : _Symbol
+        The tiled output array.
+    """
+    return _npi.tile(A, reps)
+
+
 _set_np_symbol_class(_Symbol)
diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc
index 63da561..72ba422 100644
--- a/src/operator/nn/dropout.cc
+++ b/src/operator/nn/dropout.cc
@@ -65,6 +65,7 @@ struct DropoutGrad {
 DMLC_REGISTER_PARAMETER(DropoutParam);
 
 NNVM_REGISTER_OP(Dropout)
+.add_alias("_npx_Dropout")
 .describe(R"(Applies dropout operation to input array.
 
 - During training, each element of the input is set to zero with probability p.
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index e95f472..7c6ddcb 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -127,6 +127,7 @@ void LayerNormGradCompute<cpu>(const nnvm::NodeAttrs& attrs,
 }
 
 NNVM_REGISTER_OP(LayerNorm)
+.add_alias("_npx_LayerNorm")
 .describe(R"code(Layer normalization.
 
 Normalizes the channels of the input tensor by mean and variance, and applies 
a scale ``gamma`` as
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index e44bbbb..8f1b2e0 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -68,6 +68,7 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& 
attrs,
 #endif
 
 NNVM_REGISTER_OP(softmax)
+.add_alias("_npx_softmax")
 .describe(R"code(Applies the softmax function.
 
 The resulting array contains elements in the range (0,1) and the elements 
along the given axis sum up to 1.
@@ -182,6 +183,7 @@ NNVM_REGISTER_OP(_backward_softmin)
                                                         mxnet_op::softmax_bwd, 
true>);
 
 NNVM_REGISTER_OP(log_softmax)
+.add_alias("_npx_log_softmax")
 .describe(R"code(Computes the log softmax of the input.
 This is equivalent to computing softmax followed by log.
 
diff --git a/src/operator/tensor/broadcast_reduce_op_index.cc 
b/src/operator/tensor/broadcast_reduce_op_index.cc
index 56af388..52082f7 100644
--- a/src/operator/tensor/broadcast_reduce_op_index.cc
+++ b/src/operator/tensor/broadcast_reduce_op_index.cc
@@ -110,6 +110,7 @@ Examples::
 
 NNVM_REGISTER_OP(pick)
 .add_alias("choose_element_0index")
+.add_alias("_npx_pick")
 .describe(R"code(Picks elements from an input array according to the input 
indices along the given axis.
 
 Given an input array of shape ``(d0, d1)`` and indices of shape ``(i0,)``, the 
result will be
diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc
index 7d7b6c0..11a0561 100644
--- a/src/operator/tensor/dot.cc
+++ b/src/operator/tensor/dot.cc
@@ -111,6 +111,7 @@ NNVM_REGISTER_OP(_backward_dot)
 .add_arguments(DotParam::__FIELDS__());
 
 NNVM_REGISTER_OP(batch_dot)
+.add_alias("_npx_batch_dot")
 .describe(R"doc(Batchwise dot product.
 
 ``batch_dot`` is used to compute dot product of ``x`` and ``y`` when ``x`` and
diff --git a/src/operator/tensor/matrix_op-inl.h 
b/src/operator/tensor/matrix_op-inl.h
index c547eb4..aa6e7bb 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -1787,9 +1787,6 @@ inline bool TileOpShape(const nnvm::NodeAttrs& attrs,
     SHAPE_ASSIGN_CHECK(*out_attrs, 0, ishape);
     return true;
   }
-  for (int i = 0; i < reps.ndim(); ++i) {
-    CHECK_GT(reps[i], 0) << "invalid reps=" << i << ", dim size must be 
greater than zero";
-  }
   mxnet::TShape oshape(std::max(ishape.ndim(), reps.ndim()), -1);
   int i1 = ishape.ndim() - 1;
   int i2 = reps.ndim() - 1;
@@ -1802,6 +1799,11 @@ inline bool TileOpShape(const nnvm::NodeAttrs& attrs,
       oshape[i] = reps[i2--];
     }
   }
+  // If reps contains 0s, oshape is a zero-size shape.
+  // Need to distinguish between np_shape mode and legacy mode.
+  if (!Imperative::Get()->is_np_shape()) {
+    common::ConvertToNumpyShape(&oshape);
+  }
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
   return shape_is_known(oshape);
 }
@@ -1820,7 +1822,7 @@ inline bool TileOpType(const nnvm::NodeAttrs& attrs,
 
 /*!
  * \brief Reshape the input and output tensors for
- * using broadcast_to to achieve the funcitonality
+ * using broadcast_to to achieve the functionality
  * of operator tile.
  * \return a pair of mxnet::TShape's, first is the reshaped
  * input shape, second is the reshaped output shape.
@@ -1828,7 +1830,7 @@ inline bool TileOpType(const nnvm::NodeAttrs& attrs,
 inline std::pair<mxnet::TShape, mxnet::TShape> ReshapeInputOutputForTileOp(
   const mxnet::TShape& ishape,
   const mxnet::Tuple<int>& reps) {
-  if (ishape.ndim() == 0 || reps.ndim() == 0) {
+  if (reps.ndim() == 0) {
     return std::make_pair(ishape, ishape);
   }
 
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 8743175..59e8386 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -773,6 +773,7 @@ NNVM_REGISTER_OP(_backward_clip)
 .set_attr<FCompute>("FCompute<cpu>", ClipGrad_<cpu>);
 
 NNVM_REGISTER_OP(repeat)
+.add_alias("_np_repeat")
 .describe(R"code(Repeats elements of an array.
 
 By default, ``repeat`` flattens the input array into 1-D and then repeats the
@@ -823,6 +824,7 @@ NNVM_REGISTER_OP(_backward_repeat)
 });
 
 NNVM_REGISTER_OP(tile)
+.add_alias("_npi_tile")
 .describe(R"code(Repeats the whole array multiple times.
 
 If ``reps`` has length *d*, and input array has dimension of *n*. There are
diff --git a/src/operator/tensor/ordering_op.cc 
b/src/operator/tensor/ordering_op.cc
index b0ade20..58c98f3 100644
--- a/src/operator/tensor/ordering_op.cc
+++ b/src/operator/tensor/ordering_op.cc
@@ -34,6 +34,7 @@ DMLC_REGISTER_PARAMETER(SortParam);
 DMLC_REGISTER_PARAMETER(ArgSortParam);
 
 NNVM_REGISTER_OP(topk)
+.add_alias("_npx_topk")
 .describe(R"code(Returns the top *k* elements in an input array along the 
given axis.
  The returned elements will be sorted.
 
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index 1243c8a..862c4d4 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -817,6 +817,47 @@ def test_np_split():
                     assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, 
atol=1e-5)
 
 
+@with_seed()
+@npx.use_np_shape
+def test_np_tile():
+    config = [
+        ((), ()),
+        ((), 0),
+        ((), (2, 0)),
+        ((), (2, 3)),
+        ((4, 2), (2,)),
+        ((4, 2), (2, 3)),
+        ((4, 2), (2, 1, 4)),
+        ((4, 2), (2, 3, 4)),
+        ((4, 2), (2, 0)),
+        ((4, 2), (2, 0, 3)),
+        ((4, 2), (2, 0, 3)),
+        ((4, 0), (2, 0, 3)),
+    ]
+
+    class TestTile(HybridBlock):
+        def __init__(self, reps):
+            super(TestTile, self).__init__()
+            self._reps = reps
+
+        def hybrid_forward(self, F, x):
+            return F.np.tile(x, reps=self._reps)
+
+    for shape, reps in config:
+        data_np = _np.random.uniform(size=shape)
+        data_mx = np.array(data_np, dtype=data_np.dtype)
+        ret_np = _np.tile(data_np, reps=reps)
+        ret_mx = np.tile(data_mx, reps=reps)
+        assert same(ret_mx.asnumpy(), ret_np)
+
+        net = TestTile(reps)
+        for hybrid in [False, True]:
+            if hybrid:
+                net.hybridize()
+            ret_mx = net(data_mx)
+            assert same(ret_mx.asnumpy(), ret_np)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to