Skip to content

Commit 92713e9

Browse files
committed
Forward.State: make insertion lazy
Previously, the forward state was eagerly updated whenever a hypothesis was added. Now, we use a two-phase protocol. Initially, hypotheses are added to (and removed from) per-rule queues, which were already in place as part of a different laziness scheme. They are properly processed only once we need the forward matches, i.e. when selecting forward rules. The processing is also staggered by phase, so before norm rule selection only norm rules are processed. This yields a nice speedup when hypotheses both appear and disappear within the same phase, e.g. due to repeated `simp`s or `subst`s during normalisation. Previously, safe and unsafe forward rules would have been updated for each new hypotheses. Now, we only process those hypotheses that remain after normalisation. TODO docs
1 parent 0c22752 commit 92713e9

File tree

8 files changed

+217
-235
lines changed

8 files changed

+217
-235
lines changed

Aesop/Forward/State.lean

Lines changed: 135 additions & 183 deletions
Large diffs are not rendered by default.

Aesop/Forward/State/ApplyGoalDiff.lean

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,37 @@ Authors: Jannis Limperg
77
import Aesop.Forward.State
88
import Aesop.RuleSet
99

10-
namespace Aesop
10+
namespace Aesop.ForwardState
1111

1212
open Lean Lean.Meta
1313

1414
/-- Apply a goal diff to the state, adding and removing hypotheses as indicated
1515
by the diff. -/
16-
def ForwardState.applyGoalDiff (rs : LocalRuleSet) (diff : GoalDiff)
17-
(fs : ForwardState) : BaseM (ForwardState × Array ForwardRuleMatch) :=
16+
def applyGoalDiff (rs : LocalRuleSet) (diff : GoalDiff) (fs : ForwardState) :
17+
BaseM ForwardState :=
1818
profilingForwardState do
1919
if ! aesop.dev.statefulForward.get (← getOptions) then
20-
return (fs, #[])
20+
return fs
2121
let fs ← diff.oldGoal.withContext do
2222
diff.removedFVars.foldM (init := fs) λ fs h => eraseHyp h fs
2323
diff.newGoal.withContext do
24-
let (fs, ruleMatches) ←
25-
diff.addedFVars.foldM (init := (fs, ∅)) λ (fs, ruleMatches) h =>
26-
addHyp h fs ruleMatches
24+
let fs ← diff.addedFVars.foldM (init := fs) λ fs h => addHyp h fs
2725
if diff.targetChanged then
28-
updateTarget fs ruleMatches
26+
updateTarget fs
2927
else
30-
return (fs, ruleMatches)
28+
return fs
3129
where
3230
eraseHyp (h : FVarId) (fs : ForwardState) : BaseM ForwardState :=
3331
withConstAesopTraceNode .forward (return m!"erase hyp {Expr.fvar h} ({h.name})") do
3432
return fs.eraseHyp h
3533

36-
addHyp (h : FVarId) (fs : ForwardState)
37-
(ruleMatches : Array ForwardRuleMatch) :
38-
BaseM (ForwardState × Array ForwardRuleMatch) := do
34+
addHyp (h : FVarId) (fs : ForwardState) : BaseM ForwardState := do
3935
let rules ← rs.applicableForwardRules (← h.getType)
4036
let patInsts ← rs.forwardRulePatternSubstsInLocalDecl (← h.getDecl)
41-
fs.addHypWithPatSubstsCore ruleMatches diff.newGoal h rules patInsts
37+
return fs.enqueueHypWithPatSubsts h rules patInsts
4238

43-
updateTarget (fs : ForwardState) (ruleMatches : Array ForwardRuleMatch) :
44-
BaseM (ForwardState × Array ForwardRuleMatch) := do
39+
updateTarget (fs : ForwardState) : BaseM ForwardState := do
4540
let patInsts ← rs.forwardRulePatternSubstsInExpr (← diff.newGoal.getType)
46-
fs.updateTargetPatSubstsCore ruleMatches diff.newGoal patInsts
41+
return fs.enqueueTargetPatSubsts patInsts
4742

48-
end Aesop
43+
end Aesop.ForwardState

Aesop/Forward/State/Initial.lean

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,18 @@ def mkInitialForwardState (goal : MVarId) (rs : LocalRuleSet) :
1818
if ! aesop.dev.statefulForward.get (← getOptions) then
1919
return (∅, #[])
2020
let mut fs : ForwardState := ∅
21-
let mut ruleMatches := rs.constForwardRuleMatches
22-
aesop_trace[forward] do
23-
for m in ruleMatches do
24-
aesop_trace![forward] "match for constant rule {m.rule.name}"
2521
for ldecl inshow MetaM _ from getLCtx do
2622
if ldecl.isImplementationDetail then
2723
continue
2824
let rules ← rs.applicableForwardRules ldecl.type
2925
let patInsts ← rs.forwardRulePatternSubstsInLocalDecl ldecl
30-
let (fs', ruleMatches') ←
31-
fs.addHypWithPatSubstsCore ruleMatches goal ldecl.fvarId rules patInsts
32-
fs := fs'
33-
ruleMatches := ruleMatches'
26+
fs := fs.enqueueHypWithPatSubsts ldecl.fvarId rules patInsts
3427
let patInsts ← rs.forwardRulePatternSubstsInExpr (← goal.getType)
35-
fs.addPatSubstsCore ruleMatches goal patInsts
28+
fs := fs.enqueueTargetPatSubsts patInsts
29+
let ruleMatches := rs.constForwardRuleMatches
30+
aesop_trace[forward] do
31+
for m in ruleMatches do
32+
aesop_trace![forward] "match for constant rule {m.rule.name}"
33+
return (fs, ruleMatches)
3634

3735
end Aesop.LocalRuleSet
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/-
2+
Copyright (c) 2025 Jannis Limperg. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Jannis Limperg
5+
-/
6+
7+
import Aesop.Forward.State
8+
import Aesop.Tree.TreeM
9+
import Aesop.Tree.RunMetaM
10+
11+
open Lean
12+
13+
namespace Aesop.GoalRef
14+
15+
def updateForwardState (phase : PhaseName) (gref : GoalRef) : TreeM Unit := do
16+
let g ← gref.get
17+
if phase == .norm then
18+
throwError "aesop: internal error: {decl_name%}: at goal {g.id}: norm phase not supported"
19+
if ! g.isNormal then
20+
throwError "aesop: internal error: {decl_name%}: attempt to update forward state of non-normal goal {g.id} to phase {phase}"
21+
let (fs, ms) ← g.runMetaMInPostNormState' fun goal =>
22+
g.forwardState.update goal phase
23+
let ms := g.forwardRuleMatches.update ms ∅ ∅
24+
gref.set <| g.setForwardState fs |>.setForwardRuleMatches ms
25+
26+
end Aesop.GoalRef

Aesop/Saturate.lean

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,18 @@ initialize
4848
partial def saturateCore (rs : LocalRuleSet) (goal : MVarId) : SaturateM MVarId :=
4949
withExceptionPrefix "saturate: internal error: " do
5050
goal.checkNotAssigned `saturate
51-
-- We use the forward state only to track the hypotheses present in the goal.
52-
let (fs, _) ← rs.mkInitialForwardState goal
53-
go goal fs
51+
go goal
5452
where
55-
go (goal : MVarId) (fs : ForwardState) : SaturateM MVarId :=
53+
go (goal : MVarId) : SaturateM MVarId :=
5654
withIncRecDepth do
5755
checkSystem "saturate"
5856
trace[saturate] "goal {goal.name}:{indentD goal}"
5957
let mvars := UnorderedArraySet.ofHashSet $ ← goal.getMVarDependencies
6058
let preState ← show MetaM _ from saveState
6159
if let some diff ← tryNormRules goal mvars preState then
62-
let (fs, _) ← fs.applyGoalDiff rs diff
63-
go diff.newGoal fs
60+
go diff.newGoal
6461
else if let some diff ← trySafeRules goal mvars preState then
65-
let (fs, _) ← fs.applyGoalDiff rs diff
66-
go diff.newGoal fs
62+
go diff.newGoal
6763
else
6864
clearForwardImplDetailHyps goal
6965

@@ -131,8 +127,13 @@ partial def saturateCore (rs : LocalRuleSet) (goal : MVarId) : SaturateM MVarId
131127
withExceptionPrefix "saturate: internal error: " do
132128
goal.withContext do
133129
goal.checkNotAssigned `saturate
130+
let mut queue := ∅
134131
let (fs, ruleMatches) ← rs.mkInitialForwardState goal
135-
let queue := ruleMatches.foldl (init := ∅) λ queue m => queue.insert m
132+
for m in ruleMatches do
133+
queue := queue.insert m
134+
let (fs, ruleMatches) ← fs.update goal (phase? := none)
135+
for m in ruleMatches do
136+
queue := queue.insert m
136137
go ∅ fs queue ∅ goal
137138
where
138139
go (hypDepths : Std.HashMap FVarId Nat) (fs : ForwardState) (queue : Queue)
@@ -166,10 +167,11 @@ where
166167
else
167168
let rules ← profilingRuleSelection do
168169
rs.applicableForwardRules type
169-
let (fs, ruleMatches) ← profilingForwardState do
170+
let fs ← profilingForwardState do
170171
let patInsts ←
171172
rs.forwardRulePatternSubstsInLocalDecl (← hyp.getDecl)
172-
fs.addHypWithPatSubsts goal hyp rules patInsts
173+
return fs.enqueueHypWithPatSubsts hyp rules patInsts
174+
let (fs, ruleMatches) ← fs.update goal (phase? := none)
173175
let queue :=
174176
ruleMatches.foldl (init := queue) λ queue m => queue.insert m
175177
go hypDepths fs queue erasedHyps goal

Aesop/Search/Expansion.lean

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Authors: Jannis Limperg
66

77
import Aesop.Search.Expansion.Norm
88
import Aesop.Tree.AddRapp
9+
import Aesop.Forward.State.UpdateGoal
910

1011
open Lean
1112
open Lean.Meta
@@ -173,11 +174,12 @@ def SafeRulesResult.toEmoji : SafeRulesResult → String
173174
| skipped => ruleSkippedEmoji
174175

175176
def runFirstSafeRule (gref : GoalRef) : SearchM Q SafeRulesResult := do
176-
let g ← gref.get
177-
if g.unsafeRulesSelected then
177+
if (← gref.get).unsafeRulesSelected then
178178
return .skipped
179179
-- If the unsafe rules have been selected, we have already tried all the
180180
-- safe rules.
181+
gref.updateForwardState .safe
182+
let g ← gref.get
181183
let rules ← selectSafeRules g
182184
let mut postponedRules := {}
183185
for r in rules do
@@ -197,6 +199,7 @@ def applyPostponedSafeRule (r : PostponedSafeRule) (parentRef : GoalRef) :
197199

198200
partial def runFirstUnsafeRule (postponedSafeRules : Array PostponedSafeRule)
199201
(parentRef : GoalRef) : SearchM Q RuleResult := do
202+
parentRef.updateForwardState .unsafe
200203
let queue ← selectUnsafeRules postponedSafeRules parentRef
201204
let (remainingQueue, result) ← loop queue
202205
parentRef.modify λ g => g.setUnsafeQueue remainingQueue

Aesop/Search/Expansion/Norm.lean

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def getForwardState : NormM ForwardState :=
4141
def getResetForwardState : NormM ForwardState := do
4242
modifyGetThe NormM.State λ s => (s.forwardState, { s with forwardState := ∅ })
4343

44-
def updateForwardState (fs : ForwardState) (newMatches : Array ForwardRuleMatch)
44+
def modifyForwardState (fs : ForwardState) (newMatches : Array ForwardRuleMatch)
4545
(erasedHyps : Std.HashSet FVarId) : NormM Unit :=
4646
modifyThe NormM.State λ s => { s with
4747
forwardState := fs
@@ -50,13 +50,18 @@ def updateForwardState (fs : ForwardState) (newMatches : Array ForwardRuleMatch)
5050
(consumedForwardRuleMatches := #[]) -- We erase the consumed matches separately.
5151
}
5252

53+
def updateForwardState (goal : MVarId) : NormM Unit := do
54+
let fs ← getResetForwardState
55+
let (fs, ms) ← fs.update goal (phase? := some .norm)
56+
modifyForwardState fs ms ∅
57+
5358
def eraseForwardRuleMatch (m : ForwardRuleMatch) : NormM Unit := do
5459
modifyThe NormM.State λ s => { s with forwardRuleMatches := s.forwardRuleMatches.erase m }
5560

5661
def applyDiffToForwardState (diff : GoalDiff) : NormM Unit := do
5762
let fs ← getResetForwardState
58-
let (fs, ms) ← fs.applyGoalDiff (← read).ruleSet diff
59-
updateForwardState fs ms diff.removedFVars
63+
let fs ← fs.applyGoalDiff (← read).ruleSet diff
64+
modifyForwardState fs #[] diff.removedFVars
6065

6166
inductive NormRuleResult
6267
| succeeded (goal : MVarId) (steps? : Option (Array Script.LazyStep))
@@ -97,7 +102,7 @@ returns the matches as well. -/
97102
def runNormRuleTac (rule : NormRule) (input : RuleTacInput) (fs : ForwardState)
98103
(rs : LocalRuleSet) :
99104
NormM $
100-
Option (NormRuleResult × ForwardState × Array ForwardRuleMatch × Std.HashSet FVarId) ×
105+
Option (NormRuleResult × ForwardState × Std.HashSet FVarId) ×
101106
Array ForwardRuleMatch := do
102107
let preMetaState ← show MetaM _ from saveState
103108
let result? ← runRuleTac rule.tac.run rule.name preMetaState input
@@ -111,18 +116,18 @@ def runNormRuleTac (rule : NormRule) (input : RuleTacInput) (fs : ForwardState)
111116
| err m!"rule did not produce exactly one rule application."
112117
show MetaM _ from restoreState rapp.postState
113118
if rapp.goals.isEmpty then
114-
return (some (.proved rapp.scriptSteps?, fs, #[], ∅), forwardRuleMatches)
119+
return (some (.proved rapp.scriptSteps?, fs, ∅), forwardRuleMatches)
115120
let (#[{ diff }]) := rapp.goals
116121
| err m!"rule produced more than one subgoal."
117-
let (fs, ms) ← fs.applyGoalDiff rs diff
122+
let fs ← fs.applyGoalDiff rs diff
118123
let g := diff.newGoal
119124
if ← Check.rules.isEnabled then
120125
let mvars := .ofArray input.mvars.toArray
121126
let actualMVars ← rapp.postState.runMetaM' g.getMVarDependencies
122127
if ! actualMVars == mvars then
123128
err "the goal produced by the rule depends on different metavariables than the original goal."
124129
let result := .succeeded g rapp.scriptSteps?
125-
return (some (result, fs, ms, diff.removedFVars), forwardRuleMatches)
130+
return (some (result, fs, diff.removedFVars), forwardRuleMatches)
126131
where
127132
err {α} (msg : MessageData) : MetaM α := throwError
128133
"aesop: error while running norm rule {rule.name}: {msg}\nThe rule was run on this goal:{indentD $ MessageData.ofGoal input.goal}"
@@ -142,9 +147,9 @@ def runNormRule (goal : MVarId) (mvars : UnorderedArraySet MVarId)
142147
runNormRuleTac rule.rule ruleInput fs (← read).ruleSet
143148
for m in consumedForwardRuleMatches do
144149
eraseForwardRuleMatch m
145-
let (some (result, fs, ms, removedFVars)) := result?
150+
let (some (result, fs, removedFVars)) := result?
146151
| return none
147-
updateForwardState fs ms removedFVars
152+
modifyForwardState fs #[] removedFVars
148153
return result
149154

150155
def runFirstNormRule (goal : MVarId) (mvars : UnorderedArraySet MVarId)
@@ -323,6 +328,7 @@ def runNormSteps (goal : MVarId) (steps : Array NormStep)
323328
let mut anySuccess := false
324329
while iteration < maxIterations do
325330
if step.val == 0 then
331+
updateForwardState goal
326332
let rules ←
327333
selectNormRules ctx.ruleSet (← getThe NormM.State).forwardRuleMatches
328334
goal

Aesop/Tree/AddRapp.lean

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ unsafe def copyGoals (assignedMVars : UnorderedArraySet MVarId)
9393
runInMetaState parentMetaState do
9494
let start ← start.get
9595
let diff ← diffGoals start.currentGoal g.preNormGoal
96-
let (forwardState, ms) ← start.forwardState.applyGoalDiff rs diff
96+
let forwardState ← start.forwardState.applyGoalDiff rs diff
9797
let forwardRuleMatches :=
98-
start.forwardRuleMatches.update ms diff.removedFVars
98+
start.forwardRuleMatches.update #[] diff.removedFVars
9999
(consumedForwardRuleMatches := #[]) -- TODO unsure whether this is correct
100100
let mvars ← .ofHashSet <$> g.preNormGoal.getMVarDependencies
101101
pure (forwardState, forwardRuleMatches, mvars)
@@ -129,9 +129,9 @@ def makeInitialGoal (goal : Subgoal) (mvars : UnorderedArraySet MVarId)
129129
(successProbability : Percent) (origin : GoalOrigin) : TreeM Goal := do
130130
let rs := (← read).ruleSet
131131
let (forwardState, forwardRuleMatches) ← runInMetaState parentMetaState do
132-
let (fs, newMatches) ← parentForwardState.applyGoalDiff rs goal.diff
132+
let fs ← parentForwardState.applyGoalDiff rs goal.diff
133133
let ms :=
134-
parentForwardMatches.update newMatches goal.diff.removedFVars
134+
parentForwardMatches.update #[] goal.diff.removedFVars
135135
consumedForwardRuleMatches
136136
pure (fs, ms)
137137
return Goal.mk {

0 commit comments

Comments
 (0)