[pypy-svn] r58657 - in pypy/branch/2.5-merge/pypy/module/itertools: . test

arigo at codespeak.net arigo at codespeak.net
Mon Oct 6 16:05:31 CEST 2008


Author: arigo
Date: Mon Oct  6 16:05:29 2008
New Revision: 58657

Modified:
   pypy/branch/2.5-merge/pypy/module/itertools/interp_itertools.py
   pypy/branch/2.5-merge/pypy/module/itertools/test/test_itertools.py
Log:
(iko, arigo)
Redo an optimization that CPython tests for.


Modified: pypy/branch/2.5-merge/pypy/module/itertools/interp_itertools.py
==============================================================================
--- pypy/branch/2.5-merge/pypy/module/itertools/interp_itertools.py	(original)
+++ pypy/branch/2.5-merge/pypy/module/itertools/interp_itertools.py	Mon Oct  6 16:05:29 2008
@@ -678,9 +678,16 @@
     """
     if n < 0:
         raise OperationError(space.w_ValueError, space.wrap("n must be >= 0"))
-    
-    tee_state = TeeState(space, w_iterable)
-    iterators_w = [space.wrap(W_TeeIterable(space, tee_state)) for x in range(n)]
+
+    myiter = space.interpclass_w(w_iterable)
+    if isinstance(myiter, W_TeeIterable):     # optimization only
+        tee_state = myiter.tee_state
+        iterators_w = [w_iterable] * n
+        for i in range(1, n):
+            iterators_w[i] = space.wrap(W_TeeIterable(space, tee_state))
+    else:
+        tee_state = TeeState(space, w_iterable)
+        iterators_w = [space.wrap(W_TeeIterable(space, tee_state)) for x in range(n)]
     return space.newtuple(iterators_w)
 tee.unwrap_spec = [ObjSpace, W_Root, int]
 

Modified: pypy/branch/2.5-merge/pypy/module/itertools/test/test_itertools.py
==============================================================================
--- pypy/branch/2.5-merge/pypy/module/itertools/test/test_itertools.py	(original)
+++ pypy/branch/2.5-merge/pypy/module/itertools/test/test_itertools.py	Mon Oct  6 16:05:29 2008
@@ -439,6 +439,22 @@
         raises(ValueError, itertools.tee, [], -1)
         raises(TypeError, itertools.tee, [], None)
 
+    def test_tee_optimization(self):
+        import itertools
+
+        a, b = itertools.tee(iter('foobar'))
+        c, d = itertools.tee(b)
+        assert c is b
+        assert a is not c
+        assert a is not d
+        assert c is not d
+        res = list(a)
+        assert res == list('foobar')
+        res = list(c)
+        assert res == list('foobar')
+        res = list(d)
+        assert res == list('foobar')
+
     def test_groupby(self):
         import itertools
         



More information about the Pypy-commit mailing list