Author: Spenser Andrew Bauman <[email protected]>
Branch: force-virtual-state
Changeset: r86777:cd41860150a5
Date: 2016-08-30 15:08 -0400
http://bitbucket.org/pypy/pypy/changeset/cd41860150a5/

Log:    Try forcing virtual objects to avoid jumping to preamble

diff --git a/rpython/jit/metainterp/optimizeopt/unroll.py 
b/rpython/jit/metainterp/optimizeopt/unroll.py
--- a/rpython/jit/metainterp/optimizeopt/unroll.py
+++ b/rpython/jit/metainterp/optimizeopt/unroll.py
@@ -16,7 +16,7 @@
 from rpython.rlib.debug import debug_print, debug_start, debug_stop,\
      have_debug_prints
 
-class UnrollableOptimizer(Optimizer):    
+class UnrollableOptimizer(Optimizer):
     def force_op_from_preamble(self, preamble_op):
         if isinstance(preamble_op, PreambleOp):
             if self.optunroll.short_preamble_producer is None:
@@ -120,7 +120,8 @@
                 assert op.get_forwarded() is None
         if check_newops:
             assert not self.optimizer._newoperations
-    
+
+
     def optimize_preamble(self, trace, runtime_boxes, call_pure_results, memo):
         info, newops = self.optimizer.propagate_all_forward(
             trace.get_iter(), call_pure_results, flush=False)
@@ -156,7 +157,7 @@
         current_vs = self.get_virtual_state(end_jump.getarglist())
         # pick the vs we want to jump to
         assert isinstance(celltoken, JitCellToken)
-        
+
         target_virtual_state = self.pick_virtual_state(current_vs,
                                                        state.virtual_state,
                                                 celltoken.target_tokens)
@@ -180,7 +181,7 @@
             self.jump_to_preamble(celltoken, end_jump, info)
             return (UnrollInfo(target_token, label_op, extra_same_as,
                                self.optimizer.quasi_immutable_deps),
-                    self.optimizer._newoperations)            
+                    self.optimizer._newoperations)
 
         try:
             new_virtual_state = self.jump_to_existing_trace(end_jump, label_op,
@@ -191,6 +192,16 @@
             return (UnrollInfo(target_token, label_op, extra_same_as,
                                self.optimizer.quasi_immutable_deps),
                     self.optimizer._newoperations)
+
+        if new_virtual_state is not None:
+            # Attempt to force virtual boxes in order to avoid jumping
+            # to the preamble.
+            try:
+                new_virtual_state = self.jump_to_existing_trace(
+                        end_jump, label_op, state.runtime_boxes, 
force_boxes=True)
+            except InvalidLoop:
+                pass
+
         if new_virtual_state is not None:
             self.jump_to_preamble(celltoken, end_jump, info)
             return (UnrollInfo(target_token, label_op, extra_same_as,
@@ -199,7 +210,7 @@
 
         self.disable_retracing_if_max_retrace_guards(
             self.optimizer._newoperations, target_token)
-        
+
         return (UnrollInfo(target_token, label_op, extra_same_as,
                            self.optimizer.quasi_immutable_deps),
                 self.optimizer._newoperations)
@@ -241,7 +252,7 @@
         for a in jump_op.getarglist():
             self.optimizer.force_box_for_end_of_preamble(a)
         try:
-            vs = self.jump_to_existing_trace(jump_op, None, runtime_boxes)
+            vs = self.jump_to_existing_trace(jump_op, None, runtime_boxes, 
False)
         except InvalidLoop:
             return self.jump_to_preamble(cell_token, jump_op, info)
         if vs is None:
@@ -252,6 +263,13 @@
             cell_token.retraced_count += 1
             debug_print('Retracing (%d/%d)' % (cell_token.retraced_count, 
limit))
         else:
+            # Try forcing boxes to avoid jumping to the preamble
+            try:
+                vs = self.jump_to_existing_trace(jump_op, None, runtime_boxes, 
True)
+            except InvalidLoop:
+                pass
+            if vs is None:
+                return info, self.optimizer._newoperations[:]
             debug_print("Retrace count reached, jumping to preamble")
             return self.jump_to_preamble(cell_token, jump_op, info)
         exported_state = self.export_state(info.jump_op.getarglist(),
@@ -288,7 +306,7 @@
         return info, self.optimizer._newoperations[:]
 
 
-    def jump_to_existing_trace(self, jump_op, label_op, runtime_boxes):
+    def jump_to_existing_trace(self, jump_op, label_op, runtime_boxes, 
force_boxes=False):
         jitcelltoken = jump_op.getdescr()
         assert isinstance(jitcelltoken, JitCellToken)
         virtual_state = self.get_virtual_state(jump_op.getarglist())
@@ -299,17 +317,18 @@
                 continue
             try:
                 extra_guards = target_virtual_state.generate_guards(
-                    virtual_state, args, runtime_boxes, self.optimizer)
+                    virtual_state, args, runtime_boxes, self.optimizer,
+                    force_boxes=force_boxes)
                 patchguardop = self.optimizer.patchguardop
                 for guard in extra_guards.extra_guards:
                     if isinstance(guard, GuardResOp):
                         guard.rd_resume_position = 
patchguardop.rd_resume_position
                         guard.setdescr(compile.ResumeAtPositionDescr())
                     self.send_extra_operation(guard)
-            except VirtualStatesCantMatch:
+            except VirtualStatesCantMatch as e:
                 continue
             args, virtuals = target_virtual_state.make_inputargs_and_virtuals(
-                args, self.optimizer)
+                args, self.optimizer, force_boxes=force_boxes)
             short_preamble = target_token.short_preamble
             try:
                 extra = self.inline_short_preamble(args + virtuals, args,
@@ -452,7 +471,7 @@
         # by short preamble
         label_args = exported_state.virtual_state.make_inputargs(
             targetargs, self.optimizer)
-        
+
         self.short_preamble_producer = ShortPreambleBuilder(
             label_args, exported_state.short_boxes,
             exported_state.short_inputargs, exported_state.exported_infos,
@@ -497,7 +516,7 @@
     * runtime_boxes - runtime values for boxes, necessary when generating
                       guards to jump to
     """
-    
+
     def __init__(self, end_args, next_iteration_args, virtual_state,
                  exported_infos, short_boxes, renamed_inputargs,
                  short_inputargs, runtime_boxes, memo):
diff --git a/rpython/jit/metainterp/optimizeopt/virtualstate.py 
b/rpython/jit/metainterp/optimizeopt/virtualstate.py
--- a/rpython/jit/metainterp/optimizeopt/virtualstate.py
+++ b/rpython/jit/metainterp/optimizeopt/virtualstate.py
@@ -4,7 +4,7 @@
      ArrayStructInfo, AbstractStructPtrInfo
 from rpython.jit.metainterp.optimizeopt.intutils import \
      MININT, MAXINT, IntBound, IntLowerBound
-from rpython.jit.metainterp.resoperation import rop, ResOperation,\
+from rpython.jit.metainterp.resoperation import rop, ResOperation, \
      InputArgInt, InputArgRef, InputArgFloat
 from rpython.rlib.debug import debug_print
 
@@ -20,7 +20,7 @@
 
 
 class GenerateGuardState(object):
-    def __init__(self, optimizer=None, guards=None, renum=None, bad=None):
+    def __init__(self, optimizer=None, guards=None, renum=None, bad=None, 
force_boxes=False):
         self.optimizer = optimizer
         self.cpu = optimizer.cpu
         if guards is None:
@@ -32,6 +32,7 @@
         if bad is None:
             bad = {}
         self.bad = bad
+        self.force_boxes = force_boxes
 
     def get_runtime_item(self, box, descr, i):
         array = box.getref_base()
@@ -303,7 +304,7 @@
             opinfo = state.optimizer.getptrinfo(box)
             assert isinstance(opinfo, ArrayPtrInfo)
         else:
-            opinfo = None            
+            opinfo = None
         for i in range(self.length):
             for descr in self.fielddescrs:
                 index = i * len(self.fielddescrs) + descr.get_index()
@@ -514,6 +515,8 @@
         NotVirtualStateInfo.__init__(self, cpu, type, info)
 
     def _generate_guards(self, other, box, runtime_box, state):
+        if state.force_boxes and isinstance(other, VirtualStateInfo):
+            return self._generate_virtual_guards(other, box, runtime_box, 
state)
         if not isinstance(other, NotVirtualStateInfoPtr):
             raise VirtualStatesCantMatch(
                     'The VirtualStates does not match as a ' +
@@ -545,6 +548,23 @@
     # to an existing compiled loop or retracing the loop. Both alternatives
     # will always generate correct behaviour, but performance will differ.
 
+    def _generate_virtual_guards(self, other, box, runtime_box, state):
+        """
+        Generate the guards and add state information for unifying a virtual
+        object with a non-virtual. This involves forcing the object in the
+        event that unifcation can succeed. Since virtual objects cannot be 
null,
+        this method need only check that the virtual object has the expected 
type.
+        """
+        assert isinstance(other, VirtualStateInfo)
+
+        if self.level == LEVEL_CONSTANT:
+            raise VirtualStatesCantMatch(
+                    "cannot unify a constant value with a virtual object")
+
+        if self.level == LEVEL_KNOWNCLASS:
+            if not self.known_class.same_constant(other.known_class):
+                raise VirtualStatesCantMatch("classes don't match")
+
     def _generate_guards_nonnull(self, other, box, runtime_box, extra_guards,
                                  state):
         if not isinstance(other, NotVirtualStateInfoPtr):
@@ -617,10 +637,10 @@
             return False
         return True
 
-    def generate_guards(self, other, boxes, runtime_boxes, optimizer):
+    def generate_guards(self, other, boxes, runtime_boxes, optimizer, 
force_boxes=False):
         assert (len(self.state) == len(other.state) == len(boxes) ==
                 len(runtime_boxes))
-        state = GenerateGuardState(optimizer)
+        state = GenerateGuardState(optimizer, force_boxes=force_boxes)
         for i in range(len(self.state)):
             self.state[i].generate_guards(other.state[i], boxes[i],
                                           runtime_boxes[i], state)
@@ -644,8 +664,8 @@
 
         return boxes
 
-    def make_inputargs_and_virtuals(self, inputargs, optimizer):
-        inpargs = self.make_inputargs(inputargs, optimizer)
+    def make_inputargs_and_virtuals(self, inputargs, optimizer, 
force_boxes=False):
+        inpargs = self.make_inputargs(inputargs, optimizer, force_boxes)
         # we append the virtuals here in case some stuff is proven
         # to be not a virtual and there are getfields in the short preamble
         # that will read items out of there
@@ -653,7 +673,7 @@
         for i in range(len(inputargs)):
             if not isinstance(self.state[i], NotVirtualStateInfo):
                 virtuals.append(inputargs[i])
-            
+
         return inpargs, virtuals
 
     def debug_print(self, hdr='', bad=None, metainterp_sd=None):
diff --git a/rpython/jit/metainterp/test/test_ajit.py 
b/rpython/jit/metainterp/test/test_ajit.py
--- a/rpython/jit/metainterp/test/test_ajit.py
+++ b/rpython/jit/metainterp/test/test_ajit.py
@@ -4507,3 +4507,33 @@
                 i += 1
             return i
         self.meta_interp(f, [])
+
+    def test_loop_unroll_bug(self):
+        driver = JitDriver(greens=[], reds=['acc', 'i', 'val'])
+        class X(object):
+            # _immutable_ = True
+            def __init__(self, v):
+                self.v = v
+
+        class Box(object):
+            def __init__(self, v):
+                self.unbox = v
+
+        const = Box(X(5))
+        def f(v):
+            val   = X(0)
+            acc   = 0
+            i     = 0
+            const.unbox = X(5)
+            while i < 100:
+                driver.can_enter_jit(acc=acc, i=i, val=val)
+                driver.jit_merge_point(acc=acc, i=i, val=val)
+                acc += val.v
+                if i & 0b100 == 0:
+                    val = const.unbox
+                else:
+                    val = X(i)
+                i += 1
+            return acc
+        result = self.meta_interp(f, [10])
+        # import pdb; pdb.set_trace()
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to