Author: mattip
Branch: numpypy-shape-bug
Changeset: r51671:5a8fc969e644
Date: 2012-01-19 01:55 +0200
http://bitbucket.org/pypy/pypy/changeset/5a8fc969e644/
Log: add more tests
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -183,10 +183,9 @@
n_old_elems_to_use *= old_shape[oldI]
if n_new_elems_used == n_old_elems_to_use:
oldI += 1
- if oldI >= len(old_shape):
- break
- cur_step = steps[oldI]
- n_old_elems_to_use *= old_shape[oldI]
+ if oldI < len(old_shape):
+ cur_step = steps[oldI]
+ n_old_elems_to_use *= old_shape[oldI]
elif order == 'C':
for i in range(len(old_shape) - 1, -1, -1):
steps.insert(0, old_strides[i] / last_step)
@@ -206,10 +205,10 @@
n_old_elems_to_use *= old_shape[oldI]
if n_new_elems_used == n_old_elems_to_use:
oldI -= 1
- if oldI < -len(old_shape):
- break
- cur_step = steps[oldI]
- n_old_elems_to_use *= old_shape[oldI]
+ if oldI >= -len(old_shape):
+ cur_step = steps[oldI]
+ n_old_elems_to_use *= old_shape[oldI]
+ assert len(new_strides) == len(new_shape)
return new_strides
class BaseArray(Wrappable):
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -157,6 +157,13 @@
assert calc_new_strides([2, 3, 4], [8, 3], [1, 16], 'F') is None
assert calc_new_strides([24], [2, 4, 3], [48, 6, 1], 'C') is None
assert calc_new_strides([24], [2, 4, 3], [24, 6, 2], 'C') == [2]
+ assert calc_new_strides([105, 1], [3, 5, 7], [35, 7, 1],'C') == [1, 1]
+ assert calc_new_strides([1, 105], [3, 5, 7], [35, 7, 1],'C') == [105,
1]
+ assert calc_new_strides([1, 105], [3, 5, 7], [35, 7, 1],'F') is None
+ assert calc_new_strides([1, 1, 1, 105, 1], [15, 7], [7, 1],'C') == \
+ [105, 105, 105, 1, 1]
+ assert calc_new_strides([1, 1, 105, 1, 1], [7, 15], [1, 7],'F') == \
+ [1, 1, 1, 105, 105]
class AppTestNumArray(BaseNumpyAppTest):
@@ -767,7 +774,6 @@
assert (a[:, 1, :].sum(1) == [70, 315, 560]).all()
raises (ValueError, 'a[:, 1, :].sum(2)')
assert ((a + a).T.sum(2).T == (a + a).sum(0)).all()
- skip("Those are broken, fix after removing Scalar")
assert (a.reshape(1,-1).sum(0) == range(105)).all()
assert (a.reshape(1,-1).sum(1) == 5460)
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit