Skip to content

Commit

Permalink
Implementing requested changes on GMM
Browse files Browse the repository at this point in the history
  • Loading branch information
snehilchatterjee committed Oct 14, 2024
1 parent 8366a47 commit c9fe57a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 24 deletions.
2 changes: 1 addition & 1 deletion lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform
from lightly.transforms.fast_siam_transform import FastSiamTransform
from lightly.transforms.gaussian_blur import GaussianBlur
from lightly.transforms.gaussian_mixture_masks_transform import GaussianMixtureMasks
from lightly.transforms.gaussian_mixture_masks_transform import GaussianMixtureMask
from lightly.transforms.irfft2d_transform import IRFFT2DTransform
from lightly.transforms.jigsaw import Jigsaw
from lightly.transforms.mae_transform import MAETransform
Expand Down
54 changes: 31 additions & 23 deletions lightly/transforms/gaussian_mixture_masks_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@
from lightly.transforms.rfft2d_transform import RFFT2DTransform


class GaussianMixtureMasks:
"""Applies a Gaussian Mixture Mask in the Fourier domain to RGB images.
class GaussianMixtureMask:
"""Applies a Gaussian Mixture Mask in the Fourier domain to a single-channel image.
The mask is created using random Gaussian kernels, which are applied in
the frequency domain via RFFT2D, and then the IRFFT2D is used to return
to the spatial domain. The transformation is applied to each RGB channel separately.
to the spatial domain. The transformation is applied to each image channel separately.
Attributes:
num_gaussians: Number of Gaussian kernels to generate in the mixture mask.
std_range: Tuple containing the minimum and maximum standard deviation for the Gaussians.
"""

def __init__(self, num_gaussians: int = 20, std_range: Tuple[int, int] = (10, 15)):
def __init__(
self, num_gaussians: int = 20, std_range: Tuple[float, float] = (10, 15)
):
"""Initializes GaussianMixtureMasks with the given parameters.
Args:
Expand All @@ -29,6 +31,7 @@ def __init__(self, num_gaussians: int = 20, std_range: Tuple[int, int] = (10, 15
"""

self.rfft2d_transform = RFFT2DTransform()

self.num_gaussians = num_gaussians
self.std_range = std_range

Expand All @@ -38,14 +41,16 @@ def gaussian_kernel(
"""Generates a 2D Gaussian kernel.
Args:
size: Tuple specifying the dimensions of the Gaussian kernel (C, H, W).
size: Tuple specifying the dimensions of the Gaussian kernel (H, W).
sigma: Tensor specifying the standard deviation of the Gaussian.
center: Tensor specifying the center of the Gaussian kernel.
Returns:
Tensor: A 2D Gaussian kernel.
"""
u, v = torch.meshgrid(torch.arange(0, size[0]), torch.arange(0, size[1]))
u = u.to(sigma.device)
v = v.to(sigma.device)
u0, v0 = center
gaussian = torch.exp(
-((u - u0) ** 2 / (2 * sigma[0] ** 2) + (v - v0) ** 2 / (2 * sigma[1] ** 2))
Expand All @@ -54,50 +59,53 @@ def gaussian_kernel(
return gaussian

def apply_gaussian_mixture_mask(
self, image_channel: Tensor, num_gaussians: int, std: Tuple[int, int]
self, freq_image: Tensor, num_gaussians: int, std: Tuple[int, int]
) -> Tensor:
"""Applies the Gaussian mixture mask to a single channel in the frequency domain.
"""Applies the Gaussian mixture mask to a frequency-domain image.
Args:
image_channel: Tensor representing a single channel of the image.
freq_image: Tensor representing the frequency-domain image of shape (C, H, W//2+1).
num_gaussians: Number of Gaussian kernels to generate in the mask.
std: Tuple specifying the standard deviation range for the Gaussians.
Returns:
Tensor: Image after applying the Gaussian mixture mask.
"""
image_size = image_channel[0].shape
image_size = freq_image.shape[1:]
original_height = image_size[0]
original_width = 2 * (image_size[1] - 1)

original_shape = (original_height, original_width)

self.irfft2d_transform = IRFFT2DTransform((image_size[0], image_size[1]))
f_transform = self.rfft2d_transform(image_channel)
self.irfft2d_transform = IRFFT2DTransform(original_shape)

size = f_transform[0].shape
size = freq_image[0].shape

mask = torch.ones(size)
mask = freq_image.new_ones(freq_image.shape)

for _ in range(num_gaussians):
u0 = torch.randint(0, size[0], (1,))
v0 = torch.randint(0, size[1], (1,))
center = torch.tensor((u0, v0))
sigma = torch.rand(2) * 5 + 10
u0 = torch.randint(0, size[0], (1,), device=freq_image.device)
v0 = torch.randint(0, size[1], (1,), device=freq_image.device)
center = torch.tensor((u0, v0), device=freq_image.device)
sigma = torch.rand(2, device=freq_image.device) * (std[1] - std[0]) + std[0]

g_kernel = self.gaussian_kernel((size[0], size[1]), sigma, center)
mask -= g_kernel

filtered_f_transform = f_transform * mask
filtered_image = self.irfft2d_transform(filtered_f_transform).abs()
filtered_freq_image = freq_image * mask
filtered_image = self.irfft2d_transform(filtered_freq_image).abs()
return filtered_image

def __call__(self, image_tensor: Tensor) -> Tensor:
"""Applies the Gaussian mixture mask transformation to the input image.
def __call__(self, freq_image: Tensor) -> Tensor:
"""Applies the Gaussian mixture mask transformation to the input frequency-domain image.
Args:
image_tensor: Tensor representing an RGB image of shape (C, H, W).
freq_image: Tensor representing a frequency-domain image of shape (C, H, W//2+1).
Returns:
Tensor: The transformed image after applying the Gaussian mixture mask.
"""
transformed_channel: Tensor = self.apply_gaussian_mixture_mask(
image_tensor, self.num_gaussians, self.std_range
freq_image, self.num_gaussians, self.std_range
)
return transformed_channel

0 comments on commit c9fe57a

Please sign in to comment.