[Python-checkins] [3.10] bpo-20499: Rounding error in statistics.pvariance (GH-28230) (GH-28248)

rhettinger webhook-mailer at python.org
Wed Sep 8 23:42:39 EDT 2021


https://github.com/python/cpython/commit/3c30805b58421a1e2aa613052b6d45899f9b1b5d
commit: 3c30805b58421a1e2aa613052b6d45899f9b1b5d
branch: 3.10
author: Raymond Hettinger <rhettinger at users.noreply.github.com>
committer: rhettinger <rhettinger at users.noreply.github.com>
date: 2021-09-08T22:42:29-05:00
summary:

[3.10] bpo-20499: Rounding error in statistics.pvariance (GH-28230) (GH-28248)

files:
A Misc/NEWS.d/next/Library/2021-09-08-01-19-31.bpo-20499.tSxx8Y.rst
M Lib/statistics.py
M Lib/test/test_statistics.py

diff --git a/Lib/statistics.py b/Lib/statistics.py
index 268cc71a0952b..cfcc456fd786e 100644
--- a/Lib/statistics.py
+++ b/Lib/statistics.py
@@ -147,21 +147,17 @@ class StatisticsError(ValueError):
 
 # === Private utilities ===
 
-def _sum(data, start=0):
-    """_sum(data [, start]) -> (type, sum, count)
+def _sum(data):
+    """_sum(data) -> (type, sum, count)
 
     Return a high-precision sum of the given numeric data as a fraction,
     together with the type to be converted to and the count of items.
 
-    If optional argument ``start`` is given, it is added to the total.
-    If ``data`` is empty, ``start`` (defaulting to 0) is returned.
-
-
     Examples
     --------
 
-    >>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
-    (<class 'float'>, Fraction(11, 1), 5)
+    >>> _sum([3, 2.25, 4.5, -0.5, 0.25])
+    (<class 'float'>, Fraction(19, 2), 5)
 
     Some sources of round-off error will be avoided:
 
@@ -184,10 +180,9 @@ def _sum(data, start=0):
     allowed.
     """
     count = 0
-    n, d = _exact_ratio(start)
-    partials = {d: n}
+    partials = {}
     partials_get = partials.get
-    T = _coerce(int, type(start))
+    T = int
     for typ, values in groupby(data, type):
         T = _coerce(T, typ)  # or raise TypeError
         for n, d in map(_exact_ratio, values):
@@ -200,8 +195,7 @@ def _sum(data, start=0):
         assert not _isfinite(total)
     else:
         # Sum all the partial sums using builtin sum.
-        # FIXME is this faster if we sum them in order of the denominator?
-        total = sum(Fraction(n, d) for d, n in sorted(partials.items()))
+        total = sum(Fraction(n, d) for d, n in partials.items())
     return (T, total, count)
 
 
@@ -252,27 +246,19 @@ def _exact_ratio(x):
     x is expected to be an int, Fraction, Decimal or float.
     """
     try:
-        # Optimise the common case of floats. We expect that the most often
-        # used numeric type will be builtin floats, so try to make this as
-        # fast as possible.
-        if type(x) is float or type(x) is Decimal:
-            return x.as_integer_ratio()
-        try:
-            # x may be an int, Fraction, or Integral ABC.
-            return (x.numerator, x.denominator)
-        except AttributeError:
-            try:
-                # x may be a float or Decimal subclass.
-                return x.as_integer_ratio()
-            except AttributeError:
-                # Just give up?
-                pass
+        return x.as_integer_ratio()
+    except AttributeError:
+        pass
     except (OverflowError, ValueError):
         # float NAN or INF.
         assert not _isfinite(x)
         return (x, None)
-    msg = "can't convert type '{}' to numerator/denominator"
-    raise TypeError(msg.format(type(x).__name__))
+    try:
+        # x may be an Integral ABC.
+        return (x.numerator, x.denominator)
+    except AttributeError:
+        msg = f"can't convert type '{type(x).__name__}' to numerator/denominator"
+        raise TypeError(msg)
 
 
 def _convert(value, T):
@@ -719,14 +705,20 @@ def _ss(data, c=None):
     if c is not None:
         T, total, count = _sum((x-c)**2 for x in data)
         return (T, total)
-    c = mean(data)
-    T, total, count = _sum((x-c)**2 for x in data)
-    # The following sum should mathematically equal zero, but due to rounding
-    # error may not.
-    U, total2, count2 = _sum((x - c) for x in data)
-    assert T == U and count == count2
-    total -= total2 ** 2 / len(data)
-    assert not total < 0, 'negative sum of square deviations: %f' % total
+    T, total, count = _sum(data)
+    mean_n, mean_d = (total / count).as_integer_ratio()
+    partials = Counter()
+    for n, d in map(_exact_ratio, data):
+        diff_n = n * mean_d - d * mean_n
+        diff_d = d * mean_d
+        partials[diff_d * diff_d] += diff_n * diff_n
+    if None in partials:
+        # The sum will be a NAN or INF. We can ignore all the finite
+        # partials, and just look at this special one.
+        total = partials[None]
+        assert not _isfinite(total)
+    else:
+        total = sum(Fraction(n, d) for d, n in partials.items())
     return (T, total)
 
 
@@ -830,6 +822,9 @@ def stdev(data, xbar=None):
     1.0810874155219827
 
     """
+    # Fixme: Despite the exact sum of squared deviations, some inaccuracy
+    # remain because there are two rounding steps.  The first occurs in
+    # the _convert() step for variance(), the second occurs in math.sqrt().
     var = variance(data, xbar)
     try:
         return var.sqrt()
@@ -846,6 +841,9 @@ def pstdev(data, mu=None):
     0.986893273527251
 
     """
+    # Fixme: Despite the exact sum of squared deviations, some inaccuracy
+    # remain because there are two rounding steps.  The first occurs in
+    # the _convert() step for pvariance(), the second occurs in math.sqrt().
     var = pvariance(data, mu)
     try:
         return var.sqrt()
diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py
index 436c420149489..adccfad7b8ed1 100644
--- a/Lib/test/test_statistics.py
+++ b/Lib/test/test_statistics.py
@@ -1247,20 +1247,14 @@ def test_empty_data(self):
         # Override test for empty data.
         for data in ([], (), iter([])):
             self.assertEqual(self.func(data), (int, Fraction(0), 0))
-            self.assertEqual(self.func(data, 23), (int, Fraction(23), 0))
-            self.assertEqual(self.func(data, 2.3), (float, Fraction(2.3), 0))
 
     def test_ints(self):
         self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]),
                          (int, Fraction(60), 8))
-        self.assertEqual(self.func([4, 2, 3, -8, 7], 1000),
-                         (int, Fraction(1008), 5))
 
     def test_floats(self):
         self.assertEqual(self.func([0.25]*20),
                          (float, Fraction(5.0), 20))
-        self.assertEqual(self.func([0.125, 0.25, 0.5, 0.75], 1.5),
-                         (float, Fraction(3.125), 4))
 
     def test_fractions(self):
         self.assertEqual(self.func([Fraction(1, 1000)]*500),
@@ -1281,14 +1275,6 @@ def test_compare_with_math_fsum(self):
         data = [random.uniform(-100, 1000) for _ in range(1000)]
         self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16)
 
-    def test_start_argument(self):
-        # Test that the optional start argument works correctly.
-        data = [random.uniform(1, 1000) for _ in range(100)]
-        t = self.func(data)[1]
-        self.assertEqual(t+42, self.func(data, 42)[1])
-        self.assertEqual(t-23, self.func(data, -23)[1])
-        self.assertEqual(t+Fraction(1e20), self.func(data, 1e20)[1])
-
     def test_strings_fail(self):
         # Sum of strings should fail.
         self.assertRaises(TypeError, self.func, [1, 2, 3], '999')
@@ -2077,6 +2063,13 @@ def test_decimals(self):
         self.assertEqual(result, exact)
         self.assertIsInstance(result, Decimal)
 
+    def test_accuracy_bug_20499(self):
+        data = [0, 0, 1]
+        exact = 2 / 9
+        result = self.func(data)
+        self.assertEqual(result, exact)
+        self.assertIsInstance(result, float)
+
 
 class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
     # Tests for sample variance.
@@ -2117,6 +2110,13 @@ def test_center_not_at_mean(self):
         self.assertEqual(self.func(data), 0.5)
         self.assertEqual(self.func(data, xbar=2.0), 1.0)
 
+    def test_accuracy_bug_20499(self):
+        data = [0, 0, 2]
+        exact = 4 / 3
+        result = self.func(data)
+        self.assertEqual(result, exact)
+        self.assertIsInstance(result, float)
+
 class TestPStdev(VarianceStdevMixin, NumericTestCase):
     # Tests for population standard deviation.
     def setUp(self):
diff --git a/Misc/NEWS.d/next/Library/2021-09-08-01-19-31.bpo-20499.tSxx8Y.rst b/Misc/NEWS.d/next/Library/2021-09-08-01-19-31.bpo-20499.tSxx8Y.rst
new file mode 100644
index 0000000000000..cbbe61ac4a269
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-09-08-01-19-31.bpo-20499.tSxx8Y.rst
@@ -0,0 +1 @@
+Improve the speed and accuracy of statistics.pvariance().



More information about the Python-checkins mailing list