[Python-checkins] [3.10] bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654) (GH-28657)

serhiy-storchaka webhook-mailer at python.org
Thu Sep 30 12:56:52 EDT 2021


https://github.com/python/cpython/commit/7873884d4730d7e637a968011b8958bd79fd3398
commit: 7873884d4730d7e637a968011b8958bd79fd3398
branch: 3.10
author: Serhiy Storchaka <storchaka at gmail.com>
committer: serhiy-storchaka <storchaka at gmail.com>
date: 2021-09-30T19:56:41+03:00
summary:

[3.10] bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654) (GH-28657)

* Work correctly if an additional fresh module imports other
  additional fresh module which imports a blocked module.
* Raises ImportError if the specified module cannot be imported
  while all additional fresh modules are successfully imported.
* Support blocking packages.
* Always restore the import state of fresh and blocked modules
  and their submodules.
* Fix test_decimal and test_xml_etree which depended on an undesired
  side effect of import_fresh_module().
(cherry picked from commit ec4d917a6a68824f1895f75d113add9410283da7)

files:
A Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst
M Lib/test/support/import_helper.py
M Lib/test/test_decimal.py
M Lib/test/test_xml_etree.py

diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py
index 5d1e9406879cc..43ae31483420d 100644
--- a/Lib/test/support/import_helper.py
+++ b/Lib/test/support/import_helper.py
@@ -80,33 +80,13 @@ def import_module(name, deprecated=False, *, required_on=()):
             raise unittest.SkipTest(str(msg))
 
 
-def _save_and_remove_module(name, orig_modules):
-    """Helper function to save and remove a module from sys.modules
-
-    Raise ImportError if the module can't be imported.
-    """
-    # try to import the module and raise an error if it can't be imported
-    if name not in sys.modules:
-        __import__(name)
-        del sys.modules[name]
+def _save_and_remove_modules(names):
+    orig_modules = {}
+    prefixes = tuple(name + '.' for name in names)
     for modname in list(sys.modules):
-        if modname == name or modname.startswith(name + '.'):
-            orig_modules[modname] = sys.modules[modname]
-            del sys.modules[modname]
-
-
-def _save_and_block_module(name, orig_modules):
-    """Helper function to save and block a module in sys.modules
-
-    Return True if the module was in sys.modules, False otherwise.
-    """
-    saved = True
-    try:
-        orig_modules[name] = sys.modules[name]
-    except KeyError:
-        saved = False
-    sys.modules[name] = None
-    return saved
+        if modname in names or modname.startswith(prefixes):
+            orig_modules[modname] = sys.modules.pop(modname)
+    return orig_modules
 
 
 def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
@@ -118,7 +98,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
     this operation.
 
     *fresh* is an iterable of additional module names that are also removed
-    from the sys.modules cache before doing the import.
+    from the sys.modules cache before doing the import. If one of these
+    modules can't be imported, None is returned.
 
     *blocked* is an iterable of module names that are replaced with None
     in the module cache during the import to ensure that attempts to import
@@ -139,24 +120,24 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
     with _ignore_deprecated_imports(deprecated):
         # Keep track of modules saved for later restoration as well
         # as those which just need a blocking entry removed
-        orig_modules = {}
-        names_to_remove = []
-        _save_and_remove_module(name, orig_modules)
+        fresh = list(fresh)
+        blocked = list(blocked)
+        names = {name, *fresh, *blocked}
+        orig_modules = _save_and_remove_modules(names)
+        for modname in blocked:
+            sys.modules[modname] = None
+
         try:
-            for fresh_name in fresh:
-                _save_and_remove_module(fresh_name, orig_modules)
-            for blocked_name in blocked:
-                if not _save_and_block_module(blocked_name, orig_modules):
-                    names_to_remove.append(blocked_name)
-            fresh_module = importlib.import_module(name)
-        except ImportError:
-            fresh_module = None
+            # Return None when one of the "fresh" modules can not be imported.
+            try:
+                for modname in fresh:
+                    __import__(modname)
+            except ImportError:
+                return None
+            return importlib.import_module(name)
         finally:
-            for orig_name, module in orig_modules.items():
-                sys.modules[orig_name] = module
-            for name_to_remove in names_to_remove:
-                del sys.modules[name_to_remove]
-        return fresh_module
+            _save_and_remove_modules(names)
+            sys.modules.update(orig_modules)
 
 
 class CleanImport(object):
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 99263bb13b0d1..b6173a5ffec96 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -62,7 +62,7 @@
 
 C = import_fresh_module('decimal', fresh=['_decimal'])
 P = import_fresh_module('decimal', blocked=['_decimal'])
-orig_sys_decimal = sys.modules['decimal']
+import decimal as orig_sys_decimal
 
 # fractions module must import the correct decimal module.
 cfractions = import_fresh_module('fractions', fresh=['fractions'])
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index c79b5462b931d..5a8824a78ffa4 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -26,7 +26,7 @@
 from test import support
 from test.support import os_helper
 from test.support import warnings_helper
-from test.support import findfile, gc_collect, swap_attr
+from test.support import findfile, gc_collect, swap_attr, swap_item
 from test.support.import_helper import import_fresh_module
 from test.support.os_helper import TESTFN
 
@@ -167,12 +167,11 @@ def setUpClass(cls):
         cls.modules = {pyET, ET}
 
     def pickleRoundTrip(self, obj, name, dumper, loader, proto):
-        save_m = sys.modules[name]
         try:
-            sys.modules[name] = dumper
-            temp = pickle.dumps(obj, proto)
-            sys.modules[name] = loader
-            result = pickle.loads(temp)
+            with swap_item(sys.modules, name, dumper):
+                temp = pickle.dumps(obj, proto)
+            with swap_item(sys.modules, name, loader):
+                result = pickle.loads(temp)
         except pickle.PicklingError as pe:
             # pyET must be second, because pyET may be (equal to) ET.
             human = dict([(ET, "cET"), (pyET, "pyET")])
@@ -180,8 +179,6 @@ def pickleRoundTrip(self, obj, name, dumper, loader, proto):
                                      % (obj,
                                         human.get(dumper, dumper),
                                         human.get(loader, loader))) from pe
-        finally:
-            sys.modules[name] = save_m
         return result
 
     def assertEqualElements(self, alice, bob):
diff --git a/Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst b/Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst
new file mode 100644
index 0000000000000..21671473c16cc
--- /dev/null
+++ b/Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst
@@ -0,0 +1,2 @@
+Fix :func:`test.support.import_helper.import_fresh_module`.
+



More information about the Python-checkins mailing list