diff options
-rw-r--r-- | compiler/simplCore/Exitify.hs | 100 |
1 files changed, 78 insertions, 22 deletions
diff --git a/compiler/simplCore/Exitify.hs b/compiler/simplCore/Exitify.hs index 53434bf107..714baaacf2 100644 --- a/compiler/simplCore/Exitify.hs +++ b/compiler/simplCore/Exitify.hs @@ -47,6 +47,7 @@ import VarSet import VarEnv import CoreFVs import FastString +import TrieMap import Type import Data.Bifunctor @@ -95,7 +96,7 @@ exitifyProgram binds = map goTopLvl binds -- join-points outside the joinrec. exitify :: InScopeSet -> [(Var,CoreExpr)] -> (CoreExpr -> CoreExpr) exitify in_scope pairs = - \body ->mkExitLets exits (mkLetRec pairs' body) + \body -> mkExitLets exits (mkLetRec pairs' body) where mkExitLets ((exitId, exitRhs):exits') = mkLetNonRec exitId exitRhs . mkExitLets exits' mkExitLets [] = id @@ -108,7 +109,7 @@ exitify in_scope pairs = -- Which are the recursive calls? recursive_calls = mkVarSet $ map fst pairs - (pairs',exits) = (`runState` []) $ do + (pairs',exits) = runExitifyState in_scope $ forM ann_pairs $ \(x,rhs) -> do -- go past the lambdas of the join point let (args, body) = collectNAnnBndrs (idJoinArity x) rhs @@ -123,7 +124,7 @@ exitify in_scope pairs = -- It uses a state monad to keep track of floated binds go :: [Var] -- ^ variables to abstract over -> CoreExprWithFVs -- ^ current expression in tail position - -> State [(Id, CoreExpr)] CoreExpr + -> ExitifyM CoreExpr go captured ann_e -- Do not touch an expression that is already a join jump where all arguments @@ -146,10 +147,10 @@ exitify in_scope pairs = | is_exit = do -- Assemble the RHS of the exit join point let rhs = mkLams args e - ty = exprType rhs - let avoid = in_scope `extendInScopeSetList` captured + -- Remember what is in scope here + nowInScope captured -- Remember this binding under a suitable name - v <- addExit avoid ty (length args) rhs + v <- addExit (length args) rhs -- And jump to it from here return $ mkVarApps (Var v) args where @@ -214,16 +215,41 @@ exitify in_scope pairs = go _ ann_e = return (deAnnotate ann_e) +type ExitifyM = State ExitifyState +data ExitifyState = ExitifyState + { es_in_scope_acc :: InScopeSet -- ^ combined in_scope_set of all call sites + , es_in_scope :: InScopeSet -- ^ final in_scope_set + , es_joins :: [(JoinId, CoreExpr)] -- ^ exit join points + , es_map :: CoreMap JoinId + -- ^ reverse lookup map, see Note [Avoid duplicate exit points] + } + +-- Runs the ExitifyM monad, and feeds in the final es_in_scope_acc as the +-- es_in_scope to use +runExitifyState :: InScopeSet -> ExitifyM a -> (a, [(JoinId, CoreExpr)]) +runExitifyState in_scope_init f = (res, es_joins state) + where + (res, state) = runState f (ExitifyState in_scope_init in_scope [] emptyTM) + in_scope = es_in_scope_acc state + +-- Keeps track of what is in scope at all the various positions where +-- we want to jump to an exit join point +nowInScope :: [Var] -> ExitifyM () +nowInScope captured = do + st <- get + put (st { es_in_scope_acc = es_in_scope_acc st `extendInScopeSetList` captured}) + -- Picks a new unique, which is disjoint from -- * the free variables of the whole joinrec -- * any bound variables (captured) -- * any exit join points created so far. -mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId -mkExitJoinId in_scope ty join_arity = do - fs <- get - let avoid = in_scope `extendInScopeSetList` (map fst fs) - `extendInScopeSet` exit_id_tmpl -- just cosmetics - return (uniqAway avoid exit_id_tmpl) +mkExitJoinId :: Type -> JoinArity -> ExitifyM JoinId +mkExitJoinId ty join_arity = do + st <- get + let in_scope = es_in_scope st `extendInScopeSet` exit_id_tmpl -- cosmetic only + let v = uniqAway in_scope exit_id_tmpl + put (st { es_in_scope = es_in_scope st `extendInScopeSet` v}) + return v where exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ty `asJoinId` join_arity @@ -236,16 +262,22 @@ mkExitJoinId in_scope ty join_arity = do , occ_int_cxt = False , occ_tail = AlwaysTailCalled join_arity } -addExit :: InScopeSet -> Type -> JoinArity -> CoreExpr -> ExitifyM JoinId -addExit in_scope ty join_arity rhs = do - -- Pick a suitable name - v <- mkExitJoinId in_scope ty join_arity - fs <- get - put ((v,rhs):fs) - return v - - -type ExitifyM = State [(JoinId, CoreExpr)] +-- Adds a new exit join point +-- (or re-uses an existing one) +addExit :: JoinArity -> CoreExpr -> ExitifyM JoinId +addExit join_arity rhs = do + st <- get + -- See Note [Avoid duplicate exit points] + case lookupTM rhs (es_map st) of + Just v -> return v + Nothing -> do + -- Pick a suitable name + v <- mkExitJoinId (exprType rhs) join_arity + st <- get + put (st { es_joins = (v,rhs) : es_joins st + , es_map = insertTM rhs v (es_map st) + }) + return v {- Note [Interesting expression] @@ -389,4 +421,28 @@ this kind of inlining. In the `final` run of the simplifier, we do allow inlining of exit join points, via a `SimplifierMode` flag. + +Note [Avoid duplicate exit points] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If we have + + joinrec go 0 x y = t (x*x) + go 10 x y = t (x*x) + go (n-1) x y = jump go (n-1) (x+y) + in … + +we want to create only _one_ exit join point: + + join exit x = t (x*x) + joinrec go 0 x y = jump exit x + go 10 x y = jump exit x + go (n-1) x y = jump go (n-1) (x+y) + in … + +we do so by keeping a `CoreMap JoinId` around, and `addExit` checks for +if we can re-use an already created exit join point. + +Note that (at the time of writing), CSE does *not* handle join points. +See Note [CSE for join points?] -} |