Skip to content

Commit

Permalink
Add detection_filter to predict() method to allow for user-define…
Browse files Browse the repository at this point in the history
…d logic (#307)

* Added a `detection_filter` parameter to `predict` method to allow for
filtering of detections based on user-defined logic;
appears to seamlessly work with `batch_predict` also;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
brendancol and pre-commit-ci[bot] authored Aug 19, 2024
1 parent 9894762 commit 409953a
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions samgeo/text_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
Credits to Luca Medeiros for the original implementation.
"""

import argparse
import inspect
import os
import warnings
import argparse

import numpy as np
import torch
from PIL import Image
Expand Down Expand Up @@ -238,6 +240,7 @@ def predict(
save_args={},
return_results=False,
return_coords=False,
detection_filter=None,
**kwargs,
):
"""
Expand All @@ -253,6 +256,10 @@ def predict(
dtype (np.dtype, optional): Data type for the prediction. Defaults to np.uint8.
save_args (dict, optional): Save arguments for the prediction. Defaults to {}.
return_results (bool, optional): Whether to return the results. Defaults to False.
detection_filter (callable, optional):
Callable which with box, mask, logit, phrase, and index args returns a boolean.
If provided, the function will be called for each detected object.
Defaults to None.
Returns:
tuple: Tuple containing masks, boxes, phrases, and logits.
Expand Down Expand Up @@ -312,12 +319,34 @@ def predict(
image_np[..., 0], dtype=dtype
) # Adjusted for single channel

for i, (box, mask) in enumerate(zip(boxes, masks)):
# Validate the detection_filter argument
if detection_filter is not None:

if not callable(detection_filter):
raise ValueError("detection_filter must be callable.")

req_nargs = 6 if inspect.ismethod(detection_filter) else 5
if not len(inspect.signature(detection_filter).parameters) == req_nargs:
raise ValueError(
"detection_filter required args: "
"box, mask, logit, phrase, and index."
)

for i, (box, mask, logit, phrase) in enumerate(
zip(boxes, masks, logits, phrases)
):

# Convert tensor to numpy array if necessary and ensure it contains integers
if isinstance(mask, torch.Tensor):
mask = (
mask.cpu().numpy().astype(dtype)
) # If mask is on GPU, use .cpu() before .numpy()

# Apply the user-supplied filtering logic if provided
if detection_filter is not None:
if not detection_filter(box, mask, logit, phrase, i):
continue

mask_overlay += ((mask > 0) * (i + 1)).astype(
dtype
) # Assign a unique value for each mask
Expand Down

0 comments on commit 409953a

Please sign in to comment.