summaryrefslogtreecommitdiff
path: root/compiler/utils
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/utils')
-rw-r--r--compiler/utils/UniqDFM.hs8
-rw-r--r--compiler/utils/UniqDSet.hs79
-rw-r--r--compiler/utils/UniqFM.hs2
3 files changed, 61 insertions, 28 deletions
diff --git a/compiler/utils/UniqDFM.hs b/compiler/utils/UniqDFM.hs
index 82a67f351b..bd530b76c3 100644
--- a/compiler/utils/UniqDFM.hs
+++ b/compiler/utils/UniqDFM.hs
@@ -46,12 +46,13 @@ module UniqDFM (
intersectUDFM, udfmIntersectUFM,
intersectsUDFM,
disjointUDFM, disjointUdfmUfm,
+ equalKeysUDFM,
minusUDFM,
listToUDFM,
udfmMinusUFM,
partitionUDFM,
anyUDFM, allUDFM,
- pprUDFM,
+ pprUniqDFM, pprUDFM,
udfmToList,
udfmToUfm,
@@ -66,6 +67,7 @@ import Outputable
import qualified Data.IntMap as M
import Data.Data
+import Data.Functor.Classes (Eq1 (..))
import Data.List (sortBy)
import Data.Function (on)
import qualified Data.Semigroup as Semi
@@ -288,6 +290,10 @@ udfmToList (UDFM m _i) =
[ (getUnique k, taggedFst v)
| (k, v) <- sortBy (compare `on` (taggedSnd . snd)) $ M.toList m ]
+-- Determines whether two 'UniqDFM's contain the same keys.
+equalKeysUDFM :: UniqDFM a -> UniqDFM b -> Bool
+equalKeysUDFM (UDFM m1 _) (UDFM m2 _) = liftEq (\_ _ -> True) m1 m2
+
isNullUDFM :: UniqDFM elt -> Bool
isNullUDFM (UDFM m _) = M.null m
diff --git a/compiler/utils/UniqDSet.hs b/compiler/utils/UniqDSet.hs
index aa53194331..4be437c1ee 100644
--- a/compiler/utils/UniqDSet.hs
+++ b/compiler/utils/UniqDSet.hs
@@ -3,14 +3,19 @@
-- |
-- Specialised deterministic sets, for things with @Uniques@
--
--- Based on @UniqDFMs@ (as you would expect).
+-- Based on 'UniqDFM's (as you would expect).
-- See Note [Deterministic UniqFM] in UniqDFM for explanation why we need it.
--
--- Basically, the things need to be in class @Uniquable@.
+-- Basically, the things need to be in class 'Uniquable'.
+
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE DeriveDataTypeable #-}
module UniqDSet (
-- * Unique set type
UniqDSet, -- type synonym for UniqFM a
+ getUniqDSet,
+ pprUniqDSet,
-- ** Manipulating these sets
delOneFromUniqDSet, delListFromUniqDSet,
@@ -21,7 +26,6 @@ module UniqDSet (
unionUniqDSets, unionManyUniqDSets,
minusUniqDSet, uniqDSetMinusUniqSet,
intersectUniqDSets, uniqDSetIntersectUniqSet,
- intersectsUniqDSets,
foldUniqDSet,
elementOfUniqDSet,
filterUniqDSet,
@@ -34,76 +38,99 @@ module UniqDSet (
import GhcPrelude
+import Outputable
import UniqDFM
import UniqSet
import Unique
-type UniqDSet a = UniqDFM a
+import Data.Coerce
+import Data.Data
+import qualified Data.Semigroup as Semi
+
+-- See Note [UniqSet invariant] in UniqSet.hs for why we want a newtype here.
+-- Beyond preserving invariants, we may also want to 'override' typeclass
+-- instances.
+
+newtype UniqDSet a = UniqDSet {getUniqDSet' :: UniqDFM a}
+ deriving (Data, Semi.Semigroup, Monoid)
emptyUniqDSet :: UniqDSet a
-emptyUniqDSet = emptyUDFM
+emptyUniqDSet = UniqDSet emptyUDFM
unitUniqDSet :: Uniquable a => a -> UniqDSet a
-unitUniqDSet x = unitUDFM x x
+unitUniqDSet x = UniqDSet (unitUDFM x x)
-mkUniqDSet :: Uniquable a => [a] -> UniqDSet a
+mkUniqDSet :: Uniquable a => [a] -> UniqDSet a
mkUniqDSet = foldl' addOneToUniqDSet emptyUniqDSet
-- The new element always goes to the right of existing ones.
addOneToUniqDSet :: Uniquable a => UniqDSet a -> a -> UniqDSet a
-addOneToUniqDSet set x = addToUDFM set x x
+addOneToUniqDSet (UniqDSet set) x = UniqDSet (addToUDFM set x x)
addListToUniqDSet :: Uniquable a => UniqDSet a -> [a] -> UniqDSet a
addListToUniqDSet = foldl' addOneToUniqDSet
delOneFromUniqDSet :: Uniquable a => UniqDSet a -> a -> UniqDSet a
-delOneFromUniqDSet = delFromUDFM
+delOneFromUniqDSet (UniqDSet s) = UniqDSet . delFromUDFM s
delListFromUniqDSet :: Uniquable a => UniqDSet a -> [a] -> UniqDSet a
-delListFromUniqDSet = delListFromUDFM
+delListFromUniqDSet (UniqDSet s) = UniqDSet . delListFromUDFM s
unionUniqDSets :: UniqDSet a -> UniqDSet a -> UniqDSet a
-unionUniqDSets = plusUDFM
+unionUniqDSets (UniqDSet s) (UniqDSet t) = UniqDSet (plusUDFM s t)
unionManyUniqDSets :: [UniqDSet a] -> UniqDSet a
unionManyUniqDSets [] = emptyUniqDSet
unionManyUniqDSets sets = foldr1 unionUniqDSets sets
minusUniqDSet :: UniqDSet a -> UniqDSet a -> UniqDSet a
-minusUniqDSet = minusUDFM
+minusUniqDSet (UniqDSet s) (UniqDSet t) = UniqDSet (minusUDFM s t)
uniqDSetMinusUniqSet :: UniqDSet a -> UniqSet b -> UniqDSet a
-uniqDSetMinusUniqSet xs ys = udfmMinusUFM xs (getUniqSet ys)
+uniqDSetMinusUniqSet xs ys
+ = UniqDSet (udfmMinusUFM (getUniqDSet xs) (getUniqSet ys))
intersectUniqDSets :: UniqDSet a -> UniqDSet a -> UniqDSet a
-intersectUniqDSets = intersectUDFM
+intersectUniqDSets (UniqDSet s) (UniqDSet t) = UniqDSet (intersectUDFM s t)
uniqDSetIntersectUniqSet :: UniqDSet a -> UniqSet b -> UniqDSet a
-uniqDSetIntersectUniqSet xs ys = xs `udfmIntersectUFM` getUniqSet ys
-
-intersectsUniqDSets :: UniqDSet a -> UniqDSet a -> Bool
-intersectsUniqDSets = intersectsUDFM
+uniqDSetIntersectUniqSet xs ys
+ = UniqDSet (udfmIntersectUFM (getUniqDSet xs) (getUniqSet ys))
foldUniqDSet :: (a -> b -> b) -> b -> UniqDSet a -> b
-foldUniqDSet = foldUDFM
+foldUniqDSet c n (UniqDSet s) = foldUDFM c n s
elementOfUniqDSet :: Uniquable a => a -> UniqDSet a -> Bool
-elementOfUniqDSet = elemUDFM
+elementOfUniqDSet k = elemUDFM k . getUniqDSet
filterUniqDSet :: (a -> Bool) -> UniqDSet a -> UniqDSet a
-filterUniqDSet = filterUDFM
+filterUniqDSet p (UniqDSet s) = UniqDSet (filterUDFM p s)
sizeUniqDSet :: UniqDSet a -> Int
-sizeUniqDSet = sizeUDFM
+sizeUniqDSet = sizeUDFM . getUniqDSet
isEmptyUniqDSet :: UniqDSet a -> Bool
-isEmptyUniqDSet = isNullUDFM
+isEmptyUniqDSet = isNullUDFM . getUniqDSet
lookupUniqDSet :: Uniquable a => UniqDSet a -> a -> Maybe a
-lookupUniqDSet = lookupUDFM
+lookupUniqDSet = lookupUDFM . getUniqDSet
uniqDSetToList :: UniqDSet a -> [a]
-uniqDSetToList = eltsUDFM
+uniqDSetToList = eltsUDFM . getUniqDSet
partitionUniqDSet :: (a -> Bool) -> UniqDSet a -> (UniqDSet a, UniqDSet a)
-partitionUniqDSet = partitionUDFM
+partitionUniqDSet p = coerce . partitionUDFM p . getUniqDSet
+
+-- Two 'UniqDSet's are considered equal if they contain the same
+-- uniques.
+instance Eq (UniqDSet a) where
+ UniqDSet a == UniqDSet b = equalKeysUDFM a b
+
+getUniqDSet :: UniqDSet a -> UniqDFM a
+getUniqDSet = getUniqDSet'
+
+instance Outputable a => Outputable (UniqDSet a) where
+ ppr = pprUniqDSet ppr
+
+pprUniqDSet :: (a -> SDoc) -> UniqDSet a -> SDoc
+pprUniqDSet f (UniqDSet s) = pprUniqDFM f s
diff --git a/compiler/utils/UniqFM.hs b/compiler/utils/UniqFM.hs
index d4a024d34c..33d73cc60c 100644
--- a/compiler/utils/UniqFM.hs
+++ b/compiler/utils/UniqFM.hs
@@ -336,7 +336,7 @@ nonDetUFMToList (UFM m) = map (\(k, v) -> (getUnique k, v)) $ M.toList m
ufmToIntMap :: UniqFM elt -> M.IntMap elt
ufmToIntMap (UFM m) = m
--- Determines whether two 'UniqFm's contain the same keys.
+-- Determines whether two 'UniqFM's contain the same keys.
equalKeysUFM :: UniqFM a -> UniqFM b -> Bool
equalKeysUFM (UFM m1) (UFM m2) = liftEq (\_ _ -> True) m1 m2