This repo is a python implementation of Geometric Dataset Distances via Optimal Transport and Robust Optimal Transport. Routines are implemented in numpy with Python Optimal Transport and CVXPY, as well as in Pytorch using KeOps and GeomLoss.
The OTDD algorithm allows us to incorporate label information into the optimal transport problem.
• Algorithm Overview • API • Examples
Core dependencies can be installed from the environment.yml
file
conda env create -f environment.yml
To use the Pytorch implementation, install Pytorch, KeOps and GeomLoss
conda install pytorch torchvision torchaudio -c pytorch
pip install pykeops
pip install geomloss
Then validate the KeOps installation
import pykeops
pykeops.clean_pykeops()
pykeops.test_torch_bindings()
To use the cheminformatics functions in chem.py
, install RDKit
conda install -c rdkit rdkit