diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2021-10-04 16:08:59 +0200 |
---|---|---|
committer | Bas van Beek <43369155+BvB93@users.noreply.github.com> | 2021-10-04 19:04:23 +0200 |
commit | 0437a2518c9f2b33c054c21bc84cd7c1a4880080 (patch) | |
tree | d6a8a8a82ccac69ed37042078eaec64aa34b2eb2 /numpy/lib/tests/test_nanfunctions.py | |
parent | 27710a458c70298f5c2cb0ffefce0737e43aeaed (diff) | |
download | numpy-0437a2518c9f2b33c054c21bc84cd7c1a4880080.tar.gz |
TST: Add tests for the new `nan<x>` function parameters
Diffstat (limited to 'numpy/lib/tests/test_nanfunctions.py')
-rw-r--r-- | numpy/lib/tests/test_nanfunctions.py | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py index 3b0fa6656..0bd68e461 100644 --- a/numpy/lib/tests/test_nanfunctions.py +++ b/numpy/lib/tests/test_nanfunctions.py @@ -218,6 +218,46 @@ class TestNanFunctions_MinMax: assert_(len(w) == 1, 'no warning raised') assert_(issubclass(w[0].category, RuntimeWarning)) + @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"]) + def test_initial(self, dtype): + class MyNDArray(np.ndarray): + pass + + ar = np.arange(9).astype(dtype) + ar[:5] = np.nan + + for f in self.nanfuncs: + initial = 100 if f is np.nanmax else 0 + + ret1 = f(ar, initial=initial) + assert ret1.dtype == dtype + assert ret1 == initial + + ret2 = f(ar.view(MyNDArray), initial=initial) + assert ret2.dtype == dtype + assert ret2 == initial + + @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"]) + def test_where(self, dtype): + class MyNDArray(np.ndarray): + pass + + ar = np.arange(9).reshape(3, 3).astype(dtype) + ar[0, :] = np.nan + where = np.ones_like(ar, dtype=np.bool_) + where[:, 0] = False + + for f in self.nanfuncs: + reference = 4 if f is np.nanmin else 8 + + ret1 = f(ar, where=where, initial=5) + assert ret1.dtype == dtype + assert ret1 == reference + + ret2 = f(ar.view(MyNDArray), where=where, initial=5) + assert ret2.dtype == dtype + assert ret2 == reference + class TestNanFunctions_ArgminArgmax: @@ -288,6 +328,30 @@ class TestNanFunctions_ArgminArgmax: res = f(mine) assert_(res.shape == ()) + @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"]) + def test_keepdims(self, dtype): + ar = np.arange(9).astype(dtype) + ar[:5] = np.nan + + for f in self.nanfuncs: + reference = 5 if f is np.nanargmin else 8 + ret = f(ar, keepdims=True) + assert ret.ndim == ar.ndim + assert ret == reference + + @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"]) + def test_out(self, dtype): + ar = np.arange(9).astype(dtype) + ar[:5] = np.nan + + for f in self.nanfuncs: + out = np.zeros((), dtype=np.intp) + reference = 5 if f is np.nanargmin else 8 + ret = f(ar, out=out) + assert ret is out + assert ret == reference + + _TEST_ARRAYS = { "0d": np.array(5), @@ -504,6 +568,30 @@ class TestNanFunctions_SumProd(SharedNanFunctionsTestsMixin): res = f(mat, axis=None) assert_equal(res, tgt) + @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"]) + def test_initial(self, dtype): + ar = np.arange(9).astype(dtype) + ar[:5] = np.nan + + for f in self.nanfuncs: + reference = 28 if f is np.nansum else 3360 + ret = f(ar, initial=2) + assert ret.dtype == dtype + assert ret == reference + + @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"]) + def test_where(self, dtype): + ar = np.arange(9).reshape(3, 3).astype(dtype) + ar[0, :] = np.nan + where = np.ones_like(ar, dtype=np.bool_) + where[:, 0] = False + + for f in self.nanfuncs: + reference = 26 if f is np.nansum else 2240 + ret = f(ar, where=where, initial=2) + assert ret.dtype == dtype + assert ret == reference + class TestNanFunctions_CumSumProd(SharedNanFunctionsTestsMixin): @@ -659,6 +747,21 @@ class TestNanFunctions_MeanVarStd(SharedNanFunctionsTestsMixin): assert_equal(f(mat, axis=axis), np.zeros([])) assert_(len(w) == 0) + @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"]) + def test_where(self, dtype): + ar = np.arange(9).reshape(3, 3).astype(dtype) + ar[0, :] = np.nan + where = np.ones_like(ar, dtype=np.bool_) + where[:, 0] = False + + for f, f_std in zip(self.nanfuncs, self.stdfuncs): + reference = f_std(ar[where][2:]) + dtype_reference = dtype if f is np.nanmean else ar.real.dtype + + ret = f(ar, where=where) + assert ret.dtype == dtype_reference + np.testing.assert_allclose(ret, reference) + _TIME_UNITS = ( "Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", "ps", "fs", "as" |