[Python-checkins] [3.10] bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654) (GH-28657)
serhiy-storchaka
webhook-mailer at python.org
Thu Sep 30 12:56:52 EDT 2021
https://github.com/python/cpython/commit/7873884d4730d7e637a968011b8958bd79fd3398
commit: 7873884d4730d7e637a968011b8958bd79fd3398
branch: 3.10
author: Serhiy Storchaka <storchaka at gmail.com>
committer: serhiy-storchaka <storchaka at gmail.com>
date: 2021-09-30T19:56:41+03:00
summary:
[3.10] bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654) (GH-28657)
* Work correctly if an additional fresh module imports other
additional fresh module which imports a blocked module.
* Raises ImportError if the specified module cannot be imported
while all additional fresh modules are successfully imported.
* Support blocking packages.
* Always restore the import state of fresh and blocked modules
and their submodules.
* Fix test_decimal and test_xml_etree which depended on an undesired
side effect of import_fresh_module().
(cherry picked from commit ec4d917a6a68824f1895f75d113add9410283da7)
files:
A Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst
M Lib/test/support/import_helper.py
M Lib/test/test_decimal.py
M Lib/test/test_xml_etree.py
diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py
index 5d1e9406879cc..43ae31483420d 100644
--- a/Lib/test/support/import_helper.py
+++ b/Lib/test/support/import_helper.py
@@ -80,33 +80,13 @@ def import_module(name, deprecated=False, *, required_on=()):
raise unittest.SkipTest(str(msg))
-def _save_and_remove_module(name, orig_modules):
- """Helper function to save and remove a module from sys.modules
-
- Raise ImportError if the module can't be imported.
- """
- # try to import the module and raise an error if it can't be imported
- if name not in sys.modules:
- __import__(name)
- del sys.modules[name]
+def _save_and_remove_modules(names):
+ orig_modules = {}
+ prefixes = tuple(name + '.' for name in names)
for modname in list(sys.modules):
- if modname == name or modname.startswith(name + '.'):
- orig_modules[modname] = sys.modules[modname]
- del sys.modules[modname]
-
-
-def _save_and_block_module(name, orig_modules):
- """Helper function to save and block a module in sys.modules
-
- Return True if the module was in sys.modules, False otherwise.
- """
- saved = True
- try:
- orig_modules[name] = sys.modules[name]
- except KeyError:
- saved = False
- sys.modules[name] = None
- return saved
+ if modname in names or modname.startswith(prefixes):
+ orig_modules[modname] = sys.modules.pop(modname)
+ return orig_modules
def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
@@ -118,7 +98,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
this operation.
*fresh* is an iterable of additional module names that are also removed
- from the sys.modules cache before doing the import.
+ from the sys.modules cache before doing the import. If one of these
+ modules can't be imported, None is returned.
*blocked* is an iterable of module names that are replaced with None
in the module cache during the import to ensure that attempts to import
@@ -139,24 +120,24 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
with _ignore_deprecated_imports(deprecated):
# Keep track of modules saved for later restoration as well
# as those which just need a blocking entry removed
- orig_modules = {}
- names_to_remove = []
- _save_and_remove_module(name, orig_modules)
+ fresh = list(fresh)
+ blocked = list(blocked)
+ names = {name, *fresh, *blocked}
+ orig_modules = _save_and_remove_modules(names)
+ for modname in blocked:
+ sys.modules[modname] = None
+
try:
- for fresh_name in fresh:
- _save_and_remove_module(fresh_name, orig_modules)
- for blocked_name in blocked:
- if not _save_and_block_module(blocked_name, orig_modules):
- names_to_remove.append(blocked_name)
- fresh_module = importlib.import_module(name)
- except ImportError:
- fresh_module = None
+ # Return None when one of the "fresh" modules can not be imported.
+ try:
+ for modname in fresh:
+ __import__(modname)
+ except ImportError:
+ return None
+ return importlib.import_module(name)
finally:
- for orig_name, module in orig_modules.items():
- sys.modules[orig_name] = module
- for name_to_remove in names_to_remove:
- del sys.modules[name_to_remove]
- return fresh_module
+ _save_and_remove_modules(names)
+ sys.modules.update(orig_modules)
class CleanImport(object):
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 99263bb13b0d1..b6173a5ffec96 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -62,7 +62,7 @@
C = import_fresh_module('decimal', fresh=['_decimal'])
P = import_fresh_module('decimal', blocked=['_decimal'])
-orig_sys_decimal = sys.modules['decimal']
+import decimal as orig_sys_decimal
# fractions module must import the correct decimal module.
cfractions = import_fresh_module('fractions', fresh=['fractions'])
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index c79b5462b931d..5a8824a78ffa4 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -26,7 +26,7 @@
from test import support
from test.support import os_helper
from test.support import warnings_helper
-from test.support import findfile, gc_collect, swap_attr
+from test.support import findfile, gc_collect, swap_attr, swap_item
from test.support.import_helper import import_fresh_module
from test.support.os_helper import TESTFN
@@ -167,12 +167,11 @@ def setUpClass(cls):
cls.modules = {pyET, ET}
def pickleRoundTrip(self, obj, name, dumper, loader, proto):
- save_m = sys.modules[name]
try:
- sys.modules[name] = dumper
- temp = pickle.dumps(obj, proto)
- sys.modules[name] = loader
- result = pickle.loads(temp)
+ with swap_item(sys.modules, name, dumper):
+ temp = pickle.dumps(obj, proto)
+ with swap_item(sys.modules, name, loader):
+ result = pickle.loads(temp)
except pickle.PicklingError as pe:
# pyET must be second, because pyET may be (equal to) ET.
human = dict([(ET, "cET"), (pyET, "pyET")])
@@ -180,8 +179,6 @@ def pickleRoundTrip(self, obj, name, dumper, loader, proto):
% (obj,
human.get(dumper, dumper),
human.get(loader, loader))) from pe
- finally:
- sys.modules[name] = save_m
return result
def assertEqualElements(self, alice, bob):
diff --git a/Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst b/Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst
new file mode 100644
index 0000000000000..21671473c16cc
--- /dev/null
+++ b/Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst
@@ -0,0 +1,2 @@
+Fix :func:`test.support.import_helper.import_fresh_module`.
+
More information about the Python-checkins
mailing list