summaryrefslogtreecommitdiff
path: root/numpy/testing/pytest_tools/utils.py
diff options
context:
space:
mode:
authorxoviat <xoviat@users.noreply.github.com>2017-12-22 13:15:07 -0600
committerxoviat <xoviat@users.noreply.github.com>2017-12-22 17:30:50 -0600
commitc6db7da2c932470f70ce9c9ab4dd49fc1a64b789 (patch)
tree7be96d4f535e0f7c684153779e6738772ee328cb /numpy/testing/pytest_tools/utils.py
parentcc4a3df98432fe621429da78596fa746214cd016 (diff)
downloadnumpy-c6db7da2c932470f70ce9c9ab4dd49fc1a64b789.tar.gz
BUG: Fix pytest implementation errors
Diffstat (limited to 'numpy/testing/pytest_tools/utils.py')
-rw-r--r--numpy/testing/pytest_tools/utils.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/numpy/testing/pytest_tools/utils.py b/numpy/testing/pytest_tools/utils.py
index a873199b1..19982ec54 100644
--- a/numpy/testing/pytest_tools/utils.py
+++ b/numpy/testing/pytest_tools/utils.py
@@ -45,7 +45,7 @@ class KnownFailureException(Exception):
def __new__(cls, *args, **kwargs):
# import _pytest here to avoid hard dependency
import _pytest
- return _pytest.skipping.XFailed(*args, **kwargs)
+ return _pytest.skipping.xfail(*args, **kwargs)
class SkipTest(Exception):
@@ -1187,7 +1187,7 @@ def raises(*exceptions):
return raises_decorator
-def assert_raises(*args, **kwargs):
+def assert_raises(exception_class, fn=None, *args, **kwargs):
"""
assert_raises(exception_class, callable, *args, **kwargs)
assert_raises(exception_class)
@@ -1215,7 +1215,20 @@ def assert_raises(*args, **kwargs):
import pytest
__tracebackhide__ = True # Hide traceback for py.test
- pytest.raises(*args,**kwargs)
+
+ if fn is not None:
+ pytest.raises(exception_class, fn, *args,**kwargs)
+ else:
+ @contextlib.contextmanager
+ def assert_raises_context():
+ try:
+ yield
+ except BaseException as raised_exception:
+ assert isinstance(raised_exception, exception_class)
+ else:
+ raise ValueError('Function did not raise an exception')
+
+ return assert_raises_context()
def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs):
@@ -1245,7 +1258,7 @@ def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs):
def do_nothing(self):
pass
- tmp = Dummy('')
+ tmp = Dummy('do_nothing')
__tracebackhide__ = True # Hide traceback for py.test
res = pytest.raises(exception_class, *args, **kwargs)