diff options
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 51 |
1 files changed, 33 insertions, 18 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 0ab49fa11..35a3b3543 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -3689,7 +3689,7 @@ def msort(a): return b -def _ureduce(a, func, **kwargs): +def _ureduce(a, func, keepdims=False, **kwargs): """ Internal Function. Call `func` with `a` as first argument swapping the axes to use extended @@ -3717,13 +3717,20 @@ def _ureduce(a, func, **kwargs): """ a = np.asanyarray(a) axis = kwargs.get('axis', None) + out = kwargs.get('out', None) + + if keepdims is np._NoValue: + keepdims = False + + nd = a.ndim if axis is not None: - keepdim = list(a.shape) - nd = a.ndim axis = _nx.normalize_axis_tuple(axis, nd) - for ax in axis: - keepdim[ax] = 1 + if keepdims: + if out is not None: + index_out = tuple( + 0 if i in axis else slice(None) for i in range(nd)) + kwargs['out'] = out[(Ellipsis, ) + index_out] if len(axis) == 1: kwargs['axis'] = axis[0] @@ -3736,12 +3743,27 @@ def _ureduce(a, func, **kwargs): # merge reduced axis a = a.reshape(a.shape[:nkeep] + (-1,)) kwargs['axis'] = -1 - keepdim = tuple(keepdim) else: - keepdim = (1,) * a.ndim + if keepdims: + if out is not None: + index_out = (0, ) * nd + kwargs['out'] = out[(Ellipsis, ) + index_out] r = func(a, **kwargs) - return r, keepdim + + if out is not None: + return out + + if keepdims: + if axis is None: + index_r = (np.newaxis, ) * nd + else: + index_r = tuple( + np.newaxis if i in axis else slice(None) + for i in range(nd)) + r = r[(Ellipsis, ) + index_r] + + return r def _median_dispatcher( @@ -3831,12 +3853,8 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False): >>> assert not np.all(a==b) """ - r, k = _ureduce(a, func=_median, axis=axis, out=out, + return _ureduce(a, func=_median, keepdims=keepdims, axis=axis, out=out, overwrite_input=overwrite_input) - if keepdims: - return r.reshape(k) - else: - return r def _median(a, axis=None, out=None, overwrite_input=False): @@ -4452,17 +4470,14 @@ def _quantile_unchecked(a, method="linear", keepdims=False): """Assumes that q is in [0, 1], and is an ndarray""" - r, k = _ureduce(a, + return _ureduce(a, func=_quantile_ureduce_func, q=q, + keepdims=keepdims, axis=axis, out=out, overwrite_input=overwrite_input, method=method) - if keepdims: - return r.reshape(q.shape + k) - else: - return r def _quantile_is_valid(q): |