Skip to content

Commit f4c9b58

Browse files
authored
Merge pull request #31 from gdkrmr/ND-Histograms
Nd histograms
2 parents d98b9c2 + 69db6aa commit f4c9b58

File tree

3 files changed

+174
-40
lines changed

3 files changed

+174
-40
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
name = "WeightedOnlineStats"
22
uuid = "bbac0a1f-7c9d-5672-960b-c6ca726e5d5d"
33
authors = ["Guido Kraemer <[email protected]>", "Martin Gutwin <[email protected]>"]
4-
version = "0.3.2"
4+
version = "0.4.0"
55

66
[compat]
77
julia = "1"
8+
MultivariateStats = "0.7"
9+
OnlineStats = "1"
10+
OnlineStatsBase = "1"
11+
StatsBase = "0.30,0.31,0.32"
812

913
[deps]
1014
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/histogram.jl

Lines changed: 120 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
# Modifying it to work with WeightedOnlineStats
44
##############################################################
55

6+
import LinearAlgebra
67
abstract type WeightedHistogramStat{T} <: WeightedOnlineStat{T} end
8+
abstract type WeightedHist{T} <: WeightedHistogramStat{T} end
79
split_candidates(o::WeightedHistogramStat) = midpoints(o)
810
Statistics.mean(o::WeightedHistogramStat) = mean(midpoints(o), fweights(counts(o)))
911
Statistics.var(o::WeightedHistogramStat) = var(midpoints(o), fweights(counts(o)); corrected=true)
10-
Statistics.std(o::WeightedHistogramStat) = sqrt(var(o))
12+
Statistics.std(o::WeightedHistogramStat) = sqrt.(var(o))
1113
Statistics.median(o::WeightedHistogramStat) = quantile(o, .5)
1214

1315
function Base.show(io::IO, o::WeightedHistogramStat)
@@ -25,7 +27,12 @@ Create a histogram with bin partition defined by `edges`.
2527
- If `left`, the bins will be left-closed.
2628
- If `closed`, the bin on the end will be closed.
2729
- E.g. for a two bin histogram ``[a, b), [b, c)`` vs. ``[a, b), [b, c]``
28-
# Example
30+
31+
If `edges` is a tuple instead of an array, a multidimensional histogram will be
32+
generated that behaves like a `WeightedOnlineStat{VectorOb}`.
33+
34+
# Examples
35+
2936
o = fit!(WeightedHist(-5:.1:5), randn(10^6))
3037
3138
# approximate statistics
@@ -38,68 +45,152 @@ Create a histogram with bin partition defined by `edges`.
3845
extrema(o)
3946
area(o)
4047
pdf(o)
48+
49+
## 2d Histogram
50+
51+
hist2d = fit!(WeightedHist((-5:1:5, -5:1:5) ), randn(10000,2), rand(10000))
52+
value(hist2d).y
4153
"""
42-
struct WeightedHist{T, R} <: WeightedHistogramStat{T}
54+
struct WeightedHist1D{R} <: WeightedHist{Float64}
55+
edges::R
56+
counts::Vector{Int}
57+
meanw::Vector{Float64}
58+
outcount::Vector{Int}
59+
meanwout::Vector{Float64}
60+
left::Bool
61+
closed::Bool
62+
end
63+
struct WeightedHistND{R, N} <: WeightedHist{OnlineStats.VectorOb}
4364
edges::R
44-
counts::Vector{Float64}
45-
out::Vector{Float64}
65+
counts::Array{Int,N}
66+
meanw::Array{Float64,N}
67+
outcount::Array{Int,N}
68+
meanwout::Array{Float64,N}
4669
left::Bool
4770
closed::Bool
71+
end
4872

49-
function WeightedHist(edges::R, T::Type = eltype(edges); left::Bool=true, closed::Bool = true) where {R<:AbstractVector}
50-
new{T,R}(edges, zeros(Int, length(edges) - 1), [0,0], left, closed)
73+
function WeightedHist(edges; left::Bool=true, closed::Bool = true)
74+
edges = isa(edges,Tuple) ? edges : (edges,)
75+
counts = zeros(Int, map(i->length(i)-1, edges))
76+
meanw = zeros(Float64, map(i->length(i)-1, edges))
77+
outcount = zeros(Int,ntuple(_->3,length(edges)))
78+
meanwout = zeros(Float64,ntuple(_->3,length(edges)))
79+
if length(edges) == 1
80+
WeightedHist1D(edges[1],counts,meanw,outcount,meanwout,left,closed)
81+
else
82+
WeightedHistND{typeof(edges),length(edges)}(edges, counts, meanw,outcount,meanwout, left, closed)
5183
end
5284
end
53-
nobs(o::WeightedHist) = sum(o.counts) + sum(o.out)
54-
weightsum(o::WeightedHist) = nobs(o)
55-
value(o::WeightedHist) = (x=o.edges, y=o.counts)
56-
57-
midpoints(o::WeightedHist) = midpoints(o.edges)
85+
# Special case for 1D Histogram
86+
nobs(o::WeightedHist) = sum(o.counts) + sum(o.outcount)
87+
weightsum(o::WeightedHist) = LinearAlgebra.dot(o.counts, o.meanw) + LinearAlgebra.dot(o.outcount,o.meanwout)
88+
value(o::WeightedHist) = (x=edges(o), y=o.counts .* o.meanw)
89+
binindices(o::WeightedHistND{<:Any,N}, x::AbstractVector) where N = binindices(o, ntuple(i->x[i],N))
90+
binindices(o::WeightedHist1D,x) = OnlineStats.binindex(o.edges, x, o.left, o.closed)
91+
binindices(o::WeightedHistND,x) = CartesianIndex(map((e,ix)->OnlineStats.binindex(e, ix, o.left, o.closed), o.edges, x))
92+
midpoints(o::WeightedHistND) = Iterators.product(map(midpoints,o.edges)...)
93+
midpoints(o::WeightedHist1D) = midpoints(edges(o))
5894
counts(o::WeightedHist) = o.counts
5995
edges(o::WeightedHist) = o.edges
96+
function Statistics.mean(o::WeightedHist)
97+
weights = value(o).y
98+
N = ndims(o.counts)
99+
r = ntuple(N) do idim
100+
a = map(i->i[idim],midpoints(o))
101+
mean(a,fweights(weights))
102+
end
103+
N==1 ? r[1] : r
104+
end
105+
function Statistics.var(o::WeightedHist)
106+
weights = value(o).y
107+
N = ndims(o.counts)
108+
r = ntuple(N) do idim
109+
a = map(i->i[idim],midpoints(o))
110+
var(a,fweights(weights),corrected=true)
111+
end
112+
N==1 ? r[1] : r
113+
end
114+
Statistics.std(o::WeightedHist) = sqrt.(var(o))
115+
Statistics.median(o::WeightedHist) = quantile(o, .5)
60116

61-
function Base.extrema(o::WeightedHist)
117+
function Base.extrema(o::WeightedHist1D)
118+
x, y = midpoints(o), counts(o)
119+
x[findfirst(!iszero,y)],x[findlast(!iszero,y)]
120+
end
121+
function Base.extrema(o::WeightedHistND{<:Any,N}) where N
62122
x, y = midpoints(o), counts(o)
63-
x[findfirst(x -> x > 0, y)], x[findlast(x -> x > 0, y)]
123+
ntuple(N) do idim
124+
avalue = any(!iszero, y, dims = setdiff(1:N,idim))[:]
125+
x.iterators[idim][findfirst(avalue)],x.iterators[idim][findlast(avalue)]
126+
end
64127
end
128+
65129
function Statistics.quantile(o::WeightedHist, p = [0, .25, .5, .75, 1])
66130
x, y = midpoints(o), counts(o)
67-
inds = findall(x -> x != 0, y)
68-
quantile(x[inds], fweights(y[inds]), p)
131+
N = ndims(y)
132+
inds = findall(!iszero, y)
133+
yweights = fweights(y[inds])
134+
subset = collect(x)[inds]
135+
r = ntuple(N) do idim
136+
data = map(i->i[idim],subset)
137+
quantile(data, fweights(y[inds]), p)
138+
end
139+
if N==1
140+
return r[1]
141+
else
142+
return r
143+
end
69144
end
70145

71146
function area(o::WeightedHist)
72147
c = o.counts
73148
e = o.edges
74-
if isa(e, AbstractRange)
75-
return step(e) * sum(c)
76-
else
77-
return sum((e[i+1] - e[i]) * c[i] for i in 1:length(c))
149+
return mapreduce(+, CartesianIndices(c)) do I
150+
ar = prod(map((ed,i)->ed[i+1]-ed[i],e,I.I))
151+
c[I]*ar
78152
end
79153
end
80154

155+
outindex(o, ci::CartesianIndex) = CartesianIndex(map((i,l)->i < 1 ? 1 : i > l ? 3 : 2, ci.I, size(o.counts)))
156+
outindex(o, ci::Int) = CartesianIndex(ci < 1 ? 1 : ci > length(o.counts) ? 3 : 2)
81157
function pdf(o::WeightedHist, y)
82-
i = OnlineStats.binindex(o.edges, y, o.left, o.closed)
83-
if i < 1 || i > length(o.counts)
84-
return 0.0
158+
ci = binindices(o, y)
159+
if all(isequal(2),outindex(o,ci).I)
160+
return o.counts[ci]*o.meanw[ci] / area(o) / weightsum(o)
85161
else
86-
return o.counts[i] / area(o)
162+
return 0.0
87163
end
88164
end
89165

90166
function _fit!(o::WeightedHist, x, wt)
91-
i = OnlineStats.binindex(o.edges, x, o.left, o.closed)
92-
if 1 i < length(o.edges)
93-
o.counts[i] += wt
167+
#length(x) == N || error("You must provide $(N) values for the histogram")
168+
ci = binindices(o, x)
169+
oi = outindex(o,ci)
170+
if all(isequal(2),oi.I)
171+
o.counts[ci] += 1
172+
o.meanw[ci] = smooth(o.meanw[ci], wt, 1.0 / o.counts[ci])
94173
else
95-
o.out[1 + (i > 0)] += wt
174+
o.outcount[oi] += 1
175+
o.meanwout[oi] = smooth(o.meanwout[oi], wt, 1.0 / o.outcount[oi])
96176
end
97177
end
98178

99179
function _merge!(o::WeightedHist, o2::WeightedHist)
100180
if o.edges == o2.edges
101181
for j in eachindex(o.counts)
102-
o.counts[j] += o2.counts[j]
182+
newcount = o.counts[j] + o2.counts[j]
183+
if newcount > 0
184+
o.meanw[j] = (o.meanw[j]*o.counts[j] + o2.meanw[j]*o2.counts[j])/newcount
185+
end
186+
o.counts[j] = newcount
187+
end
188+
for j in eachindex(o.outcount)
189+
newcount = o.outcount[j] + o2.outcount[j]
190+
if newcount > 0
191+
o.meanwout[j] = (o.meanwout[j]*o.outcount[j] + o2.meanwout[j]*o2.outcount[j])/newcount
192+
end
193+
o.outcount[j] = newcount
103194
end
104195
else
105196
@warn("WeightedHistogram edges do not align. Merging is approximate.")

test/test_hist.jl

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,23 +86,28 @@ end
8686
h = WeightedHist(-3:1:1)
8787

8888
fit!(h, -2.5,1.3)
89-
@test h.counts == [1.3, 0,0,0]
89+
@test h.counts == [1,0,0,0]
90+
@test h.meanw == [1.3,0,0,0]
9091

9192
fit!(h, (-2.1, 1.0))
92-
@test h.counts == [2.3, 0,0,0]
93+
@test value(h).y == [2.3, 0,0,0]
94+
@test h.counts == [2,0,0,0]
95+
@test h.meanw == [1.15,0,0,0]
9396

9497
fit!(h, (-20, 2.0))
95-
@test h.counts == [2.3, 0,0,0]
96-
@test h.out == [2.0, 0]
98+
@test value(h).y == [2.3, 0,0,0]
99+
@test h.outcount == [1, 0, 0]
100+
@test h.meanwout == [2.0,0,0]
97101

98102
fit!(h, 20, 1.7)
99-
@test h.counts == [2.3, 0,0,0]
100-
@test h.out == [2.0, 1.7]
103+
@test value(h).y == [2.3, 0,0,0]
104+
@test h.meanwout == [2.0, 0.0, 1.7]
105+
@test h.outcount == [1, 0, 1]
106+
101107

102108
fit!(h, -0.1, 1.1)
103-
@test h.counts == [2.3, 0,1.1,0]
109+
@test value(h).y == [2.3, 0,1.1,0]
104110
@test h.edges === -3:1:1
105-
@test h.out == [2.0, 1.7]
106111
end
107112

108113
@testset "merge!" begin
@@ -111,7 +116,7 @@ end
111116
fit!(h1, (1, 10))
112117
h1_copy = deepcopy(h1)
113118
@test merge!(h1, WeightedHist([-2,0,2])) == h1_copy
114-
@test merge!(h1_copy, h1).counts == [10, 20.]
119+
@test value(merge!(h1_copy, h1)).y == [10, 20.]
115120
end
116121

117122
@testset "stats" begin
@@ -122,12 +127,46 @@ end
122127
fit!(h, x, 1.0)
123128
fit!(ho, x)
124129
end
125-
126130
@test mean(h) mean(ho)
127131
@test std(h) std(ho)
128132
@test median(h) median(ho)
129133
@test nobs(h) nobs(ho)
130134
@test var(h) var(ho)
135+
@test all(extrema(h) .≈ extrema(ho))
136+
end
137+
138+
@testset "N-dimensional Hist" begin
139+
h = WeightedHist((-2:2:2,0:3:6))
140+
fit!(h,(-1.5,1.5),1.5)
141+
@test h.counts == [1 0; 0 0]
142+
@test h.meanw == [1.5 0; 0 0]
143+
144+
fit!(h,(-0.5,0.2),1.1)
145+
@test h.counts == [2 0; 0 0]
146+
@test h.meanw == [1.3 0;0 0]
147+
148+
fit!(h,(-3.0,0.0),1.5)
149+
@test h.counts == [2 0; 0 0]
150+
@test h.meanw == [1.3 0;0 0]
151+
@test h.outcount == [0 1 0; 0 0 0; 0 0 0]
152+
153+
fit!(h,(-10,-10),1.1)
154+
@test h.counts == [2 0; 0 0]
155+
@test h.meanw == [1.3 0;0 0]
156+
@test h.outcount == [1 1 0; 0 0 0; 0 0 0]
157+
158+
fit!(h,(1.5,4.5),2.6)
159+
@test h.counts == [2 0; 0 1]
160+
@test h.meanw == [1.3 0; 0 2.6]
161+
@test value(h) == (x=(-2:2:2, 0:3:6),y=[2.6 0;0 2.6])
162+
163+
@test mean(h) == (0.0,3.0)
164+
@test var(h) == (1.2380952380952381, 2.785714285714286)
165+
@test std(h) == (1.1126972805283737, 1.6690459207925605)
166+
@test median(h) == (-1.0, 1.5)
167+
@test nobs(h) == 5
168+
@test extrema(h) == ((-1,1),(1.5,4.5))
131169
end
132170

171+
133172
end

0 commit comments

Comments
 (0)