[Python-checkins] bpo-29302: Implement contextlib.AsyncExitStack. (#4790)

Yury Selivanov webhook-mailer at python.org
Thu Jan 25 15:51:27 EST 2018


https://github.com/python/cpython/commit/1aa094f74039cd20fdc7df56c68f6848c18ce4dd
commit: 1aa094f74039cd20fdc7df56c68f6848c18ce4dd
branch: master
author: Ilya Kulakov <kulakov.ilya at gmail.com>
committer: Yury Selivanov <yury at magic.io>
date: 2018-01-25T15:51:18-05:00
summary:

bpo-29302: Implement contextlib.AsyncExitStack. (#4790)

files:
A Misc/NEWS.d/next/Library/2017-12-11-15-14-55.bpo-29302.Nczj9l.rst
M Doc/library/contextlib.rst
M Doc/whatsnew/3.7.rst
M Lib/contextlib.py
M Lib/test/test_contextlib.py
M Lib/test/test_contextlib_async.py

diff --git a/Doc/library/contextlib.rst b/Doc/library/contextlib.rst
index faa6c8ac23b..54d3a8e4296 100644
--- a/Doc/library/contextlib.rst
+++ b/Doc/library/contextlib.rst
@@ -435,6 +435,44 @@ Functions and classes provided:
       callbacks registered, the arguments passed in will indicate that no
       exception occurred.
 
+.. class:: AsyncExitStack()
+
+   An :ref:`asynchronous context manager <async-context-managers>`, similar
+   to :class:`ExitStack`, that supports combining both synchronous and
+   asynchronous context managers, as well as having coroutines for
+   cleanup logic.
+
+   The :meth:`close` method is not implemented, :meth:`aclose` must be used
+   instead.
+
+   .. method:: enter_async_context(cm)
+
+      Similar to :meth:`enter_context` but expects an asynchronous context
+      manager.
+
+   .. method:: push_async_exit(exit)
+
+      Similar to :meth:`push` but expects either an asynchronous context manager
+      or a coroutine.
+
+   .. method:: push_async_callback(callback, *args, **kwds)
+
+      Similar to :meth:`callback` but expects a coroutine.
+
+   .. method:: aclose()
+
+      Similar to :meth:`close` but properly handles awaitables.
+
+   Continuing the example for :func:`asynccontextmanager`::
+
+      async with AsyncExitStack() as stack:
+          connections = [await stack.enter_async_context(get_connection())
+              for i in range(5)]
+          # All opened connections will automatically be released at the end of
+          # the async with statement, even if attempts to open a connection
+          # later in the list raise an exception.
+
+   .. versionadded:: 3.7
 
 Examples and Recipes
 --------------------
diff --git a/Doc/whatsnew/3.7.rst b/Doc/whatsnew/3.7.rst
index e52e32c4736..e9f057029eb 100644
--- a/Doc/whatsnew/3.7.rst
+++ b/Doc/whatsnew/3.7.rst
@@ -379,6 +379,9 @@ contextlib
 :class:`~contextlib.AbstractAsyncContextManager` have been added. (Contributed
 by Jelle Zijlstra in :issue:`29679` and :issue:`30241`.)
 
+:class:`contextlib.AsyncExitStack` has been added. (Contributed by
+Alexander Mohr and Ilya Kulakov in :issue:`29302`.)
+
 cProfile
 --------
 
diff --git a/Lib/contextlib.py b/Lib/contextlib.py
index 96c8c22084a..ef8f8c9f55b 100644
--- a/Lib/contextlib.py
+++ b/Lib/contextlib.py
@@ -7,7 +7,7 @@
 
 __all__ = ["asynccontextmanager", "contextmanager", "closing", "nullcontext",
            "AbstractContextManager", "AbstractAsyncContextManager",
-           "ContextDecorator", "ExitStack",
+           "AsyncExitStack", "ContextDecorator", "ExitStack",
            "redirect_stdout", "redirect_stderr", "suppress"]
 
 
@@ -365,85 +365,102 @@ def __exit__(self, exctype, excinst, exctb):
         return exctype is not None and issubclass(exctype, self._exceptions)
 
 
-# Inspired by discussions on http://bugs.python.org/issue13585
-class ExitStack(AbstractContextManager):
-    """Context manager for dynamic management of a stack of exit callbacks
+class _BaseExitStack:
+    """A base class for ExitStack and AsyncExitStack."""
 
-    For example:
+    @staticmethod
+    def _create_exit_wrapper(cm, cm_exit):
+        def _exit_wrapper(exc_type, exc, tb):
+            return cm_exit(cm, exc_type, exc, tb)
+        return _exit_wrapper
 
-        with ExitStack() as stack:
-            files = [stack.enter_context(open(fname)) for fname in filenames]
-            # All opened files will automatically be closed at the end of
-            # the with statement, even if attempts to open files later
-            # in the list raise an exception
+    @staticmethod
+    def _create_cb_wrapper(callback, *args, **kwds):
+        def _exit_wrapper(exc_type, exc, tb):
+            callback(*args, **kwds)
+        return _exit_wrapper
 
-    """
     def __init__(self):
         self._exit_callbacks = deque()
 
     def pop_all(self):
-        """Preserve the context stack by transferring it to a new instance"""
+        """Preserve the context stack by transferring it to a new instance."""
         new_stack = type(self)()
         new_stack._exit_callbacks = self._exit_callbacks
         self._exit_callbacks = deque()
         return new_stack
 
-    def _push_cm_exit(self, cm, cm_exit):
-        """Helper to correctly register callbacks to __exit__ methods"""
-        def _exit_wrapper(*exc_details):
-            return cm_exit(cm, *exc_details)
-        _exit_wrapper.__self__ = cm
-        self.push(_exit_wrapper)
-
     def push(self, exit):
-        """Registers a callback with the standard __exit__ method signature
-
-        Can suppress exceptions the same way __exit__ methods can.
+        """Registers a callback with the standard __exit__ method signature.
 
+        Can suppress exceptions the same way __exit__ method can.
         Also accepts any object with an __exit__ method (registering a call
-        to the method instead of the object itself)
+        to the method instead of the object itself).
         """
         # We use an unbound method rather than a bound method to follow
-        # the standard lookup behaviour for special methods
+        # the standard lookup behaviour for special methods.
         _cb_type = type(exit)
+
         try:
             exit_method = _cb_type.__exit__
         except AttributeError:
-            # Not a context manager, so assume its a callable
-            self._exit_callbacks.append(exit)
+            # Not a context manager, so assume it's a callable.
+            self._push_exit_callback(exit)
         else:
             self._push_cm_exit(exit, exit_method)
-        return exit # Allow use as a decorator
-
-    def callback(self, callback, *args, **kwds):
-        """Registers an arbitrary callback and arguments.
-
-        Cannot suppress exceptions.
-        """
-        def _exit_wrapper(exc_type, exc, tb):
-            callback(*args, **kwds)
-        # We changed the signature, so using @wraps is not appropriate, but
-        # setting __wrapped__ may still help with introspection
-        _exit_wrapper.__wrapped__ = callback
-        self.push(_exit_wrapper)
-        return callback # Allow use as a decorator
+        return exit  # Allow use as a decorator.
 
     def enter_context(self, cm):
-        """Enters the supplied context manager
+        """Enters the supplied context manager.
 
         If successful, also pushes its __exit__ method as a callback and
         returns the result of the __enter__ method.
         """
-        # We look up the special methods on the type to match the with statement
+        # We look up the special methods on the type to match the with
+        # statement.
         _cm_type = type(cm)
         _exit = _cm_type.__exit__
         result = _cm_type.__enter__(cm)
         self._push_cm_exit(cm, _exit)
         return result
 
-    def close(self):
-        """Immediately unwind the context stack"""
-        self.__exit__(None, None, None)
+    def callback(self, callback, *args, **kwds):
+        """Registers an arbitrary callback and arguments.
+
+        Cannot suppress exceptions.
+        """
+        _exit_wrapper = self._create_cb_wrapper(callback, *args, **kwds)
+
+        # We changed the signature, so using @wraps is not appropriate, but
+        # setting __wrapped__ may still help with introspection.
+        _exit_wrapper.__wrapped__ = callback
+        self._push_exit_callback(_exit_wrapper)
+        return callback  # Allow use as a decorator
+
+    def _push_cm_exit(self, cm, cm_exit):
+        """Helper to correctly register callbacks to __exit__ methods."""
+        _exit_wrapper = self._create_exit_wrapper(cm, cm_exit)
+        _exit_wrapper.__self__ = cm
+        self._push_exit_callback(_exit_wrapper, True)
+
+    def _push_exit_callback(self, callback, is_sync=True):
+        self._exit_callbacks.append((is_sync, callback))
+
+
+# Inspired by discussions on http://bugs.python.org/issue13585
+class ExitStack(_BaseExitStack, AbstractContextManager):
+    """Context manager for dynamic management of a stack of exit callbacks.
+
+    For example:
+        with ExitStack() as stack:
+            files = [stack.enter_context(open(fname)) for fname in filenames]
+            # All opened files will automatically be closed at the end of
+            # the with statement, even if attempts to open files later
+            # in the list raise an exception.
+    """
+
+    def __enter__(self):
+        return self
 
     def __exit__(self, *exc_details):
         received_exc = exc_details[0] is not None
@@ -470,7 +487,8 @@ def _fix_exception_context(new_exc, old_exc):
         suppressed_exc = False
         pending_raise = False
         while self._exit_callbacks:
-            cb = self._exit_callbacks.pop()
+            is_sync, cb = self._exit_callbacks.pop()
+            assert is_sync
             try:
                 if cb(*exc_details):
                     suppressed_exc = True
@@ -493,6 +511,147 @@ def _fix_exception_context(new_exc, old_exc):
                 raise
         return received_exc and suppressed_exc
 
+    def close(self):
+        """Immediately unwind the context stack."""
+        self.__exit__(None, None, None)
+
+
+# Inspired by discussions on https://bugs.python.org/issue29302
+class AsyncExitStack(_BaseExitStack, AbstractAsyncContextManager):
+    """Async context manager for dynamic management of a stack of exit
+    callbacks.
+
+    For example:
+        async with AsyncExitStack() as stack:
+            connections = [await stack.enter_async_context(get_connection())
+                for i in range(5)]
+            # All opened connections will automatically be released at the
+            # end of the async with statement, even if attempts to open a
+            # connection later in the list raise an exception.
+    """
+
+    @staticmethod
+    def _create_async_exit_wrapper(cm, cm_exit):
+        async def _exit_wrapper(exc_type, exc, tb):
+            return await cm_exit(cm, exc_type, exc, tb)
+        return _exit_wrapper
+
+    @staticmethod
+    def _create_async_cb_wrapper(callback, *args, **kwds):
+        async def _exit_wrapper(exc_type, exc, tb):
+            await callback(*args, **kwds)
+        return _exit_wrapper
+
+    async def enter_async_context(self, cm):
+        """Enters the supplied async context manager.
+
+        If successful, also pushes its __aexit__ method as a callback and
+        returns the result of the __aenter__ method.
+        """
+        _cm_type = type(cm)
+        _exit = _cm_type.__aexit__
+        result = await _cm_type.__aenter__(cm)
+        self._push_async_cm_exit(cm, _exit)
+        return result
+
+    def push_async_exit(self, exit):
+        """Registers a coroutine function with the standard __aexit__ method
+        signature.
+
+        Can suppress exceptions the same way __aexit__ method can.
+        Also accepts any object with an __aexit__ method (registering a call
+        to the method instead of the object itself).
+        """
+        _cb_type = type(exit)
+        try:
+            exit_method = _cb_type.__aexit__
+        except AttributeError:
+            # Not an async context manager, so assume it's a coroutine function
+            self._push_exit_callback(exit, False)
+        else:
+            self._push_async_cm_exit(exit, exit_method)
+        return exit  # Allow use as a decorator
+
+    def push_async_callback(self, callback, *args, **kwds):
+        """Registers an arbitrary coroutine function and arguments.
+
+        Cannot suppress exceptions.
+        """
+        _exit_wrapper = self._create_async_cb_wrapper(callback, *args, **kwds)
+
+        # We changed the signature, so using @wraps is not appropriate, but
+        # setting __wrapped__ may still help with introspection.
+        _exit_wrapper.__wrapped__ = callback
+        self._push_exit_callback(_exit_wrapper, False)
+        return callback  # Allow use as a decorator
+
+    async def aclose(self):
+        """Immediately unwind the context stack."""
+        await self.__aexit__(None, None, None)
+
+    def _push_async_cm_exit(self, cm, cm_exit):
+        """Helper to correctly register coroutine function to __aexit__
+        method."""
+        _exit_wrapper = self._create_async_exit_wrapper(cm, cm_exit)
+        _exit_wrapper.__self__ = cm
+        self._push_exit_callback(_exit_wrapper, False)
+
+    async def __aenter__(self):
+        return self
+
+    async def __aexit__(self, *exc_details):
+        received_exc = exc_details[0] is not None
+
+        # We manipulate the exception state so it behaves as though
+        # we were actually nesting multiple with statements
+        frame_exc = sys.exc_info()[1]
+        def _fix_exception_context(new_exc, old_exc):
+            # Context may not be correct, so find the end of the chain
+            while 1:
+                exc_context = new_exc.__context__
+                if exc_context is old_exc:
+                    # Context is already set correctly (see issue 20317)
+                    return
+                if exc_context is None or exc_context is frame_exc:
+                    break
+                new_exc = exc_context
+            # Change the end of the chain to point to the exception
+            # we expect it to reference
+            new_exc.__context__ = old_exc
+
+        # Callbacks are invoked in LIFO order to match the behaviour of
+        # nested context managers
+        suppressed_exc = False
+        pending_raise = False
+        while self._exit_callbacks:
+            is_sync, cb = self._exit_callbacks.pop()
+            try:
+                if is_sync:
+                    cb_suppress = cb(*exc_details)
+                else:
+                    cb_suppress = await cb(*exc_details)
+
+                if cb_suppress:
+                    suppressed_exc = True
+                    pending_raise = False
+                    exc_details = (None, None, None)
+            except:
+                new_exc_details = sys.exc_info()
+                # simulate the stack of exceptions by setting the context
+                _fix_exception_context(new_exc_details[1], exc_details[1])
+                pending_raise = True
+                exc_details = new_exc_details
+        if pending_raise:
+            try:
+                # bare "raise exc_details[1]" replaces our carefully
+                # set-up context
+                fixed_ctx = exc_details[1].__context__
+                raise exc_details[1]
+            except BaseException:
+                exc_details[1].__context__ = fixed_ctx
+                raise
+        return received_exc and suppressed_exc
+
 
 class nullcontext(AbstractContextManager):
     """Context manager that does no additional processing.
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index 1a5e6edad9b..96ceaa7f060 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -1,5 +1,6 @@
 """Unit tests for contextlib.py, and other context managers."""
 
+import asyncio
 import io
 import sys
 import tempfile
@@ -505,17 +506,18 @@ def test(x):
         self.assertEqual(state, [1, 'something else', 999])
 
 
-class TestExitStack(unittest.TestCase):
+class TestBaseExitStack:
+    exit_stack = None
 
     @support.requires_docstrings
     def test_instance_docs(self):
         # Issue 19330: ensure context manager instances have good docstrings
-        cm_docstring = ExitStack.__doc__
-        obj = ExitStack()
+        cm_docstring = self.exit_stack.__doc__
+        obj = self.exit_stack()
         self.assertEqual(obj.__doc__, cm_docstring)
 
     def test_no_resources(self):
-        with ExitStack():
+        with self.exit_stack():
             pass
 
     def test_callback(self):
@@ -531,7 +533,7 @@ def test_callback(self):
         def _exit(*args, **kwds):
             """Test metadata propagation"""
             result.append((args, kwds))
-        with ExitStack() as stack:
+        with self.exit_stack() as stack:
             for args, kwds in reversed(expected):
                 if args and kwds:
                     f = stack.callback(_exit, *args, **kwds)
@@ -543,9 +545,9 @@ def _exit(*args, **kwds):
                     f = stack.callback(_exit)
                 self.assertIs(f, _exit)
             for wrapper in stack._exit_callbacks:
-                self.assertIs(wrapper.__wrapped__, _exit)
-                self.assertNotEqual(wrapper.__name__, _exit.__name__)
-                self.assertIsNone(wrapper.__doc__, _exit.__doc__)
+                self.assertIs(wrapper[1].__wrapped__, _exit)
+                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
+                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
         self.assertEqual(result, expected)
 
     def test_push(self):
@@ -565,21 +567,21 @@ def __enter__(self):
                 self.fail("Should not be called!")
             def __exit__(self, *exc_details):
                 self.check_exc(*exc_details)
-        with ExitStack() as stack:
+        with self.exit_stack() as stack:
             stack.push(_expect_ok)
-            self.assertIs(stack._exit_callbacks[-1], _expect_ok)
+            self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
             cm = ExitCM(_expect_ok)
             stack.push(cm)
-            self.assertIs(stack._exit_callbacks[-1].__self__, cm)
+            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
             stack.push(_suppress_exc)
-            self.assertIs(stack._exit_callbacks[-1], _suppress_exc)
+            self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
             cm = ExitCM(_expect_exc)
             stack.push(cm)
-            self.assertIs(stack._exit_callbacks[-1].__self__, cm)
+            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
             stack.push(_expect_exc)
-            self.assertIs(stack._exit_callbacks[-1], _expect_exc)
+            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
             stack.push(_expect_exc)
-            self.assertIs(stack._exit_callbacks[-1], _expect_exc)
+            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
             1/0
 
     def test_enter_context(self):
@@ -591,19 +593,19 @@ def __exit__(self, *exc_details):
 
         result = []
         cm = TestCM()
-        with ExitStack() as stack:
+        with self.exit_stack() as stack:
             @stack.callback  # Registered first => cleaned up last
             def _exit():
                 result.append(4)
             self.assertIsNotNone(_exit)
             stack.enter_context(cm)
-            self.assertIs(stack._exit_callbacks[-1].__self__, cm)
+            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
             result.append(2)
         self.assertEqual(result, [1, 2, 3, 4])
 
     def test_close(self):
         result = []
-        with ExitStack() as stack:
+        with self.exit_stack() as stack:
             @stack.callback
             def _exit():
                 result.append(1)
@@ -614,7 +616,7 @@ def _exit():
 
     def test_pop_all(self):
         result = []
-        with ExitStack() as stack:
+        with self.exit_stack() as stack:
             @stack.callback
             def _exit():
                 result.append(3)
@@ -627,12 +629,12 @@ def _exit():
 
     def test_exit_raise(self):
         with self.assertRaises(ZeroDivisionError):
-            with ExitStack() as stack:
+            with self.exit_stack() as stack:
                 stack.push(lambda *exc: False)
                 1/0
 
     def test_exit_suppress(self):
-        with ExitStack() as stack:
+        with self.exit_stack() as stack:
             stack.push(lambda *exc: True)
             1/0
 
@@ -696,7 +698,7 @@ def suppress_exc(*exc_details):
             return True
 
         try:
-            with ExitStack() as stack:
+            with self.exit_stack() as stack:
                 stack.callback(raise_exc, IndexError)
                 stack.callback(raise_exc, KeyError)
                 stack.callback(raise_exc, AttributeError)
@@ -724,7 +726,7 @@ def suppress_exc(*exc_details):
             return True
 
         try:
-            with ExitStack() as stack:
+            with self.exit_stack() as stack:
                 stack.callback(lambda: None)
                 stack.callback(raise_exc, IndexError)
         except Exception as exc:
@@ -733,7 +735,7 @@ def suppress_exc(*exc_details):
             self.fail("Expected IndexError, but no exception was raised")
 
         try:
-            with ExitStack() as stack:
+            with self.exit_stack() as stack:
                 stack.callback(raise_exc, KeyError)
                 stack.push(suppress_exc)
                 stack.callback(raise_exc, IndexError)
@@ -760,7 +762,7 @@ def gets_the_context_right(exc):
         # fix, ExitStack would try to fix it *again* and get into an
         # infinite self-referential loop
         try:
-            with ExitStack() as stack:
+            with self.exit_stack() as stack:
                 stack.enter_context(gets_the_context_right(exc4))
                 stack.enter_context(gets_the_context_right(exc3))
                 stack.enter_context(gets_the_context_right(exc2))
@@ -787,7 +789,7 @@ def raise_nested(inner_exc, outer_exc):
         exc4 = Exception(4)
         exc5 = Exception(5)
         try:
-            with ExitStack() as stack:
+            with self.exit_stack() as stack:
                 stack.callback(raise_nested, exc4, exc5)
                 stack.callback(raise_nested, exc2, exc3)
                 raise exc1
@@ -801,27 +803,25 @@ def raise_nested(inner_exc, outer_exc):
             self.assertIsNone(
                 exc.__context__.__context__.__context__.__context__.__context__)
 
-
-
     def test_body_exception_suppress(self):
         def suppress_exc(*exc_details):
             return True
         try:
-            with ExitStack() as stack:
+            with self.exit_stack() as stack:
                 stack.push(suppress_exc)
                 1/0
         except IndexError as exc:
             self.fail("Expected no exception, got IndexError")
 
     def test_exit_exception_chaining_suppress(self):
-        with ExitStack() as stack:
+        with self.exit_stack() as stack:
             stack.push(lambda *exc: True)
             stack.push(lambda *exc: 1/0)
             stack.push(lambda *exc: {}[1])
 
     def test_excessive_nesting(self):
         # The original implementation would die with RecursionError here
-        with ExitStack() as stack:
+        with self.exit_stack() as stack:
             for i in range(10000):
                 stack.callback(int)
 
@@ -829,10 +829,10 @@ def test_instance_bypass(self):
         class Example(object): pass
         cm = Example()
         cm.__exit__ = object()
-        stack = ExitStack()
+        stack = self.exit_stack()
         self.assertRaises(AttributeError, stack.enter_context, cm)
         stack.push(cm)
-        self.assertIs(stack._exit_callbacks[-1], cm)
+        self.assertIs(stack._exit_callbacks[-1][1], cm)
 
     def test_dont_reraise_RuntimeError(self):
         # https://bugs.python.org/issue27122
@@ -856,7 +856,7 @@ def first():
         # The UniqueRuntimeError should be caught by second()'s exception
         # handler which chain raised a new UniqueException.
         with self.assertRaises(UniqueException) as err_ctx:
-            with ExitStack() as es_ctx:
+            with self.exit_stack() as es_ctx:
                 es_ctx.enter_context(second())
                 es_ctx.enter_context(first())
                 raise UniqueRuntimeError("please no infinite loop.")
@@ -869,6 +869,10 @@ def first():
         self.assertIs(exc.__cause__, exc.__context__)
 
 
+class TestExitStack(TestBaseExitStack, unittest.TestCase):
+    exit_stack = ExitStack
+
+
 class TestRedirectStream:
 
     redirect_stream = None
diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py
index 447ca965122..879ddbe0e11 100644
--- a/Lib/test/test_contextlib_async.py
+++ b/Lib/test/test_contextlib_async.py
@@ -1,9 +1,11 @@
 import asyncio
-from contextlib import asynccontextmanager, AbstractAsyncContextManager
+from contextlib import asynccontextmanager, AbstractAsyncContextManager, AsyncExitStack
 import functools
 from test import support
 import unittest
 
+from .test_contextlib import TestBaseExitStack
+
 
 def _async_test(func):
     """Decorator to turn an async function into a test case."""
@@ -255,5 +257,168 @@ def test_contextmanager_doc_attrib(self):
             self.assertEqual(target, (11, 22, 33, 44))
 
 
+class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
+    class SyncAsyncExitStack(AsyncExitStack):
+        @staticmethod
+        def run_coroutine(coro):
+            loop = asyncio.get_event_loop()
+
+            f = asyncio.ensure_future(coro)
+            f.add_done_callback(lambda f: loop.stop())
+            loop.run_forever()
+
+            exc = f.exception()
+
+            if not exc:
+                return f.result()
+            else:
+                context = exc.__context__
+
+                try:
+                    raise exc
+                except:
+                    exc.__context__ = context
+                    raise exc
+
+        def close(self):
+            return self.run_coroutine(self.aclose())
+
+        def __enter__(self):
+            return self.run_coroutine(self.__aenter__())
+
+        def __exit__(self, *exc_details):
+            return self.run_coroutine(self.__aexit__(*exc_details))
+
+    exit_stack = SyncAsyncExitStack
+
+    def setUp(self):
+        self.loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(self.loop)
+        self.addCleanup(self.loop.close)
+
+    @_async_test
+    async def test_async_callback(self):
+        expected = [
+            ((), {}),
+            ((1,), {}),
+            ((1,2), {}),
+            ((), dict(example=1)),
+            ((1,), dict(example=1)),
+            ((1,2), dict(example=1)),
+        ]
+        result = []
+        async def _exit(*args, **kwds):
+            """Test metadata propagation"""
+            result.append((args, kwds))
+
+        async with AsyncExitStack() as stack:
+            for args, kwds in reversed(expected):
+                if args and kwds:
+                    f = stack.push_async_callback(_exit, *args, **kwds)
+                elif args:
+                    f = stack.push_async_callback(_exit, *args)
+                elif kwds:
+                    f = stack.push_async_callback(_exit, **kwds)
+                else:
+                    f = stack.push_async_callback(_exit)
+                self.assertIs(f, _exit)
+            for wrapper in stack._exit_callbacks:
+                self.assertIs(wrapper[1].__wrapped__, _exit)
+                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
+                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
+
+        self.assertEqual(result, expected)
+
+    @_async_test
+    async def test_async_push(self):
+        exc_raised = ZeroDivisionError
+        async def _expect_exc(exc_type, exc, exc_tb):
+            self.assertIs(exc_type, exc_raised)
+        async def _suppress_exc(*exc_details):
+            return True
+        async def _expect_ok(exc_type, exc, exc_tb):
+            self.assertIsNone(exc_type)
+            self.assertIsNone(exc)
+            self.assertIsNone(exc_tb)
+        class ExitCM(object):
+            def __init__(self, check_exc):
+                self.check_exc = check_exc
+            async def __aenter__(self):
+                self.fail("Should not be called!")
+            async def __aexit__(self, *exc_details):
+                await self.check_exc(*exc_details)
+
+        async with self.exit_stack() as stack:
+            stack.push_async_exit(_expect_ok)
+            self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
+            cm = ExitCM(_expect_ok)
+            stack.push_async_exit(cm)
+            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
+            stack.push_async_exit(_suppress_exc)
+            self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
+            cm = ExitCM(_expect_exc)
+            stack.push_async_exit(cm)
+            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
+            stack.push_async_exit(_expect_exc)
+            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
+            stack.push_async_exit(_expect_exc)
+            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
+            1/0
+
+    @_async_test
+    async def test_async_enter_context(self):
+        class TestCM(object):
+            async def __aenter__(self):
+                result.append(1)
+            async def __aexit__(self, *exc_details):
+                result.append(3)
+
+        result = []
+        cm = TestCM()
+
+        async with AsyncExitStack() as stack:
+            @stack.push_async_callback  # Registered first => cleaned up last
+            async def _exit():
+                result.append(4)
+            self.assertIsNotNone(_exit)
+            await stack.enter_async_context(cm)
+            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
+            result.append(2)
+
+        self.assertEqual(result, [1, 2, 3, 4])
+
+    @_async_test
+    async def test_async_exit_exception_chaining(self):
+        # Ensure exception chaining matches the reference behaviour
+        async def raise_exc(exc):
+            raise exc
+
+        saved_details = None
+        async def suppress_exc(*exc_details):
+            nonlocal saved_details
+            saved_details = exc_details
+            return True
+
+        try:
+            async with self.exit_stack() as stack:
+                stack.push_async_callback(raise_exc, IndexError)
+                stack.push_async_callback(raise_exc, KeyError)
+                stack.push_async_callback(raise_exc, AttributeError)
+                stack.push_async_exit(suppress_exc)
+                stack.push_async_callback(raise_exc, ValueError)
+                1 / 0
+        except IndexError as exc:
+            self.assertIsInstance(exc.__context__, KeyError)
+            self.assertIsInstance(exc.__context__.__context__, AttributeError)
+            # Inner exceptions were suppressed
+            self.assertIsNone(exc.__context__.__context__.__context__)
+        else:
+            self.fail("Expected IndexError, but no exception was raised")
+        # Check the inner exceptions
+        inner_exc = saved_details[1]
+        self.assertIsInstance(inner_exc, ValueError)
+        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2017-12-11-15-14-55.bpo-29302.Nczj9l.rst b/Misc/NEWS.d/next/Library/2017-12-11-15-14-55.bpo-29302.Nczj9l.rst
new file mode 100644
index 00000000000..0030e2ce367
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2017-12-11-15-14-55.bpo-29302.Nczj9l.rst
@@ -0,0 +1 @@
+Add contextlib.AsyncExitStack. Patch by Alexander Mohr and Ilya Kulakov.



More information about the Python-checkins mailing list