diff options
author | Simon Marlow <marlowsd@gmail.com> | 2015-03-13 16:39:58 +0000 |
---|---|---|
committer | Simon Marlow <marlowsd@gmail.com> | 2015-09-17 16:52:03 +0100 |
commit | 8ecf6d8f7dfee9e5b1844cd196f83f00f3b6b879 (patch) | |
tree | 9bf2b8601fefa7e1eaac11079d27660824b1466f /compiler | |
parent | 43eb1dc52a4d3cbba9617f5a26177b8251d84b6a (diff) | |
download | haskell-8ecf6d8f7dfee9e5b1844cd196f83f00f3b6b879.tar.gz |
ApplicativeDo transformation
Summary:
This is an implementation of the ApplicativeDo proposal. See the Note
[ApplicativeDo] in RnExpr for details on the current implementation,
and the wiki page https://ghc.haskell.org/trac/ghc/wiki/ApplicativeDo
for design notes.
Test Plan: validate
Reviewers: simonpj, goldfire, austin
Subscribers: thomie
Differential Revision: https://phabricator.haskell.org/D729
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/coreSyn/MkCore.hs | 47 | ||||
-rw-r--r-- | compiler/deSugar/Coverage.hs | 31 | ||||
-rw-r--r-- | compiler/deSugar/DsArrows.hs | 11 | ||||
-rw-r--r-- | compiler/deSugar/DsExpr.hs | 41 | ||||
-rw-r--r-- | compiler/deSugar/DsGRHSs.hs | 2 | ||||
-rw-r--r-- | compiler/deSugar/DsListComp.hs | 27 | ||||
-rw-r--r-- | compiler/deSugar/DsMeta.hs | 2 | ||||
-rw-r--r-- | compiler/deSugar/DsUtils.hs | 18 | ||||
-rw-r--r-- | compiler/hsSyn/HsExpr.hs | 91 | ||||
-rw-r--r-- | compiler/hsSyn/HsUtils.hs | 76 | ||||
-rw-r--r-- | compiler/main/DynFlags.hs | 2 | ||||
-rw-r--r-- | compiler/parser/RdrHsSyn.hs | 4 | ||||
-rw-r--r-- | compiler/rename/RnBinds.hs | 19 | ||||
-rw-r--r-- | compiler/rename/RnExpr.hs | 528 | ||||
-rw-r--r-- | compiler/typecheck/TcArrows.hs | 4 | ||||
-rw-r--r-- | compiler/typecheck/TcHsSyn.hs | 27 | ||||
-rw-r--r-- | compiler/typecheck/TcMatches.hs | 126 |
17 files changed, 898 insertions, 158 deletions
diff --git a/compiler/coreSyn/MkCore.hs b/compiler/coreSyn/MkCore.hs index 69410cd6cd..fb797f11ce 100644 --- a/compiler/coreSyn/MkCore.hs +++ b/compiler/coreSyn/MkCore.hs @@ -22,10 +22,6 @@ module MkCore ( -- * Constructing equality evidence boxes mkEqBox, - -- * Constructing general big tuples - -- $big_tuples - mkChunkified, - -- * Constructing small tuples mkCoreVarTup, mkCoreVarTupTy, mkCoreTup, @@ -67,6 +63,7 @@ import HscTypes import TysWiredIn import PrelNames +import HsUtils ( mkChunkified, chunkify ) import TcType ( mkSigmaTy ) import Type import Coercion @@ -82,7 +79,6 @@ import UniqSupply import BasicTypes import Util import Pair -import Constants import DynFlags import Data.Char ( ord ) @@ -319,47 +315,6 @@ mkEqBox co = ASSERT2( typeKind ty2 `eqKind` k, ppr co $$ ppr ty1 $$ ppr ty2 $$ p ************************************************************************ -} --- $big_tuples --- #big_tuples# --- --- GHCs built in tuples can only go up to 'mAX_TUPLE_SIZE' in arity, but --- we might concievably want to build such a massive tuple as part of the --- output of a desugaring stage (notably that for list comprehensions). --- --- We call tuples above this size \"big tuples\", and emulate them by --- creating and pattern matching on >nested< tuples that are expressible --- by GHC. --- --- Nesting policy: it's better to have a 2-tuple of 10-tuples (3 objects) --- than a 10-tuple of 2-tuples (11 objects), so we want the leaves of any --- construction to be big. --- --- If you just use the 'mkBigCoreTup', 'mkBigCoreVarTupTy', 'mkTupleSelector' --- and 'mkTupleCase' functions to do all your work with tuples you should be --- fine, and not have to worry about the arity limitation at all. - --- | Lifts a \"small\" constructor into a \"big\" constructor by recursive decompositon -mkChunkified :: ([a] -> a) -- ^ \"Small\" constructor function, of maximum input arity 'mAX_TUPLE_SIZE' - -> [a] -- ^ Possible \"big\" list of things to construct from - -> a -- ^ Constructed thing made possible by recursive decomposition -mkChunkified small_tuple as = mk_big_tuple (chunkify as) - where - -- Each sub-list is short enough to fit in a tuple - mk_big_tuple [as] = small_tuple as - mk_big_tuple as_s = mk_big_tuple (chunkify (map small_tuple as_s)) - -chunkify :: [a] -> [[a]] --- ^ Split a list into lists that are small enough to have a corresponding --- tuple arity. The sub-lists of the result all have length <= 'mAX_TUPLE_SIZE' --- But there may be more than 'mAX_TUPLE_SIZE' sub-lists -chunkify xs - | n_xs <= mAX_TUPLE_SIZE = [xs] - | otherwise = split xs - where - n_xs = length xs - split [] = [] - split xs = take mAX_TUPLE_SIZE xs : split (drop mAX_TUPLE_SIZE xs) - {- Creating tuples and their types for Core expressions diff --git a/compiler/deSugar/Coverage.hs b/compiler/deSugar/Coverage.hs index f5a9290e48..4ee205ec4c 100644 --- a/compiler/deSugar/Coverage.hs +++ b/compiler/deSugar/Coverage.hs @@ -3,7 +3,7 @@ (c) University of Glasgow, 2007 -} -{-# LANGUAGE NondecreasingIndentation #-} +{-# LANGUAGE CPP, NondecreasingIndentation #-} module Coverage (addTicksToBinds, hpcInitCode) where @@ -660,9 +660,10 @@ addTickLStmts' isGuard lstmts res ; return (lstmts', a) } addTickStmt :: (Maybe (Bool -> BoxLabel)) -> Stmt Id (LHsExpr Id) -> TM (Stmt Id (LHsExpr Id)) -addTickStmt _isGuard (LastStmt e ret) = do - liftM2 LastStmt +addTickStmt _isGuard (LastStmt e noret ret) = do + liftM3 LastStmt (addTickLHsExpr e) + (pure noret) (addTickSyntaxExpr hpcSrcSpan ret) addTickStmt _isGuard (BindStmt pat e bind fail) = do liftM4 BindStmt @@ -684,6 +685,9 @@ addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr) = do (mapM (addTickStmtAndBinders isGuard) pairs) (addTickSyntaxExpr hpcSrcSpan mzipExpr) (addTickSyntaxExpr hpcSrcSpan bindExpr) +addTickStmt isGuard (ApplicativeStmt args mb_join body_ty) = do + args' <- mapM (addTickApplicativeArg isGuard) args + return (ApplicativeStmt args' mb_join body_ty) addTickStmt isGuard stmt@(TransStmt { trS_stmts = stmts , trS_by = by, trS_using = using @@ -710,6 +714,20 @@ addTick :: Maybe (Bool -> BoxLabel) -> LHsExpr Id -> TM (LHsExpr Id) addTick isGuard e | Just fn <- isGuard = addBinTickLHsExpr fn e | otherwise = addTickLHsExprRHS e +addTickApplicativeArg + :: Maybe (Bool -> BoxLabel) -> (SyntaxExpr Id, ApplicativeArg Id Id) + -> TM (SyntaxExpr Id, ApplicativeArg Id Id) +addTickApplicativeArg isGuard (op, arg) = + liftM2 (,) (addTickSyntaxExpr hpcSrcSpan op) (addTickArg arg) + where + addTickArg (ApplicativeArgOne pat expr) = + ApplicativeArgOne <$> addTickLPat pat <*> addTickLHsExpr expr + addTickArg (ApplicativeArgMany stmts ret pat) = + ApplicativeArgMany + <$> addTickLStmts isGuard stmts + <*> addTickSyntaxExpr hpcSrcSpan ret + <*> addTickLPat pat + addTickStmtAndBinders :: Maybe (Bool -> BoxLabel) -> ParStmtBlock Id Id -> TM (ParStmtBlock Id Id) addTickStmtAndBinders isGuard (ParStmtBlock stmts ids returnExpr) = @@ -872,9 +890,10 @@ addTickCmdStmt (BindStmt pat c bind fail) = do (addTickLHsCmd c) (return bind) (return fail) -addTickCmdStmt (LastStmt c ret) = do - liftM2 LastStmt +addTickCmdStmt (LastStmt c noret ret) = do + liftM3 LastStmt (addTickLHsCmd c) + (pure noret) (addTickSyntaxExpr hpcSrcSpan ret) addTickCmdStmt (BodyStmt c bind' guard' ty) = do liftM4 BodyStmt @@ -892,6 +911,8 @@ addTickCmdStmt stmt@(RecStmt {}) ; bind' <- addTickSyntaxExpr hpcSrcSpan (recS_bind_fn stmt) ; return (stmt { recS_stmts = stmts', recS_ret_fn = ret' , recS_mfix_fn = mfix', recS_bind_fn = bind' }) } +addTickCmdStmt ApplicativeStmt{} = + panic "ToDo: addTickCmdStmt ApplicativeLastStmt" -- Others should never happen in a command context. addTickCmdStmt stmt = pprPanic "addTickHsCmd" (ppr stmt) diff --git a/compiler/deSugar/DsArrows.hs b/compiler/deSugar/DsArrows.hs index 44795b9dfa..1657a5f49d 100644 --- a/compiler/deSugar/DsArrows.hs +++ b/compiler/deSugar/DsArrows.hs @@ -18,6 +18,7 @@ import DsMonad import HsSyn hiding (collectPatBinders, collectPatsBinders, collectLStmtsBinders, collectLStmtBinders, collectStmtBinders ) import TcHsSyn +import qualified HsUtils -- NB: The desugarer, which straddles the source and Core worlds, sometimes -- needs to see source types (newtypes etc), and sometimes not @@ -694,7 +695,7 @@ dsCmdDo _ _ _ [] _ = panic "dsCmdDo" -- -- ---> premap (\ (xs) -> ((xs), ())) c -dsCmdDo ids local_vars res_ty [L _ (LastStmt body _)] env_ids = do +dsCmdDo ids local_vars res_ty [L _ (LastStmt body _ _)] env_ids = do (core_body, env_ids') <- dsLCmd ids local_vars unitTy res_ty body env_ids let env_ty = mkBigCoreVarTupTy env_ids env_var <- newSysLocalDs env_ty @@ -1167,11 +1168,5 @@ collectLStmtBinders :: LStmt Id body -> [Id] collectLStmtBinders = collectStmtBinders . unLoc collectStmtBinders :: Stmt Id body -> [Id] -collectStmtBinders (BindStmt pat _ _ _) = collectPatBinders pat -collectStmtBinders (LetStmt binds) = collectLocalBinders binds -collectStmtBinders (BodyStmt {}) = [] -collectStmtBinders (LastStmt {}) = [] -collectStmtBinders (ParStmt xs _ _) = collectLStmtsBinders - $ [ s | ParStmtBlock ss _ _ <- xs, s <- ss] -collectStmtBinders (TransStmt { trS_stmts = stmts }) = collectLStmtsBinders stmts collectStmtBinders (RecStmt { recS_later_ids = later_ids }) = later_ids +collectStmtBinders stmt = HsUtils.collectStmtBinders stmt diff --git a/compiler/deSugar/DsExpr.hs b/compiler/deSugar/DsExpr.hs index 433a13ee37..d3a8156f88 100644 --- a/compiler/deSugar/DsExpr.hs +++ b/compiler/deSugar/DsExpr.hs @@ -33,6 +33,7 @@ import TcType import Coercion ( Role(..) ) import TcEvidence import TcRnMonad +import TcHsSyn import Type import CoreSyn import CoreUtils @@ -819,7 +820,7 @@ dsDo stmts goL [] = panic "dsDo" goL (L loc stmt:lstmts) = putSrcSpanDs loc (go loc stmt lstmts) - go _ (LastStmt body _) stmts + go _ (LastStmt body _ _) stmts = ASSERT( null stmts ) dsLExpr body -- The 'return' op isn't used for 'do' expressions @@ -846,13 +847,45 @@ dsDo stmts ; match_code <- handle_failure pat match fail_op ; return (mkApps bind_op' [rhs', Lam var match_code]) } + go _ (ApplicativeStmt args mb_join body_ty) stmts + = do { + let + (pats, rhss) = unzip (map (do_arg . snd) args) + + do_arg (ApplicativeArgOne pat expr) = + (pat, dsLExpr expr) + do_arg (ApplicativeArgMany stmts ret pat) = + (pat, dsDo (stmts ++ [noLoc $ mkLastStmt (noLoc ret)])) + + arg_tys = map hsLPatType pats + + ; rhss' <- sequence rhss + ; ops' <- mapM dsExpr (map fst args) + + ; let body' = noLoc $ HsDo DoExpr stmts body_ty + + ; let fun = L noSrcSpan $ HsLam $ + MG { mg_alts = [mkSimpleMatch pats body'] + , mg_arg_tys = arg_tys + , mg_res_ty = body_ty + , mg_origin = Generated } + + ; fun' <- dsLExpr fun + ; let mk_ap_call l (op,r) = mkApps op [l,r] + expr = foldl mk_ap_call fun' (zip ops' rhss') + ; case mb_join of + Nothing -> return expr + Just join_op -> + do { join_op' <- dsExpr join_op + ; return (App join_op' expr) } } + go loc (RecStmt { recS_stmts = rec_stmts, recS_later_ids = later_ids , recS_rec_ids = rec_ids, recS_ret_fn = return_op , recS_mfix_fn = mfix_op, recS_bind_fn = bind_op , recS_rec_rets = rec_rets, recS_ret_ty = body_ty }) stmts = goL (new_bind_stmt : stmts) -- rec_ids can be empty; eg rec { print 'x' } where - new_bind_stmt = L loc $ BindStmt (mkBigLHsPatTup later_pats) + new_bind_stmt = L loc $ BindStmt (mkBigLHsPatTupId later_pats) mfix_app bind_op noSyntaxExpr -- Tuple cannot fail @@ -865,9 +898,9 @@ dsDo stmts mfix_arg = noLoc $ HsLam (MG { mg_alts = [mkSimpleMatch [mfix_pat] body] , mg_arg_tys = [tup_ty], mg_res_ty = body_ty , mg_origin = Generated }) - mfix_pat = noLoc $ LazyPat $ mkBigLHsPatTup rec_tup_pats + mfix_pat = noLoc $ LazyPat $ mkBigLHsPatTupId rec_tup_pats body = noLoc $ HsDo DoExpr (rec_stmts ++ [ret_stmt]) body_ty - ret_app = nlHsApp (noLoc return_op) (mkBigLHsTup rets) + ret_app = nlHsApp (noLoc return_op) (mkBigLHsTupId rets) ret_stmt = noLoc $ mkLastStmt ret_app -- This LastStmt will be desugared with dsDo, -- which ignores the return_op in the LastStmt, diff --git a/compiler/deSugar/DsGRHSs.hs b/compiler/deSugar/DsGRHSs.hs index 1346f8af5e..6e4056a7c3 100644 --- a/compiler/deSugar/DsGRHSs.hs +++ b/compiler/deSugar/DsGRHSs.hs @@ -123,6 +123,8 @@ matchGuards (LastStmt {} : _) _ _ _ = panic "matchGuards LastStmt" matchGuards (ParStmt {} : _) _ _ _ = panic "matchGuards ParStmt" matchGuards (TransStmt {} : _) _ _ _ = panic "matchGuards TransStmt" matchGuards (RecStmt {} : _) _ _ _ = panic "matchGuards RecStmt" +matchGuards (ApplicativeStmt {} : _) _ _ _ = + panic "matchGuards ApplicativeLastStmt" isTrueLHsExpr :: LHsExpr Id -> Maybe (CoreExpr -> DsM CoreExpr) diff --git a/compiler/deSugar/DsListComp.hs b/compiler/deSugar/DsListComp.hs index 79d6f47612..985b12e19f 100644 --- a/compiler/deSugar/DsListComp.hs +++ b/compiler/deSugar/DsListComp.hs @@ -81,7 +81,7 @@ dsListComp lquals res_ty = do -- and the type of the elements that it outputs (tuples of binders) dsInnerListComp :: (ParStmtBlock Id Id) -> DsM (CoreExpr, Type) dsInnerListComp (ParStmtBlock stmts bndrs _) - = do { expr <- dsListComp (stmts ++ [noLoc $ mkLastStmt (mkBigLHsVarTup bndrs)]) + = do { expr <- dsListComp (stmts ++ [noLoc $ mkLastStmt (mkBigLHsVarTupId bndrs)]) (mkListTy bndrs_tuple_type) ; return (expr, bndrs_tuple_type) } where @@ -133,7 +133,7 @@ dsTransStmt (TransStmt { trS_form = form, trS_stmts = stmts, trS_bndrs = binderM -- Build a pattern that ensures the consumer binds into the NEW binders, -- which hold lists rather than single values - let pat = mkBigLHsVarPatTup to_bndrs + let pat = mkBigLHsVarPatTupId to_bndrs return (bound_unzipped_inner_list_expr, pat) dsTransStmt _ = panic "dsTransStmt: Not given a TransStmt" @@ -208,7 +208,7 @@ deListComp :: [ExprStmt Id] -> CoreExpr -> DsM CoreExpr deListComp [] _ = panic "deListComp" -deListComp (LastStmt body _ : quals) list +deListComp (LastStmt body _ _ : quals) list = -- Figure 7.4, SLPJ, p 135, rule C above ASSERT( null quals ) do { core_body <- dsLExpr body @@ -246,11 +246,14 @@ deListComp (ParStmt stmtss_w_bndrs _ _ : quals) list bndrs_s = [bs | ParStmtBlock _ bs _ <- stmtss_w_bndrs] -- pat is the pattern ((x1,..,xn), (y1,..,ym)) in the example above - pat = mkBigLHsPatTup pats - pats = map mkBigLHsVarPatTup bndrs_s + pat = mkBigLHsPatTupId pats + pats = map mkBigLHsVarPatTupId bndrs_s deListComp (RecStmt {} : _) _ = panic "deListComp RecStmt" +deListComp (ApplicativeStmt {} : _) _ = + panic "deListComp ApplicativeStmt" + deBindComp :: OutPat Id -> CoreExpr -> [ExprStmt Id] @@ -312,7 +315,7 @@ dfListComp :: Id -> Id -- 'c' and 'n' dfListComp _ _ [] = panic "dfListComp" -dfListComp c_id n_id (LastStmt body _ : quals) +dfListComp c_id n_id (LastStmt body _ _ : quals) = ASSERT( null quals ) do { core_body <- dsLExpr body ; return (mkApps (Var c_id) [core_body, Var n_id]) } @@ -342,6 +345,8 @@ dfListComp c_id n_id (BindStmt pat list1 _ _ : quals) = do dfListComp _ _ (ParStmt {} : _) = panic "dfListComp ParStmt" dfListComp _ _ (RecStmt {} : _) = panic "dfListComp RecStmt" +dfListComp _ _ (ApplicativeStmt {} : _) = + panic "dfListComp ApplicativeStmt" dfBindComp :: Id -> Id -- 'c' and 'n' -> (LPat Id, CoreExpr) @@ -510,7 +515,7 @@ dePArrComp [] _ _ = panic "dePArrComp" -- -- <<[:e' | :]>> pa ea = mapP (\pa -> e') ea -- -dePArrComp (LastStmt e' _ : quals) pa cea +dePArrComp (LastStmt e' _ _ : quals) pa cea = ASSERT( null quals ) do { mapP <- dsDPHBuiltin mapPVar ; let ty = parrElemType cea @@ -589,6 +594,8 @@ dePArrComp (ParStmt {} : _) _ _ = panic "DsListComp.dePArrComp: malformed comprehension AST: ParStmt" dePArrComp (TransStmt {} : _) _ _ = panic "DsListComp.dePArrComp: TransStmt" dePArrComp (RecStmt {} : _) _ _ = panic "DsListComp.dePArrComp: RecStmt" +dePArrComp (ApplicativeStmt {} : _) _ _ = + panic "DsListComp.dePArrComp: ApplicativeStmt" -- <<[:e' | qs | qss:]>> pa ea = -- <<[:e' | qss:]>> (pa, (x_1, ..., x_n)) @@ -666,7 +673,7 @@ dsMcStmts (L loc stmt : lstmts) = putSrcSpanDs loc (dsMcStmt stmt lstmts) --------------- dsMcStmt :: ExprStmt Id -> [ExprLStmt Id] -> DsM CoreExpr -dsMcStmt (LastStmt body ret_op) stmts +dsMcStmt (LastStmt body _ ret_op) stmts = ASSERT( null stmts ) do { body' <- dsLExpr body ; ret_op' <- dsExpr ret_op @@ -761,7 +768,7 @@ dsMcStmt (ParStmt blocks mzip_op bind_op) stmts_rest ; mzip_op' <- dsExpr mzip_op ; let -- The pattern variables - pats = [ mkBigLHsVarPatTup bs | ParStmtBlock _ bs _ <- blocks] + pats = [ mkBigLHsVarPatTupId bs | ParStmtBlock _ bs _ <- blocks] -- Pattern with tuples of variables -- [v1,v2,v3] => (v1, (v2, v3)) pat = foldr1 (\p1 p2 -> mkLHsPatTup [p1, p2]) pats @@ -834,7 +841,7 @@ dsInnerMonadComp :: [ExprLStmt Id] -> HsExpr Id -- The monomorphic "return" operator -> DsM CoreExpr dsInnerMonadComp stmts bndrs ret_op - = dsMcStmts (stmts ++ [noLoc (LastStmt (mkBigLHsVarTup bndrs) ret_op)]) + = dsMcStmts (stmts ++ [noLoc (LastStmt (mkBigLHsVarTupId bndrs) False ret_op)]) -- The `unzip` function for `GroupStmt` in a monad comprehensions -- diff --git a/compiler/deSugar/DsMeta.hs b/compiler/deSugar/DsMeta.hs index 867f900a78..a762810419 100644 --- a/compiler/deSugar/DsMeta.hs +++ b/compiler/deSugar/DsMeta.hs @@ -1279,7 +1279,7 @@ repSts (ParStmt stmt_blocks _ _ : ss) = do { (ss1, zs) <- repSts (map unLoc stmts) ; zs1 <- coreList stmtQTyConName zs ; return (ss1, zs1) } -repSts [LastStmt e _] +repSts [LastStmt e _ _] = do { e2 <- repLE e ; z <- repNoBindSt e2 ; return ([], [z]) } diff --git a/compiler/deSugar/DsUtils.hs b/compiler/deSugar/DsUtils.hs index 819944312b..bce5186f08 100644 --- a/compiler/deSugar/DsUtils.hs +++ b/compiler/deSugar/DsUtils.hs @@ -30,7 +30,7 @@ module DsUtils ( -- LHs tuples mkLHsVarPatTup, mkLHsPatTup, mkVanillaTuplePat, - mkBigLHsVarTup, mkBigLHsTup, mkBigLHsVarPatTup, mkBigLHsPatTup, + mkBigLHsVarTupId, mkBigLHsTupId, mkBigLHsVarPatTupId, mkBigLHsPatTupId, mkSelectorBinds, @@ -717,18 +717,18 @@ mkVanillaTuplePat :: [OutPat Id] -> Boxity -> Pat Id mkVanillaTuplePat pats box = TuplePat pats box (map hsLPatType pats) -- The Big equivalents for the source tuple expressions -mkBigLHsVarTup :: [Id] -> LHsExpr Id -mkBigLHsVarTup ids = mkBigLHsTup (map nlHsVar ids) +mkBigLHsVarTupId :: [Id] -> LHsExpr Id +mkBigLHsVarTupId ids = mkBigLHsTupId (map nlHsVar ids) -mkBigLHsTup :: [LHsExpr Id] -> LHsExpr Id -mkBigLHsTup = mkChunkified mkLHsTupleExpr +mkBigLHsTupId :: [LHsExpr Id] -> LHsExpr Id +mkBigLHsTupId = mkChunkified mkLHsTupleExpr -- The Big equivalents for the source tuple patterns -mkBigLHsVarPatTup :: [Id] -> LPat Id -mkBigLHsVarPatTup bs = mkBigLHsPatTup (map nlVarPat bs) +mkBigLHsVarPatTupId :: [Id] -> LPat Id +mkBigLHsVarPatTupId bs = mkBigLHsPatTupId (map nlVarPat bs) -mkBigLHsPatTup :: [LPat Id] -> LPat Id -mkBigLHsPatTup = mkChunkified mkLHsPatTup +mkBigLHsPatTupId :: [LPat Id] -> LPat Id +mkBigLHsPatTupId = mkChunkified mkLHsPatTup {- ************************************************************************ diff --git a/compiler/hsSyn/HsExpr.hs b/compiler/hsSyn/HsExpr.hs index 8b8b9df255..a3c1f6ce5b 100644 --- a/compiler/hsSyn/HsExpr.hs +++ b/compiler/hsSyn/HsExpr.hs @@ -39,6 +39,7 @@ import Type -- libraries: import Data.Data hiding (Fixity) +import Data.Maybe (isNothing) {- ************************************************************************ @@ -1266,12 +1267,15 @@ data StmtLR idL idR body -- body should always be (LHs**** idR) = LastStmt -- Always the last Stmt in ListComp, MonadComp, PArrComp, -- and (after the renamer) DoExpr, MDoExpr -- Not used for GhciStmtCtxt, PatGuard, which scope over other stuff - body - (SyntaxExpr idR) -- The return operator, used only for MonadComp - -- For ListComp, PArrComp, we use the baked-in 'return' - -- For DoExpr, MDoExpr, we don't apply a 'return' at all - -- See Note [Monad Comprehensions] - -- | - 'ApiAnnotation.AnnKeywordId' : 'ApiAnnotation.AnnLarrow' + body + Bool -- True <=> return was stripped by ApplicativeDo + (SyntaxExpr idR) -- The return operator, used only for + -- MonadComp For ListComp, PArrComp, we + -- use the baked-in 'return' For DoExpr, + -- MDoExpr, we don't apply a 'return' at + -- all See Note [Monad Comprehensions] | + -- - 'ApiAnnotation.AnnKeywordId' : + -- 'ApiAnnotation.AnnLarrow' -- For details on above see note [Api annotations] in ApiAnnotation | BindStmt (LPat idL) @@ -1281,6 +1285,20 @@ data StmtLR idL idR body -- body should always be (LHs**** idR) -- The fail operator is noSyntaxExpr -- if the pattern match can't fail + -- | 'ApplicativeStmt' represents an applicative expression built with + -- <$> and <*>. It is generated by the renamer, and is desugared into the + -- appropriate applicative expression by the desugarer, but it is intended + -- to be invisible in error messages. + -- + -- For full details, see Note [ApplicativeDo] in RnExpr + -- + | ApplicativeStmt + [ ( SyntaxExpr idR + , ApplicativeArg idL idR) ] + -- [(<$>, e1), (<*>, e2), ..., (<*>, en)] + (Maybe (SyntaxExpr idR)) -- 'join', if necessary + (PostTc idR Type) -- Type of the body + | BodyStmt body -- See Note [BodyStmt] (SyntaxExpr idR) -- The (>>) operator (SyntaxExpr idR) -- The `guard` operator; used only in MonadComp @@ -1375,6 +1393,17 @@ data ParStmtBlock idL idR deriving( Typeable ) deriving instance (DataId idL, DataId idR) => Data (ParStmtBlock idL idR) +data ApplicativeArg idL idR + = ApplicativeArgOne -- pat <- expr (pat must be irrefutable) + (LPat idL) + (LHsExpr idL) + | ApplicativeArgMany -- do { stmts; return vars } + [ExprLStmt idL] -- stmts + (SyntaxExpr idL) -- return (v1,..,vn), or just (v1,..,vn) + (LPat idL) -- (v1,...,vn) + deriving( Typeable ) +deriving instance (DataId idL, DataId idR) => Data (ApplicativeArg idL idR) + {- Note [The type of bind in Stmts] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1520,9 +1549,12 @@ instance (OutputableBndr idL, OutputableBndr idR, Outputable body) => Outputable (StmtLR idL idR body) where ppr stmt = pprStmt stmt -pprStmt :: (OutputableBndr idL, OutputableBndr idR, Outputable body) +pprStmt :: forall idL idR body . (OutputableBndr idL, OutputableBndr idR, Outputable body) => (StmtLR idL idR body) -> SDoc -pprStmt (LastStmt expr _) = ifPprDebug (ptext (sLit "[last]")) <+> ppr expr +pprStmt (LastStmt expr ret_stripped _) + = ifPprDebug (ptext (sLit "[last]")) <+> + (if ret_stripped then ptext (sLit "return") else empty) <+> + ppr expr pprStmt (BindStmt pat expr _ _) = hsep [ppr pat, larrow, ppr expr] pprStmt (LetStmt binds) = hsep [ptext (sLit "let"), pprBinds binds] pprStmt (BodyStmt expr _ _ _) = ppr expr @@ -1538,6 +1570,45 @@ pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids , ifPprDebug (vcat [ ptext (sLit "rec_ids=") <> ppr rec_ids , ptext (sLit "later_ids=") <> ppr later_ids])] +pprStmt (ApplicativeStmt args mb_join _) + = getPprStyle $ \style -> + if userStyle style + then pp_for_user + else pp_debug + where + -- make all the Applicative stuff invisible in error messages by + -- flattening the whole ApplicativeStmt nest back to a sequence + -- of statements. + pp_for_user = vcat $ punctuate semi $ concatMap flattenArg args + + -- ppr directly rather than transforming here, becuase we need to + -- inject a "return" which is hard when we're polymorphic in the id + -- type. + flattenStmt :: ExprLStmt idL -> [SDoc] + flattenStmt (L _ (ApplicativeStmt args _ _)) = concatMap flattenArg args + flattenStmt stmt = [ppr stmt] + + flattenArg (_, ApplicativeArgOne pat expr) = + [ppr (BindStmt pat expr noSyntaxExpr noSyntaxExpr :: ExprStmt idL)] + flattenArg (_, ApplicativeArgMany stmts _ _) = + concatMap flattenStmt stmts + + pp_debug = + let + ap_expr = sep (punctuate (ptext (sLit " |")) (map pp_arg args)) + in + if isNothing mb_join + then ap_expr + else ptext (sLit "join") <+> parens ap_expr + + pp_arg (_, ApplicativeArgOne pat expr) = + ppr (BindStmt pat expr noSyntaxExpr noSyntaxExpr :: ExprStmt idL) + pp_arg (_, ApplicativeArgMany stmts return pat) = + ppr pat <+> + ptext (sLit "<-") <+> + ppr (HsDo DoExpr (stmts ++ [noLoc (LastStmt (noLoc return) False noSyntaxExpr)]) + (error "pprStmt")) + pprTransformStmt :: OutputableBndr id => [id] -> LHsExpr id -> Maybe (LHsExpr id) -> SDoc pprTransformStmt bndrs using by = sep [ ptext (sLit "then") <+> ifPprDebug (braces (ppr bndrs)) @@ -1577,7 +1648,7 @@ pprComp :: (OutputableBndr id, Outputable body) => [LStmt id body] -> SDoc pprComp quals -- Prints: body | qual1, ..., qualn | not (null quals) - , L _ (LastStmt body _) <- last quals + , L _ (LastStmt body _ _) <- last quals = hang (ppr body <+> char '|') 2 (pprQuals (dropTail 1 quals)) | otherwise = pprPanic "pprComp" (pprQuals quals) @@ -1962,7 +2033,7 @@ pprMatchInCtxt ctxt match = hang (ptext (sLit "In") <+> pprMatchContext ctxt <> pprStmtInCtxt :: (OutputableBndr idL, OutputableBndr idR, Outputable body) => HsStmtContext idL -> StmtLR idL idR body -> SDoc -pprStmtInCtxt ctxt (LastStmt e _) +pprStmtInCtxt ctxt (LastStmt e _ _) | isListCompExpr ctxt -- For [ e | .. ], do not mutter about "stmts" = hang (ptext (sLit "In the expression:")) 2 (ppr e) diff --git a/compiler/hsSyn/HsUtils.hs b/compiler/hsSyn/HsUtils.hs index 2242d10f76..b45156288f 100644 --- a/compiler/hsSyn/HsUtils.hs +++ b/compiler/hsSyn/HsUtils.hs @@ -32,8 +32,13 @@ module HsUtils( mkLHsTupleExpr, mkLHsVarTuple, missingTupArg, toHsType, toHsKind, + -- * Constructing general big tuples + -- $big_tuples + mkChunkified, chunkify, + -- Bindings - mkFunBind, mkVarBind, mkHsVarBind, mk_easy_FunBind, mkTopFunBind, mkPatSynBind, + mkFunBind, mkVarBind, mkHsVarBind, mk_easy_FunBind, mkTopFunBind, + mkPatSynBind, -- Literals mkHsIntegral, mkHsFractional, mkHsIsString, mkHsString, @@ -42,6 +47,7 @@ module HsUtils( mkNPat, mkNPlusKPat, nlVarPat, nlLitPat, nlConVarPat, nlConPat, nlConPatName, nlInfixConPat, nlNullaryConPat, nlWildConPat, nlWildPat, nlWildPatName, nlWildPatId, nlTuplePat, mkParPat, + mkBigLHsVarTup, mkBigLHsTup, mkBigLHsVarPatTup, mkBigLHsPatTup, -- Types mkHsAppTy, userHsTyVarBndrs, @@ -99,6 +105,7 @@ import FastString import Util import Bag import Outputable +import Constants import Data.Either import Data.Function @@ -254,7 +261,7 @@ mkTransformByStmt ss u b = emptyTransStmt { trS_form = ThenForm, trS_stmts = s mkGroupUsingStmt ss u = emptyTransStmt { trS_form = GroupForm, trS_stmts = ss, trS_using = u } mkGroupByUsingStmt ss b u = emptyTransStmt { trS_form = GroupForm, trS_stmts = ss, trS_using = u, trS_by = Just b } -mkLastStmt body = LastStmt body noSyntaxExpr +mkLastStmt body = LastStmt body False noSyntaxExpr mkBodyStmt body = BodyStmt body noSyntaxExpr noSyntaxExpr placeHolderType mkBindStmt pat body = BindStmt pat body noSyntaxExpr noSyntaxExpr @@ -425,6 +432,66 @@ nlTuplePat pats box = noLoc (TuplePat pats box []) missingTupArg :: HsTupArg RdrName missingTupArg = Missing placeHolderType +mkLHsPatTup :: [LPat id] -> LPat id +mkLHsPatTup [] = noLoc $ TuplePat [] Boxed [] +mkLHsPatTup [lpat] = lpat +mkLHsPatTup lpats = L (getLoc (head lpats)) $ TuplePat lpats Boxed [] + +-- The Big equivalents for the source tuple expressions +mkBigLHsVarTup :: [id] -> LHsExpr id +mkBigLHsVarTup ids = mkBigLHsTup (map nlHsVar ids) + +mkBigLHsTup :: [LHsExpr id] -> LHsExpr id +mkBigLHsTup = mkChunkified mkLHsTupleExpr + +-- The Big equivalents for the source tuple patterns +mkBigLHsVarPatTup :: [id] -> LPat id +mkBigLHsVarPatTup bs = mkBigLHsPatTup (map nlVarPat bs) + +mkBigLHsPatTup :: [LPat id] -> LPat id +mkBigLHsPatTup = mkChunkified mkLHsPatTup + +-- $big_tuples +-- #big_tuples# +-- +-- GHCs built in tuples can only go up to 'mAX_TUPLE_SIZE' in arity, but +-- we might concievably want to build such a massive tuple as part of the +-- output of a desugaring stage (notably that for list comprehensions). +-- +-- We call tuples above this size \"big tuples\", and emulate them by +-- creating and pattern matching on >nested< tuples that are expressible +-- by GHC. +-- +-- Nesting policy: it's better to have a 2-tuple of 10-tuples (3 objects) +-- than a 10-tuple of 2-tuples (11 objects), so we want the leaves of any +-- construction to be big. +-- +-- If you just use the 'mkBigCoreTup', 'mkBigCoreVarTupTy', 'mkTupleSelector' +-- and 'mkTupleCase' functions to do all your work with tuples you should be +-- fine, and not have to worry about the arity limitation at all. + +-- | Lifts a \"small\" constructor into a \"big\" constructor by recursive decompositon +mkChunkified :: ([a] -> a) -- ^ \"Small\" constructor function, of maximum input arity 'mAX_TUPLE_SIZE' + -> [a] -- ^ Possible \"big\" list of things to construct from + -> a -- ^ Constructed thing made possible by recursive decomposition +mkChunkified small_tuple as = mk_big_tuple (chunkify as) + where + -- Each sub-list is short enough to fit in a tuple + mk_big_tuple [as] = small_tuple as + mk_big_tuple as_s = mk_big_tuple (chunkify (map small_tuple as_s)) + +chunkify :: [a] -> [[a]] +-- ^ Split a list into lists that are small enough to have a corresponding +-- tuple arity. The sub-lists of the result all have length <= 'mAX_TUPLE_SIZE' +-- But there may be more than 'mAX_TUPLE_SIZE' sub-lists +chunkify xs + | n_xs <= mAX_TUPLE_SIZE = [xs] + | otherwise = split xs + where + n_xs = length xs + split [] = [] + split xs = take mAX_TUPLE_SIZE xs : split (drop mAX_TUPLE_SIZE xs) + {- ************************************************************************ * * @@ -670,6 +737,7 @@ collectStmtBinders (ParStmt xs _ _) = collectLStmtsBinders $ [s | ParStmtBlock ss _ _ <- xs, s <- ss] collectStmtBinders (TransStmt { trS_stmts = stmts }) = collectLStmtsBinders stmts collectStmtBinders (RecStmt { recS_stmts = ss }) = collectLStmtsBinders ss +collectStmtBinders ApplicativeStmt{} = [] ----------------- Patterns -------------------------- @@ -877,7 +945,11 @@ lStmtsImplicits = hs_lstmts hs_lstmts :: [LStmtLR Name idR (Located (body idR))] -> NameSet hs_lstmts = foldr (\stmt rest -> unionNameSet (hs_stmt (unLoc stmt)) rest) emptyNameSet + hs_stmt :: StmtLR Name idR (Located (body idR)) -> NameSet hs_stmt (BindStmt pat _ _ _) = lPatImplicits pat + hs_stmt (ApplicativeStmt args _ _) = unionNameSets (map do_arg args) + where do_arg (_, ApplicativeArgOne pat _) = lPatImplicits pat + do_arg (_, ApplicativeArgMany stmts _ _) = hs_lstmts stmts hs_stmt (LetStmt binds) = hs_local_binds binds hs_stmt (BodyStmt {}) = emptyNameSet hs_stmt (LastStmt {}) = emptyNameSet diff --git a/compiler/main/DynFlags.hs b/compiler/main/DynFlags.hs index 01effa8bd9..802f264e36 100644 --- a/compiler/main/DynFlags.hs +++ b/compiler/main/DynFlags.hs @@ -602,6 +602,7 @@ data ExtensionFlag | Opt_PolyKinds -- Kind polymorphism | Opt_DataKinds -- Datatype promotion | Opt_InstanceSigs + | Opt_ApplicativeDo | Opt_StandaloneDeriving | Opt_DeriveDataTypeable @@ -3158,6 +3159,7 @@ xFlags = [ flagSpec' "IncoherentInstances" Opt_IncoherentInstances setIncoherentInsts, flagSpec "InstanceSigs" Opt_InstanceSigs, + flagSpec "ApplicativeDo" Opt_ApplicativeDo, flagSpec "InterruptibleFFI" Opt_InterruptibleFFI, flagSpec "JavaScriptFFI" Opt_JavaScriptFFI, flagSpec "KindSignatures" Opt_KindSignatures, diff --git a/compiler/parser/RdrHsSyn.hs b/compiler/parser/RdrHsSyn.hs index edc8a63bad..beb3b3bffa 100644 --- a/compiler/parser/RdrHsSyn.hs +++ b/compiler/parser/RdrHsSyn.hs @@ -1119,8 +1119,8 @@ checkCmdLStmt :: ExprLStmt RdrName -> P (CmdLStmt RdrName) checkCmdLStmt = locMap checkCmdStmt checkCmdStmt :: SrcSpan -> ExprStmt RdrName -> P (CmdStmt RdrName) -checkCmdStmt _ (LastStmt e r) = - checkCommand e >>= (\c -> return $ LastStmt c r) +checkCmdStmt _ (LastStmt e s r) = + checkCommand e >>= (\c -> return $ LastStmt c s r) checkCmdStmt _ (BindStmt pat e b f) = checkCommand e >>= (\c -> return $ BindStmt pat c b f) checkCmdStmt _ (BodyStmt e t g ty) = diff --git a/compiler/rename/RnBinds.hs b/compiler/rename/RnBinds.hs index 62a2472586..10c5b7bb03 100644 --- a/compiler/rename/RnBinds.hs +++ b/compiler/rename/RnBinds.hs @@ -197,13 +197,13 @@ rnTopBindsBoot b = pprPanic "rnTopBindsBoot" (ppr b) -} rnLocalBindsAndThen :: HsLocalBinds RdrName - -> (HsLocalBinds Name -> RnM (result, FreeVars)) + -> (HsLocalBinds Name -> FreeVars -> RnM (result, FreeVars)) -> RnM (result, FreeVars) -- This version (a) assumes that the binding vars are *not* already in scope -- (b) removes the binders from the free vars of the thing inside -- The parser doesn't produce ThenBinds -rnLocalBindsAndThen EmptyLocalBinds thing_inside - = thing_inside EmptyLocalBinds +rnLocalBindsAndThen EmptyLocalBinds thing_inside = + thing_inside EmptyLocalBinds emptyNameSet rnLocalBindsAndThen (HsValBinds val_binds) thing_inside = rnLocalValBindsAndThen val_binds $ \ val_binds' -> @@ -211,7 +211,7 @@ rnLocalBindsAndThen (HsValBinds val_binds) thing_inside rnLocalBindsAndThen (HsIPBinds binds) thing_inside = do (binds',fv_binds) <- rnIPBinds binds - (thing, fvs_thing) <- thing_inside (HsIPBinds binds') + (thing, fvs_thing) <- thing_inside (HsIPBinds binds') fv_binds return (thing, fvs_thing `plusFV` fv_binds) rnIPBinds :: HsIPBinds RdrName -> RnM (HsIPBinds Name, FreeVars) @@ -322,9 +322,10 @@ rnLocalValBindsRHS bound_names binds -- -- here there are no local fixity decls passed in; -- the local fixity decls come from the ValBinds sigs -rnLocalValBindsAndThen :: HsValBinds RdrName - -> (HsValBinds Name -> RnM (result, FreeVars)) - -> RnM (result, FreeVars) +rnLocalValBindsAndThen + :: HsValBinds RdrName + -> (HsValBinds Name -> FreeVars -> RnM (result, FreeVars)) + -> RnM (result, FreeVars) rnLocalValBindsAndThen binds@(ValBindsIn _ sigs) thing_inside = do { -- (A) Create the local fixity environment new_fixities <- makeMiniFixityEnv [L loc sig @@ -339,7 +340,7 @@ rnLocalValBindsAndThen binds@(ValBindsIn _ sigs) thing_inside { -- (C) Do the RHS and thing inside (binds', dus) <- rnLocalValBindsRHS (mkNameSet bound_names) new_lhs - ; (result, result_fvs) <- thing_inside binds' + ; (result, result_fvs) <- thing_inside binds' (allUses dus) -- Report unused bindings based on the (accurate) -- findUses. E.g. @@ -1091,7 +1092,7 @@ rnGRHSs :: HsMatchContext Name -> GRHSs RdrName (Located (body RdrName)) -> RnM (GRHSs Name (Located (body Name)), FreeVars) rnGRHSs ctxt rnBody (GRHSs grhss binds) - = rnLocalBindsAndThen binds $ \ binds' -> do + = rnLocalBindsAndThen binds $ \ binds' _ -> do (grhss', fvGRHSs) <- mapFvRn (rnGRHS ctxt rnBody) grhss return (GRHSs grhss' binds', fvGRHSs) diff --git a/compiler/rename/RnExpr.hs b/compiler/rename/RnExpr.hs index da0d38754d..aaac8f10de 100644 --- a/compiler/rename/RnExpr.hs +++ b/compiler/rename/RnExpr.hs @@ -10,7 +10,7 @@ general, all of these functions return a renamed thing, and a set of free variables. -} -{-# LANGUAGE CPP, ScopedTypeVariables #-} +{-# LANGUAGE CPP, ScopedTypeVariables, RecordWildCards #-} module RnExpr ( rnLExpr, rnExpr, rnStmts @@ -28,9 +28,9 @@ import RnSplice ( rnBracket, rnSpliceExpr, checkThLocalName ) import RnTypes import RnPat import DynFlags -import BasicTypes ( FixityDirection(..), Fixity(..), minPrecedence ) import PrelNames +import BasicTypes import Name import NameSet import RdrName @@ -212,12 +212,15 @@ rnExpr (HsCase expr matches) ; return (HsCase new_expr new_matches, e_fvs `plusFV` ms_fvs) } rnExpr (HsLet binds expr) - = rnLocalBindsAndThen binds $ \binds' -> do + = rnLocalBindsAndThen binds $ \binds' _ -> do { (expr',fvExpr) <- rnLExpr expr ; return (HsLet binds' expr', fvExpr) } rnExpr (HsDo do_or_lc stmts _) - = do { ((stmts', _), fvs) <- rnStmts do_or_lc rnLExpr stmts (\ _ -> return ((), emptyFVs)) + = do { ((stmts', _), fvs) <- + rnStmtsWithPostProcessing do_or_lc rnLExpr + postProcessStmtsForApplicativeDo stmts + (\ _ -> return ((), emptyFVs)) ; return ( HsDo do_or_lc stmts' placeHolderType, fvs ) } rnExpr (ExplicitList _ _ exps) @@ -512,12 +515,13 @@ rnCmd (HsCmdIf _ p b1 b2) ; return (HsCmdIf mb_ite p' b1' b2', plusFVs [fvITE, fvP, fvB1, fvB2]) } rnCmd (HsCmdLet binds cmd) - = rnLocalBindsAndThen binds $ \ binds' -> do + = rnLocalBindsAndThen binds $ \ binds' _ -> do { (cmd',fvExpr) <- rnLCmd cmd ; return (HsCmdLet binds' cmd', fvExpr) } rnCmd (HsCmdDo stmts _) - = do { ((stmts', _), fvs) <- rnStmts ArrowExpr rnLCmd stmts (\ _ -> return ((), emptyFVs)) + = do { ((stmts', _), fvs) <- + rnStmts ArrowExpr rnLCmd stmts (\ _ -> return ((), emptyFVs)) ; return ( HsCmdDo stmts' placeHolderType, fvs ) } rnCmd cmd@(HsCmdCast {}) = pprPanic "rnCmd" (ppr cmd) @@ -583,15 +587,17 @@ methodNamesLStmt :: Located (StmtLR Name Name (LHsCmd Name)) -> FreeVars methodNamesLStmt = methodNamesStmt . unLoc methodNamesStmt :: StmtLR Name Name (LHsCmd Name) -> FreeVars -methodNamesStmt (LastStmt cmd _) = methodNamesLCmd cmd +methodNamesStmt (LastStmt cmd _ _) = methodNamesLCmd cmd methodNamesStmt (BodyStmt cmd _ _ _) = methodNamesLCmd cmd methodNamesStmt (BindStmt _ cmd _ _) = methodNamesLCmd cmd -methodNamesStmt (RecStmt { recS_stmts = stmts }) = methodNamesStmts stmts `addOneFV` loopAName +methodNamesStmt (RecStmt { recS_stmts = stmts }) = + methodNamesStmts stmts `addOneFV` loopAName methodNamesStmt (LetStmt {}) = emptyFVs methodNamesStmt (ParStmt {}) = emptyFVs methodNamesStmt (TransStmt {}) = emptyFVs - -- ParStmt and TransStmt can't occur in commands, but it's not convenient to error - -- here so we just do what's convenient +methodNamesStmt ApplicativeStmt{} = emptyFVs + -- ParStmt and TransStmt can't occur in commands, but it's not + -- convenient to error here so we just do what's convenient {- ************************************************************************ @@ -631,20 +637,86 @@ rnArithSeq (FromThenTo expr1 expr2 expr3) ************************************************************************ -} -rnStmts :: Outputable (body RdrName) => HsStmtContext Name +-- | Rename some Stmts +rnStmts :: Outputable (body RdrName) + => HsStmtContext Name -> (Located (body RdrName) -> RnM (Located (body Name), FreeVars)) + -- ^ How to rename the body of each statement (e.g. rnLExpr) -> [LStmt RdrName (Located (body RdrName))] + -- ^ Statements -> ([Name] -> RnM (thing, FreeVars)) + -- ^ if these statements scope over something, this renames it + -- and returns the result. -> RnM (([LStmt Name (Located (body Name))], thing), FreeVars) +rnStmts ctxt rnBody = rnStmtsWithPostProcessing ctxt rnBody noPostProcessStmts + +-- | like 'rnStmts' but applies a post-processing step to the renamed Stmts +rnStmtsWithPostProcessing + :: Outputable (body RdrName) + => HsStmtContext Name + -> (Located (body RdrName) -> RnM (Located (body Name), FreeVars)) + -- ^ How to rename the body of each statement (e.g. rnLExpr) + -> (HsStmtContext Name + -> [(LStmt Name (Located (body Name)), FreeVars)] + -> RnM ([LStmt Name (Located (body Name))], FreeVars)) + -- ^ postprocess the statements + -> [LStmt RdrName (Located (body RdrName))] + -- ^ Statements + -> ([Name] -> RnM (thing, FreeVars)) + -- ^ if these statements scope over something, this renames it + -- and returns the result. + -> RnM (([LStmt Name (Located (body Name))], thing), FreeVars) +rnStmtsWithPostProcessing ctxt rnBody ppStmts stmts thing_inside + = do { ((stmts', thing), fvs) <- + rnStmtsWithFreeVars ctxt rnBody stmts thing_inside + ; (pp_stmts, fvs') <- ppStmts ctxt stmts' + ; return ((pp_stmts, thing), fvs `plusFV` fvs') + } + +-- | maybe rearrange statements according to the ApplicativeDo transformation +postProcessStmtsForApplicativeDo + :: HsStmtContext Name + -> [(LStmt Name (LHsExpr Name), FreeVars)] + -> RnM ([LStmt Name (LHsExpr Name)], FreeVars) +postProcessStmtsForApplicativeDo ctxt stmts + = do { + -- rearrange the statements using ApplicativeStmt if + -- -XApplicativeDo is on. Also strip out the FreeVars attached + -- to each Stmt body. + ado_is_on <- xoptM Opt_ApplicativeDo + ; let is_do_expr | DoExpr <- ctxt = True + | otherwise = False + ; if ado_is_on && is_do_expr + then rearrangeForApplicativeDo ctxt stmts + else noPostProcessStmts ctxt stmts } + +-- | strip the FreeVars annotations from statements +noPostProcessStmts + :: HsStmtContext Name + -> [(LStmt Name (Located (body Name)), FreeVars)] + -> RnM ([LStmt Name (Located (body Name))], FreeVars) +noPostProcessStmts _ stmts = return (map fst stmts, emptyNameSet) + + +rnStmtsWithFreeVars :: Outputable (body RdrName) + => HsStmtContext Name + -> (Located (body RdrName) -> RnM (Located (body Name), FreeVars)) + -> [LStmt RdrName (Located (body RdrName))] + -> ([Name] -> RnM (thing, FreeVars)) + -> RnM ( ([(LStmt Name (Located (body Name)), FreeVars)], thing) + , FreeVars) +-- Each Stmt body is annotated with its FreeVars, so that +-- we can rearrange statements for ApplicativeDo. +-- -- Variables bound by the Stmts, and mentioned in thing_inside, -- do not appear in the result FreeVars -rnStmts ctxt _ [] thing_inside +rnStmtsWithFreeVars ctxt _ [] thing_inside = do { checkEmptyStmts ctxt ; (thing, fvs) <- thing_inside [] ; return (([], thing), fvs) } -rnStmts MDoExpr rnBody stmts thing_inside -- Deal with mdo +rnStmtsWithFreeVars MDoExpr rnBody stmts thing_inside -- Deal with mdo = -- Behave like do { rec { ...all but last... }; last } do { ((stmts1, (stmts2, thing)), fvs) <- rnStmt MDoExpr rnBody (noLoc $ mkRecStmt all_but_last) $ \ _ -> @@ -654,7 +726,7 @@ rnStmts MDoExpr rnBody stmts thing_inside -- Deal with mdo where Just (all_but_last, last_stmt) = snocView stmts -rnStmts ctxt rnBody (lstmt@(L loc _) : lstmts) thing_inside +rnStmtsWithFreeVars ctxt rnBody (lstmt@(L loc _) : lstmts) thing_inside | null lstmts = setSrcSpan loc $ do { lstmt' <- checkLastStmt ctxt lstmt @@ -665,24 +737,29 @@ rnStmts ctxt rnBody (lstmt@(L loc _) : lstmts) thing_inside <- setSrcSpan loc $ do { checkStmt ctxt lstmt ; rnStmt ctxt rnBody lstmt $ \ bndrs1 -> - rnStmts ctxt rnBody lstmts $ \ bndrs2 -> + rnStmtsWithFreeVars ctxt rnBody lstmts $ \ bndrs2 -> thing_inside (bndrs1 ++ bndrs2) } ; return (((stmts1 ++ stmts2), thing), fvs) } ---------------------- -rnStmt :: Outputable (body RdrName) => HsStmtContext Name +rnStmt :: Outputable (body RdrName) + => HsStmtContext Name -> (Located (body RdrName) -> RnM (Located (body Name), FreeVars)) + -- ^ How to rename the body of the statement -> LStmt RdrName (Located (body RdrName)) + -- ^ The statement -> ([Name] -> RnM (thing, FreeVars)) - -> RnM (([LStmt Name (Located (body Name))], thing), FreeVars) + -- ^ Rename the stuff that this statement scopes over + -> RnM ( ([(LStmt Name (Located (body Name)), FreeVars)], thing) + , FreeVars) -- Variables bound by the Stmt, and mentioned in thing_inside, -- do not appear in the result FreeVars -rnStmt ctxt rnBody (L loc (LastStmt body _)) thing_inside +rnStmt ctxt rnBody (L loc (LastStmt body noret _)) thing_inside = do { (body', fv_expr) <- rnBody body ; (ret_op, fvs1) <- lookupStmtName ctxt returnMName ; (thing, fvs3) <- thing_inside [] - ; return (([L loc (LastStmt body' ret_op)], thing), + ; return (([(L loc (LastStmt body' noret ret_op), fv_expr)], thing), fv_expr `plusFV` fvs1 `plusFV` fvs3) } rnStmt ctxt rnBody (L loc (BodyStmt body _ _ _)) thing_inside @@ -695,7 +772,8 @@ rnStmt ctxt rnBody (L loc (BodyStmt body _ _ _)) thing_inside -- Also for sub-stmts of same eg [ e | x<-xs, gd | blah ] -- Here "gd" is a guard ; (thing, fvs3) <- thing_inside [] - ; return (([L loc (BodyStmt body' then_op guard_op placeHolderType)], thing), + ; return (([(L loc (BodyStmt body' + then_op guard_op placeHolderType), fv_expr)], thing), fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } rnStmt ctxt rnBody (L loc (BindStmt pat body _ _)) thing_inside @@ -705,15 +783,16 @@ rnStmt ctxt rnBody (L loc (BindStmt pat body _ _)) thing_inside ; (fail_op, fvs2) <- lookupStmtName ctxt failMName ; rnPat (StmtCtxt ctxt) pat $ \ pat' -> do { (thing, fvs3) <- thing_inside (collectPatBinders pat') - ; return (([L loc (BindStmt pat' body' bind_op fail_op)], thing), + ; return (( [(L loc (BindStmt pat' body' bind_op fail_op), fv_expr)] + , thing), fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) }} -- fv_expr shouldn't really be filtered by the rnPatsAndThen -- but it does not matter because the names are unique rnStmt _ _ (L loc (LetStmt binds)) thing_inside - = do { rnLocalBindsAndThen binds $ \binds' -> do + = do { rnLocalBindsAndThen binds $ \binds' bind_fvs -> do { (thing, fvs) <- thing_inside (collectLocalBinders binds') - ; return (([L loc (LetStmt binds')], thing), fvs) } } + ; return (([(L loc (LetStmt binds'), bind_fvs)], thing), fvs) } } rnStmt ctxt rnBody (L loc (RecStmt { recS_stmts = rec_stmts })) thing_inside = do { (return_op, fvs1) <- lookupStmtName ctxt returnMName @@ -737,14 +816,17 @@ rnStmt ctxt rnBody (L loc (RecStmt { recS_stmts = rec_stmts })) thing_inside emptyNameSet segs ; (thing, fvs_later) <- thing_inside bndrs ; let (rec_stmts', fvs) = segmentRecStmts loc ctxt empty_rec_stmt segs fvs_later - ; return ((rec_stmts', thing), fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } } + -- We aren't going to try to group RecStmts with + -- ApplicativeDo, so attaching empty FVs is fine. + ; return ( ((zip rec_stmts' (repeat emptyNameSet)), thing) + , fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } } rnStmt ctxt _ (L loc (ParStmt segs _ _)) thing_inside = do { (mzip_op, fvs1) <- lookupStmtName ctxt mzipName ; (bind_op, fvs2) <- lookupStmtName ctxt bindMName ; (return_op, fvs3) <- lookupStmtName ctxt returnMName ; ((segs', thing), fvs4) <- rnParallelStmts (ParStmtCtxt ctxt) return_op segs thing_inside - ; return ( ([L loc (ParStmt segs' mzip_op bind_op)], thing) + ; return ( ([(L loc (ParStmt segs' mzip_op bind_op), fvs4)], thing) , fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) } rnStmt ctxt _ (L loc (TransStmt { trS_stmts = stmts, trS_by = by, trS_form = form @@ -777,10 +859,13 @@ rnStmt ctxt _ (L loc (TransStmt { trS_stmts = stmts, trS_by = by, trS_form = for -- See Note [TransStmt binder map] in HsExpr ; traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr bndr_map) - ; return (([L loc (TransStmt { trS_stmts = stmts', trS_bndrs = bndr_map + ; return (([(L loc (TransStmt { trS_stmts = stmts', trS_bndrs = bndr_map , trS_by = by', trS_using = using', trS_form = form , trS_ret = return_op, trS_bind = bind_op - , trS_fmap = fmap_op })], thing), all_fvs) } + , trS_fmap = fmap_op }), fvs2)], thing), all_fvs) } + +rnStmt _ _ (L _ ApplicativeStmt{}) _ = + panic "rnStmt: ApplicativeStmt" rnParallelStmts :: forall thing. HsStmtContext Name -> SyntaxExpr Name @@ -844,8 +929,9 @@ Renaming parallel statements is painful. Given, say [ a+c | a <- as, bs <- bss | c <- bs, a <- ds ] Note that - (a) In order to report "Defined by not used" about 'bs', we must rename - each group of Stmts with a thing_inside whose FreeVars include at least {a,c} + (a) In order to report "Defined but not used" about 'bs', we must + rename each group of Stmts with a thing_inside whose FreeVars + include at least {a,c} (b) We want to report that 'a' is illegally bound in both branches @@ -874,11 +960,13 @@ type Segment stmts = (Defs, -- wrapper that does both the left- and right-hand sides rnRecStmtsAndThen :: Outputable (body RdrName) => - (Located (body RdrName) -> RnM (Located (body Name), FreeVars)) + (Located (body RdrName) + -> RnM (Located (body Name), FreeVars)) -> [LStmt RdrName (Located (body RdrName))] -- assumes that the FreeVars returned includes -- the FreeVars of the Segments - -> ([Segment (LStmt Name (Located (body Name)))] -> RnM (a, FreeVars)) + -> ([Segment (LStmt Name (Located (body Name)))] + -> RnM (a, FreeVars)) -> RnM (a, FreeVars) rnRecStmtsAndThen rnBody s cont = do { -- (A) Make the mini fixity env for all of the stmts @@ -922,8 +1010,8 @@ rn_rec_stmt_lhs :: Outputable body => MiniFixityEnv rn_rec_stmt_lhs _ (L loc (BodyStmt body a b c)) = return [(L loc (BodyStmt body a b c), emptyFVs)] -rn_rec_stmt_lhs _ (L loc (LastStmt body a)) - = return [(L loc (LastStmt body a), emptyFVs)] +rn_rec_stmt_lhs _ (L loc (LastStmt body noret a)) + = return [(L loc (LastStmt body noret a), emptyFVs)] rn_rec_stmt_lhs fix_env (L loc (BindStmt pat body a b)) = do @@ -952,6 +1040,9 @@ rn_rec_stmt_lhs _ stmt@(L _ (ParStmt {})) -- Syntactically illegal in mdo rn_rec_stmt_lhs _ stmt@(L _ (TransStmt {})) -- Syntactically illegal in mdo = pprPanic "rn_rec_stmt" (ppr stmt) +rn_rec_stmt_lhs _ stmt@(L _ (ApplicativeStmt {})) -- Shouldn't appear yet + = pprPanic "rn_rec_stmt" (ppr stmt) + rn_rec_stmt_lhs _ (L _ (LetStmt EmptyLocalBinds)) = panic "rn_rec_stmt LetStmt EmptyLocalBinds" @@ -978,11 +1069,11 @@ rn_rec_stmt :: (Outputable (body RdrName)) => -- Rename a Stmt that is inside a RecStmt (or mdo) -- Assumes all binders are already in scope -- Turns each stmt into a singleton Stmt -rn_rec_stmt rnBody _ (L loc (LastStmt body _), _) +rn_rec_stmt rnBody _ (L loc (LastStmt body noret _), _) = do { (body', fv_expr) <- rnBody body ; (ret_op, fvs1) <- lookupSyntaxName returnMName ; return [(emptyNameSet, fv_expr `plusFV` fvs1, emptyNameSet, - L loc (LastStmt body' ret_op))] } + L loc (LastStmt body' noret ret_op))] } rn_rec_stmt rnBody _ (L loc (BodyStmt body _ _ _), _) = do { (body', fvs) <- rnBody body @@ -1005,8 +1096,9 @@ rn_rec_stmt _ _ (L _ (LetStmt binds@(HsIPBinds _)), _) rn_rec_stmt _ all_bndrs (L loc (LetStmt (HsValBinds binds')), _) = do { (binds', du_binds) <- rnLocalValBindsRHS (mkNameSet all_bndrs) binds' -- fixities and unused are handled above in rnRecStmtsAndThen - ; return [(duDefs du_binds, allUses du_binds, - emptyNameSet, L loc (LetStmt (HsValBinds binds')))] } + ; let fvs = allUses du_binds + ; return [(duDefs du_binds, fvs, emptyNameSet, + L loc (LetStmt (HsValBinds binds')))] } -- no RecStmt case because they get flattened above when doing the LHSes rn_rec_stmt _ _ stmt@(L _ (RecStmt {}), _) @@ -1021,6 +1113,9 @@ rn_rec_stmt _ _ stmt@(L _ (TransStmt {}), _) -- Syntactically illegal in mdo rn_rec_stmt _ _ (L _ (LetStmt EmptyLocalBinds), _) = panic "rn_rec_stmt: LetStmt EmptyLocalBinds" +rn_rec_stmt _ _ stmt@(L _ (ApplicativeStmt {}), _) + = pprPanic "rn_rec_stmt: ApplicativeStmt" (ppr stmt) + rn_rec_stmts :: Outputable (body RdrName) => (Located (body RdrName) -> RnM (Located (body Name), FreeVars)) -> [Name] @@ -1042,7 +1137,7 @@ segmentRecStmts loc ctxt empty_rec_stmt segs fvs_later | MDoExpr <- ctxt = segsToStmts empty_rec_stmt grouped_segs fvs_later - -- Step 4: Turn the segments into Stmts + -- Step 4: Turn the segments into Stmts -- Use RecStmt when and only when there are fwd refs -- Also gather up the uses from the end towards the -- start, so we can tell the RecStmt which things are @@ -1186,6 +1281,360 @@ segsToStmts empty_rec_stmt ((defs, uses, fwds, ss) : segs) fvs_later {- ************************************************************************ * * +ApplicativeDo +* * +************************************************************************ + +Note [ApplicativeDo] + += Example = + +For a sequence of statements + + do + x <- A + y <- B x + z <- C + return (f x y z) + +We want to transform this to + + (\(x,y) z -> f x y z) <$> (do x <- A; y <- B x; return (x,y)) <*> C + +It would be easy to notice that "y <- B x" and "z <- C" are +independent and do something like this: + + do + x <- A + (y,z) <- (,) <$> B x <*> C + return (f x y z) + +But this isn't enough! A and C were also independent, and this +transformation loses the ability to do A and C in parallel. + +The algorithm works by first splitting the sequence of statements into +independent "segments", and a separate "tail" (the final statement). In +our example above, the segements would be + + [ x <- A + , y <- B x ] + + [ z <- C ] + +and the tail is: + + return (f x y z) + +Then we take these segments and make an Applicative expression from them: + + (\(x,y) z -> return (f x y z)) + <$> do { x <- A; y <- B x; return (x,y) } + <*> C + +Finally, we recursively apply the transformation to each segment, to +discover any nested parallelism. + += Syntax & spec = + + expr ::= ... | do {stmt_1; ..; stmt_n} expr | ... + + stmt ::= pat <- expr + | (arg_1 | ... | arg_n) -- applicative composition, n>=1 + | ... -- other kinds of statement (e.g. let) + + arg ::= pat <- expr + | {stmt_1; ..; stmt_n} {var_1..var_n} + +(note that in the actual implementation,the expr in a do statement is +represented by a LastStmt as the final stmt, this is just a +representational issue and may change later.) + +== Transformation to introduce applicative stmts == + +ado {} tail = tail +ado {pat <- expr} {return expr'} = (mkArg(pat <- expr)); return expr' +ado {one} tail = one : tail +ado stmts tail + | n == 1 = ado before (ado after tail) + where (before,after) = split(stmts_1) + | n > 1 = (mkArg(stmts_1) | ... | mkArg(stmts_n)); tail + where + {stmts_1 .. stmts_n} = segments(stmts) + +segments(stmts) = + -- divide stmts into segments with no interdependencies + +mkArg({pat <- expr}) = (pat <- expr) +mkArg({stmt_1; ...; stmt_n}) = + {stmt_1; ...; stmt_n} {vars(stmt_1) u .. u vars(stmt_n)} + +split({stmt_1; ..; stmt_n) = + ({stmt_1; ..; stmt_i}, {stmt_i+1; ..; stmt_n}) + -- 1 <= i <= n + -- i is a good place to insert a bind + +== Desugaring for do == + +dsDo {} expr = expr + +dsDo {pat <- rhs; stmts} expr = + rhs >>= \pat -> dsDo stmts expr + +dsDo {(arg_1 | ... | arg_n)} (return expr) = + (\argpat (arg_1) .. argpat(arg_n) -> expr) + <$> argexpr(arg_1) + <*> ... + <*> argexpr(arg_n) + +dsDo {(arg_1 | ... | arg_n); stmts} expr = + join (\argpat (arg_1) .. argpat(arg_n) -> dsDo stmts expr) + <$> argexpr(arg_1) + <*> ... + <*> argexpr(arg_n) + +-} + +-- | rearrange a list of statements using ApplicativeDoStmt. See +-- Note [ApplicativeDo]. +rearrangeForApplicativeDo + :: HsStmtContext Name + -> [(LStmt Name (LHsExpr Name), FreeVars)] + -> RnM ([LStmt Name (LHsExpr Name)], FreeVars) + +rearrangeForApplicativeDo _ [] = return ([], emptyNameSet) +rearrangeForApplicativeDo ctxt stmts0 = do + (stmts', fvs) <- ado ctxt stmts [last] last_fvs + return (stmts', fvs) + where (stmts,(last,last_fvs)) = findLast stmts0 + findLast [] = error "findLast" + findLast [last] = ([],last) + findLast (x:xs) = (x:rest,last) where (rest,last) = findLast xs + +-- | The ApplicativeDo transformation. +ado + :: HsStmtContext Name + -> [(LStmt Name (LHsExpr Name), FreeVars)] -- ^ input statements + -> [LStmt Name (LHsExpr Name)] -- ^ the "tail" + -> FreeVars -- ^ free variables of the tail + -> RnM ( [LStmt Name (LHsExpr Name)] -- ( output statements, + , FreeVars ) -- , things we needed + -- e.g. <$>, <*>, join ) + +ado _ctxt [] tail _ = return (tail, emptyNameSet) + +-- If we have a single bind, and we can do it without a join, transform +-- to an ApplicativeStmt. This corresponds to the rule +-- dsBlock [pat <- rhs] (return expr) = expr <$> rhs +-- In the spec, but we do it here rather than in the desugarer, +-- because we need the typechecker to typecheck the <$> form rather than +-- the bind form, which would give rise to a Monad constraint. +ado ctxt [(L _ (BindStmt pat rhs _ _),_)] tail _ + | isIrrefutableHsPat pat, (False,tail') <- needJoin tail + = mkApplicativeStmt ctxt [ApplicativeArgOne pat rhs] False tail' + +ado _ctxt [(one,_)] tail _ = return (one:tail, emptyNameSet) + +ado ctxt stmts tail tail_fvs = + case segments stmts of -- chop into segments + [] -> panic "ado" + [one] -> + -- one indivisible segment, divide it by adding a bind + adoSegment ctxt one tail tail_fvs + segs -> + -- multiple segments; recursively transform the segments, and + -- combine into an ApplicativeStmt + do { pairs <- mapM (adoSegmentArg ctxt tail_fvs) segs + ; let (stmts', fvss) = unzip pairs + ; let (need_join, tail') = needJoin tail + ; (stmts, fvs) <- mkApplicativeStmt ctxt stmts' need_join tail' + ; return (stmts, unionNameSets (fvs:fvss)) } + +-- | Deal with an indivisible segment. We pick a place to insert a +-- bind (it will actually be a join), and recursively transform the +-- two halves. +adoSegment + :: HsStmtContext Name + -> [(LStmt Name (LHsExpr Name), FreeVars)] + -> [LStmt Name (LHsExpr Name)] + -> FreeVars + -> RnM ( [LStmt Name (LHsExpr Name)], FreeVars ) +adoSegment ctxt stmts tail tail_fvs + = do { -- choose somewhere to put a bind + let (before,after) = splitSegment stmts + ; (stmts1, fvs1) <- ado ctxt after tail tail_fvs + ; let tail1_fvs = unionNameSets (tail_fvs : map snd after) + ; (stmts2, fvs2) <- ado ctxt before stmts1 tail1_fvs + ; return (stmts2, fvs1 `plusFV` fvs2) } + +-- | Given a segment, make an ApplicativeArg. Here we recursively +-- call adoSegment on the segment's contents to extract any further +-- available parallelism. +adoSegmentArg + :: HsStmtContext Name + -> FreeVars + -> [(LStmt Name (LHsExpr Name), FreeVars)] + -> RnM (ApplicativeArg Name Name, FreeVars) +adoSegmentArg _ _ [(L _ (BindStmt pat exp _ _),_)] = + return (ApplicativeArgOne pat exp, emptyFVs) +adoSegmentArg ctxt tail_fvs stmts = + do { let pvarset = mkNameSet (concatMap (collectStmtBinders.unLoc.fst) stmts) + `intersectNameSet` tail_fvs + pvars = nameSetElems pvarset + pat = mkBigLHsVarPatTup pvars + tup = mkBigLHsVarTup pvars + ; (stmts',fvs2) <- adoSegment ctxt stmts [] pvarset + ; (mb_ret, fvs1) <- case () of + _ | L _ ApplicativeStmt{} <- last stmts' -> + return (unLoc tup, emptyNameSet) + | otherwise -> do + (ret,fvs) <- lookupStmtName ctxt returnMName + return (HsApp (noLoc ret) tup, fvs) + ; return ( ApplicativeArgMany stmts' mb_ret pat + , fvs1 `plusFV` fvs2) } + +-- | Divide a sequence of statements into segments, where no segment +-- depends on any variables defined by a statement in another segment. +segments + :: [(LStmt Name (LHsExpr Name), FreeVars)] + -> [[(LStmt Name (LHsExpr Name), FreeVars)]] +segments stmts = map fst $ merge $ reverse $ map reverse $ walk (reverse stmts) + where + allvars = mkNameSet (concatMap (collectStmtBinders.unLoc.fst) stmts) + + -- We would rather not have a segment that just has LetStmts in + -- it, so combine those with an adjacent segment where possible. + merge [] = [] + merge (seg : segs) + = case rest of + [] -> [(seg,all_lets)] + ((s,s_lets):ss) | all_lets || s_lets + -> (seg ++ s, all_lets && s_lets) : ss + _otherwise -> (seg,all_lets) : rest + where + rest = merge segs + all_lets = all (not . isBindStmt . fst) seg + + walk [] = [] + walk ((stmt,fvs) : stmts) = ((stmt,fvs) : seg) : walk rest + where (seg,rest) = chunter (fvs `intersectNameSet` allvars) stmts + + chunter _ [] = ([], []) + chunter vars ((stmt,fvs) : rest) + | not (isEmptyNameSet vars) + = ((stmt,fvs) : chunk, rest') + where (chunk,rest') = chunter vars' rest + evars = fvs `intersectNameSet` allvars + pvars = mkNameSet (collectStmtBinders (unLoc stmt)) + vars' = (vars `minusNameSet` pvars) `unionNameSet` evars + chunter _ rest = ([], rest) + + isBindStmt (L _ BindStmt{}) = True + isBindStmt _ = False + +-- | Find a "good" place to insert a bind in an indivisible segment. +-- This is the only place where we use heuristics. The current +-- heuristic is to peel off the first group of independent statements +-- and put the bind after those. +splitSegment + :: [(LStmt Name (LHsExpr Name), FreeVars)] + -> ( [(LStmt Name (LHsExpr Name), FreeVars)] + , [(LStmt Name (LHsExpr Name), FreeVars)] ) +splitSegment stmts + | Just (lets,binds,rest) <- slurpIndependentStmts stmts + = if not (null lets) + then (lets, binds++rest) + else (lets++binds, rest) + | otherwise + = case stmts of + (x:xs) -> ([x],xs) + _other -> (stmts,[]) + +slurpIndependentStmts + :: [(LStmt Name (Located (body Name)), FreeVars)] + -> Maybe ( [(LStmt Name (Located (body Name)), FreeVars)] -- LetStmts + , [(LStmt Name (Located (body Name)), FreeVars)] -- BindStmts + , [(LStmt Name (Located (body Name)), FreeVars)] ) +slurpIndependentStmts stmts = go [] [] emptyNameSet stmts + where + -- If we encounter a BindStmt that doesn't depend on a previous BindStmt + -- in this group, then add it to the group. + go lets indep bndrs ((L loc (BindStmt pat body bind_op fail_op), fvs) : rest) + | isEmptyNameSet (bndrs `intersectNameSet` fvs) + = go lets ((L loc (BindStmt pat body bind_op fail_op), fvs) : indep) + bndrs' rest + where bndrs' = bndrs `unionNameSet` mkNameSet (collectPatBinders pat) + -- If we encounter a LetStmt that doesn't depend on a BindStmt in this + -- group, then move it to the beginning, so that it doesn't interfere with + -- grouping more BindStmts. + -- TODO: perhaps we shouldn't do this if there are any strict bindings, + -- because we might be moving evaluation earlier. + go lets indep bndrs ((L loc (LetStmt binds), fvs) : rest) + | isEmptyNameSet (bndrs `intersectNameSet` fvs) + = go ((L loc (LetStmt binds), fvs) : lets) indep bndrs rest + go _ [] _ _ = Nothing + go _ [_] _ _ = Nothing + go lets indep _ stmts = Just (reverse lets, reverse indep, stmts) + +-- | Build an ApplicativeStmt, and strip the "return" from the tail +-- if necessary. +-- +-- For example, if we start with +-- do x <- E1; y <- E2; return (f x y) +-- then we get +-- do (E1[x] | E2[y]); f x y +-- +-- the LastStmt in this case has the return removed, but we set the +-- flag on the LastStmt to indicate this, so that we can print out the +-- original statement correctly in error messages. It is easier to do +-- it this way rather than try to ignore the return later in both the +-- typechecker and the desugarer (I tried it that way first!). +mkApplicativeStmt + :: HsStmtContext Name + -> [ApplicativeArg Name Name] -- ^ The args + -> Bool -- ^ True <=> need a join + -> [LStmt Name (LHsExpr Name)] -- ^ The body statements + -> RnM ([LStmt Name (LHsExpr Name)], FreeVars) +mkApplicativeStmt ctxt args need_join body_stmts + = do { (fmap_op, fvs1) <- lookupStmtName ctxt fmapName + ; (ap_op, fvs2) <- lookupStmtName ctxt apAName + ; (mb_join, fvs3) <- + if need_join then + do { (join_op, fvs) <- lookupStmtName ctxt joinMName + ; return (Just join_op, fvs) } + else + return (Nothing, emptyNameSet) + ; let applicative_stmt = noLoc $ ApplicativeStmt + (zip (fmap_op : repeat ap_op) args) + mb_join + placeHolderType + ; return ( applicative_stmt : body_stmts + , fvs1 `plusFV` fvs2 `plusFV` fvs3) } + +-- | Given the statements following an ApplicativeStmt, determine whether +-- we need a @join@ or not, and remove the @return@ if necessary. +needJoin :: [LStmt Name (LHsExpr Name)] -> (Bool, [LStmt Name (LHsExpr Name)]) +needJoin [] = (False, []) -- we're in an ApplicativeArg +needJoin [L loc (LastStmt e _ t)] + | Just arg <- isReturnApp e = (False, [L loc (LastStmt arg True t)]) +needJoin stmts = (True, stmts) + +-- | @Just e@, if the expression is @return e@, otherwise @Nothing@ +isReturnApp :: LHsExpr Name -> Maybe (LHsExpr Name) +isReturnApp (L _ (HsPar expr)) = isReturnApp expr +isReturnApp (L _ (HsApp f arg)) + | is_return f = Just arg + | otherwise = Nothing + where + is_return (L _ (HsPar e)) = is_return e + is_return (L _ (HsVar r)) = r == returnMName + -- TODO: I don't know how to get this right for rebindable syntax + is_return _ = False +isReturnApp _ = Nothing + + +{- +************************************************************************ +* * \subsubsection{Errors} * * ************************************************************************ @@ -1257,6 +1706,7 @@ pprStmtCat (BindStmt {}) = ptext (sLit "binding") pprStmtCat (LetStmt {}) = ptext (sLit "let") pprStmtCat (RecStmt {}) = ptext (sLit "rec") pprStmtCat (ParStmt {}) = ptext (sLit "parallel") +pprStmtCat (ApplicativeStmt {}) = panic "pprStmtCat: ApplicativeStmt" ------------ emptyInvalid :: Validity -- Payload is the empty document @@ -1322,6 +1772,7 @@ okCompStmt dflags _ stmt | otherwise -> NotValid (ptext (sLit "Use TransformListComp")) RecStmt {} -> emptyInvalid LastStmt {} -> emptyInvalid -- Should not happen (dealt with by checkLastStmt) + ApplicativeStmt {} -> emptyInvalid ---------------- okPArrStmt dflags _ stmt @@ -1335,6 +1786,7 @@ okPArrStmt dflags _ stmt TransStmt {} -> emptyInvalid RecStmt {} -> emptyInvalid LastStmt {} -> emptyInvalid -- Should not happen (dealt with by checkLastStmt) + ApplicativeStmt {} -> emptyInvalid --------- checkTupleSection :: [LHsTupArg RdrName] -> RnM () diff --git a/compiler/typecheck/TcArrows.hs b/compiler/typecheck/TcArrows.hs index 9ad65722cd..dc2a38229c 100644 --- a/compiler/typecheck/TcArrows.hs +++ b/compiler/typecheck/TcArrows.hs @@ -340,10 +340,10 @@ matchExpectedCmdArgs n ty -- (b) no rebindable syntax tcArrDoStmt :: CmdEnv -> TcCmdStmtChecker -tcArrDoStmt env _ (LastStmt rhs _) res_ty thing_inside +tcArrDoStmt env _ (LastStmt rhs noret _) res_ty thing_inside = do { rhs' <- tcCmd env rhs (unitTy, res_ty) ; thing <- thing_inside (panic "tcArrDoStmt") - ; return (LastStmt rhs' noSyntaxExpr, thing) } + ; return (LastStmt rhs' noret noSyntaxExpr, thing) } tcArrDoStmt env _ (BodyStmt rhs _ _ _) res_ty thing_inside = do { (rhs', elt_ty) <- tc_arr_rhs env rhs diff --git a/compiler/typecheck/TcHsSyn.hs b/compiler/typecheck/TcHsSyn.hs index c461d513e2..abe367dcc0 100644 --- a/compiler/typecheck/TcHsSyn.hs +++ b/compiler/typecheck/TcHsSyn.hs @@ -953,10 +953,10 @@ zonkStmt env zBody (BodyStmt body then_op guard_op ty) new_ty <- zonkTcTypeToType env ty return (env, BodyStmt new_body new_then new_guard new_ty) -zonkStmt env zBody (LastStmt body ret_op) +zonkStmt env zBody (LastStmt body noret ret_op) = do new_body <- zBody env body new_ret <- zonkExpr env ret_op - return (env, LastStmt new_body new_ret) + return (env, LastStmt new_body noret new_ret) zonkStmt env _ (TransStmt { trS_stmts = stmts, trS_bndrs = binderMap , trS_by = by, trS_form = form, trS_using = using @@ -989,6 +989,29 @@ zonkStmt env zBody (BindStmt pat body bind_op fail_op) ; new_fail <- zonkExpr env fail_op ; return (env1, BindStmt new_pat new_body new_bind new_fail) } +zonkStmt env _zBody (ApplicativeStmt args mb_join body_ty) + = do { (env', args') <- zonk_args env args + ; new_mb_join <- traverse (zonkExpr env) mb_join + ; new_body_ty <- zonkTcTypeToType env' body_ty + ; return (env', ApplicativeStmt args' new_mb_join new_body_ty) } + where + zonk_args env [] = return (env, []) + zonk_args env ((op, arg) : groups) + = do { (env1, arg') <- zonk_arg env arg + ; op' <- zonkExpr env1 op + ; (env2, ss) <- zonk_args env1 groups + ; return (env2, (op', arg') : ss) } + + zonk_arg env (ApplicativeArgOne pat expr) + = do { (env1, new_pat) <- zonkPat env pat + ; new_expr <- zonkLExpr env expr + ; return (env1, ApplicativeArgOne new_pat new_expr) } + zonk_arg env (ApplicativeArgMany stmts ret pat) + = do { (env1, new_stmts) <- zonkStmts env zonkLExpr stmts + ; new_ret <- zonkExpr env1 ret + ; (env2, new_pat) <- zonkPat env pat + ; return (env2, ApplicativeArgMany new_stmts new_ret new_pat) } + ------------------------------------------------------------------------- zonkRecFields :: ZonkEnv -> HsRecordBinds TcId -> TcM (HsRecordBinds TcId) zonkRecFields env (HsRecFields flds dd) diff --git a/compiler/typecheck/TcMatches.hs b/compiler/typecheck/TcMatches.hs index 386a08d282..ebb7797673 100644 --- a/compiler/typecheck/TcMatches.hs +++ b/compiler/typecheck/TcMatches.hs @@ -322,16 +322,27 @@ tcStmtsAndThen _ _ [] res_ty thing_inside -- LetStmts are handled uniformly, regardless of context tcStmtsAndThen ctxt stmt_chk (L loc (LetStmt binds) : stmts) res_ty thing_inside = do { (binds', (stmts',thing)) <- tcLocalBinds binds $ - tcStmtsAndThen ctxt stmt_chk stmts res_ty thing_inside + tcStmtsAndThen ctxt stmt_chk stmts res_ty thing_inside ; return (L loc (LetStmt binds') : stmts', thing) } --- For the vanilla case, handle the location-setting part +-- Don't set the error context for an ApplicativeStmt. It ought to be +-- possible to do this with a popErrCtxt in the tcStmt case for +-- ApplicativeStmt, but it did someting strange and broke a test (ado002). tcStmtsAndThen ctxt stmt_chk (L loc stmt : stmts) res_ty thing_inside + | ApplicativeStmt{} <- stmt + = do { (stmt', (stmts', thing)) <- + stmt_chk ctxt stmt res_ty $ \ res_ty' -> + tcStmtsAndThen ctxt stmt_chk stmts res_ty' $ + thing_inside + ; return (L loc stmt' : stmts', thing) } + + -- For the vanilla case, handle the location-setting part + | otherwise = do { (stmt', (stmts', thing)) <- setSrcSpan loc $ - addErrCtxt (pprStmtInCtxt ctxt stmt) $ + addErrCtxt (pprStmtInCtxt ctxt stmt) $ stmt_chk ctxt stmt res_ty $ \ res_ty' -> - popErrCtxt $ + popErrCtxt $ tcStmtsAndThen ctxt stmt_chk stmts res_ty' $ thing_inside ; return (L loc stmt' : stmts', thing) } @@ -373,10 +384,10 @@ tcGuardStmt _ stmt _ _ tcLcStmt :: TyCon -- The list/Parray type constructor ([] or PArray) -> TcExprStmtChecker -tcLcStmt _ _ (LastStmt body _) elt_ty thing_inside +tcLcStmt _ _ (LastStmt body noret _) elt_ty thing_inside = do { body' <- tcMonoExprNC body elt_ty ; thing <- thing_inside (panic "tcLcStmt: thing_inside") - ; return (LastStmt body' noSyntaxExpr, thing) } + ; return (LastStmt body' noret noSyntaxExpr, thing) } -- A generator, pat <- rhs tcLcStmt m_tc ctxt (BindStmt pat rhs _ _) elt_ty thing_inside @@ -480,13 +491,13 @@ tcLcStmt _ _ stmt _ _ tcMcStmt :: TcExprStmtChecker -tcMcStmt _ (LastStmt body return_op) res_ty thing_inside +tcMcStmt _ (LastStmt body noret return_op) res_ty thing_inside = do { a_ty <- newFlexiTyVarTy liftedTypeKind ; return_op' <- tcSyntaxOp MCompOrigin return_op (a_ty `mkFunTy` res_ty) ; body' <- tcMonoExprNC body a_ty ; thing <- thing_inside (panic "tcMcStmt: thing_inside") - ; return (LastStmt body' return_op', thing) } + ; return (LastStmt body' noret return_op', thing) } -- Generators for monad comprehensions ( pat <- rhs ) -- @@ -729,10 +740,10 @@ tcMcStmt _ stmt _ _ tcDoStmt :: TcExprStmtChecker -tcDoStmt _ (LastStmt body _) res_ty thing_inside +tcDoStmt _ (LastStmt body noret _) res_ty thing_inside = do { body' <- tcMonoExprNC body res_ty ; thing <- thing_inside (panic "tcDoStmt: thing_inside") - ; return (LastStmt body' noSyntaxExpr, thing) } + ; return (LastStmt body' noret noSyntaxExpr, thing) } tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside = do { -- Deal with rebindable syntax: @@ -762,6 +773,20 @@ tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside ; return (BindStmt pat' rhs' bind_op' fail_op', thing) } +tcDoStmt ctxt (ApplicativeStmt pairs mb_join _) res_ty thing_inside + = do { + ; (mb_join', rhs_ty) <- case mb_join of + Nothing -> return (Nothing, res_ty) + Just join_op -> + do { rhs_ty <- newFlexiTyVarTy liftedTypeKind + ; join_op' <- tcSyntaxOp DoOrigin join_op + (mkFunTy rhs_ty res_ty) + ; return (Just join_op', rhs_ty) } + + ; (pairs', body_ty, thing) <- + tcApplicativeStmts ctxt pairs rhs_ty thing_inside + + ; return (ApplicativeStmt pairs' mb_join' body_ty, thing) } tcDoStmt _ (BodyStmt rhs then_op _ _) res_ty thing_inside = do { -- Deal with rebindable syntax; @@ -829,8 +854,89 @@ pushing info from the context into the RHS. To do this, we check the rebindable syntax first, and push that information into (tcMonoExprNC rhs). Otherwise the error shows up when cheking the rebindable syntax, and the expected/inferred stuff is back to front (see Trac #3613). +-} +{- +Note [typechecking ApplicativeStmt] + +join ((\pat1 ... patn -> body) <$> e1 <*> ... <*> en) + +fresh type variables: + pat_ty_1..pat_ty_n + exp_ty_1..exp_ty_n + t_1..t_(n-1) + +body :: body_ty +(\pat1 ... patn -> body) :: pat_ty_1 -> ... -> pat_ty_n -> body_ty +pat_i :: pat_ty_i +e_i :: exp_ty_i +<$> :: (pat_ty_1 -> ... -> pat_ty_n -> body_ty) -> exp_ty_1 -> t_1 +<*>_i :: t_(i-1) -> exp_ty_i -> t_i +join :: tn -> res_ty +-} +tcApplicativeStmts + :: HsStmtContext Name + -> [(HsExpr Name, ApplicativeArg Name Name)] + -> Type -- rhs_ty + -> (Type -> TcM t) -- thing_inside + -> TcM ([(HsExpr TcId, ApplicativeArg TcId TcId)], Type, t) + +tcApplicativeStmts ctxt pairs rhs_ty thing_inside + = do { body_ty <- newFlexiTyVarTy liftedTypeKind + ; let arity = length pairs + ; ts <- replicateM (arity-1) $ newFlexiTyVarTy liftedTypeKind + ; exp_tys <- replicateM arity $ newFlexiTyVarTy liftedTypeKind + ; pat_tys <- replicateM arity $ newFlexiTyVarTy liftedTypeKind + ; let fun_ty = mkFunTys pat_tys body_ty + + -- NB. do the <$>,<*> operators first, we don't want type errors here + ; let (ops, args) = unzip pairs + ; ops' <- goOps fun_ty (zip3 ops (ts ++ [rhs_ty]) exp_tys) + + ; (args', thing) <- goArgs (zip3 args pat_tys exp_tys) $ + thing_inside body_ty + ; return (zip ops' args', body_ty, thing) } + where + goOps _ [] = return [] + goOps t_left ((op,t_i,exp_ty) : ops) + = do { op' <- tcSyntaxOp DoOrigin op (mkFunTys [t_left, exp_ty] t_i) + ; ops' <- goOps t_i ops + ; return (op' : ops') } + + goArgs + :: [(ApplicativeArg Name Name, Type, Type)] + -> TcM t + -> TcM ([ApplicativeArg TcId TcId], t) + + goArgs [] thing_inside + = do { thing <- thing_inside + ; return ([],thing) + } + goArgs ((ApplicativeArgOne pat rhs, pat_ty, exp_ty) : rest) thing_inside + = do { let stmt :: ExprStmt Name + stmt = BindStmt pat rhs noSyntaxExpr noSyntaxExpr + ; setSrcSpan (combineSrcSpans (getLoc pat) (getLoc rhs)) $ + addErrCtxt (pprStmtInCtxt ctxt stmt) $ + do { rhs' <- tcMonoExprNC rhs exp_ty + ; (pat',(pairs, thing)) <- + tcPat (StmtCtxt ctxt) pat pat_ty $ + popErrCtxt $ + goArgs rest thing_inside + ; return (ApplicativeArgOne pat' rhs' : pairs, thing) } } + + goArgs ((ApplicativeArgMany stmts ret pat, pat_ty, exp_ty) : rest) + thing_inside + = do { (stmts', (ret',pat',rest',thing)) <- + tcStmtsAndThen ctxt tcDoStmt stmts exp_ty $ \res_ty -> do + { L _ ret' <- tcMonoExprNC (noLoc ret) res_ty + ; (pat',(rest', thing)) <- + tcPat (StmtCtxt ctxt) pat pat_ty $ + goArgs rest thing_inside + ; return (ret', pat', rest', thing) + } + ; return (ApplicativeArgMany stmts' ret' pat' : rest', thing) } +{- ************************************************************************ * * \subsection{Errors and contexts} |