[Jython-checkins] jython: Initial support for SSLContext

darjus.loktevic jython-checkins at python.org
Sun Oct 11 08:17:14 CEST 2015


https://hg.python.org/jython/rev/57e704347787
changeset:   7754:57e704347787
user:        Darjus Loktevic <darjus at gmail.com>
date:        Sun Oct 11 17:17:10 2015 +1100
summary:
  Initial support for SSLContext

files:
  Lib/_sslcerts.py |   76 +++++-----
  Lib/ssl.py       |  247 +++++++++++++++++++++++++++++++---
  2 files changed, 261 insertions(+), 62 deletions(-)


diff --git a/Lib/_sslcerts.py b/Lib/_sslcerts.py
--- a/Lib/_sslcerts.py
+++ b/Lib/_sslcerts.py
@@ -34,15 +34,16 @@
 
 
 
-def _get_ca_certs_trust_manager(ca_certs):
+def _get_ca_certs_trust_manager(ca_certs=None):
     trust_store = KeyStore.getInstance(KeyStore.getDefaultType())
     trust_store.load(None, None)
     num_certs_installed = 0
-    with open(ca_certs) as f:
-        cf = CertificateFactory.getInstance("X.509")
-        for cert in cf.generateCertificates(BufferedInputStream(f)):
-            trust_store.setCertificateEntry(str(uuid.uuid4()), cert)
-            num_certs_installed += 1
+    if ca_certs is not None:
+        with open(ca_certs) as f:
+            cf = CertificateFactory.getInstance("X.509")
+            for cert in cf.generateCertificates(BufferedInputStream(f)):
+                trust_store.setCertificateEntry(str(uuid.uuid4()), cert)
+                num_certs_installed += 1
     tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
     tmf.init(trust_store)
     log.debug("Installed %s certificates", num_certs_installed, extra={"sock": "*"})
@@ -67,9 +68,13 @@
     return _stringio_as_reader(private_key), _stringio_as_reader(certs)
 
 
-def _get_openssl_key_manager(cert_file, key_file=None):
+def _get_openssl_key_manager(cert_file=None, key_file=None, password=None, _key_store=None):
+    if password is None:
+        password = []
+
     paths = [key_file] if key_file else []
-    paths.append(cert_file)
+    if cert_file:
+        paths.append(cert_file)
 
     # Go from Bouncy Castle API to Java's; a bit heavyweight for the Python dev ;)
     key_converter = JcaPEMKeyConverter().setProvider("BC")
@@ -90,41 +95,23 @@
                 elif isinstance(obj, X509CertificateHolder):
                     certs.append(cert_converter.getCertificate(obj))
 
-    if not private_key:
-        from _socket import SSLError, SSL_ERROR_SSL
-        raise SSLError(SSL_ERROR_SSL, "No private key loaded")
-    key_store = KeyStore.getInstance(KeyStore.getDefaultType())
-    key_store.load(None, None)
-    key_store.setKeyEntry(str(uuid.uuid4()), private_key, [], certs)
+
+    if _key_store is None:
+        key_store = KeyStore.getInstance(KeyStore.getDefaultType())
+        key_store.load(None, None)
+
+    if cert_file is not None:
+        if not private_key:
+            from _socket import SSLError, SSL_ERROR_SSL
+            raise SSLError(SSL_ERROR_SSL, "No private key loaded")
+
+        key_store.setKeyEntry(str(uuid.uuid4()), private_key, [], certs)
+
     kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
     kmf.init(key_store, [])
     return kmf
 
 
-def _get_ssl_context(keyfile, certfile, ca_certs):
-    if certfile is None and ca_certs is None:
-        log.debug("Using default SSL context", extra={"sock": "*"})
-        return SSLContext.getDefault()
-    else:
-        log.debug("Setting up a specific SSL context for keyfile=%s, certfile=%s, ca_certs=%s",
-                  keyfile, certfile, ca_certs, extra={"sock": "*"})
-        if ca_certs:
-            # should support composite usage below
-            trust_managers = _get_ca_certs_trust_manager(ca_certs).getTrustManagers()
-        else:
-            trust_managers = None
-        if certfile:
-            key_managers = _get_openssl_key_manager(certfile, keyfile).getKeyManagers()
-        else:
-            key_managers = None
-
-        # FIXME FIXME for performance, cache this lookup in the future
-        # to avoid re-reading files on every lookup
-        context = SSLContext.getInstance("SSL")
-        context.init(key_managers, trust_managers, None)
-        return context
-
-
 # CompositeX509KeyManager and CompositeX509TrustManager allow for mixing together Java built-in managers
 # with new managers to support Python ssl.
 #
@@ -214,3 +201,16 @@
         for trust_manager in self.trust_managers:
             certs.extend(trustManager.getAcceptedIssuers())
         return certs
+
+
+# To use with CERT_NONE
+class NoVerifyX509TrustManager(X509TrustManager):
+
+    def checkClientTrusted(self, chain, auth_type):
+        pass
+
+    def checkServerTrusted(self, chain, auth_type):
+        pass
+
+    def getAcceptedIssuers(self):
+        return None
diff --git a/Lib/ssl.py b/Lib/ssl.py
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -1,5 +1,12 @@
 import base64
+from collections import namedtuple
 import errno
+from java.security.cert import CertificateFactory
+import uuid
+from java.io import BufferedInputStream
+from java.security import KeyStore
+from java.security.cert import CertificateParsingException
+from javax.net.ssl import TrustManagerFactory
 import logging
 import os.path
 import textwrap
@@ -27,7 +34,8 @@
     SSL_ERROR_EOF,
     SSL_ERROR_INVALID_ERROR_CODE,
     error as socket_error)
-from _sslcerts import _get_ssl_context
+from _sslcerts import _get_openssl_key_manager, NoVerifyX509TrustManager
+from _sslcerts import SSLContext as _JavaSSLContext
 
 from java.text import SimpleDateFormat
 from java.util import ArrayList, Locale, TimeZone
@@ -35,7 +43,6 @@
 from javax.naming.ldap import LdapName
 from javax.security.auth.x500 import X500Principal
 
-
 log = logging.getLogger("_socket")
 
 
@@ -47,11 +54,16 @@
 CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED = range(3)
 
 # Do not support PROTOCOL_SSLv2, it is highly insecure and it is optional
-_, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 = range(4)
+_, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1, PROTOCOL_TLSv1_1, PROTOCOL_TLSv1_2 = range(6)
 _PROTOCOL_NAMES = {
-    PROTOCOL_SSLv3: 'SSLv3', 
+    PROTOCOL_SSLv3: 'SSLv3',
     PROTOCOL_SSLv23: 'SSLv23',
-    PROTOCOL_TLSv1: 'TLSv1'}
+    PROTOCOL_TLSv1: 'TLSv1',
+    PROTOCOL_TLSv1_1: 'TLSv1.1',
+    PROTOCOL_TLSv1_2: 'TLSv1.2'
+}
+
+OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1 = range(4)
 
 _rfc2822_date_format = SimpleDateFormat("MMM dd HH:mm:ss yyyy z", Locale.US)
 _rfc2822_date_format.setTimeZone(TimeZone.getTimeZone("GMT"))
@@ -59,11 +71,11 @@
 _ldap_rdn_display_names = {
     # list from RFC 2253
     "CN": "commonName",
-    "L":  "localityName",
+    "L": "localityName",
     "ST": "stateOrProvinceName",
-    "O":  "organizationName",
+    "O": "organizationName",
     "OU": "organizationalUnitName",
-    "C":  "countryName",
+    "C": "countryName",
     "STREET": "streetAddress",
     "DC": "domainComponent",
     "UID": "userid"
@@ -84,7 +96,6 @@
 
 
 class SSLInitializer(ChannelInitializer):
-
     def __init__(self, ssl_handler):
         self.ssl_handler = ssl_handler
 
@@ -94,15 +105,35 @@
 
 
 class SSLSocket(object):
-    
-    def __init__(self, sock,
-                 keyfile, certfile, ca_certs,
-                 do_handshake_on_connect, server_side):
+
+    def __init__(self, sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE,
+                 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
+                 do_handshake_on_connect=True, suppress_ragged_eofs=True, ciphers=None, _context=None):
+        # TODO ^^ handle suppress_ragged_eofs
         self.sock = sock
         self.do_handshake_on_connect = do_handshake_on_connect
         self._sock = sock._sock  # the real underlying socket
-        self.context = _get_ssl_context(keyfile, certfile, ca_certs)
-        self.engine = self.context.createSSLEngine()
+        self.context = _context
+        if _context is None:
+            self.context = SSLContext(ssl_version)
+        else:
+            if server_side and not certfile:
+                raise ValueError("certfile must be specified for server-side "
+                                 "operations")
+            if keyfile and not certfile:
+                raise ValueError("certfile must be specified")
+            if certfile and not keyfile:
+                keyfile = certfile
+            self._context = SSLContext(ssl_version)
+            self._context.verify_mode = cert_reqs
+            if ca_certs:
+                self._context.load_verify_locations(ca_certs)
+            if certfile:
+                self._context.load_cert_chain(certfile, keyfile)
+            if ciphers:
+                self._context.set_ciphers(ciphers)
+
+        self.engine = self.context._createSSLEngine()
         self.server_side = server_side
         self.engine.setUseClientMode(not server_side)
         self.ssl_handler = None
@@ -254,7 +285,7 @@
         pycert = {
             "notAfter": _rfc2822_date_format.format(cert.getNotAfter()),
             "subject": rdns,
-            "subjectAltName": alt_names, 
+            "subjectAltName": alt_names,
         }
         return pycert
 
@@ -274,7 +305,6 @@
         return suite, str(session.protocol), strength
 
 
-
 # instantiates a SSLEngine, with the following things to keep in mind:
 
 # FIXME not yet supported
@@ -284,10 +314,11 @@
 
 @raises_java_exception
 def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE,
-                ssl_version=None, ca_certs=None, do_handshake_on_connect=True,
+                ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True,
                 suppress_ragged_eofs=True, ciphers=None):
+
     return SSLSocket(
-        sock, 
+        sock,
         keyfile=keyfile, certfile=certfile, ca_certs=ca_certs,
         server_side=server_side,
         do_handshake_on_connect=do_handshake_on_connect)
@@ -296,7 +327,6 @@
 # some utility functions
 
 def cert_time_to_seconds(cert_time):
-
     """Takes a date-time string in standard ASN1_print form
     ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
     a Python time value in seconds past the epoch."""
@@ -304,11 +334,12 @@
     import time
     return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))
 
+
 PEM_HEADER = "-----BEGIN CERTIFICATE-----"
 PEM_FOOTER = "-----END CERTIFICATE-----"
 
+
 def DER_cert_to_PEM_cert(der_cert_bytes):
-
     """Takes a certificate in binary DER format and returns the
     PEM version of it as a string."""
 
@@ -323,8 +354,8 @@
                 base64.encodestring(der_cert_bytes) +
                 PEM_FOOTER + '\n')
 
+
 def PEM_cert_to_DER_cert(pem_cert_string):
-
     """Takes a certificate in ASCII PEM format and returns the
     DER-encoded version of it as a byte sequence"""
 
@@ -337,8 +368,8 @@
     d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
     return base64.decodestring(d)
 
+
 def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
-
     """Retrieve the certificate from the server at the specified address,
     and return it as a PEM-encoded string.
     If 'ca_certs' is specified, validate the server cert against it.
@@ -356,13 +387,14 @@
     s.close()
     return DER_cert_to_PEM_cert(dercert)
 
+
 def get_protocol_name(protocol_code):
     return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
 
+
 # a replacement for the old socket.ssl function
 
 def sslwrap_simple(sock, keyfile=None, certfile=None):
-
     """A replacement for the old socket.ssl function.  Designed
     for compability with Python 2.5 and earlier.  Will disappear in
     Python 3.0."""
@@ -385,11 +417,178 @@
 def RAND_status():
     return True
 
+
 def RAND_egd(path):
     if os.path.abspath(str(path)) != path:
         raise TypeError("Must be an absolute path, but ignoring it regardless")
 
+
 def RAND_add(bytes, entropy):
     pass
 
 
+class Purpose(object):
+    """SSLContext purpose flags with X509v3 Extended Key Usage objects
+    """
+    SERVER_AUTH = '1.3.6.1.5.5.7.3.1'
+    CLIENT_AUTH = '1.3.6.1.5.5.7.3.2'
+
+
+class SSLContext(object):
+    _DN_TO_CPY = {'CN': 'commonName', 'O': 'commonOrganization', 'C': 'countryName', 'DC': 'domainComponent',
+                  'SN': 'surname', 'GN': 'givenName', 'OU': 'organizationalUnitName', 'ST': 'stateOrProvinceName',
+                  'L': 'localityName', 'SERIALNUMBER': 'serialNumber', 'EMAILADDRESS': 'emailAddress'}
+
+    def __init__(self, protocol):
+        protocol_name = _PROTOCOL_NAMES[protocol]
+        if protocol == PROTOCOL_SSLv23:  # darjus: at least my Java does not let me use v2
+            protocol_name = 'SSL'
+
+        self.protocol = protocol
+        self.check_hostname = False
+        self.options = OP_ALL
+        self.verify_flags = None
+        self.verify_mode = CERT_NONE
+        self._ciphers = None
+
+        self._trust_store = KeyStore.getInstance(KeyStore.getDefaultType())
+        self._trust_store.load(None, None)
+
+        self._key_store = KeyStore.getInstance(KeyStore.getDefaultType())
+        self._key_store.load(None, None)
+
+        self._context = _JavaSSLContext.getInstance(protocol_name)
+        self._key_managers = None
+
+    def wrap_socket(self, sock, server_side=False,
+                    do_handshake_on_connect=True,
+                    suppress_ragged_eofs=True,
+                    server_hostname=None):
+        # FIXME do something about server_hostname
+        return SSLSocket(sock, keyfile=None, certfile=None, ca_certs=None, suppress_ragged_eofs=suppress_ragged_eofs,
+                         do_handshake_on_connect=do_handshake_on_connect, server_side=server_side, _context=self)
+
+    def _createSSLEngine(self):
+        trust_managers = [NoVerifyX509TrustManager()]
+        if self.verify_mode == CERT_REQUIRED:
+            tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
+            tmf.init(self._trust_store)
+            trust_managers = tmf.getTrustManagers()
+
+        if self._key_managers is None:  # get an e
+            self._context.init(_get_openssl_key_manager().getKeyManagers(), trust_managers, None)
+        else:
+            self._context.init(self._key_managers.getKeyManagers(), trust_managers, None)
+
+        engine = self._context.createSSLEngine()
+
+        if self._ciphers is not None:
+            engine.setEnabledCipherSuites(self._ciphers)
+
+        return engine
+
+    def cert_store_stats(self):
+        # TODO not sure if we can even get something similar from Java
+        return {}
+
+    def load_cert_chain(self, certfile, keyfile=None, password=None):
+        self._key_managers = _get_openssl_key_manager(certfile, keyfile, password, _key_store=self._key_store)
+
+    def set_ciphers(self, ciphers):
+        # TODO conversion from OpenSSL to http://www.iana.org/assignments/tls-parameters/tls-parameters.xml
+        # as Java knows no other
+        #self._ciphers = ciphers
+        pass
+
+    def load_verify_locations(self, cafile=None, capath=None, cadata=None):
+        print cafile
+        if cafile is not None:
+            with open(cafile) as f:
+                self._load_certificates(f)
+        if capath is not None:
+            for fname in os.listdir(capath):
+                _, ext = os.path.splitext()
+                if ext.lower() == 'pem':
+                    with open(os.path.join(capath, fname)) as f:
+                        self._load_certificates(f)
+        if cadata is not None:
+            self._load_certificates(f)
+
+    @raises_java_exception
+    def _load_certificates(self, f):
+        cf = CertificateFactory.getInstance("X.509")
+        try:
+            for cert in cf.generateCertificates(BufferedInputStream(f)):
+                self._trust_store.setCertificateEntry(str(uuid.uuid4()), cert)
+        except CertificateParsingException:
+            log.debug("Failed to parse certificate", exc_info=True)
+            raise
+
+    def load_default_certs(self, purpose=None):
+        # TODO handle/support purpose
+        self.set_default_verify_paths()
+
+    def set_default_verify_paths(self):
+        """
+        Load a set of default "certification authority" (CA) certificates from a filesystem path defined when building
+        the OpenSSL library. Unfortunately, there's no easy way to know whether this method succeeds: no error is
+        returned if no certificates are to be found. When the OpenSSL library is provided as part of the operating
+        system, though, it is likely to be configured properly.
+        """
+        # TODO not implemented, we want to use some default Java's loading method.
+        return None
+
+    def set_alpn_protocols(self, protocols):
+        raise NotImplementedError()
+
+    def set_npn_protocols(self, protocols):
+        raise NotImplementedError()
+
+    def set_servername_callback(self, server_name_callback):
+        raise NotImplementedError()
+
+    def load_dh_params(self, dhfile):
+        # TODO?
+        pass
+
+    def set_ecdh_curve(self, curve_name):
+        # TODO?
+        pass
+
+    def get_ca_certs(self, binary_form=False):
+        """get_ca_certs(binary_form=False) -> list of loaded certificate
+
+        Returns a list of dicts with information of loaded CA certs. If the optional argument is True,
+        returns a DER-encoded copy of the CA certificate.
+        NOTE: Certificates in a capath directory aren't loaded unless they have been used at least once.
+        """
+        if binary_form:
+            raise NotImplementedError()
+
+        certs = []
+        enumerator = self._trust_store.aliases()
+        while enumerator.hasMoreElements():
+            alias = enumerator.next()
+            if self._trust_store.isCertificateEntry(alias):
+                cert = self._trust_store.getCertificate(alias)
+                issuer_info = self._parse_dn(cert.issuerDN)
+                subject_info = self._parse_dn(cert.subjectDN)
+
+                cert_info = {'issuer': issuer_info, 'subject': subject_info}
+                for k in ('notBefore', 'serialNumber', 'notAfter', 'version'):
+                    cert_info[k] = getattr(cert, k)
+
+                certs.append(cert_info)
+
+        return certs
+
+    @classmethod
+    def _parse_dn(cls, dn):
+        try:
+            dn_dct = dict([iss.split('=', 1) for iss in unicode(dn).split(',')])
+        except ValueError:
+            # FIXME CN=Starfield Root Certificate Authority - G2, O="Starfield Technologies, Inc.",
+            log.error("Failed to parse {}".format(dn), exc_info=True)
+            return tuple()
+
+        return tuple((cls._DN_TO_CPY.get(key.strip(), 'unk'), val) for key, val in dn_dct.iteritems())

-- 
Repository URL: https://hg.python.org/jython


More information about the Jython-checkins mailing list