diff options
author | jaimefrio <jaime.frio@gmail.com> | 2014-01-09 10:23:07 -0800 |
---|---|---|
committer | jaimefrio <jaime.frio@gmail.com> | 2014-04-05 21:06:56 -0700 |
commit | eae3d1a73f2f901da5956e3bcdaf2c44bfdd1ed3 (patch) | |
tree | eacd0e21014656e4cf36fe617830697474d9166f /numpy | |
parent | 52d5d109f9dedf4f006b930abef9ff9c54ec1542 (diff) | |
download | numpy-eae3d1a73f2f901da5956e3bcdaf2c44bfdd1ed3.tar.gz |
ENH: add a 'return_counts=' keyword argument to `np.unique`
This PR adds a new keyword argument to `np.unique` that returns the
number of times each unique item comes up in the array. This allows
replacing a typical numpy construct:
unq, _ = np.unique(a, return_inverse=True)
unq_counts = np.bincount(_)
with a single line of code:
unq, unq_counts = np.unique(a, return_counts=True)
As a plus, it runs faster, because it does not need the extra
operations required to produce `unique_inverse`.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/arraysetops.py | 68 | ||||
-rw-r--r-- | numpy/lib/tests/test_arraysetops.py | 50 |
2 files changed, 81 insertions, 37 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index 691550579..0755fffd1 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -90,7 +90,7 @@ def ediff1d(ary, to_end=None, to_begin=None): return ed -def unique(ar, return_index=False, return_inverse=False): +def unique(ar, return_index=False, return_inverse=False, return_counts=False): """ Find the unique elements of an array. @@ -109,6 +109,10 @@ def unique(ar, return_index=False, return_inverse=False): return_inverse : bool, optional If True, also return the indices of the unique array that can be used to reconstruct `ar`. + return_counts : bool, optional + .. versionadded:: 1.9.0 + If True, also return the number of times each unique value comes up + in `ar`. Returns ------- @@ -120,6 +124,10 @@ def unique(ar, return_index=False, return_inverse=False): unique_inverse : ndarray, optional The indices to reconstruct the (flattened) original array from the unique array. Only provided if `return_inverse` is True. + unique_counts : ndarray, optional + .. versionadded:: 1.9.0 + The number of times each of the unique values comes up in the + original array. Only provided if `return_counts` is True. See Also -------- @@ -162,41 +170,49 @@ def unique(ar, return_index=False, return_inverse=False): try: ar = ar.flatten() except AttributeError: - if not return_inverse and not return_index: - return np.sort(list(set(ar))) + if not return_inverse and not return_index and not return_counts: + return np.sort(list((set(ar)))) else: ar = np.asanyarray(ar).flatten() + optional_indices = return_index or return_inverse + optional_returns = optional_indices or return_counts + if ar.size == 0: - if return_inverse and return_index: - return ar, np.empty(0, np.bool), np.empty(0, np.bool) - elif return_inverse or return_index: - return ar, np.empty(0, np.bool) + if not optional_returns: + ret = ar else: - return ar + ret = (ar,) + if return_index: + ret += (np.empty(0, np.bool),) + if return_inverse: + ret += (np.empty(0, np.bool),) + if return_counts: + ret += (np.empty(0, np.intp),) + return ret + + if optional_indices: + perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') + aux = ar[perm] + else: + ar.sort() + aux = ar + flag = np.concatenate(([True], aux[1:] != aux[:-1])) - if return_inverse or return_index: + if not optional_returns: + ret = aux[flag] + else: + ret = (aux[flag],) if return_index: - perm = ar.argsort(kind='mergesort') - else: - perm = ar.argsort() - aux = ar[perm] - flag = np.concatenate(([True], aux[1:] != aux[:-1])) + ret += (perm[flag],) if return_inverse: iflag = np.cumsum(flag) - 1 iperm = perm.argsort() - if return_index: - return aux[flag], perm[flag], iflag[iperm] - else: - return aux[flag], iflag[iperm] - else: - return aux[flag], perm[flag] - - else: - ar.sort() - flag = np.concatenate(([True], ar[1:] != ar[:-1])) - return ar[flag] - + ret += (np.take(iflag, iperm),) + if return_counts: + idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) + ret += (np.diff(idx),) + return ret def intersect1d(ar1, ar2, assume_unique=False): """ diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py index e44ccd12b..41d77c07f 100644 --- a/numpy/lib/tests/test_arraysetops.py +++ b/numpy/lib/tests/test_arraysetops.py @@ -14,31 +14,59 @@ class TestSetOps(TestCase): def test_unique(self): - def check_all(a, b, i1, i2, dt): - msg = "check values failed for type '%s'" % dt + def check_all(a, b, i1, i2, c, dt): + base_msg = 'check {0} failed for type {1}' + + msg = base_msg.format('values', dt) v = unique(a) assert_array_equal(v, b, msg) - msg = "check indexes failed for type '%s'" % dt - v, j = unique(a, 1, 0) + msg = base_msg.format('return_index', dt) + v, j = unique(a, 1, 0, 0) assert_array_equal(v, b, msg) assert_array_equal(j, i1, msg) - msg = "check reverse indexes failed for type '%s'" % dt - v, j = unique(a, 0, 1) + msg = base_msg.format('return_inverse', dt) + v, j = unique(a, 0, 1, 0) assert_array_equal(v, b, msg) assert_array_equal(j, i2, msg) - msg = "check with all indexes failed for type '%s'" % dt - v, j1, j2 = unique(a, 1, 1) + msg = base_msg.format('return_counts', dt) + v, j = unique(a, 0, 0, 1) + assert_array_equal(v, b, msg) + assert_array_equal(j, c, msg) + + msg = base_msg.format('return_index and return_inverse', dt) + v, j1, j2 = unique(a, 1, 1, 0) + assert_array_equal(v, b, msg) + assert_array_equal(j1, i1, msg) + assert_array_equal(j2, i2, msg) + + msg = base_msg.format('return_index and return_counts', dt) + v, j1, j2 = unique(a, 1, 0, 1) + assert_array_equal(v, b, msg) + assert_array_equal(j1, i1, msg) + assert_array_equal(j2, c, msg) + + msg = base_msg.format('return_inverse and return_counts', dt) + v, j1, j2 = unique(a, 0, 1, 1) + assert_array_equal(v, b, msg) + assert_array_equal(j1, i2, msg) + assert_array_equal(j2, c, msg) + + msg = base_msg.format(('return_index, return_inverse ' + 'and return_counts'), dt) + v, j1, j2, j3 = unique(a, 1, 1, 1) assert_array_equal(v, b, msg) assert_array_equal(j1, i1, msg) assert_array_equal(j2, i2, msg) + assert_array_equal(j3, c, msg) a = [5, 7, 1, 2, 1, 5, 7]*10 b = [1, 2, 5, 7] i1 = [2, 3, 0, 1] i2 = [2, 3, 0, 1, 0, 2, 3]*10 + c = np.multiply([2, 1, 2, 2], 10) # test for numeric arrays types = [] @@ -49,7 +77,7 @@ class TestSetOps(TestCase): for dt in types: aa = np.array(a, dt) bb = np.array(b, dt) - check_all(aa, bb, i1, i2, dt) + check_all(aa, bb, i1, i2, c, dt) # test for object arrays dt = 'O' @@ -57,13 +85,13 @@ class TestSetOps(TestCase): aa[:] = a bb = np.empty(len(b), dt) bb[:] = b - check_all(aa, bb, i1, i2, dt) + check_all(aa, bb, i1, i2, c, dt) # test for structured arrays dt = [('', 'i'), ('', 'i')] aa = np.array(list(zip(a, a)), dt) bb = np.array(list(zip(b, b)), dt) - check_all(aa, bb, i1, i2, dt) + check_all(aa, bb, i1, i2, c, dt) # test for ticket #2799 aa = [1.+0.j, 1- 1.j, 1] |