diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/arraysetops.py | 10 | ||||
-rw-r--r-- | numpy/lib/tests/test_arraysetops.py | 4 | ||||
-rw-r--r-- | numpy/ma/extras.py | 10 | ||||
-rw-r--r-- | numpy/ma/tests/test_extras.py | 2 |
4 files changed, 13 insertions, 13 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index 7776d7e76..3dd97aecf 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -470,11 +470,9 @@ def setdiff1d(ar1, ar2, assume_unique=False): array([1, 2]) """ - if not assume_unique: + if assume_unique: + ar1 = np.asarray(ar1).ravel() + else: ar1 = unique(ar1) ar2 = unique(ar2) - aux = in1d(ar1, ar2, assume_unique=True) - if aux.size == 0: - return aux - else: - return np.asarray(ar1)[aux == 0] + return ar1[in1d(ar1, ar2, assume_unique=True, invert=True)] diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py index 39196f4bc..852183ffe 100644 --- a/numpy/lib/tests/test_arraysetops.py +++ b/numpy/lib/tests/test_arraysetops.py @@ -5,7 +5,7 @@ from __future__ import division, absolute_import, print_function import numpy as np from numpy.testing import ( - run_module_suite, TestCase, assert_array_equal + run_module_suite, TestCase, assert_array_equal, assert_equal ) from numpy.lib.arraysetops import ( ediff1d, intersect1d, setxor1d, union1d, setdiff1d, unique, in1d @@ -286,6 +286,8 @@ class TestSetOps(TestCase): assert_array_equal(c, ec) assert_array_equal([], setdiff1d([], [])) + a = np.array((), np.uint32) + assert_equal(setdiff1d(a, []).dtype, np.uint32) def test_setdiff1d_char_array(self): a = np.array(['a', 'b', 'c']) diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 51064e831..64a9844cf 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -1278,14 +1278,12 @@ def setdiff1d(ar1, ar2, assume_unique=False): fill_value = 999999) """ - if not assume_unique: + if assume_unique: + ar1 = ma.asarray(ar1).ravel() + else: ar1 = unique(ar1) ar2 = unique(ar2) - aux = in1d(ar1, ar2, assume_unique=True) - if aux.size == 0: - return aux - else: - return ma.asarray(ar1)[aux == 0] + return ar1[in1d(ar1, ar2, assume_unique=True, invert=True)] #####-------------------------------------------------------------------------- diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py index b6749ae9e..3c7b95c9e 100644 --- a/numpy/ma/tests/test_extras.py +++ b/numpy/ma/tests/test_extras.py @@ -1109,6 +1109,8 @@ class TestArraySetOps(TestCase): a = arange(10) b = arange(8) assert_equal(setdiff1d(a, b), array([8, 9])) + a = array([], np.uint32, mask=[]) + assert_equal(setdiff1d(a, []).dtype, np.uint32) def test_setdiff1d_char_array(self): # Test setdiff1d_charray |