[pypy-svn] r58657 - in pypy/branch/2.5-merge/pypy/module/itertools: . test
arigo at codespeak.net
arigo at codespeak.net
Mon Oct 6 16:05:31 CEST 2008
Author: arigo
Date: Mon Oct 6 16:05:29 2008
New Revision: 58657
Modified:
pypy/branch/2.5-merge/pypy/module/itertools/interp_itertools.py
pypy/branch/2.5-merge/pypy/module/itertools/test/test_itertools.py
Log:
(iko, arigo)
Redo an optimization that CPython tests for.
Modified: pypy/branch/2.5-merge/pypy/module/itertools/interp_itertools.py
==============================================================================
--- pypy/branch/2.5-merge/pypy/module/itertools/interp_itertools.py (original)
+++ pypy/branch/2.5-merge/pypy/module/itertools/interp_itertools.py Mon Oct 6 16:05:29 2008
@@ -678,9 +678,16 @@
"""
if n < 0:
raise OperationError(space.w_ValueError, space.wrap("n must be >= 0"))
-
- tee_state = TeeState(space, w_iterable)
- iterators_w = [space.wrap(W_TeeIterable(space, tee_state)) for x in range(n)]
+
+ myiter = space.interpclass_w(w_iterable)
+ if isinstance(myiter, W_TeeIterable): # optimization only
+ tee_state = myiter.tee_state
+ iterators_w = [w_iterable] * n
+ for i in range(1, n):
+ iterators_w[i] = space.wrap(W_TeeIterable(space, tee_state))
+ else:
+ tee_state = TeeState(space, w_iterable)
+ iterators_w = [space.wrap(W_TeeIterable(space, tee_state)) for x in range(n)]
return space.newtuple(iterators_w)
tee.unwrap_spec = [ObjSpace, W_Root, int]
Modified: pypy/branch/2.5-merge/pypy/module/itertools/test/test_itertools.py
==============================================================================
--- pypy/branch/2.5-merge/pypy/module/itertools/test/test_itertools.py (original)
+++ pypy/branch/2.5-merge/pypy/module/itertools/test/test_itertools.py Mon Oct 6 16:05:29 2008
@@ -439,6 +439,22 @@
raises(ValueError, itertools.tee, [], -1)
raises(TypeError, itertools.tee, [], None)
+ def test_tee_optimization(self):
+ import itertools
+
+ a, b = itertools.tee(iter('foobar'))
+ c, d = itertools.tee(b)
+ assert c is b
+ assert a is not c
+ assert a is not d
+ assert c is not d
+ res = list(a)
+ assert res == list('foobar')
+ res = list(c)
+ assert res == list('foobar')
+ res = list(d)
+ assert res == list('foobar')
+
def test_groupby(self):
import itertools
More information about the Pypy-commit
mailing list