Skip to content

Commit 0c22752

Browse files
committed
EMap: add DiscrTree index
This should reduce the number of defeq comparisons.
1 parent 8d8e553 commit 0c22752

File tree

2 files changed

+63
-41
lines changed

2 files changed

+63
-41
lines changed

Aesop/EMap.lean

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,58 +11,79 @@ namespace Aesop
1111
`reducible` transparency and treat metavariables as rigid (i.e.,
1212
unassignable). -/
1313
structure EMap (α) where
14-
rep : AssocList Expr α
14+
/-- The mappings stored in the map. Defeq expressions are identified, so
15+
for each equivalence class of defeq expressions we only store one
16+
representative. Missing values indicate expressions that were removed from the
17+
map. -/
18+
rep : PArray (Option (Expr × α))
19+
/-- An index for `rep`. For each expression `e` at index `i` in `rep`,
20+
`idx.getMatch` returns a list of indexes containing `i`. This is used as a
21+
pre-filter during lookups/insertions/modifications to reduce the number of
22+
defeq comparisons. -/
23+
idx : DiscrTree Nat
1524
deriving Inhabited
1625

1726
namespace EMap
1827

19-
protected def empty : EMap α :=
20-
⟨.nil⟩
28+
protected def empty : EMap α where
29+
rep := .empty
30+
idx := .empty
2131

2232
instance : EmptyCollection (EMap α) :=
2333
⟨.empty⟩
2434

2535
variable [Monad m] [MonadLiftT MetaM m]
2636

2737
instance : ForM m (EMap α) (Expr × α) where
28-
forM map f := map.rep.forM fun e a => f (e, a)
38+
forM map f := map.rep.forM fun
39+
| none => return
40+
| some x => f x
2941

3042
instance : ForIn m (EMap α) (Expr × α) where
31-
forIn map := map.rep.forIn
43+
forIn map init f := map.rep.forIn init fun
44+
| none, s => return .yield s
45+
| some x, s => f x s
3246

3347
def foldlM (init : σ) (f : σ → Expr → α → m σ) (map : EMap α) : m σ :=
34-
map.rep.foldlM f init
48+
map.rep.foldlM (init := init) fun
49+
| s, none => return s
50+
| s, some (e, a) => f s e a
3551

3652
def foldl (init : σ) (f : σ → Expr → α → σ) (map : EMap α) : σ :=
37-
map.rep.foldl f init
53+
inline <| map.foldlM (m := Id) init f
3854

55+
private def getCandidates (e : Expr) (map : EMap α) : m (Array Nat) :=
56+
map.idx.getMatch e
57+
58+
@[specialize]
3959
def alterM (e : Expr) (f : α → m (Option α × β)) (map : EMap α) :
4060
m (EMap α × Option β) := do
4161
let lctx ← show MetaM _ from getLCtx
42-
let (map, b?) ← go lctx map.rep
43-
return (⟨map⟩, b?)
44-
where
45-
go (lctx : LocalContext) : AssocList Expr α → m (AssocList Expr α × Option β)
46-
| .nil => return (.nil, none)
47-
| .cons e' old map => do
48-
if e'.hasAnyFVar (! lctx.contains ·) then
49-
return ← go lctx map
50-
if ← isDefEqReducibleRigid e' e then
51-
let (new?, b) ← f old
52-
match new? with
53-
| none => return (map, b)
54-
| some new => return (.cons e' new map, b)
55-
else
56-
let (map, b) ← go lctx map
57-
return (.cons e' old map, b)
62+
let mut rep := map.rep
63+
for i in ← map.getCandidates e do
64+
let some (e', old) := map.rep[i]!
65+
| continue
66+
if e'.hasAnyFVar (! lctx.contains ·) then
67+
rep := rep.set i none
68+
continue
69+
if ← isDefEqReducibleRigid e' e then
70+
let (new?, b) ← f old
71+
let entry := new?.map (e', ·)
72+
return ({ map with rep := rep.set i entry }, b)
73+
return ({ map with rep }, none)
5874

5975
def alter (e : Expr) (f : α → Option α × β) (map : EMap α) :
6076
MetaM (EMap α × Option β) := do
61-
inline map.alterM e fun a => return f a
77+
inline <| map.alterM e fun a => return f a
6278

63-
def insertNew (e : Expr) (a : α) (map : EMap α) : EMap α :=
64-
⟨.cons e a map.rep⟩
79+
@[specialize]
80+
def insertNew (e : Expr) (a : α) (map : EMap α) : m (EMap α) := do
81+
let i := map.rep.size
82+
let rep := map.rep.push (e, a)
83+
let idx ← map.idx.insert e i
84+
return { idx, rep }
6585

86+
@[specialize]
6687
def insertWithM (e : Expr) (f : Option α → m α) (map : EMap α) :
6788
m (EMap α × Option α × α) := do
6889
let (map, vals?) ← map.alterM e fun old => do
@@ -71,38 +92,38 @@ def insertWithM (e : Expr) (f : Option α → m α) (map : EMap α) :
7192
match vals? with
7293
| none =>
7394
let new ← f none
74-
return (⟨.cons e new map.rep⟩, none, new)
95+
return (← map.insertNew e new, none, new)
7596
| some (old, new) =>
7697
return (map, old, new)
7798

7899
def insertWith (e : Expr) (f : Option α → α) (map : EMap α) :
79100
MetaM (EMap α × Option α × α) :=
80-
inline map.insertWithM e fun a? => return f a?
101+
inline <| map.insertWithM e fun a? => return f a?
81102

82103
def insert (e : Expr) (a : α) (map : EMap α) : MetaM (EMap α) :=
83-
inline (·.fst) <$> map.insertWithM e (fun _ => return a)
104+
(·.fst) <$> inline (map.insertWithM e (fun _ => return a))
84105

85-
def singleton (e : Expr) (a : α) : EMap α :=
86-
⟨.cons e a .nil⟩
106+
def singleton (e : Expr) (a : α) : m (EMap α) :=
107+
EMap.empty.insertNew e a
87108

88109
def findWithKey? (e : Expr) (map : EMap α) : MetaM (Option (Expr × α)) := do
89110
let lctx ← getLCtx
90-
for (e', a) in map.rep do
111+
for i in ← map.getCandidates e do
112+
let some (e', a) := map.rep[i]!
113+
| continue
91114
if e'.hasAnyFVar (! lctx.contains ·) then
92115
continue
93116
if ← isDefEqReducibleRigid e e' then
94117
return some (e', a)
95118
return none
96119

97120
def find? (e : Expr) (map : EMap α) : MetaM (Option α) := do
98-
return (← inline map.findWithKey? e).map (·.2)
99-
100-
private def mapMAssocList (f : α → β → m γ) : AssocList α β → m (AssocList α γ)
101-
| .nil => return .nil
102-
| .cons a b xs => return (.cons a (← f a b) (← mapMAssocList f xs))
121+
return (← inline <| map.findWithKey? e).map (·.2)
103122

104-
def mapM (f : Expr → α → m β) (map : EMap α) : m (EMap β) :=
105-
return ⟨← inline mapMAssocList f map.rep⟩
123+
@[specialize]
124+
def mapM (f : Expr → α → m β) (map : EMap α) : m (EMap β) := do
125+
let rep ← map.rep.mapM fun x? => x?.mapM fun (e, a) => return (e, ← f e a)
126+
return { map with rep }
106127

107128
def map (f : Expr → α → β) (map : EMap α) : EMap β :=
108129
map.mapM (m := Id) f

Aesop/Forward/State.lean

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,14 @@ def modify (imap : InstMap) (slot : SlotIndex) (inst : Expr)
135135
let (ms, hs, a) := f ∅ ∅
136136
if ms.isEmpty && hs.isEmpty then
137137
return (imap, a)
138-
let map := map.insertNew inst (ms, hs)
138+
let map map.insertNew inst (ms, hs)
139139
return (⟨imap.map.insert slot map⟩, a)
140140
else
141141
let (ms, hs, a) := f ∅ ∅
142142
if ms.isEmpty && hs.isEmpty then
143143
return (imap, a)
144-
return (⟨imap.map.insert slot <| .singleton inst (ms, hs)⟩, a)
144+
let map ← EMap.singleton inst (ms, hs)
145+
return (⟨imap.map.insert slot map⟩, a)
145146

146147
/-- Inserts a hyp associated with slot `slot` and instantiation `inst`.
147148
The hyp must be a valid assignment for the slot's premise. Returns `true` if

0 commit comments

Comments
 (0)