This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 8e743904b4e76b29f61f2d5bb34f5dfaa2cccc72
Author: Jake Lee <gstu1...@gmail.com>
AuthorDate: Tue Jun 18 23:45:32 2019 -0700

    Numpy compatible max (#15161)
    
    * numpy amax
    
    * weird cu file diff
    
    * fix the unit test error
    
    * fix gpu bug
    
    * minor fix
    
    * fix lint
    
    * remove scalar value check
    
    * fix the bug on unit test
    
    * fix the case () that breaks the kernel launch
    
    * add zero dimension unit test
    
    * revert the tuple change
    
    * use mshadow maximum
    
    * remove test zero
    
    * change the macro for now
    
    * change the cuda to use mashadow op
    
    * fix the broadcast_reduce_op_value.cu wrong kernel
    
    * add more logic in shape to detect the invalid situation
    
    * change back to type swtich
    
    * change to as_nd_ndarray
    
    * add missing @npx.use_np_shape
    
    * retrigger CI
    
    * address the comment
    
    * undo algorithm import
    
    * remove the numeric gradient check
---
 src/operator/numpy/np_broadcast_reduce_op.h        | 92 ++++++++++++++++++++++
 src/operator/numpy/np_broadcast_reduce_op_value.cc | 39 +++++++++
 src/operator/numpy/np_broadcast_reduce_op_value.cu |  5 ++
 tests/python/unittest/test_numpy_op.py             | 92 +++++++++++++++++++++-
 4 files changed, 227 insertions(+), 1 deletion(-)

diff --git a/src/operator/numpy/np_broadcast_reduce_op.h 
b/src/operator/numpy/np_broadcast_reduce_op.h
index 0f3d71d..c76b596 100644
--- a/src/operator/numpy/np_broadcast_reduce_op.h
+++ b/src/operator/numpy/np_broadcast_reduce_op.h
@@ -64,6 +64,24 @@ struct NumpyReduceAxesParam : public 
dmlc::Parameter<NumpyReduceAxesParam> {
   }
 };
 
+struct NumpyMaxParam : public dmlc::Parameter<NumpyMaxParam> {
+  dmlc::optional<mxnet::Tuple<int>> axis;
+  bool keepdims;
+  dmlc::optional<double> initial;
+  DMLC_DECLARE_PARAMETER(NumpyMaxParam) {
+    DMLC_DECLARE_FIELD(axis)
+      .set_default(dmlc::optional<mxnet::Tuple<int>>())
+      .describe("Axis or axes along which a sum is performed. The default, 
axis=None, will sum "
+                "all of the elements of the input array. If axis is negative 
it counts from the "
+                "last to the first axis.");
+    DMLC_DECLARE_FIELD(keepdims).set_default(false)
+      .describe("If this is set to `True`, the reduced axes are left "
+                "in the result as dimension with size one.");
+    DMLC_DECLARE_FIELD(initial).set_default(dmlc::optional<double>())
+      .describe("Starting value for the sum.");
+  }
+};
+
 inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape,
                                        const 
dmlc::optional<mxnet::Tuple<int>>& axis,
                                        bool keepdims) {
@@ -152,6 +170,39 @@ inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& 
attrs,
   return shape_is_known(out_attrs->at(0));
 }
 
+inline bool NumpyMaxShape(const nnvm::NodeAttrs& attrs,
+                                 std::vector<TShape> *in_attrs,
+                                 std::vector<TShape> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  if (!shape_is_known(in_attrs->at(0))) {
+    return false;
+  }
+  const NumpyMaxParam& param = nnvm::get<NumpyMaxParam>(attrs.parsed);
+  // check the case where the reduction axis should not be zero
+  bool is_all_reducded_axes_not_zero = true;
+  const TShape& ishape = (*in_attrs)[0];
+  if (param.axis.has_value()) {
+    const mxnet::Tuple<int>& axes = param.axis.value();
+    for (int i = 0; i < axes.ndim(); ++i) {
+      if (ishape[axes[i]] == 0) {
+        is_all_reducded_axes_not_zero = false;
+        break;
+      }
+    }
+  } else {
+    if (ishape.Size() == 0) {
+      // global reduction should excuted only when input have size more than 0
+      is_all_reducded_axes_not_zero = false;
+    }
+  }
+  CHECK(is_all_reducded_axes_not_zero)
+    << "zero-size array to reduction operation maximum which has no identity";
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0,
+                     NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, 
param.keepdims));
+  return shape_is_known(out_attrs->at(0));
+}
+
 template<bool safe_acc_hint = false>
 inline bool NeedSafeAcc(int itype, int otype) {
   bool rule = (itype != otype) || (itype != mshadow::kFloat32 && itype != 
mshadow::kFloat64);
@@ -187,6 +238,29 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
   }
 }
 
+template<typename xpu, typename reducer, typename OP = 
op::mshadow_op::identity>
+void NumpyMaxCompute(const nnvm::NodeAttrs& attrs,
+                            const OpContext& ctx,
+                            const std::vector<TBlob>& inputs,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<TBlob>& outputs) {
+  const NumpyMaxParam& param = nnvm::get<NumpyMaxParam>(attrs.parsed);
+  if (param.initial.has_value()) {
+    LOG(FATAL) << "initial is not supported yet";
+  }
+  if (inputs[0].shape_.Size() == 0U || outputs[0].shape_.Size() == 0U) return; 
 // zero-size tensor
+  if (param.axis.has_value() && param.axis.value().ndim() == 0) {
+    UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs);
+  }
+  TShape small;
+  if (param.keepdims) {
+    small = outputs[0].shape_;
+  } else {
+    small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
+  }
+  ReduceAxesComputeImpl<xpu, reducer, false, false, OP>(ctx, inputs, req, 
outputs, small);
+}
+
 template<typename xpu, bool normalize = false>
 inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
                                            const OpContext& ctx,
@@ -213,6 +287,24 @@ inline void NumpyReduceAxesBackwardUseNone(const 
nnvm::NodeAttrs& attrs,
   }
 }
 
+template<typename xpu, typename OP>
+void NumpyMaxBackward(const nnvm::NodeAttrs& attrs,
+                                const OpContext& ctx,
+                                const std::vector<TBlob>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  const NumpyMaxParam& param = nnvm::get<NumpyMaxParam>(attrs.parsed);
+  TShape small;
+  if (param.keepdims) {
+    small = inputs[0].shape_;
+  } else {
+    small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true);
+  }
+  ReduceAxesBackwardUseInOutImpl<xpu, OP, false>(ctx, small, inputs, req, 
outputs);
+}
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc 
b/src/operator/numpy/np_broadcast_reduce_op_value.cc
index 078cd46..168fe59 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cc
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc
@@ -29,6 +29,7 @@ namespace mxnet {
 namespace op {
 
 DMLC_REGISTER_PARAMETER(NumpyReduceAxesParam);
+DMLC_REGISTER_PARAMETER(NumpyMaxParam);
 
 inline bool NumpySumType(const nnvm::NodeAttrs& attrs,
                          std::vector<int> *in_attrs,
@@ -128,5 +129,43 @@ NNVM_REGISTER_OP(_backward_np_mean)
 .set_num_inputs(1)
 .set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesBackwardUseNone<cpu, 
true>);
 
+inline bool NumpyMaxType(const nnvm::NodeAttrs& attrs,
+                         std::vector<int> *in_attrs,
+                         std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+  TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+
+  return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
+}
+
+NNVM_REGISTER_OP(_np_max)
+.describe(R"code()code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyMaxParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyMaxShape)
+.set_attr<nnvm::FInferType>("FInferType", NumpyMaxType)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"a"};
+  })
+.add_argument("a", "NDArray-or-Symbol", "The input")
+.add_arguments(NumpyMaxParam::__FIELDS__())
+.set_attr<FCompute>("FCompute<cpu>", NumpyMaxCompute<cpu, 
mshadow::red::maximum>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<nnvm::FGradient>("FGradient", ReduceGrad{"_backward_np_max"});
+
+NNVM_REGISTER_OP(_backward_np_max)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyMaxParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_num_inputs(3)
+.set_attr<FCompute>("FCompute<cpu>", NumpyMaxBackward<cpu, mshadow_op::eq>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu 
b/src/operator/numpy/np_broadcast_reduce_op_value.cu
index 7740c03..49bef09 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu
@@ -39,6 +39,11 @@ NNVM_REGISTER_OP(_np_mean)
 NNVM_REGISTER_OP(_backward_np_mean)
 .set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu, 
true>);
 
+NNVM_REGISTER_OP(_np_max)
+.set_attr<FCompute>("FCompute<gpu>", NumpyMaxCompute<gpu, 
mshadow::red::maximum>);
+
+NNVM_REGISTER_OP(_backward_np_max)
+.set_attr<FCompute>("FCompute<gpu>", NumpyMaxBackward<gpu, mshadow_op::eq>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index 862c4d4..031719c 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -20,10 +20,11 @@ from __future__ import absolute_import
 import numpy as _np
 import mxnet as mx
 from mxnet import np, npx
+from mxnet.base import MXNetError
 from mxnet.gluon import HybridBlock
 from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, 
rand_ndarray
 from mxnet.test_utils import check_numeric_gradient
-from common import with_seed
+from common import assertRaises, with_seed
 import random
 
 
@@ -201,6 +202,95 @@ def test_np_mean():
 
 @with_seed()
 @npx.use_np_shape
+def test_np_max():
+    @npx.use_np_shape
+    class TestMax(HybridBlock):
+        def __init__(self, axis=None, keepdims=False):
+            super(TestMax, self).__init__()
+            self._axis = axis
+            self._keepdims = keepdims
+
+        def hybrid_forward(self, F, a, *args, **kwargs):
+            return F.np.max(a, axis=self._axis, keepdims=self._keepdims)
+
+    def is_int(dtype):
+        return 'int' == dtype
+
+    def get_grad(axis):
+        if axis == ():
+            return _np.ones((2,3,4,5))
+        else:
+            temp = _np.zeros((2,3,4,5))
+            if axis == 0:
+                temp[-1,:,:,:] = 1
+                return temp
+            elif axis == 1:
+                temp[:,-1,:,:] = 1
+                return temp
+            elif axis == 2:
+                temp[:,:,-1,:] = 1
+                return temp
+            elif axis == 3:
+                temp[:,:,:,-1] = 1
+                return temp
+            elif not axis:
+                temp[-1,-1,-1,-1] = 1
+                return temp
+            raise ValueError('axis should be int or None or ()')
+
+    def _test_np_max_exception(shape, dim):
+        x = _np.random.uniform(-1.0, 1.0, shape)
+        x = mx.nd.array(x).as_np_ndarray()
+        out = mx.np.max(x)
+        assert out.ndim == dim, 'dimension mismatch, output.ndim={}, 
dim={}'.format(output.ndim, dim)
+
+    in_data_dim = random.choice([2, 3, 4])
+    shape = rand_shape_nd(in_data_dim, dim=3)
+    for hybridize in [False, True]:
+        for keepdims in [True, False]:
+            for axis in ([i for i in range(in_data_dim)] + [(), None]):
+                for itype in ['float16', 'float32', 'float64', 'int']:
+                    # test gluon
+                    test_max = TestMax(axis=axis, keepdims=keepdims)
+                    if hybridize:
+                        test_max.hybridize()
+                    if is_int(itype):
+                        x = mx.nd.arange(120).reshape((2, 3, 4, 5))
+                        x = mx.nd.array(x)
+                    else:
+                        x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, 
dtype=itype)
+                    x = x.as_np_ndarray()
+                    x.attach_grad()
+                    expected_ret = _np.amax(x.asnumpy(), axis=axis, 
keepdims=keepdims)
+                    with mx.autograd.record():
+                        y = test_max(x)
+                    assert y.shape == expected_ret.shape
+                    assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 
if itype == 'float16' else 1e-3,
+                                        atol=1e-5 if itype == 'float16' else 
1e-5)
+                    y.backward()
+                    # only check the gradient with hardcoded input
+                    if is_int(itype):
+                        assert same(x.grad.asnumpy(), get_grad(axis)), \
+                            
'x={}\ny={}\nx.grad={}\nnumpy={}'.format(x.asnumpy(), y.asnumpy(), 
x.grad.asnumpy(), get_grad(axis))
+
+                    # test imperative
+                    mx_out = np.max(x, axis=axis, keepdims=keepdims)
+                    np_out = _np.amax(x.asnumpy(), axis=axis, 
keepdims=keepdims)
+                    assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, 
atol=1e-5)
+
+    # test zero and zero dim
+    shapes = [(), (0), (2, 0), (0, 2, 1)]
+    exceptions = [False, True, True, True]
+    dims = [0] * len(shapes)
+    for shape, exception, dim in zip(shapes, exceptions, dims):
+        if exception:
+            assertRaises(MXNetError, _test_np_max_exception, shape, dim)
+        else:
+            _test_np_max_exception(shape, dim)
+
+
+@with_seed()
+@npx.use_np_shape
 def test_np_transpose():
     # TODO(junwu): Add more test cases
     data = mx.sym.var('a').as_np_ndarray()

Reply via email to