Skip to content

Commit 7d4a86f

Browse files
committed
Small optimisations for hamiltonian.jl
1 parent d96c29c commit 7d4a86f

File tree

1 file changed

+52
-15
lines changed

1 file changed

+52
-15
lines changed

src/riemannian/hamiltonian.jl

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -129,38 +129,75 @@ function ∂H∂θ_cache(
129129

130130
G, Q, λ, softabsλ = softabs(H, h.metric.map.α)
131131

132-
R = diagm(1 ./ softabsλ)
132+
R = Diagonal(1 ./ softabsλ)
133133

134134
# softabsΛ = diagm(softabsλ)
135135
# M = inv(softabsΛ) * Q' * r
136136
# M = R * Q' * r # equiv to above but avoid inv
137137

138138
J = make_J(λ, h.metric.map.α)
139139

140+
tmp1 = similar(H)
141+
tmp2 = similar(H)
142+
tmp3 = similar(H)
143+
tmp4 = similar(softabsλ)
144+
140145
#! Based on the two equations from the right column of Page 3 of Betancourt (2012)
141-
term_1_cached = Q * (R .* J) * Q'
146+
tmp1 = R .* J
147+
# tmp2 = Q * tmp1
148+
mul!(tmp2, Q, tmp1)
149+
150+
# tmp1 = tmp2 * Q'
151+
mul!(tmp1, tmp2, Q')
152+
153+
term_1_cached = tmp1
154+
155+
# Cache first part of the equation
156+
term_1_prod = similar(∂ℓπ∂θ)
157+
@inbounds for i in 1:length(∂ℓπ∂θ)
158+
∂H∂θᵢ = ∂H∂θ[:, :, i]
159+
term_1_prod[i] = ∂ℓπ∂θ[i] - 1/2 * tr(term_1_cached * ∂H∂θᵢ)
160+
end
161+
142162
else
143-
ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached = cache
163+
ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_prod, tmp1, tmp2, tmp3, tmp4 = cache
144164
end
145165
d = length(∂ℓπ∂θ)
146-
D = diagm((Q' * r) ./ softabsλ)
147-
term_2_cached = Q * D * J * D * Q'
148-
g =
149-
-mapreduce(vcat, 1:d) do i
150-
∂H∂θᵢ = ∂H∂θ[:, :, i]
151-
# ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1)
152-
# NOTE Some further optimization can be done here: cache the 1st product all together
153-
∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly
154-
end
166+
mul!(tmp4, Q', r)
167+
D = Diagonal(tmp4 ./ softabsλ)
168+
169+
# tmp1 = D * J
170+
mul!(tmp1, D, J)
171+
# tmp2 = tmp1 * D
172+
mul!(tmp2, tmp1, D)
173+
# tmp1 = Q * tmp2
174+
mul!(tmp1, Q, tmp2)
175+
# tmp2 = tmp1 * Q'
176+
mul!(tmp2, tmp1, Q')
177+
term_2_cached = tmp2
178+
179+
# g =
180+
# -mapreduce(vcat, 1:d) do i
181+
# ∂H∂θᵢ = ∂H∂θ[:, :, i]
182+
# # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1)
183+
# # NOTE Some further optimization can be done here: cache the 1st product all together
184+
# ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly
185+
# end
186+
g = similar(∂ℓπ∂θ)
187+
@inbounds for i in 1:d
188+
∂H∂θᵢ = ∂H∂θ[:, :, i]
189+
g[i] = term_1_prod[i] + 1/2 * tr(term_2_cached * ∂H∂θᵢ)
190+
end
191+
g .*= -1
155192

156193
dv = DualValue(ℓπ, g)
157-
return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv
194+
return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_prod, tmp1, tmp2, tmp3, tmp4)) : dv
158195
end
159196

160197
#! Eq (14) of Girolami & Calderhead (2011)
161198
function ∂H∂r(
162-
h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat
163-
)
199+
h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat{T}, r::AbstractVecOrMat{T}
200+
) where {T}
164201
H = h.metric.G(θ)
165202
# if !all(isfinite, H)
166203
# println("θ: ", θ)

0 commit comments

Comments
 (0)