diff options
author | Jörg Döpfert <jdoepfert@users.noreply.github.com> | 2017-11-18 19:21:23 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-11-18 10:21:23 -0800 |
commit | 2d140f11857fc1353d6b2dcbb801e693128fd09a (patch) | |
tree | e2b7b9edcbcc07de412e357198593fec333635aa /numpy/lib | |
parent | c41961313e7ddab6563bd82a51ee512a5c44e785 (diff) | |
download | numpy-2d140f11857fc1353d6b2dcbb801e693128fd09a.tar.gz |
ENH: Make `np.in1d()` work for unorderable object arrays (#9999)
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/arraysetops.py | 10 | ||||
-rw-r--r-- | numpy/lib/tests/test_arraysetops.py | 31 |
2 files changed, 39 insertions, 2 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index ededb9dd0..f3301af92 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -448,8 +448,14 @@ def in1d(ar1, ar2, assume_unique=False, invert=False): ar1 = np.asarray(ar1).ravel() ar2 = np.asarray(ar2).ravel() - # This code is significantly faster when the condition is satisfied. - if len(ar2) < 10 * len(ar1) ** 0.145: + # Check if one of the arrays may contain arbitrary objects + contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject + + # This code is run when + # a) the first condition is true, making the code significantly faster + # b) the second condition is true (i.e. `ar1` or `ar2` may contain + # arbitrary objects), since then sorting is not guaranteed to work + if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object: if invert: mask = np.ones(len(ar1), dtype=bool) for a in ar2: diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py index b8ced41e8..b4787838d 100644 --- a/numpy/lib/tests/test_arraysetops.py +++ b/numpy/lib/tests/test_arraysetops.py @@ -208,6 +208,37 @@ class TestSetOps(object): assert_array_equal(in1d(a, long_b, assume_unique=True), ec) assert_array_equal(in1d(a, long_b, assume_unique=False), ec) + def test_in1d_first_array_is_object(self): + ar1 = [None] + ar2 = np.array([1]*10) + expected = np.array([False]) + result = np.in1d(ar1, ar2) + assert_array_equal(result, expected) + + def test_in1d_second_array_is_object(self): + ar1 = 1 + ar2 = np.array([None]*10) + expected = np.array([False]) + result = np.in1d(ar1, ar2) + assert_array_equal(result, expected) + + def test_in1d_both_arrays_are_object(self): + ar1 = [None] + ar2 = np.array([None]*10) + expected = np.array([True]) + result = np.in1d(ar1, ar2) + assert_array_equal(result, expected) + + def test_in1d_both_arrays_have_structured_dtype(self): + # Test arrays of a structured data type containing an integer field + # and a field of dtype `object` allowing for arbitrary Python objects + dt = np.dtype([('field1', int), ('field2', object)]) + ar1 = np.array([(1, None)], dtype=dt) + ar2 = np.array([(1, None)]*10, dtype=dt) + expected = np.array([True]) + result = np.in1d(ar1, ar2) + assert_array_equal(result, expected) + def test_union1d(self): a = np.array([5, 4, 7, 1, 2]) b = np.array([2, 4, 3, 3, 2, 1, 5]) |