[pypy-commit] pypy default: Rewrite itertools.groupby(), following CPython instead of having many
arigo
pypy.commits at gmail.com
Thu May 11 05:23:37 EDT 2017
Author: Armin Rigo <arigo at tunes.org>
Branch:
Changeset: r91241:6093ff1a44e6
Date: 2017-05-11 11:22 +0200
http://bitbucket.org/pypy/pypy/changeset/6093ff1a44e6/
Log: Rewrite itertools.groupby(), following CPython instead of having
many flags to get a result that differs subtly
diff --git a/pypy/module/itertools/interp_itertools.py b/pypy/module/itertools/interp_itertools.py
--- a/pypy/module/itertools/interp_itertools.py
+++ b/pypy/module/itertools/interp_itertools.py
@@ -920,90 +920,42 @@
class W_GroupBy(W_Root):
def __init__(self, space, w_iterable, w_fun):
self.space = space
- self.w_iterable = self.space.iter(w_iterable)
- if space.is_none(w_fun):
- self.w_fun = None
- else:
- self.w_fun = w_fun
- self.index = 0
- self.lookahead = False
- self.exhausted = False
- self.started = False
- # new_group - new group not started yet, next should not skip any items
- self.new_group = True
- self.w_lookahead = self.space.w_None
- self.w_key = self.space.w_None
+ self.w_iterator = self.space.iter(w_iterable)
+ if w_fun is None:
+ w_fun = space.w_None
+ self.w_keyfunc = w_fun
+ self.w_tgtkey = None
+ self.w_currkey = None
+ self.w_currvalue = None
def iter_w(self):
return self
def next_w(self):
- if self.exhausted:
- raise OperationError(self.space.w_StopIteration, self.space.w_None)
+ self._skip_to_next_iteration_group()
+ w_key = self.w_tgtkey = self.w_currkey
+ w_grouper = W_GroupByIterator(self, w_key)
+ return self.space.newtuple([w_key, w_grouper])
- if not self.new_group:
- self._consume_unwanted_input()
+ def _skip_to_next_iteration_group(self):
+ space = self.space
+ while True:
+ if self.w_currkey is None:
+ pass
+ elif self.w_tgtkey is None:
+ break
+ else:
+ if not space.eq_w(self.w_tgtkey, self.w_currkey):
+ break
- if not self.started:
- self.started = True
- try:
- w_obj = self.space.next(self.w_iterable)
- except OperationError as e:
- if e.match(self.space, self.space.w_StopIteration):
- self.exhausted = True
- raise
+ w_newvalue = space.next(self.w_iterator)
+ if space.is_w(self.w_keyfunc, space.w_None):
+ w_newkey = w_newvalue
else:
- self.w_lookahead = w_obj
- if self.w_fun is None:
- self.w_key = w_obj
- else:
- self.w_key = self.space.call_function(self.w_fun, w_obj)
- self.lookahead = True
+ w_newkey = space.call_function(self.w_keyfunc, w_newvalue)
- self.new_group = False
- w_iterator = W_GroupByIterator(self.space, self.index, self)
- return self.space.newtuple([self.w_key, w_iterator])
-
- def _consume_unwanted_input(self):
- # Consume unwanted input until we reach the next group
- try:
- while True:
- self.group_next(self.index)
- except StopIteration:
- pass
- if self.exhausted:
- raise OperationError(self.space.w_StopIteration, self.space.w_None)
-
- def group_next(self, group_index):
- if group_index < self.index:
- raise StopIteration
- else:
- if self.lookahead:
- self.lookahead = False
- return self.w_lookahead
-
- try:
- w_obj = self.space.next(self.w_iterable)
- except OperationError as e:
- if e.match(self.space, self.space.w_StopIteration):
- self.exhausted = True
- raise StopIteration
- else:
- raise
- else:
- if self.w_fun is None:
- w_new_key = w_obj
- else:
- w_new_key = self.space.call_function(self.w_fun, w_obj)
- if self.space.eq_w(self.w_key, w_new_key):
- return w_obj
- else:
- self.index += 1
- self.w_lookahead = w_obj
- self.w_key = w_new_key
- self.lookahead = True
- self.new_group = True #new group
- raise StopIteration
+ self.w_currkey = w_newkey
+ self.w_currvalue = w_newvalue
def W_GroupBy___new__(space, w_subtype, w_iterable, w_key=None):
r = space.allocate_instance(W_GroupBy, w_subtype)
@@ -1036,26 +988,33 @@
class W_GroupByIterator(W_Root):
- def __init__(self, space, index, groupby):
- self.space = space
- self.index = index
+ def __init__(self, groupby, w_tgtkey):
self.groupby = groupby
- self.exhausted = False
+ self.w_tgtkey = w_tgtkey
def iter_w(self):
return self
def next_w(self):
- if self.exhausted:
- raise OperationError(self.space.w_StopIteration, self.space.w_None)
+ groupby = self.groupby
+ space = groupby.space
+ if groupby.w_currvalue is None:
+ w_newvalue = space.next(groupby.w_iterator)
+ if space.is_w(groupby.w_keyfunc, space.w_None):
+ w_newkey = w_newvalue
+ else:
+ w_newkey = space.call_function(groupby.w_keyfunc, w_newvalue)
+ assert groupby.w_currvalue is None
+ groupby.w_currkey = w_newkey
+ groupby.w_currvalue = w_newvalue
- try:
- w_obj = self.groupby.group_next(self.index)
- except StopIteration:
- self.exhausted = True
- raise OperationError(self.space.w_StopIteration, self.space.w_None)
- else:
- return w_obj
+ assert groupby.w_currkey is not None
+ if not space.eq_w(self.w_tgtkey, groupby.w_currkey):
+ raise OperationError(space.w_StopIteration, space.w_None)
+ w_result = groupby.w_currvalue
+ groupby.w_currvalue = None
+ groupby.w_currkey = None
+ return w_result
W_GroupByIterator.typedef = TypeDef(
'itertools._groupby',
diff --git a/pypy/module/itertools/test/test_itertools.py b/pypy/module/itertools/test/test_itertools.py
--- a/pypy/module/itertools/test/test_itertools.py
+++ b/pypy/module/itertools/test/test_itertools.py
@@ -634,6 +634,17 @@
it = itertools.groupby([0], 1)
raises(TypeError, it.next)
+ def test_groupby_question_43905804(self):
+ # http://stackoverflow.com/questions/43905804/
+ import itertools
+
+ inputs = ((x > 5, x) for x in range(10))
+ (_, a), (_, b) = itertools.groupby(inputs, key=lambda x: x[0])
+ a = list(a)
+ b = list(b)
+ assert a == []
+ assert b == [(True, 9)]
+
def test_iterables(self):
import itertools
More information about the pypy-commit
mailing list