[Python-checkins] bpo-44963: Implement send() and throw() methods for anext_awaitable objects (GH-27955)

pablogsal webhook-mailer at python.org
Tue Sep 7 06:30:19 EDT 2021


https://github.com/python/cpython/commit/533e725821b15e2df2cd4479a34597c1d8faf616
commit: 533e725821b15e2df2cd4479a34597c1d8faf616
branch: main
author: Pablo Galindo Salgado <Pablogsal at gmail.com>
committer: pablogsal <Pablogsal at gmail.com>
date: 2021-09-07T11:30:14+01:00
summary:

bpo-44963: Implement send() and throw() methods for anext_awaitable objects (GH-27955)

Co-authored-by: Yury Selivanov <yury at edgedb.com>

files:
A Misc/NEWS.d/next/Core and Builtins/2021-08-25-23-07-10.bpo-44963.5EET8y.rst
M Lib/test/test_asyncgen.py
M Objects/iterobject.c

diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py
index bc0ae8f532154..473bce484b47b 100644
--- a/Lib/test/test_asyncgen.py
+++ b/Lib/test/test_asyncgen.py
@@ -1,12 +1,16 @@
 import inspect
 import types
 import unittest
+import contextlib
 
 from test.support.import_helper import import_module
 from test.support import gc_collect
 asyncio = import_module("asyncio")
 
 
+_no_default = object()
+
+
 class AwaitException(Exception):
     pass
 
@@ -45,6 +49,37 @@ async def iterate():
     return run_until_complete(iterate())
 
 
+def py_anext(iterator, default=_no_default):
+    """Pure-Python implementation of anext() for testing purposes.
+
+    Closely matches the builtin anext() C implementation.
+    Can be used to compare the built-in implementation of the inner
+    coroutines machinery to C-implementation of __anext__() and send()
+    or throw() on the returned generator.
+    """
+
+    try:
+        __anext__ = type(iterator).__anext__
+    except AttributeError:
+        raise TypeError(f'{iterator!r} is not an async iterator')
+
+    if default is _no_default:
+        return __anext__(iterator)
+
+    async def anext_impl():
+        try:
+            # The C code is way more low-level than this, as it implements
+            # all methods of the iterator protocol. In this implementation
+            # we're relying on higher-level coroutine concepts, but that's
+            # exactly what we want -- crosstest pure-Python high-level
+            # implementation and low-level C anext() iterators.
+            return await __anext__(iterator)
+        except StopAsyncIteration:
+            return default
+
+    return anext_impl()
+
+
 class AsyncGenSyntaxTest(unittest.TestCase):
 
     def test_async_gen_syntax_01(self):
@@ -374,6 +409,12 @@ def tearDown(self):
         asyncio.set_event_loop_policy(None)
 
     def check_async_iterator_anext(self, ait_class):
+        with self.subTest(anext="pure-Python"):
+            self._check_async_iterator_anext(ait_class, py_anext)
+        with self.subTest(anext="builtin"):
+            self._check_async_iterator_anext(ait_class, anext)
+
+    def _check_async_iterator_anext(self, ait_class, anext):
         g = ait_class()
         async def consume():
             results = []
@@ -406,6 +447,24 @@ async def test_2():
         result = self.loop.run_until_complete(test_2())
         self.assertEqual(result, "completed")
 
+        def test_send():
+            p = ait_class()
+            obj = anext(p, "completed")
+            with self.assertRaises(StopIteration):
+                with contextlib.closing(obj.__await__()) as g:
+                    g.send(None)
+
+        test_send()
+
+        async def test_throw():
+            p = ait_class()
+            obj = anext(p, "completed")
+            self.assertRaises(SyntaxError, obj.throw, SyntaxError)
+            return "completed"
+
+        result = self.loop.run_until_complete(test_throw())
+        self.assertEqual(result, "completed")
+
     def test_async_generator_anext(self):
         async def agen():
             yield 1
@@ -569,6 +628,119 @@ async def do_test():
         result = self.loop.run_until_complete(do_test())
         self.assertEqual(result, "completed")
 
+    def test_anext_iter(self):
+        @types.coroutine
+        def _async_yield(v):
+            return (yield v)
+
+        class MyError(Exception):
+            pass
+
+        async def agenfn():
+            try:
+                await _async_yield(1)
+            except MyError:
+                await _async_yield(2)
+            return
+            yield
+
+        def test1(anext):
+            agen = agenfn()
+            with contextlib.closing(anext(agen, "default").__await__()) as g:
+                self.assertEqual(g.send(None), 1)
+                self.assertEqual(g.throw(MyError, MyError(), None), 2)
+                try:
+                    g.send(None)
+                except StopIteration as e:
+                    err = e
+                else:
+                    self.fail('StopIteration was not raised')
+                self.assertEqual(err.value, "default")
+
+        def test2(anext):
+            agen = agenfn()
+            with contextlib.closing(anext(agen, "default").__await__()) as g:
+                self.assertEqual(g.send(None), 1)
+                self.assertEqual(g.throw(MyError, MyError(), None), 2)
+                with self.assertRaises(MyError):
+                    g.throw(MyError, MyError(), None)
+
+        def test3(anext):
+            agen = agenfn()
+            with contextlib.closing(anext(agen, "default").__await__()) as g:
+                self.assertEqual(g.send(None), 1)
+                g.close()
+                with self.assertRaisesRegex(RuntimeError, 'cannot reuse'):
+                    self.assertEqual(g.send(None), 1)
+
+        def test4(anext):
+            @types.coroutine
+            def _async_yield(v):
+                yield v * 10
+                return (yield (v * 10 + 1))
+
+            async def agenfn():
+                try:
+                    await _async_yield(1)
+                except MyError:
+                    await _async_yield(2)
+                return
+                yield
+
+            agen = agenfn()
+            with contextlib.closing(anext(agen, "default").__await__()) as g:
+                self.assertEqual(g.send(None), 10)
+                self.assertEqual(g.throw(MyError, MyError(), None), 20)
+                with self.assertRaisesRegex(MyError, 'val'):
+                    g.throw(MyError, MyError('val'), None)
+
+        def test5(anext):
+            @types.coroutine
+            def _async_yield(v):
+                yield v * 10
+                return (yield (v * 10 + 1))
+
+            async def agenfn():
+                try:
+                    await _async_yield(1)
+                except MyError:
+                    return
+                yield 'aaa'
+
+            agen = agenfn()
+            with contextlib.closing(anext(agen, "default").__await__()) as g:
+                self.assertEqual(g.send(None), 10)
+                with self.assertRaisesRegex(StopIteration, 'default'):
+                    g.throw(MyError, MyError(), None)
+
+        def test6(anext):
+            @types.coroutine
+            def _async_yield(v):
+                yield v * 10
+                return (yield (v * 10 + 1))
+
+            async def agenfn():
+                await _async_yield(1)
+                yield 'aaa'
+
+            agen = agenfn()
+            with contextlib.closing(anext(agen, "default").__await__()) as g:
+                with self.assertRaises(MyError):
+                    g.throw(MyError, MyError(), None)
+
+        def run_test(test):
+            with self.subTest('pure-Python anext()'):
+                test(py_anext)
+            with self.subTest('builtin anext()'):
+                test(anext)
+
+        run_test(test1)
+        run_test(test2)
+        run_test(test3)
+        run_test(test4)
+        run_test(test5)
+        run_test(test6)
+
     def test_aiter_bad_args(self):
         async def gen():
             yield 1
diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-08-25-23-07-10.bpo-44963.5EET8y.rst b/Misc/NEWS.d/next/Core and Builtins/2021-08-25-23-07-10.bpo-44963.5EET8y.rst
new file mode 100644
index 0000000000000..9a54bda118e00
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2021-08-25-23-07-10.bpo-44963.5EET8y.rst	
@@ -0,0 +1,2 @@
+Implement ``send()`` and ``throw()`` methods for ``anext_awaitable``
+objects. Patch by Pablo Galindo.
diff --git a/Objects/iterobject.c b/Objects/iterobject.c
index 6961fc3b4a949..e493e41131b70 100644
--- a/Objects/iterobject.c
+++ b/Objects/iterobject.c
@@ -313,6 +313,36 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
     return 0;
 }
 
+static PyObject *
+anextawaitable_getiter(anextawaitableobject *obj)
+{
+    assert(obj->wrapped != NULL);
+    PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
+    if (awaitable == NULL) {
+        return NULL;
+    }
+    if (Py_TYPE(awaitable)->tp_iternext == NULL) {
+        /* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
+         * or an iterator. Of these, only coroutines lack tp_iternext.
+         */
+        assert(PyCoro_CheckExact(awaitable));
+        unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
+        PyObject *new_awaitable = getter(awaitable);
+        if (new_awaitable == NULL) {
+            Py_DECREF(awaitable);
+            return NULL;
+        }
+        Py_SETREF(awaitable, new_awaitable);
+        if (!PyIter_Check(awaitable)) {
+            PyErr_SetString(PyExc_TypeError,
+                            "__await__ returned a non-iterable");
+            Py_DECREF(awaitable);
+            return NULL;
+        }
+    }
+    return awaitable;
+}
+
 static PyObject *
 anextawaitable_iternext(anextawaitableobject *obj)
 {
@@ -336,30 +366,10 @@ anextawaitable_iternext(anextawaitableobject *obj)
      * Then `await anext(gen)` can just call
      * gen.__anext__().__next__()
      */
-    assert(obj->wrapped != NULL);
-    PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
+    PyObject *awaitable = anextawaitable_getiter(obj);
     if (awaitable == NULL) {
         return NULL;
     }
-    if (Py_TYPE(awaitable)->tp_iternext == NULL) {
-        /* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
-         * or an iterator. Of these, only coroutines lack tp_iternext.
-         */
-        assert(PyCoro_CheckExact(awaitable));
-        unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
-        PyObject *new_awaitable = getter(awaitable);
-        if (new_awaitable == NULL) {
-            Py_DECREF(awaitable);
-            return NULL;
-        }
-        Py_SETREF(awaitable, new_awaitable);
-        if (Py_TYPE(awaitable)->tp_iternext == NULL) {
-            PyErr_SetString(PyExc_TypeError,
-                            "__await__ returned a non-iterable");
-            Py_DECREF(awaitable);
-            return NULL;
-        }
-    }
     PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
     Py_DECREF(awaitable);
     if (result != NULL) {
@@ -371,6 +381,70 @@ anextawaitable_iternext(anextawaitableobject *obj)
     return NULL;
 }
 
+
+static PyObject *
+anextawaitable_proxy(anextawaitableobject *obj, char *meth, PyObject *arg) {
+    PyObject *awaitable = anextawaitable_getiter(obj);
+    if (awaitable == NULL) {
+        return NULL;
+    }
+    PyObject *ret = PyObject_CallMethod(awaitable, meth, "O", arg);
+    Py_DECREF(awaitable);
+    if (ret != NULL) {
+        return ret;
+    }
+    if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration)) {
+        /* `anextawaitableobject` is only used by `anext()` when
+         * a default value is provided. So when we have a StopAsyncIteration
+         * exception we replace it with a `StopIteration(default)`, as if
+         * it was the return value of `__anext__()` coroutine.
+         */
+        _PyGen_SetStopIterationValue(obj->default_value);
+    }
+    return NULL;
+}
+
+
+static PyObject *
+anextawaitable_send(anextawaitableobject *obj, PyObject *arg) {
+    return anextawaitable_proxy(obj, "send", arg);
+}
+
+
+static PyObject *
+anextawaitable_throw(anextawaitableobject *obj, PyObject *arg) {
+    return anextawaitable_proxy(obj, "throw", arg);
+}
+
+
+static PyObject *
+anextawaitable_close(anextawaitableobject *obj, PyObject *arg) {
+    return anextawaitable_proxy(obj, "close", arg);
+}
+
+
+PyDoc_STRVAR(send_doc,
+"send(arg) -> send 'arg' into the wrapped iterator,\n\
+return next yielded value or raise StopIteration.");
+
+
+PyDoc_STRVAR(throw_doc,
+"throw(typ[,val[,tb]]) -> raise exception in the wrapped iterator,\n\
+return next yielded value or raise StopIteration.");
+
+
+PyDoc_STRVAR(close_doc,
+"close() -> raise GeneratorExit inside generator.");
+
+
+static PyMethodDef anextawaitable_methods[] = {
+    {"send",(PyCFunction)anextawaitable_send, METH_O, send_doc},
+    {"throw",(PyCFunction)anextawaitable_throw, METH_VARARGS, throw_doc},
+    {"close",(PyCFunction)anextawaitable_close, METH_VARARGS, close_doc},
+    {NULL, NULL}        /* Sentinel */
+};
+
+
 static PyAsyncMethods anextawaitable_as_async = {
     PyObject_SelfIter,                          /* am_await */
     0,                                          /* am_aiter */
@@ -407,7 +481,7 @@ PyTypeObject _PyAnextAwaitable_Type = {
     0,                                          /* tp_weaklistoffset */
     PyObject_SelfIter,                          /* tp_iter */
     (unaryfunc)anextawaitable_iternext,         /* tp_iternext */
-    0,                                          /* tp_methods */
+    anextawaitable_methods,                     /* tp_methods */
 };
 
 PyObject *



More information about the Python-checkins mailing list