Skip to content

Commit 0cb2e06

Browse files
Merge pull request #836 from sathvikbhagavan/sb/batch
refactor: correctly lower quadrature training strategy in NNODE
2 parents 99feba6 + 7e3de98 commit 0cb2e06

File tree

11 files changed

+82
-87
lines changed

11 files changed

+82
-87
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ AdvancedHMC = "0.6.1"
4444
Aqua = "0.8"
4545
ArrayInterface = "7.7"
4646
CUDA = "5.2"
47-
ChainRulesCore = "1.18"
47+
ChainRulesCore = "1.21"
4848
ComponentArrays = "0.15.8"
4949
Cubature = "1.5"
5050
DiffEqBase = "6.144"
@@ -59,7 +59,7 @@ Integrals = "4"
5959
LineSearches = "7.2"
6060
LinearAlgebra = "1"
6161
LogDensityProblems = "2"
62-
Lux = "0.5.14"
62+
Lux = "0.5.22"
6363
LuxCUDA = "0.3.2"
6464
MCMCChains = "6"
6565
MethodOfLines = "0.10.7"
@@ -82,7 +82,7 @@ SymbolicUtils = "1.4"
8282
Symbolics = "5.17"
8383
Test = "1"
8484
UnPack = "1"
85-
Zygote = "0.6.68"
85+
Zygote = "0.6.69"
8686
julia = "1.10"
8787

8888
[extras]
@@ -91,12 +91,12 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
9191
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
9292
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
9393
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
94+
MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4"
9495
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
9596
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
9697
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
9798
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
9899
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
99-
MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4"
100100

101101
[targets]
102102
test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux", "MethodOfLines"]

docs/src/tutorials/neural_adapter.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function loss(cord, θ)
6969
ch2 .- phi(cord, res.u)
7070
end
7171
72-
strategy = NeuralPDE.QuadratureTraining()
72+
strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6)
7373
7474
prob_ = NeuralPDE.neural_adapter(loss, init_params2, pde_system, strategy)
7575
res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 10000)
@@ -173,7 +173,7 @@ for i in 1:count_decomp
173173
bcs_ = create_bcs(domains_[1].domain, phi_bound)
174174
@named pde_system_ = PDESystem(eq, bcs_, domains_, [x, y], [u(x, y)])
175175
push!(pde_system_map, pde_system_)
176-
strategy = NeuralPDE.QuadratureTraining()
176+
strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6)
177177
178178
discretization = NeuralPDE.PhysicsInformedNN(chains[i], strategy;
179179
init_params = init_params[i])
@@ -243,10 +243,10 @@ callback = function (p, l)
243243
end
244244
245245
prob_ = NeuralPDE.neural_adapter(losses, init_params2, pde_system_map,
246-
NeuralPDE.QuadratureTraining())
246+
NeuralPDE.QuadratureTraining(; reltol = 1e-6))
247247
res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000)
248248
prob_ = NeuralPDE.neural_adapter(losses, res_.u, pde_system_map,
249-
NeuralPDE.QuadratureTraining())
249+
NeuralPDE.QuadratureTraining(; reltol = 1e-6))
250250
res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000)
251251
252252
phi_ = PhysicsInformedNN(chain2, strategy; init_params = res_.u).phi

src/BPINN_ode.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
113113
targetacceptancerate = 0.8),
114114
Integratorkwargs = (Integrator = Leapfrog,),
115115
autodiff = false, progress = false, verbose = false)
116-
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
116+
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
117117
BNNODE(chain, Kernel, strategy,
118118
draw_samples, priorsNNw, param, l2std,
119119
phystd, dataset, physdt, MCMCkwargs,

src/NeuralPDE.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, righte
3030
using SciMLBase: @add_kwonly, parameterless_type
3131
using UnPack: @unpack
3232
import ChainRulesCore, Lux, ComponentArrays
33+
using Lux: FromFluxAdaptor
3334
using ChainRulesCore: @non_differentiable
3435

3536
RuntimeGeneratedFunctions.init(@__MODULE__)

src/advancedHMC_MCMC.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
439439
MCMCkwargs = (n_leapfrog = 30,),
440440
progress = false, verbose = false)
441441

442-
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
442+
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
443443
# NN parameter prior mean and variance(PriorsNN must be a tuple)
444444
if isinplace(prob)
445445
throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))

src/dae_solve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ end
4242

4343
function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false,
4444
kwargs...)
45-
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
45+
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
4646
NNDAE(chain, opt, init_params, autodiff, strategy, kwargs)
4747
end
4848

src/ode_solve.jl

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ of the physics-informed neural network which is used as a solver for a standard
1515
## Positional Arguments
1616
1717
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`.
18-
`Flux.Chain` will be converted to `Lux` using `Lux.transform`.
18+
`Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`.
1919
* `opt`: The optimizer to train the neural network.
2020
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
2121
which thus uses the random initialization provided by the neural network library.
@@ -27,11 +27,10 @@ of the physics-informed neural network which is used as a solver for a standard
2727
the PDE operators. The reverse mode of the loss function is always
2828
automatic differentiation (via Zygote), this is only for the derivative
2929
in the loss function (the derivative with respect to time).
30-
* `batch`: The batch size to use for the internal quadrature. Defaults to `0`, which
31-
means the application of the neural network is done at individual time points one
32-
at a time. `batch>0` means the neural network is applied at a row vector of values
33-
`t` simultaneously, i.e. it's the batch size for the neural network evaluations.
34-
This requires a neural network compatible with batched data.
30+
* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values
31+
`t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data.
32+
`false` means which means the application of the neural network is done at individual time points one at a time.
33+
This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand.
3534
* `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network.
3635
* `strategy`: The training strategy used to choose the points for the evaluations.
3736
Default of `nothing` means that `QuadratureTraining` with QuadGK is used if no
@@ -88,8 +87,8 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function},
8887
end
8988
function NNODE(chain, opt, init_params = nothing;
9089
strategy = nothing,
91-
autodiff = false, batch = nothing, param_estim = false, additional_loss = nothing, kwargs...)
92-
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
90+
autodiff = false, batch = true, param_estim = false, additional_loss = nothing, kwargs...)
91+
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
9392
NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs)
9493
end
9594

@@ -111,11 +110,7 @@ end
111110

112111
function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params)
113112
θ, st = Lux.setup(Random.default_rng(), chain)
114-
if init_params === nothing
115-
init_params = ComponentArrays.ComponentArray(θ)
116-
else
117-
init_params = ComponentArrays.ComponentArray(init_params)
118-
end
113+
isnothing(init_params) && (init_params = θ)
119114
ODEPhi(chain, t, u0, st), init_params
120115
end
121116

@@ -182,7 +177,7 @@ function ode_dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool)
182177
end
183178

184179
"""
185-
inner_loss(phi, f, autodiff, t, θ, p)
180+
inner_loss(phi, f, autodiff, t, θ, p, param_estim)
186181
187182
Simple L2 inner loss at a time `t` with parameters `θ` of the neural network.
188183
"""
@@ -220,7 +215,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector,
220215
end
221216

222217
"""
223-
generate_loss(strategy, phi, f, autodiff, tspan, p, batch)
218+
generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim)
224219
225220
Representation of the loss function, parametric on the training strategy `strategy`.
226221
"""
@@ -229,14 +224,13 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp
229224
integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim))
230225

231226
integrand(ts, θ) = [abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) for t in ts]
232-
@assert batch == 0 # not implemented
233227

234228
function loss(θ, _)
235-
intprob = IntegralProblem(integrand, (tspan[1], tspan[2]), θ)
236-
sol = solve(intprob, QuadGKJL(); abstol = strategy.abstol, reltol = strategy.reltol)
229+
intf = BatchIntegralFunction(integrand, max_batch = strategy.batch)
230+
intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ)
231+
sol = solve(intprob, strategy.quadrature_alg; abstol = strategy.abstol, reltol = strategy.reltol, maxiters = strategy.maxiters)
237232
sol.u
238233
end
239-
240234
return loss
241235
end
242236

@@ -395,16 +389,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
395389
alg.strategy
396390
end
397391

398-
batch = if alg.batch === nothing
399-
if strategy isa QuadratureTraining
400-
strategy.batch
401-
else
402-
true
403-
end
404-
else
405-
alg.batch
406-
end
407-
392+
batch = alg.batch
408393
inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim)
409394
additional_loss = alg.additional_loss
410395
(param_estim && isnothing(additional_loss)) && throw(ArgumentError("Please provide `additional_loss` in `NNODE` for parameter estimation (`param_estim` is true)."))

src/pinn_types.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ methodology.
4848
* `chain`: a vector of Lux/Flux chains with a d-dimensional input and a
4949
1-dimensional output corresponding to each of the dependent variables. Note that this
5050
specification respects the order of the dependent variables as specified in the PDESystem.
51-
Flux chains will be converted to Lux internally using `Lux.transform`.
51+
Flux chains will be converted to Lux internally using `adapt(FromFluxAdaptor(false, false), chain)`.
5252
* `strategy`: determines which training strategy will be used. See the Training Strategy
5353
documentation for more details.
5454
@@ -107,7 +107,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN
107107
if multioutput
108108
!all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain))
109109
else
110-
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
110+
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
111111
end
112112
if phi === nothing
113113
if multioutput
@@ -243,7 +243,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN
243243
if multioutput
244244
!all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain))
245245
else
246-
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
246+
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
247247
end
248248
if phi === nothing
249249
if multioutput

src/training_strategies.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ struct QuadratureTraining{Q <: SciMLBase.AbstractIntegralAlgorithm, T} <:
272272
batch::Int64
273273
end
274274

275-
function QuadratureTraining(; quadrature_alg = CubatureJLh(), reltol = 1e-6, abstol = 1e-3,
275+
function QuadratureTraining(; quadrature_alg = CubatureJLh(), reltol = 1e-3, abstol = 1e-6,
276276
maxiters = 1_000, batch = 100)
277277
QuadratureTraining(quadrature_alg, reltol, abstol, maxiters, batch)
278278
end
@@ -306,11 +306,7 @@ function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::Quadrature
306306
end
307307
area = eltypeθ(prod(abs.(ub .- lb)))
308308
f_ = (lb, ub, loss_, θ) -> begin
309-
# last_x = 1
310309
function integrand(x, θ)
311-
# last_x = x
312-
# mean(abs2,loss_(x,θ), dims=2)
313-
# size_x = fill(size(x)[2],(1,1))
314310
x = adapt(parameterless_type(ComponentArrays.getdata(θ)), x)
315311
sum(abs2, view(loss_(x, θ), 1, :), dims = 2) #./ size_x
316312
end

0 commit comments

Comments
 (0)