summaryrefslogtreecommitdiff
path: root/Lib/importlib/test
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/importlib/test')
-rw-r--r--Lib/importlib/test/__main__.py19
-rw-r--r--Lib/importlib/test/benchmark.py157
-rw-r--r--Lib/importlib/test/extension/test_case_sensitivity.py6
-rw-r--r--Lib/importlib/test/import_/test___package__.py2
-rw-r--r--Lib/importlib/test/import_/test_api.py7
-rw-r--r--Lib/importlib/test/import_/test_path.py10
-rw-r--r--Lib/importlib/test/regrtest.py7
-rw-r--r--Lib/importlib/test/source/test_abc_loader.py17
-rw-r--r--Lib/importlib/test/source/test_case_sensitivity.py6
-rw-r--r--Lib/importlib/test/source/test_file_loader.py44
-rw-r--r--Lib/importlib/test/source/test_finder.py7
-rw-r--r--Lib/importlib/test/test_api.py28
-rw-r--r--Lib/importlib/test/test_util.py10
13 files changed, 239 insertions, 81 deletions
diff --git a/Lib/importlib/test/__main__.py b/Lib/importlib/test/__main__.py
index decc53d8c5..92171b25ca 100644
--- a/Lib/importlib/test/__main__.py
+++ b/Lib/importlib/test/__main__.py
@@ -4,26 +4,27 @@ Specifying the ``--builtin`` flag will run tests, where applicable, with
builtins.__import__ instead of importlib.__import__.
"""
-import importlib
from importlib.test.import_ import util
import os.path
from test.support import run_unittest
-import sys
import unittest
def test_main():
- if '__pycache__' in __file__:
- parts = __file__.split(os.path.sep)
- start_dir = sep.join(parts[:-2])
- else:
- start_dir = os.path.dirname(__file__)
+ start_dir = os.path.dirname(__file__)
top_dir = os.path.dirname(os.path.dirname(start_dir))
test_loader = unittest.TestLoader()
- if '--builtin' in sys.argv:
- util.using___import__ = True
run_unittest(test_loader.discover(start_dir, top_level_dir=top_dir))
if __name__ == '__main__':
+ import argparse
+
+ parser = argparse.ArgumentParser(description='Execute the importlib test '
+ 'suite')
+ parser.add_argument('-b', '--builtin', action='store_true', default=False,
+ help='use builtins.__import__() instead of importlib')
+ args = parser.parse_args()
+ if args.builtin:
+ util.using___import__ = True
test_main()
diff --git a/Lib/importlib/test/benchmark.py b/Lib/importlib/test/benchmark.py
index b5de6c6b01..87b1775f66 100644
--- a/Lib/importlib/test/benchmark.py
+++ b/Lib/importlib/test/benchmark.py
@@ -9,9 +9,11 @@ from .source import util as source_util
import decimal
import imp
import importlib
+import json
import os
import py_compile
import sys
+import tabnanny
import timeit
@@ -59,7 +61,7 @@ def builtin_mod(seconds, repeat):
def source_wo_bytecode(seconds, repeat):
- """Source w/o bytecode: simple"""
+ """Source w/o bytecode: small"""
sys.dont_write_bytecode = True
try:
name = '__importlib_test_benchmark__'
@@ -73,23 +75,30 @@ def source_wo_bytecode(seconds, repeat):
sys.dont_write_bytecode = False
-def decimal_wo_bytecode(seconds, repeat):
- """Source w/o bytecode: decimal"""
- name = 'decimal'
- decimal_bytecode = imp.cache_from_source(decimal.__file__)
- if os.path.exists(decimal_bytecode):
- os.unlink(decimal_bytecode)
- sys.dont_write_bytecode = True
- try:
- for result in bench(name, lambda: sys.modules.pop(name), repeat=repeat,
- seconds=seconds):
- yield result
- finally:
- sys.dont_write_bytecode = False
+def _wo_bytecode(module):
+ name = module.__name__
+ def benchmark_wo_bytecode(seconds, repeat):
+ """Source w/o bytecode: {}"""
+ bytecode_path = imp.cache_from_source(module.__file__)
+ if os.path.exists(bytecode_path):
+ os.unlink(bytecode_path)
+ sys.dont_write_bytecode = True
+ try:
+ for result in bench(name, lambda: sys.modules.pop(name),
+ repeat=repeat, seconds=seconds):
+ yield result
+ finally:
+ sys.dont_write_bytecode = False
+
+ benchmark_wo_bytecode.__doc__ = benchmark_wo_bytecode.__doc__.format(name)
+ return benchmark_wo_bytecode
+
+tabnanny_wo_bytecode = _wo_bytecode(tabnanny)
+decimal_wo_bytecode = _wo_bytecode(decimal)
def source_writing_bytecode(seconds, repeat):
- """Source writing bytecode: simple"""
+ """Source writing bytecode: small"""
assert not sys.dont_write_bytecode
name = '__importlib_test_benchmark__'
with source_util.create_modules(name) as mapping:
@@ -101,19 +110,27 @@ def source_writing_bytecode(seconds, repeat):
yield result
-def decimal_writing_bytecode(seconds, repeat):
- """Source writing bytecode: decimal"""
- assert not sys.dont_write_bytecode
- name = 'decimal'
- def cleanup():
- sys.modules.pop(name)
- os.unlink(imp.cache_from_source(decimal.__file__))
- for result in bench(name, cleanup, repeat=repeat, seconds=seconds):
- yield result
+def _writing_bytecode(module):
+ name = module.__name__
+ def writing_bytecode_benchmark(seconds, repeat):
+ """Source writing bytecode: {}"""
+ assert not sys.dont_write_bytecode
+ def cleanup():
+ sys.modules.pop(name)
+ os.unlink(imp.cache_from_source(module.__file__))
+ for result in bench(name, cleanup, repeat=repeat, seconds=seconds):
+ yield result
+
+ writing_bytecode_benchmark.__doc__ = (
+ writing_bytecode_benchmark.__doc__.format(name))
+ return writing_bytecode_benchmark
+
+tabnanny_writing_bytecode = _writing_bytecode(tabnanny)
+decimal_writing_bytecode = _writing_bytecode(decimal)
def source_using_bytecode(seconds, repeat):
- """Bytecode w/ source: simple"""
+ """Source w/ bytecode: small"""
name = '__importlib_test_benchmark__'
with source_util.create_modules(name) as mapping:
py_compile.compile(mapping[name])
@@ -123,27 +140,56 @@ def source_using_bytecode(seconds, repeat):
yield result
-def decimal_using_bytecode(seconds, repeat):
- """Bytecode w/ source: decimal"""
- name = 'decimal'
- py_compile.compile(decimal.__file__)
- for result in bench(name, lambda: sys.modules.pop(name), repeat=repeat,
- seconds=seconds):
- yield result
+def _using_bytecode(module):
+ name = module.__name__
+ def using_bytecode_benchmark(seconds, repeat):
+ """Source w/ bytecode: {}"""
+ py_compile.compile(module.__file__)
+ for result in bench(name, lambda: sys.modules.pop(name), repeat=repeat,
+ seconds=seconds):
+ yield result
+ using_bytecode_benchmark.__doc__ = (
+ using_bytecode_benchmark.__doc__.format(name))
+ return using_bytecode_benchmark
-def main(import_):
+tabnanny_using_bytecode = _using_bytecode(tabnanny)
+decimal_using_bytecode = _using_bytecode(decimal)
+
+
+def main(import_, options):
+ if options.source_file:
+ with options.source_file:
+ prev_results = json.load(options.source_file)
+ else:
+ prev_results = {}
__builtins__.__import__ = import_
benchmarks = (from_cache, builtin_mod,
- source_using_bytecode, source_wo_bytecode,
source_writing_bytecode,
- decimal_using_bytecode, decimal_writing_bytecode,
- decimal_wo_bytecode,)
+ source_wo_bytecode, source_using_bytecode,
+ tabnanny_writing_bytecode,
+ tabnanny_wo_bytecode, tabnanny_using_bytecode,
+ decimal_writing_bytecode,
+ decimal_wo_bytecode, decimal_using_bytecode,
+ )
+ if options.benchmark:
+ for b in benchmarks:
+ if b.__doc__ == options.benchmark:
+ benchmarks = [b]
+ break
+ else:
+ print('Unknown benchmark: {!r}'.format(options.benchmark,
+ file=sys.stderr))
+ sys.exit(1)
seconds = 1
seconds_plural = 's' if seconds > 1 else ''
repeat = 3
- header = "Measuring imports/second over {} second{}, best out of {}\n"
- print(header.format(seconds, seconds_plural, repeat))
+ header = ('Measuring imports/second over {} second{}, best out of {}\n'
+ 'Entire benchmark run should take about {} seconds\n'
+ 'Using {!r} as __import__\n')
+ print(header.format(seconds, seconds_plural, repeat,
+ len(benchmarks) * seconds * repeat, __import__))
+ new_results = {}
for benchmark in benchmarks:
print(benchmark.__doc__, "[", end=' ')
sys.stdout.flush()
@@ -154,19 +200,40 @@ def main(import_):
sys.stdout.flush()
assert not sys.dont_write_bytecode
print("]", "best is", format(max(results), ',d'))
+ new_results[benchmark.__doc__] = results
+ if prev_results:
+ print('\n\nComparing new vs. old\n')
+ for benchmark in benchmarks:
+ benchmark_name = benchmark.__doc__
+ old_result = max(prev_results[benchmark_name])
+ new_result = max(new_results[benchmark_name])
+ result = '{:,d} vs. {:,d} ({:%})'.format(new_result,
+ old_result,
+ new_result/old_result)
+ print(benchmark_name, ':', result)
+ if options.dest_file:
+ with options.dest_file:
+ json.dump(new_results, options.dest_file, indent=2)
if __name__ == '__main__':
- import optparse
+ import argparse
- parser = optparse.OptionParser()
- parser.add_option('-b', '--builtin', dest='builtin', action='store_true',
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-b', '--builtin', dest='builtin', action='store_true',
default=False, help="use the built-in __import__")
- options, args = parser.parse_args()
- if args:
- raise RuntimeError("unrecognized args: {}".format(args))
+ parser.add_argument('-r', '--read', dest='source_file',
+ type=argparse.FileType('r'),
+ help='file to read benchmark data from to compare '
+ 'against')
+ parser.add_argument('-w', '--write', dest='dest_file',
+ type=argparse.FileType('w'),
+ help='file to write benchmark data to')
+ parser.add_argument('--benchmark', dest='benchmark',
+ help='specific benchmark to run')
+ options = parser.parse_args()
import_ = __import__
if not options.builtin:
import_ = importlib.__import__
- main(import_)
+ main(import_, options)
diff --git a/Lib/importlib/test/extension/test_case_sensitivity.py b/Lib/importlib/test/extension/test_case_sensitivity.py
index e062fb6597..add830dedd 100644
--- a/Lib/importlib/test/extension/test_case_sensitivity.py
+++ b/Lib/importlib/test/extension/test_case_sensitivity.py
@@ -20,12 +20,18 @@ class ExtensionModuleCaseSensitivityTest(unittest.TestCase):
def test_case_sensitive(self):
with support.EnvironmentVarGuard() as env:
env.unset('PYTHONCASEOK')
+ if b'PYTHONCASEOK' in _bootstrap._os.environ:
+ self.skipTest('os.environ changes not reflected in '
+ '_os.environ')
loader = self.find_module()
self.assertIsNone(loader)
def test_case_insensitivity(self):
with support.EnvironmentVarGuard() as env:
env.set('PYTHONCASEOK', '1')
+ if b'PYTHONCASEOK' not in _bootstrap._os.environ:
+ self.skipTest('os.environ changes not reflected in '
+ '_os.environ')
loader = self.find_module()
self.assertTrue(hasattr(loader, 'load_module'))
diff --git a/Lib/importlib/test/import_/test___package__.py b/Lib/importlib/test/import_/test___package__.py
index 5056ae59cc..783cde1729 100644
--- a/Lib/importlib/test/import_/test___package__.py
+++ b/Lib/importlib/test/import_/test___package__.py
@@ -67,7 +67,7 @@ class Using__package__(unittest.TestCase):
def test_bunk__package__(self):
globals = {'__package__': 42}
- with self.assertRaises(ValueError):
+ with self.assertRaises(TypeError):
import_util.import_('', globals, {}, ['relimport'], 1)
diff --git a/Lib/importlib/test/import_/test_api.py b/Lib/importlib/test/import_/test_api.py
index 9075d42759..2fa1f90954 100644
--- a/Lib/importlib/test/import_/test_api.py
+++ b/Lib/importlib/test/import_/test_api.py
@@ -12,6 +12,13 @@ class APITest(unittest.TestCase):
with self.assertRaises(TypeError):
util.import_(42)
+ def test_negative_level(self):
+ # Raise ValueError when a negative level is specified.
+ # PEP 328 did away with sys.module None entries and the ambiguity of
+ # absolute/relative imports.
+ with self.assertRaises(ValueError):
+ util.import_('os', globals(), level=-1)
+
def test_main():
from test.support import run_unittest
diff --git a/Lib/importlib/test/import_/test_path.py b/Lib/importlib/test/import_/test_path.py
index 2faa23174b..5713319612 100644
--- a/Lib/importlib/test/import_/test_path.py
+++ b/Lib/importlib/test/import_/test_path.py
@@ -73,6 +73,16 @@ class FinderTests(unittest.TestCase):
loader = machinery.PathFinder.find_module(module)
self.assertTrue(loader is importer)
+ def test_path_importer_cache_empty_string(self):
+ # The empty string should create a finder using the cwd.
+ path = ''
+ module = '<test module>'
+ importer = util.mock_modules(module)
+ hook = import_util.mock_path_hook(os.curdir, importer=importer)
+ with util.import_state(path=[path], path_hooks=[hook]):
+ loader = machinery.PathFinder.find_module(module)
+ self.assertIs(loader, importer)
+ self.assertIn(os.curdir, sys.path_importer_cache)
class DefaultPathFinderTests(unittest.TestCase):
diff --git a/Lib/importlib/test/regrtest.py b/Lib/importlib/test/regrtest.py
index b103ae7d0e..dc0eb97022 100644
--- a/Lib/importlib/test/regrtest.py
+++ b/Lib/importlib/test/regrtest.py
@@ -5,13 +5,6 @@ invalidates are automatically skipped if the entire test suite is run.
Otherwise all command-line options valid for test.regrtest are also valid for
this script.
-XXX FAILING
- * test_import
- - test_incorrect_code_name
- file name differing between __file__ and co_filename (r68360 on trunk)
- - test_import_by_filename
- exception for trying to import by file name does not match
-
"""
import importlib
import sys
diff --git a/Lib/importlib/test/source/test_abc_loader.py b/Lib/importlib/test/source/test_abc_loader.py
index 32459074a0..01acda4b9a 100644
--- a/Lib/importlib/test/source/test_abc_loader.py
+++ b/Lib/importlib/test/source/test_abc_loader.py
@@ -40,8 +40,10 @@ class SourceLoaderMock(SourceOnlyLoaderMock):
def __init__(self, path, magic=imp.get_magic()):
super().__init__(path)
self.bytecode_path = imp.cache_from_source(self.path)
+ self.source_size = len(self.source)
data = bytearray(magic)
- data.extend(marshal._w_long(self.source_mtime))
+ data.extend(importlib._w_long(self.source_mtime))
+ data.extend(importlib._w_long(self.source_size))
code_object = compile(self.source, self.path, 'exec',
dont_inherit=True)
data.extend(marshal.dumps(code_object))
@@ -56,9 +58,9 @@ class SourceLoaderMock(SourceOnlyLoaderMock):
else:
raise IOError
- def path_mtime(self, path):
+ def path_stats(self, path):
assert path == self.path
- return self.source_mtime
+ return {'mtime': self.source_mtime, 'size': self.source_size}
def set_data(self, path, data):
self.written[path] = bytes(data)
@@ -102,7 +104,7 @@ class PyLoaderMock(abc.PyLoader):
warnings.simplefilter("always")
path = super().get_filename(name)
assert len(w) == 1
- assert issubclass(w[0].category, PendingDeprecationWarning)
+ assert issubclass(w[0].category, DeprecationWarning)
return path
@@ -198,7 +200,7 @@ class PyPycLoaderMock(abc.PyPycLoader, PyLoaderMock):
warnings.simplefilter("always")
code_object = super().get_code(name)
assert len(w) == 1
- assert issubclass(w[0].category, PendingDeprecationWarning)
+ assert issubclass(w[0].category, DeprecationWarning)
return code_object
class PyLoaderTests(testing_abc.LoaderTests):
@@ -656,7 +658,8 @@ class SourceLoaderBytecodeTests(SourceLoaderTestHarness):
if bytecode_written:
self.assertIn(self.cached, self.loader.written)
data = bytearray(imp.get_magic())
- data.extend(marshal._w_long(self.loader.source_mtime))
+ data.extend(importlib._w_long(self.loader.source_mtime))
+ data.extend(importlib._w_long(self.loader.source_size))
data.extend(marshal.dumps(code_object))
self.assertEqual(self.loader.written[self.cached], bytes(data))
@@ -847,7 +850,7 @@ class AbstractMethodImplTests(unittest.TestCase):
# Required abstractmethods.
self.raises_NotImplementedError(ins, 'get_filename', 'get_data')
# Optional abstractmethods.
- self.raises_NotImplementedError(ins,'path_mtime', 'set_data')
+ self.raises_NotImplementedError(ins,'path_stats', 'set_data')
def test_PyLoader(self):
self.raises_NotImplementedError(self.PyLoader(), 'source_path',
diff --git a/Lib/importlib/test/source/test_case_sensitivity.py b/Lib/importlib/test/source/test_case_sensitivity.py
index 73777de4ba..569f516d5a 100644
--- a/Lib/importlib/test/source/test_case_sensitivity.py
+++ b/Lib/importlib/test/source/test_case_sensitivity.py
@@ -37,6 +37,9 @@ class CaseSensitivityTest(unittest.TestCase):
def test_sensitive(self):
with test_support.EnvironmentVarGuard() as env:
env.unset('PYTHONCASEOK')
+ if b'PYTHONCASEOK' in _bootstrap._os.environ:
+ self.skipTest('os.environ changes not reflected in '
+ '_os.environ')
sensitive, insensitive = self.sensitivity_test()
self.assertTrue(hasattr(sensitive, 'load_module'))
self.assertIn(self.name, sensitive.get_filename(self.name))
@@ -45,6 +48,9 @@ class CaseSensitivityTest(unittest.TestCase):
def test_insensitive(self):
with test_support.EnvironmentVarGuard() as env:
env.set('PYTHONCASEOK', '1')
+ if b'PYTHONCASEOK' not in _bootstrap._os.environ:
+ self.skipTest('os.environ changes not reflected in '
+ '_os.environ')
sensitive, insensitive = self.sensitivity_test()
self.assertTrue(hasattr(sensitive, 'load_module'))
self.assertIn(self.name, sensitive.get_filename(self.name))
diff --git a/Lib/importlib/test/source/test_file_loader.py b/Lib/importlib/test/source/test_file_loader.py
index c7a7d8fbca..21e718f7d5 100644
--- a/Lib/importlib/test/source/test_file_loader.py
+++ b/Lib/importlib/test/source/test_file_loader.py
@@ -71,11 +71,6 @@ class SimpleTest(unittest.TestCase):
module_dict_id = id(module.__dict__)
with open(mapping['_temp'], 'w') as file:
file.write("testing_var = 42\n")
- # For filesystems where the mtime is only to a second granularity,
- # everything that has happened above can be too fast;
- # force an mtime on the source that is guaranteed to be different
- # than the original mtime.
- loader.path_mtime = self.fake_mtime(loader.path_mtime)
module = loader.load_module('_temp')
self.assertTrue('testing_var' in module.__dict__,
"'testing_var' not in "
@@ -215,10 +210,17 @@ class BadBytecodeTest(unittest.TestCase):
del_source=del_source)
test('_temp', mapping, bc_path)
+ def _test_partial_size(self, test, *, del_source=False):
+ with source_util.create_modules('_temp') as mapping:
+ bc_path = self.manipulate_bytecode('_temp', mapping,
+ lambda bc: bc[:11],
+ del_source=del_source)
+ test('_temp', mapping, bc_path)
+
def _test_no_marshal(self, *, del_source=False):
with source_util.create_modules('_temp') as mapping:
bc_path = self.manipulate_bytecode('_temp', mapping,
- lambda bc: bc[:8],
+ lambda bc: bc[:12],
del_source=del_source)
file_path = mapping['_temp'] if not del_source else bc_path
with self.assertRaises(EOFError):
@@ -227,7 +229,7 @@ class BadBytecodeTest(unittest.TestCase):
def _test_non_code_marshal(self, *, del_source=False):
with source_util.create_modules('_temp') as mapping:
bytecode_path = self.manipulate_bytecode('_temp', mapping,
- lambda bc: bc[:8] + marshal.dumps(b'abcd'),
+ lambda bc: bc[:12] + marshal.dumps(b'abcd'),
del_source=del_source)
file_path = mapping['_temp'] if not del_source else bytecode_path
with self.assertRaises(ImportError):
@@ -236,7 +238,7 @@ class BadBytecodeTest(unittest.TestCase):
def _test_bad_marshal(self, *, del_source=False):
with source_util.create_modules('_temp') as mapping:
bytecode_path = self.manipulate_bytecode('_temp', mapping,
- lambda bc: bc[:8] + b'<test>',
+ lambda bc: bc[:12] + b'<test>',
del_source=del_source)
file_path = mapping['_temp'] if not del_source else bytecode_path
with self.assertRaises(EOFError):
@@ -260,7 +262,7 @@ class SourceLoaderBadBytecodeTest(BadBytecodeTest):
def test(name, mapping, bytecode_path):
self.import_(mapping[name], name)
with open(bytecode_path, 'rb') as file:
- self.assertGreater(len(file.read()), 8)
+ self.assertGreater(len(file.read()), 12)
self._test_empty_file(test)
@@ -268,7 +270,7 @@ class SourceLoaderBadBytecodeTest(BadBytecodeTest):
def test(name, mapping, bytecode_path):
self.import_(mapping[name], name)
with open(bytecode_path, 'rb') as file:
- self.assertGreater(len(file.read()), 8)
+ self.assertGreater(len(file.read()), 12)
self._test_partial_magic(test)
@@ -279,7 +281,7 @@ class SourceLoaderBadBytecodeTest(BadBytecodeTest):
def test(name, mapping, bytecode_path):
self.import_(mapping[name], name)
with open(bytecode_path, 'rb') as file:
- self.assertGreater(len(file.read()), 8)
+ self.assertGreater(len(file.read()), 12)
self._test_magic_only(test)
@@ -301,11 +303,22 @@ class SourceLoaderBadBytecodeTest(BadBytecodeTest):
def test(name, mapping, bc_path):
self.import_(mapping[name], name)
with open(bc_path, 'rb') as file:
- self.assertGreater(len(file.read()), 8)
+ self.assertGreater(len(file.read()), 12)
self._test_partial_timestamp(test)
@source_util.writes_bytecode_files
+ def test_partial_size(self):
+ # When the size is partial, regenerate the .pyc, else
+ # raise EOFError.
+ def test(name, mapping, bc_path):
+ self.import_(mapping[name], name)
+ with open(bc_path, 'rb') as file:
+ self.assertGreater(len(file.read()), 12)
+
+ self._test_partial_size(test)
+
+ @source_util.writes_bytecode_files
def test_no_marshal(self):
# When there is only the magic number and timestamp, raise EOFError.
self._test_no_marshal()
@@ -400,6 +413,13 @@ class SourcelessLoaderBadBytecodeTest(BadBytecodeTest):
self._test_partial_timestamp(test, del_source=True)
+ def test_partial_size(self):
+ def test(name, mapping, bytecode_path):
+ with self.assertRaises(EOFError):
+ self.import_(bytecode_path, name)
+
+ self._test_partial_size(test, del_source=True)
+
def test_no_marshal(self):
self._test_no_marshal(del_source=True)
diff --git a/Lib/importlib/test/source/test_finder.py b/Lib/importlib/test/source/test_finder.py
index 7b9088da0c..68e9ae71a4 100644
--- a/Lib/importlib/test/source/test_finder.py
+++ b/Lib/importlib/test/source/test_finder.py
@@ -143,6 +143,13 @@ class FinderTests(abc.FinderTests):
finally:
os.unlink('mod.py')
+ def test_invalidate_caches(self):
+ # invalidate_caches() should reset the mtime.
+ finder = _bootstrap._FileFinder('', _bootstrap._SourceFinderDetails())
+ finder._path_mtime = 42
+ finder.invalidate_caches()
+ self.assertEqual(finder._path_mtime, -1)
+
def test_main():
from test.support import run_unittest
diff --git a/Lib/importlib/test/test_api.py b/Lib/importlib/test/test_api.py
index a151626de7..cc147c200b 100644
--- a/Lib/importlib/test/test_api.py
+++ b/Lib/importlib/test/test_api.py
@@ -84,6 +84,34 @@ class ImportModuleTests(unittest.TestCase):
importlib.import_module('a.b')
self.assertEqual(b_load_count, 1)
+
+class InvalidateCacheTests(unittest.TestCase):
+
+ def test_method_called(self):
+ # If defined the method should be called.
+ class InvalidatingNullFinder:
+ def __init__(self, *ignored):
+ self.called = False
+ def find_module(self, *args):
+ return None
+ def invalidate_caches(self):
+ self.called = True
+
+ key = 'gobledeegook'
+ ins = InvalidatingNullFinder()
+ sys.path_importer_cache[key] = ins
+ self.addCleanup(lambda: sys.path_importer_cache.__delitem__(key))
+ importlib.invalidate_caches()
+ self.assertTrue(ins.called)
+
+ def test_method_lacking(self):
+ # There should be no issues if the method is not defined.
+ key = 'gobbledeegook'
+ sys.path_importer_cache[key] = imp.NullImporter('abc')
+ self.addCleanup(lambda: sys.path_importer_cache.__delitem__(key))
+ importlib.invalidate_caches() # Shouldn't trigger an exception.
+
+
def test_main():
from test.support import run_unittest
run_unittest(ImportModuleTests)
diff --git a/Lib/importlib/test/test_util.py b/Lib/importlib/test/test_util.py
index 602447f09e..c7cdad14d7 100644
--- a/Lib/importlib/test/test_util.py
+++ b/Lib/importlib/test/test_util.py
@@ -59,6 +59,11 @@ class ModuleForLoaderTests(unittest.TestCase):
self.raise_exception(name)
self.assertIs(module, sys.modules[name])
+ def test_decorator_attrs(self):
+ def fxn(self, module): pass
+ wrapped = util.module_for_loader(fxn)
+ self.assertEqual(wrapped.__name__, fxn.__name__)
+ self.assertEqual(wrapped.__qualname__, fxn.__qualname__)
class SetPackageTests(unittest.TestCase):
@@ -108,6 +113,11 @@ class SetPackageTests(unittest.TestCase):
module.__package__ = value
self.verify(module, value)
+ def test_decorator_attrs(self):
+ def fxn(module): pass
+ wrapped = util.set_package(fxn)
+ self.assertEqual(wrapped.__name__, fxn.__name__)
+ self.assertEqual(wrapped.__qualname__, fxn.__qualname__)
def test_main():
from test import support