summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/basicTypes/Literal.hs41
-rw-r--r--compiler/main/DynFlags.hs3
-rw-r--r--compiler/prelude/PrelRules.hs57
-rw-r--r--compiler/simplCore/SimplUtils.hs76
-rw-r--r--docs/users_guide/using-optimisation.rst21
-rw-r--r--testsuite/tests/perf/compiler/T12877.hs117
-rw-r--r--testsuite/tests/perf/compiler/T12877.stdout1
-rw-r--r--testsuite/tests/perf/compiler/all.T13
-rw-r--r--utils/mkUserGuidePart/Options/Optimizations.hs5
9 files changed, 322 insertions, 12 deletions
diff --git a/compiler/basicTypes/Literal.hs b/compiler/basicTypes/Literal.hs
index 813759628c..14ef785905 100644
--- a/compiler/basicTypes/Literal.hs
+++ b/compiler/basicTypes/Literal.hs
@@ -29,7 +29,7 @@ module Literal
, inIntRange, inWordRange, tARGET_MAX_INT, inCharRange
, isZeroLit
, litFitsInChar
- , litValue
+ , litValue, isLitValue, isLitValue_maybe, mapLitValue
-- ** Coercions
, word2IntLit, int2WordLit
@@ -59,6 +59,7 @@ import Data.ByteString (ByteString)
import Data.Int
import Data.Word
import Data.Char
+import Data.Maybe ( isJust )
import Data.Data ( Data )
import Numeric ( fromRat )
@@ -271,13 +272,37 @@ isZeroLit _ = False
-- | Returns the 'Integer' contained in the 'Literal', for when that makes
-- sense, i.e. for 'Char', 'Int', 'Word' and 'LitInteger'.
litValue :: Literal -> Integer
-litValue (MachChar c) = toInteger $ ord c
-litValue (MachInt i) = i
-litValue (MachInt64 i) = i
-litValue (MachWord i) = i
-litValue (MachWord64 i) = i
-litValue (LitInteger i _) = i
-litValue l = pprPanic "litValue" (ppr l)
+litValue l = case isLitValue_maybe l of
+ Just x -> x
+ Nothing -> pprPanic "litValue" (ppr l)
+
+-- | Returns the 'Integer' contained in the 'Literal', for when that makes
+-- sense, i.e. for 'Char', 'Int', 'Word' and 'LitInteger'.
+isLitValue_maybe :: Literal -> Maybe Integer
+isLitValue_maybe (MachChar c) = Just $ toInteger $ ord c
+isLitValue_maybe (MachInt i) = Just i
+isLitValue_maybe (MachInt64 i) = Just i
+isLitValue_maybe (MachWord i) = Just i
+isLitValue_maybe (MachWord64 i) = Just i
+isLitValue_maybe (LitInteger i _) = Just i
+isLitValue_maybe _ = Nothing
+
+-- | Apply a function to the 'Integer' contained in the 'Literal', for when that
+-- makes sense, e.g. for 'Char', 'Int', 'Word' and 'LitInteger'.
+mapLitValue :: (Integer -> Integer) -> Literal -> Literal
+mapLitValue f (MachChar c) = MachChar (fchar c)
+ where fchar = chr . fromInteger . f . toInteger . ord
+mapLitValue f (MachInt i) = MachInt (f i)
+mapLitValue f (MachInt64 i) = MachInt64 (f i)
+mapLitValue f (MachWord i) = MachWord (f i)
+mapLitValue f (MachWord64 i) = MachWord64 (f i)
+mapLitValue f (LitInteger i t) = LitInteger (f i) t
+mapLitValue _ l = pprPanic "mapLitValue" (ppr l)
+
+-- | Indicate if the `Literal` contains an 'Integer' value, e.g. 'Char',
+-- 'Int', 'Word' and 'LitInteger'.
+isLitValue :: Literal -> Bool
+isLitValue = isJust . isLitValue_maybe
{-
Coercions
diff --git a/compiler/main/DynFlags.hs b/compiler/main/DynFlags.hs
index cbf247c49d..d7cde29557 100644
--- a/compiler/main/DynFlags.hs
+++ b/compiler/main/DynFlags.hs
@@ -445,6 +445,7 @@ data GeneralFlag
| Opt_IgnoreAsserts
| Opt_DoEtaReduction
| Opt_CaseMerge
+ | Opt_CaseFolding -- Constant folding through case-expressions
| Opt_UnboxStrictFields
| Opt_UnboxSmallStrictFields
| Opt_DictsCheap
@@ -3561,6 +3562,7 @@ fFlagsDeps = [
flagSpec "building-cabal-package" Opt_BuildingCabalPackage,
flagSpec "call-arity" Opt_CallArity,
flagSpec "case-merge" Opt_CaseMerge,
+ flagSpec "case-folding" Opt_CaseFolding,
flagSpec "cmm-elim-common-blocks" Opt_CmmElimCommonBlocks,
flagSpec "cmm-sink" Opt_CmmSink,
flagSpec "cse" Opt_CSE,
@@ -4012,6 +4014,7 @@ optLevelFlags -- see Note [Documenting optimisation flags]
, ([1,2], Opt_CallArity)
, ([1,2], Opt_CaseMerge)
+ , ([1,2], Opt_CaseFolding)
, ([1,2], Opt_CmmElimCommonBlocks)
, ([1,2], Opt_CmmSink)
, ([1,2], Opt_CSE)
diff --git a/compiler/prelude/PrelRules.hs b/compiler/prelude/PrelRules.hs
index 8868047005..e98fd9f6a3 100644
--- a/compiler/prelude/PrelRules.hs
+++ b/compiler/prelude/PrelRules.hs
@@ -15,7 +15,12 @@ ToDo:
{-# LANGUAGE CPP, RankNTypes #-}
{-# OPTIONS_GHC -optc-DNON_POSIX_SOURCE #-}
-module PrelRules ( primOpRules, builtinRules ) where
+module PrelRules
+ ( primOpRules
+ , builtinRules
+ , caseRules
+ )
+where
#include "HsVersions.h"
#include "../includes/MachDeps.h"
@@ -1385,3 +1390,53 @@ match_smallIntegerTo primOp _ _ _ [App (Var x) y]
| idName x == smallIntegerName
= Just $ App (Var (mkPrimOpId primOp)) y
match_smallIntegerTo _ _ _ _ _ = Nothing
+
+
+
+--------------------------------------------------------
+-- Constant folding through case-expressions
+--
+-- cf Scrutinee Constant Folding in simplCore/SimplUtils
+--------------------------------------------------------
+
+-- | Match the scrutinee of a case and potentially return a new scrutinee and a
+-- function to apply to each literal alternative.
+caseRules :: CoreExpr -> Maybe (CoreExpr, Integer -> Integer)
+caseRules scrut = case scrut of
+
+ -- v `op` x#
+ App (App (Var f) v) (Lit l)
+ | Just op <- isPrimOpId_maybe f
+ , Just x <- isLitValue_maybe l ->
+ case op of
+ WordAddOp -> Just (v, \y -> y-x )
+ IntAddOp -> Just (v, \y -> y-x )
+ WordSubOp -> Just (v, \y -> y+x )
+ IntSubOp -> Just (v, \y -> y+x )
+ XorOp -> Just (v, \y -> y `xor` x)
+ XorIOp -> Just (v, \y -> y `xor` x)
+ _ -> Nothing
+
+ -- x# `op` v
+ App (App (Var f) (Lit l)) v
+ | Just op <- isPrimOpId_maybe f
+ , Just x <- isLitValue_maybe l ->
+ case op of
+ WordAddOp -> Just (v, \y -> y-x )
+ IntAddOp -> Just (v, \y -> y-x )
+ WordSubOp -> Just (v, \y -> x-y )
+ IntSubOp -> Just (v, \y -> x-y )
+ XorOp -> Just (v, \y -> y `xor` x)
+ XorIOp -> Just (v, \y -> y `xor` x)
+ _ -> Nothing
+
+ -- op v
+ App (Var f) v
+ | Just op <- isPrimOpId_maybe f ->
+ case op of
+ NotOp -> Just (v, \y -> complement y)
+ NotIOp -> Just (v, \y -> complement y)
+ IntNegOp -> Just (v, \y -> negate y )
+ _ -> Nothing
+
+ _ -> Nothing
diff --git a/compiler/simplCore/SimplUtils.hs b/compiler/simplCore/SimplUtils.hs
index 48dce1d090..6c4737507a 100644
--- a/compiler/simplCore/SimplUtils.hs
+++ b/compiler/simplCore/SimplUtils.hs
@@ -60,6 +60,8 @@ import Util
import MonadUtils
import Outputable
import Pair
+import PrelRules
+import Literal
import Control.Monad ( when )
@@ -1752,9 +1754,46 @@ mkCase tries these things
False -> False
and similar friends.
+
+3. Scrutinee Constant Folding
+
+ case x op# k# of _ { ===> case x of _ {
+ a1# -> e1 (a1# inv_op# k#) -> e1
+ a2# -> e2 (a2# inv_op# k#) -> e2
+ ... ...
+ DEFAULT -> ed DEFAULT -> ed
+
+ where (x op# k#) inv_op# k# == x
+
+ And similarly for commuted arguments and for some unary operations.
+
+ The purpose of this transformation is not only to avoid an arithmetic
+ operation at runtime but to allow other transformations to apply in cascade.
+
+ Example with the "Merge Nested Cases" optimization (from #12877):
+
+ main = case t of t0
+ 0## -> ...
+ DEFAULT -> case t0 `minusWord#` 1## of t1
+ 0## -> ...
+ DEFAUT -> case t1 `minusWord#` 1## of t2
+ 0## -> ...
+ DEFAULT -> case t2 `minusWord#` 1## of _
+ 0## -> ...
+ DEFAULT -> ...
+
+ becomes:
+
+ main = case t of _
+ 0## -> ...
+ 1## -> ...
+ 2## -> ...
+ 3## -> ...
+ DEFAULT -> ...
+
-}
-mkCase, mkCase1, mkCase2
+mkCase, mkCase1, mkCase2, mkCase3
:: DynFlags
-> OutExpr -> OutId
-> OutType -> [OutAlt] -- Alternatives in standard (increasing) order
@@ -1848,9 +1887,42 @@ mkCase1 _dflags scrut case_bndr _ alts@((_,_,rhs1) : _) -- Identity case
mkCase1 dflags scrut bndr alts_ty alts = mkCase2 dflags scrut bndr alts_ty alts
--------------------------------------------------
+-- 2. Scrutinee Constant Folding
+--------------------------------------------------
+
+mkCase2 dflags scrut bndr alts_ty alts
+ | gopt Opt_CaseFolding dflags
+ , Just (scrut',f) <- caseRules scrut
+ = mkCase3 dflags scrut' bndr alts_ty (map (mapAlt f) alts)
+ | otherwise
+ = mkCase3 dflags scrut bndr alts_ty alts
+ where
+ -- We need to keep the correct association between the scrutinee and its
+ -- binder if the latter isn't dead. Hence we wrap rhs of alternatives with
+ -- "let bndr = ... in":
+ --
+ -- case v + 10 of y =====> case v of y
+ -- 20 -> e1 10 -> let y = 20 in e1
+ -- DEFAULT -> e2 DEFAULT -> let y = v + 10 in e2
+ --
+ -- Other transformations give: =====> case v of y'
+ -- 10 -> let y = 20 in e1
+ -- DEFAULT -> let y = y' + 10 in e2
+ --
+ wrap_rhs l rhs
+ | isDeadBinder bndr = rhs
+ | otherwise = Let (NonRec bndr l) rhs
+
+ mapAlt f alt@(c,bs,e) = case c of
+ DEFAULT -> (c, bs, wrap_rhs scrut e)
+ LitAlt l
+ | isLitValue l -> (LitAlt (mapLitValue f l), bs, wrap_rhs (Lit l) e)
+ _ -> pprPanic "Unexpected alternative (mkCase2)" (ppr alt)
+
+--------------------------------------------------
-- Catch-all
--------------------------------------------------
-mkCase2 _dflags scrut bndr alts_ty alts
+mkCase3 _dflags scrut bndr alts_ty alts
= return (Case scrut bndr alts_ty alts)
{-
diff --git a/docs/users_guide/using-optimisation.rst b/docs/users_guide/using-optimisation.rst
index 6b58093513..3e660c19e9 100644
--- a/docs/users_guide/using-optimisation.rst
+++ b/docs/users_guide/using-optimisation.rst
@@ -115,7 +115,7 @@ list.
:default: on
- Merge immediately-nested case expressions that scrutinse the same variable.
+ Merge immediately-nested case expressions that scrutinise the same variable.
For example, ::
case x of
@@ -131,6 +131,25 @@ list.
Blue -> e2
Green -> e2
+.. ghc-flag:: -fcase-folding
+
+ :default: on
+
+ Allow constant folding in case expressions that scrutinise some primops:
+ For example, ::
+
+ case x `minusWord#` 10## of
+ 10## -> e1
+ 20## -> e2
+ v -> e3
+
+ Is transformed to, ::
+
+ case x of
+ 20## -> e1
+ 30## -> e2
+ _ -> let v = x `minusWord#` 10## in e3
+
.. ghc-flag:: -fcall-arity
:default: on
diff --git a/testsuite/tests/perf/compiler/T12877.hs b/testsuite/tests/perf/compiler/T12877.hs
new file mode 100644
index 0000000000..2fc7d58dd4
--- /dev/null
+++ b/testsuite/tests/perf/compiler/T12877.hs
@@ -0,0 +1,117 @@
+-- This ugly cascading case reduces to:
+-- case x of
+-- 0 -> "0"
+-- 1 -> "1"
+-- _ -> "n"
+--
+-- but only if GHC's case-folding reduction kicks in.
+
+{-# NOINLINE test #-}
+test :: Word -> String
+test x = case x of
+ 0 -> "0"
+ 1 -> "1"
+ t -> case t + 1 of
+ 1 -> "0"
+ 2 -> "1"
+ t -> case t + 1 of
+ 2 -> "0"
+ 3 -> "1"
+ t -> case t + 1 of
+ 3 -> "0"
+ 4 -> "1"
+ t -> case t + 1 of
+ 4 -> "0"
+ 5 -> "1"
+ t -> case t + 1 of
+ 5 -> "0"
+ 6 -> "1"
+ t -> case t + 1 of
+ 6 -> "0"
+ 7 -> "1"
+ t -> case t + 1 of
+ 7 -> "0"
+ 8 -> "1"
+ t -> case t + 1 of
+ 8 -> "0"
+ 9 -> "1"
+ t -> case t + 1 of
+ 10 -> "0"
+ 11 -> "1"
+ t -> case t + 1 of
+ 11 -> "0"
+ 12 -> "1"
+ t -> case t + 1 of
+ 12 -> "0"
+ 13 -> "1"
+ t -> case t + 1 of
+ 13 -> "0"
+ 14 -> "1"
+ t -> case t + 1 of
+ 14 -> "0"
+ 15 -> "1"
+ t -> case t + 1 of
+ 15 -> "0"
+ 16 -> "1"
+ t -> case t + 1 of
+ 16 -> "0"
+ 17 -> "1"
+ t -> case t + 1 of
+ 17 -> "0"
+ 18 -> "1"
+ t -> case t + 1 of
+ 18 -> "0"
+ 19 -> "1"
+ t -> case t + 1 of
+ 19 -> "0"
+ 20 -> "1"
+ t -> case t + 1 of
+ 20 -> "0"
+ 21 -> "1"
+ t -> case t + 1 of
+ 21 -> "0"
+ 22 -> "1"
+ t -> case t + 1 of
+ 22 -> "0"
+ 23 -> "1"
+ t -> case t + 1 of
+ 23 -> "0"
+ 24 -> "1"
+ t -> case t + 1 of
+ 24 -> "0"
+ 25 -> "1"
+ t -> case t + 1 of
+ 25 -> "0"
+ 26 -> "1"
+ t -> case t + 1 of
+ 26 -> "0"
+ 27 -> "1"
+ t -> case t + 1 of
+ 27 -> "0"
+ 28 -> "1"
+ t -> case t + 1 of
+ 28 -> "0"
+ 29 -> "1"
+ t -> case t + 1 of
+ 29 -> "0"
+ 30 -> "1"
+ t -> case t + 1 of
+ 30 -> "0"
+ 31 -> "1"
+ t -> case t + 1 of
+ 31 -> "0"
+ 32 -> "1"
+ t -> case t + 1 of
+ 32 -> "0"
+ 33 -> "1"
+ t -> case t + 1 of
+ 33 -> "0"
+ 34 -> "1"
+ t -> case t + 1 of
+ 34 -> "0"
+ 35 -> "1"
+ _ -> "n"
+
+main :: IO ()
+main = do
+ putStrLn [last (concat (fmap test [0..12345678]))]
diff --git a/testsuite/tests/perf/compiler/T12877.stdout b/testsuite/tests/perf/compiler/T12877.stdout
new file mode 100644
index 0000000000..8ba3a16384
--- /dev/null
+++ b/testsuite/tests/perf/compiler/T12877.stdout
@@ -0,0 +1 @@
+n
diff --git a/testsuite/tests/perf/compiler/all.T b/testsuite/tests/perf/compiler/all.T
index 0ccde15106..38cbdd0311 100644
--- a/testsuite/tests/perf/compiler/all.T
+++ b/testsuite/tests/perf/compiler/all.T
@@ -895,3 +895,16 @@ test('T12234',
compile,
[''])
+test('T12877',
+ [ stats_num_field('bytes allocated',
+ [(wordsize(64), 197582248, 5),
+ # initial: 197582248 (Linux)
+ ])
+ , compiler_stats_num_field('bytes allocated',
+ [(wordsize(64), 135979000, 5),
+ # initial: 135979000 (Linux)
+ ]),
+ ],
+ compile_and_run,
+ ['-O2'])
+
diff --git a/utils/mkUserGuidePart/Options/Optimizations.hs b/utils/mkUserGuidePart/Options/Optimizations.hs
index 29d35a0377..b0f9bc5ac8 100644
--- a/utils/mkUserGuidePart/Options/Optimizations.hs
+++ b/utils/mkUserGuidePart/Options/Optimizations.hs
@@ -15,6 +15,11 @@ optimizationsOptions =
, flagType = DynamicFlag
, flagReverse = "-fno-case-merge"
}
+ , flag { flagName = "-fcase-folding"
+ , flagDescription = "Enable constant folding in case expressions. Implied by :ghc-flag:`-O`."
+ , flagType = DynamicFlag
+ , flagReverse = "-fno-case-folding"
+ }
, flag { flagName = "-fcmm-elim-common-blocks"
, flagDescription =
"Enable Cmm common block elimination. Implied by :ghc-flag:`-O`."