summaryrefslogtreecommitdiff
path: root/compiler/GHC/Core
diff options
context:
space:
mode:
authorSimon Peyton Jones <simon.peytonjones@gmail.com>2022-12-23 14:53:08 +0000
committerMarge Bot <ben+marge-bot@smart-cactus.org>2023-01-07 12:14:40 -0500
commit6206cb9287f3f6e70c669660a646a65274870d2b (patch)
tree2874fab5d79349b90ea364a058c7420e075ab020 /compiler/GHC/Core
parent2459c3587bfe8105c628f9733bf32d1d3c903375 (diff)
downloadhaskell-6206cb9287f3f6e70c669660a646a65274870d2b.tar.gz
Make FloatIn robust to shadowing
This MR fixes #22622. See the new Note [Shadowing and name capture] I did a bit of refactoring in sepBindsByDropPoint too. The bug doesn't manifest in HEAD, but it did show up in 9.4, so we should backport this patch to 9.4
Diffstat (limited to 'compiler/GHC/Core')
-rw-r--r--compiler/GHC/Core/Opt/FloatIn.hs261
1 files changed, 158 insertions, 103 deletions
diff --git a/compiler/GHC/Core/Opt/FloatIn.hs b/compiler/GHC/Core/Opt/FloatIn.hs
index cf3ca726e4..2feef8a617 100644
--- a/compiler/GHC/Core/Opt/FloatIn.hs
+++ b/compiler/GHC/Core/Opt/FloatIn.hs
@@ -35,9 +35,12 @@ import GHC.Types.Var
import GHC.Types.Var.Set
import GHC.Utils.Misc
-import GHC.Utils.Panic
import GHC.Utils.Panic.Plain
+import GHC.Utils.Outputable
+
+import Data.List ( mapAccumL )
+
{-
Top-level interface function, @floatInwards@. Note that we do not
actually float any bindings downwards from the top-level.
@@ -124,7 +127,7 @@ the closure for a is not built.
************************************************************************
-}
-type FreeVarSet = DIdSet
+type FreeVarSet = DVarSet
type BoundVarSet = DIdSet
data FloatInBind = FB BoundVarSet FreeVarSet FloatBind
@@ -132,11 +135,17 @@ data FloatInBind = FB BoundVarSet FreeVarSet FloatBind
-- of recursive bindings, the set doesn't include the bound
-- variables.
-type FloatInBinds = [FloatInBind]
- -- In reverse dependency order (innermost binder first)
+type FloatInBinds = [FloatInBind] -- In normal dependency order
+ -- (outermost binder first)
+type RevFloatInBinds = [FloatInBind] -- In reverse dependency order
+ -- (innermost binder first)
+
+instance Outputable FloatInBind where
+ ppr (FB bvs fvs _) = text "FB" <> braces (sep [ text "bndrs =" <+> ppr bvs
+ , text "fvs =" <+> ppr fvs ])
fiExpr :: Platform
- -> FloatInBinds -- Binds we're trying to drop
+ -> RevFloatInBinds -- Binds we're trying to drop
-- as far "inwards" as possible
-> CoreExprWithFVs -- Input expr
-> CoreExpr -- Result
@@ -147,13 +156,12 @@ fiExpr _ to_drop (_, AnnType ty) = assert (null to_drop) $ Type ty
fiExpr _ to_drop (_, AnnVar v) = wrapFloats to_drop (Var v)
fiExpr _ to_drop (_, AnnCoercion co) = wrapFloats to_drop (Coercion co)
fiExpr platform to_drop (_, AnnCast expr (co_ann, co))
- = wrapFloats (drop_here ++ co_drop) $
+ = wrapFloats drop_here $
Cast (fiExpr platform e_drop expr) co
where
- [drop_here, e_drop, co_drop]
- = sepBindsByDropPoint platform False
- [freeVarsOf expr, freeVarsOfAnn co_ann]
- to_drop
+ (drop_here, [e_drop])
+ = sepBindsByDropPoint platform False to_drop
+ (freeVarsOfAnn co_ann) [freeVarsOf expr]
{-
Applications: we do float inside applications, mainly because we
@@ -162,7 +170,7 @@ pull out any silly ones.
-}
fiExpr platform to_drop ann_expr@(_,AnnApp {})
- = wrapFloats drop_here $ wrapFloats extra_drop $
+ = wrapFloats drop_here $
mkTicks ticks $
mkApps (fiExpr platform fun_drop ann_fun)
(zipWithEqual "fiExpr" (fiExpr platform) arg_drops ann_args)
@@ -170,21 +178,19 @@ fiExpr platform to_drop ann_expr@(_,AnnApp {})
-- length ann_args = length arg_fvs = length arg_drops
where
(ann_fun, ann_args, ticks) = collectAnnArgsTicks tickishFloatable ann_expr
- fun_ty = exprType (deAnnotate ann_fun)
fun_fvs = freeVarsOf ann_fun
- arg_fvs = map freeVarsOf ann_args
- (drop_here : extra_drop : fun_drop : arg_drops)
- = sepBindsByDropPoint platform False
- (extra_fvs : fun_fvs : arg_fvs)
- to_drop
+ (drop_here, fun_drop : arg_drops)
+ = sepBindsByDropPoint platform False to_drop
+ here_fvs (fun_fvs : arg_fvs)
+
-- Shortcut behaviour: if to_drop is empty,
-- sepBindsByDropPoint returns a suitable bunch of empty
-- lists without evaluating extra_fvs, and hence without
-- peering into each argument
- (_, extra_fvs) = foldl' add_arg (fun_ty, extra_fvs0) ann_args
- extra_fvs0 = case ann_fun of
+ (here_fvs, arg_fvs) = mapAccumL add_arg here_fvs0 ann_args
+ here_fvs0 = case ann_fun of
(_, AnnVar _) -> fun_fvs
_ -> emptyDVarSet
-- Don't float the binding for f into f x y z; see Note [Join points]
@@ -192,14 +198,11 @@ fiExpr platform to_drop ann_expr@(_,AnnApp {})
-- join point, floating it in isn't especially harmful but it's
-- useless since the simplifier will immediately float it back out.)
- add_arg :: (Type,FreeVarSet) -> CoreExprWithFVs -> (Type,FreeVarSet)
- add_arg (fun_ty, extra_fvs) (_, AnnType ty)
- = (piResultTy fun_ty ty, extra_fvs)
- add_arg (fun_ty, extra_fvs) (arg_fvs, arg)
- | noFloatIntoArg arg
- = (funResultTy fun_ty, extra_fvs `unionDVarSet` arg_fvs)
- | otherwise
- = (funResultTy fun_ty, extra_fvs)
+ add_arg :: FreeVarSet -> CoreExprWithFVs -> (FreeVarSet,FreeVarSet)
+ -- We can't float into some arguments, so put them into the here_fvs
+ add_arg here_fvs (arg_fvs, arg)
+ | noFloatIntoArg arg = (here_fvs `unionDVarSet` arg_fvs, emptyDVarSet)
+ | otherwise = (here_fvs, arg_fvs)
{- Note [Dead bindings]
~~~~~~~~~~~~~~~~~~~~~~~
@@ -272,7 +275,6 @@ it's non-recursive, so we float only into non-recursive join points.)
Urk! if all are tyvars, and we don't float in, we may miss an
opportunity to float inside a nested case branch
-
Note [Floating coercions]
~~~~~~~~~~~~~~~~~~~~~~~~~
We could, in principle, have a coercion binding like
@@ -292,6 +294,36 @@ of the types of all the drop points involved. If any of the floaters
bind a coercion variable mentioned in any of the types, that binder must
be dropped right away.
+Note [Shadowing and name capture]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Suppose we have
+ let x = y+1 in
+ case p of
+ (y:ys) -> ...x...
+ [] -> blah
+It is obviously bogus for FloatIn to transform to
+ case p of
+ (y:ys) -> ...(let x = y+1 in x)...
+ [] -> blah
+because the y is captured. This doesn't happen much, because shadowing is
+rare, but it did happen in #22662.
+
+One solution would be to clone as we go. But a simpler one is this:
+
+ at a binding site (like that for (y:ys) above), abandon float-in for
+ any floating bindings that mention the binders (y, ys in this case)
+
+We achieve that by calling sepBindsByDropPoint with the binders in
+the "used-here" set:
+
+* In fiExpr (AnnLam ...). For the body there is no need to delete
+ the lambda-binders from the body_fvs, because any bindings that
+ mention these binders will be dropped here anyway.
+
+* In fiExpr (AnnCase ...). Remember to include the case_bndr in the
+ binders. Again, no need to delete the alt binders from the rhs
+ free vars, beause any bindings mentioning them will be dropped
+ here unconditionally.
-}
fiExpr platform to_drop lam@(_, AnnLam _ _)
@@ -300,10 +332,17 @@ fiExpr platform to_drop lam@(_, AnnLam _ _)
= wrapFloats to_drop (mkLams bndrs (fiExpr platform [] body))
| otherwise -- Float inside
- = mkLams bndrs (fiExpr platform to_drop body)
+ = wrapFloats drop_here $
+ mkLams bndrs (fiExpr platform body_drop body)
where
(bndrs, body) = collectAnnBndrs lam
+ body_fvs = freeVarsOf body
+
+ -- Why sepBindsByDropPoint? Because of potential capture
+ -- See Note [Shadowing and name capture]
+ (drop_here, [body_drop]) = sepBindsByDropPoint platform False to_drop
+ (mkDVarSet bndrs) [body_fvs]
{-
We don't float lets inwards past an SCC.
@@ -443,16 +482,16 @@ fiExpr platform to_drop (_, AnnCase scrut case_bndr _ [AnnAlt con alt_bndrs rhs]
= wrapFloats shared_binds $
fiExpr platform (case_float : rhs_binds) rhs
where
- case_float = FB (mkDVarSet (case_bndr : alt_bndrs)) scrut_fvs
+ case_float = FB all_bndrs scrut_fvs
(FloatCase scrut' case_bndr con alt_bndrs)
scrut' = fiExpr platform scrut_binds scrut
- rhs_fvs = freeVarsOf rhs `delDVarSetList` (case_bndr : alt_bndrs)
- scrut_fvs = freeVarsOf scrut
+ rhs_fvs = freeVarsOf rhs -- No need to delete alt_bndrs
+ scrut_fvs = freeVarsOf scrut -- See Note [Shadowing and name capture]
+ all_bndrs = mkDVarSet alt_bndrs `extendDVarSet` case_bndr
- [shared_binds, scrut_binds, rhs_binds]
- = sepBindsByDropPoint platform False
- [scrut_fvs, rhs_fvs]
- to_drop
+ (shared_binds, [scrut_binds, rhs_binds])
+ = sepBindsByDropPoint platform False to_drop
+ all_bndrs [scrut_fvs, rhs_fvs]
fiExpr platform to_drop (_, AnnCase scrut case_bndr ty alts)
= wrapFloats drop_here1 $
@@ -462,39 +501,43 @@ fiExpr platform to_drop (_, AnnCase scrut case_bndr ty alts)
-- use zipWithEqual, we should have length alts_drops_s = length alts
where
-- Float into the scrut and alts-considered-together just like App
- [drop_here1, scrut_drops, alts_drops]
- = sepBindsByDropPoint platform False
- [scrut_fvs, all_alts_fvs]
- to_drop
+ (drop_here1, [scrut_drops, alts_drops])
+ = sepBindsByDropPoint platform False to_drop
+ all_alt_bndrs [scrut_fvs, all_alt_fvs]
+ -- all_alt_bndrs: see Note [Shadowing and name capture]
-- Float into the alts with the is_case flag set
- (drop_here2 : alts_drops_s)
- | [ _ ] <- alts = [] : [alts_drops]
- | otherwise = sepBindsByDropPoint platform True alts_fvs alts_drops
-
- scrut_fvs = freeVarsOf scrut
- alts_fvs = map alt_fvs alts
- all_alts_fvs = unionDVarSets alts_fvs
- alt_fvs (AnnAlt _con args rhs)
- = foldl' delDVarSet (freeVarsOf rhs) (case_bndr:args)
- -- Delete case_bndr and args from free vars of rhs
- -- to get free vars of alt
+ (drop_here2, alts_drops_s)
+ = sepBindsByDropPoint platform True alts_drops emptyDVarSet alts_fvs
+
+ scrut_fvs = freeVarsOf scrut
+
+ all_alt_bndrs = foldr (unionDVarSet . ann_alt_bndrs) (unitDVarSet case_bndr) alts
+ ann_alt_bndrs (AnnAlt _ bndrs _) = mkDVarSet bndrs
+
+ alts_fvs :: [DVarSet]
+ alts_fvs = [freeVarsOf rhs | AnnAlt _ _ rhs <- alts]
+ -- No need to delete binders
+ -- See Note [Shadowing and name capture]
+
+ all_alt_fvs :: DVarSet
+ all_alt_fvs = foldr unionDVarSet (unitDVarSet case_bndr) alts_fvs
fi_alt to_drop (AnnAlt con args rhs) = Alt con args (fiExpr platform to_drop rhs)
------------------
fiBind :: Platform
- -> FloatInBinds -- Binds we're trying to drop
- -- as far "inwards" as possible
- -> CoreBindWithFVs -- Input binding
- -> DVarSet -- Free in scope of binding
- -> ( FloatInBinds -- Land these before
- , FloatInBind -- The binding itself
- , FloatInBinds) -- Land these after
+ -> RevFloatInBinds -- Binds we're trying to drop
+ -- as far "inwards" as possible
+ -> CoreBindWithFVs -- Input binding
+ -> DVarSet -- Free in scope of binding
+ -> ( RevFloatInBinds -- Land these before
+ , FloatInBind -- The binding itself
+ , RevFloatInBinds) -- Land these after
fiBind platform to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs
- = ( extra_binds ++ shared_binds -- Land these before
- -- See Note [extra_fvs (1)] and Note [extra_fvs (2)]
+ = ( shared_binds -- Land these before
+ -- See Note [extra_fvs (1)] and Note [extra_fvs (2)]
, FB (unitDVarSet id) rhs_fvs' -- The new binding itself
(FloatLet (NonRec id rhs'))
, body_binds ) -- Land these after
@@ -512,10 +555,9 @@ fiBind platform to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs
-- We *can't* float into ok-for-speculation unlifted RHSs
-- But do float into join points
- [shared_binds, extra_binds, rhs_binds, body_binds]
- = sepBindsByDropPoint platform False
- [extra_fvs, rhs_fvs, body_fvs2]
- to_drop
+ (shared_binds, [rhs_binds, body_binds])
+ = sepBindsByDropPoint platform False to_drop
+ extra_fvs [rhs_fvs, body_fvs2]
-- Push rhs_binds into the right hand side of the binding
rhs' = fiRhs platform rhs_binds id ann_rhs
@@ -523,7 +565,7 @@ fiBind platform to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs
-- Don't forget the rule_fvs; the binding mentions them!
fiBind platform to_drop (AnnRec bindings) body_fvs
- = ( extra_binds ++ shared_binds
+ = ( shared_binds
, FB (mkDVarSet ids) rhs_fvs'
(FloatLet (Rec (fi_bind rhss_binds bindings)))
, body_binds )
@@ -537,17 +579,16 @@ fiBind platform to_drop (AnnRec bindings) body_fvs
unionDVarSets [ rhs_fvs | (bndr, (rhs_fvs, rhs)) <- bindings
, noFloatIntoRhs Recursive bndr rhs ]
- (shared_binds:extra_binds:body_binds:rhss_binds)
- = sepBindsByDropPoint platform False
- (extra_fvs:body_fvs:rhss_fvs)
- to_drop
+ (shared_binds, body_binds:rhss_binds)
+ = sepBindsByDropPoint platform False to_drop
+ extra_fvs (body_fvs:rhss_fvs)
rhs_fvs' = unionDVarSets rhss_fvs `unionDVarSet`
unionDVarSets (map floatedBindsFVs rhss_binds) `unionDVarSet`
rule_fvs -- Don't forget the rule variables!
-- Push rhs_binds into the right hand side of the binding
- fi_bind :: [FloatInBinds] -- one per "drop pt" conjured w/ fvs_of_rhss
+ fi_bind :: [RevFloatInBinds] -- One per "drop pt" conjured w/ fvs_of_rhss
-> [(Id, CoreExprWithFVs)]
-> [(Id, CoreExpr)]
@@ -556,7 +597,7 @@ fiBind platform to_drop (AnnRec bindings) body_fvs
| ((binder, rhs), to_drop) <- zipEqual "fi_bind" pairs to_drops ]
------------------
-fiRhs :: Platform -> FloatInBinds -> CoreBndr -> CoreExprWithFVs -> CoreExpr
+fiRhs :: Platform -> RevFloatInBinds -> CoreBndr -> CoreExprWithFVs -> CoreExpr
fiRhs platform to_drop bndr rhs
| Just join_arity <- isJoinId_maybe bndr
, let (bndrs, body) = collectNAnnBndrs join_arity rhs
@@ -656,68 +697,84 @@ point.
We have to maintain the order on these drop-point-related lists.
-}
--- pprFIB :: FloatInBinds -> SDoc
+-- pprFIB :: RevFloatInBinds -> SDoc
-- pprFIB fibs = text "FIB:" <+> ppr [b | FB _ _ b <- fibs]
sepBindsByDropPoint
:: Platform
- -> Bool -- True <=> is case expression
- -> [FreeVarSet] -- One set of FVs per drop point
- -- Always at least two long!
- -> FloatInBinds -- Candidate floaters
- -> [FloatInBinds] -- FIRST one is bindings which must not be floated
- -- inside any drop point; the rest correspond
- -- one-to-one with the input list of FV sets
+ -> Bool -- True <=> is case expression
+ -> RevFloatInBinds -- Candidate floaters
+ -> FreeVarSet -- here_fvs: if these vars are free in a binding,
+ -- don't float that binding inside any drop point
+ -> [FreeVarSet] -- fork_fvs: one set of FVs per drop point
+ -> ( RevFloatInBinds -- Bindings which must not be floated inside
+ , [RevFloatInBinds] ) -- Corresponds 1-1 with the input list of FV sets
-- Every input floater is returned somewhere in the result;
-- none are dropped, not even ones which don't seem to be
-- free in *any* of the drop-point fvs. Why? Because, for example,
-- a binding (let x = E in B) might have a specialised version of
-- x (say x') stored inside x, but x' isn't free in E or B.
+--
+-- The here_fvs argument is used for two things:
+-- * Avoid shadowing bugs: see Note [Shadowing and name capture]
+-- * Drop some of the bindings at the top, e.g. of an application
type DropBox = (FreeVarSet, FloatInBinds)
-sepBindsByDropPoint platform is_case drop_pts floaters
+dropBoxFloats :: DropBox -> RevFloatInBinds
+dropBoxFloats (_, floats) = reverse floats
+
+usedInDropBox :: DIdSet -> DropBox -> Bool
+usedInDropBox bndrs (db_fvs, _) = db_fvs `intersectsDVarSet` bndrs
+
+initDropBox :: DVarSet -> DropBox
+initDropBox fvs = (fvs, [])
+
+sepBindsByDropPoint platform is_case floaters here_fvs fork_fvs
| null floaters -- Shortcut common case
- = [] : [[] | _ <- drop_pts]
+ = ([], [[] | _ <- fork_fvs])
| otherwise
- = assert (drop_pts `lengthAtLeast` 2) $
- go floaters (map (\fvs -> (fvs, [])) (emptyDVarSet : drop_pts))
+ = go floaters (initDropBox here_fvs) (map initDropBox fork_fvs)
where
- n_alts = length drop_pts
+ n_alts = length fork_fvs
- go :: FloatInBinds -> [DropBox] -> [FloatInBinds]
- -- The *first* one in the argument list is the drop_here set
- -- The FloatInBinds in the lists are in the reverse of
- -- the normal FloatInBinds order; that is, they are the right way round!
+ go :: RevFloatInBinds -> DropBox -> [DropBox]
+ -> (RevFloatInBinds, [RevFloatInBinds])
+ -- The *first* one in the pair is the drop_here set
- go [] drop_boxes = map (reverse . snd) drop_boxes
+ go [] here_box fork_boxes
+ = (dropBoxFloats here_box, map dropBoxFloats fork_boxes)
- go (bind_w_fvs@(FB bndrs bind_fvs bind) : binds) drop_boxes@(here_box : fork_boxes)
- = go binds new_boxes
+ go (bind_w_fvs@(FB bndrs bind_fvs bind) : binds) here_box fork_boxes
+ | drop_here = go binds (insert here_box) fork_boxes
+ | otherwise = go binds here_box new_fork_boxes
where
-- "here" means the group of bindings dropped at the top of the fork
- (used_here : used_in_flags) = [ fvs `intersectsDVarSet` bndrs
- | (fvs, _) <- drop_boxes]
+ used_here = bndrs `usedInDropBox` here_box
+ used_in_flags = case fork_boxes of
+ [] -> []
+ [_] -> [True] -- Push all bindings into a single branch
+ -- No need to look at its free vars
+ _ -> map (bndrs `usedInDropBox`) fork_boxes
+ -- Short-cut for the singleton case;
+ -- used for lambdas and singleton cases
drop_here = used_here || cant_push
n_used_alts = count id used_in_flags -- returns number of Trues in list.
cant_push
- | is_case = n_used_alts == n_alts -- Used in all, don't push
- -- Remember n_alts > 1
+ | is_case = (n_alts > 1 && n_used_alts == n_alts)
+ -- Used in all, muliple branches, don't push
|| (n_used_alts > 1 && not (floatIsDupable platform bind))
-- floatIsDupable: see Note [Duplicating floats]
| otherwise = floatIsCase bind || n_used_alts > 1
-- floatIsCase: see Note [Floating primops]
- new_boxes | drop_here = (insert here_box : fork_boxes)
- | otherwise = (here_box : new_fork_boxes)
-
new_fork_boxes = zipWithEqual "FloatIn.sepBinds" insert_maybe
fork_boxes used_in_flags
@@ -727,8 +784,6 @@ sepBindsByDropPoint platform is_case drop_pts floaters
insert_maybe box True = insert box
insert_maybe box False = box
- go _ _ = panic "sepBindsByDropPoint/go"
-
{- Note [Duplicating floats]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -745,14 +800,14 @@ If the thing is used in all RHSs there is nothing gained,
so we don't duplicate then.
-}
-floatedBindsFVs :: FloatInBinds -> FreeVarSet
+floatedBindsFVs :: RevFloatInBinds -> FreeVarSet
floatedBindsFVs binds = mapUnionDVarSet fbFVs binds
fbFVs :: FloatInBind -> DVarSet
fbFVs (FB _ fvs _) = fvs
-wrapFloats :: FloatInBinds -> CoreExpr -> CoreExpr
--- Remember FloatInBinds is in *reverse* dependency order
+wrapFloats :: RevFloatInBinds -> CoreExpr -> CoreExpr
+-- Remember RevFloatInBinds is in *reverse* dependency order
wrapFloats [] e = e
wrapFloats (FB _ _ fl : bs) e = wrapFloats bs (wrapFloat fl e)