diff --git a/lightly/utils/bounding_box.py b/lightly/utils/bounding_box.py index 322b13b16..695380b3b 100644 --- a/lightly/utils/bounding_box.py +++ b/lightly/utils/bounding_box.py @@ -1,4 +1,4 @@ -""" Bounding Box Utils """ +"""Bounding Box Utils""" from __future__ import annotations @@ -31,17 +31,26 @@ class BoundingBox: >>> # (x0, y0, x1, y1) = (10, 20, 30, 40) >>> W, H = 100, 100 # get image shape >>> bbox = BoundingBox(10 / W, 20 / H, 30 / W, 40 / H) - """ def __init__( self, x0: float, y0: float, x1: float, y1: float, clip_values: bool = True ): - """ - clip_values: - Set to true to clip the values into [0, 1] instead of raising an error if they lie outside. - """ + """Initializes a BoundingBox object. + + Args: + x0: + x0 coordinate relative to image width. + y0: + y0 coordinate relative to image height. + x1: + x1 coordinate relative to image width. + y1: + y1 coordinate relative to image height. + clip_values: + If True, clips the coordinates to [0, 1]. + """ if clip_values: def clip_to_0_1(value: float) -> float: @@ -60,14 +69,12 @@ def clip_to_0_1(value: float) -> float: if x0 >= x1: raise ValueError( - f"x0 must be smaller than x1 for bounding box " - f"[{x0}, {y0}, {x1}, {y1}]" + f"x0 must be smaller than x1 for bounding box [{x0}, {y0}, {x1}, {y1}]" ) if y0 >= y1: raise ValueError( - "y0 must be smaller than y1 for bounding box " - f"[{x0}, {y0}, {x1}, {y1}]" + f"y0 must be smaller than y1 for bounding box [{x0}, {y0}, {x1}, {y1}]" ) self.x0 = x0 @@ -77,7 +84,20 @@ def clip_to_0_1(value: float) -> float: @classmethod def from_x_y_w_h(cls, x: float, y: float, w: float, h: float) -> BoundingBox: - """Helper to convert from bounding box format with width and height. + """Creates a BoundingBox from x, y, width, and height. + + Args: + x: + x coordinate of the top-left corner relative to image width. + y: + y coordinate of the top-left corner relative to image height. + w: + Width of the bounding box relative to image width. + h: + Height of the bounding box relative to image height. + + Returns: + BoundingBox: A BoundingBox instance. Examples: >>> bbox = BoundingBox.from_x_y_w_h(0.1, 0.2, 0.2, 0.2) @@ -89,11 +109,23 @@ def from_x_y_w_h(cls, x: float, y: float, w: float, h: float) -> BoundingBox: def from_yolo_label( cls, x_center: float, y_center: float, w: float, h: float ) -> BoundingBox: - """Helper to convert from yolo label format - x_center, y_center, w, h --> x0, y0, x1, y1 + """Creates a BoundingBox from YOLO label format. + + Args: + x_center: + x coordinate of the center relative to image width. + y_center: + y coordinate of the center relative to image height. + w: + Width of the bounding box relative to image width. + h: + Height of the bounding box relative to image height. + + Returns: + BoundingBox: A BoundingBox instance. Examples: - >>> bbox = BoundingBox.from_yolo(0.5, 0.4, 0.2, 0.3) + >>> bbox = BoundingBox.from_yolo_label(0.5, 0.4, 0.2, 0.3) """ return cls( diff --git a/lightly/utils/debug.py b/lightly/utils/debug.py index 7d74d1f68..a2281ab6c 100644 --- a/lightly/utils/debug.py +++ b/lightly/utils/debug.py @@ -15,7 +15,6 @@ "'pip install lightly[matplotlib]'." ) except ImportError as ex: - # Matplotlib import can fail if an incompatible dateutil version is installed. plt = ex @@ -24,9 +23,9 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor: """Calculates the mean of the standard deviation of z along each dimension. This measure was used by [0] to determine the level of collapse of the - learned representations. If the returned number is 0., the outputs z have - collapsed to a constant vector. "If the output z has a zero-mean isotropic - Gaussian distribution" [0], the returned number should be close to 1/sqrt(d) + learned representations. If the returned value is 0., the outputs z have + collapsed to a constant vector. If the output z has a zero-mean isotropic + Gaussian distribution [0], the returned value should be close to 1/sqrt(d), where d is the dimensionality of the output. [0]: https://arxiv.org/abs/2011.10566 @@ -38,9 +37,7 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor: Returns: The mean of the standard deviation of the l2 normalized tensor z along each dimension. - """ - if len(z.shape) != 2: raise ValueError( f"Input tensor must have two dimensions but has {len(z.shape)}!" @@ -53,8 +50,18 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor: def apply_transform_without_normalize( image: Image.Image, transform, -): - """Applies the transform to the image but skips ToTensor and Normalize.""" +) -> Image.Image: + """Applies the transform to the image but skips ToTensor and Normalize. + + Args: + image: + The input PIL image. + transform: + The transformation to apply, excluding ToTensor and Normalize. + + Returns: + The transformed image. + """ skippable_transforms = ( torchvision.transforms.ToTensor, torchvision.transforms.Normalize, @@ -70,10 +77,10 @@ def apply_transform_without_normalize( def generate_grid_of_augmented_images( input_images: List[Image.Image], collate_function: Union[BaseCollateFunction, MultiViewCollateFunction], -): +) -> List[List[Image.Image]]: """Returns a grid of augmented images. Images in a column belong together. - This function ignores the transforms ToTensor and Normalize for visualization purposes. + This function ignores the ToTensor and Normalize transforms for visualization purposes. Args: input_images: @@ -116,9 +123,9 @@ def plot_augmented_images( input_images: List[Image.Image], collate_function: Union[BaseCollateFunction, MultiViewCollateFunction], ): - """Returns a figure showing original images in the left column and augmented images to their right. + """Plots original images and augmented images in a figure. - This function ignores the transforms ToTensor and Normalize for visualization purposes. + This function ignores the ToTensor and Normalize transforms for visualization purposes. Args: input_images: @@ -134,7 +141,6 @@ def plot_augmented_images( MultiViewCollateFunctions all the generated views are shown. """ - _check_matplotlib_available() if len(input_images) == 0: @@ -166,5 +172,6 @@ def plot_augmented_images( def _check_matplotlib_available() -> None: + """Checks if matplotlib is available. Raises an error if not.""" if isinstance(plt, Exception): raise plt diff --git a/lightly/utils/dependency.py b/lightly/utils/dependency.py index bdd36186c..8baeb2a77 100644 --- a/lightly/utils/dependency.py +++ b/lightly/utils/dependency.py @@ -3,24 +3,42 @@ @functools.lru_cache(maxsize=1) def torchvision_vit_available() -> bool: + """Checks if Vision Transformer (ViT) models are available in torchvision. + + This function checks if the `vision_transformer` module is available in torchvision, + which requires torchvision version >= 0.12. It also handles exceptions related to + CUDA version mismatches and installation issues. + + Returns: + True if the Vision Transformer (ViT) models are available in torchvision, + otherwise False. + """ try: - import torchvision.models.vision_transformer # Requires torchvision >=0.12 + import torchvision.models.vision_transformer # Requires torchvision >=0.12. except ( - RuntimeError, # Different CUDA versions for torch and torchvision - OSError, # Different CUDA versions for torch and torchvision (old) - ImportError, # No installation or old version of torchvision + RuntimeError, # Different CUDA versions for torch and torchvision. + OSError, # Different CUDA versions for torch and torchvision (old). + ImportError, # No installation or old version of torchvision. ): return False - else: - return True + return True @functools.lru_cache(maxsize=1) def timm_vit_available() -> bool: + """Checks if Vision Transformer (ViT) models are available in the timm library. + + This function checks if the `vision_transformer` module and `LayerType` from timm + are available, which requires timm version >= 0.3.3 and >= 0.9.9, respectively. + + Returns: + True if the Vision Transformer (ViT) models are available in timm, + otherwise False. + + """ try: import timm.models.vision_transformer # Requires timm >= 0.3.3 from timm.layers import LayerType # Requires timm >= 0.9.9 except ImportError: return False - else: - return True + return True diff --git a/lightly/utils/dist.py b/lightly/utils/dist.py index 7292afaca..5143f3597 100644 --- a/lightly/utils/dist.py +++ b/lightly/utils/dist.py @@ -8,19 +8,19 @@ class GatherLayer(torch.autograd.Function): """Gather tensors from all processes, supporting backward propagation. - This code was taken and adapted from here: + Adapted from the Solo-Learn project: https://github.com/vturrisi/solo-learn/blob/b69b4bd27472593919956d9ac58902a301537a4d/solo/utils/misc.py#L187 """ @staticmethod - def forward(ctx, input: torch.Tensor) -> Tuple[torch.Tensor, ...]: # type: ignore + def forward(ctx: FunctionCtx, input: torch.Tensor) -> Tuple[torch.Tensor, ...]: # type: ignore output = [torch.empty_like(input) for _ in range(dist.get_world_size())] dist.all_gather(output, input) return tuple(output) @staticmethod - def backward(ctx, *grads) -> torch.Tensor: # type: ignore + def backward(ctx: FunctionCtx, *grads: torch.Tensor) -> torch.Tensor: # type: ignore all_gradients = torch.stack(grads) dist.all_reduce(all_gradients) grad_out = all_gradients[dist.get_rank()] @@ -38,7 +38,7 @@ def world_size() -> int: def gather(input: torch.Tensor) -> Tuple[torch.Tensor]: - """Gathers this tensor from all processes. Supports backprop.""" + """Gathers a tensor from all processes and supports backpropagation.""" return GatherLayer.apply(input) # type: ignore[no-any-return] @@ -62,6 +62,9 @@ def eye_rank(n: int, device: Optional[torch.device] = None) -> torch.Tensor: device: Device on which the matrix should be created. + Returns: + A tensor with the appropriate diagonal filled for this rank. + """ rows = torch.arange(n, device=device, dtype=torch.long) cols = rows + rank() * n @@ -74,7 +77,7 @@ def eye_rank(n: int, device: Optional[torch.device] = None) -> torch.Tensor: def rank_zero_only(fn: Callable[..., R]) -> Callable[..., Optional[R]]: - """Decorator that only runs the function on the process with rank 0. + """Decorator to ensure the function only runs on the process with rank 0. Example: >>> @rank_zero_only diff --git a/lightly/utils/embeddings_2d.py b/lightly/utils/embeddings_2d.py index f3e277a71..c10d4bc3a 100644 --- a/lightly/utils/embeddings_2d.py +++ b/lightly/utils/embeddings_2d.py @@ -1,4 +1,4 @@ -""" Transform embeddings to two-dimensional space for visualization. """ +"""Transforms embeddings to two-dimensional space for visualization.""" # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved @@ -21,13 +21,18 @@ class PCA(object): Number of principal components to keep. eps: Epsilon for numerical stability. + mean: + Mean of the data. + w: + Eigenvectors of the covariance matrix. + """ def __init__(self, n_components: int = 2, eps: float = 1e-10): self.n_components = n_components + self.eps = eps self.mean: Optional[NDArray[np.float32]] = None self.w: Optional[NDArray[np.float32]] = None - self.eps = eps def fit(self, X: NDArray[np.float32]) -> PCA: """Fits PCA to data in X. @@ -37,7 +42,7 @@ def fit(self, X: NDArray[np.float32]) -> PCA: Datapoints stored in numpy array of size n x d. Returns: - PCA object to transform datapoints. + PCA: The fitted PCA object to transform data points. """ X = X.astype(np.float32) @@ -46,7 +51,7 @@ def fit(self, X: NDArray[np.float32]) -> PCA: X = X - self.mean + self.eps cov = np.cov(X.T) / X.shape[0] v, w = np.linalg.eig(cov) - idx = v.argsort()[::-1] + idx = v.argsort()[::-1] # Sort eigenvalues in descending order v, w = v[idx], w[:, idx] self.w = w return self @@ -62,10 +67,13 @@ def transform(self, X: NDArray[np.float32]) -> NDArray[np.float32]: Numpy array of n x p datapoints where p <= d. Raises: - ValueError: If PCA was not fitted before. + ValueError: + If PCA is not fitted before calling this method. + """ if self.mean is None or self.w is None: raise ValueError("PCA not fitted yet. Call fit() before transform().") + X = X.astype(np.float32) X = X - self.mean + self.eps transformed: NDArray[np.float32] = X.dot(self.w)[:, : self.n_components] @@ -77,7 +85,7 @@ def fit_pca( n_components: int = 2, fraction: Optional[float] = None, ) -> PCA: - """Fits PCA to randomly selected subset of embeddings. + """Fits PCA to a randomly selected subset of embeddings. For large datasets, it can be unfeasible to perform PCA on the whole data. This method can fit a PCA on a fraction of the embeddings in order to save @@ -101,8 +109,7 @@ def fit_pca( """ if fraction is not None: if fraction < 0.0 or fraction > 1.0: - msg = f"fraction must be in [0, 1] but was {fraction}." - raise ValueError(msg) + raise ValueError(f"fraction must be in [0, 1] but was {fraction}.") N = embeddings.shape[0] n = N if fraction is None else min(N, int(N * fraction)) diff --git a/lightly/utils/hipify.py b/lightly/utils/hipify.py index 37fbaf8a1..389295d0f 100644 --- a/lightly/utils/hipify.py +++ b/lightly/utils/hipify.py @@ -4,6 +4,8 @@ class bcolors: + """ANSI escape sequences for colored terminal output.""" + HEADER = "\033[95m" OKBLUE = "\033[94m" OKGREEN = "\033[92m" @@ -15,6 +17,18 @@ class bcolors: def print_as_warning(message: str, warning_class: Type[Warning] = UserWarning) -> None: + """Prints a warning message with custom formatting. + + Temporarily overrides the default warning format to apply custom styling, then + restores the original formatting after the warning is printed. + + Args: + message: + The warning message to print. + warning_class: + The type of warning to raise. + + """ old_format = copy.copy(warnings.formatwarning) warnings.formatwarning = _custom_formatwarning warnings.warn(message, warning_class) @@ -28,5 +42,24 @@ def _custom_formatwarning( lineno: int, line: Optional[str] = None, ) -> str: - # ignore everything except the message + """Custom format for warning messages. + + Only the warning message is printed, with additional styling applied. + + Args: + message: + The warning message or warning object. + category: + The warning class. + filename: + The file where the warning originated. + lineno: + The line number where the warning occurred. + line: + The line of code that triggered the warning (if available). + + Returns: + str: The formatted warning message. + + """ return f"{bcolors.WARNING}{message}{bcolors.WARNING}\n" diff --git a/lightly/utils/lars.py b/lightly/utils/lars.py index 315f14559..036178977 100644 --- a/lightly/utils/lars.py +++ b/lightly/utils/lars.py @@ -36,7 +36,6 @@ class LARS(Optimizer): >>> input = torch.Tensor(10) >>> target = torch.Tensor([1.]) >>> loss_fn = lambda input, target: (input - target) ** 2 - >>> # >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() @@ -99,11 +98,10 @@ def __init__( def __setstate__(self, state: Dict[str, Any]) -> None: super().__setstate__(state) - for group in self.param_groups: group.setdefault("nesterov", False) - # Type ignore for overloads is required for Python 3.7 + # Type ignore for overloads is required for Python 3.7. @overload # type: ignore[override] def step(self, closure: None = None) -> None: ... @@ -125,7 +123,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] with torch.enable_grad(): loss = closure() - # exclude scaling for params with 0 weight decay + # Exclude scaling for params with 0 weight decay. for group in self.param_groups: weight_decay = group["weight_decay"] momentum = group["momentum"] @@ -140,7 +138,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] p_norm = torch.norm(p.data) g_norm = torch.norm(p.grad.data) - # lars scaling + weight decay part + # Apply Lars scaling and weight decay. if weight_decay != 0: if p_norm != 0 and g_norm != 0: lars_lr = p_norm / ( @@ -151,7 +149,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] d_p = d_p.add(p, alpha=weight_decay) d_p *= lars_lr - # sgd part + # Apply momentum. if momentum != 0: param_state = self.state[p] if "momentum_buffer" not in param_state: @@ -159,6 +157,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] else: buf = param_state["momentum_buffer"] buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: d_p = d_p.add(buf, alpha=momentum) else: diff --git a/lightly/utils/scheduler.py b/lightly/utils/scheduler.py index f262245da..ef913a78b 100644 --- a/lightly/utils/scheduler.py +++ b/lightly/utils/scheduler.py @@ -32,9 +32,9 @@ def cosine_schedule( """ if step < 0: - raise ValueError(f"Current step number {step} can't be negative") + raise ValueError(f"Current step number {step} can't be negative.") if max_steps < 1: - raise ValueError(f"Total step number {max_steps} must be >= 1") + raise ValueError(f"Total step number {max_steps} must be >= 1.") if period is None and step > max_steps: warnings.warn( f"Current step number {step} exceeds max_steps {max_steps}.", @@ -102,9 +102,9 @@ def cosine_warmup_schedule( Cosine decay value. """ if warmup_steps < 0: - raise ValueError(f"Warmup steps {warmup_steps} can't be negative") + raise ValueError(f"Warmup steps {warmup_steps} can't be negative.") if warmup_steps > max_steps: - raise ValueError(f"Warmup steps {warmup_steps} must be <= max_steps") + raise ValueError(f"Warmup steps {warmup_steps} must be <= max_steps.") if step > max_steps: warnings.warn( f"Current step number {step} exceeds max_steps {max_steps}.", @@ -157,7 +157,7 @@ class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR): Target learning rate for warmup. Defaults to start_value. Note: The `epoch` arguments do not necessarily have to be epochs. Any step or index - can be used. The naming follows the Pytorch convention to use `epoch` for the steps + can be used. The naming follows the PyTorch convention to use `epoch` for the steps in the scheduler. """ @@ -181,6 +181,7 @@ def __init__( self.period = period self.warmup_start_value = warmup_start_value self.warmup_end_value = warmup_end_value + super().__init__( optimizer=optimizer, lr_lambda=self.scale_lr, @@ -189,8 +190,7 @@ def __init__( ) def scale_lr(self, epoch: int) -> float: - """ - Scale learning rate according to the current epoch number. + """Scale learning rate according to the current epoch number. Args: epoch: