[pypy-commit] pypy stmgc-c8-dictiter: Iterators on hashtables

arigo noreply at buildbot.pypy.org
Thu Nov 5 13:27:09 EST 2015


Author: Armin Rigo <arigo at tunes.org>
Branch: stmgc-c8-dictiter
Changeset: r80553:47629bb038b7
Date: 2015-11-05 19:27 +0100
http://bitbucket.org/pypy/pypy/changeset/47629bb038b7/

Log:	Iterators on hashtables

diff --git a/pypy/module/pypystm/hashtable.py b/pypy/module/pypystm/hashtable.py
--- a/pypy/module/pypystm/hashtable.py
+++ b/pypy/module/pypystm/hashtable.py
@@ -2,6 +2,7 @@
 The class pypystm.hashtable, mapping integers to objects.
 """
 
+from pypy.interpreter.error import OperationError
 from pypy.interpreter.baseobjspace import W_Root
 from pypy.interpreter.typedef import TypeDef
 from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
@@ -78,6 +79,57 @@
                  for i in range(count)]
         return space.newlist(lst_w)
 
+    def iterkeys_w(self, space):
+        return W_HashtableIterKeys(self.h)
+
+    def itervalues_w(self, space):
+        return W_HashtableIterValues(self.h)
+
+    def iteritems_w(self, space):
+        return W_HashtableIterItems(self.h)
+
+
+class W_BaseHashtableIter(W_Root):
+    _immutable_fields_ = ["hiter"]
+
+    def __init__(self, hobj):
+        self.hiter = hobj.iterentries()
+
+    def descr_iter(self, space):
+        return self
+
+    def descr_length_hint(self, space):
+        # xxx overestimate: doesn't remove the items already yielded,
+        # and uses the faster len_estimate()
+        return space.wrap(self.hiter.hashtable.len_estimate())
+
+    def next_entry(self, space):
+        try:
+            return self.hiter.next()
+        except StopIteration:
+            raise OperationError(space.w_StopIteration, space.w_None)
+
+    def _cleanup_(self):
+        raise Exception("seeing a prebuilt %r object" % (
+            self.__class__,))
+
+class W_HashtableIterKeys(W_BaseHashtableIter):
+    def descr_next(self, space):
+        entry = self.next_entry(space)
+        return space.wrap(intmask(entry.index))
+
+class W_HashtableIterValues(W_BaseHashtableIter):
+    def descr_next(self, space):
+        entry = self.next_entry(space)
+        return cast_gcref_to_instance(W_Root, entry.object)
+
+class W_HashtableIterItems(W_BaseHashtableIter):
+    def descr_next(self, space):
+        entry = self.next_entry(space)
+        return space.newtuple([
+            space.wrap(intmask(entry.index)),
+            cast_gcref_to_instance(W_Root, entry.object)])
+
 
 def W_Hashtable___new__(space, w_subtype):
     r = space.allocate_instance(W_Hashtable, w_subtype)
@@ -98,4 +150,30 @@
     keys    = interp2app(W_Hashtable.keys_w),
     values  = interp2app(W_Hashtable.values_w),
     items   = interp2app(W_Hashtable.items_w),
+
+    __iter__   = interp2app(W_Hashtable.iterkeys_w),
+    iterkeys   = interp2app(W_Hashtable.iterkeys_w),
+    itervalues = interp2app(W_Hashtable.itervalues_w),
+    iteritems  = interp2app(W_Hashtable.iteritems_w),
 )
+
+W_HashtableIterKeys.typedef = TypeDef(
+    "hashtable_iterkeys",
+    __iter__ = interp2app(W_HashtableIterKeys.descr_iter),
+    next = interp2app(W_HashtableIterKeys.descr_next),
+    __length_hint__ = interp2app(W_HashtableIterKeys.descr_length_hint),
+    )
+
+W_HashtableIterValues.typedef = TypeDef(
+    "hashtable_itervalues",
+    __iter__ = interp2app(W_HashtableIterValues.descr_iter),
+    next = interp2app(W_HashtableIterValues.descr_next),
+    __length_hint__ = interp2app(W_HashtableIterValues.descr_length_hint),
+    )
+
+W_HashtableIterItems.typedef = TypeDef(
+    "hashtable_iteritems",
+    __iter__ = interp2app(W_HashtableIterItems.descr_iter),
+    next = interp2app(W_HashtableIterItems.descr_next),
+    __length_hint__ = interp2app(W_HashtableIterItems.descr_length_hint),
+    )
diff --git a/pypy/module/pypystm/test/test_hashtable.py b/pypy/module/pypystm/test/test_hashtable.py
--- a/pypy/module/pypystm/test/test_hashtable.py
+++ b/pypy/module/pypystm/test/test_hashtable.py
@@ -55,3 +55,13 @@
         assert sorted(h.keys()) == [42, 43]
         assert sorted(h.values()) == ["bar", "foo"]
         assert sorted(h.items()) == [(42, "foo"), (43, "bar")]
+
+    def test_iterator(self):
+        import pypystm
+        h = pypystm.hashtable()
+        h[42] = "foo"
+        h[43] = "bar"
+        assert sorted(h) == [42, 43]
+        assert sorted(h.iterkeys()) == [42, 43]
+        assert sorted(h.itervalues()) == ["bar", "foo"]
+        assert sorted(h.iteritems()) == [(42, "foo"), (43, "bar")]
diff --git a/rpython/rlib/rstm.py b/rpython/rlib/rstm.py
--- a/rpython/rlib/rstm.py
+++ b/rpython/rlib/rstm.py
@@ -223,11 +223,13 @@
 # ____________________________________________________________
 
 _STM_HASHTABLE_P = rffi.COpaquePtr('stm_hashtable_t')
+_STM_HASHTABLE_TABLE_P = rffi.COpaquePtr('stm_hashtable_table_t')
 
 _STM_HASHTABLE_ENTRY = lltype.GcStruct('HASHTABLE_ENTRY',
                                        ('index', lltype.Unsigned),
                                        ('object', llmemory.GCREF))
 _STM_HASHTABLE_ENTRY_P = lltype.Ptr(_STM_HASHTABLE_ENTRY)
+_STM_HASHTABLE_ENTRY_PP = rffi.CArrayPtr(_STM_HASHTABLE_ENTRY_P)
 _STM_HASHTABLE_ENTRY_ARRAY = lltype.GcArray(_STM_HASHTABLE_ENTRY_P)
 
 @dont_look_inside
@@ -245,6 +247,11 @@
                                    lltype.nullptr(_STM_HASHTABLE_ENTRY_ARRAY))
 
 @dont_look_inside
+def _ll_hashtable_len_estimate(h):
+    return llop.stm_hashtable_length_upper_bound(lltype.Signed,
+                                                 h.ll_raw_hashtable)
+
+ at dont_look_inside
 def _ll_hashtable_list(h):
     upper_bound = llop.stm_hashtable_length_upper_bound(lltype.Signed,
                                                         h.ll_raw_hashtable)
@@ -264,6 +271,27 @@
 def _ll_hashtable_writeobj(h, entry, value):
     llop.stm_hashtable_write_entry(lltype.Void, h, entry, value)
 
+ at dont_look_inside
+def _ll_hashtable_iterentries(h):
+    rgc.register_custom_trace_hook(_HASHTABLE_ITER_OBJ,
+                                   lambda_hashtable_iter_trace)
+    table = llop.stm_hashtable_iter(_STM_HASHTABLE_TABLE_P, h.ll_raw_hashtable)
+    hiter = lltype.malloc(_HASHTABLE_ITER_OBJ)
+    hiter.hashtable = h    # for keepalive
+    hiter.table = table
+    hiter.prev = lltype.nullptr(_STM_HASHTABLE_ENTRY_PP.TO)
+    return hiter
+
+def _ll_hashiter_next(hiter):
+    entrypp = llop.stm_hashtable_iter_next(_STM_HASHTABLE_ENTRY_PP,
+                                           hiter.hashtable,
+                                           hiter.table,
+                                           hiter.prev)
+    if not entrypp:
+        raise StopIteration
+    hiter.prev = entrypp
+    return entrypp[0]
+
 _HASHTABLE_OBJ = lltype.GcStruct('HASHTABLE_OBJ',
                                  ('ll_raw_hashtable', _STM_HASHTABLE_P),
                                  hints={'immutable': True},
@@ -271,11 +299,19 @@
                                  adtmeths={'get': _ll_hashtable_get,
                                            'set': _ll_hashtable_set,
                                            'len': _ll_hashtable_len,
+                                  'len_estimate': _ll_hashtable_len_estimate,
                                           'list': _ll_hashtable_list,
                                         'lookup': _ll_hashtable_lookup,
-                                      'writeobj': _ll_hashtable_writeobj})
+                                      'writeobj': _ll_hashtable_writeobj,
+                                   'iterentries': _ll_hashtable_iterentries})
 NULL_HASHTABLE = lltype.nullptr(_HASHTABLE_OBJ)
 
+_HASHTABLE_ITER_OBJ = lltype.GcStruct('HASHTABLE_ITER_OBJ',
+                                      ('hashtable', lltype.Ptr(_HASHTABLE_OBJ)),
+                                      ('table', _STM_HASHTABLE_TABLE_P),
+                                      ('prev', _STM_HASHTABLE_ENTRY_PP),
+                                      adtmeths={'next': _ll_hashiter_next})
+
 def _ll_hashtable_trace(gc, obj, callback, arg):
     from rpython.memory.gctransform.stmframework import get_visit_function
     visit_fn = get_visit_function(callback, arg)
@@ -288,6 +324,15 @@
         llop.stm_hashtable_free(lltype.Void, h.ll_raw_hashtable)
 lambda_hashtable_finlz = lambda: _ll_hashtable_finalizer
 
+def _ll_hashtable_iter_trace(gc, obj, callback, arg):
+    from rpython.memory.gctransform.stmframework import get_visit_function
+    addr = obj + llmemory.offsetof(_HASHTABLE_ITER_OBJ, 'hashtable')
+    gc._trace_callback(callback, arg, addr)
+    visit_fn = get_visit_function(callback, arg)
+    addr = obj + llmemory.offsetof(_HASHTABLE_ITER_OBJ, 'table')
+    llop.stm_hashtable_iter_tracefn(lltype.Void, addr.address[0], visit_fn)
+lambda_hashtable_iter_trace = lambda: _ll_hashtable_iter_trace
+
 _false = CDefinedIntSymbolic('0', default=0)    # remains in the C code
 
 @dont_look_inside
@@ -344,6 +389,9 @@
         items = [self.lookup(key) for key, v in self._content.items() if v.object != NULL_GCREF]
         return len(items)
 
+    def len_estimate(self):
+        return len(self._content)
+
     def list(self):
         items = [self.lookup(key) for key, v in self._content.items() if v.object != NULL_GCREF]
         count = len(items)
@@ -359,6 +407,9 @@
         assert isinstance(entry, EntryObjectForTest)
         self.set(entry.key, nvalue)
 
+    def iterentries(self):
+        return IterEntriesForTest(self, self._content.itervalues())
+
 class EntryObjectForTest(object):
     def __init__(self, hashtable, key):
         self.hashtable = hashtable
@@ -374,6 +425,14 @@
 
     object = property(_getobj, _setobj)
 
+class IterEntriesForTest(object):
+    def __init__(self, hashtable, iterator):
+        self.hashtable = hashtable
+        self.iterator = iterator
+
+    def next(self):
+        return next(self.iterator)
+
 # ____________________________________________________________
 
 _STM_QUEUE_P = rffi.COpaquePtr('stm_queue_t')
diff --git a/rpython/translator/stm/funcgen.py b/rpython/translator/stm/funcgen.py
--- a/rpython/translator/stm/funcgen.py
+++ b/rpython/translator/stm/funcgen.py
@@ -398,9 +398,28 @@
     arg0 = funcgen.expr(op.args[0])
     arg1 = funcgen.expr(op.args[1])
     arg2 = funcgen.expr(op.args[2])
-    return ('stm_hashtable_tracefn(%s, (stm_hashtable_t *)%s, '
+    return ('stm_hashtable_tracefn(%s, (stm_hashtable_t *)%s,'
             ' (void(*)(object_t**))%s);' % (arg0, arg1, arg2))
 
+def stm_hashtable_iter(funcgen, op):
+    arg0 = funcgen.expr(op.args[0])
+    result = funcgen.expr(op.result)
+    return '%s = stm_hashtable_iter(%s);' % (result, arg0)
+
+def stm_hashtable_iter_next(funcgen, op):
+    arg0 = funcgen.expr(op.args[0])
+    arg1 = funcgen.expr(op.args[1])
+    arg2 = funcgen.expr(op.args[2])
+    result = funcgen.expr(op.result)
+    return ('%s = stm_hashtable_iter_next(%s, %s, %s);' %
+            (arg0, arg1, arg2, result))
+
+def stm_hashtable_iter_tracefn(funcgen, op):
+    arg0 = funcgen.expr(op.args[0])
+    arg1 = funcgen.expr(op.args[1])
+    return ('stm_hashtable_tracefn((stm_hashtable_table_t *)%s,'
+            ' (void(*)(object_t**))%s);' % (arg0, arg1))
+
 def stm_queue_create(funcgen, op):
     result = funcgen.expr(op.result)
     return '%s = stm_queue_create();' % (result,)


More information about the pypy-commit mailing list