summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_nanfunctions.py
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-10-04 16:08:59 +0200
committerBas van Beek <43369155+BvB93@users.noreply.github.com>2021-10-04 19:04:23 +0200
commit0437a2518c9f2b33c054c21bc84cd7c1a4880080 (patch)
treed6a8a8a82ccac69ed37042078eaec64aa34b2eb2 /numpy/lib/tests/test_nanfunctions.py
parent27710a458c70298f5c2cb0ffefce0737e43aeaed (diff)
downloadnumpy-0437a2518c9f2b33c054c21bc84cd7c1a4880080.tar.gz
TST: Add tests for the new `nan<x>` function parameters
Diffstat (limited to 'numpy/lib/tests/test_nanfunctions.py')
-rw-r--r--numpy/lib/tests/test_nanfunctions.py103
1 files changed, 103 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index 3b0fa6656..0bd68e461 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -218,6 +218,46 @@ class TestNanFunctions_MinMax:
assert_(len(w) == 1, 'no warning raised')
assert_(issubclass(w[0].category, RuntimeWarning))
+ @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
+ def test_initial(self, dtype):
+ class MyNDArray(np.ndarray):
+ pass
+
+ ar = np.arange(9).astype(dtype)
+ ar[:5] = np.nan
+
+ for f in self.nanfuncs:
+ initial = 100 if f is np.nanmax else 0
+
+ ret1 = f(ar, initial=initial)
+ assert ret1.dtype == dtype
+ assert ret1 == initial
+
+ ret2 = f(ar.view(MyNDArray), initial=initial)
+ assert ret2.dtype == dtype
+ assert ret2 == initial
+
+ @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
+ def test_where(self, dtype):
+ class MyNDArray(np.ndarray):
+ pass
+
+ ar = np.arange(9).reshape(3, 3).astype(dtype)
+ ar[0, :] = np.nan
+ where = np.ones_like(ar, dtype=np.bool_)
+ where[:, 0] = False
+
+ for f in self.nanfuncs:
+ reference = 4 if f is np.nanmin else 8
+
+ ret1 = f(ar, where=where, initial=5)
+ assert ret1.dtype == dtype
+ assert ret1 == reference
+
+ ret2 = f(ar.view(MyNDArray), where=where, initial=5)
+ assert ret2.dtype == dtype
+ assert ret2 == reference
+
class TestNanFunctions_ArgminArgmax:
@@ -288,6 +328,30 @@ class TestNanFunctions_ArgminArgmax:
res = f(mine)
assert_(res.shape == ())
+ @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
+ def test_keepdims(self, dtype):
+ ar = np.arange(9).astype(dtype)
+ ar[:5] = np.nan
+
+ for f in self.nanfuncs:
+ reference = 5 if f is np.nanargmin else 8
+ ret = f(ar, keepdims=True)
+ assert ret.ndim == ar.ndim
+ assert ret == reference
+
+ @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
+ def test_out(self, dtype):
+ ar = np.arange(9).astype(dtype)
+ ar[:5] = np.nan
+
+ for f in self.nanfuncs:
+ out = np.zeros((), dtype=np.intp)
+ reference = 5 if f is np.nanargmin else 8
+ ret = f(ar, out=out)
+ assert ret is out
+ assert ret == reference
+
+
_TEST_ARRAYS = {
"0d": np.array(5),
@@ -504,6 +568,30 @@ class TestNanFunctions_SumProd(SharedNanFunctionsTestsMixin):
res = f(mat, axis=None)
assert_equal(res, tgt)
+ @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
+ def test_initial(self, dtype):
+ ar = np.arange(9).astype(dtype)
+ ar[:5] = np.nan
+
+ for f in self.nanfuncs:
+ reference = 28 if f is np.nansum else 3360
+ ret = f(ar, initial=2)
+ assert ret.dtype == dtype
+ assert ret == reference
+
+ @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
+ def test_where(self, dtype):
+ ar = np.arange(9).reshape(3, 3).astype(dtype)
+ ar[0, :] = np.nan
+ where = np.ones_like(ar, dtype=np.bool_)
+ where[:, 0] = False
+
+ for f in self.nanfuncs:
+ reference = 26 if f is np.nansum else 2240
+ ret = f(ar, where=where, initial=2)
+ assert ret.dtype == dtype
+ assert ret == reference
+
class TestNanFunctions_CumSumProd(SharedNanFunctionsTestsMixin):
@@ -659,6 +747,21 @@ class TestNanFunctions_MeanVarStd(SharedNanFunctionsTestsMixin):
assert_equal(f(mat, axis=axis), np.zeros([]))
assert_(len(w) == 0)
+ @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
+ def test_where(self, dtype):
+ ar = np.arange(9).reshape(3, 3).astype(dtype)
+ ar[0, :] = np.nan
+ where = np.ones_like(ar, dtype=np.bool_)
+ where[:, 0] = False
+
+ for f, f_std in zip(self.nanfuncs, self.stdfuncs):
+ reference = f_std(ar[where][2:])
+ dtype_reference = dtype if f is np.nanmean else ar.real.dtype
+
+ ret = f(ar, where=where)
+ assert ret.dtype == dtype_reference
+ np.testing.assert_allclose(ret, reference)
+
_TIME_UNITS = (
"Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", "ps", "fs", "as"