[Python-checkins] cpython: Issue #18844: Add random.weighted_choices()

raymond.hettinger python-checkins at python.org
Tue Sep 6 20:15:59 EDT 2016


https://hg.python.org/cpython/rev/a5856153d942
changeset:   103181:a5856153d942
user:        Raymond Hettinger <python at rcn.com>
date:        Tue Sep 06 17:15:29 2016 -0700
summary:
  Issue #18844: Add random.weighted_choices()

files:
  Doc/library/random.rst  |  21 ++++++++
  Lib/random.py           |  28 +++++++++++-
  Lib/test/test_random.py |  68 +++++++++++++++++++++++++++++
  Misc/NEWS               |   2 +
  4 files changed, 118 insertions(+), 1 deletions(-)


diff --git a/Doc/library/random.rst b/Doc/library/random.rst
--- a/Doc/library/random.rst
+++ b/Doc/library/random.rst
@@ -124,6 +124,27 @@
    Return a random element from the non-empty sequence *seq*. If *seq* is empty,
    raises :exc:`IndexError`.
 
+.. function:: weighted_choices(k, population, weights=None, *, cum_weights=None)
+
+   Return a *k* sized list of elements chosen from the *population* with replacement.
+   If the *population* is empty, raises :exc:`IndexError`.
+
+   If a *weights* sequence is specified, selections are made according to the
+   relative weights.  Alternatively, if a *cum_weights* sequence is given, the
+   selections are made according to the cumulative weights.  For example, the
+   relative weights ``[10, 5, 30, 5]`` are equivalent to the cumulative
+   weights ``[10, 15, 45, 50]``.  Internally, the relative weights are
+   converted to cumulative weights before making selections, so supplying the
+   cumulative weights saves work.
+
+   If neither *weights* nor *cum_weights* are specified, selections are made
+   with equal probability.  If a weights sequence is supplied, it must be
+   the same length as the *population* sequence.  It is a :exc:`TypeError`
+   to specify both *weights* and *cum_weights*.
+
+   The *weights* or *cum_weights* can use any numeric type that interoperates
+   with the :class:`float` values returned by :func:`random` (that includes
+   integers, floats, and fractions but excludes decimals).
 
 .. function:: shuffle(x[, random])
 
diff --git a/Lib/random.py b/Lib/random.py
--- a/Lib/random.py
+++ b/Lib/random.py
@@ -8,6 +8,7 @@
     ---------
            pick random element
            pick random sample
+           pick weighted random sample
            generate random permutation
 
     distributions on the real line:
@@ -43,12 +44,14 @@
 from os import urandom as _urandom
 from _collections_abc import Set as _Set, Sequence as _Sequence
 from hashlib import sha512 as _sha512
+import itertools as _itertools
+import bisect as _bisect
 
 __all__ = ["Random","seed","random","uniform","randint","choice","sample",
            "randrange","shuffle","normalvariate","lognormvariate",
            "expovariate","vonmisesvariate","gammavariate","triangular",
            "gauss","betavariate","paretovariate","weibullvariate",
-           "getstate","setstate", "getrandbits",
+           "getstate","setstate", "getrandbits", "weighted_choices",
            "SystemRandom"]
 
 NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0)
@@ -334,6 +337,28 @@
                 result[i] = population[j]
         return result
 
+    def weighted_choices(self, k, population, weights=None, *, cum_weights=None):
+        """Return a k sized list of population elements chosen with replacement.
+
+        If the relative weights or cumulative weights are not specified,
+        the selections are made with equal probability.
+
+        """
+        if cum_weights is None:
+            if weights is None:
+                choice = self.choice
+                return [choice(population) for i in range(k)]
+            else:
+                cum_weights = list(_itertools.accumulate(weights))
+        elif weights is not None:
+            raise TypeError('Cannot specify both weights and cumulative_weights')
+        if len(cum_weights) != len(population):
+            raise ValueError('The number of weights does not match the population')
+        bisect = _bisect.bisect
+        random = self.random
+        total = cum_weights[-1]
+        return [population[bisect(cum_weights, random() * total)] for i in range(k)]
+
 ## -------------------- real-valued distributions  -------------------
 
 ## -------------------- uniform distribution -------------------
@@ -724,6 +749,7 @@
 randrange = _inst.randrange
 sample = _inst.sample
 shuffle = _inst.shuffle
+weighted_choices = _inst.weighted_choices
 normalvariate = _inst.normalvariate
 lognormvariate = _inst.lognormvariate
 expovariate = _inst.expovariate
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -7,6 +7,7 @@
 from functools import partial
 from math import log, exp, pi, fsum, sin
 from test import support
+from fractions import Fraction
 
 class TestBasicOps:
     # Superclass with tests common to all generators.
@@ -141,6 +142,73 @@
     def test_sample_on_dicts(self):
         self.assertRaises(TypeError, self.gen.sample, dict.fromkeys('abcdef'), 2)
 
+    def test_weighted_choices(self):
+        weighted_choices = self.gen.weighted_choices
+        data = ['red', 'green', 'blue', 'yellow']
+        str_data = 'abcd'
+        range_data = range(4)
+        set_data = set(range(4))
+
+        # basic functionality
+        for sample in [
+            weighted_choices(5, data),
+            weighted_choices(5, data, range(4)),
+            weighted_choices(k=5, population=data, weights=range(4)),
+            weighted_choices(k=5, population=data, cum_weights=range(4)),
+        ]:
+            self.assertEqual(len(sample), 5)
+            self.assertEqual(type(sample), list)
+            self.assertTrue(set(sample) <= set(data))
+
+        # test argument handling
+        with self.assertRaises(TypeError):                                        # missing arguments
+            weighted_choices(2)
+
+        self.assertEqual(weighted_choices(0, data), [])                           # k == 0
+        self.assertEqual(weighted_choices(-1, data), [])                          # negative k behaves like ``[0] * -1``
+        with self.assertRaises(TypeError):
+            weighted_choices(2.5, data)                                           # k is a float
+
+        self.assertTrue(set(weighted_choices(5, str_data)) <= set(str_data))      # population is a string sequence
+        self.assertTrue(set(weighted_choices(5, range_data)) <= set(range_data))  # population is a range
+        with self.assertRaises(TypeError):
+            weighted_choices(2.5, set_data)                                       # population is not a sequence
+
+        self.assertTrue(set(weighted_choices(5, data, None)) <= set(data))        # weights is None
+        self.assertTrue(set(weighted_choices(5, data, weights=None)) <= set(data))
+        with self.assertRaises(ValueError):
+            weighted_choices(5, data, [1,2])                                      # len(weights) != len(population)
+        with self.assertRaises(IndexError):
+            weighted_choices(5, data, [0]*4)                                      # weights sum to zero
+        with self.assertRaises(TypeError):
+            weighted_choices(5, data, 10)                                         # non-iterable weights
+        with self.assertRaises(TypeError):
+            weighted_choices(5, data, [None]*4)                                   # non-numeric weights
+        for weights in [
+                [15, 10, 25, 30],                                                 # integer weights
+                [15.1, 10.2, 25.2, 30.3],                                         # float weights
+                [Fraction(1, 3), Fraction(2, 6), Fraction(3, 6), Fraction(4, 6)], # fractional weights
+                [True, False, True, False]                                        # booleans (include / exclude)
+        ]:
+            self.assertTrue(set(weighted_choices(5, data, weights)) <= set(data))
+
+        with self.assertRaises(ValueError):
+            weighted_choices(5, data, cum_weights=[1,2])                          # len(weights) != len(population)
+        with self.assertRaises(IndexError):
+            weighted_choices(5, data, cum_weights=[0]*4)                          # cum_weights sum to zero
+        with self.assertRaises(TypeError):
+            weighted_choices(5, data, cum_weights=10)                             # non-iterable cum_weights
+        with self.assertRaises(TypeError):
+            weighted_choices(5, data, cum_weights=[None]*4)                       # non-numeric cum_weights
+        with self.assertRaises(TypeError):
+            weighted_choices(5, data, range(4), cum_weights=range(4))             # both weights and cum_weights
+        for weights in [
+                [15, 10, 25, 30],                                                 # integer cum_weights
+                [15.1, 10.2, 25.2, 30.3],                                         # float cum_weights
+                [Fraction(1, 3), Fraction(2, 6), Fraction(3, 6), Fraction(4, 6)], # fractional cum_weights
+        ]:
+            self.assertTrue(set(weighted_choices(5, data, cum_weights=weights)) <= set(data))
+
     def test_gauss(self):
         # Ensure that the seed() method initializes all the hidden state.  In
         # particular, through 2.2.1 it failed to reset a piece of state used
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -101,6 +101,8 @@
 - Issue #27691: Fix ssl module's parsing of GEN_RID subject alternative name
   fields in X.509 certs.
 
+- Issue #18844: Add random.weighted_choices().
+
 - Issue #25761: Improved error reporting about truncated pickle data in
   C implementation of unpickler.  UnpicklingError is now raised instead of
   AttributeError and ValueError in some cases.

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list