-
Notifications
You must be signed in to change notification settings - Fork 550
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add large image demo with
sahi
(#284)
* add large image demo with sahi * fix some typos * restructure based on reviews * update default patch size * add docstring and update docs * updates based on reviews * print information * add debug, update docs, add large image sample * update docs * update docs * update docs * direct user to install sahi
- Loading branch information
Showing
5 changed files
with
338 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
"""Perform MMYOLO inference on large images (as satellite imagery) as: | ||
```shell | ||
wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth # noqa: E501, E261. | ||
python demo/large_image_demo.py \ | ||
demo/large_image.jpg \ | ||
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \ | ||
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \ | ||
``` | ||
""" | ||
|
||
import os | ||
from argparse import ArgumentParser | ||
|
||
import mmcv | ||
from mmdet.apis import inference_detector, init_detector | ||
from mmengine.logging import print_log | ||
from mmengine.utils import ProgressBar | ||
|
||
try: | ||
from sahi.slicing import slice_image | ||
except ImportError: | ||
raise ImportError('Please run "pip install -U sahi" ' | ||
'to install sahi first for large image inference.') | ||
|
||
from mmyolo.registry import VISUALIZERS | ||
from mmyolo.utils import register_all_modules, switch_to_deploy | ||
from mmyolo.utils.large_image import merge_results_by_nms | ||
from mmyolo.utils.misc import get_file_list | ||
|
||
|
||
def parse_args(): | ||
parser = ArgumentParser( | ||
description='Perform MMYOLO inference on large images.') | ||
parser.add_argument( | ||
'img', help='Image path, include image file, dir and URL.') | ||
parser.add_argument('config', help='Config file') | ||
parser.add_argument('checkpoint', help='Checkpoint file') | ||
parser.add_argument( | ||
'--out-dir', default='./output', help='Path to output file') | ||
parser.add_argument( | ||
'--device', default='cuda:0', help='Device used for inference') | ||
parser.add_argument( | ||
'--show', action='store_true', help='Show the detection results') | ||
parser.add_argument( | ||
'--deploy', | ||
action='store_true', | ||
help='Switch model to deployment mode') | ||
parser.add_argument( | ||
'--score-thr', type=float, default=0.3, help='Bbox score threshold') | ||
parser.add_argument( | ||
'--patch-size', type=int, default=640, help='The size of patches') | ||
parser.add_argument( | ||
'--patch-overlap-ratio', | ||
type=int, | ||
default=0.25, | ||
help='Ratio of overlap between two patches') | ||
parser.add_argument( | ||
'--merge-iou-thr', | ||
type=float, | ||
default=0.25, | ||
help='IoU threshould for merging results') | ||
parser.add_argument( | ||
'--merge-nms-type', | ||
type=str, | ||
default='nms', | ||
help='NMS type for merging results') | ||
parser.add_argument( | ||
'--batch-size', | ||
type=int, | ||
default=1, | ||
help='Batch size, must greater than or equal to 1') | ||
parser.add_argument( | ||
'--debug', | ||
action='store_true', | ||
help='Export debug images at each stage for 1 input') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
# register all modules in mmdet into the registries | ||
register_all_modules() | ||
|
||
# build the model from a config file and a checkpoint file | ||
model = init_detector(args.config, args.checkpoint, device=args.device) | ||
|
||
if args.deploy: | ||
switch_to_deploy(model) | ||
|
||
if not os.path.exists(args.out_dir) and not args.show: | ||
os.mkdir(args.out_dir) | ||
|
||
# init visualizer | ||
visualizer = VISUALIZERS.build(model.cfg.visualizer) | ||
visualizer.dataset_meta = model.dataset_meta | ||
|
||
# get file list | ||
files, source_type = get_file_list(args.img) | ||
|
||
# if debug, only process the first file | ||
if args.debug: | ||
files = files[:1] | ||
|
||
# start detector inference | ||
print(f'Performing inference on {len(files)} images... \ | ||
This may take a while.') | ||
progress_bar = ProgressBar(len(files)) | ||
for file in files: | ||
# read image | ||
img = mmcv.imread(file) | ||
|
||
# arrange slices | ||
height, width = img.shape[:2] | ||
sliced_image_object = slice_image( | ||
img, | ||
slice_height=args.patch_size, | ||
slice_width=args.patch_size, | ||
auto_slice_resolution=False, | ||
overlap_height_ratio=args.patch_overlap_ratio, | ||
overlap_width_ratio=args.patch_overlap_ratio, | ||
) | ||
|
||
# perform sliced inference | ||
slice_results = [] | ||
start = 0 | ||
while True: | ||
# prepare batch slices | ||
end = min(start + args.batch_size, len(sliced_image_object)) | ||
images = [] | ||
for sliced_image in sliced_image_object.images[start:end]: | ||
images.append(sliced_image) | ||
|
||
# forward the model | ||
slice_results.extend(inference_detector(model, images)) | ||
|
||
if end >= len(sliced_image_object): | ||
break | ||
start += args.batch_size | ||
|
||
if source_type['is_dir']: | ||
filename = os.path.relpath(file, args.img).replace('/', '_') | ||
else: | ||
filename = os.path.basename(file) | ||
|
||
# export debug images | ||
if args.debug: | ||
# export sliced images | ||
for i, image in enumerate(sliced_image_object.images): | ||
image = mmcv.imconvert(image, 'bgr', 'rgb') | ||
out_file = os.path.join(args.out_dir, 'sliced_images', | ||
filename + f'_slice_{i}.jpg') | ||
|
||
mmcv.imwrite(image, out_file) | ||
|
||
# export sliced image results | ||
for i, slice_result in enumerate(slice_results): | ||
out_file = os.path.join(args.out_dir, 'sliced_image_results', | ||
filename + f'_slice_{i}_result.jpg') | ||
image = mmcv.imconvert(sliced_image_object.images[i], 'bgr', | ||
'rgb') | ||
|
||
visualizer.add_datasample( | ||
os.path.basename(out_file), | ||
image, | ||
data_sample=slice_result, | ||
draw_gt=False, | ||
show=args.show, | ||
wait_time=0, | ||
out_file=out_file, | ||
pred_score_thr=args.score_thr, | ||
) | ||
|
||
image_result = merge_results_by_nms( | ||
slice_results, | ||
sliced_image_object.starting_pixels, | ||
src_image_shape=(height, width), | ||
nms_cfg={ | ||
'type': args.merge_nms_type, | ||
'iou_thr': args.merge_iou_thr | ||
}) | ||
|
||
img = mmcv.imconvert(img, 'bgr', 'rgb') | ||
out_file = None if args.show else os.path.join(args.out_dir, filename) | ||
|
||
visualizer.add_datasample( | ||
os.path.basename(out_file), | ||
img, | ||
data_sample=image_result, | ||
draw_gt=False, | ||
show=args.show, | ||
wait_time=0, | ||
out_file=out_file, | ||
pred_score_thr=args.score_thr, | ||
) | ||
progress_bar.update() | ||
|
||
if not args.show: | ||
print_log( | ||
f'\nResults have been saved at {os.path.abspath(args.out_dir)}') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Sequence, Tuple | ||
|
||
from mmcv.ops import batched_nms | ||
from mmdet.structures import DetDataSample, SampleList | ||
from mmengine.structures import InstanceData | ||
|
||
|
||
def shift_predictions(det_data_samples: SampleList, | ||
offsets: Sequence[Tuple[int, int]], | ||
src_image_shape: Tuple[int, int]) -> SampleList: | ||
"""Shift predictions to the original image. | ||
Args: | ||
det_data_samples (List[:obj:`DetDataSample`]): A list of patch results. | ||
offsets (Sequence[Tuple[int, int]]): Positions of the left top points | ||
of patches. | ||
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large | ||
image's width and height. | ||
Returns: | ||
(List[:obj:`DetDataSample`]): shifted results. | ||
""" | ||
try: | ||
from sahi.slicing import shift_bboxes, shift_masks | ||
except ImportError: | ||
raise ImportError('Please run "pip install -U sahi" ' | ||
'to install sahi first for large image inference.') | ||
|
||
assert len(det_data_samples) == len( | ||
offsets), 'The `results` should has the ' 'same length with `offsets`.' | ||
shifted_predictions = [] | ||
for det_data_sample, offset in zip(det_data_samples, offsets): | ||
pred_inst = det_data_sample.pred_instances.clone() | ||
|
||
# shift bboxes and masks | ||
pred_inst.bboxes = shift_bboxes(pred_inst.bboxes, offset) | ||
if 'masks' in det_data_sample: | ||
pred_inst.masks = shift_masks(pred_inst.masks, offset, | ||
src_image_shape) | ||
|
||
shifted_predictions.append(pred_inst.clone()) | ||
|
||
shifted_predictions = InstanceData.cat(shifted_predictions) | ||
|
||
return shifted_predictions | ||
|
||
|
||
def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int, | ||
int]], | ||
src_image_shape: Tuple[int, int], | ||
nms_cfg: dict) -> DetDataSample: | ||
"""Merge patch results by nms. | ||
Args: | ||
results (List[:obj:`DetDataSample`]): A list of patch results. | ||
offsets (Sequence[Tuple[int, int]]): Positions of the left top points | ||
of patches. | ||
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large | ||
image's width and height. | ||
nms_cfg (dict): it should specify nms type and other parameters | ||
like `iou_threshold`. | ||
Returns: | ||
:obj:`DetDataSample`: merged results. | ||
""" | ||
shifted_instances = shift_predictions(results, offsets, src_image_shape) | ||
|
||
_, keeps = batched_nms( | ||
boxes=shifted_instances.bboxes, | ||
scores=shifted_instances.scores, | ||
idxs=shifted_instances.labels, | ||
nms_cfg=nms_cfg) | ||
merged_instances = shifted_instances[keeps] | ||
|
||
merged_result = results[0].clone() | ||
merged_result.pred_instances = merged_instances | ||
return merged_result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
sahi>=0.11.4 |