Skip to content

Commit

Permalink
[CodeCamp2023-603] Add new configuration files for MaskFormer algorit…
Browse files Browse the repository at this point in the history
…hm in mmdetection
  • Loading branch information
SimonGuoNjust authored Oct 9, 2023
1 parent 8cc950c commit d84ea9b
Show file tree
Hide file tree
Showing 2 changed files with 331 additions and 0 deletions.
249 changes: 249 additions & 0 deletions mmdet/configs/maskformer/maskformer_r50_ms_16xb1_75e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.transforms import RandomChoice, RandomChoiceResize
from mmengine.config import read_base
from mmengine.model.weight_init import PretrainedInit
from mmengine.optim.optimizer import OptimWrapper
from mmengine.optim.scheduler import MultiStepLR
from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.normalization import GroupNorm
from torch.optim.adamw import AdamW

from mmdet.datasets.transforms.transforms import RandomCrop
from mmdet.models import MaskFormer
from mmdet.models.backbones import ResNet
from mmdet.models.data_preprocessors.data_preprocessor import \
DetDataPreprocessor
from mmdet.models.dense_heads.maskformer_head import MaskFormerHead
from mmdet.models.layers.pixel_decoder import TransformerEncoderPixelDecoder
from mmdet.models.losses import CrossEntropyLoss, DiceLoss, FocalLoss
from mmdet.models.seg_heads.panoptic_fusion_heads import MaskFormerFusionHead
from mmdet.models.task_modules.assigners.hungarian_assigner import \
HungarianAssigner
from mmdet.models.task_modules.assigners.match_cost import (ClassificationCost,
DiceCost,
FocalLossCost)
from mmdet.models.task_modules.samplers import MaskPseudoSampler

with read_base():
from .._base_.datasets.coco_panoptic import *
from .._base_.default_runtime import *

data_preprocessor = dict(
type=DetDataPreprocessor,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=1,
pad_mask=True,
mask_pad_value=0,
pad_seg=True,
seg_pad_value=255)

num_things_classes = 80
num_stuff_classes = 53
num_classes = num_things_classes + num_stuff_classes
model = dict(
type=MaskFormer,
data_preprocessor=data_preprocessor,
backbone=dict(
type=ResNet,
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type=BatchNorm2d, requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(
type=PretrainedInit, checkpoint='torchvision://resnet50')),
panoptic_head=dict(
type=MaskFormerHead,
in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside
feat_channels=256,
out_channels=256,
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
num_queries=100,
pixel_decoder=dict(
type=TransformerEncoderPixelDecoder,
norm_cfg=dict(type=GroupNorm, num_groups=32),
act_cfg=dict(type=ReLU),
encoder=dict( # DetrTransformerEncoder
num_layers=6,
layer_cfg=dict( # DetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type=ReLU, inplace=True)))),
positional_encoding=dict(num_feats=128, normalize=True)),
enforce_decoder_input_project=False,
positional_encoding=dict(num_feats=128, normalize=True),
transformer_decoder=dict( # DetrTransformerDecoder
num_layers=6,
layer_cfg=dict( # DetrTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1,
batch_first=True),
cross_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type=ReLU, inplace=True))),
return_intermediate=True),
loss_cls=dict(
type=CrossEntropyLoss,
use_sigmoid=False,
loss_weight=1.0,
reduction='mean',
class_weight=[1.0] * num_classes + [0.1]),
loss_mask=dict(
type=FocalLoss,
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=20.0),
loss_dice=dict(
type=DiceLoss,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=1.0)),
panoptic_fusion_head=dict(
type=MaskFormerFusionHead,
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
loss_panoptic=None,
init_cfg=None),
train_cfg=dict(
assigner=dict(
type=HungarianAssigner,
match_costs=[
dict(type=ClassificationCost, weight=1.0),
dict(type=FocalLossCost, weight=20.0, binary_input=True),
dict(type=DiceCost, weight=1.0, pred_act=True, eps=1.0)
]),
sampler=dict(type=MaskPseudoSampler)),
test_cfg=dict(
panoptic_on=True,
# For now, the dataset does not support
# evaluating semantic segmentation metric.
semantic_on=False,
instance_on=False,
# max_per_image is for instance segmentation.
max_per_image=100,
object_mask_thr=0.8,
iou_thr=0.8,
# In MaskFormer's panoptic postprocessing,
# it will not filter masks whose score is smaller than 0.5 .
filter_low_score=False),
init_cfg=None)

# dataset settings
train_pipeline = [
dict(type=LoadImageFromFile),
dict(
type=LoadPanopticAnnotations,
with_bbox=True,
with_mask=True,
with_seg=True),
dict(type=RandomFlip, prob=0.5),
# dict(type=Resize, scale=(1333, 800), keep_ratio=True),
dict(
type=RandomChoice,
transforms=[[
dict(
type=RandomChoiceResize,
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
resize_type=Resize,
keep_ratio=True)
],
[
dict(
type=RandomChoiceResize,
scales=[(400, 1333), (500, 1333), (600, 1333)],
resize_type=Resize,
keep_ratio=True),
dict(
type=RandomCrop,
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type=RandomChoiceResize,
scales=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
resize_type=Resize,
keep_ratio=True)
]]),
dict(type=PackDetInputs)
]

train_dataloader.update(
dict(batch_size=1, num_workers=1, dataset=dict(pipeline=train_pipeline)))

val_dataloader.update(dict(batch_size=1, num_workers=1))

test_dataloader = val_dataloader

# optimizer
optim_wrapper = dict(
type=OptimWrapper,
optimizer=dict(
type=AdamW,
lr=0.0001,
weight_decay=0.0001,
eps=1e-8,
betas=(0.9, 0.999)),
paramwise_cfg=dict(
custom_keys={
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'query_embed': dict(lr_mult=1.0, decay_mult=0.0)
},
norm_decay_mult=0.0),
clip_grad=dict(max_norm=0.01, norm_type=2))

max_epochs = 75

# learning rate
param_scheduler = dict(
type=MultiStepLR,
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[50],
gamma=0.1)

train_cfg = dict(
type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1)
val_cfg = dict(type=ValLoop)
test_cfg = dict(type=TestLoop)

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (16 GPUs) x (1 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16)
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base
from mmengine.optim.scheduler import LinearLR

from mmdet.models.backbones import SwinTransformer
from mmdet.models.layers import PixelDecoder

with read_base():
from .maskformer_r50_ms_16xb1_75e_coco import *

pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth' # noqa
depths = [2, 2, 18, 2]
model.update(
dict(
backbone=dict(
_delete_=True,
type=SwinTransformer,
pretrain_img_size=384,
embed_dims=192,
patch_size=4,
window_size=12,
mlp_ratio=4,
depths=depths,
num_heads=[6, 12, 24, 48],
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
out_indices=(0, 1, 2, 3),
with_cp=False,
convert_weights=True,
init_cfg=dict(type=PretrainedInit, checkpoint=pretrained)),
panoptic_head=dict(
in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
pixel_decoder=dict(
_delete_=True,
type=PixelDecoder,
norm_cfg=dict(type=GroupNorm, num_groups=32),
act_cfg=dict(type=ReLU)),
enforce_decoder_input_project=True)))

# optimizer

# weight_decay = 0.01
# norm_weight_decay = 0.0
# embed_weight_decay = 0.0
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
norm_multi = dict(lr_mult=1.0, decay_mult=0.0)
custom_keys = {
'norm': norm_multi,
'absolute_pos_embed': embed_multi,
'relative_position_bias_table': embed_multi,
'query_embed': embed_multi
}

optim_wrapper.update(
dict(
optimizer=dict(lr=6e-5, weight_decay=0.01),
paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0)))

max_epochs = 300

# learning rate
param_scheduler = [
dict(type=LinearLR, start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type=MultiStepLR,
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[250],
gamma=0.1)
]

train_cfg.update(dict(max_epochs=max_epochs))

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (64 GPUs) x (1 samples per GPU)
auto_scale_lr.update(dict(base_batch_size=64))

0 comments on commit d84ea9b

Please sign in to comment.