diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/arraysetops.py | 26 | ||||
-rw-r--r-- | numpy/lib/tests/test_arraysetops.py | 19 |
2 files changed, 25 insertions, 20 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index 2309f7e42..d65316598 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -94,8 +94,7 @@ def ediff1d(ary, to_end=None, to_begin=None): # force a 1d array ary = np.asanyarray(ary).ravel() - # enforce propagation of the dtype of input - # ary to returned result + # enforce that the dtype of `ary` is used for the output dtype_req = ary.dtype # fast track default case @@ -105,22 +104,23 @@ def ediff1d(ary, to_end=None, to_begin=None): if to_begin is None: l_begin = 0 else: - _to_begin = np.asanyarray(to_begin, dtype=dtype_req) - if not np.all(_to_begin == to_begin): - raise ValueError("cannot convert 'to_begin' to array with dtype " - "'%r' as required for input ary" % dtype_req) - to_begin = _to_begin.ravel() + to_begin = np.asanyarray(to_begin) + if not np.can_cast(to_begin, dtype_req, casting="same_kind"): + raise TypeError("dtype of `to_end` must be compatible " + "with input `ary` under the `same_kind` rule.") + + to_begin = to_begin.ravel() l_begin = len(to_begin) if to_end is None: l_end = 0 else: - _to_end = np.asanyarray(to_end, dtype=dtype_req) - # check that casting has not overflowed - if not np.all(_to_end == to_end): - raise ValueError("cannot convert 'to_end' to array with dtype " - "'%r' as required for input ary" % dtype_req) - to_end = _to_end.ravel() + to_end = np.asanyarray(to_end) + if not np.can_cast(to_end, dtype_req, casting="same_kind"): + raise TypeError("dtype of `to_end` must be compatible " + "with input `ary` under the `same_kind` rule.") + + to_end = to_end.ravel() l_end = len(to_end) # do the calculation in place and copy to_begin and to_end diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py index fd21a7f76..1d38d8d27 100644 --- a/numpy/lib/tests/test_arraysetops.py +++ b/numpy/lib/tests/test_arraysetops.py @@ -135,9 +135,9 @@ class TestSetOps(object): None, np.nan), # should fail because attempting - # to downcast to smaller int type: - (np.array([1, 2, 3], dtype=np.int16), - np.array([5, 1<<20, 2], dtype=np.int32), + # to downcast to int type: + (np.array([1, 2, 3], dtype=np.int64), + np.array([5, 7, 2], dtype=np.float32), None), # should fail because attempting to cast # two special floating point values @@ -152,8 +152,8 @@ class TestSetOps(object): # specifically, raise an appropriate # Exception when attempting to append or # prepend with an incompatible type - msg = 'cannot convert' - with assert_raises_regex(ValueError, msg): + msg = 'must be compatible' + with assert_raises_regex(TypeError, msg): ediff1d(ary=ary, to_end=append, to_begin=prepend) @@ -163,9 +163,13 @@ class TestSetOps(object): "append," "expected", [ (np.array([1, 2, 3], dtype=np.int16), - 0, + 2**16, # will be cast to int16 under same kind rule. + 2**16 + 4, + np.array([0, 1, 1, 4], dtype=np.int16)), + (np.array([1, 2, 3], dtype=np.float32), + np.array([5], dtype=np.float64), None, - np.array([0, 1, 1], dtype=np.int16)), + np.array([5, 1, 1], dtype=np.float32)), (np.array([1, 2, 3], dtype=np.int32), 0, 0, @@ -187,6 +191,7 @@ class TestSetOps(object): to_end=append, to_begin=prepend) assert_equal(actual, expected) + assert actual.dtype == expected.dtype def test_isin(self): |