[Python-checkins] Optimize fmean() weighted average (#102626)

rhettinger webhook-mailer at python.org
Sun Mar 12 13:48:32 EDT 2023


https://github.com/python/cpython/commit/6cd7572f859a32a1f4626644c3e8139055df59e3
commit: 6cd7572f859a32a1f4626644c3e8139055df59e3
branch: main
author: Raymond Hettinger <rhettinger at users.noreply.github.com>
committer: rhettinger <rhettinger at users.noreply.github.com>
date: 2023-03-12T12:48:25-05:00
summary:

Optimize fmean() weighted average  (#102626)

files:
M Lib/statistics.py

diff --git a/Lib/statistics.py b/Lib/statistics.py
index 07d1fd5ba6e9..7d5d750193a5 100644
--- a/Lib/statistics.py
+++ b/Lib/statistics.py
@@ -136,9 +136,9 @@
 from decimal import Decimal
 from itertools import count, groupby, repeat
 from bisect import bisect_left, bisect_right
-from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
+from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum, sumprod
 from functools import reduce
-from operator import mul, itemgetter
+from operator import itemgetter
 from collections import Counter, namedtuple, defaultdict
 
 _SQRT2 = sqrt(2.0)
@@ -496,28 +496,26 @@ def fmean(data, weights=None):
     >>> fmean([3.5, 4.0, 5.25])
     4.25
     """
-    try:
-        n = len(data)
-    except TypeError:
-        # Handle iterators that do not define __len__().
-        n = 0
-        def count(iterable):
-            nonlocal n
-            for n, x in enumerate(iterable, start=1):
-                yield x
-        data = count(data)
     if weights is None:
+        try:
+            n = len(data)
+        except TypeError:
+            # Handle iterators that do not define __len__().
+            n = 0
+            def count(iterable):
+                nonlocal n
+                for n, x in enumerate(iterable, start=1):
+                    yield x
+            data = count(data)
         total = fsum(data)
         if not n:
             raise StatisticsError('fmean requires at least one data point')
         return total / n
-    try:
-        num_weights = len(weights)
-    except TypeError:
+    if not isinstance(weights, (list, tuple)):
         weights = list(weights)
-        num_weights = len(weights)
-    num = fsum(map(mul, data, weights))
-    if n != num_weights:
+    try:
+        num = sumprod(data, weights)
+    except ValueError:
         raise StatisticsError('data and weights must be the same length')
     den = fsum(weights)
     if not den:



More information about the Python-checkins mailing list