[Python-checkins] cpython: Close #22457: Honour load_tests in the start_dir of discovery.

robert.collins python-checkins at python.org
Tue Nov 4 15:09:42 CET 2014


https://hg.python.org/cpython/rev/ce0dd5e4b801
changeset:   93383:ce0dd5e4b801
parent:      93381:be374b8c40c8
user:        Robert Collins <rbtcollins at hp.com>
date:        Wed Nov 05 03:09:01 2014 +1300
summary:
  Close #22457: Honour load_tests in the start_dir of discovery.

We were not honouring load_tests in a package/__init__.py when that was the
start_dir parameter, though we do when it is a child package. The fix required
a little care since it introduces the possibility of infinite recursion.

files:
  Doc/library/unittest.rst            |    6 +-
  Lib/unittest/__init__.py            |    9 +
  Lib/unittest/loader.py              |  160 ++++++++++-----
  Lib/unittest/test/test_discovery.py |   45 ++++
  Lib/unittest/test/test_loader.py    |    2 +-
  Misc/NEWS                           |    2 +
  6 files changed, 166 insertions(+), 58 deletions(-)


diff --git a/Doc/library/unittest.rst b/Doc/library/unittest.rst
--- a/Doc/library/unittest.rst
+++ b/Doc/library/unittest.rst
@@ -1668,7 +1668,11 @@
 
       If a package (a directory containing a file named :file:`__init__.py`) is
       found, the package will be checked for a ``load_tests`` function. If this
-      exists then it will be called with *loader*, *tests*, *pattern*.
+      exists then it will be called
+      ``package.load_tests(loader, tests, pattern)``. Test discovery takes care
+      to ensure that a package is only checked for tests once during an
+      invocation, even if the load_tests function itself calls
+      ``loader.discover``.
 
       If ``load_tests`` exists then discovery does *not* recurse into the
       package, ``load_tests`` is responsible for loading all tests in the
diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py
--- a/Lib/unittest/__init__.py
+++ b/Lib/unittest/__init__.py
@@ -67,3 +67,12 @@
 
 # deprecated
 _TextTestResult = TextTestResult
+
+# There are no tests here, so don't try to run anything discovered from
+# introspecting the symbols (e.g. FunctionTestCase). Instead, all our
+# tests come from within unittest.test.
+def load_tests(loader, tests, pattern):
+    import os.path
+    # top level directory cached on loader instance
+    this_dir = os.path.dirname(__file__)
+    return loader.discover(start_dir=this_dir, pattern=pattern)
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -65,6 +65,9 @@
     def __init__(self):
         super(TestLoader, self).__init__()
         self.errors = []
+        # Tracks packages which we have called into via load_tests, to
+        # avoid infinite re-entrancy.
+        self._loading_packages = set()
 
     def loadTestsFromTestCase(self, testCaseClass):
         """Return a suite of all tests cases contained in testCaseClass"""
@@ -229,9 +232,13 @@
 
         If a test package name (directory with '__init__.py') matches the
         pattern then the package will be checked for a 'load_tests' function. If
-        this exists then it will be called with loader, tests, pattern.
+        this exists then it will be called with (loader, tests, pattern) unless
+        the package has already had load_tests called from the same discovery
+        invocation, in which case the package module object is not scanned for
+        tests - this ensures that when a package uses discover to further
+        discover child tests that infinite recursion does not happen.
 
-        If load_tests exists then discovery does  *not* recurse into the package,
+        If load_tests exists then discovery does *not* recurse into the package,
         load_tests is responsible for loading all tests in the package.
 
         The pattern is deliberately not stored as a loader attribute so that
@@ -355,69 +362,110 @@
 
     def _find_tests(self, start_dir, pattern, namespace=False):
         """Used by discovery. Yields test suites it loads."""
+        # Handle the __init__ in this package
+        name = self._get_name_from_path(start_dir)
+        # name is '.' when start_dir == top_level_dir (and top_level_dir is by
+        # definition not a package).
+        if name != '.' and name not in self._loading_packages:
+            # name is in self._loading_packages while we have called into
+            # loadTestsFromModule with name.
+            tests, should_recurse = self._find_test_path(
+                start_dir, pattern, namespace)
+            if tests is not None:
+                yield tests
+            if not should_recurse:
+                # Either an error occured, or load_tests was used by the
+                # package.
+                return
+        # Handle the contents.
         paths = sorted(os.listdir(start_dir))
-
         for path in paths:
             full_path = os.path.join(start_dir, path)
-            if os.path.isfile(full_path):
-                if not VALID_MODULE_NAME.match(path):
-                    # valid Python identifiers only
-                    continue
-                if not self._match_path(path, full_path, pattern):
-                    continue
-                # if the test file matches, load it
+            tests, should_recurse = self._find_test_path(
+                full_path, pattern, namespace)
+            if tests is not None:
+                yield tests
+            if should_recurse:
+                # we found a package that didn't use load_tests.
                 name = self._get_name_from_path(full_path)
+                self._loading_packages.add(name)
                 try:
-                    module = self._get_module_from_name(name)
-                except case.SkipTest as e:
-                    yield _make_skipped_test(name, e, self.suiteClass)
-                except:
-                    error_case, error_message = \
-                        _make_failed_import_test(name, self.suiteClass)
-                    self.errors.append(error_message)
-                    yield error_case
-                else:
-                    mod_file = os.path.abspath(getattr(module, '__file__', full_path))
-                    realpath = _jython_aware_splitext(os.path.realpath(mod_file))
-                    fullpath_noext = _jython_aware_splitext(os.path.realpath(full_path))
-                    if realpath.lower() != fullpath_noext.lower():
-                        module_dir = os.path.dirname(realpath)
-                        mod_name = _jython_aware_splitext(os.path.basename(full_path))
-                        expected_dir = os.path.dirname(full_path)
-                        msg = ("%r module incorrectly imported from %r. Expected %r. "
-                               "Is this module globally installed?")
-                        raise ImportError(msg % (mod_name, module_dir, expected_dir))
-                    yield self.loadTestsFromModule(module, pattern=pattern)
-            elif os.path.isdir(full_path):
-                if (not namespace and
-                    not os.path.isfile(os.path.join(full_path, '__init__.py'))):
-                    continue
+                    yield from self._find_tests(full_path, pattern, namespace)
+                finally:
+                    self._loading_packages.discard(name)
 
-                load_tests = None
-                tests = None
-                name = self._get_name_from_path(full_path)
+    def _find_test_path(self, full_path, pattern, namespace=False):
+        """Used by discovery.
+
+        Loads tests from a single file, or a directories' __init__.py when
+        passed the directory.
+
+        Returns a tuple (None_or_tests_from_file, should_recurse).
+        """
+        basename = os.path.basename(full_path)
+        if os.path.isfile(full_path):
+            if not VALID_MODULE_NAME.match(basename):
+                # valid Python identifiers only
+                return None, False
+            if not self._match_path(basename, full_path, pattern):
+                return None, False
+            # if the test file matches, load it
+            name = self._get_name_from_path(full_path)
+            try:
+                module = self._get_module_from_name(name)
+            except case.SkipTest as e:
+                return _make_skipped_test(name, e, self.suiteClass), False
+            except:
+                error_case, error_message = \
+                    _make_failed_import_test(name, self.suiteClass)
+                self.errors.append(error_message)
+                return error_case, False
+            else:
+                mod_file = os.path.abspath(
+                    getattr(module, '__file__', full_path))
+                realpath = _jython_aware_splitext(
+                    os.path.realpath(mod_file))
+                fullpath_noext = _jython_aware_splitext(
+                    os.path.realpath(full_path))
+                if realpath.lower() != fullpath_noext.lower():
+                    module_dir = os.path.dirname(realpath)
+                    mod_name = _jython_aware_splitext(
+                        os.path.basename(full_path))
+                    expected_dir = os.path.dirname(full_path)
+                    msg = ("%r module incorrectly imported from %r. Expected "
+                           "%r. Is this module globally installed?")
+                    raise ImportError(
+                        msg % (mod_name, module_dir, expected_dir))
+                return self.loadTestsFromModule(module, pattern=pattern), False
+        elif os.path.isdir(full_path):
+            if (not namespace and
+                not os.path.isfile(os.path.join(full_path, '__init__.py'))):
+                return None, False
+
+            load_tests = None
+            tests = None
+            name = self._get_name_from_path(full_path)
+            try:
+                package = self._get_module_from_name(name)
+            except case.SkipTest as e:
+                return _make_skipped_test(name, e, self.suiteClass), False
+            except:
+                error_case, error_message = \
+                    _make_failed_import_test(name, self.suiteClass)
+                self.errors.append(error_message)
+                return error_case, False
+            else:
+                load_tests = getattr(package, 'load_tests', None)
+                # Mark this package as being in load_tests (possibly ;))
+                self._loading_packages.add(name)
                 try:
-                    package = self._get_module_from_name(name)
-                except case.SkipTest as e:
-                    yield _make_skipped_test(name, e, self.suiteClass)
-                except:
-                    error_case, error_message = \
-                        _make_failed_import_test(name, self.suiteClass)
-                    self.errors.append(error_message)
-                    yield error_case
-                else:
-                    load_tests = getattr(package, 'load_tests', None)
                     tests = self.loadTestsFromModule(package, pattern=pattern)
-                    if tests is not None:
-                        # tests loaded from package file
-                        yield tests
-
                     if load_tests is not None:
-                        # loadTestsFromModule(package) has load_tests for us.
-                        continue
-                    # recurse into the package
-                    yield from self._find_tests(full_path, pattern,
-                                                namespace=namespace)
+                        # loadTestsFromModule(package) has loaded tests for us.
+                        return tests, False
+                    return tests, True
+                finally:
+                    self._loading_packages.discard(name)
 
 
 defaultTestLoader = TestLoader()
diff --git a/Lib/unittest/test/test_discovery.py b/Lib/unittest/test/test_discovery.py
--- a/Lib/unittest/test/test_discovery.py
+++ b/Lib/unittest/test/test_discovery.py
@@ -368,6 +368,51 @@
         self.assertEqual(_find_tests_args, [(start_dir, 'pattern')])
         self.assertIn(top_level_dir, sys.path)
 
+    def test_discover_start_dir_is_package_calls_package_load_tests(self):
+        # This test verifies that the package load_tests in a package is indeed
+        # invoked when the start_dir is a package (and not the top level).
+        # http://bugs.python.org/issue22457
+
+        # Test data: we expect the following:
+        # an isfile to verify the package, then importing and scanning
+        # as per _find_tests' normal behaviour.
+        # We expect to see our load_tests hook called once.
+        vfs = {abspath('/toplevel'): ['startdir'],
+               abspath('/toplevel/startdir'): ['__init__.py']}
+        def list_dir(path):
+            return list(vfs[path])
+        self.addCleanup(setattr, os, 'listdir', os.listdir)
+        os.listdir = list_dir
+        self.addCleanup(setattr, os.path, 'isfile', os.path.isfile)
+        os.path.isfile = lambda path: path.endswith('.py')
+        self.addCleanup(setattr, os.path, 'isdir', os.path.isdir)
+        os.path.isdir = lambda path: not path.endswith('.py')
+        self.addCleanup(sys.path.remove, abspath('/toplevel'))
+
+        class Module(object):
+            paths = []
+            load_tests_args = []
+
+            def __init__(self, path):
+                self.path = path
+
+            def load_tests(self, loader, tests, pattern):
+                return ['load_tests called ' + self.path]
+
+            def __eq__(self, other):
+                return self.path == other.path
+
+        loader = unittest.TestLoader()
+        loader._get_module_from_name = lambda name: Module(name)
+        loader.suiteClass = lambda thing: thing
+
+        suite = loader.discover('/toplevel/startdir', top_level_dir='/toplevel')
+
+        # We should have loaded tests from the package __init__.
+        # (normally this would be nested TestSuites.)
+        self.assertEqual(suite,
+                         [['load_tests called startdir']])
+
     def setup_import_issue_tests(self, fakefile):
         listdir = os.listdir
         os.listdir = lambda _: [fakefile]
diff --git a/Lib/unittest/test/test_loader.py b/Lib/unittest/test/test_loader.py
--- a/Lib/unittest/test/test_loader.py
+++ b/Lib/unittest/test/test_loader.py
@@ -841,7 +841,7 @@
         loader = unittest.TestLoader()
 
         suite = loader.loadTestsFromNames(
-            ['unittest.loader.sdasfasfasdf', 'unittest'])
+            ['unittest.loader.sdasfasfasdf', 'unittest.test.dummy'])
         error, test = self.check_deferred_error(loader, list(suite)[0])
         expected = "module 'unittest.loader' has no attribute 'sdasfasfasdf'"
         self.assertIn(
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -212,6 +212,8 @@
 
 - Issue #22217: Implemented reprs of classes in the zipfile module.
 
+- Issue #22457: Honour load_tests in the start_dir of discovery.
+
 - Issue #18216: gettext now raises an error when a .mo file has an
   unsupported major version number.  Patch by Aaron Hill.
 

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list