Author: Matti Picus <[email protected]>
Branch:
Changeset: r61717:36a9d19e67e4
Date: 2013-02-24 01:43 +0200
http://bitbucket.org/pypy/pypy/changeset/36a9d19e67e4/
Log: test, fix segfault for negative axis in concatenate
diff --git a/pypy/module/micronumpy/interp_arrayops.py
b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -106,29 +106,32 @@
args_w = [convert_to_array(space, w_arg) for w_arg in args_w]
dtype = args_w[0].get_dtype()
shape = args_w[0].get_shape()[:]
- if len(shape) <= axis:
+ _axis = axis
+ if axis < 0:
+ _axis = len(shape) + axis
+ if _axis < 0 or len(shape) <= _axis:
raise operationerrfmt(space.w_IndexError, "axis %d out of bounds [0,
%d)", axis, len(shape))
for arr in args_w[1:]:
dtype = interp_ufuncs.find_binop_result_dtype(space, dtype,
arr.get_dtype())
- if len(arr.get_shape()) <= axis:
+ if _axis < 0 or len(arr.get_shape()) <= _axis:
raise operationerrfmt(space.w_IndexError, "axis %d out of bounds
[0, %d)", axis, len(shape))
for i, axis_size in enumerate(arr.get_shape()):
- if len(arr.get_shape()) != len(shape) or (i != axis and axis_size
!= shape[i]):
+ if len(arr.get_shape()) != len(shape) or (i != _axis and axis_size
!= shape[i]):
raise OperationError(space.w_ValueError, space.wrap(
"all the input arrays must have same number of
dimensions"))
- elif i == axis:
+ elif i == _axis:
shape[i] += axis_size
res = W_NDimArray.from_shape(shape, dtype, 'C')
chunks = [Chunk(0, i, 1, i) for i in shape]
axis_start = 0
for arr in args_w:
- if arr.get_shape()[axis] == 0:
+ if arr.get_shape()[_axis] == 0:
continue
- chunks[axis] = Chunk(axis_start, axis_start + arr.get_shape()[axis], 1,
- arr.get_shape()[axis])
+ chunks[_axis] = Chunk(axis_start, axis_start + arr.get_shape()[_axis],
1,
+ arr.get_shape()[_axis])
Chunks(chunks).apply(res).implementation.setslice(space, arr)
- axis_start += arr.get_shape()[axis]
+ axis_start += arr.get_shape()[_axis]
return res
@unwrap_spec(repeats=int)
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1435,9 +1435,8 @@
if 0: # XXX why does numpy allow this?
a = concatenate((a1, a2), axis=1)
assert (a == [0,1,2,3,4,5]).all()
- if 0: # segfault
- a = concatenate((a1, a2), axis=-1)
- assert (a == [0,1,2,3,4,5]).all()
+ a = concatenate((a1, a2), axis=-1)
+ assert (a == [0,1,2,3,4,5]).all()
b1 = array([[1, 2], [3, 4]])
b2 = array([[5, 6]])
@@ -1456,12 +1455,10 @@
g1 = array([[0,1,2]])
g2 = array([[3,4,5]])
- if 0: # segfault
- g = concatenate((g1, g2), axis=-2)
- assert (g == [[0,1,2],[3,4,5]]).all()
- if 0: # XXX why does numpy allow this?
- exc = raises(IndexError, concatenate, (g1, g2), axis=-3)
- assert str(exc.value) == "axis -3 out of bounds [0, 2)"
+ g = concatenate((g1, g2), axis=-2)
+ assert (g == [[0,1,2],[3,4,5]]).all()
+ exc = raises(IndexError, concatenate, (g1, g2), axis=-3)
+ assert str(exc.value) == "axis -3 out of bounds [0, 2)"
exc = raises(IndexError, concatenate, (g1, g2), axis=2)
assert str(exc.value) == "axis 2 out of bounds [0, 2)"
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit