Skip to content

Commit e6630fa

Browse files
committed
Fix build
1 parent 50d878e commit e6630fa

File tree

1 file changed

+28
-65
lines changed

1 file changed

+28
-65
lines changed

test/Jacobi.hs

Lines changed: 28 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
{-# LANGUAGE TypeFamilies #-}
88
{-# LANGUAGE FlexibleInstances #-}
99
{-# LANGUAGE LambdaCase #-}
10+
{-# LANGUAGE StandaloneDeriving #-}
1011
module Jacobi where
1112

1213
import Test.Tasty
1314
import Test.Tasty.HUnit
1415

15-
import qualified Data.IntMap.Strict as IM
1616
import qualified Data.Set as S
1717
import Data.String
1818
import Data.List (sort)
@@ -24,7 +24,6 @@ import Text.Show.Deriving
2424

2525
import qualified Data.Foldable as F
2626

27-
import Control.Applicative (liftA2)
2827
import Control.Monad (unless)
2928

3029
import Data.Equality.Graph.Lens
@@ -36,6 +35,7 @@ import Data.Equality.Matching.Database
3635
import Data.Equality.Saturation
3736
import Data.Equality.Saturation.Scheduler
3837
import Numeric.GSL.Special(elljac_e)
38+
import Data.Function
3939

4040
data Expr a = Sym !String
4141
| Const !Double
@@ -69,41 +69,7 @@ deriveShow1 ''Expr
6969
instance IsString (Fix Expr) where
7070
fromString = Fix . Sym
7171

72-
-- TODO: There should be a way to do this with TH.
73-
-- This Ord instance doesn't seem to get derived.
74-
instance Ord (Fix Expr) where
75-
Fix (Sym v) `compare` e = case e of
76-
Fix (Sym v') -> v `compare` v'
77-
_ -> LT
78-
Fix (Const x) `compare` e = case e of
79-
Fix (Sym _) -> GT
80-
Fix (Const y) -> x `compare` y
81-
_ -> LT
82-
Fix (Sum es) `compare` e = case e of
83-
Fix (Sym _) -> GT
84-
Fix (Const _) -> GT
85-
Fix (Sum es') -> es `compare` es'
86-
_ -> LT
87-
Fix (Prod es) `compare` e = case e of
88-
Fix (Sym _) -> GT
89-
Fix (Const _) -> GT
90-
Fix (Sum _) -> GT
91-
Fix (Prod es') -> es `compare` es'
92-
_ -> LT
93-
Fix (UnOp o x) `compare` e = case e of
94-
Fix (Sym _) -> GT
95-
Fix (Const _) -> GT
96-
Fix (Sum _) -> GT
97-
Fix (Prod _) -> GT
98-
Fix (UnOp o' y) -> (o, x) `compare` (o', y)
99-
_ -> LT
100-
Fix (BinOp o x y) `compare` e = case e of
101-
Fix (Sym _) -> GT
102-
Fix (Const _) -> GT
103-
Fix (Sum _) -> GT
104-
Fix (Prod _) -> GT
105-
Fix (UnOp _ _) -> GT
106-
Fix (BinOp o' x' y') -> (o, x, y) `compare` (o', x', y')
72+
deriving instance Ord (Fix Expr)
10773

10874
instance Num (Fix Expr) where
10975
(+) a b = case (a, b) of
@@ -177,18 +143,15 @@ instance Analysis (Maybe Double) Expr where
177143
!_ <- unless (a == b || (a == 0 && b == (-0)) || (a == (-0) && b == 0)) (error "Merged non-equal constants!")
178144
return a
179145

180-
modifyA cl = case cl^._data of
181-
Nothing -> (cl, [])
182-
Just d -> ((_nodes %~ S.filter (F.null .unNode)) cl, [Fix (Const d)])
183-
184-
-- -- Add constant as e-node
185-
-- new_c <- represent (Fix $ Const d)
186-
-- _ <- GM.merge i new_c
187-
188-
-- -- Prune all except leaf e-nodes
189-
-- modify (_class i._nodes %~ S.filter (F.null . unNode))
190-
191-
146+
modifyA cl eg0 =
147+
case eg0^._class cl._data of
148+
Nothing -> eg0
149+
Just d ->
150+
-- Add constant as e-node
151+
let (new_c,eg1) = represent (Fix (Const d)) eg0
152+
(rep, eg2) = merge cl new_c eg1
153+
-- Prune all except leaf e-nodes
154+
in eg2 & _class rep._nodes %~ S.filter (F.null .unNode)
192155

193156
evalConstant :: Expr (Maybe Double) -> Maybe Double
194157
evalConstant = \case
@@ -212,40 +175,40 @@ evalConstant = \case
212175
Sym _ -> Nothing
213176
Const x -> Just x
214177

215-
unsafeGetSubst :: Pattern Expr -> Subst -> ClassId
216-
unsafeGetSubst (NonVariablePattern _) _ = error "unsafeGetSubst: NonVariablePattern; expecting VariablePattern"
217-
unsafeGetSubst (VariablePattern v) subst = case IM.lookup v subst of
178+
unsafeGetSubst :: Pattern Expr -> VarsState -> Subst -> ClassId
179+
unsafeGetSubst (NonVariablePattern _) _ _ = error "unsafeGetSubst: NonVariablePattern; expecting VariablePattern"
180+
unsafeGetSubst (VariablePattern v) vss subst = case lookupSubst (findVarName vss v) subst of
218181
Nothing -> error "Searching for non existent bound var in conditional"
219182
Just class_id -> class_id
220183

221184
is_not_zero :: Pattern Expr -> RewriteCondition (Maybe Double) Expr
222-
is_not_zero v subst egr =
223-
egr^._class (unsafeGetSubst v subst)._data /= Just 0
185+
is_not_zero v vss subst egr =
186+
egr^._class (unsafeGetSubst v vss subst)._data /= Just 0
224187

225188
is_int :: Pattern Expr -> RewriteCondition (Maybe Double) Expr
226-
is_int v subst egr =
227-
case egr^._class (unsafeGetSubst v subst)._data of
189+
is_int v vss subst egr =
190+
case egr^._class (unsafeGetSubst v vss subst)._data of
228191
Just x -> snd (properFraction x :: (Int, Double)) == 0
229192
Nothing -> False
230193

231194
is_positive :: Pattern Expr -> RewriteCondition (Maybe Double) Expr
232-
is_positive v subst egr =
233-
case egr^._class (unsafeGetSubst v subst)._data of
195+
is_positive v vss subst egr =
196+
case egr^._class (unsafeGetSubst v vss subst)._data of
234197
Just x -> x > 0
235198
Nothing -> False
236199

237200
is_sym :: Pattern Expr -> RewriteCondition (Maybe Double) Expr
238-
is_sym v subst egr =
239-
any ((\case (Sym _) -> True; _ -> False) . unNode) (egr^._class (unsafeGetSubst v subst)._nodes)
201+
is_sym v vss subst egr =
202+
any ((\case (Sym _) -> True; _ -> False) . unNode) (egr^._class (unsafeGetSubst v vss subst)._nodes)
240203

241204
is_const :: Pattern Expr -> RewriteCondition (Maybe Double) Expr
242-
is_const v subst egr =
243-
isJust (egr^._class (unsafeGetSubst v subst)._data)
205+
is_const v vss subst egr =
206+
isJust (egr^._class (unsafeGetSubst v vss subst)._data)
244207

245208
is_const_or_distinct_var :: Pattern Expr -> Pattern Expr -> RewriteCondition (Maybe Double) Expr
246-
is_const_or_distinct_var v w subst egr =
247-
let v' = unsafeGetSubst v subst
248-
w' = unsafeGetSubst w subst
209+
is_const_or_distinct_var v w vss subst egr =
210+
let v' = unsafeGetSubst v vss subst
211+
w' = unsafeGetSubst w vss subst
249212
in (eClassId (egr^._class v') /= eClassId (egr^._class w'))
250213
&& (isJust (egr^._class v'._data)
251214
|| any ((\case (Sym _) -> True; _ -> False) . unNode) (egr^._class v'._nodes))

0 commit comments

Comments
 (0)