diff options
author | Sylvain Henry <sylvain@haskus.fr> | 2023-01-27 13:57:11 +0100 |
---|---|---|
committer | Marge Bot <ben+marge-bot@smart-cactus.org> | 2023-04-13 08:50:33 -0400 |
commit | 4dd021227559e1bc70cdaed12e45ff5459c33d27 (patch) | |
tree | 04497c322c430924c746102f1d679fed3e7396c0 | |
parent | 593218794199e23cdfc1a94200cbb9f404e28720 (diff) | |
download | haskell-4dd021227559e1bc70cdaed12e45ff5459c33d27.tar.gz |
Add quot folding rule (#22152)
(x / l1) / l2
l1 and l2 /= 0
l1*l2 doesn't overflow
==> x / (l1 * l2)
-rw-r--r-- | compiler/GHC/Core/Opt/ConstantFold.hs | 97 | ||||
-rw-r--r-- | compiler/GHC/Types/Literal.hs | 8 | ||||
-rw-r--r-- | testsuite/tests/primops/should_compile/T22152.stderr | 5 | ||||
-rw-r--r-- | testsuite/tests/primops/should_compile/T22152b.hs | 38 | ||||
-rw-r--r-- | testsuite/tests/primops/should_compile/T22152b.stderr | 47 | ||||
-rw-r--r-- | testsuite/tests/primops/should_compile/all.T | 1 |
6 files changed, 177 insertions, 19 deletions
diff --git a/compiler/GHC/Core/Opt/ConstantFold.hs b/compiler/GHC/Core/Opt/ConstantFold.hs index fb863d65cb..42ced5a86a 100644 --- a/compiler/GHC/Core/Opt/ConstantFold.hs +++ b/compiler/GHC/Core/Opt/ConstantFold.hs @@ -121,7 +121,9 @@ primOpRules nm = \case Int8QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int8Op2 quot) , leftZero , rightIdentity oneI8 - , equalArgs $> Lit oneI8 ] + , equalArgs $> Lit oneI8 + , quotFoldingRules int8Ops + ] Int8RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int8Op2 rem) , leftZero , oneLit 1 $> Lit zeroI8 @@ -150,7 +152,9 @@ primOpRules nm = \case , mulFoldingRules Word8MulOp word8Ops ] Word8QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word8Op2 quot) - , rightIdentity oneW8 ] + , rightIdentity oneW8 + , quotFoldingRules word8Ops + ] Word8RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word8Op2 rem) , leftZero , oneLit 1 $> Lit zeroW8 @@ -195,7 +199,9 @@ primOpRules nm = \case Int16QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int16Op2 quot) , leftZero , rightIdentity oneI16 - , equalArgs $> Lit oneI16 ] + , equalArgs $> Lit oneI16 + , quotFoldingRules int16Ops + ] Int16RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int16Op2 rem) , leftZero , oneLit 1 $> Lit zeroI16 @@ -224,7 +230,9 @@ primOpRules nm = \case , mulFoldingRules Word16MulOp word16Ops ] Word16QuotOp-> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word16Op2 quot) - , rightIdentity oneW16 ] + , rightIdentity oneW16 + , quotFoldingRules word16Ops + ] Word16RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word16Op2 rem) , leftZero , oneLit 1 $> Lit zeroW16 @@ -269,7 +277,9 @@ primOpRules nm = \case Int32QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int32Op2 quot) , leftZero , rightIdentity oneI32 - , equalArgs $> Lit oneI32 ] + , equalArgs $> Lit oneI32 + , quotFoldingRules int32Ops + ] Int32RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int32Op2 rem) , leftZero , oneLit 1 $> Lit zeroI32 @@ -298,7 +308,9 @@ primOpRules nm = \case , mulFoldingRules Word32MulOp word32Ops ] Word32QuotOp-> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word32Op2 quot) - , rightIdentity oneW32 ] + , rightIdentity oneW32 + , quotFoldingRules word32Ops + ] Word32RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word32Op2 rem) , leftZero , oneLit 1 $> Lit zeroW32 @@ -342,7 +354,9 @@ primOpRules nm = \case Int64QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int64Op2 quot) , leftZero , rightIdentity oneI64 - , equalArgs $> Lit oneI64 ] + , equalArgs $> Lit oneI64 + , quotFoldingRules int64Ops + ] Int64RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int64Op2 rem) , leftZero , oneLit 1 $> Lit zeroI64 @@ -371,7 +385,9 @@ primOpRules nm = \case , mulFoldingRules Word64MulOp word64Ops ] Word64QuotOp-> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word64Op2 quot) - , rightIdentity oneW64 ] + , rightIdentity oneW64 + , quotFoldingRules word64Ops + ] Word64RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word64Op2 rem) , leftZero , oneLit 1 $> Lit zeroW64 @@ -452,7 +468,9 @@ primOpRules nm = \case IntQuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 quot) , leftZero , rightIdentityPlatform onei - , equalArgs >> retLit onei ] + , equalArgs >> retLit onei + , quotFoldingRules intOps + ] IntRemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 rem) , leftZero , oneLit 1 >> retLit zeroi @@ -504,7 +522,9 @@ primOpRules nm = \case , mulFoldingRules WordMulOp wordOps ] WordQuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 quot) - , rightIdentityPlatform onew ] + , rightIdentityPlatform onew + , quotFoldingRules wordOps + ] WordRemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 rem) , leftZero , oneLit 1 >> retLit zerow @@ -2653,6 +2673,14 @@ orFoldingRules num_ops = do (orFoldingRules' platform arg1 arg2 num_ops <|> orFoldingRules' platform arg2 arg1 num_ops) +quotFoldingRules :: NumOps -> RuleM CoreExpr +quotFoldingRules num_ops = do + env <- getRuleOpts + guard (roNumConstantFolding env) + [arg1,arg2] <- getArgs + platform <- getPlatform + liftMaybe (quotFoldingRules' platform arg1 arg2 num_ops) + addFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr addFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of @@ -2943,6 +2971,29 @@ orFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of mkL = Lit . mkNumLiteral platform num_ops or x y = BinOpApp x (fromJust (numOr num_ops)) y +quotFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr +quotFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of + + -- (x / l1) / l2 + -- l1 and l2 /= 0 + -- l1*l2 doesn't overflow + -- ==> x / (l1 * l2) + (is_div num_ops -> Just (x, L l1), L l2) + | l1 /= 0 + , l2 /= 0 + -- check that the result of the multiplication is in range + , Just l <- mkNumLiteralMaybe platform num_ops (l1 * l2) + -> Just (div x (Lit l)) + -- NB: we could directly return 0 or (-1) in case of overflow, + -- but we would need to know + -- (1) if we're dealing with a quot or a div operation + -- (2) if it's an underflow or an overflow. + -- Left as future work for now. + + _ -> Nothing + where + div x y = BinOpApp x (fromJust (numDiv num_ops)) y + is_binop :: PrimOp -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr) is_binop op e = case e of BinOpApp x op' y | op == op' -> Just (x,y) @@ -2953,12 +3004,13 @@ is_op op e = case e of App (OpVal op') x | op == op' -> Just x _ -> Nothing -is_add, is_sub, is_mul, is_and, is_or :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr) +is_add, is_sub, is_mul, is_and, is_or, is_div :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr) is_add num_ops e = is_binop (numAdd num_ops) e is_sub num_ops e = is_binop (numSub num_ops) e is_mul num_ops e = is_binop (numMul num_ops) e is_and num_ops e = numAnd num_ops >>= \op -> is_binop op e is_or num_ops e = numOr num_ops >>= \op -> is_binop op e +is_div num_ops e = numDiv num_ops >>= \op -> is_binop op e is_neg :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr) is_neg num_ops e = numNeg num_ops >>= \op -> is_op op e @@ -3007,6 +3059,7 @@ data NumOps = NumOps { numAdd :: !PrimOp -- ^ Add two numbers , numSub :: !PrimOp -- ^ Sub two numbers , numMul :: !PrimOp -- ^ Multiply two numbers + , numDiv :: !(Maybe PrimOp) -- ^ Divide two numbers , numAnd :: !(Maybe PrimOp) -- ^ And two numbers , numOr :: !(Maybe PrimOp) -- ^ Or two numbers , numNeg :: !(Maybe PrimOp) -- ^ Negate a number @@ -3017,15 +3070,20 @@ data NumOps = NumOps mkNumLiteral :: Platform -> NumOps -> Integer -> Literal mkNumLiteral platform ops i = mkLitNumberWrap platform (numLitType ops) i +-- | Create a numeric literal if it is in range +mkNumLiteralMaybe :: Platform -> NumOps -> Integer -> Maybe Literal +mkNumLiteralMaybe platform ops i = mkLitNumberMaybe platform (numLitType ops) i + int8Ops :: NumOps int8Ops = NumOps { numAdd = Int8AddOp , numSub = Int8SubOp , numMul = Int8MulOp - , numLitType = LitNumInt8 + , numDiv = Just Int8QuotOp , numAnd = Nothing , numOr = Nothing , numNeg = Just Int8NegOp + , numLitType = LitNumInt8 } word8Ops :: NumOps @@ -3033,6 +3091,7 @@ word8Ops = NumOps { numAdd = Word8AddOp , numSub = Word8SubOp , numMul = Word8MulOp + , numDiv = Just Word8QuotOp , numAnd = Just Word8AndOp , numOr = Just Word8OrOp , numNeg = Nothing @@ -3044,10 +3103,11 @@ int16Ops = NumOps { numAdd = Int16AddOp , numSub = Int16SubOp , numMul = Int16MulOp - , numLitType = LitNumInt16 + , numDiv = Just Int16QuotOp , numAnd = Nothing , numOr = Nothing , numNeg = Just Int16NegOp + , numLitType = LitNumInt16 } word16Ops :: NumOps @@ -3055,6 +3115,7 @@ word16Ops = NumOps { numAdd = Word16AddOp , numSub = Word16SubOp , numMul = Word16MulOp + , numDiv = Just Word16QuotOp , numAnd = Just Word16AndOp , numOr = Just Word16OrOp , numNeg = Nothing @@ -3066,10 +3127,11 @@ int32Ops = NumOps { numAdd = Int32AddOp , numSub = Int32SubOp , numMul = Int32MulOp - , numLitType = LitNumInt32 + , numDiv = Just Int32QuotOp , numAnd = Nothing , numOr = Nothing , numNeg = Just Int32NegOp + , numLitType = LitNumInt32 } word32Ops :: NumOps @@ -3077,6 +3139,7 @@ word32Ops = NumOps { numAdd = Word32AddOp , numSub = Word32SubOp , numMul = Word32MulOp + , numDiv = Just Word32QuotOp , numAnd = Just Word32AndOp , numOr = Just Word32OrOp , numNeg = Nothing @@ -3088,10 +3151,11 @@ int64Ops = NumOps { numAdd = Int64AddOp , numSub = Int64SubOp , numMul = Int64MulOp - , numLitType = LitNumInt64 + , numDiv = Just Int64QuotOp , numAnd = Nothing , numOr = Nothing , numNeg = Just Int64NegOp + , numLitType = LitNumInt64 } word64Ops :: NumOps @@ -3099,6 +3163,7 @@ word64Ops = NumOps { numAdd = Word64AddOp , numSub = Word64SubOp , numMul = Word64MulOp + , numDiv = Just Word64QuotOp , numAnd = Just Word64AndOp , numOr = Just Word64OrOp , numNeg = Nothing @@ -3110,6 +3175,7 @@ intOps = NumOps { numAdd = IntAddOp , numSub = IntSubOp , numMul = IntMulOp + , numDiv = Just IntQuotOp , numAnd = Just IntAndOp , numOr = Just IntOrOp , numNeg = Just IntNegOp @@ -3121,6 +3187,7 @@ wordOps = NumOps { numAdd = WordAddOp , numSub = WordSubOp , numMul = WordMulOp + , numDiv = Just WordQuotOp , numAnd = Just WordAndOp , numOr = Just WordOrOp , numNeg = Nothing diff --git a/compiler/GHC/Types/Literal.hs b/compiler/GHC/Types/Literal.hs index 7581c10602..1bb9ddb31b 100644 --- a/compiler/GHC/Types/Literal.hs +++ b/compiler/GHC/Types/Literal.hs @@ -32,7 +32,7 @@ module GHC.Types.Literal , mkLitFloat, mkLitDouble , mkLitChar, mkLitString , mkLitBigNat - , mkLitNumber, mkLitNumberWrap + , mkLitNumber, mkLitNumberWrap, mkLitNumberMaybe -- ** Operations on Literals , literalType @@ -411,6 +411,12 @@ mkLitNumber platform nt i = assertPpr (litNumCheckRange platform nt i) (integer i) (LitNumber nt i) +-- | Create a numeric 'Literal' of the given type if it is in range +mkLitNumberMaybe :: Platform -> LitNumType -> Integer -> Maybe Literal +mkLitNumberMaybe platform nt i + | litNumCheckRange platform nt i = Just (LitNumber nt i) + | otherwise = Nothing + -- | Creates a 'Literal' of type @Int#@ mkLitInt :: Platform -> Integer -> Literal mkLitInt platform x = assertPpr (platformInIntRange platform x) (integer x) diff --git a/testsuite/tests/primops/should_compile/T22152.stderr b/testsuite/tests/primops/should_compile/T22152.stderr index 505bca04a7..33ff7721f6 100644 --- a/testsuite/tests/primops/should_compile/T22152.stderr +++ b/testsuite/tests/primops/should_compile/T22152.stderr @@ -1,10 +1,9 @@ ==================== Tidy Core ==================== Result size of Tidy Core - = {terms: 11, types: 5, coercions: 0, joins: 0/0} + = {terms: 9, types: 5, coercions: 0, joins: 0/0} -toHours - = \ t -> case t of { I# x -> I# (quotInt# (quotInt# x 60#) 60#) } +toHours = \ t -> case t of { I# x -> I# (quotInt# x 3600#) } diff --git a/testsuite/tests/primops/should_compile/T22152b.hs b/testsuite/tests/primops/should_compile/T22152b.hs new file mode 100644 index 0000000000..f6ee4fce28 --- /dev/null +++ b/testsuite/tests/primops/should_compile/T22152b.hs @@ -0,0 +1,38 @@ +{-# OPTIONS_GHC -O2 -ddump-simpl -dno-typeable-binds -dsuppress-all -dsuppress-uniques #-} +module T22152b where + +import Data.Int +import Data.Word + +a :: Int32 -> Int32 +a x = (x `quot` maxBound) `quot` maxBound -- overflow, mustn't trigger the rewrite rule + +b :: Int -> Int +b x = (x `quot` 10) `quot` 20 + +c :: Word -> Word +c x = (x `quot` 10) `quot` 20 + +d :: Word8 -> Word8 +d x = (x `quot` 10) `quot` 20 + +e :: Word16 -> Word16 +e x = (x `quot` 10) `quot` 20 + +f :: Word32 -> Word32 +f x = (x `quot` 10) `quot` 20 + +g :: Word64 -> Word64 +g x = (x `quot` 10) `quot` 20 + +h :: Int8 -> Int8 +h x = (x `quot` 10) `quot` 20 + +i :: Int16 -> Int16 +i x = (x `quot` 10) `quot` 20 + +j :: Int32 -> Int32 +j x = (x `quot` 10) `quot` 20 + +k :: Int64 -> Int64 +k x = (x `quot` 10) `quot` 20 diff --git a/testsuite/tests/primops/should_compile/T22152b.stderr b/testsuite/tests/primops/should_compile/T22152b.stderr new file mode 100644 index 0000000000..0cf317cc32 --- /dev/null +++ b/testsuite/tests/primops/should_compile/T22152b.stderr @@ -0,0 +1,47 @@ + +==================== Tidy Core ==================== +Result size of Tidy Core + = {terms: 119, types: 59, coercions: 0, joins: 0/0} + +b = \ x -> case x of { I# x1 -> I# (quotInt# x1 200#) } + +c = \ x -> case x of { W# x# -> W# (quotWord# x# 200##) } + +d = \ x -> case x of { W8# x# -> W8# (quotWord8# x# 200#Word8) } + +e = \ x -> + case x of { W16# x# -> W16# (quotWord16# x# 200#Word16) } + +f = \ x -> + case x of { W32# x# -> W32# (quotWord32# x# 200#Word32) } + +g = \ x -> + case x of { W64# x# -> + case quotWord64# x# 10#Word64 of ds1 { __DEFAULT -> + case quotWord64# ds1 20#Word64 of ds2 { __DEFAULT -> W64# ds2 } + } + } + +h = \ x -> + case x of { I8# x# -> + I8# (quotInt8# (quotInt8# x# 10#Int8) 20#Int8) + } + +i = \ x -> case x of { I16# x# -> I16# (quotInt16# x# 200#Int16) } + +j = \ x -> case x of { I32# x# -> I32# (quotInt32# x# 200#Int32) } + +a = \ x -> + case x of { I32# x# -> + I32# (quotInt32# (quotInt32# x# 2147483647#Int32) 2147483647#Int32) + } + +k = \ x -> + case x of { I64# x# -> + case quotInt64# x# 10#Int64 of ds { __DEFAULT -> + case quotInt64# ds 20#Int64 of ds1 { __DEFAULT -> I64# ds1 } + } + } + + + diff --git a/testsuite/tests/primops/should_compile/all.T b/testsuite/tests/primops/should_compile/all.T index 94ef2b5c4f..9ba0fe40e8 100644 --- a/testsuite/tests/primops/should_compile/all.T +++ b/testsuite/tests/primops/should_compile/all.T @@ -7,3 +7,4 @@ test('UnliftedMutVar_Comp', normal, compile, ['']) test('UnliftedStableName', normal, compile, ['']) test('KeepAliveWrapper', normal, compile, ['-O']) test('T22152', normal, compile, ['']) +test('T22152b', normal, compile, ['']) |