[Spambayes-checkins] spambayes cdb.py,NONE,1.1 neilfilter.py,NONE,1.1 neiltrain.py,NONE,1.1

Neil Schemenauer nascheme@users.sourceforge.net
Mon, 09 Sep 2002 14:21:56 -0700


Update of /cvsroot/spambayes/spambayes
In directory usw-pr-cvs1:/tmp/cvs-serv8534

Added Files:
	cdb.py neilfilter.py neiltrain.py 
Log Message:
Add a pure Python implementation of CDB and two scripts that use it.  It
seems pretty zippy for both reading and creating.


--- NEW FILE: cdb.py ---
"""
Dan Bernstein's CDB implemented in Python

see http://cr.yp.to/cdb.html

"""
import os
import struct
import mmap
import sys

def uint32_unpack(buf):
    return struct.unpack('<L', buf)[0]

def uint32_pack(n):
    return struct.pack('<L', n)
    
CDB_HASHSTART = 5381

def cdb_hash(buf):
    h = CDB_HASHSTART
    for c in buf:
        h = (h + (h << 5)) & 0xffffffffL
        h ^= ord(c)
    return h

class Cdb(object):

    def __init__(self, fp):
        fd = fp.fileno()
        self.size = os.fstat(fd).st_size
        self.map = mmap.mmap(fd, self.size, access=mmap.ACCESS_READ)
        self.findstart()
        self.loop = 0 # number of hash slots searched under this key
        # initialized if loop is nonzero
        self.khash = 0
        self.hpos = 0
        self.hslots = 0
        # initialized if findnext() returns 1
        self.dpos = 0
        self.dlen = 0

    def close(self):
        self.map.close()

    def findstart(self):
        self.loop = 0

    def read(self, n, pos):
        # XXX add code for platforms without mmap
        return self.map[pos:pos+n]

    def match(self, key, pos):
        if key == self.read(len(key), pos):
            return 1
        else:
            return 0

    def findnext(self, key):
        if not self.loop:
            u = cdb_hash(key)
            buf = self.read(8, u << 3 & 2047)
            self.hslots = uint32_unpack(buf[4:])
            if not self.hslots:
                raise KeyError
            self.hpos = uint32_unpack(buf[:4])
            self.khash = u
            u >>= 8
            u %= self.hslots
            u <<= 3
            self.kpos = self.hpos + u

        while self.loop < self.hslots:
            buf = self.read(8, self.kpos)
            pos = uint32_unpack(buf[4:])
            if not pos:
                raise KeyError
            self.loop += 1
            self.kpos += 8
            if self.kpos == self.hpos + (self.hslots << 3):
                self.kpos = self.hpos
            u = uint32_unpack(buf[:4])
            if u == self.khash:
                buf = self.read(8, pos)
                u = uint32_unpack(buf[:4])
                if u == len(key):
                    if self.match(key, pos + 8):
                        dlen = uint32_unpack(buf[4:])
                        dpos = pos + 8 + len(key)
                        return self.read(dlen, dpos)
        raise KeyError
                        
    def __getitem__(self, key):
        self.findstart()
        return self.findnext(key)

    def get(self, key, default=None):
        self.findstart()
        try:
            return self.findnext(key)
        except KeyError:
            return default

def cdb_make(outfile, items):
    pos = 2048
    tables = {} # { h & 255 : [(h, p)] }

    # write keys and data
    outfile.seek(pos)
    for key, value in items:
        outfile.write(uint32_pack(len(key)) + uint32_pack(len(value)))
        h = cdb_hash(key)
        outfile.write(key)
        outfile.write(value)
        tables.setdefault(h & 255, []).append((h, pos))
        pos += 8 + len(key) + len(value)

    final = ''
    # write hash tables
    for i in range(256):
        entries = tables.get(i, [])
        nslots = 2*len(entries)
        final += uint32_pack(pos) + uint32_pack(nslots)
        null = (0, 0)
        table = [null] * nslots
        for h, p in entries:
            n = (h >> 8) % nslots
            while table[n] is not null:
                n = (n + 1) % nslots
            table[n] = (h, p)
        for h, p in table:
            outfile.write(uint32_pack(h) + uint32_pack(p))
            pos += 8

    # write header (pointers to tables and their lengths)
    outfile.flush()
    outfile.seek(0)
    outfile.write(final)
    

def test():
    #db = Cdb(open("t"))
    #print db['one']
    #print db['two']
    #print db['foo']
    #print db['us']
    #print db.get('ec')
    #print db.get('notthere')
    db = open('test.cdb', 'wb')
    cdb_make(db,
             [('one', 'Hello'),
              ('two', 'Goodbye'),
              ('foo', 'Bar'),
              ('us', 'United States'),
              ])
    db.close()
    db = Cdb(open("test.cdb", 'rb'))
    print db['one']
    print db['two']
    print db['foo']
    print db['us']
    print db.get('ec')
    print db.get('notthere')

if __name__ == '__main__':
    test()

--- NEW FILE: neilfilter.py ---
#! /usr/bin/env python

"""Usage: %(program)s wordprobs.cdb
"""

import sys
import os
import email
from heapq import heapreplace
from sets import Set
from classifier import MIN_SPAMPROB, MAX_SPAMPROB, UNKNOWN_SPAMPROB, \
    MAX_DISCRIMINATORS
import cdb

program = sys.argv[0] # For usage(); referenced by docstring above

from tokenizer import tokenize

def spamprob(wordprobs, wordstream, evidence=False):
    """Return best-guess probability that wordstream is spam.

    wordprobs is a CDB of word probabilities

    wordstream is an iterable object producing words.
    The return value is a float in [0.0, 1.0].

    If optional arg evidence is True, the return value is a pair
        probability, evidence
    where evidence is a list of (word, probability) pairs.
    """

    # A priority queue to remember the MAX_DISCRIMINATORS best
    # probabilities, where "best" means largest distance from 0.5.
    # The tuples are (distance, prob, word).
    nbest = [(-1.0, None, None)] * MAX_DISCRIMINATORS
    smallest_best = -1.0

    mins = []   # all words w/ prob MIN_SPAMPROB
    maxs = []   # all words w/ prob MAX_SPAMPROB
    # Counting a unique word multiple times hurts, although counting one
    # at most two times had some benefit whan UNKNOWN_SPAMPROB was 0.2.
    # When that got boosted to 0.5, counting more than once became
    # counterproductive.
    for word in Set(wordstream):
        prob = float(wordprobs.get(word, UNKNOWN_SPAMPROB))
        distance = abs(prob - 0.5)
        if prob == MIN_SPAMPROB:
            mins.append((distance, prob, word))
        elif prob == MAX_SPAMPROB:
            maxs.append((distance, prob, word))
        elif distance > smallest_best:
            # Subtle:  we didn't use ">" instead of ">=" just to save
            # calls to heapreplace().  The real intent is that if
            # there are many equally strong indicators throughout the
            # message, we want to favor the ones that appear earliest:
            # it's expected that spam headers will often have smoking
            # guns, and, even when not, spam has to grab your attention
            # early (& note that when spammers generate large blocks of
            # random gibberish to throw off exact-match filters, it's
            # always at the end of the msg -- if they put it at the
            # start, *nobody* would read the msg).
            heapreplace(nbest, (distance, prob, word))
            smallest_best = nbest[0][0]

    # Compute the probability.  Note:  This is what Graham's code did,
    # but it's dubious for reasons explained in great detail on Python-
    # Dev:  it's missing P(spam) and P(not-spam) adjustments that
    # straightforward Bayesian analysis says should be here.  It's
    # unclear how much it matters, though, as the omissions here seem
    # to tend in part to cancel out distortions introduced earlier by
    # HAMBIAS.  Experiments will decide the issue.
    clues = []

    # First cancel out competing extreme clues (see comment block at
    # MAX_DISCRIMINATORS declaration -- this is a twist on Graham).
    if mins or maxs:
        if len(mins) < len(maxs):
            shorter, longer = mins, maxs
        else:
            shorter, longer = maxs, mins
        tokeep = min(len(longer) - len(shorter), MAX_DISCRIMINATORS)
        # They're all good clues, but we're only going to feed the tokeep
        # initial clues from the longer list into the probability
        # computation.
        for dist, prob, word in shorter + longer[tokeep:]:
            if evidence:
                clues.append((word, prob))
        for x in longer[:tokeep]:
            heapreplace(nbest, x)

    prob_product = inverse_prob_product = 1.0
    for distance, prob, word in nbest:
        if prob is None:    # it's one of the dummies nbest started with
            continue
        if evidence:
            clues.append((word, prob))
        prob_product *= prob
        inverse_prob_product *= 1.0 - prob

    prob = prob_product / (prob_product + inverse_prob_product)
    if evidence:
        clues.sort(lambda a, b: cmp(a[1], b[1]))
        return prob, clues
    else:
        return prob

def formatclues(clues, sep="; "):
    """Format the clues into something readable."""
    return sep.join(["%r: %.2f" % (word, prob) for word, prob in clues])

def is_spam(wordprobs, input):
    """Filter (judge) a message"""
    msg = email.message_from_file(input)
    prob, clues = spamprob(wordprobs, tokenize(msg), True)
    #print "%.2f;" % prob, formatclues(clues)
    if prob < 0.9:
        return False
    else:
        return True

def usage(code, msg=''):
    """Print usage message and sys.exit(code)."""
    if msg:
        print >> sys.stderr, msg
        print >> sys.stderr
    print >> sys.stderr, __doc__ % globals()
    sys.exit(code)

def main():
    if len(sys.argv) != 2:
        usage(2)

    wordprobs = cdb.Cdb(open(sys.argv[1], 'rb'))
    if is_spam(wordprobs, sys.stdin):
        sys.exit(1)
    else:
        sys.exit(0)

if __name__ == "__main__":
    main()

--- NEW FILE: neiltrain.py ---
#! /usr/bin/env python

"""Usage: %(program)s spam.mbox ham.mbox wordprobs.cdb
"""

import sys
import os
import mailbox
import email
import classifier
import cdb

program = sys.argv[0] # For usage(); referenced by docstring above

from tokenizer import tokenize

def getmbox(msgs):
    """Return an iterable mbox object"""
    def _factory(fp):
        try:
            return email.message_from_file(fp)
        except email.Errors.MessageParseError:
            return ''

    if msgs.startswith("+"):
        import mhlib
        mh = mhlib.MH()
        mbox = mailbox.MHMailbox(os.path.join(mh.getpath(), msgs[1:]),
                                 _factory)
    elif os.path.isdir(msgs):
        # XXX Bogus: use an MHMailbox if the pathname contains /Mail/,
        # else a DirOfTxtFileMailbox.
        if msgs.find("/Mail/") >= 0:
            mbox = mailbox.MHMailbox(msgs, _factory)
        else:
            mbox = DirOfTxtFileMailbox(msgs, _factory)
    else:
        fp = open(msgs)
        mbox = mailbox.PortableUnixMailbox(fp, _factory)
    return mbox

def train(bayes, msgs, is_spam):
    """Train bayes with all messages from a mailbox."""
    mbox = getmbox(msgs)
    for msg in mbox:
        bayes.learn(tokenize(msg), is_spam, False)

def usage(code, msg=''):
    """Print usage message and sys.exit(code)."""
    if msg:
        print >> sys.stderr, msg
        print >> sys.stderr
    print >> sys.stderr, __doc__ % globals()
    sys.exit(code)

def main():
    """Main program; parse options and go."""
    if len(sys.argv) != 4:
        usage(2)

    spam_name = sys.argv[1]
    ham_name = sys.argv[2]
    db_name = sys.argv[3]
    bayes = classifier.GrahamBayes()
    print 'Training with spam...'
    train(bayes, spam_name, True)
    print 'Training with ham...'
    train(bayes, ham_name, False)
    print 'Updating probabilities...'
    bayes.update_probabilities()
    items = []
    for word, winfo in bayes.wordinfo.iteritems():
        #print `word`, str(winfo.spamprob)
        items.append((word, str(winfo.spamprob)))
    print 'Writing DB...'
    db = open(db_name, "wb")
    cdb.cdb_make(db, items)
    db.close()
    print 'done'

if __name__ == "__main__":
    main()