Skip to content

Commit 981932e

Browse files
committed
Implement Implicit Midpoint integrator
1 parent 8f1ebc5 commit 981932e

File tree

2 files changed

+101
-5
lines changed

2 files changed

+101
-5
lines changed

src/AdvancedHMC.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog
5151
include("riemannian/metric.jl")
5252
export AbstractRiemannianMetric, DenseRiemannianMetric, IdentityMap, SoftAbsMap
5353
include("riemannian/integrator.jl")
54-
export GeneralizedLeapfrog
54+
export GeneralizedLeapfrog, ImplicitMidpoint
5555
include("riemannian/hamiltonian.jl")
5656

5757
include("trajectory.jl")

src/riemannian/integrator.jl

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,47 @@ function Base.show(io::IO, l::GeneralizedLeapfrog)
2424
return print(io, "GeneralizedLeapfrog(ϵ=", round.(l.ϵ; sigdigits=3), ", n=", l.n, ")")
2525
end
2626

27-
# fallback to ignore return_cache & cache kwargs for other ∂H∂θ
28-
function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing)
29-
dv = ∂H∂θ(h, θ, r)
30-
return return_cache ? (dv, nothing) : dv
27+
abstract type AbstractImplicitMidpoint{T} <: AbstractIntegrator end
28+
29+
step_size(lf::AbstractImplicitMidpoint) = lf.ϵ
30+
jitter(::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, lf::AbstractImplicitMidpoint) = lf
31+
function temper(
32+
lf::AbstractImplicitMidpoint, r, ::NamedTuple{(:i, :is_half),<:Tuple{Integer,Bool}}, ::Int
33+
)
34+
return r
3135
end
36+
stat(lf::AbstractImplicitMidpoint) = (step_size=step_size(lf), nom_step_size=nom_step_size(lf))
37+
update_nom_step_size(lf::AbstractImplicitMidpoint, ϵ) = @set lf.ϵ = ϵ
38+
39+
"""
40+
$(TYPEDEF)
41+
42+
Implicit midpoint integrator with fixed step size `ϵ`.
43+
44+
# Fields
45+
46+
$(TYPEDFIELDS)
47+
48+
49+
## References
50+
51+
1. James A. Brofos, Roy R. Lederman. "Evaluating the Implicit Midpoint
52+
Integrator for Riemannian Manifold Hamiltonian Monte Carlo"
53+
"""
54+
struct ImplicitMidpoint{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
55+
"Step size."
56+
ϵ::T
57+
n::Int
58+
end
59+
function Base.show(io::IO, l::ImplicitMidpoint)
60+
return print(io, "ImplicitMidpoint(ϵ=", round.(l.ϵ; sigdigits=3), ", n=", l.n, ")")
61+
end
62+
63+
# fallback to ignore return_cache & cache kwargs for other ∂H∂θ
64+
# function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing)
65+
# dv = ∂H∂θ(h, θ, r)
66+
# return return_cache ? (dv, nothing) : dv
67+
# end
3268

3369
# TODO(Kai) make sure vectorization works
3470
# TODO(Kai) check if tempering is valid
@@ -104,3 +140,63 @@ function step(
104140
end
105141
return res
106142
end
143+
144+
function step(
145+
lf::ImplicitMidpoint{T},
146+
h::Hamiltonian,
147+
z::P,
148+
n_steps::Int=1;
149+
fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0
150+
full_trajectory::Val{FullTraj}=Val(false),
151+
) where {T<:AbstractScalarOrVec{<:AbstractFloat},TP,P<:PhasePoint{TP},FullTraj}
152+
n_steps = abs(n_steps) # to support `n_steps < 0` cases
153+
154+
ϵ = fwd ? step_size(lf) : -step_size(lf)
155+
ϵ = ϵ'
156+
157+
if !(T <: AbstractFloat) || !(TP <: AbstractVector)
158+
@warn "Vectorization is not tested for ImplicitMidpoint."
159+
end
160+
161+
res = if FullTraj
162+
Vector{P}(undef, n_steps)
163+
else
164+
z
165+
end
166+
167+
for i in 1:n_steps
168+
θ_init, r_init = z.θ, z.r
169+
170+
171+
θ_full = θ_init
172+
r_full = r_init
173+
for j in 1:(lf.n)
174+
θ_bar = (θ_full + θ_init) / 2
175+
r_bar = (r_full + r_init) / 2
176+
177+
dHdr = ∂H∂r(h, θ_bar, r_bar)
178+
(; value, gradient) = ∂H∂θ(h, θ_bar, r_bar)
179+
180+
θ_full = θ_init + ϵ * dHdr
181+
r_full = r_init - ϵ * gradient
182+
end
183+
184+
(; value, gradient) = ∂H∂θ(h, θ_full, r_full)
185+
z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient))
186+
187+
if FullTraj
188+
res[i] = z
189+
else
190+
res = z
191+
end
192+
if !isfinite(z)
193+
# Remove undef
194+
if FullTraj
195+
res = res[isassigned.(Ref(res), 1:n_steps)]
196+
end
197+
break
198+
end
199+
end
200+
201+
return res
202+
end

0 commit comments

Comments
 (0)