Skip to content

Commit

Permalink
Merge branch 'master' into update-data-docs-1674
Browse files Browse the repository at this point in the history
  • Loading branch information
guarin authored Oct 10, 2024
2 parents a881429 + 9da0a24 commit 72e23d0
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 1 deletion.
3 changes: 3 additions & 0 deletions lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform
from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform
from lightly.transforms.pirl_transform import PIRLTransform
from lightly.transforms.random_frequency_mask_transform import (
RandomFrequencyMaskTransform,
)
from lightly.transforms.rfft2d_transform import RFFT2DTransform
from lightly.transforms.rotation import (
RandomRotate,
Expand Down
2 changes: 1 addition & 1 deletion lightly/transforms/random_crop_and_flip_with_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def forward(
return img, location


class RandomResizedCropAndFlip(nn.Module):
class RandomResizedCropAndFlip(nn.Module): # type: ignore[misc] # Class cannot subclass "RandomResizedCropAndFlip" (has type "Any")
"""Randomly flip and crop an image.
A PyTorch module that applies random cropping, horizontal and vertical flipping to an image,
Expand Down
41 changes: 41 additions & 0 deletions lightly/transforms/random_frequency_mask_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Tuple

import numpy as np
import torch
from torch import Tensor


class RandomFrequencyMaskTransform:
"""2D Random Frequency Mask Transformation.
This transformation applies a binary mask on the fourier transform,
across all channels. A proportion of k frequencies are set to 0 with this.
Input
- Tensor: RFFT of a 2D Image (C, H, W) C-> No. of Channels
Output
- Tensor: The masked RFFT of the image
"""

def __init__(self, k: Tuple[float, float] = (0.01, 0.1)) -> None:
self.k = k

def __call__(self, fft_image: Tensor) -> Tensor:
k = np.random.uniform(low=self.k[0], high=self.k[1])

# Every mask for every channel will have same frequencies being turned off i.e. being set to zero
mask = (
torch.rand(fft_image.shape[1:], device=fft_image.device) > k
) # mask_type: (H, W)

# Do not mask zero frequency mode to retain majority of the semantic information.
# Please refer https://arxiv.org/abs/2312.02205
mask[0, 0] = 1

# Adding channel dimension
mask = mask.unsqueeze(0)

masked_frequency_spectrum_image = fft_image * mask

return masked_frequency_spectrum_image
13 changes: 13 additions & 0 deletions tests/transforms/test_random_frequency_mask_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch

from lightly.transforms import RandomFrequencyMaskTransform, RFFT2DTransform


def test() -> None:
rfm_transform = RandomFrequencyMaskTransform()
rfft2d_transform = RFFT2DTransform()
image = torch.randn(3, 64, 64)
fft_image = rfft2d_transform(image)
transformed_image = rfm_transform(fft_image)

assert transformed_image.shape == fft_image.shape

0 comments on commit 72e23d0

Please sign in to comment.