summaryrefslogtreecommitdiff
path: root/numpy/testing
diff options
context:
space:
mode:
authorDavid Cournapeau <cournape@gmail.com>2009-11-23 09:28:02 +0000
committerDavid Cournapeau <cournape@gmail.com>2009-11-23 09:28:02 +0000
commit7277495f0c0bf4b64be4987243e1a08b2f831549 (patch)
tree3615bcacf79a4a61bee56bccb438352c10a7dfd8 /numpy/testing
parentd9306312fb09e86736717a0b4121794de5a3034d (diff)
downloadnumpy-7277495f0c0bf4b64be4987243e1a08b2f831549.tar.gz
ENH: add an assert_warns testing utility.
Diffstat (limited to 'numpy/testing')
-rw-r--r--numpy/testing/tests/test_utils.py34
-rw-r--r--numpy/testing/utils.py26
2 files changed, 59 insertions, 1 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index ab314a703..2d22789ff 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -1,3 +1,6 @@
+import warnings
+import sys
+
import numpy as np
from numpy.testing import *
import unittest
@@ -301,6 +304,37 @@ class TestRaises(unittest.TestCase):
else:
raise AssertionError("should have raised an AssertionError")
+class TestWarns(unittest.TestCase):
+ def test_warn(self):
+ def f():
+ warnings.warn("yo")
+
+ before_filters = sys.modules['warnings'].filters[:]
+ assert_warns(UserWarning, f)
+ after_filters = sys.modules['warnings'].filters
+
+ # Check that the warnings state is unchanged
+ assert_equal(before_filters, after_filters,
+ "assert_warns does not preserver warnings state")
+
+ def test_warn_wrong_warning(self):
+ def f():
+ warnings.warn("yo", DeprecationWarning)
+
+ failed = False
+ filters = sys.modules['warnings'].filters[:]
+ try:
+ # Should raise an AssertionError
+ assert_warns(UserWarning, f)
+ failed = True
+ except AssertionError:
+ pass
+ finally:
+ sys.modules['warnings'].filters = filters
+
+ if failed:
+ raise AssertionError("wrong warning caught by assert_warn")
+
class TestArrayAlmostEqualNulp(unittest.TestCase):
def test_simple(self):
dev = np.random.randn(10)
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index bafee6363..7c8e978e2 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -7,6 +7,7 @@ import sys
import re
import operator
import types
+import warnings
from nosetester import import_nose
__all__ = ['assert_equal', 'assert_almost_equal','assert_approx_equal',
@@ -15,7 +16,7 @@ __all__ = ['assert_equal', 'assert_almost_equal','assert_approx_equal',
'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure',
'assert_', 'assert_array_almost_equal_nulp',
- 'assert_array_max_ulp']
+ 'assert_array_max_ulp', 'assert_warns']
verbose = 0
@@ -1281,3 +1282,26 @@ class WarningManager:
self._module.filters = self._filters
self._module.showwarning = self._showwarning
+def assert_warns(warning_class, func, *args, **kw):
+ """Fail unless a warning of class warning_class is thrown by callable when
+ invoked with arguments args and keyword arguments kwargs.
+
+ If a different type of warning is thrown, it will not be caught, and the
+ test case will be deemed to have suffered an error.
+ """
+
+ # XXX: once we may depend on python >= 2.6, this can be replaced by the
+ # warnings module context manager.
+ ctx = WarningManager(record=True)
+ l = ctx.__enter__()
+ warnings.simplefilter('always')
+ try:
+ func(*args, **kw)
+ if not len(l) > 0:
+ raise AssertionError("No warning raised when calling %s"
+ % func.__name__)
+ if not l[0].category is warning_class:
+ raise AssertionError("First warning for %s is not a " \
+ "%s( is %s)" % (func.__name__, warning_class, l[0]))
+ finally:
+ ctx.__exit__()