summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/extras.py19
-rw-r--r--numpy/ma/tests/test_extras.py32
2 files changed, 43 insertions, 8 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index a05ea476b..0d5c73e7e 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -708,34 +708,37 @@ def _median(a, axis=None, out=None, overwrite_input=False):
asorted = a
else:
asorted = sort(a, axis=axis)
+
if axis is None:
axis = 0
elif axis < 0:
- axis += a.ndim
+ axis += asorted.ndim
if asorted.ndim == 1:
idx, odd = divmod(count(asorted), 2)
- return asorted[idx - (not odd) : idx + 1].mean()
+ return asorted[idx + odd - 1 : idx + 1].mean(out=out)
- counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis)
+ counts = count(asorted, axis=axis)
h = counts // 2
+
# create indexing mesh grid for all but reduced axis
axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape)
if i != axis]
ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij')
+
# insert indices of low and high median
ind.insert(axis, h - 1)
low = asorted[tuple(ind)]
low._sharedmask = False
ind[axis] = h
high = asorted[tuple(ind)]
+
# duplicate high if odd number of elements so mean does nothing
odd = counts % 2 == 1
- if asorted.ndim == 1:
- if odd:
- low = high
- else:
- low[odd] = high[odd]
+ if asorted.ndim > 1:
+ np.copyto(low, high, where=odd)
+ elif odd:
+ low = high
if np.issubdtype(asorted.dtype, np.inexact):
# avoid inf / x = masked
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index 6d56d4dc6..27fac3d63 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -10,6 +10,7 @@ Adapted from the original test_ma by Pierre Gerard-Marchant
from __future__ import division, absolute_import, print_function
import warnings
+import itertools
import numpy as np
from numpy.testing import (
@@ -684,6 +685,37 @@ class TestMedian(TestCase):
assert_equal(ma_x.shape, (2,), "shape mismatch")
assert_(type(ma_x) is MaskedArray)
+ def test_axis_argument_errors(self):
+ msg = "mask = %s, ndim = %s, axis = %s, overwrite_input = %s"
+ for ndmin in range(5):
+ for mask in [False, True]:
+ x = array(1, ndmin=ndmin, mask=mask)
+
+ # Valid axis values should not raise exception
+ args = itertools.product(range(-ndmin, ndmin), [False, True])
+ for axis, over in args:
+ try:
+ np.ma.median(x, axis=axis, overwrite_input=over)
+ except:
+ raise AssertionError(msg % (mask, ndmin, axis, over))
+
+ # Invalid axis values should raise exception
+ args = itertools.product([-(ndmin + 1), ndmin], [False, True])
+ for axis, over in args:
+ try:
+ np.ma.median(x, axis=axis, overwrite_input=over)
+ except IndexError:
+ pass
+ else:
+ raise AssertionError(msg % (mask, ndmin, axis, over))
+
+ def test_masked_0d(self):
+ # Check values
+ x = array(1, mask=False)
+ assert_equal(np.ma.median(x), 1)
+ x = array(1, mask=True)
+ assert_equal(np.ma.median(x), np.ma.masked)
+
def test_masked_1d(self):
x = array(np.arange(5), mask=True)
assert_equal(np.ma.median(x), np.ma.masked)