[Spambayes-checkins] spambayes Histogram.py,NONE,1.1 TestDriver.py,1.18,1.19

Tim Peters tim_one@users.sourceforge.net
Thu, 03 Oct 2002 19:29:22 -0700


Update of /cvsroot/spambayes/spambayes
In directory usw-pr-cvs1:/tmp/cvs-serv17494

Modified Files:
	TestDriver.py 
Added Files:
	Histogram.py 
Log Message:
Split the histogram class into its own file, greatly robustified the
numerics, and made it a lot more flexible and memory-consuming <sigh>.
This is to help make progress on the central-limit scheme, where we
still have a poor idea of what the zscore distributions look like.
The histogram class is flexible enough to give us nice pictures of
that now.

Note that new min, max, and median statistics are displayed for all
histograms now (and computing percentile cutoffs would be easy to add).

Note that if you have a histogram object, you can now pass the # of
buckets desired to the display() method (no code exploits that yet,
but it means, e.g., that if you discover you really wanted more buckets,
you potentially don't have to rerun the test).  HistToGNU.py in
particular may be able to exploit that immediately.


--- NEW FILE: Histogram.py ---
import math

from Options import options

class Hist:
    """Simple histograms of float values."""

    # Pass None for lo and hi and it will automatically adjust to the min
    # and max values seen.
    # Note:  nbuckets can be passed for backward compatibility.  The
    # display() method can be passed a different nbuckets value.
    def __init__(self, nbuckets=options.nbuckets,  lo=0.0, hi=100.0):
        self.lo, self.hi = lo, hi
        self.nbuckets = nbuckets
        self.buckets = [0] * nbuckets
        self.data = []  # the raw data points
        self.stats_uptodate = False

    # Add a value to the collection.
    def add(self, x):
        self.data.append(x)
        self.stats_uptodate = False

    # Compute, and set as instance attrs:
    #     n         # of data points
    # The rest are set iff n>0:
    #     min       smallest value in collection
    #     max       largest value in collection
    #     median    midpoint
    #     mean
    #     var       variance
    #     sdev      population standard deviation (sqrt(variance))
    # self.data is also sorted.
    def compute_stats(self):
        if self.stats_uptodate:
            return
        stats_uptodate = True
        data = self.data
        n = self.n = len(data)
        if n == 0:
            return
        data.sort()
        self.min = data[0]
        self.max = data[-1]
        if n & 1:
            self.median = data[n // 2]
        else:
            self.median = (data[n // 2] + data[(n-1) // 2]) / 2.0
        # Compute mean.
        # Add in increasing order of magnitude, to minimize roundoff error.
        if data[0] < 0.0:
            temp = [(abs(x), x) for x in data]
            temp.sort()
            data = [x[1] for x in temp]
            del temp
        sum = 0.0
        for x in data:
            sum += x
        mean = self.mean = sum / n
        # Compute variance.
        var = 0.0
        for x in data:
            d = x - mean
            var += d*d
        self.var = var / n
        self.sdev = math.sqrt(self.var)

    # Merge other into self.
    def __iadd__(self, other):
        self.data.extend(other.data)
        self.stats_uptodate = False
        return self

    # Print a histogram to stdout.
    # Also sets instance var nbuckets to the # of buckets, and
    # buckts to a list of nbuckets counts, but only if at least one
    # data point is in the collection.
    def display(self, nbuckets=None, WIDTH=61):
        if nbuckets <= 0:
            raise ValueError("nbuckets %g > 0 required" % nbuckets)
        self.compute_stats()
        n = self.n
        if n == 0:
            return
        print "%d items; mean %.2f; sdev %.2f" % (n, self.mean, self.sdev)
        print "-> <stat> min %g; median %g; max %g" % (self.min,
                                                       self.median,
                                                       self.max)
        if nbuckets is None:
            nbuckets = self.nbuckets
        self.nbuckets = nbuckets
        self.buckets = buckets = [0] * nbuckets

        lo, hi = self.lo, self.hi
        if lo is None:
            lo = self.min
        if hi is None:
            hi = self.max
        if lo > hi:
            return

        # Compute bucket counts.
        span = float(hi - lo)
        bucketwidth = span / nbuckets
        for x in self.data:
            i = int((x - lo) / bucketwidth)
            if i >= nbuckets:
                i = nbuckets - 1
            elif i < 0:
                i = 0
            buckets[i] += 1

        # hunit is how many items a * represents.  A * is printed for
        # each hunit items, plus any non-zero fraction thereof.
        biggest = max(self.buckets)
        hunit, r = divmod(biggest, WIDTH)
        if r:
            hunit += 1
        print "* =", hunit, "items"

        # We need ndigits decimal digits to display the largest bucket count.
        ndigits = len(str(biggest))

        # Displaying the bucket boundaries is more troublesome.  For now,
        # just print one digit after the decimal point, regardless of what
        # the boundaries look like.
        boundary_digits = max(len(str(int(lo))), len(str(int(hi))))
        format = "%" + str(boundary_digits + 2) + '.1f %' + str(ndigits) + "d"

        for i in range(nbuckets):
            n = self.buckets[i]
            print format % (lo + i * bucketwidth, n),
            print '*' * ((n + hunit - 1) // hunit)

Index: TestDriver.py
===================================================================
RCS file: /cvsroot/spambayes/spambayes/TestDriver.py,v
retrieving revision 1.18
retrieving revision 1.19
diff -C2 -d -r1.18 -r1.19
*** TestDriver.py	28 Sep 2002 03:44:15 -0000	1.18
--- TestDriver.py	4 Oct 2002 02:29:20 -0000	1.19
***************
*** 29,102 ****
  import Tester
  import classifier
  
! class Hist:
!     """Simple histograms of float values in [0.0, 1.0]."""
! 
!     def __init__(self, nbuckets=20):
!         self.buckets = [0] * nbuckets
!         self.nbuckets = nbuckets
!         self.n = 0          # number of data points
!         self.sum = 0.0      # sum of their values
!         self.sumsq = 0.0    # sum of their squares
! 
!     def add(self, x):
!         n = self.nbuckets
!         i = int(n * x)
!         if i >= n:
!             i = n-1
!         self.buckets[i] += 1
! 
!         self.n += 1
!         x *= 100.0
!         self.sum += x
!         self.sumsq += x*x
! 
!     def __iadd__(self, other):
!         if self.nbuckets != other.nbuckets:
!             raise ValueError('bucket size mismatch')
!         for i in range(self.nbuckets):
!             self.buckets[i] += other.buckets[i]
!         self.n += other.n
!         self.sum += other.sum
!         self.sumsq += other.sumsq
!         return self
! 
!     def display(self, WIDTH=61):
!         from math import sqrt
!         if self.n > 0:
!             mean = self.sum / self.n
!             var = self.sumsq / self.n - mean**2
!             # The vagaries of f.p. rounding can make var come out negative.
!             # There are ways to fix that, but they're too painful for this
!             # part of the code to endure.
!             if var < 0.0:
!                 var = 0.0
!             print "%d items; mean %.2f; sdev %.2f" % (self.n, mean, sqrt(var))
! 
!         biggest = max(self.buckets)
!         hunit, r = divmod(biggest, WIDTH)
!         if r:
!             hunit += 1
!         print "* =", hunit, "items"
! 
!         ndigits = len(str(biggest))
!         format = "%5.1f %" + str(ndigits) + "d"
! 
!         for i in range(len(self.buckets)):
!             n = self.buckets[i]
!             print format % (100.0 * i / self.nbuckets, n),
!             print '*' * ((n + hunit - 1) // hunit)
! 
! def printhist(tag, ham, spam):
      print
      print "-> <stat> Ham scores for", tag,
!     ham.display()
  
      print
      print "-> <stat> Spam scores for", tag,
!     spam.display()
  
      if not options.compute_best_cutoffs_from_histograms:
          return
  
      # Figure out "the best" spam cutoff point, meaning the one that minimizes
--- 29,47 ----
  import Tester
  import classifier
+ from Histogram import Hist
  
! def printhist(tag, ham, spam, nbuckets=options.nbuckets):
      print
      print "-> <stat> Ham scores for", tag,
!     ham.display(nbuckets)
  
      print
      print "-> <stat> Spam scores for", tag,
!     spam.display(nbuckets)
  
      if not options.compute_best_cutoffs_from_histograms:
          return
+     if ham.n == 0 or spam.n == 0:
+         return
  
      # Figure out "the best" spam cutoff point, meaning the one that minimizes
***************
*** 112,116 ****
      best_total = fpw * fp + fn
      bests = [(0, fp, fn)]
!     for i in range(ham.nbuckets):
          # When moving the cutoff beyond bucket i, the ham in bucket i
          # are redeemed, and the spam in bucket i become false negatives.
--- 57,61 ----
      best_total = fpw * fp + fn
      bests = [(0, fp, fn)]
!     for i in range(nbuckets):
          # When moving the cutoff beyond bucket i, the ham in bucket i
          # are redeemed, and the spam in bucket i become false negatives.
***************
*** 127,131 ****
  
      i, fp, fn = bests.pop(0)
!     print '-> best cutoff for', tag, float(i) / ham.nbuckets
      print '->     with weighted total %g*%d fp + %d fn = %g' % (
            fpw, fp, fn, best_total)
--- 72,76 ----
  
      i, fp, fn = bests.pop(0)
!     print '-> best cutoff for', tag, float(i) / nbuckets
      print '->     with weighted total %g*%d fp + %d fn = %g' % (
            fpw, fp, fn, best_total)
***************
*** 155,160 ****
          self.falsepos = Set()
          self.falseneg = Set()
!         self.global_ham_hist = Hist(options.nbuckets)
!         self.global_spam_hist = Hist(options.nbuckets)
          self.ntimes_finishtest_called = 0
          self.new_classifier()
--- 100,105 ----
          self.falsepos = Set()
          self.falseneg = Set()
!         self.global_ham_hist = Hist()
!         self.global_spam_hist = Hist()
          self.ntimes_finishtest_called = 0
          self.new_classifier()
***************
*** 163,168 ****
          c = self.classifier = classifier.Bayes()
          self.tester = Tester.Test(c)
!         self.trained_ham_hist = Hist(options.nbuckets)
!         self.trained_spam_hist = Hist(options.nbuckets)
  
      # CAUTION:  this just doesn't work for incrememental training when
--- 108,113 ----
          c = self.classifier = classifier.Bayes()
          self.tester = Tester.Test(c)
!         self.trained_ham_hist = Hist()
!         self.trained_spam_hist = Hist()
  
      # CAUTION:  this just doesn't work for incrememental training when
***************
*** 192,197 ****
          self.global_ham_hist += self.trained_ham_hist
          self.global_spam_hist += self.trained_spam_hist
!         self.trained_ham_hist = Hist(options.nbuckets)
!         self.trained_spam_hist = Hist(options.nbuckets)
  
          self.ntimes_finishtest_called += 1
--- 137,142 ----
          self.global_ham_hist += self.trained_ham_hist
          self.global_spam_hist += self.trained_spam_hist
!         self.trained_ham_hist = Hist()
!         self.trained_spam_hist = Hist()
  
          self.ntimes_finishtest_called += 1
***************
*** 220,229 ****
          c = self.classifier
          t = self.tester
!         local_ham_hist = Hist(options.nbuckets)
!         local_spam_hist = Hist(options.nbuckets)
  
          def new_ham(msg, prob, lo=options.show_ham_lo,
                                 hi=options.show_ham_hi):
!             local_ham_hist.add(prob)
              if lo <= prob <= hi:
                  print
--- 165,174 ----
          c = self.classifier
          t = self.tester
!         local_ham_hist = Hist()
!         local_spam_hist = Hist()
  
          def new_ham(msg, prob, lo=options.show_ham_lo,
                                 hi=options.show_ham_hi):
!             local_ham_hist.add(prob * 100.0)
              if lo <= prob <= hi:
                  print
***************
*** 234,238 ****
          def new_spam(msg, prob, lo=options.show_spam_lo,
                                  hi=options.show_spam_hi):
!             local_spam_hist.add(prob)
              if lo <= prob <= hi:
                  print
--- 179,183 ----
          def new_spam(msg, prob, lo=options.show_spam_lo,
                                  hi=options.show_spam_hi):
!             local_spam_hist.add(prob * 100.0)
              if lo <= prob <= hi:
                  print