Skip to content

Commit

Permalink
test: ensure the users provided data format is returned
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Nov 20, 2024
1 parent fe33bf0 commit 568d511
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 13 deletions.
26 changes: 13 additions & 13 deletions qpretrieve/data_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def revert_to_data_input_shape(data_format, field):
assert len(field.shape) == 3, "the field should be 3d"
field = field.copy()
if data_format == "rgb":
field = _convert_3d_to_rgb(field)
field = _revert_3d_to_rgb(field)
elif data_format == "rgba":
field = _convert_3d_to_rgba(field)
field = _revert_3d_to_rgba(field)
elif data_format == "3d":
field = field
else:
field = _convert_3d_to_2d(field)
field = _revert_3d_to_2d(field)
return field


Expand All @@ -68,17 +68,17 @@ def _convert_2d_to_3d(data_input):
return data, data_format


def _convert_3d_to_rgb(field):
field = field[0]
field = np.dstack((field, field, field))
return field
def _revert_3d_to_rgb(data_input):
data = data_input[0]
data = np.dstack((data, data, data))
return data


def _convert_3d_to_rgba(field):
field = field[0]
field = np.dstack((field, field, field, np.ones_like(field)))
return field
def _revert_3d_to_rgba(data_input):
data = data_input[0]
data = np.dstack((data, data, data, np.ones_like(data)))
return data


def _convert_3d_to_2d(field):
return field[0]
def _revert_3d_to_2d(data_input):
return data_input[0]
31 changes: 31 additions & 0 deletions tests/test_oah_from_qpimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import qpretrieve
from qpretrieve.interfere import if_oah
from qpretrieve.fourier import FFTFilterNumpy, FFTFilterScipy, FFTFilterPyFFTW
from qpretrieve.data_input import (
_convert_2d_to_3d, _revert_3d_to_rgb, _revert_3d_to_rgba,
)


def test_find_sideband():
Expand Down Expand Up @@ -268,3 +271,31 @@ def test_get_field_compare_FFTFilters(hologram):

assert not np.all(res1 == res2)
assert not np.all(res2 == res3)


def test_field_format_consistency(hologram):
"""The data format provided by the user should be returned"""
data_2d = hologram

# 2d data format
holo1 = qpretrieve.OffAxisHologram(data_2d)
res1 = holo1.run_pipeline()
assert res1.shape == data_2d.shape

# 3d data format
data_3d, _ = _convert_2d_to_3d(data_2d)
holo_3d = qpretrieve.OffAxisHologram(data_3d)
res_3d = holo_3d.run_pipeline()
assert res_3d.shape == data_3d.shape

# rgb data format
data_rgb = _revert_3d_to_rgb(data_3d)
holo_rgb = qpretrieve.OffAxisHologram(data_rgb)
res_rgb = holo_rgb.run_pipeline()
assert res_rgb.shape == data_rgb.shape

# rgba data format
data_rgba = _revert_3d_to_rgba(data_3d)
holo_rgba = qpretrieve.OffAxisHologram(data_rgba)
res_rgba = holo_rgba.run_pipeline()
assert res_rgba.shape == data_rgba.shape

0 comments on commit 568d511

Please sign in to comment.