Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add large image demo with sahi #284

Merged
merged 12 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added demo/large_image.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
203 changes: 203 additions & 0 deletions demo/large_image_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Perform MMYOLO inference on large images (as satellite imagery) as:

```shell
wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth # noqa: E501, E261.

python demo/large_image_demo.py \
demo/large_image.jpg \
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
```
"""

import os
from argparse import ArgumentParser

import mmcv
from mmdet.apis import inference_detector, init_detector
from mmengine.logging import print_log
from mmengine.utils import ProgressBar
from sahi.slicing import slice_image
fcakyon marked this conversation as resolved.
Show resolved Hide resolved

from mmyolo.registry import VISUALIZERS
from mmyolo.utils import register_all_modules, switch_to_deploy
from mmyolo.utils.large_image import merge_results_by_nms
from mmyolo.utils.misc import get_file_list

fcakyon marked this conversation as resolved.
Show resolved Hide resolved

def parse_args():
parser = ArgumentParser(
description='Perform MMYOLO inference on large images.')
parser.add_argument(
'img', help='Image path, include image file, dir and URL.')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--out-dir', default='./output', help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
parser.add_argument(
'--deploy',
action='store_true',
help='Switch model to deployment mode')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
parser.add_argument(
'--patch-size', type=int, default=640, help='The size of patches')
parser.add_argument(
'--patch-overlap-ratio',
type=int,
default=0.25,
help='Ratio of overlap between two patches')
parser.add_argument(
'--merge-iou-thr',
type=float,
default=0.25,
help='IoU threshould for merging results')
parser.add_argument(
'--merge-nms-type',
type=str,
default='nms',
help='NMS type for merging results')
parser.add_argument(
'--batch-size',
type=int,
default=1,
help='Batch size, must greater than or equal to 1')
parser.add_argument(
'--debug',
action='store_true',
help='Export debug images at each stage for 1 input')
args = parser.parse_args()
return args


def main():
args = parse_args()

# register all modules in mmdet into the registries
register_all_modules()

# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)

if args.deploy:
switch_to_deploy(model)

if not os.path.exists(args.out_dir) and not args.show:
os.mkdir(args.out_dir)

# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta

# get file list
files, source_type = get_file_list(args.img)

# if debug, only process the first file
if args.debug:
files = files[:1]
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved

# start detector inference
print(f'Performing inference on {len(files)} images... \
This may take a while.')
progress_bar = ProgressBar(len(files))
for file in files:
# read image
img = mmcv.imread(file)

# arrange slices
height, width = img.shape[:2]
sliced_image_object = slice_image(
img,
slice_height=args.patch_size,
slice_width=args.patch_size,
auto_slice_resolution=False,
overlap_height_ratio=args.patch_overlap_ratio,
overlap_width_ratio=args.patch_overlap_ratio,
)

# perform sliced inference
slice_results = []
start = 0
while True:
# prepare batch slices
end = min(start + args.batch_size, len(sliced_image_object))
images = []
for sliced_image in sliced_image_object.images[start:end]:
images.append(sliced_image)

# forward the model
slice_results.extend(inference_detector(model, images))

if end >= len(sliced_image_object):
break
start += args.batch_size

if source_type['is_dir']:
filename = os.path.relpath(file, args.img).replace('/', '_')
else:
filename = os.path.basename(file)

# export debug images
if args.debug:
# export sliced images
for i, image in enumerate(sliced_image_object.images):
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
image = mmcv.imconvert(image, 'bgr', 'rgb')
out_file = os.path.join(args.out_dir, 'sliced_images',
filename + f'_slice_{i}.jpg')

mmcv.imwrite(image, out_file)

# export sliced image results
for i, slice_result in enumerate(slice_results):
out_file = os.path.join(args.out_dir, 'sliced_image_results',
filename + f'_slice_{i}_result.jpg')
image = mmcv.imconvert(sliced_image_object.images[i], 'bgr',
'rgb')

visualizer.add_datasample(
os.path.basename(out_file),
image,
data_sample=slice_result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=out_file,
pred_score_thr=args.score_thr,
)

image_result = merge_results_by_nms(
slice_results,
sliced_image_object.starting_pixels,
src_image_shape=(height, width),
nms_cfg={
'type': args.merge_nms_type,
'iou_thr': args.merge_iou_thr
})

img = mmcv.imconvert(img, 'bgr', 'rgb')
out_file = None if args.show else os.path.join(args.out_dir, filename)

visualizer.add_datasample(
os.path.basename(out_file),
img,
data_sample=image_result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=out_file,
pred_score_thr=args.score_thr,
)
progress_bar.update()

if not args.show:
print_log(
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')


if __name__ == '__main__':
main()
53 changes: 53 additions & 0 deletions docs/en/user_guides/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,59 @@ python tools/analysis_tools/optimize_anchors.py ${CONFIG} \
--output-dir ${OUTPUT_DIR}
```

## Perform inference on large images

First install [`sahi`](https://github.com/obss/sahi) with:

```shell
pip install -U sahi>=0.11.4
```

Perform MMYOLO inference on large images (as satellite imagery) as:

```shell
wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth

python demo/large_image_demo.py \
demo/large_image.jpg \
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
```

Arrange slicing parameters as:

```shell
python demo/large_image_demo.py \
demo/large_image.jpg \
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
--patch-size 512
--patch-overlap-ratio 0.25
```

Export debug visuals while performing inference on large images as:

```shell
python demo/large_image_demo.py \
demo/large_image.jpg \
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
--debug
```

[`sahi`](https://github.com/obss/sahi) citation:

```
@article{akyon2022sahi,
title={Slicing Aided Hyper Inference and Fine-tuning for Small Object Detection},
author={Akyon, Fatih Cagatay and Altinuc, Sinan Onur and Temizel, Alptekin},
journal={2022 IEEE International Conference on Image Processing (ICIP)},
doi={10.1109/ICIP46576.2022.9897990},
pages={966-970},
year={2022}
}
```

## Extracts a subset of COCO

The training dataset of the COCO2017 dataset includes 118K images, and the validation set includes 5K images, which is a relatively large dataset. Loading JSON in debugging or quick verification scenarios will consume more resources and bring slower startup speed.
Expand Down
74 changes: 74 additions & 0 deletions mmyolo/utils/large_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Tuple

from mmcv.ops import batched_nms
from mmdet.structures import DetDataSample, SampleList
from mmengine.structures import InstanceData


def shift_predictions(det_data_samples: SampleList,
offsets: Sequence[Tuple[int, int]],
src_image_shape: Tuple[int, int]) -> SampleList:
"""Shift predictions to the original image.

Args:
det_data_samples (List[:obj:`DetDataSample`]): A list of patch results.
offsets (Sequence[Tuple[int, int]]): Positions of the left top points
of patches.
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large
image's width and height.
Returns:
(List[:obj:`DetDataSample`]): shifted results.
"""
from sahi.slicing import shift_bboxes, shift_masks
fcakyon marked this conversation as resolved.
Show resolved Hide resolved

assert len(det_data_samples) == len(
offsets), 'The `results` should has the ' 'same length with `offsets`.'
shifted_predictions = []
for det_data_sample, offset in zip(det_data_samples, offsets):
pred_inst = det_data_sample.pred_instances.clone()

# shift bboxes and masks
pred_inst.bboxes = shift_bboxes(pred_inst.bboxes, offset)
if 'masks' in det_data_sample:
pred_inst.masks = shift_masks(pred_inst.masks, offset,
src_image_shape)

shifted_predictions.append(pred_inst.clone())

shifted_predictions = InstanceData.cat(shifted_predictions)

return shifted_predictions


def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int,
int]],
src_image_shape: Tuple[int, int],
nms_cfg: dict) -> DetDataSample:
"""Merge patch results by nms.

Args:
results (List[:obj:`DetDataSample`]): A list of patch results.
offsets (Sequence[Tuple[int, int]]): Positions of the left top points
of patches.
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large
image's width and height.
nms_cfg (dict): it should specify nms type and other parameters
like `iou_threshold`.
Returns:
:obj:`DetDataSample`: merged results.
"""
shifted_instances = shift_predictions(results, offsets, src_image_shape)

_, keeps = batched_nms(
boxes=shifted_instances.bboxes,
scores=shifted_instances.scores,
idxs=shifted_instances.labels,
nms_cfg=nms_cfg)
merged_instances = shifted_instances[keeps]

merged_result = results[0].clone()
# update items like gt_instances, ignore_instances
merged_result.update(results[0])
fcakyon marked this conversation as resolved.
Show resolved Hide resolved
merged_result.pred_instances = merged_instances
return merged_result
1 change: 1 addition & 0 deletions requirements/sahi.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sahi>=0.11.4