[Python-checkins] bpo-46771: Implement asyncio context managers for handling timeouts (GH-31394)

gvanrossum webhook-mailer at python.org
Thu Mar 10 11:05:38 EST 2022


https://github.com/python/cpython/commit/f537b2a4fb86445ee3bd6ca7f10bc9d3a9f37da5
commit: f537b2a4fb86445ee3bd6ca7f10bc9d3a9f37da5
branch: main
author: Andrew Svetlov <andrew.svetlov at gmail.com>
committer: gvanrossum <gvanrossum at gmail.com>
date: 2022-03-10T08:05:20-08:00
summary:

bpo-46771: Implement asyncio context managers for handling timeouts (GH-31394)

Example:

async with asyncio.timeout(5):
    await some_task()

Will interrupt the await and raise TimeoutError if some_task() takes longer than 5 seconds.

Co-authored-by: Guido van Rossum <guido at python.org>

files:
A Lib/asyncio/timeouts.py
A Lib/test/test_asyncio/test_timeouts.py
A Misc/NEWS.d/next/Library/2022-02-21-11-41-23.bpo-464471.fL06TV.rst
M Lib/asyncio/__init__.py

diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py
index db1124cc9bd1e..fed16ec7c67fa 100644
--- a/Lib/asyncio/__init__.py
+++ b/Lib/asyncio/__init__.py
@@ -18,6 +18,7 @@
 from .subprocess import *
 from .tasks import *
 from .taskgroups import *
+from .timeouts import *
 from .threads import *
 from .transports import *
 
@@ -34,6 +35,7 @@
            subprocess.__all__ +
            tasks.__all__ +
            threads.__all__ +
+           timeouts.__all__ +
            transports.__all__)
 
 if sys.platform == 'win32':  # pragma: no cover
diff --git a/Lib/asyncio/timeouts.py b/Lib/asyncio/timeouts.py
new file mode 100644
index 0000000000000..a89205348ff24
--- /dev/null
+++ b/Lib/asyncio/timeouts.py
@@ -0,0 +1,151 @@
+import enum
+
+from types import TracebackType
+from typing import final, Optional, Type
+
+from . import events
+from . import exceptions
+from . import tasks
+
+
+__all__ = (
+    "Timeout",
+    "timeout",
+    "timeout_at",
+)
+
+
+class _State(enum.Enum):
+    CREATED = "created"
+    ENTERED = "active"
+    EXPIRING = "expiring"
+    EXPIRED = "expired"
+    EXITED = "finished"
+
+
+ at final
+class Timeout:
+
+    def __init__(self, when: Optional[float]) -> None:
+        self._state = _State.CREATED
+
+        self._timeout_handler: Optional[events.TimerHandle] = None
+        self._task: Optional[tasks.Task] = None
+        self._when = when
+
+    def when(self) -> Optional[float]:
+        return self._when
+
+    def reschedule(self, when: Optional[float]) -> None:
+        assert self._state is not _State.CREATED
+        if self._state is not _State.ENTERED:
+            raise RuntimeError(
+                f"Cannot change state of {self._state.value} Timeout",
+            )
+
+        self._when = when
+
+        if self._timeout_handler is not None:
+            self._timeout_handler.cancel()
+
+        if when is None:
+            self._timeout_handler = None
+        else:
+            loop = events.get_running_loop()
+            self._timeout_handler = loop.call_at(
+                when,
+                self._on_timeout,
+            )
+
+    def expired(self) -> bool:
+        """Is timeout expired during execution?"""
+        return self._state in (_State.EXPIRING, _State.EXPIRED)
+
+    def __repr__(self) -> str:
+        info = ['']
+        if self._state is _State.ENTERED:
+            when = round(self._when, 3) if self._when is not None else None
+            info.append(f"when={when}")
+        info_str = ' '.join(info)
+        return f"<Timeout [{self._state.value}]{info_str}>"
+
+    async def __aenter__(self) -> "Timeout":
+        self._state = _State.ENTERED
+        self._task = tasks.current_task()
+        if self._task is None:
+            raise RuntimeError("Timeout should be used inside a task")
+        self.reschedule(self._when)
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> Optional[bool]:
+        assert self._state in (_State.ENTERED, _State.EXPIRING)
+
+        if self._timeout_handler is not None:
+            self._timeout_handler.cancel()
+            self._timeout_handler = None
+
+        if self._state is _State.EXPIRING:
+            self._state = _State.EXPIRED
+
+            if self._task.uncancel() == 0 and exc_type is exceptions.CancelledError:
+                # Since there are no outstanding cancel requests, we're
+                # handling this.
+                raise TimeoutError
+        elif self._state is _State.ENTERED:
+            self._state = _State.EXITED
+
+        return None
+
+    def _on_timeout(self) -> None:
+        assert self._state is _State.ENTERED
+        self._task.cancel()
+        self._state = _State.EXPIRING
+        # drop the reference early
+        self._timeout_handler = None
+
+
+def timeout(delay: Optional[float]) -> Timeout:
+    """Timeout async context manager.
+
+    Useful in cases when you want to apply timeout logic around block
+    of code or in cases when asyncio.wait_for is not suitable. For example:
+
+    >>> async with asyncio.timeout(10):  # 10 seconds timeout
+    ...     await long_running_task()
+
+
+    delay - value in seconds or None to disable timeout logic
+
+    long_running_task() is interrupted by raising asyncio.CancelledError,
+    the top-most affected timeout() context manager converts CancelledError
+    into TimeoutError.
+    """
+    loop = events.get_running_loop()
+    return Timeout(loop.time() + delay if delay is not None else None)
+
+
+def timeout_at(when: Optional[float]) -> Timeout:
+    """Schedule the timeout at absolute time.
+
+    Like timeout() but argument gives absolute time in the same clock system
+    as loop.time().
+
+    Please note: it is not POSIX time but a time with
+    undefined starting base, e.g. the time of the system power on.
+
+    >>> async with asyncio.timeout_at(loop.time() + 10):
+    ...     await long_running_task()
+
+
+    when - a deadline when timeout occurs or None to disable timeout logic
+
+    long_running_task() is interrupted by raising asyncio.CancelledError,
+    the top-most affected timeout() context manager converts CancelledError
+    into TimeoutError.
+    """
+    return Timeout(when)
diff --git a/Lib/test/test_asyncio/test_timeouts.py b/Lib/test/test_asyncio/test_timeouts.py
new file mode 100644
index 0000000000000..ef1ab0acb390d
--- /dev/null
+++ b/Lib/test/test_asyncio/test_timeouts.py
@@ -0,0 +1,229 @@
+"""Tests for asyncio/timeouts.py"""
+
+import unittest
+import time
+
+import asyncio
+from asyncio import tasks
+
+
+def tearDownModule():
+    asyncio.set_event_loop_policy(None)
+
+
+class TimeoutTests(unittest.IsolatedAsyncioTestCase):
+
+    async def test_timeout_basic(self):
+        with self.assertRaises(TimeoutError):
+            async with asyncio.timeout(0.01) as cm:
+                await asyncio.sleep(10)
+        self.assertTrue(cm.expired())
+
+    async def test_timeout_at_basic(self):
+        loop = asyncio.get_running_loop()
+
+        with self.assertRaises(TimeoutError):
+            deadline = loop.time() + 0.01
+            async with asyncio.timeout_at(deadline) as cm:
+                await asyncio.sleep(10)
+        self.assertTrue(cm.expired())
+        self.assertEqual(deadline, cm.when())
+
+    async def test_nested_timeouts(self):
+        loop = asyncio.get_running_loop()
+        cancelled = False
+        with self.assertRaises(TimeoutError):
+            deadline = loop.time() + 0.01
+            async with asyncio.timeout_at(deadline) as cm1:
+                # Only the topmost context manager should raise TimeoutError
+                try:
+                    async with asyncio.timeout_at(deadline) as cm2:
+                        await asyncio.sleep(10)
+                except asyncio.CancelledError:
+                    cancelled = True
+                    raise
+        self.assertTrue(cancelled)
+        self.assertTrue(cm1.expired())
+        self.assertTrue(cm2.expired())
+
+    async def test_waiter_cancelled(self):
+        loop = asyncio.get_running_loop()
+        cancelled = False
+        with self.assertRaises(TimeoutError):
+            async with asyncio.timeout(0.01):
+                try:
+                    await asyncio.sleep(10)
+                except asyncio.CancelledError:
+                    cancelled = True
+                    raise
+        self.assertTrue(cancelled)
+
+    async def test_timeout_not_called(self):
+        loop = asyncio.get_running_loop()
+        t0 = loop.time()
+        async with asyncio.timeout(10) as cm:
+            await asyncio.sleep(0.01)
+        t1 = loop.time()
+
+        self.assertFalse(cm.expired())
+        # 2 sec for slow CI boxes
+        self.assertLess(t1-t0, 2)
+        self.assertGreater(cm.when(), t1)
+
+    async def test_timeout_disabled(self):
+        loop = asyncio.get_running_loop()
+        t0 = loop.time()
+        async with asyncio.timeout(None) as cm:
+            await asyncio.sleep(0.01)
+        t1 = loop.time()
+
+        self.assertFalse(cm.expired())
+        self.assertIsNone(cm.when())
+        # 2 sec for slow CI boxes
+        self.assertLess(t1-t0, 2)
+
+    async def test_timeout_at_disabled(self):
+        loop = asyncio.get_running_loop()
+        t0 = loop.time()
+        async with asyncio.timeout_at(None) as cm:
+            await asyncio.sleep(0.01)
+        t1 = loop.time()
+
+        self.assertFalse(cm.expired())
+        self.assertIsNone(cm.when())
+        # 2 sec for slow CI boxes
+        self.assertLess(t1-t0, 2)
+
+    async def test_timeout_zero(self):
+        loop = asyncio.get_running_loop()
+        t0 = loop.time()
+        with self.assertRaises(TimeoutError):
+            async with asyncio.timeout(0) as cm:
+                await asyncio.sleep(10)
+        t1 = loop.time()
+        self.assertTrue(cm.expired())
+        # 2 sec for slow CI boxes
+        self.assertLess(t1-t0, 2)
+        self.assertTrue(t0 <= cm.when() <= t1)
+
+    async def test_foreign_exception_passed(self):
+        with self.assertRaises(KeyError):
+            async with asyncio.timeout(0.01) as cm:
+                raise KeyError
+        self.assertFalse(cm.expired())
+
+    async def test_foreign_exception_on_timeout(self):
+        async def crash():
+            try:
+                await asyncio.sleep(1)
+            finally:
+                1/0
+        with self.assertRaises(ZeroDivisionError):
+            async with asyncio.timeout(0.01):
+                await crash()
+
+    async def test_foreign_cancel_doesnt_timeout_if_not_expired(self):
+        with self.assertRaises(asyncio.CancelledError):
+            async with asyncio.timeout(10) as cm:
+                asyncio.current_task().cancel()
+                await asyncio.sleep(10)
+        self.assertFalse(cm.expired())
+
+    async def test_outer_task_is_not_cancelled(self):
+        async def outer() -> None:
+            with self.assertRaises(TimeoutError):
+                async with asyncio.timeout(0.001):
+                    await asyncio.sleep(10)
+
+        task = asyncio.create_task(outer())
+        await task
+        self.assertFalse(task.cancelled())
+        self.assertTrue(task.done())
+
+    async def test_nested_timeouts_concurrent(self):
+        with self.assertRaises(TimeoutError):
+            async with asyncio.timeout(0.002):
+                with self.assertRaises(TimeoutError):
+                    async with asyncio.timeout(0.1):
+                        # Pretend we crunch some numbers.
+                        time.sleep(0.01)
+                        await asyncio.sleep(1)
+
+    async def test_nested_timeouts_loop_busy(self):
+        # After the inner timeout is an expensive operation which should
+        # be stopped by the outer timeout.
+        loop = asyncio.get_running_loop()
+        # Disable a message about long running task
+        loop.slow_callback_duration = 10
+        t0 = loop.time()
+        with self.assertRaises(TimeoutError):
+            async with asyncio.timeout(0.1):  # (1)
+                with self.assertRaises(TimeoutError):
+                    async with asyncio.timeout(0.01):  # (2)
+                        # Pretend the loop is busy for a while.
+                        time.sleep(0.1)
+                        await asyncio.sleep(1)
+                # TimeoutError was cought by (2)
+                await asyncio.sleep(10) # This sleep should be interrupted by (1)
+        t1 = loop.time()
+        self.assertTrue(t0 <= t1 <= t0 + 1)
+
+    async def test_reschedule(self):
+        loop = asyncio.get_running_loop()
+        fut = loop.create_future()
+        deadline1 = loop.time() + 10
+        deadline2 = deadline1 + 20
+
+        async def f():
+            async with asyncio.timeout_at(deadline1) as cm:
+                fut.set_result(cm)
+                await asyncio.sleep(50)
+
+        task = asyncio.create_task(f())
+        cm = await fut
+
+        self.assertEqual(cm.when(), deadline1)
+        cm.reschedule(deadline2)
+        self.assertEqual(cm.when(), deadline2)
+        cm.reschedule(None)
+        self.assertIsNone(cm.when())
+
+        task.cancel()
+
+        with self.assertRaises(asyncio.CancelledError):
+            await task
+        self.assertFalse(cm.expired())
+
+    async def test_repr_active(self):
+        async with asyncio.timeout(10) as cm:
+            self.assertRegex(repr(cm), r"<Timeout \[active\] when=\d+\.\d*>")
+
+    async def test_repr_expired(self):
+        with self.assertRaises(TimeoutError):
+            async with asyncio.timeout(0.01) as cm:
+                await asyncio.sleep(10)
+        self.assertEqual(repr(cm), "<Timeout [expired]>")
+
+    async def test_repr_finished(self):
+        async with asyncio.timeout(10) as cm:
+            await asyncio.sleep(0)
+
+        self.assertEqual(repr(cm), "<Timeout [finished]>")
+
+    async def test_repr_disabled(self):
+        async with asyncio.timeout(None) as cm:
+            self.assertEqual(repr(cm), r"<Timeout [active] when=None>")
+
+    async def test_nested_timeout_in_finally(self):
+        with self.assertRaises(TimeoutError):
+            async with asyncio.timeout(0.01):
+                try:
+                    await asyncio.sleep(1)
+                finally:
+                    with self.assertRaises(TimeoutError):
+                        async with asyncio.timeout(0.01):
+                            await asyncio.sleep(10)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2022-02-21-11-41-23.bpo-464471.fL06TV.rst b/Misc/NEWS.d/next/Library/2022-02-21-11-41-23.bpo-464471.fL06TV.rst
new file mode 100644
index 0000000000000..b8a48d658250f
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2022-02-21-11-41-23.bpo-464471.fL06TV.rst
@@ -0,0 +1,2 @@
+:func:`asyncio.timeout` and :func:`asyncio.timeout_at` context managers
+added. Patch by Tin Tvrtković and Andrew Svetlov.



More information about the Python-checkins mailing list