summaryrefslogtreecommitdiff
path: root/numpy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing')
-rw-r--r--numpy/testing/tests/test_utils.py70
-rw-r--r--numpy/testing/utils.py64
2 files changed, 132 insertions, 2 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 756ea997e..68075fc3d 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -4,7 +4,12 @@ import warnings
import sys
import numpy as np
-from numpy.testing import *
+from numpy.testing import (
+ assert_equal, assert_array_equal, assert_almost_equal,
+ assert_array_almost_equal, build_err_msg, raises, assert_raises,
+ assert_warns, assert_no_warnings, assert_allclose, assert_approx_equal,
+ assert_array_almost_equal_nulp, assert_array_max_ulp,
+ clear_and_catch_warnings, run_module_suite)
import unittest
class _GenericTest(object):
@@ -252,6 +257,7 @@ class TestArrayAlmostEqual(_GenericTest, unittest.TestCase):
assert_array_almost_equal(b, a)
assert_array_almost_equal(b, b)
+
class TestAlmostEqual(_GenericTest, unittest.TestCase):
def setUp(self):
self._assert_func = assert_almost_equal
@@ -688,5 +694,67 @@ class TestULP(unittest.TestCase):
self.assertRaises(AssertionError,
lambda: assert_array_max_ulp(nan, nzero,
maxulp=maxulp))
+
+def assert_warn_len_equal(mod, n_in_context):
+ mod_warns = mod.__warningregistry__
+ # Python 3.4 appears to clear any pre-existing warnings of the same type,
+ # when raising warnings inside a catch_warnings block. So, there is a
+ # warning generated by the tests within the context manager, but no
+ # previous warnings.
+ if 'version' in mod_warns:
+ assert_equal(len(mod_warns), 2) # including 'version'
+ else:
+ assert_equal(len(mod_warns), n_in_context)
+
+
+def _get_fresh_mod():
+ # Get this module, with warning registry empty
+ my_mod = sys.modules[__name__]
+ try:
+ my_mod.__warningregistry__.clear()
+ except AttributeError:
+ pass
+ return my_mod
+
+
+def test_clear_and_catch_warnings():
+ # Initial state of module, no warnings
+ my_mod = _get_fresh_mod()
+ assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
+ with clear_and_catch_warnings(modules=[my_mod]):
+ warnings.simplefilter('ignore')
+ warnings.warn('Some warning')
+ assert_equal(my_mod.__warningregistry__, {})
+ # Without specified modules, don't clear warnings during context
+ with clear_and_catch_warnings():
+ warnings.simplefilter('ignore')
+ warnings.warn('Some warning')
+ assert_warn_len_equal(my_mod, 1)
+ # Confirm that specifying module keeps old warning, does not add new
+ with clear_and_catch_warnings(modules=[my_mod]):
+ warnings.simplefilter('ignore')
+ warnings.warn('Another warning')
+ assert_warn_len_equal(my_mod, 1)
+ # Another warning, no module spec does add to warnings dict, except on
+ # Python 3.4 (see comments in `assert_warn_len_equal`)
+ with clear_and_catch_warnings():
+ warnings.simplefilter('ignore')
+ warnings.warn('Another warning')
+ assert_warn_len_equal(my_mod, 2)
+
+
+class my_cacw(clear_and_catch_warnings):
+ class_modules = (sys.modules[__name__],)
+
+
+def test_clear_and_catch_warnings_inherit():
+ # Test can subclass and add default modules
+ my_mod = _get_fresh_mod()
+ with my_cacw():
+ warnings.simplefilter('ignore')
+ warnings.warn('Some warning')
+ assert_equal(my_mod.__warningregistry__, {})
+
+
if __name__ == '__main__':
run_module_suite()
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 0971ebe94..4527a51d9 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -28,7 +28,7 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal',
'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure',
'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
- 'assert_allclose', 'IgnoreException']
+ 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings']
verbose = 0
@@ -1718,3 +1718,65 @@ def tempdir(*args, **kwargs):
tmpdir = mkdtemp(*args, **kwargs)
yield tmpdir
shutil.rmtree(tmpdir)
+
+
+class clear_and_catch_warnings(warnings.catch_warnings):
+ """ Context manager that resets warning registry for catching warnings
+
+ Warnings can be slippery, because, whenever a warning is triggered, Python
+ adds a ``__warningregistry__`` member to the *calling* module. This makes
+ it impossible to retrigger the warning in this module, whatever you put in
+ the warnings filters. This context manager accepts a sequence of `modules`
+ as a keyword argument to its constructor and:
+
+ * stores and removes any ``__warningregistry__`` entries in given `modules`
+ on entry;
+ * resets ``__warningregistry__`` to its previous state on exit.
+
+ This makes it possible to trigger any warning afresh inside the context
+ manager without disturbing the state of warnings outside.
+
+ For compatibility with Python 3.0, please consider all arguments to be
+ keyword-only.
+
+ Parameters
+ ----------
+ record : bool, optional
+ Specifies whether warnings should be captured by a custom
+ implementation of ``warnings.showwarning()`` and be appended to a list
+ returned by the context manager. Otherwise None is returned by the
+ context manager. The objects appended to the list are arguments whose
+ attributes mirror the arguments to ``showwarning()``.
+ modules : sequence, optional
+ Sequence of modules for which to reset warnings registry on entry and
+ restore on exit
+
+ Examples
+ --------
+ >>> import warnings
+ >>> with clear_and_catch_warnings(modules=[np.core.fromnumeric]):
+ ... warnings.simplefilter('always')
+ ... # do something that raises a warning in np.core.fromnumeric
+ """
+ class_modules = ()
+
+ def __init__(self, record=False, modules=()):
+ self.modules = set(modules).union(self.class_modules)
+ self._warnreg_copies = {}
+ super(clear_and_catch_warnings, self).__init__(record=record)
+
+ def __enter__(self):
+ for mod in self.modules:
+ if hasattr(mod, '__warningregistry__'):
+ mod_reg = mod.__warningregistry__
+ self._warnreg_copies[mod] = mod_reg.copy()
+ mod_reg.clear()
+ return super(clear_and_catch_warnings, self).__enter__()
+
+ def __exit__(self, *exc_info):
+ super(clear_and_catch_warnings, self).__exit__(*exc_info)
+ for mod in self.modules:
+ if hasattr(mod, '__warningregistry__'):
+ mod.__warningregistry__.clear()
+ if mod in self._warnreg_copies:
+ mod.__warningregistry__.update(self._warnreg_copies[mod])