[Python-checkins] GH-74116: Allow multiple drain waiters for asyncio.StreamWriter (GH-94705)

gvanrossum webhook-mailer at python.org
Mon Aug 29 14:31:25 EDT 2022


https://github.com/python/cpython/commit/e5b2453e61ba5376831093236d598ef5f9f1de61
commit: e5b2453e61ba5376831093236d598ef5f9f1de61
branch: main
author: Kumar Aditya <59607654+kumaraditya303 at users.noreply.github.com>
committer: gvanrossum <gvanrossum at gmail.com>
date: 2022-08-29T11:31:11-07:00
summary:

GH-74116: Allow multiple drain waiters for asyncio.StreamWriter (GH-94705)

files:
A Misc/NEWS.d/next/Library/2022-07-09-08-55-04.gh-issue-74116.0XwYC1.rst
M Lib/asyncio/streams.py
M Lib/test/test_asyncio/test_streams.py

diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index 614b2cda6068..c4d837a11708 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -2,6 +2,7 @@
     'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
     'open_connection', 'start_server')
 
+import collections
 import socket
 import sys
 import weakref
@@ -128,7 +129,7 @@ def __init__(self, loop=None):
         else:
             self._loop = loop
         self._paused = False
-        self._drain_waiter = None
+        self._drain_waiters = collections.deque()
         self._connection_lost = False
 
     def pause_writing(self):
@@ -143,38 +144,34 @@ def resume_writing(self):
         if self._loop.get_debug():
             logger.debug("%r resumes writing", self)
 
-        waiter = self._drain_waiter
-        if waiter is not None:
-            self._drain_waiter = None
+        for waiter in self._drain_waiters:
             if not waiter.done():
                 waiter.set_result(None)
 
     def connection_lost(self, exc):
         self._connection_lost = True
-        # Wake up the writer if currently paused.
+        # Wake up the writer(s) if currently paused.
         if not self._paused:
             return
-        waiter = self._drain_waiter
-        if waiter is None:
-            return
-        self._drain_waiter = None
-        if waiter.done():
-            return
-        if exc is None:
-            waiter.set_result(None)
-        else:
-            waiter.set_exception(exc)
+
+        for waiter in self._drain_waiters:
+            if not waiter.done():
+                if exc is None:
+                    waiter.set_result(None)
+                else:
+                    waiter.set_exception(exc)
 
     async def _drain_helper(self):
         if self._connection_lost:
             raise ConnectionResetError('Connection lost')
         if not self._paused:
             return
-        waiter = self._drain_waiter
-        assert waiter is None or waiter.cancelled()
         waiter = self._loop.create_future()
-        self._drain_waiter = waiter
-        await waiter
+        self._drain_waiters.append(waiter)
+        try:
+            await waiter
+        finally:
+            self._drain_waiters.remove(waiter)
 
     def _get_close_waiter(self, stream):
         raise NotImplementedError
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py
index 098a0da344d0..0c49099bc499 100644
--- a/Lib/test/test_asyncio/test_streams.py
+++ b/Lib/test/test_asyncio/test_streams.py
@@ -864,6 +864,25 @@ def test_streamreaderprotocol_constructor_use_global_loop(self):
         self.assertEqual(cm.filename, __file__)
         self.assertIs(protocol._loop, self.loop)
 
+    def test_multiple_drain(self):
+        # See https://github.com/python/cpython/issues/74116
+        drained = 0
+
+        async def drainer(stream):
+            nonlocal drained
+            await stream._drain_helper()
+            drained += 1
+
+        async def main():
+            loop = asyncio.get_running_loop()
+            stream = asyncio.streams.FlowControlMixin(loop)
+            stream.pause_writing()
+            loop.call_later(0.1, stream.resume_writing)
+            await asyncio.gather(*[drainer(stream) for _ in range(10)])
+            self.assertEqual(drained, 10)
+
+        self.loop.run_until_complete(main())
+
     def test_drain_raises(self):
         # See http://bugs.python.org/issue25441
 
diff --git a/Misc/NEWS.d/next/Library/2022-07-09-08-55-04.gh-issue-74116.0XwYC1.rst b/Misc/NEWS.d/next/Library/2022-07-09-08-55-04.gh-issue-74116.0XwYC1.rst
new file mode 100644
index 000000000000..33782598745b
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2022-07-09-08-55-04.gh-issue-74116.0XwYC1.rst
@@ -0,0 +1 @@
+Allow :meth:`asyncio.StreamWriter.drain` to be awaited concurrently by multiple tasks. Patch by Kumar Aditya.



More information about the Python-checkins mailing list