summaryrefslogtreecommitdiff
path: root/compiler/simplCore/FloatIn.hs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/simplCore/FloatIn.hs')
-rw-r--r--compiler/simplCore/FloatIn.hs139
1 files changed, 110 insertions, 29 deletions
diff --git a/compiler/simplCore/FloatIn.hs b/compiler/simplCore/FloatIn.hs
index 2593b1d7a1..04e4d32f5e 100644
--- a/compiler/simplCore/FloatIn.hs
+++ b/compiler/simplCore/FloatIn.hs
@@ -26,8 +26,9 @@ import MkCore
import HscTypes ( ModGuts(..) )
import CoreUtils
import CoreFVs
+import CoreUnfold
import CoreMonad ( CoreM )
-import Id ( isOneShotBndr, idType, isJoinId, isJoinId_maybe )
+import Id
import Var
import Type
import VarSet
@@ -151,7 +152,7 @@ fiExpr dflags to_drop (_, AnnCast expr (co_ann, co))
Cast (fiExpr dflags e_drop expr) co
where
[drop_here, e_drop, co_drop]
- = sepBindsByDropPoint dflags False
+ = sepBindsByDropPoint dflags SepVanilla
[freeVarsOf expr, freeVarsOfAnn co_ann]
to_drop
@@ -173,7 +174,7 @@ fiExpr dflags to_drop ann_expr@(_,AnnApp {})
arg_fvs = map freeVarsOf ann_args
(drop_here : extra_drop : fun_drop : arg_drops)
- = sepBindsByDropPoint dflags False
+ = sepBindsByDropPoint dflags SepVanilla
(extra_fvs : fun_fvs : arg_fvs)
to_drop
-- Shortcut behaviour: if to_drop is empty,
@@ -446,7 +447,7 @@ fiExpr dflags to_drop (_, AnnCase scrut case_bndr _ [(con,alt_bndrs,rhs)])
scrut_fvs = freeVarsOf scrut
[shared_binds, scrut_binds, rhs_binds]
- = sepBindsByDropPoint dflags False
+ = sepBindsByDropPoint dflags SepVanilla
[scrut_fvs, rhs_fvs]
to_drop
@@ -456,16 +457,17 @@ fiExpr dflags to_drop (_, AnnCase scrut case_bndr ty alts)
Case (fiExpr dflags scrut_drops scrut) case_bndr ty
(zipWith fi_alt alts_drops_s alts)
where
- -- Float into the scrut and alts-considered-together just like App
+ -- Float into the scrut and alts-considered-together just like App
[drop_here1, scrut_drops, alts_drops]
- = sepBindsByDropPoint dflags False
+ = sepBindsByDropPoint dflags SepVanilla
[scrut_fvs, all_alts_fvs]
to_drop
- -- Float into the alts with the is_case flag set
+ -- Float into the alts with the SepCase context set
(drop_here2 : alts_drops_s)
| [ _ ] <- alts = [] : [alts_drops]
- | otherwise = sepBindsByDropPoint dflags True alts_fvs alts_drops
+ | otherwise = sepBindsByDropPoint dflags SepCase
+ alts_fvs alts_drops
scrut_fvs = freeVarsOf scrut
alts_fvs = map alt_fvs alts
@@ -491,7 +493,7 @@ fiBind dflags to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs
= ( extra_binds ++ shared_binds -- Land these before
-- See Note [extra_fvs (1,2)]
, FB (unitDVarSet id) rhs_fvs' -- The new binding itself
- (FloatLet (NonRec id rhs'))
+ (FloatLet (NonRec id rhs'))
, body_binds ) -- Land these after
where
@@ -508,7 +510,8 @@ fiBind dflags to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs
-- But do float into join points
[shared_binds, extra_binds, rhs_binds, body_binds]
- = sepBindsByDropPoint dflags False
+ = sepBindsByDropPoint dflags
+ (if isJoinId id then SepNonRecJoin else SepVanilla)
[extra_fvs, rhs_fvs, body_fvs2]
to_drop
@@ -533,7 +536,7 @@ fiBind dflags to_drop (AnnRec bindings) body_fvs
, noFloatIntoRhs Recursive bndr rhs ]
(shared_binds:extra_binds:body_binds:rhss_binds)
- = sepBindsByDropPoint dflags False
+ = sepBindsByDropPoint dflags SepVanilla
(extra_fvs:body_fvs:rhss_fvs)
to_drop
@@ -654,9 +657,41 @@ We have to maintain the order on these drop-point-related lists.
-- pprFIB :: FloatInBinds -> SDoc
-- pprFIB fibs = text "FIB:" <+> ppr [b | FB _ _ b <- fibs]
+data SepCtxt
+ = SepCase
+ | SepNonRecJoin
+ | SepVanilla
+
+{- Note [Floating join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We push join-point bindings inwards merrily, just like let-bindings.
+They may get floated out again; e.g.
+ join j1 x = e1
+ in join j2 y = ...j1...
+ in ...
+==>
+ join j2 y = join { j1 x = e1 } in ...j1...
+ in ...
+
+Here we might float j1 out again. But we must float it in in case it
+allows an ordinary let-binding to go too. E.g.
+ let x = <thunk>
+ in join j1 x = e1
+ in join j2 y = ...j1...
+ in ...
+===>
+ join j2 y = let { x = <thunk }
+ in join { j1 x = e1 }
+ in ...j1...
+ in ...
+
+Ths is important; now the thunk for 'x' may not be allocated on the
+paths that don't involve j2.
+-}
+
sepBindsByDropPoint
:: DynFlags
- -> Bool -- True <=> is case expression
+ -> SepCtxt
-> [FreeVarSet] -- One set of FVs per drop point
-- Always at least two long!
-> FloatInBinds -- Candidate floaters
@@ -672,15 +707,15 @@ sepBindsByDropPoint
type DropBox = (FreeVarSet, FloatInBinds)
-sepBindsByDropPoint dflags is_case drop_pts floaters
+sepBindsByDropPoint dflags sep_ctxt drop_pts floaters
| null floaters -- Shortcut common case
= [] : [[] | _ <- drop_pts]
| otherwise
- = ASSERT( drop_pts `lengthAtLeast` 2 )
+ = ASSERT( n_alts >= 2 ) -- Invariant on caller
go floaters (map (\fvs -> (fvs, [])) (emptyDVarSet : drop_pts))
where
- n_alts = length drop_pts
+ n_alts = length drop_pts -- n_alts >= 2
go :: FloatInBinds -> [DropBox] -> [FloatInBinds]
-- The *first* one in the argument list is the drop_here set
@@ -697,18 +732,29 @@ sepBindsByDropPoint dflags is_case drop_pts floaters
(used_here : used_in_flags) = [ fvs `intersectsDVarSet` bndrs
| (fvs, _) <- drop_boxes]
- drop_here = used_here || cant_push
+ drop_here = used_here || not want_push
+ want_push = case sep_ctxt of
+ SepCase -> want_case_push
+ SepNonRecJoin -> want_join_push
+ SepVanilla -> want_let_push
n_used_alts = count id used_in_flags -- returns number of Trues in list.
- cant_push
- | is_case = n_used_alts == n_alts -- Used in all, don't push
+ no_duplication = n_used_alts <= 1 -- See Note [Duplicating floats]
+
+ duplicable_float = floatIsDupable dflags bind
+ -- True <=> duplication does not dup much code
+ -- (but it might still duplicate work!)
+ -- See Note [Duplicating floats]
+
+ want_case_push = -- n_used_alts < n_alts && -- Used in all alts, don't push
-- Remember n_alts > 1
- || (n_used_alts > 1 && not (floatIsDupable dflags bind))
- -- floatIsDupable: see Note [Duplicating floats]
+ (no_duplication || duplicable_float)
- | otherwise = floatIsCase bind || n_used_alts > 1
- -- floatIsCase: see Note [Floating primops]
+ want_let_push = not (floatIsCase bind) -- See Note [Floating primops]
+ && no_duplication
+
+ want_join_push = no_duplication -- See Note [Floating join points]
new_boxes | drop_here = (insert here_box : fork_boxes)
| otherwise = (here_box : new_fork_boxes)
@@ -727,18 +773,43 @@ sepBindsByDropPoint dflags is_case drop_pts floaters
{- Note [Duplicating floats]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+no_duplication is true if the binding us used in at most one
+alternative. (Zero is rare; it means the binding is dead.)
+
+If no_duplication is false, we may still float:
-For case expressions we duplicate the binding if it is reasonably
-small, and if it is not used in all the RHSs This is good for
-situations like
+* For /case expressions/ only (SepCase) we duplicate the binding if it
+ is reasonably small, and if it is not used in all the RHSs. This is
+ good for situations like
let x = I# y in
case e of
C -> error x
D -> error x
E -> ...not mentioning x...
-If the thing is used in all RHSs there is nothing gained,
-so we don't duplicate then.
+ If the thing is used in all RHSs there is nothing gained,
+ so we don't duplicate then.
+
+* This is NOT GOOD for other float-in places, like lets (SepVanilla).
+ Consider
+ let x = <small> in
+ let v = ...x...
+ in ...x...
+
+ We definitely don't want to duplicate x into the RHS of v and the
+ body! At least, it would be OK if <small> was a value; but we don't
+ test that.
+
+* For non-recursive join bindings (SepNonRecJoin) we must be equally
+ careful. Eg
+ let x = <small> in
+ join j = ...x...
+ in case f x of
+ A -> j
+ B -> something else
+ C -> j
+ Here we must not duplicate the let-x binding into the RHS of j
+ and the body, or we'll duplicate the redex.
-}
floatedBindsFVs :: FloatInBinds -> FreeVarSet
@@ -754,9 +825,19 @@ wrapFloats (FB _ _ fl : bs) e = wrapFloats bs (wrapFloat fl e)
floatIsDupable :: DynFlags -> FloatBind -> Bool
floatIsDupable dflags (FloatCase scrut _ _ _) = exprIsDupable dflags scrut
-floatIsDupable dflags (FloatLet (Rec prs)) = all (exprIsDupable dflags . snd) prs
-floatIsDupable dflags (FloatLet (NonRec _ r)) = exprIsDupable dflags r
+floatIsDupable dflags (FloatLet (Rec prs)) = -- all (exprIsDupable dflags . snd) prs
+ all (smallEnough dflags) prs
+floatIsDupable dflags (FloatLet (NonRec b r)) = -- exprIsDupable dflags r
+ smallEnough dflags (b,r)
+
+smallEnough :: DynFlags -> (Id,CoreExpr) -> Bool
+smallEnough dflags (_,rhs)
+ = couldBeSmallEnoughToInline dflags (ufUseThreshold dflags) rhs
floatIsCase :: FloatBind -> Bool
floatIsCase (FloatCase {}) = True
floatIsCase (FloatLet {}) = False
+
+--floatIsJoin :: FloatBind -> Bool
+--floatIsJoin (FloatCase {}) = False
+--floatIsJoin (FloatLet b) = isJoinBind b