[Python-checkins] cpython: Issue #12328: Fix multiprocessing's use of overlapped I/O on Windows.

antoine.pitrou python-checkins at python.org
Mon Mar 5 19:32:42 CET 2012


http://hg.python.org/cpython/rev/75c27daa592e
changeset:   75442:75c27daa592e
user:        Antoine Pitrou <solipsis at pitrou.net>
date:        Mon Mar 05 19:28:37 2012 +0100
summary:
  Issue #12328: Fix multiprocessing's use of overlapped I/O on Windows.
Also, add a multiprocessing.connection.wait(rlist, timeout=None) function
for polling multiple objects at once.  Patch by sbt.

Complete changelist from sbt's patch:

* Adds a wait(rlist, timeout=None) function for polling multiple
  objects at once.  On Unix this is just a wrapper for
  select(rlist, [], [], timeout=None).

* Removes use of the SentinelReady exception and the sentinels argument
  to certain methods.  concurrent.futures.process has been changed to
  use wait() instead of SentinelReady.

* Fixes bugs concerning PipeConnection.poll() and messages of zero
  length.

* Fixes PipeListener.accept() to call ConnectNamedPipe() with
  overlapped=True.

* Fixes Queue.empty() and SimpleQueue.empty() so that they are
  threadsafe on Windows.

* Now PipeConnection.poll() and wait() will not modify the pipe except
  possibly by consuming a zero length message.  (Previously poll()
  could consume a partial message.)

* All of multiprocesing's pipe related blocking functions/methods are
  now interruptible by SIGINT on Windows.

files:
  Doc/library/multiprocessing.rst            |   82 ++-
  Lib/concurrent/futures/process.py          |   12 +-
  Lib/multiprocessing/connection.py          |  332 ++++++---
  Lib/multiprocessing/queues.py              |    9 +-
  Lib/test/test_multiprocessing.py           |  235 +++++++-
  Misc/NEWS                                  |    4 +
  Modules/_multiprocessing/win32_functions.c |   69 +-
  7 files changed, 586 insertions(+), 157 deletions(-)


diff --git a/Doc/library/multiprocessing.rst b/Doc/library/multiprocessing.rst
--- a/Doc/library/multiprocessing.rst
+++ b/Doc/library/multiprocessing.rst
@@ -415,13 +415,14 @@
       A numeric handle of a system object which will become "ready" when
       the process ends.
 
+      You can use this value if you want to wait on several events at
+      once using :func:`multiprocessing.connection.wait`.  Otherwise
+      calling :meth:`join()` is simpler.
+
       On Windows, this is an OS handle usable with the ``WaitForSingleObject``
       and ``WaitForMultipleObjects`` family of API calls.  On Unix, this is
       a file descriptor usable with primitives from the :mod:`select` module.
 
-      You can use this value if you want to wait on several events at once.
-      Otherwise calling :meth:`join()` is simpler.
-
       .. versionadded:: 3.3
 
    .. method:: terminate()
@@ -785,6 +786,9 @@
       *timeout* is a number then this specifies the maximum time in seconds to
       block.  If *timeout* is ``None`` then an infinite timeout is used.
 
+      Note that multiple connection objects may be polled at once by
+      using :func:`multiprocessing.connection.wait`.
+
    .. method:: send_bytes(buffer[, offset[, size]])
 
       Send byte data from an object supporting the buffer interface as a
@@ -1779,8 +1783,9 @@
 
 However, the :mod:`multiprocessing.connection` module allows some extra
 flexibility.  It basically gives a high level message oriented API for dealing
-with sockets or Windows named pipes, and also has support for *digest
-authentication* using the :mod:`hmac` module.
+with sockets or Windows named pipes.  It also has support for *digest
+authentication* using the :mod:`hmac` module, and for polling
+multiple connections at the same time.
 
 
 .. function:: deliver_challenge(connection, authkey)
@@ -1878,6 +1883,38 @@
       The address from which the last accepted connection came.  If this is
       unavailable then it is ``None``.
 
+.. function:: wait(object_list, timeout=None)
+
+   Wait till an object in *object_list* is ready.  Returns the list of
+   those objects in *object_list* which are ready.  If *timeout* is a
+   float then the call blocks for at most that many seconds.  If
+   *timeout* is ``None`` then it will block for an unlimited period.
+
+   For both Unix and Windows, an object can appear in *object_list* if
+   it is
+
+   * a readable :class:`~multiprocessing.Connection` object;
+   * a connected and readable :class:`socket.socket` object; or
+   * the :attr:`~multiprocessing.Process.sentinel` attribute of a
+     :class:`~multiprocessing.Process` object.
+
+   A connection or socket object is ready when there is data available
+   to be read from it, or the other end has been closed.
+
+   **Unix**: ``wait(object_list, timeout)`` almost equivalent
+   ``select.select(object_list, [], [], timeout)``.  The difference is
+   that, if :func:`select.select` is interrupted by a signal, it can
+   raise :exc:`OSError` with an error number of ``EINTR``, whereas
+   :func:`wait` will not.
+
+   **Windows**: An item in *object_list* must either be an integer
+   handle which is waitable (according to the definition used by the
+   documentation of the Win32 function ``WaitForMultipleObjects()``)
+   or it can be an object with a :meth:`fileno` method which returns a
+   socket handle or pipe handle.  (Note that pipe handles and socket
+   handles are **not** waitable handles.)
+
+   .. versionadded:: 3.3
 
 The module defines two exceptions:
 
@@ -1929,6 +1966,41 @@
 
    conn.close()
 
+The following code uses :func:`~multiprocessing.connection.wait` to
+wait for messages from multiple processes at once::
+
+   import time, random
+   from multiprocessing import Process, Pipe, current_process
+   from multiprocessing.connection import wait
+
+   def foo(w):
+       for i in range(10):
+           w.send((i, current_process().name))
+       w.close()
+
+   if __name__ == '__main__':
+       readers = []
+
+       for i in range(4):
+           r, w = Pipe(duplex=False)
+           readers.append(r)
+           p = Process(target=foo, args=(w,))
+           p.start()
+           # We close the writable end of the pipe now to be sure that
+           # p is the only process which owns a handle for it.  This
+           # ensures that when p closes its handle for the writable end,
+           # wait() will promptly report the readable end as being ready.
+           w.close()
+
+       while readers:
+           for r in wait(readers):
+               try:
+                   msg = r.recv()
+               except EOFError:
+                   readers.remove(r)
+               else:
+                   print(msg)
+
 
 .. _multiprocessing-address-formats:
 
diff --git a/Lib/concurrent/futures/process.py b/Lib/concurrent/futures/process.py
--- a/Lib/concurrent/futures/process.py
+++ b/Lib/concurrent/futures/process.py
@@ -50,7 +50,8 @@
 from concurrent.futures import _base
 import queue
 import multiprocessing
-from multiprocessing.queues import SimpleQueue, SentinelReady, Full
+from multiprocessing.queues import SimpleQueue, Full
+from multiprocessing.connection import wait
 import threading
 import weakref
 
@@ -212,6 +213,8 @@
         for p in processes.values():
             p.join()
 
+    reader = result_queue._reader
+
     while True:
         _add_call_item_to_queue(pending_work_items,
                                 work_ids_queue,
@@ -219,9 +222,10 @@
 
         sentinels = [p.sentinel for p in processes.values()]
         assert sentinels
-        try:
-            result_item = result_queue.get(sentinels=sentinels)
-        except SentinelReady:
+        ready = wait([reader] + sentinels)
+        if reader in ready:
+            result_item = reader.recv()
+        else:
             # Mark the process pool broken so that submits fail right now.
             executor = executor_reference()
             if executor is not None:
diff --git a/Lib/multiprocessing/connection.py b/Lib/multiprocessing/connection.py
--- a/Lib/multiprocessing/connection.py
+++ b/Lib/multiprocessing/connection.py
@@ -32,7 +32,7 @@
 # SUCH DAMAGE.
 #
 
-__all__ = [ 'Client', 'Listener', 'Pipe' ]
+__all__ = [ 'Client', 'Listener', 'Pipe', 'wait' ]
 
 import io
 import os
@@ -58,8 +58,6 @@
         raise
     win32 = None
 
-_select = _eintr_retry(select.select)
-
 #
 #
 #
@@ -122,15 +120,6 @@
     else:
         raise ValueError('address type of %r unrecognized' % address)
 
-
-class SentinelReady(Exception):
-    """
-    Raised when a sentinel is ready when polling.
-    """
-    def __init__(self, *args):
-        Exception.__init__(self, *args)
-        self.sentinels = args[0]
-
 #
 # Connection classes
 #
@@ -268,11 +257,11 @@
                               (offset + size) // itemsize])
             return size
 
-    def recv(self, sentinels=None):
+    def recv(self):
         """Receive a (picklable) object"""
         self._check_closed()
         self._check_readable()
-        buf = self._recv_bytes(sentinels=sentinels)
+        buf = self._recv_bytes()
         return pickle.loads(buf.getbuffer())
 
     def poll(self, timeout=0.0):
@@ -290,85 +279,80 @@
         Overlapped I/O is used, so the handles must have been created
         with FILE_FLAG_OVERLAPPED.
         """
-        _buffered = b''
+        _got_empty_message = False
 
         def _close(self, _CloseHandle=win32.CloseHandle):
             _CloseHandle(self._handle)
 
         def _send_bytes(self, buf):
-            overlapped = win32.WriteFile(self._handle, buf, overlapped=True)
-            nwritten, complete = overlapped.GetOverlappedResult(True)
-            assert complete
+            ov, err = win32.WriteFile(self._handle, buf, overlapped=True)
+            try:
+                if err == win32.ERROR_IO_PENDING:
+                    waitres = win32.WaitForMultipleObjects(
+                        [ov.event], False, INFINITE)
+                    assert waitres == WAIT_OBJECT_0
+            except:
+                ov.cancel()
+                raise
+            finally:
+                nwritten, err = ov.GetOverlappedResult(True)
+            assert err == 0
             assert nwritten == len(buf)
 
-        def _recv_bytes(self, maxsize=None, sentinels=()):
-            if sentinels:
-                self._poll(-1.0, sentinels)
-            buf = io.BytesIO()
-            firstchunk = self._buffered
-            if firstchunk:
-                lenfirstchunk = len(firstchunk)
-                buf.write(firstchunk)
-                self._buffered = b''
+        def _recv_bytes(self, maxsize=None):
+            if self._got_empty_message:
+                self._got_empty_message = False
+                return io.BytesIO()
             else:
-                # A reasonable size for the first chunk transfer
-                bufsize = 128
-                if maxsize is not None and maxsize < bufsize:
-                    bufsize = maxsize
+                bsize = 128 if maxsize is None else min(maxsize, 128)
                 try:
-                    overlapped = win32.ReadFile(self._handle, bufsize, overlapped=True)
-                    lenfirstchunk, complete = overlapped.GetOverlappedResult(True)
-                    firstchunk = overlapped.getbuffer()
-                    assert lenfirstchunk == len(firstchunk)
+                    ov, err = win32.ReadFile(self._handle, bsize,
+                                             overlapped=True)
+                    try:
+                        if err == win32.ERROR_IO_PENDING:
+                            waitres = win32.WaitForMultipleObjects(
+                                [ov.event], False, INFINITE)
+                            assert waitres == WAIT_OBJECT_0
+                    except:
+                        ov.cancel()
+                        raise
+                    finally:
+                        nread, err = ov.GetOverlappedResult(True)
+                        if err == 0:
+                            f = io.BytesIO()
+                            f.write(ov.getbuffer())
+                            return f
+                        elif err == win32.ERROR_MORE_DATA:
+                            return self._get_more_data(ov, maxsize)
                 except IOError as e:
                     if e.winerror == win32.ERROR_BROKEN_PIPE:
                         raise EOFError
-                    raise
-                buf.write(firstchunk)
-                if complete:
-                    return buf
-            navail, nleft = win32.PeekNamedPipe(self._handle)
-            if maxsize is not None and lenfirstchunk + nleft > maxsize:
-                return None
-            if nleft > 0:
-                overlapped = win32.ReadFile(self._handle, nleft, overlapped=True)
-                res, complete = overlapped.GetOverlappedResult(True)
-                assert res == nleft
-                assert complete
-                buf.write(overlapped.getbuffer())
-            return buf
+                    else:
+                        raise
+            raise RuntimeError("shouldn't get here; expected KeyboardInterrupt")
 
-        def _poll(self, timeout, sentinels=()):
-            # Fast non-blocking path
-            navail, nleft = win32.PeekNamedPipe(self._handle)
-            if navail > 0:
+        def _poll(self, timeout):
+            if (self._got_empty_message or
+                        win32.PeekNamedPipe(self._handle)[0] != 0):
                 return True
-            elif timeout == 0.0:
-                return False
-            # Blocking: use overlapped I/O
-            if timeout < 0.0:
-                timeout = INFINITE
-            else:
-                timeout = int(timeout * 1000 + 0.5)
-            overlapped = win32.ReadFile(self._handle, 1, overlapped=True)
-            try:
-                handles = [overlapped.event]
-                handles += sentinels
-                res = win32.WaitForMultipleObjects(handles, False, timeout)
-            finally:
-                # Always cancel overlapped I/O in the same thread
-                # (because CancelIoEx() appears only in Vista)
-                overlapped.cancel()
-            if res == WAIT_TIMEOUT:
-                return False
-            idx = res - WAIT_OBJECT_0
-            if idx == 0:
-                # I/O was successful, store received data
-                overlapped.GetOverlappedResult(True)
-                self._buffered += overlapped.getbuffer()
-                return True
-            assert 0 < idx < len(handles)
-            raise SentinelReady([handles[idx]])
+            if timeout < 0:
+                timeout = None
+            return bool(wait([self], timeout))
+
+        def _get_more_data(self, ov, maxsize):
+            buf = ov.getbuffer()
+            f = io.BytesIO()
+            f.write(buf)
+            left = win32.PeekNamedPipe(self._handle)[1]
+            assert left > 0
+            if maxsize is not None and len(buf) + left > maxsize:
+                self._bad_message_length()
+            ov, err = win32.ReadFile(self._handle, left, overlapped=True)
+            rbytes, err = ov.GetOverlappedResult(True)
+            assert err == 0
+            assert rbytes == left
+            f.write(ov.getbuffer())
+            return f
 
 
 class Connection(_ConnectionBase):
@@ -397,17 +381,11 @@
                 break
             buf = buf[n:]
 
-    def _recv(self, size, sentinels=(), read=_read):
+    def _recv(self, size, read=_read):
         buf = io.BytesIO()
         handle = self._handle
-        if sentinels:
-            handles = [handle] + sentinels
         remaining = size
         while remaining > 0:
-            if sentinels:
-                r = _select(handles, [], [])[0]
-                if handle not in r:
-                    raise SentinelReady(r)
             chunk = read(handle, remaining)
             n = len(chunk)
             if n == 0:
@@ -428,17 +406,17 @@
         if n > 0:
             self._send(buf)
 
-    def _recv_bytes(self, maxsize=None, sentinels=()):
-        buf = self._recv(4, sentinels)
+    def _recv_bytes(self, maxsize=None):
+        buf = self._recv(4)
         size, = struct.unpack("!i", buf.getvalue())
         if maxsize is not None and size > maxsize:
             return None
-        return self._recv(size, sentinels)
+        return self._recv(size)
 
     def _poll(self, timeout):
         if timeout < 0.0:
             timeout = None
-        r = _select([self._handle], [], [], timeout)[0]
+        r = wait([self._handle], timeout)
         return bool(r)
 
 
@@ -559,7 +537,8 @@
             )
 
         overlapped = win32.ConnectNamedPipe(h1, overlapped=True)
-        overlapped.GetOverlappedResult(True)
+        _, err = overlapped.GetOverlappedResult(True)
+        assert err == 0
 
         c1 = PipeConnection(h1, writable=duplex)
         c2 = PipeConnection(h2, readable=duplex)
@@ -633,39 +612,40 @@
         '''
         def __init__(self, address, backlog=None):
             self._address = address
-            handle = win32.CreateNamedPipe(
-                address, win32.PIPE_ACCESS_DUPLEX |
-                win32.FILE_FLAG_FIRST_PIPE_INSTANCE,
+            self._handle_queue = [self._new_handle(first=True)]
+
+            self._last_accepted = None
+            sub_debug('listener created with address=%r', self._address)
+            self.close = Finalize(
+                self, PipeListener._finalize_pipe_listener,
+                args=(self._handle_queue, self._address), exitpriority=0
+                )
+
+        def _new_handle(self, first=False):
+            flags = win32.PIPE_ACCESS_DUPLEX | win32.FILE_FLAG_OVERLAPPED
+            if first:
+                flags |= win32.FILE_FLAG_FIRST_PIPE_INSTANCE
+            return win32.CreateNamedPipe(
+                self._address, flags,
                 win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
                 win32.PIPE_WAIT,
                 win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
                 win32.NMPWAIT_WAIT_FOREVER, win32.NULL
                 )
-            self._handle_queue = [handle]
-            self._last_accepted = None
-
-            sub_debug('listener created with address=%r', self._address)
-
-            self.close = Finalize(
-                self, PipeListener._finalize_pipe_listener,
-                args=(self._handle_queue, self._address), exitpriority=0
-                )
 
         def accept(self):
-            newhandle = win32.CreateNamedPipe(
-                self._address, win32.PIPE_ACCESS_DUPLEX,
-                win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
-                win32.PIPE_WAIT,
-                win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
-                win32.NMPWAIT_WAIT_FOREVER, win32.NULL
-                )
-            self._handle_queue.append(newhandle)
+            self._handle_queue.append(self._new_handle())
             handle = self._handle_queue.pop(0)
+            ov = win32.ConnectNamedPipe(handle, overlapped=True)
             try:
-                win32.ConnectNamedPipe(handle, win32.NULL)
-            except WindowsError as e:
-                if e.winerror != win32.ERROR_PIPE_CONNECTED:
-                    raise
+                res = win32.WaitForMultipleObjects([ov.event], False, INFINITE)
+            except:
+                ov.cancel()
+                win32.CloseHandle(handle)
+                raise
+            finally:
+                _, err = ov.GetOverlappedResult(True)
+            assert err == 0
             return PipeConnection(handle)
 
         @staticmethod
@@ -684,7 +664,8 @@
                 win32.WaitNamedPipe(address, 1000)
                 h = win32.CreateFile(
                     address, win32.GENERIC_READ | win32.GENERIC_WRITE,
-                    0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL
+                    0, win32.NULL, win32.OPEN_EXISTING,
+                    win32.FILE_FLAG_OVERLAPPED, win32.NULL
                     )
             except WindowsError as e:
                 if e.winerror not in (win32.ERROR_SEM_TIMEOUT,
@@ -773,6 +754,125 @@
     import xmlrpc.client as xmlrpclib
     return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
 
+#
+# Wait
+#
+
+if sys.platform == 'win32':
+
+    def _exhaustive_wait(handles, timeout):
+        # Return ALL handles which are currently signalled.  (Only
+        # returning the first signalled might create starvation issues.)
+        L = list(handles)
+        ready = []
+        while L:
+            res = win32.WaitForMultipleObjects(L, False, timeout)
+            if res == WAIT_TIMEOUT:
+                break
+            elif WAIT_OBJECT_0 <= res < WAIT_OBJECT_0 + len(L):
+                res -= WAIT_OBJECT_0
+            elif WAIT_ABANDONED_0 <= res < WAIT_ABANDONED_0 + len(L):
+                res -= WAIT_ABANDONED_0
+            else:
+                raise RuntimeError('Should not get here')
+            ready.append(L[res])
+            L = L[res+1:]
+            timeout = 0
+        return ready
+
+    _ready_errors = {win32.ERROR_BROKEN_PIPE, win32.ERROR_NETNAME_DELETED}
+
+    def wait(object_list, timeout=None):
+        '''
+        Wait till an object in object_list is ready/readable.
+
+        Returns list of those objects in object_list which are ready/readable.
+        '''
+        if timeout is None:
+            timeout = INFINITE
+        elif timeout < 0:
+            timeout = 0
+        else:
+            timeout = int(timeout * 1000 + 0.5)
+
+        object_list = list(object_list)
+        waithandle_to_obj = {}
+        ov_list = []
+        ready_objects = set()
+        ready_handles = set()
+
+        try:
+            for o in object_list:
+                try:
+                    fileno = getattr(o, 'fileno')
+                except AttributeError:
+                    waithandle_to_obj[o.__index__()] = o
+                else:
+                    # start an overlapped read of length zero
+                    try:
+                        ov, err = win32.ReadFile(fileno(), 0, True)
+                    except OSError as e:
+                        err = e.winerror
+                        if err not in _ready_errors:
+                            raise
+                    if err == win32.ERROR_IO_PENDING:
+                        ov_list.append(ov)
+                        waithandle_to_obj[ov.event] = o
+                    else:
+                        # If o.fileno() is an overlapped pipe handle and
+                        # err == 0 then there is a zero length message
+                        # in the pipe, but it HAS NOT been consumed.
+                        ready_objects.add(o)
+                        timeout = 0
+
+            ready_handles = _exhaustive_wait(waithandle_to_obj.keys(), timeout)
+        finally:
+            # request that overlapped reads stop
+            for ov in ov_list:
+                ov.cancel()
+
+            # wait for all overlapped reads to stop
+            for ov in ov_list:
+                try:
+                    _, err = ov.GetOverlappedResult(True)
+                except OSError as e:
+                    err = e.winerror
+                    if err not in _ready_errors:
+                        raise
+                if err != win32.ERROR_OPERATION_ABORTED:
+                    o = waithandle_to_obj[ov.event]
+                    ready_objects.add(o)
+                    if err == 0:
+                        # If o.fileno() is an overlapped pipe handle then
+                        # a zero length message HAS been consumed.
+                        if hasattr(o, '_got_empty_message'):
+                            o._got_empty_message = True
+
+        ready_objects.update(waithandle_to_obj[h] for h in ready_handles)
+        return [o for o in object_list if o in ready_objects]
+
+else:
+
+    def wait(object_list, timeout=None):
+        '''
+        Wait till an object in object_list is ready/readable.
+
+        Returns list of those objects in object_list which are ready/readable.
+        '''
+        if timeout is not None:
+            if timeout <= 0:
+                return select.select(object_list, [], [], 0)[0]
+            else:
+                deadline = time.time() + timeout
+        while True:
+            try:
+                return select.select(object_list, [], [], timeout)[0]
+            except OSError as e:
+                if e.errno != errno.EINTR:
+                    raise
+            if timeout is not None:
+                timeout = deadline - time.time()
+
 
 # Late import because of circular import
 from multiprocessing.forking import duplicate, close
diff --git a/Lib/multiprocessing/queues.py b/Lib/multiprocessing/queues.py
--- a/Lib/multiprocessing/queues.py
+++ b/Lib/multiprocessing/queues.py
@@ -44,7 +44,7 @@
 
 from queue import Empty, Full
 import _multiprocessing
-from multiprocessing.connection import Pipe, SentinelReady
+from multiprocessing.connection import Pipe
 from multiprocessing.synchronize import Lock, BoundedSemaphore, Semaphore, Condition
 from multiprocessing.util import debug, info, Finalize, register_after_fork
 from multiprocessing.forking import assert_spawning
@@ -360,6 +360,7 @@
     def __init__(self):
         self._reader, self._writer = Pipe(duplex=False)
         self._rlock = Lock()
+        self._poll = self._reader.poll
         if sys.platform == 'win32':
             self._wlock = None
         else:
@@ -367,7 +368,7 @@
         self._make_methods()
 
     def empty(self):
-        return not self._reader.poll()
+        return not self._poll()
 
     def __getstate__(self):
         assert_spawning(self)
@@ -380,10 +381,10 @@
     def _make_methods(self):
         recv = self._reader.recv
         racquire, rrelease = self._rlock.acquire, self._rlock.release
-        def get(*, sentinels=None):
+        def get():
             racquire()
             try:
-                return recv(sentinels)
+                return recv()
             finally:
                 rrelease()
         self.get = get
diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py
--- a/Lib/test/test_multiprocessing.py
+++ b/Lib/test/test_multiprocessing.py
@@ -1811,6 +1811,84 @@
             p.join()
             l.close()
 
+class _TestPoll(unittest.TestCase):
+
+    ALLOWED_TYPES = ('processes', 'threads')
+
+    def test_empty_string(self):
+        a, b = self.Pipe()
+        self.assertEqual(a.poll(), False)
+        b.send_bytes(b'')
+        self.assertEqual(a.poll(), True)
+        self.assertEqual(a.poll(), True)
+
+    @classmethod
+    def _child_strings(cls, conn, strings):
+        for s in strings:
+            time.sleep(0.1)
+            conn.send_bytes(s)
+        conn.close()
+
+    def test_strings(self):
+        strings = (b'hello', b'', b'a', b'b', b'', b'bye', b'', b'lop')
+        a, b = self.Pipe()
+        p = self.Process(target=self._child_strings, args=(b, strings))
+        p.start()
+
+        for s in strings:
+            for i in range(200):
+                if a.poll(0.01):
+                    break
+            x = a.recv_bytes()
+            self.assertEqual(s, x)
+
+        p.join()
+
+    @classmethod
+    def _child_boundaries(cls, r):
+        # Polling may "pull" a message in to the child process, but we
+        # don't want it to pull only part of a message, as that would
+        # corrupt the pipe for any other processes which might later
+        # read from it.
+        r.poll(5)
+
+    def test_boundaries(self):
+        r, w = self.Pipe(False)
+        p = self.Process(target=self._child_boundaries, args=(r,))
+        p.start()
+        time.sleep(2)
+        L = [b"first", b"second"]
+        for obj in L:
+            w.send_bytes(obj)
+        w.close()
+        p.join()
+        self.assertIn(r.recv_bytes(), L)
+
+    @classmethod
+    def _child_dont_merge(cls, b):
+        b.send_bytes(b'a')
+        b.send_bytes(b'b')
+        b.send_bytes(b'cd')
+
+    def test_dont_merge(self):
+        a, b = self.Pipe()
+        self.assertEqual(a.poll(0.0), False)
+        self.assertEqual(a.poll(0.1), False)
+
+        p = self.Process(target=self._child_dont_merge, args=(b,))
+        p.start()
+
+        self.assertEqual(a.recv_bytes(), b'a')
+        self.assertEqual(a.poll(1.0), True)
+        self.assertEqual(a.poll(1.0), True)
+        self.assertEqual(a.recv_bytes(), b'b')
+        self.assertEqual(a.poll(1.0), True)
+        self.assertEqual(a.poll(1.0), True)
+        self.assertEqual(a.poll(0.0), True)
+        self.assertEqual(a.recv_bytes(), b'cd')
+
+        p.join()
+
 #
 # Test of sending connection and socket objects between processes
 #
@@ -2404,8 +2482,163 @@
         flike.flush()
         assert sio.getvalue() == 'foo'
 
+
+class TestWait(unittest.TestCase):
+
+    @classmethod
+    def _child_test_wait(cls, w, slow):
+        for i in range(10):
+            if slow:
+                time.sleep(random.random()*0.1)
+            w.send((i, os.getpid()))
+        w.close()
+
+    def test_wait(self, slow=False):
+        from multiprocessing import Pipe, Process
+        from multiprocessing.connection import wait
+        readers = []
+        procs = []
+        messages = []
+
+        for i in range(4):
+            r, w = Pipe(duplex=False)
+            p = Process(target=self._child_test_wait, args=(w, slow))
+            p.daemon = True
+            p.start()
+            w.close()
+            readers.append(r)
+            procs.append(p)
+
+        while readers:
+            for r in wait(readers):
+                try:
+                    msg = r.recv()
+                except EOFError:
+                    readers.remove(r)
+                    r.close()
+                else:
+                    messages.append(msg)
+
+        messages.sort()
+        expected = sorted((i, p.pid) for i in range(10) for p in procs)
+        self.assertEqual(messages, expected)
+
+    @classmethod
+    def _child_test_wait_socket(cls, address, slow):
+        s = socket.socket()
+        s.connect(address)
+        for i in range(10):
+            if slow:
+                time.sleep(random.random()*0.1)
+            s.sendall(('%s\n' % i).encode('ascii'))
+        s.close()
+
+    def test_wait_socket(self, slow=False):
+        from multiprocessing import Process
+        from multiprocessing.connection import wait
+        l = socket.socket()
+        l.bind(('', 0))
+        l.listen(4)
+        addr = ('localhost', l.getsockname()[1])
+        readers = []
+        procs = []
+        dic = {}
+
+        for i in range(4):
+            p = Process(target=self._child_test_wait_socket, args=(addr, slow))
+            p.daemon = True
+            p.start()
+            procs.append(p)
+
+        for i in range(4):
+            r, _ = l.accept()
+            readers.append(r)
+            dic[r] = []
+        l.close()
+
+        while readers:
+            for r in wait(readers):
+                msg = r.recv(32)
+                if not msg:
+                    readers.remove(r)
+                    r.close()
+                else:
+                    dic[r].append(msg)
+
+        expected = ''.join('%s\n' % i for i in range(10)).encode('ascii')
+        for v in dic.values():
+            self.assertEqual(b''.join(v), expected)
+
+    def test_wait_slow(self):
+        self.test_wait(True)
+
+    def test_wait_socket_slow(self):
+        self.test_wait(True)
+
+    def test_wait_timeout(self):
+        from multiprocessing.connection import wait
+
+        expected = 1
+        a, b = multiprocessing.Pipe()
+
+        start = time.time()
+        res = wait([a, b], 1)
+        delta = time.time() - start
+
+        self.assertEqual(res, [])
+        self.assertLess(delta, expected + 0.2)
+        self.assertGreater(delta, expected - 0.2)
+
+        b.send(None)
+
+        start = time.time()
+        res = wait([a, b], 1)
+        delta = time.time() - start
+
+        self.assertEqual(res, [a])
+        self.assertLess(delta, 0.2)
+
+    def test_wait_integer(self):
+        from multiprocessing.connection import wait
+
+        expected = 5
+        a, b = multiprocessing.Pipe()
+        p = multiprocessing.Process(target=time.sleep, args=(expected,))
+
+        p.start()
+        self.assertIsInstance(p.sentinel, int)
+
+        start = time.time()
+        res = wait([a, p.sentinel, b], expected + 20)
+        delta = time.time() - start
+
+        self.assertEqual(res, [p.sentinel])
+        self.assertLess(delta, expected + 1)
+        self.assertGreater(delta, expected - 1)
+
+        a.send(None)
+
+        start = time.time()
+        res = wait([a, p.sentinel, b], 20)
+        delta = time.time() - start
+
+        self.assertEqual(res, [p.sentinel, b])
+        self.assertLess(delta, 0.2)
+
+        b.send(None)
+
+        start = time.time()
+        res = wait([a, p.sentinel, b], 20)
+        delta = time.time() - start
+
+        self.assertEqual(res, [a, p.sentinel, b])
+        self.assertLess(delta, 0.2)
+
+        p.join()
+
+
 testcases_other = [OtherTest, TestInvalidHandle, TestInitializers,
-                   TestStdinBadfiledescriptor]
+                   TestStdinBadfiledescriptor, TestWait]
 
 #
 #
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -13,6 +13,10 @@
 Library
 -------
 
+- Issue #12328: Fix multiprocessing's use of overlapped I/O on Windows.
+  Also, add a multiprocessing.connection.wait(rlist, timeout=None) function
+  for polling multiple objects at once.  Patch by sbt.
+
 - Issue #13719: Make the distutils and packaging upload commands aware of
   bdist_msi products.
 
diff --git a/Modules/_multiprocessing/win32_functions.c b/Modules/_multiprocessing/win32_functions.c
--- a/Modules/_multiprocessing/win32_functions.c
+++ b/Modules/_multiprocessing/win32_functions.c
@@ -60,16 +60,18 @@
 static void
 overlapped_dealloc(OverlappedObject *self)
 {
+    DWORD bytes;
     int err = GetLastError();
     if (self->pending) {
-        if (check_CancelIoEx())
-            Py_CancelIoEx(self->handle, &self->overlapped);
-        else {
-            PyErr_SetString(PyExc_RuntimeError,
-                            "I/O operations still in flight while destroying "
-                            "Overlapped object, the process may crash");
-            PyErr_WriteUnraisable(NULL);
-        }
+        /* make it a programming error to deallocate while operation
+           is pending, even if we can safely cancel it */
+        if (check_CancelIoEx() &&
+                Py_CancelIoEx(self->handle, &self->overlapped))
+            GetOverlappedResult(self->handle, &self->overlapped, &bytes, TRUE);
+        PyErr_SetString(PyExc_RuntimeError,
+                        "I/O operations still in flight while destroying "
+                        "Overlapped object, the process may crash");
+        PyErr_WriteUnraisable(NULL);
     }
     CloseHandle(self->overlapped.hEvent);
     SetLastError(err);
@@ -85,6 +87,7 @@
     int wait;
     BOOL res;
     DWORD transferred = 0;
+    DWORD err;
 
     wait = PyObject_IsTrue(waitobj);
     if (wait < 0)
@@ -94,23 +97,27 @@
                               wait != 0);
     Py_END_ALLOW_THREADS
 
-    if (!res) {
-        int err = GetLastError();
-        if (err == ERROR_IO_INCOMPLETE)
-            Py_RETURN_NONE;
-        if (err != ERROR_MORE_DATA) {
+    err = res ? ERROR_SUCCESS : GetLastError();
+    switch (err) {
+        case ERROR_SUCCESS:
+        case ERROR_MORE_DATA:
+        case ERROR_OPERATION_ABORTED:
+            self->completed = 1;
+            self->pending = 0;
+            break;
+        case ERROR_IO_INCOMPLETE:
+            break;
+        default:
             self->pending = 0;
             return PyErr_SetExcFromWindowsErr(PyExc_IOError, err);
-        }
     }
-    self->pending = 0;
-    self->completed = 1;
-    if (self->read_buffer) {
+    if (self->completed && self->read_buffer != NULL) {
         assert(PyBytes_CheckExact(self->read_buffer));
-        if (_PyBytes_Resize(&self->read_buffer, transferred))
+        if (transferred != PyBytes_GET_SIZE(self->read_buffer) &&
+            _PyBytes_Resize(&self->read_buffer, transferred))
             return NULL;
     }
-    return Py_BuildValue("lN", (long) transferred, PyBool_FromLong(res));
+    return Py_BuildValue("II", (unsigned) transferred, (unsigned) err);
 }
 
 static PyObject *
@@ -522,9 +529,10 @@
     HANDLE handle;
     Py_buffer _buf, *buf;
     PyObject *bufobj;
-    int written;
+    DWORD written;
     BOOL ret;
     int use_overlapped = 0;
+    DWORD err;
     OverlappedObject *overlapped = NULL;
     static char *kwlist[] = {"handle", "buffer", "overlapped", NULL};
 
@@ -553,8 +561,9 @@
                     overlapped ? &overlapped->overlapped : NULL);
     Py_END_ALLOW_THREADS
 
+    err = ret ? 0 : GetLastError();
+
     if (overlapped) {
-        int err = GetLastError();
         if (!ret) {
             if (err == ERROR_IO_PENDING)
                 overlapped->pending = 1;
@@ -563,13 +572,13 @@
                 return PyErr_SetExcFromWindowsErr(PyExc_IOError, 0);
             }
         }
-        return (PyObject *) overlapped;
+        return Py_BuildValue("NI", (PyObject *) overlapped, err);
     }
 
     PyBuffer_Release(buf);
     if (!ret)
         return PyErr_SetExcFromWindowsErr(PyExc_IOError, 0);
-    return PyLong_FromLong(written);
+    return Py_BuildValue("II", written, err);
 }
 
 static PyObject *
@@ -581,6 +590,7 @@
     PyObject *buf;
     BOOL ret;
     int use_overlapped = 0;
+    DWORD err;
     OverlappedObject *overlapped = NULL;
     static char *kwlist[] = {"handle", "size", "overlapped", NULL};
 
@@ -607,8 +617,9 @@
                    overlapped ? &overlapped->overlapped : NULL);
     Py_END_ALLOW_THREADS
 
+    err = ret ? 0 : GetLastError();
+
     if (overlapped) {
-        int err = GetLastError();
         if (!ret) {
             if (err == ERROR_IO_PENDING)
                 overlapped->pending = 1;
@@ -617,16 +628,16 @@
                 return PyErr_SetExcFromWindowsErr(PyExc_IOError, 0);
             }
         }
-        return (PyObject *) overlapped;
+        return Py_BuildValue("NI", (PyObject *) overlapped, err);
     }
 
-    if (!ret && GetLastError() != ERROR_MORE_DATA) {
+    if (!ret && err != ERROR_MORE_DATA) {
         Py_DECREF(buf);
         return PyErr_SetExcFromWindowsErr(PyExc_IOError, 0);
     }
     if (_PyBytes_Resize(&buf, nread))
         return NULL;
-    return Py_BuildValue("NN", buf, PyBool_FromLong(ret));
+    return Py_BuildValue("NI", buf, err);
 }
 
 static PyObject *
@@ -783,7 +794,11 @@
 
     WIN32_CONSTANT(F_DWORD, ERROR_ALREADY_EXISTS);
     WIN32_CONSTANT(F_DWORD, ERROR_BROKEN_PIPE);
+    WIN32_CONSTANT(F_DWORD, ERROR_IO_PENDING);
+    WIN32_CONSTANT(F_DWORD, ERROR_MORE_DATA);
+    WIN32_CONSTANT(F_DWORD, ERROR_NETNAME_DELETED);
     WIN32_CONSTANT(F_DWORD, ERROR_NO_SYSTEM_RESOURCES);
+    WIN32_CONSTANT(F_DWORD, ERROR_OPERATION_ABORTED);
     WIN32_CONSTANT(F_DWORD, ERROR_PIPE_BUSY);
     WIN32_CONSTANT(F_DWORD, ERROR_PIPE_CONNECTED);
     WIN32_CONSTANT(F_DWORD, ERROR_SEM_TIMEOUT);

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list