Skip to content

Commit

Permalink
One backward function for Ghost Clipping (pytorch#661)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#661

Simplfied training loop for ghost clipping using only one "double backward" function.

Reviewed By: HuanyuZhang

Differential Revision: D60427371

fbshipit-source-id: 73c016a31f0692adcfa3f6838e74315fbed26bb1
  • Loading branch information
EnayatUllah authored and facebook-github-bot committed Aug 1, 2024
1 parent a059670 commit 4804a51
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 14 deletions.
2 changes: 1 addition & 1 deletion opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class DistributedDPOptimizerFastGradientClipping(DPOptimizerFastGradientClipping):
"""
:class:`~opacus.optimizers.optimizer.DPOptimizer` compatible with
:class:`opacus.optimizers.optimizer.DPOptimizer` compatible with
distributed data processing
"""

Expand Down
18 changes: 5 additions & 13 deletions opacus/tests/grad_sample_module_fast_gradient_clipping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from hypothesis import given, settings
from opacus.grad_sample import GradSampleModule, GradSampleModuleFastGradientClipping
from opacus.optimizers import DPOptimizer, DPOptimizerFastGradientClipping
from opacus.utils.fast_gradient_clipping_utils import double_backward
from opacus.utils.per_sample_gradients_utils import clone_module
from torch.utils.data import DataLoader, Dataset

Expand Down Expand Up @@ -108,7 +109,7 @@ def setUp_data_sequantial(self, size, length, dim):
@settings(deadline=1000000)
def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
"""
Tests if norm calculation is same between standard (opacus) and fast gradient clipping"
Tests if norm calculation is the same between standard (opacus) and fast gradient clipping"
"""
self.length = length
self.size = size
Expand Down Expand Up @@ -189,7 +190,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
@settings(deadline=1000000)
def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
"""
Tests if gradients are same between standard (opacus) and fast gradient clipping"
Tests if gradients are the same between standard (opacus) and fast gradient clipping, using double_backward function"
"""

noise_multiplier = 0.0
Expand Down Expand Up @@ -237,19 +238,10 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
]
flat_grads_normal = torch.cat([p.flatten() for p in all_grads_normal])

self.grad_sample_module.enable_hooks()
output_gc = self.grad_sample_module(input_data)

first_loss_per_sample = self.criterion(output_gc, target_data)
first_loss = torch.mean(first_loss_per_sample)
first_loss.backward(retain_graph=True)

optimizer_gc.zero_grad()
coeff = self.grad_sample_module.get_coeff()
second_loss_per_sample = coeff * first_loss_per_sample
second_loss = torch.sum(second_loss_per_sample)
self.grad_sample_module.disable_hooks()
second_loss.backward()
double_backward(self.grad_sample_module, optimizer_gc, first_loss_per_sample)

all_grads_gc = [param.grad for param in self.grad_sample_module.parameters()]
flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc])
Expand All @@ -261,5 +253,5 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
]
)
logging.info(f"Diff = {diff}")
msg = "FAIL: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg
47 changes: 47 additions & 0 deletions opacus/utils/fast_gradient_clipping_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import (
GradSampleModuleFastGradientClipping,
)
from opacus.optimizers import DPOptimizerFastGradientClipping


def double_backward(
module: GradSampleModuleFastGradientClipping,
optimizer: DPOptimizerFastGradientClipping,
loss_per_sample: torch.Tensor,
) -> None:
"""
Packages the training loop for Fast Gradient and Ghost Clipping. It does the two backward passes, as well as the loss rescaling and hook operations in between.
This function also works with DistributedDPOptimizer.
Args:
module: The DP gradient sample module to train
optimizer: The DP optimizer used to train the module
loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1]
Returns:
None
"""

torch.mean(loss_per_sample).backward(retain_graph=True)
optimizer.zero_grad()
rescaled_loss_per_sample = module.get_coeff() * loss_per_sample
rescaled_loss = torch.sum(rescaled_loss_per_sample)
module.disable_hooks()
rescaled_loss.backward()
module.enable_hooks()

0 comments on commit 4804a51

Please sign in to comment.