Skip to content

Commit 1b6380c

Browse files
Merge pull request #4044 from SciML/as/precompile
refactor: improve precompile-friendliness
2 parents 8d49c9d + e31579e commit 1b6380c

File tree

15 files changed

+184
-124
lines changed

15 files changed

+184
-124
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ LinearSolve = "3.19.2"
109109
Logging = "1"
110110
ModelingToolkitBase = "1"
111111
ModelingToolkitStandardLibrary = "2.20"
112-
ModelingToolkitTearing = "1"
112+
ModelingToolkitTearing = "1.0.2"
113113
Moshi = "0.3"
114114
NonlinearSolve = "4.3"
115115
OffsetArrays = "1"
@@ -130,12 +130,12 @@ Serialization = "1"
130130
Setfield = "0.7, 0.8, 1"
131131
SimpleNonlinearSolve = "0.1.0, 1, 2"
132132
SparseArrays = "1"
133-
StateSelection = "1"
133+
StateSelection = "1.1.1"
134134
StaticArrays = "1.9.14"
135135
StochasticDelayDiffEq = "1.11"
136136
StochasticDiffEq = "6.82.0"
137137
SymbolicIndexingInterface = "0.3.39"
138-
SymbolicUtils = "4.5.1"
138+
SymbolicUtils = "4.7.1"
139139
Symbolics = "7"
140140
UnPack = "0.1, 1.0"
141141
julia = "1.9"

lib/ModelingToolkitBase/src/ModelingToolkitBase.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ using PrecompileTools, Reexport
1010
using JumpProcesses
1111
# ONLY here for the invalidations
1212
import REPL
13+
using OffsetArrays: Origin
14+
import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, blockpush!,
15+
undef_blocks, blocks
1316
end
1417

1518
import SymbolicUtils
@@ -60,9 +63,6 @@ using Moshi.Data: @data
6063
using Reexport
6164
using RecursiveArrayTools
6265
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
63-
import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, blockpush!,
64-
undef_blocks, blocks
65-
using OffsetArrays: Origin
6666
import CommonSolve
6767
import EnumX
6868
import ReadOnlyDicts: ReadOnlyDict

lib/ModelingToolkitBase/src/discretedomain.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct Shift <: Operator
2323
"""Fixed Shift"""
2424
t::Union{Nothing, SymbolicT}
2525
steps::Int
26-
Shift(t, steps = 1) = new(value(t), steps)
26+
Shift(t, steps = 1) = new(unwrap(t), steps)
2727
end
2828
Shift(steps::Int) = new(nothing, steps)
2929
normalize_to_differential(s::Shift) = Differential(s.t)^s.steps

lib/ModelingToolkitBase/src/precompile.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ PrecompileTools.@compile_workload begin
1212
x ^ 5
1313
6 ^ x
1414
x - y
15+
x * x * q[1]
1516
-y
1617
2y
1718
z = 2
@@ -81,11 +82,22 @@ PrecompileTools.@compile_workload begin
8182
q[1]
8283
q'q
8384
using ModelingToolkitBase
84-
@variables x(ModelingToolkitBase.t_nounits) y(ModelingToolkitBase.t_nounits)
85-
isequal(ModelingToolkitBase.D_nounits.x, ModelingToolkitBase.t_nounits)
85+
@parameters g
86+
@variables x(ModelingToolkitBase.t_nounits)
87+
@variables y(ModelingToolkitBase.t_nounits) [state_priority = 10]
88+
@variables λ(ModelingToolkitBase.t_nounits)
89+
eqs = [
90+
ModelingToolkitBase.D_nounits(ModelingToolkitBase.D_nounits(x)) ~ λ * x
91+
ModelingToolkitBase.D_nounits(ModelingToolkitBase.D_nounits(y)) ~ λ * y - g
92+
x^2 + y^2 ~ 1
93+
]
94+
dvs = Num[x, y, λ]
95+
ps = Num[g]
8696
ics = Dict{SymbolicT, SymbolicT}()
87-
ics[x] = 2.3
88-
sys = System([ModelingToolkitBase.D_nounits(x) ~ x * y, y ~ 2x], ModelingToolkitBase.t_nounits, [x, y], Num[]; initial_conditions = ics, guesses = ics, name = :sys)
97+
ics[y] = -1.0
98+
ics[ModelingToolkitBase.D_nounits(x)] = 0.5
99+
isequal(ModelingToolkitBase.D_nounits.x, ModelingToolkitBase.t_nounits)
100+
sys = System(eqs, ModelingToolkitBase.t_nounits, dvs, ps; initial_conditions = ics, guesses = ics, name = :sys)
89101
complete(sys)
90102
@static if @isdefined(ModelingToolkit)
91103
TearingState(sys)

lib/ModelingToolkitBase/src/problems/jumpproblem.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,23 +172,22 @@ end
172172

173173
##### MTK dispatches for Symbolic jumps #####
174174
eqtype_supports_collect_vars(j::MassActionJump) = true
175-
function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, j::MassActionJump, iv::Union{SymbolicT, Nothing}; depth = 0,
176-
op = Differential)
177-
collect_vars!(unknowns, parameters, j.scaled_rates, iv; depth, op)
175+
function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, j::MassActionJump, iv::Union{SymbolicT, Nothing}, ::Type{op} = Differential; depth = 0) where {op}
176+
collect_vars!(unknowns, parameters, j.scaled_rates, iv, op; depth)
178177
for field in (j.reactant_stoch, j.net_stoch)
179178
for el in field
180-
collect_vars!(unknowns, parameters, el, iv; depth, op)
179+
collect_vars!(unknowns, parameters, el, iv, op; depth)
181180
end
182181
end
183182
return nothing
184183
end
185184

186185
eqtype_supports_collect_vars(j::Union{ConstantRateJump, VariableRateJump}) = true
187186
function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, j::Union{ConstantRateJump, VariableRateJump},
188-
iv::Union{SymbolicT, Nothing}; depth = 0, op = Differential)
189-
collect_vars!(unknowns, parameters, j.rate, iv; depth, op)
187+
iv::Union{SymbolicT, Nothing}, ::Type{op} = Differential; depth = 0) where {op}
188+
collect_vars!(unknowns, parameters, j.rate, iv, op; depth)
190189
for eq in j.affect!
191-
(eq isa Equation) && collect_vars!(unknowns, parameters, eq, iv; depth, op)
190+
(eq isa Equation) && collect_vars!(unknowns, parameters, eq, iv, op; depth)
192191
end
193192
return nothing
194193
end

lib/ModelingToolkitBase/src/systems/abstractsystem.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,6 +1460,14 @@ function unknowns_toplevel(sys::AbstractSystem)
14601460
return get_unknowns(sys)
14611461
end
14621462

1463+
function __no_initial_params_pred(x::SymbolicT)
1464+
arr, _ = split_indexed_var(x)
1465+
Moshi.Match.@match arr begin
1466+
BSImpl.Term(; f) && if f isa Initial end => false
1467+
_ => true
1468+
end
1469+
end
1470+
14631471
"""
14641472
$(TYPEDSIGNATURES)
14651473
@@ -1482,11 +1490,7 @@ function parameters(sys::AbstractSystem; initial_parameters = false)
14821490
end
14831491
result = collect(result)
14841492
if !initial_parameters && !is_initializesystem(sys)
1485-
filter!(result) do sym
1486-
return !(isoperator(sym, Initial) ||
1487-
iscall(sym) && operation(sym) === getindex &&
1488-
isoperator(arguments(sym)[1], Initial))
1489-
end
1493+
filter!(__no_initial_params_pred, result)
14901494
end
14911495
return result
14921496
end
@@ -1699,7 +1703,7 @@ Recursively substitute `dict` into `expr`. Use `Symbolics.simplify` on the expre
16991703
if `simplify == true`.
17001704
"""
17011705
function substitute_and_simplify(expr, dict::AbstractDict, simplify::Bool)
1702-
expr = Symbolics.fixpoint_sub(expr, dict; operator = Union{ModelingToolkitBase.Initial, Pre})
1706+
expr = Symbolics.fixpoint_sub(expr, dict, Union{Initial, Pre})
17031707
simplify ? Symbolics.simplify(expr) : expr
17041708
end
17051709

lib/ModelingToolkitBase/src/systems/callbacks.jl

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = SymbolicT[
8989
if !haspre(eq) && !(isconst(eq.lhs) && isconst(eq.rhs))
9090
@invokelatest warn_algebraic_equation(eq)
9191
end
92-
collect_vars!(dvs, params, eq, iv; op = Pre)
92+
collect_vars!(dvs, params, eq, iv, Pre)
9393
empty!(_varsbuf)
9494
SU.search_variables!(_varsbuf, eq; is_atomic = OperatorIsAtomic{Pre}())
9595
filter!(x -> iscall(x) && operation(x) === Pre(), _varsbuf)
@@ -125,20 +125,35 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = SymbolicT[
125125
# This `@invokelatest` should not be necessary, but it works around the inference bug
126126
# in https://github.com/JuliaLang/julia/issues/59943. Remove it at your own risk, the
127127
# bug took weeks to reduce to an MWE.
128-
affectsys = @invokelatest mtkcompile(affectsys; fully_determined = nothing)
128+
affectsys = (@invokelatest mtkcompile(affectsys; fully_determined = nothing))::System
129129
# get accessed parameters p from Pre(p) in the callback parameters
130130
accessed_params = Vector{SymbolicT}(filter(isparameter, map(unPre, collect(pre_params))))
131131
union!(accessed_params, sys_params)
132132

133133
# add scalarized unknowns to the map.
134134
_obs, _ = unhack_observed(observed(affectsys), equations(affectsys))
135135
_dvs = vcat(unknowns(affectsys), map(eq -> eq.lhs, _obs))
136-
_dvs = reduce(vcat, map(safe_vec scalarize, _dvs), init = SymbolicT[])
137-
_discs = reduce(vcat, map(safe_vec scalarize, discretes); init = SymbolicT[])
136+
_dvs = __safe_scalarize_vars(_dvs)
137+
_discs = __safe_scalarize_vars(discretes)
138138
setdiff!(_dvs, _discs)
139139
AffectSystem(affectsys, _dvs, accessed_params, discrete_parameters)
140140
end
141141

142+
function __safe_scalarize_vars(vars::Vector{SymbolicT})
143+
_vars = SymbolicT[]
144+
for v in vars
145+
sh = SU.shape(v)::SU.ShapeVecT
146+
if isempty(sh)
147+
push!(_vars, v)
148+
continue
149+
end
150+
for i in SU.stable_eachindex(v)
151+
push!(_vars, v[i])
152+
end
153+
end
154+
return _vars
155+
end
156+
142157
safe_vec(@nospecialize(x)) = x isa SymbolicT ? [x] : vec(x::Array{SymbolicT})
143158

144159
system(a::AffectSystem) = a.system
@@ -1043,13 +1058,13 @@ The `SymbolicDiscreteCallback`s in the returned vector are structs with two fiel
10431058
See also `get_discrete_events`, which only returns the events of the top-level system.
10441059
"""
10451060
function discrete_events(sys::AbstractSystem)
1046-
obs = get_discrete_events(sys)
1061+
cbs = get_discrete_events(sys)
10471062
systems = get_systems(sys)
1048-
cbs = [obs;
1049-
reduce(vcat,
1050-
(map(cb -> namespace_callback(cb, s), discrete_events(s)) for s in systems),
1051-
init = SymbolicDiscreteCallback[])]
1052-
cbs
1063+
cbs = copy(cbs)
1064+
for s in systems
1065+
append!(cbs, map(Base.Fix2(namespace_callback, s), discrete_events(s)))
1066+
end
1067+
return cbs
10531068
end
10541069

10551070
"""
@@ -1100,15 +1115,13 @@ The `SymbolicContinuousCallback`s in the returned vector are structs with two fi
11001115
See also `get_continuous_events`, which only returns the events of the top-level system.
11011116
"""
11021117
function continuous_events(sys::AbstractSystem)
1103-
obs = get_continuous_events(sys)
1104-
filter(!isempty, obs)
1105-
1118+
cbs = get_continuous_events(sys)
11061119
systems = get_systems(sys)
1107-
cbs = [obs;
1108-
reduce(vcat,
1109-
(map(o -> namespace_callback(o, s), continuous_events(s)) for s in systems),
1110-
init = SymbolicContinuousCallback[])]
1111-
filter(!isempty, cbs)
1120+
cbs = copy(cbs)
1121+
for s in systems
1122+
append!(cbs, map(Base.Fix2(namespace_callback, s), continuous_events(s)))
1123+
end
1124+
return cbs
11121125
end
11131126

11141127
"""

lib/ModelingToolkitBase/src/systems/codegen_utils.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ reconstruct array variables if they are present scalarized in `args`.
4242
"""
4343
function array_variable_assignments(args...; argument_name = generated_argument_name)
4444
# map array symbolic to an identically sized array where each element is (buffer_idx, idx_in_buffer)
45-
var_to_arridxs = Dict{BasicSymbolic, Array{Tuple{Int, Int}}}()
45+
var_to_arridxs = Dict{SymbolicT, Vector{Tuple{Int, Int}}}()
4646
for (i, arg) in enumerate(args)
4747
# filter out non-arrays
4848
# any element of args which is not an array is assumed to not contain a
@@ -55,13 +55,12 @@ function array_variable_assignments(args...; argument_name = generated_argument_
5555
for (j, var) in enumerate(arg)
5656
var = unwrap(var)
5757
# filter out non-array-symbolics
58-
iscall(var) || continue
59-
operation(var) == getindex || continue
60-
arrvar = arguments(var)[1]
58+
arrvar, isarr = split_indexed_var(var)
59+
isarr || continue
6160
# get and/or construct the buffer storing indexes
6261
idxbuffer = get!(
63-
() -> map(Returns((0, 0)), eachindex(arrvar)), var_to_arridxs, arrvar)
64-
Origin(first.(axes(arrvar))...)(idxbuffer)[arguments(var)[2:end]...] = (i, j)
62+
() -> map(Returns((0, 0)), SU.stable_eachindex(arrvar)), var_to_arridxs, arrvar)
63+
idxbuffer[SU.as_linear_idx(SU.shape(arrvar), get_stable_index(var))] = (i, j)
6564
end
6665
end
6766

lib/ModelingToolkitBase/src/systems/system.jl

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,12 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[];
381381
continuous_events = SymbolicContinuousCallback[], discrete_events = SymbolicDiscreteCallback[],
382382
connector_type = nothing, assertions = Dict{SymbolicT, String}(),
383383
metadata = MetadataT(), gui_metadata = nothing,
384-
is_dde = nothing, tstops = [], inputs = OrderedSet{SymbolicT}(),
384+
is_dde = nothing, @nospecialize(tstops = []), inputs = OrderedSet{SymbolicT}(),
385385
outputs = OrderedSet{SymbolicT}(), tearing_state = nothing,
386386
ignored_connections = nothing, parent = nothing,
387387
description = "", name = nothing, discover_from_metadata = true,
388388
initializesystem = nothing, is_initializesystem = false, is_discrete = false,
389-
preface = [], checks = true, __legacy_defaults__ = nothing)
389+
@nospecialize(preface = nothing), checks = true, __legacy_defaults__ = nothing)
390390
name === nothing && throw(NoNameError())
391391

392392
if __legacy_defaults__ !== nothing
@@ -480,14 +480,17 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[];
480480
filter!(!(Base.Fix1(===, COMMON_NOTHING) last), guesses)
481481

482482
if iv === nothing
483-
filter!(bindings) do kvp
484-
k = kvp[1]
485-
if k in all_dvs
486-
initial_conditions[k] = kvp[2]
487-
return false
483+
filterer = let initial_conditions = initial_conditions, all_dvs = all_dvs
484+
function _filterer(kvp)
485+
k = kvp[1]
486+
if k in all_dvs
487+
initial_conditions[k] = kvp[2]
488+
return false
489+
end
490+
return true
488491
end
489-
return true
490492
end
493+
filter!(filterer, bindings)
491494
end
492495

493496
check_bindings(ps, bindings)
@@ -683,19 +686,20 @@ Create a `System` with a single equation `eq`.
683686
System(eq::Equation, args...; kwargs...) = System([eq], args...; kwargs...)
684687

685688
function gather_array_params(ps)
686-
new_ps = OrderedSet()
689+
new_ps = OrderedSet{SymbolicT}()
687690
for p in ps
688-
if iscall(p) && operation(p) === getindex
689-
par = arguments(p)[begin]
690-
if symbolic_has_known_size(par) && all(par[i] in ps for i in eachindex(par))
691-
push!(new_ps, par)
691+
arr, isarr = split_indexed_var(p)
692+
sh = SU.shape(arr)
693+
if isarr
694+
if !(sh isa SU.Unknown) && all(in(ps) Base.Fix1(getindex, arr), SU.stable_eachindex(arr))
695+
push!(new_ps, arr)
692696
else
693697
push!(new_ps, p)
694698
end
695699
else
696-
if symbolic_type(p) == ArraySymbolic() && symbolic_has_known_size(p)
697-
for i in eachindex(p)
698-
delete!(new_ps, p[i])
700+
if sh isa SU.ShapeVecT && !isempty(sh)
701+
for i in SU.stable_eachindex(arr)
702+
delete!(new_ps, arr[i])
699703
end
700704
end
701705
push!(new_ps, p)

lib/ModelingToolkitBase/src/systems/systems.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ function canonicalize_io(iovars, type::String)
1010
iobuffer = OrderedSet{SymbolicT}()
1111
arrsyms = AtomicArrayDict{OrderedSet{SymbolicT}}()
1212
for var in iovars
13-
if Symbolics.isarraysymbolic(var)
14-
if !symbolic_has_known_size(var)
15-
throw(ArgumentError("""
16-
All $(type)s must have known shape. Found $var with unknown shape.
17-
"""))
13+
sh = SU.shape(var)
14+
if SU.is_array_shape(sh)
15+
if sh isa SU.ShapeVecT
16+
union!(iobuffer, vec(collect(var)::Array{SymbolicT})::Vector{SymbolicT})
17+
continue
1818
end
19-
union!(iobuffer, vec(collect(var)::Array{SymbolicT})::Vector{SymbolicT})
20-
continue
19+
throw(ArgumentError("""
20+
All $(type)s must have known shape. Found $var with unknown shape.
21+
"""))
2122
end
2223
arr, isarr = split_indexed_var(var)
2324
if isarr
@@ -41,7 +42,7 @@ function canonicalize_io(iovars, type::String)
4142
or simply pass $k as an $type.
4243
"""))
4344
end
44-
if type != "output" && !isequal(vec(collect(k))::Vector{SymbolicT}, collect(v))
45+
if type != "output" && !isequal(vec(collect(k)::Array{SymbolicT})::Vector{SymbolicT}, collect(v))
4546
throw(ArgumentError("""
4647
Elements of scalarized array variables must be in sorted order in $(type)s. \
4748
Either pass all scalarized elements in sorted order as $(type)s \
@@ -601,16 +602,17 @@ function __num_isdiag_noise(mat)
601602
true
602603
end
603604

604-
function __get_num_diag_noise(mat)
605-
map(axes(mat, 1)) do i
605+
function __get_num_diag_noise(mat::Matrix{SymbolicT})
606+
result = fill(Symbolics.COMMON_ZERO, size(mat, 1))
607+
for i in axes(mat, 1)
606608
for j in axes(mat, 2)
607609
mij = mat[i, j]
608-
if !_iszero(mij)
609-
return mij
610-
end
610+
_iszero(mij) && continue
611+
result[i] = mij
612+
break
611613
end
612-
0
613614
end
615+
return result
614616
end
615617

616618
"""

0 commit comments

Comments
 (0)