diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-02-23 14:55:01 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-03-07 17:33:20 +0000 |
commit | 11c5a9f3fed6ccd2a8f77b22cc1abb405b9d24ab (patch) | |
tree | 42c8071676ddf156444de866054c1ad5c6cd9828 /numpy | |
parent | f91eb364283bf6066f3a18e4b9738bc3452d155b (diff) | |
download | numpy-11c5a9f3fed6ccd2a8f77b22cc1abb405b9d24ab.tar.gz |
BUG: Make MaskedArray.argsort and MaskedArray.sort consistent
Previously, these had different rules for unmasking values, and even different
arguments to decide how to do so.
Fixes #8664
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/tests/test_arraysetops.py | 12 | ||||
-rw-r--r-- | numpy/ma/core.py | 47 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 14 |
3 files changed, 57 insertions, 16 deletions
diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py index 8b142c264..eb4cca0ce 100644 --- a/numpy/lib/tests/test_arraysetops.py +++ b/numpy/lib/tests/test_arraysetops.py @@ -352,6 +352,18 @@ class TestUnique(TestCase): result = np.array([[-0.0, 0.0]]) assert_array_equal(unique(data, axis=0), result, msg) + def test_unique_masked(self): + # issue 8664 + x = np.array([64, 0, 1, 2, 3, 63, 63, 0, 0, 0, 1, 2, 0, 63, 0], dtype='uint8') + y = np.ma.masked_equal(x, 0) + + v = np.unique(y) + v2, i, c = np.unique(y, return_index=True, return_counts=True) + + msg = 'Unique returned different results when asked for index' + assert_array_equal(v.data, v2.data, msg) + assert_array_equal(v.mask, v2.mask, msg) + def _run_axis_tests(self, dtype): data = np.array([[0, 1, 0, 0], [1, 0, 0, 0], diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 427c146a2..219bea2f6 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -5230,7 +5230,8 @@ class MaskedArray(ndarray): out.__setmask__(self._mask) return out - def argsort(self, axis=None, kind='quicksort', order=None, fill_value=None): + def argsort(self, axis=None, kind='quicksort', order=None, + endwith=True, fill_value=None): """ Return an ndarray of indices that sort the array along the specified axis. Masked values are filled beforehand to @@ -5241,15 +5242,21 @@ class MaskedArray(ndarray): axis : int, optional Axis along which to sort. The default is -1 (last axis). If None, the flattened array is used. - fill_value : var, optional - Value used to fill the array before sorting. - The default is the `fill_value` attribute of the input array. kind : {'quicksort', 'mergesort', 'heapsort'}, optional Sorting algorithm. order : list, optional When `a` is an array with fields defined, this argument specifies which fields to compare first, second, etc. Not all fields need be specified. + endwith : {True, False}, optional + Whether missing values (if any) should be treated as the largest values + (True) or the smallest values (False) + When the array contains unmasked values at the same extremes of the + datatype, the ordering of these values and the masked values is + undefined. + fill_value : {var}, optional + Value used internally for the masked values. + If ``fill_value`` is not None, it supersedes ``endwith``. Returns ------- @@ -5259,7 +5266,7 @@ class MaskedArray(ndarray): See Also -------- - sort : Describes sorting algorithms used. + MaskedArray.sort : Describes sorting algorithms used. lexsort : Indirect stable sort with multiple keys. ndarray.sort : Inplace sort. @@ -5278,10 +5285,19 @@ class MaskedArray(ndarray): array([1, 0, 2]) """ + if fill_value is None: - fill_value = default_fill_value(self) - d = self.filled(fill_value).view(ndarray) - return d.argsort(axis=axis, kind=kind, order=order) + if endwith: + # nan > inf + if np.issubdtype(self.dtype, np.floating): + fill_value = np.nan + else: + fill_value = minimum_fill_value(self) + else: + fill_value = maximum_fill_value(self) + + filled = self.filled(fill_value) + return filled.argsort(axis=axis, kind=kind, order=order) def argmin(self, axis=None, fill_value=None, out=None): """ @@ -5380,12 +5396,11 @@ class MaskedArray(ndarray): to compare first, second, and so on. This list does not need to include all of the fields. endwith : {True, False}, optional - Whether missing values (if any) should be forced in the upper indices - (at the end of the array) (True) or lower indices (at the beginning). - When the array contains unmasked values of the largest (or smallest if - False) representable value of the datatype the ordering of these values - and the masked values is undefined. To enforce the masked values are - at the end (beginning) in this case one must sort the mask. + Whether missing values (if any) should be treated as the largest values + (True) or the smallest values (False) + When the array contains unmasked values at the same extremes of the + datatype, the ordering of these values and the masked values is + undefined. fill_value : {var}, optional Value used internally for the masked values. If ``fill_value`` is not None, it supersedes ``endwith``. @@ -6503,13 +6518,13 @@ def power(a, b, third=None): argmin = _frommethod('argmin') argmax = _frommethod('argmax') -def argsort(a, axis=None, kind='quicksort', order=None, fill_value=None): +def argsort(a, axis=None, kind='quicksort', order=None, endwith=True, fill_value=None): "Function version of the eponymous method." a = np.asanyarray(a) if isinstance(a, MaskedArray): return a.argsort(axis=axis, kind=kind, order=order, - fill_value=fill_value) + endwith=endwith, fill_value=fill_value) else: return a.argsort(axis=axis, kind=kind, order=order) argsort.__doc__ = MaskedArray.argsort.__doc__ diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 93898c4d0..a65cac8c8 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -3031,6 +3031,20 @@ class TestMaskedArrayMethods(TestCase): assert_equal(sortedx._data, [1, 2, -2, -1, 0]) assert_equal(sortedx._mask, [1, 1, 0, 0, 0]) + def test_argsort_matches_sort(self): + x = array([1, 4, 2, 3], mask=[0, 1, 0, 0], dtype=np.uint8) + + for kwargs in [dict(), + dict(endwith=True), + dict(endwith=False), + dict(fill_value=2), + dict(fill_value=2, endwith=True), + dict(fill_value=2, endwith=False)]: + sortedx = sort(x, **kwargs) + argsortedx = x[argsort(x, **kwargs)] + assert_equal(sortedx._data, argsortedx._data) + assert_equal(sortedx._mask, argsortedx._mask) + def test_sort_2d(self): # Check sort of 2D array. # 2D array w/o mask |