diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/mixins.py | 17 | ||||
-rw-r--r-- | tests/test_misc.py | 7 |
2 files changed, 11 insertions, 13 deletions
diff --git a/tests/mixins.py b/tests/mixins.py index 0638f336..95b2145a 100644 --- a/tests/mixins.py +++ b/tests/mixins.py @@ -15,6 +15,7 @@ import sys import pytest +from coverage.misc import SysModuleSaver from tests.helpers import change_dir, make_file, remove_files @@ -96,21 +97,11 @@ class SysPathModulesMixin: @pytest.fixture(autouse=True) def _module_saving(self): """Remove modules we imported during the test.""" - self._old_modules = list(sys.modules) + self._sys_module_saver = SysModuleSaver() try: yield finally: - self._cleanup_modules() - - def _cleanup_modules(self): - """Remove any new modules imported since our construction. - - This lets us import the same source files for more than one test, or - if called explicitly, within one test. - - """ - for m in [m for m in sys.modules if m not in self._old_modules]: - del sys.modules[m] + self._sys_module_saver.restore() def clean_local_file_imports(self): """Clean up the results of calls to `import_local_file`. @@ -120,7 +111,7 @@ class SysPathModulesMixin: """ # So that we can re-import files, clean them out first. - self._cleanup_modules() + self._sys_module_saver.restore() # Also have to clean out the .pyc file, since the timestamp # resolution is only one second, a changed file might not be diff --git a/tests/test_misc.py b/tests/test_misc.py index 077c2434..74002232 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -165,8 +165,15 @@ class ImportThirdPartyTest(CoverageTest): run_in_temp_dir = False def test_success(self): + # Make sure we don't have pytest in sys.modules before we start. + del sys.modules["pytest"] + # Import pytest mod = import_third_party("pytest") + # Yes, it's really pytest: assert mod.__name__ == "pytest" + print(dir(mod)) + assert all(hasattr(mod, name) for name in ["skip", "mark", "raises", "warns"]) + # But it's not in sys.modules: assert "pytest" not in sys.modules def test_failure(self): |