Author: mattip Branch: numpypy-axisops Changeset: r50912:e5b246bae93a Date: 2011-12-25 21:09 +0200 http://bitbucket.org/pypy/pypy/changeset/e5b246bae93a/
Log: add tests for numpypy.reduce diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -3,7 +3,8 @@ from pypy.interpreter.gateway import interp2app, unwrap_spec from pypy.interpreter.typedef import TypeDef, GetSetProperty, interp_attrproperty from pypy.module.micronumpy import interp_boxes, interp_dtype, types -from pypy.module.micronumpy.signature import ReduceSignature, ScalarSignature, find_sig +from pypy.module.micronumpy.signature import (ReduceSignature, ScalarSignature, + ArraySignature, find_sig) from pypy.rlib import jit from pypy.rlib.rarithmetic import LONG_BIT from pypy.tool.sourcetools import func_with_new_name @@ -46,8 +47,60 @@ ) return self.call(space, __args__.arguments_w) - def descr_reduce(self, space, w_obj): - return self.reduce(space, w_obj, False, space.wrap(-1)) + def descr_reduce(self, space, w_obj, w_dim=0): + '''reduce(...) + reduce(a, axis=0) + + Reduces `a`'s dimension by one, by applying ufunc along one axis. + + Let :math:`a.shape = (N_0, ..., N_i, ..., N_{M-1})`. Then + :math:`ufunc.reduce(a, axis=i)[k_0, ..,k_{i-1}, k_{i+1}, .., k_{M-1}]` = + the result of iterating `j` over :math:`range(N_i)`, cumulatively applying + ufunc to each :math:`a[k_0, ..,k_{i-1}, j, k_{i+1}, .., k_{M-1}]`. + For a one-dimensional array, reduce produces results equivalent to: + :: + + r = op.identity # op = ufunc + for i in xrange(len(A)): + r = op(r, A[i]) + return r + + For example, add.reduce() is equivalent to sum(). + + Parameters + ---------- + a : array_like + The array to act on. + axis : int, optional + The axis along which to apply the reduction. + + Examples + -------- + >>> np.multiply.reduce([2,3,5]) + 30 + + A multi-dimensional array example: + + >>> X = np.arange(8).reshape((2,2,2)) + >>> X + array([[[0, 1], + [2, 3]], + [[4, 5], + [6, 7]]]) + >>> np.add.reduce(X, 0) + array([[ 4, 6], + [ 8, 10]]) + >>> np.add.reduce(X) # confirm: default axis value is 0 + array([[ 4, 6], + [ 8, 10]]) + >>> np.add.reduce(X, 1) + array([[ 2, 4], + [10, 12]]) + >>> np.add.reduce(X, 2) + array([[ 1, 5], + [ 9, 13]]) + ''' + return self.reduce(space, w_obj, False, w_dim) def reduce(self, space, w_obj, multidim, w_dim): from pypy.module.micronumpy.interp_numarray import convert_to_array, Scalar @@ -57,6 +110,8 @@ dim = -1 if not space.is_w(w_dim, space.w_None): dim = space.int_w(w_dim) + if not multidim and space.is_w(w_dim, space.w_None): + dim = 0 assert isinstance(self, W_Ufunc2) obj = convert_to_array(space, w_obj) if isinstance(obj, Scalar): @@ -69,14 +124,15 @@ promote_to_largest=True ) shapelen = len(obj.shape) - #TODO: if dim>=0 return a ArraySignature? - sig = find_sig(ReduceSignature(self.func, self.name, dtype, + if dim>=0 or 0: + sig = find_sig(ReduceSignature(self.func, self.name, dtype, + ArraySignature(dtype), + obj.create_sig(obj.shape)), obj) + else: + sig = find_sig(ReduceSignature(self.func, self.name, dtype, ScalarSignature(dtype), obj.create_sig(obj.shape)), obj) frame = sig.create_frame(obj) - if shapelen > 1 and not multidim: - raise OperationError(space.w_NotImplementedError, - space.wrap("not implemented yet")) if self.identity is None: if size == 0: raise operationerrfmt(space.w_ValueError, "zero-size array to " diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py --- a/pypy/module/micronumpy/test/test_ufuncs.py +++ b/pypy/module/micronumpy/test/test_ufuncs.py @@ -339,12 +339,15 @@ raises(TypeError, add.reduce, 1) def test_reduce(self): - from numpypy import add, maximum + from numpypy import add, maximum, arange assert add.reduce([1, 2, 3]) == 6 assert maximum.reduce([1]) == 1 assert maximum.reduce([1, 2, 3]) == 3 raises(ValueError, maximum.reduce, []) + a = arange(12).reshape(3,4) + assert add.reduce(a, 0) == add.reduce(a) + assert (add.reduce(a, 1) == [ 6, 22, 38]).all() def test_comparisons(self): import operator _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit