diff --git a/opacus/layers/dp_multihead_attention.py b/opacus/layers/dp_multihead_attention.py index 40b5c8ed..8cc18c2e 100644 --- a/opacus/layers/dp_multihead_attention.py +++ b/opacus/layers/dp_multihead_attention.py @@ -203,17 +203,12 @@ def forward( r""" Using the same logic with ``nn.MultiheadAttention`` (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html). """ - if self.batch_first: - if key is value: - if query is key: - query = key = value = query.transpose(1, 0) - else: - query, key = [x.transpose(1, 0) for x in (query, key)] - value = key - else: - query, key, value = [x.transpose(1, 0) for x in (query, key, value)] - tgt_len, bsz, embed_dim = query.size() + if not self.batch_first: + tgt_len, bsz, embed_dim = query.size() + else: + bsz, tgt_len, embed_dim = query.size() + if embed_dim != self.embed_dim: raise ValueError( f"query has as size of {embed_dim} while the embedding" @@ -234,6 +229,9 @@ def forward( q = q * scaling + if self.batch_first: + q, k, v = [x.transpose(0, 1) for x in (q, k, v)] + if attn_mask is not None: if attn_mask.dtype not in ( torch.float32, @@ -352,13 +350,14 @@ def forward( attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim] - attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - ) - attn_output = self.out_proj(attn_output) if self.batch_first: - attn_output = attn_output.transpose(1, 0) + attn_output = attn_output.contiguous().view(bsz, tgt_len, embed_dim) + else: + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + ) + attn_output = self.out_proj(attn_output) if need_weights: # average attention weights over heads