This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch numpy_staging_prs
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/numpy_staging_prs by this push:
     new 6e2a58d  numpy operator arctan2 (#15890)
6e2a58d is described below

commit 6e2a58d917ad3b42d29a1c973ce7cb857c0e52d4
Author: tingying <tingying2...@u.northwestern.edu>
AuthorDate: Mon Sep 23 13:38:37 2019 +0800

    numpy operator arctan2 (#15890)
    
    * change the test code
    
    * add @use_np in test code
    
    * only support float16, float32 and float64.
    
    * fix format error
    
    * remove redundant backslash
    
    * change wrapper in symbol
    
    * delete gpu test
    
    * edit test
    
    * change infer type
    
    * remove redundant **kwargs
    
    * change atol and rtol in test
    
    * edit test shape
---
 python/mxnet/ndarray/numpy/_op.py              | 92 +++++++++++++++++++++++++-
 python/mxnet/numpy/multiarray.py               | 91 ++++++++++++++++++++++++-
 python/mxnet/symbol/numpy/_symbol.py           | 72 +++++++++++++++++++-
 src/operator/math_functions-inl.h              |  2 +
 src/operator/mshadow_op.h                      | 10 +++
 src/operator/numpy/np_elemwise_broadcast_op.cc | 70 ++++++++++++++++++++
 src/operator/numpy/np_elemwise_broadcast_op.cu | 19 ++++++
 src/operator/operator_tune.cc                  |  5 ++
 tests/python/unittest/test_numpy_op.py         | 67 +++++++++++++++++++
 9 files changed, 420 insertions(+), 8 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py 
b/python/mxnet/ndarray/numpy/_op.py
index 2cdfff1..197bae6 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -28,9 +28,9 @@ from . import _internal as _npi
 from ..ndarray import NDArray
 
 __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 
'mod', 'remainder', 'power',
-           'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 
'cbrt', 'abs', 'absolute',
-           'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 
'degrees', 'log2', 'log1p',
-           'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 
'ceil', 'floor',
+           'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 
'sqrt', 'cbrt', 'abs',
+           'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 
'log', 'degrees', 'log2',
+           'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 
'fix', 'ceil', 'floor',
            'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 
'tensordot',
            'linspace', 'expand_dims', 'tile', 'arange', 'split', 
'concatenate', 'stack', 'vstack', 'mean',
            'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 
'indices', 'copysign',
@@ -2953,3 +2953,89 @@ def around(x, decimals=0, out=None, **kwargs):
         return _npi.around(x, decimals, out=out, **kwargs)
     else:
         raise TypeError('type {} not supported'.format(str(type(x))))
+
+
+@set_module('mxnet.ndarray.numpy')
+def arctan2(x1, x2, out=None):
+    r"""
+    arctan2(x1, x2, out=None)
+
+    Element-wise arc tangent of ``x1/x2`` choosing the quadrant correctly.
+
+    The quadrant (i.e., branch) is chosen so that ``arctan2(x1, x2)`` is
+    the signed angle in radians between the ray ending at the origin and
+    passing through the point (1,0), and the ray ending at the origin and
+    passing through the point (`x2`, `x1`).  (Note the role reversal: the
+    "`y`-coordinate" is the first function parameter, the "`x`-coordinate"
+    is the second.)  By IEEE convention, this function is defined for
+    `x2` = +/-0 and for either or both of `x1` and `x2` = +/-inf (see
+    Notes for specific values).
+
+    This function is not defined for complex-valued arguments; for the
+    so-called argument of complex values, use `angle`.
+
+    Parameters
+    ----------
+    x1 : ndarray or scalar
+        `y`-coordinates.
+    x2 : ndarray or scalar
+        `x`-coordinates. `x2` must be broadcastable to match the shape of
+        `x1` or vice versa.
+    out : ndarray or None, optional
+        A location into which the result is stored. If provided, it must have
+        a shape that the inputs broadcast to. If not provided or `None`,
+        a freshly-allocated array is returned.
+
+    Returns
+    -------
+    out : ndarray or scalar
+        Array of angles in radians, in the range ``[-pi, pi]``. This is a 
scalar if
+        `x1` and `x2` are scalars.
+
+    Notes
+    -----
+    *arctan2* is identical to the `atan2` function of the underlying
+    C library.  The following special values are defined in the C
+    standard: [1]_
+
+    ====== ====== ================
+    `x1`   `x2`   `arctan2(x1,x2)`
+    ====== ====== ================
+    +/- 0  +0     +/- 0
+    +/- 0  -0     +/- pi
+        > 0   +/-inf +0 / +pi
+        < 0   +/-inf -0 / -pi
+    +/-inf +inf   +/- (pi/4)
+    +/-inf -inf   +/- (3*pi/4)
+    ====== ====== ================
+
+    Note that +0 and -0 are distinct floating point numbers, as are +inf
+    and -inf.
+
+    This function differs from the original numpy.arange in the following 
aspects:
+        - Only support float16, float32 and float64.
+
+    References
+    ----------
+    .. [1] ISO/IEC standard 9899:1999, "Programming language C."
+
+    Examples
+    --------
+    Consider four points in different quadrants:
+
+    >>> x = np.array([-1, +1, +1, -1])
+    >>> y = np.array([-1, -1, +1, +1])
+    >>> np.arctan2(y, x) * 180 / np.pi
+    array([-135.,  -45.,   45.,  135.])
+
+    Note the order of the parameters. `arctan2` is defined also when `x2` = 0
+    and at several other special points, obtaining values in
+    the range ``[-pi, pi]``:
+
+    >>> x = np.array([1, -1])
+    >>> y = np.array([0, 0])
+    >>> np.arctan2(x, y)
+    array([ 1.5707964, -1.5707964])
+    """
+    return _ufunc_helper(x1, x2, _npi.arctan2, _np.arctan2,
+                         _npi.arctan2_scalar, _npi.rarctan2_scalar, out=out)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index e0c7a67..0cd9036 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -47,13 +47,13 @@ from ..ndarray import numpy as _mx_nd_np
 from ..ndarray.numpy import _internal as _npi
 
 __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 
'subtract', 'multiply', 'divide',
-           'mod', 'remainder', 'power', 'sin', 'cos', 'tan', 'sinh', 'cosh', 
'tanh', 'log10', 'sqrt', 'cbrt',
-           'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 
'sign', 'log',
+           'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 
'sinh', 'cosh', 'tanh', 'log10',
+           'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 
'arccos', 'arctan', 'sign', 'log',
            'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 
'square', 'negative',
            'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 
'arccosh', 'arctanh',
            'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 
'concatenate',
            'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 
'clip', 'argmax', 'std', 'var', 'indices',
-           'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 
'around']
+           'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 
'around', 'arctan2']
 
 # Return code for dispatching indexing function call
 _NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -4481,3 +4481,88 @@ def around(x, decimals=0, out=None, **kwargs):
     array([ 0,  0,  0, 10])
     """
     return _mx_nd_np.around(x, decimals, out=out, **kwargs)
+
+
+@set_module('mxnet.numpy')
+def arctan2(x1, x2, out=None):
+    r"""
+    arctan2(x1, x2, out=None)
+
+    Element-wise arc tangent of ``x1/x2`` choosing the quadrant correctly.
+
+    The quadrant (i.e., branch) is chosen so that ``arctan2(x1, x2)`` is
+    the signed angle in radians between the ray ending at the origin and
+    passing through the point (1,0), and the ray ending at the origin and
+    passing through the point (`x2`, `x1`).  (Note the role reversal: the
+    "`y`-coordinate" is the first function parameter, the "`x`-coordinate"
+    is the second.)  By IEEE convention, this function is defined for
+    `x2` = +/-0 and for either or both of `x1` and `x2` = +/-inf (see
+    Notes for specific values).
+
+    This function is not defined for complex-valued arguments; for the
+    so-called argument of complex values, use `angle`.
+
+    Parameters
+    ----------
+    x1 : ndarray or scalar
+        `y`-coordinates.
+    x2 : ndarray or scalar
+        `x`-coordinates. `x2` must be broadcastable to match the shape of
+        `x1` or vice versa.
+    out : ndarray or None, optional
+        A location into which the result is stored. If provided, it must have
+        a shape that the inputs broadcast to. If not provided or `None`,
+        a freshly-allocated array is returned.
+
+    Returns
+    -------
+    out : ndarray or scalar
+        Array of angles in radians, in the range ``[-pi, pi]``. This is a 
scalar if
+        `x1` and `x2` are scalars.
+
+    Notes
+    -----
+    *arctan2* is identical to the `atan2` function of the underlying
+    C library.  The following special values are defined in the C
+    standard: [1]_
+
+    ====== ====== ================
+    `x1`   `x2`   `arctan2(x1,x2)`
+    ====== ====== ================
+    +/- 0  +0     +/- 0
+    +/- 0  -0     +/- pi
+        > 0   +/-inf +0 / +pi
+        < 0   +/-inf -0 / -pi
+    +/-inf +inf   +/- (pi/4)
+    +/-inf -inf   +/- (3*pi/4)
+    ====== ====== ================
+
+    Note that +0 and -0 are distinct floating point numbers, as are +inf
+    and -inf.
+
+    This function differs from the original numpy.arange in the following 
aspects:
+        - Only support float16, float32 and float64.
+
+    References
+    ----------
+    .. [1] ISO/IEC standard 9899:1999, "Programming language C."
+
+    Examples
+    --------
+    Consider four points in different quadrants:
+
+    >>> x = np.array([-1, +1, +1, -1])
+    >>> y = np.array([-1, -1, +1, +1])
+    >>> np.arctan2(y, x) * 180 / np.pi
+    array([-135.,  -45.,   45.,  135.])
+
+    Note the order of the parameters. `arctan2` is defined also when `x2` = 0
+    and at several other special points, obtaining values in
+    the range ``[-pi, pi]``:
+
+    >>> x = np.array([1, -1])
+    >>> y = np.array([0, 0])
+    >>> np.arctan2(x, y)
+    array([ 1.5707964, -1.5707964])
+    """
+    return _mx_nd_np.arctan2(x1, x2, out=out)
diff --git a/python/mxnet/symbol/numpy/_symbol.py 
b/python/mxnet/symbol/numpy/_symbol.py
index 1bfbba2..94a4a37 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -29,8 +29,8 @@ from ..symbol import Symbol
 from .._internal import _set_np_symbol_class
 from . import _internal as _npi
 
-__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 
'remainder', 'power', 'sin',
-           'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 
'abs', 'absolute', 'exp',
+__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 
'remainder', 'power', 'arctan2',
+           'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 
'cbrt', 'abs', 'absolute', 'exp',
            'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 
'log2', 'log1p',
            'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 
'ceil', 'floor',
            'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 
'tensordot',
@@ -3172,4 +3172,72 @@ def around(x, decimals=0, out=None, **kwargs):
         raise TypeError('type {} not supported'.format(str(type(x))))
 
 
+@set_module('mxnet.symbol.numpy')
+def arctan2(x1, x2, out=None):
+    r"""
+    arctan2(x1, x2, out=None)
+
+    Element-wise arc tangent of ``x1/x2`` choosing the quadrant correctly.
+
+    The quadrant (i.e., branch) is chosen so that ``arctan2(x1, x2)`` is
+    the signed angle in radians between the ray ending at the origin and
+    passing through the point (1,0), and the ray ending at the origin and
+    passing through the point (`x2`, `x1`).  (Note the role reversal: the
+    "`y`-coordinate" is the first function parameter, the "`x`-coordinate"
+    is the second.)  By IEEE convention, this function is defined for
+    `x2` = +/-0 and for either or both of `x1` and `x2` = +/-inf (see
+    Notes for specific values).
+
+    This function is not defined for complex-valued arguments; for the
+    so-called argument of complex values, use `angle`.
+
+    Parameters
+    ----------
+    x1 : _Symbol or scalar
+        `y`-coordinates.
+    x2 : _Symbol or scalar
+        `x`-coordinates. `x2` must be broadcastable to match the shape of
+        `x1` or vice versa.
+    out : _Symbol or None, optional
+        A location into which the result is stored. If provided, it must have
+        a shape that the inputs broadcast to. If not provided or `None`,
+        a freshly-allocated array is returned.
+
+    Returns
+    -------
+    out : _Symbol or scalar
+        Array of angles in radians, in the range ``[-pi, pi]``. This is a 
scalar if
+        `x1` and `x2` are scalars.
+
+    Notes
+    -----
+    *arctan2* is identical to the `atan2` function of the underlying
+    C library.  The following special values are defined in the C
+    standard: [1]_
+
+    ====== ====== ================
+    `x1`   `x2`   `arctan2(x1,x2)`
+    ====== ====== ================
+    +/- 0  +0     +/- 0
+    +/- 0  -0     +/- pi
+        > 0   +/-inf +0 / +pi
+        < 0   +/-inf -0 / -pi
+    +/-inf +inf   +/- (pi/4)
+    +/-inf -inf   +/- (3*pi/4)
+    ====== ====== ================
+
+    Note that +0 and -0 are distinct floating point numbers, as are +inf
+    and -inf.
+
+    This function differs from the original numpy.arange in the following 
aspects:
+        - Only support float16, float32 and float64.
+
+    References
+    ----------
+    .. [1] ISO/IEC standard 9899:1999, "Programming language C."
+    """
+    return _ufunc_helper(x1, x2, _npi.arctan2, _np.arctan2,
+                         _npi.arctan2_scalar, _npi.rarctan2_scalar, out=out)
+
+
 _set_np_symbol_class(_Symbol)
diff --git a/src/operator/math_functions-inl.h 
b/src/operator/math_functions-inl.h
index 45d74a6..5f95654 100644
--- a/src/operator/math_functions-inl.h
+++ b/src/operator/math_functions-inl.h
@@ -125,6 +125,8 @@ MXNET_BINARY_MATH_FUNC(hypot)
 
 MXNET_BINARY_MATH_FUNC(pow)
 
+MXNET_BINARY_MATH_FUNC(atan2)
+
 template<typename DType> MSHADOW_XINLINE
 float id(DType a) {
   return static_cast<float>(a);
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index f3d24b2..6261638 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -322,6 +322,16 @@ MXNET_BINARY_MATH_OP(rpower, math::pow(b, a));
 
 MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b));
 
+MXNET_BINARY_MATH_OP(arctan2, math::atan2(a, b));
+
+MXNET_BINARY_MATH_OP(arctan2_grad, math::id(b) / (math::id(a * a + b * b)));
+
+MXNET_BINARY_MATH_OP(arctan2_rgrad, -math::id(a) / (math::id(a * a + b * b)));
+
+MXNET_BINARY_MATH_OP(rarctan2, math::atan2(b, a));
+
+MXNET_BINARY_MATH_OP(rarctan2_grad, math::id(a) / (math::id(a * a + b * b)));
+
 MXNET_UNARY_MATH_OP_NC(nt, a != DType(0) ? DType(0) : DType(1));
 
 MXNET_BINARY_MATH_OP_NC(ge, a >= b ? DType(1) : DType(0));
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc 
b/src/operator/numpy/np_elemwise_broadcast_op.cc
index a9254e8..f9293ee 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cc
@@ -144,5 +144,75 @@ 
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar)
 .set_attr<FCompute>("FCompute<cpu>",
                     BinaryScalarOp::Backward<cpu, mshadow_op::rcopysign_grad>);
 
+inline bool IsFloatType(const int dtype) {
+  return (dtype == mshadow::kFloat16 ||
+          dtype == mshadow::kFloat32 ||
+          dtype == mshadow::kFloat64);
+}
+
+inline bool Arctan2OpType(const nnvm::NodeAttrs& attrs,
+                          std::vector<int>* in_attrs,
+                          std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
+  TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+  TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0));
+  // check if it is float16, float32 or float64. If not, raise error.
+  CHECK(IsFloatType(in_attrs->at(0))) << "Do not support `int` as input.\n";
+  return out_attrs->at(0) != -1;
+}
+
+NNVM_REGISTER_OP(_npi_arctan2)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"x1", "x2"};
+  })
+.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
+.set_attr<nnvm::FInferType>("FInferType", Arctan2OpType)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::arctan2>)
+.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_arctan2"})
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.add_argument("x1", "NDArray-or-Symbol", "The input array")
+.add_argument("x2", "NDArray-or-Symbol", "The input array");
+
+NNVM_REGISTER_OP(_backward_npi_arctan2)
+.set_num_inputs(3)
+.set_num_outputs(2)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, 
mshadow_op::arctan2_grad,
+                                                                  
mshadow_op::arctan2_rgrad>);
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_arctan2_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::arctan2>)
+.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_arctan2_scalar"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::rarctan2>)
+.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"});
+
+MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar)
+.add_argument("scalar", "float", "scalar value")
+.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = 
std::stod(attrs->dict["scalar"]); })
+.set_attr<FCompute>("FCompute<cpu>",
+                    BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_grad>);
+
+MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar)
+.add_argument("scalar", "float", "scalar value")
+.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = 
std::stod(attrs->dict["scalar"]); })
+.set_attr<FCompute>("FCompute<cpu>",
+                    BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_rgrad>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu 
b/src/operator/numpy/np_elemwise_broadcast_op.cu
index ecf8e85..ab76e5c 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cu
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cu
@@ -49,6 +49,13 @@ NNVM_REGISTER_OP(_backward_npi_copysign)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, 
mshadow_op::copysign_grad,
                                                                   
mshadow_op::copysign_rgrad>);
 
+NNVM_REGISTER_OP(_npi_arctan2)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::arctan2>);
+
+NNVM_REGISTER_OP(_backward_npi_arctan2)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, 
mshadow_op::arctan2_grad,
+                                                                  
mshadow_op::arctan2_rgrad>);
+
 NNVM_REGISTER_OP(_npi_add_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
op::mshadow_op::plus>);
 
@@ -87,5 +94,17 @@ NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar)
 .set_attr<FCompute>("FCompute<gpu>",
                     BinaryScalarOp::Backward<gpu, mshadow_op::rcopysign_grad>);
 
+NNVM_REGISTER_OP(_npi_arctan2_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::arctan2>);
+
+NNVM_REGISTER_OP(_backward_npi_arctan2_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::arctan2_grad>);
+
+NNVM_REGISTER_OP(_npi_rarctan2_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rarctan2>);
+
+NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rarctan2_grad>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 5159525..1d64438 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -333,6 +333,11 @@ 
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rcopysign);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_grad);  // 
NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_rgrad);  // 
NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rcopysign_grad);  // 
NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctan2);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rarctan2);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan2_grad);  // 
NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rarctan2_grad);  // 
NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan2_rgrad);  // 
NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gelu_grad); // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT()
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index 1324af0..3d30012 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -2380,6 +2380,73 @@ def test_np_around():
                     assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, 
atol=atol)    
 
 
+@with_seed()
+@use_np
+def test_np_arctan2():
+    class TestArctan2(HybridBlock):
+        def __init__(self):
+            super(TestArctan2, self).__init__()
+
+        def hybrid_forward(self, F, x1, x2):
+            return F.np.arctan2(x1, x2)
+
+    # Reduce dimension of src to dimension of des.
+    def dimReduce(src, des):
+        srcShape = src.shape
+        desShape = des.shape
+        if len(desShape) == 0:
+            return src.sum()
+        redu = []
+        for i, j in zip(range(len(srcShape)-1, -1, -1), range(len(desShape)-1, 
-1, -1)):
+            if srcShape[i] != desShape[j] and desShape[j] == 1:
+                redu.append(i)
+            if j == 0:
+                for k in range(0, i):
+                    redu.append(k)
+                break
+        if len(redu) > 0:
+            src = _np.reshape(src.sum(axis=tuple(redu)), desShape)
+        return src
+
+    types = ['float64', 'float32', 'float16']
+    for hybridize in [True, False]:
+        for shape1, shape2 in [[(3, 2), (3, 2)],  # tall matrices
+                               [(), ()],  # scalar only
+                               [(3, 0, 2), (3, 0, 2)],  # zero-dim
+                               [(3, 4, 5), (4, 1)],  # trailing dim 
broadcasting
+                               [(3, 4, 5), ()],  # scalar broadcasting
+                               [(), (1, 2, 3)],  # scalar broadcasting
+                               ]:
+            for oneType in types:
+                rtol = 1e-2 if oneType == 'float16' else 1e-3
+                atol = 1e-2 if oneType == 'float16' else 1e-5
+                test_arctan2 = TestArctan2()
+                if hybridize:
+                    test_arctan2.hybridize()
+                x1 = rand_ndarray(shape1, dtype=oneType).as_np_ndarray()
+                x2 = rand_ndarray(shape2, dtype=oneType).as_np_ndarray()
+                x11 = x1.asnumpy()
+                x21 = x2.asnumpy()
+                x1.attach_grad()
+                x2.attach_grad()
+                np_out = _np.arctan2(x1.asnumpy(), x2.asnumpy())
+                with mx.autograd.record():
+                    mx_out = test_arctan2(x1, x2)
+                assert mx_out.shape == np_out.shape
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, 
atol=atol)
+                mx_out.backward()
+                np_backward_1 = x21 / (x11 * x11 + x21 * x21)
+                np_backward_2 = -1 * x11 / (x11 * x11 + x21 * x21)
+                np_backward_1 = dimReduce(np_backward_1, x11)
+                np_backward_2 = dimReduce(np_backward_2, x21)
+                assert_almost_equal(x1.grad.asnumpy(), np_backward_1, 
rtol=rtol, atol=atol)
+                assert_almost_equal(x2.grad.asnumpy(), np_backward_2, 
rtol=rtol, atol=atol)
+
+                mx_out = np.arctan2(x1, x2)
+                np_out = _np.arctan2(x1.asnumpy(), x2.asnumpy())
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, 
atol=atol)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to