Skip to content

Commit

Permalink
tests: fix torch preprocessing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
paulmueller committed Jun 7, 2024
1 parent dd0553d commit 093ad1f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
run: |
# https://github.com/luispedro/mahotas/issues/144
pip install mahotas==1.4.13
pip install -e .
pip install .[torch]
- name: List installed packages
run: |
pip freeze
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
0.23.1
- enh: support passing custom default arguments to get_class_method_info
- tests: fix torch preprocessing tests
0.23.0
- feat: implement segmentation using PyTorch models
- fix: always compute image_bg if it is not in the input file
Expand Down
19 changes: 12 additions & 7 deletions src/dcnum/segm/segm_torch/torch_preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@


def preprocess_images(images: np.ndarray,
norm_mean: float,
norm_std: float,
norm_mean: float | None,
norm_std: float | None,
image_shape: Tuple[int, int] = None,
):
"""Transform image data to something torch models expect
Expand All @@ -24,10 +24,11 @@ def preprocess_images(images: np.ndarray,
2D image, it will be reshaped to a 3D image with a batch_size of 1.
norm_mean:
Mean value used for standard score data normalization, i.e.
`normalized = `(images / 255 - norm_mean) / norm_std`
`normalized = `(images / 255 - norm_mean) / norm_std`; Set
to None to disable normalization.
norm_std:
Standard deviation used for standard score data normalization
(see above)
Standard deviation used for standard score data normalization;
Set to None to disable normalization (see above).
image_shape
Image shape for which the model was created (height, width).
If the image shape does not match the input image shape, then
Expand Down Expand Up @@ -102,8 +103,12 @@ def preprocess_images(images: np.ndarray,
# Replace img_norm
img_proc = img_pad

# normalize images
img_norm = (img_proc.astype(np.float32) / 255 - norm_mean) / norm_std
if norm_mean is None or norm_std is None:
# convert to float32
img_norm = img_proc.astype(np.float32)
else:
# normalize images
img_norm = (img_proc.astype(np.float32) / 255 - norm_mean) / norm_std

# Add a "channels" axis for the ML models.
return img_norm[:, np.newaxis, :, :]
24 changes: 12 additions & 12 deletions tests/test_segm_torch_preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def test_reshape_crop_both():
image = np.arange(64, dtype=np.float32).reshape(8, 8)
out = torch_preproc.preprocess_images(images=image,
image_shape=(6, 6),
norm_std=1,
norm_mean=1,
norm_std=None,
norm_mean=None,
)
assert out.shape == (1, 1, 6, 6)
imout = out[0, 0, :, :]
Expand All @@ -23,8 +23,8 @@ def test_reshape_crop_height():
image = np.arange(64, dtype=np.float32).reshape(8, 8)
out = torch_preproc.preprocess_images(images=image,
image_shape=(6, 8),
norm_std=1,
norm_mean=1,
norm_std=None,
norm_mean=None,
)
assert out.shape == (1, 1, 6, 8)
imout = out[0, 0, :, :]
Expand All @@ -35,8 +35,8 @@ def test_reshape_crop_width():
image = np.arange(64, dtype=np.float32).reshape(8, 8)
out = torch_preproc.preprocess_images(images=image,
image_shape=(8, 6),
norm_std=1,
norm_mean=1,
norm_std=None,
norm_mean=None,
)
assert out.shape == (1, 1, 8, 6)
imout = out[0, 0, :, :]
Expand All @@ -47,8 +47,8 @@ def test_reshape_pad_both():
image = np.arange(64, dtype=np.float32).reshape(8, 8)
out = torch_preproc.preprocess_images(images=image,
image_shape=(10, 10),
norm_std=1,
norm_mean=1,
norm_std=None,
norm_mean=None,
)
assert out.shape == (1, 1, 10, 10)
imout = out[0, 0, :, :]
Expand All @@ -64,8 +64,8 @@ def test_reshape_pad_height():
image = np.arange(64, dtype=np.float32).reshape(8, 8)
out = torch_preproc.preprocess_images(images=image,
image_shape=(10, 8),
norm_std=1,
norm_mean=1,
norm_std=None,
norm_mean=None,
)
assert out.shape == (1, 1, 10, 8)
imout = out[0, 0, :, :]
Expand All @@ -79,8 +79,8 @@ def test_reshape_pad_width_crop_height():
image = np.arange(64, dtype=np.float32).reshape(8, 8)
out = torch_preproc.preprocess_images(images=image,
image_shape=(6, 10),
norm_std=1,
norm_mean=1,
norm_std=None,
norm_mean=None,
)
assert out.shape == (1, 1, 6, 10)
imout = out[0, 0, :, :]
Expand Down

0 comments on commit 093ad1f

Please sign in to comment.