From 093ad1f281563cdbba8754c96f646b84c29b6178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul=20M=C3=BCller?= Date: Fri, 7 Jun 2024 21:48:36 +0200 Subject: [PATCH] tests: fix torch preprocessing tests --- .github/workflows/check.yml | 2 +- CHANGELOG | 1 + src/dcnum/segm/segm_torch/torch_preproc.py | 19 ++++++++++------- tests/test_segm_torch_preproc.py | 24 +++++++++++----------- 4 files changed, 26 insertions(+), 20 deletions(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index ab206da..2fb721b 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -31,7 +31,7 @@ jobs: run: | # https://github.com/luispedro/mahotas/issues/144 pip install mahotas==1.4.13 - pip install -e . + pip install .[torch] - name: List installed packages run: | pip freeze diff --git a/CHANGELOG b/CHANGELOG index b54a7b4..f2d6794 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,5 +1,6 @@ 0.23.1 - enh: support passing custom default arguments to get_class_method_info + - tests: fix torch preprocessing tests 0.23.0 - feat: implement segmentation using PyTorch models - fix: always compute image_bg if it is not in the input file diff --git a/src/dcnum/segm/segm_torch/torch_preproc.py b/src/dcnum/segm/segm_torch/torch_preproc.py index 0a3abca..f675136 100644 --- a/src/dcnum/segm/segm_torch/torch_preproc.py +++ b/src/dcnum/segm/segm_torch/torch_preproc.py @@ -4,8 +4,8 @@ def preprocess_images(images: np.ndarray, - norm_mean: float, - norm_std: float, + norm_mean: float | None, + norm_std: float | None, image_shape: Tuple[int, int] = None, ): """Transform image data to something torch models expect @@ -24,10 +24,11 @@ def preprocess_images(images: np.ndarray, 2D image, it will be reshaped to a 3D image with a batch_size of 1. norm_mean: Mean value used for standard score data normalization, i.e. - `normalized = `(images / 255 - norm_mean) / norm_std` + `normalized = `(images / 255 - norm_mean) / norm_std`; Set + to None to disable normalization. norm_std: - Standard deviation used for standard score data normalization - (see above) + Standard deviation used for standard score data normalization; + Set to None to disable normalization (see above). image_shape Image shape for which the model was created (height, width). If the image shape does not match the input image shape, then @@ -102,8 +103,12 @@ def preprocess_images(images: np.ndarray, # Replace img_norm img_proc = img_pad - # normalize images - img_norm = (img_proc.astype(np.float32) / 255 - norm_mean) / norm_std + if norm_mean is None or norm_std is None: + # convert to float32 + img_norm = img_proc.astype(np.float32) + else: + # normalize images + img_norm = (img_proc.astype(np.float32) / 255 - norm_mean) / norm_std # Add a "channels" axis for the ML models. return img_norm[:, np.newaxis, :, :] diff --git a/tests/test_segm_torch_preproc.py b/tests/test_segm_torch_preproc.py index d9846a9..c301346 100644 --- a/tests/test_segm_torch_preproc.py +++ b/tests/test_segm_torch_preproc.py @@ -11,8 +11,8 @@ def test_reshape_crop_both(): image = np.arange(64, dtype=np.float32).reshape(8, 8) out = torch_preproc.preprocess_images(images=image, image_shape=(6, 6), - norm_std=1, - norm_mean=1, + norm_std=None, + norm_mean=None, ) assert out.shape == (1, 1, 6, 6) imout = out[0, 0, :, :] @@ -23,8 +23,8 @@ def test_reshape_crop_height(): image = np.arange(64, dtype=np.float32).reshape(8, 8) out = torch_preproc.preprocess_images(images=image, image_shape=(6, 8), - norm_std=1, - norm_mean=1, + norm_std=None, + norm_mean=None, ) assert out.shape == (1, 1, 6, 8) imout = out[0, 0, :, :] @@ -35,8 +35,8 @@ def test_reshape_crop_width(): image = np.arange(64, dtype=np.float32).reshape(8, 8) out = torch_preproc.preprocess_images(images=image, image_shape=(8, 6), - norm_std=1, - norm_mean=1, + norm_std=None, + norm_mean=None, ) assert out.shape == (1, 1, 8, 6) imout = out[0, 0, :, :] @@ -47,8 +47,8 @@ def test_reshape_pad_both(): image = np.arange(64, dtype=np.float32).reshape(8, 8) out = torch_preproc.preprocess_images(images=image, image_shape=(10, 10), - norm_std=1, - norm_mean=1, + norm_std=None, + norm_mean=None, ) assert out.shape == (1, 1, 10, 10) imout = out[0, 0, :, :] @@ -64,8 +64,8 @@ def test_reshape_pad_height(): image = np.arange(64, dtype=np.float32).reshape(8, 8) out = torch_preproc.preprocess_images(images=image, image_shape=(10, 8), - norm_std=1, - norm_mean=1, + norm_std=None, + norm_mean=None, ) assert out.shape == (1, 1, 10, 8) imout = out[0, 0, :, :] @@ -79,8 +79,8 @@ def test_reshape_pad_width_crop_height(): image = np.arange(64, dtype=np.float32).reshape(8, 8) out = torch_preproc.preprocess_images(images=image, image_shape=(6, 10), - norm_std=1, - norm_mean=1, + norm_std=None, + norm_mean=None, ) assert out.shape == (1, 1, 6, 10) imout = out[0, 0, :, :]