Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory bank distributed training support #1293

Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/imagenet/resnet50/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def pretrain(
logger=TensorBoardLogger(save_dir=str(log_dir), name="pretrain"),
precision=precision,
strategy="ddp_find_unused_parameters_true",
sync_batchnorm=True,
sync_batchnorm=accelerator != "cpu", # Sync batchnorm is not supported on CPU.
guarin marked this conversation as resolved.
Show resolved Hide resolved
num_sanity_val_steps=0,
)

Expand Down
20 changes: 16 additions & 4 deletions benchmarks/imagenet/resnet50/mocov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
self.projection_head = MoCoProjectionHead()
self.query_backbone = copy.deepcopy(self.backbone)
self.query_projection_head = MoCoProjectionHead()
self.criterion = NTXentLoss(temperature=0.2, memory_bank_size=65536)
self.criterion = NTXentLoss(
temperature=0.2,
memory_bank_size=(65536, 128),
gather_distributed=self.trainer.num_devices > 1,
guarin marked this conversation as resolved.
Show resolved Hide resolved
)

self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)

Expand All @@ -45,8 +49,16 @@ def forward_key_encoder(self, x: Tensor) -> Tuple[Tensor, Tensor]:
x, shuffle = batch_shuffle(batch=x, distributed=self.trainer.num_devices > 1)
features = self.forward(x).flatten(start_dim=1)
projections = self.projection_head(features)
features = batch_unshuffle(batch=features, shuffle=shuffle)
projections = batch_unshuffle(batch=projections, shuffle=shuffle)
features = batch_unshuffle(
batch=features,
shuffle=shuffle,
distributed=self.trainer.num_devices > 1,
)
projections = batch_unshuffle(
batch=projections,
shuffle=shuffle,
distributed=self.trainer.num_devices > 1,
)
return features, projections

def forward_query_encoder(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -123,7 +135,7 @@ def configure_optimizers(self):
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=0,
max_epochs=self.trainer.estimated_stepping_batches,
max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/imagenet/resnet50/swav.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
self.queues = ModuleList(
[
MemoryBankModule(
size=self.n_batches_in_queue * self.batch_size_per_device
size=(self.n_batches_in_queue * self.batch_size_per_device, 128)
)
for _ in range(CROP_COUNTS[0])
]
Expand Down
4 changes: 2 additions & 2 deletions docs/source/getting_started/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ For more information check the documentation:
.. code-block:: python

# to create a NTXentLoss with a memory bank (like for MoCo) set the
# memory_bank_size parameter to a value > 0
# memory_bank_size parameter to a value > 0 and specify the feature dimension
from lightly.loss import NTXentLoss
criterion = NTXentLoss(memory_bank_size=4096)
criterion = NTXentLoss(memory_bank_size=(4096, 128))
# the memory bank is used automatically for every forward pass
y0, y1 = resnet_moco(x0, x1)
loss = criterion(y0, y1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ def __init__(self, dataloader_kNN, num_classes):
# smog
self.n_groups = 300
memory_bank_size = 10000
self.memory_bank = memory_bank.MemoryBankModule(size=memory_bank_size)
self.memory_bank = memory_bank.MemoryBankModule(size=(memory_bank_size, 128))
# create our loss
group_features = torch.nn.functional.normalize(
torch.rand(self.n_groups, 128), dim=1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def __init__(self, dataloader_kNN, num_classes):
utils.deactivate_requires_grad(self.projection_head_momentum)

# create our loss with the optional memory bank
self.criterion = NTXentLoss(temperature=0.07, memory_bank_size=memory_bank_size)
self.criterion = NTXentLoss(
temperature=0.07, memory_bank_size=(memory_bank_size, 128)
)

def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
Expand Down Expand Up @@ -505,7 +507,7 @@ def __init__(self, dataloader_kNN, num_classes):
self.projection_head = heads.NNCLRProjectionHead(feature_dim, 4096, 256)

self.criterion = NTXentLoss()
self.memory_bank = modules.NNMemoryBankModule(size=memory_bank_size)
self.memory_bank = modules.NNMemoryBankModule(size=(memory_bank_size, 256))

def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ def __init__(self, dataloader_kNN, num_classes):
utils.deactivate_requires_grad(self.projection_head_momentum)

# create our loss with the optional memory bank
self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=memory_bank_size)
self.criterion = NTXentLoss(
temperature=0.1, memory_bank_size=(memory_bank_size, 128)
)

def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
Expand Down Expand Up @@ -1014,7 +1016,7 @@ def __init__(self, dataloader_kNN, num_classes):
# smog
self.n_groups = 300
memory_bank_size = 10000
self.memory_bank = memory_bank.MemoryBankModule(size=memory_bank_size)
self.memory_bank = memory_bank.MemoryBankModule(size=(memory_bank_size, 128))
# create our loss
group_features = torch.nn.functional.normalize(
torch.rand(self.n_groups, 128), dim=1
Expand Down Expand Up @@ -1319,7 +1321,7 @@ def __init__(self, dataloader_kNN, num_classes):
self.prototypes = heads.SwaVPrototypes(128, 3000, 1)
self.start_queue_at_epoch = 15
self.queues = nn.ModuleList(
[memory_bank.MemoryBankModule(size=384) for _ in range(2)]
[memory_bank.MemoryBankModule(size=(384, 128)) for _ in range(2)]
) # Queue size reduced in order to work with a smaller dataset
self.criterion = SwaVLoss()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __init__(self):
deactivate_requires_grad(self.projection_head_momentum)

# Create the loss function with memory bank.
self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=4096)
self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=(4096, 128))

def training_step(self, batch, batch_idx):
(x_q, x_k), _, _ = batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def __init__(self):
deactivate_requires_grad(self.projection_head_momentum)

# create our loss with the optional memory bank
self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=memory_bank_size)
self.criterion = NTXentLoss(
temperature=0.1, memory_bank_size=(memory_bank_size, 128)
)

def training_step(self, batch, batch_idx):
(x_q, x_k), _, _ = batch
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def forward_momentum(self, x):
num_workers=8,
)

criterion = NTXentLoss(memory_bank_size=4096)
criterion = NTXentLoss(memory_bank_size=(4096, 128))
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

epochs = 10
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/nnclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, x):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

memory_bank = NNMemoryBankModule(size=4096)
memory_bank = NNMemoryBankModule(size=(4096, 128))
memory_bank.to(device)

transform = SimCLRTransform(input_size=32)
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/smog.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward_momentum(self, x):

# memory bank because we reset the group features every 300 iterations
memory_bank_size = 300 * batch_size
memory_bank = MemoryBankModule(size=memory_bank_size)
memory_bank = MemoryBankModule(size=(memory_bank_size, 128))

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
Expand Down
4 changes: 3 additions & 1 deletion examples/pytorch/swav_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def __init__(self, backbone):
self.prototypes = SwaVPrototypes(128, 512, 1)

self.start_queue_at_epoch = 2
self.queues = nn.ModuleList([MemoryBankModule(size=3840) for _ in range(2)])
self.queues = nn.ModuleList(
[MemoryBankModule(size=(3840, 128)) for _ in range(2)]
)

def forward(self, high_resolution, low_resolution, epoch):
self.prototypes.normalize()
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch_lightning/moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self):
deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)

self.criterion = NTXentLoss(memory_bank_size=4096)
self.criterion = NTXentLoss(memory_bank_size=(4096, 128))

def forward(self, x):
query = self.backbone(x).flatten(start_dim=1)
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch_lightning/nnclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self):
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = NNCLRProjectionHead(512, 512, 128)
self.prediction_head = NNCLRPredictionHead(128, 512, 128)
self.memory_bank = NNMemoryBankModule(size=4096)
self.memory_bank = NNMemoryBankModule(size=(4096, 128))

self.criterion = NTXentLoss()

Expand Down
4 changes: 3 additions & 1 deletion examples/pytorch_lightning/swav_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def __init__(self):
self.projection_head = SwaVProjectionHead(512, 512, 128)
self.prototypes = SwaVPrototypes(128, 512, 1)
self.start_queue_at_epoch = 2
self.queues = nn.ModuleList([MemoryBankModule(size=3840) for _ in range(2)])
self.queues = nn.ModuleList(
[MemoryBankModule(size=(3840, 128)) for _ in range(2)]
)
self.criterion = SwaVLoss()

def training_step(self, batch, batch_idx):
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch_lightning_distributed/moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self):
deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)

self.criterion = NTXentLoss(memory_bank_size=4096)
self.criterion = NTXentLoss(memory_bank_size=(4096, 128))

def forward(self, x):
query = self.backbone(x).flatten(start_dim=1)
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch_lightning_distributed/nnclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self):
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = NNCLRProjectionHead(512, 512, 128)
self.prediction_head = NNCLRPredictionHead(128, 512, 128)
self.memory_bank = NNMemoryBankModule(size=4096)
self.memory_bank = NNMemoryBankModule(size=(4096, 128))

self.criterion = NTXentLoss()

Expand Down
18 changes: 12 additions & 6 deletions lightly/loss/ntx_ent_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

from typing import Sequence, Union

import torch
from torch import distributed as torch_dist
from torch import nn
Expand All @@ -25,8 +27,14 @@ class NTXentLoss(MemoryBankModule):
temperature:
Scale logits by the inverse of the temperature.
memory_bank_size:
Number of negative samples to store in the memory bank.
Use 0 for SimCLR. For MoCo we typically use numbers like 4096 or 65536.
Size of the memory bank as (num_features, dim) tuple. num_features are the
number of negative samples stored in the memory bank. If num_features is 0,
the memory bank is disabled. Use 0 for SimCLR. For MoCo we typically use
numbers like 4096 or 65536.
Deprecated: If only a single integer is passed, it is interpreted as the
number of features and the feature dimension is inferred from the first
batch stored in the memory bank. Leaving out the feature dimension might
lead to errors in distributed training.
gather_distributed:
If True then negatives from all gpus are gathered before the
loss calculation. If a memory bank is used and gather_distributed is True,
Expand Down Expand Up @@ -56,12 +64,10 @@ class NTXentLoss(MemoryBankModule):
def __init__(
self,
temperature: float = 0.5,
memory_bank_size: int = 0,
memory_bank_size: Union[int, Sequence[int]] = 0,
gather_distributed: bool = False,
):
super(NTXentLoss, self).__init__(
size=memory_bank_size, gather_distributed=gather_distributed
)
super().__init__(size=memory_bank_size, gather_distributed=gather_distributed)
self.temperature = temperature
self.gather_distributed = gather_distributed
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
Expand Down
19 changes: 14 additions & 5 deletions lightly/loss/regularizer/co2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

from typing import Sequence, Union

import torch

from lightly.models.modules.memory_bank import MemoryBankModule
Expand All @@ -19,15 +21,19 @@ class CO2Regularizer(MemoryBankModule):
t_consistency:
Temperature used during softmax calculations.
memory_bank_size:
Number of negative samples to store in the memory bank.
Use 0 to use the second batch for negative samples.
Size of the memory bank as (num_features, dim) tuple. num_features is the
number of negatives stored in the bank. If set to 0, the memory bank is
disabled. Deprecated: If only a single integer is passed, it is interpreted
as the number of features and the feature dimension is inferred from the
first batch stored in the memory bank. Leaving out the feature dimension
might lead to errors in distributed training.

Examples:
>>> # initialize loss function for MoCo
>>> loss_fn = NTXentLoss(memory_bank_size=4096)
>>> loss_fn = NTXentLoss(memory_bank_size=(4096, 128))
>>>
>>> # initialize CO2 regularizer
>>> co2 = CO2Regularizer(alpha=1.0, memory_bank_size=4096)
>>> co2 = CO2Regularizer(alpha=1.0, memory_bank_size=(4096, 128))
>>>
>>> # generate two random trasnforms of images
>>> t0 = transforms(images)
Expand All @@ -42,7 +48,10 @@ class CO2Regularizer(MemoryBankModule):
"""

def __init__(
self, alpha: float = 1, t_consistency: float = 0.05, memory_bank_size: int = 0
self,
alpha: float = 1,
t_consistency: float = 0.05,
memory_bank_size: Union[int, Sequence[int]] = 0,
):
super(CO2Regularizer, self).__init__(size=memory_bank_size)
# try-catch the KLDivLoss construction for backwards compatability
Expand Down
Loading
Loading