Author: Richard Plangger <r...@pasra.at> Branch: vecopt-merge-iterator-sharing Changeset: r78988:690ba1eaa6a8 Date: 2015-08-14 18:50 +0200 http://bitbucket.org/pypy/pypy/changeset/690ba1eaa6a8/
Log: (plan_rich, ronan) first working version that generates all five possible call2 combinations that shares the iterators diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py --- a/pypy/module/micronumpy/iterators.py +++ b/pypy/module/micronumpy/iterators.py @@ -88,7 +88,6 @@ return self.iterator.same_shape(other.iterator) return False - class ArrayIter(object): _immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1', 'shape_m1[*]', 'strides[*]', 'backstrides[*]', 'factors[*]', diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py --- a/pypy/module/micronumpy/loop.py +++ b/pypy/module/micronumpy/loop.py @@ -2,6 +2,7 @@ operations. This is the place to look for all the computations that iterate over all the array elements. """ +import py from pypy.interpreter.error import OperationError from rpython.rlib import jit from rpython.rlib.rstring import StringBuilder @@ -13,11 +14,6 @@ from pypy.interpreter.argument import Arguments -call2_driver = jit.JitDriver( - name='numpy_call2', - greens=['shapelen','state_count', 'left_index', 'right_index', 'left', 'right', 'func', 'calc_dtype', 'res_dtype'], - reds='auto', vectorize=True) - def call2(space, shape, func, calc_dtype, w_lhs, w_rhs, out): if w_lhs.get_size() == 1: w_left = w_lhs.get_scalar_value().convert_to(space, calc_dtype) @@ -40,68 +36,102 @@ res_dtype = out.get_dtype() states = [out_state,left_state,right_state] - out_index = 0 left_index = 1 right_index = 2 # left == right == out # left == right # left == out # right == out + params = (space, shapelen, func, calc_dtype, res_dtype, out, + w_left, w_right, left_iter, right_iter, out_iter, + left_state, right_state, out_state) if not right_iter: + # rhs is a scalar del states[2] else: + # rhs is NOT a scalar if out_state.same(right_state): # (1) out and right are the same -> remove right right_index = 0 del states[2] + # if not left_iter: + # lhs is a scalar del states[1] if right_index == 2: right_index = 1 + return call2_advance_out_right(*params) else: + # lhs is NOT a scalar if out_state.same(left_state): # (2) out and left are the same -> remove left left_index = 0 del states[1] if right_index == 2: right_index = 1 + return call2_advance_out_right(*params) else: if len(states) == 3: # did not enter (1) if right_iter and right_state.same(left_state): right_index = 1 del states[2] + return call2_advance_out_left_eq_right(*params) + else: + # worst case + return call2_advance_out_left_right(*params) + else: + return call2_advance_out_left(*params) + state_count = len(states) + if state_count == 1: + return call2_advance_out(*params) + + assert 0, "logical problem with the selection of the call 2 case" + +def generate_call2_cases(name, left_state, right_state): + call2_driver = jit.JitDriver(name='numpy_call2_' + name, + greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'], + reds='auto', vectorize=True) # - while not out_iter.done(states[0]): - call2_driver.jit_merge_point(shapelen=shapelen, - func=func, - left=left_iter is None, - right=right_iter is None, - state_count=state_count, - left_index=left_index, - right_index=right_index, - calc_dtype=calc_dtype, - res_dtype=res_dtype) - if left_iter: - left_state = states[left_index] - w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype) - if right_iter: - right_state = states[right_index] - w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype) - w_out = func(calc_dtype, w_left, w_right) - out_iter.setitem(states[0], w_out.convert_to(space, res_dtype)) - # - for i,state in enumerate(states): - states[i] = state.iterator.next(state) + advance_left_state = left_state == "left_state" + advance_right_state = right_state == "right_state" + code = """ + def method(space, shapelen, func, calc_dtype, res_dtype, out, + w_left, w_right, left_iter, right_iter, out_iter, + left_state, right_state, out_state): + while not out_iter.done(out_state): + call2_driver.jit_merge_point(shapelen=shapelen, func=func, + calc_dtype=calc_dtype, res_dtype=res_dtype) + if left_iter: + w_left = left_iter.getitem({left_state}).convert_to(space, calc_dtype) + if right_iter: + w_right = right_iter.getitem({right_state}).convert_to(space, calc_dtype) + w_out = func(calc_dtype, w_left, w_right) + out_iter.setitem(out_state, w_out.convert_to(space, res_dtype)) + out_state = out_iter.next(out_state) + if advance_left_state and left_iter: + left_state = left_iter.next(left_state) + if advance_right_state and right_iter: + right_state = right_iter.next(right_state) + # + # if not set to None, the values will be loop carried + # (for the var,var case), forcing the vectorization to unpack + # the vector registers at the end of the loop + if left_iter: + w_left = None + if right_iter: + w_right = None + return out + """ + exec(py.code.Source(code.format(left_state=left_state,right_state=right_state)).compile(), locals()) + method.__name__ = "call2_" + name + return method - # if not set to None, the values will be loop carried - # (for the var,var case), forcing the vectorization to unpack - # the vector registers at the end of the loop - if left_iter: - w_left = None - if right_iter: - w_right = None - return out +call2_advance_out = generate_call2_cases("inc_out", "out_state", "out_state") +call2_advance_out_left = generate_call2_cases("inc_out_left", "left_state", "out_state") +call2_advance_out_right = generate_call2_cases("inc_out_right", "out_state", "right_state") +call2_advance_out_left_eq_right = generate_call2_cases("inc_out_left_eq_right", "left_state", "left_state") +call2_advance_out_left_right = generate_call2_cases("inc_out_left_right", "left_state", "right_state") call1_driver = jit.JitDriver( name='numpy_call1', diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py --- a/pypy/module/micronumpy/test/test_zjit.py +++ b/pypy/module/micronumpy/test/test_zjit.py @@ -911,8 +911,10 @@ def test_multidim_slice(self): result = self.run('multidim_slice') assert result == 12 - self.check_trace_count(2) - self.check_vectorized(1,0) # TODO? + self.check_trace_count(3) + # ::2 creates a view object -> needs an inner loop + # that iterates continous chunks of the matrix + self.check_vectorized(1,1) # NOT WORKING _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit