diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2021-09-03 15:43:16 +0200 |
---|---|---|
committer | Bas van Beek <b.f.van.beek@vu.nl> | 2021-09-03 16:47:42 +0200 |
commit | fffcb6e2bd45bf423f626ebf77a2e3142de636c4 (patch) | |
tree | 065369dcfde11c31648613e4fa7489d198831690 /numpy/lib/tests/test_nanfunctions.py | |
parent | b6d7c4680e23520fd90387f72d136717ed882bc0 (diff) | |
download | numpy-fffcb6e2bd45bf423f626ebf77a2e3142de636c4.tar.gz |
TST: Expand the old `TestNanFunctions_IntTypes` test with non-integer number types
Diffstat (limited to 'numpy/lib/tests/test_nanfunctions.py')
-rw-r--r-- | numpy/lib/tests/test_nanfunctions.py | 123 |
1 files changed, 52 insertions, 71 deletions
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py index 1f1f5601b..dc14c3ad2 100644 --- a/numpy/lib/tests/test_nanfunctions.py +++ b/numpy/lib/tests/test_nanfunctions.py @@ -231,79 +231,60 @@ class TestNanFunctions_ArgminArgmax: assert_(res.shape == ()) -class TestNanFunctions_IntTypes: - - int_types = (np.int8, np.int16, np.int32, np.int64, np.uint8, - np.uint16, np.uint32, np.uint64) +@pytest.mark.parametrize( + "dtype", + np.typecodes["AllInteger"] + np.typecodes["AllFloat"] + "O", +) +class TestNanFunctions_NumberTypes: mat = np.array([127, 39, 93, 87, 46]) - - def integer_arrays(self): - for dtype in self.int_types: - yield self.mat.astype(dtype) - - def test_nanmin(self): - tgt = np.min(self.mat) - for mat in self.integer_arrays(): - assert_equal(np.nanmin(mat), tgt) - - def test_nanmax(self): - tgt = np.max(self.mat) - for mat in self.integer_arrays(): - assert_equal(np.nanmax(mat), tgt) - - def test_nanargmin(self): - tgt = np.argmin(self.mat) - for mat in self.integer_arrays(): - assert_equal(np.nanargmin(mat), tgt) - - def test_nanargmax(self): - tgt = np.argmax(self.mat) - for mat in self.integer_arrays(): - assert_equal(np.nanargmax(mat), tgt) - - def test_nansum(self): - tgt = np.sum(self.mat) - for mat in self.integer_arrays(): - assert_equal(np.nansum(mat), tgt) - - def test_nanprod(self): - tgt = np.prod(self.mat) - for mat in self.integer_arrays(): - assert_equal(np.nanprod(mat), tgt) - - def test_nancumsum(self): - tgt = np.cumsum(self.mat) - for mat in self.integer_arrays(): - assert_equal(np.nancumsum(mat), tgt) - - def test_nancumprod(self): - tgt = np.cumprod(self.mat) - for mat in self.integer_arrays(): - assert_equal(np.nancumprod(mat), tgt) - - def test_nanmean(self): - tgt = np.mean(self.mat) - for mat in self.integer_arrays(): - assert_equal(np.nanmean(mat), tgt) - - def test_nanvar(self): - tgt = np.var(self.mat) - for mat in self.integer_arrays(): - assert_equal(np.nanvar(mat), tgt) - - tgt = np.var(mat, ddof=1) - for mat in self.integer_arrays(): - assert_equal(np.nanvar(mat, ddof=1), tgt) - - def test_nanstd(self): - tgt = np.std(self.mat) - for mat in self.integer_arrays(): - assert_equal(np.nanstd(mat), tgt) - - tgt = np.std(self.mat, ddof=1) - for mat in self.integer_arrays(): - assert_equal(np.nanstd(mat, ddof=1), tgt) + mat.setflags(write=False) + + nanfuncs = { + np.nanmin: np.min, + np.nanmax: np.max, + np.nanargmin: np.argmin, + np.nanargmax: np.argmax, + np.nansum: np.sum, + np.nanprod: np.prod, + np.nancumsum: np.cumsum, + np.nancumprod: np.cumprod, + np.nanmean: np.mean, + np.nanvar: np.var, + np.nanstd: np.std, + } + nanfunc_ids = [i.__name__ for i in nanfuncs] + + @pytest.mark.parametrize("nanfunc,func", nanfuncs.items(), ids=nanfunc_ids) + def test_nanfunc(self, dtype, nanfunc, func): + if nanfunc is np.nanprod and dtype == "e": + pytest.xfail(reason="overflow encountered in reduce") + + mat = self.mat.astype(dtype) + tgt = func(mat) + out = nanfunc(mat) + + assert_almost_equal(out, tgt) + if dtype == "O": + assert type(out) is type(tgt) + else: + assert out.dtype == tgt.dtype + + @pytest.mark.parametrize( + "nanfunc,func", + [(np.nanvar, np.var), (np.nanstd, np.std)], + ids=["nanvar", "nanstd"], + ) + def test_nanfunc_ddof(self, dtype, nanfunc, func): + mat = self.mat.astype(dtype) + tgt = func(mat, ddof=1) + out = nanfunc(mat, ddof=1) + + assert_almost_equal(out, tgt) + if dtype == "O": + assert type(out) is type(tgt) + else: + assert out.dtype == tgt.dtype class SharedNanFunctionsTestsMixin: |