diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py index a2cbd87bd..0d5ec475e 100644 --- a/supervision/detection/utils.py +++ b/supervision/detection/utils.py @@ -720,25 +720,71 @@ def move_masks( masks (npt.NDArray[np.bool_]): A 3D array of binary masks corresponding to the predictions. Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the dimensions of each mask. - offset (npt.NDArray[np.int32]): An array of shape `(2,)` containing non-negative - int values `[dx, dy]`. + offset (npt.NDArray[np.int32]): An array of shape `(2,)` containing int values + `[dx, dy]`. Supports both positive and negative values for bidirectional + movement. resolution_wh (Tuple[int, int]): The width and height of the desired mask resolution. Returns: (npt.NDArray[np.bool_]) repositioned masks, optionally padded to the specified shape. - """ - if offset[0] < 0 or offset[1] < 0: - raise ValueError(f"Offset values must be non-negative integers. Got: {offset}") + Examples: + ```python + import numpy as np + import supervision as sv + mask = np.array([[[False, False, False, False], + [False, True, True, False], + [False, True, True, False], + [False, False, False, False]]], dtype=bool) + + offset = np.array([1, 1]) + sv.move_masks(mask, offset, resolution_wh=(4, 4)) + # array([[[False, False, False, False], + # [False, False, False, False], + # [False, False, True, True], + # [False, False, True, True]]], dtype=bool) + + offset = np.array([-2, 2]) + sv.move_masks(mask, offset, resolution_wh=(4, 4)) + # array([[[False, False, False, False], + # [False, False, False, False], + # [False, False, False, False], + # [True, False, False, False]]], dtype=bool) + ``` + """ mask_array = np.full((masks.shape[0], resolution_wh[1], resolution_wh[0]), False) - mask_array[ - :, - offset[1] : masks.shape[1] + offset[1], - offset[0] : masks.shape[2] + offset[0], - ] = masks + + if offset[0] < 0: + source_x_start = -offset[0] + source_x_end = min(masks.shape[2], resolution_wh[0] - offset[0]) + destination_x_start = 0 + destination_x_end = min(resolution_wh[0], masks.shape[2] + offset[0]) + else: + source_x_start = 0 + source_x_end = min(masks.shape[2], resolution_wh[0] - offset[0]) + destination_x_start = offset[0] + destination_x_end = offset[0] + source_x_end - source_x_start + + if offset[1] < 0: + source_y_start = -offset[1] + source_y_end = min(masks.shape[1], resolution_wh[1] - offset[1]) + destination_y_start = 0 + destination_y_end = min(resolution_wh[1], masks.shape[1] + offset[1]) + else: + source_y_start = 0 + source_y_end = min(masks.shape[1], resolution_wh[1] - offset[1]) + destination_y_start = offset[1] + destination_y_end = offset[1] + source_y_end - source_y_start + + if source_x_end > source_x_start and source_y_end > source_y_start: + mask_array[ + :, + destination_y_start:destination_y_end, + destination_x_start:destination_x_end, + ] = masks[:, source_y_start:source_y_end, source_x_start:source_x_end] return mask_array diff --git a/test/detection/test_utils.py b/test/detection/test_utils.py index 87e50f6a4..d93c72c83 100644 --- a/test/detection/test_utils.py +++ b/test/detection/test_utils.py @@ -16,6 +16,7 @@ merge_data, merge_metadata, move_boxes, + move_masks, process_roboflow_result, scale_boxes, xcycwh_to_xyxy, @@ -442,6 +443,268 @@ def test_move_boxes( assert np.array_equal(result, expected_result) +@pytest.mark.parametrize( + "masks, offset, resolution_wh, expected_result, exception", + [ + ( + np.array( + [ + [ + [False, False, False, False], + [False, True, True, False], + [False, True, True, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + np.array([0, 0]), + (4, 4), + np.array( + [ + [ + [False, False, False, False], + [False, True, True, False], + [False, True, True, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + DoesNotRaise(), + ), + ( + np.array( + [ + [ + [False, False, False, False], + [False, True, True, False], + [False, True, True, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + np.array([-1, -1]), + (4, 4), + np.array( + [ + [ + [True, True, False, False], + [True, True, False, False], + [False, False, False, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + DoesNotRaise(), + ), + ( + np.array( + [ + [ + [False, False, False, False], + [False, True, True, False], + [False, True, True, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + np.array([-2, -2]), + (4, 4), + np.array( + [ + [ + [True, False, False, False], + [False, False, False, False], + [False, False, False, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + DoesNotRaise(), + ), + ( + np.array( + [ + [ + [False, False, False, False], + [False, True, True, False], + [False, True, True, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + np.array([-3, -3]), + (4, 4), + np.array( + [ + [ + [False, False, False, False], + [False, False, False, False], + [False, False, False, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + DoesNotRaise(), + ), + ( + np.array( + [ + [ + [False, False, False, False], + [False, True, True, False], + [False, True, True, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + np.array([-2, -1]), + (4, 4), + np.array( + [ + [ + [True, False, False, False], + [True, False, False, False], + [False, False, False, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + DoesNotRaise(), + ), + ( + np.array( + [ + [ + [False, False, False, False], + [False, True, True, False], + [False, True, True, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + np.array([-1, -2]), + (4, 4), + np.array( + [ + [ + [True, True, False, False], + [False, False, False, False], + [False, False, False, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + DoesNotRaise(), + ), + ( + np.array( + [ + [ + [False, False, False, False], + [False, True, True, False], + [False, True, True, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + np.array([-2, 2]), + (4, 4), + np.array( + [ + [ + [False, False, False, False], + [False, False, False, False], + [False, False, False, False], + [True, False, False, False], + ] + ], + dtype=bool, + ), + DoesNotRaise(), + ), + ( + np.array( + [ + [ + [False, False, False, False], + [False, True, True, False], + [False, True, True, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + np.array([3, 3]), + (4, 4), + np.array( + [ + [ + [False, False, False, False], + [False, False, False, False], + [False, False, False, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + DoesNotRaise(), + ), + ( + np.array( + [ + [ + [False, False, False, False], + [False, True, True, False], + [False, True, True, False], + [False, False, False, False], + ] + ], + dtype=bool, + ), + np.array([3, 3]), + (6, 6), + np.array( + [ + [ + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + ] + ], + dtype=bool, + ), + DoesNotRaise(), + ), + ], +) +def test_move_masks( + masks: np.ndarray, + offset: np.ndarray, + resolution_wh: Tuple[int, int], + expected_result: np.ndarray, + exception: Exception, +) -> None: + with exception: + result = move_masks(masks=masks, offset=offset, resolution_wh=resolution_wh) + np.testing.assert_array_equal(result, expected_result) + + @pytest.mark.parametrize( "xyxy, factor, expected_result, exception", [