diff options
Diffstat (limited to 'numpy/testing/pytest_tools/utils.py')
-rw-r--r-- | numpy/testing/pytest_tools/utils.py | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/numpy/testing/pytest_tools/utils.py b/numpy/testing/pytest_tools/utils.py index a873199b1..19982ec54 100644 --- a/numpy/testing/pytest_tools/utils.py +++ b/numpy/testing/pytest_tools/utils.py @@ -45,7 +45,7 @@ class KnownFailureException(Exception): def __new__(cls, *args, **kwargs): # import _pytest here to avoid hard dependency import _pytest - return _pytest.skipping.XFailed(*args, **kwargs) + return _pytest.skipping.xfail(*args, **kwargs) class SkipTest(Exception): @@ -1187,7 +1187,7 @@ def raises(*exceptions): return raises_decorator -def assert_raises(*args, **kwargs): +def assert_raises(exception_class, fn=None, *args, **kwargs): """ assert_raises(exception_class, callable, *args, **kwargs) assert_raises(exception_class) @@ -1215,7 +1215,20 @@ def assert_raises(*args, **kwargs): import pytest __tracebackhide__ = True # Hide traceback for py.test - pytest.raises(*args,**kwargs) + + if fn is not None: + pytest.raises(exception_class, fn, *args,**kwargs) + else: + @contextlib.contextmanager + def assert_raises_context(): + try: + yield + except BaseException as raised_exception: + assert isinstance(raised_exception, exception_class) + else: + raise ValueError('Function did not raise an exception') + + return assert_raises_context() def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): @@ -1245,7 +1258,7 @@ def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): def do_nothing(self): pass - tmp = Dummy('') + tmp = Dummy('do_nothing') __tracebackhide__ = True # Hide traceback for py.test res = pytest.raises(exception_class, *args, **kwargs) |