summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaghuveer Devulapalli <raghuveer.devulapalli@intel.com>2021-07-16 12:25:01 -0700
committerRaghuveer Devulapalli <raghuveer.devulapalli@intel.com>2021-07-16 15:39:56 -0700
commitf17470711c61a37397c67c9241796eae4b29e9e0 (patch)
tree32ee0b9d026785e632383eba057b96b2a3b2145f
parent41b14b16cd75ad2d3cb2fb1409dc152de5c32cd2 (diff)
downloadnumpy-f17470711c61a37397c67c9241796eae4b29e9e0.tar.gz
MAINT: Abstract tests to avoid repetition
-rw-r--r--numpy/core/tests/test_umath.py72
1 files changed, 25 insertions, 47 deletions
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index d7f48040e..8e579b65d 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -1173,24 +1173,12 @@ class TestSpecialFloats:
assert_equal(np.arcsin(in_arr), out_arr)
assert_equal(np.arccos(in_arr), out_arr)
- with np.errstate(invalid='raise'):
- for dt in ['f', 'd']:
- assert_raises(FloatingPointError, np.arcsin,
- np.array(np.inf, dtype=dt))
- assert_raises(FloatingPointError, np.arcsin,
- np.array(-np.inf, dtype=dt))
- assert_raises(FloatingPointError, np.arcsin,
- np.array(2.0, dtype=dt))
- assert_raises(FloatingPointError, np.arcsin,
- np.array(-2.0, dtype=dt))
- assert_raises(FloatingPointError, np.arccos,
- np.array(np.inf, dtype=dt))
- assert_raises(FloatingPointError, np.arccos,
- np.array(-np.inf, dtype=dt))
- assert_raises(FloatingPointError, np.arccos,
- np.array(2.0, dtype=dt))
- assert_raises(FloatingPointError, np.arccos,
- np.array(-2.0, dtype=dt))
+ for callable in [np.arcsin, np.arccos]:
+ for value in [np.inf, -np.inf, 2.0, -2.0]:
+ for dt in ['f', 'd']:
+ with np.errstate(invalid='raise'):
+ assert_raises(FloatingPointError, callable,
+ np.array(value, dtype=dt))
def test_arctan(self):
with np.errstate(all='ignore'):
@@ -1254,12 +1242,11 @@ class TestSpecialFloats:
out_arr = np.array(out, dtype=dt)
assert_equal(np.arccosh(in_arr), out_arr)
- with np.errstate(invalid='raise'):
- for dt in ['f', 'd']:
- assert_raises(FloatingPointError, np.arccosh,
- np.array(0.0, dtype=dt))
- assert_raises(FloatingPointError, np.arccosh,
- np.array(-np.inf, dtype=dt))
+ for value in [0.0, -np.inf]:
+ with np.errstate(invalid='raise'):
+ for dt in ['f', 'd']:
+ assert_raises(FloatingPointError, np.arccosh,
+ np.array(value, dtype=dt))
def test_arctanh(self):
with np.errstate(all='ignore'):
@@ -1270,18 +1257,11 @@ class TestSpecialFloats:
out_arr = np.array(out, dtype=dt)
assert_equal(np.arctanh(in_arr), out_arr)
- with np.errstate(invalid='raise', divide='raise'):
- for dt in ['f', 'd']:
- assert_raises(FloatingPointError, np.arctanh,
- np.array(1.01, dtype=dt))
- assert_raises(FloatingPointError, np.arctanh,
- np.array(np.inf, dtype=dt))
- assert_raises(FloatingPointError, np.arctanh,
- np.array(-np.inf, dtype=dt))
- assert_raises(FloatingPointError, np.arctanh,
- np.array(1.0, dtype=dt))
- assert_raises(FloatingPointError, np.arctanh,
- np.array(-1.0, dtype=dt))
+ for value in [1.01, np.inf, -np.inf, 1.0, -1.0]:
+ with np.errstate(invalid='raise', divide='raise'):
+ for dt in ['f', 'd']:
+ assert_raises(FloatingPointError, np.arctanh,
+ np.array(value, dtype=dt))
def test_exp2(self):
with np.errstate(all='ignore'):
@@ -1292,12 +1272,11 @@ class TestSpecialFloats:
out_arr = np.array(out, dtype=dt)
assert_equal(np.exp2(in_arr), out_arr)
- with np.errstate(over='raise', under='raise'):
- for dt in ['f', 'd']:
- assert_raises(FloatingPointError, np.exp2,
- np.array(2000.0, dtype=dt))
- assert_raises(FloatingPointError, np.exp2,
- np.array(-2000.0, dtype=dt))
+ for value in [2000.0, -2000.0]:
+ with np.errstate(over='raise', under='raise'):
+ for dt in ['f', 'd']:
+ assert_raises(FloatingPointError, np.exp2,
+ np.array(value, dtype=dt))
def test_expm1(self):
with np.errstate(all='ignore'):
@@ -1308,11 +1287,10 @@ class TestSpecialFloats:
out_arr = np.array(out, dtype=dt)
assert_equal(np.expm1(in_arr), out_arr)
- with np.errstate(over='raise'):
- assert_raises(FloatingPointError, np.expm1,
- np.array(200.0, dtype='f'))
- assert_raises(FloatingPointError, np.expm1,
- np.array(2000.0, dtype='d'))
+ for value in [200.0, 2000.0]:
+ with np.errstate(over='raise'):
+ assert_raises(FloatingPointError, np.expm1,
+ np.array(value, dtype='f'))
class TestFPClass:
@pytest.mark.parametrize("stride", [-4,-2,-1,1,2,4])