From 2cb2d9e770697de1a3a12a8bd58c477068dc5107 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 10 Jul 2024 15:07:16 -0400 Subject: [PATCH 1/6] feat: port generalized leapfrog Signed-off-by: Kai Xu --- src/AdvancedHMC.jl | 2 + src/riemannian/integrator.jl | 93 ++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 src/riemannian/integrator.jl diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 42d52767..fcaab095 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -52,6 +52,8 @@ export Hamiltonian include("integrator.jl") export Leapfrog, JitteredLeapfrog, TemperedLeapfrog +include("riemannian/integrator.jl") +export GeneralizedLeapfrog include("trajectory.jl") export Trajectory, diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl new file mode 100644 index 00000000..718193e8 --- /dev/null +++ b/src/riemannian/integrator.jl @@ -0,0 +1,93 @@ +""" +$(TYPEDEF) + +Generalized leapfrog integrator with fixed step size `ϵ`. + +# Fields + +$(TYPEDFIELDS) +""" +struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} + "Step size." + ϵ::T + n::Int +end +Base.show(io::IO, l::GeneralizedLeapfrog) = + print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))") + +# fallback to ignore return_cache & cache kwargs for other ∂H∂θ +function ∂H∂θ_cache(h, θ, r; return_cache = false, cache = nothing) + dv = ∂H∂θ(h, θ, r) + return return_cache ? (dv, nothing) : dv +end + +# TODO(Kai) make sure vectorization works +# TODO(Kai) check if tempering is valid +# TODO(Kai) abstract out the 3 main steps and merge with `step` in `integrator.jl` +function step( + lf::GeneralizedLeapfrog{T}, + h::Hamiltonian, + z::P, + n_steps::Int = 1; + fwd::Bool = n_steps > 0, # simulate hamiltonian backward when n_steps < 0 + full_trajectory::Val{FullTraj} = Val(false), +) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj} + n_steps = abs(n_steps) # to support `n_steps < 0` cases + + ϵ = fwd ? step_size(lf) : -step_size(lf) + ϵ = ϵ' + + res = if FullTraj + Vector{P}(undef, n_steps) + else + z + end + + for i = 1:n_steps + θ_init, r_init = z.θ, z.r + # Tempering + #r = temper(lf, r, (i=i, is_half=true), n_steps) + # eq (16) of Girolami & Calderhead (2011) + r_half = copy(r_init) + local cache + for j = 1:lf.n + # Reuse cache for the first iteration + if j == 1 + @unpack value, gradient = z.ℓπ + elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged) + retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache = true) + @unpack value, gradient = retval + else # reuse cache + @unpack value, gradient = ∂H∂θ_cache(h, θ_init, r_half; cache = cache) + end + r_half = r_init - ϵ / 2 * gradient + end + # eq (17) of Girolami & Calderhead (2011) + θ_full = copy(θ_init) + term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop + for j = 1:lf.n + θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) + end + # eq (18) of Girolami & Calderhead (2011) + @unpack value, gradient = ∂H∂θ(h, θ_full, r_half) + r_full = r_half - ϵ / 2 * gradient + # Tempering + #r = temper(lf, r, (i=i, is_half=false), n_steps) + # Create a new phase point by caching the logdensity and gradient + z = phasepoint(h, θ_full, r_full; ℓπ = DualValue(value, gradient)) + # Update result + if FullTraj + res[i] = z + else + res = z + end + if !isfinite(z) + # Remove undef + if FullTraj + res = res[isassigned.(Ref(res), 1:n_steps)] + end + break + end + end + return res +end \ No newline at end of file From c99e0062b09266f05fc9cea0dc75bc19e7a3d824 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Thu, 11 Jul 2024 01:17:27 -0400 Subject: [PATCH 2/6] refactor: remove copy Signed-off-by: Kai Xu --- src/riemannian/integrator.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index 718193e8..06f557ba 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -48,7 +48,7 @@ function step( # Tempering #r = temper(lf, r, (i=i, is_half=true), n_steps) # eq (16) of Girolami & Calderhead (2011) - r_half = copy(r_init) + r_half = r_init local cache for j = 1:lf.n # Reuse cache for the first iteration @@ -63,7 +63,7 @@ function step( r_half = r_init - ϵ / 2 * gradient end # eq (17) of Girolami & Calderhead (2011) - θ_full = copy(θ_init) + θ_full = θ_init term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop for j = 1:lf.n θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) From 27c2728bfd59ae84a4a5c08eb11ac40016a2fb72 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Thu, 11 Jul 2024 09:30:03 -0400 Subject: [PATCH 3/6] format: add emplty line Signed-off-by: Kai Xu --- src/riemannian/integrator.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index 06f557ba..1c15eefb 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -90,4 +90,4 @@ function step( end end return res -end \ No newline at end of file +end From 62b4ede3fe94f38d733a654fbe95dfee3045f4eb Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 23 Jul 2024 17:53:03 -0400 Subject: [PATCH 4/6] Update src/riemannian/integrator.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/riemannian/integrator.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index 1c15eefb..25461d1a 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -6,6 +6,11 @@ Generalized leapfrog integrator with fixed step size `ϵ`. # Fields $(TYPEDFIELDS) + + +## References + +1. Girolami, Mark, and Ben Calderhead. "Riemann manifold Langevin and Hamiltonian Monte Carlo methods." Journal of the Royal Statistical Society Series B: Statistical Methodology 73, no. 2 (2011): 123-214. """ struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} "Step size." From 5462aad73b1220f87ebd2b1abba7b3dc14d3da15 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 24 Jul 2024 00:02:24 +0200 Subject: [PATCH 5/6] chore: add warning for using generalized leapfrog with vectorization Signed-off-by: Kai Xu --- src/riemannian/integrator.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index 25461d1a..d6108340 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -36,12 +36,16 @@ function step( n_steps::Int = 1; fwd::Bool = n_steps > 0, # simulate hamiltonian backward when n_steps < 0 full_trajectory::Val{FullTraj} = Val(false), -) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj} +) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint{TP},FullTraj,TP} n_steps = abs(n_steps) # to support `n_steps < 0` cases ϵ = fwd ? step_size(lf) : -step_size(lf) ϵ = ϵ' + if !(T <: AbstractFloat) || !(TP <: AbstractVector) + @warn "Vectorization is not tested for GeneralizedLeapfrog." + end + res = if FullTraj Vector{P}(undef, n_steps) else From ccfd74043ceccae2669b7c2d8323f5f6c1fa3066 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 24 Jul 2024 00:10:57 +0200 Subject: [PATCH 6/6] fix: type order Signed-off-by: Kai Xu --- src/riemannian/integrator.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index d6108340..cce0dc27 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -36,7 +36,7 @@ function step( n_steps::Int = 1; fwd::Bool = n_steps > 0, # simulate hamiltonian backward when n_steps < 0 full_trajectory::Val{FullTraj} = Val(false), -) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint{TP},FullTraj,TP} +) where {T<:AbstractScalarOrVec{<:AbstractFloat},TP,P<:PhasePoint{TP},FullTraj} n_steps = abs(n_steps) # to support `n_steps < 0` cases ϵ = fwd ? step_size(lf) : -step_size(lf)