[Python-checkins] Extend _sqrtprod() to cover the full range of inputs. Add tests. (GH-107855)

rhettinger webhook-mailer at python.org
Fri Aug 11 12:19:23 EDT 2023


https://github.com/python/cpython/commit/52e0797f8e1c631eecf24cb3f997ace336f52271
commit: 52e0797f8e1c631eecf24cb3f997ace336f52271
branch: main
author: Raymond Hettinger <rhettinger at users.noreply.github.com>
committer: rhettinger <rhettinger at users.noreply.github.com>
date: 2023-08-11T11:19:19-05:00
summary:

Extend _sqrtprod() to cover the full range of inputs. Add tests. (GH-107855)

files:
M Lib/statistics.py
M Lib/test/test_statistics.py

diff --git a/Lib/statistics.py b/Lib/statistics.py
index 93c44f1f13fab..a8036e9928c46 100644
--- a/Lib/statistics.py
+++ b/Lib/statistics.py
@@ -137,6 +137,7 @@
 from itertools import count, groupby, repeat
 from bisect import bisect_left, bisect_right
 from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum, sumprod
+from math import isfinite, isinf
 from functools import reduce
 from operator import itemgetter
 from collections import Counter, namedtuple, defaultdict
@@ -1005,14 +1006,25 @@ def _mean_stdev(data):
         return float(xbar), float(xbar) / float(ss)
 
 def _sqrtprod(x: float, y: float) -> float:
-    "Return sqrt(x * y) computed with high accuracy."
-    # Square root differential correction:
-    # https://www.wolframalpha.com/input/?i=Maclaurin+series+sqrt%28h**2+%2B+x%29+at+x%3D0
+    "Return sqrt(x * y) computed with improved accuracy and without overflow/underflow."
     h = sqrt(x * y)
+    if not isfinite(h):
+        if isinf(h) and not isinf(x) and not isinf(y):
+            # Finite inputs overflowed, so scale down, and recompute.
+            scale = 2.0 ** -512  # sqrt(1 / sys.float_info.max)
+            return _sqrtprod(scale * x, scale * y) / scale
+        return h
     if not h:
-        return 0.0
-    x = sumprod((x, h), (y, -h))
-    return h + x / (2.0 * h)
+        if x and y:
+            # Non-zero inputs underflowed, so scale up, and recompute.
+            # Scale:  1 / sqrt(sys.float_info.min * sys.float_info.epsilon)
+            scale = 2.0 ** 537
+            return _sqrtprod(scale * x, scale * y) / scale
+        return h
+    # Improve accuracy with a differential correction.
+    # https://www.wolframalpha.com/input/?i=Maclaurin+series+sqrt%28h**2+%2B+x%29+at+x%3D0
+    d = sumprod((x, h), (y, -h))
+    return h + d / (2.0 * h)
 
 
 # === Statistics for relations between two inputs ===
diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py
index f0fa6454b1f91..aa2cf2b1edc58 100644
--- a/Lib/test/test_statistics.py
+++ b/Lib/test/test_statistics.py
@@ -28,6 +28,12 @@
 
 # === Helper functions and class ===
 
+# Test copied from Lib/test/test_math.py
+# detect evidence of double-rounding: fsum is not always correctly
+# rounded on machines that suffer from double rounding.
+x, y = 1e16, 2.9999 # use temporary values to defeat peephole optimizer
+HAVE_DOUBLE_ROUNDING = (x + y == 1e16 + 4)
+
 def sign(x):
     """Return -1.0 for negatives, including -0.0, otherwise +1.0."""
     return math.copysign(1, x)
@@ -2564,6 +2570,79 @@ def test_different_scales(self):
         self.assertAlmostEqual(statistics.correlation(x, y), 1)
         self.assertAlmostEqual(statistics.covariance(x, y), 0.1)
 
+    def test_sqrtprod_helper_function_fundamentals(self):
+        # Verify that results are close to sqrt(x * y)
+        for i in range(100):
+            x = random.expovariate()
+            y = random.expovariate()
+            expected = math.sqrt(x * y)
+            actual = statistics._sqrtprod(x, y)
+            with self.subTest(x=x, y=y, expected=expected, actual=actual):
+                self.assertAlmostEqual(expected, actual)
+
+        x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661
+        self.assertEqual(statistics._sqrtprod(x, y), target)
+        self.assertNotEqual(math.sqrt(x * y), target)
+
+        # Test that range extremes avoid underflow and overflow
+        smallest = sys.float_info.min * sys.float_info.epsilon
+        self.assertEqual(statistics._sqrtprod(smallest, smallest), smallest)
+        biggest = sys.float_info.max
+        self.assertEqual(statistics._sqrtprod(biggest, biggest), biggest)
+
+        # Check special values and the sign of the result
+        special_values = [0.0, -0.0, 1.0, -1.0, 4.0, -4.0,
+                          math.nan, -math.nan, math.inf, -math.inf]
+        for x, y in itertools.product(special_values, repeat=2):
+            try:
+                expected = math.sqrt(x * y)
+            except ValueError:
+                expected = 'ValueError'
+            try:
+                actual = statistics._sqrtprod(x, y)
+            except ValueError:
+                actual = 'ValueError'
+            with self.subTest(x=x, y=y, expected=expected, actual=actual):
+                if isinstance(expected, str) and expected == 'ValueError':
+                    self.assertEqual(actual, 'ValueError')
+                    continue
+                self.assertIsInstance(actual, float)
+                if math.isnan(expected):
+                    self.assertTrue(math.isnan(actual))
+                    continue
+                self.assertEqual(actual, expected)
+                self.assertEqual(sign(actual), sign(expected))
+
+    @requires_IEEE_754
+    @unittest.skipIf(HAVE_DOUBLE_ROUNDING,
+                     "accuracy not guaranteed on machines with double rounding")
+    @support.cpython_only    # Allow for a weaker sumprod() implmentation
+    def test_sqrtprod_helper_function_improved_accuracy(self):
+        # Test a known example where accuracy is improved
+        x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661
+        self.assertEqual(statistics._sqrtprod(x, y), target)
+        self.assertNotEqual(math.sqrt(x * y), target)
+
+        def reference_value(x: float, y: float) -> float:
+            x = decimal.Decimal(x)
+            y = decimal.Decimal(y)
+            with decimal.localcontext() as ctx:
+                ctx.prec = 200
+                return float((x * y).sqrt())
+
+        # Verify that the new function with improved accuracy
+        # agrees with a reference value more often than old version.
+        new_agreements = 0
+        old_agreements = 0
+        for i in range(10_000):
+            x = random.expovariate()
+            y = random.expovariate()
+            new = statistics._sqrtprod(x, y)
+            old = math.sqrt(x * y)
+            ref = reference_value(x, y)
+            new_agreements += (new == ref)
+            old_agreements += (old == ref)
+        self.assertGreater(new_agreements, old_agreements)
 
     def test_correlation_spearman(self):
         # https://statistics.laerd.com/statistical-guides/spearmans-rank-order-correlation-statistical-guide-2.php



More information about the Python-checkins mailing list