[Python-checkins] cpython: Issue #14204: The ssl module now has support for the Next Protocol Negotiation

antoine.pitrou python-checkins at python.org
Thu Mar 22 00:34:40 CET 2012


http://hg.python.org/cpython/rev/2514a4e2b3ce
changeset:   75865:2514a4e2b3ce
user:        Antoine Pitrou <solipsis at pitrou.net>
date:        Thu Mar 22 00:23:03 2012 +0100
summary:
  Issue #14204: The ssl module now has support for the Next Protocol Negotiation extension, if available in the underlying OpenSSL library.
Patch by Colin Marc.

files:
  Doc/library/ssl.rst  |   35 +++++++++-
  Lib/ssl.py           |   27 ++++++-
  Lib/test/test_ssl.py |   54 +++++++++++++-
  Misc/ACKS            |    1 +
  Misc/NEWS            |    4 +
  Modules/_ssl.c       |  115 +++++++++++++++++++++++++++++++
  6 files changed, 228 insertions(+), 8 deletions(-)


diff --git a/Doc/library/ssl.rst b/Doc/library/ssl.rst
--- a/Doc/library/ssl.rst
+++ b/Doc/library/ssl.rst
@@ -470,6 +470,16 @@
 
    .. versionadded:: 3.2
 
+.. data:: HAS_NPN
+
+   Whether the OpenSSL library has built-in support for *Next Protocol
+   Negotiation* as described in the `NPN draft specification
+   <http://tools.ietf.org/html/draft-agl-tls-nextprotoneg>`_. When true,
+   you can use the :meth:`SSLContext.set_npn_protocols` method to advertise
+   which protocols you want to support.
+
+   .. versionadded:: 3.3
+
 .. data:: CHANNEL_BINDING_TYPES
 
    List of supported TLS channel binding types.  Strings in this list
@@ -609,6 +619,15 @@
 
    .. versionadded:: 3.3
 
+.. method:: SSLSocket.selected_npn_protocol()
+
+   Returns the protocol that was selected during the TLS/SSL handshake. If
+   :meth:`SSLContext.set_npn_protocols` was not called, or if the other party
+   does not support NPN, or if the handshake has not yet happened, this will
+   return ``None``.
+
+   .. versionadded:: 3.3
+
 .. method:: SSLSocket.unwrap()
 
    Performs the SSL shutdown handshake, which removes the TLS layer from the
@@ -617,7 +636,6 @@
    returned socket should always be used for further communication with the
    other side of the connection, rather than the original socket.
 
-
 .. attribute:: SSLSocket.context
 
    The :class:`SSLContext` object this SSL socket is tied to.  If the SSL
@@ -715,6 +733,21 @@
       when connected, the :meth:`SSLSocket.cipher` method of SSL sockets will
       give the currently selected cipher.
 
+.. method:: SSLContext.set_npn_protocols(protocols)
+
+   Specify which protocols the socket should avertise during the SSL/TLS
+   handshake. It should be a list of strings, like ``['http/1.1', 'spdy/2']``,
+   ordered by preference. The selection of a protocol will happen during the
+   handshake, and will play out according to the `NPN draft specification
+   <http://tools.ietf.org/html/draft-agl-tls-nextprotoneg>`_. After a
+   successful handshake, the :meth:`SSLSocket.selected_npn_protocol` method will
+   return the agreed-upon protocol.
+
+   This method will raise :exc:`NotImplementedError` if :data:`HAS_NPN` is
+   False.
+
+   .. versionadded:: 3.3
+
 .. method:: SSLContext.load_dh_params(dhfile)
 
    Load the key generation parameters for Diffie-Helman (DH) key exchange.
diff --git a/Lib/ssl.py b/Lib/ssl.py
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -90,7 +90,7 @@
     SSL_ERROR_EOF,
     SSL_ERROR_INVALID_ERROR_CODE,
     )
-from _ssl import HAS_SNI, HAS_ECDH
+from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN
 from _ssl import (PROTOCOL_SSLv3, PROTOCOL_SSLv23,
                   PROTOCOL_TLSv1)
 from _ssl import _OPENSSL_API_VERSION
@@ -209,6 +209,17 @@
                          server_hostname=server_hostname,
                          _context=self)
 
+    def set_npn_protocols(self, npn_protocols):
+        protos = bytearray()
+        for protocol in npn_protocols:
+            b = bytes(protocol, 'ascii')
+            if len(b) == 0 or len(b) > 255:
+                raise SSLError('NPN protocols must be 1 to 255 in length')
+            protos.append(len(b))
+            protos.extend(b)
+
+        self._set_npn_protocols(protos)
+
 
 class SSLSocket(socket):
     """This class implements a subtype of socket.socket that wraps
@@ -220,7 +231,7 @@
                  ssl_version=PROTOCOL_SSLv23, ca_certs=None,
                  do_handshake_on_connect=True,
                  family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
-                 suppress_ragged_eofs=True, ciphers=None,
+                 suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
                  server_hostname=None,
                  _context=None):
 
@@ -240,6 +251,8 @@
                 self.context.load_verify_locations(ca_certs)
             if certfile:
                 self.context.load_cert_chain(certfile, keyfile)
+            if npn_protocols:
+                self.context.set_npn_protocols(npn_protocols)
             if ciphers:
                 self.context.set_ciphers(ciphers)
             self.keyfile = keyfile
@@ -340,6 +353,13 @@
         self._checkClosed()
         return self._sslobj.peer_certificate(binary_form)
 
+    def selected_npn_protocol(self):
+        self._checkClosed()
+        if not self._sslobj or not _ssl.HAS_NPN:
+            return None
+        else:
+            return self._sslobj.selected_npn_protocol()
+
     def cipher(self):
         self._checkClosed()
         if not self._sslobj:
@@ -568,7 +588,8 @@
                 server_side=False, cert_reqs=CERT_NONE,
                 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
                 do_handshake_on_connect=True,
-                suppress_ragged_eofs=True, ciphers=None):
+                suppress_ragged_eofs=True,
+                ciphers=None):
 
     return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
                      server_side=server_side, cert_reqs=cert_reqs,
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -879,6 +879,7 @@
                 try:
                     self.sslconn = self.server.context.wrap_socket(
                         self.sock, server_side=True)
+                    self.server.selected_protocols.append(self.sslconn.selected_npn_protocol())
                 except ssl.SSLError as e:
                     # XXX Various errors can have happened here, for example
                     # a mismatching protocol version, an invalid certificate,
@@ -901,6 +902,8 @@
                     cipher = self.sslconn.cipher()
                     if support.verbose and self.server.chatty:
                         sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
+                        sys.stdout.write(" server: selected protocol is now "
+                                + str(self.sslconn.selected_npn_protocol()) + "\n")
                     return True
 
             def read(self):
@@ -979,7 +982,7 @@
         def __init__(self, certificate=None, ssl_version=None,
                      certreqs=None, cacerts=None,
                      chatty=True, connectionchatty=False, starttls_server=False,
-                     ciphers=None, context=None):
+                     npn_protocols=None, ciphers=None, context=None):
             if context:
                 self.context = context
             else:
@@ -992,6 +995,8 @@
                     self.context.load_verify_locations(cacerts)
                 if certificate:
                     self.context.load_cert_chain(certificate)
+                if npn_protocols:
+                    self.context.set_npn_protocols(npn_protocols)
                 if ciphers:
                     self.context.set_ciphers(ciphers)
             self.chatty = chatty
@@ -1001,6 +1006,7 @@
             self.port = support.bind_port(self.sock)
             self.flag = None
             self.active = False
+            self.selected_protocols = []
             self.conn_errors = []
             threading.Thread.__init__(self)
             self.daemon = True
@@ -1195,6 +1201,7 @@
         Launch a server, connect a client to it and try various reads
         and writes.
         """
+        stats = {}
         server = ThreadedEchoServer(context=server_context,
                                     chatty=chatty,
                                     connectionchatty=False)
@@ -1220,12 +1227,14 @@
                 if connectionchatty:
                     if support.verbose:
                         sys.stdout.write(" client:  closing connection.\n")
-                stats = {
+                stats.update({
                     'compression': s.compression(),
                     'cipher': s.cipher(),
-                }
+                    'client_npn_protocol': s.selected_npn_protocol()
+                })
                 s.close()
-                return stats
+            stats['server_npn_protocols'] = server.selected_protocols
+        return stats
 
     def try_protocol_combo(server_protocol, client_protocol, expect_success,
                            certsreqs=None, server_options=0, client_options=0):
@@ -1853,6 +1862,43 @@
             if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
                 self.fail("Non-DH cipher: " + cipher[0])
 
+        def test_selected_npn_protocol(self):
+            # selected_npn_protocol() is None unless NPN is used
+            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+            context.load_cert_chain(CERTFILE)
+            stats = server_params_test(context, context,
+                                       chatty=True, connectionchatty=True)
+            self.assertIs(stats['client_npn_protocol'], None)
+
+        @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
+        def test_npn_protocols(self):
+            server_protocols = ['http/1.1', 'spdy/2']
+            protocol_tests = [
+                (['http/1.1', 'spdy/2'], 'http/1.1'),
+                (['spdy/2', 'http/1.1'], 'http/1.1'),
+                (['spdy/2', 'test'], 'spdy/2'),
+                (['abc', 'def'], 'abc')
+            ]
+            for client_protocols, expected in protocol_tests:
+                server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+                server_context.load_cert_chain(CERTFILE)
+                server_context.set_npn_protocols(server_protocols)
+                client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+                client_context.load_cert_chain(CERTFILE)
+                client_context.set_npn_protocols(client_protocols)
+                stats = server_params_test(client_context, server_context,
+                                           chatty=True, connectionchatty=True)
+
+                msg = "failed trying %s (s) and %s (c).\n" \
+                      "was expecting %s, but got %%s from the %%s" \
+                          % (str(server_protocols), str(client_protocols),
+                             str(expected))
+                client_result = stats['client_npn_protocol']
+                self.assertEqual(client_result, expected, msg % (client_result, "client"))
+                server_result = stats['server_npn_protocols'][-1] \
+                    if len(stats['server_npn_protocols']) else 'nothing'
+                self.assertEqual(server_result, expected, msg % (server_result, "server"))
+
 
 def test_main(verbose=False):
     if support.verbose:
diff --git a/Misc/ACKS b/Misc/ACKS
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -644,6 +644,7 @@
 David Malcolm
 Ken Manheimer
 Vladimir Marangozov
+Colin Marc
 David Marek
 Doug Marien
 Sven Marnach
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -30,6 +30,10 @@
 Library
 -------
 
+- Issue #14204: The ssl module now has support for the Next Protocol
+  Negotiation extension, if available in the underlying OpenSSL library.
+  Patch by Colin Marc.
+
 - Issue #3035: Unused functions from tkinter are marked as pending peprecated.
 
 - Issue #12757: Fix the skipping of doctests when python is run with -OO so
diff --git a/Modules/_ssl.c b/Modules/_ssl.c
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -159,6 +159,10 @@
 typedef struct {
     PyObject_HEAD
     SSL_CTX *ctx;
+#ifdef OPENSSL_NPN_NEGOTIATED
+    char *npn_protocols;
+    int npn_protocols_len;
+#endif
 } PySSLContext;
 
 typedef struct {
@@ -1015,6 +1019,20 @@
     return NULL;
 }
 
+#ifdef OPENSSL_NPN_NEGOTIATED
+static PyObject *PySSL_selected_npn_protocol(PySSLSocket *self) {
+    const unsigned char *out;
+    unsigned int outlen;
+
+    SSL_get0_next_proto_negotiated(self->ssl, 
+                                   &out, &outlen);
+
+    if (out == NULL)
+        Py_RETURN_NONE;
+    return PyUnicode_FromStringAndSize((char *) out, outlen);
+}
+#endif
+
 static PyObject *PySSL_compression(PySSLSocket *self) {
 #ifdef OPENSSL_NO_COMP
     Py_RETURN_NONE;
@@ -1487,6 +1505,9 @@
     {"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS,
      PySSL_peercert_doc},
     {"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS},
+#ifdef OPENSSL_NPN_NEGOTIATED
+    {"selected_npn_protocol", (PyCFunction)PySSL_selected_npn_protocol, METH_NOARGS},
+#endif
     {"compression", (PyCFunction)PySSL_compression, METH_NOARGS},
     {"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS,
      PySSL_SSLshutdown_doc},
@@ -1597,6 +1618,9 @@
 context_dealloc(PySSLContext *self)
 {
     SSL_CTX_free(self->ctx);
+#ifdef OPENSSL_NPN_NEGOTIATED
+    PyMem_Free(self->npn_protocols);
+#endif
     Py_TYPE(self)->tp_free(self);
 }
 
@@ -1621,6 +1645,87 @@
     Py_RETURN_NONE;
 }
 
+#ifdef OPENSSL_NPN_NEGOTIATED
+/* this callback gets passed to SSL_CTX_set_next_protos_advertise_cb */
+static int
+_advertiseNPN_cb(SSL *s, 
+                 const unsigned char **data, unsigned int *len, 
+                 void *args)
+{
+    PySSLContext *ssl_ctx = (PySSLContext *) args;
+
+    if (ssl_ctx->npn_protocols == NULL) {
+        *data = (unsigned char *) "";
+        *len = 0;
+    } else {
+        *data = (unsigned char *) ssl_ctx->npn_protocols;
+        *len = ssl_ctx->npn_protocols_len;
+    }
+
+    return SSL_TLSEXT_ERR_OK;
+}
+/* this callback gets passed to SSL_CTX_set_next_proto_select_cb */
+static int
+_selectNPN_cb(SSL *s, 
+              unsigned char **out, unsigned char *outlen,
+              const unsigned char *server, unsigned int server_len,
+              void *args)
+{
+    PySSLContext *ssl_ctx = (PySSLContext *) args;
+
+    unsigned char *client = (unsigned char *) ssl_ctx->npn_protocols;
+    int client_len;
+
+    if (client == NULL) {
+        client = (unsigned char *) "";
+        client_len = 0;
+    } else {
+        client_len = ssl_ctx->npn_protocols_len;
+    }
+
+    SSL_select_next_proto(out, outlen,
+                          server, server_len,
+                          client, client_len);
+
+    return SSL_TLSEXT_ERR_OK;
+}
+#endif
+
+static PyObject *
+_set_npn_protocols(PySSLContext *self, PyObject *args)
+{
+#ifdef OPENSSL_NPN_NEGOTIATED
+    Py_buffer protos;
+
+    if (!PyArg_ParseTuple(args, "y*:set_npn_protocols", &protos))
+        return NULL;
+
+    self->npn_protocols = PyMem_Malloc(protos.len);
+    if (self->npn_protocols == NULL) {
+        PyBuffer_Release(&protos);
+        return PyErr_NoMemory();
+    }
+    memcpy(self->npn_protocols, protos.buf, protos.len);
+    self->npn_protocols_len = (int) protos.len;
+
+    /* set both server and client callbacks, because the context can
+     * be used to create both types of sockets */
+    SSL_CTX_set_next_protos_advertised_cb(self->ctx,
+                                          _advertiseNPN_cb,
+                                          self);
+    SSL_CTX_set_next_proto_select_cb(self->ctx,
+                                     _selectNPN_cb,
+                                     self);
+
+    PyBuffer_Release(&protos);
+    Py_RETURN_NONE;
+#else
+    PyErr_SetString(PyExc_NotImplementedError,
+                    "The NPN extension requires OpenSSL 1.0.1 or later.");
+    return NULL;
+#endif
+}
+
 static PyObject *
 get_verify_mode(PySSLContext *self, void *c)
 {
@@ -2097,6 +2202,8 @@
                        METH_VARARGS | METH_KEYWORDS, NULL},
     {"set_ciphers", (PyCFunction) set_ciphers,
                     METH_VARARGS, NULL},
+    {"_set_npn_protocols", (PyCFunction) _set_npn_protocols,
+                           METH_VARARGS, NULL},
     {"load_cert_chain", (PyCFunction) load_cert_chain,
                         METH_VARARGS | METH_KEYWORDS, NULL},
     {"load_dh_params", (PyCFunction) load_dh_params,
@@ -2590,6 +2697,14 @@
     Py_INCREF(r);
     PyModule_AddObject(m, "HAS_ECDH", r);
 
+#ifdef OPENSSL_NPN_NEGOTIATED
+    r = Py_True;
+#else
+    r = Py_False;
+#endif
+    Py_INCREF(r);
+    PyModule_AddObject(m, "HAS_NPN", r);
+
     /* OpenSSL version */
     /* SSLeay() gives us the version of the library linked against,
        which could be different from the headers version.

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list