Skip to content

Commit

Permalink
add diffaugment script from zhaos paper, allow user to specify which …
Browse files Browse the repository at this point in the history
…augmentation types to use from command line
  • Loading branch information
lucidrains committed Sep 24, 2020
1 parent 0cca585 commit e8381fb
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 22 deletions.
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ If you have one machine with multiple GPUs, the repository offers a way to utili
You simply have to add a `--multi-gpus` flag, everyting else is taken care of. If you would like to restrict to specific GPUs, you can use the `CUDA_VISIBLE_DEVICES` environment variable to control what devices can be used. (ex. `CUDA_VISIBLE_DEVICES=0,2,3` only devices 0, 2, 3 are available)

```bash
$ stylegan2_pytorch --data /path/to/data --multi-gpus --batch-size 32 --gradient-accumulate-every 1
$ stylegan2_pytorch --data ./data --multi-gpus --batch-size 32 --gradient-accumulate-every 1
```

## Low amounts of Training Data
Expand All @@ -116,6 +116,15 @@ In the setting of low data, you can use the feature with a simple flag.
$ stylegan2_pytorch --data ./data --aug-prob 0.25
```

By default, the augmentations used are `translation` and `cutout`. If you would like to add `color`, you can do so with the `--aug-types` argument.

```bash
# make sure there are no spaces between items!
$ stylegan2_pytorch --data ./data --aug-prob 0.25 --aug-types [translation,cutout,color]
```

You can customize it to any combination of the three you would like. The differentiable augmentation code was copied and slightly modified from <a href="https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment_pytorch.py">here</a>.

## Attention

This framework also allows for you to add an efficient form of self-attention to the designated layers of the discriminator (and the symmetric layer of the generator), which will greatly improve results. The more attention you can afford, the better!
Expand Down Expand Up @@ -314,6 +323,15 @@ Thank you to Matthew Mann for his inspiring [simple port](https://github.com/man
}
```

```bibtex
@article{zhao2020diffaugment,
title = {Differentiable Augmentation for Data-Efficient GAN Training},
author = {Zhao, Shengyu and Liu, Zhijian and Lin, Ji and Zhu, Jun-Yan and Han, Song},
journal = {arXiv preprint arXiv:2006.10738},
year = {2020}
}
```

```bibtex
@misc{zhao2020image,
title = {Image Augmentations for GAN Training},
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
'stylegan2_pytorch = stylegan2_pytorch.cli:main',
],
},
version = '0.22.2',
version = '0.22.3',
license='GPLv3+',
description = 'StyleGan2 in Pytorch',
author = 'Phil Wang',
Expand Down
2 changes: 2 additions & 0 deletions stylegan2_pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def train_from_folder(
attn_layers = [],
no_const = False,
aug_prob = 0.,
aug_types = ['translation', 'cutout'],
dataset_aug_prob = 0.,
multi_gpus = False
):
Expand All @@ -118,6 +119,7 @@ def train_from_folder(
attn_layers = attn_layers,
no_const = no_const,
aug_prob = aug_prob,
aug_types = cast_list(aug_types),
dataset_aug_prob = dataset_aug_prob
)

Expand Down
59 changes: 59 additions & 0 deletions stylegan2_pytorch/diff_augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import torch.nn.functional as F

def DiffAugment(x, types=[]):
for p in types:
for f in AUGMENT_FNS[p]:
x = f(x)
return x.contiguous()

def rand_brightness(x):
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
return x

def rand_saturation(x):
x_mean = x.mean(dim=1, keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
return x

def rand_contrast(x):
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
return x

def rand_translation(x, ratio=0.125):
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(x.size(2), dtype=torch.long, device=x.device),
torch.arange(x.size(3), dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
return x

def rand_cutout(x, ratio=0.5):
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
mask[grid_batch, grid_x, grid_y] = 0
x = x * mask.unsqueeze(1)
return x

AUGMENT_FNS = {
'color': [rand_brightness, rand_saturation, rand_contrast],
'translation': [rand_translation],
'cutout': [rand_cutout],
}
32 changes: 12 additions & 20 deletions stylegan2_pytorch/stylegan2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import torchvision
from torchvision import transforms
from stylegan2_pytorch.diff_augment import DiffAugment

from vector_quantize_pytorch import VectorQuantize
from linear_attention_transformer import ImageLinearAttention
Expand Down Expand Up @@ -166,7 +167,7 @@ def gradient_penalty(images, output, weight = 10):
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True, retain_graph=True, only_inputs=True)[0]

gradients = gradients.view(batch_size, -1)
gradients = gradients.reshape(batch_size, -1)
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()

def calc_pl_lengths(styles, images):
Expand Down Expand Up @@ -294,18 +295,6 @@ def __getitem__(self, index):

# augmentations

def random_float(lo, hi):
return lo + (hi - lo) * random()

def random_crop_and_resize(tensor, scale):
b, c, h, _ = tensor.shape
new_width = int(h * scale)
delta = h - new_width
h_delta = int(random() * delta)
w_delta = int(random() * delta)
cropped = tensor[:, :, h_delta:(h_delta + new_width), w_delta:(w_delta + new_width)].clone()
return F.interpolate(cropped, size=(h, h), mode='bilinear')

def random_hflip(tensor, prob):
if prob > random():
return tensor
Expand All @@ -316,11 +305,10 @@ def __init__(self, D, image_size):
super().__init__()
self.D = D

def forward(self, images, prob = 0., detach = False):
def forward(self, images, prob = 0., types = [], detach = False):
if random() < prob:
random_scale = random_float(0.75, 0.95)
images = random_hflip(images, prob=0.5)
images = random_crop_and_resize(images, scale = random_scale)
images = DiffAugment(images, types=types)

if detach:
images.detach_()
Expand Down Expand Up @@ -673,7 +661,7 @@ def forward(self, x):
return x

class Trainer():
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, lr_mlp = 1., ttur_mult = 2, rel_disc_loss = False, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., dataset_aug_prob = 0., is_ddp = False, rank = 0, world_size = 1, *args, **kwargs):
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, lr_mlp = 1., ttur_mult = 2, rel_disc_loss = False, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, cl_reg = False, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, aug_prob = 0., aug_types = ['translation', 'cutout'], dataset_aug_prob = 0., is_ddp = False, rank = 0, world_size = 1, *args, **kwargs):
self.GAN_params = [args, kwargs]
self.GAN = None

Expand All @@ -691,7 +679,9 @@ def __init__(self, name, results_dir, models_dir, image_size, network_capacity,

self.attn_layers = cast_list(attn_layers)
self.no_const = no_const

self.aug_prob = aug_prob
self.aug_types = aug_types

self.lr = lr
self.lr_mlp = lr_mlp
Expand Down Expand Up @@ -787,6 +777,8 @@ def train(self):
num_layers = self.GAN.G.num_layers

aug_prob = self.aug_prob
aug_types = self.aug_types
aug_kwargs = {'prob': aug_prob, 'types': aug_types}

apply_gradient_penalty = self.steps % 4 == 0
apply_path_penalty = self.steps > 5000 and self.steps % 32 == 0
Expand Down Expand Up @@ -838,11 +830,11 @@ def train(self):
w_styles = styles_def_to_tensor(w_space)

generated_images = G(w_styles, noise)
fake_output, fake_q_loss = D_aug(generated_images.clone().detach(), detach = True, prob = aug_prob)
fake_output, fake_q_loss = D_aug(generated_images.clone().detach(), detach = True, **aug_kwargs)

image_batch = next(self.loader).cuda(self.rank)
image_batch.requires_grad_()
real_output, real_q_loss = D_aug(image_batch, prob = aug_prob)
real_output, real_q_loss = D_aug(image_batch, **aug_kwargs)

real_output_loss = real_output
fake_output_loss = fake_output
Expand Down Expand Up @@ -885,7 +877,7 @@ def train(self):
w_styles = styles_def_to_tensor(w_space)

generated_images = G(w_styles, noise)
fake_output, _ = D_aug(generated_images, prob = aug_prob)
fake_output, _ = D_aug(generated_images, **aug_kwargs)
loss = fake_output.mean()
gen_loss = loss

Expand Down

0 comments on commit e8381fb

Please sign in to comment.