This repository has been archived by the owner on Jan 26, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 564
/
config.py
1165 lines (896 loc) · 40.7 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import six
import os
import os.path as osp
import copy
from ast import literal_eval
import numpy as np
from packaging import version
import torch
import torch.nn as nn
from torch.nn import init
import yaml
import nn as mynn
from utils.collections import AttrDict
__C = AttrDict()
# Consumers can get config by:
# from fast_rcnn_config import cfg
cfg = __C
# Random note: avoid using '.ON' as a config key since yaml converts it to True;
# prefer 'ENABLED' instead
# ---------------------------------------------------------------------------- #
# Training options
# ---------------------------------------------------------------------------- #
__C.TRAIN = AttrDict()
# Datasets to train on
# Available dataset list: datasets.dataset_catalog.DATASETS.keys()
# If multiple datasets are listed, the model is trained on their union
__C.TRAIN.DATASETS = ()
# Scales to use during training
# Each scale is the pixel size of an image's shortest side
# If multiple scales are listed, then one is selected uniformly at random for
# each training image (i.e., scale jitter data augmentation)
__C.TRAIN.SCALES = (600, )
# Max pixel size of the longest side of a scaled input image
__C.TRAIN.MAX_SIZE = 1000
# Images *per GPU* in the training minibatch
# Total images per minibatch = TRAIN.IMS_PER_BATCH * NUM_GPUS
__C.TRAIN.IMS_PER_BATCH = 2
# RoI minibatch size *per image* (number of regions of interest [ROIs])
# Total number of RoIs per training minibatch =
# TRAIN.BATCH_SIZE_PER_IM * TRAIN.IMS_PER_BATCH * NUM_GPUS
# E.g., a common configuration is: 512 * 2 * 8 = 8192
__C.TRAIN.BATCH_SIZE_PER_IM = 64
# Fraction of minibatch that is labeled foreground (i.e. class > 0)
__C.TRAIN.FG_FRACTION = 0.25
# Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH)
__C.TRAIN.FG_THRESH = 0.5
# Overlap threshold for a ROI to be considered background (class = 0 if
# overlap in [LO, HI))
__C.TRAIN.BG_THRESH_HI = 0.5
__C.TRAIN.BG_THRESH_LO = 0.0
# Use horizontally-flipped images during training?
__C.TRAIN.USE_FLIPPED = True
# Overlap required between a ROI and ground-truth box in order for that ROI to
# be used as a bounding-box regression training example
__C.TRAIN.BBOX_THRESH = 0.5
# Train using these proposals
# During training, all proposals specified in the file are used (no limit is
# applied)
# Proposal files must be in correspondence with the datasets listed in
# TRAIN.DATASETS
__C.TRAIN.PROPOSAL_FILES = ()
# Snapshot (model checkpoint) period
# Divide by NUM_GPUS to determine actual period (e.g., 20000/8 => 2500 iters)
# to allow for linear training schedule scaling
__C.TRAIN.SNAPSHOT_ITERS = 20000
# Normalize the targets (subtract empirical mean, divide by empirical stddev)
__C.TRAIN.BBOX_NORMALIZE_TARGETS = True
# Deprecated (inside weights)
__C.TRAIN.BBOX_INSIDE_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
# Normalize the targets using "precomputed" (or made up) means and stdevs
# (BBOX_NORMALIZE_TARGETS must also be True) (legacy)
__C.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED = False
__C.TRAIN.BBOX_NORMALIZE_MEANS = (0.0, 0.0, 0.0, 0.0)
__C.TRAIN.BBOX_NORMALIZE_STDS = (0.1, 0.1, 0.2, 0.2)
# Make minibatches from images that have similar aspect ratios (i.e. both
# tall and thin or both short and wide)
# This feature is critical for saving memory (and makes training slightly
# faster)
__C.TRAIN.ASPECT_GROUPING = True
# Crop images that have too small or too large aspect ratio
__C.TRAIN.ASPECT_CROPPING = False
__C.TRAIN.ASPECT_HI = 2
__C.TRAIN.ASPECT_LO = 0.5
# ---------------------------------------------------------------------------- #
# RPN training options
# ---------------------------------------------------------------------------- #
# Minimum overlap required between an anchor and ground-truth box for the
# (anchor, gt box) pair to be a positive example (IOU >= thresh ==> positive RPN
# example)
__C.TRAIN.RPN_POSITIVE_OVERLAP = 0.7
# Maximum overlap allowed between an anchor and ground-truth box for the
# (anchor, gt box) pair to be a negative examples (IOU < thresh ==> negative RPN
# example)
__C.TRAIN.RPN_NEGATIVE_OVERLAP = 0.3
# Target fraction of foreground (positive) examples per RPN minibatch
__C.TRAIN.RPN_FG_FRACTION = 0.5
# Total number of RPN examples per image
__C.TRAIN.RPN_BATCH_SIZE_PER_IM = 256
# NMS threshold used on RPN proposals (used during end-to-end training with RPN)
__C.TRAIN.RPN_NMS_THRESH = 0.7
# Number of top scoring RPN proposals to keep before applying NMS (per image)
# When FPN is used, this is *per FPN level* (not total)
__C.TRAIN.RPN_PRE_NMS_TOP_N = 12000
# Number of top scoring RPN proposals to keep after applying NMS (per image)
# This is the total number of RPN proposals produced (for both FPN and non-FPN
# cases)
__C.TRAIN.RPN_POST_NMS_TOP_N = 2000
# Remove RPN anchors that go outside the image by RPN_STRADDLE_THRESH pixels
# Set to -1 or a large value, e.g. 100000, to disable pruning anchors
__C.TRAIN.RPN_STRADDLE_THRESH = 0
# Proposal height and width both need to be greater than RPN_MIN_SIZE
# (at orig image scale; not scale used during training or inference)
__C.TRAIN.RPN_MIN_SIZE = 0
# Filter proposals that are inside of crowd regions by CROWD_FILTER_THRESH
# "Inside" is measured as: proposal-with-crowd intersection area divided by
# proposal area
__C.TRAIN.CROWD_FILTER_THRESH = 0.7
# Ignore ground-truth objects with area < this threshold
__C.TRAIN.GT_MIN_AREA = -1
# Freeze the backbone architecture during training if set to True
__C.TRAIN.FREEZE_CONV_BODY = False
# ---------------------------------------------------------------------------- #
# Data loader options
# ---------------------------------------------------------------------------- #
__C.DATA_LOADER = AttrDict()
# Number of Python threads to use for the data loader (warning: using too many
# threads can cause GIL-based interference with Python Ops leading to *slower*
# training; 4 seems to be the sweet spot in our experience)
__C.DATA_LOADER.NUM_THREADS = 4
# ---------------------------------------------------------------------------- #
# Inference ('test') options
# ---------------------------------------------------------------------------- #
__C.TEST = AttrDict()
# Datasets to test on
# Available dataset list: datasets.dataset_catalog.DATASETS.keys()
# If multiple datasets are listed, testing is performed on each one sequentially
__C.TEST.DATASETS = ()
# Scale to use during testing (can NOT list multiple scales)
# The scale is the pixel size of an image's shortest side
__C.TEST.SCALE = 600
# Max pixel size of the longest side of a scaled input image
__C.TEST.MAX_SIZE = 1000
# Overlap threshold used for non-maximum suppression (suppress boxes with
# IoU >= this threshold)
__C.TEST.NMS = 0.3
# Apply Fast R-CNN style bounding-box regression if True
__C.TEST.BBOX_REG = True
# Test using these proposal files (must correspond with TEST.DATASETS)
__C.TEST.PROPOSAL_FILES = ()
# Limit on the number of proposals per image used during inference
__C.TEST.PROPOSAL_LIMIT = 2000
## NMS threshold used on RPN proposals
__C.TEST.RPN_NMS_THRESH = 0.7
# Number of top scoring RPN proposals to keep before applying NMS
# When FPN is used, this is *per FPN level* (not total)
__C.TEST.RPN_PRE_NMS_TOP_N = 12000
# Number of top scoring RPN proposals to keep after applying NMS
# This is the total number of RPN proposals produced (for both FPN and non-FPN
# cases)
__C.TEST.RPN_POST_NMS_TOP_N = 2000
# Proposal height and width both need to be greater than RPN_MIN_SIZE
# (at orig image scale; not scale used during training or inference)
__C.TEST.RPN_MIN_SIZE = 0
# Maximum number of detections to return per image (100 is based on the limit
# established for the COCO dataset)
__C.TEST.DETECTIONS_PER_IM = 100
# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to
# balance obtaining high recall with not having too many low precision
# detections that will slow down inference post processing steps (like NMS)
__C.TEST.SCORE_THRESH = 0.05
# Save detection results files if True
# If false, results files are cleaned up (they can be large) after local
# evaluation
__C.TEST.COMPETITION_MODE = True
# Evaluate detections with the COCO json dataset eval code even if it's not the
# evaluation code for the dataset (e.g. evaluate PASCAL VOC results using the
# COCO API to get COCO style AP on PASCAL VOC)
__C.TEST.FORCE_JSON_DATASET_EVAL = False
# [Inferred value; do not set directly in a config]
# Indicates if precomputed proposals are used at test time
# Not set for 1-stage models and 2-stage models with RPN subnetwork enabled
__C.TEST.PRECOMPUTED_PROPOSALS = True
# ---------------------------------------------------------------------------- #
# Test-time augmentations for bounding box detection
# See configs/test_time_aug/e2e_mask_rcnn_R-50-FPN_2x.yaml for an example
# ---------------------------------------------------------------------------- #
__C.TEST.BBOX_AUG = AttrDict()
# Enable test-time augmentation for bounding box detection if True
__C.TEST.BBOX_AUG.ENABLED = False
# Heuristic used to combine predicted box scores
# Valid options: ('ID', 'AVG', 'UNION')
__C.TEST.BBOX_AUG.SCORE_HEUR = 'UNION'
# Heuristic used to combine predicted box coordinates
# Valid options: ('ID', 'AVG', 'UNION')
__C.TEST.BBOX_AUG.COORD_HEUR = 'UNION'
# Horizontal flip at the original scale (id transform)
__C.TEST.BBOX_AUG.H_FLIP = False
# Each scale is the pixel size of an image's shortest side
__C.TEST.BBOX_AUG.SCALES = ()
# Max pixel size of the longer side
__C.TEST.BBOX_AUG.MAX_SIZE = 4000
# Horizontal flip at each scale
__C.TEST.BBOX_AUG.SCALE_H_FLIP = False
# Apply scaling based on object size
__C.TEST.BBOX_AUG.SCALE_SIZE_DEP = False
__C.TEST.BBOX_AUG.AREA_TH_LO = 50**2
__C.TEST.BBOX_AUG.AREA_TH_HI = 180**2
# Each aspect ratio is relative to image width
__C.TEST.BBOX_AUG.ASPECT_RATIOS = ()
# Horizontal flip at each aspect ratio
__C.TEST.BBOX_AUG.ASPECT_RATIO_H_FLIP = False
# ---------------------------------------------------------------------------- #
# Test-time augmentations for mask detection
# See configs/test_time_aug/e2e_mask_rcnn_R-50-FPN_2x.yaml for an example
# ---------------------------------------------------------------------------- #
__C.TEST.MASK_AUG = AttrDict()
# Enable test-time augmentation for instance mask detection if True
__C.TEST.MASK_AUG.ENABLED = False
# Heuristic used to combine mask predictions
# SOFT prefix indicates that the computation is performed on soft masks
# Valid options: ('SOFT_AVG', 'SOFT_MAX', 'LOGIT_AVG')
__C.TEST.MASK_AUG.HEUR = 'SOFT_AVG'
# Horizontal flip at the original scale (id transform)
__C.TEST.MASK_AUG.H_FLIP = False
# Each scale is the pixel size of an image's shortest side
__C.TEST.MASK_AUG.SCALES = ()
# Max pixel size of the longer side
__C.TEST.MASK_AUG.MAX_SIZE = 4000
# Horizontal flip at each scale
__C.TEST.MASK_AUG.SCALE_H_FLIP = False
# Apply scaling based on object size
__C.TEST.MASK_AUG.SCALE_SIZE_DEP = False
__C.TEST.MASK_AUG.AREA_TH = 180**2
# Each aspect ratio is relative to image width
__C.TEST.MASK_AUG.ASPECT_RATIOS = ()
# Horizontal flip at each aspect ratio
__C.TEST.MASK_AUG.ASPECT_RATIO_H_FLIP = False
# ---------------------------------------------------------------------------- #
# Test-augmentations for keypoints detection
# configs/test_time_aug/keypoint_rcnn_R-50-FPN_1x.yaml
# ---------------------------------------------------------------------------- #
__C.TEST.KPS_AUG = AttrDict()
# Enable test-time augmentation for keypoint detection if True
__C.TEST.KPS_AUG.ENABLED = False
# Heuristic used to combine keypoint predictions
# Valid options: ('HM_AVG', 'HM_MAX')
__C.TEST.KPS_AUG.HEUR = 'HM_AVG'
# Horizontal flip at the original scale (id transform)
__C.TEST.KPS_AUG.H_FLIP = False
# Each scale is the pixel size of an image's shortest side
__C.TEST.KPS_AUG.SCALES = ()
# Max pixel size of the longer side
__C.TEST.KPS_AUG.MAX_SIZE = 4000
# Horizontal flip at each scale
__C.TEST.KPS_AUG.SCALE_H_FLIP = False
# Apply scaling based on object size
__C.TEST.KPS_AUG.SCALE_SIZE_DEP = False
__C.TEST.KPS_AUG.AREA_TH = 180**2
# Eeach aspect ratio is realtive to image width
__C.TEST.KPS_AUG.ASPECT_RATIOS = ()
# Horizontal flip at each aspect ratio
__C.TEST.KPS_AUG.ASPECT_RATIO_H_FLIP = False
# ---------------------------------------------------------------------------- #
# Soft NMS
# ---------------------------------------------------------------------------- #
__C.TEST.SOFT_NMS = AttrDict()
# Use soft NMS instead of standard NMS if set to True
__C.TEST.SOFT_NMS.ENABLED = False
# See soft NMS paper for definition of these options
__C.TEST.SOFT_NMS.METHOD = 'linear'
__C.TEST.SOFT_NMS.SIGMA = 0.5
# For the soft NMS overlap threshold, we simply use TEST.NMS
# ---------------------------------------------------------------------------- #
# Bounding box voting (from the Multi-Region CNN paper)
# ---------------------------------------------------------------------------- #
__C.TEST.BBOX_VOTE = AttrDict()
# Use box voting if set to True
__C.TEST.BBOX_VOTE.ENABLED = False
# We use TEST.NMS threshold for the NMS step. VOTE_TH overlap threshold
# is used to select voting boxes (IoU >= VOTE_TH) for each box that survives NMS
__C.TEST.BBOX_VOTE.VOTE_TH = 0.8
# The method used to combine scores when doing bounding box voting
# Valid options include ('ID', 'AVG', 'IOU_AVG', 'GENERALIZED_AVG', 'QUASI_SUM')
__C.TEST.BBOX_VOTE.SCORING_METHOD = 'ID'
# Hyperparameter used by the scoring method (it has different meanings for
# different methods)
__C.TEST.BBOX_VOTE.SCORING_METHOD_BETA = 1.0
# ---------------------------------------------------------------------------- #
# Model options
# ---------------------------------------------------------------------------- #
__C.MODEL = AttrDict()
# The type of model to use
# The string must match a function in the modeling.model_builder module
# (e.g., 'generalized_rcnn', 'mask_rcnn', ...)
__C.MODEL.TYPE = ''
# The backbone conv body to use
__C.MODEL.CONV_BODY = ''
# Number of classes in the dataset; must be set
# E.g., 81 for COCO (80 foreground + 1 background)
__C.MODEL.NUM_CLASSES = -1
# Use a class agnostic bounding box regressor instead of the default per-class
# regressor
__C.MODEL.CLS_AGNOSTIC_BBOX_REG = False
# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets
# These are empirically chosen to approximately lead to unit variance targets
#
# In older versions, the weights were set such that the regression deltas
# would have unit standard deviation on the training dataset. Presently, rather
# than computing these statistics exactly, we use a fixed set of weights
# (10., 10., 5., 5.) by default. These are approximately the weights one would
# get from COCO using the previous unit stdev heuristic.
__C.MODEL.BBOX_REG_WEIGHTS = (10., 10., 5., 5.)
# The meaning of FASTER_RCNN depends on the context (training vs. inference):
# 1) During training, FASTER_RCNN = True means that end-to-end training will be
# used to jointly train the RPN subnetwork and the Fast R-CNN subnetwork
# (Faster R-CNN = RPN + Fast R-CNN).
# 2) During inference, FASTER_RCNN = True means that the model's RPN subnetwork
# will be used to generate proposals rather than relying on precomputed
# proposals. Note that FASTER_RCNN = True can be used at inference time even
# if the Faster R-CNN model was trained with stagewise training (which
# consists of alternating between RPN and Fast R-CNN training in a way that
# finally leads to a single network).
__C.MODEL.FASTER_RCNN = False
# Indicates the model makes instance mask predictions (as in Mask R-CNN)
__C.MODEL.MASK_ON = False
# Indicates the model makes keypoint predictions (as in Mask R-CNN for
# keypoints)
__C.MODEL.KEYPOINTS_ON = False
# Indicates the model's computation terminates with the production of RPN
# proposals (i.e., it outputs proposals ONLY, no actual object detections)
__C.MODEL.RPN_ONLY = False
# [Inferred value; do not set directly in a config]
# Indicate whether the res5 stage weights and training forward computation
# are shared from box head or not.
__C.MODEL.SHARE_RES5 = False
# Whether to load imagenet pretrained weights
# If True, path to the weight file must be specified.
# See: __C.RESNETS.IMAGENET_PRETRAINED_WEIGHTS
__C.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS = True
# ---------------------------------------------------------------------------- #
# Unsupervise Pose
# ---------------------------------------------------------------------------- #
__C.MODEL.UNSUPERVISED_POSE = False
# ---------------------------------------------------------------------------- #
# RetinaNet options
# ---------------------------------------------------------------------------- #
__C.RETINANET = AttrDict()
# RetinaNet is used (instead of Fast/er/Mask R-CNN/R-FCN/RPN) if True
__C.RETINANET.RETINANET_ON = False
# Anchor aspect ratios to use
__C.RETINANET.ASPECT_RATIOS = (0.5, 1.0, 2.0)
# Anchor scales per octave
__C.RETINANET.SCALES_PER_OCTAVE = 3
# At each FPN level, we generate anchors based on their scale, aspect_ratio,
# stride of the level, and we multiply the resulting anchor by ANCHOR_SCALE
__C.RETINANET.ANCHOR_SCALE = 4
# Convolutions to use in the cls and bbox tower
# NOTE: this doesn't include the last conv for logits
__C.RETINANET.NUM_CONVS = 4
# Weight for bbox_regression loss
__C.RETINANET.BBOX_REG_WEIGHT = 1.0
# Smooth L1 loss beta for bbox regression
__C.RETINANET.BBOX_REG_BETA = 0.11
# During inference, #locs to select based on cls score before NMS is performed
# per FPN level
__C.RETINANET.PRE_NMS_TOP_N = 1000
# IoU overlap ratio for labeling an anchor as positive
# Anchors with >= iou overlap are labeled positive
__C.RETINANET.POSITIVE_OVERLAP = 0.5
# IoU overlap ratio for labeling an anchor as negative
# Anchors with < iou overlap are labeled negative
__C.RETINANET.NEGATIVE_OVERLAP = 0.4
# Focal loss parameter: alpha
__C.RETINANET.LOSS_ALPHA = 0.25
# Focal loss parameter: gamma
__C.RETINANET.LOSS_GAMMA = 2.0
# Prior prob for the positives at the beginning of training. This is used to set
# the bias init for the logits layer
__C.RETINANET.PRIOR_PROB = 0.01
# Whether classification and bbox branch tower should be shared or not
__C.RETINANET.SHARE_CLS_BBOX_TOWER = False
# Use class specific bounding box regression instead of the default class
# agnostic regression
__C.RETINANET.CLASS_SPECIFIC_BBOX = False
# Whether softmax should be used in classification branch training
__C.RETINANET.SOFTMAX = False
# Inference cls score threshold, anchors with score > INFERENCE_TH are
# considered for inference
__C.RETINANET.INFERENCE_TH = 0.05
# ---------------------------------------------------------------------------- #
# Solver options
# Note: all solver options are used exactly as specified; the implication is
# that if you switch from training on 1 GPU to N GPUs, you MUST adjust the
# solver configuration accordingly. We suggest using gradual warmup and the
# linear learning rate scaling rule as described in
# "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour" Goyal et al.
# https://arxiv.org/abs/1706.02677
# ---------------------------------------------------------------------------- #
__C.SOLVER = AttrDict()
# e.g 'SGD', 'Adam'
__C.SOLVER.TYPE = 'SGD'
# Base learning rate for the specified schedule
__C.SOLVER.BASE_LR = 0.001
# Schedule type (see functions in utils.lr_policy for options)
# E.g., 'step', 'steps_with_decay', ...
__C.SOLVER.LR_POLICY = 'step'
# Some LR Policies (by example):
# 'step'
# lr = SOLVER.BASE_LR * SOLVER.GAMMA ** (cur_iter // SOLVER.STEP_SIZE)
# 'steps_with_decay'
# SOLVER.STEPS = [0, 60000, 80000]
# SOLVER.GAMMA = 0.1
# lr = SOLVER.BASE_LR * SOLVER.GAMMA ** current_step
# iters [0, 59999] are in current_step = 0, iters [60000, 79999] are in
# current_step = 1, and so on
# 'steps_with_lrs'
# SOLVER.STEPS = [0, 60000, 80000]
# SOLVER.LRS = [0.02, 0.002, 0.0002]
# lr = LRS[current_step]
# Hyperparameter used by the specified policy
# For 'step', the current LR is multiplied by SOLVER.GAMMA at each step
__C.SOLVER.GAMMA = 0.1
# Uniform step size for 'steps' policy
__C.SOLVER.STEP_SIZE = 30000
# Non-uniform step iterations for 'steps_with_decay' or 'steps_with_lrs'
# policies
__C.SOLVER.STEPS = []
# Learning rates to use with 'steps_with_lrs' policy
__C.SOLVER.LRS = []
# Maximum number of SGD iterations
__C.SOLVER.MAX_ITER = 40000
# Momentum to use with SGD
__C.SOLVER.MOMENTUM = 0.9
# L2 regularization hyperparameter
__C.SOLVER.WEIGHT_DECAY = 0.0005
# L2 regularization hyperparameter for GroupNorm's parameters
__C.SOLVER.WEIGHT_DECAY_GN = 0.0
# Whether to double the learning rate for bias
__C.SOLVER.BIAS_DOUBLE_LR = True
# Whether to have weight decay on bias as well
__C.SOLVER.BIAS_WEIGHT_DECAY = False
# Warm up to SOLVER.BASE_LR over this number of SGD iterations
__C.SOLVER.WARM_UP_ITERS = 500
# Start the warm up from SOLVER.BASE_LR * SOLVER.WARM_UP_FACTOR
__C.SOLVER.WARM_UP_FACTOR = 1.0 / 3.0
# WARM_UP_METHOD can be either 'constant' or 'linear' (i.e., gradual)
__C.SOLVER.WARM_UP_METHOD = 'linear'
# Scale the momentum update history by new_lr / old_lr when updating the
# learning rate (this is correct given MomentumSGDUpdateOp)
__C.SOLVER.SCALE_MOMENTUM = True
# Only apply the correction if the relative LR change exceeds this threshold
# (prevents ever change in linear warm up from scaling the momentum by a tiny
# amount; momentum scaling is only important if the LR change is large)
__C.SOLVER.SCALE_MOMENTUM_THRESHOLD = 1.1
# Suppress logging of changes to LR unless the relative change exceeds this
# threshold (prevents linear warm up from spamming the training log)
__C.SOLVER.LOG_LR_CHANGE_THRESHOLD = 1.1
# ---------------------------------------------------------------------------- #
# Fast R-CNN options
# ---------------------------------------------------------------------------- #
__C.FAST_RCNN = AttrDict()
# The type of RoI head to use for bounding box classification and regression
# The string must match a function this is imported in modeling.model_builder
# (e.g., 'head_builder.add_roi_2mlp_head' to specify a two hidden layer MLP)
__C.FAST_RCNN.ROI_BOX_HEAD = ''
# Hidden layer dimension when using an MLP for the RoI box head
__C.FAST_RCNN.MLP_HEAD_DIM = 1024
# Hidden Conv layer dimension when using Convs for the RoI box head
__C.FAST_RCNN.CONV_HEAD_DIM = 256
# Number of stacked Conv layers in the RoI box head
__C.FAST_RCNN.NUM_STACKED_CONVS = 4
# RoI transformation function (e.g., RoIPool or RoIAlign)
# (RoIPoolF is the same as RoIPool; ignore the trailing 'F')
__C.FAST_RCNN.ROI_XFORM_METHOD = 'RoIPoolF'
# Number of grid sampling points in RoIAlign (usually use 2)
# Only applies to RoIAlign
__C.FAST_RCNN.ROI_XFORM_SAMPLING_RATIO = 0
# RoI transform output resolution
# Note: some models may have constraints on what they can use, e.g. they use
# pretrained FC layers like in VGG16, and will ignore this option
__C.FAST_RCNN.ROI_XFORM_RESOLUTION = 14
# ---------------------------------------------------------------------------- #
# RPN options
# ---------------------------------------------------------------------------- #
__C.RPN = AttrDict()
# [Infered value; do not set directly in a config]
# Indicates that the model contains an RPN subnetwork
__C.RPN.RPN_ON = False
# `True` for Detectron implementation. `False` for jwyang's implementation.
__C.RPN.OUT_DIM_AS_IN_DIM = True
# Output dim of conv2d. Ignored if `__C.RPN.OUT_DIM_AS_IN_DIM` is True.
# 512 is the fixed value in jwyang's implementation.
__C.RPN.OUT_DIM = 512
# 'sigmoid' or 'softmax'. Detectron use 'sigmoid'. jwyang use 'softmax'
# This will affect the conv2d output dim for classifying the bg/fg rois
__C.RPN.CLS_ACTIVATION = 'sigmoid'
# RPN anchor sizes given in absolute pixels w.r.t. the scaled network input
# Note: these options are *not* used by FPN RPN; see FPN.RPN* options
__C.RPN.SIZES = (64, 128, 256, 512)
# Stride of the feature map that RPN is attached
__C.RPN.STRIDE = 16
# RPN anchor aspect ratios
__C.RPN.ASPECT_RATIOS = (0.5, 1, 2)
# ---------------------------------------------------------------------------- #
# FPN options
# ---------------------------------------------------------------------------- #
__C.FPN = AttrDict()
# FPN is enabled if True
__C.FPN.FPN_ON = False
# Channel dimension of the FPN feature levels
__C.FPN.DIM = 256
# Initialize the lateral connections to output zero if True
__C.FPN.ZERO_INIT_LATERAL = False
# Stride of the coarsest FPN level
# This is needed so the input can be padded properly
__C.FPN.COARSEST_STRIDE = 32
#
# FPN may be used for just RPN, just object detection, or both
#
# Use FPN for RoI transform for object detection if True
__C.FPN.MULTILEVEL_ROIS = False
# Hyperparameters for the RoI-to-FPN level mapping heuristic
__C.FPN.ROI_CANONICAL_SCALE = 224 # s0
__C.FPN.ROI_CANONICAL_LEVEL = 4 # k0: where s0 maps to
# Coarsest level of the FPN pyramid
__C.FPN.ROI_MAX_LEVEL = 5
# Finest level of the FPN pyramid
__C.FPN.ROI_MIN_LEVEL = 2
# Use FPN for RPN if True
__C.FPN.MULTILEVEL_RPN = False
# Coarsest level of the FPN pyramid
__C.FPN.RPN_MAX_LEVEL = 6
# Finest level of the FPN pyramid
__C.FPN.RPN_MIN_LEVEL = 2
# FPN RPN anchor aspect ratios
__C.FPN.RPN_ASPECT_RATIOS = (0.5, 1, 2)
# RPN anchors start at this size on RPN_MIN_LEVEL
# The anchor size doubled each level after that
# With a default of 32 and levels 2 to 6, we get anchor sizes of 32 to 512
__C.FPN.RPN_ANCHOR_START_SIZE = 32
# [Infered Value] Scale for RPN_POST_NMS_TOP_N.
# Automatically infered in training, fixed to 1 in testing.
__C.FPN.RPN_COLLECT_SCALE = 1
# Use extra FPN levels, as done in the RetinaNet paper
__C.FPN.EXTRA_CONV_LEVELS = False
# Use GroupNorm in the FPN-specific layers (lateral, etc.)
__C.FPN.USE_GN = False
# ---------------------------------------------------------------------------- #
# Mask R-CNN options ("MRCNN" means Mask R-CNN)
# ---------------------------------------------------------------------------- #
__C.MRCNN = AttrDict()
# The type of RoI head to use for instance mask prediction
# The string must match a function this is imported in modeling.model_builder
# (e.g., 'mask_rcnn_heads.ResNet_mask_rcnn_fcn_head_v1up4convs')
__C.MRCNN.ROI_MASK_HEAD = ''
# Resolution of mask predictions
__C.MRCNN.RESOLUTION = 14
# RoI transformation function and associated options
__C.MRCNN.ROI_XFORM_METHOD = 'RoIAlign'
# RoI transformation function (e.g., RoIPool or RoIAlign)
__C.MRCNN.ROI_XFORM_RESOLUTION = 7
# Number of grid sampling points in RoIAlign (usually use 2)
# Only applies to RoIAlign
__C.MRCNN.ROI_XFORM_SAMPLING_RATIO = 0
# Number of channels in the mask head
__C.MRCNN.DIM_REDUCED = 256
# Use dilated convolution in the mask head
__C.MRCNN.DILATION = 2
# Upsample the predicted masks by this factor
__C.MRCNN.UPSAMPLE_RATIO = 1
# Use a fully-connected layer to predict the final masks instead of a conv layer
__C.MRCNN.USE_FC_OUTPUT = False
# Weight initialization method for the mask head and mask output layers. ['GaussianFill', 'MSRAFill']
__C.MRCNN.CONV_INIT = 'GaussianFill'
# Use class specific mask predictions if True (otherwise use class agnostic mask
# predictions)
__C.MRCNN.CLS_SPECIFIC_MASK = True
# Multi-task loss weight for masks
__C.MRCNN.WEIGHT_LOSS_MASK = 1.0
# Binarization threshold for converting soft masks to hard masks
__C.MRCNN.THRESH_BINARIZE = 0.5
__C.MRCNN.MEMORY_EFFICIENT_LOSS = True # TODO
# ---------------------------------------------------------------------------- #
# Keyoint Mask R-CNN options ("KRCNN" = Mask R-CNN with Keypoint support)
# ---------------------------------------------------------------------------- #
__C.KRCNN = AttrDict()
# The type of RoI head to use for instance keypoint prediction
# The string must match a function this is imported in modeling.model_builder
# (e.g., 'keypoint_rcnn_heads.add_roi_pose_head_v1convX')
__C.KRCNN.ROI_KEYPOINTS_HEAD = ''
# Output size (and size loss is computed on), e.g., 56x56
__C.KRCNN.HEATMAP_SIZE = -1
# Use bilinear interpolation to upsample the final heatmap by this factor
__C.KRCNN.UP_SCALE = -1
# Apply a ConvTranspose layer to the hidden representation computed by the
# keypoint head prior to predicting the per-keypoint heatmaps
__C.KRCNN.USE_DECONV = False
# Channel dimension of the hidden representation produced by the ConvTranspose
__C.KRCNN.DECONV_DIM = 256
# Use a ConvTranspose layer to predict the per-keypoint heatmaps
__C.KRCNN.USE_DECONV_OUTPUT = False
# Use dilation in the keypoint head
__C.KRCNN.DILATION = 1
# Size of the kernels to use in all ConvTranspose operations
__C.KRCNN.DECONV_KERNEL = 4
# Number of keypoints in the dataset (e.g., 17 for COCO)
__C.KRCNN.NUM_KEYPOINTS = -1
# Number of stacked Conv layers in keypoint head
__C.KRCNN.NUM_STACKED_CONVS = 8
# Dimension of the hidden representation output by the keypoint head
__C.KRCNN.CONV_HEAD_DIM = 256
# Conv kernel size used in the keypoint head
__C.KRCNN.CONV_HEAD_KERNEL = 3
# Conv kernel weight filling function
__C.KRCNN.CONV_INIT = 'GaussianFill'
# Use NMS based on OKS if True
__C.KRCNN.NMS_OKS = False
# Source of keypoint confidence
# Valid options: ('bbox', 'logit', 'prob')
__C.KRCNN.KEYPOINT_CONFIDENCE = 'bbox'
# Standard ROI XFORM options (see FAST_RCNN or MRCNN options)
__C.KRCNN.ROI_XFORM_METHOD = 'RoIAlign'
__C.KRCNN.ROI_XFORM_RESOLUTION = 7
__C.KRCNN.ROI_XFORM_SAMPLING_RATIO = 0
# Minimum number of labeled keypoints that must exist in a minibatch (otherwise
# the minibatch is discarded)
__C.KRCNN.MIN_KEYPOINT_COUNT_FOR_VALID_MINIBATCH = 20
# When infering the keypoint locations from the heatmap, don't scale the heatmap
# below this minimum size
__C.KRCNN.INFERENCE_MIN_SIZE = 0
# Multi-task loss weight to use for keypoints
# Recommended values:
# - use 1.0 if KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS is True
# - use 4.0 if KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS is False
__C.KRCNN.LOSS_WEIGHT = 1.0
# Normalize by the total number of visible keypoints in the minibatch if True.
# Otherwise, normalize by the total number of keypoints that could ever exist
# in the minibatch. See comments in modeling.model_builder.add_keypoint_losses
# for detailed discussion.
__C.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS = True
# ---------------------------------------------------------------------------- #
# R-FCN options
# ---------------------------------------------------------------------------- #
__C.RFCN = AttrDict()
# Position-sensitive RoI pooling output grid size (height and width)
__C.RFCN.PS_GRID_SIZE = 3
# ---------------------------------------------------------------------------- #
# ResNets options ("ResNets" = ResNet and ResNeXt)
# ---------------------------------------------------------------------------- #
__C.RESNETS = AttrDict()
# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt
__C.RESNETS.NUM_GROUPS = 1
# Baseline width of each group
__C.RESNETS.WIDTH_PER_GROUP = 64
# Place the stride 2 conv on the 1x1 filter
# Use True only for the original MSRA ResNet; use False for C2 and Torch models
__C.RESNETS.STRIDE_1X1 = True
# Residual transformation function
__C.RESNETS.TRANS_FUNC = 'bottleneck_transformation'
# ResNet's stem function (conv1 and pool1)
__C.RESNETS.STEM_FUNC = 'basic_bn_stem'
# ResNet's shortcut function
__C.RESNETS.SHORTCUT_FUNC = 'basic_bn_shortcut'
# Apply dilation in stage "res5"
__C.RESNETS.RES5_DILATION = 1
# Freeze model weights before and including which block.
# Choices: [0, 2, 3, 4, 5]. O means not fixed. First conv and bn are defaults to
# be fixed.
__C.RESNETS.FREEZE_AT = 2
# Path to pretrained resnet weights on ImageNet.
# If start with '/', then it is treated as a absolute path.
# Otherwise, treat as a relative path to __C.ROOT_DIR
__C.RESNETS.IMAGENET_PRETRAINED_WEIGHTS = ''
# Use GroupNorm instead of BatchNorm
__C.RESNETS.USE_GN = False
# ---------------------------------------------------------------------------- #
# GroupNorm options
# ---------------------------------------------------------------------------- #
__C.GROUP_NORM = AttrDict()
# Number of dimensions per group in GroupNorm (-1 if using NUM_GROUPS)
__C.GROUP_NORM.DIM_PER_GP = -1
# Number of groups in GroupNorm (-1 if using DIM_PER_GP)
__C.GROUP_NORM.NUM_GROUPS = 32
# GroupNorm's small constant in the denominator
__C.GROUP_NORM.EPSILON = 1e-5
# ---------------------------------------------------------------------------- #
# MISC options
# ---------------------------------------------------------------------------- #
# Number of GPUs to use (applies to both training and testing)
__C.NUM_GPUS = 1
# The mapping from image coordinates to feature map coordinates might cause
# some boxes that are distinct in image space to become identical in feature
# coordinates. If DEDUP_BOXES > 0, then DEDUP_BOXES is used as the scale factor
# for identifying duplicate boxes.
# 1/16 is correct for {Alex,Caffe}Net, VGG_CNN_M_1024, and VGG16
__C.DEDUP_BOXES = 1. / 16.
# Clip bounding box transformation predictions to prevent np.exp from
# overflowing
# Heuristic choice based on that would scale a 16 pixel anchor up to 1000 pixels
__C.BBOX_XFORM_CLIP = np.log(1000. / 16.)
# Pixel mean values (BGR order) as a (1, 1, 3) array
# We use the same pixel mean for all networks even though it's not exactly what
# they were trained with
# "Fun" fact: the history of where these values comes from is lost (From Detectron lol)
__C.PIXEL_MEANS = np.array([[[102.9801, 115.9465, 122.7717]]])
# For reproducibility
__C.RNG_SEED = 3
# A small number that's used many times
__C.EPS = 1e-14
# Root directory of project
__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..'))
# Output basedir
__C.OUTPUT_DIR = 'Outputs'
# Name (or path to) the matlab executable
__C.MATLAB = 'matlab'
# Dump detection visualizations
__C.VIS = False
# Score threshold for visualization
__C.VIS_TH = 0.9
# Expected results should take the form of a list of expectations, each
# specified by four elements (dataset, task, metric, expected value). For
# example: [['coco_2014_minival', 'box_proposal', 'AR@1000', 0.387]]
__C.EXPECTED_RESULTS = []
# Absolute and relative tolerance to use when comparing to EXPECTED_RESULTS
__C.EXPECTED_RESULTS_RTOL = 0.1
__C.EXPECTED_RESULTS_ATOL = 0.005
# Set to send email in case of an EXPECTED_RESULTS failure
__C.EXPECTED_RESULTS_EMAIL = ''
# ------------------------------
# Data directory
__C.DATA_DIR = osp.abspath(osp.join(__C.ROOT_DIR, 'data'))
# [Deprecate]
__C.POOLING_MODE = 'crop'
# [Deprecate] Size of the pooled region after RoI pooling
__C.POOLING_SIZE = 7
__C.CROP_RESIZE_WITH_MAX_POOL = True
# [Infered value]
__C.CUDA = False
__C.DEBUG = False
# [Infered value]
__C.PYTORCH_VERSION_LESS_THAN_040 = False
# ---------------------------------------------------------------------------- #
# mask heads or keypoint heads that share res5 stage weights and
# training forward computation with box head.
# ---------------------------------------------------------------------------- #
_SHARE_RES5_HEADS = set(
[