Skip to content

Commit

Permalink
Add predict_by_points method
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Oct 18, 2024
1 parent 0726b29 commit a67328a
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 15 deletions.
37 changes: 27 additions & 10 deletions samgeo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,25 +795,38 @@ def geojson_to_coords(

def coords_to_xy(
src_fp: str,
coords: list,
coords: np.ndarray,
coord_crs: str = "epsg:4326",
return_out_of_bounds=False,
**kwargs,
) -> list:
"""Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.
) -> np.ndarray:
"""Converts a list or array of coordinates to pixel coordinates, i.e., (col, row) coordinates.
Args:
src_fp: The source raster file path.
coords: A list of coordinates in the format of [[x1, y1], [x2, y2], ...]
coords: A 2D or 3D array of coordinates. Can be of shape [[x1, y1], [x2, y2], ...]
or [[[x1, y1]], [[x2, y2]], ...].
coord_crs: The coordinate CRS of the input coordinates. Defaults to "epsg:4326".
return_out_of_bounds: Whether to return out of bounds coordinates. Defaults to False.
return_out_of_bounds: Whether to return out-of-bounds coordinates. Defaults to False.
**kwargs: Additional keyword arguments to pass to rasterio.transform.rowcol.
Returns:
A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...]
A 2D or 3D array of pixel coordinates in the same format as the input.
"""
from rasterio.warp import transform as transform_coords

out_of_bounds = []
if isinstance(coords, np.ndarray):
input_is_3d = coords.ndim == 3 # Check if the input is a 3D array
else:
input_is_3d = False

# Flatten the 3D array to 2D if necessary
if input_is_3d:
original_shape = coords.shape # Store the original shape
coords = coords.reshape(-1, 2) # Flatten to 2D

# Convert ndarray to a list if necessary
if isinstance(coords, np.ndarray):
coords = coords.tolist()

Expand All @@ -822,8 +835,9 @@ def coords_to_xy(
width = src.width
height = src.height
if coord_crs != src.crs:
xs, ys = transform_coords(xs, ys, coord_crs, src.crs, **kwargs)
xs, ys = transform_coords(coord_crs, src.crs, xs, ys, **kwargs)
rows, cols = rasterio.transform.rowcol(src.transform, xs, ys, **kwargs)

result = [[col, row] for col, row in zip(cols, rows)]

output = []
Expand All @@ -834,9 +848,12 @@ def coords_to_xy(
else:
out_of_bounds.append(i)

# output = [
# [x, y] for x, y in result if x >= 0 and y >= 0 and x < width and y < height
# ]
# Convert the output back to the original shape if input was 3D
output = np.array(output)
if input_is_3d:
output = output.reshape(original_shape)

# Handle cases where no valid pixel coordinates are found
if len(output) == 0:
print("No valid pixel coordinates found.")
elif len(output) < len(coords):
Expand Down
140 changes: 135 additions & 5 deletions samgeo/samgeo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
min_mask_region_area: int = 0,
output_mode: str = "binary_mask",
use_m2m: bool = False,
multimask_output: bool = True,
multimask_output: bool = False,
max_hole_area: float = 0.0,
max_sprinkle_area: float = 0.0,
**kwargs: Any,
Expand Down Expand Up @@ -100,6 +100,7 @@ def __init__(
memory.
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
multimask_output (bool): Whether to output multimask at each point of the grid.
Defaults to False.
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
Expand Down Expand Up @@ -546,7 +547,7 @@ def predict(
point_labels: Optional[np.ndarray] = None,
boxes: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
multimask_output: bool = False,
return_logits: bool = False,
normalize_coords: bool = True,
point_crs: Optional[str] = None,
Expand All @@ -573,7 +574,7 @@ def predict(
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
multimask_output (bool, optional): Whether to output multimask at each
point of the grid. Defaults to True.
point of the grid. Defaults to False.
return_logits (bool, optional): If true, returns un-thresholded masks logits
instead of a binary mask.
normalize_coords (bool, optional): Whether to normalize the coordinates.
Expand Down Expand Up @@ -688,13 +689,142 @@ def predict(
if return_results:
return masks, scores, logits

def predict_by_points(
self,
point_coords_batch: List[np.ndarray] = None,
point_labels_batch: List[np.ndarray] = None,
box_batch: List[np.ndarray] = None,
mask_input_batch: List[np.ndarray] = None,
multimask_output: bool = False,
return_logits: bool = False,
normalize_coords=True,
point_crs: Optional[str] = None,
output: Optional[str] = None,
index: Optional[int] = None,
mask_multiplier: int = 255,
dtype: str = "float32",
return_results: bool = False,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Predict the mask for the input image.
Args:
point_coords (np.ndarray, optional): The point coordinates. Defaults to None.
point_labels (np.ndarray, optional): The point labels. Defaults to None.
boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray, optional): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
multimask_output (bool, optional): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
multimask_output (bool, optional): Whether to output multimask at each
point of the grid. Defaults to True.
return_logits (bool, optional): If true, returns un-thresholded masks logits
instead of a binary mask.
normalize_coords (bool, optional): Whether to normalize the coordinates.
Defaults to True.
point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.
output (str, optional): The path to the output image. Defaults to None.
index (index, optional): The index of the mask to save. Defaults to None,
which will save the mask with the highest score.
mask_multiplier (int, optional): The mask multiplier for the output mask,
which is usually a binary mask [0, 1].
dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
return_results (bool, optional): Whether to return the predicted masks,
scores, and logits. Defaults to False.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: The mask, the multimask,
and the logits.
"""
import geopandas as gpd

if hasattr(self, "image_batch") and self.image_batch is not None:
pass
elif self.image is not None:
self.predictor.set_image_batch([self.image])
setattr(self, "image_batch", [self.image])
else:
raise ValueError("Please set the input image first using set_image().")

if isinstance(point_coords_batch, str) or isinstance(
point_coords_batch, gpd.GeoDataFrame
):
if isinstance(point_coords_batch, str):
gdf = gpd.read_file(point_coords_batch)
else:
gdf = point_coords_batch
if gdf.crs is None and (point_crs is not None):
gdf.crs = point_crs

points = gdf.geometry.apply(lambda geom: [geom.x, geom.y])
coordinates_array = np.array([[point] for point in points])
points = common.coords_to_xy(self.source, coordinates_array, point_crs)
num_points = points.shape[0]
if point_labels_batch is None:
labels = np.array([[1] for i in range(num_points)])
else:
labels = point_labels_batch

elif isinstance(point_coords_batch, list):
points = point_coords_batch
num_points = points.shape[0]
if point_labels_batch is None:
labels = np.array([[1] for i in range(num_points)])
else:
labels = point_labels_batch
else:
raise ValueError("point_coords must be a list, a GeoDataFrame, or a path.")

predictor = self.predictor

masks_batch, scores_batch, logits_batch = predictor.predict_batch(
point_coords_batch=[points],
point_labels_batch=[labels],
box_batch=box_batch,
mask_input_batch=mask_input_batch,
multimask_output=multimask_output,
return_logits=return_logits,
normalize_coords=normalize_coords,
)

masks = masks_batch[0]
scores = scores_batch[0]
logits = logits_batch[0]

if multimask_output and (index is not None):
masks = masks[:, index, :, :]

if masks.ndim > 3:
masks = masks.squeeze()

output_masks = []
sums = np.sum(masks, axis=(1, 2))
for index, mask in enumerate(masks):
item = {"segmentation": mask.astype("bool"), "area": sums[index]}
output_masks.append(item)

self.masks = output_masks
self.scores = scores
self.logits = logits

if output is not None:
self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)

if return_results:
return output_masks, scores, logits

def predict_batch(
self,
point_coords_batch: List[np.ndarray] = None,
point_labels_batch: List[np.ndarray] = None,
box_batch: List[np.ndarray] = None,
mask_input_batch: List[np.ndarray] = None,
multimask_output: bool = True,
multimask_output: bool = False,
return_logits: bool = False,
normalize_coords=True,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
Expand All @@ -710,7 +840,7 @@ def predict_batch(
mask_input_batch (Optional[List[np.ndarray]]): A batch of mask inputs.
Defaults to None.
multimask_output (bool): Whether to output multimask at each point
of the grid. Defaults to True.
of the grid. Defaults to False.
return_logits (bool): Whether to return the logits. Defaults to False.
normalize_coords (bool): Whether to normalize the coordinates.
Defaults to True.
Expand Down

0 comments on commit a67328a

Please sign in to comment.