Skip to content

Commit

Permalink
Adde PR requested Changes
Browse files Browse the repository at this point in the history
  • Loading branch information
payo101 committed Oct 9, 2024
1 parent 1136690 commit e8cfc24
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
4 changes: 3 additions & 1 deletion lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform
from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform
from lightly.transforms.pirl_transform import PIRLTransform
from lightly.transforms.random_frequency_mask_transform import RFMTransform
from lightly.transforms.random_frequency_mask_transform import (
RandomFrequencyMaskTransform,
)
from lightly.transforms.rfft2d_transform import RFFT2DTransform
from lightly.transforms.rotation import (
RandomRotate,
Expand Down
30 changes: 16 additions & 14 deletions lightly/transforms/random_frequency_mask_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from torch import Tensor


class RFMTransform:
class RandomFrequencyMaskTransform:
"""2D Random Frequency Mask Transformation.
This transformation applies a binary mask on the fourier transform,
across all channels. k% of frequencies are set to 0 with this.
k ranges [0.01, 0.1)
across all channels. A proportion of k frequencies are set to 0 with this.
Input
- Tensor: RFFT of a 2D Image (C, H, W) C-> No. of Channels
Expand All @@ -19,21 +18,24 @@ class RFMTransform:
"""

def __call__(self, fft_image: Tensor) -> Tensor:
k = np.random.uniform(0.01, 0.1)
# Mask: (C, H, W)
mask = torch.ones_like(fft_image)
def __init__(self, k: Tuple[float, float] = (0.01, 0.2)) -> None:
self.k = k

total_frequencies = torch.numel(fft_image[0])
num_frequencies_zeroed = int(total_frequencies * k)
zero_frequency_idxs = torch.randperm(total_frequencies)[:num_frequencies_zeroed]
def __call__(self, fft_image: Tensor) -> Tensor:
k = np.random.uniform(low=self.k[0], high=self.k[1])

# Every mask for every channel will have same frequencies being turned off i.e. being set to zero
for c in range(mask.size(dim=0)):
mask[c].view(-1)[zero_frequency_idxs] = 0
# To retain majority of the semantic information. Please refer https://arxiv.org/abs/2312.02205
mask[c][0][0] = 1
mask_type = (
torch.rand(fft_image.shape[1:], device=fft_image.device) > k
) # mask_type: (H, W)
mask_list = [] # Mask: (C, H, W)

for c in range(fft_image.size(dim=0)):
mask_list.append(mask_type)

mask = torch.stack(mask_list)
# To retain majority of the semantic information. Please refer https://arxiv.org/abs/2312.02205
mask[:, 0, 0] = 1
masked_frequency_spectrum_image = fft_image * mask

return masked_frequency_spectrum_image
4 changes: 2 additions & 2 deletions tests/transforms/test_random_frequency_mask_transform.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch

from lightly.transforms import RFFT2DTransform, RFMTransform
from lightly.transforms import RandomFrequencyMaskTransform, RFFT2DTransform


def test() -> None:
rfm_transform = RFMTransform()
rfm_transform = RandomFrequencyMaskTransform()
rfft2d_transform = RFFT2DTransform()
image = torch.randn(3, 64, 64)
fft_image = rfft2d_transform(image)
Expand Down

0 comments on commit e8cfc24

Please sign in to comment.