summaryrefslogtreecommitdiff
path: root/compiler/GHC/Tc/Gen
diff options
context:
space:
mode:
authorJakob Bruenker <jakob.bruenker@gmail.com>2022-03-21 00:14:25 +0100
committerJakob Bruenker <jakob.bruenker@gmail.com>2022-04-01 20:31:08 +0200
commit32070e6c2e1b4b7c32530a9566fe14543791f9a6 (patch)
treef0913ef2a69fd660542723ec07240167dbd37961 /compiler/GHC/Tc/Gen
parentd85c7dcb7c457efc23b20ac8f4e4ae88bae5b050 (diff)
downloadhaskell-32070e6c2e1b4b7c32530a9566fe14543791f9a6.tar.gz
Implement \cases (Proposal 302)
This commit implements proposal 302: \cases - Multi-way lambda expressions. This adds a new expression heralded by \cases, which works exactly like \case, but can match multiple apats instead of a single pat. Updates submodule haddock to support the ITlcases token. Closes #20768
Diffstat (limited to 'compiler/GHC/Tc/Gen')
-rw-r--r--compiler/GHC/Tc/Gen/Arrow.hs150
-rw-r--r--compiler/GHC/Tc/Gen/Expr.hs8
-rw-r--r--compiler/GHC/Tc/Gen/Match.hs41
3 files changed, 116 insertions, 83 deletions
diff --git a/compiler/GHC/Tc/Gen/Arrow.hs b/compiler/GHC/Tc/Gen/Arrow.hs
index ad4b67ee88..d3035b5cf2 100644
--- a/compiler/GHC/Tc/Gen/Arrow.hs
+++ b/compiler/GHC/Tc/Gen/Arrow.hs
@@ -1,5 +1,6 @@
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE BlockArguments #-}
{-# OPTIONS_GHC -Wno-incomplete-record-updates #-}
@@ -45,6 +46,8 @@ import GHC.Utils.Outputable
import GHC.Utils.Panic
import GHC.Utils.Misc
+import qualified GHC.Data.Strict as Strict
+
import Control.Monad
{-
@@ -164,19 +167,21 @@ tc_cmd env in_cmd@(HsCmdCase x scrut matches) (stk, res_ty)
= addErrCtxt (cmdCtxt in_cmd) $ do
(scrut', scrut_ty) <- tcInferRho scrut
hasFixedRuntimeRep_MustBeRefl
- (FRRArrow $ ArrowCmdCase { isCmdLamCase = False })
+ (FRRArrow $ ArrowCmdCase)
scrut_ty
matches' <- tcCmdMatches env scrut_ty matches (stk, res_ty)
return (HsCmdCase x scrut' matches')
-tc_cmd env in_cmd@(HsCmdLamCase x matches) (stk, res_ty)
- = addErrCtxt (cmdCtxt in_cmd) $ do
- (co, [scrut_ty], stk') <- matchExpectedCmdArgs 1 stk
- hasFixedRuntimeRep_MustBeRefl
- (FRRArrow $ ArrowCmdCase { isCmdLamCase = True })
- scrut_ty
- matches' <- tcCmdMatches env scrut_ty matches (stk', res_ty)
- return (mkHsCmdWrap (mkWpCastN co) (HsCmdLamCase x matches'))
+tc_cmd env cmd@(HsCmdLamCase x lc_variant match) cmd_ty
+ = addErrCtxt (cmdCtxt cmd)
+ do { let match_ctxt = ArrowLamCaseAlt lc_variant
+ ; checkPatCounts (ArrowMatchCtxt match_ctxt) match
+ ; (wrap, match') <-
+ tcCmdMatchLambda env match_ctxt mk_origin match cmd_ty
+ ; return (mkHsCmdWrap wrap (HsCmdLamCase x lc_variant match')) }
+ where mk_origin = ArrowCmdLamCase . case lc_variant of
+ LamCase -> const Strict.Nothing
+ LamCases -> Strict.Just
tc_cmd env (HsCmdIf x NoSyntaxExprRn pred b1 b2) res_ty -- Ordinary 'if'
= do { pred' <- tcCheckMonoExpr pred boolTy
@@ -269,52 +274,9 @@ tc_cmd env cmd@(HsCmdApp x fun arg) (cmd_stk, res_ty)
-- ------------------------------
-- D;G |-a (\x.cmd) : (t,stk) --> res
-tc_cmd env
- (HsCmdLam x (MG { mg_alts = L l [L mtch_loc
- (match@(Match { m_pats = pats, m_grhss = grhss }))],
- mg_origin = origin }))
- (cmd_stk, res_ty)
- = addErrCtxt (pprMatchInCtxt match) $
- do { (co, arg_tys, cmd_stk') <- matchExpectedCmdArgs n_pats cmd_stk
-
- -- Check the patterns, and the GRHSs inside
- ; (pats', grhss') <- setSrcSpanA mtch_loc $
- tcPats (ArrowMatchCtxt KappaExpr)
- pats (map (unrestricted . mkCheckExpType) arg_tys) $
- tc_grhss grhss cmd_stk' (mkCheckExpType res_ty)
-
- ; let match' = L mtch_loc (Match { m_ext = noAnn
- , m_ctxt = ArrowMatchCtxt KappaExpr
- , m_pats = pats'
- , m_grhss = grhss' })
- arg_tys = map (unrestricted . hsLPatType) pats'
-
- ; zipWithM_
- (\ (Scaled _ arg_ty) i ->
- hasFixedRuntimeRep_MustBeRefl (FRRArrow $ ArrowCmdLam i) arg_ty)
- arg_tys
- [1..]
-
- ; let
- cmd' = HsCmdLam x (MG { mg_alts = L l [match']
- , mg_ext = MatchGroupTc arg_tys res_ty
- , mg_origin = origin })
- ; return (mkHsCmdWrap (mkWpCastN co) cmd') }
- where
- n_pats = length pats
- match_ctxt = ArrowMatchCtxt KappaExpr
- pg_ctxt = PatGuard match_ctxt
-
- tc_grhss (GRHSs x grhss binds) stk_ty res_ty
- = do { (binds', grhss') <- tcLocalBinds binds $
- mapM (wrapLocMA (tc_grhs stk_ty res_ty)) grhss
- ; return (GRHSs x grhss' binds') }
-
- tc_grhs stk_ty res_ty (GRHS x guards body)
- = do { (guards', rhs') <- tcStmtsAndThen pg_ctxt tcGuardStmt guards res_ty $
- \ res_ty -> tcCmd env body
- (stk_ty, checkingExpType "tc_grhs" res_ty)
- ; return (GRHS x guards' rhs') }
+tc_cmd env (HsCmdLam x match) cmd_ty
+ = do { (wrap, match') <- tcCmdMatchLambda env KappaExpr ArrowCmdLam match cmd_ty
+ ; return (mkHsCmdWrap wrap (HsCmdLam x match')) }
-------------------------------------------
-- Do notation
@@ -340,7 +302,7 @@ tc_cmd env (HsCmdDo _ (L l stmts) ) (cmd_stk, res_ty)
-- D; G |-a (| e c1 ... cn |) : stk --> t
tc_cmd env cmd@(HsCmdArrForm x expr f fixity cmd_args) (cmd_stk, res_ty)
- = addErrCtxt (cmdCtxt cmd) $
+ = addErrCtxt (cmdCtxt cmd)
do { (cmd_args', cmd_tys) <- mapAndUnzipM tc_cmd_arg cmd_args
-- We use alphaTyVar for 'w'
; let e_ty = mkInfForAllTy alphaTyVar $
@@ -361,15 +323,7 @@ tc_cmd env cmd@(HsCmdArrForm x expr f fixity cmd_args) (cmd_stk, res_ty)
; cmd' <- tcCmdTop env' names' cmd (stk_ty, res_ty)
; return (cmd', mkCmdArrTy env' (mkPairTy alphaTy stk_ty) res_ty) }
------------------------------------------------------------------
--- Base case for illegal commands
--- This is where expressions that aren't commands get rejected
-
-tc_cmd _ cmd _
- = failWithTc (TcRnArrowCommandExpected cmd)
-
--- | Typechecking for case command alternatives. Used for both
--- 'HsCmdCase' and 'HsCmdLamCase'.
+-- | Typechecking for case command alternatives. Used for 'HsCmdCase'.
tcCmdMatches :: CmdEnv
-> TcType -- ^ Type of the scrutinee.
-- Must have a fixed RuntimeRep as per
@@ -385,6 +339,68 @@ tcCmdMatches env scrut_ty matches (stk, res_ty)
mc_body body res_ty' = do { res_ty' <- expTypeToType res_ty'
; tcCmd env body (stk, res_ty') }
+-- | Typechecking for 'HsCmdLam' and 'HsCmdLamCase'.
+tcCmdMatchLambda :: CmdEnv
+ -> HsArrowMatchContext
+ -> (Int -> FRRArrowOrigin) -- ^ Function that creates an origin
+ -- given the index of a pattern
+ -> MatchGroup GhcRn (LHsCmd GhcRn)
+ -> CmdType
+ -> TcM (HsWrapper, MatchGroup GhcTc (LHsCmd GhcTc))
+tcCmdMatchLambda env
+ ctxt
+ mk_origin
+ mg@MG { mg_alts = L l matches }
+ (cmd_stk, res_ty)
+ = do { (co, arg_tys, cmd_stk') <- matchExpectedCmdArgs n_pats cmd_stk
+
+ ; let check_arg_tys = map (unrestricted . mkCheckExpType) arg_tys
+ ; matches' <- forM matches $
+ addErrCtxt . pprMatchInCtxt . unLoc <*> tc_match check_arg_tys cmd_stk'
+
+ ; let arg_tys' = map unrestricted arg_tys
+ mg' = mg { mg_alts = L l matches'
+ , mg_ext = MatchGroupTc arg_tys' res_ty }
+
+ ; return (mkWpCastN co, mg') }
+ where
+ n_pats | isEmptyMatchGroup mg = 1 -- must be lambda-case
+ | otherwise = matchGroupArity mg
+
+ -- Check the patterns, and the GRHSs inside
+ tc_match arg_tys cmd_stk' (L mtch_loc (Match { m_pats = pats, m_grhss = grhss }))
+ = do { (pats', grhss') <- setSrcSpanA mtch_loc $
+ tcPats match_ctxt pats arg_tys $
+ tc_grhss grhss cmd_stk' (mkCheckExpType res_ty)
+
+ ; let arg_tys' = map (unrestricted . hsLPatType) pats'
+
+ ; zipWithM_
+ (\ (Scaled _ arg_ty) i ->
+ hasFixedRuntimeRep_MustBeRefl (FRRArrow $ mk_origin i) arg_ty)
+ arg_tys'
+ [1..]
+
+ ; return $ L mtch_loc (Match { m_ext = noAnn
+ , m_ctxt = match_ctxt
+ , m_pats = pats'
+ , m_grhss = grhss' }) }
+
+
+ match_ctxt = ArrowMatchCtxt ctxt
+ pg_ctxt = PatGuard match_ctxt
+
+ tc_grhss (GRHSs x grhss binds) stk_ty res_ty
+ = do { (binds', grhss') <- tcLocalBinds binds $
+ mapM (wrapLocMA (tc_grhs stk_ty res_ty)) grhss
+ ; return (GRHSs x grhss' binds') }
+
+ tc_grhs stk_ty res_ty (GRHS x guards body)
+ = do { (guards', rhs') <- tcStmtsAndThen pg_ctxt tcGuardStmt guards res_ty $
+ \ res_ty -> tcCmd env body
+ (stk_ty, checkingExpType "tc_grhs" res_ty)
+ ; return (GRHS x guards' rhs') }
+
matchExpectedCmdArgs :: Arity -> TcType -> TcM (TcCoercionN, [TcType], TcType)
matchExpectedCmdArgs 0 ty
= return (mkTcNomReflCo ty, [], ty)
diff --git a/compiler/GHC/Tc/Gen/Expr.hs b/compiler/GHC/Tc/Gen/Expr.hs
index 5cfe527c70..b5e9982f48 100644
--- a/compiler/GHC/Tc/Gen/Expr.hs
+++ b/compiler/GHC/Tc/Gen/Expr.hs
@@ -264,13 +264,13 @@ tcExpr (HsLam _ match) res_ty
match_ctxt = MC { mc_what = LambdaExpr, mc_body = tcBody }
herald = ExpectedFunTyLam match
-tcExpr e@(HsLamCase x matches) res_ty
+tcExpr e@(HsLamCase x lc_variant matches) res_ty
= do { (wrap, matches')
<- tcMatchLambda herald match_ctxt matches res_ty
- ; return (mkHsWrap wrap $ HsLamCase x matches') }
+ ; return (mkHsWrap wrap $ HsLamCase x lc_variant matches') }
where
- match_ctxt = MC { mc_what = CaseAlt, mc_body = tcBody }
- herald = ExpectedFunTyLamCase e
+ match_ctxt = MC { mc_what = LamCaseAlt lc_variant, mc_body = tcBody }
+ herald = ExpectedFunTyLamCase lc_variant e
diff --git a/compiler/GHC/Tc/Gen/Match.hs b/compiler/GHC/Tc/Gen/Match.hs
index d6f3590910..0763ad2679 100644
--- a/compiler/GHC/Tc/Gen/Match.hs
+++ b/compiler/GHC/Tc/Gen/Match.hs
@@ -31,6 +31,7 @@ module GHC.Tc.Gen.Match
, tcBody
, tcDoStmt
, tcGuardStmt
+ , checkPatCounts
)
where
@@ -105,7 +106,9 @@ tcMatchesFun fun_id matches exp_ty
-- ann-grabbing, because we don't always have annotations in
-- hand when we call tcMatchesFun...
traceTc "tcMatchesFun" (ppr fun_name $$ ppr exp_ty)
- ; checkArgs fun_name matches
+ -- We can't easily call checkPatCounts here because fun_id can be an
+ -- unfilled thunk
+ ; checkArgCounts fun_name matches
; matchExpectedFunTys herald ctxt arity exp_ty $ \ pat_tys rhs_ty ->
-- NB: exp_type may be polymorphic, but
@@ -161,8 +164,10 @@ tcMatchLambda :: ExpectedFunTyOrigin -- see Note [Herald for matchExpectedFunTys
-> ExpRhoType
-> TcM (HsWrapper, MatchGroup GhcTc (LHsExpr GhcTc))
tcMatchLambda herald match_ctxt match res_ty
- = matchExpectedFunTys herald GenSigCtxt n_pats res_ty $ \ pat_tys rhs_ty ->
- tcMatches match_ctxt pat_tys rhs_ty match
+ = do { checkPatCounts (mc_what match_ctxt) match
+ ; matchExpectedFunTys herald GenSigCtxt n_pats res_ty $ \ pat_tys rhs_ty -> do
+ -- checking argument counts since this is also used for \cases
+ tcMatches match_ctxt pat_tys rhs_ty match }
where
n_pats | isEmptyMatchGroup match = 1 -- must be lambda-case
| otherwise = matchGroupArity match
@@ -1132,23 +1137,35 @@ the variables they bind into scope, and typecheck the thing_inside.
* *
************************************************************************
-@sameNoOfArgs@ takes a @[RenamedMatch]@ and decides whether the same
+@checkArgCounts@ takes a @[RenamedMatch]@ and decides whether the same
number of args are used in each equation.
-}
-checkArgs :: AnnoBody body
- => Name -> MatchGroup GhcRn (LocatedA (body GhcRn)) -> TcM ()
-checkArgs _ (MG { mg_alts = L _ [] })
+checkArgCounts :: AnnoBody body
+ => Name -> MatchGroup GhcRn (LocatedA (body GhcRn)) -> TcM ()
+checkArgCounts = check_match_pats . (text "Equations for" <+>) . quotes . ppr
+
+-- @checkPatCounts@ takes a @[RenamedMatch]@ and decides whether the same
+-- number of patterns are used in each alternative
+checkPatCounts :: AnnoBody body
+ => HsMatchContext GhcTc -> MatchGroup GhcRn (LocatedA (body GhcRn))
+ -> TcM ()
+checkPatCounts = check_match_pats . pprMatchContextNouns
+
+check_match_pats :: AnnoBody body
+ => SDoc -> MatchGroup GhcRn (LocatedA (body GhcRn))
+ -> TcM ()
+check_match_pats _ (MG { mg_alts = L _ [] })
= return ()
-checkArgs fun (MG { mg_alts = L _ (match1:matches) })
+check_match_pats err_msg (MG { mg_alts = L _ (match1:matches) })
| null bad_matches
= return ()
| otherwise
= failWithTc $ TcRnUnknownMessage $ mkPlainError noHints $
- (vcat [ text "Equations for" <+> quotes (ppr fun) <+>
- text "have different numbers of arguments"
- , nest 2 (ppr (getLocA match1))
- , nest 2 (ppr (getLocA (head bad_matches)))])
+ (vcat [ err_msg <+>
+ text "have different numbers of arguments"
+ , nest 2 (ppr (getLocA match1))
+ , nest 2 (ppr (getLocA (head bad_matches)))])
where
n_args1 = args_in_match match1
bad_matches = [m | m <- matches, args_in_match m /= n_args1]