[Jython-checkins] jython: More ssl improvements: Support check_hostname, support homebrew openssl certs,

darjus.loktevic jython-checkins at python.org
Thu Dec 10 21:10:51 EST 2015


https://hg.python.org/jython/rev/1df0ae17b03f
changeset:   7828:1df0ae17b03f
user:        Darjus Loktevic <darjus at gmail.com>
date:        Fri Dec 11 11:53:56 2015 +1100
summary:
  More ssl improvements: Support check_hostname, support homebrew openssl certs, upgraded urllib, other SSLEngine creation improvements, more tests passing

files:
  Lib/_socket.py             |    4 +-
  Lib/_sslcerts.py           |    6 +-
  Lib/ssl.py                 |   60 +++++++++-----
  Lib/test/test_socket_jy.py |    2 +-
  Lib/test/test_ssl.py       |    1 -
  Lib/urllib.py              |  105 ++++++++++++++++--------
  6 files changed, 114 insertions(+), 64 deletions(-)


diff --git a/Lib/_socket.py b/Lib/_socket.py
--- a/Lib/_socket.py
+++ b/Lib/_socket.py
@@ -926,8 +926,8 @@
             # from socketmodule.c
             # if (res == EISCONN)
             #   res = 0;
-            # and that is what tests expect, so we return 0 to be like CPython
-            return 0
+            # but http://bugs.jython.org/issue2428
+            return errno.EISCONN
         else:
             return errno.ENOTCONN
 
diff --git a/Lib/_sslcerts.py b/Lib/_sslcerts.py
--- a/Lib/_sslcerts.py
+++ b/Lib/_sslcerts.py
@@ -5,13 +5,13 @@
 import types
 
 from java.lang import RuntimeException
-from java.io import BufferedInputStream, BufferedReader, InputStreamReader, ByteArrayInputStream, IOException
+from java.io import BufferedInputStream, BufferedReader, FileReader, InputStreamReader, ByteArrayInputStream, IOException
 from java.security import KeyStore, Security, InvalidAlgorithmParameterException
 from java.security.cert import CertificateException, CertificateFactory
 from java.security.interfaces import RSAPrivateCrtKey
 from java.security.interfaces import RSAPublicKey
 from javax.net.ssl import (
-    X509KeyManager, X509TrustManager, KeyManagerFactory, TrustManagerFactory)
+    X509KeyManager, X509TrustManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory)
 try:
     # jarjar-ed version
     from org.python.bouncycastle.asn1.pkcs import PrivateKeyInfo
@@ -137,7 +137,7 @@
     _hash = 0
     for arg in args:
         if arg:
-            _hash += hash(str(arg))
+            _hash += hash(arg.toString().encode('utf8'))
 
     return str(_hash)
 
diff --git a/Lib/ssl.py b/Lib/ssl.py
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -488,20 +488,22 @@
         self._sslobj = None
         self.handshake_count = 0
 
-        self.engine = self._context._createSSLEngine()
-
+        self.engine = None
         if self.do_handshake_on_connect and self.sock._sock.connected:
             self.do_handshake()
 
+    def setup_engine(self, addr):
+        if self.engine is None:
+            # http://stackoverflow.com/questions/13390964/java-ssl-fatal-error-80-unwrapping-net-record-after-adding-the-https-en
+            self.engine = self._context._createSSLEngine(addr, self.server_hostname)
+            self.engine.setUseClientMode(not self.server_side)
+
     def connect(self, addr):
         if self.server_side:
             raise ValueError("can't connect in server-side mode")
 
         log.debug("Connect SSL with handshaking %s", self.do_handshake_on_connect, extra={"sock": self._sock})
 
-        self.engine = self._context._createSSLEngine(*addr)
-        self.engine.setUseClientMode(not self.server_side)
-
         self._sock._connect(addr)
         if self.do_handshake_on_connect:
             self.do_handshake()
@@ -512,13 +514,18 @@
 
         log.debug("Connect SSL with handshaking %s", self.do_handshake_on_connect, extra={"sock": self._sock})
 
-        self.engine = self._context._createSSLEngine(*addr)
-        self.engine.setUseClientMode(not self.server_side)
-
         self._sock._connect(addr)
         if self.do_handshake_on_connect:
             self.do_handshake()
-        return self._sock.connect_ex(addr)
+
+        # from socketmodule.c
+        # if (res == EISCONN)
+        #   res = 0;
+        # but http://bugs.jython.org/issue2428
+        res = self._sock.connect_ex(addr)
+        if res == errno.EISCONN:
+            return 0
+        return res
 
     def unwrap(self):
         self._sock.channel.pipeline().remove("ssl")
@@ -527,6 +534,7 @@
 
     def do_handshake(self):
         log.debug("SSL handshaking", extra={"sock": self._sock})
+        self.setup_engine(self.sock.getpeername())
 
         def handshake_step(result):
             log.debug("SSL handshaking completed %s", result, extra={"sock": self._sock})
@@ -603,7 +611,6 @@
 
     def shutdown(self, how):
         self.sock.shutdown(how)
-
     # Need to work with the real underlying socket as well
 
     def pending(self):
@@ -844,12 +851,12 @@
 
     def __init__(self, protocol):
         try:
-            protocol_name = _PROTOCOL_NAMES[protocol]
+            self._protocol_name = _PROTOCOL_NAMES[protocol]
         except KeyError:
             raise ValueError("invalid protocol version")
 
         if protocol == PROTOCOL_SSLv23:  # darjus: at least my Java does not let me use v2
-            protocol_name = 'SSL'
+            self._protocol_name = 'SSL'
 
         self.protocol = protocol
         self._check_hostname = False
@@ -866,33 +873,42 @@
         self._key_store = KeyStore.getInstance(KeyStore.getDefaultType())
         self._key_store.load(None, None)
 
-        self._context = _JavaSSLContext.getInstance(protocol_name)
         self._key_managers = None
 
     def wrap_socket(self, sock, server_side=False,
                     do_handshake_on_connect=True,
                     suppress_ragged_eofs=True,
                     server_hostname=None):
-        return SSLSocket(sock, keyfile=None, certfile=None, ca_certs=None, cert_reqs=self.verify_mode, suppress_ragged_eofs=suppress_ragged_eofs,
-                         do_handshake_on_connect=do_handshake_on_connect, server_side=server_side,
-                         server_hostname=server_hostname, _context=self)
+        return SSLSocket(sock=sock, server_side=server_side,
+                         do_handshake_on_connect=do_handshake_on_connect,
+                         suppress_ragged_eofs=suppress_ragged_eofs,
+                         server_hostname=server_hostname,
+                         _context=self)
 
-    def _createSSLEngine(self, host=None, port=None):
+    def _createSSLEngine(self, addr, hostname=None):
         trust_managers = [NoVerifyX509TrustManager()]
         if self.verify_mode == CERT_REQUIRED:
             tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
             tmf.init(self._trust_store)
             trust_managers = [CompositeX509TrustManager(tmf.getTrustManagers())]
 
+        context = _JavaSSLContext.getInstance(self._protocol_name)
+
         if self._key_managers is None:  # get an e
-            self._context.init(_get_openssl_key_manager().getKeyManagers(), trust_managers, None)
+            context.init(_get_openssl_key_manager().getKeyManagers(), trust_managers, None)
         else:
-            self._context.init(self._key_managers.getKeyManagers(), trust_managers, None)
+            context.init(self._key_managers.getKeyManagers(), trust_managers, None)
 
-        if host is not None and port is not None:
-            engine = self._context.createSSLEngine(host, port)
+        if hostname is not None:
+            engine = context.createSSLEngine(hostname, addr[1])
         else:
-            engine = self._context.createSSLEngine()
+            engine = context.createSSLEngine(*addr)
+
+        # apparently this can be used to enforce hostname verification
+        if hostname is not None and self._check_hostname:
+            params = engine.getSSLParameters()
+            params.setEndpointIdentificationAlgorithm('HTTPS')
+            engine.setSSLParameters(params)
 
         if self._ciphers is not None:
             engine.setEnabledCipherSuites(self._ciphers)
diff --git a/Lib/test/test_socket_jy.py b/Lib/test/test_socket_jy.py
--- a/Lib/test/test_socket_jy.py
+++ b/Lib/test/test_socket_jy.py
@@ -41,7 +41,7 @@
         connect_errno = 0
         connect_attempt = 0
 
-        while connect_errno != errno.EISCONN and connect_attempt < 100:
+        while connect_errno != errno.EISCONN and connect_attempt < 500:
             connect_attempt += 1
             connect_errno = sock.connect_ex(self.address)
             results[index].append(connect_errno)
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -1324,7 +1324,6 @@
             finally:
                 s.close()
 
-    @unittest.skipIf(support.is_jython, "TODO, should raise certificate verify failed but does not")
     def test_connect_with_context(self):
         with support.transient_internet("svn.python.org"):
             # Same as test_connect, but with a separately created context
diff --git a/Lib/urllib.py b/Lib/urllib.py
--- a/Lib/urllib.py
+++ b/Lib/urllib.py
@@ -28,6 +28,7 @@
 import time
 import sys
 import base64
+import re
 
 from urlparse import urljoin as basejoin
 
@@ -68,15 +69,15 @@
 
 # Shortcut for basic usage
 _urlopener = None
-def urlopen(url, data=None, proxies=None):
+def urlopen(url, data=None, proxies=None, context=None):
     """Create a file-like object for the specified URL to read from."""
     from warnings import warnpy3k
     warnpy3k("urllib.urlopen() has been removed in Python 3.0 in "
              "favor of urllib2.urlopen()", stacklevel=2)
 
     global _urlopener
-    if proxies is not None:
-        opener = FancyURLopener(proxies=proxies)
+    if proxies is not None or context is not None:
+        opener = FancyURLopener(proxies=proxies, context=context)
     elif not _urlopener:
         opener = FancyURLopener()
         _urlopener = opener
@@ -86,11 +87,15 @@
         return opener.open(url)
     else:
         return opener.open(url, data)
-def urlretrieve(url, filename=None, reporthook=None, data=None):
+def urlretrieve(url, filename=None, reporthook=None, data=None, context=None):
     global _urlopener
-    if not _urlopener:
-        _urlopener = FancyURLopener()
-    return _urlopener.retrieve(url, filename, reporthook, data)
+    if context is not None:
+        opener = FancyURLopener(context=context)
+    elif not _urlopener:
+        _urlopener = opener = FancyURLopener()
+    else:
+        opener = _urlopener
+    return opener.retrieve(url, filename, reporthook, data)
 def urlcleanup():
     if _urlopener:
         _urlopener.cleanup()
@@ -125,13 +130,14 @@
     version = "Python-urllib/%s" % __version__
 
     # Constructor
-    def __init__(self, proxies=None, **x509):
+    def __init__(self, proxies=None, context=None, **x509):
         if proxies is None:
             proxies = getproxies()
         assert hasattr(proxies, 'has_key'), "proxies must be a mapping"
         self.proxies = proxies
         self.key_file = x509.get('key_file')
         self.cert_file = x509.get('cert_file')
+        self.context = context
         self.addheaders = [('User-Agent', self.version)]
         self.__tempfiles = []
         self.__unlink = os.unlink # See cleanup()
@@ -421,7 +427,8 @@
                 auth = None
             h = httplib.HTTPS(host, 0,
                               key_file=self.key_file,
-                              cert_file=self.cert_file)
+                              cert_file=self.cert_file,
+                              context=self.context)
             if data is not None:
                 h.putrequest('POST', selector)
                 h.putheader('Content-Type',
@@ -818,7 +825,10 @@
     """Return the IP address of the current host."""
     global _thishost
     if _thishost is None:
-        _thishost = socket.gethostbyname(socket.gethostname())
+        try:
+            _thishost = socket.gethostbyname(socket.gethostname())
+        except socket.gaierror:
+            _thishost = socket.gethostbyname('localhost')
     return _thishost
 
 _ftperrors = None
@@ -861,7 +871,11 @@
         self.timeout = timeout
         self.refcount = 0
         self.keepalive = persistent
-        self.init()
+        try:
+            self.init()
+        except:
+            self.close()
+            raise
 
     def init(self):
         import ftplib
@@ -869,8 +883,8 @@
         self.ftp = ftplib.FTP()
         self.ftp.connect(self.host, self.port, self.timeout)
         self.ftp.login(self.user, self.passwd)
-        for dir in self.dirs:
-            self.ftp.cwd(dir)
+        _target = '/'.join(self.dirs)
+        self.ftp.cwd(_target)
 
     def retrfile(self, file, type):
         import ftplib
@@ -980,11 +994,16 @@
         self.hookargs = hookargs
 
     def close(self):
-        if self.closehook:
-            self.closehook(*self.hookargs)
-            self.closehook = None
-            self.hookargs = None
-        addbase.close(self)
+        try:
+            closehook = self.closehook
+            hookargs = self.hookargs
+            if closehook:
+                self.closehook = None
+                self.hookargs = None
+                closehook(*hookargs)
+        finally:
+            addbase.close(self)
+
 
 class addinfo(addbase):
     """class to add an info() method to an open file."""
@@ -1121,10 +1140,13 @@
     global _portprog
     if _portprog is None:
         import re
-        _portprog = re.compile('^(.*):([0-9]+)$')
+        _portprog = re.compile('^(.*):([0-9]*)$')
 
     match = _portprog.match(host)
-    if match: return match.group(1, 2)
+    if match:
+        host, port = match.groups()
+        if port:
+            return host, port
     return host, None
 
 _nportprog = None
@@ -1141,12 +1163,12 @@
     match = _nportprog.match(host)
     if match:
         host, port = match.group(1, 2)
-        try:
-            if not port: raise ValueError, "no digits"
-            nport = int(port)
-        except ValueError:
-            nport = None
-        return host, nport
+        if port:
+            try:
+                nport = int(port)
+            except ValueError:
+                nport = None
+            return host, nport
     return host, defport
 
 _queryprog = None
@@ -1198,22 +1220,35 @@
 _hexdig = '0123456789ABCDEFabcdef'
 _hextochr = dict((a + b, chr(int(a + b, 16)))
                  for a in _hexdig for b in _hexdig)
+_asciire = re.compile('([\x00-\x7f]+)')
 
 def unquote(s):
     """unquote('abc%20def') -> 'abc def'."""
-    res = s.split('%')
+    if _is_unicode(s):
+        if '%' not in s:
+            return s
+        bits = _asciire.split(s)
+        res = [bits[0]]
+        append = res.append
+        for i in range(1, len(bits), 2):
+            append(unquote(str(bits[i])).decode('latin1'))
+            append(bits[i + 1])
+        return ''.join(res)
+
+    bits = s.split('%')
     # fastpath
-    if len(res) == 1:
+    if len(bits) == 1:
         return s
-    s = res[0]
-    for item in res[1:]:
+    res = [bits[0]]
+    append = res.append
+    for item in bits[1:]:
         try:
-            s += _hextochr[item[:2]] + item[2:]
+            append(_hextochr[item[:2]])
+            append(item[2:])
         except KeyError:
-            s += '%' + item
-        except UnicodeDecodeError:
-            s += unichr(int(item[:2], 16)) + item[2:]
-    return s
+            append('%')
+            append(item)
+    return ''.join(res)
 
 def unquote_plus(s):
     """unquote('%7e/abc+def') -> '~/abc def'"""

-- 
Repository URL: https://hg.python.org/jython


More information about the Jython-checkins mailing list