Author: Richard Plangger <r...@pasra.at> Branch: vecopt-merge-iterator-sharing Changeset: r78985:3cc178b04857 Date: 2015-08-14 13:24 +0200 http://bitbucket.org/pypy/pypy/changeset/3cc178b04857/
Log: started to transform call2 to share iterators in the loop, works but needs check if the jit codes improve as well diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py --- a/pypy/module/micronumpy/iterators.py +++ b/pypy/module/micronumpy/iterators.py @@ -83,6 +83,11 @@ self._indices = indices self.offset = offset + def same(self, other): + if self.offset == other.offset: + return self.iterator.same_shape(other.iterator) + return False + class ArrayIter(object): _immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1', 'shape_m1[*]', @@ -100,6 +105,7 @@ self.array = array self.size = size self.ndim_m1 = len(shape) - 1 + # self.shape_m1 = [s - 1 for s in shape] self.strides = strides self.backstrides = backstrides @@ -113,6 +119,17 @@ factors[ndim-i-1] = factors[ndim-i] * shape[ndim-i] self.factors = factors + def same_shape(self, other): + """ if two iterators share the same shape, + next() only needs to be called on one! + """ + return (self.contiguous == other.contiguous and + self.array.dtype is self.array.dtype and + self.shape_m1 == other.shape_m1 and + self.strides == other.strides and + self.backstrides == other.backstrides and + self.factors == other.factors) + @jit.unroll_safe def reset(self, state=None, mutate=False): index = 0 @@ -196,7 +213,7 @@ return state.index >= self.size def getitem(self, state): - assert state.iterator is self + # assert state.iterator is self return self.array.getitem(state.offset) def getitem_bool(self, state): @@ -207,7 +224,6 @@ assert state.iterator is self self.array.setitem(state.offset, elem) - def AxisIter(array, shape, axis): strides = array.get_strides() backstrides = array.get_backstrides() diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py --- a/pypy/module/micronumpy/loop.py +++ b/pypy/module/micronumpy/loop.py @@ -15,7 +15,7 @@ call2_driver = jit.JitDriver( name='numpy_call2', - greens=['shapelen', 'func', 'left', 'right', 'calc_dtype', 'res_dtype'], + greens=['shapelen','state_count', 'left_index', 'right_index', 'left', 'right', 'func', 'calc_dtype', 'res_dtype'], reds='auto', vectorize=True) def call2(space, shape, func, calc_dtype, w_lhs, w_rhs, out): @@ -38,20 +38,62 @@ out_iter, out_state = out.create_iter(shape) shapelen = len(shape) res_dtype = out.get_dtype() - while not out_iter.done(out_state): - call2_driver.jit_merge_point(shapelen=shapelen, func=func, + + states = [out_state,left_state,right_state] + out_index = 0 + left_index = 1 + right_index = 2 + # left == right == out + # left == right + # left == out + # right == out + if not right_iter: + del states[2] + else: + if out_state.same(right_state): + # (1) out and right are the same -> remove right + right_index = 0 + del states[2] + if not left_iter: + del states[1] + if right_index == 2: + right_index = 1 + else: + if out_state.same(left_state): + # (2) out and left are the same -> remove left + left_index = 0 + del states[1] + if right_index == 2: + right_index = 1 + else: + if len(states) == 3: # did not enter (1) + if right_iter and right_state.same(left_state): + right_index = 1 + del states[2] + state_count = len(states) + # + while not out_iter.done(states[0]): + call2_driver.jit_merge_point(shapelen=shapelen, + func=func, left=left_iter is None, right=right_iter is None, - calc_dtype=calc_dtype, res_dtype=res_dtype) + state_count=state_count, + left_index=left_index, + right_index=right_index, + calc_dtype=calc_dtype, + res_dtype=res_dtype) if left_iter: + left_state = states[left_index] w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype) - left_state = left_iter.next(left_state) if right_iter: + right_state = states[right_index] w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype) - right_state = right_iter.next(right_state) w_out = func(calc_dtype, w_left, w_right) - out_iter.setitem(out_state, w_out.convert_to(space, res_dtype)) - out_state = out_iter.next(out_state) + out_iter.setitem(states[0], w_out.convert_to(space, res_dtype)) + # + for i,state in enumerate(states): + states[i] = state.iterator.next(state) + # if not set to None, the values will be loop carried # (for the var,var case), forcing the vectorization to unpack # the vector registers at the end of the loop _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit