[Python-checkins] bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)
pablogsal
webhook-mailer at python.org
Sun Apr 11 00:51:42 EDT 2021
https://github.com/python/cpython/commit/dfb45323ce8a543ca844c311e32c994ec9554c1b
commit: dfb45323ce8a543ca844c311e32c994ec9554c1b
branch: master
author: Dennis Sweeney <36520290+sweeneyde at users.noreply.github.com>
committer: pablogsal <Pablogsal at gmail.com>
date: 2021-04-11T05:51:35+01:00
summary:
bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)
files:
A Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst
M Lib/test/test_asyncgen.py
M Objects/iterobject.c
diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py
index 99464e3d0929f..77c15c02bc891 100644
--- a/Lib/test/test_asyncgen.py
+++ b/Lib/test/test_asyncgen.py
@@ -372,11 +372,8 @@ def tearDown(self):
self.loop = None
asyncio.set_event_loop_policy(None)
- def test_async_gen_anext(self):
- async def gen():
- yield 1
- yield 2
- g = gen()
+ def check_async_iterator_anext(self, ait_class):
+ g = ait_class()
async def consume():
results = []
results.append(await anext(g))
@@ -388,6 +385,66 @@ async def consume():
with self.assertRaises(StopAsyncIteration):
self.loop.run_until_complete(consume())
+ async def test_2():
+ g1 = ait_class()
+ self.assertEqual(await anext(g1), 1)
+ self.assertEqual(await anext(g1), 2)
+ with self.assertRaises(StopAsyncIteration):
+ await anext(g1)
+ with self.assertRaises(StopAsyncIteration):
+ await anext(g1)
+
+ g2 = ait_class()
+ self.assertEqual(await anext(g2, "default"), 1)
+ self.assertEqual(await anext(g2, "default"), 2)
+ self.assertEqual(await anext(g2, "default"), "default")
+ self.assertEqual(await anext(g2, "default"), "default")
+
+ return "completed"
+
+ result = self.loop.run_until_complete(test_2())
+ self.assertEqual(result, "completed")
+
+ def test_async_generator_anext(self):
+ async def agen():
+ yield 1
+ yield 2
+ self.check_async_iterator_anext(agen)
+
+ def test_python_async_iterator_anext(self):
+ class MyAsyncIter:
+ """Asynchronously yield 1, then 2."""
+ def __init__(self):
+ self.yielded = 0
+ def __aiter__(self):
+ return self
+ async def __anext__(self):
+ if self.yielded >= 2:
+ raise StopAsyncIteration()
+ else:
+ self.yielded += 1
+ return self.yielded
+ self.check_async_iterator_anext(MyAsyncIter)
+
+ def test_python_async_iterator_types_coroutine_anext(self):
+ import types
+ class MyAsyncIterWithTypesCoro:
+ """Asynchronously yield 1, then 2."""
+ def __init__(self):
+ self.yielded = 0
+ def __aiter__(self):
+ return self
+ @types.coroutine
+ def __anext__(self):
+ if False:
+ yield "this is a generator-based coroutine"
+ if self.yielded >= 2:
+ raise StopAsyncIteration()
+ else:
+ self.yielded += 1
+ return self.yielded
+ self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)
+
def test_async_gen_aiter(self):
async def gen():
yield 1
@@ -431,12 +488,85 @@ async def call_with_too_many_args():
await anext(gen(), 1, 3)
async def call_with_wrong_type_args():
await anext(1, gen())
+ async def call_with_kwarg():
+ await anext(aiterator=gen())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_too_few_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_too_many_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_wrong_type_args())
+ with self.assertRaises(TypeError):
+ self.loop.run_until_complete(call_with_kwarg())
+
+ def test_anext_bad_await(self):
+ async def bad_awaitable():
+ class BadAwaitable:
+ def __await__(self):
+ return 42
+ class MyAsyncIter:
+ def __aiter__(self):
+ return self
+ def __anext__(self):
+ return BadAwaitable()
+ regex = r"__await__.*iterator"
+ awaitable = anext(MyAsyncIter(), "default")
+ with self.assertRaisesRegex(TypeError, regex):
+ await awaitable
+ awaitable = anext(MyAsyncIter())
+ with self.assertRaisesRegex(TypeError, regex):
+ await awaitable
+ return "completed"
+ result = self.loop.run_until_complete(bad_awaitable())
+ self.assertEqual(result, "completed")
+
+ async def check_anext_returning_iterator(self, aiter_class):
+ awaitable = anext(aiter_class(), "default")
+ with self.assertRaises(TypeError):
+ await awaitable
+ awaitable = anext(aiter_class())
+ with self.assertRaises(TypeError):
+ await awaitable
+ return "completed"
+
+ def test_anext_return_iterator(self):
+ class WithIterAnext:
+ def __aiter__(self):
+ return self
+ def __anext__(self):
+ return iter("abc")
+ result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext))
+ self.assertEqual(result, "completed")
+
+ def test_anext_return_generator(self):
+ class WithGenAnext:
+ def __aiter__(self):
+ return self
+ def __anext__(self):
+ yield
+ result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext))
+ self.assertEqual(result, "completed")
+
+ def test_anext_await_raises(self):
+ class RaisingAwaitable:
+ def __await__(self):
+ raise ZeroDivisionError()
+ yield
+ class WithRaisingAwaitableAnext:
+ def __aiter__(self):
+ return self
+ def __anext__(self):
+ return RaisingAwaitable()
+ async def do_test():
+ awaitable = anext(WithRaisingAwaitableAnext())
+ with self.assertRaises(ZeroDivisionError):
+ await awaitable
+ awaitable = anext(WithRaisingAwaitableAnext(), "default")
+ with self.assertRaises(ZeroDivisionError):
+ await awaitable
+ return "completed"
+ result = self.loop.run_until_complete(do_test())
+ self.assertEqual(result, "completed")
def test_aiter_bad_args(self):
async def gen():
diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst b/Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst
new file mode 100644
index 0000000000000..75951ae794d10
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst
@@ -0,0 +1 @@
+Fixed a bug where ``anext(ait, default)`` would erroneously return None.
\ No newline at end of file
diff --git a/Objects/iterobject.c b/Objects/iterobject.c
index 65af18abf79de..6961fc3b4a949 100644
--- a/Objects/iterobject.c
+++ b/Objects/iterobject.c
@@ -316,7 +316,52 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
static PyObject *
anextawaitable_iternext(anextawaitableobject *obj)
{
- PyObject *result = PyIter_Next(obj->wrapped);
+ /* Consider the following class:
+ *
+ * class A:
+ * async def __anext__(self):
+ * ...
+ * a = A()
+ *
+ * Then `await anext(a)` should call
+ * a.__anext__().__await__().__next__()
+ *
+ * On the other hand, given
+ *
+ * async def agen():
+ * yield 1
+ * yield 2
+ * gen = agen()
+ *
+ * Then `await anext(gen)` can just call
+ * gen.__anext__().__next__()
+ */
+ assert(obj->wrapped != NULL);
+ PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
+ if (awaitable == NULL) {
+ return NULL;
+ }
+ if (Py_TYPE(awaitable)->tp_iternext == NULL) {
+ /* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
+ * or an iterator. Of these, only coroutines lack tp_iternext.
+ */
+ assert(PyCoro_CheckExact(awaitable));
+ unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
+ PyObject *new_awaitable = getter(awaitable);
+ if (new_awaitable == NULL) {
+ Py_DECREF(awaitable);
+ return NULL;
+ }
+ Py_SETREF(awaitable, new_awaitable);
+ if (Py_TYPE(awaitable)->tp_iternext == NULL) {
+ PyErr_SetString(PyExc_TypeError,
+ "__await__ returned a non-iterable");
+ Py_DECREF(awaitable);
+ return NULL;
+ }
+ }
+ PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
+ Py_DECREF(awaitable);
if (result != NULL) {
return result;
}
More information about the Python-checkins
mailing list