Skip to content

Commit 8efdb81

Browse files
feat: add flatten_equation utility
Split out core of `flatten_equations`
1 parent fa650c5 commit 8efdb81

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

lib/ModelingToolkitBase/src/utils.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,25 @@ end
12011201

12021202
_eq_unordered(a, b) = isequal(a, b)
12031203

1204+
"""
1205+
$TYPEDSIGNATURES
1206+
1207+
Given an equation that may be an array equation, return a `Vector{Equation}` representing
1208+
its scalarized form.
1209+
"""
1210+
function flatten_equation(eq::Equation)::Vector{Equation}
1211+
if !SU.is_array_shape(SU.shape(eq.lhs))
1212+
return [eq]
1213+
end
1214+
lhs = vec(collect(eq.lhs)::Array{SymbolicT})::Vector{SymbolicT}
1215+
rhs = vec(collect(eq.rhs)::Array{SymbolicT})::Vector{SymbolicT}
1216+
result = Equation[]
1217+
for (l, r) in zip(lhs, rhs)
1218+
push!(result, l ~ r)
1219+
end
1220+
return result
1221+
end
1222+
12041223
"""
12051224
$(TYPEDSIGNATURES)
12061225
@@ -1210,15 +1229,7 @@ without scalarizing occurrences of array variables and return the new list of eq
12101229
function flatten_equations(eqs::Vector{Equation})
12111230
_eqs = Equation[]
12121231
for eq in eqs
1213-
if !SU.is_array_shape(SU.shape(eq.lhs))
1214-
push!(_eqs, eq)
1215-
continue
1216-
end
1217-
lhs = vec(collect(eq.lhs)::Array{SymbolicT})::Vector{SymbolicT}
1218-
rhs = vec(collect(eq.rhs)::Array{SymbolicT})::Vector{SymbolicT}
1219-
for (l, r) in zip(lhs, rhs)
1220-
push!(_eqs, l ~ r)
1221-
end
1232+
append!(_eqs, flatten_equation(eq))
12221233
end
12231234
return _eqs
12241235
end

0 commit comments

Comments
 (0)