summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorklebinger.andreas@gmx.at <klebinger.andreas@gmx.at>2019-02-18 00:28:39 +0100
committerBen Gamari <ben@smart-cactus.org>2019-10-13 15:07:06 -0400
commit9a90f27f080123c636a3401b103ea68aa843e34d (patch)
tree581dabf4a1720457bce175902d8e9b5f6f1a0d93
parentd584e3f08cfee6e28b70bf53c573d86e44f326f8 (diff)
downloadhaskell-wip/gc/base.tar.gz
Add loop level analysis to the NCG backend.wip/gc/base
For backends maintaining the CFG during codegen we can now find loops and their nesting level. This is based on the Cmm CFG and dominator analysis. As a result we can estimate edge frequencies a lot better for methods, resulting in far better code layout. Speedup on nofib: ~1.5% Increase in compile times: ~1.9% To make this feasible this commit adds: * Dominator analysis based on the Lengauer-Tarjan Algorithm. * An algorithm estimating global edge frequences from branch probabilities - In CFG.hs A few static branch prediction heuristics: * Expect to take the backedge in loops. * Expect to take the branch NOT exiting a loop. * Expect integer vs constant comparisons to be false. We also treat heap/stack checks special for branch prediction to avoid them being treated as loops. (cherry picked from commit 056aa12d60f34ee90db2527586c82fc6f16eba39)
-rw-r--r--compiler/cmm/Hoopl/Dataflow.hs5
-rw-r--r--compiler/ghc.cabal.in1
-rw-r--r--compiler/nativeGen/AsmCodeGen.hs6
-rw-r--r--compiler/nativeGen/BlockLayout.hs638
-rw-r--r--compiler/nativeGen/CFG.hs748
-rw-r--r--compiler/nativeGen/RegAlloc/Graph/SpillCost.hs101
-rw-r--r--compiler/nativeGen/X86/CodeGen.hs2
-rw-r--r--compiler/utils/Dominators.hs588
-rw-r--r--compiler/utils/OrdList.hs60
9 files changed, 1775 insertions, 374 deletions
diff --git a/compiler/cmm/Hoopl/Dataflow.hs b/compiler/cmm/Hoopl/Dataflow.hs
index 2a2bb72dcc..9762a84e20 100644
--- a/compiler/cmm/Hoopl/Dataflow.hs
+++ b/compiler/cmm/Hoopl/Dataflow.hs
@@ -6,8 +6,6 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
-{-# OPTIONS_GHC -fprof-auto-top #-}
-
--
-- Copyright (c) 2010, João Dias, Simon Marlow, Simon Peyton Jones,
-- and Norman Ramsey
@@ -108,6 +106,7 @@ analyzeCmm
-> FactBase f
-> FactBase f
analyzeCmm dir lattice transfer cmmGraph initFact =
+ {-# SCC analyzeCmm #-}
let entry = g_entry cmmGraph
hooplGraph = g_graph cmmGraph
blockMap =
@@ -169,7 +168,7 @@ rewriteCmm
-> CmmGraph
-> FactBase f
-> UniqSM (CmmGraph, FactBase f)
-rewriteCmm dir lattice rwFun cmmGraph initFact = do
+rewriteCmm dir lattice rwFun cmmGraph initFact = {-# SCC rewriteCmm #-} do
let entry = g_entry cmmGraph
hooplGraph = g_graph cmmGraph
blockMap1 =
diff --git a/compiler/ghc.cabal.in b/compiler/ghc.cabal.in
index a612733de6..3ff27ea098 100644
--- a/compiler/ghc.cabal.in
+++ b/compiler/ghc.cabal.in
@@ -593,6 +593,7 @@ Library
Instruction
BlockLayout
CFG
+ Dominators
Format
Reg
RegClass
diff --git a/compiler/nativeGen/AsmCodeGen.hs b/compiler/nativeGen/AsmCodeGen.hs
index e033a4c218..68a094a9b9 100644
--- a/compiler/nativeGen/AsmCodeGen.hs
+++ b/compiler/nativeGen/AsmCodeGen.hs
@@ -563,7 +563,7 @@ cmmNativeGen dflags this_mod modLoc ncgImpl us fileIds dbgMap cmm count
Opt_D_dump_asm_native "Native code"
(vcat $ map (pprNatCmmDecl ncgImpl) native)
- dumpIfSet_dyn dflags
+ when (not $ null nativeCfgWeights) $ dumpIfSet_dyn dflags
Opt_D_dump_cfg_weights "CFG Weights"
(pprEdgeWeights nativeCfgWeights)
@@ -691,7 +691,7 @@ cmmNativeGen dflags this_mod modLoc ncgImpl us fileIds dbgMap cmm count
{-# SCC "generateJumpTables" #-}
generateJumpTables ncgImpl alloced
- dumpIfSet_dyn dflags
+ when (not $ null nativeCfgWeights) $ dumpIfSet_dyn dflags
Opt_D_dump_cfg_weights "CFG Update information"
( text "stack:" <+> ppr stack_updt_blks $$
text "linearAlloc:" <+> ppr cfgRegAllocUpdates )
@@ -704,7 +704,7 @@ cmmNativeGen dflags this_mod modLoc ncgImpl us fileIds dbgMap cmm count
let optimizedCFG =
optimizeCFG (cfgWeightInfo dflags) cmm postShortCFG
- dumpIfSet_dyn dflags
+ when (not $ null nativeCfgWeights) $ dumpIfSet_dyn dflags
Opt_D_dump_cfg_weights "CFG Final Weights"
( pprEdgeWeights optimizedCFG )
diff --git a/compiler/nativeGen/BlockLayout.hs b/compiler/nativeGen/BlockLayout.hs
index 5e34b28793..2216d45f48 100644
--- a/compiler/nativeGen/BlockLayout.hs
+++ b/compiler/nativeGen/BlockLayout.hs
@@ -6,6 +6,8 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE FlexibleContexts #-}
module BlockLayout
( sequenceTop )
@@ -22,7 +24,6 @@ import BlockId
import Cmm
import Hoopl.Collections
import Hoopl.Label
-import Hoopl.Block
import DynFlags (gopt, GeneralFlag(..), DynFlags, backendMaintainsCfg)
import UniqFM
@@ -41,11 +42,30 @@ import ListSetOps (removeDups)
import OrdList
import Data.List
import Data.Foldable (toList)
-import Hoopl.Graph
import qualified Data.Set as Set
+import Data.STRef
+import Control.Monad.ST.Strict
+import Control.Monad (foldM)
{-
+ Note [CFG based code layout]
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ The major steps in placing blocks are as follow:
+ * Compute a CFG based on the Cmm AST, see getCfgProc.
+ This CFG will have edge weights representing a guess
+ on how important they are.
+ * After we convert Cmm to Asm we run `optimizeCFG` which
+ adds a few more "educated guesses" to the equation.
+ * Then we run loop analysis on the CFG (`loopInfo`) which tells us
+ about loop headers, loop nesting levels and the sort.
+ * Based on the CFG and loop information refine the edge weights
+ in the CFG and normalize them relative to the most often visited
+ node. (See `mkGlobalWeights`)
+ * Feed this CFG into the block layout code (`sequenceTop`) in this
+ module. Which will then produce a code layout based on the input weights.
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~ Note [Chain based CFG serialization]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -60,8 +80,8 @@ import qualified Data.Set as Set
but also how much a block would benefit from being placed sequentially after
it's predecessor.
For example blocks which are preceeded by an info table are more likely to end
- up in a different cache line than their predecessor. So there is less benefit
- in placing them sequentially.
+ up in a different cache line than their predecessor and we can't eliminate the jump
+ so there is less benefit to placing them sequentially.
For example consider this example:
@@ -81,56 +101,83 @@ import qualified Data.Set as Set
Eg for our example we might end up with two chains like:
[A->B->C->X],[D]. Blocks inside chains will always be placed sequentially.
However there is no particular order in which chains are placed since
- (hopefully) the blocks for which sequentially is important have already
+ (hopefully) the blocks for which sequentiality is important have already
been placed in the same chain.
-----------------------------------------------------------------------------
- First try to create a lists of good chains.
+ 1) First try to create a list of good chains.
-----------------------------------------------------------------------------
- We do so by taking a block not yet placed in a chain and
- looking at these cases:
+ Good chains are these which allow us to eliminate jump instructions.
+ Which further eliminate often executed jumps first.
+
+ We do so by:
+
+ *) Ignore edges which represent instructions which can not be replaced
+ by fall through control flow. Primarily calls and edges to blocks which
+ are prefixed by a info table we have to jump across.
+
+ *) Then process remaining edges in order of frequency taken and:
+
+ +) If source and target have not been placed build a new chain from them.
+
+ +) If source and target have been placed, and are ends of differing chains
+ try to merge the two chains.
- *) Check if the best predecessor of the block is at the end of a chain.
- If so add the current block to the end of that chain.
+ +) If one side of the edge is a end/front of a chain, add the other block of
+ to edge to the same chain
- Eg if we look at block C and already have the chain (A -> B)
- then we extend the chain to (A -> B -> C).
+ Eg if we look at edge (B -> C) and already have the chain (A -> B)
+ then we extend the chain to (A -> B -> C).
- Combined with the fact that we process blocks in reverse post order
- this means loop bodies and trivially sequential control flow already
- ends up as a single chain.
+ +) If the edge was used to modify or build a new chain remove the edge from
+ our working list.
- *) Otherwise we create a singleton chain from the block we are looking at.
- Eg if we have from the example above already constructed (A->B)
- and look at D we create the chain (D) resulting in the chains [A->B, D]
+ *) If there any blocks not being placed into a chain after these steps we place
+ them into a chain consisting of only this block.
+
+ Ranking edges by their taken frequency, if
+ two edges compete for fall through on the same target block, the one taken
+ more often will automatically win out. Resulting in fewer instructions being
+ executed.
+
+ Creating singleton chains is required for situations where we have code of the
+ form:
+
+ A: goto B:
+ <infoTable>
+ B: goto C:
+ <infoTable>
+ C: ...
+
+ As the code in block B is only connected to the rest of the program via edges
+ which will be ignored in this step we make sure that B still ends up in a chain
+ this way.
-----------------------------------------------------------------------------
- We then try to fuse chains.
+ 2) We also try to fuse chains.
-----------------------------------------------------------------------------
- There are edge cases which result in two chains being created which trivially
- represent linear control flow. For example we might have the chains
- [(A-B-C),(D-E)] with an cfg triangle:
+ As a result from the above step we still end up with multiple chains which
+ represent sequential control flow chunks. But they are not yet suitable for
+ code layout as we need to place *all* blocks into a single sequence.
- A----->C->D->E
- \->B-/
+ In this step we combine chains result from the above step via these steps:
- We also get three independent chains if two branches end with a jump
- to a common successor.
+ *) Look at the ranked list of *all* edges, including calls/jumps across info tables
+ and the like.
- We take care of these cases by fusing chains which are connected by an
- edge.
+ *) Look at each edge and
- We do so by looking at the list of edges sorted by weight.
- Given the edge (C -> D) we try to find two chains such that:
- * C is at the end of chain one.
- * D is in front of chain two.
- * If two such chains exist we fuse them.
- We then remove the edge and repeat the process for the rest of the edges.
+ +) Given an edge (A -> B) try to find two chains for which
+ * Block A is at the end of one chain
+ * Block B is at the front of the other chain.
+ +) If we find such a chain we "fuse" them into a single chain, remove the
+ edge from working set and continue.
+ +) If we can't find such chains we skip the edge and continue.
-----------------------------------------------------------------------------
- Place indirect successors (neighbours) after each other
+ 3) Place indirect successors (neighbours) after each other
-----------------------------------------------------------------------------
We might have chains [A,B,C,X],[E] in a CFG of the sort:
@@ -141,15 +188,11 @@ import qualified Data.Set as Set
While E does not follow X it's still beneficial to place them near each other.
This can be advantageous if eg C,X,E will end up in the same cache line.
- TODO: If we remove edges as we use them (eg if we build up A->B remove A->B
- from the list) we could save some more work in later phases.
-
-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~ Note [Triangle Control Flow]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- Checking if an argument is already evaluating leads to a somewhat
+ Checking if an argument is already evaluated leads to a somewhat
special case which looks like this:
A:
@@ -204,11 +247,6 @@ import qualified Data.Set as Set
neighbourOverlapp :: Int
neighbourOverlapp = 2
--- | Only edges heavier than this are considered
--- for fusing two chains into a single chain.
-fuseEdgeThreshold :: EdgeWeight
-fuseEdgeThreshold = 0
-
-- | Maps blocks near the end of a chain to it's chain AND
-- the other blocks near the end.
-- [A,B,C,D,E] Gives entries like (B -> ([A,B], [A,B,C,D,E]))
@@ -224,40 +262,24 @@ type FrontierMap = LabelMap ([BlockId],BlockChain)
newtype BlockChain
= BlockChain { chainBlocks :: (OrdList BlockId) }
-instance Eq (BlockChain) where
- (BlockChain blks1) == (BlockChain blks2)
- = fromOL blks1 == fromOL blks2
+-- All chains are constructed the same way so comparison
+-- including structure is faster.
+instance Eq BlockChain where
+ BlockChain b1 == BlockChain b2 = strictlyEqOL b1 b2
-- Useful for things like sets and debugging purposes, sorts by blocks
-- in the chain.
instance Ord (BlockChain) where
(BlockChain lbls1) `compare` (BlockChain lbls2)
- = (fromOL lbls1) `compare` (fromOL lbls2)
+ = ASSERT(toList lbls1 /= toList lbls2 || lbls1 `strictlyEqOL` lbls2)
+ strictlyOrdOL lbls1 lbls2
instance Outputable (BlockChain) where
ppr (BlockChain blks) =
parens (text "Chain:" <+> ppr (fromOL $ blks) )
-data WeightedEdge = WeightedEdge !BlockId !BlockId EdgeWeight deriving (Eq)
-
-
--- | Non deterministic! (Uniques) Sorts edges by weight and nodes.
-instance Ord WeightedEdge where
- compare (WeightedEdge from1 to1 weight1)
- (WeightedEdge from2 to2 weight2)
- | weight1 < weight2 || weight1 == weight2 && from1 < from2 ||
- weight1 == weight2 && from1 == from2 && to1 < to2
- = LT
- | from1 == from2 && to1 == to2 && weight1 == weight2
- = EQ
- | otherwise
- = GT
-
-instance Outputable WeightedEdge where
- ppr (WeightedEdge from to info) =
- ppr from <> text "->" <> ppr to <> brackets (ppr info)
-
-type WeightedEdgeList = [WeightedEdge]
+chainFoldl :: (b -> BlockId -> b) -> b -> BlockChain -> b
+chainFoldl f z (BlockChain blocks) = foldl' f z blocks
noDups :: [BlockChain] -> Bool
noDups chains =
@@ -270,19 +292,21 @@ inFront :: BlockId -> BlockChain -> Bool
inFront bid (BlockChain seq)
= headOL seq == bid
-chainMember :: BlockId -> BlockChain -> Bool
-chainMember bid chain
- = elem bid $ fromOL . chainBlocks $ chain
--- = setMember bid . chainMembers $ chain
-
chainSingleton :: BlockId -> BlockChain
chainSingleton lbl
= BlockChain (unitOL lbl)
+chainFromList :: [BlockId] -> BlockChain
+chainFromList = BlockChain . toOL
+
chainSnoc :: BlockChain -> BlockId -> BlockChain
chainSnoc (BlockChain blks) lbl
= BlockChain (blks `snocOL` lbl)
+chainCons :: BlockId -> BlockChain -> BlockChain
+chainCons lbl (BlockChain blks)
+ = BlockChain (lbl `consOL` blks)
+
chainConcat :: BlockChain -> BlockChain -> BlockChain
chainConcat (BlockChain blks1) (BlockChain blks2)
= BlockChain (blks1 `appOL` blks2)
@@ -311,52 +335,14 @@ takeL :: Int -> BlockChain -> [BlockId]
takeL n (BlockChain blks) =
take n . fromOL $ blks
--- | For a given list of chains try to fuse chains with strong
--- edges between them into a single chain.
--- Returns the list of fused chains together with a set of
--- used edges. The set of edges is indirectly encoded in the
--- chains so doesn't need to be considered for later passes.
-fuseChains :: WeightedEdgeList -> LabelMap BlockChain
- -> (LabelMap BlockChain, Set.Set WeightedEdge)
-fuseChains weights chains
- = let fronts = mapFromList $
- map (\chain -> (headOL . chainBlocks $ chain,chain)) $
- mapElems chains :: LabelMap BlockChain
- (chains', used, _) = applyEdges weights chains fronts Set.empty
- in (chains', used)
- where
- applyEdges :: WeightedEdgeList -> LabelMap BlockChain
- -> LabelMap BlockChain -> Set.Set WeightedEdge
- -> (LabelMap BlockChain, Set.Set WeightedEdge, LabelMap BlockChain)
- applyEdges [] chainsEnd chainsFront used
- = (chainsEnd, used, chainsFront)
- applyEdges (edge@(WeightedEdge from to w):edges) chainsEnd chainsFront used
- --Since we order edges descending by weight we can stop here
- | w <= fuseEdgeThreshold
- = ( chainsEnd, used, chainsFront)
- --Fuse the two chains
- | Just c1 <- mapLookup from chainsEnd
- , Just c2 <- mapLookup to chainsFront
- , c1 /= c2
- = let newChain = chainConcat c1 c2
- front = headOL . chainBlocks $ newChain
- end = lastOL . chainBlocks $ newChain
- chainsFront' = mapInsert front newChain $
- mapDelete to chainsFront
- chainsEnd' = mapInsert end newChain $
- mapDelete from chainsEnd
- in applyEdges edges chainsEnd' chainsFront'
- (Set.insert edge used)
- | otherwise
- --Check next edge
- = applyEdges edges chainsEnd chainsFront used
-
+-- Note [Combining neighborhood chains]
+-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- See also Note [Chain based CFG serialization]
-- We have the chains (A-B-C-D) and (E-F) and an Edge C->E.
--
--- While placing the later after the former doesn't result in sequential
--- control flow it is still be benefical since block C and E might end
+-- While placing the latter after the former doesn't result in sequential
+-- control flow it is still benefical. As block C and E might end
-- up in the same cache line.
--
-- So we place these chains next to each other even if we can't fuse them.
@@ -365,7 +351,7 @@ fuseChains weights chains
-- v
-- - -> E -> F ...
--
--- Simple heuristic to chose which chains we want to combine:
+-- A simple heuristic to chose which chains we want to combine:
-- * Process edges in descending priority.
-- * Check if there is a edge near the end of one chain which goes
-- to a block near the start of another edge.
@@ -375,14 +361,22 @@ fuseChains weights chains
-- us to find all edges between two chains, check the distance for all edges,
-- rank them based on the distance and and only then we can select two chains
-- to combine. Which would add a lot of complexity for little gain.
+--
+-- So instead we just rank by the strength of the edge and use the first pair we
+-- find.
-- | For a given list of chains and edges try to combine chains with strong
-- edges between them.
-combineNeighbourhood :: WeightedEdgeList -> [BlockChain]
- -> [BlockChain]
+combineNeighbourhood :: [CfgEdge] -- ^ Edges to consider
+ -> [BlockChain] -- ^ Current chains of blocks
+ -> ([BlockChain], Set.Set (BlockId,BlockId))
+ -- ^ Resulting list of block chains, and a set of edges which
+ -- were used to fuse chains and as such no longer need to be
+ -- considered.
combineNeighbourhood edges chains
= -- pprTraceIt "Neigbours" $
- applyEdges edges endFrontier startFrontier
+ -- pprTrace "combineNeighbours" (ppr edges) $
+ applyEdges edges endFrontier startFrontier (Set.empty)
where
--Build maps from chain ends to chains
endFrontier, startFrontier :: FrontierMap
@@ -396,14 +390,14 @@ combineNeighbourhood edges chains
let front = getFronts chain
entry = (front,chain)
in map (\x -> (x,entry)) front) chains
- applyEdges :: WeightedEdgeList -> FrontierMap -> FrontierMap
- -> [BlockChain]
- applyEdges [] chainEnds _chainFronts =
- ordNub $ map snd $ mapElems chainEnds
- applyEdges ((WeightedEdge from to _w):edges) chainEnds chainFronts
+ applyEdges :: [CfgEdge] -> FrontierMap -> FrontierMap -> Set.Set (BlockId, BlockId)
+ -> ([BlockChain], Set.Set (BlockId,BlockId))
+ applyEdges [] chainEnds _chainFronts combined =
+ (ordNub $ map snd $ mapElems chainEnds, combined)
+ applyEdges ((CfgEdge from to _w):edges) chainEnds chainFronts combined
| Just (c1_e,c1) <- mapLookup from chainEnds
, Just (c2_f,c2) <- mapLookup to chainFronts
- , c1 /= c2 -- Avoid trying to concat a short chain with itself.
+ , c1 /= c2 -- Avoid trying to concat a chain with itself.
= let newChain = chainConcat c1 c2
newChainFrontier = getFronts newChain
newChainEnds = getEnds newChain
@@ -437,165 +431,299 @@ combineNeighbourhood edges chains
-- text "fronts" <+> ppr newFronts $$
-- text "ends" <+> ppr newEnds
-- )
- applyEdges edges newEnds newFronts
+ applyEdges edges newEnds newFronts (Set.insert (from,to) combined)
| otherwise
- = --pprTrace "noNeigbours" (ppr ()) $
- applyEdges edges chainEnds chainFronts
+ = applyEdges edges chainEnds chainFronts combined
where
getFronts chain = takeL neighbourOverlapp chain
getEnds chain = takeR neighbourOverlapp chain
-
-
--- See [Chain based CFG serialization]
-buildChains :: CFG -> [BlockId]
- -> ( LabelMap BlockChain -- Resulting chains.
+-- In the last stop we combine all chains into a single one.
+-- Trying to place chains with strong edges next to each other.
+mergeChains :: [CfgEdge] -> [BlockChain]
+ -> (BlockChain)
+mergeChains edges chains
+ = -- pprTrace "combine" (ppr edges) $
+ runST $ do
+ let addChain m0 chain = do
+ ref <- newSTRef chain
+ return $ chainFoldl (\m' b -> mapInsert b ref m') m0 chain
+ chainMap' <- foldM (\m0 c -> addChain m0 c) mapEmpty chains
+ merge edges chainMap'
+ where
+ -- We keep a map from ALL blocks to their respective chain (sigh)
+ -- This is required since when looking at an edge we need to find
+ -- the associated chains quickly.
+ -- We use a map of STRefs, maintaining a invariant of one STRef per chain.
+ -- When merging chains we can update the
+ -- STRef of one chain once (instead of writing to the map for each block).
+ -- We then overwrite the STRefs for the other chain so there is again only
+ -- a single STRef for the combined chain.
+ -- The difference in terms of allocations saved is ~0.2% with -O so actually
+ -- significant compared to using a regular map.
+
+ merge :: forall s. [CfgEdge] -> LabelMap (STRef s BlockChain) -> ST s BlockChain
+ merge [] chains = do
+ chains' <- ordNub <$> (mapM readSTRef $ mapElems chains) :: ST s [BlockChain]
+ return $ foldl' chainConcat (head chains') (tail chains')
+ merge ((CfgEdge from to _):edges) chains
+ -- | pprTrace "merge" (ppr (from,to) <> ppr chains) False
+ -- = undefined
+ | cFrom == cTo
+ = merge edges chains
+ | otherwise
+ = do
+ chains' <- mergeComb cFrom cTo
+ merge edges chains'
+ where
+ mergeComb :: STRef s BlockChain -> STRef s BlockChain -> ST s (LabelMap (STRef s BlockChain))
+ mergeComb refFrom refTo = do
+ cRight <- readSTRef refTo
+ chain <- pure chainConcat <*> readSTRef refFrom <*> pure cRight
+ writeSTRef refFrom chain
+ return $ chainFoldl (\m b -> mapInsert b refFrom m) chains cRight
+
+ cFrom = expectJust "mergeChains:chainMap:from" $ mapLookup from chains
+ cTo = expectJust "mergeChains:chainMap:to" $ mapLookup to chains
+
+
+-- See Note [Chain based CFG serialization] for the general idea.
+-- This creates and fuses chains at the same time for performance reasons.
+
+-- Try to build chains from a list of edges.
+-- Edges must be sorted **descending** by their priority.
+-- Returns the constructed chains, along with all edges which
+-- are irrelevant past this point, this information doesn't need
+-- to be complete - it's only used to speed up the process.
+-- An Edge is irrelevant if the ends are part of the same chain.
+-- We say these edges are already linked
+buildChains :: [CfgEdge] -> [BlockId]
+ -> ( LabelMap BlockChain -- Resulting chains, indexd by end if chain.
, Set.Set (BlockId, BlockId)) --List of fused edges.
-buildChains succWeights blocks
- = let (_, fusedEdges, chains) = buildNext setEmpty mapEmpty blocks Set.empty
- in (chains, fusedEdges)
+buildChains edges blocks
+ = runST $ buildNext setEmpty mapEmpty mapEmpty edges Set.empty
where
- -- We keep a map from the last block in a chain to the chain itself.
- -- This we we can easily check if an block should be appened to an
+ -- buildNext builds up chains from edges one at a time.
+
+ -- We keep a map from the ends of chains to the chains.
+ -- This we we can easily check if an block should be appended to an
-- existing chain!
- buildNext :: LabelSet
- -> LabelMap BlockChain -- Map from last element to chain.
- -> [BlockId] -- Blocks to place
- -> Set.Set (BlockId, BlockId)
- -> ( [BlockChain] -- Placed Blocks
- , Set.Set (BlockId, BlockId) --List of fused edges
- , LabelMap BlockChain
- )
- buildNext _placed chains [] linked =
- ([], linked, chains)
- buildNext placed chains (block:todo) linked
- | setMember block placed
- = buildNext placed chains todo linked
+ -- We store them using STRefs so we don't have to rebuild the spine of both
+ -- maps every time we update a chain.
+ buildNext :: forall s. LabelSet
+ -> LabelMap (STRef s BlockChain) -- Map from end of chain to chain.
+ -> LabelMap (STRef s BlockChain) -- Map from start of chain to chain.
+ -> [CfgEdge] -- Edges to check - ordered by decreasing weight
+ -> Set.Set (BlockId, BlockId) -- Used edges
+ -> ST s ( LabelMap BlockChain -- Chains by end
+ , Set.Set (BlockId, BlockId) --List of fused edges
+ )
+ buildNext placed _chainStarts chainEnds [] linked = do
+ ends' <- sequence $ mapMap readSTRef chainEnds :: ST s (LabelMap BlockChain)
+ -- Any remaining blocks have to be made to singleton chains.
+ -- They might be combined with other chains later on outside this function.
+ let unplaced = filter (\x -> not (setMember x placed)) blocks
+ singletons = map (\x -> (x,chainSingleton x)) unplaced :: [(BlockId,BlockChain)]
+ return (foldl' (\m (k,v) -> mapInsert k v m) ends' singletons , linked)
+ buildNext placed chainStarts chainEnds (edge:todo) linked
+ | from == to
+ -- We skip self edges
+ = buildNext placed chainStarts chainEnds todo (Set.insert (from,to) linked)
+ | not (alreadyPlaced from) &&
+ not (alreadyPlaced to)
+ = do
+ --pprTraceM "Edge-Chain:" (ppr edge)
+ chain' <- newSTRef $ chainFromList [from,to]
+ buildNext
+ (setInsert to (setInsert from placed))
+ (mapInsert from chain' chainStarts)
+ (mapInsert to chain' chainEnds)
+ todo
+ (Set.insert (from,to) linked)
+
+ | (alreadyPlaced from) &&
+ (alreadyPlaced to)
+ , Just predChain <- mapLookup from chainEnds
+ , Just succChain <- mapLookup to chainStarts
+ , predChain /= succChain -- Otherwise we try to create a cycle.
+ = do
+ -- pprTraceM "Fusing edge" (ppr edge)
+ fuseChain predChain succChain
+
+ | (alreadyPlaced from) &&
+ (alreadyPlaced to)
+ = --pprTraceM "Skipping:" (ppr edge) >>
+ buildNext placed chainStarts chainEnds todo linked
+
| otherwise
- = buildNext placed' chains' todo linked'
+ = do -- pprTraceM "Finding chain for:" (ppr edge $$
+ -- text "placed" <+> ppr placed)
+ findChain
where
- placed' = (foldl' (flip setInsert) placed placedBlocks)
- linked' = Set.union linked linkedEdges
- (placedBlocks, chains', linkedEdges) = findChain block
-
- --Add the block to a existing or new chain
- --Returns placed blocks, list of resulting chains
- --and fused edges
- findChain :: BlockId
- -> ([BlockId],LabelMap BlockChain, Set.Set (BlockId, BlockId))
- findChain block
- -- B) place block at end of existing chain if
- -- there is no better block to append.
- | (pred:_) <- preds
- , alreadyPlaced pred
- , Just predChain <- mapLookup pred chains
- , (best:_) <- filter (not . alreadyPlaced) $ getSuccs pred
- , best == lbl
- = --pprTrace "B.2)" (ppr (pred,lbl)) $
- let newChain = chainSnoc predChain block
- chainMap = mapInsert lbl newChain $ mapDelete pred chains
- in ( [lbl]
- , chainMap
- , Set.singleton (pred,lbl) )
-
+ from = edgeFrom edge
+ to = edgeTo edge
+ alreadyPlaced blkId = (setMember blkId placed)
+
+ -- Combine two chains into a single one.
+ fuseChain :: STRef s BlockChain -> STRef s BlockChain
+ -> ST s ( LabelMap BlockChain -- Chains by end
+ , Set.Set (BlockId, BlockId) --List of fused edges
+ )
+ fuseChain fromRef toRef = do
+ fromChain <- readSTRef fromRef
+ toChain <- readSTRef toRef
+ let newChain = chainConcat fromChain toChain
+ ref <- newSTRef newChain
+ let start = head $ takeL 1 newChain
+ let end = head $ takeR 1 newChain
+ -- chains <- sequence $ mapMap readSTRef chainStarts
+ -- pprTraceM "pre-fuse chains:" $ ppr chains
+ buildNext
+ placed
+ (mapInsert start ref $ mapDelete to $ chainStarts)
+ (mapInsert end ref $ mapDelete from $ chainEnds)
+ todo
+ (Set.insert (from,to) linked)
+
+
+ --Add the block to a existing chain or creates a new chain
+ findChain :: ST s ( LabelMap BlockChain -- Chains by end
+ , Set.Set (BlockId, BlockId) --List of fused edges
+ )
+ findChain
+ -- We can attach the block to the end of a chain
+ | alreadyPlaced from
+ , Just predChain <- mapLookup from chainEnds
+ = do
+ chain <- readSTRef predChain
+ let newChain = chainSnoc chain to
+ writeSTRef predChain newChain
+ let chainEnds' = mapInsert to predChain $ mapDelete from chainEnds
+ -- chains <- sequence $ mapMap readSTRef chainStarts
+ -- pprTraceM "from chains:" $ ppr chains
+ buildNext (setInsert to placed) chainStarts chainEnds' todo (Set.insert (from,to) linked)
+ -- We can attack it to the front of a chain
+ | alreadyPlaced to
+ , Just succChain <- mapLookup to chainStarts
+ = do
+ chain <- readSTRef succChain
+ let newChain = from `chainCons` chain
+ writeSTRef succChain newChain
+ let chainStarts' = mapInsert from succChain $ mapDelete to chainStarts
+ -- chains <- sequence $ mapMap readSTRef chainStarts'
+ -- pprTraceM "to chains:" $ ppr chains
+ buildNext (setInsert from placed) chainStarts' chainEnds todo (Set.insert (from,to) linked)
+ -- The placed end of the edge is part of a chain already and not an end.
| otherwise
- = --pprTrace "single" (ppr lbl)
- ( [lbl]
- , mapInsert lbl (chainSingleton lbl) chains
- , Set.empty)
+ = do
+ let block = if alreadyPlaced to then from else to
+ --pprTraceM "Singleton" $ ppr block
+ let newChain = chainSingleton block
+ ref <- newSTRef newChain
+ buildNext (setInsert block placed) (mapInsert block ref chainStarts)
+ (mapInsert block ref chainEnds) todo (linked)
where
alreadyPlaced blkId = (setMember blkId placed)
- lbl = block
- getSuccs = map fst . getSuccEdgesSorted succWeights
- preds = map fst $ getSuccEdgesSorted predWeights lbl
- --For efficiency we also create the map to look up predecessors here
- predWeights = reverseEdges succWeights
-
-
--- We make the CFG a Hoopl Graph, so we can reuse revPostOrder.
-newtype BlockNode (e :: Extensibility) (x :: Extensibility) = BN (BlockId,[BlockId])
-instance NonLocal (BlockNode) where
- entryLabel (BN (lbl,_)) = lbl
- successors (BN (_,succs)) = succs
-
-fromNode :: BlockNode C C -> BlockId
-fromNode (BN x) = fst x
-
-sequenceChain :: forall a i. (Instruction i, Outputable i) => LabelMap a -> CFG
- -> [GenBasicBlock i] -> [GenBasicBlock i]
+-- | Place basic blocks based on the given CFG.
+-- See Note [Chain based CFG serialization]
+sequenceChain :: forall a i. (Instruction i, Outputable i)
+ => LabelMap a -- ^ Keys indicate an info table on the block.
+ -> CFG -- ^ Control flow graph and some meta data.
+ -> [GenBasicBlock i] -- ^ List of basic blocks to be placed.
+ -> [GenBasicBlock i] -- ^ Blocks placed in sequence.
sequenceChain _info _weights [] = []
sequenceChain _info _weights [x] = [x]
sequenceChain info weights' blocks@((BasicBlock entry _):_) =
- --Optimization, delete edges of weight <= 0.
- --This significantly improves performance whenever
- --we iterate over all edges, which is a few times!
let weights :: CFG
- weights
- = filterEdges (\_f _t edgeInfo -> edgeWeight edgeInfo > 0) weights'
+ weights = --pprTrace "cfg'" (pprEdgeWeights cfg')
+ cfg'
+ where
+ (_, globalEdgeWeights) = {-# SCC mkGlobalWeights #-} mkGlobalWeights entry weights'
+ cfg' = {-# SCC rewriteEdges #-}
+ mapFoldlWithKey
+ (\cfg from m ->
+ mapFoldlWithKey
+ (\cfg to w -> setEdgeWeight cfg (EdgeWeight w) from to )
+ cfg m )
+ weights'
+ globalEdgeWeights
+
+ directEdges :: [CfgEdge]
+ directEdges = sortBy (flip compare) $ catMaybes . map relevantWeight $ (infoEdgeList weights)
+ where
+ relevantWeight :: CfgEdge -> Maybe CfgEdge
+ relevantWeight edge@(CfgEdge from to edgeInfo)
+ | (EdgeInfo CmmSource { trans_cmmNode = CmmCall {} } _) <- edgeInfo
+ -- Ignore edges across calls
+ = Nothing
+ | mapMember to info
+ , w <- edgeWeight edgeInfo
+ -- The payoff is small if we jump over an info table
+ = Just (CfgEdge from to edgeInfo { edgeWeight = w/8 })
+ | otherwise
+ = Just edge
+
blockMap :: LabelMap (GenBasicBlock i)
blockMap
= foldl' (\m blk@(BasicBlock lbl _ins) ->
mapInsert lbl blk m)
mapEmpty blocks
- toNode :: BlockId -> BlockNode C C
- toNode bid =
- -- sorted such that heavier successors come first.
- BN (bid,map fst . getSuccEdgesSorted weights' $ bid)
-
- orderedBlocks :: [BlockId]
- orderedBlocks
- = map fromNode $
- revPostorderFrom (fmap (toNode . blockId) blockMap) entry
-
(builtChains, builtEdges)
= {-# SCC "buildChains" #-}
--pprTraceIt "generatedChains" $
- --pprTrace "orderedBlocks" (ppr orderedBlocks) $
- buildChains weights orderedBlocks
+ --pprTrace "blocks" (ppr (mapKeys blockMap)) $
+ buildChains directEdges (mapKeys blockMap)
- rankedEdges :: WeightedEdgeList
- -- Sort edges descending, remove fused eges
+ rankedEdges :: [CfgEdge]
+ -- Sort descending by weight, remove fused edges
rankedEdges =
- map (\(from, to, weight) -> WeightedEdge from to weight) .
- filter (\(from, to, _)
- -> not (Set.member (from,to) builtEdges)) .
- sortWith (\(_,_,w) -> - w) $ weightedEdgeList weights
+ filter (\edge -> not (Set.member (edgeFrom edge,edgeTo edge) builtEdges)) $
+ directEdges
- (fusedChains, fusedEdges)
+ (neighbourChains, combined)
= ASSERT(noDups $ mapElems builtChains)
- {-# SCC "fuseChains" #-}
- --(pprTrace "RankedEdges" $ ppr rankedEdges) $
- --pprTraceIt "FusedChains" $
- fuseChains rankedEdges builtChains
-
- rankedEdges' =
- filter (\edge -> not $ Set.member edge fusedEdges) $ rankedEdges
-
- neighbourChains
- = ASSERT(noDups $ mapElems fusedChains)
{-# SCC "groupNeighbourChains" #-}
- --pprTraceIt "ResultChains" $
- combineNeighbourhood rankedEdges' (mapElems fusedChains)
+ -- pprTraceIt "NeighbourChains" $
+ combineNeighbourhood rankedEdges (mapElems builtChains)
+
+
+ allEdges :: [CfgEdge]
+ allEdges = {-# SCC allEdges #-}
+ sortOn (relevantWeight) $ filter (not . deadEdge) $ (infoEdgeList weights)
+ where
+ deadEdge :: CfgEdge -> Bool
+ deadEdge (CfgEdge from to _) = let e = (from,to) in Set.member e combined || Set.member e builtEdges
+ relevantWeight :: CfgEdge -> EdgeWeight
+ relevantWeight (CfgEdge _ _ edgeInfo)
+ | EdgeInfo (CmmSource { trans_cmmNode = CmmCall {}}) _ <- edgeInfo
+ -- Penalize edges across calls
+ = weight/(64.0)
+ | otherwise
+ = weight
+ where
+ -- negate to sort descending
+ weight = negate (edgeWeight edgeInfo)
+
+ masterChain =
+ {-# SCC "mergeChains" #-}
+ -- pprTraceIt "MergedChains" $
+ mergeChains allEdges neighbourChains
--Make sure the first block stays first
- ([entryChain],chains')
- = ASSERT(noDups $ neighbourChains)
- partition (chainMember entry) neighbourChains
- (entryChain':entryRest)
- | inFront entry entryChain = [entryChain]
- | (rest,entry) <- breakChainAt entry entryChain
+ prepedChains
+ | inFront entry masterChain
+ = [masterChain]
+ | (rest,entry) <- breakChainAt entry masterChain
= [entry,rest]
| otherwise = pprPanic "Entry point eliminated" $
- ppr ([entryChain],chains')
+ ppr masterChain
- prepedChains
- = entryChain':(entryRest++chains') :: [BlockChain]
blockList
- -- = (concatMap chainToBlocks prepedChains)
- = (concatMap fromOL $ map chainBlocks prepedChains)
+ = ASSERT(noDups [masterChain])
+ (concatMap fromOL $ map chainBlocks prepedChains)
--chainPlaced = setFromList $ map blockId blockList :: LabelSet
chainPlaced = setFromList $ blockList :: LabelSet
@@ -605,14 +733,22 @@ sequenceChain info weights' blocks@((BasicBlock entry _):_) =
in filter (\block -> not (isPlaced block)) blocks
placedBlocks =
+ -- We want debug builds to catch this as it's a good indicator for
+ -- issues with CFG invariants. But we don't want to blow up production
+ -- builds if something slips through.
+ ASSERT(null unplaced)
--pprTraceIt "placedBlocks" $
- blockList ++ unplaced
+ -- ++ [] is stil kinda expensive
+ if null unplaced then blockList else blockList ++ unplaced
getBlock bid = expectJust "Block placment" $ mapLookup bid blockMap
in
--Assert we placed all blocks given as input
ASSERT(all (\bid -> mapMember bid blockMap) placedBlocks)
dropJumps info $ map getBlock placedBlocks
+{-# SCC dropJumps #-}
+-- | Remove redundant jumps between blocks when we can rely on
+-- fall through.
dropJumps :: forall a i. Instruction i => LabelMap a -> [GenBasicBlock i]
-> [GenBasicBlock i]
dropJumps _ [] = []
@@ -639,8 +775,10 @@ dropJumps info ((BasicBlock lbl ins):todo)
sequenceTop
:: (Instruction instr, Outputable instr)
=> DynFlags --Use new layout code
- -> NcgImpl statics instr jumpDest -> CFG
- -> NatCmmDecl statics instr -> NatCmmDecl statics instr
+ -> NcgImpl statics instr jumpDest
+ -> CFG
+ -> NatCmmDecl statics instr
+ -> NatCmmDecl statics instr
sequenceTop _ _ _ top@(CmmData _ _) = top
sequenceTop dflags ncgImpl edgeWeights
@@ -648,10 +786,12 @@ sequenceTop dflags ncgImpl edgeWeights
| (gopt Opt_CfgBlocklayout dflags) && backendMaintainsCfg dflags
--Use chain based algorithm
= CmmProc info lbl live ( ListGraph $ ncgMakeFarBranches ncgImpl info $
+ {-# SCC layoutBlocks #-}
sequenceChain info edgeWeights blocks )
| otherwise
--Use old algorithm
= CmmProc info lbl live ( ListGraph $ ncgMakeFarBranches ncgImpl info $
+ {-# SCC layoutBlocks #-}
sequenceBlocks cfg info blocks)
where
cfg
diff --git a/compiler/nativeGen/CFG.hs b/compiler/nativeGen/CFG.hs
index 44ddecd216..e1251b76f2 100644
--- a/compiler/nativeGen/CFG.hs
+++ b/compiler/nativeGen/CFG.hs
@@ -6,31 +6,40 @@
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP #-}
+{-# LANGUAGE Rank2Types #-}
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE DataKinds #-}
module CFG
( CFG, CfgEdge(..), EdgeInfo(..), EdgeWeight(..)
, TransitionSource(..)
--Modify the CFG
- , addWeightEdge, addEdge, delEdge
+ , addWeightEdge, addEdge
+ , delEdge, delNode
, addNodesBetween, shortcutWeightMap
, reverseEdges, filterEdges
, addImmediateSuccessor
- , mkWeightInfo, adjustEdgeWeight
+ , mkWeightInfo, adjustEdgeWeight, setEdgeWeight
--Query the CFG
, infoEdgeList, edgeList
, getSuccessorEdges, getSuccessors
- , getSuccEdgesSorted, weightedEdgeList
+ , getSuccEdgesSorted
, getEdgeInfo
, getCfgNodes, hasNode
- , loopMembers
+
+ -- Loop Information
+ , loopMembers, loopLevels, loopInfo
--Construction/Misc
, getCfg, getCfgProc, pprEdgeWeights, sanityCheckCfg
--Find backedges and update their weight
- , optimizeCFG )
+ , optimizeCFG
+ , mkGlobalWeights
+
+ )
where
#include "HsVersions.h"
@@ -38,9 +47,8 @@ where
import GhcPrelude
import BlockId
-import Cmm ( RawCmmDecl, GenCmmDecl( .. ), CmmBlock, succ, g_entry
- , CmmGraph )
-import CmmNode
+import Cmm
+
import CmmUtils
import CmmSwitch
import Hoopl.Collections
@@ -50,10 +58,24 @@ import qualified Hoopl.Graph as G
import Util
import Digraph
+import Maybes
+
+import Unique
+import qualified Dominators as Dom
+import Data.IntMap.Strict (IntMap)
+import Data.IntSet (IntSet)
+
+import qualified Data.IntMap.Strict as IM
+import qualified Data.Map as M
+import qualified Data.IntSet as IS
+import qualified Data.Set as S
+import Data.Tree
+import Data.Bifunctor
import Outputable
-- DEBUGGING ONLY
--import Debug
+-- import Debug.Trace
--import OrdList
--import Debug.Trace
import PprCmm () -- For Outputable instances
@@ -61,17 +83,28 @@ import qualified DynFlags as D
import Data.List
--- import qualified Data.IntMap.Strict as M --TODO: LabelMap
+import Data.STRef.Strict
+import Control.Monad.ST
+
+import Data.Array.MArray
+import Data.Array.ST
+import Data.Array.IArray
+import Data.Array.Unsafe (unsafeFreeze)
+import Data.Array.Base (unsafeRead, unsafeWrite)
+
+import Control.Monad
+
+type Prob = Double
type Edge = (BlockId, BlockId)
type Edges = [Edge]
newtype EdgeWeight
- = EdgeWeight Int
- deriving (Eq,Ord,Enum,Num,Real,Integral)
+ = EdgeWeight { weightToDouble :: Double }
+ deriving (Eq,Ord,Enum,Num,Real,Fractional)
instance Outputable EdgeWeight where
- ppr (EdgeWeight w) = ppr w
+ ppr (EdgeWeight w) = doublePrec 5 w
type EdgeInfoMap edgeInfo = LabelMap (LabelMap edgeInfo)
@@ -108,15 +141,28 @@ instance Outputable CfgEdge where
= parens (ppr from1 <+> text "-(" <> ppr edgeInfo <> text ")->" <+> ppr to1)
-- | Can we trace back a edge to a specific Cmm Node
--- or has it been introduced for codegen. We use this to maintain
+-- or has it been introduced during assembly codegen. We use this to maintain
-- some information which would otherwise be lost during the
-- Cmm <-> asm transition.
-- See also Note [Inverting Conditional Branches]
data TransitionSource
- = CmmSource (CmmNode O C)
+ = CmmSource { trans_cmmNode :: (CmmNode O C)
+ , trans_info :: BranchInfo }
| AsmCodeGen
deriving (Eq)
+data BranchInfo = NoInfo -- ^ Unknown, but not heap or stack check.
+ | HeapStackCheck -- ^ Heap or stack check
+ deriving Eq
+
+instance Outputable BranchInfo where
+ ppr NoInfo = text "regular"
+ ppr HeapStackCheck = text "heap/stack"
+
+isHeapOrStackCheck :: TransitionSource -> Bool
+isHeapOrStackCheck (CmmSource { trans_info = HeapStackCheck}) = True
+isHeapOrStackCheck _ = False
+
-- | Information about edges
data EdgeInfo
= EdgeInfo
@@ -127,12 +173,10 @@ data EdgeInfo
instance Outputable EdgeInfo where
ppr edgeInfo = text "weight:" <+> ppr (edgeWeight edgeInfo)
--- Allow specialization
-{-# INLINEABLE mkWeightInfo #-}
-- | Convenience function, generate edge info based
-- on weight not originating from cmm.
-mkWeightInfo :: Integral n => n -> EdgeInfo
-mkWeightInfo = EdgeInfo AsmCodeGen . fromIntegral
+mkWeightInfo :: EdgeWeight -> EdgeInfo
+mkWeightInfo = EdgeInfo AsmCodeGen
-- | Adjust the weight between the blocks using the given function.
-- If there is no such edge returns the original map.
@@ -140,12 +184,25 @@ adjustEdgeWeight :: CFG -> (EdgeWeight -> EdgeWeight)
-> BlockId -> BlockId -> CFG
adjustEdgeWeight cfg f from to
| Just info <- getEdgeInfo from to cfg
- , weight <- edgeWeight info
- = addEdge from to (info { edgeWeight = f weight}) cfg
+ , !weight <- edgeWeight info
+ , !newWeight <- f weight
+ = addEdge from to (info { edgeWeight = newWeight}) cfg
+ | otherwise = cfg
+
+-- | Set the weight between the blocks to the given weight.
+-- If there is no such edge returns the original map.
+setEdgeWeight :: CFG -> EdgeWeight
+ -> BlockId -> BlockId -> CFG
+setEdgeWeight cfg !weight from to
+ | Just info <- getEdgeInfo from to cfg
+ = addEdge from to (info { edgeWeight = weight}) cfg
| otherwise = cfg
+
+
getCfgNodes :: CFG -> LabelSet
-getCfgNodes m = mapFoldMapWithKey (\k v -> setFromList (k:mapKeys v)) m
+getCfgNodes m =
+ mapFoldlWithKey (\s k toMap -> mapFoldlWithKey (\s k _ -> setInsert k s) (setInsert k s) toMap ) setEmpty m
hasNode :: CFG -> BlockId -> Bool
hasNode m node = mapMember node m || any (mapMember node) m
@@ -294,6 +351,11 @@ delEdge from to m =
remDest Nothing = Nothing
remDest (Just wm) = Just $ mapDelete to wm
+delNode :: BlockId -> CFG -> CFG
+delNode node cfg =
+ fmap (mapDelete node) -- < Edges to the node
+ (mapDelete node cfg) -- < Edges from the node
+
-- | Destinations from bid ordered by weight (descending)
getSuccEdgesSorted :: CFG -> BlockId -> [(BlockId,EdgeInfo)]
getSuccEdgesSorted m bid =
@@ -315,36 +377,54 @@ getEdgeInfo from to m
| otherwise
= Nothing
+getEdgeWeight :: CFG -> BlockId -> BlockId -> EdgeWeight
+getEdgeWeight cfg from to =
+ edgeWeight $ expectJust "Edgeweight for noexisting block" $
+ getEdgeInfo from to cfg
+
+getTransitionSource :: BlockId -> BlockId -> CFG -> TransitionSource
+getTransitionSource from to cfg = transitionSource $ expectJust "Source info for noexisting block" $
+ getEdgeInfo from to cfg
+
reverseEdges :: CFG -> CFG
-reverseEdges cfg = foldr add mapEmpty flatElems
+reverseEdges cfg = mapFoldlWithKey (\cfg from toMap -> go (addNode cfg from) from toMap) mapEmpty cfg
where
- elems = mapToList $ fmap mapToList cfg :: [(BlockId,[(BlockId,EdgeInfo)])]
- flatElems =
- concatMap (\(from,ws) -> map (\(to,info) -> (to,from,info)) ws ) elems
- add (to,from,info) m = addEdge to from info m
+ -- We preserve nodes without outgoing edges!
+ addNode :: CFG -> BlockId -> CFG
+ addNode cfg b = mapInsertWith mapUnion b mapEmpty cfg
+ go :: CFG -> BlockId -> (LabelMap EdgeInfo) -> CFG
+ go cfg from toMap = mapFoldlWithKey (\cfg to info -> addEdge to from info cfg) cfg toMap :: CFG
+
-- | Returns a unordered list of all edges with info
infoEdgeList :: CFG -> [CfgEdge]
infoEdgeList m =
- mapFoldMapWithKey
- (\from toMap ->
- map (\(to,info) -> CfgEdge from to info) (mapToList toMap))
- m
-
--- | Unordered list of edges with weight as Tuple (from,to,weight)
-weightedEdgeList :: CFG -> [(BlockId,BlockId,EdgeWeight)]
-weightedEdgeList m =
- mapFoldMapWithKey
- (\from toMap ->
- map (\(to,info) ->
- (from,to, edgeWeight info)) (mapToList toMap))
- m
- -- (\(from, tos) -> map (\(to,info) -> (from,to, edgeWeight info)) tos )
+ go (mapToList m) []
+ where
+ -- We avoid foldMap to avoid thunk buildup
+ go :: [(BlockId,LabelMap EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
+ go [] acc = acc
+ go ((from,toMap):xs) acc
+ = go' xs from (mapToList toMap) acc
+ go' :: [(BlockId,LabelMap EdgeInfo)] -> BlockId -> [(BlockId,EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
+ go' froms _ [] acc = go froms acc
+ go' froms from ((to,info):tos) acc
+ = go' froms from tos (CfgEdge from to info : acc)
-- | Returns a unordered list of all edges without weights
edgeList :: CFG -> [Edge]
edgeList m =
- mapFoldMapWithKey (\from toMap -> fmap (from,) (mapKeys toMap)) m
+ go (mapToList m) []
+ where
+ -- We avoid foldMap to avoid thunk buildup
+ go :: [(BlockId,LabelMap EdgeInfo)] -> [Edge] -> [Edge]
+ go [] acc = acc
+ go ((from,toMap):xs) acc
+ = go' xs from (mapKeys toMap) acc
+ go' :: [(BlockId,LabelMap EdgeInfo)] -> BlockId -> [BlockId] -> [Edge] -> [Edge]
+ go' froms _ [] acc = go froms acc
+ go' froms from (to:tos) acc
+ = go' froms from tos ((from,to) : acc)
-- | Get successors of a given node without edge weights.
getSuccessors :: CFG -> BlockId -> [BlockId]
@@ -355,8 +435,8 @@ getSuccessors m bid
pprEdgeWeights :: CFG -> SDoc
pprEdgeWeights m =
- let edges = sort $ weightedEdgeList m
- printEdge (from, to, weight)
+ let edges = sort $ infoEdgeList m :: [CfgEdge]
+ printEdge (CfgEdge from to (EdgeInfo { edgeWeight = weight }))
= text "\t" <> ppr from <+> text "->" <+> ppr to <>
text "[label=\"" <> ppr weight <> text "\",weight=\"" <>
ppr weight <> text "\"];\n"
@@ -365,7 +445,7 @@ pprEdgeWeights m =
--to immediately see it when it does.
printNode node
= text "\t" <> ppr node <> text ";\n"
- getEdgeNodes (from, to, _weight) = [from,to]
+ getEdgeNodes (CfgEdge from to _) = [from,to]
edgeNodes = setFromList $ concatMap getEdgeNodes edges :: LabelSet
nodes = filter (\n -> (not . setMember n) edgeNodes) . mapKeys $ mapFilter null m
in
@@ -378,8 +458,8 @@ pprEdgeWeights m =
updateEdgeWeight :: (EdgeWeight -> EdgeWeight) -> Edge -> CFG -> CFG
updateEdgeWeight f (from, to) cfg
| Just oldInfo <- getEdgeInfo from to cfg
- = let oldWeight = edgeWeight oldInfo
- newWeight = f oldWeight
+ = let !oldWeight = edgeWeight oldInfo
+ !newWeight = f oldWeight
in addEdge from to (oldInfo {edgeWeight = newWeight}) cfg
| otherwise
= panic "Trying to update invalid edge"
@@ -447,9 +527,7 @@ addNodesBetween m updates =
Should A or B be placed in front of C? The block layout algorithm
decides this based on which edge (A,C)/(B,C) is heavier. So we
- make a educated guess how often execution will transer control
- along each edge as well as how much we gain by placing eg A before
- C.
+ make a educated guess on which branch should be preferred.
We rank edges in this order:
* Unconditional Control Transfer - They will always
@@ -478,7 +556,6 @@ addNodesBetween m updates =
address. This reduces the chance that we return to the same
cache line further.
-
-}
-- | Generate weights for a Cmm proc based on some simple heuristics.
getCfgProc :: D.CfgWeights -> RawCmmDecl -> CFG
@@ -514,13 +591,24 @@ getCfg weights graph =
getBlockEdges block =
case branch of
CmmBranch dest -> [mkEdge dest uncondWeight]
- CmmCondBranch _c t f l
+ CmmCondBranch cond t f l
| l == Nothing ->
[mkEdge f condBranchWeight, mkEdge t condBranchWeight]
| l == Just True ->
[mkEdge f unlikelyCondWeight, mkEdge t likelyCondWeight]
| l == Just False ->
[mkEdge f likelyCondWeight, mkEdge t unlikelyCondWeight]
+ where
+ mkEdgeInfo = -- pprTrace "Info" (ppr branchInfo <+> ppr cond)
+ EdgeInfo (CmmSource branch branchInfo) . fromIntegral
+ mkEdge target weight = ((bid,target), mkEdgeInfo weight)
+ branchInfo =
+ foldRegsUsed
+ (panic "foldRegsDynFlags")
+ (\info r -> if r == SpLim || r == HpLim || r == BaseReg
+ then HeapStackCheck else info)
+ NoInfo cond
+
(CmmSwitch _e ids) ->
let switchTargets = switchTargetsToList ids
--Compiler performance hack - for very wide switches don't
@@ -538,7 +626,7 @@ getCfg weights graph =
map (\x -> ((bid,x),mkEdgeInfo 0)) $ G.successors other
where
bid = G.entryLabel block
- mkEdgeInfo = EdgeInfo (CmmSource branch) . fromIntegral
+ mkEdgeInfo = EdgeInfo (CmmSource branch NoInfo) . fromIntegral
mkEdge target weight = ((bid,target), mkEdgeInfo weight)
branch = lastNode block :: CmmNode O C
@@ -560,6 +648,11 @@ findBackEdges root cfg =
optimizeCFG :: D.CfgWeights -> RawCmmDecl -> CFG -> CFG
optimizeCFG _ (CmmData {}) cfg = cfg
optimizeCFG weights (CmmProc info _lab _live graph) cfg =
+ {-# SCC optimizeCFG #-}
+ -- pprTrace "Initial:" (pprEdgeWeights cfg) $
+ -- pprTrace "Initial:" (ppr $ mkGlobalWeights (g_entry graph) cfg) $
+
+ -- pprTrace "LoopInfo:" (ppr $ loopInfo cfg (g_entry graph)) $
favourFewerPreds .
penalizeInfoTables info .
increaseBackEdgeWeight (g_entry graph) $ cfg
@@ -589,12 +682,8 @@ optimizeCFG weights (CmmProc info _lab _live graph) cfg =
= weight - (fromIntegral $ D.infoTablePenalty weights)
| otherwise = weight
-
-{- Note [Optimize for Fallthrough]
-
--}
-- | If a block has two successors, favour the one with fewer
- -- predecessors. (As that one is more likely to become a fallthrough)
+ -- predecessors and/or the one allowing fall through.
favourFewerPreds :: CFG -> CFG
favourFewerPreds cfg =
let
@@ -611,16 +700,17 @@ optimizeCFG weights (CmmProc info _lab _live graph) cfg =
| preds1 == preds2 = ( 0, 0)
| otherwise = (-1, 1)
+ update :: CFG -> BlockId -> CFG
update cfg node
| [(s1,e1),(s2,e2)] <- getSuccessorEdges cfg node
- , w1 <- edgeWeight e1
- , w2 <- edgeWeight e2
+ , !w1 <- edgeWeight e1
+ , !w2 <- edgeWeight e2
--Only change the weights if there isn't already a ordering.
, w1 == w2
, (mod1,mod2) <- modifiers (predCount s1) (predCount s2)
= (\cfg' ->
(adjustEdgeWeight cfg' (+mod2) node s2))
- (adjustEdgeWeight cfg (+mod1) node s1)
+ (adjustEdgeWeight cfg (+mod1) node s1)
| otherwise
= cfg
in setFoldl update cfg nodes
@@ -629,13 +719,12 @@ optimizeCFG weights (CmmProc info _lab _live graph) cfg =
fallthroughTarget to (EdgeInfo source _weight)
| mapMember to info = False
| AsmCodeGen <- source = True
- | CmmSource (CmmBranch {}) <- source = True
- | CmmSource (CmmCondBranch {}) <- source = True
+ | CmmSource { trans_cmmNode = CmmBranch {} } <- source = True
+ | CmmSource { trans_cmmNode = CmmCondBranch {} } <- source = True
| otherwise = False
-- | Determine loop membership of blocks based on SCC analysis
--- Ideally we would replace this with a variant giving us loop
--- levels instead but the SCC code will do for now.
+-- This is faster but only gives yes/no answers.
loopMembers :: CFG -> LabelMap Bool
loopMembers cfg =
foldl' (flip setLevel) mapEmpty sccs
@@ -649,3 +738,534 @@ loopMembers cfg =
setLevel :: SCC BlockId -> LabelMap Bool -> LabelMap Bool
setLevel (AcyclicSCC bid) m = mapInsert bid False m
setLevel (CyclicSCC bids) m = foldl' (\m k -> mapInsert k True m) m bids
+
+loopLevels :: CFG -> BlockId -> LabelMap Int
+loopLevels cfg root = liLevels $ loopInfo cfg root
+
+data LoopInfo = LoopInfo
+ { liBackEdges :: [(Edge)] -- ^ List of back edges
+ , liLevels :: LabelMap Int -- ^ BlockId -> LoopLevel mapping
+ , liLoops :: [(Edge, LabelSet)] -- ^ (backEdge, loopBody), body includes header
+ }
+
+instance Outputable LoopInfo where
+ ppr (LoopInfo _ _lvls loops) =
+ text "Loops:(backEdge, bodyNodes)" $$
+ (vcat $ map ppr loops)
+
+-- | Determine loop membership of blocks based on Dominator analysis.
+-- This is slower but gives loop levels instead of just loop membership.
+-- However it only detects natural loops. Irreducible control flow is not
+-- recognized even if it loops. But that is rare enough that we don't have
+-- to care about that special case.
+loopInfo :: CFG -> BlockId -> LoopInfo
+loopInfo cfg root = LoopInfo { liBackEdges = backEdges
+ , liLevels = mapFromList loopCounts
+ , liLoops = loopBodies }
+ where
+ revCfg = reverseEdges cfg
+ graph = fmap (setFromList . mapKeys ) cfg :: LabelMap LabelSet
+
+ --TODO - This should be a no op: Export constructors? Use unsafeCoerce? ...
+ rooted = ( fromBlockId root
+ , toIntMap $ fmap toIntSet graph) :: (Int, IntMap IntSet)
+ -- rooted = unsafeCoerce (root, graph)
+ tree = fmap toBlockId $ Dom.domTree rooted :: Tree BlockId
+
+ -- Map from Nodes to their dominators
+ domMap :: LabelMap LabelSet
+ domMap = mkDomMap tree
+
+ edges = edgeList cfg :: [(BlockId, BlockId)]
+ -- We can't recompute this from the edges, there might be blocks not connected via edges.
+ nodes = getCfgNodes cfg :: LabelSet
+
+ -- identify back edges
+ isBackEdge (from,to)
+ | Just doms <- mapLookup from domMap
+ , setMember to doms
+ = True
+ | otherwise = False
+
+ -- determine the loop body for a back edge
+ findBody edge@(tail, head)
+ = ( edge, setInsert head $ go (setSingleton tail) (setSingleton tail) )
+ where
+ -- The reversed cfg makes it easier to look up predecessors
+ cfg' = delNode head revCfg
+ go :: LabelSet -> LabelSet -> LabelSet
+ go found current
+ | setNull current = found
+ | otherwise = go (setUnion newSuccessors found)
+ newSuccessors
+ where
+ newSuccessors = setFilter (\n -> not $ setMember n found) successors :: LabelSet
+ successors = setFromList $ concatMap
+ (getSuccessors cfg')
+ (setElems current) :: LabelSet
+
+ backEdges = filter isBackEdge edges
+ loopBodies = map findBody backEdges :: [(Edge, LabelSet)]
+
+ -- Block b is part of n loop bodies => loop nest level of n
+ loopCounts =
+ let bodies = map (first snd) loopBodies -- [(Header, Body)]
+ loopCount n = length $ nub . map fst . filter (setMember n . snd) $ bodies
+ in map (\n -> (n, loopCount n)) $ setElems nodes :: [(BlockId, Int)]
+
+ toIntSet :: LabelSet -> IntSet
+ toIntSet s = IS.fromList . map fromBlockId . setElems $ s
+ toIntMap :: LabelMap a -> IntMap a
+ toIntMap m = IM.fromList $ map (\(x,y) -> (fromBlockId x,y)) $ mapToList m
+
+ mkDomMap :: Tree BlockId -> LabelMap LabelSet
+ mkDomMap root = mapFromList $ go setEmpty root
+ where
+ go :: LabelSet -> Tree BlockId -> [(Label,LabelSet)]
+ go parents (Node lbl [])
+ = [(lbl, parents)]
+ go parents (Node _ leaves)
+ = let nodes = map rootLabel leaves
+ entries = map (\x -> (x,parents)) nodes
+ in entries ++ concatMap
+ (\n -> go (setInsert (rootLabel n) parents) n)
+ leaves
+
+ fromBlockId :: BlockId -> Int
+ fromBlockId = getKey . getUnique
+
+ toBlockId :: Int -> BlockId
+ toBlockId = mkBlockId . mkUniqueGrimily
+
+-- We make the CFG a Hoopl Graph, so we can reuse revPostOrder.
+newtype BlockNode (e :: Extensibility) (x :: Extensibility) = BN (BlockId,[BlockId])
+
+instance G.NonLocal (BlockNode) where
+ entryLabel (BN (lbl,_)) = lbl
+ successors (BN (_,succs)) = succs
+
+revPostorderFrom :: CFG -> BlockId -> [BlockId]
+revPostorderFrom cfg root =
+ map fromNode $ G.revPostorderFrom hooplGraph root
+ where
+ nodes = getCfgNodes cfg
+ hooplGraph = setFoldl (\m n -> mapInsert n (toNode n) m) mapEmpty nodes
+
+ fromNode :: BlockNode C C -> BlockId
+ fromNode (BN x) = fst x
+
+ toNode :: BlockId -> BlockNode C C
+ toNode bid =
+ BN (bid,getSuccessors cfg $ bid)
+
+
+-- | We take in a CFG which has on its edges weights which are
+-- relative only to other edges originating from the same node.
+--
+-- We return a CFG for which each edge represents a GLOBAL weight.
+-- This means edge weights are comparable across the whole graph.
+--
+-- For irreducible control flow results might be imprecise, otherwise they
+-- are reliable.
+--
+-- The algorithm is based on the Paper
+-- "Static Branch Prediction and Program Profile Analysis" by Y Wu, JR Larus
+-- The only big change is that we go over the nodes in the body of loops in
+-- reverse post order. Which is required for diamond control flow to work probably.
+--
+-- We also apply a few prediction heuristics (based on the same paper)
+
+{-# SCC mkGlobalWeights #-}
+mkGlobalWeights :: BlockId -> CFG -> (LabelMap Double, LabelMap (LabelMap Double))
+mkGlobalWeights root localCfg
+ | null localCfg = panic "Error - Empty CFG"
+ | otherwise
+ = --pprTrace "revOrder" (ppr revOrder) $
+ -- undefined --propagate (mapSingleton root 1) (revOrder)
+ (blockFreqs', edgeFreqs')
+ where
+ -- Calculate fixpoints
+ (blockFreqs, edgeFreqs) = calcFreqs nodeProbs backEdges' bodies' revOrder'
+ blockFreqs' = mapFromList $ map (first fromVertex) (assocs blockFreqs) :: LabelMap Double
+ edgeFreqs' = fmap fromVertexMap $ fromVertexMap edgeFreqs
+
+ fromVertexMap :: IM.IntMap x -> LabelMap x
+ fromVertexMap m = mapFromList . map (first fromVertex) $ IM.toList m
+
+ revOrder = revPostorderFrom localCfg root :: [BlockId]
+ loopinfo@(LoopInfo backedges _levels bodies) = loopInfo localCfg root
+
+ revOrder' = map toVertex revOrder
+ backEdges' = map (bimap toVertex toVertex) backedges
+ bodies' = map calcBody bodies
+
+ estimatedCfg = staticBranchPrediction root loopinfo localCfg
+ -- Normalize the weights to probabilities and apply heuristics
+ nodeProbs = cfgEdgeProbabilities estimatedCfg toVertex
+
+ -- By mapping vertices to numbers in reverse post order we can bring any subset into reverse post
+ -- order simply by sorting.
+ -- TODO: The sort is redundant if we can guarantee that setElems returns elements ascending
+ calcBody (backedge, blocks) =
+ (toVertex $ snd backedge, sort . map toVertex $ (setElems blocks))
+
+ vertexMapping = mapFromList $ zip revOrder [0..] :: LabelMap Int
+ blockMapping = listArray (0,mapSize vertexMapping - 1) revOrder :: Array Int BlockId
+ -- Map from blockId to indicies starting at zero
+ toVertex :: BlockId -> Int
+ toVertex blockId = expectJust "mkGlobalWeights" $ mapLookup blockId vertexMapping
+ -- Map from indicies starting at zero to blockIds
+ fromVertex :: Int -> BlockId
+ fromVertex vertex = blockMapping ! vertex
+
+{- Note [Static Branch Prediction]
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The work here has been based on the paper
+"Static Branch Prediction and Program Profile Analysis" by Y Wu, JR Larus.
+
+The primary differences are that if we branch on the result of a heap
+check we do not apply any of the heuristics.
+The reason is simple: They look like loops in the control flow graph
+but are usually never entered, and if at most once.
+
+Currently implemented is a heuristic to predict that we do not exit
+loops (lehPredicts) and one to predict that backedges are more likely
+than any other edge.
+
+The back edge case is special as it superceeds any other heuristic if it
+applies.
+
+Do NOT rely solely on nofib results for benchmarking this. I recommend at least
+comparing megaparsec and container benchmarks. Nofib does not seeem to have
+many instances of "loopy" Cmm where these make a difference.
+
+TODO:
+* The paper containers more benchmarks which should be implemented.
+* If we turn the likelyhood on if/else branches into a probability
+ instead of true/false we could implement this as a Cmm pass.
+ + The complete Cmm code still exists and can be accessed by the heuristics
+ + There is no chance of register allocation/codegen inserting branches/blocks
+ + making the TransitionSource info wrong.
+ + potential to use this information in CmmPasses.
+ - Requires refactoring of all the code relying on the binary nature of likelyhood.
+ - Requires refactoring `loopInfo` to work on both, Cmm Graphs and the backend CFG.
+-}
+
+-- | Combination of target node id and information about the branch
+-- we are looking at.
+type TargetNodeInfo = (BlockId, EdgeInfo)
+
+
+-- | Update branch weights based on certain heuristics.
+-- See Note [Static Branch Prediction]
+-- TODO: This should be combined with optimizeCFG
+{-# SCC staticBranchPrediction #-}
+staticBranchPrediction :: BlockId -> LoopInfo -> CFG -> CFG
+staticBranchPrediction _root (LoopInfo l_backEdges loopLevels l_loops) cfg =
+ -- pprTrace "staticEstimatesOn" (ppr (cfg)) $
+ setFoldl update cfg nodes
+ where
+ nodes = getCfgNodes cfg
+ backedges = S.fromList $ l_backEdges
+ -- Loops keyed by their back edge
+ loops = M.fromList $ l_loops :: M.Map Edge LabelSet
+ loopHeads = S.fromList $ map snd $ M.keys loops
+
+ update :: CFG -> BlockId -> CFG
+ update cfg node
+ -- No successors, nothing to do.
+ | null successors = cfg
+
+ -- Mix of backedges and others:
+ -- Always predict the backedges.
+ | not (null m) && length m < length successors
+ -- Heap/Stack checks "loop", but only once.
+ -- So we simply exclude any case involving them.
+ , not $ any (isHeapOrStackCheck . transitionSource . snd) successors
+ = let loopChance = repeat $! pred_LBH / (fromIntegral $ length m)
+ exitChance = repeat $! (1 - pred_LBH) / fromIntegral (length not_m)
+ updates = zip (map fst m) loopChance ++ zip (map fst not_m) exitChance
+ in -- pprTrace "mix" (ppr (node,successors)) $
+ foldl' (\cfg (to,weight) -> setEdgeWeight cfg weight node to) cfg updates
+
+ -- For (regular) non-binary branches we keep the weights from the STG -> Cmm translation.
+ | length successors /= 2
+ = cfg
+
+ -- Only backedges - no need to adjust
+ | length m > 0
+ = cfg
+
+ -- A regular binary branch, we can plug addition predictors in here.
+ | [(s1,s1_info),(s2,s2_info)] <- successors
+ , not $ any (isHeapOrStackCheck . transitionSource . snd) successors
+ = -- Normalize weights to total of 1
+ let !w1 = max (edgeWeight s1_info) (0)
+ !w2 = max (edgeWeight s2_info) (0)
+ -- Of both weights are <= 0 we set both to 0.5
+ normalizeWeight w = if w1 + w2 == 0 then 0.5 else w/(w1+w2)
+ !cfg' = setEdgeWeight cfg (normalizeWeight w1) node s1
+ !cfg'' = setEdgeWeight cfg' (normalizeWeight w2) node s2
+
+ -- Figure out which heuristics apply to these successors
+ heuristics = map ($ ((s1,s1_info),(s2,s2_info)))
+ [lehPredicts, phPredicts, ohPredicts, ghPredicts, lhhPredicts, chPredicts
+ , shPredicts, rhPredicts]
+ -- Apply result of a heuristic. Argument is the likelyhood
+ -- predicted for s1.
+ applyHeuristic :: CFG -> Maybe Prob -> CFG
+ applyHeuristic cfg Nothing = cfg
+ applyHeuristic cfg (Just (s1_pred :: Double))
+ | s1_old == 0 || s2_old == 0 ||
+ isHeapOrStackCheck (transitionSource s1_info) ||
+ isHeapOrStackCheck (transitionSource s2_info)
+ = cfg
+ | otherwise =
+ let -- Predictions from heuristic
+ s1_prob = EdgeWeight s1_pred :: EdgeWeight
+ s2_prob = 1.0 - s1_prob
+ -- Update
+ d = (s1_old * s1_prob) + (s2_old * s2_prob) :: EdgeWeight
+ s1_prob' = s1_old * s1_prob / d
+ !s2_prob' = s2_old * s2_prob / d
+ !cfg_s1 = setEdgeWeight cfg s1_prob' node s1
+ in -- pprTrace "Applying heuristic!" (ppr (node,s1,s2) $$ ppr (s1_prob', s2_prob')) $
+ setEdgeWeight cfg_s1 s2_prob' node s2
+ where
+ -- Old weights
+ s1_old = getEdgeWeight cfg node s1
+ s2_old = getEdgeWeight cfg node s2
+
+ in
+ -- pprTraceIt "RegularCfgResult" $
+ foldl' applyHeuristic cfg'' heuristics
+
+ -- Branch on heap/stack check
+ | otherwise = cfg
+
+ where
+ -- Chance that loops are taken.
+ pred_LBH = 0.875
+ -- successors
+ successors = getSuccessorEdges cfg node
+ -- backedges
+ (m,not_m) = partition (\succ -> S.member (node, fst succ) backedges) successors
+
+ -- Heuristics return nothing if they don't say anything about this branch
+ -- or Just (prob_s1) where prob_s1 is the likelyhood for s1 to be the
+ -- taken branch. s1 is the branch in the true case.
+
+ -- Loop exit heuristic.
+ -- We are unlikely to leave a loop unless it's to enter another one.
+ pred_LEH = 0.75
+ -- If and only if no successor is a loopheader,
+ -- then we will likely not exit the current loop body.
+ lehPredicts :: (TargetNodeInfo,TargetNodeInfo) -> Maybe Prob
+ lehPredicts ((s1,_s1_info),(s2,_s2_info))
+ | S.member s1 loopHeads || S.member s2 loopHeads
+ = Nothing
+
+ | otherwise
+ = --pprTrace "lehPredict:" (ppr $ compare s1Level s2Level) $
+ case compare s1Level s2Level of
+ EQ -> Nothing
+ LT -> Just (1-pred_LEH) --s1 exits to a shallower loop level (exits loop)
+ GT -> Just (pred_LEH) --s1 exits to a deeper loop level
+ where
+ s1Level = mapLookup s1 loopLevels
+ s2Level = mapLookup s2 loopLevels
+
+ -- Comparing to a constant is unlikely to be equal.
+ ohPredicts (s1,_s2)
+ | CmmSource { trans_cmmNode = src1 } <- getTransitionSource node (fst s1) cfg
+ , CmmCondBranch cond ltrue _lfalse likely <- src1
+ , likely == Nothing
+ , CmmMachOp mop args <- cond
+ , MO_Eq {} <- mop
+ , not (null [x | x@CmmLit{} <- args])
+ = if fst s1 == ltrue then Just 0.3 else Just 0.7
+
+ | otherwise
+ = Nothing
+
+ -- TODO: These are all the other heuristics from the paper.
+ -- Not all will apply, for now we just stub them out as Nothing.
+ phPredicts = const Nothing
+ ghPredicts = const Nothing
+ lhhPredicts = const Nothing
+ chPredicts = const Nothing
+ shPredicts = const Nothing
+ rhPredicts = const Nothing
+
+-- We normalize all edge weights as probabilities between 0 and 1.
+-- Ignoring rounding errors all outgoing edges sum up to 1.
+cfgEdgeProbabilities :: CFG -> (BlockId -> Int) -> IM.IntMap (IM.IntMap Prob)
+cfgEdgeProbabilities cfg toVertex
+ = mapFoldlWithKey foldEdges IM.empty cfg
+ where
+ foldEdges = (\m from toMap -> IM.insert (toVertex from) (normalize toMap) m)
+
+ normalize :: (LabelMap EdgeInfo) -> (IM.IntMap Prob)
+ normalize weightMap
+ | edgeCount <= 1 = mapFoldlWithKey (\m k _ -> IM.insert (toVertex k) 1.0 m) IM.empty weightMap
+ | otherwise = mapFoldlWithKey (\m k _ -> IM.insert (toVertex k) (normalWeight k) m) IM.empty weightMap
+ where
+ edgeCount = mapSize weightMap
+ -- Negative weights are generally allowed but are mapped to zero.
+ -- We then check if there is at least one non-zero edge and if not
+ -- assign uniform weights to all branches.
+ minWeight = 0 :: Prob
+ weightMap' = fmap (\w -> max (weightToDouble . edgeWeight $ w) minWeight) weightMap
+ totalWeight = sum weightMap'
+
+ normalWeight :: BlockId -> Prob
+ normalWeight bid
+ | totalWeight == 0
+ = 1.0 / fromIntegral edgeCount
+ | Just w <- mapLookup bid weightMap'
+ = w/totalWeight
+ | otherwise = panic "impossible"
+
+-- This is the fixpoint algorithm from
+-- "Static Branch Prediction and Program Profile Analysis" by Y Wu, JR Larus
+-- The adaption to Haskell is my own.
+calcFreqs :: IM.IntMap (IM.IntMap Prob) -> [(Int,Int)] -> [(Int, [Int])] -> [Int]
+ -> (Array Int Double, IM.IntMap (IM.IntMap Prob))
+calcFreqs graph backEdges loops revPostOrder = runST $ do
+ visitedNodes <- newArray (0,nodeCount-1) False :: ST s (STUArray s Int Bool)
+ blockFreqs <- newArray (0,nodeCount-1) 0.0 :: ST s (STUArray s Int Double)
+ edgeProbs <- newSTRef graph
+ edgeBackProbs <- newSTRef graph
+
+ -- let traceArray a = do
+ -- vs <- forM [0..nodeCount-1] $ \i -> readArray a i >>= (\v -> return (i,v))
+ -- trace ("array: " ++ show vs) $ return ()
+
+ let -- See #1600, we need to inline or unboxing makes perf worse.
+ -- {-# INLINE getFreq #-}
+ {-# INLINE visited #-}
+ visited b = unsafeRead visitedNodes b
+ getFreq b = unsafeRead blockFreqs b
+ -- setFreq :: forall s. Int -> Double -> ST s ()
+ setFreq b f = unsafeWrite blockFreqs b f
+ -- setVisited :: forall s. Node -> ST s ()
+ setVisited b = unsafeWrite visitedNodes b True
+ -- Frequency/probability that edge is taken.
+ getProb' arr b1 b2 = readSTRef arr >>=
+ (\graph ->
+ return .
+ fromMaybe (error "getFreq 1") .
+ IM.lookup b2 .
+ fromMaybe (error "getFreq 2") $
+ (IM.lookup b1 graph)
+ )
+ setProb' arr b1 b2 prob = do
+ g <- readSTRef arr
+ let !m = fromMaybe (error "Foo") $ IM.lookup b1 g
+ !m' = IM.insert b2 prob m
+ writeSTRef arr $! (IM.insert b1 m' g)
+
+ getEdgeFreq b1 b2 = getProb' edgeProbs b1 b2
+ setEdgeFreq b1 b2 = setProb' edgeProbs b1 b2
+ getProb b1 b2 = fromMaybe (error "getProb") $ do
+ m' <- IM.lookup b1 graph
+ IM.lookup b2 m'
+
+ getBackProb b1 b2 = getProb' edgeBackProbs b1 b2
+ setBackProb b1 b2 = setProb' edgeBackProbs b1 b2
+
+
+ let -- calcOutFreqs :: Node -> ST s ()
+ calcOutFreqs bhead block = do
+ !f <- getFreq block
+ forM (successors block) $ \bi -> do
+ let !prob = getProb block bi
+ let !succFreq = f * prob
+ setEdgeFreq block bi succFreq
+ -- traceM $ "SetOut: " ++ show (block, bi, f, prob, succFreq)
+ when (bi == bhead) $ setBackProb block bi succFreq
+
+
+ let propFreq block head = do
+ -- traceM ("prop:" ++ show (block,head))
+ -- traceShowM block
+
+ !v <- visited block
+ if v then
+ return () --Dont look at nodes twice
+ else if block == head then
+ setFreq block 1.0 -- Loop header frequency is always 1
+ else do
+ let preds = IS.elems $ predecessors block
+ irreducible <- (fmap or) $ forM preds $ \bp -> do
+ !bp_visited <- visited bp
+ let bp_backedge = isBackEdge bp block
+ return (not bp_visited && not bp_backedge)
+
+ if irreducible
+ then return () -- Rare we don't care
+ else do
+ setFreq block 0
+ !cycleProb <- sum <$> (forM preds $ \pred -> do
+ if isBackEdge pred block
+ then
+ getBackProb pred block
+ else do
+ !f <- getFreq block
+ !prob <- getEdgeFreq pred block
+ setFreq block $! f + prob
+ return 0)
+ -- traceM $ "cycleProb:" ++ show cycleProb
+ let limit = 1 - 1/512 -- Paper uses 1 - epsilon, but this works.
+ -- determines how large likelyhoods in loops can grow.
+ !cycleProb <- return $ min cycleProb limit -- <- return $ if cycleProb > limit then limit else cycleProb
+ -- traceM $ "cycleProb:" ++ show cycleProb
+
+ !f <- getFreq block
+ setFreq block (f / (1.0 - cycleProb))
+
+ setVisited block
+ calcOutFreqs head block
+
+ -- Loops, by nesting, inner to outer
+ forM_ loops $ \(head, body) -> do
+ forM_ [0 .. nodeCount - 1] (\i -> unsafeWrite visitedNodes i True) -- Mark all nodes as visited.
+ forM_ body (\i -> unsafeWrite visitedNodes i False) -- Mark all blocks reachable from head as not visited
+ forM_ body $ \block -> propFreq block head
+
+ -- After dealing with all loops, deal with non-looping parts of the CFG
+ forM_ [0 .. nodeCount - 1] (\i -> unsafeWrite visitedNodes i False) -- Everything in revPostOrder is reachable
+ forM_ revPostOrder $ \block -> propFreq block (head revPostOrder)
+
+ -- trace ("Final freqs:") $ return ()
+ -- let freqString = pprFreqs freqs
+ -- trace (unlines freqString) $ return ()
+ -- trace (pprFre) $ return ()
+ graph' <- readSTRef edgeProbs
+ freqs' <- unsafeFreeze blockFreqs
+
+ return (freqs', graph')
+ where
+ predecessors :: Int -> IS.IntSet
+ predecessors b = fromMaybe IS.empty $ IM.lookup b revGraph
+ successors b = fromMaybe (lookupError "succ" b graph)$ IM.keys <$> IM.lookup b graph
+ lookupError s b g = pprPanic ("Lookup error " ++ s) $
+ ( text "node" <+> ppr b $$
+ text "graph" <+>
+ vcat (map (\(k,m) -> ppr (k,m :: IM.IntMap Double)) $ IM.toList g)
+ )
+
+ nodeCount = IM.foldl' (\count toMap -> IM.foldlWithKey' countTargets count toMap) (IM.size graph) graph
+ where
+ countTargets = (\count k _ -> countNode k + count )
+ countNode n = if IM.member n graph then 0 else 1
+
+ isBackEdge from to = S.member (from,to) backEdgeSet
+ backEdgeSet = S.fromList backEdges
+
+ revGraph :: IntMap IntSet
+ revGraph = IM.foldlWithKey' (\m from toMap -> addEdges m from toMap) IM.empty graph
+ where
+ addEdges m0 from toMap = IM.foldlWithKey' (\m k _ -> addEdge m from k) m0 toMap
+ addEdge m0 from to = IM.insertWith IS.union to (IS.singleton from) m0
diff --git a/compiler/nativeGen/RegAlloc/Graph/SpillCost.hs b/compiler/nativeGen/RegAlloc/Graph/SpillCost.hs
index 9c6e24d320..52f590948a 100644
--- a/compiler/nativeGen/RegAlloc/Graph/SpillCost.hs
+++ b/compiler/nativeGen/RegAlloc/Graph/SpillCost.hs
@@ -1,4 +1,4 @@
-{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE ScopedTypeVariables, GADTs, BangPatterns #-}
module RegAlloc.Graph.SpillCost (
SpillCostRecord,
plusSpillCostRecord,
@@ -23,6 +23,7 @@ import Reg
import GraphBase
import Hoopl.Collections (mapLookup)
+import Hoopl.Label
import Cmm
import UniqFM
import UniqSet
@@ -49,9 +50,6 @@ type SpillCostRecord
type SpillCostInfo
= UniqFM SpillCostRecord
--- | Block membership in a loop
-type LoopMember = Bool
-
type SpillCostState = State (UniqFM SpillCostRecord) ()
-- | An empty map of spill costs.
@@ -88,45 +86,49 @@ slurpSpillCostInfo platform cfg cmm
where
countCmm CmmData{} = return ()
countCmm (CmmProc info _ _ sccs)
- = mapM_ (countBlock info)
+ = mapM_ (countBlock info freqMap)
$ flattenSCCs sccs
+ where
+ LiveInfo _ entries _ _ = info
+ freqMap = (fst . mkGlobalWeights (head entries)) <$> cfg
-- Lookup the regs that are live on entry to this block in
-- the info table from the CmmProc.
- countBlock info (BasicBlock blockId instrs)
+ countBlock info freqMap (BasicBlock blockId instrs)
| LiveInfo _ _ blockLive _ <- info
, Just rsLiveEntry <- mapLookup blockId blockLive
, rsLiveEntry_virt <- takeVirtuals rsLiveEntry
- = countLIs (loopMember blockId) rsLiveEntry_virt instrs
+ = countLIs (ceiling $ blockFreq freqMap blockId) rsLiveEntry_virt instrs
| otherwise
= error "RegAlloc.SpillCost.slurpSpillCostInfo: bad block"
- countLIs :: LoopMember -> UniqSet VirtualReg -> [LiveInstr instr] -> SpillCostState
+
+ countLIs :: Int -> UniqSet VirtualReg -> [LiveInstr instr] -> SpillCostState
countLIs _ _ []
= return ()
-- Skip over comment and delta pseudo instrs.
- countLIs inLoop rsLive (LiveInstr instr Nothing : lis)
+ countLIs scale rsLive (LiveInstr instr Nothing : lis)
| isMetaInstr instr
- = countLIs inLoop rsLive lis
+ = countLIs scale rsLive lis
| otherwise
= pprPanic "RegSpillCost.slurpSpillCostInfo"
$ text "no liveness information on instruction " <> ppr instr
- countLIs inLoop rsLiveEntry (LiveInstr instr (Just live) : lis)
+ countLIs scale rsLiveEntry (LiveInstr instr (Just live) : lis)
= do
-- Increment the lifetime counts for regs live on entry to this instr.
- mapM_ (incLifetime (loopCount inLoop)) $ nonDetEltsUniqSet rsLiveEntry
+ mapM_ incLifetime $ nonDetEltsUniqSet rsLiveEntry
-- This is non-deterministic but we do not
-- currently support deterministic code-generation.
-- See Note [Unique Determinism and code generation]
-- Increment counts for what regs were read/written from.
let (RU read written) = regUsageOfInstr platform instr
- mapM_ (incUses (loopCount inLoop)) $ catMaybes $ map takeVirtualReg $ nub read
- mapM_ (incDefs (loopCount inLoop)) $ catMaybes $ map takeVirtualReg $ nub written
+ mapM_ (incUses scale) $ catMaybes $ map takeVirtualReg $ nub read
+ mapM_ (incDefs scale) $ catMaybes $ map takeVirtualReg $ nub written
-- Compute liveness for entry to next instruction.
let liveDieRead_virt = takeVirtuals (liveDieRead live)
@@ -140,21 +142,18 @@ slurpSpillCostInfo platform cfg cmm
= (rsLiveAcross `unionUniqSets` liveBorn_virt)
`minusUniqSet` liveDieWrite_virt
- countLIs inLoop rsLiveNext lis
+ countLIs scale rsLiveNext lis
- loopCount inLoop
- | inLoop = 10
- | otherwise = 1
incDefs count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, count, 0, 0)
incUses count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, count, 0)
- incLifetime count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, 0, count)
+ incLifetime reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, 0, 1)
- loopBlocks = CFG.loopMembers <$> cfg
- loopMember bid
- | Just isMember <- join (mapLookup bid <$> loopBlocks)
- = isMember
+ blockFreq :: Maybe (LabelMap Double) -> Label -> Double
+ blockFreq freqs bid
+ | Just freq <- join (mapLookup bid <$> freqs)
+ = max 1.0 (10000 * freq)
| otherwise
- = False
+ = 1.0 -- Only if no cfg given
-- | Take all the virtual registers from this set.
takeVirtuals :: UniqSet Reg -> UniqSet VirtualReg
@@ -215,31 +214,39 @@ chooseSpill info graph
-- Without live range splitting, its's better to spill from the outside
-- in so set the cost of very long live ranges to zero
--
-{-
-spillCost_chaitin
- :: SpillCostInfo
- -> Graph Reg RegClass Reg
- -> Reg
- -> Float
-spillCost_chaitin info graph reg
- -- Spilling a live range that only lives for 1 instruction
- -- isn't going to help us at all - and we definitely want to avoid
- -- trying to re-spill previously inserted spill code.
- | lifetime <= 1 = 1/0
-
- -- It's unlikely that we'll find a reg for a live range this long
- -- better to spill it straight up and not risk trying to keep it around
- -- and have to go through the build/color cycle again.
- | lifetime > allocatableRegsInClass (regClass reg) * 10
- = 0
+-- spillCost_chaitin
+-- :: SpillCostInfo
+-- -> Graph VirtualReg RegClass RealReg
+-- -> VirtualReg
+-- -> Float
+
+-- spillCost_chaitin info graph reg
+-- -- Spilling a live range that only lives for 1 instruction
+-- -- isn't going to help us at all - and we definitely want to avoid
+-- -- trying to re-spill previously inserted spill code.
+-- | lifetime <= 1 = 1/0
+
+-- -- It's unlikely that we'll find a reg for a live range this long
+-- -- better to spill it straight up and not risk trying to keep it around
+-- -- and have to go through the build/color cycle again.
+
+-- -- To facility this we scale down the spill cost of long ranges.
+-- -- This makes sure long ranges are still spilled first.
+-- -- But this way spill cost remains relevant for long live
+-- -- ranges.
+-- | lifetime >= 128
+-- = (spillCost / conflicts) / 10.0
+
+
+-- -- Otherwise revert to chaitin's regular cost function.
+-- | otherwise = (spillCost / conflicts)
+-- where
+-- !spillCost = fromIntegral (uses + defs) :: Float
+-- conflicts = fromIntegral (nodeDegree classOfVirtualReg graph reg)
+-- (_, defs, uses, lifetime)
+-- = fromMaybe (reg, 0, 0, 0) $ lookupUFM info reg
- -- Otherwise revert to chaitin's regular cost function.
- | otherwise = fromIntegral (uses + defs)
- / fromIntegral (nodeDegree graph reg)
- where (_, defs, uses, lifetime)
- = fromMaybe (reg, 0, 0, 0) $ lookupUFM info reg
--}
-- Just spill the longest live range.
spillCost_length
diff --git a/compiler/nativeGen/X86/CodeGen.hs b/compiler/nativeGen/X86/CodeGen.hs
index 670950d754..672f6d3b62 100644
--- a/compiler/nativeGen/X86/CodeGen.hs
+++ b/compiler/nativeGen/X86/CodeGen.hs
@@ -3421,7 +3421,7 @@ invertCondBranches cfg keep bs =
, Just edgeInfo2 <- getEdgeInfo lbl1 target2 cfg
-- Both jumps come from the same cmm statement
, transitionSource edgeInfo1 == transitionSource edgeInfo2
- , (CmmSource cmmCondBranch) <- transitionSource edgeInfo1
+ , CmmSource {trans_cmmNode = cmmCondBranch} <- transitionSource edgeInfo1
--Int comparisons are invertable
, CmmCondBranch (CmmMachOp op _args) _ _ _ <- cmmCondBranch
diff --git a/compiler/utils/Dominators.hs b/compiler/utils/Dominators.hs
new file mode 100644
index 0000000000..9877c2c1f0
--- /dev/null
+++ b/compiler/utils/Dominators.hs
@@ -0,0 +1,588 @@
+{-# LANGUAGE RankNTypes, BangPatterns, FlexibleContexts, Strict #-}
+
+{- |
+ Module : Dominators
+ Copyright : (c) Matt Morrow 2009
+ License : BSD3
+ Maintainer : <morrow@moonpatio.com>
+ Stability : experimental
+ Portability : portable
+
+ Taken from the dom-lt package.
+
+ The Lengauer-Tarjan graph dominators algorithm.
+
+ \[1\] Lengauer, Tarjan,
+ /A Fast Algorithm for Finding Dominators in a Flowgraph/, 1979.
+
+ \[2\] Muchnick,
+ /Advanced Compiler Design and Implementation/, 1997.
+
+ \[3\] Brisk, Sarrafzadeh,
+ /Interference Graphs for Procedures in Static Single/
+ /Information Form are Interval Graphs/, 2007.
+
+ Originally taken from the dom-lt package.
+-}
+
+module Dominators (
+ Node,Path,Edge
+ ,Graph,Rooted
+ ,idom,ipdom
+ ,domTree,pdomTree
+ ,dom,pdom
+ ,pddfs,rpddfs
+ ,fromAdj,fromEdges
+ ,toAdj,toEdges
+ ,asTree,asGraph
+ ,parents,ancestors
+) where
+
+import GhcPrelude
+
+import Data.Bifunctor
+import Data.Tuple (swap)
+
+import Data.Tree
+import Data.IntMap(IntMap)
+import Data.IntSet(IntSet)
+import qualified Data.IntMap.Strict as IM
+import qualified Data.IntSet as IS
+
+import Control.Monad
+import Control.Monad.ST.Strict
+
+import Data.Array.ST
+import Data.Array.Base
+ (unsafeNewArray_
+ ,unsafeWrite,unsafeRead)
+
+-----------------------------------------------------------------------------
+
+type Node = Int
+type Path = [Node]
+type Edge = (Node,Node)
+type Graph = IntMap IntSet
+type Rooted = (Node, Graph)
+
+-----------------------------------------------------------------------------
+
+-- | /Dominators/.
+-- Complexity as for @idom@
+dom :: Rooted -> [(Node, Path)]
+dom = ancestors . domTree
+
+-- | /Post-dominators/.
+-- Complexity as for @idom@.
+pdom :: Rooted -> [(Node, Path)]
+pdom = ancestors . pdomTree
+
+-- | /Dominator tree/.
+-- Complexity as for @idom@.
+domTree :: Rooted -> Tree Node
+domTree a@(r,_) =
+ let is = filter ((/=r).fst) (idom a)
+ tg = fromEdges (fmap swap is)
+ in asTree (r,tg)
+
+-- | /Post-dominator tree/.
+-- Complexity as for @idom@.
+pdomTree :: Rooted -> Tree Node
+pdomTree a@(r,_) =
+ let is = filter ((/=r).fst) (ipdom a)
+ tg = fromEdges (fmap swap is)
+ in asTree (r,tg)
+
+-- | /Immediate dominators/.
+-- /O(|E|*alpha(|E|,|V|))/, where /alpha(m,n)/ is
+-- \"a functional inverse of Ackermann's function\".
+--
+-- This Complexity bound assumes /O(1)/ indexing. Since we're
+-- using @IntMap@, it has an additional /lg |V|/ factor
+-- somewhere in there. I'm not sure where.
+idom :: Rooted -> [(Node,Node)]
+idom rg = runST (evalS idomM =<< initEnv (pruneReach rg))
+
+-- | /Immediate post-dominators/.
+-- Complexity as for @idom@.
+ipdom :: Rooted -> [(Node,Node)]
+ipdom rg = runST (evalS idomM =<< initEnv (pruneReach (second predG rg)))
+
+-----------------------------------------------------------------------------
+
+-- | /Post-dominated depth-first search/.
+pddfs :: Rooted -> [Node]
+pddfs = reverse . rpddfs
+
+-- | /Reverse post-dominated depth-first search/.
+rpddfs :: Rooted -> [Node]
+rpddfs = concat . levels . pdomTree
+
+-----------------------------------------------------------------------------
+
+type Dom s a = S s (Env s) a
+type NodeSet = IntSet
+type NodeMap a = IntMap a
+data Env s = Env
+ {succE :: !Graph
+ ,predE :: !Graph
+ ,bucketE :: !Graph
+ ,dfsE :: {-# UNPACK #-}!Int
+ ,zeroE :: {-# UNPACK #-}!Node
+ ,rootE :: {-# UNPACK #-}!Node
+ ,labelE :: {-# UNPACK #-}!(Arr s Node)
+ ,parentE :: {-# UNPACK #-}!(Arr s Node)
+ ,ancestorE :: {-# UNPACK #-}!(Arr s Node)
+ ,childE :: {-# UNPACK #-}!(Arr s Node)
+ ,ndfsE :: {-# UNPACK #-}!(Arr s Node)
+ ,dfnE :: {-# UNPACK #-}!(Arr s Int)
+ ,sdnoE :: {-# UNPACK #-}!(Arr s Int)
+ ,sizeE :: {-# UNPACK #-}!(Arr s Int)
+ ,domE :: {-# UNPACK #-}!(Arr s Node)
+ ,rnE :: {-# UNPACK #-}!(Arr s Node)}
+
+-----------------------------------------------------------------------------
+
+idomM :: Dom s [(Node,Node)]
+idomM = do
+ dfsDom =<< rootM
+ n <- gets dfsE
+ forM_ [n,n-1..1] (\i-> do
+ w <- ndfsM i
+ sw <- sdnoM w
+ ps <- predsM w
+ forM_ ps (\v-> do
+ u <- eval v
+ su <- sdnoM u
+ when (su < sw)
+ (store sdnoE w su))
+ z <- ndfsM =<< sdnoM w
+ modify(\e->e{bucketE=IM.adjust
+ (w`IS.insert`)
+ z (bucketE e)})
+ pw <- parentM w
+ link pw w
+ bps <- bucketM pw
+ forM_ bps (\v-> do
+ u <- eval v
+ su <- sdnoM u
+ sv <- sdnoM v
+ let dv = case su < sv of
+ True-> u
+ False-> pw
+ store domE v dv))
+ forM_ [1..n] (\i-> do
+ w <- ndfsM i
+ j <- sdnoM w
+ z <- ndfsM j
+ dw <- domM w
+ when (dw /= z)
+ (do ddw <- domM dw
+ store domE w ddw))
+ fromEnv
+
+-----------------------------------------------------------------------------
+
+eval :: Node -> Dom s Node
+eval v = do
+ n0 <- zeroM
+ a <- ancestorM v
+ case a==n0 of
+ True-> labelM v
+ False-> do
+ compress v
+ a <- ancestorM v
+ l <- labelM v
+ la <- labelM a
+ sl <- sdnoM l
+ sla <- sdnoM la
+ case sl <= sla of
+ True-> return l
+ False-> return la
+
+compress :: Node -> Dom s ()
+compress v = do
+ n0 <- zeroM
+ a <- ancestorM v
+ aa <- ancestorM a
+ when (aa /= n0) (do
+ compress a
+ a <- ancestorM v
+ aa <- ancestorM a
+ l <- labelM v
+ la <- labelM a
+ sl <- sdnoM l
+ sla <- sdnoM la
+ when (sla < sl)
+ (store labelE v la)
+ store ancestorE v aa)
+
+-----------------------------------------------------------------------------
+
+link :: Node -> Node -> Dom s ()
+link v w = do
+ n0 <- zeroM
+ lw <- labelM w
+ slw <- sdnoM lw
+ let balance s = do
+ c <- childM s
+ lc <- labelM c
+ slc <- sdnoM lc
+ case slw < slc of
+ False-> return s
+ True-> do
+ zs <- sizeM s
+ zc <- sizeM c
+ cc <- childM c
+ zcc <- sizeM cc
+ case 2*zc <= zs+zcc of
+ True-> do
+ store ancestorE c s
+ store childE s cc
+ balance s
+ False-> do
+ store sizeE c zs
+ store ancestorE s c
+ balance c
+ s <- balance w
+ lw <- labelM w
+ zw <- sizeM w
+ store labelE s lw
+ store sizeE v . (+zw) =<< sizeM v
+ let follow s = do
+ when (s /= n0) (do
+ store ancestorE s v
+ follow =<< childM s)
+ zv <- sizeM v
+ follow =<< case zv < 2*zw of
+ False-> return s
+ True-> do
+ cv <- childM v
+ store childE v s
+ return cv
+
+-----------------------------------------------------------------------------
+
+dfsDom :: Node -> Dom s ()
+dfsDom i = do
+ _ <- go i
+ n0 <- zeroM
+ r <- rootM
+ store parentE r n0
+ where go i = do
+ n <- nextM
+ store dfnE i n
+ store sdnoE i n
+ store ndfsE n i
+ store labelE i i
+ ss <- succsM i
+ forM_ ss (\j-> do
+ s <- sdnoM j
+ case s==0 of
+ False-> return()
+ True-> do
+ store parentE j i
+ go j)
+
+-----------------------------------------------------------------------------
+
+initEnv :: Rooted -> ST s (Env s)
+initEnv (r0,g0) = do
+ let (g,rnmap) = renum 1 g0
+ pred = predG g
+ r = rnmap IM.! r0
+ n = IM.size g
+ ns = [0..n]
+ m = n+1
+
+ let bucket = IM.fromList
+ (zip ns (repeat mempty))
+
+ rna <- newI m
+ writes rna (fmap swap
+ (IM.toList rnmap))
+
+ doms <- newI m
+ sdno <- newI m
+ size <- newI m
+ parent <- newI m
+ ancestor <- newI m
+ child <- newI m
+ label <- newI m
+ ndfs <- newI m
+ dfn <- newI m
+
+ forM_ [0..n] (doms.=0)
+ forM_ [0..n] (sdno.=0)
+ forM_ [1..n] (size.=1)
+ forM_ [0..n] (ancestor.=0)
+ forM_ [0..n] (child.=0)
+
+ (doms.=r) r
+ (size.=0) 0
+ (label.=0) 0
+
+ return (Env
+ {rnE = rna
+ ,dfsE = 0
+ ,zeroE = 0
+ ,rootE = r
+ ,labelE = label
+ ,parentE = parent
+ ,ancestorE = ancestor
+ ,childE = child
+ ,ndfsE = ndfs
+ ,dfnE = dfn
+ ,sdnoE = sdno
+ ,sizeE = size
+ ,succE = g
+ ,predE = pred
+ ,bucketE = bucket
+ ,domE = doms})
+
+fromEnv :: Dom s [(Node,Node)]
+fromEnv = do
+ dom <- gets domE
+ rn <- gets rnE
+ -- r <- gets rootE
+ (_,n) <- st (getBounds dom)
+ forM [1..n] (\i-> do
+ j <- st(rn!:i)
+ d <- st(dom!:i)
+ k <- st(rn!:d)
+ return (j,k))
+
+-----------------------------------------------------------------------------
+
+zeroM :: Dom s Node
+zeroM = gets zeroE
+domM :: Node -> Dom s Node
+domM = fetch domE
+rootM :: Dom s Node
+rootM = gets rootE
+succsM :: Node -> Dom s [Node]
+succsM i = gets (IS.toList . (!i) . succE)
+predsM :: Node -> Dom s [Node]
+predsM i = gets (IS.toList . (!i) . predE)
+bucketM :: Node -> Dom s [Node]
+bucketM i = gets (IS.toList . (!i) . bucketE)
+sizeM :: Node -> Dom s Int
+sizeM = fetch sizeE
+sdnoM :: Node -> Dom s Int
+sdnoM = fetch sdnoE
+-- dfnM :: Node -> Dom s Int
+-- dfnM = fetch dfnE
+ndfsM :: Int -> Dom s Node
+ndfsM = fetch ndfsE
+childM :: Node -> Dom s Node
+childM = fetch childE
+ancestorM :: Node -> Dom s Node
+ancestorM = fetch ancestorE
+parentM :: Node -> Dom s Node
+parentM = fetch parentE
+labelM :: Node -> Dom s Node
+labelM = fetch labelE
+nextM :: Dom s Int
+nextM = do
+ n <- gets dfsE
+ let n' = n+1
+ modify(\e->e{dfsE=n'})
+ return n'
+
+-----------------------------------------------------------------------------
+
+type A = STUArray
+type Arr s a = A s Int a
+
+infixl 9 !:
+infixr 2 .=
+
+(.=) :: (MArray (A s) a (ST s))
+ => Arr s a -> a -> Int -> ST s ()
+(v .= x) i = unsafeWrite v i x
+
+(!:) :: (MArray (A s) a (ST s))
+ => A s Int a -> Int -> ST s a
+a !: i = do
+ o <- unsafeRead a i
+ return $! o
+
+new :: (MArray (A s) a (ST s))
+ => Int -> ST s (Arr s a)
+new n = unsafeNewArray_ (0,n-1)
+
+newI :: Int -> ST s (Arr s Int)
+newI = new
+
+-- newD :: Int -> ST s (Arr s Double)
+-- newD = new
+
+-- dump :: (MArray (A s) a (ST s)) => Arr s a -> ST s [a]
+-- dump a = do
+-- (m,n) <- getBounds a
+-- forM [m..n] (\i -> a!:i)
+
+writes :: (MArray (A s) a (ST s))
+ => Arr s a -> [(Int,a)] -> ST s ()
+writes a xs = forM_ xs (\(i,x) -> (a.=x) i)
+
+-- arr :: (MArray (A s) a (ST s)) => [a] -> ST s (Arr s a)
+-- arr xs = do
+-- let n = length xs
+-- a <- new n
+-- go a n 0 xs
+-- return a
+-- where go _ _ _ [] = return ()
+-- go a n i (x:xs)
+-- | i <= n = (a.=x) i >> go a n (i+1) xs
+-- | otherwise = return ()
+
+-----------------------------------------------------------------------------
+
+(!) :: Monoid a => IntMap a -> Int -> a
+(!) g n = maybe mempty id (IM.lookup n g)
+
+fromAdj :: [(Node, [Node])] -> Graph
+fromAdj = IM.fromList . fmap (second IS.fromList)
+
+fromEdges :: [Edge] -> Graph
+fromEdges = collectI IS.union fst (IS.singleton . snd)
+
+toAdj :: Graph -> [(Node, [Node])]
+toAdj = fmap (second IS.toList) . IM.toList
+
+toEdges :: Graph -> [Edge]
+toEdges = concatMap (uncurry (fmap . (,))) . toAdj
+
+predG :: Graph -> Graph
+predG g = IM.unionWith IS.union (go g) g0
+ where g0 = fmap (const mempty) g
+ f :: IntMap IntSet -> Int -> IntSet -> IntMap IntSet
+ f m i a = foldl' (\m p -> IM.insertWith mappend p
+ (IS.singleton i) m)
+ m
+ (IS.toList a)
+ go :: IntMap IntSet -> IntMap IntSet
+ go = flip IM.foldlWithKey' mempty f
+
+pruneReach :: Rooted -> Rooted
+pruneReach (r,g) = (r,g2)
+ where is = reachable
+ (maybe mempty id
+ . flip IM.lookup g) $ r
+ g2 = IM.fromList
+ . fmap (second (IS.filter (`IS.member`is)))
+ . filter ((`IS.member`is) . fst)
+ . IM.toList $ g
+
+tip :: Tree a -> (a, [Tree a])
+tip (Node a ts) = (a, ts)
+
+parents :: Tree a -> [(a, a)]
+parents (Node i xs) = p i xs
+ ++ concatMap parents xs
+ where p i = fmap (flip (,) i . rootLabel)
+
+ancestors :: Tree a -> [(a, [a])]
+ancestors = go []
+ where go acc (Node i xs)
+ = let acc' = i:acc
+ in p acc' xs ++ concatMap (go acc') xs
+ p is = fmap (flip (,) is . rootLabel)
+
+asGraph :: Tree Node -> Rooted
+asGraph t@(Node a _) = let g = go t in (a, fromAdj g)
+ where go (Node a ts) = let as = (fst . unzip . fmap tip) ts
+ in (a, as) : concatMap go ts
+
+asTree :: Rooted -> Tree Node
+asTree (r,g) = let go a = Node a (fmap go ((IS.toList . f) a))
+ f = (g !)
+ in go r
+
+reachable :: (Node -> NodeSet) -> (Node -> NodeSet)
+reachable f a = go (IS.singleton a) a
+ where go seen a = let s = f a
+ as = IS.toList (s `IS.difference` seen)
+ in foldl' go (s `IS.union` seen) as
+
+collectI :: (c -> c -> c)
+ -> (a -> Int) -> (a -> c) -> [a] -> IntMap c
+collectI (<>) f g
+ = foldl' (\m a -> IM.insertWith (<>)
+ (f a)
+ (g a) m) mempty
+
+-- collect :: (Ord b) => (c -> c -> c)
+-- -> (a -> b) -> (a -> c) -> [a] -> Map b c
+-- collect (<>) f g
+-- = foldl' (\m a -> SM.insertWith (<>)
+-- (f a)
+-- (g a) m) mempty
+
+-- (renamed, old -> new)
+renum :: Int -> Graph -> (Graph, NodeMap Node)
+renum from = (\(_,m,g)->(g,m))
+ . IM.foldlWithKey'
+ f (from,mempty,mempty)
+ where
+ f :: (Int, NodeMap Node, IntMap IntSet) -> Node -> IntSet
+ -> (Int, NodeMap Node, IntMap IntSet)
+ f (!n,!env,!new) i ss =
+ let (j,n2,env2) = go n env i
+ (n3,env3,ss2) = IS.fold
+ (\k (!n,!env,!new)->
+ case go n env k of
+ (l,n2,env2)-> (n2,env2,l `IS.insert` new))
+ (n2,env2,mempty) ss
+ new2 = IM.insertWith IS.union j ss2 new
+ in (n3,env3,new2)
+ go :: Int
+ -> NodeMap Node
+ -> Node
+ -> (Node,Int,NodeMap Node)
+ go !n !env i =
+ case IM.lookup i env of
+ Just j -> (j,n,env)
+ Nothing -> (n,n+1,IM.insert i n env)
+
+-----------------------------------------------------------------------------
+
+newtype S z s a = S {unS :: forall o. (a -> s -> ST z o) -> s -> ST z o}
+instance Functor (S z s) where
+ fmap f (S g) = S (\k -> g (k . f))
+instance Monad (S z s) where
+ return = pure
+ S g >>= f = S (\k -> g (\a -> unS (f a) k))
+instance Applicative (S z s) where
+ pure a = S (\k -> k a)
+ (<*>) = ap
+-- get :: S z s s
+-- get = S (\k s -> k s s)
+gets :: (s -> a) -> S z s a
+gets f = S (\k s -> k (f s) s)
+-- set :: s -> S z s ()
+-- set s = S (\k _ -> k () s)
+modify :: (s -> s) -> S z s ()
+modify f = S (\k -> k () . f)
+-- runS :: S z s a -> s -> ST z (a, s)
+-- runS (S g) = g (\a s -> return (a,s))
+evalS :: S z s a -> s -> ST z a
+evalS (S g) = g ((return .) . const)
+-- execS :: S z s a -> s -> ST z s
+-- execS (S g) = g ((return .) . flip const)
+st :: ST z a -> S z s a
+st m = S (\k s-> do
+ a <- m
+ k a s)
+store :: (MArray (A z) a (ST z))
+ => (s -> Arr z a) -> Int -> a -> S z s ()
+store f i x = do
+ a <- gets f
+ st ((a.=x) i)
+fetch :: (MArray (A z) a (ST z))
+ => (s -> Arr z a) -> Int -> S z s a
+fetch f i = do
+ a <- gets f
+ st (a!:i)
+
diff --git a/compiler/utils/OrdList.hs b/compiler/utils/OrdList.hs
index e8b50e5968..8da5038b2c 100644
--- a/compiler/utils/OrdList.hs
+++ b/compiler/utils/OrdList.hs
@@ -10,14 +10,18 @@ can be appended in linear time.
-}
{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE BangPatterns #-}
+
module OrdList (
OrdList,
nilOL, isNilOL, unitOL, appOL, consOL, snocOL, concatOL, lastOL,
headOL,
- mapOL, fromOL, toOL, foldrOL, foldlOL, reverseOL, fromOLReverse
+ mapOL, fromOL, toOL, foldrOL, foldlOL, reverseOL, fromOLReverse,
+ strictlyEqOL, strictlyOrdOL
) where
import GhcPrelude
+import Data.Foldable
import Outputable
@@ -49,7 +53,11 @@ instance Monoid (OrdList a) where
mconcat = concatOL
instance Foldable OrdList where
- foldr = foldrOL
+ foldr = foldrOL
+ foldl' = foldlOL
+ toList = fromOL
+ null = isNilOL
+ length = lengthOL
instance Traversable OrdList where
traverse f xs = toOL <$> traverse f (fromOL xs)
@@ -64,7 +72,7 @@ appOL :: OrdList a -> OrdList a -> OrdList a
concatOL :: [OrdList a] -> OrdList a
headOL :: OrdList a -> a
lastOL :: OrdList a -> a
-
+lengthOL :: OrdList a -> Int
nilOL = None
unitOL as = One as
@@ -86,6 +94,13 @@ lastOL (Cons _ as) = lastOL as
lastOL (Snoc _ a) = a
lastOL (Two _ as) = lastOL as
+lengthOL None = 0
+lengthOL (One _) = 1
+lengthOL (Many as) = length as
+lengthOL (Cons _ as) = 1 + length as
+lengthOL (Snoc as _) = 1 + length as
+lengthOL (Two as bs) = length as + length bs
+
isNilOL None = True
isNilOL _ = False
@@ -126,13 +141,14 @@ foldrOL k z (Snoc xs x) = foldrOL k (k x z) xs
foldrOL k z (Two b1 b2) = foldrOL k (foldrOL k z b2) b1
foldrOL k z (Many xs) = foldr k z xs
+-- | Strict left fold.
foldlOL :: (b->a->b) -> b -> OrdList a -> b
foldlOL _ z None = z
foldlOL k z (One x) = k z x
-foldlOL k z (Cons x xs) = foldlOL k (k z x) xs
-foldlOL k z (Snoc xs x) = k (foldlOL k z xs) x
-foldlOL k z (Two b1 b2) = foldlOL k (foldlOL k z b1) b2
-foldlOL k z (Many xs) = foldl k z xs
+foldlOL k z (Cons x xs) = let !z' = (k z x) in foldlOL k z' xs
+foldlOL k z (Snoc xs x) = let !z' = (foldlOL k z xs) in k z' x
+foldlOL k z (Two b1 b2) = let !z' = (foldlOL k z b1) in foldlOL k z' b2
+foldlOL k z (Many xs) = foldl' k z xs
toOL :: [a] -> OrdList a
toOL [] = None
@@ -146,3 +162,33 @@ reverseOL (Cons a b) = Snoc (reverseOL b) a
reverseOL (Snoc a b) = Cons b (reverseOL a)
reverseOL (Two a b) = Two (reverseOL b) (reverseOL a)
reverseOL (Many xs) = Many (reverse xs)
+
+-- | Compare not only the values but also the structure of two lists
+strictlyEqOL :: Eq a => OrdList a -> OrdList a -> Bool
+strictlyEqOL None None = True
+strictlyEqOL (One x) (One y) = x == y
+strictlyEqOL (Cons a as) (Cons b bs) = a == b && as `strictlyEqOL` bs
+strictlyEqOL (Snoc as a) (Snoc bs b) = a == b && as `strictlyEqOL` bs
+strictlyEqOL (Two a1 a2) (Two b1 b2) = a1 `strictlyEqOL` b1 && a2 `strictlyEqOL` b2
+strictlyEqOL (Many as) (Many bs) = as == bs
+strictlyEqOL _ _ = False
+
+-- | Compare not only the values but also the structure of two lists
+strictlyOrdOL :: Ord a => OrdList a -> OrdList a -> Ordering
+strictlyOrdOL None None = EQ
+strictlyOrdOL None _ = LT
+strictlyOrdOL (One x) (One y) = compare x y
+strictlyOrdOL (One _) _ = LT
+strictlyOrdOL (Cons a as) (Cons b bs) =
+ compare a b `mappend` strictlyOrdOL as bs
+strictlyOrdOL (Cons _ _) _ = LT
+strictlyOrdOL (Snoc as a) (Snoc bs b) =
+ compare a b `mappend` strictlyOrdOL as bs
+strictlyOrdOL (Snoc _ _) _ = LT
+strictlyOrdOL (Two a1 a2) (Two b1 b2) =
+ (strictlyOrdOL a1 b1) `mappend` (strictlyOrdOL a2 b2)
+strictlyOrdOL (Two _ _) _ = LT
+strictlyOrdOL (Many as) (Many bs) = compare as bs
+strictlyOrdOL (Many _ ) _ = GT
+
+