Skip to content

Commit

Permalink
Reworked RandomFrequencyMask logic
Browse files Browse the repository at this point in the history
  • Loading branch information
payo101 committed Oct 9, 2024
1 parent e8cfc24 commit bce2003
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions lightly/transforms/random_frequency_mask_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ def __call__(self, fft_image: Tensor) -> Tensor:
k = np.random.uniform(low=self.k[0], high=self.k[1])

# Every mask for every channel will have same frequencies being turned off i.e. being set to zero
mask_type = (
mask = (
torch.rand(fft_image.shape[1:], device=fft_image.device) > k
) # mask_type: (H, W)
mask_list = [] # Mask: (C, H, W)

for c in range(fft_image.size(dim=0)):
mask_list.append(mask_type)
# Do not mask zero frequency mode to retain majority of the semantic information.
# Please refer https://arxiv.org/abs/2312.02205
mask[0, 0] = 1

# Adding channel dimension
mask = mask.unsqueeze(0)

mask = torch.stack(mask_list)
# To retain majority of the semantic information. Please refer https://arxiv.org/abs/2312.02205
mask[:, 0, 0] = 1
masked_frequency_spectrum_image = fft_image * mask

return masked_frequency_spectrum_image

0 comments on commit bce2003

Please sign in to comment.