Skip to content

Commit

Permalink
Fix gradient shape error for DPMultiheadAttention (issue 650) (#651)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #651

When batch_first = True, the activation and partial gradient for each linear layer in DPMultiheadAttention still has batch_size in the second dimension, thus causing wrong gradient shape in optimizer.step().

Details in: #650

Reviewed By: EnayatUllah

Differential Revision: D57446245

fbshipit-source-id: c0f0e3643c802e51afe7ddb6bb054e5447845f32
  • Loading branch information
HuanyuZhang authored and facebook-github-bot committed May 31, 2024
1 parent 7d65ddf commit 202c58a
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions opacus/layers/dp_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,12 @@ def forward(
r"""
Using the same logic with ``nn.MultiheadAttention`` (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html).
"""
if self.batch_first:
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

tgt_len, bsz, embed_dim = query.size()
if not self.batch_first:
tgt_len, bsz, embed_dim = query.size()
else:
bsz, tgt_len, embed_dim = query.size()

if embed_dim != self.embed_dim:
raise ValueError(
f"query has as size of {embed_dim} while the embedding"
Expand All @@ -234,6 +229,9 @@ def forward(

q = q * scaling

if self.batch_first:
q, k, v = [x.transpose(0, 1) for x in (q, k, v)]

if attn_mask is not None:
if attn_mask.dtype not in (
torch.float32,
Expand Down Expand Up @@ -352,13 +350,14 @@ def forward(

attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)
attn_output = self.out_proj(attn_output)

if self.batch_first:
attn_output = attn_output.transpose(1, 0)
attn_output = attn_output.contiguous().view(bsz, tgt_len, embed_dim)
else:
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)
attn_output = self.out_proj(attn_output)

if need_weights:
# average attention weights over heads
Expand Down

0 comments on commit 202c58a

Please sign in to comment.