diff --git a/setup.py b/setup.py index 1abc8bb..7d594ba 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'stylegan2_pytorch', packages = find_packages(), scripts=['bin/stylegan2_pytorch'], - version = '0.17.2', + version = '0.17.4', license='GPLv3+', description = 'StyleGan2 in Pytorch', author = 'Phil Wang', diff --git a/stylegan2_pytorch/stylegan2_pytorch.py b/stylegan2_pytorch/stylegan2_pytorch.py index 60b6922..1fc7b7c 100644 --- a/stylegan2_pytorch/stylegan2_pytorch.py +++ b/stylegan2_pytorch/stylegan2_pytorch.py @@ -87,6 +87,13 @@ def forward(self, x): out = out.permute(0, 3, 1, 2) return out, loss +# one layer of self-attention and feedforward, for images + +attn_and_ff = lambda chan: nn.Sequential(*[ + Residual(Rezero(ImageLinearAttention(chan))), + Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1)))) +]) + # helpers def default(value, d): @@ -424,9 +431,7 @@ def __init__(self, image_size, latent_dim, network_capacity = 16, transparent = not_last = ind != (self.num_layers - 1) num_layer = self.num_layers - ind - attn_fn = nn.Sequential(*[ - Residual(Rezero(ImageLinearAttention(in_chan))) for _ in range(2) - ]) if num_layer in attn_layers else None + attn_fn = attn_and_ff(in_chan) if num_layer in attn_layers else None self.attns.append(attn_fn) @@ -484,9 +489,7 @@ def __init__(self, image_size, network_capacity = 16, fq_layers = [], fq_dict_si block = DiscriminatorBlock(in_chan, out_chan, downsample = is_not_last) blocks.append(block) - attn_fn = nn.Sequential(*[ - Residual(Rezero(ImageLinearAttention(out_chan))) for _ in range(2) - ]) if num_layer in attn_layers else None + attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None attn_blocks.append(attn_fn)