Author: Ronan Lamy <ronan.l...@gmail.com> Branch: ufunc-reduce Changeset: r78697:77850d7d684c Date: 2015-07-28 18:37 +0100 http://bitbucket.org/pypy/pypy/changeset/77850d7d684c/
Log: Refactor do_axis_reduce() so that it takes the axis_flags list instead of a single axis diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py --- a/pypy/module/micronumpy/loop.py +++ b/pypy/module/micronumpy/loop.py @@ -339,30 +339,60 @@ greens=['shapelen', 'func', 'dtype'], reds='auto') -def do_axis_reduce(space, func, arr, dtype, axis, out, identity): - out_iter = AxisIter(out.implementation, arr.get_shape(), axis) - out_state = out_iter.reset() - arr_iter, arr_state = arr.create_iter() - arr_iter.track_index = False +def do_axis_reduce(space, func, arr, dtype, axis_flags, out, identity): + out_iter, out_state = out.create_iter() + out_iter.track_index = False + shape = arr.get_shape() + strides = arr.implementation.get_strides() + backstrides = arr.implementation.get_backstrides() + shapelen = len(shape) + inner_shape = [-1] * shapelen + inner_strides = [-1] * shapelen + inner_backstrides = [-1] * shapelen + outer_shape = [-1] * shapelen + outer_strides = [-1] * shapelen + outer_backstrides = [-1] * shapelen + for i in range(len(shape)): + if axis_flags[i]: + inner_shape[i] = shape[i] + inner_strides[i] = strides[i] + inner_backstrides[i] = backstrides[i] + outer_shape[i] = 1 + outer_strides[i] = 0 + outer_backstrides[i] = 0 + else: + outer_shape[i] = shape[i] + outer_strides[i] = strides[i] + outer_backstrides[i] = backstrides[i] + inner_shape[i] = 1 + inner_strides[i] = 0 + inner_backstrides[i] = 0 + inner_iter = ArrayIter(arr.implementation, support.product(inner_shape), + inner_shape, inner_strides, inner_backstrides) + outer_iter = ArrayIter(arr.implementation, support.product(outer_shape), + outer_shape, outer_strides, outer_backstrides) + assert outer_iter.size == out_iter.size + if identity is not None: identity = identity.convert_to(space, dtype) - shapelen = len(out.get_shape()) - while not out_iter.done(out_state): - axis_reduce_driver.jit_merge_point(shapelen=shapelen, func=func, - dtype=dtype) - w_val = arr_iter.getitem(arr_state).convert_to(space, dtype) - arr_state = arr_iter.next(arr_state) - - out_indices = out_iter.indices(out_state) - if out_indices[axis] == 0: - if identity is not None: - w_val = func(dtype, identity, w_val) + outer_state = outer_iter.reset() + while not outer_iter.done(outer_state): + inner_state = inner_iter.reset() + inner_state.offset = outer_state.offset + if identity is not None: + w_val = identity else: - cur = out_iter.getitem(out_state) - w_val = func(dtype, cur, w_val) - + w_val = inner_iter.getitem(inner_state).convert_to(space, dtype) + inner_state = inner_iter.next(inner_state) + while not inner_iter.done(inner_state): + axis_reduce_driver.jit_merge_point(shapelen=shapelen, func=func, + dtype=dtype) + w_item = inner_iter.getitem(inner_state).convert_to(space, dtype) + w_val = func(dtype, w_item, w_val) + inner_state = inner_iter.next(inner_state) out_iter.setitem(out_state, w_val) out_state = out_iter.next(out_state) + outer_state = outer_iter.next(outer_state) return out diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py --- a/pypy/module/micronumpy/ufuncs.py +++ b/pypy/module/micronumpy/ufuncs.py @@ -410,7 +410,7 @@ if self.identity is not None: out.fill(space, self.identity.convert_to(space, dtype)) return out - loop.do_axis_reduce(space, self.func, obj, dtype, axis, + loop.do_axis_reduce(space, self.func, obj, dtype, axis_flags, out, self.identity) if call__array_wrap__: out = space.call_method(obj, '__array_wrap__', out) _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit