Skip to content

Commit a20efb6

Browse files
Merge pull request #802 from ayushinav/master
add Deep Galerkin method
2 parents 389086f + c380453 commit a20efb6

File tree

8 files changed

+460
-2
lines changed

8 files changed

+460
-2
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ jobs:
2525
- AdaptiveLoss
2626
- Logging
2727
- Forward
28+
- NeuralAdapter
29+
- DGM
2830
version:
2931
- "1"
3032
steps:

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ Symbolics = "5"
8282
Test = "1"
8383
UnPack = "1"
8484
Zygote = "0.6"
85+
MethodOfLines = "0.10.7"
8586
julia = "1.6"
8687

8788
[extras]
@@ -95,6 +96,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
9596
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
9697
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
9798
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
99+
MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4"
98100

99101
[targets]
100-
test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux"]
102+
test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux", "MethodOfLines"]

docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pages = ["index.md",
33
"Bayesian PINNs for Coupled ODEs" => "tutorials/Lotka_Volterra_BPINNs.md",
44
"PINNs DAEs" => "tutorials/dae.md",
55
"Parameter Estimation with PINNs for ODEs" => "tutorials/ode_parameter_estimation.md",
6+
"Deep Galerkin Method" => "tutorials/dgm.md"
67
#"examples/nnrode_example.md", # currently incorrect
78
],
89
"PDE PINN Tutorials" => Any["Introduction to NeuralPDE for PDEs" => "tutorials/pdesystem.md",

docs/src/tutorials/dgm.md

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
## Solving PDEs using Deep Galerkin Method
2+
3+
### Overview
4+
5+
Deep Galerkin Method is a meshless deep learning algorithm to solve high dimensional PDEs. The algorithm does so by approximating the solution of a PDE with a neural network. The loss function of the network is defined in the similar spirit as PINNs, composed of PDE loss and boundary condition loss.
6+
7+
In the following example, we demonstrate computing the loss function using Quasi-Random Sampling, a sampling technique that uses quasi-Monte Carlo sampling to generate low discrepancy random sequences in high dimensional spaces.
8+
9+
### Algorithm
10+
The authors of DGM suggest a network composed of LSTM-type layers that works well for most of the parabolic and quasi-parabolic PDEs.
11+
12+
```math
13+
\begin{align*}
14+
S^1 &= \sigma_1(W^1 \vec{x} + b^1); \\
15+
Z^l &= \sigma_1(U^{z,l} \vec{x} + W^{z,l} S^l + b^{z,l}); \quad l = 1, \ldots, L; \\
16+
G^l &= \sigma_1(U^{g,l} \vec{x} + W^{g,l} S_l + b^{g,l}); \quad l = 1, \ldots, L; \\
17+
R^l &= \sigma_1(U^{r,l} \vec{x} + W^{r,l} S^l + b^{r,l}); \quad l = 1, \ldots, L; \\
18+
H^l &= \sigma_2(U^{h,l} \vec{x} + W^{h,l}(S^l \cdot R^l) + b^{h,l}); \quad l = 1, \ldots, L; \\
19+
S^{l+1} &= (1 - G^l) \cdot H^l + Z^l \cdot S^{l}; \quad l = 1, \ldots, L; \\
20+
f(t, x; \theta) &= \sigma_{out}(W S^{L+1} + b).
21+
\end{align*}
22+
```
23+
24+
where $\vec{x}$ is the concatenated vector of $(t, x)$ and $L$ is the number of LSTM type layers in the network.
25+
26+
### Example
27+
28+
Let's try to solve the following Burger's equation using Deep Galerkin Method for $\alpha = 0.05$ and compare our solution with the finite difference method:
29+
30+
$$
31+
\partial_t u(t, x) + u(t, x) \partial_x u(t, x) - \alpha \partial_{xx} u(t, x) = 0
32+
$$
33+
34+
defined over
35+
$$ t \in [0, 1], x \in [-1, 1] $$
36+
37+
with boundary conditions
38+
```math
39+
\begin{align*}
40+
u(t, x) & = - sin(πx), \\
41+
u(t, -1) & = 0, \\
42+
u(t, 1) & = 0
43+
\end{align*}
44+
```
45+
46+
### Copy- Pasteable code
47+
```julia
48+
using NeuralPDE
49+
using ModelingToolkit, Optimization, OptimizationOptimisers
50+
import Lux: tanh, identity
51+
using Distributions
52+
import ModelingToolkit: Interval, infimum, supremum
53+
using MethodOfLines, OrdinaryDiffEq
54+
55+
@parameters x t
56+
@variables u(..)
57+
58+
Dt= Differential(t)
59+
Dx= Differential(x)
60+
Dxx= Dx^2
61+
α = 0.05;
62+
# Burger's equation
63+
eq= Dt(u(t,x)) + u(t,x) * Dx(u(t,x)) - α * Dxx(u(t,x)) ~ 0
64+
65+
# boundary conditions
66+
bcs= [
67+
u(0.0, x) ~ - sin*x),
68+
u(t, -1.0) ~ 0.0,
69+
u(t, 1.0) ~ 0.0
70+
]
71+
72+
domains = [t Interval(0.0, 1.0), x Interval(-1.0, 1.0)]
73+
74+
# MethodOfLines, for FD solution
75+
dx= 0.01
76+
order = 2
77+
discretization = MOLFiniteDifference([x => dx], t, saveat = 0.01)
78+
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)])
79+
prob = discretize(pde_system, discretization)
80+
sol= solve(prob, Tsit5())
81+
ts = sol[t]
82+
xs = sol[x]
83+
84+
u_MOL = sol[u(t,x)]
85+
86+
# NeuralPDE, using Deep Galerkin Method
87+
strategy = QuasiRandomTraining(4_000, minibatch= 500);
88+
discretization= DeepGalerkin(2, 1, 50, 5, tanh, tanh, identity, strategy);
89+
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)]);
90+
prob = discretize(pde_system, discretization);
91+
global iter = 0;
92+
callback = function (p, l)
93+
global iter += 1;
94+
if iter%20 == 0
95+
println("$iter => $l")
96+
end
97+
return false
98+
end
99+
100+
res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 300);
101+
phi = discretization.phi;
102+
103+
u_predict= [first(phi([t, x], res.minimizer)) for t in ts, x in xs]
104+
105+
diff_u = abs.(u_predict .- u_MOL);
106+
107+
using Plots
108+
p1 = plot(tgrid, xgrid, u_MOL', linetype = :contourf, title = "FD");
109+
p2 = plot(tgrid, xgrid, u_predict', linetype = :contourf, title = "predict");
110+
p3 = plot(tgrid, xgrid, diff_u', linetype = :contourf, title = "error");
111+
plot(p1, p2, p3)
112+
```

src/NeuralPDE.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ include("neural_adapter.jl")
5151
include("advancedHMC_MCMC.jl")
5252
include("BPINN_ode.jl")
5353
include("PDE_BPINN.jl")
54+
include("dgm.jl")
5455

5556
export NNODE, NNDAE,
5657
PhysicsInformedNN, discretize,
@@ -62,6 +63,7 @@ export NNODE, NNDAE,
6263
AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss,
6364
MiniMaxAdaptiveLoss, LogOptions,
6465
ahmc_bayesian_pinn_ode, BNNODE, ahmc_bayesian_pinn_pde, vector_to_parameters,
65-
BPINNsolution, BayesianPINN
66+
BPINNsolution, BayesianPINN,
67+
DeepGalerkin
6668

6769
end # module

src/dgm.jl

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
struct dgm_lstm_layer{F1, F2} <:Lux.AbstractExplicitLayer
2+
activation1::Function
3+
activation2::Function
4+
in_dims::Int
5+
out_dims::Int
6+
init_weight::F1
7+
init_bias::F2
8+
end
9+
10+
function dgm_lstm_layer(in_dims::Int, out_dims::Int, activation1, activation2;
11+
init_weight = Lux.glorot_uniform, init_bias = Lux.zeros32)
12+
return dgm_lstm_layer{typeof(init_weight), typeof(init_bias)}(activation1, activation2, in_dims, out_dims, init_weight, init_bias);
13+
end
14+
15+
import Lux:initialparameters, initialstates, parameterlength, statelength
16+
17+
function Lux.initialparameters(rng::AbstractRNG, l::dgm_lstm_layer)
18+
return (
19+
Uz = l.init_weight(rng, l.out_dims, l.in_dims),
20+
Ug = l.init_weight(rng, l.out_dims, l.in_dims),
21+
Ur = l.init_weight(rng, l.out_dims, l.in_dims),
22+
Uh = l.init_weight(rng, l.out_dims, l.in_dims),
23+
Wz = l.init_weight(rng, l.out_dims, l.out_dims),
24+
Wg = l.init_weight(rng, l.out_dims, l.out_dims),
25+
Wr = l.init_weight(rng, l.out_dims, l.out_dims),
26+
Wh = l.init_weight(rng, l.out_dims, l.out_dims),
27+
bz = l.init_bias(rng, l.out_dims) ,
28+
bg = l.init_bias(rng, l.out_dims) ,
29+
br = l.init_bias(rng, l.out_dims) ,
30+
bh = l.init_bias(rng, l.out_dims)
31+
)
32+
end
33+
34+
Lux.initialstates(::AbstractRNG, ::dgm_lstm_layer) = NamedTuple()
35+
Lux.parameterlength(l::dgm_lstm_layer) = 4* (l.out_dims * l.in_dims + l.out_dims * l.out_dims + l.out_dims)
36+
Lux.statelength(l::dgm_lstm_layer) = 0
37+
38+
function (layer::dgm_lstm_layer)(S::AbstractVecOrMat{T}, x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T
39+
@unpack Uz, Ug, Ur, Uh, Wz, Wg, Wr, Wh, bz, bg, br, bh = ps
40+
Z = layer.activation1.(Uz*x+ Wz*S .+ bz);
41+
G = layer.activation1.(Ug*x+ Wg*S .+ bg);
42+
R = layer.activation1.(Ur*x+ Wr*S .+ br);
43+
H = layer.activation2.(Uh*x+ Wh*(S.*R) .+ bh);
44+
S_new = (1. .- G) .* H .+ Z .* S;
45+
return S_new, st;
46+
end
47+
48+
struct dgm_lstm_block{L <:NamedTuple} <: Lux.AbstractExplicitContainerLayer{(:layers,)}
49+
layers::L
50+
end
51+
52+
function dgm_lstm_block(l...)
53+
names = ntuple(i-> Symbol("dgm_lstm_$i"), length(l));
54+
layers = NamedTuple{names}(l);
55+
return dgm_lstm_block(layers);
56+
end
57+
58+
dgm_lstm_block(xs::AbstractVector) = dgm_lstm_block(xs...)
59+
60+
@generated function apply_dgm_lstm_block(layers::NamedTuple{fields}, S::AbstractVecOrMat, x::AbstractVecOrMat, ps, st::NamedTuple) where fields
61+
N = length(fields);
62+
S_symbols = vcat([:S], [gensym() for _ in 1:N])
63+
x_symbol = :x;
64+
st_symbols = [gensym() for _ in 1:N]
65+
calls = [:(($(S_symbols[i + 1]), $(st_symbols[i])) = layers.$(fields[i])(
66+
$(S_symbols[i]), $(x_symbol), ps.$(fields[i]), st.$(fields[i]))) for i in 1:N]
67+
push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),)))))
68+
push!(calls, :(return $(S_symbols[N + 1]), st))
69+
return Expr(:block, calls...)
70+
end
71+
72+
function (L::dgm_lstm_block)(S::AbstractVecOrMat{T}, x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T
73+
return apply_dgm_lstm_block(L.layers, S, x, ps, st)
74+
end
75+
76+
struct dgm{S, L, E} <: Lux.AbstractExplicitContainerLayer{(:d_start, :lstm, :d_end)}
77+
d_start::S
78+
lstm:: L
79+
d_end:: E
80+
end
81+
82+
function (l::dgm)(x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T
83+
84+
S, st_start = l.d_start(x, ps.d_start, st.d_start);
85+
S, st_lstm = l.lstm(S, x, ps.lstm, st.lstm);
86+
y, st_end = l.d_end(S, ps.d_end, st.d_end);
87+
88+
st_new = (
89+
d_start= st_start,
90+
lstm= st_lstm,
91+
d_end= st_end
92+
)
93+
return y, st_new;
94+
95+
end
96+
97+
"""
98+
`dgm(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1, activation2, out_activation= Lux.identity)`:
99+
returns the architecture defined for Deep Galerkin method
100+
101+
```math
102+
\\begin{align}
103+
S^1 &= \\sigma_1(W^1 x + b^1); \\
104+
Z^l &= \\sigma_1(U^{z,l} x + W^{z,l} S^l + b^{z,l}); \\quad l = 1, \\ldots, L; \\
105+
G^l &= \\sigma_1(U^{g,l} x + W^{g,l} S_l + b^{g,l}); \\quad l = 1, \\ldots, L; \\
106+
R^l &= \\sigma_1(U^{r,l} x + W^{r,l} S^l + b^{r,l}); \\quad l = 1, \\ldots, L; \\
107+
H^l &= \\sigma_2(U^{h,l} x + W^{h,l}(S^l \\cdot R^l) + b^{h,l}); \\quad l = 1, \\ldots, L; \\
108+
S^{l+1} &= (1 - G^l) \\cdot H^l + Z^l \\cdot S^{l}; \\quad l = 1, \\ldots, L; \\
109+
f(t, x, \\theta) &= \\sigma_{out}(W S^{L+1} + b).
110+
\\end{align}
111+
```
112+
## Positional Arguments:
113+
`in_dims`: number of input dimensions= (spatial dimension+ 1)
114+
115+
`out_dims`: number of output dimensions
116+
117+
`modes`: Width of the LSTM type layer (output of the first Dense layer)
118+
119+
`layers`: number of LSTM type layers
120+
121+
`activation1`: activation function used in LSTM type layers
122+
123+
`activation2`: activation function used for the output of LSTM type layers
124+
125+
`out_activation`: activation fn used for the output of the network
126+
127+
`kwargs`: additional arguments to be splatted into `PhysicsInformedNN`
128+
"""
129+
function dgm(in_dims::Int, out_dims::Int, modes::Int, layers::Int, activation1, activation2, out_activation)
130+
dgm(
131+
Lux.Dense(in_dims, modes, activation1),
132+
dgm_lstm_block([dgm_lstm_layer(in_dims, modes, activation1, activation2) for i in 1:layers]),
133+
Lux.Dense(modes, out_dims, out_activation)
134+
)
135+
end
136+
137+
"""
138+
`DeepGalerkin(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1::Function, activation2::Function, out_activation::Function,
139+
strategy::NeuralPDE.AbstractTrainingStrategy; kwargs...)`:
140+
141+
returns a `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a `PDESystem` into an
142+
`OptimizationProblem` using the Deep Galerkin method.
143+
144+
## Arguments:
145+
`in_dims`: number of input dimensions= (spatial dimension+ 1)
146+
147+
`out_dims`: number of output dimensions
148+
149+
`modes`: Width of the LSTM type layer
150+
151+
`L`: number of LSTM type layers
152+
153+
`activation1`: activation fn used in LSTM type layers
154+
155+
`activation2`: activation fn used for the output of LSTM type layers
156+
157+
`out_activation`: activation fn used for the output of the network
158+
159+
`kwargs`: additional arguments to be splatted into `PhysicsInformedNN`
160+
161+
## Examples
162+
```julia
163+
discretization= DeepGalerkin(2, 1, 30, 3, tanh, tanh, identity, QuasiRandomTraining(4_000));
164+
```
165+
## References
166+
Sirignano, Justin and Spiliopoulos, Konstantinos, "DGM: A deep learning algorithm for solving partial differential equations",
167+
Journal of Computational Physics, Volume 375, 2018, Pages 1339-1364, doi: https://doi.org/10.1016/j.jcp.2018.08.029
168+
"""
169+
function DeepGalerkin(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1::Function, activation2::Function, out_activation::Function, strategy::NeuralPDE.AbstractTrainingStrategy; kwargs...)
170+
PhysicsInformedNN(
171+
dgm(in_dims, out_dims, modes, L, activation1, activation2, out_activation),
172+
strategy; kwargs...
173+
)
174+
end

0 commit comments

Comments
 (0)