diff options
Diffstat (limited to 'compiler/deSugar/Check.hs')
-rw-r--r-- | compiler/deSugar/Check.hs | 160 |
1 files changed, 85 insertions, 75 deletions
diff --git a/compiler/deSugar/Check.hs b/compiler/deSugar/Check.hs index 38626a486a..73f0177342 100644 --- a/compiler/deSugar/Check.hs +++ b/compiler/deSugar/Check.hs @@ -31,6 +31,7 @@ import Id import ConLike import DataCon import Name +import FamInstEnv import TysWiredIn import TyCon import SrcLoc @@ -148,7 +149,8 @@ type PmResult = ( [[LPat Id]] checkSingle :: Id -> Pat Id -> DsM PmResult checkSingle var p = do let lp = [noLoc p] - vec <- liftUs (translatePat p) + fam_insts <- dsGetFamInstEnvs + vec <- liftUs (translatePat fam_insts p) vsa <- initial_uncovered [var] (c,d,us') <- patVectProc False (vec,[]) vsa -- no guards us <- pruneVSA us' @@ -171,7 +173,8 @@ checkMatches oversimplify vars matches return ([], [], missing') go (m:ms) missing = do - clause <- liftUs (translateMatch m) + fam_insts <- dsGetFamInstEnvs + clause <- liftUs (translateMatch fam_insts m) (c, d, us ) <- patVectProc oversimplify clause missing (rs, is, us') <- go ms us return $ case (c,d) of @@ -209,7 +212,8 @@ noFailingGuards clauses = sum [ countPatVecs gvs | (_, gvs) <- clauses ] computeNoGuards :: [LMatch Id (LHsExpr Id)] -> PmM Int computeNoGuards matches = do - matches' <- mapM (liftUs . translateMatch) matches + fam_insts <- dsGetFamInstEnvs + matches' <- mapM (liftUs . translateMatch fam_insts) matches return (noFailingGuards matches') maximum_failing_guards :: Int @@ -264,46 +268,47 @@ mkLitPattern lit = PmLit { pm_lit_lit = PmSLit lit } -- ----------------------------------------------------------------------- -- * Transform (Pat Id) into of (PmPat Id) -translatePat :: Pat Id -> UniqSM PatVec -translatePat pat = case pat of +translatePat :: FamInstEnvs -> Pat Id -> UniqSM PatVec +translatePat fam_insts pat = case pat of WildPat ty -> mkPmVarsSM [ty] VarPat id -> return [PmVar (unLoc id)] - ParPat p -> translatePat (unLoc p) + ParPat p -> translatePat fam_insts (unLoc p) LazyPat _ -> mkPmVarsSM [hsPatType pat] -- like a variable -- ignore strictness annotations for now - BangPat p -> translatePat (unLoc p) + BangPat p -> translatePat fam_insts (unLoc p) AsPat lid p -> do -- Note [Translating As Patterns] - ps <- translatePat (unLoc p) + ps <- translatePat fam_insts (unLoc p) let [e] = map valAbsToPmExpr (coercePatVec ps) g = PmGrd [PmVar (unLoc lid)] e return (ps ++ [g]) - SigPatOut p _ty -> translatePat (unLoc p) + SigPatOut p _ty -> translatePat fam_insts (unLoc p) -- See Note [Translate CoPats] CoPat wrapper p ty - | isIdHsWrapper wrapper -> translatePat p - | WpCast co <- wrapper, isReflexiveCo co -> translatePat p + | isIdHsWrapper wrapper -> translatePat fam_insts p + | WpCast co <- wrapper, isReflexiveCo co -> translatePat fam_insts p | otherwise -> do - ps <- translatePat p + ps <- translatePat fam_insts p (xp,xe) <- mkPmId2FormsSM ty let g = mkGuard ps (HsWrap wrapper (unLoc xe)) return [xp,g] -- (n + k) ===> x (True <- x >= k) (n <- x-k) - NPlusKPat (L _ n) k ge minus -> do - (xp, xe) <- mkPmId2FormsSM (idType n) - let ke = L (getLoc k) (HsOverLit (unLoc k)) - g1 = mkGuard [truePattern] (OpApp xe (noLoc ge) no_fixity ke) - g2 = mkGuard [PmVar n] (OpApp xe (noLoc minus) no_fixity ke) + NPlusKPat (L _ n) k1 k2 ge minus ty -> do + (xp, xe) <- mkPmId2FormsSM ty + let ke1 = L (getLoc k1) (HsOverLit (unLoc k1)) + ke2 = L (getLoc k1) (HsOverLit k2) + g1 = mkGuard [truePattern] (unLoc $ nlHsSyntaxApps ge [xe, ke1]) + g2 = mkGuard [PmVar n] (unLoc $ nlHsSyntaxApps minus [xe, ke2]) return [xp, g1, g2] -- (fun -> pat) ===> x (pat <- fun x) ViewPat lexpr lpat arg_ty -> do - ps <- translatePat (unLoc lpat) + ps <- translatePat fam_insts (unLoc lpat) -- See Note [Guards and Approximation] case all cantFailPattern ps of True -> do @@ -316,15 +321,18 @@ translatePat pat = case pat of -- list ListPat ps ty Nothing -> do - foldr (mkListPatVec ty) [nilPattern ty] <$> translatePatVec (map unLoc ps) + foldr (mkListPatVec ty) [nilPattern ty] <$> translatePatVec fam_insts (map unLoc ps) -- overloaded list ListPat lpats elem_ty (Just (pat_ty, _to_list)) - | Just e_ty <- splitListTyConApp_maybe pat_ty, elem_ty `eqType` e_ty -> + | Just e_ty <- splitListTyConApp_maybe pat_ty + , (_, norm_elem_ty) <- normaliseType fam_insts Nominal elem_ty + -- elem_ty is frequently something like `Item [Int]`, but we prefer `Int` + , norm_elem_ty `eqType` e_ty -> -- We have to ensure that the element types are exactly the same. -- Otherwise, one may give an instance IsList [Int] (more specific than -- the default IsList [a]) with a different implementation for `toList' - translatePat (ListPat lpats e_ty Nothing) + translatePat fam_insts (ListPat lpats e_ty Nothing) | otherwise -> do -- See Note [Guards and Approximation] var <- mkPmVarSM pat_ty @@ -345,29 +353,29 @@ translatePat pat = case pat of , pat_tvs = ex_tvs , pat_dicts = dicts , pat_args = ps } -> do - args <- translateConPatVec arg_tys ex_tvs con ps + args <- translateConPatVec fam_insts arg_tys ex_tvs con ps return [PmCon { pm_con_con = con , pm_con_arg_tys = arg_tys , pm_con_tvs = ex_tvs , pm_con_dicts = dicts , pm_con_args = args }] - NPat (L _ ol) mb_neg _eq -> translateNPat ol mb_neg + NPat (L _ ol) mb_neg _eq ty -> translateNPat fam_insts ol mb_neg ty LitPat lit -- If it is a string then convert it to a list of characters | HsString src s <- lit -> foldr (mkListPatVec charTy) [nilPattern charTy] <$> - translatePatVec (map (LitPat . HsChar src) (unpackFS s)) + translatePatVec fam_insts (map (LitPat . HsChar src) (unpackFS s)) | otherwise -> return [mkLitPattern lit] PArrPat ps ty -> do - tidy_ps <- translatePatVec (map unLoc ps) + tidy_ps <- translatePatVec fam_insts (map unLoc ps) let fake_con = parrFakeCon (length ps) return [vanillaConPattern fake_con [ty] (concat tidy_ps)] TuplePat ps boxity tys -> do - tidy_ps <- translatePatVec (map unLoc ps) + tidy_ps <- translatePatVec fam_insts (map unLoc ps) let tuple_con = tupleDataCon boxity (length ps) return [vanillaConPattern tuple_con tys (concat tidy_ps)] @@ -378,33 +386,35 @@ translatePat pat = case pat of SigPatIn {} -> panic "Check.translatePat: SigPatIn" -- | Translate an overloaded literal (see `tidyNPat' in deSugar/MatchLit.hs) -translateNPat :: HsOverLit Id -> Maybe (SyntaxExpr Id) -> UniqSM PatVec -translateNPat (OverLit val False _ ty) mb_neg - | isStringTy ty, HsIsString src s <- val, Nothing <- mb_neg - = translatePat (LitPat (HsString src s)) - | isIntTy ty, HsIntegral src i <- val - = translatePat (mk_num_lit HsInt src i) - | isWordTy ty, HsIntegral src i <- val - = translatePat (mk_num_lit HsWordPrim src i) +translateNPat :: FamInstEnvs + -> HsOverLit Id -> Maybe (SyntaxExpr Id) -> Type -> UniqSM PatVec +translateNPat fam_insts (OverLit val False _ ty) mb_neg outer_ty + | not type_change, isStringTy ty, HsIsString src s <- val, Nothing <- mb_neg + = translatePat fam_insts (LitPat (HsString src s)) + | not type_change, isIntTy ty, HsIntegral src i <- val + = translatePat fam_insts (mk_num_lit HsInt src i) + | not type_change, isWordTy ty, HsIntegral src i <- val + = translatePat fam_insts (mk_num_lit HsWordPrim src i) where + type_change = not (outer_ty `eqType` ty) mk_num_lit c src i = LitPat $ case mb_neg of Nothing -> c src i Just _ -> c src (-i) -translateNPat ol mb_neg +translateNPat _ ol mb_neg _ = return [PmLit { pm_lit_lit = PmOLit (isJust mb_neg) ol }] -- | Translate a list of patterns (Note: each pattern is translated -- to a pattern vector but we do not concatenate the results). -translatePatVec :: [Pat Id] -> UniqSM [PatVec] -translatePatVec pats = mapM translatePat pats +translatePatVec :: FamInstEnvs -> [Pat Id] -> UniqSM [PatVec] +translatePatVec fam_insts pats = mapM (translatePat fam_insts) pats -translateConPatVec :: [Type] -> [TyVar] +translateConPatVec :: FamInstEnvs -> [Type] -> [TyVar] -> DataCon -> HsConPatDetails Id -> UniqSM PatVec -translateConPatVec _univ_tys _ex_tvs _ (PrefixCon ps) - = concat <$> translatePatVec (map unLoc ps) -translateConPatVec _univ_tys _ex_tvs _ (InfixCon p1 p2) - = concat <$> translatePatVec (map unLoc [p1,p2]) -translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _)) +translateConPatVec fam_insts _univ_tys _ex_tvs _ (PrefixCon ps) + = concat <$> translatePatVec fam_insts (map unLoc ps) +translateConPatVec fam_insts _univ_tys _ex_tvs _ (InfixCon p1 p2) + = concat <$> translatePatVec fam_insts (map unLoc [p1,p2]) +translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _)) -- Nothing matched. Make up some fresh term variables | null fs = mkPmVarsSM arg_tys -- The data constructor was not defined using record syntax. For the @@ -417,7 +427,7 @@ translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _)) | matched_lbls `subsetOf` orig_lbls = ASSERT(length orig_lbls == length arg_tys) let translateOne (lbl, ty) = case lookup lbl matched_pats of - Just p -> translatePat p + Just p -> translatePat fam_insts p Nothing -> mkPmVarsSM [ty] in concatMapM translateOne (zip orig_lbls arg_tys) -- The fields that appear are not in the correct order. Make up fresh @@ -426,7 +436,7 @@ translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _)) | otherwise = do arg_var_pats <- mkPmVarsSM arg_tys translated_pats <- forM matched_pats $ \(x,pat) -> do - pvec <- translatePat pat + pvec <- translatePat fam_insts pat return (x, pvec) let zipped = zip orig_lbls [ x | PmVar x <- arg_var_pats ] @@ -453,10 +463,10 @@ translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _)) | x == y = subsetOf xs ys | otherwise = subsetOf (x:xs) ys -translateMatch :: LMatch Id (LHsExpr Id) -> UniqSM (PatVec,[PatVec]) -translateMatch (L _ (Match _ lpats _ grhss)) = do - pats' <- concat <$> translatePatVec pats - guards' <- mapM translateGuards guards +translateMatch :: FamInstEnvs -> LMatch Id (LHsExpr Id) -> UniqSM (PatVec,[PatVec]) +translateMatch fam_insts (L _ (Match _ lpats _ grhss)) = do + pats' <- concat <$> translatePatVec fam_insts pats + guards' <- mapM (translateGuards fam_insts) guards return (pats', guards') where extractGuards :: LGRHS Id (LHsExpr Id) -> [GuardStmt Id] @@ -469,9 +479,9 @@ translateMatch (L _ (Match _ lpats _ grhss)) = do -- * Transform source guards (GuardStmt Id) to PmPats (Pattern) -- | Translate a list of guard statements to a pattern vector -translateGuards :: [GuardStmt Id] -> UniqSM PatVec -translateGuards guards = do - all_guards <- concat <$> mapM translateGuard guards +translateGuards :: FamInstEnvs -> [GuardStmt Id] -> UniqSM PatVec +translateGuards fam_insts guards = do + all_guards <- concat <$> mapM (translateGuard fam_insts) guards return (replace_unhandled all_guards) -- It should have been (return $ all_guards) but it is too expressive. -- Since the term oracle does not handle all constraints we generate, @@ -509,24 +519,24 @@ cantFailPattern (PmGrd pv _e) cantFailPattern _ = False -- | Translate a guard statement to Pattern -translateGuard :: GuardStmt Id -> UniqSM PatVec -translateGuard (BodyStmt e _ _ _) = translateBoolGuard e -translateGuard (LetStmt binds) = translateLet (unLoc binds) -translateGuard (BindStmt p e _ _) = translateBind p e -translateGuard (LastStmt {}) = panic "translateGuard LastStmt" -translateGuard (ParStmt {}) = panic "translateGuard ParStmt" -translateGuard (TransStmt {}) = panic "translateGuard TransStmt" -translateGuard (RecStmt {}) = panic "translateGuard RecStmt" -translateGuard (ApplicativeStmt {}) = panic "translateGuard ApplicativeLastStmt" +translateGuard :: FamInstEnvs -> GuardStmt Id -> UniqSM PatVec +translateGuard _ (BodyStmt e _ _ _) = translateBoolGuard e +translateGuard _ (LetStmt binds) = translateLet (unLoc binds) +translateGuard fam_insts (BindStmt p e _ _ _) = translateBind fam_insts p e +translateGuard _ (LastStmt {}) = panic "translateGuard LastStmt" +translateGuard _ (ParStmt {}) = panic "translateGuard ParStmt" +translateGuard _ (TransStmt {}) = panic "translateGuard TransStmt" +translateGuard _ (RecStmt {}) = panic "translateGuard RecStmt" +translateGuard _ (ApplicativeStmt {}) = panic "translateGuard ApplicativeLastStmt" -- | Translate let-bindings translateLet :: HsLocalBinds Id -> UniqSM PatVec translateLet _binds = return [] -- | Translate a pattern guard -translateBind :: LPat Id -> LHsExpr Id -> UniqSM PatVec -translateBind (L _ p) e = do - ps <- translatePat p +translateBind :: FamInstEnvs -> LPat Id -> LHsExpr Id -> UniqSM PatVec +translateBind fam_insts (L _ p) e = do + ps <- translatePat fam_insts p return [mkGuard ps (unLoc e)] -- | Translate a boolean guard @@ -600,7 +610,8 @@ below is the *right thing to do*: The case with literals is a bit different. a literal @l@ should be translated to @x (True <- x == from l)@. Since we want to have better warnings for overloaded literals as it is a very common feature, we treat them differently. -They are mainly covered in Note [Undecidable Equality on Overloaded Literals]. +They are mainly covered in Note [Undecidable Equality on Overloaded Literals] +in PmExpr. 4. N+K Patterns & Pattern Synonyms ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -845,9 +856,6 @@ coercePmPat (PmCon { pm_con_con = con, pm_con_arg_tys = arg_tys , pm_con_args = coercePatVec args }] coercePmPat (PmGrd {}) = [] -- drop the guards -no_fixity :: a -- TODO: Can we retrieve the fixity from the operator name? -no_fixity = panic "Check: no fixity" - -- Get all constructors in the family (including given) allConstructors :: DataCon -> [DataCon] allConstructors = tyConDataCons . dataConTyCon @@ -1101,7 +1109,7 @@ cMatcher us gvsa (p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps -- CLitLit cMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of - -- See Note [Undecidable Equality for Overloaded Literals] + -- See Note [Undecidable Equality for Overloaded Literals] in PmExpr True -> va `mkCons` covered us gvsa ps vsa -- match False -> Empty -- mismatch @@ -1172,7 +1180,7 @@ uMatcher us gvsa ( p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps -- ULitLit uMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of - -- See Note [Undecidable Equality for Overloaded Literals] + -- See Note [Undecidable Equality for Overloaded Literals] in PmExpr True -> va `mkCons` uncovered us gvsa ps vsa -- match False -> va `mkCons` vsa -- mismatch @@ -1256,7 +1264,7 @@ dMatcher us gvsa (p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps -- DLitLit dMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of - -- See Note [Undecidable Equality for Overloaded Literals] + -- See Note [Undecidable Equality for Overloaded Literals] in PmExpr True -> va `mkCons` divergent us gvsa ps vsa -- match False -> Empty -- mismatch @@ -1331,10 +1339,12 @@ genCaseTmCs2 :: Maybe (LHsExpr Id) -- Scrutinee -> [Id] -- MatchVars (should have length 1) -> DsM (Bag SimpleEq) genCaseTmCs2 Nothing _ _ = return emptyBag -genCaseTmCs2 (Just scr) [p] [var] = liftUs $ do - [e] <- map valAbsToPmExpr . coercePatVec <$> translatePat p - let scr_e = lhsExprToPmExpr scr - return $ listToBag [(var, e), (var, scr_e)] +genCaseTmCs2 (Just scr) [p] [var] = do + fam_insts <- dsGetFamInstEnvs + liftUs $ do + [e] <- map valAbsToPmExpr . coercePatVec <$> translatePat fam_insts p + let scr_e = lhsExprToPmExpr scr + return $ listToBag [(var, e), (var, scr_e)] genCaseTmCs2 _ _ _ = panic "genCaseTmCs2: HsCase" -- | Generate a simple equality when checking a case expression: |