Author: Richard Plangger <r...@pasra.at>
Branch: vecopt
Changeset: r78558:04f489dd60dd
Date: 2015-07-16 12:16 +0200
http://bitbucket.org/pypy/pypy/changeset/04f489dd60dd/

Log:    added ABC optimization that is turned on when executed when user
        code is vectorized note that this optimization versions the loop
        immediatly (to be tested) it introduces a guard before the loop is
        entered to remove guards that are contained within the loop body

diff --git a/rpython/jit/metainterp/history.py 
b/rpython/jit/metainterp/history.py
--- a/rpython/jit/metainterp/history.py
+++ b/rpython/jit/metainterp/history.py
@@ -746,7 +746,7 @@
 
 class LoopVersion(object):
 
-    def __init__(self, operations, opt_ops, invariant_arg_count=0, 
aligned=False):
+    def __init__(self, operations, aligned=False):
         self.operations = operations
         self.aligned = aligned
         self.faildescrs = []
@@ -756,6 +756,8 @@
         label = operations[idx]
         self.label_pos = idx
         self.inputargs = label.getarglist()
+
+    def register_all_guards(self, opt_ops, invariant_arg_count=0):
         idx = index_of_first(rop.LABEL, opt_ops)
         assert idx >= 0
         version_failargs = opt_ops[idx].getarglist()
@@ -769,6 +771,7 @@
 
         for op in opt_ops:
             if op.is_guard():
+                import pdb; pdb.set_trace()
                 assert isinstance(op, GuardResOp)
                 descr = op.getdescr()
                 if descr.loop_version():
@@ -780,6 +783,13 @@
                     op.setfailargs(version_failargs)
                     op.rd_snapshot = None
 
+    def register_guard(self, op):
+        assert isinstance(op, GuardResOp)
+        descr = op.getdescr()
+        self.faildescrs.append(descr)
+        op.setfailargs(self.inputargs)
+        op.rd_snapshot = None
+
     def copy_operations(self):
         return [op.clone() for op in self.operations]
 
@@ -803,7 +813,6 @@
     call_pure_results = None
     logops = None
     quasi_immutable_deps = None
-    versions = None
 
     def _token(*args):
         raise Exception("TreeLoop.token is killed")
@@ -816,6 +825,7 @@
 
     def __init__(self, name):
         self.name = name
+        self.versions = []
         # self.operations = list of ResOperations
         #   ops of the kind 'guard_xxx' contain a further list of operations,
         #   which may itself contain 'guard_xxx' and so on, making a tree.
@@ -841,6 +851,14 @@
         """ return the first operation having the same opnum or -1 """
         return index_of_first(opnum, self.operations)
 
+    def snapshot(self):
+        version = LoopVersion(self.copy_operations(), [])
+        self.versions.append(version)
+        return version
+
+    def copy_operations(self):
+        return [ op.clone() for op in self.operations ]
+
     def get_display_text(self):    # for graphpage.py
         return self.name + '\n' + repr(self.inputargs)
 
diff --git a/rpython/jit/metainterp/optimizeopt/__init__.py 
b/rpython/jit/metainterp/optimizeopt/__init__.py
--- a/rpython/jit/metainterp/optimizeopt/__init__.py
+++ b/rpython/jit/metainterp/optimizeopt/__init__.py
@@ -73,7 +73,7 @@
                  or warmstate.vectorize_user):
                 optimize_vector(metainterp_sd, jitdriver_sd, loop,
                                 optimizations, inline_short_preamble,
-                                start_state, warmstate.vec_cost)
+                                start_state, warmstate)
             else:
                 return optimize_unroll(metainterp_sd, jitdriver_sd, loop,
                                        optimizations, inline_short_preamble,
diff --git a/rpython/jit/metainterp/optimizeopt/dependency.py 
b/rpython/jit/metainterp/optimizeopt/dependency.py
--- a/rpython/jit/metainterp/optimizeopt/dependency.py
+++ b/rpython/jit/metainterp/optimizeopt/dependency.py
@@ -891,8 +891,9 @@
         return mycoeff + self.constant - (othercoeff + other.constant)
 
     def emit_operations(self, opt, result_box=None):
-        assert not self.is_identity()
         box = self.var
+        if self.is_identity():
+            return box
         last_op = None
         if self.coefficient_mul != 1:
             box_result = box.clonebox()
@@ -904,25 +905,31 @@
             last_op = ResOperation(rop.INT_FLOORDIV, [box, 
ConstInt(self.coefficient_div)], box_result)
             opt.emit_operation(last_op)
             box = box_result
-        if self.constant != 0:
+        if self.constant > 0:
             box_result = box.clonebox()
             last_op = ResOperation(rop.INT_ADD, [box, 
ConstInt(self.constant)], box_result)
             opt.emit_operation(last_op)
             box = box_result
+        if self.constant < 0:
+            box_result = box.clonebox()
+            last_op = ResOperation(rop.INT_SUB, [box, 
ConstInt(self.constant)], box_result)
+            opt.emit_operation(last_op)
+            box = box_result
         if result_box is not None:
             last_op.result = box = result_box
         return box
 
     def compare(self, other):
-        assert isinstance(other, IndexVar)
+        """ returns if the two are compareable as a first result
+            and a number (-1,0,1) of the ordering
+        """
         v1 = (self.coefficient_mul // self.coefficient_div) + self.constant
         v2 = (other.coefficient_mul // other.coefficient_div) + other.constant
-        if v1 == v2:
-            return 0
-        elif v1 < v2:
-            return -1
-        else:
-            return 1
+        c = (v1 - v2)
+        if self.var.same_box(other.var):
+            #print "cmp(",self,",",other,") =>", (v1 - v2)
+            return True, (v1 - v2)
+        return False, 0
 
     def __repr__(self):
         if self.is_identity():
diff --git a/rpython/jit/metainterp/optimizeopt/guard.py 
b/rpython/jit/metainterp/optimizeopt/guard.py
--- a/rpython/jit/metainterp/optimizeopt/guard.py
+++ b/rpython/jit/metainterp/optimizeopt/guard.py
@@ -16,27 +16,38 @@
     """ An object wrapper around a guard. Helps to determine
         if one guard implies another
     """
-    def __init__(self, index, op, cmp_op, lhs_arg, rhs_arg):
+    def __init__(self, index, op, cmp_op, index_vars):
         self.index = index
         self.op = op
         self.cmp_op = cmp_op
-        self.lhs_arg = lhs_arg
-        self.rhs_arg = rhs_arg
         self.lhs_key = None
         self.rhs_key = None
+        lhs = cmp_op.getarg(0)
+        self.lhs = index_vars.get(lhs, None)
+        if self.lhs is None:
+            self.lhs = IndexVar(lhs)
+        #
+        rhs = cmp_op.getarg(1)
+        self.rhs = index_vars.get(rhs, None)
+        if self.rhs is None:
+            self.rhs = IndexVar(rhs)
+
+    def getleftkey(self):
+        return self.lhs.getvariable()
+
+    def getrightkey(self):
+        return self.rhs.getvariable()
 
     def implies(self, guard, opt):
         if self.op.getopnum() != guard.op.getopnum():
             return False
 
-        if self.lhs_key == guard.lhs_key:
+        if self.getleftkey() is guard.getleftkey():
             # same operation
-            valid, lc = self.compare(self.lhs, guard.lhs)
-            if not valid:
-                return False
-            valid, rc = self.compare(self.rhs, guard.rhs)
-            if not valid:
-                return False
+            valid, lc = self.lhs.compare(guard.lhs)
+            if not valid: return False
+            valid, rc = self.rhs.compare(guard.rhs)
+            if not valid: return False
             opnum = self.get_compare_opnum()
             if opnum == -1:
                 return False
@@ -53,6 +64,35 @@
                 return (lc <= 0 and rc >= 0) or (lc == 0 and rc >= 0)
         return False
 
+    def transitive_imply(self, other, opt):
+        if self.op.getopnum() != other.op.getopnum():
+            # stronger restriction, intermixing e.g. <= and < would be possible
+            return None
+        if self.getleftkey() is not other.getleftkey():
+            return None
+        if not self.rhs.is_identity():
+            # stronger restriction
+            return None
+        # this is a valid transitive guard that eliminates the loop guard
+        opnum = self.transitive_cmpop(self.cmp_op.getopnum())
+        box_rhs = self.emit_varops(opt, self.rhs, self.cmp_op.getarg(1))
+        other_rhs = self.emit_varops(opt, other.rhs, other.cmp_op.getarg(1))
+        box_result = self.cmp_op.result.clonebox()
+        opt.emit_operation(ResOperation(opnum, [box_rhs, other_rhs], 
box_result))
+        # guard
+        guard = self.op.clone()
+        guard.setarg(0, box_result)
+        opt.emit_operation(guard)
+
+        return guard
+
+    def transitive_cmpop(self, opnum):
+        if opnum == rop.INT_LT:
+            return rop.INT_LE
+        if opnum == rop.INT_GT:
+            return rop.INT_GE
+        return opnum
+
     def get_compare_opnum(self):
         opnum = self.op.getopnum()
         if opnum == rop.GUARD_TRUE:
@@ -74,65 +114,38 @@
             myop.rd_snapshot = otherop.rd_snapshot
             myop.setfailargs(otherop.getfailargs())
 
-    def compare(self, key1, key2):
-        if isinstance(key1, Box):
-            if isinstance(key2, Box) and key1 is key2:
-                return True, 0
-            return False, 0
-        #
-        if isinstance(key1, ConstInt):
-            if not isinstance(key2, ConstInt):
-                return False, 0
-            v1 = key1.value
-            v2 = key2.value
-            if v1 == v2:
-                return True, 0
-            elif v1 < v2:
-                return True, -1
-            else:
-                return True, 1
-        #
-        if isinstance(key1, IndexVar):
-            assert isinstance(key2, IndexVar)
-            return True, key1.compare(key2)
-        #
-        raise AssertionError("cannot compare: " + str(key1) + " <=> " + 
str(key2))
-
     def emit_varops(self, opt, var, old_arg):
-        if isinstance(var, IndexVar):
-            if var.is_identity():
-                return var.var
-            box = var.emit_operations(opt)
-            opt.renamer.start_renaming(old_arg, box)
-            return box
-        else:
-            return var
+        assert isinstance(var, IndexVar)
+        if var.is_identity():
+            return var.var
+        box = var.emit_operations(opt)
+        opt.renamer.start_renaming(old_arg, box)
+        return box
 
     def emit_operations(self, opt):
-        lhs, opnum, rhs = opt._get_key(self.cmp_op)
         # create trace instructions for the index
-        box_lhs = self.emit_varops(opt, self.lhs, self.lhs_arg)
-        box_rhs = self.emit_varops(opt, self.rhs, self.rhs_arg)
+        box_lhs = self.emit_varops(opt, self.lhs, self.cmp_op.getarg(0))
+        box_rhs = self.emit_varops(opt, self.rhs, self.cmp_op.getarg(1))
         box_result = self.cmp_op.result.clonebox()
-        opt.emit_operation(ResOperation(opnum, [box_lhs, box_rhs], box_result))
-        # guard
+        opnum = self.cmp_op.getopnum()
+        cmp_op = ResOperation(opnum, [box_lhs, box_rhs], box_result)
+        opt.emit_operation(cmp_op)
+        # emit that actual guard
         guard = self.op.clone()
         guard.setarg(0, box_result)
         opt.emit_operation(guard)
+        guard.index = opt.operation_position()-1
+        guard.op = guard
+        guard.cmp_op = cmp_op
 
-    def update_keys(self, index_vars):
-        self.lhs = index_vars.get(self.lhs_arg, self.lhs_arg)
-        if isinstance(self.lhs, IndexVar):
-            self.lhs = self.lhs.var
-        self.lhs_key = self.lhs
-        #
-        self.rhs = index_vars.get(self.rhs_arg, self.rhs_arg)
-        if isinstance(self.rhs, IndexVar):
-            self.rhs = self.rhs.var
-        self.rhs_key = self.rhs
+    def set_to_none(self, operations):
+        assert operations[self.index] is self.op
+        operations[self.index] = None
+        if operations[self.index-1] is self.cmp_op:
+            operations[self.index-1] = None
 
     @staticmethod
-    def of(boolarg, operations, index):
+    def of(boolarg, operations, index, index_vars):
         guard_op = operations[index]
         i = index - 1
         # most likely hit in the first iteration
@@ -147,9 +160,7 @@
         else:
             raise AssertionError("guard_true/false first arg not defined")
         #
-        lhs_arg = cmp_op.getarg(0)
-        rhs_arg = cmp_op.getarg(1)
-        return Guard(i, guard_op, cmp_op, lhs_arg, rhs_arg)
+        return Guard(index, guard_op, cmp_op, index_vars)
 
 class GuardStrengthenOpt(object):
     def __init__(self, index_vars):
@@ -159,25 +170,6 @@
         self.strongest_guards = {}
         self.guards = {}
 
-    #def _get_key(self, cmp_op):
-    #    assert cmp_op
-    #    lhs_arg = cmp_op.getarg(0)
-    #    rhs_arg = cmp_op.getarg(1)
-    #    lhs_index_var = self.index_vars.get(lhs_arg, None)
-    #    rhs_index_var = self.index_vars.get(rhs_arg, None)
-
-    #    cmp_opnum = cmp_op.getopnum()
-    #    # get the key, this identifies the guarded operation
-    #    if lhs_index_var and rhs_index_var:
-    #        return (lhs_index_var.getvariable(), cmp_opnum, 
rhs_index_var.getvariable())
-    #    elif lhs_index_var:
-    #        return (lhs_index_var.getvariable(), cmp_opnum, None)
-    #    elif rhs_index_var:
-    #        return (None, cmp_opnum, rhs_index_var)
-    #    else:
-    #        return (None, cmp_opnum, None)
-    #    return key
-
     def collect_guard_information(self, loop):
         operations = loop.operations
         last_guard = None
@@ -186,12 +178,11 @@
             if not op.is_guard():
                 continue
             if op.getopnum() in (rop.GUARD_TRUE, rop.GUARD_FALSE):
-                guard = Guard.of(op.getarg(0), operations, i)
+                guard = Guard.of(op.getarg(0), operations, i, self.index_vars)
                 if guard is None:
                     continue
-                guard.update_keys(self.index_vars)
-                self.record_guard(guard.lhs_key, guard)
-                self.record_guard(guard.rhs_key, guard)
+                self.record_guard(guard.getleftkey(), guard)
+                self.record_guard(guard.getrightkey(), guard)
 
     def record_guard(self, key, guard):
         if key is None:
@@ -204,18 +195,23 @@
         # not emitted and the original is replaced with the current
         others = self.strongest_guards.setdefault(key, [])
         if len(others) > 0: # (2)
+            replaced = False
             for i,other in enumerate(others):
                 if guard.implies(other, self):
                     # strengthend
+                    others[i] = guard
+                    self.guards[guard.index] = None # mark as 'do not emit'
                     guard.inhert_attributes(other)
-                    others[i] = guard
                     self.guards[other.index] = guard
-                    self.guards[guard.index] = None # mark as 'do not emit'
+                    replaced = True
                     continue
                 elif other.implies(guard, self):
                     # implied
                     self.guards[guard.index] = None # mark as 'do not emit'
+                    replaced = True
                     continue
+            if not replaced:
+                others.append(guard)
         else: # (2)
             others.append(guard)
 
@@ -247,14 +243,58 @@
         #
         loop.operations = self._newoperations[:]
 
-    def propagate_all_forward(self, loop):
+    def propagate_all_forward(self, loop, user_code=False):
         """ strengthens the guards that protect an integral value """
         # the guards are ordered. guards[i] is before guards[j] iff i < j
         self.collect_guard_information(loop)
-        #
         self.eliminate_guards(loop)
 
+        if user_code:
+            version = loop.snapshot()
+            self.eliminate_array_bound_checks(loop, version)
+
     def emit_operation(self, op):
         self.renamer.rename(op)
         self._newoperations.append(op)
 
+    def operation_position(self):
+        return len(self._newoperations)
+
+    def eliminate_array_bound_checks(self, loop, version):
+        self._newoperations = []
+        for key, guards in self.strongest_guards.items():
+            if len(guards) <= 1:
+                continue
+            # there is more than one guard for that key,
+            # that is why we could imply the guards 2..n
+            # iff we add invariant guards
+            one = guards[0]
+            for other in guards[1:]:
+                transitive_guard = one.transitive_imply(other, self)
+                if transitive_guard:
+                    other.set_to_none(loop.operations)
+                    version.register_guard(transitive_guard)
+
+        loop.operations = self._newoperations + \
+                [op for op in loop.operations if op]
+
+    # OLD
+    #def _get_key(self, cmp_op):
+    #    assert cmp_op
+    #    lhs_arg = cmp_op.getarg(0)
+    #    rhs_arg = cmp_op.getarg(1)
+    #    lhs_index_var = self.index_vars.get(lhs_arg, None)
+    #    rhs_index_var = self.index_vars.get(rhs_arg, None)
+
+    #    cmp_opnum = cmp_op.getopnum()
+    #    # get the key, this identifies the guarded operation
+    #    if lhs_index_var and rhs_index_var:
+    #        return (lhs_index_var.getvariable(), cmp_opnum, 
rhs_index_var.getvariable())
+    #    elif lhs_index_var:
+    #        return (lhs_index_var.getvariable(), cmp_opnum, None)
+    #    elif rhs_index_var:
+    #        return (None, cmp_opnum, rhs_index_var)
+    #    else:
+    #        return (None, cmp_opnum, None)
+    #    return key
+
diff --git a/rpython/jit/metainterp/optimizeopt/test/test_guard.py 
b/rpython/jit/metainterp/optimizeopt/test/test_guard.py
--- a/rpython/jit/metainterp/optimizeopt/test/test_guard.py
+++ b/rpython/jit/metainterp/optimizeopt/test/test_guard.py
@@ -31,10 +31,10 @@
         return abs(val) == 1
 
 class GuardBaseTest(SchedulerBaseTest):
-    def optguards(self, loop):
+    def optguards(self, loop, user_code=False):
         dep = DependencyGraph(loop)
         opt = GuardStrengthenOpt(dep.index_vars)
-        opt.propagate_all_forward(loop)
+        opt.propagate_all_forward(loop, user_code)
         return opt
 
     def assert_guard_count(self, loop, count):
@@ -48,6 +48,8 @@
 
     def assert_contains_sequence(self, loop, instr):
         class Glob(object):
+            next = None
+            prev = None
             def __repr__(self):
                 return '*'
         from rpython.jit.tool.oparser import OpParser, default_fail_descr
@@ -73,33 +75,36 @@
             prev_op = op
 
         def check(op, candidate, rename):
+            m = 0
             if isinstance(candidate, Glob):
                 if candidate.next is None:
                     return 0 # consumes the rest
                 if op.getopnum() != candidate.next.getopnum():
                     return 0
+                m = 1
                 candidate = candidate.next
             if op.getopnum() == candidate.getopnum():
                 for i,arg in enumerate(op.getarglist()):
                     oarg = candidate.getarg(i)
                     if arg in rename:
-                        assert rename[arg] is oarg
+                        assert rename[arg].same_box(oarg)
                     else:
                         rename[arg] = oarg
 
                 if op.result:
                     rename[op.result] = candidate.result
-                return 1
+                m += 1
+                return m
             return 0
         j = 0
         rename = {}
         for i, op in enumerate(loop.operations):
             candidate = operations[j]
             j += check(op, candidate, rename)
-        if isinstance(operations[0], Glob):
-            assert j == len(operations)-2
+        if isinstance(operations[-1], Glob):
+            assert j == len(operations)-1, self.debug_print_operations(loop)
         else:
-            assert j == len(operations)-1
+            assert j == len(operations), self.debug_print_operations(loop)
 
     def test_basic(self):
         loop1 = self.parse("""
@@ -141,17 +146,18 @@
         loop1 = self.parse("""
         i10 = int_gt(i1, 42)
         guard_true(i10) []
-        i11 = int_sub(i1, 1)
-        i12 = int_gt(i11, 42)
+        i11 = int_add(i1, 1)
+        i12 = int_gt(i11, i2)
         guard_true(i12) []
         """)
-        opt = self.optguards(loop1)
-        self.assert_guard_count(loop1, 1)
+        opt = self.optguards(loop1, True)
+        self.assert_guard_count(loop1, 2)
         self.assert_contains_sequence(loop1, """
+        i40 = int_ge(42, i2)
+        guard_true(i40) []
         ...
-        i11 = int_sub(i1, 1)
-        i12 = int_gt(i11, 42)
-        guard_true(i12) []
+        i10 = int_gt(i1, 42)
+        guard_true(i10) []
         ...
         """)
 
diff --git a/rpython/jit/metainterp/optimizeopt/vectorize.py 
b/rpython/jit/metainterp/optimizeopt/vectorize.py
--- a/rpython/jit/metainterp/optimizeopt/vectorize.py
+++ b/rpython/jit/metainterp/optimizeopt/vectorize.py
@@ -32,11 +32,11 @@
 from rpython.rtyper.lltypesystem import lltype, rffi
 
 def optimize_vector(metainterp_sd, jitdriver_sd, loop, optimizations,
-                    inline_short_preamble, start_state, cost_threshold):
+                    inline_short_preamble, start_state, warmstate):
     optimize_unroll(metainterp_sd, jitdriver_sd, loop, optimizations,
                     inline_short_preamble, start_state, False)
-    orig_ops = loop.operations
-    if len(orig_ops) >= 75:
+    version = loop.snapshot()
+    if len(loop.operations) >= 75:
         # if more than 75 operations are present in this loop,
         # it won't be possible to vectorize. There are too many
         # guards that prevent parallel execution of instructions
@@ -52,9 +52,11 @@
         opt = VectorizingOptimizer(metainterp_sd, jitdriver_sd, loop, 
cost_threshold)
         opt.propagate_all_forward()
         gso = GuardStrengthenOpt(opt.dependency_graph.index_vars)
-        gso.propagate_all_forward(opt.loop)
+        user_code = not jitdriver_sd.vectorize and warmstate.vectorize_user
+        gso.propagate_all_forward(opt.loop, user_code)
         # loop versioning
-        loop.versions = [LoopVersion(orig_ops, loop.operations, 
opt.appended_arg_count)]
+        version.register_all_guards(loop.operations, opt.appended_arg_count)
+        loop.versions.append(version)
         #
         #
         end = time.clock()
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to