Skip to content

Commit 9caffb1

Browse files
committed
Use lists for products and sums.
Guiding sum-of-products simplification via cost functions with access only to outermost constructors is difficult. The effects are limited.
1 parent a004873 commit 9caffb1

File tree

1 file changed

+85
-32
lines changed

1 file changed

+85
-32
lines changed

test/Jacobi.hs

Lines changed: 85 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import Test.Tasty.HUnit
1515
import qualified Data.IntMap.Strict as IM
1616
import qualified Data.Set as S
1717
import Data.String
18+
import Data.List (sort)
1819
import Data.Maybe (isJust)
1920

2021
import Data.Eq.Deriving
@@ -33,20 +34,19 @@ import Data.Equality.Analysis
3334
import Data.Equality.Matching
3435
import Data.Equality.Matching.Database
3536
import Data.Equality.Saturation
37+
import Data.Equality.Saturation.Scheduler
3638
import Numeric.GSL.Special(elljac_e)
3739

3840
data Expr a = Sym !String
3941
| Const !Double
42+
| Sum [a]
43+
| Prod [a]
4044
| UnOp !UOp !a
4145
| BinOp !BOp !a !a
4246
deriving ( Eq, Ord, Functor
4347
, Foldable, Traversable
4448
)
45-
data BOp = Add
46-
| Sub
47-
| Mul
48-
| Div
49-
| Pow
49+
data BOp = Pow
5050
| Sn -- ^ Glaisher notation for Jacobi's sine of the amplitude
5151
| Cn -- ^ Glaisher notation for Jacobi's cosine of the amplitude
5252
| Dn -- ^ Glaisher notation for Jacobi's derivative/delta of the amplitude
@@ -69,29 +69,72 @@ deriveShow1 ''Expr
6969
instance IsString (Fix Expr) where
7070
fromString = Fix . Sym
7171

72+
-- TODO: There should be a way to do this with TH.
73+
-- This Ord instance doesn't seem to get derived.
74+
instance Ord (Fix Expr) where
75+
Fix (Sym v) `compare` e = case e of
76+
Fix (Sym v') -> v `compare` v'
77+
_ -> LT
78+
Fix (Const x) `compare` e = case e of
79+
Fix (Sym _) -> GT
80+
Fix (Const y) -> x `compare` y
81+
_ -> LT
82+
Fix (Sum es) `compare` e = case e of
83+
Fix (Sym _) -> GT
84+
Fix (Const _) -> GT
85+
Fix (Sum es') -> es `compare` es'
86+
_ -> LT
87+
Fix (Prod es) `compare` e = case e of
88+
Fix (Sym _) -> GT
89+
Fix (Const _) -> GT
90+
Fix (Sum _) -> GT
91+
Fix (Prod es') -> es `compare` es'
92+
_ -> LT
93+
Fix (UnOp o x) `compare` e = case e of
94+
Fix (Sym _) -> GT
95+
Fix (Const _) -> GT
96+
Fix (Sum _) -> GT
97+
Fix (Prod _) -> GT
98+
Fix (UnOp o' y) -> (o, x) `compare` (o', y)
99+
_ -> LT
100+
Fix (BinOp o x y) `compare` e = case e of
101+
Fix (Sym _) -> GT
102+
Fix (Const _) -> GT
103+
Fix (Sum _) -> GT
104+
Fix (Prod _) -> GT
105+
Fix (UnOp _ _) -> GT
106+
Fix (BinOp o' x' y') -> (o, x, y) `compare` (o', x', y')
107+
72108
instance Num (Fix Expr) where
73-
(+) a b = Fix (BinOp Add a b)
74-
(-) a b = Fix (BinOp Sub a b)
75-
(*) a b = Fix (BinOp Mul a b)
109+
(+) a b = case (a, b) of
110+
(Fix (Sum as), Fix (Sum bs)) -> Fix . Sum . sort $ as ++ bs
111+
(Fix (Sum as), _) -> Fix . Sum . sort $ b : as
112+
(_, Fix (Sum bs)) -> Fix . Sum . sort $ a : bs
113+
_ -> Fix . Sum . sort $ [a,b]
114+
(-) a b = a + negate b
115+
(*) a b = case (a, b) of
116+
(Fix (Prod as), Fix (Prod bs)) -> Fix . Prod . sort $ as ++ bs
117+
(Fix (Prod as), _) -> Fix . Prod . sort $ b : as
118+
(_, Fix (Prod bs)) -> Fix . Prod . sort $ a : bs
119+
_ -> Fix . Prod . sort $ [a,b]
76120
fromInteger = Fix . Const . fromInteger
77-
negate = error "DONT USE"
121+
negate = (*) . fromInteger $ -1
78122
abs = error "abs"
79123
signum = error "signum"
80124

81125
instance Fractional (Fix Expr) where
82-
(/) a b = Fix (BinOp Div a b)
126+
(/) a b = a * Fix (BinOp Pow b . fromInteger $ -1)
83127
fromRational = Fix . Const . fromRational
84128

129+
-- Sum-of-products preference might want a different recursion scheme.
85130
symCost :: CostFunction Expr Int
86131
symCost = \case
87132
BinOp Sn e1 e2 -> e1 + e2 + 50
88133
BinOp Cn e1 e2 -> e1 + e2 + 50
89134
BinOp Dn e1 e2 -> e1 + e2 + 50
90-
BinOp Pow e1 e2 -> e1 + e2 + 6
91-
BinOp Div e1 e2 -> e1 + e2 + 5
92-
BinOp Sub e1 e2 -> e1 + e2 + 4
93-
BinOp Mul e1 e2 -> e1 + e2 + 4
94-
BinOp Add e1 e2 -> e1 + e2 + 2
135+
BinOp Pow e1 e2 -> e1 + e2 + 1
136+
Sum es -> sum es + length es + 5
137+
Prod es -> sum es + length es + 10
95138
BinOp Diff e1 e2 -> e1 + e2 + 500
96139
BinOp Integral e1 e2 -> e1 + e2 + 20000
97140
UnOp Sin e1 -> e1 + 20
@@ -103,17 +146,19 @@ symCost = \case
103146
Sym _ -> 1
104147
Const _ -> 1
105148

149+
-- This Num instance for Pattern Expr may not reflect very well what's
150+
-- needed to successfully describe desired structure.
106151
instance Num (Pattern Expr) where
107-
(+) a b = NonVariablePattern $ BinOp Add a b
108-
(-) a b = NonVariablePattern $ BinOp Sub a b
109-
(*) a b = NonVariablePattern $ BinOp Mul a b
152+
(+) a b = NonVariablePattern $ Sum [a,b]
153+
(-) a b = NonVariablePattern $ Sum [a, NonVariablePattern $ Prod [fromInteger (-1), b]]
154+
(*) a b = NonVariablePattern $ Prod [a,b]
110155
fromInteger = NonVariablePattern . Const . fromInteger
111156
negate = error "DONT USE" -- NonVariablePattern. BinOp Mul (fromInteger $ -1)
112157
abs = error "abs"
113158
signum = error "signum"
114159

115160
instance Fractional (Pattern Expr) where
116-
(/) a b = NonVariablePattern $ BinOp Div a b
161+
(/) a b = NonVariablePattern $ Prod [a, NonVariablePattern $ BinOp Pow b (fromInteger (-1))]
117162
fromRational = NonVariablePattern . Const . fromRational
118163

119164
-- | Define analysis for the @Expr@ language over domain @Maybe Double@ for
@@ -148,10 +193,10 @@ instance Analysis (Maybe Double) Expr where
148193
evalConstant :: Expr (Maybe Double) -> Maybe Double
149194
evalConstant = \case
150195
-- Exception: Negative exponent: BinOp Pow e1 e2 -> liftA2 (^) e1 (round <$> e2 :: Maybe Integer)
151-
BinOp Div e1 e2 -> liftA2 (/) e1 e2
152-
BinOp Sub e1 e2 -> liftA2 (-) e1 e2
153-
BinOp Mul e1 e2 -> liftA2 (*) e1 e2
154-
BinOp Add e1 e2 -> liftA2 (+) e1 e2
196+
Sum [] -> Just 0
197+
Sum es@(_:_) -> foldr1 (liftA2 (+)) es
198+
Prod [] -> Just 1
199+
Prod es@(_:_) -> foldr1 (liftA2 (*)) es
155200
BinOp Pow e1 e2 -> liftA2 (**) e1 e2
156201
BinOp Sn e1 e2 -> fmap (\(x,_,_) -> x) $ liftA2 elljac_e e1 e2
157202
BinOp Cn e1 e2 -> fmap (\(_,x,_) -> x) $ liftA2 elljac_e e1 e2
@@ -233,9 +278,10 @@ rewrites =
233278
, ("a"*"b")+("a"*"c") := "a"*("b"+"c") -- factor
234279

235280
, powP "a" "b"*powP "a" "c" := powP "a" ("b" + "c") -- pow mul
281+
, powP "a" "b" * "a" := powP "a" ("b" + 1)
236282
, powP "a" 0 := 1 :| is_not_zero "a"
237283
, powP "a" 1 := "a"
238-
, powP "a" 2 := "a"*"a"
284+
, "a" * "a" := powP "a" 2
239285
, powP "a" (fromInteger $ -1) := 1/"a" :| is_not_zero "a"
240286

241287
, "x"*(1/"x") := 1 :| is_not_zero "x"
@@ -247,7 +293,8 @@ rewrites =
247293

248294
-- How can the binomial theorem be represented?
249295
-- Is it really only available for one integer at a time?
250-
++ [ powP ("a" + "b") (NonVariablePattern . Const $ fromIntegral n) := sum [(fromInteger $ n `choose` k) * powP "a" (fromInteger k) * powP "b" (fromInteger $ n - k) | k <- [0..n]] | n <- [2..1000]] ++
296+
-- ++ [ powP ("a" + "b") (NonVariablePattern . Const $ fromIntegral n) := sum [(fromInteger $ n `_choose` k) * powP "a" (fromInteger k) * powP "b" (fromInteger $ n - k) | k <- [0..n]] | n <- [2..1000]]
297+
++
251298

252299
-- It's a bit unclear to me how to determine that high powers
253300
-- can be reduced. Ideally something like:
@@ -273,11 +320,11 @@ rewrites =
273320
, dnP (fromInteger (-1) * "x") "k" := dnP "x" "k"
274321

275322
, snP "x" 0 := sinP "x"
276-
-- , snP "x" 1 := tanhP "x"
323+
, snP "x" 1 := tanhP "x"
277324
, cnP "x" 0 := cosP "x"
278-
-- , cnP "x" 1 := 1 / powP (coshP "x") 2
325+
, cnP "x" 1 := powP (sechP "x") 2
279326
, dnP "x" 0 := 1
280-
-- , dnP "x" 1 := 1 / powP (coshP "x") 2
327+
, dnP "x" 1 := powP (sechP "x") 2
281328
, cosP ("x" + "y") := cosP "x" * cosP "y" - sinP "x" * sinP "y"
282329
, sinP ("x" + "y") := sinP "x" * cosP "y" + cosP "x" * sinP "y"
283330
, coshP ("x" + "y") := coshP "x" * coshP "y" + sinhP "x" * sinhP "y"
@@ -331,12 +378,12 @@ rewrites =
331378
, "a"-(fromInteger (-1)*"b") := "a"+"b"
332379

333380
] where
334-
n `choose` k
381+
n `_choose` k
335382
| k < 0 || k > n = 0
336383
| k == 0 || k == n = 1
337384
| k == 1 || k == n - 1 = n
338-
| 2 * k > n = n `choose` (n - k)
339-
| otherwise = (n - 1) `choose` (k - 1) * n `div` k
385+
| 2 * k > n = n `_choose` (n - k)
386+
| otherwise = (n - 1) `_choose` (k - 1) * n `div` k
340387

341388
rewrite :: Fix Expr -> Fix Expr
342389
rewrite e = fst $ equalitySaturation e rewrites symCost
@@ -435,7 +482,7 @@ symTests = testGroup "Jacobi"
435482

436483
-- TODO: More elliptic function identities may be worthwhile.
437484
, testCase "reduce (dn(x,k))^11 in terms of sn(x,k)" $
438-
rewrite (_pow (_dn "x" "k") 11) @?= _pow ((1 - _pow (_sn "x" "k") 2) / _pow "k" 2) 5 * _dn "x" "k" -- this should actually not be equal
485+
fst (equalitySaturation' (defaultBackoffScheduler { banLength = 100 }) (_pow (_dn "x" "k") 11) rewrites depthCost) @?= _pow ((1 - _pow (_sn "x" "k") 2) / _pow "k" 2) 5 * _dn "x" "k" -- this should actually not be equal
439486

440487
, testCase "reduce (dn(x,k))^1001 in terms of sn(x,k)" $
441488
rewrite (_pow (_dn "x" "k") 1001) @?= _pow ((1 - _pow (_sn "x" "k") 2) / _pow "k" 2) 500 * _dn "x" "k"
@@ -485,6 +532,12 @@ coshP a = NonVariablePattern (UnOp Cosh a)
485532
sinhP :: Pattern Expr -> Pattern Expr
486533
sinhP a = NonVariablePattern (UnOp Sinh a)
487534

535+
tanhP :: Pattern Expr -> Pattern Expr
536+
tanhP a = NonVariablePattern (Prod [NonVariablePattern $ UnOp Sinh a, NonVariablePattern $ BinOp Pow (NonVariablePattern $ UnOp Cosh a) (NonVariablePattern . Const $ -1)])
537+
538+
sechP :: Pattern Expr -> Pattern Expr
539+
sechP a = NonVariablePattern $ BinOp Pow (NonVariablePattern $ UnOp Cosh a) (NonVariablePattern . Const $ -1)
540+
488541
lnP :: Pattern Expr -> Pattern Expr
489542
lnP a = NonVariablePattern (UnOp Ln a)
490543

0 commit comments

Comments
 (0)