diff options
author | Matthew Pickering <matthewtpickering@gmail.com> | 2016-11-29 14:43:43 -0500 |
---|---|---|
committer | Ben Gamari <ben@smart-cactus.org> | 2016-11-29 14:43:44 -0500 |
commit | c2268ba0eeb36a48da77ba95c72525c398c8b306 (patch) | |
tree | b0b550bc91132d81b10db5e904da3b76ee52fd9d | |
parent | 3ec856308cbfb89299daba56337eda866ac88d6e (diff) | |
download | haskell-c2268ba0eeb36a48da77ba95c72525c398c8b306.tar.gz |
Refactor Pattern Match Checker to use ListT
Reviewers: bgamari, austin
Reviewed By: bgamari
Subscribers: thomie
Differential Revision: https://phabricator.haskell.org/D2725
-rw-r--r-- | compiler/deSugar/Check.hs | 326 | ||||
-rw-r--r-- | compiler/ghc.cabal.in | 1 | ||||
-rw-r--r-- | compiler/utils/ListT.hs | 71 |
3 files changed, 287 insertions, 111 deletions
diff --git a/compiler/deSugar/Check.hs b/compiler/deSugar/Check.hs index b5f6eace89..04ba5681b0 100644 --- a/compiler/deSugar/Check.hs +++ b/compiler/deSugar/Check.hs @@ -50,6 +50,8 @@ import Coercion import TcEvidence import IOEnv +import ListT (ListT(..), fold) + {- This module checks pattern matches for: \begin{enumerate} @@ -72,7 +74,25 @@ The algorithm is based on the paper: %************************************************************************ -} -type PmM a = DsM a +-- We use the non-determinism monad to apply the algorithm to several +-- possible sets of constructors. Users can specify complete sets of +-- constructors by using COMPLETE pragmas. +-- The algorithm only picks out constructor +-- sets deep in the bowels which makes a simpler `mapM` more difficult to +-- implement. The non-determinism is only used in one place, see the ConVar +-- case in `pmCheckHd`. + +type PmM a = ListT DsM a + +liftD :: DsM a -> PmM a +liftD m = ListT $ \sk fk -> m >>= \a -> sk a fk + + +myRunListT :: PmM a -> DsM [a] +myRunListT pm = fold pm go (return []) + where + go a mas = + mas >>= \as -> return (a:as) data PatTy = PAT | VA -- Used only as a kind, to index PmPat @@ -122,14 +142,64 @@ type Uncovered = ValSetAbs -- C = True ==> Useful clause (no warning) -- C = False, D = True ==> Clause with inaccessible RHS -- C = False, D = False ==> Redundant clause -type Triple = (Bool, Uncovered, Bool) + +data Covered = Covered | NotCovered + deriving Show + +instance Outputable Covered where + ppr (Covered) = text "Covered" + ppr (NotCovered) = text "NotCovered" + +-- Like the or monoid for booleans +-- Covered = True, Uncovered = False +instance Monoid Covered where + mempty = NotCovered + Covered `mappend` _ = Covered + _ `mappend` Covered = Covered + NotCovered `mappend` NotCovered = NotCovered + +data Diverged = Diverged | NotDiverged + deriving Show + +instance Outputable Diverged where + ppr Diverged = text "Diverged" + ppr NotDiverged = text "NotDiverged" + +instance Monoid Diverged where + mempty = NotDiverged + Diverged `mappend` _ = Diverged + _ `mappend` Diverged = Diverged + NotDiverged `mappend` NotDiverged = NotDiverged + +data PartialResult = PartialResult { + presultCovered :: Covered + , presultUncovered :: Uncovered + , presultDivergent :: Diverged } + +instance Outputable PartialResult where + ppr (PartialResult c vsa d) = text "PartialResult" <+> ppr c + <+> ppr d <+> ppr vsa + +instance Monoid PartialResult where + mempty = PartialResult mempty [] mempty + (PartialResult cs1 vsa1 ds1) + `mappend` (PartialResult cs2 vsa2 ds2) + = PartialResult (cs1 `mappend` cs2) + (vsa1 `mappend` vsa2) + (ds1 `mappend` ds2) + +-- newtype ChoiceOf a = ChoiceOf [a] -- | Pattern check result -- -- * Redundant clauses -- * Not-covered clauses -- * Clauses with inaccessible RHS -type PmResult = ([Located [LPat Id]], Uncovered, [Located [LPat Id]]) +data PmResult = + PmResult { + pmresultRedundant :: [Located [LPat Id]] + , pmresultUncovered :: Uncovered + , pmresultInaccessible :: [Located [LPat Id]] } {- %************************************************************************ @@ -142,63 +212,67 @@ type PmResult = ([Located [LPat Id]], Uncovered, [Located [LPat Id]]) -- | Check a single pattern binding (let) checkSingle :: DynFlags -> DsMatchContext -> Id -> Pat Id -> DsM () checkSingle dflags ctxt@(DsMatchContext _ locn) var p = do - tracePm "checkSingle" (vcat [ppr ctxt, ppr var, ppr p]) - mb_pm_res <- tryM (checkSingle' locn var p) + tracePmD "checkSingle" (vcat [ppr ctxt, ppr var, ppr p]) + mb_pm_res <- tryM (head <$> myRunListT (checkSingle' locn var p)) case mb_pm_res of Left _ -> warnPmIters dflags ctxt Right res -> dsPmWarn dflags ctxt res -- | Check a single pattern binding (let) -checkSingle' :: SrcSpan -> Id -> Pat Id -> DsM PmResult +checkSingle' :: SrcSpan -> Id -> Pat Id -> PmM PmResult checkSingle' locn var p = do - resetPmIterDs -- set the iter-no to zero - fam_insts <- dsGetFamInstEnvs - clause <- translatePat fam_insts p + liftD resetPmIterDs -- set the iter-no to zero + fam_insts <- liftD dsGetFamInstEnvs + clause <- liftD $ translatePat fam_insts p missing <- mkInitialUncovered [var] tracePm "checkSingle: missing" (vcat (map pprValVecDebug missing)) - (cs,us,ds) <- runMany (pmcheckI clause []) missing -- no guards + PartialResult cs us ds <- runMany (pmcheckI clause []) missing -- no guards return $ case (cs,ds) of - (True, _ ) -> ([], us, []) -- useful - (False, False) -> ( m, us, []) -- redundant - (False, True ) -> ([], us, m) -- inaccessible rhs + (Covered, _ ) -> PmResult [] us [] -- useful + (NotCovered, NotDiverged) -> PmResult m us [] -- redundant + (NotCovered, Diverged ) -> PmResult [] us m -- inaccessible rhs where m = [L locn [L locn p]] -- | Check a matchgroup (case, functions, etc.) checkMatches :: DynFlags -> DsMatchContext -> [Id] -> [LMatch Id (LHsExpr Id)] -> DsM () checkMatches dflags ctxt vars matches = do - tracePm "checkMatches" (hang (vcat [ppr ctxt + tracePmD "checkMatches" (hang (vcat [ppr ctxt , ppr vars , text "Matches:"]) 2 (vcat (map ppr matches))) - mb_pm_res <- tryM (checkMatches' vars matches) + mb_pm_res <- tryM (head <$> myRunListT (checkMatches' vars matches)) case mb_pm_res of Left _ -> warnPmIters dflags ctxt Right res -> dsPmWarn dflags ctxt res -- | Check a matchgroup (case, functions, etc.) -checkMatches' :: [Id] -> [LMatch Id (LHsExpr Id)] -> DsM PmResult +checkMatches' :: [Id] -> [LMatch Id (LHsExpr Id)] -> PmM PmResult checkMatches' vars matches - | null matches = return ([], [], []) + | null matches = return $ PmResult [] [] [] | otherwise = do - resetPmIterDs -- set the iter-no to zero + liftD resetPmIterDs -- set the iter-no to zero missing <- mkInitialUncovered vars tracePm "checkMatches: missing" (vcat (map pprValVecDebug missing)) (rs,us,ds) <- go matches missing - return (map hsLMatchToLPats rs, us, map hsLMatchToLPats ds) + return $ PmResult (map hsLMatchToLPats rs) us (map hsLMatchToLPats ds) where + go :: [LMatch Id (LHsExpr Id)] -> Uncovered + -> PmM ([LMatch Id (LHsExpr Id)] , Uncovered , [LMatch Id (LHsExpr Id)]) go [] missing = return ([], missing, []) go (m:ms) missing = do tracePm "checMatches': go" (ppr m $$ ppr missing) - fam_insts <- dsGetFamInstEnvs - (clause, guards) <- translateMatch fam_insts m - (cs, missing', ds) <- runMany (pmcheckI clause guards) missing + fam_insts <- liftD dsGetFamInstEnvs + (clause, guards) <- liftD $ translateMatch fam_insts m + r@(PartialResult cs missing' ds) + <- runMany (pmcheckI clause guards) missing + tracePm "checMatches': go: res" (ppr r) (rs, final_u, is) <- go ms missing' return $ case (cs, ds) of - (True, _ ) -> ( rs, final_u, is) -- useful - (False, False) -> (m:rs, final_u, is) -- redundant - (False, True ) -> ( rs, final_u, m:is) -- inaccessible + (Covered, _ ) -> ( rs, final_u, is) -- useful + (NotCovered, NotDiverged) -> (m:rs, final_u, is) -- redundant + (NotCovered, Diverged ) -> ( rs, final_u, m:is) -- inaccessible hsLMatchToLPats :: LMatch id body -> Located [LPat id] hsLMatchToLPats (L l (Match _ pats _ _)) = L l pats @@ -239,7 +313,7 @@ isFakeGuard [PmCon { pm_con_con = c }] (PmExprOther EWildPat) isFakeGuard _pats _e = False -- | Generate a `canFail` pattern vector of a specific type -mkCanFailPmPat :: Type -> PmM PatVec +mkCanFailPmPat :: Type -> DsM PatVec mkCanFailPmPat ty = do var <- mkPmVar ty return [var, fake_pat] @@ -274,7 +348,7 @@ mkLitPattern lit = PmLit { pm_lit_lit = PmSLit lit } -- ----------------------------------------------------------------------- -- * Transform (Pat Id) into of (PmPat Id) -translatePat :: FamInstEnvs -> Pat Id -> PmM PatVec +translatePat :: FamInstEnvs -> Pat Id -> DsM PatVec translatePat fam_insts pat = case pat of WildPat ty -> mkPmVars [ty] VarPat id -> return [PmVar (unLoc id)] @@ -389,7 +463,7 @@ translatePat fam_insts pat = case pat of -- | Translate an overloaded literal (see `tidyNPat' in deSugar/MatchLit.hs) translateNPat :: FamInstEnvs - -> HsOverLit Id -> Maybe (SyntaxExpr Id) -> Type -> PmM PatVec + -> HsOverLit Id -> Maybe (SyntaxExpr Id) -> Type -> DsM PatVec translateNPat fam_insts (OverLit val False _ ty) mb_neg outer_ty | not type_change, isStringTy ty, HsIsString src s <- val, Nothing <- mb_neg = translatePat fam_insts (LitPat (HsString src s)) @@ -407,12 +481,12 @@ translateNPat _ ol mb_neg _ -- | Translate a list of patterns (Note: each pattern is translated -- to a pattern vector but we do not concatenate the results). -translatePatVec :: FamInstEnvs -> [Pat Id] -> PmM [PatVec] +translatePatVec :: FamInstEnvs -> [Pat Id] -> DsM [PatVec] translatePatVec fam_insts pats = mapM (translatePat fam_insts) pats -- | Translate a constructor pattern translateConPatVec :: FamInstEnvs -> [Type] -> [TyVar] - -> DataCon -> HsConPatDetails Id -> PmM PatVec + -> DataCon -> HsConPatDetails Id -> DsM PatVec translateConPatVec fam_insts _univ_tys _ex_tvs _ (PrefixCon ps) = concat <$> translatePatVec fam_insts (map unLoc ps) translateConPatVec fam_insts _univ_tys _ex_tvs _ (InfixCon p1 p2) @@ -467,7 +541,7 @@ translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _)) | otherwise = subsetOf (x:xs) ys -- Translate a single match -translateMatch :: FamInstEnvs -> LMatch Id (LHsExpr Id) -> PmM (PatVec,[PatVec]) +translateMatch :: FamInstEnvs -> LMatch Id (LHsExpr Id) -> DsM (PatVec,[PatVec]) translateMatch fam_insts (L _ (Match _ lpats _ grhss)) = do pats' <- concat <$> translatePatVec fam_insts pats guards' <- mapM (translateGuards fam_insts) guards @@ -483,7 +557,7 @@ translateMatch fam_insts (L _ (Match _ lpats _ grhss)) = do -- * Transform source guards (GuardStmt Id) to PmPats (Pattern) -- | Translate a list of guard statements to a pattern vector -translateGuards :: FamInstEnvs -> [GuardStmt Id] -> PmM PatVec +translateGuards :: FamInstEnvs -> [GuardStmt Id] -> DsM PatVec translateGuards fam_insts guards = do all_guards <- concat <$> mapM (translateGuard fam_insts) guards return (replace_unhandled all_guards) @@ -523,7 +597,7 @@ cantFailPattern (PmGrd pv _e) cantFailPattern _ = False -- | Translate a guard statement to Pattern -translateGuard :: FamInstEnvs -> GuardStmt Id -> PmM PatVec +translateGuard :: FamInstEnvs -> GuardStmt Id -> DsM PatVec translateGuard fam_insts guard = case guard of BodyStmt e _ _ _ -> translateBoolGuard e LetStmt binds -> translateLet (unLoc binds) @@ -535,17 +609,17 @@ translateGuard fam_insts guard = case guard of ApplicativeStmt {} -> panic "translateGuard ApplicativeLastStmt" -- | Translate let-bindings -translateLet :: HsLocalBinds Id -> PmM PatVec +translateLet :: HsLocalBinds Id -> DsM PatVec translateLet _binds = return [] -- | Translate a pattern guard -translateBind :: FamInstEnvs -> LPat Id -> LHsExpr Id -> PmM PatVec +translateBind :: FamInstEnvs -> LPat Id -> LHsExpr Id -> DsM PatVec translateBind fam_insts (L _ p) e = do ps <- translatePat fam_insts p return [mkGuard ps (unLoc e)] -- | Translate a boolean guard -translateBoolGuard :: LHsExpr Id -> PmM PatVec +translateBoolGuard :: LHsExpr Id -> DsM PatVec translateBoolGuard e | isJust (isTrueLHsExpr e) = return [] -- The formal thing to do would be to generate (True <- True) @@ -675,7 +749,7 @@ pmPatType (PmGrd { pm_grd_pv = pv }) -- | Generate a value abstraction for a given constructor (generate -- fresh variables of the appropriate type for arguments) -mkOneConFull :: Id -> DataCon -> PmM (ValAbs, ComplexEq, Bag EvVar) +mkOneConFull :: Id -> DataCon -> DsM (ValAbs, ComplexEq, Bag EvVar) -- * x :: T tys, where T is an algebraic data type -- NB: in the case of a data familiy, T is the *representation* TyCon -- e.g. data instance T (a,b) = T1 a b @@ -738,17 +812,17 @@ mkPosEq x l = (PmExprVar (idName x), PmExprLit l) {-# INLINE mkPosEq #-} -- | Generate a variable pattern of a given type -mkPmVar :: Type -> PmM (PmPat p) +mkPmVar :: Type -> DsM (PmPat p) mkPmVar ty = PmVar <$> mkPmId ty {-# INLINE mkPmVar #-} -- | Generate many variable patterns, given a list of types -mkPmVars :: [Type] -> PmM PatVec +mkPmVars :: [Type] -> DsM PatVec mkPmVars tys = mapM mkPmVar tys {-# INLINE mkPmVars #-} -- | Generate a fresh `Id` of a given type -mkPmId :: Type -> PmM Id +mkPmId :: Type -> DsM Id mkPmId ty = getUniqueM >>= \unique -> let occname = mkVarOccFS (fsLit (show unique)) name = mkInternalName unique occname noSrcSpan @@ -757,7 +831,7 @@ mkPmId ty = getUniqueM >>= \unique -> -- | Generate a fresh term variable of a given and return it in two forms: -- * A variable pattern -- * A variable expression -mkPmId2Forms :: Type -> PmM (Pattern, LHsExpr Id) +mkPmId2Forms :: Type -> DsM (Pattern, LHsExpr Id) mkPmId2Forms ty = do x <- mkPmId ty return (PmVar x, noLoc (HsVar (noLoc x))) @@ -802,7 +876,7 @@ allConstructors = tyConDataCons . dataConTyCon newEvVar :: Name -> Type -> EvVar newEvVar name ty = mkLocalId name (toTcType ty) -nameType :: String -> Type -> PmM EvVar +nameType :: String -> Type -> DsM EvVar nameType name ty = do unique <- getUniqueM let occname = mkVarOccFS (fsLit (name++"_"++show unique)) @@ -820,7 +894,8 @@ nameType name ty = do -- | Check whether a set of type constraints is satisfiable. tyOracle :: Bag EvVar -> PmM Bool tyOracle evs - = do { ((_warns, errs), res) <- initTcDsForSolver $ tcCheckSatisfiability evs + = liftD $ + do { ((_warns, errs), res) <- initTcDsForSolver $ tcCheckSatisfiability evs ; case res of Just sat -> return sat Nothing -> pprPanic "tyOracle" (vcat $ pprErrMsgBagWithLoc errs) } @@ -861,7 +936,7 @@ Main functions are: are checked, if they are inconsistent, the set is empty, otherwise, the set contains only a vector of variables with the constraints in scope. -* pmcheck :: PatVec -> [PatVec] -> ValVec -> PmM Triple +* pmcheck :: PatVec -> [PatVec] -> ValVec -> PmM PartialResult Checks redundancy, coverage and inaccessibility, using auxilary functions `pmcheckGuards` and `pmcheckHd`. Mainly handles the guard case which is @@ -869,12 +944,12 @@ Main functions are: whole clause is checked, or `pmcheckHd` when the pattern vector does not start with a guard. -* pmcheckGuards :: [PatVec] -> ValVec -> PmM Triple +* pmcheckGuards :: [PatVec] -> ValVec -> PmM PartialResult Processes the guards. * pmcheckHd :: Pattern -> PatVec -> [PatVec] - -> ValAbs -> ValVec -> PmM Triple + -> ValAbs -> ValVec -> PmM PartialResult Worker: This function implements functions `covered`, `uncovered` and `divergent` from the paper at once. Slightly different from the paper because @@ -886,17 +961,20 @@ Main functions are: -- | Lift a pattern matching action from a single value vector abstration to a -- value set abstraction, but calling it on every vector and the combining the -- results. -runMany :: (ValVec -> PmM Triple) -> (Uncovered -> PmM Triple) -runMany pm us = mapAndUnzip3M pm us >>= \(css, uss, dss) -> - return (or css, concat uss, or dss) +runMany :: (ValVec -> PmM PartialResult) -> (Uncovered -> PmM PartialResult) +runMany _ [] = return $ PartialResult mempty mempty mempty +runMany pm (m:ms) = do + (PartialResult c v d) <- pm m + (PartialResult cs vs ds) <- runMany pm ms + return (PartialResult (c `mappend` cs) (v `mappend` vs) (d `mappend` ds)) {-# INLINE runMany #-} -- | Generate the initial uncovered set. It initializes the -- delta with all term and type constraints in scope. mkInitialUncovered :: [Id] -> PmM Uncovered mkInitialUncovered vars = do - ty_cs <- getDictsDs - tm_cs <- map toComplex . bagToList <$> getTmCsDs + ty_cs <- liftD getDictsDs + tm_cs <- map toComplex . bagToList <$> liftD getTmCsDs sat_ty <- tyOracle ty_cs return $ case (sat_ty, tmOracle initialTmState tm_cs) of (True, Just tm_state) -> [ValVec patterns (MkDelta ty_cs tm_state)] @@ -908,41 +986,45 @@ mkInitialUncovered vars = do -- | Increase the counter for elapsed algorithm iterations, check that the -- limit is not exceeded and call `pmcheck` -pmcheckI :: PatVec -> [PatVec] -> ValVec -> PmM Triple +pmcheckI :: PatVec -> [PatVec] -> ValVec -> PmM PartialResult pmcheckI ps guards vva = do - n <- incrCheckPmIterDs + n <- liftD incrCheckPmIterDs tracePm "pmCheck" (ppr n <> colon <+> pprPatVec ps $$ hang (text "guards:") 2 (vcat (map pprPatVec guards)) $$ pprValVecDebug vva) - pmcheck ps guards vva + res <- pmcheck ps guards vva + tracePm "pmCheckResult:" (ppr res) + return res {-# INLINE pmcheckI #-} -- | Increase the counter for elapsed algorithm iterations, check that the -- limit is not exceeded and call `pmcheckGuards` -pmcheckGuardsI :: [PatVec] -> ValVec -> PmM Triple -pmcheckGuardsI gvs vva = incrCheckPmIterDs >> pmcheckGuards gvs vva +pmcheckGuardsI :: [PatVec] -> ValVec -> PmM PartialResult +pmcheckGuardsI gvs vva = liftD incrCheckPmIterDs >> pmcheckGuards gvs vva {-# INLINE pmcheckGuardsI #-} -- | Increase the counter for elapsed algorithm iterations, check that the -- limit is not exceeded and call `pmcheckHd` -pmcheckHdI :: Pattern -> PatVec -> [PatVec] -> ValAbs -> ValVec -> PmM Triple +pmcheckHdI :: Pattern -> PatVec -> [PatVec] -> ValAbs -> ValVec -> PmM PartialResult pmcheckHdI p ps guards va vva = do - n <- incrCheckPmIterDs + n <- liftD incrCheckPmIterDs tracePm "pmCheckHdI" (ppr n <> colon <+> pprPmPatDebug p $$ pprPatVec ps $$ hang (text "guards:") 2 (vcat (map pprPatVec guards)) $$ pprPmPatDebug va $$ pprValVecDebug vva) - pmcheckHd p ps guards va vva + res <- pmcheckHd p ps guards va vva + tracePm "pmCheckHdI: res" (ppr res) + return res {-# INLINE pmcheckHdI #-} -- | Matching function: Check simultaneously a clause (takes separately the -- patterns and the list of guards) for exhaustiveness, redundancy and -- inaccessibility. -pmcheck :: PatVec -> [PatVec] -> ValVec -> PmM Triple +pmcheck :: PatVec -> [PatVec] -> ValVec -> PmM PartialResult pmcheck [] guards vva@(ValVec [] _) - | null guards = return (True, [], False) + | null guards = return $ mempty { presultCovered = Covered } | otherwise = pmcheckGuardsI guards vva -- Guard @@ -953,7 +1035,7 @@ pmcheck (p@(PmGrd pv e) : ps) guards vva@(ValVec vas delta) -- though. So just have these two cases but do not do all the boilerplate | isFakeGuard pv e = forces . mkCons vva <$> pmcheckI ps guards vva | otherwise = do - y <- mkPmId (pmPatType p) + y <- liftD $ mkPmId (pmPatType p) let tm_state = extendSubst y e (delta_tm_cs delta) delta' = delta { delta_tm_cs = tm_state } utail <$> pmcheckI (pv ++ ps) guards (ValVec (PmVar y : vas) delta') @@ -965,41 +1047,44 @@ pmcheck (p:ps) guards (ValVec (va:vva) delta) = pmcheckHdI p ps guards va (ValVec vva delta) -- | Check the list of guards -pmcheckGuards :: [PatVec] -> ValVec -> PmM Triple -pmcheckGuards [] vva = return (False, [vva], False) +pmcheckGuards :: [PatVec] -> ValVec -> PmM PartialResult +pmcheckGuards [] vva = return (usimple [vva]) pmcheckGuards (gv:gvs) vva = do - (cs, vsa, ds ) <- pmcheckI gv [] vva - (css, vsas, dss) <- runMany (pmcheckGuardsI gvs) vsa - return (cs || css, vsas, ds || dss) + (PartialResult cs vsa ds) <- pmcheckI gv [] vva + (PartialResult css vsas dss) <- runMany (pmcheckGuardsI gvs) vsa + return $ PartialResult (cs `mappend` css) vsas (ds `mappend` dss) -- | Worker function: Implements all cases described in the paper for all three -- functions (`covered`, `uncovered` and `divergent`) apart from the `Guard` -- cases which are handled by `pmcheck` -pmcheckHd :: Pattern -> PatVec -> [PatVec] -> ValAbs -> ValVec -> PmM Triple +pmcheckHd :: Pattern -> PatVec -> [PatVec] -> ValAbs -> ValVec -> PmM PartialResult -- Var pmcheckHd (PmVar x) ps guards va (ValVec vva delta) | Just tm_state <- solveOneEq (delta_tm_cs delta) (PmExprVar (idName x), vaToPmExpr va) = ucon va <$> pmcheckI ps guards (ValVec vva (delta {delta_tm_cs = tm_state})) - | otherwise = return (False, [], False) + | otherwise = return mempty -- ConCon pmcheckHd ( p@(PmCon {pm_con_con = c1, pm_con_args = args1})) ps guards (va@(PmCon {pm_con_con = c2, pm_con_args = args2})) (ValVec vva delta) - | c1 /= c2 = return (False, [ValVec (va:vva) delta], False) + | c1 /= c2 = + return (usimple [ValVec (va:vva) delta]) | otherwise = kcon c1 (pm_con_arg_tys p) (pm_con_tvs p) (pm_con_dicts p) <$> pmcheckI (args1 ++ ps) guards (ValVec (args2 ++ vva) delta) -- LitLit -pmcheckHd (PmLit l1) ps guards (va@(PmLit l2)) vva = case eqPmLit l1 l2 of - True -> ucon va <$> pmcheckI ps guards vva - False -> return $ ucon va (False, [vva], False) +pmcheckHd (PmLit l1) ps guards (va@(PmLit l2)) vva = + case eqPmLit l1 l2 of + True -> ucon va <$> pmcheckI ps guards vva + False -> return $ ucon va (usimple [vva]) -- ConVar pmcheckHd (p@(PmCon { pm_con_con = con })) ps guards (PmVar x) (ValVec vva delta) = do - cons_cs <- mapM (mkOneConFull x) (allConstructors con) + cons_cs <- mapM (liftD . mkOneConFull x) (allConstructors con) + inst_vsa <- flip concatMapM cons_cs $ \(va, tm_ct, ty_cs) -> do let ty_state = ty_cs `unionBags` delta_ty_cs delta -- not actually a state sat_ty <- if isEmptyBag ty_cs then return True @@ -1018,13 +1103,13 @@ pmcheckHd (p@(PmLit l)) ps guards (PmVar x) (ValVec vva delta) case solveOneEq (delta_tm_cs delta) (mkPosEq x l) of Just tm_state -> pmcheckHdI p ps guards (PmLit l) $ ValVec vva (delta {delta_tm_cs = tm_state}) - Nothing -> return (False, [], False) + Nothing -> return mempty where us | Just tm_state <- solveOneEq (delta_tm_cs delta) (mkNegEq x l) = [ValVec (PmNLit x [l] : vva) (delta { delta_tm_cs = tm_state })] | otherwise = [] - non_matched = (False, us, False) + non_matched = usimple us -- LitNLit pmcheckHd (p@(PmLit l)) ps guards @@ -1044,7 +1129,7 @@ pmcheckHd (p@(PmLit l)) ps guards = [ValVec (PmNLit x (l:lits) : vva) (delta { delta_tm_cs = tm_state })] | otherwise = [] - non_matched = (False, us, False) + non_matched = usimple us -- ---------------------------------------------------------------------------- -- The following three can happen only in cases like #322 where constructors @@ -1055,14 +1140,14 @@ pmcheckHd (p@(PmLit l)) ps guards -- LitCon pmcheckHd (PmLit l) ps guards (va@(PmCon {})) (ValVec vva delta) - = do y <- mkPmId (pmPatType va) + = do y <- liftD $ mkPmId (pmPatType va) let tm_state = extendSubst y (PmExprLit l) (delta_tm_cs delta) delta' = delta { delta_tm_cs = tm_state } pmcheckHdI (PmVar y) ps guards va (ValVec vva delta') -- ConLit pmcheckHd (p@(PmCon {})) ps guards (PmLit l) (ValVec vva delta) - = do y <- mkPmId (pmPatType p) + = do y <- liftD $ mkPmId (pmPatType p) let tm_state = extendSubst y (PmExprLit l) (delta_tm_cs delta) delta' = delta { delta_tm_cs = tm_state } pmcheckHdI p ps guards (PmVar y) (ValVec vva delta') @@ -1077,54 +1162,66 @@ pmcheckHd (PmGrd {}) _ _ _ _ = panic "pmcheckHd: Guard" -- ---------------------------------------------------------------------------- -- * Utilities for main checking +updateVsa :: (ValSetAbs -> ValSetAbs) -> (PartialResult -> PartialResult) +updateVsa f p@(PartialResult { presultUncovered = old }) + = p { presultUncovered = f old } + + +-- | Initialise with default values for covering and divergent information. +usimple :: ValSetAbs -> PartialResult +usimple vsa = mempty { presultUncovered = vsa } + -- | Take the tail of all value vector abstractions in the uncovered set -utail :: Triple -> Triple -utail (cs, vsa, ds) = (cs, vsa', ds) - where vsa' = [ ValVec vva delta | ValVec (_:vva) delta <- vsa ] +utail :: PartialResult -> PartialResult +utail = updateVsa upd + where upd vsa = [ ValVec vva delta | ValVec (_:vva) delta <- vsa ] -- | Prepend a value abstraction to all value vector abstractions in the -- uncovered set -ucon :: ValAbs -> Triple -> Triple -ucon va (cs, vsa, ds) = (cs, vsa', ds) - where vsa' = [ ValVec (va:vva) delta | ValVec vva delta <- vsa ] +ucon :: ValAbs -> PartialResult -> PartialResult +ucon va = updateVsa upd + where + upd vsa = [ ValVec (va:vva) delta | ValVec vva delta <- vsa ] -- | Given a data constructor of arity `a` and an uncovered set containing -- value vector abstractions of length `(a+n)`, pass the first `n` value -- abstractions to the constructor (Hence, the resulting value vector -- abstractions will have length `n+1`) -kcon :: DataCon -> [Type] -> [TyVar] -> [EvVar] -> Triple -> Triple -kcon con arg_tys ex_tvs dicts (cs, vsa, ds) - = (cs, [ ValVec (va:vva) delta - | ValVec vva' delta <- vsa - , let (args, vva) = splitAt n vva' - , let va = PmCon { pm_con_con = con - , pm_con_arg_tys = arg_tys - , pm_con_tvs = ex_tvs - , pm_con_dicts = dicts - , pm_con_args = args } ] - , ds) - where n = dataConSourceArity con +kcon :: DataCon -> [Type] -> [TyVar] -> [EvVar] + -> PartialResult -> PartialResult +kcon con arg_tys ex_tvs dicts + = let n = dataConSourceArity con + upd vsa = + [ ValVec (va:vva) delta + | ValVec vva' delta <- vsa + , let (args, vva) = splitAt n vva' + , let va = PmCon { pm_con_con = con + , pm_con_arg_tys = arg_tys + , pm_con_tvs = ex_tvs + , pm_con_dicts = dicts + , pm_con_args = args } ] + in updateVsa upd -- | Get the union of two covered, uncovered and divergent value set -- abstractions. Since the covered and divergent sets are represented by a -- boolean, union means computing the logical or (at least one of the two is -- non-empty). -mkUnion :: Triple -> Triple -> Triple -mkUnion (cs1, vsa1, ds1) (cs2, vsa2, ds2) - = (cs1 || cs2, vsa1 ++ vsa2, ds1 || ds2) + +mkUnion :: PartialResult -> PartialResult -> PartialResult +mkUnion = mappend -- | Add a value vector abstraction to a value set abstraction (uncovered). -mkCons :: ValVec -> Triple -> Triple -mkCons vva (cs, vsa, ds) = (cs, vva:vsa, ds) +mkCons :: ValVec -> PartialResult -> PartialResult +mkCons vva = updateVsa (vva:) -- | Set the divergent set to not empty -forces :: Triple -> Triple -forces (cs, us, _) = (cs, us, True) +forces :: PartialResult -> PartialResult +forces pres = pres { presultDivergent = Diverged } -- | Set the divergent set to non-empty if the flag is `True` -force_if :: Bool -> Triple -> Triple -force_if True (cs,us,_) = (cs,us,True) -force_if False triple = triple +force_if :: Bool -> PartialResult -> PartialResult +force_if True pres = forces pres +force_if False pres = pres -- ---------------------------------------------------------------------------- -- * Propagation of term constraints inwards when checking nested matches @@ -1133,7 +1230,7 @@ force_if False triple = triple ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ When checking a match it would be great to have all type and term information available so we can get more precise results. For this reason we have functions -`addDictsDs' and `addTmCsDs' in DsMonad that store in the environment type and +`addDictsDs' and `addTmCsDs' in PmMonad that store in the environment type and term constraints (respectively) as we go deeper. The type constraints we propagate inwards are collected by `collectEvVarsPats' @@ -1275,7 +1372,10 @@ dsPmWarn dflags ctx@(DsMatchContext kind loc) pm_result when exists_u $ putSrcSpanDs loc (warnDs flag_u_reason (pprEqns uncovered)) where - (redundant, uncovered, inaccessible) = pm_result + PmResult + { pmresultRedundant = redundant + , pmresultUncovered = uncovered + , pmresultInaccessible = inaccessible } = pm_result flag_i = wopt Opt_WarnOverlappingPatterns dflags flag_u = exhaustive dflags kind @@ -1298,7 +1398,7 @@ dsPmWarn dflags ctx@(DsMatchContext kind loc) pm_result -- | Issue a warning when the predefined number of iterations is exceeded -- for the pattern match checker -warnPmIters :: DynFlags -> DsMatchContext -> PmM () +warnPmIters :: DynFlags -> DsMatchContext -> DsM () warnPmIters dflags (DsMatchContext kind loc) = when (flag_i || flag_u) $ do iters <- maxPmCheckIterations <$> getDynFlags @@ -1441,7 +1541,11 @@ involved. -- Debugging Infrastructre tracePm :: String -> SDoc -> PmM () -tracePm herald doc = do +tracePm herald doc = liftD $ tracePmD herald doc + + +tracePmD :: String -> SDoc -> DsM () +tracePmD herald doc = do dflags <- getDynFlags printer <- mkPrintUnqualifiedDs liftIO $ dumpIfSet_dyn_printer printer dflags diff --git a/compiler/ghc.cabal.in b/compiler/ghc.cabal.in index 9538e2cb0b..c5ca3132cd 100644 --- a/compiler/ghc.cabal.in +++ b/compiler/ghc.cabal.in @@ -490,6 +490,7 @@ Library GraphPpr IOEnv ListSetOps + ListT Maybes MonadUtils OrdList diff --git a/compiler/utils/ListT.hs b/compiler/utils/ListT.hs new file mode 100644 index 0000000000..2b81db1ed4 --- /dev/null +++ b/compiler/utils/ListT.hs @@ -0,0 +1,71 @@ +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} + +------------------------------------------------------------------------- +-- | +-- Module : Control.Monad.Logic +-- Copyright : (c) Dan Doel +-- License : BSD3 +-- +-- Maintainer : dan.doel@gmail.com +-- Stability : experimental +-- Portability : non-portable (multi-parameter type classes) +-- +-- A backtracking, logic programming monad. +-- +-- Adapted from the paper +-- /Backtracking, Interleaving, and Terminating +-- Monad Transformers/, by +-- Oleg Kiselyov, Chung-chieh Shan, Daniel P. Friedman, Amr Sabry +-- (<http://www.cs.rutgers.edu/~ccshan/logicprog/ListT-icfp2005.pdf>). +------------------------------------------------------------------------- + +module ListT ( + ListT(..), + runListT, + select, + fold + ) where + +import Control.Applicative + +import Control.Monad + +------------------------------------------------------------------------- +-- | A monad transformer for performing backtracking computations +-- layered over another monad 'm' +newtype ListT m a = + ListT { unListT :: forall r. (a -> m r -> m r) -> m r -> m r } + +select :: Monad m => [a] -> ListT m a +select xs = foldr (<|>) mzero (map pure xs) + +fold :: ListT m a -> (a -> m r -> m r) -> m r -> m r +fold = runListT + +------------------------------------------------------------------------- +-- | Runs a ListT computation with the specified initial success and +-- failure continuations. +runListT :: ListT m a -> (a -> m r -> m r) -> m r -> m r +runListT = unListT + +instance Functor (ListT f) where + fmap f lt = ListT $ \sk fk -> unListT lt (sk . f) fk + +instance Applicative (ListT f) where + pure a = ListT $ \sk fk -> sk a fk + f <*> a = ListT $ \sk fk -> unListT f (\g fk' -> unListT a (sk . g) fk') fk + +instance Alternative (ListT f) where + empty = ListT $ \_ fk -> fk + f1 <|> f2 = ListT $ \sk fk -> unListT f1 sk (unListT f2 sk fk) + +instance Monad (ListT m) where + m >>= f = ListT $ \sk fk -> unListT m (\a fk' -> unListT (f a) sk fk') fk + fail _ = ListT $ \_ fk -> fk + +instance MonadPlus (ListT m) where + mzero = ListT $ \_ fk -> fk + m1 `mplus` m2 = ListT $ \sk fk -> unListT m1 sk (unListT m2 sk fk) |