[Python-checkins] Rework integer overflow path in math.prod and add more tests (GH-11809)

Pablo Galindo webhook-mailer at python.org
Sat Mar 9 14:18:13 EST 2019


https://github.com/python/cpython/commit/0411411c6b16a574144dfb59a7780b057ca8e750
commit: 0411411c6b16a574144dfb59a7780b057ca8e750
branch: master
author: Pablo Galindo <Pablogsal at gmail.com>
committer: GitHub <noreply at github.com>
date: 2019-03-09T19:18:08Z
summary:

Rework integer overflow path in math.prod and add more tests (GH-11809)

The overflow check was relying on undefined behaviour as it was using the result of the multiplication to do the check, and once the overflow has already happened, any operation on the result is undefined behaviour.

Some extra checks that exercise code paths related to this are also added.

files:
M Lib/test/test_math.py
M Modules/mathmodule.c

diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py
index 856b1e8ac11e..cb05dee0e0fd 100644
--- a/Lib/test/test_math.py
+++ b/Lib/test/test_math.py
@@ -1595,6 +1595,92 @@ def test_mtestfile(self):
             self.fail('Failures in test_mtestfile:\n  ' +
                       '\n  '.join(failures))
 
+    def test_prod(self):
+        prod = math.prod
+        self.assertEqual(prod([]), 1)
+        self.assertEqual(prod([], start=5), 5)
+        self.assertEqual(prod(list(range(2,8))), 5040)
+        self.assertEqual(prod(iter(list(range(2,8)))), 5040)
+        self.assertEqual(prod(range(1, 10), start=10), 3628800)
+
+        self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
+        self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
+        self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
+        self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
+
+        # Test overflow in fast-path for integers
+        self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
+        # Test overflow in fast-path for floats
+        self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
+
+        self.assertRaises(TypeError, prod)
+        self.assertRaises(TypeError, prod, 42)
+        self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
+        self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
+        self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
+        values = [bytearray(b'a'), bytearray(b'b')]
+        self.assertRaises(TypeError, prod, values, bytearray(b''))
+        self.assertRaises(TypeError, prod, [[1], [2], [3]])
+        self.assertRaises(TypeError, prod, [{2:3}])
+        self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
+        self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
+        with self.assertRaises(TypeError):
+            prod([10, 20], [30, 40])     # start is a keyword-only argument
+
+        self.assertEqual(prod([0, 1, 2, 3]), 0)
+        self.assertEqual(prod([1, 0, 2, 3]), 0)
+        self.assertEqual(prod([1, 2, 3, 0]), 0)
+
+        def _naive_prod(iterable, start=1):
+            for elem in iterable:
+                start *= elem
+            return start
+
+        # Big integers
+
+        iterable = range(1, 10000)
+        self.assertEqual(prod(iterable), _naive_prod(iterable))
+        iterable = range(-10000, -1)
+        self.assertEqual(prod(iterable), _naive_prod(iterable))
+        iterable = range(-1000, 1000)
+        self.assertEqual(prod(iterable), 0)
+
+        # Big floats
+
+        iterable = [float(x) for x in range(1, 1000)]
+        self.assertEqual(prod(iterable), _naive_prod(iterable))
+        iterable = [float(x) for x in range(-1000, -1)]
+        self.assertEqual(prod(iterable), _naive_prod(iterable))
+        iterable = [float(x) for x in range(-1000, 1000)]
+        self.assertIsNaN(prod(iterable))
+
+        # Float tests
+
+        self.assertIsNaN(prod([1, 2, 3, float("nan"), 2, 3]))
+        self.assertIsNaN(prod([1, 0, float("nan"), 2, 3]))
+        self.assertIsNaN(prod([1, float("nan"), 0, 3]))
+        self.assertIsNaN(prod([1, float("inf"), float("nan"),3]))
+        self.assertIsNaN(prod([1, float("-inf"), float("nan"),3]))
+        self.assertIsNaN(prod([1, float("nan"), float("inf"),3]))
+        self.assertIsNaN(prod([1, float("nan"), float("-inf"),3]))
+
+        self.assertEqual(prod([1, 2, 3, float('inf'),-3,4]), float('-inf'))
+        self.assertEqual(prod([1, 2, 3, float('-inf'),-3,4]), float('inf'))
+
+        self.assertIsNaN(prod([1,2,0,float('inf'), -3, 4]))
+        self.assertIsNaN(prod([1,2,0,float('-inf'), -3, 4]))
+        self.assertIsNaN(prod([1, 2, 3, float('inf'), -3, 0, 3]))
+        self.assertIsNaN(prod([1, 2, 3, float('-inf'), -3, 0, 2]))
+
+        # Type preservation
+
+        self.assertEqual(type(prod([1, 2, 3, 4, 5, 6])), int)
+        self.assertEqual(type(prod([1, 2.0, 3, 4, 5, 6])), float)
+        self.assertEqual(type(prod(range(1, 10000))), int)
+        self.assertEqual(type(prod(range(1, 10000), start=1.0)), float)
+        self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])),
+                         decimal.Decimal)
+
     # Custom assertions.
 
     def assertIsNaN(self, value):
@@ -1724,41 +1810,6 @@ def test_fractions(self):
         self.assertAllClose(fraction_examples, rel_tol=1e-8)
         self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
 
-    def test_prod(self):
-        prod = math.prod
-        self.assertEqual(prod([]), 1)
-        self.assertEqual(prod([], start=5), 5)
-        self.assertEqual(prod(list(range(2,8))), 5040)
-        self.assertEqual(prod(iter(list(range(2,8)))), 5040)
-        self.assertEqual(prod(range(1, 10), start=10), 3628800)
-
-        self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
-        self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
-        self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
-        self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
-
-        # Test overflow in fast-path for integers
-        self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
-        # Test overflow in fast-path for floats
-        self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
-
-        self.assertRaises(TypeError, prod)
-        self.assertRaises(TypeError, prod, 42)
-        self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
-        self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
-        self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
-        values = [bytearray(b'a'), bytearray(b'b')]
-        self.assertRaises(TypeError, prod, values, bytearray(b''))
-        self.assertRaises(TypeError, prod, [[1], [2], [3]])
-        self.assertRaises(TypeError, prod, [{2:3}])
-        self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
-        self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
-        with self.assertRaises(TypeError):
-            prod([10, 20], [30, 40])     # start is a keyword-only argument
-
-        self.assertEqual(prod([0, 1, 2, 3]), 0)
-        self.assertEqual(prod([1, 0, 2, 3]), 0)
-        self.assertEqual(prod(range(10)), 0)
 
 def test_main():
     from doctest import DocFileSuite
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index fd0eb327c743..ba8423211c2b 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -2493,6 +2493,55 @@ math_isclose_impl(PyObject *module, double a, double b, double rel_tol,
             (diff <= abs_tol));
 }
 
+static inline int
+_check_long_mult_overflow(long a, long b) {
+
+    /* From Python2's int_mul code:
+
+    Integer overflow checking for * is painful:  Python tried a couple ways, but
+    they didn't work on all platforms, or failed in endcases (a product of
+    -sys.maxint-1 has been a particular pain).
+
+    Here's another way:
+
+    The native long product x*y is either exactly right or *way* off, being
+    just the last n bits of the true product, where n is the number of bits
+    in a long (the delivered product is the true product plus i*2**n for
+    some integer i).
+
+    The native double product (double)x * (double)y is subject to three
+    rounding errors:  on a sizeof(long)==8 box, each cast to double can lose
+    info, and even on a sizeof(long)==4 box, the multiplication can lose info.
+    But, unlike the native long product, it's not in *range* trouble:  even
+    if sizeof(long)==32 (256-bit longs), the product easily fits in the
+    dynamic range of a double.  So the leading 50 (or so) bits of the double
+    product are correct.
+
+    We check these two ways against each other, and declare victory if they're
+    approximately the same.  Else, because the native long product is the only
+    one that can lose catastrophic amounts of information, it's the native long
+    product that must have overflowed.
+
+    */
+
+    long longprod = (long)((unsigned long)a * b);
+    double doubleprod = (double)a * (double)b;
+    double doubled_longprod = (double)longprod;
+
+    if (doubled_longprod == doubleprod) {
+        return 0;
+    }
+
+    const double diff = doubled_longprod - doubleprod;
+    const double absdiff = diff >= 0.0 ? diff : -diff;
+    const double absprod = doubleprod >= 0.0 ? doubleprod : -doubleprod;
+
+    if (32.0 * absdiff <= absprod) {
+        return 0;
+    }
+
+    return 1;
+}
 
 /*[clinic input]
 math.prod
@@ -2558,11 +2607,8 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
             }
             if (PyLong_CheckExact(item)) {
                 long b = PyLong_AsLongAndOverflow(item, &overflow);
-                long x = i_result * b;
-                /* Continue if there is no overflow */
-                if (overflow == 0
-                    && x < LONG_MAX && x > LONG_MIN
-                    && !(b != 0 && x / b != i_result)) {
+                if (overflow == 0 && !_check_long_mult_overflow(i_result, b)) {
+                    long x = i_result * b;
                     i_result = x;
                     Py_DECREF(item);
                     continue;



More information about the Python-checkins mailing list