summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZiyang Liu <unsafeFixIO@gmail.com>2022-01-27 15:49:20 -0800
committerMarge Bot <ben+marge-bot@smart-cactus.org>2022-02-24 04:53:34 -0500
commit7d426148c48d95c45ee3e5b050a4804700b08206 (patch)
tree182f717525f3f022afe1f83bc4c197e09506c084
parent9ed3bc6e1c09c558bba20e045e61582ba8fbadc7 (diff)
downloadhaskell-7d426148c48d95c45ee3e5b050a4804700b08206.tar.gz
Allow `return` in more cases in ApplicativeDo
The doc says that the last statement of an ado-block can be one of `return E`, `return $ E`, `pure E` and `pure $ E`. But `return` is not accepted in a few cases such as: ```haskell -- The ado-block only has one statement x :: F () x = do return () -- The ado-block only has let-statements besides the `return` y :: F () y = do let a = True return () ``` These currently require `Monad` instances. This MR fixes it. Normally `return` is accepted as the last statement because it is stripped in constructing an `ApplicativeStmt`, but this cannot be done in the above cases, so instead we replace `return` by `pure`. A similar but different issue (when the ado-block contains `BindStmt` or `BodyStmt`, the second last statement cannot be `LetStmt`, even if the last statement uses `pure`) is fixed in !6786.
-rw-r--r--compiler/GHC/Rename/Expr.hs82
-rw-r--r--testsuite/tests/ado/T20540.hs19
-rw-r--r--testsuite/tests/ado/ado004.stderr4
-rw-r--r--testsuite/tests/ado/all.T1
4 files changed, 88 insertions, 18 deletions
diff --git a/compiler/GHC/Rename/Expr.hs b/compiler/GHC/Rename/Expr.hs
index bb529c8066..b34581dd8e 100644
--- a/compiler/GHC/Rename/Expr.hs
+++ b/compiler/GHC/Rename/Expr.hs
@@ -1826,6 +1826,12 @@ dsDo {(arg_1 | ... | arg_n); stmts} expr =
<*> ...
<*> argexpr(arg_n)
+== Special cases ==
+
+If a do-expression contains only "return E" or "return $ E" plus
+zero or more let-statements, we replace the "return" with "pure".
+See Section 3.6 of the paper.
+
= Relevant modules in the rest of the compiler =
ApplicativeDo touches a few phases in the compiler:
@@ -1873,7 +1879,17 @@ rearrangeForApplicativeDo
-> RnM ([ExprLStmt GhcRn], FreeVars)
rearrangeForApplicativeDo _ [] = return ([], emptyNameSet)
-rearrangeForApplicativeDo _ [(one,_)] = return ([one], emptyNameSet)
+-- If the do-block contains a single @return@ statement, change it to
+-- @pure@ if ApplicativeDo is turned on. See Note [ApplicativeDo].
+rearrangeForApplicativeDo ctxt [(one,_)] = do
+ (return_name, _) <- lookupQualifiedDoName (HsDoStmt ctxt) returnMName
+ (pure_name, _) <- lookupQualifiedDoName (HsDoStmt ctxt) pureAName
+ let pure_expr = nl_HsVar pure_name
+ let monad_names = MonadNames { return_name = return_name
+ , pure_name = pure_name }
+ return $ case needJoin monad_names [one] (Just pure_expr) of
+ (False, one') -> (one', emptyNameSet)
+ (True, _) -> ([one], emptyNameSet)
rearrangeForApplicativeDo ctxt stmts0 = do
optimal_ado <- goptM Opt_OptimalApplicativeDo
let stmt_tree | optimal_ado = mkStmtTreeOptimal stmts
@@ -2007,9 +2023,12 @@ stmtTreeToStmts
-- In the spec, but we do it here rather than in the desugarer,
-- because we need the typechecker to typecheck the <$> form rather than
-- the bind form, which would give rise to a Monad constraint.
+--
+-- If we have a single let, and the last statement is @return E@ or @return $ E@,
+-- change the @return@ to @pure@.
stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BindStmt xbs pat rhs), _))
tail _tail_fvs
- | not (isStrictPattern pat), (False,tail') <- needJoin monad_names tail
+ | not (isStrictPattern pat), (False,tail') <- needJoin monad_names tail Nothing
-- See Note [ApplicativeDo and strict patterns]
= mkApplicativeStmt ctxt [ApplicativeArgOne
{ xarg_app_arg_one = xbsrn_failOp xbs
@@ -2020,7 +2039,7 @@ stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BindStmt xbs pat rhs), _))
False tail'
stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BodyStmt _ rhs _ _),_))
tail _tail_fvs
- | (False,tail') <- needJoin monad_names tail
+ | (False,tail') <- needJoin monad_names tail Nothing
= mkApplicativeStmt ctxt
[ApplicativeArgOne
{ xarg_app_arg_one = Nothing
@@ -2028,6 +2047,12 @@ stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BodyStmt _ rhs _ _),_))
, arg_expr = rhs
, is_body_stmt = True
}] False tail'
+stmtTreeToStmts monad_names ctxt (StmtTreeOne (let_stmt@(L _ LetStmt{}),_))
+ tail _tail_fvs = do
+ (pure_expr, _) <- lookupQualifiedDoExpr (HsDoStmt ctxt) pureAName
+ return $ case needJoin monad_names tail (Just pure_expr) of
+ (False, tail') -> (let_stmt : tail', emptyNameSet)
+ (True, _) -> (let_stmt : tail, emptyNameSet)
stmtTreeToStmts _monad_names _ctxt (StmtTreeOne (s,_)) tail _tail_fvs =
return (s : tail, emptyNameSet)
@@ -2046,7 +2071,7 @@ stmtTreeToStmts monad_names ctxt (StmtTreeApplicative trees) tail tail_fvs = do
-- See Note [ApplicativeDo and refutable patterns]
if any (hasRefutablePattern dflags) stmts'
then (True, tail)
- else needJoin monad_names tail
+ else needJoin monad_names tail Nothing
(stmts, fvs) <- mkApplicativeStmt ctxt stmts' need_join tail'
return (stmts, unionNameSets (fvs:fvss))
@@ -2304,25 +2329,47 @@ mkApplicativeStmt ctxt args need_join body_stmts
-- | Given the statements following an ApplicativeStmt, determine whether
-- we need a @join@ or not, and remove the @return@ if necessary.
+--
+-- We don't need @join@ if there's a single @LastStmt@ in the form of
+-- @return E@, @return $ E@, @pure E@ or @pure $ E@.
needJoin :: MonadNames
-> [ExprLStmt GhcRn]
+ -- If this is @Just pure@, replace return by pure
+ -- If this is @Nothing@, strip the return/pure
+ -> Maybe (HsExpr GhcRn)
-> (Bool, [ExprLStmt GhcRn])
-needJoin _monad_names [] = (False, []) -- we're in an ApplicativeArg
-needJoin monad_names [L loc (LastStmt _ e _ t)]
- | Just (arg, wasDollar) <- isReturnApp monad_names e =
- (False, [L loc (LastStmt noExtField arg (Just wasDollar) t)])
-needJoin _monad_names stmts = (True, stmts)
-
--- | @(Just e, False)@, if the expression is @return e@
--- @(Just e, True)@ if the expression is @return $ e@,
+needJoin _monad_names [] _mb_pure = (False, []) -- we're in an ApplicativeArg
+needJoin monad_names [L loc (LastStmt _ e _ t)] mb_pure
+ | Just (arg, noret) <- isReturnApp monad_names e mb_pure =
+ (False, [L loc (LastStmt noExtField arg noret t)])
+needJoin _monad_names stmts _mb_pure = (True, stmts)
+
+-- | @(Just e, Just False)@, if the expression is @return/pure e@,
+-- and the third argument is Nothing,
+-- @(Just e, Just True)@ if the expression is @return/pure $ e@,
+-- and the third argument is Nothing,
+-- @(Just (pure e), Nothing)@ if the expression is @return/pure e@,
+-- and the third argument is @Just pure_expr@,
+-- @(Just (pure $ e), Nothing)@ if the expression is @return/pure $ e@,
+-- and the third argument is @Just pure_expr@,
-- otherwise @Nothing@.
isReturnApp :: MonadNames
-> LHsExpr GhcRn
- -> Maybe (LHsExpr GhcRn, Bool)
-isReturnApp monad_names (L _ (HsPar _ _ expr _)) = isReturnApp monad_names expr
-isReturnApp monad_names (L _ e) = case e of
- OpApp _ l op r | is_return l, is_dollar op -> Just (r, True)
- HsApp _ f arg | is_return f -> Just (arg, False)
+ -- If this is @Just pure@, replace return by pure
+ -- If this is @Nothing@, strip the return/pure
+ -> Maybe (HsExpr GhcRn)
+ -> Maybe (LHsExpr GhcRn, Maybe Bool)
+isReturnApp monad_names (L _ (HsPar _ _ expr _)) mb_pure =
+ isReturnApp monad_names expr mb_pure
+isReturnApp monad_names (L loc e) mb_pure = case e of
+ OpApp x l op r
+ | Just pure_expr <- mb_pure, is_return l, is_dollar op ->
+ Just (L loc (OpApp x (to_pure l pure_expr) op r), Nothing)
+ | is_return l, is_dollar op -> Just (r, Just True)
+ HsApp x f arg
+ | Just pure_expr <- mb_pure, is_return f ->
+ Just (L loc (HsApp x (to_pure f pure_expr) arg), Nothing)
+ | is_return f -> Just (arg, Just False)
_otherwise -> Nothing
where
is_var f (L _ (HsPar _ _ e _)) = is_var f e
@@ -2333,6 +2380,7 @@ isReturnApp monad_names (L _ e) = case e of
is_return = is_var (\n -> n == return_name monad_names
|| n == pure_name monad_names)
+ to_pure (L loc _) pure_expr = L loc pure_expr
is_dollar = is_var (`hasKey` dollarIdKey)
{-
diff --git a/testsuite/tests/ado/T20540.hs b/testsuite/tests/ado/T20540.hs
new file mode 100644
index 0000000000..5af2be029f
--- /dev/null
+++ b/testsuite/tests/ado/T20540.hs
@@ -0,0 +1,19 @@
+{-# LANGUAGE ApplicativeDo, DerivingStrategies, GeneralizedNewtypeDeriving #-}
+
+module T20540 where
+
+import Data.Functor.Identity
+
+newtype F a = F (Identity a)
+ deriving newtype (Functor, Applicative, Show)
+
+x :: F Int
+x = do
+ return 3
+
+y :: F Int
+y = do
+ let a = 3
+ let b = 4
+ let c = 5
+ return $ a + b + c
diff --git a/testsuite/tests/ado/ado004.stderr b/testsuite/tests/ado/ado004.stderr
index d3f33c81f6..61b8cee912 100644
--- a/testsuite/tests/ado/ado004.stderr
+++ b/testsuite/tests/ado/ado004.stderr
@@ -16,7 +16,9 @@ TYPE SIGNATURES
(Functor f, Num t, Num b) =>
(t -> f b) -> f b
test2b ::
- forall {m :: * -> *} {t} {a}. (Monad m, Num t) => (t -> a) -> m a
+ forall {f :: * -> *} {t} {a}.
+ (Applicative f, Num t) =>
+ (t -> a) -> f a
test2c ::
forall {f :: * -> *} {t} {b}.
(Functor f, Num t, Num b) =>
diff --git a/testsuite/tests/ado/all.T b/testsuite/tests/ado/all.T
index 86e18998b0..7369f9e986 100644
--- a/testsuite/tests/ado/all.T
+++ b/testsuite/tests/ado/all.T
@@ -18,4 +18,5 @@ test('T14163', normal, compile_and_run, [''])
test('T15344', normal, compile_and_run, [''])
test('T16628', normal, compile_fail, [''])
test('T17835', normal, compile, [''])
+test('T20540', normal, compile, [''])
test('T16135', when(compiler_debugged(),expect_broken(16135)), compile, [''])