summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/GHC/CmmToAsm/BlockLayout.hs44
-rw-r--r--compiler/GHC/Data/UnionFind.hs91
-rw-r--r--compiler/ghc.cabal.in1
3 files changed, 109 insertions, 27 deletions
diff --git a/compiler/GHC/CmmToAsm/BlockLayout.hs b/compiler/GHC/CmmToAsm/BlockLayout.hs
index 6e2e7e2189..70e131c717 100644
--- a/compiler/GHC/CmmToAsm/BlockLayout.hs
+++ b/compiler/GHC/CmmToAsm/BlockLayout.hs
@@ -41,12 +41,13 @@ import GHC.Utils.Panic
import GHC.Utils.Panic.Plain
import GHC.Utils.Misc
-import Data.List (sortOn, sortBy)
+import Data.List (sortOn, sortBy, nub)
import Data.Foldable (toList)
import qualified Data.Set as Set
import Data.STRef
import Control.Monad.ST.Strict
-import Control.Monad (foldM)
+import Control.Monad (foldM, unless)
+import GHC.Data.UnionFind
{-
Note [CFG based code layout]
@@ -480,10 +481,9 @@ combineNeighbourhood edges chains
mergeChains :: [CfgEdge] -> [BlockChain]
-> (BlockChain)
mergeChains edges chains
- = -- pprTrace "combine" (ppr edges) $
- runST $ do
+ = runST $ do
let addChain m0 chain = do
- ref <- newSTRef chain
+ ref <- fresh chain
return $ chainFoldl (\m' b -> mapInsert b ref m') m0 chain
chainMap' <- foldM (\m0 c -> addChain m0 c) mapEmpty chains
merge edges chainMap'
@@ -491,35 +491,25 @@ mergeChains edges chains
-- We keep a map from ALL blocks to their respective chain (sigh)
-- This is required since when looking at an edge we need to find
-- the associated chains quickly.
- -- We use a map of STRefs, maintaining a invariant of one STRef per chain.
- -- When merging chains we can update the
- -- STRef of one chain once (instead of writing to the map for each block).
- -- We then overwrite the STRefs for the other chain so there is again only
- -- a single STRef for the combined chain.
- -- The difference in terms of allocations saved is ~0.2% with -O so actually
- -- significant compared to using a regular map.
-
- merge :: forall s. [CfgEdge] -> LabelMap (STRef s BlockChain) -> ST s BlockChain
+ -- We use a union-find data structure to do this efficiently.
+
+ merge :: forall s. [CfgEdge] -> LabelMap (Point s BlockChain) -> ST s BlockChain
merge [] chains = do
- chains' <- ordNub <$> (mapM readSTRef $ mapElems chains) :: ST s [BlockChain]
+ chains' <- mapM find =<< (nub <$> (mapM repr $ mapElems chains)) :: ST s [BlockChain]
return $ foldl' chainConcat (head chains') (tail chains')
merge ((CfgEdge from to _):edges) chains
-- | pprTrace "merge" (ppr (from,to) <> ppr chains) False
-- = undefined
- | cFrom == cTo
- = merge edges chains
- | otherwise
= do
- chains' <- mergeComb cFrom cTo
- merge edges chains'
+ same <- equivalent cFrom cTo
+ unless same $ do
+ cRight <- find cTo
+ cLeft <- find cFrom
+ new_point <- fresh (chainConcat cLeft cRight)
+ union cTo new_point
+ union cFrom new_point
+ merge edges chains
where
- mergeComb :: STRef s BlockChain -> STRef s BlockChain -> ST s (LabelMap (STRef s BlockChain))
- mergeComb refFrom refTo = do
- cRight <- readSTRef refTo
- chain <- pure chainConcat <*> readSTRef refFrom <*> pure cRight
- writeSTRef refFrom chain
- return $ chainFoldl (\m b -> mapInsert b refFrom m) chains cRight
-
cFrom = expectJust "mergeChains:chainMap:from" $ mapLookup from chains
cTo = expectJust "mergeChains:chainMap:to" $ mapLookup to chains
diff --git a/compiler/GHC/Data/UnionFind.hs b/compiler/GHC/Data/UnionFind.hs
new file mode 100644
index 0000000000..21687b5a09
--- /dev/null
+++ b/compiler/GHC/Data/UnionFind.hs
@@ -0,0 +1,91 @@
+{- Union-find data structure compiled from Distribution.Utils.UnionFind -}
+module GHC.Data.UnionFind where
+
+import GHC.Prelude
+import Data.STRef
+import Control.Monad.ST
+import Control.Monad
+
+-- | A variable which can be unified; alternately, this can be thought
+-- of as an equivalence class with a distinguished representative.
+newtype Point s a = Point (STRef s (Link s a))
+ deriving (Eq)
+
+-- | Mutable write to a 'Point'
+writePoint :: Point s a -> Link s a -> ST s ()
+writePoint (Point v) = writeSTRef v
+
+-- | Read the current value of 'Point'.
+readPoint :: Point s a -> ST s (Link s a)
+readPoint (Point v) = readSTRef v
+
+-- | The internal data structure for a 'Point', which either records
+-- the representative element of an equivalence class, or a link to
+-- the 'Point' that actually stores the representative type.
+data Link s a
+ -- NB: it is too bad we can't say STRef Int#; the weights remain boxed
+ = Info {-# UNPACK #-} !(STRef s Int) {-# UNPACK #-} !(STRef s a)
+ | Link {-# UNPACK #-} !(Point s a)
+
+-- | Create a fresh equivalence class with one element.
+fresh :: a -> ST s (Point s a)
+fresh desc = do
+ weight <- newSTRef 1
+ descriptor <- newSTRef desc
+ Point `fmap` newSTRef (Info weight descriptor)
+
+-- | Flatten any chains of links, returning a 'Point'
+-- which points directly to the canonical representation.
+repr :: Point s a -> ST s (Point s a)
+repr point = readPoint point >>= \r ->
+ case r of
+ Link point' -> do
+ point'' <- repr point'
+ when (point'' /= point') $ do
+ writePoint point =<< readPoint point'
+ return point''
+ Info _ _ -> return point
+
+-- | Return the canonical element of an equivalence
+-- class 'Point'.
+find :: Point s a -> ST s a
+find point =
+ -- Optimize length 0 and 1 case at expense of
+ -- general case
+ readPoint point >>= \r ->
+ case r of
+ Info _ d_ref -> readSTRef d_ref
+ Link point' -> readPoint point' >>= \r' ->
+ case r' of
+ Info _ d_ref -> readSTRef d_ref
+ Link _ -> repr point >>= find
+
+-- | Unify two equivalence classes, so that they share
+-- a canonical element. Keeps the descriptor of point2.
+union :: Point s a -> Point s a -> ST s ()
+union refpoint1 refpoint2 = do
+ point1 <- repr refpoint1
+ point2 <- repr refpoint2
+ when (point1 /= point2) $ do
+ l1 <- readPoint point1
+ l2 <- readPoint point2
+ case (l1, l2) of
+ (Info wref1 dref1, Info wref2 dref2) -> do
+ weight1 <- readSTRef wref1
+ weight2 <- readSTRef wref2
+ -- Should be able to optimize the == case separately
+ if weight1 >= weight2
+ then do
+ writePoint point2 (Link point1)
+ -- The weight calculation here seems a bit dodgy
+ writeSTRef wref1 (weight1 + weight2)
+ writeSTRef dref1 =<< readSTRef dref2
+ else do
+ writePoint point1 (Link point2)
+ writeSTRef wref2 (weight1 + weight2)
+ _ -> error "UnionFind.union: repr invariant broken"
+
+-- | Test if two points are in the same equivalence class.
+equivalent :: Point s a -> Point s a -> ST s Bool
+equivalent point1 point2 = liftM2 (==) (repr point1) (repr point2)
+
diff --git a/compiler/ghc.cabal.in b/compiler/ghc.cabal.in
index 2023fbe3da..b1acf005da 100644
--- a/compiler/ghc.cabal.in
+++ b/compiler/ghc.cabal.in
@@ -407,6 +407,7 @@ Library
GHC.Data.Strict
GHC.Data.StringBuffer
GHC.Data.TrieMap
+ GHC.Data.UnionFind
GHC.Driver.Backend
GHC.Driver.Backpack
GHC.Driver.Backpack.Syntax