Skip to content

Commit

Permalink
enh: add 3d array usage to qsli
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Nov 20, 2024
1 parent 568d511 commit 7d88bb9
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 32 deletions.
4 changes: 2 additions & 2 deletions qpretrieve/data_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
]


def check_data_input_form(data_input):
def check_data_input_format(data_input):
"""Figure out what data input is provided."""
if len(data_input.shape) == 3:
if data_input.shape[-1] in [1, 2, 3]:
Expand All @@ -30,7 +30,7 @@ def check_data_input_form(data_input):
return data.copy(), data_format


def revert_to_data_input_shape(data_format, field):
def revert_to_data_input_format(data_format, field):
"""Convert the outputted field shape to the original input shape,
for user convenience."""
assert data_format in allowed_data_formats
Expand Down
4 changes: 2 additions & 2 deletions qpretrieve/fourier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .. import filter
from ..utils import padding_3d, mean_3d
from ..data_input import check_data_input_form
from ..data_input import check_data_input_format


class FFTCache:
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(self,
copy = None
data_ed = np.array(data, dtype=dtype, copy=copy)
# figure out what type of data we have
data_ed, self.data_format = check_data_input_form(data_ed)
data_ed, self.data_format = check_data_input_format(data_ed)
#: original data (with subtracted mean)
self.origin = data_ed
# for `subtract_mean` and `padding`, we could use `np.atleast_3d`
Expand Down
12 changes: 4 additions & 8 deletions qpretrieve/interfere/if_oah.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from .base import BaseInterferogram
from ..data_input import revert_to_data_input_shape
from ..data_input import revert_to_data_input_format


class OffAxisHologram(BaseInterferogram):
Expand Down Expand Up @@ -74,7 +74,7 @@ def run_pipeline(self, **pipeline_kws):

if pipeline_kws["sideband_freq"] is None:
pipeline_kws["sideband_freq"] = find_peak_cosine(
self.fft.fft_origin)
self.fft.fft_origin[0])

# convert filter_size to frequency coordinates
fsize = self.compute_filter_size(
Expand All @@ -93,7 +93,7 @@ def run_pipeline(self, **pipeline_kws):
if pipeline_kws["invert_phase"]:
field.imag *= -1

field = revert_to_data_input_shape(self.fft.data_format, field)
field = revert_to_data_input_format(self.fft.data_format, field)
self._field = field
self._phase = None
self._amplitude = None
Expand All @@ -103,7 +103,7 @@ def run_pipeline(self, **pipeline_kws):


def find_peak_cosine(ft_data, copy=True):
"""Find the side band position of a regular off-axis hologram
"""Find the side band position of a 2d regular off-axis hologram
The Fourier transform of a cosine function (known as the
striped fringe pattern in off-axis holography) results in
Expand All @@ -128,10 +128,6 @@ def find_peak_cosine(ft_data, copy=True):
if copy:
ft_data = ft_data.copy()

if len(ft_data.shape) == 3:
# then we have a stack of images, just take one for finding the peak
ft_data = ft_data[0]
assert len(ft_data.shape) == 2
ox, oy = ft_data.shape
cx = ox // 2
cy = oy // 2
Expand Down
29 changes: 16 additions & 13 deletions qpretrieve/interfere/if_qlsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .base import BaseInterferogram
from ..fourier import get_best_interface
from ..data_input import revert_to_data_input_format


class QLSInterferogram(BaseInterferogram):
Expand Down Expand Up @@ -47,7 +48,7 @@ def amplitude(self):
@property
def field(self):
if self._field is None:
self._field = self.amplitude * np.exp(1j*2*np.pi*self.phase)
self._field = self.amplitude * np.exp(1j * 2 * np.pi * self.phase)
return self._field

@property
Expand Down Expand Up @@ -120,7 +121,7 @@ def run_pipeline(self, **pipeline_kws):

if pipeline_kws["sideband_freq"] is None:
pipeline_kws["sideband_freq"] = find_peaks_qlsi(
self.fft.fft_origin)
self.fft.fft_origin[0])

# convert filter_size to frequency coordinates
fsize = self.compute_filter_size(
Expand Down Expand Up @@ -183,10 +184,10 @@ def run_pipeline(self, **pipeline_kws):
# Pad the gradient information so that we can rotate with cropping
# (keeping the image shape the same).
# TODO: Make padding dependent on rotation angle to save time?
sx, sy = px.shape
gradpad1 = np.pad(px, ((sx // 2, sx // 2), (sy // 2, sy // 2)),
sx, sy = px.shape[-2:]
gradpad1 = np.pad(px, ((0, 0), (sx // 2, sx // 2), (sy // 2, sy // 2)),
mode="constant", constant_values=0)
gradpad2 = np.pad(py, ((sx // 2, sx // 2), (sy // 2, sy // 2)),
gradpad2 = np.pad(py, ((0, 0), (sx // 2, sx // 2), (sy // 2, sy // 2)),
mode="constant", constant_values=0)

# Perform rotation of the gradients.
Expand All @@ -204,19 +205,19 @@ def run_pipeline(self, **pipeline_kws):
copy=False)
# Compute the frequencies that correspond to the frequencies of the
# Fourier-transformed image.
fx = np.fft.fftfreq(rfft.shape[0]).reshape(-1, 1)
fy = np.fft.fftfreq(rfft.shape[1]).reshape(1, -1)
fxy = -2*np.pi*1j * (fx + 1j*fy)
fx = np.fft.fftfreq(rfft.shape[-1]).reshape(rfft.shape[0], -1, 1)
fy = np.fft.fftfreq(rfft.shape[-2]).reshape(rfft.shape[0], 1, -1)
fxy = -2 * np.pi * 1j * (fx + 1j * fy)
fxy[0, 0] = 1

# The wavefront is the real part of the inverse Fourier transform
# of the filtered (divided by frequencies) data.
wfr = rfft._ifft(np.fft.ifftshift(rfft.fft_origin)/fxy).real
wfr = rfft._ifft(np.fft.ifftshift(rfft.fft_origin) / fxy).real

# Rotate the wavefront back and crop it so that the FOV matches
# the input data.
raw_wavefront = rotate_noreshape(wfr,
angle)[sx//2:-sx//2, sy//2:-sy//2]
raw_wavefront = rotate_noreshape(
wfr, angle)[:, sx // 2:-sx // 2, sy // 2:-sy // 2]
# Multiply by qlsi pitch term and the scaling factor to get
# the quantitative wavefront.
scaling_factor = self.fft_origin.shape[0] / wfr.shape[0]
Expand All @@ -230,6 +231,8 @@ def run_pipeline(self, **pipeline_kws):

self.pipeline_kws.update(pipeline_kws)

raw_wavefront = revert_to_data_input_format(
self.fft.data_format, raw_wavefront)
self.wavefront = raw_wavefront

return raw_wavefront
Expand Down Expand Up @@ -287,12 +290,12 @@ def find_peaks_qlsi(ft_data, periodicity=4, copy=True):
# circular bandpass according to periodicity
fx = np.fft.fftshift(np.fft.fftfreq(ft_data.shape[0])).reshape(-1, 1)
fy = np.fft.fftshift(np.fft.fftfreq(ft_data.shape[1])).reshape(1, -1)
frmask1 = np.sqrt(fx**2 + fy**2) > 1/(periodicity*.8)
frmask1 = np.sqrt(fx ** 2 + fy ** 2) > 1 / (periodicity * .8)
frmask2 = np.sqrt(fx ** 2 + fy ** 2) < 1 / (periodicity * 1.2)
ft_data[np.logical_or(frmask1, frmask2)] = 0

# find the peak in the left part
am1 = np.argmax(np.abs(ft_data*(fy < 0)))
am1 = np.argmax(np.abs(ft_data * (fy < 0)))
i1y = am1 % oy
i1x = int((am1 - i1y) / oy)

Expand Down
10 changes: 5 additions & 5 deletions tests/test_data_input.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import numpy as np
import pytest

from qpretrieve.data_input import check_data_input_form
from qpretrieve.data_input import check_data_input_format


def test_check_data_input_2d():
data = np.zeros(shape=(256, 256))

data_new, data_format = check_data_input_form(data)
data_new, data_format = check_data_input_format(data)

assert data_new.shape == (1, 256, 256)
assert np.array_equal(data_new[0], data)
Expand All @@ -17,7 +17,7 @@ def test_check_data_input_2d():
def test_check_data_input_3d_image_stack():
data = np.zeros(shape=(50, 256, 256))

data_new, data_format = check_data_input_form(data)
data_new, data_format = check_data_input_format(data)

assert data_new.shape == (50, 256, 256)
assert np.array_equal(data_new, data)
Expand All @@ -28,7 +28,7 @@ def test_check_data_input_3d_rgb():
data = np.zeros(shape=(256, 256, 3))

with pytest.warns(UserWarning):
data_new, data_format = check_data_input_form(data)
data_new, data_format = check_data_input_format(data)

assert data_new.shape == (1, 256, 256)
assert np.array_equal(data_new[0], data[:, :, 0])
Expand All @@ -39,7 +39,7 @@ def test_check_data_input_3d_rgba():
data = np.zeros(shape=(256, 256, 4))

with pytest.warns(UserWarning):
data_new, data_format = check_data_input_form(data)
data_new, data_format = check_data_input_format(data)

assert data_new.shape == (1, 256, 256)
assert np.array_equal(data_new[0], data[:, :, 0])
Expand Down
14 changes: 12 additions & 2 deletions tests/test_fourier_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,22 @@ def test_scale_to_filter_qlsi():
}

ifh = interfere.QLSInterferogram(image, **pipeline_kws)
ifh.run_pipeline()
raw_wavefront = ifh.run_pipeline()
assert raw_wavefront.shape == (720, 720)
assert ifh.phase.shape == (1, 720, 720)
assert ifh.amplitude.shape == (1, 720, 720)
assert ifh.field.shape == (1, 720, 720)

ifr = interfere.QLSInterferogram(refer, **pipeline_kws)
ifr.run_pipeline()
assert ifr.phase.shape == (1, 720, 720)
assert ifr.amplitude.shape == (1, 720, 720)
assert ifr.field.shape == (1, 720, 720)

phase = unwrap_phase(ifh.phase - ifr.phase)
ifh_phase = ifh.phase[0]
ifr_phase = ifr.phase[0]

phase = unwrap_phase(ifh_phase - ifr_phase)
assert phase.shape == (720, 720)
assert np.allclose(phase.mean(), 0.12434563269684816, atol=1e-6)

Expand Down

0 comments on commit 7d88bb9

Please sign in to comment.