[pypy-commit] pypy stmgc-c8: hg merge stmgc-c8-dictiter

arigo noreply at buildbot.pypy.org
Thu Nov 12 01:47:44 EST 2015


Author: Armin Rigo <arigo at tunes.org>
Branch: stmgc-c8
Changeset: r80642:8c3e06db1827
Date: 2015-11-12 07:48 +0100
http://bitbucket.org/pypy/pypy/changeset/8c3e06db1827/

Log:	hg merge stmgc-c8-dictiter

	Iterators over stm dictionaries.

diff --git a/pypy/module/pypystm/hashtable.py b/pypy/module/pypystm/hashtable.py
--- a/pypy/module/pypystm/hashtable.py
+++ b/pypy/module/pypystm/hashtable.py
@@ -2,6 +2,7 @@
 The class pypystm.hashtable, mapping integers to objects.
 """
 
+from pypy.interpreter.error import OperationError
 from pypy.interpreter.baseobjspace import W_Root
 from pypy.interpreter.typedef import TypeDef
 from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
@@ -78,6 +79,55 @@
                  for i in range(count)]
         return space.newlist(lst_w)
 
+    def iterkeys_w(self, space):
+        return W_HashtableIterKeys(self.h)
+
+    def itervalues_w(self, space):
+        return W_HashtableIterValues(self.h)
+
+    def iteritems_w(self, space):
+        return W_HashtableIterItems(self.h)
+
+
+class W_BaseHashtableIter(W_Root):
+    _immutable_fields_ = ["hiter"]
+
+    def __init__(self, hobj):
+        self.hiter = hobj.iterentries()
+
+    def descr_iter(self, space):
+        return self
+
+    def descr_length_hint(self, space):
+        # xxx overestimate: doesn't remove the items already yielded,
+        # and uses the faster len_estimate()
+        return space.wrap(self.hiter.hashtable.len_estimate())
+
+    def descr_next(self, space):
+        try:
+            entry = self.hiter.next()
+        except StopIteration:
+            raise OperationError(space.w_StopIteration, space.w_None)
+        return self.get_final_value(space, entry)
+
+    def _cleanup_(self):
+        raise Exception("seeing a prebuilt %r object" % (
+            self.__class__,))
+
+class W_HashtableIterKeys(W_BaseHashtableIter):
+    def get_final_value(self, space, entry):
+        return space.wrap(intmask(entry.index))
+
+class W_HashtableIterValues(W_BaseHashtableIter):
+    def get_final_value(self, space, entry):
+        return cast_gcref_to_instance(W_Root, entry.object)
+
+class W_HashtableIterItems(W_BaseHashtableIter):
+    def get_final_value(self, space, entry):
+        return space.newtuple([
+            space.wrap(intmask(entry.index)),
+            cast_gcref_to_instance(W_Root, entry.object)])
+
 
 def W_Hashtable___new__(space, w_subtype):
     r = space.allocate_instance(W_Hashtable, w_subtype)
@@ -98,4 +148,16 @@
     keys    = interp2app(W_Hashtable.keys_w),
     values  = interp2app(W_Hashtable.values_w),
     items   = interp2app(W_Hashtable.items_w),
+
+    __iter__   = interp2app(W_Hashtable.iterkeys_w),
+    iterkeys   = interp2app(W_Hashtable.iterkeys_w),
+    itervalues = interp2app(W_Hashtable.itervalues_w),
+    iteritems  = interp2app(W_Hashtable.iteritems_w),
 )
+
+W_BaseHashtableIter.typedef = TypeDef(
+    "hashtable_iter",
+    __iter__ = interp2app(W_BaseHashtableIter.descr_iter),
+    next = interp2app(W_BaseHashtableIter.descr_next),
+    __length_hint__ = interp2app(W_BaseHashtableIter.descr_length_hint),
+    )
diff --git a/pypy/module/pypystm/stmdict.py b/pypy/module/pypystm/stmdict.py
--- a/pypy/module/pypystm/stmdict.py
+++ b/pypy/module/pypystm/stmdict.py
@@ -2,6 +2,7 @@
 The class pypystm.stmdict, giving a part of the regular 'dict' interface
 """
 
+from pypy.interpreter.error import OperationError
 from pypy.interpreter.baseobjspace import W_Root
 from pypy.interpreter.typedef import TypeDef
 from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
@@ -215,10 +216,6 @@
     def len_w(self, space):
         return space.wrap(self.get_length())
 
-    def iter_w(self, space):
-        # not a real lazy iterator!
-        return space.iter(self.keys_w(space))
-
     def keys_w(self, space):
         return space.newlist(self.get_keys_values_w(offset=0))
 
@@ -228,6 +225,70 @@
     def items_w(self, space):
         return space.newlist(self.get_items_w(space))
 
+    def iterkeys_w(self, space):
+        return W_STMDictIterKeys(self.h)
+
+    def itervalues_w(self, space):
+        return W_STMDictIterValues(self.h)
+
+    def iteritems_w(self, space):
+        return W_STMDictIterItems(self.h)
+
+
+class W_BaseSTMDictIter(W_Root):
+    _immutable_fields_ = ["hiter"]
+    next_from_same_hash = 0
+
+    def __init__(self, hobj):
+        self.hiter = hobj.iterentries()
+
+    def descr_iter(self, space):
+        return self
+
+    def descr_length_hint(self, space):
+        # xxx estimate: doesn't remove the items already yielded,
+        # and uses the faster len_estimate(); on the other hand,
+        # counts only one for every 64-bit hash value
+        return space.wrap(self.hiter.hashtable.len_estimate())
+
+    def descr_next(self, space):
+        if self.next_from_same_hash == 0:      # common case
+            try:
+                entry = self.hiter.next()
+            except StopIteration:
+                raise OperationError(space.w_StopIteration, space.w_None)
+            index = 0
+            array = lltype.cast_opaque_ptr(PARRAY, entry.object)
+        else:
+            index = self.next_from_same_hash
+            array = self.next_array
+            self.next_from_same_hash = 0
+            self.next_array = lltype.nullptr(ARRAY)
+        #
+        if len(array) > index + 2:      # uncommon case
+            self.next_from_same_hash = index + 2
+            self.next_array = array
+        #
+        return self.get_final_value(space, array, index)
+
+    def _cleanup_(self):
+        raise Exception("seeing a prebuilt %r object" % (
+            self.__class__,))
+
+class W_STMDictIterKeys(W_BaseSTMDictIter):
+    def get_final_value(self, space, array, index):
+        return cast_gcref_to_instance(W_Root, array[index])
+
+class W_STMDictIterValues(W_BaseSTMDictIter):
+    def get_final_value(self, space, array, index):
+        return cast_gcref_to_instance(W_Root, array[index + 1])
+
+class W_STMDictIterItems(W_BaseSTMDictIter):
+    def get_final_value(self, space, array, index):
+        return space.newtuple([
+            cast_gcref_to_instance(W_Root, array[index]),
+            cast_gcref_to_instance(W_Root, array[index + 1])])
+
 
 def W_STMDict___new__(space, w_subtype):
     r = space.allocate_instance(W_STMDict, w_subtype)
@@ -246,8 +307,19 @@
     setdefault = interp2app(W_STMDict.setdefault_w),
 
     __len__  = interp2app(W_STMDict.len_w),
-    __iter__ = interp2app(W_STMDict.iter_w),
     keys     = interp2app(W_STMDict.keys_w),
     values   = interp2app(W_STMDict.values_w),
     items    = interp2app(W_STMDict.items_w),
+
+    __iter__   = interp2app(W_STMDict.iterkeys_w),
+    iterkeys   = interp2app(W_STMDict.iterkeys_w),
+    itervalues = interp2app(W_STMDict.itervalues_w),
+    iteritems  = interp2app(W_STMDict.iteritems_w),
     )
+
+W_BaseSTMDictIter.typedef = TypeDef(
+    "stmdict_iter",
+    __iter__ = interp2app(W_BaseSTMDictIter.descr_iter),
+    next = interp2app(W_BaseSTMDictIter.descr_next),
+    __length_hint__ = interp2app(W_BaseSTMDictIter.descr_length_hint),
+    )
diff --git a/pypy/module/pypystm/stmset.py b/pypy/module/pypystm/stmset.py
--- a/pypy/module/pypystm/stmset.py
+++ b/pypy/module/pypystm/stmset.py
@@ -2,6 +2,7 @@
 The class pypystm.stmset, giving a part of the regular 'set' interface
 """
 
+from pypy.interpreter.error import OperationError
 from pypy.interpreter.baseobjspace import W_Root
 from pypy.interpreter.typedef import TypeDef
 from pypy.interpreter.gateway import interp2app
@@ -150,8 +151,48 @@
         return space.wrap(self.get_length())
 
     def iter_w(self, space):
-        # not a real lazy iterator!
-        return space.iter(space.newlist(self.get_items_w()))
+        return W_STMSetIter(self.h)
+
+
+class W_STMSetIter(W_Root):
+    _immutable_fields_ = ["hiter"]
+    next_from_same_hash = 0
+
+    def __init__(self, hobj):
+        self.hiter = hobj.iterentries()
+
+    def descr_iter(self, space):
+        return self
+
+    def descr_length_hint(self, space):
+        # xxx estimate: doesn't remove the items already yielded,
+        # and uses the faster len_estimate(); on the other hand,
+        # counts only one for every 64-bit hash value
+        return space.wrap(self.hiter.hashtable.len_estimate())
+
+    def descr_next(self, space):
+        if self.next_from_same_hash == 0:      # common case
+            try:
+                entry = self.hiter.next()
+            except StopIteration:
+                raise OperationError(space.w_StopIteration, space.w_None)
+            index = 0
+            array = lltype.cast_opaque_ptr(PARRAY, entry.object)
+        else:
+            index = self.next_from_same_hash
+            array = self.next_array
+            self.next_from_same_hash = 0
+            self.next_array = lltype.nullptr(ARRAY)
+        #
+        if len(array) > index + 1:      # uncommon case
+            self.next_from_same_hash = index + 1
+            self.next_array = array
+        #
+        return cast_gcref_to_instance(W_Root, array[index])
+
+    def _cleanup_(self):
+        raise Exception("seeing a prebuilt %r object" % (
+            self.__class__,))
 
 
 def W_STMSet___new__(space, w_subtype):
@@ -170,3 +211,10 @@
     __len__ = interp2app(W_STMSet.len_w),
     __iter__ = interp2app(W_STMSet.iter_w),
     )
+
+W_STMSetIter.typedef = TypeDef(
+    "stmset_iter",
+    __iter__ = interp2app(W_STMSetIter.descr_iter),
+    next = interp2app(W_STMSetIter.descr_next),
+    __length_hint__ = interp2app(W_STMSetIter.descr_length_hint),
+    )
diff --git a/pypy/module/pypystm/test/test_hashtable.py b/pypy/module/pypystm/test/test_hashtable.py
--- a/pypy/module/pypystm/test/test_hashtable.py
+++ b/pypy/module/pypystm/test/test_hashtable.py
@@ -55,3 +55,13 @@
         assert sorted(h.keys()) == [42, 43]
         assert sorted(h.values()) == ["bar", "foo"]
         assert sorted(h.items()) == [(42, "foo"), (43, "bar")]
+
+    def test_iterator(self):
+        import pypystm
+        h = pypystm.hashtable()
+        h[42] = "foo"
+        h[43] = "bar"
+        assert sorted(h) == [42, 43]
+        assert sorted(h.iterkeys()) == [42, 43]
+        assert sorted(h.itervalues()) == ["bar", "foo"]
+        assert sorted(h.iteritems()) == [(42, "foo"), (43, "bar")]
diff --git a/pypy/module/pypystm/test/test_stmdict.py b/pypy/module/pypystm/test/test_stmdict.py
--- a/pypy/module/pypystm/test/test_stmdict.py
+++ b/pypy/module/pypystm/test/test_stmdict.py
@@ -158,3 +158,24 @@
         assert a not in d
         assert b not in d
         assert d.keys() == []
+
+
+    def test_iterator(self):
+        import pypystm
+        class A(object):
+            def __hash__(self):
+                return 42
+        class B(object):
+            pass
+        d = pypystm.stmdict()
+        a1 = A()
+        a2 = A()
+        b0 = B()
+        d[a1] = "foo"
+        d[a2] = None
+        d[b0] = "bar"
+        assert sorted(d) == sorted([a1, a2, b0])
+        assert sorted(d.iterkeys()) == sorted([a1, a2, b0])
+        assert sorted(d.itervalues()) == [None, "bar", "foo"]
+        assert sorted(d.iteritems()) == sorted([(a1, "foo"), (a2, None),
+                                                (b0, "bar")])
diff --git a/pypy/module/pypystm/test/test_stmset.py b/pypy/module/pypystm/test/test_stmset.py
--- a/pypy/module/pypystm/test/test_stmset.py
+++ b/pypy/module/pypystm/test/test_stmset.py
@@ -83,3 +83,19 @@
         assert len(s) == 2
         items = list(s)
         assert items == [42.5, key3] or items == [key3, 42.5]
+
+    def test_iterator(self):
+        import pypystm
+        class A(object):
+            def __hash__(self):
+                return 42
+        class B(object):
+            pass
+        d = pypystm.stmset()
+        a1 = A()
+        a2 = A()
+        b0 = B()
+        d.add(a1)
+        d.add(a2)
+        d.add(b0)
+        assert sorted(d) == sorted([a1, a2, b0])
diff --git a/rpython/rlib/rstm.py b/rpython/rlib/rstm.py
--- a/rpython/rlib/rstm.py
+++ b/rpython/rlib/rstm.py
@@ -223,11 +223,13 @@
 # ____________________________________________________________
 
 _STM_HASHTABLE_P = rffi.COpaquePtr('stm_hashtable_t')
+_STM_HASHTABLE_TABLE_P = rffi.COpaquePtr('stm_hashtable_table_t')
 
 _STM_HASHTABLE_ENTRY = lltype.GcStruct('HASHTABLE_ENTRY',
                                        ('index', lltype.Unsigned),
                                        ('object', llmemory.GCREF))
 _STM_HASHTABLE_ENTRY_P = lltype.Ptr(_STM_HASHTABLE_ENTRY)
+_STM_HASHTABLE_ENTRY_PP = rffi.CArrayPtr(_STM_HASHTABLE_ENTRY_P)
 _STM_HASHTABLE_ENTRY_ARRAY = lltype.GcArray(_STM_HASHTABLE_ENTRY_P)
 
 @dont_look_inside
@@ -245,6 +247,11 @@
                                    lltype.nullptr(_STM_HASHTABLE_ENTRY_ARRAY))
 
 @dont_look_inside
+def _ll_hashtable_len_estimate(h):
+    return llop.stm_hashtable_length_upper_bound(lltype.Signed,
+                                                 h.ll_raw_hashtable)
+
+ at dont_look_inside
 def _ll_hashtable_list(h):
     upper_bound = llop.stm_hashtable_length_upper_bound(lltype.Signed,
                                                         h.ll_raw_hashtable)
@@ -264,6 +271,28 @@
 def _ll_hashtable_writeobj(h, entry, value):
     llop.stm_hashtable_write_entry(lltype.Void, h, entry, value)
 
+ at dont_look_inside
+def _ll_hashtable_iterentries(h):
+    rgc.register_custom_trace_hook(_HASHTABLE_ITER_OBJ,
+                                   lambda_hashtable_iter_trace)
+    table = llop.stm_hashtable_iter(_STM_HASHTABLE_TABLE_P, h.ll_raw_hashtable)
+    hiter = lltype.malloc(_HASHTABLE_ITER_OBJ)
+    hiter.hashtable = h    # for keepalive
+    hiter.table = table
+    hiter.prev = lltype.nullptr(_STM_HASHTABLE_ENTRY_PP.TO)
+    return hiter
+
+ at dont_look_inside
+def _ll_hashiter_next(hiter):
+    entrypp = llop.stm_hashtable_iter_next(_STM_HASHTABLE_ENTRY_PP,
+                                           hiter.hashtable,
+                                           hiter.table,
+                                           hiter.prev)
+    if not entrypp:
+        raise StopIteration
+    hiter.prev = entrypp
+    return entrypp[0]
+
 _HASHTABLE_OBJ = lltype.GcStruct('HASHTABLE_OBJ',
                                  ('ll_raw_hashtable', _STM_HASHTABLE_P),
                                  hints={'immutable': True},
@@ -271,11 +300,19 @@
                                  adtmeths={'get': _ll_hashtable_get,
                                            'set': _ll_hashtable_set,
                                            'len': _ll_hashtable_len,
+                                  'len_estimate': _ll_hashtable_len_estimate,
                                           'list': _ll_hashtable_list,
                                         'lookup': _ll_hashtable_lookup,
-                                      'writeobj': _ll_hashtable_writeobj})
+                                      'writeobj': _ll_hashtable_writeobj,
+                                   'iterentries': _ll_hashtable_iterentries})
 NULL_HASHTABLE = lltype.nullptr(_HASHTABLE_OBJ)
 
+_HASHTABLE_ITER_OBJ = lltype.GcStruct('HASHTABLE_ITER_OBJ',
+                                      ('hashtable', lltype.Ptr(_HASHTABLE_OBJ)),
+                                      ('table', _STM_HASHTABLE_TABLE_P),
+                                      ('prev', _STM_HASHTABLE_ENTRY_PP),
+                                      adtmeths={'next': _ll_hashiter_next})
+
 def _ll_hashtable_trace(gc, obj, callback, arg):
     from rpython.memory.gctransform.stmframework import get_visit_function
     visit_fn = get_visit_function(callback, arg)
@@ -288,6 +325,15 @@
         llop.stm_hashtable_free(lltype.Void, h.ll_raw_hashtable)
 lambda_hashtable_finlz = lambda: _ll_hashtable_finalizer
 
+def _ll_hashtable_iter_trace(gc, obj, callback, arg):
+    from rpython.memory.gctransform.stmframework import get_visit_function
+    addr = obj + llmemory.offsetof(_HASHTABLE_ITER_OBJ, 'hashtable')
+    gc._trace_callback(callback, arg, addr)
+    visit_fn = get_visit_function(callback, arg)
+    addr = obj + llmemory.offsetof(_HASHTABLE_ITER_OBJ, 'table')
+    llop.stm_hashtable_iter_tracefn(lltype.Void, addr.address[0], visit_fn)
+lambda_hashtable_iter_trace = lambda: _ll_hashtable_iter_trace
+
 _false = CDefinedIntSymbolic('0', default=0)    # remains in the C code
 
 @dont_look_inside
@@ -344,6 +390,9 @@
         items = [self.lookup(key) for key, v in self._content.items() if v.object != NULL_GCREF]
         return len(items)
 
+    def len_estimate(self):
+        return len(self._content)
+
     def list(self):
         items = [self.lookup(key) for key, v in self._content.items() if v.object != NULL_GCREF]
         count = len(items)
@@ -359,6 +408,9 @@
         assert isinstance(entry, EntryObjectForTest)
         self.set(entry.key, nvalue)
 
+    def iterentries(self):
+        return IterEntriesForTest(self, self._content.itervalues())
+
 class EntryObjectForTest(object):
     def __init__(self, hashtable, key):
         self.hashtable = hashtable
@@ -374,6 +426,14 @@
 
     object = property(_getobj, _setobj)
 
+class IterEntriesForTest(object):
+    def __init__(self, hashtable, iterator):
+        self.hashtable = hashtable
+        self.iterator = iterator
+
+    def next(self):
+        return next(self.iterator)
+
 # ____________________________________________________________
 
 _STM_QUEUE_P = rffi.COpaquePtr('stm_queue_t')
diff --git a/rpython/rtyper/lltypesystem/lloperation.py b/rpython/rtyper/lltypesystem/lloperation.py
--- a/rpython/rtyper/lltypesystem/lloperation.py
+++ b/rpython/rtyper/lltypesystem/lloperation.py
@@ -486,6 +486,10 @@
     'stm_hashtable_list'  :   LLOp(),
     'stm_hashtable_tracefn':  LLOp(),
 
+    'stm_hashtable_iter':         LLOp(),
+    'stm_hashtable_iter_next':    LLOp(),
+    'stm_hashtable_iter_tracefn': LLOp(),
+
     'stm_queue_create':       LLOp(),
     'stm_queue_free':         LLOp(),
     'stm_queue_get':          LLOp(canmallocgc=True),   # push roots!
diff --git a/rpython/translator/stm/funcgen.py b/rpython/translator/stm/funcgen.py
--- a/rpython/translator/stm/funcgen.py
+++ b/rpython/translator/stm/funcgen.py
@@ -398,9 +398,28 @@
     arg0 = funcgen.expr(op.args[0])
     arg1 = funcgen.expr(op.args[1])
     arg2 = funcgen.expr(op.args[2])
-    return ('stm_hashtable_tracefn(%s, (stm_hashtable_t *)%s, '
+    return ('stm_hashtable_tracefn(%s, (stm_hashtable_t *)%s,'
             ' (void(*)(object_t**))%s);' % (arg0, arg1, arg2))
 
+def stm_hashtable_iter(funcgen, op):
+    arg0 = funcgen.expr(op.args[0])
+    result = funcgen.expr(op.result)
+    return '%s = stm_hashtable_iter(%s);' % (result, arg0)
+
+def stm_hashtable_iter_next(funcgen, op):
+    arg0 = funcgen.expr(op.args[0])
+    arg1 = funcgen.expr(op.args[1])
+    arg2 = funcgen.expr(op.args[2])
+    result = funcgen.expr(op.result)
+    return ('%s = stm_hashtable_iter_next(%s, %s, %s);' %
+            (result, arg0, arg1, arg2))
+
+def stm_hashtable_iter_tracefn(funcgen, op):
+    arg0 = funcgen.expr(op.args[0])
+    arg1 = funcgen.expr(op.args[1])
+    return ('stm_hashtable_iter_tracefn((stm_hashtable_table_t *)%s,'
+            ' (void(*)(object_t**))%s);' % (arg0, arg1))
+
 def stm_queue_create(funcgen, op):
     result = funcgen.expr(op.result)
     return '%s = stm_queue_create();' % (result,)
diff --git a/rpython/translator/stm/src_stm/revision b/rpython/translator/stm/src_stm/revision
--- a/rpython/translator/stm/src_stm/revision
+++ b/rpython/translator/stm/src_stm/revision
@@ -1,1 +1,1 @@
-41227d7659ac
+72facb6e4533
diff --git a/rpython/translator/stm/src_stm/stm/core.c b/rpython/translator/stm/src_stm/stm/core.c
--- a/rpython/translator/stm/src_stm/stm/core.c
+++ b/rpython/translator/stm/src_stm/stm/core.c
@@ -1374,6 +1374,8 @@
            from its segment.  Better do it as soon as possible, because
            other threads might be spin-looping, waiting for the -1 to
            disappear. */
+        /* but first, emit commit-event of this thread: */
+        timing_event(STM_SEGMENT->running_thread, STM_TRANSACTION_COMMIT);
         STM_SEGMENT->running_thread = NULL;
         write_fence();
         assert(_stm_detached_inevitable_from_thread == -1);
diff --git a/rpython/translator/stm/src_stm/stm/detach.c b/rpython/translator/stm/src_stm/stm/detach.c
--- a/rpython/translator/stm/src_stm/stm/detach.c
+++ b/rpython/translator/stm/src_stm/stm/detach.c
@@ -127,6 +127,7 @@
         // XXX: not sure if the next line is a good idea
         tl->last_associated_segment_num = remote_seg_num;
         ensure_gs_register(remote_seg_num);
+        timing_event(STM_SEGMENT->running_thread, STM_TRANSACTION_REATTACH);
         commit_external_inevitable_transaction();
     }
     dprintf(("reattach_transaction: start a new transaction\n"));
@@ -185,6 +186,7 @@
     assert(segnum > 0);
 
     ensure_gs_register(segnum);
+    timing_event(STM_SEGMENT->running_thread, STM_TRANSACTION_REATTACH);
     commit_external_inevitable_transaction();
     ensure_gs_register(mysegnum);
 }
diff --git a/rpython/translator/stm/src_stm/stm/finalizer.c b/rpython/translator/stm/src_stm/stm/finalizer.c
--- a/rpython/translator/stm/src_stm/stm/finalizer.c
+++ b/rpython/translator/stm/src_stm/stm/finalizer.c
@@ -501,7 +501,17 @@
     /* XXX: become inevitable, bc. otherwise, we would need to keep
        around the original g_finalizers.run_finalizers to restore it
        in case of an abort. */
-    _stm_become_inevitable("finalizer-Tx");
+    _stm_become_inevitable(MSG_INEV_DONT_SLEEP);
+    /* did it work? */
+    if (STM_PSEGMENT->transaction_state != TS_INEVITABLE) {   /* no */
+        /* avoid blocking here, waiting for another INEV transaction.
+           If we did that, application code could not proceed (start the
+           next transaction) and it will not be obvious from the profile
+           why we were WAITing. */
+        _stm_commit_transaction();
+        stm_rewind_jmp_leaveframe(tl, &rjbuf);
+        return;
+    }
 
     while (__sync_lock_test_and_set(&g_finalizers.lock, 1) != 0) {
         /* somebody is adding more finalizers (_commit_finalizer()) */
diff --git a/rpython/translator/stm/src_stm/stm/gcpage.c b/rpython/translator/stm/src_stm/stm/gcpage.c
--- a/rpython/translator/stm/src_stm/stm/gcpage.c
+++ b/rpython/translator/stm/src_stm/stm/gcpage.c
@@ -224,6 +224,9 @@
    version and thus don't need tracing. */
 static struct list_s *marked_objects_to_trace;
 
+/* a list of hobj/hashtable pairs for all hashtables seen */
+static struct list_s *all_hashtables_seen = NULL;
+
 /* we use the sharing seg0's pages for the GCFLAG_VISITED flag */
 
 static inline struct object_s *mark_loc(object_t *obj)
@@ -301,8 +304,6 @@
 }
 
 
-#define TRACE_FOR_MAJOR_COLLECTION  (&mark_record_trace)
-
 static void mark_and_trace(
     object_t *obj,
     char *segment_base, /* to trace obj in */
@@ -791,6 +792,7 @@
 
     /* marking */
     LIST_CREATE(marked_objects_to_trace);
+    LIST_CREATE(all_hashtables_seen);
     mark_visit_from_modified_objects();
     mark_visit_from_markers();
     mark_visit_from_roots();
@@ -815,6 +817,10 @@
     sweep_large_objects();
     sweep_small_objects();
 
+    /* hashtables */
+    stm_compact_hashtables();
+    LIST_FREE(all_hashtables_seen);
+
     dprintf((" | used after collection:  %ld\n",
              (long)pages_ctl.total_allocated));
     dprintf((" `----------------------------------------------\n"));
diff --git a/rpython/translator/stm/src_stm/stm/hashtable.c b/rpython/translator/stm/src_stm/stm/hashtable.c
--- a/rpython/translator/stm/src_stm/stm/hashtable.c
+++ b/rpython/translator/stm/src_stm/stm/hashtable.c
@@ -49,8 +49,12 @@
 #define PERTURB_SHIFT            5
 #define RESIZING_LOCK            0
 
-typedef struct {
-    uintptr_t mask;
+#define TRACE_FLAG_OFF              0
+#define TRACE_FLAG_ONCE             1
+#define TRACE_FLAG_KEEPALIVE        2
+
+struct stm_hashtable_table_s {
+    uintptr_t mask;      /* 'mask' is always immutable. */
 
     /* 'resize_counter' start at an odd value, and is decremented (by
        6) for every new item put in 'items'.  When it crosses 0, we
@@ -63,8 +67,10 @@
     */
     uintptr_t resize_counter;
 
+    uint8_t trace_flag;
+
     stm_hashtable_entry_t *items[INITIAL_HASHTABLE_SIZE];
-} stm_hashtable_table_t;
+};
 
 #define IS_EVEN(p) (((p) & 1) == 0)
 
@@ -79,6 +85,7 @@
 {
     table->mask = itemcount - 1;
     table->resize_counter = itemcount * 4 + 1;
+    table->trace_flag = TRACE_FLAG_OFF;
     memset(table->items, 0, itemcount * sizeof(stm_hashtable_entry_t *));
 }
 
@@ -162,6 +169,7 @@
     assert(biggertable);   // XXX
 
     stm_hashtable_table_t *table = hashtable->table;
+    table->trace_flag = TRACE_FLAG_ONCE;
     table->resize_counter = (uintptr_t)biggertable;
     /* ^^^ this unlocks the table by writing a non-zero value to
        table->resize_counter, but the new value is a pointer to the
@@ -485,6 +493,41 @@
 static void _stm_compact_hashtable(struct object_s *hobj,
                                    stm_hashtable_t *hashtable)
 {
+    /* Walk the chained list that starts at 'hashtable->initial_table'
+       and follows the 'resize_counter' fields.  Remove all tables
+       except (1) the initial one, (2) the most recent one, and (3)
+       the ones on which stm_hashtable_iter_tracefn() was called.
+    */
+    stm_hashtable_table_t *most_recent_table = hashtable->table;
+    assert(!IS_EVEN(most_recent_table->resize_counter));
+    /* set the "don't free me" flag on the most recent table */
+    most_recent_table->trace_flag = TRACE_FLAG_KEEPALIVE;
+
+    stm_hashtable_table_t *known_alive = &hashtable->initial_table;
+    known_alive->trace_flag = TRACE_FLAG_OFF;
+    /* a KEEPALIVE flag is ignored on the initial table: it is never
+       individually freed anyway */
+
+    while (known_alive != most_recent_table) {
+        uintptr_t rc = known_alive->resize_counter;
+        assert(IS_EVEN(rc));
+        assert(rc != RESIZING_LOCK);
+
+        stm_hashtable_table_t *next_table = (stm_hashtable_table_t *)rc;
+        if (next_table->trace_flag != TRACE_FLAG_KEEPALIVE) {
+            /* free this next table and relink the chained list to skip it */
+            assert(IS_EVEN(next_table->resize_counter));
+            known_alive->resize_counter = next_table->resize_counter;
+            free(next_table);
+        }
+        else {
+            /* this next table is kept alive */
+            known_alive = next_table;
+            known_alive->trace_flag = TRACE_FLAG_OFF;
+        }
+    }
+    /* done the first part */
+
     stm_hashtable_table_t *table = hashtable->table;
     uintptr_t rc = table->resize_counter;
     assert(!IS_EVEN(rc));
@@ -515,35 +558,24 @@
         dprintf(("compact with %ld items:\n", num_entries_times_6 / 6));
         _stm_rehash_hashtable(hashtable, count, segnum);
     }
+}
 
-    table = hashtable->table;
-    assert(!IS_EVEN(table->resize_counter));
-
-    if (table != &hashtable->initial_table) {
-        uintptr_t rc = hashtable->initial_table.resize_counter;
-        while (1) {
-            assert(IS_EVEN(rc));
-            assert(rc != RESIZING_LOCK);
-
-            stm_hashtable_table_t *old_table = (stm_hashtable_table_t *)rc;
-            if (old_table == table)
-                break;
-            rc = old_table->resize_counter;
-            free(old_table);
-        }
-        hashtable->initial_table.resize_counter = (uintptr_t)table;
-        assert(IS_EVEN(hashtable->initial_table.resize_counter));
+static void stm_compact_hashtables(void)
+{
+    uintptr_t i = all_hashtables_seen->count;
+    while (i > 0) {
+        i -= 2;
+        _stm_compact_hashtable(
+            (struct object_s *)all_hashtables_seen->items[i],
+            (stm_hashtable_t *)all_hashtables_seen->items[i + 1]);
     }
 }
 
-void stm_hashtable_tracefn(struct object_s *hobj, stm_hashtable_t *hashtable,
-                           void trace(object_t **))
+static void _hashtable_tracefn(stm_hashtable_table_t *table,
+                               void trace(object_t **))
 {
-    if (trace == TRACE_FOR_MAJOR_COLLECTION)
-        _stm_compact_hashtable(hobj, hashtable);
-
-    stm_hashtable_table_t *table;
-    table = VOLATILE_HASHTABLE(hashtable)->table;
+    if (table->trace_flag == TRACE_FLAG_ONCE)
+        table->trace_flag = TRACE_FLAG_OFF;
 
     uintptr_t j, mask = table->mask;
     for (j = 0; j <= mask; j++) {
@@ -554,3 +586,105 @@
         }
     }
 }
+
+void stm_hashtable_tracefn(struct object_s *hobj, stm_hashtable_t *hashtable,
+                           void trace(object_t **))
+{
+    if (all_hashtables_seen != NULL)
+        all_hashtables_seen = list_append2(all_hashtables_seen,
+                                           (uintptr_t)hobj,
+                                           (uintptr_t)hashtable);
+
+    _hashtable_tracefn(VOLATILE_HASHTABLE(hashtable)->table, trace);
+}
+
+
+/* Hashtable iterators */
+
+/* TRACE_FLAG_ONCE: the table must be traced once if it supports an iterator
+   TRACE_FLAG_OFF: the table is the most recent table, or has already been
+       traced once
+   TRACE_FLAG_KEEPALIVE: during major collection only: mark tables that
+       must be kept alive because there are iterators
+*/
+
+struct stm_hashtable_table_s *stm_hashtable_iter(stm_hashtable_t *hashtable)
+{
+    /* Get the table.  No synchronization is needed: we may miss some
+       entries that are being added, but they would contain NULL in
+       this segment anyway. */
+    return VOLATILE_HASHTABLE(hashtable)->table;
+}
+
+stm_hashtable_entry_t **
+stm_hashtable_iter_next(object_t *hobj, stm_hashtable_table_t *table,
+                        stm_hashtable_entry_t **previous)
+{
+    /* Set the read marker on hobj for every item, in case we have
+       transaction breaks in-between.
+    */
+    stm_read(hobj);
+
+    /* Get the bounds of the part of the 'stm_hashtable_entry_t *' array
+       that we have to check */
+    stm_hashtable_entry_t **pp, **last;
+    if (previous == NULL)
+        pp = table->items;
+    else
+        pp = previous + 1;
+    last = table->items + table->mask;
+
+    /* Find the first non-null entry */
+    stm_hashtable_entry_t *entry;
+
+    while (pp <= last) {
+        entry = *(stm_hashtable_entry_t *volatile *)pp;
+        if (entry != NULL) {
+            stm_read((object_t *)entry);
+            if (entry->object != NULL) {
+                //fprintf(stderr, "stm_hashtable_iter_next(%p, %p, %p) = %p\n",
+                //        hobj, table, previous, pp);
+                return pp;
+            }
+        }
+        ++pp;
+    }
+    //fprintf(stderr, "stm_hashtable_iter_next(%p, %p, %p) = %p\n",
+    //        hobj, table, previous, NULL);
+    return NULL;
+}
+
+void stm_hashtable_iter_tracefn(stm_hashtable_table_t *table,
+                                void trace(object_t **))
+{
+    if (all_hashtables_seen == NULL) {   /* for minor collections */
+
+        /* During minor collection, tracing the table is only required
+           the first time: if it contains young objects, they must be
+           kept alive and have their address updated.  We use
+           TRACE_FLAG_ONCE to know that.  We don't need to do it if
+           our 'table' is the latest version, because in that case it
+           will be done by stm_hashtable_tracefn().  That's why
+           TRACE_FLAG_ONCE is only set when a more recent table is
+           attached.
+
+           It is only needed once: non-latest-version tables are
+           immutable.  We mark once all the entries as old, and
+           then these now-old objects stay alive until the next
+           major collection.
+
+           Checking the flag can be done without synchronization: it
+           never wrong to call _hashtable_tracefn() too much, and the
+           only case where it *has to* be called occurs if the
+           hashtable object is still young (and not seen by other
+           threads).
+        */
+        if (table->trace_flag == TRACE_FLAG_ONCE)
+            _hashtable_tracefn(table, trace);
+    }
+    else {       /* for major collections */
+
+        /* Set this flag for _stm_compact_hashtable() */
+        table->trace_flag = TRACE_FLAG_KEEPALIVE;
+    }
+}
diff --git a/rpython/translator/stm/src_stm/stm/hashtable.h b/rpython/translator/stm/src_stm/stm/hashtable.h
new file mode 100644
--- /dev/null
+++ b/rpython/translator/stm/src_stm/stm/hashtable.h
@@ -0,0 +1,2 @@
+/* Imported by rpython/translator/stm/import_stmgc.py */
+static void stm_compact_hashtables(void);
diff --git a/rpython/translator/stm/src_stm/stmgc.c b/rpython/translator/stm/src_stm/stmgc.c
--- a/rpython/translator/stm/src_stm/stmgc.c
+++ b/rpython/translator/stm/src_stm/stmgc.c
@@ -20,6 +20,7 @@
 #include "stm/finalizer.h"
 #include "stm/locks.h"
 #include "stm/detach.h"
+#include "stm/hashtable.h"
 #include "stm/queue.h"
 #include "stm/misc.c"
 #include "stm/list.c"
diff --git a/rpython/translator/stm/src_stm/stmgc.h b/rpython/translator/stm/src_stm/stmgc.h
--- a/rpython/translator/stm/src_stm/stmgc.h
+++ b/rpython/translator/stm/src_stm/stmgc.h
@@ -100,6 +100,8 @@
 #define _stm_detach_inevitable_transaction(tl)  do {                    \
     write_fence();                                                      \
     assert(_stm_detached_inevitable_from_thread == 0);                  \
+    if (stmcb_timing_event != NULL && tl->self_or_0_if_atomic != 0)     \
+        {stmcb_timing_event(tl, STM_TRANSACTION_DETACH, NULL);}         \
     _stm_detached_inevitable_from_thread = tl->self_or_0_if_atomic;     \
 } while (0)
 void _stm_reattach_transaction(intptr_t);
@@ -416,69 +418,6 @@
 #endif
 
 
-/* Entering and leaving a "transactional code zone": a (typically very
-   large) section in the code where we are running a transaction.
-   This is the STM equivalent to "acquire the GIL" and "release the
-   GIL", respectively.  stm_read(), stm_write(), stm_allocate(), and
-   other functions should only be called from within a transaction.
-
-   Note that transactions, in the STM sense, cover _at least_ one
-   transactional code zone.  They may be longer; for example, if one
-   thread does a lot of stm_enter_transactional_zone() +
-   stm_become_inevitable() + stm_leave_transactional_zone(), as is
-   typical in a thread that does a lot of C function calls, then we
-   get only a few bigger inevitable transactions that cover the many
-   short transactional zones.  This is done by having
-   stm_leave_transactional_zone() turn the current transaction
-   inevitable and detach it from the running thread (if there is no
-   other inevitable transaction running so far).  Then
-   stm_enter_transactional_zone() will try to reattach to it.  This is
-   far more efficient than constantly starting and committing
-   transactions.
-
-   stm_enter_transactional_zone() and stm_leave_transactional_zone()
-   preserve the value of errno.
-*/
-#ifdef STM_DEBUGPRINT
-#include <stdio.h>
-#endif
-static inline void stm_enter_transactional_zone(stm_thread_local_t *tl) {
-    intptr_t self = tl->self_or_0_if_atomic;
-    if (__sync_bool_compare_and_swap(&_stm_detached_inevitable_from_thread,
-                                     self, 0)) {
-#ifdef STM_DEBUGPRINT
-        fprintf(stderr, "stm_enter_transactional_zone fast path\n");
-#endif
-    }
-    else {
-        _stm_reattach_transaction(self);
-        /* _stm_detached_inevitable_from_thread should be 0 here, but
-           it can already have been changed from a parallel thread
-           (assuming we're not inevitable ourselves) */
-    }
-}
-static inline void stm_leave_transactional_zone(stm_thread_local_t *tl) {
-    assert(STM_SEGMENT->running_thread == tl);
-    if (stm_is_inevitable(tl)) {
-#ifdef STM_DEBUGPRINT
-        fprintf(stderr, "stm_leave_transactional_zone fast path\n");
-#endif
-        _stm_detach_inevitable_transaction(tl);
-    }
-    else {
-        _stm_leave_noninevitable_transactional_zone();
-    }
-}
-
-/* stm_force_transaction_break() is in theory equivalent to
-   stm_leave_transactional_zone() immediately followed by
-   stm_enter_transactional_zone(); however, it is supposed to be
-   called in CPU-heavy threads that had a transaction run for a while,
-   and so it *always* forces a commit and starts the next transaction.
-   The new transaction is never inevitable.  See also
-   stm_should_break_transaction(). */
-void stm_force_transaction_break(stm_thread_local_t *tl);
-
 /* Abort the currently running transaction.  This function never
    returns: it jumps back to the start of the transaction (which must
    not be inevitable). */
@@ -596,6 +535,10 @@
     STM_TRANSACTION_COMMIT,
     STM_TRANSACTION_ABORT,
 
+    /* DETACH/REATTACH is used for leaving/reentering the transactional */
+    STM_TRANSACTION_DETACH,
+    STM_TRANSACTION_REATTACH,
+
     /* inevitable contention: all threads that try to become inevitable
        have a STM_BECOME_INEVITABLE event with a position marker.  Then,
        if it waits it gets a STM_WAIT_OTHER_INEVITABLE.  It is possible
@@ -688,6 +631,75 @@
 } while (0)
 
 
+
+/* Entering and leaving a "transactional code zone": a (typically very
+   large) section in the code where we are running a transaction.
+   This is the STM equivalent to "acquire the GIL" and "release the
+   GIL", respectively.  stm_read(), stm_write(), stm_allocate(), and
+   other functions should only be called from within a transaction.
+
+   Note that transactions, in the STM sense, cover _at least_ one
+   transactional code zone.  They may be longer; for example, if one
+   thread does a lot of stm_enter_transactional_zone() +
+   stm_become_inevitable() + stm_leave_transactional_zone(), as is
+   typical in a thread that does a lot of C function calls, then we
+   get only a few bigger inevitable transactions that cover the many
+   short transactional zones.  This is done by having
+   stm_leave_transactional_zone() turn the current transaction
+   inevitable and detach it from the running thread (if there is no
+   other inevitable transaction running so far).  Then
+   stm_enter_transactional_zone() will try to reattach to it.  This is
+   far more efficient than constantly starting and committing
+   transactions.
+
+   stm_enter_transactional_zone() and stm_leave_transactional_zone()
+   preserve the value of errno.
+*/
+#ifdef STM_DEBUGPRINT
+#include <stdio.h>
+#endif
+static inline void stm_enter_transactional_zone(stm_thread_local_t *tl) {
+    intptr_t self = tl->self_or_0_if_atomic;
+    if (__sync_bool_compare_and_swap(&_stm_detached_inevitable_from_thread,
+                                     self, 0)) {
+        if (self != 0 && stmcb_timing_event != NULL) {
+            /* for atomic transactions, we don't emit DETACH/REATTACH */
+            stmcb_timing_event(tl, STM_TRANSACTION_REATTACH, NULL);
+        }
+#ifdef STM_DEBUGPRINT
+        fprintf(stderr, "stm_enter_transactional_zone fast path\n");
+#endif
+    }
+    else {
+        _stm_reattach_transaction(self);
+        /* _stm_detached_inevitable_from_thread should be 0 here, but
+           it can already have been changed from a parallel thread
+           (assuming we're not inevitable ourselves) */
+    }
+}
+static inline void stm_leave_transactional_zone(stm_thread_local_t *tl) {
+    assert(STM_SEGMENT->running_thread == tl);
+    if (stm_is_inevitable(tl)) {
+#ifdef STM_DEBUGPRINT
+        fprintf(stderr, "stm_leave_transactional_zone fast path\n");
+#endif
+        _stm_detach_inevitable_transaction(tl);
+    }
+    else {
+        _stm_leave_noninevitable_transactional_zone();
+    }
+}
+
+/* stm_force_transaction_break() is in theory equivalent to
+   stm_leave_transactional_zone() immediately followed by
+   stm_enter_transactional_zone(); however, it is supposed to be
+   called in CPU-heavy threads that had a transaction run for a while,
+   and so it *always* forces a commit and starts the next transaction.
+   The new transaction is never inevitable.  See also
+   stm_should_break_transaction(). */
+void stm_force_transaction_break(stm_thread_local_t *tl);
+
+
 /* Support for light finalizers.  This is a simple version of
    finalizers that guarantees not to do anything fancy, like not
    resurrecting objects. */
@@ -755,6 +767,21 @@
     object_t *object;
 };
 
+/* Hashtable iterators.  You get a raw 'table' pointer when you make
+   an iterator, which you pass to stm_hashtable_iter_next().  This may
+   or may not return items added after stm_hashtable_iter() was
+   called; there is no logic so far to detect changes (unlike Python's
+   RuntimeError).  When the GC traces, you must keep the table pointer
+   alive with stm_hashtable_iter_tracefn().  The original hashtable
+   object must also be kept alive. */
+typedef struct stm_hashtable_table_s stm_hashtable_table_t;
+stm_hashtable_table_t *stm_hashtable_iter(stm_hashtable_t *);
+stm_hashtable_entry_t **
+stm_hashtable_iter_next(object_t *hobj, stm_hashtable_table_t *table,
+                        stm_hashtable_entry_t **previous);
+void stm_hashtable_iter_tracefn(stm_hashtable_table_t *table,
+                                void trace(object_t **));
+
 
 /* Queues.  The items you put() and get() back are in random order.
    Like hashtables, the type 'stm_queue_t' is not an object type at


More information about the pypy-commit mailing list