summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_function_base.py
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2022-12-06 09:55:38 +0100
committerGitHub <noreply@github.com>2022-12-06 09:55:38 +0100
commit297c66131f2a32ebb7dfb0aeb9d88d917a791430 (patch)
tree99d543dfd704cc6f62d7f8bd9217d82711efa2f6 /numpy/lib/tests/test_function_base.py
parente877ba95bb3e22238353df0d654ef4d425c42f42 (diff)
parent91432a36a3611c2374ea9e2d45592f0ac5e71adb (diff)
downloadnumpy-297c66131f2a32ebb7dfb0aeb9d88d917a791430.tar.gz
Merge pull request #22721 from byrdie/bugfix/median-keepdims-out
BUG: `keepdims=True` is ignored if `out` is not `None` in `numpy.median`.
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r--numpy/lib/tests/test_function_base.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index c5b31ebf4..1bb4c4efa 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -25,6 +25,7 @@ from numpy.lib import (
i0, insert, interp, kaiser, meshgrid, msort, piecewise, place, rot90,
select, setxor1d, sinc, trapz, trim_zeros, unwrap, unique, vectorize
)
+from numpy.core.numeric import normalize_axis_tuple
def get_mat(n):
@@ -3331,6 +3332,32 @@ class TestPercentile:
assert_equal(np.percentile(d, [1, 7], axis=(0, 3),
keepdims=True).shape, (2, 1, 5, 7, 1))
+ @pytest.mark.parametrize('q', [7, [1, 7]])
+ @pytest.mark.parametrize(
+ argnames='axis',
+ argvalues=[
+ None,
+ 1,
+ (1,),
+ (0, 1),
+ (-3, -1),
+ ]
+ )
+ def test_keepdims_out(self, q, axis):
+ d = np.ones((3, 5, 7, 11))
+ if axis is None:
+ shape_out = (1,) * d.ndim
+ else:
+ axis_norm = normalize_axis_tuple(axis, d.ndim)
+ shape_out = tuple(
+ 1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
+ shape_out = np.shape(q) + shape_out
+
+ out = np.empty(shape_out)
+ result = np.percentile(d, q, axis=axis, keepdims=True, out=out)
+ assert result is out
+ assert_equal(result.shape, shape_out)
+
def test_out(self):
o = np.zeros((4,))
d = np.ones((3, 4))
@@ -3843,6 +3870,29 @@ class TestMedian:
assert_equal(np.median(d, axis=(0, 1, 3), keepdims=True).shape,
(1, 1, 7, 1))
+ @pytest.mark.parametrize(
+ argnames='axis',
+ argvalues=[
+ None,
+ 1,
+ (1, ),
+ (0, 1),
+ (-3, -1),
+ ]
+ )
+ def test_keepdims_out(self, axis):
+ d = np.ones((3, 5, 7, 11))
+ if axis is None:
+ shape_out = (1,) * d.ndim
+ else:
+ axis_norm = normalize_axis_tuple(axis, d.ndim)
+ shape_out = tuple(
+ 1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
+ out = np.empty(shape_out)
+ result = np.median(d, axis=axis, keepdims=True, out=out)
+ assert result is out
+ assert_equal(result.shape, shape_out)
+
class TestAdd_newdoc_ufunc: