Author: Ronan Lamy <ronan.l...@gmail.com>
Branch: ufunc-reduce
Changeset: r78686:ea4e0e3201e8
Date: 2015-07-27 19:53 +0100
http://bitbucket.org/pypy/pypy/changeset/ea4e0e3201e8/

Log:    small simplification

diff --git a/pypy/module/micronumpy/iterators.py 
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -204,17 +204,16 @@
         self.array.setitem(state.offset, elem)
 
 
-def AxisIter(array, shape, axis, cumulative):
+def AxisIter(array, shape, axis):
     strides = array.get_strides()
     backstrides = array.get_backstrides()
-    if not cumulative:
-        if len(shape) == len(strides):
-            # keepdims = True
-            strides = strides[:axis] + [0] + strides[axis + 1:]
-            backstrides = backstrides[:axis] + [0] + backstrides[axis + 1:]
-        else:
-            strides = strides[:axis] + [0] + strides[axis:]
-            backstrides = backstrides[:axis] + [0] + backstrides[axis:]
+    if len(shape) == len(strides):
+        # keepdims = True
+        strides = strides[:axis] + [0] + strides[axis + 1:]
+        backstrides = backstrides[:axis] + [0] + backstrides[axis + 1:]
+    else:
+        strides = strides[:axis] + [0] + strides[axis:]
+        backstrides = backstrides[:axis] + [0] + backstrides[axis:]
     return ArrayIter(array, support.product(shape), shape, strides, 
backstrides)
 
 
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -9,7 +9,7 @@
 from pypy.module.micronumpy import support, constants as NPY
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy.iterators import PureShapeIter, AxisIter, \
-    AllButAxisIter
+    AllButAxisIter, ArrayIter
 from pypy.interpreter.argument import Arguments
 
 
@@ -246,13 +246,13 @@
                                   greens=['shapelen', 'func', 'dtype'],
                                   reds='auto')
 
+
 def do_accumulate(space, shape, func, arr, dtype, axis, out, identity):
-    out_iter = AxisIter(out.implementation, arr.get_shape(), axis, 
cumulative=True)
-    out_state = out_iter.reset()
+    out_iter, out_state = out.create_iter()
     obj_shape = arr.get_shape()
     temp_shape = obj_shape[:axis] + obj_shape[axis + 1:]
     temp = W_NDimArray.from_shape(space, temp_shape, dtype, w_instance=arr)
-    temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis, False)
+    temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis)
     temp_state = temp_iter.reset()
     arr_iter, arr_state = arr.create_iter()
     arr_iter.track_index = False
@@ -340,7 +340,7 @@
                                    reds='auto')
 
 def do_axis_reduce(space, shape, func, arr, dtype, axis, out, identity):
-    out_iter = AxisIter(out.implementation, arr.get_shape(), axis, 
cumulative=False)
+    out_iter = AxisIter(out.implementation, arr.get_shape(), axis)
     out_state = out_iter.reset()
     arr_iter, arr_state = arr.create_iter()
     arr_iter.track_index = False
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to