Skip to content

Commit fba2156

Browse files
Merge pull request #4059 from SciML/as/source-info
feat: propagate equation source information through `expand_connections`
2 parents 2fd0397 + 1f61fd3 commit fba2156

File tree

4 files changed

+217
-20
lines changed

4 files changed

+217
-20
lines changed

lib/ModelingToolkitBase/src/systems/abstractsystem.jl

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,18 +1261,18 @@ namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys))
12611261
namespace_guesses(sys::AbstractSystem) = namespace_expr(guesses(sys), sys)
12621262

12631263
"""
1264-
$(TYPEDSIGNATURES)
1264+
namespace_equations(sys::AbstractSystem)
12651265
12661266
Return `equations(sys)`, namespaced by the name of `sys`.
12671267
"""
1268-
function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys))
1269-
eqs = equations(sys)
1268+
function namespace_equations(sys::AbstractSystem, visitor = NoVisitor())
1269+
eqs = equations(sys, visitor)
12701270
isempty(eqs) && return eqs
12711271
if eqs === get_eqs(sys)
12721272
eqs = copy(eqs)
12731273
end
12741274
for i in eachindex(eqs)
1275-
eqs[i] = namespace_equation(eqs[i], sys; ivs)
1275+
eqs[i] = namespace_equation(eqs[i], sys)
12761276
end
12771277
return eqs
12781278
end
@@ -1661,25 +1661,125 @@ end
16611661
flatten(sys::AbstractSystem, args...) = sys
16621662

16631663
"""
1664-
$(TYPEDSIGNATURES)
1664+
$TYPEDEF
1665+
1666+
Abstract supertype for functors that can be passed to recursive functions such as
1667+
[`equations`](@ref) to track additional information.
1668+
"""
1669+
abstract type AbstractRecursivePropertyVisitor end
1670+
1671+
"""
1672+
descend_visitor!(visitor::AbstractRecursivePropertyVisitor, sys::AbstractSystem, f)
1673+
1674+
Descend the `visitor` into system `sys`. Also provide the getter function `f` for the
1675+
property the recursive function handles (e.g. `get_eqs` for `equations`).
1676+
"""
1677+
function descend_visitor! end
1678+
1679+
"""
1680+
ascend_visitor!(visitor::AbstractRecursivePropertyVisitor, sys::AbstractSystem, f)
1681+
1682+
Ascend the `visitor` from system `sys` into the parent, marking that all its subsystems
1683+
have been explored. Also provide the getter function `f` for the property the recursive
1684+
function handles (e.g. `get_eqs` for `equations`).
1685+
"""
1686+
function ascend_visitor! end
1687+
1688+
"""
1689+
$TYPEDEF
1690+
1691+
Dummy visitor
1692+
"""
1693+
struct NoVisitor <: AbstractRecursivePropertyVisitor end
1694+
descend_visitor!(::NoVisitor, ::AbstractSystem, _) = nothing
1695+
ascend_visitor!(::NoVisitor, ::AbstractSystem, _) = nothing
1696+
1697+
"""
1698+
$TYPEDEF
1699+
1700+
Visitor that tracks source information
1701+
"""
1702+
struct SourceInformationVisitor <: AbstractRecursivePropertyVisitor
1703+
"""
1704+
List of names indicating the subsystem containing each value as a path from the root.
1705+
Names are in reverse order (root occurs last).
1706+
"""
1707+
sources::Vector{Vector{Symbol}}
1708+
"""
1709+
A stack of indices indicating the index where source entries belonging to each system
1710+
in the call stack start.
1711+
"""
1712+
start_positions_stack::Vector{Int}
1713+
end
1714+
1715+
SourceInformationVisitor() = SourceInformationVisitor(Vector{Symbol}[], Int[])
1716+
1717+
function descend_visitor!(vis::SourceInformationVisitor, sys::AbstractSystem, f)
1718+
(; sources, start_positions_stack) = vis
1719+
# The sources for equations in this system start from the next valid index
1720+
start = length(sources) + 1
1721+
push!(start_positions_stack, start)
1722+
# Add source information for the current system
1723+
for _ in f(sys)
1724+
push!(sources, Symbol[])
1725+
end
1726+
end
1727+
1728+
function ascend_visitor!(vis::SourceInformationVisitor, sys::AbstractSystem, f)
1729+
(; sources, start_positions_stack) = vis
1730+
# Get the start position for `sys`. We know we've explored all subsystems of `sys`.
1731+
cur_start = pop!(start_positions_stack)
1732+
# Since the search is DFS, all entries in `sources` from `cur_start` till the
1733+
# end are inside `sys`, so add the name to them.
1734+
name = nameof(sys)
1735+
for i in cur_start:lastindex(sources)
1736+
push!(sources[i], name)
1737+
end
1738+
end
1739+
1740+
"""
1741+
equations(sys::AbstractSystem)
16651742
16661743
Get the flattened equations of the system `sys` and its subsystems.
16671744
It may include some abbreviations and aliases of observables.
16681745
It is often the most useful way to inspect the equations of a system.
16691746
16701747
See also [`full_equations`](@ref) and [`ModelingToolkitBase.get_eqs`](@ref).
16711748
"""
1672-
function equations(sys::AbstractSystem)
1749+
function equations(sys::AbstractSystem, visitor::AbstractRecursivePropertyVisitor = NoVisitor())
16731750
eqs = get_eqs(sys)
16741751
systems = get_systems(sys)
1675-
isempty(systems) && return eqs
1752+
descend_visitor!(visitor, sys, get_eqs)
1753+
if isempty(systems)
1754+
ascend_visitor!(visitor, sys, get_eqs)
1755+
return eqs
1756+
end
16761757
eqs = copy(eqs)
16771758
for subsys in systems
1678-
append!(eqs, namespace_equations(subsys))
1759+
append!(eqs, namespace_equations(subsys, visitor))
16791760
end
1761+
ascend_visitor!(visitor, sys, get_eqs)
16801762
return eqs
16811763
end
16821764

1765+
function equations_source(sys::AbstractSystem)
1766+
source = Vector{Symbol}[]
1767+
for _ in eachindex(get_eqs(sys))
1768+
push!(source, Symbol[])
1769+
end
1770+
systems = get_systems(sys)
1771+
isempty(systems) && return source
1772+
1773+
for subsys in systems
1774+
name = nameof(subsys)
1775+
sub_sources = equations_source(subsys)
1776+
for src in sub_sources
1777+
push!(src, name)
1778+
end
1779+
append!(source, name)
1780+
end
1781+
end
1782+
16831783
"""
16841784
equations_toplevel(sys::AbstractSystem)
16851785

lib/ModelingToolkitBase/src/systems/connectors.jl

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -967,14 +967,38 @@ function get_domain_bindings(
967967
return binds
968968
end
969969

970+
"""
971+
$TYPEDEF
972+
973+
Struct that can optionally be returned from [`expand_connections`](@ref) and indicates
974+
where in the system the equations come from, along with other potentially useful source
975+
information.
976+
977+
# Fields
978+
979+
$TYPEDFIELDS
980+
"""
981+
struct EquationSourceInformation
982+
"""
983+
For each equation, a `Vector{Symbol}` denoting the path from the root system to the
984+
subsystem where this equation comes from. Includes the name of the root system. An
985+
empty entry indicates unknown source (typically for connection equations).
986+
"""
987+
eqs_source::Vector{Vector{Symbol}}
988+
"""
989+
A mask indicating which equations arise from `connect` statements.
990+
"""
991+
is_connection_equation::BitVector
992+
end
993+
970994
"""
971995
$(TYPEDSIGNATURES)
972996
973997
Given a hierarchical system with [`connect`](@ref) equations, expand the connection
974998
equations and return the new system. `tol` is the tolerance for handling the singularities
975999
in stream connection equations that happen when a flow variable approaches zero.
9761000
"""
977-
function expand_connections(sys::AbstractSystem; tol = 1e-10)
1001+
function expand_connections(sys::AbstractSystem, ::Val{with_source_info} = Val(false); tol = 1e-10) where {with_source_info}
9781002
# turn analysis points into standard connection equations
9791003
sys = remove_analysis_points(sys)
9801004
# generate the connection sets
@@ -983,7 +1007,31 @@ function expand_connections(sys::AbstractSystem; tol = 1e-10)
9831007
ceqs, instream_csets = generate_connection_equations_and_stream_connections(sys, csets)
9841008
stream_eqs, instream_subs = expand_instream(instream_csets, sys; tol = tol)
9851009

986-
eqs = [equations(sys); ceqs; stream_eqs]
1010+
if with_source_info
1011+
source_visitor = SourceInformationVisitor()
1012+
eqs = equations(sys, source_visitor)
1013+
N = length(eqs) + length(ceqs) + length(stream_eqs)
1014+
sources = source_visitor.sources
1015+
# Names are in reverse order
1016+
foreach(reverse!, sources)
1017+
is_connection_equation = falses(length(sources))
1018+
sizehint!(eqs, N)
1019+
sizehint!(sources, N)
1020+
sizehint!(is_connection_equation, N)
1021+
for eq in ceqs
1022+
push!(eqs, eq)
1023+
push!(sources, Symbol[])
1024+
push!(is_connection_equation, true)
1025+
end
1026+
for eq in stream_eqs
1027+
push!(eqs, eq)
1028+
push!(sources, Symbol[])
1029+
push!(is_connection_equation, true)
1030+
end
1031+
source_info = EquationSourceInformation(sources, is_connection_equation)
1032+
else
1033+
eqs = [equations(sys); ceqs; stream_eqs]
1034+
end
9871035
if !isempty(instream_subs)
9881036
# substitute `instream(..)` expressions with their new values
9891037
for i in eachindex(eqs)
@@ -996,7 +1044,13 @@ function expand_connections(sys::AbstractSystem; tol = 1e-10)
9961044
# build the new system
9971045
sys = flatten(sys, true)
9981046
@set! sys.eqs = eqs
999-
@set sys.bindings = newbinds
1047+
@set! sys.bindings = newbinds
1048+
1049+
if with_source_info
1050+
return sys, source_info
1051+
else
1052+
return sys
1053+
end
10001054
end
10011055

10021056
"""

lib/ModelingToolkitBase/src/utils.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,25 @@ end
12011201

12021202
_eq_unordered(a, b) = isequal(a, b)
12031203

1204+
"""
1205+
$TYPEDSIGNATURES
1206+
1207+
Given an equation that may be an array equation, return a `Vector{Equation}` representing
1208+
its scalarized form.
1209+
"""
1210+
function flatten_equation(eq::Equation)::Vector{Equation}
1211+
if !SU.is_array_shape(SU.shape(eq.lhs))
1212+
return [eq]
1213+
end
1214+
lhs = vec(collect(eq.lhs)::Array{SymbolicT})::Vector{SymbolicT}
1215+
rhs = vec(collect(eq.rhs)::Array{SymbolicT})::Vector{SymbolicT}
1216+
result = Equation[]
1217+
for (l, r) in zip(lhs, rhs)
1218+
push!(result, l ~ r)
1219+
end
1220+
return result
1221+
end
1222+
12041223
"""
12051224
$(TYPEDSIGNATURES)
12061225
@@ -1210,15 +1229,7 @@ without scalarizing occurrences of array variables and return the new list of eq
12101229
function flatten_equations(eqs::Vector{Equation})
12111230
_eqs = Equation[]
12121231
for eq in eqs
1213-
if !SU.is_array_shape(SU.shape(eq.lhs))
1214-
push!(_eqs, eq)
1215-
continue
1216-
end
1217-
lhs = vec(collect(eq.lhs)::Array{SymbolicT})::Vector{SymbolicT}
1218-
rhs = vec(collect(eq.rhs)::Array{SymbolicT})::Vector{SymbolicT}
1219-
for (l, r) in zip(lhs, rhs)
1220-
push!(_eqs, l ~ r)
1221-
end
1232+
append!(_eqs, flatten_equation(eq))
12221233
end
12231234
return _eqs
12241235
end

lib/ModelingToolkitBase/test/components.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,35 @@ end
388388
# as opposed to `output.u ~ input.u`
389389
@test isequal(eq, comp1.input.u ~ comp2.output.u)
390390
end
391+
392+
@testset "Source information propagation through `expand_connections`" begin
393+
@named rc_model = RCModel()
394+
sys, source_info = expand_connections(rc_model, Val(true))
395+
@test source_info.eqs_source == [
396+
[:rc_model, :resistor],
397+
[:rc_model, :resistor],
398+
[:rc_model, :resistor],
399+
[:rc_model, :resistor],
400+
[:rc_model, :capacitor],
401+
[:rc_model, :capacitor],
402+
[:rc_model, :capacitor],
403+
[:rc_model, :capacitor],
404+
[:rc_model, :shape],
405+
[:rc_model, :source],
406+
[:rc_model, :source],
407+
[:rc_model, :source],
408+
[:rc_model, :source],
409+
[:rc_model, :ground],
410+
Symbol[],
411+
Symbol[],
412+
Symbol[],
413+
Symbol[],
414+
Symbol[],
415+
Symbol[],
416+
Symbol[],
417+
Symbol[]
418+
]
419+
is_connect_truth = falses(22)
420+
is_connect_truth[end-7:end] .= true
421+
@test source_info.is_connection_equation == is_connect_truth
422+
end

0 commit comments

Comments
 (0)