Skip to content

Commit

Permalink
Move converters from PPLs to extensions (#293)
Browse files Browse the repository at this point in the history
* Turing MCMCChains integration into extension

* Turing SampleChains integration into extension

* Move conversion declarations to their own file

* Copy rekey from InferenceObjects

* Fix imports

* Prefix function calls with module name

* Separate dynamichmc part to its own extension

* Properly extend functions

* Add weak deps to extras for older Julia versions

* Fix Requires usage for older Julia versions

* Increment patch number

* Remove no-longer-used imports

* Fix prefixes
  • Loading branch information
sethaxen authored Jul 30, 2023
1 parent 7c9ae5a commit a572fe0
Show file tree
Hide file tree
Showing 10 changed files with 319 additions and 237 deletions.
19 changes: 18 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ArviZ"
uuid = "131c737c-5715-5e2e-ad31-c244f01c1dc7"
authors = ["Seth Axen <[email protected]>"]
version = "0.9.0"
version = "0.9.1"

[deps]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Expand Down Expand Up @@ -29,6 +29,16 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TableTraits = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
SampleChains = "754583d1-7fc4-4dab-93b5-5eaca5c9622e"
SampleChainsDynamicHMC = "6d9fd711-e8b2-4778-9c70-c1dfb499d4c4"

[extensions]
ArviZMCMCChainsExt = "MCMCChains"
ArviZSampleChainsExt = "SampleChains"
ArviZSampleChainsDynamicHMCExt = ["SampleChains", "SampleChainsDynamicHMC"]

[compat]
DataInterpolations = "4"
DimensionalData = "0.23, 0.24"
Expand All @@ -37,14 +47,21 @@ DocStringExtensions = "0.8, 0.9"
InferenceObjects = "0.3.10"
IteratorInterfaceExtensions = "0.1.1, 1"
LogExpFunctions = "0.2.0, 0.3"
MCMCChains = "6"
MCMCDiagnosticTools = "0.3.4"
Optim = "1"
OrderedCollections = "1"
PSIS = "0.9.1"
PrettyTables = "2.1, 2.2"
Requires = "0.5.2, 1.0"
SampleChains = "0.5"
Setfield = "1"
StatsBase = "0.32, 0.33, 0.34"
TableTraits = "0.4, 1"
Tables = "1"
julia = "1.6"

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
SampleChains = "754583d1-7fc4-4dab-93b5-5eaca5c9622e"
SampleChainsDynamicHMC = "6d9fd711-e8b2-4778-9c70-c1dfb499d4c4"
145 changes: 66 additions & 79 deletions src/mcmcchains.jl → ext/ArviZMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
module ArviZMCMCChainsExt

if isdefined(Base, :get_extension)
using ArviZ: ArviZ, InferenceObjects
using MCMCChains: MCMCChains
else
using ..ArviZ: ArviZ, InferenceObjects
using ..MCMCChains: MCMCChains
end

const turing_key_map = Dict(
:hamiltonian_energy => :energy,
:hamiltonian_energy_error => :energy_error,
Expand Down Expand Up @@ -53,12 +63,12 @@ function varnames_locs(loc_names)
return NamedTuple(vars_to_locs)
end

function attributes_dict(chns::Chains)
function attributes_dict(chns::MCMCChains.Chains)
info = Base.structdiff(chns.info, NamedTuple{(:hashedsummary,)})
return Dict{String,Any}((string(k), v) for (k, v) in pairs(info))
end

function section_namedtuple(chns::Chains, section)
function section_namedtuple(chns::MCMCChains.Chains, section)
ndraws, _, nchains = size(chns)
loc_names = chns.name_map[section]
vars_to_locs = varnames_locs(loc_names)
Expand All @@ -68,7 +78,7 @@ function section_namedtuple(chns::Chains, section)
ndim = length(sizes)
# NOTE: slicing specific entries from AxisArrays does not preserve order
# https://github.com/JuliaArrays/AxisArrays.jl/issues/182
oldarr = replacemissing(permutedims(chns.value[:, loc_names, :], (1, 3, 2)))
oldarr = ArviZ.replacemissing(permutedims(chns.value[:, loc_names, :], (1, 3, 2)))
if iszero(ndim)
arr = dropdims(oldarr; dims=3)
else
Expand All @@ -84,77 +94,29 @@ function section_namedtuple(chns::Chains, section)
end

function chains_to_namedtuple(
chns::Chains; ignore=(), section=:parameters, rekey_fun=identity
chns::MCMCChains.MCMCChains.Chains; ignore=(), section=:parameters, rekey_fun=identity
)
section in sections(chns) || return (;)
section in MCMCChains.sections(chns) || return (;)
chns_data = section_namedtuple(chns, section)
chns_data_return = NamedTuple{filter(∉(ignore), keys(chns_data))}(chns_data)
return rekey_fun(chns_data_return)
end

"""
convert_to_inference_data(obj::Chains; group = :posterior, kwargs...) -> InferenceData
convert_to_inference_data(obj::MCMCChains.Chains; group = :posterior, kwargs...) -> InferenceData
Convert the chains `obj` to an [`InferenceData`](@ref) with the specified `group`.
Remaining `kwargs` are forwarded to [`from_mcmcchains`](@ref).
"""
function convert_to_inference_data(chns::Chains; group::Symbol=:posterior, kwargs...)
group === :posterior && return from_mcmcchains(chns; kwargs...)
return from_mcmcchains(; group => chns, kwargs...)
function InferenceObjects.convert_to_inference_data(
chns::MCMCChains.Chains; group::Symbol=:posterior, kwargs...
)
group === :posterior && return ArviZ.from_mcmcchains(chns; kwargs...)
return ArviZ.from_mcmcchains(; group => chns, kwargs...)
end

@doc doc"""
from_mcmcchains(posterior::MCMCChains.Chains; kwargs...) -> InferenceData
from_mcmcchains(; kwargs...) -> InferenceData
from_mcmcchains(
posterior::MCMCChains.Chains,
posterior_predictive,
predictions,
log_likelihood;
kwargs...
) -> InferenceData
Convert data in an `MCMCChains.Chains` format into an [`InferenceData`](@ref).
Any keyword argument below without an an explicitly annotated type above is allowed, so long
as it can be passed to [`convert_to_inference_data`](@ref).
# Arguments
- `posterior::MCMCChains.Chains`: Draws from the posterior
# Keywords
- `posterior_predictive::Any=nothing`: Draws from the posterior predictive distribution or
name(s) of predictive variables in `posterior`
- `predictions`: Out-of-sample predictions for the posterior.
- `prior`: Draws from the prior
- `prior_predictive`: Draws from the prior predictive distribution or name(s) of predictive
variables in `prior`
- `observed_data`: Observed data on which the `posterior` is conditional. It should only
contain data which is modeled as a random variable. Keys are parameter names and values.
- `constant_data`: Model constants, data included in the model that are not modeled as
random variables. Keys are parameter names.
- `predictions_constant_data`: Constants relevant to the model predictions (i.e. new `x`
values in a linear regression).
- `log_likelihood`: Pointwise log-likelihood for the data. It is recommended to use this
argument as a named tuple whose keys are observed variable names and whose values are log
likelihood arrays. Alternatively, provide the name of variable in `posterior` containing
log likelihoods.
- `library=MCMCChains`: Name of library that generated the chains
- `coords`: Map from named dimension to named indices
- `dims`: Map from variable name to names of its dimensions
- `eltypes`: Map from variable names to eltypes. This is primarily used to assign discrete
eltypes to discrete variables that were stored in `Chains` as floats.
# Returns
- `InferenceData`: The data with groups corresponding to the provided data
"""
from_mcmcchains

function from_mcmcchains(
function ArviZ.from_mcmcchains(
posterior,
posterior_predictive,
predictions,
Expand All @@ -170,13 +132,13 @@ function from_mcmcchains(
post_data = nothing
stats_data = nothing
else
post_data = convert_to_eltypes(chains_to_namedtuple(posterior), eltypes)
post_data = ArviZ.convert_to_eltypes(chains_to_namedtuple(posterior), eltypes)
stats_data = chains_to_namedtuple(posterior; section=:internals, rekey_fun)
stats_data = enforce_stat_eltypes(stats_data)
stats_data = convert_to_eltypes(stats_data, (; is_accept=Bool))
stats_data = ArviZ.enforce_stat_eltypes(stats_data)
stats_data = ArviZ.convert_to_eltypes(stats_data, (; is_accept=Bool))
end

all_idata = InferenceData()
all_idata = InferenceObjects.InferenceData()
for (group, group_data) in [
:posterior_predictive => posterior_predictive,
:predictions => predictions,
Expand All @@ -192,18 +154,22 @@ function from_mcmcchains(
post_data
)
end
group_dataset = if group_data isa Chains
convert_to_dataset(group_data; library, eltypes, kwargs...)
group_dataset = if group_data isa MCMCChains.Chains
InferenceObjects.convert_to_dataset(group_data; library, eltypes, kwargs...)
else
convert_to_dataset(group_data; library, kwargs...)
InferenceObjects.convert_to_dataset(group_data; library, kwargs...)
end
all_idata = merge(all_idata, InferenceData(; group => group_dataset))
all_idata = merge(
all_idata, InferenceObjects.InferenceData(; group => group_dataset)
)
end
post_idata = from_namedtuple(post_data; sample_stats=stats_data, library, kwargs...)
post_idata = ArviZ.from_namedtuple(
post_data; sample_stats=stats_data, library, kwargs...
)
all_idata = merge(all_idata, post_idata)
return all_idata
end
function from_mcmcchains(
function ArviZ.from_mcmcchains(
posterior=nothing;
posterior_predictive=nothing,
predictions=nothing,
Expand All @@ -217,7 +183,7 @@ function from_mcmcchains(
eltypes=(;),
kwargs...,
)
all_idata = from_mcmcchains(
all_idata = ArviZ.from_mcmcchains(
posterior,
posterior_predictive,
predictions,
Expand All @@ -228,7 +194,7 @@ function from_mcmcchains(
)

if prior !== nothing
pre_prior_idata = convert_to_inference_data(
pre_prior_idata = InferenceObjects.convert_to_inference_data(
prior; posterior_predictive=prior_predictive, library, eltypes, kwargs...
)
prior_idata = rekey(
Expand All @@ -241,18 +207,20 @@ function from_mcmcchains(
)
all_idata = merge(all_idata, prior_idata)
elseif prior_predictive !== nothing
if prior_predictive isa Chains
pre_prior_predictive_idata = convert_to_inference_data(
if prior_predictive isa MCMCChains.Chains
pre_prior_predictive_idata = InferenceObjects.convert_to_inference_data(
prior_predictive; library, eltypes, kwargs...
)
else
pre_prior_predictive_idata = convert_to_inference_data(
pre_prior_predictive_idata = InferenceObjects.convert_to_inference_data(
prior_predictive; library, kwargs...
)
end
all_idata = merge(
all_idata,
InferenceData(; prior_predictive=pre_prior_predictive_idata.posterior),
InferenceObjects.InferenceData(;
prior_predictive=pre_prior_predictive_idata.posterior
),
)
end

Expand All @@ -262,10 +230,29 @@ function from_mcmcchains(
:predictions_constant_data => predictions_constant_data,
]
group_data === nothing && continue
group_data = convert_to_eltypes(group_data, eltypes)
group_dataset = convert_to_dataset(group_data; library, default_dims=(), kwargs...)
all_idata = merge(all_idata, InferenceData(; group => group_dataset))
group_data = ArviZ.convert_to_eltypes(group_data, eltypes)
group_dataset = ArviZ.convert_to_dataset(
group_data; library, default_dims=(), kwargs...
)
all_idata = merge(
all_idata, InferenceObjects.InferenceData(; group => group_dataset)
)
end

return all_idata
end

# adapted from InferenceObjects.jl
rekey(d, keymap) = Dict(get(keymap, k, k) => d[k] for k in keys(d))
function rekey(d::NamedTuple, keymap)
new_keys = map(k -> get(keymap, k, k), keys(d))
return NamedTuple{new_keys}(values(d))
end
function rekey(data::InferenceObjects.InferenceData, keymap)
groups_old = InferenceObjects.groups(data)
names_new = map(k -> get(keymap, k, k), propertynames(groups_old))
groups_new = NamedTuple{names_new}(Tuple(groups_old))
return InferenceObjects.InferenceData(groups_new)
end

end # module
31 changes: 31 additions & 0 deletions ext/ArviZSampleChainsDynamicHMCExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
module ArviZSampleChainsDynamicHMCExt

if isdefined(Base, :get_extension)
using ArviZ
using SampleChains: SampleChains
using SampleChainsDynamicHMC: SampleChainsDynamicHMC
else
using ..ArviZ
using ..SampleChains: SampleChains
using ..SampleChainsDynamicHMC: SampleChainsDynamicHMC
end

function ArviZ._samplechains_info(chain::SampleChainsDynamicHMC.DynamicHMCChain)
info = SampleChains.info(chain)
termination = info.termination
tree_stats = (
energy=info.π,
tree_depth=info.depth,
acceptance_rate=info.acceptance_rate,
n_steps=info.steps,
diverging=map(t -> t.left == t.right, termination),
turning=map(t -> t.left < t.right, termination),
)
used_info = (, :depth, :acceptance_rate, :steps, :termination)
skipped_info = setdiff(propertynames(info), used_info)
isempty(skipped_info) ||
@debug "Skipped SampleChainsDynamicHMC info entries: $skipped_info."
return tree_stats
end

end # module
Loading

2 comments on commit a572fe0

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/88661

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.1 -m "<description of version>" a572fe059022060e627ae331761b3a1dd23c7284
git push origin v0.9.1

Please sign in to comment.