Skip to content

Commit a7b8395

Browse files
Merge pull request #826 from ayushinav/dgm
reducing test times for DGM
2 parents 052e610 + 89a340b commit a7b8395

File tree

2 files changed

+33
-26
lines changed

2 files changed

+33
-26
lines changed

docs/src/tutorials/dgm.md

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ $$
3232
$$
3333

3434
defined over
35-
$$ t \in [0, 1], x \in [-1, 1] $$
35+
36+
$$
37+
t \in [0, 1], x \in [-1, 1]
38+
$$
3639

3740
with boundary conditions
3841
```math
@@ -44,7 +47,7 @@ u(t, 1) & = 0
4447
```
4548

4649
### Copy- Pasteable code
47-
```julia
50+
```@example dgm
4851
using NeuralPDE
4952
using ModelingToolkit, Optimization, OptimizationOptimisers
5053
import Lux: tanh, identity
@@ -55,15 +58,15 @@ using MethodOfLines, OrdinaryDiffEq
5558
@parameters x t
5659
@variables u(..)
5760
58-
Dt= Differential(t)
59-
Dx= Differential(x)
60-
Dxx= Dx^2
61-
α = 0.05;
61+
Dt = Differential(t)
62+
Dx = Differential(x)
63+
Dxx = Dx^2
64+
α = 0.05
6265
# Burger's equation
63-
eq= Dt(u(t,x)) + u(t,x) * Dx(u(t,x)) - α * Dxx(u(t,x)) ~ 0
66+
eq = Dt(u(t,x)) + u(t,x) * Dx(u(t,x)) - α * Dxx(u(t,x)) ~ 0
6467
6568
# boundary conditions
66-
bcs= [
69+
bcs = [
6770
u(0.0, x) ~ - sin(π*x),
6871
u(t, -1.0) ~ 0.0,
6972
u(t, 1.0) ~ 0.0
@@ -72,7 +75,7 @@ bcs= [
7275
domains = [t ∈ Interval(0.0, 1.0), x ∈ Interval(-1.0, 1.0)]
7376
7477
# MethodOfLines, for FD solution
75-
dx= 0.01
78+
dx = 0.01
7679
order = 2
7780
discretization = MOLFiniteDifference([x => dx], t, saveat = 0.01)
7881
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)])
@@ -84,25 +87,27 @@ xs = sol[x]
8487
u_MOL = sol[u(t,x)]
8588
8689
# NeuralPDE, using Deep Galerkin Method
87-
strategy = QuasiRandomTraining(4_000, minibatch= 500);
88-
discretization= DeepGalerkin(2, 1, 50, 5, tanh, tanh, identity, strategy);
89-
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)]);
90-
prob = discretize(pde_system, discretization);
91-
global iter = 0;
90+
strategy = QuasiRandomTraining(256, minibatch= 32)
91+
discretization = DeepGalerkin(2, 1, 50, 5, tanh, tanh, identity, strategy)
92+
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)])
93+
prob = discretize(pde_system, discretization)
94+
global iter = 0
9295
callback = function (p, l)
93-
global iter += 1;
96+
global iter += 1
9497
if iter%20 == 0
9598
println("$iter => $l")
9699
end
97100
return false
98101
end
99102
100-
res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 300);
101-
phi = discretization.phi;
103+
res = Optimization.solve(prob, Adam(0.1); callback = callback, maxiters = 100)
104+
prob = remake(prob, u0 = res.u)
105+
res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 500)
106+
phi = discretization.phi
102107
103108
u_predict= [first(phi([t, x], res.minimizer)) for t in ts, x in xs]
104109
105-
diff_u = abs.(u_predict .- u_MOL);
110+
diff_u = abs.(u_predict .- u_MOL)
106111
107112
using Plots
108113
p1 = plot(tgrid, xgrid, u_MOL', linetype = :contourf, title = "FD");

test/dgm_test.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ import Lux: tanh, identity
1919
# Space and time domains
2020
domains = [x Interval(0.0, 1.0), y Interval(0.0, 1.0)]
2121

22-
strategy = QuasiRandomTraining(4_000, minibatch= 500);
23-
discretization= DeepGalerkin(2, 1, 30, 3, tanh, tanh, identity, strategy);
22+
strategy = QuasiRandomTraining(256, minibatch= 32);
23+
discretization= DeepGalerkin(2, 1, 20, 3, tanh, tanh, identity, strategy);
2424

2525
@named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)])
2626
prob = discretize(pde_system, discretization)
@@ -71,8 +71,8 @@ end
7171

7272
domains = [t Interval(0.0, T), x Interval(0.0, S * S_multiplier)]
7373

74-
strategy = QuasiRandomTraining(4_000, minibatch= 500);
75-
discretization= DeepGalerkin(2, 1, 30, 3, tanh, tanh, identity, strategy);
74+
strategy = QuasiRandomTraining(128, minibatch= 32);
75+
discretization= DeepGalerkin(2, 1, 40, 3, tanh, tanh, identity, strategy);
7676

7777
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [g(t,x)])
7878
prob = discretize(pde_system, discretization)
@@ -86,9 +86,9 @@ end
8686
return false
8787
end
8888

89-
res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 300)
89+
res = Optimization.solve(prob, Adam(0.1); callback = callback, maxiters = 100)
9090
prob = remake(prob, u0 = res.u)
91-
res = Optimization.solve(prob, Adam(0.001); callback = callback, maxiters = 300)
91+
res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 500)
9292
phi = discretization.phi
9393

9494
function analytical_soln(t, x, K, σ, T)
@@ -138,7 +138,7 @@ end
138138
u_MOL = sol[u(t,x)]
139139

140140
# NeuralPDE
141-
strategy = QuasiRandomTraining(4_000, minibatch= 500);
141+
strategy = QuasiRandomTraining(256, minibatch= 32);
142142
discretization= DeepGalerkin(2, 1, 50, 5, tanh, tanh, identity, strategy);
143143
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)]);
144144
prob = discretize(pde_system, discretization);
@@ -151,7 +151,9 @@ end
151151
return false
152152
end
153153

154-
res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 300);
154+
res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 200);
155+
prob = remake(prob, u0 = res.u);
156+
res = Optimization.solve(prob, Adam(0.001); callback = callback, maxiters = 100);
155157
phi = discretization.phi;
156158

157159
u_predict= [first(phi([t, x], res.u)) for t in ts, x in xs]

0 commit comments

Comments
 (0)