diff options
author | Sebastian Graf <sebastian.graf@kit.edu> | 2022-11-14 17:40:01 +0100 |
---|---|---|
committer | Marge Bot <ben+marge-bot@smart-cactus.org> | 2023-01-12 15:51:47 -0500 |
commit | 905d0b6e1db714b306a940fb58a570c9294aa88d (patch) | |
tree | db8750594a94ad305f7f98fe87e7ccf8d2b17d6f | |
parent | 9ffd5d57a7cc19bcd6ea0139b00c77639566ba82 (diff) | |
download | haskell-905d0b6e1db714b306a940fb58a570c9294aa88d.tar.gz |
Fix contification with stable unfoldings (#22428)
Many functions now return a `TailUsageDetails` that adorns a `UsageDetails` with
a `JoinArity` that reflects the number of join point binders around the body
for which the `UsageDetails` was computed. `TailUsageDetails` is now returned by
`occAnalLamTail` as well as `occAnalUnfolding` and `occAnalRules`.
I adjusted `Note [Join points and unfoldings/rules]` and
`Note [Adjusting right-hand sides]` to account for the new machinery.
I also wrote a new `Note [Join arity prediction based on joinRhsArity]`
and refer to it when we combine `TailUsageDetails` for a recursive RHS.
I also renamed
* `occAnalLam` to `occAnalLamTail`
* `adjustRhsUsage` to `adjustTailUsage`
* a few other less important functions
and properly documented the that each call of `occAnalLamTail` must pair up with
`adjustTailUsage`.
I removed `Note [Unfoldings and join points]` because it was redundant with
`Note [Occurrences in stable unfoldings]`.
While in town, I refactored `mkLoopBreakerNodes` so that it returns a condensed
`NodeDetails` called `SimpleNodeDetails`.
Fixes #22428.
The refactoring seems to have quite beneficial effect on ghc/alloc performance:
```
CoOpt_Read(normal) ghc/alloc 784,778,420 768,091,176 -2.1% GOOD
T12150(optasm) ghc/alloc 77,762,270 75,986,720 -2.3% GOOD
T12425(optasm) ghc/alloc 85,740,186 84,641,712 -1.3% GOOD
T13056(optasm) ghc/alloc 306,104,656 299,811,632 -2.1% GOOD
T13253(normal) ghc/alloc 350,233,952 346,004,008 -1.2%
T14683(normal) ghc/alloc 2,800,514,792 2,754,651,360 -1.6%
T15304(normal) ghc/alloc 1,230,883,318 1,215,978,336 -1.2%
T15630(normal) ghc/alloc 153,379,590 151,796,488 -1.0%
T16577(normal) ghc/alloc 7,356,797,056 7,244,194,416 -1.5%
T17516(normal) ghc/alloc 1,718,941,448 1,692,157,288 -1.6%
T19695(normal) ghc/alloc 1,485,794,632 1,458,022,112 -1.9%
T21839c(normal) ghc/alloc 437,562,314 431,295,896 -1.4% GOOD
T21839r(normal) ghc/alloc 446,927,580 440,615,776 -1.4% GOOD
geo. mean -0.6%
minimum -2.4%
maximum -0.0%
```
Metric Decrease:
CoOpt_Read
T10421
T12150
T12425
T13056
T18698a
T18698b
T21839c
T21839r
T9961
-rw-r--r-- | compiler/GHC/Core/Opt/Arity.hs | 3 | ||||
-rw-r--r-- | compiler/GHC/Core/Opt/OccurAnal.hs | 639 | ||||
-rw-r--r-- | compiler/GHC/Data/Graph/Directed.hs | 3 | ||||
-rw-r--r-- | compiler/GHC/Utils/Misc.hs | 13 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/T22428.hs | 9 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/T22428.stderr | 45 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/all.T | 4 |
7 files changed, 472 insertions, 244 deletions
diff --git a/compiler/GHC/Core/Opt/Arity.hs b/compiler/GHC/Core/Opt/Arity.hs index 832fba354c..5ed015281a 100644 --- a/compiler/GHC/Core/Opt/Arity.hs +++ b/compiler/GHC/Core/Opt/Arity.hs @@ -132,6 +132,9 @@ joinRhsArity :: CoreExpr -> JoinArity -- Join points are supposed to have manifestly-visible -- lambdas at the top: no ticks, no casts, nothing -- Moreover, type lambdas count in JoinArity +-- NB: For non-recursive bindings, the join arity of the binding may actually be +-- less that the number of manifestly-visible lambdas. +-- See Note [Join arity prediction based on joinRhsArity] in GHC.Core.Opt.OccurAnal joinRhsArity (Lam _ e) = 1 + joinRhsArity e joinRhsArity _ = 0 diff --git a/compiler/GHC/Core/Opt/OccurAnal.hs b/compiler/GHC/Core/Opt/OccurAnal.hs index 539074e698..fc374adb99 100644 --- a/compiler/GHC/Core/Opt/OccurAnal.hs +++ b/compiler/GHC/Core/Opt/OccurAnal.hs @@ -59,7 +59,7 @@ import GHC.Builtin.Names( runRWKey ) import GHC.Unit.Module( Module ) import Data.List (mapAccumL, mapAccumR) -import Data.List.NonEmpty (NonEmpty (..), nonEmpty) +import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as NE {- @@ -510,6 +510,16 @@ of the file. at any of the definitions. This is done by Simplify.simplRecBind, when it calls addLetIdInfo. +Note [TailUsageDetails when forming Rec groups] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The `TailUsageDetails` stored in the `nd_uds` field of a `NodeDetails` is +computed by `occAnalLamTail` applied to the RHS, not `occAnalExpr`. +That is because the binding might still become a *non-recursive join point* in +the AcyclicSCC case of dependency analysis! +Hence we do the delayed `adjustTailUsage` in `occAnalRec`/`tagRecBinders` to get +a regular, adjusted UsageDetails. +See Note [Join points and unfoldings/rules] for more details on the contract. + Note [Stable unfoldings] ~~~~~~~~~~~~~~~~~~~~~~~~ None of the above stuff about RULES applies to a stable unfolding @@ -608,6 +618,65 @@ tail call with `n` arguments (counting both value and type arguments). Otherwise 'occ_tail' will be 'NoTailCallInfo'. The tail call info flows bottom-up with the rest of 'OccInfo' until it goes on the binder. +Note [Join arity prediction based on joinRhsArity] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +In general, the join arity from tail occurrences of a join point (O) may be +higher or lower than the manifest join arity of the join body (M). E.g., + + -- M > O: + let f x y = x + y -- M = 2 + in if b then f 1 else f 2 -- O = 1 + ==> { Contify for join arity 1 } + join f x = \y -> x + y + in if b then jump f 1 else jump f 2 + + -- M < O + let f = id -- M = 0 + in if ... then f 12 else f 13 -- O = 1 + ==> { Contify for join arity 1, eta-expand f } + join f x = id x + in if b then jump f 12 else jump f 13 + +But for *recursive* let, it is crucial that both arities match up, consider + + letrec f x y = if ... then f x else True + in f 42 + +Here, M=2 but O=1. If we settled for a joinrec arity of 1, the recursive jump +would not happen in a tail context! Contification is invalid here. +So indeed it is crucial to demand that M=O. + +(Side note: Actually, we could be more specific: Let O1 be the join arity of +occurrences from the letrec RHS and O2 the join arity from the let body. Then +we need M=O1 and M<=O2 and could simply eta-expand the RHS to match O2 later. +M=O is the specific case where we don't want to eta-expand. Neither the join +points paper nor GHC does this at the moment.) + +We can capitalise on this observation and conclude that *if* f could become a +joinrec (without eta-expansion), it will have join arity M. +Now, M is just the result of 'joinRhsArity', a rather simple, local analysis. +It is also the join arity inside the 'TailUsageDetails' returned by +'occAnalLamTail', so we can predict join arity without doing any fixed-point +iteration or really doing any deep traversal of let body or RHS at all. +We check for M in the 'adjustTailUsage' call inside 'tagRecBinders'. + +All this is quite apparent if you look at the contification transformation in +Fig. 5 of "Compiling without Continuations" (which does not account for +eta-expansion at all, mind you). The letrec case looks like this + + letrec f = /\as.\xs. L[us] in L'[es] + ... and a bunch of conditions establishing that f only occurs + in app heads of join arity (len as + len xs) inside us and es ... + +The syntactic form `/\as.\xs. L[us]` forces M=O iff `f` occurs in `us`. However, +for non-recursive functions, this is the definition of contification from the +paper: + + let f = /\as.\xs.u in L[es] ... conditions ... + +Note that u could be a lambda itself, as we have seen. No relationship between M +and O to exploit here. + Note [Join points and unfoldings/rules] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Consider @@ -618,8 +687,10 @@ Consider Before j is inlined, we'll have occurrences of j2 in both j's RHS and in its stable unfolding. We want to discover -j2 as a join point. So we must do the adjustRhsUsage thing -on j's RHS. That's why we pass mb_join_arity to calcUnfolding. +j2 as a join point. So 'occAnalUnfolding' returns an unadjusted +'TailUsageDetails', like 'occAnalLamTail'. We adjust the usage details of the +unfolding to the actual join arity using the same 'adjustTailArity' as for +the RHS, see Note [Adjusting right-hand sides]. Same with rules. Suppose we have: @@ -636,14 +707,31 @@ up. So provided the join-point arity of k matches the args of the rule we can allow the tail-call info from the RHS of the rule to propagate. -* Wrinkle for Rec case. In the recursive case we don't know the - join-point arity in advance, when calling occAnalUnfolding and - occAnalRules. (See makeNode.) We don't want to pass Nothing, - because then a recursive joinrec might lose its join-poin-hood - when SpecConstr adds a RULE. So we just make do with the - *current* join-poin-hood, stored in the Id. +* Note that the join arity of the RHS and that of the unfolding or RULE might + mismatch: + + let j x y = j2 (x+x) + {-# INLINE[2] j = \x. g #-} + {-# RULE forall x y z. j x y z = h 17 #-} + in j 1 2 - In the non-recursive case things are simple: see occAnalNonRecBind + So it is crucial that we adjust each TailUsageDetails individually + with the actual join arity 2 here before we combine with `andUDs`. + Here, that means losing tail call info on `g` and `h`. + +* Wrinkle for Rec case: We store one TailUsageDetails in the node Details for + RHS, unfolding and RULE combined. Clearly, if they don't agree on their join + arity, we have to do some adjusting. We choose to adjust to the join arity + of the RHS, because that is likely the join arity that the join point will + have; see Note [Join arity prediction based on joinRhsArity]. + + If the guess is correct, then tail calls in the RHS are preserved; a necessary + condition for the whole binding becoming a joinrec. + The guess can only be incorrect in the 'AcyclicSCC' case when the binding + becomes a non-recursive join point with a different join arity. But then the + eventual call to 'adjustTailUsage' in 'tagRecBinders'/'occAnalRec' will + be with a different join arity and destroy unsound tail call info with + 'markNonTail'. * Wrinkle for RULES. Suppose the example was a bit different: let j :: Int -> Int @@ -669,28 +757,21 @@ propagate. This appears to be very rare in practice. TODO Perhaps we should gather statistics to be sure. -Note [Unfoldings and join points] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -We assume that anything in an unfolding occurs multiple times, since -unfoldings are often copied (that's the whole point!). But we still -need to track tail calls for the purpose of finding join points. - - ------------------------------------------------------------ Note [Adjusting right-hand sides] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ There's a bit of a dance we need to do after analysing a lambda expression or a right-hand side. In particular, we need to - a) call 'markAllInsideLam' *unless* the binding is for a thunk, a one-shot - lambda, or a non-recursive join point; and - b) call 'markAllNonTail' *unless* the binding is for a join point, and - the RHS has the right arity; e.g. + a) call 'markAllNonTail' *unless* the binding is for a join point, and + the TailUsageDetails from the RHS has the right join arity; e.g. join j x y = case ... of A -> j2 p B -> j2 q in j a b Here we want the tail calls to j2 to be tail calls of the whole expression + b) call 'markAllInsideLam' *unless* the binding is for a thunk, a one-shot + lambda, or a non-recursive join point Some examples, with how the free occurrences in e (assumed not to be a value lambda) get marked: @@ -707,26 +788,39 @@ lambda) get marked: There are a few other caveats; most importantly, if we're marking a binding as 'AlwaysTailCalled', it's *going* to be a join point, so we treat it as one so that the effect cascades properly. Consequently, at the time the RHS is -analysed, we won't know what adjustments to make; thus 'occAnalLamOrRhs' must -return the unadjusted 'UsageDetails', to be adjusted by 'adjustRhsUsage' once -join-point-hood has been decided. - -Thus the overall sequence taking place in 'occAnalNonRecBind' and -'occAnalRecBind' is as follows: - - 1. Call 'occAnalLamOrRhs' to find usage information for the RHS. - 2. Call 'tagNonRecBinder' or 'tagRecBinders', which decides whether to make +analysed, we won't know what adjustments to make; thus 'occAnalLamTail' must +return the unadjusted 'TailUsageDetails', to be adjusted by 'adjustTailUsage' +once join-point-hood has been decided and eventual one-shot annotations have +been added through 'markNonRecJoinOneShots'. + +It is not so simple to see that 'occAnalNonRecBind' and 'occAnalRecBind' indeed +perform a similar sequence of steps. Thus, here is an interleaving of events +of both functions, serving as a specification: + + 1. Call 'occAnalLamTail' to find usage information for the RHS. + Recursive case: 'makeNode' + Non-recursive case: 'occAnalNonRecBind' + 2. (Analyse the binding's scope. Done in 'occAnalBind'/`occAnal Let{}`. + Same whether recursive or not.) + 3. Call 'tagNonRecBinder' or 'tagRecBinders', which decides whether to make the binding a join point. - 3. Call 'adjustRhsUsage' accordingly. (Done as part of 'tagRecBinders' when - recursive.) - -(In the recursive case, this logic is spread between 'makeNode' and -'occAnalRec'.) + Cyclic Recursive case: 'mkLoopBreakerNodes' + Acyclic Recursive case: `occAnalRec AcyclicSCC{}` + Non-recursive case: 'occAnalNonRecBind' + 4. Non-recursive join point: Call 'markNonRecJoinOneShots' so that e.g., + FloatOut sees one-shot annotations on lambdas + Acyclic Recursive case: `occAnalRec AcyclicSCC{}` calls 'adjustNonRecRhs' + Non-recursive case: 'occAnalNonRecBind' calls 'adjustNonRecRhs' + 5. Call 'adjustTailUsage' accordingly. + Cyclic Recursive case: 'tagRecBinders' + Acyclic Recursive case: 'adjustNonRecRhs' + Non-recursive case: 'adjustNonRecRhs' -} - data WithUsageDetails a = WithUsageDetails !UsageDetails !a +data WithTailUsageDetails a = WithTailUsageDetails !TailUsageDetails !a + ------------------------------------------------------------------ -- occAnalBind ------------------------------------------------------------------ @@ -750,19 +844,17 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage | isTyVar bndr -- A type let; we don't gather usage info = WithUsageDetails body_usage [NonRec bndr rhs] - | not (bndr `usedIn` body_usage) -- It's not mentioned - = WithUsageDetails body_usage [] + | not (bndr `usedIn` body_usage) + = WithUsageDetails body_usage [] -- See Note [Dead code] | otherwise -- It's mentioned in the body - = WithUsageDetails (body_usage' `andUDs` rhs_usage) [NonRec final_bndr rhs'] + = WithUsageDetails (body_usage' `andUDs` rhs_usage) [NonRec final_bndr final_rhs] where - (body_usage', tagged_bndr) = tagNonRecBinder lvl body_usage bndr - final_bndr = tagged_bndr `setIdUnfolding` unf' - `setIdSpecialisation` mkRuleInfo rules' - rhs_usage = rhs_uds `andUDs` unf_uds `andUDs` rule_uds + WithUsageDetails body_usage' tagged_bndr = tagNonRecBinder lvl body_usage bndr -- Get the join info from the *new* decision -- See Note [Join points and unfoldings/rules] + -- => join arity O of Note [Join arity prediction based on joinRhsArity] mb_join_arity = willBeJoinId_maybe tagged_bndr is_join_point = isJust mb_join_arity @@ -773,17 +865,28 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage -- See Note [Sources of one-shot information] rhs_env = env1 { occ_one_shots = argOneShots dmd } - (WithUsageDetails rhs_uds rhs') = occAnalRhs rhs_env NonRecursive mb_join_arity rhs + -- See Note [Join arity prediction based on joinRhsArity] + -- Match join arity O from mb_join_arity with manifest join arity M as + -- returned by of occAnalLamTail. It's totally OK for them to mismatch; + -- hence adjust the UDs from the RHS + WithUsageDetails adj_rhs_uds final_rhs + = adjustNonRecRhs mb_join_arity $ occAnalLamTail rhs_env rhs + rhs_usage = adj_rhs_uds `andUDs` adj_unf_uds `andUDs` adj_rule_uds + final_bndr = tagged_bndr `setIdSpecialisation` mkRuleInfo rules' + `setIdUnfolding` unf2 --------- Unfolding --------- - -- See Note [Unfoldings and join points] + -- See Note [Join points and unfoldings/rules] unf | isId bndr = idUnfolding bndr | otherwise = NoUnfolding - (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env NonRecursive mb_join_arity unf + WithTailUsageDetails unf_uds unf1 = occAnalUnfolding rhs_env unf + unf2 = markNonRecUnfoldingOneShots mb_join_arity unf1 + adj_unf_uds = adjustTailArity mb_join_arity unf_uds --------- Rules --------- -- See Note [Rules are extra RHSs] and Note [Rule dependency info] - rules_w_uds = occAnalRules rhs_env mb_join_arity bndr + -- and Note [Join points and unfoldings/rules] + rules_w_uds = occAnalRules rhs_env bndr rules' = map fstOf3 rules_w_uds imp_rule_uds = impRulesScopeUsage (lookupImpRules imp_rule_edges bndr) -- imp_rule_uds: consider @@ -794,8 +897,9 @@ occAnalNonRecBind !env lvl imp_rule_edges bndr rhs body_usage -- that g is (since the RULE might turn g into h), so -- we make g mention h. - rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds - add_rule_uds (_, l, r) uds = l `andUDs` r `andUDs` uds + adj_rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds + add_rule_uds (_, l, r) uds + = l `andUDs` adjustTailArity mb_join_arity r `andUDs` uds ---------- occ = idOccInfo tagged_bndr @@ -820,7 +924,7 @@ occAnalRecBind :: OccEnv -> TopLevelFlag -> ImpRuleEdges -> [(Var,CoreExpr)] occAnalRecBind !env lvl imp_rule_edges pairs body_usage = foldr (occAnalRec rhs_env lvl) (WithUsageDetails body_usage []) sccs where - sccs :: [SCC Details] + sccs :: [SCC NodeDetails] sccs = {-# SCC "occAnalBind.scc" #-} stronglyConnCompFromEdgedVerticesUniq nodes @@ -832,48 +936,62 @@ occAnalRecBind !env lvl imp_rule_edges pairs body_usage bndr_set = mkVarSet bndrs rhs_env = env `addInScope` bndrs +adjustNonRecRhs :: Maybe JoinArity -> WithTailUsageDetails CoreExpr -> WithUsageDetails CoreExpr +-- ^ This function concentrates shared logic between occAnalNonRecBind and the +-- AcyclicSCC case of occAnalRec. +-- * It applies 'markNonRecJoinOneShots' to the RHS +-- * and returns the adjusted rhs UsageDetails combined with the body usage +adjustNonRecRhs mb_join_arity (WithTailUsageDetails rhs_tuds rhs) + = WithUsageDetails rhs_uds' rhs' + where + --------- Marking (non-rec) join binders one-shot --------- + !rhs' | Just ja <- mb_join_arity = markNonRecJoinOneShots ja rhs + | otherwise = rhs + --------- Adjusting right-hand side usage --------- + rhs_uds' = adjustTailUsage mb_join_arity rhs' rhs_tuds + +bindersOfSCC :: SCC NodeDetails -> [Var] +bindersOfSCC (AcyclicSCC nd) = [nd_bndr nd] +bindersOfSCC (CyclicSCC ds) = map nd_bndr ds ----------------------------- occAnalRec :: OccEnv -> TopLevelFlag - -> SCC Details + -> SCC NodeDetails -> WithUsageDetails [CoreBind] -> WithUsageDetails [CoreBind] - -- The NonRec case is just like a Let (NonRec ...) above -occAnalRec !_ lvl (AcyclicSCC (ND { nd_bndr = bndr, nd_rhs = rhs - , nd_uds = rhs_uds })) - (WithUsageDetails body_uds binds) - | not (bndr `usedIn` body_uds) - = WithUsageDetails body_uds binds -- See Note [Dead code] +-- Check for Note [Dead code] +-- NB: Only look at body_uds, ignoring uses in the SCC +occAnalRec !_ _ scc (WithUsageDetails body_uds binds) + | not (any (`usedIn` body_uds) (bindersOfSCC scc)) + = WithUsageDetails body_uds binds - | otherwise -- It's mentioned in the body - = WithUsageDetails (body_uds' `andUDs` rhs_uds') - (NonRec tagged_bndr rhs : binds) +-- The NonRec case is just like a Let (NonRec ...) above +occAnalRec !_ lvl + (AcyclicSCC (ND { nd_bndr = bndr, nd_rhs = wtuds })) + (WithUsageDetails body_uds binds) + = WithUsageDetails (body_uds' `andUDs` rhs_uds') (NonRec bndr' rhs' : binds) where - (body_uds', tagged_bndr) = tagNonRecBinder lvl body_uds bndr - rhs_uds' = adjustRhsUsage mb_join_arity rhs rhs_uds + WithUsageDetails body_uds' tagged_bndr = tagNonRecBinder lvl body_uds bndr mb_join_arity = willBeJoinId_maybe tagged_bndr + WithUsageDetails rhs_uds' rhs' = adjustNonRecRhs mb_join_arity wtuds + !unf' = markNonRecUnfoldingOneShots mb_join_arity (idUnfolding tagged_bndr) + !bndr' = tagged_bndr `setIdUnfolding` unf' - -- The Rec case is the interesting one - -- See Note [Recursive bindings: the grand plan] - -- See Note [Loop breaking] +-- The Rec case is the interesting one +-- See Note [Recursive bindings: the grand plan] +-- See Note [Loop breaking] occAnalRec env lvl (CyclicSCC details_s) (WithUsageDetails body_uds binds) - | not (any (`usedIn` body_uds) bndrs) -- NB: look at body_uds, not total_uds - = WithUsageDetails body_uds binds -- See Note [Dead code] - - | otherwise -- At this point we always build a single Rec = -- pprTrace "occAnalRec" (ppr loop_breaker_nodes) WithUsageDetails final_uds (Rec pairs : binds) - where - bndrs = map nd_bndr details_s all_simple = all nd_simple details_s ------------------------------ -- Make the nodes for the loop-breaker analysis -- See Note [Choosing loop breakers] for loop_breaker_nodes final_uds :: UsageDetails - loop_breaker_nodes :: [LetrecNode] + loop_breaker_nodes :: [LoopBreakerNode] (WithUsageDetails final_uds loop_breaker_nodes) = mkLoopBreakerNodes env lvl body_uds details_s ------------------------------ @@ -1102,7 +1220,7 @@ type Binding = (Id,CoreExpr) loopBreakNodes :: Int -> VarSet -- Binders whose dependencies may be "missing" -- See Note [Weak loop breakers] - -> [LetrecNode] + -> [LoopBreakerNode] -> [Binding] -- Append these to the end -> [Binding] @@ -1121,7 +1239,7 @@ loopBreakNodes depth weak_fvs nodes binds CyclicSCC nodes -> reOrderNodes depth weak_fvs nodes binds ---------------------------------- -reOrderNodes :: Int -> VarSet -> [LetrecNode] -> [Binding] -> [Binding] +reOrderNodes :: Int -> VarSet -> [LoopBreakerNode] -> [Binding] -> [Binding] -- Choose a loop breaker, mark it no-inline, -- and call loopBreakNodes on the rest reOrderNodes _ _ [] _ = panic "reOrderNodes" @@ -1133,7 +1251,7 @@ reOrderNodes depth weak_fvs (node : nodes) binds (map (nodeBinding mk_loop_breaker) chosen_nodes ++ binds) where (chosen_nodes, unchosen) = chooseLoopBreaker approximate_lb - (nd_score (node_payload node)) + (snd_score (node_payload node)) [node] [] nodes approximate_lb = depth >= 2 @@ -1142,8 +1260,8 @@ reOrderNodes depth weak_fvs (node : nodes) binds -- After two iterations (d=0, d=1) give up -- and approximate, returning to d=0 -nodeBinding :: (Id -> Id) -> LetrecNode -> Binding -nodeBinding set_id_occ (node_payload -> ND { nd_bndr = bndr, nd_rhs = rhs}) +nodeBinding :: (Id -> Id) -> LoopBreakerNode -> Binding +nodeBinding set_id_occ (node_payload -> SND { snd_bndr = bndr, snd_rhs = rhs}) = (set_id_occ bndr, rhs) mk_loop_breaker :: Id -> Id @@ -1163,13 +1281,13 @@ mk_non_loop_breaker weak_fvs bndr tail_info = tailCallInfo (idOccInfo bndr) ---------------------------------- -chooseLoopBreaker :: Bool -- True <=> Too many iterations, - -- so approximate - -> NodeScore -- Best score so far - -> [LetrecNode] -- Nodes with this score - -> [LetrecNode] -- Nodes with higher scores - -> [LetrecNode] -- Unprocessed nodes - -> ([LetrecNode], [LetrecNode]) +chooseLoopBreaker :: Bool -- True <=> Too many iterations, + -- so approximate + -> NodeScore -- Best score so far + -> [LoopBreakerNode] -- Nodes with this score + -> [LoopBreakerNode] -- Nodes with higher scores + -> [LoopBreakerNode] -- Unprocessed nodes + -> ([LoopBreakerNode], [LoopBreakerNode]) -- This loop looks for the bind with the lowest score -- to pick as the loop breaker. The rest accumulate in chooseLoopBreaker _ _ loop_nodes acc [] @@ -1189,7 +1307,7 @@ chooseLoopBreaker approx_lb loop_sc loop_nodes acc (node : nodes) | otherwise -- Worse score so don't pick it = chooseLoopBreaker approx_lb loop_sc loop_nodes (node : acc) nodes where - sc = nd_score (node_payload node) + sc = snd_score (node_payload node) {- Note [Complexity of loop breaking] @@ -1322,16 +1440,21 @@ ToDo: try using the occurrence info for the inline'd binder. ************************************************************************ -} -type LetrecNode = Node Unique Details -- Node comes from Digraph - -- The Unique key is gotten from the Id -data Details - = ND { nd_bndr :: Id -- Binder +-- | Digraph node as constructed by 'makeNode' and consumed by 'occAnalRec'. +-- The Unique key is gotten from the Id. +type LetrecNode = Node Unique NodeDetails - , nd_rhs :: CoreExpr -- RHS, already occ-analysed +-- | Node details as consumed by 'occAnalRec'. +data NodeDetails + = ND { nd_bndr :: Id -- Binder - , nd_uds :: UsageDetails -- Usage from RHS, and RULES, and stable unfoldings - -- ignoring phase (ie assuming all are active) - -- See Note [Forming Rec groups] + , nd_rhs :: !(WithTailUsageDetails CoreExpr) + -- ^ RHS, already occ-analysed + -- With TailUsageDetails from RHS, and RULES, and stable unfoldings, + -- ignoring phase (ie assuming all are active). + -- NB: Unadjusted TailUsageDetails, as if this Node becomes a + -- non-recursive join point! + -- See Note [TailUsageDetails when forming Rec groups] , nd_inl :: IdSet -- Free variables of the stable unfolding and the RHS -- but excluding any RULES @@ -1348,18 +1471,33 @@ data Details , nd_active_rule_fvs :: IdSet -- Variables bound in this Rec group that are free -- in the RHS of an active rule for this bndr -- See Note [Rules and loop breakers] - - , nd_score :: NodeScore } -instance Outputable Details where +instance Outputable NodeDetails where ppr nd = text "ND" <> braces (sep [ text "bndr =" <+> ppr (nd_bndr nd) - , text "uds =" <+> ppr (nd_uds nd) + , text "uds =" <+> ppr uds , text "inl =" <+> ppr (nd_inl nd) , text "simple =" <+> ppr (nd_simple nd) , text "active_rule_fvs =" <+> ppr (nd_active_rule_fvs nd) - , text "score =" <+> ppr (nd_score nd) + ]) + where WithTailUsageDetails uds _ = nd_rhs nd + +-- | Digraph with simplified and completely occurrence analysed +-- 'SimpleNodeDetails', retaining just the info we need for breaking loops. +type LoopBreakerNode = Node Unique SimpleNodeDetails + +-- | Condensed variant of 'NodeDetails' needed during loop breaking. +data SimpleNodeDetails + = SND { snd_bndr :: IdWithOccInfo -- OccInfo accurate + , snd_rhs :: CoreExpr -- properly occur-analysed + , snd_score :: NodeScore + } + +instance Outputable SimpleNodeDetails where + ppr nd = text "SND" <> braces + (sep [ text "bndr =" <+> ppr (snd_bndr nd) + , text "score =" <+> ppr (snd_score nd) ]) -- The NodeScore is compared lexicographically; @@ -1387,52 +1525,59 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs) -- explained in Note [Deterministic SCC] in GHC.Data.Graph.Directed. where details = ND { nd_bndr = bndr' - , nd_rhs = rhs' - , nd_uds = scope_uds + , nd_rhs = WithTailUsageDetails scope_uds rhs' , nd_inl = inl_fvs , nd_simple = null rules_w_uds && null imp_rule_info , nd_weak_fvs = weak_fvs - , nd_active_rule_fvs = active_rule_fvs - , nd_score = pprPanic "makeNodeDetails" (ppr bndr) } + , nd_active_rule_fvs = active_rule_fvs } bndr' = bndr `setIdUnfolding` unf' `setIdSpecialisation` mkRuleInfo rules' - inl_uds = rhs_uds `andUDs` unf_uds - scope_uds = inl_uds `andUDs` rule_uds + -- NB: Both adj_unf_uds and adj_rule_uds have been adjusted to match the + -- JoinArity rhs_ja of unadj_rhs_uds. + unadj_inl_uds = unadj_rhs_uds `andUDs` adj_unf_uds + unadj_scope_uds = unadj_inl_uds `andUDs` adj_rule_uds + scope_uds = TUD rhs_ja unadj_scope_uds -- Note [Rules are extra RHSs] -- Note [Rule dependency info] - scope_fvs = udFreeVars bndr_set scope_uds + scope_fvs = udFreeVars bndr_set unadj_scope_uds -- scope_fvs: all occurrences from this binder: RHS, unfolding, -- and RULES, both LHS and RHS thereof, active or inactive - inl_fvs = udFreeVars bndr_set inl_uds + inl_fvs = udFreeVars bndr_set unadj_inl_uds -- inl_fvs: vars that would become free if the function was inlined. -- We conservatively approximate that by thefree vars from the RHS -- and the unfolding together. -- See Note [inl_fvs] - mb_join_arity = isJoinId_maybe bndr - -- Get join point info from the *current* decision - -- We don't know what the new decision will be! - -- Using the old decision at least allows us to - -- preserve existing join point, even RULEs are added - -- See Note [Join points and unfoldings/rules] --------- Right hand side --------- -- Constructing the edges for the main Rec computation -- See Note [Forming Rec groups] - -- Do not use occAnalRhs because we don't yet know the final - -- answer for mb_join_arity; instead, do the occAnalLam call from - -- occAnalRhs, and postpone adjustRhsUsage until occAnalRec - rhs_env = rhsCtxt env - (WithUsageDetails rhs_uds rhs') = occAnalLam rhs_env rhs + -- and Note [TailUsageDetails when forming Rec groups] + -- Compared to occAnalNonRecBind, we can't yet adjust the RHS because + -- (a) we don't yet know the final joinpointhood. It might not become a + -- join point after all! + -- (b) we don't even know whether it stays a recursive RHS after the SCC + -- analysis we are about to seed! So we can't markAllInsideLam in + -- advance, because if it ends up as a non-recursive join point we'll + -- consider it as one-shot and don't need to markAllInsideLam. + -- Instead, do the occAnalLamTail call here and postpone adjustTailUsage + -- until occAnalRec. In effect, we pretend that the RHS becomes a + -- non-recursive join point and fix up later with adjustTailUsage. + rhs_env = rhsCtxt env + WithTailUsageDetails (TUD rhs_ja unadj_rhs_uds) rhs' = occAnalLamTail rhs_env rhs + -- corresponding call to adjustTailUsage in occAnalRec and tagRecBinders --------- Unfolding --------- - -- See Note [Unfoldings and join points] + -- See Note [Join points and unfoldings/rules] unf = realIdUnfolding bndr -- realIdUnfolding: Ignore loop-breaker-ness -- here because that is what we are setting! - (WithUsageDetails unf_uds unf') = occAnalUnfolding rhs_env Recursive mb_join_arity unf + WithTailUsageDetails unf_tuds unf' = occAnalUnfolding rhs_env unf + adj_unf_uds = adjustTailArity (Just rhs_ja) unf_tuds + -- `rhs_ja` is `joinRhsArity rhs` and is the prediction for source M + -- of Note [Join arity prediction based on joinRhsArity] --------- IMP-RULES -------- is_active = occ_rule_act env :: Activation -> Bool @@ -1441,11 +1586,15 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs) imp_rule_fvs = impRulesActiveFvs is_active bndr_set imp_rule_info --------- All rules -------- + -- See Note [Join points and unfoldings/rules] + -- `rhs_ja` is `joinRhsArity rhs'` and is the prediction for source M + -- of Note [Join arity prediction based on joinRhsArity] rules_w_uds :: [(CoreRule, UsageDetails, UsageDetails)] - rules_w_uds = occAnalRules rhs_env mb_join_arity bndr + rules_w_uds = [ (r,l,adjustTailArity (Just rhs_ja) rhs_tuds) + | (r,l,rhs_tuds) <- occAnalRules rhs_env bndr ] rules' = map fstOf3 rules_w_uds - rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds + adj_rule_uds = foldr add_rule_uds imp_rule_uds rules_w_uds add_rule_uds (_, l, r) uds = l `andUDs` r `andUDs` uds -------- active_rule_fvs ------------ @@ -1463,8 +1612,8 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs) mkLoopBreakerNodes :: OccEnv -> TopLevelFlag -> UsageDetails -- for BODY of let - -> [Details] - -> WithUsageDetails [LetrecNode] -- adjusted + -> [NodeDetails] + -> WithUsageDetails [LoopBreakerNode] -- with OccInfo up-to-date -- See Note [Choosing loop breakers] -- This function primarily creates the Nodes for the -- loop-breaker SCC analysis. More specifically: @@ -1477,10 +1626,10 @@ mkLoopBreakerNodes :: OccEnv -> TopLevelFlag mkLoopBreakerNodes !env lvl body_uds details_s = WithUsageDetails final_uds (zipWithEqual "mkLoopBreakerNodes" mk_lb_node details_s bndrs') where - (final_uds, bndrs') = tagRecBinders lvl body_uds details_s + WithUsageDetails final_uds bndrs' = tagRecBinders lvl body_uds details_s mk_lb_node nd@(ND { nd_bndr = old_bndr, nd_inl = inl_fvs }) new_bndr - = DigraphNode { node_payload = new_nd + = DigraphNode { node_payload = simple_nd , node_key = varUnique old_bndr , node_dependencies = nonDetKeysUniqSet lb_deps } -- It's OK to use nonDetKeysUniqSet here as @@ -1488,7 +1637,8 @@ mkLoopBreakerNodes !env lvl body_uds details_s -- in nondeterministic order as explained in -- Note [Deterministic SCC] in GHC.Data.Graph.Directed. where - new_nd = nd { nd_bndr = new_bndr, nd_score = score } + WithTailUsageDetails _ rhs = nd_rhs nd + simple_nd = SND { snd_bndr = new_bndr, snd_rhs = rhs, snd_score = score } score = nodeScore env new_bndr lb_deps nd lb_deps = extendFvs_ rule_fv_env inl_fvs -- See Note [Loop breaker dependencies] @@ -1524,10 +1674,10 @@ group { f1 = e1; ...; fn = en } are: nodeScore :: OccEnv -> Id -- Binder with new occ-info -> VarSet -- Loop-breaker dependencies - -> Details + -> NodeDetails -> NodeScore nodeScore !env new_bndr lb_deps - (ND { nd_bndr = old_bndr, nd_rhs = bind_rhs }) + (ND { nd_bndr = old_bndr, nd_rhs = WithTailUsageDetails _ bind_rhs }) | not (isId old_bndr) -- A type or coercion variable is never a loop breaker = (100, 0, False) @@ -1748,7 +1898,7 @@ lambda and casts, e.g. * Occurrence analyser: we just mark each binder in the lambda-group (here: x,y,z) with its occurrence info in the *body* of the - lambda-group. See occAnalLam. + lambda-group. See occAnalLamTail. * Simplifier. The simplifier is careful when partially applying lambda-groups. See the call to zapLambdaBndrs in @@ -1804,25 +1954,31 @@ zapLambdaBndrs fun arg_count zap_bndr b | isTyVar b = b | otherwise = zapLamIdInfo b -occAnalLam :: OccEnv -> CoreExpr -> (WithUsageDetails CoreExpr) --- See Note [Occurrence analysis for lambda binders] +occAnalLamTail :: OccEnv -> CoreExpr -> WithTailUsageDetails CoreExpr +-- ^ See Note [Occurrence analysis for lambda binders]. -- It does the following: -- * Sets one-shot info on the lambda binder from the OccEnv, and -- removes that one-shot info from the OccEnv -- * Sets the OccEnv to OccVanilla when going under a value lambda -- * Tags each lambda with its occurrence information -- * Walks through casts +-- * Package up the analysed lambda with its manifest join arity +-- -- This function does /not/ do -- markAllInsideLam or -- markAllNonTail --- The caller does that, either in occAnal (Lam {}), or in adjustRhsUsage +-- The caller does that, via adjustTailUsage (mostly calls go through +-- adjustNonRecRhs). Every call to occAnalLamTail must ultimately call +-- adjustTailUsage to discharge the assumed join arity. +-- +-- In effect, the analysis result is for a non-recursive join point with +-- manifest arity and adjustTailUsage does the fixup. -- See Note [Adjusting right-hand sides] - -occAnalLam env (Lam bndr expr) +occAnalLamTail env (Lam bndr expr) | isTyVar bndr - = let env1 = addOneInScope env bndr - WithUsageDetails usage expr' = occAnalLam env1 expr - in WithUsageDetails usage (Lam bndr expr') + , let env1 = addOneInScope env bndr + , WithTailUsageDetails (TUD ja usage) expr' <- occAnalLamTail env1 expr + = WithTailUsageDetails (TUD (ja+1) usage) (Lam bndr expr') -- Important: Keep the 'env' unchanged so that with a RHS like -- \(@ x) -> K @x (f @x) -- we'll see that (K @x (f @x)) is in a OccRhs, and hence refrain @@ -1840,14 +1996,14 @@ occAnalLam env (Lam bndr expr) env1 = env { occ_encl = OccVanilla, occ_one_shots = env_one_shots' } env2 = addOneInScope env1 bndr - (WithUsageDetails usage expr') = occAnalLam env2 expr + WithTailUsageDetails (TUD ja usage) expr' = occAnalLamTail env2 expr (usage', bndr2) = tagLamBinder usage bndr1 - in WithUsageDetails usage' (Lam bndr2 expr') + in WithTailUsageDetails (TUD (ja+1) usage') (Lam bndr2 expr') -- For casts, keep going in the same lambda-group -- See Note [Occurrence analysis for lambda binders] -occAnalLam env (Cast expr co) - = let (WithUsageDetails usage expr') = occAnalLam env expr +occAnalLamTail env (Cast expr co) + = let WithTailUsageDetails (TUD ja usage) expr' = occAnalLamTail env expr -- usage1: see Note [Gather occurrences of coercion variables] usage1 = addManyOccs usage (coVarsOfCo co) @@ -1857,15 +2013,16 @@ occAnalLam env (Cast expr co) _ -> usage1 -- usage3: you might think this was not necessary, because of - -- the markAllNonTail in adjustRhsUsage; but not so! For a - -- join point, adjustRhsUsage doesn't do this; yet if there is + -- the markAllNonTail in adjustTailUsage; but not so! For a + -- join point, adjustTailUsage doesn't do this; yet if there is -- a cast, we must! Also: why markAllNonTail? See -- GHC.Core.Lint: Note Note [Join points and casts] usage3 = markAllNonTail usage2 - in WithUsageDetails usage3 (Cast expr' co) + in WithTailUsageDetails (TUD ja usage3) (Cast expr' co) -occAnalLam env expr = occAnal env expr +occAnalLamTail env expr = case occAnal env expr of + WithUsageDetails usage expr' -> WithTailUsageDetails (TUD 0 usage) expr' {- Note [Occ-anal and cast worker/wrapper] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1885,8 +2042,8 @@ RHS. So it'll get a Many occ-info. (Maybe Cast w/w should create a stable unfolding, which would obviate this Note; but that seems a bit of a heavyweight solution.) -We only need to this in occAnalLam, not occAnal, because the top leve -of a right hand side is handled by occAnalLam. +We only need to this in occAnalLamTail, not occAnal, because the top leve +of a right hand side is handled by occAnalLamTail. -} @@ -1896,57 +2053,23 @@ of a right hand side is handled by occAnalLam. * * ********************************************************************* -} -occAnalRhs :: OccEnv -> RecFlag -> Maybe JoinArity - -> CoreExpr -- RHS - -> WithUsageDetails CoreExpr -occAnalRhs !env is_rec mb_join_arity rhs - = let (WithUsageDetails usage rhs1) = occAnalLam env rhs - -- We call occAnalLam here, not occAnalExpr, so that it doesn't - -- do the markAllInsideLam and markNonTailCall stuff before - -- we've had a chance to help with join points; that comes next - rhs2 = markJoinOneShots is_rec mb_join_arity rhs1 - rhs_usage = adjustRhsUsage mb_join_arity rhs2 usage - in WithUsageDetails rhs_usage rhs2 - - - -markJoinOneShots :: RecFlag -> Maybe JoinArity -> CoreExpr -> CoreExpr --- For a /non-recursive/ join point we can mark all --- its join-lambda as one-shot; and it's a good idea to do so -markJoinOneShots NonRecursive (Just join_arity) rhs - = go join_arity rhs - where - go 0 rhs = rhs - go n (Lam b rhs) = Lam (if isId b then setOneShotLambda b else b) - (go (n-1) rhs) - go _ rhs = rhs -- Not enough lambdas. This can legitimately happen. - -- e.g. let j = case ... in j True - -- This will become an arity-1 join point after the - -- simplifier has eta-expanded it; but it may not have - -- enough lambdas /yet/. (Lint checks that JoinIds do - -- have enough lambdas.) -markJoinOneShots _ _ rhs - = rhs - occAnalUnfolding :: OccEnv - -> RecFlag - -> Maybe JoinArity -- See Note [Join points and unfoldings/rules] -> Unfolding - -> WithUsageDetails Unfolding + -> WithTailUsageDetails Unfolding -- Occurrence-analyse a stable unfolding; --- discard a non-stable one altogether. -occAnalUnfolding !env is_rec mb_join_arity unf +-- discard a non-stable one altogether and return empty usage details. +occAnalUnfolding !env unf = case unf of unf@(CoreUnfolding { uf_tmpl = rhs, uf_src = src }) | isStableSource src -> let - (WithUsageDetails usage rhs') = occAnalRhs env is_rec mb_join_arity rhs + WithTailUsageDetails (TUD rhs_ja usage) rhs' = occAnalLamTail env rhs unf' | noBinderSwaps env = unf -- Note [Unfoldings and rules] | otherwise = unf { uf_tmpl = rhs' } - in WithUsageDetails (markAllMany usage) unf' + in WithTailUsageDetails (TUD rhs_ja (markAllMany usage)) unf' -- markAllMany: see Note [Occurrences in stable unfoldings] - | otherwise -> WithUsageDetails emptyDetails unf + | otherwise -> WithTailUsageDetails (TUD 0 emptyDetails) unf -- For non-Stable unfoldings we leave them undisturbed, but -- don't count their usage because the simplifier will discard them. -- We leave them undisturbed because nodeScore uses their size info @@ -1955,29 +2078,26 @@ occAnalUnfolding !env is_rec mb_join_arity unf -- scope remain in scope; there is no cloning etc. unf@(DFunUnfolding { df_bndrs = bndrs, df_args = args }) - -> WithUsageDetails final_usage (unf { df_args = args' }) + -> WithTailUsageDetails (TUD 0 final_usage) (unf { df_args = args' }) where env' = env `addInScope` bndrs (WithUsageDetails usage args') = occAnalList env' args - final_usage = markAllManyNonTail (delDetailsList usage bndrs) - `addLamCoVarOccs` bndrs - `delDetailsList` bndrs + final_usage = usage `addLamCoVarOccs` bndrs `delDetailsList` bndrs -- delDetailsList; no need to use tagLamBinders because we -- never inline DFuns so the occ-info on binders doesn't matter - unf -> WithUsageDetails emptyDetails unf + unf -> WithTailUsageDetails (TUD 0 emptyDetails) unf occAnalRules :: OccEnv - -> Maybe JoinArity -- See Note [Join points and unfoldings/rules] -> Id -- Get rules from here -> [(CoreRule, -- Each (non-built-in) rule UsageDetails, -- Usage details for LHS - UsageDetails)] -- Usage details for RHS -occAnalRules !env mb_join_arity bndr + TailUsageDetails)] -- Usage details for RHS +occAnalRules !env bndr = map occ_anal_rule (idCoreRules bndr) where occ_anal_rule rule@(Rule { ru_bndrs = bndrs, ru_args = args, ru_rhs = rhs }) - = (rule', lhs_uds', rhs_uds') + = (rule', lhs_uds', TUD rhs_ja rhs_uds') where env' = env `addInScope` bndrs rule' | noBinderSwaps env = rule -- Note [Unfoldings and rules] @@ -1990,14 +2110,11 @@ occAnalRules !env mb_join_arity bndr (WithUsageDetails rhs_uds rhs') = occAnal env' rhs -- Note [Rules are extra RHSs] -- Note [Rule dependency info] - rhs_uds' = markAllNonTailIf (not exact_join) $ - markAllMany $ + rhs_uds' = markAllMany $ rhs_uds `delDetailsList` bndrs + rhs_ja = length args -- See Note [Join points and unfoldings/rules] - exact_join = exactJoin mb_join_arity args - -- See Note [Join points and unfoldings/rules] - - occ_anal_rule other_rule = (other_rule, emptyDetails, emptyDetails) + occ_anal_rule other_rule = (other_rule, emptyDetails, TUD 0 emptyDetails) {- Note [Join point RHSs] ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2032,6 +2149,8 @@ Another way to think about it: if we inlined g as-is into multiple call sites, now there's be multiple calls to f. Bottom line: treat all occurrences in a stable unfolding as "Many". +We still leave tail call information intact, though, as to not spoil +potential join points. Note [Unfoldings and rules] ~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2209,10 +2328,7 @@ occAnal env app@(App _ _) = occAnalApp env (collectArgsTicks tickishFloatable app) occAnal env expr@(Lam {}) - = let (WithUsageDetails usage expr') = occAnalLam env expr - final_usage = markAllInsideLamIf (not (isOneShotFun expr')) $ - markAllNonTail usage - in WithUsageDetails final_usage expr' + = adjustNonRecRhs Nothing $ occAnalLamTail env expr -- mb_join_arity == Nothing <=> markAllManyNonTail occAnal env (Case scrut bndr ty alts) = let @@ -2287,7 +2403,7 @@ occAnalApp !env (Var fun, args, ticks) -- This caused #18296 | fun `hasKey` runRWKey , [t1, t2, arg] <- args - , let (WithUsageDetails usage arg') = occAnalRhs env NonRecursive (Just 1) arg + , WithUsageDetails usage arg' <- adjustNonRecRhs (Just 1) $ occAnalLamTail env arg = WithUsageDetails usage (mkTicks ticks $ mkApps (Var fun) [t1, t2, arg']) occAnalApp env (Var fun_id, args, ticks) @@ -2872,7 +2988,6 @@ lookupBndrSwap env@(OccEnv { occ_bs_env = bs_env }) bndr case lookupBndrSwap env bndr1 of (fun, fun_id) -> (mkCastMCo fun mco, fun_id) } - {- Historical note [Proxy let-bindings] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We used to do the binder-swap transformation by introducing @@ -2998,6 +3113,19 @@ data UsageDetails instance Outputable UsageDetails where ppr ud = ppr (ud_env (flattenUsageDetails ud)) +-- | Captures the result of applying 'occAnalLamTail' to a function `\xyz.body`. +-- The TailUsageDetails records +-- * the number of lambdas (including type lambdas: a JoinArity) +-- * UsageDetails for the `body`, unadjusted by `adjustTailUsage`. +-- If the binding turns out to be a join point with the indicated join +-- arity, this unadjusted usage details is just what we need; otherwise we +-- need to discard tail calls. That's what `adjustTailUsage` does. +data TailUsageDetails = TUD !JoinArity !UsageDetails + +instance Outputable TailUsageDetails where + ppr (TUD ja uds) = lambda <> ppr ja <> ppr uds + + ------------------- -- UsageDetails API @@ -3139,24 +3267,49 @@ flattenUsageDetails ud@(UD { ud_env = env }) ------------------- -- See Note [Adjusting right-hand sides] -adjustRhsUsage :: Maybe JoinArity - -> CoreExpr -- Rhs, AFTER occ anal - -> UsageDetails -- From body of lambda +adjustTailUsage :: Maybe JoinArity + -> CoreExpr -- Rhs, AFTER occAnalLamTail + -> TailUsageDetails -- From body of lambda -> UsageDetails -adjustRhsUsage mb_join_arity rhs usage +adjustTailUsage mb_join_arity rhs (TUD rhs_ja usage) = -- c.f. occAnal (Lam {}) markAllInsideLamIf (not one_shot) $ markAllNonTailIf (not exact_join) $ usage where one_shot = isOneShotFun rhs - exact_join = exactJoin mb_join_arity bndrs - (bndrs,_) = collectBinders rhs + exact_join = mb_join_arity == Just rhs_ja + +adjustTailArity :: Maybe JoinArity -> TailUsageDetails -> UsageDetails +adjustTailArity mb_rhs_ja (TUD ud_ja usage) = + markAllNonTailIf (mb_rhs_ja /= Just ud_ja) usage + +markNonRecJoinOneShots :: JoinArity -> CoreExpr -> CoreExpr +-- For a /non-recursive/ join point we can mark all +-- its join-lambda as one-shot; and it's a good idea to do so +markNonRecJoinOneShots join_arity rhs + = go join_arity rhs + where + go 0 rhs = rhs + go n (Lam b rhs) = Lam (if isId b then setOneShotLambda b else b) + (go (n-1) rhs) + go _ rhs = rhs -- Not enough lambdas. This can legitimately happen. + -- e.g. let j = case ... in j True + -- This will become an arity-1 join point after the + -- simplifier has eta-expanded it; but it may not have + -- enough lambdas /yet/. (Lint checks that JoinIds do + -- have enough lambdas.) -exactJoin :: Maybe JoinArity -> [a] -> Bool -exactJoin Nothing _ = False -exactJoin (Just join_arity) args = args `lengthIs` join_arity - -- Remember join_arity includes type binders +markNonRecUnfoldingOneShots :: Maybe JoinArity -> Unfolding -> Unfolding +-- ^ Apply 'markNonRecJoinOneShots' to a stable unfolding +markNonRecUnfoldingOneShots mb_join_arity unf + | Just ja <- mb_join_arity + , CoreUnfolding{uf_src=src,uf_tmpl=tmpl} <- unf + , isStableSource src + , let !tmpl' = markNonRecJoinOneShots ja tmpl + = unf{uf_tmpl=tmpl'} + | otherwise + = unf type IdWithOccInfo = Id @@ -3192,8 +3345,8 @@ tagLamBinder usage bndr tagNonRecBinder :: TopLevelFlag -- At top level? -> UsageDetails -- Of scope -> CoreBndr -- Binder - -> (UsageDetails, -- Details with binder removed - IdWithOccInfo) -- Tagged binder + -> WithUsageDetails -- Details with binder removed + IdWithOccInfo -- Tagged binder tagNonRecBinder lvl usage binder = let @@ -3205,37 +3358,34 @@ tagNonRecBinder lvl usage binder binder' = setBinderOcc occ' binder usage' = usage `delDetails` binder in - usage' `seq` (usage', binder') + WithUsageDetails usage' binder' tagRecBinders :: TopLevelFlag -- At top level? -> UsageDetails -- Of body of let ONLY - -> [Details] - -> (UsageDetails, -- Adjusted details for whole scope, + -> [NodeDetails] + -> WithUsageDetails -- Adjusted details for whole scope, -- with binders removed - [IdWithOccInfo]) -- Tagged binders + [IdWithOccInfo] -- Tagged binders -- Substantially more complicated than non-recursive case. Need to adjust RHS -- details *before* tagging binders (because the tags depend on the RHSes). tagRecBinders lvl body_uds details_s = let bndrs = map nd_bndr details_s - rhs_udss = map nd_uds details_s - - -- 1. Determine join-point-hood of whole group, as determined by - -- the *unadjusted* usage details - unadj_uds = foldr andUDs body_uds rhs_udss - -- This is only used in `mb_join_arity`, to adjust each `Details` in `details_s`, thus, - -- when `bndrs` is non-empty. So, we only write `maybe False` as `decideJoinPointHood` - -- takes a `NonEmpty CoreBndr`; the default value `False` won't affect program behavior. - will_be_joins = maybe False (decideJoinPointHood lvl unadj_uds) (nonEmpty bndrs) + -- 1. See Note [Join arity prediction based on joinRhsArity] + -- Determine possible join-point-hood of whole group, by testing for + -- manifest join arity M. + -- This (re-)asserts that makeNode had made tuds for that same arity M! + unadj_uds = foldr (andUDs . test_manifest_arity) body_uds details_s + test_manifest_arity ND{nd_rhs=WithTailUsageDetails tuds rhs} + = adjustTailArity (Just (joinRhsArity rhs)) tuds - -- 2. Adjust usage details of each RHS, taking into account the - -- join-point-hood decision - rhs_udss' = [ adjustRhsUsage (mb_join_arity bndr) rhs rhs_uds - | ND { nd_bndr = bndr, nd_uds = rhs_uds - , nd_rhs = rhs } <- details_s ] + bndr_ne = expectNonEmpty "List of binders is never empty" bndrs + will_be_joins = decideJoinPointHood lvl unadj_uds bndr_ne mb_join_arity :: Id -> Maybe JoinArity + -- mb_join_arity: See Note [Join arity prediction based on joinRhsArity] + -- This is the source O mb_join_arity bndr -- Can't use willBeJoinId_maybe here because we haven't tagged -- the binder yet (the tag depends on these adjustments!) @@ -3247,6 +3397,12 @@ tagRecBinders lvl body_uds details_s = assert (not will_be_joins) -- Should be AlwaysTailCalled if Nothing -- we are making join points! + -- 2. Adjust usage details of each RHS, taking into account the + -- join-point-hood decision + rhs_udss' = [ adjustTailUsage (mb_join_arity bndr) rhs rhs_tuds -- matching occAnalLamTail in makeNode + | ND { nd_bndr = bndr, nd_rhs = WithTailUsageDetails rhs_tuds rhs } + <- details_s ] + -- 3. Compute final usage details from adjusted RHS details adj_uds = foldr andUDs body_uds rhs_udss' @@ -3257,7 +3413,7 @@ tagRecBinders lvl body_uds details_s -- 5. Drop the binders from the adjusted details and return usage' = adj_uds `delDetailsList` bndrs in - (usage', bndrs') + WithUsageDetails usage' bndrs' setBinderOcc :: OccInfo -> CoreBndr -> CoreBndr setBinderOcc occ_info bndr @@ -3271,12 +3427,13 @@ setBinderOcc occ_info bndr | otherwise = setIdOccInfo bndr occ_info --- | Decide whether some bindings should be made into join points or not. +-- | Decide whether some bindings should be made into join points or not, based +-- on its occurrences. This is -- Returns `False` if they can't be join points. Note that it's an -- all-or-nothing decision, as if multiple binders are given, they're -- assumed to be mutually recursive. -- --- It must, however, be a final decision. If we say "True" for 'f', +-- It must, however, be a final decision. If we say `True` for 'f', -- and then subsequently decide /not/ make 'f' into a join point, then -- the decision about another binding 'g' might be invalidated if (say) -- 'f' tail-calls 'g'. diff --git a/compiler/GHC/Data/Graph/Directed.hs b/compiler/GHC/Data/Graph/Directed.hs index 1f4202038e..915180b9e9 100644 --- a/compiler/GHC/Data/Graph/Directed.hs +++ b/compiler/GHC/Data/Graph/Directed.hs @@ -4,6 +4,7 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE DeriveFunctor #-} module GHC.Data.Graph.Directed ( Graph, graphFromEdgedVerticesOrd, graphFromEdgedVerticesUniq, @@ -108,7 +109,7 @@ data Node key payload = DigraphNode { node_payload :: payload, -- ^ User data node_key :: key, -- ^ User defined node id node_dependencies :: [key] -- ^ Dependencies/successors of the node - } + } deriving Functor instance (Outputable a, Outputable b) => Outputable (Node a b) where diff --git a/compiler/GHC/Utils/Misc.hs b/compiler/GHC/Utils/Misc.hs index a115c61336..09111ccc47 100644 --- a/compiler/GHC/Utils/Misc.hs +++ b/compiler/GHC/Utils/Misc.hs @@ -35,7 +35,7 @@ module GHC.Utils.Misc ( equalLength, compareLength, leLength, ltLength, isSingleton, only, expectOnly, GHC.Utils.Misc.singleton, - notNull, snocView, + notNull, expectNonEmpty, snocView, chunkList, @@ -481,7 +481,6 @@ expectOnly _ (a:_) = a #endif expectOnly msg _ = panic ("expectOnly: " ++ msg) - -- | Split a list into chunks of /n/ elements chunkList :: Int -> [a] -> [[a]] chunkList _ [] = [] @@ -500,6 +499,16 @@ changeLast [] _ = panic "changeLast" changeLast [_] x = [x] changeLast (x:xs) x' = x : changeLast xs x' +-- | Like @expectJust msg . nonEmpty@; a better alternative to 'NE.fromList'. +expectNonEmpty :: HasCallStack => String -> [a] -> NonEmpty a +{-# INLINE expectNonEmpty #-} +expectNonEmpty _ (x:xs) = x:|xs +expectNonEmpty msg [] = expectNonEmptyPanic msg + +expectNonEmptyPanic :: String -> a +expectNonEmptyPanic msg = panic ("expectNonEmpty: " ++ msg) +{-# NOINLINE expectNonEmptyPanic #-} + -- | Apply an effectful function to the last list element. mapLastM :: Functor f => (a -> f a) -> NonEmpty a -> f (NonEmpty a) mapLastM f (x:|[]) = NE.singleton <$> f x diff --git a/testsuite/tests/simplCore/should_compile/T22428.hs b/testsuite/tests/simplCore/should_compile/T22428.hs new file mode 100644 index 0000000000..02cccb7f3a --- /dev/null +++ b/testsuite/tests/simplCore/should_compile/T22428.hs @@ -0,0 +1,9 @@ +module T22428 where + +f :: Integer -> Integer -> Integer +f x y = go y + where + go :: Integer -> Integer + go 0 = x + go n = go (n-1) + {-# INLINE go #-} diff --git a/testsuite/tests/simplCore/should_compile/T22428.stderr b/testsuite/tests/simplCore/should_compile/T22428.stderr new file mode 100644 index 0000000000..48ea278ae0 --- /dev/null +++ b/testsuite/tests/simplCore/should_compile/T22428.stderr @@ -0,0 +1,45 @@ + +==================== Tidy Core ==================== +Result size of Tidy Core + = {terms: 32, types: 14, coercions: 0, joins: 1/1} + +-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0} +T22428.f1 :: Integer +[GblId, + Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True, + WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 10}] +T22428.f1 = GHC.Num.Integer.IS 1# + +-- RHS size: {terms: 28, types: 10, coercions: 0, joins: 1/1} +f :: Integer -> Integer -> Integer +[GblId, + Arity=2, + Str=<SL><1L>, + Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True, + WorkFree=True, Expandable=True, Guidance=IF_ARGS [0 0] 156 0}] +f = \ (x :: Integer) (y :: Integer) -> + joinrec { + go [InlPrag=INLINE (sat-args=1), Occ=LoopBreaker, Dmd=SC(S,L)] + :: Integer -> Integer + [LclId[JoinId(1)(Just [!])], + Arity=1, + Str=<1L>, + Unf=Unf{Src=StableUser, TopLvl=False, Value=True, ConLike=True, + WorkFree=True, Expandable=True, + Guidance=ALWAYS_IF(arity=1,unsat_ok=False,boring_ok=False)}] + go (ds :: Integer) + = case ds of wild { + GHC.Num.Integer.IS x1 -> + case x1 of { + __DEFAULT -> jump go (GHC.Num.Integer.integerSub wild T22428.f1); + 0# -> x + }; + GHC.Num.Integer.IP x1 -> + jump go (GHC.Num.Integer.integerSub wild T22428.f1); + GHC.Num.Integer.IN x1 -> + jump go (GHC.Num.Integer.integerSub wild T22428.f1) + }; } in + jump go y + + + diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T index 4fd57c5301..1b25c1b00b 100644 --- a/testsuite/tests/simplCore/should_compile/all.T +++ b/testsuite/tests/simplCore/should_compile/all.T @@ -459,6 +459,10 @@ test('T22494', [grep_errmsg(r'case') ], compile, ['-O -ddump-simpl -dsuppress-un test('T22491', normal, compile, ['-O2']) test('T21476', normal, compile, ['']) test('T22272', normal, multimod_compile, ['T22272', '-O -fexpose-all-unfoldings -fno-omit-interface-pragmas -fno-ignore-interface-pragmas']) + +# go should become a join point +test('T22428', [grep_errmsg(r'jump go') ], compile, ['-O -ddump-simpl -dsuppress-uniques -dno-typeable-binds -dsuppress-unfoldings']) + test('T22459', normal, compile, ['']) test('T22623', normal, multimod_compile, ['T22623', '-O -v0']) test('T22662', normal, compile, ['']) |