Skip to content

Commit

Permalink
tests: fix pyfftw axes input for fft
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Jan 17, 2025
1 parent a484b98 commit 3ec047b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions qpretrieve/fourier/ff_pyfftw.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _init_fft(self, data):
in_arr = pyfftw.empty_aligned(data.shape, dtype='complex128')
out_arr = pyfftw.empty_aligned(data.shape, dtype='complex128')
fft_obj = pyfftw.FFTW(in_arr, out_arr,
axes=(0, 1),
axes=(-2, -1),
threads=mp.cpu_count())
in_arr[:] = data
fft_obj()
Expand All @@ -37,7 +37,7 @@ def _ifft(self, data):
"""Perform inverse Fourier transform"""
in_arr = pyfftw.empty_aligned(data.shape, dtype='complex128')
ou_arr = pyfftw.empty_aligned(data.shape, dtype='complex128')
fft_obj = pyfftw.FFTW(in_arr, ou_arr, axes=(0, 1),
fft_obj = pyfftw.FFTW(in_arr, ou_arr, axes=(-2, -1),
direction="FFTW_BACKWARD",
)
in_arr[:] = data
Expand Down
5 changes: 3 additions & 2 deletions tests/test_oah_from_qpimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,9 @@ def test_get_field_compare_FFTFilters(hologram):
res2 = holo2.run_pipeline(**kwargs)
assert res2.shape == shape_expected

assert np.all(res1 == res2) # fails on linux, passes on windows?!
# assert np.allclose(res1, res2, rtol=1e-3)
# not exactly the same, but roughly equal to 1e-5
assert np.allclose(holo1.fft.fft_used, holo2.fft.fft_used)
assert np.allclose(res1, res2)


def test_field_format_consistency(hologram):
Expand Down

0 comments on commit 3ec047b

Please sign in to comment.