summaryrefslogtreecommitdiff
path: root/compiler/utils/ListSetOps.hs
blob: c311ac9c854431316770509dae791b72e30947f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
{-
(c) The University of Glasgow 2006
(c) The GRASP/AQUA Project, Glasgow University, 1992-1998

\section[ListSetOps]{Set-like operations on lists}
-}

{-# LANGUAGE CPP #-}

module ListSetOps (
        unionLists, minusList,

        -- Association lists
        Assoc, assoc, assocMaybe, assocUsing, assocDefault, assocDefaultUsing,

        -- Duplicate handling
        hasNoDups, removeDups, findDupsEq,
        equivClasses,

        -- Indexing
        getNth
   ) where

#include "HsVersions.h"

import GhcPrelude

import Outputable
import Util

import Data.List
import qualified Data.List.NonEmpty as NE
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.Set as S

getNth :: Outputable a => [a] -> Int -> a
getNth xs n = ASSERT2( xs `lengthExceeds` n, ppr n $$ ppr xs )
             xs !! n

{-
************************************************************************
*                                                                      *
        Treating lists as sets
        Assumes the lists contain no duplicates, but are unordered
*                                                                      *
************************************************************************
-}


unionLists :: (Outputable a, Eq a) => [a] -> [a] -> [a]
-- Assumes that the arguments contain no duplicates
unionLists xs ys
  = WARN(lengthExceeds xs 100 || lengthExceeds ys 100, ppr xs $$ ppr ys)
    [x | x <- xs, isn'tIn "unionLists" x ys] ++ ys

-- | Calculate the set difference of two lists. This is
-- /O((m + n) log n)/, where we subtract a list of /n/ elements
-- from a list of /m/ elements.
--
-- Extremely short cases are handled specially:
-- When /m/ or /n/ is 0, this takes /O(1)/ time. When /m/ is 1,
-- it takes /O(n)/ time.
minusList :: Ord a => [a] -> [a] -> [a]
-- There's no point building a set to perform just one lookup, so we handle
-- extremely short lists specially. It might actually be better to use
-- an O(m*n) algorithm when m is a little longer (perhaps up to 4 or even 5).
-- The tipping point will be somewhere in the area of where /m/ and /log n/
-- become comparable, but we probably don't want to work too hard on this.
minusList [] _ = []
minusList xs@[x] ys
  | x `elem` ys = []
  | otherwise = xs
-- Using an empty set or a singleton would also be silly, so let's not.
minusList xs [] = xs
minusList xs [y] = filter (/= y) xs
-- When each list has at least two elements, we build a set from the
-- second argument, allowing us to filter the first argument fairly
-- efficiently.
minusList xs ys = filter (`S.notMember` yss) xs
  where
    yss = S.fromList ys

{-
************************************************************************
*                                                                      *
\subsection[Utils-assoc]{Association lists}
*                                                                      *
************************************************************************

Inefficient finite maps based on association lists and equality.
-}

-- A finite mapping based on equality and association lists
type Assoc a b = [(a,b)]

assoc             :: (Eq a) => String -> Assoc a b -> a -> b
assocDefault      :: (Eq a) => b -> Assoc a b -> a -> b
assocUsing        :: (a -> a -> Bool) -> String -> Assoc a b -> a -> b
assocMaybe        :: (Eq a) => Assoc a b -> a -> Maybe b
assocDefaultUsing :: (a -> a -> Bool) -> b -> Assoc a b -> a -> b

assocDefaultUsing _  deflt []             _   = deflt
assocDefaultUsing eq deflt ((k,v) : rest) key
  | k `eq` key = v
  | otherwise  = assocDefaultUsing eq deflt rest key

assoc crash_msg         list key = assocDefaultUsing (==) (panic ("Failed in assoc: " ++ crash_msg)) list key
assocDefault deflt      list key = assocDefaultUsing (==) deflt list key
assocUsing eq crash_msg list key = assocDefaultUsing eq (panic ("Failed in assoc: " ++ crash_msg)) list key

assocMaybe alist key
  = lookup alist
  where
    lookup []             = Nothing
    lookup ((tv,ty):rest) = if key == tv then Just ty else lookup rest

{-
************************************************************************
*                                                                      *
\subsection[Utils-dups]{Duplicate-handling}
*                                                                      *
************************************************************************
-}

hasNoDups :: (Eq a) => [a] -> Bool

hasNoDups xs = f [] xs
  where
    f _           []     = True
    f seen_so_far (x:xs) = if x `is_elem` seen_so_far
                           then False
                           else f (x:seen_so_far) xs

    is_elem = isIn "hasNoDups"

equivClasses :: (a -> a -> Ordering) -- Comparison
             -> [a]
             -> [NonEmpty a]

equivClasses _   []      = []
equivClasses _   [stuff] = [stuff :| []]
equivClasses cmp items   = NE.groupBy eq (sortBy cmp items)
  where
    eq a b = case cmp a b of { EQ -> True; _ -> False }

removeDups :: (a -> a -> Ordering) -- Comparison function
           -> [a]
           -> ([a],          -- List with no duplicates
               [NonEmpty a]) -- List of duplicate groups.  One representative
                             -- from each group appears in the first result

removeDups _   []  = ([], [])
removeDups _   [x] = ([x],[])
removeDups cmp xs
  = case (mapAccumR collect_dups [] (equivClasses cmp xs)) of { (dups, xs') ->
    (xs', dups) }
  where
    collect_dups :: [NonEmpty a] -> NonEmpty a -> ([NonEmpty a], a)
    collect_dups dups_so_far (x :| [])     = (dups_so_far,      x)
    collect_dups dups_so_far dups@(x :| _) = (dups:dups_so_far, x)

findDupsEq :: (a->a->Bool) -> [a] -> [NonEmpty a]
findDupsEq _  [] = []
findDupsEq eq (x:xs) | null eq_xs  = findDupsEq eq xs
                     | otherwise   = (x :| eq_xs) : findDupsEq eq neq_xs
    where (eq_xs, neq_xs) = partition (eq x) xs