77{-# LANGUAGE TypeFamilies #-}
88{-# LANGUAGE FlexibleInstances #-}
99{-# LANGUAGE LambdaCase #-}
10+ {-# LANGUAGE StandaloneDeriving #-}
1011module Jacobi where
1112
1213import Test.Tasty
1314import Test.Tasty.HUnit
1415
15- import qualified Data.IntMap.Strict as IM
1616import qualified Data.Set as S
1717import Data.String
1818import Data.List (sort )
@@ -24,7 +24,6 @@ import Text.Show.Deriving
2424
2525import qualified Data.Foldable as F
2626
27- import Control.Applicative (liftA2 )
2827import Control.Monad (unless )
2928
3029import Data.Equality.Graph.Lens
@@ -36,6 +35,7 @@ import Data.Equality.Matching.Database
3635import Data.Equality.Saturation
3736import Data.Equality.Saturation.Scheduler
3837import Numeric.GSL.Special (elljac_e )
38+ import Data.Function
3939
4040data Expr a = Sym ! String
4141 | Const ! Double
@@ -69,41 +69,7 @@ deriveShow1 ''Expr
6969instance IsString (Fix Expr ) where
7070 fromString = Fix . Sym
7171
72- -- TODO: There should be a way to do this with TH.
73- -- This Ord instance doesn't seem to get derived.
74- instance Ord (Fix Expr ) where
75- Fix (Sym v) `compare` e = case e of
76- Fix (Sym v') -> v `compare` v'
77- _ -> LT
78- Fix (Const x) `compare` e = case e of
79- Fix (Sym _) -> GT
80- Fix (Const y) -> x `compare` y
81- _ -> LT
82- Fix (Sum es) `compare` e = case e of
83- Fix (Sym _) -> GT
84- Fix (Const _) -> GT
85- Fix (Sum es') -> es `compare` es'
86- _ -> LT
87- Fix (Prod es) `compare` e = case e of
88- Fix (Sym _) -> GT
89- Fix (Const _) -> GT
90- Fix (Sum _) -> GT
91- Fix (Prod es') -> es `compare` es'
92- _ -> LT
93- Fix (UnOp o x) `compare` e = case e of
94- Fix (Sym _) -> GT
95- Fix (Const _) -> GT
96- Fix (Sum _) -> GT
97- Fix (Prod _) -> GT
98- Fix (UnOp o' y) -> (o, x) `compare` (o', y)
99- _ -> LT
100- Fix (BinOp o x y) `compare` e = case e of
101- Fix (Sym _) -> GT
102- Fix (Const _) -> GT
103- Fix (Sum _) -> GT
104- Fix (Prod _) -> GT
105- Fix (UnOp _ _) -> GT
106- Fix (BinOp o' x' y') -> (o, x, y) `compare` (o', x', y')
72+ deriving instance Ord (Fix Expr )
10773
10874instance Num (Fix Expr ) where
10975 (+) a b = case (a, b) of
@@ -177,18 +143,15 @@ instance Analysis (Maybe Double) Expr where
177143 ! _ <- unless (a == b || (a == 0 && b == (- 0 )) || (a == (- 0 ) && b == 0 )) (error " Merged non-equal constants!" )
178144 return a
179145
180- modifyA cl = case cl^. _data of
181- Nothing -> (cl, [] )
182- Just d -> ((_nodes %~ S. filter (F. null . unNode)) cl, [Fix (Const d)])
183-
184- -- -- Add constant as e-node
185- -- new_c <- represent (Fix $ Const d)
186- -- _ <- GM.merge i new_c
187-
188- -- -- Prune all except leaf e-nodes
189- -- modify (_class i._nodes %~ S.filter (F.null . unNode))
190-
191-
146+ modifyA cl eg0 =
147+ case eg0^. _class cl. _data of
148+ Nothing -> eg0
149+ Just d ->
150+ -- Add constant as e-node
151+ let (new_c,eg1) = represent (Fix (Const d)) eg0
152+ (rep, eg2) = merge cl new_c eg1
153+ -- Prune all except leaf e-nodes
154+ in eg2 & _class rep. _nodes %~ S. filter (F. null . unNode)
192155
193156evalConstant :: Expr (Maybe Double ) -> Maybe Double
194157evalConstant = \ case
@@ -212,40 +175,40 @@ evalConstant = \case
212175 Sym _ -> Nothing
213176 Const x -> Just x
214177
215- unsafeGetSubst :: Pattern Expr -> Subst -> ClassId
216- unsafeGetSubst (NonVariablePattern _) _ = error " unsafeGetSubst: NonVariablePattern; expecting VariablePattern"
217- unsafeGetSubst (VariablePattern v) subst = case IM. lookup v subst of
178+ unsafeGetSubst :: Pattern Expr -> VarsState -> Subst -> ClassId
179+ unsafeGetSubst (NonVariablePattern _) _ _ = error " unsafeGetSubst: NonVariablePattern; expecting VariablePattern"
180+ unsafeGetSubst (VariablePattern v) vss subst = case lookupSubst (findVarName vss v) subst of
218181 Nothing -> error " Searching for non existent bound var in conditional"
219182 Just class_id -> class_id
220183
221184is_not_zero :: Pattern Expr -> RewriteCondition (Maybe Double ) Expr
222- is_not_zero v subst egr =
223- egr^. _class (unsafeGetSubst v subst). _data /= Just 0
185+ is_not_zero v vss subst egr =
186+ egr^. _class (unsafeGetSubst v vss subst). _data /= Just 0
224187
225188is_int :: Pattern Expr -> RewriteCondition (Maybe Double ) Expr
226- is_int v subst egr =
227- case egr^. _class (unsafeGetSubst v subst). _data of
189+ is_int v vss subst egr =
190+ case egr^. _class (unsafeGetSubst v vss subst). _data of
228191 Just x -> snd (properFraction x :: (Int , Double )) == 0
229192 Nothing -> False
230193
231194is_positive :: Pattern Expr -> RewriteCondition (Maybe Double ) Expr
232- is_positive v subst egr =
233- case egr^. _class (unsafeGetSubst v subst). _data of
195+ is_positive v vss subst egr =
196+ case egr^. _class (unsafeGetSubst v vss subst). _data of
234197 Just x -> x > 0
235198 Nothing -> False
236199
237200is_sym :: Pattern Expr -> RewriteCondition (Maybe Double ) Expr
238- is_sym v subst egr =
239- any ((\ case (Sym _) -> True ; _ -> False ) . unNode) (egr^. _class (unsafeGetSubst v subst). _nodes)
201+ is_sym v vss subst egr =
202+ any ((\ case (Sym _) -> True ; _ -> False ) . unNode) (egr^. _class (unsafeGetSubst v vss subst). _nodes)
240203
241204is_const :: Pattern Expr -> RewriteCondition (Maybe Double ) Expr
242- is_const v subst egr =
243- isJust (egr^. _class (unsafeGetSubst v subst). _data)
205+ is_const v vss subst egr =
206+ isJust (egr^. _class (unsafeGetSubst v vss subst). _data)
244207
245208is_const_or_distinct_var :: Pattern Expr -> Pattern Expr -> RewriteCondition (Maybe Double ) Expr
246- is_const_or_distinct_var v w subst egr =
247- let v' = unsafeGetSubst v subst
248- w' = unsafeGetSubst w subst
209+ is_const_or_distinct_var v w vss subst egr =
210+ let v' = unsafeGetSubst v vss subst
211+ w' = unsafeGetSubst w vss subst
249212 in (eClassId (egr^. _class v') /= eClassId (egr^. _class w'))
250213 && (isJust (egr^. _class v'. _data)
251214 || any ((\ case (Sym _) -> True ; _ -> False ) . unNode) (egr^. _class v'. _nodes))
0 commit comments