[pypy-commit] pypy default: Rewrite itertools.groupby(), following CPython instead of having many

arigo pypy.commits at gmail.com
Thu May 11 05:23:37 EDT 2017


Author: Armin Rigo <arigo at tunes.org>
Branch: 
Changeset: r91241:6093ff1a44e6
Date: 2017-05-11 11:22 +0200
http://bitbucket.org/pypy/pypy/changeset/6093ff1a44e6/

Log:	Rewrite itertools.groupby(), following CPython instead of having
	many flags to get a result that differs subtly

diff --git a/pypy/module/itertools/interp_itertools.py b/pypy/module/itertools/interp_itertools.py
--- a/pypy/module/itertools/interp_itertools.py
+++ b/pypy/module/itertools/interp_itertools.py
@@ -920,90 +920,42 @@
 class W_GroupBy(W_Root):
     def __init__(self, space, w_iterable, w_fun):
         self.space = space
-        self.w_iterable = self.space.iter(w_iterable)
-        if space.is_none(w_fun):
-            self.w_fun = None
-        else:
-            self.w_fun = w_fun
-        self.index = 0
-        self.lookahead = False
-        self.exhausted = False
-        self.started = False
-        # new_group - new group not started yet, next should not skip any items
-        self.new_group = True
-        self.w_lookahead = self.space.w_None
-        self.w_key = self.space.w_None
+        self.w_iterator = self.space.iter(w_iterable)
+        if w_fun is None:
+            w_fun = space.w_None
+        self.w_keyfunc = w_fun
+        self.w_tgtkey = None
+        self.w_currkey = None
+        self.w_currvalue = None
 
     def iter_w(self):
         return self
 
     def next_w(self):
-        if self.exhausted:
-            raise OperationError(self.space.w_StopIteration, self.space.w_None)
+        self._skip_to_next_iteration_group()
+        w_key = self.w_tgtkey = self.w_currkey
+        w_grouper = W_GroupByIterator(self, w_key)
+        return self.space.newtuple([w_key, w_grouper])
 
-        if not self.new_group:
-            self._consume_unwanted_input()
+    def _skip_to_next_iteration_group(self):
+        space = self.space
+        while True:
+            if self.w_currkey is None:
+                pass
+            elif self.w_tgtkey is None:
+                break
+            else:
+                if not space.eq_w(self.w_tgtkey, self.w_currkey):
+                    break
 
-        if not self.started:
-            self.started = True
-            try:
-                w_obj = self.space.next(self.w_iterable)
-            except OperationError as e:
-                if e.match(self.space, self.space.w_StopIteration):
-                    self.exhausted = True
-                raise
+            w_newvalue = space.next(self.w_iterator)
+            if space.is_w(self.w_keyfunc, space.w_None):
+                w_newkey = w_newvalue
             else:
-                self.w_lookahead = w_obj
-                if self.w_fun is None:
-                    self.w_key = w_obj
-                else:
-                    self.w_key = self.space.call_function(self.w_fun, w_obj)
-                self.lookahead = True
+                w_newkey = space.call_function(self.w_keyfunc, w_newvalue)
 
-        self.new_group = False
-        w_iterator = W_GroupByIterator(self.space, self.index, self)
-        return self.space.newtuple([self.w_key, w_iterator])
-
-    def _consume_unwanted_input(self):
-        # Consume unwanted input until we reach the next group
-        try:
-            while True:
-                self.group_next(self.index)
-        except StopIteration:
-            pass
-        if self.exhausted:
-            raise OperationError(self.space.w_StopIteration, self.space.w_None)
-
-    def group_next(self, group_index):
-        if group_index < self.index:
-            raise StopIteration
-        else:
-            if self.lookahead:
-                self.lookahead = False
-                return self.w_lookahead
-
-            try:
-                w_obj = self.space.next(self.w_iterable)
-            except OperationError as e:
-                if e.match(self.space, self.space.w_StopIteration):
-                    self.exhausted = True
-                    raise StopIteration
-                else:
-                    raise
-            else:
-                if self.w_fun is None:
-                    w_new_key = w_obj
-                else:
-                    w_new_key = self.space.call_function(self.w_fun, w_obj)
-                if self.space.eq_w(self.w_key, w_new_key):
-                    return w_obj
-                else:
-                    self.index += 1
-                    self.w_lookahead = w_obj
-                    self.w_key = w_new_key
-                    self.lookahead = True
-                    self.new_group = True #new group
-                    raise StopIteration
+            self.w_currkey = w_newkey
+            self.w_currvalue = w_newvalue
 
 def W_GroupBy___new__(space, w_subtype, w_iterable, w_key=None):
     r = space.allocate_instance(W_GroupBy, w_subtype)
@@ -1036,26 +988,33 @@
 
 
 class W_GroupByIterator(W_Root):
-    def __init__(self, space, index, groupby):
-        self.space = space
-        self.index = index
+    def __init__(self, groupby, w_tgtkey):
         self.groupby = groupby
-        self.exhausted = False
+        self.w_tgtkey = w_tgtkey
 
     def iter_w(self):
         return self
 
     def next_w(self):
-        if self.exhausted:
-            raise OperationError(self.space.w_StopIteration, self.space.w_None)
+        groupby = self.groupby
+        space = groupby.space
+        if groupby.w_currvalue is None:
+            w_newvalue = space.next(groupby.w_iterator)
+            if space.is_w(groupby.w_keyfunc, space.w_None):
+                w_newkey = w_newvalue
+            else:
+                w_newkey = space.call_function(groupby.w_keyfunc, w_newvalue)
+            assert groupby.w_currvalue is None
+            groupby.w_currkey = w_newkey
+            groupby.w_currvalue = w_newvalue
 
-        try:
-            w_obj = self.groupby.group_next(self.index)
-        except StopIteration:
-            self.exhausted = True
-            raise OperationError(self.space.w_StopIteration, self.space.w_None)
-        else:
-            return w_obj
+        assert groupby.w_currkey is not None
+        if not space.eq_w(self.w_tgtkey, groupby.w_currkey):
+            raise OperationError(space.w_StopIteration, space.w_None)
+        w_result = groupby.w_currvalue
+        groupby.w_currvalue = None
+        groupby.w_currkey = None
+        return w_result
 
 W_GroupByIterator.typedef = TypeDef(
         'itertools._groupby',
diff --git a/pypy/module/itertools/test/test_itertools.py b/pypy/module/itertools/test/test_itertools.py
--- a/pypy/module/itertools/test/test_itertools.py
+++ b/pypy/module/itertools/test/test_itertools.py
@@ -634,6 +634,17 @@
         it = itertools.groupby([0], 1)
         raises(TypeError, it.next)
 
+    def test_groupby_question_43905804(self):
+        # http://stackoverflow.com/questions/43905804/
+        import itertools
+
+        inputs = ((x > 5, x) for x in range(10))
+        (_, a), (_, b) = itertools.groupby(inputs, key=lambda x: x[0])
+        a = list(a)
+        b = list(b)
+        assert a == []
+        assert b == [(True, 9)]
+
     def test_iterables(self):
         import itertools
     


More information about the pypy-commit mailing list