This project aims to implement neural network architecture, described in Krishnan et al. (2021) -- Text Style Brush.
Our implementation is unofficial and might contain some differences from the origin implementation. You can find a link to slides from project presentation as well.
- Install requirements
pip install -r requirements.txt
- Choose config file in
src/config
folder - Log in into wandb if needed
wandb login
- Download
models
folder from cloud this folder contains all pretrained models which we use. This folder should be in root folder as shown below. - Download IMGUR5K dataset: use original
download_imgur5k.py
script which you can find here. You can clone whole origin repo, it will be easier. Tip: there is a PR with parallel execution of image download. - Add
prepare_dataset.py
script in that repo and run it to preprocess files as we did it. - Put prepared dataset in
data/
folder of this project. - Run
python3 run.py './src/config/<chosen config>'
. In the most casespython3 run.py './src/config/stylegan_adversarial.py'
.
├── run.py <- [entry point]
│
├── prepare_dataset.py <- [our preprocess of images]
│
├── requirements.txt <- [necessary requirements]
│
├── data <- [necessary data (including downloaded datasets)]
|
├── docs <- [docs and images]
|
├── models <- [pretrained models -- download this folder from cloud]
|
├── src <- [project source code]
│ ├── config
│ │ ├── simple.py <- [Template Config]
│ │ ├── gan.py
│ │ ├── ...
│ │
│ ├── data
│ │ ├── simple.py <- [Template CustomDataset]
│ │ ├── ...
│ │
│ ├── disk
│ │ ├── disk.py <- [Disk class to upload and download data from cloud]
│ │ ├── ...
│ │
│ ├── logger
│ │ ├── simple.py <- [Logger class to log train and validation process]
│ │ ├── ...
│ │
│ ├── losses
│ │ ├── ocr.py <- [Recognizer Loss]
│ │ ├── perceptual.py
│ │ ├── ...
│ │
│ ├── metrics
│ │ ├── accuracy.py <- [Accuracy Metric]
│ │ ├── ...
│ │
│ ├── models
│ │ ├── ocr.py <- [Model for CTC Loss]
│ │ ├── ...
│ │
│ ├── storage
│ │ ├── simple.py <- [Storage class to save models' checkpoints]
│ │ ├── ...
│ │
│ ├── training
│ │ ├── simple.py <- [Template Trainer]
│ │ ├── stylegan.py
│ │ ├── ...
│ │
│ ├── utils
│ │ ├── download.py <- [Tool to download data from remote to cloud]
│ │ ├── ...
│ │
│ ├── ...
We started our work from a very simple architecture, shown below:
We call it baseline and you can find its config here. We did it because we could and because we needed something to set up work space.
Anyway, we ended up with this architecture, very similar to TextStyleBrush:
You can find its config here. It's not perfect, but we did our best -- you can check out results below.
Before you do, there are differences with the original paper:
Subject | Us | TextStyleBrush |
---|---|---|
Generator | styleGAN | styleGAN2 |
Encoders | resNet18 | resNet34 |
Style loss model | VGG16 | VGG19 |
Input style size | 64 x 192 | 256 x 256 |
Input content size | 64 x 192 | 64 x 256 |
Soft masks | no | yes |
Adversarial loss | MSE | non-saturating loss with regularization |
Discriminator | NLayerDiscriminator | ?? |
Text recognizer | TRBA | ?? |
Hardware | Google Colab resources : ) | 8GPUS with 16GB of RAM |
We trained our model using Imgur5K dataset. You can download it using instruction from the origin repo.
What we did: we dowloaded original repo from the link above. We modified download_imgur5k.py
a little bit: added ability to proceed download process from the point where it stopped in case of exeptions and added ability to run it in parallel. We do not publish this version because we were afraid of conflicts with their Licence. Anyway you can do it yourself or use code from PR in official repo.
After that we added prepare_dataset.py
to that folder and ran it. Output if this script is the dataset which we use. Put in into data/
folder of this project and you are ready to go.
We did our best to make classes' names speak for themselves. Anyway, small intro:
Config
class stored insrc/config
contains information about experiment configuration: model, loss functions, coefficients, optimaizer, dataset info, tools, etc.Trainer
class stored insrc/training
contains information about experiment training process: train and validation steps, losses' calculation and propagation, etc.
We use Yandex.disk with 1TB storage to store dataset, logs and checkpoints. Main reason for us to use it -- we had free access to this service.
We understand that usage of this service is not user-friendly for other users and will come up with the solution soon. Right now you can comment out disk
class from the code and download necessary datasets manually in data
folder.
- PyTorch framework
- Python 3.7.13
- Type Annotations
- CI tests
- To our supervisors -- Renat Khizbullin and Danil Kononyhin from Huawei Technologies for their support and lead in this project.