[pypy-commit] pypy stm-gc: Some progress, fixing test_rstm.

arigo noreply at buildbot.pypy.org
Mon Apr 16 15:51:33 CEST 2012


Author: Armin Rigo <arigo at tunes.org>
Branch: stm-gc
Changeset: r54419:53a337586265
Date: 2012-04-16 15:43 +0200
http://bitbucket.org/pypy/pypy/changeset/53a337586265/

Log:	Some progress, fixing test_rstm.

diff --git a/pypy/module/transaction/interp_transaction.py b/pypy/module/transaction/interp_transaction.py
--- a/pypy/module/transaction/interp_transaction.py
+++ b/pypy/module/transaction/interp_transaction.py
@@ -220,8 +220,8 @@
     """The main function running one of the threads."""
     # Note that we cannot allocate any object here outside a transaction,
     # so we need to be very careful.
+    rstm.descriptor_init()
     state.lock()
-    rstm.descriptor_init()
     #
     rstm.perform_transaction(_setup_thread, AbstractPending, None)
     my_transactions_pending = state.getvalue()._transaction_pending
@@ -260,10 +260,10 @@
             state.lock()
             _add_list(my_transactions_pending)
     #
-    rstm.descriptor_done()
     if state.num_waiting_threads == 0:    # only the last thread to leave
         state.unlock_unfinished()
     state.unlock()
+    rstm.descriptor_done()
 
 
 @rgc.no_collect
diff --git a/pypy/rlib/test/test_rstm.py b/pypy/rlib/test/test_rstm.py
--- a/pypy/rlib/test/test_rstm.py
+++ b/pypy/rlib/test/test_rstm.py
@@ -1,12 +1,14 @@
 import os, thread, time
-from pypy.rlib.debug import debug_print, ll_assert
+from pypy.rlib.debug import debug_print, ll_assert, fatalerror
 from pypy.rlib import rstm
+from pypy.rpython.annlowlevel import llhelper
 from pypy.translator.stm.test.support import CompiledSTMTests
+from pypy.module.thread import ll_thread
 
 
 class Arg(object):
     pass
-arg = Arg()
+arg_list = [Arg() for i in range(10)]
 
 def setx(arg, retry_counter):
     debug_print(arg.x)
@@ -17,22 +19,29 @@
         assert rstm._debug_get_state() == 2
     arg.x = 42
 
-def stm_perform_transaction(initial_x=202):
-    arg.x = initial_x
+def stm_perform_transaction(done=None, i=0):
     ll_assert(rstm._debug_get_state() == -2, "bad debug_get_state (1)")
     rstm.descriptor_init()
+    arg = arg_list[i]
+    if done is None:
+        arg.x = 202
+    else:
+        arg.x = done.initial_x
     ll_assert(rstm._debug_get_state() == 0, "bad debug_get_state (2)")
     rstm.perform_transaction(setx, Arg, arg)
     ll_assert(rstm._debug_get_state() == 0, "bad debug_get_state (3)")
+    ll_assert(arg.x == 42, "bad arg.x")
+    if done is not None:
+        ll_thread.release_NOAUTO(done.finished_lock)
     rstm.descriptor_done()
     ll_assert(rstm._debug_get_state() == -2, "bad debug_get_state (4)")
-    ll_assert(arg.x == 42, "bad arg.x")
 
 def test_stm_multiple_threads():
     ok = []
     def f(i):
-        stm_perform_transaction()
+        stm_perform_transaction(i=i)
         ok.append(i)
+    rstm.enter_transactional_mode()
     for i in range(10):
         thread.start_new_thread(f, (i,))
     timeout = 10
@@ -40,6 +49,7 @@
         time.sleep(0.1)
         timeout -= 0.1
         assert timeout >= 0.0, "timeout!"
+    rstm.leave_transactional_mode()
     assert sorted(ok) == range(10)
 
 
@@ -58,22 +68,23 @@
         assert '102' in dataerr.splitlines()
 
     def build_perform_transaction(self):
-        from pypy.module.thread import ll_thread
         class Done: done = False
         done = Done()
         def g():
-            stm_perform_transaction(done.initial_x)
-            done.done = True
+            stm_perform_transaction(done)
         def f(argv):
             done.initial_x = int(argv[1])
             assert rstm._debug_get_state() == -1    # main thread
-            ll_thread.start_new_thread(g, ())
-            for i in range(20):
-                if done.done: break
-                time.sleep(0.1)
-            else:
-                print "timeout!"
-                raise Exception
+            done.finished_lock = ll_thread.allocate_ll_lock()
+            ll_thread.acquire_NOAUTO(done.finished_lock, True)
+            #
+            rstm.enter_transactional_mode()
+            #
+            llcallback = llhelper(ll_thread.CALLBACK, g)
+            ident = ll_thread.c_thread_start_NOGIL(llcallback)
+            ll_thread.acquire_NOAUTO(done.finished_lock, True)
+            #
+            rstm.leave_transactional_mode()
             return 0
         t, cbuilder = self.compile(f)
         return cbuilder
diff --git a/pypy/rpython/lltypesystem/lloperation.py b/pypy/rpython/lltypesystem/lloperation.py
--- a/pypy/rpython/lltypesystem/lloperation.py
+++ b/pypy/rpython/lltypesystem/lloperation.py
@@ -403,8 +403,8 @@
     'stm_become_inevitable':  LLOp(),
     'stm_descriptor_init':    LLOp(canrun=True),
     'stm_descriptor_done':    LLOp(canrun=True),
-    'stm_enter_transactional_mode': LLOp(),
-    'stm_leave_transactional_mode': LLOp(),
+    'stm_enter_transactional_mode': LLOp(canrun=True),
+    'stm_leave_transactional_mode': LLOp(canrun=True),
     'stm_writebarrier':       LLOp(sideeffects=False),
     'stm_normalize_global':   LLOp(),
     'stm_start_transaction':  LLOp(canrun=True),
diff --git a/pypy/rpython/lltypesystem/opimpl.py b/pypy/rpython/lltypesystem/opimpl.py
--- a/pypy/rpython/lltypesystem/opimpl.py
+++ b/pypy/rpython/lltypesystem/opimpl.py
@@ -630,6 +630,12 @@
 def op_stm_commit_transaction():
     pass
 
+def op_stm_enter_transactional_mode():
+    pass
+
+def op_stm_leave_transactional_mode():
+    pass
+
 # ____________________________________________________________
 
 def get_op_impl(opname):
diff --git a/pypy/rpython/memory/gc/stmgc.py b/pypy/rpython/memory/gc/stmgc.py
--- a/pypy/rpython/memory/gc/stmgc.py
+++ b/pypy/rpython/memory/gc/stmgc.py
@@ -240,8 +240,7 @@
         self.get_tls().start_transaction()
 
     def commit_transaction(self):
-        raise NotImplementedError
-        self.collector.commit_transaction()
+        self.get_tls().stop_transaction()
 
 
     @always_inline
@@ -317,6 +316,11 @@
         #
         @dont_inline
         def _stm_write_barrier_global(obj):
+            tls = self.get_tls()
+            if not tls.in_transaction():
+                return obj # not in transaction: only when running the code
+                           # in _run_thread(), i.e. in sub-threads outside
+                           # transactions.  xxx statically detect this case?
             # we need to find or make a local copy
             hdr = self.header(obj)
             if hdr.tid & GCFLAG_WAS_COPIED == 0:
@@ -339,7 +343,6 @@
             # Here, we need to really make a local copy
             size = self.get_size(obj)
             totalsize = self.gcheaderbuilder.size_gc_header + size
-            tls = self.get_tls()
             try:
                 localobj = tls.malloc_local_copy(totalsize)
             except MemoryError:
@@ -453,230 +456,3 @@
 
     def identityhash(self, gcobj):
         return self.id_or_identityhash(gcobj, True)
-
-# ------------------------------------------------------------
-
-
-class Collector(object):
-    """A separate frozen class.  Useful to prevent any buggy concurrent
-    access to GC data.  The methods here use the GCTLS instead for
-    storing things in a thread-local way."""
-
-    def __init__(self, gc):
-        self.gc = gc
-        self.stm_operations = gc.stm_operations
-
-    def _freeze_(self):
-        return True
-
-    def is_in_nursery(self, tls, addr):
-        ll_assert(llmemory.cast_adr_to_int(addr) & 1 == 0,
-                  "odd-valued (i.e. tagged) pointer unexpected here")
-        return tls.nursery_start <= addr < tls.nursery_top
-
-    def header(self, obj):
-        return self.gc.header(obj)
-
-
-    def start_transaction(self):
-        """Start a transaction, by clearing and resetting the tls nursery."""
-        tls = self.get_tls()
-        self.gc.reset_nursery(tls)
-
-
-    def commit_transaction(self):
-        """End of a transaction, just before its end.  No more GC
-        operations should occur afterwards!  Note that the C code that
-        does the commit runs afterwards, and may still abort."""
-        #
-        debug_start("gc-collect-commit")
-        #
-        tls = self.get_tls()
-        #
-        # Do a mark-and-move minor collection out of the tls' nursery
-        # into the main thread's global area (which is right now also
-        # called a nursery).
-        debug_print("local arena:", tls.nursery_free - tls.nursery_start,
-                    "bytes")
-        #
-        # We are starting from the tldict's local objects as roots.  At
-        # this point, these objects have GCFLAG_WAS_COPIED, and the other
-        # local objects don't.  We want to move all reachable local objects
-        # to the global area.
-        #
-        # Start from tracing the root objects
-        self.collect_roots_from_tldict(tls)
-        #
-        # Continue iteratively until we have reached all the reachable
-        # local objects
-        self.collect_from_pending_list(tls)
-        #
-        # Fix up the weakrefs that used to point to local objects
-        self.fixup_weakrefs(tls)
-        #
-        # Now, all indirectly reachable local objects have been copied into
-        # the global area, and all pointers have been fixed to point to the
-        # global copies, including in the local copy of the roots.  What
-        # remains is only overwriting of the global copy of the roots.
-        # This is done by the C code.
-        debug_stop("gc-collect-commit")
-
-
-    def collect_roots_from_tldict(self, tls):
-        tls.pending_list = NULL
-        tls.surviving_weakrefs = NULL
-        # Enumerate the roots, which are the local copies of global objects.
-        # For each root, trace it.
-        CALLBACK = self.stm_operations.CALLBACK_ENUM
-        callback = llhelper(CALLBACK, self._enum_entries)
-        # xxx hack hack hack!  Stores 'self' in a global place... but it's
-        # pointless after translation because 'self' is a Void.
-        _global_collector.collector = self
-        self.stm_operations.tldict_enum(callback)
-
-
-    @staticmethod
-    def _enum_entries(tls_addr, globalobj, localobj):
-        self = _global_collector.collector
-        tls = llmemory.cast_adr_to_ptr(tls_addr, lltype.Ptr(StmGC.GCTLS))
-        #
-        localhdr = self.header(localobj)
-        ll_assert(localhdr.version == globalobj,
-                  "in a root: localobj.version != globalobj")
-        ll_assert(localhdr.tid & GCFLAG_GLOBAL == 0,
-                  "in a root: unexpected GCFLAG_GLOBAL")
-        ll_assert(localhdr.tid & GCFLAG_WAS_COPIED != 0,
-                  "in a root: missing GCFLAG_WAS_COPIED")
-        #
-        self.trace_and_drag_out_of_nursery(tls, localobj)
-
-
-    def collect_from_pending_list(self, tls):
-        while tls.pending_list != NULL:
-            pending_obj = tls.pending_list
-            pending_hdr = self.header(pending_obj)
-            #
-            # 'pending_list' is a chained list of fresh global objects,
-            # linked together via their 'version' field.  The 'version'
-            # must be replaced with NULL after we pop the object from
-            # the linked list.
-            tls.pending_list = pending_hdr.version
-            pending_hdr.version = NULL
-            #
-            # Check the flags of pending_obj: it should be a fresh global
-            # object, without GCFLAG_WAS_COPIED
-            ll_assert(pending_hdr.tid & GCFLAG_GLOBAL != 0,
-                      "from pending list: missing GCFLAG_GLOBAL")
-            ll_assert(pending_hdr.tid & GCFLAG_WAS_COPIED == 0,
-                      "from pending list: unexpected GCFLAG_WAS_COPIED")
-            #
-            self.trace_and_drag_out_of_nursery(tls, pending_obj)
-
-
-    def trace_and_drag_out_of_nursery(self, tls, obj):
-        # This is called to fix the references inside 'obj', to ensure that
-        # they are global.  If necessary, the referenced objects are copied
-        # into the global area first.  This is called on the *local* copy of
-        # the roots, and on the fresh *global* copy of all other reached
-        # objects.
-        self.gc.trace(obj, self._trace_drag_out, tls)
-
-    def _trace_drag_out(self, root, tls):
-        obj = root.address[0]
-        hdr = self.header(obj)
-        #
-        # Figure out if the object is GLOBAL or not by looking at its
-        # address, not at its header --- to avoid cache misses and
-        # pollution for all global objects
-        if not self.is_in_nursery(tls, obj):
-            ll_assert(hdr.tid & GCFLAG_GLOBAL != 0,
-                      "trace_and_mark: non-GLOBAL obj is not in nursery")
-            return        # ignore global objects
-        #
-        ll_assert(hdr.tid & GCFLAG_GLOBAL == 0,
-                  "trace_and_mark: GLOBAL obj in nursery")
-        #
-        if hdr.tid & (GCFLAG_WAS_COPIED | GCFLAG_HAS_SHADOW) == 0:
-            # First visit to a local-only 'obj': allocate a corresponding
-            # global object
-            size = self.gc.get_size(obj)
-            globalobj = self.gc._malloc_global_raw(tls, size)
-            need_to_copy = True
-            #
-        else:
-            globalobj = hdr.version
-            if hdr.tid & GCFLAG_WAS_COPIED != 0:
-                # this local object is a root or was already marked.  Either
-                # way, its 'version' field should point to the corresponding
-                # global object. 
-                size = 0
-                need_to_copy = False
-            else:
-                # this local object has a shadow made by id_or_identityhash();
-                # and the 'version' field points to the global shadow.
-                ll_assert(hdr.tid & GCFLAG_HAS_SHADOW != 0, "uh?")
-                size = self.gc.get_size(obj)
-                need_to_copy = True
-        #
-        if need_to_copy:
-            # Copy the data of the object from the local to the global
-            llmemory.raw_memcopy(obj, globalobj, size)
-            #
-            # Initialize the header of the 'globalobj'
-            globalhdr = self.header(globalobj)
-            globalhdr.tid = hdr.tid | GCFLAG_GLOBAL
-            #
-            # Add the flags to 'localobj' to say 'has been copied now'
-            hdr.tid |= GCFLAG_WAS_COPIED
-            hdr.version = globalobj
-            #
-            # Set a temporary linked list through the globalobj's version
-            # numbers.  This is normally not allowed, but it works here
-            # because these new globalobjs are not visible to any other
-            # thread before the commit is really complete.
-            globalhdr.version = tls.pending_list
-            tls.pending_list = globalobj
-            #
-            if hdr.tid & GCFLAG_WEAKREF != 0:
-                # this was a weakref object that survives.
-                self.young_weakref_survives(tls, obj)
-        #
-        # Fix the original root.address[0] to point to the globalobj
-        root.address[0] = globalobj
-
-
-    @dont_inline
-    def young_weakref_survives(self, tls, obj):
-        # Relink it in the tls.surviving_weakrefs chained list,
-        # via the weakpointer_offset in the local copy of the object.
-        # Do it only if the weakref points to a local object.
-        offset = self.gc.weakpointer_offset(self.gc.get_type_id(obj))
-        if self.is_in_nursery(tls, (obj + offset).address[0]):
-            (obj + offset).address[0] = tls.surviving_weakrefs
-            tls.surviving_weakrefs = obj
-
-    def fixup_weakrefs(self, tls):
-        obj = tls.surviving_weakrefs
-        while obj:
-            offset = self.gc.weakpointer_offset(self.gc.get_type_id(obj))
-            #
-            hdr = self.header(obj)
-            ll_assert(hdr.tid & GCFLAG_GLOBAL == 0,
-                      "weakref: unexpectedly global")
-            globalobj = hdr.version
-            obj2 = (globalobj + offset).address[0]
-            hdr2 = self.header(obj2)
-            ll_assert(hdr2.tid & GCFLAG_GLOBAL == 0,
-                      "weakref: points to a global")
-            if hdr2.tid & GCFLAG_WAS_COPIED:
-                obj2g = hdr2.version    # obj2 survives, going there
-            else:
-                obj2g = llmemory.NULL   # obj2 dies
-            (globalobj + offset).address[0] = obj2g
-            #
-            obj = (obj + offset).address[0]
-
-
-class _GlobalCollector(object):
-    pass
-_global_collector = _GlobalCollector()
diff --git a/pypy/rpython/memory/gc/stmtls.py b/pypy/rpython/memory/gc/stmtls.py
--- a/pypy/rpython/memory/gc/stmtls.py
+++ b/pypy/rpython/memory/gc/stmtls.py
@@ -1,13 +1,15 @@
 from pypy.rpython.lltypesystem import lltype, llmemory, llarena, rffi
-from pypy.rpython.annlowlevel import cast_instance_to_base_ptr
+from pypy.rpython.annlowlevel import cast_instance_to_base_ptr, llhelper
 from pypy.rpython.annlowlevel import cast_base_ptr_to_instance, base_ptr_lltype
 from pypy.rlib.objectmodel import we_are_translated, free_non_gc_object
+from pypy.rlib.objectmodel import specialize
 from pypy.rlib.rarithmetic import r_uint
 from pypy.rlib.debug import ll_assert, debug_start, debug_stop, fatalerror
 
 from pypy.rpython.memory.gc.stmgc import WORD, NULL
 from pypy.rpython.memory.gc.stmgc import always_inline, dont_inline
 from pypy.rpython.memory.gc.stmgc import GCFLAG_GLOBAL, GCFLAG_VISITED
+from pypy.rpython.memory.gc.stmgc import GCFLAG_WAS_COPIED
 
 
 class StmGCTLS(object):
@@ -101,11 +103,8 @@
     def enter_transactional_mode(self):
         """Called on the main thread, just before spawning the other
         threads."""
-        self.local_collection()
-        if not self.local_nursery_is_empty():
-            self.local_collection(run_finalizers=False)
-        self._promote_locals_to_globals()
-        self._disable_mallocs()
+        self.stop_transaction()
+        self.stm_operations.enter_transactional_mode()
 
     def leave_transactional_mode(self):
         """Restart using the main thread for mallocs."""
@@ -113,10 +112,11 @@
             for key, value in StmGCTLS.nontranslated_dict.items():
                 if value is not self:
                     del StmGCTLS.nontranslated_dict[key]
+        self.stm_operations.leave_transactional_mode()
         self.start_transaction()
 
     def start_transaction(self):
-        """Enter a thread: performs any pending cleanups, and set
+        """Start a transaction: performs any pending cleanups, and set
         up a fresh state for allocating.  Called at the start of
         each transaction, and at the start of the main thread."""
         # Note that the calls to enter() and
@@ -134,11 +134,24 @@
         self.nursery_free = self.nursery_start
         self.nursery_top  = self.nursery_start + self.nursery_size
 
+    def stop_transaction(self):
+        """Stop a transaction: do a local collection to empty the
+        nursery and track which objects are still alive now, and
+        then mark all these objects as global."""
+        self.local_collection()
+        if not self.local_nursery_is_empty():
+            self.local_collection(run_finalizers=False)
+        self._promote_locals_to_globals()
+        self._disable_mallocs()
+
     def local_nursery_is_empty(self):
         ll_assert(bool(self.nursery_free),
                   "local_nursery_is_empty: gc not running")
         return self.nursery_free == self.nursery_start
 
+    def in_transaction(self):
+        return bool(self.nursery_free)
+
     # ------------------------------------------------------------
 
     def local_collection(self, run_finalizers=True):
@@ -192,18 +205,6 @@
         #
         debug_stop("gc-local")
 
-    def end_of_transaction_collection(self):
-        """Do an end-of-transaction collection.  Finds all surviving
-        non-GCFLAG_WAS_COPIED young objects and make them old.  Assumes
-        that there are no roots from the stack.  This guarantees that the
-        nursery will end up empty, apart from GCFLAG_WAS_COPIED objects.
-        To finish the commit, the C code will need to copy them over the
-        global objects (or abort in case of conflict, which is still ok).
-
-        No more mallocs are allowed after this is called.
-        """
-        raise NotImplementedError
-
     # ------------------------------------------------------------
 
     @always_inline
@@ -245,8 +246,6 @@
 
     def _promote_locals_to_globals(self):
         ll_assert(self.local_nursery_is_empty(), "nursery must be empty [1]")
-        ll_assert(not self.sharedarea_tls.special_stack.non_empty(),
-                  "special_stack should be empty here [1]")
         #
         # Promote all objects in sharedarea_tls to global
         obj = self.sharedarea_tls.chained_list
@@ -272,18 +271,31 @@
         self.gc.root_walker.walk_current_stack_roots(
             StmGCTLS._trace_drag_out1, self)
 
+    def trace_and_drag_out_of_nursery(self, obj):
+        # This is called to fix the references inside 'obj', to ensure that
+        # they are global.  If necessary, the referenced objects are copied
+        # into the global area first.  This is called on the LOCAL copy of
+        # the roots, and on the freshly OLD copy of all other reached LOCAL
+        # objects.
+        self.gc.trace(obj, self._trace_drag_out, None)
+
     def _trace_drag_out1(self, root):
         self._trace_drag_out(root, None)
 
     def _trace_drag_out(self, root, ignored):
         """Trace callback: 'root' is the address of some pointer.  If that
         pointer points to a YOUNG object, allocate an OLD copy of it and
-        fix the pointer.  Also, add the object to 'pending_list', if it was
-        not done so far.
+        fix the pointer.  Also, add the object to the 'pending' stack, if
+        it was not done so far.
         """
         obj = root.address[0]
         hdr = self.gc.header(obj)
         #
+        # If 'obj' is a LOCAL copy of a GLOBAL object, skip it
+        # (this case is handled differently in collect_roots_from_tldict)
+        if hdr.tid & GCFLAG_WAS_COPIED:
+            return
+        #
         # If 'obj' is not in the nursery, we set GCFLAG_VISITED
         if not self.is_in_nursery(obj):
             if hdr.tid & GCFLAG_VISITED == 0:
@@ -335,14 +347,32 @@
         self.sharedarea_tls.add_regular(obj)
 
     def collect_roots_from_tldict(self):
-        pass  # XXX
+        if not we_are_translated():
+            if not hasattr(self.stm_operations, 'tldict_enum'):
+                return
+        CALLBACK = self.stm_operations.CALLBACK_ENUM
+        callback = llhelper(CALLBACK, StmGCTLS._enum_entries)
+        self.stm_operations.tldict_enum(callback)
+
+    @staticmethod
+    def _enum_entries(tlsaddr, globalobj, localobj):
+        self = StmGCTLS.cast_address_to_tls_object(tlsaddr)
+        localhdr = self.gc.header(localobj)
+        ll_assert(localhdr.version == globalobj,
+                  "in a root: localobj.version != globalobj")
+        ll_assert(localhdr.tid & GCFLAG_GLOBAL == 0,
+                  "in a root: unexpected GCFLAG_GLOBAL")
+        ll_assert(localhdr.tid & GCFLAG_WAS_COPIED != 0,
+                  "in a root: missing GCFLAG_WAS_COPIED")
+        #
+        self.trace_and_drag_out_of_nursery(localobj)
 
     def collect_flush_pending(self):
         # Follow the objects in the 'pending' stack and move the
         # young objects they point to out of the nursery.
         while self.pending.non_empty():
             obj = self.pending.pop()
-            self.gc.trace(obj, self._trace_drag_out, None)
+            self.trace_and_drag_out_of_nursery(obj)
         self.pending.delete()
 
     def mass_free_old_local(self, previous_sharedarea_tls):
diff --git a/pypy/rpython/memory/gctransform/stmframework.py b/pypy/rpython/memory/gctransform/stmframework.py
--- a/pypy/rpython/memory/gctransform/stmframework.py
+++ b/pypy/rpython/memory/gctransform/stmframework.py
@@ -57,6 +57,14 @@
     def gct_stm_descriptor_done(self, hop):
         hop.genop("direct_call", [self.teardown_thread_ptr, self.c_const_gc])
 
+    def gct_stm_enter_transactional_mode(self, hop):
+        hop.genop("direct_call", [self.stm_enter_transactional_mode_ptr,
+                                  self.c_const_gc])
+
+    def gct_stm_leave_transactional_mode(self, hop):
+        hop.genop("direct_call", [self.stm_leave_transactional_mode_ptr,
+                                  self.c_const_gc])
+
     def gct_stm_writebarrier(self, hop):
         op = hop.spaceop
         v_adr = hop.genop('cast_ptr_to_adr',
diff --git a/pypy/translator/stm/src_stm/et.c b/pypy/translator/stm/src_stm/et.c
--- a/pypy/translator/stm/src_stm/et.c
+++ b/pypy/translator/stm/src_stm/et.c
@@ -143,26 +143,18 @@
 {
   owner_version_t newver = d->end_time;
   wlog_t *item;
-  /* loop in "forward" order: in this order, if there are duplicate orecs
-     then only the last one has p != -1. */
   REDOLOG_LOOP_FORWARD(d->redolog, item)
     {
       void *globalobj = item->addr;
       void *localobj = item->val;
-      owner_version_t p = item->p;
       long size = rpython_get_size(localobj);
       memcpy(((char *)globalobj) + sizeof(orec_t),
              ((char *)localobj) + sizeof(orec_t),
              size - sizeof(orec_t));
-      /* but we must only unlock the orec if it's the last time it
-         appears in the redolog list.  If it's not, then p == -1.
-         XXX I think that duplicate orecs are not possible any more. */
-      if (p != -1)
-        {
-          volatile orec_t* o = get_orec(globalobj);
-          CFENCE;
-          o->version = newver;
-        }
+      /* unlock the orec */
+      volatile orec_t* o = get_orec(globalobj);
+      CFENCE;
+      o->version = newver;
     } REDOLOG_LOOP_END;
 }
 
@@ -186,12 +178,10 @@
   wlog_t *item;
   REDOLOG_LOOP_FORWARD(d->redolog, item)
     {
-      if (item->p != -1)
-        {
-          volatile orec_t* o = get_orec(item->addr);
-          o->version = item->p;
-          item->p = -1;
-        }
+      volatile orec_t* o = get_orec(item->addr);
+      assert(item->p != -1);
+      o->version = item->p;
+      item->p = -1;
     } REDOLOG_LOOP_END;
 }
 
@@ -218,14 +208,14 @@
         if (!bool_cas(&o->version, ovt, d->my_lock_word))
           goto retry;
         // save old version to item->p.  Now we hold the lock.
-        // in case of duplicate orecs, only the last one has p != -1.
         item->p = ovt;
       }
       // else if the location is too recent...
       else if (!IS_LOCKED(ovt))
         tx_abort(0);
-      // else it is locked: if we don't hold the lock...
-      else if (ovt != d->my_lock_word) {
+      // else it is locked: check it's not by me
+      else {
+        assert(ovt != d->my_lock_word);
         // we can either abort or spinloop.  Because we are at the end of
         // the transaction we might try to spinloop, even though after the
         // lock is released the ovt will be very recent, possibly
@@ -534,7 +524,10 @@
       /*d->spinloop_counter = (unsigned int)(d->my_lock_word | 1);*/
 
       thread_descriptor = d;
-      /* active_thread_descriptor stays NULL */
+      if (in_main_thread)
+        stm_leave_transactional_mode();
+      else
+        ;   /* active_thread_descriptor stays NULL */
 
 #ifdef RPY_STM_DEBUG_PRINT
       if (PYPY_HAVE_DEBUG_PRINTS) fprintf(PYPY_DEBUG_FILE, "thread %lx starting\n",
@@ -681,6 +674,33 @@
   return result;
 }
 
+void stm_enter_transactional_mode(void)
+{
+  struct tx_descriptor *d = active_thread_descriptor;
+  assert(d != NULL);
+  assert(is_inevitable(d));
+  /* we only need a subset of a full commit */
+  acquireLocks(d);
+  commitInevitableTransaction(d);
+  common_cleanup(d);
+  active_thread_descriptor = NULL;
+}
+
+void stm_leave_transactional_mode(void)
+{
+  struct tx_descriptor *d = thread_descriptor;
+  assert(active_thread_descriptor == NULL);
+
+  mutex_lock();
+  d->setjmp_buf = NULL;
+  d->start_time = get_global_timestamp(d);
+  assert(!(d->start_time & 1));
+  set_global_timestamp(d, d->start_time | 1);
+
+  assert(is_inevitable(d));
+  active_thread_descriptor = d;
+}
+
 void stm_try_inevitable(STM_CCHARP1(why))
 {
   /* when a transaction is inevitable, its start_time is equal to
@@ -745,13 +765,12 @@
     return -2;
   if (active_thread_descriptor == NULL)
     {
-      if (d->my_lock_word == 0)
-        return -1;
-      else
-        return 0;
+      assert(d->my_lock_word != 0);
+      return 0;
     }
   assert(d == active_thread_descriptor);
-  assert(d->my_lock_word != 0);
+  if (d->my_lock_word == 0)
+    return -1;
   if (!is_inevitable(d))
     return 1;
   else
diff --git a/pypy/translator/stm/src_stm/et.h b/pypy/translator/stm/src_stm/et.h
--- a/pypy/translator/stm/src_stm/et.h
+++ b/pypy/translator/stm/src_stm/et.h
@@ -40,6 +40,8 @@
 
 
 void* stm_perform_transaction(void*(*)(void*, long), void*);
+void stm_enter_transactional_mode(void);
+void stm_leave_transactional_mode(void);
 void stm_try_inevitable(STM_CCHARP1(why));
 void stm_abort_and_retry(void);
 long stm_debug_get_state(void);  /* -1: descriptor_init() was not called
diff --git a/pypy/translator/stm/stmgcintf.py b/pypy/translator/stm/stmgcintf.py
--- a/pypy/translator/stm/stmgcintf.py
+++ b/pypy/translator/stm/stmgcintf.py
@@ -56,6 +56,11 @@
     get_tls = smexternal('stm_get_tls', [], llmemory.Address)
     del_tls = smexternal('stm_del_tls', [], lltype.Void)
 
+    enter_transactional_mode = smexternal('stm_enter_transactional_mode',
+                                          [], lltype.Void)
+    leave_transactional_mode = smexternal('stm_leave_transactional_mode',
+                                          [], lltype.Void)
+
     tldict_lookup = smexternal('stm_tldict_lookup', [llmemory.Address],
                                llmemory.Address)
     tldict_add = smexternal('stm_tldict_add', [llmemory.Address] * 2,


More information about the pypy-commit mailing list