[Python-checkins] python/nondist/sandbox/statistics statistics.py, NONE, 1.1 test_statistics.py, NONE, 1.1

rhettinger at users.sourceforge.net rhettinger at users.sourceforge.net
Sun Feb 15 01:03:59 EST 2004


Update of /cvsroot/python/python/nondist/sandbox/statistics
In directory sc8-pr-cvs1.sourceforge.net:/tmp/cvs-serv20081

Added Files:
	statistics.py test_statistics.py 
Log Message:
Load a draft version of a statistics module.

--- NEW FILE: statistics.py ---
"""statistics.py  A collection of functions for summarizing data

The module focuses on everyday data reduction functions and omits more
sophisticated statistical tools.

Unless otherwise noted, each function is designed to accept any iterable
and make only a single pass over the data.  This makes the tools suitable
for use with generator expressions and keeps them as memory friendly as
possible.  Except for median(), the use of iterators means that the
dataset can be processed one element at a time without bringing it all
into memory.

The functions, nlargest and nsmallest, are also designed to use the fewest
possible data comparisions and the data elements may be non-numerical (such
as strings) as long they support __cmp__ or __lt__.

The code is designed to run at its best on Py2.4; however,
it being kept fully compatible with Py2.3.

Written and maintained by Raymond D. Hettinger <python at rcn.com>.
Copyright (c) 2004 Python Software Foundation.  All rights reserved.

"""

from __future__ import division
import operator
import itertools
import heapq
import bisect

__all__ = ['mean', 'stddev', 'product', 'nlargest', 'nsmallest', 'select']

def mean(data):
    """Compute the arithmetic mean of a data sample"""
    try:
        cnt = len(data)                 # do it the fast way if possible
    except TypeError:
        cum = 0
        for cnt, elem in enumerate(data):
            cum += elem
        cnt += 1                        # enumerate counts from zero
    else:
        cum = sum(data)
    try:
        return cum / cnt                # relies on true division
    except ZeroDivisionError:
        raise ValueError('data must have at least one element')


def stddev(data, sample=True):
    """Computes the standard deviation of the dataset.

    If sample is True, computes the standard deviation with an n-1 divisor
    for an unbiased estimator for a data sample.  If sample is False, computes
    with a divisor of n, giving the standard deviation for a complete population.
    """

    # Formula for recurrence taken from Seminumerical Algorithms, Knuth, 4.2.2
    # It has substantially better numerical performance than computing the
    # result from sum(data) and sum(x*x for x in data).
    it = iter(data)
    try:
        m = it.next()                   # running mean
    except StopIteration:
        raise ValueError('data must have at least one element')
    s = 0                               # running sum((x-mean)**2 for x in data)
    k = 1                               # number of items
    dm = 0                              # carried forward error term
    for x in it:
        k += 1

        # This block computes:  newm = (x-m)/k
        # The use of a cumulative error term improves accuracy when the mean is
        # much larger than the standard deviation.  Also, it makes the formula
        # less sensitive to data order (with sorted data resulting is larger
        # relative errors according to experiments by Chris Reedy).
        adjm = (x-m)/k - dm             # relies on true division
        newm = m + adjm
        dm = (newm - m) - adjm

        s += (x-m)*(x-newm)
        m = newm
    if sample:
        try:
            return (s / (k-1)) ** 0.5   # sample standard deviation
        except ZeroDivisionError:
            raise ValueError('sample deviation requires at least two elements')
    else:
        return (s / k) ** 0.5           # population standard deviation

def partition(iterable):
    # Support routine for select
    under, notunder = [], []
    ua, nua = under.append, notunder.append
    it = iter(iterable)
    try:
        pivot = it.next()
    except StopIteration:
        pass
    for elem in it:
        if elem < pivot:
            under.append(elem)
        else:
            notunder.append(elem)
    return under, pivot, notunder

def select(data, n):
    'Find the nth rank ordered element'
    while 1:
        p, pivot, q = partition(data)
        if n < len(p):
            data = p
        elif n == len(p):
            return pivot
        else :
            data = q
            n -= len(p) + 1
        p = q = None

def median(data):
    try:
        n = len(data)
    except TypeError:
        data = list(data)
        n = len(data)
    if n == 0:
        raise ValueError('data must have at least one element')
    if n & 1 == 1:                      # if length is an odd number
        return select(data, n//2)
    return (select(data, n//2) + select(data, n//2-1)) / 2

def product(iterable):
    "Compute the product of the data elements"
    return reduce(operator.mul, iterable, 1)

def nlargest(iterable, n=1):            # XXX key= or cmp= argument
    "Compute the n largest elements in the dataset"
    ## When collections.fibheap is available, use it instead of heapq
    it = iter(iterable)
    result = list(itertools.islice(it, n))
    heapq.heapify(result)
    if len(result) == 0:
        raise ValueError('data must have at least one element')
    subst = heapq.heapreplace
    for elem in it:
        if (elem > result[0]):
            subst(result, elem)
    result.sort(reverse=True)
    return result

def nsmallest(iterable, n=1):
    "Compute the n largest elements in the dataset"
    ## When collections.fibheap is available, use it instead of bisect
    it = iter(iterable)
    result = list(itertools.islice(it, 0, n))
    result.sort()
    if len(result) == 0:
        raise ValueError('data must have at least one element')
    insort = bisect.insort
    pop = result.pop
    for elem in it:
        if (elem < result[-1]):
            insort(result, elem)
            pop()
    return result

XXX = """
    Other possible functions include data groupers for
    binning, counting, and splitting into equivalence classes.
"""

--- NEW FILE: test_statistics.py ---
from __future__ import division
from statistics import mean, stddev, product, nlargest, nsmallest, select, median


### UNITTESTS #################################################

import unittest
import random

def g(n):
    "iterator substitute for xrange (without defining __len__)"
    for i in xrange(n):
        yield i

class TestStats(unittest.TestCase):

    def test_mean(self):
        self.assertEqual(mean(range(6)), 15/6.0)
        self.assertEqual(mean(g(6)), 15/6.0)
        self.assertEqual(mean([10]), 10)
        self.assertRaises(ValueError, mean, [])
        self.assertRaises(TypeError, mean, 'abc')

    def test_stddev(self):
        self.assertEqual(stddev([10,15,20]), 5)
        self.assertEqual(round(stddev([11.1,4,9,13]), 3), 3.878)
        self.assertEqual(round(stddev([11.1,4,9,13], False), 3), 3.358)
        self.assertEqual(stddev([10], False), 0.0)
        self.assertRaises(ValueError, stddev, [1])
        self.assertRaises(ValueError, stddev, [], False)

    def test_median(self):
        data = range(10)
        random.shuffle(data)
        copy = data[:]
        self.assertEqual(median(data), 4.5)
        self.assertEqual(data, copy)
        self.assertEqual(median(g(10)), 4.5)
        data.insert(1, 10)
        self.assertEqual(median(data), 5)
        self.assertEqual(median([-50,10,2,11]), 6)
        self.assertEqual(median([30,10,2]), 10)
        self.assertRaises(ValueError, median, [])

    def test_product(self):
        self.assertEqual(product(range(1, 5)), 24)
        def g(lo, hi):
            for i in xrange(lo, hi):
                yield i
        self.assertEqual(product(g(1, 5)), 24)

    def test_select(self):
        n = 200
        a = range(n)
        random.shuffle(a)
        for i in xrange(100):
            nth = random.randrange(n)
            self.assertEqual(select(a, nth), nth)

    def test_nlargest(self):
        n = 10
        data = range(n)
        random.shuffle(data)
        copy = data[:]
        self.assertEqual(nlargest(data), [n-1])
        self.assertRaises(ValueError, nlargest, data, 0)
        for i in xrange(1, n+1):
            self.assertEqual(nlargest(data, i), range(n-1, n-i-1, -1))
        self.assertEqual(data, copy)
        self.assertEqual(nlargest('abcde', 3), list('edc'))

    def test_nsmallest(self):
        n = 10
        data = range(n)
        random.shuffle(data)
        copy = data[:]
        self.assertEqual(nsmallest(data), [0])
        self.assertRaises(ValueError, nlargest, data, 0)
        for i in xrange(1, n+1):
            self.assertEqual(nsmallest(data, i), range(i))
        self.assertEqual(data, copy)
        self.assertEqual(nsmallest('abcde', 3), list('abc'))

if __name__ == '__main__':
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(TestStats))
    unittest.TextTestRunner(verbosity=2).run(suite)




More information about the Python-checkins mailing list