Skip to content

Commit 85adcbd

Browse files
devmotionsunxd3
andauthored
Fix show definitions (#466)
* Fix `show` definitions * Update src/integrator.jl Co-authored-by: Xianda Sun <[email protected]> --------- Co-authored-by: Xianda Sun <[email protected]>
1 parent 7290b91 commit 85adcbd

File tree

8 files changed

+62
-16
lines changed

8 files changed

+62
-16
lines changed

src/adaptation/Adaptation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct NaiveHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractAd
3737
pc::M
3838
ssa::Tssa
3939
end
40-
function Base.show(io::IO, ::MIME"text/plain", a::NaiveHMCAdaptor)
40+
function Base.show(io::IO, a::NaiveHMCAdaptor)
4141
return print(io, "NaiveHMCAdaptor(pc=", a.pc, ", ssa=", a.ssa, ")")
4242
end
4343

src/adaptation/massmatrix.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323

2424
struct UnitMassMatrix{T<:AbstractFloat} <: MassMatrixAdaptor end
2525

26-
function Base.show(io::IO, mime::MIME"text/plain", ::UnitMassMatrix{T}) where {T}
26+
function Base.show(io::IO, ::UnitMassMatrix{T}) where {T}
2727
return print(io, "UnitMassMatrix{", T, "} adaptor")
2828
end
2929

@@ -93,7 +93,7 @@ mutable struct WelfordVar{T<:AbstractFloat,E<:AbstractVecOrMat{T},V<:AbstractVec
9393
end
9494
end
9595

96-
function Base.show(io::IO, mime::MIME"text/plain", ::WelfordVar{T}) where {T}
96+
function Base.show(io::IO, ::WelfordVar{T}) where {T}
9797
return print(io, "WelfordVar{", T, "} adaptor")
9898
end
9999

@@ -194,7 +194,7 @@ mutable struct WelfordCov{F<:AbstractFloat,C<:AbstractMatrix{F}} <: DenseMatrixE
194194
cov::C
195195
end
196196

197-
function Base.show(io::IO, mime::MIME"text/plain", ::WelfordCov{T}) where {T}
197+
function Base.show(io::IO, ::WelfordCov{T}) where {T}
198198
return print(io, "WelfordCov{", T, "} adaptor")
199199
end
200200

src/adaptation/stan_adaptor.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function initialize!(
4949
return nothing
5050
end
5151

52-
function Base.show(io::IO, mime::MIME"text/plain", state::StanHMCAdaptorState)
52+
function Base.show(io::IO, state::StanHMCAdaptorState)
5353
print(io, "window(", state.window_start, ", ", state.window_end, "), window_splits(")
5454
join(io, state.window_splits, ", ")
5555
return print(io, ")")
@@ -66,6 +66,23 @@ struct StanHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractAda
6666
window_size::Int
6767
state::StanHMCAdaptorState
6868
end
69+
70+
function Base.show(io::IO, a::StanHMCAdaptor)
71+
return print(
72+
io,
73+
"StanHMCAdaptor(",
74+
a.pc,
75+
", ",
76+
a.ssa,
77+
"; init_buffer=",
78+
a.init_buffer,
79+
", term_buffer=",
80+
a.term_buffer,
81+
", window_size=",
82+
a.window_size,
83+
")",
84+
)
85+
end
6986
function Base.show(io::IO, mime::MIME"text/plain", a::StanHMCAdaptor)
7087
return print(
7188
io,

src/adaptation/stepsize.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ getϵ(ss::StepSizeAdaptor) = ss.state.ϵ
7777
struct FixedStepSize{T<:AbstractScalarOrVec{<:AbstractFloat}} <: StepSizeAdaptor
7878
ϵ::T
7979
end
80-
function Base.show(io::IO, mime::MIME"text/plain", a::FixedStepSize)
80+
function Base.show(io::IO, a::FixedStepSize)
8181
return print(io, "FixedStepSize adaptor with step size ", a.ϵ)
8282
end
8383

@@ -86,7 +86,7 @@ getϵ(fss::FixedStepSize) = fss.ϵ
8686
struct ManualSSAdaptor{T<:AbstractScalarOrVec{<:AbstractFloat}} <: StepSizeAdaptor
8787
state::MSSState{T}
8888
end
89-
function Base.show(io::IO, mime::MIME"text/plain", a::ManualSSAdaptor{T}) where {T}
89+
function Base.show(io::IO, a::ManualSSAdaptor{T}) where {T}
9090
return print(io, "ManualSSAdaptor{", T, "} with step size of ", a.state.ϵ)
9191
end
9292

@@ -119,6 +119,23 @@ struct NesterovDualAveraging{T<:AbstractFloat,S<:AbstractScalarOrVec{T}} <: Step
119119
δ::T
120120
state::DAState{S}
121121
end
122+
123+
function Base.show(io::IO, a::NesterovDualAveraging)
124+
print(
125+
io,
126+
"NesterovDualAveraging(",
127+
a.γ,
128+
", ",
129+
a.t_0,
130+
", ",
131+
a.κ,
132+
", ",
133+
a.δ,
134+
", ",
135+
a.state.ϵ,
136+
")",
137+
)
138+
end
122139
function Base.show(io::IO, mime::MIME"text/plain", a::NesterovDualAveraging{T}) where {T}
123140
return print(
124141
io,

src/hamiltonian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ struct Hamiltonian{M<:AbstractMetric,K<:AbstractKinetic,Tlogπ,T∂logπ∂θ}
44
ℓπ::Tlogπ
55
∂ℓπ∂θ::T∂logπ∂θ
66
end
7-
function Base.show(io::IO, mime::MIME"text/plain", h::Hamiltonian)
7+
function Base.show(io::IO, h::Hamiltonian)
88
return print(
99
io,
1010
"Hamiltonian with ",

src/integrator.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ struct Leapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
7272
"Step size."
7373
ϵ::T
7474
end
75-
function Base.show(io::IO, mime::MIME"text/plain", l::Leapfrog)
76-
return print(io, "Leapfrog with step size ϵ=", round.(l.ϵ; sigdigits=3), ")")
75+
function Base.show(io::IO, l::Leapfrog)
76+
return print(io, "Leapfrog with step size ϵ=", round.(l.ϵ; sigdigits=3))
7777
end
7878
integrator_eltype(i::AbstractLeapfrog{T}) where {T<:AbstractFloat} = T
7979

@@ -120,7 +120,7 @@ end
120120

121121
JitteredLeapfrog(ϵ0, jitter) = JitteredLeapfrog(ϵ0, jitter, ϵ0)
122122

123-
function Base.show(io::IO, mime::MIME"text/plain", l::JitteredLeapfrog)
123+
function Base.show(io::IO, l::JitteredLeapfrog)
124124
return print(
125125
io,
126126
"JitteredLeapfrog with step size ",
@@ -178,7 +178,7 @@ struct TemperedLeapfrog{FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}} <: Abstrac
178178
α::FT
179179
end
180180

181-
function Base.show(io::IO, mime::MIME"text/plain", l::TemperedLeapfrog)
181+
function Base.show(io::IO, l::TemperedLeapfrog)
182182
return print(
183183
io,
184184
"TemperedLeapfrog with step size ϵ=",

src/metric.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ abstract type AbstractMetric end
77

88
_string_M⁻¹(mat::AbstractMatrix, n_chars::Int=32) = _string_M⁻¹(diag(mat), n_chars)
99
function _string_M⁻¹(vec::AbstractVector, n_chars::Int=32)
10-
s_vec = string(vec)
10+
s_vec = repr(vec; context=(:compact => true))
1111
l = length(s_vec)
1212
s_dots = " ...]"
1313
n_diag_chars = n_chars - length(s_dots)
@@ -33,6 +33,10 @@ renew(ue::UnitEuclideanMetric, M⁻¹) = UnitEuclideanMetric(M⁻¹, ue.size)
3333
Base.eltype(::UnitEuclideanMetric{T}) where {T} = T
3434
Base.size(e::UnitEuclideanMetric) = e.size
3535
Base.size(e::UnitEuclideanMetric, dim::Int) = e.size[dim]
36+
37+
function Base.show(io::IO, uem::UnitEuclideanMetric{T}) where {T}
38+
print(io, "UnitEuclideanMetric(", T, ", ", uem.size, ")")
39+
end
3640
function Base.show(io::IO, ::MIME"text/plain", uem::UnitEuclideanMetric{T}) where {T}
3741
return print(
3842
io,
@@ -66,6 +70,10 @@ renew(ue::DiagEuclideanMetric, M⁻¹) = DiagEuclideanMetric(M⁻¹)
6670

6771
Base.eltype(::DiagEuclideanMetric{T}) where {T} = T
6872
Base.size(e::DiagEuclideanMetric, dim...) = size(e.M⁻¹, dim...)
73+
74+
function Base.show(io::IO, dem::DiagEuclideanMetric)
75+
print(io, "DiagEuclideanMetric(", _string_M⁻¹(dem.M⁻¹), ")")
76+
end
6977
function Base.show(io::IO, ::MIME"text/plain", dem::DiagEuclideanMetric{T}) where {T}
7078
return print(
7179
io,
@@ -110,6 +118,10 @@ renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹)
110118

111119
Base.eltype(::DenseEuclideanMetric{T}) where {T} = T
112120
Base.size(e::DenseEuclideanMetric, dim...) = size(e._temp, dim...)
121+
122+
function Base.show(io::IO, dem::DenseEuclideanMetric)
123+
print(io, "DenseEuclideanMetric(", _string_M⁻¹(dem.M⁻¹), ")")
124+
end
113125
function Base.show(io::IO, ::MIME"text/plain", dem::DenseEuclideanMetric{T}) where {T}
114126
return print(
115127
io,

src/trajectory.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ struct SliceTS{F<:AbstractFloat,P<:PhasePoint} <: AbstractTrajectorySampler
108108
n::Int
109109
end
110110

111-
function Base.show(io::IO, mime::MIME"text/plain", s::SliceTS)
111+
function Base.show(io::IO, s::SliceTS)
112112
return print(
113113
io,
114114
"SliceTS with slice variable ℓu=",
@@ -225,7 +225,7 @@ end
225225

226226
ConstructionBase.constructorof(::Type{<:Trajectory{TS}}) where {TS} = Trajectory{TS}
227227

228-
function Base.show(io::IO, mime::MIME"text/plain", τ::Trajectory{TS}) where {TS}
228+
function Base.show(io::IO, τ::Trajectory{TS}) where {TS}
229229
return print(
230230
io,
231231
"Trajectory{",
@@ -482,7 +482,7 @@ struct Termination
482482
numerical::Bool
483483
end
484484

485-
function Base.show(io::IO, mime::MIME"text/plain", d::Termination)
485+
function Base.show(io::IO, d::Termination)
486486
return print(
487487
io, "Termination reasons of (dynamic=", d.dynamic, ", numerical=", d.numerical, ")"
488488
)

0 commit comments

Comments
 (0)