summaryrefslogtreecommitdiff
path: root/compiler/simplCore
diff options
context:
space:
mode:
authorJoachim Breitner <mail@joachim-breitner.de>2018-04-06 17:26:45 -0400
committerJoachim Breitner <mail@joachim-breitner.de>2018-04-09 11:25:06 -0400
commitb14c03737574895718eed786a60dfdfd42ab49ce (patch)
treecb649dc1d68784fe8a8b8e8bdc5d37a252fdfcf4 /compiler/simplCore
parent8b823f270e53627ddca1a993c05f1ab556742d96 (diff)
downloadhaskell-b14c03737574895718eed786a60dfdfd42ab49ce.tar.gz
Some cleanup of the Exitification code
based on a thorough review by Simon in comments https://ghc.haskell.org/trac/ghc/ticket/14152#comment:33 through 37. The changes are: * `isExitJoinId` is moved to `SimplUtils`, because it is only valid when occurrence information is up-to-date. * Abstracted variables are properly sorted using `sortQuantVars` * Exitification does not set occ info. And then minor quibles to notes and avoiding some unhelpful shadowing of local names. Differential Revision: https://phabricator.haskell.org/D4576
Diffstat (limited to 'compiler/simplCore')
-rw-r--r--compiler/simplCore/Exitify.hs45
-rw-r--r--compiler/simplCore/SimplUtils.hs12
2 files changed, 36 insertions, 21 deletions
diff --git a/compiler/simplCore/Exitify.hs b/compiler/simplCore/Exitify.hs
index cf6a930d3e..570186e219 100644
--- a/compiler/simplCore/Exitify.hs
+++ b/compiler/simplCore/Exitify.hs
@@ -48,16 +48,19 @@ import VarEnv
import CoreFVs
import FastString
import Type
+import MkCore ( sortQuantVars )
import Data.Bifunctor
import Control.Monad
-- | Traverses the AST, simply to find all joinrecs and call 'exitify' on them.
+-- The really interesting function is exitify
exitifyProgram :: CoreProgram -> CoreProgram
exitifyProgram binds = map goTopLvl binds
where
goTopLvl (NonRec v e) = NonRec v (go in_scope_toplvl e)
goTopLvl (Rec pairs) = Rec (map (second (go in_scope_toplvl)) pairs)
+ -- Top-level bindings are never join points
in_scope_toplvl = emptyInScopeSet `extendInScopeSetList` bindersOfBinds binds
@@ -91,6 +94,10 @@ exitifyProgram binds = map goTopLvl binds
is_join_rec = any (isJoinId . fst) pairs
in_scope' = in_scope `extendInScopeSetList` bindersOf (Rec pairs)
+
+-- | State Monad used inside `exitify`
+type ExitifyM = State [(JoinId, CoreExpr)]
+
-- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
-- join-points outside the joinrec.
exitify :: InScopeSet -> [(Var,CoreExpr)] -> (CoreExpr -> CoreExpr)
@@ -120,11 +127,13 @@ exitify in_scope pairs =
-- checks if there are no more recursive calls, if so, abstracts over
-- variables bound on the way and lifts it out as a join point.
--
- -- It uses a state monad to keep track of floated binds
+ -- ExitifyM is a state monad to keep track of floated binds
go :: [Var] -- ^ variables to abstract over
-> CoreExprWithFVs -- ^ current expression in tail position
- -> State [(Id, CoreExpr)] CoreExpr
+ -> ExitifyM CoreExpr
+ -- We first look at the expression (no matter what it shape is)
+ -- and determine if we can turn it into a exit join point
go captured ann_e
-- Do not touch an expression that is already a join jump where all arguments
-- are captured variables. See Note [Idempotency]
@@ -145,13 +154,13 @@ exitify in_scope pairs =
-- We have something to float out!
| is_exit = do
-- Assemble the RHS of the exit join point
- let rhs = mkLams args e
+ let rhs = mkLams abs_vars e
ty = exprType rhs
let avoid = in_scope `extendInScopeSetList` captured
-- Remember this binding under a suitable name
- v <- addExit avoid ty (length args) rhs
+ v <- addExit avoid ty (length abs_vars) rhs
-- And jump to it from here
- return $ mkVarApps (Var v) args
+ return $ mkVarApps (Var v) abs_vars
where
-- An exit expression has no recursive calls
is_exit = disjointVarSet fvs recursive_calls
@@ -166,14 +175,17 @@ exitify in_scope pairs =
is_interesting = anyVarSet isLocalId (fvs `minusVarSet` mkVarSet captured)
-- The possible arguments of this exit join point
- args = filter (`elemVarSet` fvs) captured
+ abs_vars = sortQuantVars $ filter (`elemVarSet` fvs) captured
-- We cannot abstract over join points
- captures_join_points = any isJoinId args
+ captures_join_points = any isJoinId abs_vars
e = deAnnotate ann_e
fvs = dVarSetToVarSet (freeVarsOf ann_e)
+ -- We could not turn it into a exit joint point. So now recurse
+ -- into all expression where eligible exit join points might sit,
+ -- i.e. into all tail-call positions:
-- Case right hand sides are in tail-call position
go captured (_, AnnCase scrut bndr ty alts) = do
@@ -211,6 +223,8 @@ exitify in_scope pairs =
return $ Let bind body'
where bind = deAnnBind ann_bind
+ -- Cannot be turned into an exit join point, but also has no
+ -- tail-call subexpression. Nothing to do here.
go _ ann_e = return (deAnnotate ann_e)
@@ -227,14 +241,6 @@ mkExitJoinId in_scope ty join_arity = do
where
exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ty
`asJoinId` join_arity
- `setIdOccInfo` exit_occ_info
-
- -- See Note [Do not inline exit join points]
- exit_occ_info =
- OneOcc { occ_in_lam = True
- , occ_one_br = True
- , occ_int_cxt = False
- , occ_tail = AlwaysTailCalled join_arity }
addExit :: InScopeSet -> Type -> JoinArity -> CoreExpr -> ExitifyM JoinId
addExit in_scope ty join_arity rhs = do
@@ -245,8 +251,6 @@ addExit in_scope ty join_arity rhs = do
return v
-type ExitifyM = State [(JoinId, CoreExpr)]
-
{-
Note [Interesting expression]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -381,6 +385,8 @@ joinrecs are nested.
Further downside of A: If the exitify function returns annotated expressions,
it would have to ensure that the annotations are correct.
+We therefore choose B, and calculate the free variables in `exitify`.
+
Note [Do not inline exit join points]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -399,7 +405,8 @@ To prevent this, we need to recognize exit join points, and then disable
inlining.
Exit join points, recognizeable using `isExitJoinId` are join points with an
-occurence in a recursive group, and can be recognized using `isExitJoinId`.
+occurence in a recursive group, and can be recognized (after the occurence
+analyzer ran!) using `isExitJoinId`.
This function detects joinpoints with `occ_in_lam (idOccinfo id) == True`,
because the lambdas of a non-recursive join point are not considered for
`occ_in_lam`. For example, in the following code, `j1` is /not/ marked
@@ -408,8 +415,6 @@ occ_in_lam, because `j2` is called only once.
join j1 x = x+1
join j2 y = join j1 (y+2)
-We create exit join point ids with such an `OccInfo`, see `exit_occ_info`.
-
To prevent inlining, we check for isExitJoinId
* In `preInlineUnconditionally` directly.
* In `simplLetUnfolding` we simply give exit join points no unfolding, which
diff --git a/compiler/simplCore/SimplUtils.hs b/compiler/simplCore/SimplUtils.hs
index db26af426e..7c0689d9be 100644
--- a/compiler/simplCore/SimplUtils.hs
+++ b/compiler/simplCore/SimplUtils.hs
@@ -30,7 +30,10 @@ module SimplUtils (
addValArgTo, addCastTo, addTyArgTo,
argInfoExpr, argInfoAppArgs, pushSimplifiedArgs,
- abstractFloats
+ abstractFloats,
+
+ -- Utilities
+ isExitJoinId
) where
#include "HsVersions.h"
@@ -2199,6 +2202,13 @@ in PrelRules)
mkCase3 _dflags scrut bndr alts_ty alts
= return (Case scrut bndr alts_ty alts)
+-- See Note [Exitification] and Note [Do not inline exit join points] in Exitify.hs
+-- This lives here (and not in Id) becuase occurrence info is only valid on
+-- InIds, so it's crucial that isExitJoinId is only called on freshly
+-- occ-analysed code. It's not a generic function you can call anywhere.
+isExitJoinId :: Var -> Bool
+isExitJoinId id = isJoinId id && isOneOcc (idOccInfo id) && occ_in_lam (idOccInfo id)
+
{-
Note [Dead binders]
~~~~~~~~~~~~~~~~~~~~