Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Self attention for pooling linear classifier #28

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion fastai_contrib/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def bilm_text_classifier_learner(data: DataBunch, bptt: int = 70, max_len: int =
ds = data.train_ds
vocab_size, n_class = len(data.vocab.itos), data.c
if bicls_head == 'BiPoolingLinearClassifier':
count = 3*2
count = 3 * 2
elif if bicls_head == 'BiAttentionPoolingClassifier':
count = 5
else:
count = 3
layers = [emb_sz * count] + lin_ftrs + [n_class]
Expand Down
134 changes: 127 additions & 7 deletions fastai_contrib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,118 @@ def forward(self, input:LongTensor)->Tuple[Tensor,Tensor]:
outputs.append(o)
return self.concat(raw_outputs), self.concat(outputs)

class BiAttentionPoolingClassifier(nn.Module):
r" BiLM Pooling with self attention"

def __init__(self, layers:Collection[int], drops:Collection[float], emb_sz:int):
super().__init__()
mod_layers = []
activs = [nn.ReLU(inplace=True)] * (len(layers) - 2) + [None]
for n_in,n_out,p,actn in zip(layers[:-1],layers[1:], drops, activs):
mod_layers += bn_drop_lin(n_in, n_out, p=p, actn=actn)
self.self_attn = MultiHeadAttention(n_head=8, d_model=emb_sz, d_k=64, d_v=64, dropout=0.1)
self.layers = nn.Sequential(*mod_layers)

def pool(self, x:Tensor, bs:int, is_max:bool):
"Pool the tensor along the seq_len dimension."
f = F.adaptive_max_pool1d if is_max else F.adaptive_avg_pool1d
return f(x.permute(2, 0, 1), (1,)).view(bs,-1)

def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
raw_outputs, outputs = input
output = outputs[-1]
assert len(output.size()) == 4, 'Expected input dimension 4'
bs, sl, em_sz, passes = output.size()

x = torch.cat([output[..., 0], output[..., 1]], 1)
x, _ = self.self_attn(x, x, x)

avgpool = self.pool(x, bs, False)
mxpool = self.pool(x, bs, True)

x = torch.cat([output[:,-1,..., 0], x, mxpool,
avgpool, output[:,-1,..., 1]], 1)
x = self.layers(x)
return x, raw_outputs, outputs

class ScaledDotProductAttention(nn.Module):
r"""
Scaled Dot-Product Attention
based on: https://github.com/jadore801120/attention-is-all-you-need-pytorch
"""

def __init__(self, temperature:float, attn_dropout:float=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(dim=2)

def forward(self, q, k, v):

attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature

attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.bmm(attn, v)

return output, attn

class MultiHeadAttention(nn.Module):
r"""
Multi-Head Attention module
based on: https://github.com/jadore801120/attention-is-all-you-need-pytorch
"""

def __init__(self, n_head:int, d_model:int, d_k:int, d_v:int, dropout:float=0.1):
super().__init__()

self.n_head = n_head
self.d_k = d_k
self.d_v = d_v

self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
self.layer_norm = nn.LayerNorm(d_model)

self.fc = nn.Linear(n_head * d_v, d_model)
nn.init.xavier_normal_(self.fc.weight)

self.dropout = nn.Dropout(dropout)

def forward(self, q, k, v):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head


sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
residual = q

q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv

x, attn = self.attention(q, k, v)

x = x.view(n_head, sz_b, len_q, d_v)
x = x.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)

x = self.dropout(self.fc(x))
x = self.layer_norm(x + residual)

return x, attn

class BiPoolingLinearClassifier(PoolingLinearClassifier):
"Create a linear classifier with pooling."

Expand Down Expand Up @@ -126,7 +238,6 @@ def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
x = self.layers(x)
return x, raw_outputs, outputs


def get_bilm(vocab_sz:int, emb_sz:int, n_hid:int, n_layers:int, pad_token:int, tie_weights:bool=True,
qrnn:bool=False, bias:bool=True, bidir:bool=False, output_p:float=0.4, hidden_p:float=0.2, input_p:float=0.6,
embed_p:float=0.1, weight_p:float=0.5)->nn.Module:
Expand Down Expand Up @@ -156,11 +267,20 @@ def get_birnn_classifier(bptt:int, max_seq:int, n_class:int, vocab_sz:int, emb_s
qrnn=qrnn, hidden_p=hidden_p, input_p=input_p, embed_p=embed_p, weight_p=weight_p)

head = BiPoolingLinearClassifier
if bicls_head == 'BiPoolingLinearClassifier': head = BiPoolingLinearClassifier
elif bicls_head == 'AvgPoolingLinearClassifier': head = AvgPoolingLinearClassifier

model = SequentialRNN(BiLMModel(fwd_rnn_enc, bwd_rnn_enc), head(layers, drops))
model.reset()

if bicls_head == 'BiPoolingLinearClassifier':
head = BiPoolingLinearClassifier
model = SequentialRNN(BiLMModel(fwd_rnn_enc, bwd_rnn_enc), head(layers, drops))
elif bicls_head == 'AvgPoolingLinearClassifier':
head = AvgPoolingLinearClassifier
model = SequentialRNN(BiLMModel(fwd_rnn_enc, bwd_rnn_enc), head(layers, drops))
elif bicls_head == 'BiAttentionPoolingClassifier':
head = BiAttentionPoolingClassifier
# attention requires an additional argument
# maybe use kwargs for initialising classes
model = SequentialRNN(BiLMModel(fwd_rnn_enc, bwd_rnn_enc), head(layers, drops, emb_sz))

model.reset()
return model

#endregion
#endregion