diff options
author | Simon Peyton Jones <simon.peytonjones@gmail.com> | 2023-01-13 16:42:01 +0000 |
---|---|---|
committer | Simon Peyton Jones <simon.peytonjones@gmail.com> | 2023-02-19 19:40:57 +0000 |
commit | b3cb11cc8416f443c3d1eaa5064c623a048a23ab (patch) | |
tree | cebb7b31c9656276d6688068acfb4e4616ccce05 | |
parent | 7267a5a0a1d4fdaae48d005f8531e11d3241e75a (diff) | |
download | haskell-b3cb11cc8416f443c3d1eaa5064c623a048a23ab.tar.gz |
Refactor WithTailJoinDetails
-rw-r--r-- | compiler/GHC/Core/Opt/OccurAnal.hs | 148 |
1 files changed, 68 insertions, 80 deletions
diff --git a/compiler/GHC/Core/Opt/OccurAnal.hs b/compiler/GHC/Core/Opt/OccurAnal.hs index 270cc5b8c6..26f2112119 100644 --- a/compiler/GHC/Core/Opt/OccurAnal.hs +++ b/compiler/GHC/Core/Opt/OccurAnal.hs @@ -819,7 +819,18 @@ of both functions, serving as a specification: data WithUsageDetails a = WithUsageDetails !UsageDetails !a -data WithTailUsageDetails a = WithTailUsageDetails !TailUsageDetails !a +-- | Captures the result of applying 'occAnalLamTail' to a function `\xyz.body`. +-- The TailUsageDetails records +-- * the number of lambdas (including type lambdas: a JoinArity) +-- * UsageDetails for the `body`, unadjusted by `adjustTailUsage`. +-- If the binding turns out to be a join point with the indicated join +-- arity, this unadjusted usage details is just what we need; otherwise we +-- need to discard tail calls. That's what `adjustTailUsage` does. +data Tail a = TE !JoinArity !a + +instance Outputable a => Outputable (Tail a) where + ppr (TE ja rhs) = text "TE" <> braces(ppr ja) <+> ppr rhs + ------------------------------------------------------------------ -- occAnalBind @@ -935,9 +946,9 @@ occAnalNonRecIdBind !env imp_rule_edges tagged_bndr rhs --------- Unfolding --------- -- See Note [Join points and unfoldings/rules] unf = idUnfolding tagged_bndr - WithTailUsageDetails unf_uds unf1 = occAnalUnfolding rhs_env unf + unf_wuds@(WithUsageDetails _ (TE _ unf1)) = occAnalUnfolding rhs_env unf unf2 = markNonRecUnfoldingOneShots mb_join_arity unf1 - adj_unf_uds = adjustTailArity mb_join_arity unf_uds + adj_unf_uds = adjustTailArity mb_join_arity unf_wuds --------- Rules --------- -- See Note [Rules are extra RHSs] and Note [Rule dependency info] @@ -991,19 +1002,22 @@ occAnalRecBind !rhs_env lvl imp_rule_edges pairs body_usage bndrs = map fst pairs bndr_set = mkVarSet bndrs -adjustNonRecRhs :: Maybe JoinArity -> WithTailUsageDetails CoreExpr -> WithUsageDetails CoreExpr +adjustNonRecRhs :: Maybe JoinArity -> WithUsageDetails (Tail CoreExpr) + -> WithUsageDetails CoreExpr -- ^ This function concentrates shared logic between occAnalNonRecBind and the -- AcyclicSCC case of occAnalRec. -- * It applies 'markNonRecJoinOneShots' to the RHS -- * and returns the adjusted rhs UsageDetails combined with the body usage -adjustNonRecRhs mb_join_arity (WithTailUsageDetails rhs_tuds rhs) +adjustNonRecRhs mb_join_arity (WithUsageDetails rhs_usage (TE ja rhs)) = WithUsageDetails rhs_uds' rhs' where --------- Marking (non-rec) join binders one-shot --------- !rhs' | Just ja <- mb_join_arity = markNonRecJoinOneShots ja rhs | otherwise = rhs + --------- Adjusting right-hand side usage --------- - rhs_uds' = adjustTailUsage mb_join_arity rhs' rhs_tuds + rhs_uds' = adjustTailUsage mb_join_arity $ + WithUsageDetails rhs_usage (TE ja rhs') bindersOfSCC :: SCC NodeDetails -> [Var] bindersOfSCC (AcyclicSCC nd) = [nd_bndr nd] @@ -1504,7 +1518,7 @@ type LetrecNode = Node Unique NodeDetails data NodeDetails = ND { nd_bndr :: Id -- Binder - , nd_rhs :: !(WithTailUsageDetails CoreExpr) + , nd_rhs :: !(WithUsageDetails (Tail CoreExpr)) -- ^ RHS, already occ-analysed -- With TailUsageDetails from RHS, and RULES, and stable unfoldings, -- ignoring phase (ie assuming all are active). @@ -1537,7 +1551,8 @@ instance Outputable NodeDetails where , text "simple =" <+> ppr (nd_simple nd) , text "active_rule_fvs =" <+> ppr (nd_active_rule_fvs nd) ]) - where WithTailUsageDetails uds _ = nd_rhs nd + where + WithUsageDetails uds _ = nd_rhs nd -- | Digraph with simplified and completely occurrence analysed -- 'SimpleNodeDetails', retaining just the info we need for breaking loops. @@ -1581,7 +1596,7 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs) -- explained in Note [Deterministic SCC] in GHC.Data.Graph.Directed. where details = ND { nd_bndr = bndr' - , nd_rhs = WithTailUsageDetails scope_uds rhs' + , nd_rhs = WithUsageDetails unadj_scope_uds rhs_te , nd_inl = inl_fvs , nd_simple = null rules_w_uds && null imp_rule_info , nd_weak_fvs = weak_fvs @@ -1594,7 +1609,6 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs) -- JoinArity rhs_ja of unadj_rhs_uds. unadj_inl_uds = unadj_rhs_uds `andUDs` adj_unf_uds unadj_scope_uds = unadj_inl_uds `andUDs` adj_rule_uds - scope_uds = TUD rhs_ja unadj_scope_uds -- Note [Rules are extra RHSs] -- Note [Rule dependency info] scope_fvs = udFreeVars bndr_set unadj_scope_uds @@ -1623,15 +1637,15 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs) -- until occAnalRec. In effect, we pretend that the RHS becomes a -- non-recursive join point and fix up later with adjustTailUsage. rhs_env = setRhsCtxt OccRhs env - WithTailUsageDetails (TUD rhs_ja unadj_rhs_uds) rhs' = occAnalLamTail rhs_env rhs + WithUsageDetails unadj_rhs_uds rhs_te@(TE rhs_ja _) = occAnalLamTail rhs_env rhs -- corresponding call to adjustTailUsage in occAnalRec and tagRecBinders --------- Unfolding --------- -- See Note [Join points and unfoldings/rules] unf = realIdUnfolding bndr -- realIdUnfolding: Ignore loop-breaker-ness -- here because that is what we are setting! - WithTailUsageDetails unf_tuds unf' = occAnalUnfolding rhs_env unf - adj_unf_uds = adjustTailArity (Just rhs_ja) unf_tuds + unf_wuds@(WithUsageDetails _ (TE _ unf')) = occAnalUnfolding rhs_env unf + adj_unf_uds = adjustTailArity (Just rhs_ja) unf_wuds -- `rhs_ja` is `joinRhsArity rhs` and is the prediction for source M -- of Note [Join arity prediction based on joinRhsArity] @@ -1646,8 +1660,8 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs) -- `rhs_ja` is `joinRhsArity rhs'` and is the prediction for source M -- of Note [Join arity prediction based on joinRhsArity] rules_w_uds :: [(CoreRule, UsageDetails, UsageDetails)] - rules_w_uds = [ (r,l,adjustTailArity (Just rhs_ja) rhs_tuds) - | (r,l,rhs_tuds) <- occAnalRules rhs_env bndr ] + rules_w_uds = [ (r,l,adjustTailArity (Just rhs_ja) rhs_wuds) + | (r,l,rhs_wuds) <- occAnalRules rhs_env bndr ] rules' = map fstOf3 rules_w_uds adj_rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds @@ -1693,7 +1707,7 @@ mkLoopBreakerNodes !env lvl body_uds details_s -- in nondeterministic order as explained in -- Note [Deterministic SCC] in GHC.Data.Graph.Directed. where - WithTailUsageDetails _ rhs = nd_rhs nd + WithUsageDetails _ (TE _ rhs) = nd_rhs nd simple_nd = SND { snd_bndr = new_bndr, snd_rhs = rhs, snd_score = score } score = nodeScore env new_bndr lb_deps nd lb_deps = extendFvs_ rule_fv_env inl_fvs @@ -1733,7 +1747,7 @@ nodeScore :: OccEnv -> NodeDetails -> NodeScore nodeScore !env new_bndr lb_deps - (ND { nd_bndr = old_bndr, nd_rhs = WithTailUsageDetails _ bind_rhs }) + (ND { nd_bndr = old_bndr, nd_rhs = WithUsageDetails _ (TE _ bind_rhs) }) | not (isId old_bndr) -- A type or coercion variable is never a loop breaker = (100, 0, False) @@ -2010,7 +2024,7 @@ zapLambdaBndrs fun arg_count zap_bndr b | isTyVar b = b | otherwise = zapLamIdInfo b -occAnalLamTail :: OccEnv -> CoreExpr -> WithTailUsageDetails CoreExpr +occAnalLamTail :: OccEnv -> CoreExpr -> WithUsageDetails (Tail CoreExpr) -- ^ See Note [Occurrence analysis for lambda binders]. -- It does the following: -- * Sets one-shot info on the lambda binder from the OccEnv, and @@ -2032,16 +2046,16 @@ occAnalLamTail :: OccEnv -> CoreExpr -> WithTailUsageDetails CoreExpr -- See Note [Adjusting right-hand sides] occAnalLamTail env (Lam bndr expr) | isTyVar bndr - = addInScopeTail env [bndr] $ \env -> - let WithTailUsageDetails (TUD ja usage) expr' = occAnalLamTail env expr - in WithTailUsageDetails (TUD (ja+1) usage) (Lam bndr expr') + = addInScope env [bndr] $ \env -> + let WithUsageDetails usage (TE ja expr') = occAnalLamTail env expr + in WithUsageDetails usage (TE (ja+1) (Lam bndr expr')) -- Important: Do not modify occ_encl, so that with a RHS like -- \(@ x) -> K @x (f @x) -- we'll see that (K @x (f @x)) is in a OccRhs, and hence refrain -- from inlining f. See the beginning of Note [Cascading inlines]. | otherwise -- So 'bndr' is an Id - = addInScopeTail env [bndr] $ \env -> + = addInScope env [bndr] $ \env -> let (env_one_shots', bndr1) = case occ_one_shots env of [] -> ([], bndr) @@ -2052,14 +2066,14 @@ occAnalLamTail env (Lam bndr expr) -- See Note [The oneShot function] env1 = env { occ_encl = OccVanilla, occ_one_shots = env_one_shots' } - WithTailUsageDetails (TUD ja usage) expr' = occAnalLamTail env1 expr + WithUsageDetails usage (TE ja expr') = occAnalLamTail env1 expr bndr2 = tagLamBinder usage bndr1 - in WithTailUsageDetails (TUD (ja+1) usage) (Lam bndr2 expr') + in WithUsageDetails usage (TE (ja+1) (Lam bndr2 expr')) -- For casts, keep going in the same lambda-group -- See Note [Occurrence analysis for lambda binders] occAnalLamTail env (Cast expr co) - = let WithTailUsageDetails (TUD ja usage) expr' = occAnalLamTail env expr + = let WithUsageDetails usage (TE ja expr') = occAnalLamTail env expr -- usage1: see Note [Gather occurrences of coercion variables] usage1 = addManyOccs usage (coVarsOfCo co) @@ -2075,10 +2089,10 @@ occAnalLamTail env (Cast expr co) -- GHC.Core.Lint: Note Note [Join points and casts] usage3 = markAllNonTail usage2 - in WithTailUsageDetails (TUD ja usage3) (Cast expr' co) + in WithUsageDetails usage3 (TE ja (Cast expr' co)) occAnalLamTail env expr = case occAnal env expr of - WithUsageDetails usage expr' -> WithTailUsageDetails (TUD 0 usage) expr' + WithUsageDetails usage expr' -> WithUsageDetails usage (TE 0 expr') {- Note [Occ-anal and cast worker/wrapper] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2111,7 +2125,7 @@ of a right hand side is handled by occAnalLamTail. occAnalUnfolding :: OccEnv -> Unfolding - -> WithTailUsageDetails Unfolding + -> WithUsageDetails (Tail Unfolding) -- Occurrence-analyse a stable unfolding; -- discard a non-stable one altogether and return empty usage details. occAnalUnfolding !env unf @@ -2119,13 +2133,14 @@ occAnalUnfolding !env unf unf@(CoreUnfolding { uf_tmpl = rhs, uf_src = src }) | isStableSource src -> let - WithTailUsageDetails (TUD rhs_ja usage) rhs' = occAnalLamTail env rhs + WithUsageDetails usage (TE rhs_ja rhs') = occAnalLamTail env rhs unf' | noBinderSwaps env = unf -- Note [Unfoldings and rules] | otherwise = unf { uf_tmpl = rhs' } - in WithTailUsageDetails (TUD rhs_ja (markAllMany usage)) unf' + in WithUsageDetails (markAllMany usage) (TE rhs_ja unf') -- markAllMany: see Note [Occurrences in stable unfoldings] - | otherwise -> WithTailUsageDetails (TUD 0 emptyDetails) unf + + | otherwise -> WithUsageDetails emptyDetails (TE 0 unf) -- For non-Stable unfoldings we leave them undisturbed, but -- don't count their usage because the simplifier will discard them. -- We leave them undisturbed because nodeScore uses their size info @@ -2136,22 +2151,22 @@ occAnalUnfolding !env unf unf@(DFunUnfolding { df_bndrs = bndrs, df_args = args }) -> let WithUsageDetails uds args' = addInScope env bndrs $ \ env -> occAnalList env args - in WithTailUsageDetails (TUD 0 uds) (unf { df_args = args' }) + in WithUsageDetails uds (TE 0 (unf { df_args = args' })) -- No need to use tagLamBinders because we -- never inline DFuns so the occ-info on binders doesn't matter - unf -> WithTailUsageDetails (TUD 0 emptyDetails) unf + unf -> WithUsageDetails emptyDetails (TE 0 unf) occAnalRules :: OccEnv -> Id -- Get rules from here -> [(CoreRule, -- Each (non-built-in) rule UsageDetails, -- Usage details for LHS - TailUsageDetails)] -- Usage details for RHS + WithUsageDetails (Tail ()))] -- Usage details for RHS occAnalRules !env bndr = map occ_anal_rule (idCoreRules bndr) where occ_anal_rule rule@(Rule { ru_bndrs = bndrs, ru_args = args, ru_rhs = rhs }) - = (rule', lhs_uds', TUD rhs_ja rhs_uds') + = (rule', lhs_uds', WithUsageDetails rhs_uds' (TE rhs_ja ())) where rule' | noBinderSwaps env = rule -- Note [Unfoldings and rules] | otherwise = rule { ru_args = args', ru_rhs = rhs' } @@ -2167,7 +2182,8 @@ occAnalRules !env bndr rhs_uds' = markAllMany rhs_uds rhs_ja = length args -- See Note [Join points and unfoldings/rules] - occ_anal_rule other_rule = (other_rule, emptyDetails, TUD 0 emptyDetails) + occ_anal_rule other_rule = ( other_rule, emptyDetails + , WithUsageDetails emptyDetails (TE 0 ())) {- Note [Join point RHSs] ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2742,25 +2758,10 @@ isRhsEnv (OccEnv { occ_encl = cxt }) = case cxt of addInScope :: OccEnv -> [Var] -> (OccEnv -> WithUsageDetails a) -> WithUsageDetails a -addInScope = add_in_scope fix_up - where - fix_up fix_up_usage (WithUsageDetails usage res) - = WithUsageDetails (fix_up_usage usage) res - -addInScopeTail :: OccEnv -> [Var] -> (OccEnv -> WithTailUsageDetails a) - -> WithTailUsageDetails a -addInScopeTail = add_in_scope fix_up - where - fix_up fix_up_usage (WithTailUsageDetails (TUD ja usage) res) - = WithTailUsageDetails (TUD ja (fix_up_usage usage)) res - -add_in_scope :: ((UsageDetails -> UsageDetails) -> res -> res) - -> OccEnv -> [Var] -> (OccEnv -> res) -> res -- Needed for all Vars not just Ids; a TyVar might have a CoVars in its kind -add_in_scope fix_up_result - env@(OccEnv { occ_join_points = join_points }) - bndrs thing_inside - = fix_up_result fix_up_uds $ thing_inside $ +addInScope env@(OccEnv { occ_join_points = join_points }) + bndrs thing_inside + = fix_up_uds $ thing_inside $ drop_shadowed_swaps $ drop_shadowed_joins env where @@ -2776,11 +2777,11 @@ add_in_scope fix_up_result -- See Note [Occurrence analysis for join points] drop_shadowed_joins env = env { occ_join_points = good_joins `delVarEnvList` bndrs} - fix_up_uds :: UsageDetails -> UsageDetails + fix_up_uds :: WithUsageDetails a -> WithUsageDetails a -- Remove usage for bndrs -- Add usage info for (a) CoVars used in the types of bndrs -- and (b) occ_join_points that we cannot push inwards because of shadowing - fix_up_uds uds = with_joins + fix_up_uds (WithUsageDetails uds res) = WithUsageDetails with_joins res where trimmed_uds = uds `delDetails` bndrs with_co_var_occs = trimmed_uds `addManyOccs` coVarOccs bndrs @@ -3231,18 +3232,6 @@ data UsageDetails instance Outputable UsageDetails where ppr ud = ppr (ud_env (flattenUsageDetails ud)) --- | Captures the result of applying 'occAnalLamTail' to a function `\xyz.body`. --- The TailUsageDetails records --- * the number of lambdas (including type lambdas: a JoinArity) --- * UsageDetails for the `body`, unadjusted by `adjustTailUsage`. --- If the binding turns out to be a join point with the indicated join --- arity, this unadjusted usage details is just what we need; otherwise we --- need to discard tail calls. That's what `adjustTailUsage` does. -data TailUsageDetails = TUD !JoinArity !UsageDetails - -instance Outputable TailUsageDetails where - ppr (TUD ja uds) = lambda <> ppr ja <> ppr uds - ------------------- -- UsageDetails API @@ -3395,10 +3384,9 @@ flattenUsageDetails ud@(UD { ud_env = env }) ------------------- -- See Note [Adjusting right-hand sides] adjustTailUsage :: Maybe JoinArity - -> CoreExpr -- Rhs, AFTER occAnalLamTail - -> TailUsageDetails -- From body of lambda - -> UsageDetails -adjustTailUsage mb_join_arity rhs (TUD rhs_ja usage) + -> WithUsageDetails (Tail CoreExpr) -- Rhs, AFTER occAnalLamTail + -> UsageDetails +adjustTailUsage mb_join_arity (WithUsageDetails usage (TE rhs_ja rhs)) = -- c.f. occAnal (Lam {}) markAllInsideLamIf (not one_shot) $ markAllNonTailIf (not exact_join) $ @@ -3407,9 +3395,9 @@ adjustTailUsage mb_join_arity rhs (TUD rhs_ja usage) one_shot = isOneShotFun rhs exact_join = mb_join_arity == Just rhs_ja -adjustTailArity :: Maybe JoinArity -> TailUsageDetails -> UsageDetails -adjustTailArity mb_rhs_ja (TUD ud_ja usage) - = markAllNonTailIf (mb_rhs_ja /= Just ud_ja) usage +adjustTailArity :: Maybe JoinArity -> WithUsageDetails (Tail a) -> UsageDetails +adjustTailArity mb_rhs_ja (WithUsageDetails usage (TE ja _)) + = markAllNonTailIf (mb_rhs_ja /= Just ja) usage markNonRecJoinOneShots :: JoinArity -> CoreExpr -> CoreExpr -- For a /non-recursive/ join point we can mark all @@ -3490,8 +3478,8 @@ tagRecBinders lvl body_uds details_s -- manifest join arity M. -- This (re-)asserts that makeNode had made tuds for that same arity M! unadj_uds = foldr (andUDs . test_manifest_arity) body_uds details_s - test_manifest_arity ND{nd_rhs=WithTailUsageDetails tuds rhs} - = adjustTailArity (Just (joinRhsArity rhs)) tuds + test_manifest_arity ND{nd_rhs = wud_rhs@(WithUsageDetails _ (TE _ rhs))} + = adjustTailArity (Just (joinRhsArity rhs)) wud_rhs bndr_ne = expectNonEmpty "List of binders is never empty" bndrs will_be_joins = decideJoinPointHood lvl unadj_uds bndr_ne @@ -3512,9 +3500,9 @@ tagRecBinders lvl body_uds details_s -- 2. Adjust usage details of each RHS, taking into account the -- join-point-hood decision - rhs_udss' = [ adjustTailUsage (mb_join_arity bndr) rhs rhs_tuds -- matching occAnalLamTail in makeNode - | ND { nd_bndr = bndr, nd_rhs = WithTailUsageDetails rhs_tuds rhs } - <- details_s ] + rhs_udss' = [ adjustTailUsage (mb_join_arity bndr) rhs_wuds + -- Matching occAnalLamTail in makeNode + | ND { nd_bndr = bndr, nd_rhs = rhs_wuds } <- details_s ] -- 3. Compute final usage details from adjusted RHS details adj_uds = foldr andUDs body_uds rhs_udss' |