Skip to content

Commit

Permalink
Update module
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Dec 12, 2023
1 parent 0171979 commit 8a1a188
Showing 1 changed file with 13 additions and 56 deletions.
69 changes: 13 additions & 56 deletions samgeo/efficient_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,19 @@ def __init__(
if device == "cuda":
torch.cuda.empty_cache()

# self.checkpoint = checkpoint
# self.model_type = model_type
# self.device = device
# self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model
# self.source = None # Store the input image path
# self.image = None # Store the input image as a numpy array
# # Store the masks as a list of dictionaries. Each mask is a dictionary
# # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box
# self.masks = None
# self.objects = None # Store the mask objects as a numpy array
# # Store the annotations (objects with random color) as a numpy array.
# self.annotations = None
self.checkpoint = checkpoint
self.model_type = model_type
self.model = model
self.device = device
self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model
self.source = None # Store the input image path
self.image = None # Store the input image as a numpy array
# Store the masks as a list of dictionaries. Each mask is a dictionary
# containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box
self.masks = None
self.objects = None # Store the mask objects as a numpy array
# Store the annotations (objects with random color) as a numpy array.
self.annotations = None

# # Store the predicted masks, iou_predictions, and low_res_masks
# self.prediction = None
Expand All @@ -111,50 +112,6 @@ def __init__(
# # Segment selected objects using input prompts
# self.predictor = SamPredictor(self.sam, **sam_kwargs)

def __call__(
self,
image,
foreground=True,
erosion_kernel=(3, 3),
mask_multiplier=255,
**kwargs,
):
"""Generate masks for the input tile. This function originates from the segment-anything-eo repository.
See https://bit.ly/41pwiHw
Args:
image (np.ndarray): The input image as a numpy array.
foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3).
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
"""
h, w, _ = image.shape

masks = self.mask_generator.generate(image)

if foreground: # Extract foreground objects only
resulting_mask = np.zeros((h, w), dtype=np.uint8)
else:
resulting_mask = np.ones((h, w), dtype=np.uint8)
resulting_borders = np.zeros((h, w), dtype=np.uint8)

for m in masks:
mask = (m["segmentation"] > 0).astype(np.uint8)
resulting_mask += mask

# Apply erosion to the mask
if erosion_kernel is not None:
mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
mask_erode = (mask_erode > 0).astype(np.uint8)
edge_mask = mask - mask_erode
resulting_borders += edge_mask

resulting_mask = (resulting_mask > 0).astype(np.uint8)
resulting_borders = (resulting_borders > 0).astype(np.uint8)
resulting_mask_with_borders = resulting_mask - resulting_borders
return resulting_mask_with_borders * mask_multiplier

def generate(
self,
source,
Expand Down

0 comments on commit 8a1a188

Please sign in to comment.