Author: Ronan Lamy <[email protected]>
Branch: ufunc-reduce
Changeset: r78686:ea4e0e3201e8
Date: 2015-07-27 19:53 +0100
http://bitbucket.org/pypy/pypy/changeset/ea4e0e3201e8/
Log: small simplification
diff --git a/pypy/module/micronumpy/iterators.py
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -204,17 +204,16 @@
self.array.setitem(state.offset, elem)
-def AxisIter(array, shape, axis, cumulative):
+def AxisIter(array, shape, axis):
strides = array.get_strides()
backstrides = array.get_backstrides()
- if not cumulative:
- if len(shape) == len(strides):
- # keepdims = True
- strides = strides[:axis] + [0] + strides[axis + 1:]
- backstrides = backstrides[:axis] + [0] + backstrides[axis + 1:]
- else:
- strides = strides[:axis] + [0] + strides[axis:]
- backstrides = backstrides[:axis] + [0] + backstrides[axis:]
+ if len(shape) == len(strides):
+ # keepdims = True
+ strides = strides[:axis] + [0] + strides[axis + 1:]
+ backstrides = backstrides[:axis] + [0] + backstrides[axis + 1:]
+ else:
+ strides = strides[:axis] + [0] + strides[axis:]
+ backstrides = backstrides[:axis] + [0] + backstrides[axis:]
return ArrayIter(array, support.product(shape), shape, strides,
backstrides)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -9,7 +9,7 @@
from pypy.module.micronumpy import support, constants as NPY
from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
from pypy.module.micronumpy.iterators import PureShapeIter, AxisIter, \
- AllButAxisIter
+ AllButAxisIter, ArrayIter
from pypy.interpreter.argument import Arguments
@@ -246,13 +246,13 @@
greens=['shapelen', 'func', 'dtype'],
reds='auto')
+
def do_accumulate(space, shape, func, arr, dtype, axis, out, identity):
- out_iter = AxisIter(out.implementation, arr.get_shape(), axis,
cumulative=True)
- out_state = out_iter.reset()
+ out_iter, out_state = out.create_iter()
obj_shape = arr.get_shape()
temp_shape = obj_shape[:axis] + obj_shape[axis + 1:]
temp = W_NDimArray.from_shape(space, temp_shape, dtype, w_instance=arr)
- temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis, False)
+ temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis)
temp_state = temp_iter.reset()
arr_iter, arr_state = arr.create_iter()
arr_iter.track_index = False
@@ -340,7 +340,7 @@
reds='auto')
def do_axis_reduce(space, shape, func, arr, dtype, axis, out, identity):
- out_iter = AxisIter(out.implementation, arr.get_shape(), axis,
cumulative=False)
+ out_iter = AxisIter(out.implementation, arr.get_shape(), axis)
out_state = out_iter.reset()
arr_iter, arr_state = arr.create_iter()
arr_iter.track_index = False
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit