Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@ This modularity means that different HMC variants can be easily constructed by c
- Diagonal metric: `DiagEuclideanMetric(dim)`
- Dense metric: `DenseEuclideanMetric(dim)`

where `dim` is the dimensionality of the sampling space.
where `dim` is the dimension of the sampling space.

Furthermore, there is now an experimental dense Riemannian metric implementation, specifiable as `DenseRiemannianMetric(dim, premetric, premetric_sensitivities, metric_map=IdentityMap())`, with

- `dim`: again the dimension of the sampling space,
- `premetric`: a function which, for a given posterior position `pos`, computes either
a) a symmetric, **positive definite** matrix acting as the position dependent Riemannian metric (if `metric_map = IdentityMap()`), or
b) a symmetric, **not necessarily positive definite** matrix acting as the position dependent Riemannian metric after being passed through the `metric_map` argument, which will have to ensure that its return value *is* positive definite (like `metric_map = SoftAbsMap(alpha)`),
Comment on lines +17 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
- `premetric`: a function which, for a given posterior position `pos`, computes either
a) a symmetric, **positive definite** matrix acting as the position dependent Riemannian metric (if `metric_map = IdentityMap()`), or
b) a symmetric, **not necessarily positive definite** matrix acting as the position dependent Riemannian metric after being passed through the `metric_map` argument, which will have to ensure that its return value *is* positive definite (like `metric_map = SoftAbsMap(alpha)`),
- `premetric`: a function which, for a given posterior position `pos`, computes either
a) a symmetric, **positive definite** matrix acting as the position dependent Riemannian metric (if `metric_map = IdentityMap()`), or
b) a symmetric, **not necessarily positive definite** matrix acting as the position dependent Riemannian metric after being passed through the `metric_map` argument, which will have to ensure that its return value *is* positive definite (like `metric_map = SoftAbsMap(alpha)`),

- `premetric_sensitivities`: a function which, again for a given posterior position `pos`, computes the sensitivities with respect to this position of the **`premetric`** function,
- `metric_map=IdentityMap()`: a function which takes in `premetric(pos)` and returns a symmetric positive definite matrix. Provided options are `IdentityMap()` or `SoftAbsMap(alpha)`, with the `SoftAbsMap` type allowing to work directly with the `premetric` returning the Hessian of the log density function, which generally is not guaranteed to be positive definite..

### [Integrator (`integrator`)](@id integrator)

Expand Down
21 changes: 19 additions & 2 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,19 @@ module AdvancedHMC

using Statistics: mean, var, middle
using LinearAlgebra:
Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling
Symmetric,
UpperTriangular,
mul!,
ldiv!,
dot,
I,
diag,
cholesky,
UniformScaling,
logdet,
tr,
eigen,
diagm
using StatsFuns: logaddexp, logsumexp, loghalf
using Random: Random, AbstractRNG
using ProgressMeter: ProgressMeter
Expand Down Expand Up @@ -40,7 +52,7 @@ struct GaussianKinetic <: AbstractKinetic end
export GaussianKinetic

include("metric.jl")
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric, DenseRiemannianMetric

include("hamiltonian.jl")
export Hamiltonian
Expand All @@ -50,6 +62,11 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog
include("riemannian/integrator.jl")
export GeneralizedLeapfrog

include("riemannian/metric.jl")
export IdentityMap, SoftAbsMap, DenseRiemannianMetric

include("riemannian/hamiltonian.jl")

include("trajectory.jl")
export Trajectory,
HMCKernel,
Expand Down
Loading
Loading