summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorJoachim Breitner <mail@joachim-breitner.de>2015-03-02 21:20:24 +0100
committerJoachim Breitner <mail@joachim-breitner.de>2015-03-02 23:12:46 +0100
commitc3eee14d31585445d4a7eff5b6c69a815b911059 (patch)
tree052db8973e5fe2c1e4ce823f6f2b267b86dd8d8c /compiler
parent3197018d4efbf7407577300b88897cef26f7f4c6 (diff)
downloadhaskell-c3eee14d31585445d4a7eff5b6c69a815b911059.tar.gz
Improve if-then-else tree for cases on literal values
Previously, in the branch of the if-then-else tree, it would emit a final check if the scrut matches the alternative, even if earlier comparisons alread imply this equality. By keeping track of the bounds we can skip this check. Of course this is only sound for integer types. This closes #10129. Differential Revision: https://phabricator.haskell.org/D693
Diffstat (limited to 'compiler')
-rw-r--r--compiler/basicTypes/Literal.hs11
-rw-r--r--compiler/codeGen/StgCmmUtils.hs29
2 files changed, 34 insertions, 6 deletions
diff --git a/compiler/basicTypes/Literal.hs b/compiler/basicTypes/Literal.hs
index cb0be03402..8198f81078 100644
--- a/compiler/basicTypes/Literal.hs
+++ b/compiler/basicTypes/Literal.hs
@@ -30,6 +30,7 @@ module Literal
, inIntRange, inWordRange, tARGET_MAX_INT, inCharRange
, isZeroLit
, litFitsInChar
+ , onlyWithinBounds
-- ** Coercions
, word2IntLit, int2WordLit
@@ -359,6 +360,16 @@ litIsLifted :: Literal -> Bool
litIsLifted (LitInteger {}) = True
litIsLifted _ = False
+-- | x `onlyWithinBounds` (l,h) is true if l <= y < h ==> x = y
+onlyWithinBounds :: Literal -> (Literal, Literal) -> Bool
+onlyWithinBounds (MachChar x) (MachChar l, MachChar h) = x == l && succ x == h
+onlyWithinBounds (MachInt x) (MachInt l, MachInt h) = x == l && succ x == h
+onlyWithinBounds (MachWord x) (MachWord l, MachWord h) = x == l && succ x == h
+onlyWithinBounds (MachInt64 x) (MachInt64 l, MachInt64 h) = x == l && succ x == h
+onlyWithinBounds (MachWord64 x) (MachWord64 l, MachWord64 h) = x == l && succ x == h
+onlyWithinBounds _ _ = False
+
+
{-
Types
~~~~~
diff --git a/compiler/codeGen/StgCmmUtils.hs b/compiler/codeGen/StgCmmUtils.hs
index 5e8944df4a..763177f297 100644
--- a/compiler/codeGen/StgCmmUtils.hs
+++ b/compiler/codeGen/StgCmmUtils.hs
@@ -652,14 +652,21 @@ emitCmmLitSwitch scrut branches deflt = do
join_lbl <- newLabelC
deflt_lbl <- label_code join_lbl deflt
branches_lbls <- label_branches join_lbl branches
- emit =<< mk_lit_switch scrut' deflt_lbl
+ emit =<< mk_lit_switch scrut' deflt_lbl noBound
(sortBy (comparing fst) branches_lbls)
emitLabel join_lbl
+-- | lower bound (inclusive), upper bound (exclusive)
+type LitBound = (Maybe Literal, Maybe Literal)
+
+noBound :: LitBound
+noBound = (Nothing, Nothing)
+
mk_lit_switch :: CmmExpr -> BlockId
+ -> LitBound
-> [(Literal,BlockId)]
-> FCode CmmAGraph
-mk_lit_switch scrut deflt [(lit,blk)]
+mk_lit_switch scrut deflt bounds [(lit,blk)]
= do
dflags <- getDynFlags
let
@@ -667,12 +674,19 @@ mk_lit_switch scrut deflt [(lit,blk)]
cmm_ty = cmmLitType dflags cmm_lit
rep = typeWidth cmm_ty
ne = if isFloatType cmm_ty then MO_F_Ne rep else MO_Ne rep
- return (mkCbranch (CmmMachOp ne [scrut, CmmLit cmm_lit]) deflt blk)
-mk_lit_switch scrut deflt_blk_id branches
+ return $ if lit `onlyWithinBounds'` bounds
+ then mkBranch blk
+ else mkCbranch (CmmMachOp ne [scrut, CmmLit cmm_lit]) deflt blk
+ where
+ -- If the bounds already imply scrut == lit, then we can skip the final check (#10129)
+ l `onlyWithinBounds'` (Just lo, Just hi) = l `onlyWithinBounds` (lo, hi)
+ l `onlyWithinBounds'` _ = False
+
+mk_lit_switch scrut deflt_blk_id (lo_bound, hi_bound) branches
= do dflags <- getDynFlags
- lo_blk <- mk_lit_switch scrut deflt_blk_id lo_branches
- hi_blk <- mk_lit_switch scrut deflt_blk_id hi_branches
+ lo_blk <- mk_lit_switch scrut deflt_blk_id bounds_lo lo_branches
+ hi_blk <- mk_lit_switch scrut deflt_blk_id bounds_hi hi_branches
mkCmmIfThenElse (cond dflags) lo_blk hi_blk
where
n_branches = length branches
@@ -682,6 +696,9 @@ mk_lit_switch scrut deflt_blk_id branches
(lo_branches, hi_branches) = span is_lo branches
is_lo (t,_) = t < mid_lit
+ bounds_lo = (lo_bound, Just mid_lit)
+ bounds_hi = (Just mid_lit, hi_bound)
+
cond dflags = CmmMachOp (mkLtOp dflags mid_lit)
[scrut, CmmLit (mkSimpleLit dflags mid_lit)]