[Python-checkins] bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)

pablogsal webhook-mailer at python.org
Sun Apr 11 00:51:42 EDT 2021


https://github.com/python/cpython/commit/dfb45323ce8a543ca844c311e32c994ec9554c1b
commit: dfb45323ce8a543ca844c311e32c994ec9554c1b
branch: master
author: Dennis Sweeney <36520290+sweeneyde at users.noreply.github.com>
committer: pablogsal <Pablogsal at gmail.com>
date: 2021-04-11T05:51:35+01:00
summary:

bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)

files:
A Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst
M Lib/test/test_asyncgen.py
M Objects/iterobject.c

diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py
index 99464e3d0929f..77c15c02bc891 100644
--- a/Lib/test/test_asyncgen.py
+++ b/Lib/test/test_asyncgen.py
@@ -372,11 +372,8 @@ def tearDown(self):
         self.loop = None
         asyncio.set_event_loop_policy(None)
 
-    def test_async_gen_anext(self):
-        async def gen():
-            yield 1
-            yield 2
-        g = gen()
+    def check_async_iterator_anext(self, ait_class):
+        g = ait_class()
         async def consume():
             results = []
             results.append(await anext(g))
@@ -388,6 +385,66 @@ async def consume():
         with self.assertRaises(StopAsyncIteration):
             self.loop.run_until_complete(consume())
 
+        async def test_2():
+            g1 = ait_class()
+            self.assertEqual(await anext(g1), 1)
+            self.assertEqual(await anext(g1), 2)
+            with self.assertRaises(StopAsyncIteration):
+                await anext(g1)
+            with self.assertRaises(StopAsyncIteration):
+                await anext(g1)
+
+            g2 = ait_class()
+            self.assertEqual(await anext(g2, "default"), 1)
+            self.assertEqual(await anext(g2, "default"), 2)
+            self.assertEqual(await anext(g2, "default"), "default")
+            self.assertEqual(await anext(g2, "default"), "default")
+
+            return "completed"
+
+        result = self.loop.run_until_complete(test_2())
+        self.assertEqual(result, "completed")
+
+    def test_async_generator_anext(self):
+        async def agen():
+            yield 1
+            yield 2
+        self.check_async_iterator_anext(agen)
+
+    def test_python_async_iterator_anext(self):
+        class MyAsyncIter:
+            """Asynchronously yield 1, then 2."""
+            def __init__(self):
+                self.yielded = 0
+            def __aiter__(self):
+                return self
+            async def __anext__(self):
+                if self.yielded >= 2:
+                    raise StopAsyncIteration()
+                else:
+                    self.yielded += 1
+                    return self.yielded
+        self.check_async_iterator_anext(MyAsyncIter)
+
+    def test_python_async_iterator_types_coroutine_anext(self):
+        import types
+        class MyAsyncIterWithTypesCoro:
+            """Asynchronously yield 1, then 2."""
+            def __init__(self):
+                self.yielded = 0
+            def __aiter__(self):
+                return self
+            @types.coroutine
+            def __anext__(self):
+                if False:
+                    yield "this is a generator-based coroutine"
+                if self.yielded >= 2:
+                    raise StopAsyncIteration()
+                else:
+                    self.yielded += 1
+                    return self.yielded
+        self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)
+
     def test_async_gen_aiter(self):
         async def gen():
             yield 1
@@ -431,12 +488,85 @@ async def call_with_too_many_args():
             await anext(gen(), 1, 3)
         async def call_with_wrong_type_args():
             await anext(1, gen())
+        async def call_with_kwarg():
+            await anext(aiterator=gen())
         with self.assertRaises(TypeError):
             self.loop.run_until_complete(call_with_too_few_args())
         with self.assertRaises(TypeError):
             self.loop.run_until_complete(call_with_too_many_args())
         with self.assertRaises(TypeError):
             self.loop.run_until_complete(call_with_wrong_type_args())
+        with self.assertRaises(TypeError):
+            self.loop.run_until_complete(call_with_kwarg())
+
+    def test_anext_bad_await(self):
+        async def bad_awaitable():
+            class BadAwaitable:
+                def __await__(self):
+                    return 42
+            class MyAsyncIter:
+                def __aiter__(self):
+                    return self
+                def __anext__(self):
+                    return BadAwaitable()
+            regex = r"__await__.*iterator"
+            awaitable = anext(MyAsyncIter(), "default")
+            with self.assertRaisesRegex(TypeError, regex):
+                await awaitable
+            awaitable = anext(MyAsyncIter())
+            with self.assertRaisesRegex(TypeError, regex):
+                await awaitable
+            return "completed"
+        result = self.loop.run_until_complete(bad_awaitable())
+        self.assertEqual(result, "completed")
+
+    async def check_anext_returning_iterator(self, aiter_class):
+        awaitable = anext(aiter_class(), "default")
+        with self.assertRaises(TypeError):
+            await awaitable
+        awaitable = anext(aiter_class())
+        with self.assertRaises(TypeError):
+            await awaitable
+        return "completed"
+
+    def test_anext_return_iterator(self):
+        class WithIterAnext:
+            def __aiter__(self):
+                return self
+            def __anext__(self):
+                return iter("abc")
+        result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext))
+        self.assertEqual(result, "completed")
+
+    def test_anext_return_generator(self):
+        class WithGenAnext:
+            def __aiter__(self):
+                return self
+            def __anext__(self):
+                yield
+        result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext))
+        self.assertEqual(result, "completed")
+
+    def test_anext_await_raises(self):
+        class RaisingAwaitable:
+            def __await__(self):
+                raise ZeroDivisionError()
+                yield
+        class WithRaisingAwaitableAnext:
+            def __aiter__(self):
+                return self
+            def __anext__(self):
+                return RaisingAwaitable()
+        async def do_test():
+            awaitable = anext(WithRaisingAwaitableAnext())
+            with self.assertRaises(ZeroDivisionError):
+                await awaitable
+            awaitable = anext(WithRaisingAwaitableAnext(), "default")
+            with self.assertRaises(ZeroDivisionError):
+                await awaitable
+            return "completed"
+        result = self.loop.run_until_complete(do_test())
+        self.assertEqual(result, "completed")
 
     def test_aiter_bad_args(self):
         async def gen():
diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst b/Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst
new file mode 100644
index 0000000000000..75951ae794d10
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst	
@@ -0,0 +1 @@
+Fixed a bug where ``anext(ait, default)`` would erroneously return None.
\ No newline at end of file
diff --git a/Objects/iterobject.c b/Objects/iterobject.c
index 65af18abf79de..6961fc3b4a949 100644
--- a/Objects/iterobject.c
+++ b/Objects/iterobject.c
@@ -316,7 +316,52 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
 static PyObject *
 anextawaitable_iternext(anextawaitableobject *obj)
 {
-    PyObject *result = PyIter_Next(obj->wrapped);
+    /* Consider the following class:
+     *
+     *     class A:
+     *         async def __anext__(self):
+     *             ...
+     *     a = A()
+     *
+     * Then `await anext(a)` should call
+     * a.__anext__().__await__().__next__()
+     *
+     * On the other hand, given
+     *
+     *     async def agen():
+     *         yield 1
+     *         yield 2
+     *     gen = agen()
+     *
+     * Then `await anext(gen)` can just call
+     * gen.__anext__().__next__()
+     */
+    assert(obj->wrapped != NULL);
+    PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
+    if (awaitable == NULL) {
+        return NULL;
+    }
+    if (Py_TYPE(awaitable)->tp_iternext == NULL) {
+        /* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
+         * or an iterator. Of these, only coroutines lack tp_iternext.
+         */
+        assert(PyCoro_CheckExact(awaitable));
+        unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
+        PyObject *new_awaitable = getter(awaitable);
+        if (new_awaitable == NULL) {
+            Py_DECREF(awaitable);
+            return NULL;
+        }
+        Py_SETREF(awaitable, new_awaitable);
+        if (Py_TYPE(awaitable)->tp_iternext == NULL) {
+            PyErr_SetString(PyExc_TypeError,
+                            "__await__ returned a non-iterable");
+            Py_DECREF(awaitable);
+            return NULL;
+        }
+    }
+    PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
+    Py_DECREF(awaitable);
     if (result != NULL) {
         return result;
     }



More information about the Python-checkins mailing list