summaryrefslogtreecommitdiff
path: root/numpy/lib/arraysetops.py
diff options
context:
space:
mode:
authorTyler Reddy <tyler.je.reddy@gmail.com>2019-01-10 11:16:55 -0800
committerGitHub <noreply@github.com>2019-01-10 11:16:55 -0800
commitce779dc607baa0630095dc456f5b9ad001095cca (patch)
treec7d55c1489f32c1787de5a2ed273b0989b3066a8 /numpy/lib/arraysetops.py
parent23d53d90f7918f5c85c8e9423d06daf4cabfa437 (diff)
parentc088383cb290ca064d456e89d79177a0e234cb8d (diff)
downloadnumpy-ce779dc607baa0630095dc456f5b9ad001095cca.tar.gz
Merge pull request #12713 from mattip/gh-12711
BUG: loosen kwargs requirements in ediff1d
Diffstat (limited to 'numpy/lib/arraysetops.py')
-rw-r--r--numpy/lib/arraysetops.py26
1 files changed, 12 insertions, 14 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index 558150e48..b53d8c03f 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -94,8 +94,7 @@ def ediff1d(ary, to_end=None, to_begin=None):
# force a 1d array
ary = np.asanyarray(ary).ravel()
- # we have unit tests enforcing
- # propagation of the dtype of input
+ # enforce propagation of the dtype of input
# ary to returned result
dtype_req = ary.dtype
@@ -106,23 +105,22 @@ def ediff1d(ary, to_end=None, to_begin=None):
if to_begin is None:
l_begin = 0
else:
- to_begin = np.asanyarray(to_begin)
- if not np.can_cast(to_begin, dtype_req):
- raise TypeError("dtype of to_begin must be compatible "
- "with input ary")
-
- to_begin = to_begin.ravel()
+ _to_begin = np.asanyarray(to_begin, dtype=dtype_req)
+ if not np.all(_to_begin == to_begin):
+ raise ValueError("cannot convert 'to_begin' to array with dtype "
+ "'%r' as required for input ary" % dtype_req)
+ to_begin = _to_begin.ravel()
l_begin = len(to_begin)
if to_end is None:
l_end = 0
else:
- to_end = np.asanyarray(to_end)
- if not np.can_cast(to_end, dtype_req):
- raise TypeError("dtype of to_end must be compatible "
- "with input ary")
-
- to_end = to_end.ravel()
+ _to_end = np.asanyarray(to_end, dtype=dtype_req)
+ # check that casting has not overflowed
+ if not np.all(_to_end == to_end):
+ raise ValueError("cannot convert 'to_end' to array with dtype "
+ "'%r' as required for input ary" % dtype_req)
+ to_end = _to_end.ravel()
l_end = len(to_end)
# do the calculation in place and copy to_begin and to_end