Skip to content

Commit 0856525

Browse files
Merge pull request #839 from sathvikbhagavan/sb/complex
feat: allow complex for NNODE
2 parents 0cb2e06 + dad8815 commit 0856525

File tree

5 files changed

+144
-5
lines changed

5 files changed

+144
-5
lines changed

docs/pages.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ pages = ["index.md",
2222
"examples/heterogeneous.md",
2323
"examples/linear_parabolic.md",
2424
"examples/nonlinear_elliptic.md",
25-
"examples/nonlinear_hyperbolic.md"],
25+
"examples/nonlinear_hyperbolic.md",
26+
"examples/complex.md"],
2627
"Manual" => Any["manual/ode.md",
2728
"manual/dae.md",
2829
"manual/pinns.md",

docs/src/examples/complex.md

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Complex Equations with PINNs
2+
3+
NeuralPDE supports training PINNs with complex differential equations. This example will demonstrate how to use it for [`NNODE`](@ref). Let us consider a system of [bloch equations](https://en.wikipedia.org/wiki/Bloch_equations). Note [`QuadratureTraining`](@ref) cannot be used with complex equations due to current limitations of computing quadratures.
4+
5+
As the input to this neural network is time which is real, we need to initialize the parameters of the neural network with complex values for it to output and train with complex values.
6+
7+
```@example complex
8+
using Random, NeuralPDE
9+
using OrdinaryDiffEq
10+
using Lux, OptimizationOptimisers
11+
using Plots
12+
rng = Random.default_rng()
13+
Random.seed!(100)
14+
15+
function bloch_equations(u, p, t)
16+
Ω, Δ, Γ = p
17+
γ = Γ / 2
18+
ρ₁₁, ρ₂₂, ρ₁₂, ρ₂₁ = u
19+
d̢ρ = [im * Ω * (ρ₁₂ - ρ₂₁) + Γ * ρ₂₂;
20+
-im * Ω * (ρ₁₂ - ρ₂₁) - Γ * ρ₂₂;
21+
-(γ + im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁);
22+
conj(-(γ + im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁))]
23+
return d̢ρ
24+
end
25+
26+
u0 = zeros(ComplexF64, 4)
27+
u0[1] = 1.0
28+
time_span = (0.0, 2.0)
29+
parameters = [2.0, 0.0, 1.0]
30+
31+
problem = ODEProblem(bloch_equations, u0, time_span, parameters)
32+
33+
chain = Lux.Chain(
34+
Lux.Dense(1, 16, tanh; init_weight = (rng, a...) -> Lux.kaiming_normal(rng, ComplexF64, a...)) ,
35+
Lux.Dense(16, 4; init_weight = (rng, a...) -> Lux.kaiming_normal(rng, ComplexF64, a...))
36+
)
37+
ps, st = Lux.setup(rng, chain)
38+
39+
opt = OptimizationOptimisers.Adam(0.01)
40+
ground_truth = solve(problem, Tsit5(), saveat = 0.01)
41+
alg = NNODE(chain, opt, ps; strategy = StochasticTraining(500))
42+
sol = solve(problem, alg, verbose = false, maxiters = 5000, saveat = 0.01)
43+
```
44+
45+
Now, lets plot the predictions.
46+
47+
`u1`:
48+
49+
```@example complex
50+
plot(sol.t, real.(reduce(hcat, sol.u)[1, :]));
51+
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[1, :]))
52+
```
53+
54+
```@example complex
55+
plot(sol.t, imag.(reduce(hcat, sol.u)[1, :]));
56+
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[1, :]))
57+
```
58+
59+
`u2`:
60+
61+
```@example complex
62+
plot(sol.t, real.(reduce(hcat, sol.u)[2, :]));
63+
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[2, :]))
64+
```
65+
66+
```@example complex
67+
plot(sol.t, imag.(reduce(hcat, sol.u)[2, :]));
68+
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[2, :]))
69+
```
70+
71+
`u3`:
72+
73+
```@example complex
74+
plot(sol.t, real.(reduce(hcat, sol.u)[3, :]));
75+
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[3, :]))
76+
```
77+
78+
```@example complex
79+
plot(sol.t, imag.(reduce(hcat, sol.u)[3, :]));
80+
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[3, :]))
81+
```
82+
83+
`u4`:
84+
85+
```@example complex
86+
plot(sol.t, real.(reduce(hcat, sol.u)[4, :]));
87+
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[4, :]))
88+
```
89+
90+
```@example complex
91+
plot(sol.t, imag.(reduce(hcat, sol.u)[4, :]));
92+
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[4, :]))
93+
```
94+
95+
We can see it is able to learn the real parts of `u1`, `u2` and imaginary parts of `u3`, `u4`.

docs/src/tutorials/neural_adapter.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function loss(cord, θ)
6969
ch2 .- phi(cord, res.u)
7070
end
7171
72-
strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6)
72+
strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6, abstol = 1e-3)
7373
7474
prob_ = NeuralPDE.neural_adapter(loss, init_params2, pde_system, strategy)
7575
res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 10000)
@@ -173,7 +173,7 @@ for i in 1:count_decomp
173173
bcs_ = create_bcs(domains_[1].domain, phi_bound)
174174
@named pde_system_ = PDESystem(eq, bcs_, domains_, [x, y], [u(x, y)])
175175
push!(pde_system_map, pde_system_)
176-
strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6)
176+
strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6, abstol = 1e-3)
177177
178178
discretization = NeuralPDE.PhysicsInformedNN(chains[i], strategy;
179179
init_params = init_params[i])
@@ -243,10 +243,10 @@ callback = function (p, l)
243243
end
244244
245245
prob_ = NeuralPDE.neural_adapter(losses, init_params2, pde_system_map,
246-
NeuralPDE.QuadratureTraining(; reltol = 1e-6))
246+
NeuralPDE.QuadratureTraining(; reltol = 1e-6, abstol = 1e-3))
247247
res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000)
248248
prob_ = NeuralPDE.neural_adapter(losses, res_.u, pde_system_map,
249-
NeuralPDE.QuadratureTraining(; reltol = 1e-6))
249+
NeuralPDE.QuadratureTraining(; reltol = 1e-6, abstol = 1e-3))
250250
res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000)
251251
252252
phi_ = PhysicsInformedNN(chain2, strategy; init_params = res_.u).phi

src/ode_solve.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ function (f::NNODEInterpolation)(t::Vector, idxs, ::Type{Val{0}}, p, continuity)
326326
end
327327

328328
SciMLBase.interp_summary(::NNODEInterpolation) = "Trained neural network interpolation"
329+
SciMLBase.allowscomplex(::NNODE) = true
329330

330331
function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
331332
alg::NNODE,
@@ -357,6 +358,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
357358

358359
!(chain isa Lux.AbstractExplicitLayer) && error("Only Lux.AbstractExplicitLayer neural networks are supported")
359360
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
361+
((eltype(eltype(init_params).types[1]) <: Complex || eltype(eltype(init_params).types[2]) <: Complex) && alg.strategy isa QuadratureTraining) &&
362+
error("QuadratureTraining cannot be used with complex parameters. Use other strategies.")
360363

361364
init_params = if alg.param_estim
362365
ComponentArrays.ComponentArray(; depvar = ComponentArrays.ComponentArray(init_params), p = prob.p)

test/NNODE_tests.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import Lux, OptimizationOptimisers, OptimizationOptimJL
55
using Flux
66
using LineSearches
77

8+
rng = Random.default_rng()
89
Random.seed!(100)
910

1011
@testset "Scalar" begin
@@ -250,6 +251,45 @@ end
250251
@test reduce(hcat, sol.u)u_ atol=1e-2
251252
end
252253

254+
@testset "Complex Numbers" begin
255+
function bloch_equations(u, p, t)
256+
Ω, Δ, Γ = p
257+
γ = Γ / 2
258+
ρ₁₁, ρ₂₂, ρ₁₂, ρ₂₁ = u
259+
d̢ρ = [im * Ω * (ρ₁₂ - ρ₂₁) + Γ * ρ₂₂;
260+
-im * Ω * (ρ₁₂ - ρ₂₁) - Γ * ρ₂₂;
261+
-+ im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁);
262+
conj(-+ im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁))]
263+
return d̢ρ
264+
end
265+
266+
u0 = zeros(ComplexF64, 4)
267+
u0[1] = 1
268+
time_span = (0.0, 2.0)
269+
parameters = [2.0, 0.0, 1.0]
270+
271+
problem = ODEProblem(bloch_equations, u0, time_span, parameters)
272+
273+
chain = Lux.Chain(
274+
Lux.Dense(1, 16, tanh; init_weight = (rng, a...) -> Lux.kaiming_normal(rng, ComplexF64, a...)) ,
275+
Lux.Dense(16, 4; init_weight = (rng, a...) -> Lux.kaiming_normal(rng, ComplexF64, a...))
276+
)
277+
ps, st = Lux.setup(rng, chain)
278+
279+
opt = OptimizationOptimisers.Adam(0.01)
280+
ground_truth = solve(problem, Tsit5(), saveat = 0.01)
281+
strategies = [StochasticTraining(500), GridTraining(0.01), WeightedIntervalTraining([0.1, 0.4, 0.4, 0.1], 500)]
282+
283+
@testset "$(nameof(typeof(strategy)))" for strategy in strategies
284+
alg = NNODE(chain, opt, ps; strategy)
285+
sol = solve(problem, alg, verbose = false, maxiters = 5000, saveat = 0.01)
286+
@test sol.u ground_truth.u rtol=1e-1
287+
end
288+
289+
alg = NNODE(chain, opt, ps; strategy = QuadratureTraining())
290+
@test_throws ErrorException solve(problem, alg, verbose = false, maxiters = 5000, saveat = 0.01)
291+
end
292+
253293
@testset "Translating from Flux" begin
254294
println("Translating from Flux")
255295
linear = (u, p, t) -> cos(2pi * t)

0 commit comments

Comments
 (0)