[Python-checkins] bpo-40791: Use CRYPTO_memcmp() for compare_digest (#20456)

Christian Heimes webhook-mailer at python.org
Wed May 27 15:50:12 EDT 2020


https://github.com/python/cpython/commit/db5aed931f8a617f7b63e773f62db468fe9c5ca1
commit: db5aed931f8a617f7b63e773f62db468fe9c5ca1
branch: master
author: Christian Heimes <christian at python.org>
committer: GitHub <noreply at github.com>
date: 2020-05-27T21:50:06+02:00
summary:

bpo-40791: Use CRYPTO_memcmp() for compare_digest (#20456)

hashlib.compare_digest uses OpenSSL's CRYPTO_memcmp() function
when OpenSSL is available.

Note: The _operator module is a builtin module. I don't want to add
libcrypto dependency to libpython. Therefore I duplicated the wrapper
function and added a copy to _hashopenssl.c.

files:
A Misc/NEWS.d/next/Library/2020-05-27-18-04-52.bpo-40791.IzpNor.rst
M Doc/library/hmac.rst
M Lib/hmac.py
M Lib/test/test_hmac.py
M Modules/_hashopenssl.c
M Modules/_operator.c
M Modules/clinic/_hashopenssl.c.h

diff --git a/Doc/library/hmac.rst b/Doc/library/hmac.rst
index 5ad348490eaf6..6f1b59b57ce58 100644
--- a/Doc/library/hmac.rst
+++ b/Doc/library/hmac.rst
@@ -138,6 +138,11 @@ This module also provides the following helper function:
 
    .. versionadded:: 3.3
 
+   .. versionchanged:: 3.10
+
+      The function uses OpenSSL's ``CRYPTO_memcmp()`` internally when
+      available.
+
 
 .. seealso::
 
diff --git a/Lib/hmac.py b/Lib/hmac.py
index 54a1ef9bdbdcf..180bc378b52d6 100644
--- a/Lib/hmac.py
+++ b/Lib/hmac.py
@@ -4,14 +4,15 @@
 """
 
 import warnings as _warnings
-from _operator import _compare_digest as compare_digest
 try:
     import _hashlib as _hashopenssl
 except ImportError:
     _hashopenssl = None
     _openssl_md_meths = None
+    from _operator import _compare_digest as compare_digest
 else:
     _openssl_md_meths = frozenset(_hashopenssl.openssl_md_meth_names)
+    compare_digest = _hashopenssl.compare_digest
 import hashlib as _hashlib
 
 trans_5C = bytes((x ^ 0x5C) for x in range(256))
diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py
index 7a52e39c5d471..6daf22ca06fb8 100644
--- a/Lib/test/test_hmac.py
+++ b/Lib/test/test_hmac.py
@@ -8,12 +8,16 @@
 
 from test.support import hashlib_helper
 
+from _operator import _compare_digest as operator_compare_digest
+
 try:
     from _hashlib import HMAC as C_HMAC
     from _hashlib import hmac_new as c_hmac_new
+    from _hashlib import compare_digest as openssl_compare_digest
 except ImportError:
     C_HMAC = None
     c_hmac_new = None
+    openssl_compare_digest = None
 
 
 def ignore_warning(func):
@@ -505,87 +509,101 @@ def test_equality_new(self):
 
 class CompareDigestTestCase(unittest.TestCase):
 
-    def test_compare_digest(self):
+    def test_hmac_compare_digest(self):
+        self._test_compare_digest(hmac.compare_digest)
+        if openssl_compare_digest is not None:
+            self.assertIs(hmac.compare_digest, openssl_compare_digest)
+        else:
+            self.assertIs(hmac.compare_digest, operator_compare_digest)
+
+    def test_operator_compare_digest(self):
+        self._test_compare_digest(operator_compare_digest)
+
+    @unittest.skipIf(openssl_compare_digest is None, "test requires _hashlib")
+    def test_openssl_compare_digest(self):
+        self._test_compare_digest(openssl_compare_digest)
+
+    def _test_compare_digest(self, compare_digest):
         # Testing input type exception handling
         a, b = 100, 200
-        self.assertRaises(TypeError, hmac.compare_digest, a, b)
+        self.assertRaises(TypeError, compare_digest, a, b)
         a, b = 100, b"foobar"
-        self.assertRaises(TypeError, hmac.compare_digest, a, b)
+        self.assertRaises(TypeError, compare_digest, a, b)
         a, b = b"foobar", 200
-        self.assertRaises(TypeError, hmac.compare_digest, a, b)
+        self.assertRaises(TypeError, compare_digest, a, b)
         a, b = "foobar", b"foobar"
-        self.assertRaises(TypeError, hmac.compare_digest, a, b)
+        self.assertRaises(TypeError, compare_digest, a, b)
         a, b = b"foobar", "foobar"
-        self.assertRaises(TypeError, hmac.compare_digest, a, b)
+        self.assertRaises(TypeError, compare_digest, a, b)
 
         # Testing bytes of different lengths
         a, b = b"foobar", b"foo"
-        self.assertFalse(hmac.compare_digest(a, b))
+        self.assertFalse(compare_digest(a, b))
         a, b = b"\xde\xad\xbe\xef", b"\xde\xad"
-        self.assertFalse(hmac.compare_digest(a, b))
+        self.assertFalse(compare_digest(a, b))
 
         # Testing bytes of same lengths, different values
         a, b = b"foobar", b"foobaz"
-        self.assertFalse(hmac.compare_digest(a, b))
+        self.assertFalse(compare_digest(a, b))
         a, b = b"\xde\xad\xbe\xef", b"\xab\xad\x1d\xea"
-        self.assertFalse(hmac.compare_digest(a, b))
+        self.assertFalse(compare_digest(a, b))
 
         # Testing bytes of same lengths, same values
         a, b = b"foobar", b"foobar"
-        self.assertTrue(hmac.compare_digest(a, b))
+        self.assertTrue(compare_digest(a, b))
         a, b = b"\xde\xad\xbe\xef", b"\xde\xad\xbe\xef"
-        self.assertTrue(hmac.compare_digest(a, b))
+        self.assertTrue(compare_digest(a, b))
 
         # Testing bytearrays of same lengths, same values
         a, b = bytearray(b"foobar"), bytearray(b"foobar")
-        self.assertTrue(hmac.compare_digest(a, b))
+        self.assertTrue(compare_digest(a, b))
 
         # Testing bytearrays of different lengths
         a, b = bytearray(b"foobar"), bytearray(b"foo")
-        self.assertFalse(hmac.compare_digest(a, b))
+        self.assertFalse(compare_digest(a, b))
 
         # Testing bytearrays of same lengths, different values
         a, b = bytearray(b"foobar"), bytearray(b"foobaz")
-        self.assertFalse(hmac.compare_digest(a, b))
+        self.assertFalse(compare_digest(a, b))
 
         # Testing byte and bytearray of same lengths, same values
         a, b = bytearray(b"foobar"), b"foobar"
-        self.assertTrue(hmac.compare_digest(a, b))
-        self.assertTrue(hmac.compare_digest(b, a))
+        self.assertTrue(compare_digest(a, b))
+        self.assertTrue(compare_digest(b, a))
 
         # Testing byte bytearray of different lengths
         a, b = bytearray(b"foobar"), b"foo"
-        self.assertFalse(hmac.compare_digest(a, b))
-        self.assertFalse(hmac.compare_digest(b, a))
+        self.assertFalse(compare_digest(a, b))
+        self.assertFalse(compare_digest(b, a))
 
         # Testing byte and bytearray of same lengths, different values
         a, b = bytearray(b"foobar"), b"foobaz"
-        self.assertFalse(hmac.compare_digest(a, b))
-        self.assertFalse(hmac.compare_digest(b, a))
+        self.assertFalse(compare_digest(a, b))
+        self.assertFalse(compare_digest(b, a))
 
         # Testing str of same lengths
         a, b = "foobar", "foobar"
-        self.assertTrue(hmac.compare_digest(a, b))
+        self.assertTrue(compare_digest(a, b))
 
         # Testing str of different lengths
         a, b = "foo", "foobar"
-        self.assertFalse(hmac.compare_digest(a, b))
+        self.assertFalse(compare_digest(a, b))
 
         # Testing bytes of same lengths, different values
         a, b = "foobar", "foobaz"
-        self.assertFalse(hmac.compare_digest(a, b))
+        self.assertFalse(compare_digest(a, b))
 
         # Testing error cases
         a, b = "foobar", b"foobar"
-        self.assertRaises(TypeError, hmac.compare_digest, a, b)
+        self.assertRaises(TypeError, compare_digest, a, b)
         a, b = b"foobar", "foobar"
-        self.assertRaises(TypeError, hmac.compare_digest, a, b)
+        self.assertRaises(TypeError, compare_digest, a, b)
         a, b = b"foobar", 1
-        self.assertRaises(TypeError, hmac.compare_digest, a, b)
+        self.assertRaises(TypeError, compare_digest, a, b)
         a, b = 100, 200
-        self.assertRaises(TypeError, hmac.compare_digest, a, b)
+        self.assertRaises(TypeError, compare_digest, a, b)
         a, b = "fooä", "fooä"
-        self.assertRaises(TypeError, hmac.compare_digest, a, b)
+        self.assertRaises(TypeError, compare_digest, a, b)
 
         # subclasses are supported by ignore __eq__
         class mystr(str):
@@ -593,22 +611,22 @@ def __eq__(self, other):
                 return False
 
         a, b = mystr("foobar"), mystr("foobar")
-        self.assertTrue(hmac.compare_digest(a, b))
+        self.assertTrue(compare_digest(a, b))
         a, b = mystr("foobar"), "foobar"
-        self.assertTrue(hmac.compare_digest(a, b))
+        self.assertTrue(compare_digest(a, b))
         a, b = mystr("foobar"), mystr("foobaz")
-        self.assertFalse(hmac.compare_digest(a, b))
+        self.assertFalse(compare_digest(a, b))
 
         class mybytes(bytes):
             def __eq__(self, other):
                 return False
 
         a, b = mybytes(b"foobar"), mybytes(b"foobar")
-        self.assertTrue(hmac.compare_digest(a, b))
+        self.assertTrue(compare_digest(a, b))
         a, b = mybytes(b"foobar"), b"foobar"
-        self.assertTrue(hmac.compare_digest(a, b))
+        self.assertTrue(compare_digest(a, b))
         a, b = mybytes(b"foobar"), mybytes(b"foobaz")
-        self.assertFalse(hmac.compare_digest(a, b))
+        self.assertFalse(compare_digest(a, b))
 
 
 if __name__ == "__main__":
diff --git a/Misc/NEWS.d/next/Library/2020-05-27-18-04-52.bpo-40791.IzpNor.rst b/Misc/NEWS.d/next/Library/2020-05-27-18-04-52.bpo-40791.IzpNor.rst
new file mode 100644
index 0000000000000..b88f308ec3b52
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2020-05-27-18-04-52.bpo-40791.IzpNor.rst
@@ -0,0 +1,2 @@
+:func:`hashlib.compare_digest` uses OpenSSL's ``CRYPTO_memcmp()`` function
+when OpenSSL is available.
diff --git a/Modules/_hashopenssl.c b/Modules/_hashopenssl.c
index 0b2ef95a6f126..adc8653773250 100644
--- a/Modules/_hashopenssl.c
+++ b/Modules/_hashopenssl.c
@@ -21,6 +21,7 @@
 /* EVP is the preferred interface to hashing in OpenSSL */
 #include <openssl/evp.h>
 #include <openssl/hmac.h>
+#include <openssl/crypto.h>
 /* We use the object interface to discover what hashes OpenSSL supports. */
 #include <openssl/objects.h>
 #include "openssl/err.h"
@@ -1833,6 +1834,120 @@ _hashlib_get_fips_mode_impl(PyObject *module)
 #endif  // !LIBRESSL_VERSION_NUMBER
 
 
+static int
+_tscmp(const unsigned char *a, const unsigned char *b,
+        Py_ssize_t len_a, Py_ssize_t len_b)
+{
+    /* loop count depends on length of b. Might leak very little timing
+     * information if sizes are different.
+     */
+    Py_ssize_t length = len_b;
+    const void *left = a;
+    const void *right = b;
+    int result = 0;
+
+    if (len_a != length) {
+        left = b;
+        result = 1;
+    }
+
+    result |= CRYPTO_memcmp(left, right, length);
+
+    return (result == 0);
+}
+
+/* NOTE: Keep in sync with _operator.c implementation. */
+
+/*[clinic input]
+_hashlib.compare_digest
+
+    a: object
+    b: object
+    /
+
+Return 'a == b'.
+
+This function uses an approach designed to prevent
+timing analysis, making it appropriate for cryptography.
+
+a and b must both be of the same type: either str (ASCII only),
+or any bytes-like object.
+
+Note: If a and b are of different lengths, or if an error occurs,
+a timing attack could theoretically reveal information about the
+types and lengths of a and b--but not their values.
+[clinic start generated code]*/
+
+static PyObject *
+_hashlib_compare_digest_impl(PyObject *module, PyObject *a, PyObject *b)
+/*[clinic end generated code: output=6f1c13927480aed9 input=9c40c6e566ca12f5]*/
+{
+    int rc;
+
+    /* ASCII unicode string */
+    if(PyUnicode_Check(a) && PyUnicode_Check(b)) {
+        if (PyUnicode_READY(a) == -1 || PyUnicode_READY(b) == -1) {
+            return NULL;
+        }
+        if (!PyUnicode_IS_ASCII(a) || !PyUnicode_IS_ASCII(b)) {
+            PyErr_SetString(PyExc_TypeError,
+                            "comparing strings with non-ASCII characters is "
+                            "not supported");
+            return NULL;
+        }
+
+        rc = _tscmp(PyUnicode_DATA(a),
+                    PyUnicode_DATA(b),
+                    PyUnicode_GET_LENGTH(a),
+                    PyUnicode_GET_LENGTH(b));
+    }
+    /* fallback to buffer interface for bytes, bytesarray and other */
+    else {
+        Py_buffer view_a;
+        Py_buffer view_b;
+
+        if (PyObject_CheckBuffer(a) == 0 && PyObject_CheckBuffer(b) == 0) {
+            PyErr_Format(PyExc_TypeError,
+                         "unsupported operand types(s) or combination of types: "
+                         "'%.100s' and '%.100s'",
+                         Py_TYPE(a)->tp_name, Py_TYPE(b)->tp_name);
+            return NULL;
+        }
+
+        if (PyObject_GetBuffer(a, &view_a, PyBUF_SIMPLE) == -1) {
+            return NULL;
+        }
+        if (view_a.ndim > 1) {
+            PyErr_SetString(PyExc_BufferError,
+                            "Buffer must be single dimension");
+            PyBuffer_Release(&view_a);
+            return NULL;
+        }
+
+        if (PyObject_GetBuffer(b, &view_b, PyBUF_SIMPLE) == -1) {
+            PyBuffer_Release(&view_a);
+            return NULL;
+        }
+        if (view_b.ndim > 1) {
+            PyErr_SetString(PyExc_BufferError,
+                            "Buffer must be single dimension");
+            PyBuffer_Release(&view_a);
+            PyBuffer_Release(&view_b);
+            return NULL;
+        }
+
+        rc = _tscmp((const unsigned char*)view_a.buf,
+                    (const unsigned char*)view_b.buf,
+                    view_a.len,
+                    view_b.len);
+
+        PyBuffer_Release(&view_a);
+        PyBuffer_Release(&view_b);
+    }
+
+    return PyBool_FromLong(rc);
+}
+
 /* List of functions exported by this module */
 
 static struct PyMethodDef EVP_functions[] = {
@@ -1840,6 +1955,7 @@ static struct PyMethodDef EVP_functions[] = {
     PBKDF2_HMAC_METHODDEF
     _HASHLIB_SCRYPT_METHODDEF
     _HASHLIB_GET_FIPS_MODE_METHODDEF
+    _HASHLIB_COMPARE_DIGEST_METHODDEF
     _HASHLIB_HMAC_SINGLESHOT_METHODDEF
     _HASHLIB_HMAC_NEW_METHODDEF
     _HASHLIB_OPENSSL_MD5_METHODDEF
diff --git a/Modules/_operator.c b/Modules/_operator.c
index 19026b6c38e60..8a54829e5bbcc 100644
--- a/Modules/_operator.c
+++ b/Modules/_operator.c
@@ -785,6 +785,8 @@ _operator_length_hint_impl(PyObject *module, PyObject *obj,
     return PyObject_LengthHint(obj, default_value);
 }
 
+/* NOTE: Keep in sync with _hashopenssl.c implementation. */
+
 /*[clinic input]
 _operator._compare_digest = _operator.eq
 
diff --git a/Modules/clinic/_hashopenssl.c.h b/Modules/clinic/_hashopenssl.c.h
index 619cb1c8516b8..51ae2402896c1 100644
--- a/Modules/clinic/_hashopenssl.c.h
+++ b/Modules/clinic/_hashopenssl.c.h
@@ -1338,6 +1338,46 @@ _hashlib_get_fips_mode(PyObject *module, PyObject *Py_UNUSED(ignored))
 
 #endif /* !defined(LIBRESSL_VERSION_NUMBER) */
 
+PyDoc_STRVAR(_hashlib_compare_digest__doc__,
+"compare_digest($module, a, b, /)\n"
+"--\n"
+"\n"
+"Return \'a == b\'.\n"
+"\n"
+"This function uses an approach designed to prevent\n"
+"timing analysis, making it appropriate for cryptography.\n"
+"\n"
+"a and b must both be of the same type: either str (ASCII only),\n"
+"or any bytes-like object.\n"
+"\n"
+"Note: If a and b are of different lengths, or if an error occurs,\n"
+"a timing attack could theoretically reveal information about the\n"
+"types and lengths of a and b--but not their values.");
+
+#define _HASHLIB_COMPARE_DIGEST_METHODDEF    \
+    {"compare_digest", (PyCFunction)(void(*)(void))_hashlib_compare_digest, METH_FASTCALL, _hashlib_compare_digest__doc__},
+
+static PyObject *
+_hashlib_compare_digest_impl(PyObject *module, PyObject *a, PyObject *b);
+
+static PyObject *
+_hashlib_compare_digest(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
+{
+    PyObject *return_value = NULL;
+    PyObject *a;
+    PyObject *b;
+
+    if (!_PyArg_CheckPositional("compare_digest", nargs, 2, 2)) {
+        goto exit;
+    }
+    a = args[0];
+    b = args[1];
+    return_value = _hashlib_compare_digest_impl(module, a, b);
+
+exit:
+    return return_value;
+}
+
 #ifndef EVPXOF_DIGEST_METHODDEF
     #define EVPXOF_DIGEST_METHODDEF
 #endif /* !defined(EVPXOF_DIGEST_METHODDEF) */
@@ -1377,4 +1417,4 @@ _hashlib_get_fips_mode(PyObject *module, PyObject *Py_UNUSED(ignored))
 #ifndef _HASHLIB_GET_FIPS_MODE_METHODDEF
     #define _HASHLIB_GET_FIPS_MODE_METHODDEF
 #endif /* !defined(_HASHLIB_GET_FIPS_MODE_METHODDEF) */
-/*[clinic end generated code: output=d8dddcd85fb11dde input=a9049054013a1b77]*/
+/*[clinic end generated code: output=95447a60132f039e input=a9049054013a1b77]*/



More information about the Python-checkins mailing list