Author: Armin Rigo <[email protected]>
Branch: 
Changeset: r74748:cf25c8fc4cdb
Date: 2014-11-27 13:18 +0100
http://bitbucket.org/pypy/pypy/changeset/cf25c8fc4cdb/

Log:    issue #1921: produce switches on ints which don't promote the
        incoming value if it doesn't match any of the cases and goes to
        "default".

diff --git a/rpython/jit/codewriter/assembler.py 
b/rpython/jit/codewriter/assembler.py
--- a/rpython/jit/codewriter/assembler.py
+++ b/rpython/jit/codewriter/assembler.py
@@ -216,10 +216,11 @@
             self.code[pos  ] = chr(target & 0xFF)
             self.code[pos+1] = chr(target >> 8)
         for descr in self.switchdictdescrs:
-            descr.dict = {}
+            as_dict = {}
             for key, switchlabel in descr._labels:
                 target = self.label_positions[switchlabel.name]
-                descr.dict[key] = target
+                as_dict[key] = target
+            descr.attach(as_dict)
 
     def check_result(self):
         # Limitation of the number of registers, from the single-byte encoding
diff --git a/rpython/jit/codewriter/flatten.py 
b/rpython/jit/codewriter/flatten.py
--- a/rpython/jit/codewriter/flatten.py
+++ b/rpython/jit/codewriter/flatten.py
@@ -243,55 +243,39 @@
         else:
             # A switch.
             #
-            def emitdefaultpath():
-                if block.exits[-1].exitcase == 'default':
-                    self.make_link(block.exits[-1])
-                else:
-                    self.emitline("unreachable")
-                    self.emitline("---")
-            #
-            self.emitline('-live-')
             switches = [link for link in block.exits
                         if link.exitcase != 'default']
             switches.sort(key=lambda link: link.llexitcase)
             kind = getkind(block.exitswitch.concretetype)
-            if len(switches) >= 5 and kind == 'int':
-                # A large switch on an integer, implementable efficiently
-                # with the help of a SwitchDictDescr
-                from rpython.jit.codewriter.jitcode import SwitchDictDescr
-                switchdict = SwitchDictDescr()
-                switchdict._labels = []
-                self.emitline('switch', self.getcolor(block.exitswitch),
-                                        switchdict)
-                emitdefaultpath()
-                #
-                for switch in switches:
-                    key = lltype.cast_primitive(lltype.Signed,
-                                                switch.llexitcase)
-                    switchdict._labels.append((key, TLabel(switch)))
-                    # emit code for that path
-                    self.emitline(Label(switch))
-                    self.make_link(switch)
+            assert kind == 'int'    # XXX
             #
+            # A switch on an integer, implementable efficiently with the
+            # help of a SwitchDictDescr.  We use this even if there are
+            # very few cases: in pyjitpl.py, opimpl_switch() will promote
+            # the int only if it matches one of the cases.
+            from rpython.jit.codewriter.jitcode import SwitchDictDescr
+            switchdict = SwitchDictDescr()
+            switchdict._labels = []
+            self.emitline('-live-')    # for 'guard_value'
+            self.emitline('switch', self.getcolor(block.exitswitch),
+                                    switchdict)
+            # emit the default path
+            if block.exits[-1].exitcase == 'default':
+                self.make_link(block.exits[-1])
             else:
-                # A switch with several possible answers, though not too
-                # many of them -- a chain of int_eq comparisons is fine
-                assert kind == 'int'    # XXX
-                color = self.getcolor(block.exitswitch)
-                self.emitline('int_guard_value', color)
-                for switch in switches:
-                    # make the case described by 'switch'
-                    self.emitline('goto_if_not_int_eq',
-                                  color,
-                                  Constant(switch.llexitcase,
-                                           block.exitswitch.concretetype),
-                                  TLabel(switch))
-                    # emit code for the "taken" path
-                    self.make_link(switch)
-                    # finally, emit the label for the "non-taken" path
-                    self.emitline(Label(switch))
-                #
-                emitdefaultpath()
+                self.emitline("unreachable")
+                self.emitline("---")
+            #
+            for switch in switches:
+                key = lltype.cast_primitive(lltype.Signed,
+                                            switch.llexitcase)
+                switchdict._labels.append((key, TLabel(switch)))
+                # emit code for that path
+                # note: we need a -live- for all the 'guard_false' we produce
+                # if the switched value doesn't match any case.
+                self.emitline(Label(switch))
+                self.emitline('-live-')
+                self.make_link(switch)
 
     def insert_renamings(self, link):
         renamings = {}
diff --git a/rpython/jit/codewriter/jitcode.py 
b/rpython/jit/codewriter/jitcode.py
--- a/rpython/jit/codewriter/jitcode.py
+++ b/rpython/jit/codewriter/jitcode.py
@@ -1,4 +1,4 @@
-from rpython.jit.metainterp.history import AbstractDescr
+from rpython.jit.metainterp.history import AbstractDescr, ConstInt
 from rpython.jit.codewriter import heaptracker
 from rpython.rlib.objectmodel import we_are_translated
 
@@ -109,6 +109,10 @@
 class SwitchDictDescr(AbstractDescr):
     "Get a 'dict' attribute mapping integer values to bytecode positions."
 
+    def attach(self, as_dict):
+        self.dict = as_dict
+        self.const_keys_in_order = map(ConstInt, sorted(as_dict.keys()))
+
     def __repr__(self):
         dict = getattr(self, 'dict', '?')
         return '<SwitchDictDescr %s>' % (dict,)
diff --git a/rpython/jit/codewriter/test/test_flatten.py 
b/rpython/jit/codewriter/test/test_flatten.py
--- a/rpython/jit/codewriter/test/test_flatten.py
+++ b/rpython/jit/codewriter/test/test_flatten.py
@@ -282,30 +282,6 @@
             foobar hi_there!
         """)
 
-    def test_switch(self):
-        def f(n):
-            if n == -5:  return 12
-            elif n == 2: return 51
-            elif n == 7: return 1212
-            else:        return 42
-        self.encoding_test(f, [65], """
-            -live-
-            int_guard_value %i0
-            goto_if_not_int_eq %i0, $-5, L1
-            int_return $12
-            ---
-            L1:
-            goto_if_not_int_eq %i0, $2, L2
-            int_return $51
-            ---
-            L2:
-            goto_if_not_int_eq %i0, $7, L3
-            int_return $1212
-            ---
-            L3:
-            int_return $42
-        """)
-
     def test_switch_dict(self):
         def f(x):
             if   x == 1: return 61
diff --git a/rpython/jit/metainterp/pyjitpl.py 
b/rpython/jit/metainterp/pyjitpl.py
--- a/rpython/jit/metainterp/pyjitpl.py
+++ b/rpython/jit/metainterp/pyjitpl.py
@@ -402,13 +402,26 @@
 
     @arguments("box", "descr", "orgpc")
     def opimpl_switch(self, valuebox, switchdict, orgpc):
-        box = self.implement_guard_value(valuebox, orgpc)
-        search_value = box.getint()
+        search_value = valuebox.getint()
         assert isinstance(switchdict, SwitchDictDescr)
         try:
-            self.pc = switchdict.dict[search_value]
+            target = switchdict.dict[search_value]
         except KeyError:
-            pass
+            # None of the cases match.  Fall back to generating a chain
+            # of 'int_eq'.
+            # xxx as a minor optimization, if that's a bridge, then we would
+            # not need the cases that we already tested (and failed) with
+            # 'guard_value'.  How to do it is not very clear though.
+            for const1 in switchdict.const_keys_in_order:
+                box = self.execute(rop.INT_EQ, valuebox, const1)
+                assert box.getint() == 0
+                target = switchdict.dict[const1.getint()]
+                self.metainterp.generate_guard(rop.GUARD_FALSE, box,
+                                               resumepc=target)
+        else:
+            # found one of the cases
+            self.implement_guard_value(valuebox, orgpc)
+            self.pc = target
 
     @arguments()
     def opimpl_unreachable(self):
@@ -2270,8 +2283,8 @@
         if opnum == rop.GUARD_TRUE:     # a goto_if_not that jumps only now
             if not dont_change_position:
                 frame.pc = frame.jitcode.follow_jump(frame.pc)
-        elif opnum == rop.GUARD_FALSE:     # a goto_if_not that stops jumping
-            pass
+        elif opnum == rop.GUARD_FALSE:     # a goto_if_not that stops jumping;
+            pass                  # or a switch that was in its "default" case
         elif opnum == rop.GUARD_VALUE or opnum == rop.GUARD_CLASS:
             pass        # the pc is already set to the *start* of the opcode
         elif (opnum == rop.GUARD_NONNULL or
diff --git a/rpython/jit/metainterp/test/test_ajit.py 
b/rpython/jit/metainterp/test/test_ajit.py
--- a/rpython/jit/metainterp/test/test_ajit.py
+++ b/rpython/jit/metainterp/test/test_ajit.py
@@ -698,6 +698,40 @@
         res = self.interp_operations(f, [12311])
         assert res == 42
 
+    def test_switch_bridges(self):
+        from rpython.rlib.rarithmetic import intmask
+        myjitdriver = JitDriver(greens = [], reds = 'auto')
+        lsts = [[-5, 2, 20] * 6,
+                [7, 123, 2] * 6,
+                [12311, -5, 7] * 6,
+                [7, 123, 2] * 4 + [-5, -5, -5] * 2,
+                [7, 123, 2] * 4 + [-5, -5, -5] * 2 + [12311, 12311, 12311],
+                ]
+        def f(case):
+            x = 0
+            i = 0
+            lst = lsts[case]
+            while i < len(lst):
+                myjitdriver.jit_merge_point()
+                n = lst[i]
+                if n == -5:  a = 5
+                elif n == 2: a = 4
+                elif n == 7: a = 3
+                else:        a = 2
+                x = intmask(x * 10 + a)
+                i += 1
+            return x
+        res = self.meta_interp(f, [0], backendopt=True)
+        assert res == intmask(542 * 1001001001001001)
+        res = self.meta_interp(f, [1], backendopt=True)
+        assert res == intmask(324 * 1001001001001001)
+        res = self.meta_interp(f, [2], backendopt=True)
+        assert res == intmask(253 * 1001001001001001)
+        res = self.meta_interp(f, [3], backendopt=True)
+        assert res == intmask(324324324324555555)
+        res = self.meta_interp(f, [4], backendopt=True)
+        assert res == intmask(324324324324555555222)
+
     def test_r_uint(self):
         from rpython.rlib.rarithmetic import r_uint
         myjitdriver = JitDriver(greens = [], reds = ['y'])
@@ -833,23 +867,6 @@
         assert type(res) == bool
         assert not res
 
-    def test_switch_dict(self):
-        def f(x):
-            if   x == 1: return 61
-            elif x == 2: return 511
-            elif x == 3: return -22
-            elif x == 4: return 81
-            elif x == 5: return 17
-            elif x == 6: return 54
-            elif x == 7: return 987
-            elif x == 8: return -12
-            elif x == 9: return 321
-            return -1
-        res = self.interp_operations(f, [5])
-        assert res == 17
-        res = self.interp_operations(f, [15])
-        assert res == -1
-
     def test_int_add_ovf(self):
         def f(x, y):
             try:
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to