Skip to content

Commit

Permalink
Added the Amplitude Rescaling Transform
Browse files Browse the repository at this point in the history
  • Loading branch information
payo101 committed Oct 12, 2024
1 parent 56e669a commit 3082e12
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
1 change: 1 addition & 0 deletions lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# All Rights Reserved

from lightly.transforms.aim_transform import AIMTransform
from lightly.transforms.amplitude_rescale_transform import AmplitudeRescaleTranform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
Expand Down
42 changes: 42 additions & 0 deletions lightly/transforms/amplitude_rescale_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Tuple

import numpy as np
import torch
from torch import Tensor


class AmplitudeRescaleTranform:
"""
This transform will rescale the amplitude of the Fourier Spectrum (`input`) of the image and return it.
The scaling value *p* will range within `[m, n)`
```
img = torch.randn(3, 64, 64)
rfft = lightly.transforms.RFFT2DTransform()
rfft_img = rfft(img)
art = AmplitudeRescaleTransform()
rescaled_img = art(rfft_img)
```
# Intial Arguments
**range**: *Tuple of float_like*
The low `m` and high `n` values such that **p belongs to [m, n)**.
# Parameters:
**input**: _torch.Tensor_
The 2D Discrete Fourier Tranform of an Image.
# Returns:
**output**:_torch.Tensor_
The Fourier spectrum of the 2D Image with rescaled Amplitude.
"""

def __init__(self, range: Tuple[float, float] = (0.8, 1.75)) -> None:
self.m = range[0]
self.n = range[1]

def __call__(self, input: Tensor) -> Tensor:
p = np.random.uniform(self.m, self.n)

output = input * p

return output
21 changes: 21 additions & 0 deletions tests/transforms/test_amplitude_rescale_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np
import torch

from lightly.transforms import AmplitudeRescaleTranform, RFFT2DTransform


# Testing function image -> FFT -> AmplitudeRescale.
# Compare shapes of source and result.
def test() -> None:
image = torch.randn(3, 64, 64)

rfftTransform = RFFT2DTransform()
rfft = rfftTransform(image)

ampRescaleTf_1 = AmplitudeRescaleTranform()
rescaled_rfft_1 = ampRescaleTf_1(rfft)

ampRescaleTf_2 = AmplitudeRescaleTranform(range=(1.0, 2.0))
rescaled_rfft_2 = ampRescaleTf_2(rfft)

assert rescaled_rfft_1.shape == rfft.shape and rescaled_rfft_2.shape == rfft.shape

0 comments on commit 3082e12

Please sign in to comment.