[pypy-commit] pypy stmgc-c8-dictiter: stmdict iterators

arigo noreply at buildbot.pypy.org
Thu Nov 5 16:53:54 EST 2015


Author: Armin Rigo <arigo at tunes.org>
Branch: stmgc-c8-dictiter
Changeset: r80556:ea8156b61a62
Date: 2015-11-05 22:53 +0100
http://bitbucket.org/pypy/pypy/changeset/ea8156b61a62/

Log:	stmdict iterators

diff --git a/pypy/module/pypystm/hashtable.py b/pypy/module/pypystm/hashtable.py
--- a/pypy/module/pypystm/hashtable.py
+++ b/pypy/module/pypystm/hashtable.py
@@ -103,29 +103,27 @@
         # and uses the faster len_estimate()
         return space.wrap(self.hiter.hashtable.len_estimate())
 
-    def next_entry(self, space):
+    def descr_next(self, space):
         try:
-            return self.hiter.next()
+            entry = self.hiter.next()
         except StopIteration:
             raise OperationError(space.w_StopIteration, space.w_None)
+        return self.get_final_value(space, entry)
 
     def _cleanup_(self):
         raise Exception("seeing a prebuilt %r object" % (
             self.__class__,))
 
 class W_HashtableIterKeys(W_BaseHashtableIter):
-    def descr_next(self, space):
-        entry = self.next_entry(space)
+    def get_final_value(self, space, entry):
         return space.wrap(intmask(entry.index))
 
 class W_HashtableIterValues(W_BaseHashtableIter):
-    def descr_next(self, space):
-        entry = self.next_entry(space)
+    def get_final_value(self, space, entry):
         return cast_gcref_to_instance(W_Root, entry.object)
 
 class W_HashtableIterItems(W_BaseHashtableIter):
-    def descr_next(self, space):
-        entry = self.next_entry(space)
+    def get_final_value(self, space, entry):
         return space.newtuple([
             space.wrap(intmask(entry.index)),
             cast_gcref_to_instance(W_Root, entry.object)])
@@ -157,23 +155,9 @@
     iteritems  = interp2app(W_Hashtable.iteritems_w),
 )
 
-W_HashtableIterKeys.typedef = TypeDef(
-    "hashtable_iterkeys",
-    __iter__ = interp2app(W_HashtableIterKeys.descr_iter),
-    next = interp2app(W_HashtableIterKeys.descr_next),
-    __length_hint__ = interp2app(W_HashtableIterKeys.descr_length_hint),
+W_BaseHashtableIter.typedef = TypeDef(
+    "hashtable_iter",
+    __iter__ = interp2app(W_BaseHashtableIter.descr_iter),
+    next = interp2app(W_BaseHashtableIter.descr_next),
+    __length_hint__ = interp2app(W_BaseHashtableIter.descr_length_hint),
     )
-
-W_HashtableIterValues.typedef = TypeDef(
-    "hashtable_itervalues",
-    __iter__ = interp2app(W_HashtableIterValues.descr_iter),
-    next = interp2app(W_HashtableIterValues.descr_next),
-    __length_hint__ = interp2app(W_HashtableIterValues.descr_length_hint),
-    )
-
-W_HashtableIterItems.typedef = TypeDef(
-    "hashtable_iteritems",
-    __iter__ = interp2app(W_HashtableIterItems.descr_iter),
-    next = interp2app(W_HashtableIterItems.descr_next),
-    __length_hint__ = interp2app(W_HashtableIterItems.descr_length_hint),
-    )
diff --git a/pypy/module/pypystm/stmdict.py b/pypy/module/pypystm/stmdict.py
--- a/pypy/module/pypystm/stmdict.py
+++ b/pypy/module/pypystm/stmdict.py
@@ -2,6 +2,7 @@
 The class pypystm.stmdict, giving a part of the regular 'dict' interface
 """
 
+from pypy.interpreter.error import OperationError
 from pypy.interpreter.baseobjspace import W_Root
 from pypy.interpreter.typedef import TypeDef
 from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
@@ -215,10 +216,6 @@
     def len_w(self, space):
         return space.wrap(self.get_length())
 
-    def iter_w(self, space):
-        # not a real lazy iterator!
-        return space.iter(self.keys_w(space))
-
     def keys_w(self, space):
         return space.newlist(self.get_keys_values_w(offset=0))
 
@@ -228,6 +225,70 @@
     def items_w(self, space):
         return space.newlist(self.get_items_w(space))
 
+    def iterkeys_w(self, space):
+        return W_STMDictIterKeys(self.h)
+
+    def itervalues_w(self, space):
+        return W_STMDictIterValues(self.h)
+
+    def iteritems_w(self, space):
+        return W_STMDictIterItems(self.h)
+
+
+class W_BaseSTMDictIter(W_Root):
+    _immutable_fields_ = ["hiter"]
+    next_from_same_hash = 0
+
+    def __init__(self, hobj):
+        self.hiter = hobj.iterentries()
+
+    def descr_iter(self, space):
+        return self
+
+    def descr_length_hint(self, space):
+        # xxx estimate: doesn't remove the items already yielded,
+        # and uses the faster len_estimate(); on the other hand,
+        # counts only one for every 64-bit hash value
+        return space.wrap(self.hiter.hashtable.len_estimate())
+
+    def descr_next(self, space):
+        if self.next_from_same_hash == 0:      # common case
+            try:
+                entry = self.hiter.next()
+            except StopIteration:
+                raise OperationError(space.w_StopIteration, space.w_None)
+            index = 0
+            array = lltype.cast_opaque_ptr(PARRAY, entry.object)
+        else:
+            index = self.next_from_same_hash
+            array = self.next_array
+            self.next_from_same_hash = 0
+            self.next_array = lltype.nullptr(ARRAY)
+        #
+        if len(array) > index + 2:      # uncommon case
+            self.next_from_same_hash = index + 2
+            self.next_array = array
+        #
+        return self.get_final_value(space, array, index)
+
+    def _cleanup_(self):
+        raise Exception("seeing a prebuilt %r object" % (
+            self.__class__,))
+
+class W_STMDictIterKeys(W_BaseSTMDictIter):
+    def get_final_value(self, space, array, index):
+        return cast_gcref_to_instance(W_Root, array[index])
+
+class W_STMDictIterValues(W_BaseSTMDictIter):
+    def get_final_value(self, space, array, index):
+        return cast_gcref_to_instance(W_Root, array[index + 1])
+
+class W_STMDictIterItems(W_BaseSTMDictIter):
+    def get_final_value(self, space, array, index):
+        return space.newtuple([
+            cast_gcref_to_instance(W_Root, array[index]),
+            cast_gcref_to_instance(W_Root, array[index + 1])])
+
 
 def W_STMDict___new__(space, w_subtype):
     r = space.allocate_instance(W_STMDict, w_subtype)
@@ -246,8 +307,19 @@
     setdefault = interp2app(W_STMDict.setdefault_w),
 
     __len__  = interp2app(W_STMDict.len_w),
-    __iter__ = interp2app(W_STMDict.iter_w),
     keys     = interp2app(W_STMDict.keys_w),
     values   = interp2app(W_STMDict.values_w),
     items    = interp2app(W_STMDict.items_w),
+
+    __iter__   = interp2app(W_STMDict.iterkeys_w),
+    iterkeys   = interp2app(W_STMDict.iterkeys_w),
+    itervalues = interp2app(W_STMDict.itervalues_w),
+    iteritems  = interp2app(W_STMDict.iteritems_w),
     )
+
+W_BaseSTMDictIter.typedef = TypeDef(
+    "stmdict_iter",
+    __iter__ = interp2app(W_BaseSTMDictIter.descr_iter),
+    next = interp2app(W_BaseSTMDictIter.descr_next),
+    __length_hint__ = interp2app(W_BaseSTMDictIter.descr_length_hint),
+    )
diff --git a/pypy/module/pypystm/test/test_stmdict.py b/pypy/module/pypystm/test/test_stmdict.py
--- a/pypy/module/pypystm/test/test_stmdict.py
+++ b/pypy/module/pypystm/test/test_stmdict.py
@@ -158,3 +158,24 @@
         assert a not in d
         assert b not in d
         assert d.keys() == []
+
+
+    def test_iterator(self):
+        import pypystm
+        class A(object):
+            def __hash__(self):
+                return 42
+        class B(object):
+            pass
+        d = pypystm.stmdict()
+        a1 = A()
+        a2 = A()
+        b0 = B()
+        d[a1] = "foo"
+        d[a2] = None
+        d[b0] = "bar"
+        assert sorted(d) == sorted([a1, a2, b0])
+        assert sorted(d.iterkeys()) == sorted([a1, a2, b0])
+        assert sorted(d.itervalues()) == [None, "bar", "foo"]
+        assert sorted(d.iteritems()) == sorted([(a1, "foo"), (a2, None),
+                                                (b0, "bar")])


More information about the pypy-commit mailing list