diff options
author | warren <warren.weckesser@gmail.com> | 2022-05-09 21:50:55 -0400 |
---|---|---|
committer | warren <warren.weckesser@gmail.com> | 2022-05-09 21:50:55 -0400 |
commit | c27f7817acb63ad05200c9c240a00cc5a7280394 (patch) | |
tree | c56993121bf87a9a1c4141dd5f4c48c0f3f9932b /numpy/lib/tests/test_function_base.py | |
parent | 247cb34641bbd3481c3df741f88d7bfa65901e1b (diff) | |
download | numpy-c27f7817acb63ad05200c9c240a00cc5a7280394.tar.gz |
ENH: Add 'keepdims' to 'average()' and 'ma.average()'.
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 874754a64..bdcbef91d 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -305,6 +305,29 @@ class TestAverage: assert_almost_equal(y5.mean(0), average(y5, 0)) assert_almost_equal(y5.mean(1), average(y5, 1)) + @pytest.mark.parametrize( + 'x, axis, expected_avg, weights, expected_wavg, expected_wsum', + [([1, 2, 3], None, [2.0], [3, 4, 1], [1.75], [8.0]), + ([[1, 2, 5], [1, 6, 11]], 0, [[1.0, 4.0, 8.0]], + [1, 3], [[1.0, 5.0, 9.5]], [[4, 4, 4]])], + ) + def test_basic_keepdims(self, x, axis, expected_avg, + weights, expected_wavg, expected_wsum): + avg = np.average(x, axis=axis, keepdims=True) + assert avg.shape == np.shape(expected_avg) + assert_array_equal(avg, expected_avg) + + wavg = np.average(x, axis=axis, weights=weights, keepdims=True) + assert wavg.shape == np.shape(expected_wavg) + assert_array_equal(wavg, expected_wavg) + + wavg, wsum = np.average(x, axis=axis, weights=weights, returned=True, + keepdims=True) + assert wavg.shape == np.shape(expected_wavg) + assert_array_equal(wavg, expected_wavg) + assert wsum.shape == np.shape(expected_wsum) + assert_array_equal(wsum, expected_wsum) + def test_weights(self): y = np.arange(10) w = np.arange(10) @@ -1242,11 +1265,11 @@ class TestTrimZeros: res = trim_zeros(arr) assert_array_equal(arr, res) - def test_list_to_list(self): res = trim_zeros(self.a.tolist()) assert isinstance(res, list) + class TestExtins: def test_basic(self): @@ -1759,6 +1782,7 @@ class TestLeaks: finally: gc.enable() + class TestDigitize: def test_forward(self): @@ -2339,6 +2363,7 @@ class Test_I0: with pytest.raises(TypeError, match="i0 not supported for complex values"): res = i0(a) + class TestKaiser: def test_simple(self): @@ -3474,6 +3499,7 @@ class TestQuantile: assert np.isscalar(actual) assert_equal(np.quantile(a, 0.5), np.nan) + class TestLerp: @hypothesis.given(t0=st.floats(allow_nan=False, allow_infinity=False, min_value=0, max_value=1), |