[Python-checkins] cpython (merge 3.5 -> 3.6): Issue #23996: Added _PyGen_SetStopIterationValue for safe raising

serhiy.storchaka python-checkins at python.org
Sun Nov 6 11:48:03 EST 2016


https://hg.python.org/cpython/rev/a2c9f06ada28
changeset:   104931:a2c9f06ada28
branch:      3.6
parent:      104928:d9eb609be2ca
parent:      104930:bce18f5c0bc4
user:        Serhiy Storchaka <storchaka at gmail.com>
date:        Sun Nov 06 18:47:03 2016 +0200
summary:
  Issue #23996: Added _PyGen_SetStopIterationValue for safe raising
StopIteration with value. More safely handle non-normalized exceptions
in -_PyGen_FetchStopIterationValue.

files:
  Include/genobject.h         |   1 +
  Lib/test/test_asyncgen.py   |  79 +++++++++++++++++++++
  Lib/test/test_coroutines.py |  61 ++++++++++++++++
  Lib/test/test_generators.py |  21 +++++
  Lib/test/test_yield_from.py |  93 +++++++++++++++++-------
  Modules/_asynciomodule.c    |  22 +----
  Objects/genobject.c         |  72 +++++++++++++-----
  7 files changed, 281 insertions(+), 68 deletions(-)


diff --git a/Include/genobject.h b/Include/genobject.h
--- a/Include/genobject.h
+++ b/Include/genobject.h
@@ -41,6 +41,7 @@
 PyAPI_FUNC(PyObject *) PyGen_NewWithQualName(struct _frame *,
     PyObject *name, PyObject *qualname);
 PyAPI_FUNC(int) PyGen_NeedsFinalizing(PyGenObject *);
+PyAPI_FUNC(int) _PyGen_SetStopIterationValue(PyObject *);
 PyAPI_FUNC(int) _PyGen_FetchStopIterationValue(PyObject **);
 PyAPI_FUNC(PyObject *) _PyGen_Send(PyGenObject *, PyObject *);
 PyObject *_PyGen_yf(PyGenObject *);
diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py
--- a/Lib/test/test_asyncgen.py
+++ b/Lib/test/test_asyncgen.py
@@ -450,6 +450,48 @@
 
         self.loop.run_until_complete(run())
 
+    def test_async_gen_asyncio_anext_tuple(self):
+        async def foo():
+            try:
+                yield (1,)
+            except ZeroDivisionError:
+                yield (2,)
+
+        async def run():
+            it = foo().__aiter__()
+
+            self.assertEqual(await it.__anext__(), (1,))
+            with self.assertRaises(StopIteration) as cm:
+                it.__anext__().throw(ZeroDivisionError)
+            self.assertEqual(cm.exception.args[0], (2,))
+            with self.assertRaises(StopAsyncIteration):
+                await it.__anext__()
+
+        self.loop.run_until_complete(run())
+
+    def test_async_gen_asyncio_anext_stopiteration(self):
+        async def foo():
+            try:
+                yield StopIteration(1)
+            except ZeroDivisionError:
+                yield StopIteration(3)
+
+        async def run():
+            it = foo().__aiter__()
+
+            v = await it.__anext__()
+            self.assertIsInstance(v, StopIteration)
+            self.assertEqual(v.value, 1)
+            with self.assertRaises(StopIteration) as cm:
+                it.__anext__().throw(ZeroDivisionError)
+            v = cm.exception.args[0]
+            self.assertIsInstance(v, StopIteration)
+            self.assertEqual(v.value, 3)
+            with self.assertRaises(StopAsyncIteration):
+                await it.__anext__()
+
+        self.loop.run_until_complete(run())
+
     def test_async_gen_asyncio_aclose_06(self):
         async def foo():
             try:
@@ -759,6 +801,43 @@
             self.loop.run_until_complete(run())
         self.assertEqual(DONE, 1)
 
+    def test_async_gen_asyncio_athrow_tuple(self):
+        async def gen():
+            try:
+                yield 1
+            except ZeroDivisionError:
+                yield (2,)
+
+        async def run():
+            g = gen()
+            v = await g.asend(None)
+            self.assertEqual(v, 1)
+            v = await g.athrow(ZeroDivisionError)
+            self.assertEqual(v, (2,))
+            with self.assertRaises(StopAsyncIteration):
+                await g.asend(None)
+
+        self.loop.run_until_complete(run())
+
+    def test_async_gen_asyncio_athrow_stopiteration(self):
+        async def gen():
+            try:
+                yield 1
+            except ZeroDivisionError:
+                yield StopIteration(2)
+
+        async def run():
+            g = gen()
+            v = await g.asend(None)
+            self.assertEqual(v, 1)
+            v = await g.athrow(ZeroDivisionError)
+            self.assertIsInstance(v, StopIteration)
+            self.assertEqual(v.value, 2)
+            with self.assertRaises(StopAsyncIteration):
+                await g.asend(None)
+
+        self.loop.run_until_complete(run())
+
     def test_async_gen_asyncio_shutdown_01(self):
         finalized = 0
 
diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py
--- a/Lib/test/test_coroutines.py
+++ b/Lib/test/test_coroutines.py
@@ -838,6 +838,21 @@
             coro.close()
             self.assertEqual(CHK, 1)
 
+    def test_coro_wrapper_send_tuple(self):
+        async def foo():
+            return (10,)
+
+        result = run_async__await__(foo())
+        self.assertEqual(result, ([], (10,)))
+
+    def test_coro_wrapper_send_stop_iterator(self):
+        async def foo():
+            return StopIteration(10)
+
+        result = run_async__await__(foo())
+        self.assertIsInstance(result[1], StopIteration)
+        self.assertEqual(result[1].value, 10)
+
     def test_cr_await(self):
         @types.coroutine
         def a():
@@ -1665,6 +1680,52 @@
                 warnings.simplefilter("error")
                 run_async(foo())
 
+    def test_for_tuple(self):
+        class Done(Exception): pass
+
+        class AIter(tuple):
+            i = 0
+            def __aiter__(self):
+                return self
+            async def __anext__(self):
+                if self.i >= len(self):
+                    raise StopAsyncIteration
+                self.i += 1
+                return self[self.i - 1]
+
+        result = []
+        async def foo():
+            async for i in AIter([42]):
+                result.append(i)
+            raise Done
+
+        with self.assertRaises(Done):
+            foo().send(None)
+        self.assertEqual(result, [42])
+
+    def test_for_stop_iteration(self):
+        class Done(Exception): pass
+
+        class AIter(StopIteration):
+            i = 0
+            def __aiter__(self):
+                return self
+            async def __anext__(self):
+                if self.i:
+                    raise StopAsyncIteration
+                self.i += 1
+                return self.value
+
+        result = []
+        async def foo():
+            async for i in AIter(42):
+                result.append(i)
+            raise Done
+
+        with self.assertRaises(Done):
+            foo().send(None)
+        self.assertEqual(result, [42])
+
     def test_comp_1(self):
         async def f(i):
             return i
diff --git a/Lib/test/test_generators.py b/Lib/test/test_generators.py
--- a/Lib/test/test_generators.py
+++ b/Lib/test/test_generators.py
@@ -277,6 +277,27 @@
             # hence no warning.
             next(g)
 
+    def test_return_tuple(self):
+        def g():
+            return (yield 1)
+
+        gen = g()
+        self.assertEqual(next(gen), 1)
+        with self.assertRaises(StopIteration) as cm:
+            gen.send((2,))
+        self.assertEqual(cm.exception.value, (2,))
+
+    def test_return_stopiteration(self):
+        def g():
+            return (yield 1)
+
+        gen = g()
+        self.assertEqual(next(gen), 1)
+        with self.assertRaises(StopIteration) as cm:
+            gen.send(StopIteration(2))
+        self.assertIsInstance(cm.exception.value, StopIteration)
+        self.assertEqual(cm.exception.value.value, 2)
+
 
 class YieldFromTests(unittest.TestCase):
     def test_generator_gi_yieldfrom(self):
diff --git a/Lib/test/test_yield_from.py b/Lib/test/test_yield_from.py
--- a/Lib/test/test_yield_from.py
+++ b/Lib/test/test_yield_from.py
@@ -384,9 +384,10 @@
             trace.append("Starting g1")
             yield "g1 ham"
             ret = yield from g2()
-            trace.append("g2 returned %s" % (ret,))
-            ret = yield from g2(42)
-            trace.append("g2 returned %s" % (ret,))
+            trace.append("g2 returned %r" % (ret,))
+            for v in 1, (2,), StopIteration(3):
+                ret = yield from g2(v)
+                trace.append("g2 returned %r" % (ret,))
             yield "g1 eggs"
             trace.append("Finishing g1")
         def g2(v = None):
@@ -410,7 +411,17 @@
             "Yielded g2 spam",
             "Yielded g2 more spam",
             "Finishing g2",
-            "g2 returned 42",
+            "g2 returned 1",
+            "Starting g2",
+            "Yielded g2 spam",
+            "Yielded g2 more spam",
+            "Finishing g2",
+            "g2 returned (2,)",
+            "Starting g2",
+            "Yielded g2 spam",
+            "Yielded g2 more spam",
+            "Finishing g2",
+            "g2 returned StopIteration(3,)",
             "Yielded g1 eggs",
             "Finishing g1",
         ])
@@ -670,14 +681,16 @@
                 next(gi)
                 trace.append("f SHOULD NOT BE HERE")
             except StopIteration as e:
-                trace.append("f caught %s" % (repr(e),))
+                trace.append("f caught %r" % (e,))
         def g(r):
             trace.append("g starting")
             yield
-            trace.append("g returning %s" % (r,))
+            trace.append("g returning %r" % (r,))
             return r
         f(None)
-        f(42)
+        f(1)
+        f((2,))
+        f(StopIteration(3))
         self.assertEqual(trace,[
             "g starting",
             "f resuming g",
@@ -685,8 +698,16 @@
             "f caught StopIteration()",
             "g starting",
             "f resuming g",
-            "g returning 42",
-            "f caught StopIteration(42,)",
+            "g returning 1",
+            "f caught StopIteration(1,)",
+            "g starting",
+            "f resuming g",
+            "g returning (2,)",
+            "f caught StopIteration((2,),)",
+            "g starting",
+            "f resuming g",
+            "g returning StopIteration(3,)",
+            "f caught StopIteration(StopIteration(3,),)",
         ])
 
     def test_send_and_return_with_value(self):
@@ -706,22 +727,34 @@
         def g(r):
             trace.append("g starting")
             x = yield
-            trace.append("g received %s" % (x,))
-            trace.append("g returning %s" % (r,))
+            trace.append("g received %r" % (x,))
+            trace.append("g returning %r" % (r,))
             return r
         f(None)
-        f(42)
-        self.assertEqual(trace,[
+        f(1)
+        f((2,))
+        f(StopIteration(3))
+        self.assertEqual(trace, [
             "g starting",
             "f sending spam to g",
-            "g received spam",
+            "g received 'spam'",
             "g returning None",
             "f caught StopIteration()",
             "g starting",
             "f sending spam to g",
-            "g received spam",
-            "g returning 42",
-            "f caught StopIteration(42,)",
+            "g received 'spam'",
+            "g returning 1",
+            'f caught StopIteration(1,)',
+            'g starting',
+            'f sending spam to g',
+            "g received 'spam'",
+            'g returning (2,)',
+            'f caught StopIteration((2,),)',
+            'g starting',
+            'f sending spam to g',
+            "g received 'spam'",
+            'g returning StopIteration(3,)',
+            'f caught StopIteration(StopIteration(3,),)'
         ])
 
     def test_catching_exception_from_subgen_and_returning(self):
@@ -729,27 +762,29 @@
         Test catching an exception thrown into a
         subgenerator and returning a value
         """
-        trace = []
         def inner():
             try:
                 yield 1
             except ValueError:
                 trace.append("inner caught ValueError")
-            return 2
+            return value
 
         def outer():
             v = yield from inner()
-            trace.append("inner returned %r to outer" % v)
+            trace.append("inner returned %r to outer" % (v,))
             yield v
-        g = outer()
-        trace.append(next(g))
-        trace.append(g.throw(ValueError))
-        self.assertEqual(trace,[
-            1,
-            "inner caught ValueError",
-            "inner returned 2 to outer",
-            2,
-        ])
+
+        for value in 2, (2,), StopIteration(2):
+            trace = []
+            g = outer()
+            trace.append(next(g))
+            trace.append(repr(g.throw(ValueError)))
+            self.assertEqual(trace, [
+                1,
+                "inner caught ValueError",
+                "inner returned %r to outer" % (value,),
+                repr(value),
+            ])
 
     def test_throwing_GeneratorExit_into_subgen_that_returns(self):
         """
diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c
--- a/Modules/_asynciomodule.c
+++ b/Modules/_asynciomodule.c
@@ -997,26 +997,12 @@
 
     res = _asyncio_Future_result_impl(fut);
     if (res != NULL) {
-        /* The result of the Future is not an exception.
-
-           We construct an exception instance manually with
-           PyObject_CallFunctionObjArgs and pass it to PyErr_SetObject
-           (similarly to what genobject.c does).
-
-           We do this to handle a situation when "res" is a tuple, in which
-           case PyErr_SetObject would set the value of StopIteration to
-           the first element of the tuple.
-
-           (See PyErr_SetObject/_PyErr_CreateException code for details.)
-        */
-        PyObject *e = PyObject_CallFunctionObjArgs(
-            PyExc_StopIteration, res, NULL);
-        Py_DECREF(res);
-        if (e == NULL) {
+        /* The result of the Future is not an exception. */
+        if (_PyGen_SetStopIterationValue(res) < 0) {
+            Py_DECREF(res);
             return NULL;
         }
-        PyErr_SetObject(PyExc_StopIteration, e);
-        Py_DECREF(e);
+        Py_DECREF(res);
     }
 
     it->future = NULL;
diff --git a/Objects/genobject.c b/Objects/genobject.c
--- a/Objects/genobject.c
+++ b/Objects/genobject.c
@@ -208,16 +208,9 @@
             }
         }
         else {
-            PyObject *e = PyObject_CallFunctionObjArgs(
-                               PyExc_StopIteration, result, NULL);
-
             /* Async generators cannot return anything but None */
             assert(!PyAsyncGen_CheckExact(gen));
-
-            if (e != NULL) {
-                PyErr_SetObject(PyExc_StopIteration, e);
-                Py_DECREF(e);
-            }
+            _PyGen_SetStopIterationValue(result);
         }
         Py_CLEAR(result);
     }
@@ -562,6 +555,43 @@
 }
 
 /*
+ * Set StopIteration with specified value.  Value can be arbitrary object
+ * or NULL.
+ *
+ * Returns 0 if StopIteration is set and -1 if any other exception is set.
+ */
+int
+_PyGen_SetStopIterationValue(PyObject *value)
+{
+    PyObject *e;
+
+    if (value == NULL ||
+        (!PyTuple_Check(value) &&
+         !PyObject_TypeCheck(value, (PyTypeObject *) PyExc_StopIteration)))
+    {
+        /* Delay exception instantiation if we can */
+        PyErr_SetObject(PyExc_StopIteration, value);
+        return 0;
+    }
+    /* Construct an exception instance manually with
+     * PyObject_CallFunctionObjArgs and pass it to PyErr_SetObject.
+     *
+     * We do this to handle a situation when "value" is a tuple, in which
+     * case PyErr_SetObject would set the value of StopIteration to
+     * the first element of the tuple.
+     *
+     * (See PyErr_SetObject/_PyErr_CreateException code for details.)
+     */
+    e = PyObject_CallFunctionObjArgs(PyExc_StopIteration, value, NULL);
+    if (e == NULL) {
+        return -1;
+    }
+    PyErr_SetObject(PyExc_StopIteration, e);
+    Py_DECREF(e);
+    return 0;
+}
+
+/*
  *   If StopIteration exception is set, fetches its 'value'
  *   attribute if any, otherwise sets pvalue to None.
  *
@@ -571,7 +601,8 @@
  */
 
 int
-_PyGen_FetchStopIterationValue(PyObject **pvalue) {
+_PyGen_FetchStopIterationValue(PyObject **pvalue)
+{
     PyObject *et, *ev, *tb;
     PyObject *value = NULL;
 
@@ -583,8 +614,15 @@
                 value = ((PyStopIterationObject *)ev)->value;
                 Py_INCREF(value);
                 Py_DECREF(ev);
-            } else if (et == PyExc_StopIteration) {
-                /* avoid normalisation and take ev as value */
+            } else if (et == PyExc_StopIteration && !PyTuple_Check(ev)) {
+                /* Avoid normalisation and take ev as value.
+                 *
+                 * Normalization is required if the value is a tuple, in
+                 * that case the value of StopIteration would be set to
+                 * the first element of the tuple.
+                 *
+                 * (See _PyErr_CreateException code for details.)
+                 */
                 value = ev;
             } else {
                 /* normalisation required */
@@ -1106,7 +1144,7 @@
 static PyObject *
 aiter_wrapper_iternext(PyAIterWrapper *aw)
 {
-    PyErr_SetObject(PyExc_StopIteration, aw->ags_aiter);
+    _PyGen_SetStopIterationValue(aw->ags_aiter);
     return NULL;
 }
 
@@ -1504,16 +1542,8 @@
 
     if (_PyAsyncGenWrappedValue_CheckExact(result)) {
         /* async yield */
-        PyObject *e = PyObject_CallFunctionObjArgs(
-            PyExc_StopIteration,
-            ((_PyAsyncGenWrappedValue*)result)->agw_val,
-            NULL);
+        _PyGen_SetStopIterationValue(((_PyAsyncGenWrappedValue*)result)->agw_val);
         Py_DECREF(result);
-        if (e == NULL) {
-            return NULL;
-        }
-        PyErr_SetObject(PyExc_StopIteration, e);
-        Py_DECREF(e);
         return NULL;
     }
 

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list