Author: Brian Kearns <[email protected]>
Branch:
Changeset: r70890:d91034c74551
Date: 2014-04-23 15:17 -0400
http://bitbucket.org/pypy/pypy/changeset/d91034c74551/
Log: fix ufunc reduce with single axis tuple (issue1718)
diff --git a/pypy/module/micronumpy/compile.py
b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -136,6 +136,11 @@
def newcomplex(self, r, i):
return ComplexObject(r, i)
+ def getitem(self, obj, index):
+ assert isinstance(obj, ListObject)
+ assert isinstance(index, IntObject)
+ return obj.items[index.intval]
+
def listview(self, obj, number=-1):
assert isinstance(obj, ListObject)
if number != -1:
diff --git a/pypy/module/micronumpy/test/test_ndarray.py
b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -1506,6 +1506,9 @@
from numpypy import array, zeros
a = array([-1.2, 3.4, 5.7, -3.0, 2.7])
assert a.max() == 5.7
+ assert a.max().shape == ()
+ assert a.max(axis=(0,)) == 5.7
+ assert a.max(axis=(0,)).shape == ()
assert a.max(keepdims=True) == 5.7
assert a.max(keepdims=True).shape == (1,)
b = array([])
@@ -1521,6 +1524,9 @@
from numpypy import array, zeros
a = array([-1.2, 3.4, 5.7, -3.0, 2.7])
assert a.min() == -3.0
+ assert a.min().shape == ()
+ assert a.min(axis=(0,)) == -3.0
+ assert a.min(axis=(0,)).shape == ()
assert a.min(keepdims=True) == -3.0
assert a.min(keepdims=True).shape == (1,)
b = array([])
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -772,6 +772,7 @@
a = zeros((2, 2)) + 1
assert (add.reduce(a, axis=1) == [2, 2]).all()
+ assert (add.reduce(a, axis=(1,)) == [2, 2]).all()
exc = raises(ValueError, add.reduce, a, axis=2)
assert exc.value[0] == "'axis' entry is out of bounds"
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -178,6 +178,8 @@
if space.is_none(w_axis):
axis = maxint
else:
+ if space.isinstance_w(w_axis, space.w_tuple) and
space.len_w(w_axis) == 1:
+ w_axis = space.getitem(w_axis, space.wrap(0))
axis = space.int_w(w_axis)
if axis < -shapelen or axis >= shapelen:
raise oefmt(space.w_ValueError, "'axis' entry is out of
bounds")
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit