diff options
Diffstat (limited to 'numpy/testing')
-rw-r--r-- | numpy/testing/tests/test_utils.py | 70 | ||||
-rw-r--r-- | numpy/testing/utils.py | 64 |
2 files changed, 132 insertions, 2 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 756ea997e..68075fc3d 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -4,7 +4,12 @@ import warnings import sys import numpy as np -from numpy.testing import * +from numpy.testing import ( + assert_equal, assert_array_equal, assert_almost_equal, + assert_array_almost_equal, build_err_msg, raises, assert_raises, + assert_warns, assert_no_warnings, assert_allclose, assert_approx_equal, + assert_array_almost_equal_nulp, assert_array_max_ulp, + clear_and_catch_warnings, run_module_suite) import unittest class _GenericTest(object): @@ -252,6 +257,7 @@ class TestArrayAlmostEqual(_GenericTest, unittest.TestCase): assert_array_almost_equal(b, a) assert_array_almost_equal(b, b) + class TestAlmostEqual(_GenericTest, unittest.TestCase): def setUp(self): self._assert_func = assert_almost_equal @@ -688,5 +694,67 @@ class TestULP(unittest.TestCase): self.assertRaises(AssertionError, lambda: assert_array_max_ulp(nan, nzero, maxulp=maxulp)) + +def assert_warn_len_equal(mod, n_in_context): + mod_warns = mod.__warningregistry__ + # Python 3.4 appears to clear any pre-existing warnings of the same type, + # when raising warnings inside a catch_warnings block. So, there is a + # warning generated by the tests within the context manager, but no + # previous warnings. + if 'version' in mod_warns: + assert_equal(len(mod_warns), 2) # including 'version' + else: + assert_equal(len(mod_warns), n_in_context) + + +def _get_fresh_mod(): + # Get this module, with warning registry empty + my_mod = sys.modules[__name__] + try: + my_mod.__warningregistry__.clear() + except AttributeError: + pass + return my_mod + + +def test_clear_and_catch_warnings(): + # Initial state of module, no warnings + my_mod = _get_fresh_mod() + assert_equal(getattr(my_mod, '__warningregistry__', {}), {}) + with clear_and_catch_warnings(modules=[my_mod]): + warnings.simplefilter('ignore') + warnings.warn('Some warning') + assert_equal(my_mod.__warningregistry__, {}) + # Without specified modules, don't clear warnings during context + with clear_and_catch_warnings(): + warnings.simplefilter('ignore') + warnings.warn('Some warning') + assert_warn_len_equal(my_mod, 1) + # Confirm that specifying module keeps old warning, does not add new + with clear_and_catch_warnings(modules=[my_mod]): + warnings.simplefilter('ignore') + warnings.warn('Another warning') + assert_warn_len_equal(my_mod, 1) + # Another warning, no module spec does add to warnings dict, except on + # Python 3.4 (see comments in `assert_warn_len_equal`) + with clear_and_catch_warnings(): + warnings.simplefilter('ignore') + warnings.warn('Another warning') + assert_warn_len_equal(my_mod, 2) + + +class my_cacw(clear_and_catch_warnings): + class_modules = (sys.modules[__name__],) + + +def test_clear_and_catch_warnings_inherit(): + # Test can subclass and add default modules + my_mod = _get_fresh_mod() + with my_cacw(): + warnings.simplefilter('ignore') + warnings.warn('Some warning') + assert_equal(my_mod.__warningregistry__, {}) + + if __name__ == '__main__': run_module_suite() diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 0971ebe94..4527a51d9 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -28,7 +28,7 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal', 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure', 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', - 'assert_allclose', 'IgnoreException'] + 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings'] verbose = 0 @@ -1718,3 +1718,65 @@ def tempdir(*args, **kwargs): tmpdir = mkdtemp(*args, **kwargs) yield tmpdir shutil.rmtree(tmpdir) + + +class clear_and_catch_warnings(warnings.catch_warnings): + """ Context manager that resets warning registry for catching warnings + + Warnings can be slippery, because, whenever a warning is triggered, Python + adds a ``__warningregistry__`` member to the *calling* module. This makes + it impossible to retrigger the warning in this module, whatever you put in + the warnings filters. This context manager accepts a sequence of `modules` + as a keyword argument to its constructor and: + + * stores and removes any ``__warningregistry__`` entries in given `modules` + on entry; + * resets ``__warningregistry__`` to its previous state on exit. + + This makes it possible to trigger any warning afresh inside the context + manager without disturbing the state of warnings outside. + + For compatibility with Python 3.0, please consider all arguments to be + keyword-only. + + Parameters + ---------- + record : bool, optional + Specifies whether warnings should be captured by a custom + implementation of ``warnings.showwarning()`` and be appended to a list + returned by the context manager. Otherwise None is returned by the + context manager. The objects appended to the list are arguments whose + attributes mirror the arguments to ``showwarning()``. + modules : sequence, optional + Sequence of modules for which to reset warnings registry on entry and + restore on exit + + Examples + -------- + >>> import warnings + >>> with clear_and_catch_warnings(modules=[np.core.fromnumeric]): + ... warnings.simplefilter('always') + ... # do something that raises a warning in np.core.fromnumeric + """ + class_modules = () + + def __init__(self, record=False, modules=()): + self.modules = set(modules).union(self.class_modules) + self._warnreg_copies = {} + super(clear_and_catch_warnings, self).__init__(record=record) + + def __enter__(self): + for mod in self.modules: + if hasattr(mod, '__warningregistry__'): + mod_reg = mod.__warningregistry__ + self._warnreg_copies[mod] = mod_reg.copy() + mod_reg.clear() + return super(clear_and_catch_warnings, self).__enter__() + + def __exit__(self, *exc_info): + super(clear_and_catch_warnings, self).__exit__(*exc_info) + for mod in self.modules: + if hasattr(mod, '__warningregistry__'): + mod.__warningregistry__.clear() + if mod in self._warnreg_copies: + mod.__warningregistry__.update(self._warnreg_copies[mod]) |