Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-back-to-applevel
Changeset: r51598:0fcad0cba011
Date: 2012-01-21 19:34 +0200
http://bitbucket.org/pypy/pypy/changeset/0fcad0cba011/
Log: implement keepdims=True
diff --git a/pypy/module/micronumpy/interp_iter.py
b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -153,8 +153,13 @@
class AxisIterator(BaseIterator):
def __init__(self, start, dim, shape, strides, backstrides):
self.res_shape = shape[:]
- self.strides = strides[:dim] + [0] + strides[dim:]
- self.backstrides = backstrides[:dim] + [0] + backstrides[dim:]
+ if len(shape) == len(strides):
+ # keepdims = True
+ self.strides = strides[:dim] + [0] + strides[dim + 1:]
+ self.backstrides = backstrides[:dim] + [0] + backstrides[dim + 1:]
+ else:
+ self.strides = strides[:dim] + [0] + strides[dim:]
+ self.backstrides = backstrides[:dim] + [0] + backstrides[dim:]
self.first_line = True
self.indices = [0] * len(shape)
self._done = False
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -1077,6 +1077,8 @@
def array(space, w_item_or_iterable, w_dtype=None, w_order=None,
subok=True, copy=False, w_maskna=None, ownmaskna=False):
# find scalar
+ if w_maskna is None:
+ w_maskna = space.w_None
if (not subok or copy or not space.is_w(w_maskna, space.w_None) or
ownmaskna):
raise OperationError(space.w_NotImplementedError,
space.wrap("Unsupported args"))
@@ -1088,7 +1090,7 @@
space.call_function(space.gettypefor(interp_dtype.W_Dtype),
w_dtype)
)
return scalar_w(space, dtype, w_item_or_iterable)
- if space.is_w(w_order, space.w_None):
+ if space.is_w(w_order, space.w_None) or w_order is None:
order = 'C'
else:
order = space.str_w(w_order)
diff --git a/pypy/module/micronumpy/interp_ufuncs.py
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -46,7 +46,8 @@
return self.identity
def descr_call(self, space, __args__):
- if __args__.keywords or len(__args__.arguments_w) < self.argcount:
+ # XXX do something with strange keywords
+ if len(__args__.arguments_w) < self.argcount:
raise OperationError(space.w_ValueError,
space.wrap("invalid number of arguments")
)
@@ -60,7 +61,7 @@
@unwrap_spec(skipna=bool, keepdims=bool)
def descr_reduce(self, space, w_obj, w_axis=None, w_dtype=None,
- skipna=False, keepdims=True, w_out=None):
+ skipna=False, keepdims=False, w_out=None):
"""reduce(...)
reduce(a, axis=0)
@@ -120,9 +121,9 @@
axis = -1
else:
axis = space.int_w(w_axis)
- return self.reduce(space, w_obj, False, False, axis)
+ return self.reduce(space, w_obj, False, False, axis, keepdims)
- def reduce(self, space, w_obj, multidim, promote_to_largest, dim):
+ def reduce(self, space, w_obj, multidim, promote_to_largest, dim,
keepdims):
from pypy.module.micronumpy.interp_numarray import convert_to_array, \
Scalar
if self.argcount != 2:
@@ -148,7 +149,7 @@
raise operationerrfmt(space.w_ValueError, "zero-size array to "
"%s.reduce without identity", self.name)
if shapelen > 1 and dim >= 0:
- res = self.do_axis_reduce(obj, dtype, dim)
+ res = self.do_axis_reduce(obj, dtype, dim, keepdims)
return space.wrap(res)
scalarsig = ScalarSignature(dtype)
sig = find_sig(ReduceSignature(self.func, self.name, dtype,
@@ -162,11 +163,14 @@
value = self.identity.convert_to(dtype)
return self.reduce_loop(shapelen, sig, frame, value, obj, dtype)
- def do_axis_reduce(self, obj, dtype, dim):
+ def do_axis_reduce(self, obj, dtype, dim, keepdims):
from pypy.module.micronumpy.interp_numarray import AxisReduce,\
W_NDimArray
-
- shape = obj.shape[0:dim] + obj.shape[dim + 1:len(obj.shape)]
+
+ if keepdims:
+ shape = obj.shape[:dim] + [1] + obj.shape[dim + 1:]
+ else:
+ shape = obj.shape[:dim] + obj.shape[dim + 1:]
size = 1
for s in shape:
size *= s
diff --git a/pypy/module/micronumpy/strides.py
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -1,5 +1,5 @@
from pypy.rlib import jit
-
+from pypy.interpreter.error import OperationError
@jit.look_inside_iff(lambda shape, start, strides, backstrides, chunks:
jit.isconstant(len(chunks))
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -344,7 +344,7 @@
from _numpypy import sin, add
raises(ValueError, sin.reduce, [1, 2, 3])
- raises(ValueError, add.reduce, 1)
+ raises(TypeError, add.reduce, 1)
def test_reduce_1d(self):
from _numpypy import add, maximum
@@ -360,6 +360,14 @@
assert (add.reduce(a, 0) == [12, 15, 18, 21]).all()
assert (add.reduce(a, 1) == [6.0, 22.0, 38.0]).all()
+ def test_reduce_keepdims(self):
+ from _numpypy import add, arange
+ a = arange(12).reshape(3, 4)
+ b = add.reduce(a, 0, keepdims=True)
+ assert b.shape == (1, 4)
+ assert (add.reduce(a, 0, keepdims=True) == [12, 15, 18, 21]).all()
+
+
def test_bitwise(self):
from _numpypy import bitwise_and, bitwise_or, arange, array
a = arange(6).reshape(2, 3)
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit