Implementation of the ID-GAN in PyTorch (https://arxiv.org/pdf/2001.04296.pdf)
import id_gan
# Train and save VAE
id_gan.train_vae("mnist", epochs=20, batch_size=512, num_workers=8)
# Train and save GAN based on VAE
id_gan.train_gan("mnist", epochs=60, batch_size=512, num_workers=8)
# Load generator from GAN model
gan = id_gan.load_gan("mnist")
# Sample
with torch.no_grad():
z = torch.randn(16, 40)
images = gan(z).numpy()
The same code as above, but with the parameter celeba
instead of mnist
- Add dataset config to the
id_gan/config.py
- Add function to create dataset to the
id_gan/data.py