Yet another PyTorch implementation of the 2018 ACL Multimedia paper on Semantic Human Matting.
All dependencies are listed in the Pipfile. You can install them using pipenv.
$ pipenv install
This repository depends on the PSPNet implementation from https://github.com/hszhao/semseg. You will need to download the resnet model resnet50_v2.pth
from the initmodel
directory from this google drive link and place it in data/models
.
This repository expects training data in the form of raw images and alpha mattes placed in data/images
and data/mattes
folders respectively.
If you're considering pretraining the TNet separately, you will need target trimaps for training. To do so, simply run the generate_trimap.py
script located in the data
directory with a list of all files to be converted in images.txt
. This will create trimaps in data/trimaps
which can be used while pre-training the model.
# cd data
$ python3 generate_trimap.py
This repository currently assumes that the final mattes in data/mattes
are also the ground truths for pre-training the MNet. There is no support for using a separate ground-truth as of now.
To train the image matting pipeline end-to-end, simply run the train.py
script.
$ python3 train.py
The training script also supports pre-training of TNet and MNet. This can easily be done by using the --mode
flag.
# Pre-train TNet
$ python3 train.py --mode pretrain_tnet
# Pre-train MNet
$ python3 train.py --mode pretrain_mnet
For additional options such as changing hyperparameters or using a GPU, please use the --help
flag.
To run inference with a trained model, use the test.py
script. This will automatically choose the best model available.
$ python3 test.py
For additional options, please see the --help
flag.
Although there are a bunch of implementations available for this paper, here are a few key differences why you might want to consider this repository.
- Minimal dependencies: The only dependencies are
torch
andtorchvision
. - Correct loss computation: Most other implementations use the L2 loss even when the paper specifically mentions the L1 loss.
- Based on official repositories: The code is based on the official implementations of PSPNet and DIMNet.
This repository is primarily based on the official implementations of PSPNet and DIMNet from https://github.com/foamliu/Deep-Image-Matting-PyTorch and https://github.com/hszhao/semseg respectively. Any other attributions are commented on top of individual files.