Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework ad to prevent compiling the same gradient multiple times #374

Merged
merged 6 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ HypercubeTransform = "9ec9aee3-0fd3-44c2-8e61-a50acc66f3c8"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
PaddedViews = "5432bcbf-9aad-5242-b902-cca2824c8663"
ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5"
Expand Down Expand Up @@ -75,14 +74,13 @@ DimensionalData = "0.27, 0.28"
Distributions = "0.25"
DocStringExtensions = "0.8, 0.9"
Dynesty = "0.4"
Enzyme = "0.12"
EnzymeCore = "0.7"
Enzyme = "0.13"
EnzymeCore = "0.8"
FillArrays = "1"
ForwardDiff = "0.9, 0.10"
HypercubeTransform = "0.4"
IntervalSets = "0.6, 0.7"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
Makie = "0.21"
NamedTupleTools = "0.13,0.14"
Optimization = "4"
Expand All @@ -101,7 +99,7 @@ StatsBase = "0.33,0.34"
StructArrays = "0.5,0.6"
Tables = "1"
TransformVariables = "0.8"
VLBIImagePriors = "0.8"
VLBIImagePriors = "0.9"
VLBILikelihoods = "^0.2.6"
VLBISkyModels = "0.6"
julia = "1.10"
Expand Down
7 changes: 6 additions & 1 deletion docs/src/ext/ahmc.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@ To sample a user can use follow the standard `AdvancedHMC` interface, e.g.,
chain = sample(post, NUTS(0.8), 10_000; n_adapts=5_000)
```

!!! warning
To use HMC the `VLBIPosterior` must be created with a specific `admode` specified.
The `admode` can be a union of `Nothing` and `<:EnzymeCore.Mode` types. We recommend
using `Enzyme.set_runtime_activity(Enzyme.Reverse)`


In addition our sample call has a few additional keyword arguments:

- `adtype = Val(:Enzyme)`: The autodiff package to use. Currently the only options is `Enzyme`. Note that you must load Enzyme before calling `sample`.
- `saveto = MemoryStore()`: Specifies how to store the samples. The default is `MemoryStore` which stores the samples directly in RAM. For large models this is not a good idea. To save samples periodically to disk use [`DiskStore`](@ref), and then load the results with `load_samples`.

Note that like most `AbstractMCMC` samplers the initial location can be specified with the `initial_params` argument.
Expand Down
8 changes: 7 additions & 1 deletion docs/src/ext/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ optimization algorithm.
To see what optimizers are available and what options are available, please see the `Optimizations.jl` [docs](http://optimization.sciml.ai/dev/).


!!! warning
To use use a gradient optimizer with AD, `VLBIPosterior` must be created with a specific `admode` specified.
The `admode` can be a union of `Nothing` and `<:EnzymeCore.Mode` types. We recommend
using `Enzyme.set_runtime_activity(Enzyme.Reverse)`


## Example

```julia
Expand All @@ -18,5 +24,5 @@ using Enzyme
# Some stuff to create a posterior object
post # of type Comrade.Posterior

xopt, sol = comrade_opt(post, LBFGS(); adtype=Val(:Enzyme))
xopt, sol = comrade_opt(post, LBFGS())
```
2 changes: 1 addition & 1 deletion examples/advanced/HybridImaging/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ Plots = "1"
Pyehtim = "0.1"
StableRNGs = "1"
StatsBase = "0.34"
VLBIImagePriors = "0.8"
VLBIImagePriors = "0.9"
6 changes: 3 additions & 3 deletions examples/advanced/HybridImaging/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ skym = SkyModel(sky, skyprior, g; metadata=skymetadata)

# This is everything we need to specify our posterior distribution, which our is the main
# object of interest in image reconstructions when using Bayesian inference.
post = VLBIPosterior(skym, intmodel, dvis)
using Enzyme
post = VLBIPosterior(skym, intmodel, dvis; admode=set_runtime_activity(Enzyme.Reverse))

# To sample from our prior we can do
xrand = prior_sample(rng, post)
Expand All @@ -179,8 +180,7 @@ fig |> DisplayAs.PNG |> DisplayAs.Text #hide
# To use this we use the [`comrade_opt`](@ref) function
using Optimization
using OptimizationOptimJL
using Enzyme
xopt, sol = comrade_opt(post, LBFGS(), AutoEnzyme(;mode=Enzyme.Reverse);
xopt, sol = comrade_opt(post, LBFGS();
initial_params=prior_sample(rng, post), maxiters=1000, g_tol=1e0)


Expand Down
2 changes: 1 addition & 1 deletion examples/beginner/GeometricModeling/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ Pigeons = "0.4"
Plots = "1"
Pyehtim = "0.1"
StableRNGs = "1"
VLBIImagePriors = "0.8"
VLBIImagePriors = "0.9"
4 changes: 3 additions & 1 deletion examples/intermediate/ClosureImaging/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -14,6 +15,7 @@ Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029"
VLBILikelihoods = "90db92cd-0007-4c0a-8e51-dbf0782ce592"

[compat]
CairoMakie = "0.12"
Expand All @@ -24,4 +26,4 @@ Pkg = "1"
Plots = "1"
Pyehtim = "0.1"
StableRNGs = "1"
VLBIImagePriors = "0.8"
VLBIImagePriors = "0.9"
6 changes: 3 additions & 3 deletions examples/intermediate/ClosureImaging/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ skym = SkyModel(sky, prior, grid; metadata=skymeta)

# Since we are fitting closures we do not need to include an instrument model, since
# the closure likelihood is approximately independent of gains in the high SNR limit.
post = VLBIPosterior(skym, dlcamp, dcphase)
using Enzyme
post = VLBIPosterior(skym, dlcamp, dcphase; admode=set_runtime_activity(Enzyme.Reverse))

# ## Reconstructing the Image

Expand All @@ -144,8 +145,7 @@ post = VLBIPosterior(skym, dlcamp, dcphase)
# OptimizationOptimJL. We also need to import Enzyme to allow for automatic differentiation.
using Optimization
using OptimizationOptimJL
using Enzyme
xopt, sol = comrade_opt(post, LBFGS(), AutoEnzyme(;mode=Enzyme.Reverse);
xopt, sol = comrade_opt(post, LBFGS();
maxiters=1000, initial_params=prior_sample(rng, post))


Expand Down
2 changes: 0 additions & 2 deletions examples/intermediate/PolarizedImaging/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029"
18 changes: 9 additions & 9 deletions examples/intermediate/PolarizedImaging/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ end
# image model. Our image will be a 10x10 raster with a 60μas FOV.
using Distributions
using VLBIImagePriors
fovx = μas2rad(150.0)
fovy = μas2rad(150.0)
fovx = μas2rad(200.0)
fovy = μas2rad(200.0)
nx = ny = 32
grid = imagepixels(fovx, fovy, nx, ny)

Expand All @@ -204,7 +204,7 @@ skymeta = (; mimg=mimg./flux(mimg), ftot=0.6)
cprior = corr_image_prior(grid, dvis)
skyprior = (
c = cprior,
σ = truncated(Normal(0.0, 0.5); lower=0.0),
σ = Exponential(0.1),
p = cprior,
p0 = Normal(-2.0, 2.0),
pσ = truncated(Normal(0.0, 1.0); lower=0.01),
Expand Down Expand Up @@ -287,7 +287,7 @@ J = JonesSandwich(js, G, D, R)
intprior = (
lgR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.2))),
lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.01))),
gpR = ArrayPrior(IIDSitePrior(ScanSeg(), DiagonalVonMises(0.0, inv(π^2))); refant=SEFDReference(0.0), phase=false),
gpR = ArrayPrior(IIDSitePrior(ScanSeg(), DiagonalVonMises(0.0, inv(π^2))); refant=SEFDReference(0.0), phase=true),
gprat= ArrayPrior(IIDSitePrior(ScanSeg(), DiagonalVonMises(0.0, inv(0.1^2))); refant = SingleReference(:AA, 0.0), phase=false),
dRx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))),
dRy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))),
Expand All @@ -301,8 +301,9 @@ intmodel = InstrumentModel(J, intprior)

# intmodel = InstrumentModel(JonesR(;add_fr=true))
# Putting it all together, we form our likelihood and posterior objects for optimization and
# sampling.
post = VLBIPosterior(skym, intmodel, dvis)
# sampling, and specifying to use Enzyme.Reverse with runtime activity for AD.
using Enzyme
post = VLBIPosterior(skym, intmodel, dvis; admode=set_runtime_activity(Enzyme.Reverse))

# ## Reconstructing the Image and Instrument Effects

Expand All @@ -323,8 +324,7 @@ tpost = asflat(post)
# work with the polarized Comrade posterior is Enzyme.
using Optimization
using OptimizationOptimisers
using Enzyme
xopt, sol = comrade_opt(post, Optimisers.Adam(), AutoEnzyme(;mode=Enzyme.Reverse);
xopt, sol = comrade_opt(post, Optimisers.Adam();
initial_params=prior_sample(rng, post), maxiters=25_000)


Expand Down Expand Up @@ -384,7 +384,7 @@ p |> DisplayAs.PNG |> DisplayAs.Text
# other imaging examples. For example
# ```julia
# using AdvancedHMC
# chain = sample(rng, post, NUTS(0.8), 10_000; adtype=AutoEnzyme(;mode=Enzyme.Reverse), n_adapts=5000, progress=true, initial_params=xopt)
# chain = sample(rng, post, NUTS(0.8), 10_000, n_adapts=5000, progress=true, initial_params=xopt)
# ```


Expand Down
2 changes: 1 addition & 1 deletion examples/intermediate/StokesIImaging/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ Pkg = "1"
Plots = "1"
Pyehtim = "0.1"
StableRNGs = "1"
VLBIImagePriors = "0.8"
VLBIImagePriors = "0.9"
12 changes: 8 additions & 4 deletions examples/intermediate/StokesIImaging/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,12 @@ intpr = (
intmodel = InstrumentModel(G, intpr)


post = VLBIPosterior(skym, intmodel, dvis)
# To form the posterior we just combine the skymodel, instrument model and the data. Additionally,
# since we want to use gradients we need to specify the AD mode. Essentially for all modes we recommend
# using `Enzyme.set_runtime_activity(Enzyme.Reverse)`. Eventually as Comrade and Enzyme matures we will
# no need `set_runtime_activity`.
using Enzyme
post = VLBIPosterior(skym, intmodel, dvis; admode=set_runtime_activity(Enzyme.Reverse))
# done using the `asflat` function.
tpost = asflat(post)
ndim = dimension(tpost)
Expand All @@ -160,8 +165,7 @@ ndim = dimension(tpost)
# To initialize our sampler we will use optimize using Adam
using Optimization
using OptimizationOptimisers
using Enzyme
xopt, sol = comrade_opt(post, Optimisers.Adam(), AutoEnzyme(;mode=Enzyme.Reverse); initial_params=prior_sample(rng, post), maxiters=20_000, g_tol=1e-1)
xopt, sol = comrade_opt(post, Optimisers.Adam(); initial_params=prior_sample(rng, post), maxiters=20_000, g_tol=1e-1)

# !!! warning
# Fitting gains tends to be very difficult, meaning that optimization can take a lot longer.
Expand Down Expand Up @@ -208,7 +212,7 @@ plot(gt, layout=(3,3), size=(600,500)) |> DisplayAs.PNG |> DisplayAs.Text
# run
#-
using AdvancedHMC
chain = sample(rng, post, NUTS(0.8), 1_000; adtype=AutoEnzyme(;mode=Enzyme.Reverse), n_adapts=500, progress=false, initial_params=xopt)
chain = sample(rng, post, NUTS(0.8), 1_000; n_adapts=500, progress=false, initial_params=xopt)
#-
# !!! note
# The above sampler will store the samples in memory, i.e. RAM. For large models this
Expand Down
28 changes: 14 additions & 14 deletions ext/ComradeAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Accessors
using ArgCheck
using DocStringExtensions
using HypercubeTransform
using LogDensityProblems, LogDensityProblemsAD
using LogDensityProblems
using Printf
using Random
using StatsBase
Expand All @@ -36,15 +36,14 @@ end

function AbstractMCMC.Sample(
rng::Random.AbstractRNG, tpost::Comrade.TransformedVLBIPosterior,
sampler::AbstractHMCSampler; adtype=Val(:Enzyme), initial_params=nothiing, kwargs...)
∇ℓ = ADgradient(adtype, tpost)
sampler::AbstractHMCSampler; initial_params=nothiing, kwargs...)
θ0 = initialize_params(tpost, initial_params)
model, smplr = make_sampler(rng, ∇ℓ, sampler, θ0)
model, smplr = make_sampler(rng, tpost, sampler, θ0)
return AbstractMCMC.Sample(rng, model, smplr; initial_params=θ0, kwargs...)
end

"""
sample(rng, post::VLBIPosterior, sampler::AbstractHMCSampler, nsamples, args...;saveto=MemoryStore(), adtype=Val(:Enzyme), initial_params=nothing, kwargs...)
sample(rng, post::VLBIPosterior, sampler::AbstractHMCSampler, nsamples, args...;saveto=MemoryStore(), initial_params=nothing, kwargs...)

Sample from the posterior `post` using the sampler `sampler` for `nsamples` samples. Additional
arguments are forwarded to AbstractMCMC.sample. If `saveto` is a DiskStore, the samples will be
Expand All @@ -59,23 +58,26 @@ saved to disk. If `initial_params` is not `nothing` then the sampler will start
## Keyword Arguments

- `saveto`: If a DiskStore, the samples will be saved to disk, if [`MemoryStore`](@ref) the samples will be stored in memory/ram.
- `adtype`: The automatic differentiation type to use. The default if Enzyme which is the recommended choice for Comrade currently.
- `initial_params`: The initial parameters to start the sampler from. If `nothing` then the sampler will start from a random point in the prior.
- `kwargs`: Additional keyword arguments to pass to the sampler. Examples include `n_adapts` which is the total number of samples to use for adaptation.
To see the others see the AdvancedHMC documentation.
"""
function AbstractMCMC.sample(
rng::Random.AbstractRNG, post::Comrade.VLBIPosterior,
sampler::AbstractHMCSampler, nsamples, args...;
saveto=MemoryStore(), adtype=Val(:Enzyme), initial_params=nothing, kwargs...)
saveto=MemoryStore(), initial_params=nothing, kwargs...)

saveto isa DiskStore && return sample_to_disk(rng, post, sampler, nsamples, args...; outdir=saveto.name, output_stride=min(saveto.stride, nsamples), adtype, initial_params, kwargs...)
saveto isa DiskStore && return sample_to_disk(rng, post, sampler, nsamples, args...; outdir=saveto.name, output_stride=min(saveto.stride, nsamples), initial_params, kwargs...)

if isnothing(Comrade.admode(post))
throw(ArgumentError("You must specify an automatic differentiation type in VLBIPosterior with admode kwarg"))
else
tpost = asflat(post)
end

tpost = asflat(post)
∇ℓ = ADgradient(adtype, tpost)
θ0 = initialize_params(tpost, initial_params)
model, smplr = make_sampler(rng, ∇ℓ, sampler, θ0)
model, smplr = make_sampler(rng, tpost, sampler, θ0)

res = sample(rng, model, smplr, nsamples, args...;
initial_params=θ0, saveto=saveto, chain_type=Array, kwargs...)
Expand All @@ -90,7 +92,6 @@ end
function initialize(rng::Random.AbstractRNG, tpost::Comrade.TransformedVLBIPosterior,
sampler::AbstractHMCSampler, nsamples, outbase, args...;
n_adapts = min(nsamples÷2, 1000),
adtype = Val(:Enzyme),
initial_params=nothing, outdir = "Results",
output_stride=min(100, nsamples),
restart = false,
Expand Down Expand Up @@ -119,7 +120,7 @@ function initialize(rng::Random.AbstractRNG, tpost::Comrade.TransformedVLBIPoste
@warn "No starting location chosen, picking start from prior"
θ0 = prior_sample(rng, tpost)
end
t = Sample(rng, tpost, sampler; initial_params=θ0, adtype, n_adapts, kwargs...)(1:nsamples)
t = Sample(rng, tpost, sampler; initial_params=θ0, n_adapts, kwargs...)(1:nsamples)
pt = Iterators.partition(t, output_stride)
nscans = nsamples÷output_stride + (nsamples%output_stride!=0 ? 1 : 0)

Expand Down Expand Up @@ -158,7 +159,6 @@ end

function sample_to_disk(rng::Random.AbstractRNG, post::Comrade.VLBIPosterior,
sampler::AbstractHMCSampler, nsamples, args...;
adtype = Val(:Enzyme),
n_adapts = min(nsamples÷2, 1000),
initial_params=nothing, outdir = "Results",
restart=false,
Expand All @@ -172,7 +172,7 @@ function sample_to_disk(rng::Random.AbstractRNG, post::Comrade.VLBIPosterior,

pt, state, out, i = initialize(
rng, tpost, sampler, nsamples, outbase, args...;
n_adapts, adtype,
n_adapts,
initial_params, restart, outdir, output_stride, kwargs...
)

Expand Down
16 changes: 13 additions & 3 deletions ext/ComradeEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
module ComradeEnzymeExt

using Enzyme
using Comrade
using LogDensityProblems

function __init__()
# We need this to ensure than Enzyme can AD through the Comrade code base
Enzyme.API.runtimeActivity!(true)
LogDensityProblems.dimension(d::Comrade.TransformedVLBIPosterior) = dimension(d)
LogDensityProblems.capabilities(::Type{<:Comrade.TransformedVLBIPosterior}) = LogDensityProblems.LogDensityOrder{1}()


function LogDensityProblems.logdensity_and_gradient(d::Comrade.TransformedVLBIPosterior, x::AbstractArray)
mode = Enzyme.EnzymeCore.WithPrimal(Comrade.admode(d))
dx = zero(x)
(_, y) = autodiff(mode, Comrade.logdensityof, Active, Const(d), Duplicated(x, dx))
return y, dx
end



end
Loading
Loading