Skip to content

Commit

Permalink
Fix BatchMemoryManager length
Browse files Browse the repository at this point in the history
  • Loading branch information
Dariush Wahdany committed Mar 22, 2024
1 parent c314d42 commit c35262d
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions opacus/utils/batch_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
from typing import List

import numpy as np
from torch.utils.data import BatchSampler, DataLoader, Sampler

from opacus.optimizers import DPOptimizer
from opacus.utils.uniform_sampler import (
DistributedUniformWithReplacementSampler,
UniformWithReplacementSampler,
)
from torch.utils.data import BatchSampler, DataLoader, Sampler


class BatchSplittingSampler(Sampler[List[int]]):
Expand Down Expand Up @@ -71,13 +72,17 @@ def __iter__(self):
def __len__(self):
if isinstance(self.sampler, BatchSampler):
return int(
len(self.sampler) * (self.sampler.batch_size / self.max_batch_size)
np.ceil(
len(self.sampler) * (self.sampler.batch_size / self.max_batch_size)
)
)
elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance(
self.sampler, DistributedUniformWithReplacementSampler
):
expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
return int(len(self.sampler) * (expected_batch_size / self.max_batch_size))
return int(
np.ceil(len(self.sampler) * (expected_batch_size / self.max_batch_size))
)

return len(self.sampler)

Expand Down

0 comments on commit c35262d

Please sign in to comment.