[Python-checkins] gh-90155: Fix bug in asyncio.Semaphore and strengthen FIFO guarantee (GH-93222)

miss-islington webhook-mailer at python.org
Thu Sep 22 12:58:41 EDT 2022


https://github.com/python/cpython/commit/773dbb9e3a7dc5d4a8560bc3ffb28c16758f159f
commit: 773dbb9e3a7dc5d4a8560bc3ffb28c16758f159f
branch: 3.11
author: Miss Islington (bot) <31488909+miss-islington at users.noreply.github.com>
committer: miss-islington <31488909+miss-islington at users.noreply.github.com>
date: 2022-09-22T09:58:35-07:00
summary:

gh-90155: Fix bug in asyncio.Semaphore and strengthen FIFO guarantee (GH-93222)


The main problem was that an unluckily timed task cancellation could cause
the semaphore to be stuck. There were also doubts about strict FIFO ordering
of tasks allowed to pass.

The Semaphore implementation was rewritten to be more similar to Lock.
Many tests for edge cases (including cancellation) were added.
(cherry picked from commit 24e03796248ab8c7f62d715c28156abe2f1c0d20)

Co-authored-by: Cyker Way <cykerway at gmail.com>

files:
A Misc/NEWS.d/next/Library/2022-05-25-15-57-39.gh-issue-90155.YMstB5.rst
M Lib/asyncio/locks.py
M Lib/test/test_asyncio/test_locks.py

diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py
index e71130274dd6..f8f590304e31 100644
--- a/Lib/asyncio/locks.py
+++ b/Lib/asyncio/locks.py
@@ -346,9 +346,8 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
     def __init__(self, value=1):
         if value < 0:
             raise ValueError("Semaphore initial value must be >= 0")
+        self._waiters = None
         self._value = value
-        self._waiters = collections.deque()
-        self._wakeup_scheduled = False
 
     def __repr__(self):
         res = super().__repr__()
@@ -357,16 +356,8 @@ def __repr__(self):
             extra = f'{extra}, waiters:{len(self._waiters)}'
         return f'<{res[1:-1]} [{extra}]>'
 
-    def _wake_up_next(self):
-        while self._waiters:
-            waiter = self._waiters.popleft()
-            if not waiter.done():
-                waiter.set_result(None)
-                self._wakeup_scheduled = True
-                return
-
     def locked(self):
-        """Returns True if semaphore can not be acquired immediately."""
+        """Returns True if semaphore counter is zero."""
         return self._value == 0
 
     async def acquire(self):
@@ -378,28 +369,57 @@ async def acquire(self):
         called release() to make it larger than 0, and then return
         True.
         """
-        # _wakeup_scheduled is set if *another* task is scheduled to wakeup
-        # but its acquire() is not resumed yet
-        while self._wakeup_scheduled or self._value <= 0:
-            fut = self._get_loop().create_future()
-            self._waiters.append(fut)
+        if (not self.locked() and (self._waiters is None or
+                all(w.cancelled() for w in self._waiters))):
+            self._value -= 1
+            return True
+
+        if self._waiters is None:
+            self._waiters = collections.deque()
+        fut = self._get_loop().create_future()
+        self._waiters.append(fut)
+
+        # Finally block should be called before the CancelledError
+        # handling as we don't want CancelledError to call
+        # _wake_up_first() and attempt to wake up itself.
+        try:
             try:
                 await fut
-                # reset _wakeup_scheduled *after* waiting for a future
-                self._wakeup_scheduled = False
-            except exceptions.CancelledError:
-                self._wake_up_next()
-                raise
+            finally:
+                self._waiters.remove(fut)
+        except exceptions.CancelledError:
+            if not self.locked():
+                self._wake_up_first()
+            raise
+
         self._value -= 1
+        if not self.locked():
+            self._wake_up_first()
         return True
 
     def release(self):
         """Release a semaphore, incrementing the internal counter by one.
+
         When it was zero on entry and another coroutine is waiting for it to
         become larger than zero again, wake up that coroutine.
         """
         self._value += 1
-        self._wake_up_next()
+        self._wake_up_first()
+
+    def _wake_up_first(self):
+        """Wake up the first waiter if it isn't done."""
+        if not self._waiters:
+            return
+        try:
+            fut = next(iter(self._waiters))
+        except StopIteration:
+            return
+
+        # .done() necessarily means that a waiter will wake up later on and
+        # either take the lock, or, if it was cancelled and lock wasn't
+        # taken already, will hit this again and wake up a new waiter.
+        if not fut.done():
+            fut.set_result(True)
 
 
 class BoundedSemaphore(Semaphore):
diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py
index 541b4907b6de..4b9d166e1a0f 100644
--- a/Lib/test/test_asyncio/test_locks.py
+++ b/Lib/test/test_asyncio/test_locks.py
@@ -5,6 +5,7 @@
 import re
 
 import asyncio
+import collections
 
 STR_RGX_REPR = (
     r'^<(?P<class>.*?) object at (?P<address>.*?)'
@@ -774,6 +775,9 @@ async def test_repr(self):
         self.assertTrue('waiters' not in repr(sem))
         self.assertTrue(RGX_REPR.match(repr(sem)))
 
+        if sem._waiters is None:
+            sem._waiters = collections.deque()
+
         sem._waiters.append(mock.Mock())
         self.assertTrue('waiters:1' in repr(sem))
         self.assertTrue(RGX_REPR.match(repr(sem)))
@@ -842,6 +846,7 @@ async def c4(result):
         sem.release()
         self.assertEqual(2, sem._value)
 
+        await asyncio.sleep(0)
         await asyncio.sleep(0)
         self.assertEqual(0, sem._value)
         self.assertEqual(3, len(result))
@@ -884,6 +889,7 @@ async def test_acquire_cancel_before_awoken(self):
         t2.cancel()
         sem.release()
 
+        await asyncio.sleep(0)
         await asyncio.sleep(0)
         num_done = sum(t.done() for t in [t3, t4])
         self.assertEqual(num_done, 1)
@@ -904,9 +910,32 @@ async def test_acquire_hang(self):
         t1.cancel()
         sem.release()
         await asyncio.sleep(0)
+        await asyncio.sleep(0)
         self.assertTrue(sem.locked())
         self.assertTrue(t2.done())
 
+    async def test_acquire_no_hang(self):
+
+        sem = asyncio.Semaphore(1)
+
+        async def c1():
+            async with sem:
+                await asyncio.sleep(0)
+            t2.cancel()
+
+        async def c2():
+            async with sem:
+                self.assertFalse(True)
+
+        t1 = asyncio.create_task(c1())
+        t2 = asyncio.create_task(c2())
+
+        r1, r2 = await asyncio.gather(t1, t2, return_exceptions=True)
+        self.assertTrue(r1 is None)
+        self.assertTrue(isinstance(r2, asyncio.CancelledError))
+
+        await asyncio.wait_for(sem.acquire(), timeout=1.0)
+
     def test_release_not_acquired(self):
         sem = asyncio.BoundedSemaphore()
 
@@ -945,6 +974,77 @@ async def coro(tag):
             result
         )
 
+    async def test_acquire_fifo_order_2(self):
+        sem = asyncio.Semaphore(1)
+        result = []
+
+        async def c1(result):
+            await sem.acquire()
+            result.append(1)
+            return True
+
+        async def c2(result):
+            await sem.acquire()
+            result.append(2)
+            sem.release()
+            await sem.acquire()
+            result.append(4)
+            return True
+
+        async def c3(result):
+            await sem.acquire()
+            result.append(3)
+            return True
+
+        t1 = asyncio.create_task(c1(result))
+        t2 = asyncio.create_task(c2(result))
+        t3 = asyncio.create_task(c3(result))
+
+        await asyncio.sleep(0)
+
+        sem.release()
+        sem.release()
+
+        tasks = [t1, t2, t3]
+        await asyncio.gather(*tasks)
+        self.assertEqual([1, 2, 3, 4], result)
+
+    async def test_acquire_fifo_order_3(self):
+        sem = asyncio.Semaphore(0)
+        result = []
+
+        async def c1(result):
+            await sem.acquire()
+            result.append(1)
+            return True
+
+        async def c2(result):
+            await sem.acquire()
+            result.append(2)
+            return True
+
+        async def c3(result):
+            await sem.acquire()
+            result.append(3)
+            return True
+
+        t1 = asyncio.create_task(c1(result))
+        t2 = asyncio.create_task(c2(result))
+        t3 = asyncio.create_task(c3(result))
+
+        await asyncio.sleep(0)
+
+        t1.cancel()
+
+        await asyncio.sleep(0)
+
+        sem.release()
+        sem.release()
+
+        tasks = [t1, t2, t3]
+        await asyncio.gather(*tasks, return_exceptions=True)
+        self.assertEqual([2, 3], result)
+
 
 class BarrierTests(unittest.IsolatedAsyncioTestCase):
 
diff --git a/Misc/NEWS.d/next/Library/2022-05-25-15-57-39.gh-issue-90155.YMstB5.rst b/Misc/NEWS.d/next/Library/2022-05-25-15-57-39.gh-issue-90155.YMstB5.rst
new file mode 100644
index 000000000000..8def76914eda
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2022-05-25-15-57-39.gh-issue-90155.YMstB5.rst
@@ -0,0 +1 @@
+Fix broken :class:`asyncio.Semaphore` when acquire is cancelled.



More information about the Python-checkins mailing list