summaryrefslogtreecommitdiff
path: root/compiler/prelude/PrelRules.hs
diff options
context:
space:
mode:
authorSylvain Henry <hsyl20@gmail.com>2018-04-13 13:29:07 -0400
committerBen Gamari <ben@smart-cactus.org>2018-04-13 16:25:43 -0400
commitfea04defa64871caab6339ff3fc5511a272f37c7 (patch)
treed19cfb3062f3edb1229998f93c4e382ff366fd42 /compiler/prelude/PrelRules.hs
parent4b831c27926d643b0b6fad82c1e946d05cde8645 (diff)
downloadhaskell-fea04defa64871caab6339ff3fc5511a272f37c7.tar.gz
Enhanced constant folding
Until now GHC only supported basic constant folding (lit op lit, expr op 0, etc.). This patch uses laws of +/-/* (associativity, commutativity, distributivity) to support some constant folding into nested expressions. Examples of new transformations: - simple nesting: (10 + x) + 10 becomes 20 + x - deep nesting: 5 + x + (y + (z + (t + 5))) becomes 10 + (x + (y + (z + t))) - distribution: (5 + x) * 6 becomes 30 + 6*x - simple factorization: 5 + x + (x + (x + (x + 5))) becomes 10 + (4 *x) - siblings: (5 + 4*x) - (3*x + 2) becomes 3 + x Test Plan: validate Reviewers: simonpj, austin, bgamari Reviewed By: bgamari Subscribers: thomie GHC Trac Issues: #9136 Differential Revision: https://phabricator.haskell.org/D2858
Diffstat (limited to 'compiler/prelude/PrelRules.hs')
-rw-r--r--compiler/prelude/PrelRules.hs305
1 files changed, 296 insertions, 9 deletions
diff --git a/compiler/prelude/PrelRules.hs b/compiler/prelude/PrelRules.hs
index 9fa0db6253..c1250c113b 100644
--- a/compiler/prelude/PrelRules.hs
+++ b/compiler/prelude/PrelRules.hs
@@ -12,7 +12,7 @@ ToDo:
(i1 + i2) only if it results in a valid Float.
-}
-{-# LANGUAGE CPP, RankNTypes #-}
+{-# LANGUAGE CPP, RankNTypes, PatternSynonyms, ViewPatterns, RecordWildCards #-}
{-# OPTIONS_GHC -optc-DNON_POSIX_SOURCE #-}
module PrelRules
@@ -90,13 +90,19 @@ primOpRules nm DataToTagOp = mkPrimOpRule nm 2 [ dataToTagRule ]
-- Int operations
primOpRules nm IntAddOp = mkPrimOpRule nm 2 [ binaryLit (intOp2 (+))
- , identityDynFlags zeroi ]
+ , identityDynFlags zeroi
+ , numFoldingRules IntAddOp intPrimOps
+ ]
primOpRules nm IntSubOp = mkPrimOpRule nm 2 [ binaryLit (intOp2 (-))
, rightIdentityDynFlags zeroi
- , equalArgs >> retLit zeroi ]
+ , equalArgs >> retLit zeroi
+ , numFoldingRules IntSubOp intPrimOps
+ ]
primOpRules nm IntMulOp = mkPrimOpRule nm 2 [ binaryLit (intOp2 (*))
, zeroElem zeroi
- , identityDynFlags onei ]
+ , identityDynFlags onei
+ , numFoldingRules IntMulOp intPrimOps
+ ]
primOpRules nm IntQuotOp = mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 quot)
, leftZero zeroi
, rightIdentityDynFlags onei
@@ -131,12 +137,18 @@ primOpRules nm ISrlOp = mkPrimOpRule nm 2 [ shiftRule shiftRightLogical
-- Word operations
primOpRules nm WordAddOp = mkPrimOpRule nm 2 [ binaryLit (wordOp2 (+))
- , identityDynFlags zerow ]
+ , identityDynFlags zerow
+ , numFoldingRules WordAddOp wordPrimOps
+ ]
primOpRules nm WordSubOp = mkPrimOpRule nm 2 [ binaryLit (wordOp2 (-))
, rightIdentityDynFlags zerow
- , equalArgs >> retLit zerow ]
+ , equalArgs >> retLit zerow
+ , numFoldingRules WordSubOp wordPrimOps
+ ]
primOpRules nm WordMulOp = mkPrimOpRule nm 2 [ binaryLit (wordOp2 (*))
- , identityDynFlags onew ]
+ , identityDynFlags onew
+ , numFoldingRules WordMulOp wordPrimOps
+ ]
primOpRules nm WordQuotOp = mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 quot)
, rightIdentityDynFlags onew ]
primOpRules nm WordRemOp = mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 rem)
@@ -548,12 +560,18 @@ isMaxBound _ _ = False
-- | Create an Int literal expression while ensuring the given Integer is in the
-- target Int range
intResult :: DynFlags -> Integer -> Maybe CoreExpr
-intResult dflags result = Just (Lit (mkMachIntWrap dflags result))
+intResult dflags result = Just (intResult' dflags result)
+
+intResult' :: DynFlags -> Integer -> CoreExpr
+intResult' dflags result = Lit (mkMachIntWrap dflags result)
-- | Create a Word literal expression while ensuring the given Integer is in the
-- target Word range
wordResult :: DynFlags -> Integer -> Maybe CoreExpr
-wordResult dflags result = Just (Lit (mkMachWordWrap dflags result))
+wordResult dflags result = Just (wordResult' dflags result)
+
+wordResult' :: DynFlags -> Integer -> CoreExpr
+wordResult' dflags result = Lit (mkMachWordWrap dflags result)
inversePrimOp :: PrimOp -> RuleM CoreExpr
inversePrimOp primop = do
@@ -1478,6 +1496,275 @@ match_smallIntegerTo _ _ _ _ _ = Nothing
--------------------------------------------------------
+-- Note [Constant folding through nested expressions]
+-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+--
+-- We use rewrites rules to perform constant folding. It means that we don't
+-- have a global view of the expression we are trying to optimise. As a
+-- consequence we only perform local (small-step) transformations that either:
+-- 1) reduce the number of operations
+-- 2) rearrange the expression to increase the odds that other rules will
+-- match
+--
+-- We don't try to handle more complex expression optimisation cases that would
+-- require a global view. For example, rewriting expressions to increase
+-- sharing (e.g., Horner's method); optimisations that require local
+-- transformations increasing the number of operations; rearrangements to
+-- cancel/factorize terms (e.g., (a+b-a-b) isn't rearranged to reduce to 0).
+--
+-- We already have rules to perform constant folding on expressions with the
+-- following shape (where a and/or b are literals):
+--
+-- D) op
+-- /\
+-- / \
+-- / \
+-- a b
+--
+-- To support nested expressions, we match three other shapes of expression
+-- trees:
+--
+-- A) op1 B) op1 C) op1
+-- /\ /\ /\
+-- / \ / \ / \
+-- / \ / \ / \
+-- a op2 op2 c op2 op3
+-- /\ /\ /\ /\
+-- / \ / \ / \ / \
+-- b c a b a b c d
+--
+--
+-- R1) +/- simplification:
+-- ops = + or -, two literals (not siblings)
+--
+-- Examples:
+-- A: 5 + (10-x) ==> 15-x
+-- B: (10+x) + 5 ==> 15+x
+-- C: (5+a)-(5-b) ==> 0+(a+b)
+--
+-- R2) * simplification
+-- ops = *, two literals (not siblings)
+--
+-- Examples:
+-- A: 5 * (10*x) ==> 50*x
+-- B: (10*x) * 5 ==> 50*x
+-- C: (5*a)*(5*b) ==> 25*(a*b)
+--
+-- R3) * distribution over +/-
+-- op1 = *, op2 = + or -, two literals (not siblings)
+--
+-- This transformation doesn't reduce the number of operations but switches
+-- the outer and the inner operations so that the outer is (+) or (-) instead
+-- of (*). It increases the odds that other rules will match after this one.
+--
+-- Examples:
+-- A: 5 * (10-x) ==> 50 - (5*x)
+-- B: (10+x) * 5 ==> 50 + (5*x)
+-- C: Not supported as it would increase the number of operations:
+-- (5+a)*(5-b) ==> 25 - 5*b + 5*a - a*b
+--
+-- R4) Simple factorization
+--
+-- op1 = + or -, op2/op3 = *,
+-- one literal for each innermost * operation (except in the D case),
+-- the two other terms are equals
+--
+-- Examples:
+-- A: x - (10*x) ==> (-9)*x
+-- B: (10*x) + x ==> 11*x
+-- C: (5*x)-(x*3) ==> 2*x
+-- D: x+x ==> 2*x
+--
+-- R5) +/- propagation
+--
+-- ops = + or -, one literal
+--
+-- This transformation doesn't reduce the number of operations but propagates
+-- the constant to the outer level. It increases the odds that other rules
+-- will match after this one.
+--
+-- Examples:
+-- A: x - (10-y) ==> (x+y) - 10
+-- B: (10+x) - y ==> 10 + (x-y)
+-- C: N/A (caught by the A and B cases)
+--
+--------------------------------------------------------
+
+-- | Rules to perform constant folding into nested expressions
+--
+--See Note [Constant folding through nested expressions]
+numFoldingRules :: PrimOp -> (DynFlags -> PrimOps) -> RuleM CoreExpr
+numFoldingRules op dict = do
+ [e1,e2] <- getArgs
+ dflags <- getDynFlags
+ let PrimOps{..} = dict dflags
+ if not (gopt Opt_NumConstantFolding dflags)
+ then mzero
+ else case BinOpApp e1 op e2 of
+ -- R1) +/- simplification
+ x :++: (y :++: v) -> return $ mkL (x+y) `add` v
+ x :++: (L y :-: v) -> return $ mkL (x+y) `sub` v
+ x :++: (v :-: L y) -> return $ mkL (x-y) `add` v
+ L x :-: (y :++: v) -> return $ mkL (x-y) `sub` v
+ L x :-: (L y :-: v) -> return $ mkL (x-y) `add` v
+ L x :-: (v :-: L y) -> return $ mkL (x+y) `sub` v
+
+ (y :++: v) :-: L x -> return $ mkL (y-x) `add` v
+ (L y :-: v) :-: L x -> return $ mkL (y-x) `sub` v
+ (v :-: L y) :-: L x -> return $ mkL (0-y-x) `add` v
+
+ (x :++: w) :+: (y :++: v) -> return $ mkL (x+y) `add` (w `add` v)
+ (w :-: L x) :+: (L y :-: v) -> return $ mkL (y-x) `add` (w `sub` v)
+ (w :-: L x) :+: (v :-: L y) -> return $ mkL (0-x-y) `add` (w `add` v)
+ (L x :-: w) :+: (L y :-: v) -> return $ mkL (x+y) `sub` (w `add` v)
+ (L x :-: w) :+: (v :-: L y) -> return $ mkL (x-y) `add` (v `sub` w)
+ (w :-: L x) :+: (y :++: v) -> return $ mkL (y-x) `add` (w `add` v)
+ (L x :-: w) :+: (y :++: v) -> return $ mkL (x+y) `add` (v `sub` w)
+ (y :++: v) :+: (w :-: L x) -> return $ mkL (y-x) `add` (w `add` v)
+ (y :++: v) :+: (L x :-: w) -> return $ mkL (x+y) `add` (v `sub` w)
+
+ (v :-: L y) :-: (w :-: L x) -> return $ mkL (x-y) `add` (v `sub` w)
+ (v :-: L y) :-: (L x :-: w) -> return $ mkL (0-x-y) `add` (v `add` w)
+ (L y :-: v) :-: (w :-: L x) -> return $ mkL (x+y) `sub` (v `add` w)
+ (L y :-: v) :-: (L x :-: w) -> return $ mkL (y-x) `add` (w `add` v)
+ (x :++: w) :-: (y :++: v) -> return $ mkL (x-y) `add` (w `sub` v)
+ (w :-: L x) :-: (y :++: v) -> return $ mkL (0-y-x) `add` (w `sub` v)
+ (L x :-: w) :-: (y :++: v) -> return $ mkL (x-y) `sub` (v `add` w)
+ (y :++: v) :-: (w :-: L x) -> return $ mkL (y+x) `add` (v `sub` w)
+ (y :++: v) :-: (L x :-: w) -> return $ mkL (y-x) `add` (v `add` w)
+
+ -- R2) * simplification
+ x :**: (y :**: v) -> return $ mkL (x*y) `mul` v
+ (x :**: w) :*: (y :**: v) -> return $ mkL (x*y) `mul` (w `mul` v)
+
+ -- R3) * distribution over +/-
+ x :**: (y :++: v) -> return $ mkL (x*y) `add` (mkL x `mul` v)
+ x :**: (L y :-: v) -> return $ mkL (x*y) `sub` (mkL x `mul` v)
+ x :**: (v :-: L y) -> return $ (mkL x `mul` v) `sub` mkL (x*y)
+
+ -- R4) Simple factorization
+ v :+: w
+ | w `cheapEqExpr` v -> return $ mkL 2 `mul` v
+ w :+: (y :**: v)
+ | w `cheapEqExpr` v -> return $ mkL (1+y) `mul` v
+ w :-: (y :**: v)
+ | w `cheapEqExpr` v -> return $ mkL (1-y) `mul` v
+ (y :**: v) :+: w
+ | w `cheapEqExpr` v -> return $ mkL (y+1) `mul` v
+ (y :**: v) :-: w
+ | w `cheapEqExpr` v -> return $ mkL (y-1) `mul` v
+ (x :**: w) :+: (y :**: v)
+ | w `cheapEqExpr` v -> return $ mkL (x+y) `mul` v
+ (x :**: w) :-: (y :**: v)
+ | w `cheapEqExpr` v -> return $ mkL (x-y) `mul` v
+
+ -- R5) +/- propagation
+ w :+: (y :++: v) -> return $ mkL y `add` (w `add` v)
+ (y :++: v) :+: w -> return $ mkL y `add` (w `add` v)
+ w :-: (y :++: v) -> return $ (w `sub` v) `sub` mkL y
+ (y :++: v) :-: w -> return $ mkL y `add` (v `sub` w)
+ w :-: (L y :-: v) -> return $ (w `add` v) `sub` mkL y
+ (L y :-: v) :-: w -> return $ mkL y `sub` (w `add` v)
+ w :+: (L y :-: v) -> return $ mkL y `add` (w `sub` v)
+ w :+: (v :-: L y) -> return $ (w `add` v) `sub` mkL y
+ (L y :-: v) :+: w -> return $ mkL y `add` (w `sub` v)
+ (v :-: L y) :+: w -> return $ (w `add` v) `sub` mkL y
+
+ _ -> mzero
+
+
+
+-- | Match the application of a binary primop
+pattern BinOpApp :: Arg CoreBndr -> PrimOp -> Arg CoreBndr -> CoreExpr
+pattern BinOpApp x op y = OpVal op `App` x `App` y
+
+-- | Match a primop
+pattern OpVal :: PrimOp -> Arg CoreBndr
+pattern OpVal op <- Var (isPrimOpId_maybe -> Just op) where
+ OpVal op = Var (mkPrimOpId op)
+
+
+
+-- | Match a literal
+pattern L :: Integer -> Arg CoreBndr
+pattern L l <- Lit (isLitValue_maybe -> Just l)
+
+-- | Match an addition
+pattern (:+:) :: Arg CoreBndr -> Arg CoreBndr -> CoreExpr
+pattern x :+: y <- BinOpApp x (isAddOp -> True) y
+
+-- | Match an addition with a literal (handle commutativity)
+pattern (:++:) :: Integer -> Arg CoreBndr -> CoreExpr
+pattern l :++: x <- (isAdd -> Just (l,x))
+
+isAdd :: CoreExpr -> Maybe (Integer,CoreExpr)
+isAdd e = case e of
+ L l :+: x -> Just (l,x)
+ x :+: L l -> Just (l,x)
+ _ -> Nothing
+
+-- | Match a multiplication
+pattern (:*:) :: Arg CoreBndr -> Arg CoreBndr -> CoreExpr
+pattern x :*: y <- BinOpApp x (isMulOp -> True) y
+
+-- | Match a multiplication with a literal (handle commutativity)
+pattern (:**:) :: Integer -> Arg CoreBndr -> CoreExpr
+pattern l :**: x <- (isMul -> Just (l,x))
+
+isMul :: CoreExpr -> Maybe (Integer,CoreExpr)
+isMul e = case e of
+ L l :*: x -> Just (l,x)
+ x :*: L l -> Just (l,x)
+ _ -> Nothing
+
+
+-- | Match a subtraction
+pattern (:-:) :: Arg CoreBndr -> Arg CoreBndr -> CoreExpr
+pattern x :-: y <- BinOpApp x (isSubOp -> True) y
+
+isSubOp :: PrimOp -> Bool
+isSubOp IntSubOp = True
+isSubOp WordSubOp = True
+isSubOp _ = False
+
+isAddOp :: PrimOp -> Bool
+isAddOp IntAddOp = True
+isAddOp WordAddOp = True
+isAddOp _ = False
+
+isMulOp :: PrimOp -> Bool
+isMulOp IntMulOp = True
+isMulOp WordMulOp = True
+isMulOp _ = False
+
+-- | Explicit "type-class"-like dictionary for numeric primops
+--
+-- Depends on DynFlags because creating a literal value depends on DynFlags
+data PrimOps = PrimOps
+ { add :: CoreExpr -> CoreExpr -> CoreExpr -- ^ Add two numbers
+ , sub :: CoreExpr -> CoreExpr -> CoreExpr -- ^ Sub two numbers
+ , mul :: CoreExpr -> CoreExpr -> CoreExpr -- ^ Multiply two numbers
+ , mkL :: Integer -> CoreExpr -- ^ Create a literal value
+ }
+
+intPrimOps :: DynFlags -> PrimOps
+intPrimOps dflags = PrimOps
+ { add = \x y -> BinOpApp x IntAddOp y
+ , sub = \x y -> BinOpApp x IntSubOp y
+ , mul = \x y -> BinOpApp x IntMulOp y
+ , mkL = intResult' dflags
+ }
+
+wordPrimOps :: DynFlags -> PrimOps
+wordPrimOps dflags = PrimOps
+ { add = \x y -> BinOpApp x WordAddOp y
+ , sub = \x y -> BinOpApp x WordSubOp y
+ , mul = \x y -> BinOpApp x WordMulOp y
+ , mkL = wordResult' dflags
+ }
+
+
+--------------------------------------------------------
-- Constant folding through case-expressions
--
-- cf Scrutinee Constant Folding in simplCore/SimplUtils