summaryrefslogtreecommitdiff
path: root/numpy/ma/extras.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r--numpy/ma/extras.py41
1 files changed, 25 insertions, 16 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index f53d9c7e5..82a61a67c 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -668,15 +668,9 @@ def median(a, axis=None, out=None, overwrite_input=False):
fill_value = 1e+20)
"""
- def _median1D(data):
- counts = filled(count(data), 0)
- (idx, rmd) = divmod(counts, 2)
- if rmd:
- choice = slice(idx, idx + 1)
- else:
- choice = slice(idx - 1, idx + 1)
- return data[choice].mean(0)
- #
+ if not hasattr(a, 'mask') or np.count_nonzero(a.mask) == 0:
+ return masked_array(np.median(a, axis=axis, out=out,
+ overwrite_input=overwrite_input), copy=False)
if overwrite_input:
if axis is None:
asorted = a.ravel()
@@ -687,14 +681,29 @@ def median(a, axis=None, out=None, overwrite_input=False):
else:
asorted = sort(a, axis=axis)
if axis is None:
- result = _median1D(asorted)
+ axis = 0
+ elif axis < 0:
+ axis += a.ndim
+
+ counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis)
+ h = counts // 2
+ # create indexing mesh grid for all but reduced axis
+ axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape)
+ if i != axis]
+ ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij')
+ # insert indices of low and high median
+ ind.insert(axis, h - 1)
+ low = asorted[ind]
+ ind[axis] = h
+ high = asorted[ind]
+ # duplicate high if odd number of elements so mean does nothing
+ odd = counts % 2 == 1
+ if asorted.ndim == 1:
+ if odd:
+ low = high
else:
- result = apply_along_axis(_median1D, axis, asorted)
- if out is not None:
- out = result
- return result
-
-
+ low[odd] = high[odd]
+ return np.ma.mean([low, high], axis=0, out=out)
#..............................................................................