diff options
author | simonpj@microsoft.com <unknown> | 2009-10-23 16:15:51 +0000 |
---|---|---|
committer | simonpj@microsoft.com <unknown> | 2009-10-23 16:15:51 +0000 |
commit | c43c981705ec33da92a9ce91eb90f2ecf00be9fe (patch) | |
tree | ef6725e233f3481f5121561671370f818c6ec2fa /compiler/specialise | |
parent | 0cffd31b0f25c2a31ed6eff2c0c0b1b1a8a8d507 (diff) | |
download | haskell-c43c981705ec33da92a9ce91eb90f2ecf00be9fe.tar.gz |
Fix Trac #3591: very tricky specialiser bug
There was a subtle bug in the interation of specialisation and floating,
described in Note [Specialisation of dictionary functions].
The net effect was to create a loop where none existed before; plain wrong.
In fixing it, I did quite a bit of house-cleaning in the specialiser, and
added a lot more comments. It's tricky, alas.
Diffstat (limited to 'compiler/specialise')
-rw-r--r-- | compiler/specialise/Specialise.lhs | 488 |
1 files changed, 308 insertions, 180 deletions
diff --git a/compiler/specialise/Specialise.lhs b/compiler/specialise/Specialise.lhs index 64d0cdd7a3..590e689f4a 100644 --- a/compiler/specialise/Specialise.lhs +++ b/compiler/specialise/Specialise.lhs @@ -14,30 +14,17 @@ module Specialise ( specProgram ) where #include "HsVersions.h" -import Id ( Id, idName, idType, mkUserLocal, idCoreRules, idUnfolding, - idInlineActivation, setInlineActivation, setIdUnfolding, - isLocalId, isDataConWorkId, idArity, setIdArity ) -import TcType ( Type, mkTyVarTy, tcSplitSigmaTy, - tyVarsOfTypes, tyVarsOfTheta, isClassPred, - tcCmpType, isUnLiftedType - ) -import CoreSubst ( Subst, mkEmptySubst, extendTvSubstList, lookupIdSubst, - substBndr, substBndrs, substTy, substInScope, - cloneIdBndr, cloneIdBndrs, cloneRecIdBndrs, - extendIdSubst - ) +import Id +import TcType +import CoreSubst import CoreUnfold ( mkUnfolding ) -import Var ( DictId ) import VarSet import VarEnv import CoreSyn import Rules import CoreUtils ( exprIsTrivial, applyTypeToArgs, mkPiTypes ) import CoreFVs ( exprFreeVars, exprsFreeVars, idFreeVars ) -import UniqSupply ( UniqSupply, - UniqSM, initUs_, - MonadUnique(..) - ) +import UniqSupply ( UniqSupply, UniqSM, initUs_, MonadUnique(..) ) import Name import MkId ( voidArgId, realWorldPrimId ) import FiniteMap @@ -575,8 +562,9 @@ Hence, the invariant is this: \begin{code} specProgram :: UniqSupply -> [CoreBind] -> [CoreBind] -specProgram us binds = initSM us (do (binds', uds') <- go binds - return (dumpAllDictBinds uds' binds')) +specProgram us binds = initSM us $ + do { (binds', uds') <- go binds + ; return (wrapDictBinds (ud_binds uds') binds') } where -- We need to start with a Subst that knows all the things -- that are in scope, so that the substitution engine doesn't @@ -609,12 +597,12 @@ specExpr :: Subst -> CoreExpr -> SpecM (CoreExpr, UsageDetails) -- the RHS of specialised bindings (no type-let!) ---------------- First the easy cases -------------------- -specExpr subst (Type ty) = return (Type (substTy subst ty), emptyUDs) +specExpr subst (Type ty) = return (Type (CoreSubst.substTy subst ty), emptyUDs) specExpr subst (Var v) = return (specVar subst v, emptyUDs) specExpr _ (Lit lit) = return (Lit lit, emptyUDs) specExpr subst (Cast e co) = do (e', uds) <- specExpr subst e - return ((Cast e' (substTy subst co)), uds) + return ((Cast e' (CoreSubst.substTy subst co)), uds) specExpr subst (Note note body) = do (body', uds) <- specExpr subst body return (Note (specNote subst note) body', uds) @@ -636,8 +624,8 @@ specExpr subst expr@(App {}) ---------------- Lambda/case require dumping of usage details -------------------- specExpr subst e@(Lam _ _) = do (body', uds) <- specExpr subst' body - let (filtered_uds, body'') = dumpUDs bndrs' uds body' - return (mkLams bndrs' body'', filtered_uds) + let (free_uds, dumped_dbs) = dumpUDs bndrs' uds + return (mkLams bndrs' (wrapDictBindsE dumped_dbs body'), free_uds) where (bndrs, body) = collectBinders e (subst', bndrs') = substBndrs subst bndrs @@ -647,15 +635,16 @@ specExpr subst e@(Lam _ _) = do specExpr subst (Case scrut case_bndr ty alts) = do (scrut', uds_scrut) <- specExpr subst scrut (alts', uds_alts) <- mapAndCombineSM spec_alt alts - return (Case scrut' case_bndr' (substTy subst ty) alts', uds_scrut `plusUDs` uds_alts) + return (Case scrut' case_bndr' (CoreSubst.substTy subst ty) alts', + uds_scrut `plusUDs` uds_alts) where (subst_alt, case_bndr') = substBndr subst case_bndr -- No need to clone case binder; it can't float like a let(rec) spec_alt (con, args, rhs) = do (rhs', uds) <- specExpr subst_rhs rhs - let (uds', rhs'') = dumpUDs args uds rhs' - return ((con, args', rhs''), uds') + let (free_uds, dumped_dbs) = dumpUDs args' uds + return ((con, args', wrapDictBindsE dumped_dbs rhs'), free_uds) where (subst_rhs, args') = substBndrs subst_alt args @@ -691,92 +680,72 @@ specBind :: Subst -- Use this for RHSs -> SpecM ([CoreBind], -- New bindings UsageDetails) -- And info to pass upstream -specBind rhs_subst bind body_uds - = do { (bind', bind_uds) <- specBindItself rhs_subst bind (calls body_uds) - ; return (finishSpecBind bind' bind_uds body_uds) } - -finishSpecBind :: CoreBind -> UsageDetails -> UsageDetails -> ([CoreBind], UsageDetails) -finishSpecBind bind - (MkUD { dict_binds = rhs_dbs, calls = rhs_calls, ud_fvs = rhs_fvs }) - (MkUD { dict_binds = body_dbs, calls = body_calls, ud_fvs = body_fvs }) - | not (mkVarSet bndrs `intersectsVarSet` all_fvs) - -- Common case 1: the bound variables are not - -- mentioned in the dictionary bindings - = ([bind], MkUD { dict_binds = body_dbs `unionBags` rhs_dbs - -- It's important that the `unionBags` is this way round, - -- because body_uds may bind dictionaries that are - -- used in the calls passed to specDefn. So the - -- dictionary bindings in rhs_uds may mention - -- dictionaries bound in body_uds. - , calls = all_calls - , ud_fvs = all_fvs }) - - | case bind of { NonRec {} -> True; Rec {} -> False } - -- Common case 2: no specialisation happened, and binding - -- is non-recursive. But the binding may be - -- mentioned in body_dbs, so we should put it first - = ([], MkUD { dict_binds = rhs_dbs `unionBags` ((bind, b_fvs) `consBag` body_dbs) - , calls = all_calls - , ud_fvs = all_fvs `unionVarSet` b_fvs }) - - | otherwise -- General case: make a huge Rec (sigh) - = ([], MkUD { dict_binds = unitBag (Rec all_db_prs, all_db_fvs) - , calls = all_calls - , ud_fvs = all_fvs `unionVarSet` b_fvs }) - where - all_fvs = rhs_fvs `unionVarSet` body_fvs - all_calls = zapCalls bndrs (rhs_calls `unionCalls` body_calls) +-- Returned UsageDetails: +-- No calls for binders of this bind +specBind rhs_subst (NonRec fn rhs) body_uds + = do { (rhs', rhs_uds) <- specExpr rhs_subst rhs + ; (fn', spec_defns, body_uds1) <- specDefn rhs_subst body_uds fn rhs - bndrs = bindersOf bind - b_fvs = bind_fvs bind + ; let pairs = spec_defns ++ [(fn', rhs')] + -- fn' mentions the spec_defns in its rules, + -- so put the latter first - (all_db_prs, all_db_fvs) = add (bind, b_fvs) $ - foldrBag add ([], emptyVarSet) $ - rhs_dbs `unionBags` body_dbs - add (NonRec b r, b_fvs) (prs, fvs) = ((b,r) : prs, b_fvs `unionVarSet` fvs) - add (Rec b_prs, b_fvs) (prs, fvs) = (b_prs ++ prs, b_fvs `unionVarSet` fvs) + combined_uds = body_uds1 `plusUDs` rhs_uds + -- This way round a call in rhs_uds of a function f + -- at type T will override a call of f at T in body_uds1; and + -- that is good because it'll tend to keep "earlier" calls + -- See Note [Specialisation of dictionary functions] ---------------------------- -specBindItself :: Subst -> CoreBind -> CallDetails -> SpecM (CoreBind, UsageDetails) - --- specBindItself deals with the RHS, specialising it according --- to the calls found in the body (if any) -specBindItself rhs_subst (NonRec fn rhs) call_info - = do { (rhs', rhs_uds) <- specExpr rhs_subst rhs -- Do RHS of original fn - ; (fn', spec_defns, spec_uds) <- specDefn rhs_subst call_info fn rhs - ; if null spec_defns then - return (NonRec fn rhs', rhs_uds) - else - return (Rec ((fn',rhs') : spec_defns), rhs_uds `plusUDs` spec_uds) } - -- bndr' mentions the spec_defns in its SpecEnv - -- Not sure why we couln't just put the spec_defns first - -specBindItself rhs_subst (Rec pairs) call_info + (free_uds, dump_dbs, float_all) = dumpBindUDs [fn] combined_uds + -- See Note [From non-recursive to recursive] + + final_binds | isEmptyBag dump_dbs = [NonRec b r | (b,r) <- pairs] + | otherwise = [Rec (flattenDictBinds dump_dbs pairs)] + + ; if float_all then + -- Rather than discard the calls mentioning the bound variables + -- we float this binding along with the others + return ([], free_uds `snocDictBinds` final_binds) + else + -- No call in final_uds mentions bound variables, + -- so we can just leave the binding here + return (final_binds, free_uds) } + + +specBind rhs_subst (Rec pairs) body_uds -- Note [Specialising a recursive group] = do { let (bndrs,rhss) = unzip pairs ; (rhss', rhs_uds) <- mapAndCombineSM (specExpr rhs_subst) rhss - ; let all_calls = call_info `unionCalls` calls rhs_uds - ; (bndrs1, spec_defns1, spec_uds1) <- specDefns rhs_subst all_calls pairs - - ; if null spec_defns1 then -- Common case: no specialisation - return (Rec (bndrs `zip` rhss'), rhs_uds) - else do -- Specialisation occurred; do it again - { (bndrs2, spec_defns2, spec_uds2) <- - -- pprTrace "specB" (ppr bndrs $$ ppr rhs_uds) $ - specDefns rhs_subst (calls spec_uds1) (bndrs1 `zip` rhss) - - ; let all_defns = spec_defns1 ++ spec_defns2 ++ zip bndrs2 rhss' + ; let scope_uds = body_uds `plusUDs` rhs_uds + -- Includes binds and calls arising from rhss + + ; (bndrs1, spec_defns1, uds1) <- specDefns rhs_subst scope_uds pairs + + ; (bndrs3, spec_defns3, uds3) + <- if null spec_defns1 -- Common case: no specialisation + then return (bndrs1, [], uds1) + else do { -- Specialisation occurred; do it again + (bndrs2, spec_defns2, uds2) + <- specDefns rhs_subst uds1 (bndrs1 `zip` rhss) + ; return (bndrs2, spec_defns2 ++ spec_defns1, uds2) } + + ; let (final_uds, dumped_dbs, float_all) = dumpBindUDs bndrs uds3 + bind = Rec (flattenDictBinds dumped_dbs $ + spec_defns3 ++ zip bndrs3 rhss') - ; return (Rec all_defns, rhs_uds `plusUDs` spec_uds1 `plusUDs` spec_uds2) } } + ; if float_all then + return ([], final_uds `snocDictBind` bind) + else + return ([bind], final_uds) } --------------------------- specDefns :: Subst - -> CallDetails -- Info on how it is used in its scope - -> [(Id,CoreExpr)] -- The things being bound and their un-processed RHS - -> SpecM ([Id], -- Original Ids with RULES added - [(Id,CoreExpr)], -- Extra, specialised bindings - UsageDetails) -- Stuff to fling upwards from the specialised versions + -> UsageDetails -- Info on how it is used in its scope + -> [(Id,CoreExpr)] -- The things being bound and their un-processed RHS + -> SpecM ([Id], -- Original Ids with RULES added + [(Id,CoreExpr)], -- Extra, specialised bindings + UsageDetails) -- Stuff to fling upwards from the specialised versions -- Specialise a list of bindings (the contents of a Rec), but flowing usages -- upwards binding by binding. Example: { f = ...g ...; g = ...f .... } @@ -784,25 +753,22 @@ specDefns :: Subst -- in turn generates a specialised call for 'f', we catch that in this one sweep. -- But not vice versa (it's a fixpoint problem). -specDefns _subst _call_info [] - = return ([], [], emptyUDs) -specDefns subst call_info ((bndr,rhs):pairs) - = do { (bndrs', spec_defns, spec_uds) <- specDefns subst call_info pairs - ; let all_calls = call_info `unionCalls` calls spec_uds - ; (bndr', spec_defns1, spec_uds1) <- specDefn subst all_calls bndr rhs - ; return (bndr' : bndrs', - spec_defns1 ++ spec_defns, - spec_uds1 `plusUDs` spec_uds) } +specDefns _subst uds [] + = return ([], [], uds) +specDefns subst uds ((bndr,rhs):pairs) + = do { (bndrs1, spec_defns1, uds1) <- specDefns subst uds pairs + ; (bndr1, spec_defns2, uds2) <- specDefn subst uds1 bndr rhs + ; return (bndr1 : bndrs1, spec_defns1 ++ spec_defns2, uds2) } --------------------------- specDefn :: Subst - -> CallDetails -- Info on how it is used in its scope + -> UsageDetails -- Info on how it is used in its scope -> Id -> CoreExpr -- The thing being bound and its un-processed RHS -> SpecM (Id, -- Original Id with added RULES [(Id,CoreExpr)], -- Extra, specialised bindings UsageDetails) -- Stuff to fling upwards from the specialised versions -specDefn subst calls fn rhs +specDefn subst body_uds fn rhs -- The first case is the interesting one | rhs_tyvars `lengthIs` n_tyvars -- Rhs of fn's defn has right number of big lambdas && rhs_ids `lengthAtLeast` n_dicts -- and enough dict args @@ -816,12 +782,20 @@ specDefn subst calls fn rhs stuff <- mapM spec_call calls_for_me ; let (spec_defns, spec_uds, spec_rules) = unzip3 (catMaybes stuff) fn' = addIdSpecialisations fn spec_rules - ; return (fn', spec_defns, plusUDList spec_uds) } + final_uds = body_uds_without_me `plusUDs` plusUDList spec_uds + -- It's important that the `plusUDs` is this way + -- round, because body_uds_without_me may bind + -- dictionaries that are used in calls_for_me passed + -- to specDefn. So the dictionary bindings in + -- spec_uds may mention dictionaries bound in + -- body_uds_without_me + + ; return (fn', spec_defns, final_uds) } | otherwise -- No calls or RHS doesn't fit our preconceptions = WARN( notNull calls_for_me, ptext (sLit "Missed specialisation opportunity for") <+> ppr fn ) -- Note [Specialisation shape] - return (fn, [], emptyUDs) + return (fn, [], body_uds_without_me) where fn_type = idType fn @@ -831,6 +805,8 @@ specDefn subst calls fn rhs n_dicts = length theta inline_act = idInlineActivation fn + (body_uds_without_me, calls_for_me) = callsForMe fn body_uds + -- It's important that we "see past" any INLINE pragma -- else we'll fail to specialise an INLINE thing (inline_rhs, rhs_inside) = dropInline rhs @@ -840,10 +816,6 @@ specDefn subst calls fn rhs body = mkLams (drop n_dicts rhs_ids) rhs_body -- Glue back on the non-dict lambdas - calls_for_me = case lookupFM calls fn of - Nothing -> [] - Just cs -> fmToList cs - already_covered :: [CoreExpr] -> Bool already_covered args -- Note [Specialisations already covered] = isJust (lookupRule (const True) (substInScope subst) @@ -857,7 +829,7 @@ specDefn subst calls fn rhs ---------------------------------------------------------- -- Specialise to one particular call pattern - spec_call :: (CallKey, ([DictExpr], VarSet)) -- Call instance + spec_call :: CallInfo -- Call instance -> SpecM (Maybe ((Id,CoreExpr), -- Specialised definition UsageDetails, -- Usage details from specialised body CoreRule)) -- Info for the Id's SpecEnv @@ -885,7 +857,7 @@ specDefn subst calls fn rhs spec_tv_binds = [(tv,ty) | (tv, Just ty) <- rhs_tyvars `zip` call_ts] spec_ty_args = map snd spec_tv_binds ty_args = mk_ty_args call_ts - rhs_subst = extendTvSubstList subst spec_tv_binds + rhs_subst = CoreSubst.extendTvSubstList subst spec_tv_binds ; (rhs_subst1, inst_dict_ids) <- cloneDictBndrs rhs_subst rhs_dict_ids -- Clone rhs_dicts, including instantiating their types @@ -924,7 +896,7 @@ specDefn subst calls fn rhs (mkVarApps (Var spec_f_w_arity) app_args) -- Add the { d1' = dx1; d2' = dx2 } usage stuff - final_uds = foldr addDictBind rhs_uds dx_binds + final_uds = foldr consDictBind rhs_uds dx_binds spec_pr | inline_rhs = (spec_f_w_arity `setInlineActivation` inline_act, Note InlineMe spec_rhs) | otherwise = (spec_f_w_arity, spec_rhs) @@ -944,7 +916,7 @@ bindAuxiliaryDicts :: Subst -> [(DictId,DictId,CoreExpr)] -- (orig_dict, inst_dict, dx) -> (Subst, -- Substitute for all orig_dicts - [(DictId, CoreExpr)]) -- Auxiliary bindings + [CoreBind]) -- Auxiliary bindings -- Bind any dictionary arguments to fresh names, to preserve sharing -- Substitution already substitutes orig_dict -> inst_dict bindAuxiliaryDicts subst triples = go subst [] triples @@ -953,7 +925,7 @@ bindAuxiliaryDicts subst triples = go subst [] triples go subst binds ((d, dx_id, dx) : pairs) | exprIsTrivial dx = go (extendIdSubst subst d dx) binds pairs -- No auxiliary binding necessary - | otherwise = go subst_w_unf ((dx_id,dx) : binds) pairs + | otherwise = go subst_w_unf (NonRec dx_id dx : binds) pairs where dx_id1 = dx_id `setIdUnfolding` mkUnfolding False dx subst_w_unf = extendIdSubst subst d (Var dx_id1) @@ -967,6 +939,96 @@ bindAuxiliaryDicts subst triples = go subst [] triples -- We want that consequent call to look interesting \end{code} +Note [From non-recursive to recursive] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Even in the non-recursive case, if any dict-binds depend on 'fn' we might +have built a recursive knot + + f a d x = <blah> + MkUD { ud_binds = d7 = MkD ..f.. + , ud_calls = ...(f T d7)... } + +The we generate + + Rec { fs x = <blah>[T/a, d7/d] + f a d x = <blah> + RULE f T _ = fs + d7 = ...f... } + +Here the recursion is only through the RULE. + + +Note [Specialisation of dictionary functions] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Here is a nasty example that bit us badly: see Trac #3591 + + dfun a d = MkD a d (meth d) + d4 = <blah> + d2 = dfun T d4 + d1 = $p1 d2 + d3 = dfun T d1 + +None of these definitions is recursive. What happened was that we +generated a specialisation: + + RULE forall d. dfun T d = dT + dT = (MkD a d (meth d)) [T/a, d1/d] + = MkD T d1 (meth d1) + +But now we use the RULE on the RHS of d2, to get + + d2 = dT = MkD d1 (meth d1) + d1 = $p1 d2 + +and now d1 is bottom! The problem is that when specialising 'dfun' we +should first dump "below" the binding all floated dictionary bindings +that mention 'dfun' itself. So d2 and d3 (and hence d1) must be +placed below 'dfun', and thus unavailable to it when specialising +'dfun'. That in turn means that the call (dfun T d1) must be +discarded. On the other hand, the call (dfun T d4) is fine, assuming +d4 doesn't mention dfun. + +But look at this: + + class C a where { foo,bar :: [a] -> [a] } + + instance C Int where + foo x = r_bar x + bar xs = reverse xs + + r_bar :: C a => [a] -> [a] + r_bar xs = bar (xs ++ xs) + +That translates to: + + r_bar a (c::C a) (xs::[a]) = bar a d (xs ++ xs) + + Rec { $fCInt :: C Int = MkC foo_help reverse + foo_help (xs::[Int]) = r_bar Int $fCInt xs } + +The call (r_bar $fCInt) mentions $fCInt, + which mentions foo_help, + which mentions r_bar +But we DO want to specialise r_bar at Int: + + Rec { $fCInt :: C Int = MkC foo_help reverse + foo_help (xs::[Int]) = r_bar Int $fCInt xs + + r_bar a (c::C a) (xs::[a]) = bar a d (xs ++ xs) + RULE r_bar Int _ = r_bar_Int + + r_bar_Int xs = bar Int $fCInt (xs ++ xs) + } + +Note that, because of its RULE, r_bar joins the recursive +group. (In this case it'll unravel a short moment later.) + + +Conclusion: we catch the nasty case using filter_dfuns in +callsForMe To be honest I'm not 100% certain that this is 100% +right, but it works. Sigh. + + Note [Specialising a recursive group] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Consider @@ -1110,24 +1172,24 @@ dropInline rhs = (False, rhs) \begin{code} data UsageDetails = MkUD { - dict_binds :: !(Bag DictBind), + ud_binds :: !(Bag DictBind), -- Floated dictionary bindings -- The order is important; -- in ds1 `union` ds2, bindings in ds2 can depend on those in ds1 -- (Remember, Bags preserve order in GHC.) - calls :: !CallDetails, + ud_calls :: !CallDetails - ud_fvs :: !VarSet -- A superset of the variables mentioned in - -- either dict_binds or calls + -- INVARIANT: suppose bs = bindersOf ud_binds + -- Then 'calls' may *mention* 'bs', + -- but there should be no calls *for* bs } instance Outputable UsageDetails where - ppr (MkUD { dict_binds = dbs, calls = calls, ud_fvs = fvs }) + ppr (MkUD { ud_binds = dbs, ud_calls = calls }) = ptext (sLit "MkUD") <+> braces (sep (punctuate comma [ptext (sLit "binds") <+> equals <+> ppr dbs, - ptext (sLit "calls") <+> equals <+> ppr calls, - ptext (sLit "fvs") <+> equals <+> ppr fvs])) + ptext (sLit "calls") <+> equals <+> ppr calls])) type DictBind = (CoreBind, VarSet) -- The set is the free vars of the binding @@ -1136,10 +1198,10 @@ type DictBind = (CoreBind, VarSet) type DictExpr = CoreExpr emptyUDs :: UsageDetails -emptyUDs = MkUD { dict_binds = emptyBag, calls = emptyFM, ud_fvs = emptyVarSet } +emptyUDs = MkUD { ud_binds = emptyBag, ud_calls = emptyVarEnv } ------------------------------------------------------------ -type CallDetails = FiniteMap Id CallInfo +type CallDetails = IdEnv CallInfoSet newtype CallKey = CallKey [Maybe Type] -- Nothing => unconstrained type argument -- CallInfo uses a FiniteMap, thereby ensuring that @@ -1147,11 +1209,13 @@ newtype CallKey = CallKey [Maybe Type] -- Nothing => unconstrained type argu -- -- The list of types and dictionaries is guaranteed to -- match the type of f -type CallInfo = FiniteMap CallKey ([DictExpr], VarSet) +type CallInfoSet = FiniteMap CallKey ([DictExpr], VarSet) -- Range is dict args and the vars of the whole -- call (including tyvars) -- [*not* include the main id itself, of course] +type CallInfo = (CallKey, ([DictExpr], VarSet)) + instance Outputable CallKey where ppr (CallKey ts) = ppr ts @@ -1169,13 +1233,22 @@ instance Ord CallKey where cmp (Just t1) (Just t2) = tcCmpType t1 t2 unionCalls :: CallDetails -> CallDetails -> CallDetails -unionCalls c1 c2 = plusFM_C plusFM c1 c2 +unionCalls c1 c2 = plusVarEnv_C plusFM c1 c2 + +-- plusCalls :: UsageDetails -> CallDetails -> UsageDetails +-- plusCalls uds call_ds = uds { ud_calls = ud_calls uds `unionCalls` call_ds } + +callDetailsFVs :: CallDetails -> VarSet +callDetailsFVs calls = foldVarEnv (unionVarSet . callInfoFVs) emptyVarSet calls +callInfoFVs :: CallInfoSet -> VarSet +callInfoFVs call_info = foldFM (\_ (_,fv) vs -> unionVarSet fv vs) emptyVarSet call_info + +------------------------------------------------------------ singleCall :: Id -> [Maybe Type] -> [DictExpr] -> UsageDetails singleCall id tys dicts - = MkUD {dict_binds = emptyBag, - calls = unitFM id (unitFM (CallKey tys) (dicts, call_fvs)), - ud_fvs = call_fvs } + = MkUD {ud_binds = emptyBag, + ud_calls = unitVarEnv id (unitFM (CallKey tys) (dicts, call_fvs)) } where call_fvs = exprsFreeVars dicts `unionVarSet` tys_fvs tys_fvs = tyVarsOfTypes (catMaybes tys) @@ -1245,19 +1318,17 @@ interestingDict _ = True \begin{code} plusUDs :: UsageDetails -> UsageDetails -> UsageDetails -plusUDs (MkUD {dict_binds = db1, calls = calls1, ud_fvs = fvs1}) - (MkUD {dict_binds = db2, calls = calls2, ud_fvs = fvs2}) - = MkUD {dict_binds = d, calls = c, ud_fvs = fvs1 `unionVarSet` fvs2} - where - d = db1 `unionBags` db2 - c = calls1 `unionCalls` calls2 +plusUDs (MkUD {ud_binds = db1, ud_calls = calls1}) + (MkUD {ud_binds = db2, ud_calls = calls2}) + = MkUD { ud_binds = db1 `unionBags` db2 + , ud_calls = calls1 `unionCalls` calls2 } plusUDList :: [UsageDetails] -> UsageDetails plusUDList = foldr plusUDs emptyUDs --- zapCalls deletes calls to ids from uds -zapCalls :: [Id] -> CallDetails -> CallDetails -zapCalls ids calls = delListFromFM calls ids +----------------------------- +_dictBindBndrs :: Bag DictBind -> [Id] +_dictBindBndrs dbs = foldrBag ((++) . bindersOf . fst) [] dbs mkDB :: CoreBind -> DictBind mkDB bind = (bind, bind_fvs bind) @@ -1277,45 +1348,97 @@ pair_fvs (bndr, rhs) = exprFreeVars rhs `unionVarSet` idFreeVars bndr -- type T a = Int -- x :: T a = 3 -addDictBind :: (Id,CoreExpr) -> UsageDetails -> UsageDetails -addDictBind (dict,rhs) uds - = uds { dict_binds = db `consBag` dict_binds uds - , ud_fvs = ud_fvs uds `unionVarSet` fvs } +flattenDictBinds :: Bag DictBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)] +flattenDictBinds dbs pairs + = foldrBag add pairs dbs where - db@(_, fvs) = mkDB (NonRec dict rhs) + add (NonRec b r,_) pairs = (b,r) : pairs + add (Rec prs1, _) pairs = prs1 ++ pairs + +snocDictBinds :: UsageDetails -> [CoreBind] -> UsageDetails +-- Add ud_binds to the tail end of the bindings in uds +snocDictBinds uds dbs + = uds { ud_binds = ud_binds uds `unionBags` + foldr (consBag . mkDB) emptyBag dbs } -dumpAllDictBinds :: UsageDetails -> [CoreBind] -> [CoreBind] -dumpAllDictBinds (MkUD {dict_binds = dbs}) binds +consDictBind :: CoreBind -> UsageDetails -> UsageDetails +consDictBind bind uds = uds { ud_binds = mkDB bind `consBag` ud_binds uds } + +snocDictBind :: UsageDetails -> CoreBind -> UsageDetails +snocDictBind uds bind = uds { ud_binds = ud_binds uds `snocBag` mkDB bind } + +wrapDictBinds :: Bag DictBind -> [CoreBind] -> [CoreBind] +wrapDictBinds dbs binds = foldrBag add binds dbs where add (bind,_) binds = bind : binds -dumpUDs :: [CoreBndr] - -> UsageDetails -> CoreExpr - -> (UsageDetails, CoreExpr) -dumpUDs bndrs (MkUD { dict_binds = orig_dbs - , calls = orig_calls - , ud_fvs = fvs}) body - = (new_uds, foldrBag add_let body dump_dbs) - -- This may delete fewer variables - -- than in priciple possible +wrapDictBindsE :: Bag DictBind -> CoreExpr -> CoreExpr +wrapDictBindsE dbs expr + = foldrBag add expr dbs where - new_uds = - MkUD { dict_binds = free_dbs - , calls = free_calls - , ud_fvs = fvs `minusVarSet` bndr_set} - + add (bind,_) expr = Let bind expr + +---------------------- +dumpUDs :: [CoreBndr] -> UsageDetails -> (UsageDetails, Bag DictBind) +-- Used at a lambda or case binder; just dump anything mentioning the binder +dumpUDs bndrs uds@(MkUD { ud_binds = orig_dbs, ud_calls = orig_calls }) + | null bndrs = (uds, emptyBag) -- Common in case alternatives + | otherwise = (free_uds, dump_dbs) + where + free_uds = MkUD { ud_binds = free_dbs, ud_calls = free_calls } bndr_set = mkVarSet bndrs - add_let (bind,_) body = Let bind body + (free_dbs, dump_dbs, dump_set) = splitDictBinds orig_dbs bndr_set + free_calls = deleteCallsMentioning dump_set $ -- Drop calls mentioning bndr_set on the floor + deleteCallsFor bndrs orig_calls -- Discard calls for bndr_set; there should be + -- no calls for any of the dicts in dump_dbs + +dumpBindUDs :: [CoreBndr] -> UsageDetails -> (UsageDetails, Bag DictBind, Bool) +-- Used at a lambda or case binder; just dump anything mentioning the binder +dumpBindUDs bndrs (MkUD { ud_binds = orig_dbs, ud_calls = orig_calls }) + = (free_uds, dump_dbs, float_all) + where + free_uds = MkUD { ud_binds = free_dbs, ud_calls = free_calls } + bndr_set = mkVarSet bndrs + (free_dbs, dump_dbs, dump_set) = splitDictBinds orig_dbs bndr_set + free_calls = deleteCallsFor bndrs orig_calls + float_all = dump_set `intersectsVarSet` callDetailsFVs free_calls + +callsForMe :: Id -> UsageDetails -> (UsageDetails, [CallInfo]) +callsForMe fn (MkUD { ud_binds = orig_dbs, ud_calls = orig_calls }) + = -- pprTrace ("callsForMe") + -- (vcat [ppr fn, + -- text "Orig dbs =" <+> ppr (_dictBindBndrs orig_dbs), + -- text "Orig calls =" <+> ppr orig_calls, + -- text "Dep set =" <+> ppr dep_set, + -- text "Calls for me =" <+> ppr calls_for_me]) $ + (uds_without_me, calls_for_me) + where + uds_without_me = MkUD { ud_binds = orig_dbs, ud_calls = delVarEnv orig_calls fn } + calls_for_me = case lookupVarEnv orig_calls fn of + Nothing -> [] + Just cs -> filter_dfuns (fmToList cs) - (free_dbs, dump_dbs, dump_set) - = foldlBag dump_db (emptyBag, emptyBag, bndr_set) orig_dbs - -- Important that it's foldl not foldr; - -- we're accumulating the set of dumped ids in dump_set + dep_set = foldlBag go (unitVarSet fn) orig_dbs + go dep_set (db,fvs) | fvs `intersectsVarSet` dep_set + = extendVarSetList dep_set (bindersOf db) + | otherwise = fvs - free_calls = filterCalls dump_set orig_calls + -- Note [Specialisation of dictionary functions] + filter_dfuns | isDFunId fn = filter ok_call + | otherwise = \cs -> cs - dump_db (free_dbs, dump_dbs, dump_idset) db@(bind, fvs) + ok_call (_, (_,fvs)) = not (fvs `intersectsVarSet` dep_set) + +---------------------- +splitDictBinds :: Bag DictBind -> IdSet -> (Bag DictBind, Bag DictBind, IdSet) +-- Returns (free_dbs, dump_dbs, dump_set) +splitDictBinds dbs bndr_set + = foldlBag split_db (emptyBag, emptyBag, bndr_set) dbs + -- Important that it's foldl not foldr; + -- we're accumulating the set of dumped ids in dump_set + where + split_db (free_dbs, dump_dbs, dump_idset) db@(bind, fvs) | dump_idset `intersectsVarSet` fvs -- Dump it = (free_dbs, dump_dbs `snocBag` db, extendVarSetList dump_idset (bindersOf bind)) @@ -1323,14 +1446,19 @@ dumpUDs bndrs (MkUD { dict_binds = orig_dbs | otherwise -- Don't dump it = (free_dbs `snocBag` db, dump_dbs, dump_idset) -filterCalls :: VarSet -> CallDetails -> CallDetails --- Remove any calls that mention the variables -filterCalls bs calls - = mapFM (\_ cs -> filter_calls cs) $ - filterFM (\k _ -> not (k `elemVarSet` bs)) calls + +---------------------- +deleteCallsMentioning :: VarSet -> CallDetails -> CallDetails +-- Remove calls *mentioning* bs +deleteCallsMentioning bs calls + = mapVarEnv filter_calls calls where - filter_calls :: CallInfo -> CallInfo + filter_calls :: CallInfoSet -> CallInfoSet filter_calls = filterFM (\_ (_, fvs) -> not (fvs `intersectsVarSet` bs)) + +deleteCallsFor :: [Id] -> CallDetails -> CallDetails +-- Remove calls *for* bs +deleteCallsFor bs calls = delVarEnvList calls bs \end{code} |