summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Graf <sebastian.graf@kit.edu>2022-04-14 21:21:27 +0200
committerSebastian Graf <sebastian.graf@kit.edu>2022-04-19 12:49:26 +0200
commit78aae867e3b3e8d84bcbffde18857efff06acd3d (patch)
tree0618d52bcbe966bc449cd32c893a2518c6e18f1c
parent79d69ad7548ef103974e77df2db6444ea031dd4b (diff)
downloadhaskell-wip/T21392.tar.gz
A fix for #21392wip/T21392
One that also refactors worker/wrapper to transfer demands onto workers
-rw-r--r--compiler/GHC/Core/Opt/SetLevels.hs23
-rw-r--r--compiler/GHC/Core/Opt/SpecConstr.hs2
-rw-r--r--compiler/GHC/Core/Opt/WorkWrap.hs23
-rw-r--r--compiler/GHC/Core/Opt/WorkWrap/Utils.hs158
-rw-r--r--compiler/GHC/Core/Utils.hs2
-rw-r--r--compiler/GHC/Stg/Lift/Analysis.hs2
-rw-r--r--compiler/GHC/Types/Demand.hs74
-rw-r--r--compiler/GHC/Types/Id.hs3
8 files changed, 190 insertions, 97 deletions
diff --git a/compiler/GHC/Core/Opt/SetLevels.hs b/compiler/GHC/Core/Opt/SetLevels.hs
index eab4d0ef4e..c0ae50b406 100644
--- a/compiler/GHC/Core/Opt/SetLevels.hs
+++ b/compiler/GHC/Core/Opt/SetLevels.hs
@@ -104,7 +104,7 @@ import GHC.Types.Unique.Set ( nonDetStrictFoldUniqSet )
import GHC.Types.Unique.DSet ( getUniqDSet )
import GHC.Types.Var.Env
import GHC.Types.Literal ( litIsTrivial )
-import GHC.Types.Demand ( DmdSig, Demand, isStrUsedDmd, splitDmdSig, prependArgsDmdSig )
+import GHC.Types.Demand
import GHC.Types.Cpr ( mkCprSig, botCpr )
import GHC.Types.Name ( getOccName, mkSystemVarName )
import GHC.Types.Name.Occurrence ( occNameString )
@@ -730,7 +730,7 @@ lvlMFE env strict_ctxt ann_expr
-- See Note [Bottoming floats]
-- esp Bottoming floats (2)
expr_ok_for_spec = exprOkForSpeculation expr
- dest_lvl = destLevel env fvs fvs_ty is_function is_bot False
+ dest_lvl = destLevel env fvs fvs_ty is_function is_bot False False
abs_vars = abstractVars dest_lvl env fvs
-- float_is_new_lam: the floated thing will be a new value lambda
@@ -1175,7 +1175,8 @@ lvlBind env (AnnNonRec bndr rhs)
rhs_fvs = freeVarsOf rhs
bind_fvs = rhs_fvs `unionDVarSet` dIdFreeVars bndr
abs_vars = abstractVars dest_lvl env bind_fvs
- dest_lvl = destLevel env bind_fvs ty_fvs (isFunction rhs) is_bot is_join
+ frag_dmd = hasFragileDmdSig bndr
+ dest_lvl = destLevel env bind_fvs ty_fvs (isFunction rhs) is_bot is_join frag_dmd
deann_rhs = deAnnotate rhs
mb_bot_str = exprBotStrictness_maybe deann_rhs
@@ -1275,7 +1276,8 @@ lvlBind env (AnnRec pairs)
bndrs
ty_fvs = foldr (unionVarSet . tyCoVarsOfType . idType) emptyVarSet bndrs
- dest_lvl = destLevel env bind_fvs ty_fvs is_fun is_bot is_join
+ frag_dmd = any hasFragileDmdSig bndrs
+ dest_lvl = destLevel env bind_fvs ty_fvs is_fun is_bot is_join frag_dmd
abs_vars = abstractVars dest_lvl env bind_fvs
profitableFloat :: LevelEnv -> Level -> Bool
@@ -1283,6 +1285,15 @@ profitableFloat env dest_lvl
= (dest_lvl `ltMajLvl` le_ctxt_lvl env) -- Escapes a value lambda
|| isTopLvl dest_lvl -- Going all the way to top level
+-- | The 'idDmdSig' of a join point is fragile if it is not top and was computed
+-- assuming an interesting 'idDemandInfo' on the join body that would be lost by
+-- floating the join point to the top-level.
+hasFragileDmdSig :: Id -> Bool
+hasFragileDmdSig join_bndr
+ = not (isTopSig (idDmdSig join_bndr)) && topSubDmd /= body_sd
+ where
+ _ :* join_sd = idDemandInfo join_bndr
+ (_, body_sd) = peelManyCalls (idArity join_bndr) join_sd
----------------------------------------------------
-- Three help functions for the type-abstraction case
@@ -1445,11 +1456,13 @@ destLevel :: LevelEnv
-> Bool -- True <=> is function
-> Bool -- True <=> is bottom
-> Bool -- True <=> is a join point
+ -> Bool -- True <=> if join point, then demand info is fragile
-> Level
-- INVARIANT: if is_join=True then result >= join_ceiling
-destLevel env fvs fvs_ty is_function is_bot is_join
+destLevel env fvs fvs_ty is_function is_bot is_join frag_dmd
| isTopLvl max_fv_id_level -- Float even joins if they get to top level
-- See Note [Floating join point bindings]
+ , is_bot || not (is_join && frag_dmd)
= tOP_LEVEL
| is_join -- Never float a join point past the join ceiling
diff --git a/compiler/GHC/Core/Opt/SpecConstr.hs b/compiler/GHC/Core/Opt/SpecConstr.hs
index 90f492ffea..1ce501e0b0 100644
--- a/compiler/GHC/Core/Opt/SpecConstr.hs
+++ b/compiler/GHC/Core/Opt/SpecConstr.hs
@@ -1785,7 +1785,7 @@ spec_one env fn arg_bndrs body (call_pat, rule_number)
(spec_lam_args, spec_call_args, spec_arity, spec_join_arity)
| needsVoidWorkerArg fn arg_bndrs spec_lam_args1
- , (spec_lam_args, spec_call_args, _) <- addVoidWorkerArg spec_lam_args1 []
+ , (spec_lam_args, spec_call_args, _, _) <- addVoidWorkerArg spec_lam_args1 [] []
-- needsVoidWorkerArg: usual w/w hack to avoid generating
-- a spec_rhs of unlifted type and no args.
-- Unlike W/W we don't turn functions into strict workers
diff --git a/compiler/GHC/Core/Opt/WorkWrap.hs b/compiler/GHC/Core/Opt/WorkWrap.hs
index 93c4c31995..68a34a634d 100644
--- a/compiler/GHC/Core/Opt/WorkWrap.hs
+++ b/compiler/GHC/Core/Opt/WorkWrap.hs
@@ -753,7 +753,7 @@ splitFun ww_opts fn_id rhs
= warnPprTrace (not (wrap_dmds `lengthIs` (arityInfo fn_info)))
"splitFun"
(ppr fn_id <+> (ppr wrap_dmds $$ ppr cpr)) $
- do { mb_stuff <- mkWwBodies ww_opts fn_id arg_vars (exprType body) wrap_dmds cpr
+ do { mb_stuff <- mkWwBodies ww_opts fn_id arg_vars (exprType body) wrap_dmds fun_dmd cpr
; case mb_stuff of
Nothing -> -- No useful wrapper; leave the binding alone
return [(fn_id, rhs)]
@@ -786,14 +786,15 @@ splitFun ww_opts fn_id rhs
(ppr fn_id <> colon <+> text "ct_arty:" <+> int (ct_arty cpr_ty)
<+> text "arityInfo:" <+> ppr (arityInfo fn_info)) $
ct_cpr cpr_ty
+ fun_dmd = idDemandInfo fn_id
mkWWBindPair :: WwOpts -> Id -> IdInfo
-> [Var] -> CoreExpr -> Unique -> Divergence
- -> ([Demand],[CbvMark], JoinArity, Id -> CoreExpr, Expr CoreBndr -> CoreExpr)
+ -> ([Demand],[CbvMark], JoinArity, Demand, Id -> CoreExpr, Expr CoreBndr -> CoreExpr)
-> [(Id, CoreExpr)]
mkWWBindPair ww_opts fn_id fn_info fn_args fn_body work_uniq div
- (work_demands, cbv_marks :: [CbvMark], join_arity, wrap_fn, work_fn)
- = -- pprTrace "mkWWBindPair" (ppr fn_id <+> ppr wrap_id <+> ppr work_id $$ ppr wrap_rhs) $
+ (work_demands, cbv_marks :: [CbvMark], join_arity, work_dmd, wrap_fn, work_fn)
+ = pprTrace "mkWWBindPair" (ppr fn_id <+> ppr wrap_id <+> ppr work_id <+> ppr work_dmd $$ ppr wrap_rhs) $
[(work_id, work_rhs), (wrap_id, wrap_rhs)]
-- Worker first, because wrapper mentions it
where
@@ -840,7 +841,7 @@ mkWWBindPair ww_opts fn_id fn_info fn_args fn_body work_uniq div
`setIdCprSig` topCprSig
- `setIdDemandInfo` worker_demand
+ `setIdDemandInfo` work_dmd
`setIdArity` work_arity
-- Set the arity so that the Core Lint check that the
@@ -855,9 +856,9 @@ mkWWBindPair ww_opts fn_id fn_info fn_args fn_body work_uniq div
work_arity = length work_demands :: Int
-- See Note [Demand on the worker]
- single_call = saturatedByOneShots arity (demandInfo fn_info)
- worker_demand | single_call = mkWorkerDemand work_arity
- | otherwise = topDmd
+ --single_call = saturatedByOneShots arity (demandInfo fn_info)
+ --worker_demand | single_call = mkWorkerDemand work_arity
+ -- | otherwise = topDmd
wrap_rhs = wrap_fn work_id
wrap_prag = mkStrWrapperInlinePrag fn_inl_prag fn_rules
@@ -898,12 +899,10 @@ mkStrWrapperInlinePrag (InlinePragma { inl_act = act, inl_rule = rule_info }) ru
{-
Note [Demand on the worker]
~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
If the original function is called once, according to its demand info, then
so is the worker. This is important so that the occurrence analyser can
attach OneShot annotations to the worker’s lambda binders.
-
Example:
-- Original function
@@ -1033,8 +1032,8 @@ splitThunk :: WwOpts -> RecFlag -> Var -> Expr Var -> UniqSM [(Var, Expr Var)]
splitThunk ww_opts is_rec x rhs
= assert (not (isJoinId x)) $
do { let x' = localiseId x -- See comment above
- ; (useful,_args, wrap_fn, fn_arg)
- <- mkWWstr_one ww_opts x' NotMarkedCbv
+ ; (useful, _args, wrap_fn, fn_arg)
+ <- mkWWstr_one ww_opts x' NotMarkedCbv C_0N
; let res = [ (x, Let (NonRec x' rhs) (wrap_fn fn_arg)) ]
; if useful then assertPpr (isNonRec is_rec) (ppr x) -- The thunk must be non-recursive
return res
diff --git a/compiler/GHC/Core/Opt/WorkWrap/Utils.hs b/compiler/GHC/Core/Opt/WorkWrap/Utils.hs
index 63ac670418..edc31e55a9 100644
--- a/compiler/GHC/Core/Opt/WorkWrap/Utils.hs
+++ b/compiler/GHC/Core/Opt/WorkWrap/Utils.hs
@@ -50,6 +50,7 @@ import GHC.Types.Unique.Supply
import GHC.Types.Name ( getOccFS )
import GHC.Data.FastString
+import GHC.Data.Maybe
import GHC.Data.OrdList
import GHC.Data.List.SetOps
@@ -163,6 +164,7 @@ type WwResult
= ([Demand], -- Demands for worker (value) args
[CbvMark], -- Cbv semantics for worker (value) args
JoinArity, -- Number of worker (type OR value) args
+ Demand, -- Demand on the worker
Id -> CoreExpr, -- Wrapper body, lacking only the worker Id
CoreExpr -> CoreExpr) -- Worker body, lacking the original function rhs
@@ -175,7 +177,8 @@ mkWwBodies :: WwOpts
-> [Var] -- ^ Manifest args of original function
-> Type -- ^ Result type of the original function,
-- after being stripped of args
- -> [Demand] -- ^ Strictness of original function
+ -> [Demand] -- ^ Argument Demands of original function
+ -> Demand -- ^ Demand on the original function
-> Cpr -- ^ Info about function result
-> UniqSM (Maybe WwResult)
-- ^ Given a function definition
@@ -218,9 +221,9 @@ mkWwBodies :: WwOpts
-- and beta-redexes]), which allows us to apply the same split to function body
-- and its unfolding(s) alike.
--
-mkWwBodies opts fun_id arg_vars res_ty demands res_cpr
- = do { massertPpr (filter isId arg_vars `equalLength` demands)
- (text "wrong wrapper arity" $$ ppr fun_id $$ ppr arg_vars $$ ppr res_ty $$ ppr demands)
+mkWwBodies opts fun_id arg_vars res_ty arg_dmds fun_dmd res_cpr
+ = do { massertPpr (filter isId arg_vars `equalLength` arg_dmds)
+ (text "wrong wrapper arity" $$ ppr fun_id $$ ppr arg_vars $$ ppr res_ty $$ ppr arg_dmds)
-- Clone and prepare arg_vars of the original fun RHS
-- See Note [Freshen WW arguments]
@@ -233,31 +236,40 @@ mkWwBodies opts fun_id arg_vars res_ty demands res_cpr
res_ty' = GHC.Core.Subst.substTy subst res_ty
init_cbv_marks = map (const NotMarkedCbv) cloned_arg_vars
- ; (useful1, work_args_cbv, wrap_fn_str, fn_args)
- <- mkWWstr opts cloned_arg_vars init_cbv_marks
+ ; let fun_card :* fun_sd = fun_dmd
+ (arg_cards, res_sd) = peelManyCalls (length arg_dmds) fun_sd
+ init_arg_cards = intersperseTyArgs cloned_arg_vars C_11 arg_cards
- ; let (work_args, work_marks) = unzip work_args_cbv
+ ; (useful1, work_args_cbv_cards, wrap_fn_str, fn_args)
+ <- mkWWstr opts cloned_arg_vars init_cbv_marks init_arg_cards
+
+ ; let (work_args, work_marks, work_cards) = unzip3 work_args_cbv_cards
+ ; let (work_val_cards, _work_ty_cards) = partitionByList (map isId work_args) work_cards
-- Do CPR w/w. See Note [Always do CPR w/w]
- ; (useful2, wrap_fn_cpr, work_fn_cpr)
- <- mkWWcpr_entry opts res_ty' res_cpr
+ ; (useful2, work_res_sd, wrap_fn_cpr, work_fn_cpr)
+ <- mkWWcpr_entry opts res_ty' res_sd res_cpr
- ; let (work_lam_args, work_call_args, work_call_cbv)
+ ; let (work_lam_args, work_call_args, work_call_cbv, work_call_cards)
| needsVoidWorkerArg fun_id arg_vars work_args
- = addVoidWorkerArg work_args work_marks
+ = addVoidWorkerArg work_args work_marks work_val_cards
| otherwise
- = (work_args, work_args, work_marks)
+ = (work_args, work_args, work_marks, work_val_cards)
- call_work work_fn = mkVarApps (Var work_fn) work_call_args
- call_rhs fn_rhs = mkAppsBeta fn_rhs fn_args
+ call_work work_fn = mkVarApps (Var work_fn) work_call_args
+ call_rhs fn_rhs = mkAppsBeta fn_rhs fn_args
-- See Note [Join points and beta-redexes]
wrapper_body = mkLams cloned_arg_vars . wrap_fn_cpr . wrap_fn_str . call_work
- worker_body = mkLams work_lam_args . work_fn_cpr . call_rhs
- (worker_args_dmds, work_val_cbvs)= unzip [(idDemandInfo v,cbv) | (v,cbv) <- zipEqual "mkWwBodies" work_call_args work_call_cbv, isId v]
+ worker_body = mkLams work_lam_args . work_fn_cpr . call_rhs
+ (worker_args_dmds, work_val_cbvs) =
+ unzip [(idDemandInfo v,cbv) | (v,cbv) <- zipEqual "mkWwBodies" work_call_args work_call_cbv, isId v]
+ worker_demand = fun_card :* mkCalls work_call_cards work_res_sd
+
+ ; pprTraceM "worker demand" (ppr fun_id $$ ppr worker_demand $$ ppr work_call_cards)
; if ((useful1 && not only_one_void_argument) || useful2)
then return (Just (worker_args_dmds, work_val_cbvs, length work_call_args,
- wrapper_body, worker_body))
+ worker_demand, wrapper_body, worker_body))
else return Nothing
}
-- We use an INLINE unconditionally, even if the wrapper turns out to be
@@ -275,7 +287,7 @@ mkWwBodies opts fun_id arg_vars res_ty demands res_cpr
-- Note [Do not split void functions]
only_one_void_argument
- | [d] <- demands
+ | [d] <- arg_dmds
, [v] <- filter isId arg_vars
, isAbsDmd d && isZeroBitTy (idType v)
= True
@@ -299,6 +311,13 @@ isWorkerSmallEnough max_worker_args old_n_args vars
-- Also if the function took 82 arguments before (old_n_args), it's fine if
-- it takes <= 82 arguments afterwards.
+intersperseTyArgs :: HasDebugCallStack => [Var] -> a -> [a] -> [a]
+intersperseTyArgs [] _ as = as
+intersperseTyArgs (v:vs) def as
+ | isTyVar v = def : intersperseTyArgs vs def as
+ | a:as' <- as = a : intersperseTyArgs vs def as'
+intersperseTyArgs _ _ _ = panic "intersperseTyArgs ran out of as"
+
{-
Note [Always do CPR w/w]
~~~~~~~~~~~~~~~~~~~~~~~~
@@ -396,12 +415,13 @@ needsVoidWorkerArg fn_id wrap_args work_args
needs_float_barrier = wrap_had_barrier && not work_has_barrier
-- | Inserts a `Void#` arg before the first value argument (but after leading type args).
-addVoidWorkerArg :: [Var] -> [CbvMark]
+addVoidWorkerArg :: [Var] -> [CbvMark] -> [Card]
-> ([Var], -- Lambda bound args
[Var], -- Args at call site
- [CbvMark]) -- cbv semantics for the worker args.
-addVoidWorkerArg work_args cbv_marks
- = (ty_args ++ voidArgId:rest, ty_args ++ voidPrimId:rest, NotMarkedCbv:cbv_marks)
+ [CbvMark], -- cbv semantics for the worker args.
+ [Card]) -- worker arg cardinalities
+addVoidWorkerArg work_args cbv_marks arg_cards
+ = (ty_args ++ voidArgId:rest, ty_args ++ voidPrimId:rest, NotMarkedCbv:cbv_marks, C_0N:arg_cards)
where
(ty_args, rest) = break isId work_args
@@ -728,7 +748,7 @@ real-world example involves unsafeCoerce:
foreign import ccall "c_exit" c_exit :: IO ()
Here CPR will tell you that `foo` returns a () constructor for sure, but trying
to create a worker/wrapper for type `a` obviously fails.
-(This was a real example until ee8e792 in libraries/base.)
+(This was a real example until ee8e792 in libraries/base.)
It does not seem feasible to avoid all such cases already in the analyser (and
after all, the analysis is not really wrong), so we simply do nothing here in
@@ -811,28 +831,32 @@ mkWWstr :: WwOpts
-> [Var] -- Wrapper args; have their demand info on them
-- *Includes type variables*
-> [CbvMark] -- cbv info for arguments
+ -> [Card] -- Call cardinalities for arguments derived from
+ -- wrapper function's demand
-> UniqSM (Bool, -- Will this result in a useful worker
- [(Var,CbvMark)], -- Worker args/their call-by-value semantics.
+ [(Var,CbvMark,Card)], -- Worker args, their call-by-value semantics
+ -- and their call cardinalities
CoreExpr -> CoreExpr, -- Wrapper body, lacking the worker call
-- and without its lambdas
-- This fn adds the unboxing
[CoreExpr]) -- Reboxed args for the call to the
-- original RHS. Corresponds one-to-one
-- with the wrapper arg vars
-mkWWstr opts args cbv_info
- = go args cbv_info
+mkWWstr opts args cbv_info cards
+ = go args cbv_info cards
where
- go_one arg cbv = mkWWstr_one opts arg cbv
+ go_one arg cbv card = mkWWstr_one opts arg cbv card
- go [] _ = return (badWorker, [], nop_fn, [])
- go (arg : args) (cbv:cbvs)
- = do { (useful1, args1, wrap_fn1, wrap_arg) <- go_one arg cbv
- ; (useful2, args2, wrap_fn2, wrap_args) <- go args cbvs
+ go [] _ _ = return (badWorker, [], nop_fn, [])
+ go (arg : args) (cbv:cbvs) (card:cards)
+ = do { (useful1, args1, wrap_fn1, wrap_arg) <- go_one arg cbv card
+ ; (useful2, args2, wrap_fn2, wrap_args) <- go args cbvs cards
; return ( useful1 || useful2
, args1 ++ args2
, wrap_fn1 . wrap_fn2
, wrap_arg:wrap_args ) }
- go _ _ = panic "mkWWstr: Impossible - cbv/arg length missmatch"
+ go args cbvs cards = pprPanic "mkWWstr: Impossible - cbv/card/arg length missmatch"
+ $ ppr args $$ ppr cbvs $$ ppr cards
----------------------
-- mkWWstr_one wrap_var = (useful, work_args, wrap_fn, wrap_arg)
@@ -844,8 +868,9 @@ mkWWstr opts args cbv_info
mkWWstr_one :: WwOpts
-> Var
-> CbvMark
- -> UniqSM (Bool, [(Var,CbvMark)], CoreExpr -> CoreExpr, CoreExpr)
-mkWWstr_one opts arg marked_cbv =
+ -> Card
+ -> UniqSM (Bool, [(Var,CbvMark,Card)], CoreExpr -> CoreExpr, CoreExpr)
+mkWWstr_one opts arg marked_cbv card =
case wantToUnboxArg True fam_envs arg_ty arg_dmd of
_ | isTyVar arg -> do_nothing
@@ -857,10 +882,10 @@ mkWWstr_one opts arg marked_cbv =
-- (that's what mkAbsentFiller does)
-> return (goodWorker, [], nop_fn, absent_filler)
- Unbox dcpc ds -> unbox_one_arg opts arg ds dcpc marked_cbv
+ Unbox dcpc ds -> unbox_one_arg opts arg ds dcpc marked_cbv card
Unlift -> return ( wwForUnlifting opts
- , [(setIdUnfolding arg evaldUnfolding, MarkedCbv)]
+ , [(setIdUnfolding arg evaldUnfolding, MarkedCbv, card)]
, nop_fn
, varToCoreExpr arg)
@@ -872,18 +897,19 @@ mkWWstr_one opts arg marked_cbv =
arg_dmd = idDemandInfo arg
-- Type args don't get cbv marks
arg_cbv = if isTyVar arg then NotMarkedCbv else marked_cbv
- do_nothing = return (badWorker, [(arg,arg_cbv)], nop_fn, varToCoreExpr arg)
+ do_nothing = return (badWorker, [(arg,arg_cbv,card)], nop_fn, varToCoreExpr arg)
unbox_one_arg :: WwOpts
-> Var
-> [Demand]
-> DataConPatContext
-> CbvMark
- -> UniqSM (Bool, [(Var,CbvMark)], CoreExpr -> CoreExpr, CoreExpr)
+ -> Card
+ -> UniqSM (Bool, [(Var,CbvMark,Card)], CoreExpr -> CoreExpr, CoreExpr)
unbox_one_arg opts arg_var ds
DataConPatContext { dcpc_dc = dc, dcpc_tc_args = tc_args
, dcpc_co = co }
- _marked_cbv
+ _marked_cbv card
= do { pat_bndrs_uniqs <- getUniquesM
; let ex_name_fss = map getOccFS $ dataConExTyCoVars dc
-- Create new arguments we get when unboxing dc
@@ -897,7 +923,8 @@ unbox_one_arg opts arg_var ds
cbv_arg_marks = zipWithEqual "unbox_one_arg" bangToMark (dataConRepStrictness dc) arg_ids'
unf_args = zipWith setEvald arg_ids' cbv_arg_marks
cbv_marks = (map (const NotMarkedCbv) ex_tvs') ++ cbv_arg_marks
- ; (_sub_args_quality, worker_args, wrap_fn, wrap_args) <- mkWWstr opts (ex_tvs' ++ unf_args) cbv_marks
+ cards = takeList arg_ids' (card:repeat C_11)
+ ; (_sub_args_quality, worker_args, wrap_fn, wrap_args) <- mkWWstr opts (ex_tvs' ++ unf_args) cbv_marks cards
; let wrap_arg = mkConApp dc (map Type tc_args ++ wrap_args) `mkCast` mkSymCo co
; return (goodWorker, worker_args, unbox_fn . wrap_fn, wrap_arg) }
-- Don't pass the arg, rebox instead
@@ -1338,16 +1365,18 @@ See Note [Worker/wrapper for CPR] for an overview.
mkWWcpr_entry
:: WwOpts
-> Type -- function body
+ -> SubDemand -- How the function (body) was eval'd
-> Cpr -- CPR analysis results
- -> UniqSM (Bool, -- Is w/w'ing useful?
+ -> UniqSM (Bool, -- Is w/w'ing useful?
+ SubDemand, -- How the worker function (body) is eval'd
CoreExpr -> CoreExpr, -- New wrapper. 'nop_fn' if not useful
CoreExpr -> CoreExpr) -- New worker. 'nop_fn' if not useful
-- ^ Entrypoint to CPR W/W. See Note [Worker/wrapper for CPR] for an overview.
-mkWWcpr_entry opts body_ty body_cpr
- | not (wo_cpr_anal opts) = return (badWorker, nop_fn, nop_fn)
+mkWWcpr_entry opts body_ty body_sd body_cpr
+ | not (wo_cpr_anal opts) = return (badWorker, body_sd, nop_fn, nop_fn)
| otherwise = do
-- Part (1)
- res_bndr <- mk_res_bndr body_ty
+ res_bndr <- mk_res_bndr body_ty body_sd
let bind_res_bndr body scope = mkDefaultCase body res_bndr scope
-- Part (2)
@@ -1355,22 +1384,23 @@ mkWWcpr_entry opts body_ty body_cpr
mkWWcpr_one opts res_bndr body_cpr
-- Part (3)
- let (unbox_transit_tup, transit_tup) = move_transit_vars transit_vars
+ let (work_body_sd, unbox_transit_tup, transit_tup) = move_transit_vars transit_vars
-- Stacking unboxer (work_fn) and builder (wrap_fn) together
let wrap_fn = unbox_transit_tup rebuilt_result -- 3 2
work_fn body = bind_res_bndr body (work_unpack_res transit_tup) -- 1 2 3
return $ if not useful
- then (badWorker, nop_fn, nop_fn)
- else (goodWorker, wrap_fn, work_fn)
+ then (badWorker, body_sd, nop_fn, nop_fn)
+ else (goodWorker, work_body_sd, wrap_fn, work_fn)
-- | Part (1) of Note [Worker/wrapper for CPR].
-mk_res_bndr :: Type -> UniqSM Id
-mk_res_bndr body_ty = do
+mk_res_bndr :: Type -> SubDemand -> UniqSM Id
+mk_res_bndr body_ty body_sd = do
-- See Note [Linear types and CPR]
bndr <- mkSysLocalOrCoVarM ww_prefix cprCaseBndrMult body_ty
-- See Note [Record evaluated-ness in worker/wrapper]
- pure (setCaseBndrEvald MarkedStrict bndr)
+ pure $ setCaseBndrEvald MarkedStrict bndr
+ `setIdDemandInfo` C_11 :* body_sd
-- | What part (2) of Note [Worker/wrapper for CPR] collects.
--
@@ -1378,11 +1408,11 @@ mk_res_bndr body_ty = do
-- 2. The list of transit variables (see the Note).
-- 3. The result builder expression for the wrapper. The original case binder if not useful.
-- 4. The result unpacking expression for the worker. 'nop_fn' if not useful.
-type CprWwResultOne = (Bool, OrdList Var, CoreExpr , CoreExpr -> CoreExpr)
+type CprWwResultOne = (Bool, OrdList Var, CoreExpr, CoreExpr -> CoreExpr)
type CprWwResultMany = (Bool, OrdList Var, [CoreExpr], CoreExpr -> CoreExpr)
mkWWcpr :: WwOpts -> [Id] -> [Cpr] -> UniqSM CprWwResultMany
-mkWWcpr _opts vars [] =
+mkWWcpr _opts vars [] =
-- special case: No CPRs means all top (for example from FlatConCpr),
-- hence stop WW.
return (badWorker, toOL vars, map varToCoreExpr vars, nop_fn)
@@ -1422,13 +1452,19 @@ unbox_one_result opts res_bndr arg_cprs
dataConRepFSInstPat (repeat ww_prefix) pat_bndrs_uniqs cprCaseBndrMult dc tc_args
massert (null _exs) -- Should have been caught by wantToUnboxResult
+ -- Set the demand on arg_ids
+ let arity = dataConRepArity dc
+ _ :* sd = strictifyDmd $ idDemandInfo res_bndr -- strictifyDmd: Because we eval res_bndr
+ arg_dmds = (snd <$> viewProd arity sd) `orElse` replicate arity topDmd
+ arg_ids' = zipWith setIdDemandInfo arg_ids arg_dmds
+
(nested_useful, transit_vars, con_args, work_unbox_res) <-
- mkWWcpr opts arg_ids arg_cprs
+ mkWWcpr opts arg_ids' arg_cprs
let -- rebuilt_result = (C a b |> sym co)
rebuilt_result = mkConApp dc (map Type tc_args ++ con_args) `mkCast` mkSymCo co
-- this_work_unbox_res alt = (case res_bndr |> co of C a b -> <alt>[a,b])
- this_work_unbox_res = mkUnpackCase (Var res_bndr) co cprCaseBndrMult dc arg_ids
+ this_work_unbox_res = mkUnpackCase (Var res_bndr) co cprCaseBndrMult dc arg_ids'
-- Don't try to WW an unboxed tuple return type when there's nothing inside
-- to unbox further.
@@ -1442,14 +1478,16 @@ unbox_one_result opts res_bndr arg_cprs
-- | Implements part (3) of Note [Worker/wrapper for CPR].
--
--- If `move_transit_vars [a,b] = (unbox, tup)` then
+-- If `move_transit_vars [a,b] = (sd, unbox, tup)` then
-- * `a` and `b` are the *transit vars* to be returned from the worker
-- to the wrapper
-- * `unbox scrut alt = (case <scrut> of (# a, b #) -> <alt>)`
-- * `tup = (# a, b #)`
+-- * `sd` is the SubDemand on the worker body, reconstructed form the demand
+-- on transit vars that we carefully transfered over in mkWWcpr_one
-- There is a special case for when there's 1 transit var,
-- see Note [No unboxed tuple for single, unlifted transit var].
-move_transit_vars :: [Id] -> (CoreExpr -> CoreExpr -> CoreExpr, CoreExpr)
+move_transit_vars :: [Id] -> (SubDemand, CoreExpr -> CoreExpr -> CoreExpr, CoreExpr)
move_transit_vars vars
| [var] <- vars
, let var_ty = idType var
@@ -1457,21 +1495,25 @@ move_transit_vars vars
-- See Note [No unboxed tuple for single, unlifted transit var]
-- * Wrapper: `unbox scrut alt = (case <scrut> of a -> <alt>)`
-- * Worker: `tup = a`
- = ( \build_res wkr_call -> mkDefaultCase wkr_call var build_res
+ = ( subDemandIfEvaluated (strictifyDmd (idDemandInfo var))
+ , \build_res wkr_call -> mkDefaultCase wkr_call var build_res
, varToCoreExpr var ) -- varToCoreExpr important here: var can be a coercion
-- Lacking this caused #10658
| otherwise
-- The general case: Just return an unboxed tuple from the worker
-- * Wrapper: `unbox scrut alt = (case <scrut> of (# a, b #) -> <alt>)`
-- * Worker: `tup = (# a, b #)`
- = ( \build_res wkr_call -> mkSingleAltCase wkr_call case_bndr
+ = ( body_sd
+ , \build_res wkr_call -> mkSingleAltCase wkr_call case_bndr
(DataAlt tup_con) vars build_res
, ubx_tup_app )
where
ubx_tup_app = mkCoreUbxTup (map idType vars) (map varToCoreExpr vars)
tup_con = tupleDataCon Unboxed (length vars)
+ body_sd = mkProd Unboxed (map idDemandInfo vars)
-- See also Note [Linear types and CPR]
case_bndr = mkWildValBinder cprCaseBndrMult (exprType ubx_tup_app)
+ `setIdDemandInfo` C_11 :* body_sd
{- Note [Worker/wrapper for CPR]
diff --git a/compiler/GHC/Core/Utils.hs b/compiler/GHC/Core/Utils.hs
index f6656e602a..462ed3bdcb 100644
--- a/compiler/GHC/Core/Utils.hs
+++ b/compiler/GHC/Core/Utils.hs
@@ -2593,7 +2593,7 @@ tryEtaReduce bndrs body eval_sd
-- ... and that the function can be eta reduced to arity 0
-- without violating invariants of Core and GHC
&& canEtaReduceToArity fun 0 0 -- criteria (L), (J), (W), (B)
- all_calls_with_arity n = isStrict (peelManyCalls n eval_sd)
+ all_calls_with_arity n = isStrict (enterManyCalls n eval_sd)
-- See Note [Eta reduction based on evaluation context]
---------------
diff --git a/compiler/GHC/Stg/Lift/Analysis.hs b/compiler/GHC/Stg/Lift/Analysis.hs
index 6fc116c8bc..e4ba2f077c 100644
--- a/compiler/GHC/Stg/Lift/Analysis.hs
+++ b/compiler/GHC/Stg/Lift/Analysis.hs
@@ -326,7 +326,7 @@ tagSkeletonRhs bndr (StgRhsClosure fvs ccs upd bndrs body)
rhsCard :: Id -> Card
rhsCard bndr
| is_thunk = oneifyCard n
- | otherwise = peelManyCalls (idArity bndr) cd
+ | otherwise = n `multCard` enterManyCalls (idArity bndr) cd
where
is_thunk = idArity bndr == 0
-- Let's pray idDemandInfo is still OK after unarise...
diff --git a/compiler/GHC/Types/Demand.hs b/compiler/GHC/Types/Demand.hs
index cecd2ccd1c..1720bf5ec3 100644
--- a/compiler/GHC/Types/Demand.hs
+++ b/compiler/GHC/Types/Demand.hs
@@ -22,6 +22,7 @@ module GHC.Types.Demand (
Demand(AbsDmd, BotDmd, (:*)),
SubDemand(Prod, Poly), mkProd, viewProd,
-- ** Algebra
+ botCard, topCard,
absDmd, topDmd, botDmd, seqDmd, topSubDmd,
-- *** Least upper bound
lubCard, lubDmd, lubSubDmd,
@@ -39,7 +40,8 @@ module GHC.Types.Demand (
lazyApply1Dmd, lazyApply2Dmd, strictOnceApply1Dmd, strictManyApply1Dmd,
-- ** Other @Demand@ operations
oneifyCard, oneifyDmd, strictifyDmd, strictifyDictDmd, lazifyDmd,
- peelCallDmd, peelManyCalls, mkCalledOnceDmd, mkCalledOnceDmds,
+ peelCallDmd, peelManyCalls, enterManyCalls,
+ mkCall, mkCalls, mkCalledOnceDmd, mkCalledOnceDmds,
mkWorkerDemand, subDemandIfEvaluated,
-- ** Extracting one-shot information
argOneShots, argsOneShots, saturatedByOneShots,
@@ -536,8 +538,8 @@ pattern C_0N = Card 0b111
{-# COMPLETE C_00, C_01, C_0N, C_10, C_11, C_1N :: Card #-}
-_botCard, topCard :: Card
-_botCard = C_10
+botCard, topCard :: Card
+botCard = C_10
topCard = C_0N
-- | True <=> lower bound is 1.
@@ -803,13 +805,14 @@ viewProd _ _
-- for Arity. Otherwise, #18304 bites us.
-- | A smart constructor for 'Call', applying rewrite rules along the semantic
--- equality @Call C_0N (Poly C_0N) === Poly C_0N@, simplifying to 'Poly' 'SubDemand's
+-- equality @Call n (Poly n) === Poly n@, simplifying to 'Poly' 'SubDemand's
-- when possible.
-mkCall :: CardNonAbs -> SubDemand -> SubDemand
-mkCall C_1N sd@(Poly Boxed C_1N) = sd
-mkCall C_0N sd@(Poly Boxed C_0N) = sd
-mkCall n sd = assertPpr (isCardNonAbs n) (ppr n $$ ppr sd) $
- Call n sd
+mkCall :: Card -> SubDemand -> SubDemand
+mkCall n sd@(Poly Boxed m) | n == m = sd
+mkCall n sd = Call n sd & assertPpr (isCardNonAbs n) (ppr n $$ ppr sd)
+
+mkCalls :: [Card] -> SubDemand -> SubDemand
+mkCalls ns sd = foldr mkCall sd ns
-- | @viewCall sd@ interprets @sd@ as a 'Call', expanding 'Poly' subdemands as
-- necessary.
@@ -1096,14 +1099,49 @@ mkCalledOnceDmds arity sd = iterate mkCalledOnceDmd sd !! arity
peelCallDmd :: SubDemand -> (Card, SubDemand)
peelCallDmd sd = viewCall sd `orElse` (topCard, topSubDmd)
--- Peels multiple nestings of 'Call' sub-demands and also returns
--- whether it was unsaturated in the form of a 'Card'inality, denoting
--- how many times the lambda body was entered.
+-- A "fusion helper" that allows efficient implementation of peelManyCalls and
+-- enterManyCalls (and possibly other folds in the future).
+peelManyCallsFB
+ :: Arity
+ -> SubDemand
+ -> (Card -> r -> r) -- "cons" / Call
+ -> (SubDemand -> r) -- "nil" / out of arity
+ -> r
+peelManyCallsFB k sd call end = peel k sd
+ where
+ peel 0 sd = end sd
+ peel k (peelCallDmd -> (n, sd)) = n `call` (peel (k-1) sd)
+{-# INLINE peelManyCallsFB #-}
+
+-- | 'peelManyCalls k sd' iterates 'peelCallDmd' `k` times on (the 'Call' body
+-- of) `sd`. E.g.
+--
+-- > peelManyCalls 3 CL(CM(CS(P(A)))) = ([L,M,S], P(A))
+--
+peelManyCalls :: Arity -> SubDemand -> ([Card], SubDemand)
+peelManyCalls k sd = peelManyCallsFB k sd cons nil
+ where
+ nil sd = ([], sd)
+ cons n (ns, sd) = (n:ns, sd)
+
+-- | 'enterManyCalls k sd' returns as a 'Card' how often `sd` was called with
+-- `k` many args. A more efficient variant of
+--
+-- > let (cards, _) = peelManyCalls in
+-- > in foldr multCard C_11 cards
+--
+-- If `sd` represents undersaturated calls (e.g., there are less than `k`
+-- 'Call's) then the resulting cardinality is 'topCard'. Examples:
+--
+-- > enterManyCalls 2 CS(C1(CM(A))) = S
+-- > enterManyCalls 3 CS(C1(CM(A))) = L
+-- > enterManyCalls 2 C1(C1(CM(A))) = 1
+-- > enterManyCalls 3 C1(C1(CM(A))) = M
+-- > enterManyCalls 3 C1(C1(A)) = A
+--
-- See Note [Demands from unsaturated function calls].
-peelManyCalls :: Int -> SubDemand -> Card
-peelManyCalls 0 _ = C_11
-peelManyCalls n (viewCall -> Just (m, sd)) = m `multCard` peelManyCalls (n-1) sd
-peelManyCalls _ _ = C_0N
+enterManyCalls :: Arity -> SubDemand -> Card
+enterManyCalls k sd = peelManyCallsFB k sd multCard (\_sd -> C_11)
-- | Extract the 'SubDemand' of a 'Demand'.
-- PRECONDITION: The SubDemand must be used in a context where the expression
@@ -1153,7 +1191,7 @@ argOneShots (_ :* sd) = go sd
saturatedByOneShots :: Int -> Demand -> Bool
saturatedByOneShots _ AbsDmd = True
saturatedByOneShots _ BotDmd = True
-saturatedByOneShots n (_ :* sd) = isUsedOnce (peelManyCalls n sd)
+saturatedByOneShots n (_ :* sd) = isUsedOnce (enterManyCalls n sd)
{- Note [Strict demands]
~~~~~~~~~~~~~~~~~~~~~~~~
@@ -2198,7 +2236,7 @@ type DmdTransformer = SubDemand -> DmdType
-- return how the function evaluates its free variables and arguments.
dmdTransformSig :: DmdSig -> DmdTransformer
dmdTransformSig (DmdSig dmd_ty@(DmdType _ arg_ds _)) sd
- = multDmdType (peelManyCalls (length arg_ds) sd) dmd_ty
+ = multDmdType (enterManyCalls (length arg_ds) sd) dmd_ty
-- see Note [Demands from unsaturated function calls]
-- and Note [What are demand signatures?]
diff --git a/compiler/GHC/Types/Id.hs b/compiler/GHC/Types/Id.hs
index 4d04c82a35..26a056dc35 100644
--- a/compiler/GHC/Types/Id.hs
+++ b/compiler/GHC/Types/Id.hs
@@ -188,7 +188,8 @@ infixl 1 `setIdUnfolding`,
`asJoinId`,
`asJoinId_maybe`,
- `setIdCbvMarks`
+ `setIdCbvMarks`,
+ `setCaseBndrEvald`
{-
************************************************************************