diff options
Diffstat (limited to 'compiler/simplCore/FloatIn.hs')
-rw-r--r-- | compiler/simplCore/FloatIn.hs | 139 |
1 files changed, 110 insertions, 29 deletions
diff --git a/compiler/simplCore/FloatIn.hs b/compiler/simplCore/FloatIn.hs index 2593b1d7a1..04e4d32f5e 100644 --- a/compiler/simplCore/FloatIn.hs +++ b/compiler/simplCore/FloatIn.hs @@ -26,8 +26,9 @@ import MkCore import HscTypes ( ModGuts(..) ) import CoreUtils import CoreFVs +import CoreUnfold import CoreMonad ( CoreM ) -import Id ( isOneShotBndr, idType, isJoinId, isJoinId_maybe ) +import Id import Var import Type import VarSet @@ -151,7 +152,7 @@ fiExpr dflags to_drop (_, AnnCast expr (co_ann, co)) Cast (fiExpr dflags e_drop expr) co where [drop_here, e_drop, co_drop] - = sepBindsByDropPoint dflags False + = sepBindsByDropPoint dflags SepVanilla [freeVarsOf expr, freeVarsOfAnn co_ann] to_drop @@ -173,7 +174,7 @@ fiExpr dflags to_drop ann_expr@(_,AnnApp {}) arg_fvs = map freeVarsOf ann_args (drop_here : extra_drop : fun_drop : arg_drops) - = sepBindsByDropPoint dflags False + = sepBindsByDropPoint dflags SepVanilla (extra_fvs : fun_fvs : arg_fvs) to_drop -- Shortcut behaviour: if to_drop is empty, @@ -446,7 +447,7 @@ fiExpr dflags to_drop (_, AnnCase scrut case_bndr _ [(con,alt_bndrs,rhs)]) scrut_fvs = freeVarsOf scrut [shared_binds, scrut_binds, rhs_binds] - = sepBindsByDropPoint dflags False + = sepBindsByDropPoint dflags SepVanilla [scrut_fvs, rhs_fvs] to_drop @@ -456,16 +457,17 @@ fiExpr dflags to_drop (_, AnnCase scrut case_bndr ty alts) Case (fiExpr dflags scrut_drops scrut) case_bndr ty (zipWith fi_alt alts_drops_s alts) where - -- Float into the scrut and alts-considered-together just like App + -- Float into the scrut and alts-considered-together just like App [drop_here1, scrut_drops, alts_drops] - = sepBindsByDropPoint dflags False + = sepBindsByDropPoint dflags SepVanilla [scrut_fvs, all_alts_fvs] to_drop - -- Float into the alts with the is_case flag set + -- Float into the alts with the SepCase context set (drop_here2 : alts_drops_s) | [ _ ] <- alts = [] : [alts_drops] - | otherwise = sepBindsByDropPoint dflags True alts_fvs alts_drops + | otherwise = sepBindsByDropPoint dflags SepCase + alts_fvs alts_drops scrut_fvs = freeVarsOf scrut alts_fvs = map alt_fvs alts @@ -491,7 +493,7 @@ fiBind dflags to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs = ( extra_binds ++ shared_binds -- Land these before -- See Note [extra_fvs (1,2)] , FB (unitDVarSet id) rhs_fvs' -- The new binding itself - (FloatLet (NonRec id rhs')) + (FloatLet (NonRec id rhs')) , body_binds ) -- Land these after where @@ -508,7 +510,8 @@ fiBind dflags to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs -- But do float into join points [shared_binds, extra_binds, rhs_binds, body_binds] - = sepBindsByDropPoint dflags False + = sepBindsByDropPoint dflags + (if isJoinId id then SepNonRecJoin else SepVanilla) [extra_fvs, rhs_fvs, body_fvs2] to_drop @@ -533,7 +536,7 @@ fiBind dflags to_drop (AnnRec bindings) body_fvs , noFloatIntoRhs Recursive bndr rhs ] (shared_binds:extra_binds:body_binds:rhss_binds) - = sepBindsByDropPoint dflags False + = sepBindsByDropPoint dflags SepVanilla (extra_fvs:body_fvs:rhss_fvs) to_drop @@ -654,9 +657,41 @@ We have to maintain the order on these drop-point-related lists. -- pprFIB :: FloatInBinds -> SDoc -- pprFIB fibs = text "FIB:" <+> ppr [b | FB _ _ b <- fibs] +data SepCtxt + = SepCase + | SepNonRecJoin + | SepVanilla + +{- Note [Floating join points] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +We push join-point bindings inwards merrily, just like let-bindings. +They may get floated out again; e.g. + join j1 x = e1 + in join j2 y = ...j1... + in ... +==> + join j2 y = join { j1 x = e1 } in ...j1... + in ... + +Here we might float j1 out again. But we must float it in in case it +allows an ordinary let-binding to go too. E.g. + let x = <thunk> + in join j1 x = e1 + in join j2 y = ...j1... + in ... +===> + join j2 y = let { x = <thunk } + in join { j1 x = e1 } + in ...j1... + in ... + +Ths is important; now the thunk for 'x' may not be allocated on the +paths that don't involve j2. +-} + sepBindsByDropPoint :: DynFlags - -> Bool -- True <=> is case expression + -> SepCtxt -> [FreeVarSet] -- One set of FVs per drop point -- Always at least two long! -> FloatInBinds -- Candidate floaters @@ -672,15 +707,15 @@ sepBindsByDropPoint type DropBox = (FreeVarSet, FloatInBinds) -sepBindsByDropPoint dflags is_case drop_pts floaters +sepBindsByDropPoint dflags sep_ctxt drop_pts floaters | null floaters -- Shortcut common case = [] : [[] | _ <- drop_pts] | otherwise - = ASSERT( drop_pts `lengthAtLeast` 2 ) + = ASSERT( n_alts >= 2 ) -- Invariant on caller go floaters (map (\fvs -> (fvs, [])) (emptyDVarSet : drop_pts)) where - n_alts = length drop_pts + n_alts = length drop_pts -- n_alts >= 2 go :: FloatInBinds -> [DropBox] -> [FloatInBinds] -- The *first* one in the argument list is the drop_here set @@ -697,18 +732,29 @@ sepBindsByDropPoint dflags is_case drop_pts floaters (used_here : used_in_flags) = [ fvs `intersectsDVarSet` bndrs | (fvs, _) <- drop_boxes] - drop_here = used_here || cant_push + drop_here = used_here || not want_push + want_push = case sep_ctxt of + SepCase -> want_case_push + SepNonRecJoin -> want_join_push + SepVanilla -> want_let_push n_used_alts = count id used_in_flags -- returns number of Trues in list. - cant_push - | is_case = n_used_alts == n_alts -- Used in all, don't push + no_duplication = n_used_alts <= 1 -- See Note [Duplicating floats] + + duplicable_float = floatIsDupable dflags bind + -- True <=> duplication does not dup much code + -- (but it might still duplicate work!) + -- See Note [Duplicating floats] + + want_case_push = -- n_used_alts < n_alts && -- Used in all alts, don't push -- Remember n_alts > 1 - || (n_used_alts > 1 && not (floatIsDupable dflags bind)) - -- floatIsDupable: see Note [Duplicating floats] + (no_duplication || duplicable_float) - | otherwise = floatIsCase bind || n_used_alts > 1 - -- floatIsCase: see Note [Floating primops] + want_let_push = not (floatIsCase bind) -- See Note [Floating primops] + && no_duplication + + want_join_push = no_duplication -- See Note [Floating join points] new_boxes | drop_here = (insert here_box : fork_boxes) | otherwise = (here_box : new_fork_boxes) @@ -727,18 +773,43 @@ sepBindsByDropPoint dflags is_case drop_pts floaters {- Note [Duplicating floats] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +no_duplication is true if the binding us used in at most one +alternative. (Zero is rare; it means the binding is dead.) + +If no_duplication is false, we may still float: -For case expressions we duplicate the binding if it is reasonably -small, and if it is not used in all the RHSs This is good for -situations like +* For /case expressions/ only (SepCase) we duplicate the binding if it + is reasonably small, and if it is not used in all the RHSs. This is + good for situations like let x = I# y in case e of C -> error x D -> error x E -> ...not mentioning x... -If the thing is used in all RHSs there is nothing gained, -so we don't duplicate then. + If the thing is used in all RHSs there is nothing gained, + so we don't duplicate then. + +* This is NOT GOOD for other float-in places, like lets (SepVanilla). + Consider + let x = <small> in + let v = ...x... + in ...x... + + We definitely don't want to duplicate x into the RHS of v and the + body! At least, it would be OK if <small> was a value; but we don't + test that. + +* For non-recursive join bindings (SepNonRecJoin) we must be equally + careful. Eg + let x = <small> in + join j = ...x... + in case f x of + A -> j + B -> something else + C -> j + Here we must not duplicate the let-x binding into the RHS of j + and the body, or we'll duplicate the redex. -} floatedBindsFVs :: FloatInBinds -> FreeVarSet @@ -754,9 +825,19 @@ wrapFloats (FB _ _ fl : bs) e = wrapFloats bs (wrapFloat fl e) floatIsDupable :: DynFlags -> FloatBind -> Bool floatIsDupable dflags (FloatCase scrut _ _ _) = exprIsDupable dflags scrut -floatIsDupable dflags (FloatLet (Rec prs)) = all (exprIsDupable dflags . snd) prs -floatIsDupable dflags (FloatLet (NonRec _ r)) = exprIsDupable dflags r +floatIsDupable dflags (FloatLet (Rec prs)) = -- all (exprIsDupable dflags . snd) prs + all (smallEnough dflags) prs +floatIsDupable dflags (FloatLet (NonRec b r)) = -- exprIsDupable dflags r + smallEnough dflags (b,r) + +smallEnough :: DynFlags -> (Id,CoreExpr) -> Bool +smallEnough dflags (_,rhs) + = couldBeSmallEnoughToInline dflags (ufUseThreshold dflags) rhs floatIsCase :: FloatBind -> Bool floatIsCase (FloatCase {}) = True floatIsCase (FloatLet {}) = False + +--floatIsJoin :: FloatBind -> Bool +--floatIsJoin (FloatCase {}) = False +--floatIsJoin (FloatLet b) = isJoinBind b |