Author: mattip
Branch: numpypy-shape-bug
Changeset: r51621:e11082ff75b9
Date: 2012-01-22 01:50 +0200
http://bitbucket.org/pypy/pypy/changeset/e11082ff75b9/

Log:    add failing setshape test, fix for test

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -156,7 +156,7 @@
 # fits the new shape, using those steps. If there is a shape/step mismatch
 # (meaning that the realignment of elements crosses from one step into another)
 # return None so that the caller can raise an exception.
-def calc_new_strides(new_shape, old_shape, old_strides):
+def calc_new_strides(new_shape, old_shape, old_strides, order):
     # Return the proper strides for new_shape, or None if the mapping crosses
     # stepping boundaries
 
@@ -166,7 +166,7 @@
     last_step = 1
     oldI = 0
     new_strides = []
-    if old_strides[0] < old_strides[-1]:
+    if order == 'F':
         for i in range(len(old_shape)):
             steps.append(old_strides[i] / last_step)
             last_step *= old_shape[i]
@@ -187,7 +187,7 @@
                     break
                 cur_step = steps[oldI]
                 n_old_elems_to_use *= old_shape[oldI]
-    else:
+    elif order == 'C':
         for i in range(len(old_shape) - 1, -1, -1):
             steps.insert(0, old_strides[i] / last_step)
             last_step *= old_shape[i]
@@ -543,8 +543,8 @@
         concrete = self.get_concrete()
         new_shape = get_shape_from_iterable(space, concrete.size, w_shape)
         # Since we got to here, prod(new_shape) == self.size
-        new_strides = calc_new_strides(new_shape,
-                                       concrete.shape, concrete.strides)
+        new_strides = calc_new_strides(new_shape, concrete.shape,
+                                     concrete.strides, concrete.order)
         if new_strides:
             # We can create a view, strides somehow match up.
             ndims = len(new_shape)
@@ -1105,7 +1105,8 @@
             self.backstrides = backstrides
             self.shape = new_shape
             return
-        new_strides = calc_new_strides(new_shape, self.shape, self.strides)
+        new_strides = calc_new_strides(new_shape, self.shape, self.strides,
+                                                   self.order)
         if new_strides is None:
             raise OperationError(space.w_AttributeError, space.wrap(
                           "incompatible shape for a non-contiguous array"))
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -152,11 +152,11 @@
 
     def test_calc_new_strides(self):
         from pypy.module.micronumpy.interp_numarray import calc_new_strides
-        assert calc_new_strides([2, 4], [4, 2], [4, 2]) == [8, 2]
-        assert calc_new_strides([2, 4, 3], [8, 3], [1, 16]) == [1, 2, 16]
-        assert calc_new_strides([2, 3, 4], [8, 3], [1, 16]) is None
-        assert calc_new_strides([24], [2, 4, 3], [48, 6, 1]) is None
-        assert calc_new_strides([24], [2, 4, 3], [24, 6, 2]) == [2]
+        assert calc_new_strides([2, 4], [4, 2], [4, 2], "C") == [8, 2]
+        assert calc_new_strides([2, 4, 3], [8, 3], [1, 16], 'F') == [1, 2, 16]
+        assert calc_new_strides([2, 3, 4], [8, 3], [1, 16], 'F') is None
+        assert calc_new_strides([24], [2, 4, 3], [48, 6, 1], 'C') is None
+        assert calc_new_strides([24], [2, 4, 3], [24, 6, 2], 'C') == [2]
 
 
 class AppTestNumArray(BaseNumpyAppTest):
@@ -381,6 +381,8 @@
         a.shape = ()
         #numpy allows this
         a.shape = (1,)
+        a = array(range(6)).reshape(2,3).T
+        raises(AttributeError, 'a.shape = 6')
 
     def test_reshape(self):
         from _numpypy import array, zeros
@@ -765,7 +767,7 @@
         assert (a[:, 1, :].sum(1) == [70, 315, 560]).all()
         raises (ValueError, 'a[:, 1, :].sum(2)')
         assert ((a + a).T.sum(2).T == (a + a).sum(0)).all()
-        skip("Those are broken on reshape, fix!")
+        skip("Those are broken, fix after removing Scalar")
         assert (a.reshape(1,-1).sum(0) == range(105)).all()
         assert (a.reshape(1,-1).sum(1) == 5460)
 
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to