[Python-checkins] bpo-46805: Add low level UDP socket functions to asyncio (GH-31455)

asvetlov webhook-mailer at python.org
Sun Mar 13 12:42:41 EDT 2022


https://github.com/python/cpython/commit/9f04ee569cebb8b4c6f04bea95d91a19c5403806
commit: 9f04ee569cebb8b4c6f04bea95d91a19c5403806
branch: main
author: Alex Grönholm <alex.gronholm at nextday.fi>
committer: asvetlov <andrew.svetlov at gmail.com>
date: 2022-03-13T18:42:29+02:00
summary:

bpo-46805: Add low level UDP socket functions to asyncio (GH-31455)

files:
A Misc/NEWS.d/next/Library/2022-02-20-23-03-32.bpo-46805.HZ8xWG.rst
M Doc/library/asyncio-eventloop.rst
M Doc/library/asyncio-llapi-index.rst
M Doc/whatsnew/3.11.rst
M Lib/asyncio/events.py
M Lib/asyncio/proactor_events.py
M Lib/asyncio/selector_events.py
M Lib/asyncio/windows_events.py
M Lib/test/test_asyncio/test_sock_lowlevel.py
M Lib/test/test_asyncio/utils.py
M Modules/clinic/overlapped.c.h
M Modules/overlapped.c

diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst
index 1b762a75aed0a..4776853b5a56d 100644
--- a/Doc/library/asyncio-eventloop.rst
+++ b/Doc/library/asyncio-eventloop.rst
@@ -922,6 +922,29 @@ convenient.
 
    .. versionadded:: 3.7
 
+.. coroutinemethod:: loop.sock_recvfrom(sock, bufsize)
+
+   Receive a datagram of up to *bufsize* from *sock*.  Asynchronous version of
+   :meth:`socket.recvfrom() <socket.socket.recvfrom>`.
+
+   Return a tuple of (received data, remote address).
+
+   *sock* must be a non-blocking socket.
+
+   .. versionadded:: 3.11
+
+.. coroutinemethod:: loop.sock_recvfrom_into(sock, buf, nbytes=0)
+
+   Receive a datagram of up to *nbytes* from *sock* into *buf*.
+   Asynchronous version of
+   :meth:`socket.recvfrom_into() <socket.socket.recvfrom_into>`.
+
+   Return a tuple of (number of bytes received, remote address).
+
+   *sock* must be a non-blocking socket.
+
+   .. versionadded:: 3.11
+
 .. coroutinemethod:: loop.sock_sendall(sock, data)
 
    Send *data* to the *sock* socket. Asynchronous version of
@@ -940,6 +963,18 @@ convenient.
       method, before Python 3.7 it returned a :class:`Future`.
       Since Python 3.7, this is an ``async def`` method.
 
+.. coroutinemethod:: loop.sock_sendto(sock, data, address)
+
+   Send a datagram from *sock* to *address*.
+   Asynchronous version of
+   :meth:`socket.sendto() <socket.socket.sendto>`.
+
+   Return the number of bytes sent.
+
+   *sock* must be a non-blocking socket.
+
+   .. versionadded:: 3.11
+
 .. coroutinemethod:: loop.sock_connect(sock, address)
 
    Connect *sock* to a remote socket at *address*.
diff --git a/Doc/library/asyncio-llapi-index.rst b/Doc/library/asyncio-llapi-index.rst
index 0ab322af6dc72..69b550e43f5aa 100644
--- a/Doc/library/asyncio-llapi-index.rst
+++ b/Doc/library/asyncio-llapi-index.rst
@@ -189,9 +189,18 @@ See also the main documentation section about the
     * - ``await`` :meth:`loop.sock_recv_into`
       - Receive data from the :class:`~socket.socket` into a buffer.
 
+    * - ``await`` :meth:`loop.sock_recvfrom`
+      - Receive a datagram from the :class:`~socket.socket`.
+
+    * - ``await`` :meth:`loop.sock_recvfrom_into`
+      - Receive a datagram from the :class:`~socket.socket` into a buffer.
+
     * - ``await`` :meth:`loop.sock_sendall`
       - Send data to the :class:`~socket.socket`.
 
+    * - ``await`` :meth:`loop.sock_sendto`
+      - Send a datagram via the :class:`~socket.socket` to the given address.
+
     * - ``await`` :meth:`loop.sock_connect`
       - Connect the :class:`~socket.socket`.
 
diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst
index 8ab6854663030..9fbf46791c27d 100644
--- a/Doc/whatsnew/3.11.rst
+++ b/Doc/whatsnew/3.11.rst
@@ -226,6 +226,15 @@ New Modules
 Improved Modules
 ================
 
+asyncio
+-------
+
+* Add raw datagram socket functions to the event loop:
+  :meth:`~asyncio.AbstractEventLoop.sock_sendto`,
+  :meth:`~asyncio.AbstractEventLoop.sock_recvfrom` and
+  :meth:`~asyncio.AbstractEventLoop.sock_recvfrom_into`.
+  (Contributed by Alex Grönholm in :issue:`46805`.)
+
 fractions
 ---------
 
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index 1d305e3ddff1c..e682a192a887f 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -546,9 +546,18 @@ async def sock_recv(self, sock, nbytes):
     async def sock_recv_into(self, sock, buf):
         raise NotImplementedError
 
+    async def sock_recvfrom(self, sock, bufsize):
+        raise NotImplementedError
+
+    async def sock_recvfrom_into(self, sock, buf, nbytes=0):
+        raise NotImplementedError
+
     async def sock_sendall(self, sock, data):
         raise NotImplementedError
 
+    async def sock_sendto(self, sock, data, address):
+        raise NotImplementedError
+
     async def sock_connect(self, sock, address):
         raise NotImplementedError
 
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index 43d5e70b79cac..087f0950d118b 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -700,9 +700,21 @@ async def sock_recv(self, sock, n):
     async def sock_recv_into(self, sock, buf):
         return await self._proactor.recv_into(sock, buf)
 
+    async def sock_recvfrom(self, sock, bufsize):
+        return await self._proactor.recvfrom(sock, bufsize)
+
+    async def sock_recvfrom_into(self, sock, buf, nbytes=0):
+        if not nbytes:
+            nbytes = len(buf)
+
+        return await self._proactor.recvfrom_into(sock, buf, nbytes)
+
     async def sock_sendall(self, sock, data):
         return await self._proactor.send(sock, data)
 
+    async def sock_sendto(self, sock, data, address):
+        return await self._proactor.sendto(sock, data, 0, address)
+
     async def sock_connect(self, sock, address):
         return await self._proactor.connect(sock, address)
 
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index c3c2ec12a7787..bfd8019da606e 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -434,6 +434,88 @@ def _sock_recv_into(self, fut, sock, buf):
         else:
             fut.set_result(nbytes)
 
+    async def sock_recvfrom(self, sock, bufsize):
+        """Receive a datagram from a datagram socket.
+
+        The return value is a tuple of (bytes, address) representing the
+        datagram received and the address it came from.
+        The maximum amount of data to be received at once is specified by
+        nbytes.
+        """
+        base_events._check_ssl_socket(sock)
+        if self._debug and sock.gettimeout() != 0:
+            raise ValueError("the socket must be non-blocking")
+        try:
+            return sock.recvfrom(bufsize)
+        except (BlockingIOError, InterruptedError):
+            pass
+        fut = self.create_future()
+        fd = sock.fileno()
+        self._ensure_fd_no_transport(fd)
+        handle = self._add_reader(fd, self._sock_recvfrom, fut, sock, bufsize)
+        fut.add_done_callback(
+            functools.partial(self._sock_read_done, fd, handle=handle))
+        return await fut
+
+    def _sock_recvfrom(self, fut, sock, bufsize):
+        # _sock_recvfrom() can add itself as an I/O callback if the operation
+        # can't be done immediately. Don't use it directly, call
+        # sock_recvfrom().
+        if fut.done():
+            return
+        try:
+            result = sock.recvfrom(bufsize)
+        except (BlockingIOError, InterruptedError):
+            return  # try again next time
+        except (SystemExit, KeyboardInterrupt):
+            raise
+        except BaseException as exc:
+            fut.set_exception(exc)
+        else:
+            fut.set_result(result)
+
+    async def sock_recvfrom_into(self, sock, buf, nbytes=0):
+        """Receive data from the socket.
+
+        The received data is written into *buf* (a writable buffer).
+        The return value is a tuple of (number of bytes written, address).
+        """
+        base_events._check_ssl_socket(sock)
+        if self._debug and sock.gettimeout() != 0:
+            raise ValueError("the socket must be non-blocking")
+        if not nbytes:
+            nbytes = len(buf)
+          
+        try:
+            return sock.recvfrom_into(buf, nbytes)
+        except (BlockingIOError, InterruptedError):
+            pass
+        fut = self.create_future()
+        fd = sock.fileno()
+        self._ensure_fd_no_transport(fd)
+        handle = self._add_reader(fd, self._sock_recvfrom_into, fut, sock, buf,
+                                  nbytes)
+        fut.add_done_callback(
+            functools.partial(self._sock_read_done, fd, handle=handle))
+        return await fut
+
+    def _sock_recvfrom_into(self, fut, sock, buf, bufsize):
+        # _sock_recv_into() can add itself as an I/O callback if the operation
+        # can't be done immediately. Don't use it directly, call
+        # sock_recv_into().
+        if fut.done():
+            return
+        try:
+            result = sock.recvfrom_into(buf, bufsize)
+        except (BlockingIOError, InterruptedError):
+            return  # try again next time
+        except (SystemExit, KeyboardInterrupt):
+            raise
+        except BaseException as exc:
+            fut.set_exception(exc)
+        else:
+            fut.set_result(result)
+
     async def sock_sendall(self, sock, data):
         """Send data to the socket.
 
@@ -487,6 +569,48 @@ def _sock_sendall(self, fut, sock, view, pos):
         else:
             pos[0] = start
 
+    async def sock_sendto(self, sock, data, address):
+        """Send data to the socket.
+
+        The socket must be connected to a remote socket. This method continues
+        to send data from data until either all data has been sent or an
+        error occurs. None is returned on success. On error, an exception is
+        raised, and there is no way to determine how much data, if any, was
+        successfully processed by the receiving end of the connection.
+        """
+        base_events._check_ssl_socket(sock)
+        if self._debug and sock.gettimeout() != 0:
+            raise ValueError("the socket must be non-blocking")
+        try:
+            return sock.sendto(data, address)
+        except (BlockingIOError, InterruptedError):
+            pass
+
+        fut = self.create_future()
+        fd = sock.fileno()
+        self._ensure_fd_no_transport(fd)
+        # use a trick with a list in closure to store a mutable state
+        handle = self._add_writer(fd, self._sock_sendto, fut, sock, data,
+                                  address)
+        fut.add_done_callback(
+            functools.partial(self._sock_write_done, fd, handle=handle))
+        return await fut
+
+    def _sock_sendto(self, fut, sock, data, address):
+        if fut.done():
+            # Future cancellation can be scheduled on previous loop iteration
+            return
+        try:
+            n = sock.sendto(data, 0, address)
+        except (BlockingIOError, InterruptedError):
+            return
+        except (SystemExit, KeyboardInterrupt):
+            raise
+        except BaseException as exc:
+            fut.set_exception(exc)
+        else:
+            fut.set_result(n)
+
     async def sock_connect(self, sock, address):
         """Connect to a remote socket at address.
 
diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py
index 0d9a07ef4772e..90b259cbafead 100644
--- a/Lib/asyncio/windows_events.py
+++ b/Lib/asyncio/windows_events.py
@@ -512,6 +512,26 @@ def finish_recv(trans, key, ov):
 
         return self._register(ov, conn, finish_recv)
 
+    def recvfrom_into(self, conn, buf, flags=0):
+        self._register_with_iocp(conn)
+        ov = _overlapped.Overlapped(NULL)
+        try:
+            ov.WSARecvFromInto(conn.fileno(), buf, flags)
+        except BrokenPipeError:
+            return self._result((0, None))
+
+        def finish_recv(trans, key, ov):
+            try:
+                return ov.getresult()
+            except OSError as exc:
+                if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+                                    _overlapped.ERROR_OPERATION_ABORTED):
+                    raise ConnectionResetError(*exc.args)
+                else:
+                    raise
+
+        return self._register(ov, conn, finish_recv)
+
     def sendto(self, conn, buf, flags=0, addr=None):
         self._register_with_iocp(conn)
         ov = _overlapped.Overlapped(NULL)
diff --git a/Lib/test/test_asyncio/test_sock_lowlevel.py b/Lib/test/test_asyncio/test_sock_lowlevel.py
index 14001a4a5001f..db47616d18343 100644
--- a/Lib/test/test_asyncio/test_sock_lowlevel.py
+++ b/Lib/test/test_asyncio/test_sock_lowlevel.py
@@ -5,11 +5,11 @@
 
 from asyncio import proactor_events
 from itertools import cycle, islice
+from unittest.mock import patch, Mock
 from test.test_asyncio import utils as test_utils
 from test import support
 from test.support import socket_helper
 
-
 def tearDownModule():
     asyncio.set_event_loop_policy(None)
 
@@ -380,6 +380,79 @@ def test_huge_content_recvinto(self):
             self.loop.run_until_complete(
                 self._basetest_huge_content_recvinto(httpd.address))
 
+    async def _basetest_datagram_recvfrom(self, server_address):
+        # Happy path, sock.sendto() returns immediately
+        data = b'\x01' * 4096
+        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
+            sock.setblocking(False)
+            await self.loop.sock_sendto(sock, data, server_address)
+            received_data, from_addr = await self.loop.sock_recvfrom(
+                sock, 4096)
+            self.assertEqual(received_data, data)
+            self.assertEqual(from_addr, server_address)
+
+    def test_recvfrom(self):
+        with test_utils.run_udp_echo_server() as server_address:
+            self.loop.run_until_complete(
+                self._basetest_datagram_recvfrom(server_address))
+
+    async def _basetest_datagram_recvfrom_into(self, server_address):
+        # Happy path, sock.sendto() returns immediately
+        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
+            sock.setblocking(False)
+
+            buf = bytearray(4096)
+            data = b'\x01' * 4096
+            await self.loop.sock_sendto(sock, data, server_address)
+            num_bytes, from_addr = await self.loop.sock_recvfrom_into(
+                sock, buf)
+            self.assertEqual(num_bytes, 4096)
+            self.assertEqual(buf, data)
+            self.assertEqual(from_addr, server_address)
+
+            buf = bytearray(8192)
+            await self.loop.sock_sendto(sock, data, server_address)
+            num_bytes, from_addr = await self.loop.sock_recvfrom_into(
+                sock, buf, 4096)
+            self.assertEqual(num_bytes, 4096)
+            self.assertEqual(buf[:4096], data[:4096])
+            self.assertEqual(from_addr, server_address)
+
+    def test_recvfrom_into(self):
+        with test_utils.run_udp_echo_server() as server_address:
+            self.loop.run_until_complete(
+                self._basetest_datagram_recvfrom_into(server_address))
+
+    async def _basetest_datagram_sendto_blocking(self, server_address):
+        # Sad path, sock.sendto() raises BlockingIOError
+        # This involves patching sock.sendto() to raise BlockingIOError but
+        # sendto() is not used by the proactor event loop
+        data = b'\x01' * 4096
+        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
+            sock.setblocking(False)
+            mock_sock = Mock(sock)
+            mock_sock.gettimeout = sock.gettimeout
+            mock_sock.sendto.configure_mock(side_effect=BlockingIOError)
+            mock_sock.fileno = sock.fileno
+            self.loop.call_soon(
+                lambda: setattr(mock_sock, 'sendto', sock.sendto)
+            )
+            await self.loop.sock_sendto(mock_sock, data, server_address)
+
+            received_data, from_addr = await self.loop.sock_recvfrom(
+                sock, 4096)
+            self.assertEqual(received_data, data)
+            self.assertEqual(from_addr, server_address)
+
+    def test_sendto_blocking(self):
+        if sys.platform == 'win32':
+            if isinstance(self.loop, asyncio.ProactorEventLoop):
+                raise unittest.SkipTest('Not relevant to ProactorEventLoop')
+
+        with test_utils.run_udp_echo_server() as server_address:
+            self.loop.run_until_complete(
+                self._basetest_datagram_sendto_blocking(server_address))
+
     @socket_helper.skip_unless_bind_unix_socket
     def test_unix_sock_client_ops(self):
         with test_utils.run_test_unix_server() as httpd:
diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py
index 0b9cde6878f37..c32494d40ccea 100644
--- a/Lib/test/test_asyncio/utils.py
+++ b/Lib/test/test_asyncio/utils.py
@@ -281,6 +281,31 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
                                 server_ssl_cls=SSLWSGIServer)
 
 
+def echo_datagrams(sock):
+    while True:
+        data, addr = sock.recvfrom(4096)
+        if data == b'STOP':
+            sock.close()
+            break
+        else:
+            sock.sendto(data, addr)
+
+
+ at contextlib.contextmanager
+def run_udp_echo_server(*, host='127.0.0.1', port=0):
+    addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
+    family, type, proto, _, sockaddr = addr_info[0]
+    sock = socket.socket(family, type, proto)
+    sock.bind((host, port))
+    thread = threading.Thread(target=lambda: echo_datagrams(sock))
+    thread.start()
+    try:
+        yield sock.getsockname()
+    finally:
+        sock.sendto(b'STOP', sock.getsockname())
+        thread.join()
+
+
 def make_test_protocol(base):
     dct = {}
     for name in dir(base):
diff --git a/Misc/NEWS.d/next/Library/2022-02-20-23-03-32.bpo-46805.HZ8xWG.rst b/Misc/NEWS.d/next/Library/2022-02-20-23-03-32.bpo-46805.HZ8xWG.rst
new file mode 100644
index 0000000000000..3c877d5498cd6
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2022-02-20-23-03-32.bpo-46805.HZ8xWG.rst
@@ -0,0 +1,4 @@
+Added raw datagram socket functions for asyncio:
+:meth:`~asyncio.AbstractEventLoop.sock_sendto`,
+:meth:`~asyncio.AbstractEventLoop.sock_recvfrom` and
+:meth:`~asyncio.AbstractEventLoop.sock_recvfrom_into`.
diff --git a/Modules/clinic/overlapped.c.h b/Modules/clinic/overlapped.c.h
index efecd9028b776..16d6013ef2f2e 100644
--- a/Modules/clinic/overlapped.c.h
+++ b/Modules/clinic/overlapped.c.h
@@ -905,4 +905,42 @@ _overlapped_Overlapped_WSARecvFrom(OverlappedObject *self, PyObject *const *args
 exit:
     return return_value;
 }
-/*[clinic end generated code: output=ee2ec2f93c8d334b input=a9049054013a1b77]*/
+
+PyDoc_STRVAR(_overlapped_Overlapped_WSARecvFromInto__doc__,
+"WSARecvFromInto($self, handle, buf, size, flags=0, /)\n"
+"--\n"
+"\n"
+"Start overlapped receive.");
+
+#define _OVERLAPPED_OVERLAPPED_WSARECVFROMINTO_METHODDEF    \
+    {"WSARecvFromInto", (PyCFunction)(void(*)(void))_overlapped_Overlapped_WSARecvFromInto, METH_FASTCALL, _overlapped_Overlapped_WSARecvFromInto__doc__},
+
+static PyObject *
+_overlapped_Overlapped_WSARecvFromInto_impl(OverlappedObject *self,
+                                            HANDLE handle, Py_buffer *bufobj,
+                                            DWORD size, DWORD flags);
+
+static PyObject *
+_overlapped_Overlapped_WSARecvFromInto(OverlappedObject *self, PyObject *const *args, Py_ssize_t nargs)
+{
+    PyObject *return_value = NULL;
+    HANDLE handle;
+    Py_buffer bufobj = {NULL, NULL};
+    DWORD size;
+    DWORD flags = 0;
+
+    if (!_PyArg_ParseStack(args, nargs, ""F_HANDLE"y*k|k:WSARecvFromInto",
+        &handle, &bufobj, &size, &flags)) {
+        goto exit;
+    }
+    return_value = _overlapped_Overlapped_WSARecvFromInto_impl(self, handle, &bufobj, size, flags);
+
+exit:
+    /* Cleanup for bufobj */
+    if (bufobj.obj) {
+       PyBuffer_Release(&bufobj);
+    }
+
+    return return_value;
+}
+/*[clinic end generated code: output=5c9b17890ef29d52 input=a9049054013a1b77]*/
diff --git a/Modules/overlapped.c b/Modules/overlapped.c
index 2ba48c8b845f5..ab9a2f0ce26f6 100644
--- a/Modules/overlapped.c
+++ b/Modules/overlapped.c
@@ -64,7 +64,7 @@ class _overlapped.Overlapped "OverlappedObject *" "&OverlappedType"
 enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_READINTO, TYPE_WRITE,
       TYPE_ACCEPT, TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE,
       TYPE_WAIT_NAMED_PIPE_AND_CONNECT, TYPE_TRANSMIT_FILE, TYPE_READ_FROM,
-      TYPE_WRITE_TO};
+      TYPE_WRITE_TO, TYPE_READ_FROM_INTO};
 
 typedef struct {
     PyObject_HEAD
@@ -91,6 +91,17 @@ typedef struct {
             struct sockaddr_in6 address;
             int address_length;
         } read_from;
+
+        /* Data used for reading from a connectionless socket:
+           TYPE_READ_FROM_INTO */
+        struct {
+            // A (number of bytes read, (host, port)) tuple
+            PyObject* result;
+            /* Buffer passed by the user */
+            Py_buffer *user_buffer;
+            struct sockaddr_in6 address;
+            int address_length;
+        } read_from_into;
     };
 } OverlappedObject;
 
@@ -662,6 +673,13 @@ Overlapped_clear(OverlappedObject *self)
             }
             break;
         }
+        case TYPE_READ_FROM_INTO: {
+            if (self->read_from_into.result) {
+                // We've received a message, free the result tuple.
+                Py_CLEAR(self->read_from_into.result);
+            }
+            break;
+        }
         case TYPE_WRITE:
         case TYPE_WRITE_TO:
         case TYPE_READINTO: {
@@ -866,6 +884,11 @@ _overlapped_Overlapped_getresult_impl(OverlappedObject *self, BOOL wait)
             {
                 break;
             }
+            else if (self->type == TYPE_READ_FROM_INTO &&
+                     self->read_from_into.result != NULL)
+            {
+                break;
+            }
             /* fall through */
         default:
             return SetFromWindowsErr(err);
@@ -914,6 +937,30 @@ _overlapped_Overlapped_getresult_impl(OverlappedObject *self, BOOL wait)
 
             Py_INCREF(self->read_from.result);
             return self->read_from.result;
+        case TYPE_READ_FROM_INTO:
+            // unparse the address
+            addr = unparse_address((SOCKADDR*)&self->read_from_into.address,
+                self->read_from_into.address_length);
+
+            if (addr == NULL) {
+                return NULL;
+            }
+
+            // The result is a two item tuple: (number of bytes read, address)
+            self->read_from_into.result = PyTuple_New(2);
+            if (self->read_from_into.result == NULL) {
+                Py_CLEAR(addr);
+                return NULL;
+            }
+
+            // first item: number of bytes read
+            PyTuple_SET_ITEM(self->read_from_into.result, 0,
+                PyLong_FromUnsignedLong((unsigned long)transferred));
+            // second item: address
+            PyTuple_SET_ITEM(self->read_from_into.result, 1, addr);
+
+            Py_INCREF(self->read_from_into.result);
+            return self->read_from_into.result;
         default:
             return PyLong_FromUnsignedLong((unsigned long) transferred);
     }
@@ -1053,6 +1100,7 @@ do_WSARecv(OverlappedObject *self, HANDLE handle,
     }
 }
 
+
 /*[clinic input]
 _overlapped.Overlapped.WSARecv
 
@@ -1617,6 +1665,13 @@ Overlapped_traverse(OverlappedObject *self, visitproc visit, void *arg)
     case TYPE_READ_FROM:
         Py_VISIT(self->read_from.result);
         Py_VISIT(self->read_from.allocated_buffer);
+        break;
+    case TYPE_READ_FROM_INTO:
+        Py_VISIT(self->read_from_into.result);
+        if (self->read_from_into.user_buffer->obj) {
+            Py_VISIT(&self->read_from_into.user_buffer->obj);
+        }
+        break;
     }
     return 0;
 }
@@ -1766,8 +1821,8 @@ _overlapped_Overlapped_WSARecvFrom_impl(OverlappedObject *self,
                                         DWORD flags)
 /*[clinic end generated code: output=13832a2025b86860 input=1b2663fa130e0286]*/
 {
-    DWORD nread;
     PyObject *buf;
+    DWORD nread;
     WSABUF wsabuf;
     int ret;
     DWORD err;
@@ -1785,8 +1840,8 @@ _overlapped_Overlapped_WSARecvFrom_impl(OverlappedObject *self,
         return NULL;
     }
 
-    wsabuf.len = size;
     wsabuf.buf = PyBytes_AS_STRING(buf);
+    wsabuf.len = size;
 
     self->type = TYPE_READ_FROM;
     self->handle = handle;
@@ -1802,8 +1857,74 @@ _overlapped_Overlapped_WSARecvFrom_impl(OverlappedObject *self,
     Py_END_ALLOW_THREADS
 
     self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS);
+    switch (err) {
+    case ERROR_BROKEN_PIPE:
+        mark_as_completed(&self->overlapped);
+        return SetFromWindowsErr(err);
+    case ERROR_SUCCESS:
+    case ERROR_MORE_DATA:
+    case ERROR_IO_PENDING:
+        Py_RETURN_NONE;
+    default:
+        self->type = TYPE_NOT_STARTED;
+        return SetFromWindowsErr(err);
+    }
+}
 
-    switch(err) {
+
+/*[clinic input]
+_overlapped.Overlapped.WSARecvFromInto
+
+    handle: HANDLE
+    buf as bufobj: Py_buffer
+    size: DWORD
+    flags: DWORD = 0
+    /
+
+Start overlapped receive.
+[clinic start generated code]*/
+
+static PyObject *
+_overlapped_Overlapped_WSARecvFromInto_impl(OverlappedObject *self,
+                                            HANDLE handle, Py_buffer *bufobj,
+                                            DWORD size, DWORD flags)
+/*[clinic end generated code: output=30c7ea171a691757 input=4be4b08d03531e76]*/
+{
+    DWORD nread;
+    WSABUF wsabuf;
+    int ret;
+    DWORD err;
+
+    if (self->type != TYPE_NONE) {
+        PyErr_SetString(PyExc_ValueError, "operation already attempted");
+        return NULL;
+    }
+
+#if SIZEOF_SIZE_T > SIZEOF_LONG
+    if (bufobj->len > (Py_ssize_t)ULONG_MAX) {
+        PyErr_SetString(PyExc_ValueError, "buffer too large");
+        return NULL;
+    }
+#endif
+
+    wsabuf.buf = bufobj->buf;
+    wsabuf.len = size;
+
+    self->type = TYPE_READ_FROM_INTO;
+    self->handle = handle;
+    self->read_from_into.user_buffer = bufobj;
+    memset(&self->read_from_into.address, 0, sizeof(self->read_from_into.address));
+    self->read_from_into.address_length = sizeof(self->read_from_into.address);
+
+    Py_BEGIN_ALLOW_THREADS
+    ret = WSARecvFrom((SOCKET)handle, &wsabuf, 1, &nread, &flags,
+                      (SOCKADDR*)&self->read_from_into.address,
+                      &self->read_from_into.address_length,
+                      &self->overlapped, NULL);
+    Py_END_ALLOW_THREADS
+
+    self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS);
+    switch (err) {
     case ERROR_BROKEN_PIPE:
         mark_as_completed(&self->overlapped);
         return SetFromWindowsErr(err);
@@ -1817,6 +1938,7 @@ _overlapped_Overlapped_WSARecvFrom_impl(OverlappedObject *self,
     }
 }
 
+
 #include "clinic/overlapped.c.h"
 
 static PyMethodDef Overlapped_methods[] = {
@@ -1826,6 +1948,8 @@ static PyMethodDef Overlapped_methods[] = {
     _OVERLAPPED_OVERLAPPED_READFILEINTO_METHODDEF
     _OVERLAPPED_OVERLAPPED_WSARECV_METHODDEF
     _OVERLAPPED_OVERLAPPED_WSARECVINTO_METHODDEF
+    _OVERLAPPED_OVERLAPPED_WSARECVFROM_METHODDEF
+    _OVERLAPPED_OVERLAPPED_WSARECVFROMINTO_METHODDEF
     _OVERLAPPED_OVERLAPPED_WRITEFILE_METHODDEF
     _OVERLAPPED_OVERLAPPED_WSASEND_METHODDEF
     _OVERLAPPED_OVERLAPPED_ACCEPTEX_METHODDEF



More information about the Python-checkins mailing list