diff options
author | Sebastian Berg <sebastianb@nvidia.com> | 2023-05-03 12:36:43 +0200 |
---|---|---|
committer | Sebastian Berg <sebastianb@nvidia.com> | 2023-05-03 12:44:59 +0200 |
commit | 89486a335a478aac46be91b92e69267f9409b1be (patch) | |
tree | 495565abe628ea0f6fd2251dcf49be89cc79ac37 /numpy/lib/function_base.py | |
parent | c37a577c9df74e29c97a7bb010de0b37f83870bb (diff) | |
download | numpy-89486a335a478aac46be91b92e69267f9409b1be.tar.gz |
MAINT: Reorganize the way windowing functions ensure float64 result
This roughly changes things so that we ensure a float64 working
values up-front. There is a tiny chance of precision changes if the
input was not float64 or error changes on bad input.
I don't think this should matter in practice, precision changes
(as far as I can tell) should happen rather the other way around.
Since float64 has 53bits mantissa, I think the arange should give
the correct result reliably for any sensible inputs.
There is an argument to be made that the windowing functions could
return float32 for float32 input, but I somewhat think this is OK
and users can be expected to just cast manually after the fact.
The result type is tested, but this ensures the tests pass also
when enabling weak promotion.
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 48 |
1 files changed, 38 insertions, 10 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 22371a038..02e141920 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -2999,10 +2999,15 @@ def blackman(M): >>> plt.show() """ + # Ensures at least float64 via 0.0. M should be an integer, but conversion + # to double is safe for a range. + values = np.array([0.0, M]) + M = values[1] + if M < 1: - return array([], dtype=np.result_type(M, 0.0)) + return array([], dtype=values.dtype) if M == 1: - return ones(1, dtype=np.result_type(M, 0.0)) + return ones(1, dtype=values.dtype) n = arange(1-M, M, 2) return 0.42 + 0.5*cos(pi*n/(M-1)) + 0.08*cos(2.0*pi*n/(M-1)) @@ -3107,10 +3112,15 @@ def bartlett(M): >>> plt.show() """ + # Ensures at least float64 via 0.0. M should be an integer, but conversion + # to double is safe for a range. + values = np.array([0.0, M]) + M = values[1] + if M < 1: - return array([], dtype=np.result_type(M, 0.0)) + return array([], dtype=values.dtype) if M == 1: - return ones(1, dtype=np.result_type(M, 0.0)) + return ones(1, dtype=values.dtype) n = arange(1-M, M, 2) return where(less_equal(n, 0), 1 + n/(M-1), 1 - n/(M-1)) @@ -3211,10 +3221,15 @@ def hanning(M): >>> plt.show() """ + # Ensures at least float64 via 0.0. M should be an integer, but conversion + # to double is safe for a range. + values = np.array([0.0, M]) + M = values[1] + if M < 1: - return array([], dtype=np.result_type(M, 0.0)) + return array([], dtype=values.dtype) if M == 1: - return ones(1, dtype=np.result_type(M, 0.0)) + return ones(1, dtype=values.dtype) n = arange(1-M, M, 2) return 0.5 + 0.5*cos(pi*n/(M-1)) @@ -3311,10 +3326,15 @@ def hamming(M): >>> plt.show() """ + # Ensures at least float64 via 0.0. M should be an integer, but conversion + # to double is safe for a range. + values = np.array([0.0, M]) + M = values[1] + if M < 1: - return array([], dtype=np.result_type(M, 0.0)) + return array([], dtype=values.dtype) if M == 1: - return ones(1, dtype=np.result_type(M, 0.0)) + return ones(1, dtype=values.dtype) n = arange(1-M, M, 2) return 0.54 + 0.46*cos(pi*n/(M-1)) @@ -3590,11 +3610,19 @@ def kaiser(M, beta): >>> plt.show() """ + # Ensures at least float64 via 0.0. M should be an integer, but conversion + # to double is safe for a range. (Simplified result_type with 0.0 + # strongly typed. result-type is not/less order sensitive, but that mainly + # matters for integers anyway.) + values = np.array([0.0, M, beta]) + M = values[1] + beta = values[2] + if M == 1: - return np.ones(1, dtype=np.result_type(M, 0.0)) + return np.ones(1, dtype=values.dtype) n = arange(0, M) alpha = (M-1)/2.0 - return i0(beta * sqrt(1-((n-alpha)/alpha)**2.0))/i0(float(beta)) + return i0(beta * sqrt(1-((n-alpha)/alpha)**2.0))/i0(beta) def _sinc_dispatcher(x): |