summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Graf <sebastian.graf@kit.edu>2019-05-03 10:34:44 +0200
committerSebastian Graf <sebastian.graf@kit.edu>2019-05-03 10:34:44 +0200
commit24c20cd71fbdd0de588a7ea0e06cbc520fd3a97c (patch)
tree1d7f6a34d7ead9528f1bd0a976d195abc402791e
parent37a4fd9715de4dad8033ea74483432c77818abf5 (diff)
downloadhaskell-wip/gvn-pmcheck.tar.gz
Pattern match complex expressions by GVNwip/gvn-pmcheck
By referential transparency, multiple syntactic occurrences of the same expression evaluate to the same value. Global value numbering (GVN) assigns each such expression the same unique number (a `Name` in our case). Two expressions trivially have the same value if they are assigned the same value number. The term oracle `TmOracle` of the pattern match checker couldn't handle any complex expression before this patch. It would just give up on anything involving a function application whose head was not a constructor, by falling back to `PmExprOther`. This means it could not determine completeness of the following example: ```haskell foo | True <- id True = 1 | False <- id True = 2 ``` This is simply because `TmOracle` couldn't figure out that `id True` always evaluates to the same `Bool`. In this patch, we desugar such `PmExprOther`s in pattern guards to `CoreExpr`. We do so in order to utilise `CoreMap Name` for a light-weight GVN pass without concern for subexpressions. `TmOracle` only sees the representing variables, like so: ```haskell x = id True foo | True <- x = 1 | False <- x = 2 ``` So `TmOracle` still doesn't need to decide equality of complex expressions, which allows it to stay dead simple.
-rw-r--r--compiler/deSugar/Check.hs122
1 files changed, 79 insertions, 43 deletions
diff --git a/compiler/deSugar/Check.hs b/compiler/deSugar/Check.hs
index e87eb39d26..3ec04effe4 100644
--- a/compiler/deSugar/Check.hs
+++ b/compiler/deSugar/Check.hs
@@ -32,6 +32,7 @@ import TcHsSyn
import Id
import ConLike
import Name
+import NameEnv
import FamInstEnv
import TysPrim (tYPETyCon)
import TysWiredIn
@@ -42,9 +43,13 @@ import Outputable
import FastString
import DataCon
import PatSyn
-import HscTypes (CompleteMatch(..))
+import HscTypes (CompleteMatch(..))
+import CoreMap (CoreMap, emptyCoreMap, lookupCoreMap, extendCoreMap)
+import CoreOpt (simpleOptExpr)
+import CoreUtils (exprType)
import DsMonad
+import {-# SOURCE #-} DsExpr (dsExpr)
import TcSimplify (tcCheckSatisfiability)
import TcType (isStringTy)
import Bag
@@ -60,13 +65,15 @@ import qualified GHC.LanguageExtensions as LangExt
import Data.List (find)
import Data.Maybe (catMaybes, isJust, fromMaybe)
import Control.Monad (forM, when, forM_, zipWithM, filterM)
+import Control.Monad.Trans.State.Strict (StateT (..), evalStateT)
+import Control.Monad.Trans.Class
import Coercion
import TcEvidence
import TcSimplify (tcNormalise)
import IOEnv
import qualified Data.Semigroup as Semi
-import ListT (ListT(..), fold, select)
+import ListT (ListT(..), fold, select)
{-
This module checks pattern matches for:
@@ -140,6 +147,34 @@ getResult ls
go (Just (PmResult _ _ (TypeOfUncovered _) _)) _new
= panic "getResult: No inhabitation candidates"
+data TranslateEnv
+ = TE { te_rep_env :: !(CoreMap Id)
+ -- ^ Representatives for PmExprOther as Core expressions
+ , te_orig_exprs :: NameEnv (HsExpr GhcTc)
+ -- ^ Maps representatives to their represented expression
+ }
+
+initialTE :: TranslateEnv
+initialTE = TE emptyCoreMap emptyNameEnv
+
+-- | Monad in which we translate pattern matches
+type TlM a = StateT TranslateEnv DsM a
+
+representPmExprOther :: PmExpr -> TlM PmExpr
+representPmExprOther (PmExprOther e) = do
+ dflags <- lift getDynFlags
+ core_expr <- simpleOptExpr dflags <$> lift (dsExpr e)
+ StateT $ \env@TE{te_rep_env = cm, te_orig_exprs = origs } -> do
+ (name, env') <-
+ case lookupCoreMap cm core_expr of
+ Just y -> pure (idName y, env)
+ Nothing -> do
+ y <- mkPmId (exprType core_expr)
+ pure (idName y, env { te_rep_env = extendCoreMap cm core_expr y })
+ tracePmD "representPmExprOther" (ppr name <+> text "->" <+> ppr (e, core_expr))
+ pure (PmExprVar name, env' { te_orig_exprs = extendNameEnv origs name e })
+representPmExprOther e = pure e
+
data PatTy = PAT | VA -- Used only as a kind, to index PmPat
-- The *arity* of a PatVec [p1,..,pn] is
@@ -350,9 +385,9 @@ checkSingle dflags ctxt@(DsMatchContext _ locn) var p = do
checkSingle' :: SrcSpan -> Id -> Pat GhcTc -> PmM PmResult
checkSingle' locn var p = do
liftD resetPmIterDs -- set the iter-no to zero
- fam_insts <- liftD dsGetFamInstEnvs
- clause <- liftD $ translatePat fam_insts p
- missing <- mkInitialUncovered [var]
+ fam_insts <- liftD dsGetFamInstEnvs
+ (clause, te) <- liftD $ runStateT (translatePat fam_insts p) initialTE
+ missing <- mkInitialUncovered [var]
tracePm "checkSingle': missing" (vcat (map pprValVecDebug missing))
-- no guards
PartialResult prov cs us ds <- runMany (pmcheckI clause []) missing
@@ -422,8 +457,8 @@ checkMatches' vars matches
go [] missing = return (mempty, [], missing, [])
go (m:ms) missing = do
tracePm "checkMatches': go" (ppr m $$ ppr missing)
- fam_insts <- liftD dsGetFamInstEnvs
- (clause, guards) <- liftD $ translateMatch fam_insts m
+ fam_insts <- liftD dsGetFamInstEnvs
+ ((clause, guards), te) <- liftD $ runStateT (translateMatch fam_insts m) initialTE
r@(PartialResult prov cs missing' ds)
<- runMany (pmcheckI clause guards) missing
tracePm "checkMatches': go: res" (ppr r)
@@ -966,12 +1001,12 @@ mkLitPattern lit = PmLit { pm_lit_lit = PmSLit lit }
-- -----------------------------------------------------------------------
-- * Transform (Pat Id) into of (PmPat Id)
-translatePat :: FamInstEnvs -> Pat GhcTc -> DsM PatVec
+translatePat :: FamInstEnvs -> Pat GhcTc -> TlM PatVec
translatePat fam_insts pat = case pat of
- WildPat ty -> mkPmVars [ty]
+ WildPat ty -> lift $ mkPmVars [ty]
VarPat _ id -> return [PmVar (unLoc id)]
ParPat _ p -> translatePat fam_insts (unLoc p)
- LazyPat _ _ -> mkPmVars [hsPatType pat] -- like a variable
+ LazyPat _ _ -> lift $ mkPmVars [hsPatType pat] -- like a variable
-- ignore strictness annotations for now
BangPat _ p -> translatePat fam_insts (unLoc p)
@@ -991,24 +1026,24 @@ translatePat fam_insts pat = case pat of
| WpCast co <- wrapper, isReflexiveCo co -> translatePat fam_insts p
| otherwise -> do
ps <- translatePat fam_insts p
- (xp,xe) <- mkPmId2Forms ty
+ (xp,xe) <- lift $ mkPmId2Forms ty
g <- mkGuard ps (mkHsWrap wrapper (unLoc xe))
return [xp,g]
-- (n + k) ===> x (True <- x >= k) (n <- x-k)
- NPlusKPat ty (dL->L _ _n) _k1 _k2 _ge _minus -> mkCanFailPmPat ty
+ NPlusKPat ty (dL->L _ _n) _k1 _k2 _ge _minus -> lift $ mkCanFailPmPat ty
-- (fun -> pat) ===> x (pat <- fun x)
ViewPat arg_ty lexpr lpat -> do
ps <- translatePat fam_insts (unLoc lpat)
-- See Note [Guards and Approximation]
- res <- allM cantFailPattern ps
+ res <- lift $ allM cantFailPattern ps
case res of
True -> do
- (xp,xe) <- mkPmId2Forms arg_ty
+ (xp,xe) <- lift $ mkPmId2Forms arg_ty
g <- mkGuard ps (HsApp noExt lexpr xe)
return [xp,g]
- False -> mkCanFailPmPat arg_ty
+ False -> lift $ mkCanFailPmPat arg_ty
-- list
ListPat (ListPatTc ty Nothing) ps -> do
@@ -1017,13 +1052,13 @@ translatePat fam_insts pat = case pat of
-- overloaded list
ListPat (ListPatTc _elem_ty (Just (pat_ty, _to_list))) lpats -> do
- dflags <- getDynFlags
+ dflags <- lift $ getDynFlags
if xopt LangExt.RebindableSyntax dflags
- then mkCanFailPmPat pat_ty
+ then lift $ mkCanFailPmPat pat_ty
else case splitListTyConApp_maybe pat_ty of
Just e_ty -> translatePat fam_insts
(ListPat (ListPatTc e_ty Nothing) lpats)
- Nothing -> mkCanFailPmPat pat_ty
+ Nothing -> lift $ mkCanFailPmPat pat_ty
-- (a) In the presence of RebindableSyntax, we don't know anything about
-- `toList`, we should treat `ListPat` as any other view pattern.
--
@@ -1047,9 +1082,9 @@ translatePat fam_insts pat = case pat of
, pat_tvs = ex_tvs
, pat_dicts = dicts
, pat_args = ps } -> do
- groups <- allCompleteMatches con arg_tys
+ groups <- lift $ allCompleteMatches con arg_tys
case groups of
- [] -> mkCanFailPmPat (conLikeResTy con arg_tys)
+ [] -> lift $ mkCanFailPmPat (conLikeResTy con arg_tys)
_ -> do
args <- translateConPatVec fam_insts arg_tys ex_tvs con ps
return [PmCon { pm_con_con = con
@@ -1178,23 +1213,23 @@ from translation in pattern matcher.
-- | Translate a list of patterns (Note: each pattern is translated
-- to a pattern vector but we do not concatenate the results).
-translatePatVec :: FamInstEnvs -> [Pat GhcTc] -> DsM [PatVec]
+translatePatVec :: FamInstEnvs -> [Pat GhcTc] -> TlM [PatVec]
translatePatVec fam_insts pats = mapM (translatePat fam_insts) pats
-- | Translate a constructor pattern
translateConPatVec :: FamInstEnvs -> [Type] -> [TyVar]
- -> ConLike -> HsConPatDetails GhcTc -> DsM PatVec
+ -> ConLike -> HsConPatDetails GhcTc -> TlM PatVec
translateConPatVec fam_insts _univ_tys _ex_tvs _ (PrefixCon ps)
= concat <$> translatePatVec fam_insts (map unLoc ps)
translateConPatVec fam_insts _univ_tys _ex_tvs _ (InfixCon p1 p2)
= concat <$> translatePatVec fam_insts (map unLoc [p1,p2])
translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _))
-- Nothing matched. Make up some fresh term variables
- | null fs = mkPmVars arg_tys
+ | null fs = lift $ mkPmVars arg_tys
-- The data constructor was not defined using record syntax. For the
-- pattern to be in record syntax it should be empty (e.g. Just {}).
-- So just like the previous case.
- | null orig_lbls = ASSERT(null matched_lbls) mkPmVars arg_tys
+ | null orig_lbls = ASSERT(null matched_lbls) lift $ mkPmVars arg_tys
-- Some of the fields appear, in the original order (there may be holes).
-- Generate a simple constructor pattern and make up fresh variables for
-- the rest of the fields
@@ -1202,13 +1237,13 @@ translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _))
= ASSERT(orig_lbls `equalLength` arg_tys)
let translateOne (lbl, ty) = case lookup lbl matched_pats of
Just p -> translatePat fam_insts p
- Nothing -> mkPmVars [ty]
+ Nothing -> lift $ mkPmVars [ty]
in concatMapM translateOne (zip orig_lbls arg_tys)
-- The fields that appear are not in the correct order. Make up fresh
-- variables for all fields and add guards after matching, to force the
-- evaluation in the correct order.
| otherwise = do
- arg_var_pats <- mkPmVars arg_tys
+ arg_var_pats <- lift $ mkPmVars arg_tys
translated_pats <- forM matched_pats $ \(x,pat) -> do
pvec <- translatePat fam_insts pat
return (x, pvec)
@@ -1239,7 +1274,7 @@ translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _))
-- Translate a single match
translateMatch :: FamInstEnvs -> LMatch GhcTc (LHsExpr GhcTc)
- -> DsM (PatVec,[PatVec])
+ -> TlM (PatVec,[PatVec])
translateMatch fam_insts (dL->L _ (Match { m_pats = lpats, m_grhss = grhss })) =
do
pats' <- concat <$> translatePatVec fam_insts pats
@@ -1258,7 +1293,7 @@ translateMatch _ _ = panic "translateMatch"
-- * Transform source guards (GuardStmt Id) to PmPats (Pattern)
-- | Translate a list of guard statements to a pattern vector
-translateGuards :: FamInstEnvs -> [GuardStmt GhcTc] -> DsM PatVec
+translateGuards :: FamInstEnvs -> [GuardStmt GhcTc] -> TlM PatVec
translateGuards fam_insts guards = do
all_guards <- concat <$> mapM (translateGuard fam_insts) guards
let
@@ -1273,7 +1308,7 @@ translateGuards fam_insts guards = do
| otherwise = allM shouldKeep pv
shouldKeep _other_pat = pure False -- let the rest..
- all_handled <- allM shouldKeep all_guards
+ all_handled <- lift $ allM shouldKeep all_guards
-- It should have been @pure all_guards@ but it is too expressive.
-- Since the term oracle does not handle all constraints we generate,
-- we (hackily) replace all constraints the oracle cannot handle with a
@@ -1283,7 +1318,7 @@ translateGuards fam_insts guards = do
if all_handled
then pure all_guards
else do
- kept <- filterM shouldKeep all_guards
+ kept <- lift $ filterM shouldKeep all_guards
pure (PmFake : kept)
-- | Check whether a pattern can fail to match
@@ -1295,7 +1330,7 @@ cantFailPattern (PmGrd pv _e) = allM cantFailPattern pv
cantFailPattern _ = pure False
-- | Translate a guard statement to Pattern
-translateGuard :: FamInstEnvs -> GuardStmt GhcTc -> DsM PatVec
+translateGuard :: FamInstEnvs -> GuardStmt GhcTc -> TlM PatVec
translateGuard fam_insts guard = case guard of
BodyStmt _ e _ _ -> translateBoolGuard e
LetStmt _ binds -> translateLet (unLoc binds)
@@ -1308,18 +1343,18 @@ translateGuard fam_insts guard = case guard of
XStmtLR {} -> panic "translateGuard RecStmt"
-- | Translate let-bindings
-translateLet :: HsLocalBinds GhcTc -> DsM PatVec
+translateLet :: HsLocalBinds GhcTc -> TlM PatVec
translateLet _binds = return []
-- | Translate a pattern guard
-translateBind :: FamInstEnvs -> LPat GhcTc -> LHsExpr GhcTc -> DsM PatVec
+translateBind :: FamInstEnvs -> LPat GhcTc -> LHsExpr GhcTc -> TlM PatVec
translateBind fam_insts (dL->L _ p) e = do
ps <- translatePat fam_insts p
g <- mkGuard ps (unLoc e)
return [g]
-- | Translate a boolean guard
-translateBoolGuard :: LHsExpr GhcTc -> DsM PatVec
+translateBoolGuard :: LHsExpr GhcTc -> TlM PatVec
translateBoolGuard e
| isJust (isTrueLHsExpr e) = return []
-- The formal thing to do would be to generate (True <- True)
@@ -1663,14 +1698,13 @@ mkOneConFull x con = do
-- * More smart constructors and fresh variable generation
-- | Create a guard pattern
-mkGuard :: PatVec -> HsExpr GhcTc -> DsM Pattern
+mkGuard :: PatVec -> HsExpr GhcTc -> TlM Pattern
mkGuard pv e = do
- res <- allM cantFailPattern pv
- let expr = hsExprToPmExpr e
- tracePmD "mkGuard" (vcat [ppr pv, ppr e, ppr res, ppr expr])
- if | res -> pure (PmGrd pv expr)
- | PmExprOther {} <- expr -> pure PmFake
- | otherwise -> pure (PmGrd pv expr)
+ res <- lift $ allM cantFailPattern pv
+ let expr = hsExprToPmExpr e
+ expr' <- representPmExprOther expr
+ traceTl "mkGuard" (vcat [ppr pv, ppr e, ppr res, ppr expr, ppr expr'])
+ pure (PmGrd pv expr')
-- | Create a term equality of the form: `(False ~ (x ~ lit))`
mkNegEq :: Id -> PmLit -> ComplexEq
@@ -2403,8 +2437,8 @@ genCaseTmCs2 :: Maybe (LHsExpr GhcTc) -- Scrutinee
-> [Id] -- MatchVars (should have length 1)
-> DsM (Bag SimpleEq)
genCaseTmCs2 Nothing _ _ = return emptyBag
-genCaseTmCs2 (Just scr) [p] [var] = do
- fam_insts <- dsGetFamInstEnvs
+genCaseTmCs2 (Just scr) [p] [var] = flip evalStateT initialTE $ do
+ fam_insts <- lift $ dsGetFamInstEnvs
[e] <- map vaToPmExpr . coercePatVec <$> translatePat fam_insts p
let scr_e = lhsExprToPmExpr scr
return $ listToBag [(var, e), (var, scr_e)]
@@ -2719,6 +2753,8 @@ involved.
tracePm :: String -> SDoc -> PmM ()
tracePm herald doc = liftD $ tracePmD herald doc
+traceTl :: String -> SDoc -> TlM ()
+traceTl herald doc = lift $ tracePmD herald doc
tracePmD :: String -> SDoc -> DsM ()
tracePmD herald doc = do