[Python-checkins] cpython: Issue #18937: Add an assertLogs() context manager to unittest.TestCase to

antoine.pitrou python-checkins at python.org
Sat Sep 14 19:45:54 CEST 2013


http://hg.python.org/cpython/rev/4f5815747f58
changeset:   85701:4f5815747f58
user:        Antoine Pitrou <solipsis at pitrou.net>
date:        Sat Sep 14 19:45:47 2013 +0200
summary:
  Issue #18937: Add an assertLogs() context manager to unittest.TestCase to ensure that a block of code emits a message using the logging module.

files:
  Doc/library/unittest.rst       |   41 +++++++
  Lib/unittest/case.py           |  109 +++++++++++++++++++-
  Lib/unittest/test/test_case.py |   96 ++++++++++++++++++
  Misc/NEWS                      |    2 +
  4 files changed, 242 insertions(+), 6 deletions(-)


diff --git a/Doc/library/unittest.rst b/Doc/library/unittest.rst
--- a/Doc/library/unittest.rst
+++ b/Doc/library/unittest.rst
@@ -1031,6 +1031,47 @@
       .. versionchanged:: 3.3
          Added the *msg* keyword argument when used as a context manager.
 
+   .. method:: assertLogs(logger=None, level=None)
+
+      A context manager to test that at least one message is logged on
+      the *logger* or one of its children, with at least the given
+      *level*.
+
+      If given, *logger* should be a :class:`logging.Logger` object or a
+      :class:`str` giving the name of a logger.  The default is the root
+      logger, which will catch all messages.
+
+      If given, *level* should be either a numeric logging level or
+      its string equivalent (for example either ``"ERROR"`` or
+      :attr:`logging.ERROR`).  The default is :attr:`logging.INFO`.
+
+      The test passes if at least one message emitted inside the ``with``
+      block matches the *logger* and *level* conditions, otherwise it fails.
+
+      The object returned by the context manager is a recording helper
+      which keeps tracks of the matching log messages.  It has two
+      attributes:
+
+      .. attribute:: records
+
+         A list of :class:`logging.LogRecord` objects of the matching
+         log messages.
+
+      .. attribute:: output
+
+         A list of :class:`str` objects with the formatted output of
+         matching messages.
+
+      Example::
+
+         with self.assertLogs('foo', level='INFO') as cm:
+            logging.getLogger('foo').info('first message')
+            logging.getLogger('foo.bar').error('second message')
+         self.assertEqual(cm.output, ['INFO:foo:first message',
+                                      'ERROR:foo.bar:second message'])
+
+      .. versionadded:: 3.4
+
 
    There are also other methods used to perform more specific checks, such as:
 
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -3,6 +3,7 @@
 import sys
 import functools
 import difflib
+import logging
 import pprint
 import re
 import warnings
@@ -115,10 +116,21 @@
     return test_item
 
 
-class _AssertRaisesBaseContext(object):
+class _BaseTestCaseContext:
+
+    def __init__(self, test_case):
+        self.test_case = test_case
+
+    def _raiseFailure(self, standardMsg):
+        msg = self.test_case._formatMessage(self.msg, standardMsg)
+        raise self.test_case.failureException(msg)
+
+
+class _AssertRaisesBaseContext(_BaseTestCaseContext):
 
     def __init__(self, expected, test_case, callable_obj=None,
                  expected_regex=None):
+        _BaseTestCaseContext.__init__(self, test_case)
         self.expected = expected
         self.test_case = test_case
         if callable_obj is not None:
@@ -133,10 +145,6 @@
         self.expected_regex = expected_regex
         self.msg = None
 
-    def _raiseFailure(self, standardMsg):
-        msg = self.test_case._formatMessage(self.msg, standardMsg)
-        raise self.test_case.failureException(msg)
-
     def handle(self, name, callable_obj, args, kwargs):
         """
         If callable_obj is None, assertRaises/Warns is being used as a
@@ -150,7 +158,6 @@
             callable_obj(*args, **kwargs)
 
 
-
 class _AssertRaisesContext(_AssertRaisesBaseContext):
     """A context manager used to implement TestCase.assertRaises* methods."""
 
@@ -232,6 +239,74 @@
             self._raiseFailure("{} not triggered".format(exc_name))
 
 
+
+_LoggingWatcher = collections.namedtuple("_LoggingWatcher",
+                                         ["records", "output"])
+
+
+class _CapturingHandler(logging.Handler):
+    """
+    A logging handler capturing all (raw and formatted) logging output.
+    """
+
+    def __init__(self):
+        logging.Handler.__init__(self)
+        self.watcher = _LoggingWatcher([], [])
+
+    def flush(self):
+        pass
+
+    def emit(self, record):
+        self.watcher.records.append(record)
+        msg = self.format(record)
+        self.watcher.output.append(msg)
+
+
+
+class _AssertLogsContext(_BaseTestCaseContext):
+    """A context manager used to implement TestCase.assertLogs()."""
+
+    LOGGING_FORMAT = "%(levelname)s:%(name)s:%(message)s"
+
+    def __init__(self, test_case, logger_name, level):
+        _BaseTestCaseContext.__init__(self, test_case)
+        self.logger_name = logger_name
+        if level:
+            self.level = logging._nameToLevel.get(level, level)
+        else:
+            self.level = logging.INFO
+        self.msg = None
+
+    def __enter__(self):
+        if isinstance(self.logger_name, logging.Logger):
+            logger = self.logger = self.logger_name
+        else:
+            logger = self.logger = logging.getLogger(self.logger_name)
+        formatter = logging.Formatter(self.LOGGING_FORMAT)
+        handler = _CapturingHandler()
+        handler.setFormatter(formatter)
+        self.watcher = handler.watcher
+        self.old_handlers = logger.handlers[:]
+        self.old_level = logger.level
+        self.old_propagate = logger.propagate
+        logger.handlers = [handler]
+        logger.setLevel(self.level)
+        logger.propagate = False
+        return handler.watcher
+
+    def __exit__(self, exc_type, exc_value, tb):
+        self.logger.handlers = self.old_handlers
+        self.logger.propagate = self.old_propagate
+        self.logger.setLevel(self.old_level)
+        if exc_type is not None:
+            # let unexpected exceptions pass through
+            return False
+        if len(self.watcher.records) == 0:
+            self._raiseFailure(
+                "no logs of level {} or higher triggered on {}"
+                .format(logging.getLevelName(self.level), self.logger.name))
+
+
 class TestCase(object):
     """A class whose instances are single test cases.
 
@@ -644,6 +719,28 @@
         context = _AssertWarnsContext(expected_warning, self, callable_obj)
         return context.handle('assertWarns', callable_obj, args, kwargs)
 
+    def assertLogs(self, logger=None, level=None):
+        """Fail unless a log message of level *level* or higher is emitted
+        on *logger_name* or its children.  If omitted, *level* defaults to
+        INFO and *logger* defaults to the root logger.
+
+        This method must be used as a context manager, and will yield
+        a recording object with two attributes: `output` and `records`.
+        At the end of the context manager, the `output` attribute will
+        be a list of the matching formatted log messages and the
+        `records` attribute will be a list of the corresponding LogRecord
+        objects.
+
+        Example::
+
+            with self.assertLogs('foo', level='INFO') as cm:
+                logging.getLogger('foo').info('first message')
+                logging.getLogger('foo.bar').error('second message')
+            self.assertEqual(cm.output, ['INFO:foo:first message',
+                                         'ERROR:foo.bar:second message'])
+        """
+        return _AssertLogsContext(self, logger, level)
+
     def _getAssertEqualityFunc(self, first, second):
         """Get a detailed comparison function for the types of the two args.
 
diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py
--- a/Lib/unittest/test/test_case.py
+++ b/Lib/unittest/test/test_case.py
@@ -1,8 +1,10 @@
+import contextlib
 import difflib
 import pprint
 import pickle
 import re
 import sys
+import logging
 import warnings
 import weakref
 import inspect
@@ -16,6 +18,12 @@
     TestEquality, TestHashing, LoggingResult, LegacyLoggingResult,
     ResultWithNoStartTestRunStopTestRun
 )
+from test.support import captured_stderr
+
+
+log_foo = logging.getLogger('foo')
+log_foobar = logging.getLogger('foo.bar')
+log_quux = logging.getLogger('quux')
 
 
 class Test(object):
@@ -1251,6 +1259,94 @@
                 with self.assertWarnsRegex(RuntimeWarning, "o+"):
                     _runtime_warn("barz")
 
+    @contextlib.contextmanager
+    def assertNoStderr(self):
+        with captured_stderr() as buf:
+            yield
+        self.assertEqual(buf.getvalue(), "")
+
+    def assertLogRecords(self, records, matches):
+        self.assertEqual(len(records), len(matches))
+        for rec, match in zip(records, matches):
+            self.assertIsInstance(rec, logging.LogRecord)
+            for k, v in match.items():
+                self.assertEqual(getattr(rec, k), v)
+
+    def testAssertLogsDefaults(self):
+        # defaults: root logger, level INFO
+        with self.assertNoStderr():
+            with self.assertLogs() as cm:
+                log_foo.info("1")
+                log_foobar.debug("2")
+            self.assertEqual(cm.output, ["INFO:foo:1"])
+            self.assertLogRecords(cm.records, [{'name': 'foo'}])
+
+    def testAssertLogsTwoMatchingMessages(self):
+        # Same, but with two matching log messages
+        with self.assertNoStderr():
+            with self.assertLogs() as cm:
+                log_foo.info("1")
+                log_foobar.debug("2")
+                log_quux.warning("3")
+            self.assertEqual(cm.output, ["INFO:foo:1", "WARNING:quux:3"])
+            self.assertLogRecords(cm.records,
+                                   [{'name': 'foo'}, {'name': 'quux'}])
+
+    def checkAssertLogsPerLevel(self, level):
+        # Check level filtering
+        with self.assertNoStderr():
+            with self.assertLogs(level=level) as cm:
+                log_foo.warning("1")
+                log_foobar.error("2")
+                log_quux.critical("3")
+            self.assertEqual(cm.output, ["ERROR:foo.bar:2", "CRITICAL:quux:3"])
+            self.assertLogRecords(cm.records,
+                                   [{'name': 'foo.bar'}, {'name': 'quux'}])
+
+    def testAssertLogsPerLevel(self):
+        self.checkAssertLogsPerLevel(logging.ERROR)
+        self.checkAssertLogsPerLevel('ERROR')
+
+    def checkAssertLogsPerLogger(self, logger):
+        # Check per-logger fitering
+        with self.assertNoStderr():
+            with self.assertLogs(level='DEBUG') as outer_cm:
+                with self.assertLogs(logger, level='DEBUG') as cm:
+                    log_foo.info("1")
+                    log_foobar.debug("2")
+                    log_quux.warning("3")
+                self.assertEqual(cm.output, ["INFO:foo:1", "DEBUG:foo.bar:2"])
+                self.assertLogRecords(cm.records,
+                                       [{'name': 'foo'}, {'name': 'foo.bar'}])
+            # The outer catchall caught the quux log
+            self.assertEqual(outer_cm.output, ["WARNING:quux:3"])
+
+    def testAssertLogsPerLogger(self):
+        self.checkAssertLogsPerLogger(logging.getLogger('foo'))
+        self.checkAssertLogsPerLogger('foo')
+
+    def testAssertLogsFailureNoLogs(self):
+        # Failure due to no logs
+        with self.assertNoStderr():
+            with self.assertRaises(self.failureException):
+                with self.assertLogs():
+                    pass
+
+    def testAssertLogsFailureLevelTooHigh(self):
+        # Failure due to level too high
+        with self.assertNoStderr():
+            with self.assertRaises(self.failureException):
+                with self.assertLogs(level='WARNING'):
+                    log_foo.info("1")
+
+    def testAssertLogsFailureMismatchingLogger(self):
+        # Failure due to mismatching logger (and the logged message is
+        # passed through)
+        with self.assertLogs('quux', level='ERROR'):
+            with self.assertRaises(self.failureException):
+                with self.assertLogs('foo'):
+                    log_quux.error("1")
+
     def testDeprecatedMethodNames(self):
         """
         Test that the deprecated methods raise a DeprecationWarning. See #9424.
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -12,6 +12,8 @@
 Library
 -------
 
+- Issue #18937: Add an assertLogs() context manager to unittest.TestCase
+  to ensure that a block of code emits a message using the logging module.
 
 - Issue #17324: Fix http.server's request handling case on trailing '/'. Patch
   contributed by Vajrasky Kok.

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list