diff options
Diffstat (limited to 'Lib/test/test_importlib/util.py')
-rw-r--r-- | Lib/test/test_importlib/util.py | 105 |
1 files changed, 95 insertions, 10 deletions
diff --git a/Lib/test/test_importlib/util.py b/Lib/test/test_importlib/util.py index ef32f7d690..885cec3b29 100644 --- a/Lib/test/test_importlib/util.py +++ b/Lib/test/test_importlib/util.py @@ -1,9 +1,31 @@ from contextlib import contextmanager -import imp +from importlib import util, invalidate_caches import os.path from test import support import unittest import sys +import types + + +def import_importlib(module_name): + """Import a module from importlib both w/ and w/o _frozen_importlib.""" + fresh = ('importlib',) if '.' in module_name else () + frozen = support.import_fresh_module(module_name) + source = support.import_fresh_module(module_name, fresh=fresh, + blocked=('_frozen_importlib',)) + return frozen, source + + +def test_both(test_class, **kwargs): + frozen_tests = types.new_class('Frozen_'+test_class.__name__, + (test_class, unittest.TestCase)) + source_tests = types.new_class('Source_'+test_class.__name__, + (test_class, unittest.TestCase)) + frozen_tests.__module__ = source_tests.__module__ = test_class.__module__ + for attr, (frozen_value, source_value) in kwargs.items(): + setattr(frozen_tests, attr, frozen_value) + setattr(source_tests, attr, source_value) + return frozen_tests, source_tests CASE_INSENSITIVE_FS = True @@ -24,6 +46,13 @@ def case_insensitive_tests(test): "requires a case-insensitive filesystem")(test) +def submodule(parent, name, pkg_dir, content=''): + path = os.path.join(pkg_dir, name + '.py') + with open(path, 'w') as subfile: + subfile.write(content) + return '{}.{}'.format(parent, name), path + + @contextmanager def uncache(*names): """Uncache a module from sys.modules. @@ -49,6 +78,31 @@ def uncache(*names): except KeyError: pass + +@contextmanager +def temp_module(name, content='', *, pkg=False): + conflicts = [n for n in sys.modules if n.partition('.')[0] == name] + with support.temp_cwd(None) as cwd: + with uncache(name, *conflicts): + with support.DirsOnSysPath(cwd): + invalidate_caches() + + location = os.path.join(cwd, name) + if pkg: + modpath = os.path.join(location, '__init__.py') + os.mkdir(name) + else: + modpath = location + '.py' + if content is None: + # Make sure the module file gets created. + content = '' + if content is not None: + # not a namespace package + with open(modpath, 'w') as modfile: + modfile.write(content) + yield location + + @contextmanager def import_state(**kwargs): """Context manager to manage the various importers and stored state in the @@ -80,9 +134,9 @@ def import_state(**kwargs): setattr(sys, attr, value) -class mock_modules: +class _ImporterMock: - """A mock importer/loader.""" + """Base class to help with creating importer mocks.""" def __init__(self, *names, module_code={}): self.modules = {} @@ -98,7 +152,7 @@ class mock_modules: package = name.rsplit('.', 1)[0] else: package = import_name - module = imp.new_module(import_name) + module = types.ModuleType(import_name) module.__loader__ = self module.__file__ = '<mock __file__>' module.__package__ = package @@ -112,6 +166,19 @@ class mock_modules: def __getitem__(self, name): return self.modules[name] + def __enter__(self): + self._uncache = uncache(*self.modules.keys()) + self._uncache.__enter__() + return self + + def __exit__(self, *exc_info): + self._uncache.__exit__(None, None, None) + + +class mock_modules(_ImporterMock): + + """Importer mock using PEP 302 APIs.""" + def find_module(self, fullname, path=None): if fullname not in self.modules: return None @@ -131,10 +198,28 @@ class mock_modules: raise return self.modules[fullname] - def __enter__(self): - self._uncache = uncache(*self.modules.keys()) - self._uncache.__enter__() - return self +class mock_spec(_ImporterMock): - def __exit__(self, *exc_info): - self._uncache.__exit__(None, None, None) + """Importer mock using PEP 451 APIs.""" + + def find_spec(self, fullname, path=None, parent=None): + try: + module = self.modules[fullname] + except KeyError: + return None + is_package = hasattr(module, '__path__') + spec = util.spec_from_file_location( + fullname, module.__file__, loader=self, + submodule_search_locations=getattr(module, '__path__', None)) + return spec + + def create_module(self, spec): + if spec.name not in self.modules: + raise ImportError + return self.modules[spec.name] + + def exec_module(self, module): + try: + self.module_code[module.__spec__.name]() + except KeyError: + pass |