From 535a88e1f348112c44e27e0083a68f054b055619 Mon Sep 17 00:00:00 2001 From: "klebinger.andreas@gmx.at" Date: Mon, 18 Feb 2019 00:28:39 +0100 Subject: Add loop level analysis to the NCG backend. For backends maintaining the CFG during codegen we can now find loops and their nesting level. This is based on the Cmm CFG and dominator analysis. As a result we can estimate edge frequencies a lot better for methods, resulting in far better code layout. Speedup on nofib: ~1.5% Increase in compile times: ~1.9% To make this feasible this commit adds: * Dominator analysis based on the Lengauer-Tarjan Algorithm. * An algorithm estimating global edge frequences from branch probabilities - In CFG.hs A few static branch prediction heuristics: * Expect to take the backedge in loops. * Expect to take the branch NOT exiting a loop. * Expect integer vs constant comparisons to be false. We also treat heap/stack checks special for branch prediction to avoid them being treated as loops. --- compiler/cmm/Hoopl/Dataflow.hs | 5 +- compiler/ghc.cabal.in | 1 + compiler/nativeGen/AsmCodeGen.hs | 9 +- compiler/nativeGen/BlockLayout.hs | 635 +++++++++++++-------- compiler/nativeGen/CFG.hs | 748 ++++++++++++++++++++++--- compiler/nativeGen/RegAlloc/Graph/SpillCost.hs | 101 ++-- compiler/nativeGen/X86/CodeGen.hs | 2 +- compiler/utils/Dominators.hs | 588 +++++++++++++++++++ compiler/utils/OrdList.hs | 60 +- 9 files changed, 1775 insertions(+), 374 deletions(-) create mode 100644 compiler/utils/Dominators.hs diff --git a/compiler/cmm/Hoopl/Dataflow.hs b/compiler/cmm/Hoopl/Dataflow.hs index 2a2bb72dcc..9762a84e20 100644 --- a/compiler/cmm/Hoopl/Dataflow.hs +++ b/compiler/cmm/Hoopl/Dataflow.hs @@ -6,8 +6,6 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} -{-# OPTIONS_GHC -fprof-auto-top #-} - -- -- Copyright (c) 2010, João Dias, Simon Marlow, Simon Peyton Jones, -- and Norman Ramsey @@ -108,6 +106,7 @@ analyzeCmm -> FactBase f -> FactBase f analyzeCmm dir lattice transfer cmmGraph initFact = + {-# SCC analyzeCmm #-} let entry = g_entry cmmGraph hooplGraph = g_graph cmmGraph blockMap = @@ -169,7 +168,7 @@ rewriteCmm -> CmmGraph -> FactBase f -> UniqSM (CmmGraph, FactBase f) -rewriteCmm dir lattice rwFun cmmGraph initFact = do +rewriteCmm dir lattice rwFun cmmGraph initFact = {-# SCC rewriteCmm #-} do let entry = g_entry cmmGraph hooplGraph = g_graph cmmGraph blockMap1 = diff --git a/compiler/ghc.cabal.in b/compiler/ghc.cabal.in index a612733de6..3ff27ea098 100644 --- a/compiler/ghc.cabal.in +++ b/compiler/ghc.cabal.in @@ -593,6 +593,7 @@ Library Instruction BlockLayout CFG + Dominators Format Reg RegClass diff --git a/compiler/nativeGen/AsmCodeGen.hs b/compiler/nativeGen/AsmCodeGen.hs index 6b7727a426..4c883e7185 100644 --- a/compiler/nativeGen/AsmCodeGen.hs +++ b/compiler/nativeGen/AsmCodeGen.hs @@ -562,7 +562,7 @@ cmmNativeGen dflags this_mod modLoc ncgImpl us fileIds dbgMap cmm count Opt_D_dump_asm_native "Native code" (vcat $ map (pprNatCmmDecl ncgImpl) native) - dumpIfSet_dyn dflags + when (not $ null nativeCfgWeights) $ dumpIfSet_dyn dflags Opt_D_dump_cfg_weights "CFG Weights" (pprEdgeWeights nativeCfgWeights) @@ -691,7 +691,7 @@ cmmNativeGen dflags this_mod modLoc ncgImpl us fileIds dbgMap cmm count {-# SCC "generateJumpTables" #-} generateJumpTables ncgImpl alloced - dumpIfSet_dyn dflags + when (not $ null nativeCfgWeights) $ dumpIfSet_dyn dflags Opt_D_dump_cfg_weights "CFG Update information" ( text "stack:" <+> ppr stack_updt_blks $$ text "linearAlloc:" <+> ppr cfgRegAllocUpdates ) @@ -705,8 +705,9 @@ cmmNativeGen dflags this_mod modLoc ncgImpl us fileIds dbgMap cmm count optimizedCFG = optimizeCFG (cfgWeightInfo dflags) cmm <$!> postShortCFG - maybe (return ()) - (dumpIfSet_dyn dflags Opt_D_dump_cfg_weights "CFG Final Weights" . pprEdgeWeights) + maybe (return ()) (\cfg-> + dumpIfSet_dyn dflags Opt_D_dump_cfg_weights "CFG Final Weights" + ( pprEdgeWeights cfg )) optimizedCFG --TODO: Partially check validity of the cfg. diff --git a/compiler/nativeGen/BlockLayout.hs b/compiler/nativeGen/BlockLayout.hs index 7a39071541..56e3177dd8 100644 --- a/compiler/nativeGen/BlockLayout.hs +++ b/compiler/nativeGen/BlockLayout.hs @@ -6,6 +6,8 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE FlexibleContexts #-} module BlockLayout ( sequenceTop ) @@ -22,7 +24,6 @@ import BlockId import Cmm import Hoopl.Collections import Hoopl.Label -import Hoopl.Block import DynFlags (gopt, GeneralFlag(..), DynFlags, backendMaintainsCfg) import UniqFM @@ -41,11 +42,30 @@ import ListSetOps (removeDups) import OrdList import Data.List import Data.Foldable (toList) -import Hoopl.Graph import qualified Data.Set as Set +import Data.STRef +import Control.Monad.ST.Strict +import Control.Monad (foldM) {- + Note [CFG based code layout] + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + The major steps in placing blocks are as follow: + * Compute a CFG based on the Cmm AST, see getCfgProc. + This CFG will have edge weights representing a guess + on how important they are. + * After we convert Cmm to Asm we run `optimizeCFG` which + adds a few more "educated guesses" to the equation. + * Then we run loop analysis on the CFG (`loopInfo`) which tells us + about loop headers, loop nesting levels and the sort. + * Based on the CFG and loop information refine the edge weights + in the CFG and normalize them relative to the most often visited + node. (See `mkGlobalWeights`) + * Feed this CFG into the block layout code (`sequenceTop`) in this + module. Which will then produce a code layout based on the input weights. + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~ Note [Chain based CFG serialization] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -60,8 +80,8 @@ import qualified Data.Set as Set but also how much a block would benefit from being placed sequentially after it's predecessor. For example blocks which are preceeded by an info table are more likely to end - up in a different cache line than their predecessor. So there is less benefit - in placing them sequentially. + up in a different cache line than their predecessor and we can't eliminate the jump + so there is less benefit to placing them sequentially. For example consider this example: @@ -81,56 +101,83 @@ import qualified Data.Set as Set Eg for our example we might end up with two chains like: [A->B->C->X],[D]. Blocks inside chains will always be placed sequentially. However there is no particular order in which chains are placed since - (hopefully) the blocks for which sequentially is important have already + (hopefully) the blocks for which sequentiality is important have already been placed in the same chain. ----------------------------------------------------------------------------- - First try to create a lists of good chains. + 1) First try to create a list of good chains. ----------------------------------------------------------------------------- - We do so by taking a block not yet placed in a chain and - looking at these cases: + Good chains are these which allow us to eliminate jump instructions. + Which further eliminate often executed jumps first. + + We do so by: + + *) Ignore edges which represent instructions which can not be replaced + by fall through control flow. Primarily calls and edges to blocks which + are prefixed by a info table we have to jump across. + + *) Then process remaining edges in order of frequency taken and: + + +) If source and target have not been placed build a new chain from them. + + +) If source and target have been placed, and are ends of differing chains + try to merge the two chains. - *) Check if the best predecessor of the block is at the end of a chain. - If so add the current block to the end of that chain. + +) If one side of the edge is a end/front of a chain, add the other block of + to edge to the same chain - Eg if we look at block C and already have the chain (A -> B) - then we extend the chain to (A -> B -> C). + Eg if we look at edge (B -> C) and already have the chain (A -> B) + then we extend the chain to (A -> B -> C). - Combined with the fact that we process blocks in reverse post order - this means loop bodies and trivially sequential control flow already - ends up as a single chain. + +) If the edge was used to modify or build a new chain remove the edge from + our working list. - *) Otherwise we create a singleton chain from the block we are looking at. - Eg if we have from the example above already constructed (A->B) - and look at D we create the chain (D) resulting in the chains [A->B, D] + *) If there any blocks not being placed into a chain after these steps we place + them into a chain consisting of only this block. + + Ranking edges by their taken frequency, if + two edges compete for fall through on the same target block, the one taken + more often will automatically win out. Resulting in fewer instructions being + executed. + + Creating singleton chains is required for situations where we have code of the + form: + + A: goto B: + + B: goto C: + + C: ... + + As the code in block B is only connected to the rest of the program via edges + which will be ignored in this step we make sure that B still ends up in a chain + this way. ----------------------------------------------------------------------------- - We then try to fuse chains. + 2) We also try to fuse chains. ----------------------------------------------------------------------------- - There are edge cases which result in two chains being created which trivially - represent linear control flow. For example we might have the chains - [(A-B-C),(D-E)] with an cfg triangle: + As a result from the above step we still end up with multiple chains which + represent sequential control flow chunks. But they are not yet suitable for + code layout as we need to place *all* blocks into a single sequence. - A----->C->D->E - \->B-/ + In this step we combine chains result from the above step via these steps: - We also get three independent chains if two branches end with a jump - to a common successor. + *) Look at the ranked list of *all* edges, including calls/jumps across info tables + and the like. - We take care of these cases by fusing chains which are connected by an - edge. + *) Look at each edge and - We do so by looking at the list of edges sorted by weight. - Given the edge (C -> D) we try to find two chains such that: - * C is at the end of chain one. - * D is in front of chain two. - * If two such chains exist we fuse them. - We then remove the edge and repeat the process for the rest of the edges. + +) Given an edge (A -> B) try to find two chains for which + * Block A is at the end of one chain + * Block B is at the front of the other chain. + +) If we find such a chain we "fuse" them into a single chain, remove the + edge from working set and continue. + +) If we can't find such chains we skip the edge and continue. ----------------------------------------------------------------------------- - Place indirect successors (neighbours) after each other + 3) Place indirect successors (neighbours) after each other ----------------------------------------------------------------------------- We might have chains [A,B,C,X],[E] in a CFG of the sort: @@ -141,15 +188,11 @@ import qualified Data.Set as Set While E does not follow X it's still beneficial to place them near each other. This can be advantageous if eg C,X,E will end up in the same cache line. - TODO: If we remove edges as we use them (eg if we build up A->B remove A->B - from the list) we could save some more work in later phases. - - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~ Note [Triangle Control Flow] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Checking if an argument is already evaluating leads to a somewhat + Checking if an argument is already evaluated leads to a somewhat special case which looks like this: A: @@ -204,11 +247,6 @@ import qualified Data.Set as Set neighbourOverlapp :: Int neighbourOverlapp = 2 --- | Only edges heavier than this are considered --- for fusing two chains into a single chain. -fuseEdgeThreshold :: EdgeWeight -fuseEdgeThreshold = 0 - -- | Maps blocks near the end of a chain to it's chain AND -- the other blocks near the end. -- [A,B,C,D,E] Gives entries like (B -> ([A,B], [A,B,C,D,E])) @@ -224,40 +262,24 @@ type FrontierMap = LabelMap ([BlockId],BlockChain) newtype BlockChain = BlockChain { chainBlocks :: (OrdList BlockId) } -instance Eq (BlockChain) where - (BlockChain blks1) == (BlockChain blks2) - = fromOL blks1 == fromOL blks2 +-- All chains are constructed the same way so comparison +-- including structure is faster. +instance Eq BlockChain where + BlockChain b1 == BlockChain b2 = strictlyEqOL b1 b2 -- Useful for things like sets and debugging purposes, sorts by blocks -- in the chain. instance Ord (BlockChain) where (BlockChain lbls1) `compare` (BlockChain lbls2) - = (fromOL lbls1) `compare` (fromOL lbls2) + = ASSERT(toList lbls1 /= toList lbls2 || lbls1 `strictlyEqOL` lbls2) + strictlyOrdOL lbls1 lbls2 instance Outputable (BlockChain) where ppr (BlockChain blks) = parens (text "Chain:" <+> ppr (fromOL $ blks) ) -data WeightedEdge = WeightedEdge !BlockId !BlockId EdgeWeight deriving (Eq) - - --- | Non deterministic! (Uniques) Sorts edges by weight and nodes. -instance Ord WeightedEdge where - compare (WeightedEdge from1 to1 weight1) - (WeightedEdge from2 to2 weight2) - | weight1 < weight2 || weight1 == weight2 && from1 < from2 || - weight1 == weight2 && from1 == from2 && to1 < to2 - = LT - | from1 == from2 && to1 == to2 && weight1 == weight2 - = EQ - | otherwise - = GT - -instance Outputable WeightedEdge where - ppr (WeightedEdge from to info) = - ppr from <> text "->" <> ppr to <> brackets (ppr info) - -type WeightedEdgeList = [WeightedEdge] +chainFoldl :: (b -> BlockId -> b) -> b -> BlockChain -> b +chainFoldl f z (BlockChain blocks) = foldl' f z blocks noDups :: [BlockChain] -> Bool noDups chains = @@ -270,19 +292,21 @@ inFront :: BlockId -> BlockChain -> Bool inFront bid (BlockChain seq) = headOL seq == bid -chainMember :: BlockId -> BlockChain -> Bool -chainMember bid chain - = elem bid $ fromOL . chainBlocks $ chain --- = setMember bid . chainMembers $ chain - chainSingleton :: BlockId -> BlockChain chainSingleton lbl = BlockChain (unitOL lbl) +chainFromList :: [BlockId] -> BlockChain +chainFromList = BlockChain . toOL + chainSnoc :: BlockChain -> BlockId -> BlockChain chainSnoc (BlockChain blks) lbl = BlockChain (blks `snocOL` lbl) +chainCons :: BlockId -> BlockChain -> BlockChain +chainCons lbl (BlockChain blks) + = BlockChain (lbl `consOL` blks) + chainConcat :: BlockChain -> BlockChain -> BlockChain chainConcat (BlockChain blks1) (BlockChain blks2) = BlockChain (blks1 `appOL` blks2) @@ -311,52 +335,14 @@ takeL :: Int -> BlockChain -> [BlockId] takeL n (BlockChain blks) = take n . fromOL $ blks --- | For a given list of chains try to fuse chains with strong --- edges between them into a single chain. --- Returns the list of fused chains together with a set of --- used edges. The set of edges is indirectly encoded in the --- chains so doesn't need to be considered for later passes. -fuseChains :: WeightedEdgeList -> LabelMap BlockChain - -> (LabelMap BlockChain, Set.Set WeightedEdge) -fuseChains weights chains - = let fronts = mapFromList $ - map (\chain -> (headOL . chainBlocks $ chain,chain)) $ - mapElems chains :: LabelMap BlockChain - (chains', used, _) = applyEdges weights chains fronts Set.empty - in (chains', used) - where - applyEdges :: WeightedEdgeList -> LabelMap BlockChain - -> LabelMap BlockChain -> Set.Set WeightedEdge - -> (LabelMap BlockChain, Set.Set WeightedEdge, LabelMap BlockChain) - applyEdges [] chainsEnd chainsFront used - = (chainsEnd, used, chainsFront) - applyEdges (edge@(WeightedEdge from to w):edges) chainsEnd chainsFront used - --Since we order edges descending by weight we can stop here - | w <= fuseEdgeThreshold - = ( chainsEnd, used, chainsFront) - --Fuse the two chains - | Just c1 <- mapLookup from chainsEnd - , Just c2 <- mapLookup to chainsFront - , c1 /= c2 - = let newChain = chainConcat c1 c2 - front = headOL . chainBlocks $ newChain - end = lastOL . chainBlocks $ newChain - chainsFront' = mapInsert front newChain $ - mapDelete to chainsFront - chainsEnd' = mapInsert end newChain $ - mapDelete from chainsEnd - in applyEdges edges chainsEnd' chainsFront' - (Set.insert edge used) - | otherwise - --Check next edge - = applyEdges edges chainsEnd chainsFront used - +-- Note [Combining neighborhood chains] +-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -- See also Note [Chain based CFG serialization] -- We have the chains (A-B-C-D) and (E-F) and an Edge C->E. -- --- While placing the later after the former doesn't result in sequential --- control flow it is still be benefical since block C and E might end +-- While placing the latter after the former doesn't result in sequential +-- control flow it is still benefical. As block C and E might end -- up in the same cache line. -- -- So we place these chains next to each other even if we can't fuse them. @@ -365,7 +351,7 @@ fuseChains weights chains -- v -- - -> E -> F ... -- --- Simple heuristic to chose which chains we want to combine: +-- A simple heuristic to chose which chains we want to combine: -- * Process edges in descending priority. -- * Check if there is a edge near the end of one chain which goes -- to a block near the start of another edge. @@ -375,14 +361,22 @@ fuseChains weights chains -- us to find all edges between two chains, check the distance for all edges, -- rank them based on the distance and and only then we can select two chains -- to combine. Which would add a lot of complexity for little gain. +-- +-- So instead we just rank by the strength of the edge and use the first pair we +-- find. -- | For a given list of chains and edges try to combine chains with strong -- edges between them. -combineNeighbourhood :: WeightedEdgeList -> [BlockChain] - -> [BlockChain] +combineNeighbourhood :: [CfgEdge] -- ^ Edges to consider + -> [BlockChain] -- ^ Current chains of blocks + -> ([BlockChain], Set.Set (BlockId,BlockId)) + -- ^ Resulting list of block chains, and a set of edges which + -- were used to fuse chains and as such no longer need to be + -- considered. combineNeighbourhood edges chains = -- pprTraceIt "Neigbours" $ - applyEdges edges endFrontier startFrontier + -- pprTrace "combineNeighbours" (ppr edges) $ + applyEdges edges endFrontier startFrontier (Set.empty) where --Build maps from chain ends to chains endFrontier, startFrontier :: FrontierMap @@ -396,14 +390,14 @@ combineNeighbourhood edges chains let front = getFronts chain entry = (front,chain) in map (\x -> (x,entry)) front) chains - applyEdges :: WeightedEdgeList -> FrontierMap -> FrontierMap - -> [BlockChain] - applyEdges [] chainEnds _chainFronts = - ordNub $ map snd $ mapElems chainEnds - applyEdges ((WeightedEdge from to _w):edges) chainEnds chainFronts + applyEdges :: [CfgEdge] -> FrontierMap -> FrontierMap -> Set.Set (BlockId, BlockId) + -> ([BlockChain], Set.Set (BlockId,BlockId)) + applyEdges [] chainEnds _chainFronts combined = + (ordNub $ map snd $ mapElems chainEnds, combined) + applyEdges ((CfgEdge from to _w):edges) chainEnds chainFronts combined | Just (c1_e,c1) <- mapLookup from chainEnds , Just (c2_f,c2) <- mapLookup to chainFronts - , c1 /= c2 -- Avoid trying to concat a short chain with itself. + , c1 /= c2 -- Avoid trying to concat a chain with itself. = let newChain = chainConcat c1 c2 newChainFrontier = getFronts newChain newChainEnds = getEnds newChain @@ -437,165 +431,299 @@ combineNeighbourhood edges chains -- text "fronts" <+> ppr newFronts $$ -- text "ends" <+> ppr newEnds -- ) - applyEdges edges newEnds newFronts + applyEdges edges newEnds newFronts (Set.insert (from,to) combined) | otherwise - = --pprTrace "noNeigbours" (ppr ()) $ - applyEdges edges chainEnds chainFronts + = applyEdges edges chainEnds chainFronts combined where getFronts chain = takeL neighbourOverlapp chain getEnds chain = takeR neighbourOverlapp chain - - --- See [Chain based CFG serialization] -buildChains :: CFG -> [BlockId] - -> ( LabelMap BlockChain -- Resulting chains. +-- In the last stop we combine all chains into a single one. +-- Trying to place chains with strong edges next to each other. +mergeChains :: [CfgEdge] -> [BlockChain] + -> (BlockChain) +mergeChains edges chains + = -- pprTrace "combine" (ppr edges) $ + runST $ do + let addChain m0 chain = do + ref <- newSTRef chain + return $ chainFoldl (\m' b -> mapInsert b ref m') m0 chain + chainMap' <- foldM (\m0 c -> addChain m0 c) mapEmpty chains + merge edges chainMap' + where + -- We keep a map from ALL blocks to their respective chain (sigh) + -- This is required since when looking at an edge we need to find + -- the associated chains quickly. + -- We use a map of STRefs, maintaining a invariant of one STRef per chain. + -- When merging chains we can update the + -- STRef of one chain once (instead of writing to the map for each block). + -- We then overwrite the STRefs for the other chain so there is again only + -- a single STRef for the combined chain. + -- The difference in terms of allocations saved is ~0.2% with -O so actually + -- significant compared to using a regular map. + + merge :: forall s. [CfgEdge] -> LabelMap (STRef s BlockChain) -> ST s BlockChain + merge [] chains = do + chains' <- ordNub <$> (mapM readSTRef $ mapElems chains) :: ST s [BlockChain] + return $ foldl' chainConcat (head chains') (tail chains') + merge ((CfgEdge from to _):edges) chains + -- | pprTrace "merge" (ppr (from,to) <> ppr chains) False + -- = undefined + | cFrom == cTo + = merge edges chains + | otherwise + = do + chains' <- mergeComb cFrom cTo + merge edges chains' + where + mergeComb :: STRef s BlockChain -> STRef s BlockChain -> ST s (LabelMap (STRef s BlockChain)) + mergeComb refFrom refTo = do + cRight <- readSTRef refTo + chain <- pure chainConcat <*> readSTRef refFrom <*> pure cRight + writeSTRef refFrom chain + return $ chainFoldl (\m b -> mapInsert b refFrom m) chains cRight + + cFrom = expectJust "mergeChains:chainMap:from" $ mapLookup from chains + cTo = expectJust "mergeChains:chainMap:to" $ mapLookup to chains + + +-- See Note [Chain based CFG serialization] for the general idea. +-- This creates and fuses chains at the same time for performance reasons. + +-- Try to build chains from a list of edges. +-- Edges must be sorted **descending** by their priority. +-- Returns the constructed chains, along with all edges which +-- are irrelevant past this point, this information doesn't need +-- to be complete - it's only used to speed up the process. +-- An Edge is irrelevant if the ends are part of the same chain. +-- We say these edges are already linked +buildChains :: [CfgEdge] -> [BlockId] + -> ( LabelMap BlockChain -- Resulting chains, indexd by end if chain. , Set.Set (BlockId, BlockId)) --List of fused edges. -buildChains succWeights blocks - = let (_, fusedEdges, chains) = buildNext setEmpty mapEmpty blocks Set.empty - in (chains, fusedEdges) +buildChains edges blocks + = runST $ buildNext setEmpty mapEmpty mapEmpty edges Set.empty where - -- We keep a map from the last block in a chain to the chain itself. - -- This we we can easily check if an block should be appened to an + -- buildNext builds up chains from edges one at a time. + + -- We keep a map from the ends of chains to the chains. + -- This we we can easily check if an block should be appended to an -- existing chain! - buildNext :: LabelSet - -> LabelMap BlockChain -- Map from last element to chain. - -> [BlockId] -- Blocks to place - -> Set.Set (BlockId, BlockId) - -> ( [BlockChain] -- Placed Blocks - , Set.Set (BlockId, BlockId) --List of fused edges - , LabelMap BlockChain - ) - buildNext _placed chains [] linked = - ([], linked, chains) - buildNext placed chains (block:todo) linked - | setMember block placed - = buildNext placed chains todo linked + -- We store them using STRefs so we don't have to rebuild the spine of both + -- maps every time we update a chain. + buildNext :: forall s. LabelSet + -> LabelMap (STRef s BlockChain) -- Map from end of chain to chain. + -> LabelMap (STRef s BlockChain) -- Map from start of chain to chain. + -> [CfgEdge] -- Edges to check - ordered by decreasing weight + -> Set.Set (BlockId, BlockId) -- Used edges + -> ST s ( LabelMap BlockChain -- Chains by end + , Set.Set (BlockId, BlockId) --List of fused edges + ) + buildNext placed _chainStarts chainEnds [] linked = do + ends' <- sequence $ mapMap readSTRef chainEnds :: ST s (LabelMap BlockChain) + -- Any remaining blocks have to be made to singleton chains. + -- They might be combined with other chains later on outside this function. + let unplaced = filter (\x -> not (setMember x placed)) blocks + singletons = map (\x -> (x,chainSingleton x)) unplaced :: [(BlockId,BlockChain)] + return (foldl' (\m (k,v) -> mapInsert k v m) ends' singletons , linked) + buildNext placed chainStarts chainEnds (edge:todo) linked + | from == to + -- We skip self edges + = buildNext placed chainStarts chainEnds todo (Set.insert (from,to) linked) + | not (alreadyPlaced from) && + not (alreadyPlaced to) + = do + --pprTraceM "Edge-Chain:" (ppr edge) + chain' <- newSTRef $ chainFromList [from,to] + buildNext + (setInsert to (setInsert from placed)) + (mapInsert from chain' chainStarts) + (mapInsert to chain' chainEnds) + todo + (Set.insert (from,to) linked) + + | (alreadyPlaced from) && + (alreadyPlaced to) + , Just predChain <- mapLookup from chainEnds + , Just succChain <- mapLookup to chainStarts + , predChain /= succChain -- Otherwise we try to create a cycle. + = do + -- pprTraceM "Fusing edge" (ppr edge) + fuseChain predChain succChain + + | (alreadyPlaced from) && + (alreadyPlaced to) + = --pprTraceM "Skipping:" (ppr edge) >> + buildNext placed chainStarts chainEnds todo linked + | otherwise - = buildNext placed' chains' todo linked' + = do -- pprTraceM "Finding chain for:" (ppr edge $$ + -- text "placed" <+> ppr placed) + findChain where - placed' = (foldl' (flip setInsert) placed placedBlocks) - linked' = Set.union linked linkedEdges - (placedBlocks, chains', linkedEdges) = findChain block - - --Add the block to a existing or new chain - --Returns placed blocks, list of resulting chains - --and fused edges - findChain :: BlockId - -> ([BlockId],LabelMap BlockChain, Set.Set (BlockId, BlockId)) - findChain block - -- B) place block at end of existing chain if - -- there is no better block to append. - | (pred:_) <- preds - , alreadyPlaced pred - , Just predChain <- mapLookup pred chains - , (best:_) <- filter (not . alreadyPlaced) $ getSuccs pred - , best == lbl - = --pprTrace "B.2)" (ppr (pred,lbl)) $ - let newChain = chainSnoc predChain block - chainMap = mapInsert lbl newChain $ mapDelete pred chains - in ( [lbl] - , chainMap - , Set.singleton (pred,lbl) ) - + from = edgeFrom edge + to = edgeTo edge + alreadyPlaced blkId = (setMember blkId placed) + + -- Combine two chains into a single one. + fuseChain :: STRef s BlockChain -> STRef s BlockChain + -> ST s ( LabelMap BlockChain -- Chains by end + , Set.Set (BlockId, BlockId) --List of fused edges + ) + fuseChain fromRef toRef = do + fromChain <- readSTRef fromRef + toChain <- readSTRef toRef + let newChain = chainConcat fromChain toChain + ref <- newSTRef newChain + let start = head $ takeL 1 newChain + let end = head $ takeR 1 newChain + -- chains <- sequence $ mapMap readSTRef chainStarts + -- pprTraceM "pre-fuse chains:" $ ppr chains + buildNext + placed + (mapInsert start ref $ mapDelete to $ chainStarts) + (mapInsert end ref $ mapDelete from $ chainEnds) + todo + (Set.insert (from,to) linked) + + + --Add the block to a existing chain or creates a new chain + findChain :: ST s ( LabelMap BlockChain -- Chains by end + , Set.Set (BlockId, BlockId) --List of fused edges + ) + findChain + -- We can attach the block to the end of a chain + | alreadyPlaced from + , Just predChain <- mapLookup from chainEnds + = do + chain <- readSTRef predChain + let newChain = chainSnoc chain to + writeSTRef predChain newChain + let chainEnds' = mapInsert to predChain $ mapDelete from chainEnds + -- chains <- sequence $ mapMap readSTRef chainStarts + -- pprTraceM "from chains:" $ ppr chains + buildNext (setInsert to placed) chainStarts chainEnds' todo (Set.insert (from,to) linked) + -- We can attack it to the front of a chain + | alreadyPlaced to + , Just succChain <- mapLookup to chainStarts + = do + chain <- readSTRef succChain + let newChain = from `chainCons` chain + writeSTRef succChain newChain + let chainStarts' = mapInsert from succChain $ mapDelete to chainStarts + -- chains <- sequence $ mapMap readSTRef chainStarts' + -- pprTraceM "to chains:" $ ppr chains + buildNext (setInsert from placed) chainStarts' chainEnds todo (Set.insert (from,to) linked) + -- The placed end of the edge is part of a chain already and not an end. | otherwise - = --pprTrace "single" (ppr lbl) - ( [lbl] - , mapInsert lbl (chainSingleton lbl) chains - , Set.empty) + = do + let block = if alreadyPlaced to then from else to + --pprTraceM "Singleton" $ ppr block + let newChain = chainSingleton block + ref <- newSTRef newChain + buildNext (setInsert block placed) (mapInsert block ref chainStarts) + (mapInsert block ref chainEnds) todo (linked) where alreadyPlaced blkId = (setMember blkId placed) - lbl = block - getSuccs = map fst . getSuccEdgesSorted succWeights - preds = map fst $ getSuccEdgesSorted predWeights lbl - --For efficiency we also create the map to look up predecessors here - predWeights = reverseEdges succWeights - - --- We make the CFG a Hoopl Graph, so we can reuse revPostOrder. -newtype BlockNode (e :: Extensibility) (x :: Extensibility) = BN (BlockId,[BlockId]) -instance NonLocal (BlockNode) where - entryLabel (BN (lbl,_)) = lbl - successors (BN (_,succs)) = succs - -fromNode :: BlockNode C C -> BlockId -fromNode (BN x) = fst x - -sequenceChain :: forall a i. (Instruction i, Outputable i) => LabelMap a -> CFG - -> [GenBasicBlock i] -> [GenBasicBlock i] +-- | Place basic blocks based on the given CFG. +-- See Note [Chain based CFG serialization] +sequenceChain :: forall a i. (Instruction i, Outputable i) + => LabelMap a -- ^ Keys indicate an info table on the block. + -> CFG -- ^ Control flow graph and some meta data. + -> [GenBasicBlock i] -- ^ List of basic blocks to be placed. + -> [GenBasicBlock i] -- ^ Blocks placed in sequence. sequenceChain _info _weights [] = [] sequenceChain _info _weights [x] = [x] sequenceChain info weights' blocks@((BasicBlock entry _):_) = - --Optimization, delete edges of weight <= 0. - --This significantly improves performance whenever - --we iterate over all edges, which is a few times! let weights :: CFG - weights - = filterEdges (\_f _t edgeInfo -> edgeWeight edgeInfo > 0) weights' + weights = --pprTrace "cfg'" (pprEdgeWeights cfg') + cfg' + where + (_, globalEdgeWeights) = {-# SCC mkGlobalWeights #-} mkGlobalWeights entry weights' + cfg' = {-# SCC rewriteEdges #-} + mapFoldlWithKey + (\cfg from m -> + mapFoldlWithKey + (\cfg to w -> setEdgeWeight cfg (EdgeWeight w) from to ) + cfg m ) + weights' + globalEdgeWeights + + directEdges :: [CfgEdge] + directEdges = sortBy (flip compare) $ catMaybes . map relevantWeight $ (infoEdgeList weights) + where + relevantWeight :: CfgEdge -> Maybe CfgEdge + relevantWeight edge@(CfgEdge from to edgeInfo) + | (EdgeInfo CmmSource { trans_cmmNode = CmmCall {} } _) <- edgeInfo + -- Ignore edges across calls + = Nothing + | mapMember to info + , w <- edgeWeight edgeInfo + -- The payoff is small if we jump over an info table + = Just (CfgEdge from to edgeInfo { edgeWeight = w/8 }) + | otherwise + = Just edge + blockMap :: LabelMap (GenBasicBlock i) blockMap = foldl' (\m blk@(BasicBlock lbl _ins) -> mapInsert lbl blk m) mapEmpty blocks - toNode :: BlockId -> BlockNode C C - toNode bid = - -- sorted such that heavier successors come first. - BN (bid,map fst . getSuccEdgesSorted weights' $ bid) - - orderedBlocks :: [BlockId] - orderedBlocks - = map fromNode $ - revPostorderFrom (fmap (toNode . blockId) blockMap) entry - (builtChains, builtEdges) = {-# SCC "buildChains" #-} --pprTraceIt "generatedChains" $ - --pprTrace "orderedBlocks" (ppr orderedBlocks) $ - buildChains weights orderedBlocks + --pprTrace "blocks" (ppr (mapKeys blockMap)) $ + buildChains directEdges (mapKeys blockMap) - rankedEdges :: WeightedEdgeList - -- Sort edges descending, remove fused eges + rankedEdges :: [CfgEdge] + -- Sort descending by weight, remove fused edges rankedEdges = - map (\(from, to, weight) -> WeightedEdge from to weight) . - filter (\(from, to, _) - -> not (Set.member (from,to) builtEdges)) . - sortWith (\(_,_,w) -> - w) $ weightedEdgeList weights + filter (\edge -> not (Set.member (edgeFrom edge,edgeTo edge) builtEdges)) $ + directEdges - (fusedChains, fusedEdges) + (neighbourChains, combined) = ASSERT(noDups $ mapElems builtChains) - {-# SCC "fuseChains" #-} - --(pprTrace "RankedEdges" $ ppr rankedEdges) $ - --pprTraceIt "FusedChains" $ - fuseChains rankedEdges builtChains - - rankedEdges' = - filter (\edge -> not $ Set.member edge fusedEdges) $ rankedEdges - - neighbourChains - = ASSERT(noDups $ mapElems fusedChains) {-# SCC "groupNeighbourChains" #-} - --pprTraceIt "ResultChains" $ - combineNeighbourhood rankedEdges' (mapElems fusedChains) + -- pprTraceIt "NeighbourChains" $ + combineNeighbourhood rankedEdges (mapElems builtChains) + + + allEdges :: [CfgEdge] + allEdges = {-# SCC allEdges #-} + sortOn (relevantWeight) $ filter (not . deadEdge) $ (infoEdgeList weights) + where + deadEdge :: CfgEdge -> Bool + deadEdge (CfgEdge from to _) = let e = (from,to) in Set.member e combined || Set.member e builtEdges + relevantWeight :: CfgEdge -> EdgeWeight + relevantWeight (CfgEdge _ _ edgeInfo) + | EdgeInfo (CmmSource { trans_cmmNode = CmmCall {}}) _ <- edgeInfo + -- Penalize edges across calls + = weight/(64.0) + | otherwise + = weight + where + -- negate to sort descending + weight = negate (edgeWeight edgeInfo) + + masterChain = + {-# SCC "mergeChains" #-} + -- pprTraceIt "MergedChains" $ + mergeChains allEdges neighbourChains --Make sure the first block stays first - ([entryChain],chains') - = ASSERT(noDups $ neighbourChains) - partition (chainMember entry) neighbourChains - (entryChain':entryRest) - | inFront entry entryChain = [entryChain] - | (rest,entry) <- breakChainAt entry entryChain + prepedChains + | inFront entry masterChain + = [masterChain] + | (rest,entry) <- breakChainAt entry masterChain = [entry,rest] | otherwise = pprPanic "Entry point eliminated" $ - ppr ([entryChain],chains') + ppr masterChain - prepedChains - = entryChain':(entryRest++chains') :: [BlockChain] blockList - -- = (concatMap chainToBlocks prepedChains) - = (concatMap fromOL $ map chainBlocks prepedChains) + = ASSERT(noDups [masterChain]) + (concatMap fromOL $ map chainBlocks prepedChains) --chainPlaced = setFromList $ map blockId blockList :: LabelSet chainPlaced = setFromList $ blockList :: LabelSet @@ -605,14 +733,22 @@ sequenceChain info weights' blocks@((BasicBlock entry _):_) = in filter (\block -> not (isPlaced block)) blocks placedBlocks = + -- We want debug builds to catch this as it's a good indicator for + -- issues with CFG invariants. But we don't want to blow up production + -- builds if something slips through. + ASSERT(null unplaced) --pprTraceIt "placedBlocks" $ - blockList ++ unplaced + -- ++ [] is stil kinda expensive + if null unplaced then blockList else blockList ++ unplaced getBlock bid = expectJust "Block placment" $ mapLookup bid blockMap in --Assert we placed all blocks given as input ASSERT(all (\bid -> mapMember bid blockMap) placedBlocks) dropJumps info $ map getBlock placedBlocks +{-# SCC dropJumps #-} +-- | Remove redundant jumps between blocks when we can rely on +-- fall through. dropJumps :: forall a i. Instruction i => LabelMap a -> [GenBasicBlock i] -> [GenBasicBlock i] dropJumps _ [] = [] @@ -641,7 +777,8 @@ sequenceTop => DynFlags -- Determine which layout algo to use -> NcgImpl statics instr jumpDest -> Maybe CFG -- ^ CFG if we have one. - -> NatCmmDecl statics instr -> NatCmmDecl statics instr + -> NatCmmDecl statics instr -- ^ Function to serialize + -> NatCmmDecl statics instr sequenceTop _ _ _ top@(CmmData _ _) = top sequenceTop dflags ncgImpl edgeWeights @@ -650,11 +787,13 @@ sequenceTop dflags ncgImpl edgeWeights --Use chain based algorithm , Just cfg <- edgeWeights = CmmProc info lbl live ( ListGraph $ ncgMakeFarBranches ncgImpl info $ + {-# SCC layoutBlocks #-} sequenceChain info cfg blocks ) | otherwise --Use old algorithm = let cfg = if dontUseCfg then Nothing else edgeWeights in CmmProc info lbl live ( ListGraph $ ncgMakeFarBranches ncgImpl info $ + {-# SCC layoutBlocks #-} sequenceBlocks cfg info blocks) where dontUseCfg = gopt Opt_WeightlessBlocklayout dflags || diff --git a/compiler/nativeGen/CFG.hs b/compiler/nativeGen/CFG.hs index fee47188ac..8eb69a9dbf 100644 --- a/compiler/nativeGen/CFG.hs +++ b/compiler/nativeGen/CFG.hs @@ -6,31 +6,40 @@ {-# LANGUAGE TupleSections #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} module CFG ( CFG, CfgEdge(..), EdgeInfo(..), EdgeWeight(..) , TransitionSource(..) --Modify the CFG - , addWeightEdge, addEdge, delEdge + , addWeightEdge, addEdge + , delEdge, delNode , addNodesBetween, shortcutWeightMap , reverseEdges, filterEdges , addImmediateSuccessor - , mkWeightInfo, adjustEdgeWeight + , mkWeightInfo, adjustEdgeWeight, setEdgeWeight --Query the CFG , infoEdgeList, edgeList , getSuccessorEdges, getSuccessors - , getSuccEdgesSorted, weightedEdgeList + , getSuccEdgesSorted , getEdgeInfo , getCfgNodes, hasNode - , loopMembers + + -- Loop Information + , loopMembers, loopLevels, loopInfo --Construction/Misc , getCfg, getCfgProc, pprEdgeWeights, sanityCheckCfg --Find backedges and update their weight - , optimizeCFG ) + , optimizeCFG + , mkGlobalWeights + + ) where #include "HsVersions.h" @@ -38,9 +47,8 @@ where import GhcPrelude import BlockId -import Cmm ( RawCmmDecl, GenCmmDecl( .. ), CmmBlock, succ, g_entry - , CmmGraph ) -import CmmNode +import Cmm + import CmmUtils import CmmSwitch import Hoopl.Collections @@ -50,10 +58,24 @@ import qualified Hoopl.Graph as G import Util import Digraph +import Maybes + +import Unique +import qualified Dominators as Dom +import Data.IntMap.Strict (IntMap) +import Data.IntSet (IntSet) + +import qualified Data.IntMap.Strict as IM +import qualified Data.Map as M +import qualified Data.IntSet as IS +import qualified Data.Set as S +import Data.Tree +import Data.Bifunctor import Outputable -- DEBUGGING ONLY --import Debug +-- import Debug.Trace --import OrdList --import Debug.Trace import PprCmm () -- For Outputable instances @@ -61,17 +83,28 @@ import qualified DynFlags as D import Data.List --- import qualified Data.IntMap.Strict as M --TODO: LabelMap +import Data.STRef.Strict +import Control.Monad.ST + +import Data.Array.MArray +import Data.Array.ST +import Data.Array.IArray +import Data.Array.Unsafe (unsafeFreeze) +import Data.Array.Base (unsafeRead, unsafeWrite) + +import Control.Monad + +type Prob = Double type Edge = (BlockId, BlockId) type Edges = [Edge] newtype EdgeWeight - = EdgeWeight Int - deriving (Eq,Ord,Enum,Num,Real,Integral) + = EdgeWeight { weightToDouble :: Double } + deriving (Eq,Ord,Enum,Num,Real,Fractional) instance Outputable EdgeWeight where - ppr (EdgeWeight w) = ppr w + ppr (EdgeWeight w) = doublePrec 5 w type EdgeInfoMap edgeInfo = LabelMap (LabelMap edgeInfo) @@ -108,15 +141,28 @@ instance Outputable CfgEdge where = parens (ppr from1 <+> text "-(" <> ppr edgeInfo <> text ")->" <+> ppr to1) -- | Can we trace back a edge to a specific Cmm Node --- or has it been introduced for codegen. We use this to maintain +-- or has it been introduced during assembly codegen. We use this to maintain -- some information which would otherwise be lost during the -- Cmm <-> asm transition. -- See also Note [Inverting Conditional Branches] data TransitionSource - = CmmSource (CmmNode O C) + = CmmSource { trans_cmmNode :: (CmmNode O C) + , trans_info :: BranchInfo } | AsmCodeGen deriving (Eq) +data BranchInfo = NoInfo -- ^ Unknown, but not heap or stack check. + | HeapStackCheck -- ^ Heap or stack check + deriving Eq + +instance Outputable BranchInfo where + ppr NoInfo = text "regular" + ppr HeapStackCheck = text "heap/stack" + +isHeapOrStackCheck :: TransitionSource -> Bool +isHeapOrStackCheck (CmmSource { trans_info = HeapStackCheck}) = True +isHeapOrStackCheck _ = False + -- | Information about edges data EdgeInfo = EdgeInfo @@ -127,12 +173,10 @@ data EdgeInfo instance Outputable EdgeInfo where ppr edgeInfo = text "weight:" <+> ppr (edgeWeight edgeInfo) --- Allow specialization -{-# INLINEABLE mkWeightInfo #-} -- | Convenience function, generate edge info based -- on weight not originating from cmm. -mkWeightInfo :: Integral n => n -> EdgeInfo -mkWeightInfo = EdgeInfo AsmCodeGen . fromIntegral +mkWeightInfo :: EdgeWeight -> EdgeInfo +mkWeightInfo = EdgeInfo AsmCodeGen -- | Adjust the weight between the blocks using the given function. -- If there is no such edge returns the original map. @@ -140,12 +184,25 @@ adjustEdgeWeight :: CFG -> (EdgeWeight -> EdgeWeight) -> BlockId -> BlockId -> CFG adjustEdgeWeight cfg f from to | Just info <- getEdgeInfo from to cfg - , weight <- edgeWeight info - = addEdge from to (info { edgeWeight = f weight}) cfg + , !weight <- edgeWeight info + , !newWeight <- f weight + = addEdge from to (info { edgeWeight = newWeight}) cfg + | otherwise = cfg + +-- | Set the weight between the blocks to the given weight. +-- If there is no such edge returns the original map. +setEdgeWeight :: CFG -> EdgeWeight + -> BlockId -> BlockId -> CFG +setEdgeWeight cfg !weight from to + | Just info <- getEdgeInfo from to cfg + = addEdge from to (info { edgeWeight = weight}) cfg | otherwise = cfg + + getCfgNodes :: CFG -> LabelSet -getCfgNodes m = mapFoldMapWithKey (\k v -> setFromList (k:mapKeys v)) m +getCfgNodes m = + mapFoldlWithKey (\s k toMap -> mapFoldlWithKey (\s k _ -> setInsert k s) (setInsert k s) toMap ) setEmpty m hasNode :: CFG -> BlockId -> Bool hasNode m node = mapMember node m || any (mapMember node) m @@ -294,6 +351,11 @@ delEdge from to m = remDest Nothing = Nothing remDest (Just wm) = Just $ mapDelete to wm +delNode :: BlockId -> CFG -> CFG +delNode node cfg = + fmap (mapDelete node) -- < Edges to the node + (mapDelete node cfg) -- < Edges from the node + -- | Destinations from bid ordered by weight (descending) getSuccEdgesSorted :: CFG -> BlockId -> [(BlockId,EdgeInfo)] getSuccEdgesSorted m bid = @@ -315,36 +377,54 @@ getEdgeInfo from to m | otherwise = Nothing +getEdgeWeight :: CFG -> BlockId -> BlockId -> EdgeWeight +getEdgeWeight cfg from to = + edgeWeight $ expectJust "Edgeweight for noexisting block" $ + getEdgeInfo from to cfg + +getTransitionSource :: BlockId -> BlockId -> CFG -> TransitionSource +getTransitionSource from to cfg = transitionSource $ expectJust "Source info for noexisting block" $ + getEdgeInfo from to cfg + reverseEdges :: CFG -> CFG -reverseEdges cfg = foldr add mapEmpty flatElems +reverseEdges cfg = mapFoldlWithKey (\cfg from toMap -> go (addNode cfg from) from toMap) mapEmpty cfg where - elems = mapToList $ fmap mapToList cfg :: [(BlockId,[(BlockId,EdgeInfo)])] - flatElems = - concatMap (\(from,ws) -> map (\(to,info) -> (to,from,info)) ws ) elems - add (to,from,info) m = addEdge to from info m + -- We preserve nodes without outgoing edges! + addNode :: CFG -> BlockId -> CFG + addNode cfg b = mapInsertWith mapUnion b mapEmpty cfg + go :: CFG -> BlockId -> (LabelMap EdgeInfo) -> CFG + go cfg from toMap = mapFoldlWithKey (\cfg to info -> addEdge to from info cfg) cfg toMap :: CFG + -- | Returns a unordered list of all edges with info infoEdgeList :: CFG -> [CfgEdge] infoEdgeList m = - mapFoldMapWithKey - (\from toMap -> - map (\(to,info) -> CfgEdge from to info) (mapToList toMap)) - m - --- | Unordered list of edges with weight as Tuple (from,to,weight) -weightedEdgeList :: CFG -> [(BlockId,BlockId,EdgeWeight)] -weightedEdgeList m = - mapFoldMapWithKey - (\from toMap -> - map (\(to,info) -> - (from,to, edgeWeight info)) (mapToList toMap)) - m - -- (\(from, tos) -> map (\(to,info) -> (from,to, edgeWeight info)) tos ) + go (mapToList m) [] + where + -- We avoid foldMap to avoid thunk buildup + go :: [(BlockId,LabelMap EdgeInfo)] -> [CfgEdge] -> [CfgEdge] + go [] acc = acc + go ((from,toMap):xs) acc + = go' xs from (mapToList toMap) acc + go' :: [(BlockId,LabelMap EdgeInfo)] -> BlockId -> [(BlockId,EdgeInfo)] -> [CfgEdge] -> [CfgEdge] + go' froms _ [] acc = go froms acc + go' froms from ((to,info):tos) acc + = go' froms from tos (CfgEdge from to info : acc) -- | Returns a unordered list of all edges without weights edgeList :: CFG -> [Edge] edgeList m = - mapFoldMapWithKey (\from toMap -> fmap (from,) (mapKeys toMap)) m + go (mapToList m) [] + where + -- We avoid foldMap to avoid thunk buildup + go :: [(BlockId,LabelMap EdgeInfo)] -> [Edge] -> [Edge] + go [] acc = acc + go ((from,toMap):xs) acc + = go' xs from (mapKeys toMap) acc + go' :: [(BlockId,LabelMap EdgeInfo)] -> BlockId -> [BlockId] -> [Edge] -> [Edge] + go' froms _ [] acc = go froms acc + go' froms from (to:tos) acc + = go' froms from tos ((from,to) : acc) -- | Get successors of a given node without edge weights. getSuccessors :: CFG -> BlockId -> [BlockId] @@ -355,8 +435,8 @@ getSuccessors m bid pprEdgeWeights :: CFG -> SDoc pprEdgeWeights m = - let edges = sort $ weightedEdgeList m - printEdge (from, to, weight) + let edges = sort $ infoEdgeList m :: [CfgEdge] + printEdge (CfgEdge from to (EdgeInfo { edgeWeight = weight })) = text "\t" <> ppr from <+> text "->" <+> ppr to <> text "[label=\"" <> ppr weight <> text "\",weight=\"" <> ppr weight <> text "\"];\n" @@ -365,7 +445,7 @@ pprEdgeWeights m = --to immediately see it when it does. printNode node = text "\t" <> ppr node <> text ";\n" - getEdgeNodes (from, to, _weight) = [from,to] + getEdgeNodes (CfgEdge from to _) = [from,to] edgeNodes = setFromList $ concatMap getEdgeNodes edges :: LabelSet nodes = filter (\n -> (not . setMember n) edgeNodes) . mapKeys $ mapFilter null m in @@ -378,8 +458,8 @@ pprEdgeWeights m = updateEdgeWeight :: (EdgeWeight -> EdgeWeight) -> Edge -> CFG -> CFG updateEdgeWeight f (from, to) cfg | Just oldInfo <- getEdgeInfo from to cfg - = let oldWeight = edgeWeight oldInfo - newWeight = f oldWeight + = let !oldWeight = edgeWeight oldInfo + !newWeight = f oldWeight in addEdge from to (oldInfo {edgeWeight = newWeight}) cfg | otherwise = panic "Trying to update invalid edge" @@ -448,9 +528,7 @@ addNodesBetween m updates = Should A or B be placed in front of C? The block layout algorithm decides this based on which edge (A,C)/(B,C) is heavier. So we - make a educated guess how often execution will transer control - along each edge as well as how much we gain by placing eg A before - C. + make a educated guess on which branch should be preferred. We rank edges in this order: * Unconditional Control Transfer - They will always @@ -479,7 +557,6 @@ addNodesBetween m updates = address. This reduces the chance that we return to the same cache line further. - -} -- | Generate weights for a Cmm proc based on some simple heuristics. getCfgProc :: D.CfgWeights -> RawCmmDecl -> CFG @@ -515,13 +592,24 @@ getCfg weights graph = getBlockEdges block = case branch of CmmBranch dest -> [mkEdge dest uncondWeight] - CmmCondBranch _c t f l + CmmCondBranch cond t f l | l == Nothing -> [mkEdge f condBranchWeight, mkEdge t condBranchWeight] | l == Just True -> [mkEdge f unlikelyCondWeight, mkEdge t likelyCondWeight] | l == Just False -> [mkEdge f likelyCondWeight, mkEdge t unlikelyCondWeight] + where + mkEdgeInfo = -- pprTrace "Info" (ppr branchInfo <+> ppr cond) + EdgeInfo (CmmSource branch branchInfo) . fromIntegral + mkEdge target weight = ((bid,target), mkEdgeInfo weight) + branchInfo = + foldRegsUsed + (panic "foldRegsDynFlags") + (\info r -> if r == SpLim || r == HpLim || r == BaseReg + then HeapStackCheck else info) + NoInfo cond + (CmmSwitch _e ids) -> let switchTargets = switchTargetsToList ids --Compiler performance hack - for very wide switches don't @@ -539,7 +627,7 @@ getCfg weights graph = map (\x -> ((bid,x),mkEdgeInfo 0)) $ G.successors other where bid = G.entryLabel block - mkEdgeInfo = EdgeInfo (CmmSource branch) . fromIntegral + mkEdgeInfo = EdgeInfo (CmmSource branch NoInfo) . fromIntegral mkEdge target weight = ((bid,target), mkEdgeInfo weight) branch = lastNode block :: CmmNode O C @@ -561,6 +649,11 @@ findBackEdges root cfg = optimizeCFG :: D.CfgWeights -> RawCmmDecl -> CFG -> CFG optimizeCFG _ (CmmData {}) cfg = cfg optimizeCFG weights (CmmProc info _lab _live graph) cfg = + {-# SCC optimizeCFG #-} + -- pprTrace "Initial:" (pprEdgeWeights cfg) $ + -- pprTrace "Initial:" (ppr $ mkGlobalWeights (g_entry graph) cfg) $ + + -- pprTrace "LoopInfo:" (ppr $ loopInfo cfg (g_entry graph)) $ favourFewerPreds . penalizeInfoTables info . increaseBackEdgeWeight (g_entry graph) $ cfg @@ -590,12 +683,8 @@ optimizeCFG weights (CmmProc info _lab _live graph) cfg = = weight - (fromIntegral $ D.infoTablePenalty weights) | otherwise = weight - -{- Note [Optimize for Fallthrough] - --} -- | If a block has two successors, favour the one with fewer - -- predecessors. (As that one is more likely to become a fallthrough) + -- predecessors and/or the one allowing fall through. favourFewerPreds :: CFG -> CFG favourFewerPreds cfg = let @@ -612,16 +701,17 @@ optimizeCFG weights (CmmProc info _lab _live graph) cfg = | preds1 == preds2 = ( 0, 0) | otherwise = (-1, 1) + update :: CFG -> BlockId -> CFG update cfg node | [(s1,e1),(s2,e2)] <- getSuccessorEdges cfg node - , w1 <- edgeWeight e1 - , w2 <- edgeWeight e2 + , !w1 <- edgeWeight e1 + , !w2 <- edgeWeight e2 --Only change the weights if there isn't already a ordering. , w1 == w2 , (mod1,mod2) <- modifiers (predCount s1) (predCount s2) = (\cfg' -> (adjustEdgeWeight cfg' (+mod2) node s2)) - (adjustEdgeWeight cfg (+mod1) node s1) + (adjustEdgeWeight cfg (+mod1) node s1) | otherwise = cfg in setFoldl update cfg nodes @@ -630,13 +720,12 @@ optimizeCFG weights (CmmProc info _lab _live graph) cfg = fallthroughTarget to (EdgeInfo source _weight) | mapMember to info = False | AsmCodeGen <- source = True - | CmmSource (CmmBranch {}) <- source = True - | CmmSource (CmmCondBranch {}) <- source = True + | CmmSource { trans_cmmNode = CmmBranch {} } <- source = True + | CmmSource { trans_cmmNode = CmmCondBranch {} } <- source = True | otherwise = False -- | Determine loop membership of blocks based on SCC analysis --- Ideally we would replace this with a variant giving us loop --- levels instead but the SCC code will do for now. +-- This is faster but only gives yes/no answers. loopMembers :: CFG -> LabelMap Bool loopMembers cfg = foldl' (flip setLevel) mapEmpty sccs @@ -650,3 +739,534 @@ loopMembers cfg = setLevel :: SCC BlockId -> LabelMap Bool -> LabelMap Bool setLevel (AcyclicSCC bid) m = mapInsert bid False m setLevel (CyclicSCC bids) m = foldl' (\m k -> mapInsert k True m) m bids + +loopLevels :: CFG -> BlockId -> LabelMap Int +loopLevels cfg root = liLevels $ loopInfo cfg root + +data LoopInfo = LoopInfo + { liBackEdges :: [(Edge)] -- ^ List of back edges + , liLevels :: LabelMap Int -- ^ BlockId -> LoopLevel mapping + , liLoops :: [(Edge, LabelSet)] -- ^ (backEdge, loopBody), body includes header + } + +instance Outputable LoopInfo where + ppr (LoopInfo _ _lvls loops) = + text "Loops:(backEdge, bodyNodes)" $$ + (vcat $ map ppr loops) + +-- | Determine loop membership of blocks based on Dominator analysis. +-- This is slower but gives loop levels instead of just loop membership. +-- However it only detects natural loops. Irreducible control flow is not +-- recognized even if it loops. But that is rare enough that we don't have +-- to care about that special case. +loopInfo :: CFG -> BlockId -> LoopInfo +loopInfo cfg root = LoopInfo { liBackEdges = backEdges + , liLevels = mapFromList loopCounts + , liLoops = loopBodies } + where + revCfg = reverseEdges cfg + graph = fmap (setFromList . mapKeys ) cfg :: LabelMap LabelSet + + --TODO - This should be a no op: Export constructors? Use unsafeCoerce? ... + rooted = ( fromBlockId root + , toIntMap $ fmap toIntSet graph) :: (Int, IntMap IntSet) + -- rooted = unsafeCoerce (root, graph) + tree = fmap toBlockId $ Dom.domTree rooted :: Tree BlockId + + -- Map from Nodes to their dominators + domMap :: LabelMap LabelSet + domMap = mkDomMap tree + + edges = edgeList cfg :: [(BlockId, BlockId)] + -- We can't recompute this from the edges, there might be blocks not connected via edges. + nodes = getCfgNodes cfg :: LabelSet + + -- identify back edges + isBackEdge (from,to) + | Just doms <- mapLookup from domMap + , setMember to doms + = True + | otherwise = False + + -- determine the loop body for a back edge + findBody edge@(tail, head) + = ( edge, setInsert head $ go (setSingleton tail) (setSingleton tail) ) + where + -- The reversed cfg makes it easier to look up predecessors + cfg' = delNode head revCfg + go :: LabelSet -> LabelSet -> LabelSet + go found current + | setNull current = found + | otherwise = go (setUnion newSuccessors found) + newSuccessors + where + newSuccessors = setFilter (\n -> not $ setMember n found) successors :: LabelSet + successors = setFromList $ concatMap + (getSuccessors cfg') + (setElems current) :: LabelSet + + backEdges = filter isBackEdge edges + loopBodies = map findBody backEdges :: [(Edge, LabelSet)] + + -- Block b is part of n loop bodies => loop nest level of n + loopCounts = + let bodies = map (first snd) loopBodies -- [(Header, Body)] + loopCount n = length $ nub . map fst . filter (setMember n . snd) $ bodies + in map (\n -> (n, loopCount n)) $ setElems nodes :: [(BlockId, Int)] + + toIntSet :: LabelSet -> IntSet + toIntSet s = IS.fromList . map fromBlockId . setElems $ s + toIntMap :: LabelMap a -> IntMap a + toIntMap m = IM.fromList $ map (\(x,y) -> (fromBlockId x,y)) $ mapToList m + + mkDomMap :: Tree BlockId -> LabelMap LabelSet + mkDomMap root = mapFromList $ go setEmpty root + where + go :: LabelSet -> Tree BlockId -> [(Label,LabelSet)] + go parents (Node lbl []) + = [(lbl, parents)] + go parents (Node _ leaves) + = let nodes = map rootLabel leaves + entries = map (\x -> (x,parents)) nodes + in entries ++ concatMap + (\n -> go (setInsert (rootLabel n) parents) n) + leaves + + fromBlockId :: BlockId -> Int + fromBlockId = getKey . getUnique + + toBlockId :: Int -> BlockId + toBlockId = mkBlockId . mkUniqueGrimily + +-- We make the CFG a Hoopl Graph, so we can reuse revPostOrder. +newtype BlockNode (e :: Extensibility) (x :: Extensibility) = BN (BlockId,[BlockId]) + +instance G.NonLocal (BlockNode) where + entryLabel (BN (lbl,_)) = lbl + successors (BN (_,succs)) = succs + +revPostorderFrom :: CFG -> BlockId -> [BlockId] +revPostorderFrom cfg root = + map fromNode $ G.revPostorderFrom hooplGraph root + where + nodes = getCfgNodes cfg + hooplGraph = setFoldl (\m n -> mapInsert n (toNode n) m) mapEmpty nodes + + fromNode :: BlockNode C C -> BlockId + fromNode (BN x) = fst x + + toNode :: BlockId -> BlockNode C C + toNode bid = + BN (bid,getSuccessors cfg $ bid) + + +-- | We take in a CFG which has on its edges weights which are +-- relative only to other edges originating from the same node. +-- +-- We return a CFG for which each edge represents a GLOBAL weight. +-- This means edge weights are comparable across the whole graph. +-- +-- For irreducible control flow results might be imprecise, otherwise they +-- are reliable. +-- +-- The algorithm is based on the Paper +-- "Static Branch Prediction and Program Profile Analysis" by Y Wu, JR Larus +-- The only big change is that we go over the nodes in the body of loops in +-- reverse post order. Which is required for diamond control flow to work probably. +-- +-- We also apply a few prediction heuristics (based on the same paper) + +{-# SCC mkGlobalWeights #-} +mkGlobalWeights :: BlockId -> CFG -> (LabelMap Double, LabelMap (LabelMap Double)) +mkGlobalWeights root localCfg + | null localCfg = panic "Error - Empty CFG" + | otherwise + = --pprTrace "revOrder" (ppr revOrder) $ + -- undefined --propagate (mapSingleton root 1) (revOrder) + (blockFreqs', edgeFreqs') + where + -- Calculate fixpoints + (blockFreqs, edgeFreqs) = calcFreqs nodeProbs backEdges' bodies' revOrder' + blockFreqs' = mapFromList $ map (first fromVertex) (assocs blockFreqs) :: LabelMap Double + edgeFreqs' = fmap fromVertexMap $ fromVertexMap edgeFreqs + + fromVertexMap :: IM.IntMap x -> LabelMap x + fromVertexMap m = mapFromList . map (first fromVertex) $ IM.toList m + + revOrder = revPostorderFrom localCfg root :: [BlockId] + loopinfo@(LoopInfo backedges _levels bodies) = loopInfo localCfg root + + revOrder' = map toVertex revOrder + backEdges' = map (bimap toVertex toVertex) backedges + bodies' = map calcBody bodies + + estimatedCfg = staticBranchPrediction root loopinfo localCfg + -- Normalize the weights to probabilities and apply heuristics + nodeProbs = cfgEdgeProbabilities estimatedCfg toVertex + + -- By mapping vertices to numbers in reverse post order we can bring any subset into reverse post + -- order simply by sorting. + -- TODO: The sort is redundant if we can guarantee that setElems returns elements ascending + calcBody (backedge, blocks) = + (toVertex $ snd backedge, sort . map toVertex $ (setElems blocks)) + + vertexMapping = mapFromList $ zip revOrder [0..] :: LabelMap Int + blockMapping = listArray (0,mapSize vertexMapping - 1) revOrder :: Array Int BlockId + -- Map from blockId to indicies starting at zero + toVertex :: BlockId -> Int + toVertex blockId = expectJust "mkGlobalWeights" $ mapLookup blockId vertexMapping + -- Map from indicies starting at zero to blockIds + fromVertex :: Int -> BlockId + fromVertex vertex = blockMapping ! vertex + +{- Note [Static Branch Prediction] + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The work here has been based on the paper +"Static Branch Prediction and Program Profile Analysis" by Y Wu, JR Larus. + +The primary differences are that if we branch on the result of a heap +check we do not apply any of the heuristics. +The reason is simple: They look like loops in the control flow graph +but are usually never entered, and if at most once. + +Currently implemented is a heuristic to predict that we do not exit +loops (lehPredicts) and one to predict that backedges are more likely +than any other edge. + +The back edge case is special as it superceeds any other heuristic if it +applies. + +Do NOT rely solely on nofib results for benchmarking this. I recommend at least +comparing megaparsec and container benchmarks. Nofib does not seeem to have +many instances of "loopy" Cmm where these make a difference. + +TODO: +* The paper containers more benchmarks which should be implemented. +* If we turn the likelyhood on if/else branches into a probability + instead of true/false we could implement this as a Cmm pass. + + The complete Cmm code still exists and can be accessed by the heuristics + + There is no chance of register allocation/codegen inserting branches/blocks + + making the TransitionSource info wrong. + + potential to use this information in CmmPasses. + - Requires refactoring of all the code relying on the binary nature of likelyhood. + - Requires refactoring `loopInfo` to work on both, Cmm Graphs and the backend CFG. +-} + +-- | Combination of target node id and information about the branch +-- we are looking at. +type TargetNodeInfo = (BlockId, EdgeInfo) + + +-- | Update branch weights based on certain heuristics. +-- See Note [Static Branch Prediction] +-- TODO: This should be combined with optimizeCFG +{-# SCC staticBranchPrediction #-} +staticBranchPrediction :: BlockId -> LoopInfo -> CFG -> CFG +staticBranchPrediction _root (LoopInfo l_backEdges loopLevels l_loops) cfg = + -- pprTrace "staticEstimatesOn" (ppr (cfg)) $ + setFoldl update cfg nodes + where + nodes = getCfgNodes cfg + backedges = S.fromList $ l_backEdges + -- Loops keyed by their back edge + loops = M.fromList $ l_loops :: M.Map Edge LabelSet + loopHeads = S.fromList $ map snd $ M.keys loops + + update :: CFG -> BlockId -> CFG + update cfg node + -- No successors, nothing to do. + | null successors = cfg + + -- Mix of backedges and others: + -- Always predict the backedges. + | not (null m) && length m < length successors + -- Heap/Stack checks "loop", but only once. + -- So we simply exclude any case involving them. + , not $ any (isHeapOrStackCheck . transitionSource . snd) successors + = let loopChance = repeat $! pred_LBH / (fromIntegral $ length m) + exitChance = repeat $! (1 - pred_LBH) / fromIntegral (length not_m) + updates = zip (map fst m) loopChance ++ zip (map fst not_m) exitChance + in -- pprTrace "mix" (ppr (node,successors)) $ + foldl' (\cfg (to,weight) -> setEdgeWeight cfg weight node to) cfg updates + + -- For (regular) non-binary branches we keep the weights from the STG -> Cmm translation. + | length successors /= 2 + = cfg + + -- Only backedges - no need to adjust + | length m > 0 + = cfg + + -- A regular binary branch, we can plug addition predictors in here. + | [(s1,s1_info),(s2,s2_info)] <- successors + , not $ any (isHeapOrStackCheck . transitionSource . snd) successors + = -- Normalize weights to total of 1 + let !w1 = max (edgeWeight s1_info) (0) + !w2 = max (edgeWeight s2_info) (0) + -- Of both weights are <= 0 we set both to 0.5 + normalizeWeight w = if w1 + w2 == 0 then 0.5 else w/(w1+w2) + !cfg' = setEdgeWeight cfg (normalizeWeight w1) node s1 + !cfg'' = setEdgeWeight cfg' (normalizeWeight w2) node s2 + + -- Figure out which heuristics apply to these successors + heuristics = map ($ ((s1,s1_info),(s2,s2_info))) + [lehPredicts, phPredicts, ohPredicts, ghPredicts, lhhPredicts, chPredicts + , shPredicts, rhPredicts] + -- Apply result of a heuristic. Argument is the likelyhood + -- predicted for s1. + applyHeuristic :: CFG -> Maybe Prob -> CFG + applyHeuristic cfg Nothing = cfg + applyHeuristic cfg (Just (s1_pred :: Double)) + | s1_old == 0 || s2_old == 0 || + isHeapOrStackCheck (transitionSource s1_info) || + isHeapOrStackCheck (transitionSource s2_info) + = cfg + | otherwise = + let -- Predictions from heuristic + s1_prob = EdgeWeight s1_pred :: EdgeWeight + s2_prob = 1.0 - s1_prob + -- Update + d = (s1_old * s1_prob) + (s2_old * s2_prob) :: EdgeWeight + s1_prob' = s1_old * s1_prob / d + !s2_prob' = s2_old * s2_prob / d + !cfg_s1 = setEdgeWeight cfg s1_prob' node s1 + in -- pprTrace "Applying heuristic!" (ppr (node,s1,s2) $$ ppr (s1_prob', s2_prob')) $ + setEdgeWeight cfg_s1 s2_prob' node s2 + where + -- Old weights + s1_old = getEdgeWeight cfg node s1 + s2_old = getEdgeWeight cfg node s2 + + in + -- pprTraceIt "RegularCfgResult" $ + foldl' applyHeuristic cfg'' heuristics + + -- Branch on heap/stack check + | otherwise = cfg + + where + -- Chance that loops are taken. + pred_LBH = 0.875 + -- successors + successors = getSuccessorEdges cfg node + -- backedges + (m,not_m) = partition (\succ -> S.member (node, fst succ) backedges) successors + + -- Heuristics return nothing if they don't say anything about this branch + -- or Just (prob_s1) where prob_s1 is the likelyhood for s1 to be the + -- taken branch. s1 is the branch in the true case. + + -- Loop exit heuristic. + -- We are unlikely to leave a loop unless it's to enter another one. + pred_LEH = 0.75 + -- If and only if no successor is a loopheader, + -- then we will likely not exit the current loop body. + lehPredicts :: (TargetNodeInfo,TargetNodeInfo) -> Maybe Prob + lehPredicts ((s1,_s1_info),(s2,_s2_info)) + | S.member s1 loopHeads || S.member s2 loopHeads + = Nothing + + | otherwise + = --pprTrace "lehPredict:" (ppr $ compare s1Level s2Level) $ + case compare s1Level s2Level of + EQ -> Nothing + LT -> Just (1-pred_LEH) --s1 exits to a shallower loop level (exits loop) + GT -> Just (pred_LEH) --s1 exits to a deeper loop level + where + s1Level = mapLookup s1 loopLevels + s2Level = mapLookup s2 loopLevels + + -- Comparing to a constant is unlikely to be equal. + ohPredicts (s1,_s2) + | CmmSource { trans_cmmNode = src1 } <- getTransitionSource node (fst s1) cfg + , CmmCondBranch cond ltrue _lfalse likely <- src1 + , likely == Nothing + , CmmMachOp mop args <- cond + , MO_Eq {} <- mop + , not (null [x | x@CmmLit{} <- args]) + = if fst s1 == ltrue then Just 0.3 else Just 0.7 + + | otherwise + = Nothing + + -- TODO: These are all the other heuristics from the paper. + -- Not all will apply, for now we just stub them out as Nothing. + phPredicts = const Nothing + ghPredicts = const Nothing + lhhPredicts = const Nothing + chPredicts = const Nothing + shPredicts = const Nothing + rhPredicts = const Nothing + +-- We normalize all edge weights as probabilities between 0 and 1. +-- Ignoring rounding errors all outgoing edges sum up to 1. +cfgEdgeProbabilities :: CFG -> (BlockId -> Int) -> IM.IntMap (IM.IntMap Prob) +cfgEdgeProbabilities cfg toVertex + = mapFoldlWithKey foldEdges IM.empty cfg + where + foldEdges = (\m from toMap -> IM.insert (toVertex from) (normalize toMap) m) + + normalize :: (LabelMap EdgeInfo) -> (IM.IntMap Prob) + normalize weightMap + | edgeCount <= 1 = mapFoldlWithKey (\m k _ -> IM.insert (toVertex k) 1.0 m) IM.empty weightMap + | otherwise = mapFoldlWithKey (\m k _ -> IM.insert (toVertex k) (normalWeight k) m) IM.empty weightMap + where + edgeCount = mapSize weightMap + -- Negative weights are generally allowed but are mapped to zero. + -- We then check if there is at least one non-zero edge and if not + -- assign uniform weights to all branches. + minWeight = 0 :: Prob + weightMap' = fmap (\w -> max (weightToDouble . edgeWeight $ w) minWeight) weightMap + totalWeight = sum weightMap' + + normalWeight :: BlockId -> Prob + normalWeight bid + | totalWeight == 0 + = 1.0 / fromIntegral edgeCount + | Just w <- mapLookup bid weightMap' + = w/totalWeight + | otherwise = panic "impossible" + +-- This is the fixpoint algorithm from +-- "Static Branch Prediction and Program Profile Analysis" by Y Wu, JR Larus +-- The adaption to Haskell is my own. +calcFreqs :: IM.IntMap (IM.IntMap Prob) -> [(Int,Int)] -> [(Int, [Int])] -> [Int] + -> (Array Int Double, IM.IntMap (IM.IntMap Prob)) +calcFreqs graph backEdges loops revPostOrder = runST $ do + visitedNodes <- newArray (0,nodeCount-1) False :: ST s (STUArray s Int Bool) + blockFreqs <- newArray (0,nodeCount-1) 0.0 :: ST s (STUArray s Int Double) + edgeProbs <- newSTRef graph + edgeBackProbs <- newSTRef graph + + -- let traceArray a = do + -- vs <- forM [0..nodeCount-1] $ \i -> readArray a i >>= (\v -> return (i,v)) + -- trace ("array: " ++ show vs) $ return () + + let -- See #1600, we need to inline or unboxing makes perf worse. + -- {-# INLINE getFreq #-} + {-# INLINE visited #-} + visited b = unsafeRead visitedNodes b + getFreq b = unsafeRead blockFreqs b + -- setFreq :: forall s. Int -> Double -> ST s () + setFreq b f = unsafeWrite blockFreqs b f + -- setVisited :: forall s. Node -> ST s () + setVisited b = unsafeWrite visitedNodes b True + -- Frequency/probability that edge is taken. + getProb' arr b1 b2 = readSTRef arr >>= + (\graph -> + return . + fromMaybe (error "getFreq 1") . + IM.lookup b2 . + fromMaybe (error "getFreq 2") $ + (IM.lookup b1 graph) + ) + setProb' arr b1 b2 prob = do + g <- readSTRef arr + let !m = fromMaybe (error "Foo") $ IM.lookup b1 g + !m' = IM.insert b2 prob m + writeSTRef arr $! (IM.insert b1 m' g) + + getEdgeFreq b1 b2 = getProb' edgeProbs b1 b2 + setEdgeFreq b1 b2 = setProb' edgeProbs b1 b2 + getProb b1 b2 = fromMaybe (error "getProb") $ do + m' <- IM.lookup b1 graph + IM.lookup b2 m' + + getBackProb b1 b2 = getProb' edgeBackProbs b1 b2 + setBackProb b1 b2 = setProb' edgeBackProbs b1 b2 + + + let -- calcOutFreqs :: Node -> ST s () + calcOutFreqs bhead block = do + !f <- getFreq block + forM (successors block) $ \bi -> do + let !prob = getProb block bi + let !succFreq = f * prob + setEdgeFreq block bi succFreq + -- traceM $ "SetOut: " ++ show (block, bi, f, prob, succFreq) + when (bi == bhead) $ setBackProb block bi succFreq + + + let propFreq block head = do + -- traceM ("prop:" ++ show (block,head)) + -- traceShowM block + + !v <- visited block + if v then + return () --Dont look at nodes twice + else if block == head then + setFreq block 1.0 -- Loop header frequency is always 1 + else do + let preds = IS.elems $ predecessors block + irreducible <- (fmap or) $ forM preds $ \bp -> do + !bp_visited <- visited bp + let bp_backedge = isBackEdge bp block + return (not bp_visited && not bp_backedge) + + if irreducible + then return () -- Rare we don't care + else do + setFreq block 0 + !cycleProb <- sum <$> (forM preds $ \pred -> do + if isBackEdge pred block + then + getBackProb pred block + else do + !f <- getFreq block + !prob <- getEdgeFreq pred block + setFreq block $! f + prob + return 0) + -- traceM $ "cycleProb:" ++ show cycleProb + let limit = 1 - 1/512 -- Paper uses 1 - epsilon, but this works. + -- determines how large likelyhoods in loops can grow. + !cycleProb <- return $ min cycleProb limit -- <- return $ if cycleProb > limit then limit else cycleProb + -- traceM $ "cycleProb:" ++ show cycleProb + + !f <- getFreq block + setFreq block (f / (1.0 - cycleProb)) + + setVisited block + calcOutFreqs head block + + -- Loops, by nesting, inner to outer + forM_ loops $ \(head, body) -> do + forM_ [0 .. nodeCount - 1] (\i -> unsafeWrite visitedNodes i True) -- Mark all nodes as visited. + forM_ body (\i -> unsafeWrite visitedNodes i False) -- Mark all blocks reachable from head as not visited + forM_ body $ \block -> propFreq block head + + -- After dealing with all loops, deal with non-looping parts of the CFG + forM_ [0 .. nodeCount - 1] (\i -> unsafeWrite visitedNodes i False) -- Everything in revPostOrder is reachable + forM_ revPostOrder $ \block -> propFreq block (head revPostOrder) + + -- trace ("Final freqs:") $ return () + -- let freqString = pprFreqs freqs + -- trace (unlines freqString) $ return () + -- trace (pprFre) $ return () + graph' <- readSTRef edgeProbs + freqs' <- unsafeFreeze blockFreqs + + return (freqs', graph') + where + predecessors :: Int -> IS.IntSet + predecessors b = fromMaybe IS.empty $ IM.lookup b revGraph + successors b = fromMaybe (lookupError "succ" b graph)$ IM.keys <$> IM.lookup b graph + lookupError s b g = pprPanic ("Lookup error " ++ s) $ + ( text "node" <+> ppr b $$ + text "graph" <+> + vcat (map (\(k,m) -> ppr (k,m :: IM.IntMap Double)) $ IM.toList g) + ) + + nodeCount = IM.foldl' (\count toMap -> IM.foldlWithKey' countTargets count toMap) (IM.size graph) graph + where + countTargets = (\count k _ -> countNode k + count ) + countNode n = if IM.member n graph then 0 else 1 + + isBackEdge from to = S.member (from,to) backEdgeSet + backEdgeSet = S.fromList backEdges + + revGraph :: IntMap IntSet + revGraph = IM.foldlWithKey' (\m from toMap -> addEdges m from toMap) IM.empty graph + where + addEdges m0 from toMap = IM.foldlWithKey' (\m k _ -> addEdge m from k) m0 toMap + addEdge m0 from to = IM.insertWith IS.union to (IS.singleton from) m0 diff --git a/compiler/nativeGen/RegAlloc/Graph/SpillCost.hs b/compiler/nativeGen/RegAlloc/Graph/SpillCost.hs index 9c6e24d320..52f590948a 100644 --- a/compiler/nativeGen/RegAlloc/Graph/SpillCost.hs +++ b/compiler/nativeGen/RegAlloc/Graph/SpillCost.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE ScopedTypeVariables, GADTs, BangPatterns #-} module RegAlloc.Graph.SpillCost ( SpillCostRecord, plusSpillCostRecord, @@ -23,6 +23,7 @@ import Reg import GraphBase import Hoopl.Collections (mapLookup) +import Hoopl.Label import Cmm import UniqFM import UniqSet @@ -49,9 +50,6 @@ type SpillCostRecord type SpillCostInfo = UniqFM SpillCostRecord --- | Block membership in a loop -type LoopMember = Bool - type SpillCostState = State (UniqFM SpillCostRecord) () -- | An empty map of spill costs. @@ -88,45 +86,49 @@ slurpSpillCostInfo platform cfg cmm where countCmm CmmData{} = return () countCmm (CmmProc info _ _ sccs) - = mapM_ (countBlock info) + = mapM_ (countBlock info freqMap) $ flattenSCCs sccs + where + LiveInfo _ entries _ _ = info + freqMap = (fst . mkGlobalWeights (head entries)) <$> cfg -- Lookup the regs that are live on entry to this block in -- the info table from the CmmProc. - countBlock info (BasicBlock blockId instrs) + countBlock info freqMap (BasicBlock blockId instrs) | LiveInfo _ _ blockLive _ <- info , Just rsLiveEntry <- mapLookup blockId blockLive , rsLiveEntry_virt <- takeVirtuals rsLiveEntry - = countLIs (loopMember blockId) rsLiveEntry_virt instrs + = countLIs (ceiling $ blockFreq freqMap blockId) rsLiveEntry_virt instrs | otherwise = error "RegAlloc.SpillCost.slurpSpillCostInfo: bad block" - countLIs :: LoopMember -> UniqSet VirtualReg -> [LiveInstr instr] -> SpillCostState + + countLIs :: Int -> UniqSet VirtualReg -> [LiveInstr instr] -> SpillCostState countLIs _ _ [] = return () -- Skip over comment and delta pseudo instrs. - countLIs inLoop rsLive (LiveInstr instr Nothing : lis) + countLIs scale rsLive (LiveInstr instr Nothing : lis) | isMetaInstr instr - = countLIs inLoop rsLive lis + = countLIs scale rsLive lis | otherwise = pprPanic "RegSpillCost.slurpSpillCostInfo" $ text "no liveness information on instruction " <> ppr instr - countLIs inLoop rsLiveEntry (LiveInstr instr (Just live) : lis) + countLIs scale rsLiveEntry (LiveInstr instr (Just live) : lis) = do -- Increment the lifetime counts for regs live on entry to this instr. - mapM_ (incLifetime (loopCount inLoop)) $ nonDetEltsUniqSet rsLiveEntry + mapM_ incLifetime $ nonDetEltsUniqSet rsLiveEntry -- This is non-deterministic but we do not -- currently support deterministic code-generation. -- See Note [Unique Determinism and code generation] -- Increment counts for what regs were read/written from. let (RU read written) = regUsageOfInstr platform instr - mapM_ (incUses (loopCount inLoop)) $ catMaybes $ map takeVirtualReg $ nub read - mapM_ (incDefs (loopCount inLoop)) $ catMaybes $ map takeVirtualReg $ nub written + mapM_ (incUses scale) $ catMaybes $ map takeVirtualReg $ nub read + mapM_ (incDefs scale) $ catMaybes $ map takeVirtualReg $ nub written -- Compute liveness for entry to next instruction. let liveDieRead_virt = takeVirtuals (liveDieRead live) @@ -140,21 +142,18 @@ slurpSpillCostInfo platform cfg cmm = (rsLiveAcross `unionUniqSets` liveBorn_virt) `minusUniqSet` liveDieWrite_virt - countLIs inLoop rsLiveNext lis + countLIs scale rsLiveNext lis - loopCount inLoop - | inLoop = 10 - | otherwise = 1 incDefs count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, count, 0, 0) incUses count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, count, 0) - incLifetime count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, 0, count) + incLifetime reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, 0, 1) - loopBlocks = CFG.loopMembers <$> cfg - loopMember bid - | Just isMember <- join (mapLookup bid <$> loopBlocks) - = isMember + blockFreq :: Maybe (LabelMap Double) -> Label -> Double + blockFreq freqs bid + | Just freq <- join (mapLookup bid <$> freqs) + = max 1.0 (10000 * freq) | otherwise - = False + = 1.0 -- Only if no cfg given -- | Take all the virtual registers from this set. takeVirtuals :: UniqSet Reg -> UniqSet VirtualReg @@ -215,31 +214,39 @@ chooseSpill info graph -- Without live range splitting, its's better to spill from the outside -- in so set the cost of very long live ranges to zero -- -{- -spillCost_chaitin - :: SpillCostInfo - -> Graph Reg RegClass Reg - -> Reg - -> Float -spillCost_chaitin info graph reg - -- Spilling a live range that only lives for 1 instruction - -- isn't going to help us at all - and we definitely want to avoid - -- trying to re-spill previously inserted spill code. - | lifetime <= 1 = 1/0 - - -- It's unlikely that we'll find a reg for a live range this long - -- better to spill it straight up and not risk trying to keep it around - -- and have to go through the build/color cycle again. - | lifetime > allocatableRegsInClass (regClass reg) * 10 - = 0 +-- spillCost_chaitin +-- :: SpillCostInfo +-- -> Graph VirtualReg RegClass RealReg +-- -> VirtualReg +-- -> Float + +-- spillCost_chaitin info graph reg +-- -- Spilling a live range that only lives for 1 instruction +-- -- isn't going to help us at all - and we definitely want to avoid +-- -- trying to re-spill previously inserted spill code. +-- | lifetime <= 1 = 1/0 + +-- -- It's unlikely that we'll find a reg for a live range this long +-- -- better to spill it straight up and not risk trying to keep it around +-- -- and have to go through the build/color cycle again. + +-- -- To facility this we scale down the spill cost of long ranges. +-- -- This makes sure long ranges are still spilled first. +-- -- But this way spill cost remains relevant for long live +-- -- ranges. +-- | lifetime >= 128 +-- = (spillCost / conflicts) / 10.0 + + +-- -- Otherwise revert to chaitin's regular cost function. +-- | otherwise = (spillCost / conflicts) +-- where +-- !spillCost = fromIntegral (uses + defs) :: Float +-- conflicts = fromIntegral (nodeDegree classOfVirtualReg graph reg) +-- (_, defs, uses, lifetime) +-- = fromMaybe (reg, 0, 0, 0) $ lookupUFM info reg - -- Otherwise revert to chaitin's regular cost function. - | otherwise = fromIntegral (uses + defs) - / fromIntegral (nodeDegree graph reg) - where (_, defs, uses, lifetime) - = fromMaybe (reg, 0, 0, 0) $ lookupUFM info reg --} -- Just spill the longest live range. spillCost_length diff --git a/compiler/nativeGen/X86/CodeGen.hs b/compiler/nativeGen/X86/CodeGen.hs index 7a2d59993b..b1dd9c58ad 100644 --- a/compiler/nativeGen/X86/CodeGen.hs +++ b/compiler/nativeGen/X86/CodeGen.hs @@ -3529,7 +3529,7 @@ invertCondBranches (Just cfg) keep bs = , Just edgeInfo2 <- getEdgeInfo lbl1 target2 cfg -- Both jumps come from the same cmm statement , transitionSource edgeInfo1 == transitionSource edgeInfo2 - , (CmmSource cmmCondBranch) <- transitionSource edgeInfo1 + , CmmSource {trans_cmmNode = cmmCondBranch} <- transitionSource edgeInfo1 --Int comparisons are invertable , CmmCondBranch (CmmMachOp op _args) _ _ _ <- cmmCondBranch diff --git a/compiler/utils/Dominators.hs b/compiler/utils/Dominators.hs new file mode 100644 index 0000000000..9877c2c1f0 --- /dev/null +++ b/compiler/utils/Dominators.hs @@ -0,0 +1,588 @@ +{-# LANGUAGE RankNTypes, BangPatterns, FlexibleContexts, Strict #-} + +{- | + Module : Dominators + Copyright : (c) Matt Morrow 2009 + License : BSD3 + Maintainer : + Stability : experimental + Portability : portable + + Taken from the dom-lt package. + + The Lengauer-Tarjan graph dominators algorithm. + + \[1\] Lengauer, Tarjan, + /A Fast Algorithm for Finding Dominators in a Flowgraph/, 1979. + + \[2\] Muchnick, + /Advanced Compiler Design and Implementation/, 1997. + + \[3\] Brisk, Sarrafzadeh, + /Interference Graphs for Procedures in Static Single/ + /Information Form are Interval Graphs/, 2007. + + Originally taken from the dom-lt package. +-} + +module Dominators ( + Node,Path,Edge + ,Graph,Rooted + ,idom,ipdom + ,domTree,pdomTree + ,dom,pdom + ,pddfs,rpddfs + ,fromAdj,fromEdges + ,toAdj,toEdges + ,asTree,asGraph + ,parents,ancestors +) where + +import GhcPrelude + +import Data.Bifunctor +import Data.Tuple (swap) + +import Data.Tree +import Data.IntMap(IntMap) +import Data.IntSet(IntSet) +import qualified Data.IntMap.Strict as IM +import qualified Data.IntSet as IS + +import Control.Monad +import Control.Monad.ST.Strict + +import Data.Array.ST +import Data.Array.Base + (unsafeNewArray_ + ,unsafeWrite,unsafeRead) + +----------------------------------------------------------------------------- + +type Node = Int +type Path = [Node] +type Edge = (Node,Node) +type Graph = IntMap IntSet +type Rooted = (Node, Graph) + +----------------------------------------------------------------------------- + +-- | /Dominators/. +-- Complexity as for @idom@ +dom :: Rooted -> [(Node, Path)] +dom = ancestors . domTree + +-- | /Post-dominators/. +-- Complexity as for @idom@. +pdom :: Rooted -> [(Node, Path)] +pdom = ancestors . pdomTree + +-- | /Dominator tree/. +-- Complexity as for @idom@. +domTree :: Rooted -> Tree Node +domTree a@(r,_) = + let is = filter ((/=r).fst) (idom a) + tg = fromEdges (fmap swap is) + in asTree (r,tg) + +-- | /Post-dominator tree/. +-- Complexity as for @idom@. +pdomTree :: Rooted -> Tree Node +pdomTree a@(r,_) = + let is = filter ((/=r).fst) (ipdom a) + tg = fromEdges (fmap swap is) + in asTree (r,tg) + +-- | /Immediate dominators/. +-- /O(|E|*alpha(|E|,|V|))/, where /alpha(m,n)/ is +-- \"a functional inverse of Ackermann's function\". +-- +-- This Complexity bound assumes /O(1)/ indexing. Since we're +-- using @IntMap@, it has an additional /lg |V|/ factor +-- somewhere in there. I'm not sure where. +idom :: Rooted -> [(Node,Node)] +idom rg = runST (evalS idomM =<< initEnv (pruneReach rg)) + +-- | /Immediate post-dominators/. +-- Complexity as for @idom@. +ipdom :: Rooted -> [(Node,Node)] +ipdom rg = runST (evalS idomM =<< initEnv (pruneReach (second predG rg))) + +----------------------------------------------------------------------------- + +-- | /Post-dominated depth-first search/. +pddfs :: Rooted -> [Node] +pddfs = reverse . rpddfs + +-- | /Reverse post-dominated depth-first search/. +rpddfs :: Rooted -> [Node] +rpddfs = concat . levels . pdomTree + +----------------------------------------------------------------------------- + +type Dom s a = S s (Env s) a +type NodeSet = IntSet +type NodeMap a = IntMap a +data Env s = Env + {succE :: !Graph + ,predE :: !Graph + ,bucketE :: !Graph + ,dfsE :: {-# UNPACK #-}!Int + ,zeroE :: {-# UNPACK #-}!Node + ,rootE :: {-# UNPACK #-}!Node + ,labelE :: {-# UNPACK #-}!(Arr s Node) + ,parentE :: {-# UNPACK #-}!(Arr s Node) + ,ancestorE :: {-# UNPACK #-}!(Arr s Node) + ,childE :: {-# UNPACK #-}!(Arr s Node) + ,ndfsE :: {-# UNPACK #-}!(Arr s Node) + ,dfnE :: {-# UNPACK #-}!(Arr s Int) + ,sdnoE :: {-# UNPACK #-}!(Arr s Int) + ,sizeE :: {-# UNPACK #-}!(Arr s Int) + ,domE :: {-# UNPACK #-}!(Arr s Node) + ,rnE :: {-# UNPACK #-}!(Arr s Node)} + +----------------------------------------------------------------------------- + +idomM :: Dom s [(Node,Node)] +idomM = do + dfsDom =<< rootM + n <- gets dfsE + forM_ [n,n-1..1] (\i-> do + w <- ndfsM i + sw <- sdnoM w + ps <- predsM w + forM_ ps (\v-> do + u <- eval v + su <- sdnoM u + when (su < sw) + (store sdnoE w su)) + z <- ndfsM =<< sdnoM w + modify(\e->e{bucketE=IM.adjust + (w`IS.insert`) + z (bucketE e)}) + pw <- parentM w + link pw w + bps <- bucketM pw + forM_ bps (\v-> do + u <- eval v + su <- sdnoM u + sv <- sdnoM v + let dv = case su < sv of + True-> u + False-> pw + store domE v dv)) + forM_ [1..n] (\i-> do + w <- ndfsM i + j <- sdnoM w + z <- ndfsM j + dw <- domM w + when (dw /= z) + (do ddw <- domM dw + store domE w ddw)) + fromEnv + +----------------------------------------------------------------------------- + +eval :: Node -> Dom s Node +eval v = do + n0 <- zeroM + a <- ancestorM v + case a==n0 of + True-> labelM v + False-> do + compress v + a <- ancestorM v + l <- labelM v + la <- labelM a + sl <- sdnoM l + sla <- sdnoM la + case sl <= sla of + True-> return l + False-> return la + +compress :: Node -> Dom s () +compress v = do + n0 <- zeroM + a <- ancestorM v + aa <- ancestorM a + when (aa /= n0) (do + compress a + a <- ancestorM v + aa <- ancestorM a + l <- labelM v + la <- labelM a + sl <- sdnoM l + sla <- sdnoM la + when (sla < sl) + (store labelE v la) + store ancestorE v aa) + +----------------------------------------------------------------------------- + +link :: Node -> Node -> Dom s () +link v w = do + n0 <- zeroM + lw <- labelM w + slw <- sdnoM lw + let balance s = do + c <- childM s + lc <- labelM c + slc <- sdnoM lc + case slw < slc of + False-> return s + True-> do + zs <- sizeM s + zc <- sizeM c + cc <- childM c + zcc <- sizeM cc + case 2*zc <= zs+zcc of + True-> do + store ancestorE c s + store childE s cc + balance s + False-> do + store sizeE c zs + store ancestorE s c + balance c + s <- balance w + lw <- labelM w + zw <- sizeM w + store labelE s lw + store sizeE v . (+zw) =<< sizeM v + let follow s = do + when (s /= n0) (do + store ancestorE s v + follow =<< childM s) + zv <- sizeM v + follow =<< case zv < 2*zw of + False-> return s + True-> do + cv <- childM v + store childE v s + return cv + +----------------------------------------------------------------------------- + +dfsDom :: Node -> Dom s () +dfsDom i = do + _ <- go i + n0 <- zeroM + r <- rootM + store parentE r n0 + where go i = do + n <- nextM + store dfnE i n + store sdnoE i n + store ndfsE n i + store labelE i i + ss <- succsM i + forM_ ss (\j-> do + s <- sdnoM j + case s==0 of + False-> return() + True-> do + store parentE j i + go j) + +----------------------------------------------------------------------------- + +initEnv :: Rooted -> ST s (Env s) +initEnv (r0,g0) = do + let (g,rnmap) = renum 1 g0 + pred = predG g + r = rnmap IM.! r0 + n = IM.size g + ns = [0..n] + m = n+1 + + let bucket = IM.fromList + (zip ns (repeat mempty)) + + rna <- newI m + writes rna (fmap swap + (IM.toList rnmap)) + + doms <- newI m + sdno <- newI m + size <- newI m + parent <- newI m + ancestor <- newI m + child <- newI m + label <- newI m + ndfs <- newI m + dfn <- newI m + + forM_ [0..n] (doms.=0) + forM_ [0..n] (sdno.=0) + forM_ [1..n] (size.=1) + forM_ [0..n] (ancestor.=0) + forM_ [0..n] (child.=0) + + (doms.=r) r + (size.=0) 0 + (label.=0) 0 + + return (Env + {rnE = rna + ,dfsE = 0 + ,zeroE = 0 + ,rootE = r + ,labelE = label + ,parentE = parent + ,ancestorE = ancestor + ,childE = child + ,ndfsE = ndfs + ,dfnE = dfn + ,sdnoE = sdno + ,sizeE = size + ,succE = g + ,predE = pred + ,bucketE = bucket + ,domE = doms}) + +fromEnv :: Dom s [(Node,Node)] +fromEnv = do + dom <- gets domE + rn <- gets rnE + -- r <- gets rootE + (_,n) <- st (getBounds dom) + forM [1..n] (\i-> do + j <- st(rn!:i) + d <- st(dom!:i) + k <- st(rn!:d) + return (j,k)) + +----------------------------------------------------------------------------- + +zeroM :: Dom s Node +zeroM = gets zeroE +domM :: Node -> Dom s Node +domM = fetch domE +rootM :: Dom s Node +rootM = gets rootE +succsM :: Node -> Dom s [Node] +succsM i = gets (IS.toList . (!i) . succE) +predsM :: Node -> Dom s [Node] +predsM i = gets (IS.toList . (!i) . predE) +bucketM :: Node -> Dom s [Node] +bucketM i = gets (IS.toList . (!i) . bucketE) +sizeM :: Node -> Dom s Int +sizeM = fetch sizeE +sdnoM :: Node -> Dom s Int +sdnoM = fetch sdnoE +-- dfnM :: Node -> Dom s Int +-- dfnM = fetch dfnE +ndfsM :: Int -> Dom s Node +ndfsM = fetch ndfsE +childM :: Node -> Dom s Node +childM = fetch childE +ancestorM :: Node -> Dom s Node +ancestorM = fetch ancestorE +parentM :: Node -> Dom s Node +parentM = fetch parentE +labelM :: Node -> Dom s Node +labelM = fetch labelE +nextM :: Dom s Int +nextM = do + n <- gets dfsE + let n' = n+1 + modify(\e->e{dfsE=n'}) + return n' + +----------------------------------------------------------------------------- + +type A = STUArray +type Arr s a = A s Int a + +infixl 9 !: +infixr 2 .= + +(.=) :: (MArray (A s) a (ST s)) + => Arr s a -> a -> Int -> ST s () +(v .= x) i = unsafeWrite v i x + +(!:) :: (MArray (A s) a (ST s)) + => A s Int a -> Int -> ST s a +a !: i = do + o <- unsafeRead a i + return $! o + +new :: (MArray (A s) a (ST s)) + => Int -> ST s (Arr s a) +new n = unsafeNewArray_ (0,n-1) + +newI :: Int -> ST s (Arr s Int) +newI = new + +-- newD :: Int -> ST s (Arr s Double) +-- newD = new + +-- dump :: (MArray (A s) a (ST s)) => Arr s a -> ST s [a] +-- dump a = do +-- (m,n) <- getBounds a +-- forM [m..n] (\i -> a!:i) + +writes :: (MArray (A s) a (ST s)) + => Arr s a -> [(Int,a)] -> ST s () +writes a xs = forM_ xs (\(i,x) -> (a.=x) i) + +-- arr :: (MArray (A s) a (ST s)) => [a] -> ST s (Arr s a) +-- arr xs = do +-- let n = length xs +-- a <- new n +-- go a n 0 xs +-- return a +-- where go _ _ _ [] = return () +-- go a n i (x:xs) +-- | i <= n = (a.=x) i >> go a n (i+1) xs +-- | otherwise = return () + +----------------------------------------------------------------------------- + +(!) :: Monoid a => IntMap a -> Int -> a +(!) g n = maybe mempty id (IM.lookup n g) + +fromAdj :: [(Node, [Node])] -> Graph +fromAdj = IM.fromList . fmap (second IS.fromList) + +fromEdges :: [Edge] -> Graph +fromEdges = collectI IS.union fst (IS.singleton . snd) + +toAdj :: Graph -> [(Node, [Node])] +toAdj = fmap (second IS.toList) . IM.toList + +toEdges :: Graph -> [Edge] +toEdges = concatMap (uncurry (fmap . (,))) . toAdj + +predG :: Graph -> Graph +predG g = IM.unionWith IS.union (go g) g0 + where g0 = fmap (const mempty) g + f :: IntMap IntSet -> Int -> IntSet -> IntMap IntSet + f m i a = foldl' (\m p -> IM.insertWith mappend p + (IS.singleton i) m) + m + (IS.toList a) + go :: IntMap IntSet -> IntMap IntSet + go = flip IM.foldlWithKey' mempty f + +pruneReach :: Rooted -> Rooted +pruneReach (r,g) = (r,g2) + where is = reachable + (maybe mempty id + . flip IM.lookup g) $ r + g2 = IM.fromList + . fmap (second (IS.filter (`IS.member`is))) + . filter ((`IS.member`is) . fst) + . IM.toList $ g + +tip :: Tree a -> (a, [Tree a]) +tip (Node a ts) = (a, ts) + +parents :: Tree a -> [(a, a)] +parents (Node i xs) = p i xs + ++ concatMap parents xs + where p i = fmap (flip (,) i . rootLabel) + +ancestors :: Tree a -> [(a, [a])] +ancestors = go [] + where go acc (Node i xs) + = let acc' = i:acc + in p acc' xs ++ concatMap (go acc') xs + p is = fmap (flip (,) is . rootLabel) + +asGraph :: Tree Node -> Rooted +asGraph t@(Node a _) = let g = go t in (a, fromAdj g) + where go (Node a ts) = let as = (fst . unzip . fmap tip) ts + in (a, as) : concatMap go ts + +asTree :: Rooted -> Tree Node +asTree (r,g) = let go a = Node a (fmap go ((IS.toList . f) a)) + f = (g !) + in go r + +reachable :: (Node -> NodeSet) -> (Node -> NodeSet) +reachable f a = go (IS.singleton a) a + where go seen a = let s = f a + as = IS.toList (s `IS.difference` seen) + in foldl' go (s `IS.union` seen) as + +collectI :: (c -> c -> c) + -> (a -> Int) -> (a -> c) -> [a] -> IntMap c +collectI (<>) f g + = foldl' (\m a -> IM.insertWith (<>) + (f a) + (g a) m) mempty + +-- collect :: (Ord b) => (c -> c -> c) +-- -> (a -> b) -> (a -> c) -> [a] -> Map b c +-- collect (<>) f g +-- = foldl' (\m a -> SM.insertWith (<>) +-- (f a) +-- (g a) m) mempty + +-- (renamed, old -> new) +renum :: Int -> Graph -> (Graph, NodeMap Node) +renum from = (\(_,m,g)->(g,m)) + . IM.foldlWithKey' + f (from,mempty,mempty) + where + f :: (Int, NodeMap Node, IntMap IntSet) -> Node -> IntSet + -> (Int, NodeMap Node, IntMap IntSet) + f (!n,!env,!new) i ss = + let (j,n2,env2) = go n env i + (n3,env3,ss2) = IS.fold + (\k (!n,!env,!new)-> + case go n env k of + (l,n2,env2)-> (n2,env2,l `IS.insert` new)) + (n2,env2,mempty) ss + new2 = IM.insertWith IS.union j ss2 new + in (n3,env3,new2) + go :: Int + -> NodeMap Node + -> Node + -> (Node,Int,NodeMap Node) + go !n !env i = + case IM.lookup i env of + Just j -> (j,n,env) + Nothing -> (n,n+1,IM.insert i n env) + +----------------------------------------------------------------------------- + +newtype S z s a = S {unS :: forall o. (a -> s -> ST z o) -> s -> ST z o} +instance Functor (S z s) where + fmap f (S g) = S (\k -> g (k . f)) +instance Monad (S z s) where + return = pure + S g >>= f = S (\k -> g (\a -> unS (f a) k)) +instance Applicative (S z s) where + pure a = S (\k -> k a) + (<*>) = ap +-- get :: S z s s +-- get = S (\k s -> k s s) +gets :: (s -> a) -> S z s a +gets f = S (\k s -> k (f s) s) +-- set :: s -> S z s () +-- set s = S (\k _ -> k () s) +modify :: (s -> s) -> S z s () +modify f = S (\k -> k () . f) +-- runS :: S z s a -> s -> ST z (a, s) +-- runS (S g) = g (\a s -> return (a,s)) +evalS :: S z s a -> s -> ST z a +evalS (S g) = g ((return .) . const) +-- execS :: S z s a -> s -> ST z s +-- execS (S g) = g ((return .) . flip const) +st :: ST z a -> S z s a +st m = S (\k s-> do + a <- m + k a s) +store :: (MArray (A z) a (ST z)) + => (s -> Arr z a) -> Int -> a -> S z s () +store f i x = do + a <- gets f + st ((a.=x) i) +fetch :: (MArray (A z) a (ST z)) + => (s -> Arr z a) -> Int -> S z s a +fetch f i = do + a <- gets f + st (a!:i) + diff --git a/compiler/utils/OrdList.hs b/compiler/utils/OrdList.hs index e8b50e5968..8da5038b2c 100644 --- a/compiler/utils/OrdList.hs +++ b/compiler/utils/OrdList.hs @@ -10,14 +10,18 @@ can be appended in linear time. -} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE BangPatterns #-} + module OrdList ( OrdList, nilOL, isNilOL, unitOL, appOL, consOL, snocOL, concatOL, lastOL, headOL, - mapOL, fromOL, toOL, foldrOL, foldlOL, reverseOL, fromOLReverse + mapOL, fromOL, toOL, foldrOL, foldlOL, reverseOL, fromOLReverse, + strictlyEqOL, strictlyOrdOL ) where import GhcPrelude +import Data.Foldable import Outputable @@ -49,7 +53,11 @@ instance Monoid (OrdList a) where mconcat = concatOL instance Foldable OrdList where - foldr = foldrOL + foldr = foldrOL + foldl' = foldlOL + toList = fromOL + null = isNilOL + length = lengthOL instance Traversable OrdList where traverse f xs = toOL <$> traverse f (fromOL xs) @@ -64,7 +72,7 @@ appOL :: OrdList a -> OrdList a -> OrdList a concatOL :: [OrdList a] -> OrdList a headOL :: OrdList a -> a lastOL :: OrdList a -> a - +lengthOL :: OrdList a -> Int nilOL = None unitOL as = One as @@ -86,6 +94,13 @@ lastOL (Cons _ as) = lastOL as lastOL (Snoc _ a) = a lastOL (Two _ as) = lastOL as +lengthOL None = 0 +lengthOL (One _) = 1 +lengthOL (Many as) = length as +lengthOL (Cons _ as) = 1 + length as +lengthOL (Snoc as _) = 1 + length as +lengthOL (Two as bs) = length as + length bs + isNilOL None = True isNilOL _ = False @@ -126,13 +141,14 @@ foldrOL k z (Snoc xs x) = foldrOL k (k x z) xs foldrOL k z (Two b1 b2) = foldrOL k (foldrOL k z b2) b1 foldrOL k z (Many xs) = foldr k z xs +-- | Strict left fold. foldlOL :: (b->a->b) -> b -> OrdList a -> b foldlOL _ z None = z foldlOL k z (One x) = k z x -foldlOL k z (Cons x xs) = foldlOL k (k z x) xs -foldlOL k z (Snoc xs x) = k (foldlOL k z xs) x -foldlOL k z (Two b1 b2) = foldlOL k (foldlOL k z b1) b2 -foldlOL k z (Many xs) = foldl k z xs +foldlOL k z (Cons x xs) = let !z' = (k z x) in foldlOL k z' xs +foldlOL k z (Snoc xs x) = let !z' = (foldlOL k z xs) in k z' x +foldlOL k z (Two b1 b2) = let !z' = (foldlOL k z b1) in foldlOL k z' b2 +foldlOL k z (Many xs) = foldl' k z xs toOL :: [a] -> OrdList a toOL [] = None @@ -146,3 +162,33 @@ reverseOL (Cons a b) = Snoc (reverseOL b) a reverseOL (Snoc a b) = Cons b (reverseOL a) reverseOL (Two a b) = Two (reverseOL b) (reverseOL a) reverseOL (Many xs) = Many (reverse xs) + +-- | Compare not only the values but also the structure of two lists +strictlyEqOL :: Eq a => OrdList a -> OrdList a -> Bool +strictlyEqOL None None = True +strictlyEqOL (One x) (One y) = x == y +strictlyEqOL (Cons a as) (Cons b bs) = a == b && as `strictlyEqOL` bs +strictlyEqOL (Snoc as a) (Snoc bs b) = a == b && as `strictlyEqOL` bs +strictlyEqOL (Two a1 a2) (Two b1 b2) = a1 `strictlyEqOL` b1 && a2 `strictlyEqOL` b2 +strictlyEqOL (Many as) (Many bs) = as == bs +strictlyEqOL _ _ = False + +-- | Compare not only the values but also the structure of two lists +strictlyOrdOL :: Ord a => OrdList a -> OrdList a -> Ordering +strictlyOrdOL None None = EQ +strictlyOrdOL None _ = LT +strictlyOrdOL (One x) (One y) = compare x y +strictlyOrdOL (One _) _ = LT +strictlyOrdOL (Cons a as) (Cons b bs) = + compare a b `mappend` strictlyOrdOL as bs +strictlyOrdOL (Cons _ _) _ = LT +strictlyOrdOL (Snoc as a) (Snoc bs b) = + compare a b `mappend` strictlyOrdOL as bs +strictlyOrdOL (Snoc _ _) _ = LT +strictlyOrdOL (Two a1 a2) (Two b1 b2) = + (strictlyOrdOL a1 b1) `mappend` (strictlyOrdOL a2 b2) +strictlyOrdOL (Two _ _) _ = LT +strictlyOrdOL (Many as) (Many bs) = compare as bs +strictlyOrdOL (Many _ ) _ = GT + + -- cgit v1.2.1