-
Notifications
You must be signed in to change notification settings - Fork 0
/
early_stopping_scheduler.py
80 lines (71 loc) · 3.02 KB
/
early_stopping_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""A modified ReduceLROnPlateau-like scheduler for early stopping."""
class EarlyStopping:
"""
ReduceLROnPlateau with Early Stopping and model restoration.
Inputs (see PyTorch docs on ReduceLROnPlateau for more info):
- optimizer: PyTorch Optimizer
- factor: float, factor to multiply the learning rate
- mode: str in 'min', 'max'. Whether min or max value of metric
is best
- patience: int, number of epochs with no improvement after
which learning rate will be reduced.
- threshold: float, threshold for measuring the new optimum
- threshold_mode: str in 'rel', 'abs'
- max_decays: int, max number of lr decays allowed
"""
def __init__(self, optimizer, factor=0.1, mode='min', patience=1,
threshold=1e-4, threshold_mode='rel', max_decays=2):
"""Initialize scheduler."""
assert (
factor < 1.0
and mode in ('min', 'max')
and threshold_mode in ('rel', 'abs')
)
self.optimizer = optimizer
self.mode = mode
self.factor = factor
self.patience = patience
self.threshold = threshold
self.threshold_mode = threshold_mode
self.max_decays = max_decays
self.best = None
self.num_bad_epochs = 0
self.last_epoch = -1
self.decay_times = 0
def step(self, metrics):
"""Scheduler step."""
self.last_epoch += 1
if self.best is None or self._is_better(metrics, self.best):
self.best = metrics
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1
if self.num_bad_epochs > self.patience:
self.decay_times += 1
self.reduce_lr()
self.num_bad_epochs = 0
return (
False, # improved?
self.last_epoch - self.patience - 1, # checkpoint to back-off
self.decay_times <= self.max_decays) # keep training?
return self.num_bad_epochs == 0, self.last_epoch, True
def reduce_lr(self):
"""Apply learning rate decay on all parameters."""
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.factor * float(param_group['lr'])
def _is_better(self, curr, best):
if self.mode == 'min' and self.threshold_mode == 'rel':
rel_epsilon = 1. - self.threshold
return curr < best * rel_epsilon
if self.mode == 'min' and self.threshold_mode == 'abs':
return curr < best - self.threshold
if self.mode == 'max' and self.threshold_mode == 'rel':
rel_epsilon = self.threshold + 1.
return curr > best * rel_epsilon
return curr > best + self.threshold
def state_dict(self):
"""Return state dict except for optimizer."""
return {k: v for k, v in self.__dict__.items() if k != 'optimizer'}
def load_state_dict(self, state_dict):
"""Load state dict from checkpoint."""
self.__dict__.update(state_dict)