summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-02-23 14:55:01 +0000
committerEric Wieser <wieser.eric@gmail.com>2017-03-07 17:33:20 +0000
commit11c5a9f3fed6ccd2a8f77b22cc1abb405b9d24ab (patch)
tree42c8071676ddf156444de866054c1ad5c6cd9828 /numpy
parentf91eb364283bf6066f3a18e4b9738bc3452d155b (diff)
downloadnumpy-11c5a9f3fed6ccd2a8f77b22cc1abb405b9d24ab.tar.gz
BUG: Make MaskedArray.argsort and MaskedArray.sort consistent
Previously, these had different rules for unmasking values, and even different arguments to decide how to do so. Fixes #8664
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/tests/test_arraysetops.py12
-rw-r--r--numpy/ma/core.py47
-rw-r--r--numpy/ma/tests/test_core.py14
3 files changed, 57 insertions, 16 deletions
diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py
index 8b142c264..eb4cca0ce 100644
--- a/numpy/lib/tests/test_arraysetops.py
+++ b/numpy/lib/tests/test_arraysetops.py
@@ -352,6 +352,18 @@ class TestUnique(TestCase):
result = np.array([[-0.0, 0.0]])
assert_array_equal(unique(data, axis=0), result, msg)
+ def test_unique_masked(self):
+ # issue 8664
+ x = np.array([64, 0, 1, 2, 3, 63, 63, 0, 0, 0, 1, 2, 0, 63, 0], dtype='uint8')
+ y = np.ma.masked_equal(x, 0)
+
+ v = np.unique(y)
+ v2, i, c = np.unique(y, return_index=True, return_counts=True)
+
+ msg = 'Unique returned different results when asked for index'
+ assert_array_equal(v.data, v2.data, msg)
+ assert_array_equal(v.mask, v2.mask, msg)
+
def _run_axis_tests(self, dtype):
data = np.array([[0, 1, 0, 0],
[1, 0, 0, 0],
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 427c146a2..219bea2f6 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -5230,7 +5230,8 @@ class MaskedArray(ndarray):
out.__setmask__(self._mask)
return out
- def argsort(self, axis=None, kind='quicksort', order=None, fill_value=None):
+ def argsort(self, axis=None, kind='quicksort', order=None,
+ endwith=True, fill_value=None):
"""
Return an ndarray of indices that sort the array along the
specified axis. Masked values are filled beforehand to
@@ -5241,15 +5242,21 @@ class MaskedArray(ndarray):
axis : int, optional
Axis along which to sort. The default is -1 (last axis).
If None, the flattened array is used.
- fill_value : var, optional
- Value used to fill the array before sorting.
- The default is the `fill_value` attribute of the input array.
kind : {'quicksort', 'mergesort', 'heapsort'}, optional
Sorting algorithm.
order : list, optional
When `a` is an array with fields defined, this argument specifies
which fields to compare first, second, etc. Not all fields need be
specified.
+ endwith : {True, False}, optional
+ Whether missing values (if any) should be treated as the largest values
+ (True) or the smallest values (False)
+ When the array contains unmasked values at the same extremes of the
+ datatype, the ordering of these values and the masked values is
+ undefined.
+ fill_value : {var}, optional
+ Value used internally for the masked values.
+ If ``fill_value`` is not None, it supersedes ``endwith``.
Returns
-------
@@ -5259,7 +5266,7 @@ class MaskedArray(ndarray):
See Also
--------
- sort : Describes sorting algorithms used.
+ MaskedArray.sort : Describes sorting algorithms used.
lexsort : Indirect stable sort with multiple keys.
ndarray.sort : Inplace sort.
@@ -5278,10 +5285,19 @@ class MaskedArray(ndarray):
array([1, 0, 2])
"""
+
if fill_value is None:
- fill_value = default_fill_value(self)
- d = self.filled(fill_value).view(ndarray)
- return d.argsort(axis=axis, kind=kind, order=order)
+ if endwith:
+ # nan > inf
+ if np.issubdtype(self.dtype, np.floating):
+ fill_value = np.nan
+ else:
+ fill_value = minimum_fill_value(self)
+ else:
+ fill_value = maximum_fill_value(self)
+
+ filled = self.filled(fill_value)
+ return filled.argsort(axis=axis, kind=kind, order=order)
def argmin(self, axis=None, fill_value=None, out=None):
"""
@@ -5380,12 +5396,11 @@ class MaskedArray(ndarray):
to compare first, second, and so on. This list does not need to
include all of the fields.
endwith : {True, False}, optional
- Whether missing values (if any) should be forced in the upper indices
- (at the end of the array) (True) or lower indices (at the beginning).
- When the array contains unmasked values of the largest (or smallest if
- False) representable value of the datatype the ordering of these values
- and the masked values is undefined. To enforce the masked values are
- at the end (beginning) in this case one must sort the mask.
+ Whether missing values (if any) should be treated as the largest values
+ (True) or the smallest values (False)
+ When the array contains unmasked values at the same extremes of the
+ datatype, the ordering of these values and the masked values is
+ undefined.
fill_value : {var}, optional
Value used internally for the masked values.
If ``fill_value`` is not None, it supersedes ``endwith``.
@@ -6503,13 +6518,13 @@ def power(a, b, third=None):
argmin = _frommethod('argmin')
argmax = _frommethod('argmax')
-def argsort(a, axis=None, kind='quicksort', order=None, fill_value=None):
+def argsort(a, axis=None, kind='quicksort', order=None, endwith=True, fill_value=None):
"Function version of the eponymous method."
a = np.asanyarray(a)
if isinstance(a, MaskedArray):
return a.argsort(axis=axis, kind=kind, order=order,
- fill_value=fill_value)
+ endwith=endwith, fill_value=fill_value)
else:
return a.argsort(axis=axis, kind=kind, order=order)
argsort.__doc__ = MaskedArray.argsort.__doc__
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 93898c4d0..a65cac8c8 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -3031,6 +3031,20 @@ class TestMaskedArrayMethods(TestCase):
assert_equal(sortedx._data, [1, 2, -2, -1, 0])
assert_equal(sortedx._mask, [1, 1, 0, 0, 0])
+ def test_argsort_matches_sort(self):
+ x = array([1, 4, 2, 3], mask=[0, 1, 0, 0], dtype=np.uint8)
+
+ for kwargs in [dict(),
+ dict(endwith=True),
+ dict(endwith=False),
+ dict(fill_value=2),
+ dict(fill_value=2, endwith=True),
+ dict(fill_value=2, endwith=False)]:
+ sortedx = sort(x, **kwargs)
+ argsortedx = x[argsort(x, **kwargs)]
+ assert_equal(sortedx._data, argsortedx._data)
+ assert_equal(sortedx._mask, argsortedx._mask)
+
def test_sort_2d(self):
# Check sort of 2D array.
# 2D array w/o mask