Most pythonic way of weighted random selection

Steven D'Aprano steve at REMOVE-THIS-cybersource.com.au
Sat Aug 30 22:02:17 EDT 2008


On Sat, 30 Aug 2008 17:41:27 +0200, Manuel Ebert wrote:

> Dear list,
> 
> who's got aesthetic advice for the following problem? 

...

[ugly code removed]
> Now that looks plain ugly, and I wonder whether you might find a
> slightly more elegant way of doing it without using numpy and the like.

Never be afraid to factor out pieces of code into small functions. 
Writing a single huge while loop that does everything is not only hard to 
read, hard to write and hard to debug, but it can also run slower. (It 
depends on a number of factors.) 

Anyway, here's my attempt to solve the problem, as best as I can 
understand it:


import random

def eq(x, y, tol=1e-10):
    # floating point equality within some tolerance
    return abs(x-y) <= tol

M = [[0.2, 0.4, 0.05], [0.1, 0.05, 0.2]]
# the sums of each row must sum to 1.0
assert eq(1.0, sum([sum(row) for row in M]))

# build a cumulative probability matrix
CM = []
for row in M:
    for p in row:
        CM.append(p)  # initialize with the raw probabilities

for i in range(1, len(CM)):
    CM[i] += CM[i-1]  # and turn into cumulative probabilities

assert CM[0] >= 0.0
assert eq(CM[-1], 1.0)

def index(data, p):
    """Return the index of the item in data
    which is no smaller than float p.
    """
    # Note: this uses a linear search. If it is too slow,
    # you can re-write it using the bisect module.
    for i, x in enumerate(data):
        if x >= p:
            return i
    return len(data-1)

def index_to_rowcolumn(i, num_columns):
    """Convert a linear index number i into a (row, column) tuple."""
    # When converting [ [a, b, c, ...], [...] ] into a single
    # array [a, b, c, ... z] we have the identity:
    # index number = row number * number of columns + column number
    return divmod(i, num_columns)

# Now with these two helper functions, we can find the row and column 
# number of the first entry in M where the cumulative probability 
# exceeds some given value.

# You will need to define your own fulfills_criterion_a and
# fulfills_criterion_b, but here's a couple of mock functions 
# for testing with:

def fulfills_criterion_a(row, column):
    return random.random() < 0.5

fulfills_criterion_b = fulfills_criterion_a

def find_match(p=0.2):
    while True:
        r = random.random()
        i = index(CM, r)
        row, column = index_to_rowcolumn(i, len(M[0]))
        if fulfills_criterion_a(row, column) or \
        fulfills_criterion_b(row, column):
            return row, column
        else:
            if random.random() <= p:
                return row, column


And here's my test:

>>> find_match()
(1, 2)


Hope this helps.


-- 
Steven



More information about the Python-list mailing list