diff --git a/.gitignore b/.gitignore index 2fdc54c33..031ae8795 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ lightning_logs/ **lightning_logs/ **/__MACOSX +datasets/ docs/source/tutorials/package/* docs/source/tutorials/platform/* docs/source/tutorials_source/platform/data diff --git a/README.md b/README.md index 1733d15eb..7938b21e8 100644 --- a/README.md +++ b/README.md @@ -287,11 +287,18 @@ tuned for maximum accuracy. For detailed results and more info about the benchma > > See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details. -| Model | Backbone | Batch Size | Epochs | Linear Top1 | Finetune Top1 | KNN Top1 | Tensorboard | Checkpoint | -|-------------|----------|------------|--------|-------------|---------------|----------|-------------|------------| -| DINO | Res50 | 128 | 100 | 68.2 | 72.5 | 49.9 | [link](https://tensorboard.dev/experiment/DvKHX9sNSWWqDrRksllPLA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dino_2023-06-06_13-59-48/pretrain/version_0/checkpoints/epoch%3D99-step%3D1000900.ckpt) | -| SimCLR | Res50 | 256 | 100 | 63.2 | N/A | 44.9 | [link](https://tensorboard.dev/experiment/JwNs9E02TeeQkS7aljh8dA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-05-04_09-02-54/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | -| SwAV | Res50 | 256 | 100 | 67.2 | 75.4 | 49.5 | [link](https://tensorboard.dev/experiment/Ipx4Oxl5Qkqm5Sl5kWyKKg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_swav_2023-05-25_08-29-14/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) +| Model | Backbone | Batch Size | Epochs | Linear Top1 | Finetune Top1 | KNN Top1 | Tensorboard | Checkpoint | +|----------------|----------|------------|--------|-------------|---------------|----------|-------------|------------| +| BYOL | Res50 | 256 | 100 | 62.4 | 74.0 | 45.6 | [link](https://tensorboard.dev/experiment/Z0iG2JLaTJe5nuBD7DK1bg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_byol_2023-07-10_10-37-32/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | +| DINO | Res50 | 128 | 100 | 68.2 | 72.5 | 49.9 | [link](https://tensorboard.dev/experiment/DvKHX9sNSWWqDrRksllPLA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dino_2023-06-06_13-59-48/pretrain/version_0/checkpoints/epoch%3D99-step%3D1000900.ckpt) | +| SimCLR* | Res50 | 256 | 100 | 63.2 | 73.9 | 44.8 | [link](https://tensorboard.dev/experiment/Ugol97adQdezgcVibDYMMA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | +| SimCLR* + DCL | Res50 | 256 | 100 | 65.1 | 73.5 | 49.6 | [link](https://tensorboard.dev/experiment/k4ZonZ77QzmBkc0lXswQlg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dcl_2023-07-04_16-51-40/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | +| SimCLR* + DCLW | Res50 | 256 | 100 | 64.5 | 73.2 | 48.5 | [link](https://tensorboard.dev/experiment/TrALnpwFQ4OkZV3uvaX7wQ/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dclw_2023-07-07_14-57-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | +| SwAV | Res50 | 256 | 100 | 67.2 | 75.4 | 49.5 | [link](https://tensorboard.dev/experiment/Ipx4Oxl5Qkqm5Sl5kWyKKg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_swav_2023-05-25_08-29-14/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | + +*\*We use square root learning rate scaling instead of linear scaling as it yields +better results for smaller batch sizes. See Appendix B.1 in [SimCLR paper](https://arxiv.org/abs/2002.05709).* + ### ImageNette diff --git a/benchmarks/imagenet/resnet50/byol.py b/benchmarks/imagenet/resnet50/byol.py new file mode 100644 index 000000000..bfc2f1080 --- /dev/null +++ b/benchmarks/imagenet/resnet50/byol.py @@ -0,0 +1,148 @@ +import copy +from typing import List, Tuple + +import torch +from pytorch_lightning import LightningModule +from torch import Tensor +from torch.nn import Identity +from torchvision.models import resnet50 + +from lightly.loss import NegativeCosineSimilarity +from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead +from lightly.models.utils import get_weight_decay_parameters, update_momentum +from lightly.transforms import SimCLRTransform +from lightly.utils.benchmarking import OnlineLinearClassifier +from lightly.utils.lars import LARS +from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule + + +class BYOL(LightningModule): + def __init__(self, batch_size_per_device: int, num_classes: int) -> None: + super().__init__() + self.save_hyperparameters() + self.batch_size_per_device = batch_size_per_device + + resnet = resnet50() + resnet.fc = Identity() # Ignore classification head + self.backbone = resnet + self.projection_head = BYOLProjectionHead() + self.student_backbone = copy.deepcopy(self.backbone) + self.student_projection_head = BYOLProjectionHead() + self.student_prediction_head = BYOLPredictionHead() + self.criterion = NegativeCosineSimilarity() + + self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) + + def forward(self, x: Tensor) -> Tensor: + return self.backbone(x) + + @torch.no_grad() + def forward_teacher(self, x: Tensor) -> Tuple[Tensor, Tensor]: + features = self(x).flatten(start_dim=1) + projections = self.projection_head(features) + return features, projections + + def forward_student(self, x: Tensor) -> Tensor: + features = self.student_backbone(x).flatten(start_dim=1) + projections = self.student_projection_head(features) + predictions = self.student_prediction_head(projections) + return predictions + + def training_step( + self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int + ) -> Tensor: + # Momentum update teacher. + # Settings follow original code for 100 epochs which are slightly different + # from the paper, see: + # https://github.com/deepmind/deepmind-research/blob/f5de0ede8430809180254ee957abf36ed62579ef/byol/configs/byol.py#L21-L23 + momentum = cosine_schedule( + step=self.trainer.global_step, + max_steps=self.trainer.estimated_stepping_batches, + start_value=0.99, + end_value=1.0, + ) + update_momentum(self.student_backbone, self.backbone, m=momentum) + update_momentum(self.student_projection_head, self.projection_head, m=momentum) + + # Forward pass and loss calculation. + views, targets = batch[0], batch[1] + teacher_features_0, teacher_projections_0 = self.forward_teacher(views[0]) + _, teacher_projections_1 = self.forward_teacher(views[1]) + student_predictions_0 = self.forward_student(views[0]) + student_predictions_1 = self.forward_student(views[1]) + # NOTE: Factor 2 because: L2(norm(x), norm(y)) = 2 - 2 * cossim(x, y) + loss_0 = 2 * self.criterion(teacher_projections_0, student_predictions_1) + loss_1 = 2 * self.criterion(teacher_projections_1, student_predictions_0) + # NOTE: No mean because original code only takes mean over batch dimension, not + # views. + loss = loss_0 + loss_1 + self.log( + "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets) + ) + + # Online linear evaluation. + cls_loss, cls_log = self.online_classifier.training_step( + (teacher_features_0.detach(), targets), batch_idx + ) + self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) + return loss + cls_loss + + def validation_step( + self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int + ) -> Tensor: + images, targets = batch[0], batch[1] + features = self.forward(images).flatten(start_dim=1) + cls_loss, cls_log = self.online_classifier.validation_step( + (features.detach(), targets), batch_idx + ) + self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) + return cls_loss + + def configure_optimizers(self): + # Don't use weight decay for batch norm, bias parameters, and classification + # head to improve performance. + params, params_no_weight_decay = get_weight_decay_parameters( + [ + self.student_backbone, + self.student_projection_head, + self.student_prediction_head, + ] + ) + optimizer = LARS( + [ + {"name": "byol", "params": params}, + { + "name": "byol_no_weight_decay", + "params": params_no_weight_decay, + "weight_decay": 0.0, + }, + { + "name": "online_classifier", + "params": self.online_classifier.parameters(), + "weight_decay": 0.0, + }, + ], + # Settings follow original code for 100 epochs which are slightly different + # from the paper, see: + # https://github.com/deepmind/deepmind-research/blob/f5de0ede8430809180254ee957abf36ed62579ef/byol/configs/byol.py#L21-L23 + lr=0.45 * self.batch_size_per_device * self.trainer.world_size / 256, + momentum=0.9, + weight_decay=1e-6, + ) + scheduler = { + "scheduler": CosineWarmupScheduler( + optimizer=optimizer, + warmup_epochs=( + self.trainer.estimated_stepping_batches + / self.trainer.max_epochs + * 10 + ), + max_epochs=self.trainer.estimated_stepping_batches, + ), + "interval": "step", + } + return [optimizer], [scheduler] + + +# BYOL uses same transform as SimCLR. +transform = SimCLRTransform() diff --git a/benchmarks/imagenet/resnet50/dcl.py b/benchmarks/imagenet/resnet50/dcl.py new file mode 100644 index 000000000..42c66c7c5 --- /dev/null +++ b/benchmarks/imagenet/resnet50/dcl.py @@ -0,0 +1,113 @@ +import math +from typing import List, Tuple + +import torch +from pytorch_lightning import LightningModule +from torch import Tensor +from torch.nn import Identity +from torchvision.models import resnet50 + +from lightly.loss.dcl_loss import DCLLoss +from lightly.models.modules import SimCLRProjectionHead +from lightly.models.utils import get_weight_decay_parameters +from lightly.transforms import SimCLRTransform +from lightly.utils.benchmarking import OnlineLinearClassifier +from lightly.utils.lars import LARS +from lightly.utils.scheduler import CosineWarmupScheduler + + +class DCL(LightningModule): + def __init__(self, batch_size_per_device: int, num_classes: int) -> None: + super().__init__() + self.save_hyperparameters() + self.batch_size_per_device = batch_size_per_device + + resnet = resnet50() + resnet.fc = Identity() # Ignore classification head + self.backbone = resnet + self.projection_head = SimCLRProjectionHead() # DCL uses SimCLR head + self.criterion = DCLLoss(temperature=0.1, gather_distributed=True) + + self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) + + def forward(self, x: Tensor) -> Tensor: + return self.backbone(x) + + def training_step( + self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int + ) -> Tensor: + views, targets = batch[0], batch[1] + features = self.forward(torch.cat(views)).flatten(start_dim=1) + z = self.projection_head(features) + z0, z1 = z.chunk(len(views)) + loss = self.criterion(z0, z1) + self.log( + "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets) + ) + + cls_loss, cls_log = self.online_classifier.training_step( + (features.detach(), targets.repeat(len(views))), batch_idx + ) + self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) + return loss + cls_loss + + def validation_step( + self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int + ) -> Tensor: + images, targets = batch[0], batch[1] + features = self.forward(images).flatten(start_dim=1) + cls_loss, cls_log = self.online_classifier.validation_step( + (features.detach(), targets), batch_idx + ) + self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) + return cls_loss + + def configure_optimizers(self): + # Don't use weight decay for batch norm, bias parameters, and classification + # head to improve performance. + params, params_no_weight_decay = get_weight_decay_parameters( + [self.backbone, self.projection_head] + ) + optimizer = LARS( + [ + {"name": "dcl", "params": params}, + { + "name": "dcl_no_weight_decay", + "params": params_no_weight_decay, + "weight_decay": 0.0, + }, + { + "name": "online_classifier", + "params": self.online_classifier.parameters(), + "weight_decay": 0.0, + }, + ], + # DCL uses SimCLR's learning rate scaling scheme. + # Square root learning rate scaling improves performance for small + # batch sizes (<=2048) and few training epochs (<=200). Alternatively, + # linear scaling can be used for larger batches and longer training: + # lr=0.3 * self.batch_size_per_device * self.trainer.world_size / 256 + # See Appendix B.1. in the SimCLR paper https://arxiv.org/abs/2002.05709 + lr=0.075 * math.sqrt(self.batch_size_per_device * self.trainer.world_size), + momentum=0.9, + # Note: Paper uses weight decay of 1e-6 but reference code 1e-4. See: + # https://github.com/google-research/simclr/blob/2fc637bdd6a723130db91b377ac15151e01e4fc2/README.md?plain=1#L103 + weight_decay=1e-6, + ) + scheduler = { + "scheduler": CosineWarmupScheduler( + optimizer=optimizer, + warmup_epochs=( + self.trainer.estimated_stepping_batches + / self.trainer.max_epochs + * 10 + ), + max_epochs=self.trainer.estimated_stepping_batches, + ), + "interval": "step", + } + return [optimizer], [scheduler] + + +# DCL uses SimCLR augmentations +transform = SimCLRTransform() diff --git a/benchmarks/imagenet/resnet50/dclw.py b/benchmarks/imagenet/resnet50/dclw.py new file mode 100644 index 000000000..bcae95d6e --- /dev/null +++ b/benchmarks/imagenet/resnet50/dclw.py @@ -0,0 +1,113 @@ +import math +from typing import List, Tuple + +import torch +from pytorch_lightning import LightningModule +from torch import Tensor +from torch.nn import Identity +from torchvision.models import resnet50 + +from lightly.loss.dcl_loss import DCLWLoss +from lightly.models.modules import SimCLRProjectionHead +from lightly.models.utils import get_weight_decay_parameters +from lightly.transforms import SimCLRTransform +from lightly.utils.benchmarking import OnlineLinearClassifier +from lightly.utils.lars import LARS +from lightly.utils.scheduler import CosineWarmupScheduler + + +class DCLW(LightningModule): + def __init__(self, batch_size_per_device: int, num_classes: int) -> None: + super().__init__() + self.save_hyperparameters() + self.batch_size_per_device = batch_size_per_device + + resnet = resnet50() + resnet.fc = Identity() # Ignore classification head + self.backbone = resnet + self.projection_head = SimCLRProjectionHead() # DCLW uses SimCLR head + self.criterion = DCLWLoss(temperature=0.1, sigma=0.5, gather_distributed=True) + + self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) + + def forward(self, x: Tensor) -> Tensor: + return self.backbone(x) + + def training_step( + self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int + ) -> Tensor: + views, targets = batch[0], batch[1] + features = self.forward(torch.cat(views)).flatten(start_dim=1) + z = self.projection_head(features) + z0, z1 = z.chunk(len(views)) + loss = self.criterion(z0, z1) + self.log( + "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets) + ) + + cls_loss, cls_log = self.online_classifier.training_step( + (features.detach(), targets.repeat(len(views))), batch_idx + ) + self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) + return loss + cls_loss + + def validation_step( + self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int + ) -> Tensor: + images, targets = batch[0], batch[1] + features = self.forward(images).flatten(start_dim=1) + cls_loss, cls_log = self.online_classifier.validation_step( + (features.detach(), targets), batch_idx + ) + self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) + return cls_loss + + def configure_optimizers(self): + # Don't use weight decay for batch norm, bias parameters, and classification + # head to improve performance. + params, params_no_weight_decay = get_weight_decay_parameters( + [self.backbone, self.projection_head] + ) + optimizer = LARS( + [ + {"name": "dclw", "params": params}, + { + "name": "dclw_no_weight_decay", + "params": params_no_weight_decay, + "weight_decay": 0.0, + }, + { + "name": "online_classifier", + "params": self.online_classifier.parameters(), + "weight_decay": 0.0, + }, + ], + # DCLW uses SimCLR's learning rate scaling scheme. + # Square root learning rate scaling improves performance for small + # batch sizes (<=2048) and few training epochs (<=200). Alternatively, + # linear scaling can be used for larger batches and longer training: + # lr=0.3 * self.batch_size_per_device * self.trainer.world_size / 256 + # See Appendix B.1. in the SimCLR paper https://arxiv.org/abs/2002.05709 + lr=0.075 * math.sqrt(self.batch_size_per_device * self.trainer.world_size), + momentum=0.9, + # Note: Paper uses weight decay of 1e-6 but reference code 1e-4. See: + # https://github.com/google-research/simclr/blob/2fc637bdd6a723130db91b377ac15151e01e4fc2/README.md?plain=1#L103 + weight_decay=1e-6, + ) + scheduler = { + "scheduler": CosineWarmupScheduler( + optimizer=optimizer, + warmup_epochs=( + self.trainer.estimated_stepping_batches + / self.trainer.max_epochs + * 10 + ), + max_epochs=self.trainer.estimated_stepping_batches, + ), + "interval": "step", + } + return [optimizer], [scheduler] + + +# DCLW uses SimCLR augmentations +transform = SimCLRTransform() diff --git a/benchmarks/imagenet/resnet50/main.py b/benchmarks/imagenet/resnet50/main.py index bdda083e8..ce1e9e40b 100644 --- a/benchmarks/imagenet/resnet50/main.py +++ b/benchmarks/imagenet/resnet50/main.py @@ -3,6 +3,9 @@ from pathlib import Path from typing import Sequence, Union +import byol +import dcl +import dclw import dino import finetune_eval import knn_eval @@ -43,6 +46,9 @@ parser.add_argument("--skip-finetune-eval", action="store_true") METHODS = { + "byol": {"model": byol.BYOL, "transform": byol.transform}, + "dcl": {"model": dcl.DCL, "transform": dcl.transform}, + "dclw": {"model": dclw.DCLW, "transform": dclw.transform}, "dino": {"model": dino.DINO, "transform": dino.transform}, "mocov2": {"model": mocov2.MoCoV2, "transform": mocov2.transform}, "simclr": {"model": simclr.SimCLR, "transform": simclr.transform}, diff --git a/benchmarks/imagenet/resnet50/swav.py b/benchmarks/imagenet/resnet50/swav.py index af1ac3f3b..62d7fc10a 100644 --- a/benchmarks/imagenet/resnet50/swav.py +++ b/benchmarks/imagenet/resnet50/swav.py @@ -105,14 +105,14 @@ def training_step( loss, prog_bar=True, sync_dist=True, - batch_size_per_device=len(targets), + batch_size=len(targets), ) # Calculate the classification loss. cls_loss, cls_log = self.online_classifier.training_step( (multi_crop_features[0].detach(), targets), batch_idx ) - self.log_dict(cls_log, sync_dist=True, batch_size_per_device=len(targets)) + self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) return loss + cls_loss def validation_step( @@ -123,9 +123,7 @@ def validation_step( cls_loss, cls_log = self.online_classifier.validation_step( (features.detach(), targets), batch_idx ) - self.log_dict( - cls_log, prog_bar=True, sync_dist=True, batch_size_per_device=len(targets) - ) + self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) return cls_loss def configure_optimizers(self): @@ -192,10 +190,10 @@ def _update_queue( # Get the queue projections queue_projections = [] for i in range(len(queues)): - _, projections = queues[i](projections[i], update=True) + _, queue_proj = queues[i](projections[i], update=True) # Queue projections are in (num_ftrs X queue_length) shape, while the high res # projections are in (batch_size_per_device X num_ftrs). Swap the axes for interoperability. - projections = torch.permute(projections, (1, 0)) - queue_projections.append(projections) + queue_proj = torch.permute(queue_proj, (1, 0)) + queue_projections.append(queue_proj) return queue_projections diff --git a/docs/source/getting_started/benchmarks.rst b/docs/source/getting_started/benchmarks.rst index 24ff79545..585a6a514 100644 --- a/docs/source/getting_started/benchmarks.rst +++ b/docs/source/getting_started/benchmarks.rst @@ -33,10 +33,15 @@ See the `benchmarking scripts `_", "`link `_" "DINO", "Res50", "128", "100", "68.2", "87.9", "72.5", "90.8", "49.9", "78.7", "`link `_", "`link `_" - "SimCLR", "Res50", "256", "100", "63.2", "85.3", "N/A", "N/A", "44.9", "74.2", "`link `_", "`link `_" + "SimCLR*", "Res50", "256", "100", "63.2", "85.2", "73.9", "91.9", "44.8", "73.9", "`link `_", "`link `_" + "SimCLR* + DCL", "Res50", "256", "100", "65.1", "86.2", "73.5", "91.7", "49.6", "77.5", "`link `_", "`link `_" + "SimCLR* + DCLW", "Res50", "256", "100", "64.5", "86.0", "73.2", "91.5", "48.5", "76.8", "`link `_", "`link `_" "SwAV", "Res50", "256", "100", "67.2", "88.1", "75.4", "92.7", "49.5", "78.6", "`link `_", "`link `_" +*\*We use square root learning rate scaling instead of linear scaling as it yields better results for smaller batch sizes. See Appendix B.1 in SimCLR paper.* + ImageNette ----------------------------------- diff --git a/examples/pytorch/ijepa.py b/examples/pytorch/ijepa.py new file mode 100644 index 000000000..eb4730e04 --- /dev/null +++ b/examples/pytorch/ijepa.py @@ -0,0 +1,117 @@ +import copy + +import torch +import torchvision +from torch import nn +from torch.nn import functional as F +from tqdm import tqdm + +from lightly.data.collate import IJEPAMaskCollator +from lightly.models import utils +from lightly.models.modules.ijepa import IJEPABackbone, IJEPAPredictor +from lightly.transforms.ijepa_transform import IJEPATransform + + +class IJEPA(nn.Module): + def __init__(self, vit_encoder, vit_predictor, momentum_scheduler): + super().__init__() + self.encoder = IJEPABackbone.from_vit(vit_encoder) + self.predictor = IJEPAPredictor.from_vit_encoder( + vit_predictor.encoder, + (vit_predictor.image_size // vit_predictor.patch_size) ** 2, + ) + self.target_encoder = copy.deepcopy(self.encoder) + self.momentum_scheduler = momentum_scheduler + + def forward_target(self, imgs, masks_enc, masks_pred): + with torch.no_grad(): + h = self.target_encoder(imgs) + h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim + B = len(h) + # -- create targets (masked regions of h) + h = utils.apply_masks(h, masks_pred) + h = utils.repeat_interleave_batch(h, B, repeat=len(masks_enc)) + return h + + def forward_context(self, imgs, masks_enc, masks_pred): + z = self.encoder(imgs, masks_enc) + z = self.predictor(z, masks_enc, masks_pred) + return z + + def forward(self, imgs, masks_enc, masks_pred): + z = self.forward_context(imgs, masks_enc, masks_pred) + h = self.forward_target(imgs, masks_enc, masks_pred) + return z, h + + def update_target_encoder( + self, + ): + with torch.no_grad(): + m = next(self.momentum_scheduler) + for param_q, param_k in zip( + self.encoder.parameters(), self.target_encoder.parameters() + ): + param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data) + + +collator = IJEPAMaskCollator( + input_size=(224, 224), + patch_size=32, +) + +transform = IJEPATransform() + +# we ignore object detection annotations by setting target_transform to return 0 +# or create a dataset from a folder containing images or videos: +# dataset = LightlyDataset("path/to/folder") +dataset = torchvision.datasets.VOCDetection( + "datasets/pascal_voc", + download=True, + transform=transform, + target_transform=lambda t: 0, +) +data_loader = torch.utils.data.DataLoader( + dataset, collate_fn=collator, batch_size=10, persistent_workers=False +) + +ema = (0.996, 1.0) +ipe_scale = 1.0 +ipe = len(data_loader) +num_epochs = 10 +momentum_scheduler = ( + ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale) + for i in range(int(ipe * num_epochs * ipe_scale) + 1) +) + +vit_for_predictor = torchvision.models.vit_b_32(pretrained=False) +vit_for_embedder = torchvision.models.vit_b_32(pretrained=False) +model = IJEPA(vit_for_embedder, vit_for_predictor, momentum_scheduler) + +criterion = nn.SmoothL1Loss() +optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4) +device = "cuda" if torch.cuda.is_available() else "cpu" +model.to(device) + +print("Starting Training") +for epoch in range(num_epochs): + total_loss = 0 + for udata, masks_enc, masks_pred in tqdm(data_loader): + + def load_imgs(): + # -- unsupervised imgs + imgs = udata[0].to(device, non_blocking=True) + masks_1 = [u.to(device, non_blocking=True) for u in masks_enc] + masks_2 = [u.to(device, non_blocking=True) for u in masks_pred] + return (imgs, masks_1, masks_2) + + imgs, masks_enc, masks_pred = load_imgs() + z, h = model(imgs, masks_enc, masks_pred) + loss = criterion(z, h) + total_loss += loss.detach() + loss.backward() + optimizer.step() + optimizer.zero_grad() + model.update_target_encoder() + + avg_loss = total_loss / len(data_loader) + print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}") diff --git a/examples/pytorch_lightning/ijepa.py b/examples/pytorch_lightning/ijepa.py new file mode 100644 index 000000000..464090415 --- /dev/null +++ b/examples/pytorch_lightning/ijepa.py @@ -0,0 +1 @@ +# TODO diff --git a/examples/pytorch_lightning_distributed/ijepa.py b/examples/pytorch_lightning_distributed/ijepa.py new file mode 100644 index 000000000..464090415 --- /dev/null +++ b/examples/pytorch_lightning_distributed/ijepa.py @@ -0,0 +1 @@ +# TODO diff --git a/lightly/__init__.py b/lightly/__init__.py index 533ebb904..50929e040 100644 --- a/lightly/__init__.py +++ b/lightly/__init__.py @@ -75,7 +75,7 @@ # All Rights Reserved __name__ = "lightly" -__version__ = "1.4.8" +__version__ = "1.4.12" import os @@ -94,17 +94,6 @@ msg = f"Partial import of {__name__}=={__version__} during build process." print(msg) else: - # see if prefetch_generator is available - try: - import prefetch_generator - except ImportError: - _prefetch_generator_available = False - else: - _prefetch_generator_available = True - - def _is_prefetch_generator_available(): - return _prefetch_generator_available - # see if torchvision vision transformer is available try: import torchvision.models.vision_transformer @@ -122,18 +111,10 @@ def _is_prefetch_generator_available(): from multiprocessing import current_process if current_process().name == "MainProcess": - from lightly.api.version_checking import ( - LightlyAPITimeoutException, - is_latest_version, - ) - from lightly.openapi_generated.swagger_client.rest import ApiException + from lightly.api.version_checking import is_latest_version try: is_latest_version(current_version=__version__) - except ( - ValueError, - ApiException, - LightlyAPITimeoutException, - AttributeError, - ): + except Exception: + # Version check should never break the package. pass diff --git a/lightly/api/api_workflow_client.py b/lightly/api/api_workflow_client.py index 9c89e57ea..18fa190e7 100644 --- a/lightly/api/api_workflow_client.py +++ b/lightly/api/api_workflow_client.py @@ -8,6 +8,7 @@ from requests import Response from lightly.__init__ import __version__ +from lightly.api import utils, version_checking from lightly.api.api_workflow_artifacts import _ArtifactsMixin from lightly.api.api_workflow_collaboration import _CollaborationMixin from lightly.api.api_workflow_compute_worker import _ComputeWorkerMixin @@ -18,19 +19,11 @@ from lightly.api.api_workflow_predictions import _PredictionsMixin from lightly.api.api_workflow_selection import _SelectionMixin from lightly.api.api_workflow_tags import _TagsMixin -from lightly.api.api_workflow_upload_dataset import _UploadDatasetMixin from lightly.api.api_workflow_upload_embeddings import _UploadEmbeddingsMixin from lightly.api.api_workflow_upload_metadata import _UploadCustomMetadataMixin from lightly.api.swagger_api_client import LightlySwaggerApiClient -from lightly.api.utils import ( - DatasourceType, - get_api_client_configuration, - get_signed_url_destination, -) -from lightly.api.version_checking import ( - LightlyAPITimeoutException, - is_compatible_version, -) +from lightly.api.utils import DatasourceType +from lightly.api.version_checking import LightlyAPITimeoutException from lightly.openapi_generated.swagger_client.api import ( CollaborationApi, DatasetsApi, @@ -49,7 +42,7 @@ ) from lightly.openapi_generated.swagger_client.models import Creator, DatasetData from lightly.openapi_generated.swagger_client.rest import ApiException -from lightly.utils.reordering import sort_items_by_keys +from lightly.utils import reordering # Env variable for server side encryption on S3 LIGHTLY_S3_SSE_KMS_KEY = "LIGHTLY_S3_SSE_KMS_KEY" @@ -58,7 +51,6 @@ class ApiWorkflowClient( _UploadEmbeddingsMixin, _SelectionMixin, - _UploadDatasetMixin, _DownloadDatasetMixin, _DatasetsMixin, _UploadCustomMetadataMixin, @@ -101,7 +93,7 @@ def __init__( creator: str = Creator.USER_PIP, ): try: - if not is_compatible_version(__version__): + if not version_checking.is_compatible_version(__version__): warnings.warn( UserWarning( ( @@ -119,7 +111,7 @@ def __init__( ): pass - configuration = get_api_client_configuration(token=token) + configuration = utils.get_api_client_configuration(token=token) self.api_client = LightlySwaggerApiClient(configuration=configuration) self.api_client.user_agent = f"Lightly/{__version__} ({platform.system()}/{platform.release()}; {platform.platform()}; {platform.processor()};) python/{platform.python_version()}" @@ -210,7 +202,7 @@ def _order_list_by_filenames( """ filenames_on_server = self.get_filenames() - list_ordered = sort_items_by_keys( + list_ordered = reordering.sort_items_by_keys( filenames_for_list, list_to_order, filenames_on_server ) return list_ordered @@ -259,7 +251,7 @@ def upload_file_with_signed_url( lightly_s3_sse_kms_key = os.environ.get(LIGHTLY_S3_SSE_KMS_KEY, "").strip() # Only set s3 related headers when we are talking with s3 if ( - get_signed_url_destination(signed_write_url) == DatasourceType.S3 + utils.get_signed_url_destination(signed_write_url) == DatasourceType.S3 and lightly_s3_sse_kms_key ): if headers is None: diff --git a/lightly/api/api_workflow_compute_worker.py b/lightly/api/api_workflow_compute_worker.py index a0a8c32f3..06c37c392 100644 --- a/lightly/api/api_workflow_compute_worker.py +++ b/lightly/api/api_workflow_compute_worker.py @@ -1,12 +1,12 @@ import copy import dataclasses import difflib +import json import time from functools import partial from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar, Union from lightly.api import utils -from lightly.api.utils import retry from lightly.openapi_generated.swagger_client.api_client import ApiClient from lightly.openapi_generated.swagger_client.models import ( CreateDockerWorkerRegistryEntryRequest, @@ -244,8 +244,24 @@ def create_compute_worker_config( request = DockerWorkerConfigV3CreateRequest( config=config, creator=self._creator ) - response = self._compute_worker_api.create_docker_worker_config_v3(request) - return response.id + try: + response = self._compute_worker_api.create_docker_worker_config_v3(request) + return response.id + except ApiException as e: + if e.body is None: + raise e + eb = json.loads(e.body) + eb_code = eb.get("code") + eb_error = eb.get("error") + if str(e.status)[0] == "4" and eb_code is not None and eb_error is not None: + raise ValueError( + f"Trying to schedule your job resulted in\n" + f">> {eb_code}\n>> {json.dumps(eb_error, indent=4)}\n" + f">> Please fix the issue mentioned above and see our docs " + f"https://docs.lightly.ai/docs/all-configuration-options for more help." + ) from e + else: + raise e def schedule_compute_worker_run( self, @@ -478,7 +494,7 @@ def _get_scheduled_run_by_id(self, scheduled_run_id: str) -> DockerRunScheduledD try: run: DockerRunScheduledData = next( run - for run in retry( + for run in utils.retry( lambda: self._compute_worker_api.get_docker_runs_scheduled_by_dataset_id( self.dataset_id ) @@ -618,12 +634,13 @@ def get_compute_worker_run_tags(self, run_id: str) -> List[TagData]: def selection_config_from_dict(cfg: Dict[str, Any]) -> SelectionConfig: """Recursively converts selection config from dict to a SelectionConfig instance.""" - new_cfg = copy.deepcopy(cfg) strategies = [] - for entry in new_cfg.get("strategies", []): - entry["input"] = SelectionConfigEntryInput(**entry["input"]) - entry["strategy"] = SelectionConfigEntryStrategy(**entry["strategy"]) - strategies.append(SelectionConfigEntry(**entry)) + for entry in cfg.get("strategies", []): + new_entry = copy.deepcopy(entry) + new_entry["input"] = SelectionConfigEntryInput(**entry["input"]) + new_entry["strategy"] = SelectionConfigEntryStrategy(**entry["strategy"]) + strategies.append(SelectionConfigEntry(**new_entry)) + new_cfg = copy.deepcopy(cfg) new_cfg["strategies"] = strategies return SelectionConfig(**new_cfg) diff --git a/lightly/api/api_workflow_datasets.py b/lightly/api/api_workflow_datasets.py index 76190e73e..29d33d63d 100644 --- a/lightly/api/api_workflow_datasets.py +++ b/lightly/api/api_workflow_datasets.py @@ -1,6 +1,6 @@ import warnings from itertools import chain -from typing import Iterator, List, Optional +from typing import Iterator, List, Optional, Set from lightly.api import utils from lightly.openapi_generated.swagger_client.models import ( @@ -52,13 +52,20 @@ def dataset_name_exists( ) -> bool: """Checks if a dataset with the given name exists. + There can be multiple datasets with the same name accessible to the current + user. This can happen if either: + * A dataset has been explicitly shared with the user + * The user has access to team datasets + The `shared` flag controls whether these datasets are checked. + Args: dataset_name: Name of the dataset. shared: - If False, considers only datasets owned by the user. - If True, considers only datasets which have been shared with the user. - If None, considers all datasets the users has access to. Defaults to False. + * If False (default), checks only datasets owned by the user. + * If True, checks datasets which have been shared with the user, + including team datasets. Excludes user's own datasets. + * If None, checks all datasets the users has access to. Returns: A boolean value indicating whether any dataset with the given name exists. @@ -101,18 +108,25 @@ def get_datasets_by_name( dataset_name: str, shared: Optional[bool] = False, ) -> List[DatasetData]: - """Fetches a dataset by name. + """Fetches datasets by name. + + There can be multiple datasets with the same name accessible to the current + user. This can happen if either: + * A dataset has been explicitly shared with the user + * The user has access to team datasets + The `shared` flag controls whether these datasets are returned. Args: dataset_name: Name of the target dataset. shared: - If False, returns only datasets owned by the user. In this case at most - one dataset will be returned. - If True, returns only datasets which have been shared with the user. Can - return multiple datasets. - If None, returns datasets the users has access to. Can return multiple - datasets. Defaults to False. + * If False (default), returns only datasets owned by the user. In this + case at most one dataset will be returned. + * If True, returns datasets which have been shared with the user, + including team datasets. Excludes user's own datasets. Can return + multiple datasets. + * If None, returns all datasets the users has access to. Can return + multiple datasets. Returns: A list of datasets that match the name. If no datasets with the name exist, @@ -134,41 +148,76 @@ def get_datasets_by_name( >>> client.get_datasets_by_name(dataset_name="random-name") [] """ - datasets = [] + datasets: List[DatasetData] = [] if not shared or shared is None: datasets.extend( - self._datasets_api.get_datasets_query_by_name( - dataset_name=dataset_name, - exact=True, - shared=False, + list( + utils.paginate_endpoint( + self._datasets_api.get_datasets_query_by_name, + dataset_name=dataset_name, + exact=True, + shared=False, + ) ) ) if shared or shared is None: datasets.extend( - self._datasets_api.get_datasets_query_by_name( - dataset_name=dataset_name, - exact=True, - shared=True, + list( + utils.paginate_endpoint( + self._datasets_api.get_datasets_query_by_name, + dataset_name=dataset_name, + exact=True, + shared=True, + ) + ) + ) + datasets.extend( + list( + utils.paginate_endpoint( + self._datasets_api.get_datasets_query_by_name, + dataset_name=dataset_name, + exact=True, + get_assets_of_team=True, + ) ) ) - return datasets + + # De-duplicate datasets because results from shared=True and + # those from get_assets_of_team=True might overlap + dataset_ids: Set[str] = set() + filtered_datasets: List[DatasetData] = [] + for dataset in datasets: + if dataset.id not in dataset_ids: + dataset_ids.add(dataset.id) + filtered_datasets.append(dataset) + + return filtered_datasets def get_datasets_iter( self, shared: Optional[bool] = False ) -> Iterator[DatasetData]: """Returns an iterator over all datasets owned by the current user. + There can be multiple datasets with the same name accessible to the current + user. This can happen if either: + * A dataset has been explicitly shared with the user + * The user has access to team datasets + The `shared` flag controls whether these datasets are returned. + Args: shared: - If False, returns only datasets owned by the user. - If True, returns only the datasets which have been shared with the user. - If None, returns all datasets the user has access to (owned and shared). - Defaults to False. + * If False (default), returns only datasets owned by the user. In this + case at most one dataset will be returned. + * If True, returns datasets which have been shared with the user, + including team datasets. Excludes user's own datasets. Can return + multiple datasets. + * If None, returns all datasets the users has access to. Can return + multiple datasets. Returns: An iterator over datasets owned by the current user. """ - dataset_iterable = [] + dataset_iterable: Iterator[DatasetData] = (_ for _ in ()) if not shared or shared is None: dataset_iterable = utils.paginate_endpoint( self._datasets_api.get_datasets, @@ -182,17 +231,40 @@ def get_datasets_iter( shared=True, ), ) - return dataset_iterable + dataset_iterable = chain( + dataset_iterable, + utils.paginate_endpoint( + self._datasets_api.get_datasets, + get_assets_of_team=True, + ), + ) + + # De-duplicate datasets because results from shared=True and + # those from get_assets_of_team=True might overlap + dataset_ids: Set[str] = set() + for dataset in dataset_iterable: + if dataset.id not in dataset_ids: + dataset_ids.add(dataset.id) + yield dataset def get_datasets(self, shared: Optional[bool] = False) -> List[DatasetData]: """Returns all datasets owned by the current user. + There can be multiple datasets with the same name accessible to the current + user. This can happen if either: + * A dataset has been explicitly shared with the user + * The user has access to team datasets + The `shared` flag controls whether these datasets are returned. + Args: shared: - If False, returns only datasets owned by the user. - If True, returns only the datasets which have been shared with the user. - If None, returns all datasets the user has access to (owned and shared). - Defaults to False. + * If False (default), returns only datasets owned by the user. In this + case at most one dataset will be returned. + * If True, returns datasets which have been shared with the user, + including team datasets. Excludes user's own datasets. Can return + multiple datasets. + * If None, returns all datasets the users has access to. Can return + multiple datasets. Returns: A list of datasets owned by the current user. @@ -230,14 +302,24 @@ def set_dataset_id_by_name( ) -> None: """Sets the dataset ID in the API client given the name of the desired dataset. + There can be multiple datasets with the same name accessible to the current + user. This can happen if either: + * A dataset has been explicitly shared with the user + * The user has access to team datasets + The `shared` flag controls whether these datasets are also checked. If multiple + datasets with the given name are found, the API client uses the ID of the first + dataset and prints a warning message. + Args: dataset_name: The name of the target dataset. shared: - If False, considers only datasets owned by the user. - If True, considers only the datasets which have been shared with the user. - If None, consider all datasets the user has access to (owned and shared). - Defaults to False. + * If False (default), checks only datasets owned by the user. + * If True, returns datasets which have been shared with the user, + including team datasets. Excludes user's own datasets. There can be + multiple candidate datasets. + * If None, returns all datasets the users has access to. There can be + multiple candidate datasets. Raises: ValueError: @@ -384,10 +466,13 @@ def create_new_dataset_with_unique_name( dataset_type=dataset_type, ) else: - existing_datasets = self._datasets_api.get_datasets_query_by_name( - dataset_name=dataset_basename, - exact=False, - shared=False, + existing_datasets = list( + utils.paginate_endpoint( + self._datasets_api.get_datasets_query_by_name, + dataset_name=dataset_basename, + exact=False, + shared=False, + ) ) existing_dataset_names = {dataset.name for dataset in existing_datasets} counter = 1 diff --git a/lightly/api/api_workflow_datasources.py b/lightly/api/api_workflow_datasources.py index 39cb836ce..90566002c 100644 --- a/lightly/api/api_workflow_datasources.py +++ b/lightly/api/api_workflow_datasources.py @@ -765,26 +765,28 @@ def get_custom_embedding_read_url( def list_datasource_permissions( self, - ) -> Dict[str, Union[bool, Optional[DatasourceConfigVerifyDataErrors]]]: + ) -> Dict[str, Union[bool, Dict[str, str]]]: """Lists granted access permissions for the datasource set up with a dataset. Returns a string dictionary, with each permission mapped to a boolean value, - see the example below. Additionally, there is the ``errors`` key. Permission - errors are stored in a dictionary where permission names are keys and error - messages are values. If there is no error, the value is ``None``. + see the example below. An additional ``errors`` key is present if any permission + errors have been encountered. Permission errors are stored in a dictionary where + permission names are keys and error messages are values. >>> from lightly.api import ApiWorkflowClient >>> client = ApiWorkflowClient( ... token="MY_LIGHTLY_TOKEN", dataset_id="MY_DATASET_ID" ... ) >>> client.list_datasource_permissions() - {'can_list': True, - 'can_overwrite': True, - 'can_read': True, - 'can_write': True, - 'errors': None} + { + 'can_read': True, + 'can_write': True, + 'can_list': False, + 'can_overwrite': True, + 'errors': {'can_list': 'error message'} + } """ return self._datasources_api.verify_datasource_by_dataset_id( dataset_id=self.dataset_id, - ) + ).to_dict() diff --git a/lightly/api/api_workflow_download_dataset.py b/lightly/api/api_workflow_download_dataset.py index a3a22cfee..9cac18882 100644 --- a/lightly/api/api_workflow_download_dataset.py +++ b/lightly/api/api_workflow_download_dataset.py @@ -1,16 +1,16 @@ import io import os +import urllib.request import warnings from concurrent.futures.thread import ThreadPoolExecutor from typing import Dict, List, Optional -from urllib.request import Request, urlopen +from urllib.request import Request import tqdm from PIL import Image -from lightly.api import download +from lightly.api import download, utils from lightly.api.bitmask import BitMask -from lightly.api.utils import paginate_endpoint from lightly.openapi_generated.swagger_client.models import ( DatasetEmbeddingData, ImageType, @@ -33,7 +33,7 @@ def _make_dir_and_save_image(output_dir: str, filename: str, img: Image): def _get_image_from_read_url(read_url: str): """Makes a get request to the signed read url and returns the image.""" request = Request(read_url, method="GET") - with urlopen(request) as response: + with urllib.request.urlopen(request) as response: blob = response.read() img = Image.open(io.BytesIO(blob)) return img @@ -293,7 +293,7 @@ def export_label_studio_tasks_by_tag_id( [{'id': 0, 'data': {'image': '...', ...}}] """ label_studio_tasks = list( - paginate_endpoint( + utils.paginate_endpoint( self._tags_api.export_tag_to_label_studio_tasks, page_size=20000, dataset_id=self.dataset_id, diff --git a/lightly/api/api_workflow_export.py b/lightly/api/api_workflow_export.py index 616132f47..7e9bb3ba4 100644 --- a/lightly/api/api_workflow_export.py +++ b/lightly/api/api_workflow_export.py @@ -1,7 +1,7 @@ import warnings from typing import Dict, List -from lightly.api.utils import paginate_endpoint, retry +from lightly.api import utils from lightly.openapi_generated.swagger_client.models import ( FileNameFormat, LabelBoxDataRow, @@ -39,7 +39,7 @@ def export_label_studio_tasks_by_tag_id( [{'id': 0, 'data': {'image': '...', ...}}] """ label_studio_tasks: List[LabelStudioTask] = list( - paginate_endpoint( + utils.paginate_endpoint( self._tags_api.export_tag_to_label_studio_tasks, page_size=20000, dataset_id=self.dataset_id, @@ -114,7 +114,7 @@ def export_label_box_data_rows_by_tag_id( ) ) label_box_data_rows: List[LabelBoxDataRow] = list( - paginate_endpoint( + utils.paginate_endpoint( self._tags_api.export_tag_to_label_box_data_rows, page_size=20000, dataset_id=self.dataset_id, @@ -187,7 +187,7 @@ def export_label_box_v4_data_rows_by_tag_id( [{'row_data': '...', 'global_key': 'image-1.jpg', 'media_type': 'IMAGE'} """ label_box_data_rows: List[LabelBoxV4DataRow] = list( - paginate_endpoint( + utils.paginate_endpoint( self._tags_api.export_tag_to_label_box_v4_data_rows, page_size=20000, dataset_id=self.dataset_id, @@ -248,10 +248,12 @@ def export_filenames_by_tag_id( >>> client.export_filenames_by_tag_id("646b40d6c06aae1b91294a9e") 'image-1.jpg\nimage-2.jpg\nimage-3.jpg' """ - filenames = retry( - self._tags_api.export_tag_to_basic_filenames, - dataset_id=self.dataset_id, - tag_id=tag_id, + filenames = "\n".join( + utils.paginate_endpoint( + self._tags_api.export_tag_to_basic_filenames, + dataset_id=self.dataset_id, + tag_id=tag_id, + ) ) return filenames @@ -314,23 +316,29 @@ def export_filenames_and_read_urls_by_tag_id( ] """ - filenames_string = retry( - self._tags_api.export_tag_to_basic_filenames, - dataset_id=self.dataset_id, - tag_id=tag_id, - file_name_format=FileNameFormat.NAME, + filenames_string = "\n".join( + utils.paginate_endpoint( + self._tags_api.export_tag_to_basic_filenames, + dataset_id=self.dataset_id, + tag_id=tag_id, + file_name_format=FileNameFormat.NAME, + ) ) - read_urls_string = retry( - self._tags_api.export_tag_to_basic_filenames, - dataset_id=self.dataset_id, - tag_id=tag_id, - file_name_format=FileNameFormat.REDIRECTED_READ_URL, + read_urls_string = "\n".join( + utils.paginate_endpoint( + self._tags_api.export_tag_to_basic_filenames, + dataset_id=self.dataset_id, + tag_id=tag_id, + file_name_format=FileNameFormat.REDIRECTED_READ_URL, + ) ) - datasource_urls_string = retry( - self._tags_api.export_tag_to_basic_filenames, - dataset_id=self.dataset_id, - tag_id=tag_id, - file_name_format=FileNameFormat.DATASOURCE_FULL, + datasource_urls_string = "\n".join( + utils.paginate_endpoint( + self._tags_api.export_tag_to_basic_filenames, + dataset_id=self.dataset_id, + tag_id=tag_id, + file_name_format=FileNameFormat.DATASOURCE_FULL, + ) ) # The endpoint exportTagToBasicFilenames returns a plain string so we # have to split it by newlines in order to get the individual entries. diff --git a/lightly/api/api_workflow_predictions.py b/lightly/api/api_workflow_predictions.py index 30f8edc13..3ea7b63af 100644 --- a/lightly/api/api_workflow_predictions.py +++ b/lightly/api/api_workflow_predictions.py @@ -1,7 +1,4 @@ -from concurrent.futures import ThreadPoolExecutor -from typing import Mapping, Optional, Sequence, Tuple - -import tqdm +from typing import Sequence from lightly.openapi_generated.swagger_client.models import ( PredictionSingleton, @@ -55,82 +52,6 @@ def create_or_update_prediction_task_schema( prediction_uuid_timestamp=prediction_version_id, ) - def create_or_update_predictions( - self, - sample_id_to_prediction_singletons: Mapping[str, Sequence[PredictionSingleton]], - prediction_version_id: int = -1, - progress_bar: Optional[tqdm.tqdm] = None, - max_workers: int = 8, - ) -> None: - """Creates or updates the predictions for specific samples. - - Args: - sample_id_to_prediction_singletons - A mapping from the sample_id of the sample to its corresponding prediction singletons. - The singletons can be from different tasks and different types. - - prediction_version_id: - A numerical ID (e.g., timestamp) to distinguish different predictions of different model versions. - Use the same id if you don't require versioning or if you wish to overwrite the previous schema. - This ID must match the ID of a prediction task schema. - - progress_bar: - Tqdm progress bar to show how many prediction files have already been uploaded. - - max_workers: - Maximum number of workers uploading predictions in parallel. - - Example: - >>> import time - >>> from tqdm import tqdm - >>> from lightly.api import ApiWorkflowClient - >>> from lightly.openapi_generated.swagger_client.models import ( - >>> PredictionTaskSchema, - >>> TaskType, - >>> PredictionTaskSchemaCategory, - >>> ) - >>> from lightly.api.prediction_singletons import PredictionSingletonClassificationRepr - >>> - >>> client = ApiWorkflowClient( - >>> token="MY_LIGHTLY_TOKEN", dataset_id="MY_DATASET_ID" - >>> ) - >>> - >>> samples = client._samples_api.get_samples_partial_by_dataset_id(dataset_id=client.dataset_id, mode=SamplePartialMode.FILENAMES) - >>> sample_id_to_prediction_singletons_dummy = { - >>> sample.id: [PredictionSingletonClassificationRepr(taskName="my-task", categoryId=i%4, score=0.9, probabilities=[0.1, 0.2, 0.3, 0.4])] - >>> for i, sample in enumerate(samples) - >>> } - >>> client.create_or_update_predictions( - >>> sample_id_to_prediction_singletons=sample_id_to_prediction_singletons_dummy, - >>> progress_bar=tqdm(desc="Uploading predictions", total=len(samples), unit=" predictions") - >>> ) - - - """ - - # handle the case where len(sample_id_to_prediction_singletons) < max_workers - max_workers = min(len(sample_id_to_prediction_singletons), max_workers) - max_workers = max(max_workers, 1) - - def upload_prediction( - sample_id_prediction_singletons_tuple: Tuple[ - str, Sequence[PredictionSingleton] - ] - ) -> None: - (sample_id, prediction_singletons) = sample_id_prediction_singletons_tuple - self.create_or_update_prediction( - sample_id=sample_id, - prediction_singletons=prediction_singletons, - prediction_version_id=prediction_version_id, - ) - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - for _ in executor.map( - upload_prediction, sample_id_to_prediction_singletons.items() - ): - if progress_bar is not None: - progress_bar.update(1) - def create_or_update_prediction( self, sample_id: str, diff --git a/lightly/api/api_workflow_upload_dataset.py b/lightly/api/api_workflow_upload_dataset.py deleted file mode 100644 index a13289bfd..000000000 --- a/lightly/api/api_workflow_upload_dataset.py +++ /dev/null @@ -1,336 +0,0 @@ -""" Upload Dataset Mixin """ - - -import os -import warnings -from concurrent.futures.thread import ThreadPoolExecutor -from datetime import datetime -from typing import Any, Dict, Optional, Union - -import tqdm -from lightly_utils import image_processing - -from lightly.api.utils import ( - MAXIMUM_FILENAME_LENGTH, - build_azure_signed_url_write_headers, - check_filename, - retry, -) -from lightly.openapi_generated.swagger_client.models import ( - DatasourceConfigBase, - InitialTagCreateRequest, - JobStatusMeta, - JobStatusUploadMethod, - SampleCreateRequest, - SamplePartialMode, - SampleWriteUrls, - TagUpsizeRequest, -) -from lightly.openapi_generated.swagger_client.rest import ApiException -from lightly.utils.hipify import bcolors - -try: - from lightly.data import LightlyDataset - - _lightly_dataset_available = True -except ( - RuntimeError, # Different CUDA versions for torch and torchvision - OSError, # Different CUDA versions for torch and torchvision (old) - ImportError, # No installation of torch or torchvision -): - _lightly_dataset_available = False - - -class _UploadDatasetMixin: - """Mixin to upload datasets to the Lightly Api.""" - - def upload_dataset( - self, - input: Union[str, "LightlyDataset"], - max_workers: int = 8, - mode: str = "thumbnails", - custom_metadata: Optional[Dict[str, Any]] = None, - ) -> None: - """Uploads a dataset to the Lightly Platform. - - Args: - input: - Either the path to the dataset, e.g. "path/to/dataset", - or the dataset in form of a LightlyDataset. - max_workers: - Maximum number of workers uploading images in parallel. - mode: - One of [full, thumbnails, metadata]. Whether to upload - thumbnails, full images, or metadata only. - custom_metadata: - COCO-style dictionary of custom metadata to be uploaded. Optional. - - Raises: - ValueError: - If dataset is too large or input has the wrong type. - RuntimeError: - If the connection to the server failed. - - """ - - # get all tags of the dataset - tags = self.get_all_tags() - if len(tags) > 0: - print( - f"Dataset with id {self.dataset_id} has {bcolors.OKGREEN}{len(tags)}{bcolors.ENDC} tags.", - flush=True, - ) - - # parse "input" variable - if isinstance(input, str): - if _lightly_dataset_available: - dataset = LightlyDataset(input_dir=input) - else: - raise RuntimeError( - "Can't create LightlyDataset! Requires torch and torchvision." - ) - else: - dataset = input - - # handle the case where len(dataset) < max_workers - max_workers = min(len(dataset), max_workers) - max_workers = max(max_workers, 1) - - # upload the samples - print( - f"Uploading {bcolors.OKGREEN}{len(dataset)}{bcolors.ENDC} images (with {bcolors.OKGREEN}{max_workers}{bcolors.ENDC} workers).", - flush=True, - ) - - # TODO: remove _size_in_bytes from image_processing - image_processing.metadata._size_in_bytes = ( - lambda img: 0 - ) # pylint: disable=protected-access - - # get the filenames of the samples already on the server - samples = retry( - self._samples_api.get_samples_partial_by_dataset_id, - dataset_id=self.dataset_id, - mode=SamplePartialMode.FILENAMES, - ) - filenames_on_server = [sample.file_name for sample in samples] - filenames_on_server_set = set(filenames_on_server) - if len(filenames_on_server) > 0: - print( - f"Found {bcolors.OKGREEN}{len(filenames_on_server)}{bcolors.ENDC} images already on the server" - ", they are skipped during the upload." - ) - - # check the maximum allowed dataset size - total_filenames = set(dataset.get_filenames()).union(filenames_on_server_set) - max_dataset_size = int(self._quota_api.get_quota_maximum_dataset_size()) - if len(total_filenames) > max_dataset_size: - msg = f"Your dataset has {bcolors.OKGREEN}{len(dataset)}{bcolors.ENDC} samples which" - msg += f" is more than the allowed maximum of {bcolors.OKGREEN}{max_dataset_size}{bcolors.ENDC}" - raise ValueError(msg) - - # index custom metadata by filename (only if it exists) - filename_to_metadata = {} - if custom_metadata is not None: - self.verify_custom_metadata_format(custom_metadata) - filename_to_metadata = self.index_custom_metadata_by_filename( - custom_metadata, - ) - - # get the datasource - try: - datasource_config: DatasourceConfigBase = self.get_datasource() - datasource_type = datasource_config["type"] - except ApiException: - datasource_type = "LIGHTLY" # default to lightly datasource - - # register dataset upload - job_status_meta = JobStatusMeta( - total=len(total_filenames), - processed=len(filenames_on_server), - is_registered=True, - upload_method=JobStatusUploadMethod.USER_PIP, - ) - self._datasets_api.register_dataset_upload_by_id( - job_status_meta, self.dataset_id - ) - - pbar = tqdm.tqdm( - unit="imgs", - total=len(total_filenames) - len(filenames_on_server), - ) - tqdm_lock = tqdm.tqdm.get_lock() - - # define lambda function for concurrent upload - def lambda_(i): - # load image - image, _, filename = dataset[i] - if filename in filenames_on_server_set: - # the sample was already uploaded - return True - - filepath = dataset.get_filepath_from_filename(filename, image) - - # get custom metadata (evaluates to None if there is none) - custom_metadata_item = filename_to_metadata.get(filename, None) - - # try to upload image - try: - self._upload_single_image( - image=image, - filename=filename, - filepath=filepath, - mode=mode, - custom_metadata=custom_metadata_item, - datasource_type=datasource_type, - ) - success = True - except Exception as e: # pylint: disable=broad-except - warnings.warn(f"Upload of image {filename} failed with error {e}") - success = False - - # update the progress bar - tqdm_lock.acquire() - pbar.update(1) - tqdm_lock.release() - # return whether the upload was successful - return success - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - results = list( - executor.map(lambda_, [i for i in range(len(dataset))], chunksize=1) - ) - - if not all(results): - msg = "Warning: Unsuccessful upload(s)! " - msg += "This could cause problems when uploading embeddings." - msg += "Failed at image: {}".format(results.index(False)) - warnings.warn(msg) - - # set image type of data and create initial tag - if mode == "full": - img_type = "full" - elif mode == "thumbnails": - img_type = "thumbnail" - else: - img_type = "meta" - - if len(tags) == 0: - # create initial tag - initial_tag_create_request = InitialTagCreateRequest( - img_type=img_type, - creator=self._creator, - ) - self._tags_api.create_initial_tag_by_dataset_id( - initial_tag_create_request=initial_tag_create_request, - dataset_id=self.dataset_id, - ) - else: - # upsize existing tags - upsize_tags_request = TagUpsizeRequest( - upsize_tag_name=datetime.now().strftime("%Y%m%d_%Hh%Mm%Ss"), - upsize_tag_creator=self._creator, - ) - self._tags_api.upsize_tags_by_dataset_id( - tag_upsize_request=upsize_tags_request, - dataset_id=self.dataset_id, - ) - - def _upload_single_image( - self, - image, - filename: str, - filepath: str, - mode: str, - custom_metadata: Union[Dict, None] = None, - datasource_type: str = "LIGHTLY", - ) -> None: - """Uploads a single image to the Lightly platform.""" - # check whether the filepath is too long - if not check_filename(filepath): - msg = ( - "Filepath {filepath} is longer than the allowed maximum of " - f"{MAXIMUM_FILENAME_LENGTH} characters and will be skipped." - ) - raise ValueError(msg) - - # calculate metadata, and check if corrupted - metadata = image_processing.Metadata(image).to_dict() - metadata["sizeInBytes"] = os.path.getsize(filepath) - - # try to get exif data - try: - exifdata = image_processing.Exifdata(image) - except Exception: # pylint disable=broad-except - exifdata = None - - # generate thumbnail if necessary - thumbname = None - if not metadata["is_corrupted"] and mode in ["thumbnails", "full"]: - thumbname = ".".join(filename.split(".")[:-1]) + "_thumb.webp" - - body = SampleCreateRequest( - file_name=filename, - thumb_name=thumbname, - meta_data=metadata, - exif=exifdata if exifdata is None else exifdata.to_dict(), - custom_meta_data=custom_metadata, - ) - sample_id = retry( - self._samples_api.create_sample_by_dataset_id, - sample_create_request=body, - dataset_id=self.dataset_id, - ).id - - if not metadata["is_corrupted"] and mode in ["thumbnails", "full"]: - - def upload_thumbnail(image, signed_url): - thumbnail = image_processing.Thumbnail(image) - image_to_upload = thumbnail.to_bytes() - headers = None - if datasource_type == "AZURE": - # build headers for Azure blob storage - size_in_bytes = str(image_to_upload.getbuffer().nbytes) - headers = build_azure_signed_url_write_headers(size_in_bytes) - retry( - self.upload_file_with_signed_url, - image_to_upload, - signed_url, - headers=headers, - ) - thumbnail.thumbnail.close() - - def upload_full_image(filepath, signed_url): - with open(filepath, "rb") as image_to_upload: - headers = None - if datasource_type == "AZURE": - # build headers for Azure blob storage - image_to_upload.seek(0, 2) - size_in_bytes = str(image_to_upload.tell()) - image_to_upload.seek(0, 0) - headers = build_azure_signed_url_write_headers(size_in_bytes) - retry( - self.upload_file_with_signed_url, - image_to_upload, - signed_url, - headers=headers, - ) - - if mode == "thumbnails": - thumbnail_url = retry( - self._samples_api.get_sample_image_write_url_by_id, - dataset_id=self.dataset_id, - sample_id=sample_id, - is_thumbnail=True, - ) - upload_thumbnail(image, thumbnail_url) - elif mode == "full": - sample_write_urls: SampleWriteUrls = retry( - self._samples_api.get_sample_image_write_urls_by_id, - dataset_id=self.dataset_id, - sample_id=sample_id, - ) - upload_thumbnail(image, sample_write_urls.thumb) - upload_full_image(filepath, sample_write_urls.full) - - image.close() diff --git a/lightly/api/api_workflow_upload_embeddings.py b/lightly/api/api_workflow_upload_embeddings.py index 3ae89df76..a8b21a169 100644 --- a/lightly/api/api_workflow_upload_embeddings.py +++ b/lightly/api/api_workflow_upload_embeddings.py @@ -1,9 +1,10 @@ import csv import io import tempfile +import urllib.request from datetime import datetime from typing import List -from urllib.request import Request, urlopen +from urllib.request import Request from lightly.api.utils import retry from lightly.openapi_generated.swagger_client.models import ( @@ -12,7 +13,7 @@ Trigger2dEmbeddingJobRequest, WriteCSVUrlData, ) -from lightly.utils.io import check_embeddings, check_filenames +from lightly.utils import io as io_utils class EmbeddingDoesNotExistError(ValueError): @@ -23,7 +24,7 @@ class _UploadEmbeddingsMixin: def _get_csv_reader_from_read_url(self, read_url: str) -> None: """Makes a get request to the signed read url and returns the .csv file.""" request = Request(read_url, method="GET") - with urlopen(request) as response: + with urllib.request.urlopen(request) as response: buffer = io.StringIO(response.read().decode("utf-8")) reader = csv.reader(buffer) @@ -104,7 +105,9 @@ def upload_embeddings(self, path_to_embeddings_csv: str, name: str) -> None: the upload is aborted. """ - check_embeddings(path_to_embeddings_csv, remove_additional_columns=True) + io_utils.check_embeddings( + path_to_embeddings_csv, remove_additional_columns=True + ) # Try to append the embeddings on the server, if they exist try: @@ -251,7 +254,7 @@ def _order_csv_by_filenames(self, path_to_embeddings_csv: str) -> List[str]: f"The filenames in the embedding file and " f"the filenames on the server do not align" ) - check_filenames(filenames) + io_utils.check_filenames(filenames) rows_without_header_ordered = self._order_list_by_filenames( filenames, rows_without_header diff --git a/lightly/api/api_workflow_upload_metadata.py b/lightly/api/api_workflow_upload_metadata.py index 48ef01592..e5859b089 100644 --- a/lightly/api/api_workflow_upload_metadata.py +++ b/lightly/api/api_workflow_upload_metadata.py @@ -4,14 +4,15 @@ from requests import Response from tqdm import tqdm -from lightly.api.utils import retry +from lightly.api.utils import paginate_endpoint, retry from lightly.openapi_generated.swagger_client.models import ( ConfigurationEntry, ConfigurationSetRequest, + SampleDataModes, SamplePartialMode, SampleUpdateRequest, ) -from lightly.utils.hipify import print_as_warning +from lightly.utils import hipify from lightly.utils.io import COCO_ANNOTATION_KEYS @@ -155,10 +156,13 @@ def upload_custom_metadata( for image_info in custom_metadata[COCO_ANNOTATION_KEYS.images] } - samples = retry( - self._samples_api.get_samples_partial_by_dataset_id, - dataset_id=self.dataset_id, - mode=SamplePartialMode.FILENAMES, + samples: List[SampleDataModes] = list( + paginate_endpoint( + self._samples_api.get_samples_partial_by_dataset_id, + page_size=25000, # as this information is rather small, we can request a lot of samples at once + dataset_id=self.dataset_id, + mode=SamplePartialMode.FILENAMES, + ) ) filename_to_sample_id = {sample.file_name: sample.id for sample in samples} @@ -168,7 +172,7 @@ def upload_custom_metadata( image_id = metadata[COCO_ANNOTATION_KEYS.custom_metadata_image_id] filename = image_id_to_filename.get(image_id, None) if filename is None: - print_as_warning( + hipify.print_as_warning( "No image found for custom metadata annotation " f"with image_id {image_id}. " "This custom metadata annotation is skipped. ", @@ -177,7 +181,7 @@ def upload_custom_metadata( continue sample_id = filename_to_sample_id.get(filename, None) if sample_id is None: - print_as_warning( + hipify.print_as_warning( "You tried to upload custom metadata for a sample with " f"filename {{{filename}}}, " "but a sample with this filename " diff --git a/lightly/api/bitmask.py b/lightly/api/bitmask.py index ddb935571..a79f9f69f 100644 --- a/lightly/api/bitmask.py +++ b/lightly/api/bitmask.py @@ -2,7 +2,7 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -from copy import deepcopy +import copy from typing import List @@ -175,7 +175,7 @@ def difference(self, other): self.x = self.x - other.x def __sub__(self, other): - ret = deepcopy(self) + ret = copy.deepcopy(self) ret.difference(other) return ret diff --git a/lightly/api/utils.py b/lightly/api/utils.py index 46780b85e..5960fa7f6 100644 --- a/lightly/api/utils.py +++ b/lightly/api/utils.py @@ -70,6 +70,7 @@ def retry(func, *args, **kwargs): class Paginated(Iterator): def __init__(self, fn, page_size, *args, **kwargs): self.entries: List = [] + self.last_chunk_size = page_size self.offset = 0 self.fn = fn self.page_size = page_size @@ -81,6 +82,9 @@ def __iter__(self): def __next__(self): if len(self.entries) == 0: + # stop iteration if the last chunk was smaller than the page size + if self.last_chunk_size < self.page_size: + raise StopIteration chunk = retry( self.fn, page_offset=self.offset * self.page_size, @@ -91,6 +95,11 @@ def __next__(self): if len(chunk) == 0: raise StopIteration self.offset += 1 + self.last_chunk_size = len(chunk) + # Handle the case where the chunk is a string. In this case we want + # to return the whole page as a single string instead of an interable + # of characters. + chunk = chunk if not isinstance(chunk, str) else [chunk] self.entries.extend(chunk) return self.entries.pop(0) diff --git a/lightly/api/version_checking.py b/lightly/api/version_checking.py index 048ac3b8d..175179a5f 100644 --- a/lightly/api/version_checking.py +++ b/lightly/api/version_checking.py @@ -5,7 +5,7 @@ from lightly.api import utils from lightly.api.swagger_api_client import LightlySwaggerApiClient from lightly.openapi_generated.swagger_client.api import VersioningApi -from lightly.utils.version_compare import version_compare +from lightly.utils import version_compare class LightlyAPITimeoutException(Exception): @@ -33,14 +33,14 @@ def is_latest_version(current_version: str) -> bool: latest_version: str = versioning_api.get_latest_pip_version( current_version=current_version ) - return version_compare(current_version, latest_version) >= 0 + return version_compare.version_compare(current_version, latest_version) >= 0 def is_compatible_version(current_version: str) -> bool: with TimeoutDecorator(1): versioning_api = get_versioning_api() minimum_version: str = versioning_api.get_minimum_compatible_pip_version() - return version_compare(current_version, minimum_version) >= 0 + return version_compare.version_compare(current_version, minimum_version) >= 0 def get_versioning_api() -> VersioningApi: diff --git a/lightly/data/collate.py b/lightly/data/collate.py index 25720171d..3a935951f 100644 --- a/lightly/data/collate.py +++ b/lightly/data/collate.py @@ -3,6 +3,8 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +import math +from multiprocessing import Value from typing import List, Optional, Tuple, Union from warnings import warn @@ -1345,6 +1347,176 @@ def forward( return (views_global, views_local, grids_global, grids_local), labels, fnames +class IJEPAMaskCollator: + """Collator for IJEPA model [0]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. + + Code inspired by [1]. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + """ + + def __init__( + self, + input_size=(224, 224), + patch_size=16, + enc_mask_scale=(0.2, 0.8), + pred_mask_scale=(0.2, 0.8), + aspect_ratio=(0.3, 3.0), + nenc=1, + npred=2, + min_keep=4, + allow_overlap=False, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.patch_size = patch_size + self.height, self.width = ( + input_size[0] // patch_size, + input_size[1] // patch_size, + ) + self.enc_mask_scale = enc_mask_scale + self.pred_mask_scale = pred_mask_scale + self.aspect_ratio = aspect_ratio + self.nenc = nenc + self.npred = npred + self.min_keep = min_keep # minimum number of patches to keep + self.allow_overlap = ( + allow_overlap # whether to allow overlap b/w enc and pred masks + ) + self._itr_counter = Value("i", -1) # collator is shared across worker processes + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def _sample_block_size(self, generator, scale, aspect_ratio_scale): + _rand = torch.rand(1, generator=generator).item() + # -- Sample block scale + min_s, max_s = scale + mask_scale = min_s + _rand * (max_s - min_s) + max_keep = int(self.height * self.width * mask_scale) + # -- Sample block aspect-ratio + min_ar, max_ar = aspect_ratio_scale + aspect_ratio = min_ar + _rand * (max_ar - min_ar) + # -- Compute block height and width (given scale and aspect-ratio) + h = int(round(math.sqrt(max_keep * aspect_ratio))) + w = int(round(math.sqrt(max_keep / aspect_ratio))) + while h >= self.height: + h -= 1 + while w >= self.width: + w -= 1 + + return (h, w) + + def _sample_block_mask(self, b_size, acceptable_regions=None): + h, w = b_size + + def constrain_mask(mask, tries=0): + """Helper to restrict given mask to a set of acceptable regions""" + N = max(int(len(acceptable_regions) - tries), 0) + for k in range(N): + mask *= acceptable_regions[k] + + # -- + # -- Loop to sample masks until we find a valid one + tries = 0 + timeout = og_timeout = 20 + valid_mask = False + while not valid_mask: + # -- Sample block top-left corner + top = torch.randint(0, self.height - h, (1,)) + left = torch.randint(0, self.width - w, (1,)) + mask = torch.zeros((self.height, self.width), dtype=torch.int32) + mask[top : top + h, left : left + w] = 1 + # -- Constrain mask to a set of acceptable regions + if acceptable_regions is not None: + constrain_mask(mask, tries) + mask = torch.nonzero(mask.flatten()) + # -- If mask too small try again + valid_mask = len(mask) > self.min_keep + if not valid_mask: + timeout -= 1 + if timeout == 0: + tries += 1 + timeout = og_timeout + mask = mask.squeeze() + # -- + mask_complement = torch.ones((self.height, self.width), dtype=torch.int32) + mask_complement[top : top + h, left : left + w] = 0 + # -- + return mask, mask_complement + + def __call__(self, batch): + """ + Create encoder and predictor masks when collating imgs into a batch + # 1. sample enc block (size + location) using seed + # 2. sample pred block (size) using seed + # 3. sample several enc block locations for each image (w/o seed) + # 4. sample several pred block locations for each image (w/o seed) + # 5. return enc mask and pred mask + """ + B = len(batch) + + collated_batch = torch.utils.data.default_collate(batch) + + seed = self.step() + g = torch.Generator() + g.manual_seed(seed) + p_size = self._sample_block_size( + generator=g, + scale=self.pred_mask_scale, + aspect_ratio_scale=self.aspect_ratio, + ) + e_size = self._sample_block_size( + generator=g, scale=self.enc_mask_scale, aspect_ratio_scale=(1.0, 1.0) + ) + + collated_masks_pred, collated_masks_enc = [], [] + min_keep_pred = self.height * self.width + min_keep_enc = self.height * self.width + for _ in range(B): + masks_p, masks_C = [], [] + for _ in range(self.npred): + mask, mask_C = self._sample_block_mask(p_size) + masks_p.append(mask) + masks_C.append(mask_C) + min_keep_pred = min(min_keep_pred, len(mask)) + collated_masks_pred.append(masks_p) + + acceptable_regions = masks_C + + if self.allow_overlap: + acceptable_regions = None + + masks_e = [] + for _ in range(self.nenc): + mask, _ = self._sample_block_mask( + e_size, acceptable_regions=acceptable_regions + ) + masks_e.append(mask) + min_keep_enc = min(min_keep_enc, len(mask)) + collated_masks_enc.append(masks_e) + + collated_masks_pred = [ + [cm[:min_keep_pred] for cm in cm_list] for cm_list in collated_masks_pred + ] + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + # -- + collated_masks_enc = [ + [cm[:min_keep_enc] for cm in cm_list] for cm_list in collated_masks_enc + ] + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + + return collated_batch, collated_masks_enc, collated_masks_pred + + def _deprecation_warning_collate_functions() -> None: warn( "Collate functions are deprecated and will be removed in favor of transforms in v1.4.0.\n" diff --git a/lightly/embedding/embedding.py b/lightly/embedding/embedding.py index d86e25a3a..5439fea3f 100644 --- a/lightly/embedding/embedding.py +++ b/lightly/embedding/embedding.py @@ -14,9 +14,6 @@ from lightly.embedding._base import BaseEmbedding from lightly.utils.reordering import sort_items_by_keys -if lightly._is_prefetch_generator_available(): - from prefetch_generator import BackgroundGenerator - class SelfSupervisedEmbedding(BaseEmbedding): """Implementation of self-supervised embedding models. @@ -105,8 +102,6 @@ def embed( embeddings, labels, filenames = None, None, [] dataset = dataloader.dataset - if lightly._is_prefetch_generator_available(): - dataloader = BackgroundGenerator(dataloader, max_prefetch=3) pbar = tqdm(total=len(dataset), unit="imgs") diff --git a/lightly/models/modules/ijepa.py b/lightly/models/modules/ijepa.py new file mode 100644 index 000000000..3eb14a247 --- /dev/null +++ b/lightly/models/modules/ijepa.py @@ -0,0 +1,496 @@ +import math +from functools import partial +from typing import Callable, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from torchvision.models import vision_transformer +from torchvision.models.vision_transformer import ConvStemConfig + +from lightly.models import utils + + +class IJEPAPredictor(vision_transformer.Encoder): + """Predictor for the I-JEPA model [0]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. + + Predict patch embeddings. Code inspired by [1]. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + + Attributes: + seq_length: + Token sequence length, including the class token. + num_layers: + Number of transformer blocks. + num_heads: + Number of attention heads. + hidden_dim: + Dimension of the input and output tokens. + predictor_embed_dim: + Dimension of inner predicted tokens + mlp_dim: + Dimension of the MLP in the transformer block. + dropout: + Percentage of elements set to zero after the MLP in the transformer. + attention_dropout: + Percentage of elements set to zero after the attention head. + + """ + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + predictor_embed_dim: int, + num_patches: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + **kwargs + ): + super().__init__( + seq_length=seq_length, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, + ) + self.predictor_embed = nn.Linear(mlp_dim, predictor_embed_dim, bias=True) + self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + self.predictor_proj = nn.Linear(predictor_embed_dim, mlp_dim, bias=True) + self.predictor_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False + ) + predictor_pos_embed = _get_2d_sincos_pos_embed( + self.predictor_pos_embed.shape[-1], int(num_patches**0.5), cls_token=False + ) + self.predictor_pos_embed.data.copy_( + torch.from_numpy(predictor_pos_embed).float().unsqueeze(0) + ) + + @classmethod + def from_vit_encoder(cls, vit_encoder, num_patches): + """Creates a I-JEPA predictor backbone (mhas and layernorm) from a torchvision ViT encoder.""" + # Create a new instance with dummy values as they will be overwritten + # by the copied vit_encoder attributes + encoder = cls( + seq_length=1, + num_layers=1, + num_heads=1, + hidden_dim=1, + predictor_embed_dim=768, + mlp_dim=768, + num_patches=num_patches, + dropout=0, + attention_dropout=0, + ) + encoder.layers = vit_encoder.layers + encoder.ln = vit_encoder.ln + return encoder + + def forward(self, x, masks_x, masks): + assert (masks is not None) and ( + masks_x is not None + ), "Cannot run predictor without mask indices" + + if not isinstance(masks_x, list): + masks_x = [masks_x] + + if not isinstance(masks, list): + masks = [masks] + + B = len(x) // len(masks_x) + x = self.predictor_embed(x) + x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + + x += utils.apply_masks(x_pos_embed, masks_x) + _, N_ctxt, _ = x.shape + + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = utils.apply_masks(pos_embs, masks) + pos_embs = utils.repeat_interleave_batch(pos_embs, B, repeat=len(masks_x)) + pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1) + + pred_tokens += pos_embs + x = x.repeat(len(masks), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + x = self.ln(self.layers(x)) + + x = x[:, N_ctxt:] + x = self.predictor_proj(x) + + return x + + +class IJEPAEncoder(vision_transformer.Encoder): + """Encoder for the I-JEPA model [0]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. + + Encodes patch embeddings. Code inspired by [1]. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + + Attributes: + seq_length: + Token sequence length, including the class token. + num_layers: + Number of transformer blocks. + num_heads: + Number of attention heads. + hidden_dim: + Dimension of the input and output tokens. + mlp_dim: + Dimension of the MLP in the transformer block. + dropout: + Percentage of elements set to zero after the MLP in the transformer. + attention_dropout: + Percentage of elements set to zero after the attention head. + + """ + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__( + seq_length=seq_length, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, + ) + + @classmethod + def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder): + """Creates a IJEPA encoder from a torchvision ViT encoder.""" + # Create a new instance with dummy values as they will be overwritten + # by the copied vit_encoder attributes + encoder = cls( + seq_length=1, + num_layers=1, + num_heads=1, + hidden_dim=1, + mlp_dim=1, + dropout=0, + attention_dropout=0, + ) + encoder.pos_embedding = vit_encoder.pos_embedding + encoder.dropout = vit_encoder.dropout + encoder.layers = vit_encoder.layers + encoder.ln = vit_encoder.ln + return encoder + + def forward( + self, input: torch.Tensor, idx_keep: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Encode input tokens. + + Args: + input: + Batch of token sequences. + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be encoded. + + Returns: + Batch of encoded output tokens. + """ + input = input + self.interpolate_pos_encoding(input) + if idx_keep is not None: + input = utils.apply_masks(input, idx_keep) + return self.ln(self.layers(self.dropout(input))) + + def interpolate_pos_encoding(self, input: torch.Tensor): + """Returns the interpolated positional embedding for the given input. + + This function interpolates self.pos_embedding for all tokens in the input, + ignoring the class token. This allows encoding variable sized images. + + Args: + input: + Input tensor with shape (batch_size, num_sequences). + + """ + # code copied from: + # https://github.com/facebookresearch/msn/blob/4388dc1eadbe3042b85d3296d41b9b207656e043/src/deit.py#L291 + npatch = input.shape[1] - 1 + N = self.pos_embedding.shape[1] - 1 + if npatch == N: + return self.pos_embedding + class_emb = self.pos_embedding[:, 0] + pos_embedding = self.pos_embedding[:, 1:] + dim = input.shape[-1] + pos_embedding = nn.functional.interpolate( + pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=math.sqrt(npatch / N), + mode="bicubic", + ) + pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) + + +class IJEPABackbone(vision_transformer.VisionTransformer): + """Encoder for the I-JEPA model [0]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. + + Converts images into patches and encodes them. Code inspired by [1]. + Note that this implementation uses a learned positional embedding while [0] + uses a fixed positional embedding. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + + Attributes: + image_size: + Input image size. + patch_size: + Width and height of the image patches. image_size must be a multiple + of patch_size. + num_layers: + Number of transformer blocks. + num_heads: + Number of attention heads. + hidden_dim: + Dimension of the input and output tokens. + mlp_dim: + Dimension of the MLP in the transformer block. + dropout: + Percentage of elements set to zero after the MLP in the transformer. + attention_dropout: + Percentage of elements set to zero after the attention head. + num_classes: + Number of classes for the classification head. Currently not used. + representation_size: + If specified, an additional linear layer is added before the + classification head to change the token dimension from hidden_dim + to representation_size. Currently not used. + norm_layer: + Callable that creates a normalization layer. + + """ + + def __init__( + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float = 0, + attention_dropout: float = 0, + num_classes: int = 1000, + representation_size: Optional[int] = None, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + conv_stem_configs: Optional[List[ConvStemConfig]] = None, + ): + super().__init__( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + num_classes=num_classes, + representation_size=representation_size, + norm_layer=norm_layer, + conv_stem_configs=conv_stem_configs, + ) + self.encoder = IJEPAEncoder( + seq_length=self.seq_length, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, + ) + + @classmethod + def from_vit(cls, vit: vision_transformer.VisionTransformer): + """Creates a IJEPABackbone from a torchvision ViT model.""" + # Create a new instance with dummy values as they will be overwritten + # by the copied vit_encoder attributes + backbone = cls( + image_size=vit.image_size, + patch_size=vit.patch_size, + num_layers=1, + num_heads=1, + hidden_dim=vit.hidden_dim, + mlp_dim=vit.mlp_dim, + dropout=vit.dropout, + attention_dropout=vit.attention_dropout, + num_classes=vit.num_classes, + representation_size=vit.representation_size, + norm_layer=vit.norm_layer, + ) + backbone.conv_proj = vit.conv_proj + backbone.class_token = vit.class_token + backbone.seq_length = vit.seq_length + backbone.heads = vit.heads + backbone.encoder = IJEPAEncoder.from_vit_encoder(vit.encoder) + return backbone + + def forward( + self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Returns encoded class tokens from a batch of images. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be passed to the + encoder. + + Returns: + Tensor with shape (batch_size, hidden_dim) containing the + encoded class token for every image. + + """ + if idx_keep is not None: + if not isinstance(idx_keep, list): + idx_keep = [idx_keep] + + out = self.encode(images, idx_keep) + return out + + def encode( + self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Returns encoded class and patch tokens from images. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be passed to the + encoder. + + Returns: + Tensor with shape (batch_size, sequence_length, hidden_dim) + containing the encoded class and patch tokens for every image. + + """ + out = self.images_to_tokens(images, prepend_class_token=True) + return self.encoder(out, idx_keep) + + def images_to_tokens( + self, images: torch.Tensor, prepend_class_token: bool + ) -> torch.Tensor: + """Converts images into patch tokens. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + + Returns: + Tensor with shape (batch_size, sequence_length - 1, hidden_dim) + containing the patch tokens. + """ + x = self.conv_proj(images) + tokens = x.flatten(2).transpose(1, 2) + if prepend_class_token: + tokens = utils.prepend_class_token(tokens, self.class_token) + return tokens + + +def _get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = _get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def _get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def _get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid length + return: + pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = np.arange(grid_size, dtype=float) + pos_embed = _get_1d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def _get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 614132db4..a983686bc 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -566,3 +566,31 @@ def get_weight_decay_parameters( else: params.append(param) return params, params_no_weight_decay + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + return _no_grad_trunc_normal(tensor, mean, std, a, b) + + +def apply_masks(x, masks): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors containing indices of patches in [N] to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + return torch.cat(all_x, dim=0) + + +def repeat_interleave_batch(x, B, repeat): + N = len(x) // B + x = torch.cat( + [ + torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0) + for i in range(N) + ], + dim=0, + ) + return x diff --git a/lightly/openapi_generated/swagger_client/__init__.py b/lightly/openapi_generated/swagger_client/__init__.py index 4077a1b74..a03a7a3b8 100644 --- a/lightly/openapi_generated/swagger_client/__init__.py +++ b/lightly/openapi_generated/swagger_client/__init__.py @@ -185,7 +185,15 @@ from lightly.openapi_generated.swagger_client.models.prediction_singleton_semantic_segmentation import PredictionSingletonSemanticSegmentation from lightly.openapi_generated.swagger_client.models.prediction_singleton_semantic_segmentation_all_of import PredictionSingletonSemanticSegmentationAllOf from lightly.openapi_generated.swagger_client.models.prediction_task_schema import PredictionTaskSchema +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_base import PredictionTaskSchemaBase from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category import PredictionTaskSchemaCategory +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category_keypoints import PredictionTaskSchemaCategoryKeypoints +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category_keypoints_all_of import PredictionTaskSchemaCategoryKeypointsAllOf +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_keypoint import PredictionTaskSchemaKeypoint +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_keypoint_all_of import PredictionTaskSchemaKeypointAllOf +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_simple import PredictionTaskSchemaSimple +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_simple_all_of import PredictionTaskSchemaSimpleAllOf +from lightly.openapi_generated.swagger_client.models.prediction_task_schemas import PredictionTaskSchemas from lightly.openapi_generated.swagger_client.models.questionnaire_data import QuestionnaireData from lightly.openapi_generated.swagger_client.models.s3_region import S3Region from lightly.openapi_generated.swagger_client.models.sama_task import SamaTask diff --git a/lightly/openapi_generated/swagger_client/api/predictions_api.py b/lightly/openapi_generated/swagger_client/api/predictions_api.py index 23a7ecfa3..057cbdcb7 100644 --- a/lightly/openapi_generated/swagger_client/api/predictions_api.py +++ b/lightly/openapi_generated/swagger_client/api/predictions_api.py @@ -27,6 +27,7 @@ from lightly.openapi_generated.swagger_client.models.create_entity_response import CreateEntityResponse from lightly.openapi_generated.swagger_client.models.prediction_singleton import PredictionSingleton from lightly.openapi_generated.swagger_client.models.prediction_task_schema import PredictionTaskSchema +from lightly.openapi_generated.swagger_client.models.prediction_task_schemas import PredictionTaskSchemas from lightly.openapi_generated.swagger_client.api_client import ApiClient from lightly.openapi_generated.swagger_client.api_response import ApiResponse @@ -718,7 +719,7 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An _request_auth=_params.get('_request_auth')) @validate_arguments - def get_prediction_task_schemas_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> List[PredictionTaskSchema]: # noqa: E501 + def get_prediction_task_schemas_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> PredictionTaskSchemas: # noqa: E501 """get_prediction_task_schemas_by_dataset_id # noqa: E501 Get list of all the prediction task schemas for a datasetId at a specific predictionUUIDTimestamp. If no predictionUUIDTimestamp is set, it defaults to the newest # noqa: E501 @@ -741,7 +742,7 @@ def get_prediction_task_schemas_by_dataset_id(self, dataset_id : Annotated[const :return: Returns the result object. If the method is called asynchronously, returns the request thread. - :rtype: List[PredictionTaskSchema] + :rtype: PredictionTaskSchemas """ kwargs['_return_http_data_only'] = True if '_preload_content' in kwargs: @@ -785,7 +786,7 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id : :return: Returns the result object. If the method is called asynchronously, returns the request thread. - :rtype: tuple(List[PredictionTaskSchema], status_code(int), headers(HTTPHeaderDict)) + :rtype: tuple(PredictionTaskSchemas, status_code(int), headers(HTTPHeaderDict)) """ _params = locals() @@ -847,7 +848,7 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id : _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501 _response_types_map = { - '200': "List[PredictionTaskSchema]", + '200': "PredictionTaskSchemas", '400': "ApiErrorResponse", '401': "ApiErrorResponse", '403': "ApiErrorResponse", diff --git a/lightly/openapi_generated/swagger_client/models/__init__.py b/lightly/openapi_generated/swagger_client/models/__init__.py index dbed5827c..1c47427c7 100644 --- a/lightly/openapi_generated/swagger_client/models/__init__.py +++ b/lightly/openapi_generated/swagger_client/models/__init__.py @@ -152,7 +152,15 @@ from lightly.openapi_generated.swagger_client.models.prediction_singleton_semantic_segmentation import PredictionSingletonSemanticSegmentation from lightly.openapi_generated.swagger_client.models.prediction_singleton_semantic_segmentation_all_of import PredictionSingletonSemanticSegmentationAllOf from lightly.openapi_generated.swagger_client.models.prediction_task_schema import PredictionTaskSchema +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_base import PredictionTaskSchemaBase from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category import PredictionTaskSchemaCategory +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category_keypoints import PredictionTaskSchemaCategoryKeypoints +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category_keypoints_all_of import PredictionTaskSchemaCategoryKeypointsAllOf +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_keypoint import PredictionTaskSchemaKeypoint +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_keypoint_all_of import PredictionTaskSchemaKeypointAllOf +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_simple import PredictionTaskSchemaSimple +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_simple_all_of import PredictionTaskSchemaSimpleAllOf +from lightly.openapi_generated.swagger_client.models.prediction_task_schemas import PredictionTaskSchemas from lightly.openapi_generated.swagger_client.models.questionnaire_data import QuestionnaireData from lightly.openapi_generated.swagger_client.models.s3_region import S3Region from lightly.openapi_generated.swagger_client.models.sama_task import SamaTask diff --git a/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3_docker.py b/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3_docker.py index 39310739e..962263917 100644 --- a/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3_docker.py +++ b/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3_docker.py @@ -44,7 +44,8 @@ class DockerWorkerConfigV3Docker(BaseModel): relevant_filenames_file: Optional[StrictStr] = Field(None, alias="relevantFilenamesFile") selected_sequence_length: Optional[conint(strict=True, ge=1)] = Field(None, alias="selectedSequenceLength") upload_report: Optional[StrictBool] = Field(None, alias="uploadReport") - __properties = ["checkpoint", "corruptnessCheck", "datasource", "embeddings", "enableTraining", "training", "normalizeEmbeddings", "numProcesses", "numThreads", "outputImageFormat", "pretagging", "pretaggingUpload", "relevantFilenamesFile", "selectedSequenceLength", "uploadReport"] + use_datapool: Optional[StrictBool] = Field(None, alias="useDatapool") + __properties = ["checkpoint", "corruptnessCheck", "datasource", "embeddings", "enableTraining", "training", "normalizeEmbeddings", "numProcesses", "numThreads", "outputImageFormat", "pretagging", "pretaggingUpload", "relevantFilenamesFile", "selectedSequenceLength", "uploadReport", "useDatapool"] class Config: """Pydantic configuration""" @@ -112,7 +113,8 @@ def from_dict(cls, obj: dict) -> DockerWorkerConfigV3Docker: "pretagging_upload": obj.get("pretaggingUpload"), "relevant_filenames_file": obj.get("relevantFilenamesFile"), "selected_sequence_length": obj.get("selectedSequenceLength"), - "upload_report": obj.get("uploadReport") + "upload_report": obj.get("uploadReport"), + "use_datapool": obj.get("useDatapool") }) return _obj diff --git a/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection.py b/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection.py index cb61221e6..c0def6930 100644 --- a/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection.py +++ b/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection.py @@ -27,7 +27,7 @@ class PredictionSingletonKeypointDetection(PredictionSingletonBase): """ PredictionSingletonKeypointDetection """ - keypoints: conlist(Union[confloat(ge=0, strict=True), conint(ge=0, strict=True)], min_items=3) = Field(..., description="[x1, y2, v1, ..., xk, yk, vk] as outlined by the coco format https://cocodataset.org/#format-results ") + keypoints: conlist(Union[confloat(ge=0, strict=True), conint(ge=0, strict=True)], min_items=3) = Field(..., description="[x1, y2, s1, ..., xk, yk, sk] as outlined by https://docs.lightly.ai/docs/prediction-format#keypoint-detection ") probabilities: Optional[conlist(Union[confloat(le=1, ge=0, strict=True), conint(le=1, ge=0, strict=True)])] = Field(None, description="The probabilities of it being a certain category other than the one which was selected. The sum of all probabilities should equal 1.") __properties = ["type", "taskName", "cropDatasetId", "cropSampleId", "categoryId", "score", "keypoints", "probabilities"] diff --git a/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection_all_of.py b/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection_all_of.py index 73504ae2f..abeadb9dc 100644 --- a/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection_all_of.py +++ b/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection_all_of.py @@ -26,7 +26,7 @@ class PredictionSingletonKeypointDetectionAllOf(BaseModel): """ PredictionSingletonKeypointDetectionAllOf """ - keypoints: conlist(Union[confloat(ge=0, strict=True), conint(ge=0, strict=True)], min_items=3) = Field(..., description="[x1, y2, v1, ..., xk, yk, vk] as outlined by the coco format https://cocodataset.org/#format-results ") + keypoints: conlist(Union[confloat(ge=0, strict=True), conint(ge=0, strict=True)], min_items=3) = Field(..., description="[x1, y2, s1, ..., xk, yk, sk] as outlined by https://docs.lightly.ai/docs/prediction-format#keypoint-detection ") probabilities: Optional[conlist(Union[confloat(le=1, ge=0, strict=True), conint(le=1, ge=0, strict=True)])] = Field(None, description="The probabilities of it being a certain category other than the one which was selected. The sum of all probabilities should equal 1.") __properties = ["keypoints", "probabilities"] diff --git a/lightly/openapi_generated/swagger_client/models/prediction_task_schema.py b/lightly/openapi_generated/swagger_client/models/prediction_task_schema.py index da0aa6918..469dcd4da 100644 --- a/lightly/openapi_generated/swagger_client/models/prediction_task_schema.py +++ b/lightly/openapi_generated/swagger_client/models/prediction_task_schema.py @@ -14,85 +14,170 @@ from __future__ import annotations +from inspect import getfullargspec +import json import pprint import re # noqa: F401 -import json +from typing import Any, List, Optional +from pydantic import BaseModel, Field, StrictStr, ValidationError, validator +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_keypoint import PredictionTaskSchemaKeypoint +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_simple import PredictionTaskSchemaSimple +from typing import Any, List +from pydantic import StrictStr, Field, Extra -from typing import List -from pydantic import Extra, BaseModel, Field, conlist, constr, validator -from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category import PredictionTaskSchemaCategory -from lightly.openapi_generated.swagger_client.models.task_type import TaskType +PREDICTIONTASKSCHEMA_ONE_OF_SCHEMAS = ["PredictionTaskSchemaKeypoint", "PredictionTaskSchemaSimple"] class PredictionTaskSchema(BaseModel): """ - The schema for predictions or labels when doing classification, object detection, keypoint detection or instance segmentation + PredictionTaskSchema """ - name: constr(strict=True, min_length=1) = Field(..., description="A name which is safe to have as a file/folder name in a file system") - type: TaskType = Field(...) - categories: conlist(PredictionTaskSchemaCategory) = Field(..., description="An array of the categories that exist for this prediction task. The id needs to be unique") - __properties = ["name", "type", "categories"] - - @validator('name') - def name_validate_regular_expression(cls, value): - """Validates the regular expression""" - if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9 ._-]+$", value): - raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9][a-zA-Z0-9 ._-]+$/") - return value + # data type: PredictionTaskSchemaSimple + oneof_schema_1_validator: Optional[PredictionTaskSchemaSimple] = None + # data type: PredictionTaskSchemaKeypoint + oneof_schema_2_validator: Optional[PredictionTaskSchemaKeypoint] = None + actual_instance: Any + one_of_schemas: List[str] = Field(PREDICTIONTASKSCHEMA_ONE_OF_SCHEMAS, const=True) class Config: - """Pydantic configuration""" - allow_population_by_field_name = True validate_assignment = True use_enum_values = True extra = Extra.forbid - def to_str(self, by_alias: bool = False) -> str: - """Returns the string representation of the model""" - return pprint.pformat(self.dict(by_alias=by_alias)) + discriminator_value_class_map = { + } + + def __init__(self, *args, **kwargs): + if args: + if len(args) > 1: + raise ValueError("If a position argument is used, only 1 is allowed to set `actual_instance`") + if kwargs: + raise ValueError("If a position argument is used, keyword arguments cannot be used.") + super().__init__(actual_instance=args[0]) + else: + super().__init__(**kwargs) + + @validator('actual_instance') + def actual_instance_must_validate_oneof(cls, v): + instance = PredictionTaskSchema.construct() + error_messages = [] + match = 0 + # validate data type: PredictionTaskSchemaSimple + if not isinstance(v, PredictionTaskSchemaSimple): + error_messages.append(f"Error! Input type `{type(v)}` is not `PredictionTaskSchemaSimple`") + else: + match += 1 + # validate data type: PredictionTaskSchemaKeypoint + if not isinstance(v, PredictionTaskSchemaKeypoint): + error_messages.append(f"Error! Input type `{type(v)}` is not `PredictionTaskSchemaKeypoint`") + else: + match += 1 + if match > 1: + # more than 1 match + raise ValueError("Multiple matches found when setting `actual_instance` in PredictionTaskSchema with oneOf schemas: PredictionTaskSchemaKeypoint, PredictionTaskSchemaSimple. Details: " + ", ".join(error_messages)) + elif match == 0: + # no match + raise ValueError("No match found when setting `actual_instance` in PredictionTaskSchema with oneOf schemas: PredictionTaskSchemaKeypoint, PredictionTaskSchemaSimple. Details: " + ", ".join(error_messages)) + else: + return v - def to_json(self, by_alias: bool = False) -> str: - """Returns the JSON representation of the model""" - return json.dumps(self.to_dict(by_alias=by_alias)) + @classmethod + def from_dict(cls, obj: dict) -> PredictionTaskSchema: + return cls.from_json(json.dumps(obj)) @classmethod def from_json(cls, json_str: str) -> PredictionTaskSchema: - """Create an instance of PredictionTaskSchema from a JSON string""" - return cls.from_dict(json.loads(json_str)) - - def to_dict(self, by_alias: bool = False): - """Returns the dictionary representation of the model""" - _dict = self.dict(by_alias=by_alias, - exclude={ - }, - exclude_none=True) - # override the default output from pydantic by calling `to_dict()` of each item in categories (list) - _items = [] - if self.categories: - for _item in self.categories: - if _item: - _items.append(_item.to_dict(by_alias=by_alias)) - _dict['categories' if by_alias else 'categories'] = _items - return _dict + """Returns the object represented by the json string""" + instance = PredictionTaskSchema.construct() + error_messages = [] + match = 0 + + # use oneOf discriminator to lookup the data type + _data_type = json.loads(json_str).get("type") + if not _data_type: + raise ValueError("Failed to lookup data type from the field `type` in the input.") + + # check if data type is `PredictionTaskSchemaSimple` + if _data_type == "CLASSIFICATION": + instance.actual_instance = PredictionTaskSchemaSimple.from_json(json_str) + return instance + + # check if data type is `PredictionTaskSchemaSimple` + if _data_type == "INSTANCE_SEGMENTATION": + instance.actual_instance = PredictionTaskSchemaSimple.from_json(json_str) + return instance + + # check if data type is `PredictionTaskSchemaKeypoint` + if _data_type == "KEYPOINT_DETECTION": + instance.actual_instance = PredictionTaskSchemaKeypoint.from_json(json_str) + return instance + + # check if data type is `PredictionTaskSchemaSimple` + if _data_type == "OBJECT_DETECTION": + instance.actual_instance = PredictionTaskSchemaSimple.from_json(json_str) + return instance + + # check if data type is `PredictionTaskSchemaKeypoint` + if _data_type == "PredictionTaskSchemaKeypoint": + instance.actual_instance = PredictionTaskSchemaKeypoint.from_json(json_str) + return instance + + # check if data type is `PredictionTaskSchemaSimple` + if _data_type == "PredictionTaskSchemaSimple": + instance.actual_instance = PredictionTaskSchemaSimple.from_json(json_str) + return instance + + # check if data type is `PredictionTaskSchemaSimple` + if _data_type == "SEMANTIC_SEGMENTATION": + instance.actual_instance = PredictionTaskSchemaSimple.from_json(json_str) + return instance + + # deserialize data into PredictionTaskSchemaSimple + try: + instance.actual_instance = PredictionTaskSchemaSimple.from_json(json_str) + match += 1 + except (ValidationError, ValueError) as e: + error_messages.append(str(e)) + # deserialize data into PredictionTaskSchemaKeypoint + try: + instance.actual_instance = PredictionTaskSchemaKeypoint.from_json(json_str) + match += 1 + except (ValidationError, ValueError) as e: + error_messages.append(str(e)) + + if match > 1: + # more than 1 match + raise ValueError("Multiple matches found when deserializing the JSON string into PredictionTaskSchema with oneOf schemas: PredictionTaskSchemaKeypoint, PredictionTaskSchemaSimple. Details: " + ", ".join(error_messages)) + elif match == 0: + # no match + raise ValueError("No match found when deserializing the JSON string into PredictionTaskSchema with oneOf schemas: PredictionTaskSchemaKeypoint, PredictionTaskSchemaSimple. Details: " + ", ".join(error_messages)) + else: + return instance - @classmethod - def from_dict(cls, obj: dict) -> PredictionTaskSchema: - """Create an instance of PredictionTaskSchema from a dict""" - if obj is None: + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the actual instance""" + if self.actual_instance is None: + return "null" + + to_json = getattr(self.actual_instance, "to_json", None) + if callable(to_json): + return self.actual_instance.to_json(by_alias=by_alias) + else: + return json.dumps(self.actual_instance) + + def to_dict(self, by_alias: bool = False) -> dict: + """Returns the dict representation of the actual instance""" + if self.actual_instance is None: return None - if not isinstance(obj, dict): - return PredictionTaskSchema.parse_obj(obj) + to_dict = getattr(self.actual_instance, "to_dict", None) + if callable(to_dict): + return self.actual_instance.to_dict(by_alias=by_alias) + else: + # primitive type + return self.actual_instance - # raise errors for additional fields in the input - for _key in obj.keys(): - if _key not in cls.__properties: - raise ValueError("Error due to additional fields (not defined in PredictionTaskSchema) in the input: " + str(obj)) - - _obj = PredictionTaskSchema.parse_obj({ - "name": obj.get("name"), - "type": obj.get("type"), - "categories": [PredictionTaskSchemaCategory.from_dict(_item) for _item in obj.get("categories")] if obj.get("categories") is not None else None - }) - return _obj + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the actual instance""" + return pprint.pformat(self.dict(by_alias=by_alias)) diff --git a/lightly/openapi_generated/swagger_client/models/prediction_task_schema_base.py b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_base.py new file mode 100644 index 000000000..a657ec939 --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_base.py @@ -0,0 +1,99 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json +import lightly.openapi_generated.swagger_client.models + + + +from pydantic import Extra, BaseModel, Field, StrictStr, constr, validator + +class PredictionTaskSchemaBase(BaseModel): + """ + The schema for predictions or labels when doing classification, object detection, keypoint detection or instance segmentation + """ + name: constr(strict=True, min_length=1) = Field(..., description="A name which is safe to have as a file/folder name in a file system") + type: StrictStr = Field(..., description="This is the TaskType. Due to openapi.oneOf fuckery with discriminators, this needs to be a string") + __properties = ["name", "type"] + + @validator('name') + def name_validate_regular_expression(cls, value): + """Validates the regular expression""" + if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9 ._-]+$", value): + raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9][a-zA-Z0-9 ._-]+$/") + return value + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + # JSON field name that stores the object type + __discriminator_property_name = 'type' + + # discriminator mappings + __discriminator_value_class_map = { + 'PredictionTaskSchemaKeypoint': 'PredictionTaskSchemaKeypoint', + 'PredictionTaskSchemaSimple': 'PredictionTaskSchemaSimple' + } + + @classmethod + def get_discriminator_value(cls, obj: dict) -> str: + """Returns the discriminator value (object type) of the data""" + discriminator_value = obj[cls.__discriminator_property_name] + if discriminator_value: + return cls.__discriminator_value_class_map.get(discriminator_value) + else: + return None + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> Union(PredictionTaskSchemaKeypoint, PredictionTaskSchemaSimple): + """Create an instance of PredictionTaskSchemaBase from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> Union(PredictionTaskSchemaKeypoint, PredictionTaskSchemaSimple): + """Create an instance of PredictionTaskSchemaBase from a dict""" + # look up the object type based on discriminator mapping + object_type = cls.get_discriminator_value(obj) + if object_type: + klass = getattr(lightly.openapi_generated.swagger_client.models, object_type) + return klass.from_dict(obj) + else: + raise ValueError("PredictionTaskSchemaBase failed to lookup discriminator value from " + + json.dumps(obj) + ". Discriminator property name: " + cls.__discriminator_property_name + + ", mapping: " + json.dumps(cls.__discriminator_value_class_map)) + diff --git a/lightly/openapi_generated/swagger_client/models/prediction_task_schema_category.py b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_category.py index 5d18c8231..505a8b7e1 100644 --- a/lightly/openapi_generated/swagger_client/models/prediction_task_schema_category.py +++ b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_category.py @@ -19,15 +19,15 @@ import json -from typing import Optional + from pydantic import Extra, BaseModel, Field, conint, constr class PredictionTaskSchemaCategory(BaseModel): """ The link between the categoryId and the name that should be used """ - id: Optional[conint(strict=True, ge=0)] = Field(None, description="The id of the category. Needs to be a positive integer but can be any integer (gaps are allowed, does not need to be sequential)") - name: Optional[constr(strict=True, min_length=1)] = Field(None, description="The name of the category when it should be visualized") + id: conint(strict=True, ge=0) = Field(..., description="The id of the category. Needs to be a positive integer but can be any integer (gaps are allowed, does not need to be sequential)") + name: constr(strict=True, min_length=1) = Field(..., description="The name of the category when it should be visualized") __properties = ["id", "name"] class Config: diff --git a/lightly/openapi_generated/swagger_client/models/prediction_task_schema_category_keypoints.py b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_category_keypoints.py new file mode 100644 index 000000000..e0d9a71fa --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_category_keypoints.py @@ -0,0 +1,84 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import List, Optional +from pydantic import Extra, BaseModel, Field, conint, conlist, constr + +class PredictionTaskSchemaCategoryKeypoints(BaseModel): + """ + PredictionTaskSchemaCategoryKeypoints + """ + id: conint(strict=True, ge=0) = Field(..., description="The id of the category. Needs to be a positive integer but can be any integer (gaps are allowed, does not need to be sequential)") + name: constr(strict=True, min_length=1) = Field(..., description="The name of the category when it should be visualized") + keypoint_names: Optional[conlist(constr(strict=True, min_length=1))] = Field(None, alias="keypointNames", description="The names of the individual keypoints. E.g left-shoulder, right-shoulder, nose, etc. Must be of equal length as the number of keypoints of a keypoint detection. ") + keypoint_skeleton: Optional[conlist(conlist(conint(strict=True, ge=0), max_items=2, min_items=2))] = Field(None, alias="keypointSkeleton", description="The keypoint skeleton of a category. It is used to show the overall connectivity between keypoints. Each entry in the array describes a a single connection between two keypoints by their index. e.g [1,3],[2,4],[3,4] ") + __properties = ["id", "name", "keypointNames", "keypointSkeleton"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> PredictionTaskSchemaCategoryKeypoints: + """Create an instance of PredictionTaskSchemaCategoryKeypoints from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> PredictionTaskSchemaCategoryKeypoints: + """Create an instance of PredictionTaskSchemaCategoryKeypoints from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return PredictionTaskSchemaCategoryKeypoints.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in PredictionTaskSchemaCategoryKeypoints) in the input: " + str(obj)) + + _obj = PredictionTaskSchemaCategoryKeypoints.parse_obj({ + "id": obj.get("id"), + "name": obj.get("name"), + "keypoint_names": obj.get("keypointNames"), + "keypoint_skeleton": obj.get("keypointSkeleton") + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/prediction_task_schema_category_keypoints_all_of.py b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_category_keypoints_all_of.py new file mode 100644 index 000000000..4dba76568 --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_category_keypoints_all_of.py @@ -0,0 +1,80 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import List, Optional +from pydantic import Extra, BaseModel, Field, conint, conlist, constr + +class PredictionTaskSchemaCategoryKeypointsAllOf(BaseModel): + """ + The link between the categoryId and the name that should be used + """ + keypoint_names: Optional[conlist(constr(strict=True, min_length=1))] = Field(None, alias="keypointNames", description="The names of the individual keypoints. E.g left-shoulder, right-shoulder, nose, etc. Must be of equal length as the number of keypoints of a keypoint detection. ") + keypoint_skeleton: Optional[conlist(conlist(conint(strict=True, ge=0), max_items=2, min_items=2))] = Field(None, alias="keypointSkeleton", description="The keypoint skeleton of a category. It is used to show the overall connectivity between keypoints. Each entry in the array describes a a single connection between two keypoints by their index. e.g [1,3],[2,4],[3,4] ") + __properties = ["keypointNames", "keypointSkeleton"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> PredictionTaskSchemaCategoryKeypointsAllOf: + """Create an instance of PredictionTaskSchemaCategoryKeypointsAllOf from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> PredictionTaskSchemaCategoryKeypointsAllOf: + """Create an instance of PredictionTaskSchemaCategoryKeypointsAllOf from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return PredictionTaskSchemaCategoryKeypointsAllOf.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in PredictionTaskSchemaCategoryKeypointsAllOf) in the input: " + str(obj)) + + _obj = PredictionTaskSchemaCategoryKeypointsAllOf.parse_obj({ + "keypoint_names": obj.get("keypointNames"), + "keypoint_skeleton": obj.get("keypointSkeleton") + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/prediction_task_schema_keypoint.py b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_keypoint.py new file mode 100644 index 000000000..612d9677d --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_keypoint.py @@ -0,0 +1,89 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import List +from pydantic import Extra, BaseModel, Field, conlist +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_base import PredictionTaskSchemaBase +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category_keypoints import PredictionTaskSchemaCategoryKeypoints + +class PredictionTaskSchemaKeypoint(PredictionTaskSchemaBase): + """ + PredictionTaskSchemaKeypoint + """ + categories: conlist(PredictionTaskSchemaCategoryKeypoints) = Field(..., description="An array of the categories that exist for this prediction task. The id needs to be unique") + __properties = ["name", "type", "categories"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> PredictionTaskSchemaKeypoint: + """Create an instance of PredictionTaskSchemaKeypoint from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + # override the default output from pydantic by calling `to_dict()` of each item in categories (list) + _items = [] + if self.categories: + for _item in self.categories: + if _item: + _items.append(_item.to_dict(by_alias=by_alias)) + _dict['categories' if by_alias else 'categories'] = _items + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> PredictionTaskSchemaKeypoint: + """Create an instance of PredictionTaskSchemaKeypoint from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return PredictionTaskSchemaKeypoint.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in PredictionTaskSchemaKeypoint) in the input: " + str(obj)) + + _obj = PredictionTaskSchemaKeypoint.parse_obj({ + "name": obj.get("name"), + "type": obj.get("type"), + "categories": [PredictionTaskSchemaCategoryKeypoints.from_dict(_item) for _item in obj.get("categories")] if obj.get("categories") is not None else None + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/prediction_task_schema_keypoint_all_of.py b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_keypoint_all_of.py new file mode 100644 index 000000000..155c79b4e --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_keypoint_all_of.py @@ -0,0 +1,86 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import List +from pydantic import Extra, BaseModel, Field, conlist +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category_keypoints import PredictionTaskSchemaCategoryKeypoints + +class PredictionTaskSchemaKeypointAllOf(BaseModel): + """ + The schema for predictions or labels when doing keypoint detection + """ + categories: conlist(PredictionTaskSchemaCategoryKeypoints) = Field(..., description="An array of the categories that exist for this prediction task. The id needs to be unique") + __properties = ["categories"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> PredictionTaskSchemaKeypointAllOf: + """Create an instance of PredictionTaskSchemaKeypointAllOf from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + # override the default output from pydantic by calling `to_dict()` of each item in categories (list) + _items = [] + if self.categories: + for _item in self.categories: + if _item: + _items.append(_item.to_dict(by_alias=by_alias)) + _dict['categories' if by_alias else 'categories'] = _items + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> PredictionTaskSchemaKeypointAllOf: + """Create an instance of PredictionTaskSchemaKeypointAllOf from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return PredictionTaskSchemaKeypointAllOf.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in PredictionTaskSchemaKeypointAllOf) in the input: " + str(obj)) + + _obj = PredictionTaskSchemaKeypointAllOf.parse_obj({ + "categories": [PredictionTaskSchemaCategoryKeypoints.from_dict(_item) for _item in obj.get("categories")] if obj.get("categories") is not None else None + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/prediction_task_schema_simple.py b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_simple.py new file mode 100644 index 000000000..45aee273e --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_simple.py @@ -0,0 +1,89 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import List +from pydantic import Extra, BaseModel, Field, conlist +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_base import PredictionTaskSchemaBase +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category import PredictionTaskSchemaCategory + +class PredictionTaskSchemaSimple(PredictionTaskSchemaBase): + """ + PredictionTaskSchemaSimple + """ + categories: conlist(PredictionTaskSchemaCategory) = Field(..., description="An array of the categories that exist for this prediction task. The id needs to be unique") + __properties = ["name", "type", "categories"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> PredictionTaskSchemaSimple: + """Create an instance of PredictionTaskSchemaSimple from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + # override the default output from pydantic by calling `to_dict()` of each item in categories (list) + _items = [] + if self.categories: + for _item in self.categories: + if _item: + _items.append(_item.to_dict(by_alias=by_alias)) + _dict['categories' if by_alias else 'categories'] = _items + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> PredictionTaskSchemaSimple: + """Create an instance of PredictionTaskSchemaSimple from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return PredictionTaskSchemaSimple.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in PredictionTaskSchemaSimple) in the input: " + str(obj)) + + _obj = PredictionTaskSchemaSimple.parse_obj({ + "name": obj.get("name"), + "type": obj.get("type"), + "categories": [PredictionTaskSchemaCategory.from_dict(_item) for _item in obj.get("categories")] if obj.get("categories") is not None else None + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/prediction_task_schema_simple_all_of.py b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_simple_all_of.py new file mode 100644 index 000000000..9616f4fff --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/prediction_task_schema_simple_all_of.py @@ -0,0 +1,86 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import List +from pydantic import Extra, BaseModel, Field, conlist +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category import PredictionTaskSchemaCategory + +class PredictionTaskSchemaSimpleAllOf(BaseModel): + """ + The schema for predictions or labels when doing classification, object detection or instance segmentation + """ + categories: conlist(PredictionTaskSchemaCategory) = Field(..., description="An array of the categories that exist for this prediction task. The id needs to be unique") + __properties = ["categories"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> PredictionTaskSchemaSimpleAllOf: + """Create an instance of PredictionTaskSchemaSimpleAllOf from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + # override the default output from pydantic by calling `to_dict()` of each item in categories (list) + _items = [] + if self.categories: + for _item in self.categories: + if _item: + _items.append(_item.to_dict(by_alias=by_alias)) + _dict['categories' if by_alias else 'categories'] = _items + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> PredictionTaskSchemaSimpleAllOf: + """Create an instance of PredictionTaskSchemaSimpleAllOf from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return PredictionTaskSchemaSimpleAllOf.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in PredictionTaskSchemaSimpleAllOf) in the input: " + str(obj)) + + _obj = PredictionTaskSchemaSimpleAllOf.parse_obj({ + "categories": [PredictionTaskSchemaCategory.from_dict(_item) for _item in obj.get("categories")] if obj.get("categories") is not None else None + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/prediction_task_schemas.py b/lightly/openapi_generated/swagger_client/models/prediction_task_schemas.py new file mode 100644 index 000000000..bbdef7dee --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/prediction_task_schemas.py @@ -0,0 +1,88 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import List +from pydantic import Extra, BaseModel, Field, conint, conlist +from lightly.openapi_generated.swagger_client.models.prediction_task_schema import PredictionTaskSchema + +class PredictionTaskSchemas(BaseModel): + """ + PredictionTaskSchemas + """ + prediction_uuid_timestamp: conint(strict=True, ge=0) = Field(..., alias="predictionUUIDTimestamp", description="unix timestamp in milliseconds") + schemas: conlist(PredictionTaskSchema) = Field(...) + __properties = ["predictionUUIDTimestamp", "schemas"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> PredictionTaskSchemas: + """Create an instance of PredictionTaskSchemas from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + # override the default output from pydantic by calling `to_dict()` of each item in schemas (list) + _items = [] + if self.schemas: + for _item in self.schemas: + if _item: + _items.append(_item.to_dict(by_alias=by_alias)) + _dict['schemas' if by_alias else 'schemas'] = _items + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> PredictionTaskSchemas: + """Create an instance of PredictionTaskSchemas from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return PredictionTaskSchemas.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in PredictionTaskSchemas) in the input: " + str(obj)) + + _obj = PredictionTaskSchemas.parse_obj({ + "prediction_uuid_timestamp": obj.get("predictionUUIDTimestamp"), + "schemas": [PredictionTaskSchema.from_dict(_item) for _item in obj.get("schemas")] if obj.get("schemas") is not None else None + }) + return _obj + diff --git a/lightly/transforms/dino_transform.py b/lightly/transforms/dino_transform.py index c23e446ae..de6a18657 100644 --- a/lightly/transforms/dino_transform.py +++ b/lightly/transforms/dino_transform.py @@ -15,6 +15,21 @@ class DINOTransform(MultiViewTransform): """Implements the global and local view augmentations for DINO [0]. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length 2 * global + n_local_views. (8 by default) + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Gaussian blur + - Random solarization + - ImageNet normalization + This class generates two global and a user defined number of local views for each image in a batch. The code is adapted from [1]. diff --git a/lightly/transforms/fast_siam_transform.py b/lightly/transforms/fast_siam_transform.py index d14768d07..7b560591f 100644 --- a/lightly/transforms/fast_siam_transform.py +++ b/lightly/transforms/fast_siam_transform.py @@ -8,6 +8,20 @@ class FastSiamTransform(MultiViewTransform): """Implements the transformations for FastSiam. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length 4. + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Gaussian blur + - ImageNet normalization + Attributes: num_views: Number of views (num_views = K+1 where K is the number of target views). diff --git a/lightly/transforms/ijepa_transform.py b/lightly/transforms/ijepa_transform.py new file mode 100644 index 000000000..321dba66a --- /dev/null +++ b/lightly/transforms/ijepa_transform.py @@ -0,0 +1,58 @@ +from typing import Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image +from torch import Tensor + +from lightly.transforms.utils import IMAGENET_NORMALIZE + + +class IJEPATransform: + """Implements the augmentations for I-JEPA [0, 1]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + + Attributes: + input_size: + Size of the input image in pixels. + min_scale: + Minimum size of the randomized crop relative to the input_size. + normalize: + Dictionary with 'mean' and 'std' for torchvision.transforms.Normalize. + + """ + + def __init__( + self, + input_size: Union[int, Tuple[int, int]] = 224, + min_scale: float = 0.2, + normalize: dict = IMAGENET_NORMALIZE, + ): + transforms = [ + T.RandomResizedCrop( + input_size, scale=(min_scale, 1.0), interpolation=3 + ), # 3 is bicubic + T.RandomHorizontalFlip(), + T.ToTensor(), + ] + if normalize: + transforms.append(T.Normalize(mean=normalize["mean"], std=normalize["std"])) + + self.transform = T.Compose(transforms) + + def __call__(self, image: Union[Tensor, Image]) -> Tensor: + """Applies the transforms to the input image. + + Args: + image: + The input image to apply the transforms to. + + Returns: + The transformed image. + + """ + return self.transform(image) diff --git a/lightly/transforms/mae_transform.py b/lightly/transforms/mae_transform.py index 0510ca8d5..3176f084e 100644 --- a/lightly/transforms/mae_transform.py +++ b/lightly/transforms/mae_transform.py @@ -11,6 +11,16 @@ class MAETransform: """Implements the view augmentation for MAE [0]. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length 1. + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - [0]: Masked Autoencoder, 2021, https://arxiv.org/abs/2111.06377 Attributes: diff --git a/lightly/transforms/moco_transform.py b/lightly/transforms/moco_transform.py index c60e97bf1..8ec8ade55 100644 --- a/lightly/transforms/moco_transform.py +++ b/lightly/transforms/moco_transform.py @@ -7,6 +7,19 @@ class MoCoV1Transform(SimCLRTransform): """Implements the transformations for MoCo v1. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length 2. + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - ImageNet normalization + Attributes: input_size: Size of the input image in pixels. @@ -98,6 +111,20 @@ class MoCoV2Transform(SimCLRTransform): Identical to SimCLRTransform. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length 2. + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Gaussian blur + - ImageNet normalization + - [0]: MoCo v2, 2020, https://arxiv.org/abs/2003.04297 Attributes: diff --git a/lightly/transforms/msn_transform.py b/lightly/transforms/msn_transform.py index 259f7e1ba..d07130732 100644 --- a/lightly/transforms/msn_transform.py +++ b/lightly/transforms/msn_transform.py @@ -12,6 +12,20 @@ class MSNTransform(MultiViewTransform): """Implements the transformations for MSN [0]. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length 2 * random_views + focal_views. (12 by default) + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Gaussian blur + - ImageNet normalization + Generates a set of random and focal views for each input image. The generated output is (views, target, filenames) where views is list with the following entries: [random_views_0, random_views_1, ..., focal_views_0, focal_views_1, ...]. diff --git a/lightly/transforms/multi_crop_transform.py b/lightly/transforms/multi_crop_transform.py index 76390b589..307a507c6 100644 --- a/lightly/transforms/multi_crop_transform.py +++ b/lightly/transforms/multi_crop_transform.py @@ -8,6 +8,16 @@ class MultiCropTranform(MultiViewTransform): """Implements the multi-crop transformations. Used by Swav. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length crop_counts. + + Applies the following augmentations by default: + - Random resized crop + - transforms passed by constructor + Attributes: crop_sizes: Size of the input image in pixels for each crop category. diff --git a/lightly/transforms/pirl_transform.py b/lightly/transforms/pirl_transform.py index d41476fb3..b67e451c7 100644 --- a/lightly/transforms/pirl_transform.py +++ b/lightly/transforms/pirl_transform.py @@ -14,6 +14,19 @@ class PIRLTransform(MultiViewTransform): """Implements the transformations for PIRL [0]. The jigsaw augmentation is applied during the forward pass. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length 2 (original, augmented). + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Jigsaw puzzle + - [0] PIRL, 2019: https://arxiv.org/abs/1912.01991 Attributes: diff --git a/lightly/transforms/simclr_transform.py b/lightly/transforms/simclr_transform.py index c93a04988..8c39591c7 100644 --- a/lightly/transforms/simclr_transform.py +++ b/lightly/transforms/simclr_transform.py @@ -13,11 +13,31 @@ class SimCLRTransform(MultiViewTransform): """Implements the transformations for SimCLR [0, 1]. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length 2. + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Gaussian blur + - ImageNet normalization + Note that SimCLR v1 and v2 use the same data augmentations. - [0]: SimCLR v1, 2020, https://arxiv.org/abs/2002.05709 - [1]: SimCLR v2, 2020, https://arxiv.org/abs/2006.10029 + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of [tensor, tensor]. + Attributes: input_size: Size of the input image in pixels. diff --git a/lightly/transforms/simsiam_transform.py b/lightly/transforms/simsiam_transform.py index fad3bba8d..379d55242 100644 --- a/lightly/transforms/simsiam_transform.py +++ b/lightly/transforms/simsiam_transform.py @@ -13,6 +13,20 @@ class SimSiamTransform(MultiViewTransform): """Implements the transformations for SimSiam. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length 2. + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Gaussian blur + - ImageNet normalization + Attributes: input_size: Size of the input image in pixels. diff --git a/lightly/transforms/smog_transform.py b/lightly/transforms/smog_transform.py index 96237b3a1..6f8958377 100644 --- a/lightly/transforms/smog_transform.py +++ b/lightly/transforms/smog_transform.py @@ -13,6 +13,21 @@ class SMoGTransform(MultiViewTransform): """Implements the transformations for SMoG. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length sum(crop_counts). (8 by default) + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Gaussian blur + - Random solarization + - ImageNet normalization + Attributes: crop_sizes: Size of the input image in pixels for each crop category. diff --git a/lightly/transforms/swav_transform.py b/lightly/transforms/swav_transform.py index c523e7c82..2cbabf94f 100644 --- a/lightly/transforms/swav_transform.py +++ b/lightly/transforms/swav_transform.py @@ -13,6 +13,20 @@ class SwaVTransform(MultiCropTranform): """Implements the multi-crop transformations for SwaV. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length sum(crop_counts). (8 by default) + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Gaussian blur + - ImageNet normalization + Attributes: crop_sizes: Size of the input image in pixels for each crop category. diff --git a/lightly/transforms/vicreg_transform.py b/lightly/transforms/vicreg_transform.py index a854eeb15..e2234f6c4 100644 --- a/lightly/transforms/vicreg_transform.py +++ b/lightly/transforms/vicreg_transform.py @@ -14,6 +14,21 @@ class VICRegTransform(MultiViewTransform): """Implements the transformations for VICReg. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length 2. + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Random solarization + - Gaussian blur + - ImageNet normalization + Similar to SimCLR transform but with extra solarization. Attributes: diff --git a/lightly/transforms/vicregl_transform.py b/lightly/transforms/vicregl_transform.py index 2b087ceda..70baf139d 100644 --- a/lightly/transforms/vicregl_transform.py +++ b/lightly/transforms/vicregl_transform.py @@ -14,6 +14,21 @@ class VICRegLTransform(ImageGridTransform): """Transforms images for VICRegL. + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length n_global_views + n_local_views. (8 by default) + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Gaussian blur + - Random solarization + - ImageNet normalization + - [0]: VICRegL, 2022, https://arxiv.org/abs/2210.01571 Attributes: diff --git a/tests/api/test_utils.py b/tests/api/test_utils.py index 57a6f1f9d..5be42f030 100644 --- a/tests/api/test_utils.py +++ b/tests/api/test_utils.py @@ -96,7 +96,7 @@ def test_get_lightly_server_location_from_env(self): def test_paginate_endpoint(self): def some_function(page_size=8, page_offset=0): if page_offset > 3 * page_size: - return [] + assert False # should not happen elif page_offset > 2 * page_size: return (page_size - 1) * ["a"] else: @@ -108,6 +108,38 @@ def some_function(page_size=8, page_offset=0): self.assertEqual((4 * page_size - 1) * ["a"], some_list) self.assertEqual(len(some_list), (4 * page_size - 1)) + def test_paginate_endpoint__string(self): + def paginated_function(page_size=8, page_offset=0): + """Returns one page of size page_size, then one page of size page_size - 1.""" + if page_offset > 3 * page_size: + assert False # This should not happen. + elif page_offset > 2 * page_size: + return (page_size - 1) * "a" + else: + return page_size * "a" + + page_size = 8 + some_iterator = paginate_endpoint(paginated_function, page_size=page_size) + some_list = list(some_iterator) + self.assertEqual((4 * page_size - 1) * "a", "".join(some_list)) + self.assertEqual(len(some_list), 4) # Expect four pages of strings. + + def test_paginate_endpoint__multiple_of_page_size(self): + def paginated_function(page_size=8, page_offset=0): + """Returns two pages of size page_size, then an empty page.""" + if page_offset > 3 * page_size: + return [] + elif page_offset > 2 * page_size: + return page_size * ["a"] + else: + return page_size * ["a"] + + page_size = 8 + some_iterator = paginate_endpoint(paginated_function, page_size=page_size) + some_list = list(some_iterator) + self.assertEqual((4 * page_size) * ["a"], some_list) + self.assertEqual(len(some_list), (4 * page_size)) + def test_paginate_endpoint_empty(self): def some_function(page_size=8, page_offset=0): return [] diff --git a/tests/api_workflow/__init__.py b/tests/api_workflow/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/api_workflow/mocked_api_workflow_client.py b/tests/api_workflow/mocked_api_workflow_client.py index 162526ac4..f61d9d75e 100644 --- a/tests/api_workflow/mocked_api_workflow_client.py +++ b/tests/api_workflow/mocked_api_workflow_client.py @@ -83,7 +83,7 @@ WriteCSVUrlData, ) from lightly.openapi_generated.swagger_client.rest import ApiException -from tests.api_workflow.utils import generate_id +from tests.api_workflow import utils def _check_dataset_id(dataset_id: str): @@ -99,13 +99,13 @@ def __init__(self, api_client): EmbeddingsApi.__init__(self, api_client=api_client) self.embeddings = [ DatasetEmbeddingData( - id=generate_id(), + id=utils.generate_id(), name="embedding_newest", is_processed=True, created_at=1111111, ), DatasetEmbeddingData( - id=generate_id(), + id=utils.generate_id(), name="default", is_processed=True, created_at=0, @@ -116,7 +116,7 @@ def get_embeddings_csv_write_url_by_id(self, dataset_id: str, **kwargs): _check_dataset_id(dataset_id) assert isinstance(dataset_id, str) response_ = WriteCSVUrlData( - signed_write_url="signed_write_url_valid", embedding_id=generate_id() + signed_write_url="signed_write_url_valid", embedding_id=utils.generate_id() ) return response_ @@ -189,7 +189,7 @@ def create_initial_tag_by_dataset_id( _check_dataset_id(dataset_id) assert isinstance(initial_tag_create_request, InitialTagCreateRequest) assert isinstance(dataset_id, str) - response_ = CreateEntityResponse(id=generate_id()) + response_ = CreateEntityResponse(id=utils.generate_id()) return response_ def get_tag_by_tag_id(self, dataset_id, tag_id, **kwargs): @@ -199,7 +199,7 @@ def get_tag_by_tag_id(self, dataset_id, tag_id, **kwargs): response_ = TagData( id=tag_id, dataset_id=dataset_id, - prev_tag_id=generate_id(), + prev_tag_id=utils.generate_id(), bit_mask_data="0x80bda23e9", name="second-tag", tot_size=15, @@ -212,7 +212,7 @@ def get_tags_by_dataset_id(self, dataset_id, **kwargs): _check_dataset_id(dataset_id) tag_1 = TagData( - id=generate_id(), + id=utils.generate_id(), dataset_id=dataset_id, prev_tag_id=None, bit_mask_data="0xf", @@ -222,7 +222,7 @@ def get_tags_by_dataset_id(self, dataset_id, **kwargs): changes=[], ) tag_2 = TagData( - id=generate_id(), + id=utils.generate_id(), dataset_id=dataset_id, prev_tag_id=tag_1.id, bit_mask_data="0xf", @@ -232,7 +232,7 @@ def get_tags_by_dataset_id(self, dataset_id, **kwargs): changes=[], ) tag_3 = TagData( - id=generate_id(), + id=utils.generate_id(), dataset_id=dataset_id, prev_tag_id=tag_1.id, bit_mask_data="0x1", @@ -242,7 +242,7 @@ def get_tags_by_dataset_id(self, dataset_id, **kwargs): changes=[], ) tag_4 = TagData( - id=generate_id(), + id=utils.generate_id(), dataset_id=dataset_id, prev_tag_id=tag_3.id, bit_mask_data="0x3", @@ -252,7 +252,7 @@ def get_tags_by_dataset_id(self, dataset_id, **kwargs): changes=[], ) tag_5 = TagData( - id=generate_id(), + id=utils.generate_id(), dataset_id=dataset_id, prev_tag_id=None, bit_mask_data="0x1", @@ -295,7 +295,7 @@ def create_tag_by_dataset_id( ) -> TagData: _check_dataset_id(dataset_id) tag = TagData( - id=generate_id(), + id=utils.generate_id(), dataset_id=dataset_id, prev_tag_id=tag_create_request["prev_tag_id"], bit_mask_data=tag_create_request["bit_mask_data"], @@ -405,6 +405,8 @@ def export_tag_to_basic_filenames_and_read_urls( def export_tag_to_basic_filenames( self, dataset_id: str, tag_id: str, **kwargs ) -> str: + if kwargs["page_offset"] and kwargs["page_offset"] > 0: + return "" return """ IMG_2276_jpeg_jpg.rf.7411b1902c81bad8cdefd2cc4eb3a97b.jpg IMG_2285_jpeg_jpg.rf.4a93d99b9f0b6cccfb27bf2f4a13b99e.jpg @@ -520,7 +522,7 @@ def __init__(self, api_client): self._default_datasets = [ DatasetData( name=f"dataset_{i}", - id=generate_id(), + id=utils.generate_id(), last_modified_at=i, type="Images", img_type="full", @@ -534,7 +536,7 @@ def __init__(self, api_client): self._shared_datasets = [ DatasetData( name=f"shared_dataset_{i}", - id=generate_id(), + id=utils.generate_id(), last_modified_at=0, type="Images", img_type="full", @@ -557,11 +559,14 @@ def reset(self): def get_datasets( self, - shared: bool, + shared: bool = False, + get_assets_of_team: bool = False, page_size: Optional[int] = None, page_offset: Optional[int] = None, ): start, end = _start_and_end_offset(page_size=page_size, page_offset=page_offset) + if get_assets_of_team: + return [] if shared: return self.shared_datasets[start:end] else: @@ -569,7 +574,7 @@ def get_datasets( def create_dataset(self, dataset_create_request: DatasetCreateRequest, **kwargs): assert isinstance(dataset_create_request, DatasetCreateRequest) - id = generate_id() + id = utils.generate_id() if dataset_create_request.name == "xyz-no-tags": id = "xyz-no-tags" dataset = DatasetData( @@ -580,7 +585,7 @@ def create_dataset(self, dataset_create_request: DatasetCreateRequest, **kwargs) size_in_bytes=-1, n_samples=-1, created_at=-1, - user_id=generate_id(), + user_id=utils.generate_id(), ) self.datasets.append(dataset) response_ = CreateEntityResponse(id=id) @@ -615,9 +620,20 @@ def get_datasets_enriched(self, **kwargs): raise NotImplementedError() def get_datasets_query_by_name( - self, dataset_name: str, shared: bool, exact: bool + self, + dataset_name: str, + page_size: Optional[int] = None, + page_offset: Optional[int] = None, + shared: bool = False, + exact: bool = False, + get_assets_of_team: bool = False, ) -> List[DatasetData]: - datasets = self.get_datasets(shared=shared) + datasets = self.get_datasets( + shared=shared, + get_assets_of_team=get_assets_of_team, + page_size=page_size, + page_offset=page_offset, + ) if exact: return [dataset for dataset in datasets if dataset.name == dataset_name] else: @@ -810,10 +826,10 @@ def __init__(self, api_client=None): super().__init__(api_client=api_client) self._compute_worker_runs = [ DockerRunData( - id=generate_id(), + id=utils.generate_id(), user_id="user-id", docker_version="v1", - dataset_id=generate_id(), + dataset_id=utils.generate_id(), state=DockerRunState.TRAINING, created_at=0, last_modified_at=100, @@ -823,20 +839,20 @@ def __init__(self, api_client=None): ] self._scheduled_compute_worker_runs = [ DockerRunScheduledData( - id=generate_id(), - dataset_id=generate_id(), - config_id=generate_id(), + id=utils.generate_id(), + dataset_id=utils.generate_id(), + config_id=utils.generate_id(), priority=DockerRunScheduledPriority.MID, state=DockerRunScheduledState.OPEN, created_at=0, last_modified_at=100, - owner=generate_id(), + owner=utils.generate_id(), runs_on=[], ) ] self._registered_workers = [ DockerWorkerRegistryEntryData( - id=generate_id(), + id=utils.generate_id(), user_id="user-id", name="worker-name-1", worker_type=DockerWorkerType.FULL, @@ -849,18 +865,18 @@ def __init__(self, api_client=None): def register_docker_worker(self, body, **kwargs): assert isinstance(body, CreateDockerWorkerRegistryEntryRequest) - return CreateEntityResponse(id=generate_id()) + return CreateEntityResponse(id=utils.generate_id()) def get_docker_worker_registry_entries(self, **kwargs): return self._registered_workers def create_docker_worker_config(self, body, **kwargs): assert isinstance(body, DockerWorkerConfigCreateRequest) - return CreateEntityResponse(id=generate_id()) + return CreateEntityResponse(id=utils.generate_id()) def create_docker_worker_config_v3(self, body, **kwargs): assert isinstance(body, DockerWorkerConfigV3CreateRequest) - return CreateEntityResponse(id=generate_id()) + return CreateEntityResponse(id=utils.generate_id()) def create_docker_run_scheduled_by_dataset_id( self, docker_run_scheduled_create_request, dataset_id, **kwargs @@ -869,7 +885,7 @@ def create_docker_run_scheduled_by_dataset_id( docker_run_scheduled_create_request, DockerRunScheduledCreateRequest ) _check_dataset_id(dataset_id) - return CreateEntityResponse(id=generate_id()) + return CreateEntityResponse(id=utils.generate_id()) def get_docker_runs( self, @@ -996,11 +1012,11 @@ def create_or_update_shared_access_config_by_dataset_id( assert isinstance( shared_access_config_create_request, SharedAccessConfigCreateRequest ) - return CreateEntityResponse(id=generate_id()) + return CreateEntityResponse(id=utils.generate_id()) def get_shared_access_configs_by_dataset_id(self, dataset_id, **kwargs): write_config = SharedAccessConfigData( - id=generate_id(), + id=utils.generate_id(), owner="owner-id", users=["user1@gmail.com", "user2@something.com"], teams=["some-id"], diff --git a/tests/api_workflow/test_api_workflow.py b/tests/api_workflow/test_api_workflow.py index b5a437de4..1f5d3fac9 100644 --- a/tests/api_workflow/test_api_workflow.py +++ b/tests/api_workflow/test_api_workflow.py @@ -4,11 +4,11 @@ import numpy as np import lightly +from tests.api_workflow import utils from tests.api_workflow.mocked_api_workflow_client import ( MockedApiWorkflowClient, MockedApiWorkflowSetup, ) -from tests.api_workflow.utils import generate_id class TestApiWorkflow(MockedApiWorkflowSetup): @@ -44,7 +44,7 @@ def test_dataset_id_nonexisting(self): assert dataset_id == self.api_workflow_client._datasets_api.datasets[-1].id def test_dataset_id_existing(self): - id = generate_id() + id = utils.generate_id() self.api_workflow_client._dataset_id = id assert self.api_workflow_client.dataset_id == id diff --git a/tests/api_workflow/test_api_workflow_artifacts.py b/tests/api_workflow/test_api_workflow_artifacts.py index a13240d14..685aa28cf 100644 --- a/tests/api_workflow/test_api_workflow_artifacts.py +++ b/tests/api_workflow/test_api_workflow_artifacts.py @@ -9,7 +9,7 @@ DockerRunData, DockerRunState, ) -from tests.api_workflow.utils import generate_id +from tests.api_workflow import utils def test_download_compute_worker_run_artifacts(mocker: MockerFixture) -> None: @@ -20,12 +20,12 @@ def test_download_compute_worker_run_artifacts(mocker: MockerFixture) -> None: client._download_compute_worker_run_artifact = ( mock_download_compute_worker_run_artifact ) - run_id = generate_id() - artifact_ids = [generate_id(), generate_id()] + run_id = utils.generate_id() + artifact_ids = [utils.generate_id(), utils.generate_id()] run = DockerRunData( id=run_id, user_id="user-id", - dataset_id=generate_id(), + dataset_id=utils.generate_id(), docker_version="", state=DockerRunState.COMPUTING_METADATA, created_at=0, @@ -72,12 +72,12 @@ def test__download_compute_worker_run_artifact_by_type( client._download_compute_worker_run_artifact = ( mock_download_compute_worker_run_artifact ) - run_id = generate_id() - artifact_ids = [generate_id(), generate_id()] + run_id = utils.generate_id() + artifact_ids = [utils.generate_id(), utils.generate_id()] run = DockerRunData( id=run_id, user_id="user-id", - dataset_id=generate_id(), + dataset_id=utils.generate_id(), docker_version="", state=DockerRunState.COMPUTING_METADATA, created_at=0, @@ -120,9 +120,9 @@ def test__download_compute_worker_run_artifact_by_type__no_artifacts( mock_download_compute_worker_run_artifact ) run = DockerRunData( - id=generate_id(), + id=utils.generate_id(), user_id="user-id", - dataset_id=generate_id(), + dataset_id=utils.generate_id(), docker_version="", state=DockerRunState.COMPUTING_METADATA, created_at=0, @@ -149,16 +149,16 @@ def test__download_compute_worker_run_artifact_by_type__no_artifact_with_type( mock_download_compute_worker_run_artifact ) run = DockerRunData( - id=generate_id(), + id=utils.generate_id(), user_id="user-id", - dataset_id=generate_id(), + dataset_id=utils.generate_id(), docker_version="", state=DockerRunState.COMPUTING_METADATA, created_at=0, last_modified_at=0, artifacts=[ DockerRunArtifactData( - id=generate_id(), + id=utils.generate_id(), file_name="report.pdf", type=DockerRunArtifactType.REPORT_PDF, ), @@ -178,7 +178,7 @@ def test__get_compute_worker_run_checkpoint_url( ) -> None: mocked_client = mocker.MagicMock(spec=ApiWorkflowClient) mocked_artifact = DockerRunArtifactData( - id=generate_id(), + id=utils.generate_id(), file_name="report.pdf", type=DockerRunArtifactType.REPORT_PDF, ) @@ -189,9 +189,9 @@ def test__get_compute_worker_run_checkpoint_url( ) run = DockerRunData( - id=generate_id(), + id=utils.generate_id(), user_id="user-id", - dataset_id=generate_id(), + dataset_id=utils.generate_id(), docker_version="", state=DockerRunState.COMPUTING_METADATA, created_at=0, diff --git a/tests/api_workflow/test_api_workflow_client.py b/tests/api_workflow/test_api_workflow_client.py index 0d973028f..48920077e 100644 --- a/tests/api_workflow/test_api_workflow_client.py +++ b/tests/api_workflow/test_api_workflow_client.py @@ -86,7 +86,9 @@ def raise_connection_error(*args, **kwargs): def test_user_agent_header(mocker: MockerFixture) -> None: mocker.patch.object(lightly.api.api_workflow_client, "__version__", new="VERSION") mocker.patch.object( - lightly.api.api_workflow_client, "is_compatible_version", new=lambda _: True + lightly.api.api_workflow_client.version_checking, + "is_compatible_version", + new=lambda _: True, ) mocked_platform = mocker.patch.object( lightly.api.api_workflow_client, "platform", spec_set=platform diff --git a/tests/api_workflow/test_api_workflow_collaboration.py b/tests/api_workflow/test_api_workflow_collaboration.py index 0fd18bc0f..cb1080456 100644 --- a/tests/api_workflow/test_api_workflow_collaboration.py +++ b/tests/api_workflow/test_api_workflow_collaboration.py @@ -1,8 +1,8 @@ +from tests.api_workflow import utils from tests.api_workflow.mocked_api_workflow_client import ( MockedApiWorkflowClient, MockedApiWorkflowSetup, ) -from tests.api_workflow.utils import generate_id class TestApiWorkflowDatasets(MockedApiWorkflowSetup): @@ -11,16 +11,16 @@ def setUp(self) -> None: def test_share_empty_dataset(self): self.api_workflow_client.share_dataset_only_with( - dataset_id=generate_id(), user_emails=[] + dataset_id=utils.generate_id(), user_emails=[] ) def test_share_dataset(self): self.api_workflow_client.share_dataset_only_with( - dataset_id=generate_id(), user_emails=["someone@something.com"] + dataset_id=utils.generate_id(), user_emails=["someone@something.com"] ) def test_get_shared_users(self): user_emails = self.api_workflow_client.get_shared_users( - dataset_id=generate_id() + dataset_id=utils.generate_id() ) assert user_emails == ["user1@gmail.com", "user2@something.com"] diff --git a/tests/api_workflow/test_api_workflow_compute_worker.py b/tests/api_workflow/test_api_workflow_compute_worker.py index 3260b4d7f..09efdd8a2 100644 --- a/tests/api_workflow/test_api_workflow_compute_worker.py +++ b/tests/api_workflow/test_api_workflow_compute_worker.py @@ -44,8 +44,8 @@ TagData, ) from lightly.openapi_generated.swagger_client.rest import ApiException +from tests.api_workflow import utils from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup -from tests.api_workflow.utils import generate_id class TestApiWorkflowComputeWorker(MockedApiWorkflowSetup): @@ -81,7 +81,7 @@ def test_create_compute_worker_config(self): { "input": { "type": "EMBEDDINGS", - "dataset_id": generate_id(), + "dataset_id": utils.generate_id(), "tag_name": "some-tag-name", }, "strategy": {"type": "SIMILARITY"}, @@ -107,7 +107,7 @@ def test_create_compute_worker_config__selection_config_is_class(self) -> None: SelectionConfigEntry( input=SelectionConfigEntryInput( type=SelectionInputType.EMBEDDINGS, - dataset_id=generate_id(), + dataset_id=utils.generate_id(), tag_name="some-tag-name", ), strategy=SelectionConfigEntryStrategy( @@ -254,7 +254,7 @@ def test_selection_config(self): def test_selection_config_from_dict() -> None: - dataset_id = generate_id() + dataset_id = utils.generate_id() cfg = { "n_samples": 10, "proportion_samples": 0.1, @@ -352,13 +352,24 @@ def test_selection_config_from_dict__extra_strategy_strategy_key() -> None: api_workflow_compute_worker.selection_config_from_dict(cfg) +def test_selection_config_from_dict__multiple_references() -> None: + """Test that conversion is successful if the dictionary contains multiple references + to the same object. + """ + strategy = {"input": {"type": "EMBEDDINGS"}, "strategy": {"type": "DIVERSITY"}} + cfg = {"strategies": [strategy, strategy]} + selection_cfg = api_workflow_compute_worker.selection_config_from_dict(cfg) + assert len(selection_cfg.strategies) == 2 + assert selection_cfg.strategies[0] == selection_cfg.strategies[1] + + def test_get_scheduled_run_by_id() -> None: - run_ids = [generate_id() for _ in range(3)] + run_ids = [utils.generate_id() for _ in range(3)] scheduled_runs = [ DockerRunScheduledData( id=run_id, - dataset_id=generate_id(), - config_id=generate_id(), + dataset_id=utils.generate_id(), + config_id=utils.generate_id(), priority=DockerRunScheduledPriority.MID, state=DockerRunScheduledState.OPEN, created_at=0, @@ -384,9 +395,9 @@ def test_get_scheduled_run_by_id() -> None: def test_get_scheduled_run_by_id_not_found() -> None: scheduled_runs = [ DockerRunScheduledData( - id=generate_id(), - dataset_id=generate_id(), - config_id=generate_id(), + id=utils.generate_id(), + dataset_id=utils.generate_id(), + config_id=utils.generate_id(), priority=DockerRunScheduledPriority.LOW, state=DockerRunScheduledState.OPEN, created_at=0, @@ -413,11 +424,11 @@ def test_get_scheduled_run_by_id_not_found() -> None: def test_get_compute_worker_state_and_message_OPEN() -> None: - dataset_id = generate_id() + dataset_id = utils.generate_id() scheduled_run = DockerRunScheduledData( - id=generate_id(), + id=utils.generate_id(), dataset_id=dataset_id, - config_id=generate_id(), + config_id=utils.generate_id(), priority=DockerRunScheduledPriority.MID, state=DockerRunScheduledState.OPEN, created_at=0, @@ -444,12 +455,104 @@ def mocked_raise_exception(*args, **kwargs): assert run_info.in_end_state() == False +def test_create_docker_worker_config_v3_api_error() -> None: + class HttpThing: + def __init__(self, status, reason, data): + self.status = status + self.reason = reason + self.data = data + + def getheaders(self): + return [] + + def mocked_raise_exception(*args, **kwargs): + raise ApiException( + http_resp=HttpThing( + 403, + "Not everything has a reason", + '{"code": "ACCOUNT_SUBSCRIPTION_INSUFFICIENT", "error": "Your current plan allows for 1000000 samples but you tried to use 2000000 samples, please contact sales at sales@lightly.ai to upgrade your account."}', + ) + ) + + client = ApiWorkflowClient(token="123") + client._dataset_id = utils.generate_id() + client._compute_worker_api.create_docker_worker_config_v3 = mocked_raise_exception + with pytest.raises( + ValueError, + match=r'Trying to schedule your job resulted in\n>> ACCOUNT_SUBSCRIPTION_INSUFFICIENT\n>> "Your current plan allows for 1000000 samples but you tried to use 2000000 samples, please contact sales at sales@lightly.ai to upgrade your account."\n>> Please fix the issue mentioned above and see our docs https://docs.lightly.ai/docs/all-configuration-options for more help.', + ): + r = client.create_compute_worker_config( + selection_config={ + "n_samples": 2000000, + "strategies": [ + {"input": {"type": "EMBEDDINGS"}, "strategy": {"type": "DIVERSITY"}} + ], + }, + ) + + +def test_create_docker_worker_config_v3_5xx_api_error() -> None: + class HttpThing: + def __init__(self, status, reason, data): + self.status = status + self.reason = reason + self.data = data + + def getheaders(self): + return [] + + def mocked_raise_exception(*args, **kwargs): + raise ApiException( + http_resp=HttpThing( + 502, + "Not everything has a reason", + '{"code": "SOMETHING_BAD", "error": "Server pains"}', + ) + ) + + client = ApiWorkflowClient(token="123") + client._dataset_id = utils.generate_id() + client._compute_worker_api.create_docker_worker_config_v3 = mocked_raise_exception + with pytest.raises( + ApiException, + match=r"Server pains", + ): + r = client.create_compute_worker_config( + selection_config={ + "n_samples": 20, + "strategies": [ + {"input": {"type": "EMBEDDINGS"}, "strategy": {"type": "DIVERSITY"}} + ], + }, + ) + + +def test_create_docker_worker_config_v3_no_body_api_error() -> None: + def mocked_raise_exception(*args, **kwargs): + raise ApiException + + client = ApiWorkflowClient(token="123") + client._dataset_id = utils.generate_id() + client._compute_worker_api.create_docker_worker_config_v3 = mocked_raise_exception + with pytest.raises( + ApiException, + ): + r = client.create_compute_worker_config( + selection_config={ + "n_samples": 20, + "strategies": [ + {"input": {"type": "EMBEDDINGS"}, "strategy": {"type": "DIVERSITY"}} + ], + }, + ) + + def test_get_compute_worker_state_and_message_CANCELED() -> None: def mocked_raise_exception(*args, **kwargs): raise ApiException mocked_api_client = MagicMock( - dataset_id=generate_id(), + dataset_id=utils.generate_id(), _compute_worker_api=MagicMock( get_docker_run_by_scheduled_id=mocked_raise_exception ), @@ -466,7 +569,7 @@ def mocked_raise_exception(*args, **kwargs): def test_get_compute_worker_state_and_message_docker_state() -> None: message = "SOME_MESSAGE" docker_run = DockerRunData( - id=generate_id(), + id=utils.generate_id(), user_id="user-id", state=DockerRunState.GENERATING_REPORT, docker_version="", @@ -475,14 +578,14 @@ def test_get_compute_worker_state_and_message_docker_state() -> None: message=message, ) mocked_api_client = MagicMock( - dataset_id=generate_id(), + dataset_id=utils.generate_id(), _compute_worker_api=MagicMock( get_docker_run_by_scheduled_id=lambda scheduled_id: docker_run ), ) run_info = ApiWorkflowClient.get_compute_worker_run_info( - self=mocked_api_client, scheduled_run_id=generate_id() + self=mocked_api_client, scheduled_run_id=utils.generate_id() ) assert run_info.state == DockerRunState.GENERATING_REPORT assert run_info.message == message @@ -523,8 +626,8 @@ def get_compute_worker_run_info(self, scheduled_run_id: str): def test_get_compute_worker_runs(mocker: MockerFixture) -> None: mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - dataset_id = generate_id() - run_ids = [generate_id(), generate_id()] + dataset_id = utils.generate_id() + run_ids = [utils.generate_id(), utils.generate_id()] client = ApiWorkflowClient(token="123") mock_compute_worker_api = mocker.create_autospec( DockerApi, spec_set=True @@ -550,7 +653,6 @@ def test_get_compute_worker_runs(mocker: MockerFixture) -> None: last_modified_at=0, ), ], - [], ] client._compute_worker_api = mock_compute_worker_api runs = client.get_compute_worker_runs() @@ -574,13 +676,15 @@ def test_get_compute_worker_runs(mocker: MockerFixture) -> None: last_modified_at=0, ), ] - assert mock_compute_worker_api.get_docker_runs.call_count == 2 + mock_compute_worker_api.get_docker_runs.assert_called_once_with( + page_offset=0, page_size=5000 + ) def test_get_compute_worker_runs__dataset(mocker: MockerFixture) -> None: mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - dataset_id = generate_id() - run_id = generate_id() + dataset_id = utils.generate_id() + run_id = utils.generate_id() client = ApiWorkflowClient(token="123") mock_compute_worker_api = mocker.create_autospec( DockerApi, spec_set=True @@ -613,13 +717,15 @@ def test_get_compute_worker_runs__dataset(mocker: MockerFixture) -> None: last_modified_at=0, ), ] - assert mock_compute_worker_api.get_docker_runs_query_by_dataset_id.call_count == 2 + mock_compute_worker_api.get_docker_runs_query_by_dataset_id.assert_called_once_with( + page_offset=0, page_size=5000, dataset_id=dataset_id + ) def test_get_compute_worker_run_tags__no_tags(mocker: MockerFixture) -> None: mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - run_id = generate_id() - client = ApiWorkflowClient(token="123", dataset_id=generate_id()) + run_id = utils.generate_id() + client = ApiWorkflowClient(token="123", dataset_id=utils.generate_id()) mock_compute_worker_api = mocker.create_autospec( DockerApi, spec_set=True ).return_value @@ -631,8 +737,8 @@ def test_get_compute_worker_run_tags__no_tags(mocker: MockerFixture) -> None: def test_get_compute_worker_run_tags__single_tag(mocker: MockerFixture) -> None: - dataset_id = generate_id() - run_id = generate_id() + dataset_id = utils.generate_id() + run_id = utils.generate_id() mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) client = ApiWorkflowClient(token="123", dataset_id=dataset_id) client._dataset_id = dataset_id @@ -641,7 +747,7 @@ def test_get_compute_worker_run_tags__single_tag(mocker: MockerFixture) -> None: ).return_value mock_compute_worker_api.get_docker_run_tags.return_value = [ TagData( - id=generate_id(), + id=utils.generate_id(), dataset_id=dataset_id, prev_tag_id=None, bit_mask_data="0x1", @@ -660,15 +766,15 @@ def test_get_compute_worker_run_tags__single_tag(mocker: MockerFixture) -> None: def test_get_compute_worker_run_tags__multiple_tags(mocker: MockerFixture) -> None: mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - run_id = generate_id() - dataset_id = generate_id() + run_id = utils.generate_id() + dataset_id = utils.generate_id() client = ApiWorkflowClient(token="123", dataset_id=dataset_id) client._dataset_id = dataset_id mock_compute_worker_api = mocker.create_autospec( DockerApi, spec_set=True ).return_value - tag_ids = [generate_id() for _ in range(3)] + tag_ids = [utils.generate_id() for _ in range(3)] tag_0 = TagData( id=tag_ids[0], dataset_id=dataset_id, @@ -694,7 +800,7 @@ def test_get_compute_worker_run_tags__multiple_tags(mocker: MockerFixture) -> No # tag from a different dataset tag_2 = TagData( id=tag_ids[2], - dataset_id=generate_id(), + dataset_id=utils.generate_id(), prev_tag_id=None, bit_mask_data="0x1", name="tag-2", diff --git a/tests/api_workflow/test_api_workflow_datasets.py b/tests/api_workflow/test_api_workflow_datasets.py index c3e1e282d..dab2ad39d 100644 --- a/tests/api_workflow/test_api_workflow_datasets.py +++ b/tests/api_workflow/test_api_workflow_datasets.py @@ -12,15 +12,15 @@ DatasetType, ) from lightly.openapi_generated.swagger_client.rest import ApiException +from tests.api_workflow import utils from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup -from tests.api_workflow.utils import generate_id def _get_datasets(count: int) -> List[DatasetData]: return [ DatasetData( name=f"mock_dataset_{i}", - id=generate_id(), + id=utils.generate_id(), last_modified_at=0, type=DatasetType.IMAGES, img_type="full", @@ -170,6 +170,8 @@ def test_create_new_dataset_with_unique_name__name_exists( dataset_name=dataset_name, exact=False, shared=False, + page_offset=0, + page_size=5000, ) mocked_create_dataset.assert_called_once_with( dataset_name=actual_dataset_name, @@ -226,17 +228,27 @@ def test_delete_dataset(mocker: MockerFixture) -> None: def test_get_datasets__shared(mocker: MockerFixture) -> None: + datasets = _get_datasets(2) + # Returns the same set of datasets twice. API client should remove duplicates mocked_pagination = mocker.patch.object( - api_workflow_datasets.utils, "paginate_endpoint" + api_workflow_datasets.utils, + "paginate_endpoint", + side_effect=[datasets, datasets], ) mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) mock_datasets_api = mocker.MagicMock() client = ApiWorkflowClient() client._datasets_api = mock_datasets_api - client.get_datasets(shared=True) - mocked_pagination.assert_called_once_with( - mock_datasets_api.get_datasets, shared=True - ) + datasets = client.get_datasets(shared=True) + unique_dataset_ids = set([dataset.id for dataset in datasets]) + assert len(unique_dataset_ids) == len(datasets) + + assert mocked_pagination.call_count == 2 + call_args = mocked_pagination.call_args_list + assert call_args[0][0] == (mock_datasets_api.get_datasets,) + assert call_args[0][1] == {"shared": True} + assert call_args[1][0] == (mock_datasets_api.get_datasets,) + assert call_args[1][1] == {"get_assets_of_team": True} def test_get_datasets__not_shared(mocker: MockerFixture) -> None: @@ -262,7 +274,128 @@ def test_get_datasets__shared_None(mocker: MockerFixture) -> None: client = ApiWorkflowClient() client._datasets_api = mock_datasets_api client.get_datasets(shared=None) - assert mocked_pagination.call_count == 2 + assert mocked_pagination.call_count == 3 + + +def test_get_datasets_by_name__not_shared__paginated(mocker: MockerFixture) -> None: + datasets = _get_datasets(3) + # Returns the same set of datasets twice. API client should remove duplicates. + mocked_paginate_endpoint = mocker.patch.object( + api_workflow_datasets.utils, + "paginate_endpoint", + # There's one call to paginate_endpoint. + # It returns a paginated list of datasets. + return_value=iter([datasets[0], datasets[1]]), + ) + mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) + mock_datasets_api = mocker.MagicMock() + client = ApiWorkflowClient() + client._datasets_api = mock_datasets_api + + # Note: because the `dataset_name` filtering is mocked away in this test, + # the `dataset_name` passed as argument and in the returned dataset are independent. + datasets_not_shared = client.get_datasets_by_name( + shared=False, dataset_name="some_random_dataset_name" + ) + assert datasets_not_shared == [datasets[0], datasets[1]] + mocked_paginate_endpoint.assert_called_once_with( + mock_datasets_api.get_datasets_query_by_name, + dataset_name="some_random_dataset_name", + exact=True, + shared=False, + ) + + +def test_get_datasets_by_name__shared__paginated(mocker: MockerFixture) -> None: + datasets = _get_datasets(3) + # Returns the same set of datasets twice. API client should remove duplicates. + mocked_paginate_endpoint = mocker.patch.object( + api_workflow_datasets.utils, + "paginate_endpoint", + side_effect=[ + # There are two calls to paginate_endpoint to get all the team's datasets. + iter([datasets[2]]), + iter([]), + ], + ) + mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) + mock_datasets_api = mocker.MagicMock() + client = ApiWorkflowClient() + client._datasets_api = mock_datasets_api + + # Note: because the `dataset_name` filtering is mocked away in this test, + # the `dataset_name` passed as argument and in the returned dataset are independent. + datasets_shared = client.get_datasets_by_name( + shared=True, dataset_name="some_random_dataset_name" + ) + assert datasets_shared == [datasets[2]] + mocked_paginate_endpoint.assert_has_calls( + [ + mocker.call( + mock_datasets_api.get_datasets_query_by_name, + dataset_name="some_random_dataset_name", + exact=True, + shared=True, + ), + mocker.call( + mock_datasets_api.get_datasets_query_by_name, + dataset_name="some_random_dataset_name", + exact=True, + get_assets_of_team=True, + ), + ] + ) + + +def test_get_datasets_by_name__shared_None__paginated(mocker: MockerFixture) -> None: + datasets = _get_datasets(3) + # Returns the same set of datasets twice. API client should remove duplicates. + mocked_paginate_endpoint = mocker.patch.object( + api_workflow_datasets.utils, + "paginate_endpoint", + side_effect=[ + # There are three calls to paginate_endpoint. The first call + # gets all the user's datasets. The second and third calls get + # all the team's datasets. + # The first call returns a paginated list of datasets. + iter([datasets[0], datasets[1]]), + iter([datasets[2]]), + iter([]), + ], + ) + mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) + mock_datasets_api = mocker.MagicMock() + client = ApiWorkflowClient() + client._datasets_api = mock_datasets_api + + # Note: because the `dataset_name` filtering is mocked away in this test, + # the `dataset_name` passed as argument and in the returned dataset are independent. + datasets_shared_none = client.get_datasets_by_name( + shared=None, dataset_name="some_random_dataset_name" + ) + assert datasets_shared_none == [datasets[0], datasets[1], datasets[2]] + mocked_paginate_endpoint.assert_has_calls( + [ + mocker.call( + mock_datasets_api.get_datasets_query_by_name, + dataset_name="some_random_dataset_name", + exact=True, + shared=False, + ), + mocker.call( + mock_datasets_api.get_datasets_query_by_name, + dataset_name="some_random_dataset_name", + exact=True, + shared=True, + ), + mocker.call( + mock_datasets_api.get_datasets_query_by_name, + dataset_name="some_random_dataset_name", + exact=True, + get_assets_of_team=True, + ), + ] + ) def test_set_dataset_id__error(mocker: MockerFixture): diff --git a/tests/api_workflow/test_api_workflow_datasources.py b/tests/api_workflow/test_api_workflow_datasources.py index 414b8c12e..b83910975 100644 --- a/tests/api_workflow/test_api_workflow_datasources.py +++ b/tests/api_workflow/test_api_workflow_datasources.py @@ -10,6 +10,12 @@ DatasourceConfigS3DelegatedAccess, DatasourceRawSamplesDataRow, ) +from lightly.openapi_generated.swagger_client.models.datasource_config_verify_data import ( + DatasourceConfigVerifyData, +) +from lightly.openapi_generated.swagger_client.models.datasource_config_verify_data_errors import ( + DatasourceConfigVerifyDataErrors, +) def test__download_raw_files(mocker: MockerFixture) -> None: @@ -289,3 +295,48 @@ def test_update_processed_until_timestamp(mocker: MockerFixture) -> None: kwargs["datasource_processed_until_timestamp_request"].processed_until_timestamp == 10 ) + + +def test_list_datasource_permissions(mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc") + client._dataset_id = "dataset-id" + client._datasources_api.verify_datasource_by_dataset_id = mocker.MagicMock( + return_value=DatasourceConfigVerifyData( + canRead=True, + canWrite=True, + canList=False, + canOverwrite=True, + errors=None, + ), + ) + assert client.list_datasource_permissions() == { + "can_read": True, + "can_write": True, + "can_list": False, + "can_overwrite": True, + } + + +def test_list_datasource_permissions__error(mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc") + client._dataset_id = "dataset-id" + client._datasources_api.verify_datasource_by_dataset_id = mocker.MagicMock( + return_value=DatasourceConfigVerifyData( + canRead=True, + canWrite=True, + canList=False, + canOverwrite=True, + errors=DatasourceConfigVerifyDataErrors( + canRead=None, canWrite=None, canList="error message", canOverwrite=None + ), + ), + ) + assert client.list_datasource_permissions() == { + "can_read": True, + "can_write": True, + "can_list": False, + "can_overwrite": True, + "errors": { + "can_list": "error message", + }, + } diff --git a/tests/api_workflow/test_api_workflow_download_dataset.py b/tests/api_workflow/test_api_workflow_download_dataset.py index e4fe5385b..da695e0e9 100644 --- a/tests/api_workflow/test_api_workflow_download_dataset.py +++ b/tests/api_workflow/test_api_workflow_download_dataset.py @@ -9,18 +9,18 @@ ImageType, TagData, ) -from tests.api_workflow.utils import generate_id +from tests.api_workflow import utils def test_download_dataset__no_image(mocker: MockerFixture) -> None: - dataset_id = generate_id() + dataset_id = utils.generate_id() mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) mocked_api = mocker.MagicMock() mocked_get_dataset_by_id = mocker.MagicMock( return_value=DatasetData( name="dataset", id=dataset_id, - user_id=generate_id(), + user_id=utils.generate_id(), last_modified_at=0, type=DatasetType.IMAGES, img_type=ImageType.META, @@ -47,8 +47,8 @@ def test_download_dataset__tag_missing(mocker: MockerFixture) -> None: mocked_get_dataset_by_id = mocker.MagicMock( return_value=DatasetData( name="dataset", - id=generate_id(), - user_id=generate_id(), + id=utils.generate_id(), + user_id=utils.generate_id(), last_modified_at=0, type=DatasetType.IMAGES, img_type=ImageType.FULL, @@ -68,13 +68,13 @@ def test_download_dataset__tag_missing(mocker: MockerFixture) -> None: def test_download_dataset__ok(mocker: MockerFixture) -> None: - dataset_id = generate_id() + dataset_id = utils.generate_id() mocked_get_dataset_by_id = mocker.MagicMock( return_value=DatasetData( name="dataset", id=dataset_id, - user_id=generate_id(), + user_id=utils.generate_id(), last_modified_at=0, type=DatasetType.IMAGES, img_type=ImageType.FULL, @@ -106,7 +106,7 @@ def test_download_dataset__ok(mocker: MockerFixture) -> None: "get_all_tags", return_value=[ TagData( - id=generate_id(), + id=utils.generate_id(), dataset_id=dataset_id, prev_tag_id=None, bit_mask_data="0x1", @@ -149,13 +149,13 @@ def test_download_dataset__ok(mocker: MockerFixture) -> None: def test_get_embedding_data_by_name(mocker: MockerFixture) -> None: embedding_0 = DatasetEmbeddingData( - id=generate_id(), + id=utils.generate_id(), name="embedding_0", created_at=0, is_processed=False, ) embedding_1 = DatasetEmbeddingData( - id=generate_id(), + id=utils.generate_id(), name="embedding_1", created_at=1, is_processed=False, @@ -177,7 +177,7 @@ def test_get_embedding_data_by_name__no_embedding_with_name( mocker: MockerFixture, ) -> None: embedding = DatasetEmbeddingData( - id=generate_id(), + id=utils.generate_id(), name="embedding", created_at=0, is_processed=False, @@ -230,7 +230,7 @@ def test_download_embeddings_csv_by_id(mocker: MockerFixture) -> None: def test_download_embeddings_csv(mocker: MockerFixture) -> None: - embedding_id = generate_id() + embedding_id = utils.generate_id() mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) mock_get_all_embedding_data = mocker.patch.object( @@ -282,7 +282,7 @@ def test_download_embeddings_csv__no_default_embedding(mocker: MockerFixture) -> def test__get_latest_default_embedding_data__no_default_embedding() -> None: custom_embedding = DatasetEmbeddingData( - id=generate_id(), + id=utils.generate_id(), name="custom-name", created_at=0, is_processed=False, diff --git a/tests/api_workflow/test_api_workflow_export.py b/tests/api_workflow/test_api_workflow_export.py index 28bc9e957..de5af5e8e 100644 --- a/tests/api_workflow/test_api_workflow_export.py +++ b/tests/api_workflow/test_api_workflow_export.py @@ -1,13 +1,14 @@ from pytest_mock import MockerFixture from lightly.api import ApiWorkflowClient, api_workflow_export +from lightly.api import utils as api_utils from lightly.openapi_generated.swagger_client.models import FileNameFormat, TagData -from tests.api_workflow.utils import generate_id +from tests.api_workflow import utils def _get_tag(dataset_id: str, tag_name: str) -> TagData: return TagData( - id=generate_id(), + id=utils.generate_id(), dataset_id=dataset_id, prev_tag_id=None, bit_mask_data="0x1", @@ -18,15 +19,64 @@ def _get_tag(dataset_id: str, tag_name: str) -> TagData: ) -def test_export_tag_to_basic_filenames_and_read_urls(mocker: MockerFixture) -> None: - dataset_id = generate_id() - mocked_retry = mocker.patch.object( - api_workflow_export, - "retry", +def test_export_filenames_by_tag_id(mocker: MockerFixture) -> None: + dataset_id = utils.generate_id() + mocked_paginate = mocker.patch.object( + api_utils, + "paginate_endpoint", + side_effect=[iter(["file0\nfile1"])], + ) + mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) + mocked_api = mocker.MagicMock() + + client = ApiWorkflowClient() + client._dataset_id = dataset_id + client._tags_api = mocked_api + data = client.export_filenames_by_tag_id(tag_id="tag_id") + + assert data == "file0\nfile1" + mocked_paginate.assert_called_once_with( + mocked_api.export_tag_to_basic_filenames, + dataset_id=dataset_id, + tag_id="tag_id", + ) + + +def test_export_filenames_by_tag_id__two_pages(mocker: MockerFixture) -> None: + dataset_id = utils.generate_id() + mocked_paginate = mocker.patch.object( + api_utils, + "paginate_endpoint", + side_effect=[ + # Simulate two pages. + iter(["file0\nfile1", "file2\nfile3"]) + ], + ) + mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) + mocked_api = mocker.MagicMock() + + client = ApiWorkflowClient() + client._dataset_id = dataset_id + client._tags_api = mocked_api + data = client.export_filenames_by_tag_id(tag_id="tag_id") + + assert data == "file0\nfile1\nfile2\nfile3" + mocked_paginate.assert_called_once_with( + mocked_api.export_tag_to_basic_filenames, + dataset_id=dataset_id, + tag_id="tag_id", + ) + + +def test_export_filenames_and_read_urls_by_tag_id(mocker: MockerFixture) -> None: + dataset_id = utils.generate_id() + mocked_paginate = mocker.patch.object( + api_utils, + "paginate_endpoint", side_effect=[ - "file0\nfile1", - "read_url0\nread_url1", - "datasource_url0\ndatasource_url1", + iter(["file0\nfile1"]), + iter(["read_url0\nread_url1"]), + iter(["datasource_url0\ndatasource_url1"]), ], ) mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) @@ -49,20 +99,102 @@ def test_export_tag_to_basic_filenames_and_read_urls(mocker: MockerFixture) -> N "datasourceUrl": "datasource_url1", }, ] - assert mocked_retry.call_count == 3 - file_name_format_call_args = [ - call_args[1].get("file_name_format") - for call_args in mocked_retry.call_args_list - ] - assert file_name_format_call_args == [ - FileNameFormat.NAME, - FileNameFormat.REDIRECTED_READ_URL, - FileNameFormat.DATASOURCE_FULL, + mocked_paginate.assert_has_calls( + [ + mocker.call( + mocked_api.export_tag_to_basic_filenames, + dataset_id=dataset_id, + tag_id="tag_id", + file_name_format=FileNameFormat.NAME, + ), + mocker.call( + mocked_api.export_tag_to_basic_filenames, + dataset_id=dataset_id, + tag_id="tag_id", + file_name_format=FileNameFormat.REDIRECTED_READ_URL, + ), + mocker.call( + mocked_api.export_tag_to_basic_filenames, + dataset_id=dataset_id, + tag_id="tag_id", + file_name_format=FileNameFormat.DATASOURCE_FULL, + ), + ] + ) + + +def test_export_filenames_and_read_urls_by_tag_id__two_pages( + mocker: MockerFixture, +) -> None: + dataset_id = utils.generate_id() + mocked_paginate = mocker.patch.object( + api_utils, + "paginate_endpoint", + side_effect=[ + # Simulate two pages. + iter(["file0\nfile1", "file2\nfile3"]), + iter(["read_url0\nread_url1", "read_url2\nread_url3"]), + iter( + ["datasource_url0\ndatasource_url1", "datasource_url2\ndatasource_url3"] + ), + ], + ) + mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) + mocked_api = mocker.MagicMock() + + client = ApiWorkflowClient() + client._dataset_id = dataset_id + client._tags_api = mocked_api + data = client.export_filenames_and_read_urls_by_tag_id(tag_id="tag_id") + + assert data == [ + { + "fileName": "file0", + "readUrl": "read_url0", + "datasourceUrl": "datasource_url0", + }, + { + "fileName": "file1", + "readUrl": "read_url1", + "datasourceUrl": "datasource_url1", + }, + { + "fileName": "file2", + "readUrl": "read_url2", + "datasourceUrl": "datasource_url2", + }, + { + "fileName": "file3", + "readUrl": "read_url3", + "datasourceUrl": "datasource_url3", + }, ] + mocked_paginate.assert_has_calls( + [ + mocker.call( + mocked_api.export_tag_to_basic_filenames, + dataset_id=dataset_id, + tag_id="tag_id", + file_name_format=FileNameFormat.NAME, + ), + mocker.call( + mocked_api.export_tag_to_basic_filenames, + dataset_id=dataset_id, + tag_id="tag_id", + file_name_format=FileNameFormat.REDIRECTED_READ_URL, + ), + mocker.call( + mocked_api.export_tag_to_basic_filenames, + dataset_id=dataset_id, + tag_id="tag_id", + file_name_format=FileNameFormat.DATASOURCE_FULL, + ), + ] + ) def test_export_filenames_by_tag_name(mocker: MockerFixture) -> None: - dataset_id = generate_id() + dataset_id = utils.generate_id() tag_name = "some-tag" tag = _get_tag(dataset_id=dataset_id, tag_name=tag_name) mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) @@ -79,12 +211,14 @@ def test_export_filenames_by_tag_name(mocker: MockerFixture) -> None: def test_export_label_box_data_rows_by_tag_id(mocker: MockerFixture) -> None: mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_paginate = mocker.patch.object(api_workflow_export, "paginate_endpoint") + mocked_paginate = mocker.patch.object( + api_workflow_export.utils, "paginate_endpoint" + ) mocked_api = mocker.MagicMock() mocked_warning = mocker.patch("warnings.warn") client = ApiWorkflowClient() - client._dataset_id = generate_id() + client._dataset_id = utils.generate_id() client._tags_api = mocked_api client.export_label_box_data_rows_by_tag_id(tag_id="tag_id") mocked_paginate.assert_called_once() @@ -99,7 +233,7 @@ def test_export_label_box_data_rows_by_tag_id(mocker: MockerFixture) -> None: def test_export_label_box_data_rows_by_tag_name(mocker: MockerFixture) -> None: - dataset_id = generate_id() + dataset_id = utils.generate_id() tag_name = "some-tag" tag = _get_tag(dataset_id=dataset_id, tag_name=tag_name) mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) @@ -125,11 +259,13 @@ def test_export_label_box_data_rows_by_tag_name(mocker: MockerFixture) -> None: def test_export_label_box_v4_data_rows_by_tag_id(mocker: MockerFixture) -> None: mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_paginate = mocker.patch.object(api_workflow_export, "paginate_endpoint") + mocked_paginate = mocker.patch.object( + api_workflow_export.utils, "paginate_endpoint" + ) mocked_api = mocker.MagicMock() client = ApiWorkflowClient() - client._dataset_id = generate_id() + client._dataset_id = utils.generate_id() client._tags_api = mocked_api client.export_label_box_v4_data_rows_by_tag_id(tag_id="tag_id") mocked_paginate.assert_called_once() @@ -138,7 +274,7 @@ def test_export_label_box_v4_data_rows_by_tag_id(mocker: MockerFixture) -> None: def test_export_label_box_v4_data_rows_by_tag_name(mocker: MockerFixture) -> None: - dataset_id = generate_id() + dataset_id = utils.generate_id() tag_name = "some-tag" tag = _get_tag(dataset_id=dataset_id, tag_name=tag_name) mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) @@ -156,7 +292,7 @@ def test_export_label_box_v4_data_rows_by_tag_name(mocker: MockerFixture) -> Non def test_export_label_studio_tasks_by_tag_name(mocker: MockerFixture) -> None: - dataset_id = generate_id() + dataset_id = utils.generate_id() tag_name = "some-tag" tag = _get_tag(dataset_id=dataset_id, tag_name=tag_name) mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) diff --git a/tests/api_workflow/test_api_workflow_predictions.py b/tests/api_workflow/test_api_workflow_predictions.py index 67d69e660..82e6084c6 100644 --- a/tests/api_workflow/test_api_workflow_predictions.py +++ b/tests/api_workflow/test_api_workflow_predictions.py @@ -15,13 +15,15 @@ def test_create_or_update_prediction_task_schema() -> None: mocked_client.dataset_id = "some_dataset_id" mocked_client._predictions_api = MagicMock(spec_set=PredictionsApi) - schema = PredictionTaskSchema( - name="my-object-detection", - type=TaskType.OBJECT_DETECTION, - categories=[ - PredictionTaskSchemaCategory(id=0, name="dog"), - PredictionTaskSchemaCategory(id=1, name="cat"), - ], + schema = PredictionTaskSchema.from_dict( + { + "name": "my-object-detection", + "type": TaskType.OBJECT_DETECTION, + "categories": [ + PredictionTaskSchemaCategory(id=0, name="dog").to_dict(), + PredictionTaskSchemaCategory(id=1, name="cat").to_dict(), + ], + } ) timestamp = 1234 ApiWorkflowClient.create_or_update_prediction_task_schema( @@ -67,40 +69,3 @@ def test_create_or_update_prediction() -> None: sample_id=sample_id, prediction_uuid_timestamp=timestamp, ) - - -def test_create_or_update_predictions() -> None: - mocked_client = MagicMock(spec=ApiWorkflowClient).return_value - mocked_client.dataset_id = "some_dataset_id" - - sample_id_to_prediction_singletons_dummy = { - f"sample_id_{i}": [ - PredictionSingletonClassification( - type="CLASSIFICATION", - taskName="my-task", - categoryId=i % 4, - score=0.9, - probabilities=[0.1, 0.2, 0.3, 0.4], - ) - ] - for i in range(4) - } - - timestamp = 1234 - ApiWorkflowClient.create_or_update_predictions( - self=mocked_client, - sample_id_to_prediction_singletons=sample_id_to_prediction_singletons_dummy, - prediction_version_id=timestamp, - ) - - expected_calls = [ - call( - sample_id=sample_id, - prediction_singletons=singletons, - prediction_version_id=timestamp, - ) - for sample_id, singletons in sample_id_to_prediction_singletons_dummy.items() - ] - mocked_client.create_or_update_prediction.assert_has_calls( - calls=expected_calls, any_order=True - ) diff --git a/tests/api_workflow/test_api_workflow_selection.py b/tests/api_workflow/test_api_workflow_selection.py index 5c979c6a6..73147fd03 100644 --- a/tests/api_workflow/test_api_workflow_selection.py +++ b/tests/api_workflow/test_api_workflow_selection.py @@ -14,13 +14,13 @@ SamplingMethod, TagData, ) -from tests.api_workflow.utils import generate_id +from tests.api_workflow import utils def _get_tags(dataset_id: str, tag_name: str = "just-a-tag") -> List[TagData]: return [ TagData( - id=generate_id(), + id=utils.generate_id(), dataset_id=dataset_id, prev_tag_id=None, bit_mask_data="0x1", @@ -46,7 +46,7 @@ def test_selection__tag_exists(mocker: MockerFixture) -> None: mocker.patch.object( ApiWorkflowClient, "get_all_tags", - return_value=_get_tags(dataset_id=generate_id(), tag_name=tag_name), + return_value=_get_tags(dataset_id=utils.generate_id(), tag_name=tag_name), ) client = ApiWorkflowClient() @@ -71,7 +71,7 @@ def test_selection__no_tags(mocker: MockerFixture) -> None: def test_selection(mocker: MockerFixture) -> None: tag_name = "some-tag" - dataset_id = generate_id() + dataset_id = utils.generate_id() mocker.patch("time.sleep") mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) mocker.patch.object( @@ -86,13 +86,13 @@ def test_selection(mocker: MockerFixture) -> None: mocked_selection_api = mocker.MagicMock() mocked_sampling_response = mocker.MagicMock() - mocked_sampling_response.job_id = generate_id() + mocked_sampling_response.job_id = utils.generate_id() mocked_selection_api.trigger_sampling_by_id.return_value = mocked_sampling_response mocked_jobs_api = mocker.MagicMock() mocked_get_job_status = mocker.MagicMock( return_value=JobStatusData( - id=generate_id(), + id=utils.generate_id(), wait_time_till_next_poll=1, created_at=0, status=JobState.FINISHED, @@ -118,7 +118,7 @@ def test_selection(mocker: MockerFixture) -> None: def test_selection__job_failed(mocker: MockerFixture) -> None: - dataset_id = generate_id() + dataset_id = utils.generate_id() job_id = "some-job-id" mocker.patch("time.sleep") mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) @@ -140,7 +140,7 @@ def test_selection__job_failed(mocker: MockerFixture) -> None: mocked_jobs_api = mocker.MagicMock() mocked_get_job_status = mocker.MagicMock( return_value=JobStatusData( - id=generate_id(), + id=utils.generate_id(), wait_time_till_next_poll=1, created_at=0, status=JobState.FAILED, @@ -162,7 +162,7 @@ def test_selection__job_failed(mocker: MockerFixture) -> None: def test_selection__too_many_errors(mocker: MockerFixture) -> None: - dataset_id = generate_id() + dataset_id = utils.generate_id() job_id = "some-job-id" mocker.patch("time.sleep") mocked_print = mocker.patch("builtins.print") @@ -203,7 +203,7 @@ def test_selection__too_many_errors(mocker: MockerFixture) -> None: def test_upload_scores(mocker: MockerFixture) -> None: - dataset_id = generate_id() + dataset_id = utils.generate_id() tags = _get_tags(dataset_id=dataset_id, tag_name="initial-tag") tag_id = tags[0].id mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) diff --git a/tests/api_workflow/test_api_workflow_tags.py b/tests/api_workflow/test_api_workflow_tags.py index e626e3b6c..3667355bd 100644 --- a/tests/api_workflow/test_api_workflow_tags.py +++ b/tests/api_workflow/test_api_workflow_tags.py @@ -6,7 +6,7 @@ from lightly.api import ApiWorkflowClient from lightly.api.api_workflow_tags import TagDoesNotExistError from lightly.openapi_generated.swagger_client.models import TagCreator, TagData -from tests.api_workflow.utils import generate_id +from tests.api_workflow import utils def _get_tags( @@ -14,7 +14,7 @@ def _get_tags( ) -> List[TagData]: return [ TagData( - id=generate_id(), + id=utils.generate_id(), dataset_id=dataset_id, prev_tag_id=prev_tag_id, bit_mask_data="0x5", @@ -27,7 +27,7 @@ def _get_tags( def test_create_tag_from_filenames(mocker: MockerFixture) -> None: - dataset_id = generate_id() + dataset_id = utils.generate_id() tags = _get_tags(dataset_id=dataset_id, tag_name="initial-tag") mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) mocker.patch.object(ApiWorkflowClient, "get_all_tags", return_value=tags) @@ -51,7 +51,7 @@ def test_create_tag_from_filenames(mocker: MockerFixture) -> None: def test_create_tag_from_filenames__tag_exists(mocker: MockerFixture) -> None: tag_name = "some-tag" - tags = _get_tags(dataset_id=generate_id(), tag_name=tag_name) + tags = _get_tags(dataset_id=utils.generate_id(), tag_name=tag_name) mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) mocker.patch.object(ApiWorkflowClient, "get_all_tags", return_value=tags) @@ -78,7 +78,7 @@ def test_create_tag_from_filenames__no_tags(mocker: MockerFixture) -> None: def test_create_tag_from_filenames__file_not_found(mocker: MockerFixture) -> None: - tags = _get_tags(dataset_id=generate_id(), tag_name="initial-tag") + tags = _get_tags(dataset_id=utils.generate_id(), tag_name="initial-tag") mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) mocker.patch.object(ApiWorkflowClient, "get_all_tags", return_value=tags) mocked_get_filenames = mocker.patch.object( @@ -101,7 +101,7 @@ def test_create_tag_from_filenames__file_not_found(mocker: MockerFixture) -> Non def test_get_filenames_in_tag(mocker: MockerFixture) -> None: - tag = _get_tags(dataset_id=generate_id())[0] + tag = _get_tags(dataset_id=utils.generate_id())[0] mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) mocked_get_filenames = mocker.patch.object( ApiWorkflowClient, "get_filenames", return_value=[f"file{i}" for i in range(3)] @@ -115,7 +115,7 @@ def test_get_filenames_in_tag(mocker: MockerFixture) -> None: def test_get_filenames_in_tag__filenames_given(mocker: MockerFixture) -> None: - tag = _get_tags(dataset_id=generate_id())[0] + tag = _get_tags(dataset_id=utils.generate_id())[0] mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) mocked_get_filenames = mocker.patch.object(ApiWorkflowClient, "get_filenames") @@ -129,8 +129,8 @@ def test_get_filenames_in_tag__filenames_given(mocker: MockerFixture) -> None: def test_get_filenames_in_tag__exclude_parent_tag(mocker: MockerFixture) -> None: - prev_tag_id = generate_id() - dataset_id = generate_id() + prev_tag_id = utils.generate_id() + dataset_id = utils.generate_id() tag = _get_tags(dataset_id=dataset_id, prev_tag_id=prev_tag_id)[0] mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) mocked_get_filenames = mocker.patch.object( @@ -180,7 +180,7 @@ def test_get_tag_by_id(mocker: MockerFixture) -> None: def test_get_tag_name(mocker: MockerFixture) -> None: tag_name = "some-tag" - tags = _get_tags(dataset_id=generate_id(), tag_name=tag_name) + tags = _get_tags(dataset_id=utils.generate_id(), tag_name=tag_name) mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) mocker.patch.object(ApiWorkflowClient, "get_all_tags", return_value=tags) mocked_get_tag = mocker.patch.object(ApiWorkflowClient, "get_tag_by_id") diff --git a/tests/api_workflow/test_api_workflow_upload_custom_metadata.py b/tests/api_workflow/test_api_workflow_upload_custom_metadata.py index 4c79b4c29..63f36de37 100644 --- a/tests/api_workflow/test_api_workflow_upload_custom_metadata.py +++ b/tests/api_workflow/test_api_workflow_upload_custom_metadata.py @@ -3,10 +3,11 @@ from lightly.api import ApiWorkflowClient, api_workflow_upload_metadata from lightly.openapi_generated.swagger_client.models import ( SampleDataModes, + SamplePartialMode, SampleUpdateRequest, ) from lightly.utils.io import COCO_ANNOTATION_KEYS -from tests.api_workflow.utils import generate_id +from tests.api_workflow import utils def test_index_custom_metadata_by_filename(mocker: MockerFixture) -> None: @@ -41,16 +42,26 @@ def test_upload_custom_metadata(mocker: MockerFixture) -> None: # retry should be called twice: once for get_samples_partial_by_dataset_id # and once for update_sample_by_id. get_samples_partial_by_dataset_id returns # only one valid sample file `file1` + dummy_sample = SampleDataModes(id=utils.generate_id(), file_name="file1") + + mocked_paginate_endpoint = mocker.patch.object( + api_workflow_upload_metadata, + "paginate_endpoint", + side_effect=[ + [dummy_sample], + None, + ], + ) mocked_retry = mocker.patch.object( api_workflow_upload_metadata, "retry", side_effect=[ - [SampleDataModes(id=generate_id(), file_name="file1")], + [dummy_sample], None, ], ) mocked_print_warning = mocker.patch.object( - api_workflow_upload_metadata, "print_as_warning" + api_workflow_upload_metadata.hipify, "print_as_warning" ) mocked_executor = mocker.patch.object( api_workflow_upload_metadata, "ThreadPoolExecutor" @@ -98,21 +109,21 @@ def test_upload_custom_metadata(mocker: MockerFixture) -> None: ), ] - assert mocked_retry.call_count == 2 # First call: get_samples_partial_by_dataset_id - args_first_call = mocked_retry.call_args_list[0][0] - assert ( - # Check first positional argument - args_first_call[0] - == mocked_samples_api.get_samples_partial_by_dataset_id + mocked_paginate_endpoint.assert_called_once_with( + mocked_samples_api.get_samples_partial_by_dataset_id, + dataset_id="dataset-id", + mode=SamplePartialMode.FILENAMES, + page_size=25000, ) # Second call: update_sample_by_id with the only valid sample - args_second_call = mocked_retry.call_args_list[1][0] - kwargs_second_call = mocked_retry.call_args_list[1][1] - # Check first positional argument - assert args_second_call[0] == mocked_samples_api.update_sample_by_id - # Check second positional argument - assert isinstance(kwargs_second_call["sample_update_request"], SampleUpdateRequest) - assert kwargs_second_call["sample_update_request"].custom_meta_data == { - COCO_ANNOTATION_KEYS.custom_metadata_image_id: "image-id1" - } + mocked_retry.assert_called_once_with( + mocked_samples_api.update_sample_by_id, + dataset_id="dataset-id", + sample_id=dummy_sample.id, + sample_update_request=SampleUpdateRequest( + custom_meta_data={ + COCO_ANNOTATION_KEYS.custom_metadata_image_id: "image-id1" + } + ), + ) diff --git a/tests/api_workflow/test_api_workflow_upload_dataset.py b/tests/api_workflow/test_api_workflow_upload_dataset.py deleted file mode 100644 index ec773e2b6..000000000 --- a/tests/api_workflow/test_api_workflow_upload_dataset.py +++ /dev/null @@ -1,201 +0,0 @@ -import os -import pathlib -import tempfile -import warnings - -import cv2 -import numpy as np -import pytest -import torchvision - -from lightly.api.utils import MAXIMUM_FILENAME_LENGTH -from lightly.data.dataset import LightlyDataset -from lightly.openapi_generated.swagger_client.models import SamplePartialMode -from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup - - -# TODO: fix this text -@pytest.skip( - "Skip this test for now. Test cases need to be updated.", allow_module_level=True -) -class TestApiWorkflowUploadDataset(MockedApiWorkflowSetup): - def setUp(self) -> None: - MockedApiWorkflowSetup.setUp(self) - warnings.filterwarnings("ignore", category=UserWarning) - self.n_data = 100 - self.create_fake_dataset() - self.api_workflow_client._tags_api.no_tags = 0 - - def tearDown(self) -> None: - warnings.resetwarnings() - - def create_fake_dataset(self, length_of_filepath: int = -1, sample_names=None): - n_data = self.n_data if sample_names is None else len(sample_names) - self.dataset = torchvision.datasets.FakeData( - size=n_data, image_size=(3, 32, 32) - ) - - self.folder_path = tempfile.mkdtemp() - image_extension = ".jpg" - sample_names = ( - sample_names - if sample_names is not None - else [f"img_{i}{image_extension}" for i in range(n_data)] - ) - for sample_idx in range(n_data): - data = self.dataset[sample_idx] - sample_name = sample_names[sample_idx] - path = os.path.join(self.folder_path, sample_name) - - if length_of_filepath > len(path): - assert path.endswith(image_extension) - n_missing_chars = length_of_filepath - len(path) - path = ( - path[: -len(image_extension)] - + "x" * n_missing_chars - + image_extension - ) - - data[0].save(path) - - def corrupt_fake_dataset(self): - n_data = self.n_data - sample_names = [f"img_{i}.jpg" for i in range(n_data)] - for sample_name in sample_names: - pathlib.Path(os.path.join(self.folder_path, sample_name)).touch() - - def test_upload_dataset_over_quota(self): - quota = self.n_data - 1 - - def get_quota_reduced(): - return str(quota) - - self.api_workflow_client._quota_api.get_quota_maximum_dataset_size = ( - get_quota_reduced - ) - with self.assertRaises(ValueError): - self.api_workflow_client.upload_dataset(input=self.folder_path) - - def test_upload_dataset_from_folder(self): - self.api_workflow_client.upload_dataset(input=self.folder_path) - - def test_upload_dataset_from_folder_full(self): - self.api_workflow_client.upload_dataset(input=self.folder_path, mode="full") - - def test_upload_dataset_from_folder_only_metadata(self): - self.api_workflow_client.upload_dataset(input=self.folder_path, mode="metadata") - - def test_upsize_existing_dataset(self): - self.api_workflow_client._tags_api.no_tags = 1 - self.api_workflow_client.upload_dataset(input=self.folder_path) - - def test_upload_dataset_from_dataset(self): - dataset = LightlyDataset.from_torch_dataset(self.dataset) - self.api_workflow_client.upload_dataset(input=dataset) - - def test_corrupt_dataset_from_folder(self): - self.corrupt_fake_dataset() - self.api_workflow_client.upload_dataset(input=self.folder_path) - self.api_workflow_client.upload_dataset(input=self.folder_path) - - def test_filename_length_lower(self): - self.create_fake_dataset(length_of_filepath=MAXIMUM_FILENAME_LENGTH - 1) - self.api_workflow_client.upload_dataset(input=self.folder_path) - - samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id( - dataset_id="does not matter" - ) - self.assertEqual(self.n_data, len(samples)) - - def test_filename_length_upper(self): - self.create_fake_dataset(length_of_filepath=MAXIMUM_FILENAME_LENGTH + 10) - self.api_workflow_client.upload_dataset(input=self.folder_path) - - samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id( - dataset_id="does not matter" - ) - self.assertEqual(0, len(samples)) - - def create_fake_video_dataset( - self, n_videos=5, n_frames_per_video=10, w=32, h=32, c=3, extension="avi" - ): - self.video_input_dir = tempfile.mkdtemp() - self.frames = (np.random.randn(n_frames_per_video, w, h, c) * 255).astype( - np.uint8 - ) - - for i in range(n_videos): - path = os.path.join(self.video_input_dir, f"output-{i}.{extension}") - out = cv2.VideoWriter(path, 0, 1, (w, h)) - for frame in self.frames: - out.write(frame) - out.release() - - def test_upload_video_dataset_from_folder(self): - self.create_fake_video_dataset() - self.api_workflow_client.upload_dataset(input=self.folder_path) - - def test_upload_dataset_twice(self): - rng = np.random.default_rng(2021) - - base_upload_single_image = self.api_workflow_client._upload_single_image - - # Upload with some uploads failing - def failing_upload_sample(*args, **kwargs): - if rng.random() < 0.9: - return base_upload_single_image(*args, **kwargs) - else: - raise ValueError() - - self.api_workflow_client._upload_single_image = failing_upload_sample - self.api_workflow_client.upload_dataset(input=self.folder_path) - - # Ensure that not all samples were uploaded - samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id( - dataset_id="does not matter" - ) - self.assertLess(len(samples), self.n_data) - - # Upload without failing uploads - self.api_workflow_client._upload_single_image = base_upload_single_image - self.api_workflow_client.upload_dataset(input=self.folder_path) - - # Ensure that now all samples were uploaded exactly once - samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id( - dataset_id="does not matter" - ) - self.assertEqual(self.n_data, len(samples)) - - def test_upload_dataset_twice_with_overlap(self): - all_sample_names = [f"img_upload_twice_{i}.jpg" for i in range(10)] - - # upload first part of the dataset (sample_0 - sample_6) - self.create_fake_dataset(sample_names=all_sample_names[:7]) - self.api_workflow_client.upload_dataset(input=self.folder_path) - - # upload second part of the dataset (sample_3 - sample_9) - self.create_fake_dataset(sample_names=all_sample_names[3:]) - self.api_workflow_client.upload_dataset(input=self.folder_path) - - # always returns all samples so dataset_id doesn't matter - samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id( - dataset_id="" - ) - - # assert the filenames are the same - self.assertListEqual( - sorted(all_sample_names), - sorted([s.file_name for s in samples]), - ) - - # assert partially getting the samples fileNames returns the same data - samples_file_names = ( - self.api_workflow_client._samples_api.get_samples_partial_by_dataset_id( - dataset_id="", mode=SamplePartialMode.FULL - ) - ) - - self.assertListEqual( - sorted(all_sample_names), - sorted([s.file_name for s in samples_file_names]), - ) diff --git a/tests/api_workflow/test_api_workflow_upload_embeddings.py b/tests/api_workflow/test_api_workflow_upload_embeddings.py index f04ac4b3b..4444d11d7 100644 --- a/tests/api_workflow/test_api_workflow_upload_embeddings.py +++ b/tests/api_workflow/test_api_workflow_upload_embeddings.py @@ -3,11 +3,8 @@ import numpy as np -from lightly.utils.io import ( - INVALID_FILENAME_CHARACTERS, - load_embeddings, - save_embeddings, -) +from lightly.utils import io as io_utils +from lightly.utils.io import INVALID_FILENAME_CHARACTERS from tests.api_workflow.mocked_api_workflow_client import ( N_FILES_ON_SERVER, MockedApiWorkflowSetup, @@ -37,7 +34,7 @@ def create_fake_embeddings( f"_{special_char_in_first_filename}" f"{self.sample_names[0]}" ) labels = [0] * len(self.sample_names) - save_embeddings( + io_utils.save_embeddings( self.path_to_embeddings, np.random.randn(n_data, n_dims), labels, @@ -160,13 +157,13 @@ def test_append_embeddings_with_overlap(self): ) # load the new (appended) embeddings - _, labels_appended, filenames_appended = load_embeddings( + _, labels_appended, filenames_appended = io_utils.load_embeddings( self.path_to_embeddings ) # define the expected filenames and labels self.create_fake_embeddings(n_data=n_data_local + n_data_start_local) - _, _, filenames_expected = load_embeddings(self.path_to_embeddings) + _, _, filenames_expected = io_utils.load_embeddings(self.path_to_embeddings) labels_expected = list(range(n_data_start_local)) + [0] * n_data_local # make sure the list of filenames and labels equal