diff options
author | Jakob Bruenker <jakob.bruenker@gmail.com> | 2022-03-21 00:14:25 +0100 |
---|---|---|
committer | Jakob Bruenker <jakob.bruenker@gmail.com> | 2022-04-01 20:31:08 +0200 |
commit | 32070e6c2e1b4b7c32530a9566fe14543791f9a6 (patch) | |
tree | f0913ef2a69fd660542723ec07240167dbd37961 /compiler/GHC/Tc/Gen | |
parent | d85c7dcb7c457efc23b20ac8f4e4ae88bae5b050 (diff) | |
download | haskell-32070e6c2e1b4b7c32530a9566fe14543791f9a6.tar.gz |
Implement \cases (Proposal 302)
This commit implements proposal 302: \cases - Multi-way lambda
expressions.
This adds a new expression heralded by \cases, which works exactly like
\case, but can match multiple apats instead of a single pat.
Updates submodule haddock to support the ITlcases token.
Closes #20768
Diffstat (limited to 'compiler/GHC/Tc/Gen')
-rw-r--r-- | compiler/GHC/Tc/Gen/Arrow.hs | 150 | ||||
-rw-r--r-- | compiler/GHC/Tc/Gen/Expr.hs | 8 | ||||
-rw-r--r-- | compiler/GHC/Tc/Gen/Match.hs | 41 |
3 files changed, 116 insertions, 83 deletions
diff --git a/compiler/GHC/Tc/Gen/Arrow.hs b/compiler/GHC/Tc/Gen/Arrow.hs index ad4b67ee88..d3035b5cf2 100644 --- a/compiler/GHC/Tc/Gen/Arrow.hs +++ b/compiler/GHC/Tc/Gen/Arrow.hs @@ -1,5 +1,6 @@ -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE BlockArguments #-} {-# OPTIONS_GHC -Wno-incomplete-record-updates #-} @@ -45,6 +46,8 @@ import GHC.Utils.Outputable import GHC.Utils.Panic import GHC.Utils.Misc +import qualified GHC.Data.Strict as Strict + import Control.Monad {- @@ -164,19 +167,21 @@ tc_cmd env in_cmd@(HsCmdCase x scrut matches) (stk, res_ty) = addErrCtxt (cmdCtxt in_cmd) $ do (scrut', scrut_ty) <- tcInferRho scrut hasFixedRuntimeRep_MustBeRefl - (FRRArrow $ ArrowCmdCase { isCmdLamCase = False }) + (FRRArrow $ ArrowCmdCase) scrut_ty matches' <- tcCmdMatches env scrut_ty matches (stk, res_ty) return (HsCmdCase x scrut' matches') -tc_cmd env in_cmd@(HsCmdLamCase x matches) (stk, res_ty) - = addErrCtxt (cmdCtxt in_cmd) $ do - (co, [scrut_ty], stk') <- matchExpectedCmdArgs 1 stk - hasFixedRuntimeRep_MustBeRefl - (FRRArrow $ ArrowCmdCase { isCmdLamCase = True }) - scrut_ty - matches' <- tcCmdMatches env scrut_ty matches (stk', res_ty) - return (mkHsCmdWrap (mkWpCastN co) (HsCmdLamCase x matches')) +tc_cmd env cmd@(HsCmdLamCase x lc_variant match) cmd_ty + = addErrCtxt (cmdCtxt cmd) + do { let match_ctxt = ArrowLamCaseAlt lc_variant + ; checkPatCounts (ArrowMatchCtxt match_ctxt) match + ; (wrap, match') <- + tcCmdMatchLambda env match_ctxt mk_origin match cmd_ty + ; return (mkHsCmdWrap wrap (HsCmdLamCase x lc_variant match')) } + where mk_origin = ArrowCmdLamCase . case lc_variant of + LamCase -> const Strict.Nothing + LamCases -> Strict.Just tc_cmd env (HsCmdIf x NoSyntaxExprRn pred b1 b2) res_ty -- Ordinary 'if' = do { pred' <- tcCheckMonoExpr pred boolTy @@ -269,52 +274,9 @@ tc_cmd env cmd@(HsCmdApp x fun arg) (cmd_stk, res_ty) -- ------------------------------ -- D;G |-a (\x.cmd) : (t,stk) --> res -tc_cmd env - (HsCmdLam x (MG { mg_alts = L l [L mtch_loc - (match@(Match { m_pats = pats, m_grhss = grhss }))], - mg_origin = origin })) - (cmd_stk, res_ty) - = addErrCtxt (pprMatchInCtxt match) $ - do { (co, arg_tys, cmd_stk') <- matchExpectedCmdArgs n_pats cmd_stk - - -- Check the patterns, and the GRHSs inside - ; (pats', grhss') <- setSrcSpanA mtch_loc $ - tcPats (ArrowMatchCtxt KappaExpr) - pats (map (unrestricted . mkCheckExpType) arg_tys) $ - tc_grhss grhss cmd_stk' (mkCheckExpType res_ty) - - ; let match' = L mtch_loc (Match { m_ext = noAnn - , m_ctxt = ArrowMatchCtxt KappaExpr - , m_pats = pats' - , m_grhss = grhss' }) - arg_tys = map (unrestricted . hsLPatType) pats' - - ; zipWithM_ - (\ (Scaled _ arg_ty) i -> - hasFixedRuntimeRep_MustBeRefl (FRRArrow $ ArrowCmdLam i) arg_ty) - arg_tys - [1..] - - ; let - cmd' = HsCmdLam x (MG { mg_alts = L l [match'] - , mg_ext = MatchGroupTc arg_tys res_ty - , mg_origin = origin }) - ; return (mkHsCmdWrap (mkWpCastN co) cmd') } - where - n_pats = length pats - match_ctxt = ArrowMatchCtxt KappaExpr - pg_ctxt = PatGuard match_ctxt - - tc_grhss (GRHSs x grhss binds) stk_ty res_ty - = do { (binds', grhss') <- tcLocalBinds binds $ - mapM (wrapLocMA (tc_grhs stk_ty res_ty)) grhss - ; return (GRHSs x grhss' binds') } - - tc_grhs stk_ty res_ty (GRHS x guards body) - = do { (guards', rhs') <- tcStmtsAndThen pg_ctxt tcGuardStmt guards res_ty $ - \ res_ty -> tcCmd env body - (stk_ty, checkingExpType "tc_grhs" res_ty) - ; return (GRHS x guards' rhs') } +tc_cmd env (HsCmdLam x match) cmd_ty + = do { (wrap, match') <- tcCmdMatchLambda env KappaExpr ArrowCmdLam match cmd_ty + ; return (mkHsCmdWrap wrap (HsCmdLam x match')) } ------------------------------------------- -- Do notation @@ -340,7 +302,7 @@ tc_cmd env (HsCmdDo _ (L l stmts) ) (cmd_stk, res_ty) -- D; G |-a (| e c1 ... cn |) : stk --> t tc_cmd env cmd@(HsCmdArrForm x expr f fixity cmd_args) (cmd_stk, res_ty) - = addErrCtxt (cmdCtxt cmd) $ + = addErrCtxt (cmdCtxt cmd) do { (cmd_args', cmd_tys) <- mapAndUnzipM tc_cmd_arg cmd_args -- We use alphaTyVar for 'w' ; let e_ty = mkInfForAllTy alphaTyVar $ @@ -361,15 +323,7 @@ tc_cmd env cmd@(HsCmdArrForm x expr f fixity cmd_args) (cmd_stk, res_ty) ; cmd' <- tcCmdTop env' names' cmd (stk_ty, res_ty) ; return (cmd', mkCmdArrTy env' (mkPairTy alphaTy stk_ty) res_ty) } ------------------------------------------------------------------ --- Base case for illegal commands --- This is where expressions that aren't commands get rejected - -tc_cmd _ cmd _ - = failWithTc (TcRnArrowCommandExpected cmd) - --- | Typechecking for case command alternatives. Used for both --- 'HsCmdCase' and 'HsCmdLamCase'. +-- | Typechecking for case command alternatives. Used for 'HsCmdCase'. tcCmdMatches :: CmdEnv -> TcType -- ^ Type of the scrutinee. -- Must have a fixed RuntimeRep as per @@ -385,6 +339,68 @@ tcCmdMatches env scrut_ty matches (stk, res_ty) mc_body body res_ty' = do { res_ty' <- expTypeToType res_ty' ; tcCmd env body (stk, res_ty') } +-- | Typechecking for 'HsCmdLam' and 'HsCmdLamCase'. +tcCmdMatchLambda :: CmdEnv + -> HsArrowMatchContext + -> (Int -> FRRArrowOrigin) -- ^ Function that creates an origin + -- given the index of a pattern + -> MatchGroup GhcRn (LHsCmd GhcRn) + -> CmdType + -> TcM (HsWrapper, MatchGroup GhcTc (LHsCmd GhcTc)) +tcCmdMatchLambda env + ctxt + mk_origin + mg@MG { mg_alts = L l matches } + (cmd_stk, res_ty) + = do { (co, arg_tys, cmd_stk') <- matchExpectedCmdArgs n_pats cmd_stk + + ; let check_arg_tys = map (unrestricted . mkCheckExpType) arg_tys + ; matches' <- forM matches $ + addErrCtxt . pprMatchInCtxt . unLoc <*> tc_match check_arg_tys cmd_stk' + + ; let arg_tys' = map unrestricted arg_tys + mg' = mg { mg_alts = L l matches' + , mg_ext = MatchGroupTc arg_tys' res_ty } + + ; return (mkWpCastN co, mg') } + where + n_pats | isEmptyMatchGroup mg = 1 -- must be lambda-case + | otherwise = matchGroupArity mg + + -- Check the patterns, and the GRHSs inside + tc_match arg_tys cmd_stk' (L mtch_loc (Match { m_pats = pats, m_grhss = grhss })) + = do { (pats', grhss') <- setSrcSpanA mtch_loc $ + tcPats match_ctxt pats arg_tys $ + tc_grhss grhss cmd_stk' (mkCheckExpType res_ty) + + ; let arg_tys' = map (unrestricted . hsLPatType) pats' + + ; zipWithM_ + (\ (Scaled _ arg_ty) i -> + hasFixedRuntimeRep_MustBeRefl (FRRArrow $ mk_origin i) arg_ty) + arg_tys' + [1..] + + ; return $ L mtch_loc (Match { m_ext = noAnn + , m_ctxt = match_ctxt + , m_pats = pats' + , m_grhss = grhss' }) } + + + match_ctxt = ArrowMatchCtxt ctxt + pg_ctxt = PatGuard match_ctxt + + tc_grhss (GRHSs x grhss binds) stk_ty res_ty + = do { (binds', grhss') <- tcLocalBinds binds $ + mapM (wrapLocMA (tc_grhs stk_ty res_ty)) grhss + ; return (GRHSs x grhss' binds') } + + tc_grhs stk_ty res_ty (GRHS x guards body) + = do { (guards', rhs') <- tcStmtsAndThen pg_ctxt tcGuardStmt guards res_ty $ + \ res_ty -> tcCmd env body + (stk_ty, checkingExpType "tc_grhs" res_ty) + ; return (GRHS x guards' rhs') } + matchExpectedCmdArgs :: Arity -> TcType -> TcM (TcCoercionN, [TcType], TcType) matchExpectedCmdArgs 0 ty = return (mkTcNomReflCo ty, [], ty) diff --git a/compiler/GHC/Tc/Gen/Expr.hs b/compiler/GHC/Tc/Gen/Expr.hs index 5cfe527c70..b5e9982f48 100644 --- a/compiler/GHC/Tc/Gen/Expr.hs +++ b/compiler/GHC/Tc/Gen/Expr.hs @@ -264,13 +264,13 @@ tcExpr (HsLam _ match) res_ty match_ctxt = MC { mc_what = LambdaExpr, mc_body = tcBody } herald = ExpectedFunTyLam match -tcExpr e@(HsLamCase x matches) res_ty +tcExpr e@(HsLamCase x lc_variant matches) res_ty = do { (wrap, matches') <- tcMatchLambda herald match_ctxt matches res_ty - ; return (mkHsWrap wrap $ HsLamCase x matches') } + ; return (mkHsWrap wrap $ HsLamCase x lc_variant matches') } where - match_ctxt = MC { mc_what = CaseAlt, mc_body = tcBody } - herald = ExpectedFunTyLamCase e + match_ctxt = MC { mc_what = LamCaseAlt lc_variant, mc_body = tcBody } + herald = ExpectedFunTyLamCase lc_variant e diff --git a/compiler/GHC/Tc/Gen/Match.hs b/compiler/GHC/Tc/Gen/Match.hs index d6f3590910..0763ad2679 100644 --- a/compiler/GHC/Tc/Gen/Match.hs +++ b/compiler/GHC/Tc/Gen/Match.hs @@ -31,6 +31,7 @@ module GHC.Tc.Gen.Match , tcBody , tcDoStmt , tcGuardStmt + , checkPatCounts ) where @@ -105,7 +106,9 @@ tcMatchesFun fun_id matches exp_ty -- ann-grabbing, because we don't always have annotations in -- hand when we call tcMatchesFun... traceTc "tcMatchesFun" (ppr fun_name $$ ppr exp_ty) - ; checkArgs fun_name matches + -- We can't easily call checkPatCounts here because fun_id can be an + -- unfilled thunk + ; checkArgCounts fun_name matches ; matchExpectedFunTys herald ctxt arity exp_ty $ \ pat_tys rhs_ty -> -- NB: exp_type may be polymorphic, but @@ -161,8 +164,10 @@ tcMatchLambda :: ExpectedFunTyOrigin -- see Note [Herald for matchExpectedFunTys -> ExpRhoType -> TcM (HsWrapper, MatchGroup GhcTc (LHsExpr GhcTc)) tcMatchLambda herald match_ctxt match res_ty - = matchExpectedFunTys herald GenSigCtxt n_pats res_ty $ \ pat_tys rhs_ty -> - tcMatches match_ctxt pat_tys rhs_ty match + = do { checkPatCounts (mc_what match_ctxt) match + ; matchExpectedFunTys herald GenSigCtxt n_pats res_ty $ \ pat_tys rhs_ty -> do + -- checking argument counts since this is also used for \cases + tcMatches match_ctxt pat_tys rhs_ty match } where n_pats | isEmptyMatchGroup match = 1 -- must be lambda-case | otherwise = matchGroupArity match @@ -1132,23 +1137,35 @@ the variables they bind into scope, and typecheck the thing_inside. * * ************************************************************************ -@sameNoOfArgs@ takes a @[RenamedMatch]@ and decides whether the same +@checkArgCounts@ takes a @[RenamedMatch]@ and decides whether the same number of args are used in each equation. -} -checkArgs :: AnnoBody body - => Name -> MatchGroup GhcRn (LocatedA (body GhcRn)) -> TcM () -checkArgs _ (MG { mg_alts = L _ [] }) +checkArgCounts :: AnnoBody body + => Name -> MatchGroup GhcRn (LocatedA (body GhcRn)) -> TcM () +checkArgCounts = check_match_pats . (text "Equations for" <+>) . quotes . ppr + +-- @checkPatCounts@ takes a @[RenamedMatch]@ and decides whether the same +-- number of patterns are used in each alternative +checkPatCounts :: AnnoBody body + => HsMatchContext GhcTc -> MatchGroup GhcRn (LocatedA (body GhcRn)) + -> TcM () +checkPatCounts = check_match_pats . pprMatchContextNouns + +check_match_pats :: AnnoBody body + => SDoc -> MatchGroup GhcRn (LocatedA (body GhcRn)) + -> TcM () +check_match_pats _ (MG { mg_alts = L _ [] }) = return () -checkArgs fun (MG { mg_alts = L _ (match1:matches) }) +check_match_pats err_msg (MG { mg_alts = L _ (match1:matches) }) | null bad_matches = return () | otherwise = failWithTc $ TcRnUnknownMessage $ mkPlainError noHints $ - (vcat [ text "Equations for" <+> quotes (ppr fun) <+> - text "have different numbers of arguments" - , nest 2 (ppr (getLocA match1)) - , nest 2 (ppr (getLocA (head bad_matches)))]) + (vcat [ err_msg <+> + text "have different numbers of arguments" + , nest 2 (ppr (getLocA match1)) + , nest 2 (ppr (getLocA (head bad_matches)))]) where n_args1 = args_in_match match1 bad_matches = [m | m <- matches, args_in_match m /= n_args1] |