-
Notifications
You must be signed in to change notification settings - Fork 103
/
utils.py
156 lines (123 loc) · 4.38 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import shutil
import torch
import sys
import os
import json
import numpy as np
from config import config
from torch import nn
import torch.nn.functional as F
def save_checkpoint(state, is_best,fold):
filename = config.weights + config.model_name + os.sep +str(fold) + os.sep + "_checkpoint.pth.tar"
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, config.best_models + config.model_name+ os.sep +str(fold) + os.sep + 'model_best.pth.tar')
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 3 epochs"""
lr = config.lr * (0.1 ** (epoch // 3))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def schedule(current_epoch, current_lrs, **logs):
lrs = [1e-3, 1e-4, 0.5e-4, 1e-5, 0.5e-5]
epochs = [0, 1, 6, 8, 12]
for lr, epoch in zip(lrs, epochs):
if current_epoch >= epoch:
current_lrs[5] = lr
if current_epoch >= 2:
current_lrs[4] = lr * 1
current_lrs[3] = lr * 1
current_lrs[2] = lr * 1
current_lrs[1] = lr * 1
current_lrs[0] = lr * 0.1
return current_lrs
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class Logger(object):
def __init__(self):
self.terminal = sys.stdout #stdout
self.file = None
def open(self, file, mode=None):
if mode is None: mode ='w'
self.file = open(file, mode)
def write(self, message, is_terminal=1, is_file=1 ):
if '\r' in message: is_file=0
if is_terminal == 1:
self.terminal.write(message)
self.terminal.flush()
#time.sleep(1)
if is_file == 1:
self.file.write(message)
self.file.flush()
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
def get_learning_rate(optimizer):
lr=[]
for param_group in optimizer.param_groups:
lr +=[ param_group['lr'] ]
#assert(len(lr)==1) #we support only one param_group
lr = lr[0]
return lr
def time_to_str(t, mode='min'):
if mode=='min':
t = int(t)/60
hr = t//60
min = t%60
return '%2d hr %02d min'%(hr,min)
elif mode=='sec':
t = int(t)
min = t//60
sec = t%60
return '%2d min %02d sec'%(min,sec)
else:
raise NotImplementedError
class FocalLoss(nn.Module):
def __init__(self, focusing_param=2, balance_param=0.25):
super(FocalLoss, self).__init__()
self.focusing_param = focusing_param
self.balance_param = balance_param
def forward(self, output, target):
cross_entropy = F.cross_entropy(output, target)
cross_entropy_log = torch.log(cross_entropy)
logpt = - F.cross_entropy(output, target)
pt = torch.exp(logpt)
focal_loss = -((1 - pt) ** self.focusing_param) * logpt
balanced_focal_loss = self.balance_param * focal_loss
return balanced_focal_loss
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(MyEncoder, self).default(obj)