summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexis King <lexi.lambda@gmail.com>2020-04-19 20:11:37 -0500
committerMarge Bot <ben+marge-bot@smart-cactus.org>2020-04-30 01:57:35 -0400
commita48cd2a045695c5f34ed7b00184a8d91c4fcac0e (patch)
tree526d1b2e19ce1b8ffcaeb73688999a255de2ef2e
parent71484b09fa3c676e99785b3d68f69d4cfb14266e (diff)
downloadhaskell-a48cd2a045695c5f34ed7b00184a8d91c4fcac0e.tar.gz
Allow LambdaCase to be used as a command in proc notation
-rw-r--r--compiler/GHC/Hs/Expr.hs12
-rw-r--r--compiler/GHC/Hs/Extension.hs1
-rw-r--r--compiler/GHC/HsToCore/Arrows.hs91
-rw-r--r--compiler/GHC/HsToCore/Coverage.hs2
-rw-r--r--compiler/GHC/Iface/Ext/Ast.hs3
-rw-r--r--compiler/GHC/Parser.y7
-rw-r--r--compiler/GHC/Parser/PostProcess.hs5
-rw-r--r--compiler/GHC/Rename/Expr.hs6
-rw-r--r--compiler/GHC/Tc/Gen/Arrow.hs27
-rw-r--r--compiler/GHC/Tc/Utils/Zonk.hs4
-rw-r--r--testsuite/tests/arrows/should_run/ArrowLambdaCase.hs18
-rw-r--r--testsuite/tests/arrows/should_run/ArrowLambdaCase.stdout3
-rw-r--r--testsuite/tests/arrows/should_run/all.T2
-rw-r--r--testsuite/tests/parser/should_compile/ParserArrowLambdaCase.hs8
-rw-r--r--testsuite/tests/parser/should_compile/all.T1
15 files changed, 143 insertions, 47 deletions
diff --git a/compiler/GHC/Hs/Expr.hs b/compiler/GHC/Hs/Expr.hs
index a03c0aa50d..97eab7d3aa 100644
--- a/compiler/GHC/Hs/Expr.hs
+++ b/compiler/GHC/Hs/Expr.hs
@@ -1330,6 +1330,14 @@ data HsCmd id
-- For details on above see note [Api annotations] in GHC.Parser.Annotation
+ | HsCmdLamCase (XCmdLamCase id)
+ (MatchGroup id (LHsCmd id)) -- bodies are HsCmd's
+ -- ^ - 'ApiAnnotation.AnnKeywordId' : 'ApiAnnotation.AnnLam',
+ -- 'ApiAnnotation.AnnCase','ApiAnnotation.AnnOpen' @'{'@,
+ -- 'ApiAnnotation.AnnClose' @'}'@
+
+ -- For details on above see note [Api annotations] in GHC.Parser.Annotation
+
| HsCmdIf (XCmdIf id)
(SyntaxExpr id) -- cond function
(LHsExpr id) -- predicate
@@ -1371,6 +1379,7 @@ type instance XCmdApp (GhcPass _) = NoExtField
type instance XCmdLam (GhcPass _) = NoExtField
type instance XCmdPar (GhcPass _) = NoExtField
type instance XCmdCase (GhcPass _) = NoExtField
+type instance XCmdLamCase (GhcPass _) = NoExtField
type instance XCmdIf (GhcPass _) = NoExtField
type instance XCmdLet (GhcPass _) = NoExtField
@@ -1460,6 +1469,9 @@ ppr_cmd (HsCmdCase _ expr matches)
= sep [ sep [text "case", nest 4 (ppr expr), ptext (sLit "of")],
nest 2 (pprMatches matches) ]
+ppr_cmd (HsCmdLamCase _ matches)
+ = sep [ text "\\case", nest 2 (pprMatches matches) ]
+
ppr_cmd (HsCmdIf _ _ e ct ce)
= sep [hsep [text "if", nest 2 (ppr e), ptext (sLit "then")],
nest 4 (ppr ct),
diff --git a/compiler/GHC/Hs/Extension.hs b/compiler/GHC/Hs/Extension.hs
index 57cd67e65a..0de2ac35a6 100644
--- a/compiler/GHC/Hs/Extension.hs
+++ b/compiler/GHC/Hs/Extension.hs
@@ -599,6 +599,7 @@ type family XCmdApp x
type family XCmdLam x
type family XCmdPar x
type family XCmdCase x
+type family XCmdLamCase x
type family XCmdIf x
type family XCmdLet x
type family XCmdDo x
diff --git a/compiler/GHC/HsToCore/Arrows.hs b/compiler/GHC/HsToCore/Arrows.hs
index 733ae86d6e..444989a18e 100644
--- a/compiler/GHC/HsToCore/Arrows.hs
+++ b/compiler/GHC/HsToCore/Arrows.hs
@@ -447,45 +447,12 @@ dsCmd ids local_vars stack_ty res_ty (HsCmdApp _ cmd arg) env_ids = do
free_vars `unionDVarSet`
(exprFreeIdsDSet core_arg `uniqDSetIntersectUniqSet` local_vars))
--- D; ys |-a cmd : stk t'
--- -----------------------------------------------
--- D; xs |-a \ p1 ... pk -> cmd : (t1,...(tk,stk)...) t'
---
--- ---> premap (\ ((xs), (p1, ... (pk,stk)...)) -> ((ys),stk)) cmd
-
dsCmd ids local_vars stack_ty res_ty
(HsCmdLam _ (MG { mg_alts
= (L _ [L _ (Match { m_pats = pats
, m_grhss = GRHSs _ [L _ (GRHS _ [] body)] _ })]) }))
- env_ids = do
- let pat_vars = mkVarSet (collectPatsBinders pats)
- let
- local_vars' = pat_vars `unionVarSet` local_vars
- (pat_tys, stack_ty') = splitTypeAt (length pats) stack_ty
- (core_body, free_vars, env_ids')
- <- dsfixCmd ids local_vars' stack_ty' res_ty body
- param_ids <- mapM newSysLocalDsNoLP pat_tys
- stack_id' <- newSysLocalDs stack_ty'
-
- -- the expression is built from the inside out, so the actions
- -- are presented in reverse order
-
- let
- -- build a new environment, plus what's left of the stack
- core_expr = buildEnvStack env_ids' stack_id'
- in_ty = envStackType env_ids stack_ty
- in_ty' = envStackType env_ids' stack_ty'
-
- fail_expr <- mkFailExpr LambdaExpr in_ty'
- -- match the patterns against the parameters
- match_code <- matchSimplys (map Var param_ids) LambdaExpr pats core_expr
- fail_expr
- -- match the parameters against the top of the old stack
- (stack_id, param_code) <- matchVarStack param_ids stack_id' match_code
- -- match the old environment and stack against the input
- select_code <- matchEnvStack env_ids stack_id param_code
- return (do_premap ids in_ty in_ty' res_ty select_code core_body,
- free_vars `uniqDSetMinusUniqSet` pat_vars)
+ env_ids
+ = dsCmdLam ids local_vars stack_ty res_ty pats body env_ids
dsCmd ids local_vars stack_ty res_ty (HsCmdPar _ cmd) env_ids
= dsLCmd ids local_vars stack_ty res_ty cmd env_ids
@@ -626,6 +593,12 @@ dsCmd ids local_vars stack_ty res_ty
return (do_premap ids in_ty sum_ty res_ty core_matches core_choices,
exprFreeIdsDSet core_body `uniqDSetIntersectUniqSet` local_vars)
+dsCmd ids local_vars stack_ty res_ty
+ (HsCmdLamCase _ mg@MG { mg_ext = MatchGroupTc [arg_ty] _ }) env_ids = do
+ arg_id <- newSysLocalDs arg_ty
+ let case_cmd = noLoc $ HsCmdCase noExtField (nlHsVar arg_id) mg
+ dsCmdLam ids local_vars stack_ty res_ty [nlVarPat arg_id] case_cmd env_ids
+
-- D; ys |-a cmd : stk --> t
-- ----------------------------------
-- D; xs |-a let binds in cmd : stk --> t
@@ -693,7 +666,7 @@ dsCmd ids local_vars stack_ty res_ty (XCmd (HsWrap wrap cmd)) env_ids = do
core_wrap <- dsHsWrapper wrap
return (core_wrap core_cmd, env_ids')
-dsCmd _ _ _ _ _ c = pprPanic "dsCmd" (ppr c)
+dsCmd _ _ _ _ c _ = pprPanic "dsCmd" (ppr c)
-- D; ys |-a c : stk --> t (ys <= xs)
-- ---------------------
@@ -753,6 +726,52 @@ trimInput build_arrow
(core_cmd, free_vars) <- build_arrow env_ids
return (core_cmd, free_vars, dVarSetElems free_vars))
+-- Desugaring for both HsCmdLam and HsCmdLamCase.
+--
+-- D; ys |-a cmd : stk t'
+-- -----------------------------------------------
+-- D; xs |-a \ p1 ... pk -> cmd : (t1,...(tk,stk)...) t'
+--
+-- ---> premap (\ ((xs), (p1, ... (pk,stk)...)) -> ((ys),stk)) cmd
+dsCmdLam :: DsCmdEnv -- arrow combinators
+ -> IdSet -- set of local vars available to this command
+ -> Type -- type of the stack (right-nested tuple)
+ -> Type -- return type of the command
+ -> [LPat GhcTc] -- argument patterns to desugar
+ -> LHsCmd GhcTc -- body to desugar
+ -> [Id] -- list of vars in the input to this command
+ -- This is typically fed back,
+ -- so don't pull on it too early
+ -> DsM (CoreExpr, -- desugared expression
+ DIdSet) -- subset of local vars that occur free
+dsCmdLam ids local_vars stack_ty res_ty pats body env_ids = do
+ let pat_vars = mkVarSet (collectPatsBinders pats)
+ let local_vars' = pat_vars `unionVarSet` local_vars
+ (pat_tys, stack_ty') = splitTypeAt (length pats) stack_ty
+ (core_body, free_vars, env_ids')
+ <- dsfixCmd ids local_vars' stack_ty' res_ty body
+ param_ids <- mapM newSysLocalDsNoLP pat_tys
+ stack_id' <- newSysLocalDs stack_ty'
+
+ -- the expression is built from the inside out, so the actions
+ -- are presented in reverse order
+
+ let -- build a new environment, plus what's left of the stack
+ core_expr = buildEnvStack env_ids' stack_id'
+ in_ty = envStackType env_ids stack_ty
+ in_ty' = envStackType env_ids' stack_ty'
+
+ fail_expr <- mkFailExpr LambdaExpr in_ty'
+ -- match the patterns against the parameters
+ match_code <- matchSimplys (map Var param_ids) LambdaExpr pats core_expr
+ fail_expr
+ -- match the parameters against the top of the old stack
+ (stack_id, param_code) <- matchVarStack param_ids stack_id' match_code
+ -- match the old environment and stack against the input
+ select_code <- matchEnvStack env_ids stack_id param_code
+ return (do_premap ids in_ty in_ty' res_ty select_code core_body,
+ free_vars `uniqDSetMinusUniqSet` pat_vars)
+
{-
Translation of command judgements of the form
diff --git a/compiler/GHC/HsToCore/Coverage.hs b/compiler/GHC/HsToCore/Coverage.hs
index 8130565837..d8b83bb25e 100644
--- a/compiler/GHC/HsToCore/Coverage.hs
+++ b/compiler/GHC/HsToCore/Coverage.hs
@@ -861,6 +861,8 @@ addTickHsCmd (HsCmdCase x e mgs) =
liftM2 (HsCmdCase x)
(addTickLHsExpr e)
(addTickCmdMatchGroup mgs)
+addTickHsCmd (HsCmdLamCase x mgs) =
+ liftM (HsCmdLamCase x) (addTickCmdMatchGroup mgs)
addTickHsCmd (HsCmdIf x cnd e1 c2 c3) =
liftM3 (HsCmdIf x cnd)
(addBinTickLHsExpr (BinBox CondBinBox) e1)
diff --git a/compiler/GHC/Iface/Ext/Ast.hs b/compiler/GHC/Iface/Ext/Ast.hs
index ffd7d26415..ddb29ce63d 100644
--- a/compiler/GHC/Iface/Ext/Ast.hs
+++ b/compiler/GHC/Iface/Ext/Ast.hs
@@ -1240,6 +1240,9 @@ instance ( a ~ GhcPass p
[ toHie expr
, toHie alts
]
+ HsCmdLamCase _ alts ->
+ [ toHie alts
+ ]
HsCmdIf _ _ a b c ->
[ toHie a
, toHie b
diff --git a/compiler/GHC/Parser.y b/compiler/GHC/Parser.y
index 34d46fd4db..bafed741be 100644
--- a/compiler/GHC/Parser.y
+++ b/compiler/GHC/Parser.y
@@ -2765,11 +2765,10 @@ aexp :: { ECP }
(mj AnnLet $1:mj AnnIn $3
:(fst $ unLoc $2)) }
| '\\' 'lcase' altslist
- {% runPV $3 >>= \ $3 ->
- fmap ecpFromExp $
- ams (sLL $1 $> $ HsLamCase noExtField
+ { ECP $ $3 >>= \ $3 ->
+ amms (mkHsLamCasePV (comb2 $1 $>)
(mkMatchGroup FromSource (snd $ unLoc $3)))
- (mj AnnLam $1:mj AnnCase $2:(fst $ unLoc $3)) }
+ (mj AnnLam $1:mj AnnCase $2:(fst $ unLoc $3)) }
| 'if' exp optSemi 'then' exp optSemi 'else' exp
{% runECP_P $2 >>= \ $2 ->
return $ ECP $
diff --git a/compiler/GHC/Parser/PostProcess.hs b/compiler/GHC/Parser/PostProcess.hs
index fdc3085e3d..94137f07ba 100644
--- a/compiler/GHC/Parser/PostProcess.hs
+++ b/compiler/GHC/Parser/PostProcess.hs
@@ -1760,6 +1760,8 @@ class b ~ (Body b) GhcPs => DisambECP b where
mkHsOpAppPV :: SrcSpan -> Located b -> Located (InfixOp b) -> Located b -> PV (Located b)
-- | Disambiguate "case ... of ..."
mkHsCasePV :: SrcSpan -> LHsExpr GhcPs -> MatchGroup GhcPs (Located b) -> PV (Located b)
+ -- | Disambiguate @\\case ...@ (lambda case)
+ mkHsLamCasePV :: SrcSpan -> MatchGroup GhcPs (Located b) -> PV (Located b)
-- | Function argument representation
type FunArg b
-- | Bring superclass constraints on FunArg into scope.
@@ -1874,6 +1876,7 @@ instance DisambECP (HsCmd GhcPs) where
let cmdArg c = L (getLoc c) $ HsCmdTop noExtField c
return $ L l $ HsCmdArrForm noExtField op Infix Nothing [cmdArg c1, cmdArg c2]
mkHsCasePV l c mg = return $ L l (HsCmdCase noExtField c mg)
+ mkHsLamCasePV l mg = return $ L l (HsCmdLamCase noExtField mg)
type FunArg (HsCmd GhcPs) = HsExpr GhcPs
superFunArg m = m
mkHsAppPV l c e = do
@@ -1930,6 +1933,7 @@ instance DisambECP (HsExpr GhcPs) where
mkHsOpAppPV l e1 op e2 = do
return $ L l $ OpApp noExtField e1 op e2
mkHsCasePV l e mg = return $ L l (HsCase noExtField e mg)
+ mkHsLamCasePV l mg = return $ L l (HsLamCase noExtField mg)
type FunArg (HsExpr GhcPs) = HsExpr GhcPs
superFunArg m = m
mkHsAppPV l e1 e2 = do
@@ -2014,6 +2018,7 @@ instance DisambECP (PatBuilder GhcPs) where
superInfixOp m = m
mkHsOpAppPV l p1 op p2 = return $ L l $ PatBuilderOpApp p1 op p2
mkHsCasePV l _ _ = addFatalError l $ text "(case ... of ...)-syntax in pattern"
+ mkHsLamCasePV l _ = addFatalError l $ text "(\\case ...)-syntax in pattern"
type FunArg (PatBuilder GhcPs) = PatBuilder GhcPs
superFunArg m = m
mkHsAppPV l p1 p2 = return $ L l (PatBuilderApp p1 p2)
diff --git a/compiler/GHC/Rename/Expr.hs b/compiler/GHC/Rename/Expr.hs
index 773b194db8..6ec473134d 100644
--- a/compiler/GHC/Rename/Expr.hs
+++ b/compiler/GHC/Rename/Expr.hs
@@ -495,6 +495,10 @@ rnCmd (HsCmdCase x expr matches)
; (new_matches, ms_fvs) <- rnMatchGroup CaseAlt rnLCmd matches
; return (HsCmdCase x new_expr new_matches, e_fvs `plusFV` ms_fvs) }
+rnCmd (HsCmdLamCase x matches)
+ = do { (new_matches, ms_fvs) <- rnMatchGroup CaseAlt rnLCmd matches
+ ; return (HsCmdLamCase x new_matches, ms_fvs) }
+
rnCmd (HsCmdIf x _ p b1 b2)
= do { (p', fvP) <- rnLExpr p
; (b1', fvB1) <- rnLCmd b1
@@ -540,6 +544,8 @@ methodNamesCmd (HsCmdLam _ match) = methodNamesMatch match
methodNamesCmd (HsCmdCase _ _ matches)
= methodNamesMatch matches `addOneFV` choiceAName
+methodNamesCmd (HsCmdLamCase _ matches)
+ = methodNamesMatch matches `addOneFV` choiceAName
--methodNamesCmd _ = emptyFVs
-- Other forms can't occur in commands, but it's not convenient
diff --git a/compiler/GHC/Tc/Gen/Arrow.hs b/compiler/GHC/Tc/Gen/Arrow.hs
index 5d26927ed4..6ac42a76d0 100644
--- a/compiler/GHC/Tc/Gen/Arrow.hs
+++ b/compiler/GHC/Tc/Gen/Arrow.hs
@@ -151,13 +151,14 @@ tc_cmd env (HsCmdLet x (L l binds) (L body_loc body)) res_ty
tc_cmd env in_cmd@(HsCmdCase x scrut matches) (stk, res_ty)
= addErrCtxt (cmdCtxt in_cmd) $ do
(scrut', scrut_ty) <- tcInferRho scrut
- matches' <- tcMatchesCase match_ctxt scrut_ty matches (mkCheckExpType res_ty)
+ matches' <- tcCmdMatches env scrut_ty matches (stk, res_ty)
return (HsCmdCase x scrut' matches')
- where
- match_ctxt = MC { mc_what = CaseAlt,
- mc_body = mc_body }
- mc_body body res_ty' = do { res_ty' <- expTypeToType res_ty'
- ; tcCmd env body (stk, res_ty') }
+
+tc_cmd env in_cmd@(HsCmdLamCase x matches) (stk, res_ty)
+ = addErrCtxt (cmdCtxt in_cmd) $ do
+ (co, [scrut_ty], stk') <- matchExpectedCmdArgs 1 stk
+ matches' <- tcCmdMatches env scrut_ty matches (stk', res_ty)
+ return (mkHsCmdWrap (mkWpCastN co) (HsCmdLamCase x matches'))
tc_cmd env (HsCmdIf x NoSyntaxExprRn pred b1 b2) res_ty -- Ordinary 'if'
= do { pred' <- tcLExpr pred (mkCheckExpType boolTy)
@@ -330,6 +331,20 @@ tc_cmd _ cmd _
= failWithTc (vcat [text "The expression", nest 2 (ppr cmd),
text "was found where an arrow command was expected"])
+-- | Typechecking for case command alternatives. Used for both
+-- 'HsCmdCase' and 'HsCmdLamCase'.
+tcCmdMatches :: CmdEnv
+ -> TcType -- ^ type of the scrutinee
+ -> MatchGroup GhcRn (LHsCmd GhcRn) -- ^ case alternatives
+ -> CmdType
+ -> TcM (MatchGroup GhcTcId (LHsCmd GhcTcId))
+tcCmdMatches env scrut_ty matches (stk, res_ty)
+ = tcMatchesCase match_ctxt scrut_ty matches (mkCheckExpType res_ty)
+ where
+ match_ctxt = MC { mc_what = CaseAlt,
+ mc_body = mc_body }
+ mc_body body res_ty' = do { res_ty' <- expTypeToType res_ty'
+ ; tcCmd env body (stk, res_ty') }
matchExpectedCmdArgs :: Arity -> TcType -> TcM (TcCoercionN, [TcType], TcType)
matchExpectedCmdArgs 0 ty
diff --git a/compiler/GHC/Tc/Utils/Zonk.hs b/compiler/GHC/Tc/Utils/Zonk.hs
index 8fbbba22b1..4372a39e9d 100644
--- a/compiler/GHC/Tc/Utils/Zonk.hs
+++ b/compiler/GHC/Tc/Utils/Zonk.hs
@@ -995,6 +995,10 @@ zonkCmd env (HsCmdCase x expr ms)
new_ms <- zonkMatchGroup env zonkLCmd ms
return (HsCmdCase x new_expr new_ms)
+zonkCmd env (HsCmdLamCase x ms)
+ = do new_ms <- zonkMatchGroup env zonkLCmd ms
+ return (HsCmdLamCase x new_ms)
+
zonkCmd env (HsCmdIf x eCond ePred cThen cElse)
= do { (env1, new_eCond) <- zonkSyntaxExpr env eCond
; new_ePred <- zonkLExpr env1 ePred
diff --git a/testsuite/tests/arrows/should_run/ArrowLambdaCase.hs b/testsuite/tests/arrows/should_run/ArrowLambdaCase.hs
new file mode 100644
index 0000000000..c678339890
--- /dev/null
+++ b/testsuite/tests/arrows/should_run/ArrowLambdaCase.hs
@@ -0,0 +1,18 @@
+{-# LANGUAGE Arrows, LambdaCase #-}
+module Main (main) where
+
+import Control.Arrow
+
+main :: IO ()
+main = do
+ putStrLn $ foo (Just 42)
+ putStrLn $ foo (Just 500)
+ putStrLn $ foo Nothing
+
+foo :: ArrowChoice p => p (Maybe Int) String
+foo = proc x ->
+ (| id (\case
+ Just x | x > 100 -> returnA -< "big " ++ show x
+ | otherwise -> returnA -< "small " ++ show x
+ Nothing -> returnA -< "none")
+ |) x
diff --git a/testsuite/tests/arrows/should_run/ArrowLambdaCase.stdout b/testsuite/tests/arrows/should_run/ArrowLambdaCase.stdout
new file mode 100644
index 0000000000..09e50cf6d7
--- /dev/null
+++ b/testsuite/tests/arrows/should_run/ArrowLambdaCase.stdout
@@ -0,0 +1,3 @@
+small 42
+big 500
+none
diff --git a/testsuite/tests/arrows/should_run/all.T b/testsuite/tests/arrows/should_run/all.T
index 2faabff765..0a1e32e34c 100644
--- a/testsuite/tests/arrows/should_run/all.T
+++ b/testsuite/tests/arrows/should_run/all.T
@@ -3,4 +3,4 @@ test('arrowrun002', when(fast(), skip), compile_and_run, [''])
test('arrowrun003', normal, compile_and_run, [''])
test('arrowrun004', when(fast(), skip), compile_and_run, [''])
test('T3822', normal, compile_and_run, [''])
-
+test('ArrowLambdaCase', normal, compile_and_run, [''])
diff --git a/testsuite/tests/parser/should_compile/ParserArrowLambdaCase.hs b/testsuite/tests/parser/should_compile/ParserArrowLambdaCase.hs
new file mode 100644
index 0000000000..b16eb7579b
--- /dev/null
+++ b/testsuite/tests/parser/should_compile/ParserArrowLambdaCase.hs
@@ -0,0 +1,8 @@
+{-# LANGUAGE Arrows, LambdaCase #-}
+module ParserArrowLambdaCase where
+
+import Control.Arrow
+
+foo :: () -> ()
+foo = proc () -> (| id (\case
+ () -> () >- returnA) |) ()
diff --git a/testsuite/tests/parser/should_compile/all.T b/testsuite/tests/parser/should_compile/all.T
index fd69d32f0f..1568a341ec 100644
--- a/testsuite/tests/parser/should_compile/all.T
+++ b/testsuite/tests/parser/should_compile/all.T
@@ -94,6 +94,7 @@ test('mc15', normal, compile, [''])
test('mc16', normal, compile, [''])
test('EmptyDecls', normal, compile, [''])
test('ParserLambdaCase', [], compile, [''])
+test('ParserArrowLambdaCase', [], compile, [''])
test('ColumnPragma', normal, compile, [''])
test('T5243', [], multimod_compile, ['T5243', ''])