[Python-checkins] cpython: Issue #16997: unittest.TestCase now provides a subTest() context manager to

antoine.pitrou python-checkins at python.org
Wed Mar 20 20:21:04 CET 2013


http://hg.python.org/cpython/rev/5c09e1c57200
changeset:   82844:5c09e1c57200
parent:      82840:2873f4971281
user:        Antoine Pitrou <solipsis at pitrou.net>
date:        Wed Mar 20 20:16:47 2013 +0100
summary:
  Issue #16997: unittest.TestCase now provides a subTest() context manager to procedurally generate, in an easy way, small test instances.

files:
  Doc/library/unittest.rst           |   93 ++++++
  Lib/test/test_socket.py            |    4 -
  Lib/unittest/case.py               |  252 +++++++++++-----
  Lib/unittest/result.py             |   16 +
  Lib/unittest/test/support.py       |   38 ++-
  Lib/unittest/test/test_case.py     |   94 ++++++-
  Lib/unittest/test/test_result.py   |   58 +++
  Lib/unittest/test/test_runner.py   |   12 +-
  Lib/unittest/test/test_skipping.py |   74 ++++
  Misc/NEWS                          |    3 +
  10 files changed, 539 insertions(+), 105 deletions(-)


diff --git a/Doc/library/unittest.rst b/Doc/library/unittest.rst
--- a/Doc/library/unittest.rst
+++ b/Doc/library/unittest.rst
@@ -556,6 +556,68 @@
 Skipped classes will not have :meth:`setUpClass` or :meth:`tearDownClass` run.
 
 
+.. _subtests:
+
+Distinguishing test iterations using subtests
+---------------------------------------------
+
+.. versionadded:: 3.4
+
+When some of your tests differ only by a some very small differences, for
+instance some parameters, unittest allows you to distinguish them inside
+the body of a test method using the :meth:`~TestCase.subTest` context manager.
+
+For example, the following test::
+
+   class NumbersTest(unittest.TestCase):
+
+       def test_even(self):
+           """
+           Test that numbers between 0 and 5 are all even.
+           """
+           for i in range(0, 6):
+               with self.subTest(i=i):
+                   self.assertEqual(i % 2, 0)
+
+will produce the following output::
+
+   ======================================================================
+   FAIL: test_even (__main__.NumbersTest) (i=1)
+   ----------------------------------------------------------------------
+   Traceback (most recent call last):
+     File "subtests.py", line 32, in test_even
+       self.assertEqual(i % 2, 0)
+   AssertionError: 1 != 0
+
+   ======================================================================
+   FAIL: test_even (__main__.NumbersTest) (i=3)
+   ----------------------------------------------------------------------
+   Traceback (most recent call last):
+     File "subtests.py", line 32, in test_even
+       self.assertEqual(i % 2, 0)
+   AssertionError: 1 != 0
+
+   ======================================================================
+   FAIL: test_even (__main__.NumbersTest) (i=5)
+   ----------------------------------------------------------------------
+   Traceback (most recent call last):
+     File "subtests.py", line 32, in test_even
+       self.assertEqual(i % 2, 0)
+   AssertionError: 1 != 0
+
+Without using a subtest, execution would stop after the first failure,
+and the error would be less easy to diagnose because the value of ``i``
+wouldn't be displayed::
+
+   ======================================================================
+   FAIL: test_even (__main__.NumbersTest)
+   ----------------------------------------------------------------------
+   Traceback (most recent call last):
+     File "subtests.py", line 32, in test_even
+       self.assertEqual(i % 2, 0)
+   AssertionError: 1 != 0
+
+
 .. _unittest-contents:
 
 Classes and functions
@@ -669,6 +731,21 @@
       .. versionadded:: 3.1
 
 
+   .. method:: subTest(msg=None, **params)
+
+      Return a context manager which executes the enclosed code block as a
+      subtest.  *msg* and *params* are optional, arbitrary values which are
+      displayed whenever a subtest fails, allowing you to identify them
+      clearly.
+
+      A test case can contain any number of subtest declarations, and
+      they can be arbitrarily nested.
+
+      See :ref:`subtests` for more information.
+
+      .. versionadded:: 3.4
+
+
    .. method:: debug()
 
       Run the test without collecting the result.  This allows exceptions raised
@@ -1733,6 +1810,22 @@
       :attr:`unexpectedSuccesses` attribute.
 
 
+   .. method:: addSubTest(test, subtest, outcome)
+
+      Called when a subtest finishes.  *test* is the test case
+      corresponding to the test method.  *subtest* is a custom
+      :class:`TestCase` instance describing the subtest.
+
+      If *outcome* is :const:`None`, the subtest succeeded.  Otherwise,
+      it failed with an exception where *outcome* is a tuple of the form
+      returned by :func:`sys.exc_info`: ``(type, value, traceback)``.
+
+      The default implementation does nothing when the outcome is a
+      success, and records subtest failures as normal failures.
+
+      .. versionadded:: 3.4
+
+
 .. class:: TextTestResult(stream, descriptions, verbosity)
 
    A concrete implementation of :class:`TestResult` used by the
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -2,7 +2,6 @@
 
 import unittest
 from test import support
-from unittest.case import _ExpectedFailure
 
 import errno
 import io
@@ -272,9 +271,6 @@
             raise TypeError("test_func must be a callable function")
         try:
             test_func()
-        except _ExpectedFailure:
-            # We deliberately ignore expected failures
-            pass
         except BaseException as e:
             self.queue.put(e)
         finally:
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -7,6 +7,7 @@
 import re
 import warnings
 import collections
+import contextlib
 
 from . import result
 from .util import (strclass, safe_repr, _count_diff_all_purpose,
@@ -26,17 +27,11 @@
     instead of raising this directly.
     """
 
-class _ExpectedFailure(Exception):
+class _ShouldStop(Exception):
     """
-    Raise this when a test is expected to fail.
-
-    This is an implementation detail.
+    The test should stop.
     """
 
-    def __init__(self, exc_info):
-        super(_ExpectedFailure, self).__init__()
-        self.exc_info = exc_info
-
 class _UnexpectedSuccess(Exception):
     """
     The test was supposed to fail, but it didn't!
@@ -44,13 +39,40 @@
 
 
 class _Outcome(object):
-    def __init__(self):
+    def __init__(self, result=None):
+        self.expecting_failure = False
+        self.result = result
+        self.result_supports_subtests = hasattr(result, "addSubTest")
         self.success = True
-        self.skipped = None
-        self.unexpectedSuccess = None
+        self.skipped = []
         self.expectedFailure = None
         self.errors = []
-        self.failures = []
+
+    @contextlib.contextmanager
+    def testPartExecutor(self, test_case, isTest=False):
+        old_success = self.success
+        self.success = True
+        try:
+            yield
+        except KeyboardInterrupt:
+            raise
+        except SkipTest as e:
+            self.success = False
+            self.skipped.append((test_case, str(e)))
+        except _ShouldStop:
+            pass
+        except:
+            exc_info = sys.exc_info()
+            if self.expecting_failure:
+                self.expectedFailure = exc_info
+            else:
+                self.success = False
+                self.errors.append((test_case, exc_info))
+        else:
+            if self.result_supports_subtests and self.success:
+                self.errors.append((test_case, None))
+        finally:
+            self.success = self.success and old_success
 
 
 def _id(obj):
@@ -88,16 +110,9 @@
         return skip(reason)
     return _id
 
-
-def expectedFailure(func):
-    @functools.wraps(func)
-    def wrapper(*args, **kwargs):
-        try:
-            func(*args, **kwargs)
-        except Exception:
-            raise _ExpectedFailure(sys.exc_info())
-        raise _UnexpectedSuccess
-    return wrapper
+def expectedFailure(test_item):
+    test_item.__unittest_expecting_failure__ = True
+    return test_item
 
 
 class _AssertRaisesBaseContext(object):
@@ -271,7 +286,7 @@
            not have a method with the specified name.
         """
         self._testMethodName = methodName
-        self._outcomeForDoCleanups = None
+        self._outcome = None
         self._testMethodDoc = 'No test'
         try:
             testMethod = getattr(self, methodName)
@@ -284,6 +299,7 @@
         else:
             self._testMethodDoc = testMethod.__doc__
         self._cleanups = []
+        self._subtest = None
 
         # Map types to custom assertEqual functions that will compare
         # instances of said type in more detail to generate a more useful
@@ -371,44 +387,80 @@
         return "<%s testMethod=%s>" % \
                (strclass(self.__class__), self._testMethodName)
 
-    def _addSkip(self, result, reason):
+    def _addSkip(self, result, test_case, reason):
         addSkip = getattr(result, 'addSkip', None)
         if addSkip is not None:
-            addSkip(self, reason)
+            addSkip(test_case, reason)
         else:
             warnings.warn("TestResult has no addSkip method, skips not reported",
                           RuntimeWarning, 2)
+            result.addSuccess(test_case)
+
+    @contextlib.contextmanager
+    def subTest(self, msg=None, **params):
+        """Return a context manager that will return the enclosed block
+        of code in a subtest identified by the optional message and
+        keyword parameters.  A failure in the subtest marks the test
+        case as failed but resumes execution at the end of the enclosed
+        block, allowing further test code to be executed.
+        """
+        if not self._outcome.result_supports_subtests:
+            yield
+            return
+        parent = self._subtest
+        if parent is None:
+            params_map = collections.ChainMap(params)
+        else:
+            params_map = parent.params.new_child(params)
+        self._subtest = _SubTest(self, msg, params_map)
+        try:
+            with self._outcome.testPartExecutor(self._subtest, isTest=True):
+                yield
+            if not self._outcome.success:
+                result = self._outcome.result
+                if result is not None and result.failfast:
+                    raise _ShouldStop
+            elif self._outcome.expectedFailure:
+                # If the test is expecting a failure, we really want to
+                # stop now and register the expected failure.
+                raise _ShouldStop
+        finally:
+            self._subtest = parent
+
+    def _feedErrorsToResult(self, result, errors):
+        for test, exc_info in errors:
+            if isinstance(test, _SubTest):
+                result.addSubTest(test.test_case, test, exc_info)
+            elif exc_info is not None:
+                if issubclass(exc_info[0], self.failureException):
+                    result.addFailure(test, exc_info)
+                else:
+                    result.addError(test, exc_info)
+
+    def _addExpectedFailure(self, result, exc_info):
+        try:
+            addExpectedFailure = result.addExpectedFailure
+        except AttributeError:
+            warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
+                          RuntimeWarning)
             result.addSuccess(self)
+        else:
+            addExpectedFailure(self, exc_info)
 
-    def _executeTestPart(self, function, outcome, isTest=False):
+    def _addUnexpectedSuccess(self, result):
         try:
-            function()
-        except KeyboardInterrupt:
-            raise
-        except SkipTest as e:
-            outcome.success = False
-            outcome.skipped = str(e)
-        except _UnexpectedSuccess:
-            exc_info = sys.exc_info()
-            outcome.success = False
-            if isTest:
-                outcome.unexpectedSuccess = exc_info
-            else:
-                outcome.errors.append(exc_info)
-        except _ExpectedFailure:
-            outcome.success = False
-            exc_info = sys.exc_info()
-            if isTest:
-                outcome.expectedFailure = exc_info
-            else:
-                outcome.errors.append(exc_info)
-        except self.failureException:
-            outcome.success = False
-            outcome.failures.append(sys.exc_info())
-            exc_info = sys.exc_info()
-        except:
-            outcome.success = False
-            outcome.errors.append(sys.exc_info())
+            addUnexpectedSuccess = result.addUnexpectedSuccess
+        except AttributeError:
+            warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failure",
+                          RuntimeWarning)
+            # We need to pass an actual exception and traceback to addFailure,
+            # otherwise the legacy result can choke.
+            try:
+                raise _UnexpectedSuccess from None
+            except _UnexpectedSuccess:
+                result.addFailure(self, sys.exc_info())
+        else:
+            addUnexpectedSuccess(self)
 
     def run(self, result=None):
         orig_result = result
@@ -427,46 +479,38 @@
             try:
                 skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
                             or getattr(testMethod, '__unittest_skip_why__', ''))
-                self._addSkip(result, skip_why)
+                self._addSkip(result, self, skip_why)
             finally:
                 result.stopTest(self)
             return
+        expecting_failure = getattr(testMethod,
+                                    "__unittest_expecting_failure__", False)
         try:
-            outcome = _Outcome()
-            self._outcomeForDoCleanups = outcome
+            outcome = _Outcome(result)
+            self._outcome = outcome
 
-            self._executeTestPart(self.setUp, outcome)
+            with outcome.testPartExecutor(self):
+                self.setUp()
             if outcome.success:
-                self._executeTestPart(testMethod, outcome, isTest=True)
-                self._executeTestPart(self.tearDown, outcome)
+                outcome.expecting_failure = expecting_failure
+                with outcome.testPartExecutor(self, isTest=True):
+                    testMethod()
+                outcome.expecting_failure = False
+                with outcome.testPartExecutor(self):
+                    self.tearDown()
 
             self.doCleanups()
+            for test, reason in outcome.skipped:
+                self._addSkip(result, test, reason)
+            self._feedErrorsToResult(result, outcome.errors)
             if outcome.success:
-                result.addSuccess(self)
-            else:
-                if outcome.skipped is not None:
-                    self._addSkip(result, outcome.skipped)
-                for exc_info in outcome.errors:
-                    result.addError(self, exc_info)
-                for exc_info in outcome.failures:
-                    result.addFailure(self, exc_info)
-                if outcome.unexpectedSuccess is not None:
-                    addUnexpectedSuccess = getattr(result, 'addUnexpectedSuccess', None)
-                    if addUnexpectedSuccess is not None:
-                        addUnexpectedSuccess(self)
+                if expecting_failure:
+                    if outcome.expectedFailure:
+                        self._addExpectedFailure(result, outcome.expectedFailure)
                     else:
-                        warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failures",
-                                      RuntimeWarning)
-                        result.addFailure(self, outcome.unexpectedSuccess)
-
-                if outcome.expectedFailure is not None:
-                    addExpectedFailure = getattr(result, 'addExpectedFailure', None)
-                    if addExpectedFailure is not None:
-                        addExpectedFailure(self, outcome.expectedFailure)
-                    else:
-                        warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
-                                      RuntimeWarning)
-                        result.addSuccess(self)
+                        self._addUnexpectedSuccess(result)
+                else:
+                    result.addSuccess(self)
             return result
         finally:
             result.stopTest(self)
@@ -478,11 +522,11 @@
     def doCleanups(self):
         """Execute all cleanup functions. Normally called for you after
         tearDown."""
-        outcome = self._outcomeForDoCleanups or _Outcome()
+        outcome = self._outcome or _Outcome()
         while self._cleanups:
             function, args, kwargs = self._cleanups.pop()
-            part = lambda: function(*args, **kwargs)
-            self._executeTestPart(part, outcome)
+            with outcome.testPartExecutor(self):
+                function(*args, **kwargs)
 
         # return this for backwards compatibility
         # even though we no longer us it internally
@@ -1213,3 +1257,39 @@
             return self._description
         doc = self._testFunc.__doc__
         return doc and doc.split("\n")[0].strip() or None
+
+
+class _SubTest(TestCase):
+
+    def __init__(self, test_case, message, params):
+        super().__init__()
+        self._message = message
+        self.test_case = test_case
+        self.params = params
+        self.failureException = test_case.failureException
+
+    def runTest(self):
+        raise NotImplementedError("subtests cannot be run directly")
+
+    def _subDescription(self):
+        parts = []
+        if self._message:
+            parts.append("[{}]".format(self._message))
+        if self.params:
+            params_desc = ', '.join(
+                "{}={!r}".format(k, v)
+                for (k, v) in sorted(self.params.items()))
+            parts.append("({})".format(params_desc))
+        return " ".join(parts) or '(<subtest>)'
+
+    def id(self):
+        return "{} {}".format(self.test_case.id(), self._subDescription())
+
+    def shortDescription(self):
+        """Returns a one-line description of the subtest, or None if no
+        description has been provided.
+        """
+        return self.test_case.shortDescription()
+
+    def __str__(self):
+        return "{} {}".format(self.test_case, self._subDescription())
diff --git a/Lib/unittest/result.py b/Lib/unittest/result.py
--- a/Lib/unittest/result.py
+++ b/Lib/unittest/result.py
@@ -121,6 +121,22 @@
         self.failures.append((test, self._exc_info_to_string(err, test)))
         self._mirrorOutput = True
 
+    @failfast
+    def addSubTest(self, test, subtest, err):
+        """Called at the end of a subtest.
+        'err' is None if the subtest ended successfully, otherwise it's a
+        tuple of values as returned by sys.exc_info().
+        """
+        # By default, we don't do anything with successful subtests, but
+        # more sophisticated test results might want to record them.
+        if err is not None:
+            if issubclass(err[0], test.failureException):
+                errors = self.failures
+            else:
+                errors = self.errors
+            errors.append((test, self._exc_info_to_string(err, test)))
+            self._mirrorOutput = True
+
     def addSuccess(self, test):
         "Called when a test has completed successfully"
         pass
diff --git a/Lib/unittest/test/support.py b/Lib/unittest/test/support.py
--- a/Lib/unittest/test/support.py
+++ b/Lib/unittest/test/support.py
@@ -41,7 +41,7 @@
                 self.fail("Problem hashing %s and %s: %s" % (obj_1, obj_2, e))
 
 
-class LoggingResult(unittest.TestResult):
+class _BaseLoggingResult(unittest.TestResult):
     def __init__(self, log):
         self._events = log
         super().__init__()
@@ -52,7 +52,7 @@
 
     def startTestRun(self):
         self._events.append('startTestRun')
-        super(LoggingResult, self).startTestRun()
+        super().startTestRun()
 
     def stopTest(self, test):
         self._events.append('stopTest')
@@ -60,7 +60,7 @@
 
     def stopTestRun(self):
         self._events.append('stopTestRun')
-        super(LoggingResult, self).stopTestRun()
+        super().stopTestRun()
 
     def addFailure(self, *args):
         self._events.append('addFailure')
@@ -68,7 +68,7 @@
 
     def addSuccess(self, *args):
         self._events.append('addSuccess')
-        super(LoggingResult, self).addSuccess(*args)
+        super().addSuccess(*args)
 
     def addError(self, *args):
         self._events.append('addError')
@@ -76,15 +76,39 @@
 
     def addSkip(self, *args):
         self._events.append('addSkip')
-        super(LoggingResult, self).addSkip(*args)
+        super().addSkip(*args)
 
     def addExpectedFailure(self, *args):
         self._events.append('addExpectedFailure')
-        super(LoggingResult, self).addExpectedFailure(*args)
+        super().addExpectedFailure(*args)
 
     def addUnexpectedSuccess(self, *args):
         self._events.append('addUnexpectedSuccess')
-        super(LoggingResult, self).addUnexpectedSuccess(*args)
+        super().addUnexpectedSuccess(*args)
+
+
+class LegacyLoggingResult(_BaseLoggingResult):
+    """
+    A legacy TestResult implementation, without an addSubTest method,
+    which records its method calls.
+    """
+
+    @property
+    def addSubTest(self):
+        raise AttributeError
+
+
+class LoggingResult(_BaseLoggingResult):
+    """
+    A TestResult implementation which records its method calls.
+    """
+
+    def addSubTest(self, test, subtest, err):
+        if err is None:
+            self._events.append('addSubTestSuccess')
+        else:
+            self._events.append('addSubTestFailure')
+        super().addSubTest(test, subtest, err)
 
 
 class ResultWithNoStartTestRunStopTestRun(object):
diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py
--- a/Lib/unittest/test/test_case.py
+++ b/Lib/unittest/test/test_case.py
@@ -13,7 +13,7 @@
 import unittest
 
 from .support import (
-    TestEquality, TestHashing, LoggingResult,
+    TestEquality, TestHashing, LoggingResult, LegacyLoggingResult,
     ResultWithNoStartTestRunStopTestRun
 )
 
@@ -297,6 +297,98 @@
 
         Foo('test').run()
 
+    def _check_call_order__subtests(self, result, events, expected_events):
+        class Foo(Test.LoggingTestCase):
+            def test(self):
+                super(Foo, self).test()
+                for i in [1, 2, 3]:
+                    with self.subTest(i=i):
+                        if i == 1:
+                            self.fail('failure')
+                        for j in [2, 3]:
+                            with self.subTest(j=j):
+                                if i * j == 6:
+                                    raise RuntimeError('raised by Foo.test')
+                1 / 0
+
+        # Order is the following:
+        # i=1 => subtest failure
+        # i=2, j=2 => subtest success
+        # i=2, j=3 => subtest error
+        # i=3, j=2 => subtest error
+        # i=3, j=3 => subtest success
+        # toplevel => error
+        Foo(events).run(result)
+        self.assertEqual(events, expected_events)
+
+    def test_run_call_order__subtests(self):
+        events = []
+        result = LoggingResult(events)
+        expected = ['startTest', 'setUp', 'test', 'tearDown',
+                    'addSubTestFailure', 'addSubTestSuccess',
+                    'addSubTestFailure', 'addSubTestFailure',
+                    'addSubTestSuccess', 'addError', 'stopTest']
+        self._check_call_order__subtests(result, events, expected)
+
+    def test_run_call_order__subtests_legacy(self):
+        # With a legacy result object (without a addSubTest method),
+        # text execution stops after the first subtest failure.
+        events = []
+        result = LegacyLoggingResult(events)
+        expected = ['startTest', 'setUp', 'test', 'tearDown',
+                    'addFailure', 'stopTest']
+        self._check_call_order__subtests(result, events, expected)
+
+    def _check_call_order__subtests_success(self, result, events, expected_events):
+        class Foo(Test.LoggingTestCase):
+            def test(self):
+                super(Foo, self).test()
+                for i in [1, 2]:
+                    with self.subTest(i=i):
+                        for j in [2, 3]:
+                            with self.subTest(j=j):
+                                pass
+
+        Foo(events).run(result)
+        self.assertEqual(events, expected_events)
+
+    def test_run_call_order__subtests_success(self):
+        events = []
+        result = LoggingResult(events)
+        # The 6 subtest successes are individually recorded, in addition
+        # to the whole test success.
+        expected = (['startTest', 'setUp', 'test', 'tearDown']
+                    + 6 * ['addSubTestSuccess']
+                    + ['addSuccess', 'stopTest'])
+        self._check_call_order__subtests_success(result, events, expected)
+
+    def test_run_call_order__subtests_success_legacy(self):
+        # With a legacy result, only the whole test success is recorded.
+        events = []
+        result = LegacyLoggingResult(events)
+        expected = ['startTest', 'setUp', 'test', 'tearDown',
+                    'addSuccess', 'stopTest']
+        self._check_call_order__subtests_success(result, events, expected)
+
+    def test_run_call_order__subtests_failfast(self):
+        events = []
+        result = LoggingResult(events)
+        result.failfast = True
+
+        class Foo(Test.LoggingTestCase):
+            def test(self):
+                super(Foo, self).test()
+                with self.subTest(i=1):
+                    self.fail('failure')
+                with self.subTest(i=2):
+                    self.fail('failure')
+                self.fail('failure')
+
+        expected = ['startTest', 'setUp', 'test', 'tearDown',
+                    'addSubTestFailure', 'stopTest']
+        Foo(events).run(result)
+        self.assertEqual(events, expected)
+
     # "This class attribute gives the exception raised by the test() method.
     # If a test framework needs to use a specialized exception, possibly to
     # carry additional information, it must subclass this exception in
diff --git a/Lib/unittest/test/test_result.py b/Lib/unittest/test/test_result.py
--- a/Lib/unittest/test/test_result.py
+++ b/Lib/unittest/test/test_result.py
@@ -234,6 +234,37 @@
                 'testGetDescriptionWithoutDocstring (' + __name__ +
                 '.Test_TestResult)')
 
+    def testGetSubTestDescriptionWithoutDocstring(self):
+        with self.subTest(foo=1, bar=2):
+            result = unittest.TextTestResult(None, True, 1)
+            self.assertEqual(
+                    result.getDescription(self._subtest),
+                    'testGetSubTestDescriptionWithoutDocstring (' + __name__ +
+                    '.Test_TestResult) (bar=2, foo=1)')
+        with self.subTest('some message'):
+            result = unittest.TextTestResult(None, True, 1)
+            self.assertEqual(
+                    result.getDescription(self._subtest),
+                    'testGetSubTestDescriptionWithoutDocstring (' + __name__ +
+                    '.Test_TestResult) [some message]')
+
+    def testGetSubTestDescriptionWithoutDocstringAndParams(self):
+        with self.subTest():
+            result = unittest.TextTestResult(None, True, 1)
+            self.assertEqual(
+                    result.getDescription(self._subtest),
+                    'testGetSubTestDescriptionWithoutDocstringAndParams '
+                    '(' + __name__ + '.Test_TestResult) (<subtest>)')
+
+    def testGetNestedSubTestDescriptionWithoutDocstring(self):
+        with self.subTest(foo=1):
+            with self.subTest(bar=2):
+                result = unittest.TextTestResult(None, True, 1)
+                self.assertEqual(
+                        result.getDescription(self._subtest),
+                        'testGetNestedSubTestDescriptionWithoutDocstring '
+                        '(' + __name__ + '.Test_TestResult) (bar=2, foo=1)')
+
     @unittest.skipIf(sys.flags.optimize >= 2,
                      "Docstrings are omitted with -O2 and above")
     def testGetDescriptionWithOneLineDocstring(self):
@@ -247,6 +278,18 @@
 
     @unittest.skipIf(sys.flags.optimize >= 2,
                      "Docstrings are omitted with -O2 and above")
+    def testGetSubTestDescriptionWithOneLineDocstring(self):
+        """Tests getDescription() for a method with a docstring."""
+        result = unittest.TextTestResult(None, True, 1)
+        with self.subTest(foo=1, bar=2):
+            self.assertEqual(
+                result.getDescription(self._subtest),
+               ('testGetSubTestDescriptionWithOneLineDocstring '
+                '(' + __name__ + '.Test_TestResult) (bar=2, foo=1)\n'
+                'Tests getDescription() for a method with a docstring.'))
+
+    @unittest.skipIf(sys.flags.optimize >= 2,
+                     "Docstrings are omitted with -O2 and above")
     def testGetDescriptionWithMultiLineDocstring(self):
         """Tests getDescription() for a method with a longer docstring.
         The second line of the docstring.
@@ -259,6 +302,21 @@
                 'Tests getDescription() for a method with a longer '
                 'docstring.'))
 
+    @unittest.skipIf(sys.flags.optimize >= 2,
+                     "Docstrings are omitted with -O2 and above")
+    def testGetSubTestDescriptionWithMultiLineDocstring(self):
+        """Tests getDescription() for a method with a longer docstring.
+        The second line of the docstring.
+        """
+        result = unittest.TextTestResult(None, True, 1)
+        with self.subTest(foo=1, bar=2):
+            self.assertEqual(
+                result.getDescription(self._subtest),
+               ('testGetSubTestDescriptionWithMultiLineDocstring '
+                '(' + __name__ + '.Test_TestResult) (bar=2, foo=1)\n'
+                'Tests getDescription() for a method with a longer '
+                'docstring.'))
+
     def testStackFrameTrimming(self):
         class Frame(object):
             class tb_frame(object):
diff --git a/Lib/unittest/test/test_runner.py b/Lib/unittest/test/test_runner.py
--- a/Lib/unittest/test/test_runner.py
+++ b/Lib/unittest/test/test_runner.py
@@ -5,6 +5,7 @@
 import subprocess
 
 import unittest
+from unittest.case import _Outcome
 
 from .support import LoggingResult, ResultWithNoStartTestRunStopTestRun
 
@@ -42,12 +43,8 @@
             def testNothing(self):
                 pass
 
-        class MockOutcome(object):
-            success = True
-            errors = []
-
         test = TestableTest('testNothing')
-        test._outcomeForDoCleanups = MockOutcome
+        outcome = test._outcome = _Outcome()
 
         exc1 = Exception('foo')
         exc2 = Exception('bar')
@@ -61,9 +58,10 @@
         test.addCleanup(cleanup2)
 
         self.assertFalse(test.doCleanups())
-        self.assertFalse(MockOutcome.success)
+        self.assertFalse(outcome.success)
 
-        (Type1, instance1, _), (Type2, instance2, _) = reversed(MockOutcome.errors)
+        ((_, (Type1, instance1, _)),
+         (_, (Type2, instance2, _))) = reversed(outcome.errors)
         self.assertEqual((Type1, instance1), (Exception, exc1))
         self.assertEqual((Type2, instance2), (Exception, exc2))
 
diff --git a/Lib/unittest/test/test_skipping.py b/Lib/unittest/test/test_skipping.py
--- a/Lib/unittest/test/test_skipping.py
+++ b/Lib/unittest/test/test_skipping.py
@@ -29,6 +29,31 @@
         self.assertEqual(result.skipped, [(test, "testing")])
         self.assertEqual(result.testsRun, 1)
 
+    def test_skipping_subtests(self):
+        class Foo(unittest.TestCase):
+            def test_skip_me(self):
+                with self.subTest(a=1):
+                    with self.subTest(b=2):
+                        self.skipTest("skip 1")
+                    self.skipTest("skip 2")
+                self.skipTest("skip 3")
+        events = []
+        result = LoggingResult(events)
+        test = Foo("test_skip_me")
+        test.run(result)
+        self.assertEqual(events, ['startTest', 'addSkip', 'addSkip',
+                                  'addSkip', 'stopTest'])
+        self.assertEqual(len(result.skipped), 3)
+        subtest, msg = result.skipped[0]
+        self.assertEqual(msg, "skip 1")
+        self.assertIsInstance(subtest, unittest.TestCase)
+        self.assertIsNot(subtest, test)
+        subtest, msg = result.skipped[1]
+        self.assertEqual(msg, "skip 2")
+        self.assertIsInstance(subtest, unittest.TestCase)
+        self.assertIsNot(subtest, test)
+        self.assertEqual(result.skipped[2], (test, "skip 3"))
+
     def test_skipping_decorators(self):
         op_table = ((unittest.skipUnless, False, True),
                     (unittest.skipIf, True, False))
@@ -95,6 +120,31 @@
         self.assertEqual(result.expectedFailures[0][0], test)
         self.assertTrue(result.wasSuccessful())
 
+    def test_expected_failure_subtests(self):
+        # A failure in any subtest counts as the expected failure of the
+        # whole test.
+        class Foo(unittest.TestCase):
+            @unittest.expectedFailure
+            def test_die(self):
+                with self.subTest():
+                    # This one succeeds
+                    pass
+                with self.subTest():
+                    self.fail("help me!")
+                with self.subTest():
+                    # This one doesn't get executed
+                    self.fail("shouldn't come here")
+        events = []
+        result = LoggingResult(events)
+        test = Foo("test_die")
+        test.run(result)
+        self.assertEqual(events,
+                         ['startTest', 'addSubTestSuccess',
+                          'addExpectedFailure', 'stopTest'])
+        self.assertEqual(len(result.expectedFailures), 1)
+        self.assertIs(result.expectedFailures[0][0], test)
+        self.assertTrue(result.wasSuccessful())
+
     def test_unexpected_success(self):
         class Foo(unittest.TestCase):
             @unittest.expectedFailure
@@ -110,6 +160,30 @@
         self.assertEqual(result.unexpectedSuccesses, [test])
         self.assertTrue(result.wasSuccessful())
 
+    def test_unexpected_success_subtests(self):
+        # Success in all subtests counts as the unexpected success of
+        # the whole test.
+        class Foo(unittest.TestCase):
+            @unittest.expectedFailure
+            def test_die(self):
+                with self.subTest():
+                    # This one succeeds
+                    pass
+                with self.subTest():
+                    # So does this one
+                    pass
+        events = []
+        result = LoggingResult(events)
+        test = Foo("test_die")
+        test.run(result)
+        self.assertEqual(events,
+                         ['startTest',
+                          'addSubTestSuccess', 'addSubTestSuccess',
+                          'addUnexpectedSuccess', 'stopTest'])
+        self.assertFalse(result.failures)
+        self.assertEqual(result.unexpectedSuccesses, [test])
+        self.assertTrue(result.wasSuccessful())
+
     def test_skip_doesnt_run_setup(self):
         class Foo(unittest.TestCase):
             wasSetUp = False
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -292,6 +292,9 @@
 Library
 -------
 
+- Issue #16997: unittest.TestCase now provides a subTest() context manager
+  to procedurally generate, in an easy way, small test instances.
+
 - Issue #17485: Also delete the Request Content-Length header if the data
   attribute is deleted.  (Follow on to issue 16464).
 

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list