[Python-checkins] cpython: fix issue #17552: add socket.sendfile() method allowing to send a file over a

giampaolo.rodola python-checkins at python.org
Wed Jun 11 03:54:44 CEST 2014


http://hg.python.org/cpython/rev/001895c39fea
changeset:   91122:001895c39fea
user:        Giampaolo Rodola' <g.rodola at gmail.com>
date:        Wed Jun 11 03:54:30 2014 +0200
summary:
  fix issue #17552: add socket.sendfile() method allowing to send a file over a socket by using high-performance os.sendfile() on UNIX. Patch by Giampaolo Rodola'·

files:
  Doc/library/os.rst      |    4 +
  Doc/library/socket.rst  |   15 +
  Doc/library/ssl.rst     |    3 +
  Doc/whatsnew/3.5.rst    |   11 +-
  Lib/socket.py           |  148 +++++++++++++++-
  Lib/ssl.py              |   10 +
  Lib/test/test_socket.py |  273 ++++++++++++++++++++++++++++
  Lib/test/test_ssl.py    |   17 +
  Misc/NEWS               |    4 +
  9 files changed, 483 insertions(+), 2 deletions(-)


diff --git a/Doc/library/os.rst b/Doc/library/os.rst
--- a/Doc/library/os.rst
+++ b/Doc/library/os.rst
@@ -1092,6 +1092,10 @@
 
    Availability: Unix.
 
+   .. note::
+
+      For a higher-level version of this see :mod:`socket.socket.sendfile`.
+
    .. versionadded:: 3.3
 
 
diff --git a/Doc/library/socket.rst b/Doc/library/socket.rst
--- a/Doc/library/socket.rst
+++ b/Doc/library/socket.rst
@@ -1148,6 +1148,21 @@
 
    .. versionadded:: 3.3
 
+.. method:: socket.sendfile(file, offset=0, count=None)
+
+   Send a file until EOF is reached by using high-performance
+   :mod:`os.sendfile` and return the total number of bytes which were sent.
+   *file* must be a regular file object opened in binary mode. If
+   :mod:`os.sendfile` is not available (e.g. Windows) or *file* is not a
+   regular file :meth:`send` will be used instead. *offset* tells from where to
+   start reading the file. If specified, *count* is the total number of bytes
+   to transmit as opposed to sending the file until EOF is reached. File
+   position is updated on return or also in case of error in which case
+   :meth:`file.tell() <io.IOBase.tell>` can be used to figure out the number of
+   bytes which were sent. The socket must be of :const:`SOCK_STREAM` type. Non-
+   blocking sockets are not supported.
+
+   .. versionadded:: 3.5
 
 .. method:: socket.set_inheritable(inheritable)
 
diff --git a/Doc/library/ssl.rst b/Doc/library/ssl.rst
--- a/Doc/library/ssl.rst
+++ b/Doc/library/ssl.rst
@@ -789,6 +789,9 @@
   (but passing a non-zero ``flags`` argument is not allowed)
 - :meth:`~socket.socket.send()`, :meth:`~socket.socket.sendall()` (with
   the same limitation)
+- :meth:`~socket.socket.sendfile()` (but :mod:`os.sendfile` will be used
+  for plain-text sockets only, else :meth:`~socket.socket.send()` will be used)
+   .. versionadded:: 3.5
 - :meth:`~socket.socket.shutdown()`
 
 However, since the SSL (and TLS) protocol has its own framing atop
diff --git a/Doc/whatsnew/3.5.rst b/Doc/whatsnew/3.5.rst
--- a/Doc/whatsnew/3.5.rst
+++ b/Doc/whatsnew/3.5.rst
@@ -181,9 +181,18 @@
 
 * Different constants of :mod:`signal` module are now enumeration values using
   the :mod:`enum` module. This allows meaningful names to be printed during
-  debugging, instead of integer “magic numbers”. (contribute by Giampaolo
+  debugging, instead of integer “magic numbers”. (contributed by Giampaolo
   Rodola' in :issue:`21076`)
 
+socket
+------
+
+* New :meth:`socket.socket.sendfile` method allows to send a file over a socket
+  by using high-performance :func:`os.sendfile` function on UNIX resulting in
+  uploads being from 2x to 3x faster than when using plain
+  :meth:`socket.socket.send`.
+  (contributed by Giampaolo Rodola' in :issue:`17552`)
+
 xmlrpc
 ------
 
diff --git a/Lib/socket.py b/Lib/socket.py
--- a/Lib/socket.py
+++ b/Lib/socket.py
@@ -47,7 +47,7 @@
 import _socket
 from _socket import *
 
-import os, sys, io
+import os, sys, io, selectors
 from enum import IntEnum
 
 try:
@@ -109,6 +109,9 @@
     __all__.append("errorTab")
 
 
+class _GiveupOnSendfile(Exception): pass
+
+
 class socket(_socket.socket):
 
     """A subclass of _socket.socket adding the makefile() method."""
@@ -233,6 +236,149 @@
         text.mode = mode
         return text
 
+    if hasattr(os, 'sendfile'):
+
+        def _sendfile_use_sendfile(self, file, offset=0, count=None):
+            self._check_sendfile_params(file, offset, count)
+            sockno = self.fileno()
+            try:
+                fileno = file.fileno()
+            except (AttributeError, io.UnsupportedOperation) as err:
+                raise _GiveupOnSendfile(err)  # not a regular file
+            try:
+                fsize = os.fstat(fileno).st_size
+            except OSError:
+                raise _GiveupOnSendfile(err)  # not a regular file
+            if not fsize:
+                return 0  # empty file
+            blocksize = fsize if not count else count
+
+            timeout = self.gettimeout()
+            if timeout == 0:
+                raise ValueError("non-blocking sockets are not supported")
+            # poll/select have the advantage of not requiring any
+            # extra file descriptor, contrarily to epoll/kqueue
+            # (also, they require a single syscall).
+            if hasattr(selectors, 'PollSelector'):
+                selector = selectors.PollSelector()
+            else:
+                selector = selectors.SelectSelector()
+            selector.register(sockno, selectors.EVENT_WRITE)
+
+            total_sent = 0
+            # localize variable access to minimize overhead
+            selector_select = selector.select
+            os_sendfile = os.sendfile
+            try:
+                while True:
+                    if timeout and not selector_select(timeout):
+                        raise _socket.timeout('timed out')
+                    if count:
+                        blocksize = count - total_sent
+                        if blocksize <= 0:
+                            break
+                    try:
+                        sent = os_sendfile(sockno, fileno, offset, blocksize)
+                    except BlockingIOError:
+                        if not timeout:
+                            # Block until the socket is ready to send some
+                            # data; avoids hogging CPU resources.
+                            selector_select()
+                        continue
+                    except OSError as err:
+                        if total_sent == 0:
+                            # We can get here for different reasons, the main
+                            # one being 'file' is not a regular mmap(2)-like
+                            # file, in which case we'll fall back on using
+                            # plain send().
+                            raise _GiveupOnSendfile(err)
+                        raise err from None
+                    else:
+                        if sent == 0:
+                            break  # EOF
+                        offset += sent
+                        total_sent += sent
+                return total_sent
+            finally:
+                if total_sent > 0 and hasattr(file, 'seek'):
+                    file.seek(offset)
+    else:
+        def _sendfile_use_sendfile(self, file, offset=0, count=None):
+            raise _GiveupOnSendfile(
+                "os.sendfile() not available on this platform")
+
+    def _sendfile_use_send(self, file, offset=0, count=None):
+        self._check_sendfile_params(file, offset, count)
+        if self.gettimeout() == 0:
+            raise ValueError("non-blocking sockets are not supported")
+        if offset:
+            file.seek(offset)
+        blocksize = min(count, 8192) if count else 8192
+        total_sent = 0
+        # localize variable access to minimize overhead
+        file_read = file.read
+        sock_send = self.send
+        try:
+            while True:
+                if count:
+                    blocksize = min(count - total_sent, blocksize)
+                    if blocksize <= 0:
+                        break
+                data = memoryview(file_read(blocksize))
+                if not data:
+                    break  # EOF
+                while True:
+                    try:
+                        sent = sock_send(data)
+                    except BlockingIOError:
+                        continue
+                    else:
+                        total_sent += sent
+                        if sent < len(data):
+                            data = data[sent:]
+                        else:
+                            break
+            return total_sent
+        finally:
+            if total_sent > 0 and hasattr(file, 'seek'):
+                file.seek(offset + total_sent)
+
+    def _check_sendfile_params(self, file, offset, count):
+        if 'b' not in getattr(file, 'mode', 'b'):
+            raise ValueError("file should be opened in binary mode")
+        if not self.type & SOCK_STREAM:
+            raise ValueError("only SOCK_STREAM type sockets are supported")
+        if count is not None:
+            if not isinstance(count, int):
+                raise TypeError(
+                    "count must be a positive integer (got {!r})".format(count))
+            if count <= 0:
+                raise ValueError(
+                    "count must be a positive integer (got {!r})".format(count))
+
+    def sendfile(self, file, offset=0, count=None):
+        """sendfile(file[, offset[, count]]) -> sent
+
+        Send a file until EOF is reached by using high-performance
+        os.sendfile() and return the total number of bytes which
+        were sent.
+        *file* must be a regular file object opened in binary mode.
+        If os.sendfile() is not available (e.g. Windows) or file is
+        not a regular file socket.send() will be used instead.
+        *offset* tells from where to start reading the file.
+        If specified, *count* is the total number of bytes to transmit
+        as opposed to sending the file until EOF is reached.
+        File position is updated on return or also in case of error in
+        which case file.tell() can be used to figure out the number of
+        bytes which were sent.
+        The socket must be of SOCK_STREAM type.
+        Non-blocking sockets are not supported.
+        """
+        try:
+            return self._sendfile_use_sendfile(file, offset, count)
+        except _GiveupOnSendfile:
+            return self._sendfile_use_send(file, offset, count)
+
     def _decref_socketios(self):
         if self._io_refs > 0:
             self._io_refs -= 1
diff --git a/Lib/ssl.py b/Lib/ssl.py
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -700,6 +700,16 @@
         else:
             return socket.sendall(self, data, flags)
 
+    def sendfile(self, file, offset=0, count=None):
+        """Send a file, possibly by using os.sendfile() if this is a
+        clear-text socket.  Return the total number of bytes sent.
+        """
+        if self._sslobj is None:
+            # os.sendfile() works with plain sockets only
+            return super().sendfile(file, offset, count)
+        else:
+            return self._sendfile_use_send(file, offset, count)
+
     def recv(self, buflen=1024, flags=0):
         self._checkClosed()
         if self._sslobj:
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -19,6 +19,8 @@
 import math
 import pickle
 import struct
+import random
+import string
 try:
     import multiprocessing
 except ImportError:
@@ -5077,6 +5079,275 @@
                     source.close()
 
 
+ at unittest.skipUnless(thread, 'Threading required for this test.')
+class SendfileUsingSendTest(ThreadedTCPSocketTest):
+    """
+    Test the send() implementation of socket.sendfile().
+    """
+
+    FILESIZE = (10 * 1024 * 1024)  # 10MB
+    BUFSIZE = 8192
+    FILEDATA = b""
+    TIMEOUT = 2
+
+    @classmethod
+    def setUpClass(cls):
+        def chunks(total, step):
+            assert total >= step
+            while total > step:
+                yield step
+                total -= step
+            if total:
+                yield total
+
+        chunk = b"".join([random.choice(string.ascii_letters).encode()
+                          for i in range(cls.BUFSIZE)])
+        with open(support.TESTFN, 'wb') as f:
+            for csize in chunks(cls.FILESIZE, cls.BUFSIZE):
+                f.write(chunk)
+        with open(support.TESTFN, 'rb') as f:
+            cls.FILEDATA = f.read()
+            assert len(cls.FILEDATA) == cls.FILESIZE
+
+    @classmethod
+    def tearDownClass(cls):
+        support.unlink(support.TESTFN)
+
+    def accept_conn(self):
+        self.serv.settimeout(self.TIMEOUT)
+        conn, addr = self.serv.accept()
+        conn.settimeout(self.TIMEOUT)
+        self.addCleanup(conn.close)
+        return conn
+
+    def recv_data(self, conn):
+        received = []
+        while True:
+            chunk = conn.recv(self.BUFSIZE)
+            if not chunk:
+                break
+            received.append(chunk)
+        return b''.join(received)
+
+    def meth_from_sock(self, sock):
+        # Depending on the mixin class being run return either send()
+        # or sendfile() method implementation.
+        return getattr(sock, "_sendfile_use_send")
+
+    # regular file
+
+    def _testRegularFile(self):
+        address = self.serv.getsockname()
+        file = open(support.TESTFN, 'rb')
+        with socket.create_connection(address) as sock, file as file:
+            meth = self.meth_from_sock(sock)
+            sent = meth(file)
+            self.assertEqual(sent, self.FILESIZE)
+            self.assertEqual(file.tell(), self.FILESIZE)
+
+    def testRegularFile(self):
+        conn = self.accept_conn()
+        data = self.recv_data(conn)
+        self.assertEqual(len(data), self.FILESIZE)
+        self.assertEqual(data, self.FILEDATA)
+
+    # non regular file
+
+    def _testNonRegularFile(self):
+        address = self.serv.getsockname()
+        file = io.BytesIO(self.FILEDATA)
+        with socket.create_connection(address) as sock, file as file:
+            sent = sock.sendfile(file)
+            self.assertEqual(sent, self.FILESIZE)
+            self.assertEqual(file.tell(), self.FILESIZE)
+            self.assertRaises(socket._GiveupOnSendfile,
+                              sock._sendfile_use_sendfile, file)
+
+    def testNonRegularFile(self):
+        conn = self.accept_conn()
+        data = self.recv_data(conn)
+        self.assertEqual(len(data), self.FILESIZE)
+        self.assertEqual(data, self.FILEDATA)
+
+    # empty file
+
+    def _testEmptyFileSend(self):
+        address = self.serv.getsockname()
+        filename = support.TESTFN + "2"
+        with open(filename, 'wb'):
+            self.addCleanup(support.unlink, filename)
+        file = open(filename, 'rb')
+        with socket.create_connection(address) as sock, file as file:
+            meth = self.meth_from_sock(sock)
+            sent = meth(file)
+            self.assertEqual(sent, 0)
+            self.assertEqual(file.tell(), 0)
+
+    def testEmptyFileSend(self):
+        conn = self.accept_conn()
+        data = self.recv_data(conn)
+        self.assertEqual(data, b"")
+
+    # offset
+
+    def _testOffset(self):
+        address = self.serv.getsockname()
+        file = open(support.TESTFN, 'rb')
+        with socket.create_connection(address) as sock, file as file:
+            meth = self.meth_from_sock(sock)
+            sent = meth(file, offset=5000)
+            self.assertEqual(sent, self.FILESIZE - 5000)
+            self.assertEqual(file.tell(), self.FILESIZE)
+
+    def testOffset(self):
+        conn = self.accept_conn()
+        data = self.recv_data(conn)
+        self.assertEqual(len(data), self.FILESIZE - 5000)
+        self.assertEqual(data, self.FILEDATA[5000:])
+
+    # count
+
+    def _testCount(self):
+        address = self.serv.getsockname()
+        file = open(support.TESTFN, 'rb')
+        with socket.create_connection(address, timeout=2) as sock, file as file:
+            count = 5000007
+            meth = self.meth_from_sock(sock)
+            sent = meth(file, count=count)
+            self.assertEqual(sent, count)
+            self.assertEqual(file.tell(), count)
+
+    def testCount(self):
+        count = 5000007
+        conn = self.accept_conn()
+        data = self.recv_data(conn)
+        self.assertEqual(len(data), count)
+        self.assertEqual(data, self.FILEDATA[:count])
+
+    # count small
+
+    def _testCountSmall(self):
+        address = self.serv.getsockname()
+        file = open(support.TESTFN, 'rb')
+        with socket.create_connection(address, timeout=2) as sock, file as file:
+            count = 1
+            meth = self.meth_from_sock(sock)
+            sent = meth(file, count=count)
+            self.assertEqual(sent, count)
+            self.assertEqual(file.tell(), count)
+
+    def testCountSmall(self):
+        count = 1
+        conn = self.accept_conn()
+        data = self.recv_data(conn)
+        self.assertEqual(len(data), count)
+        self.assertEqual(data, self.FILEDATA[:count])
+
+    # count + offset
+
+    def _testCountWithOffset(self):
+        address = self.serv.getsockname()
+        file = open(support.TESTFN, 'rb')
+        with socket.create_connection(address, timeout=2) as sock, file as file:
+            count = 100007
+            meth = self.meth_from_sock(sock)
+            sent = meth(file, offset=2007, count=count)
+            self.assertEqual(sent, count)
+            self.assertEqual(file.tell(), count + 2007)
+
+    def testCountWithOffset(self):
+        count = 100007
+        conn = self.accept_conn()
+        data = self.recv_data(conn)
+        self.assertEqual(len(data), count)
+        self.assertEqual(data, self.FILEDATA[2007:count+2007])
+
+    # non blocking sockets are not supposed to work
+
+    def _testNonBlocking(self):
+        address = self.serv.getsockname()
+        file = open(support.TESTFN, 'rb')
+        with socket.create_connection(address) as sock, file as file:
+            sock.setblocking(False)
+            meth = self.meth_from_sock(sock)
+            self.assertRaises(ValueError, meth, file)
+            self.assertRaises(ValueError, sock.sendfile, file)
+
+    def testNonBlocking(self):
+        conn = self.accept_conn()
+        if conn.recv(8192):
+            self.fail('was not supposed to receive any data')
+
+    # timeout (non-triggered)
+
+    def _testWithTimeout(self):
+        address = self.serv.getsockname()
+        file = open(support.TESTFN, 'rb')
+        with socket.create_connection(address, timeout=2) as sock, file as file:
+            meth = self.meth_from_sock(sock)
+            sent = meth(file)
+            self.assertEqual(sent, self.FILESIZE)
+
+    def testWithTimeout(self):
+        conn = self.accept_conn()
+        data = self.recv_data(conn)
+        self.assertEqual(len(data), self.FILESIZE)
+        self.assertEqual(data, self.FILEDATA)
+
+    # timeout (triggered)
+
+    def _testWithTimeoutTriggeredSend(self):
+        address = self.serv.getsockname()
+        file = open(support.TESTFN, 'rb')
+        with socket.create_connection(address, timeout=0.01) as sock, \
+                file as file:
+            meth = self.meth_from_sock(sock)
+            self.assertRaises(socket.timeout, meth, file)
+
+    def testWithTimeoutTriggeredSend(self):
+        conn = self.accept_conn()
+        conn.recv(88192)
+
+    # errors
+
+    def _test_errors(self):
+        pass
+
+    def test_errors(self):
+        with open(support.TESTFN, 'rb') as file:
+            with socket.socket(type=socket.SOCK_DGRAM) as s:
+                meth = self.meth_from_sock(s)
+                self.assertRaisesRegex(
+                    ValueError, "SOCK_STREAM", meth, file)
+        with open(support.TESTFN, 'rt') as file:
+            with socket.socket() as s:
+                meth = self.meth_from_sock(s)
+                self.assertRaisesRegex(
+                    ValueError, "binary mode", meth, file)
+        with open(support.TESTFN, 'rb') as file:
+            with socket.socket() as s:
+                meth = self.meth_from_sock(s)
+                self.assertRaisesRegex(TypeError, "positive integer",
+                                       meth, file, count='2')
+                self.assertRaisesRegex(TypeError, "positive integer",
+                                       meth, file, count=0.1)
+                self.assertRaisesRegex(ValueError, "positive integer",
+                                       meth, file, count=0)
+                self.assertRaisesRegex(ValueError, "positive integer",
+                                       meth, file, count=-1)
+
+
+ at unittest.skipUnless(thread, 'Threading required for this test.')
+ at unittest.skipUnless(hasattr(os, "sendfile"),
+                     'os.sendfile() required for this test.')
+class SendfileUsingSendfileTest(SendfileUsingSendTest):
+    """
+    Test the sendfile() implementation of socket.sendfile().
+    """
+    def meth_from_sock(self, sock):
+        return getattr(sock, "_sendfile_use_sendfile")
+
+
 def test_main():
     tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
              TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ]
@@ -5129,6 +5400,8 @@
         InterruptedRecvTimeoutTest,
         InterruptedSendTimeoutTest,
         TestSocketSharing,
+        SendfileUsingSendTest,
+        SendfileUsingSendfileTest,
     ])
 
     thread_info = support.threading_setup()
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -2957,6 +2957,23 @@
                 self.assertRaises(ValueError, s.read, 1024)
                 self.assertRaises(ValueError, s.write, b'hello')
 
+        def test_sendfile(self):
+            TEST_DATA = b"x" * 512
+            with open(support.TESTFN, 'wb') as f:
+                f.write(TEST_DATA)
+            self.addCleanup(support.unlink, support.TESTFN)
+            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+            context.verify_mode = ssl.CERT_REQUIRED
+            context.load_verify_locations(CERTFILE)
+            context.load_cert_chain(CERTFILE)
+            server = ThreadedEchoServer(context=context, chatty=False)
+            with server:
+                with context.wrap_socket(socket.socket()) as s:
+                    s.connect((HOST, server.port))
+                    with open(support.TESTFN, 'rb') as file:
+                        s.sendfile(file)
+                        self.assertEqual(s.recv(1024), TEST_DATA)
+
 
 def test_main(verbose=False):
     if support.verbose:
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -92,6 +92,10 @@
 Library
 -------
 
+- Issue 17552: new socket.sendfile() method allowing to send a file over a
+  socket by using high-performance os.sendfile() on UNIX.
+  Patch by Giampaolo Rodola'.
+
 - Issue #18039: dbm.dump.open() now always creates a new database when the
   flag has the value 'n'.  Patch by Claudiu Popa.
 

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list