summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorM Farkas-Dyck <strake888@proton.me>2022-10-17 22:46:03 -0800
committerMarge Bot <ben+marge-bot@smart-cactus.org>2022-10-19 10:47:13 -0400
commitc3732c6210972a992e1153b0667cf8abf0351acd (patch)
treee215a94569287d58320dad20ca0accbff8023052
parent83638dce4e20097b9b7073534e488a92dce6e88f (diff)
downloadhaskell-c3732c6210972a992e1153b0667cf8abf0351acd.tar.gz
Enforce invariant of `ListBag` constructor.
-rw-r--r--compiler/GHC/Data/Bag.hs42
-rw-r--r--compiler/GHC/Tc/Deriv.hs4
-rw-r--r--compiler/GHC/Utils/Monad.hs55
3 files changed, 65 insertions, 36 deletions
diff --git a/compiler/GHC/Data/Bag.hs b/compiler/GHC/Data/Bag.hs
index 91b079f419..a79bc41e6f 100644
--- a/compiler/GHC/Data/Bag.hs
+++ b/compiler/GHC/Data/Bag.hs
@@ -18,7 +18,7 @@ module GHC.Data.Bag (
concatBag, catBagMaybes, foldBag,
isEmptyBag, isSingletonBag, consBag, snocBag, anyBag, allBag,
listToBag, nonEmptyToBag, bagToList, headMaybe, mapAccumBagL,
- concatMapBag, concatMapBagPair, mapMaybeBag,
+ concatMapBag, concatMapBagPair, mapMaybeBag, unzipBag,
mapBagM, mapBagM_,
flatMapBagM, flatMapBagPairM,
mapAndUnzipBagM, mapAccumBagLM,
@@ -33,9 +33,10 @@ import GHC.Utils.Misc
import GHC.Utils.Monad
import Control.Monad
import Data.Data
-import Data.Maybe( mapMaybe, listToMaybe )
+import Data.Maybe( mapMaybe )
import Data.List ( partition, mapAccumL )
import Data.List.NonEmpty ( NonEmpty(..) )
+import qualified Data.List.NonEmpty as NE
import qualified Data.Semigroup ( (<>) )
infixr 3 `consBag`
@@ -45,7 +46,7 @@ data Bag a
= EmptyBag
| UnitBag a
| TwoBags (Bag a) (Bag a) -- INVARIANT: neither branch is empty
- | ListBag [a] -- INVARIANT: the list is non-empty
+ | ListBag (NonEmpty a)
deriving (Foldable, Functor, Traversable)
emptyBag :: Bag a
@@ -90,7 +91,7 @@ isSingletonBag :: Bag a -> Bool
isSingletonBag EmptyBag = False
isSingletonBag (UnitBag _) = True
isSingletonBag (TwoBags _ _) = False -- Neither is empty
-isSingletonBag (ListBag xs) = isSingleton xs
+isSingletonBag (ListBag (_:|xs)) = null xs
filterBag :: (a -> Bool) -> Bag a -> Bag a
filterBag _ EmptyBag = EmptyBag
@@ -98,7 +99,7 @@ filterBag pred b@(UnitBag val) = if pred val then b else EmptyBag
filterBag pred (TwoBags b1 b2) = sat1 `unionBags` sat2
where sat1 = filterBag pred b1
sat2 = filterBag pred b2
-filterBag pred (ListBag vs) = listToBag (filter pred vs)
+filterBag pred (ListBag vs) = listToBag (filter pred (toList vs))
filterBagM :: Monad m => (a -> m Bool) -> Bag a -> m (Bag a)
filterBagM _ EmptyBag = return EmptyBag
@@ -111,7 +112,7 @@ filterBagM pred (TwoBags b1 b2) = do
sat2 <- filterBagM pred b2
return (sat1 `unionBags` sat2)
filterBagM pred (ListBag vs) = do
- sat <- filterM pred vs
+ sat <- filterM pred (toList vs)
return (listToBag sat)
allBag :: (a -> Bool) -> Bag a -> Bool
@@ -135,9 +136,7 @@ anyBagM p (TwoBags b1 b2) = do flag <- anyBagM p b1
anyBagM p (ListBag xs) = anyM p xs
concatBag :: Bag (Bag a) -> Bag a
-concatBag bss = foldr add emptyBag bss
- where
- add bs rs = bs `unionBags` rs
+concatBag = foldr unionBags emptyBag
catBagMaybes :: Bag (Maybe a) -> Bag a
catBagMaybes bs = foldr add emptyBag bs
@@ -155,7 +154,7 @@ partitionBag pred (TwoBags b1 b2)
where (sat1, fail1) = partitionBag pred b1
(sat2, fail2) = partitionBag pred b2
partitionBag pred (ListBag vs) = (listToBag sats, listToBag fails)
- where (sats, fails) = partition pred vs
+ where (sats, fails) = partition pred (toList vs)
partitionBagWith :: (a -> Either b c) -> Bag a
@@ -171,7 +170,7 @@ partitionBagWith pred (TwoBags b1 b2)
where (sat1, fail1) = partitionBagWith pred b1
(sat2, fail2) = partitionBagWith pred b2
partitionBagWith pred (ListBag vs) = (listToBag sats, listToBag fails)
- where (sats, fails) = partitionWith pred vs
+ where (sats, fails) = partitionWith pred (toList vs)
foldBag :: (r -> r -> r) -- Replace TwoBags with this; should be associative
-> (a -> r) -- Replace UnitBag with this
@@ -220,7 +219,7 @@ mapMaybeBag f (UnitBag x) = case f x of
Nothing -> EmptyBag
Just y -> UnitBag y
mapMaybeBag f (TwoBags b1 b2) = unionBags (mapMaybeBag f b1) (mapMaybeBag f b2)
-mapMaybeBag f (ListBag xs) = ListBag (mapMaybe f xs)
+mapMaybeBag f (ListBag xs) = listToBag $ mapMaybe f (toList xs)
mapBagM :: Monad m => (a -> m b) -> Bag a -> m (Bag b)
mapBagM _ EmptyBag = return EmptyBag
@@ -267,7 +266,7 @@ mapAndUnzipBagM f (TwoBags b1 b2) = do (r1,s1) <- mapAndUnzipBagM f b1
(r2,s2) <- mapAndUnzipBagM f b2
return (TwoBags r1 r2, TwoBags s1 s2)
mapAndUnzipBagM f (ListBag xs) = do ts <- mapM f xs
- let (rs,ss) = unzip ts
+ let (rs,ss) = NE.unzip ts
return (ListBag rs, ListBag ss)
mapAccumBagL ::(acc -> x -> (acc, y)) -- ^ combining function
@@ -298,20 +297,31 @@ mapAccumBagLM f s (ListBag xs) = do { (s', xs') <- mapAccumLM f s xs
listToBag :: [a] -> Bag a
listToBag [] = EmptyBag
listToBag [x] = UnitBag x
-listToBag vs = ListBag vs
+listToBag (x:xs) = ListBag (x:|xs)
nonEmptyToBag :: NonEmpty a -> Bag a
nonEmptyToBag (x :| []) = UnitBag x
-nonEmptyToBag (x :| xs) = ListBag (x : xs)
+nonEmptyToBag xs = ListBag xs
bagToList :: Bag a -> [a]
bagToList b = foldr (:) [] b
+unzipBag :: Bag (a, b) -> (Bag a, Bag b)
+unzipBag EmptyBag = (EmptyBag, EmptyBag)
+unzipBag (UnitBag (a, b)) = (UnitBag a, UnitBag b)
+unzipBag (TwoBags xs1 xs2) = (TwoBags as1 as2, TwoBags bs1 bs2)
+ where
+ (as1, bs1) = unzipBag xs1
+ (as2, bs2) = unzipBag xs2
+unzipBag (ListBag xs) = (ListBag as, ListBag bs)
+ where
+ (as, bs) = NE.unzip xs
+
headMaybe :: Bag a -> Maybe a
headMaybe EmptyBag = Nothing
headMaybe (UnitBag v) = Just v
headMaybe (TwoBags b1 _) = headMaybe b1
-headMaybe (ListBag l) = listToMaybe l
+headMaybe (ListBag (v:|_)) = Just v
instance (Outputable a) => Outputable (Bag a) where
ppr bag = braces (pprWithCommas ppr (bagToList bag))
diff --git a/compiler/GHC/Tc/Deriv.hs b/compiler/GHC/Tc/Deriv.hs
index 0a873ff05e..feefb05ac1 100644
--- a/compiler/GHC/Tc/Deriv.hs
+++ b/compiler/GHC/Tc/Deriv.hs
@@ -289,8 +289,8 @@ renameDeriv inst_infos bagBinds
-- Bring the extra deriving stuff into scope
-- before renaming the instances themselves
; traceTc "rnd" (vcat (map (\i -> pprInstInfoDetails i $$ text "") inst_infos))
- ; (aux_binds, aux_sigs) <- mapAndUnzipBagM return bagBinds
- ; let aux_val_binds = ValBinds NoAnnSortKey aux_binds (bagToList aux_sigs)
+ ; let (aux_binds, aux_sigs) = unzipBag bagBinds
+ aux_val_binds = ValBinds NoAnnSortKey aux_binds (bagToList aux_sigs)
-- Importantly, we use rnLocalValBindsLHS, not rnTopBindsLHS, to rename
-- auxiliary bindings as if they were defined locally.
-- See Note [Auxiliary binders] in GHC.Tc.Deriv.Generate.
diff --git a/compiler/GHC/Utils/Monad.hs b/compiler/GHC/Utils/Monad.hs
index 6f93ac2c27..c5e24794a4 100644
--- a/compiler/GHC/Utils/Monad.hs
+++ b/compiler/GHC/Utils/Monad.hs
@@ -1,3 +1,5 @@
+{-# LANGUAGE MonadComprehensions #-}
+
-- | Utilities related to Monad and Applicative classes
-- Mostly for backwards compatibility.
@@ -28,8 +30,11 @@ import GHC.Prelude
import Control.Monad
import Control.Monad.Fix
import Control.Monad.IO.Class
+import Control.Monad.Trans.State.Strict (StateT (..))
import Data.Foldable (sequenceA_, foldlM, foldrM)
import Data.List (unzip4, unzip5, zipWith4)
+import Data.List.NonEmpty (NonEmpty (..))
+import Data.Tuple (swap)
-------------------------------------------------------------------------------
-- Common functions
@@ -137,18 +142,28 @@ mapAndUnzip5M f xs = unzip5 <$> traverse f xs
-- variant and use it where appropriate.
-- | Monadic version of mapAccumL
-mapAccumLM :: Monad m
+mapAccumLM :: (Monad m, Traversable t)
=> (acc -> x -> m (acc, y)) -- ^ combining function
-> acc -- ^ initial state
- -> [x] -- ^ inputs
- -> m (acc, [y]) -- ^ final state, outputs
-{-# INLINE mapAccumLM #-}
+ -> t x -- ^ inputs
+ -> m (acc, t y) -- ^ final state, outputs
+{-# INLINE [1] mapAccumLM #-}
-- INLINE pragma. mapAccumLM is called in inner loops. Like 'map',
-- we inline it so that we can take advantage of knowing 'f'.
-- This makes a few percent difference (in compiler allocations)
-- when compiling perf/compiler/T9675
-mapAccumLM f s xs =
- go s xs
+mapAccumLM f s = fmap swap . flip runStateT s . traverse f'
+ where
+ f' = StateT . (fmap . fmap) swap . flip f
+{-# RULES "mapAccumLM/List" mapAccumLM = mapAccumLM_List #-}
+{-# RULES "mapAccumLM/NonEmpty" mapAccumLM = mapAccumLM_NonEmpty #-}
+
+mapAccumLM_List
+ :: Monad m
+ => (acc -> x -> m (acc, y))
+ -> acc -> [x] -> m (acc, [y])
+{-# INLINE mapAccumLM_List #-}
+mapAccumLM_List f s = go s
where
go s (x:xs) = do
(s1, x') <- f s x
@@ -156,6 +171,14 @@ mapAccumLM f s xs =
return (s2, x' : xs')
go s [] = return (s, [])
+mapAccumLM_NonEmpty
+ :: Monad m
+ => (acc -> x -> m (acc, y))
+ -> acc -> NonEmpty x -> m (acc, NonEmpty y)
+{-# INLINE mapAccumLM_NonEmpty #-}
+mapAccumLM_NonEmpty f s (x:|xs) =
+ [(s2, x':|xs') | (s1, x') <- f s x, (s2, xs') <- mapAccumLM_List f s1 xs]
+
-- | Monadic version of mapSnd
mapSndM :: (Applicative m, Traversable f) => (b -> m c) -> f (a,b) -> m (f (a,c))
mapSndM = traverse . traverse
@@ -174,25 +197,21 @@ mapMaybeM f = foldr g (pure [])
where g a = liftA2 (maybe id (:)) (f a)
-- | Monadic version of 'any', aborts the computation at the first @True@ value
-anyM :: Monad m => (a -> m Bool) -> [a] -> m Bool
-anyM f xs = go xs
- where
- go [] = return False
- go (x:xs) = do b <- f x
- if b then return True
- else go xs
+anyM :: (Monad m, Foldable f) => (a -> m Bool) -> f a -> m Bool
+anyM f = foldr (orM . f) (pure False)
-- | Monad version of 'all', aborts the computation at the first @False@ value
-allM :: Monad m => (a -> m Bool) -> [a] -> m Bool
-allM f bs = go bs
- where
- go [] = return True
- go (b:bs) = (f b) >>= (\bv -> if bv then go bs else return False)
+allM :: (Monad m, Foldable f) => (a -> m Bool) -> f a -> m Bool
+allM f = foldr (andM . f) (pure True)
-- | Monadic version of or
orM :: Monad m => m Bool -> m Bool -> m Bool
orM m1 m2 = m1 >>= \x -> if x then return True else m2
+-- | Monadic version of and
+andM :: Monad m => m Bool -> m Bool -> m Bool
+andM m1 m2 = m1 >>= \x -> if x then m2 else return False
+
-- | Monadic version of foldl that discards its result
foldlM_ :: (Monad m, Foldable t) => (a -> b -> m a) -> a -> t b -> m ()
foldlM_ = foldM_