From 3111c3187124cf1e45c40235e9f13d0454cf7955 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Wed, 18 Sep 2024 17:13:02 +0200 Subject: [PATCH 01/16] add MultiOutput in enum Signed-off-by: thibaultdvx --- monai/utils/__init__.py | 1 + monai/utils/enums.py | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 4e36e3cd47..84c9ac8f82 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -47,6 +47,7 @@ MetaKeys, Method, MetricReduction, + MultiOutput, NdimageMode, NumpyPadMode, OrderingTransformations, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 7838a2e741..e9675ff369 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -230,6 +230,16 @@ class Average(StrEnum): NONE = "none" +class MultiOutput(StrEnum): + """ + See also: :py:func:`monai.metrics.r2_score.compute_r2_score` + """ + + RAW = "raw_values" + UNIFORM = "uniform_average" + VARIANCE = "variance_weighted" + + class MetricReduction(StrEnum): """ See also: :py:func:`monai.metrics.utils.do_metric_reduction` From 62ecb2347ac675059da710cd3e744e7d3caf1ed3 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Wed, 18 Sep 2024 17:13:27 +0200 Subject: [PATCH 02/16] add r2 metric and compute Signed-off-by: thibaultdvx --- monai/metrics/__init__.py | 1 + monai/metrics/r2_score.py | 184 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 monai/metrics/r2_score.py diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 201acdfa50..db0de24eb0 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -25,6 +25,7 @@ from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric from .mmd import MMDMetric, compute_mmd from .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality +from .r2_score import R2Metric, compute_r2_score from .regression import ( MAEMetric, MSEMetric, diff --git a/monai/metrics/r2_score.py b/monai/metrics/r2_score.py new file mode 100644 index 0000000000..52ffc3c5b0 --- /dev/null +++ b/monai/metrics/r2_score.py @@ -0,0 +1,184 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + import numpy.typing as npt + +import torch + +from monai.utils import MultiOutput, look_up_option + +from .metric import CumulativeIterationMetric + + +class R2Metric(CumulativeIterationMetric): + r"""Computes :math:`R^{2}` score (coefficient of determination): + + .. math:: + \operatorname {R^{2}}\left(Y, \hat{Y}\right) = 1 - \frac {\sum _{i=1}^{n}\left(y_i-\hat{y_i} \right)^{2}}{\sum _{i=1}^{n}\left(y_i-\bar{y} \right)^{2}}, + + where :math:`\bar{y}` is the mean of observed :math:`y` ; or adjusted :math:`R^{2}` score: + + .. math:: + \operatorname {\bar{R}^{2}} = 1 - (1-R^{2}) \frac {n-1}{n-p-1}, + + where :math:`p` is the number of independant variables used for the regression. + + More info: https://en.wikipedia.org/wiki/Coefficient_of_determination + + Input `y_pred` is compared with ground truth `y`. + `y_pred` and `y` are expected to be 1D (single-output regression) or 2D (multi-output regression) real-valued tensors of same shape. + + Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + + Args: + multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``} + Type of aggregation performed on multi-output scores. + Defaults to ``"uniform_average"``. + + - ``"raw_values"``: the scores for each output are returned. + - ``"uniform_average"``: the scores of all outputs are averaged with uniform weight. + - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances of + each individual output. + p: non-negative integer. + Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score. + Defaults to 0 (standard :math:`R^{2}` score). + + """ + + def __init__(self, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int = 0) -> None: + super().__init__() + multi_output, p = _check_r2_params(multi_output, p) + self.multi_output = multi_output + self.p = p + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + _check_dim(y_pred, y) + return y_pred, y + + def aggregate(self, multi_output: MultiOutput | str | None = None) -> np.ndarray | float | npt.ArrayLike: + """ + Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration, + This function reads the buffers and computes the :math:`R^{2}` score. + + Args: + multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``} + Type of aggregation performed on multi-output scores. Defaults to `self.multi_output`. + + """ + y_pred, y = self.get_buffer() + return compute_r2_score(y_pred=y_pred, y=y, multi_output=multi_output or self.multi_output, p=self.p) + + +def _check_dim(y_pred: torch.Tensor, y: torch.Tensor) -> None: + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): + raise ValueError("y_pred and y must be PyTorch Tensor.") + + if y.shape != y_pred.shape: + raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.") + + dim = y.ndimension() + if dim not in (1, 2): + raise ValueError( + f"predictions and ground truths should be of shape (batch_size, num_outputs) or (batch_size, ), got {y.shape}." + ) + + +def _check_r2_params(multi_output, p) -> tuple[MultiOutput, int]: + multi_output = look_up_option(multi_output, MultiOutput) + if not isinstance(p, int) or p < 0: + raise ValueError(f"`p` must be an integer larger or equal to 0, got {p}.") + + return multi_output, p + + +def _calculate(y_pred: np.ndarray, y: np.ndarray, p: int) -> float: + num_obs = len(y) + rss = np.sum((y_pred - y) ** 2) + tss = np.sum(y ** 2) - np.sum(y) ** 2 / num_obs + r2 = 1 - (rss / tss) + r2_adjusted = 1 - (1 - r2) * (num_obs - 1) / (num_obs - p - 1) + + return r2_adjusted + + +def compute_r2_score( + y_pred: torch.Tensor, y: torch.Tensor, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int = 0 +) -> np.ndarray | float | npt.ArrayLike: + """Computes :math:`R^{2}` score (coefficient of determination). + + Args: + y_pred: input data to compute :math:`R^{2}` score, the first dim must be batch. + For example: shape `[16]` or `[16, 1]` for a single-output regression, shape `[16, x]` for x output variables. + y: ground truth to compute :math:`R^{2}` score, the first dim must be batch. + For example: shape `[16]` or `[16, 1]` for a single-output regression, shape `[16, x]` for x output variables. + multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``} + Type of aggregation performed on multi-output scores. + Defaults to ``"uniform_average"``. + + - ``"raw_values"``: the scores for each output are returned. + - ``"uniform_average"``: the scores of all outputs are averaged with uniform weight. + - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances + each individual output. + p: non-negative integer. + Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score. + Defaults to 0 (standard :math:`R^{2}` score). + + Raises: + ValueError: When ``multi_output`` is not one of ["raw_values", "uniform_average", "variance_weighted"]. + ValueError: When ``p`` is not a non-negative integer. + ValueError: When ``y_pred`` or ``y`` are not PyTorch tensors. + ValueError: When ``y_pred`` and ``y`` don't have the same shape. + ValueError: When ``y_pred`` or ``y`` dimension is not one of [1, 2]. + ValueError: When n_samples is less than 2. + ValueError: When ``p`` is greater or equal to n_samples - 1. + + """ + multi_output, p = _check_r2_params(multi_output, p) + _check_dim(y_pred, y) + dim = y.ndimension() + n = y.shape[0] + y = y.cpu().numpy() + y_pred = y_pred.cpu().numpy() + + if n < 2: + raise ValueError( + "There is no enough data for computing. Needs at least two samples to calculate r2 score." + ) + if p >= n - 1: + raise ValueError( + "`p` must be smaller than n_samples - 1, " + f"got p={p}, n_samples={n}.", + ) + + if dim == 2 and y_pred.shape[1] == 1: + y_pred = np.squeeze(y_pred, axis=-1) + y = np.squeeze(y, axis=-1) + dim = 1 + + if dim == 1: + return _calculate(y_pred, y, p) + + y, y_pred = np.transpose(y, axes=(1, 0)), np.transpose(y_pred, axes=(1, 0)) + r2_values = [_calculate(y_pred_, y_, p) for y_pred_, y_ in zip(y_pred, y)] + if multi_output == MultiOutput.RAW: + return r2_values + if multi_output == MultiOutput.UNIFORM: + return np.mean(r2_values) + if multi_output == multi_output.VARIANCE: + weights = np.var(y, axis=1) + return np.average(r2_values, weights=weights) From 9e3458676b4c128e25f8df132be400cb6a014024 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Wed, 18 Sep 2024 17:13:47 +0200 Subject: [PATCH 03/16] add r2 handler Signed-off-by: thibaultdvx --- monai/handlers/__init__.py | 1 + monai/handlers/r2_score.py | 51 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 monai/handlers/r2_score.py diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index c1fa448f25..fed8504722 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -34,6 +34,7 @@ from .parameter_scheduler import ParamSchedulerHandler from .postprocessing import PostProcessing from .probability_maps import ProbMapProducer +from .r2_score import R2Score from .regression_metrics import MeanAbsoluteError, MeanSquaredError, PeakSignalToNoiseRatio, RootMeanSquaredError from .roc_auc import ROCAUC from .smartcache_handler import SmartCacheHandler diff --git a/monai/handlers/r2_score.py b/monai/handlers/r2_score.py new file mode 100644 index 0000000000..d90ddede52 --- /dev/null +++ b/monai/handlers/r2_score.py @@ -0,0 +1,51 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +from monai.handlers.ignite_metric import IgniteMetricHandler +from monai.metrics import R2Metric +from monai.utils import MultiOutput + + +class R2Score(IgniteMetricHandler): + """ + Computes :math:`R^{2}` score accumulating predictions and the ground-truth during an epoch and applying `compute_r2_score`. + + Args: + multi_output: {``"raw_values"``, ``"uniform_average"``, ``"variance_weighted"``} + Type of aggregation performed on multi-output scores. + Defaults to ``"uniform_average"``. + + - ``"raw_values"``: the scores for each output are returned. + - ``"uniform_average"``: the scores of all outputs are averaged with uniform weight. + - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances + of each individual output. + p: non-negative integer. + Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score. + Defaults to 0 (standard :math:`R^{2}` score). + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. The form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + + See also: + :py:class:`monai.metrics.R2Metric` + + """ + + def __init__(self, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int = 0, output_transform: Callable = lambda x: x) -> None: + metric_fn = R2Metric(multi_output=multi_output, p=p) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False) From ee65c1ecdda557f47f02cb14cb84f33047d5b473 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Wed, 18 Sep 2024 17:14:04 +0200 Subject: [PATCH 04/16] unittests Signed-off-by: thibaultdvx --- tests/test_compute_r2_score.py | 173 ++++++++++++++++++++++++++++ tests/test_handler_r2_score.py | 41 +++++++ tests/test_handler_r2_score_dist.py | 50 ++++++++ 3 files changed, 264 insertions(+) create mode 100644 tests/test_compute_r2_score.py create mode 100644 tests/test_handler_r2_score.py create mode 100644 tests/test_handler_r2_score_dist.py diff --git a/tests/test_compute_r2_score.py b/tests/test_compute_r2_score.py new file mode 100644 index 0000000000..7b069f7238 --- /dev/null +++ b/tests/test_compute_r2_score.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.metrics import R2Metric, compute_r2_score + + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" +TEST_CASE_1 = [ + torch.tensor([0.1, -0.25, 3.0, 0.99], device=_device), + torch.tensor([0.1, -0.2, -2.7, 1.58], device=_device), + "uniform_average", + 0, + -2.469944, +] + +TEST_CASE_2 = [ + torch.tensor([0.1, -0.25, 3.0, 0.99]), + torch.tensor([0.1, -0.2, 2.7, 1.58]), + "uniform_average", + 2, + 0.75828, +] + +TEST_CASE_3 = [ + torch.tensor([[0.1], [-0.25], [3.0], [0.99]]), + torch.tensor([[0.1], [-0.2], [2.7], [1.58]]), + "raw_values", + 2, + 0.75828, +] + +TEST_CASE_4 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "raw_values", + 1, + [0.87914, 0.844375], +] + +TEST_CASE_5 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "variance_weighted", + 1, + 0.867314, +] + +TEST_CASE_6 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "uniform_average", + 0, + 0.907838, +] + +TEST_CASE_ERROR_1 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "abc", + 0, +] + +TEST_CASE_ERROR_2 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "uniform_average", + -1, +] + +TEST_CASE_ERROR_3 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + np.array([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "uniform_average", + 0, +] + +TEST_CASE_ERROR_4 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1]]), + "uniform_average", + 0, +] + +TEST_CASE_ERROR_5 = [ + torch.tensor([[[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]]), + torch.tensor([[[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]]), + "uniform_average", + 0, +] + +TEST_CASE_ERROR_6 = [ + torch.tensor([[0.1, 1.0], [-0.25, 0.5], [3.0, -0.2], [0.99, 2.1]]), + torch.tensor([[0.1, 0.82], [-0.2, 0.01], [2.7, -0.1], [1.58, 2.0]]), + "uniform_average", + 3, +] + +TEST_CASE_ERROR_7 = [ + torch.tensor([[0.1, 1.0]]), + torch.tensor([[0.1, 0.82]]), + "uniform_average", + 0, +] + +class TestComputeR2Score(unittest.TestCase): + + @parameterized.expand( + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + ] + ) + def test_value(self, y_pred, y, multi_output, p, expected_value): + result = compute_r2_score(y_pred=y_pred, y=y, multi_output=multi_output, p=p) + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + + @parameterized.expand( + [ + TEST_CASE_ERROR_1, + TEST_CASE_ERROR_2, + TEST_CASE_ERROR_3, + TEST_CASE_ERROR_4, + TEST_CASE_ERROR_5, + TEST_CASE_ERROR_6, + TEST_CASE_ERROR_7, + ] + ) + def test_error(self, y_pred, y, multi_output, p): + with self.assertRaises(ValueError): + _ = compute_r2_score(y_pred=y_pred, y=y, multi_output=multi_output, p=p) + + @parameterized.expand( + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + ] + ) + def test_class_value(self, y_pred, y, multi_output, p, expected_value): + metric = R2Metric(multi_output=multi_output, p=p) + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + result = metric.aggregate(multi_output=multi_output) # test optional argument + metric.reset() + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_r2_score.py b/tests/test_handler_r2_score.py new file mode 100644 index 0000000000..f2fa243719 --- /dev/null +++ b/tests/test_handler_r2_score.py @@ -0,0 +1,41 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch + +from monai.handlers import R2Score + + +class TestHandlerR2Score(unittest.TestCase): + + def test_compute(self): + r2_score = R2Score(multi_output="variance_weighted", p=1) + + y_pred = [torch.Tensor([0.1, 1.0]), torch.Tensor([-0.25, 0.5])] + y = [torch.Tensor([0.1, 0.82]), torch.Tensor([-0.2, 0.01])] + r2_score.update([y_pred, y]) + + y_pred = [torch.Tensor([3.0, -0.2]), torch.Tensor([0.99, 2.1])] + y = [torch.Tensor([2.7, -0.1]), torch.Tensor([1.58, 2.0])] + + r2_score.update([y_pred, y]) + + r2 = r2_score.compute() + np.testing.assert_allclose(0.867314, r2, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_r2_score_dist.py b/tests/test_handler_r2_score_dist.py new file mode 100644 index 0000000000..a8eb4bf455 --- /dev/null +++ b/tests/test_handler_r2_score_dist.py @@ -0,0 +1,50 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +import torch.distributed as dist + +from monai.handlers import R2Score +from tests.utils import DistCall, DistTestCase + + +class DistributedR2Score(DistTestCase): + + @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) + def test_compute(self): + r2_score = R2Score(multi_output="variance_weighted", p=1) + + device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" + if dist.get_rank() == 0: + y_pred = [torch.tensor([0.1, 1.0], device=device), torch.tensor([-0.25, 0.5], device=device)] + y = [torch.tensor([0.1, 0.82], device=device), torch.tensor([-0.2, 0.01], device=device)] + + if dist.get_rank() == 1: + y_pred = [ + torch.tensor([3.0, -0.2], device=device), + torch.tensor([0.99, 2.1], device=device), + torch.tensor([-0.1, 0.0], device=device), + ] + y = [torch.tensor([2.7, -0.1], device=device), torch.tensor([1.58, 2.0], device=device), torch.tensor([-1.0, -0.1], device=device)] + + r2_score.update([y_pred, y]) + + result = r2_score.compute() + np.testing.assert_allclose(0.829185, result, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() From bfce6e7b2b6871a1aeecf8c437e5a2d90486ea6b Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Wed, 18 Sep 2024 17:14:13 +0200 Subject: [PATCH 05/16] docs Signed-off-by: thibaultdvx --- docs/source/handlers.rst | 6 ++++++ docs/source/metrics.rst | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 270083f717..b48869d01e 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -77,6 +77,12 @@ Panoptic Quality metrics handler :members: +:math:`R^{2}` score +------------------- +.. autoclass:: R2Score + :members: + + Mean squared error metrics handler ---------------------------------- .. autoclass:: MeanSquaredError diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 616f0fe385..751c624405 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -117,6 +117,13 @@ Metrics .. autoclass:: PanopticQualityMetric :members: +:math:`R^{2}` score +------------------- +.. autofunction:: compute_r2_score + +.. autoclass:: R2Metric + :members: + `Mean squared error` -------------------- .. autoclass:: MSEMetric From e00db73f8439bb73406b833c1ad397224ff140a9 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Wed, 18 Sep 2024 17:22:41 +0200 Subject: [PATCH 06/16] fix code issues Signed-off-by: thibaultdvx --- monai/handlers/r2_score.py | 11 +++++++--- monai/metrics/r2_score.py | 27 +++++++++++-------------- monai/utils/enums.py | 1 + tests/test_compute_r2_score.py | 31 ++++------------------------- tests/test_handler_r2_score_dist.py | 6 +++++- 5 files changed, 30 insertions(+), 46 deletions(-) diff --git a/monai/handlers/r2_score.py b/monai/handlers/r2_score.py index d90ddede52..dc94182885 100644 --- a/monai/handlers/r2_score.py +++ b/monai/handlers/r2_score.py @@ -29,9 +29,9 @@ class R2Score(IgniteMetricHandler): - ``"raw_values"``: the scores for each output are returned. - ``"uniform_average"``: the scores of all outputs are averaged with uniform weight. - - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances + - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances of each individual output. - p: non-negative integer. + p: non-negative integer. Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score. Defaults to 0 (standard :math:`R^{2}` score). output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then @@ -46,6 +46,11 @@ class R2Score(IgniteMetricHandler): """ - def __init__(self, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int = 0, output_transform: Callable = lambda x: x) -> None: + def __init__( + self, + multi_output: MultiOutput | str = MultiOutput.UNIFORM, + p: int = 0, + output_transform: Callable = lambda x: x, + ) -> None: metric_fn = R2Metric(multi_output=multi_output, p=p) super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False) diff --git a/monai/metrics/r2_score.py b/monai/metrics/r2_score.py index 52ffc3c5b0..e2b323f175 100644 --- a/monai/metrics/r2_score.py +++ b/monai/metrics/r2_score.py @@ -29,19 +29,21 @@ class R2Metric(CumulativeIterationMetric): r"""Computes :math:`R^{2}` score (coefficient of determination): .. math:: - \operatorname {R^{2}}\left(Y, \hat{Y}\right) = 1 - \frac {\sum _{i=1}^{n}\left(y_i-\hat{y_i} \right)^{2}}{\sum _{i=1}^{n}\left(y_i-\bar{y} \right)^{2}}, + \operatorname {R^{2}}\left(Y, \hat{Y}\right) = 1 - \frac {\sum _{i=1}^{n}\left(y_i-\hat{y_i} \right)^{2}} + {\sum _{i=1}^{n}\left(y_i-\bar{y} \right)^{2}}, where :math:`\bar{y}` is the mean of observed :math:`y` ; or adjusted :math:`R^{2}` score: .. math:: - \operatorname {\bar{R}^{2}} = 1 - (1-R^{2}) \frac {n-1}{n-p-1}, + \operatorname {\bar{R}^{2}} = 1 - (1-R^{2}) \frac {n-1}{n-p-1}, where :math:`p` is the number of independant variables used for the regression. More info: https://en.wikipedia.org/wiki/Coefficient_of_determination Input `y_pred` is compared with ground truth `y`. - `y_pred` and `y` are expected to be 1D (single-output regression) or 2D (multi-output regression) real-valued tensors of same shape. + `y_pred` and `y` are expected to be 1D (single-output regression) or 2D (multi-output regression) real-valued + tensors of same shape. Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. @@ -52,9 +54,9 @@ class R2Metric(CumulativeIterationMetric): - ``"raw_values"``: the scores for each output are returned. - ``"uniform_average"``: the scores of all outputs are averaged with uniform weight. - - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances of + - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances of each individual output. - p: non-negative integer. + p: non-negative integer. Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score. Defaults to 0 (standard :math:`R^{2}` score). @@ -102,14 +104,14 @@ def _check_r2_params(multi_output, p) -> tuple[MultiOutput, int]: multi_output = look_up_option(multi_output, MultiOutput) if not isinstance(p, int) or p < 0: raise ValueError(f"`p` must be an integer larger or equal to 0, got {p}.") - + return multi_output, p def _calculate(y_pred: np.ndarray, y: np.ndarray, p: int) -> float: num_obs = len(y) rss = np.sum((y_pred - y) ** 2) - tss = np.sum(y ** 2) - np.sum(y) ** 2 / num_obs + tss = np.sum(y**2) - np.sum(y) ** 2 / num_obs r2 = 1 - (rss / tss) r2_adjusted = 1 - (1 - r2) * (num_obs - 1) / (num_obs - p - 1) @@ -132,7 +134,7 @@ def compute_r2_score( - ``"raw_values"``: the scores for each output are returned. - ``"uniform_average"``: the scores of all outputs are averaged with uniform weight. - - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances + - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances each individual output. p: non-negative integer. Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score. @@ -156,14 +158,9 @@ def compute_r2_score( y_pred = y_pred.cpu().numpy() if n < 2: - raise ValueError( - "There is no enough data for computing. Needs at least two samples to calculate r2 score." - ) + raise ValueError("There is no enough data for computing. Needs at least two samples to calculate r2 score.") if p >= n - 1: - raise ValueError( - "`p` must be smaller than n_samples - 1, " - f"got p={p}, n_samples={n}.", - ) + raise ValueError("`p` must be smaller than n_samples - 1, " f"got p={p}, n_samples={n}.") if dim == 2 and y_pred.shape[1] == 1: y_pred = np.squeeze(y_pred, axis=-1) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index e9675ff369..05cc94500c 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -30,6 +30,7 @@ "NdimageMode", "GridSamplePadMode", "Average", + "MultiOutput", "MetricReduction", "LossReduction", "DiceCEReduction", diff --git a/tests/test_compute_r2_score.py b/tests/test_compute_r2_score.py index 7b069f7238..0cea11cf47 100644 --- a/tests/test_compute_r2_score.py +++ b/tests/test_compute_r2_score.py @@ -19,7 +19,6 @@ from monai.metrics import R2Metric, compute_r2_score - _device = "cuda:0" if torch.cuda.is_available() else "cpu" TEST_CASE_1 = [ torch.tensor([0.1, -0.25, 3.0, 0.99], device=_device), @@ -111,25 +110,12 @@ 3, ] -TEST_CASE_ERROR_7 = [ - torch.tensor([[0.1, 1.0]]), - torch.tensor([[0.1, 0.82]]), - "uniform_average", - 0, -] +TEST_CASE_ERROR_7 = [torch.tensor([[0.1, 1.0]]), torch.tensor([[0.1, 0.82]]), "uniform_average", 0] + class TestComputeR2Score(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - ] - ) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_value(self, y_pred, y, multi_output, p, expected_value): result = compute_r2_score(y_pred=y_pred, y=y, multi_output=multi_output, p=p) np.testing.assert_allclose(expected_value, result, rtol=1e-5) @@ -149,16 +135,7 @@ def test_error(self, y_pred, y, multi_output, p): with self.assertRaises(ValueError): _ = compute_r2_score(y_pred=y_pred, y=y, multi_output=multi_output, p=p) - @parameterized.expand( - [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - ] - ) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_class_value(self, y_pred, y, multi_output, p, expected_value): metric = R2Metric(multi_output=multi_output, p=p) metric(y_pred=y_pred, y=y) diff --git a/tests/test_handler_r2_score_dist.py b/tests/test_handler_r2_score_dist.py index a8eb4bf455..378989d555 100644 --- a/tests/test_handler_r2_score_dist.py +++ b/tests/test_handler_r2_score_dist.py @@ -38,7 +38,11 @@ def test_compute(self): torch.tensor([0.99, 2.1], device=device), torch.tensor([-0.1, 0.0], device=device), ] - y = [torch.tensor([2.7, -0.1], device=device), torch.tensor([1.58, 2.0], device=device), torch.tensor([-1.0, -0.1], device=device)] + y = [ + torch.tensor([2.7, -0.1], device=device), + torch.tensor([1.58, 2.0], device=device), + torch.tensor([-1.0, -0.1], device=device), + ] r2_score.update([y_pred, y]) From c9ee60c0624d1f37007ca4b8d15f31e007f86458 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Wed, 18 Sep 2024 18:22:20 +0200 Subject: [PATCH 07/16] mypy issues Signed-off-by: thibaultdvx --- monai/metrics/r2_score.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/monai/metrics/r2_score.py b/monai/metrics/r2_score.py index e2b323f175..848786cc3a 100644 --- a/monai/metrics/r2_score.py +++ b/monai/metrics/r2_score.py @@ -68,7 +68,7 @@ def __init__(self, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int self.multi_output = multi_output self.p = p - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override] _check_dim(y_pred, y) return y_pred, y @@ -100,7 +100,7 @@ def _check_dim(y_pred: torch.Tensor, y: torch.Tensor) -> None: ) -def _check_r2_params(multi_output, p) -> tuple[MultiOutput, int]: +def _check_r2_params(multi_output: MultiOutput | str, p: int) -> tuple[MultiOutput | str, int]: multi_output = look_up_option(multi_output, MultiOutput) if not isinstance(p, int) or p < 0: raise ValueError(f"`p` must be an integer larger or equal to 0, got {p}.") @@ -115,7 +115,7 @@ def _calculate(y_pred: np.ndarray, y: np.ndarray, p: int) -> float: r2 = 1 - (rss / tss) r2_adjusted = 1 - (1 - r2) * (num_obs - 1) / (num_obs - p - 1) - return r2_adjusted + return r2_adjusted # type: ignore[no-any-return] def compute_r2_score( @@ -154,8 +154,8 @@ def compute_r2_score( _check_dim(y_pred, y) dim = y.ndimension() n = y.shape[0] - y = y.cpu().numpy() - y_pred = y_pred.cpu().numpy() + y = y.cpu().numpy() # type: ignore[assignment] + y_pred = y_pred.cpu().numpy() # type: ignore[assignment] if n < 2: raise ValueError("There is no enough data for computing. Needs at least two samples to calculate r2 score.") @@ -163,19 +163,20 @@ def compute_r2_score( raise ValueError("`p` must be smaller than n_samples - 1, " f"got p={p}, n_samples={n}.") if dim == 2 and y_pred.shape[1] == 1: - y_pred = np.squeeze(y_pred, axis=-1) - y = np.squeeze(y, axis=-1) + y_pred = np.squeeze(y_pred, axis=-1) # type: ignore[assignment] + y = np.squeeze(y, axis=-1) # type: ignore[assignment] dim = 1 if dim == 1: - return _calculate(y_pred, y, p) + return _calculate(y_pred, y, p) # type: ignore[arg-type] - y, y_pred = np.transpose(y, axes=(1, 0)), np.transpose(y_pred, axes=(1, 0)) + y, y_pred = np.transpose(y, axes=(1, 0)), np.transpose(y_pred, axes=(1, 0)) # type: ignore[assignment] r2_values = [_calculate(y_pred_, y_, p) for y_pred_, y_ in zip(y_pred, y)] if multi_output == MultiOutput.RAW: return r2_values if multi_output == MultiOutput.UNIFORM: return np.mean(r2_values) - if multi_output == multi_output.VARIANCE: + if multi_output == MultiOutput.VARIANCE: weights = np.var(y, axis=1) - return np.average(r2_values, weights=weights) + return np.average(r2_values, weights=weights) # type: ignore[no-any-return] + raise ValueError(f'Unsupported multi_output: {multi_output}, available options are ["raw_values", "uniform_average", "variance_weighted"].') From 838d0a1b091df1fdc7d683d56f324bfc234512f6 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Wed, 18 Sep 2024 18:35:41 +0200 Subject: [PATCH 08/16] code format Signed-off-by: thibaultdvx --- monai/metrics/r2_score.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/metrics/r2_score.py b/monai/metrics/r2_score.py index 848786cc3a..0ad2e133a5 100644 --- a/monai/metrics/r2_score.py +++ b/monai/metrics/r2_score.py @@ -179,4 +179,6 @@ def compute_r2_score( if multi_output == MultiOutput.VARIANCE: weights = np.var(y, axis=1) return np.average(r2_values, weights=weights) # type: ignore[no-any-return] - raise ValueError(f'Unsupported multi_output: {multi_output}, available options are ["raw_values", "uniform_average", "variance_weighted"].') + raise ValueError( + f'Unsupported multi_output: {multi_output}, available options are ["raw_values", "uniform_average", "variance_weighted"].' + ) From a33a2121fee12403da3cc357039987bb328469ab Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 13 Feb 2025 03:10:30 +0800 Subject: [PATCH 09/16] Recursive Item Mapping for Nested Lists in Compose (#8187) Fixes #8186. ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Ben Murray Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/transforms/compose.py | 39 +++++++++++++++++------- monai/transforms/transform.py | 21 +++++++++---- tests/transforms/compose/test_compose.py | 14 +++++++++ 3 files changed, 57 insertions(+), 17 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 236d3cc4c5..4513e26678 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -47,7 +47,7 @@ def execute_compose( data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor], transforms: Sequence[Any], - map_items: bool = True, + map_items: bool | int = True, unpack_items: bool = False, start: int = 0, end: int | None = None, @@ -65,8 +65,13 @@ def execute_compose( Args: data: a tensor-like object to be transformed transforms: a sequence of transforms to be carried out - map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. - defaults to `True`. + map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple, + it can behave as follows: + - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied + to the first level of items in `data`. + - If an integer is provided, it specifies the maximum level of nesting to which the transformation + should be recursively applied. This allows treating multi-sample transforms applied after another + multi-sample transform while controlling how deep the mapping goes. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. start: the index of the first transform to be executed. If not set, this defaults to 0 @@ -205,8 +210,14 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform): Args: transforms: sequence of callables. - map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. - defaults to `True`. + map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple, + it can behave as follows: + + - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied + to the first level of items in `data`. + - If an integer is provided, it specifies the maximum level of nesting to which the transformation + should be recursively applied. This allows treating multi-sample transforms applied after another + multi-sample transform while controlling how deep the mapping goes. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution. @@ -227,7 +238,7 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform): def __init__( self, transforms: Sequence[Callable] | Callable | None = None, - map_items: bool = True, + map_items: bool | int = True, unpack_items: bool = False, log_stats: bool | str = False, lazy: bool | None = False, @@ -238,9 +249,9 @@ def __init__( if transforms is None: transforms = [] - if not isinstance(map_items, bool): + if not isinstance(map_items, (bool, int)): raise ValueError( - f"Argument 'map_items' should be boolean. Got {type(map_items)}." + f"Argument 'map_items' should be boolean or int. Got {type(map_items)}." "Check brackets when passing a sequence of callables." ) @@ -391,8 +402,14 @@ class OneOf(Compose): transforms: sequence of callables. weights: probabilities corresponding to each callable in transforms. Probabilities are normalized to sum to one. - map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. - defaults to `True`. + map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple, + it can behave as follows: + + - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied + to the first level of items in `data`. + - If an integer is provided, it specifies the maximum level of nesting to which the transformation + should be recursively applied. This allows treating multi-sample transforms applied after another + multi-sample transform while controlling how deep the mapping goes. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution. @@ -414,7 +431,7 @@ def __init__( self, transforms: Sequence[Callable] | Callable | None = None, weights: Sequence[float] | float | None = None, - map_items: bool = True, + map_items: bool | int = True, unpack_items: bool = False, log_stats: bool | str = False, lazy: bool | None = False, diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 15c2499a73..1a365b8d8e 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -101,12 +101,12 @@ def _apply_transform( def apply_transform( transform: Callable[..., ReturnType], data: Any, - map_items: bool = True, + map_items: bool | int = True, unpack_items: bool = False, log_stats: bool | str = False, lazy: bool | None = None, overrides: dict | None = None, -) -> list[ReturnType] | ReturnType: +) -> list[Any] | ReturnType: """ Transform `data` with `transform`. @@ -117,8 +117,13 @@ def apply_transform( Args: transform: a callable to be used to transform `data`. data: an object to be transformed. - map_items: whether to apply transform to each item in `data`, - if `data` is a list or tuple. Defaults to True. + map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple, + it can behave as follows: + - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied + to the first level of items in `data`. + - If an integer is provided, it specifies the maximum level of nesting to which the transformation + should be recursively applied. This allows treating multi-sample transforms applied after another + multi-sample transform while controlling how deep the mapping goes. unpack_items: whether to unpack parameters using `*`. Defaults to False. log_stats: log errors when they occur in the processing pipeline. By default, this is set to False, which disables the logger for processing pipeline errors. Setting it to None or True will enable logging to the @@ -136,8 +141,12 @@ def apply_transform( Union[List[ReturnType], ReturnType]: The return type of `transform` or a list thereof. """ try: - if isinstance(data, (list, tuple)) and map_items: - return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data] + map_items_ = int(map_items) if isinstance(map_items, bool) else map_items + if isinstance(data, (list, tuple)) and map_items_ > 0: + return [ + apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) + for item in data + ] return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint diff --git a/tests/transforms/compose/test_compose.py b/tests/transforms/compose/test_compose.py index 3c53ac4a22..e6727c976f 100644 --- a/tests/transforms/compose/test_compose.py +++ b/tests/transforms/compose/test_compose.py @@ -141,6 +141,20 @@ def b(i, i2): self.assertEqual(mt.Compose(transforms, unpack_items=True)(data), expected) self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected) + def test_list_non_dict_compose_with_unpack_map_2(self): + + def a(i, i2): + return i + "a", i2 + "a2" + + def b(i, i2): + return i + "b", i2 + "b2" + + transforms = [a, b, a, b] + data = [[("", ""), ("", "")], [("t", "t"), ("t", "t")]] + expected = [[("abab", "a2b2a2b2"), ("abab", "a2b2a2b2")], [("tabab", "ta2b2a2b2"), ("tabab", "ta2b2a2b2")]] + self.assertEqual(mt.Compose(transforms, map_items=2, unpack_items=True)(data), expected) + self.assertEqual(execute_compose(data, transforms, map_items=2, unpack_items=True), expected) + def test_list_dict_compose_no_map(self): def a(d): # transform to handle dict data From b19f288443eb49f2efcb993b815a3a27024aa206 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Thu, 13 Feb 2025 11:34:16 +0100 Subject: [PATCH 10/16] update test structure Signed-off-by: thibaultdvx --- tests/{ => handlers}/test_handler_r2_score.py | 0 tests/{ => handlers}/test_handler_r2_score_dist.py | 0 tests/{ => metrics}/test_compute_r2_score.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename tests/{ => handlers}/test_handler_r2_score.py (100%) rename tests/{ => handlers}/test_handler_r2_score_dist.py (100%) rename tests/{ => metrics}/test_compute_r2_score.py (100%) diff --git a/tests/test_handler_r2_score.py b/tests/handlers/test_handler_r2_score.py similarity index 100% rename from tests/test_handler_r2_score.py rename to tests/handlers/test_handler_r2_score.py diff --git a/tests/test_handler_r2_score_dist.py b/tests/handlers/test_handler_r2_score_dist.py similarity index 100% rename from tests/test_handler_r2_score_dist.py rename to tests/handlers/test_handler_r2_score_dist.py diff --git a/tests/test_compute_r2_score.py b/tests/metrics/test_compute_r2_score.py similarity index 100% rename from tests/test_compute_r2_score.py rename to tests/metrics/test_compute_r2_score.py From 5c260aa54fc710f1bc3c32769ac4827e1130413e Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Thu, 13 Feb 2025 11:36:48 +0100 Subject: [PATCH 11/16] merge handler test files Signed-off-by: thibaultdvx --- tests/handlers/test_handler_r2_score.py | 30 +++++++++++ tests/handlers/test_handler_r2_score_dist.py | 54 -------------------- 2 files changed, 30 insertions(+), 54 deletions(-) delete mode 100644 tests/handlers/test_handler_r2_score_dist.py diff --git a/tests/handlers/test_handler_r2_score.py b/tests/handlers/test_handler_r2_score.py index f2fa243719..1700540c2a 100644 --- a/tests/handlers/test_handler_r2_score.py +++ b/tests/handlers/test_handler_r2_score.py @@ -15,8 +15,10 @@ import numpy as np import torch +import torch.distributed as dist from monai.handlers import R2Score +from tests.test_utils import DistCall, DistTestCase class TestHandlerR2Score(unittest.TestCase): @@ -37,5 +39,33 @@ def test_compute(self): np.testing.assert_allclose(0.867314, r2, rtol=1e-5) +class DistributedR2Score(DistTestCase): + + @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) + def test_compute(self): + r2_score = R2Score(multi_output="variance_weighted", p=1) + + device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" + if dist.get_rank() == 0: + y_pred = [torch.tensor([0.1, 1.0], device=device), torch.tensor([-0.25, 0.5], device=device)] + y = [torch.tensor([0.1, 0.82], device=device), torch.tensor([-0.2, 0.01], device=device)] + + if dist.get_rank() == 1: + y_pred = [ + torch.tensor([3.0, -0.2], device=device), + torch.tensor([0.99, 2.1], device=device), + torch.tensor([-0.1, 0.0], device=device), + ] + y = [ + torch.tensor([2.7, -0.1], device=device), + torch.tensor([1.58, 2.0], device=device), + torch.tensor([-1.0, -0.1], device=device), + ] + + r2_score.update([y_pred, y]) + + result = r2_score.compute() + np.testing.assert_allclose(0.829185, result, rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/handlers/test_handler_r2_score_dist.py b/tests/handlers/test_handler_r2_score_dist.py deleted file mode 100644 index 378989d555..0000000000 --- a/tests/handlers/test_handler_r2_score_dist.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import numpy as np -import torch -import torch.distributed as dist - -from monai.handlers import R2Score -from tests.utils import DistCall, DistTestCase - - -class DistributedR2Score(DistTestCase): - - @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) - def test_compute(self): - r2_score = R2Score(multi_output="variance_weighted", p=1) - - device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" - if dist.get_rank() == 0: - y_pred = [torch.tensor([0.1, 1.0], device=device), torch.tensor([-0.25, 0.5], device=device)] - y = [torch.tensor([0.1, 0.82], device=device), torch.tensor([-0.2, 0.01], device=device)] - - if dist.get_rank() == 1: - y_pred = [ - torch.tensor([3.0, -0.2], device=device), - torch.tensor([0.99, 2.1], device=device), - torch.tensor([-0.1, 0.0], device=device), - ] - y = [ - torch.tensor([2.7, -0.1], device=device), - torch.tensor([1.58, 2.0], device=device), - torch.tensor([-1.0, -0.1], device=device), - ] - - r2_score.update([y_pred, y]) - - result = r2_score.compute() - np.testing.assert_allclose(0.829185, result, rtol=1e-5) - - -if __name__ == "__main__": - unittest.main() From 2a7fdb506ae6b852d702ba29240fd407007c2f55 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Thu, 13 Feb 2025 11:41:00 +0100 Subject: [PATCH 12/16] fix issue in distributed test Signed-off-by: thibaultdvx --- tests/handlers/test_handler_r2_score.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/handlers/test_handler_r2_score.py b/tests/handlers/test_handler_r2_score.py index 1700540c2a..c5ad270da9 100644 --- a/tests/handlers/test_handler_r2_score.py +++ b/tests/handlers/test_handler_r2_score.py @@ -49,6 +49,7 @@ def test_compute(self): if dist.get_rank() == 0: y_pred = [torch.tensor([0.1, 1.0], device=device), torch.tensor([-0.25, 0.5], device=device)] y = [torch.tensor([0.1, 0.82], device=device), torch.tensor([-0.2, 0.01], device=device)] + r2_score.update([y_pred, y]) if dist.get_rank() == 1: y_pred = [ @@ -61,8 +62,7 @@ def test_compute(self): torch.tensor([1.58, 2.0], device=device), torch.tensor([-1.0, -0.1], device=device), ] - - r2_score.update([y_pred, y]) + r2_score.update([y_pred, y]) result = r2_score.compute() np.testing.assert_allclose(0.829185, result, rtol=1e-5) From 14dee3e604ef5b748d2b20d83deb8517b4be9323 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Thu, 13 Feb 2025 13:57:44 +0100 Subject: [PATCH 13/16] r2 docstring Signed-off-by: thibaultdvx --- monai/metrics/r2_score.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/monai/metrics/r2_score.py b/monai/metrics/r2_score.py index 0ad2e133a5..87130d7f1f 100644 --- a/monai/metrics/r2_score.py +++ b/monai/metrics/r2_score.py @@ -26,16 +26,25 @@ class R2Metric(CumulativeIterationMetric): - r"""Computes :math:`R^{2}` score (coefficient of determination): + r"""Computes :math:`R^{2}` score (coefficient of determination). :math:`R^{2}` is used to evaluate + a regression model. In the best case, when the predictions match exactly the observed values, :math:`R^{2} = 1`. + It has no lower bound, and the more negative it is, the worse the model is. Finally, a baseline model, which always + predicts the mean of observed values, will get :math:`R^{2} = 0`. .. math:: \operatorname {R^{2}}\left(Y, \hat{Y}\right) = 1 - \frac {\sum _{i=1}^{n}\left(y_i-\hat{y_i} \right)^{2}} {\sum _{i=1}^{n}\left(y_i-\bar{y} \right)^{2}}, + :label: r2 - where :math:`\bar{y}` is the mean of observed :math:`y` ; or adjusted :math:`R^{2}` score: + where :math:`\bar{y}` is the mean of observed :math:`y`. + + However, :math:`R^{2}` automatically increases when extra + variables are added to the model. To account for this phenomenon and penalize the addition of unnecessary variables, + :math:`adjusted \ R^{2}` (:math:`\bar{R}^{2}`) is defined: .. math:: \operatorname {\bar{R}^{2}} = 1 - (1-R^{2}) \frac {n-1}{n-p-1}, + :label: r2_adjusted where :math:`p` is the number of independant variables used for the regression. @@ -57,7 +66,7 @@ class R2Metric(CumulativeIterationMetric): - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances of each individual output. p: non-negative integer. - Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score. + Number of independent variables used for regression. ``p`` is used to compute :math:`\bar{R}^{2}` score. Defaults to 0 (standard :math:`R^{2}` score). """ @@ -121,7 +130,8 @@ def _calculate(y_pred: np.ndarray, y: np.ndarray, p: int) -> float: def compute_r2_score( y_pred: torch.Tensor, y: torch.Tensor, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int = 0 ) -> np.ndarray | float | npt.ArrayLike: - """Computes :math:`R^{2}` score (coefficient of determination). + """Computes :math:`R^{2}` score (coefficient of determination). :math:`R^{2}` is used to evaluate + a regression model according to equations :eq:`r2` and :eq:`r2_adjusted`. Args: y_pred: input data to compute :math:`R^{2}` score, the first dim must be batch. @@ -137,7 +147,7 @@ def compute_r2_score( - ``"variance_weighted"``: the scores of all outputs are averaged, weighted by the variances each individual output. p: non-negative integer. - Number of independent variables used for regression. ``p`` is used to compute adjusted :math:`R^{2}` score. + Number of independent variables used for regression. ``p`` is used to compute :math:`\bar{R}^{2}` score. Defaults to 0 (standard :math:`R^{2}` score). Raises: From 3a5b44da6902ecc16d276477d060b11abc106078 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Feb 2025 13:00:34 +0000 Subject: [PATCH 14/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/metrics/r2_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/metrics/r2_score.py b/monai/metrics/r2_score.py index 87130d7f1f..423b5e951a 100644 --- a/monai/metrics/r2_score.py +++ b/monai/metrics/r2_score.py @@ -36,7 +36,7 @@ class R2Metric(CumulativeIterationMetric): {\sum _{i=1}^{n}\left(y_i-\bar{y} \right)^{2}}, :label: r2 - where :math:`\bar{y}` is the mean of observed :math:`y`. + where :math:`\bar{y}` is the mean of observed :math:`y`. However, :math:`R^{2}` automatically increases when extra variables are added to the model. To account for this phenomenon and penalize the addition of unnecessary variables, From 38d4c13d73d9bb86afda7f4aed1d57841ce759c5 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Thu, 13 Feb 2025 14:23:55 +0100 Subject: [PATCH 15/16] change tss formula Signed-off-by: thibaultdvx --- monai/metrics/r2_score.py | 2 +- tests/handlers/test_handler_r2_score.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/metrics/r2_score.py b/monai/metrics/r2_score.py index 423b5e951a..22b18ab87d 100644 --- a/monai/metrics/r2_score.py +++ b/monai/metrics/r2_score.py @@ -120,7 +120,7 @@ def _check_r2_params(multi_output: MultiOutput | str, p: int) -> tuple[MultiOutp def _calculate(y_pred: np.ndarray, y: np.ndarray, p: int) -> float: num_obs = len(y) rss = np.sum((y_pred - y) ** 2) - tss = np.sum(y**2) - np.sum(y) ** 2 / num_obs + tss = np.sum((y - np.mean(y)) ** 2) r2 = 1 - (rss / tss) r2_adjusted = 1 - (1 - r2) * (num_obs - 1) / (num_obs - p - 1) diff --git a/tests/handlers/test_handler_r2_score.py b/tests/handlers/test_handler_r2_score.py index c5ad270da9..b4d4c1613e 100644 --- a/tests/handlers/test_handler_r2_score.py +++ b/tests/handlers/test_handler_r2_score.py @@ -67,5 +67,6 @@ def test_compute(self): result = r2_score.compute() np.testing.assert_allclose(0.829185, result, rtol=1e-5) + if __name__ == "__main__": unittest.main() From 809d791d7c08b1c21abe274ff103401b1708d532 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Thu, 13 Feb 2025 16:07:59 +0100 Subject: [PATCH 16/16] fix test issue with ignite dependency Signed-off-by: thibaultdvx --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index 1fc3da4a19..049c82d4c2 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -98,6 +98,7 @@ def run_testsuit(): "test_handler_parameter_scheduler", "test_handler_post_processing", "test_handler_prob_map_producer", + "test_handler_r2_score", "test_handler_regression_metrics", "test_handler_regression_metrics_dist", "test_handler_rocauc",