This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 0565405699403698573224b4d8fc01700cae87e5
Author: reminisce <wujun....@gmail.com>
AuthorDate: Sun Jun 23 14:16:31 2019 -0700

    [numpy] Misc fix for other chapters (#15332)
    
    * Add np.prod
    
    * Fix ndarray.reshape accepting positional integers as arguments
    
    * Rebase
    
    * Fix rebase error
    
    * Add np.ndarray.flatten
    
    * Fix
    
    * Add broadcast_to
    
    * Add meshgrid and broadcast_arrays
    
    * Fix sin, cos, sinh, cosh not supporting scalars
    
    * Add more unary ops supporting python scalars
    
    * Fix
    
    * Fix
    
    * Fix ci
    
    * Fix sanity
---
 python/mxnet/_numpy_op_doc.py                      |  34 +++
 python/mxnet/gluon/block.py                        |  13 +-
 python/mxnet/gluon/data/vision/datasets.py         |   2 +
 python/mxnet/ndarray/ndarray.py                    |   2 +-
 python/mxnet/ndarray/numpy/_op.py                  | 220 +++++++++++++++--
 python/mxnet/ndarray/register.py                   |  20 +-
 python/mxnet/numpy/__init__.py                     |   8 +-
 python/mxnet/numpy/function_base.py                | 115 +++++++++
 .../{numpy_extension/__init__.py => numpy/io.py}   |  36 ++-
 python/mxnet/numpy/multiarray.py                   | 275 +++++++++++++++++----
 python/mxnet/numpy/stride_tricks.py                |  56 +++++
 python/mxnet/numpy/utils.py                        | 107 +-------
 python/mxnet/numpy_extension/__init__.py           |   1 +
 python/mxnet/{numpy => numpy_extension}/utils.py   |   2 +-
 python/mxnet/symbol/numpy/_symbol.py               | 240 ++++++++++++++++--
 python/mxnet/symbol/numpy/linalg.py                |   5 +-
 python/mxnet/symbol/register.py                    |   8 +-
 src/operator/numpy/np_broadcast_reduce_op.h        |  67 ++++-
 src/operator/numpy/np_broadcast_reduce_op_value.cc |  75 +++++-
 src/operator/numpy/np_broadcast_reduce_op_value.cu |  12 +
 src/operator/numpy/np_elemwise_unary_op_basic.cc   |  12 +-
 src/operator/numpy/np_elemwise_unary_op_basic.cu   |  12 +-
 src/operator/tensor/broadcast_reduce_op.h          |  36 +--
 tests/python/unittest/test_numpy_ndarray.py        |  10 +-
 tests/python/unittest/test_numpy_op.py             | 104 +++++++-
 25 files changed, 1210 insertions(+), 262 deletions(-)

diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py
index ab81732..995a65c 100644
--- a/python/mxnet/_numpy_op_doc.py
+++ b/python/mxnet/_numpy_op_doc.py
@@ -139,3 +139,37 @@ def _npi_multinomial(a):
         In other words, each entry ``out[i,j,...,:]`` is an N-dimensional 
value drawn from the distribution.
     """
     pass
+
+
+def _np_cumsum(a, axis=None, dtype=None, out=None):
+    """
+    Return the cumulative sum of the elements along a given axis.
+
+    Parameters
+    ----------
+    a : array_like
+        Input array.
+    axis : int, optional
+        Axis along which the cumulative sum is computed. The default
+        (None) is to compute the cumsum over the flattened array.
+    dtype : dtype, optional
+        Type of the returned array and of the accumulator in which the
+        elements are summed.  If `dtype` is not specified, it defaults
+        to the dtype of `a`, unless `a` has an integer dtype with a
+        precision less than that of the default platform integer.  In
+        that case, the default platform integer is used.
+    out : ndarray, optional
+        Alternative output array in which to place the result. It must
+        have the same shape and buffer length as the expected output
+        but the type will be cast if necessary. See `doc.ufuncs`
+        (Section "Output arguments") for more details.
+
+    Returns
+    -------
+    cumsum_along_axis : ndarray.
+        A new array holding the result is returned unless `out` is
+        specified, in which case a reference to `out` is returned. The
+        result has the same size as `a`, and the same shape as `a` if
+        `axis` is not None or `a` is a 1-d array.
+    """
+    pass
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 7866cfb..5b8b2e8 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -36,7 +36,7 @@ from .parameter import Parameter, ParameterDict, 
DeferredInitializationError
 from .utils import _indent, _brief_print_list, HookHandle
 from .utils import _check_same_symbol_type, _check_all_np_ndarrays
 from .. import numpy_extension as _mx_npx
-from .. import numpy as _mx_np
+from .. import numpy as _mx_np, numpy_extension as _mx_npx
 from .. util import is_np_array
 
 
@@ -336,10 +336,8 @@ class Block(object):
         """
         params = self._collect_params_with_prefix()
         arg_dict = {key : val._reduce() for key, val in params.items()}
-        if is_np_array():
-            _mx_np.save(filename, arg_dict)
-        else:
-            ndarray.save(filename, arg_dict)
+        save_fn = _mx_npx.save if is_np_array() else ndarray.save
+        save_fn(filename, arg_dict)
 
     def save_params(self, filename):
         """[Deprecated] Please use save_parameters. Note that if you want load
@@ -389,7 +387,7 @@ class Block(object):
         
<https://mxnet.incubator.apache.org/tutorials/gluon/save_load_params.html>`_
         """
         if is_np_array():
-            loaded = _mx_np.load(filename)
+            loaded = _mx_npx.load(filename)
         else:
             loaded = ndarray.load(filename)
         params = self._collect_params_with_prefix()
@@ -920,7 +918,8 @@ class HybridBlock(Block):
             else:
                 assert name in aux_names
                 arg_dict['aux:%s'%name] = param._reduce()
-        ndarray.save('%s-%04d.params'%(path, epoch), arg_dict)
+        save_fn = _mx_npx.save if is_np_array() else ndarray.save
+        save_fn('%s-%04d.params'%(path, epoch), arg_dict)
 
     def forward(self, x, *args):
         """Defines the forward computation. Arguments can be either
diff --git a/python/mxnet/gluon/data/vision/datasets.py 
b/python/mxnet/gluon/data/vision/datasets.py
index c580502..362cc9e 100644
--- a/python/mxnet/gluon/data/vision/datasets.py
+++ b/python/mxnet/gluon/data/vision/datasets.py
@@ -83,6 +83,8 @@ class MNIST(dataset._DownloadedDataset):
         with gzip.open(label_file, 'rb') as fin:
             struct.unpack(">II", fin.read(8))
             label = np.frombuffer(fin.read(), dtype=np.uint8).astype(np.int32)
+            if is_np_array():
+                label = _mx_np.array(label, dtype=label.dtype)
 
         with gzip.open(data_file, 'rb') as fin:
             struct.unpack(">IIII", fin.read(16))
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 5ddc9f7..09f76a8 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -2408,7 +2408,7 @@ def _get_broadcast_shape(shape1, shape2):
     for a, b in zip(shape1[::-1], shape2[::-1]):
         if a != 1 and b != 1 and a != b:
             raise ValueError('shape1=%s is not broadcastable to shape2=%s' % 
(shape1, shape2))
-        shape[i] = max(a, b)
+        shape[i] = b if a == 1 else a
         i -= 1
     return tuple(shape)
 
diff --git a/python/mxnet/ndarray/numpy/_op.py 
b/python/mxnet/ndarray/numpy/_op.py
index cf14d89..449f495 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -27,7 +27,8 @@ from ..ndarray import NDArray
 
 __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
            'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 
'concatenate',
-           'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace']
+           'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace',
+           'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt']
 
 
 @set_module('mxnet.ndarray.numpy')
@@ -99,29 +100,29 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, 
lfn_scalar, rfn_scalar=None, ou
 
     Parameters
     --------
-    lhs : NDArray or numeric value
+    lhs : ndarray or numeric value
         Left-hand side operand.
 
-    rhs : NDArray or numeric value
+    rhs : ndarray or numeric value
         Right-hand operand,
 
     fn_array : function
-        Function to be called if both lhs and rhs are of ``NDArray`` type.
+        Function to be called if both lhs and rhs are of ``ndarray`` type.
 
     fn_scalar : function
         Function to be called if both lhs and rhs are numeric values.
 
     lfn_scalar : function
-        Function to be called if lhs is ``NDArray`` while rhs is numeric value
+        Function to be called if lhs is ``ndarray`` while rhs is numeric value
 
     rfn_scalar : function
-        Function to be called if lhs is numeric value while rhs is ``NDArray``;
+        Function to be called if lhs is numeric value while rhs is ``ndarray``;
         if none is provided, then the function is commutative, so rfn_scalar 
is equal to lfn_scalar
 
     Returns
     --------
-    mxnet.numpy.ndarray
-        result array
+    mxnet.numpy.ndarray or scalar
+        result array or scalar
     """
     from ...numpy import ndarray
     if isinstance(lhs, numeric_types):
@@ -138,7 +139,7 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, 
lfn_scalar, rfn_scalar=None, ou
     elif isinstance(rhs, ndarray):
         return fn_array(lhs, rhs, out=out)
     else:
-        raise TypeError('type %s not supported' % str(type(rhs)))
+        raise TypeError('type {} not supported'.format(str(type(rhs))))
 #pylint: enable= too-many-arguments, no-member, protected-access
 
 
@@ -633,7 +634,7 @@ def tile(A, reps):
 
 
 @set_module('mxnet.ndarray.numpy')
-def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, 
axis=0, **kwargs): #pylint: disable=too-many-arguments
+def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, 
axis=0, **kwargs):  # pylint: disable=too-many-arguments
     """Return evenly spaced numbers over a specified interval.
 
     Returns num evenly spaced samples, calculated over the interval [start, 
stop].
@@ -653,15 +654,16 @@ def linspace(start, stop, num=50, endpoint=True, 
retstep=False, dtype=None, axis
     endpoint : bool, optional
         If True, stop is the last sample. Otherwise, it is not included.
         Default is True.
-    retstep: bool, optional
+    retstep : bool, optional
         If True, return (samples, step), where step is the spacing between 
samples.
-    dtype: dtype, optional
+    dtype : dtype, optional
         The type of the output array. If dtype is not given, infer the data
         type from the other input arguments.
     axis : int, optional
         The axis in the result to store the samples. Relevant only if start or
         stop are array-like. By default (0), the samples will be along a new
         axis inserted at the beginning. Use -1 to get an axis at the end.
+
     Returns
     -------
     samples : ndarray
@@ -678,7 +680,7 @@ def linspace(start, stop, num=50, endpoint=True, 
retstep=False, dtype=None, axis
     axis could only be 0 now.
     """
     if isinstance(start, (list, _np.ndarray, NDArray)) or \
-        isinstance(stop, (list, _np.ndarray, NDArray)):
+       isinstance(stop, (list, _np.ndarray, NDArray)):
         raise NotImplementedError('start and stop only support int')
     if axis != 0:
         raise NotImplementedError("the function only support axis 0")
@@ -687,6 +689,196 @@ def linspace(start, stop, num=50, endpoint=True, 
retstep=False, dtype=None, axis
         ctx = current_context()
     if retstep:
         step = (stop - start) / (num - 1)
-        return (_npi.linspace(start=start, stop=stop, num=num, 
endpoint=endpoint, ctx=ctx, dtype=dtype), step)
+        return _npi.linspace(start=start, stop=stop, num=num, 
endpoint=endpoint, ctx=ctx, dtype=dtype), step
     else:
         return _npi.linspace(start=start, stop=stop, num=num, 
endpoint=endpoint, ctx=ctx, dtype=dtype)
+
+
+def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs):
+    """Helper function for unary operators.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input of the unary operator.
+    fn_array : function
+        Function to be called if x is of ``ndarray`` type.
+    fn_scalar : function
+        Function to be called if x is a Python scalar.
+    out : ndarray
+        The buffer ndarray for storing the result of the unary function.
+
+    Returns
+    -------
+    out : mxnet.numpy.ndarray or scalar
+        Result array or scalar.
+    """
+    if isinstance(x, numeric_types):
+        return fn_scalar(x, **kwargs)
+    elif isinstance(x, NDArray):
+        return fn_array(x, out=out, **kwargs)
+    else:
+        raise TypeError('type {} not supported'.format(str(type(x))))
+
+
+@set_module('mxnet.ndarray.numpy')
+def sin(x, out=None, **kwargs):
+    r"""Trigonometric sine, element-wise.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Angle, in radians (:math:`2 \pi` rad equals 360 degrees).
+    out : ndarray or None
+        A location into which the result is stored. If provided, it
+        must have a shape that the inputs broadcast to. If not provided
+        or None, a freshly-allocated array is returned. The dtype of the
+        output is the same as that of the input if the input is an ndarray.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The sine of each element of x. This is a scalar if `x` is a scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _unary_func_helper(x, _npi.sin, _np.sin, out=out, **kwargs)
+
+
+@set_module('mxnet.ndarray.numpy')
+def cos(x, out=None, **kwargs):
+    r"""Cosine, element-wise.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Angle, in radians (:math:`2 \pi` rad equals 360 degrees).
+    out : ndarray or None
+        A location into which the result is stored. If provided, it
+        must have a shape that the inputs broadcast to. If not provided
+        or None, a freshly-allocated array is returned. The dtype of the
+        output is the same as that of the input if the input is an ndarray.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The corresponding cosine values. This is a scalar if x is a scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _unary_func_helper(x, _npi.cos, _np.cos, out=out, **kwargs)
+
+
+@set_module('mxnet.ndarray.numpy')
+def sinh(x, out=None, **kwargs):
+    """Hyperbolic sine, element-wise.
+
+    Equivalent to ``1/2 * (np.exp(x) - np.exp(-x))`` or ``-1j * np.sin(1j*x)``.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input array or scalar.
+    out : ndarray or None
+        A location into which the result is stored. If provided, it
+        must have a shape that the inputs broadcast to. If not provided
+        or None, a freshly-allocated array is returned. The dtype of the
+        output is the same as that of the input if the input is an ndarray.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The corresponding hyperbolic sine values. This is a scalar if `x` is a 
scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _unary_func_helper(x, _npi.sinh, _np.sinh, out=out, **kwargs)
+
+
+@set_module('mxnet.ndarray.numpy')
+def cosh(x, out=None, **kwargs):
+    """Hyperbolic cosine, element-wise.
+
+    Equivalent to ``1/2 * (np.exp(x) + np.exp(-x))`` and ``np.cos(1j*x)``.
+
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input array or scalar.
+    out : ndarray or None
+        A location into which the result is stored. If provided, it
+        must have a shape that the inputs broadcast to. If not provided
+        or None, a freshly-allocated array is returned. The dtype of the
+        output is the same as that of the input if the input is an ndarray.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The corresponding hyperbolic cosine values. This is a scalar if `x` is 
a scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _unary_func_helper(x, _npi.cosh, _np.cosh, out=out, **kwargs)
+
+
+@set_module('mxnet.ndarray.numpy')
+def log10(x, out=None, **kwargs):
+    """Return the base 10 logarithm of the input array, element-wise.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input array or scalar.
+    out : ndarray or None
+        A location into which the result is stored. If provided, it
+        must have a shape that the inputs broadcast to. If not provided
+        or None, a freshly-allocated array is returned. The dtype of the
+        output is the same as that of the input if the input is an ndarray.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The logarithm to the base 10 of `x`, element-wise. NaNs are
+        returned where x is negative. This is a scalar if `x` is a scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _unary_func_helper(x, _npi.log10, _np.log10, out=out, **kwargs)
+
+
+@set_module('mxnet.ndarray.numpy')
+def sqrt(x, out=None, **kwargs):
+    """
+    Return the non-negative square-root of an array, element-wise.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        The values whose square-roots are required.
+    out : ndarray, or None, optional
+        A location into which the result is stored. If provided, it must have
+        a shape that the inputs broadcast to. If not provided or `None`,
+        a freshly-allocated array is returned.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        An array of the same shape as `x`, containing the positive
+        square-root of each element in `x`. This is a scalar if `x` is a 
scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _unary_func_helper(x, _npi.sqrt, _np.sqrt, out=out, **kwargs)
diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py
index 20e6223..bdbfa15 100644
--- a/python/mxnet/ndarray/register.py
+++ b/python/mxnet/ndarray/register.py
@@ -49,9 +49,11 @@ def _verify_all_np_ndarrays(op_name, func_name, args, out):
             raise TypeError('Operator `{}` registered in backend is known as 
`{}` in Python. '
                             'This is a numpy operator which can only accept '
                             'MXNet numpy ndarrays, while received a legacy 
ndarray. '
-                            'Please call `as_np_ndarray()` upon the legacy 
ndarray to '
-                            'convert it to an MXNet numpy ndarray, and then 
feed the converted '
-                            'array to this operator.'
+                            'Please ensure that you have activated numpy 
semantics by calling '
+                            '`npx.set_np()` in your code. If you still see 
this error with numpy '
+                            'semantics activated, please call 
`as_np_ndarray()` upon the legacy '
+                            'ndarray to convert it to an MXNet numpy ndarray, 
and then feed the '
+                            'converted array to this operator.'
                             .format(op_name, func_name))
     if out is None:
         return
@@ -60,11 +62,13 @@ def _verify_all_np_ndarrays(op_name, func_name, args, out):
     for arr in out:
         if (arr is not None) and (not isinstance(arr, np_ndarray)):
             raise TypeError('Operator `{}` registered in backend is known as 
`{}` in Python. '
-                            'This is a numpy operator which can only write to 
MXNet numpy '
-                            'ndarrays, while received a legacy ndarray. '
-                            'Please call `as_np_ndarray()` upon the legacy 
ndarray to '
-                            'convert it to an MXNet numpy ndarray, and then 
feed the converted '
-                            'array to this operator.'
+                            'This is a numpy operator which can only accept '
+                            'MXNet numpy ndarrays, while received a legacy 
ndarray. '
+                            'Please ensure that you have activated numpy 
semantics by calling '
+                            '`npx.set_np()` in your code. If you still see 
this error with numpy '
+                            'semantics activated, please call 
`as_np_ndarray()` upon the legacy '
+                            'ndarray to convert it to an MXNet numpy ndarray, 
and then feed the '
+                            'converted array to this operator.'
                             .format(op_name, func_name))
 
 
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py
index 266c2fa..7a9a2f6 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/__init__.py
@@ -15,9 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""Module for numpy ops used in imperative programming."""
+"""MXNet NumPy module."""
+
+from __future__ import division, absolute_import, print_function
 
-from __future__ import absolute_import
 from . import random
 from . import linalg
 from .multiarray import *  # pylint: disable=wildcard-import
@@ -25,5 +26,8 @@ from . import _op
 from . import _register
 from ._op import *  # pylint: disable=wildcard-import
 from .utils import *  # pylint: disable=wildcard-import
+from .function_base import *  # pylint: disable=wildcard-import
+from .stride_tricks import *  # pylint: disable=wildcard-import
+from .io import *  # pylint: disable=wildcard-import
 
 __all__ = []
diff --git a/python/mxnet/numpy/function_base.py 
b/python/mxnet/numpy/function_base.py
new file mode 100644
index 0000000..e8e07c7
--- /dev/null
+++ b/python/mxnet/numpy/function_base.py
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Numpy basic functions."""
+from __future__ import absolute_import
+
+from .stride_tricks import broadcast_arrays
+
+__all__ = ['meshgrid']
+
+
+def meshgrid(*xi, **kwargs):
+    """
+    Return coordinate matrices from coordinate vectors.
+
+    Make N-D coordinate arrays for vectorized evaluations of
+    N-D scalar/vector fields over N-D grids, given
+    one-dimensional coordinate arrays x1, x2,..., xn.
+
+    Parameters
+    ----------
+    x1, x2,..., xn : ndarrays
+        1-D arrays representing the coordinates of a grid.
+    indexing : {'xy', 'ij'}, optional
+        Cartesian ('xy', default) or matrix ('ij') indexing of output.
+        See Notes for more details.
+
+    sparse : bool, optional
+        If True a sparse grid is returned in order to conserve memory.
+        Default is False. Please note that `sparse=True` is currently
+        not supported.
+
+    copy : bool, optional
+        If False, a view into the original arrays are returned in order to
+        conserve memory.  Default is True. Please note that `copy=False`
+        is currently not supported.
+
+    Returns
+    -------
+    X1, X2,..., XN : ndarray
+        For vectors `x1`, `x2`,..., 'xn' with lengths ``Ni=len(xi)`` ,
+        return ``(N1, N2, N3,...Nn)`` shaped arrays if indexing='ij'
+        or ``(N2, N1, N3,...Nn)`` shaped arrays if indexing='xy'
+        with the elements of `xi` repeated to fill the matrix along
+        the first dimension for `x1`, the second for `x2` and so on.
+
+    Notes
+    -----
+    This function supports both indexing conventions through the indexing
+    keyword argument.  Giving the string 'ij' returns a meshgrid with
+    matrix indexing, while 'xy' returns a meshgrid with Cartesian indexing.
+    In the 2-D case with inputs of length M and N, the outputs are of shape
+    (N, M) for 'xy' indexing and (M, N) for 'ij' indexing.  In the 3-D case
+    with inputs of length M, N and P, outputs are of shape (N, M, P) for
+    'xy' indexing and (M, N, P) for 'ij' indexing.  The difference is
+    illustrated by the following code snippet::
+
+        xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij')
+        for i in range(nx):
+            for j in range(ny):
+                # treat xv[i,j], yv[i,j]
+
+        xv, yv = np.meshgrid(x, y, sparse=False, indexing='xy')
+        for i in range(nx):
+            for j in range(ny):
+                # treat xv[j,i], yv[j,i]
+
+    In the 1-D and 0-D case, the indexing and sparse keywords have no effect.
+    """
+    ndim = len(xi)
+
+    copy_ = kwargs.pop('copy', True)
+    if not copy_:
+        raise NotImplementedError('copy=False is not implemented')
+    sparse = kwargs.pop('sparse', False)
+    if sparse:
+        raise NotImplementedError('sparse=False is not implemented')
+    indexing = kwargs.pop('indexing', 'xy')
+
+    if kwargs:
+        raise TypeError("meshgrid() got an unexpected keyword argument '%s'"
+                        % (list(kwargs)[0],))
+
+    if indexing not in ['xy', 'ij']:
+        raise ValueError(
+            "Valid values for `indexing` are 'xy' and 'ij'.")
+
+    s0 = (1,) * ndim
+    output = [x.reshape(s0[:i] + (-1,) + s0[i + 1:])
+              for i, x in enumerate(xi)]
+
+    if indexing == 'xy' and ndim > 1:
+        # switch first and second axis
+        output[0] = output[0].reshape(1, -1, *s0[2:])
+        output[1] = output[1].reshape(-1, 1, *s0[2:])
+
+    if not sparse:
+        # Return the full N-D matrix (not only the 1-D vector)
+        output = broadcast_arrays(*output)
+
+    return output
diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy/io.py
similarity index 52%
copy from python/mxnet/numpy_extension/__init__.py
copy to python/mxnet/numpy/io.py
index 0e2d005..aece13f 100644
--- a/python/mxnet/numpy_extension/__init__.py
+++ b/python/mxnet/numpy/io.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -17,17 +15,29 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""Module for ops not belonging to the official numpy package for imperative 
programming."""
 
+"""I/O functions for ndarrays."""
 from __future__ import absolute_import
-from . import _op
-from . import _register
-from ._op import *  # pylint: disable=wildcard-import
-from ..context import *  # pylint: disable=wildcard-import
-# TODO(junwu): revisit what functions should be exposed to users
-from ..util import use_np_shape, np_shape, is_np_shape
-from ..util import use_np_array, np_array, is_np_array
-from ..util import set_np, use_np, reset_np
-from ..ndarray import waitall
+import numpy as onp
+from ..context import current_context
+from .multiarray import array
+
+__all__ = ['genfromtxt']
+
+
+# TODO(junwu): Add doc
+def genfromtxt(*args, **kwargs):
+    """This is a wrapper of the official NumPy's `genfromtxt` function.
+    Please refer to the documentation here
+    https://docs.scipy.org/doc/numpy/reference/generated/numpy.genfromtxt.html.
 
-__all__ = []
+    Notes
+    -----
+    This function has added an additional parameter `ctx` which allows to 
create
+    ndarrays on the user-specified device.
+    """
+    ctx = kwargs.pop('ctx', current_context())
+    if ctx is None:
+        ctx = current_context()
+    ret = onp.genfromtxt(*args, **kwargs)
+    return array(ret, dtype=ret.dtype, ctx=ctx)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index dd13c8e..2a37af7 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -45,7 +45,8 @@ from ..ndarray.numpy import _internal as _npi
 
 __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 
'stack', 'arange',
            'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 
'concatenate',
-           'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace']
+           'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 
'sin', 'cos',
+           'sinh', 'cosh', 'log10', 'sqrt']
 
 
 # This function is copied from ndarray.py since pylint
@@ -356,6 +357,9 @@ class ndarray(NDArray):
 
     def __len__(self):
         """Number of elements along the first axis."""
+        shape = self.shape
+        if len(shape) == 0:
+            raise TypeError('len() of unsized object')
         return self.shape[0]
 
     def __reduce__(self):
@@ -419,21 +423,20 @@ class ndarray(NDArray):
         return self
 
     def __repr__(self):
-        """Returns a string representation of the array using the following 
rules:
-        1. If the `ndarray` is a scalar tensor, only the string of the scalar 
is returned.
-        2. Else if the `ndarray` is allocated on cpu, the string of its numpy 
form, class name,
-        and shape is returned.
-        3. Else (the `ndarray` is allocated on gpu), the string of its numpy 
form, class name,
-        shape, and context is returned."""
-        array_str = str(self.asnumpy())
-        if self.ndim == 0:  # scalar tensor
+        """Returns a string representation of the array."""
+        array_str = self.asnumpy().__repr__()
+        context = self.context
+        if context.device_type == 'cpu':
             return array_str
+        return array_str[:-1] + ', ctx={})'.format(str(context))
+
+    def __str__(self):
+        """Returns a string representation of the array."""
+        array_str = self.asnumpy().__str__()
         context = self.context
-        if context.device_type == 'gpu':
-            return '%s\n<%s shape=%s ctx=%s>' % (array_str, 
self.__class__.__name__, self.shape,
-                                                 context)
-        else:
-            return '%s\n<%s shape=%s>' % (array_str, self.__class__.__name__, 
self.shape)
+        if context.device_type == 'cpu' or self.ndim == 0:
+            return array_str
+        return '{array} @{ctx}'.format(array=array_str, ctx=context)
 
     def attach_grad(self, grad_req='write'):  # pylint: 
disable=arguments-differ
         """Attach a gradient buffer to this ndarray, so that `backward`
@@ -570,12 +573,33 @@ class ndarray(NDArray):
     def dot(self, b, out=None):
         return _mx_np_op.dot(self, b, out=out)
 
-    def reshape(self, shape, order='C'):  # pylint: disable=arguments-differ
-        """Returns an array containing the same data with a new shape."""
-        if order != 'C':
-            raise NotImplementedError('reshape only supports C-order,'
-                                      ' while received {}'.format(order))
-        return _mx_np_op.reshape(self, newshape=shape, order=order)
+    def reshape(self, *args, **kwargs):  # pylint: disable=arguments-differ
+        """Returns an array containing the same data with a new shape.
+
+        Notes
+        -----
+        Unlike the free function `numpy.reshape`, this method on `ndarray` 
allows
+        the elements of the shape parameter to be passed in as separate 
arguments.
+        For example, ``a.reshape(10, 11)`` is equivalent to
+        ``a.reshape((10, 11))``.
+        """
+        order = 'C'
+        if len(kwargs) > 1:
+            raise TypeError('function takes at most 1 keyword argument')
+        if len(kwargs) == 1:
+            if 'order' not in kwargs:
+                raise TypeError('{} is an invalid keyword argument for this 
function'
+                                .format(kwargs.keys()[0]))
+            order = kwargs.pop('order', 'C')
+            if order != 'C':
+                raise NotImplementedError('only supports C-order,'
+                                          ' while received {}'.format(order))
+        if len(args) == 0:
+            raise TypeError('reshape() takes exactly 1 argument (0 given)')
+        if len(args) == 1 and isinstance(args[0], tuple):
+            return _mx_np_op.reshape(self, newshape=args[0], order=order)
+        else:
+            return _mx_np_op.reshape(self, newshape=args, order=order)
 
     def reshape_like(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`reshape_like`.
@@ -753,13 +777,9 @@ class ndarray(NDArray):
         """
         raise AttributeError('mxnet.numpy.ndarray object has no attribute abs')
 
-    def flatten(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`flatten`.
-
-        The arguments are the same as for :py:func:`flatten`, with
-        this array as data.
-        """
-        raise NotImplementedError
+    def flatten(self, order='C'):  # pylint: disable=arguments-differ
+        """Return a copy of the array collapsed into one dimension."""
+        return self.reshape(-1, order=order)
 
     def shape_array(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`shape_array`.
@@ -849,13 +869,9 @@ class ndarray(NDArray):
         """
         raise AttributeError('mxnet.numpy.ndarray object has no attribute 
nansum')
 
-    def prod(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`prod`.
-
-        The arguments are the same as for :py:func:`prod`, with
-        this array as data.
-        """
-        raise NotImplementedError
+    def prod(self, axis=None, dtype=None, out=None, keepdims=False):  # 
pylint: disable=arguments-differ
+        """Return the product of the array elements over the given axis."""
+        return _mx_np_op.prod(self, axis=axis, dtype=dtype, keepdims=keepdims, 
out=out)
 
     def nanprod(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`nanprod`.
@@ -866,20 +882,25 @@ class ndarray(NDArray):
         raise AttributeError('mxnet.numpy.ndarray object has no attribute 
nanprod')
 
     def mean(self, axis=None, dtype=None, out=None, keepdims=False):  # 
pylint: disable=arguments-differ
-        """Convenience fluent method for :py:func:`mean`.
+        """Returns the average of the array elements along given axis."""
+        return _mx_np_op.mean(self, axis=axis, dtype=dtype, keepdims=keepdims, 
out=out)
 
-        The arguments are the same as for :py:func:`mean`, with
-        this array as data.
-        """
-        return _mx_nd_np.mean(self, axis=axis, dtype=dtype, keepdims=keepdims, 
out=out)
+    # TODO(junwu): Use mxnet std op instead of onp.std
+    def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False):  # 
pylint: disable=arguments-differ
+        """Returns the standard deviation of the array elements along given 
axis."""
+        ret_np = self.asnumpy().std(axis=axis, dtype=dtype, out=out, 
ddof=ddof, keepdims=keepdims)
+        return array(ret_np, dtype=ret_np.dtype, ctx=self.context)
 
-    def max(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`max`.
+    def cumsum(self, axis=None, dtype=None, out=None):
+        """Return the cumulative sum of the elements along the given axis."""
+        return _mx_np_op.cumsum(self, axis=axis, dtype=dtype, out=out)
 
-        The arguments are the same as for :py:func:`max`, with
-        this array as data.
-        """
-        raise NotImplementedError
+    def tolist(self):
+        return self.asnumpy().tolist()
+
+    def max(self, axis=None, out=None, keepdims=False):  # pylint: 
disable=arguments-differ
+        """Return the maximum along a given axis."""
+        return _mx_np_op.max(self, axis=axis, keepdims=keepdims, out=out)
 
     def min(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`min`.
@@ -1699,7 +1720,7 @@ def swapaxes(a, axis1, axis2):
 def expand_dims(a, axis):
     """Expand the shape of an array.
 
-    Insert a new axis that will appear at the `axis` position in the expanded
+    Insert a new axis that will appear at the `axis` position in the expanded 
array shape.
 
     Parameters
     ----------
@@ -1833,3 +1854,165 @@ def linspace(start, stop, num=50, endpoint=True, 
retstep=False, dtype=None, axis
         Size of spacing between samples.
     """
     return _mx_nd_np.linspace(start, stop, num, endpoint, retstep, dtype, 
axis, **kwargs)
+
+
+@set_module('mxnet.numpy')
+def sin(x, out=None, **kwargs):
+    r"""Trigonometric sine, element-wise.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Angle, in radians (:math:`2 \pi` rad equals 360 degrees).
+    out : ndarray or None
+        A location into which the result is stored. If provided, it
+        must have a shape that the inputs broadcast to. If not provided
+        or None, a freshly-allocated array is returned. The dtype of the
+        output is the same as that of the input if the input is an ndarray.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The sine of each element of x. This is a scalar if `x` is a scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _mx_nd_np.sin(x, out=out, **kwargs)
+
+
+@set_module('mxnet.numpy')
+def cos(x, out=None, **kwargs):
+    r"""Cosine, element-wise.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Angle, in radians (:math:`2 \pi` rad equals 360 degrees).
+    out : ndarray or None
+        A location into which the result is stored. If provided, it
+        must have a shape that the inputs broadcast to. If not provided
+        or None, a freshly-allocated array is returned. The dtype of the
+        output is the same as that of the input if the input is an ndarray.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The corresponding cosine values. This is a scalar if x is a scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _mx_nd_np.cos(x, out=out, **kwargs)
+
+
+def sinh(x, out=None, **kwargs):
+    """Hyperbolic sine, element-wise.
+
+    Equivalent to ``1/2 * (np.exp(x) - np.exp(-x))`` or ``-1j * np.sin(1j*x)``.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input array or scalar.
+    out : ndarray or None
+        A location into which the result is stored. If provided, it
+        must have a shape that the inputs broadcast to. If not provided
+        or None, a freshly-allocated array is returned. The dtype of the
+        output is the same as that of the input if the input is an ndarray.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The corresponding hyperbolic sine values. This is a scalar if `x` is a 
scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _mx_nd_np.sinh(x, out=out, **kwargs)
+
+
+@set_module('mxnet.numpy')
+def cosh(x, out=None, **kwargs):
+    """Hyperbolic cosine, element-wise.
+
+    Equivalent to ``1/2 * (np.exp(x) + np.exp(-x))`` and ``np.cos(1j*x)``.
+
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input array or scalar.
+    out : ndarray or None
+        A location into which the result is stored. If provided, it
+        must have a shape that the inputs broadcast to. If not provided
+        or None, a freshly-allocated array is returned. The dtype of the
+        output is the same as that of the input if the input is an ndarray.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The corresponding hyperbolic cosine values. This is a scalar if `x` is 
a scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _mx_nd_np.cosh(x, out=out, **kwargs)
+
+
+@set_module('mxnet.numpy')
+def log10(x, out=None, **kwargs):
+    """Return the base 10 logarithm of the input array, element-wise.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input array or scalar.
+    out : ndarray or None
+        A location into which the result is stored. If provided, it
+        must have a shape that the inputs broadcast to. If not provided
+        or None, a freshly-allocated array is returned. The dtype of the
+        output is the same as that of the input if the input is an ndarray.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The logarithm to the base 10 of `x`, element-wise. NaNs are
+        returned where x is negative. This is a scalar if `x` is a scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _mx_nd_np.log10(x, out=out, **kwargs)
+
+
+@set_module('mxnet.numpy')
+def sqrt(x, out=None, **kwargs):
+    """
+    Return the non-negative square-root of an array, element-wise.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        The values whose square-roots are required.
+    out : ndarray, or None, optional
+        A location into which the result is stored. If provided, it must have
+        a shape that the inputs broadcast to. If not provided or `None`,
+        a freshly-allocated array is returned.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        An array of the same shape as `x`, containing the positive
+        square-root of each element in `x`. This is a scalar if `x` is a 
scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _mx_nd_np.sqrt(x, out=out, **kwargs)
diff --git a/python/mxnet/numpy/stride_tricks.py 
b/python/mxnet/numpy/stride_tricks.py
new file mode 100644
index 0000000..1848a29
--- /dev/null
+++ b/python/mxnet/numpy/stride_tricks.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Util functions with broadcast."""
+
+from ..ndarray.ndarray import _get_broadcast_shape
+from . import _op as _mx_np_op
+
+
+__all__ = ['broadcast_arrays']
+
+
+def _broadcast_shape(*args):
+    shape = ()
+    for arr in args:
+        shape = _get_broadcast_shape(shape, arr.shape)
+    return shape
+
+
+def broadcast_arrays(*args):
+    """
+    Broadcast any number of arrays against each other.
+
+    Parameters
+    ----------
+    `*args` : a list of ndarrays
+        The arrays to broadcast.
+
+    Returns
+    -------
+    broadcasted : list of arrays
+        These arrays are copies of the original arrays unless that all the 
input
+        arrays have the same shape, the input list of arrays are returned
+        instead of a list of copies.
+    """
+    shape = _broadcast_shape(*args)
+
+    if all(array.shape == shape for array in args):
+        # Common case where nothing needs to be broadcasted.
+        return args
+
+    return [_mx_np_op.broadcast_to(array, shape) for array in args]
diff --git a/python/mxnet/numpy/utils.py b/python/mxnet/numpy/utils.py
index 48a47a3..920897e 100644
--- a/python/mxnet/numpy/utils.py
+++ b/python/mxnet/numpy/utils.py
@@ -20,103 +20,16 @@
 
 from __future__ import absolute_import
 
-import ctypes
-from .. util import is_np_array, is_np_shape
-from .. base import _LIB, check_call, string_types, c_str_array
-from .. base import c_handle_array, c_str, mx_uint, NDArrayHandle, py_str
-from . import ndarray
+import numpy as onp
 
-__all__ = ['save', 'load']
+__all__ = ['float16', 'float32', 'float64', 'uint8', 'int32', 'int8', 'int64', 
'pi']
 
+float16 = onp.float16
+float32 = onp.float32
+float64 = onp.float64
+uint8 = onp.uint8
+int32 = onp.int32
+int8 = onp.int8
+int64 = onp.int64
 
-def save(file, arr):
-    """Saves a list of `ndarray`s or a dict of `str`->`ndarray` to file.
-
-    Examples of filenames:
-
-    - ``/path/to/file``
-    - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports)
-    - ``hdfs://path/to/file`` (if compiled with HDFS supports)
-
-    Parameters
-    ----------
-    file : str
-        Filename to which the data is saved.
-    arr : `ndarray` or list of `ndarray`s or dict of `str` to `ndarray`
-        The data to be saved.
-
-    Notes
-    -----
-    This function can only be called within numpy semantics, i.e., 
`npx.is_np_shape()`
-    and `npx.is_np_array()` must both return true.
-    """
-    if not (is_np_shape() and is_np_array()):
-        raise ValueError('Cannot save `mxnet.numpy.ndarray` in legacy mode. 
Please activate'
-                         ' numpy semantics by calling `npx.set_np()` in the 
global scope'
-                         ' before calling this function.')
-    if isinstance(arr, ndarray):
-        arr = [arr]
-    if isinstance(arr, dict):
-        str_keys = arr.keys()
-        nd_vals = arr.values()
-        if any(not isinstance(k, string_types) for k in str_keys) or \
-                any(not isinstance(v, ndarray) for v in nd_vals):
-            raise TypeError('Only accepts dict str->ndarray or list of 
ndarrays')
-        keys = c_str_array(str_keys)
-        handles = c_handle_array(nd_vals)
-    elif isinstance(arr, list):
-        if any(not isinstance(v, ndarray) for v in arr):
-            raise TypeError('Only accepts dict str->ndarray or list of 
ndarrays')
-        keys = None
-        handles = c_handle_array(arr)
-    else:
-        raise ValueError("data needs to either be a ndarray, dict of (str, 
ndarray) pairs "
-                         "or a list of ndarrays.")
-    check_call(_LIB.MXNDArraySave(c_str(file),
-                                  mx_uint(len(handles)),
-                                  handles,
-                                  keys))
-
-
-def load(file):
-    """Loads an array from file.
-
-    See more details in ``save``.
-
-    Parameters
-    ----------
-    file : str
-        The filename.
-
-    Returns
-    -------
-    result : list of ndarrays or dict of str -> ndarray
-        Data stored in the file.
-
-    Notes
-    -----
-    This function can only be called within numpy semantics, i.e., 
`npx.is_np_shape()`
-    and `npx.is_np_array()` must both return true.
-    """
-    if not (is_np_shape() and is_np_array()):
-        raise ValueError('Cannot load `mxnet.numpy.ndarray` in legacy mode. 
Please activate'
-                         ' numpy semantics by calling `npx.set_np()` in the 
global scope'
-                         ' before calling this function.')
-    if not isinstance(file, string_types):
-        raise TypeError('file required to be a string')
-    out_size = mx_uint()
-    out_name_size = mx_uint()
-    handles = ctypes.POINTER(NDArrayHandle)()
-    names = ctypes.POINTER(ctypes.c_char_p)()
-    check_call(_LIB.MXNDArrayLoad(c_str(file),
-                                  ctypes.byref(out_size),
-                                  ctypes.byref(handles),
-                                  ctypes.byref(out_name_size),
-                                  ctypes.byref(names)))
-    if out_name_size.value == 0:
-        return [ndarray(NDArrayHandle(handles[i])) for i in 
range(out_size.value)]
-    else:
-        assert out_name_size.value == out_size.value
-        return dict(
-            (py_str(names[i]), ndarray(NDArrayHandle(handles[i])))
-            for i in range(out_size.value))
+pi = onp.pi
diff --git a/python/mxnet/numpy_extension/__init__.py 
b/python/mxnet/numpy_extension/__init__.py
index 0e2d005..d80f0cc 100644
--- a/python/mxnet/numpy_extension/__init__.py
+++ b/python/mxnet/numpy_extension/__init__.py
@@ -29,5 +29,6 @@ from ..util import use_np_shape, np_shape, is_np_shape
 from ..util import use_np_array, np_array, is_np_array
 from ..util import set_np, use_np, reset_np
 from ..ndarray import waitall
+from .utils import *  # pylint: disable=wildcard-import
 
 __all__ = []
diff --git a/python/mxnet/numpy/utils.py b/python/mxnet/numpy_extension/utils.py
similarity index 99%
copy from python/mxnet/numpy/utils.py
copy to python/mxnet/numpy_extension/utils.py
index 48a47a3..0aa89ba 100644
--- a/python/mxnet/numpy/utils.py
+++ b/python/mxnet/numpy_extension/utils.py
@@ -24,7 +24,7 @@ import ctypes
 from .. util import is_np_array, is_np_shape
 from .. base import _LIB, check_call, string_types, c_str_array
 from .. base import c_handle_array, c_str, mx_uint, NDArrayHandle, py_str
-from . import ndarray
+from ..numpy import ndarray
 
 __all__ = ['save', 'load']
 
diff --git a/python/mxnet/symbol/numpy/_symbol.py 
b/python/mxnet/symbol/numpy/_symbol.py
index e015b7a..55577e9 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -31,7 +31,7 @@ from . import _internal as _npi
 
 __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 
'arange', 'argmax',
            'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 
'split', 'swapaxes',
-           'expand_dims', 'tile', 'linspace']
+           'expand_dims', 'tile', 'linspace', 'sin', 'cos', 'sinh', 'cosh', 
'log10', 'sqrt']
 
 
 def _num_outputs(sym):
@@ -216,11 +216,33 @@ class _Symbol(Symbol):
     def dot(self, b, out=None):
         return _mx_np_op.dot(self, b, out=out)
 
-    def reshape(self, shape, order='C'):  # pylint: disable=arguments-differ
-        if order != 'C':
-            raise NotImplementedError('only supports order=\'C\', while 
received {}'
-                                      .format(str(order)))
-        return _mx_np_op.reshape(self, newshape=shape, order=order)
+    def reshape(self, *args, **kwargs):  # pylint: disable=arguments-differ
+        """Returns an array containing the same data with a new shape.
+
+        Notes
+        -----
+        Unlike the free function `numpy.reshape`, this method on `ndarray` 
allows
+        the elements of the shape parameter to be passed in as separate 
arguments.
+        For example, ``a.reshape(10, 11)`` is equivalent to
+        ``a.reshape((10, 11))``.
+        """
+        order = 'C'
+        if len(kwargs) > 1:
+            raise TypeError('function takes at most 1 keyword argument')
+        if len(kwargs) == 1:
+            if 'order' not in kwargs:
+                raise TypeError('{} is an invalid keyword argument for this 
function'
+                                .format(kwargs.keys()[0]))
+            order = kwargs.pop('order', 'C')
+            if order != 'C':
+                raise NotImplementedError('only supports C-order,'
+                                          ' while received {}'.format(order))
+        if len(args) == 0:
+            raise TypeError('reshape() takes exactly 1 argument (0 given)')
+        if len(args) == 1 and isinstance(args[0], tuple):
+            return _mx_np_op.reshape(self, newshape=args[0], order=order)
+        else:
+            return _mx_np_op.reshape(self, newshape=args, order=order)
 
     def argmax(self, axis=None, out=None):  # pylint: disable=arguments-differ
         return _mx_np_op.argmax(self, axis, out)
@@ -401,13 +423,9 @@ class _Symbol(Symbol):
         """
         raise AttributeError('_Symbol object has no attribute abs')
 
-    def flatten(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`flatten`.
-
-        The arguments are the same as for :py:func:`flatten`, with
-        this array as data.
-        """
-        raise NotImplementedError
+    def flatten(self, order='C'):  # pylint: disable=arguments-differ
+        """Return a copy of the array collapsed into one dimension."""
+        return self.reshape(-1, order=order)
 
     def shape_array(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`shape_array`.
@@ -497,13 +515,9 @@ class _Symbol(Symbol):
         """
         raise AttributeError('_Symbol object has no attribute nansum')
 
-    def prod(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`prod`.
-
-        The arguments are the same as for :py:func:`prod`, with
-        this array as data.
-        """
-        raise NotImplementedError
+    def prod(self, axis=None, dtype=None, out=None, keepdims=False):  # 
pylint: disable=arguments-differ
+        """Return the product of the array elements over the given axis."""
+        return _mx_np_op.prod(self, axis=axis, dtype=dtype, keepdims=keepdims, 
out=out)
 
     def nanprod(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`nanprod`.
@@ -521,13 +535,13 @@ class _Symbol(Symbol):
         """
         return _mx_np_op.mean(self, axis=axis, dtype=dtype, keepdims=keepdims, 
out=out)
 
-    def max(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`max`.
+    def cumsum(self, axis=None, dtype=None, out=None):
+        """Return the cumulative sum of the elements along the given axis."""
+        return _mx_np_op.cumsum(self, axis=axis, dtype=dtype, out=out)
 
-        The arguments are the same as for :py:func:`max`, with
-        this array as data.
-        """
-        raise NotImplementedError
+    def max(self, axis=None, out=None, keepdims=False):  # pylint: 
disable=arguments-differ
+        """Return the maximum along a given axis."""
+        return _mx_np_op.max(self, axis=axis, keepdims=keepdims, out=out)
 
     def min(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`min`.
@@ -1367,4 +1381,178 @@ def linspace(start, stop, num=50, endpoint=True, 
retstep=False, dtype=None, axis
         return _npi.linspace(start=start, stop=stop, num=num, 
endpoint=endpoint, ctx=ctx, dtype=dtype)
 
 
+def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs):
+    """Helper function for unary operators.
+
+    Parameters
+    ----------
+    x : _Symbol or scalar
+        Input of the unary operator.
+    fn_array : function
+        Function to be called if x is of ``_Symbol`` type.
+    fn_scalar : function
+        Function to be called if x is a Python scalar.
+    out : _Symbol
+        Dummy parameter to keep the consistency with the ndarray counterpart.
+
+    Returns
+    -------
+    out : _Symbol or scalar
+        Result _Symbol or scalar.
+    """
+    if isinstance(x, numeric_types):
+        return fn_scalar(x, **kwargs)
+    elif isinstance(x, _Symbol):
+        return fn_array(x, out=out, **kwargs)
+    else:
+        raise TypeError('type {} not supported'.format(str(type(x))))
+
+
+@set_module('mxnet.symbol.numpy')
+def sin(x, out=None, **kwargs):
+    r"""Trigonometric sine, element-wise.
+
+    Parameters
+    ----------
+    x : _Symbol or scalar
+        Angle, in radians (:math:`2 \pi` rad equals 360 degrees).
+    out : _Symbol or None
+        Dummy parameter to keep the consistency with the ndarray counterpart.
+
+    Returns
+    -------
+    y : _Symbol
+        The sine of each element of x.
+        This is a scalar if `x` is a scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _unary_func_helper(x, _npi.sin, _np.sin, out=out, **kwargs)
+
+
+@set_module('mxnet.symbol.numpy')
+def cos(x, out=None, **kwargs):
+    r"""Cosine, element-wise.
+
+    Parameters
+    ----------
+    x : _Symbol or scalar
+        Angle, in radians (:math:`2 \pi` rad equals 360 degrees).
+    out : _Symbol or None
+        Dummy parameter to keep the consistency with the ndarray counterpart.
+
+    Returns
+    -------
+    y : _Symbol
+        The corresponding cosine values. This is a scalar if x is a scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _unary_func_helper(x, _npi.cos, _np.cos, out=out, **kwargs)
+
+
+@set_module('mxnet.symbol.numpy')
+def sinh(x, out=None, **kwargs):
+    """Hyperbolic sine, element-wise.
+
+    Equivalent to ``1/2 * (np.exp(x) - np.exp(-x))`` or ``-1j * np.sin(1j*x)``.
+
+    Parameters
+    ----------
+    x : _Symbol or scalar
+        Input array or scalar.
+    out : _Symbol or None
+        Dummy parameter to keep the consistency with the ndarray counterpart.
+
+    Returns
+    -------
+    y : _Symbol or scalar
+        The corresponding hyperbolic sine values. This is a scalar if `x` is a 
scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _unary_func_helper(x, _npi.sinh, _np.sinh, out=out, **kwargs)
+
+
+@set_module('mxnet.symbol.numpy')
+def cosh(x, out=None, **kwargs):
+    """Hyperbolic cosine, element-wise.
+
+    Equivalent to ``1/2 * (np.exp(x) + np.exp(-x))`` and ``np.cos(1j*x)``.
+
+
+    Parameters
+    ----------
+    x : _Symbol or scalar
+        Input array or scalar.
+    out : ndarray or None
+        Dummy parameter to keep the consistency with the ndarray counterpart.
+
+    Returns
+    -------
+    y : _Symbol or scalar
+        The corresponding hyperbolic cosine values. This is a scalar if `x` is 
a scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _unary_func_helper(x, _npi.cosh, _np.cosh, out=out, **kwargs)
+
+
+@set_module('mxnet.symbol.numpy')
+def log10(x, out=None, **kwargs):
+    """Return the base 10 logarithm of the input array, element-wise.
+
+    Parameters
+    ----------
+    x : _Symbol or scalar
+        Input array or scalar.
+    out : _Symbol or None
+        Dummy parameter to keep the consistency with the ndarray counterpart.
+
+    Returns
+    -------
+    y : _Symbol or scalar
+        The logarithm to the base 10 of `x`, element-wise. NaNs are
+        returned where x is negative. This is a scalar if `x` is a scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _unary_func_helper(x, _npi.log10, _np.log10, out=out, **kwargs)
+
+
+@set_module('mxnet.symbol.numpy')
+def sqrt(x, out=None, **kwargs):
+    """
+    Return the non-negative square-root of an array, element-wise.
+
+    Parameters
+    ----------
+    x : _Symbol or scalar
+        The values whose square-roots are required.
+    out : _Symbol, or None, optional
+        Dummy parameter to keep the consistency with the ndarray counterpart.
+
+    Returns
+    -------
+    y : _Symbol or scalar
+        An array of the same shape as `x`, containing the positive
+        square-root of each element in `x`. This is a scalar if `x` is a 
scalar.
+
+    Notes
+    ----
+    This function only supports input type of float.
+    """
+    return _unary_func_helper(x, _npi.sqrt, _np.sqrt, out=out, **kwargs)
+
+
 _set_np_symbol_class(_Symbol)
diff --git a/python/mxnet/symbol/numpy/linalg.py 
b/python/mxnet/symbol/numpy/linalg.py
index 2cb0d22..d1918ef 100644
--- a/python/mxnet/symbol/numpy/linalg.py
+++ b/python/mxnet/symbol/numpy/linalg.py
@@ -18,7 +18,8 @@
 """Namespace for operators used in Gluon dispatched by F=symbol."""
 
 from __future__ import absolute_import
-from . import _op as _mx_nd_np
+from . import _symbol
+from . import _op as _mx_sym_np
 
 __all__ = ['norm']
 
@@ -64,4 +65,4 @@ def norm(x, ord=None, axis=None, keepdims=False):
     if isinstance(axis, tuple) and len(axis) > 2:
         raise ValueError('Improper number of dimensions to norm')
     # TODO(junwu): When ord = 'fro', axis = None, and x.ndim > 2, raise 
exception
-    return _mx_nd_np.sqrt(_mx_nd_np.sum(x * x, axis=axis, keepdims=keepdims))
+    return _symbol.sqrt(_mx_sym_np.sum(x * x, axis=axis, keepdims=keepdims))
diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py
index 365a088..a17dd79 100644
--- a/python/mxnet/symbol/register.py
+++ b/python/mxnet/symbol/register.py
@@ -49,9 +49,11 @@ def _verify_np_symbol(op_name, func_name, sym):
         raise TypeError('Operator `{}` registered in backend is known as `{}` 
in Python. '
                         'This is a numpy operator which can only accept '
                         'MXNet numpy ndarrays, while received a legacy 
ndarray. '
-                        'Please call `as_np_ndarray()` upon the legacy ndarray 
to '
-                        'convert it to an MXNet numpy ndarray, and then feed 
the converted '
-                        'array to this operator.'
+                        'Please ensure that you have activated numpy semantics 
by calling '
+                        '`npx.set_np()` in your code. If you still see this 
error with numpy '
+                        'semantics activated, please call `as_np_ndarray()` 
upon the legacy '
+                        'ndarray to convert it to an MXNet numpy ndarray, and 
then feed the '
+                        'converted array to this operator.'
                         .format(op_name, func_name))
 
 
diff --git a/src/operator/numpy/np_broadcast_reduce_op.h 
b/src/operator/numpy/np_broadcast_reduce_op.h
index c76b596..3e28f0a 100644
--- a/src/operator/numpy/np_broadcast_reduce_op.h
+++ b/src/operator/numpy/np_broadcast_reduce_op.h
@@ -289,10 +289,10 @@ inline void NumpyReduceAxesBackwardUseNone(const 
nnvm::NodeAttrs& attrs,
 
 template<typename xpu, typename OP>
 void NumpyMaxBackward(const nnvm::NodeAttrs& attrs,
-                                const OpContext& ctx,
-                                const std::vector<TBlob>& inputs,
-                                const std::vector<OpReqType>& req,
-                                const std::vector<TBlob>& outputs) {
+                      const OpContext& ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
   using namespace mshadow;
   using namespace mshadow::expr;
   const NumpyMaxParam& param = nnvm::get<NumpyMaxParam>(attrs.parsed);
@@ -305,6 +305,65 @@ void NumpyMaxBackward(const nnvm::NodeAttrs& attrs,
   ReduceAxesBackwardUseInOutImpl<xpu, OP, false>(ctx, small, inputs, req, 
outputs);
 }
 
+template<typename xpu, typename OP, bool normalize = false>
+void NumpyReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs,
+                                     const OpContext& ctx,
+                                     const std::vector<TBlob>& inputs,
+                                     const std::vector<OpReqType>& req,
+                                     const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  const NumpyReduceAxesParam& param = 
nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
+  TShape small;
+  if (param.keepdims) {
+    small = inputs[0].shape_;
+  } else {
+    small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true);
+  }
+  ReduceAxesBackwardUseInOutImpl<xpu, OP, normalize>(ctx, small, inputs, req, 
outputs);
+}
+
+template<typename xpu>
+void NumpyBroadcastToForward(const nnvm::NodeAttrs& attrs,
+                             const OpContext& ctx,
+                             const std::vector<TBlob>& inputs,
+                             const std::vector<OpReqType>& req,
+                             const std::vector<TBlob>& outputs) {
+  if (outputs[0].shape_.Size() == 0U) return;  // zero-size tensor
+  TShape expanded_ishape(outputs[0].shape_.ndim(), 1);
+  const TShape& ishape = inputs[0].shape_;
+  CHECK_LE(ishape.ndim(), expanded_ishape.ndim()) << "output ndim cannot be 
less than input ndim";
+  const int ndim_delta = expanded_ishape.ndim() - ishape.ndim();
+  for (int i = 0; i < ishape.ndim(); ++i) {
+    expanded_ishape[i + ndim_delta] = ishape[i];
+  }
+  BroadcastComputeImpl<xpu>(attrs, ctx, {inputs[0].reshape(expanded_ishape)},
+                            req, outputs, expanded_ishape);
+}
+
+template<typename xpu>
+void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs,
+                              const OpContext& ctx,
+                              const std::vector<TBlob>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<TBlob>& outputs) {
+  TShape expanded_igrad_shape(inputs[0].shape_.ndim(), 1);
+  const TShape& igrad_shape = outputs[0].shape_;
+  CHECK_LE(igrad_shape.ndim(), expanded_igrad_shape.ndim())
+      << "output ndim cannot be less than input ndim";
+  const int ndim_delta = expanded_igrad_shape.ndim() - igrad_shape.ndim();
+  for (int i = 0; i < igrad_shape.ndim(); ++i) {
+    expanded_igrad_shape[i + ndim_delta] = igrad_shape[i];
+  }
+  if (NeedSafeAcc<true>(inputs[0].type_flag_, outputs[0].type_flag_)) {
+    ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
+        ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, 
expanded_igrad_shape);
+  } else {
+    ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(
+        ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, 
expanded_igrad_shape);
+  }
+}
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc 
b/src/operator/numpy/np_broadcast_reduce_op_value.cc
index 168fe59..d8234c5 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cc
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc
@@ -103,7 +103,6 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs,
 }
 
 NNVM_REGISTER_OP(_np_mean)
-.describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<NumpyReduceAxesParam>)
@@ -141,7 +140,7 @@ inline bool NumpyMaxType(const nnvm::NodeAttrs& attrs,
 }
 
 NNVM_REGISTER_OP(_np_max)
-.describe(R"code()code" ADD_FILELINE)
+.add_alias("_np_amax")
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<NumpyMaxParam>)
@@ -167,5 +166,77 @@ NNVM_REGISTER_OP(_backward_np_max)
 .set_num_inputs(3)
 .set_attr<FCompute>("FCompute<cpu>", NumpyMaxBackward<cpu, mshadow_op::eq>);
 
+NNVM_REGISTER_OP(_np_prod)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyReduceAxesParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxesShape)
+.set_attr<nnvm::FInferType>("FInferType", NumpySumType)
+.add_arguments(NumpyReduceAxesParam::__FIELDS__())
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"a"};
+  })
+.add_argument("a", "NDArray-or-Symbol", "The input")
+.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesCompute<cpu, 
mshadow_op::product, true>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<nnvm::FGradient>("FGradient", ReduceGrad{"_backward_np_prod"});
+
+NNVM_REGISTER_OP(_backward_np_prod)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyReduceAxesParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesBackwardUseInOut<cpu, 
mshadow_op::rdiv>);
+
+bool NumpyBroadcastToShape(const nnvm::NodeAttrs& attrs,
+                           mxnet::ShapeVector *in_attrs,
+                           mxnet::ShapeVector *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  mxnet::TShape& ishape = (*in_attrs)[0];
+  if (!mxnet::shape_is_known(ishape)) return false;
+  const BroadcastToParam& param = nnvm::get<BroadcastToParam>(attrs.parsed);
+  CHECK(mxnet::shape_is_known(param.shape))
+      << "the objective shape for broadcasting array must be known";
+  CHECK_LE(ishape.ndim(), param.shape.ndim())
+      << "shape " << ishape << " is not broadcastable to " << param.shape;
+  for (int i = param.shape.ndim() - 1; i >= 0; --i) {
+    int j = i - param.shape.ndim() + ishape.ndim();
+    if (j < 0) break;
+    CHECK(ishape[j] == param.shape[i] || ishape[j] == 1)
+        << "shape " << ishape << " is not broadcastable to " << param.shape;
+  }
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
+  return true;
+}
+
+NNVM_REGISTER_OP(_np_broadcast_to)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FGradient>("FGradient",
+  [](const nnvm::NodePtr& n,
+     const std::vector<nnvm::NodeEntry>& ograds) {
+    return MakeNonlossGradNode("_backward_np_broadcast_to", n, ograds, {}, 
n->attrs.dict);
+  })
+.add_argument("array", "NDArray-or-Symbol", "The input")
+.set_attr_parser(ParamParser<BroadcastToParam>)
+.add_arguments(BroadcastToParam::__FIELDS__())
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyBroadcastToShape)
+.set_attr<FCompute>("FCompute<cpu>", NumpyBroadcastToForward<cpu>);
+
+NNVM_REGISTER_OP(_backward_np_broadcast_to)
+.set_attr_parser(ParamParser<BroadcastToParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", NumpyBroadcastToBackward<cpu>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  });
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu 
b/src/operator/numpy/np_broadcast_reduce_op_value.cu
index 49bef09..a0a6472 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu
@@ -45,5 +45,17 @@ NNVM_REGISTER_OP(_np_max)
 NNVM_REGISTER_OP(_backward_np_max)
 .set_attr<FCompute>("FCompute<gpu>", NumpyMaxBackward<gpu, mshadow_op::eq>);
 
+NNVM_REGISTER_OP(_np_prod)
+.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, 
mshadow_op::product, true>);
+
+NNVM_REGISTER_OP(_backward_np_prod)
+.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseInOut<gpu, 
mshadow_op::rdiv>);
+
+NNVM_REGISTER_OP(_np_broadcast_to)
+.set_attr<FCompute>("FCompute<gpu>", NumpyBroadcastToForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_np_broadcast_to)
+.set_attr<FCompute>("FCompute<gpu>", NumpyBroadcastToBackward<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc 
b/src/operator/numpy/np_elemwise_unary_op_basic.cc
index 1acec6f..4932ee8 100644
--- a/src/operator/numpy/np_elemwise_unary_op_basic.cc
+++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc
@@ -175,7 +175,7 @@ Example::
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_square"});
 
 // sqrt
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_sqrt, "x", mshadow_op::square_root)
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_sqrt, "x", mshadow_op::square_root)
 .describe(R"code(Return the non-negative square-root of an array, element-wise.
 Example::
    sqrt([4, 9, 16]) = [2, 3, 4]
@@ -220,7 +220,7 @@ The natural logarithm is logarithm in base *e*, so that 
``log(exp(x)) = x``
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_log"});
 
 // log10
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_log10, "x", mshadow_op::log10)
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_log10, "x", mshadow_op::log10)
 .describe(R"code(Returns element-wise Base-10 logarithmic value of the input.
 ``10**log10(x) = x``
 )code" ADD_FILELINE)
@@ -255,7 +255,7 @@ Example::
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
 
 // sin
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_sin, "x", mshadow_op::sin)
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_sin, "x", mshadow_op::sin)
 .describe(R"code(Trigonometric sine, element-wise.
 .. math::
    sin([0, \pi/4, \pi/2]) = [0, 0.707, 1]
@@ -263,7 +263,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_sin, "x", 
mshadow_op::sin)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_sin" });
 
 // cos
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_cos, "x", mshadow_op::cos)
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_cos, "x", mshadow_op::cos)
 .describe(R"code(Computes the element-wise cosine of the input array.
 .. math::
    cos([0, \pi/4, \pi/2]) = [1, 0.707, 0]
@@ -322,7 +322,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_radians, "x", 
mshadow_op::radians)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_radians" 
});
 
 // sinh
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_sinh, "x", mshadow_op::sinh)
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_sinh, "x", mshadow_op::sinh)
 .describe(R"code(Returns the hyperbolic sine of the input array, computed 
element-wise.
 .. math::
    sinh(x) = 0.5\times(exp(x) - exp(-x))
@@ -330,7 +330,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_sinh, "x", 
mshadow_op::sinh)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_sinh" });
 
 // cosh
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_cosh, "x", mshadow_op::cosh)
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_cosh, "x", mshadow_op::cosh)
 .describe(R"code(Returns the hyperbolic cosine  of the input array, computed 
element-wise.
 .. math::
    cosh(x) = 0.5\times(exp(x) + exp(-x))
diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu 
b/src/operator/numpy/np_elemwise_unary_op_basic.cu
index 1323768..887c74e 100644
--- a/src/operator/numpy/np_elemwise_unary_op_basic.cu
+++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu
@@ -59,7 +59,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_fix, 
mshadow_op::fix);
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_square, mshadow_op::square);
 
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_sqrt, mshadow_op::square_root);
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sqrt, mshadow_op::square_root);
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_cbrt, mshadow_op::cube_root);
 
@@ -68,7 +68,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_exp, 
mshadow_op::exp);
 NNVM_REGISTER_OP(_np_log)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::log>);
 
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_log10, mshadow_op::log10);
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_log10, mshadow_op::log10);
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_log2, mshadow_op::log2);
 
@@ -78,9 +78,9 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_expm1, 
mshadow_op::expm1);
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_logical_not, mshadow_op::nt);
 
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_sin, mshadow_op::sin);
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sin, mshadow_op::sin);
 
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_cos, mshadow_op::cos);
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_cos, mshadow_op::cos);
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_tan, mshadow_op::tan);
 
@@ -94,9 +94,9 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_degrees, 
mshadow_op::degrees);
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_radians, mshadow_op::radians);
 
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_sinh, mshadow_op::sinh);
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sinh, mshadow_op::sinh);
 
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_cosh, mshadow_op::cosh);
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_cosh, mshadow_op::cosh);
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_tanh, mshadow_op::tanh);
 
diff --git a/src/operator/tensor/broadcast_reduce_op.h 
b/src/operator/tensor/broadcast_reduce_op.h
index cba9821..07ce716 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -946,36 +946,36 @@ void ReduceAxesBackwardUseInOutImpl(const OpContext& ctx,
         }
       }
       if (dst_shape.ndim() == 2) {
-        Tensor<xpu, 2, DType> igrad =
-          outputs[0].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
-        Tensor<xpu, 2, OType> ograd =
-          inputs[0].get_with_shape<xpu, 2, OType>(dst_shape.get<2>(), s);
-        Tensor<xpu, 2, DType> data =
-          inputs[1].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
-        Tensor<xpu, 2, OType> out =
-          inputs[2].get_with_shape<xpu, 2, OType>(dst_shape.get<2>(), s);
+        Tensor<xpu, 2, OType> igrad =
+          outputs[0].get_with_shape<xpu, 2, OType>(src_shape.get<2>(), s);
+        Tensor<xpu, 2, DType> ograd =
+          inputs[0].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
+        Tensor<xpu, 2, OType> data =
+          inputs[1].get_with_shape<xpu, 2, OType>(src_shape.get<2>(), s);
+        Tensor<xpu, 2, DType> out =
+          inputs[2].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
         MXNET_REQ_TYPE_SWITCH(req[0], Req, {
           Kernel<reduce_axes_backward_broadcast<Req, OP>, xpu>::Launch(
             s, outputs[0].shape_.Size(), data.dptr_, out.dptr_, igrad.dptr_, 
ograd.dptr_,
             in_shape, out_shape, src_shape.ndim());
         });
-        if (normalize) igrad /= 
scalar<DType>(src_shape.Size()/dst_shape.Size());
+        if (normalize) igrad /= 
scalar<OType>(src_shape.Size()/dst_shape.Size());
       } else {
         const int ndim = MXNET_SPECIAL_MAX_NDIM;
-        Tensor<xpu, ndim, DType> igrad =
-          outputs[0].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), 
s);
-        Tensor<xpu, ndim, OType> ograd =
-          inputs[0].get_with_shape<xpu, ndim, OType>(dst_shape.get<ndim>(), s);
-        Tensor<xpu, ndim, DType> data =
-          inputs[1].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
-        Tensor<xpu, ndim, OType> out =
-          inputs[2].get_with_shape<xpu, ndim, OType>(dst_shape.get<ndim>(), s);
+        Tensor<xpu, ndim, OType> igrad =
+          outputs[0].get_with_shape<xpu, ndim, OType>(src_shape.get<ndim>(), 
s);
+        Tensor<xpu, ndim, DType> ograd =
+          inputs[0].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
+        Tensor<xpu, ndim, OType> data =
+          inputs[1].get_with_shape<xpu, ndim, OType>(src_shape.get<ndim>(), s);
+        Tensor<xpu, ndim, DType> out =
+          inputs[2].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
         MXNET_REQ_TYPE_SWITCH(req[0], Req, {
           Kernel<reduce_axes_backward_broadcast<Req, OP>, xpu>::Launch(
             s, outputs[0].shape_.Size(), data.dptr_, out.dptr_, igrad.dptr_, 
ograd.dptr_,
             in_shape, out_shape, src_shape.ndim());
         });
-        if (normalize) igrad /= 
scalar<DType>(src_shape.Size()/dst_shape.Size());
+        if (normalize) igrad /= 
scalar<OType>(src_shape.Size()/dst_shape.Size());
       }
     });
   });
diff --git a/tests/python/unittest/test_numpy_ndarray.py 
b/tests/python/unittest/test_numpy_ndarray.py
index e6e4911..c5a9279 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -636,8 +636,8 @@ def test_np_save_load_ndarrays():
     for i, arr in enumerate(array_list):
         with TemporaryDirectory() as work_dir:
             fname = os.path.join(work_dir, 'dataset.npy')
-            np.save(fname, arr)
-            arr_loaded = np.load(fname)
+            npx.save(fname, arr)
+            arr_loaded = npx.load(fname)
             assert isinstance(arr_loaded, list)
             assert len(arr_loaded) == 1
             assert _np.array_equal(arr_loaded[0].asnumpy(), 
array_list[i].asnumpy())
@@ -645,7 +645,7 @@ def test_np_save_load_ndarrays():
     # test save/load a list of ndarrays
     with TemporaryDirectory() as work_dir:
         fname = os.path.join(work_dir, 'dataset.npy')
-        np.save(fname, array_list)
+        npx.save(fname, array_list)
         array_list_loaded = mx.nd.load(fname)
         assert isinstance(arr_loaded, list)
         assert len(array_list) == len(array_list_loaded)
@@ -660,8 +660,8 @@ def test_np_save_load_ndarrays():
         arr_dict[k] = v
     with TemporaryDirectory() as work_dir:
         fname = os.path.join(work_dir, 'dataset.npy')
-        np.save(fname, arr_dict)
-        arr_dict_loaded = np.load(fname)
+        npx.save(fname, arr_dict)
+        arr_dict_loaded = npx.load(fname)
         assert isinstance(arr_dict_loaded, dict)
         assert len(arr_dict_loaded) == len(arr_dict)
         for k, v in arr_dict_loaded.items():
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index 7a43083..ac1da8c 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -356,7 +356,7 @@ def test_npx_sigmoid():
 def test_np_reshape():
     # TODO(junwu): Add more test cases
     data = mx.sym.var('a').as_np_ndarray()
-    ret = data.reshape(shape=())
+    ret = data.reshape(())
     assert type(ret) == mx.sym.np._Symbol
 
     data = np.ones((1, 1, 1))
@@ -365,6 +365,8 @@ def test_np_reshape():
     ret = np.reshape(ret, (1, 1, 1, 1))
     assert ret.shape == (1, 1, 1, 1)
     assert type(ret) == np.ndarray
+    ret2 = ret.reshape(1, 1, -1)
+    assert ret2.shape == (1, 1, 1)
 
 
 @with_seed()
@@ -1060,6 +1062,106 @@ def test_np_tile():
             assert same(ret_mx.asnumpy(), ret_np)
 
 
+@with_seed()
+@npx.use_np_shape
+def test_np_prod():
+    class TestProd(HybridBlock):
+        def __init__(self, axis=None, dtype=None, keepdims=False):
+            super(TestProd, self).__init__()
+            self._axis = axis
+            self._dtype = dtype
+            self._keepdims = keepdims
+
+        def hybrid_forward(self, F, a, *args, **kwargs):
+            return F.np.prod(a, axis=self._axis, dtype=self._dtype, 
keepdims=self._keepdims)
+
+    in_data_dim = random.choice([3, 4])
+    shape = rand_shape_nd(in_data_dim, dim=3)
+    for hybridize in [False, True]:
+        for keepdims in [True, False]:
+            for axis in ([i for i in range(in_data_dim)] + [(), None]):
+                for itype in ['float32', 'float64']:
+                    for dtype in ['float32', 'float64']:
+                        # test gluon
+                        test_prod = TestProd(axis=axis, dtype=dtype, 
keepdims=keepdims)
+                        if hybridize:
+                            test_prod.hybridize()
+                        x = np.random.uniform(-2.0, 2.0, size=shape, 
dtype=itype)
+                        x.attach_grad()
+                        print(x.grad.dtype)
+                        expected_ret = _np.prod(x.asnumpy(), axis=axis, 
keepdims=keepdims)
+                        expected_ret = expected_ret.astype(dtype)
+                        with mx.autograd.record():
+                            y = test_prod(x)
+                        assert y.shape == expected_ret.shape
+                        assert_almost_equal(y.asnumpy(), expected_ret, 
rtol=1e-3, atol=1e-5)
+                        y.backward()
+                        # use keepdims=True so that broadcast divide can be 
used to calculate
+                        # grad of input
+                        expected_ret = _np.prod(x.asnumpy(), axis=axis, 
keepdims=True)
+                        assert_almost_equal(x.grad.asnumpy(), expected_ret / 
x.asnumpy(), rtol=1e-3, atol=1e-3)
+
+                        # test numeric
+                        if itype == 'float32' and dtype == 'float32':
+                            x_sym = mx.sym.Variable("x").as_np_ndarray()
+                            mx_sym = mx.sym.np.prod(x_sym, axis=axis, 
dtype=dtype, keepdims=keepdims).as_nd_ndarray()
+                            check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
+                                                   numeric_eps=1e-3, 
rtol=1e-3, atol=1e-4, dtype=_np.float32)
+
+                        # test imperative
+                        mx_out = np.prod(x, axis=axis, dtype=dtype, 
keepdims=keepdims)
+                        np_out = _np.prod(x.asnumpy(), axis=axis, 
keepdims=keepdims).astype(dtype)
+                        assert_almost_equal(mx_out.asnumpy(), np_out, 
rtol=1e-3, atol=1e-5)
+
+
+@with_seed()
+@npx.use_np
+def test_np_flatten():
+    # TODO(junwu): Add more test cases
+    shapes = [(), (2, 0, 1), (3, 4, 5), 6]
+    for shape in shapes:
+        a = _np.random.uniform(size=shape).astype('float32')
+        a_mx = np.array(a, dtype=a.dtype)
+        expected_ret = a.flatten()
+        ret_mx = a_mx.flatten()
+        assert same(expected_ret, ret_mx.asnumpy())
+
+
+@with_seed()
+@npx.use_np
+def test_np_broadcast_to():
+    # TODO(junwu): Add more test cases and backward test
+    shapes = [(1, 2, 3, 4, 5), (1, 0, 3, 4, 5)]
+    for shape in shapes:
+        a = _np.random.uniform(size=(4, 1)).astype('float32')
+        a_mx = np.array(a, dtype=a.dtype)
+        expected_ret = _np.broadcast_to(a, shape)
+        ret_mx = np.broadcast_to(a_mx, shape)
+        assert same(expected_ret, ret_mx.asnumpy())
+
+
+@with_seed()
+@npx.use_np
+def test_np_meshgrid():
+    nx, ny = (4, 5)
+    x = np.linspace(0, 1, nx)
+    y = np.linspace(0, 1, ny)
+    z = np.ones(())
+    xv, yv, zv = np.meshgrid(x, y, z)
+    xv_expected, yv_expected, zv_expected = _np.meshgrid(x.asnumpy(), 
y.asnumpy(), z.asnumpy())
+    assert same(xv.asnumpy(), xv_expected)
+    assert same(yv.asnumpy(), yv_expected)
+    assert same(zv.asnumpy(), zv_expected)
+    # TODO(junwu): Add more test
+
+
+@with_seed()
+@npx.use_np
+def test_np_broadcast_arrays():
+    # TODO(junwu): Add test
+    pass
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to