Author: Carl Friedrich Bolz <cfb...@gmx.de>
Branch: better-storesink
Changeset: r87152:ca579ca81332
Date: 2016-08-23 11:40 +0100
http://bitbucket.org/pypy/pypy/changeset/ca579ca81332/

Log:    support for more complex cases in the merge code:

        sometimes there are two operations in the cache that happen on
        different arguments. but those arguments are merged into one
        variable at the merge. this leads to a new entry in the cache of the
        merged block.

        (a bit terrible code, needs a refactoring)

diff --git a/rpython/translator/backendopt/cse.py 
b/rpython/translator/backendopt/cse.py
--- a/rpython/translator/backendopt/cse.py
+++ b/rpython/translator/backendopt/cse.py
@@ -43,7 +43,41 @@
         block = firstlink.target
         # copy all operations that exist in *all* blocks over. need to add a 
new
         # inputarg if the result is really a variable
+
+        # try non-straight merges
+        for argindex, inputarg in enumerate(block.inputargs):
+            # bit slow, but probably ok
+            firstlinkarg = 
self.variable_families.find_rep(firstlink.args[argindex])
+            results = []
+            for key, res in self.purecache.iteritems():
+                (opname, concretetype, args) = key
+                if args[0] != firstlinkarg: # XXX other args
+                    continue
+                results.append(res)
+                for linkindex, (link, cache) in enumerate(tuples):
+                    if linkindex == 0:
+                        continue
+                    listargs = list(args)
+                    listargs[0] = 
self.variable_families.find_rep(link.args[argindex])
+                    newkey = (opname, concretetype, tuple(listargs))
+                    otherres = cache.purecache.get(newkey, None)
+                    if otherres is None:
+                        break
+                    results.append(otherres)
+                else:
+                    listargs = list(args)
+                    listargs[0] = self.variable_families.find_rep(inputarg)
+                    newkey = (opname, concretetype, tuple(listargs))
+                    newres = res
+                    if isinstance(res, Variable):
+                        newres = res.copy()
+                        for linkindex, (link, cache) in enumerate(tuples):
+                            link.args.append(results[linkindex])
+                        block.inputargs.append(newres)
+                    purecache[newkey] = newres
+
         for key, res in self.purecache.iteritems():
+            # "straight" merge: the variable is in all other caches
             for link, cache in tuples[1:]:
                 val = cache.purecache.get(key, None)
                 if val is None:
@@ -57,8 +91,41 @@
                     block.inputargs.append(newres)
                 purecache[key] = newres
 
+        # ______________________
         # merge heapcache
         heapcache = {}
+
+        # try non-straight merges
+        for argindex, inputarg in enumerate(block.inputargs):
+            # bit slow, but probably ok
+            firstlinkarg = 
self.variable_families.find_rep(firstlink.args[argindex])
+            results = []
+            for key, res in self.heapcache.iteritems():
+                (arg, fieldname) = key
+                if arg != firstlinkarg:
+                    continue
+                results.append(res)
+                for linkindex, (link, cache) in enumerate(tuples):
+                    if linkindex == 0:
+                        continue
+                    otherarg = 
self.variable_families.find_rep(link.args[argindex])
+                    newkey = (otherarg, fieldname)
+                    otherres = cache.heapcache.get(newkey, None)
+                    if otherres is None:
+                        break
+                    results.append(otherres)
+                else:
+                    listargs = list(args)
+                    listargs[0] = inputarg
+                    newkey = (self.variable_families.find_rep(inputarg), 
fieldname)
+                    newres = res
+                    if isinstance(res, Variable):
+                        newres = res.copy()
+                        for linkindex, (link, cache) in enumerate(tuples):
+                            link.args.append(results[linkindex])
+                        block.inputargs.append(newres)
+                    heapcache[newkey] = newres
+
         for key, res in self.heapcache.iteritems():
             for link, cache in tuples[1:]:
                 val = cache.heapcache.get(key, None)
@@ -73,6 +140,8 @@
                     block.inputargs.append(newres)
                 heapcache[key] = newres
 
+
+
         return Cache(
                 self.variable_families, self.analyzer, purecache, heapcache)
 
diff --git a/rpython/translator/backendopt/test/test_cse.py 
b/rpython/translator/backendopt/test/test_cse.py
--- a/rpython/translator/backendopt/test/test_cse.py
+++ b/rpython/translator/backendopt/test/test_cse.py
@@ -61,6 +61,21 @@
         # an add in each branch, but not the final block
         self.check(f, [int, int], int_add=2)
 
+    def test_merge2(self):
+        # in this test we add two different values, but the final add is on the
+        # same different value, so it can be shared
+        def f(i, j):
+            if j:
+                x = i
+                y = x + 1
+            else:
+                x = ~i
+                y = x + 1
+            return (x + 1) * y
+
+        # an add in each branch, but not the final block
+        self.check(f, [int, int], int_add=2)
+
     def test_optimize_across_merge(self):
         def f(i, j):
             k = i + 1
@@ -213,6 +228,28 @@
 
         self.check(f, [int], getfield=0)
 
+    def test_merge2_heapcache(self):
+        class A(object):
+            pass
+
+        def f(i):
+            a1 = A()
+            a1.x = i
+            a2 = A()
+            a2.x = i + 1
+            a3 = A()
+            a3.x = 1 # clear other caches
+            if i:
+                a = a1
+                j = a.x
+            else:
+                a = a2
+                j = a.x
+            j += a.x
+            return j
+
+        self.check(f, [int], getfield=2)
+
     def test_dont_invalidate_on_call(self):
         class A(object):
             pass
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to