diff --git a/qpretrieve/fourier/__init__.py b/qpretrieve/fourier/__init__.py index bfb4143..3ef04d6 100644 --- a/qpretrieve/fourier/__init__.py +++ b/qpretrieve/fourier/__init__.py @@ -1,6 +1,7 @@ # flake8: noqa: F401 import warnings +from .base import FFTFilter from .ff_numpy import FFTFilterNumpy try: @@ -11,7 +12,7 @@ PREFERRED_INTERFACE = None -def get_available_interfaces(): +def get_available_interfaces() -> list: """Return a list of available FFT algorithms""" interfaces = [ FFTFilterPyFFTW, @@ -24,7 +25,7 @@ def get_available_interfaces(): return interfaces_available -def get_best_interface(): +def get_best_interface() -> FFTFilter: """Return the fastest refocusing interface available If `pyfftw` is installed, :class:`.FFTFilterPyFFTW` diff --git a/qpretrieve/fourier/base.py b/qpretrieve/fourier/base.py index c8f6e91..0df9b77 100644 --- a/qpretrieve/fourier/base.py +++ b/qpretrieve/fourier/base.py @@ -41,7 +41,7 @@ def __init__(self, data: np.ndarray, subtract_mean: bool = True, padding: int = 2, - copy: bool = True): + copy: bool = True) -> None: r""" Parameters ---------- @@ -135,13 +135,13 @@ def __init__(self, self.fft_used = None @property - def shape(self): + def shape(self) -> tuple: """Shape of the Fourier transform data""" return self.fft_origin.shape @property @abstractmethod - def is_available(self): + def is_available(self) -> bool: """Whether this method is available given current hardware/software""" return True @@ -169,7 +169,7 @@ def _init_fft(self, data): def filter(self, filter_name: str, filter_size: float, freq_pos: (float, float), - scale_to_filter: bool | float = False): + scale_to_filter: bool | float = False) -> np.ndarray: """ Parameters ---------- diff --git a/qpretrieve/fourier/ff_numpy.py b/qpretrieve/fourier/ff_numpy.py index af3d960..0ac86c8 100644 --- a/qpretrieve/fourier/ff_numpy.py +++ b/qpretrieve/fourier/ff_numpy.py @@ -10,7 +10,7 @@ class FFTFilterNumpy(FFTFilter): # always available, because numpy is a dependency is_available = True - def _init_fft(self, data): + def _init_fft(self, data: np.ndarray) -> np.ndarray: """Perform initial Fourier transform of the input data Parameters @@ -25,6 +25,6 @@ def _init_fft(self, data): """ return np.fft.fft2(data, axes=(-2, -1)) - def _ifft(self, data): + def _ifft(self, data: np.ndarray) -> np.ndarray: """Perform inverse Fourier transform""" return np.fft.ifft2(data, axes=(-2, -1)) diff --git a/qpretrieve/fourier/ff_pyfftw.py b/qpretrieve/fourier/ff_pyfftw.py index 56eda93..7ecd056 100644 --- a/qpretrieve/fourier/ff_pyfftw.py +++ b/qpretrieve/fourier/ff_pyfftw.py @@ -1,4 +1,5 @@ import multiprocessing as mp +import numpy as np import pyfftw @@ -11,7 +12,7 @@ class FFTFilterPyFFTW(FFTFilter): # always available, because numpy is a dependency is_available = True - def _init_fft(self, data): + def _init_fft(self, data: np.ndarray) -> np.ndarray: """Perform initial Fourier transform of the input data Parameters @@ -33,13 +34,13 @@ def _init_fft(self, data): fft_obj() return out_arr - def _ifft(self, data): + def _ifft(self, data: np.ndarray) -> np.ndarray: """Perform inverse Fourier transform""" in_arr = pyfftw.empty_aligned(data.shape, dtype='complex128') - ou_arr = pyfftw.empty_aligned(data.shape, dtype='complex128') - fft_obj = pyfftw.FFTW(in_arr, ou_arr, axes=(-2, -1), + out_arr = pyfftw.empty_aligned(data.shape, dtype='complex128') + fft_obj = pyfftw.FFTW(in_arr, out_arr, axes=(-2, -1), direction="FFTW_BACKWARD", ) in_arr[:] = data fft_obj() - return ou_arr + return out_arr