This package implements SGLD and cSGLD as a PyTorch Optimizer.
Install from pip
as:
pip install torch-sgld
To install the latest directly from source, run
pip install git+https://github.com/activatedgeek/torch-sgld.git
The general idea is to modify the usual gradient-based update loops
in PyTorch with the SGLD
optimizer.
from torch_sgld import SGLD
f = module() ## construct PyTorch nn.Module.
sgld = SGLD(f.parameters(), lr=lr, momentum=.9) ## Add momentum to make it SG-HMC.
sgld_scheduler = ## Optionally add a step-size scheduler.
for _ in range(num_steps):
energy = f()
energy.backward()
sgld.step()
sgld_scheduler.step() ## Optional scheduler step.
cSGLD
can be implemented by using a cyclical learning rate schedule.
See the toy_csgld.ipynb notebook for a
complete example.
Apache 2.0