Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 108 additions & 8 deletions lib/ModelingToolkitBase/src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1261,18 +1261,18 @@ namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys))
namespace_guesses(sys::AbstractSystem) = namespace_expr(guesses(sys), sys)

"""
$(TYPEDSIGNATURES)
namespace_equations(sys::AbstractSystem)

Return `equations(sys)`, namespaced by the name of `sys`.
"""
function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys))
eqs = equations(sys)
function namespace_equations(sys::AbstractSystem, visitor = NoVisitor())
eqs = equations(sys, visitor)
isempty(eqs) && return eqs
if eqs === get_eqs(sys)
eqs = copy(eqs)
end
for i in eachindex(eqs)
eqs[i] = namespace_equation(eqs[i], sys; ivs)
eqs[i] = namespace_equation(eqs[i], sys)
end
return eqs
end
Expand Down Expand Up @@ -1661,25 +1661,125 @@ end
flatten(sys::AbstractSystem, args...) = sys

"""
$(TYPEDSIGNATURES)
$TYPEDEF

Abstract supertype for functors that can be passed to recursive functions such as
[`equations`](@ref) to track additional information.
"""
abstract type AbstractRecursivePropertyVisitor end

"""
descend_visitor!(visitor::AbstractRecursivePropertyVisitor, sys::AbstractSystem, f)

Descend the `visitor` into system `sys`. Also provide the getter function `f` for the
property the recursive function handles (e.g. `get_eqs` for `equations`).
"""
function descend_visitor! end

"""
ascend_visitor!(visitor::AbstractRecursivePropertyVisitor, sys::AbstractSystem, f)

Ascend the `visitor` from system `sys` into the parent, marking that all its subsystems
have been explored. Also provide the getter function `f` for the property the recursive
function handles (e.g. `get_eqs` for `equations`).
"""
function ascend_visitor! end

"""
$TYPEDEF

Dummy visitor
"""
struct NoVisitor <: AbstractRecursivePropertyVisitor end
descend_visitor!(::NoVisitor, ::AbstractSystem, _) = nothing
ascend_visitor!(::NoVisitor, ::AbstractSystem, _) = nothing

"""
$TYPEDEF

Visitor that tracks source information
"""
struct SourceInformationVisitor <: AbstractRecursivePropertyVisitor
"""
List of names indicating the subsystem containing each value as a path from the root.
Names are in reverse order (root occurs last).
"""
sources::Vector{Vector{Symbol}}
"""
A stack of indices indicating the index where source entries belonging to each system
in the call stack start.
"""
start_positions_stack::Vector{Int}
end

SourceInformationVisitor() = SourceInformationVisitor(Vector{Symbol}[], Int[])

function descend_visitor!(vis::SourceInformationVisitor, sys::AbstractSystem, f)
(; sources, start_positions_stack) = vis
# The sources for equations in this system start from the next valid index
start = length(sources) + 1
push!(start_positions_stack, start)
# Add source information for the current system
for _ in f(sys)
push!(sources, Symbol[])
end
end

function ascend_visitor!(vis::SourceInformationVisitor, sys::AbstractSystem, f)
(; sources, start_positions_stack) = vis
# Get the start position for `sys`. We know we've explored all subsystems of `sys`.
cur_start = pop!(start_positions_stack)
# Since the search is DFS, all entries in `sources` from `cur_start` till the
# end are inside `sys`, so add the name to them.
name = nameof(sys)
for i in cur_start:lastindex(sources)
push!(sources[i], name)
end
end

"""
equations(sys::AbstractSystem)

Get the flattened equations of the system `sys` and its subsystems.
It may include some abbreviations and aliases of observables.
It is often the most useful way to inspect the equations of a system.

See also [`full_equations`](@ref) and [`ModelingToolkitBase.get_eqs`](@ref).
"""
function equations(sys::AbstractSystem)
function equations(sys::AbstractSystem, visitor::AbstractRecursivePropertyVisitor = NoVisitor())
eqs = get_eqs(sys)
systems = get_systems(sys)
isempty(systems) && return eqs
descend_visitor!(visitor, sys, get_eqs)
if isempty(systems)
ascend_visitor!(visitor, sys, get_eqs)
return eqs
end
eqs = copy(eqs)
for subsys in systems
append!(eqs, namespace_equations(subsys))
append!(eqs, namespace_equations(subsys, visitor))
end
ascend_visitor!(visitor, sys, get_eqs)
return eqs
end

function equations_source(sys::AbstractSystem)
source = Vector{Symbol}[]
for _ in eachindex(get_eqs(sys))
push!(source, Symbol[])
end
systems = get_systems(sys)
isempty(systems) && return source

for subsys in systems
name = nameof(subsys)
sub_sources = equations_source(subsys)
for src in sub_sources
push!(src, name)
end
append!(source, name)
end
end

"""
equations_toplevel(sys::AbstractSystem)

Expand Down
60 changes: 57 additions & 3 deletions lib/ModelingToolkitBase/src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -967,14 +967,38 @@ function get_domain_bindings(
return binds
end

"""
$TYPEDEF

Struct that can optionally be returned from [`expand_connections`](@ref) and indicates
where in the system the equations come from, along with other potentially useful source
information.

# Fields

$TYPEDFIELDS
"""
struct EquationSourceInformation
"""
For each equation, a `Vector{Symbol}` denoting the path from the root system to the
subsystem where this equation comes from. Includes the name of the root system. An
empty entry indicates unknown source (typically for connection equations).
"""
eqs_source::Vector{Vector{Symbol}}
"""
A mask indicating which equations arise from `connect` statements.
"""
is_connection_equation::BitVector
end

"""
$(TYPEDSIGNATURES)

Given a hierarchical system with [`connect`](@ref) equations, expand the connection
equations and return the new system. `tol` is the tolerance for handling the singularities
in stream connection equations that happen when a flow variable approaches zero.
"""
function expand_connections(sys::AbstractSystem; tol = 1e-10)
function expand_connections(sys::AbstractSystem, ::Val{with_source_info} = Val(false); tol = 1e-10) where {with_source_info}
# turn analysis points into standard connection equations
sys = remove_analysis_points(sys)
# generate the connection sets
Expand All @@ -983,7 +1007,31 @@ function expand_connections(sys::AbstractSystem; tol = 1e-10)
ceqs, instream_csets = generate_connection_equations_and_stream_connections(sys, csets)
stream_eqs, instream_subs = expand_instream(instream_csets, sys; tol = tol)

eqs = [equations(sys); ceqs; stream_eqs]
if with_source_info
source_visitor = SourceInformationVisitor()
eqs = equations(sys, source_visitor)
N = length(eqs) + length(ceqs) + length(stream_eqs)
sources = source_visitor.sources
# Names are in reverse order
foreach(reverse!, sources)
is_connection_equation = falses(length(sources))
sizehint!(eqs, N)
sizehint!(sources, N)
sizehint!(is_connection_equation, N)
for eq in ceqs
push!(eqs, eq)
push!(sources, Symbol[])
push!(is_connection_equation, true)
end
for eq in stream_eqs
push!(eqs, eq)
push!(sources, Symbol[])
push!(is_connection_equation, true)
end
source_info = EquationSourceInformation(sources, is_connection_equation)
else
eqs = [equations(sys); ceqs; stream_eqs]
end
if !isempty(instream_subs)
# substitute `instream(..)` expressions with their new values
for i in eachindex(eqs)
Expand All @@ -996,7 +1044,13 @@ function expand_connections(sys::AbstractSystem; tol = 1e-10)
# build the new system
sys = flatten(sys, true)
@set! sys.eqs = eqs
@set sys.bindings = newbinds
@set! sys.bindings = newbinds

if with_source_info
return sys, source_info
else
return sys
end
end

"""
Expand Down
29 changes: 20 additions & 9 deletions lib/ModelingToolkitBase/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,25 @@ end

_eq_unordered(a, b) = isequal(a, b)

"""
$TYPEDSIGNATURES

Given an equation that may be an array equation, return a `Vector{Equation}` representing
its scalarized form.
"""
function flatten_equation(eq::Equation)::Vector{Equation}
if !SU.is_array_shape(SU.shape(eq.lhs))
return [eq]
end
lhs = vec(collect(eq.lhs)::Array{SymbolicT})::Vector{SymbolicT}
rhs = vec(collect(eq.rhs)::Array{SymbolicT})::Vector{SymbolicT}
result = Equation[]
for (l, r) in zip(lhs, rhs)
push!(result, l ~ r)
end
return result
end

"""
$(TYPEDSIGNATURES)

Expand All @@ -1210,15 +1229,7 @@ without scalarizing occurrences of array variables and return the new list of eq
function flatten_equations(eqs::Vector{Equation})
_eqs = Equation[]
for eq in eqs
if !SU.is_array_shape(SU.shape(eq.lhs))
push!(_eqs, eq)
continue
end
lhs = vec(collect(eq.lhs)::Array{SymbolicT})::Vector{SymbolicT}
rhs = vec(collect(eq.rhs)::Array{SymbolicT})::Vector{SymbolicT}
for (l, r) in zip(lhs, rhs)
push!(_eqs, l ~ r)
end
append!(_eqs, flatten_equation(eq))
end
return _eqs
end
Expand Down
32 changes: 32 additions & 0 deletions lib/ModelingToolkitBase/test/components.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,35 @@ end
# as opposed to `output.u ~ input.u`
@test isequal(eq, comp1.input.u ~ comp2.output.u)
end

@testset "Source information propagation through `expand_connections`" begin
@named rc_model = RCModel()
sys, source_info = expand_connections(rc_model, Val(true))
@test source_info.eqs_source == [
[:rc_model, :resistor],
[:rc_model, :resistor],
[:rc_model, :resistor],
[:rc_model, :resistor],
[:rc_model, :capacitor],
[:rc_model, :capacitor],
[:rc_model, :capacitor],
[:rc_model, :capacitor],
[:rc_model, :shape],
[:rc_model, :source],
[:rc_model, :source],
[:rc_model, :source],
[:rc_model, :source],
[:rc_model, :ground],
Symbol[],
Symbol[],
Symbol[],
Symbol[],
Symbol[],
Symbol[],
Symbol[],
Symbol[]
]
is_connect_truth = falses(22)
is_connect_truth[end-7:end] .= true
@test source_info.is_connection_equation == is_connect_truth
end
Loading