summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Peyton Jones <simon.peytonjones@gmail.com>2023-01-13 16:42:01 +0000
committerSimon Peyton Jones <simon.peytonjones@gmail.com>2023-02-19 19:40:57 +0000
commitb3cb11cc8416f443c3d1eaa5064c623a048a23ab (patch)
treecebb7b31c9656276d6688068acfb4e4616ccce05
parent7267a5a0a1d4fdaae48d005f8531e11d3241e75a (diff)
downloadhaskell-b3cb11cc8416f443c3d1eaa5064c623a048a23ab.tar.gz
Refactor WithTailJoinDetails
-rw-r--r--compiler/GHC/Core/Opt/OccurAnal.hs148
1 files changed, 68 insertions, 80 deletions
diff --git a/compiler/GHC/Core/Opt/OccurAnal.hs b/compiler/GHC/Core/Opt/OccurAnal.hs
index 270cc5b8c6..26f2112119 100644
--- a/compiler/GHC/Core/Opt/OccurAnal.hs
+++ b/compiler/GHC/Core/Opt/OccurAnal.hs
@@ -819,7 +819,18 @@ of both functions, serving as a specification:
data WithUsageDetails a = WithUsageDetails !UsageDetails !a
-data WithTailUsageDetails a = WithTailUsageDetails !TailUsageDetails !a
+-- | Captures the result of applying 'occAnalLamTail' to a function `\xyz.body`.
+-- The TailUsageDetails records
+-- * the number of lambdas (including type lambdas: a JoinArity)
+-- * UsageDetails for the `body`, unadjusted by `adjustTailUsage`.
+-- If the binding turns out to be a join point with the indicated join
+-- arity, this unadjusted usage details is just what we need; otherwise we
+-- need to discard tail calls. That's what `adjustTailUsage` does.
+data Tail a = TE !JoinArity !a
+
+instance Outputable a => Outputable (Tail a) where
+ ppr (TE ja rhs) = text "TE" <> braces(ppr ja) <+> ppr rhs
+
------------------------------------------------------------------
-- occAnalBind
@@ -935,9 +946,9 @@ occAnalNonRecIdBind !env imp_rule_edges tagged_bndr rhs
--------- Unfolding ---------
-- See Note [Join points and unfoldings/rules]
unf = idUnfolding tagged_bndr
- WithTailUsageDetails unf_uds unf1 = occAnalUnfolding rhs_env unf
+ unf_wuds@(WithUsageDetails _ (TE _ unf1)) = occAnalUnfolding rhs_env unf
unf2 = markNonRecUnfoldingOneShots mb_join_arity unf1
- adj_unf_uds = adjustTailArity mb_join_arity unf_uds
+ adj_unf_uds = adjustTailArity mb_join_arity unf_wuds
--------- Rules ---------
-- See Note [Rules are extra RHSs] and Note [Rule dependency info]
@@ -991,19 +1002,22 @@ occAnalRecBind !rhs_env lvl imp_rule_edges pairs body_usage
bndrs = map fst pairs
bndr_set = mkVarSet bndrs
-adjustNonRecRhs :: Maybe JoinArity -> WithTailUsageDetails CoreExpr -> WithUsageDetails CoreExpr
+adjustNonRecRhs :: Maybe JoinArity -> WithUsageDetails (Tail CoreExpr)
+ -> WithUsageDetails CoreExpr
-- ^ This function concentrates shared logic between occAnalNonRecBind and the
-- AcyclicSCC case of occAnalRec.
-- * It applies 'markNonRecJoinOneShots' to the RHS
-- * and returns the adjusted rhs UsageDetails combined with the body usage
-adjustNonRecRhs mb_join_arity (WithTailUsageDetails rhs_tuds rhs)
+adjustNonRecRhs mb_join_arity (WithUsageDetails rhs_usage (TE ja rhs))
= WithUsageDetails rhs_uds' rhs'
where
--------- Marking (non-rec) join binders one-shot ---------
!rhs' | Just ja <- mb_join_arity = markNonRecJoinOneShots ja rhs
| otherwise = rhs
+
--------- Adjusting right-hand side usage ---------
- rhs_uds' = adjustTailUsage mb_join_arity rhs' rhs_tuds
+ rhs_uds' = adjustTailUsage mb_join_arity $
+ WithUsageDetails rhs_usage (TE ja rhs')
bindersOfSCC :: SCC NodeDetails -> [Var]
bindersOfSCC (AcyclicSCC nd) = [nd_bndr nd]
@@ -1504,7 +1518,7 @@ type LetrecNode = Node Unique NodeDetails
data NodeDetails
= ND { nd_bndr :: Id -- Binder
- , nd_rhs :: !(WithTailUsageDetails CoreExpr)
+ , nd_rhs :: !(WithUsageDetails (Tail CoreExpr))
-- ^ RHS, already occ-analysed
-- With TailUsageDetails from RHS, and RULES, and stable unfoldings,
-- ignoring phase (ie assuming all are active).
@@ -1537,7 +1551,8 @@ instance Outputable NodeDetails where
, text "simple =" <+> ppr (nd_simple nd)
, text "active_rule_fvs =" <+> ppr (nd_active_rule_fvs nd)
])
- where WithTailUsageDetails uds _ = nd_rhs nd
+ where
+ WithUsageDetails uds _ = nd_rhs nd
-- | Digraph with simplified and completely occurrence analysed
-- 'SimpleNodeDetails', retaining just the info we need for breaking loops.
@@ -1581,7 +1596,7 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
-- explained in Note [Deterministic SCC] in GHC.Data.Graph.Directed.
where
details = ND { nd_bndr = bndr'
- , nd_rhs = WithTailUsageDetails scope_uds rhs'
+ , nd_rhs = WithUsageDetails unadj_scope_uds rhs_te
, nd_inl = inl_fvs
, nd_simple = null rules_w_uds && null imp_rule_info
, nd_weak_fvs = weak_fvs
@@ -1594,7 +1609,6 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
-- JoinArity rhs_ja of unadj_rhs_uds.
unadj_inl_uds = unadj_rhs_uds `andUDs` adj_unf_uds
unadj_scope_uds = unadj_inl_uds `andUDs` adj_rule_uds
- scope_uds = TUD rhs_ja unadj_scope_uds
-- Note [Rules are extra RHSs]
-- Note [Rule dependency info]
scope_fvs = udFreeVars bndr_set unadj_scope_uds
@@ -1623,15 +1637,15 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
-- until occAnalRec. In effect, we pretend that the RHS becomes a
-- non-recursive join point and fix up later with adjustTailUsage.
rhs_env = setRhsCtxt OccRhs env
- WithTailUsageDetails (TUD rhs_ja unadj_rhs_uds) rhs' = occAnalLamTail rhs_env rhs
+ WithUsageDetails unadj_rhs_uds rhs_te@(TE rhs_ja _) = occAnalLamTail rhs_env rhs
-- corresponding call to adjustTailUsage in occAnalRec and tagRecBinders
--------- Unfolding ---------
-- See Note [Join points and unfoldings/rules]
unf = realIdUnfolding bndr -- realIdUnfolding: Ignore loop-breaker-ness
-- here because that is what we are setting!
- WithTailUsageDetails unf_tuds unf' = occAnalUnfolding rhs_env unf
- adj_unf_uds = adjustTailArity (Just rhs_ja) unf_tuds
+ unf_wuds@(WithUsageDetails _ (TE _ unf')) = occAnalUnfolding rhs_env unf
+ adj_unf_uds = adjustTailArity (Just rhs_ja) unf_wuds
-- `rhs_ja` is `joinRhsArity rhs` and is the prediction for source M
-- of Note [Join arity prediction based on joinRhsArity]
@@ -1646,8 +1660,8 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
-- `rhs_ja` is `joinRhsArity rhs'` and is the prediction for source M
-- of Note [Join arity prediction based on joinRhsArity]
rules_w_uds :: [(CoreRule, UsageDetails, UsageDetails)]
- rules_w_uds = [ (r,l,adjustTailArity (Just rhs_ja) rhs_tuds)
- | (r,l,rhs_tuds) <- occAnalRules rhs_env bndr ]
+ rules_w_uds = [ (r,l,adjustTailArity (Just rhs_ja) rhs_wuds)
+ | (r,l,rhs_wuds) <- occAnalRules rhs_env bndr ]
rules' = map fstOf3 rules_w_uds
adj_rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds
@@ -1693,7 +1707,7 @@ mkLoopBreakerNodes !env lvl body_uds details_s
-- in nondeterministic order as explained in
-- Note [Deterministic SCC] in GHC.Data.Graph.Directed.
where
- WithTailUsageDetails _ rhs = nd_rhs nd
+ WithUsageDetails _ (TE _ rhs) = nd_rhs nd
simple_nd = SND { snd_bndr = new_bndr, snd_rhs = rhs, snd_score = score }
score = nodeScore env new_bndr lb_deps nd
lb_deps = extendFvs_ rule_fv_env inl_fvs
@@ -1733,7 +1747,7 @@ nodeScore :: OccEnv
-> NodeDetails
-> NodeScore
nodeScore !env new_bndr lb_deps
- (ND { nd_bndr = old_bndr, nd_rhs = WithTailUsageDetails _ bind_rhs })
+ (ND { nd_bndr = old_bndr, nd_rhs = WithUsageDetails _ (TE _ bind_rhs) })
| not (isId old_bndr) -- A type or coercion variable is never a loop breaker
= (100, 0, False)
@@ -2010,7 +2024,7 @@ zapLambdaBndrs fun arg_count
zap_bndr b | isTyVar b = b
| otherwise = zapLamIdInfo b
-occAnalLamTail :: OccEnv -> CoreExpr -> WithTailUsageDetails CoreExpr
+occAnalLamTail :: OccEnv -> CoreExpr -> WithUsageDetails (Tail CoreExpr)
-- ^ See Note [Occurrence analysis for lambda binders].
-- It does the following:
-- * Sets one-shot info on the lambda binder from the OccEnv, and
@@ -2032,16 +2046,16 @@ occAnalLamTail :: OccEnv -> CoreExpr -> WithTailUsageDetails CoreExpr
-- See Note [Adjusting right-hand sides]
occAnalLamTail env (Lam bndr expr)
| isTyVar bndr
- = addInScopeTail env [bndr] $ \env ->
- let WithTailUsageDetails (TUD ja usage) expr' = occAnalLamTail env expr
- in WithTailUsageDetails (TUD (ja+1) usage) (Lam bndr expr')
+ = addInScope env [bndr] $ \env ->
+ let WithUsageDetails usage (TE ja expr') = occAnalLamTail env expr
+ in WithUsageDetails usage (TE (ja+1) (Lam bndr expr'))
-- Important: Do not modify occ_encl, so that with a RHS like
-- \(@ x) -> K @x (f @x)
-- we'll see that (K @x (f @x)) is in a OccRhs, and hence refrain
-- from inlining f. See the beginning of Note [Cascading inlines].
| otherwise -- So 'bndr' is an Id
- = addInScopeTail env [bndr] $ \env ->
+ = addInScope env [bndr] $ \env ->
let (env_one_shots', bndr1)
= case occ_one_shots env of
[] -> ([], bndr)
@@ -2052,14 +2066,14 @@ occAnalLamTail env (Lam bndr expr)
-- See Note [The oneShot function]
env1 = env { occ_encl = OccVanilla, occ_one_shots = env_one_shots' }
- WithTailUsageDetails (TUD ja usage) expr' = occAnalLamTail env1 expr
+ WithUsageDetails usage (TE ja expr') = occAnalLamTail env1 expr
bndr2 = tagLamBinder usage bndr1
- in WithTailUsageDetails (TUD (ja+1) usage) (Lam bndr2 expr')
+ in WithUsageDetails usage (TE (ja+1) (Lam bndr2 expr'))
-- For casts, keep going in the same lambda-group
-- See Note [Occurrence analysis for lambda binders]
occAnalLamTail env (Cast expr co)
- = let WithTailUsageDetails (TUD ja usage) expr' = occAnalLamTail env expr
+ = let WithUsageDetails usage (TE ja expr') = occAnalLamTail env expr
-- usage1: see Note [Gather occurrences of coercion variables]
usage1 = addManyOccs usage (coVarsOfCo co)
@@ -2075,10 +2089,10 @@ occAnalLamTail env (Cast expr co)
-- GHC.Core.Lint: Note Note [Join points and casts]
usage3 = markAllNonTail usage2
- in WithTailUsageDetails (TUD ja usage3) (Cast expr' co)
+ in WithUsageDetails usage3 (TE ja (Cast expr' co))
occAnalLamTail env expr = case occAnal env expr of
- WithUsageDetails usage expr' -> WithTailUsageDetails (TUD 0 usage) expr'
+ WithUsageDetails usage expr' -> WithUsageDetails usage (TE 0 expr')
{- Note [Occ-anal and cast worker/wrapper]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -2111,7 +2125,7 @@ of a right hand side is handled by occAnalLamTail.
occAnalUnfolding :: OccEnv
-> Unfolding
- -> WithTailUsageDetails Unfolding
+ -> WithUsageDetails (Tail Unfolding)
-- Occurrence-analyse a stable unfolding;
-- discard a non-stable one altogether and return empty usage details.
occAnalUnfolding !env unf
@@ -2119,13 +2133,14 @@ occAnalUnfolding !env unf
unf@(CoreUnfolding { uf_tmpl = rhs, uf_src = src })
| isStableSource src ->
let
- WithTailUsageDetails (TUD rhs_ja usage) rhs' = occAnalLamTail env rhs
+ WithUsageDetails usage (TE rhs_ja rhs') = occAnalLamTail env rhs
unf' | noBinderSwaps env = unf -- Note [Unfoldings and rules]
| otherwise = unf { uf_tmpl = rhs' }
- in WithTailUsageDetails (TUD rhs_ja (markAllMany usage)) unf'
+ in WithUsageDetails (markAllMany usage) (TE rhs_ja unf')
-- markAllMany: see Note [Occurrences in stable unfoldings]
- | otherwise -> WithTailUsageDetails (TUD 0 emptyDetails) unf
+
+ | otherwise -> WithUsageDetails emptyDetails (TE 0 unf)
-- For non-Stable unfoldings we leave them undisturbed, but
-- don't count their usage because the simplifier will discard them.
-- We leave them undisturbed because nodeScore uses their size info
@@ -2136,22 +2151,22 @@ occAnalUnfolding !env unf
unf@(DFunUnfolding { df_bndrs = bndrs, df_args = args })
-> let WithUsageDetails uds args' = addInScope env bndrs $ \ env ->
occAnalList env args
- in WithTailUsageDetails (TUD 0 uds) (unf { df_args = args' })
+ in WithUsageDetails uds (TE 0 (unf { df_args = args' }))
-- No need to use tagLamBinders because we
-- never inline DFuns so the occ-info on binders doesn't matter
- unf -> WithTailUsageDetails (TUD 0 emptyDetails) unf
+ unf -> WithUsageDetails emptyDetails (TE 0 unf)
occAnalRules :: OccEnv
-> Id -- Get rules from here
-> [(CoreRule, -- Each (non-built-in) rule
UsageDetails, -- Usage details for LHS
- TailUsageDetails)] -- Usage details for RHS
+ WithUsageDetails (Tail ()))] -- Usage details for RHS
occAnalRules !env bndr
= map occ_anal_rule (idCoreRules bndr)
where
occ_anal_rule rule@(Rule { ru_bndrs = bndrs, ru_args = args, ru_rhs = rhs })
- = (rule', lhs_uds', TUD rhs_ja rhs_uds')
+ = (rule', lhs_uds', WithUsageDetails rhs_uds' (TE rhs_ja ()))
where
rule' | noBinderSwaps env = rule -- Note [Unfoldings and rules]
| otherwise = rule { ru_args = args', ru_rhs = rhs' }
@@ -2167,7 +2182,8 @@ occAnalRules !env bndr
rhs_uds' = markAllMany rhs_uds
rhs_ja = length args -- See Note [Join points and unfoldings/rules]
- occ_anal_rule other_rule = (other_rule, emptyDetails, TUD 0 emptyDetails)
+ occ_anal_rule other_rule = ( other_rule, emptyDetails
+ , WithUsageDetails emptyDetails (TE 0 ()))
{- Note [Join point RHSs]
~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -2742,25 +2758,10 @@ isRhsEnv (OccEnv { occ_encl = cxt }) = case cxt of
addInScope :: OccEnv -> [Var] -> (OccEnv -> WithUsageDetails a)
-> WithUsageDetails a
-addInScope = add_in_scope fix_up
- where
- fix_up fix_up_usage (WithUsageDetails usage res)
- = WithUsageDetails (fix_up_usage usage) res
-
-addInScopeTail :: OccEnv -> [Var] -> (OccEnv -> WithTailUsageDetails a)
- -> WithTailUsageDetails a
-addInScopeTail = add_in_scope fix_up
- where
- fix_up fix_up_usage (WithTailUsageDetails (TUD ja usage) res)
- = WithTailUsageDetails (TUD ja (fix_up_usage usage)) res
-
-add_in_scope :: ((UsageDetails -> UsageDetails) -> res -> res)
- -> OccEnv -> [Var] -> (OccEnv -> res) -> res
-- Needed for all Vars not just Ids; a TyVar might have a CoVars in its kind
-add_in_scope fix_up_result
- env@(OccEnv { occ_join_points = join_points })
- bndrs thing_inside
- = fix_up_result fix_up_uds $ thing_inside $
+addInScope env@(OccEnv { occ_join_points = join_points })
+ bndrs thing_inside
+ = fix_up_uds $ thing_inside $
drop_shadowed_swaps $ drop_shadowed_joins env
where
@@ -2776,11 +2777,11 @@ add_in_scope fix_up_result
-- See Note [Occurrence analysis for join points]
drop_shadowed_joins env = env { occ_join_points = good_joins `delVarEnvList` bndrs}
- fix_up_uds :: UsageDetails -> UsageDetails
+ fix_up_uds :: WithUsageDetails a -> WithUsageDetails a
-- Remove usage for bndrs
-- Add usage info for (a) CoVars used in the types of bndrs
-- and (b) occ_join_points that we cannot push inwards because of shadowing
- fix_up_uds uds = with_joins
+ fix_up_uds (WithUsageDetails uds res) = WithUsageDetails with_joins res
where
trimmed_uds = uds `delDetails` bndrs
with_co_var_occs = trimmed_uds `addManyOccs` coVarOccs bndrs
@@ -3231,18 +3232,6 @@ data UsageDetails
instance Outputable UsageDetails where
ppr ud = ppr (ud_env (flattenUsageDetails ud))
--- | Captures the result of applying 'occAnalLamTail' to a function `\xyz.body`.
--- The TailUsageDetails records
--- * the number of lambdas (including type lambdas: a JoinArity)
--- * UsageDetails for the `body`, unadjusted by `adjustTailUsage`.
--- If the binding turns out to be a join point with the indicated join
--- arity, this unadjusted usage details is just what we need; otherwise we
--- need to discard tail calls. That's what `adjustTailUsage` does.
-data TailUsageDetails = TUD !JoinArity !UsageDetails
-
-instance Outputable TailUsageDetails where
- ppr (TUD ja uds) = lambda <> ppr ja <> ppr uds
-
-------------------
-- UsageDetails API
@@ -3395,10 +3384,9 @@ flattenUsageDetails ud@(UD { ud_env = env })
-------------------
-- See Note [Adjusting right-hand sides]
adjustTailUsage :: Maybe JoinArity
- -> CoreExpr -- Rhs, AFTER occAnalLamTail
- -> TailUsageDetails -- From body of lambda
- -> UsageDetails
-adjustTailUsage mb_join_arity rhs (TUD rhs_ja usage)
+ -> WithUsageDetails (Tail CoreExpr) -- Rhs, AFTER occAnalLamTail
+ -> UsageDetails
+adjustTailUsage mb_join_arity (WithUsageDetails usage (TE rhs_ja rhs))
= -- c.f. occAnal (Lam {})
markAllInsideLamIf (not one_shot) $
markAllNonTailIf (not exact_join) $
@@ -3407,9 +3395,9 @@ adjustTailUsage mb_join_arity rhs (TUD rhs_ja usage)
one_shot = isOneShotFun rhs
exact_join = mb_join_arity == Just rhs_ja
-adjustTailArity :: Maybe JoinArity -> TailUsageDetails -> UsageDetails
-adjustTailArity mb_rhs_ja (TUD ud_ja usage)
- = markAllNonTailIf (mb_rhs_ja /= Just ud_ja) usage
+adjustTailArity :: Maybe JoinArity -> WithUsageDetails (Tail a) -> UsageDetails
+adjustTailArity mb_rhs_ja (WithUsageDetails usage (TE ja _))
+ = markAllNonTailIf (mb_rhs_ja /= Just ja) usage
markNonRecJoinOneShots :: JoinArity -> CoreExpr -> CoreExpr
-- For a /non-recursive/ join point we can mark all
@@ -3490,8 +3478,8 @@ tagRecBinders lvl body_uds details_s
-- manifest join arity M.
-- This (re-)asserts that makeNode had made tuds for that same arity M!
unadj_uds = foldr (andUDs . test_manifest_arity) body_uds details_s
- test_manifest_arity ND{nd_rhs=WithTailUsageDetails tuds rhs}
- = adjustTailArity (Just (joinRhsArity rhs)) tuds
+ test_manifest_arity ND{nd_rhs = wud_rhs@(WithUsageDetails _ (TE _ rhs))}
+ = adjustTailArity (Just (joinRhsArity rhs)) wud_rhs
bndr_ne = expectNonEmpty "List of binders is never empty" bndrs
will_be_joins = decideJoinPointHood lvl unadj_uds bndr_ne
@@ -3512,9 +3500,9 @@ tagRecBinders lvl body_uds details_s
-- 2. Adjust usage details of each RHS, taking into account the
-- join-point-hood decision
- rhs_udss' = [ adjustTailUsage (mb_join_arity bndr) rhs rhs_tuds -- matching occAnalLamTail in makeNode
- | ND { nd_bndr = bndr, nd_rhs = WithTailUsageDetails rhs_tuds rhs }
- <- details_s ]
+ rhs_udss' = [ adjustTailUsage (mb_join_arity bndr) rhs_wuds
+ -- Matching occAnalLamTail in makeNode
+ | ND { nd_bndr = bndr, nd_rhs = rhs_wuds } <- details_s ]
-- 3. Compute final usage details from adjusted RHS details
adj_uds = foldr andUDs body_uds rhs_udss'