diff options
author | M Farkas-Dyck <strake888@proton.me> | 2022-10-17 22:46:03 -0800 |
---|---|---|
committer | Marge Bot <ben+marge-bot@smart-cactus.org> | 2022-10-19 10:47:13 -0400 |
commit | c3732c6210972a992e1153b0667cf8abf0351acd (patch) | |
tree | e215a94569287d58320dad20ca0accbff8023052 | |
parent | 83638dce4e20097b9b7073534e488a92dce6e88f (diff) | |
download | haskell-c3732c6210972a992e1153b0667cf8abf0351acd.tar.gz |
Enforce invariant of `ListBag` constructor.
-rw-r--r-- | compiler/GHC/Data/Bag.hs | 42 | ||||
-rw-r--r-- | compiler/GHC/Tc/Deriv.hs | 4 | ||||
-rw-r--r-- | compiler/GHC/Utils/Monad.hs | 55 |
3 files changed, 65 insertions, 36 deletions
diff --git a/compiler/GHC/Data/Bag.hs b/compiler/GHC/Data/Bag.hs index 91b079f419..a79bc41e6f 100644 --- a/compiler/GHC/Data/Bag.hs +++ b/compiler/GHC/Data/Bag.hs @@ -18,7 +18,7 @@ module GHC.Data.Bag ( concatBag, catBagMaybes, foldBag, isEmptyBag, isSingletonBag, consBag, snocBag, anyBag, allBag, listToBag, nonEmptyToBag, bagToList, headMaybe, mapAccumBagL, - concatMapBag, concatMapBagPair, mapMaybeBag, + concatMapBag, concatMapBagPair, mapMaybeBag, unzipBag, mapBagM, mapBagM_, flatMapBagM, flatMapBagPairM, mapAndUnzipBagM, mapAccumBagLM, @@ -33,9 +33,10 @@ import GHC.Utils.Misc import GHC.Utils.Monad import Control.Monad import Data.Data -import Data.Maybe( mapMaybe, listToMaybe ) +import Data.Maybe( mapMaybe ) import Data.List ( partition, mapAccumL ) import Data.List.NonEmpty ( NonEmpty(..) ) +import qualified Data.List.NonEmpty as NE import qualified Data.Semigroup ( (<>) ) infixr 3 `consBag` @@ -45,7 +46,7 @@ data Bag a = EmptyBag | UnitBag a | TwoBags (Bag a) (Bag a) -- INVARIANT: neither branch is empty - | ListBag [a] -- INVARIANT: the list is non-empty + | ListBag (NonEmpty a) deriving (Foldable, Functor, Traversable) emptyBag :: Bag a @@ -90,7 +91,7 @@ isSingletonBag :: Bag a -> Bool isSingletonBag EmptyBag = False isSingletonBag (UnitBag _) = True isSingletonBag (TwoBags _ _) = False -- Neither is empty -isSingletonBag (ListBag xs) = isSingleton xs +isSingletonBag (ListBag (_:|xs)) = null xs filterBag :: (a -> Bool) -> Bag a -> Bag a filterBag _ EmptyBag = EmptyBag @@ -98,7 +99,7 @@ filterBag pred b@(UnitBag val) = if pred val then b else EmptyBag filterBag pred (TwoBags b1 b2) = sat1 `unionBags` sat2 where sat1 = filterBag pred b1 sat2 = filterBag pred b2 -filterBag pred (ListBag vs) = listToBag (filter pred vs) +filterBag pred (ListBag vs) = listToBag (filter pred (toList vs)) filterBagM :: Monad m => (a -> m Bool) -> Bag a -> m (Bag a) filterBagM _ EmptyBag = return EmptyBag @@ -111,7 +112,7 @@ filterBagM pred (TwoBags b1 b2) = do sat2 <- filterBagM pred b2 return (sat1 `unionBags` sat2) filterBagM pred (ListBag vs) = do - sat <- filterM pred vs + sat <- filterM pred (toList vs) return (listToBag sat) allBag :: (a -> Bool) -> Bag a -> Bool @@ -135,9 +136,7 @@ anyBagM p (TwoBags b1 b2) = do flag <- anyBagM p b1 anyBagM p (ListBag xs) = anyM p xs concatBag :: Bag (Bag a) -> Bag a -concatBag bss = foldr add emptyBag bss - where - add bs rs = bs `unionBags` rs +concatBag = foldr unionBags emptyBag catBagMaybes :: Bag (Maybe a) -> Bag a catBagMaybes bs = foldr add emptyBag bs @@ -155,7 +154,7 @@ partitionBag pred (TwoBags b1 b2) where (sat1, fail1) = partitionBag pred b1 (sat2, fail2) = partitionBag pred b2 partitionBag pred (ListBag vs) = (listToBag sats, listToBag fails) - where (sats, fails) = partition pred vs + where (sats, fails) = partition pred (toList vs) partitionBagWith :: (a -> Either b c) -> Bag a @@ -171,7 +170,7 @@ partitionBagWith pred (TwoBags b1 b2) where (sat1, fail1) = partitionBagWith pred b1 (sat2, fail2) = partitionBagWith pred b2 partitionBagWith pred (ListBag vs) = (listToBag sats, listToBag fails) - where (sats, fails) = partitionWith pred vs + where (sats, fails) = partitionWith pred (toList vs) foldBag :: (r -> r -> r) -- Replace TwoBags with this; should be associative -> (a -> r) -- Replace UnitBag with this @@ -220,7 +219,7 @@ mapMaybeBag f (UnitBag x) = case f x of Nothing -> EmptyBag Just y -> UnitBag y mapMaybeBag f (TwoBags b1 b2) = unionBags (mapMaybeBag f b1) (mapMaybeBag f b2) -mapMaybeBag f (ListBag xs) = ListBag (mapMaybe f xs) +mapMaybeBag f (ListBag xs) = listToBag $ mapMaybe f (toList xs) mapBagM :: Monad m => (a -> m b) -> Bag a -> m (Bag b) mapBagM _ EmptyBag = return EmptyBag @@ -267,7 +266,7 @@ mapAndUnzipBagM f (TwoBags b1 b2) = do (r1,s1) <- mapAndUnzipBagM f b1 (r2,s2) <- mapAndUnzipBagM f b2 return (TwoBags r1 r2, TwoBags s1 s2) mapAndUnzipBagM f (ListBag xs) = do ts <- mapM f xs - let (rs,ss) = unzip ts + let (rs,ss) = NE.unzip ts return (ListBag rs, ListBag ss) mapAccumBagL ::(acc -> x -> (acc, y)) -- ^ combining function @@ -298,20 +297,31 @@ mapAccumBagLM f s (ListBag xs) = do { (s', xs') <- mapAccumLM f s xs listToBag :: [a] -> Bag a listToBag [] = EmptyBag listToBag [x] = UnitBag x -listToBag vs = ListBag vs +listToBag (x:xs) = ListBag (x:|xs) nonEmptyToBag :: NonEmpty a -> Bag a nonEmptyToBag (x :| []) = UnitBag x -nonEmptyToBag (x :| xs) = ListBag (x : xs) +nonEmptyToBag xs = ListBag xs bagToList :: Bag a -> [a] bagToList b = foldr (:) [] b +unzipBag :: Bag (a, b) -> (Bag a, Bag b) +unzipBag EmptyBag = (EmptyBag, EmptyBag) +unzipBag (UnitBag (a, b)) = (UnitBag a, UnitBag b) +unzipBag (TwoBags xs1 xs2) = (TwoBags as1 as2, TwoBags bs1 bs2) + where + (as1, bs1) = unzipBag xs1 + (as2, bs2) = unzipBag xs2 +unzipBag (ListBag xs) = (ListBag as, ListBag bs) + where + (as, bs) = NE.unzip xs + headMaybe :: Bag a -> Maybe a headMaybe EmptyBag = Nothing headMaybe (UnitBag v) = Just v headMaybe (TwoBags b1 _) = headMaybe b1 -headMaybe (ListBag l) = listToMaybe l +headMaybe (ListBag (v:|_)) = Just v instance (Outputable a) => Outputable (Bag a) where ppr bag = braces (pprWithCommas ppr (bagToList bag)) diff --git a/compiler/GHC/Tc/Deriv.hs b/compiler/GHC/Tc/Deriv.hs index 0a873ff05e..feefb05ac1 100644 --- a/compiler/GHC/Tc/Deriv.hs +++ b/compiler/GHC/Tc/Deriv.hs @@ -289,8 +289,8 @@ renameDeriv inst_infos bagBinds -- Bring the extra deriving stuff into scope -- before renaming the instances themselves ; traceTc "rnd" (vcat (map (\i -> pprInstInfoDetails i $$ text "") inst_infos)) - ; (aux_binds, aux_sigs) <- mapAndUnzipBagM return bagBinds - ; let aux_val_binds = ValBinds NoAnnSortKey aux_binds (bagToList aux_sigs) + ; let (aux_binds, aux_sigs) = unzipBag bagBinds + aux_val_binds = ValBinds NoAnnSortKey aux_binds (bagToList aux_sigs) -- Importantly, we use rnLocalValBindsLHS, not rnTopBindsLHS, to rename -- auxiliary bindings as if they were defined locally. -- See Note [Auxiliary binders] in GHC.Tc.Deriv.Generate. diff --git a/compiler/GHC/Utils/Monad.hs b/compiler/GHC/Utils/Monad.hs index 6f93ac2c27..c5e24794a4 100644 --- a/compiler/GHC/Utils/Monad.hs +++ b/compiler/GHC/Utils/Monad.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE MonadComprehensions #-} + -- | Utilities related to Monad and Applicative classes -- Mostly for backwards compatibility. @@ -28,8 +30,11 @@ import GHC.Prelude import Control.Monad import Control.Monad.Fix import Control.Monad.IO.Class +import Control.Monad.Trans.State.Strict (StateT (..)) import Data.Foldable (sequenceA_, foldlM, foldrM) import Data.List (unzip4, unzip5, zipWith4) +import Data.List.NonEmpty (NonEmpty (..)) +import Data.Tuple (swap) ------------------------------------------------------------------------------- -- Common functions @@ -137,18 +142,28 @@ mapAndUnzip5M f xs = unzip5 <$> traverse f xs -- variant and use it where appropriate. -- | Monadic version of mapAccumL -mapAccumLM :: Monad m +mapAccumLM :: (Monad m, Traversable t) => (acc -> x -> m (acc, y)) -- ^ combining function -> acc -- ^ initial state - -> [x] -- ^ inputs - -> m (acc, [y]) -- ^ final state, outputs -{-# INLINE mapAccumLM #-} + -> t x -- ^ inputs + -> m (acc, t y) -- ^ final state, outputs +{-# INLINE [1] mapAccumLM #-} -- INLINE pragma. mapAccumLM is called in inner loops. Like 'map', -- we inline it so that we can take advantage of knowing 'f'. -- This makes a few percent difference (in compiler allocations) -- when compiling perf/compiler/T9675 -mapAccumLM f s xs = - go s xs +mapAccumLM f s = fmap swap . flip runStateT s . traverse f' + where + f' = StateT . (fmap . fmap) swap . flip f +{-# RULES "mapAccumLM/List" mapAccumLM = mapAccumLM_List #-} +{-# RULES "mapAccumLM/NonEmpty" mapAccumLM = mapAccumLM_NonEmpty #-} + +mapAccumLM_List + :: Monad m + => (acc -> x -> m (acc, y)) + -> acc -> [x] -> m (acc, [y]) +{-# INLINE mapAccumLM_List #-} +mapAccumLM_List f s = go s where go s (x:xs) = do (s1, x') <- f s x @@ -156,6 +171,14 @@ mapAccumLM f s xs = return (s2, x' : xs') go s [] = return (s, []) +mapAccumLM_NonEmpty + :: Monad m + => (acc -> x -> m (acc, y)) + -> acc -> NonEmpty x -> m (acc, NonEmpty y) +{-# INLINE mapAccumLM_NonEmpty #-} +mapAccumLM_NonEmpty f s (x:|xs) = + [(s2, x':|xs') | (s1, x') <- f s x, (s2, xs') <- mapAccumLM_List f s1 xs] + -- | Monadic version of mapSnd mapSndM :: (Applicative m, Traversable f) => (b -> m c) -> f (a,b) -> m (f (a,c)) mapSndM = traverse . traverse @@ -174,25 +197,21 @@ mapMaybeM f = foldr g (pure []) where g a = liftA2 (maybe id (:)) (f a) -- | Monadic version of 'any', aborts the computation at the first @True@ value -anyM :: Monad m => (a -> m Bool) -> [a] -> m Bool -anyM f xs = go xs - where - go [] = return False - go (x:xs) = do b <- f x - if b then return True - else go xs +anyM :: (Monad m, Foldable f) => (a -> m Bool) -> f a -> m Bool +anyM f = foldr (orM . f) (pure False) -- | Monad version of 'all', aborts the computation at the first @False@ value -allM :: Monad m => (a -> m Bool) -> [a] -> m Bool -allM f bs = go bs - where - go [] = return True - go (b:bs) = (f b) >>= (\bv -> if bv then go bs else return False) +allM :: (Monad m, Foldable f) => (a -> m Bool) -> f a -> m Bool +allM f = foldr (andM . f) (pure True) -- | Monadic version of or orM :: Monad m => m Bool -> m Bool -> m Bool orM m1 m2 = m1 >>= \x -> if x then return True else m2 +-- | Monadic version of and +andM :: Monad m => m Bool -> m Bool -> m Bool +andM m1 m2 = m1 >>= \x -> if x then m2 else return False + -- | Monadic version of foldl that discards its result foldlM_ :: (Monad m, Foldable t) => (a -> b -> m a) -> a -> t b -> m () foldlM_ = foldM_ |