Skip to content

Commit b26aa57

Browse files
feat: add visitor for recursive accessor functions
Currently only implemented for `equations`
1 parent 2fd0397 commit b26aa57

File tree

1 file changed

+108
-8
lines changed

1 file changed

+108
-8
lines changed

lib/ModelingToolkitBase/src/systems/abstractsystem.jl

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,18 +1261,18 @@ namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys))
12611261
namespace_guesses(sys::AbstractSystem) = namespace_expr(guesses(sys), sys)
12621262

12631263
"""
1264-
$(TYPEDSIGNATURES)
1264+
namespace_equations(sys::AbstractSystem)
12651265
12661266
Return `equations(sys)`, namespaced by the name of `sys`.
12671267
"""
1268-
function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys))
1269-
eqs = equations(sys)
1268+
function namespace_equations(sys::AbstractSystem, visitor = NoVisitor())
1269+
eqs = equations(sys, visitor)
12701270
isempty(eqs) && return eqs
12711271
if eqs === get_eqs(sys)
12721272
eqs = copy(eqs)
12731273
end
12741274
for i in eachindex(eqs)
1275-
eqs[i] = namespace_equation(eqs[i], sys; ivs)
1275+
eqs[i] = namespace_equation(eqs[i], sys)
12761276
end
12771277
return eqs
12781278
end
@@ -1661,25 +1661,125 @@ end
16611661
flatten(sys::AbstractSystem, args...) = sys
16621662

16631663
"""
1664-
$(TYPEDSIGNATURES)
1664+
$TYPEDEF
1665+
1666+
Abstract supertype for functors that can be passed to recursive functions such as
1667+
[`equations`](@ref) to track additional information.
1668+
"""
1669+
abstract type AbstractRecursivePropertyVisitor end
1670+
1671+
"""
1672+
descend_visitor!(visitor::AbstractRecursivePropertyVisitor, sys::AbstractSystem, f)
1673+
1674+
Descend the `visitor` into system `sys`. Also provide the getter function `f` for the
1675+
property the recursive function handles (e.g. `get_eqs` for `equations`).
1676+
"""
1677+
function descend_visitor! end
1678+
1679+
"""
1680+
ascend_visitor!(visitor::AbstractRecursivePropertyVisitor, sys::AbstractSystem, f)
1681+
1682+
Ascend the `visitor` from system `sys` into the parent, marking that all its subsystems
1683+
have been explored. Also provide the getter function `f` for the property the recursive
1684+
function handles (e.g. `get_eqs` for `equations`).
1685+
"""
1686+
function ascend_visitor! end
1687+
1688+
"""
1689+
$TYPEDEF
1690+
1691+
Dummy visitor
1692+
"""
1693+
struct NoVisitor <: AbstractRecursivePropertyVisitor end
1694+
descend_visitor!(::NoVisitor, ::AbstractSystem, _) = nothing
1695+
ascend_visitor!(::NoVisitor, ::AbstractSystem, _) = nothing
1696+
1697+
"""
1698+
$TYPEDEF
1699+
1700+
Visitor that tracks source information
1701+
"""
1702+
struct SourceInformationVisitor <: AbstractRecursivePropertyVisitor
1703+
"""
1704+
List of names indicating the subsystem containing each value as a path from the root.
1705+
Names are in reverse order (root occurs last).
1706+
"""
1707+
sources::Vector{Vector{Symbol}}
1708+
"""
1709+
A stack of indices indicating the index where source entries belonging to each system
1710+
in the call stack start.
1711+
"""
1712+
start_positions_stack::Vector{Int}
1713+
end
1714+
1715+
SourceInformationVisitor() = SourceInformationVisitor(Vector{Symbol}[], Int[])
1716+
1717+
function descend_visitor!(vis::SourceInformationVisitor, sys::AbstractSystem, f)
1718+
(; sources, start_positions_stack) = vis
1719+
# The sources for equations in this system start from the next valid index
1720+
start = length(sources) + 1
1721+
push!(start_positions_stack, start)
1722+
# Add source information for the current system
1723+
for _ in f(sys)
1724+
push!(sources, Symbol[])
1725+
end
1726+
end
1727+
1728+
function ascend_visitor!(vis::SourceInformationVisitor, sys::AbstractSystem, f)
1729+
(; sources, start_positions_stack) = vis
1730+
# Get the start position for `sys`. We know we've explored all subsystems of `sys`.
1731+
cur_start = pop!(start_positions_stack)
1732+
# Since the search is DFS, all entries in `sources` from `cur_start` till the
1733+
# end are inside `sys`, so add the name to them.
1734+
name = nameof(sys)
1735+
for i in cur_start:lastindex(sources)
1736+
push!(sources[i], name)
1737+
end
1738+
end
1739+
1740+
"""
1741+
equations(sys::AbstractSystem)
16651742
16661743
Get the flattened equations of the system `sys` and its subsystems.
16671744
It may include some abbreviations and aliases of observables.
16681745
It is often the most useful way to inspect the equations of a system.
16691746
16701747
See also [`full_equations`](@ref) and [`ModelingToolkitBase.get_eqs`](@ref).
16711748
"""
1672-
function equations(sys::AbstractSystem)
1749+
function equations(sys::AbstractSystem, visitor::AbstractRecursivePropertyVisitor = NoVisitor())
16731750
eqs = get_eqs(sys)
16741751
systems = get_systems(sys)
1675-
isempty(systems) && return eqs
1752+
descend_visitor!(visitor, sys, get_eqs)
1753+
if isempty(systems)
1754+
ascend_visitor!(visitor, sys, get_eqs)
1755+
return eqs
1756+
end
16761757
eqs = copy(eqs)
16771758
for subsys in systems
1678-
append!(eqs, namespace_equations(subsys))
1759+
append!(eqs, namespace_equations(subsys, visitor))
16791760
end
1761+
ascend_visitor!(visitor, sys, get_eqs)
16801762
return eqs
16811763
end
16821764

1765+
function equations_source(sys::AbstractSystem)
1766+
source = Vector{Symbol}[]
1767+
for _ in eachindex(get_eqs(sys))
1768+
push!(source, Symbol[])
1769+
end
1770+
systems = get_systems(sys)
1771+
isempty(systems) && return source
1772+
1773+
for subsys in systems
1774+
name = nameof(subsys)
1775+
sub_sources = equations_source(subsys)
1776+
for src in sub_sources
1777+
push!(src, name)
1778+
end
1779+
append!(source, name)
1780+
end
1781+
end
1782+
16831783
"""
16841784
equations_toplevel(sys::AbstractSystem)
16851785

0 commit comments

Comments
 (0)