Skip to content

Commit e67909f

Browse files
committed
examples
1 parent d783d80 commit e67909f

File tree

4 files changed

+538
-0
lines changed

4 files changed

+538
-0
lines changed

examples/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Examples
2+
3+
This folder contains the examples from the paper: <https://doi.org/10.48550/arXiv.2306.15243>. There are three scripts: nonlinear.jl, explicit.jl, and implicit.jl. All are written with for loops, solving problems of increasing size, but in practice I often ran them one size at a time as I found that produced more consistent timings. The last case (implicit.jl) requires a patch to NLsolve described here: <https://github.com/JuliaNLSolvers/NLsolve.jl/issues/281>. The new methodology, with ImplicitAD, will work fine without it since it doesn't propagate through the solver. But if you want to compare using revese mode directly through NLsolve the patch is required.
4+

examples/explicit.jl

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
using ImplicitAD
2+
using ForwardDiff
3+
using ReverseDiff
4+
using BenchmarkTools
5+
6+
struct Tsit5{TV, TF}
7+
k1::TV
8+
k2::TV
9+
k3::TV
10+
k4::TV
11+
k5::TV
12+
k6::TV
13+
ytemp2::TV
14+
ytemp3::TV
15+
ytemp4::TV
16+
ytemp5::TV
17+
ytemp6::TV
18+
c1::TF
19+
c2::TF
20+
c3::TF
21+
c4::TF
22+
a21::TF
23+
a31::TF
24+
a32::TF
25+
a41::TF
26+
a42::TF
27+
a43::TF
28+
a51::TF
29+
a52::TF
30+
a53::TF
31+
a54::TF
32+
a61::TF
33+
a62::TF
34+
a63::TF
35+
a64::TF
36+
a65::TF
37+
a71::TF
38+
a72::TF
39+
a73::TF
40+
a74::TF
41+
a75::TF
42+
a76::TF
43+
end
44+
45+
function Tsit5(ny, T)
46+
47+
k1 = Vector{T}(undef, ny)
48+
k2 = Vector{T}(undef, ny)
49+
k3 = Vector{T}(undef, ny)
50+
k4 = Vector{T}(undef, ny)
51+
k5 = Vector{T}(undef, ny)
52+
k6 = Vector{T}(undef, ny)
53+
ytemp2 = Vector{T}(undef, ny)
54+
ytemp3 = Vector{T}(undef, ny)
55+
ytemp4 = Vector{T}(undef, ny)
56+
ytemp5 = Vector{T}(undef, ny)
57+
ytemp6 = Vector{T}(undef, ny)
58+
59+
# constants
60+
c1 = 0.161; c2 = 0.327; c3 = 0.9; c4 = 0.9800255409045097;
61+
a21 = 0.161;
62+
a31 = -0.008480655492356989; a32 = 0.335480655492357;
63+
a41 = 2.8971530571054935; a42 = -6.359448489975075; a43 = 4.3622954328695815;
64+
a51 = 5.325864828439257; a52 = -11.748883564062828; a53 = 7.4955393428898365; a54 = -0.09249506636175525;
65+
a61 = 5.86145544294642; a62 = -12.92096931784711; a63 = 8.159367898576159; a64 = -0.071584973281401; a65 = -0.028269050394068383;
66+
a71 = 0.09646076681806523; a72 = 0.01; a73 = 0.4798896504144996; a74 = 1.379008574103742; a75 = -3.290069515436081; a76 = 2.324710524099774
67+
68+
return Tsit5(k1, k2, k3, k4, k5, k6, ytemp2, ytemp3, ytemp4, ytemp5, ytemp6, c1, c2, c3, c4, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76)
69+
end
70+
71+
function odestep!(tsit::Tsit5, odefun, y, yprev, t, tprev, xd, xci, p)
72+
73+
dt = t - tprev
74+
75+
odefun(tsit.k1, yprev, t, xd, xci, p)
76+
@. tsit.ytemp2 = yprev + dt*tsit.a21*tsit.k1
77+
odefun(tsit.k2, tsit.ytemp2, t+tsit.c1*dt, xd, xci, p)
78+
@. tsit.ytemp3 = yprev + dt*(tsit.a31*tsit.k1+tsit.a32*tsit.k2)
79+
odefun(tsit.k3, tsit.ytemp3, t+tsit.c2*dt, xd, xci, p)
80+
@. tsit.ytemp4 = yprev + dt*(tsit.a41*tsit.k1+tsit.a42*tsit.k2+tsit.a43*tsit.k3)
81+
odefun(tsit.k4, tsit.ytemp4, t+tsit.c3*dt, xd, xci, p)
82+
@. tsit.ytemp5 = yprev + dt*(tsit.a51*tsit.k1+tsit.a52*tsit.k2+tsit.a53*tsit.k3+tsit.a54*tsit.k4)
83+
odefun(tsit.k5, tsit.ytemp5, t+tsit.c4*dt, xd, xci, p)
84+
@. tsit.ytemp6 = yprev + dt*(tsit.a61*tsit.k1+tsit.a62*tsit.k2+tsit.a63*tsit.k3+tsit.a64*tsit.k4+tsit.a65*tsit.k5)
85+
odefun(tsit.k6, tsit.ytemp6, t+dt, xd, xci, p)
86+
87+
@. y = yprev + dt*(tsit.a71*tsit.k1+tsit.a72*tsit.k2+tsit.a73*tsit.k3+tsit.a74*tsit.k4+tsit.a75*tsit.k5+tsit.a76*tsit.k6)
88+
89+
return y
90+
end
91+
92+
93+
function Q(T, p)
94+
(; hc, Ta, ϵ, σ) = p
95+
Qc = hc*(T - Ta)
96+
Qr = ϵ*σ*(T^4 - Ta^4)
97+
return Qc + Qr
98+
end
99+
100+
function plate(dy, y, t, xd, xci, p)
101+
102+
(; k, rho, Cp, δ, tz, n, T, dT) = p
103+
alpha = k/(rho*Cp*δ^2)
104+
beta = 2/(rho*Cp*tz)
105+
106+
# set temperature grid
107+
T[2:n-1, 2:n-1] .= reshape(y, n-2, n-2)
108+
109+
# Direchlet b.c. on bottom
110+
T[end, :] .= xci
111+
112+
# Neuman b.c. on sides and top
113+
@views for i = 2:n-1
114+
T[i, 1] = T[i, 2] # left
115+
T[i, end] = T[i, end-1] # right
116+
end
117+
@views for j = 1:n
118+
T[1, j] = T[2, j] # top
119+
end
120+
121+
# update interior points
122+
@views for i = 2:n-1
123+
for j = 2:n-1
124+
Tij = T[i, j]
125+
dT[i, j] = alpha*(T[i+1, j] + T[i-1, j] + T[i, j+1] + T[i, j-1] - 4*Tij) - beta*(Q(Tij, p))
126+
end
127+
end
128+
129+
dy .= dT[2:n-1, 2:n-1][:]
130+
end
131+
132+
133+
function initialize(t0, xd, xc0, p)
134+
(; n) = p
135+
return 300*ones((n-2)*(n-2))
136+
end
137+
138+
function runit(n, nt)
139+
140+
tsit = Tsit5((n-2)*(n-2), Float64)
141+
onestep!(y, yprev, t, tprev, xd, xci, p) = odestep!(tsit, plate, y, yprev, t, tprev, xd, xci, p)
142+
143+
# problem constants
144+
p = (k=400.0, rho=8960.0, Cp=386.0, tz=.01, σ=5.670373e-8, hc=1.0, Ta=300.0, ϵ=0.5, T=zeros(n, n), dT=zeros(n, n), n=n, δ=1/(n-1))
145+
146+
t = range(0.0, 5000, nt)
147+
nt = length(t)
148+
149+
xc = reverse(range(600, 1000, n))*ones(nt)'
150+
xd = Float64[]
151+
152+
function program(xc)
153+
xcm = reshape(xc, n, nt)
154+
155+
TF = eltype(xc)
156+
if eltype(tsit.k1) != TF
157+
tsit = Tsit5((n-2)*(n-2), TF)
158+
pf = (k=400.0, rho=8960.0, Cp=386.0, tz=.01, σ=5.670373e-8, hc=1.0, Ta=300.0, ϵ=0.5, T=zeros(TF, n, n), dT=zeros(TF, n, n), n=n, δ=1/(n-1))
159+
onestep!(y, yprev, t, tprev, xd, xci, p) = odestep!(tsit, plate, y, yprev, t, tprev, xd, xci, pf)
160+
end
161+
162+
y = ImplicitAD.odesolve(initialize, onestep!, t, xd, xcm, p)
163+
164+
return y[1, end] # top left corner temperature at last time
165+
end
166+
167+
# ------ implicitAD ----------
168+
169+
TR = eltype(ReverseDiff.track([1.0]))
170+
tsitf = Tsit5((n-2)*(n-2), Float64)
171+
tsitr = Tsit5((n-2)*(n-2), TR)
172+
pr = (k=400.0, rho=8960.0, Cp=386.0, tz=.01, σ=5.670373e-8, hc=1.0, Ta=300.0, ϵ=0.5, T=zeros(TR, n, n), dT=zeros(TR, n, n), n=n, δ=1/(n-1))
173+
onestepf!(y, yprev, t, tprev, xd, xci, p) = odestep!(tsitf, plate, y, yprev, t, tprev, xd, xci, p)
174+
onestepr!(y, yprev, t, tprev, xd, xci, p) = odestep!(tsitr, plate, y, yprev, t, tprev, xd, xci, pr)
175+
cache = ImplicitAD.explicit_unsteady_cache(initialize, onestepr!, (n-2)*(n-2), 0, n, p; compile=true)
176+
177+
function modprogram(xc)
178+
xcm = reshape(xc, n, nt)
179+
180+
y = explicit_unsteady(initialize, onestepf!, t, xd, xcm, p; cache)
181+
182+
return y[1, end]
183+
end
184+
185+
xcv = xc[:]
186+
187+
fwd_cache = ForwardDiff.GradientConfig(program, xcv)
188+
f_tape1 = ReverseDiff.GradientTape(modprogram, xcv)
189+
rev_cache1 = ReverseDiff.compile(f_tape1)
190+
f_tape2 = ReverseDiff.GradientTape(program, xcv)
191+
rev_cache2 = ReverseDiff.compile(f_tape2)
192+
193+
g1 = zeros(length(xcv))
194+
g2 = zeros(length(xcv))
195+
g3 = zeros(length(xcv))
196+
197+
198+
# approximate cost of central diff
199+
t1 = @benchmark $program($xcv)
200+
time1 = median(t1).time * 1e-9 * (2*length(xcv) + 1)
201+
202+
# forward
203+
t2 = @benchmark ForwardDiff.gradient!($g1, $program, $xcv, $fwd_cache)
204+
time2 = median(t2).time * 1e-9
205+
206+
#reverse diff
207+
t3 = @benchmark ReverseDiff.gradient!($g2, $rev_cache2, $xcv)
208+
time3 = median(t3).time * 1e-9
209+
210+
# reverse w/ implicitad
211+
t4 = @benchmark ReverseDiff.gradient!($g3, $rev_cache1, $xcv)
212+
time4 = median(t4).time * 1e-9 # reverse implicit diff
213+
214+
println(time1, " ", time2, " ", time3, " ", time4)
215+
216+
return time1, time2, time3, time4
217+
218+
end
219+
220+
nt = 1000
221+
nvec = [3, 5, 7, 9, 11, 13, 15, 17, 19]
222+
# nvec = [19]
223+
nn = length(nvec)
224+
t1 = zeros(nn)
225+
t2 = zeros(nn)
226+
t3 = zeros(nn)
227+
t4 = zeros(nn)
228+
states = zeros(nn)
229+
230+
for i = 1:nn
231+
n = nvec[i]
232+
t1[i], t2[i], t3[i], t4[i] = runit(n, nt)
233+
states[i] = (n-2)^2
234+
end

0 commit comments

Comments
 (0)