[Python-checkins] cpython (2.7): allow a SSLContext to be given to ftplib.FTP_TLS

benjamin.peterson python-checkins at python.org
Sun Jan 4 22:36:56 CET 2015


https://hg.python.org/cpython/rev/e8342b3154d1
changeset:   94017:e8342b3154d1
branch:      2.7
user:        Benjamin Peterson <benjamin at python.org>
date:        Sun Jan 04 15:36:31 2015 -0600
summary:
  allow a SSLContext to be given to ftplib.FTP_TLS

files:
  Doc/library/ftplib.rst  |   18 +++-
  Lib/ftplib.py           |   22 ++++-
  Lib/test/test_ftplib.py |  115 +++++++++++++++++++++------
  Misc/NEWS               |    2 +
  4 files changed, 120 insertions(+), 37 deletions(-)


diff --git a/Doc/library/ftplib.rst b/Doc/library/ftplib.rst
--- a/Doc/library/ftplib.rst
+++ b/Doc/library/ftplib.rst
@@ -55,18 +55,26 @@
       *timeout* was added.
 
 
-.. class:: FTP_TLS([host[, user[, passwd[, acct[, keyfile[, certfile[, timeout]]]]]]])
+.. class:: FTP_TLS([host[, user[, passwd[, acct[, keyfile[, certfile[, context[, timeout]]]]]]]])
 
-   A :class:`FTP` subclass which adds TLS support to FTP as described in
+    A :class:`FTP` subclass which adds TLS support to FTP as described in
    :rfc:`4217`.
    Connect as usual to port 21 implicitly securing the FTP control connection
    before authenticating. Securing the data connection requires the user to
-   explicitly ask for it by calling the :meth:`prot_p` method.
-   *keyfile* and *certfile* are optional -- they can contain a PEM formatted
-   private key and certificate chain file name for the SSL connection.
+   explicitly ask for it by calling the :meth:`prot_p` method.  *context*
+   is a :class:`ssl.SSLContext` object which allows bundling SSL configuration
+   options, certificates and private keys into a single (potentially
+   long-lived) structure.  Please read :ref:`ssl-security` for best practices.
+
+   *keyfile* and *certfile* are a legacy alternative to *context* -- they
+   can point to PEM-formatted private key and certificate chain files
+   (respectively) for the SSL connection.
 
    .. versionadded:: 2.7
 
+   .. versionchanged:: 2.7.10
+      The *context* parameter was added.
+
    Here's a sample session using the :class:`FTP_TLS` class:
 
    >>> from ftplib import FTP_TLS
diff --git a/Lib/ftplib.py b/Lib/ftplib.py
--- a/Lib/ftplib.py
+++ b/Lib/ftplib.py
@@ -641,9 +641,21 @@
         ssl_version = ssl.PROTOCOL_SSLv23
 
         def __init__(self, host='', user='', passwd='', acct='', keyfile=None,
-                     certfile=None, timeout=_GLOBAL_DEFAULT_TIMEOUT):
+                     certfile=None, context=None,
+                     timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None):
+            if context is not None and keyfile is not None:
+                raise ValueError("context and keyfile arguments are mutually "
+                                 "exclusive")
+            if context is not None and certfile is not None:
+                raise ValueError("context and certfile arguments are mutually "
+                                 "exclusive")
             self.keyfile = keyfile
             self.certfile = certfile
+            if context is None:
+                context = ssl._create_stdlib_context(self.ssl_version,
+                                                     certfile=certfile,
+                                                     keyfile=keyfile)
+            self.context = context
             self._prot_p = False
             FTP.__init__(self, host, user, passwd, acct, timeout)
 
@@ -660,8 +672,8 @@
                 resp = self.voidcmd('AUTH TLS')
             else:
                 resp = self.voidcmd('AUTH SSL')
-            self.sock = ssl.wrap_socket(self.sock, self.keyfile, self.certfile,
-                                        ssl_version=self.ssl_version)
+            self.sock = self.context.wrap_socket(self.sock,
+                                                 server_hostname=self.host)
             self.file = self.sock.makefile(mode='rb')
             return resp
 
@@ -692,8 +704,8 @@
         def ntransfercmd(self, cmd, rest=None):
             conn, size = FTP.ntransfercmd(self, cmd, rest)
             if self._prot_p:
-                conn = ssl.wrap_socket(conn, self.keyfile, self.certfile,
-                                       ssl_version=self.ssl_version)
+                conn = self.context.wrap_socket(conn,
+                                                server_hostname=self.host)
             return conn, size
 
         def retrbinary(self, cmd, callback, blocksize=8192, rest=None):
diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py
--- a/Lib/test/test_ftplib.py
+++ b/Lib/test/test_ftplib.py
@@ -20,7 +20,7 @@
 from test.test_support import HOST, HOSTv6
 threading = test_support.import_module('threading')
 
-
+TIMEOUT = 3
 # the dummy data returned by server over the data channel when
 # RETR, LIST and NLST commands are issued
 RETR_DATA = 'abcde12345\r\n' * 1000
@@ -223,6 +223,7 @@
         self.active = False
         self.active_lock = threading.Lock()
         self.host, self.port = self.socket.getsockname()[:2]
+        self.handler_instance = None
 
     def start(self):
         assert not self.active
@@ -246,8 +247,7 @@
 
     def handle_accept(self):
         conn, addr = self.accept()
-        self.handler = self.handler(conn)
-        self.close()
+        self.handler_instance = self.handler(conn)
 
     def handle_connect(self):
         self.close()
@@ -262,7 +262,8 @@
 
 if ssl is not None:
 
-    CERTFILE = os.path.join(os.path.dirname(__file__), "keycert.pem")
+    CERTFILE = os.path.join(os.path.dirname(__file__), "keycert3.pem")
+    CAFILE = os.path.join(os.path.dirname(__file__), "pycacert.pem")
 
     class SSLConnection(object, asyncore.dispatcher):
         """An asyncore.dispatcher subclass supporting TLS/SSL."""
@@ -271,23 +272,25 @@
         _ssl_closing = False
 
         def secure_connection(self):
-            self.socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False,
-                                          certfile=CERTFILE, server_side=True,
-                                          do_handshake_on_connect=False,
-                                          ssl_version=ssl.PROTOCOL_SSLv23)
+            socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False,
+                                     certfile=CERTFILE, server_side=True,
+                                     do_handshake_on_connect=False,
+                                     ssl_version=ssl.PROTOCOL_SSLv23)
+            self.del_channel()
+            self.set_socket(socket)
             self._ssl_accepting = True
 
         def _do_ssl_handshake(self):
             try:
                 self.socket.do_handshake()
-            except ssl.SSLError, err:
+            except ssl.SSLError as err:
                 if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
                                    ssl.SSL_ERROR_WANT_WRITE):
                     return
                 elif err.args[0] == ssl.SSL_ERROR_EOF:
                     return self.handle_close()
                 raise
-            except socket.error, err:
+            except socket.error as err:
                 if err.args[0] == errno.ECONNABORTED:
                     return self.handle_close()
             else:
@@ -297,18 +300,21 @@
             self._ssl_closing = True
             try:
                 self.socket = self.socket.unwrap()
-            except ssl.SSLError, err:
+            except ssl.SSLError as err:
                 if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
                                    ssl.SSL_ERROR_WANT_WRITE):
                     return
-            except socket.error, err:
+            except socket.error as err:
                 # Any "socket error" corresponds to a SSL_ERROR_SYSCALL return
                 # from OpenSSL's SSL_shutdown(), corresponding to a
                 # closed socket condition. See also:
                 # http://www.mail-archive.com/openssl-users@openssl.org/msg60710.html
                 pass
             self._ssl_closing = False
-            super(SSLConnection, self).close()
+            if getattr(self, '_ccc', False) is False:
+                super(SSLConnection, self).close()
+            else:
+                pass
 
         def handle_read_event(self):
             if self._ssl_accepting:
@@ -329,7 +335,7 @@
         def send(self, data):
             try:
                 return super(SSLConnection, self).send(data)
-            except ssl.SSLError, err:
+            except ssl.SSLError as err:
                 if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN,
                                    ssl.SSL_ERROR_WANT_READ,
                                    ssl.SSL_ERROR_WANT_WRITE):
@@ -339,13 +345,13 @@
         def recv(self, buffer_size):
             try:
                 return super(SSLConnection, self).recv(buffer_size)
-            except ssl.SSLError, err:
+            except ssl.SSLError as err:
                 if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
                                    ssl.SSL_ERROR_WANT_WRITE):
-                    return ''
+                    return b''
                 if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN):
                     self.handle_close()
-                    return ''
+                    return b''
                 raise
 
         def handle_error(self):
@@ -355,6 +361,8 @@
             if (isinstance(self.socket, ssl.SSLSocket) and
                 self.socket._sslobj is not None):
                 self._do_ssl_shutdown()
+            else:
+                super(SSLConnection, self).close()
 
 
     class DummyTLS_DTPHandler(SSLConnection, DummyDTPHandler):
@@ -462,12 +470,12 @@
 
     def test_rename(self):
         self.client.rename('a', 'b')
-        self.server.handler.next_response = '200'
+        self.server.handler_instance.next_response = '200'
         self.assertRaises(ftplib.error_reply, self.client.rename, 'a', 'b')
 
     def test_delete(self):
         self.client.delete('foo')
-        self.server.handler.next_response = '199'
+        self.server.handler_instance.next_response = '199'
         self.assertRaises(ftplib.error_reply, self.client.delete, 'foo')
 
     def test_size(self):
@@ -515,7 +523,7 @@
     def test_storbinary(self):
         f = StringIO.StringIO(RETR_DATA)
         self.client.storbinary('stor', f)
-        self.assertEqual(self.server.handler.last_received_data, RETR_DATA)
+        self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA)
         # test new callback arg
         flag = []
         f.seek(0)
@@ -527,12 +535,12 @@
         for r in (30, '30'):
             f.seek(0)
             self.client.storbinary('stor', f, rest=r)
-            self.assertEqual(self.server.handler.rest, str(r))
+            self.assertEqual(self.server.handler_instance.rest, str(r))
 
     def test_storlines(self):
         f = StringIO.StringIO(RETR_DATA.replace('\r\n', '\n'))
         self.client.storlines('stor', f)
-        self.assertEqual(self.server.handler.last_received_data, RETR_DATA)
+        self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA)
         # test new callback arg
         flag = []
         f.seek(0)
@@ -551,14 +559,14 @@
     def test_makeport(self):
         self.client.makeport()
         # IPv4 is in use, just make sure send_eprt has not been used
-        self.assertEqual(self.server.handler.last_received_cmd, 'port')
+        self.assertEqual(self.server.handler_instance.last_received_cmd, 'port')
 
     def test_makepasv(self):
         host, port = self.client.makepasv()
         conn = socket.create_connection((host, port), 10)
         conn.close()
         # IPv4 is in use, just make sure send_epsv has not been used
-        self.assertEqual(self.server.handler.last_received_cmd, 'pasv')
+        self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv')
 
     def test_line_too_long(self):
         self.assertRaises(ftplib.Error, self.client.sendcmd,
@@ -600,13 +608,13 @@
 
     def test_makeport(self):
         self.client.makeport()
-        self.assertEqual(self.server.handler.last_received_cmd, 'eprt')
+        self.assertEqual(self.server.handler_instance.last_received_cmd, 'eprt')
 
     def test_makepasv(self):
         host, port = self.client.makepasv()
         conn = socket.create_connection((host, port), 10)
         conn.close()
-        self.assertEqual(self.server.handler.last_received_cmd, 'epsv')
+        self.assertEqual(self.server.handler_instance.last_received_cmd, 'epsv')
 
     def test_transfer(self):
         def retr():
@@ -642,7 +650,7 @@
     def setUp(self):
         self.server = DummyTLS_FTPServer((HOST, 0))
         self.server.start()
-        self.client = ftplib.FTP_TLS(timeout=10)
+        self.client = ftplib.FTP_TLS(timeout=TIMEOUT)
         self.client.connect(self.server.host, self.server.port)
 
     def tearDown(self):
@@ -695,6 +703,59 @@
         finally:
             self.client.ssl_version = ssl.PROTOCOL_TLSv1
 
+    def test_context(self):
+        self.client.quit()
+        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+        self.assertRaises(ValueError, ftplib.FTP_TLS, keyfile=CERTFILE,
+                          context=ctx)
+        self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE,
+                          context=ctx)
+        self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE,
+                          keyfile=CERTFILE, context=ctx)
+
+        self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT)
+        self.client.connect(self.server.host, self.server.port)
+        self.assertNotIsInstance(self.client.sock, ssl.SSLSocket)
+        self.client.auth()
+        self.assertIs(self.client.sock.context, ctx)
+        self.assertIsInstance(self.client.sock, ssl.SSLSocket)
+
+        self.client.prot_p()
+        sock = self.client.transfercmd('list')
+        try:
+            self.assertIs(sock.context, ctx)
+            self.assertIsInstance(sock, ssl.SSLSocket)
+        finally:
+            sock.close()
+
+    def test_check_hostname(self):
+        self.client.quit()
+        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+        ctx.verify_mode = ssl.CERT_REQUIRED
+        ctx.check_hostname = True
+        ctx.load_verify_locations(CAFILE)
+        self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT)
+
+        # 127.0.0.1 doesn't match SAN
+        self.client.connect(self.server.host, self.server.port)
+        with self.assertRaises(ssl.CertificateError):
+            self.client.auth()
+        # exception quits connection
+
+        self.client.connect(self.server.host, self.server.port)
+        self.client.prot_p()
+        with self.assertRaises(ssl.CertificateError):
+            self.client.transfercmd("list").close()
+        self.client.quit()
+
+        self.client.connect("localhost", self.server.port)
+        self.client.auth()
+        self.client.quit()
+
+        self.client.connect("localhost", self.server.port)
+        self.client.prot_p()
+        self.client.transfercmd("list").close()
+
 
 class TestTimeouts(TestCase):
 
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -15,6 +15,8 @@
 Library
 -------
 
+- Backport the context argument to ftplib.FTP_TLS.
+
 - Issue #23111: Maximize compatibility in protocol versions of ftplib.FTP_TLS.
 
 - Issue #23112: Fix SimpleHTTPServer to correctly carry the query string and

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list