Author: Richard Plangger <[email protected]>
Branch: vecopt-merge-iterator-sharing
Changeset: r78985:3cc178b04857
Date: 2015-08-14 13:24 +0200
http://bitbucket.org/pypy/pypy/changeset/3cc178b04857/
Log: started to transform call2 to share iterators in the loop, works but
needs check if the jit codes improve as well
diff --git a/pypy/module/micronumpy/iterators.py
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -83,6 +83,11 @@
self._indices = indices
self.offset = offset
+ def same(self, other):
+ if self.offset == other.offset:
+ return self.iterator.same_shape(other.iterator)
+ return False
+
class ArrayIter(object):
_immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1',
'shape_m1[*]',
@@ -100,6 +105,7 @@
self.array = array
self.size = size
self.ndim_m1 = len(shape) - 1
+ #
self.shape_m1 = [s - 1 for s in shape]
self.strides = strides
self.backstrides = backstrides
@@ -113,6 +119,17 @@
factors[ndim-i-1] = factors[ndim-i] * shape[ndim-i]
self.factors = factors
+ def same_shape(self, other):
+ """ if two iterators share the same shape,
+ next() only needs to be called on one!
+ """
+ return (self.contiguous == other.contiguous and
+ self.array.dtype is self.array.dtype and
+ self.shape_m1 == other.shape_m1 and
+ self.strides == other.strides and
+ self.backstrides == other.backstrides and
+ self.factors == other.factors)
+
@jit.unroll_safe
def reset(self, state=None, mutate=False):
index = 0
@@ -196,7 +213,7 @@
return state.index >= self.size
def getitem(self, state):
- assert state.iterator is self
+ # assert state.iterator is self
return self.array.getitem(state.offset)
def getitem_bool(self, state):
@@ -207,7 +224,6 @@
assert state.iterator is self
self.array.setitem(state.offset, elem)
-
def AxisIter(array, shape, axis):
strides = array.get_strides()
backstrides = array.get_backstrides()
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -15,7 +15,7 @@
call2_driver = jit.JitDriver(
name='numpy_call2',
- greens=['shapelen', 'func', 'left', 'right', 'calc_dtype', 'res_dtype'],
+ greens=['shapelen','state_count', 'left_index', 'right_index', 'left',
'right', 'func', 'calc_dtype', 'res_dtype'],
reds='auto', vectorize=True)
def call2(space, shape, func, calc_dtype, w_lhs, w_rhs, out):
@@ -38,20 +38,62 @@
out_iter, out_state = out.create_iter(shape)
shapelen = len(shape)
res_dtype = out.get_dtype()
- while not out_iter.done(out_state):
- call2_driver.jit_merge_point(shapelen=shapelen, func=func,
+
+ states = [out_state,left_state,right_state]
+ out_index = 0
+ left_index = 1
+ right_index = 2
+ # left == right == out
+ # left == right
+ # left == out
+ # right == out
+ if not right_iter:
+ del states[2]
+ else:
+ if out_state.same(right_state):
+ # (1) out and right are the same -> remove right
+ right_index = 0
+ del states[2]
+ if not left_iter:
+ del states[1]
+ if right_index == 2:
+ right_index = 1
+ else:
+ if out_state.same(left_state):
+ # (2) out and left are the same -> remove left
+ left_index = 0
+ del states[1]
+ if right_index == 2:
+ right_index = 1
+ else:
+ if len(states) == 3: # did not enter (1)
+ if right_iter and right_state.same(left_state):
+ right_index = 1
+ del states[2]
+ state_count = len(states)
+ #
+ while not out_iter.done(states[0]):
+ call2_driver.jit_merge_point(shapelen=shapelen,
+ func=func,
left=left_iter is None,
right=right_iter is None,
- calc_dtype=calc_dtype,
res_dtype=res_dtype)
+ state_count=state_count,
+ left_index=left_index,
+ right_index=right_index,
+ calc_dtype=calc_dtype,
+ res_dtype=res_dtype)
if left_iter:
+ left_state = states[left_index]
w_left = left_iter.getitem(left_state).convert_to(space,
calc_dtype)
- left_state = left_iter.next(left_state)
if right_iter:
+ right_state = states[right_index]
w_right = right_iter.getitem(right_state).convert_to(space,
calc_dtype)
- right_state = right_iter.next(right_state)
w_out = func(calc_dtype, w_left, w_right)
- out_iter.setitem(out_state, w_out.convert_to(space, res_dtype))
- out_state = out_iter.next(out_state)
+ out_iter.setitem(states[0], w_out.convert_to(space, res_dtype))
+ #
+ for i,state in enumerate(states):
+ states[i] = state.iterator.next(state)
+
# if not set to None, the values will be loop carried
# (for the var,var case), forcing the vectorization to unpack
# the vector registers at the end of the loop
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit