summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2016-12-06 00:37:53 +0100
committerJulian Taylor <jtaylor.debian@googlemail.com>2016-12-12 23:19:06 +0100
commitff4758ff8432077c06cb580b97c6080f9b312c5a (patch)
treea40b77c06a1446f053f405a31bf0bcff91d3d8eb
parente00b9587052248486a9bf66c2ce95638c0d9817f (diff)
downloadnumpy-ff4758ff8432077c06cb580b97c6080f9b312c5a.tar.gz
BUG: handle unmasked NaN in ma.median like normal median
This requires to base masked median on sort(endwith=False) as we need to distinguish Inf and NaN. Using Inf as filler element of the sort does not work as then the mask is not guaranteed to be at the end. Closes gh-8340 Also fixed 1d ma.median not handling np.inf correctly, the nd variant was ok.
-rw-r--r--numpy/lib/function_base.py18
-rw-r--r--numpy/lib/tests/test_nanfunctions.py42
-rw-r--r--numpy/lib/utils.py46
-rw-r--r--numpy/ma/extras.py52
-rw-r--r--numpy/ma/tests/test_extras.py186
5 files changed, 305 insertions, 39 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 352512513..172e9a322 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3982,23 +3982,7 @@ def _median(a, axis=None, out=None, overwrite_input=False):
if np.issubdtype(a.dtype, np.inexact) and sz > 0:
# warn and return nans like mean would
rout = mean(part[indexer], axis=axis, out=out)
- part = np.rollaxis(part, axis, part.ndim)
- n = np.isnan(part[..., -1])
- if rout.ndim == 0:
- if n == True:
- warnings.warn("Invalid value encountered in median",
- RuntimeWarning, stacklevel=3)
- if out is not None:
- out[...] = a.dtype.type(np.nan)
- rout = out
- else:
- rout = a.dtype.type(np.nan)
- elif np.count_nonzero(n.ravel()) > 0:
- warnings.warn("Invalid value encountered in median for" +
- " %d results" % np.count_nonzero(n.ravel()),
- RuntimeWarning, stacklevel=3)
- rout[n] = np.nan
- return rout
+ return np.lib.utils._median_nancheck(part, rout, axis, out)
else:
# if there are no nans
# Use mean in odd and even case to coerce data type
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index 06c0953b5..18fcb2887 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -693,18 +693,36 @@ class TestNanFunctions_Median(TestCase):
def test_float_special(self):
with suppress_warnings() as sup:
sup.filter(RuntimeWarning)
- a = np.array([[np.inf, np.nan], [np.nan, np.nan]])
- assert_equal(np.nanmedian(a, axis=0), [np.inf, np.nan])
- assert_equal(np.nanmedian(a, axis=1), [np.inf, np.nan])
- assert_equal(np.nanmedian(a), np.inf)
-
- # minimum fill value check
- a = np.array([[np.nan, np.nan, np.inf], [np.nan, np.nan, np.inf]])
- assert_equal(np.nanmedian(a, axis=1), np.inf)
-
- # no mask path
- a = np.array([[np.inf, np.inf], [np.inf, np.inf]])
- assert_equal(np.nanmedian(a, axis=1), np.inf)
+ for inf in [np.inf, -np.inf]:
+ a = np.array([[inf, np.nan], [np.nan, np.nan]])
+ assert_equal(np.nanmedian(a, axis=0), [inf, np.nan])
+ assert_equal(np.nanmedian(a, axis=1), [inf, np.nan])
+ assert_equal(np.nanmedian(a), inf)
+
+ # minimum fill value check
+ a = np.array([[np.nan, np.nan, inf],
+ [np.nan, np.nan, inf]])
+ assert_equal(np.nanmedian(a), inf)
+ assert_equal(np.nanmedian(a, axis=0), [np.nan, np.nan, inf])
+ assert_equal(np.nanmedian(a, axis=1), inf)
+
+ # no mask path
+ a = np.array([[inf, inf], [inf, inf]])
+ assert_equal(np.nanmedian(a, axis=1), inf)
+
+ for i in range(0, 10):
+ for j in range(1, 10):
+ a = np.array([([np.nan] * i) + ([inf] * j)] * 2)
+ assert_equal(np.nanmedian(a), inf)
+ assert_equal(np.nanmedian(a, axis=1), inf)
+ assert_equal(np.nanmedian(a, axis=0),
+ ([np.nan] * i) + [inf] * j)
+
+ a = np.array([([np.nan] * i) + ([-inf] * j)] * 2)
+ assert_equal(np.nanmedian(a), -inf)
+ assert_equal(np.nanmedian(a, axis=1), -inf)
+ assert_equal(np.nanmedian(a, axis=0),
+ ([np.nan] * i) + [-inf] * j)
class TestNanFunctions_Percentile(TestCase):
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
index 5c364268c..61aa5e33b 100644
--- a/numpy/lib/utils.py
+++ b/numpy/lib/utils.py
@@ -8,6 +8,7 @@ import warnings
from numpy.core.numerictypes import issubclass_, issubsctype, issubdtype
from numpy.core import ndarray, ufunc, asarray
+import numpy as np
# getargspec and formatargspec were removed in Python 3.6
from numpy.compat import getargspec, formatargspec
@@ -1113,4 +1114,49 @@ def safe_eval(source):
import ast
return ast.literal_eval(source)
+
+
+def _median_nancheck(data, result, axis, out):
+ """
+ Utility function to check median result from data for NaN values at the end
+ and return NaN in that case. Input result can also be a MaskedArray.
+
+ Parameters
+ ----------
+ data : array
+ Input data to median function
+ result : Array or MaskedArray
+ Result of median function
+ axis : {int, sequence of int, None}, optional
+ Axis or axes along which the median was computed.
+ out : ndarray, optional
+ Output array in which to place the result.
+ Returns
+ -------
+ median : scalar or ndarray
+ Median or NaN in axes which contained NaN in the input.
+ """
+ if data.size == 0:
+ return result
+ data = np.rollaxis(data, axis, data.ndim)
+ n = np.isnan(data[..., -1])
+ # masked NaN values are ok
+ if np.ma.isMaskedArray(n):
+ n = n.filled(False)
+ if result.ndim == 0:
+ if n == True:
+ warnings.warn("Invalid value encountered in median",
+ RuntimeWarning, stacklevel=3)
+ if out is not None:
+ out[...] = data.dtype.type(np.nan)
+ result = out
+ else:
+ result = data.dtype.type(np.nan)
+ elif np.count_nonzero(n.ravel()) > 0:
+ warnings.warn("Invalid value encountered in median for" +
+ " %d results" % np.count_nonzero(n.ravel()),
+ RuntimeWarning, stacklevel=3)
+ result[n] = np.nan
+ return result
+
#-----------------------------------------------------------------------------
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index e4ff8ef2d..dadf032e0 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -699,15 +699,21 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
return r
def _median(a, axis=None, out=None, overwrite_input=False):
+ # when an unmasked NaN is present return it, so we need to sort the NaN
+ # values behind the mask
+ if np.issubdtype(a.dtype, np.inexact):
+ fill_value = np.inf
+ else:
+ fill_value = None
if overwrite_input:
if axis is None:
asorted = a.ravel()
- asorted.sort()
+ asorted.sort(fill_value=fill_value)
else:
- a.sort(axis=axis)
+ a.sort(axis=axis, fill_value=fill_value)
asorted = a
else:
- asorted = sort(a, axis=axis)
+ asorted = sort(a, axis=axis, fill_value=fill_value)
if axis is None:
axis = 0
@@ -715,8 +721,23 @@ def _median(a, axis=None, out=None, overwrite_input=False):
axis += asorted.ndim
if asorted.ndim == 1:
+ counts = count(asorted)
idx, odd = divmod(count(asorted), 2)
- return asorted[idx + odd - 1 : idx + 1].mean(out=out)
+ mid = asorted[idx + odd - 1 : idx + 1]
+ if np.issubdtype(asorted.dtype, np.inexact) and asorted.size > 0:
+ # avoid inf / x = masked
+ s = mid.sum(out=out)
+ np.true_divide(s, 2., casting='unsafe')
+ s = np.lib.utils._median_nancheck(asorted, s, axis, out)
+ else:
+ s = mid.mean(out=out)
+
+ # if result is masked either the input contained enough
+ # minimum_fill_value so that it would be the median or all values
+ # masked
+ if np.ma.is_masked(s) and not np.all(asorted.mask):
+ return np.ma.minimum_fill_value(asorted)
+ return s
counts = count(asorted, axis=axis)
h = counts // 2
@@ -727,24 +748,35 @@ def _median(a, axis=None, out=None, overwrite_input=False):
ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij')
# insert indices of low and high median
- ind.insert(axis, np.maximum(0, h - 1))
+ ind.insert(axis, h - 1)
low = asorted[tuple(ind)]
- ind[axis] = h
+ ind[axis] = np.minimum(h, asorted.shape[axis] - 1)
high = asorted[tuple(ind)]
# duplicate high if odd number of elements so mean does nothing
odd = counts % 2 == 1
- if asorted.ndim > 1:
- np.copyto(low, high, where=odd)
- elif odd:
- low = high
+ np.copyto(low, high, where=odd)
+ # not necessary for scalar True/False masks
+ try:
+ np.copyto(low.mask, high.mask, where=odd)
+ except:
+ pass
if np.issubdtype(asorted.dtype, np.inexact):
# avoid inf / x = masked
s = np.ma.sum([low, high], axis=0, out=out)
np.true_divide(s.data, 2., casting='unsafe', out=s.data)
+
+ s = np.lib.utils._median_nancheck(asorted, s, axis, out)
else:
s = np.ma.mean([low, high], axis=0, out=out)
+
+ # if result is masked either the input contained enough minimum_fill_value
+ # so that it would be the median or all values masked
+ if np.ma.is_masked(s):
+ rep = (~np.all(asorted.mask, axis=axis)) & s.mask
+ s.data[rep] = np.ma.minimum_fill_value(asorted)
+ s.mask[rep] = False
return s
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index 0a6de4eba..faee4f599 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -659,6 +659,15 @@ class TestMedian(TestCase):
r = np.ma.median([[np.inf, np.inf], [np.inf, np.inf]], axis=-1)
assert_equal(r, np.inf)
+ def test_inf(self):
+ # test that even which computes handles inf / x = masked
+ r = np.ma.median(np.ma.masked_array([[np.inf, np.inf],
+ [np.inf, np.inf]]), axis=-1)
+ assert_equal(r, np.inf)
+ r = np.ma.median(np.ma.masked_array([[np.inf, np.inf],
+ [np.inf, np.inf]]), axis=None)
+ assert_equal(r, np.inf)
+
def test_non_masked(self):
x = np.arange(9)
assert_equal(np.ma.median(x), 4.)
@@ -799,6 +808,183 @@ class TestMedian(TestCase):
assert_array_equal(np.ma.median(masked_arr, axis=0),
expected)
+ def test_nan(self):
+ with suppress_warnings() as w:
+ w.record(RuntimeWarning)
+ w.filter(DeprecationWarning, message=r"in 3\.x, __getslice__")
+ for mask in (False, np.zeros(6, dtype=np.bool)):
+ dm = np.ma.array([[1, np.nan, 3], [1, 2, 3]])
+ dm.mask = mask
+
+ # scalar result
+ r = np.ma.median(dm, axis=None)
+ assert_(np.isscalar(r))
+ assert_array_equal(r, np.nan)
+ r = np.ma.median(dm.ravel(), axis=0)
+ assert_(np.isscalar(r))
+ assert_array_equal(r, np.nan)
+
+ r = np.ma.median(dm, axis=0)
+ assert_equal(type(r), MaskedArray)
+ assert_array_equal(r, [1, np.nan, 3])
+ r = np.ma.median(dm, axis=1)
+ assert_equal(type(r), MaskedArray)
+ assert_array_equal(r, [np.nan, 2])
+ r = np.ma.median(dm, axis=-1)
+ assert_equal(type(r), MaskedArray)
+ assert_array_equal(r, [np.nan, 2])
+
+ dm = np.ma.array([[1, np.nan, 3], [1, 2, 3]])
+ dm[:, 2] = np.ma.masked
+ assert_array_equal(np.ma.median(dm, axis=None), np.nan)
+ assert_array_equal(np.ma.median(dm, axis=0), [1, np.nan, 3])
+ assert_array_equal(np.ma.median(dm, axis=1), [np.nan, 1.5])
+ assert_equal([x.category is RuntimeWarning for x in w.log],
+ [True]*13)
+
+ def test_out_nan(self):
+ with warnings.catch_warnings(record=True):
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ o = np.ma.masked_array(np.zeros((4,)))
+ d = np.ma.masked_array(np.ones((3, 4)))
+ d[2, 1] = np.nan
+ d[2, 2] = np.ma.masked
+ assert_equal(np.ma.median(d, 0, out=o), o)
+ o = np.ma.masked_array(np.zeros((3,)))
+ assert_equal(np.ma.median(d, 1, out=o), o)
+ o = np.ma.masked_array(np.zeros(()))
+ assert_equal(np.ma.median(d, out=o), o)
+
+ def test_nan_behavior(self):
+ a = np.ma.masked_array(np.arange(24, dtype=float))
+ a[::3] = np.ma.masked
+ a[2] = np.nan
+ with suppress_warnings() as w:
+ w.record(RuntimeWarning)
+ w.filter(DeprecationWarning, message=r"in 3\.x, __getslice__")
+ assert_array_equal(np.ma.median(a), np.nan)
+ assert_array_equal(np.ma.median(a, axis=0), np.nan)
+ assert_(w.log[0].category is RuntimeWarning)
+ assert_(w.log[1].category is RuntimeWarning)
+
+ a = np.ma.masked_array(np.arange(24, dtype=float).reshape(2, 3, 4))
+ a.mask = np.arange(a.size) % 2 == 1
+ aorig = a.copy()
+ a[1, 2, 3] = np.nan
+ a[1, 1, 2] = np.nan
+
+ # no axis
+ with suppress_warnings() as w:
+ w.record(RuntimeWarning)
+ w.filter(DeprecationWarning, message=r"in 3\.x, __getslice__")
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_array_equal(np.ma.median(a), np.nan)
+ assert_(np.isscalar(np.ma.median(a)))
+ assert_(w.log[0].category is RuntimeWarning)
+
+ # axis0
+ b = np.ma.median(aorig, axis=0)
+ b[2, 3] = np.nan
+ b[1, 2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.ma.median(a, 0), b)
+ assert_equal(len(w), 1)
+
+ # axis1
+ b = np.ma.median(aorig, axis=1)
+ b[1, 3] = np.nan
+ b[1, 2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.ma.median(a, 1), b)
+ assert_equal(len(w), 1)
+
+ # axis02
+ b = np.ma.median(aorig, axis=(0, 2))
+ b[1] = np.nan
+ b[2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.ma.median(a, (0, 2)), b)
+ assert_equal(len(w), 1)
+
+ def test_ambigous_fill(self):
+ # 255 is max value, used as filler for sort
+ a = np.array([[3, 3, 255], [3, 3, 255]], dtype=np.uint8)
+ a = np.ma.masked_array(a, mask=a == 3)
+ assert_array_equal(np.ma.median(a, axis=1), 255)
+ assert_array_equal(np.ma.median(a, axis=1).mask, False)
+ assert_array_equal(np.ma.median(a, axis=0), a[0])
+ assert_array_equal(np.ma.median(a), 255)
+
+ def test_special(self):
+ for inf in [np.inf, -np.inf]:
+ a = np.array([[inf, np.nan], [np.nan, np.nan]])
+ a = np.ma.masked_array(a, mask=np.isnan(a))
+ assert_equal(np.ma.median(a, axis=0), [inf, np.nan])
+ assert_equal(np.ma.median(a, axis=1), [inf, np.nan])
+ assert_equal(np.ma.median(a), inf)
+
+ a = np.array([[np.nan, np.nan, inf], [np.nan, np.nan, inf]])
+ a = np.ma.masked_array(a, mask=np.isnan(a))
+ assert_array_equal(np.ma.median(a, axis=1), inf)
+ assert_array_equal(np.ma.median(a, axis=1).mask, False)
+ assert_array_equal(np.ma.median(a, axis=0), a[0])
+ assert_array_equal(np.ma.median(a), inf)
+
+ # no mask
+ a = np.array([[inf, inf], [inf, inf]])
+ assert_equal(np.ma.median(a), inf)
+ assert_equal(np.ma.median(a, axis=0), inf)
+ assert_equal(np.ma.median(a, axis=1), inf)
+
+ for i in range(0, 10):
+ for j in range(1, 10):
+ a = np.array([([np.nan] * i) + ([inf] * j)] * 2)
+ a = np.ma.masked_array(a, mask=np.isnan(a))
+ assert_equal(np.ma.median(a), inf)
+ assert_equal(np.ma.median(a, axis=1), inf)
+ assert_equal(np.ma.median(a, axis=0),
+ ([np.nan] * i) + [inf] * j)
+
+ def test_empty(self):
+ # empty arrays
+ a = np.ma.masked_array(np.array([], dtype=float))
+ with suppress_warnings() as w:
+ w.record(RuntimeWarning)
+ w.filter(DeprecationWarning, message=r"in 3\.x, __getslice__")
+ assert_array_equal(np.ma.median(a), np.nan)
+ assert_(w.log[0].category is RuntimeWarning)
+
+ # multiple dimensions
+ a = np.ma.masked_array(np.array([], dtype=float, ndmin=3))
+ # no axis
+ with suppress_warnings() as w:
+ w.record(RuntimeWarning)
+ w.filter(DeprecationWarning, message=r"in 3\.x, __getslice__")
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_array_equal(np.ma.median(a), np.nan)
+ assert_(w.log[0].category is RuntimeWarning)
+
+ # axis 0 and 1
+ b = np.ma.masked_array(np.array([], dtype=float, ndmin=2))
+ assert_equal(np.median(a, axis=0), b)
+ assert_equal(np.median(a, axis=1), b)
+
+ # axis 2
+ b = np.ma.masked_array(np.array(np.nan, dtype=float, ndmin=2))
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.median(a, axis=2), b)
+ assert_(w[0].category is RuntimeWarning)
+
+ def test_object(self):
+ o = np.ma.masked_array(np.arange(7.))
+ assert_(type(np.ma.median(o.astype(object))), float)
+ o[2] = np.nan
+ assert_(type(np.ma.median(o.astype(object))), float)
+
class TestCov(TestCase):