[Spambayes-checkins] spambayes/spambayes/test test_stats.py, 1.1, 1.2

Tony Meyer anadelonbrin at users.sourceforge.net
Wed Dec 22 01:22:29 CET 2004


Update of /cvsroot/spambayes/spambayes/spambayes/test
In directory sc8-pr-cvs1.sourceforge.net:/tmp/cvs-serv31740/spambayes/test

Modified Files:
	test_stats.py 
Log Message:
Update tests to reflect new Stats() constructor.

Add checks for the cost calculations.

Check that getting stats from a certain date works correctly.

Index: test_stats.py
===================================================================
RCS file: /cvsroot/spambayes/spambayes/spambayes/test/test_stats.py,v
retrieving revision 1.1
retrieving revision 1.2
diff -C2 -d -r1.1 -r1.2
*** test_stats.py	21 Dec 2004 21:30:13 -0000	1.1
--- test_stats.py	22 Dec 2004 00:22:26 -0000	1.2
***************
*** 15,27 ****
  class StatsTest(unittest.TestCase):
      def setUp(self):
-         self.s_cut = options["Categorization", "spam_cutoff"]
-         self.h_cut = options["Categorization", "ham_cutoff"]
-         self.h_string = options["Headers", "header_ham_string"]
-         self.u_string = options["Headers", "header_unsure_string"]
-         self.s_string = options["Headers", "header_spam_string"]
          self.messageinfo_db_name = "__unittest.pik"
          self.messageinfo_db = MessageInfoPickle(self.messageinfo_db_name)
!         self.s = Stats(self.s_cut, self.h_cut, self.messageinfo_db,
!                        self.h_string, self.u_string, self.s_string)
  
      def tearDown(self):
--- 15,21 ----
  class StatsTest(unittest.TestCase):
      def setUp(self):
          self.messageinfo_db_name = "__unittest.pik"
          self.messageinfo_db = MessageInfoPickle(self.messageinfo_db_name)
!         self.s = Stats(options, self.messageinfo_db)
  
      def tearDown(self):
***************
*** 43,48 ****
          self.messageinfo_db.close()
          self.messageinfo_db = MessageInfoPickle(self.messageinfo_db_name)
!         self.s = Stats(self.s_cut, self.h_cut, self.messageinfo_db,
!                        self.h_string, self.u_string, self.s_string)
          self.assertEqual(now, self.s.from_date)
  
--- 37,41 ----
          self.messageinfo_db.close()
          self.messageinfo_db = MessageInfoPickle(self.messageinfo_db_name)
!         self.s = Stats(options, self.messageinfo_db)
          self.assertEqual(now, self.s.from_date)
  
***************
*** 235,238 ****
--- 228,241 ----
                            data["num_unsure_trained_spam"]) /
                           data["total_spam"])
+         self.assertEqual(new_data["total_cost"],
+                          data["num_trained_ham_fp"] *
+                          options["TestDriver", "best_cutoff_fp_weight"] + \
+                          data["num_trained_spam_fn"] *
+                          options["TestDriver", "best_cutoff_fn_weight"] + \
+                          data["num_unsure"] *
+                          options["TestDriver", "best_cutoff_unsure_weight"])
+         self.assertEqual(new_data["cost_savings"], data["num_spam"] *
+                          options["TestDriver", "best_cutoff_fn_weight"] -
+                          data["total_cost"])
  
      def test_AddPercentStrings(self):
***************
*** 293,297 ****
          self.assertEqual(s[9], "Manually classified as spam:\t0")
          self.assertEqual(s[10], "")
!         if self.h_cut <= score < self.s_cut:
              self.assertEqual(s[11], "Unsures trained as good:\t0 (0.0% of unsures)")
              self.assertEqual(s[12], "Unsures trained as spam:\t0 (0.0% of unsures)")
--- 296,301 ----
          self.assertEqual(s[9], "Manually classified as spam:\t0")
          self.assertEqual(s[10], "")
!         if options["Categorization", "ham_cutoff"] <= score < \
!            options["Categorization", "spam_cutoff"]:
              self.assertEqual(s[11], "Unsures trained as good:\t0 (0.0% of unsures)")
              self.assertEqual(s[12], "Unsures trained as spam:\t0 (0.0% of unsures)")
***************
*** 372,376 ****
          self.assertEqual(s[17], "Spam correctly identified:\t33.3% (+ 33.3% unsure)")
          self.assertEqual(s[18], "Good incorrectly identified:\t33.3% (+ 33.3% unsure)")
!         self.assertEqual(len(s), 19)
  
      def test_get_all_stats(self):
--- 376,383 ----
          self.assertEqual(s[17], "Spam correctly identified:\t33.3% (+ 33.3% unsure)")
          self.assertEqual(s[18], "Good incorrectly identified:\t33.3% (+ 33.3% unsure)")
!         self.assertEqual(s[19], "")
!         self.assertEqual(s[20], "Total cost of spam:\t$11.60")
!         self.assertEqual(s[21], "SpamBayes savings:\t$-9.60")
!         self.assertEqual(len(s), 22)
  
      def test_get_all_stats(self):
***************
*** 395,401 ****
          self.assertEqual(s[17], "Spam correctly identified:\t40.0% (+ 20.0% unsure)")
          self.assertEqual(s[18], "Good incorrectly identified:\t33.3% (+ 16.7% unsure)")
!         self.assertEqual(len(s), 19)
  
      def _stuff_with_data(self, use_html=False):
          # Record some session data.
          self.s.RecordClassification(0.0)
--- 402,417 ----
          self.assertEqual(s[17], "Spam correctly identified:\t40.0% (+ 20.0% unsure)")
          self.assertEqual(s[18], "Good incorrectly identified:\t33.3% (+ 16.7% unsure)")
!         self.assertEqual(s[19], "")
!         self.assertEqual(s[20], "Total cost of spam:\t$23.40")
!         self.assertEqual(s[21], "SpamBayes savings:\t$-19.40")
!         self.assertEqual(len(s), 22)
  
      def _stuff_with_data(self, use_html=False):
+         self._stuff_with_session_data()
+         self._stuff_with_persistent_data()
+         self.s.CalculatePersistentStats()
+         return self.s.GetStats(use_html=use_html)
+ 
+     def _stuff_with_session_data(self):
          # Record some session data.
          self.s.RecordClassification(0.0)
***************
*** 411,414 ****
--- 427,431 ----
          self.s.RecordTraining(False, 1.0)
  
+     def _stuff_with_persistent_data(self):
          # Put data into the totals.
          msg = Message('0', self.messageinfo_db)
***************
*** 436,441 ****
          msg = Message('8', self.messageinfo_db)
          msg.RememberClassification(options['Headers','header_unsure_string'])
-         self.s.CalculatePersistentStats()
-         return self.s.GetStats(use_html=use_html)
  
      def test_with_html(self):
--- 453,456 ----
***************
*** 449,452 ****
--- 464,504 ----
              self.assert_('&nbsp;' not in line)
  
+     def test_from_date_empty(self):
+         # Put persistent data in, but no session data.
+         self._stuff_with_persistent_data()
+         # Wait for a bit to make sure the time is later.
+         time.sleep(0.1)
+         # Set the date to now.
+         self.s.ResetTotal(permanently=True)
+         # Recalculate.
+         self.s.CalculatePersistentStats()
+         # Check.
+         self.assertEqual(self.s.GetStats(), ["Messages classified: 0"])
+ 
+     def test_from_specified_date(self):
+         # Put persistent data in, but no session data.
+         self._stuff_with_persistent_data()
+         # Wait for a bit to make sure the time is later.
+         time.sleep(0.1)
+         # Set the date to now.
+         self.s.from_date = time.time()
+         # Wait for a bit to make sure the time is later.
+         time.sleep(0.1)
+         # Put more data in.
+         msg = Message('0', self.messageinfo_db)
+         msg.RememberTrained(True)
+         msg.RememberClassification(options['Headers','header_spam_string'])
+         msg = Message('7', self.messageinfo_db)
+         msg.RememberTrained(False)
+         msg.RememberClassification(options['Headers','header_spam_string'])
+         msg = Message('2', self.messageinfo_db)
+         msg.RememberTrained(True)
+         msg.RememberClassification(options['Headers','header_ham_string'])
+         # Recalculate.
+         self.s.CalculatePersistentStats()
+         # Check that there are the right number of messages (assume that
+         # the rest is right - if not it should be caught by other tests).
+         self.assertEqual(self.s.GetStats()[0], "Messages classified: 3")
+         
  
  def suite():



More information about the Spambayes-checkins mailing list