summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/simplCore/Exitify.hs100
1 files changed, 78 insertions, 22 deletions
diff --git a/compiler/simplCore/Exitify.hs b/compiler/simplCore/Exitify.hs
index 53434bf107..714baaacf2 100644
--- a/compiler/simplCore/Exitify.hs
+++ b/compiler/simplCore/Exitify.hs
@@ -47,6 +47,7 @@ import VarSet
import VarEnv
import CoreFVs
import FastString
+import TrieMap
import Type
import Data.Bifunctor
@@ -95,7 +96,7 @@ exitifyProgram binds = map goTopLvl binds
-- join-points outside the joinrec.
exitify :: InScopeSet -> [(Var,CoreExpr)] -> (CoreExpr -> CoreExpr)
exitify in_scope pairs =
- \body ->mkExitLets exits (mkLetRec pairs' body)
+ \body -> mkExitLets exits (mkLetRec pairs' body)
where
mkExitLets ((exitId, exitRhs):exits') = mkLetNonRec exitId exitRhs . mkExitLets exits'
mkExitLets [] = id
@@ -108,7 +109,7 @@ exitify in_scope pairs =
-- Which are the recursive calls?
recursive_calls = mkVarSet $ map fst pairs
- (pairs',exits) = (`runState` []) $ do
+ (pairs',exits) = runExitifyState in_scope $
forM ann_pairs $ \(x,rhs) -> do
-- go past the lambdas of the join point
let (args, body) = collectNAnnBndrs (idJoinArity x) rhs
@@ -123,7 +124,7 @@ exitify in_scope pairs =
-- It uses a state monad to keep track of floated binds
go :: [Var] -- ^ variables to abstract over
-> CoreExprWithFVs -- ^ current expression in tail position
- -> State [(Id, CoreExpr)] CoreExpr
+ -> ExitifyM CoreExpr
go captured ann_e
-- Do not touch an expression that is already a join jump where all arguments
@@ -146,10 +147,10 @@ exitify in_scope pairs =
| is_exit = do
-- Assemble the RHS of the exit join point
let rhs = mkLams args e
- ty = exprType rhs
- let avoid = in_scope `extendInScopeSetList` captured
+ -- Remember what is in scope here
+ nowInScope captured
-- Remember this binding under a suitable name
- v <- addExit avoid ty (length args) rhs
+ v <- addExit (length args) rhs
-- And jump to it from here
return $ mkVarApps (Var v) args
where
@@ -214,16 +215,41 @@ exitify in_scope pairs =
go _ ann_e = return (deAnnotate ann_e)
+type ExitifyM = State ExitifyState
+data ExitifyState = ExitifyState
+ { es_in_scope_acc :: InScopeSet -- ^ combined in_scope_set of all call sites
+ , es_in_scope :: InScopeSet -- ^ final in_scope_set
+ , es_joins :: [(JoinId, CoreExpr)] -- ^ exit join points
+ , es_map :: CoreMap JoinId
+ -- ^ reverse lookup map, see Note [Avoid duplicate exit points]
+ }
+
+-- Runs the ExitifyM monad, and feeds in the final es_in_scope_acc as the
+-- es_in_scope to use
+runExitifyState :: InScopeSet -> ExitifyM a -> (a, [(JoinId, CoreExpr)])
+runExitifyState in_scope_init f = (res, es_joins state)
+ where
+ (res, state) = runState f (ExitifyState in_scope_init in_scope [] emptyTM)
+ in_scope = es_in_scope_acc state
+
+-- Keeps track of what is in scope at all the various positions where
+-- we want to jump to an exit join point
+nowInScope :: [Var] -> ExitifyM ()
+nowInScope captured = do
+ st <- get
+ put (st { es_in_scope_acc = es_in_scope_acc st `extendInScopeSetList` captured})
+
-- Picks a new unique, which is disjoint from
-- * the free variables of the whole joinrec
-- * any bound variables (captured)
-- * any exit join points created so far.
-mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId
-mkExitJoinId in_scope ty join_arity = do
- fs <- get
- let avoid = in_scope `extendInScopeSetList` (map fst fs)
- `extendInScopeSet` exit_id_tmpl -- just cosmetics
- return (uniqAway avoid exit_id_tmpl)
+mkExitJoinId :: Type -> JoinArity -> ExitifyM JoinId
+mkExitJoinId ty join_arity = do
+ st <- get
+ let in_scope = es_in_scope st `extendInScopeSet` exit_id_tmpl -- cosmetic only
+ let v = uniqAway in_scope exit_id_tmpl
+ put (st { es_in_scope = es_in_scope st `extendInScopeSet` v})
+ return v
where
exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ty
`asJoinId` join_arity
@@ -236,16 +262,22 @@ mkExitJoinId in_scope ty join_arity = do
, occ_int_cxt = False
, occ_tail = AlwaysTailCalled join_arity }
-addExit :: InScopeSet -> Type -> JoinArity -> CoreExpr -> ExitifyM JoinId
-addExit in_scope ty join_arity rhs = do
- -- Pick a suitable name
- v <- mkExitJoinId in_scope ty join_arity
- fs <- get
- put ((v,rhs):fs)
- return v
-
-
-type ExitifyM = State [(JoinId, CoreExpr)]
+-- Adds a new exit join point
+-- (or re-uses an existing one)
+addExit :: JoinArity -> CoreExpr -> ExitifyM JoinId
+addExit join_arity rhs = do
+ st <- get
+ -- See Note [Avoid duplicate exit points]
+ case lookupTM rhs (es_map st) of
+ Just v -> return v
+ Nothing -> do
+ -- Pick a suitable name
+ v <- mkExitJoinId (exprType rhs) join_arity
+ st <- get
+ put (st { es_joins = (v,rhs) : es_joins st
+ , es_map = insertTM rhs v (es_map st)
+ })
+ return v
{-
Note [Interesting expression]
@@ -389,4 +421,28 @@ this kind of inlining.
In the `final` run of the simplifier, we do allow inlining of exit join points,
via a `SimplifierMode` flag.
+
+Note [Avoid duplicate exit points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+If we have
+
+ joinrec go 0 x y = t (x*x)
+ go 10 x y = t (x*x)
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+
+we want to create only _one_ exit join point:
+
+ join exit x = t (x*x)
+ joinrec go 0 x y = jump exit x
+ go 10 x y = jump exit x
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+
+we do so by keeping a `CoreMap JoinId` around, and `addExit` checks for
+if we can re-use an already created exit join point.
+
+Note that (at the time of writing), CSE does *not* handle join points.
+See Note [CSE for join points?]
-}