[Python-checkins] bpo-29883: Asyncio proactor udp (GH-13440)

Miss Islington (bot) webhook-mailer at python.org
Tue May 28 05:52:20 EDT 2019


https://github.com/python/cpython/commit/bafd4b5ac83b6cc0b7455290a04c4bfad34bdc90
commit: bafd4b5ac83b6cc0b7455290a04c4bfad34bdc90
branch: master
author: Andrew Svetlov <andrew.svetlov at gmail.com>
committer: Miss Islington (bot) <31488909+miss-islington at users.noreply.github.com>
date: 2019-05-28T02:52:15-07:00
summary:

bpo-29883: Asyncio proactor udp (GH-13440)



Follow-up for #1067


https://bugs.python.org/issue29883

files:
A Misc/NEWS.d/next/Windows/2018-09-15-11-36-55.bpo-29883.HErerE.rst
M Doc/library/asyncio-eventloop.rst
M Doc/library/asyncio-platforms.rst
M Lib/asyncio/proactor_events.py
M Lib/asyncio/windows_events.py
M Lib/test/test_asyncio/test_events.py
M Lib/test/test_asyncio/test_proactor_events.py
M Modules/overlapped.c

diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst
index 4acd23f59465..8673f84e9638 100644
--- a/Doc/library/asyncio-eventloop.rst
+++ b/Doc/library/asyncio-eventloop.rst
@@ -504,8 +504,6 @@ Opening network connections
      transport. If specified, *local_addr* and *remote_addr* should be omitted
      (must be :const:`None`).
 
-   On Windows, with :class:`ProactorEventLoop`, this method is not supported.
-
    See :ref:`UDP echo client protocol <asyncio-udp-echo-client-protocol>` and
    :ref:`UDP echo server protocol <asyncio-udp-echo-server-protocol>` examples.
 
@@ -513,6 +511,9 @@ Opening network connections
       The *family*, *proto*, *flags*, *reuse_address*, *reuse_port,
       *allow_broadcast*, and *sock* parameters were added.
 
+   .. versionchanged:: 3.8
+      Added support for Windows.
+
 .. coroutinemethod:: loop.create_unix_connection(protocol_factory, \
                         path=None, \*, ssl=None, sock=None, \
                         server_hostname=None, ssl_handshake_timeout=None)
diff --git a/Doc/library/asyncio-platforms.rst b/Doc/library/asyncio-platforms.rst
index 81d840e23277..7e4a70f91c6e 100644
--- a/Doc/library/asyncio-platforms.rst
+++ b/Doc/library/asyncio-platforms.rst
@@ -53,9 +53,6 @@ All event loops on Windows do not support the following methods:
 
 :class:`ProactorEventLoop` has the following limitations:
 
-* The :meth:`loop.create_datagram_endpoint` method
-  is not supported.
-
 * The :meth:`loop.add_reader` and :meth:`loop.add_writer`
   methods are not supported.
 
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index 6a53b2edaac1..9b8ae064a892 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -11,6 +11,7 @@
 import socket
 import warnings
 import signal
+import collections
 
 from . import base_events
 from . import constants
@@ -23,6 +24,24 @@
 from .log import logger
 
 
+def _set_socket_extra(transport, sock):
+    transport._extra['socket'] = trsock.TransportSocket(sock)
+
+    try:
+        transport._extra['sockname'] = sock.getsockname()
+    except socket.error:
+        if transport._loop.get_debug():
+            logger.warning(
+                "getsockname() failed on %r", sock, exc_info=True)
+
+    if 'peername' not in transport._extra:
+        try:
+            transport._extra['peername'] = sock.getpeername()
+        except socket.error:
+            # UDP sockets may not have a peer name
+            transport._extra['peername'] = None
+
+
 class _ProactorBasePipeTransport(transports._FlowControlMixin,
                                  transports.BaseTransport):
     """Base class for pipe and socket transports."""
@@ -430,6 +449,134 @@ def _pipe_closed(self, fut):
             self.close()
 
 
+class _ProactorDatagramTransport(_ProactorBasePipeTransport):
+    max_size = 256 * 1024
+    def __init__(self, loop, sock, protocol, address=None,
+                 waiter=None, extra=None):
+        self._address = address
+        self._empty_waiter = None
+        # We don't need to call _protocol.connection_made() since our base
+        # constructor does it for us.
+        super().__init__(loop, sock, protocol, waiter=waiter, extra=extra)
+
+        # The base constructor sets _buffer = None, so we set it here
+        self._buffer = collections.deque()
+        self._loop.call_soon(self._loop_reading)
+
+    def _set_extra(self, sock):
+        _set_socket_extra(self, sock)
+
+    def get_write_buffer_size(self):
+        return sum(len(data) for data, _ in self._buffer)
+
+    def abort(self):
+        self._force_close(None)
+
+    def sendto(self, data, addr=None):
+        if not isinstance(data, (bytes, bytearray, memoryview)):
+            raise TypeError('data argument must be bytes-like object (%r)',
+                            type(data))
+
+        if not data:
+            return
+
+        if self._address is not None and addr not in (None, self._address):
+            raise ValueError(
+                f'Invalid address: must be None or {self._address}')
+
+        if self._conn_lost and self._address:
+            if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+                logger.warning('socket.sendto() raised exception.')
+            self._conn_lost += 1
+            return
+
+        # Ensure that what we buffer is immutable.
+        self._buffer.append((bytes(data), addr))
+
+        if self._write_fut is None:
+            # No current write operations are active, kick one off
+            self._loop_writing()
+        # else: A write operation is already kicked off
+
+        self._maybe_pause_protocol()
+
+    def _loop_writing(self, fut=None):
+        try:
+            if self._conn_lost:
+                return
+
+            assert fut is self._write_fut
+            self._write_fut = None
+            if fut:
+                # We are in a _loop_writing() done callback, get the result
+                fut.result()
+
+            if not self._buffer or (self._conn_lost and self._address):
+                # The connection has been closed
+                if self._closing:
+                    self._loop.call_soon(self._call_connection_lost, None)
+                return
+
+            data, addr = self._buffer.popleft()
+            if self._address is not None:
+                self._write_fut = self._loop._proactor.send(self._sock,
+                                                            data)
+            else:
+                self._write_fut = self._loop._proactor.sendto(self._sock,
+                                                              data,
+                                                              addr=addr)
+        except OSError as exc:
+            self._protocol.error_received(exc)
+        except Exception as exc:
+            self._fatal_error(exc, 'Fatal write error on datagram transport')
+        else:
+            self._write_fut.add_done_callback(self._loop_writing)
+            self._maybe_resume_protocol()
+
+    def _loop_reading(self, fut=None):
+        data = None
+        try:
+            if self._conn_lost:
+                return
+
+            assert self._read_fut is fut or (self._read_fut is None and
+                                             self._closing)
+
+            self._read_fut = None
+            if fut is not None:
+                res = fut.result()
+
+                if self._closing:
+                    # since close() has been called we ignore any read data
+                    data = None
+                    return
+
+                if self._address is not None:
+                    data, addr = res, self._address
+                else:
+                    data, addr = res
+
+            if self._conn_lost:
+                return
+            if self._address is not None:
+                self._read_fut = self._loop._proactor.recv(self._sock,
+                                                           self.max_size)
+            else:
+                self._read_fut = self._loop._proactor.recvfrom(self._sock,
+                                                               self.max_size)
+        except OSError as exc:
+            self._protocol.error_received(exc)
+        except exceptions.CancelledError:
+            if not self._closing:
+                raise
+        else:
+            if self._read_fut is not None:
+                self._read_fut.add_done_callback(self._loop_reading)
+        finally:
+            if data:
+                self._protocol.datagram_received(data, addr)
+
+
 class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport,
                                    _ProactorBaseWritePipeTransport,
                                    transports.Transport):
@@ -455,22 +602,7 @@ def __init__(self, loop, sock, protocol, waiter=None,
         base_events._set_nodelay(sock)
 
     def _set_extra(self, sock):
-        self._extra['socket'] = trsock.TransportSocket(sock)
-
-        try:
-            self._extra['sockname'] = sock.getsockname()
-        except (socket.error, AttributeError):
-            if self._loop.get_debug():
-                logger.warning(
-                    "getsockname() failed on %r", sock, exc_info=True)
-
-        if 'peername' not in self._extra:
-            try:
-                self._extra['peername'] = sock.getpeername()
-            except (socket.error, AttributeError):
-                if self._loop.get_debug():
-                    logger.warning("getpeername() failed on %r",
-                                   sock, exc_info=True)
+        _set_socket_extra(self, sock)
 
     def can_write_eof(self):
         return True
@@ -515,6 +647,11 @@ def _make_ssl_transport(
                                  extra=extra, server=server)
         return ssl_protocol._app_transport
 
+    def _make_datagram_transport(self, sock, protocol,
+                                 address=None, waiter=None, extra=None):
+        return _ProactorDatagramTransport(self, sock, protocol, address,
+                                          waiter, extra)
+
     def _make_duplex_pipe_transport(self, sock, protocol, waiter=None,
                                     extra=None):
         return _ProactorDuplexPipeTransport(self,
diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py
index 61b40ba52a64..ac51109ff1a8 100644
--- a/Lib/asyncio/windows_events.py
+++ b/Lib/asyncio/windows_events.py
@@ -483,6 +483,44 @@ def finish_recv(trans, key, ov):
 
         return self._register(ov, conn, finish_recv)
 
+    def recvfrom(self, conn, nbytes, flags=0):
+        self._register_with_iocp(conn)
+        ov = _overlapped.Overlapped(NULL)
+        try:
+            ov.WSARecvFrom(conn.fileno(), nbytes, flags)
+        except BrokenPipeError:
+            return self._result((b'', None))
+
+        def finish_recv(trans, key, ov):
+            try:
+                return ov.getresult()
+            except OSError as exc:
+                if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+                                    _overlapped.ERROR_OPERATION_ABORTED):
+                    raise ConnectionResetError(*exc.args)
+                else:
+                    raise
+
+        return self._register(ov, conn, finish_recv)
+
+    def sendto(self, conn, buf, flags=0, addr=None):
+        self._register_with_iocp(conn)
+        ov = _overlapped.Overlapped(NULL)
+
+        ov.WSASendTo(conn.fileno(), buf, flags, addr)
+
+        def finish_send(trans, key, ov):
+            try:
+                return ov.getresult()
+            except OSError as exc:
+                if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+                                    _overlapped.ERROR_OPERATION_ABORTED):
+                    raise ConnectionResetError(*exc.args)
+                else:
+                    raise
+
+        return self._register(ov, conn, finish_send)
+
     def send(self, conn, buf, flags=0):
         self._register_with_iocp(conn)
         ov = _overlapped.Overlapped(NULL)
@@ -532,6 +570,14 @@ def finish_accept(trans, key, ov):
         return future
 
     def connect(self, conn, address):
+        if conn.type == socket.SOCK_DGRAM:
+            # WSAConnect will complete immediately for UDP sockets so we don't
+            # need to register any IOCP operation
+            _overlapped.WSAConnect(conn.fileno(), address)
+            fut = self._loop.create_future()
+            fut.set_result(None)
+            return fut
+
         self._register_with_iocp(conn)
         # The socket needs to be locally bound before we call ConnectEx().
         try:
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
index e89db99df312..045654e87a85 100644
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -1249,11 +1249,6 @@ def datagram_received(self, data, addr):
         server.transport.close()
 
     def test_create_datagram_endpoint_sock(self):
-        if (sys.platform == 'win32' and
-                isinstance(self.loop, proactor_events.BaseProactorEventLoop)):
-            raise unittest.SkipTest(
-                'UDP is not supported with proactor event loops')
-
         sock = None
         local_address = ('127.0.0.1', 0)
         infos = self.loop.run_until_complete(
@@ -2004,10 +1999,6 @@ def test_writer_callback(self):
         def test_writer_callback_cancel(self):
             raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
 
-        def test_create_datagram_endpoint(self):
-            raise unittest.SkipTest(
-                "IocpEventLoop does not have create_datagram_endpoint()")
-
         def test_remove_fds_after_closing(self):
             raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
 else:
diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py
index 5952ccccce0e..2e9995d32807 100644
--- a/Lib/test/test_asyncio/test_proactor_events.py
+++ b/Lib/test/test_asyncio/test_proactor_events.py
@@ -4,6 +4,7 @@
 import socket
 import unittest
 import sys
+from collections import deque
 from unittest import mock
 
 import asyncio
@@ -12,6 +13,7 @@
 from asyncio.proactor_events import _ProactorSocketTransport
 from asyncio.proactor_events import _ProactorWritePipeTransport
 from asyncio.proactor_events import _ProactorDuplexPipeTransport
+from asyncio.proactor_events import _ProactorDatagramTransport
 from test import support
 from test.test_asyncio import utils as test_utils
 
@@ -725,6 +727,208 @@ def test_pause_resume_reading(self):
         self.assertFalse(tr.is_reading())
 
 
+class ProactorDatagramTransportTests(test_utils.TestCase):
+
+    def setUp(self):
+        super().setUp()
+        self.loop = self.new_test_loop()
+        self.proactor = mock.Mock()
+        self.loop._proactor = self.proactor
+        self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
+        self.sock = mock.Mock(spec_set=socket.socket)
+        self.sock.fileno.return_value = 7
+
+    def datagram_transport(self, address=None):
+        self.sock.getpeername.side_effect = None if address else OSError
+        transport = _ProactorDatagramTransport(self.loop, self.sock,
+                                               self.protocol,
+                                               address=address)
+        self.addCleanup(close_transport, transport)
+        return transport
+
+    def test_sendto(self):
+        data = b'data'
+        transport = self.datagram_transport()
+        transport.sendto(data, ('0.0.0.0', 1234))
+        self.assertTrue(self.proactor.sendto.called)
+        self.proactor.sendto.assert_called_with(
+            self.sock, data, addr=('0.0.0.0', 1234))
+
+    def test_sendto_bytearray(self):
+        data = bytearray(b'data')
+        transport = self.datagram_transport()
+        transport.sendto(data, ('0.0.0.0', 1234))
+        self.assertTrue(self.proactor.sendto.called)
+        self.proactor.sendto.assert_called_with(
+            self.sock, b'data', addr=('0.0.0.0', 1234))
+
+    def test_sendto_memoryview(self):
+        data = memoryview(b'data')
+        transport = self.datagram_transport()
+        transport.sendto(data, ('0.0.0.0', 1234))
+        self.assertTrue(self.proactor.sendto.called)
+        self.proactor.sendto.assert_called_with(
+            self.sock, b'data', addr=('0.0.0.0', 1234))
+
+    def test_sendto_no_data(self):
+        transport = self.datagram_transport()
+        transport._buffer.append((b'data', ('0.0.0.0', 12345)))
+        transport.sendto(b'', ())
+        self.assertFalse(self.sock.sendto.called)
+        self.assertEqual(
+            [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+    def test_sendto_buffer(self):
+        transport = self.datagram_transport()
+        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+        transport._write_fut = object()
+        transport.sendto(b'data2', ('0.0.0.0', 12345))
+        self.assertFalse(self.proactor.sendto.called)
+        self.assertEqual(
+            [(b'data1', ('0.0.0.0', 12345)),
+             (b'data2', ('0.0.0.0', 12345))],
+            list(transport._buffer))
+
+    def test_sendto_buffer_bytearray(self):
+        data2 = bytearray(b'data2')
+        transport = self.datagram_transport()
+        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+        transport._write_fut = object()
+        transport.sendto(data2, ('0.0.0.0', 12345))
+        self.assertFalse(self.proactor.sendto.called)
+        self.assertEqual(
+            [(b'data1', ('0.0.0.0', 12345)),
+             (b'data2', ('0.0.0.0', 12345))],
+            list(transport._buffer))
+        self.assertIsInstance(transport._buffer[1][0], bytes)
+
+    def test_sendto_buffer_memoryview(self):
+        data2 = memoryview(b'data2')
+        transport = self.datagram_transport()
+        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+        transport._write_fut = object()
+        transport.sendto(data2, ('0.0.0.0', 12345))
+        self.assertFalse(self.proactor.sendto.called)
+        self.assertEqual(
+            [(b'data1', ('0.0.0.0', 12345)),
+             (b'data2', ('0.0.0.0', 12345))],
+            list(transport._buffer))
+        self.assertIsInstance(transport._buffer[1][0], bytes)
+
+    @mock.patch('asyncio.proactor_events.logger')
+    def test_sendto_exception(self, m_log):
+        data = b'data'
+        err = self.proactor.sendto.side_effect = RuntimeError()
+
+        transport = self.datagram_transport()
+        transport._fatal_error = mock.Mock()
+        transport.sendto(data, ())
+
+        self.assertTrue(transport._fatal_error.called)
+        transport._fatal_error.assert_called_with(
+                                   err,
+                                   'Fatal write error on datagram transport')
+        transport._conn_lost = 1
+
+        transport._address = ('123',)
+        transport.sendto(data)
+        transport.sendto(data)
+        transport.sendto(data)
+        transport.sendto(data)
+        transport.sendto(data)
+        m_log.warning.assert_called_with('socket.sendto() raised exception.')
+
+    def test_sendto_error_received(self):
+        data = b'data'
+
+        self.sock.sendto.side_effect = ConnectionRefusedError
+
+        transport = self.datagram_transport()
+        transport._fatal_error = mock.Mock()
+        transport.sendto(data, ())
+
+        self.assertEqual(transport._conn_lost, 0)
+        self.assertFalse(transport._fatal_error.called)
+
+    def test_sendto_error_received_connected(self):
+        data = b'data'
+
+        self.proactor.send.side_effect = ConnectionRefusedError
+
+        transport = self.datagram_transport(address=('0.0.0.0', 1))
+        transport._fatal_error = mock.Mock()
+        transport.sendto(data)
+
+        self.assertFalse(transport._fatal_error.called)
+        self.assertTrue(self.protocol.error_received.called)
+
+    def test_sendto_str(self):
+        transport = self.datagram_transport()
+        self.assertRaises(TypeError, transport.sendto, 'str', ())
+
+    def test_sendto_connected_addr(self):
+        transport = self.datagram_transport(address=('0.0.0.0', 1))
+        self.assertRaises(
+            ValueError, transport.sendto, b'str', ('0.0.0.0', 2))
+
+    def test_sendto_closing(self):
+        transport = self.datagram_transport(address=(1,))
+        transport.close()
+        self.assertEqual(transport._conn_lost, 1)
+        transport.sendto(b'data', (1,))
+        self.assertEqual(transport._conn_lost, 2)
+
+    def test__loop_writing_closing(self):
+        transport = self.datagram_transport()
+        transport._closing = True
+        transport._loop_writing()
+        self.assertIsNone(transport._write_fut)
+        test_utils.run_briefly(self.loop)
+        self.sock.close.assert_called_with()
+        self.protocol.connection_lost.assert_called_with(None)
+
+    def test__loop_writing_exception(self):
+        err = self.proactor.sendto.side_effect = RuntimeError()
+
+        transport = self.datagram_transport()
+        transport._fatal_error = mock.Mock()
+        transport._buffer.append((b'data', ()))
+        transport._loop_writing()
+
+        transport._fatal_error.assert_called_with(
+                                   err,
+                                   'Fatal write error on datagram transport')
+
+    def test__loop_writing_error_received(self):
+        self.proactor.sendto.side_effect = ConnectionRefusedError
+
+        transport = self.datagram_transport()
+        transport._fatal_error = mock.Mock()
+        transport._buffer.append((b'data', ()))
+        transport._loop_writing()
+
+        self.assertFalse(transport._fatal_error.called)
+
+    def test__loop_writing_error_received_connection(self):
+        self.proactor.send.side_effect = ConnectionRefusedError
+
+        transport = self.datagram_transport(address=('0.0.0.0', 1))
+        transport._fatal_error = mock.Mock()
+        transport._buffer.append((b'data', ()))
+        transport._loop_writing()
+
+        self.assertFalse(transport._fatal_error.called)
+        self.assertTrue(self.protocol.error_received.called)
+
+    @mock.patch('asyncio.base_events.logger.error')
+    def test_fatal_error_connected(self, m_exc):
+        transport = self.datagram_transport(address=('0.0.0.0', 1))
+        err = ConnectionRefusedError()
+        transport._fatal_error(err)
+        self.assertFalse(self.protocol.error_received.called)
+        m_exc.assert_not_called()
+
+
 class BaseProactorEventLoopTests(test_utils.TestCase):
 
     def setUp(self):
@@ -864,6 +1068,80 @@ def test_stop_serving(self):
         self.assertFalse(sock2.close.called)
         self.assertFalse(future2.cancel.called)
 
+    def datagram_transport(self):
+        self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
+        return self.loop._make_datagram_transport(self.sock, self.protocol)
+
+    def test_make_datagram_transport(self):
+        tr = self.datagram_transport()
+        self.assertIsInstance(tr, _ProactorDatagramTransport)
+        close_transport(tr)
+
+    def test_datagram_loop_writing(self):
+        tr = self.datagram_transport()
+        tr._buffer.appendleft((b'data', ('127.0.0.1', 12068)))
+        tr._loop_writing()
+        self.loop._proactor.sendto.assert_called_with(self.sock, b'data', addr=('127.0.0.1', 12068))
+        self.loop._proactor.sendto.return_value.add_done_callback.\
+            assert_called_with(tr._loop_writing)
+
+        close_transport(tr)
+
+    def test_datagram_loop_reading(self):
+        tr = self.datagram_transport()
+        tr._loop_reading()
+        self.loop._proactor.recvfrom.assert_called_with(self.sock, 256 * 1024)
+        self.assertFalse(self.protocol.datagram_received.called)
+        self.assertFalse(self.protocol.error_received.called)
+        close_transport(tr)
+
+    def test_datagram_loop_reading_data(self):
+        res = asyncio.Future(loop=self.loop)
+        res.set_result((b'data', ('127.0.0.1', 12068)))
+
+        tr = self.datagram_transport()
+        tr._read_fut = res
+        tr._loop_reading(res)
+        self.loop._proactor.recvfrom.assert_called_with(self.sock, 256 * 1024)
+        self.protocol.datagram_received.assert_called_with(b'data', ('127.0.0.1', 12068))
+        close_transport(tr)
+
+    def test_datagram_loop_reading_no_data(self):
+        res = asyncio.Future(loop=self.loop)
+        res.set_result((b'', ('127.0.0.1', 12068)))
+
+        tr = self.datagram_transport()
+        self.assertRaises(AssertionError, tr._loop_reading, res)
+
+        tr.close = mock.Mock()
+        tr._read_fut = res
+        tr._loop_reading(res)
+        self.assertTrue(self.loop._proactor.recvfrom.called)
+        self.assertFalse(self.protocol.error_received.called)
+        self.assertFalse(tr.close.called)
+        close_transport(tr)
+
+    def test_datagram_loop_reading_aborted(self):
+        err = self.loop._proactor.recvfrom.side_effect = ConnectionAbortedError()
+
+        tr = self.datagram_transport()
+        tr._fatal_error = mock.Mock()
+        tr._protocol.error_received = mock.Mock()
+        tr._loop_reading()
+        tr._protocol.error_received.assert_called_with(err)
+        close_transport(tr)
+
+    def test_datagram_loop_writing_aborted(self):
+        err = self.loop._proactor.sendto.side_effect = ConnectionAbortedError()
+
+        tr = self.datagram_transport()
+        tr._fatal_error = mock.Mock()
+        tr._protocol.error_received = mock.Mock()
+        tr._buffer.appendleft((b'Hello', ('127.0.0.1', 12068)))
+        tr._loop_writing()
+        tr._protocol.error_received.assert_called_with(err)
+        close_transport(tr)
+
 
 @unittest.skipIf(sys.platform != 'win32',
                  'Proactor is supported on Windows only')
diff --git a/Misc/NEWS.d/next/Windows/2018-09-15-11-36-55.bpo-29883.HErerE.rst b/Misc/NEWS.d/next/Windows/2018-09-15-11-36-55.bpo-29883.HErerE.rst
new file mode 100644
index 000000000000..b6d1375c7752
--- /dev/null
+++ b/Misc/NEWS.d/next/Windows/2018-09-15-11-36-55.bpo-29883.HErerE.rst
@@ -0,0 +1,2 @@
+Add Windows support for UDP transports for the Proactor Event Loop. Patch by
+Adam Meily.
diff --git a/Modules/overlapped.c b/Modules/overlapped.c
index e5a209bf7582..aad531e47893 100644
--- a/Modules/overlapped.c
+++ b/Modules/overlapped.c
@@ -39,7 +39,8 @@
 
 enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_READINTO, TYPE_WRITE,
       TYPE_ACCEPT, TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE,
-      TYPE_WAIT_NAMED_PIPE_AND_CONNECT, TYPE_TRANSMIT_FILE};
+      TYPE_WAIT_NAMED_PIPE_AND_CONNECT, TYPE_TRANSMIT_FILE, TYPE_READ_FROM,
+      TYPE_WRITE_TO};
 
 typedef struct {
     PyObject_HEAD
@@ -53,8 +54,19 @@ typedef struct {
     union {
         /* Buffer allocated by us: TYPE_READ and TYPE_ACCEPT */
         PyObject *allocated_buffer;
-        /* Buffer passed by the user: TYPE_WRITE and TYPE_READINTO */
+        /* Buffer passed by the user: TYPE_WRITE, TYPE_WRITE_TO, and TYPE_READINTO */
         Py_buffer user_buffer;
+
+        /* Data used for reading from a connectionless socket:
+           TYPE_READ_FROM */
+        struct {
+            // A (buffer, (host, port)) tuple
+            PyObject *result;
+            // The actual read buffer
+            PyObject *allocated_buffer;
+            struct sockaddr_in6 address;
+            int address_length;
+        } read_from;
     };
 } OverlappedObject;
 
@@ -570,16 +582,32 @@ static int
 Overlapped_clear(OverlappedObject *self)
 {
     switch (self->type) {
-    case TYPE_READ:
-    case TYPE_ACCEPT:
-        Py_CLEAR(self->allocated_buffer);
-        break;
-    case TYPE_WRITE:
-    case TYPE_READINTO:
-        if (self->user_buffer.obj) {
-            PyBuffer_Release(&self->user_buffer);
+        case TYPE_READ:
+        case TYPE_ACCEPT: {
+            Py_CLEAR(self->allocated_buffer);
+            break;
+        }
+        case TYPE_READ_FROM: {
+            // An initial call to WSARecvFrom will only allocate the buffer.
+            // The result tuple of (message, address) is only
+            // allocated _after_ a message has been received.
+            if(self->read_from.result) {
+                // We've received a message, free the result tuple.
+                Py_CLEAR(self->read_from.result);
+            }
+            if(self->read_from.allocated_buffer) {
+                Py_CLEAR(self->read_from.allocated_buffer);
+            }
+            break;
+        }
+        case TYPE_WRITE:
+        case TYPE_WRITE_TO:
+        case TYPE_READINTO: {
+            if (self->user_buffer.obj) {
+                PyBuffer_Release(&self->user_buffer);
+            }
+            break;
         }
-        break;
     }
     self->type = TYPE_NOT_STARTED;
     return 0;
@@ -627,6 +655,73 @@ Overlapped_dealloc(OverlappedObject *self)
     SetLastError(olderr);
 }
 
+
+/* Convert IPv4 sockaddr to a Python str. */
+
+static PyObject *
+make_ipv4_addr(const struct sockaddr_in *addr)
+{
+        char buf[INET_ADDRSTRLEN];
+        if (inet_ntop(AF_INET, &addr->sin_addr, buf, sizeof(buf)) == NULL) {
+                PyErr_SetFromErrno(PyExc_OSError);
+                return NULL;
+        }
+        return PyUnicode_FromString(buf);
+}
+
+#ifdef ENABLE_IPV6
+/* Convert IPv6 sockaddr to a Python str. */
+
+static PyObject *
+make_ipv6_addr(const struct sockaddr_in6 *addr)
+{
+        char buf[INET6_ADDRSTRLEN];
+        if (inet_ntop(AF_INET6, &addr->sin6_addr, buf, sizeof(buf)) == NULL) {
+                PyErr_SetFromErrno(PyExc_OSError);
+                return NULL;
+        }
+        return PyUnicode_FromString(buf);
+}
+#endif
+
+static PyObject*
+unparse_address(LPSOCKADDR Address, DWORD Length)
+{
+        /* The function is adopted from mocketmodule.c makesockaddr()*/
+
+    switch(Address->sa_family) {
+        case AF_INET: {
+            const struct sockaddr_in *a = (const struct sockaddr_in *)Address;
+            PyObject *addrobj = make_ipv4_addr(a);
+            PyObject *ret = NULL;
+            if (addrobj) {
+                ret = Py_BuildValue("Oi", addrobj, ntohs(a->sin_port));
+                Py_DECREF(addrobj);
+            }
+            return ret;
+        }
+#ifdef ENABLE_IPV6
+        case AF_INET6: {
+            const struct sockaddr_in6 *a = (const struct sockaddr_in6 *)Address;
+            PyObject *addrobj = make_ipv6_addr(a);
+            PyObject *ret = NULL;
+            if (addrobj) {
+                ret = Py_BuildValue("OiII",
+                                    addrobj,
+                                    ntohs(a->sin6_port),
+                                    ntohl(a->sin6_flowinfo),
+                                    a->sin6_scope_id);
+                Py_DECREF(addrobj);
+            }
+            return ret;
+        }
+#endif /* ENABLE_IPV6 */
+        default: {
+            return SetFromWindowsErr(ERROR_INVALID_PARAMETER);
+        }
+    }
+}
+
 PyDoc_STRVAR(
     Overlapped_cancel_doc,
     "cancel() -> None\n\n"
@@ -670,6 +765,7 @@ Overlapped_getresult(OverlappedObject *self, PyObject *args)
     DWORD transferred = 0;
     BOOL ret;
     DWORD err;
+    PyObject *addr;
 
     if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait))
         return NULL;
@@ -695,8 +791,15 @@ Overlapped_getresult(OverlappedObject *self, PyObject *args)
         case ERROR_MORE_DATA:
             break;
         case ERROR_BROKEN_PIPE:
-            if (self->type == TYPE_READ || self->type == TYPE_READINTO)
+            if (self->type == TYPE_READ || self->type == TYPE_READINTO) {
                 break;
+            }
+            else if (self->type == TYPE_READ_FROM &&
+                     (self->read_from.result != NULL ||
+                      self->read_from.allocated_buffer != NULL))
+            {
+                break;
+            }
             /* fall through */
         default:
             return SetFromWindowsErr(err);
@@ -708,8 +811,43 @@ Overlapped_getresult(OverlappedObject *self, PyObject *args)
             if (transferred != PyBytes_GET_SIZE(self->allocated_buffer) &&
                 _PyBytes_Resize(&self->allocated_buffer, transferred))
                 return NULL;
+
             Py_INCREF(self->allocated_buffer);
             return self->allocated_buffer;
+        case TYPE_READ_FROM:
+            assert(PyBytes_CheckExact(self->read_from.allocated_buffer));
+
+            if (transferred != PyBytes_GET_SIZE(
+                    self->read_from.allocated_buffer) &&
+                _PyBytes_Resize(&self->read_from.allocated_buffer, transferred))
+            {
+                return NULL;
+            }
+
+            // unparse the address
+            addr = unparse_address((SOCKADDR*)&self->read_from.address,
+                                   self->read_from.address_length);
+
+            if (addr == NULL) {
+                return NULL;
+            }
+
+            // The result is a two item tuple: (message, address)
+            self->read_from.result = PyTuple_New(2);
+            if (self->read_from.result == NULL) {
+                Py_CLEAR(addr);
+                return NULL;
+            }
+
+            // first item: message
+            Py_INCREF(self->read_from.allocated_buffer);
+            PyTuple_SET_ITEM(self->read_from.result, 0,
+                             self->read_from.allocated_buffer);
+            // second item: address
+            PyTuple_SET_ITEM(self->read_from.result, 1, addr);
+
+            Py_INCREF(self->read_from.result);
+            return self->read_from.result;
         default:
             return PyLong_FromUnsignedLong((unsigned long) transferred);
     }
@@ -1121,7 +1259,6 @@ parse_address(PyObject *obj, SOCKADDR *Address, int Length)
     return -1;
 }
 
-
 PyDoc_STRVAR(
     Overlapped_ConnectEx_doc,
     "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n"
@@ -1314,7 +1451,7 @@ PyDoc_STRVAR(
     "Connect to the pipe for asynchronous I/O (overlapped).");
 
 static PyObject *
-ConnectPipe(OverlappedObject *self, PyObject *args)
+overlapped_ConnectPipe(PyObject *self, PyObject *args)
 {
     PyObject *AddressObj;
     wchar_t *Address;
@@ -1362,15 +1499,213 @@ Overlapped_traverse(OverlappedObject *self, visitproc visit, void *arg)
         Py_VISIT(self->allocated_buffer);
         break;
     case TYPE_WRITE:
+    case TYPE_WRITE_TO:
     case TYPE_READINTO:
         if (self->user_buffer.obj) {
             Py_VISIT(&self->user_buffer.obj);
         }
         break;
+    case TYPE_READ_FROM:
+        if(self->read_from.result) {
+            Py_VISIT(self->read_from.result);
+        }
+        if(self->read_from.allocated_buffer) {
+            Py_VISIT(self->read_from.allocated_buffer);
+        }
     }
     return 0;
 }
 
+// UDP functions
+
+PyDoc_STRVAR(
+    WSAConnect_doc,
+    "WSAConnect(client_handle, address_as_bytes) -> Overlapped[None]\n\n"
+    "Bind a remote address to a connectionless (UDP) socket");
+
+/*
+ * Note: WSAConnect does not support Overlapped I/O so this function should
+ * _only_ be used for connectionless sockets (UDP).
+ */
+static PyObject *
+overlapped_WSAConnect(PyObject *self, PyObject *args)
+{
+    SOCKET ConnectSocket;
+    PyObject *AddressObj;
+    char AddressBuf[sizeof(struct sockaddr_in6)];
+    SOCKADDR *Address = (SOCKADDR*)AddressBuf;
+    int Length;
+    int err;
+
+    if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) {
+        return NULL;
+    }
+
+    Length = sizeof(AddressBuf);
+    Length = parse_address(AddressObj, Address, Length);
+    if (Length < 0) {
+        return NULL;
+    }
+
+    Py_BEGIN_ALLOW_THREADS
+    // WSAConnect does not support overlapped I/O so this call will
+    // successfully complete immediately.
+    err = WSAConnect(ConnectSocket, Address, Length,
+                        NULL, NULL, NULL, NULL);
+    Py_END_ALLOW_THREADS
+
+    if (err == 0) {
+        Py_RETURN_NONE;
+    }
+    else {
+        return SetFromWindowsErr(WSAGetLastError());
+    }
+}
+
+PyDoc_STRVAR(
+    Overlapped_WSASendTo_doc,
+    "WSASendTo(handle, buf, flags, address_as_bytes) -> "
+    "Overlapped[bytes_transferred]\n\n"
+    "Start overlapped sendto over a connectionless (UDP) socket");
+
+static PyObject *
+Overlapped_WSASendTo(OverlappedObject *self, PyObject *args)
+{
+    HANDLE handle;
+    PyObject *bufobj;
+    DWORD flags;
+    PyObject *AddressObj;
+    char AddressBuf[sizeof(struct sockaddr_in6)];
+    SOCKADDR *Address = (SOCKADDR*)AddressBuf;
+    int AddressLength;
+    DWORD written;
+    WSABUF wsabuf;
+    int ret;
+    DWORD err;
+
+    if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD "O",
+                          &handle, &bufobj, &flags, &AddressObj))
+    {
+        return NULL;
+    }
+
+    // Parse the "to" address
+    AddressLength = sizeof(AddressBuf);
+    AddressLength = parse_address(AddressObj, Address, AddressLength);
+    if (AddressLength < 0) {
+        return NULL;
+    }
+
+    if (self->type != TYPE_NONE) {
+        PyErr_SetString(PyExc_ValueError, "operation already attempted");
+        return NULL;
+    }
+
+    if (!PyArg_Parse(bufobj, "y*", &self->user_buffer)) {
+        return NULL;
+    }
+
+#if SIZEOF_SIZE_T > SIZEOF_LONG
+    if (self->user_buffer.len > (Py_ssize_t)ULONG_MAX) {
+        PyBuffer_Release(&self->user_buffer);
+        PyErr_SetString(PyExc_ValueError, "buffer too large");
+        return NULL;
+    }
+#endif
+
+    self->type = TYPE_WRITE_TO;
+    self->handle = handle;
+    wsabuf.len = (DWORD)self->user_buffer.len;
+    wsabuf.buf = self->user_buffer.buf;
+
+    Py_BEGIN_ALLOW_THREADS
+    ret = WSASendTo((SOCKET)handle, &wsabuf, 1, &written, flags,
+                    Address, AddressLength, &self->overlapped, NULL);
+    Py_END_ALLOW_THREADS
+
+    self->error = err = (ret == SOCKET_ERROR ? WSAGetLastError() :
+                                               ERROR_SUCCESS);
+
+    switch(err) {
+        case ERROR_SUCCESS:
+        case ERROR_IO_PENDING:
+            Py_RETURN_NONE;
+        default:
+            self->type = TYPE_NOT_STARTED;
+            return SetFromWindowsErr(err);
+    }
+}
+
+
+
+PyDoc_STRVAR(
+    Overlapped_WSARecvFrom_doc,
+    "RecvFile(handle, size, flags) -> Overlapped[(message, (host, port))]\n\n"
+    "Start overlapped receive");
+
+static PyObject *
+Overlapped_WSARecvFrom(OverlappedObject *self, PyObject *args)
+{
+    HANDLE handle;
+    DWORD size;
+    DWORD flags = 0;
+    DWORD nread;
+    PyObject *buf;
+    WSABUF wsabuf;
+    int ret;
+    DWORD err;
+
+    if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD,
+                          &handle, &size, &flags))
+    {
+        return NULL;
+    }
+
+    if (self->type != TYPE_NONE) {
+        PyErr_SetString(PyExc_ValueError, "operation already attempted");
+        return NULL;
+    }
+
+#if SIZEOF_SIZE_T <= SIZEOF_LONG
+    size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX);
+#endif
+    buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1));
+    if (buf == NULL) {
+        return NULL;
+    }
+
+    wsabuf.len = size;
+    wsabuf.buf = PyBytes_AS_STRING(buf);
+
+    self->type = TYPE_READ_FROM;
+    self->handle = handle;
+    self->read_from.allocated_buffer = buf;
+    memset(&self->read_from.address, 0, sizeof(self->read_from.address));
+    self->read_from.address_length = sizeof(self->read_from.address);
+
+    Py_BEGIN_ALLOW_THREADS
+    ret = WSARecvFrom((SOCKET)handle, &wsabuf, 1, &nread, &flags,
+                      (SOCKADDR*)&self->read_from.address,
+                      &self->read_from.address_length,
+                      &self->overlapped, NULL);
+    Py_END_ALLOW_THREADS
+
+    self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS);
+
+    switch(err) {
+    case ERROR_BROKEN_PIPE:
+        mark_as_completed(&self->overlapped);
+        return SetFromWindowsErr(err);
+    case ERROR_SUCCESS:
+    case ERROR_MORE_DATA:
+    case ERROR_IO_PENDING:
+        Py_RETURN_NONE;
+    default:
+        self->type = TYPE_NOT_STARTED;
+        return SetFromWindowsErr(err);
+    }
+}
+
 
 static PyMethodDef Overlapped_methods[] = {
     {"getresult", (PyCFunction) Overlapped_getresult,
@@ -1399,6 +1734,10 @@ static PyMethodDef Overlapped_methods[] = {
      METH_VARARGS, Overlapped_TransmitFile_doc},
     {"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe,
      METH_VARARGS, Overlapped_ConnectNamedPipe_doc},
+    {"WSARecvFrom", (PyCFunction) Overlapped_WSARecvFrom,
+     METH_VARARGS, Overlapped_WSARecvFrom_doc },
+    {"WSASendTo", (PyCFunction) Overlapped_WSASendTo,
+     METH_VARARGS, Overlapped_WSASendTo_doc },
     {NULL}
 };
 
@@ -1484,9 +1823,10 @@ static PyMethodDef overlapped_functions[] = {
      METH_VARARGS, SetEvent_doc},
     {"ResetEvent", overlapped_ResetEvent,
      METH_VARARGS, ResetEvent_doc},
-    {"ConnectPipe",
-     (PyCFunction) ConnectPipe,
+    {"ConnectPipe", overlapped_ConnectPipe,
      METH_VARARGS, ConnectPipe_doc},
+    {"WSAConnect", overlapped_WSAConnect,
+     METH_VARARGS, WSAConnect_doc},
     {NULL}
 };
 



More information about the Python-checkins mailing list