summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_nanfunctions.py
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-09-03 15:43:16 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-09-03 16:47:42 +0200
commitfffcb6e2bd45bf423f626ebf77a2e3142de636c4 (patch)
tree065369dcfde11c31648613e4fa7489d198831690 /numpy/lib/tests/test_nanfunctions.py
parentb6d7c4680e23520fd90387f72d136717ed882bc0 (diff)
downloadnumpy-fffcb6e2bd45bf423f626ebf77a2e3142de636c4.tar.gz
TST: Expand the old `TestNanFunctions_IntTypes` test with non-integer number types
Diffstat (limited to 'numpy/lib/tests/test_nanfunctions.py')
-rw-r--r--numpy/lib/tests/test_nanfunctions.py123
1 files changed, 52 insertions, 71 deletions
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index 1f1f5601b..dc14c3ad2 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -231,79 +231,60 @@ class TestNanFunctions_ArgminArgmax:
assert_(res.shape == ())
-class TestNanFunctions_IntTypes:
-
- int_types = (np.int8, np.int16, np.int32, np.int64, np.uint8,
- np.uint16, np.uint32, np.uint64)
+@pytest.mark.parametrize(
+ "dtype",
+ np.typecodes["AllInteger"] + np.typecodes["AllFloat"] + "O",
+)
+class TestNanFunctions_NumberTypes:
mat = np.array([127, 39, 93, 87, 46])
-
- def integer_arrays(self):
- for dtype in self.int_types:
- yield self.mat.astype(dtype)
-
- def test_nanmin(self):
- tgt = np.min(self.mat)
- for mat in self.integer_arrays():
- assert_equal(np.nanmin(mat), tgt)
-
- def test_nanmax(self):
- tgt = np.max(self.mat)
- for mat in self.integer_arrays():
- assert_equal(np.nanmax(mat), tgt)
-
- def test_nanargmin(self):
- tgt = np.argmin(self.mat)
- for mat in self.integer_arrays():
- assert_equal(np.nanargmin(mat), tgt)
-
- def test_nanargmax(self):
- tgt = np.argmax(self.mat)
- for mat in self.integer_arrays():
- assert_equal(np.nanargmax(mat), tgt)
-
- def test_nansum(self):
- tgt = np.sum(self.mat)
- for mat in self.integer_arrays():
- assert_equal(np.nansum(mat), tgt)
-
- def test_nanprod(self):
- tgt = np.prod(self.mat)
- for mat in self.integer_arrays():
- assert_equal(np.nanprod(mat), tgt)
-
- def test_nancumsum(self):
- tgt = np.cumsum(self.mat)
- for mat in self.integer_arrays():
- assert_equal(np.nancumsum(mat), tgt)
-
- def test_nancumprod(self):
- tgt = np.cumprod(self.mat)
- for mat in self.integer_arrays():
- assert_equal(np.nancumprod(mat), tgt)
-
- def test_nanmean(self):
- tgt = np.mean(self.mat)
- for mat in self.integer_arrays():
- assert_equal(np.nanmean(mat), tgt)
-
- def test_nanvar(self):
- tgt = np.var(self.mat)
- for mat in self.integer_arrays():
- assert_equal(np.nanvar(mat), tgt)
-
- tgt = np.var(mat, ddof=1)
- for mat in self.integer_arrays():
- assert_equal(np.nanvar(mat, ddof=1), tgt)
-
- def test_nanstd(self):
- tgt = np.std(self.mat)
- for mat in self.integer_arrays():
- assert_equal(np.nanstd(mat), tgt)
-
- tgt = np.std(self.mat, ddof=1)
- for mat in self.integer_arrays():
- assert_equal(np.nanstd(mat, ddof=1), tgt)
+ mat.setflags(write=False)
+
+ nanfuncs = {
+ np.nanmin: np.min,
+ np.nanmax: np.max,
+ np.nanargmin: np.argmin,
+ np.nanargmax: np.argmax,
+ np.nansum: np.sum,
+ np.nanprod: np.prod,
+ np.nancumsum: np.cumsum,
+ np.nancumprod: np.cumprod,
+ np.nanmean: np.mean,
+ np.nanvar: np.var,
+ np.nanstd: np.std,
+ }
+ nanfunc_ids = [i.__name__ for i in nanfuncs]
+
+ @pytest.mark.parametrize("nanfunc,func", nanfuncs.items(), ids=nanfunc_ids)
+ def test_nanfunc(self, dtype, nanfunc, func):
+ if nanfunc is np.nanprod and dtype == "e":
+ pytest.xfail(reason="overflow encountered in reduce")
+
+ mat = self.mat.astype(dtype)
+ tgt = func(mat)
+ out = nanfunc(mat)
+
+ assert_almost_equal(out, tgt)
+ if dtype == "O":
+ assert type(out) is type(tgt)
+ else:
+ assert out.dtype == tgt.dtype
+
+ @pytest.mark.parametrize(
+ "nanfunc,func",
+ [(np.nanvar, np.var), (np.nanstd, np.std)],
+ ids=["nanvar", "nanstd"],
+ )
+ def test_nanfunc_ddof(self, dtype, nanfunc, func):
+ mat = self.mat.astype(dtype)
+ tgt = func(mat, ddof=1)
+ out = nanfunc(mat, ddof=1)
+
+ assert_almost_equal(out, tgt)
+ if dtype == "O":
+ assert type(out) is type(tgt)
+ else:
+ assert out.dtype == tgt.dtype
class SharedNanFunctionsTestsMixin: