summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2023-05-03 12:36:43 +0200
committerSebastian Berg <sebastianb@nvidia.com>2023-05-03 12:44:59 +0200
commit89486a335a478aac46be91b92e69267f9409b1be (patch)
tree495565abe628ea0f6fd2251dcf49be89cc79ac37 /numpy/lib/function_base.py
parentc37a577c9df74e29c97a7bb010de0b37f83870bb (diff)
downloadnumpy-89486a335a478aac46be91b92e69267f9409b1be.tar.gz
MAINT: Reorganize the way windowing functions ensure float64 result
This roughly changes things so that we ensure a float64 working values up-front. There is a tiny chance of precision changes if the input was not float64 or error changes on bad input. I don't think this should matter in practice, precision changes (as far as I can tell) should happen rather the other way around. Since float64 has 53bits mantissa, I think the arange should give the correct result reliably for any sensible inputs. There is an argument to be made that the windowing functions could return float32 for float32 input, but I somewhat think this is OK and users can be expected to just cast manually after the fact. The result type is tested, but this ensures the tests pass also when enabling weak promotion.
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py48
1 files changed, 38 insertions, 10 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 22371a038..02e141920 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -2999,10 +2999,15 @@ def blackman(M):
>>> plt.show()
"""
+ # Ensures at least float64 via 0.0. M should be an integer, but conversion
+ # to double is safe for a range.
+ values = np.array([0.0, M])
+ M = values[1]
+
if M < 1:
- return array([], dtype=np.result_type(M, 0.0))
+ return array([], dtype=values.dtype)
if M == 1:
- return ones(1, dtype=np.result_type(M, 0.0))
+ return ones(1, dtype=values.dtype)
n = arange(1-M, M, 2)
return 0.42 + 0.5*cos(pi*n/(M-1)) + 0.08*cos(2.0*pi*n/(M-1))
@@ -3107,10 +3112,15 @@ def bartlett(M):
>>> plt.show()
"""
+ # Ensures at least float64 via 0.0. M should be an integer, but conversion
+ # to double is safe for a range.
+ values = np.array([0.0, M])
+ M = values[1]
+
if M < 1:
- return array([], dtype=np.result_type(M, 0.0))
+ return array([], dtype=values.dtype)
if M == 1:
- return ones(1, dtype=np.result_type(M, 0.0))
+ return ones(1, dtype=values.dtype)
n = arange(1-M, M, 2)
return where(less_equal(n, 0), 1 + n/(M-1), 1 - n/(M-1))
@@ -3211,10 +3221,15 @@ def hanning(M):
>>> plt.show()
"""
+ # Ensures at least float64 via 0.0. M should be an integer, but conversion
+ # to double is safe for a range.
+ values = np.array([0.0, M])
+ M = values[1]
+
if M < 1:
- return array([], dtype=np.result_type(M, 0.0))
+ return array([], dtype=values.dtype)
if M == 1:
- return ones(1, dtype=np.result_type(M, 0.0))
+ return ones(1, dtype=values.dtype)
n = arange(1-M, M, 2)
return 0.5 + 0.5*cos(pi*n/(M-1))
@@ -3311,10 +3326,15 @@ def hamming(M):
>>> plt.show()
"""
+ # Ensures at least float64 via 0.0. M should be an integer, but conversion
+ # to double is safe for a range.
+ values = np.array([0.0, M])
+ M = values[1]
+
if M < 1:
- return array([], dtype=np.result_type(M, 0.0))
+ return array([], dtype=values.dtype)
if M == 1:
- return ones(1, dtype=np.result_type(M, 0.0))
+ return ones(1, dtype=values.dtype)
n = arange(1-M, M, 2)
return 0.54 + 0.46*cos(pi*n/(M-1))
@@ -3590,11 +3610,19 @@ def kaiser(M, beta):
>>> plt.show()
"""
+ # Ensures at least float64 via 0.0. M should be an integer, but conversion
+ # to double is safe for a range. (Simplified result_type with 0.0
+ # strongly typed. result-type is not/less order sensitive, but that mainly
+ # matters for integers anyway.)
+ values = np.array([0.0, M, beta])
+ M = values[1]
+ beta = values[2]
+
if M == 1:
- return np.ones(1, dtype=np.result_type(M, 0.0))
+ return np.ones(1, dtype=values.dtype)
n = arange(0, M)
alpha = (M-1)/2.0
- return i0(beta * sqrt(1-((n-alpha)/alpha)**2.0))/i0(float(beta))
+ return i0(beta * sqrt(1-((n-alpha)/alpha)**2.0))/i0(beta)
def _sinc_dispatcher(x):