Author: Armin Rigo <[email protected]>
Branch: stm-thread-2
Changeset: r61407:7000756077c5
Date: 2013-02-18 14:07 +0100
http://bitbucket.org/pypy/pypy/changeset/7000756077c5/
Log: ThreadLocalReference: implementation.
diff --git a/rpython/rlib/rstm.py b/rpython/rlib/rstm.py
--- a/rpython/rlib/rstm.py
+++ b/rpython/rlib/rstm.py
@@ -1,10 +1,9 @@
-import thread
+import thread, weakref
from rpython.translator.stm import stmgcintf
from rpython.rlib.debug import ll_assert, fatalerror
from rpython.rlib.objectmodel import keepalive_until_here, specialize
from rpython.rlib.objectmodel import we_are_translated
from rpython.rlib.rposix import get_errno, set_errno
-from rpython.rlib.rarithmetic import intmask
from rpython.rtyper.lltypesystem import lltype, llmemory, rffi, rclass
from rpython.rtyper.lltypesystem.lloperation import llop
from rpython.rtyper.annlowlevel import cast_instance_to_base_ptr, llhelper
@@ -108,12 +107,16 @@
# ____________________________________________________________
class ThreadLocalReference(object):
+ _ALL = weakref.WeakKeyDictionary()
+ _COUNT = 0
def __init__(self, Cls):
"NOT_RPYTHON: must be prebuilt"
self.Cls = Cls
- self.unique_id = intmask(id(self))
self.local = thread._local()
+ self.unique_id = ThreadLocalReference._COUNT
+ ThreadLocalReference._COUNT += 1
+ ThreadLocalReference._ALL[self] = True
def _freeze_(self):
return True
@@ -121,7 +124,7 @@
@specialize.arg(0)
def get(self):
if we_are_translated():
- ptr = llop.stm_localref_get(llmemory.Address, self.unique_id)
+ ptr = llop.stm_threadlocalref_get(llmemory.Address, self.unique_id)
ptr = rffi.cast(rclass.OBJECTPTR, ptr)
return cast_base_ptr_to_instance(self.Cls, ptr)
else:
@@ -133,6 +136,14 @@
if we_are_translated():
ptr = cast_instance_to_base_ptr(value)
ptr = rffi.cast(llmemory.Address, ptr)
- llop.stm_localref_set(lltype.Void, self.unique_id, ptr)
+ llop.stm_threadlocalref_set(lltype.Void, self.unique_id, ptr)
else:
self.local.value = value
+
+ @staticmethod
+ def flush_all_in_this_thread():
+ if we_are_translated():
+ llop.stm_threadlocalref_flush(lltype.Void)
+ else:
+ for tlref in ThreadLocalReference._ALL.keys():
+ tlref.local.value = None
diff --git a/rpython/rlib/test/test_rstm.py b/rpython/rlib/test/test_rstm.py
--- a/rpython/rlib/test/test_rstm.py
+++ b/rpython/rlib/test/test_rstm.py
@@ -8,10 +8,13 @@
results = []
def subthread():
x = FooBar()
+ results.append(t.get() is None)
t.set(x)
time.sleep(0.2)
results.append(t.get() is x)
+ ThreadLocalReference.flush_all_in_this_thread()
+ results.append(t.get() is None)
for i in range(5):
thread.start_new_thread(subthread, ())
time.sleep(0.5)
- assert results == [True] * 5
+ assert results == [True] * 15
diff --git a/rpython/rtyper/lltypesystem/lloperation.py
b/rpython/rtyper/lltypesystem/lloperation.py
--- a/rpython/rtyper/lltypesystem/lloperation.py
+++ b/rpython/rtyper/lltypesystem/lloperation.py
@@ -429,6 +429,9 @@
'stm_start_transaction': LLOp(canrun=True, canmallocgc=True),
'stm_stop_transaction': LLOp(canrun=True, canmallocgc=True),
#'stm_jit_invoke_code': LLOp(canmallocgc=True),
+ 'stm_threadlocalref_get': LLOp(sideeffects=False),
+ 'stm_threadlocalref_set': LLOp(),
+ 'stm_threadlocalref_flush': LLOp(),
# __________ address operations __________
diff --git a/rpython/translator/stm/src_stm/atomic_ops.h
b/rpython/translator/stm/src_stm/atomic_ops.h
--- a/rpython/translator/stm/src_stm/atomic_ops.h
+++ b/rpython/translator/stm/src_stm/atomic_ops.h
@@ -1,6 +1,8 @@
#ifndef _SRCSTM_ATOMIC_OPS_
#define _SRCSTM_ATOMIC_OPS_
+#include <assert.h>
+
/* "compiler fence" for preventing reordering of loads/stores to
non-volatiles */
diff --git a/rpython/translator/stm/test/test_ztranslated.py
b/rpython/translator/stm/test/test_ztranslated.py
--- a/rpython/translator/stm/test/test_ztranslated.py
+++ b/rpython/translator/stm/test/test_ztranslated.py
@@ -104,3 +104,20 @@
t, cbuilder = self.compile(main)
data = cbuilder.cmdexec('')
assert '42\n' in data, "got: %r" % (data,)
+
+ def test_threadlocalref(self):
+ class FooBar(object):
+ pass
+ t = rstm.ThreadLocalReference(FooBar)
+ def main(argv):
+ x = FooBar()
+ assert t.get() is None
+ t.set(x)
+ assert t.get() is x
+ rstm.ThreadLocalReference.flush_all_in_this_thread()
+ assert t.get() is None
+ print "ok"
+ return 0
+ t, cbuilder = self.compile(main)
+ data = cbuilder.cmdexec('')
+ assert 'ok\n' in data
diff --git a/rpython/translator/stm/threadlocalref.py
b/rpython/translator/stm/threadlocalref.py
new file mode 100644
--- /dev/null
+++ b/rpython/translator/stm/threadlocalref.py
@@ -0,0 +1,61 @@
+from rpython.rtyper.lltypesystem import lltype, llmemory
+from rpython.translator.unsimplify import varoftype
+from rpython.flowspace.model import SpaceOperation, Constant
+
+#
+# Note: all this slightly messy code is to have 'stm_threadlocalref_flush'
+# which zeroes *all* thread-locals variables accessed with
+# stm_threadlocalref_{get,set}.
+#
+
+def transform_tlref(graphs):
+ ids = set()
+ #
+ for graph in graphs:
+ for block in graph.iterblocks():
+ for i in range(len(block.operations)):
+ op = block.operations[i]
+ if (op.opname == 'stm_threadlocalref_set' or
+ op.opname == 'stm_threadlocalref_get'):
+ ids.add(op.args[0].value)
+ if len(ids) == 0:
+ return
+ #
+ ids = sorted(ids)
+ fields = [('ptr%d' % id1, llmemory.Address) for id1 in ids]
+ kwds = {'hints': {'stm_thread_local': True}}
+ S = lltype.Struct('THREADLOCALREF', *fields, **kwds)
+ ll_threadlocalref = lltype.malloc(S, immortal=True)
+ c_threadlocalref = Constant(ll_threadlocalref, lltype.Ptr(S))
+ c_fieldnames = {}
+ for id1 in ids:
+ fieldname = 'ptr%d' % id1
+ c_fieldnames[id1] = Constant(fieldname, lltype.Void)
+ c_null = Constant(llmemory.NULL, llmemory.Address)
+ #
+ for graph in graphs:
+ for block in graph.iterblocks():
+ for i in range(len(block.operations)-1, -1, -1):
+ op = block.operations[i]
+ if op.opname == 'stm_threadlocalref_set':
+ id1 = op.args[0].value
+ op = SpaceOperation('setfield', [c_threadlocalref,
+ c_fieldnames[id1],
+ op.args[1]],
+ op.result)
+ block.operations[i] = op
+ elif op.opname == 'stm_threadlocalref_get':
+ id1 = op.args[0].value
+ op = SpaceOperation('getfield', [c_threadlocalref,
+ c_fieldnames[id1]],
+ op.result)
+ block.operations[i] = op
+ elif op.opname == 'stm_threadlocalref_flush':
+ extra = []
+ for id1 in ids:
+ op = SpaceOperation('setfield', [c_threadlocalref,
+ c_fieldnames[id1],
+ c_null],
+ varoftype(lltype.Void))
+ extra.append(op)
+ block.operations[i:i+1] = extra
diff --git a/rpython/translator/stm/transform2.py
b/rpython/translator/stm/transform2.py
--- a/rpython/translator/stm/transform2.py
+++ b/rpython/translator/stm/transform2.py
@@ -7,6 +7,7 @@
def transform(self):
assert not hasattr(self.translator, 'stm_transformation_applied')
self.start_log()
+ self.transform_threadlocalref()
self.transform_jit_driver()
self.transform_write_barrier()
self.transform_turn_inevitable()
@@ -34,6 +35,10 @@
for graph in self.translator.graphs:
reorganize_around_jit_driver(self, graph)
+ def transform_threadlocalref(self):
+ from rpython.translator.stm.threadlocalref import transform_tlref
+ transform_tlref(self.translator.graphs)
+
def start_log(self):
from rpython.translator.c.support import log
log.info("Software Transactional Memory transformation")
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit