summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_function_base.py
diff options
context:
space:
mode:
authorwarren <warren.weckesser@gmail.com>2022-05-09 21:50:55 -0400
committerwarren <warren.weckesser@gmail.com>2022-05-09 21:50:55 -0400
commitc27f7817acb63ad05200c9c240a00cc5a7280394 (patch)
treec56993121bf87a9a1c4141dd5f4c48c0f3f9932b /numpy/lib/tests/test_function_base.py
parent247cb34641bbd3481c3df741f88d7bfa65901e1b (diff)
downloadnumpy-c27f7817acb63ad05200c9c240a00cc5a7280394.tar.gz
ENH: Add 'keepdims' to 'average()' and 'ma.average()'.
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r--numpy/lib/tests/test_function_base.py28
1 files changed, 27 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 874754a64..bdcbef91d 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -305,6 +305,29 @@ class TestAverage:
assert_almost_equal(y5.mean(0), average(y5, 0))
assert_almost_equal(y5.mean(1), average(y5, 1))
+ @pytest.mark.parametrize(
+ 'x, axis, expected_avg, weights, expected_wavg, expected_wsum',
+ [([1, 2, 3], None, [2.0], [3, 4, 1], [1.75], [8.0]),
+ ([[1, 2, 5], [1, 6, 11]], 0, [[1.0, 4.0, 8.0]],
+ [1, 3], [[1.0, 5.0, 9.5]], [[4, 4, 4]])],
+ )
+ def test_basic_keepdims(self, x, axis, expected_avg,
+ weights, expected_wavg, expected_wsum):
+ avg = np.average(x, axis=axis, keepdims=True)
+ assert avg.shape == np.shape(expected_avg)
+ assert_array_equal(avg, expected_avg)
+
+ wavg = np.average(x, axis=axis, weights=weights, keepdims=True)
+ assert wavg.shape == np.shape(expected_wavg)
+ assert_array_equal(wavg, expected_wavg)
+
+ wavg, wsum = np.average(x, axis=axis, weights=weights, returned=True,
+ keepdims=True)
+ assert wavg.shape == np.shape(expected_wavg)
+ assert_array_equal(wavg, expected_wavg)
+ assert wsum.shape == np.shape(expected_wsum)
+ assert_array_equal(wsum, expected_wsum)
+
def test_weights(self):
y = np.arange(10)
w = np.arange(10)
@@ -1242,11 +1265,11 @@ class TestTrimZeros:
res = trim_zeros(arr)
assert_array_equal(arr, res)
-
def test_list_to_list(self):
res = trim_zeros(self.a.tolist())
assert isinstance(res, list)
+
class TestExtins:
def test_basic(self):
@@ -1759,6 +1782,7 @@ class TestLeaks:
finally:
gc.enable()
+
class TestDigitize:
def test_forward(self):
@@ -2339,6 +2363,7 @@ class Test_I0:
with pytest.raises(TypeError, match="i0 not supported for complex values"):
res = i0(a)
+
class TestKaiser:
def test_simple(self):
@@ -3474,6 +3499,7 @@ class TestQuantile:
assert np.isscalar(actual)
assert_equal(np.quantile(a, 0.5), np.nan)
+
class TestLerp:
@hypothesis.given(t0=st.floats(allow_nan=False, allow_infinity=False,
min_value=0, max_value=1),