From c35262d778545e468e8526b4acd3e9a91690f954 Mon Sep 17 00:00:00 2001 From: Dariush Wahdany Date: Fri, 22 Mar 2024 18:26:08 +0100 Subject: [PATCH] Fix BatchMemoryManager length --- opacus/utils/batch_memory_manager.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/opacus/utils/batch_memory_manager.py b/opacus/utils/batch_memory_manager.py index c5d6dcc0..a2e2de62 100644 --- a/opacus/utils/batch_memory_manager.py +++ b/opacus/utils/batch_memory_manager.py @@ -16,12 +16,13 @@ from typing import List import numpy as np +from torch.utils.data import BatchSampler, DataLoader, Sampler + from opacus.optimizers import DPOptimizer from opacus.utils.uniform_sampler import ( DistributedUniformWithReplacementSampler, UniformWithReplacementSampler, ) -from torch.utils.data import BatchSampler, DataLoader, Sampler class BatchSplittingSampler(Sampler[List[int]]): @@ -71,13 +72,17 @@ def __iter__(self): def __len__(self): if isinstance(self.sampler, BatchSampler): return int( - len(self.sampler) * (self.sampler.batch_size / self.max_batch_size) + np.ceil( + len(self.sampler) * (self.sampler.batch_size / self.max_batch_size) + ) ) elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance( self.sampler, DistributedUniformWithReplacementSampler ): expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples - return int(len(self.sampler) * (expected_batch_size / self.max_batch_size)) + return int( + np.ceil(len(self.sampler) * (expected_batch_size / self.max_batch_size)) + ) return len(self.sampler)