summaryrefslogtreecommitdiff
path: root/compiler/specialise
diff options
context:
space:
mode:
authorsimonpj@microsoft.com <unknown>2009-10-23 16:15:51 +0000
committersimonpj@microsoft.com <unknown>2009-10-23 16:15:51 +0000
commitc43c981705ec33da92a9ce91eb90f2ecf00be9fe (patch)
treeef6725e233f3481f5121561671370f818c6ec2fa /compiler/specialise
parent0cffd31b0f25c2a31ed6eff2c0c0b1b1a8a8d507 (diff)
downloadhaskell-c43c981705ec33da92a9ce91eb90f2ecf00be9fe.tar.gz
Fix Trac #3591: very tricky specialiser bug
There was a subtle bug in the interation of specialisation and floating, described in Note [Specialisation of dictionary functions]. The net effect was to create a loop where none existed before; plain wrong. In fixing it, I did quite a bit of house-cleaning in the specialiser, and added a lot more comments. It's tricky, alas.
Diffstat (limited to 'compiler/specialise')
-rw-r--r--compiler/specialise/Specialise.lhs488
1 files changed, 308 insertions, 180 deletions
diff --git a/compiler/specialise/Specialise.lhs b/compiler/specialise/Specialise.lhs
index 64d0cdd7a3..590e689f4a 100644
--- a/compiler/specialise/Specialise.lhs
+++ b/compiler/specialise/Specialise.lhs
@@ -14,30 +14,17 @@ module Specialise ( specProgram ) where
#include "HsVersions.h"
-import Id ( Id, idName, idType, mkUserLocal, idCoreRules, idUnfolding,
- idInlineActivation, setInlineActivation, setIdUnfolding,
- isLocalId, isDataConWorkId, idArity, setIdArity )
-import TcType ( Type, mkTyVarTy, tcSplitSigmaTy,
- tyVarsOfTypes, tyVarsOfTheta, isClassPred,
- tcCmpType, isUnLiftedType
- )
-import CoreSubst ( Subst, mkEmptySubst, extendTvSubstList, lookupIdSubst,
- substBndr, substBndrs, substTy, substInScope,
- cloneIdBndr, cloneIdBndrs, cloneRecIdBndrs,
- extendIdSubst
- )
+import Id
+import TcType
+import CoreSubst
import CoreUnfold ( mkUnfolding )
-import Var ( DictId )
import VarSet
import VarEnv
import CoreSyn
import Rules
import CoreUtils ( exprIsTrivial, applyTypeToArgs, mkPiTypes )
import CoreFVs ( exprFreeVars, exprsFreeVars, idFreeVars )
-import UniqSupply ( UniqSupply,
- UniqSM, initUs_,
- MonadUnique(..)
- )
+import UniqSupply ( UniqSupply, UniqSM, initUs_, MonadUnique(..) )
import Name
import MkId ( voidArgId, realWorldPrimId )
import FiniteMap
@@ -575,8 +562,9 @@ Hence, the invariant is this:
\begin{code}
specProgram :: UniqSupply -> [CoreBind] -> [CoreBind]
-specProgram us binds = initSM us (do (binds', uds') <- go binds
- return (dumpAllDictBinds uds' binds'))
+specProgram us binds = initSM us $
+ do { (binds', uds') <- go binds
+ ; return (wrapDictBinds (ud_binds uds') binds') }
where
-- We need to start with a Subst that knows all the things
-- that are in scope, so that the substitution engine doesn't
@@ -609,12 +597,12 @@ specExpr :: Subst -> CoreExpr -> SpecM (CoreExpr, UsageDetails)
-- the RHS of specialised bindings (no type-let!)
---------------- First the easy cases --------------------
-specExpr subst (Type ty) = return (Type (substTy subst ty), emptyUDs)
+specExpr subst (Type ty) = return (Type (CoreSubst.substTy subst ty), emptyUDs)
specExpr subst (Var v) = return (specVar subst v, emptyUDs)
specExpr _ (Lit lit) = return (Lit lit, emptyUDs)
specExpr subst (Cast e co) = do
(e', uds) <- specExpr subst e
- return ((Cast e' (substTy subst co)), uds)
+ return ((Cast e' (CoreSubst.substTy subst co)), uds)
specExpr subst (Note note body) = do
(body', uds) <- specExpr subst body
return (Note (specNote subst note) body', uds)
@@ -636,8 +624,8 @@ specExpr subst expr@(App {})
---------------- Lambda/case require dumping of usage details --------------------
specExpr subst e@(Lam _ _) = do
(body', uds) <- specExpr subst' body
- let (filtered_uds, body'') = dumpUDs bndrs' uds body'
- return (mkLams bndrs' body'', filtered_uds)
+ let (free_uds, dumped_dbs) = dumpUDs bndrs' uds
+ return (mkLams bndrs' (wrapDictBindsE dumped_dbs body'), free_uds)
where
(bndrs, body) = collectBinders e
(subst', bndrs') = substBndrs subst bndrs
@@ -647,15 +635,16 @@ specExpr subst e@(Lam _ _) = do
specExpr subst (Case scrut case_bndr ty alts) = do
(scrut', uds_scrut) <- specExpr subst scrut
(alts', uds_alts) <- mapAndCombineSM spec_alt alts
- return (Case scrut' case_bndr' (substTy subst ty) alts', uds_scrut `plusUDs` uds_alts)
+ return (Case scrut' case_bndr' (CoreSubst.substTy subst ty) alts',
+ uds_scrut `plusUDs` uds_alts)
where
(subst_alt, case_bndr') = substBndr subst case_bndr
-- No need to clone case binder; it can't float like a let(rec)
spec_alt (con, args, rhs) = do
(rhs', uds) <- specExpr subst_rhs rhs
- let (uds', rhs'') = dumpUDs args uds rhs'
- return ((con, args', rhs''), uds')
+ let (free_uds, dumped_dbs) = dumpUDs args' uds
+ return ((con, args', wrapDictBindsE dumped_dbs rhs'), free_uds)
where
(subst_rhs, args') = substBndrs subst_alt args
@@ -691,92 +680,72 @@ specBind :: Subst -- Use this for RHSs
-> SpecM ([CoreBind], -- New bindings
UsageDetails) -- And info to pass upstream
-specBind rhs_subst bind body_uds
- = do { (bind', bind_uds) <- specBindItself rhs_subst bind (calls body_uds)
- ; return (finishSpecBind bind' bind_uds body_uds) }
-
-finishSpecBind :: CoreBind -> UsageDetails -> UsageDetails -> ([CoreBind], UsageDetails)
-finishSpecBind bind
- (MkUD { dict_binds = rhs_dbs, calls = rhs_calls, ud_fvs = rhs_fvs })
- (MkUD { dict_binds = body_dbs, calls = body_calls, ud_fvs = body_fvs })
- | not (mkVarSet bndrs `intersectsVarSet` all_fvs)
- -- Common case 1: the bound variables are not
- -- mentioned in the dictionary bindings
- = ([bind], MkUD { dict_binds = body_dbs `unionBags` rhs_dbs
- -- It's important that the `unionBags` is this way round,
- -- because body_uds may bind dictionaries that are
- -- used in the calls passed to specDefn. So the
- -- dictionary bindings in rhs_uds may mention
- -- dictionaries bound in body_uds.
- , calls = all_calls
- , ud_fvs = all_fvs })
-
- | case bind of { NonRec {} -> True; Rec {} -> False }
- -- Common case 2: no specialisation happened, and binding
- -- is non-recursive. But the binding may be
- -- mentioned in body_dbs, so we should put it first
- = ([], MkUD { dict_binds = rhs_dbs `unionBags` ((bind, b_fvs) `consBag` body_dbs)
- , calls = all_calls
- , ud_fvs = all_fvs `unionVarSet` b_fvs })
-
- | otherwise -- General case: make a huge Rec (sigh)
- = ([], MkUD { dict_binds = unitBag (Rec all_db_prs, all_db_fvs)
- , calls = all_calls
- , ud_fvs = all_fvs `unionVarSet` b_fvs })
- where
- all_fvs = rhs_fvs `unionVarSet` body_fvs
- all_calls = zapCalls bndrs (rhs_calls `unionCalls` body_calls)
+-- Returned UsageDetails:
+-- No calls for binders of this bind
+specBind rhs_subst (NonRec fn rhs) body_uds
+ = do { (rhs', rhs_uds) <- specExpr rhs_subst rhs
+ ; (fn', spec_defns, body_uds1) <- specDefn rhs_subst body_uds fn rhs
- bndrs = bindersOf bind
- b_fvs = bind_fvs bind
+ ; let pairs = spec_defns ++ [(fn', rhs')]
+ -- fn' mentions the spec_defns in its rules,
+ -- so put the latter first
- (all_db_prs, all_db_fvs) = add (bind, b_fvs) $
- foldrBag add ([], emptyVarSet) $
- rhs_dbs `unionBags` body_dbs
- add (NonRec b r, b_fvs) (prs, fvs) = ((b,r) : prs, b_fvs `unionVarSet` fvs)
- add (Rec b_prs, b_fvs) (prs, fvs) = (b_prs ++ prs, b_fvs `unionVarSet` fvs)
+ combined_uds = body_uds1 `plusUDs` rhs_uds
+ -- This way round a call in rhs_uds of a function f
+ -- at type T will override a call of f at T in body_uds1; and
+ -- that is good because it'll tend to keep "earlier" calls
+ -- See Note [Specialisation of dictionary functions]
----------------------------
-specBindItself :: Subst -> CoreBind -> CallDetails -> SpecM (CoreBind, UsageDetails)
-
--- specBindItself deals with the RHS, specialising it according
--- to the calls found in the body (if any)
-specBindItself rhs_subst (NonRec fn rhs) call_info
- = do { (rhs', rhs_uds) <- specExpr rhs_subst rhs -- Do RHS of original fn
- ; (fn', spec_defns, spec_uds) <- specDefn rhs_subst call_info fn rhs
- ; if null spec_defns then
- return (NonRec fn rhs', rhs_uds)
- else
- return (Rec ((fn',rhs') : spec_defns), rhs_uds `plusUDs` spec_uds) }
- -- bndr' mentions the spec_defns in its SpecEnv
- -- Not sure why we couln't just put the spec_defns first
-
-specBindItself rhs_subst (Rec pairs) call_info
+ (free_uds, dump_dbs, float_all) = dumpBindUDs [fn] combined_uds
+ -- See Note [From non-recursive to recursive]
+
+ final_binds | isEmptyBag dump_dbs = [NonRec b r | (b,r) <- pairs]
+ | otherwise = [Rec (flattenDictBinds dump_dbs pairs)]
+
+ ; if float_all then
+ -- Rather than discard the calls mentioning the bound variables
+ -- we float this binding along with the others
+ return ([], free_uds `snocDictBinds` final_binds)
+ else
+ -- No call in final_uds mentions bound variables,
+ -- so we can just leave the binding here
+ return (final_binds, free_uds) }
+
+
+specBind rhs_subst (Rec pairs) body_uds
-- Note [Specialising a recursive group]
= do { let (bndrs,rhss) = unzip pairs
; (rhss', rhs_uds) <- mapAndCombineSM (specExpr rhs_subst) rhss
- ; let all_calls = call_info `unionCalls` calls rhs_uds
- ; (bndrs1, spec_defns1, spec_uds1) <- specDefns rhs_subst all_calls pairs
-
- ; if null spec_defns1 then -- Common case: no specialisation
- return (Rec (bndrs `zip` rhss'), rhs_uds)
- else do -- Specialisation occurred; do it again
- { (bndrs2, spec_defns2, spec_uds2) <-
- -- pprTrace "specB" (ppr bndrs $$ ppr rhs_uds) $
- specDefns rhs_subst (calls spec_uds1) (bndrs1 `zip` rhss)
-
- ; let all_defns = spec_defns1 ++ spec_defns2 ++ zip bndrs2 rhss'
+ ; let scope_uds = body_uds `plusUDs` rhs_uds
+ -- Includes binds and calls arising from rhss
+
+ ; (bndrs1, spec_defns1, uds1) <- specDefns rhs_subst scope_uds pairs
+
+ ; (bndrs3, spec_defns3, uds3)
+ <- if null spec_defns1 -- Common case: no specialisation
+ then return (bndrs1, [], uds1)
+ else do { -- Specialisation occurred; do it again
+ (bndrs2, spec_defns2, uds2)
+ <- specDefns rhs_subst uds1 (bndrs1 `zip` rhss)
+ ; return (bndrs2, spec_defns2 ++ spec_defns1, uds2) }
+
+ ; let (final_uds, dumped_dbs, float_all) = dumpBindUDs bndrs uds3
+ bind = Rec (flattenDictBinds dumped_dbs $
+ spec_defns3 ++ zip bndrs3 rhss')
- ; return (Rec all_defns, rhs_uds `plusUDs` spec_uds1 `plusUDs` spec_uds2) } }
+ ; if float_all then
+ return ([], final_uds `snocDictBind` bind)
+ else
+ return ([bind], final_uds) }
---------------------------
specDefns :: Subst
- -> CallDetails -- Info on how it is used in its scope
- -> [(Id,CoreExpr)] -- The things being bound and their un-processed RHS
- -> SpecM ([Id], -- Original Ids with RULES added
- [(Id,CoreExpr)], -- Extra, specialised bindings
- UsageDetails) -- Stuff to fling upwards from the specialised versions
+ -> UsageDetails -- Info on how it is used in its scope
+ -> [(Id,CoreExpr)] -- The things being bound and their un-processed RHS
+ -> SpecM ([Id], -- Original Ids with RULES added
+ [(Id,CoreExpr)], -- Extra, specialised bindings
+ UsageDetails) -- Stuff to fling upwards from the specialised versions
-- Specialise a list of bindings (the contents of a Rec), but flowing usages
-- upwards binding by binding. Example: { f = ...g ...; g = ...f .... }
@@ -784,25 +753,22 @@ specDefns :: Subst
-- in turn generates a specialised call for 'f', we catch that in this one sweep.
-- But not vice versa (it's a fixpoint problem).
-specDefns _subst _call_info []
- = return ([], [], emptyUDs)
-specDefns subst call_info ((bndr,rhs):pairs)
- = do { (bndrs', spec_defns, spec_uds) <- specDefns subst call_info pairs
- ; let all_calls = call_info `unionCalls` calls spec_uds
- ; (bndr', spec_defns1, spec_uds1) <- specDefn subst all_calls bndr rhs
- ; return (bndr' : bndrs',
- spec_defns1 ++ spec_defns,
- spec_uds1 `plusUDs` spec_uds) }
+specDefns _subst uds []
+ = return ([], [], uds)
+specDefns subst uds ((bndr,rhs):pairs)
+ = do { (bndrs1, spec_defns1, uds1) <- specDefns subst uds pairs
+ ; (bndr1, spec_defns2, uds2) <- specDefn subst uds1 bndr rhs
+ ; return (bndr1 : bndrs1, spec_defns1 ++ spec_defns2, uds2) }
---------------------------
specDefn :: Subst
- -> CallDetails -- Info on how it is used in its scope
+ -> UsageDetails -- Info on how it is used in its scope
-> Id -> CoreExpr -- The thing being bound and its un-processed RHS
-> SpecM (Id, -- Original Id with added RULES
[(Id,CoreExpr)], -- Extra, specialised bindings
UsageDetails) -- Stuff to fling upwards from the specialised versions
-specDefn subst calls fn rhs
+specDefn subst body_uds fn rhs
-- The first case is the interesting one
| rhs_tyvars `lengthIs` n_tyvars -- Rhs of fn's defn has right number of big lambdas
&& rhs_ids `lengthAtLeast` n_dicts -- and enough dict args
@@ -816,12 +782,20 @@ specDefn subst calls fn rhs
stuff <- mapM spec_call calls_for_me
; let (spec_defns, spec_uds, spec_rules) = unzip3 (catMaybes stuff)
fn' = addIdSpecialisations fn spec_rules
- ; return (fn', spec_defns, plusUDList spec_uds) }
+ final_uds = body_uds_without_me `plusUDs` plusUDList spec_uds
+ -- It's important that the `plusUDs` is this way
+ -- round, because body_uds_without_me may bind
+ -- dictionaries that are used in calls_for_me passed
+ -- to specDefn. So the dictionary bindings in
+ -- spec_uds may mention dictionaries bound in
+ -- body_uds_without_me
+
+ ; return (fn', spec_defns, final_uds) }
| otherwise -- No calls or RHS doesn't fit our preconceptions
= WARN( notNull calls_for_me, ptext (sLit "Missed specialisation opportunity for") <+> ppr fn )
-- Note [Specialisation shape]
- return (fn, [], emptyUDs)
+ return (fn, [], body_uds_without_me)
where
fn_type = idType fn
@@ -831,6 +805,8 @@ specDefn subst calls fn rhs
n_dicts = length theta
inline_act = idInlineActivation fn
+ (body_uds_without_me, calls_for_me) = callsForMe fn body_uds
+
-- It's important that we "see past" any INLINE pragma
-- else we'll fail to specialise an INLINE thing
(inline_rhs, rhs_inside) = dropInline rhs
@@ -840,10 +816,6 @@ specDefn subst calls fn rhs
body = mkLams (drop n_dicts rhs_ids) rhs_body
-- Glue back on the non-dict lambdas
- calls_for_me = case lookupFM calls fn of
- Nothing -> []
- Just cs -> fmToList cs
-
already_covered :: [CoreExpr] -> Bool
already_covered args -- Note [Specialisations already covered]
= isJust (lookupRule (const True) (substInScope subst)
@@ -857,7 +829,7 @@ specDefn subst calls fn rhs
----------------------------------------------------------
-- Specialise to one particular call pattern
- spec_call :: (CallKey, ([DictExpr], VarSet)) -- Call instance
+ spec_call :: CallInfo -- Call instance
-> SpecM (Maybe ((Id,CoreExpr), -- Specialised definition
UsageDetails, -- Usage details from specialised body
CoreRule)) -- Info for the Id's SpecEnv
@@ -885,7 +857,7 @@ specDefn subst calls fn rhs
spec_tv_binds = [(tv,ty) | (tv, Just ty) <- rhs_tyvars `zip` call_ts]
spec_ty_args = map snd spec_tv_binds
ty_args = mk_ty_args call_ts
- rhs_subst = extendTvSubstList subst spec_tv_binds
+ rhs_subst = CoreSubst.extendTvSubstList subst spec_tv_binds
; (rhs_subst1, inst_dict_ids) <- cloneDictBndrs rhs_subst rhs_dict_ids
-- Clone rhs_dicts, including instantiating their types
@@ -924,7 +896,7 @@ specDefn subst calls fn rhs
(mkVarApps (Var spec_f_w_arity) app_args)
-- Add the { d1' = dx1; d2' = dx2 } usage stuff
- final_uds = foldr addDictBind rhs_uds dx_binds
+ final_uds = foldr consDictBind rhs_uds dx_binds
spec_pr | inline_rhs = (spec_f_w_arity `setInlineActivation` inline_act, Note InlineMe spec_rhs)
| otherwise = (spec_f_w_arity, spec_rhs)
@@ -944,7 +916,7 @@ bindAuxiliaryDicts
:: Subst
-> [(DictId,DictId,CoreExpr)] -- (orig_dict, inst_dict, dx)
-> (Subst, -- Substitute for all orig_dicts
- [(DictId, CoreExpr)]) -- Auxiliary bindings
+ [CoreBind]) -- Auxiliary bindings
-- Bind any dictionary arguments to fresh names, to preserve sharing
-- Substitution already substitutes orig_dict -> inst_dict
bindAuxiliaryDicts subst triples = go subst [] triples
@@ -953,7 +925,7 @@ bindAuxiliaryDicts subst triples = go subst [] triples
go subst binds ((d, dx_id, dx) : pairs)
| exprIsTrivial dx = go (extendIdSubst subst d dx) binds pairs
-- No auxiliary binding necessary
- | otherwise = go subst_w_unf ((dx_id,dx) : binds) pairs
+ | otherwise = go subst_w_unf (NonRec dx_id dx : binds) pairs
where
dx_id1 = dx_id `setIdUnfolding` mkUnfolding False dx
subst_w_unf = extendIdSubst subst d (Var dx_id1)
@@ -967,6 +939,96 @@ bindAuxiliaryDicts subst triples = go subst [] triples
-- We want that consequent call to look interesting
\end{code}
+Note [From non-recursive to recursive]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Even in the non-recursive case, if any dict-binds depend on 'fn' we might
+have built a recursive knot
+
+ f a d x = <blah>
+ MkUD { ud_binds = d7 = MkD ..f..
+ , ud_calls = ...(f T d7)... }
+
+The we generate
+
+ Rec { fs x = <blah>[T/a, d7/d]
+ f a d x = <blah>
+ RULE f T _ = fs
+ d7 = ...f... }
+
+Here the recursion is only through the RULE.
+
+
+Note [Specialisation of dictionary functions]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Here is a nasty example that bit us badly: see Trac #3591
+
+ dfun a d = MkD a d (meth d)
+ d4 = <blah>
+ d2 = dfun T d4
+ d1 = $p1 d2
+ d3 = dfun T d1
+
+None of these definitions is recursive. What happened was that we
+generated a specialisation:
+
+ RULE forall d. dfun T d = dT
+ dT = (MkD a d (meth d)) [T/a, d1/d]
+ = MkD T d1 (meth d1)
+
+But now we use the RULE on the RHS of d2, to get
+
+ d2 = dT = MkD d1 (meth d1)
+ d1 = $p1 d2
+
+and now d1 is bottom! The problem is that when specialising 'dfun' we
+should first dump "below" the binding all floated dictionary bindings
+that mention 'dfun' itself. So d2 and d3 (and hence d1) must be
+placed below 'dfun', and thus unavailable to it when specialising
+'dfun'. That in turn means that the call (dfun T d1) must be
+discarded. On the other hand, the call (dfun T d4) is fine, assuming
+d4 doesn't mention dfun.
+
+But look at this:
+
+ class C a where { foo,bar :: [a] -> [a] }
+
+ instance C Int where
+ foo x = r_bar x
+ bar xs = reverse xs
+
+ r_bar :: C a => [a] -> [a]
+ r_bar xs = bar (xs ++ xs)
+
+That translates to:
+
+ r_bar a (c::C a) (xs::[a]) = bar a d (xs ++ xs)
+
+ Rec { $fCInt :: C Int = MkC foo_help reverse
+ foo_help (xs::[Int]) = r_bar Int $fCInt xs }
+
+The call (r_bar $fCInt) mentions $fCInt,
+ which mentions foo_help,
+ which mentions r_bar
+But we DO want to specialise r_bar at Int:
+
+ Rec { $fCInt :: C Int = MkC foo_help reverse
+ foo_help (xs::[Int]) = r_bar Int $fCInt xs
+
+ r_bar a (c::C a) (xs::[a]) = bar a d (xs ++ xs)
+ RULE r_bar Int _ = r_bar_Int
+
+ r_bar_Int xs = bar Int $fCInt (xs ++ xs)
+ }
+
+Note that, because of its RULE, r_bar joins the recursive
+group. (In this case it'll unravel a short moment later.)
+
+
+Conclusion: we catch the nasty case using filter_dfuns in
+callsForMe To be honest I'm not 100% certain that this is 100%
+right, but it works. Sigh.
+
+
Note [Specialising a recursive group]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider
@@ -1110,24 +1172,24 @@ dropInline rhs = (False, rhs)
\begin{code}
data UsageDetails
= MkUD {
- dict_binds :: !(Bag DictBind),
+ ud_binds :: !(Bag DictBind),
-- Floated dictionary bindings
-- The order is important;
-- in ds1 `union` ds2, bindings in ds2 can depend on those in ds1
-- (Remember, Bags preserve order in GHC.)
- calls :: !CallDetails,
+ ud_calls :: !CallDetails
- ud_fvs :: !VarSet -- A superset of the variables mentioned in
- -- either dict_binds or calls
+ -- INVARIANT: suppose bs = bindersOf ud_binds
+ -- Then 'calls' may *mention* 'bs',
+ -- but there should be no calls *for* bs
}
instance Outputable UsageDetails where
- ppr (MkUD { dict_binds = dbs, calls = calls, ud_fvs = fvs })
+ ppr (MkUD { ud_binds = dbs, ud_calls = calls })
= ptext (sLit "MkUD") <+> braces (sep (punctuate comma
[ptext (sLit "binds") <+> equals <+> ppr dbs,
- ptext (sLit "calls") <+> equals <+> ppr calls,
- ptext (sLit "fvs") <+> equals <+> ppr fvs]))
+ ptext (sLit "calls") <+> equals <+> ppr calls]))
type DictBind = (CoreBind, VarSet)
-- The set is the free vars of the binding
@@ -1136,10 +1198,10 @@ type DictBind = (CoreBind, VarSet)
type DictExpr = CoreExpr
emptyUDs :: UsageDetails
-emptyUDs = MkUD { dict_binds = emptyBag, calls = emptyFM, ud_fvs = emptyVarSet }
+emptyUDs = MkUD { ud_binds = emptyBag, ud_calls = emptyVarEnv }
------------------------------------------------------------
-type CallDetails = FiniteMap Id CallInfo
+type CallDetails = IdEnv CallInfoSet
newtype CallKey = CallKey [Maybe Type] -- Nothing => unconstrained type argument
-- CallInfo uses a FiniteMap, thereby ensuring that
@@ -1147,11 +1209,13 @@ newtype CallKey = CallKey [Maybe Type] -- Nothing => unconstrained type argu
--
-- The list of types and dictionaries is guaranteed to
-- match the type of f
-type CallInfo = FiniteMap CallKey ([DictExpr], VarSet)
+type CallInfoSet = FiniteMap CallKey ([DictExpr], VarSet)
-- Range is dict args and the vars of the whole
-- call (including tyvars)
-- [*not* include the main id itself, of course]
+type CallInfo = (CallKey, ([DictExpr], VarSet))
+
instance Outputable CallKey where
ppr (CallKey ts) = ppr ts
@@ -1169,13 +1233,22 @@ instance Ord CallKey where
cmp (Just t1) (Just t2) = tcCmpType t1 t2
unionCalls :: CallDetails -> CallDetails -> CallDetails
-unionCalls c1 c2 = plusFM_C plusFM c1 c2
+unionCalls c1 c2 = plusVarEnv_C plusFM c1 c2
+
+-- plusCalls :: UsageDetails -> CallDetails -> UsageDetails
+-- plusCalls uds call_ds = uds { ud_calls = ud_calls uds `unionCalls` call_ds }
+
+callDetailsFVs :: CallDetails -> VarSet
+callDetailsFVs calls = foldVarEnv (unionVarSet . callInfoFVs) emptyVarSet calls
+callInfoFVs :: CallInfoSet -> VarSet
+callInfoFVs call_info = foldFM (\_ (_,fv) vs -> unionVarSet fv vs) emptyVarSet call_info
+
+------------------------------------------------------------
singleCall :: Id -> [Maybe Type] -> [DictExpr] -> UsageDetails
singleCall id tys dicts
- = MkUD {dict_binds = emptyBag,
- calls = unitFM id (unitFM (CallKey tys) (dicts, call_fvs)),
- ud_fvs = call_fvs }
+ = MkUD {ud_binds = emptyBag,
+ ud_calls = unitVarEnv id (unitFM (CallKey tys) (dicts, call_fvs)) }
where
call_fvs = exprsFreeVars dicts `unionVarSet` tys_fvs
tys_fvs = tyVarsOfTypes (catMaybes tys)
@@ -1245,19 +1318,17 @@ interestingDict _ = True
\begin{code}
plusUDs :: UsageDetails -> UsageDetails -> UsageDetails
-plusUDs (MkUD {dict_binds = db1, calls = calls1, ud_fvs = fvs1})
- (MkUD {dict_binds = db2, calls = calls2, ud_fvs = fvs2})
- = MkUD {dict_binds = d, calls = c, ud_fvs = fvs1 `unionVarSet` fvs2}
- where
- d = db1 `unionBags` db2
- c = calls1 `unionCalls` calls2
+plusUDs (MkUD {ud_binds = db1, ud_calls = calls1})
+ (MkUD {ud_binds = db2, ud_calls = calls2})
+ = MkUD { ud_binds = db1 `unionBags` db2
+ , ud_calls = calls1 `unionCalls` calls2 }
plusUDList :: [UsageDetails] -> UsageDetails
plusUDList = foldr plusUDs emptyUDs
--- zapCalls deletes calls to ids from uds
-zapCalls :: [Id] -> CallDetails -> CallDetails
-zapCalls ids calls = delListFromFM calls ids
+-----------------------------
+_dictBindBndrs :: Bag DictBind -> [Id]
+_dictBindBndrs dbs = foldrBag ((++) . bindersOf . fst) [] dbs
mkDB :: CoreBind -> DictBind
mkDB bind = (bind, bind_fvs bind)
@@ -1277,45 +1348,97 @@ pair_fvs (bndr, rhs) = exprFreeVars rhs `unionVarSet` idFreeVars bndr
-- type T a = Int
-- x :: T a = 3
-addDictBind :: (Id,CoreExpr) -> UsageDetails -> UsageDetails
-addDictBind (dict,rhs) uds
- = uds { dict_binds = db `consBag` dict_binds uds
- , ud_fvs = ud_fvs uds `unionVarSet` fvs }
+flattenDictBinds :: Bag DictBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
+flattenDictBinds dbs pairs
+ = foldrBag add pairs dbs
where
- db@(_, fvs) = mkDB (NonRec dict rhs)
+ add (NonRec b r,_) pairs = (b,r) : pairs
+ add (Rec prs1, _) pairs = prs1 ++ pairs
+
+snocDictBinds :: UsageDetails -> [CoreBind] -> UsageDetails
+-- Add ud_binds to the tail end of the bindings in uds
+snocDictBinds uds dbs
+ = uds { ud_binds = ud_binds uds `unionBags`
+ foldr (consBag . mkDB) emptyBag dbs }
-dumpAllDictBinds :: UsageDetails -> [CoreBind] -> [CoreBind]
-dumpAllDictBinds (MkUD {dict_binds = dbs}) binds
+consDictBind :: CoreBind -> UsageDetails -> UsageDetails
+consDictBind bind uds = uds { ud_binds = mkDB bind `consBag` ud_binds uds }
+
+snocDictBind :: UsageDetails -> CoreBind -> UsageDetails
+snocDictBind uds bind = uds { ud_binds = ud_binds uds `snocBag` mkDB bind }
+
+wrapDictBinds :: Bag DictBind -> [CoreBind] -> [CoreBind]
+wrapDictBinds dbs binds
= foldrBag add binds dbs
where
add (bind,_) binds = bind : binds
-dumpUDs :: [CoreBndr]
- -> UsageDetails -> CoreExpr
- -> (UsageDetails, CoreExpr)
-dumpUDs bndrs (MkUD { dict_binds = orig_dbs
- , calls = orig_calls
- , ud_fvs = fvs}) body
- = (new_uds, foldrBag add_let body dump_dbs)
- -- This may delete fewer variables
- -- than in priciple possible
+wrapDictBindsE :: Bag DictBind -> CoreExpr -> CoreExpr
+wrapDictBindsE dbs expr
+ = foldrBag add expr dbs
where
- new_uds =
- MkUD { dict_binds = free_dbs
- , calls = free_calls
- , ud_fvs = fvs `minusVarSet` bndr_set}
-
+ add (bind,_) expr = Let bind expr
+
+----------------------
+dumpUDs :: [CoreBndr] -> UsageDetails -> (UsageDetails, Bag DictBind)
+-- Used at a lambda or case binder; just dump anything mentioning the binder
+dumpUDs bndrs uds@(MkUD { ud_binds = orig_dbs, ud_calls = orig_calls })
+ | null bndrs = (uds, emptyBag) -- Common in case alternatives
+ | otherwise = (free_uds, dump_dbs)
+ where
+ free_uds = MkUD { ud_binds = free_dbs, ud_calls = free_calls }
bndr_set = mkVarSet bndrs
- add_let (bind,_) body = Let bind body
+ (free_dbs, dump_dbs, dump_set) = splitDictBinds orig_dbs bndr_set
+ free_calls = deleteCallsMentioning dump_set $ -- Drop calls mentioning bndr_set on the floor
+ deleteCallsFor bndrs orig_calls -- Discard calls for bndr_set; there should be
+ -- no calls for any of the dicts in dump_dbs
+
+dumpBindUDs :: [CoreBndr] -> UsageDetails -> (UsageDetails, Bag DictBind, Bool)
+-- Used at a lambda or case binder; just dump anything mentioning the binder
+dumpBindUDs bndrs (MkUD { ud_binds = orig_dbs, ud_calls = orig_calls })
+ = (free_uds, dump_dbs, float_all)
+ where
+ free_uds = MkUD { ud_binds = free_dbs, ud_calls = free_calls }
+ bndr_set = mkVarSet bndrs
+ (free_dbs, dump_dbs, dump_set) = splitDictBinds orig_dbs bndr_set
+ free_calls = deleteCallsFor bndrs orig_calls
+ float_all = dump_set `intersectsVarSet` callDetailsFVs free_calls
+
+callsForMe :: Id -> UsageDetails -> (UsageDetails, [CallInfo])
+callsForMe fn (MkUD { ud_binds = orig_dbs, ud_calls = orig_calls })
+ = -- pprTrace ("callsForMe")
+ -- (vcat [ppr fn,
+ -- text "Orig dbs =" <+> ppr (_dictBindBndrs orig_dbs),
+ -- text "Orig calls =" <+> ppr orig_calls,
+ -- text "Dep set =" <+> ppr dep_set,
+ -- text "Calls for me =" <+> ppr calls_for_me]) $
+ (uds_without_me, calls_for_me)
+ where
+ uds_without_me = MkUD { ud_binds = orig_dbs, ud_calls = delVarEnv orig_calls fn }
+ calls_for_me = case lookupVarEnv orig_calls fn of
+ Nothing -> []
+ Just cs -> filter_dfuns (fmToList cs)
- (free_dbs, dump_dbs, dump_set)
- = foldlBag dump_db (emptyBag, emptyBag, bndr_set) orig_dbs
- -- Important that it's foldl not foldr;
- -- we're accumulating the set of dumped ids in dump_set
+ dep_set = foldlBag go (unitVarSet fn) orig_dbs
+ go dep_set (db,fvs) | fvs `intersectsVarSet` dep_set
+ = extendVarSetList dep_set (bindersOf db)
+ | otherwise = fvs
- free_calls = filterCalls dump_set orig_calls
+ -- Note [Specialisation of dictionary functions]
+ filter_dfuns | isDFunId fn = filter ok_call
+ | otherwise = \cs -> cs
- dump_db (free_dbs, dump_dbs, dump_idset) db@(bind, fvs)
+ ok_call (_, (_,fvs)) = not (fvs `intersectsVarSet` dep_set)
+
+----------------------
+splitDictBinds :: Bag DictBind -> IdSet -> (Bag DictBind, Bag DictBind, IdSet)
+-- Returns (free_dbs, dump_dbs, dump_set)
+splitDictBinds dbs bndr_set
+ = foldlBag split_db (emptyBag, emptyBag, bndr_set) dbs
+ -- Important that it's foldl not foldr;
+ -- we're accumulating the set of dumped ids in dump_set
+ where
+ split_db (free_dbs, dump_dbs, dump_idset) db@(bind, fvs)
| dump_idset `intersectsVarSet` fvs -- Dump it
= (free_dbs, dump_dbs `snocBag` db,
extendVarSetList dump_idset (bindersOf bind))
@@ -1323,14 +1446,19 @@ dumpUDs bndrs (MkUD { dict_binds = orig_dbs
| otherwise -- Don't dump it
= (free_dbs `snocBag` db, dump_dbs, dump_idset)
-filterCalls :: VarSet -> CallDetails -> CallDetails
--- Remove any calls that mention the variables
-filterCalls bs calls
- = mapFM (\_ cs -> filter_calls cs) $
- filterFM (\k _ -> not (k `elemVarSet` bs)) calls
+
+----------------------
+deleteCallsMentioning :: VarSet -> CallDetails -> CallDetails
+-- Remove calls *mentioning* bs
+deleteCallsMentioning bs calls
+ = mapVarEnv filter_calls calls
where
- filter_calls :: CallInfo -> CallInfo
+ filter_calls :: CallInfoSet -> CallInfoSet
filter_calls = filterFM (\_ (_, fvs) -> not (fvs `intersectsVarSet` bs))
+
+deleteCallsFor :: [Id] -> CallDetails -> CallDetails
+-- Remove calls *for* bs
+deleteCallsFor bs calls = delVarEnvList calls bs
\end{code}