Skip to content

Commit dad8815

Browse files
author
Sathvik Bhagavan
committed
refactor: error out if QuadratureTraining is used with complex parameters for NNODE
1 parent fd9afba commit dad8815

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/ode_solve.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
358358

359359
!(chain isa Lux.AbstractExplicitLayer) && error("Only Lux.AbstractExplicitLayer neural networks are supported")
360360
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
361+
((eltype(eltype(init_params).types[1]) <: Complex || eltype(eltype(init_params).types[2]) <: Complex) && alg.strategy isa QuadratureTraining) &&
362+
error("QuadratureTraining cannot be used with complex parameters. Use other strategies.")
361363

362364
init_params = if alg.param_estim
363365
ComponentArrays.ComponentArray(; depvar = ComponentArrays.ComponentArray(init_params), p = prob.p)

0 commit comments

Comments
 (0)