[Python-checkins] bpo-45876: Correctly rounded stdev() and pstdev() for the Decimal case (GH-29828)

rhettinger webhook-mailer at python.org
Tue Nov 30 19:20:17 EST 2021


https://github.com/python/cpython/commit/a39f46afdead515e7ac3722464b5ee8d7b0b2c9b
commit: a39f46afdead515e7ac3722464b5ee8d7b0b2c9b
branch: main
author: Raymond Hettinger <rhettinger at users.noreply.github.com>
committer: rhettinger <rhettinger at users.noreply.github.com>
date: 2021-11-30T18:20:08-06:00
summary:

bpo-45876:  Correctly rounded stdev() and pstdev() for the Decimal case (GH-29828)

files:
M Lib/statistics.py
M Lib/test/test_statistics.py

diff --git a/Lib/statistics.py b/Lib/statistics.py
index cf8eaa0a61e62..9f1efa21b15e3 100644
--- a/Lib/statistics.py
+++ b/Lib/statistics.py
@@ -137,7 +137,7 @@
 from itertools import groupby, repeat
 from bisect import bisect_left, bisect_right
 from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
-from operator import itemgetter, mul
+from operator import mul
 from collections import Counter, namedtuple
 
 _SQRT2 = sqrt(2.0)
@@ -248,6 +248,28 @@ def _exact_ratio(x):
 
     x is expected to be an int, Fraction, Decimal or float.
     """
+
+    # XXX We should revisit whether using fractions to accumulate exact
+    # ratios is the right way to go.
+
+    # The integer ratios for binary floats can have numerators or
+    # denominators with over 300 decimal digits.  The problem is more
+    # acute with decimal floats where the the default decimal context
+    # supports a huge range of exponents from Emin=-999999 to
+    # Emax=999999.  When expanded with as_integer_ratio(), numbers like
+    # Decimal('3.14E+5000') and Decimal('3.14E-5000') have large
+    # numerators or denominators that will slow computation.
+
+    # When the integer ratios are accumulated as fractions, the size
+    # grows to cover the full range from the smallest magnitude to the
+    # largest.  For example, Fraction(3.14E+300) + Fraction(3.14E-300),
+    # has a 616 digit numerator.  Likewise,
+    # Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000'))
+    # has 10,003 digit numerator.
+
+    # This doesn't seem to have been problem in practice, but it is a
+    # potential pitfall.
+
     try:
         return x.as_integer_ratio()
     except AttributeError:
@@ -305,28 +327,60 @@ def _fail_neg(values, errmsg='negative value'):
             raise StatisticsError(errmsg)
         yield x
 
-def _isqrt_frac_rto(n: int, m: int) -> float:
+
+def _integer_sqrt_of_frac_rto(n: int, m: int) -> int:
     """Square root of n/m, rounded to the nearest integer using round-to-odd."""
     # Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
     a = math.isqrt(n // m)
     return a | (a*a*m != n)
 
-# For 53 bit precision floats, the _sqrt_frac() shift is 109.
-_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3
 
-def _sqrt_frac(n: int, m: int) -> float:
+# For 53 bit precision floats, the bit width used in
+# _float_sqrt_of_frac() is 109.
+_sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3
+
+
+def _float_sqrt_of_frac(n: int, m: int) -> float:
     """Square root of n/m as a float, correctly rounded."""
     # See principle and proof sketch at: https://bugs.python.org/msg407078
-    q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2
+    q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2
     if q >= 0:
-        numerator = _isqrt_frac_rto(n, m << 2 * q) << q
+        numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q
         denominator = 1
     else:
-        numerator = _isqrt_frac_rto(n << -2 * q, m)
+        numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m)
         denominator = 1 << -q
     return numerator / denominator   # Convert to float
 
 
+def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal:
+    """Square root of n/m as a Decimal, correctly rounded."""
+    # Premise:  For decimal, computing (n/m).sqrt() can be off
+    #           by 1 ulp from the correctly rounded result.
+    # Method:   Check the result, moving up or down a step if needed.
+    if n <= 0:
+        if not n:
+            return Decimal('0.0')
+        n, m = -n, -m
+
+    root = (Decimal(n) / Decimal(m)).sqrt()
+    nr, dr = root.as_integer_ratio()
+
+    plus = root.next_plus()
+    np, dp = plus.as_integer_ratio()
+    # test: n / m > ((root + plus) / 2) ** 2
+    if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2:
+        return plus
+
+    minus = root.next_minus()
+    nm, dm = minus.as_integer_ratio()
+    # test: n / m < ((root + minus) / 2) ** 2
+    if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2:
+        return minus
+
+    return root
+
+
 # === Measures of central tendency (averages) ===
 
 def mean(data):
@@ -869,7 +923,7 @@ def stdev(data, xbar=None):
     if hasattr(T, 'sqrt'):
         var = _convert(mss, T)
         return var.sqrt()
-    return _sqrt_frac(mss.numerator, mss.denominator)
+    return _float_sqrt_of_frac(mss.numerator, mss.denominator)
 
 
 def pstdev(data, mu=None):
@@ -888,10 +942,9 @@ def pstdev(data, mu=None):
         raise StatisticsError('pstdev requires at least one data point')
     T, ss = _ss(data, mu)
     mss = ss / n
-    if hasattr(T, 'sqrt'):
-        var = _convert(mss, T)
-        return var.sqrt()
-    return _sqrt_frac(mss.numerator, mss.denominator)
+    if issubclass(T, Decimal):
+        return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
+    return _float_sqrt_of_frac(mss.numerator, mss.denominator)
 
 
 # === Statistics for relations between two inputs ===
diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py
index 771a03e707ee0..bacb76a9b036b 100644
--- a/Lib/test/test_statistics.py
+++ b/Lib/test/test_statistics.py
@@ -2164,9 +2164,9 @@ def test_center_not_at_mean(self):
 
 class TestSqrtHelpers(unittest.TestCase):
 
-    def test_isqrt_frac_rto(self):
+    def test_integer_sqrt_of_frac_rto(self):
         for n, m in itertools.product(range(100), range(1, 1000)):
-            r = statistics._isqrt_frac_rto(n, m)
+            r = statistics._integer_sqrt_of_frac_rto(n, m)
             self.assertIsInstance(r, int)
             if r*r*m == n:
                 # Root is exact
@@ -2177,7 +2177,7 @@ def test_isqrt_frac_rto(self):
             self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2)
 
     @requires_IEEE_754
-    def test_sqrt_frac(self):
+    def test_float_sqrt_of_frac(self):
 
         def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
             if not x:
@@ -2204,22 +2204,59 @@ def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
             denonimator: int = randrange(10 ** randrange(50)) + 1
             with self.subTest(numerator=numerator, denonimator=denonimator):
                 x: Fraction = Fraction(numerator, denonimator)
-                root: float = statistics._sqrt_frac(numerator, denonimator)
+                root: float = statistics._float_sqrt_of_frac(numerator, denonimator)
                 self.assertTrue(is_root_correctly_rounded(x, root))
 
         # Verify that corner cases and error handling match math.sqrt()
-        self.assertEqual(statistics._sqrt_frac(0, 1), 0.0)
+        self.assertEqual(statistics._float_sqrt_of_frac(0, 1), 0.0)
         with self.assertRaises(ValueError):
-            statistics._sqrt_frac(-1, 1)
+            statistics._float_sqrt_of_frac(-1, 1)
         with self.assertRaises(ValueError):
-            statistics._sqrt_frac(1, -1)
+            statistics._float_sqrt_of_frac(1, -1)
 
         # Error handling for zero denominator matches that for Fraction(1, 0)
         with self.assertRaises(ZeroDivisionError):
-            statistics._sqrt_frac(1, 0)
+            statistics._float_sqrt_of_frac(1, 0)
 
         # The result is well defined if both inputs are negative
-        self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0))
+        self.assertEqual(statistics._float_sqrt_of_frac(-2, -1), statistics._float_sqrt_of_frac(2, 1))
+
+    def test_decimal_sqrt_of_frac(self):
+        root: Decimal
+        numerator: int
+        denominator: int
+
+        for root, numerator, denominator in [
+            (Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000),  # No adj
+            (Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000),  # Adj up
+            (Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000),  # Adj down
+        ]:
+            with decimal.localcontext(decimal.DefaultContext):
+                self.assertEqual(statistics._decimal_sqrt_of_frac(numerator, denominator), root)
+
+            # Confirm expected root with a quad precision decimal computation
+            with decimal.localcontext(decimal.DefaultContext) as ctx:
+                ctx.prec *= 4
+                high_prec_ratio = Decimal(numerator) / Decimal(denominator)
+                ctx.rounding = decimal.ROUND_05UP
+                high_prec_root = high_prec_ratio.sqrt()
+            with decimal.localcontext(decimal.DefaultContext):
+                target_root = +high_prec_root
+            self.assertEqual(root, target_root)
+
+        # Verify that corner cases and error handling match Decimal.sqrt()
+        self.assertEqual(statistics._decimal_sqrt_of_frac(0, 1), 0.0)
+        with self.assertRaises(decimal.InvalidOperation):
+            statistics._decimal_sqrt_of_frac(-1, 1)
+        with self.assertRaises(decimal.InvalidOperation):
+            statistics._decimal_sqrt_of_frac(1, -1)
+
+        # Error handling for zero denominator matches that for Fraction(1, 0)
+        with self.assertRaises(ZeroDivisionError):
+            statistics._decimal_sqrt_of_frac(1, 0)
+
+        # The result is well defined if both inputs are negative
+        self.assertEqual(statistics._decimal_sqrt_of_frac(-2, -1), statistics._decimal_sqrt_of_frac(2, 1))
 
 
 class TestStdev(VarianceStdevMixin, NumericTestCase):



More information about the Python-checkins mailing list