Skip to content

Commit

Permalink
tests: fix shapes for oah tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Jan 21, 2025
1 parent ffbb2c8 commit fb205f8
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions tests/test_oah.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_get_field_filter_names(hologram):
def test_get_field_interpretation_fourier_index(hologram):
"""Filter size in Fourier space using Fourier index new in 0.7.0"""
data = hologram
shape_expected = (1, hologram.shape[-2], hologram.shape[-1])
holo = qpretrieve.OffAxisHologram(data)

ft_data = holo.fft_origin
Expand All @@ -119,8 +120,8 @@ def test_get_field_interpretation_fourier_index(hologram):
)
res2 = holo.run_pipeline(**kwargs2)

assert res1.shape == hologram.shape
assert res2.shape == hologram.shape
assert res1.shape == shape_expected
assert res2.shape == shape_expected
assert np.all(res1 == res2)


Expand Down Expand Up @@ -229,6 +230,8 @@ def test_get_field_three_axes(hologram):
# create a copy with empty entry in third axis
data2 = np.zeros((data1.shape[0], data1.shape[1], 3))
data2[:, :, 0] = data1
# both will be output as (z,y,x) shaped image stacks
shape_expected = (1, hologram.shape[-2], hologram.shape[-1])

holo1 = qpretrieve.OffAxisHologram(data1)
holo2 = qpretrieve.OffAxisHologram(data2)
Expand All @@ -238,17 +241,16 @@ def test_get_field_three_axes(hologram):
res1 = holo1.run_pipeline(**kwargs)
res2 = holo2.run_pipeline(**kwargs)

assert res1.shape == (data1.shape[0], data1.shape[1])
assert res2.shape == (data1.shape[0], data1.shape[1], 3)

assert np.all(res1 == res2[:, :, 0])
assert res1.shape == shape_expected
assert res2.shape == shape_expected
assert np.all(res1 == res2)


def test_get_field_compare_FFTFilters(hologram):
data1 = hologram
kwargs = dict(filter_name="disk", filter_size=1 / 3)
padding = False
shape_expected = (64, 64)
shape_expected = (1, hologram.shape[-2], hologram.shape[-1])

holo1 = qpretrieve.OffAxisHologram(data1,
fft_interface=FFTFilterNumpy,
Expand All @@ -268,28 +270,33 @@ def test_get_field_compare_FFTFilters(hologram):


def test_field_format_consistency(hologram):
"""The data format provided by the user should be returned"""
data_2d = hologram
"""The data format returned should be (z,y,x)"""
data_2d = hologram.copy()
shape_expected = (1, hologram.shape[-2], hologram.shape[-1])

# 2d data format
holo1 = qpretrieve.OffAxisHologram(data_2d)
res1 = holo1.run_pipeline()
assert res1.shape == data_2d.shape
holo_2d = qpretrieve.OffAxisHologram(data_2d)
res_2d = holo_2d.run_pipeline()
assert res_2d.shape == shape_expected

# 3d data format
data_3d, _ = _convert_2d_to_3d(data_2d)
holo_3d = qpretrieve.OffAxisHologram(data_3d)
res_3d = holo_3d.run_pipeline()
assert res_3d.shape == data_3d.shape
assert res_3d.shape == shape_expected

# rgb data format
data_rgb = _revert_3d_to_rgb(data_3d)
holo_rgb = qpretrieve.OffAxisHologram(data_rgb)
res_rgb = holo_rgb.run_pipeline()
assert res_rgb.shape == data_rgb.shape
assert res_rgb.shape == shape_expected

# rgba data format
data_rgba = _revert_3d_to_rgba(data_3d)
holo_rgba = qpretrieve.OffAxisHologram(data_rgba)
res_rgba = holo_rgba.run_pipeline()
assert res_rgba.shape == data_rgba.shape
assert res_rgba.shape == shape_expected

assert np.all(res_2d == res_3d)
assert np.all(res_2d == res_rgb)
assert np.all(res_2d == res_rgba)

0 comments on commit fb205f8

Please sign in to comment.