summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSylvain Henry <sylvain@haskus.fr>2023-01-27 13:57:11 +0100
committerMarge Bot <ben+marge-bot@smart-cactus.org>2023-04-13 08:50:33 -0400
commit4dd021227559e1bc70cdaed12e45ff5459c33d27 (patch)
tree04497c322c430924c746102f1d679fed3e7396c0
parent593218794199e23cdfc1a94200cbb9f404e28720 (diff)
downloadhaskell-4dd021227559e1bc70cdaed12e45ff5459c33d27.tar.gz
Add quot folding rule (#22152)
(x / l1) / l2 l1 and l2 /= 0 l1*l2 doesn't overflow ==> x / (l1 * l2)
-rw-r--r--compiler/GHC/Core/Opt/ConstantFold.hs97
-rw-r--r--compiler/GHC/Types/Literal.hs8
-rw-r--r--testsuite/tests/primops/should_compile/T22152.stderr5
-rw-r--r--testsuite/tests/primops/should_compile/T22152b.hs38
-rw-r--r--testsuite/tests/primops/should_compile/T22152b.stderr47
-rw-r--r--testsuite/tests/primops/should_compile/all.T1
6 files changed, 177 insertions, 19 deletions
diff --git a/compiler/GHC/Core/Opt/ConstantFold.hs b/compiler/GHC/Core/Opt/ConstantFold.hs
index fb863d65cb..42ced5a86a 100644
--- a/compiler/GHC/Core/Opt/ConstantFold.hs
+++ b/compiler/GHC/Core/Opt/ConstantFold.hs
@@ -121,7 +121,9 @@ primOpRules nm = \case
Int8QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int8Op2 quot)
, leftZero
, rightIdentity oneI8
- , equalArgs $> Lit oneI8 ]
+ , equalArgs $> Lit oneI8
+ , quotFoldingRules int8Ops
+ ]
Int8RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int8Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroI8
@@ -150,7 +152,9 @@ primOpRules nm = \case
, mulFoldingRules Word8MulOp word8Ops
]
Word8QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word8Op2 quot)
- , rightIdentity oneW8 ]
+ , rightIdentity oneW8
+ , quotFoldingRules word8Ops
+ ]
Word8RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word8Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroW8
@@ -195,7 +199,9 @@ primOpRules nm = \case
Int16QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int16Op2 quot)
, leftZero
, rightIdentity oneI16
- , equalArgs $> Lit oneI16 ]
+ , equalArgs $> Lit oneI16
+ , quotFoldingRules int16Ops
+ ]
Int16RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int16Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroI16
@@ -224,7 +230,9 @@ primOpRules nm = \case
, mulFoldingRules Word16MulOp word16Ops
]
Word16QuotOp-> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word16Op2 quot)
- , rightIdentity oneW16 ]
+ , rightIdentity oneW16
+ , quotFoldingRules word16Ops
+ ]
Word16RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word16Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroW16
@@ -269,7 +277,9 @@ primOpRules nm = \case
Int32QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int32Op2 quot)
, leftZero
, rightIdentity oneI32
- , equalArgs $> Lit oneI32 ]
+ , equalArgs $> Lit oneI32
+ , quotFoldingRules int32Ops
+ ]
Int32RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int32Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroI32
@@ -298,7 +308,9 @@ primOpRules nm = \case
, mulFoldingRules Word32MulOp word32Ops
]
Word32QuotOp-> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word32Op2 quot)
- , rightIdentity oneW32 ]
+ , rightIdentity oneW32
+ , quotFoldingRules word32Ops
+ ]
Word32RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word32Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroW32
@@ -342,7 +354,9 @@ primOpRules nm = \case
Int64QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int64Op2 quot)
, leftZero
, rightIdentity oneI64
- , equalArgs $> Lit oneI64 ]
+ , equalArgs $> Lit oneI64
+ , quotFoldingRules int64Ops
+ ]
Int64RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int64Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroI64
@@ -371,7 +385,9 @@ primOpRules nm = \case
, mulFoldingRules Word64MulOp word64Ops
]
Word64QuotOp-> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word64Op2 quot)
- , rightIdentity oneW64 ]
+ , rightIdentity oneW64
+ , quotFoldingRules word64Ops
+ ]
Word64RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word64Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroW64
@@ -452,7 +468,9 @@ primOpRules nm = \case
IntQuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 quot)
, leftZero
, rightIdentityPlatform onei
- , equalArgs >> retLit onei ]
+ , equalArgs >> retLit onei
+ , quotFoldingRules intOps
+ ]
IntRemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 rem)
, leftZero
, oneLit 1 >> retLit zeroi
@@ -504,7 +522,9 @@ primOpRules nm = \case
, mulFoldingRules WordMulOp wordOps
]
WordQuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 quot)
- , rightIdentityPlatform onew ]
+ , rightIdentityPlatform onew
+ , quotFoldingRules wordOps
+ ]
WordRemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 rem)
, leftZero
, oneLit 1 >> retLit zerow
@@ -2653,6 +2673,14 @@ orFoldingRules num_ops = do
(orFoldingRules' platform arg1 arg2 num_ops
<|> orFoldingRules' platform arg2 arg1 num_ops)
+quotFoldingRules :: NumOps -> RuleM CoreExpr
+quotFoldingRules num_ops = do
+ env <- getRuleOpts
+ guard (roNumConstantFolding env)
+ [arg1,arg2] <- getArgs
+ platform <- getPlatform
+ liftMaybe (quotFoldingRules' platform arg1 arg2 num_ops)
+
addFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr
addFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of
@@ -2943,6 +2971,29 @@ orFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of
mkL = Lit . mkNumLiteral platform num_ops
or x y = BinOpApp x (fromJust (numOr num_ops)) y
+quotFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr
+quotFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of
+
+ -- (x / l1) / l2
+ -- l1 and l2 /= 0
+ -- l1*l2 doesn't overflow
+ -- ==> x / (l1 * l2)
+ (is_div num_ops -> Just (x, L l1), L l2)
+ | l1 /= 0
+ , l2 /= 0
+ -- check that the result of the multiplication is in range
+ , Just l <- mkNumLiteralMaybe platform num_ops (l1 * l2)
+ -> Just (div x (Lit l))
+ -- NB: we could directly return 0 or (-1) in case of overflow,
+ -- but we would need to know
+ -- (1) if we're dealing with a quot or a div operation
+ -- (2) if it's an underflow or an overflow.
+ -- Left as future work for now.
+
+ _ -> Nothing
+ where
+ div x y = BinOpApp x (fromJust (numDiv num_ops)) y
+
is_binop :: PrimOp -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr)
is_binop op e = case e of
BinOpApp x op' y | op == op' -> Just (x,y)
@@ -2953,12 +3004,13 @@ is_op op e = case e of
App (OpVal op') x | op == op' -> Just x
_ -> Nothing
-is_add, is_sub, is_mul, is_and, is_or :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr)
+is_add, is_sub, is_mul, is_and, is_or, is_div :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr)
is_add num_ops e = is_binop (numAdd num_ops) e
is_sub num_ops e = is_binop (numSub num_ops) e
is_mul num_ops e = is_binop (numMul num_ops) e
is_and num_ops e = numAnd num_ops >>= \op -> is_binop op e
is_or num_ops e = numOr num_ops >>= \op -> is_binop op e
+is_div num_ops e = numDiv num_ops >>= \op -> is_binop op e
is_neg :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr)
is_neg num_ops e = numNeg num_ops >>= \op -> is_op op e
@@ -3007,6 +3059,7 @@ data NumOps = NumOps
{ numAdd :: !PrimOp -- ^ Add two numbers
, numSub :: !PrimOp -- ^ Sub two numbers
, numMul :: !PrimOp -- ^ Multiply two numbers
+ , numDiv :: !(Maybe PrimOp) -- ^ Divide two numbers
, numAnd :: !(Maybe PrimOp) -- ^ And two numbers
, numOr :: !(Maybe PrimOp) -- ^ Or two numbers
, numNeg :: !(Maybe PrimOp) -- ^ Negate a number
@@ -3017,15 +3070,20 @@ data NumOps = NumOps
mkNumLiteral :: Platform -> NumOps -> Integer -> Literal
mkNumLiteral platform ops i = mkLitNumberWrap platform (numLitType ops) i
+-- | Create a numeric literal if it is in range
+mkNumLiteralMaybe :: Platform -> NumOps -> Integer -> Maybe Literal
+mkNumLiteralMaybe platform ops i = mkLitNumberMaybe platform (numLitType ops) i
+
int8Ops :: NumOps
int8Ops = NumOps
{ numAdd = Int8AddOp
, numSub = Int8SubOp
, numMul = Int8MulOp
- , numLitType = LitNumInt8
+ , numDiv = Just Int8QuotOp
, numAnd = Nothing
, numOr = Nothing
, numNeg = Just Int8NegOp
+ , numLitType = LitNumInt8
}
word8Ops :: NumOps
@@ -3033,6 +3091,7 @@ word8Ops = NumOps
{ numAdd = Word8AddOp
, numSub = Word8SubOp
, numMul = Word8MulOp
+ , numDiv = Just Word8QuotOp
, numAnd = Just Word8AndOp
, numOr = Just Word8OrOp
, numNeg = Nothing
@@ -3044,10 +3103,11 @@ int16Ops = NumOps
{ numAdd = Int16AddOp
, numSub = Int16SubOp
, numMul = Int16MulOp
- , numLitType = LitNumInt16
+ , numDiv = Just Int16QuotOp
, numAnd = Nothing
, numOr = Nothing
, numNeg = Just Int16NegOp
+ , numLitType = LitNumInt16
}
word16Ops :: NumOps
@@ -3055,6 +3115,7 @@ word16Ops = NumOps
{ numAdd = Word16AddOp
, numSub = Word16SubOp
, numMul = Word16MulOp
+ , numDiv = Just Word16QuotOp
, numAnd = Just Word16AndOp
, numOr = Just Word16OrOp
, numNeg = Nothing
@@ -3066,10 +3127,11 @@ int32Ops = NumOps
{ numAdd = Int32AddOp
, numSub = Int32SubOp
, numMul = Int32MulOp
- , numLitType = LitNumInt32
+ , numDiv = Just Int32QuotOp
, numAnd = Nothing
, numOr = Nothing
, numNeg = Just Int32NegOp
+ , numLitType = LitNumInt32
}
word32Ops :: NumOps
@@ -3077,6 +3139,7 @@ word32Ops = NumOps
{ numAdd = Word32AddOp
, numSub = Word32SubOp
, numMul = Word32MulOp
+ , numDiv = Just Word32QuotOp
, numAnd = Just Word32AndOp
, numOr = Just Word32OrOp
, numNeg = Nothing
@@ -3088,10 +3151,11 @@ int64Ops = NumOps
{ numAdd = Int64AddOp
, numSub = Int64SubOp
, numMul = Int64MulOp
- , numLitType = LitNumInt64
+ , numDiv = Just Int64QuotOp
, numAnd = Nothing
, numOr = Nothing
, numNeg = Just Int64NegOp
+ , numLitType = LitNumInt64
}
word64Ops :: NumOps
@@ -3099,6 +3163,7 @@ word64Ops = NumOps
{ numAdd = Word64AddOp
, numSub = Word64SubOp
, numMul = Word64MulOp
+ , numDiv = Just Word64QuotOp
, numAnd = Just Word64AndOp
, numOr = Just Word64OrOp
, numNeg = Nothing
@@ -3110,6 +3175,7 @@ intOps = NumOps
{ numAdd = IntAddOp
, numSub = IntSubOp
, numMul = IntMulOp
+ , numDiv = Just IntQuotOp
, numAnd = Just IntAndOp
, numOr = Just IntOrOp
, numNeg = Just IntNegOp
@@ -3121,6 +3187,7 @@ wordOps = NumOps
{ numAdd = WordAddOp
, numSub = WordSubOp
, numMul = WordMulOp
+ , numDiv = Just WordQuotOp
, numAnd = Just WordAndOp
, numOr = Just WordOrOp
, numNeg = Nothing
diff --git a/compiler/GHC/Types/Literal.hs b/compiler/GHC/Types/Literal.hs
index 7581c10602..1bb9ddb31b 100644
--- a/compiler/GHC/Types/Literal.hs
+++ b/compiler/GHC/Types/Literal.hs
@@ -32,7 +32,7 @@ module GHC.Types.Literal
, mkLitFloat, mkLitDouble
, mkLitChar, mkLitString
, mkLitBigNat
- , mkLitNumber, mkLitNumberWrap
+ , mkLitNumber, mkLitNumberWrap, mkLitNumberMaybe
-- ** Operations on Literals
, literalType
@@ -411,6 +411,12 @@ mkLitNumber platform nt i =
assertPpr (litNumCheckRange platform nt i) (integer i)
(LitNumber nt i)
+-- | Create a numeric 'Literal' of the given type if it is in range
+mkLitNumberMaybe :: Platform -> LitNumType -> Integer -> Maybe Literal
+mkLitNumberMaybe platform nt i
+ | litNumCheckRange platform nt i = Just (LitNumber nt i)
+ | otherwise = Nothing
+
-- | Creates a 'Literal' of type @Int#@
mkLitInt :: Platform -> Integer -> Literal
mkLitInt platform x = assertPpr (platformInIntRange platform x) (integer x)
diff --git a/testsuite/tests/primops/should_compile/T22152.stderr b/testsuite/tests/primops/should_compile/T22152.stderr
index 505bca04a7..33ff7721f6 100644
--- a/testsuite/tests/primops/should_compile/T22152.stderr
+++ b/testsuite/tests/primops/should_compile/T22152.stderr
@@ -1,10 +1,9 @@
==================== Tidy Core ====================
Result size of Tidy Core
- = {terms: 11, types: 5, coercions: 0, joins: 0/0}
+ = {terms: 9, types: 5, coercions: 0, joins: 0/0}
-toHours
- = \ t -> case t of { I# x -> I# (quotInt# (quotInt# x 60#) 60#) }
+toHours = \ t -> case t of { I# x -> I# (quotInt# x 3600#) }
diff --git a/testsuite/tests/primops/should_compile/T22152b.hs b/testsuite/tests/primops/should_compile/T22152b.hs
new file mode 100644
index 0000000000..f6ee4fce28
--- /dev/null
+++ b/testsuite/tests/primops/should_compile/T22152b.hs
@@ -0,0 +1,38 @@
+{-# OPTIONS_GHC -O2 -ddump-simpl -dno-typeable-binds -dsuppress-all -dsuppress-uniques #-}
+module T22152b where
+
+import Data.Int
+import Data.Word
+
+a :: Int32 -> Int32
+a x = (x `quot` maxBound) `quot` maxBound -- overflow, mustn't trigger the rewrite rule
+
+b :: Int -> Int
+b x = (x `quot` 10) `quot` 20
+
+c :: Word -> Word
+c x = (x `quot` 10) `quot` 20
+
+d :: Word8 -> Word8
+d x = (x `quot` 10) `quot` 20
+
+e :: Word16 -> Word16
+e x = (x `quot` 10) `quot` 20
+
+f :: Word32 -> Word32
+f x = (x `quot` 10) `quot` 20
+
+g :: Word64 -> Word64
+g x = (x `quot` 10) `quot` 20
+
+h :: Int8 -> Int8
+h x = (x `quot` 10) `quot` 20
+
+i :: Int16 -> Int16
+i x = (x `quot` 10) `quot` 20
+
+j :: Int32 -> Int32
+j x = (x `quot` 10) `quot` 20
+
+k :: Int64 -> Int64
+k x = (x `quot` 10) `quot` 20
diff --git a/testsuite/tests/primops/should_compile/T22152b.stderr b/testsuite/tests/primops/should_compile/T22152b.stderr
new file mode 100644
index 0000000000..0cf317cc32
--- /dev/null
+++ b/testsuite/tests/primops/should_compile/T22152b.stderr
@@ -0,0 +1,47 @@
+
+==================== Tidy Core ====================
+Result size of Tidy Core
+ = {terms: 119, types: 59, coercions: 0, joins: 0/0}
+
+b = \ x -> case x of { I# x1 -> I# (quotInt# x1 200#) }
+
+c = \ x -> case x of { W# x# -> W# (quotWord# x# 200##) }
+
+d = \ x -> case x of { W8# x# -> W8# (quotWord8# x# 200#Word8) }
+
+e = \ x ->
+ case x of { W16# x# -> W16# (quotWord16# x# 200#Word16) }
+
+f = \ x ->
+ case x of { W32# x# -> W32# (quotWord32# x# 200#Word32) }
+
+g = \ x ->
+ case x of { W64# x# ->
+ case quotWord64# x# 10#Word64 of ds1 { __DEFAULT ->
+ case quotWord64# ds1 20#Word64 of ds2 { __DEFAULT -> W64# ds2 }
+ }
+ }
+
+h = \ x ->
+ case x of { I8# x# ->
+ I8# (quotInt8# (quotInt8# x# 10#Int8) 20#Int8)
+ }
+
+i = \ x -> case x of { I16# x# -> I16# (quotInt16# x# 200#Int16) }
+
+j = \ x -> case x of { I32# x# -> I32# (quotInt32# x# 200#Int32) }
+
+a = \ x ->
+ case x of { I32# x# ->
+ I32# (quotInt32# (quotInt32# x# 2147483647#Int32) 2147483647#Int32)
+ }
+
+k = \ x ->
+ case x of { I64# x# ->
+ case quotInt64# x# 10#Int64 of ds { __DEFAULT ->
+ case quotInt64# ds 20#Int64 of ds1 { __DEFAULT -> I64# ds1 }
+ }
+ }
+
+
+
diff --git a/testsuite/tests/primops/should_compile/all.T b/testsuite/tests/primops/should_compile/all.T
index 94ef2b5c4f..9ba0fe40e8 100644
--- a/testsuite/tests/primops/should_compile/all.T
+++ b/testsuite/tests/primops/should_compile/all.T
@@ -7,3 +7,4 @@ test('UnliftedMutVar_Comp', normal, compile, [''])
test('UnliftedStableName', normal, compile, [''])
test('KeepAliveWrapper', normal, compile, ['-O'])
test('T22152', normal, compile, [''])
+test('T22152b', normal, compile, [''])