This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit e080ceb049548b8628e2d756b646ee0bf9072612
Author: Jake Lee <gstu1...@gmail.com>
AuthorDate: Fri Jul 12 03:18:53 2019 -0700

    numpy eye op (#15282)
    
    address the comment
---
 python/mxnet/ndarray/numpy/_op.py      |  33 +++++++++-
 python/mxnet/numpy/multiarray.py       |  29 ++++++++-
 python/mxnet/symbol/numpy/_symbol.py   |  33 +++++++++-
 src/operator/numpy/np_init_op.cc       |  30 ++++-----
 src/operator/numpy/np_init_op.cu       |   5 +-
 src/operator/numpy/np_init_op.h        | 113 +++++++++++++++++++++++++++++++++
 src/operator/tensor/init_op.h          |  40 +++++++-----
 tests/python/unittest/test_numpy_op.py |  69 ++++++++++++++++++++
 8 files changed, 314 insertions(+), 38 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py 
b/python/mxnet/ndarray/numpy/_op.py
index 7f710a0..ff0e8c8 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -30,7 +30,7 @@ from ..ndarray import NDArray
 
 __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
            'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 
'concatenate',
-           'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace',
+           'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 
'eye',
            'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 
'arctan', 'sign', 'log',
            'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 
'square', 'arcsin',
            'argsort']
@@ -997,6 +997,37 @@ def linspace(start, stop, num=50, endpoint=True, 
retstep=False, dtype=None, axis
         return _npi.linspace(start=start, stop=stop, num=num, 
endpoint=endpoint, ctx=ctx, dtype=dtype)
 
 
+@set_module('mxnet.ndarray.numpy')
+def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
+    """
+    Return a 2-D array with ones on the diagonal and zeros elsewhere.
+
+    Parameters
+    ----------
+    N : int
+        Number of rows in the output.
+    M : int, optional
+        Number of columns in the output. If None, defaults to N.
+    k : int, optional
+        Index of the diagonal: 0 (the default) refers to the main diagonal,
+        a positive value refers to an upper diagonal,
+        and a negative value to a lower diagonal.
+    dtype : data-type, optional
+        Data-type of the returned array.
+
+    Returns
+    -------
+    I : ndarray of shape (N,M)
+        An array where all elements are equal to zero,
+        except for the k-th diagonal, whose values are equal to one.
+    """
+    _sanity_check_params('eye', ['order'], kwargs)
+    ctx = kwargs.pop('ctx', current_context())
+    if ctx is None:
+        ctx = current_context()
+    return _npi.eye(N, M, k, ctx, dtype)
+
+
 def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs):
     """Helper function for unary operators.
 
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index cafc656..83fcfc1 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -45,7 +45,7 @@ from ..ndarray.numpy import _internal as _npi
 
 __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 
'stack', 'arange',
            'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 
'concatenate',
-           'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 
'sin', 'cos',
+           'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 
'eye', 'sin', 'cos',
            'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 
'arctan', 'sign', 'log',
            'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 
'square', 'arcsin',
            'argsort']
@@ -2144,6 +2144,33 @@ def linspace(start, stop, num=50, endpoint=True, 
retstep=False, dtype=None, axis
 
 
 @set_module('mxnet.numpy')
+def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
+    """
+    Return a 2-D array with ones on the diagonal and zeros elsewhere.
+
+    Parameters
+    ----------
+    N : int
+        Number of rows in the output.
+    M : int, optional
+        Number of columns in the output. If None, defaults to N.
+    k : int, optional
+        Index of the diagonal: 0 (the default) refers to the main diagonal,
+        a positive value refers to an upper diagonal,
+        and a negative value to a lower diagonal.
+    dtype : data-type, optional
+        Data-type of the returned array.
+
+    Returns
+    -------
+    I : ndarray of shape (N,M)
+        An array where all elements are equal to zero,
+        except for the k-th diagonal, whose values are equal to one.
+    """
+    return _mx_nd_np.eye(N, M, k, dtype, **kwargs)
+
+
+@set_module('mxnet.numpy')
 def sin(x, out=None, **kwargs):
     r"""Trigonometric sine, element-wise.
 
diff --git a/python/mxnet/symbol/numpy/_symbol.py 
b/python/mxnet/symbol/numpy/_symbol.py
index fa47d8d..92e0563 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -31,7 +31,7 @@ from . import _internal as _npi
 
 __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 
'arange', 'argmax',
            'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 
'split', 'swapaxes',
-           'expand_dims', 'tile', 'linspace', 'sin', 'cos', 'sinh', 'cosh', 
'log10', 'sqrt',
+           'expand_dims', 'tile', 'linspace', 'eye', 'sin', 'cos', 'sinh', 
'cosh', 'log10', 'sqrt',
            'abs', 'exp', 'arctan', 'sign', 'log', 'degrees', 'log2', 'rint', 
'radians', 'mean',
            'reciprocal', 'square', 'arcsin', 'argsort']
 
@@ -1626,6 +1626,37 @@ def linspace(start, stop, num=50, endpoint=True, 
retstep=False, dtype=None, axis
         return _npi.linspace(start=start, stop=stop, num=num, 
endpoint=endpoint, ctx=ctx, dtype=dtype)
 
 
+@set_module('mxnet.symbol.numpy')
+def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
+    """
+    Return a 2-D array with ones on the diagonal and zeros elsewhere.
+
+    Parameters
+    ----------
+    N : int
+        Number of rows in the output.
+    M : int, optional
+        Number of columns in the output. If None, defaults to N.
+    k : int, optional
+        Index of the diagonal: 0 (the default) refers to the main diagonal,
+        a positive value refers to an upper diagonal,
+        and a negative value to a lower diagonal.
+    dtype : data-type, optional
+        Data-type of the returned array.
+
+    Returns
+    -------
+    I : ndarray of shape (N,M)
+        An array where all elements are equal to zero,
+        except for the k-th diagonal, whose values are equal to one.
+    """
+    _sanity_check_params('eye', ['order'], kwargs)
+    ctx = kwargs.pop('ctx', current_context())
+    if ctx is None:
+        ctx = current_context()
+    return _npi.eye(N, M, k, ctx, dtype)
+
+
 def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs):
     """Helper function for unary operators.
 
diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc
index 9edfa20..dc262fe 100644
--- a/src/operator/numpy/np_init_op.cc
+++ b/src/operator/numpy/np_init_op.cc
@@ -22,28 +22,12 @@
  * \file np_init_op.cc
  * \brief CPU Implementation of numpy init op
  */
-#include "../tensor/init_op.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "./np_init_op.h"
 
 namespace mxnet {
 namespace op {
 
-inline bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
-                            mxnet::ShapeVector* in_shapes,
-                            mxnet::ShapeVector* out_shapes) {
-  const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
-  CHECK_EQ(in_shapes->size(), 0U);
-  CHECK_EQ(out_shapes->size(), 1U);
-  CHECK_NE(param.step, 0) << "_npi_arange does not support step=0";
-  CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " 
<< param.repeat;
-  CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value";
-  double out_size = std::ceil((param.stop.value() - param.start) / param.step);
-  if (out_size < 0) {
-    out_size = 0;
-  }
-  SHAPE_ASSIGN_CHECK(*out_shapes, 0, 
mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
-  return true;
-}
+DMLC_REGISTER_PARAMETER(NumpyEyeParam);
 
 NNVM_REGISTER_OP(_npi_zeros)
 .describe("Return a new array of given shape, type, and context, filled with 
zeros.")
@@ -134,5 +118,15 @@ NNVM_REGISTER_OP(_npi_arange)
 .set_attr<FCompute>("FCompute<cpu>", RangeCompute<cpu>)
 .add_arguments(RangeParam::__FIELDS__());
 
+NNVM_REGISTER_OP(_npi_eye)
+.describe("Return a 2-D array with ones on the diagonal and zeros elsewhere.")
+.set_num_inputs(0)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyEyeParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyEyeShape)
+.set_attr<nnvm::FInferType>("FInferType", InitType<NumpyEyeParam>)
+.set_attr<FCompute>("FCompute<cpu>", NumpyEyeFill<cpu>)
+.add_arguments(NumpyEyeParam::__FIELDS__());
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu
index 2c41e56..68d1681 100644
--- a/src/operator/numpy/np_init_op.cu
+++ b/src/operator/numpy/np_init_op.cu
@@ -23,7 +23,7 @@
  * \brief GPU Implementation of numpy init op
  */
 
-#include "../tensor/init_op.h"
+#include "./np_init_op.h"
 
 namespace mxnet {
 namespace op {
@@ -43,5 +43,8 @@ NNVM_REGISTER_OP(_np_ones_like)
 NNVM_REGISTER_OP(_npi_arange)
 .set_attr<FCompute>("FCompute<gpu>", RangeCompute<gpu>);
 
+NNVM_REGISTER_OP(_npi_eye)
+.set_attr<FCompute>("FCompute<gpu>", NumpyEyeFill<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_init_op.h b/src/operator/numpy/np_init_op.h
new file mode 100644
index 0000000..52be5fb
--- /dev/null
+++ b/src/operator/numpy/np_init_op.h
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file np_init_op.h
+ * \brief CPU Implementation of numpy init op
+ */
+#ifndef MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_
+#define MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_
+
+#include <vector>
+#include <string>
+#include "../tensor/init_op.h"
+#include "../tensor/elemwise_unary_op.h"
+
+
+namespace mxnet {
+namespace op {
+
+struct NumpyEyeParam : public dmlc::Parameter<NumpyEyeParam> {
+  nnvm::dim_t N;
+  dmlc::optional<nnvm::dim_t> M;
+  nnvm::dim_t k;
+  std::string ctx;
+  int dtype;
+  DMLC_DECLARE_PARAMETER(NumpyEyeParam) {
+    DMLC_DECLARE_FIELD(N)
+    .describe("Number of rows in the output.");
+    DMLC_DECLARE_FIELD(M)
+    .set_default(dmlc::optional<nnvm::dim_t>())
+    .describe("Number of columns in the output. If None, defaults to N.");
+    DMLC_DECLARE_FIELD(k)
+    .set_default(0)
+    .describe("Index of the diagonal. 0 (the default) refers to the main 
diagonal,"
+              "a positive value refers to an upper diagonal."
+              "and a negative value to a lower diagonal.");
+    DMLC_DECLARE_FIELD(ctx)
+    .set_default("")
+    .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
+              "Only used for imperative calls.");
+    DMLC_DECLARE_FIELD(dtype)
+    .set_default(mshadow::kFloat32)
+    MXNET_ADD_ALL_TYPES
+    .describe("Data-type of the returned array.");
+  }
+};
+
+inline bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
+                            mxnet::ShapeVector* in_shapes,
+                            mxnet::ShapeVector* out_shapes) {
+  const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
+  CHECK_EQ(in_shapes->size(), 0U);
+  CHECK_EQ(out_shapes->size(), 1U);
+  CHECK_NE(param.step, 0) << "_npi_arange does not support step=0";
+  CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " 
<< param.repeat;
+  CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value";
+  double out_size = std::ceil((param.stop.value() - param.start) / param.step);
+  if (out_size < 0) {
+    out_size = 0;
+  }
+  SHAPE_ASSIGN_CHECK(*out_shapes, 0, 
mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
+  return true;
+}
+
+inline bool NumpyEyeShape(const nnvm::NodeAttrs& attrs,
+                         mxnet::ShapeVector *in_attrs,
+                         mxnet::ShapeVector *out_attrs) {
+  const NumpyEyeParam& param = nnvm::get<NumpyEyeParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), 0U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  nnvm::dim_t M = param.M.has_value() ? param.M.value() : param.N;
+  CHECK(param.N >= 0) << "negative dimensions are not allowed. N is " << 
param.N;
+  CHECK(M >= 0) << "negative dimensions are not allowed. M is " << M;
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(param.N, M));
+
+  return out_attrs->at(0).ndim() != 0U;
+}
+
+template<typename xpu>
+void NumpyEyeFill(const nnvm::NodeAttrs& attrs,
+                  const OpContext& ctx,
+                  const std::vector<TBlob>& inputs,
+                  const std::vector<OpReqType>& req,
+                  const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 0U);
+  CHECK_EQ(outputs.size(), 1U);
+  if (outputs[0].shape_.Size() == 0) return;  // zero-size tensor
+  const NumpyEyeParam& param = nnvm::get<NumpyEyeParam>(attrs.parsed);
+  const nnvm::dim_t num_cols = param.M.has_value() ? param.M.value() : param.N;
+  EyeFillImpl<xpu>(outputs[0], ctx, req, num_cols, param.N, param.k);
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 51c8436..068ddd4 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -487,6 +487,29 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
   }
 }
 
+template<typename xpu>
+inline void EyeFillImpl(const TBlob& out_data,
+                        const OpContext& ctx,
+                        const std::vector<OpReqType>& req,
+                        const nnvm::dim_t num_cols,
+                        const nnvm::dim_t N,
+                        const nnvm::dim_t k) {
+  using namespace mxnet_op;
+  const nnvm::dim_t cnnz = std::max(num_cols - std::abs(k), (nnvm::dim_t)0);
+  const nnvm::dim_t rnnz = std::max(N - std::abs(k), (nnvm::dim_t)0);
+  const nnvm::dim_t nnz = k > 0 ? std::min(cnnz, N) :
+                                        std::min(rnnz, num_cols);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+    MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+      Fill(s, out_data, req[0], static_cast<DType>(0));
+      if (nnz > 0) {
+        Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, 
out_data.dptr<DType>(),
+          std::max(static_cast<nnvm::dim_t>(0), k), k, num_cols);
+      }
+    });
+  });
+}
 
 template<typename xpu>
 void EyeFill(const nnvm::NodeAttrs& attrs,
@@ -497,25 +520,10 @@ void EyeFill(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 0U);
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
-  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   const EyeParam& param = nnvm::get<EyeParam>(attrs.parsed);
   const TBlob& out_data = outputs[0];
   const nnvm::dim_t num_cols = param.M > 0 ? param.M : param.N;
-
-  const nnvm::dim_t cnnz = std::max(num_cols - std::abs(param.k), 
(nnvm::dim_t)0);
-  const nnvm::dim_t rnnz = std::max(param.N - std::abs(param.k), 
(nnvm::dim_t)0);
-  const nnvm::dim_t nnz = param.k > 0 ? std::min(cnnz, param.N) :
-                                        std::min(rnnz, num_cols);
-  using namespace mxnet_op;
-  MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-    MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-      Fill(s, out_data, req[0], static_cast<DType>(0));
-      if (nnz > 0) {
-        Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, 
out_data.dptr<DType>(),
-          std::max(static_cast<nnvm::dim_t>(0), param.k), param.k, num_cols);
-      }
-    });
-  });
+  EyeFillImpl<xpu>(out_data, ctx, req, num_cols, param.N, param.k);
 }
 
 
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index d373419..06f0994 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -22,6 +22,7 @@ import mxnet as mx
 from mxnet import np, npx
 from mxnet.base import MXNetError
 from mxnet.gluon import HybridBlock
+from mxnet.base import MXNetError
 from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, 
rand_ndarray
 from mxnet.test_utils import check_numeric_gradient
 from common import assertRaises, with_seed
@@ -718,6 +719,74 @@ def test_np_linspace():
 
 @with_seed()
 @npx.use_np_shape
+def test_np_eye():
+    configs = [
+        4,
+        1000,
+        (4, 3),
+        (5, None),
+        (4, None, 1),
+        (2, 2, 1),
+        (4, 6, 1),
+        (7, 3, -3),
+        (3, 2, -2),
+        (4, 0),
+        (0, 0),
+        (0, 3),
+        (0, 0, -2)
+    ]
+    exception_configs = [
+        -1,
+        -1000,
+        (-2, None),
+        (1, -1)
+    ]
+    dtypes = ['int32', 'float16', 'float32', 'float64', None]
+    for config in configs:
+        for dtype in dtypes:
+            if isinstance(config, tuple):
+                mx_ret = np.eye(*config, dtype=dtype)
+                np_ret = _np.eye(*config, dtype=dtype)
+            else:
+                mx_ret = np.eye(config, dtype=dtype)
+                np_ret = _np.eye(config, dtype=dtype)
+            assert same(mx_ret.asnumpy(), np_ret)
+    # check for exception input
+    for config in exception_configs:
+        if isinstance(config, tuple):
+            assertRaises(MXNetError, np.eye, *config)
+        else:
+            assertRaises(MXNetError, np.eye, config)
+    @npx.use_np
+    class TestEye(HybridBlock):
+        def __init__(self, N, M=None, k=0, dtype=None):
+            super(TestEye, self).__init__()
+            self._N = N
+            self._M = M
+            self._k = k
+            self._dtype = dtype
+
+        def hybrid_forward(self, F, x):
+            return x + F.np.eye(self._N, self._M, self._k, dtype=self._dtype)
+
+    for dtype in dtypes:
+        x = np.zeros(shape=(), dtype=dtype)
+        for config in configs:
+            for hybridize in [False, True]:
+                if isinstance(config, tuple):
+                    net = TestEye(*config, dtype=dtype)
+                    np_out = _np.eye(*config, dtype=dtype)
+                else:
+                    net = TestEye(config, dtype=dtype)
+                    np_out = _np.eye(config, dtype=dtype)
+                if hybridize:
+                    net.hybridize()
+                mx_out = net(x)
+                assert same(mx_out.asnumpy(), np_out)
+
+
+@with_seed()
+@npx.use_np_shape
 def test_np_argmax():
     workloads = [
         ((), 0, False),

Reply via email to