[Python-checkins] bpo-34271: Add ssl debugging helpers (GH-10031)

Christian Heimes webhook-mailer at python.org
Fri May 31 05:44:09 EDT 2019


https://github.com/python/cpython/commit/c7f7069e77c58e83b847c0bfe4d5aadf6add2e68
commit: c7f7069e77c58e83b847c0bfe4d5aadf6add2e68
branch: master
author: Christian Heimes <christian at python.org>
committer: GitHub <noreply at github.com>
date: 2019-05-31T11:44:05+02:00
summary:

bpo-34271: Add ssl debugging helpers (GH-10031)

The ssl module now can dump key material to a keylog file and trace TLS
protocol messages with a tracing callback. The default and stdlib
contexts also support SSLKEYLOGFILE env var.

The msg_callback and related enums are private members. The feature
is designed for internal debugging and not for end users.

Signed-off-by: Christian Heimes <christian at python.org>

files:
A Misc/NEWS.d/next/Library/2018-10-21-17-39-32.bpo-34271.P15VLM.rst
A Modules/_ssl/debughelpers.c
M Doc/library/ssl.rst
M Lib/ssl.py
M Lib/test/test_ssl.py
M Modules/_ssl.c
M setup.py

diff --git a/Doc/library/ssl.rst b/Doc/library/ssl.rst
index 20f572444716..be09f38f7dfa 100644
--- a/Doc/library/ssl.rst
+++ b/Doc/library/ssl.rst
@@ -139,6 +139,10 @@ purposes.
    *cadata* is given) or uses :meth:`SSLContext.load_default_certs` to load
    default CA certificates.
 
+   When :attr:`~SSLContext.keylog_filename` is supported and the environment
+   variable :envvar:`SSLKEYLOGFILE` is set, :func:`create_default_context`
+   enables key logging.
+
    .. note::
       The protocol, options, cipher and other settings may change to more
       restrictive values anytime without prior deprecation.  The values
@@ -172,6 +176,10 @@ purposes.
 
      3DES was dropped from the default cipher string.
 
+   .. versionchanged:: 3.8
+
+      Support for key logging to :envvar:`SSLKEYLOGFILE` was added.
+
 
 Exceptions
 ^^^^^^^^^^
@@ -1056,6 +1064,7 @@ Constants
 
    SSL 3.0 to TLS 1.3.
 
+
 SSL Sockets
 -----------
 
@@ -1901,6 +1910,20 @@ to speed up repeated connections from the same clients.
 
      This features requires OpenSSL 0.9.8f or newer.
 
+.. attribute:: SSLContext.keylog_filename
+
+   Write TLS keys to a keylog file, whenever key material is generated or
+   received. The keylog file is designed for debugging purposes only. The
+   file format is specified by NSS and used by many traffic analyzers such
+   as Wireshark. The log file is opened in append-only mode. Writes are
+   synchronized between threads, but not between processes.
+
+   .. versionadded:: 3.8
+
+   .. note::
+
+     This features requires OpenSSL 1.1.1 or newer.
+
 .. attribute:: SSLContext.maximum_version
 
    A :class:`TLSVersion` enum member representing the highest supported
diff --git a/Lib/ssl.py b/Lib/ssl.py
index 793ed496c77a..f5fa6aeec2d2 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -165,6 +165,90 @@ class TLSVersion(_IntEnum):
     MAXIMUM_SUPPORTED = _ssl.PROTO_MAXIMUM_SUPPORTED
 
 
+class _TLSContentType(_IntEnum):
+    """Content types (record layer)
+
+    See RFC 8446, section B.1
+    """
+    CHANGE_CIPHER_SPEC = 20
+    ALERT = 21
+    HANDSHAKE = 22
+    APPLICATION_DATA = 23
+    # pseudo content types
+    HEADER = 0x100
+    INNER_CONTENT_TYPE = 0x101
+
+
+class _TLSAlertType(_IntEnum):
+    """Alert types for TLSContentType.ALERT messages
+
+    See RFC 8466, section B.2
+    """
+    CLOSE_NOTIFY = 0
+    UNEXPECTED_MESSAGE = 10
+    BAD_RECORD_MAC = 20
+    DECRYPTION_FAILED = 21
+    RECORD_OVERFLOW = 22
+    DECOMPRESSION_FAILURE = 30
+    HANDSHAKE_FAILURE = 40
+    NO_CERTIFICATE = 41
+    BAD_CERTIFICATE = 42
+    UNSUPPORTED_CERTIFICATE = 43
+    CERTIFICATE_REVOKED = 44
+    CERTIFICATE_EXPIRED = 45
+    CERTIFICATE_UNKNOWN = 46
+    ILLEGAL_PARAMETER = 47
+    UNKNOWN_CA = 48
+    ACCESS_DENIED = 49
+    DECODE_ERROR = 50
+    DECRYPT_ERROR = 51
+    EXPORT_RESTRICTION = 60
+    PROTOCOL_VERSION = 70
+    INSUFFICIENT_SECURITY = 71
+    INTERNAL_ERROR = 80
+    INAPPROPRIATE_FALLBACK = 86
+    USER_CANCELED = 90
+    NO_RENEGOTIATION = 100
+    MISSING_EXTENSION = 109
+    UNSUPPORTED_EXTENSION = 110
+    CERTIFICATE_UNOBTAINABLE = 111
+    UNRECOGNIZED_NAME = 112
+    BAD_CERTIFICATE_STATUS_RESPONSE = 113
+    BAD_CERTIFICATE_HASH_VALUE = 114
+    UNKNOWN_PSK_IDENTITY = 115
+    CERTIFICATE_REQUIRED = 116
+    NO_APPLICATION_PROTOCOL = 120
+
+
+class _TLSMessageType(_IntEnum):
+    """Message types (handshake protocol)
+
+    See RFC 8446, section B.3
+    """
+    HELLO_REQUEST = 0
+    CLIENT_HELLO = 1
+    SERVER_HELLO = 2
+    HELLO_VERIFY_REQUEST = 3
+    NEWSESSION_TICKET = 4
+    END_OF_EARLY_DATA = 5
+    HELLO_RETRY_REQUEST = 6
+    ENCRYPTED_EXTENSIONS = 8
+    CERTIFICATE = 11
+    SERVER_KEY_EXCHANGE = 12
+    CERTIFICATE_REQUEST = 13
+    SERVER_DONE = 14
+    CERTIFICATE_VERIFY = 15
+    CLIENT_KEY_EXCHANGE = 16
+    FINISHED = 20
+    CERTIFICATE_URL = 21
+    CERTIFICATE_STATUS = 22
+    SUPPLEMENTAL_DATA = 23
+    KEY_UPDATE = 24
+    NEXT_PROTO = 67
+    MESSAGE_HASH = 254
+    CHANGE_CIPHER_SPEC = 0x0101
+
+
 if sys.platform == "win32":
     from _ssl import enum_certificates, enum_crls
 
@@ -523,6 +607,83 @@ def hostname_checks_common_name(self, value):
         def hostname_checks_common_name(self):
             return True
 
+    @property
+    def _msg_callback(self):
+        """TLS message callback
+
+        The message callback provides a debugging hook to analyze TLS
+        connections. The callback is called for any TLS protocol message
+        (header, handshake, alert, and more), but not for application data.
+        Due to technical  limitations, the callback can't be used to filter
+        traffic or to abort a connection. Any exception raised in the
+        callback is delayed until the handshake, read, or write operation
+        has been performed.
+
+        def msg_cb(conn, direction, version, content_type, msg_type, data):
+            pass
+
+        conn
+            :class:`SSLSocket` or :class:`SSLObject` instance
+        direction
+            ``read`` or ``write``
+        version
+            :class:`TLSVersion` enum member or int for unknown version. For a
+            frame header, it's the header version.
+        content_type
+            :class:`_TLSContentType` enum member or int for unsupported
+            content type.
+        msg_type
+            Either a :class:`_TLSContentType` enum number for a header
+            message, a :class:`_TLSAlertType` enum member for an alert
+            message, a :class:`_TLSMessageType` enum member for other
+            messages, or int for unsupported message types.
+        data
+            Raw, decrypted message content as bytes
+        """
+        inner = super()._msg_callback
+        if inner is not None:
+            return inner.user_function
+        else:
+            return None
+
+    @_msg_callback.setter
+    def _msg_callback(self, callback):
+        if callback is None:
+            super(SSLContext, SSLContext)._msg_callback.__set__(self, None)
+            return
+
+        if not hasattr(callback, '__call__'):
+            raise TypeError(f"{callback} is not callable.")
+
+        def inner(conn, direction, version, content_type, msg_type, data):
+            try:
+                version = TLSVersion(version)
+            except TypeError:
+                pass
+
+            try:
+                content_type = _TLSContentType(content_type)
+            except TypeError:
+                pass
+
+            if content_type == _TLSContentType.HEADER:
+                msg_enum = _TLSContentType
+            elif content_type == _TLSContentType.ALERT:
+                msg_enum = _TLSAlertType
+            else:
+                msg_enum = _TLSMessageType
+            try:
+                msg_type = msg_enum(msg_type)
+            except TypeError:
+                pass
+
+            return callback(conn, direction, version,
+                            content_type, msg_type, data)
+
+        inner.user_function = callback
+
+        super(SSLContext, SSLContext)._msg_callback.__set__(self, inner)
+
     @property
     def protocol(self):
         return _SSLMethod(super().protocol)
@@ -576,6 +737,11 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
         # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
         # root CA certificates for the given purpose. This may fail silently.
         context.load_default_certs(purpose)
+    # OpenSSL 1.1.1 keylog file
+    if hasattr(context, 'keylog_filename'):
+        keylogfile = os.environ.get('SSLKEYLOGFILE')
+        if keylogfile and not sys.flags.ignore_environment:
+            context.keylog_filename = keylogfile
     return context
 
 def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=CERT_NONE,
@@ -617,7 +783,11 @@ def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=CERT_NONE,
         # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
         # root CA certificates for the given purpose. This may fail silently.
         context.load_default_certs(purpose)
-
+    # OpenSSL 1.1.1 keylog file
+    if hasattr(context, 'keylog_filename'):
+        keylogfile = os.environ.get('SSLKEYLOGFILE')
+        if keylogfile and not sys.flags.ignore_environment:
+            context.keylog_filename = keylogfile
     return context
 
 # Used by http.client if no context is explicitly passed.
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index d48d6e5569fc..f368906c8a94 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -2,6 +2,7 @@
 
 import sys
 import unittest
+import unittest.mock
 from test import support
 import socket
 import select
@@ -25,6 +26,7 @@
 
 ssl = support.import_module("ssl")
 
+from ssl import TLSVersion, _TLSContentType, _TLSMessageType, _TLSAlertType
 
 PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
 HOST = support.HOST
@@ -4405,6 +4407,170 @@ def test_pha_not_tls13(self):
                 self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
 
 
+HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
+requires_keylog = unittest.skipUnless(
+    HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')
+
+class TestSSLDebug(unittest.TestCase):
+
+    def keylog_lines(self, fname=support.TESTFN):
+        with open(fname) as f:
+            return len(list(f))
+
+    @requires_keylog
+    def test_keylog_defaults(self):
+        self.addCleanup(support.unlink, support.TESTFN)
+        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+        self.assertEqual(ctx.keylog_filename, None)
+
+        self.assertFalse(os.path.isfile(support.TESTFN))
+        ctx.keylog_filename = support.TESTFN
+        self.assertEqual(ctx.keylog_filename, support.TESTFN)
+        self.assertTrue(os.path.isfile(support.TESTFN))
+        self.assertEqual(self.keylog_lines(), 1)
+
+        ctx.keylog_filename = None
+        self.assertEqual(ctx.keylog_filename, None)
+
+        with self.assertRaises((IsADirectoryError, PermissionError)):
+            # Windows raises PermissionError
+            ctx.keylog_filename = os.path.dirname(
+                os.path.abspath(support.TESTFN))
+
+        with self.assertRaises(TypeError):
+            ctx.keylog_filename = 1
+
+    @requires_keylog
+    def test_keylog_filename(self):
+        self.addCleanup(support.unlink, support.TESTFN)
+        client_context, server_context, hostname = testing_context()
+
+        client_context.keylog_filename = support.TESTFN
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+        # header, 5 lines for TLS 1.3
+        self.assertEqual(self.keylog_lines(), 6)
+
+        client_context.keylog_filename = None
+        server_context.keylog_filename = support.TESTFN
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+        self.assertGreaterEqual(self.keylog_lines(), 11)
+
+        client_context.keylog_filename = support.TESTFN
+        server_context.keylog_filename = support.TESTFN
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+        self.assertGreaterEqual(self.keylog_lines(), 21)
+
+        client_context.keylog_filename = None
+        server_context.keylog_filename = None
+
+    @requires_keylog
+    @unittest.skipIf(sys.flags.ignore_environment,
+                     "test is not compatible with ignore_environment")
+    def test_keylog_env(self):
+        self.addCleanup(support.unlink, support.TESTFN)
+        with unittest.mock.patch.dict(os.environ):
+            os.environ['SSLKEYLOGFILE'] = support.TESTFN
+            self.assertEqual(os.environ['SSLKEYLOGFILE'], support.TESTFN)
+
+            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+            self.assertEqual(ctx.keylog_filename, None)
+
+            ctx = ssl.create_default_context()
+            self.assertEqual(ctx.keylog_filename, support.TESTFN)
+
+            ctx = ssl._create_stdlib_context()
+            self.assertEqual(ctx.keylog_filename, support.TESTFN)
+
+    def test_msg_callback(self):
+        client_context, server_context, hostname = testing_context()
+
+        def msg_cb(conn, direction, version, content_type, msg_type, data):
+            pass
+
+        self.assertIs(client_context._msg_callback, None)
+        client_context._msg_callback = msg_cb
+        self.assertIs(client_context._msg_callback, msg_cb)
+        with self.assertRaises(TypeError):
+            client_context._msg_callback = object()
+
+    def test_msg_callback_tls12(self):
+        client_context, server_context, hostname = testing_context()
+        client_context.options |= ssl.OP_NO_TLSv1_3
+
+        msg = []
+
+        def msg_cb(conn, direction, version, content_type, msg_type, data):
+            self.assertIsInstance(conn, ssl.SSLSocket)
+            self.assertIsInstance(data, bytes)
+            self.assertIn(direction, {'read', 'write'})
+            msg.append((direction, version, content_type, msg_type))
+
+        client_context._msg_callback = msg_cb
+
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+
+        self.assertEqual(msg, [
+            ("write", TLSVersion.TLSv1, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.CLIENT_HELLO),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.SERVER_HELLO),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.CERTIFICATE),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.SERVER_KEY_EXCHANGE),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.SERVER_DONE),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.CLIENT_KEY_EXCHANGE),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.FINISHED),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.CHANGE_CIPHER_SPEC,
+             _TLSMessageType.CHANGE_CIPHER_SPEC),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.FINISHED),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.NEWSESSION_TICKET),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.FINISHED),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.FINISHED),
+        ])
+
+
 def test_main(verbose=False):
     if support.verbose:
         import warnings
@@ -4440,7 +4606,7 @@ def test_main(verbose=False):
     tests = [
         ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
         SSLObjectTests, SimpleBackgroundTests, ThreadedTests,
-        TestPostHandshakeAuth
+        TestPostHandshakeAuth, TestSSLDebug
     ]
 
     if support.is_resource_enabled('network'):
diff --git a/Misc/NEWS.d/next/Library/2018-10-21-17-39-32.bpo-34271.P15VLM.rst b/Misc/NEWS.d/next/Library/2018-10-21-17-39-32.bpo-34271.P15VLM.rst
new file mode 100644
index 000000000000..344388f7f228
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-10-21-17-39-32.bpo-34271.P15VLM.rst
@@ -0,0 +1,3 @@
+Add debugging helpers to ssl module. It's now possible to dump key material
+and to trace TLS protocol. The default and stdlib contexts also support
+SSLKEYLOGFILE env var.
diff --git a/Modules/_ssl.c b/Modules/_ssl.c
index 4fb7dca9bb04..f40127d3d932 100644
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -185,6 +185,10 @@ static void _PySSLFixErrno(void) {
 # define HAVE_NPN 0
 #endif
 
+#if (OPENSSL_VERSION_NUMBER >= 0x10101000L) && !defined(LIBRESSL_VERSION_NUMBER)
+#define HAVE_OPENSSL_KEYLOG 1
+#endif
+
 #ifndef INVALID_SOCKET /* MS defines this */
 #define INVALID_SOCKET (-1)
 #endif
@@ -423,6 +427,11 @@ typedef struct {
     int protocol;
 #ifdef TLS1_3_VERSION
     int post_handshake_auth;
+#endif
+    PyObject *msg_cb;
+#ifdef HAVE_OPENSSL_KEYLOG
+    PyObject *keylog_filename;
+    BIO *keylog_bio;
 #endif
 } PySSLContext;
 
@@ -444,6 +453,13 @@ typedef struct {
     PyObject *owner; /* Python level "owner" passed to servername callback */
     PyObject *server_hostname;
     _PySSLError err; /* last seen error from various sources */
+    /* Some SSL callbacks don't have error reporting. Callback wrappers
+     * store exception information on the socket. The handshake, read, write,
+     * and shutdown methods check for chained exceptions.
+     */
+    PyObject *exc_type;
+    PyObject *exc_value;
+    PyObject *exc_tb;
 } PySSLSocket;
 
 typedef struct {
@@ -517,6 +533,8 @@ typedef enum {
 #define GET_SOCKET_TIMEOUT(sock) \
     ((sock != NULL) ? (sock)->sock_timeout : 0)
 
+#include "_ssl/debughelpers.c"
+
 /*
  * SSL errors.
  */
@@ -703,6 +721,18 @@ fill_and_set_sslerror(PySSLSocket *sslsock, PyObject *type, int ssl_errno,
     Py_XDECREF(verify_obj);
 }
 
+static int
+PySSL_ChainExceptions(PySSLSocket *sslsock) {
+    if (sslsock->exc_type == NULL)
+        return 0;
+
+    _PyErr_ChainExceptions(sslsock->exc_type, sslsock->exc_value, sslsock->exc_tb);
+    sslsock->exc_type = NULL;
+    sslsock->exc_value = NULL;
+    sslsock->exc_tb = NULL;
+    return -1;
+}
+
 static PyObject *
 PySSL_SetError(PySSLSocket *sslsock, int ret, const char *filename, int lineno)
 {
@@ -796,6 +826,7 @@ PySSL_SetError(PySSLSocket *sslsock, int ret, const char *filename, int lineno)
     }
     fill_and_set_sslerror(sslsock, type, p, errstr, lineno, e);
     ERR_clear_error();
+    PySSL_ChainExceptions(sslsock);
     return NULL;
 }
 
@@ -903,6 +934,9 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
     self->owner = NULL;
     self->server_hostname = NULL;
     self->err = err;
+    self->exc_type = NULL;
+    self->exc_value = NULL;
+    self->exc_tb = NULL;
 
     /* Make sure the SSL error state is initialized */
     ERR_clear_error();
@@ -1052,11 +1086,12 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self)
     Py_XDECREF(sock);
     if (ret < 1)
         return PySSL_SetError(self, ret, __FILE__, __LINE__);
-
+    if (PySSL_ChainExceptions(self) < 0)
+        return NULL;
     Py_RETURN_NONE;
-
 error:
     Py_XDECREF(sock);
+    PySSL_ChainExceptions(self);
     return NULL;
 }
 
@@ -2151,8 +2186,26 @@ PyDoc_STRVAR(PySSL_get_owner_doc,
 "The Python-level owner of this object.\
 Passed as \"self\" in servername callback.");
 
+static int
+PySSL_traverse(PySSLSocket *self, visitproc visit, void *arg)
+{
+    Py_VISIT(self->exc_type);
+    Py_VISIT(self->exc_value);
+    Py_VISIT(self->exc_tb);
+    return 0;
+}
+
+static int
+PySSL_clear(PySSLSocket *self)
+{
+    Py_CLEAR(self->exc_type);
+    Py_CLEAR(self->exc_value);
+    Py_CLEAR(self->exc_tb);
+    return 0;
+}
 
-static void PySSL_dealloc(PySSLSocket *self)
+static void
+PySSL_dealloc(PySSLSocket *self)
 {
     if (self->ssl)
         SSL_free(self->ssl);
@@ -2333,13 +2386,14 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b)
              err.ssl == SSL_ERROR_WANT_WRITE);
 
     Py_XDECREF(sock);
-    if (len > 0)
-        return PyLong_FromLong(len);
-    else
+    if (len <= 0)
         return PySSL_SetError(self, len, __FILE__, __LINE__);
-
+    if (PySSL_ChainExceptions(self) < 0)
+        return NULL;
+    return PyLong_FromLong(len);
 error:
     Py_XDECREF(sock);
+    PySSL_ChainExceptions(self);
     return NULL;
 }
 
@@ -2486,6 +2540,8 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, int len, int group_right_1,
         PySSL_SetError(self, count, __FILE__, __LINE__);
         goto error;
     }
+    if (self->exc_type != NULL)
+        goto error;
 
 done:
     Py_XDECREF(sock);
@@ -2498,6 +2554,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, int len, int group_right_1,
     }
 
 error:
+    PySSL_ChainExceptions(self);
     Py_XDECREF(sock);
     if (!group_right_1)
         Py_XDECREF(dest);
@@ -2601,11 +2658,13 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)
             /* Retain the SSL error code */
             break;
     }
-
     if (ret < 0) {
         Py_XDECREF(sock);
-        return PySSL_SetError(self, ret, __FILE__, __LINE__);
+        PySSL_SetError(self, ret, __FILE__, __LINE__);
+        return NULL;
     }
+    if (self->exc_type != NULL)
+        goto error;
     if (sock)
         /* It's already INCREF'ed */
         return (PyObject *) sock;
@@ -2614,6 +2673,7 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)
 
 error:
     Py_XDECREF(sock);
+    PySSL_ChainExceptions(self);
     return NULL;
 }
 
@@ -2889,8 +2949,8 @@ static PyTypeObject PySSLSocket_Type = {
     0,                                  /*tp_as_buffer*/
     Py_TPFLAGS_DEFAULT,                 /*tp_flags*/
     0,                                  /*tp_doc*/
-    0,                                  /*tp_traverse*/
-    0,                                  /*tp_clear*/
+    (traverseproc) PySSL_traverse,      /*tp_traverse*/
+    (inquiry) PySSL_clear,              /*tp_clear*/
     0,                                  /*tp_richcompare*/
     0,                                  /*tp_weaklistoffset*/
     0,                                  /*tp_iter*/
@@ -3002,6 +3062,11 @@ _ssl__SSLContext_impl(PyTypeObject *type, int proto_version)
     self->ctx = ctx;
     self->hostflags = X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS;
     self->protocol = proto_version;
+    self->msg_cb = NULL;
+#ifdef HAVE_OPENSSL_KEYLOG
+    self->keylog_filename = NULL;
+    self->keylog_bio = NULL;
+#endif
 #if HAVE_NPN
     self->npn_protocols = NULL;
 #endif
@@ -3127,6 +3192,7 @@ context_traverse(PySSLContext *self, visitproc visit, void *arg)
 #ifndef OPENSSL_NO_TLSEXT
     Py_VISIT(self->set_sni_cb);
 #endif
+    Py_VISIT(self->msg_cb);
     return 0;
 }
 
@@ -3135,6 +3201,16 @@ context_clear(PySSLContext *self)
 {
 #ifndef OPENSSL_NO_TLSEXT
     Py_CLEAR(self->set_sni_cb);
+#endif
+    Py_CLEAR(self->msg_cb);
+#ifdef HAVE_OPENSSL_KEYLOG
+    Py_CLEAR(self->keylog_filename);
+    if (self->keylog_bio != NULL) {
+        PySSL_BEGIN_ALLOW_THREADS
+        BIO_free_all(self->keylog_bio);
+        PySSL_END_ALLOW_THREADS
+        self->keylog_bio = NULL;
+    }
 #endif
     return 0;
 }
@@ -4570,6 +4646,12 @@ static PyGetSetDef context_getsetlist[] = {
     {"maximum_version", (getter) get_maximum_version,
                         (setter) set_maximum_version, NULL},
 #endif
+#ifdef HAVE_OPENSSL_KEYLOG
+    {"keylog_filename", (getter) _PySSLContext_get_keylog_filename,
+                        (setter) _PySSLContext_set_keylog_filename, NULL},
+#endif
+    {"_msg_callback", (getter) _PySSLContext_get_msg_callback,
+                      (setter) _PySSLContext_set_msg_callback, NULL},
     {"sni_callback", (getter) get_sni_callback,
                      (setter) set_sni_callback, PySSLContext_sni_callback_doc},
     {"options", (getter) get_options,
diff --git a/Modules/_ssl/debughelpers.c b/Modules/_ssl/debughelpers.c
new file mode 100644
index 000000000000..53b966749328
--- /dev/null
+++ b/Modules/_ssl/debughelpers.c
@@ -0,0 +1,213 @@
+/* Debug helpers */
+
+static void
+_PySSL_msg_callback(int write_p, int version, int content_type,
+                    const void *buf, size_t len, SSL *ssl, void *arg)
+{
+    const char *cbuf = (const char *)buf;
+    PyGILState_STATE threadstate;
+    PyObject *res = NULL;
+    PySSLSocket *ssl_obj = NULL;  /* ssl._SSLSocket, borrowed ref */
+    PyObject *ssl_socket = NULL;  /* ssl.SSLSocket or ssl.SSLObject */
+    int msg_type;
+
+    threadstate = PyGILState_Ensure();
+
+    ssl_obj = (PySSLSocket *)SSL_get_app_data(ssl);
+    assert(PySSLSocket_Check(ssl_obj));
+    if (ssl_obj->ctx->msg_cb == NULL) {
+        return;
+    }
+
+    if (ssl_obj->owner)
+        ssl_socket = PyWeakref_GetObject(ssl_obj->owner);
+    else if (ssl_obj->Socket)
+        ssl_socket = PyWeakref_GetObject(ssl_obj->Socket);
+    else
+        ssl_socket = (PyObject *)ssl_obj;
+    Py_INCREF(ssl_socket);
+
+    /* assume that OpenSSL verifies all payload and buf len is of sufficient
+       length */
+    switch(content_type) {
+      case SSL3_RT_CHANGE_CIPHER_SPEC:
+        msg_type = SSL3_MT_CHANGE_CIPHER_SPEC;
+        break;
+      case SSL3_RT_ALERT:
+        /* byte 0: level */
+        /* byte 1: alert type */
+        msg_type = (int)cbuf[1];
+        break;
+      case SSL3_RT_HANDSHAKE:
+        msg_type = (int)cbuf[0];
+        break;
+      case SSL3_RT_HEADER:
+        /* frame header encodes version in bytes 1..2 */
+        version = cbuf[1] << 8 | cbuf[2];
+        msg_type = (int)cbuf[0];
+        break;
+#ifdef SSL3_RT_INNER_CONTENT_TYPE
+      case SSL3_RT_INNER_CONTENT_TYPE:
+        msg_type = (int)cbuf[0];
+        break;
+#endif
+      default:
+        /* never SSL3_RT_APPLICATION_DATA */
+        msg_type = -1;
+        break;
+    }
+
+    res = PyObject_CallFunction(
+        ssl_obj->ctx->msg_cb, "Osiiiy#",
+        ssl_socket, write_p ? "write" : "read",
+        version, content_type, msg_type,
+        buf, len
+    );
+    if (res == NULL) {
+        PyErr_Fetch(&ssl_obj->exc_type, &ssl_obj->exc_value, &ssl_obj->exc_tb);
+    } else {
+        Py_DECREF(res);
+    }
+    Py_XDECREF(ssl_socket);
+
+    PyGILState_Release(threadstate);
+}
+
+
+static PyObject *
+_PySSLContext_get_msg_callback(PySSLContext *self, void *c) {
+    if (self->msg_cb != NULL) {
+        Py_INCREF(self->msg_cb);
+        return self->msg_cb;
+    } else {
+        Py_RETURN_NONE;
+    }
+}
+
+static int
+_PySSLContext_set_msg_callback(PySSLContext *self, PyObject *arg, void *c) {
+    Py_CLEAR(self->msg_cb);
+    if (arg == Py_None) {
+        SSL_CTX_set_msg_callback(self->ctx, NULL);
+    }
+    else {
+        if (!PyCallable_Check(arg)) {
+            SSL_CTX_set_msg_callback(self->ctx, NULL);
+            PyErr_SetString(PyExc_TypeError,
+                            "not a callable object");
+            return -1;
+        }
+        Py_INCREF(arg);
+        self->msg_cb = arg;
+        SSL_CTX_set_msg_callback(self->ctx, _PySSL_msg_callback);
+    }
+    return 0;
+}
+
+#ifdef HAVE_OPENSSL_KEYLOG
+
+static void
+_PySSL_keylog_callback(const SSL *ssl, const char *line)
+{
+    PyGILState_STATE threadstate;
+    PySSLSocket *ssl_obj = NULL;  /* ssl._SSLSocket, borrowed ref */
+    int res, e;
+    static PyThread_type_lock *lock = NULL;
+
+    threadstate = PyGILState_Ensure();
+
+    /* Allocate a static lock to synchronize writes to keylog file.
+     * The lock is neither released on exit nor on fork(). The lock is
+     * also shared between all SSLContexts although contexts may write to
+     * their own files. IMHO that's good enough for a non-performance
+     * critical debug helper.
+     */
+    if (lock == NULL) {
+        lock = PyThread_allocate_lock();
+        if (lock == NULL) {
+            PyErr_SetString(PyExc_MemoryError, "Unable to allocate lock");
+            PyErr_Fetch(&ssl_obj->exc_type, &ssl_obj->exc_value,
+                        &ssl_obj->exc_tb);
+            return;
+        }
+    }
+
+    ssl_obj = (PySSLSocket *)SSL_get_app_data(ssl);
+    assert(PySSLSocket_Check(ssl_obj));
+    if (ssl_obj->ctx->keylog_bio == NULL) {
+        return;
+    }
+
+    PySSL_BEGIN_ALLOW_THREADS
+    PyThread_acquire_lock(lock, 1);
+    res = BIO_printf(ssl_obj->ctx->keylog_bio, "%s\n", line);
+    e = errno;
+    (void)BIO_flush(ssl_obj->ctx->keylog_bio);
+    PyThread_release_lock(lock);
+    PySSL_END_ALLOW_THREADS
+
+    if (res == -1) {
+        errno = e;
+        PyErr_SetFromErrnoWithFilenameObject(PyExc_OSError,
+                                             ssl_obj->ctx->keylog_filename);
+        PyErr_Fetch(&ssl_obj->exc_type, &ssl_obj->exc_value, &ssl_obj->exc_tb);
+    }
+    PyGILState_Release(threadstate);
+}
+
+static PyObject *
+_PySSLContext_get_keylog_filename(PySSLContext *self, void *c) {
+    if (self->keylog_filename != NULL) {
+        Py_INCREF(self->keylog_filename);
+        return self->keylog_filename;
+    } else {
+        Py_RETURN_NONE;
+    }
+}
+
+static int
+_PySSLContext_set_keylog_filename(PySSLContext *self, PyObject *arg, void *c) {
+    FILE *fp;
+    /* Reset variables and callback first */
+    SSL_CTX_set_keylog_callback(self->ctx, NULL);
+    Py_CLEAR(self->keylog_filename);
+    if (self->keylog_bio != NULL) {
+        BIO *bio = self->keylog_bio;
+        self->keylog_bio = NULL;
+        PySSL_BEGIN_ALLOW_THREADS
+        BIO_free_all(bio);
+        PySSL_END_ALLOW_THREADS
+    }
+
+    if (arg == Py_None) {
+        /* None disables the callback */
+        return 0;
+    }
+
+    /* _Py_fopen_obj() also checks that arg is of proper type. */
+    fp = _Py_fopen_obj(arg, "a" PY_STDIOTEXTMODE);
+    if (fp == NULL)
+        return -1;
+
+    self->keylog_bio = BIO_new_fp(fp, BIO_CLOSE | BIO_FP_TEXT);
+    if (self->keylog_bio == NULL) {
+        PyErr_SetString(PySSLErrorObject,
+                        "Can't malloc memory for keylog file");
+        return -1;
+    }
+    Py_INCREF(arg);
+    self->keylog_filename = arg;
+
+    /* Write a header for seekable, empty files (this excludes pipes). */
+    PySSL_BEGIN_ALLOW_THREADS
+    if (BIO_tell(self->keylog_bio) == 0) {
+        BIO_puts(self->keylog_bio,
+                 "# TLS secrets log file, generated by OpenSSL / Python\n");
+        (void)BIO_flush(self->keylog_bio);
+    }
+    PySSL_END_ALLOW_THREADS
+    SSL_CTX_set_keylog_callback(self->ctx, _PySSL_keylog_callback);
+    return 0;
+}
+
+#endif
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 96a49b4e353c..7852c2dfa27e 100644
--- a/setup.py
+++ b/setup.py
@@ -2178,11 +2178,13 @@ def split_var(name, sep):
             ssl_incs.extend(krb5_h)
 
         if config_vars.get("HAVE_X509_VERIFY_PARAM_SET1_HOST"):
-            self.add(Extension('_ssl', ['_ssl.c'],
-                               include_dirs=openssl_includes,
-                               library_dirs=openssl_libdirs,
-                               libraries=openssl_libs,
-                               depends=['socketmodule.h']))
+            self.add(Extension(
+                '_ssl', ['_ssl.c'],
+                include_dirs=openssl_includes,
+                library_dirs=openssl_libdirs,
+                libraries=openssl_libs,
+                depends=['socketmodule.h', '_ssl/debughelpers.c'])
+            )
         else:
             self.missing.append('_ssl')
 



More information about the Python-checkins mailing list