|
| 1 | +using ImplicitAD |
| 2 | +using ForwardDiff |
| 3 | +using ReverseDiff |
| 4 | +using BenchmarkTools |
| 5 | + |
| 6 | +struct Tsit5{TV, TF} |
| 7 | + k1::TV |
| 8 | + k2::TV |
| 9 | + k3::TV |
| 10 | + k4::TV |
| 11 | + k5::TV |
| 12 | + k6::TV |
| 13 | + ytemp2::TV |
| 14 | + ytemp3::TV |
| 15 | + ytemp4::TV |
| 16 | + ytemp5::TV |
| 17 | + ytemp6::TV |
| 18 | + c1::TF |
| 19 | + c2::TF |
| 20 | + c3::TF |
| 21 | + c4::TF |
| 22 | + a21::TF |
| 23 | + a31::TF |
| 24 | + a32::TF |
| 25 | + a41::TF |
| 26 | + a42::TF |
| 27 | + a43::TF |
| 28 | + a51::TF |
| 29 | + a52::TF |
| 30 | + a53::TF |
| 31 | + a54::TF |
| 32 | + a61::TF |
| 33 | + a62::TF |
| 34 | + a63::TF |
| 35 | + a64::TF |
| 36 | + a65::TF |
| 37 | + a71::TF |
| 38 | + a72::TF |
| 39 | + a73::TF |
| 40 | + a74::TF |
| 41 | + a75::TF |
| 42 | + a76::TF |
| 43 | +end |
| 44 | + |
| 45 | +function Tsit5(ny, T) |
| 46 | + |
| 47 | + k1 = Vector{T}(undef, ny) |
| 48 | + k2 = Vector{T}(undef, ny) |
| 49 | + k3 = Vector{T}(undef, ny) |
| 50 | + k4 = Vector{T}(undef, ny) |
| 51 | + k5 = Vector{T}(undef, ny) |
| 52 | + k6 = Vector{T}(undef, ny) |
| 53 | + ytemp2 = Vector{T}(undef, ny) |
| 54 | + ytemp3 = Vector{T}(undef, ny) |
| 55 | + ytemp4 = Vector{T}(undef, ny) |
| 56 | + ytemp5 = Vector{T}(undef, ny) |
| 57 | + ytemp6 = Vector{T}(undef, ny) |
| 58 | + |
| 59 | + # constants |
| 60 | + c1 = 0.161; c2 = 0.327; c3 = 0.9; c4 = 0.9800255409045097; |
| 61 | + a21 = 0.161; |
| 62 | + a31 = -0.008480655492356989; a32 = 0.335480655492357; |
| 63 | + a41 = 2.8971530571054935; a42 = -6.359448489975075; a43 = 4.3622954328695815; |
| 64 | + a51 = 5.325864828439257; a52 = -11.748883564062828; a53 = 7.4955393428898365; a54 = -0.09249506636175525; |
| 65 | + a61 = 5.86145544294642; a62 = -12.92096931784711; a63 = 8.159367898576159; a64 = -0.071584973281401; a65 = -0.028269050394068383; |
| 66 | + a71 = 0.09646076681806523; a72 = 0.01; a73 = 0.4798896504144996; a74 = 1.379008574103742; a75 = -3.290069515436081; a76 = 2.324710524099774 |
| 67 | + |
| 68 | + return Tsit5(k1, k2, k3, k4, k5, k6, ytemp2, ytemp3, ytemp4, ytemp5, ytemp6, c1, c2, c3, c4, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76) |
| 69 | +end |
| 70 | + |
| 71 | +function odestep!(tsit::Tsit5, odefun, y, yprev, t, tprev, xd, xci, p) |
| 72 | + |
| 73 | + dt = t - tprev |
| 74 | + |
| 75 | + odefun(tsit.k1, yprev, t, xd, xci, p) |
| 76 | + @. tsit.ytemp2 = yprev + dt*tsit.a21*tsit.k1 |
| 77 | + odefun(tsit.k2, tsit.ytemp2, t+tsit.c1*dt, xd, xci, p) |
| 78 | + @. tsit.ytemp3 = yprev + dt*(tsit.a31*tsit.k1+tsit.a32*tsit.k2) |
| 79 | + odefun(tsit.k3, tsit.ytemp3, t+tsit.c2*dt, xd, xci, p) |
| 80 | + @. tsit.ytemp4 = yprev + dt*(tsit.a41*tsit.k1+tsit.a42*tsit.k2+tsit.a43*tsit.k3) |
| 81 | + odefun(tsit.k4, tsit.ytemp4, t+tsit.c3*dt, xd, xci, p) |
| 82 | + @. tsit.ytemp5 = yprev + dt*(tsit.a51*tsit.k1+tsit.a52*tsit.k2+tsit.a53*tsit.k3+tsit.a54*tsit.k4) |
| 83 | + odefun(tsit.k5, tsit.ytemp5, t+tsit.c4*dt, xd, xci, p) |
| 84 | + @. tsit.ytemp6 = yprev + dt*(tsit.a61*tsit.k1+tsit.a62*tsit.k2+tsit.a63*tsit.k3+tsit.a64*tsit.k4+tsit.a65*tsit.k5) |
| 85 | + odefun(tsit.k6, tsit.ytemp6, t+dt, xd, xci, p) |
| 86 | + |
| 87 | + @. y = yprev + dt*(tsit.a71*tsit.k1+tsit.a72*tsit.k2+tsit.a73*tsit.k3+tsit.a74*tsit.k4+tsit.a75*tsit.k5+tsit.a76*tsit.k6) |
| 88 | + |
| 89 | + return y |
| 90 | +end |
| 91 | + |
| 92 | + |
| 93 | +function Q(T, p) |
| 94 | + (; hc, Ta, ϵ, σ) = p |
| 95 | + Qc = hc*(T - Ta) |
| 96 | + Qr = ϵ*σ*(T^4 - Ta^4) |
| 97 | + return Qc + Qr |
| 98 | +end |
| 99 | + |
| 100 | +function plate(dy, y, t, xd, xci, p) |
| 101 | + |
| 102 | + (; k, rho, Cp, δ, tz, n, T, dT) = p |
| 103 | + alpha = k/(rho*Cp*δ^2) |
| 104 | + beta = 2/(rho*Cp*tz) |
| 105 | + |
| 106 | + # set temperature grid |
| 107 | + T[2:n-1, 2:n-1] .= reshape(y, n-2, n-2) |
| 108 | + |
| 109 | + # Direchlet b.c. on bottom |
| 110 | + T[end, :] .= xci |
| 111 | + |
| 112 | + # Neuman b.c. on sides and top |
| 113 | + @views for i = 2:n-1 |
| 114 | + T[i, 1] = T[i, 2] # left |
| 115 | + T[i, end] = T[i, end-1] # right |
| 116 | + end |
| 117 | + @views for j = 1:n |
| 118 | + T[1, j] = T[2, j] # top |
| 119 | + end |
| 120 | + |
| 121 | + # update interior points |
| 122 | + @views for i = 2:n-1 |
| 123 | + for j = 2:n-1 |
| 124 | + Tij = T[i, j] |
| 125 | + dT[i, j] = alpha*(T[i+1, j] + T[i-1, j] + T[i, j+1] + T[i, j-1] - 4*Tij) - beta*(Q(Tij, p)) |
| 126 | + end |
| 127 | + end |
| 128 | + |
| 129 | + dy .= dT[2:n-1, 2:n-1][:] |
| 130 | +end |
| 131 | + |
| 132 | + |
| 133 | +function initialize(t0, xd, xc0, p) |
| 134 | + (; n) = p |
| 135 | + return 300*ones((n-2)*(n-2)) |
| 136 | +end |
| 137 | + |
| 138 | +function runit(n, nt) |
| 139 | + |
| 140 | + tsit = Tsit5((n-2)*(n-2), Float64) |
| 141 | + onestep!(y, yprev, t, tprev, xd, xci, p) = odestep!(tsit, plate, y, yprev, t, tprev, xd, xci, p) |
| 142 | + |
| 143 | + # problem constants |
| 144 | + p = (k=400.0, rho=8960.0, Cp=386.0, tz=.01, σ=5.670373e-8, hc=1.0, Ta=300.0, ϵ=0.5, T=zeros(n, n), dT=zeros(n, n), n=n, δ=1/(n-1)) |
| 145 | + |
| 146 | + t = range(0.0, 5000, nt) |
| 147 | + nt = length(t) |
| 148 | + |
| 149 | + xc = reverse(range(600, 1000, n))*ones(nt)' |
| 150 | + xd = Float64[] |
| 151 | + |
| 152 | + function program(xc) |
| 153 | + xcm = reshape(xc, n, nt) |
| 154 | + |
| 155 | + TF = eltype(xc) |
| 156 | + if eltype(tsit.k1) != TF |
| 157 | + tsit = Tsit5((n-2)*(n-2), TF) |
| 158 | + pf = (k=400.0, rho=8960.0, Cp=386.0, tz=.01, σ=5.670373e-8, hc=1.0, Ta=300.0, ϵ=0.5, T=zeros(TF, n, n), dT=zeros(TF, n, n), n=n, δ=1/(n-1)) |
| 159 | + onestep!(y, yprev, t, tprev, xd, xci, p) = odestep!(tsit, plate, y, yprev, t, tprev, xd, xci, pf) |
| 160 | + end |
| 161 | + |
| 162 | + y = ImplicitAD.odesolve(initialize, onestep!, t, xd, xcm, p) |
| 163 | + |
| 164 | + return y[1, end] # top left corner temperature at last time |
| 165 | + end |
| 166 | + |
| 167 | + # ------ implicitAD ---------- |
| 168 | + |
| 169 | + TR = eltype(ReverseDiff.track([1.0])) |
| 170 | + tsitf = Tsit5((n-2)*(n-2), Float64) |
| 171 | + tsitr = Tsit5((n-2)*(n-2), TR) |
| 172 | + pr = (k=400.0, rho=8960.0, Cp=386.0, tz=.01, σ=5.670373e-8, hc=1.0, Ta=300.0, ϵ=0.5, T=zeros(TR, n, n), dT=zeros(TR, n, n), n=n, δ=1/(n-1)) |
| 173 | + onestepf!(y, yprev, t, tprev, xd, xci, p) = odestep!(tsitf, plate, y, yprev, t, tprev, xd, xci, p) |
| 174 | + onestepr!(y, yprev, t, tprev, xd, xci, p) = odestep!(tsitr, plate, y, yprev, t, tprev, xd, xci, pr) |
| 175 | + cache = ImplicitAD.explicit_unsteady_cache(initialize, onestepr!, (n-2)*(n-2), 0, n, p; compile=true) |
| 176 | + |
| 177 | + function modprogram(xc) |
| 178 | + xcm = reshape(xc, n, nt) |
| 179 | + |
| 180 | + y = explicit_unsteady(initialize, onestepf!, t, xd, xcm, p; cache) |
| 181 | + |
| 182 | + return y[1, end] |
| 183 | + end |
| 184 | + |
| 185 | + xcv = xc[:] |
| 186 | + |
| 187 | + fwd_cache = ForwardDiff.GradientConfig(program, xcv) |
| 188 | + f_tape1 = ReverseDiff.GradientTape(modprogram, xcv) |
| 189 | + rev_cache1 = ReverseDiff.compile(f_tape1) |
| 190 | + f_tape2 = ReverseDiff.GradientTape(program, xcv) |
| 191 | + rev_cache2 = ReverseDiff.compile(f_tape2) |
| 192 | + |
| 193 | + g1 = zeros(length(xcv)) |
| 194 | + g2 = zeros(length(xcv)) |
| 195 | + g3 = zeros(length(xcv)) |
| 196 | + |
| 197 | + |
| 198 | + # approximate cost of central diff |
| 199 | + t1 = @benchmark $program($xcv) |
| 200 | + time1 = median(t1).time * 1e-9 * (2*length(xcv) + 1) |
| 201 | + |
| 202 | + # forward |
| 203 | + t2 = @benchmark ForwardDiff.gradient!($g1, $program, $xcv, $fwd_cache) |
| 204 | + time2 = median(t2).time * 1e-9 |
| 205 | + |
| 206 | + #reverse diff |
| 207 | + t3 = @benchmark ReverseDiff.gradient!($g2, $rev_cache2, $xcv) |
| 208 | + time3 = median(t3).time * 1e-9 |
| 209 | + |
| 210 | + # reverse w/ implicitad |
| 211 | + t4 = @benchmark ReverseDiff.gradient!($g3, $rev_cache1, $xcv) |
| 212 | + time4 = median(t4).time * 1e-9 # reverse implicit diff |
| 213 | + |
| 214 | + println(time1, " ", time2, " ", time3, " ", time4) |
| 215 | + |
| 216 | + return time1, time2, time3, time4 |
| 217 | + |
| 218 | +end |
| 219 | + |
| 220 | +nt = 1000 |
| 221 | +nvec = [3, 5, 7, 9, 11, 13, 15, 17, 19] |
| 222 | +# nvec = [19] |
| 223 | +nn = length(nvec) |
| 224 | +t1 = zeros(nn) |
| 225 | +t2 = zeros(nn) |
| 226 | +t3 = zeros(nn) |
| 227 | +t4 = zeros(nn) |
| 228 | +states = zeros(nn) |
| 229 | + |
| 230 | +for i = 1:nn |
| 231 | + n = nvec[i] |
| 232 | + t1[i], t2[i], t3[i], t4[i] = runit(n, nt) |
| 233 | + states[i] = (n-2)^2 |
| 234 | +end |
0 commit comments