@@ -15,6 +15,7 @@ import Test.Tasty.HUnit
1515import qualified Data.IntMap.Strict as IM
1616import qualified Data.Set as S
1717import Data.String
18+ import Data.List (sort )
1819import Data.Maybe (isJust )
1920
2021import Data.Eq.Deriving
@@ -33,20 +34,19 @@ import Data.Equality.Analysis
3334import Data.Equality.Matching
3435import Data.Equality.Matching.Database
3536import Data.Equality.Saturation
37+ import Data.Equality.Saturation.Scheduler
3638import Numeric.GSL.Special (elljac_e )
3739
3840data Expr a = Sym ! String
3941 | Const ! Double
42+ | Sum [a ]
43+ | Prod [a ]
4044 | UnOp ! UOp ! a
4145 | BinOp ! BOp ! a ! a
4246 deriving ( Eq , Ord , Functor
4347 , Foldable , Traversable
4448 )
45- data BOp = Add
46- | Sub
47- | Mul
48- | Div
49- | Pow
49+ data BOp = Pow
5050 | Sn -- ^ Glaisher notation for Jacobi's sine of the amplitude
5151 | Cn -- ^ Glaisher notation for Jacobi's cosine of the amplitude
5252 | Dn -- ^ Glaisher notation for Jacobi's derivative/delta of the amplitude
@@ -69,29 +69,72 @@ deriveShow1 ''Expr
6969instance IsString (Fix Expr ) where
7070 fromString = Fix . Sym
7171
72+ -- TODO: There should be a way to do this with TH.
73+ -- This Ord instance doesn't seem to get derived.
74+ instance Ord (Fix Expr ) where
75+ Fix (Sym v) `compare` e = case e of
76+ Fix (Sym v') -> v `compare` v'
77+ _ -> LT
78+ Fix (Const x) `compare` e = case e of
79+ Fix (Sym _) -> GT
80+ Fix (Const y) -> x `compare` y
81+ _ -> LT
82+ Fix (Sum es) `compare` e = case e of
83+ Fix (Sym _) -> GT
84+ Fix (Const _) -> GT
85+ Fix (Sum es') -> es `compare` es'
86+ _ -> LT
87+ Fix (Prod es) `compare` e = case e of
88+ Fix (Sym _) -> GT
89+ Fix (Const _) -> GT
90+ Fix (Sum _) -> GT
91+ Fix (Prod es') -> es `compare` es'
92+ _ -> LT
93+ Fix (UnOp o x) `compare` e = case e of
94+ Fix (Sym _) -> GT
95+ Fix (Const _) -> GT
96+ Fix (Sum _) -> GT
97+ Fix (Prod _) -> GT
98+ Fix (UnOp o' y) -> (o, x) `compare` (o', y)
99+ _ -> LT
100+ Fix (BinOp o x y) `compare` e = case e of
101+ Fix (Sym _) -> GT
102+ Fix (Const _) -> GT
103+ Fix (Sum _) -> GT
104+ Fix (Prod _) -> GT
105+ Fix (UnOp _ _) -> GT
106+ Fix (BinOp o' x' y') -> (o, x, y) `compare` (o', x', y')
107+
72108instance Num (Fix Expr ) where
73- (+) a b = Fix (BinOp Add a b)
74- (-) a b = Fix (BinOp Sub a b)
75- (*) a b = Fix (BinOp Mul a b)
109+ (+) a b = case (a, b) of
110+ (Fix (Sum as), Fix (Sum bs)) -> Fix . Sum . sort $ as ++ bs
111+ (Fix (Sum as), _) -> Fix . Sum . sort $ b : as
112+ (_, Fix (Sum bs)) -> Fix . Sum . sort $ a : bs
113+ _ -> Fix . Sum . sort $ [a,b]
114+ (-) a b = a + negate b
115+ (*) a b = case (a, b) of
116+ (Fix (Prod as), Fix (Prod bs)) -> Fix . Prod . sort $ as ++ bs
117+ (Fix (Prod as), _) -> Fix . Prod . sort $ b : as
118+ (_, Fix (Prod bs)) -> Fix . Prod . sort $ a : bs
119+ _ -> Fix . Prod . sort $ [a,b]
76120 fromInteger = Fix . Const . fromInteger
77- negate = error " DONT USE "
121+ negate = (*) . fromInteger $ - 1
78122 abs = error " abs"
79123 signum = error " signum"
80124
81125instance Fractional (Fix Expr ) where
82- (/) a b = Fix (BinOp Div a b )
126+ (/) a b = a * Fix (BinOp Pow b . fromInteger $ - 1 )
83127 fromRational = Fix . Const . fromRational
84128
129+ -- Sum-of-products preference might want a different recursion scheme.
85130symCost :: CostFunction Expr Int
86131symCost = \ case
87132 BinOp Sn e1 e2 -> e1 + e2 + 50
88133 BinOp Cn e1 e2 -> e1 + e2 + 50
89134 BinOp Dn e1 e2 -> e1 + e2 + 50
90- BinOp Pow e1 e2 -> e1 + e2 + 6
91- BinOp Div e1 e2 -> e1 + e2 + 5
92- BinOp Sub e1 e2 -> e1 + e2 + 4
93- BinOp Mul e1 e2 -> e1 + e2 + 4
94- BinOp Add e1 e2 -> e1 + e2 + 2
135+ BinOp Pow e1 e2 -> e1 + e2 + 1
136+ Sum es -> sum es + length es + 5
137+ Prod es -> sum es + length es + 10
95138 BinOp Diff e1 e2 -> e1 + e2 + 500
96139 BinOp Integral e1 e2 -> e1 + e2 + 20000
97140 UnOp Sin e1 -> e1 + 20
@@ -103,17 +146,19 @@ symCost = \case
103146 Sym _ -> 1
104147 Const _ -> 1
105148
149+ -- This Num instance for Pattern Expr may not reflect very well what's
150+ -- needed to successfully describe desired structure.
106151instance Num (Pattern Expr ) where
107- (+) a b = NonVariablePattern $ BinOp Add a b
108- (-) a b = NonVariablePattern $ BinOp Sub a b
109- (*) a b = NonVariablePattern $ BinOp Mul a b
152+ (+) a b = NonVariablePattern $ Sum [a,b]
153+ (-) a b = NonVariablePattern $ Sum [a, NonVariablePattern $ Prod [ fromInteger ( - 1 ), b]]
154+ (*) a b = NonVariablePattern $ Prod [a,b]
110155 fromInteger = NonVariablePattern . Const . fromInteger
111156 negate = error " DONT USE" -- NonVariablePattern. BinOp Mul (fromInteger $ -1)
112157 abs = error " abs"
113158 signum = error " signum"
114159
115160instance Fractional (Pattern Expr ) where
116- (/) a b = NonVariablePattern $ BinOp Div a b
161+ (/) a b = NonVariablePattern $ Prod [a, NonVariablePattern $ BinOp Pow b ( fromInteger ( - 1 ))]
117162 fromRational = NonVariablePattern . Const . fromRational
118163
119164-- | Define analysis for the @Expr@ language over domain @Maybe Double@ for
@@ -148,10 +193,10 @@ instance Analysis (Maybe Double) Expr where
148193evalConstant :: Expr (Maybe Double ) -> Maybe Double
149194evalConstant = \ case
150195 -- Exception: Negative exponent: BinOp Pow e1 e2 -> liftA2 (^) e1 (round <$> e2 :: Maybe Integer)
151- BinOp Div e1 e2 -> liftA2 (/) e1 e2
152- BinOp Sub e1 e2 -> liftA2 (-) e1 e2
153- BinOp Mul e1 e2 -> liftA2 (*) e1 e2
154- BinOp Add e1 e2 -> liftA2 (+) e1 e2
196+ Sum [] -> Just 0
197+ Sum es @ (_ : _) -> foldr1 ( liftA2 (+) ) es
198+ Prod [] -> Just 1
199+ Prod es @ (_ : _) -> foldr1 ( liftA2 (*) ) es
155200 BinOp Pow e1 e2 -> liftA2 (**) e1 e2
156201 BinOp Sn e1 e2 -> fmap (\ (x,_,_) -> x) $ liftA2 elljac_e e1 e2
157202 BinOp Cn e1 e2 -> fmap (\ (_,x,_) -> x) $ liftA2 elljac_e e1 e2
@@ -233,9 +278,10 @@ rewrites =
233278 , (" a" * " b" )+ (" a" * " c" ) := " a" * (" b" + " c" ) -- factor
234279
235280 , powP " a" " b" * powP " a" " c" := powP " a" (" b" + " c" ) -- pow mul
281+ , powP " a" " b" * " a" := powP " a" (" b" + 1 )
236282 , powP " a" 0 := 1 :| is_not_zero " a"
237283 , powP " a" 1 := " a"
238- , powP " a" 2 := " a" * " a "
284+ , " a" * " a " := powP " a" 2
239285 , powP " a" (fromInteger $ - 1 ) := 1 / " a" :| is_not_zero " a"
240286
241287 , " x" * (1 / " x" ) := 1 :| is_not_zero " x"
@@ -247,7 +293,8 @@ rewrites =
247293
248294 -- How can the binomial theorem be represented?
249295 -- Is it really only available for one integer at a time?
250- ++ [ powP (" a" + " b" ) (NonVariablePattern . Const $ fromIntegral n) := sum [(fromInteger $ n `choose` k) * powP " a" (fromInteger k) * powP " b" (fromInteger $ n - k) | k <- [0 .. n]] | n <- [2 .. 1000 ]] ++
296+ -- ++ [ powP ("a" + "b") (NonVariablePattern . Const $ fromIntegral n) := sum [(fromInteger $ n `_choose` k) * powP "a" (fromInteger k) * powP "b" (fromInteger $ n - k) | k <- [0..n]] | n <- [2..1000]]
297+ ++
251298
252299 -- It's a bit unclear to me how to determine that high powers
253300 -- can be reduced. Ideally something like:
@@ -273,11 +320,11 @@ rewrites =
273320 , dnP (fromInteger (- 1 ) * " x" ) " k" := dnP " x" " k"
274321
275322 , snP " x" 0 := sinP " x"
276- -- , snP "x" 1 := tanhP "x"
323+ , snP " x" 1 := tanhP " x"
277324 , cnP " x" 0 := cosP " x"
278- -- , cnP "x" 1 := 1 / powP (coshP "x") 2
325+ , cnP " x" 1 := powP (sechP " x" ) 2
279326 , dnP " x" 0 := 1
280- -- , dnP "x" 1 := 1 / powP (coshP "x") 2
327+ , dnP " x" 1 := powP (sechP " x" ) 2
281328 , cosP (" x" + " y" ) := cosP " x" * cosP " y" - sinP " x" * sinP " y"
282329 , sinP (" x" + " y" ) := sinP " x" * cosP " y" + cosP " x" * sinP " y"
283330 , coshP (" x" + " y" ) := coshP " x" * coshP " y" + sinhP " x" * sinhP " y"
@@ -331,12 +378,12 @@ rewrites =
331378 , " a" - (fromInteger (- 1 )* " b" ) := " a" + " b"
332379
333380 ] where
334- n `choose ` k
381+ n `_choose ` k
335382 | k < 0 || k > n = 0
336383 | k == 0 || k == n = 1
337384 | k == 1 || k == n - 1 = n
338- | 2 * k > n = n `choose ` (n - k)
339- | otherwise = (n - 1 ) `choose ` (k - 1 ) * n `div` k
385+ | 2 * k > n = n `_choose ` (n - k)
386+ | otherwise = (n - 1 ) `_choose ` (k - 1 ) * n `div` k
340387
341388rewrite :: Fix Expr -> Fix Expr
342389rewrite e = fst $ equalitySaturation e rewrites symCost
@@ -435,7 +482,7 @@ symTests = testGroup "Jacobi"
435482
436483 -- TODO: More elliptic function identities may be worthwhile.
437484 , testCase " reduce (dn(x,k))^11 in terms of sn(x,k)" $
438- rewrite ( _pow (_dn " x" " k" ) 11 ) @?= _pow ((1 - _pow (_sn " x" " k" ) 2 ) / _pow " k" 2 ) 5 * _dn " x" " k" -- this should actually not be equal
485+ fst (equalitySaturation' (defaultBackoffScheduler { banLength = 100 }) ( _pow (_dn " x" " k" ) 11 ) rewrites depthCost ) @?= _pow ((1 - _pow (_sn " x" " k" ) 2 ) / _pow " k" 2 ) 5 * _dn " x" " k" -- this should actually not be equal
439486
440487 , testCase " reduce (dn(x,k))^1001 in terms of sn(x,k)" $
441488 rewrite (_pow (_dn " x" " k" ) 1001 ) @?= _pow ((1 - _pow (_sn " x" " k" ) 2 ) / _pow " k" 2 ) 500 * _dn " x" " k"
@@ -485,6 +532,12 @@ coshP a = NonVariablePattern (UnOp Cosh a)
485532sinhP :: Pattern Expr -> Pattern Expr
486533sinhP a = NonVariablePattern (UnOp Sinh a)
487534
535+ tanhP :: Pattern Expr -> Pattern Expr
536+ tanhP a = NonVariablePattern (Prod [NonVariablePattern $ UnOp Sinh a, NonVariablePattern $ BinOp Pow (NonVariablePattern $ UnOp Cosh a) (NonVariablePattern . Const $ - 1 )])
537+
538+ sechP :: Pattern Expr -> Pattern Expr
539+ sechP a = NonVariablePattern $ BinOp Pow (NonVariablePattern $ UnOp Cosh a) (NonVariablePattern . Const $ - 1 )
540+
488541lnP :: Pattern Expr -> Pattern Expr
489542lnP a = NonVariablePattern (UnOp Ln a)
490543
0 commit comments