forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CodeCamp2023-474] Add new configuration files for DINO algorithm in …
…mmdetection. (open-mmlab#10901)
- Loading branch information
1 parent
769c810
commit ed65c3b
Showing
7 changed files
with
290 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmcv.transforms import RandomChoice, RandomChoiceResize | ||
from mmcv.transforms.loading import LoadImageFromFile | ||
from mmengine.config import read_base | ||
from mmengine.model.weight_init import PretrainedInit | ||
from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper | ||
from mmengine.optim.scheduler.lr_scheduler import MultiStepLR | ||
from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop | ||
from torch.nn.modules.batchnorm import BatchNorm2d | ||
from torch.nn.modules.normalization import GroupNorm | ||
from torch.optim.adamw import AdamW | ||
|
||
from mmdet.datasets.transforms import (LoadAnnotations, PackDetInputs, | ||
RandomCrop, RandomFlip, Resize) | ||
from mmdet.models import (DINO, ChannelMapper, DetDataPreprocessor, DINOHead, | ||
ResNet) | ||
from mmdet.models.losses.focal_loss import FocalLoss | ||
from mmdet.models.losses.iou_loss import GIoULoss | ||
from mmdet.models.losses.smooth_l1_loss import L1Loss | ||
from mmdet.models.task_modules import (BBoxL1Cost, FocalLossCost, | ||
HungarianAssigner, IoUCost) | ||
|
||
with read_base(): | ||
from .._base_.datasets.coco_detection import * | ||
from .._base_.default_runtime import * | ||
|
||
model = dict( | ||
type=DINO, | ||
num_queries=900, # num_matching_queries | ||
with_box_refine=True, | ||
as_two_stage=True, | ||
data_preprocessor=dict( | ||
type=DetDataPreprocessor, | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_size_divisor=1), | ||
backbone=dict( | ||
type=ResNet, | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type=BatchNorm2d, requires_grad=False), | ||
norm_eval=True, | ||
style='pytorch', | ||
init_cfg=dict( | ||
type=PretrainedInit, checkpoint='torchvision://resnet50')), | ||
neck=dict( | ||
type=ChannelMapper, | ||
in_channels=[512, 1024, 2048], | ||
kernel_size=1, | ||
out_channels=256, | ||
act_cfg=None, | ||
norm_cfg=dict(type=GroupNorm, num_groups=32), | ||
num_outs=4), | ||
encoder=dict( | ||
num_layers=6, | ||
layer_cfg=dict( | ||
self_attn_cfg=dict(embed_dims=256, num_levels=4, | ||
dropout=0.0), # 0.1 for DeformDETR | ||
ffn_cfg=dict( | ||
embed_dims=256, | ||
feedforward_channels=2048, # 1024 for DeformDETR | ||
ffn_drop=0.0))), # 0.1 for DeformDETR | ||
decoder=dict( | ||
num_layers=6, | ||
return_intermediate=True, | ||
layer_cfg=dict( | ||
self_attn_cfg=dict(embed_dims=256, num_heads=8, | ||
dropout=0.0), # 0.1 for DeformDETR | ||
cross_attn_cfg=dict(embed_dims=256, num_levels=4, | ||
dropout=0.0), # 0.1 for DeformDETR | ||
ffn_cfg=dict( | ||
embed_dims=256, | ||
feedforward_channels=2048, # 1024 for DeformDETR | ||
ffn_drop=0.0)), # 0.1 for DeformDETR | ||
post_norm_cfg=None), | ||
positional_encoding=dict( | ||
num_feats=128, | ||
normalize=True, | ||
offset=0.0, # -0.5 for DeformDETR | ||
temperature=20), # 10000 for DeformDETR | ||
bbox_head=dict( | ||
type=DINOHead, | ||
num_classes=80, | ||
sync_cls_avg_factor=True, | ||
loss_cls=dict( | ||
type=FocalLoss, | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), # 2.0 in DeformDETR | ||
loss_bbox=dict(type=L1Loss, loss_weight=5.0), | ||
loss_iou=dict(type=GIoULoss, loss_weight=2.0)), | ||
dn_cfg=dict( # TODO: Move to model.train_cfg ? | ||
label_noise_scale=0.5, | ||
box_noise_scale=1.0, # 0.4 for DN-DETR | ||
group_cfg=dict(dynamic=True, num_groups=None, | ||
num_dn_queries=100)), # TODO: half num_dn_queries | ||
# training and testing settings | ||
train_cfg=dict( | ||
assigner=dict( | ||
type=HungarianAssigner, | ||
match_costs=[ | ||
dict(type=FocalLossCost, weight=2.0), | ||
dict(type=BBoxL1Cost, weight=5.0, box_format='xywh'), | ||
dict(type=IoUCost, iou_mode='giou', weight=2.0) | ||
])), | ||
test_cfg=dict(max_per_img=300)) # 100 for DeformDETR | ||
|
||
# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different | ||
# from the default setting in mmdet. | ||
train_pipeline = [ | ||
dict(type=LoadImageFromFile, backend_args=backend_args), | ||
dict(type=LoadAnnotations, with_bbox=True), | ||
dict(type=RandomFlip, prob=0.5), | ||
dict( | ||
type=RandomChoice, | ||
transforms=[ | ||
[ | ||
dict( | ||
type=RandomChoiceResize, | ||
resize_type=Resize, | ||
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), | ||
(608, 1333), (640, 1333), (672, 1333), (704, 1333), | ||
(736, 1333), (768, 1333), (800, 1333)], | ||
keep_ratio=True) | ||
], | ||
[ | ||
dict( | ||
type=RandomChoiceResize, | ||
resize_type=Resize, | ||
# The radio of all image in train dataset < 7 | ||
# follow the original implement | ||
scales=[(400, 4200), (500, 4200), (600, 4200)], | ||
keep_ratio=True), | ||
dict( | ||
type=RandomCrop, | ||
crop_type='absolute_range', | ||
crop_size=(384, 600), | ||
allow_negative_crop=True), | ||
dict( | ||
type=RandomChoiceResize, | ||
resize_type=Resize, | ||
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), | ||
(608, 1333), (640, 1333), (672, 1333), (704, 1333), | ||
(736, 1333), (768, 1333), (800, 1333)], | ||
keep_ratio=True) | ||
] | ||
]), | ||
dict(type=PackDetInputs) | ||
] | ||
train_dataloader.update( | ||
dataset=dict( | ||
filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline)) | ||
|
||
# optimizer | ||
optim_wrapper = dict( | ||
type=OptimWrapper, | ||
optimizer=dict( | ||
type=AdamW, | ||
lr=0.0001, # 0.0002 for DeformDETR | ||
weight_decay=0.0001), | ||
clip_grad=dict(max_norm=0.1, norm_type=2), | ||
paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)}) | ||
) # custom_keys contains sampling_offsets and reference_points in DeformDETR # noqa | ||
|
||
# learning policy | ||
max_epochs = 12 | ||
train_cfg = dict( | ||
type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1) | ||
|
||
val_cfg = dict(type=ValLoop) | ||
test_cfg = dict(type=TestLoop) | ||
|
||
param_scheduler = [ | ||
dict( | ||
type=MultiStepLR, | ||
begin=0, | ||
end=max_epochs, | ||
by_epoch=True, | ||
milestones=[11], | ||
gamma=0.1) | ||
] | ||
|
||
# NOTE: `auto_scale_lr` is for automatically scaling LR, | ||
# USER SHOULD NOT CHANGE ITS VALUES. | ||
# base_batch_size = (8 GPUs) x (2 samples per GPU) | ||
auto_scale_lr = dict(base_batch_size=16) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
from mmengine.runner.loops import EpochBasedTrainLoop | ||
|
||
with read_base(): | ||
from .dino_4scale_r50_8xb2_12e_coco import * | ||
|
||
max_epochs = 24 | ||
train_cfg.update( | ||
dict(type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1)) | ||
|
||
param_scheduler[0].update(dict(milestones=[20])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
from mmengine.runner.loops import EpochBasedTrainLoop | ||
|
||
with read_base(): | ||
from .dino_4scale_r50_8xb2_12e_coco import * | ||
|
||
max_epochs = 36 | ||
train_cfg.update( | ||
dict(type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1)) | ||
|
||
param_scheduler[0].update(dict(milestones=[30])) |
24 changes: 24 additions & 0 deletions
24
mmdet/configs/dino/dino_4scale_r50_improved_8xb2_12e_coco.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .dino_4scale_r50_8xb2_12e_coco import * | ||
|
||
# from deformable detr hyper | ||
model.update( | ||
dict( | ||
backbone=dict(frozen_stages=-1), | ||
bbox_head=dict(loss_cls=dict(loss_weight=2.0)), | ||
positional_encoding=dict(offset=-0.5, temperature=10000), | ||
dn_cfg=dict(group_cfg=dict(num_dn_queries=300)))) | ||
|
||
# optimizer | ||
optim_wrapper.update( | ||
dict( | ||
optimizer=dict(lr=0.0002), | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'backbone': dict(lr_mult=0.1), | ||
'sampling_offsets': dict(lr_mult=0.1), | ||
'reference_points': dict(lr_mult=0.1) | ||
}))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
from mmengine.model.weight_init import PretrainedInit | ||
|
||
from mmdet.models import SwinTransformer | ||
|
||
with read_base(): | ||
from .dino_4scale_r50_8xb2_12e_coco import * | ||
|
||
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth' # noqa | ||
num_levels = 5 | ||
model.merge( | ||
dict( | ||
num_feature_levels=num_levels, | ||
backbone=dict( | ||
_delete_=True, | ||
type=SwinTransformer, | ||
pretrain_img_size=384, | ||
embed_dims=192, | ||
depths=[2, 2, 18, 2], | ||
num_heads=[6, 12, 24, 48], | ||
window_size=12, | ||
mlp_ratio=4, | ||
qkv_bias=True, | ||
qk_scale=None, | ||
drop_rate=0., | ||
attn_drop_rate=0., | ||
drop_path_rate=0.2, | ||
patch_norm=True, | ||
out_indices=(0, 1, 2, 3), | ||
# Please only add indices that would be used | ||
# in FPN, otherwise some parameter will not be used | ||
with_cp=True, | ||
convert_weights=True, | ||
init_cfg=dict(type=PretrainedInit, checkpoint=pretrained)), | ||
neck=dict(in_channels=[192, 384, 768, 1536], num_outs=num_levels), | ||
encoder=dict( | ||
layer_cfg=dict(self_attn_cfg=dict(num_levels=num_levels))), | ||
decoder=dict( | ||
layer_cfg=dict(cross_attn_cfg=dict(num_levels=num_levels))))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
from mmengine.runner.loops import EpochBasedTrainLoop | ||
|
||
with read_base(): | ||
from .dino_5scale_swin_l_8xb2_12e_coco import * | ||
|
||
max_epochs = 36 | ||
train_cfg.update( | ||
dict(type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1)) | ||
|
||
param_scheduler[0].update(dict(milestones=[27, 33])) |