Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-multidim
Changeset: r49726:195a2e3ca74a
Date: 2011-11-24 15:17 +0200
http://bitbucket.org/pypy/pypy/changeset/195a2e3ca74a/

Log:    For now multidimarray.argmax is a TypeError. Wait until we implement
        flatiter

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -401,6 +401,9 @@
             return result
         def impl(self, space):
             size = self.find_size()
+            if len(self.shape) > 1:
+                raise OperationError(space.w_TypeError,
+                                     space.wrap("argmin/max does not work on 
multidimensional arrays yet"))
             if size == 0:
                 raise OperationError(space.w_ValueError,
                     space.wrap("Can't call %s on zero-size arrays" \
@@ -601,9 +604,6 @@
             item += v * self.strides[i]
         return item
 
-    def get_root_shape(self):
-        return self.shape
-
     def _single_item_result(self, space, w_idx):
         """ The result of getitem/setitem is a single item if w_idx
         is a list of scalars that match the size of shape
@@ -726,14 +726,8 @@
 
     def compute_index(self, space, offset):
         offset -= self.start
-        if len(self.shape) == 1:
-            return space.wrap(offset // self.strides[0])
-        indices_w = []
-        for shard in self.strides:
-            r = offset // shard
-            indices_w.append(space.wrap(r))
-            offset -= shard * r
-        return space.newtuple(indices_w)
+        assert len(self.shape) == 1
+        return space.wrap(offset // self.strides[0])
 
 def convert_to_array(space, w_obj):
     if isinstance(w_obj, BaseArray):
@@ -1009,9 +1003,6 @@
     def setitem(self, item, value):
         self.parent.setitem(item, value)
 
-    def get_root_shape(self):
-        return self.parent.get_root_shape()
-
 class NDimArray(BaseArray):
     """ A class representing contiguous array. We know that each iteration
     by say ufunc will increase the data index by one
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -768,6 +768,7 @@
         a[1, 2, 0] = 3
         assert a[1, 2, 0] == 3
         assert a[1, 1, 0] == 0
+        assert a[1, -1, 0] == 3
 
     def test_slices(self):
         import numpypy
@@ -906,6 +907,11 @@
         b[:] = (a + a)
         assert (b == zeros((4, 3, 5))).all()
 
+    def test_argmax(self):
+        from numpypy import array
+        a = array([[1, 2], [3, 4]])
+        raises(TypeError, a.argmax)
+
 class AppTestSupport(object):
     def setup_class(cls):
         import struct
@@ -994,8 +1000,7 @@
         assert str(a) == "[[0 0 0 ..., 0 0 0]\n [0 0 0 ..., 0 0 0]\n [0 0 0 
..., 0 0 0]\n ..., \n [0 0 0 ..., 0 0 0]\n [0 0 0 ..., 0 0 0]\n [0 0 0 ..., 0 0 
0]]"
         a = zeros((2, 2, 2))
         r = str(a)
-        py.test.skip("this is not exactly right, fix it one day")
-        assert r == '[[[ 0.  0.]\n  [ 0.  0.]]\n\n [[ 0.  0.]\n  [ 0.  0.]]]'
+        assert r == '[[[0.0 0.0]\n  [0.0 0.0]]\n\n [[0.0 0.0]\n  [0.0 0.0]]]'
         
     def test_str_slice(self):
         from numpypy import array, zeros
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to