summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorpeterbell10 <peterbell10@live.co.uk>2020-08-12 07:36:07 +0100
committerGitHub <noreply@github.com>2020-08-12 09:36:07 +0300
commita2b9c2d5b6637b040917c0a2ef393dae83f09ee3 (patch)
treeafa02fc2d1bdf8115d4abf49aa2f99db02094841 /numpy/lib
parent7ec2e1bac72afcdc68cf8256879afbc4cb14a907 (diff)
downloadnumpy-a2b9c2d5b6637b040917c0a2ef393dae83f09ee3.tar.gz
API, BUG: Raise error on complex input to i0 (#17062)
* BUG, API: Raise error on complex input to np.i0
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py23
-rw-r--r--numpy/lib/tests/test_function_base.py9
2 files changed, 17 insertions, 15 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index cd8862c94..b530f0aa1 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3193,25 +3193,18 @@ def i0(x):
"""
Modified Bessel function of the first kind, order 0.
- Usually denoted :math:`I_0`. This function does broadcast, but will *not*
- "up-cast" int dtype arguments unless accompanied by at least one float or
- complex dtype argument (see Raises below).
+ Usually denoted :math:`I_0`.
Parameters
----------
- x : array_like, dtype float or complex
+ x : array_like of float
Argument of the Bessel function.
Returns
-------
- out : ndarray, shape = x.shape, dtype = x.dtype
+ out : ndarray, shape = x.shape, dtype = float
The modified Bessel function evaluated at each of the elements of `x`.
- Raises
- ------
- TypeError: array cannot be safely cast to required type
- If argument consists exclusively of int dtypes.
-
See Also
--------
scipy.special.i0, scipy.special.iv, scipy.special.ive
@@ -3241,12 +3234,16 @@ def i0(x):
Examples
--------
>>> np.i0(0.)
- array(1.0) # may vary
- >>> np.i0([0., 1. + 2j])
- array([ 1.00000000+0.j , 0.18785373+0.64616944j]) # may vary
+ array(1.0)
+ >>> np.i0([0, 1, 2, 3])
+ array([1. , 1.26606588, 2.2795853 , 4.88079259])
"""
x = np.asanyarray(x)
+ if x.dtype.kind == 'c':
+ raise TypeError("i0 not supported for complex values")
+ if x.dtype.kind != 'f':
+ x = x.astype(float)
x = np.abs(x)
return piecewise(x, [x <= 8.0], [_i0_1, _i0_2])
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 89c1a2d9b..635fe1432 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -2111,8 +2111,9 @@ class Test_I0:
i0(0.5),
np.array(1.0634833707413234))
- A = np.array([0.49842636, 0.6969809, 0.22011976, 0.0155549])
- expected = np.array([1.06307822, 1.12518299, 1.01214991, 1.00006049])
+ # need at least one test above 8, as the implementation is piecewise
+ A = np.array([0.49842636, 0.6969809, 0.22011976, 0.0155549, 10.0])
+ expected = np.array([1.06307822, 1.12518299, 1.01214991, 1.00006049, 2815.71662847])
assert_almost_equal(i0(A), expected)
assert_almost_equal(i0(-A), expected)
@@ -2149,6 +2150,10 @@ class Test_I0:
assert_array_equal(exp, res)
+ def test_complex(self):
+ a = np.array([0, 1 + 2j])
+ with pytest.raises(TypeError, match="i0 not supported for complex values"):
+ res = i0(a)
class TestKaiser: