summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/GHC/Core/Opt/ConstantFold.hs488
-rw-r--r--testsuite/tests/simplCore/should_run/NumConstantFolding.hs109
-rw-r--r--testsuite/tests/simplCore/should_run/NumConstantFolding.stdout2
-rw-r--r--testsuite/tests/simplCore/should_run/all.T1
4 files changed, 443 insertions, 157 deletions
diff --git a/compiler/GHC/Core/Opt/ConstantFold.hs b/compiler/GHC/Core/Opt/ConstantFold.hs
index 92632347e1..b8c2c5d6fa 100644
--- a/compiler/GHC/Core/Opt/ConstantFold.hs
+++ b/compiler/GHC/Core/Opt/ConstantFold.hs
@@ -1,8 +1,6 @@
{-
(c) The GRASP/AQUA Project, Glasgow University, 1992-1998
-\section[ConFold]{Constant Folder}
-
Conceptually, constant folding should be parameterized with the kind
of target machine to get identical behaviour during compilation time
and runtime. We cheat a little bit here...
@@ -13,9 +11,18 @@ ToDo:
-}
{-# LANGUAGE CPP, RankNTypes, PatternSynonyms, ViewPatterns, RecordWildCards,
- DeriveFunctor, LambdaCase, TypeApplications #-}
+ DeriveFunctor, LambdaCase, TypeApplications, MultiWayIf #-}
+
{-# OPTIONS_GHC -optc-DNON_POSIX_SOURCE -Wno-incomplete-uni-patterns #-}
+#if __GLASGOW_HASKELL__ <= 808
+-- GHC 8.10 deprecates this flag, but GHC 8.8 needs it
+-- The default iteration limit is a bit too low for the definitions
+-- in this module.
+{-# OPTIONS_GHC -fmax-pmcheck-iterations=20000000 #-}
+#endif
+
+-- | Constant Folder
module GHC.Core.Opt.ConstantFold
( primOpRules
, builtinRules
@@ -100,12 +107,12 @@ primOpRules nm = \case
-- Int operations
IntAddOp -> mkPrimOpRule nm 2 [ binaryLit (intOp2 (+))
, identityPlatform zeroi
- , numFoldingRules IntAddOp intPrimOps
+ , addFoldingRules IntAddOp intOps
]
IntSubOp -> mkPrimOpRule nm 2 [ binaryLit (intOp2 (-))
, rightIdentityPlatform zeroi
, equalArgs >> retLit zeroi
- , numFoldingRules IntSubOp intPrimOps
+ , subFoldingRules IntSubOp intOps
]
IntAddCOp -> mkPrimOpRule nm 2 [ binaryLit (intOpC2 (+))
, identityCPlatform zeroi ]
@@ -115,7 +122,7 @@ primOpRules nm = \case
IntMulOp -> mkPrimOpRule nm 2 [ binaryLit (intOp2 (*))
, zeroElem zeroi
, identityPlatform onei
- , numFoldingRules IntMulOp intPrimOps
+ , mulFoldingRules IntMulOp intOps
]
IntQuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 quot)
, leftZero zeroi
@@ -152,12 +159,12 @@ primOpRules nm = \case
-- Word operations
WordAddOp -> mkPrimOpRule nm 2 [ binaryLit (wordOp2 (+))
, identityPlatform zerow
- , numFoldingRules WordAddOp wordPrimOps
+ , addFoldingRules WordAddOp wordOps
]
WordSubOp -> mkPrimOpRule nm 2 [ binaryLit (wordOp2 (-))
, rightIdentityPlatform zerow
, equalArgs >> retLit zerow
- , numFoldingRules WordSubOp wordPrimOps
+ , subFoldingRules WordSubOp wordOps
]
WordAddCOp -> mkPrimOpRule nm 2 [ binaryLit (wordOpC2 (+))
, identityCPlatform zerow ]
@@ -166,7 +173,7 @@ primOpRules nm = \case
, equalArgs >> retLitNoC zerow ]
WordMulOp -> mkPrimOpRule nm 2 [ binaryLit (wordOp2 (*))
, identityPlatform onew
- , numFoldingRules WordMulOp wordPrimOps
+ , mulFoldingRules WordMulOp wordOps
]
WordQuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 quot)
, rightIdentityPlatform onew ]
@@ -1878,181 +1885,348 @@ match_smallIntegerTo _ _ _ _ _ = Nothing
--
--------------------------------------------------------
--- | Rules to perform constant folding into nested expressions
+-- Rules to perform constant folding into nested expressions
--
--See Note [Constant folding through nested expressions]
-numFoldingRules :: PrimOp -> (Platform -> PrimOps) -> RuleM CoreExpr
-numFoldingRules op dict = do
- env <- getEnv
- if not (roNumConstantFolding env)
- then mzero
- else do
- [e1,e2] <- getArgs
- platform <- getPlatform
- let PrimOps{..} = dict platform
- case BinOpApp e1 op e2 of
- -- R1) +/- simplification
- x :++: (y :++: v) -> return $ mkL (x+y) `add` v
- x :++: (L y :-: v) -> return $ mkL (x+y) `sub` v
- x :++: (v :-: L y) -> return $ mkL (x-y) `add` v
- L x :-: (y :++: v) -> return $ mkL (x-y) `sub` v
- L x :-: (L y :-: v) -> return $ mkL (x-y) `add` v
- L x :-: (v :-: L y) -> return $ mkL (x+y) `sub` v
-
- (y :++: v) :-: L x -> return $ mkL (y-x) `add` v
- (L y :-: v) :-: L x -> return $ mkL (y-x) `sub` v
- (v :-: L y) :-: L x -> return $ mkL (0-y-x) `add` v
-
- (x :++: w) :+: (y :++: v) -> return $ mkL (x+y) `add` (w `add` v)
- (w :-: L x) :+: (L y :-: v) -> return $ mkL (y-x) `add` (w `sub` v)
- (w :-: L x) :+: (v :-: L y) -> return $ mkL (0-x-y) `add` (w `add` v)
- (L x :-: w) :+: (L y :-: v) -> return $ mkL (x+y) `sub` (w `add` v)
- (L x :-: w) :+: (v :-: L y) -> return $ mkL (x-y) `add` (v `sub` w)
- (w :-: L x) :+: (y :++: v) -> return $ mkL (y-x) `add` (w `add` v)
- (L x :-: w) :+: (y :++: v) -> return $ mkL (x+y) `add` (v `sub` w)
- (y :++: v) :+: (w :-: L x) -> return $ mkL (y-x) `add` (w `add` v)
- (y :++: v) :+: (L x :-: w) -> return $ mkL (x+y) `add` (v `sub` w)
-
- (v :-: L y) :-: (w :-: L x) -> return $ mkL (x-y) `add` (v `sub` w)
- (v :-: L y) :-: (L x :-: w) -> return $ mkL (0-x-y) `add` (v `add` w)
- (L y :-: v) :-: (w :-: L x) -> return $ mkL (x+y) `sub` (v `add` w)
- (L y :-: v) :-: (L x :-: w) -> return $ mkL (y-x) `add` (w `sub` v)
- (x :++: w) :-: (y :++: v) -> return $ mkL (x-y) `add` (w `sub` v)
- (w :-: L x) :-: (y :++: v) -> return $ mkL (0-y-x) `add` (w `sub` v)
- (L x :-: w) :-: (y :++: v) -> return $ mkL (x-y) `sub` (v `add` w)
- (y :++: v) :-: (w :-: L x) -> return $ mkL (y+x) `add` (v `sub` w)
- (y :++: v) :-: (L x :-: w) -> return $ mkL (y-x) `add` (v `add` w)
-
- -- R2) * simplification
- x :**: (y :**: v) -> return $ mkL (x*y) `mul` v
- (x :**: w) :*: (y :**: v) -> return $ mkL (x*y) `mul` (w `mul` v)
-
- -- R3) * distribution over +/-
- x :**: (y :++: v) -> return $ mkL (x*y) `add` (mkL x `mul` v)
- x :**: (L y :-: v) -> return $ mkL (x*y) `sub` (mkL x `mul` v)
- x :**: (v :-: L y) -> return $ (mkL x `mul` v) `sub` mkL (x*y)
-
- -- R4) Simple factorization
- v :+: w
- | w `cheapEqExpr` v -> return $ mkL 2 `mul` v
- w :+: (y :**: v)
- | w `cheapEqExpr` v -> return $ mkL (1+y) `mul` v
- w :-: (y :**: v)
- | w `cheapEqExpr` v -> return $ mkL (1-y) `mul` v
- (y :**: v) :+: w
- | w `cheapEqExpr` v -> return $ mkL (y+1) `mul` v
- (y :**: v) :-: w
- | w `cheapEqExpr` v -> return $ mkL (y-1) `mul` v
- (x :**: w) :+: (y :**: v)
- | w `cheapEqExpr` v -> return $ mkL (x+y) `mul` v
- (x :**: w) :-: (y :**: v)
- | w `cheapEqExpr` v -> return $ mkL (x-y) `mul` v
-
- -- R5) +/- propagation
- w :+: (y :++: v) -> return $ mkL y `add` (w `add` v)
- (y :++: v) :+: w -> return $ mkL y `add` (w `add` v)
- w :-: (y :++: v) -> return $ (w `sub` v) `sub` mkL y
- (y :++: v) :-: w -> return $ mkL y `add` (v `sub` w)
- w :-: (L y :-: v) -> return $ (w `add` v) `sub` mkL y
- (L y :-: v) :-: w -> return $ mkL y `sub` (w `add` v)
- w :+: (L y :-: v) -> return $ mkL y `add` (w `sub` v)
- w :+: (v :-: L y) -> return $ (w `add` v) `sub` mkL y
- (L y :-: v) :+: w -> return $ mkL y `add` (w `sub` v)
- (v :-: L y) :+: w -> return $ (w `add` v) `sub` mkL y
-
- _ -> mzero
+addFoldingRules :: PrimOp -> NumOps -> RuleM CoreExpr
+addFoldingRules op num_ops = do
+ ASSERT(op == numAdd num_ops) return ()
+ env <- getEnv
+ guard (roNumConstantFolding env)
+ [arg1,arg2] <- getArgs
+ platform <- getPlatform
+ liftMaybe
+ -- commutativity for + is handled here
+ (addFoldingRules' platform arg1 arg2 num_ops
+ <|> addFoldingRules' platform arg2 arg1 num_ops)
+
+subFoldingRules :: PrimOp -> NumOps -> RuleM CoreExpr
+subFoldingRules op num_ops = do
+ ASSERT(op == numSub num_ops) return ()
+ env <- getEnv
+ guard (roNumConstantFolding env)
+ [arg1,arg2] <- getArgs
+ platform <- getPlatform
+ liftMaybe (subFoldingRules' platform arg1 arg2 num_ops)
+
+mulFoldingRules :: PrimOp -> NumOps -> RuleM CoreExpr
+mulFoldingRules op num_ops = do
+ ASSERT(op == numMul num_ops) return ()
+ env <- getEnv
+ guard (roNumConstantFolding env)
+ [arg1,arg2] <- getArgs
+ platform <- getPlatform
+ liftMaybe
+ -- commutativity for * is handled here
+ (mulFoldingRules' platform arg1 arg2 num_ops
+ <|> mulFoldingRules' platform arg2 arg1 num_ops)
+
+
+addFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr
+addFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of
+ -- R1) +/- simplification
+
+ -- l1 + (l2 + x) ==> (l1+l2) + x
+ (L l1, is_lit_add num_ops -> Just (l2,x))
+ -> Just (mkL (l1+l2) `add` x)
+
+ -- l1 + (l2 - x) ==> (l1+l2) - x
+ (L l1, is_sub num_ops -> Just (L l2,x))
+ -> Just (mkL (l1+l2) `sub` x)
+
+ -- l1 + (x - l2) ==> (l1-l2) + x
+ (L l1, is_sub num_ops -> Just (x,L l2))
+ -> Just (mkL (l1-l2) `add` x)
+
+ -- (l1 + x) + (l2 + y) ==> (l1+l2) + (x+y)
+ (is_lit_add num_ops -> Just (l1,x), is_lit_add num_ops -> Just (l2,y))
+ -> Just (mkL (l1+l2) `add` (x `add` y))
+
+ -- (l1 + x) + (l2 - y) ==> (l1+l2) + (x-y)
+ (is_lit_add num_ops -> Just (l1,x), is_sub num_ops -> Just (L l2,y))
+ -> Just (mkL (l1+l2) `add` (x `sub` y))
+
+ -- (l1 + x) + (y - l2) ==> (l1-l2) + (x+y)
+ (is_lit_add num_ops -> Just (l1,x), is_sub num_ops -> Just (y,L l2))
+ -> Just (mkL (l1-l2) `add` (x `add` y))
+
+ -- (l1 - x) + (l2 - y) ==> (l1+l2) - (x+y)
+ (is_sub num_ops -> Just (L l1,x), is_sub num_ops -> Just (L l2,y))
+ -> Just (mkL (l1+l2) `sub` (x `add` y))
+
+ -- (l1 - x) + (y - l2) ==> (l1-l2) + (y-x)
+ (is_sub num_ops -> Just (L l1,x), is_sub num_ops -> Just (y,L l2))
+ -> Just (mkL (l1-l2) `add` (y `sub` x))
+
+ -- (x - l1) + (y - l2) ==> (0-l1-l2) + (x+y)
+ (is_sub num_ops -> Just (x,L l1), is_sub num_ops -> Just (y,L l2))
+ -> Just (mkL (0-l1-l2) `add` (x `add` y))
+
+ -- R4) Simple factorization
+
+ -- x + x ==> 2 * x
+ _ | Just l1 <- is_expr_mul num_ops arg1 arg2
+ -> Just (mkL (l1+1) `mul` arg1)
+
+ -- (l1 * x) + x ==> (l1+1) * x
+ _ | Just l1 <- is_expr_mul num_ops arg2 arg1
+ -> Just (mkL (l1+1) `mul` arg2)
+
+ -- (l1 * x) + (l2 * x) ==> (l1+l2) * x
+ (is_lit_mul num_ops -> Just (l1,x), is_expr_mul num_ops x -> Just l2)
+ -> Just (mkL (l1+l2) `mul` x)
+
+ -- R5) +/- propagation: these transformations push literals outwards
+ -- with the hope that other rules can then be applied.
+
+ -- In the following rules, x can't be a literal otherwise another
+ -- rule would have combined it with the other literal in arg2. So we
+ -- don't have to check this to avoid loops here.
+
+ -- x + (l1 + y) ==> l1 + (x + y)
+ (_, is_lit_add num_ops -> Just (l1,y))
+ -> Just (mkL l1 `add` (arg1 `add` y))
+
+ -- x + (l1 - y) ==> l1 + (x - y)
+ (_, is_sub num_ops -> Just (L l1,y))
+ -> Just (mkL l1 `add` (arg1 `sub` y))
+
+ -- x + (y - l1) ==> (x + y) - l1
+ (_, is_sub num_ops -> Just (y,L l1))
+ -> Just ((arg1 `add` y) `sub` mkL l1)
+
+ _ -> Nothing
+ where
+ mkL = Lit . mkNumLiteral platform num_ops
+ add x y = BinOpApp x (numAdd num_ops) y
+ sub x y = BinOpApp x (numSub num_ops) y
+ mul x y = BinOpApp x (numMul num_ops) y
--- | Match the application of a binary primop
-pattern BinOpApp :: Arg CoreBndr -> PrimOp -> Arg CoreBndr -> CoreExpr
-pattern BinOpApp x op y = OpVal op `App` x `App` y
+subFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr
+subFoldingRules' platform arg1 arg2 num_ops = case (arg1,arg2) of
+ -- R1) +/- simplification
--- | Match a primop
-pattern OpVal :: PrimOp -> Arg CoreBndr
-pattern OpVal op <- Var (isPrimOpId_maybe -> Just op) where
- OpVal op = Var (mkPrimOpId op)
+ -- l1 - (l2 + x) ==> (l1-l2) - x
+ (L l1, is_lit_add num_ops -> Just (l2,x))
+ -> Just (mkL (l1-l2) `sub` x)
+ -- l1 - (l2 - x) ==> (l1-l2) + x
+ (L l1, is_sub num_ops -> Just (L l2,x))
+ -> Just (mkL (l1-l2) `add` x)
+ -- l1 - (x - l2) ==> (l1+l2) - x
+ (L l1, is_sub num_ops -> Just (x, L l2))
+ -> Just (mkL (l1+l2) `sub` x)
--- | Match a literal
-pattern L :: Integer -> Arg CoreBndr
-pattern L l <- Lit (isLitValue_maybe -> Just l)
+ -- (l1 + x) - l2 ==> (l1-l2) + x
+ (is_lit_add num_ops -> Just (l1,x), L l2)
+ -> Just (mkL (l1-l2) `add` x)
+
+ -- (l1 - x) - l2 ==> (l1-l2) - x
+ (is_sub num_ops -> Just (L l1,x), L l2)
+ -> Just (mkL (l1-l2) `sub` x)
+
+ -- (x - l1) - l2 ==> x - (l1+l2)
+ (is_sub num_ops -> Just (x,L l1), L l2)
+ -> Just (x `sub` mkL (l1+l2))
+
+
+ -- (l1 + x) - (l2 + y) ==> (l1-l2) + (x-y)
+ (is_lit_add num_ops -> Just (l1,x), is_lit_add num_ops -> Just (l2,y))
+ -> Just (mkL (l1-l2) `add` (x `sub` y))
+
+ -- (l1 + x) - (l2 - y) ==> (l1-l2) + (x+y)
+ (is_lit_add num_ops -> Just (l1,x), is_sub num_ops -> Just (L l2,y))
+ -> Just (mkL (l1-l2) `add` (x `add` y))
+
+ -- (l1 + x) - (y - l2) ==> (l1+l2) + (x-y)
+ (is_lit_add num_ops -> Just (l1,x), is_sub num_ops -> Just (y,L l2))
+ -> Just (mkL (l1+l2) `add` (x `sub` y))
+
+ -- (l1 - x) - (l2 + y) ==> (l1-l2) - (x+y)
+ (is_sub num_ops -> Just (L l1,x), is_lit_add num_ops -> Just (l2,y))
+ -> Just (mkL (l1-l2) `sub` (x `add` y))
+
+ -- (x - l1) - (l2 + y) ==> (0-l1-l2) + (x-y)
+ (is_sub num_ops -> Just (x,L l1), is_lit_add num_ops -> Just (l2,y))
+ -> Just (mkL (0-l1-l2) `add` (x `sub` y))
+
+ -- (l1 - x) - (l2 - y) ==> (l1-l2) + (y-x)
+ (is_sub num_ops -> Just (L l1,x), is_sub num_ops -> Just (L l2,y))
+ -> Just (mkL (l1-l2) `add` (y `sub` x))
+
+ -- (l1 - x) - (y - l2) ==> (l1+l2) - (x+y)
+ (is_sub num_ops -> Just (L l1,x), is_sub num_ops -> Just (y,L l2))
+ -> Just (mkL (l1+l2) `sub` (x `add` y))
+
+ -- (x - l1) - (l2 - y) ==> (0-l1-l2) + (x+y)
+ (is_sub num_ops -> Just (x,L l1), is_sub num_ops -> Just (L l2,y))
+ -> Just (mkL (0-l1-l2) `add` (x `add` y))
+
+ -- (x - l1) - (y - l2) ==> (l2-l1) + (x-y)
+ (is_sub num_ops -> Just (x,L l1), is_sub num_ops -> Just (y,L l2))
+ -> Just (mkL (l2-l1) `add` (x `sub` y))
+
+ -- R4) Simple factorization
--- | Match an addition
-pattern (:+:) :: Arg CoreBndr -> Arg CoreBndr -> CoreExpr
-pattern x :+: y <- BinOpApp x (isAddOp -> True) y
+ -- x - (l1 * x) ==> (1-l1) * x
+ _ | Just l1 <- is_expr_mul num_ops arg1 arg2
+ -> Just (mkL (1-l1) `mul` arg1)
--- | Match an addition with a literal (handle commutativity)
-pattern (:++:) :: Integer -> Arg CoreBndr -> CoreExpr
-pattern l :++: x <- (isAdd -> Just (l,x))
+ -- (l1 * x) - x ==> (l1-1) * x
+ _ | Just l1 <- is_expr_mul num_ops arg2 arg1
+ -> Just (mkL (l1-1) `mul` arg2)
-isAdd :: CoreExpr -> Maybe (Integer,CoreExpr)
-isAdd e = case e of
- L l :+: x -> Just (l,x)
- x :+: L l -> Just (l,x)
- _ -> Nothing
+ -- (l1 * x) - (l2 * x) ==> (l1-l2) * x
+ (is_lit_mul num_ops -> Just (l1,x), is_expr_mul num_ops x -> Just l2)
+ -> Just (mkL (l1-l2) `mul` x)
--- | Match a multiplication
-pattern (:*:) :: Arg CoreBndr -> Arg CoreBndr -> CoreExpr
-pattern x :*: y <- BinOpApp x (isMulOp -> True) y
+ -- R5) +/- propagation: these transformations push literals outwards
+ -- with the hope that other rules can then be applied.
--- | Match a multiplication with a literal (handle commutativity)
-pattern (:**:) :: Integer -> Arg CoreBndr -> CoreExpr
-pattern l :**: x <- (isMul -> Just (l,x))
+ -- In the following rules, x can't be a literal otherwise another
+ -- rule would have combined it with the other literal in arg2. So we
+ -- don't have to check this to avoid loops here.
-isMul :: CoreExpr -> Maybe (Integer,CoreExpr)
-isMul e = case e of
- L l :*: x -> Just (l,x)
- x :*: L l -> Just (l,x)
- _ -> Nothing
+ -- x - (l1 + y) ==> (x - y) - l1
+ (_, is_lit_add num_ops -> Just (l1,y))
+ -> Just ((arg1 `sub` y) `sub` mkL l1)
+ -- (l1 + x) - y ==> l1 + (x - y)
+ (is_lit_add num_ops -> Just (l1,x), _)
+ -> Just (mkL l1 `add` (x `sub` arg2))
--- | Match a subtraction
-pattern (:-:) :: Arg CoreBndr -> Arg CoreBndr -> CoreExpr
-pattern x :-: y <- BinOpApp x (isSubOp -> True) y
+ -- x - (l1 - y) ==> (x + y) - l1
+ (_, is_sub num_ops -> Just (L l1,y))
+ -> Just ((arg1 `add` y) `sub` mkL l1)
-isSubOp :: PrimOp -> Bool
-isSubOp IntSubOp = True
-isSubOp WordSubOp = True
-isSubOp _ = False
+ -- x - (y - l1) ==> l1 + (x - y)
+ (_, is_sub num_ops -> Just (y,L l1))
+ -> Just (mkL l1 `add` (arg1 `sub` y))
-isAddOp :: PrimOp -> Bool
-isAddOp IntAddOp = True
-isAddOp WordAddOp = True
-isAddOp _ = False
+ -- (l1 - x) - y ==> l1 - (x + y)
+ (is_sub num_ops -> Just (L l1,x), _)
+ -> Just (mkL l1 `sub` (x `add` arg2))
-isMulOp :: PrimOp -> Bool
-isMulOp IntMulOp = True
-isMulOp WordMulOp = True
-isMulOp _ = False
+ -- (x - l1) - y ==> (x - y) - l1
+ (is_sub num_ops -> Just (x,L l1), _)
+ -> Just ((x `sub` arg2) `sub` mkL l1)
+
+ _ -> Nothing
+ where
+ mkL = Lit . mkNumLiteral platform num_ops
+ add x y = BinOpApp x (numAdd num_ops) y
+ sub x y = BinOpApp x (numSub num_ops) y
+ mul x y = BinOpApp x (numMul num_ops) y
+
+mulFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr
+mulFoldingRules' platform arg1 arg2 num_ops = case (arg1,arg2) of
+ -- l1 * (l2 * x) ==> (l1*l2) * x
+ (L l1, is_lit_mul num_ops -> Just (l2,x))
+ -> Just (mkL (l1*l2) `mul` x)
+
+ -- l1 * (l2 + x) ==> (l1*l2) + (l1 * x)
+ (L l1, is_lit_add num_ops -> Just (l2,x))
+ -> Just (mkL (l1*l2) `add` (arg1 `mul` x))
+
+ -- l1 * (l2 - x) ==> (l1*l2) - (l1 * x)
+ (L l1, is_sub num_ops -> Just (L l2,x))
+ -> Just (mkL (l1*l2) `sub` (arg1 `mul` x))
+
+ -- l1 * (x - l2) ==> (l1 * x) - (l1*l2)
+ (L l1, is_sub num_ops -> Just (x, L l2))
+ -> Just ((arg1 `mul` x) `sub` mkL (l1*l2))
+
+ -- (l1 * x) * (l2 * y) ==> (l1*l2) * (x * y)
+ (is_lit_mul num_ops -> Just (l1,x), is_lit_mul num_ops -> Just (l2,y))
+ -> Just (mkL (l1*l2) `mul` (x `mul` y))
+
+ _ -> Nothing
+ where
+ mkL = Lit . mkNumLiteral platform num_ops
+ add x y = BinOpApp x (numAdd num_ops) y
+ sub x y = BinOpApp x (numSub num_ops) y
+ mul x y = BinOpApp x (numMul num_ops) y
+
+is_op :: PrimOp -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr)
+is_op op e = case e of
+ BinOpApp x op' y | op == op' -> Just (x,y)
+ _ -> Nothing
+
+is_add, is_sub, is_mul :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr)
+is_add num_ops = is_op (numAdd num_ops)
+is_sub num_ops = is_op (numSub num_ops)
+is_mul num_ops = is_op (numMul num_ops)
+
+-- match addition with a literal (handles commutativity)
+is_lit_add :: NumOps -> CoreExpr -> Maybe (Integer, Arg CoreBndr)
+is_lit_add num_ops e = case is_add num_ops e of
+ Just (L l, x ) -> Just (l,x)
+ Just (x , L l) -> Just (l,x)
+ _ -> Nothing
+
+-- match multiplication with a literal (handles commutativity)
+is_lit_mul :: NumOps -> CoreExpr -> Maybe (Integer, Arg CoreBndr)
+is_lit_mul num_ops e = case is_mul num_ops e of
+ Just (L l, x ) -> Just (l,x)
+ Just (x , L l) -> Just (l,x)
+ _ -> Nothing
+
+-- match given "x": return 1
+-- match "lit * x": return lit value (handles commutativity)
+is_expr_mul :: NumOps -> Expr CoreBndr -> Expr CoreBndr -> Maybe Integer
+is_expr_mul num_ops x e = if
+ | x `cheapEqExpr` e
+ -> Just 1
+ | Just (k,x') <- is_lit_mul num_ops e
+ , x `cheapEqExpr` x'
+ -> return k
+ | otherwise
+ -> Nothing
+
+
+-- | Match the application of a binary primop
+pattern BinOpApp :: Arg CoreBndr -> PrimOp -> Arg CoreBndr -> CoreExpr
+pattern BinOpApp x op y = OpVal op `App` x `App` y
+
+-- | Match a primop
+pattern OpVal:: PrimOp -> Arg CoreBndr
+pattern OpVal op <- Var (isPrimOpId_maybe -> Just op) where
+ OpVal op = Var (mkPrimOpId op)
+
+-- | Match a literal
+pattern L :: Integer -> Arg CoreBndr
+pattern L l <- Lit (isLitValue_maybe -> Just l)
-- | Explicit "type-class"-like dictionary for numeric primops
---
--- Depends on Platform because creating a literal value depends on Platform
-data PrimOps = PrimOps
- { add :: CoreExpr -> CoreExpr -> CoreExpr -- ^ Add two numbers
- , sub :: CoreExpr -> CoreExpr -> CoreExpr -- ^ Sub two numbers
- , mul :: CoreExpr -> CoreExpr -> CoreExpr -- ^ Multiply two numbers
- , mkL :: Integer -> CoreExpr -- ^ Create a literal value
+data NumOps = NumOps
+ { numAdd :: !PrimOp -- ^ Add two numbers
+ , numSub :: !PrimOp -- ^ Sub two numbers
+ , numMul :: !PrimOp -- ^ Multiply two numbers
+ , numLitType :: !LitNumType -- ^ Literal type
}
-intPrimOps :: Platform -> PrimOps
-intPrimOps platform = PrimOps
- { add = \x y -> BinOpApp x IntAddOp y
- , sub = \x y -> BinOpApp x IntSubOp y
- , mul = \x y -> BinOpApp x IntMulOp y
- , mkL = intResult' platform
- }
+-- | Create a numeric literal
+mkNumLiteral :: Platform -> NumOps -> Integer -> Literal
+mkNumLiteral platform ops i = mkLitNumberWrap platform (numLitType ops) i
-wordPrimOps :: Platform -> PrimOps
-wordPrimOps platform = PrimOps
- { add = \x y -> BinOpApp x WordAddOp y
- , sub = \x y -> BinOpApp x WordSubOp y
- , mul = \x y -> BinOpApp x WordMulOp y
- , mkL = wordResult' platform
+intOps :: NumOps
+intOps = NumOps
+ { numAdd = IntAddOp
+ , numSub = IntSubOp
+ , numMul = IntMulOp
+ , numLitType = LitNumInt
}
+wordOps :: NumOps
+wordOps = NumOps
+ { numAdd = WordAddOp
+ , numSub = WordSubOp
+ , numMul = WordMulOp
+ , numLitType = LitNumWord
+ }
--------------------------------------------------------
-- Constant folding through case-expressions
diff --git a/testsuite/tests/simplCore/should_run/NumConstantFolding.hs b/testsuite/tests/simplCore/should_run/NumConstantFolding.hs
new file mode 100644
index 0000000000..6466adfe4d
--- /dev/null
+++ b/testsuite/tests/simplCore/should_run/NumConstantFolding.hs
@@ -0,0 +1,109 @@
+{-# LANGUAGE MagicHash #-}
+
+import GHC.Exts
+import Data.Word
+import Data.Int
+
+(+&) = plusWord#
+(-&) = minusWord#
+(*&) = timesWord#
+
+{-# NOINLINE testsW #-}
+-- NOINLINE otherwise basic constant folding rules (without
+-- variables) are applied
+testsW :: Word# -> Word# -> [Word]
+testsW x y = fmap (\z -> fromIntegral (fromIntegral z :: Word32))
+ -- narrowing to get the same results on both 64- and 32-bit arch
+ [ W# (43## +& (37## +& x))
+ , W# (43## +& (37## -& x))
+ , W# (43## +& (x -& 37##))
+ , W# (43## -& (37## +& x))
+ , W# (43## -& (37## -& x))
+ , W# (43## -& (x -& 37##))
+ , W# ((43## +& x) -& 37##)
+ , W# ((x +& 43##) -& 37##)
+ , W# ((43## -& x) -& 37##)
+ , W# ((x -& 43##) -& 37##)
+
+ , W# ((x +& 43##) +& (y +& 37##))
+ , W# ((x +& 43##) +& (y -& 37##))
+ , W# ((x +& 43##) +& (37## -& y))
+ , W# ((x -& 43##) +& (37## -& y))
+ , W# ((x -& 43##) +& (y -& 37##))
+ , W# ((43## -& x) +& (37## -& y))
+ , W# ((43## -& x) +& (y -& 37##))
+ ]
+
+{-# NOINLINE testsI #-}
+testsI :: Int# -> Int# -> [Int]
+testsI x y = fmap (\z -> fromIntegral (fromIntegral z :: Int32))
+ [ I# (43# +# (37# +# x))
+ , I# (43# +# (37# -# x))
+ , I# (43# +# (x -# 37#))
+ , I# (43# -# (37# +# x))
+ , I# (43# -# (37# -# x))
+ , I# (43# -# (x -# 37#))
+ , I# ((43# +# x) -# 37#)
+ , I# ((x +# 43#) -# 37#)
+ , I# ((43# -# x) -# 37#)
+ , I# ((x -# 43#) -# 37#)
+
+ , I# ((x +# 43#) +# (y +# 37#))
+ , I# ((x +# 43#) +# (y -# 37#))
+ , I# ((x +# 43#) +# (37# -# y))
+ , I# ((x -# 43#) +# (37# -# y))
+ , I# ((x -# 43#) +# (y -# 37#))
+ , I# ((43# -# x) +# (37# -# y))
+ , I# ((43# -# x) +# (y -# 37#))
+
+ , I# ((x +# 43#) -# (y +# 37#))
+ , I# ((x +# 43#) -# (y -# 37#))
+ , I# ((x +# 43#) -# (37# -# y))
+ , I# ((x -# 43#) -# (y +# 37#))
+ , I# ((43# -# x) -# (37# +# y))
+ , I# ((x -# 43#) -# (y -# 37#))
+ , I# ((x -# 43#) -# (37# -# y))
+ , I# ((43# -# x) -# (y -# 37#))
+ , I# ((43# -# x) -# (37# -# y))
+
+ , I# (43# *# (37# *# y))
+ , I# (43# *# (y *# 37#))
+ , I# ((43# *# x) *# (y *# 37#))
+
+ , I# (43# *# (37# +# y))
+ , I# (43# *# (37# -# y))
+ , I# (43# *# (y -# 37#))
+
+ , I# (x +# x)
+ , I# ((43# *# x) +# x)
+ , I# (x +# (43# *# x))
+ , I# ((43# *# x) +# (37# *# x))
+ , I# ((43# *# x) +# (x *# 37#))
+
+ , I# (x -# x)
+ , I# ((43# *# x) -# x)
+ , I# (x -# (43# *# x))
+ , I# ((43# *# x) -# (37# *# x))
+ , I# ((43# *# x) -# (x *# 37#))
+
+ , I# (x +# (37# +# y))
+ , I# (x +# (y +# 37#))
+ , I# (x +# (37# -# y))
+ , I# (x +# (y -# 37#))
+ , I# (x -# (37# +# y))
+ , I# (x -# (y +# 37#))
+ , I# (x -# (37# -# y))
+ , I# (x -# (y -# 37#))
+ , I# ((37# +# y) -# x)
+ , I# ((y +# 37#) -# x)
+ , I# ((37# -# y) -# x)
+ , I# ((y -# 37#) -# x)
+
+ , I# (y *# y)
+ ]
+
+
+main :: IO ()
+main = do
+ print (testsW 7## 13##)
+ print (testsI 7# 13#)
diff --git a/testsuite/tests/simplCore/should_run/NumConstantFolding.stdout b/testsuite/tests/simplCore/should_run/NumConstantFolding.stdout
new file mode 100644
index 0000000000..da6f72855f
--- /dev/null
+++ b/testsuite/tests/simplCore/should_run/NumConstantFolding.stdout
@@ -0,0 +1,2 @@
+[87,73,13,4294967295,13,73,13,13,4294967295,4294967223,100,26,74,4294967284,4294967236,60,12]
+[87,73,13,-1,13,73,13,13,-1,-73,100,26,74,-12,-60,60,12,0,74,26,-86,-14,-12,-60,60,12,20683,20683,144781,2150,1032,-1032,14,308,308,560,560,0,294,-294,42,42,57,57,31,-17,-43,-43,-17,31,43,43,17,-31,169]
diff --git a/testsuite/tests/simplCore/should_run/all.T b/testsuite/tests/simplCore/should_run/all.T
index a04558be89..ea10cd7914 100644
--- a/testsuite/tests/simplCore/should_run/all.T
+++ b/testsuite/tests/simplCore/should_run/all.T
@@ -93,3 +93,4 @@ test('T17151', [], multimod_compile_and_run, ['T17151', ''])
test('T18012', normal, compile_and_run, [''])
test('T17744', normal, compile_and_run, [''])
test('T18638', normal, compile_and_run, [''])
+test('NumConstantFolding', normal, compile_and_run, [''])