[Spambayes-checkins] spambayes mboxtest.py,NONE,1.1 timtest.py,1.7,1.8

Jeremy Hylton jhylton@users.sourceforge.net
Fri, 06 Sep 2002 12:26:36 -0700


Update of /cvsroot/spambayes/spambayes
In directory usw-pr-cvs1:/tmp/cvs-serv16790

Modified Files:
	timtest.py 
Added Files:
	mboxtest.py 
Log Message:
Add a test driver that works with mboxes.

This is similar in spirit to timtest, but it works with any old kind
of mailbox recognized by the Python mailbox module.  

One non-trivial difference from timtest: Rather than requiring that
the user split the mailbox into separate parts, it selects NSETS
different subsets of the mailbox to use for testing.  It chooses an
arbitrary subset because my mailboxes are sorted by date, and I didn't
want to bias tests by choosing training data from a small period of
time.

The timtest module has grown a Driver() class that is intended to work
just like the drive() function, but with a bit more flexibility.  The
jdrive() function might be able to replace drive(), but I can't test
it so I'm not going to replace it.  Maybe Tim will try jdrive() and
report if it works correctly.

I didn't find the MsgStream() class useful outside of timtest, but
mailboxes are represented by the mbox class, which is an iterable
collection of Msg objects.  

Renamed the path attribute of Msg to tag, since path doesn't make
sense with an mbox.  The path was getting used as a human-readable tag
for messages, so I synthesized one for mbox messages.


--- NEW FILE: mboxtest.py ---
#! /usr/bin/env python

from timtoken import tokenize
from classifier import GrahamBayes
from Tester import Test
from timtest import Driver, Msg

import getopt
import mailbox
import random
from sets import Set
import sys

mbox_fmts = {"unix": mailbox.PortableUnixMailbox,
             "mmdf": mailbox.MmdfMailbox,
             "mh": mailbox.MHMailbox,
             "qmail": mailbox.Maildir,
             }

class MboxMsg(Msg):

    def __init__(self, fp, path, index):
        self.guts = fp.read()
        self.tag = "%s:%s %s" % (path, index, subject(self.guts))

class mbox(object):

    def __init__(self, path, indices=None):
        self.path = path
        self.indices = {}
        self.key = ''
        if indices is not None:
            self.key = " %s" % indices[0]
            for i in indices:
                self.indices[i] = 1

    def __repr__(self):
        return "<mbox: %s%s>" % (self.path, self.key)

    def __iter__(self):
        # Use a simple factory that just produces a string.
        mbox = mbox_fmts[FMT](open(self.path, "rb"),
                              lambda f: MboxMsg(f, self.path, i))

        i = 0
        while 1:
            msg = mbox.next()
            if msg is None:
                return
            i += 1
            if self.indices.get(i-1) or not self.indices:
                yield msg

def subject(buf):
    buf = buf.lower()
    i = buf.find('subject:')
    j = buf.find("\n", i)
    return buf[i:j]

def randindices(nelts, nresults):
    L = range(nelts)
    random.shuffle(L)
    chunk = nelts / nresults
    for i in range(nresults):
        yield Set(L[:chunk])
        del L[:chunk]

def sort(seq):
    L = list(seq)
    L.sort()
    return L

def main(args):
    global FMT
    
    FMT = "unix"
    NSETS = 5
    SEED = 101
    LIMIT = None
    opts, args = getopt.getopt(args, "f:n:s:l:")
    for k, v in opts:
        if k == '-f':
            FMT = v
        if k == '-n':
            NSETS = int(v)
        if k == '-s':
            SEED = int(v)
        if k == '-l':
            LIMIT = int(v)

    ham, spam = args

    random.seed(SEED)

    nham = len(list(mbox(ham)))
    nspam = len(list(mbox(spam)))

    if LIMIT:
        nham = min(nham, LIMIT)
        nspam = min(nspam, LIMIT)

    print "ham", ham, nham
    print "spam", spam, nspam

    testsets = []
    for iham in randindices(nham, NSETS):
        for ispam in randindices(nspam, NSETS):
            testsets.append((sort(iham), sort(ispam)))
            
    driver = Driver()

    for iham, ispam in testsets:
        driver.train(mbox(ham, iham), mbox(spam, ispam))
        for ihtest, istest in testsets:
            if (iham, ispam) == (ihtest, istest):
                continue
            driver.test(mbox(ham, ihtest), mbox(spam, istest))
        driver.finish()
    driver.alldone()

if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))


Index: timtest.py
===================================================================
RCS file: /cvsroot/spambayes/spambayes/timtest.py,v
retrieving revision 1.7
retrieving revision 1.8
diff -C2 -d -r1.7 -r1.8
*** timtest.py	6 Sep 2002 19:12:59 -0000	1.7
--- timtest.py	6 Sep 2002 19:26:34 -0000	1.8
***************
*** 59,63 ****
      def __init__(self, dir, name):
          path = dir + "/" + name
!         self.path = path
          f = open(path, 'rb')
          guts = f.read()
--- 59,63 ----
      def __init__(self, dir, name):
          path = dir + "/" + name
!         self.tag = path
          f = open(path, 'rb')
          guts = f.read()
***************
*** 69,76 ****
  
      def __hash__(self):
!         return hash(self.path)
  
      def __eq__(self, other):
!         return self.path == other.path
  
  class MsgStream(object):
--- 69,76 ----
  
      def __hash__(self):
!         return hash(self.tag)
  
      def __eq__(self, other):
!         return self.tag == other.tag
  
  class MsgStream(object):
***************
*** 86,89 ****
--- 86,198 ----
          return self.produce()
  
+ class Driver:
+ 
+     def __init__(self):
+         self.nbuckets = 40
+         self.falsepos = Set()
+         self.falseneg = Set()
+         self.global_ham_hist = Hist(self.nbuckets)
+         self.global_spam_hist = Hist(self.nbuckets)
+ 
+     def train(self, ham, spam):
+         self.classifier = classifier.GrahamBayes()
+         self.tester = Tester.Test(self.classifier)
+         print "Training on", ham, "&", spam, "..."
+         self.tester.train(ham, spam)
+ 
+         self.trained_ham_hist = Hist(self.nbuckets)
+         self.trained_spam_hist = Hist(self.nbuckets)
+ 
+     def finish(self):
+         printhist("all in this set:",
+                   self.trained_ham_hist, self.trained_spam_hist)
+         self.global_ham_hist += self.trained_ham_hist
+         self.global_spam_hist += self.trained_spam_hist
+ 
+     def alldone(self):
+         printhist("all runs:", self.global_ham_hist, self.global_spam_hist)
+ 
+     def test(self, ham, spam):
+         c = self.classifier
+         t = self.tester
+         local_ham_hist = Hist(self.nbuckets)
+         local_spam_hist = Hist(self.nbuckets)
+ 
+         def new_ham(msg, prob):
+             local_ham_hist.add(prob)
+ 
+         def new_spam(msg, prob):
+             local_spam_hist.add(prob)
+             if prob < 0.1:
+                 print
+                 print "Low prob spam!", prob
+                 print msg.tag
+                 prob, clues = c.spamprob(msg, True)
+                 for clue in clues:
+                     print "prob(%r) = %g" % clue
+                 print
+                 print msg.guts
+ 
+         t.reset_test_results()
+         print "    testing against", ham, "&", spam, "...",
+         t.predict(spam, True, new_spam)
+         t.predict(ham, False, new_ham)
+         print t.nham_tested, "hams &", t.nspam_tested, "spams"
+ 
+         print "    false positive:", t.false_positive_rate()
+         print "    false negative:", t.false_negative_rate()
+ 
+         newfpos = Set(t.false_positives()) - self.falsepos
+         self.falsepos |= newfpos
+         print "    new false positives:", [e.tag for e in newfpos]
+         for e in newfpos:
+             print '*' * 78
+             print e.tag
+             prob, clues = c.spamprob(e, True)
+             print "prob =", prob
+             for clue in clues:
+                 print "prob(%r) = %g" % clue
+             print
+             print e.guts
+ 
+         newfneg = Set(t.false_negatives()) - self.falseneg
+         self.falseneg |= newfneg
+         print "    new false negatives:", [e.tag for e in newfneg]
+         for e in []:#newfneg:
+             print '*' * 78
+             print e.tag
+             prob, clues = c.spamprob(e, True)
+             print "prob =", prob
+             for clue in clues:
+                 print "prob(%r) = %g" % clue
+             print
+             print e.guts[:1000]
+ 
+         print
+         print "    best discriminators:"
+         stats = [(r.killcount, w) for w, r in c.wordinfo.iteritems()]
+         stats.sort()
+         del stats[:-30]
+         for count, w in stats:
+             r = c.wordinfo[w]
+             print "        %r %d %g" % (w, r.killcount, r.spamprob)
+ 
+ 
+         printhist("this pair:", local_ham_hist, local_spam_hist)
+ 
+         self.trained_ham_hist += local_ham_hist
+         self.trained_spam_hist += local_spam_hist
+ 
+ def jdrive():
+     d = Driver()
+ 
+     for spamdir, hamdir in SPAMHAMDIRS:
+         d.train(MsgStream(hamdir), MsgStream(spamdir))
+         for sd2, hd2 in SPAMHAMDIRS:
+             if (sd2, hd2) == (spamdir, hamdir):
+                 continue
+             d.test(MsgStream(hd2), MsgStream(sd2))
+         d.finish()
+     d.alldone()
  
  def drive():
***************
*** 185,187 ****
      printhist("all runs:", global_ham_hist, global_spam_hist)
  
! drive()
--- 294,297 ----
      printhist("all runs:", global_ham_hist, global_spam_hist)
  
! if __name__ == "__main__":
!     drive()