From 5579e4b4c525c46ee878e72014ef5ecd6c36af84 Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Fri, 29 Dec 2023 16:17:49 +0100 Subject: [PATCH] restored masked scaled dot attention --- onmt/modules/multi_headed_attn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index d71ffbc460..1780871d9c 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -439,6 +439,7 @@ def forward( if ( step == 0 or not self.flash2 + or self.self_attn_type != "scaled-dot-flash" or self.max_relative_positions not in [0, -1] or query.size(0) > 128 or query.dtype != torch.float16 @@ -685,6 +686,8 @@ def forward( scores = self.alibi(scores) scores = scores.float() + if key_pad_mask is not None and mask is None: + mask = key_pad_mask.unsqueeze(1) if mask is not None: # not 100% necessary but expand to nb of heads