@@ -129,38 +129,75 @@ function ∂H∂θ_cache(
129129
130130 G, Q, λ, softabsλ = softabs (H, h. metric. map. α)
131131
132- R = diagm (1 ./ softabsλ)
132+ R = Diagonal (1 ./ softabsλ)
133133
134134 # softabsΛ = diagm(softabsλ)
135135 # M = inv(softabsΛ) * Q' * r
136136 # M = R * Q' * r # equiv to above but avoid inv
137137
138138 J = make_J (λ, h. metric. map. α)
139139
140+ tmp1 = similar (H)
141+ tmp2 = similar (H)
142+ tmp3 = similar (H)
143+ tmp4 = similar (softabsλ)
144+
140145 # ! Based on the two equations from the right column of Page 3 of Betancourt (2012)
141- term_1_cached = Q * (R .* J) * Q'
146+ tmp1 = R .* J
147+ # tmp2 = Q * tmp1
148+ mul! (tmp2, Q, tmp1)
149+
150+ # tmp1 = tmp2 * Q'
151+ mul! (tmp1, tmp2, Q' )
152+
153+ term_1_cached = tmp1
154+
155+ # Cache first part of the equation
156+ term_1_prod = similar (∂ℓπ∂θ)
157+ @inbounds for i in 1 : length (∂ℓπ∂θ)
158+ ∂H∂θᵢ = ∂H∂θ[:, :, i]
159+ term_1_prod[i] = ∂ℓπ∂θ[i] - 1 / 2 * tr (term_1_cached * ∂H∂θᵢ)
160+ end
161+
142162 else
143- ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached = cache
163+ ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_prod, tmp1, tmp2, tmp3, tmp4 = cache
144164 end
145165 d = length (∂ℓπ∂θ)
146- D = diagm ((Q' * r) ./ softabsλ)
147- term_2_cached = Q * D * J * D * Q'
148- g =
149- - mapreduce (vcat, 1 : d) do i
150- ∂H∂θᵢ = ∂H∂θ[:, :, i]
151- # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1)
152- # NOTE Some further optimization can be done here: cache the 1st product all together
153- ∂ℓπ∂θ[i] - 1 / 2 * tr (term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr (term_2_cached * ∂H∂θᵢ) # (v2) cache friendly
154- end
166+ mul! (tmp4, Q' , r)
167+ D = Diagonal (tmp4 ./ softabsλ)
168+
169+ # tmp1 = D * J
170+ mul! (tmp1, D, J)
171+ # tmp2 = tmp1 * D
172+ mul! (tmp2, tmp1, D)
173+ # tmp1 = Q * tmp2
174+ mul! (tmp1, Q, tmp2)
175+ # tmp2 = tmp1 * Q'
176+ mul! (tmp2, tmp1, Q' )
177+ term_2_cached = tmp2
178+
179+ # g =
180+ # -mapreduce(vcat, 1:d) do i
181+ # ∂H∂θᵢ = ∂H∂θ[:, :, i]
182+ # # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1)
183+ # # NOTE Some further optimization can be done here: cache the 1st product all together
184+ # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly
185+ # end
186+ g = similar (∂ℓπ∂θ)
187+ @inbounds for i in 1 : d
188+ ∂H∂θᵢ = ∂H∂θ[:, :, i]
189+ g[i] = term_1_prod[i] + 1 / 2 * tr (term_2_cached * ∂H∂θᵢ)
190+ end
191+ g .*= - 1
155192
156193 dv = DualValue (ℓπ, g)
157- return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached )) : dv
194+ return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_prod, tmp1, tmp2, tmp3, tmp4 )) : dv
158195end
159196
160197# ! Eq (14) of Girolami & Calderhead (2011)
161198function ∂H∂r (
162- h:: Hamiltonian{<:DenseRiemannianMetric} , θ:: AbstractVecOrMat , r:: AbstractVecOrMat
163- )
199+ h:: Hamiltonian{<:DenseRiemannianMetric} , θ:: AbstractVecOrMat{T} , r:: AbstractVecOrMat{T}
200+ ) where {T}
164201 H = h. metric. G (θ)
165202 # if !all(isfinite, H)
166203 # println("θ: ", θ)
0 commit comments