[Python-checkins] r78774 - in python/trunk/Lib: test/test_unittest.py unittest/__init__.py unittest/case.py unittest/result.py unittest/suite.py

michael.foord python-checkins at python.org
Sun Mar 7 23:04:55 CET 2010


Author: michael.foord
Date: Sun Mar  7 23:04:55 2010
New Revision: 78774

Log:
Addition of setUpClass and setUpModule shared fixtures to unittest.

Modified:
   python/trunk/Lib/test/test_unittest.py
   python/trunk/Lib/unittest/__init__.py
   python/trunk/Lib/unittest/case.py
   python/trunk/Lib/unittest/result.py
   python/trunk/Lib/unittest/suite.py

Modified: python/trunk/Lib/test/test_unittest.py
==============================================================================
--- python/trunk/Lib/test/test_unittest.py	(original)
+++ python/trunk/Lib/test/test_unittest.py	Sun Mar  7 23:04:55 2010
@@ -22,6 +22,9 @@
 ### Support code
 ################################################################
 
+def resultFactory(*_):
+    return unittest.TestResult()
+
 class LoggingResult(unittest.TestResult):
     def __init__(self, log):
         self._events = log
@@ -3937,6 +3940,397 @@
         self.assertEqual(program.verbosity, 2)
 
 
+class TestSetups(unittest.TestCase):
+
+    def getRunner(self):
+        return unittest.TextTestRunner(resultclass=resultFactory,
+                                          stream=StringIO())
+    def runTests(self, *cases):
+        suite = unittest.TestSuite()
+        for case in cases:
+            tests = unittest.defaultTestLoader.loadTestsFromTestCase(case)
+            suite.addTests(tests)
+
+        runner = self.getRunner()
+
+        # creating a nested suite exposes some potential bugs
+        realSuite = unittest.TestSuite()
+        realSuite.addTest(suite)
+        # adding empty suites to the end exposes potential bugs
+        suite.addTest(unittest.TestSuite())
+        realSuite.addTest(unittest.TestSuite())
+        return runner.run(realSuite)
+
+    def test_setup_class(self):
+        class Test(unittest.TestCase):
+            setUpCalled = 0
+            @classmethod
+            def setUpClass(cls):
+                Test.setUpCalled += 1
+                unittest.TestCase.setUpClass()
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+
+        result = self.runTests(Test)
+
+        self.assertEqual(Test.setUpCalled, 1)
+        self.assertEqual(result.testsRun, 2)
+        self.assertEqual(len(result.errors), 0)
+
+    def test_teardown_class(self):
+        class Test(unittest.TestCase):
+            tearDownCalled = 0
+            @classmethod
+            def tearDownClass(cls):
+                Test.tearDownCalled += 1
+                unittest.TestCase.tearDownClass()
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+
+        result = self.runTests(Test)
+
+        self.assertEqual(Test.tearDownCalled, 1)
+        self.assertEqual(result.testsRun, 2)
+        self.assertEqual(len(result.errors), 0)
+
+    def test_teardown_class_two_classes(self):
+        class Test(unittest.TestCase):
+            tearDownCalled = 0
+            @classmethod
+            def tearDownClass(cls):
+                Test.tearDownCalled += 1
+                unittest.TestCase.tearDownClass()
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+
+        class Test2(unittest.TestCase):
+            tearDownCalled = 0
+            @classmethod
+            def tearDownClass(cls):
+                Test2.tearDownCalled += 1
+                unittest.TestCase.tearDownClass()
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+
+        result = self.runTests(Test, Test2)
+
+        self.assertEqual(Test.tearDownCalled, 1)
+        self.assertEqual(Test2.tearDownCalled, 1)
+        self.assertEqual(result.testsRun, 4)
+        self.assertEqual(len(result.errors), 0)
+
+    def test_error_in_setupclass(self):
+        class BrokenTest(unittest.TestCase):
+            @classmethod
+            def setUpClass(cls):
+                raise TypeError('foo')
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+
+        result = self.runTests(BrokenTest)
+
+        self.assertEqual(result.testsRun, 0)
+        self.assertEqual(len(result.errors), 1)
+        error, _ = result.errors[0]
+        self.assertEqual(str(error),
+                    'classSetUp (%s.BrokenTest)' % __name__)
+
+    def test_error_in_teardown_class(self):
+        class Test(unittest.TestCase):
+            tornDown = 0
+            @classmethod
+            def tearDownClass(cls):
+                Test.tornDown += 1
+                raise TypeError('foo')
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+
+        class Test2(unittest.TestCase):
+            tornDown = 0
+            @classmethod
+            def tearDownClass(cls):
+                Test2.tornDown += 1
+                raise TypeError('foo')
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+
+        result = self.runTests(Test, Test2)
+        self.assertEqual(result.testsRun, 4)
+        self.assertEqual(len(result.errors), 2)
+        self.assertEqual(Test.tornDown, 1)
+        self.assertEqual(Test2.tornDown, 1)
+
+        error, _ = result.errors[0]
+        self.assertEqual(str(error),
+                    'classTearDown (%s.Test)' % __name__)
+
+    def test_class_not_torndown_when_setup_fails(self):
+        class Test(unittest.TestCase):
+            tornDown = False
+            @classmethod
+            def setUpClass(cls):
+                raise TypeError
+            @classmethod
+            def tearDownClass(cls):
+                Test.tornDown = True
+                raise TypeError('foo')
+            def test_one(self):
+                pass
+
+        self.runTests(Test)
+        self.assertFalse(Test.tornDown)
+
+    def test_class_not_setup_or_torndown_when_skipped(self):
+        class Test(unittest.TestCase):
+            classSetUp = False
+            tornDown = False
+            @classmethod
+            def setUpClass(cls):
+                Test.classSetUp = True
+            @classmethod
+            def tearDownClass(cls):
+                Test.tornDown = True
+            def test_one(self):
+                pass
+
+        Test = unittest.skip("hop")(Test)
+        self.runTests(Test)
+        self.assertFalse(Test.classSetUp)
+        self.assertFalse(Test.tornDown)
+
+    def test_setup_teardown_order_with_pathological_suite(self):
+        results = []
+
+        class Module1(object):
+            @staticmethod
+            def setUpModule():
+                results.append('Module1.setUpModule')
+            @staticmethod
+            def tearDownModule():
+                results.append('Module1.tearDownModule')
+
+        class Module2(object):
+            @staticmethod
+            def setUpModule():
+                results.append('Module2.setUpModule')
+            @staticmethod
+            def tearDownModule():
+                results.append('Module2.tearDownModule')
+
+        class Test1(unittest.TestCase):
+            @classmethod
+            def setUpClass(cls):
+                results.append('setup 1')
+            @classmethod
+            def tearDownClass(cls):
+                results.append('teardown 1')
+            def testOne(self):
+                results.append('Test1.testOne')
+            def testTwo(self):
+                results.append('Test1.testTwo')
+
+        class Test2(unittest.TestCase):
+            @classmethod
+            def setUpClass(cls):
+                results.append('setup 2')
+            @classmethod
+            def tearDownClass(cls):
+                results.append('teardown 2')
+            def testOne(self):
+                results.append('Test2.testOne')
+            def testTwo(self):
+                results.append('Test2.testTwo')
+
+        class Test3(unittest.TestCase):
+            @classmethod
+            def setUpClass(cls):
+                results.append('setup 3')
+            @classmethod
+            def tearDownClass(cls):
+                results.append('teardown 3')
+            def testOne(self):
+                results.append('Test3.testOne')
+            def testTwo(self):
+                results.append('Test3.testTwo')
+
+        Test1.__module__ = Test2.__module__ = 'Module'
+        Test3.__module__ = 'Module2'
+        sys.modules['Module'] = Module1
+        sys.modules['Module2'] = Module2
+
+        first = unittest.TestSuite((Test1('testOne'),))
+        second = unittest.TestSuite((Test1('testTwo'),))
+        third = unittest.TestSuite((Test2('testOne'),))
+        fourth = unittest.TestSuite((Test2('testTwo'),))
+        fifth = unittest.TestSuite((Test3('testOne'),))
+        sixth = unittest.TestSuite((Test3('testTwo'),))
+        suite = unittest.TestSuite((first, second, third, fourth, fifth, sixth))
+
+        runner = self.getRunner()
+        result = runner.run(suite)
+        self.assertEqual(result.testsRun, 6)
+        self.assertEqual(len(result.errors), 0)
+
+        self.assertEqual(results,
+                         ['Module1.setUpModule', 'setup 1',
+                          'Test1.testOne', 'Test1.testTwo', 'teardown 1',
+                          'setup 2', 'Test2.testOne', 'Test2.testTwo',
+                          'teardown 2', 'Module1.tearDownModule',
+                          'Module2.setUpModule', 'setup 3',
+                          'Test3.testOne', 'Test3.testTwo',
+                          'teardown 3', 'Module2.tearDownModule'])
+
+    def test_setup_module(self):
+        class Module(object):
+            moduleSetup = 0
+            @staticmethod
+            def setUpModule():
+                Module.moduleSetup += 1
+
+        class Test(unittest.TestCase):
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+        Test.__module__ = 'Module'
+        sys.modules['Module'] = Module
+
+        result = self.runTests(Test)
+        self.assertEqual(Module.moduleSetup, 1)
+        self.assertEqual(result.testsRun, 2)
+        self.assertEqual(len(result.errors), 0)
+
+    def test_error_in_setup_module(self):
+        class Module(object):
+            moduleSetup = 0
+            moduleTornDown = 0
+            @staticmethod
+            def setUpModule():
+                Module.moduleSetup += 1
+                raise TypeError('foo')
+            @staticmethod
+            def tearDownModule():
+                Module.moduleTornDown += 1
+
+        class Test(unittest.TestCase):
+            classSetUp = False
+            classTornDown = False
+            @classmethod
+            def setUpClass(cls):
+                Test.classSetUp = True
+            @classmethod
+            def tearDownClass(cls):
+                Test.classTornDown = True
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+
+        class Test2(unittest.TestCase):
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+        Test.__module__ = 'Module'
+        Test2.__module__ = 'Module'
+        sys.modules['Module'] = Module
+
+        result = self.runTests(Test, Test2)
+        self.assertEqual(Module.moduleSetup, 1)
+        self.assertEqual(Module.moduleTornDown, 0)
+        self.assertEqual(result.testsRun, 0)
+        self.assertFalse(Test.classSetUp)
+        self.assertFalse(Test.classTornDown)
+        self.assertEqual(len(result.errors), 1)
+        error, _ = result.errors[0]
+        self.assertEqual(str(error), 'setUpModule (Module)')
+
+    def test_testcase_with_missing_module(self):
+        class Test(unittest.TestCase):
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+        Test.__module__ = 'Module'
+        sys.modules.pop('Module', None)
+
+        result = self.runTests(Test)
+        self.assertEqual(result.testsRun, 2)
+
+    def test_teardown_module(self):
+        class Module(object):
+            moduleTornDown = 0
+            @staticmethod
+            def tearDownModule():
+                Module.moduleTornDown += 1
+
+        class Test(unittest.TestCase):
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+        Test.__module__ = 'Module'
+        sys.modules['Module'] = Module
+
+        result = self.runTests(Test)
+        self.assertEqual(Module.moduleTornDown, 1)
+        self.assertEqual(result.testsRun, 2)
+        self.assertEqual(len(result.errors), 0)
+
+    def test_error_in_teardown_module(self):
+        class Module(object):
+            moduleTornDown = 0
+            @staticmethod
+            def tearDownModule():
+                Module.moduleTornDown += 1
+                raise TypeError('foo')
+
+        class Test(unittest.TestCase):
+            classSetUp = False
+            classTornDown = False
+            @classmethod
+            def setUpClass(cls):
+                Test.classSetUp = True
+            @classmethod
+            def tearDownClass(cls):
+                Test.classTornDown = True
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+
+        class Test2(unittest.TestCase):
+            def test_one(self):
+                pass
+            def test_two(self):
+                pass
+        Test.__module__ = 'Module'
+        Test2.__module__ = 'Module'
+        sys.modules['Module'] = Module
+
+        result = self.runTests(Test, Test2)
+        self.assertEqual(Module.moduleTornDown, 1)
+        self.assertEqual(result.testsRun, 4)
+        self.assertTrue(Test.classSetUp)
+        self.assertTrue(Test.classTornDown)
+        self.assertEqual(len(result.errors), 1)
+        error, _ = result.errors[0]
+        self.assertEqual(str(error), 'tearDownModule (Module)')
+
 ######################################################################
 ## Main
 ######################################################################
@@ -3946,7 +4340,7 @@
         Test_TestSuite, Test_TestResult, Test_FunctionTestCase,
         Test_TestSkipping, Test_Assertions, TestLongMessage,
         Test_TestProgram, TestCleanUp, TestDiscovery, Test_TextTestRunner,
-        Test_OldTestResult)
+        Test_OldTestResult, TestSetups)
 
 if __name__ == "__main__":
     test_main()

Modified: python/trunk/Lib/unittest/__init__.py
==============================================================================
--- python/trunk/Lib/unittest/__init__.py	(original)
+++ python/trunk/Lib/unittest/__init__.py	Sun Mar  7 23:04:55 2010
@@ -51,13 +51,12 @@
 
 # Expose obsolete functions for backwards compatibility
 __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])
-__all__.append('_TextTestResult')
 
 
 from .result import TestResult
 from .case import (TestCase, FunctionTestCase, SkipTest, skip, skipIf,
                    skipUnless, expectedFailure)
-from .suite import TestSuite
+from .suite import BaseTestSuite, TestSuite
 from .loader import (TestLoader, defaultTestLoader, makeSuite, getTestCaseNames,
                      findTestCases)
 from .main import TestProgram, main

Modified: python/trunk/Lib/unittest/case.py
==============================================================================
--- python/trunk/Lib/unittest/case.py	(original)
+++ python/trunk/Lib/unittest/case.py	Sun Mar  7 23:04:55 2010
@@ -153,6 +153,9 @@
 
     longMessage = False
 
+    # Attribute used by TestSuite for classSetUp
+
+    _classSetupFailed = False
 
     def __init__(self, methodName='runTest'):
         """Create an instance of the class that will use the named test
@@ -211,6 +214,14 @@
         "Hook method for deconstructing the test fixture after testing it."
         pass
 
+    @classmethod
+    def setUpClass(cls):
+        "Hook method for setting up class fixture before running tests in the class."
+
+    @classmethod
+    def tearDownClass(cls):
+        "Hook method for deconstructing the class fixture after running all tests in the class."
+
     def countTestCases(self):
         return 1
 

Modified: python/trunk/Lib/unittest/result.py
==============================================================================
--- python/trunk/Lib/unittest/result.py	(original)
+++ python/trunk/Lib/unittest/result.py	Sun Mar  7 23:04:55 2010
@@ -16,6 +16,8 @@
     contain tuples of (testcase, exceptioninfo), where exceptioninfo is the
     formatted traceback of the error that occurred.
     """
+    _previousTestClass = None
+    _moduleSetUpFailed = False
     def __init__(self, stream=None, descriptions=None, verbosity=None):
         self.failures = []
         self.errors = []

Modified: python/trunk/Lib/unittest/suite.py
==============================================================================
--- python/trunk/Lib/unittest/suite.py	(original)
+++ python/trunk/Lib/unittest/suite.py	Sun Mar  7 23:04:55 2010
@@ -1,17 +1,13 @@
 """TestSuite"""
 
+import sys
+
 from . import case
 from . import util
 
 
-class TestSuite(object):
-    """A test suite is a composite test consisting of a number of TestCases.
-
-    For use, create an instance of TestSuite, then add test case instances.
-    When all tests have been added, the suite can be passed to a test
-    runner, such as TextTestRunner. It will run the individual test cases
-    in the order in which they were added, aggregating the results. When
-    subclassing, do not forget to call the base class constructor.
+class BaseTestSuite(object):
+    """A simple test suite that doesn't provide class or module shared fixtures.
     """
     def __init__(self, tests=()):
         self._tests = []
@@ -70,3 +66,190 @@
         """Run the tests without collecting errors in a TestResult"""
         for test in self:
             test.debug()
+
+
+class TestSuite(BaseTestSuite):
+    """A test suite is a composite test consisting of a number of TestCases.
+
+    For use, create an instance of TestSuite, then add test case instances.
+    When all tests have been added, the suite can be passed to a test
+    runner, such as TextTestRunner. It will run the individual test cases
+    in the order in which they were added, aggregating the results. When
+    subclassing, do not forget to call the base class constructor.
+    """
+
+
+    def run(self, result):
+        self._wrapped_run(result)
+        self._tearDownPreviousClass(None, result)
+        self._handleModuleTearDown(result)
+        return result
+
+    ################################
+    # private methods
+    def _wrapped_run(self, result):
+        for test in self:
+            if result.shouldStop:
+                break
+
+            if _isnotsuite(test):
+                self._tearDownPreviousClass(test, result)
+                self._handleModuleFixture(test, result)
+                self._handleClassSetUp(test, result)
+                result._previousTestClass = test.__class__
+
+                if (getattr(test.__class__, '_classSetupFailed', False) or
+                    getattr(result, '_moduleSetUpFailed', False)):
+                    continue
+
+            if hasattr(test, '_wrapped_run'):
+                test._wrapped_run(result)
+            else:
+                test(result)
+
+    def _handleClassSetUp(self, test, result):
+        previousClass = getattr(result, '_previousTestClass', None)
+        currentClass = test.__class__
+        if currentClass == previousClass:
+            return
+        if result._moduleSetUpFailed:
+            return
+        if getattr(currentClass, "__unittest_skip__", False):
+            return
+
+        currentClass._classSetupFailed = False
+
+        setUpClass = getattr(currentClass, 'setUpClass', None)
+        if setUpClass is not None:
+            try:
+                setUpClass()
+            except:
+                currentClass._classSetupFailed = True
+                self._addClassSetUpError(result, currentClass)
+
+    def _get_previous_module(self, result):
+        previousModule = None
+        previousClass = getattr(result, '_previousTestClass', None)
+        if previousClass is not None:
+            previousModule = previousClass.__module__
+        return previousModule
+
+
+    def _handleModuleFixture(self, test, result):
+        previousModule = self._get_previous_module(result)
+        currentModule = test.__class__.__module__
+        if currentModule == previousModule:
+            return
+
+        self._handleModuleTearDown(result)
+
+
+        result._moduleSetUpFailed = False
+        try:
+            module = sys.modules[currentModule]
+        except KeyError:
+            return
+        setUpModule = getattr(module, 'setUpModule', None)
+        if setUpModule is not None:
+            try:
+                setUpModule()
+            except:
+                result._moduleSetUpFailed = True
+                error = _ErrorHolder('setUpModule (%s)' % currentModule)
+                result.addError(error, sys.exc_info())
+
+    def _handleModuleTearDown(self, result):
+        previousModule = self._get_previous_module(result)
+        if previousModule is None:
+            return
+        if result._moduleSetUpFailed:
+            return
+
+        try:
+            module = sys.modules[previousModule]
+        except KeyError:
+            return
+
+        tearDownModule = getattr(module, 'tearDownModule', None)
+        if tearDownModule is not None:
+            try:
+                tearDownModule()
+            except:
+                error = _ErrorHolder('tearDownModule (%s)' % previousModule)
+                result.addError(error, sys.exc_info())
+
+    def _tearDownPreviousClass(self, test, result):
+        previousClass = getattr(result, '_previousTestClass', None)
+        currentClass = test.__class__
+        if currentClass == previousClass:
+            return
+        if getattr(previousClass, '_classSetupFailed', False):
+            return
+        if getattr(result, '_moduleSetUpFailed', False):
+            return
+        if getattr(previousClass, "__unittest_skip__", False):
+            return
+
+        tearDownClass = getattr(previousClass, 'tearDownClass', None)
+        if tearDownClass is not None:
+            try:
+                tearDownClass()
+            except:
+                self._addClassTearDownError(result)
+
+    def _addClassTearDownError(self, result):
+        className = util.strclass(result._previousTestClass)
+        error = _ErrorHolder('classTearDown (%s)' % className)
+        result.addError(error, sys.exc_info())
+
+    def _addClassSetUpError(self, result, klass):
+        className = util.strclass(klass)
+        error = _ErrorHolder('classSetUp (%s)' % className)
+        result.addError(error, sys.exc_info())
+
+
+class _ErrorHolder(object):
+    """
+    Placeholder for a TestCase inside a result. As far as a TestResult
+    is concerned, this looks exactly like a unit test. Used to insert
+    arbitrary errors into a test suite run.
+    """
+    # Inspired by the ErrorHolder from Twisted:
+    # http://twistedmatrix.com/trac/browser/trunk/twisted/trial/runner.py
+
+    # attribute used by TestResult._exc_info_to_string
+    failureException = None
+
+    def __init__(self, description):
+        self.description = description
+
+    def id(self):
+        return self.description
+
+    def shortDescription(self):
+        return None
+
+    def __repr__(self):
+        return "<ErrorHolder description=%r>" % (self.description,)
+
+    def __str__(self):
+        return self.id()
+
+    def run(self, result):
+        # could call result.addError(...) - but this test-like object
+        # shouldn't be run anyway
+        pass
+
+    def __call__(self, result):
+        return self.run(result)
+
+    def countTestCases(self):
+        return 0
+
+def _isnotsuite(test):
+    "A crude way to tell apart testcases and suites with duck-typing"
+    try:
+        iter(test)
+    except TypeError:
+        return True
+    return False


More information about the Python-checkins mailing list