summaryrefslogtreecommitdiff
path: root/compiler/utils/UniqDSet.hs
blob: c2ace5787f3cfc1f901d399ffc1f95899d6c97fa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
-- (c) Bartosz Nitka, Facebook, 2015

-- |
-- Specialised deterministic sets, for things with @Uniques@
--
-- Based on 'UniqDFM's (as you would expect).
-- See Note [Deterministic UniqFM] in UniqDFM for explanation why we need it.
--
-- Basically, the things need to be in class 'Uniquable'.

{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveDataTypeable #-}

module UniqDSet (
        -- * Unique set type
        UniqDSet,    -- type synonym for UniqFM a
        getUniqDSet,
        pprUniqDSet,

        -- ** Manipulating these sets
        delOneFromUniqDSet, delListFromUniqDSet,
        emptyUniqDSet,
        unitUniqDSet,
        mkUniqDSet,
        addOneToUniqDSet, addListToUniqDSet,
        unionUniqDSets, unionManyUniqDSets,
        minusUniqDSet, uniqDSetMinusUniqSet,
        intersectUniqDSets, uniqDSetIntersectUniqSet,
        foldUniqDSet,
        elementOfUniqDSet,
        filterUniqDSet,
        sizeUniqDSet,
        isEmptyUniqDSet,
        lookupUniqDSet,
        uniqDSetToList,
        partitionUniqDSet,
        mapUniqDSet
    ) where

import GhcPrelude

import Outputable
import UniqDFM
import UniqSet
import Unique

import Data.Coerce
import Data.Data
import qualified Data.Semigroup as Semi

-- See Note [UniqSet invariant] in UniqSet.hs for why we want a newtype here.
-- Beyond preserving invariants, we may also want to 'override' typeclass
-- instances.

newtype UniqDSet a = UniqDSet {getUniqDSet' :: UniqDFM a}
                   deriving (Data, Semi.Semigroup, Monoid)

emptyUniqDSet :: UniqDSet a
emptyUniqDSet = UniqDSet emptyUDFM

unitUniqDSet :: Uniquable a => a -> UniqDSet a
unitUniqDSet x = UniqDSet (unitUDFM x x)

mkUniqDSet :: Uniquable a => [a] -> UniqDSet a
mkUniqDSet = foldl' addOneToUniqDSet emptyUniqDSet

-- The new element always goes to the right of existing ones.
addOneToUniqDSet :: Uniquable a => UniqDSet a -> a -> UniqDSet a
addOneToUniqDSet (UniqDSet set) x = UniqDSet (addToUDFM set x x)

addListToUniqDSet :: Uniquable a => UniqDSet a -> [a] -> UniqDSet a
addListToUniqDSet = foldl' addOneToUniqDSet

delOneFromUniqDSet :: Uniquable a => UniqDSet a -> a -> UniqDSet a
delOneFromUniqDSet (UniqDSet s) = UniqDSet . delFromUDFM s

delListFromUniqDSet :: Uniquable a => UniqDSet a -> [a] -> UniqDSet a
delListFromUniqDSet (UniqDSet s) = UniqDSet . delListFromUDFM s

unionUniqDSets :: UniqDSet a -> UniqDSet a -> UniqDSet a
unionUniqDSets (UniqDSet s) (UniqDSet t) = UniqDSet (plusUDFM s t)

unionManyUniqDSets :: [UniqDSet a] -> UniqDSet a
unionManyUniqDSets [] = emptyUniqDSet
unionManyUniqDSets sets = foldr1 unionUniqDSets sets

minusUniqDSet :: UniqDSet a -> UniqDSet a -> UniqDSet a
minusUniqDSet (UniqDSet s) (UniqDSet t) = UniqDSet (minusUDFM s t)

uniqDSetMinusUniqSet :: UniqDSet a -> UniqSet b -> UniqDSet a
uniqDSetMinusUniqSet xs ys
  = UniqDSet (udfmMinusUFM (getUniqDSet xs) (getUniqSet ys))

intersectUniqDSets :: UniqDSet a -> UniqDSet a -> UniqDSet a
intersectUniqDSets (UniqDSet s) (UniqDSet t) = UniqDSet (intersectUDFM s t)

uniqDSetIntersectUniqSet :: UniqDSet a -> UniqSet b -> UniqDSet a
uniqDSetIntersectUniqSet xs ys
  = UniqDSet (udfmIntersectUFM (getUniqDSet xs) (getUniqSet ys))

foldUniqDSet :: (a -> b -> b) -> b -> UniqDSet a -> b
foldUniqDSet c n (UniqDSet s) = foldUDFM c n s

elementOfUniqDSet :: Uniquable a => a -> UniqDSet a -> Bool
elementOfUniqDSet k = elemUDFM k . getUniqDSet

filterUniqDSet :: (a -> Bool) -> UniqDSet a -> UniqDSet a
filterUniqDSet p (UniqDSet s) = UniqDSet (filterUDFM p s)

sizeUniqDSet :: UniqDSet a -> Int
sizeUniqDSet = sizeUDFM . getUniqDSet

isEmptyUniqDSet :: UniqDSet a -> Bool
isEmptyUniqDSet = isNullUDFM . getUniqDSet

lookupUniqDSet :: Uniquable a => UniqDSet a -> a -> Maybe a
lookupUniqDSet = lookupUDFM . getUniqDSet

uniqDSetToList :: UniqDSet a -> [a]
uniqDSetToList = eltsUDFM . getUniqDSet

partitionUniqDSet :: (a -> Bool) -> UniqDSet a -> (UniqDSet a, UniqDSet a)
partitionUniqDSet p = coerce . partitionUDFM p . getUniqDSet

-- See Note [UniqSet invariant] in UniqSet.hs
mapUniqDSet :: Uniquable b => (a -> b) -> UniqDSet a -> UniqDSet b
mapUniqDSet f = mkUniqDSet . map f . uniqDSetToList

-- Two 'UniqDSet's are considered equal if they contain the same
-- uniques.
instance Eq (UniqDSet a) where
  UniqDSet a == UniqDSet b = equalKeysUDFM a b

getUniqDSet :: UniqDSet a -> UniqDFM a
getUniqDSet = getUniqDSet'

instance Outputable a => Outputable (UniqDSet a) where
  ppr = pprUniqDSet ppr

pprUniqDSet :: (a -> SDoc) -> UniqDSet a -> SDoc
pprUniqDSet f = braces . pprWithCommas f . uniqDSetToList