summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSylvain Henry <sylvain@haskus.fr>2020-10-23 14:59:47 +0200
committerMarge Bot <ben+marge-bot@smart-cactus.org>2020-10-31 02:54:34 -0400
commita98593f0c7623843a787af5fb628336cb897c527 (patch)
tree9f383bb56ccc7b1eda312d30b776071bd3df6b2b
parent57c3db9612463426e1724816fd3f98142fec0e31 (diff)
downloadhaskell-a98593f0c7623843a787af5fb628336cb897c527.tar.gz
Refactor numeric constant folding rules
Avoid the use of global pattern synonyms. 1) I think it's going to be helpful to implement constant folding for other numeric types, especially Natural which doesn't have a wrapping behavior. We'll have to refactor these rules even more so we'd better make them less cryptic. 2) It should also be slightly faster because global pattern synonyms matched operations for every numeric types instead of the current one: e.g., ":**:" pattern was matching multiplication for both Int# and Word# types. As we will probably want to implement constant folding for other numeric types (Int8#, Int16#, etc.), it is more efficient to only match primops for a given type as we do now.
-rw-r--r--compiler/GHC/Core/Opt/ConstantFold.hs488
-rw-r--r--testsuite/tests/simplCore/should_run/NumConstantFolding.hs109
-rw-r--r--testsuite/tests/simplCore/should_run/NumConstantFolding.stdout2
-rw-r--r--testsuite/tests/simplCore/should_run/all.T1
4 files changed, 443 insertions, 157 deletions
diff --git a/compiler/GHC/Core/Opt/ConstantFold.hs b/compiler/GHC/Core/Opt/ConstantFold.hs
index 92632347e1..b8c2c5d6fa 100644
--- a/compiler/GHC/Core/Opt/ConstantFold.hs
+++ b/compiler/GHC/Core/Opt/ConstantFold.hs
@@ -1,8 +1,6 @@
{-
(c) The GRASP/AQUA Project, Glasgow University, 1992-1998
-\section[ConFold]{Constant Folder}
-
Conceptually, constant folding should be parameterized with the kind
of target machine to get identical behaviour during compilation time
and runtime. We cheat a little bit here...
@@ -13,9 +11,18 @@ ToDo:
-}
{-# LANGUAGE CPP, RankNTypes, PatternSynonyms, ViewPatterns, RecordWildCards,
- DeriveFunctor, LambdaCase, TypeApplications #-}
+ DeriveFunctor, LambdaCase, TypeApplications, MultiWayIf #-}
+
{-# OPTIONS_GHC -optc-DNON_POSIX_SOURCE -Wno-incomplete-uni-patterns #-}
+#if __GLASGOW_HASKELL__ <= 808
+-- GHC 8.10 deprecates this flag, but GHC 8.8 needs it
+-- The default iteration limit is a bit too low for the definitions
+-- in this module.
+{-# OPTIONS_GHC -fmax-pmcheck-iterations=20000000 #-}
+#endif
+
+-- | Constant Folder
module GHC.Core.Opt.ConstantFold
( primOpRules
, builtinRules
@@ -100,12 +107,12 @@ primOpRules nm = \case
-- Int operations
IntAddOp -> mkPrimOpRule nm 2 [ binaryLit (intOp2 (+))
, identityPlatform zeroi
- , numFoldingRules IntAddOp intPrimOps
+ , addFoldingRules IntAddOp intOps
]
IntSubOp -> mkPrimOpRule nm 2 [ binaryLit (intOp2 (-))
, rightIdentityPlatform zeroi
, equalArgs >> retLit zeroi
- , numFoldingRules IntSubOp intPrimOps
+ , subFoldingRules IntSubOp intOps
]
IntAddCOp -> mkPrimOpRule nm 2 [ binaryLit (intOpC2 (+))
, identityCPlatform zeroi ]
@@ -115,7 +122,7 @@ primOpRules nm = \case
IntMulOp -> mkPrimOpRule nm 2 [ binaryLit (intOp2 (*))
, zeroElem zeroi
, identityPlatform onei
- , numFoldingRules IntMulOp intPrimOps
+ , mulFoldingRules IntMulOp intOps
]
IntQuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 quot)
, leftZero zeroi
@@ -152,12 +159,12 @@ primOpRules nm = \case
-- Word operations
WordAddOp -> mkPrimOpRule nm 2 [ binaryLit (wordOp2 (+))
, identityPlatform zerow
- , numFoldingRules WordAddOp wordPrimOps
+ , addFoldingRules WordAddOp wordOps
]
WordSubOp -> mkPrimOpRule nm 2 [ binaryLit (wordOp2 (-))
, rightIdentityPlatform zerow
, equalArgs >> retLit zerow
- , numFoldingRules WordSubOp wordPrimOps
+ , subFoldingRules WordSubOp wordOps
]
WordAddCOp -> mkPrimOpRule nm 2 [ binaryLit (wordOpC2 (+))
, identityCPlatform zerow ]
@@ -166,7 +173,7 @@ primOpRules nm = \case
, equalArgs >> retLitNoC zerow ]
WordMulOp -> mkPrimOpRule nm 2 [ binaryLit (wordOp2 (*))
, identityPlatform onew
- , numFoldingRules WordMulOp wordPrimOps
+ , mulFoldingRules WordMulOp wordOps
]
WordQuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 quot)
, rightIdentityPlatform onew ]
@@ -1878,181 +1885,348 @@ match_smallIntegerTo _ _ _ _ _ = Nothing
--
--------------------------------------------------------
--- | Rules to perform constant folding into nested expressions
+-- Rules to perform constant folding into nested expressions
--
--See Note [Constant folding through nested expressions]
-numFoldingRules :: PrimOp -> (Platform -> PrimOps) -> RuleM CoreExpr
-numFoldingRules op dict = do
- env <- getEnv
- if not (roNumConstantFolding env)
- then mzero
- else do
- [e1,e2] <- getArgs
- platform <- getPlatform
- let PrimOps{..} = dict platform
- case BinOpApp e1 op e2 of
- -- R1) +/- simplification
- x :++: (y :++: v) -> return $ mkL (x+y) `add` v
- x :++: (L y :-: v) -> return $ mkL (x+y) `sub` v
- x :++: (v :-: L y) -> return $ mkL (x-y) `add` v
- L x :-: (y :++: v) -> return $ mkL (x-y) `sub` v
- L x :-: (L y :-: v) -> return $ mkL (x-y) `add` v
- L x :-: (v :-: L y) -> return $ mkL (x+y) `sub` v
-
- (y :++: v) :-: L x -> return $ mkL (y-x) `add` v
- (L y :-: v) :-: L x -> return $ mkL (y-x) `sub` v
- (v :-: L y) :-: L x -> return $ mkL (0-y-x) `add` v
-
- (x :++: w) :+: (y :++: v) -> return $ mkL (x+y) `add` (w `add` v)
- (w :-: L x) :+: (L y :-: v) -> return $ mkL (y-x) `add` (w `sub` v)
- (w :-: L x) :+: (v :-: L y) -> return $ mkL (0-x-y) `add` (w `add` v)
- (L x :-: w) :+: (L y :-: v) -> return $ mkL (x+y) `sub` (w `add` v)
- (L x :-: w) :+: (v :-: L y) -> return $ mkL (x-y) `add` (v `sub` w)
- (w :-: L x) :+: (y :++: v) -> return $ mkL (y-x) `add` (w `add` v)
- (L x :-: w) :+: (y :++: v) -> return $ mkL (x+y) `add` (v `sub` w)
- (y :++: v) :+: (w :-: L x) -> return $ mkL (y-x) `add` (w `add` v)
- (y :++: v) :+: (L x :-: w) -> return $ mkL (x+y) `add` (v `sub` w)
-
- (v :-: L y) :-: (w :-: L x) -> return $ mkL (x-y) `add` (v `sub` w)
- (v :-: L y) :-: (L x :-: w) -> return $ mkL (0-x-y) `add` (v `add` w)
- (L y :-: v) :-: (w :-: L x) -> return $ mkL (x+y) `sub` (v `add` w)
- (L y :-: v) :-: (L x :-: w) -> return $ mkL (y-x) `add` (w `sub` v)
- (x :++: w) :-: (y :++: v) -> return $ mkL (x-y) `add` (w `sub` v)
- (w :-: L x) :-: (y :++: v) -> return $ mkL (0-y-x) `add` (w `sub` v)
- (L x :-: w) :-: (y :++: v) -> return $ mkL (x-y) `sub` (v `add` w)
- (y :++: v) :-: (w :-: L x) -> return $ mkL (y+x) `add` (v `sub` w)
- (y :++: v) :-: (L x :-: w) -> return $ mkL (y-x) `add` (v `add` w)
-
- -- R2) * simplification
- x :**: (y :**: v) -> return $ mkL (x*y) `mul` v
- (x :**: w) :*: (y :**: v) -> return $ mkL (x*y) `mul` (w `mul` v)
-
- -- R3) * distribution over +/-
- x :**: (y :++: v) -> return $ mkL (x*y) `add` (mkL x `mul` v)
- x :**: (L y :-: v) -> return $ mkL (x*y) `sub` (mkL x `mul` v)
- x :**: (v :-: L y) -> return $ (mkL x `mul` v) `sub` mkL (x*y)
-
- -- R4) Simple factorization
- v :+: w
- | w `cheapEqExpr` v -> return $ mkL 2 `mul` v
- w :+: (y :**: v)
- | w `cheapEqExpr` v -> return $ mkL (1+y) `mul` v
- w :-: (y :**: v)
- | w `cheapEqExpr` v -> return $ mkL (1-y) `mul` v
- (y :**: v) :+: w
- | w `cheapEqExpr` v -> return $ mkL (y+1) `mul` v
- (y :**: v) :-: w
- | w `cheapEqExpr` v -> return $ mkL (y-1) `mul` v
- (x :**: w) :+: (y :**: v)
- | w `cheapEqExpr` v -> return $ mkL (x+y) `mul` v
- (x :**: w) :-: (y :**: v)
- | w `cheapEqExpr` v -> return $ mkL (x-y) `mul` v
-
- -- R5) +/- propagation
- w :+: (y :++: v) -> return $ mkL y `add` (w `add` v)
- (y :++: v) :+: w -> return $ mkL y `add` (w `add` v)
- w :-: (y :++: v) -> return $ (w `sub` v) `sub` mkL y
- (y :++: v) :-: w -> return $ mkL y `add` (v `sub` w)
- w :-: (L y :-: v) -> return $ (w `add` v) `sub` mkL y
- (L y :-: v) :-: w -> return $ mkL y `sub` (w `add` v)
- w :+: (L y :-: v) -> return $ mkL y `add` (w `sub` v)
- w :+: (v :-: L y) -> return $ (w `add` v) `sub` mkL y
- (L y :-: v) :+: w -> return $ mkL y `add` (w `sub` v)
- (v :-: L y) :+: w -> return $ (w `add` v) `sub` mkL y
-
- _ -> mzero
+addFoldingRules :: PrimOp -> NumOps -> RuleM CoreExpr
+addFoldingRules op num_ops = do
+ ASSERT(op == numAdd num_ops) return ()
+ env <- getEnv
+ guard (roNumConstantFolding env)
+ [arg1,arg2] <- getArgs
+ platform <- getPlatform
+ liftMaybe
+ -- commutativity for + is handled here
+ (addFoldingRules' platform arg1 arg2 num_ops
+ <|> addFoldingRules' platform arg2 arg1 num_ops)
+
+subFoldingRules :: PrimOp -> NumOps -> RuleM CoreExpr
+subFoldingRules op num_ops = do
+ ASSERT(op == numSub num_ops) return ()
+ env <- getEnv
+ guard (roNumConstantFolding env)
+ [arg1,arg2] <- getArgs
+ platform <- getPlatform
+ liftMaybe (subFoldingRules' platform arg1 arg2 num_ops)
+
+mulFoldingRules :: PrimOp -> NumOps -> RuleM CoreExpr
+mulFoldingRules op num_ops = do
+ ASSERT(op == numMul num_ops) return ()
+ env <- getEnv
+ guard (roNumConstantFolding env)
+ [arg1,arg2] <- getArgs
+ platform <- getPlatform
+ liftMaybe
+ -- commutativity for * is handled here
+ (mulFoldingRules' platform arg1 arg2 num_ops
+ <|> mulFoldingRules' platform arg2 arg1 num_ops)
+
+
+addFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr
+addFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of
+ -- R1) +/- simplification
+
+ -- l1 + (l2 + x) ==> (l1+l2) + x
+ (L l1, is_lit_add num_ops -> Just (l2,x))
+ -> Just (mkL (l1+l2) `add` x)
+
+ -- l1 + (l2 - x) ==> (l1+l2) - x
+ (L l1, is_sub num_ops -> Just (L l2,x))
+ -> Just (mkL (l1+l2) `sub` x)
+
+ -- l1 + (x - l2) ==> (l1-l2) + x
+ (L l1, is_sub num_ops -> Just (x,L l2))
+ -> Just (mkL (l1-l2) `add` x)
+
+ -- (l1 + x) + (l2 + y) ==> (l1+l2) + (x+y)
+ (is_lit_add num_ops -> Just (l1,x), is_lit_add num_ops -> Just (l2,y))
+ -> Just (mkL (l1+l2) `add` (x `add` y))
+
+ -- (l1 + x) + (l2 - y) ==> (l1+l2) + (x-y)
+ (is_lit_add num_ops -> Just (l1,x), is_sub num_ops -> Just (L l2,y))
+ -> Just (mkL (l1+l2) `add` (x `sub` y))
+
+ -- (l1 + x) + (y - l2) ==> (l1-l2) + (x+y)
+ (is_lit_add num_ops -> Just (l1,x), is_sub num_ops -> Just (y,L l2))
+ -> Just (mkL (l1-l2) `add` (x `add` y))
+
+ -- (l1 - x) + (l2 - y) ==> (l1+l2) - (x+y)
+ (is_sub num_ops -> Just (L l1,x), is_sub num_ops -> Just (L l2,y))
+ -> Just (mkL (l1+l2) `sub` (x `add` y))
+
+ -- (l1 - x) + (y - l2) ==> (l1-l2) + (y-x)
+ (is_sub num_ops -> Just (L l1,x), is_sub num_ops -> Just (y,L l2))
+ -> Just (mkL (l1-l2) `add` (y `sub` x))
+
+ -- (x - l1) + (y - l2) ==> (0-l1-l2) + (x+y)
+ (is_sub num_ops -> Just (x,L l1), is_sub num_ops -> Just (y,L l2))
+ -> Just (mkL (0-l1-l2) `add` (x `add` y))
+
+ -- R4) Simple factorization
+
+ -- x + x ==> 2 * x
+ _ | Just l1 <- is_expr_mul num_ops arg1 arg2
+ -> Just (mkL (l1+1) `mul` arg1)
+
+ -- (l1 * x) + x ==> (l1+1) * x
+ _ | Just l1 <- is_expr_mul num_ops arg2 arg1
+ -> Just (mkL (l1+1) `mul` arg2)
+
+ -- (l1 * x) + (l2 * x) ==> (l1+l2) * x
+ (is_lit_mul num_ops -> Just (l1,x), is_expr_mul num_ops x -> Just l2)
+ -> Just (mkL (l1+l2) `mul` x)
+
+ -- R5) +/- propagation: these transformations push literals outwards
+ -- with the hope that other rules can then be applied.
+
+ -- In the following rules, x can't be a literal otherwise another
+ -- rule would have combined it with the other literal in arg2. So we
+ -- don't have to check this to avoid loops here.
+
+ -- x + (l1 + y) ==> l1 + (x + y)
+ (_, is_lit_add num_ops -> Just (l1,y))
+ -> Just (mkL l1 `add` (arg1 `add` y))
+
+ -- x + (l1 - y) ==> l1 + (x - y)
+ (_, is_sub num_ops -> Just (L l1,y))
+ -> Just (mkL l1 `add` (arg1 `sub` y))
+
+ -- x + (y - l1) ==> (x + y) - l1
+ (_, is_sub num_ops -> Just (y,L l1))
+ -> Just ((arg1 `add` y) `sub` mkL l1)
+
+ _ -> Nothing
+ where
+ mkL = Lit . mkNumLiteral platform num_ops
+ add x y = BinOpApp x (numAdd num_ops) y
+ sub x y = BinOpApp x (numSub num_ops) y
+ mul x y = BinOpApp x (numMul num_ops) y
--- | Match the application of a binary primop
-pattern BinOpApp :: Arg CoreBndr -> PrimOp -> Arg CoreBndr -> CoreExpr
-pattern BinOpApp x op y = OpVal op `App` x `App` y
+subFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr
+subFoldingRules' platform arg1 arg2 num_ops = case (arg1,arg2) of
+ -- R1) +/- simplification
--- | Match a primop
-pattern OpVal :: PrimOp -> Arg CoreBndr
-pattern OpVal op <- Var (isPrimOpId_maybe -> Just op) where
- OpVal op = Var (mkPrimOpId op)
+ -- l1 - (l2 + x) ==> (l1-l2) - x
+ (L l1, is_lit_add num_ops -> Just (l2,x))
+ -> Just (mkL (l1-l2) `sub` x)
+ -- l1 - (l2 - x) ==> (l1-l2) + x
+ (L l1, is_sub num_ops -> Just (L l2,x))
+ -> Just (mkL (l1-l2) `add` x)
+ -- l1 - (x - l2) ==> (l1+l2) - x
+ (L l1, is_sub num_ops -> Just (x, L l2))
+ -> Just (mkL (l1+l2) `sub` x)
--- | Match a literal
-pattern L :: Integer -> Arg CoreBndr
-pattern L l <- Lit (isLitValue_maybe -> Just l)
+ -- (l1 + x) - l2 ==> (l1-l2) + x
+ (is_lit_add num_ops -> Just (l1,x), L l2)
+ -> Just (mkL (l1-l2) `add` x)
+
+ -- (l1 - x) - l2 ==> (l1-l2) - x
+ (is_sub num_ops -> Just (L l1,x), L l2)
+ -> Just (mkL (l1-l2) `sub` x)
+
+ -- (x - l1) - l2 ==> x - (l1+l2)
+ (is_sub num_ops -> Just (x,L l1), L l2)
+ -> Just (x `sub` mkL (l1+l2))
+
+
+ -- (l1 + x) - (l2 + y) ==> (l1-l2) + (x-y)
+ (is_lit_add num_ops -> Just (l1,x), is_lit_add num_ops -> Just (l2,y))
+ -> Just (mkL (l1-l2) `add` (x `sub` y))
+
+ -- (l1 + x) - (l2 - y) ==> (l1-l2) + (x+y)
+ (is_lit_add num_ops -> Just (l1,x), is_sub num_ops -> Just (L l2,y))
+ -> Just (mkL (l1-l2) `add` (x `add` y))
+
+ -- (l1 + x) - (y - l2) ==> (l1+l2) + (x-y)
+ (is_lit_add num_ops -> Just (l1,x), is_sub num_ops -> Just (y,L l2))
+ -> Just (mkL (l1+l2) `add` (x `sub` y))
+
+ -- (l1 - x) - (l2 + y) ==> (l1-l2) - (x+y)
+ (is_sub num_ops -> Just (L l1,x), is_lit_add num_ops -> Just (l2,y))
+ -> Just (mkL (l1-l2) `sub` (x `add` y))
+
+ -- (x - l1) - (l2 + y) ==> (0-l1-l2) + (x-y)
+ (is_sub num_ops -> Just (x,L l1), is_lit_add num_ops -> Just (l2,y))
+ -> Just (mkL (0-l1-l2) `add` (x `sub` y))
+
+ -- (l1 - x) - (l2 - y) ==> (l1-l2) + (y-x)
+ (is_sub num_ops -> Just (L l1,x), is_sub num_ops -> Just (L l2,y))
+ -> Just (mkL (l1-l2) `add` (y `sub` x))
+
+ -- (l1 - x) - (y - l2) ==> (l1+l2) - (x+y)
+ (is_sub num_ops -> Just (L l1,x), is_sub num_ops -> Just (y,L l2))
+ -> Just (mkL (l1+l2) `sub` (x `add` y))
+
+ -- (x - l1) - (l2 - y) ==> (0-l1-l2) + (x+y)
+ (is_sub num_ops -> Just (x,L l1), is_sub num_ops -> Just (L l2,y))
+ -> Just (mkL (0-l1-l2) `add` (x `add` y))
+
+ -- (x - l1) - (y - l2) ==> (l2-l1) + (x-y)
+ (is_sub num_ops -> Just (x,L l1), is_sub num_ops -> Just (y,L l2))
+ -> Just (mkL (l2-l1) `add` (x `sub` y))
+
+ -- R4) Simple factorization
--- | Match an addition
-pattern (:+:) :: Arg CoreBndr -> Arg CoreBndr -> CoreExpr
-pattern x :+: y <- BinOpApp x (isAddOp -> True) y
+ -- x - (l1 * x) ==> (1-l1) * x
+ _ | Just l1 <- is_expr_mul num_ops arg1 arg2
+ -> Just (mkL (1-l1) `mul` arg1)
--- | Match an addition with a literal (handle commutativity)
-pattern (:++:) :: Integer -> Arg CoreBndr -> CoreExpr
-pattern l :++: x <- (isAdd -> Just (l,x))
+ -- (l1 * x) - x ==> (l1-1) * x
+ _ | Just l1 <- is_expr_mul num_ops arg2 arg1
+ -> Just (mkL (l1-1) `mul` arg2)
-isAdd :: CoreExpr -> Maybe (Integer,CoreExpr)
-isAdd e = case e of
- L l :+: x -> Just (l,x)
- x :+: L l -> Just (l,x)
- _ -> Nothing
+ -- (l1 * x) - (l2 * x) ==> (l1-l2) * x
+ (is_lit_mul num_ops -> Just (l1,x), is_expr_mul num_ops x -> Just l2)
+ -> Just (mkL (l1-l2) `mul` x)
--- | Match a multiplication
-pattern (:*:) :: Arg CoreBndr -> Arg CoreBndr -> CoreExpr
-pattern x :*: y <- BinOpApp x (isMulOp -> True) y
+ -- R5) +/- propagation: these transformations push literals outwards
+ -- with the hope that other rules can then be applied.
--- | Match a multiplication with a literal (handle commutativity)
-pattern (:**:) :: Integer -> Arg CoreBndr -> CoreExpr
-pattern l :**: x <- (isMul -> Just (l,x))
+ -- In the following rules, x can't be a literal otherwise another
+ -- rule would have combined it with the other literal in arg2. So we
+ -- don't have to check this to avoid loops here.
-isMul :: CoreExpr -> Maybe (Integer,CoreExpr)
-isMul e = case e of
- L l :*: x -> Just (l,x)
- x :*: L l -> Just (l,x)
- _ -> Nothing
+ -- x - (l1 + y) ==> (x - y) - l1
+ (_, is_lit_add num_ops -> Just (l1,y))
+ -> Just ((arg1 `sub` y) `sub` mkL l1)
+ -- (l1 + x) - y ==> l1 + (x - y)
+ (is_lit_add num_ops -> Just (l1,x), _)
+ -> Just (mkL l1 `add` (x `sub` arg2))
--- | Match a subtraction
-pattern (:-:) :: Arg CoreBndr -> Arg CoreBndr -> CoreExpr
-pattern x :-: y <- BinOpApp x (isSubOp -> True) y
+ -- x - (l1 - y) ==> (x + y) - l1
+ (_, is_sub num_ops -> Just (L l1,y))
+ -> Just ((arg1 `add` y) `sub` mkL l1)
-isSubOp :: PrimOp -> Bool
-isSubOp IntSubOp = True
-isSubOp WordSubOp = True
-isSubOp _ = False
+ -- x - (y - l1) ==> l1 + (x - y)
+ (_, is_sub num_ops -> Just (y,L l1))
+ -> Just (mkL l1 `add` (arg1 `sub` y))
-isAddOp :: PrimOp -> Bool
-isAddOp IntAddOp = True
-isAddOp WordAddOp = True
-isAddOp _ = False
+ -- (l1 - x) - y ==> l1 - (x + y)
+ (is_sub num_ops -> Just (L l1,x), _)
+ -> Just (mkL l1 `sub` (x `add` arg2))
-isMulOp :: PrimOp -> Bool
-isMulOp IntMulOp = True
-isMulOp WordMulOp = True
-isMulOp _ = False
+ -- (x - l1) - y ==> (x - y) - l1
+ (is_sub num_ops -> Just (x,L l1), _)
+ -> Just ((x `sub` arg2) `sub` mkL l1)
+
+ _ -> Nothing
+ where
+ mkL = Lit . mkNumLiteral platform num_ops
+ add x y = BinOpApp x (numAdd num_ops) y
+ sub x y = BinOpApp x (numSub num_ops) y
+ mul x y = BinOpApp x (numMul num_ops) y
+
+mulFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr
+mulFoldingRules' platform arg1 arg2 num_ops = case (arg1,arg2) of
+ -- l1 * (l2 * x) ==> (l1*l2) * x
+ (L l1, is_lit_mul num_ops -> Just (l2,x))
+ -> Just (mkL (l1*l2) `mul` x)
+
+ -- l1 * (l2 + x) ==> (l1*l2) + (l1 * x)
+ (L l1, is_lit_add num_ops -> Just (l2,x))
+ -> Just (mkL (l1*l2) `add` (arg1 `mul` x))
+
+ -- l1 * (l2 - x) ==> (l1*l2) - (l1 * x)
+ (L l1, is_sub num_ops -> Just (L l2,x))
+ -> Just (mkL (l1*l2) `sub` (arg1 `mul` x))
+
+ -- l1 * (x - l2) ==> (l1 * x) - (l1*l2)
+ (L l1, is_sub num_ops -> Just (x, L l2))
+ -> Just ((arg1 `mul` x) `sub` mkL (l1*l2))
+
+ -- (l1 * x) * (l2 * y) ==> (l1*l2) * (x * y)
+ (is_lit_mul num_ops -> Just (l1,x), is_lit_mul num_ops -> Just (l2,y))
+ -> Just (mkL (l1*l2) `mul` (x `mul` y))
+
+ _ -> Nothing
+ where
+ mkL = Lit . mkNumLiteral platform num_ops
+ add x y = BinOpApp x (numAdd num_ops) y
+ sub x y = BinOpApp x (numSub num_ops) y
+ mul x y = BinOpApp x (numMul num_ops) y
+
+is_op :: PrimOp -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr)
+is_op op e = case e of
+ BinOpApp x op' y | op == op' -> Just (x,y)
+ _ -> Nothing
+
+is_add, is_sub, is_mul :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr)
+is_add num_ops = is_op (numAdd num_ops)
+is_sub num_ops = is_op (numSub num_ops)
+is_mul num_ops = is_op (numMul num_ops)
+
+-- match addition with a literal (handles commutativity)
+is_lit_add :: NumOps -> CoreExpr -> Maybe (Integer, Arg CoreBndr)
+is_lit_add num_ops e = case is_add num_ops e of
+ Just (L l, x ) -> Just (l,x)
+ Just (x , L l) -> Just (l,x)
+ _ -> Nothing
+
+-- match multiplication with a literal (handles commutativity)
+is_lit_mul :: NumOps -> CoreExpr -> Maybe (Integer, Arg CoreBndr)
+is_lit_mul num_ops e = case is_mul num_ops e of
+ Just (L l, x ) -> Just (l,x)
+ Just (x , L l) -> Just (l,x)
+ _ -> Nothing
+
+-- match given "x": return 1
+-- match "lit * x": return lit value (handles commutativity)
+is_expr_mul :: NumOps -> Expr CoreBndr -> Expr CoreBndr -> Maybe Integer
+is_expr_mul num_ops x e = if
+ | x `cheapEqExpr` e
+ -> Just 1
+ | Just (k,x') <- is_lit_mul num_ops e
+ , x `cheapEqExpr` x'
+ -> return k
+ | otherwise
+ -> Nothing
+
+
+-- | Match the application of a binary primop
+pattern BinOpApp :: Arg CoreBndr -> PrimOp -> Arg CoreBndr -> CoreExpr
+pattern BinOpApp x op y = OpVal op `App` x `App` y
+
+-- | Match a primop
+pattern OpVal:: PrimOp -> Arg CoreBndr
+pattern OpVal op <- Var (isPrimOpId_maybe -> Just op) where
+ OpVal op = Var (mkPrimOpId op)
+
+-- | Match a literal
+pattern L :: Integer -> Arg CoreBndr
+pattern L l <- Lit (isLitValue_maybe -> Just l)
-- | Explicit "type-class"-like dictionary for numeric primops
---
--- Depends on Platform because creating a literal value depends on Platform
-data PrimOps = PrimOps
- { add :: CoreExpr -> CoreExpr -> CoreExpr -- ^ Add two numbers
- , sub :: CoreExpr -> CoreExpr -> CoreExpr -- ^ Sub two numbers
- , mul :: CoreExpr -> CoreExpr -> CoreExpr -- ^ Multiply two numbers
- , mkL :: Integer -> CoreExpr -- ^ Create a literal value
+data NumOps = NumOps
+ { numAdd :: !PrimOp -- ^ Add two numbers
+ , numSub :: !PrimOp -- ^ Sub two numbers
+ , numMul :: !PrimOp -- ^ Multiply two numbers
+ , numLitType :: !LitNumType -- ^ Literal type
}
-intPrimOps :: Platform -> PrimOps
-intPrimOps platform = PrimOps
- { add = \x y -> BinOpApp x IntAddOp y
- , sub = \x y -> BinOpApp x IntSubOp y
- , mul = \x y -> BinOpApp x IntMulOp y
- , mkL = intResult' platform
- }
+-- | Create a numeric literal
+mkNumLiteral :: Platform -> NumOps -> Integer -> Literal
+mkNumLiteral platform ops i = mkLitNumberWrap platform (numLitType ops) i
-wordPrimOps :: Platform -> PrimOps
-wordPrimOps platform = PrimOps
- { add = \x y -> BinOpApp x WordAddOp y
- , sub = \x y -> BinOpApp x WordSubOp y
- , mul = \x y -> BinOpApp x WordMulOp y
- , mkL = wordResult' platform
+intOps :: NumOps
+intOps = NumOps
+ { numAdd = IntAddOp
+ , numSub = IntSubOp
+ , numMul = IntMulOp
+ , numLitType = LitNumInt
}
+wordOps :: NumOps
+wordOps = NumOps
+ { numAdd = WordAddOp
+ , numSub = WordSubOp
+ , numMul = WordMulOp
+ , numLitType = LitNumWord
+ }
--------------------------------------------------------
-- Constant folding through case-expressions
diff --git a/testsuite/tests/simplCore/should_run/NumConstantFolding.hs b/testsuite/tests/simplCore/should_run/NumConstantFolding.hs
new file mode 100644
index 0000000000..6466adfe4d
--- /dev/null
+++ b/testsuite/tests/simplCore/should_run/NumConstantFolding.hs
@@ -0,0 +1,109 @@
+{-# LANGUAGE MagicHash #-}
+
+import GHC.Exts
+import Data.Word
+import Data.Int
+
+(+&) = plusWord#
+(-&) = minusWord#
+(*&) = timesWord#
+
+{-# NOINLINE testsW #-}
+-- NOINLINE otherwise basic constant folding rules (without
+-- variables) are applied
+testsW :: Word# -> Word# -> [Word]
+testsW x y = fmap (\z -> fromIntegral (fromIntegral z :: Word32))
+ -- narrowing to get the same results on both 64- and 32-bit arch
+ [ W# (43## +& (37## +& x))
+ , W# (43## +& (37## -& x))
+ , W# (43## +& (x -& 37##))
+ , W# (43## -& (37## +& x))
+ , W# (43## -& (37## -& x))
+ , W# (43## -& (x -& 37##))
+ , W# ((43## +& x) -& 37##)
+ , W# ((x +& 43##) -& 37##)
+ , W# ((43## -& x) -& 37##)
+ , W# ((x -& 43##) -& 37##)
+
+ , W# ((x +& 43##) +& (y +& 37##))
+ , W# ((x +& 43##) +& (y -& 37##))
+ , W# ((x +& 43##) +& (37## -& y))
+ , W# ((x -& 43##) +& (37## -& y))
+ , W# ((x -& 43##) +& (y -& 37##))
+ , W# ((43## -& x) +& (37## -& y))
+ , W# ((43## -& x) +& (y -& 37##))
+ ]
+
+{-# NOINLINE testsI #-}
+testsI :: Int# -> Int# -> [Int]
+testsI x y = fmap (\z -> fromIntegral (fromIntegral z :: Int32))
+ [ I# (43# +# (37# +# x))
+ , I# (43# +# (37# -# x))
+ , I# (43# +# (x -# 37#))
+ , I# (43# -# (37# +# x))
+ , I# (43# -# (37# -# x))
+ , I# (43# -# (x -# 37#))
+ , I# ((43# +# x) -# 37#)
+ , I# ((x +# 43#) -# 37#)
+ , I# ((43# -# x) -# 37#)
+ , I# ((x -# 43#) -# 37#)
+
+ , I# ((x +# 43#) +# (y +# 37#))
+ , I# ((x +# 43#) +# (y -# 37#))
+ , I# ((x +# 43#) +# (37# -# y))
+ , I# ((x -# 43#) +# (37# -# y))
+ , I# ((x -# 43#) +# (y -# 37#))
+ , I# ((43# -# x) +# (37# -# y))
+ , I# ((43# -# x) +# (y -# 37#))
+
+ , I# ((x +# 43#) -# (y +# 37#))
+ , I# ((x +# 43#) -# (y -# 37#))
+ , I# ((x +# 43#) -# (37# -# y))
+ , I# ((x -# 43#) -# (y +# 37#))
+ , I# ((43# -# x) -# (37# +# y))
+ , I# ((x -# 43#) -# (y -# 37#))
+ , I# ((x -# 43#) -# (37# -# y))
+ , I# ((43# -# x) -# (y -# 37#))
+ , I# ((43# -# x) -# (37# -# y))
+
+ , I# (43# *# (37# *# y))
+ , I# (43# *# (y *# 37#))
+ , I# ((43# *# x) *# (y *# 37#))
+
+ , I# (43# *# (37# +# y))
+ , I# (43# *# (37# -# y))
+ , I# (43# *# (y -# 37#))
+
+ , I# (x +# x)
+ , I# ((43# *# x) +# x)
+ , I# (x +# (43# *# x))
+ , I# ((43# *# x) +# (37# *# x))
+ , I# ((43# *# x) +# (x *# 37#))
+
+ , I# (x -# x)
+ , I# ((43# *# x) -# x)
+ , I# (x -# (43# *# x))
+ , I# ((43# *# x) -# (37# *# x))
+ , I# ((43# *# x) -# (x *# 37#))
+
+ , I# (x +# (37# +# y))
+ , I# (x +# (y +# 37#))
+ , I# (x +# (37# -# y))
+ , I# (x +# (y -# 37#))
+ , I# (x -# (37# +# y))
+ , I# (x -# (y +# 37#))
+ , I# (x -# (37# -# y))
+ , I# (x -# (y -# 37#))
+ , I# ((37# +# y) -# x)
+ , I# ((y +# 37#) -# x)
+ , I# ((37# -# y) -# x)
+ , I# ((y -# 37#) -# x)
+
+ , I# (y *# y)
+ ]
+
+
+main :: IO ()
+main = do
+ print (testsW 7## 13##)
+ print (testsI 7# 13#)
diff --git a/testsuite/tests/simplCore/should_run/NumConstantFolding.stdout b/testsuite/tests/simplCore/should_run/NumConstantFolding.stdout
new file mode 100644
index 0000000000..da6f72855f
--- /dev/null
+++ b/testsuite/tests/simplCore/should_run/NumConstantFolding.stdout
@@ -0,0 +1,2 @@
+[87,73,13,4294967295,13,73,13,13,4294967295,4294967223,100,26,74,4294967284,4294967236,60,12]
+[87,73,13,-1,13,73,13,13,-1,-73,100,26,74,-12,-60,60,12,0,74,26,-86,-14,-12,-60,60,12,20683,20683,144781,2150,1032,-1032,14,308,308,560,560,0,294,-294,42,42,57,57,31,-17,-43,-43,-17,31,43,43,17,-31,169]
diff --git a/testsuite/tests/simplCore/should_run/all.T b/testsuite/tests/simplCore/should_run/all.T
index a04558be89..ea10cd7914 100644
--- a/testsuite/tests/simplCore/should_run/all.T
+++ b/testsuite/tests/simplCore/should_run/all.T
@@ -93,3 +93,4 @@ test('T17151', [], multimod_compile_and_run, ['T17151', ''])
test('T18012', normal, compile_and_run, [''])
test('T17744', normal, compile_and_run, [''])
test('T18638', normal, compile_and_run, [''])
+test('NumConstantFolding', normal, compile_and_run, [''])