From 1f4fae51bc55b9c078f19d9c7c9919fdebc16d60 Mon Sep 17 00:00:00 2001 From: yuanuayuan <784968994@qq.com> Date: Thu, 1 Nov 2018 18:13:59 -0400 Subject: [PATCH] Update SubLayers.py --- transformer/SubLayers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer/SubLayers.py b/transformer/SubLayers.py index 0298a19..b023a21 100644 --- a/transformer/SubLayers.py +++ b/transformer/SubLayers.py @@ -50,7 +50,8 @@ def forward(self, q, k, v, mask=None): k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv - mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. + if mask is not None: + mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. output, attn = self.attention(q, k, v, mask=mask) output = output.view(n_head, sz_b, len_q, d_v)