Author: Ronan Lamy <ronan.l...@gmail.com> Branch: ufunc-reduce Changeset: r78686:ea4e0e3201e8 Date: 2015-07-27 19:53 +0100 http://bitbucket.org/pypy/pypy/changeset/ea4e0e3201e8/
Log: small simplification diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py --- a/pypy/module/micronumpy/iterators.py +++ b/pypy/module/micronumpy/iterators.py @@ -204,17 +204,16 @@ self.array.setitem(state.offset, elem) -def AxisIter(array, shape, axis, cumulative): +def AxisIter(array, shape, axis): strides = array.get_strides() backstrides = array.get_backstrides() - if not cumulative: - if len(shape) == len(strides): - # keepdims = True - strides = strides[:axis] + [0] + strides[axis + 1:] - backstrides = backstrides[:axis] + [0] + backstrides[axis + 1:] - else: - strides = strides[:axis] + [0] + strides[axis:] - backstrides = backstrides[:axis] + [0] + backstrides[axis:] + if len(shape) == len(strides): + # keepdims = True + strides = strides[:axis] + [0] + strides[axis + 1:] + backstrides = backstrides[:axis] + [0] + backstrides[axis + 1:] + else: + strides = strides[:axis] + [0] + strides[axis:] + backstrides = backstrides[:axis] + [0] + backstrides[axis:] return ArrayIter(array, support.product(shape), shape, strides, backstrides) diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py --- a/pypy/module/micronumpy/loop.py +++ b/pypy/module/micronumpy/loop.py @@ -9,7 +9,7 @@ from pypy.module.micronumpy import support, constants as NPY from pypy.module.micronumpy.base import W_NDimArray, convert_to_array from pypy.module.micronumpy.iterators import PureShapeIter, AxisIter, \ - AllButAxisIter + AllButAxisIter, ArrayIter from pypy.interpreter.argument import Arguments @@ -246,13 +246,13 @@ greens=['shapelen', 'func', 'dtype'], reds='auto') + def do_accumulate(space, shape, func, arr, dtype, axis, out, identity): - out_iter = AxisIter(out.implementation, arr.get_shape(), axis, cumulative=True) - out_state = out_iter.reset() + out_iter, out_state = out.create_iter() obj_shape = arr.get_shape() temp_shape = obj_shape[:axis] + obj_shape[axis + 1:] temp = W_NDimArray.from_shape(space, temp_shape, dtype, w_instance=arr) - temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis, False) + temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis) temp_state = temp_iter.reset() arr_iter, arr_state = arr.create_iter() arr_iter.track_index = False @@ -340,7 +340,7 @@ reds='auto') def do_axis_reduce(space, shape, func, arr, dtype, axis, out, identity): - out_iter = AxisIter(out.implementation, arr.get_shape(), axis, cumulative=False) + out_iter = AxisIter(out.implementation, arr.get_shape(), axis) out_state = out_iter.reset() arr_iter, arr_state = arr.create_iter() arr_iter.track_index = False _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit