[Python-checkins] python/nondist/sandbox/spambayes Tester.py,1.2,1.3 timtest.py,1.13,1.14

tim_one@users.sourceforge.net tim_one@users.sourceforge.net
Tue, 03 Sep 2002 16:53:39 -0700


Update of /cvsroot/python/python/nondist/sandbox/spambayes
In directory usw-pr-cvs1:/tmp/cvs-serv11194

Modified Files:
	Tester.py timtest.py 
Log Message:
Added support for simple histograms of the probability distributions for
ham and spam.


Index: Tester.py
===================================================================
RCS file: /cvsroot/python/python/nondist/sandbox/spambayes/Tester.py,v
retrieving revision 1.2
retrieving revision 1.3
diff -C2 -d -r1.2 -r1.3
*** Tester.py	28 Aug 2002 21:04:56 -0000	1.2
--- Tester.py	3 Sep 2002 23:53:36 -0000	1.3
***************
*** 53,60 ****
      # in a little class that identifies the example in a useful way, and whose
      # __iter__ produces a token stream for the classifier.
!     def predict(self, stream, is_spam):
          guess = self.classifier.spamprob
          for example in stream:
!             is_spam_guessed = guess(example) > 0.90
              correct = is_spam_guessed == is_spam
              if is_spam:
--- 53,66 ----
      # in a little class that identifies the example in a useful way, and whose
      # __iter__ produces a token stream for the classifier.
!     #
!     # If specified, callback(msg, spam_probability) is called for each
!     # msg in the stream, after the spam probability is computed.
!     def predict(self, stream, is_spam, callback=None):
          guess = self.classifier.spamprob
          for example in stream:
!             prob = guess(example)
!             if callback:
!                 callback(example, prob)
!             is_spam_guessed = prob > 0.90
              correct = is_spam_guessed == is_spam
              if is_spam:

Index: timtest.py
===================================================================
RCS file: /cvsroot/python/python/nondist/sandbox/spambayes/timtest.py,v
retrieving revision 1.13
retrieving revision 1.14
diff -C2 -d -r1.13 -r1.14
*** timtest.py	3 Sep 2002 02:13:46 -0000	1.13
--- timtest.py	3 Sep 2002 23:53:36 -0000	1.14
***************
*** 16,19 ****
--- 16,61 ----
  import classifier
  
+ class Hist:
+     def __init__(self, nbuckets=20):
+         self.buckets = [0] * nbuckets
+         self.nbuckets = nbuckets
+ 
+     def add(self, x):
+         n = self.nbuckets
+         i = int(n * x)
+         if i >= n:
+             i = n-1
+         self.buckets[i] += 1
+ 
+     def __iadd__(self, other):
+         if self.nbuckets != other.nbuckets:
+             raise ValueError('bucket size mismatch')
+         for i in range(self.nbuckets):
+             self.buckets[i] += other.buckets[i]
+         return self
+ 
+     def display(self, WIDTH=60):
+         biggest = max(self.buckets)
+         hunit, r = divmod(biggest, WIDTH)
+         if r:
+             hunit += 1
+         print "* =", hunit, "items"
+ 
+         ndigits = len(str(biggest))
+         format = "%6.2f %" + str(ndigits) + "d"
+ 
+         for i, n in enumerate(self.buckets):
+             print format % (100.0 * i / self.nbuckets, n),
+             print '*' * ((n + hunit - 1) // hunit)
+ 
+ def printhist(tag, ham, spam):
+     print
+     print "Ham distribution for", tag
+     ham.display()
+ 
+     print
+     print "Spam distribution for", tag
+     spam.display()
+ 
  # Find all the text components of the msg.  There's no point decoding
  # binary blobs (like images).  If a multipart/alternative has both plain
***************
*** 543,548 ****
--- 585,593 ----
  
  def drive():
+     nbuckets = 40
      falsepos = Set()
      falseneg = Set()
+     global_ham_hist = Hist(nbuckets)
+     global_spam_hist = Hist(nbuckets)
      for spamdir, hamdir in SPAMHAMDIRS:
          c = classifier.GrahamBayes()
***************
*** 552,566 ****
          print t.nham, "hams &", t.nspam, "spams"
  
!         fp = file('w.pik', 'wb')
!         pickle.dump(c, fp, 1)
!         fp.close()
  
          for sd2, hd2 in SPAMHAMDIRS:
              if (sd2, hd2) == (spamdir, hamdir):
                  continue
              t.reset_test_results()
              print "    testing against", hd2, "&", sd2, "...",
!             t.predict(MsgStream(sd2), True)
!             t.predict(MsgStream(hd2), False)
              print t.nham_tested, "hams &", t.nspam_tested, "spams"
  
--- 597,624 ----
          print t.nham, "hams &", t.nspam, "spams"
  
!         trained_ham_hist = Hist(nbuckets)
!         trained_spam_hist = Hist(nbuckets)
! 
!         #fp = file('w.pik', 'wb')
!         #pickle.dump(c, fp, 1)
!         #fp.close()
  
          for sd2, hd2 in SPAMHAMDIRS:
              if (sd2, hd2) == (spamdir, hamdir):
                  continue
+ 
+             local_ham_hist = Hist(nbuckets)
+             local_spam_hist = Hist(nbuckets)
+ 
+             def new_ham(msg, prob):
+                 local_ham_hist.add(prob)
+ 
+             def new_spam(msg, prob):
+                 local_spam_hist.add(prob)
+ 
              t.reset_test_results()
              print "    testing against", hd2, "&", sd2, "...",
!             t.predict(MsgStream(sd2), True, new_spam)
!             t.predict(MsgStream(hd2), False, new_ham)
              print t.nham_tested, "hams &", t.nspam_tested, "spams"
  
***************
*** 595,599 ****
  
              print
- 
              print "    best discriminators:"
              stats = [(r.killcount, w) for w, r in c.wordinfo.iteritems()]
--- 653,656 ----
***************
*** 603,607 ****
                  r = c.wordinfo[w]
                  print "        %r %d %g" % (w, r.killcount, r.spamprob)
!             print
  
  drive()
--- 660,675 ----
                  r = c.wordinfo[w]
                  print "        %r %d %g" % (w, r.killcount, r.spamprob)
! 
! 
!             printhist("this pair:", local_ham_hist, local_spam_hist)
! 
!             trained_ham_hist += local_ham_hist
!             trained_spam_hist += local_spam_hist
! 
!         printhist("all in this set:", trained_ham_hist, trained_spam_hist)
!         global_ham_hist += trained_ham_hist
!         global_spam_hist += trained_spam_hist
! 
!     printhist("all runs:", global_ham_hist, global_spam_hist)
  
  drive()