Skip to content

Latest commit

 

History

History
49 lines (33 loc) · 1.16 KB

README.md

File metadata and controls

49 lines (33 loc) · 1.16 KB

SGLD in PyTorch

PyPI version

This package implements SGLD and cSGLD as a PyTorch Optimizer.

Installation

Install from pip as:

pip install torch-sgld

To install the latest directly from source, run

pip install git+https://github.com/activatedgeek/torch-sgld.git

Usage

The general idea is to modify the usual gradient-based update loops in PyTorch with the SGLD optimizer.

from torch_sgld import SGLD

f = module()  ## construct PyTorch nn.Module.

sgld = SGLD(f.parameters(), lr=lr, momentum=.9)  ## Add momentum to make it SG-HMC.
sgld_scheduler = ## Optionally add a step-size scheduler.

for _ in range(num_steps):
    energy = f()
    energy.backward()

    sgld.step()

    sgld_scheduler.step()  ## Optional scheduler step.

cSGLD can be implemented by using a cyclical learning rate schedule. See the toy_csgld.ipynb notebook for a complete example.

License

Apache 2.0