Author: Brian Kearns <bdkea...@gmail.com> Branch: Changeset: r69004:9371233f1468 Date: 2014-01-29 17:26 -0500 http://bitbucket.org/pypy/pypy/changeset/9371233f1468/
Log: support keepdims arg for array reduce operations diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -961,7 +961,8 @@ def _reduce_ufunc_impl(ufunc_name, promote_to_largest=False, cumulative=False): - def impl(self, space, w_axis=None, w_dtype=None, w_out=None): + @unwrap_spec(keepdims=bool) + def impl(self, space, w_axis=None, w_dtype=None, w_out=None, keepdims=False): if space.is_none(w_out): out = None elif not isinstance(w_out, W_NDimArray): @@ -971,7 +972,7 @@ out = w_out return getattr(interp_ufuncs.get(space), ufunc_name).reduce( space, self, promote_to_largest, w_axis, - False, out, w_dtype, cumulative=cumulative) + keepdims, out, w_dtype, cumulative=cumulative) return func_with_new_name(impl, "reduce_%s_impl_%d_%d" % (ufunc_name, promote_to_largest, cumulative)) diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -252,6 +252,11 @@ if out: out.set_scalar_value(res) return out + if keepdims: + shape = [1] * len(obj_shape) + out = W_NDimArray.from_shape(space, [1] * len(obj_shape), dtype, w_instance=obj) + out.implementation.setitem(0, res) + return out return res def descr_outer(self, space, __args__): diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -1399,6 +1399,8 @@ from numpypy import arange, array a = arange(15).reshape(5, 3) assert a.sum() == 105 + assert a.sum(keepdims=True) == 105 + assert a.sum(keepdims=True).shape == (1, 1) assert a.max() == 14 assert array([]).sum() == 0.0 assert array([]).reshape(0, 2).sum() == 0. @@ -1431,6 +1433,8 @@ from numpypy import array, dtype a = array(range(1, 6)) assert a.prod() == 120.0 + assert a.prod(keepdims=True) == 120.0 + assert a.prod(keepdims=True).shape == (1,) assert a[:4].prod() == 24.0 for dt in ['bool', 'int8', 'uint8', 'int16', 'uint16']: a = array([True, False], dtype=dt) @@ -1445,6 +1449,8 @@ from numpypy import array, zeros a = array([-1.2, 3.4, 5.7, -3.0, 2.7]) assert a.max() == 5.7 + assert a.max(keepdims=True) == 5.7 + assert a.max(keepdims=True).shape == (1,) b = array([]) raises(ValueError, "b.max()") assert list(zeros((0, 2)).max(axis=1)) == [] @@ -1458,6 +1464,8 @@ from numpypy import array, zeros a = array([-1.2, 3.4, 5.7, -3.0, 2.7]) assert a.min() == -3.0 + assert a.min(keepdims=True) == -3.0 + assert a.min(keepdims=True).shape == (1,) b = array([]) raises(ValueError, "b.min()") assert list(zeros((0, 2)).min(axis=1)) == [] @@ -1508,6 +1516,8 @@ assert a.all() == False a[0] = 3.0 assert a.all() == True + assert a.all(keepdims=True) == True + assert a.all(keepdims=True).shape == (1,) b = array([]) assert b.all() == True @@ -1515,6 +1525,8 @@ from numpypy import array, zeros a = array(range(5)) assert a.any() == True + assert a.any(keepdims=True) == True + assert a.any(keepdims=True).shape == (1,) b = zeros(5) assert b.any() == False c = array([]) _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit