Skip to content

Commit

Permalink
Merge pull request #249 from Netruk44/upstream-configurable-augment
Browse files Browse the repository at this point in the history
Add parameters to color augments & separate out into separate augments.
  • Loading branch information
lucidrains authored Aug 6, 2021
2 parents fc22408 + 1693604 commit 573d250
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions stylegan2_pytorch/diff_augment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
import random

import torch
import torch.nn.functional as F

Expand All @@ -20,18 +20,18 @@ def DiffAugment(x, types=[]):
# 3 - height of image
# """

def rand_brightness(x):
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
def rand_brightness(x, scale):
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) * scale
return x

def rand_saturation(x):
def rand_saturation(x, scale):
x_mean = x.mean(dim=1, keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
x = (x - x_mean) * (((torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) * 2.0 * scale) + 1.0) + x_mean
return x

def rand_contrast(x):
def rand_contrast(x, scale):
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
x = (x - x_mean) * (((torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) * 2.0 * scale) + 1.0) + x_mean
return x

def rand_translation(x, ratio=0.125):
Expand Down Expand Up @@ -93,7 +93,14 @@ def rand_cutout(x, ratio=0.5):
return x

AUGMENT_FNS = {
'color': [rand_brightness, rand_saturation, rand_contrast],
'brightness': [partial(rand_brightness, scale=1.)],
'lightbrightness': [partial(rand_brightness, scale=.65)],
'contrast': [partial(rand_contrast, scale=.5)],
'lightcontrast': [partial(rand_contrast, scale=.25)],
'saturation': [partial(rand_saturation, scale=1.)],
'lightsaturation': [partial(rand_saturation, scale=.5)],
'color': [partial(rand_brightness, scale=1.), partial(rand_saturation, scale=1.), partial(rand_contrast, scale=0.5)],
'lightcolor': [partial(rand_brightness, scale=0.65), partial(rand_saturation, scale=.5), partial(rand_contrast, scale=0.5)],
'offset': [rand_offset],
'offset_h': [rand_offset_h],
'offset_v': [rand_offset_v],
Expand Down

0 comments on commit 573d250

Please sign in to comment.