diff options
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r-- | numpy/ma/extras.py | 41 |
1 files changed, 25 insertions, 16 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index f53d9c7e5..82a61a67c 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -668,15 +668,9 @@ def median(a, axis=None, out=None, overwrite_input=False): fill_value = 1e+20) """ - def _median1D(data): - counts = filled(count(data), 0) - (idx, rmd) = divmod(counts, 2) - if rmd: - choice = slice(idx, idx + 1) - else: - choice = slice(idx - 1, idx + 1) - return data[choice].mean(0) - # + if not hasattr(a, 'mask') or np.count_nonzero(a.mask) == 0: + return masked_array(np.median(a, axis=axis, out=out, + overwrite_input=overwrite_input), copy=False) if overwrite_input: if axis is None: asorted = a.ravel() @@ -687,14 +681,29 @@ def median(a, axis=None, out=None, overwrite_input=False): else: asorted = sort(a, axis=axis) if axis is None: - result = _median1D(asorted) + axis = 0 + elif axis < 0: + axis += a.ndim + + counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis) + h = counts // 2 + # create indexing mesh grid for all but reduced axis + axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape) + if i != axis] + ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij') + # insert indices of low and high median + ind.insert(axis, h - 1) + low = asorted[ind] + ind[axis] = h + high = asorted[ind] + # duplicate high if odd number of elements so mean does nothing + odd = counts % 2 == 1 + if asorted.ndim == 1: + if odd: + low = high else: - result = apply_along_axis(_median1D, axis, asorted) - if out is not None: - out = result - return result - - + low[odd] = high[odd] + return np.ma.mean([low, high], axis=0, out=out) #.............................................................................. |