diff --git a/docs/examples/sam2_point_prompts.ipynb b/docs/examples/sam2_point_prompts.ipynb new file mode 100644 index 00000000..287a4083 --- /dev/null +++ b/docs/examples/sam2_point_prompts.ipynb @@ -0,0 +1,497 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Segmenting remote sensing imagery with point prompts\n", + "\n", + "[![image](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/opengeos/segment-geospatial/blob/main/docs/examples/sam2_point_prompts.ipynb)\n", + "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/segment-geospatial/blob/main/docs/examples/sam2_point_prompts.ipynb)\n", + "\n", + "This notebook shows how to generate object masks from point prompts with the Segment Anything Model 2 (SAM 2). \n", + "\n", + "Make sure you use GPU runtime for this notebook. For Google Colab, go to `Runtime` -> `Change runtime type` and select `GPU` as the hardware accelerator. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install dependencies\n", + "\n", + "Uncomment and run the following cell to install the required dependencies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %pip install -U segment-geospatial" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import leafmap\n", + "from samgeo import SamGeo2, regularize" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create an interactive map" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m = leafmap.Map(center=[47.653287, -117.588070], zoom=16, height=\"800px\")\n", + "m.add_basemap(\"Satellite\")\n", + "m" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download a sample image\n", + "\n", + "Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map. If no geometry is drawn, the default bounding box will be used." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if m.user_roi is not None:\n", + " bbox = m.user_roi_bounds()\n", + "else:\n", + " bbox = [-117.6029, 47.65, -117.5936, 47.6563]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = \"satellite.tif\"\n", + "leafmap.map_tiles_to_geotiff(\n", + " output=image, bbox=bbox, zoom=18, source=\"Satellite\", overwrite=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also use your own image. Uncomment and run the following cell to use your own image." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# image = '/path/to/your/own/image.tif'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the downloaded image on the map." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.layers[-1].visible = False\n", + "m.add_raster(image, layer_name=\"Image\")\n", + "m" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize SAM class\n", + "\n", + "Set `automatic=False` to enable the `SAM2ImagePredictor`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sam = SamGeo2(\n", + " model_id=\"sam2-hiera-large\",\n", + " automatic=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify the image to segment. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sam.set_image(image)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Segment the image\n", + "\n", + "Use the `predict_by_points()` method to segment the image with specified point coordinates. You can use the draw tools to add place markers on the map. If no point is added, the default sample points will be used.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if m.user_rois is not None:\n", + " point_coords_batch = m.user_rois\n", + "else:\n", + " point_coords_batch = [\n", + " [-117.599896, 47.655345],\n", + " [-117.59992, 47.655167],\n", + " [-117.599928, 47.654974],\n", + " [-117.599518, 47.655337],\n", + " ]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Segment the objects using the point prompts and save the output masks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sam.predict_by_points(\n", + " point_coords_batch=point_coords_batch,\n", + " point_crs=\"EPSG:4326\",\n", + " output=\"mask.tif\",\n", + " dtype=\"uint8\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Display the result\n", + "\n", + "Add the segmented image to the map." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.add_raster(\"mask.tif\", cmap=\"viridis\", nodata=0, opacity=0.7, layer_name=\"Mask\")\n", + "m" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image](https://github.com/user-attachments/assets/49e413b9-e159-4d72-bf23-a0318bc82d44)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use an existing vector dataset as points prompts\n", + "\n", + "Alternatively, you can specify a file path or HTTP URL to a vector dataset containing point geometries." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "geojson = \"https://github.com/opengeos/datasets/releases/download/places/wa_building_centroids.geojson\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the vector dataawr on the map." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m = leafmap.Map()\n", + "m.add_raster(image, layer_name=\"Image\")\n", + "m.add_circle_markers_from_xy(\n", + " geojson, radius=3, color=\"red\", fill_color=\"yellow\", fill_opacity=0.8\n", + ")\n", + "m" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image](https://github.com/user-attachments/assets/f0d3ff1e-15fa-4bd3-ac15-637e8d63527d)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Segment image with a vector dataset\n", + "\n", + "Segment the image using the specified file path to the vector dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_masks = \"building_masks.tif\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sam.predict_by_points(\n", + " point_coords_batch=geojson,\n", + " point_crs=\"EPSG:4326\",\n", + " output=output_masks,\n", + " dtype=\"uint8\",\n", + " multimask_output=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the segmented masks on the map." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m.add_raster(\n", + " output_masks, cmap=\"jet\", nodata=0, opacity=0.7, layer_name=\"Building masks\"\n", + ")\n", + "m" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image](https://github.com/user-attachments/assets/262e1a31-1648-47d2-9e71-c85ab15b1a5c)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Clean up the result\n", + "\n", + "Remove small objects from the segmented masks, fill holes, and compute geometric properties." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "out_vector = \"building_vector.geojson\"\n", + "out_image = \"buildings.tif\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "array, gdf = sam.region_groups(\n", + " output_masks, min_size=200, out_vector=out_vector, out_image=out_image\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdf.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image](https://github.com/user-attachments/assets/af9ffa11-8ebe-4b42-8cba-3f5bcc4912f4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Regularize building footprints\n", + "\n", + "Regularize the building footprints using the `regularize()` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_regularized = \"building_regularized.geojson\"\n", + "regularize(out_vector, output_regularized)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the regularized building footprints on the map." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m = leafmap.Map()\n", + "m.add_raster(image, layer_name=\"Image\")\n", + "style = {\n", + " \"color\": \"#ffff00\",\n", + " \"weight\": 2,\n", + " \"fillColor\": \"#7c4185\",\n", + " \"fillOpacity\": 0,\n", + "}\n", + "m.add_raster(out_image, cmap=\"tab20\", opacity=0.7, nodata=0, layer_name=\"Buildings\")\n", + "m.add_vector(\n", + " output_regularized, style=style, layer_name=\"Building regularized\", info_mode=None\n", + ")\n", + "m" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image](https://github.com/user-attachments/assets/b39ee029-2089-45b8-8ac0-ba0d750cec22)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Interactive segmentation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sam.show_map()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](https://github.com/user-attachments/assets/4f487505-6e89-4892-9a70-95ab0aa69cb6)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/mkdocs.yml b/mkdocs.yml index ca162fbc..798eb527 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -65,6 +65,7 @@ nav: - examples/sam2_predictor.ipynb - examples/sam2_video.ipynb - examples/sam2_box_prompts.ipynb + - examples/sam2_point_prompts.ipynb - examples/sam2_text_prompts.ipynb - Workshops: - workshops/purdue.ipynb diff --git a/samgeo/common.py b/samgeo/common.py index 86ba0dac..e4ece042 100644 --- a/samgeo/common.py +++ b/samgeo/common.py @@ -795,25 +795,38 @@ def geojson_to_coords( def coords_to_xy( src_fp: str, - coords: list, + coords: np.ndarray, coord_crs: str = "epsg:4326", return_out_of_bounds=False, **kwargs, -) -> list: - """Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates. +) -> np.ndarray: + """Converts a list or array of coordinates to pixel coordinates, i.e., (col, row) coordinates. Args: src_fp: The source raster file path. - coords: A list of coordinates in the format of [[x1, y1], [x2, y2], ...] + coords: A 2D or 3D array of coordinates. Can be of shape [[x1, y1], [x2, y2], ...] + or [[[x1, y1]], [[x2, y2]], ...]. coord_crs: The coordinate CRS of the input coordinates. Defaults to "epsg:4326". - return_out_of_bounds: Whether to return out of bounds coordinates. Defaults to False. + return_out_of_bounds: Whether to return out-of-bounds coordinates. Defaults to False. **kwargs: Additional keyword arguments to pass to rasterio.transform.rowcol. Returns: - A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...] + A 2D or 3D array of pixel coordinates in the same format as the input. """ + from rasterio.warp import transform as transform_coords + out_of_bounds = [] + if isinstance(coords, np.ndarray): + input_is_3d = coords.ndim == 3 # Check if the input is a 3D array + else: + input_is_3d = False + # Flatten the 3D array to 2D if necessary + if input_is_3d: + original_shape = coords.shape # Store the original shape + coords = coords.reshape(-1, 2) # Flatten to 2D + + # Convert ndarray to a list if necessary if isinstance(coords, np.ndarray): coords = coords.tolist() @@ -822,8 +835,9 @@ def coords_to_xy( width = src.width height = src.height if coord_crs != src.crs: - xs, ys = transform_coords(xs, ys, coord_crs, src.crs, **kwargs) + xs, ys = transform_coords(coord_crs, src.crs, xs, ys, **kwargs) rows, cols = rasterio.transform.rowcol(src.transform, xs, ys, **kwargs) + result = [[col, row] for col, row in zip(cols, rows)] output = [] @@ -834,9 +848,12 @@ def coords_to_xy( else: out_of_bounds.append(i) - # output = [ - # [x, y] for x, y in result if x >= 0 and y >= 0 and x < width and y < height - # ] + # Convert the output back to the original shape if input was 3D + output = np.array(output) + if input_is_3d: + output = output.reshape(original_shape) + + # Handle cases where no valid pixel coordinates are found if len(output) == 0: print("No valid pixel coordinates found.") elif len(output) < len(coords): @@ -1914,7 +1931,7 @@ def sam_map_gui(sam, basemap="SATELLITE", repeat_mode=True, out_dir=None, **kwar description="Mask opacity:", min=0, max=1, - value=0.5, + value=0.7, readout=True, continuous_update=True, layout=widgets.Layout(width=widget_width, padding=padding), @@ -2165,12 +2182,20 @@ def segment_button_click(change): filename = f"masks_{random_string()}.tif" filename = os.path.join(out_dir, filename) - sam.predict( - point_coords=point_coords, - point_labels=point_labels, - point_crs="EPSG:4326", - output=filename, - ) + if sam.model_version == "sam": + sam.predict( + point_coords=point_coords, + point_labels=point_labels, + point_crs="EPSG:4326", + output=filename, + ) + elif sam.model_version == "sam2": + sam.predict_by_points( + point_coords_batch=point_coords, + point_labels_batch=point_labels, + point_crs="EPSG:4326", + output=filename, + ) if m.find_layer("Masks") is not None: m.remove_layer(m.find_layer("Masks")) if m.find_layer("Regularized") is not None: @@ -2183,18 +2208,16 @@ def segment_button_click(change): os.remove(sam.prediction_fp) except: pass - # Skip the image layer if localtileserver is not available try: m.add_raster( filename, nodata=0, - cmap="Blues", + cmap="tab20", opacity=opacity_slider.value, layer_name="Masks", zoom_to_layer=False, ) - if rectangular.value: vector = filename.replace(".tif", ".gpkg") vector_rec = filename.replace(".tif", "_rect.gpkg") @@ -2285,7 +2308,7 @@ def reset_button_click(change): if change["new"]: segment_button.value = False reset_button.value = False - opacity_slider.value = 0.5 + opacity_slider.value = 0.7 rectangular.value = False colorpicker.value = "#ffff00" output.clear_output() @@ -2478,7 +2501,7 @@ def text_sam_gui( box_threshold=0.25, text_threshold=0.25, cmap="viridis", - opacity=0.5, + opacity=0.7, **kwargs, ): """Display the SAM Map GUI. @@ -2803,7 +2826,7 @@ def reset_button_click(change): segment_button.value = False save_button.value = False reset_button.value = False - opacity_slider.value = 0.5 + opacity_slider.value = 0.7 box_slider.value = 0.25 text_slider.value = 0.25 cmap_dropdown.value = "viridis" diff --git a/samgeo/samgeo.py b/samgeo/samgeo.py index a30d3573..9d0d7e01 100644 --- a/samgeo/samgeo.py +++ b/samgeo/samgeo.py @@ -73,6 +73,7 @@ def __init__( self.checkpoint = checkpoint self.model_type = model_type + self.model_version = "sam" self.device = device self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model self.source = None # Store the input image path diff --git a/samgeo/samgeo2.py b/samgeo/samgeo2.py index a7820e4d..a0bd88ca 100644 --- a/samgeo/samgeo2.py +++ b/samgeo/samgeo2.py @@ -43,7 +43,7 @@ def __init__( min_mask_region_area: int = 0, output_mode: str = "binary_mask", use_m2m: bool = False, - multimask_output: bool = True, + multimask_output: bool = False, max_hole_area: float = 0.0, max_sprinkle_area: float = 0.0, **kwargs: Any, @@ -100,6 +100,7 @@ def __init__( memory. use_m2m (bool): Whether to add a one step refinement using previous mask predictions. multimask_output (bool): Whether to output multimask at each point of the grid. + Defaults to False. max_hole_area (int): If max_hole_area > 0, we fill small holes in up to the maximum area of max_hole_area in low_res_masks. max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to @@ -132,6 +133,7 @@ def __init__( hydra_overrides_extra = [] self.model_id = model_id + self.model_version = "sam2" self.device = device if video: @@ -546,7 +548,7 @@ def predict( point_labels: Optional[np.ndarray] = None, boxes: Optional[np.ndarray] = None, mask_input: Optional[np.ndarray] = None, - multimask_output: bool = True, + multimask_output: bool = False, return_logits: bool = False, normalize_coords: bool = True, point_crs: Optional[str] = None, @@ -573,7 +575,7 @@ def predict( to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results. multimask_output (bool, optional): Whether to output multimask at each - point of the grid. Defaults to True. + point of the grid. Defaults to False. return_logits (bool, optional): If true, returns un-thresholded masks logits instead of a binary mask. normalize_coords (bool, optional): Whether to normalize the coordinates. @@ -688,13 +690,173 @@ def predict( if return_results: return masks, scores, logits + def predict_by_points( + self, + point_coords_batch: List[np.ndarray] = None, + point_labels_batch: List[np.ndarray] = None, + box_batch: List[np.ndarray] = None, + mask_input_batch: List[np.ndarray] = None, + multimask_output: bool = False, + return_logits: bool = False, + normalize_coords=True, + point_crs: Optional[str] = None, + output: Optional[str] = None, + index: Optional[int] = None, + unique: bool = True, + mask_multiplier: int = 255, + dtype: str = "int32", + return_results: bool = False, + **kwargs: Any, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Predict the mask for the input image. + + Args: + point_coords (np.ndarray, optional): The point coordinates. Defaults to None. + point_labels (np.ndarray, optional): The point labels. Defaults to None. + boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray, optional): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. + multimask_output (bool, optional): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + multimask_output (bool, optional): Whether to output multimask at each + point of the grid. Defaults to True. + return_logits (bool, optional): If true, returns un-thresholded masks logits + instead of a binary mask. + normalize_coords (bool, optional): Whether to normalize the coordinates. + Defaults to True. + point_crs (str, optional): The coordinate reference system (CRS) of the point prompts. + output (str, optional): The path to the output image. Defaults to None. + index (index, optional): The index of the mask to save. Defaults to None, + which will save the mask with the highest score. + mask_multiplier (int, optional): The mask multiplier for the output mask, + which is usually a binary mask [0, 1]. + dtype (np.dtype, optional): The data type of the output image. Defaults to np.int32. + return_results (bool, optional): Whether to return the predicted masks, + scores, and logits. Defaults to False. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray]: The mask, the multimask, + and the logits. + """ + import geopandas as gpd + + if hasattr(self, "image_batch") and self.image_batch is not None: + pass + elif self.image is not None: + self.predictor.set_image_batch([self.image]) + setattr(self, "image_batch", [self.image]) + else: + raise ValueError("Please set the input image first using set_image().") + + if isinstance(point_coords_batch, dict): + point_coords_batch = gpd.GeoDataFrame.from_features(point_coords_batch) + + if isinstance(point_coords_batch, str) or isinstance( + point_coords_batch, gpd.GeoDataFrame + ): + if isinstance(point_coords_batch, str): + gdf = gpd.read_file(point_coords_batch) + else: + gdf = point_coords_batch + if gdf.crs is None and (point_crs is not None): + gdf.crs = point_crs + + points = gdf.geometry.apply(lambda geom: [geom.x, geom.y]) + coordinates_array = np.array([[point] for point in points]) + points = common.coords_to_xy(self.source, coordinates_array, point_crs) + num_points = points.shape[0] + if point_labels_batch is None: + labels = np.array([[1] for i in range(num_points)]) + else: + labels = point_labels_batch + + elif isinstance(point_coords_batch, list): + if point_crs is not None: + point_coords_batch_crs = common.coords_to_xy( + self.source, point_coords_batch, point_crs + ) + else: + point_coords_batch_crs = point_coords_batch + num_points = len(point_coords_batch) + + points = [] + points.append([[point] for point in point_coords_batch_crs]) + + if point_labels_batch is None: + labels = np.array([[1] for i in range(num_points)]) + elif isinstance(point_labels_batch, list): + labels = [] + labels.append([[label] for label in point_labels_batch]) + labels = labels[0] + else: + labels = point_labels_batch + + points = np.array(points[0]) + labels = np.array(labels) + + elif isinstance(point_coords_batch, np.ndarray): + points = point_coords_batch + labels = point_labels_batch + else: + raise ValueError("point_coords must be a list, a GeoDataFrame, or a path.") + + predictor = self.predictor + + masks_batch, scores_batch, logits_batch = predictor.predict_batch( + point_coords_batch=[points], + point_labels_batch=[labels], + box_batch=box_batch, + mask_input_batch=mask_input_batch, + multimask_output=multimask_output, + return_logits=return_logits, + normalize_coords=normalize_coords, + ) + + masks = masks_batch[0] + scores = scores_batch[0] + logits = logits_batch[0] + + if multimask_output and (index is not None): + masks = masks[:, index, :, :] + + if masks.ndim > 3: + masks = masks.squeeze() + + output_masks = [] + sums = np.sum(masks, axis=(1, 2)) + for index, mask in enumerate(masks): + item = {"segmentation": mask.astype("bool"), "area": sums[index]} + output_masks.append(item) + + self.masks = output_masks + self.scores = scores + self.logits = logits + + if output is not None: + self.save_masks( + output, + foreground=True, + unique=unique, + mask_multiplier=mask_multiplier, + dtype=dtype, + **kwargs, + ) + + if return_results: + return output_masks, scores, logits + def predict_batch( self, point_coords_batch: List[np.ndarray] = None, point_labels_batch: List[np.ndarray] = None, box_batch: List[np.ndarray] = None, mask_input_batch: List[np.ndarray] = None, - multimask_output: bool = True, + multimask_output: bool = False, return_logits: bool = False, normalize_coords=True, ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: @@ -710,7 +872,7 @@ def predict_batch( mask_input_batch (Optional[List[np.ndarray]]): A batch of mask inputs. Defaults to None. multimask_output (bool): Whether to output multimask at each point - of the grid. Defaults to True. + of the grid. Defaults to False. return_logits (bool): Whether to return the logits. Defaults to False. normalize_coords (bool): Whether to normalize the coordinates. Defaults to True.