From 8a1a188eb74651abd2cffa0ebc2bc1fc2c8517ce Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Tue, 12 Dec 2023 13:25:04 -0800 Subject: [PATCH] Update module --- samgeo/efficient_sam.py | 69 ++++++++--------------------------------- 1 file changed, 13 insertions(+), 56 deletions(-) diff --git a/samgeo/efficient_sam.py b/samgeo/efficient_sam.py index 377a7c7d..866d9791 100644 --- a/samgeo/efficient_sam.py +++ b/samgeo/efficient_sam.py @@ -80,18 +80,19 @@ def __init__( if device == "cuda": torch.cuda.empty_cache() - # self.checkpoint = checkpoint - # self.model_type = model_type - # self.device = device - # self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model - # self.source = None # Store the input image path - # self.image = None # Store the input image as a numpy array - # # Store the masks as a list of dictionaries. Each mask is a dictionary - # # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box - # self.masks = None - # self.objects = None # Store the mask objects as a numpy array - # # Store the annotations (objects with random color) as a numpy array. - # self.annotations = None + self.checkpoint = checkpoint + self.model_type = model_type + self.model = model + self.device = device + self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model + self.source = None # Store the input image path + self.image = None # Store the input image as a numpy array + # Store the masks as a list of dictionaries. Each mask is a dictionary + # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box + self.masks = None + self.objects = None # Store the mask objects as a numpy array + # Store the annotations (objects with random color) as a numpy array. + self.annotations = None # # Store the predicted masks, iou_predictions, and low_res_masks # self.prediction = None @@ -111,50 +112,6 @@ def __init__( # # Segment selected objects using input prompts # self.predictor = SamPredictor(self.sam, **sam_kwargs) - def __call__( - self, - image, - foreground=True, - erosion_kernel=(3, 3), - mask_multiplier=255, - **kwargs, - ): - """Generate masks for the input tile. This function originates from the segment-anything-eo repository. - See https://bit.ly/41pwiHw - - Args: - image (np.ndarray): The input image as a numpy array. - foreground (bool, optional): Whether to generate the foreground mask. Defaults to True. - erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. Defaults to (3, 3). - mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1]. - You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. - """ - h, w, _ = image.shape - - masks = self.mask_generator.generate(image) - - if foreground: # Extract foreground objects only - resulting_mask = np.zeros((h, w), dtype=np.uint8) - else: - resulting_mask = np.ones((h, w), dtype=np.uint8) - resulting_borders = np.zeros((h, w), dtype=np.uint8) - - for m in masks: - mask = (m["segmentation"] > 0).astype(np.uint8) - resulting_mask += mask - - # Apply erosion to the mask - if erosion_kernel is not None: - mask_erode = cv2.erode(mask, erosion_kernel, iterations=1) - mask_erode = (mask_erode > 0).astype(np.uint8) - edge_mask = mask - mask_erode - resulting_borders += edge_mask - - resulting_mask = (resulting_mask > 0).astype(np.uint8) - resulting_borders = (resulting_borders > 0).astype(np.uint8) - resulting_mask_with_borders = resulting_mask - resulting_borders - return resulting_mask_with_borders * mask_multiplier - def generate( self, source,