diff options
author | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-05-31 13:20:43 +0200 |
---|---|---|
committer | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-06-02 23:47:24 +0200 |
commit | 99ff7a7cad36fcb5ba239bccd87a4f01ad25a1c1 (patch) | |
tree | 6bdb286e11684864ba3ff86a0d913a95863235b1 /numpy | |
parent | 0ae36289aa5104fce4e40c63ba46e19365f33b5d (diff) | |
download | numpy-99ff7a7cad36fcb5ba239bccd87a4f01ad25a1c1.tar.gz |
ENH: use masked median for small multidimensional nanmedians
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/nanfunctions.py | 20 | ||||
-rw-r--r-- | numpy/lib/tests/test_nanfunctions.py | 18 |
2 files changed, 36 insertions, 2 deletions
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py index 478e7cf7e..7120760b5 100644 --- a/numpy/lib/nanfunctions.py +++ b/numpy/lib/nanfunctions.py @@ -635,7 +635,7 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False): See nanmedian for parameter usage """ - if axis is None: + if axis is None or a.ndim == 1: part = a.ravel() if out is None: return _nanmedian1d(part, overwrite_input) @@ -643,11 +643,29 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False): out[...] = _nanmedian1d(part, overwrite_input) return out else: + # for small medians use sort + indexing which is still faster than + # apply_along_axis + if a.shape[axis] < 400: + return _nanmedian_small(a, axis, out, overwrite_input) result = np.apply_along_axis(_nanmedian1d, axis, a, overwrite_input) if out is not None: out[...] = result return result +def _nanmedian_small(a, axis=None, out=None, overwrite_input=False): + """ + sort + indexing median, faster for small medians along multiple dimensions + due to the high overhead of apply_along_axis + see nanmedian for parameter usage + """ + a = np.ma.masked_array(a, np.isnan(a)) + m = np.ma.median(a, axis=axis, overwrite_input=overwrite_input) + for i in range(np.count_nonzero(m.mask.ravel())): + warnings.warn("All-NaN slice encountered", RuntimeWarning) + if out is not None: + out[...] = m.filled(np.nan) + return out + return m.filled(np.nan) def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False): """ diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py index 3fcfca218..c5af61434 100644 --- a/numpy/lib/tests/test_nanfunctions.py +++ b/numpy/lib/tests/test_nanfunctions.py @@ -5,7 +5,7 @@ import warnings import numpy as np from numpy.testing import ( run_module_suite, TestCase, assert_, assert_equal, assert_almost_equal, - assert_raises + assert_raises, assert_array_equal ) @@ -580,6 +580,22 @@ class TestNanFunctions_Median(TestCase): assert_almost_equal(res, resout) assert_almost_equal(res, tgt) + def test_small_large(self): + # test the small and large code paths, current cutoff 400 elements + for s in [5, 20, 51, 200, 1000]: + d = np.random.randn(4, s) + # Randomly set some elements to NaN: + w = np.random.randint(0, d.size, size=d.size // 5) + d.ravel()[w] = np.nan + d[:,0] = 1. # ensure at least one good value + # use normal median without nans to compare + tgt = [] + for x in d: + nonan = np.compress(~np.isnan(x), x) + tgt.append(np.median(nonan, overwrite_input=True)) + + assert_array_equal(np.nanmedian(d, axis=-1), tgt) + def test_result_values(self): tgt = [np.median(d) for d in _rdat] res = np.nanmedian(_ndat, axis=1) |