Skip to content

Commit

Permalink
Fix get() ambiguities
Browse files Browse the repository at this point in the history
Done by:

1. Constraining the type parameter to AbstractVector{Symbol}
2. Modifying the method below it to use a vector instead of a tuple
  • Loading branch information
penelopeysm committed Aug 14, 2024
1 parent a6e01f8 commit 36accad
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ StatsBase.loglikelihood(m::ModeResult) = m.lp

"""
Base.get(m::ModeResult, var_symbol::Symbol)
Base.get(m::ModeResult, var_symbols)
Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
Return the values of all the variables with the symbol(s) `var_symbol` in the mode result
`m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second
argument should be either a `Symbol` or an iterator of `Symbol`s.
argument should be either a `Symbol` or a vector of `Symbol`s.
"""
function Base.get(m::ModeResult, var_symbols)
function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
log_density = m.f
# Get all the variable names in the model. This is the same as the list of keys in
# m.values, but they are more convenient to filter when they are VarNames rather than
Expand All @@ -304,7 +304,7 @@ function Base.get(m::ModeResult, var_symbols)
return (; zip(var_symbols, value_vectors)...)
end

Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,))
Base.get(m::ModeResult, var_symbol::Symbol) = get(m, [var_symbol])

"""
ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
Expand Down

0 comments on commit 36accad

Please sign in to comment.