[Python-checkins] r66310 - in python/trunk: Doc/library/ssl.rst Lib/ssl.py Lib/test/test_ssl.py

bill.janssen python-checkins at python.org
Mon Sep 8 18:37:24 CEST 2008


Author: bill.janssen
Date: Mon Sep  8 18:37:24 2008
New Revision: 66310

Log:
incorporate fixes from issue 3162; SSL doc patch

Modified:
   python/trunk/Doc/library/ssl.rst
   python/trunk/Lib/ssl.py
   python/trunk/Lib/test/test_ssl.py

Modified: python/trunk/Doc/library/ssl.rst
==============================================================================
--- python/trunk/Doc/library/ssl.rst	(original)
+++ python/trunk/Doc/library/ssl.rst	Mon Sep  8 18:37:24 2008
@@ -327,9 +327,10 @@
    Performs the SSL shutdown handshake, which removes the TLS layer
    from the underlying socket, and returns the underlying socket
    object.  This can be used to go from encrypted operation over a
-   connection to unencrypted.  The returned socket should always be
+   connection to unencrypted.  The socket instance returned should always be
    used for further communication with the other side of the
-   connection, rather than the original socket
+   connection, rather than the original socket instance (which may
+   not function properly after the unwrap).
 
 .. index:: single: certificates
 

Modified: python/trunk/Lib/ssl.py
==============================================================================
--- python/trunk/Lib/ssl.py	(original)
+++ python/trunk/Lib/ssl.py	Mon Sep  8 18:37:24 2008
@@ -91,10 +91,12 @@
                  suppress_ragged_eofs=True):
         socket.__init__(self, _sock=sock._sock)
         # the initializer for socket trashes the methods (tsk, tsk), so...
-        self.send = lambda x, flags=0: SSLSocket.send(self, x, flags)
-        self.recv = lambda x, flags=0: SSLSocket.recv(self, x, flags)
+        self.send = lambda data, flags=0: SSLSocket.send(self, data, flags)
         self.sendto = lambda data, addr, flags=0: SSLSocket.sendto(self, data, addr, flags)
-        self.recvfrom = lambda addr, buflen, flags: SSLSocket.recvfrom(self, addr, buflen, flags)
+        self.recv = lambda buflen=1024, flags=0: SSLSocket.recv(self, buflen, flags)
+        self.recvfrom = lambda addr, buflen=1024, flags=0: SSLSocket.recvfrom(self, addr, buflen, flags)
+        self.recv_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recv_into(self, buffer, nbytes, flags)
+        self.recvfrom_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recvfrom_into(self, buffer, nbytes, flags)
 
         if certfile and not keyfile:
             keyfile = certfile
@@ -221,6 +223,30 @@
         else:
             return socket.recv(self, buflen, flags)
 
+    def recv_into (self, buffer, nbytes=None, flags=0):
+        if buffer and (nbytes is None):
+            nbytes = len(buffer)
+        elif nbytes is None:
+            nbytes = 1024
+        if self._sslobj:
+            if flags != 0:
+                raise ValueError(
+                  "non-zero flags not allowed in calls to recv_into() on %s" %
+                  self.__class__)
+            while True:
+                try:
+                    tmp_buffer = self.read(nbytes)
+                    v = len(tmp_buffer)
+                    buffer[:v] = tmp_buffer
+                    return v
+                except SSLError as x:
+                    if x.args[0] == SSL_ERROR_WANT_READ:
+                        continue
+                    else:
+                        raise x
+        else:
+            return socket.recv_into(self, buffer, nbytes, flags)
+
     def recvfrom (self, addr, buflen=1024, flags=0):
         if self._sslobj:
             raise ValueError("recvfrom not allowed on instances of %s" %
@@ -228,6 +254,13 @@
         else:
             return socket.recvfrom(self, addr, buflen, flags)
 
+    def recvfrom_into (self, buffer, nbytes=None, flags=0):
+        if self._sslobj:
+            raise ValueError("recvfrom_into not allowed on instances of %s" %
+                             self.__class__)
+        else:
+            return socket.recvfrom_into(self, buffer, nbytes, flags)
+
     def pending (self):
         if self._sslobj:
             return self._sslobj.pending()
@@ -295,8 +328,9 @@
 
     def makefile(self, mode='r', bufsize=-1):
 
-        """Ouch.  Need to make and return a file-like object that
-        works with the SSL connection."""
+        """Make and return a file-like object that
+        works with the SSL connection.  Just use the code
+        from the socket module."""
 
         self._makefile_refs += 1
         return _fileobject(self, mode, bufsize)

Modified: python/trunk/Lib/test/test_ssl.py
==============================================================================
--- python/trunk/Lib/test/test_ssl.py	(original)
+++ python/trunk/Lib/test/test_ssl.py	Mon Sep  8 18:37:24 2008
@@ -1030,6 +1030,127 @@
                 server.join()
 
 
+        def testAllRecvAndSendMethods(self):
+
+            if test_support.verbose:
+                sys.stdout.write("\n")
+
+            server = ThreadedEchoServer(CERTFILE,
+                                        certreqs=ssl.CERT_NONE,
+                                        ssl_version=ssl.PROTOCOL_TLSv1,
+                                        cacerts=CERTFILE,
+                                        chatty=True,
+                                        connectionchatty=False)
+            flag = threading.Event()
+            server.start(flag)
+            # wait for it to start
+            flag.wait()
+            # try to connect
+            try:
+                s = ssl.wrap_socket(socket.socket(),
+                                    server_side=False,
+                                    certfile=CERTFILE,
+                                    ca_certs=CERTFILE,
+                                    cert_reqs=ssl.CERT_NONE,
+                                    ssl_version=ssl.PROTOCOL_TLSv1)
+                s.connect((HOST, server.port))
+            except ssl.SSLError as x:
+                raise support.TestFailed("Unexpected SSL error:  " + str(x))
+            except Exception as x:
+                raise support.TestFailed("Unexpected exception:  " + str(x))
+            else:
+                # helper methods for standardising recv* method signatures
+                def _recv_into():
+                    b = bytearray("\0"*100)
+                    count = s.recv_into(b)
+                    return b[:count]
+
+                def _recvfrom_into():
+                    b = bytearray("\0"*100)
+                    count, addr = s.recvfrom_into(b)
+                    return b[:count]
+
+                # (name, method, whether to expect success, *args)
+                send_methods = [
+                    ('send', s.send, True, []),
+                    ('sendto', s.sendto, False, ["some.address"]),
+                    ('sendall', s.sendall, True, []),
+                ]
+                recv_methods = [
+                    ('recv', s.recv, True, []),
+                    ('recvfrom', s.recvfrom, False, ["some.address"]),
+                    ('recv_into', _recv_into, True, []),
+                    ('recvfrom_into', _recvfrom_into, False, []),
+                ]
+                data_prefix = u"PREFIX_"
+
+                for meth_name, send_meth, expect_success, args in send_methods:
+                    indata = data_prefix + meth_name
+                    try:
+                        send_meth(indata.encode('ASCII', 'strict'), *args)
+                        outdata = s.read()
+                        outdata = outdata.decode('ASCII', 'strict')
+                        if outdata != indata.lower():
+                            raise support.TestFailed(
+                                "While sending with <<%s>> bad data "
+                                "<<%r>> (%d) received; "
+                                "expected <<%r>> (%d)\n" % (
+                                    meth_name, outdata[:20], len(outdata),
+                                    indata[:20], len(indata)
+                                )
+                            )
+                    except ValueError as e:
+                        if expect_success:
+                            raise support.TestFailed(
+                                "Failed to send with method <<%s>>; "
+                                "expected to succeed.\n" % (meth_name,)
+                            )
+                        if not str(e).startswith(meth_name):
+                            raise support.TestFailed(
+                                "Method <<%s>> failed with unexpected "
+                                "exception message: %s\n" % (
+                                    meth_name, e
+                                )
+                            )
+
+                for meth_name, recv_meth, expect_success, args in recv_methods:
+                    indata = data_prefix + meth_name
+                    try:
+                        s.send(indata.encode('ASCII', 'strict'))
+                        outdata = recv_meth(*args)
+                        outdata = outdata.decode('ASCII', 'strict')
+                        if outdata != indata.lower():
+                            raise support.TestFailed(
+                                "While receiving with <<%s>> bad data "
+                                "<<%r>> (%d) received; "
+                                "expected <<%r>> (%d)\n" % (
+                                    meth_name, outdata[:20], len(outdata),
+                                    indata[:20], len(indata)
+                                )
+                            )
+                    except ValueError as e:
+                        if expect_success:
+                            raise support.TestFailed(
+                                "Failed to receive with method <<%s>>; "
+                                "expected to succeed.\n" % (meth_name,)
+                            )
+                        if not str(e).startswith(meth_name):
+                            raise support.TestFailed(
+                                "Method <<%s>> failed with unexpected "
+                                "exception message: %s\n" % (
+                                    meth_name, e
+                                )
+                            )
+                        # consume data
+                        s.read()
+
+                s.write("over\n".encode("ASCII", "strict"))
+                s.close()
+            finally:
+                server.stop()
+                server.join()
+
+
 def test_main(verbose=False):
     if skip_expected:
         raise test_support.TestSkipped("No SSL support")


More information about the Python-checkins mailing list