Author: Armin Rigo <[email protected]>
Branch: stm-thread-2
Changeset: r61407:7000756077c5
Date: 2013-02-18 14:07 +0100
http://bitbucket.org/pypy/pypy/changeset/7000756077c5/

Log:    ThreadLocalReference: implementation.

diff --git a/rpython/rlib/rstm.py b/rpython/rlib/rstm.py
--- a/rpython/rlib/rstm.py
+++ b/rpython/rlib/rstm.py
@@ -1,10 +1,9 @@
-import thread
+import thread, weakref
 from rpython.translator.stm import stmgcintf
 from rpython.rlib.debug import ll_assert, fatalerror
 from rpython.rlib.objectmodel import keepalive_until_here, specialize
 from rpython.rlib.objectmodel import we_are_translated
 from rpython.rlib.rposix import get_errno, set_errno
-from rpython.rlib.rarithmetic import intmask
 from rpython.rtyper.lltypesystem import lltype, llmemory, rffi, rclass
 from rpython.rtyper.lltypesystem.lloperation import llop
 from rpython.rtyper.annlowlevel import cast_instance_to_base_ptr, llhelper
@@ -108,12 +107,16 @@
 # ____________________________________________________________
 
 class ThreadLocalReference(object):
+    _ALL = weakref.WeakKeyDictionary()
+    _COUNT = 0
 
     def __init__(self, Cls):
         "NOT_RPYTHON: must be prebuilt"
         self.Cls = Cls
-        self.unique_id = intmask(id(self))
         self.local = thread._local()
+        self.unique_id = ThreadLocalReference._COUNT
+        ThreadLocalReference._COUNT += 1
+        ThreadLocalReference._ALL[self] = True
 
     def _freeze_(self):
         return True
@@ -121,7 +124,7 @@
     @specialize.arg(0)
     def get(self):
         if we_are_translated():
-            ptr = llop.stm_localref_get(llmemory.Address, self.unique_id)
+            ptr = llop.stm_threadlocalref_get(llmemory.Address, self.unique_id)
             ptr = rffi.cast(rclass.OBJECTPTR, ptr)
             return cast_base_ptr_to_instance(self.Cls, ptr)
         else:
@@ -133,6 +136,14 @@
         if we_are_translated():
             ptr = cast_instance_to_base_ptr(value)
             ptr = rffi.cast(llmemory.Address, ptr)
-            llop.stm_localref_set(lltype.Void, self.unique_id, ptr)
+            llop.stm_threadlocalref_set(lltype.Void, self.unique_id, ptr)
         else:
             self.local.value = value
+
+    @staticmethod
+    def flush_all_in_this_thread():
+        if we_are_translated():
+            llop.stm_threadlocalref_flush(lltype.Void)
+        else:
+            for tlref in ThreadLocalReference._ALL.keys():
+                tlref.local.value = None
diff --git a/rpython/rlib/test/test_rstm.py b/rpython/rlib/test/test_rstm.py
--- a/rpython/rlib/test/test_rstm.py
+++ b/rpython/rlib/test/test_rstm.py
@@ -8,10 +8,13 @@
     results = []
     def subthread():
         x = FooBar()
+        results.append(t.get() is None)
         t.set(x)
         time.sleep(0.2)
         results.append(t.get() is x)
+        ThreadLocalReference.flush_all_in_this_thread()
+        results.append(t.get() is None)
     for i in range(5):
         thread.start_new_thread(subthread, ())
     time.sleep(0.5)
-    assert results == [True] * 5
+    assert results == [True] * 15
diff --git a/rpython/rtyper/lltypesystem/lloperation.py 
b/rpython/rtyper/lltypesystem/lloperation.py
--- a/rpython/rtyper/lltypesystem/lloperation.py
+++ b/rpython/rtyper/lltypesystem/lloperation.py
@@ -429,6 +429,9 @@
     'stm_start_transaction':  LLOp(canrun=True, canmallocgc=True),
     'stm_stop_transaction':   LLOp(canrun=True, canmallocgc=True),
     #'stm_jit_invoke_code':    LLOp(canmallocgc=True),
+    'stm_threadlocalref_get': LLOp(sideeffects=False),
+    'stm_threadlocalref_set': LLOp(),
+    'stm_threadlocalref_flush': LLOp(),
 
     # __________ address operations __________
 
diff --git a/rpython/translator/stm/src_stm/atomic_ops.h 
b/rpython/translator/stm/src_stm/atomic_ops.h
--- a/rpython/translator/stm/src_stm/atomic_ops.h
+++ b/rpython/translator/stm/src_stm/atomic_ops.h
@@ -1,6 +1,8 @@
 #ifndef _SRCSTM_ATOMIC_OPS_
 #define _SRCSTM_ATOMIC_OPS_
 
+#include <assert.h>
+
 
 /* "compiler fence" for preventing reordering of loads/stores to
    non-volatiles */
diff --git a/rpython/translator/stm/test/test_ztranslated.py 
b/rpython/translator/stm/test/test_ztranslated.py
--- a/rpython/translator/stm/test/test_ztranslated.py
+++ b/rpython/translator/stm/test/test_ztranslated.py
@@ -104,3 +104,20 @@
         t, cbuilder = self.compile(main)
         data = cbuilder.cmdexec('')
         assert '42\n' in data, "got: %r" % (data,)
+
+    def test_threadlocalref(self):
+        class FooBar(object):
+            pass
+        t = rstm.ThreadLocalReference(FooBar)
+        def main(argv):
+            x = FooBar()
+            assert t.get() is None
+            t.set(x)
+            assert t.get() is x
+            rstm.ThreadLocalReference.flush_all_in_this_thread()
+            assert t.get() is None
+            print "ok"
+            return 0
+        t, cbuilder = self.compile(main)
+        data = cbuilder.cmdexec('')
+        assert 'ok\n' in data
diff --git a/rpython/translator/stm/threadlocalref.py 
b/rpython/translator/stm/threadlocalref.py
new file mode 100644
--- /dev/null
+++ b/rpython/translator/stm/threadlocalref.py
@@ -0,0 +1,61 @@
+from rpython.rtyper.lltypesystem import lltype, llmemory
+from rpython.translator.unsimplify import varoftype
+from rpython.flowspace.model import SpaceOperation, Constant
+
+#
+# Note: all this slightly messy code is to have 'stm_threadlocalref_flush'
+# which zeroes *all* thread-locals variables accessed with
+# stm_threadlocalref_{get,set}.
+#
+
+def transform_tlref(graphs):
+    ids = set()
+    #
+    for graph in graphs:
+        for block in graph.iterblocks():
+            for i in range(len(block.operations)):
+                op = block.operations[i]
+                if (op.opname == 'stm_threadlocalref_set' or
+                    op.opname == 'stm_threadlocalref_get'):
+                    ids.add(op.args[0].value)
+    if len(ids) == 0:
+        return
+    #
+    ids = sorted(ids)
+    fields = [('ptr%d' % id1, llmemory.Address) for id1 in ids]
+    kwds = {'hints': {'stm_thread_local': True}}
+    S = lltype.Struct('THREADLOCALREF', *fields, **kwds)
+    ll_threadlocalref = lltype.malloc(S, immortal=True)
+    c_threadlocalref = Constant(ll_threadlocalref, lltype.Ptr(S))
+    c_fieldnames = {}
+    for id1 in ids:
+        fieldname = 'ptr%d' % id1
+        c_fieldnames[id1] = Constant(fieldname, lltype.Void)
+    c_null = Constant(llmemory.NULL, llmemory.Address)
+    #
+    for graph in graphs:
+        for block in graph.iterblocks():
+            for i in range(len(block.operations)-1, -1, -1):
+                op = block.operations[i]
+                if op.opname == 'stm_threadlocalref_set':
+                    id1 = op.args[0].value
+                    op = SpaceOperation('setfield', [c_threadlocalref,
+                                                     c_fieldnames[id1],
+                                                     op.args[1]],
+                                        op.result)
+                    block.operations[i] = op
+                elif op.opname == 'stm_threadlocalref_get':
+                    id1 = op.args[0].value
+                    op = SpaceOperation('getfield', [c_threadlocalref,
+                                                     c_fieldnames[id1]],
+                                        op.result)
+                    block.operations[i] = op
+                elif op.opname == 'stm_threadlocalref_flush':
+                    extra = []
+                    for id1 in ids:
+                        op = SpaceOperation('setfield', [c_threadlocalref,
+                                                         c_fieldnames[id1],
+                                                         c_null],
+                                            varoftype(lltype.Void))
+                        extra.append(op)
+                    block.operations[i:i+1] = extra
diff --git a/rpython/translator/stm/transform2.py 
b/rpython/translator/stm/transform2.py
--- a/rpython/translator/stm/transform2.py
+++ b/rpython/translator/stm/transform2.py
@@ -7,6 +7,7 @@
     def transform(self):
         assert not hasattr(self.translator, 'stm_transformation_applied')
         self.start_log()
+        self.transform_threadlocalref()
         self.transform_jit_driver()
         self.transform_write_barrier()
         self.transform_turn_inevitable()
@@ -34,6 +35,10 @@
         for graph in self.translator.graphs:
             reorganize_around_jit_driver(self, graph)
 
+    def transform_threadlocalref(self):
+        from rpython.translator.stm.threadlocalref import transform_tlref
+        transform_tlref(self.translator.graphs)
+
     def start_log(self):
         from rpython.translator.c.support import log
         log.info("Software Transactional Memory transformation")
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to