summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorempeeu <empeeu@yahoo.com>2014-03-08 09:32:03 -0500
committerempeeu <empeeu@users.noreply.github.com>2015-06-22 19:41:58 -0400
commita320fd772468004a53f7c448ae47032eb1b5c5df (patch)
treebf37a3da7ec96d2449f3b23642fc972990ed623a /numpy/lib
parent81c2c16f3218c879f5bfeacd80f237336e56584d (diff)
downloadnumpy-a320fd772468004a53f7c448ae47032eb1b5c5df.tar.gz
BUG: Added proper handling of median and percentile when nan's are present in array to close issue #586.
Also added unit tests.
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py118
-rw-r--r--numpy/lib/tests/test_function_base.py178
2 files changed, 269 insertions, 27 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 26d25cd6d..762d338bb 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3029,41 +3029,37 @@ def _median(a, axis=None, out=None, overwrite_input=False):
# can't be reasonably be implemented in terms of percentile as we have to
# call mean to not break astropy
a = np.asanyarray(a)
- if axis is not None and axis >= a.ndim:
- raise IndexError(
- "axis %d out of bounds (%d)" % (axis, a.ndim))
+
+ # Set the partition indexes
+ if axis is None:
+ sz = a.size
+ else:
+ sz = a.shape[axis]
+ if sz % 2 == 0:
+ szh = sz // 2
+ kth = [szh - 1, szh]
+ else:
+ kth = [(sz - 1) // 2]
+ # Check if the array contains any nan's
+ if np.issubdtype(a.dtype, np.inexact):
+ kth.append(-1)
if overwrite_input:
if axis is None:
part = a.ravel()
- sz = part.size
- if sz % 2 == 0:
- szh = sz // 2
- part.partition((szh - 1, szh))
- else:
- part.partition((sz - 1) // 2)
+ part.partition(kth)
else:
- sz = a.shape[axis]
- if sz % 2 == 0:
- szh = sz // 2
- a.partition((szh - 1, szh), axis=axis)
- else:
- a.partition((sz - 1) // 2, axis=axis)
+ a.partition(kth, axis=axis)
part = a
else:
- if axis is None:
- sz = a.size
- else:
- sz = a.shape[axis]
- if sz % 2 == 0:
- part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis)
- else:
- part = partition(a, (sz - 1) // 2, axis=axis)
+ part = partition(a, kth, axis=axis)
+
if part.shape == ():
# make 0-D arrays work
return part.item()
if axis is None:
axis = 0
+
indexer = [slice(None)] * part.ndim
index = part.shape[axis] // 2
if part.shape[axis] % 2 == 1:
@@ -3071,9 +3067,33 @@ def _median(a, axis=None, out=None, overwrite_input=False):
indexer[axis] = slice(index, index+1)
else:
indexer[axis] = slice(index-1, index+1)
- # Use mean in odd and even case to coerce data type
- # and check, use out array.
- return mean(part[indexer], axis=axis, out=out)
+
+ # Check if the array contains any nan's
+ if np.issubdtype(a.dtype, np.inexact):
+ # warn and return nans like mean would
+ rout = mean(part[indexer], axis=axis, out=out)
+ part = np.rollaxis(part, axis, part.ndim)
+ n = np.isnan(part[..., -1])
+ if rout.ndim == 0:
+ if n == True:
+ warnings.warn("Invalid value encountered in median",
+ RuntimeWarning)
+ if out is not None:
+ out[...] = a.dtype.type(np.nan)
+ rout = out
+ else:
+ rout = a.dtype.type(np.nan)
+ else:
+ for i in range(np.count_nonzero(n.ravel())):
+ warnings.warn("Invalid value encountered in median",
+ RuntimeWarning)
+ rout[n] = np.nan
+ return rout
+ else:
+ # if there are no nans
+ # Use mean in odd and even case to coerce data type
+ # and check, use out array.
+ return mean(part[indexer], axis=axis, out=out)
def percentile(a, q, axis=None, out=None,
@@ -3249,20 +3269,36 @@ def _percentile(a, q, axis=None, out=None,
"interpolation can only be 'linear', 'lower' 'higher', "
"'midpoint', or 'nearest'")
+ n = np.array(False, dtype=bool) # check for nan's flag
if indices.dtype == intp: # take the points along axis
+ # Check if the array contains any nan's
+ if np.issubdtype(a.dtype, np.inexact):
+ indices = concatenate((indices, [-1]))
+
ap.partition(indices, axis=axis)
# ensure axis with qth is first
ap = np.rollaxis(ap, axis, 0)
axis = 0
+ # Check if the array contains any nan's
+ if np.issubdtype(a.dtype, np.inexact):
+ indices = indices[:-1]
+ n = np.isnan(ap[-1:, ...])
+
if zerod:
indices = indices[0]
r = take(ap, indices, axis=axis, out=out)
+
+
else: # weight the points above and below the indices
indices_below = floor(indices).astype(intp)
indices_above = indices_below + 1
indices_above[indices_above > Nx - 1] = Nx - 1
+ # Check if the array contains any nan's
+ if np.issubdtype(a.dtype, np.inexact):
+ indices_above = concatenate((indices_above, [-1]))
+
weights_above = indices - indices_below
weights_below = 1.0 - weights_above
@@ -3272,6 +3308,18 @@ def _percentile(a, q, axis=None, out=None,
weights_above.shape = weights_shape
ap.partition(concatenate((indices_below, indices_above)), axis=axis)
+
+ # ensure axis with qth is first
+ ap = np.rollaxis(ap, axis, 0)
+ weights_below = np.rollaxis(weights_below, axis, 0)
+ weights_above = np.rollaxis(weights_above, axis, 0)
+ axis = 0
+
+ # Check if the array contains any nan's
+ if np.issubdtype(a.dtype, np.inexact):
+ indices_above = indices_above[:-1]
+ n = np.isnan(ap[-1:, ...])
+
x1 = take(ap, indices_below, axis=axis) * weights_below
x2 = take(ap, indices_above, axis=axis) * weights_above
@@ -3288,6 +3336,24 @@ def _percentile(a, q, axis=None, out=None,
else:
r = add(x1, x2)
+ if np.any(n):
+ warnings.warn("Invalid value encountered in median",
+ RuntimeWarning)
+ if zerod:
+ if ap.ndim == 1:
+ if out is not None:
+ out[...] = a.dtype.type(np.nan)
+ r = out
+ else:
+ r = a.dtype.type(np.nan)
+ else:
+ r[..., n.squeeze(0)] = a.dtype.type(np.nan)
+ else:
+ if r.ndim == 1:
+ r[:] = a.dtype.type(np.nan)
+ else:
+ r[..., n.repeat(q.size, 0)] = a.dtype.type(np.nan)
+
return r
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index ad71fd3fa..397505bd0 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1785,6 +1785,12 @@ class TestScoreatpercentile(TestCase):
assert_equal(np.percentile(x, 0), 0.)
assert_equal(np.percentile(x, 100), 3.5)
assert_equal(np.percentile(x, 50), 1.75)
+ x[1] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.percentile(x, 0), np.nan)
+ assert_equal(np.percentile(x, 0, interpolation='nearest'), np.nan)
+ assert_(w[0].category is RuntimeWarning)
def test_api(self):
d = np.ones(5)
@@ -2080,7 +2086,107 @@ class TestScoreatpercentile(TestCase):
keepdims=True).shape, (2, 1, 1, 7, 1))
assert_equal(np.percentile(d, [1, 7], axis=(0, 3),
keepdims=True).shape, (2, 1, 5, 7, 1))
+ def test_out(self):
+ o = np.zeros((4,))
+ d = np.ones((3, 4))
+ assert_equal(np.percentile(d, 0, 0, out=o), o)
+ assert_equal(np.percentile(d, 0, 0, interpolation='nearest', out=o), o)
+ o = np.zeros((3,))
+ assert_equal(np.percentile(d, 1, 1, out=o), o)
+ assert_equal(np.percentile(d, 1, 1, interpolation='nearest', out=o), o)
+
+ o = np.zeros(())
+ assert_equal(np.percentile(d, 2, out=o), o)
+ assert_equal(np.percentile(d, 2, interpolation='nearest', out=o), o)
+
+
+ def test_out_nan(self):
+ with warnings.catch_warnings(record=True):
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ o = np.zeros((4,))
+ d = np.ones((3, 4))
+ d[2, 1] = np.nan
+ assert_equal(np.percentile(d, 0, 0, out=o), o)
+ assert_equal(np.percentile(d, 0, 0, interpolation='nearest', out=o), o)
+ o = np.zeros((3,))
+ assert_equal(np.percentile(d, 1, 1, out=o), o)
+ assert_equal(np.percentile(d, 1, 1, interpolation='nearest', out=o), o)
+ o = np.zeros(())
+ assert_equal(np.percentile(d, 1, out=o), o)
+ assert_equal(np.percentile(d, 1, interpolation='nearest', out=o), o)
+
+ def test_nan_behavior(self):
+ a = np.arange(24, dtype=float)
+ a[2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.percentile(a, 0.3), np.nan)
+ assert_equal(np.percentile(a, 0.3, axis=0), np.nan)
+ assert_equal(np.percentile(a, [0.3, 0.6], axis=0),
+ np.array([np.nan] * 2))
+ assert_(w[0].category is RuntimeWarning)
+ assert_(w[1].category is RuntimeWarning)
+
+ a = np.arange(24, dtype=float).reshape(2, 3, 4)
+ a[1, 2, 3] = np.nan
+ a[1, 1, 2] = np.nan
+
+ # no axis
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.percentile(a, 0.3), np.nan)
+ assert_equal(np.percentile(a, 0.3).ndim, 0)
+ assert_(w[0].category is RuntimeWarning)
+
+ # axis0 zerod
+ b = np.percentile(np.arange(24, dtype=float).reshape(2, 3, 4), 0.3, 0)
+ b[2, 3] = np.nan; b[1, 2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.percentile(a, 0.3, 0), b)
+ # axis0 not zerod
+ b = np.percentile(np.arange(24, dtype=float).reshape(2, 3, 4),
+ [0.3, 0.6], 0)
+ b[:, 2, 3] = np.nan; b[:, 1, 2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.percentile(a, [0.3, 0.6], 0), b)
+
+ # axis1 zerod
+ b = np.percentile(np.arange(24, dtype=float).reshape(2, 3, 4), 0.3, 1)
+ b[1, 3] = np.nan; b[1, 2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.percentile(a, 0.3, 1), b)
+ # axis1 not zerod
+ b = np.percentile(np.arange(24, dtype=float).reshape(2, 3, 4), [0.3, 0.6], 1)
+ b[:, 1, 3] = np.nan; b[:, 1, 2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.percentile(a, [0.3, 0.6], 1), b)
+
+ # axis02 zerod
+ b = np.percentile(np.arange(24, dtype=float).reshape(2, 3, 4), 0.3, (0, 2))
+ b[1] = np.nan; b[2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.percentile(a, 0.3, (0, 2)), b)
+ # axis02 not zerod
+ b = np.percentile(np.arange(24, dtype=float).reshape(2, 3, 4),
+ [0.3, 0.6], (0, 2))
+ b[:, 1] = np.nan; b[:, 2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.percentile(a, [0.3, 0.6], (0, 2)), b)
+ # axis02 not zerod with nearest interpolation
+ b = np.percentile(np.arange(24, dtype=float).reshape(2, 3, 4),
+ [0.3, 0.6], (0, 2), interpolation='nearest')
+ b[:, 1] = np.nan; b[:, 2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.percentile(
+ a, [0.3, 0.6], (0, 2), interpolation='nearest'), b)
class TestMedian(TestCase):
def test_basic(self):
@@ -2103,7 +2209,11 @@ class TestMedian(TestCase):
# check array scalar result
assert_equal(np.median(a).ndim, 0)
a[1] = np.nan
- assert_equal(np.median(a).ndim, 0)
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.median(a).ndim, 0)
+ assert_(w[0].category is RuntimeWarning)
+
def test_axis_keyword(self):
a3 = np.array([[2, 3],
@@ -2177,6 +2287,72 @@ class TestMedian(TestCase):
a = MySubClass([1,2,3])
assert_equal(np.median(a), -7)
+ def test_out(self):
+ o = np.zeros((4,))
+ d = np.ones((3, 4))
+ assert_equal(np.median(d, 0, out=o), o)
+ o = np.zeros((3,))
+ assert_equal(np.median(d, 1, out=o), o)
+ o = np.zeros(())
+ assert_equal(np.median(d, out=o), o)
+
+ def test_out_nan(self):
+ with warnings.catch_warnings(record=True):
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ o = np.zeros((4,))
+ d = np.ones((3, 4))
+ d[2, 1] = np.nan
+ assert_equal(np.median(d, 0, out=o), o)
+ o = np.zeros((3,))
+ assert_equal(np.median(d, 1, out=o), o)
+ o = np.zeros(())
+ assert_equal(np.median(d, out=o), o)
+
+ def test_nan_behavior(self):
+ a = np.arange(24, dtype=float)
+ a[2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.median(a), np.nan)
+ assert_equal(np.median(a, axis=0), np.nan)
+ assert_(w[0].category is RuntimeWarning)
+ assert_(w[1].category is RuntimeWarning)
+
+ a = np.arange(24, dtype=float).reshape(2, 3, 4)
+ a[1, 2, 3] = np.nan
+ a[1, 1, 2] = np.nan
+
+ #no axis
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.median(a), np.nan)
+ assert_equal(np.median(a).ndim, 0)
+ assert_(w[0].category is RuntimeWarning)
+
+ #axis0
+ b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), 0)
+ b[2, 3] = np.nan; b[1, 2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.median(a, 0), b)
+ assert_equal(len(w), 2)
+
+ #axis1
+ b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), 1)
+ b[1, 3] = np.nan; b[1, 2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.median(a, 1), b)
+ assert_equal(len(w), 2)
+
+ #axis02
+ b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), (0, 2))
+ b[1] = np.nan; b[2] = np.nan
+ with warnings.catch_warnings(record=True) as w:
+ warnings.filterwarnings('always', '', RuntimeWarning)
+ assert_equal(np.median(a, (0, 2)), b)
+ assert_equal(len(w), 2)
+
def test_object(self):
o = np.arange(7.);
assert_(type(np.median(o.astype(object))), float)