From 4c10e3e427e65340a8280689e898bee67be21ed6 Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Wed, 18 Nov 2020 09:05:15 -0800 Subject: [PATCH 1/4] start on dtcache --- src/integrators.jl | 57 ++++++++++++++++++++++++++++------ src/solvers/lsrk.jl | 9 ++++++ src/solvers/mis.jl | 7 +++++ src/solvers/multirate.jl | 11 +++++-- src/solvers/wickerskamarock.jl | 9 ++++++ 5 files changed, 82 insertions(+), 11 deletions(-) diff --git a/src/integrators.jl b/src/integrators.jl index dc5c0165..dfe052e6 100644 --- a/src/integrators.jl +++ b/src/integrators.jl @@ -21,15 +21,15 @@ end # called by DiffEqBase.init and solve (see below) function DiffEqBase.__init( - prob::DiffEqBase.AbstractODEProblem, - alg::DistributedODEAlgorithm, - args...; + prob::DiffEqBase.AbstractODEProblem, + alg::DistributedODEAlgorithm, + args...; dt, # required - stepstop=-1, + stepstop=-1, adjustfinal=false, callback=nothing, - kwargs...) - + kwargs...) + u = prob.u0 t = prob.tspan[1] tstop = prob.tspan[2] @@ -46,10 +46,10 @@ end # called by DiffEqBase.solve function DiffEqBase.__solve( prob::DiffEqBase.AbstractODEProblem, - alg::DistributedODEAlgorithm, + alg::DistributedODEAlgorithm, args...; kwargs...) - + integrator = DiffEqBase.__init(prob, alg, args...; kwargs...) DiffEqBase.solve!(integrator) return integrator.u # ODEProblem returns a Solution objec @@ -90,14 +90,53 @@ function DiffEqBase.step!(integrator::DistributedODEIntegrator) end # solvers need to define this interface -step_u!(integrator) = step_u!(integrator, integrator.cache) +step_u!(integrator) = step_u!(integrator, integrator.cache) function adjust_dt!(integrator::DistributedODEIntegrator, dt) # TODO: figure out interface for recomputing other objects (linear operators, etc) integrator.dt = dt + adjust_dt!() end # not sure what this should do? # defined as default initialize: https://github.com/SciML/DiffEqBase.jl/blob/master/src/callbacks.jl#L3 DiffEqBase.u_modified!(i::DistributedODEIntegrator,bool) = nothing + +#= +cache +dt_cache + + +multirate => need a way to compute fast dts + - WS => round to dt/2 (WSRK2) or dt/6 (WSRK3) + dt/6 / round(dt/6 / fast_dt) + - MIS => round to tab.d[i]*dt for each i + tab.d[i]*dt / round(tab.d[i]*dt / fast_dt) + - LSRK => round to dt * (c[i+1] - c[i]) + + +sub_dts() = () + + + + + +init_alg_cache(prob, alg, dt) + +will call + +init_cache_dt(prob, alg, dt) + + +MultirateCache will have a dtcache field + +A multirate scheme will init the necessary dts: + +init_cache_dt(prob, alg::Multirate, dt) + +which will call +dtcache = (i -> init_cache(innerprob, inneralg, dt*c[i]), Nstages) + + +=# \ No newline at end of file diff --git a/src/solvers/lsrk.jl b/src/solvers/lsrk.jl index e5493283..50f38ddd 100644 --- a/src/solvers/lsrk.jl +++ b/src/solvers/lsrk.jl @@ -60,6 +60,15 @@ function step_u!(int, cache::LowStorageRungeKutta2NIncCache) end # for Multirate +function inner_dts(outercache::LowStorageRungeKutta2NIncCache, dt, fast_dt) + N = nstages(outercache) + tab = outercache.tableau + ntuple(N) do i + Δt = (i == N ? 1-tab.C[i] : tab.C[i+1] - tab.C[i]) * dt + Δt / round(Δt / fast_dt) + end +end + function init_inner(prob, outercache::LowStorageRungeKutta2NIncCache, dt) OffsetODEFunction(prob.f.f1, zero(dt), one(dt), zero(dt), outercache.du) end diff --git a/src/solvers/mis.jl b/src/solvers/mis.jl index 26068567..68b23062 100644 --- a/src/solvers/mis.jl +++ b/src/solvers/mis.jl @@ -66,6 +66,13 @@ function cache( return MultirateInfinitesimalStepCache(ΔU, F, tab) end +function inner_dts(outercache::MultirateInfinitesimalStepCache, dt, fast_dt) + tab = outercache.tableau + map(tab.d) do d_i + Δt = d_i*dt + Δt / round(Δt / fast_dt) + end +end function init_inner(prob, outercache::MultirateInfinitesimalStepCache, dt) OffsetODEFunction(prob.f.f1, zero(dt), one(dt), one(dt), outercache.ΔU[end]) diff --git a/src/solvers/multirate.jl b/src/solvers/multirate.jl index 0e379423..92d7dc34 100644 --- a/src/solvers/multirate.jl +++ b/src/solvers/multirate.jl @@ -19,9 +19,10 @@ struct Multirate{F,S} <: DistributedODEAlgorithm end -struct MultirateCache{OC,II} +struct MultirateCache{OC,II,SD} outercache::OC innerinteg::II + subdtcache::SD end function cache( @@ -35,9 +36,15 @@ function cache( outerprob = DiffEqBase.remake(prob; f=prob.f.f2) outercache = cache(outerprob, alg.slow) + sub_dts = inner_dts(outercache, dt, fast_dt) + innerfun = init_inner(prob, outercache, dt) innerprob = DiffEqBase.remake(prob; f=innerfun) - innerinteg = DiffEqBase.init(innerprob, alg.fast; dt=fast_dt, kwargs...) + innerinteg = DiffEqBase.init(innerprob, alg.fast; dt=sub_dts[1], kwargs...) + + init_cache_dt(innerinteg.cache) + + return MultirateCache(outercache, innerinteg) end diff --git a/src/solvers/wickerskamarock.jl b/src/solvers/wickerskamarock.jl index 6a1c5341..63ea12c2 100644 --- a/src/solvers/wickerskamarock.jl +++ b/src/solvers/wickerskamarock.jl @@ -32,6 +32,15 @@ end nstages(::WickerSkamarockRungeKuttaCache{Nstages}) where {Nstages} = Nstages +function inner_dts(outercache::WickerSkamarockRungeKuttaCache, dt, fast_dt) + tab = outercache.tableau + if length(tab.c) == 2 # WSRK2 + Δt = dt/2 + else # WSRK3 + Δt = dt/6 + end + return (Δt / round(Δt / fast_dt),) +end function init_inner(prob, outercache::WickerSkamarockRungeKuttaCache, dt) OffsetODEFunction(prob.f.f1, zero(dt), one(dt), one(dt), outercache.F) From 48615d2b07cfa59ab14c4bfa546b44517c6f8ca4 Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Wed, 18 Nov 2020 21:29:55 -0800 Subject: [PATCH 2/4] add dt_cache infratructure for multirate --- src/integrators.jl | 77 +++++++++++++++++++++++- src/solvers/ark.jl | 2 +- src/solvers/lsrk.jl | 6 +- src/solvers/mis.jl | 4 +- src/solvers/multirate.jl | 105 ++++++++++++++++++++++++++------- src/solvers/ssprk.jl | 4 +- src/solvers/wickerskamarock.jl | 7 ++- 7 files changed, 170 insertions(+), 35 deletions(-) diff --git a/src/integrators.jl b/src/integrators.jl index dfe052e6..98aa3468 100644 --- a/src/integrators.jl +++ b/src/integrators.jl @@ -36,7 +36,7 @@ function DiffEqBase.__init( callbackset = DiffEqBase.CallbackSet(callback) isempty(callbackset.continuous_callbacks) || error("Continuous callbacks are not supported") - integrator = DistributedODEIntegrator(prob, alg, u, dt, t, tstop, 0, stepstop, adjustfinal, callbackset, false, cache(prob, alg; dt=dt, kwargs...)) + integrator = DistributedODEIntegrator(prob, alg, u, dt, t, tstop, 0, stepstop, adjustfinal, callbackset, false, init_cache(prob, alg; dt=dt, kwargs...)) DiffEqBase.initialize!(callbackset,u,t,integrator) return integrator @@ -61,6 +61,10 @@ function DiffEqBase.solve!(integrator::DistributedODEIntegrator) if integrator.adjustfinal && integrator.t + integrator.dt > integrator.tstop adjust_dt!(integrator, integrator.tstop - integrator.t) end + if !integrator.adjustfinal && integrator.t + integrator.dt/2 > integrator.tstop + break + end + DiffEqBase.step!(integrator) if integrator.step == integrator.stepstop @@ -92,13 +96,80 @@ end # solvers need to define this interface step_u!(integrator) = step_u!(integrator, integrator.cache) -function adjust_dt!(integrator::DistributedODEIntegrator, dt) +""" + adjust_dt!(integrator::DistributedODEIntegrator, dt[, dt_cache=nothing]) + +Adjust the time step of the integrator to `dt`. The optional `dt_cache` object +can be passed when the integrator has a `dt`-dependent component that needs to +be updated (such as a linear solver). +""" +function adjust_dt!(integrator::DistributedODEIntegrator, dt, dt_cache=nothing) # TODO: figure out interface for recomputing other objects (linear operators, etc) integrator.dt = dt - adjust_dt!() + adjust_dt!(integrator.cache, dt, dt_cache) +end + +# interfaces + +""" + init_cache(prob, alg::A; kwargs...)::AC + +Construct an algorithm cache for the algorithm `alg`. This should be defined +for any algorithm type `A`, and should return an object of an appropriate cache +type `AC` that can be dispatched on for [`step_u!`](@ref) and/or +[`init_inner`](@ref)/[`update_inner!`](@ref). +""" +function init_cache end + +""" + step_u!(integrator, cache::AC) + +Perform a single step that updates the state `integrator.u` using accordint to +the algorithm corresponding to `cache`. + +This should be defined for any algorithm cache type `AC` that can be used +directly or as an inner timestepper. For outer timesteppers, +[`init_inner`](@ref) and [`update_inner!`](@ref) need to be defined instead. +""" +step_u!(integrator, cache) + +""" + init_dt_cache(cache::AC, dt) + +Construct a `dt`-dependent subcache of `cache`. This should _not_ modify `cache` +itself, but return an object that can be passed as the `dt_cache` argument to +[`adjust_dt!`](@ref). + +By default this returns `nothing`. This should be defined for any algorithm +cache type `AC` which has `dt`-dependent components. + +For example, an implicit solver can use this to return a factorized Euler +operator ``I-dt*L`` that is used as part of the implicit solve. + +This initialization will typically be done as part of [`init_cache`](@ref) +itself: this interface is provided for multirate schemes which need to modify +the `dt` of the inner solver at each outer stage. +""" +function init_dt_cache(cache, dt) + return nothing end +function get_dt_cache(cache) + return nothing +end + +""" + adjust_dt!(cache::AC, dt, dt_cache) + +Adjust the time step of the algorithm cache `cache`. This should be defined for +any algorithm cache type `AC`, where `dt_cache` is an object returned by +[`init_dt_cache`](@ref). +""" +adjust_dt!(cache, dt, dt_cache) + + + # not sure what this should do? # defined as default initialize: https://github.com/SciML/DiffEqBase.jl/blob/master/src/callbacks.jl#L3 DiffEqBase.u_modified!(i::DistributedODEIntegrator,bool) = nothing diff --git a/src/solvers/ark.jl b/src/solvers/ark.jl index e00917f4..e4841058 100644 --- a/src/solvers/ark.jl +++ b/src/solvers/ark.jl @@ -39,7 +39,7 @@ struct AdditiveRungeKuttaFullCache{Nstages, RT, A, O, L} end -function cache( +function init_cache( prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, alg::AdditiveRungeKutta; dt, kwargs...) where {uType,tType} diff --git a/src/solvers/lsrk.jl b/src/solvers/lsrk.jl index 50f38ddd..66a640e3 100644 --- a/src/solvers/lsrk.jl +++ b/src/solvers/lsrk.jl @@ -33,7 +33,7 @@ struct LowStorageRungeKutta2NIncCache{Nstages, RT, A} du::A end -function cache(prob::DiffEqBase.ODEProblem, alg::LowStorageRungeKutta2N; kwargs...) +function init_cache(prob::DiffEqBase.ODEProblem, alg::LowStorageRungeKutta2N; kwargs...) # @assert prob.problem_type isa DiffEqBase.IncrementingODEProblem || # prob.f isa DiffEqBase.IncrementingODEFunction du = zero(prob.u0) @@ -59,6 +59,8 @@ function step_u!(int, cache::LowStorageRungeKutta2NIncCache) end end +adjust_dt!(cache::LowStorageRungeKutta2NIncCache, dt, ::Nothing) = nothing + # for Multirate function inner_dts(outercache::LowStorageRungeKutta2NIncCache, dt, fast_dt) N = nstages(outercache) @@ -69,7 +71,7 @@ function inner_dts(outercache::LowStorageRungeKutta2NIncCache, dt, fast_dt) end end -function init_inner(prob, outercache::LowStorageRungeKutta2NIncCache, dt) +function init_inner_fun(prob, outercache::LowStorageRungeKutta2NIncCache, dt) OffsetODEFunction(prob.f.f1, zero(dt), one(dt), zero(dt), outercache.du) end function update_inner!(innerinteg, outercache::LowStorageRungeKutta2NIncCache, diff --git a/src/solvers/mis.jl b/src/solvers/mis.jl index 68b23062..34341732 100644 --- a/src/solvers/mis.jl +++ b/src/solvers/mis.jl @@ -51,7 +51,7 @@ end nstages(::MultirateInfinitesimalStepCache{Nstages}) where {Nstages} = Nstages -function cache( +function init_cache( prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, alg::MultirateInfinitesimalStep; kwargs...) where {uType,tType} @@ -74,7 +74,7 @@ function inner_dts(outercache::MultirateInfinitesimalStepCache, dt, fast_dt) end end -function init_inner(prob, outercache::MultirateInfinitesimalStepCache, dt) +function init_inner_fun(prob, outercache::MultirateInfinitesimalStepCache, dt) OffsetODEFunction(prob.f.f1, zero(dt), one(dt), one(dt), outercache.ΔU[end]) end diff --git a/src/solvers/multirate.jl b/src/solvers/multirate.jl index 92d7dc34..b959f140 100644 --- a/src/solvers/multirate.jl +++ b/src/solvers/multirate.jl @@ -22,10 +22,10 @@ end struct MultirateCache{OC,II,SD} outercache::OC innerinteg::II - subdtcache::SD + dt_cache::SD end -function cache( +function init_cache( prob::DiffEqBase.AbstractODEProblem, alg::Multirate; dt, fast_dt, kwargs...) @@ -34,20 +34,50 @@ function cache( # subproblems outerprob = DiffEqBase.remake(prob; f=prob.f.f2) - outercache = cache(outerprob, alg.slow) + outercache = init_cache(outerprob, alg.slow) sub_dts = inner_dts(outercache, dt, fast_dt) + unique_sub_dts = unique(sub_dts) - innerfun = init_inner(prob, outercache, dt) + innerfun = init_inner_fun(prob, outercache, dt) innerprob = DiffEqBase.remake(prob; f=innerfun) - innerinteg = DiffEqBase.init(innerprob, alg.fast; dt=sub_dts[1], kwargs...) + innerinteg = DiffEqBase.init(innerprob, alg.fast; dt=unique_sub_dts[1], adjustfinal=false, kwargs...) - init_cache_dt(innerinteg.cache) + # build dt_cache + unique_dt_caches = [ + i == 1 ? get_dt_cache(innerinteg.cache) : init_dt_cache(innerinteg.cache, unique_sub_dts[i]) + for i = 1:length(unique_sub_dts)] + dt_cache = map(sub_dts) do sub_dt + i = findfirst(==(sub_dt), unique_sub_dts) + unique_sub_dts[i] => unique_dt_caches[i] + end - return MultirateCache(outercache, innerinteg) + return MultirateCache(outercache, innerinteg, dt_cache) end +get_dt_cache(cache::Multirate) = cache.dt_cache +function init_dt_cache(cache::Multirate, dt) + outercache = cache.outercache + innerinteg = cache.innerinteg + + fast_dt = innerinteg.dt # TODO: get the original fast_dt from somewhere + + sub_dts = inner_dts(outercache, dt, fast_dt) + unique_sub_dts = unique(sub_dts) + + unique_dt_caches = [ + init_dt_cache(innerinteg.cache, unique_sub_dts[i]) + for i = 1:length(unique_sub_dts)] + + dt_cache = map(sub_dts) do sub_dt + i = findfirst(==(sub_dt), unique_sub_dts) + unique_sub_dts[i] => unique_dt_caches[i] + end + return dt_cache +end +adjust_dt!(cache::Multirate, dt, dt_cache::Tuple) = cache.dt_cache + function step_u!(int, cache::MultirateCache) outercache = cache.outercache @@ -61,23 +91,54 @@ function step_u!(int, cache::MultirateCache) innerinteg = cache.innerinteg fast_dt = innerinteg.dt - N = nstages(outercache) - for stage in 1:N + for i in 1:nstages(outercache) + sub_dt, sub_dt_cache = cache.dt_cache[i] + adjust_dt!(innerinteg, sub_dt, sub_dt_cache) + update_inner!(innerinteg, outercache, int.prob.f.f2, u, p, t, dt, i) + DiffEqBase.solve!(innerinteg) + end +end - update_inner!(innerinteg, outercache, int.prob.f.f2, u, p, t, dt, stage) +# interface +""" + nstages(outercache::AC) - # solve inner problem - # dv/dτ .= B[s]/(C[s+1] - C[s]) .* du .+ f_fast(v,τ) τ ∈ [τ0,τ1] +The number of stages of the algorithm determined by cache type `AC`. This should +be defined for any algorithm cache type `AC` used as an outer solver. +""" +function nstages end - # TODO: make this more generic - # there are 2 strategies we can use here: - # a. use same fast_dt for all slow stages, use `adjustfinal=true` - # - problems for ARK (e.g. requires expensive LU factorization) - # b. use different fast_dt, cache expensive ops - innerinteg.adjustfinal = true - DiffEqBase.solve!(innerinteg) - innerinteg.dt = fast_dt # reset - end -end +""" + inner_dts(outercache::AC, dt, fast_dt) + +The inner timesteps that will be used at each stage of the multirate procedure. + +This should be defined for any algorithm cache type `AC` that will be used as an +outer solver, and should return a tuple of the length of the number of stages. +Each value will be approximately `fast_dt`, but rounded so that an integer +number of steps can be used at each outer stage (where `dt` is the slow time +step). +""" +function inner_dts end + +""" + init_inner_fun(prob, outercache::AC, dt) +Construct the inner `ODEFunction` that will be used with inner solver. This +should be defined for any algorithm cache type `AC` that will be used as an +outer solver. +""" +function init_inner_fun end + +""" + update_inner!(innerinteg, outercache::AC, f_slow, u, p, t, dt, i) + +Update the inner integrator `innerinteg` for stage `i` of the outer algorithm. +This should be defined for any `outercache` type `AC`, and will typically modify: +- `innerinteg.prob.f` +- `innerinteg.u` +- `innerinteg.t` +- `innerinteg.tstop` +""" +function update_inner! end \ No newline at end of file diff --git a/src/solvers/ssprk.jl b/src/solvers/ssprk.jl index 46daf40e..6f7bfa65 100644 --- a/src/solvers/ssprk.jl +++ b/src/solvers/ssprk.jl @@ -34,7 +34,7 @@ struct StrongStabilityPreservingRungeKuttaCache{Nstages, RT, A} U::A end -function cache( +function init_cache( prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, alg::StrongStabilityPreservingRungeKutta; kwargs...) where {uType,tType} @@ -44,7 +44,7 @@ function cache( U = zero(prob.u0) return StrongStabilityPreservingRungeKuttaCache(tab, fU, U) end - +adjust_dt!(cache::StrongStabilityPreservingRungeKutta, dt, ::Nothing) = nothing function step_u!(int, cache::StrongStabilityPreservingRungeKuttaCache{Nstages, RT, A}) where {Nstages, RT, A} tab = cache.tableau diff --git a/src/solvers/wickerskamarock.jl b/src/solvers/wickerskamarock.jl index 63ea12c2..3a808e6b 100644 --- a/src/solvers/wickerskamarock.jl +++ b/src/solvers/wickerskamarock.jl @@ -24,7 +24,7 @@ struct WickerSkamarockRungeKuttaCache{Nstages, RT, A} U::A F::A end -function cache(prob::DiffEqBase.ODEProblem, alg::WickerSkamarockRungeKutta; kwargs...) +function init_cache(prob::DiffEqBase.ODEProblem, alg::WickerSkamarockRungeKutta; kwargs...) U = similar(prob.u0) F = similar(prob.u0) return WickerSkamarockRungeKuttaCache(tableau(alg, eltype(F)), U, F) @@ -39,10 +39,11 @@ function inner_dts(outercache::WickerSkamarockRungeKuttaCache, dt, fast_dt) else # WSRK3 Δt = dt/6 end - return (Δt / round(Δt / fast_dt),) + sub_dt = Δt / round(Δt / fast_dt) + return map(c -> sub_dt, tab.c) end -function init_inner(prob, outercache::WickerSkamarockRungeKuttaCache, dt) +function init_inner_fun(prob, outercache::WickerSkamarockRungeKuttaCache, dt) OffsetODEFunction(prob.f.f1, zero(dt), one(dt), one(dt), outercache.F) end function update_inner!(innerinteg, outercache::WickerSkamarockRungeKuttaCache, From dc31baadf41f834e335f398a8936f88762bf4a42 Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Wed, 18 Nov 2020 21:35:49 -0800 Subject: [PATCH 3/4] cleanup --- src/integrators.jl | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/src/integrators.jl b/src/integrators.jl index 98aa3468..cda1cc4f 100644 --- a/src/integrators.jl +++ b/src/integrators.jl @@ -173,41 +173,3 @@ adjust_dt!(cache, dt, dt_cache) # not sure what this should do? # defined as default initialize: https://github.com/SciML/DiffEqBase.jl/blob/master/src/callbacks.jl#L3 DiffEqBase.u_modified!(i::DistributedODEIntegrator,bool) = nothing - -#= -cache -dt_cache - - -multirate => need a way to compute fast dts - - WS => round to dt/2 (WSRK2) or dt/6 (WSRK3) - dt/6 / round(dt/6 / fast_dt) - - MIS => round to tab.d[i]*dt for each i - tab.d[i]*dt / round(tab.d[i]*dt / fast_dt) - - LSRK => round to dt * (c[i+1] - c[i]) - - -sub_dts() = () - - - - - -init_alg_cache(prob, alg, dt) - -will call - -init_cache_dt(prob, alg, dt) - - -MultirateCache will have a dtcache field - -A multirate scheme will init the necessary dts: - -init_cache_dt(prob, alg::Multirate, dt) - -which will call -dtcache = (i -> init_cache(innerprob, inneralg, dt*c[i]), Nstages) - - -=# \ No newline at end of file From 0298a873d8fe8c446e283b049870f18e4452e20d Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Thu, 19 Nov 2020 14:54:38 -0800 Subject: [PATCH 4/4] ARK fixes --- src/integrators.jl | 10 +++++----- src/solvers/ark.jl | 37 ++++++++++++++++++++++++++++--------- src/solvers/multirate.jl | 6 +++--- 3 files changed, 36 insertions(+), 17 deletions(-) diff --git a/src/integrators.jl b/src/integrators.jl index cda1cc4f..a2d08c39 100644 --- a/src/integrators.jl +++ b/src/integrators.jl @@ -134,11 +134,11 @@ directly or as an inner timestepper. For outer timesteppers, step_u!(integrator, cache) """ - init_dt_cache(cache::AC, dt) + init_dt_cache(cache::AC, prob, dt) -Construct a `dt`-dependent subcache of `cache`. This should _not_ modify `cache` -itself, but return an object that can be passed as the `dt_cache` argument to -[`adjust_dt!`](@ref). +Construct a `dt`-dependent subcache of `cache` for the ODE problem `prob`. This +should _not_ modify `cache` itself, but return an object that can be passed as +the `dt_cache` argument to [`adjust_dt!`](@ref). By default this returns `nothing`. This should be defined for any algorithm cache type `AC` which has `dt`-dependent components. @@ -150,7 +150,7 @@ This initialization will typically be done as part of [`init_cache`](@ref) itself: this interface is provided for multirate schemes which need to modify the `dt` of the inner solver at each outer stage. """ -function init_dt_cache(cache, dt) +function init_dt_cache(cache, prob, dt) return nothing end diff --git a/src/solvers/ark.jl b/src/solvers/ark.jl index e4841058..13fb49e9 100644 --- a/src/solvers/ark.jl +++ b/src/solvers/ark.jl @@ -26,7 +26,8 @@ struct AdditiveRungeKuttaTableau{Nstages, Nstages², RT} C::NTuple{Nstages, RT} end -struct AdditiveRungeKuttaFullCache{Nstages, RT, A, O, L} +struct AdditiveRungeKuttaFullCache{Nstages,RT, A, G, O, L} + alg::G "stage value of the state variable" U::A #Qstages "evaluated linear part of each stage ``f_L(U^{(i)})``" @@ -38,6 +39,30 @@ struct AdditiveRungeKuttaFullCache{Nstages, RT, A, O, L} linsolve!::L end +function implicit_part(f::DiffEqBase.ODEFunction) + f.jvp === nothing && error("IMEX solvers require a `SplitODEFunction` or an `ODEFunction` with a `jvp` component.") + return f.jvp +end +implicit_part(f::DiffEqBase.SplitFunction) = f.f1 +implicit_part(f::OffsetODEFunction) = implicit_part(f.f) + +function init_dt_cache(cache::AdditiveRungeKuttaFullCache, prob, dt) + _init_dt_cache(cache.alg, cache.tableau, prob, dt) +end +function _init_dt_cache(alg::AdditiveRungeKutta, tab, prob, dt) + f_impl = implicit_part(prob.f) + W = EulerOperator(f_impl , -dt*tab.Aimpl[2,2], prob.p, prob.tspan[1]) + linsolve! = alg.linsolve(Val{:init}, W, prob.u0) + return (W, linsolve!) +end + +function get_dt_cache(cache::AdditiveRungeKuttaFullCache) + return (cache.W, cache.linsolve!) +end +function adjust_dt!(cache::AdditiveRungeKuttaFullCache, dt, (W, linsolve!)::Tuple) + cache.W = W + cache.linsolve! = linsolve! +end function init_cache( prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, @@ -49,14 +74,8 @@ function init_cache( L = ntuple(i -> zero(prob.u0), Nstages) R = ntuple(i -> zero(prob.u0), Nstages) - if prob.f isa DiffEqBase.ODEFunction - W = EulerOperator(prob.f.jvp, -dt*tab.Aimpl[2,2], prob.p, prob.tspan[1]) - elseif prob.f isa DiffEqBase.SplitFunction - W = EulerOperator(prob.f.f1, -dt*tab.Aimpl[2,2], prob.p, prob.tspan[1]) - end - linsolve! = alg.linsolve(Val{:init}, W, prob.u0; kwargs...) - - AdditiveRungeKuttaFullCache(U, L, R, tab, W, linsolve!) + W, linsolve! = _init_dt_cache(alg, tab, prob, dt) + AdditiveRungeKuttaFullCache(alg, U, L, R, tab, W, linsolve!) end diff --git a/src/solvers/multirate.jl b/src/solvers/multirate.jl index b959f140..d018ec5d 100644 --- a/src/solvers/multirate.jl +++ b/src/solvers/multirate.jl @@ -45,7 +45,7 @@ function init_cache( # build dt_cache unique_dt_caches = [ - i == 1 ? get_dt_cache(innerinteg.cache) : init_dt_cache(innerinteg.cache, unique_sub_dts[i]) + i == 1 ? get_dt_cache(innerinteg.cache) : init_dt_cache(innerinteg.cache, innerinteg.prob, unique_sub_dts[i]) for i = 1:length(unique_sub_dts)] dt_cache = map(sub_dts) do sub_dt @@ -57,7 +57,7 @@ function init_cache( end get_dt_cache(cache::Multirate) = cache.dt_cache -function init_dt_cache(cache::Multirate, dt) +function init_dt_cache(cache::Multirate, prob, dt) outercache = cache.outercache innerinteg = cache.innerinteg @@ -67,7 +67,7 @@ function init_dt_cache(cache::Multirate, dt) unique_sub_dts = unique(sub_dts) unique_dt_caches = [ - init_dt_cache(innerinteg.cache, unique_sub_dts[i]) + init_dt_cache(innerinteg.cache, innerinteg.prob, unique_sub_dts[i]) for i = 1:length(unique_sub_dts)] dt_cache = map(sub_dts) do sub_dt