diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/function_base.py | 31 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 10 |
2 files changed, 30 insertions, 11 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 44c31b8ec..374901ccd 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -491,17 +491,28 @@ def trim_zeros(filt, trim='fb'): else: last = last - 1 return filt[first:last] -def unique(inseq): - """Return unique items (in sorted order) from a 1-dimensional sequence. - """ - # Dictionary setting is quite fast. - set = {} - for item in inseq: - set[item] = None - val = asarray(set.keys()) - val.sort() - return val +import sys +if sys.hexversion < 0x2040000: + from sets import Set as set + +def unique(x): + """Return sorted unique items from an array or sequence. + + Example: + >>> unique([5,2,4,0,4,4,2,2,1]) + array([0,1,2,4,5]) + """ + try: + tmp = x.flatten() + tmp.sort() + idx = concatenate(([True],tmp[1:]!=tmp[:-1])) + return tmp[idx] + except AttributeError: + items = list(set(x)) + items.sort() + return asarray(items) + def extract(condition, arr): """Return the elements of ravel(arr) where ravel(condition) is True (in 1D). diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index b359d5732..4642725fc 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -353,7 +353,15 @@ class test_histogram(NumpyTestCase): (a,b)=histogram(linspace(0,10,100)) assert(all(a==10)) - +class test_unique(NumpyTestCase): + def check_simple(self): + x = array([4,3,2,1,1,2,3,4, 0]) + assert(all(unique(x) == [0,1,2,3,4])) + assert(unique(array([1,1,1,1,1])) == array([1])) + x = ['widget', 'ham', 'foo', 'bar', 'foo', 'ham'] + assert(all(unique(x) == ['bar', 'foo', 'ham', 'widget'])) + x = array([5+6j, 1+1j, 1+10j, 10, 5+6j]) + assert(all(unique(x) == [1+1j, 1+10j, 5+6j, 10])) def compare_results(res,desired): for i in range(len(desired)): |