diff options
author | Matthew Pickering <matthewtpickering@gmail.com> | 2021-09-15 13:09:17 +0100 |
---|---|---|
committer | Marge Bot <ben+marge-bot@smart-cactus.org> | 2021-09-17 09:44:53 -0400 |
commit | 53dc8e41a424909c8c3eccc43695fee0cdcc1555 (patch) | |
tree | 6d96a223f8a7f76053d15de09fc1e0eff0e3689b /compiler | |
parent | b041ea7784f036dd7cfc5fae6380db4f3c392ab4 (diff) | |
download | haskell-53dc8e41a424909c8c3eccc43695fee0cdcc1555.tar.gz |
Code Gen: Use more efficient block merging algorithm
The previous algorithm scaled poorly when there was a large number of
blocks and edges.
The algorithm links together block chains which have edges between them
in the CFG. The new algorithm uses a union find data structure in order
to efficiently merge together blocks and calculate which block chain
each block id belonds to.
I copied the UnionFind data structure which already existed in Cabal
into the GHC library rathert than reimplement it myself.
This change results in a very significant reduction in allocations when
compiling the mmark package.
Ticket: #19471
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/GHC/CmmToAsm/BlockLayout.hs | 44 | ||||
-rw-r--r-- | compiler/GHC/Data/UnionFind.hs | 91 | ||||
-rw-r--r-- | compiler/ghc.cabal.in | 1 |
3 files changed, 109 insertions, 27 deletions
diff --git a/compiler/GHC/CmmToAsm/BlockLayout.hs b/compiler/GHC/CmmToAsm/BlockLayout.hs index 6e2e7e2189..70e131c717 100644 --- a/compiler/GHC/CmmToAsm/BlockLayout.hs +++ b/compiler/GHC/CmmToAsm/BlockLayout.hs @@ -41,12 +41,13 @@ import GHC.Utils.Panic import GHC.Utils.Panic.Plain import GHC.Utils.Misc -import Data.List (sortOn, sortBy) +import Data.List (sortOn, sortBy, nub) import Data.Foldable (toList) import qualified Data.Set as Set import Data.STRef import Control.Monad.ST.Strict -import Control.Monad (foldM) +import Control.Monad (foldM, unless) +import GHC.Data.UnionFind {- Note [CFG based code layout] @@ -480,10 +481,9 @@ combineNeighbourhood edges chains mergeChains :: [CfgEdge] -> [BlockChain] -> (BlockChain) mergeChains edges chains - = -- pprTrace "combine" (ppr edges) $ - runST $ do + = runST $ do let addChain m0 chain = do - ref <- newSTRef chain + ref <- fresh chain return $ chainFoldl (\m' b -> mapInsert b ref m') m0 chain chainMap' <- foldM (\m0 c -> addChain m0 c) mapEmpty chains merge edges chainMap' @@ -491,35 +491,25 @@ mergeChains edges chains -- We keep a map from ALL blocks to their respective chain (sigh) -- This is required since when looking at an edge we need to find -- the associated chains quickly. - -- We use a map of STRefs, maintaining a invariant of one STRef per chain. - -- When merging chains we can update the - -- STRef of one chain once (instead of writing to the map for each block). - -- We then overwrite the STRefs for the other chain so there is again only - -- a single STRef for the combined chain. - -- The difference in terms of allocations saved is ~0.2% with -O so actually - -- significant compared to using a regular map. - - merge :: forall s. [CfgEdge] -> LabelMap (STRef s BlockChain) -> ST s BlockChain + -- We use a union-find data structure to do this efficiently. + + merge :: forall s. [CfgEdge] -> LabelMap (Point s BlockChain) -> ST s BlockChain merge [] chains = do - chains' <- ordNub <$> (mapM readSTRef $ mapElems chains) :: ST s [BlockChain] + chains' <- mapM find =<< (nub <$> (mapM repr $ mapElems chains)) :: ST s [BlockChain] return $ foldl' chainConcat (head chains') (tail chains') merge ((CfgEdge from to _):edges) chains -- | pprTrace "merge" (ppr (from,to) <> ppr chains) False -- = undefined - | cFrom == cTo - = merge edges chains - | otherwise = do - chains' <- mergeComb cFrom cTo - merge edges chains' + same <- equivalent cFrom cTo + unless same $ do + cRight <- find cTo + cLeft <- find cFrom + new_point <- fresh (chainConcat cLeft cRight) + union cTo new_point + union cFrom new_point + merge edges chains where - mergeComb :: STRef s BlockChain -> STRef s BlockChain -> ST s (LabelMap (STRef s BlockChain)) - mergeComb refFrom refTo = do - cRight <- readSTRef refTo - chain <- pure chainConcat <*> readSTRef refFrom <*> pure cRight - writeSTRef refFrom chain - return $ chainFoldl (\m b -> mapInsert b refFrom m) chains cRight - cFrom = expectJust "mergeChains:chainMap:from" $ mapLookup from chains cTo = expectJust "mergeChains:chainMap:to" $ mapLookup to chains diff --git a/compiler/GHC/Data/UnionFind.hs b/compiler/GHC/Data/UnionFind.hs new file mode 100644 index 0000000000..21687b5a09 --- /dev/null +++ b/compiler/GHC/Data/UnionFind.hs @@ -0,0 +1,91 @@ +{- Union-find data structure compiled from Distribution.Utils.UnionFind -} +module GHC.Data.UnionFind where + +import GHC.Prelude +import Data.STRef +import Control.Monad.ST +import Control.Monad + +-- | A variable which can be unified; alternately, this can be thought +-- of as an equivalence class with a distinguished representative. +newtype Point s a = Point (STRef s (Link s a)) + deriving (Eq) + +-- | Mutable write to a 'Point' +writePoint :: Point s a -> Link s a -> ST s () +writePoint (Point v) = writeSTRef v + +-- | Read the current value of 'Point'. +readPoint :: Point s a -> ST s (Link s a) +readPoint (Point v) = readSTRef v + +-- | The internal data structure for a 'Point', which either records +-- the representative element of an equivalence class, or a link to +-- the 'Point' that actually stores the representative type. +data Link s a + -- NB: it is too bad we can't say STRef Int#; the weights remain boxed + = Info {-# UNPACK #-} !(STRef s Int) {-# UNPACK #-} !(STRef s a) + | Link {-# UNPACK #-} !(Point s a) + +-- | Create a fresh equivalence class with one element. +fresh :: a -> ST s (Point s a) +fresh desc = do + weight <- newSTRef 1 + descriptor <- newSTRef desc + Point `fmap` newSTRef (Info weight descriptor) + +-- | Flatten any chains of links, returning a 'Point' +-- which points directly to the canonical representation. +repr :: Point s a -> ST s (Point s a) +repr point = readPoint point >>= \r -> + case r of + Link point' -> do + point'' <- repr point' + when (point'' /= point') $ do + writePoint point =<< readPoint point' + return point'' + Info _ _ -> return point + +-- | Return the canonical element of an equivalence +-- class 'Point'. +find :: Point s a -> ST s a +find point = + -- Optimize length 0 and 1 case at expense of + -- general case + readPoint point >>= \r -> + case r of + Info _ d_ref -> readSTRef d_ref + Link point' -> readPoint point' >>= \r' -> + case r' of + Info _ d_ref -> readSTRef d_ref + Link _ -> repr point >>= find + +-- | Unify two equivalence classes, so that they share +-- a canonical element. Keeps the descriptor of point2. +union :: Point s a -> Point s a -> ST s () +union refpoint1 refpoint2 = do + point1 <- repr refpoint1 + point2 <- repr refpoint2 + when (point1 /= point2) $ do + l1 <- readPoint point1 + l2 <- readPoint point2 + case (l1, l2) of + (Info wref1 dref1, Info wref2 dref2) -> do + weight1 <- readSTRef wref1 + weight2 <- readSTRef wref2 + -- Should be able to optimize the == case separately + if weight1 >= weight2 + then do + writePoint point2 (Link point1) + -- The weight calculation here seems a bit dodgy + writeSTRef wref1 (weight1 + weight2) + writeSTRef dref1 =<< readSTRef dref2 + else do + writePoint point1 (Link point2) + writeSTRef wref2 (weight1 + weight2) + _ -> error "UnionFind.union: repr invariant broken" + +-- | Test if two points are in the same equivalence class. +equivalent :: Point s a -> Point s a -> ST s Bool +equivalent point1 point2 = liftM2 (==) (repr point1) (repr point2) + diff --git a/compiler/ghc.cabal.in b/compiler/ghc.cabal.in index 2023fbe3da..b1acf005da 100644 --- a/compiler/ghc.cabal.in +++ b/compiler/ghc.cabal.in @@ -407,6 +407,7 @@ Library GHC.Data.Strict GHC.Data.StringBuffer GHC.Data.TrieMap + GHC.Data.UnionFind GHC.Driver.Backend GHC.Driver.Backpack GHC.Driver.Backpack.Syntax |