-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss_func.py
60 lines (56 loc) · 2.96 KB
/
loss_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
import torch.nn.functional as F
class CE_Loss(nn.Module):
def __init__(self, opt):
super(CE_Loss, self).__init__()
self.celoss = nn.CrossEntropyLoss(reduction='none')
self.logsoftmax = nn.LogSoftmax(dim=-1)
self.opt = opt
def forward(self, inputs, outputs, targets, task, model=None, gamma=None, pseudo=False, inference=False):
if task == 'ae':
text = inputs[0]
text_mask = (text!=0)
text_len = torch.sum(text!=0, dim=-1)
if pseudo:
if task == 'ae':
pseudo_targets = torch.argmax(outputs, dim=-1)
probs = F.softmax(outputs, dim=-1)
maxp = torch.max(probs, dim=-1).values
maxp = torch.where(text_mask!=0, maxp, torch.zeros_like(maxp))
maxp = torch.sum(maxp, dim=-1).div(text_len.float())
beta = torch.sigmoid(model.threshold) if self.opt.learned_beta else self.opt.beta
if self.opt.hard_D:
weight = torch.where(maxp > beta, torch.ones_like(maxp), torch.zeros_like(maxp))
else:
weight = torch.sigmoid(self.opt.alpha * (maxp - beta))
pseudo_logits = self.logsoftmax(outputs)
pseudo_targets = F.one_hot(pseudo_targets, num_classes=3).float()
log_likelihood = torch.sum(pseudo_targets * pseudo_logits, dim=-1)
log_likelihood = torch.where(text_mask!=0, log_likelihood, torch.zeros_like(log_likelihood))
log_likelihood = torch.sum(log_likelihood, dim=-1).div(text_len.float())
loss = -torch.mean(weight * log_likelihood)
else:
pseudo_targets = torch.argmax(outputs, dim=-1)
probs = F.softmax(outputs, dim=-1)
maxp = torch.max(probs, dim=-1).values
beta = torch.sigmoid(model.threshold) if self.opt.learned_beta else self.opt.beta
if self.opt.hard_D:
weight = torch.where(maxp > beta, torch.ones_like(maxp), torch.zeros_like(maxp))
else:
weight = torch.sigmoid(self.opt.alpha * (maxp - beta))
loss = gamma * torch.sum(weight * self.celoss(outputs, pseudo_targets)).div(weight.sum()+1e-6)
else:
if task == 'ae':
logits = self.logsoftmax(outputs)
targets = F.one_hot(targets, num_classes=3).float()
log_likelihood = torch.sum(targets * logits, dim=-1)
log_likelihood = torch.where(text_mask!=0, log_likelihood, torch.zeros_like(log_likelihood))
log_likelihood = torch.sum(log_likelihood, dim=-1).div(text_len.float())
loss = -torch.mean(log_likelihood)
else:
loss = torch.mean(self.celoss(outputs, targets))
if pseudo and inference:
return loss, weight
else:
return loss