@@ -24,11 +24,47 @@ function Base.show(io::IO, l::GeneralizedLeapfrog)
2424 return print (io, " GeneralizedLeapfrog(ϵ=" , round .(l. ϵ; sigdigits= 3 ), " , n=" , l. n, " )" )
2525end
2626
27- # fallback to ignore return_cache & cache kwargs for other ∂H∂θ
28- function ∂H∂θ_cache (h, θ, r; return_cache= false , cache= nothing )
29- dv = ∂H∂θ (h, θ, r)
30- return return_cache ? (dv, nothing ) : dv
27+ abstract type AbstractImplicitMidpoint{T} <: AbstractIntegrator end
28+
29+ step_size (lf:: AbstractImplicitMidpoint ) = lf. ϵ
30+ jitter (:: Union{AbstractRNG,AbstractVector{<:AbstractRNG}} , lf:: AbstractImplicitMidpoint ) = lf
31+ function temper (
32+ lf:: AbstractImplicitMidpoint , r, :: NamedTuple{(:i, :is_half),<:Tuple{Integer,Bool}} , :: Int
33+ )
34+ return r
3135end
36+ stat (lf:: AbstractImplicitMidpoint ) = (step_size= step_size (lf), nom_step_size= nom_step_size (lf))
37+ update_nom_step_size (lf:: AbstractImplicitMidpoint , ϵ) = @set lf. ϵ = ϵ
38+
39+ """
40+ $(TYPEDEF)
41+
42+ Implicit midpoint integrator with fixed step size `ϵ`.
43+
44+ # Fields
45+
46+ $(TYPEDFIELDS)
47+
48+
49+ ## References
50+
51+ 1. James A. Brofos, Roy R. Lederman. "Evaluating the Implicit Midpoint
52+ Integrator for Riemannian Manifold Hamiltonian Monte Carlo"
53+ """
54+ struct ImplicitMidpoint{T<: AbstractScalarOrVec{<:AbstractFloat} } <: AbstractLeapfrog{T}
55+ " Step size."
56+ ϵ:: T
57+ n:: Int
58+ end
59+ function Base. show (io:: IO , l:: ImplicitMidpoint )
60+ return print (io, " ImplicitMidpoint(ϵ=" , round .(l. ϵ; sigdigits= 3 ), " , n=" , l. n, " )" )
61+ end
62+
63+ # fallback to ignore return_cache & cache kwargs for other ∂H∂θ
64+ # function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing)
65+ # dv = ∂H∂θ(h, θ, r)
66+ # return return_cache ? (dv, nothing) : dv
67+ # end
3268
3369# TODO (Kai) make sure vectorization works
3470# TODO (Kai) check if tempering is valid
@@ -104,3 +140,63 @@ function step(
104140 end
105141 return res
106142end
143+
144+ function step (
145+ lf:: ImplicitMidpoint{T} ,
146+ h:: Hamiltonian ,
147+ z:: P ,
148+ n_steps:: Int = 1 ;
149+ fwd:: Bool = n_steps > 0 , # simulate hamiltonian backward when n_steps < 0
150+ full_trajectory:: Val{FullTraj} = Val (false ),
151+ ) where {T<: AbstractScalarOrVec{<:AbstractFloat} ,TP,P<: PhasePoint{TP} ,FullTraj}
152+ n_steps = abs (n_steps) # to support `n_steps < 0` cases
153+
154+ ϵ = fwd ? step_size (lf) : - step_size (lf)
155+ ϵ = ϵ'
156+
157+ if ! (T <: AbstractFloat ) || ! (TP <: AbstractVector )
158+ @warn " Vectorization is not tested for ImplicitMidpoint."
159+ end
160+
161+ res = if FullTraj
162+ Vector {P} (undef, n_steps)
163+ else
164+ z
165+ end
166+
167+ for i in 1 : n_steps
168+ θ_init, r_init = z. θ, z. r
169+
170+
171+ θ_full = θ_init
172+ r_full = r_init
173+ for j in 1 : (lf. n)
174+ θ_bar = (θ_full + θ_init) / 2
175+ r_bar = (r_full + r_init) / 2
176+
177+ dHdr = ∂H∂r (h, θ_bar, r_bar)
178+ (; value, gradient) = ∂H∂θ (h, θ_bar, r_bar)
179+
180+ θ_full = θ_init + ϵ * dHdr
181+ r_full = r_init - ϵ * gradient
182+ end
183+
184+ (; value, gradient) = ∂H∂θ (h, θ_full, r_full)
185+ z = phasepoint (h, θ_full, r_full; ℓπ= DualValue (value, gradient))
186+
187+ if FullTraj
188+ res[i] = z
189+ else
190+ res = z
191+ end
192+ if ! isfinite (z)
193+ # Remove undef
194+ if FullTraj
195+ res = res[isassigned .(Ref (res), 1 : n_steps)]
196+ end
197+ break
198+ end
199+ end
200+
201+ return res
202+ end
0 commit comments