diff options
author | Roman Yurchak <rth.yurchak@pm.me> | 2019-01-09 13:50:55 +0200 |
---|---|---|
committer | Roman Yurchak <rth.yurchak@pm.me> | 2019-01-09 13:54:30 +0200 |
commit | 1c55693772871dd1473e0013a44d56ed6b2fdb7f (patch) | |
tree | 25f6ed918439b7a7320f716ee69015ddb50ff976 /numpy/fft/tests/test_pocketfft.py | |
parent | ad0e902717d1245d17856d47d7f16bc7817da866 (diff) | |
download | numpy-1c55693772871dd1473e0013a44d56ed6b2fdb7f.tar.gz |
TST Check FFT for C/Fortran ordered and non contigous arrays
Diffstat (limited to 'numpy/fft/tests/test_pocketfft.py')
-rw-r--r-- | numpy/fft/tests/test_pocketfft.py | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/numpy/fft/tests/test_pocketfft.py b/numpy/fft/tests/test_pocketfft.py index 1029294a1..6092c976f 100644 --- a/numpy/fft/tests/test_pocketfft.py +++ b/numpy/fft/tests/test_pocketfft.py @@ -166,6 +166,44 @@ class TestFFT1D(object): assert_array_almost_equal(np.fft.ifft(np.fft.fft(x)), x) assert_array_almost_equal(np.fft.irfft(np.fft.rfft(x)), x) + +@pytest.mark.parametrize( + "dtype", + [np.float32, np.float64, np.complex64, np.complex128]) +@pytest.mark.parametrize("order", ["F", 'non-contiguous']) +@pytest.mark.parametrize( + "fft", + [np.fft.fft, np.fft.fft2, np.fft.fftn, + np.fft.ifft, np.fft.ifft2, np.fft.ifftn]) +def test_fft_with_order(dtype, order, fft): + # Check that FFT/IFFT produces identical results for C, Fotran and + # non contiguous arrays + rng = np.random.RandomState(42) + X = rng.rand(8, 7, 13).astype(dtype, copy=False) + if order == 'F': + Y = np.asfortranarray(X) + else: + # Make a non contiguous array + Y = X[::-1] + X = np.ascontiguousarray(X[::-1]) + + if fft.__name__.endswith('fft'): + for axis in range(3): + X_res = fft(X, axis=axis) + Y_res = fft(Y, axis=axis) + assert_array_almost_equal(X_res, Y_res) + elif fft.__name__.endswith(('fft2', 'fftn')): + axes = [(0, 1), (1, 2), (0, 2)] + if fft.__name__.endswith('fftn'): + axes.extend([(0,), (1,), (2,), None]) + for ax in axes: + X_res = fft(X, axes=ax) + Y_res = fft(Y, axes=ax) + assert_array_almost_equal(X_res, Y_res) + else: + raise ValueError + + class TestFFTThreadSafe(object): threads = 16 input_shape = (800, 200) |