diff --git a/lib/ModelingToolkitBase/src/systems/abstractsystem.jl b/lib/ModelingToolkitBase/src/systems/abstractsystem.jl index 7112da2c57..9a92195c26 100644 --- a/lib/ModelingToolkitBase/src/systems/abstractsystem.jl +++ b/lib/ModelingToolkitBase/src/systems/abstractsystem.jl @@ -1261,18 +1261,18 @@ namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys)) namespace_guesses(sys::AbstractSystem) = namespace_expr(guesses(sys), sys) """ - $(TYPEDSIGNATURES) + namespace_equations(sys::AbstractSystem) Return `equations(sys)`, namespaced by the name of `sys`. """ -function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys)) - eqs = equations(sys) +function namespace_equations(sys::AbstractSystem, visitor = NoVisitor()) + eqs = equations(sys, visitor) isempty(eqs) && return eqs if eqs === get_eqs(sys) eqs = copy(eqs) end for i in eachindex(eqs) - eqs[i] = namespace_equation(eqs[i], sys; ivs) + eqs[i] = namespace_equation(eqs[i], sys) end return eqs end @@ -1661,7 +1661,84 @@ end flatten(sys::AbstractSystem, args...) = sys """ -$(TYPEDSIGNATURES) + $TYPEDEF + +Abstract supertype for functors that can be passed to recursive functions such as +[`equations`](@ref) to track additional information. +""" +abstract type AbstractRecursivePropertyVisitor end + +""" + descend_visitor!(visitor::AbstractRecursivePropertyVisitor, sys::AbstractSystem, f) + +Descend the `visitor` into system `sys`. Also provide the getter function `f` for the +property the recursive function handles (e.g. `get_eqs` for `equations`). +""" +function descend_visitor! end + +""" + ascend_visitor!(visitor::AbstractRecursivePropertyVisitor, sys::AbstractSystem, f) + +Ascend the `visitor` from system `sys` into the parent, marking that all its subsystems +have been explored. Also provide the getter function `f` for the property the recursive +function handles (e.g. `get_eqs` for `equations`). +""" +function ascend_visitor! end + +""" + $TYPEDEF + +Dummy visitor +""" +struct NoVisitor <: AbstractRecursivePropertyVisitor end +descend_visitor!(::NoVisitor, ::AbstractSystem, _) = nothing +ascend_visitor!(::NoVisitor, ::AbstractSystem, _) = nothing + +""" + $TYPEDEF + +Visitor that tracks source information +""" +struct SourceInformationVisitor <: AbstractRecursivePropertyVisitor + """ + List of names indicating the subsystem containing each value as a path from the root. + Names are in reverse order (root occurs last). + """ + sources::Vector{Vector{Symbol}} + """ + A stack of indices indicating the index where source entries belonging to each system + in the call stack start. + """ + start_positions_stack::Vector{Int} +end + +SourceInformationVisitor() = SourceInformationVisitor(Vector{Symbol}[], Int[]) + +function descend_visitor!(vis::SourceInformationVisitor, sys::AbstractSystem, f) + (; sources, start_positions_stack) = vis + # The sources for equations in this system start from the next valid index + start = length(sources) + 1 + push!(start_positions_stack, start) + # Add source information for the current system + for _ in f(sys) + push!(sources, Symbol[]) + end +end + +function ascend_visitor!(vis::SourceInformationVisitor, sys::AbstractSystem, f) + (; sources, start_positions_stack) = vis + # Get the start position for `sys`. We know we've explored all subsystems of `sys`. + cur_start = pop!(start_positions_stack) + # Since the search is DFS, all entries in `sources` from `cur_start` till the + # end are inside `sys`, so add the name to them. + name = nameof(sys) + for i in cur_start:lastindex(sources) + push!(sources[i], name) + end +end + +""" + equations(sys::AbstractSystem) Get the flattened equations of the system `sys` and its subsystems. It may include some abbreviations and aliases of observables. @@ -1669,17 +1746,40 @@ It is often the most useful way to inspect the equations of a system. See also [`full_equations`](@ref) and [`ModelingToolkitBase.get_eqs`](@ref). """ -function equations(sys::AbstractSystem) +function equations(sys::AbstractSystem, visitor::AbstractRecursivePropertyVisitor = NoVisitor()) eqs = get_eqs(sys) systems = get_systems(sys) - isempty(systems) && return eqs + descend_visitor!(visitor, sys, get_eqs) + if isempty(systems) + ascend_visitor!(visitor, sys, get_eqs) + return eqs + end eqs = copy(eqs) for subsys in systems - append!(eqs, namespace_equations(subsys)) + append!(eqs, namespace_equations(subsys, visitor)) end + ascend_visitor!(visitor, sys, get_eqs) return eqs end +function equations_source(sys::AbstractSystem) + source = Vector{Symbol}[] + for _ in eachindex(get_eqs(sys)) + push!(source, Symbol[]) + end + systems = get_systems(sys) + isempty(systems) && return source + + for subsys in systems + name = nameof(subsys) + sub_sources = equations_source(subsys) + for src in sub_sources + push!(src, name) + end + append!(source, name) + end +end + """ equations_toplevel(sys::AbstractSystem) diff --git a/lib/ModelingToolkitBase/src/systems/connectors.jl b/lib/ModelingToolkitBase/src/systems/connectors.jl index 2a199b338f..a2444154c0 100644 --- a/lib/ModelingToolkitBase/src/systems/connectors.jl +++ b/lib/ModelingToolkitBase/src/systems/connectors.jl @@ -967,6 +967,30 @@ function get_domain_bindings( return binds end +""" + $TYPEDEF + +Struct that can optionally be returned from [`expand_connections`](@ref) and indicates +where in the system the equations come from, along with other potentially useful source +information. + +# Fields + +$TYPEDFIELDS +""" +struct EquationSourceInformation + """ + For each equation, a `Vector{Symbol}` denoting the path from the root system to the + subsystem where this equation comes from. Includes the name of the root system. An + empty entry indicates unknown source (typically for connection equations). + """ + eqs_source::Vector{Vector{Symbol}} + """ + A mask indicating which equations arise from `connect` statements. + """ + is_connection_equation::BitVector +end + """ $(TYPEDSIGNATURES) @@ -974,7 +998,7 @@ Given a hierarchical system with [`connect`](@ref) equations, expand the connect equations and return the new system. `tol` is the tolerance for handling the singularities in stream connection equations that happen when a flow variable approaches zero. """ -function expand_connections(sys::AbstractSystem; tol = 1e-10) +function expand_connections(sys::AbstractSystem, ::Val{with_source_info} = Val(false); tol = 1e-10) where {with_source_info} # turn analysis points into standard connection equations sys = remove_analysis_points(sys) # generate the connection sets @@ -983,7 +1007,31 @@ function expand_connections(sys::AbstractSystem; tol = 1e-10) ceqs, instream_csets = generate_connection_equations_and_stream_connections(sys, csets) stream_eqs, instream_subs = expand_instream(instream_csets, sys; tol = tol) - eqs = [equations(sys); ceqs; stream_eqs] + if with_source_info + source_visitor = SourceInformationVisitor() + eqs = equations(sys, source_visitor) + N = length(eqs) + length(ceqs) + length(stream_eqs) + sources = source_visitor.sources + # Names are in reverse order + foreach(reverse!, sources) + is_connection_equation = falses(length(sources)) + sizehint!(eqs, N) + sizehint!(sources, N) + sizehint!(is_connection_equation, N) + for eq in ceqs + push!(eqs, eq) + push!(sources, Symbol[]) + push!(is_connection_equation, true) + end + for eq in stream_eqs + push!(eqs, eq) + push!(sources, Symbol[]) + push!(is_connection_equation, true) + end + source_info = EquationSourceInformation(sources, is_connection_equation) + else + eqs = [equations(sys); ceqs; stream_eqs] + end if !isempty(instream_subs) # substitute `instream(..)` expressions with their new values for i in eachindex(eqs) @@ -996,7 +1044,13 @@ function expand_connections(sys::AbstractSystem; tol = 1e-10) # build the new system sys = flatten(sys, true) @set! sys.eqs = eqs - @set sys.bindings = newbinds + @set! sys.bindings = newbinds + + if with_source_info + return sys, source_info + else + return sys + end end """ diff --git a/lib/ModelingToolkitBase/src/utils.jl b/lib/ModelingToolkitBase/src/utils.jl index 833ef016ea..a318311956 100644 --- a/lib/ModelingToolkitBase/src/utils.jl +++ b/lib/ModelingToolkitBase/src/utils.jl @@ -1201,6 +1201,25 @@ end _eq_unordered(a, b) = isequal(a, b) +""" + $TYPEDSIGNATURES + +Given an equation that may be an array equation, return a `Vector{Equation}` representing +its scalarized form. +""" +function flatten_equation(eq::Equation)::Vector{Equation} + if !SU.is_array_shape(SU.shape(eq.lhs)) + return [eq] + end + lhs = vec(collect(eq.lhs)::Array{SymbolicT})::Vector{SymbolicT} + rhs = vec(collect(eq.rhs)::Array{SymbolicT})::Vector{SymbolicT} + result = Equation[] + for (l, r) in zip(lhs, rhs) + push!(result, l ~ r) + end + return result +end + """ $(TYPEDSIGNATURES) @@ -1210,15 +1229,7 @@ without scalarizing occurrences of array variables and return the new list of eq function flatten_equations(eqs::Vector{Equation}) _eqs = Equation[] for eq in eqs - if !SU.is_array_shape(SU.shape(eq.lhs)) - push!(_eqs, eq) - continue - end - lhs = vec(collect(eq.lhs)::Array{SymbolicT})::Vector{SymbolicT} - rhs = vec(collect(eq.rhs)::Array{SymbolicT})::Vector{SymbolicT} - for (l, r) in zip(lhs, rhs) - push!(_eqs, l ~ r) - end + append!(_eqs, flatten_equation(eq)) end return _eqs end diff --git a/lib/ModelingToolkitBase/test/components.jl b/lib/ModelingToolkitBase/test/components.jl index 836f8760dc..183b321a6c 100644 --- a/lib/ModelingToolkitBase/test/components.jl +++ b/lib/ModelingToolkitBase/test/components.jl @@ -388,3 +388,35 @@ end # as opposed to `output.u ~ input.u` @test isequal(eq, comp1.input.u ~ comp2.output.u) end + +@testset "Source information propagation through `expand_connections`" begin + @named rc_model = RCModel() + sys, source_info = expand_connections(rc_model, Val(true)) + @test source_info.eqs_source == [ + [:rc_model, :resistor], + [:rc_model, :resistor], + [:rc_model, :resistor], + [:rc_model, :resistor], + [:rc_model, :capacitor], + [:rc_model, :capacitor], + [:rc_model, :capacitor], + [:rc_model, :capacitor], + [:rc_model, :shape], + [:rc_model, :source], + [:rc_model, :source], + [:rc_model, :source], + [:rc_model, :source], + [:rc_model, :ground], + Symbol[], + Symbol[], + Symbol[], + Symbol[], + Symbol[], + Symbol[], + Symbol[], + Symbol[] + ] + is_connect_truth = falses(22) + is_connect_truth[end-7:end] .= true + @test source_info.is_connection_equation == is_connect_truth +end