[pypy-svn] pypy default: Add SSLObject.shutdown()

amauryfa commits-noreply at bitbucket.org
Tue Jan 18 18:08:33 CET 2011


Author: Amaury Forgeot d'Arc <amauryfa at gmail.com>
Branch: 
Changeset: r40884:27ab1f6e4568
Date: 2011-01-18 18:05 +0100
http://bitbucket.org/pypy/pypy/changeset/27ab1f6e4568/

Log:	Add SSLObject.shutdown()

diff --git a/pypy/module/_ssl/test/test_ssl.py b/pypy/module/_ssl/test/test_ssl.py
--- a/pypy/module/_ssl/test/test_ssl.py
+++ b/pypy/module/_ssl/test/test_ssl.py
@@ -61,12 +61,14 @@
         _ssl.RAND_egd("entropy")
 
     def test_sslwrap(self):
-        import _ssl
-        import _socket
+        import _ssl, _socket, sys
         s = _socket.socket()
         ss = _ssl.sslwrap(s, 0)
         exc = raises(_socket.error, ss.do_handshake)
-        assert exc.value.errno == 32 # Broken pipe
+        if sys.platform == 'win32':
+            assert exc.value.errno == 2 # Cannot find file (=not a socket)
+        else:
+            assert exc.value.errno == 32 # Broken pipe
 
 class AppTestConnectedSSL:
     def setup_class(cls):
@@ -75,7 +77,7 @@
 
     def setup_method(self, method):
         # https://codespeak.net/
-        ADDR = "codespeak.net", 443
+        ADDR = "intranet", 443
 
         self.w_s = self.space.appexec([self.space.wrap(ADDR)], """(ADDR):
             import socket
@@ -132,6 +134,13 @@
         assert len(data) == 10
         self.s.close()
 
+    def test_shutdown(self):
+        import socket, ssl
+        ss = socket.ssl(self.s)
+        ss.write("hello\n")
+        assert ss.shutdown() is self.s._sock
+        raises(ssl.SSLError, ss.write, "hello\n")
+
 class AppTestConnectedSSL_Timeout(AppTestConnectedSSL):
     # Same tests, with a socket timeout
     # to exercise the poll() calls

diff --git a/pypy/module/_ssl/interp_ssl.py b/pypy/module/_ssl/interp_ssl.py
--- a/pypy/module/_ssl/interp_ssl.py
+++ b/pypy/module/_ssl/interp_ssl.py
@@ -133,6 +133,7 @@
         self._server[0] = '\0'
         self._issuer = lltype.malloc(rffi.CCHARP.TO, X509_NAME_MAXLEN, flavor='raw')
         self._issuer[0] = '\0'
+        self.shutdown_seen_zero = False
     
     def server(self):
         return self.space.wrap(rffi.charp2str(self._server))
@@ -157,6 +158,8 @@
 
         Writes the string s into the SSL object.  Returns the number
         of bytes written."""
+        self._refresh_nonblocking(self.space)
+
         sockstate = check_socket_and_wait_for_timeout(self.space,
             self.w_socket, True)
         if sockstate == SOCKET_HAS_TIMED_OUT:
@@ -248,13 +251,16 @@
         return self.space.wrap(result)
     read.unwrap_spec = ['self', int]
 
-    def do_handshake(self, space):
+    def _refresh_nonblocking(self, space):
         # just in case the blocking state of the socket has been changed
         w_timeout = space.call_method(self.w_socket, "gettimeout")
         nonblocking = not space.is_w(w_timeout, space.w_None)
         libssl_BIO_set_nbio(libssl_SSL_get_rbio(self.ssl), nonblocking)
         libssl_BIO_set_nbio(libssl_SSL_get_wbio(self.ssl), nonblocking)
 
+    def do_handshake(self, space):
+        self._refresh_nonblocking(space)
+
         # Actually negotiate SSL connection
         # XXX If SSL_do_handshake() returns 0, it's also a failure.
         while True:
@@ -297,6 +303,69 @@
                 libssl_X509_get_issuer_name(self.peer_cert),
                 self._issuer, X509_NAME_MAXLEN)
 
+    def shutdown(self, space):
+        # Guard against closed socket
+        w_fileno = space.call_method(self.w_socket, "fileno")
+        if space.int_w(w_fileno) < 0:
+            raise ssl_error(space, "Underlying socket has been closed")
+
+        self._refresh_nonblocking(space)
+
+        zeros = 0
+
+        while True:
+            # Disable read-ahead so that unwrap can work correctly.
+            # Otherwise OpenSSL might read in too much data,
+            # eating clear text data that happens to be
+            # transmitted after the SSL shutdown.
+            # Should be safe to call repeatedly everytime this
+            # function is used and the shutdown_seen_zero != 0
+            # condition is met.
+            if self.shutdown_seen_zero:
+                libssl_SSL_set_read_ahead(self.ssl, 0)
+            ret = libssl_SSL_shutdown(self.ssl)
+
+            # if err == 1, a secure shutdown with SSL_shutdown() is complete
+            if ret > 0:
+                break
+            if ret == 0:
+                # Don't loop endlessly; instead preserve legacy
+                # behaviour of trying SSL_shutdown() only twice.
+                # This looks necessary for OpenSSL < 0.9.8m
+                zeros += 1
+                if zeros > 1:
+                    break
+                # Shutdown was sent, now try receiving
+                self.shutdown_seen_zero = True
+                continue
+
+            # Possibly retry shutdown until timeout or failure 
+            ssl_err = libssl_SSL_get_error(self.ssl, ret)
+            if ssl_err == SSL_ERROR_WANT_READ:
+                sockstate = check_socket_and_wait_for_timeout(
+                    self.space, self.w_socket, False)
+            elif ssl_err == SSL_ERROR_WANT_WRITE:
+                sockstate = check_socket_and_wait_for_timeout(
+                    self.space, self.w_socket, True)
+            else:
+                break
+
+            if sockstate == SOCKET_HAS_TIMED_OUT:
+                if ssl_err == SSL_ERROR_WANT_READ:
+                    raise ssl_error(self.space, "The read operation timed out")
+                else:
+                    raise ssl_error(self.space, "The write operation timed out")
+            elif sockstate == SOCKET_TOO_LARGE_FOR_SELECT:
+                raise ssl_error(space, "Underlying socket too large for select().")
+            elif sockstate != SOCKET_OPERATION_OK:
+                # Retain the SSL error code
+                break
+
+        if ret < 0:
+            raise _ssl_seterror(space, self, ret)
+
+        return self.w_socket
+
 
 SSLObject.typedef = TypeDef("SSLObject",
     server = interp2app(SSLObject.server,
@@ -308,6 +377,8 @@
     read = interp2app(SSLObject.read, unwrap_spec=SSLObject.read.unwrap_spec),
     do_handshake=interp2app(SSLObject.do_handshake,
                             unwrap_spec=['self', ObjSpace]),
+    shutdown=interp2app(SSLObject.shutdown,
+                        unwrap_spec=['self', ObjSpace]),
 )
 
 
@@ -357,6 +428,7 @@
     libssl_SSL_CTX_set_verify(ss.ctx, SSL_VERIFY_NONE, None) # set verify level
     ss.ssl = libssl_SSL_new(ss.ctx) # new ssl struct
     libssl_SSL_set_fd(ss.ssl, sock_fd) # set the socket for SSL
+    libssl_SSL_set_mode(ss.ssl, SSL_MODE_AUTO_RETRY)
 
     # If the socket is in non-blocking mode or timeout mode, set the BIO
     # to non-blocking mode (blocking is the default)

diff --git a/pypy/rlib/ropenssl.py b/pypy/rlib/ropenssl.py
--- a/pypy/rlib/ropenssl.py
+++ b/pypy/rlib/ropenssl.py
@@ -66,7 +66,9 @@
     SSL_ERROR_SYSCALL = rffi_platform.ConstantInteger("SSL_ERROR_SYSCALL")
     SSL_ERROR_SSL = rffi_platform.ConstantInteger("SSL_ERROR_SSL")
     SSL_CTRL_OPTIONS = rffi_platform.ConstantInteger("SSL_CTRL_OPTIONS")
+    SSL_CTRL_MODE = rffi_platform.ConstantInteger("SSL_CTRL_MODE")
     BIO_C_SET_NBIO = rffi_platform.ConstantInteger("BIO_C_SET_NBIO")
+    SSL_MODE_AUTO_RETRY = rffi_platform.ConstantInteger("SSL_MODE_AUTO_RETRY")
 
 for k, v in rffi_platform.configure(CConfig).items():
     globals()[k] = v
@@ -105,6 +107,7 @@
 ssl_external('SSL_CTX_set_verify', [SSL_CTX, rffi.INT, rffi.VOIDP], lltype.Void)
 ssl_external('SSL_new', [SSL_CTX], SSL)
 ssl_external('SSL_set_fd', [SSL, rffi.INT], rffi.INT)
+ssl_external('SSL_ctrl', [SSL, rffi.INT, rffi.INT, rffi.VOIDP], rffi.INT)
 ssl_external('BIO_ctrl', [BIO, rffi.INT, rffi.INT, rffi.VOIDP], rffi.INT)
 ssl_external('SSL_get_rbio', [SSL], BIO)
 ssl_external('SSL_get_wbio', [SSL], BIO)
@@ -153,6 +156,8 @@
 EVP_MD_CTX_cleanup = external(
     'EVP_MD_CTX_cleanup', [EVP_MD_CTX], rffi.INT)
 
+def libssl_SSL_set_mode(ssl, op):
+    return libssl_SSL_ctrl(ssl, SSL_CTRL_MODE, op, None)
 def libssl_SSL_CTX_set_options(ctx, op):
     return libssl_SSL_CTX_ctrl(ctx, SSL_CTRL_OPTIONS, op, None)
 def libssl_BIO_set_nbio(bio, nonblocking):


More information about the Pypy-commit mailing list