summaryrefslogtreecommitdiff
path: root/numpy/fft/tests/test_pocketfft.py
diff options
context:
space:
mode:
authorRoman Yurchak <rth.yurchak@pm.me>2019-01-09 13:50:55 +0200
committerRoman Yurchak <rth.yurchak@pm.me>2019-01-09 13:54:30 +0200
commit1c55693772871dd1473e0013a44d56ed6b2fdb7f (patch)
tree25f6ed918439b7a7320f716ee69015ddb50ff976 /numpy/fft/tests/test_pocketfft.py
parentad0e902717d1245d17856d47d7f16bc7817da866 (diff)
downloadnumpy-1c55693772871dd1473e0013a44d56ed6b2fdb7f.tar.gz
TST Check FFT for C/Fortran ordered and non contigous arrays
Diffstat (limited to 'numpy/fft/tests/test_pocketfft.py')
-rw-r--r--numpy/fft/tests/test_pocketfft.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/numpy/fft/tests/test_pocketfft.py b/numpy/fft/tests/test_pocketfft.py
index 1029294a1..6092c976f 100644
--- a/numpy/fft/tests/test_pocketfft.py
+++ b/numpy/fft/tests/test_pocketfft.py
@@ -166,6 +166,44 @@ class TestFFT1D(object):
assert_array_almost_equal(np.fft.ifft(np.fft.fft(x)), x)
assert_array_almost_equal(np.fft.irfft(np.fft.rfft(x)), x)
+
+@pytest.mark.parametrize(
+ "dtype",
+ [np.float32, np.float64, np.complex64, np.complex128])
+@pytest.mark.parametrize("order", ["F", 'non-contiguous'])
+@pytest.mark.parametrize(
+ "fft",
+ [np.fft.fft, np.fft.fft2, np.fft.fftn,
+ np.fft.ifft, np.fft.ifft2, np.fft.ifftn])
+def test_fft_with_order(dtype, order, fft):
+ # Check that FFT/IFFT produces identical results for C, Fotran and
+ # non contiguous arrays
+ rng = np.random.RandomState(42)
+ X = rng.rand(8, 7, 13).astype(dtype, copy=False)
+ if order == 'F':
+ Y = np.asfortranarray(X)
+ else:
+ # Make a non contiguous array
+ Y = X[::-1]
+ X = np.ascontiguousarray(X[::-1])
+
+ if fft.__name__.endswith('fft'):
+ for axis in range(3):
+ X_res = fft(X, axis=axis)
+ Y_res = fft(Y, axis=axis)
+ assert_array_almost_equal(X_res, Y_res)
+ elif fft.__name__.endswith(('fft2', 'fftn')):
+ axes = [(0, 1), (1, 2), (0, 2)]
+ if fft.__name__.endswith('fftn'):
+ axes.extend([(0,), (1,), (2,), None])
+ for ax in axes:
+ X_res = fft(X, axes=ax)
+ Y_res = fft(Y, axes=ax)
+ assert_array_almost_equal(X_res, Y_res)
+ else:
+ raise ValueError
+
+
class TestFFTThreadSafe(object):
threads = 16
input_shape = (800, 200)