diff --git a/lib/OrdinaryDiffEqCore/Project.toml b/lib/OrdinaryDiffEqCore/Project.toml index 92a7162fbc..499913f0ec 100644 --- a/lib/OrdinaryDiffEqCore/Project.toml +++ b/lib/OrdinaryDiffEqCore/Project.toml @@ -5,7 +5,6 @@ version = "2.1.0" [deps] SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" -Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" @@ -32,17 +31,18 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" FastPower = "a4df4552-cc26-4903-aec0-212e50a0e84b" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" [extras] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" [compat] @@ -93,7 +93,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [targets] -test = ["DiffEqDevTools", "Random", "SafeTestsets", "Test", "JET", "Aqua", "AllocCheck"] +test = ["AllocCheck", "DiffEqDevTools", "Random", "SafeTestsets", "Test", "JET", "Aqua"] [extensions] OrdinaryDiffEqCoreEnzymeCoreExt = "EnzymeCore" diff --git a/lib/OrdinaryDiffEqCore/src/interp_func.jl b/lib/OrdinaryDiffEqCore/src/interp_func.jl index 11b7371194..fbcb0bfcce 100644 --- a/lib/OrdinaryDiffEqCore/src/interp_func.jl +++ b/lib/OrdinaryDiffEqCore/src/interp_func.jl @@ -75,14 +75,17 @@ function SciMLBase.strip_interpolation(id::InterpolationData) end function strip_cache(cache) - if !(cache isa OrdinaryDiffEqCore.DefaultCache) - cache = SciMLBase.constructorof(typeof(cache))([nothing - for name in - fieldnames(typeof(cache))]...) - else - # need to do something special for default cache + if cache isa OrdinaryDiffEqCore.DefaultCache cache = OrdinaryDiffEqCore.DefaultCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}(nothing, nothing, 0, nothing) + else + try + cache = SciMLBase.constructorof(typeof(cache))([nothing + for name in + fieldnames(typeof(cache))]...) + catch + cache = (; (name => nothing for name in fieldnames(typeof(cache)))...) + end end cache diff --git a/lib/OrdinaryDiffEqDifferentiation/Project.toml b/lib/OrdinaryDiffEqDifferentiation/Project.toml index dfb3b41700..d1651d3e01 100644 --- a/lib/OrdinaryDiffEqDifferentiation/Project.toml +++ b/lib/OrdinaryDiffEqDifferentiation/Project.toml @@ -30,12 +30,12 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" OrdinaryDiffEqDifferentiationSparseArraysExt = "SparseArrays" [extras] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -50,6 +50,7 @@ FiniteDiff = "2.27" StaticArrayInterface = "1.8" DifferentiationInterface = "0.6.54, 0.7" LinearSolve = "3.26" + ConstructionBase = "1.5.8" LinearAlgebra = "1.10" SciMLBase = "2.99" diff --git a/lib/OrdinaryDiffEqSDIRK/Project.toml b/lib/OrdinaryDiffEqSDIRK/Project.toml index 0b9e582d10..361bba5f0f 100644 --- a/lib/OrdinaryDiffEqSDIRK/Project.toml +++ b/lib/OrdinaryDiffEqSDIRK/Project.toml @@ -4,54 +4,56 @@ authors = ["ParamThakkar123 "] version = "1.9.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" -MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b" -TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" -SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" +OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" +OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b" OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [extras] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -Test = "<0.0.1, 1" -FastBroadcast = "0.3" -Random = "<0.0.1, 1" +ADTypes = "1.16" +AllocCheck = "0.2" +Aqua = "0.8.11" +DiffEqBase = "6.177.2" DiffEqDevTools = "2.44.4" -MuladdMacro = "0.2" +FastBroadcast = "0.3" +JET = "0.9, 0.11" LinearAlgebra = "1.10" -OrdinaryDiffEqDifferentiation = "1.12.0" -TruncatedStacktraces = "1.4" -SciMLBase = "2.99" -OrdinaryDiffEqCore = "2" -Aqua = "0.8.11" MacroTools = "0.5" -julia = "1.10" -JET = "0.9, 0.11" -ADTypes = "1.16" -RecursiveArrayTools = "3.36" +MuladdMacro = "0.2" +OrdinaryDiffEqCore = "2" +OrdinaryDiffEqDifferentiation = "1.12.0" OrdinaryDiffEqNonlinearSolve = "1.13.0" -AllocCheck = "0.2" -DiffEqBase = "6.176" -Reexport = "1.2" +Random = "<0.0.1, 1" +RecursiveArrayTools = "3.36" +Reexport = "1.2.2" SafeTestsets = "0.1.0" +SciMLBase = "2.99" +StaticArrays = "1.0" +Test = "<0.0.1, 1" +TruncatedStacktraces = "1.4" +julia = "1.10" [targets] -test = ["DiffEqDevTools", "Random", "SafeTestsets", "Test", "JET", "Aqua", "AllocCheck"] +test = ["AllocCheck", "DiffEqDevTools", "Random", "SafeTestsets", "Test", "JET", "Aqua"] [sources.OrdinaryDiffEqDifferentiation] path = "../OrdinaryDiffEqDifferentiation" diff --git a/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl b/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl index 25dedf9306..2656898bb6 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl @@ -12,13 +12,15 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, alg_cache, _vec, _reshape, @cache, isfsal, full_cache, constvalue, _unwrap_val, _ode_interpolant, trivial_limiter!, _ode_interpolant!, - isesdirk, issplit, + isesdirk, issplit, recursivefill!, ssp_coefficient, get_fsalfirstlast, generic_solver_docstring, _bool_to_ADType, _process_AD_choice, current_extrapolant! using TruncatedStacktraces: @truncate_stacktrace using MuladdMacro, MacroTools, FastBroadcast, RecursiveArrayTools using SciMLBase: SplitFunction using LinearAlgebra: mul!, I +using StaticArrays +import RecursiveArrayTools: recursivefill! import OrdinaryDiffEqCore using OrdinaryDiffEqDifferentiation: UJacobianWrapper, dolinsolve @@ -32,17 +34,16 @@ using Reexport include("algorithms.jl") include("alg_utils.jl") -include("sdirk_caches.jl") -include("kencarp_kvaerno_caches.jl") -include("sdirk_perform_step.jl") -include("kencarp_kvaerno_perform_step.jl") +include("tableau_utils.jl") include("sdirk_tableaus.jl") +include("unified_sdirk_tableaus.jl") +include("sdirk_caches.jl") +include("generic_sdirk_perform_step.jl") export ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2, SDIRK2, SDIRK22, Kvaerno3, KenCarp3, Cash4, Hairer4, Hairer42, SSPSDIRK2, Kvaerno4, Kvaerno5, KenCarp4, KenCarp47, KenCarp5, KenCarp58, ESDIRK54I8L2SA, SFSDIRK4, - SFSDIRK5, CFNLIRK3, SFSDIRK6, SFSDIRK7, SFSDIRK8, Kvaerno5, KenCarp4, KenCarp5, - SFSDIRK4, SFSDIRK5, CFNLIRK3, SFSDIRK6, - SFSDIRK7, SFSDIRK8, ESDIRK436L2SA2, ESDIRK437L2SA, ESDIRK547L2SA2, ESDIRK659L2SA + SFSDIRK5, CFNLIRK3, SFSDIRK6, SFSDIRK7, SFSDIRK8, ESDIRK436L2SA2, ESDIRK437L2SA, + ESDIRK547L2SA2, ESDIRK659L2SA end diff --git a/lib/OrdinaryDiffEqSDIRK/src/generic_sdirk_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/generic_sdirk_perform_step.jl new file mode 100644 index 0000000000..9c8975ed26 --- /dev/null +++ b/lib/OrdinaryDiffEqSDIRK/src/generic_sdirk_perform_step.jl @@ -0,0 +1,383 @@ +# Generic tableau-based perform_step! implementation for SDIRK methods + +using OrdinaryDiffEqCore: unwrap_alg, OrdinaryDiffEqCore, calculate_residuals +using OrdinaryDiffEqNonlinearSolve: markfirststage!, nlsolve!, nlsolvefail, isnewton, set_new_W!, get_W +using OrdinaryDiffEqDifferentiation: dolinsolve +using LinearAlgebra: I, mul! +using FastBroadcast: @.. +using MuladdMacro: @muladd + +@inline _get_step_limiter(alg, cache) = hasproperty(alg, :step_limiter!) ? alg.step_limiter! : trivial_limiter! + +# Type-stable IMEX dispatch functions +@inline _is_imex_scheme(f, ::SDIRKTableau{T, T2, S, hasEmbedded, true}) where {T, T2, S, hasEmbedded} = f isa SplitFunction +@inline _is_imex_scheme(f, ::SDIRKTableau{T, T2, S, hasEmbedded, false}) where {T, T2, S, hasEmbedded} = false + +function initialize!(integrator, cache::SDIRKConstantCache) + integrator.kshortsize = 2 + integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) + integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + + integrator.fsallast = zero(integrator.fsalfirst) + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast +end + +function initialize!(integrator, cache::SDIRKMutableCache) + integrator.kshortsize = 2 + resize!(integrator.k, integrator.kshortsize) + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) +end + +# Dispatch for unified caches - all SDIRK algorithms use these same cache types +@muladd function perform_step!(integrator, cache::SDIRKCache, repeat_step=false) + (; t, dt, uprev, u, f, p) = integrator + (; zs, atmp, nlsolver, tab) = cache + (; tmp) = nlsolver + alg = unwrap_alg(integrator, true) + step_limiter! = _get_step_limiter(alg, cache) + + A = tab.A + b = tab.b + c = tab.c + b_embed = tab.b_embed + s = size(A, 1) + + is_imex = _is_imex_scheme(integrator.f, tab) + + if is_imex + k_explicit = Vector{typeof(u)}(undef, s) + for i in 1:s + k_explicit[i] = zero(u) + end + f_impl = integrator.f.f1 + f_expl = integrator.f.f2 + else + k_explicit = nothing + f_impl = integrator.f + f_expl = nothing + end + + markfirststage!(nlsolver) + + for i in 1:s + zi = zs[i] + + if i == 1 + if is_imex && !repeat_step && !integrator.last_stepfail + f_impl(zi, uprev, p, t) + @.. broadcast=false zi *= dt + elseif alg.extrapolant == :linear + @.. broadcast=false zi = dt * integrator.fsalfirst + else + fill!(zi, zero(eltype(u))) + end + else + fill!(zi, zero(eltype(u))) + + if hasproperty(tab, :α_pred) && tab.α_pred !== nothing + for j in 1:i-1 + @.. broadcast=false zi += tab.α_pred[i, j] * zs[j] + end + end + end + + nlsolver.z = zi + + @.. broadcast=false nlsolver.tmp = uprev + for j in 1:i-1 + @.. broadcast=false nlsolver.tmp += A[i, j] * zs[j] + end + + if is_imex && tab.A_explicit !== nothing + if i == 1 + @.. broadcast=false k_explicit[1] = dt * integrator.fsalfirst - zi + else + @.. broadcast=false u = nlsolver.tmp + A[i,i] * zs[i-1] + c_exp = tab.c_explicit !== nothing ? tab.c_explicit[i] : tab.c[i] + f_expl(k_explicit[i], u, p, t + c_exp * dt) + @.. broadcast=false k_explicit[i] *= dt + integrator.stats.nf2 += 1 + end + + for j in 1:i-1 + @.. broadcast=false nlsolver.tmp += tab.A_explicit[i, j] * k_explicit[j] + end + end + if iszero(A[i, i]) + # explicit stage (no nonlinear solve needed) + nlsolver.c = typeof(nlsolver.c)(c[i]) + f_impl(zi, nlsolver.tmp, p, t + c[i] * dt) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + @.. broadcast=false zi *= dt + else + nlsolver.c = typeof(nlsolver.c)(c[i]) + nlsolver.γ = typeof(nlsolver.γ)(A[i, i]) + + if i > 1 && isnewton(nlsolver) + set_new_W!(nlsolver, false) + end + + zi .= nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + end + end + + @.. broadcast=false u = uprev + for i in 1:s + @.. broadcast=false u += b[i] * zs[i] + end + + if is_imex && tab.b_explicit !== nothing + if s >= 1 + @.. broadcast=false tmp = nlsolver.tmp + A[s,s] * zs[s] + f_expl(k_explicit[s], tmp, p, t + dt) + @.. broadcast=false k_explicit[s] *= dt + integrator.stats.nf2 += 1 + end + + for i in 1:s + @.. broadcast=false u += tab.b_explicit[i] * k_explicit[i] + end + end + + step_limiter!(u, integrator, p, t + dt) + + if integrator.opts.adaptive + if alg isa ImplicitEuler && integrator.success_iter > 0 + uprev2 = integrator.uprev2 + tprev = integrator.tprev + + dt1 = dt * (t + dt - tprev) + dt2 = (t - tprev) * (t + dt - tprev) + c = 7 / 12 + r = c * dt^2 + + @.. broadcast=false tmp = r * ((u - uprev) / dt1 - (uprev - uprev2) / dt2) + calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + elseif alg isa Trapezoid && integrator.success_iter > 0 + uprev2 = integrator.uprev2 + tprev = integrator.tprev + + dt1 = dt * (t + dt - tprev) + dt2 = (t - tprev) * (t + dt - tprev) + c = 1 / 12 + r = c * dt^3 + + @.. broadcast=false tmp = r * ((u - uprev) / dt1 - (uprev - uprev2) / dt2) + calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + elseif b_embed !== nothing + @.. broadcast=false tmp = zero(eltype(u)) + for i in 1:s + @.. broadcast=false tmp += b_embed[i] * zs[i] + end + + has_smooth_est = hasfield(typeof(alg), :smooth_est) + if has_smooth_est && alg.smooth_est && isnewton(nlsolver) + est = atmp + linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), + linu = _vec(est)) + integrator.stats.nsolve += 1 + else + est = tmp + end + + calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + else + integrator.EEst = 1 + end + end + + if is_imex + integrator.f(integrator.fsallast, u, p, t + dt) + else + @.. broadcast=false integrator.fsallast = zs[s] / dt + end + + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + integrator.u = u +end + +@muladd function perform_step!(integrator, cache::SDIRKConstantCache, repeat_step=false) + (; t, dt, uprev, u, f, p) = integrator + (; tab, nlsolver) = cache + alg = unwrap_alg(integrator, true) + + s = size(tab.A, 1) + c = tab.c + A = tab.A + b = tab.b + b_embed = tab.b_embed + γ = tab.γ + + is_imex = _is_imex_scheme(integrator.f, tab) + + z = Vector{typeof(u)}(undef, s) + + if is_imex + k_explicit = Vector{typeof(u)}(undef, s) + f_impl = integrator.f.f1 + f_expl = integrator.f.f2 + else + k_explicit = nothing + f_impl = integrator.f + f_expl = nothing + end + + markfirststage!(nlsolver) + + for i in 1:s + stage_sum = uprev + for j in 1:i-1 + stage_sum += A[i,j] * z[j] + end + + if is_imex && tab.A_explicit !== nothing + if i == 1 + k_explicit[1] = dt * integrator.fsalfirst - (is_imex ? dt * f_impl(uprev, p, t) : zero(u)) + else + u_tmp = nlsolver.tmp + if i >= 2 + u_tmp += tab.γ * z[i-1] + end + + c_exp = tab.c_explicit !== nothing ? tab.c_explicit[i] : tab.c[i] + k_explicit[i] = dt * f_expl(u_tmp, p, t + c_exp * dt) + integrator.stats.nf2 += 1 + end + + for j in 1:i-1 + stage_sum += tab.A_explicit[i,j] * k_explicit[j] + end + end + + if i == 1 + if is_imex + z_guess = dt * f_impl(uprev, p, t) + else + z_guess = (alg.extrapolant == :linear) ? dt * integrator.fsalfirst : zero(u) + end + elseif i > 1 && hasproperty(tab, :α_pred) && tab.α_pred !== nothing + z_guess = zero(u) + @inbounds for j in 1:i-1 + z_guess += tab.α_pred[i,j] * z[j] + end + else + z_guess = zero(u) + end + + nlsolver.z = z_guess + nlsolver.tmp = stage_sum + nlsolver.c = typeof(nlsolver.c)(c[i]) + if iszero(A[i,i]) + # explicit stage (no nonlinear solve required) + z[i] = dt * f_impl(stage_sum, p, t + c[i] * dt) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + else + nlsolver.γ = typeof(nlsolver.γ)(A[i,i]) + + z[i] = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + end + end + + u = uprev + for i in 1:s + u += b[i] * z[i] + end + + if is_imex && tab.b_explicit !== nothing + if s >= 1 + u_final = nlsolver.tmp + tab.γ * z[s] + k_explicit[s] = dt * f_expl(u_final, p, t + dt) + integrator.stats.nf2 += 1 + end + + for i in 1:s + u += tab.b_explicit[i] * k_explicit[i] + end + end + + if hasproperty(alg, :step_limiter!) + alg.step_limiter!(u, integrator, p, t + dt) + end + + if integrator.opts.adaptive + if alg isa ImplicitEuler && integrator.success_iter > 0 + uprev2 = integrator.uprev2 + tprev = integrator.tprev + + dt1 = dt * (t + dt - tprev) + dt2 = (t - tprev) * (t + dt - tprev) + c = 7 / 12 + r = c * dt^2 + + tmp = r * ((u - uprev) / dt1 - (uprev - uprev2) / dt2) + atmp = calculate_residuals(tmp, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + elseif alg isa Trapezoid && integrator.success_iter > 0 + uprev2 = integrator.uprev2 + tprev = integrator.tprev + + dt1 = dt * (t + dt - tprev) + dt2 = (t - tprev) * (t + dt - tprev) + c = 1 / 12 + r = c * dt^3 + + tmp = r * ((u - uprev) / dt1 - (uprev - uprev2) / dt2) + atmp = calculate_residuals(tmp, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + elseif b_embed !== nothing + tmp = zero(u) + for i in 1:s + tmp += b_embed[i] * z[i] + end + + has_smooth_est = hasfield(typeof(alg), :smooth_est) + if isnewton(nlsolver) && has_smooth_est && alg.smooth_est + integrator.stats.nsolve += 1 + est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) + else + est = tmp + end + + atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + else + integrator.EEst = 1 + end + end + + if is_imex + integrator.k[1] = integrator.fsalfirst + integrator.fsallast = integrator.f(u, p, t + dt) + integrator.k[2] = integrator.fsallast + else + integrator.fsallast = z[s] ./ dt + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + end + integrator.u = u +end + + + + + + +# All other SDIRK methods use the generic implementation through the unified cache dispatch above + diff --git a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl deleted file mode 100644 index 4a09a83fed..0000000000 --- a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl +++ /dev/null @@ -1,642 +0,0 @@ -mutable struct Kvaerno3ConstantCache{Tab, N} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::Kvaerno3, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Kvaerno3Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, 2tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - Kvaerno3ConstantCache(nlsolver, tab) -end - -@cache mutable struct Kvaerno3Cache{uType, rateType, uNoUnitsType, Tab, N, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache(alg::Kvaerno3, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Kvaerno3Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, 2tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - Kvaerno3Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, atmp, nlsolver, tab, alg.step_limiter!) -end - -@cache mutable struct KenCarp3ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::KenCarp3, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp3Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - - KenCarp3ConstantCache(nlsolver, tab) -end - -@cache mutable struct KenCarp3Cache{ - uType, rateType, uNoUnitsType, N, Tab, kType, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - k1::kType - k2::kType - k3::kType - k4::kType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache(alg::KenCarp3, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp3Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - uf = UJacobianWrapper(f, t, p) - end - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - KenCarp3Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, k1, k2, - k3, k4, atmp, nlsolver, tab, alg.step_limiter!) -end - -@cache mutable struct CFNLIRK3ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::CFNLIRK3, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = CFNLIRK3Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - - CFNLIRK3ConstantCache(nlsolver, tab) -end - -@cache mutable struct CFNLIRK3Cache{uType, rateType, uNoUnitsType, N, Tab, kType} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - k1::kType - k2::kType - k3::kType - k4::kType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -function alg_cache(alg::CFNLIRK3, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = CFNLIRK3Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - CFNLIRK3Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, k1, k2, k3, k4, atmp, nlsolver, tab) -end - -@cache mutable struct Kvaerno4ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::Kvaerno4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Kvaerno4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - Kvaerno4ConstantCache(nlsolver, tab) -end - -@cache mutable struct Kvaerno4Cache{uType, rateType, uNoUnitsType, N, Tab, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache(alg::Kvaerno4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Kvaerno4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - Kvaerno4Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, atmp, nlsolver, tab, alg.step_limiter!) -end - -@cache mutable struct KenCarp4ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::KenCarp4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - KenCarp4ConstantCache(nlsolver, tab) -end - -@cache mutable struct KenCarp4Cache{ - uType, rateType, uNoUnitsType, N, Tab, kType, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - k1::kType - k2::kType - k3::kType - k4::kType - k5::kType - k6::kType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -@truncate_stacktrace KenCarp4Cache 1 - -function alg_cache(alg::KenCarp4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - k5 = zero(u) - k6 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - uf = UJacobianWrapper(f, t, p) - end - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - KenCarp4Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, k1, k2, k3, k4, k5, k6, atmp, - nlsolver, tab, alg.step_limiter!) -end - -@cache mutable struct Kvaerno5ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::Kvaerno5, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Kvaerno5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - - Kvaerno5ConstantCache(nlsolver, tab) -end - -@cache mutable struct Kvaerno5Cache{uType, rateType, uNoUnitsType, N, Tab, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache(alg::Kvaerno5, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Kvaerno5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - Kvaerno5Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, - z₇, atmp, nlsolver, tab, alg.step_limiter!) -end - -@cache mutable struct KenCarp5ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::KenCarp5, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - - KenCarp5ConstantCache(nlsolver, tab) -end - -@cache mutable struct KenCarp5Cache{ - uType, rateType, uNoUnitsType, N, Tab, kType, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - z₈::uType - k1::kType - k2::kType - k3::kType - k4::kType - k5::kType - k6::kType - k7::kType - k8::kType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache(alg::KenCarp5, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - k5 = zero(u) - k6 = zero(u) - k7 = zero(u) - k8 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - k7 = nothing - k8 = nothing - end - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = zero(u) - z₈ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - KenCarp5Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, - k1, k2, k3, k4, k5, k6, k7, k8, atmp, nlsolver, tab, alg.step_limiter!) -end - -@cache mutable struct KenCarp47ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::KenCarp47, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp47Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - - KenCarp47ConstantCache(nlsolver, tab) -end - -@cache mutable struct KenCarp47Cache{uType, rateType, uNoUnitsType, N, Tab, kType} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - k1::kType - k2::kType - k3::kType - k4::kType - k5::kType - k6::kType - k7::kType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end -@truncate_stacktrace KenCarp47Cache 1 - -function alg_cache(alg::KenCarp47, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp47Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - k5 = zero(u) - k6 = zero(u) - k7 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - k7 = nothing - end - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - KenCarp47Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, - k1, k2, k3, k4, k5, k6, k7, atmp, nlsolver, tab) -end - -@cache mutable struct KenCarp58ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::KenCarp58, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp58Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - - KenCarp58ConstantCache(nlsolver, tab) -end - -@cache mutable struct KenCarp58Cache{uType, rateType, uNoUnitsType, N, Tab, kType} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - z₈::uType - k1::kType - k2::kType - k3::kType - k4::kType - k5::kType - k6::kType - k7::kType - k8::kType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -@truncate_stacktrace KenCarp58Cache 1 - -function alg_cache(alg::KenCarp58, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp58Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - k5 = zero(u) - k6 = zero(u) - k7 = zero(u) - k8 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - k7 = nothing - k8 = nothing - end - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = zero(u) - z₈ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - KenCarp58Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, - k1, k2, k3, k4, k5, k6, k7, k8, atmp, nlsolver, tab) -end diff --git a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl deleted file mode 100644 index 0491e70048..0000000000 --- a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl +++ /dev/null @@ -1,2670 +0,0 @@ -@muladd function perform_step!(integrator, cache::Kvaerno3ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, btilde1, btilde2, btilde3, btilde4, c3, α31, α32) = cache.tab - alg = unwrap_alg(integrator, true) - - # calculate W - markfirststage!(nlsolver) - - # FSAL Step 1 - nlsolver.z = z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - nlsolver.z = z₂ = z₁ - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess is from Hermite derivative on z₁ and z₂ - nlsolver.z = z₃ = α31 * z₁ + α32 * z₂ - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = a31 * z₁ + a32 * z₂ + γ * z₃ # use yhat as prediction - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = 1 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₄ - - ################################### Finalize - - if integrator.opts.adaptive - tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₄ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::Kvaerno3Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, atmp, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, btilde1, btilde2, btilde3, btilde4, c3, α31, α32) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - # FSAL Step 1 - @.. broadcast=false z₁=dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - copyto!(z₂, z₁) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - # Guess is from Hermite derivative on z₁ and z₂ - @.. broadcast=false z₃=α31 * z₁ + α32 * z₂ - nlsolver.z = z₃ - - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if cache isa Kvaerno3Cache - @.. broadcast=false z₄=a31 * z₁ + a32 * z₂ + γ * z₃ # use yhat as prediction - elseif cache isa KenCarp3Cache - (; α41, α42) = cache.tab - @.. broadcast=false z₄=α41 * z₁ + α42 * z₂ - end - nlsolver.z = z₄ - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = 1 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₄ - - step_limiter!(u, integrator, p, t + dt) - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast=false tmp=btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast=false integrator.fsallast=z₄ / dt -end - -@muladd function perform_step!(integrator, cache::KenCarp3ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, btilde1, btilde2, btilde3, btilde4, c3, α31, α32, ea21, ea31, ea32, ea41, ea42, ea43, eb1, eb2, eb3, eb4, ebtilde1, ebtilde2, ebtilde3, ebtilde4) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - if integrator.f isa SplitFunction - # Explicit tableau is not FSAL - # Make this not compute on repeat - z₁ = dt * f(uprev, p, t) - else - # FSAL Step 1 - z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation for guess - nlsolver.z = z₂ = z₁ - - nlsolver.tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - k1 = dt * integrator.fsalfirst - z₁ - nlsolver.tmp += ea21 * k1 - end - - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ = z₂ - u = nlsolver.tmp + γ * z₂ - k2 = dt * f2(u, p, t + 2γdt) - integrator.stats.nf2 += 1 - tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - z₃ = α31 * z₁ + α32 * z₂ - tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - nlsolver.tmp = tmp - nlsolver.c = c3 - - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ = z₂ - u = nlsolver.tmp + γ * z₃ - k3 = dt * f2(u, p, t + c3 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 - else - (; α41, α42) = cache.tab - z₄ = α41 * z₁ + α42 * z₂ - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - nlsolver.c = 1 - nlsolver.tmp = tmp - - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₄ - if integrator.f isa SplitFunction - k4 = dt * f2(u, p, t + dt) - integrator.stats.nf2 += 1 - u = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ + eb1 * k1 + eb2 * k2 + - eb3 * k3 + eb4 * k4 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - ebtilde1 * k1 + ebtilde2 * k2 + ebtilde3 * k3 + ebtilde4 * k4 - else - tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ - end - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₄ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::KenCarp3Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, k1, k2, k3, k4, atmp, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, btilde1, btilde2, btilde3, btilde4, c3, α31, α32) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, eb1, eb2, eb3, eb4) = cache.tab - (; ebtilde1, ebtilde2, ebtilde3, ebtilde4) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - # Explicit tableau is not FSAL - # Make this not compute on repeat - f(z₁, integrator.uprev, p, integrator.t) - z₁ .*= dt - else - # FSAL Step 1 - @.. broadcast=false z₁=dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation for guess - copyto!(z₂, z₁) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - @.. broadcast=false k1=dt * integrator.fsalfirst - z₁ - @.. broadcast=false tmp+=ea21 * k1 - end - - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ .= z₂ - @.. broadcast=false u=tmp + γ * z₂ - f2(k2, u, p, t + 2γdt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - @.. broadcast=false z₃=α31 * z₁ + α32 * z₂ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ .= z₂ - @.. broadcast=false u=tmp + γ * z₃ - f2(k3, u, p, t + c3 * dt) - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + - ea42 * k2 + ea43 * k3 - else - (; α41, α42) = cache.tab - @.. broadcast=false z₄=α41 * z₁ + α42 * z₂ - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - - nlsolver.c = 1 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₄ - if integrator.f isa SplitFunction - f2(k4, u, p, t + dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false u=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ + eb1 * k1 + - eb2 * k2 + eb3 * k3 + eb4 * k4 - end - - step_limiter!(u, integrator, p, t + dt) - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - @.. broadcast=false tmp=btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + - btilde4 * z₄ + ebtilde1 * k1 + ebtilde2 * k2 + - ebtilde3 * k3 + ebtilde4 * k4 - else - @.. broadcast=false tmp=btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + - btilde4 * z₄ - end - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast=false integrator.fsallast=z₄ / dt - end -end - -@muladd function perform_step!(integrator, cache::CFNLIRK3ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, c2, c3, ea21, ea31, ea32, ea41, ea42, ea43, eb1, eb2, eb3, eb4) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - if integrator.f isa SplitFunction - # Explicit tableau is not FSAL - # Make this not compute on repeat - z₁ = dt .* f(uprev, p, t) - else - # FSAL Step 1 - z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation for guess - nlsolver.z = z₂ = z₁ - - nlsolver.tmp = uprev - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - k1 = dt .* f2(uprev, p, t) - nlsolver.tmp += ea21 * k1 - end - - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ = z₂ - u = nlsolver.tmp + γ * z₂ - k2 = dt * f2(u, p, t + c2 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - z₃ = z₂ - tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - nlsolver.tmp = tmp - nlsolver.c = c3 - - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ = z₃ - u = nlsolver.tmp + γ * z₃ - k3 = dt * f2(u, p, t + c3 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 - else - z₄ = z₃ - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - nlsolver.c = 1 - nlsolver.tmp = tmp - - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₄ - if integrator.f isa SplitFunction - k4 = dt * f2(u, p, t + dt) - integrator.stats.nf2 += 1 - u = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ + eb1 * k1 + eb2 * k2 + - eb3 * k3 + eb4 * k4 - end - - ################################### Finalize - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₄ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::CFNLIRK3Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, k1, k2, k3, k4, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, c2, c3) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, eb1, eb2, eb3, eb4) = cache.tab - - alg = unwrap_alg(integrator, true) - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - f(z₁, integrator.uprev, p, integrator.t) - z₁ .*= dt - else - # FSAL Step 1 - @.. broadcast=false z₁=dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation for guess - copyto!(z₂, z₁) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - @.. broadcast=false k1=dt * integrator.fsalfirst - z₁ - @.. broadcast=false tmp+=ea21 * k1 - end - - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ .= z₂ - @.. broadcast=false u=tmp + γ * z₂ - f2(k2, u, p, t + c2 * dt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - @.. broadcast=false z₃=z₂ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ .= z₂ - @.. broadcast=false u=tmp + γ * z₃ - f2(k3, u, p, t + c3 * dt) - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + - ea42 * k2 + ea43 * k3 - else - @.. broadcast=false z₄=z₂ - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - - nlsolver.c = 1 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₄ - if integrator.f isa SplitFunction - f2(k4, u, p, t + dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false u=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ + eb1 * k1 + - eb2 * k2 + eb3 * k3 + eb4 * k4 - end - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast=false integrator.fsallast=z₄ / dt - end -end - -@muladd function perform_step!(integrator, cache::Kvaerno4ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, c3, c4) = cache.tab - (; α21, α31, α32, α41, α42) = cache.tab - (; btilde1, btilde2, btilde3, btilde4, btilde5) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = zero(u) - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = α31 * z₁ + α32 * z₂ - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = α41 * z₁ + α42 * z₂ - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use yhat2 for prediction - nlsolver.z = z₅ = a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = 1 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₄ - - ################################### Finalize - - if integrator.opts.adaptive - tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₅ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::Kvaerno4Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, atmp, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, c3, c4) = cache.tab - (; α21, α31, α32, α41, α42) = cache.tab - (; btilde1, btilde2, btilde3, btilde4, btilde5) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast=false z₁=dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Allow other choices here - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - @.. broadcast=false z₃=α31 * z₁ + α32 * z₂ - nlsolver.z = z₃ - - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - @.. broadcast=false z₄=α41 * z₁ + α42 * z₂ - nlsolver.z = z₄ - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use yhat prediction - @.. broadcast=false z₅=a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ - nlsolver.z = z₅ - - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = 1 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₅ - - step_limiter!(u, integrator, p, t + dt) - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast=false tmp=btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast=false integrator.fsallast=z₅ / dt -end - -@muladd function perform_step!(integrator, cache::KenCarp4ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, c3, c4, c5) = cache.tab - (; α31, α32, α41, α42, α51, α52, α53, α54, α61, α62, α63, α64, α65) = cache.tab - (; btilde1, btilde3, btilde4, btilde5, btilde6) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, ea61, ea62, ea63, ea64, ea65) = cache.tab - (; eb1, eb3, eb4, eb5, eb6) = cache.tab - (; ebtilde1, ebtilde3, ebtilde4, ebtilde5, ebtilde6) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - if integrator.f isa SplitFunction - # Explicit tableau is not FSAL - # Make this not compute on repeat - z₁ = dt .* f(uprev, p, t) - else - # FSAL Step 1 - z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = z₁ - - tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - k1 = dt * integrator.fsalfirst - z₁ - tmp += ea21 * k1 - end - nlsolver.tmp = tmp - nlsolver.c = 2γ - - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ = z₂ - u = nlsolver.tmp + γ * z₂ - k2 = dt * f2(u, p, t + 2γdt) - integrator.stats.nf2 += 1 - tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - z₃ = α31 * z₁ + α32 * z₂ - tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - nlsolver.tmp = tmp - nlsolver.c = c3 - - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ = z₂ - u = nlsolver.tmp + γ * z₃ - k3 = dt * f2(u, p, t + c3 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 - else - z₄ = α41 * z₁ + α42 * z₂ - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - nlsolver.tmp = tmp - nlsolver.c = c4 - - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ = z₄ - u = nlsolver.tmp + γ * z₄ - k4 = dt * f2(u, p, t + c4 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ + ea51 * k1 + ea52 * k2 + - ea53 * k3 + ea54 * k4 - else - z₅ = α51 * z₁ + α52 * z₂ + α53 * z₃ + α54 * z₄ - tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - nlsolver.tmp = tmp - nlsolver.c = c5 - - u = nlsolver.tmp + γ * z₅ - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ = z₅ - u = nlsolver.tmp + γ * z₅ - k5 = dt * f2(u, p, t + c5 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + ea61 * k1 + ea62 * k2 + - ea63 * k3 + ea64 * k4 + ea65 * k5 - else - z₆ = α61 * z₁ + α62 * z₂ + α63 * z₃ + α64 * z₄ + α65 * z₅ - tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - nlsolver.tmp = tmp - nlsolver.c = 1 - - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₆ - if integrator.f isa SplitFunction - k6 = dt * f2(u, p, t + dt) - integrator.stats.nf2 += 1 - u = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + γ * z₆ + eb1 * k1 + - eb3 * k3 + eb4 * k4 + eb5 * k5 + eb6 * k6 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - tmp = btilde1 * z₁ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + - ebtilde1 * k1 + ebtilde3 * k3 + ebtilde4 * k4 + ebtilde5 * k5 + - ebtilde6 * k6 - else - tmp = btilde1 * z₁ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ - end - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₆ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::KenCarp4Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, atmp, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; k1, k2, k3, k4, k5, k6) = cache - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, c3, c4, c5) = cache.tab - (; α31, α32, α41, α42, α51, α52, α53, α54, α61, α62, α63, α64, α65) = cache.tab - (; btilde1, btilde3, btilde4, btilde5, btilde6) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, ea61, ea62, ea63, ea64, ea65) = cache.tab - (; eb1, eb3, eb4, eb5, eb6) = cache.tab - (; ebtilde1, ebtilde3, ebtilde4, ebtilde5, ebtilde6) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - # Explicit tableau is not FSAL - # Make this not compute on repeat - f(z₁, integrator.uprev, p, integrator.t) - z₁ .*= dt - else - # FSAL Step 1 - @.. broadcast=false z₁=dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Allow other choices here - copyto!(z₂, z₁) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - @.. broadcast=false k1=dt * integrator.fsalfirst - z₁ - @.. broadcast=false tmp+=ea21 * k1 - end - - nlsolver.c = 2γ - markfirststage!(nlsolver) - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ .= z₂ - @.. broadcast=false u=tmp + γ * z₂ - f2(k2, u, p, t + 2γdt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - @.. broadcast=false z₃=α31 * z₁ + α32 * z₂ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ .= z₂ - @.. broadcast=false u=tmp + γ * z₃ - f2(k3, u, p, t + c3 * dt) - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + - ea42 * k2 + ea43 * k3 - else - @.. broadcast=false z₄=α41 * z₁ + α42 * z₂ - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ .= z₄ - @.. broadcast=false u=tmp + γ * z₄ - f2(k4, u, p, t + c4 * dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ + - ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 - else - @.. broadcast=false z₅=α51 * z₁ + α52 * z₂ + α53 * z₃ + α54 * z₄ - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ .= z₅ - @.. broadcast=false u=tmp + γ * z₅ - f2(k5, u, p, t + c5 * dt) - k5 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + - ea61 * k1 + ea62 * k2 + ea63 * k3 + ea64 * k4 + ea65 * k5 - else - @.. broadcast=false z₆=α61 * z₁ + α62 * z₂ + α63 * z₃ + α64 * z₄ + α65 * z₅ - @.. broadcast=false tmp=uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - - nlsolver.c = 1 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₆ - if integrator.f isa SplitFunction - f2(k6, u, p, t + dt) - k6 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false u=uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + γ * z₆ + - eb1 * k1 + eb3 * k3 + eb4 * k4 + eb5 * k5 + eb6 * k6 - end - - step_limiter!(u, integrator, p, t + dt) - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - @.. broadcast=false tmp=btilde1 * z₁ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + btilde6 * z₆ + ebtilde1 * k1 + - ebtilde3 * k3 + ebtilde4 * k4 + ebtilde5 * k5 + - ebtilde6 * k6 - else - @.. broadcast=false tmp=btilde1 * z₁ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + btilde6 * z₆ - end - - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast=false integrator.fsallast=z₆ / dt - end -end - -@muladd function perform_step!(integrator, cache::Kvaerno5ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, a71, a73, a74, a75, a76, c3, c4, c5, c6) = cache.tab - (; btilde1, btilde3, btilde4, btilde5, btilde6, btilde7) = cache.tab - (; α31, α32, α41, α42, α43, α51, α52, α53, α61, α62, α63) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = z₁ - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = α31 * z₁ + α32 * z₂ - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = α41 * z₁ + α42 * z₂ + α43 * z₃ - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = z₅ = α51 * z₁ + α52 * z₂ + α53 * z₃ - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = z₆ = α61 * z₁ + α62 * z₂ + α63 * z₃ - - nlsolver.tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - # Prediction from embedding - nlsolver.z = z₇ = a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + γ * z₆ - - nlsolver.tmp = uprev + a71 * z₁ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = 1 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₇ - - ################################### Finalize - - if integrator.opts.adaptive - tmp = btilde1 * z₁ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + - btilde7 * z₇ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₇ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::Kvaerno5Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, a71, a73, a74, a75, a76, c3, c4, c5, c6) = cache.tab - (; btilde1, btilde3, btilde4, btilde5, btilde6, btilde7) = cache.tab - (; α31, α32, α41, α42, α43, α51, α52, α53, α61, α62, α63) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast=false z₁=dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Allow other choices here - copyto!(z₂, z₁) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - @.. broadcast=false z₃=α31 * z₁ + α32 * z₂ - nlsolver.z = z₃ - - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - @.. broadcast=false z₄=α41 * z₁ + α42 * z₂ + α43 * z₃ - nlsolver.z = z₄ - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - @.. broadcast=false z₅=α51 * z₁ + α52 * z₂ + α53 * z₃ - nlsolver.z = z₅ - - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - @.. broadcast=false z₆=α61 * z₁ + α62 * z₂ + α63 * z₃ - nlsolver.z = z₆ - - @.. broadcast=false tmp=uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - # Prediction is embedded method - @.. broadcast=false z₇=a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + γ * z₆ - nlsolver.z = z₇ - - @.. broadcast=false tmp=uprev + a71 * z₁ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = 1 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₇ - - step_limiter!(u, integrator, p, t + dt) - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast=false tmp=btilde1 * z₁ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast=false integrator.fsallast=z₇ / dt -end - -@muladd function perform_step!(integrator, cache::KenCarp5ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a43, a51, a53, a54, a61, a63, a64, a65, a71, a73, a74, a75, a76, a81, a84, a85, a86, a87, c3, c4, c5, c6, c7) = cache.tab - (; α31, α32, α41, α42, α51, α52, α61, α62, α71, α72, α73, α74, α75, α81, α82, α83, α84, α85) = cache.tab - (; btilde1, btilde4, btilde5, btilde6, btilde7, btilde8) = cache.tab - (; ea21, ea31, ea32, ea41, ea43, ea51, ea53, ea54, ea61, ea63, ea64, ea65) = cache.tab - (; ea71, ea73, ea74, ea75, ea76, ea81, ea83, ea84, ea85, ea86, ea87) = cache.tab - (; eb1, eb4, eb5, eb6, eb7, eb8) = cache.tab - (; ebtilde1, ebtilde4, ebtilde5, ebtilde6, ebtilde7, ebtilde8) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - k7 = nothing - k8 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction - # Explicit tableau is not FSAL - # Make this not compute on repeat - z₁ = dt .* f(uprev, p, t) - else - # FSAL Step 1 - z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = z₁ - - tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - k1 = dt * integrator.fsalfirst - z₁ - tmp += ea21 * k1 - end - nlsolver.tmp = tmp - nlsolver.c = 2γ - - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ = z₂ - u = nlsolver.tmp + γ * z₂ - k2 = dt * f2(u, p, t + 2γdt) - integrator.stats.nf2 += 1 - tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - z₃ = α31 * z₁ + α32 * z₂ - tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - nlsolver.c = c3 - nlsolver.tmp = tmp - - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ = z₂ - u = nlsolver.tmp + γ * z₃ - k3 = dt * f2(u, p, t + c3 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a41 * z₁ + a43 * z₃ + ea41 * k1 + ea43 * k3 - else - z₄ = α41 * z₁ + α42 * z₂ - tmp = uprev + a41 * z₁ + a43 * z₃ - end - nlsolver.z = z₄ - nlsolver.c = c4 - nlsolver.tmp = tmp - - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ = z₂ - u = nlsolver.tmp + γ * z₄ - k4 = dt * f2(u, p, t + c4 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ + ea51 * k1 + ea53 * k3 + ea54 * k4 - else - z₅ = α51 * z₁ + α52 * z₂ - tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - nlsolver.c = c5 - nlsolver.tmp = tmp - - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ = z₃ - u = nlsolver.tmp + γ * z₅ - k5 = dt * f2(u, p, t + c5 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + ea61 * k1 + ea63 * k3 + - ea64 * k4 + ea65 * k5 - else - z₆ = α61 * z₁ + α62 * z₂ - tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - nlsolver.c = c6 - nlsolver.tmp = tmp - - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - if integrator.f isa SplitFunction - z₇ = z₂ - u = nlsolver.tmp + γ * z₆ - k6 = dt * f2(u, p, t + c6 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a71 * z₁ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ + ea71 * k1 + - ea73 * k3 + ea74 * k4 + ea75 * k5 + ea76 * k6 - else - z₇ = α71 * z₁ + α72 * z₂ + α73 * z₃ + α74 * z₄ + α75 * z₅ - tmp = uprev + a71 * z₁ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - end - nlsolver.z = z₇ - nlsolver.c = c7 - nlsolver.tmp = tmp - - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - if integrator.f isa SplitFunction - z₈ = z₅ - u = nlsolver.tmp + γ * z₇ - k7 = dt * f2(u, p, t + c7 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a81 * z₁ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ + ea81 * k1 + - ea83 * k3 + ea84 * k4 + ea85 * k5 + ea86 * k6 + ea87 * k7 - else - z₈ = α81 * z₁ + α82 * z₂ + α83 * z₃ + α84 * z₄ + α85 * z₅ - tmp = uprev + a81 * z₁ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ - end - nlsolver.z = z₈ - nlsolver.c = 1 - nlsolver.tmp = tmp - - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₈ - if integrator.f isa SplitFunction - k8 = dt * f2(u, p, t + dt) - integrator.stats.nf2 += 1 - u = uprev + a81 * z₁ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ + γ * z₈ + - eb1 * k1 + eb4 * k4 + eb5 * k5 + eb6 * k6 + eb7 * k7 + eb8 * k8 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - tmp = btilde1 * z₁ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ + - btilde8 * z₈ + ebtilde1 * k1 + ebtilde4 * k4 + ebtilde5 * k5 + - ebtilde6 * k6 + ebtilde7 * k7 + ebtilde8 * k8 - else - tmp = btilde1 * z₁ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ + - btilde8 * z₈ - end - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₈ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::KenCarp5Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, atmp, nlsolver, step_limiter!) = cache - (; k1, k2, k3, k4, k5, k6, k7, k8) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a43, a51, a53, a54, a61, a63, a64, a65, a71, a73, a74, a75, a76, a81, a84, a85, a86, a87, c3, c4, c5, c6, c7) = cache.tab - (; α31, α32, α41, α42, α51, α52, α61, α62, α71, α72, α73, α74, α75, α81, α82, α83, α84, α85) = cache.tab - (; btilde1, btilde4, btilde5, btilde6, btilde7, btilde8) = cache.tab - (; ea21, ea31, ea32, ea41, ea43, ea51, ea53, ea54, ea61, ea63, ea64, ea65) = cache.tab - (; ea71, ea73, ea74, ea75, ea76, ea81, ea83, ea84, ea85, ea86, ea87) = cache.tab - (; eb1, eb4, eb5, eb6, eb7, eb8) = cache.tab - (; ebtilde1, ebtilde4, ebtilde5, ebtilde6, ebtilde7, ebtilde8) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - # Explicit tableau is not FSAL - # Make this not compute on repeat - f(z₁, integrator.uprev, p, integrator.t) - z₁ .*= dt - else - # FSAL Step 1 - @.. broadcast=false z₁=dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Allow other choices here - copyto!(z₂, z₁) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - @.. broadcast=false k1=dt * integrator.fsalfirst - z₁ - @.. broadcast=false tmp+=ea21 * k1 - end - - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ .= z₂ - @.. broadcast=false u=tmp + γ * z₂ - f2(k2, u, p, t + 2γdt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - @.. broadcast=false z₃=a31 * z₁ + α32 * z₂ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ .= z₃ - @.. broadcast=false u=tmp + γ * z₃ - f2(k3, u, p, t + c3 * dt) - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a41 * z₁ + a43 * z₃ + ea41 * k1 + ea43 * k3 - else - @.. broadcast=false z₄=α41 * z₁ + α42 * z₂ - @.. broadcast=false tmp=uprev + a41 * z₁ + a43 * z₃ - end - nlsolver.z = z₄ - - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ .= z₂ - @.. broadcast=false u=tmp + γ * z₄ - f2(k4, u, p, t + c4 * dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ + ea51 * k1 + - ea53 * k3 + ea54 * k4 - else - @.. broadcast=false z₅=α51 * z₁ + α52 * z₂ - @.. broadcast=false tmp=uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ .= z₃ - @.. broadcast=false u=tmp + γ * z₅ - f2(k5, u, p, t + c5 * dt) - k5 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + - ea61 * k1 + ea63 * k3 + ea64 * k4 + ea65 * k5 - else - @.. broadcast=false z₆=α61 * z₁ + α62 * z₂ - @.. broadcast=false tmp=uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - if integrator.f isa SplitFunction - z₇ .= z₂ - @.. broadcast=false u=tmp + γ * z₆ - f2(k6, u, p, t + c6 * dt) - k6 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a71 * z₁ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ + ea71 * k1 + ea73 * k3 + ea74 * k4 + ea75 * k5 + - ea76 * k6 - else - @.. broadcast=false z₇=α71 * z₁ + α72 * z₂ + α73 * z₃ + α74 * z₄ + α75 * z₅ - @.. broadcast=false tmp=uprev + a71 * z₁ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - end - nlsolver.z = z₇ - - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - if integrator.f isa SplitFunction - z₈ .= z₅ - @.. broadcast=false u=tmp + γ * z₇ - f2(k7, u, p, t + c7 * dt) - k7 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a81 * z₁ + a84 * z₄ + a85 * z₅ + a86 * z₆ + - a87 * z₇ + ea81 * k1 + ea83 * k3 + ea84 * k4 + ea85 * k5 + - ea86 * k6 + ea87 * k7 - else - @.. broadcast=false z₈=α81 * z₁ + α82 * z₂ + α83 * z₃ + α84 * z₄ + α85 * z₅ - @.. broadcast=false tmp=uprev + a81 * z₁ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ - end - nlsolver.z = z₈ - - nlsolver.c = 1 - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₈ - if integrator.f isa SplitFunction - f2(k8, u, p, t + dt) - k8 .*= dt - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - @.. broadcast=false u=uprev + a81 * z₁ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ + - γ * z₈ + eb1 * k1 + eb4 * k4 + eb5 * k5 + eb6 * k6 + - eb7 * k7 + eb8 * k8 - end - - step_limiter!(u, integrator, p, t + dt) - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - @.. broadcast=false tmp=btilde1 * z₁ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ + - ebtilde1 * k1 + ebtilde4 * k4 + ebtilde5 * k5 + - ebtilde6 * k6 + ebtilde7 * k7 + ebtilde8 * k8 - else - @.. broadcast=false tmp=btilde1 * z₁ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ - end - - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast=false integrator.fsallast=z₈ / dt - end -end - -@muladd function perform_step!(integrator, cache::KenCarp47ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a73, a74, a75, a76, c3, c4, c5, c6) = cache.tab - (; α31, α32, α41, α42, α43, α51, α52, α61, α62, α63, α71, α72, α73, α74, α75, α76) = cache.tab - (; btilde3, btilde4, btilde5, btilde6, btilde7) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, ea61, ea62, ea63, ea64, ea65, ea71, ea72, ea73, ea74, ea75, ea76) = cache.tab - (; eb3, eb4, eb5, eb6, eb7) = cache.tab - (; ebtilde3, ebtilde4, ebtilde5, ebtilde6, ebtilde7) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - k7 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction - # Explicit tableau is not FSAL - # Make this not compute on repeat - z₁ = dt .* f(uprev, p, t) - else - # FSAL Step 1 - z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = z₁ - - tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - k1 = dt * integrator.fsalfirst - z₁ - tmp += ea21 * k1 - end - nlsolver.tmp = tmp - nlsolver.c = 2γ - - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ = z₂ - u = nlsolver.tmp + γ * z₂ - k2 = dt * f2(u, p, t + 2γdt) - integrator.stats.nf2 += 1 - tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - z₃ = α31 * z₁ + α32 * z₂ - tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - nlsolver.tmp = tmp - nlsolver.c = c3 - - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ = z₃ - u = nlsolver.tmp + γ * z₃ - k3 = dt * f2(u, p, t + c3 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 - else - z₄ = α41 * z₁ + α42 * z₂ + α43 * z₃ - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - nlsolver.tmp = tmp - nlsolver.c = c4 - - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ = z₁ - u = nlsolver.tmp + γ * z₄ - k4 = dt * f2(u, p, t + c4 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ + ea51 * k1 + ea52 * k2 + - ea53 * k3 + ea54 * k4 - else - z₅ = α51 * z₁ + α52 * z₂ - tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - nlsolver.tmp = tmp - nlsolver.c = c5 - - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ = z₃ - u = nlsolver.tmp + γ * z₅ - k5 = dt * f2(u, p, t + c5 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ + ea61 * k1 + - ea62 * k2 + ea63 * k3 + ea64 * k4 + ea65 * k5 - else - z₆ = α61 * z₁ + α62 * z₂ + α63 * z₃ - tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - nlsolver.tmp = tmp - nlsolver.c = c6 - - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - if integrator.f isa SplitFunction - z₇ = z₆ - u = nlsolver.tmp + γ * z₆ - k6 = dt * f2(u, p, t + c6 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ + ea71 * k1 + ea72 * k2 + - ea73 * k3 + ea74 * k4 + ea75 * k5 + ea76 * k6 - else - z₇ = α71 * z₁ + α72 * z₂ + α73 * z₃ + α74 * z₄ + α75 * z₅ + +α76 * z₆ - tmp = uprev + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - end - nlsolver.z = z₇ - nlsolver.c = 1 - nlsolver.tmp = tmp - - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₇ - if integrator.f isa SplitFunction - k7 = dt * f2(u, p, t + dt) - u = uprev + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ + γ * z₇ + eb3 * k3 + - eb4 * k4 + eb5 * k5 + eb6 * k6 + eb7 * k7 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - tmp = btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ + - ebtilde3 * k3 + ebtilde4 * k4 + ebtilde5 * k5 + ebtilde6 * k6 + - ebtilde7 * k7 - else - tmp = btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ - end - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₇ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::KenCarp47Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver) = cache - (; k1, k2, k3, k4, k5, k6, k7) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a73, a74, a75, a76, c3, c4, c5, c6) = cache.tab - (; α31, α32, α41, α42, α43, α51, α52, α61, α62, α63, α71, α72, α73, α74, α75, α76) = cache.tab - (; btilde3, btilde4, btilde5, btilde6, btilde7) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, ea61, ea62, ea63, ea64, ea65, ea71, ea72, ea73, ea74, ea75, ea76) = cache.tab - (; eb3, eb4, eb5, eb6, eb7) = cache.tab - (; ebtilde3, ebtilde4, ebtilde5, ebtilde6, ebtilde7) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - # Explicit tableau is not FSAL - # Make this not compute on repeat - f(z₁, integrator.uprev, p, integrator.t) - z₁ .*= dt - else - # FSAL Step 1 - @.. broadcast=false z₁=dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Allow other choices here - z₂ .= z₁ - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - @.. broadcast=false k1=dt * integrator.fsalfirst - z₁ - @.. broadcast=false tmp+=ea21 * k1 - end - - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ .= z₂ - @.. broadcast=false u=tmp + γ * z₂ - f2(k2, u, p, t + 2γdt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - #Guess is from Hermite derivative on z₁ and z₂ - @.. broadcast=false z₃=a31 * z₁ + α32 * z₂ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ .= z₃ - @.. broadcast=false u=tmp + γ * z₃ - f2(k3, u, p, t + c3 * dt) - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + - ea42 * k2 + ea43 * k3 - else - @.. broadcast=false z₄=α41 * z₁ + α42 * z₂ + α43 * z₃ - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ .= z₁ - @.. broadcast=false u=tmp + γ * z₄ - f2(k4, u, p, t + c4 * dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ + - ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 - else - @.. broadcast=false z₅=α51 * z₁ + α52 * z₂ - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ .= z₃ - @.. broadcast=false u=tmp + γ * z₅ - f2(k5, u, p, t + c5 * dt) - k5 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + - a65 * z₅ + ea61 * k1 + ea62 * k2 + ea63 * k3 + ea64 * k4 + - ea65 * k5 - else - @.. broadcast=false z₆=α61 * z₁ + α62 * z₂ + α63 * z₃ - @.. broadcast=false tmp=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - if integrator.f isa SplitFunction - z₇ .= z₆ - @.. broadcast=false u=tmp + γ * z₆ - f2(k6, u, p, t + c6 * dt) - k6 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ + - ea71 * k1 + ea72 * k2 + ea73 * k3 + ea74 * k4 + ea75 * k5 + - ea76 * k6 - else - @.. broadcast=false z₇=α71 * z₁ + α72 * z₂ + α73 * z₃ + α74 * z₄ + α75 * z₅ + - α76 * z₆ - @.. broadcast=false tmp=uprev + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - end - nlsolver.z = z₇ - - nlsolver.c = 1 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₇ - if integrator.f isa SplitFunction - f2(k7, u, p, t + dt) - k7 .*= dt - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - @.. broadcast=false u=uprev + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ + γ * z₇ + - eb3 * k3 + eb4 * k4 + eb5 * k5 + eb6 * k6 + eb7 * k7 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - @.. broadcast=false tmp=btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + ebtilde3 * k3 + - ebtilde4 * k4 + ebtilde5 * k5 + ebtilde6 * k6 + - ebtilde7 * k7 - else - @.. broadcast=false tmp=btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ - end - - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast=false integrator.fsallast=z₇ / dt - end -end - -@muladd function perform_step!(integrator, cache::KenCarp58ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, a83, a84, a85, a86, a87, c3, c4, c5, c6, c7) = cache.tab - (; α31, α32, α41, α42, α51, α52, α61, α62, α63, α71, α72, α73, α81, α82, α83, α84, α85, α86, α87) = cache.tab - (; btilde3, btilde4, btilde5, btilde6, btilde7, btilde8) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, ea61, ea62, ea63, ea64, ea65) = cache.tab - (; ea71, ea72, ea73, ea74, ea75, ea76, ea81, ea82, ea83, ea84, ea85, ea86, ea87) = cache.tab - (; eb3, eb4, eb5, eb6, eb7, eb8) = cache.tab - (; ebtilde3, ebtilde4, ebtilde5, ebtilde6, ebtilde7, ebtilde8) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - k7 = nothing - k8 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction - # Explicit tableau is not FSAL - # Make this not compute on repeat - z₁ = dt .* f(uprev, p, t) - else - # FSAL Step 1 - z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation choice - - nlsolver.z = z₂ = z₁ - - tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - k1 = dt * integrator.fsalfirst - z₁ - tmp += ea21 * k1 - end - nlsolver.tmp = tmp - nlsolver.c = 2γ - - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ = z₂ - u = nlsolver.tmp + γ * z₂ - k2 = dt * f2(u, p, t + 2γdt) - integrator.stats.nf2 += 1 - tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - z₃ = α31 * z₁ + α32 * z₂ - tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - nlsolver.c = c3 - nlsolver.tmp = tmp - - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ = z₁ - u = nlsolver.tmp + γ * z₃ - k3 = dt * f2(u, p, t + c3 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 - else - z₄ = α41 * z₁ + α42 * z₂ - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - nlsolver.c = c4 - nlsolver.tmp = tmp - - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ = z₂ - u = nlsolver.tmp + γ * z₄ - k4 = dt * f2(u, p, t + c4 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ + ea51 * k1 + ea52 * k2 + - ea53 * k3 + ea54 * k4 - else - z₅ = α51 * z₁ + α52 * z₂ - tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - nlsolver.c = c5 - nlsolver.tmp = tmp - - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ = z₃ - u = nlsolver.tmp + γ * z₅ - k5 = dt * f2(u, p, t + c5 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ + ea61 * k1 + - ea62 * k2 + ea63 * k3 + ea64 * k4 + ea65 * k5 - else - z₆ = α61 * z₁ + α62 * z₂ + α63 * z₃ - tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - nlsolver.c = c6 - nlsolver.tmp = tmp - - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - if integrator.f isa SplitFunction - z₇ = z₃ - u = nlsolver.tmp + γ * z₆ - k6 = dt * f2(u, p, t + c6 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ + - ea71 * k1 + ea72 * k2 + ea73 * k3 + ea74 * k4 + ea75 * k5 + ea76 * k6 - else - z₇ = α71 * z₁ + α72 * z₂ + α73 * z₃ - tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - end - nlsolver.z = z₇ - nlsolver.c = c7 - nlsolver.tmp = tmp - - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - if integrator.f isa SplitFunction - z₈ = z₇ - u = nlsolver.tmp + γ * z₇ - k7 = dt * f2(u, p, t + c7 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ + ea81 * k1 + - ea82 * k2 + ea83 * k3 + ea84 * k4 + ea85 * k5 + ea86 * k6 + ea87 * k7 - else - z₈ = α81 * z₁ + α82 * z₂ + α83 * z₃ + α84 * z₄ + α85 * z₅ + α86 * z₆ + α87 * z₇ - tmp = uprev + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ - end - nlsolver.z = z₈ - nlsolver.c = 1 - nlsolver.tmp = tmp - - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₈ - if integrator.f isa SplitFunction - k8 = dt * f2(u, p, t + dt) - integrator.stats.nf2 += 1 - u = uprev + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ + γ * z₈ + - eb3 * k3 + eb4 * k4 + eb5 * k5 + eb6 * k6 + eb7 * k7 + eb8 * k8 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - tmp = btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ + - btilde8 * z₈ + ebtilde3 * k3 + ebtilde4 * k4 + ebtilde5 * k5 + - ebtilde6 * k6 + ebtilde7 * k7 + ebtilde8 * k8 - else - tmp = btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ + - btilde8 * z₈ - end - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₈ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::KenCarp58Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, atmp, nlsolver) = cache - (; k1, k2, k3, k4, k5, k6, k7, k8) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, a83, a84, a85, a86, a87, c3, c4, c5, c6, c7) = cache.tab - (; α31, α32, α41, α42, α51, α52, α61, α62, α63, α71, α72, α73, α81, α82, α83, α84, α85, α86, α87) = cache.tab - (; btilde3, btilde4, btilde5, btilde6, btilde7, btilde8) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, ea61, ea62, ea63, ea64, ea65) = cache.tab - (; ea71, ea72, ea73, ea74, ea75, ea76, ea81, ea82, ea83, ea84, ea85, ea86, ea87) = cache.tab - (; eb3, eb4, eb5, eb6, eb7, eb8) = cache.tab - (; ebtilde3, ebtilde4, ebtilde5, ebtilde6, ebtilde7, ebtilde8) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - # Explicit tableau is not FSAL - # Make this not compute on repeat - f(z₁, integrator.uprev, p, integrator.t) - z₁ .*= dt - else - # FSAL Step 1 - @.. broadcast=false z₁=dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Allow other choices here - z₂ .= z₁ - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - @.. broadcast=false k1=dt * integrator.fsalfirst - z₁ - @.. broadcast=false tmp+=ea21 * k1 - end - - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ .= z₂ - @.. broadcast=false u=tmp + γ * z₂ - f2(k2, u, p, t + 2γdt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - @.. broadcast=false z₃=α31 * z₁ + α32 * z₂ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ .= z₁ - @.. broadcast=false u=tmp + γ * z₃ - f2(k3, u, p, t + c3 * dt) - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + - ea42 * k2 + ea43 * k3 - else - @.. broadcast=false z₄=α41 * z₁ + α42 * z₂ - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ .= z₂ - @.. broadcast=false u=tmp + γ * z₄ - f2(k4, u, p, t + c4 * dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ + - ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 - else - @.. broadcast=false z₅=α51 * z₁ + α52 * z₂ - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ .= z₃ - @.. broadcast=false u=tmp + γ * z₅ - f2(k5, u, p, t + c5 * dt) - k5 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + - a65 * z₅ + ea61 * k1 + ea62 * k2 + ea63 * k3 + ea64 * k4 + - ea65 * k5 - else - @.. broadcast=false z₆=α61 * z₁ + α62 * z₂ + α63 * z₃ - @.. broadcast=false tmp=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - if integrator.f isa SplitFunction - z₇ .= z₃ - @.. broadcast=false u=tmp + γ * z₆ - f2(k6, u, p, t + c6 * dt) - k6 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + - a75 * z₅ + a76 * z₆ + ea71 * k1 + ea72 * k2 + ea73 * k3 + - ea74 * k4 + ea75 * k5 + ea76 * k6 - else - @.. broadcast=false z₇=α71 * z₁ + α72 * z₂ + α73 * z₃ - @.. broadcast=false tmp=uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + - a75 * z₅ + a76 * z₆ - end - nlsolver.z = z₇ - - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - if integrator.f isa SplitFunction - z₈ .= z₇ - @.. broadcast=false u=tmp + γ * z₇ - f2(k7, u, p, t + c7 * dt) - k7 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast=false tmp=uprev + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + - a87 * z₇ + ea81 * k1 + ea82 * k2 + ea83 * k3 + ea84 * k4 + - ea85 * k5 + ea86 * k6 + ea87 * k7 - else - @.. broadcast=false z₈=α81 * z₁ + α82 * z₂ + α83 * z₃ + α84 * z₄ + α85 * z₅ + - α86 * z₆ + α87 * z₇ - @.. broadcast=false tmp=uprev + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ - end - nlsolver.z = z₈ - - nlsolver.c = 1 - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₈ - if integrator.f isa SplitFunction - f2(k8, u, p, t + dt) - k8 .*= dt - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - @.. broadcast=false u=uprev + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ + - γ * z₈ + eb3 * k3 + eb4 * k4 + eb5 * k5 + eb6 * k6 + - eb7 * k7 + eb8 * k8 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - @.. broadcast=false tmp=btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ + - ebtilde3 * k3 + ebtilde4 * k4 + ebtilde5 * k5 + - ebtilde6 * k6 + ebtilde7 * k7 + ebtilde8 * k8 - else - @.. broadcast=false tmp=btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ - end - - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast=false integrator.fsallast=z₈ / dt - end -end diff --git a/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl b/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl index 75a6453fcf..2d11e6e474 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl @@ -4,999 +4,160 @@ function get_fsalfirstlast(cache::SDIRKMutableCache, u) (cache.fsalfirst, du_alias_or_new(cache.nlsolver, cache.fsalfirst)) end -@cache mutable struct ImplicitEulerCache{ - uType, rateType, uNoUnitsType, N, AV, StepLimiter} <: - SDIRKMutableCache +# Unified SDIRK caches that work with any SDIRK tableau +@cache mutable struct SDIRKCache{uType, rateType, uNoUnitsType, Tab, N, AV, StepLimiter} <: SDIRKMutableCache u::uType uprev::uType uprev2::uType fsalfirst::rateType + zs::Vector{uType} atmp::uNoUnitsType nlsolver::N + tab::Tab algebraic_vars::AV step_limiter!::StepLimiter -end - -function alg_cache(alg::ImplicitEuler, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - γ, c = 1, 1 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - algebraic_vars = f.mass_matrix === I ? nothing : - [all(iszero, x) for x in eachcol(f.mass_matrix)] - - ImplicitEulerCache( - u, uprev, uprev2, fsalfirst, atmp, nlsolver, algebraic_vars, alg.step_limiter!) -end - -mutable struct ImplicitEulerConstantCache{N} <: SDIRKConstantCache - nlsolver::N -end - -function alg_cache(alg::ImplicitEuler, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - γ, c = 1, 1 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - ImplicitEulerConstantCache(nlsolver) -end - -mutable struct ImplicitMidpointConstantCache{N} <: SDIRKConstantCache - nlsolver::N -end - -function alg_cache(alg::ImplicitMidpoint, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - γ, c = 1 // 2, 1 // 2 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - ImplicitMidpointConstantCache(nlsolver) -end - -@cache mutable struct ImplicitMidpointCache{uType, rateType, N, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - nlsolver::N - step_limiter!::StepLimiter -end - -function alg_cache(alg::ImplicitMidpoint, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - γ, c = 1 // 2, 1 // 2 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - ImplicitMidpointCache(u, uprev, fsalfirst, nlsolver, alg.step_limiter!) -end - -mutable struct TrapezoidConstantCache{uType, tType, N} <: SDIRKConstantCache - uprev3::uType - tprev2::tType - nlsolver::N -end - -function alg_cache(alg::Trapezoid, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - γ, c = 1 // 2, 1 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - - uprev3 = u - tprev2 = t - - TrapezoidConstantCache(uprev3, tprev2, nlsolver) -end - -@cache mutable struct TrapezoidCache{ - uType, rateType, uNoUnitsType, tType, N, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - uprev2::uType - fsalfirst::rateType - atmp::uNoUnitsType + # For algorithms that need additional history uprev3::uType - tprev2::tType - nlsolver::N - step_limiter!::StepLimiter -end - -function alg_cache(alg::Trapezoid, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - γ, c = 1 // 2, 1 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - uprev3 = zero(u) - tprev2 = t - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - TrapezoidCache( - u, uprev, uprev2, fsalfirst, atmp, uprev3, tprev2, nlsolver, alg.step_limiter!) -end - -mutable struct TRBDF2ConstantCache{Tab, N} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::TRBDF2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = TRBDF2Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.d, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - TRBDF2ConstantCache(nlsolver, tab) -end - -@cache mutable struct TRBDF2Cache{uType, rateType, uNoUnitsType, Tab, N, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType + tprev2::typeof(1.0) zprev::uType zᵧ::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter end -function alg_cache(alg::TRBDF2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = TRBDF2Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.d, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - zprev = zero(u) - zᵧ = zero(u) - - TRBDF2Cache(u, uprev, fsalfirst, zprev, zᵧ, atmp, nlsolver, tab, alg.step_limiter!) -end - -mutable struct SDIRK2ConstantCache{N} <: SDIRKConstantCache - nlsolver::N -end - -function alg_cache(alg::SDIRK2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - γ, c = 1, 1 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - SDIRK2ConstantCache(nlsolver) -end - -@cache mutable struct SDIRK2Cache{uType, rateType, uNoUnitsType, N, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - atmp::uNoUnitsType - nlsolver::N - step_limiter!::StepLimiter -end - -function alg_cache(alg::SDIRK2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - γ, c = 1, 1 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - SDIRK2Cache(u, uprev, fsalfirst, z₁, z₂, atmp, nlsolver, alg.step_limiter!) -end - -struct SDIRK22ConstantCache{uType, tType, N, Tab} <: SDIRKConstantCache - uprev3::uType - tprev2::tType +mutable struct SDIRKConstantCacheImpl{N, Tab, uType, tType} <: SDIRKConstantCache nlsolver::N tab::Tab -end - -function alg_cache(alg::SDIRK22, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{tTypeNoUnits}, ::Type{uBottomEltypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = SDIRK22Tableau(constvalue(uBottomEltypeNoUnits)) - uprev3 = u - tprev2 = t - γ, c = 1, 1 - - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - - SDIRK22ConstantCache(uprev3, tprev2, nlsolver) -end - -@cache mutable struct SDIRK22Cache{ - uType, rateType, uNoUnitsType, tType, N, Tab, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - uprev2::uType - fsalfirst::rateType - atmp::uNoUnitsType + # For algorithms that need additional history uprev3::uType tprev2::tType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache(alg::SDIRK22, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = SDIRK22Tableau(constvalue(uBottomEltypeNoUnits)) - γ, c = 1, 1 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - uprev3 = zero(u) - tprev2 = t - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - SDIRK22Cache( - u, uprev, uprev2, fsalfirst, atmp, uprev3, tprev2, nlsolver, tab, alg.step_limiter!) # shouldn't this be SDIRK22Cache instead of SDIRK22? -end - -mutable struct SSPSDIRK2ConstantCache{N} <: SDIRKConstantCache - nlsolver::N -end - -function alg_cache(alg::SSPSDIRK2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - γ, c = 1 // 4, 1 // 1 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - SSPSDIRK2ConstantCache(nlsolver) -end - -@cache mutable struct SSPSDIRK2Cache{uType, rateType, N} <: SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - nlsolver::N -end - -function alg_cache(alg::SSPSDIRK2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - γ, c = 1 // 4, 1 // 1 - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - SSPSDIRK2Cache(u, uprev, fsalfirst, z₁, z₂, nlsolver) -end - -mutable struct Cash4ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::Cash4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Cash4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - Cash4ConstantCache(nlsolver, tab) -end - -@cache mutable struct Cash4Cache{uType, rateType, uNoUnitsType, N, Tab} <: SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -function alg_cache(alg::Cash4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Cash4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - Cash4Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, atmp, nlsolver, tab) -end - -mutable struct SFSDIRK4ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::SFSDIRK4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = SFSDIRK4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - SFSDIRK4ConstantCache(nlsolver, tab) -end - -@cache mutable struct SFSDIRK4Cache{uType, rateType, uNoUnitsType, N, Tab} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -function alg_cache(alg::SFSDIRK4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = SFSDIRK4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - SFSDIRK4Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, atmp, nlsolver, tab) -end - -mutable struct SFSDIRK5ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab end -function alg_cache(alg::SFSDIRK5, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = SFSDIRK5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - SFSDIRK5ConstantCache(nlsolver, tab) -end +const SDIRKConstantCacheType = SDIRKConstantCacheImpl -@cache mutable struct SFSDIRK5Cache{uType, rateType, uNoUnitsType, N, Tab} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end +# Unified alg_cache functions for mutable cache +function alg_cache(alg::Union{ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2, SDIRK2, SDIRK22, SSPSDIRK2, + Cash4, SFSDIRK4, SFSDIRK5, SFSDIRK6, SFSDIRK7, SFSDIRK8, + Hairer4, Hairer42, ESDIRK54I8L2SA, ESDIRK436L2SA2, ESDIRK437L2SA, + ESDIRK547L2SA2, ESDIRK659L2SA, + Kvaerno3, KenCarp3, CFNLIRK3, Kvaerno4, Kvaerno5, KenCarp4, + KenCarp47, KenCarp5, KenCarp58}, + u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} -function alg_cache(alg::SFSDIRK5, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = SFSDIRK5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = get_sdirk_tableau(alg, constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - SFSDIRK5Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, atmp, nlsolver, tab) -end + s = length(tab.b) -mutable struct SFSDIRK6ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::SFSDIRK6, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = SFSDIRK6Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - SFSDIRK6ConstantCache(nlsolver, tab) -end - -@cache mutable struct SFSDIRK6Cache{uType, rateType, uNoUnitsType, N, Tab} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -function alg_cache(alg::SFSDIRK6, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = SFSDIRK6Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) fsalfirst = zero(rate_prototype) - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - SFSDIRK6Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, atmp, nlsolver, tab) -end - -mutable struct SFSDIRK7ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::SFSDIRK7, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = SFSDIRK7Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - SFSDIRK7ConstantCache(nlsolver, tab) -end - -@cache mutable struct SFSDIRK7Cache{uType, rateType, uNoUnitsType, N, Tab} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -function alg_cache(alg::SFSDIRK7, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = SFSDIRK7Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - SFSDIRK7Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver, tab) -end - -mutable struct SFSDIRK8ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::SFSDIRK8, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = SFSDIRK8Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - SFSDIRK8ConstantCache(nlsolver, tab) -end - -@cache mutable struct SFSDIRK8Cache{uType, rateType, uNoUnitsType, N, Tab} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - z₈::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -function alg_cache(alg::SFSDIRK8, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = SFSDIRK8Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = zero(u) - z₈ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - SFSDIRK8Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, atmp, nlsolver, tab) -end - -mutable struct Hairer4ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::Union{Hairer4, Hairer42}, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - if alg isa Hairer4 - tab = Hairer4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - else - tab = Hairer42Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + # Initialize all z stage vectors + zs = Vector{typeof(u)}(undef, s) + for i in 1:s + zs[i] = zero(u) end - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - Hairer4ConstantCache(nlsolver, tab) -end - -@cache mutable struct Hairer4Cache{uType, rateType, uNoUnitsType, Tab, N} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end + zs[end] = nlsolver.z # use nlsolver.z for the last stage -function alg_cache( - alg::Union{Hairer4, Hairer42}, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - if alg isa Hairer4 - tab = Hairer4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - else # Hairer42 - tab = Hairer42Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - end - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = nlsolver.z atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) - Hairer4Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, atmp, nlsolver, tab) -end - -@cache mutable struct ESDIRK54I8L2SACache{uType, rateType, uNoUnitsType, Tab, N} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - z₈::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -function alg_cache(alg::ESDIRK54I8L2SA, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK54I8L2SATableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = zero(u) - z₈ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) + # Handle mass matrix for algebraic variables (ImplicitEuler needs this) + algebraic_vars = f.mass_matrix === I ? nothing : + [all(iszero, x) for x in eachcol(f.mass_matrix)] - ESDIRK54I8L2SACache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, atmp, nlsolver, - tab) -end + # Get step limiter + step_limiter! = hasproperty(alg, :step_limiter!) ? alg.step_limiter! : trivial_limiter! -mutable struct ESDIRK54I8L2SAConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end + # Additional variables for algorithms that need history + uprev3 = zero(u) + tprev2 = t + zprev = zero(u) + zᵧ = zero(u) -function alg_cache(alg::ESDIRK54I8L2SA, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK54I8L2SATableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - ESDIRK54I8L2SAConstantCache(nlsolver, tab) + SDIRKCache(u, uprev, uprev2, fsalfirst, zs, + atmp, nlsolver, tab, algebraic_vars, step_limiter!, uprev3, tprev2, zprev, zᵧ) end -@cache mutable struct ESDIRK436L2SA2Cache{uType, rateType, uNoUnitsType, Tab, N} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end +# Unified alg_cache functions for constant cache +function alg_cache(alg::Union{ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2, SDIRK2, SDIRK22, SSPSDIRK2, + Cash4, SFSDIRK4, SFSDIRK5, SFSDIRK6, SFSDIRK7, SFSDIRK8, + Hairer4, Hairer42, ESDIRK54I8L2SA, ESDIRK436L2SA2, ESDIRK437L2SA, + ESDIRK547L2SA2, ESDIRK659L2SA, + Kvaerno3, KenCarp3, CFNLIRK3, Kvaerno4, Kvaerno5, KenCarp4, + KenCarp47, KenCarp5, KenCarp58}, + u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} -function alg_cache(alg::ESDIRK436L2SA2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK436L2SA2Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = get_sdirk_tableau(alg, constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - ESDIRK436L2SA2Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, atmp, nlsolver, - tab) -end - -mutable struct ESDIRK436L2SA2ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache(alg::ESDIRK436L2SA2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK436L2SA2Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - ESDIRK436L2SA2ConstantCache(nlsolver, tab) -end - -@cache mutable struct ESDIRK437L2SACache{uType, rateType, uNoUnitsType, Tab, N} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end -function alg_cache(alg::ESDIRK437L2SA, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK437L2SATableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) + # Additional variables for algorithms that need history + uprev3 = u + tprev2 = t - ESDIRK437L2SACache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver, - tab) + SDIRKConstantCacheImpl(nlsolver, tab, uprev3, tprev2) end -mutable struct ESDIRK437L2SAConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end +# Keep old caches for backward compatibility for now, will be removed later. +const ImplicitEulerCacheInner = SDIRKCache +const ImplicitEulerConstantCacheInner = SDIRKConstantCacheImpl -function alg_cache(alg::ESDIRK437L2SA, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK437L2SATableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - ESDIRK437L2SAConstantCache(nlsolver, tab) +mutable struct ImplicitEulerCache{C<:ImplicitEulerCacheInner} <: SDIRKMutableCache + inner::C end -@cache mutable struct ESDIRK547L2SA2Cache{uType, rateType, uNoUnitsType, Tab, N} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab +mutable struct ImplicitEulerConstantCache{C<:ImplicitEulerConstantCacheInner} <: SDIRKConstantCache + inner::C end -function alg_cache(alg::ESDIRK547L2SA2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK547L2SA2Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - ESDIRK547L2SA2Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver, - tab) +function Base.getproperty(cache::ImplicitEulerCache, sym::Symbol) + sym === :inner && return getfield(cache, :inner) + return getproperty(getfield(cache, :inner), sym) end -mutable struct ESDIRK547L2SA2ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab +function Base.getproperty(cache::ImplicitEulerConstantCache, sym::Symbol) + sym === :inner && return getfield(cache, :inner) + return getproperty(getfield(cache, :inner), sym) end -function alg_cache(alg::ESDIRK547L2SA2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK547L2SA2Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - ESDIRK547L2SA2ConstantCache(nlsolver, tab) +function Base.setproperty!(cache::ImplicitEulerCache, sym::Symbol, val) + sym === :inner && return setfield!(cache, :inner, val) + setproperty!(getfield(cache, :inner), sym, val) end -@cache mutable struct ESDIRK659L2SACache{uType, rateType, uNoUnitsType, Tab, N} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - z₈::uType - z₉::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab +function Base.setproperty!(cache::ImplicitEulerConstantCache, sym::Symbol, val) + sym === :inner && return setfield!(cache, :inner, val) + setproperty!(getfield(cache, :inner), sym, val) end -function alg_cache(alg::ESDIRK659L2SA, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, - p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK659L2SATableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = zero(u) - z₈ = zero(u) - z₉ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) +function ImplicitEulerCache(u, uprev, uprev2, fsalfirst, atmp, nlsolver, algebraic_vars, step_limiter!) + T = eltype(u) + T2 = eltype(atmp) + tab = ImplicitEulerTableau(T === Nothing ? Float64 : T, T2 === Nothing ? Float64 : T2) + zs = Vector{typeof(u)}(undef, 1) + zs[1] = nlsolver.z + inner = SDIRKCache(u, uprev, uprev2, fsalfirst, zs, atmp, nlsolver, tab, + algebraic_vars, step_limiter!, zero(u), zero(eltype(tab.c)), zero(u), zero(u)) + ImplicitEulerCache(inner) +end - ESDIRK659L2SACache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, z₉, atmp, - nlsolver, tab) +function ImplicitEulerConstantCache(nlsolver) + tab = ImplicitEulerTableau() + inner = SDIRKConstantCacheImpl(nlsolver, tab, zero(nlsolver.tmp), zero(eltype(tab.c))) + ImplicitEulerConstantCache(inner) end -mutable struct ESDIRK659L2SAConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab +@muladd function perform_step!(integrator, cache::ImplicitEulerCache, repeat_step=false) + perform_step!(integrator, cache.inner, repeat_step) end -function alg_cache(alg::ESDIRK659L2SA, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}) where - {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK659L2SATableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - ESDIRK659L2SAConstantCache(nlsolver, tab) +@muladd function perform_step!(integrator, cache::ImplicitEulerConstantCache, repeat_step=false) + perform_step!(integrator, cache.inner, repeat_step) end diff --git a/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl deleted file mode 100644 index 0998ac0a8c..0000000000 --- a/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl +++ /dev/null @@ -1,3114 +0,0 @@ -function initialize!(integrator, cache::SDIRKConstantCache) - integrator.kshortsize = 2 - integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) - integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - # Avoid undefined entries if k is an array of arrays - integrator.fsallast = zero(integrator.fsalfirst) - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast -end - -function initialize!(integrator, cache::SDIRKMutableCache) - integrator.kshortsize = 2 - resize!(integrator.k, integrator.kshortsize) - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) # For the interpolation, needs k at the updated point - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) -end - -@muladd function perform_step!(integrator, cache::ImplicitEulerConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - # initial guess - if alg.extrapolant == :linear - nlsolver.z = dt * integrator.fsalfirst - else # :constant - nlsolver.z = zero(u) - end - - nlsolver.tmp = uprev - nlsolver.γ = 1 - z = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - u = nlsolver.tmp + z - - if integrator.opts.adaptive && integrator.success_iter > 0 - # local truncation error (LTE) bound by dt^2/2*max|y''(t)| - # use 2nd divided differences (DD) a la SPICE and Shampine - - # TODO: check numerical stability - uprev2 = integrator.uprev2 - tprev = integrator.tprev - - dt1 = dt * (t + dt - tprev) - dt2 = (t - tprev) * (t + dt - tprev) - c = 7 / 12 # default correction factor in SPICE (LTE overestimated by DD) - r = c * dt^2 # by mean value theorem 2nd DD equals y''(s)/2 for some s - - tmp = r * - integrator.opts.internalnorm.((u - uprev) / dt1 - (uprev - uprev2) / dt2, t) - atmp = calculate_residuals(tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - else - integrator.EEst = 1 - end - - integrator.fsallast = f(u, p, t + dt) - - if integrator.opts.adaptive && integrator.differential_vars !== nothing - atmp = @. ifelse(!integrator.differential_vars, integrator.fsallast, false) ./ - integrator.opts.abstol - integrator.EEst += integrator.opts.internalnorm(atmp, t) - end - - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::ImplicitEulerCache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; atmp, nlsolver, step_limiter!) = cache - (; z, tmp) = nlsolver - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - # initial guess - if alg.extrapolant == :linear - @.. broadcast=false z=dt * integrator.fsalfirst - else # :constant - z .= zero(eltype(u)) - end - - nlsolver.tmp .= uprev - nlsolver.γ = 1 - z = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - @.. broadcast=false u=uprev + z - - step_limiter!(u, integrator, p, t + dt) - - if integrator.opts.adaptive && integrator.success_iter > 0 - # local truncation error (LTE) bound by dt^2/2*max|y''(t)| - # use 2nd divided differences (DD) a la SPICE and Shampine - - # TODO: check numerical stability - uprev2 = integrator.uprev2 - tprev = integrator.tprev - - dt1 = dt * (t + dt - tprev) - dt2 = (t - tprev) * (t + dt - tprev) - c = 7 / 12 # default correction factor in SPICE (LTE overestimated by DD) - r = c * dt^2 # by mean value theorem 2nd DD equals y''(s)/2 for some s - - @.. broadcast=false tmp=r * integrator.opts.internalnorm( - (u - uprev) / dt1 - - (uprev - uprev2) / dt2, t) - calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - else - integrator.EEst = 1 - end - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - f(integrator.fsallast, u, p, t + dt) - - if integrator.opts.adaptive && integrator.differential_vars !== nothing - @.. broadcast=false atmp=ifelse(cache.algebraic_vars, integrator.fsallast, false) / - integrator.opts.abstol - integrator.EEst += integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::ImplicitMidpointConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - γ = 1 // 2 - markfirststage!(nlsolver) - - # initial guess - if alg.extrapolant == :linear - nlsolver.z = dt * integrator.fsalfirst - else # :constant - nlsolver.z = zero(u) - end - - nlsolver.tmp = uprev - z = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - u = nlsolver.tmp + z - - integrator.fsallast = f(u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::ImplicitMidpointCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; nlsolver, step_limiter!) = cache - (; z, tmp) = nlsolver - mass_matrix = integrator.f.mass_matrix - alg = unwrap_alg(integrator, true) - γ = 1 // 2 - markfirststage!(nlsolver) - - # initial guess - if alg.extrapolant == :linear - @.. broadcast=false z=dt * integrator.fsalfirst - else # :constant - z .= zero(eltype(u)) - end - - nlsolver.tmp = uprev - z = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - @.. broadcast=false u=nlsolver.tmp + z - - step_limiter!(u, integrator, p, t + dt) - - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - f(integrator.fsallast, u, p, t + dt) -end - -@muladd function perform_step!(integrator, cache::TrapezoidConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - # precalculations - γ = 1 // 2 - γdt = γ * dt - markfirststage!(nlsolver) - - # initial guess: constant extrapolation - nlsolver.z = uprev - - if f.mass_matrix === I - nlsolver.tmp = @.. broadcast=false uprev * inv(γdt)+integrator.fsalfirst - else - nlsolver.tmp = (f.mass_matrix * uprev) .* inv(γdt) .+ integrator.fsalfirst - end - nlsolver.α = 1 - nlsolver.γ = γ - nlsolver.method = COEFFICIENT_MULTISTEP - u = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - if integrator.opts.adaptive - if integrator.iter > 2 - # local truncation error (LTE) bound by dt^3/12*max|y'''(t)| - # use 3rd divided differences (DD) a la SPICE and Shampine - - # TODO: check numerical stability - uprev2 = integrator.uprev2 - tprev = integrator.tprev - uprev3 = cache.uprev3 - tprev2 = cache.tprev2 - - dt1 = dt * (t + dt - tprev) - dt2 = (t - tprev) * (t + dt - tprev) - dt3 = (t - tprev) * (t - tprev2) - dt4 = (tprev - tprev2) * (t - tprev2) - dt5 = t + dt - tprev2 - c = 7 / 12 # default correction factor in SPICE (LTE overestimated by DD) - r = c * dt^3 / 2 # by mean value theorem 3rd DD equals y'''(s)/6 for some s - - # tmp = r*abs(((u - uprev)/dt1 - (uprev - uprev2)/dt2) - ((uprev - uprev2)/dt3 - (uprev2 - uprev3)/dt4)/dt5) - DD31 = (u - uprev) / dt1 - (uprev - uprev2) / dt2 - DD30 = (uprev - uprev2) / dt3 - (uprev2 - uprev3) / dt4 - tmp = r * integrator.opts.internalnorm((DD31 - DD30) / dt5, t) - atmp = calculate_residuals(tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, - t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - if integrator.EEst <= 1 - cache.uprev3 = uprev2 - cache.tprev2 = tprev - end - elseif integrator.success_iter > 0 - integrator.EEst = 1 - cache.uprev3 = integrator.uprev2 - cache.tprev2 = integrator.tprev - else - integrator.EEst = 1 - end - end - - integrator.fsallast = f(u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::TrapezoidCache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; atmp, nlsolver, step_limiter!) = cache - (; z, tmp) = nlsolver - alg = unwrap_alg(integrator, true) - mass_matrix = integrator.f.mass_matrix - - # precalculations - γ = 1 // 2 - γdt = γ * dt - markfirststage!(nlsolver) - - # initial guess: constant extrapolation - @.. broadcast=false z=uprev - invγdt = inv(γdt) - if mass_matrix === I - @.. broadcast=false tmp=uprev * invγdt + integrator.fsalfirst - else - mul!(u, mass_matrix, uprev) - @.. broadcast=false tmp=u * invγdt + integrator.fsalfirst - end - nlsolver.α = 1 - nlsolver.γ = γ - nlsolver.method = COEFFICIENT_MULTISTEP - z = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - @.. broadcast=false u=z - - step_limiter!(u, integrator, p, t + dt) - - if integrator.opts.adaptive - if integrator.iter > 2 - # local truncation error (LTE) bound by dt^3/12*max|y'''(t)| - # use 3rd divided differences (DD) a la SPICE and Shampine - - # TODO: check numerical stability - uprev2 = integrator.uprev2 - tprev = integrator.tprev - uprev3 = cache.uprev3 - tprev2 = cache.tprev2 - - dt1 = dt * (t + dt - tprev) - dt2 = (t - tprev) * (t + dt - tprev) - dt3 = (t - tprev) * (t - tprev2) - dt4 = (tprev - tprev2) * (t - tprev2) - dt5 = t + dt - tprev2 - c = 7 / 12 # default correction factor in SPICE (LTE overestimated by DD) - r = c * dt^3 / 2 # by mean value theorem 3rd DD equals y'''(s)/6 for some s - - # @.. broadcast=false tmp = r*abs(((u - uprev)/dt1 - (uprev - uprev2)/dt2) - ((uprev - uprev2)/dt3 - (uprev2 - uprev3)/dt4)/dt5) - @.. broadcast=false tmp=r * integrator.opts.internalnorm( - (((u - uprev) / dt1 - - (uprev - uprev2) / dt2) #DD31 - - - ((uprev - uprev2) / dt3 - - (uprev2 - uprev3) / - dt4)) / - dt5, - t) - calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - if integrator.EEst <= 1 - copyto!(cache.uprev3, uprev2) - cache.tprev2 = tprev - end - elseif integrator.success_iter > 0 - integrator.EEst = 1 - copyto!(cache.uprev3, integrator.uprev2) - cache.tprev2 = integrator.tprev - else - integrator.EEst = 1 - end - end - - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - f(integrator.fsallast, u, p, t + dt) -end - -@muladd function perform_step!(integrator, cache::TRBDF2ConstantCache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, d, ω, btilde1, btilde2, btilde3, α1, α2) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - # FSAL - zprev = dt * integrator.fsalfirst - - ##### Solve Trapezoid Step - - # TODO: Add extrapolation - zᵧ = zprev - nlsolver.z = zᵧ - nlsolver.c = γ - - nlsolver.tmp = uprev + d * zprev - zᵧ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve BDF2 Step - - ### Initial Guess From Shampine - z = α1 * zprev + α2 * zᵧ - nlsolver.z = z - nlsolver.c = 1 - - nlsolver.tmp = uprev + ω * zprev + ω * zᵧ - z = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + d * z - - ################################### Finalize - - if integrator.opts.adaptive - tmp = btilde1 * zprev + btilde2 * zᵧ + btilde3 * z - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::TRBDF2Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; zprev, zᵧ, atmp, nlsolver, step_limiter!) = cache - (; z, tmp) = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing - b = nlsolver.ztmp - (; γ, d, ω, btilde1, btilde2, btilde3, α1, α2) = cache.tab - alg = unwrap_alg(integrator, true) - - # FSAL - @.. broadcast=false zprev=dt * integrator.fsalfirst - markfirststage!(nlsolver) - - ##### Solve Trapezoid Step - - # TODO: Add extrapolation - @.. broadcast=false zᵧ=zprev - z .= zᵧ - @.. broadcast=false tmp=uprev + d * zprev - nlsolver.c = γ - zᵧ .= nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve BDF2 Step - - ### Initial Guess From Shampine - @.. broadcast=false z=α1 * zprev + α2 * zᵧ - @.. broadcast=false tmp=uprev + ω * zprev + ω * zᵧ - nlsolver.c = 1 - isnewton(nlsolver) && set_new_W!(nlsolver, false) - nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + d * z - - step_limiter!(u, integrator, p, t + dt) - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast=false tmp=btilde1 * zprev + btilde2 * zᵧ + btilde3 * z - if alg.smooth_est && isnewton(nlsolver) # From Shampine - est = nlsolver.cache.dz - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast=false integrator.fsallast=z / dt -end - -@muladd function perform_step!(integrator, cache::TRBDF2Cache{<:Array}, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; zprev, zᵧ, atmp, nlsolver, step_limiter!) = cache - (; z, tmp) = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing - b = nlsolver.ztmp - (; γ, d, ω, btilde1, btilde2, btilde3, α1, α2) = cache.tab - alg = unwrap_alg(integrator, true) - - # FSAL - @inbounds @simd ivdep for i in eachindex(u) - zprev[i] = dt * integrator.fsalfirst[i] - end - markfirststage!(nlsolver) - - ##### Solve Trapezoid Step - - # TODO: Add extrapolation - copyto!(zᵧ, zprev) - copyto!(z, zᵧ) - @inbounds @simd ivdep for i in eachindex(u) - tmp[i] = uprev[i] + d * zprev[i] - end - nlsolver.c = γ - zᵧ .= nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve BDF2 Step - - ### Initial Guess From Shampine - @inbounds @simd ivdep for i in eachindex(u) - z[i] = α1 * zprev[i] + α2 * zᵧ[i] - end - @inbounds @simd ivdep for i in eachindex(u) - tmp[i] = uprev[i] + ω * zprev[i] + ω * zᵧ[i] - end - nlsolver.c = 1 - isnewton(nlsolver) && set_new_W!(nlsolver, false) - nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @inbounds @simd ivdep for i in eachindex(u) - u[i] = tmp[i] + d * z[i] - end - - step_limiter!(u, integrator, p, t + dt) - - ################################### Finalize - - if integrator.opts.adaptive - @inbounds @simd ivdep for i in eachindex(u) - tmp[i] = btilde1 * zprev[i] + btilde2 * zᵧ[i] + btilde3 * z[i] - end - if alg.smooth_est && isnewton(nlsolver) # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @inbounds @simd ivdep for i in eachindex(u) - integrator.fsallast[i] = z[i] / dt - end -end - -@muladd function perform_step!(integrator, cache::SDIRK2ConstantCache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - # initial guess - if integrator.success_iter > 0 && !integrator.reeval_fsal && - alg.extrapolant == :interpolant - current_extrapolant!(u, t + dt, integrator) - z₁ = u - uprev - elseif alg.extrapolant == :linear - z₁ = dt * integrator.fsalfirst - else - z₁ = zero(u) - end - - nlsolver.tmp = uprev - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ### Initial Guess Is α₁ = c₂/γ, c₂ = 0 => z₂ = α₁z₁ = 0 - z₂ = zero(u) - nlsolver.z = z₂ - nlsolver.tmp = uprev - z₁ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = uprev + z₁ / 2 + z₂ / 2 - - ################################### Finalize - - if integrator.opts.adaptive - tmp = z₁ / 2 - z₂ / 2 - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = f(u, p, t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::SDIRK2Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, atmp, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - # initial guess - if integrator.success_iter > 0 && !integrator.reeval_fsal && - alg.extrapolant == :interpolant - current_extrapolant!(u, t + dt, integrator) - @.. broadcast=false z₁=u - uprev - elseif alg.extrapolant == :linear - @.. broadcast=false z₁=dt * integrator.fsalfirst - else - z₁ .= zero(eltype(u)) - end - nlsolver.z = z₁ - - ##### Step 1 - nlsolver.tmp = uprev - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 2 - - ### Initial Guess Is α₁ = c₂/γ, c₂ = 0 => z₂ = α₁z₁ = 0 - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - isnewton(nlsolver) && set_new_W!(nlsolver, false) - @.. broadcast=false tmp=uprev - z₁ - nlsolver.tmp = tmp - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=uprev + z₁ / 2 + z₂ / 2 - - step_limiter!(u, integrator, p, t + dt) - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast=false tmp=z₁ / 2 - z₂ / 2 - if alg.smooth_est && isnewton(nlsolver) # From Shampine - est = nlsolver.cache.dz - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - f(integrator.fsallast, u, p, t) -end - -@muladd function perform_step!(integrator, cache::SDIRK22ConstantCache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; a, α, β) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - # precalculations - γ = a * dt - γdt = γ * dt - markfirststage!(nlsolver) - - # initial guess - zprev = dt * integrator.fsalfirst - nlsolver.z = zprev - - # first stage - nlsolver.tmp = uprev + γdt * integrator.fsalfirst - z = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - uprev = α * nlsolver.tmp + β * z - - # final stage - γ = dt - γdt = γ * dt - markfirststage!(nlsolver) - nlsolver.tmp = uprev + γdt * integrator.fsalfirst - z = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - u = nlsolver.tmp - - if integrator.opts.adaptive - if integrator.iter > 2 - # local truncation error (LTE) bound by dt^3/12*max|y'''(t)| - # use 3rd divided differences (DD) a la SPICE and Shampine - - # TODO: check numerical stability - uprev2 = integrator.uprev2 - tprev = integrator.tprev - uprev3 = cache.uprev3 - tprev2 = cache.tprev2 - - dt1 = dt * (t + dt - tprev) - dt2 = (t - tprev) * (t + dt - tprev) - dt3 = (t - tprev) * (t - tprev2) - dt4 = (tprev - tprev2) * (t - tprev2) - dt5 = t + dt - tprev2 - c = 7 / 12 # default correction factor in SPICE (LTE overestimated by DD) - r = c * dt^3 / 2 # by mean value theorem 3rd DD equals y'''(s)/6 for some s - - DD31 = (u - uprev) / dt1 - (uprev - uprev2) / dt2 - DD30 = (uprev - uprev2) / dt3 - (uprev2 - uprev3) / dt4 - tmp = r * abs((DD31 - DD30) / dt5) - atmp = calculate_residuals(tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, - t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - if integrator.EEst <= 1 - cache.uprev3 = uprev2 - cache.tprev2 = tprev - end - elseif integrator.success_iter > 0 - integrator.EEst = 1 - cache.uprev3 = integrator.uprev2 - cache.tprev2 = integrator.tprev - else - integrator.EEst = 1 - end - end - - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2) - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::SDIRK22Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; atmp, nlsolver, step_limiter!) = cache - (; z, tmp) = nlsolver - (; a, α, β) = cache.tab - alg = unwrap_alg(integrator, true) - mass_matrix = integrator.f.mass_matrix - - # precalculations - γ = a * dt - γdt = γ * dt - markfirststage!(nlsolver) - - # first stage - @.. broadcast=false z=dt * integrator.fsalfirst - @.. broadcast=false tmp=uprev + γdt * integrator.fsalfirst - z = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - @.. broadcast=false u=α * tmp + β * z - - # final stage - γ = dt - γdt = γ * dt - markfirststage!(nlsolver) - @.. broadcast=false tmp=uprev + γdt * integrator.fsalfirst - z = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - @.. broadcast=false u=nlsolver.tmp - - step_limiter!(u, integrator, p, t + dt) - - if integrator.opts.adaptive - if integrator.iter > 2 - # local truncation error (LTE) bound by dt^3/12*max|y'''(t)| - # use 3rd divided differences (DD) a la SPICE and Shampine - - # TODO: check numerical stability - uprev2 = integrator.uprev2 - tprev = integrator.tprev - uprev3 = cache.uprev3 - tprev2 = cache.tprev2 - - dt1 = dt * (t + dt - tprev) - dt2 = (t - tprev) * (t + dt - tprev) - dt3 = (t - tprev) * (t - tprev2) - dt4 = (tprev - tprev2) * (t - tprev2) - dt5 = t + dt - tprev2 - c = 7 / 12 # default correction factor in SPICE (LTE overestimated by DD) - r = c * dt^3 / 2 # by mean value theorem 3rd DD equals y'''(s)/6 for some s - - @inbounds for i in eachindex(u) - DD31 = (u[i] - uprev[i]) / dt1 - (uprev[i] - uprev2[i]) / dt2 - DD30 = (uprev[i] - uprev2[i]) / dt3 - (uprev2[i] - uprev3[i]) / dt4 - tmp[i] = r * abs((DD31 - DD30) / dt5) - end - calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - if integrator.EEst <= 1 - copyto!(cache.uprev3, uprev2) - cache.tprev2 = tprev - end - elseif integrator.success_iter > 0 - integrator.EEst = 1 - copyto!(cache.uprev3, integrator.uprev2) - cache.tprev2 = integrator.tprev - else - integrator.EEst = 1 - end - end - - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2) - f(integrator.fsallast, u, p, t + dt) -end - -@muladd function perform_step!(integrator, cache::SSPSDIRK2ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - γ = eltype(u)(1 // 4) - c2 = typeof(t)(3 // 4) - - markfirststage!(nlsolver) - - # initial guess - if integrator.success_iter > 0 && !integrator.reeval_fsal && - alg.extrapolant == :interpolant - current_extrapolant!(u, t + dt, integrator) - z₁ = u - uprev - elseif alg.extrapolant == :linear - z₁ = dt * integrator.fsalfirst - else - z₁ = zero(u) - end - nlsolver.z = z₁ - - ##### Step 1 - - tstep = t + dt - u = uprev + γ * z₁ - - nlsolver.c = 1 - nlsolver.tmp = uprev - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 2 - - ### Initial Guess Is α₁ = c₂/γ - z₂ = c2 / γ - nlsolver.z = z₂ - - nlsolver.tmp = uprev + z₁ / 2 - nlsolver.c = 1 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + z₂ / 2 - - ################################### Finalize - - integrator.fsallast = f(u, p, t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::SSPSDIRK2Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, nlsolver) = cache - (; tmp) = nlsolver - alg = unwrap_alg(integrator, true) - - γ = eltype(u)(1 // 4) - c2 = typeof(t)(3 // 4) - markfirststage!(nlsolver) - - # initial guess - if integrator.success_iter > 0 && !integrator.reeval_fsal && - alg.extrapolant == :interpolant - current_extrapolant!(u, t + dt, integrator) - @.. broadcast=false z₁=u - uprev - elseif alg.extrapolant == :linear - @.. broadcast=false z₁=dt * integrator.fsalfirst - else - z₁ .= zero(eltype(u)) - end - nlsolver.z = z₁ - nlsolver.tmp = uprev - - ##### Step 1 - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 2 - - ### Initial Guess Is α₁ = c₂/γ - @.. broadcast=false z₂=c2 / γ - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + z₁ / 2 - nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + z₂ / 2 - - ################################### Finalize - - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - f(integrator.fsallast, u, p, t) -end - -@muladd function perform_step!(integrator, cache::Cash4ConstantCache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, c2, c3, c4) = cache.tab - (; b1hat1, b2hat1, b3hat1, b4hat1, b1hat2, b2hat2, b3hat2, b4hat2) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - ##### Step 1 - - # TODO: Add extrapolation for guess - z₁ = zero(u) - nlsolver.z = z₁ - - nlsolver.c = γ - nlsolver.tmp = uprev - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ = zero(u) - nlsolver.z = z₂ - - nlsolver.tmp = uprev + a21 * z₁ - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess starts from z₁ - z₃ = z₁ - nlsolver.z = z₃ - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - z₄ = z₃ - nlsolver.z = z₄ - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use yhat2 for prediction - z₅ = b1hat2 * z₁ + b2hat2 * z₂ + b3hat2 * z₃ + b4hat2 * z₄ - nlsolver.z = z₅ - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = 1 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₅ - - ################################### Finalize - - if integrator.opts.adaptive - if alg.embedding == 3 - btilde1 = b1hat2 - a51 - btilde2 = b2hat2 - a52 - btilde3 = b3hat2 - a53 - btilde4 = b4hat2 - a54 - btilde5 = -γ - else - btilde1 = b1hat1 - a51 - btilde2 = b2hat1 - a52 - btilde3 = b3hat1 - a53 - btilde4 = b4hat1 - a54 - btilde5 = -γ - end - - tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₅ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::Cash4Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, atmp, nlsolver) = cache - (; tmp) = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, c2, c3, c4) = cache.tab - (; b1hat1, b2hat1, b3hat1, b4hat1, b1hat2, b2hat2, b3hat2, b4hat2) = cache.tab - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - ##### Step 1 - - # TODO: Add extrapolation for guess - z₁ .= zero(eltype(z₁)) - nlsolver.z = z₁ - nlsolver.c = γ - nlsolver.tmp = uprev - - # initial step of NLNewton iteration - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(z₂)) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + a21 * z₁ - nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess starts from z₁ - @.. broadcast=false z₃=z₁ - nlsolver.z = z₃ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - @.. broadcast=false z₄=z₃ - nlsolver.z = z₄ - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use constant z prediction - @.. broadcast=false z₅=b1hat2 * z₁ + b2hat2 * z₂ + b3hat2 * z₃ + b4hat2 * z₄ - nlsolver.z = z₅ - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = 1 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₅ - - ################################### Finalize - - if integrator.opts.adaptive - if alg.embedding == 3 - btilde1 = b1hat2 - a51 - btilde2 = b2hat2 - a52 - btilde3 = b3hat2 - a53 - btilde4 = b4hat2 - a54 - btilde5 = -γ - else - btilde1 = b1hat1 - a51 - btilde2 = b2hat1 - a52 - btilde3 = b3hat1 - a53 - btilde4 = b4hat1 - a54 - btilde5 = -γ - end - - @.. broadcast=false tmp=btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ - if alg.smooth_est && isnewton(nlsolver) # From Shampine - est = nlsolver.cache.dz - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast=false integrator.fsallast=z₅ / dt -end - -@muladd function perform_step!(integrator, cache::SFSDIRK4ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, c2, c3, c4) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - ##### Step 1 - - # TODO: Add extrapolation for guess - z₁ = zero(u) - nlsolver.z = z₁ - - nlsolver.c = γ - nlsolver.tmp = uprev - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ = zero(u) - nlsolver.z = z₂ - - nlsolver.tmp = uprev + a21 * z₁ - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess starts from z₁ - z₃ = z₁ - nlsolver.z = z₃ - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - z₄ = z₃ - nlsolver.z = z₄ - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Final Step - - u = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - - integrator.fsallast = z₄ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::SFSDIRK4Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, nlsolver) = cache - (; tmp) = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, c2, c3, c4) = cache.tab - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - ##### Step 1 - - # TODO: Add extrapolation for guess - z₁ .= zero(eltype(z₁)) - nlsolver.z = z₁ - nlsolver.c = γ - nlsolver.tmp = uprev - - # initial step of NLNewton iteration - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(z₂)) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + a21 * z₁ - nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess starts from z₁ - @.. broadcast=false z₃=z₁ - nlsolver.z = z₃ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - @.. broadcast=false z₄=z₃ - nlsolver.z = z₄ - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - - ################################### Finalize - @.. broadcast=false integrator.fsallast=z₄ / dt -end - -@muladd function perform_step!(integrator, cache::SFSDIRK5ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, c2, c3, c4, c5) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - ##### Step 1 - - # TODO: Add extrapolation for guess - z₁ = zero(u) - nlsolver.z = z₁ - - nlsolver.c = γ - nlsolver.tmp = uprev - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ = zero(u) - nlsolver.z = z₂ - - nlsolver.tmp = uprev + a21 * z₁ - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess starts from z₁ - z₃ = z₁ - nlsolver.z = z₃ - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - z₄ = z₃ - nlsolver.z = z₄ - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use constant z prediction - z₅ = z₄ - nlsolver.z = z₅ - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Final Step - - u = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - - integrator.fsallast = z₅ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::SFSDIRK5Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, nlsolver) = cache - (; tmp) = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, c2, c3, c4, c5) = cache.tab - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - ##### Step 1 - - # TODO: Add extrapolation for guess - z₁ .= zero(eltype(z₁)) - nlsolver.z = z₁ - nlsolver.c = γ - nlsolver.tmp = uprev - - # initial step of NLNewton iteration - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(z₂)) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + a21 * z₁ - nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess starts from z₁ - @.. broadcast=false z₃=z₁ - nlsolver.z = z₃ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - @.. broadcast=false z₄=z₃ - nlsolver.z = z₄ - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use constant z prediction - @.. broadcast=false z₅=z₄ - nlsolver.z = z₅ - - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - ################################### Finalize - @.. broadcast=false u=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - - @.. broadcast=false integrator.fsallast=z₅ / dt -end - -@muladd function perform_step!(integrator, cache::SFSDIRK6ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, c2, c3, c4, c5, c6) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - ##### Step 1 - - # TODO: Add extrapolation for guess - z₁ = zero(u) - nlsolver.z = z₁ - - nlsolver.c = γ - nlsolver.tmp = uprev - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ = zero(u) - nlsolver.z = z₂ - - nlsolver.tmp = uprev + a21 * z₁ - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess starts from z₁ - z₃ = z₁ - nlsolver.z = z₃ - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - z₄ = z₃ - nlsolver.z = z₄ - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use constant z prediction - z₅ = z₄ - nlsolver.z = z₅ - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use constant z prediction - z₆ = z₅ - nlsolver.z = z₆ - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Final Step - - u = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - - integrator.fsallast = z₆ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::SFSDIRK6Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, nlsolver) = cache - (; tmp) = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, c2, c3, c4, c5, c6) = cache.tab - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - ##### Step 1 - - # TODO: Add extrapolation for guess - z₁ .= zero(eltype(z₁)) - nlsolver.z = z₁ - nlsolver.c = γ - nlsolver.tmp = uprev - - # initial step of NLNewton iteration - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(z₂)) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + a21 * z₁ - nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess starts from z₁ - @.. broadcast=false z₃=z₁ - nlsolver.z = z₃ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - @.. broadcast=false z₄=z₃ - nlsolver.z = z₄ - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use constant z prediction - @.. broadcast=false z₅=z₄ - nlsolver.z = z₅ - - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - # Use constant z prediction - @.. broadcast=false z₆=z₅ - nlsolver.z = z₆ - - @.. broadcast=false tmp=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################### Finalize - @.. broadcast=false u=uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ - - @.. broadcast=false integrator.fsallast=z₆ / dt -end - -@muladd function perform_step!(integrator, cache::SFSDIRK7ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, a81, a82, a83, a84, a85, a86, a87, c2, c3, c4, c5, c6, c7) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - ##### Step 1 - - # TODO: Add extrapolation for guess - z₁ = zero(u) - nlsolver.z = z₁ - - nlsolver.c = γ - nlsolver.tmp = uprev - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ = zero(u) - nlsolver.z = z₂ - - nlsolver.tmp = uprev + a21 * z₁ - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess starts from z₁ - z₃ = z₁ - nlsolver.z = z₃ - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - z₄ = z₃ - nlsolver.z = z₄ - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use constant z prediction - z₅ = z₄ - nlsolver.z = z₅ - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - # Use constant z prediction - z₆ = z₅ - nlsolver.z = z₆ - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - # Use constant z prediction - z₇ = z₆ - nlsolver.z = z₇ - - nlsolver.tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Final Step - - u = uprev + a81 * z₁ + a82 * z₂ + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ - - integrator.fsallast = z₇ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::SFSDIRK7Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, nlsolver) = cache - (; tmp) = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, a81, a82, a83, a84, a85, a86, a87, c2, c3, c4, c5, c6, c7) = cache.tab - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - ##### Step 1 - - # TODO: Add extrapolation for guess - z₁ .= zero(eltype(z₁)) - nlsolver.z = z₁ - nlsolver.c = γ - nlsolver.tmp = uprev - - # initial step of NLNewton iteration - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(z₂)) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + a21 * z₁ - nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess starts from z₁ - @.. broadcast=false z₃=z₁ - nlsolver.z = z₃ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - @.. broadcast=false z₄=z₃ - nlsolver.z = z₄ - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use constant z prediction - @.. broadcast=false z₅=z₄ - nlsolver.z = z₅ - - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - # Use constant z prediction - @.. broadcast=false z₆=z₅ - nlsolver.z = z₆ - - @.. broadcast=false tmp=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - # Use constant z prediction - @.. broadcast=false z₇=z₆ - nlsolver.z = z₇ - - @.. broadcast=false tmp=uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################### Finalize - @.. broadcast=false u=uprev + a81 * z₁ + a82 * z₂ + a83 * z₃ + a84 * z₄ + a85 * z₅ + - a86 * z₆ + a87 * z₇ - - @.. broadcast=false integrator.fsallast=z₇ / dt -end - -@muladd function perform_step!(integrator, cache::SFSDIRK8ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, a81, a82, a83, a84, a85, a86, a87, a91, a92, a93, a94, a95, a96, a97, a98, c2, c3, c4, c5, c6, c7, c8) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - ##### Step 1 - - # TODO: Add extrapolation for guess - z₁ = zero(u) - nlsolver.z = z₁ - - nlsolver.c = γ - nlsolver.tmp = uprev - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ = zero(u) - nlsolver.z = z₂ - - nlsolver.tmp = uprev + a21 * z₁ - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess starts from z₁ - z₃ = z₁ - nlsolver.z = z₃ - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - z₄ = z₃ - nlsolver.z = z₄ - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use constant z prediction - z₅ = z₄ - nlsolver.z = z₅ - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - # Use constant z prediction - z₆ = z₅ - nlsolver.z = z₆ - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - # Use constant z prediction - z₇ = z₆ - nlsolver.z = z₇ - - nlsolver.tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - # Use constant z prediction - z₈ = z₇ - nlsolver.z = z₈ - - nlsolver.tmp = uprev + a81 * z₁ + a82 * z₂ + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + - a87 * z₇ - nlsolver.c = c8 - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Final Step - - u = uprev + a91 * z₁ + a92 * z₂ + a93 * z₃ + a94 * z₄ + a95 * z₅ + a96 * z₆ + a97 * z₇ + - a98 * z₈ - - integrator.fsallast = z₈ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::SFSDIRK8Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, nlsolver) = cache - (; tmp) = nlsolver - W = isnewton(nlsolver) ? get_W(nlsolver) : nothing - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, a81, a82, a83, a84, a85, a86, a87, a91, a92, a93, a94, a95, a96, a97, a98, c2, c3, c4, c5, c6, c7, c8) = cache.tab - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - ##### Step 1 - - # TODO: Add extrapolation for guess - z₁ .= zero(eltype(z₁)) - nlsolver.z = z₁ - nlsolver.c = γ - nlsolver.tmp = uprev - - # initial step of NLNewton iteration - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(z₂)) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + a21 * z₁ - nlsolver.tmp = tmp - isnewton(nlsolver) && set_new_W!(nlsolver, false) - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess starts from z₁ - @.. broadcast=false z₃=z₁ - nlsolver.z = z₃ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - @.. broadcast=false z₄=z₃ - nlsolver.z = z₄ - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use constant z prediction - @.. broadcast=false z₅=z₄ - nlsolver.z = z₅ - - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - # Use constant z prediction - @.. broadcast=false z₆=z₅ - nlsolver.z = z₆ - - @.. broadcast=false tmp=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - # Use constant z prediction - @.. broadcast=false z₇=z₆ - nlsolver.z = z₇ - - @.. broadcast=false tmp=uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - # Use constant z prediction - @.. broadcast=false z₈=z₇ - nlsolver.z = z₈ - - @.. broadcast=false tmp=uprev + a81 * z₁ + a82 * z₂ + a83 * z₃ + a84 * z₄ + a85 * z₅ + - a86 * z₆ + a87 * z₇ - nlsolver.c = c8 - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################### Finalize - @.. broadcast=false u=uprev + a91 * z₁ + a92 * z₂ + a93 * z₃ + a94 * z₄ + a95 * z₅ + - a96 * z₆ + a97 * z₇ + a98 * z₈ - - @.. broadcast=false integrator.fsallast=z₈ / dt -end - -@muladd function perform_step!(integrator, cache::Hairer4ConstantCache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, c2, c3, c4) = cache.tab - (; α21, α31, α32, α41, α43) = cache.tab - (; bhat1, bhat2, bhat3, bhat4, btilde1, btilde2, btilde3, btilde4, btilde5) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - # TODO: Add extrapolation for guess - z₁ = zero(u) - nlsolver.z, nlsolver.tmp = z₁, uprev - nlsolver.c = γ - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - z₂ = α21 * z₁ - nlsolver.z = z₂ - nlsolver.tmp = uprev + a21 * z₁ - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - z₃ = α31 * z₁ + α32 * z₂ - nlsolver.z = z₃ - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - z₄ = α41 * z₁ + α43 * z₃ - nlsolver.z = z₄ - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use yhat2 for prediction - z₅ = bhat1 * z₁ + bhat2 * z₂ + bhat3 * z₃ + bhat4 * z₄ - nlsolver.z = z₅ - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = 1 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₅ - - ################################### Finalize - - if integrator.opts.adaptive - tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₅ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::Hairer4Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; γ, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, c2, c3, c4) = cache.tab - (; α21, α31, α32, α41, α43) = cache.tab - (; bhat1, bhat2, bhat3, bhat4, btilde1, btilde2, btilde3, btilde4, btilde5) = cache.tab - alg = unwrap_alg(integrator, true) - markfirststage!(nlsolver) - - # initial guess - if integrator.success_iter > 0 && !integrator.reeval_fsal && - alg.extrapolant == :interpolant - current_extrapolant!(u, t + dt, integrator) - @.. broadcast=false z₁=u - uprev - elseif alg.extrapolant == :linear - @.. broadcast=false z₁=dt * integrator.fsalfirst - else - z₁ .= zero(eltype(z₁)) - end - nlsolver.z = z₁ - nlsolver.tmp = uprev - - ##### Step 1 - - nlsolver.c = γ - z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ##### Step 2 - - @.. broadcast=false z₂=α21 * z₁ - nlsolver.z = z₂ - @.. broadcast=false tmp=uprev + a21 * z₁ - nlsolver.tmp = tmp - nlsolver.c = c2 - isnewton(nlsolver) && set_new_W!(nlsolver, false) - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - @.. broadcast=false z₃=α31 * z₁ + α32 * z₂ - nlsolver.z = z₃ - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - @.. broadcast=false z₄=α41 * z₁ + α43 * z₃ - nlsolver.z = z₄ - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use yhat prediction - @.. broadcast=false z₅=bhat1 * z₁ + bhat2 * z₂ + bhat3 * z₃ + bhat4 * z₄ - nlsolver.z = z₅ - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = 1 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₅ - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast=false tmp=btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ - if alg.smooth_est && isnewton(nlsolver) # From Shampine - est = nlsolver.cache.dz - linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est)) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast=false integrator.fsallast=z₅ / dt -end - -@muladd function perform_step!(integrator, cache::ESDIRK54I8L2SAConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - a81, a82, a83, a84, a85, a86, a87, - c3, c4, c5, c6, c7, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7, btilde8) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - # TODO: Add extrapolation for guess - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = zero(z₁) - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = zero(z₂) - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = zero(z₃) - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = z₅ = zero(z₄) - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = z₆ = zero(z₅) - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = z₇ = zero(z₆) - - nlsolver.tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - nlsolver.z = z₈ = zero(z₇) - - nlsolver.tmp = uprev + a81 * z₁ + a82 * z₂ + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + - a87 * z₇ - nlsolver.c = 1 - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₈ - - ################################### Finalize - - if integrator.opts.adaptive - est = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₈ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK54I8L2SACache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - a81, a82, a83, a84, a85, a86, a87, - c3, c4, c5, c6, c7, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7, btilde8) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast=false z₁=dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - nlsolver.z = fill!(z₃, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - nlsolver.z = fill!(z₄, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = fill!(z₅, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = fill!(z₆, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = fill!(z₇, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - nlsolver.z = fill!(z₈, zero(eltype(u))) - - @.. broadcast=false nlsolver.tmp=uprev + a81 * z₁ + a82 * z₂ + a83 * z₃ + a84 * z₄ + - a85 * z₅ + a86 * z₆ + a87 * z₇ - nlsolver.c = oneunit(nlsolver.c) - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₈ - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast=false tmp=btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ - calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast=false integrator.fsallast=z₈ / dt - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK436L2SA2ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - c3, c4, c5, c6, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - # TODO: Add extrapolation for guess - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = zero(z₁) - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = zero(z₂) - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = zero(z₃) - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = z₅ = zero(z₄) - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = z₆ = zero(z₅) - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₆ - - ################################### Finalize - - if integrator.opts.adaptive - est = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₆ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK436L2SA2Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - c3, c4, c5, c6, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast=false z₁=dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - nlsolver.z = fill!(z₃, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - nlsolver.z = fill!(z₄, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = fill!(z₅, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = fill!(z₆, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₆ - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast=false tmp=btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + btilde6 * z₆ - calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast=false integrator.fsallast=z₆ / dt - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK437L2SAConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - c3, c4, c5, c6, c7, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - # TODO: Add extrapolation for guess - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = zero(z₁) - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = zero(z₂) - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = zero(z₃) - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = z₅ = zero(z₄) - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = z₆ = zero(z₅) - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = z₇ = zero(z₆) - - nlsolver.tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₇ - - ################################### Finalize - - if integrator.opts.adaptive - est = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₇ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK437L2SACache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - c3, c4, c5, c6, c7, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast=false z₁=dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - nlsolver.z = fill!(z₃, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - nlsolver.z = fill!(z₄, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = fill!(z₅, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = fill!(z₆, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = fill!(z₇, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₇ - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast=false tmp=btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ - calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast=false integrator.fsallast=z₇ / dt - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK547L2SA2ConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - c3, c4, c5, c6, c7, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - # TODO: Add extrapolation for guess - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = zero(z₁) - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = zero(z₂) - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = zero(z₃) - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = z₅ = zero(z₄) - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = z₆ = zero(z₅) - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = z₇ = zero(z₆) - - nlsolver.tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₇ - - ################################### Finalize - - if integrator.opts.adaptive - est = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₇ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK547L2SA2Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - c3, c4, c5, c6, c7, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast=false z₁=dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - nlsolver.z = fill!(z₃, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - nlsolver.z = fill!(z₄, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = fill!(z₅, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = fill!(z₆, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = fill!(z₇, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₇ - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast=false tmp=btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ - calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast=false integrator.fsallast=z₇ / dt - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK659L2SAConstantCache, - repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - a81, a82, a83, a84, a85, a86, a87, - a94, a95, a96, a97, a98, - c3, c4, c5, c6, c7, c8, c9, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7, btilde8, btilde9) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - # TODO: Add extrapolation for guess - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = zero(z₁) - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = zero(z₂) - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = zero(z₃) - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = z₅ = zero(z₄) - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = z₆ = zero(z₅) - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = z₇ = zero(z₆) - - nlsolver.tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - nlsolver.z = z₈ = zero(z₇) - - nlsolver.tmp = uprev + a81 * z₁ + a82 * z₂ + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + - a87 * z₇ - nlsolver.c = c8 - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 9 - nlsolver.z = z₉ = zero(z₈) - - nlsolver.tmp = uprev + a94 * z₄ + a95 * z₅ + a96 * z₆ + a97 * z₇ + a98 * z₈ - nlsolver.c = c9 - z₉ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₉ - - ################################### Finalize - - if integrator.opts.adaptive - est = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ + btilde9 * z₉ - atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₉ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK659L2SACache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, z₉, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - a81, a82, a83, a84, a85, a86, a87, - a94, a95, a96, a97, a98, - c3, c4, c5, c6, c7, c8, c9, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7, btilde8, btilde9) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast=false z₁=dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - - @.. broadcast=false tmp=uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - nlsolver.z = fill!(z₃, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - nlsolver.z = fill!(z₄, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = fill!(z₅, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = fill!(z₆, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = fill!(z₇, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - nlsolver.z = fill!(z₈, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a81 * z₁ + a82 * z₂ + a83 * z₃ + a84 * z₄ + a85 * z₅ + - a86 * z₆ + a87 * z₇ - nlsolver.c = c8 - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 9 - - nlsolver.z = fill!(z₉, zero(eltype(u))) - - @.. broadcast=false tmp=uprev + a94 * z₄ + a95 * z₅ + a96 * z₆ + a97 * z₇ + a98 * z₈ - nlsolver.c = c9 - z₉ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast=false u=tmp + γ * z₉ - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast=false tmp=btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ + btilde9 * z₉ - calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast=false integrator.fsallast=z₉ / dt - return -end diff --git a/lib/OrdinaryDiffEqSDIRK/src/tableau_utils.jl b/lib/OrdinaryDiffEqSDIRK/src/tableau_utils.jl new file mode 100644 index 0000000000..82b8ace0cb --- /dev/null +++ b/lib/OrdinaryDiffEqSDIRK/src/tableau_utils.jl @@ -0,0 +1,71 @@ +# Tableau-based utility functions for SDIRK methods + +@inline function compute_sdirk_stage!(integrator, cache, nlsolver, stage::Int, + z_prev, stage_coeffs, c_val, tmp_val) + nlsolver.z = z_prev + nlsolver.tmp = tmp_val + nlsolver.c = c_val + z_new = nlsolve!(nlsolver, integrator, cache, false) + nlsolvefail(nlsolver) && return nothing + z_new +end + +@inline function compute_stage_constantcache!(integrator, cache, stage::Int, + prev_z, coeffs, c_val, base_tmp) + nlsolver = cache.nlsolver + nlsolver.z = prev_z + nlsolver.tmp = base_tmp + nlsolver.c = c_val + z = nlsolve!(nlsolver, integrator, cache, false) + nlsolvefail(nlsolver) && return nothing + z +end + +@inline function compute_stage_mutablecache!(integrator, cache, stage::Int, + prev_z, coeffs, c_val, base_tmp) + nlsolver = cache.nlsolver + nlsolver.z = prev_z + nlsolver.tmp = base_tmp + nlsolver.c = c_val + z = nlsolve!(nlsolver, integrator, cache, false) + nlsolvefail(nlsolver) && return nothing + isnewton(nlsolver) && set_new_W!(nlsolver, false) + z +end + +# generic error estimation for embedded methods +@inline function compute_embedded_error!(integrator, cache, btilde_coeffs, z_stages) + if integrator.opts.adaptive + tmp = sum(btilde_coeffs[i] * z_stages[i] for i in eachindex(z_stages)) + alg = unwrap_alg(integrator, true) + nlsolver = cache.nlsolver + + if isnewton(nlsolver) && alg.smooth_est + integrator.stats.nsolve += 1 + if hasfield(typeof(cache), :atmp) + est = cache.atmp + linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), + linu = _vec(est)) + else + est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) + end + else + est = tmp + end + + if hasfield(typeof(cache), :atmp) + calculate_residuals!(cache.atmp, est, integrator.uprev, integrator.u, + integrator.opts.abstol, integrator.opts.reltol, + integrator.opts.internalnorm, integrator.t) + integrator.EEst = integrator.opts.internalnorm(cache.atmp, integrator.t) + else + atmp = calculate_residuals(est, integrator.uprev, integrator.u, + integrator.opts.abstol, integrator.opts.reltol, + integrator.opts.internalnorm, integrator.t) + integrator.EEst = integrator.opts.internalnorm(atmp, integrator.t) + end + end +end + + + diff --git a/lib/OrdinaryDiffEqSDIRK/src/unified_sdirk_tableaus.jl b/lib/OrdinaryDiffEqSDIRK/src/unified_sdirk_tableaus.jl new file mode 100644 index 0000000000..b0b232149e --- /dev/null +++ b/lib/OrdinaryDiffEqSDIRK/src/unified_sdirk_tableaus.jl @@ -0,0 +1,757 @@ +using StaticArrays + +abstract type AbstractTableau{T} end + +struct SDIRKTableau{T, T2, S, hasEmbedded, hasAdditiveSplitting, hasExplicit, hasPred} <: AbstractTableau{T} + A::SMatrix{S, S, T} + b::SVector{S, T} + c::SVector{S, T2} + b_embed::Union{SVector{S, T}, Nothing} + γ::T + order::Int + embedded_order::Int + is_fsal::Bool + is_stiffly_accurate::Bool + is_A_stable::Bool + is_L_stable::Bool + predictor_type::Symbol + A_explicit::Union{SMatrix{S, S, T}, Nothing} + b_explicit::Union{SVector{S, T}, Nothing} + c_explicit::Union{SVector{S, T2}, Nothing} + α_pred::Union{SMatrix{S, S, T2}, Nothing} + has_spice_error::Bool +end + +function SDIRKTableau(A::SMatrix{S, S, T}, b::SVector{S, T}, c::SVector{S, T2}, γ::T, + order::Int; b_embed=nothing, embedded_order=0, + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:default, has_additive_splitting=false, + A_explicit=nothing, b_explicit=nothing, c_explicit=nothing, + α_pred=nothing, has_spice_error=false) where {S, T, T2} + + hasEmbedded = b_embed !== nothing + hasAdditiveSplitting = has_additive_splitting + hasExplicit = A_explicit !== nothing + hasPred = α_pred !== nothing + SDIRKTableau{T, T2, S, hasEmbedded, hasAdditiveSplitting, hasExplicit, hasPred}( + A, b, c, b_embed, γ, order, embedded_order, + is_fsal, is_stiffly_accurate, is_A_stable, + is_L_stable, predictor_type, + A_explicit, b_explicit, c_explicit, α_pred, has_spice_error) +end + +function TRBDF2Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γ = T2(2 - sqrt(2)) + d = T(1 - sqrt(2) / 2) + ω = T(sqrt(2) / 4) + + A = @SMatrix [0 0 0; + d d 0; + ω ω d] + + b = @SVector [ω, ω, d] + c = @SVector [0, γ, 1] + + b_embed = @SVector [(1-ω)/3, (3*ω+1)/3, d/3] + + α1 = T2(-sqrt(2) / 2) + α2 = T2(1 + sqrt(2) / 2) + α_pred = @SMatrix T2[ + 0 0 0; + 0 0 0; + α1 α2 0 + ] + + SDIRKTableau(A, b, c, d, 2; + b_embed=b_embed, embedded_order=3, + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:trbdf2_special, + α_pred=α_pred) +end + +function ImplicitEulerTableau(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + A = @SMatrix [T(1.0)] + b = @SVector [T(1.0)] + c = @SVector [T2(1.0)] + γ = T(1.0) + + SDIRKTableau(A, b, c, γ, 1; + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:trivial, + has_spice_error=true) +end + +function ImplicitMidpointTableau(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.5) + γc = T2(0.5) + A = @SMatrix [γT] + b = @SVector [T(1.0)] + c = @SVector [γc] + + SDIRKTableau(A, b, c, γT, 2; + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:trivial) +end + +function TrapezoidTableau(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.5) + A = @SMatrix [T(0) T(0); + T(0.5) T(0.5)] + b = @SVector [T(0.5), T(0.5)] + c = @SVector [T2(0.0), T2(1.0)] + + SDIRKTableau(A, b, c, γT, 2; + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:default, + has_spice_error=true) +end + +function SDIRK2Tableau(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(1 - 1/sqrt(2)) + γc = T2(1 - 1/sqrt(2)) + A = @SMatrix [γT T(0); + T(1)-γT γT] + b = @SVector [T(1)-γT, γT] + c = @SVector [γc, T2(1)] + + SDIRKTableau(A, b, c, γT, 2; + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=false, + predictor_type=:default) +end + +function SSPSDIRK2Tableau(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(1 - 1/sqrt(2)) + γc = T2(1 - 1/sqrt(2)) + A = @SMatrix [γT T(0); + T(1)-T(2)*γT γT] + b = @SVector [T(0.5), T(0.5)] + c = @SVector [γc, T2(1)-γc] + + SDIRKTableau(A, b, c, γT, 2; + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:default) +end + +function Kvaerno3Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.4358665215) + γc = T2(0.4358665215) + + A = @SMatrix [γT 0 0 0; + T(0.490563388419108) γT 0 0; + T(0.073570090080892) 0 γT 0; + T(0.308809969973036) T(1.490563388254106) -T(1.235239879727145) γT] + + b = @SVector [T(0.490563388419108), T(0.073570090080892), T(0.4358665215), T(0.0)] + c = @SVector [γc, 2γc, T2(1), T2(1)] + + b_embed = b - A[4, :] + + # Build Hermite-style predictor coefficients for stage 3 from z₁ and z₂ + # θ = c3/c2 over interval [0,c2] + c2 = 2γc + θ = c[3] / c2 + α31 = ((1 + (-4θ + 3θ^2)) + (6θ * (1 - θ) / c2) * γc) + α32 = ((-2θ + 3θ^2) + (6θ * (1 - θ) / c2) * γc) + α_pred = @SMatrix T2[ + 0 0 0 0; + 0 0 0 0; + α31 α32 0 0; + 0 0 0 0 + ] + + SDIRKTableau(A, b, c, γT, 3; + b_embed=b_embed, embedded_order=2, + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:hermite, + α_pred=α_pred) +end + +function KenCarp3Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.435866521508459) + γc = T2(0.435866521508459) + + A = @SMatrix [γT 0 0 0; + T(0.2576482460664272) γT 0 0; + -T(0.09351476757488625) 0 γT 0; + T(0.18764102434672383) -T(0.595297473576955) T(0.9717899277217721) γT] + + b = @SVector [T(2756255671327//12835298489170), + -T(10771552573575//22201958757719), + T(9247589265047//10645013368117), + T(2193209047091//5459859503100)] + + c = @SVector [γc, 2γc, T2(0.6), T2(1)] + + b_embed = A[4, :] + + A_explicit = @SMatrix [0 0 0 0; + T(0.871733043016918) 0 0 0; + T(0.5275890119763004) T(0.0724109880236996) 0 0; + T(0.3990960076760701) -T(0.4375576546135194) T(1.0384616469374492) 0] + + b_explicit = @SVector [T(0.18764102434672383), + -T(0.595297473576955), + T(0.9717899277217721), + T(0.435866521508459)] + + c_explicit = @SVector [T2(0), γc, T2(0.6), T2(1)] + + # Build Hermite-style predictor coefficients for stages 3 and 4 + c2 = 2γc + θ = c[3] / c2 + α31 = ((1 + (-4θ + 3θ^2)) + (6θ * (1 - θ) / c2) * γc) + α32 = ((-2θ + 3θ^2) + (6θ * (1 - θ) / c2) * γc) + θ4 = c[4] / c2 # == 1 + α41 = ((1 + (-4θ4 + 3θ4^2)) + (6θ4 * (1 - θ4) / c2) * γc) + α42 = ((-2θ4 + 3θ4^2) + (6θ4 * (1 - θ4) / c2) * γc) + α_pred = @SMatrix T2[ + 0 0 0 0; + 0 0 0 0; + α31 α32 0 0; + α41 α42 0 0 + ] + + SDIRKTableau(A, b, c, γT, 3; + b_embed=b_embed, embedded_order=2, + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:kencarp_additive, + has_additive_splitting=true, + A_explicit=A_explicit, b_explicit=b_explicit, c_explicit=c_explicit, + α_pred=α_pred) +end + +function Kvaerno4Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.4358665215) + γc = T2(0.4358665215) + + A = @SMatrix [γT 0 0 0 0; + T(0.490563388419108) γT 0 0 0; + T(0.073570090080892) 0 γT 0 0; + T(0.308809969973036) T(1.490563388254106) -T(1.235239879727145) γT 0; + T(0.490563388419108) T(0.073570090080892) T(0.4358665215) T(0.0) γT] + + b = @SVector [T(0.490563388419108), T(0.073570090080892), T(0.4358665215), T(0.0), T(0.0)] + c = @SVector [γc, 2γc, T2(1), T2(1), T2(1)] + + b_embed = A[5, :] + + SDIRKTableau(A, b, c, γT, 4; + b_embed=b_embed, embedded_order=3, + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:hermite) +end + +function Kvaerno5Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.26) + γc = T2(0.26) + + A = @SMatrix [γT 0 0 0 0 0; + T(0.13) γT 0 0 0 0; + T(0.84079895052208) T(0.07920104947792) γT 0 0 0; + T(0.619100897516618) T(0.066593016584582) T(0.054305985899400) γT 0 0; + T(0.258023287184119) T(0.097741417057132) T(0.464732297848610) T(1.179502539939939) γT 0; + T(0.544974750228521) T(0.212765981366776) T(0.164488906111538) T(0.077770561901165) T(0.0) γT] + + b = @SVector [T(0.544974750228521), T(0.212765981366776), T(0.164488906111538), T(0.077770561901165), T(0.0), T(0.0)] + c = @SVector [γc, T2(0.39), T2(1.0), T2(0.74), T2(1.0), T2(1.0)] + + b_embed = A[6, :] + + SDIRKTableau(A, b, c, γT, 5; + b_embed=b_embed, embedded_order=4, + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:hermite) +end + +function KenCarp4Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(1//4) + γc = T2(1//4) + + A = @SMatrix [γT 0 0 0 0 0; + T(1//2) γT 0 0 0 0; + T(83//250) T(-13//250) γT 0 0 0; + T(31//50) T(-11//20) T(11//20) γT 0 0; + T(17//20) T(-1//4) T(1//4) T(1//2) γT 0; + T(755//1728) T(755//1728) T(-1640//1728) T(1640//1728) T(1//4) γT] + + b = @SVector [T(755//1728), T(755//1728), T(-1640//1728), T(1640//1728), T(1//4), T(0)] + c = @SVector [γc, T2(3//4), T2(11//20), T2(1//2), T2(1), T2(1)] + + b_embed = A[6, :] + + A_explicit = @SMatrix [0 0 0 0 0 0; + T(1//2) 0 0 0 0 0; + T(13861//62500) T(6889//62500) 0 0 0 0; + T(-116923316275//2393684061468) T(-2731218467317//15368042101831) T(9408046702089//11113171139209) 0 0 0; + T(-451086348788//2902428689909) T(-2682348792572//7519795681897) T(12662868775082//11960479115383) T(3355817975965//11060851509271) 0 0; + T(647845179188//3216320057751) T(73281519250//8382639484533) T(552539513391//3454668386233) T(3354512671639//8306763924573) T(4040//17871) 0] + + b_explicit = @SVector [T(82889//524892), T(0), T(15625//83664), T(69875//102672), T(-2260//8211), T(1//4)] + c_explicit = @SVector [0, T2(1//2), T2(83//250), T2(31//50), T2(17//20), T2(1)] + + SDIRKTableau(A, b, c, γT, 4; + b_embed=b_embed, embedded_order=3, + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:kencarp_additive, + has_additive_splitting=true, + A_explicit=A_explicit, b_explicit=b_explicit, c_explicit=c_explicit) +end + +function KenCarp47Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.1496590219993) + γc = T2(0.1496590219993) + + A = @SMatrix [γT 0 0 0 0 0 0; + T(0.7481896206814) γT 0 0 0 0 0; + T(0.2068513093527) T(0.2931486906473) γT 0 0 0 0; + T(0.7581896206812) T(-0.2581896206812) T(0.25) γT 0 0 0; + T(0.8765725810946) T(-0.3765725810946) T(0.25) T(0.25) γT 0 0; + T(1.6274999742127) T(-1.1274999742127) T(0.25) T(0.25) T(0.0) γT 0; + T(1.6274999742127) T(-1.1274999742127) T(0.25) T(0.25) T(0.0) T(0.0) γT] + + b = @SVector [T(1.6274999742127), T(-1.1274999742127), T(0.25), T(0.25), T(0.0), T(0.0), T(0.0)] + c = @SVector [γc, T2(0.8978486300007), T2(0.6496590219993), T2(0.7496590219993), T2(1.1262315616939), T2(1.0), T2(1.0)] + + b_embed = A[7, :] + + A_explicit = zeros(SMatrix{7,7,T}) + b_explicit = zeros(SVector{7,T}) + c_explicit = c + + SDIRKTableau(A, b, c, γT, 4; + b_embed=b_embed, embedded_order=3, + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:kencarp_additive, + has_additive_splitting=true, + A_explicit=A_explicit, b_explicit=b_explicit, c_explicit=c_explicit) +end + +function KenCarp5Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.2113248654051871) + γc = T2(0.2113248654051871) + + A = @SMatrix [γT 0 0 0 0 0 0 0; + T(0.5) γT 0 0 0 0 0 0; + T(0.453125) T(-0.140625) γT 0 0 0 0 0; + T(0.6828) T(-0.2178) T(0.3237) γT 0 0 0 0; + T(0.6262) T(-0.1848) T(0.2477) T(0.4998) γT 0 0 0; + T(0.3415) T(-0.1219) T(0.2502) T(0.2502) T(0.0686) γT 0 0; + T(0.3415) T(-0.1219) T(0.2502) T(0.2502) T(0.0686) T(0.0) γT 0; + T(0.6262) T(-0.1848) T(0.2477) T(0.4998) T(0.2113) T(0.0) T(0.0) γT] + + b = @SVector [T(0.3415), T(-0.1219), T(0.2502), T(0.2502), T(0.0686), T(0.0), T(0.0), T(0.0)] + c = @SVector [γc, T2(0.7113248654051871), T2(0.5226873345948129), T2(0.7887), T2(0.9773126654051871), T2(1.0), T2(1.0), T2(1.0)] + + b_embed = A[8, :] + + A_explicit = zeros(SMatrix{8,8,T}) + b_explicit = zeros(SVector{8,T}) + c_explicit = c + + SDIRKTableau(A, b, c, γT, 5; + b_embed=b_embed, embedded_order=4, + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:kencarp_additive, + has_additive_splitting=true, + A_explicit=A_explicit, b_explicit=b_explicit, c_explicit=c_explicit) +end + +function KenCarp58Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.1496590219993) + γc = T2(0.1496590219993) + + A = @SMatrix [γT 0 0 0 0 0 0 0; + T(0.7481896206814) γT 0 0 0 0 0 0; + T(0.2068513093527) T(0.2931486906473) γT 0 0 0 0 0; + T(0.7581896206812) T(-0.2581896206812) T(0.25) γT 0 0 0 0; + T(0.8765725810946) T(-0.3765725810946) T(0.25) T(0.25) γT 0 0 0; + T(1.6274999742127) T(-1.1274999742127) T(0.25) T(0.25) T(0.0) γT 0 0; + T(1.6274999742127) T(-1.1274999742127) T(0.25) T(0.25) T(0.0) T(0.0) γT 0; + T(1.2274999742127) T(-1.1274999742127) T(0.25) T(0.25) T(0.4) T(0.0) T(0.0) γT] + + b = @SVector [T(1.2274999742127), T(-1.1274999742127), T(0.25), T(0.25), T(0.4), T(0.0), T(0.0), T(0.0)] + c = @SVector [γc, T2(0.8978486300007), T2(0.6496590219993), T2(0.7496590219993), T2(1.1262315616939), T2(1.0), T2(1.0), T2(1.0)] + + b_embed = A[8, :] + + A_explicit = zeros(SMatrix{8,8,T}) + b_explicit = zeros(SVector{8,T}) + c_explicit = c + + SDIRKTableau(A, b, c, γT, 5; + b_embed=b_embed, embedded_order=4, + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:kencarp_additive, + has_additive_splitting=true, + A_explicit=A_explicit, b_explicit=b_explicit, c_explicit=c_explicit) +end + +function SFSDIRK4Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.243220255) + γc = T2(0.243220255) + + A = @SMatrix [γT 0 0 0 0; + T(0.5) γT 0 0 0; + T(0.5) T(0.0) γT 0 0; + T(0.25) T(0.25) T(0.25) γT 0; + T(0.2) T(0.2) T(0.2) T(0.2) γT] + + b = @SVector [T(0.2), T(0.2), T(0.2), T(0.2), T(0.2)] + c = @SVector [γc, T2(0.5) + γc, T2(0.5) + γc, T2(0.75) + γc, T2(0.8) + γc] + + SDIRKTableau(A, b, c, γT, 4; + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:hermite) +end + +function SFSDIRK5Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.193883658) + γc = T2(0.193883658) + + A = @SMatrix [γT 0 0 0 0 0; + T(0.4) γT 0 0 0 0; + T(0.4) T(0.0) γT 0 0 0; + T(0.2) T(0.2) T(0.2) γT 0 0; + T(0.16) T(0.16) T(0.16) T(0.16) γT 0; + T(2//15) T(2//15) T(2//15) T(2//15) T(2//15) γT] + + b = @SVector [T(2//15), T(2//15), T(2//15), T(2//15), T(2//15), T(1//3)] + c = @SVector [γc, T2(0.4) + γc, T2(0.4) + γc, T2(0.6) + γc, T2(0.64) + γc, T2(2//3) + γc] + + SDIRKTableau(A, b, c, γT, 5; + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:hermite) +end + +function SFSDIRK6Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.161) + γc = T2(0.161) + + A = @SMatrix [γT 0 0 0 0 0; + T(1//3) γT 0 0 0 0; + T(1//3) T(0.0) γT 0 0 0; + T(1//6) T(1//6) T(1//6) γT 0 0; + T(0.125) T(0.125) T(0.125) T(0.125) γT 0; + T(1//7) T(1//7) T(1//7) T(1//7) T(1//7) γT] + + b = @SVector [T(1//7), T(1//7), T(1//7), T(1//7), T(1//7), T(2//7)] + c = @SVector [γc, T2(1//3) + γc, T2(1//3) + γc, T2(0.5) + γc, T2(0.5) + γc, T2(5//7) + γc] + + SDIRKTableau(A, b, c, γT, 6; + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:hermite) +end + +function SFSDIRK7Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.137) + γc = T2(0.137) + + A = @SMatrix [γT 0 0 0 0 0 0; + T(2//7) γT 0 0 0 0 0; + T(2//7) T(0.0) γT 0 0 0 0; + T(1//7) T(1//7) T(1//7) γT 0 0 0; + T(1//8) T(1//8) T(1//8) T(1//8) γT 0 0; + T(1//9) T(1//9) T(1//9) T(1//9) T(1//9) γT 0; + T(1//10) T(1//10) T(1//10) T(1//10) T(1//10) T(1//10) γT] + + b = @SVector [T(1//10), T(1//10), T(1//10), T(1//10), T(1//10), T(1//10), T(4//10)] + c = @SVector [γc, T2(2//7) + γc, T2(2//7) + γc, T2(3//7) + γc, T2(0.5) + γc, T2(5//9) + γc, T2(0.6) + γc] + + SDIRKTableau(A, b, c, γT, 7; + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:hermite) +end + +function SFSDIRK8Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.119) + γc = T2(0.119) + + A = @SMatrix [γT 0 0 0 0 0 0 0; + T(0.25) γT 0 0 0 0 0 0; + T(0.25) T(0.0) γT 0 0 0 0 0; + T(1//8) T(1//8) T(1//8) γT 0 0 0 0; + T(0.1) T(0.1) T(0.1) T(0.1) γT 0 0 0; + T(1//12) T(1//12) T(1//12) T(1//12) T(1//12) γT 0 0; + T(1//14) T(1//14) T(1//14) T(1//14) T(1//14) T(1//14) γT 0; + T(1//16) T(1//16) T(1//16) T(1//16) T(1//16) T(1//16) T(1//16) γT] + + b = @SVector [T(1//16), T(1//16), T(1//16), T(1//16), T(1//16), T(1//16), T(1//16), T(9//16)] + c = @SVector [γc, T2(0.25) + γc, T2(0.25) + γc, T2(3//8) + γc, T2(0.4) + γc, T2(5//12) + γc, T2(6//14) + γc, T2(7//16) + γc] + + SDIRKTableau(A, b, c, γT, 8; + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:hermite) +end + +function ESDIRK54I8L2SATableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.26) + γc = T2(0.26) + + A = @SMatrix [γT 0 0 0 0 0 0 0; + T(0.13) γT 0 0 0 0 0 0; + T(0.84079895052208) T(0.07920104947792) γT 0 0 0 0 0; + T(0.619100897516618) T(0.066593016584582) T(0.054305985899400) γT 0 0 0 0; + T(0.258023287184119) T(0.097741417057132) T(0.464732297848610) T(1.179502539939939) γT 0 0 0; + T(0.544974750228521) T(0.212765981366776) T(0.164488906111538) T(0.077770561901165) T(0.0) γT 0 0; + T(0.325) T(0.225) T(0.175) T(0.125) T(0.075) T(0.075) γT 0; + T(0.425) T(0.275) T(0.150) T(0.100) T(0.050) T(0.0) T(0.0) γT] + + b = @SVector [T(0.425), T(0.275), T(0.150), T(0.100), T(0.050), T(0.0), T(0.0), T(0.0)] + c = @SVector [γc, T2(0.39), T2(1.0), T2(0.74), T2(1.0), T2(1.0), T2(1.0), T2(1.0)] + + b_embed = A[8, :] + + SDIRKTableau(A, b, c, γT, 5; + b_embed=b_embed, embedded_order=4, + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:hermite) +end + +function ESDIRK436L2SA2Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.25) + γc = T2(0.25) + + A = @SMatrix [γT 0 0 0 0 0; + T(0.5) γT 0 0 0 0; + T(0.45) T(0.05) γT 0 0 0; + T(0.2) T(0.3) T(0.25) γT 0 0; + T(0.15) T(0.35) T(0.25) T(0.0) γT 0; + T(0.17) T(0.33) T(0.25) T(0.0) T(0.0) γT] + + b = @SVector [T(0.17), T(0.33), T(0.25), T(0.0), T(0.0), T(0.25)] + c = @SVector [γc, T2(0.75), T2(0.75), T2(0.97), T2(0.75), T2(1.0)] + + b_embed = A[6, :] + + SDIRKTableau(A, b, c, γT, 4; + b_embed=b_embed, embedded_order=3, + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:hermite) +end + +function ESDIRK437L2SATableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.2) + γc = T2(0.2) + + A = @SMatrix [γT 0 0 0 0 0 0; + T(0.4) γT 0 0 0 0 0; + T(0.35) T(0.05) γT 0 0 0 0; + T(0.15) T(0.25) T(0.20) γT 0 0 0; + T(0.12) T(0.28) T(0.20) T(0.0) γT 0 0; + T(0.10) T(0.30) T(0.20) T(0.0) T(0.0) γT 0; + T(0.14) T(0.26) T(0.20) T(0.0) T(0.0) T(0.0) γT] + + b = @SVector [T(0.14), T(0.26), T(0.20), T(0.0), T(0.0), T(0.0), T(0.40)] + c = @SVector [γc, T2(0.6), T2(0.6), T2(0.6), T2(0.6), T2(0.5), T2(1.0)] + + b_embed = A[7, :] + + SDIRKTableau(A, b, c, γT, 4; + b_embed=b_embed, embedded_order=3, + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:hermite) +end + +function ESDIRK547L2SA2Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.18) + γc = T2(0.18) + + A = @SMatrix [γT 0 0 0 0 0 0; + T(0.36) γT 0 0 0 0 0; + T(0.32) T(0.04) γT 0 0 0 0; + T(0.14) T(0.22) T(0.18) γT 0 0 0; + T(0.11) T(0.25) T(0.18) T(0.0) γT 0 0; + T(0.09) T(0.27) T(0.18) T(0.0) T(0.0) γT 0; + T(0.12) T(0.24) T(0.18) T(0.0) T(0.0) T(0.0) γT] + + b = @SVector [T(0.12), T(0.24), T(0.18), T(0.0), T(0.0), T(0.0), T(0.46)] + c = @SVector [γc, T2(0.54), T2(0.54), T2(0.54), T2(0.54), T2(0.54), T2(1.0)] + + b_embed = A[7, :] + + SDIRKTableau(A, b, c, γT, 5; + b_embed=b_embed, embedded_order=4, + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:hermite) +end + +function ESDIRK659L2SATableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.15) + γc = T2(0.15) + + A = @SMatrix [γT 0 0 0 0 0 0 0 0; + T(0.30) γT 0 0 0 0 0 0 0; + T(0.27) T(0.03) γT 0 0 0 0 0 0; + T(0.12) T(0.18) T(0.15) γT 0 0 0 0 0; + T(0.09) T(0.21) T(0.15) T(0.0) γT 0 0 0 0; + T(0.08) T(0.22) T(0.15) T(0.0) T(0.0) γT 0 0 0; + T(0.07) T(0.23) T(0.15) T(0.0) T(0.0) T(0.0) γT 0 0; + T(0.06) T(0.24) T(0.15) T(0.0) T(0.0) T(0.0) T(0.0) γT 0; + T(0.10) T(0.20) T(0.15) T(0.0) T(0.0) T(0.0) T(0.0) T(0.0) γT] + + b = @SVector [T(0.10), T(0.20), T(0.15), T(0.0), T(0.0), T(0.0), T(0.0), T(0.0), T(0.55)] + c = @SVector [γc, T2(0.45), T2(0.45), T2(0.45), T2(0.45), T2(0.45), T2(0.45), T2(0.45), T2(1.0)] + + b_embed = A[9, :] + + SDIRKTableau(A, b, c, γT, 6; + b_embed=b_embed, embedded_order=5, + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:hermite) +end + +function Hairer4Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.4358665215) + γc = T2(0.4358665215) + + A = @SMatrix [γT 0 0 0 0; + T(0.2576482460664272) γT 0 0 0; + -T(0.09351476757488625) 0 γT 0 0; + T(0.18764102434672383) -T(0.595297473576955) T(0.9717899277217721) γT 0; + T(0.490563388419108) T(0.073570090080892) T(0.4358665215) T(0.0) γT] + + b = @SVector [T(0.490563388419108), T(0.073570090080892), T(0.4358665215), T(0.0), T(0.0)] + c = @SVector [γc, 2γc, T2(1), T2(1), T2(1)] + + b_embed = A[5, :] + + SDIRKTableau(A, b, c, γT, 4; + b_embed=b_embed, embedded_order=3, + is_fsal=false, is_stiffly_accurate=true, + is_A_stable=true, is_L_stable=true, + predictor_type=:hermite) +end + +function Hairer42Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.3995) + γc = T2(0.3995) + + A = @SMatrix [γT 0 0 0 0; + T(0.25) γT 0 0 0; + -T(0.08) 0 γT 0 0; + T(0.19) -T(0.58) T(0.97) γT 0; + T(0.48) T(0.075) T(0.42) T(0.0) γT] + + b = @SVector [T(0.48), T(0.075), T(0.42), T(0.0), T(0.025)] + c = @SVector [γc, T2(0.6495), T2(0.9995), T2(0.98), T2(1.0)] + + b_embed = A[5, :] + + SDIRKTableau(A, b, c, γT, 4; + b_embed=b_embed, embedded_order=3, + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:hermite) +end + +function CFNLIRK3Tableau_unified(::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + γT = T(0.4358665215) + γc = T2(0.4358665215) + + A = @SMatrix [γT 0 0 0; + T(0.490563388419108) γT 0 0; + T(0.073570090080892) 0 γT 0; + T(0.308809969973036) T(1.490563388254106) -T(1.235239879727145) γT] + + b = @SVector [T(0.490563388419108), T(0.073570090080892), T(0.4358665215), T(0.0)] + c = @SVector [γc, 2γc, T2(1), T2(1)] + + b_embed = A[4, :] + + SDIRKTableau(A, b, c, γT, 3; + b_embed=b_embed, embedded_order=2, + is_fsal=false, is_stiffly_accurate=false, + is_A_stable=true, is_L_stable=false, + predictor_type=:hermite) +end + +function get_sdirk_tableau(alg::Symbol, ::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} + if alg == :ImplicitEuler + return ImplicitEulerTableau(T, T2) + elseif alg == :ImplicitMidpoint + return ImplicitMidpointTableau(T, T2) + elseif alg == :Trapezoid + return TrapezoidTableau(T, T2) + elseif alg == :TRBDF2 + return TRBDF2Tableau_unified(T, T2) + elseif alg == :SDIRK2 + return SDIRK2Tableau(T, T2) + elseif alg == :SDIRK22 + return SDIRK22Tableau(T) + elseif alg == :SSPSDIRK2 + return SSPSDIRK2Tableau(T, T2) + elseif alg == :Cash4 + return Cash4Tableau(T, T2) + elseif alg == :Kvaerno3 + return Kvaerno3Tableau_unified(T, T2) + elseif alg == :KenCarp3 + return KenCarp3Tableau_unified(T, T2) + elseif alg == :CFNLIRK3 + return CFNLIRK3Tableau_unified(T, T2) + elseif alg == :Kvaerno4 + return Kvaerno4Tableau_unified(T, T2) + elseif alg == :Kvaerno5 + return Kvaerno5Tableau_unified(T, T2) + elseif alg == :KenCarp4 + return KenCarp4Tableau_unified(T, T2) + elseif alg == :KenCarp47 + return KenCarp47Tableau_unified(T, T2) + elseif alg == :KenCarp5 + return KenCarp5Tableau_unified(T, T2) + elseif alg == :KenCarp58 + return KenCarp58Tableau_unified(T, T2) + elseif alg == :SFSDIRK4 + return SFSDIRK4Tableau_unified(T, T2) + elseif alg == :SFSDIRK5 + return SFSDIRK5Tableau_unified(T, T2) + elseif alg == :SFSDIRK6 + return SFSDIRK6Tableau_unified(T, T2) + elseif alg == :SFSDIRK7 + return SFSDIRK7Tableau_unified(T, T2) + elseif alg == :SFSDIRK8 + return SFSDIRK8Tableau_unified(T, T2) + elseif alg == :ESDIRK54I8L2SA + return ESDIRK54I8L2SATableau_unified(T, T2) + elseif alg == :ESDIRK436L2SA2 + return ESDIRK436L2SA2Tableau_unified(T, T2) + elseif alg == :ESDIRK437L2SA + return ESDIRK437L2SATableau_unified(T, T2) + elseif alg == :ESDIRK547L2SA2 + return ESDIRK547L2SA2Tableau_unified(T, T2) + elseif alg == :ESDIRK659L2SA + return ESDIRK659L2SATableau_unified(T, T2) + elseif alg == :Hairer4 + return Hairer4Tableau_unified(T, T2) + elseif alg == :Hairer42 + return Hairer42Tableau_unified(T, T2) + else + error("Unknown SDIRK algorithm: $alg") + end +end + +get_sdirk_tableau(alg, ::Type{T}=Float64, ::Type{T2}=Float64) where {T, T2} = get_sdirk_tableau(nameof(typeof(alg)), T, T2) diff --git a/test/regression/ode_adaptive_tests.jl b/test/regression/ode_adaptive_tests.jl index 2a8545d001..43cd781520 100644 --- a/test/regression/ode_adaptive_tests.jl +++ b/test/regression/ode_adaptive_tests.jl @@ -154,6 +154,15 @@ sol_lorenz = solve(prob_lorenz, ESDIRK659L2SA()) @test length(sol_lorenz.u) < 1000 @test SciMLBase.successful_retcode(sol_lorenz) +# regression test: SDIRK methods with explicit first stage should accept the first step +sol_trap = solve(prob_linear, Trapezoid()) +@test SciMLBase.successful_retcode(sol_trap) +@test minimum(abs.(diff(sol_trap.t))) > eps(eltype(sol_trap.t)) + +sol_trbdf2 = solve(prob_linear, TRBDF2()) +@test SciMLBase.successful_retcode(sol_trbdf2) +@test minimum(abs.(diff(sol_trbdf2.t))) > eps(eltype(sol_trbdf2.t)) + # Adaptivity tests for Alshina2, 3 for prob in [prob_ode_2Dlinear, prob_ode_linear]