summaryrefslogtreecommitdiff
path: root/compiler/GHC/Tc/Gen/Arrow.hs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/GHC/Tc/Gen/Arrow.hs')
-rw-r--r--compiler/GHC/Tc/Gen/Arrow.hs150
1 files changed, 83 insertions, 67 deletions
diff --git a/compiler/GHC/Tc/Gen/Arrow.hs b/compiler/GHC/Tc/Gen/Arrow.hs
index ad4b67ee88..d3035b5cf2 100644
--- a/compiler/GHC/Tc/Gen/Arrow.hs
+++ b/compiler/GHC/Tc/Gen/Arrow.hs
@@ -1,5 +1,6 @@
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE BlockArguments #-}
{-# OPTIONS_GHC -Wno-incomplete-record-updates #-}
@@ -45,6 +46,8 @@ import GHC.Utils.Outputable
import GHC.Utils.Panic
import GHC.Utils.Misc
+import qualified GHC.Data.Strict as Strict
+
import Control.Monad
{-
@@ -164,19 +167,21 @@ tc_cmd env in_cmd@(HsCmdCase x scrut matches) (stk, res_ty)
= addErrCtxt (cmdCtxt in_cmd) $ do
(scrut', scrut_ty) <- tcInferRho scrut
hasFixedRuntimeRep_MustBeRefl
- (FRRArrow $ ArrowCmdCase { isCmdLamCase = False })
+ (FRRArrow $ ArrowCmdCase)
scrut_ty
matches' <- tcCmdMatches env scrut_ty matches (stk, res_ty)
return (HsCmdCase x scrut' matches')
-tc_cmd env in_cmd@(HsCmdLamCase x matches) (stk, res_ty)
- = addErrCtxt (cmdCtxt in_cmd) $ do
- (co, [scrut_ty], stk') <- matchExpectedCmdArgs 1 stk
- hasFixedRuntimeRep_MustBeRefl
- (FRRArrow $ ArrowCmdCase { isCmdLamCase = True })
- scrut_ty
- matches' <- tcCmdMatches env scrut_ty matches (stk', res_ty)
- return (mkHsCmdWrap (mkWpCastN co) (HsCmdLamCase x matches'))
+tc_cmd env cmd@(HsCmdLamCase x lc_variant match) cmd_ty
+ = addErrCtxt (cmdCtxt cmd)
+ do { let match_ctxt = ArrowLamCaseAlt lc_variant
+ ; checkPatCounts (ArrowMatchCtxt match_ctxt) match
+ ; (wrap, match') <-
+ tcCmdMatchLambda env match_ctxt mk_origin match cmd_ty
+ ; return (mkHsCmdWrap wrap (HsCmdLamCase x lc_variant match')) }
+ where mk_origin = ArrowCmdLamCase . case lc_variant of
+ LamCase -> const Strict.Nothing
+ LamCases -> Strict.Just
tc_cmd env (HsCmdIf x NoSyntaxExprRn pred b1 b2) res_ty -- Ordinary 'if'
= do { pred' <- tcCheckMonoExpr pred boolTy
@@ -269,52 +274,9 @@ tc_cmd env cmd@(HsCmdApp x fun arg) (cmd_stk, res_ty)
-- ------------------------------
-- D;G |-a (\x.cmd) : (t,stk) --> res
-tc_cmd env
- (HsCmdLam x (MG { mg_alts = L l [L mtch_loc
- (match@(Match { m_pats = pats, m_grhss = grhss }))],
- mg_origin = origin }))
- (cmd_stk, res_ty)
- = addErrCtxt (pprMatchInCtxt match) $
- do { (co, arg_tys, cmd_stk') <- matchExpectedCmdArgs n_pats cmd_stk
-
- -- Check the patterns, and the GRHSs inside
- ; (pats', grhss') <- setSrcSpanA mtch_loc $
- tcPats (ArrowMatchCtxt KappaExpr)
- pats (map (unrestricted . mkCheckExpType) arg_tys) $
- tc_grhss grhss cmd_stk' (mkCheckExpType res_ty)
-
- ; let match' = L mtch_loc (Match { m_ext = noAnn
- , m_ctxt = ArrowMatchCtxt KappaExpr
- , m_pats = pats'
- , m_grhss = grhss' })
- arg_tys = map (unrestricted . hsLPatType) pats'
-
- ; zipWithM_
- (\ (Scaled _ arg_ty) i ->
- hasFixedRuntimeRep_MustBeRefl (FRRArrow $ ArrowCmdLam i) arg_ty)
- arg_tys
- [1..]
-
- ; let
- cmd' = HsCmdLam x (MG { mg_alts = L l [match']
- , mg_ext = MatchGroupTc arg_tys res_ty
- , mg_origin = origin })
- ; return (mkHsCmdWrap (mkWpCastN co) cmd') }
- where
- n_pats = length pats
- match_ctxt = ArrowMatchCtxt KappaExpr
- pg_ctxt = PatGuard match_ctxt
-
- tc_grhss (GRHSs x grhss binds) stk_ty res_ty
- = do { (binds', grhss') <- tcLocalBinds binds $
- mapM (wrapLocMA (tc_grhs stk_ty res_ty)) grhss
- ; return (GRHSs x grhss' binds') }
-
- tc_grhs stk_ty res_ty (GRHS x guards body)
- = do { (guards', rhs') <- tcStmtsAndThen pg_ctxt tcGuardStmt guards res_ty $
- \ res_ty -> tcCmd env body
- (stk_ty, checkingExpType "tc_grhs" res_ty)
- ; return (GRHS x guards' rhs') }
+tc_cmd env (HsCmdLam x match) cmd_ty
+ = do { (wrap, match') <- tcCmdMatchLambda env KappaExpr ArrowCmdLam match cmd_ty
+ ; return (mkHsCmdWrap wrap (HsCmdLam x match')) }
-------------------------------------------
-- Do notation
@@ -340,7 +302,7 @@ tc_cmd env (HsCmdDo _ (L l stmts) ) (cmd_stk, res_ty)
-- D; G |-a (| e c1 ... cn |) : stk --> t
tc_cmd env cmd@(HsCmdArrForm x expr f fixity cmd_args) (cmd_stk, res_ty)
- = addErrCtxt (cmdCtxt cmd) $
+ = addErrCtxt (cmdCtxt cmd)
do { (cmd_args', cmd_tys) <- mapAndUnzipM tc_cmd_arg cmd_args
-- We use alphaTyVar for 'w'
; let e_ty = mkInfForAllTy alphaTyVar $
@@ -361,15 +323,7 @@ tc_cmd env cmd@(HsCmdArrForm x expr f fixity cmd_args) (cmd_stk, res_ty)
; cmd' <- tcCmdTop env' names' cmd (stk_ty, res_ty)
; return (cmd', mkCmdArrTy env' (mkPairTy alphaTy stk_ty) res_ty) }
------------------------------------------------------------------
--- Base case for illegal commands
--- This is where expressions that aren't commands get rejected
-
-tc_cmd _ cmd _
- = failWithTc (TcRnArrowCommandExpected cmd)
-
--- | Typechecking for case command alternatives. Used for both
--- 'HsCmdCase' and 'HsCmdLamCase'.
+-- | Typechecking for case command alternatives. Used for 'HsCmdCase'.
tcCmdMatches :: CmdEnv
-> TcType -- ^ Type of the scrutinee.
-- Must have a fixed RuntimeRep as per
@@ -385,6 +339,68 @@ tcCmdMatches env scrut_ty matches (stk, res_ty)
mc_body body res_ty' = do { res_ty' <- expTypeToType res_ty'
; tcCmd env body (stk, res_ty') }
+-- | Typechecking for 'HsCmdLam' and 'HsCmdLamCase'.
+tcCmdMatchLambda :: CmdEnv
+ -> HsArrowMatchContext
+ -> (Int -> FRRArrowOrigin) -- ^ Function that creates an origin
+ -- given the index of a pattern
+ -> MatchGroup GhcRn (LHsCmd GhcRn)
+ -> CmdType
+ -> TcM (HsWrapper, MatchGroup GhcTc (LHsCmd GhcTc))
+tcCmdMatchLambda env
+ ctxt
+ mk_origin
+ mg@MG { mg_alts = L l matches }
+ (cmd_stk, res_ty)
+ = do { (co, arg_tys, cmd_stk') <- matchExpectedCmdArgs n_pats cmd_stk
+
+ ; let check_arg_tys = map (unrestricted . mkCheckExpType) arg_tys
+ ; matches' <- forM matches $
+ addErrCtxt . pprMatchInCtxt . unLoc <*> tc_match check_arg_tys cmd_stk'
+
+ ; let arg_tys' = map unrestricted arg_tys
+ mg' = mg { mg_alts = L l matches'
+ , mg_ext = MatchGroupTc arg_tys' res_ty }
+
+ ; return (mkWpCastN co, mg') }
+ where
+ n_pats | isEmptyMatchGroup mg = 1 -- must be lambda-case
+ | otherwise = matchGroupArity mg
+
+ -- Check the patterns, and the GRHSs inside
+ tc_match arg_tys cmd_stk' (L mtch_loc (Match { m_pats = pats, m_grhss = grhss }))
+ = do { (pats', grhss') <- setSrcSpanA mtch_loc $
+ tcPats match_ctxt pats arg_tys $
+ tc_grhss grhss cmd_stk' (mkCheckExpType res_ty)
+
+ ; let arg_tys' = map (unrestricted . hsLPatType) pats'
+
+ ; zipWithM_
+ (\ (Scaled _ arg_ty) i ->
+ hasFixedRuntimeRep_MustBeRefl (FRRArrow $ mk_origin i) arg_ty)
+ arg_tys'
+ [1..]
+
+ ; return $ L mtch_loc (Match { m_ext = noAnn
+ , m_ctxt = match_ctxt
+ , m_pats = pats'
+ , m_grhss = grhss' }) }
+
+
+ match_ctxt = ArrowMatchCtxt ctxt
+ pg_ctxt = PatGuard match_ctxt
+
+ tc_grhss (GRHSs x grhss binds) stk_ty res_ty
+ = do { (binds', grhss') <- tcLocalBinds binds $
+ mapM (wrapLocMA (tc_grhs stk_ty res_ty)) grhss
+ ; return (GRHSs x grhss' binds') }
+
+ tc_grhs stk_ty res_ty (GRHS x guards body)
+ = do { (guards', rhs') <- tcStmtsAndThen pg_ctxt tcGuardStmt guards res_ty $
+ \ res_ty -> tcCmd env body
+ (stk_ty, checkingExpType "tc_grhs" res_ty)
+ ; return (GRHS x guards' rhs') }
+
matchExpectedCmdArgs :: Arity -> TcType -> TcM (TcCoercionN, [TcType], TcType)
matchExpectedCmdArgs 0 ty
= return (mkTcNomReflCo ty, [], ty)