Author: Brian Kearns <[email protected]>
Branch: 
Changeset: r68500:aca7d2177494
Date: 2013-12-19 18:37 -0500
http://bitbucket.org/pypy/pypy/changeset/aca7d2177494/

Log:    fix ndarray.take with axis argument

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -1187,10 +1187,11 @@
         if axis is None:
             res = a.ravel()[indices]
         else:
+            from operator import mul
             if axis < 0: axis += len(a.shape)
             s0, s1 = a.shape[:axis], a.shape[axis+1:]
-            l0 = prod(s0) if s0 else 1
-            l1 = prod(s1) if s1 else 1
+            l0 = reduce(mul, s0) if s0 else 1
+            l1 = reduce(mul, s1) if s1 else 1
             res = a.reshape((l0, -1, l1))[:,indices,:].reshape(s0 + (-1,) + s1)
         if out is not None:
             out[:] = res
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2722,6 +2722,8 @@
         raises(IndexError, "arange(3).take([15])")
         a = arange(6).reshape(2, 3)
         assert (a.take([1, 0, 3]) == [1, 0, 3]).all()
+        assert (a.take([1], axis=0) == [[3, 4, 5]]).all()
+        assert (a.take([1], axis=1) == [[1], [4]]).all()
         assert ((a + a).take([3]) == [6]).all()
         a = arange(12).reshape(2, 6)
         assert (a[:,::2].take([3, 2, 1]) == [6, 4, 2]).all()
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to