Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make AdvancedMH compatible with AbstractMCMC 5 #92

Merged
merged 3 commits into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedMH"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.7.6"
version = "0.8.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -23,18 +23,18 @@ AdvancedMHMCMCChainsExt = "MCMCChains"
AdvancedMHStructArraysExt = "StructArrays"

[compat]
AbstractMCMC = "4, 5"
AbstractMCMC = "5"
DiffResults = "1"
Distributions = "0.20 - 0.25"
LinearAlgebra = "1.6 - 1.11"
Random = "1.6 - 1.11"
Distributions = "0.25"
FillArrays = "1"
ForwardDiff = "0.10"
LogDensityProblems = "2"
MCMCChains = "5, 6"
MCMCChains = "6.0.4"
Requires = "1"
StructArrays = "0.6"
julia = "1.6"
LinearAlgebra = "1.6"
Random = "1.6"

[extras]
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,21 @@ AdvancedMH.jl implements the interface of [AbstractMCMC](https://github.com/Turi

```julia
# Sample 4 chains from the posterior serially, without thread or process parallelism.
chain = sample(model, RWMH(init_params), MCMCSerial(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
chain = sample(model, spl, MCMCSerial(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)

# Sample 4 chains from the posterior using multiple threads.
chain = sample(model, RWMH(init_params), MCMCThreads(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
chain = sample(model, spl, MCMCThreads(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)

# Sample 4 chains from the posterior using multiple processes.
chain = sample(model, RWMH(init_params), MCMCDistributed(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
chain = sample(model, spl, MCMCDistributed(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
```

## Metropolis-adjusted Langevin algorithm (MALA)

AdvancedMH.jl also offers an implementation of [MALA](https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm) if the `ForwardDiff` and `DiffResults` packages are available.

A `MALA` sampler can be constructed by `MALA(proposal)` where `proposal` is a function that
takes the gradient computed at the current sample. It is required to specify an initial sample `init_params` when calling `sample`.
takes the gradient computed at the current sample. It is required to specify an initial sample `initial_params` when calling `sample`.

```julia
# Import the package.
Expand Down Expand Up @@ -180,7 +180,7 @@ model = DensityModel(density)
spl = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))

# Sample from the posterior.
chain = sample(model, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
chain = sample(model, spl, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
```

### Usage with [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl)
Expand All @@ -192,5 +192,5 @@ Using our implementation of the `LogDensityProblems.jl` interface above:
```julia
using LogDensityProblemsAD
model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity())
sample(model_with_ad, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
sample(model_with_ad, spl, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
```
6 changes: 3 additions & 3 deletions src/mh-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ spl = MetropolisHastings(proposal)
When using `MetropolisHastings` with the function `sample`, the following keyword
arguments are allowed:

- `init_params` defines the initial parameterization for your model. If
- `initial_params` defines the initial parameterization for your model. If
none is given, the initial parameters will be drawn from the sampler's proposals.
- `param_names` is a vector of strings to be assigned to parameters. This is only
used if `chain_type=Chains`.
Expand Down Expand Up @@ -77,10 +77,10 @@ function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModelOrLogDensityModel,
sampler::MHSampler;
init_params=nothing,
initial_params=nothing,
kwargs...
)
params = init_params === nothing ? propose(rng, sampler, model) : init_params
params = initial_params === nothing ? propose(rng, sampler, model) : initial_params
transition = AdvancedMH.transition(sampler, model, params)
return transition, transition
end
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ include("util.jl")
val = [0.4, 1.2]

# Sample from the posterior.
chain1 = sample(model, spl1, 10, init_params = val)
chain1 = sample(model, spl1, 10, initial_params = val)

@test chain1[1].params == val
end
Expand Down Expand Up @@ -265,7 +265,7 @@ include("util.jl")
spl1 = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))

# Sample from the posterior with initial parameters.
chain1 = sample(model, spl1, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
chain1 = sample(model, spl1, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])

@test mean(chain1.μ) ≈ 0.0 atol=0.1
@test mean(chain1.σ) ≈ 1.0 atol=0.1
Expand All @@ -276,7 +276,7 @@ include("util.jl")
admodel,
spl1,
100000;
init_params=ones(2),
initial_params=ones(2),
chain_type=StructArray,
param_names=["μ", "σ"]
)
Expand Down
Loading