summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/arraysetops.py10
-rw-r--r--numpy/lib/tests/test_arraysetops.py4
-rw-r--r--numpy/ma/extras.py10
-rw-r--r--numpy/ma/tests/test_extras.py2
4 files changed, 13 insertions, 13 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index 7776d7e76..3dd97aecf 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -470,11 +470,9 @@ def setdiff1d(ar1, ar2, assume_unique=False):
array([1, 2])
"""
- if not assume_unique:
+ if assume_unique:
+ ar1 = np.asarray(ar1).ravel()
+ else:
ar1 = unique(ar1)
ar2 = unique(ar2)
- aux = in1d(ar1, ar2, assume_unique=True)
- if aux.size == 0:
- return aux
- else:
- return np.asarray(ar1)[aux == 0]
+ return ar1[in1d(ar1, ar2, assume_unique=True, invert=True)]
diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py
index 39196f4bc..852183ffe 100644
--- a/numpy/lib/tests/test_arraysetops.py
+++ b/numpy/lib/tests/test_arraysetops.py
@@ -5,7 +5,7 @@ from __future__ import division, absolute_import, print_function
import numpy as np
from numpy.testing import (
- run_module_suite, TestCase, assert_array_equal
+ run_module_suite, TestCase, assert_array_equal, assert_equal
)
from numpy.lib.arraysetops import (
ediff1d, intersect1d, setxor1d, union1d, setdiff1d, unique, in1d
@@ -286,6 +286,8 @@ class TestSetOps(TestCase):
assert_array_equal(c, ec)
assert_array_equal([], setdiff1d([], []))
+ a = np.array((), np.uint32)
+ assert_equal(setdiff1d(a, []).dtype, np.uint32)
def test_setdiff1d_char_array(self):
a = np.array(['a', 'b', 'c'])
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 51064e831..64a9844cf 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -1278,14 +1278,12 @@ def setdiff1d(ar1, ar2, assume_unique=False):
fill_value = 999999)
"""
- if not assume_unique:
+ if assume_unique:
+ ar1 = ma.asarray(ar1).ravel()
+ else:
ar1 = unique(ar1)
ar2 = unique(ar2)
- aux = in1d(ar1, ar2, assume_unique=True)
- if aux.size == 0:
- return aux
- else:
- return ma.asarray(ar1)[aux == 0]
+ return ar1[in1d(ar1, ar2, assume_unique=True, invert=True)]
#####--------------------------------------------------------------------------
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index b6749ae9e..3c7b95c9e 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -1109,6 +1109,8 @@ class TestArraySetOps(TestCase):
a = arange(10)
b = arange(8)
assert_equal(setdiff1d(a, b), array([8, 9]))
+ a = array([], np.uint32, mask=[])
+ assert_equal(setdiff1d(a, []).dtype, np.uint32)
def test_setdiff1d_char_array(self):
# Test setdiff1d_charray