[pypy-commit] pypy stdlib-2.7.9: SSL: Add support for npn_protocols

amauryfa noreply at buildbot.pypy.org
Fri Jan 23 23:41:14 CET 2015


Author: Amaury Forgeot d'Arc <amauryfa at gmail.com>
Branch: stdlib-2.7.9
Changeset: r75507:75d21c757ba1
Date: 2014-12-16 18:40 +0100
http://bitbucket.org/pypy/pypy/changeset/75d21c757ba1/

Log:	SSL: Add support for npn_protocols

diff --git a/pypy/module/_ssl/interp_ssl.py b/pypy/module/_ssl/interp_ssl.py
--- a/pypy/module/_ssl/interp_ssl.py
+++ b/pypy/module/_ssl/interp_ssl.py
@@ -1,7 +1,8 @@
 from rpython.rlib import rpoll, rsocket
-from rpython.rlib.rarithmetic import intmask
+from rpython.rlib.rarithmetic import intmask, widen, r_uint
 from rpython.rlib.ropenssl import *
 from rpython.rlib.rposix import get_errno, set_errno
+from rpython.rlib.rweakref import RWeakValueDictionary
 from rpython.rtyper.lltypesystem import lltype, rffi
 
 from pypy.interpreter.baseobjspace import W_Root
@@ -72,6 +73,10 @@
 constants["OP_NO_SSLv2"] = SSL_OP_NO_SSLv2
 constants["OP_NO_SSLv3"] = SSL_OP_NO_SSLv3
 constants["OP_NO_TLSv1"] = SSL_OP_NO_TLSv1
+constants["HAS_SNI"] = HAS_SNI
+constants["HAS_ECDH"] = True  # To break the test suite
+constants["HAS_NPN"] = HAS_NPN
+constants["HAS_TLS_UNIQUE"] = True  # To break the test suite
 
 constants["OPENSSL_VERSION_NUMBER"] = OPENSSL_VERSION_NUMBER
 ver = OPENSSL_VERSION_NUMBER
@@ -95,6 +100,54 @@
                                           space.wrap(errno), space.wrap(msg))
     return OperationError(w_exception_class, w_exception)
 
+class SSLNpnProtocols(object):
+
+    def __init__(self, ctx, protos):
+        self.protos = protos
+        self.buf, self.pinned, self.is_raw = rffi.get_nonmovingbuffer(protos)
+        NPN_STORAGE.set(r_uint(rffi.cast(rffi.UINT, self.buf)), self)
+
+        # set both server and client callbacks, because the context
+        # can be used to create both types of sockets
+        libssl_SSL_CTX_set_next_protos_advertised_cb(
+            ctx, self.advertiseNPN_cb, self.buf)
+        libssl_SSL_CTX_set_next_proto_select_cb(
+            ctx, self.selectNPN_cb, self.buf)
+
+    def __del__(self):
+        rffi.free_nonmovingbuffer(
+            self.protos, self.buf, self.pinned, self.is_raw)    
+
+    @staticmethod
+    def advertiseNPN_cb(s, data_ptr, len_ptr, args):
+        npn = NPN_STORAGE.get(r_uint(rffi.cast(rffi.UINT, args)))
+        if npn and npn.protos:
+            data_ptr[0] = npn.buf
+            len_ptr[0] = rffi.cast(rffi.UINT, len(npn.protos))
+        else:
+            data_ptr[0] = lltype.nullptr(rffi.CCHARP.TO)
+            len_ptr[0] = rffi.cast(rffi.UINT, 0)
+
+        return rffi.cast(rffi.INT, SSL_TLSEXT_ERR_OK)
+
+    @staticmethod
+    def selectNPN_cb(s, out_ptr, outlen_ptr, server, server_len, args):
+        npn = NPN_STORAGE.get(r_uint(rffi.cast(rffi.UINT, args)))
+        if npn and npn.protos:
+            client = npn.buf
+            client_len = len(npn.protos)
+        else:
+            client = lltype.nullptr(rffi.CCHARP.TO)
+            client_len = 0            
+
+        libssl_SSL_select_next_proto(out_ptr, outlen_ptr,
+                                     server, server_len,
+                                     client, client_len)
+        return rffi.cast(rffi.INT, SSL_TLSEXT_ERR_OK)
+
+NPN_STORAGE = RWeakValueDictionary(r_uint, SSLNpnProtocols)
+
+
 if HAVE_OPENSSL_RAND:
     # helper routines for seeding the SSL PRNG
     @unwrap_spec(string=str, entropy=float)
@@ -452,6 +505,15 @@
             else:
                 return _decode_certificate(space, self.peer_cert)
 
+    def selected_npn_protocol(self, space):
+        with lltype.scoped_alloc(rffi.CCHARPP.TO, 1) as out_ptr:
+            with lltype.scoped_alloc(rffi.UINTP.TO, 1) as len_ptr:
+                libssl_SSL_get0_next_proto_negotiated(self.ssl,
+                                                      out_ptr, len_ptr)
+                if out_ptr[0]:
+                    return space.wrap(
+                        rffi.charpsize2str(out_ptr[0], widen(len_ptr[0])))
+
 _SSLSocket.typedef = TypeDef(
     "_ssl._SSLSocket",
 
@@ -462,6 +524,7 @@
     peer_certificate=interp2app(_SSLSocket.peer_certificate),
     cipher=interp2app(_SSLSocket.cipher),
     shutdown=interp2app(_SSLSocket.shutdown),
+    selected_npn_protocol = interp2app(_SSLSocket.selected_npn_protocol),
 )
 
 
@@ -1079,6 +1142,15 @@
                       space.wrap('crl'), space.wrap(counters['crl']))
         return w_result
 
+    @unwrap_spec(protos='bufferstr')
+    def set_npn_protocols_w(self, space, protos):
+        if not HAS_NPN:
+            raise oefmt(space.w_NotImplementedError,
+                        "The NPN extension requires OpenSSL 1.0.1 or later.")
+
+        self.npn_protocols = SSLNpnProtocols(self.ctx, protos)
+
+
 _SSLContext.typedef = TypeDef(
     "_ssl._SSLContext",
     __new__=interp2app(_SSLContext.descr_new),
@@ -1088,6 +1160,7 @@
     cert_store_stats=interp2app(_SSLContext.cert_store_stats_w),
     load_cert_chain=interp2app(_SSLContext.load_cert_chain_w),
     set_default_verify_paths=interp2app(_SSLContext.descr_set_default_verify_paths),
+    _set_npn_protocols=interp2app(_SSLContext.set_npn_protocols_w),
 
     options=GetSetProperty(_SSLContext.descr_get_options,
                            _SSLContext.descr_set_options),
diff --git a/pypy/module/_ssl/test/test_ssl.py b/pypy/module/_ssl/test/test_ssl.py
--- a/pypy/module/_ssl/test/test_ssl.py
+++ b/pypy/module/_ssl/test/test_ssl.py
@@ -222,6 +222,15 @@
         raises(ssl.SSLError, ss.write, "hello\n")
         del ss; gc.collect()
 
+    def test_npn_protocol(self):
+        import socket, _ssl, gc
+        ctx = _ssl._SSLContext(_ssl.PROTOCOL_TLSv1)
+        ctx._set_npn_protocols(b'\x08http/1.1\x06spdy/2')
+        ss = ctx._wrap_socket(self.s, True,
+                              server_hostname="svn.python.org")
+        self.s.close()
+        del ss; gc.collect()
+
 
 class AppTestConnectedSSL_Timeout(AppTestConnectedSSL):
     # Same tests, with a socket timeout
diff --git a/rpython/rlib/ropenssl.py b/rpython/rlib/ropenssl.py
--- a/rpython/rlib/ropenssl.py
+++ b/rpython/rlib/ropenssl.py
@@ -83,6 +83,7 @@
     SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS = rffi_platform.ConstantInteger(
         "SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS")
     HAS_SNI = rffi_platform.Defined("SSL_CTRL_SET_TLSEXT_HOSTNAME")
+    HAS_NPN = rffi_platform.Defined("OPENSSL_NPN_NEGOTIATED")
     SSL_VERIFY_NONE = rffi_platform.ConstantInteger("SSL_VERIFY_NONE")
     SSL_VERIFY_PEER = rffi_platform.ConstantInteger("SSL_VERIFY_PEER")
     SSL_VERIFY_FAIL_IF_NO_PEER_CERT = rffi_platform.ConstantInteger("SSL_VERIFY_FAIL_IF_NO_PEER_CERT")
@@ -105,6 +106,7 @@
         "SSL_RECEIVED_SHUTDOWN")
     SSL_MODE_AUTO_RETRY = rffi_platform.ConstantInteger("SSL_MODE_AUTO_RETRY")
     SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER = rffi_platform.ConstantInteger("SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER")
+    SSL_TLSEXT_ERR_OK = rffi_platform.ConstantInteger("SSL_TLSEXT_ERR_OK")
 
     ERR_LIB_X509 = rffi_platform.ConstantInteger("ERR_LIB_X509")
     ERR_LIB_PEM = rffi_platform.ConstantInteger("ERR_LIB_PEM")
@@ -365,6 +367,24 @@
 ssl_external('PEM_read_bio_X509_AUX',
              [BIO, rffi.VOIDP, rffi.VOIDP, rffi.VOIDP], X509)
 
+if HAS_NPN:
+    SSL_NEXT_PROTOS_ADV_CB = lltype.Ptr(lltype.FuncType(
+        [SSL, rffi.CCHARPP, rffi.UINTP, rffi.VOIDP], rffi.INT))
+    ssl_external('SSL_CTX_set_next_protos_advertised_cb',
+                 [SSL_CTX, SSL_NEXT_PROTOS_ADV_CB, rffi.VOIDP], lltype.Void)
+    SSL_NEXT_PROTOS_SEL_CB = lltype.Ptr(lltype.FuncType(
+        [SSL, rffi.CCHARPP, rffi.UCHARP, rffi.CCHARP, rffi.UINT, rffi.VOIDP],
+        rffi.INT))
+    ssl_external('SSL_CTX_set_next_proto_select_cb',
+                 [SSL_CTX, SSL_NEXT_PROTOS_SEL_CB, rffi.VOIDP], lltype.Void)
+    ssl_external(
+        'SSL_select_next_proto', [rffi.CCHARPP, rffi.UCHARP,
+                                  rffi.CCHARP, rffi.UINT,
+                                  rffi.CCHARP, rffi.UINT], rffi.INT)
+    ssl_external(
+        'SSL_get0_next_proto_negotiated', [
+            SSL, rffi.CCHARPP, rffi.UINTP], lltype.Void)
+
 EVP_MD_CTX = rffi.COpaquePtr('EVP_MD_CTX', compilation_info=eci)
 EVP_MD     = lltype.Ptr(EVP_MD_st)
 


More information about the pypy-commit mailing list