Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in DPOptimizer: Inconsistency between batch_first argument of PrivacyEngine and DPMultiheadAttention #650

Closed
tklausen opened this issue May 9, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@tklausen
Copy link
Contributor

tklausen commented May 9, 2024

🐛 Bug

Context

Both PrivacyEngine and DPMultiheadAttention accept the bool argument batch_first, which indicates whether the batch dimension is the first or second dimension. In the case of the PrivacyEngine, this argument is passed down to the GradSampleModule which ensures that the batch dimension is always the first dimension in .grad_samples (=per-sample gradients) (see rearrange_grad_samples), so that the grad_samples can be used by DPOptimizer.

Problem

Using PrivacyEngine and DPMultiheadAttention both with batch_first=True mixes up the batch dimension and can throw an error.

DPMultiheadAttention reorders its inputs to the forward method (query, key, value) so that the batch dimension is the second dimension (and the sequence dimension is the first dimension) if batch_first=True. Therefore, the internal linear layers of DPMultiheadAttention are called with an input whose second dimension is the batch dimension. However, the GradSampleModule expects the batch dimension to be the first dimension (because batch_first was set to True in the PrivacyEngine). Thus, the computed gradients are not per-sample gradients. This even throws an error if the model uses an additional layer other than DPMultiheadAttention whose input is batch dimension first. This error is thrown during a torch.stack operation in the DPOptimizer's clip_and_accumulate method.

To Reproduce

See Colab.

  1. Initialize PrivacyEngine with batch_first=True
  2. Create a model that has at least:
    a. one DPMultiheadAttention layer with batch_first=True
    b. one other layer such as nn.Linear
  3. Ensure that batch size != sequence length of input to DPMultiheadAttention layer

Stack trace:

[<ipython-input-2-e63f0910143c>](https://localhost:8080/#) in train(model, criterion, optimizer, train_loader, device)
     12         loss.backward()
     13 
---> 14         optimizer.step()
     15         optimizer.zero_grad()
     16 

[/content/opacus/opacus/optimizers/optimizer.py](https://localhost:8080/#) in step(self, closure)
    518                 closure()
    519 
--> 520         if self.pre_step():
    521             return self.original_optimizer.step()
    522         else:

[/content/opacus/opacus/optimizers/optimizer.py](https://localhost:8080/#) in pre_step(self, closure)
    499             return True
    500 
--> 501         self.clip_and_accumulate()
    502         if self._check_skip_next_step():
    503             self._is_last_step_skipped = True

[/content/opacus/opacus/optimizers/optimizer.py](https://localhost:8080/#) in clip_and_accumulate(self)
    404                 g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
    405             ]
--> 406             per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
    407             per_sample_clip_factor = (
    408                 self.max_grad_norm / (per_sample_norms + 1e-6)

RuntimeError: stack expects each tensor to be equal size, but got [16] at entry 0 and [8] at entry 2

Expected Behavior

The per-sample gradients are computed correctly and no error is thrown if batch_first has the same value in both PrivacyEngine and DPMultiheadAttention.

For batch_first=False, no changes are required.

For batch_first=True, the DPMultiheadAttention layer should call its internal linear layers with an input whose first dimension is the batch dimension.

Environment

opacus: 1.4.1
pytorch: 2.2.1

Other packages should not be relevant as this is a pure Opacus bug.

Additional context

This issue may be related to #505, but I can't confirm this as the source code for this issue seems to have been deleted.

@tklausen tklausen changed the title Error in DPOptimizer when using DPMultiheadAttention: DPMultiheadAttention seems to violate the batch first assumption of DPOptimizer Error in DPOptimizer: Inconsistency between batch_first argument of PrivacyEngine and DPMultiheadAttention May 10, 2024
@HuanyuZhang HuanyuZhang added the bug Something isn't working label May 15, 2024
@HuanyuZhang
Copy link
Contributor

Thanks for contributing to Opacus! Great catch!
Let me launch some fix. Need to guarantee that the input of all the linear layers inside DPMultiheadAttention has batch_size as the first dimension of input when batch_first = True.

HuanyuZhang added a commit to HuanyuZhang/opacus that referenced this issue May 16, 2024
Summary:
When batch_first = True, the activation and partial gradient for each linear layer in DPMultiheadAttention still has batch_size in the second dimension, thus causing wrong gradient shape in optimizer.step().

Details in: pytorch#650

Differential Revision: D57446245
HuanyuZhang added a commit to HuanyuZhang/opacus that referenced this issue May 16, 2024
…h#651)

Summary:

When batch_first = True, the activation and partial gradient for each linear layer in DPMultiheadAttention still has batch_size in the second dimension, thus causing wrong gradient shape in optimizer.step().

Details in: pytorch#650

Differential Revision: D57446245
HuanyuZhang added a commit to HuanyuZhang/opacus that referenced this issue May 30, 2024
…h#651)

Summary:

When batch_first = True, the activation and partial gradient for each linear layer in DPMultiheadAttention still has batch_size in the second dimension, thus causing wrong gradient shape in optimizer.step().

Details in: pytorch#650

Differential Revision: D57446245
HuanyuZhang added a commit to HuanyuZhang/opacus that referenced this issue May 31, 2024
…h#651)

Summary:

When batch_first = True, the activation and partial gradient for each linear layer in DPMultiheadAttention still has batch_size in the second dimension, thus causing wrong gradient shape in optimizer.step().

Details in: pytorch#650

Reviewed By: EnayatUllah

Differential Revision: D57446245
facebook-github-bot pushed a commit that referenced this issue May 31, 2024
Summary:
Pull Request resolved: #651

When batch_first = True, the activation and partial gradient for each linear layer in DPMultiheadAttention still has batch_size in the second dimension, thus causing wrong gradient shape in optimizer.step().

Details in: #650

Reviewed By: EnayatUllah

Differential Revision: D57446245

fbshipit-source-id: c0f0e3643c802e51afe7ddb6bb054e5447845f32
@HuanyuZhang
Copy link
Contributor

Closed the issue since we launched the fix in #651

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants