[Python-checkins] cpython (3.5): Issue #24325, #24400: Add more unittests for types.coroutine; tweak wrapper

yury.selivanov python-checkins at python.org
Wed Jun 24 17:45:28 CEST 2015


https://hg.python.org/cpython/rev/eb6fb8e2f995
changeset:   96665:eb6fb8e2f995
branch:      3.5
parent:      96663:e31aad001fdb
user:        Yury Selivanov <yselivanov at sprymix.com>
date:        Wed Jun 24 11:44:51 2015 -0400
summary:
  Issue #24325, #24400: Add more unittests for types.coroutine; tweak wrapper implementation.

files:
  Lib/test/test_asyncio/test_pep492.py |   19 +
  Lib/test/test_types.py               |  206 ++++++++++++--
  Lib/types.py                         |   65 ++--
  3 files changed, 228 insertions(+), 62 deletions(-)


diff --git a/Lib/test/test_asyncio/test_pep492.py b/Lib/test/test_asyncio/test_pep492.py
--- a/Lib/test/test_asyncio/test_pep492.py
+++ b/Lib/test/test_asyncio/test_pep492.py
@@ -1,6 +1,7 @@
 """Tests support for new syntax introduced by PEP 492."""
 
 import collections.abc
+import types
 import unittest
 
 from test import support
@@ -164,5 +165,23 @@
         self.loop.run_until_complete(start())
 
 
+    def test_types_coroutine(self):
+        def gen():
+            yield from ()
+            return 'spam'
+
+        @types.coroutine
+        def func():
+            return gen()
+
+        async def coro():
+            wrapper = func()
+            self.assertIsInstance(wrapper, types._GeneratorWrapper)
+            return await wrapper
+
+        data = self.loop.run_until_complete(coro())
+        self.assertEqual(data, 'spam')
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py
--- a/Lib/test/test_types.py
+++ b/Lib/test/test_types.py
@@ -7,7 +7,8 @@
 import locale
 import sys
 import types
-import unittest
+import unittest.mock
+import weakref
 
 class TypesTests(unittest.TestCase):
 
@@ -1191,23 +1192,27 @@
 
 class CoroutineTests(unittest.TestCase):
     def test_wrong_args(self):
-        class Foo:
-            def __call__(self):
-                pass
-        def bar(): pass
-
         samples = [None, 1, object()]
         for sample in samples:
             with self.assertRaisesRegex(TypeError,
                                         'types.coroutine.*expects a callable'):
                 types.coroutine(sample)
 
-    def test_wrong_func(self):
+    def test_non_gen_values(self):
         @types.coroutine
         def foo():
             return 'spam'
         self.assertEqual(foo(), 'spam')
 
+        class Awaitable:
+            def __await__(self):
+                return ()
+        aw = Awaitable()
+        @types.coroutine
+        def foo():
+            return aw
+        self.assertIs(aw, foo())
+
     def test_async_def(self):
         # Test that types.coroutine passes 'async def' coroutines
         # without modification
@@ -1263,24 +1268,157 @@
             def send(self): pass
             def throw(self): pass
             def close(self): pass
-            def __iter__(self): return self
+            def __iter__(self): pass
             def __next__(self): pass
 
-        gen = GenLike()
+        # Setup generator mock object
+        gen = unittest.mock.MagicMock(GenLike)
+        gen.__iter__ = lambda gen: gen
+        gen.__name__ = 'gen'
+        gen.__qualname__ = 'test.gen'
+        self.assertIsInstance(gen, collections.abc.Generator)
+        self.assertIs(gen, iter(gen))
+
         @types.coroutine
-        def foo():
-            return gen
-        self.assertIs(foo().__await__(), gen)
-        self.assertTrue(isinstance(foo(), collections.abc.Coroutine))
-        with self.assertRaises(AttributeError):
-            foo().gi_code
+        def foo(): return gen
+
+        wrapper = foo()
+        self.assertIsInstance(wrapper, types._GeneratorWrapper)
+        self.assertIs(wrapper.__await__(), wrapper)
+        # Wrapper proxies duck generators completely:
+        self.assertIs(iter(wrapper), wrapper)
+
+        self.assertIsInstance(wrapper, collections.abc.Coroutine)
+        self.assertIsInstance(wrapper, collections.abc.Awaitable)
+
+        self.assertIs(wrapper.__qualname__, gen.__qualname__)
+        self.assertIs(wrapper.__name__, gen.__name__)
+
+        # Test AttributeErrors
+        for name in {'gi_running', 'gi_frame', 'gi_code',
+                     'cr_running', 'cr_frame', 'cr_code'}:
+            with self.assertRaises(AttributeError):
+                getattr(wrapper, name)
+
+        # Test attributes pass-through
+        gen.gi_running = object()
+        gen.gi_frame = object()
+        gen.gi_code = object()
+        self.assertIs(wrapper.gi_running, gen.gi_running)
+        self.assertIs(wrapper.gi_frame, gen.gi_frame)
+        self.assertIs(wrapper.gi_code, gen.gi_code)
+        self.assertIs(wrapper.cr_running, gen.gi_running)
+        self.assertIs(wrapper.cr_frame, gen.gi_frame)
+        self.assertIs(wrapper.cr_code, gen.gi_code)
+
+        wrapper.close()
+        gen.close.assert_called_once_with()
+
+        wrapper.send(1)
+        gen.send.assert_called_once_with(1)
+
+        wrapper.throw(1, 2, 3)
+        gen.throw.assert_called_once_with(1, 2, 3)
+        gen.reset_mock()
+
+        wrapper.throw(1, 2)
+        gen.throw.assert_called_once_with(1, 2)
+        gen.reset_mock()
+
+        wrapper.throw(1)
+        gen.throw.assert_called_once_with(1)
+        gen.reset_mock()
+
+        # Test exceptions propagation
+        error = Exception()
+        gen.throw.side_effect = error
+        try:
+            wrapper.throw(1)
+        except Exception as ex:
+            self.assertIs(ex, error)
+        else:
+            self.fail('wrapper did not propagate an exception')
+
+        # Test invalid args
+        gen.reset_mock()
+        with self.assertRaises(TypeError):
+            wrapper.throw()
+        self.assertFalse(gen.throw.called)
+        with self.assertRaises(TypeError):
+            wrapper.close(1)
+        self.assertFalse(gen.close.called)
+        with self.assertRaises(TypeError):
+            wrapper.send()
+        self.assertFalse(gen.send.called)
+
+        # Test that we do not double wrap
+        @types.coroutine
+        def bar(): return wrapper
+        self.assertIs(wrapper, bar())
+
+        # Test weakrefs support
+        ref = weakref.ref(wrapper)
+        self.assertIs(ref(), wrapper)
+
+    def test_duck_functional_gen(self):
+        class Generator:
+            """Emulates the following generator (very clumsy):
+
+              def gen(fut):
+                  result = yield fut
+                  return result * 2
+            """
+            def __init__(self, fut):
+                self._i = 0
+                self._fut = fut
+            def __iter__(self):
+                return self
+            def __next__(self):
+                return self.send(None)
+            def send(self, v):
+                try:
+                    if self._i == 0:
+                        assert v is None
+                        return self._fut
+                    if self._i == 1:
+                        raise StopIteration(v * 2)
+                    if self._i > 1:
+                        raise StopIteration
+                finally:
+                    self._i += 1
+            def throw(self, tp, *exc):
+                self._i = 100
+                if tp is not GeneratorExit:
+                    raise tp
+            def close(self):
+                self.throw(GeneratorExit)
+
+        @types.coroutine
+        def foo(): return Generator('spam')
+
+        wrapper = foo()
+        self.assertIsInstance(wrapper, types._GeneratorWrapper)
+
+        async def corofunc():
+            return await foo() + 100
+        coro = corofunc()
+
+        self.assertEqual(coro.send(None), 'spam')
+        try:
+            coro.send(20)
+        except StopIteration as ex:
+            self.assertEqual(ex.args[0], 140)
+        else:
+            self.fail('StopIteration was expected')
 
     def test_gen(self):
         def gen(): yield
         gen = gen()
         @types.coroutine
         def foo(): return gen
-        self.assertIs(foo().__await__(), gen)
+        wrapper = foo()
+        self.assertIsInstance(wrapper, types._GeneratorWrapper)
+        self.assertIs(wrapper.__await__(), gen)
 
         for name in ('__name__', '__qualname__', 'gi_code',
                      'gi_running', 'gi_frame'):
@@ -1289,19 +1427,8 @@
         self.assertIs(foo().cr_code, gen.gi_code)
 
     def test_genfunc(self):
-        def gen():
-            yield
-
-        self.assertFalse(isinstance(gen(), collections.abc.Coroutine))
-        self.assertFalse(isinstance(gen(), collections.abc.Awaitable))
-
-        gen_code = gen.__code__
-        decorated_gen = types.coroutine(gen)
-        self.assertIs(decorated_gen, gen)
-        self.assertIsNot(decorated_gen.__code__, gen_code)
-
-        decorated_gen2 = types.coroutine(decorated_gen)
-        self.assertIs(decorated_gen2.__code__, decorated_gen.__code__)
+        def gen(): yield
+        self.assertIs(types.coroutine(gen), gen)
 
         self.assertTrue(gen.__code__.co_flags & inspect.CO_ITERABLE_COROUTINE)
         self.assertFalse(gen.__code__.co_flags & inspect.CO_COROUTINE)
@@ -1309,10 +1436,27 @@
         g = gen()
         self.assertTrue(g.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE)
         self.assertFalse(g.gi_code.co_flags & inspect.CO_COROUTINE)
-        self.assertTrue(isinstance(g, collections.abc.Coroutine))
-        self.assertTrue(isinstance(g, collections.abc.Awaitable))
+        self.assertIsInstance(g, collections.abc.Coroutine)
+        self.assertIsInstance(g, collections.abc.Awaitable)
         g.close() # silence warning
 
+        self.assertIs(types.coroutine(gen), gen)
+
+    def test_wrapper_object(self):
+        def gen():
+            yield
+        @types.coroutine
+        def coro():
+            return gen()
+
+        wrapper = coro()
+        self.assertIn('GeneratorWrapper', repr(wrapper))
+        self.assertEqual(repr(wrapper), str(wrapper))
+        self.assertTrue(set(dir(wrapper)).issuperset({
+            '__await__', '__iter__', '__next__', 'cr_code', 'cr_running',
+            'cr_frame', 'gi_code', 'gi_frame', 'gi_running', 'send',
+            'close', 'throw'}))
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/Lib/types.py b/Lib/types.py
--- a/Lib/types.py
+++ b/Lib/types.py
@@ -166,6 +166,39 @@
 import functools as _functools
 import collections.abc as _collections_abc
 
+class _GeneratorWrapper:
+    # TODO: Implement this in C.
+    def __init__(self, gen):
+        self.__wrapped__ = gen
+        self.__isgen__ = gen.__class__ is GeneratorType
+        self.__name__ = getattr(gen, '__name__', None)
+        self.__qualname__ = getattr(gen, '__qualname__', None)
+    def send(self, val):
+        return self.__wrapped__.send(val)
+    def throw(self, tp, *rest):
+        return self.__wrapped__.throw(tp, *rest)
+    def close(self):
+        return self.__wrapped__.close()
+    @property
+    def gi_code(self):
+        return self.__wrapped__.gi_code
+    @property
+    def gi_frame(self):
+        return self.__wrapped__.gi_frame
+    @property
+    def gi_running(self):
+        return self.__wrapped__.gi_running
+    cr_code = gi_code
+    cr_frame = gi_frame
+    cr_running = gi_running
+    def __next__(self):
+        return next(self.__wrapped__)
+    def __iter__(self):
+        if self.__isgen__:
+            return self.__wrapped__
+        return self
+    __await__ = __iter__
+
 def coroutine(func):
     """Convert regular generator function to a coroutine."""
 
@@ -201,36 +234,6 @@
     # return generator-like objects (for instance generators
     # compiled with Cython).
 
-    class GeneratorWrapper:
-        def __init__(self, gen):
-            self.__wrapped__ = gen
-            self.__name__ = getattr(gen, '__name__', None)
-            self.__qualname__ = getattr(gen, '__qualname__', None)
-        def send(self, val):
-            return self.__wrapped__.send(val)
-        def throw(self, *args):
-            return self.__wrapped__.throw(*args)
-        def close(self):
-            return self.__wrapped__.close()
-        @property
-        def gi_code(self):
-            return self.__wrapped__.gi_code
-        @property
-        def gi_frame(self):
-            return self.__wrapped__.gi_frame
-        @property
-        def gi_running(self):
-            return self.__wrapped__.gi_running
-        cr_code = gi_code
-        cr_frame = gi_frame
-        cr_running = gi_running
-        def __next__(self):
-            return next(self.__wrapped__)
-        def __iter__(self):
-            return self.__wrapped__
-        def __await__(self):
-            return self.__wrapped__
-
     @_functools.wraps(func)
     def wrapped(*args, **kwargs):
         coro = func(*args, **kwargs)
@@ -243,7 +246,7 @@
             # 'coro' is either a pure Python generator iterator, or it
             # implements collections.abc.Generator (and does not implement
             # collections.abc.Coroutine).
-            return GeneratorWrapper(coro)
+            return _GeneratorWrapper(coro)
         # 'coro' is either an instance of collections.abc.Coroutine or
         # some other object -- pass it through.
         return coro

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list