Skip to content

Commit

Permalink
[Feature] Add large image demo with sahi (#284)
Browse files Browse the repository at this point in the history
* add large image demo with sahi

* fix some typos

* restructure based on reviews

* update default patch size

* add docstring and update docs

* updates based on reviews

* print information

* add debug, update docs, add large image sample

* update docs

* update docs

* update docs

* direct user to install sahi
  • Loading branch information
fcakyon authored Nov 18, 2022
1 parent 5cee9c9 commit 0fd6444
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 0 deletions.
Binary file added demo/large_image.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
208 changes: 208 additions & 0 deletions demo/large_image_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Perform MMYOLO inference on large images (as satellite imagery) as:
```shell
wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth # noqa: E501, E261.
python demo/large_image_demo.py \
demo/large_image.jpg \
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
```
"""

import os
from argparse import ArgumentParser

import mmcv
from mmdet.apis import inference_detector, init_detector
from mmengine.logging import print_log
from mmengine.utils import ProgressBar

try:
from sahi.slicing import slice_image
except ImportError:
raise ImportError('Please run "pip install -U sahi" '
'to install sahi first for large image inference.')

from mmyolo.registry import VISUALIZERS
from mmyolo.utils import register_all_modules, switch_to_deploy
from mmyolo.utils.large_image import merge_results_by_nms
from mmyolo.utils.misc import get_file_list


def parse_args():
parser = ArgumentParser(
description='Perform MMYOLO inference on large images.')
parser.add_argument(
'img', help='Image path, include image file, dir and URL.')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--out-dir', default='./output', help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
parser.add_argument(
'--deploy',
action='store_true',
help='Switch model to deployment mode')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
parser.add_argument(
'--patch-size', type=int, default=640, help='The size of patches')
parser.add_argument(
'--patch-overlap-ratio',
type=int,
default=0.25,
help='Ratio of overlap between two patches')
parser.add_argument(
'--merge-iou-thr',
type=float,
default=0.25,
help='IoU threshould for merging results')
parser.add_argument(
'--merge-nms-type',
type=str,
default='nms',
help='NMS type for merging results')
parser.add_argument(
'--batch-size',
type=int,
default=1,
help='Batch size, must greater than or equal to 1')
parser.add_argument(
'--debug',
action='store_true',
help='Export debug images at each stage for 1 input')
args = parser.parse_args()
return args


def main():
args = parse_args()

# register all modules in mmdet into the registries
register_all_modules()

# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)

if args.deploy:
switch_to_deploy(model)

if not os.path.exists(args.out_dir) and not args.show:
os.mkdir(args.out_dir)

# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta

# get file list
files, source_type = get_file_list(args.img)

# if debug, only process the first file
if args.debug:
files = files[:1]

# start detector inference
print(f'Performing inference on {len(files)} images... \
This may take a while.')
progress_bar = ProgressBar(len(files))
for file in files:
# read image
img = mmcv.imread(file)

# arrange slices
height, width = img.shape[:2]
sliced_image_object = slice_image(
img,
slice_height=args.patch_size,
slice_width=args.patch_size,
auto_slice_resolution=False,
overlap_height_ratio=args.patch_overlap_ratio,
overlap_width_ratio=args.patch_overlap_ratio,
)

# perform sliced inference
slice_results = []
start = 0
while True:
# prepare batch slices
end = min(start + args.batch_size, len(sliced_image_object))
images = []
for sliced_image in sliced_image_object.images[start:end]:
images.append(sliced_image)

# forward the model
slice_results.extend(inference_detector(model, images))

if end >= len(sliced_image_object):
break
start += args.batch_size

if source_type['is_dir']:
filename = os.path.relpath(file, args.img).replace('/', '_')
else:
filename = os.path.basename(file)

# export debug images
if args.debug:
# export sliced images
for i, image in enumerate(sliced_image_object.images):
image = mmcv.imconvert(image, 'bgr', 'rgb')
out_file = os.path.join(args.out_dir, 'sliced_images',
filename + f'_slice_{i}.jpg')

mmcv.imwrite(image, out_file)

# export sliced image results
for i, slice_result in enumerate(slice_results):
out_file = os.path.join(args.out_dir, 'sliced_image_results',
filename + f'_slice_{i}_result.jpg')
image = mmcv.imconvert(sliced_image_object.images[i], 'bgr',
'rgb')

visualizer.add_datasample(
os.path.basename(out_file),
image,
data_sample=slice_result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=out_file,
pred_score_thr=args.score_thr,
)

image_result = merge_results_by_nms(
slice_results,
sliced_image_object.starting_pixels,
src_image_shape=(height, width),
nms_cfg={
'type': args.merge_nms_type,
'iou_thr': args.merge_iou_thr
})

img = mmcv.imconvert(img, 'bgr', 'rgb')
out_file = None if args.show else os.path.join(args.out_dir, filename)

visualizer.add_datasample(
os.path.basename(out_file),
img,
data_sample=image_result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=out_file,
pred_score_thr=args.score_thr,
)
progress_bar.update()

if not args.show:
print_log(
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')


if __name__ == '__main__':
main()
53 changes: 53 additions & 0 deletions docs/en/user_guides/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,59 @@ python tools/analysis_tools/optimize_anchors.py ${CONFIG} \
--output-dir ${OUTPUT_DIR}
```
## Perform inference on large images
First install [`sahi`](https://github.com/obss/sahi) with:
```shell
pip install -U sahi>=0.11.4
```
Perform MMYOLO inference on large images (as satellite imagery) as:
```shell
wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth
python demo/large_image_demo.py \
demo/large_image.jpg \
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
```
Arrange slicing parameters as:
```shell
python demo/large_image_demo.py \
demo/large_image.jpg \
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
--patch-size 512
--patch-overlap-ratio 0.25
```
Export debug visuals while performing inference on large images as:
```shell
python demo/large_image_demo.py \
demo/large_image.jpg \
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
--debug
```
[`sahi`](https://github.com/obss/sahi) citation:
```
@article{akyon2022sahi,
title={Slicing Aided Hyper Inference and Fine-tuning for Small Object Detection},
author={Akyon, Fatih Cagatay and Altinuc, Sinan Onur and Temizel, Alptekin},
journal={2022 IEEE International Conference on Image Processing (ICIP)},
doi={10.1109/ICIP46576.2022.9897990},
pages={966-970},
year={2022}
}
```
## Extracts a subset of COCO
The training dataset of the COCO2017 dataset includes 118K images, and the validation set includes 5K images, which is a relatively large dataset. Loading JSON in debugging or quick verification scenarios will consume more resources and bring slower startup speed.
Expand Down
76 changes: 76 additions & 0 deletions mmyolo/utils/large_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Tuple

from mmcv.ops import batched_nms
from mmdet.structures import DetDataSample, SampleList
from mmengine.structures import InstanceData


def shift_predictions(det_data_samples: SampleList,
offsets: Sequence[Tuple[int, int]],
src_image_shape: Tuple[int, int]) -> SampleList:
"""Shift predictions to the original image.
Args:
det_data_samples (List[:obj:`DetDataSample`]): A list of patch results.
offsets (Sequence[Tuple[int, int]]): Positions of the left top points
of patches.
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large
image's width and height.
Returns:
(List[:obj:`DetDataSample`]): shifted results.
"""
try:
from sahi.slicing import shift_bboxes, shift_masks
except ImportError:
raise ImportError('Please run "pip install -U sahi" '
'to install sahi first for large image inference.')

assert len(det_data_samples) == len(
offsets), 'The `results` should has the ' 'same length with `offsets`.'
shifted_predictions = []
for det_data_sample, offset in zip(det_data_samples, offsets):
pred_inst = det_data_sample.pred_instances.clone()

# shift bboxes and masks
pred_inst.bboxes = shift_bboxes(pred_inst.bboxes, offset)
if 'masks' in det_data_sample:
pred_inst.masks = shift_masks(pred_inst.masks, offset,
src_image_shape)

shifted_predictions.append(pred_inst.clone())

shifted_predictions = InstanceData.cat(shifted_predictions)

return shifted_predictions


def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int,
int]],
src_image_shape: Tuple[int, int],
nms_cfg: dict) -> DetDataSample:
"""Merge patch results by nms.
Args:
results (List[:obj:`DetDataSample`]): A list of patch results.
offsets (Sequence[Tuple[int, int]]): Positions of the left top points
of patches.
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large
image's width and height.
nms_cfg (dict): it should specify nms type and other parameters
like `iou_threshold`.
Returns:
:obj:`DetDataSample`: merged results.
"""
shifted_instances = shift_predictions(results, offsets, src_image_shape)

_, keeps = batched_nms(
boxes=shifted_instances.bboxes,
scores=shifted_instances.scores,
idxs=shifted_instances.labels,
nms_cfg=nms_cfg)
merged_instances = shifted_instances[keeps]

merged_result = results[0].clone()
merged_result.pred_instances = merged_instances
return merged_result
1 change: 1 addition & 0 deletions requirements/sahi.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sahi>=0.11.4

0 comments on commit 0fd6444

Please sign in to comment.