From 6206cb9287f3f6e70c669660a646a65274870d2b Mon Sep 17 00:00:00 2001 From: Simon Peyton Jones Date: Fri, 23 Dec 2022 14:53:08 +0000 Subject: Make FloatIn robust to shadowing This MR fixes #22622. See the new Note [Shadowing and name capture] I did a bit of refactoring in sepBindsByDropPoint too. The bug doesn't manifest in HEAD, but it did show up in 9.4, so we should backport this patch to 9.4 --- compiler/GHC/Core/Opt/FloatIn.hs | 261 ++++++++++++++++++++++++--------------- 1 file changed, 158 insertions(+), 103 deletions(-) (limited to 'compiler/GHC/Core') diff --git a/compiler/GHC/Core/Opt/FloatIn.hs b/compiler/GHC/Core/Opt/FloatIn.hs index cf3ca726e4..2feef8a617 100644 --- a/compiler/GHC/Core/Opt/FloatIn.hs +++ b/compiler/GHC/Core/Opt/FloatIn.hs @@ -35,9 +35,12 @@ import GHC.Types.Var import GHC.Types.Var.Set import GHC.Utils.Misc -import GHC.Utils.Panic import GHC.Utils.Panic.Plain +import GHC.Utils.Outputable + +import Data.List ( mapAccumL ) + {- Top-level interface function, @floatInwards@. Note that we do not actually float any bindings downwards from the top-level. @@ -124,7 +127,7 @@ the closure for a is not built. ************************************************************************ -} -type FreeVarSet = DIdSet +type FreeVarSet = DVarSet type BoundVarSet = DIdSet data FloatInBind = FB BoundVarSet FreeVarSet FloatBind @@ -132,11 +135,17 @@ data FloatInBind = FB BoundVarSet FreeVarSet FloatBind -- of recursive bindings, the set doesn't include the bound -- variables. -type FloatInBinds = [FloatInBind] - -- In reverse dependency order (innermost binder first) +type FloatInBinds = [FloatInBind] -- In normal dependency order + -- (outermost binder first) +type RevFloatInBinds = [FloatInBind] -- In reverse dependency order + -- (innermost binder first) + +instance Outputable FloatInBind where + ppr (FB bvs fvs _) = text "FB" <> braces (sep [ text "bndrs =" <+> ppr bvs + , text "fvs =" <+> ppr fvs ]) fiExpr :: Platform - -> FloatInBinds -- Binds we're trying to drop + -> RevFloatInBinds -- Binds we're trying to drop -- as far "inwards" as possible -> CoreExprWithFVs -- Input expr -> CoreExpr -- Result @@ -147,13 +156,12 @@ fiExpr _ to_drop (_, AnnType ty) = assert (null to_drop) $ Type ty fiExpr _ to_drop (_, AnnVar v) = wrapFloats to_drop (Var v) fiExpr _ to_drop (_, AnnCoercion co) = wrapFloats to_drop (Coercion co) fiExpr platform to_drop (_, AnnCast expr (co_ann, co)) - = wrapFloats (drop_here ++ co_drop) $ + = wrapFloats drop_here $ Cast (fiExpr platform e_drop expr) co where - [drop_here, e_drop, co_drop] - = sepBindsByDropPoint platform False - [freeVarsOf expr, freeVarsOfAnn co_ann] - to_drop + (drop_here, [e_drop]) + = sepBindsByDropPoint platform False to_drop + (freeVarsOfAnn co_ann) [freeVarsOf expr] {- Applications: we do float inside applications, mainly because we @@ -162,7 +170,7 @@ pull out any silly ones. -} fiExpr platform to_drop ann_expr@(_,AnnApp {}) - = wrapFloats drop_here $ wrapFloats extra_drop $ + = wrapFloats drop_here $ mkTicks ticks $ mkApps (fiExpr platform fun_drop ann_fun) (zipWithEqual "fiExpr" (fiExpr platform) arg_drops ann_args) @@ -170,21 +178,19 @@ fiExpr platform to_drop ann_expr@(_,AnnApp {}) -- length ann_args = length arg_fvs = length arg_drops where (ann_fun, ann_args, ticks) = collectAnnArgsTicks tickishFloatable ann_expr - fun_ty = exprType (deAnnotate ann_fun) fun_fvs = freeVarsOf ann_fun - arg_fvs = map freeVarsOf ann_args - (drop_here : extra_drop : fun_drop : arg_drops) - = sepBindsByDropPoint platform False - (extra_fvs : fun_fvs : arg_fvs) - to_drop + (drop_here, fun_drop : arg_drops) + = sepBindsByDropPoint platform False to_drop + here_fvs (fun_fvs : arg_fvs) + -- Shortcut behaviour: if to_drop is empty, -- sepBindsByDropPoint returns a suitable bunch of empty -- lists without evaluating extra_fvs, and hence without -- peering into each argument - (_, extra_fvs) = foldl' add_arg (fun_ty, extra_fvs0) ann_args - extra_fvs0 = case ann_fun of + (here_fvs, arg_fvs) = mapAccumL add_arg here_fvs0 ann_args + here_fvs0 = case ann_fun of (_, AnnVar _) -> fun_fvs _ -> emptyDVarSet -- Don't float the binding for f into f x y z; see Note [Join points] @@ -192,14 +198,11 @@ fiExpr platform to_drop ann_expr@(_,AnnApp {}) -- join point, floating it in isn't especially harmful but it's -- useless since the simplifier will immediately float it back out.) - add_arg :: (Type,FreeVarSet) -> CoreExprWithFVs -> (Type,FreeVarSet) - add_arg (fun_ty, extra_fvs) (_, AnnType ty) - = (piResultTy fun_ty ty, extra_fvs) - add_arg (fun_ty, extra_fvs) (arg_fvs, arg) - | noFloatIntoArg arg - = (funResultTy fun_ty, extra_fvs `unionDVarSet` arg_fvs) - | otherwise - = (funResultTy fun_ty, extra_fvs) + add_arg :: FreeVarSet -> CoreExprWithFVs -> (FreeVarSet,FreeVarSet) + -- We can't float into some arguments, so put them into the here_fvs + add_arg here_fvs (arg_fvs, arg) + | noFloatIntoArg arg = (here_fvs `unionDVarSet` arg_fvs, emptyDVarSet) + | otherwise = (here_fvs, arg_fvs) {- Note [Dead bindings] ~~~~~~~~~~~~~~~~~~~~~~~ @@ -272,7 +275,6 @@ it's non-recursive, so we float only into non-recursive join points.) Urk! if all are tyvars, and we don't float in, we may miss an opportunity to float inside a nested case branch - Note [Floating coercions] ~~~~~~~~~~~~~~~~~~~~~~~~~ We could, in principle, have a coercion binding like @@ -292,6 +294,36 @@ of the types of all the drop points involved. If any of the floaters bind a coercion variable mentioned in any of the types, that binder must be dropped right away. +Note [Shadowing and name capture] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Suppose we have + let x = y+1 in + case p of + (y:ys) -> ...x... + [] -> blah +It is obviously bogus for FloatIn to transform to + case p of + (y:ys) -> ...(let x = y+1 in x)... + [] -> blah +because the y is captured. This doesn't happen much, because shadowing is +rare, but it did happen in #22662. + +One solution would be to clone as we go. But a simpler one is this: + + at a binding site (like that for (y:ys) above), abandon float-in for + any floating bindings that mention the binders (y, ys in this case) + +We achieve that by calling sepBindsByDropPoint with the binders in +the "used-here" set: + +* In fiExpr (AnnLam ...). For the body there is no need to delete + the lambda-binders from the body_fvs, because any bindings that + mention these binders will be dropped here anyway. + +* In fiExpr (AnnCase ...). Remember to include the case_bndr in the + binders. Again, no need to delete the alt binders from the rhs + free vars, beause any bindings mentioning them will be dropped + here unconditionally. -} fiExpr platform to_drop lam@(_, AnnLam _ _) @@ -300,10 +332,17 @@ fiExpr platform to_drop lam@(_, AnnLam _ _) = wrapFloats to_drop (mkLams bndrs (fiExpr platform [] body)) | otherwise -- Float inside - = mkLams bndrs (fiExpr platform to_drop body) + = wrapFloats drop_here $ + mkLams bndrs (fiExpr platform body_drop body) where (bndrs, body) = collectAnnBndrs lam + body_fvs = freeVarsOf body + + -- Why sepBindsByDropPoint? Because of potential capture + -- See Note [Shadowing and name capture] + (drop_here, [body_drop]) = sepBindsByDropPoint platform False to_drop + (mkDVarSet bndrs) [body_fvs] {- We don't float lets inwards past an SCC. @@ -443,16 +482,16 @@ fiExpr platform to_drop (_, AnnCase scrut case_bndr _ [AnnAlt con alt_bndrs rhs] = wrapFloats shared_binds $ fiExpr platform (case_float : rhs_binds) rhs where - case_float = FB (mkDVarSet (case_bndr : alt_bndrs)) scrut_fvs + case_float = FB all_bndrs scrut_fvs (FloatCase scrut' case_bndr con alt_bndrs) scrut' = fiExpr platform scrut_binds scrut - rhs_fvs = freeVarsOf rhs `delDVarSetList` (case_bndr : alt_bndrs) - scrut_fvs = freeVarsOf scrut + rhs_fvs = freeVarsOf rhs -- No need to delete alt_bndrs + scrut_fvs = freeVarsOf scrut -- See Note [Shadowing and name capture] + all_bndrs = mkDVarSet alt_bndrs `extendDVarSet` case_bndr - [shared_binds, scrut_binds, rhs_binds] - = sepBindsByDropPoint platform False - [scrut_fvs, rhs_fvs] - to_drop + (shared_binds, [scrut_binds, rhs_binds]) + = sepBindsByDropPoint platform False to_drop + all_bndrs [scrut_fvs, rhs_fvs] fiExpr platform to_drop (_, AnnCase scrut case_bndr ty alts) = wrapFloats drop_here1 $ @@ -462,39 +501,43 @@ fiExpr platform to_drop (_, AnnCase scrut case_bndr ty alts) -- use zipWithEqual, we should have length alts_drops_s = length alts where -- Float into the scrut and alts-considered-together just like App - [drop_here1, scrut_drops, alts_drops] - = sepBindsByDropPoint platform False - [scrut_fvs, all_alts_fvs] - to_drop + (drop_here1, [scrut_drops, alts_drops]) + = sepBindsByDropPoint platform False to_drop + all_alt_bndrs [scrut_fvs, all_alt_fvs] + -- all_alt_bndrs: see Note [Shadowing and name capture] -- Float into the alts with the is_case flag set - (drop_here2 : alts_drops_s) - | [ _ ] <- alts = [] : [alts_drops] - | otherwise = sepBindsByDropPoint platform True alts_fvs alts_drops - - scrut_fvs = freeVarsOf scrut - alts_fvs = map alt_fvs alts - all_alts_fvs = unionDVarSets alts_fvs - alt_fvs (AnnAlt _con args rhs) - = foldl' delDVarSet (freeVarsOf rhs) (case_bndr:args) - -- Delete case_bndr and args from free vars of rhs - -- to get free vars of alt + (drop_here2, alts_drops_s) + = sepBindsByDropPoint platform True alts_drops emptyDVarSet alts_fvs + + scrut_fvs = freeVarsOf scrut + + all_alt_bndrs = foldr (unionDVarSet . ann_alt_bndrs) (unitDVarSet case_bndr) alts + ann_alt_bndrs (AnnAlt _ bndrs _) = mkDVarSet bndrs + + alts_fvs :: [DVarSet] + alts_fvs = [freeVarsOf rhs | AnnAlt _ _ rhs <- alts] + -- No need to delete binders + -- See Note [Shadowing and name capture] + + all_alt_fvs :: DVarSet + all_alt_fvs = foldr unionDVarSet (unitDVarSet case_bndr) alts_fvs fi_alt to_drop (AnnAlt con args rhs) = Alt con args (fiExpr platform to_drop rhs) ------------------ fiBind :: Platform - -> FloatInBinds -- Binds we're trying to drop - -- as far "inwards" as possible - -> CoreBindWithFVs -- Input binding - -> DVarSet -- Free in scope of binding - -> ( FloatInBinds -- Land these before - , FloatInBind -- The binding itself - , FloatInBinds) -- Land these after + -> RevFloatInBinds -- Binds we're trying to drop + -- as far "inwards" as possible + -> CoreBindWithFVs -- Input binding + -> DVarSet -- Free in scope of binding + -> ( RevFloatInBinds -- Land these before + , FloatInBind -- The binding itself + , RevFloatInBinds) -- Land these after fiBind platform to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs - = ( extra_binds ++ shared_binds -- Land these before - -- See Note [extra_fvs (1)] and Note [extra_fvs (2)] + = ( shared_binds -- Land these before + -- See Note [extra_fvs (1)] and Note [extra_fvs (2)] , FB (unitDVarSet id) rhs_fvs' -- The new binding itself (FloatLet (NonRec id rhs')) , body_binds ) -- Land these after @@ -512,10 +555,9 @@ fiBind platform to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs -- We *can't* float into ok-for-speculation unlifted RHSs -- But do float into join points - [shared_binds, extra_binds, rhs_binds, body_binds] - = sepBindsByDropPoint platform False - [extra_fvs, rhs_fvs, body_fvs2] - to_drop + (shared_binds, [rhs_binds, body_binds]) + = sepBindsByDropPoint platform False to_drop + extra_fvs [rhs_fvs, body_fvs2] -- Push rhs_binds into the right hand side of the binding rhs' = fiRhs platform rhs_binds id ann_rhs @@ -523,7 +565,7 @@ fiBind platform to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs -- Don't forget the rule_fvs; the binding mentions them! fiBind platform to_drop (AnnRec bindings) body_fvs - = ( extra_binds ++ shared_binds + = ( shared_binds , FB (mkDVarSet ids) rhs_fvs' (FloatLet (Rec (fi_bind rhss_binds bindings))) , body_binds ) @@ -537,17 +579,16 @@ fiBind platform to_drop (AnnRec bindings) body_fvs unionDVarSets [ rhs_fvs | (bndr, (rhs_fvs, rhs)) <- bindings , noFloatIntoRhs Recursive bndr rhs ] - (shared_binds:extra_binds:body_binds:rhss_binds) - = sepBindsByDropPoint platform False - (extra_fvs:body_fvs:rhss_fvs) - to_drop + (shared_binds, body_binds:rhss_binds) + = sepBindsByDropPoint platform False to_drop + extra_fvs (body_fvs:rhss_fvs) rhs_fvs' = unionDVarSets rhss_fvs `unionDVarSet` unionDVarSets (map floatedBindsFVs rhss_binds) `unionDVarSet` rule_fvs -- Don't forget the rule variables! -- Push rhs_binds into the right hand side of the binding - fi_bind :: [FloatInBinds] -- one per "drop pt" conjured w/ fvs_of_rhss + fi_bind :: [RevFloatInBinds] -- One per "drop pt" conjured w/ fvs_of_rhss -> [(Id, CoreExprWithFVs)] -> [(Id, CoreExpr)] @@ -556,7 +597,7 @@ fiBind platform to_drop (AnnRec bindings) body_fvs | ((binder, rhs), to_drop) <- zipEqual "fi_bind" pairs to_drops ] ------------------ -fiRhs :: Platform -> FloatInBinds -> CoreBndr -> CoreExprWithFVs -> CoreExpr +fiRhs :: Platform -> RevFloatInBinds -> CoreBndr -> CoreExprWithFVs -> CoreExpr fiRhs platform to_drop bndr rhs | Just join_arity <- isJoinId_maybe bndr , let (bndrs, body) = collectNAnnBndrs join_arity rhs @@ -656,68 +697,84 @@ point. We have to maintain the order on these drop-point-related lists. -} --- pprFIB :: FloatInBinds -> SDoc +-- pprFIB :: RevFloatInBinds -> SDoc -- pprFIB fibs = text "FIB:" <+> ppr [b | FB _ _ b <- fibs] sepBindsByDropPoint :: Platform - -> Bool -- True <=> is case expression - -> [FreeVarSet] -- One set of FVs per drop point - -- Always at least two long! - -> FloatInBinds -- Candidate floaters - -> [FloatInBinds] -- FIRST one is bindings which must not be floated - -- inside any drop point; the rest correspond - -- one-to-one with the input list of FV sets + -> Bool -- True <=> is case expression + -> RevFloatInBinds -- Candidate floaters + -> FreeVarSet -- here_fvs: if these vars are free in a binding, + -- don't float that binding inside any drop point + -> [FreeVarSet] -- fork_fvs: one set of FVs per drop point + -> ( RevFloatInBinds -- Bindings which must not be floated inside + , [RevFloatInBinds] ) -- Corresponds 1-1 with the input list of FV sets -- Every input floater is returned somewhere in the result; -- none are dropped, not even ones which don't seem to be -- free in *any* of the drop-point fvs. Why? Because, for example, -- a binding (let x = E in B) might have a specialised version of -- x (say x') stored inside x, but x' isn't free in E or B. +-- +-- The here_fvs argument is used for two things: +-- * Avoid shadowing bugs: see Note [Shadowing and name capture] +-- * Drop some of the bindings at the top, e.g. of an application type DropBox = (FreeVarSet, FloatInBinds) -sepBindsByDropPoint platform is_case drop_pts floaters +dropBoxFloats :: DropBox -> RevFloatInBinds +dropBoxFloats (_, floats) = reverse floats + +usedInDropBox :: DIdSet -> DropBox -> Bool +usedInDropBox bndrs (db_fvs, _) = db_fvs `intersectsDVarSet` bndrs + +initDropBox :: DVarSet -> DropBox +initDropBox fvs = (fvs, []) + +sepBindsByDropPoint platform is_case floaters here_fvs fork_fvs | null floaters -- Shortcut common case - = [] : [[] | _ <- drop_pts] + = ([], [[] | _ <- fork_fvs]) | otherwise - = assert (drop_pts `lengthAtLeast` 2) $ - go floaters (map (\fvs -> (fvs, [])) (emptyDVarSet : drop_pts)) + = go floaters (initDropBox here_fvs) (map initDropBox fork_fvs) where - n_alts = length drop_pts + n_alts = length fork_fvs - go :: FloatInBinds -> [DropBox] -> [FloatInBinds] - -- The *first* one in the argument list is the drop_here set - -- The FloatInBinds in the lists are in the reverse of - -- the normal FloatInBinds order; that is, they are the right way round! + go :: RevFloatInBinds -> DropBox -> [DropBox] + -> (RevFloatInBinds, [RevFloatInBinds]) + -- The *first* one in the pair is the drop_here set - go [] drop_boxes = map (reverse . snd) drop_boxes + go [] here_box fork_boxes + = (dropBoxFloats here_box, map dropBoxFloats fork_boxes) - go (bind_w_fvs@(FB bndrs bind_fvs bind) : binds) drop_boxes@(here_box : fork_boxes) - = go binds new_boxes + go (bind_w_fvs@(FB bndrs bind_fvs bind) : binds) here_box fork_boxes + | drop_here = go binds (insert here_box) fork_boxes + | otherwise = go binds here_box new_fork_boxes where -- "here" means the group of bindings dropped at the top of the fork - (used_here : used_in_flags) = [ fvs `intersectsDVarSet` bndrs - | (fvs, _) <- drop_boxes] + used_here = bndrs `usedInDropBox` here_box + used_in_flags = case fork_boxes of + [] -> [] + [_] -> [True] -- Push all bindings into a single branch + -- No need to look at its free vars + _ -> map (bndrs `usedInDropBox`) fork_boxes + -- Short-cut for the singleton case; + -- used for lambdas and singleton cases drop_here = used_here || cant_push n_used_alts = count id used_in_flags -- returns number of Trues in list. cant_push - | is_case = n_used_alts == n_alts -- Used in all, don't push - -- Remember n_alts > 1 + | is_case = (n_alts > 1 && n_used_alts == n_alts) + -- Used in all, muliple branches, don't push || (n_used_alts > 1 && not (floatIsDupable platform bind)) -- floatIsDupable: see Note [Duplicating floats] | otherwise = floatIsCase bind || n_used_alts > 1 -- floatIsCase: see Note [Floating primops] - new_boxes | drop_here = (insert here_box : fork_boxes) - | otherwise = (here_box : new_fork_boxes) - new_fork_boxes = zipWithEqual "FloatIn.sepBinds" insert_maybe fork_boxes used_in_flags @@ -727,8 +784,6 @@ sepBindsByDropPoint platform is_case drop_pts floaters insert_maybe box True = insert box insert_maybe box False = box - go _ _ = panic "sepBindsByDropPoint/go" - {- Note [Duplicating floats] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -745,14 +800,14 @@ If the thing is used in all RHSs there is nothing gained, so we don't duplicate then. -} -floatedBindsFVs :: FloatInBinds -> FreeVarSet +floatedBindsFVs :: RevFloatInBinds -> FreeVarSet floatedBindsFVs binds = mapUnionDVarSet fbFVs binds fbFVs :: FloatInBind -> DVarSet fbFVs (FB _ fvs _) = fvs -wrapFloats :: FloatInBinds -> CoreExpr -> CoreExpr --- Remember FloatInBinds is in *reverse* dependency order +wrapFloats :: RevFloatInBinds -> CoreExpr -> CoreExpr +-- Remember RevFloatInBinds is in *reverse* dependency order wrapFloats [] e = e wrapFloats (FB _ _ fl : bs) e = wrapFloats bs (wrapFloat fl e) -- cgit v1.2.1