summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py51
1 files changed, 33 insertions, 18 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 0ab49fa11..35a3b3543 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3689,7 +3689,7 @@ def msort(a):
return b
-def _ureduce(a, func, **kwargs):
+def _ureduce(a, func, keepdims=False, **kwargs):
"""
Internal Function.
Call `func` with `a` as first argument swapping the axes to use extended
@@ -3717,13 +3717,20 @@ def _ureduce(a, func, **kwargs):
"""
a = np.asanyarray(a)
axis = kwargs.get('axis', None)
+ out = kwargs.get('out', None)
+
+ if keepdims is np._NoValue:
+ keepdims = False
+
+ nd = a.ndim
if axis is not None:
- keepdim = list(a.shape)
- nd = a.ndim
axis = _nx.normalize_axis_tuple(axis, nd)
- for ax in axis:
- keepdim[ax] = 1
+ if keepdims:
+ if out is not None:
+ index_out = tuple(
+ 0 if i in axis else slice(None) for i in range(nd))
+ kwargs['out'] = out[(Ellipsis, ) + index_out]
if len(axis) == 1:
kwargs['axis'] = axis[0]
@@ -3736,12 +3743,27 @@ def _ureduce(a, func, **kwargs):
# merge reduced axis
a = a.reshape(a.shape[:nkeep] + (-1,))
kwargs['axis'] = -1
- keepdim = tuple(keepdim)
else:
- keepdim = (1,) * a.ndim
+ if keepdims:
+ if out is not None:
+ index_out = (0, ) * nd
+ kwargs['out'] = out[(Ellipsis, ) + index_out]
r = func(a, **kwargs)
- return r, keepdim
+
+ if out is not None:
+ return out
+
+ if keepdims:
+ if axis is None:
+ index_r = (np.newaxis, ) * nd
+ else:
+ index_r = tuple(
+ np.newaxis if i in axis else slice(None)
+ for i in range(nd))
+ r = r[(Ellipsis, ) + index_r]
+
+ return r
def _median_dispatcher(
@@ -3831,12 +3853,8 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
>>> assert not np.all(a==b)
"""
- r, k = _ureduce(a, func=_median, axis=axis, out=out,
+ return _ureduce(a, func=_median, keepdims=keepdims, axis=axis, out=out,
overwrite_input=overwrite_input)
- if keepdims:
- return r.reshape(k)
- else:
- return r
def _median(a, axis=None, out=None, overwrite_input=False):
@@ -4452,17 +4470,14 @@ def _quantile_unchecked(a,
method="linear",
keepdims=False):
"""Assumes that q is in [0, 1], and is an ndarray"""
- r, k = _ureduce(a,
+ return _ureduce(a,
func=_quantile_ureduce_func,
q=q,
+ keepdims=keepdims,
axis=axis,
out=out,
overwrite_input=overwrite_input,
method=method)
- if keepdims:
- return r.reshape(q.shape + k)
- else:
- return r
def _quantile_is_valid(q):