diff options
author | Sebastian Berg <sebastianb@nvidia.com> | 2022-12-06 09:55:38 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-06 09:55:38 +0100 |
commit | 297c66131f2a32ebb7dfb0aeb9d88d917a791430 (patch) | |
tree | 99d543dfd704cc6f62d7f8bd9217d82711efa2f6 /numpy/lib/tests/test_function_base.py | |
parent | e877ba95bb3e22238353df0d654ef4d425c42f42 (diff) | |
parent | 91432a36a3611c2374ea9e2d45592f0ac5e71adb (diff) | |
download | numpy-297c66131f2a32ebb7dfb0aeb9d88d917a791430.tar.gz |
Merge pull request #22721 from byrdie/bugfix/median-keepdims-out
BUG: `keepdims=True` is ignored if `out` is not `None` in `numpy.median`.
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index c5b31ebf4..1bb4c4efa 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -25,6 +25,7 @@ from numpy.lib import ( i0, insert, interp, kaiser, meshgrid, msort, piecewise, place, rot90, select, setxor1d, sinc, trapz, trim_zeros, unwrap, unique, vectorize ) +from numpy.core.numeric import normalize_axis_tuple def get_mat(n): @@ -3331,6 +3332,32 @@ class TestPercentile: assert_equal(np.percentile(d, [1, 7], axis=(0, 3), keepdims=True).shape, (2, 1, 5, 7, 1)) + @pytest.mark.parametrize('q', [7, [1, 7]]) + @pytest.mark.parametrize( + argnames='axis', + argvalues=[ + None, + 1, + (1,), + (0, 1), + (-3, -1), + ] + ) + def test_keepdims_out(self, q, axis): + d = np.ones((3, 5, 7, 11)) + if axis is None: + shape_out = (1,) * d.ndim + else: + axis_norm = normalize_axis_tuple(axis, d.ndim) + shape_out = tuple( + 1 if i in axis_norm else d.shape[i] for i in range(d.ndim)) + shape_out = np.shape(q) + shape_out + + out = np.empty(shape_out) + result = np.percentile(d, q, axis=axis, keepdims=True, out=out) + assert result is out + assert_equal(result.shape, shape_out) + def test_out(self): o = np.zeros((4,)) d = np.ones((3, 4)) @@ -3843,6 +3870,29 @@ class TestMedian: assert_equal(np.median(d, axis=(0, 1, 3), keepdims=True).shape, (1, 1, 7, 1)) + @pytest.mark.parametrize( + argnames='axis', + argvalues=[ + None, + 1, + (1, ), + (0, 1), + (-3, -1), + ] + ) + def test_keepdims_out(self, axis): + d = np.ones((3, 5, 7, 11)) + if axis is None: + shape_out = (1,) * d.ndim + else: + axis_norm = normalize_axis_tuple(axis, d.ndim) + shape_out = tuple( + 1 if i in axis_norm else d.shape[i] for i in range(d.ndim)) + out = np.empty(shape_out) + result = np.median(d, axis=axis, keepdims=True, out=out) + assert result is out + assert_equal(result.shape, shape_out) + class TestAdd_newdoc_ufunc: |