Author: Richard Plangger <r...@pasra.at>
Branch: vecopt-merge-iterator-sharing
Changeset: r78985:3cc178b04857
Date: 2015-08-14 13:24 +0200
http://bitbucket.org/pypy/pypy/changeset/3cc178b04857/

Log:    started to transform call2 to share iterators in the loop, works but
        needs check if the jit codes improve as well

diff --git a/pypy/module/micronumpy/iterators.py 
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -83,6 +83,11 @@
         self._indices = indices
         self.offset = offset
 
+    def same(self, other):
+        if self.offset == other.offset:
+            return self.iterator.same_shape(other.iterator)
+        return False
+
 
 class ArrayIter(object):
     _immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1', 
'shape_m1[*]',
@@ -100,6 +105,7 @@
         self.array = array
         self.size = size
         self.ndim_m1 = len(shape) - 1
+        #
         self.shape_m1 = [s - 1 for s in shape]
         self.strides = strides
         self.backstrides = backstrides
@@ -113,6 +119,17 @@
                 factors[ndim-i-1] = factors[ndim-i] * shape[ndim-i]
         self.factors = factors
 
+    def same_shape(self, other):
+        """ if two iterators share the same shape,
+        next() only needs to be called on one!
+        """
+        return (self.contiguous == other.contiguous and
+                self.array.dtype is self.array.dtype and
+                self.shape_m1 == other.shape_m1 and
+                self.strides == other.strides and
+                self.backstrides == other.backstrides and
+                self.factors == other.factors)
+
     @jit.unroll_safe
     def reset(self, state=None, mutate=False):
         index = 0
@@ -196,7 +213,7 @@
         return state.index >= self.size
 
     def getitem(self, state):
-        assert state.iterator is self
+        # assert state.iterator is self
         return self.array.getitem(state.offset)
 
     def getitem_bool(self, state):
@@ -207,7 +224,6 @@
         assert state.iterator is self
         self.array.setitem(state.offset, elem)
 
-
 def AxisIter(array, shape, axis):
     strides = array.get_strides()
     backstrides = array.get_backstrides()
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -15,7 +15,7 @@
 
 call2_driver = jit.JitDriver(
     name='numpy_call2',
-    greens=['shapelen', 'func', 'left', 'right', 'calc_dtype', 'res_dtype'],
+    greens=['shapelen','state_count', 'left_index', 'right_index', 'left', 
'right', 'func', 'calc_dtype', 'res_dtype'],
     reds='auto', vectorize=True)
 
 def call2(space, shape, func, calc_dtype, w_lhs, w_rhs, out):
@@ -38,20 +38,62 @@
     out_iter, out_state = out.create_iter(shape)
     shapelen = len(shape)
     res_dtype = out.get_dtype()
-    while not out_iter.done(out_state):
-        call2_driver.jit_merge_point(shapelen=shapelen, func=func,
+
+    states = [out_state,left_state,right_state]
+    out_index = 0
+    left_index = 1
+    right_index = 2
+    # left == right == out
+    # left == right
+    # left == out
+    # right == out
+    if not right_iter:
+        del states[2]
+    else:
+        if out_state.same(right_state):
+            # (1) out and right are the same -> remove right
+            right_index = 0
+            del states[2]
+    if not left_iter:
+        del states[1]
+        if right_index == 2:
+            right_index = 1
+    else:
+        if out_state.same(left_state):
+            # (2) out and left are the same -> remove left
+            left_index = 0
+            del states[1]
+            if right_index == 2:
+                right_index = 1
+        else:
+            if len(states) == 3: # did not enter (1)
+                if right_iter and right_state.same(left_state):
+                    right_index = 1
+                    del states[2]
+    state_count = len(states)
+    #
+    while not out_iter.done(states[0]):
+        call2_driver.jit_merge_point(shapelen=shapelen,
+                                     func=func,
                                      left=left_iter is None,
                                      right=right_iter is None,
-                                     calc_dtype=calc_dtype, 
res_dtype=res_dtype)
+                                     state_count=state_count,
+                                     left_index=left_index,
+                                     right_index=right_index,
+                                     calc_dtype=calc_dtype,
+                                     res_dtype=res_dtype)
         if left_iter:
+            left_state = states[left_index]
             w_left = left_iter.getitem(left_state).convert_to(space, 
calc_dtype)
-            left_state = left_iter.next(left_state)
         if right_iter:
+            right_state = states[right_index]
             w_right = right_iter.getitem(right_state).convert_to(space, 
calc_dtype)
-            right_state = right_iter.next(right_state)
         w_out = func(calc_dtype, w_left, w_right)
-        out_iter.setitem(out_state, w_out.convert_to(space, res_dtype))
-        out_state = out_iter.next(out_state)
+        out_iter.setitem(states[0], w_out.convert_to(space, res_dtype))
+        #
+        for i,state in enumerate(states):
+            states[i] = state.iterator.next(state)
+
         # if not set to None, the values will be loop carried
         # (for the var,var case), forcing the vectorization to unpack
         # the vector registers at the end of the loop
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to