summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Peyton Jones <simon.peytonjones@gmail.com>2022-12-23 14:53:08 +0000
committerSimon Peyton Jones <simon.peytonjones@gmail.com>2023-01-05 12:58:14 +0000
commit4e78c3bfffd105d1d14579598668a1415d1457eb (patch)
tree7a27a2cbc0e89e26db5246d2458201d645f2ad1f
parent00dc51060881df81258ba3b3bdf447294618a4de (diff)
downloadhaskell-wip/T22662.tar.gz
Make FloatIn robust to shadowingwip/T22662
This MR fixes #22622. See the new Note [Shadowing and name capture] I did a bit of refactoring in sepBindsByDropPoint too. The bug doesn't manifest in HEAD, but it did show up in 9.4, so we should backport this patch to 9.4
-rw-r--r--compiler/GHC/Core/Opt/FloatIn.hs261
-rw-r--r--testsuite/tests/simplCore/should_compile/T22662.hs6
-rw-r--r--testsuite/tests/simplCore/should_compile/all.T1
3 files changed, 165 insertions, 103 deletions
diff --git a/compiler/GHC/Core/Opt/FloatIn.hs b/compiler/GHC/Core/Opt/FloatIn.hs
index cf3ca726e4..2feef8a617 100644
--- a/compiler/GHC/Core/Opt/FloatIn.hs
+++ b/compiler/GHC/Core/Opt/FloatIn.hs
@@ -35,9 +35,12 @@ import GHC.Types.Var
import GHC.Types.Var.Set
import GHC.Utils.Misc
-import GHC.Utils.Panic
import GHC.Utils.Panic.Plain
+import GHC.Utils.Outputable
+
+import Data.List ( mapAccumL )
+
{-
Top-level interface function, @floatInwards@. Note that we do not
actually float any bindings downwards from the top-level.
@@ -124,7 +127,7 @@ the closure for a is not built.
************************************************************************
-}
-type FreeVarSet = DIdSet
+type FreeVarSet = DVarSet
type BoundVarSet = DIdSet
data FloatInBind = FB BoundVarSet FreeVarSet FloatBind
@@ -132,11 +135,17 @@ data FloatInBind = FB BoundVarSet FreeVarSet FloatBind
-- of recursive bindings, the set doesn't include the bound
-- variables.
-type FloatInBinds = [FloatInBind]
- -- In reverse dependency order (innermost binder first)
+type FloatInBinds = [FloatInBind] -- In normal dependency order
+ -- (outermost binder first)
+type RevFloatInBinds = [FloatInBind] -- In reverse dependency order
+ -- (innermost binder first)
+
+instance Outputable FloatInBind where
+ ppr (FB bvs fvs _) = text "FB" <> braces (sep [ text "bndrs =" <+> ppr bvs
+ , text "fvs =" <+> ppr fvs ])
fiExpr :: Platform
- -> FloatInBinds -- Binds we're trying to drop
+ -> RevFloatInBinds -- Binds we're trying to drop
-- as far "inwards" as possible
-> CoreExprWithFVs -- Input expr
-> CoreExpr -- Result
@@ -147,13 +156,12 @@ fiExpr _ to_drop (_, AnnType ty) = assert (null to_drop) $ Type ty
fiExpr _ to_drop (_, AnnVar v) = wrapFloats to_drop (Var v)
fiExpr _ to_drop (_, AnnCoercion co) = wrapFloats to_drop (Coercion co)
fiExpr platform to_drop (_, AnnCast expr (co_ann, co))
- = wrapFloats (drop_here ++ co_drop) $
+ = wrapFloats drop_here $
Cast (fiExpr platform e_drop expr) co
where
- [drop_here, e_drop, co_drop]
- = sepBindsByDropPoint platform False
- [freeVarsOf expr, freeVarsOfAnn co_ann]
- to_drop
+ (drop_here, [e_drop])
+ = sepBindsByDropPoint platform False to_drop
+ (freeVarsOfAnn co_ann) [freeVarsOf expr]
{-
Applications: we do float inside applications, mainly because we
@@ -162,7 +170,7 @@ pull out any silly ones.
-}
fiExpr platform to_drop ann_expr@(_,AnnApp {})
- = wrapFloats drop_here $ wrapFloats extra_drop $
+ = wrapFloats drop_here $
mkTicks ticks $
mkApps (fiExpr platform fun_drop ann_fun)
(zipWithEqual "fiExpr" (fiExpr platform) arg_drops ann_args)
@@ -170,21 +178,19 @@ fiExpr platform to_drop ann_expr@(_,AnnApp {})
-- length ann_args = length arg_fvs = length arg_drops
where
(ann_fun, ann_args, ticks) = collectAnnArgsTicks tickishFloatable ann_expr
- fun_ty = exprType (deAnnotate ann_fun)
fun_fvs = freeVarsOf ann_fun
- arg_fvs = map freeVarsOf ann_args
- (drop_here : extra_drop : fun_drop : arg_drops)
- = sepBindsByDropPoint platform False
- (extra_fvs : fun_fvs : arg_fvs)
- to_drop
+ (drop_here, fun_drop : arg_drops)
+ = sepBindsByDropPoint platform False to_drop
+ here_fvs (fun_fvs : arg_fvs)
+
-- Shortcut behaviour: if to_drop is empty,
-- sepBindsByDropPoint returns a suitable bunch of empty
-- lists without evaluating extra_fvs, and hence without
-- peering into each argument
- (_, extra_fvs) = foldl' add_arg (fun_ty, extra_fvs0) ann_args
- extra_fvs0 = case ann_fun of
+ (here_fvs, arg_fvs) = mapAccumL add_arg here_fvs0 ann_args
+ here_fvs0 = case ann_fun of
(_, AnnVar _) -> fun_fvs
_ -> emptyDVarSet
-- Don't float the binding for f into f x y z; see Note [Join points]
@@ -192,14 +198,11 @@ fiExpr platform to_drop ann_expr@(_,AnnApp {})
-- join point, floating it in isn't especially harmful but it's
-- useless since the simplifier will immediately float it back out.)
- add_arg :: (Type,FreeVarSet) -> CoreExprWithFVs -> (Type,FreeVarSet)
- add_arg (fun_ty, extra_fvs) (_, AnnType ty)
- = (piResultTy fun_ty ty, extra_fvs)
- add_arg (fun_ty, extra_fvs) (arg_fvs, arg)
- | noFloatIntoArg arg
- = (funResultTy fun_ty, extra_fvs `unionDVarSet` arg_fvs)
- | otherwise
- = (funResultTy fun_ty, extra_fvs)
+ add_arg :: FreeVarSet -> CoreExprWithFVs -> (FreeVarSet,FreeVarSet)
+ -- We can't float into some arguments, so put them into the here_fvs
+ add_arg here_fvs (arg_fvs, arg)
+ | noFloatIntoArg arg = (here_fvs `unionDVarSet` arg_fvs, emptyDVarSet)
+ | otherwise = (here_fvs, arg_fvs)
{- Note [Dead bindings]
~~~~~~~~~~~~~~~~~~~~~~~
@@ -272,7 +275,6 @@ it's non-recursive, so we float only into non-recursive join points.)
Urk! if all are tyvars, and we don't float in, we may miss an
opportunity to float inside a nested case branch
-
Note [Floating coercions]
~~~~~~~~~~~~~~~~~~~~~~~~~
We could, in principle, have a coercion binding like
@@ -292,6 +294,36 @@ of the types of all the drop points involved. If any of the floaters
bind a coercion variable mentioned in any of the types, that binder must
be dropped right away.
+Note [Shadowing and name capture]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Suppose we have
+ let x = y+1 in
+ case p of
+ (y:ys) -> ...x...
+ [] -> blah
+It is obviously bogus for FloatIn to transform to
+ case p of
+ (y:ys) -> ...(let x = y+1 in x)...
+ [] -> blah
+because the y is captured. This doesn't happen much, because shadowing is
+rare, but it did happen in #22662.
+
+One solution would be to clone as we go. But a simpler one is this:
+
+ at a binding site (like that for (y:ys) above), abandon float-in for
+ any floating bindings that mention the binders (y, ys in this case)
+
+We achieve that by calling sepBindsByDropPoint with the binders in
+the "used-here" set:
+
+* In fiExpr (AnnLam ...). For the body there is no need to delete
+ the lambda-binders from the body_fvs, because any bindings that
+ mention these binders will be dropped here anyway.
+
+* In fiExpr (AnnCase ...). Remember to include the case_bndr in the
+ binders. Again, no need to delete the alt binders from the rhs
+ free vars, beause any bindings mentioning them will be dropped
+ here unconditionally.
-}
fiExpr platform to_drop lam@(_, AnnLam _ _)
@@ -300,10 +332,17 @@ fiExpr platform to_drop lam@(_, AnnLam _ _)
= wrapFloats to_drop (mkLams bndrs (fiExpr platform [] body))
| otherwise -- Float inside
- = mkLams bndrs (fiExpr platform to_drop body)
+ = wrapFloats drop_here $
+ mkLams bndrs (fiExpr platform body_drop body)
where
(bndrs, body) = collectAnnBndrs lam
+ body_fvs = freeVarsOf body
+
+ -- Why sepBindsByDropPoint? Because of potential capture
+ -- See Note [Shadowing and name capture]
+ (drop_here, [body_drop]) = sepBindsByDropPoint platform False to_drop
+ (mkDVarSet bndrs) [body_fvs]
{-
We don't float lets inwards past an SCC.
@@ -443,16 +482,16 @@ fiExpr platform to_drop (_, AnnCase scrut case_bndr _ [AnnAlt con alt_bndrs rhs]
= wrapFloats shared_binds $
fiExpr platform (case_float : rhs_binds) rhs
where
- case_float = FB (mkDVarSet (case_bndr : alt_bndrs)) scrut_fvs
+ case_float = FB all_bndrs scrut_fvs
(FloatCase scrut' case_bndr con alt_bndrs)
scrut' = fiExpr platform scrut_binds scrut
- rhs_fvs = freeVarsOf rhs `delDVarSetList` (case_bndr : alt_bndrs)
- scrut_fvs = freeVarsOf scrut
+ rhs_fvs = freeVarsOf rhs -- No need to delete alt_bndrs
+ scrut_fvs = freeVarsOf scrut -- See Note [Shadowing and name capture]
+ all_bndrs = mkDVarSet alt_bndrs `extendDVarSet` case_bndr
- [shared_binds, scrut_binds, rhs_binds]
- = sepBindsByDropPoint platform False
- [scrut_fvs, rhs_fvs]
- to_drop
+ (shared_binds, [scrut_binds, rhs_binds])
+ = sepBindsByDropPoint platform False to_drop
+ all_bndrs [scrut_fvs, rhs_fvs]
fiExpr platform to_drop (_, AnnCase scrut case_bndr ty alts)
= wrapFloats drop_here1 $
@@ -462,39 +501,43 @@ fiExpr platform to_drop (_, AnnCase scrut case_bndr ty alts)
-- use zipWithEqual, we should have length alts_drops_s = length alts
where
-- Float into the scrut and alts-considered-together just like App
- [drop_here1, scrut_drops, alts_drops]
- = sepBindsByDropPoint platform False
- [scrut_fvs, all_alts_fvs]
- to_drop
+ (drop_here1, [scrut_drops, alts_drops])
+ = sepBindsByDropPoint platform False to_drop
+ all_alt_bndrs [scrut_fvs, all_alt_fvs]
+ -- all_alt_bndrs: see Note [Shadowing and name capture]
-- Float into the alts with the is_case flag set
- (drop_here2 : alts_drops_s)
- | [ _ ] <- alts = [] : [alts_drops]
- | otherwise = sepBindsByDropPoint platform True alts_fvs alts_drops
-
- scrut_fvs = freeVarsOf scrut
- alts_fvs = map alt_fvs alts
- all_alts_fvs = unionDVarSets alts_fvs
- alt_fvs (AnnAlt _con args rhs)
- = foldl' delDVarSet (freeVarsOf rhs) (case_bndr:args)
- -- Delete case_bndr and args from free vars of rhs
- -- to get free vars of alt
+ (drop_here2, alts_drops_s)
+ = sepBindsByDropPoint platform True alts_drops emptyDVarSet alts_fvs
+
+ scrut_fvs = freeVarsOf scrut
+
+ all_alt_bndrs = foldr (unionDVarSet . ann_alt_bndrs) (unitDVarSet case_bndr) alts
+ ann_alt_bndrs (AnnAlt _ bndrs _) = mkDVarSet bndrs
+
+ alts_fvs :: [DVarSet]
+ alts_fvs = [freeVarsOf rhs | AnnAlt _ _ rhs <- alts]
+ -- No need to delete binders
+ -- See Note [Shadowing and name capture]
+
+ all_alt_fvs :: DVarSet
+ all_alt_fvs = foldr unionDVarSet (unitDVarSet case_bndr) alts_fvs
fi_alt to_drop (AnnAlt con args rhs) = Alt con args (fiExpr platform to_drop rhs)
------------------
fiBind :: Platform
- -> FloatInBinds -- Binds we're trying to drop
- -- as far "inwards" as possible
- -> CoreBindWithFVs -- Input binding
- -> DVarSet -- Free in scope of binding
- -> ( FloatInBinds -- Land these before
- , FloatInBind -- The binding itself
- , FloatInBinds) -- Land these after
+ -> RevFloatInBinds -- Binds we're trying to drop
+ -- as far "inwards" as possible
+ -> CoreBindWithFVs -- Input binding
+ -> DVarSet -- Free in scope of binding
+ -> ( RevFloatInBinds -- Land these before
+ , FloatInBind -- The binding itself
+ , RevFloatInBinds) -- Land these after
fiBind platform to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs
- = ( extra_binds ++ shared_binds -- Land these before
- -- See Note [extra_fvs (1)] and Note [extra_fvs (2)]
+ = ( shared_binds -- Land these before
+ -- See Note [extra_fvs (1)] and Note [extra_fvs (2)]
, FB (unitDVarSet id) rhs_fvs' -- The new binding itself
(FloatLet (NonRec id rhs'))
, body_binds ) -- Land these after
@@ -512,10 +555,9 @@ fiBind platform to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs
-- We *can't* float into ok-for-speculation unlifted RHSs
-- But do float into join points
- [shared_binds, extra_binds, rhs_binds, body_binds]
- = sepBindsByDropPoint platform False
- [extra_fvs, rhs_fvs, body_fvs2]
- to_drop
+ (shared_binds, [rhs_binds, body_binds])
+ = sepBindsByDropPoint platform False to_drop
+ extra_fvs [rhs_fvs, body_fvs2]
-- Push rhs_binds into the right hand side of the binding
rhs' = fiRhs platform rhs_binds id ann_rhs
@@ -523,7 +565,7 @@ fiBind platform to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs
-- Don't forget the rule_fvs; the binding mentions them!
fiBind platform to_drop (AnnRec bindings) body_fvs
- = ( extra_binds ++ shared_binds
+ = ( shared_binds
, FB (mkDVarSet ids) rhs_fvs'
(FloatLet (Rec (fi_bind rhss_binds bindings)))
, body_binds )
@@ -537,17 +579,16 @@ fiBind platform to_drop (AnnRec bindings) body_fvs
unionDVarSets [ rhs_fvs | (bndr, (rhs_fvs, rhs)) <- bindings
, noFloatIntoRhs Recursive bndr rhs ]
- (shared_binds:extra_binds:body_binds:rhss_binds)
- = sepBindsByDropPoint platform False
- (extra_fvs:body_fvs:rhss_fvs)
- to_drop
+ (shared_binds, body_binds:rhss_binds)
+ = sepBindsByDropPoint platform False to_drop
+ extra_fvs (body_fvs:rhss_fvs)
rhs_fvs' = unionDVarSets rhss_fvs `unionDVarSet`
unionDVarSets (map floatedBindsFVs rhss_binds) `unionDVarSet`
rule_fvs -- Don't forget the rule variables!
-- Push rhs_binds into the right hand side of the binding
- fi_bind :: [FloatInBinds] -- one per "drop pt" conjured w/ fvs_of_rhss
+ fi_bind :: [RevFloatInBinds] -- One per "drop pt" conjured w/ fvs_of_rhss
-> [(Id, CoreExprWithFVs)]
-> [(Id, CoreExpr)]
@@ -556,7 +597,7 @@ fiBind platform to_drop (AnnRec bindings) body_fvs
| ((binder, rhs), to_drop) <- zipEqual "fi_bind" pairs to_drops ]
------------------
-fiRhs :: Platform -> FloatInBinds -> CoreBndr -> CoreExprWithFVs -> CoreExpr
+fiRhs :: Platform -> RevFloatInBinds -> CoreBndr -> CoreExprWithFVs -> CoreExpr
fiRhs platform to_drop bndr rhs
| Just join_arity <- isJoinId_maybe bndr
, let (bndrs, body) = collectNAnnBndrs join_arity rhs
@@ -656,68 +697,84 @@ point.
We have to maintain the order on these drop-point-related lists.
-}
--- pprFIB :: FloatInBinds -> SDoc
+-- pprFIB :: RevFloatInBinds -> SDoc
-- pprFIB fibs = text "FIB:" <+> ppr [b | FB _ _ b <- fibs]
sepBindsByDropPoint
:: Platform
- -> Bool -- True <=> is case expression
- -> [FreeVarSet] -- One set of FVs per drop point
- -- Always at least two long!
- -> FloatInBinds -- Candidate floaters
- -> [FloatInBinds] -- FIRST one is bindings which must not be floated
- -- inside any drop point; the rest correspond
- -- one-to-one with the input list of FV sets
+ -> Bool -- True <=> is case expression
+ -> RevFloatInBinds -- Candidate floaters
+ -> FreeVarSet -- here_fvs: if these vars are free in a binding,
+ -- don't float that binding inside any drop point
+ -> [FreeVarSet] -- fork_fvs: one set of FVs per drop point
+ -> ( RevFloatInBinds -- Bindings which must not be floated inside
+ , [RevFloatInBinds] ) -- Corresponds 1-1 with the input list of FV sets
-- Every input floater is returned somewhere in the result;
-- none are dropped, not even ones which don't seem to be
-- free in *any* of the drop-point fvs. Why? Because, for example,
-- a binding (let x = E in B) might have a specialised version of
-- x (say x') stored inside x, but x' isn't free in E or B.
+--
+-- The here_fvs argument is used for two things:
+-- * Avoid shadowing bugs: see Note [Shadowing and name capture]
+-- * Drop some of the bindings at the top, e.g. of an application
type DropBox = (FreeVarSet, FloatInBinds)
-sepBindsByDropPoint platform is_case drop_pts floaters
+dropBoxFloats :: DropBox -> RevFloatInBinds
+dropBoxFloats (_, floats) = reverse floats
+
+usedInDropBox :: DIdSet -> DropBox -> Bool
+usedInDropBox bndrs (db_fvs, _) = db_fvs `intersectsDVarSet` bndrs
+
+initDropBox :: DVarSet -> DropBox
+initDropBox fvs = (fvs, [])
+
+sepBindsByDropPoint platform is_case floaters here_fvs fork_fvs
| null floaters -- Shortcut common case
- = [] : [[] | _ <- drop_pts]
+ = ([], [[] | _ <- fork_fvs])
| otherwise
- = assert (drop_pts `lengthAtLeast` 2) $
- go floaters (map (\fvs -> (fvs, [])) (emptyDVarSet : drop_pts))
+ = go floaters (initDropBox here_fvs) (map initDropBox fork_fvs)
where
- n_alts = length drop_pts
+ n_alts = length fork_fvs
- go :: FloatInBinds -> [DropBox] -> [FloatInBinds]
- -- The *first* one in the argument list is the drop_here set
- -- The FloatInBinds in the lists are in the reverse of
- -- the normal FloatInBinds order; that is, they are the right way round!
+ go :: RevFloatInBinds -> DropBox -> [DropBox]
+ -> (RevFloatInBinds, [RevFloatInBinds])
+ -- The *first* one in the pair is the drop_here set
- go [] drop_boxes = map (reverse . snd) drop_boxes
+ go [] here_box fork_boxes
+ = (dropBoxFloats here_box, map dropBoxFloats fork_boxes)
- go (bind_w_fvs@(FB bndrs bind_fvs bind) : binds) drop_boxes@(here_box : fork_boxes)
- = go binds new_boxes
+ go (bind_w_fvs@(FB bndrs bind_fvs bind) : binds) here_box fork_boxes
+ | drop_here = go binds (insert here_box) fork_boxes
+ | otherwise = go binds here_box new_fork_boxes
where
-- "here" means the group of bindings dropped at the top of the fork
- (used_here : used_in_flags) = [ fvs `intersectsDVarSet` bndrs
- | (fvs, _) <- drop_boxes]
+ used_here = bndrs `usedInDropBox` here_box
+ used_in_flags = case fork_boxes of
+ [] -> []
+ [_] -> [True] -- Push all bindings into a single branch
+ -- No need to look at its free vars
+ _ -> map (bndrs `usedInDropBox`) fork_boxes
+ -- Short-cut for the singleton case;
+ -- used for lambdas and singleton cases
drop_here = used_here || cant_push
n_used_alts = count id used_in_flags -- returns number of Trues in list.
cant_push
- | is_case = n_used_alts == n_alts -- Used in all, don't push
- -- Remember n_alts > 1
+ | is_case = (n_alts > 1 && n_used_alts == n_alts)
+ -- Used in all, muliple branches, don't push
|| (n_used_alts > 1 && not (floatIsDupable platform bind))
-- floatIsDupable: see Note [Duplicating floats]
| otherwise = floatIsCase bind || n_used_alts > 1
-- floatIsCase: see Note [Floating primops]
- new_boxes | drop_here = (insert here_box : fork_boxes)
- | otherwise = (here_box : new_fork_boxes)
-
new_fork_boxes = zipWithEqual "FloatIn.sepBinds" insert_maybe
fork_boxes used_in_flags
@@ -727,8 +784,6 @@ sepBindsByDropPoint platform is_case drop_pts floaters
insert_maybe box True = insert box
insert_maybe box False = box
- go _ _ = panic "sepBindsByDropPoint/go"
-
{- Note [Duplicating floats]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -745,14 +800,14 @@ If the thing is used in all RHSs there is nothing gained,
so we don't duplicate then.
-}
-floatedBindsFVs :: FloatInBinds -> FreeVarSet
+floatedBindsFVs :: RevFloatInBinds -> FreeVarSet
floatedBindsFVs binds = mapUnionDVarSet fbFVs binds
fbFVs :: FloatInBind -> DVarSet
fbFVs (FB _ fvs _) = fvs
-wrapFloats :: FloatInBinds -> CoreExpr -> CoreExpr
--- Remember FloatInBinds is in *reverse* dependency order
+wrapFloats :: RevFloatInBinds -> CoreExpr -> CoreExpr
+-- Remember RevFloatInBinds is in *reverse* dependency order
wrapFloats [] e = e
wrapFloats (FB _ _ fl : bs) e = wrapFloats bs (wrapFloat fl e)
diff --git a/testsuite/tests/simplCore/should_compile/T22662.hs b/testsuite/tests/simplCore/should_compile/T22662.hs
new file mode 100644
index 0000000000..101603634a
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T22662.hs
@@ -0,0 +1,6 @@
+module T22662 where
+
+import Data.Set
+
+foo x = sequence_ [ f y | y <- x ]
+ where f _ = return (fromList [0])
diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T
index edbefd6145..c5f63d6e7a 100644
--- a/testsuite/tests/simplCore/should_compile/all.T
+++ b/testsuite/tests/simplCore/should_compile/all.T
@@ -461,3 +461,4 @@ test('T21476', normal, compile, [''])
test('T22272', normal, multimod_compile, ['T22272', '-O -fexpose-all-unfoldings -fno-omit-interface-pragmas -fno-ignore-interface-pragmas'])
test('T22459', normal, compile, [''])
test('T22623', normal, multimod_compile, ['T22623', '-O -v0'])
+test('T22662', normal, compile, [''])