diff options
author | Sebastian Graf <sebastian.graf@kit.edu> | 2019-05-03 10:34:44 +0200 |
---|---|---|
committer | Sebastian Graf <sebastian.graf@kit.edu> | 2019-05-03 10:34:44 +0200 |
commit | 24c20cd71fbdd0de588a7ea0e06cbc520fd3a97c (patch) | |
tree | 1d7f6a34d7ead9528f1bd0a976d195abc402791e | |
parent | 37a4fd9715de4dad8033ea74483432c77818abf5 (diff) | |
download | haskell-wip/gvn-pmcheck.tar.gz |
Pattern match complex expressions by GVNwip/gvn-pmcheck
By referential transparency, multiple syntactic occurrences of the same
expression evaluate to the same value. Global value numbering (GVN)
assigns each such expression the same unique number (a `Name` in our
case). Two expressions trivially have the same value if they are
assigned the same value number.
The term oracle `TmOracle` of the pattern match checker couldn't handle
any complex expression before this patch. It would just give up on
anything involving a function application whose head was not a
constructor, by falling back to `PmExprOther`. This means it could not
determine completeness of the following example:
```haskell
foo
| True <- id True
= 1
| False <- id True
= 2
```
This is simply because `TmOracle` couldn't figure out that `id True`
always evaluates to the same `Bool`.
In this patch, we desugar such `PmExprOther`s in pattern guards to
`CoreExpr`. We do so in order to utilise `CoreMap Name` for a
light-weight GVN pass without concern for subexpressions. `TmOracle`
only sees the representing variables, like so:
```haskell
x = id True
foo
| True <- x
= 1
| False <- x
= 2
```
So `TmOracle` still doesn't need to decide equality of complex
expressions, which allows it to stay dead simple.
-rw-r--r-- | compiler/deSugar/Check.hs | 122 |
1 files changed, 79 insertions, 43 deletions
diff --git a/compiler/deSugar/Check.hs b/compiler/deSugar/Check.hs index e87eb39d26..3ec04effe4 100644 --- a/compiler/deSugar/Check.hs +++ b/compiler/deSugar/Check.hs @@ -32,6 +32,7 @@ import TcHsSyn import Id import ConLike import Name +import NameEnv import FamInstEnv import TysPrim (tYPETyCon) import TysWiredIn @@ -42,9 +43,13 @@ import Outputable import FastString import DataCon import PatSyn -import HscTypes (CompleteMatch(..)) +import HscTypes (CompleteMatch(..)) +import CoreMap (CoreMap, emptyCoreMap, lookupCoreMap, extendCoreMap) +import CoreOpt (simpleOptExpr) +import CoreUtils (exprType) import DsMonad +import {-# SOURCE #-} DsExpr (dsExpr) import TcSimplify (tcCheckSatisfiability) import TcType (isStringTy) import Bag @@ -60,13 +65,15 @@ import qualified GHC.LanguageExtensions as LangExt import Data.List (find) import Data.Maybe (catMaybes, isJust, fromMaybe) import Control.Monad (forM, when, forM_, zipWithM, filterM) +import Control.Monad.Trans.State.Strict (StateT (..), evalStateT) +import Control.Monad.Trans.Class import Coercion import TcEvidence import TcSimplify (tcNormalise) import IOEnv import qualified Data.Semigroup as Semi -import ListT (ListT(..), fold, select) +import ListT (ListT(..), fold, select) {- This module checks pattern matches for: @@ -140,6 +147,34 @@ getResult ls go (Just (PmResult _ _ (TypeOfUncovered _) _)) _new = panic "getResult: No inhabitation candidates" +data TranslateEnv + = TE { te_rep_env :: !(CoreMap Id) + -- ^ Representatives for PmExprOther as Core expressions + , te_orig_exprs :: NameEnv (HsExpr GhcTc) + -- ^ Maps representatives to their represented expression + } + +initialTE :: TranslateEnv +initialTE = TE emptyCoreMap emptyNameEnv + +-- | Monad in which we translate pattern matches +type TlM a = StateT TranslateEnv DsM a + +representPmExprOther :: PmExpr -> TlM PmExpr +representPmExprOther (PmExprOther e) = do + dflags <- lift getDynFlags + core_expr <- simpleOptExpr dflags <$> lift (dsExpr e) + StateT $ \env@TE{te_rep_env = cm, te_orig_exprs = origs } -> do + (name, env') <- + case lookupCoreMap cm core_expr of + Just y -> pure (idName y, env) + Nothing -> do + y <- mkPmId (exprType core_expr) + pure (idName y, env { te_rep_env = extendCoreMap cm core_expr y }) + tracePmD "representPmExprOther" (ppr name <+> text "->" <+> ppr (e, core_expr)) + pure (PmExprVar name, env' { te_orig_exprs = extendNameEnv origs name e }) +representPmExprOther e = pure e + data PatTy = PAT | VA -- Used only as a kind, to index PmPat -- The *arity* of a PatVec [p1,..,pn] is @@ -350,9 +385,9 @@ checkSingle dflags ctxt@(DsMatchContext _ locn) var p = do checkSingle' :: SrcSpan -> Id -> Pat GhcTc -> PmM PmResult checkSingle' locn var p = do liftD resetPmIterDs -- set the iter-no to zero - fam_insts <- liftD dsGetFamInstEnvs - clause <- liftD $ translatePat fam_insts p - missing <- mkInitialUncovered [var] + fam_insts <- liftD dsGetFamInstEnvs + (clause, te) <- liftD $ runStateT (translatePat fam_insts p) initialTE + missing <- mkInitialUncovered [var] tracePm "checkSingle': missing" (vcat (map pprValVecDebug missing)) -- no guards PartialResult prov cs us ds <- runMany (pmcheckI clause []) missing @@ -422,8 +457,8 @@ checkMatches' vars matches go [] missing = return (mempty, [], missing, []) go (m:ms) missing = do tracePm "checkMatches': go" (ppr m $$ ppr missing) - fam_insts <- liftD dsGetFamInstEnvs - (clause, guards) <- liftD $ translateMatch fam_insts m + fam_insts <- liftD dsGetFamInstEnvs + ((clause, guards), te) <- liftD $ runStateT (translateMatch fam_insts m) initialTE r@(PartialResult prov cs missing' ds) <- runMany (pmcheckI clause guards) missing tracePm "checkMatches': go: res" (ppr r) @@ -966,12 +1001,12 @@ mkLitPattern lit = PmLit { pm_lit_lit = PmSLit lit } -- ----------------------------------------------------------------------- -- * Transform (Pat Id) into of (PmPat Id) -translatePat :: FamInstEnvs -> Pat GhcTc -> DsM PatVec +translatePat :: FamInstEnvs -> Pat GhcTc -> TlM PatVec translatePat fam_insts pat = case pat of - WildPat ty -> mkPmVars [ty] + WildPat ty -> lift $ mkPmVars [ty] VarPat _ id -> return [PmVar (unLoc id)] ParPat _ p -> translatePat fam_insts (unLoc p) - LazyPat _ _ -> mkPmVars [hsPatType pat] -- like a variable + LazyPat _ _ -> lift $ mkPmVars [hsPatType pat] -- like a variable -- ignore strictness annotations for now BangPat _ p -> translatePat fam_insts (unLoc p) @@ -991,24 +1026,24 @@ translatePat fam_insts pat = case pat of | WpCast co <- wrapper, isReflexiveCo co -> translatePat fam_insts p | otherwise -> do ps <- translatePat fam_insts p - (xp,xe) <- mkPmId2Forms ty + (xp,xe) <- lift $ mkPmId2Forms ty g <- mkGuard ps (mkHsWrap wrapper (unLoc xe)) return [xp,g] -- (n + k) ===> x (True <- x >= k) (n <- x-k) - NPlusKPat ty (dL->L _ _n) _k1 _k2 _ge _minus -> mkCanFailPmPat ty + NPlusKPat ty (dL->L _ _n) _k1 _k2 _ge _minus -> lift $ mkCanFailPmPat ty -- (fun -> pat) ===> x (pat <- fun x) ViewPat arg_ty lexpr lpat -> do ps <- translatePat fam_insts (unLoc lpat) -- See Note [Guards and Approximation] - res <- allM cantFailPattern ps + res <- lift $ allM cantFailPattern ps case res of True -> do - (xp,xe) <- mkPmId2Forms arg_ty + (xp,xe) <- lift $ mkPmId2Forms arg_ty g <- mkGuard ps (HsApp noExt lexpr xe) return [xp,g] - False -> mkCanFailPmPat arg_ty + False -> lift $ mkCanFailPmPat arg_ty -- list ListPat (ListPatTc ty Nothing) ps -> do @@ -1017,13 +1052,13 @@ translatePat fam_insts pat = case pat of -- overloaded list ListPat (ListPatTc _elem_ty (Just (pat_ty, _to_list))) lpats -> do - dflags <- getDynFlags + dflags <- lift $ getDynFlags if xopt LangExt.RebindableSyntax dflags - then mkCanFailPmPat pat_ty + then lift $ mkCanFailPmPat pat_ty else case splitListTyConApp_maybe pat_ty of Just e_ty -> translatePat fam_insts (ListPat (ListPatTc e_ty Nothing) lpats) - Nothing -> mkCanFailPmPat pat_ty + Nothing -> lift $ mkCanFailPmPat pat_ty -- (a) In the presence of RebindableSyntax, we don't know anything about -- `toList`, we should treat `ListPat` as any other view pattern. -- @@ -1047,9 +1082,9 @@ translatePat fam_insts pat = case pat of , pat_tvs = ex_tvs , pat_dicts = dicts , pat_args = ps } -> do - groups <- allCompleteMatches con arg_tys + groups <- lift $ allCompleteMatches con arg_tys case groups of - [] -> mkCanFailPmPat (conLikeResTy con arg_tys) + [] -> lift $ mkCanFailPmPat (conLikeResTy con arg_tys) _ -> do args <- translateConPatVec fam_insts arg_tys ex_tvs con ps return [PmCon { pm_con_con = con @@ -1178,23 +1213,23 @@ from translation in pattern matcher. -- | Translate a list of patterns (Note: each pattern is translated -- to a pattern vector but we do not concatenate the results). -translatePatVec :: FamInstEnvs -> [Pat GhcTc] -> DsM [PatVec] +translatePatVec :: FamInstEnvs -> [Pat GhcTc] -> TlM [PatVec] translatePatVec fam_insts pats = mapM (translatePat fam_insts) pats -- | Translate a constructor pattern translateConPatVec :: FamInstEnvs -> [Type] -> [TyVar] - -> ConLike -> HsConPatDetails GhcTc -> DsM PatVec + -> ConLike -> HsConPatDetails GhcTc -> TlM PatVec translateConPatVec fam_insts _univ_tys _ex_tvs _ (PrefixCon ps) = concat <$> translatePatVec fam_insts (map unLoc ps) translateConPatVec fam_insts _univ_tys _ex_tvs _ (InfixCon p1 p2) = concat <$> translatePatVec fam_insts (map unLoc [p1,p2]) translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _)) -- Nothing matched. Make up some fresh term variables - | null fs = mkPmVars arg_tys + | null fs = lift $ mkPmVars arg_tys -- The data constructor was not defined using record syntax. For the -- pattern to be in record syntax it should be empty (e.g. Just {}). -- So just like the previous case. - | null orig_lbls = ASSERT(null matched_lbls) mkPmVars arg_tys + | null orig_lbls = ASSERT(null matched_lbls) lift $ mkPmVars arg_tys -- Some of the fields appear, in the original order (there may be holes). -- Generate a simple constructor pattern and make up fresh variables for -- the rest of the fields @@ -1202,13 +1237,13 @@ translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _)) = ASSERT(orig_lbls `equalLength` arg_tys) let translateOne (lbl, ty) = case lookup lbl matched_pats of Just p -> translatePat fam_insts p - Nothing -> mkPmVars [ty] + Nothing -> lift $ mkPmVars [ty] in concatMapM translateOne (zip orig_lbls arg_tys) -- The fields that appear are not in the correct order. Make up fresh -- variables for all fields and add guards after matching, to force the -- evaluation in the correct order. | otherwise = do - arg_var_pats <- mkPmVars arg_tys + arg_var_pats <- lift $ mkPmVars arg_tys translated_pats <- forM matched_pats $ \(x,pat) -> do pvec <- translatePat fam_insts pat return (x, pvec) @@ -1239,7 +1274,7 @@ translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _)) -- Translate a single match translateMatch :: FamInstEnvs -> LMatch GhcTc (LHsExpr GhcTc) - -> DsM (PatVec,[PatVec]) + -> TlM (PatVec,[PatVec]) translateMatch fam_insts (dL->L _ (Match { m_pats = lpats, m_grhss = grhss })) = do pats' <- concat <$> translatePatVec fam_insts pats @@ -1258,7 +1293,7 @@ translateMatch _ _ = panic "translateMatch" -- * Transform source guards (GuardStmt Id) to PmPats (Pattern) -- | Translate a list of guard statements to a pattern vector -translateGuards :: FamInstEnvs -> [GuardStmt GhcTc] -> DsM PatVec +translateGuards :: FamInstEnvs -> [GuardStmt GhcTc] -> TlM PatVec translateGuards fam_insts guards = do all_guards <- concat <$> mapM (translateGuard fam_insts) guards let @@ -1273,7 +1308,7 @@ translateGuards fam_insts guards = do | otherwise = allM shouldKeep pv shouldKeep _other_pat = pure False -- let the rest.. - all_handled <- allM shouldKeep all_guards + all_handled <- lift $ allM shouldKeep all_guards -- It should have been @pure all_guards@ but it is too expressive. -- Since the term oracle does not handle all constraints we generate, -- we (hackily) replace all constraints the oracle cannot handle with a @@ -1283,7 +1318,7 @@ translateGuards fam_insts guards = do if all_handled then pure all_guards else do - kept <- filterM shouldKeep all_guards + kept <- lift $ filterM shouldKeep all_guards pure (PmFake : kept) -- | Check whether a pattern can fail to match @@ -1295,7 +1330,7 @@ cantFailPattern (PmGrd pv _e) = allM cantFailPattern pv cantFailPattern _ = pure False -- | Translate a guard statement to Pattern -translateGuard :: FamInstEnvs -> GuardStmt GhcTc -> DsM PatVec +translateGuard :: FamInstEnvs -> GuardStmt GhcTc -> TlM PatVec translateGuard fam_insts guard = case guard of BodyStmt _ e _ _ -> translateBoolGuard e LetStmt _ binds -> translateLet (unLoc binds) @@ -1308,18 +1343,18 @@ translateGuard fam_insts guard = case guard of XStmtLR {} -> panic "translateGuard RecStmt" -- | Translate let-bindings -translateLet :: HsLocalBinds GhcTc -> DsM PatVec +translateLet :: HsLocalBinds GhcTc -> TlM PatVec translateLet _binds = return [] -- | Translate a pattern guard -translateBind :: FamInstEnvs -> LPat GhcTc -> LHsExpr GhcTc -> DsM PatVec +translateBind :: FamInstEnvs -> LPat GhcTc -> LHsExpr GhcTc -> TlM PatVec translateBind fam_insts (dL->L _ p) e = do ps <- translatePat fam_insts p g <- mkGuard ps (unLoc e) return [g] -- | Translate a boolean guard -translateBoolGuard :: LHsExpr GhcTc -> DsM PatVec +translateBoolGuard :: LHsExpr GhcTc -> TlM PatVec translateBoolGuard e | isJust (isTrueLHsExpr e) = return [] -- The formal thing to do would be to generate (True <- True) @@ -1663,14 +1698,13 @@ mkOneConFull x con = do -- * More smart constructors and fresh variable generation -- | Create a guard pattern -mkGuard :: PatVec -> HsExpr GhcTc -> DsM Pattern +mkGuard :: PatVec -> HsExpr GhcTc -> TlM Pattern mkGuard pv e = do - res <- allM cantFailPattern pv - let expr = hsExprToPmExpr e - tracePmD "mkGuard" (vcat [ppr pv, ppr e, ppr res, ppr expr]) - if | res -> pure (PmGrd pv expr) - | PmExprOther {} <- expr -> pure PmFake - | otherwise -> pure (PmGrd pv expr) + res <- lift $ allM cantFailPattern pv + let expr = hsExprToPmExpr e + expr' <- representPmExprOther expr + traceTl "mkGuard" (vcat [ppr pv, ppr e, ppr res, ppr expr, ppr expr']) + pure (PmGrd pv expr') -- | Create a term equality of the form: `(False ~ (x ~ lit))` mkNegEq :: Id -> PmLit -> ComplexEq @@ -2403,8 +2437,8 @@ genCaseTmCs2 :: Maybe (LHsExpr GhcTc) -- Scrutinee -> [Id] -- MatchVars (should have length 1) -> DsM (Bag SimpleEq) genCaseTmCs2 Nothing _ _ = return emptyBag -genCaseTmCs2 (Just scr) [p] [var] = do - fam_insts <- dsGetFamInstEnvs +genCaseTmCs2 (Just scr) [p] [var] = flip evalStateT initialTE $ do + fam_insts <- lift $ dsGetFamInstEnvs [e] <- map vaToPmExpr . coercePatVec <$> translatePat fam_insts p let scr_e = lhsExprToPmExpr scr return $ listToBag [(var, e), (var, scr_e)] @@ -2719,6 +2753,8 @@ involved. tracePm :: String -> SDoc -> PmM () tracePm herald doc = liftD $ tracePmD herald doc +traceTl :: String -> SDoc -> TlM () +traceTl herald doc = lift $ tracePmD herald doc tracePmD :: String -> SDoc -> DsM () tracePmD herald doc = do |