diff options
author | Ziyang Liu <unsafeFixIO@gmail.com> | 2022-01-27 15:49:20 -0800 |
---|---|---|
committer | Marge Bot <ben+marge-bot@smart-cactus.org> | 2022-02-24 04:53:34 -0500 |
commit | 7d426148c48d95c45ee3e5b050a4804700b08206 (patch) | |
tree | 182f717525f3f022afe1f83bc4c197e09506c084 | |
parent | 9ed3bc6e1c09c558bba20e045e61582ba8fbadc7 (diff) | |
download | haskell-7d426148c48d95c45ee3e5b050a4804700b08206.tar.gz |
Allow `return` in more cases in ApplicativeDo
The doc says that the last statement of an ado-block can be one of
`return E`, `return $ E`, `pure E` and `pure $ E`. But `return`
is not accepted in a few cases such as:
```haskell
-- The ado-block only has one statement
x :: F ()
x = do
return ()
-- The ado-block only has let-statements besides the `return`
y :: F ()
y = do
let a = True
return ()
```
These currently require `Monad` instances. This MR fixes it.
Normally `return` is accepted as the last statement because it is
stripped in constructing an `ApplicativeStmt`, but this cannot be
done in the above cases, so instead we replace `return` by `pure`.
A similar but different issue (when the ado-block contains `BindStmt`
or `BodyStmt`, the second last statement cannot be `LetStmt`, even if
the last statement uses `pure`) is fixed in !6786.
-rw-r--r-- | compiler/GHC/Rename/Expr.hs | 82 | ||||
-rw-r--r-- | testsuite/tests/ado/T20540.hs | 19 | ||||
-rw-r--r-- | testsuite/tests/ado/ado004.stderr | 4 | ||||
-rw-r--r-- | testsuite/tests/ado/all.T | 1 |
4 files changed, 88 insertions, 18 deletions
diff --git a/compiler/GHC/Rename/Expr.hs b/compiler/GHC/Rename/Expr.hs index bb529c8066..b34581dd8e 100644 --- a/compiler/GHC/Rename/Expr.hs +++ b/compiler/GHC/Rename/Expr.hs @@ -1826,6 +1826,12 @@ dsDo {(arg_1 | ... | arg_n); stmts} expr = <*> ... <*> argexpr(arg_n) +== Special cases == + +If a do-expression contains only "return E" or "return $ E" plus +zero or more let-statements, we replace the "return" with "pure". +See Section 3.6 of the paper. + = Relevant modules in the rest of the compiler = ApplicativeDo touches a few phases in the compiler: @@ -1873,7 +1879,17 @@ rearrangeForApplicativeDo -> RnM ([ExprLStmt GhcRn], FreeVars) rearrangeForApplicativeDo _ [] = return ([], emptyNameSet) -rearrangeForApplicativeDo _ [(one,_)] = return ([one], emptyNameSet) +-- If the do-block contains a single @return@ statement, change it to +-- @pure@ if ApplicativeDo is turned on. See Note [ApplicativeDo]. +rearrangeForApplicativeDo ctxt [(one,_)] = do + (return_name, _) <- lookupQualifiedDoName (HsDoStmt ctxt) returnMName + (pure_name, _) <- lookupQualifiedDoName (HsDoStmt ctxt) pureAName + let pure_expr = nl_HsVar pure_name + let monad_names = MonadNames { return_name = return_name + , pure_name = pure_name } + return $ case needJoin monad_names [one] (Just pure_expr) of + (False, one') -> (one', emptyNameSet) + (True, _) -> ([one], emptyNameSet) rearrangeForApplicativeDo ctxt stmts0 = do optimal_ado <- goptM Opt_OptimalApplicativeDo let stmt_tree | optimal_ado = mkStmtTreeOptimal stmts @@ -2007,9 +2023,12 @@ stmtTreeToStmts -- In the spec, but we do it here rather than in the desugarer, -- because we need the typechecker to typecheck the <$> form rather than -- the bind form, which would give rise to a Monad constraint. +-- +-- If we have a single let, and the last statement is @return E@ or @return $ E@, +-- change the @return@ to @pure@. stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BindStmt xbs pat rhs), _)) tail _tail_fvs - | not (isStrictPattern pat), (False,tail') <- needJoin monad_names tail + | not (isStrictPattern pat), (False,tail') <- needJoin monad_names tail Nothing -- See Note [ApplicativeDo and strict patterns] = mkApplicativeStmt ctxt [ApplicativeArgOne { xarg_app_arg_one = xbsrn_failOp xbs @@ -2020,7 +2039,7 @@ stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BindStmt xbs pat rhs), _)) False tail' stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BodyStmt _ rhs _ _),_)) tail _tail_fvs - | (False,tail') <- needJoin monad_names tail + | (False,tail') <- needJoin monad_names tail Nothing = mkApplicativeStmt ctxt [ApplicativeArgOne { xarg_app_arg_one = Nothing @@ -2028,6 +2047,12 @@ stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BodyStmt _ rhs _ _),_)) , arg_expr = rhs , is_body_stmt = True }] False tail' +stmtTreeToStmts monad_names ctxt (StmtTreeOne (let_stmt@(L _ LetStmt{}),_)) + tail _tail_fvs = do + (pure_expr, _) <- lookupQualifiedDoExpr (HsDoStmt ctxt) pureAName + return $ case needJoin monad_names tail (Just pure_expr) of + (False, tail') -> (let_stmt : tail', emptyNameSet) + (True, _) -> (let_stmt : tail, emptyNameSet) stmtTreeToStmts _monad_names _ctxt (StmtTreeOne (s,_)) tail _tail_fvs = return (s : tail, emptyNameSet) @@ -2046,7 +2071,7 @@ stmtTreeToStmts monad_names ctxt (StmtTreeApplicative trees) tail tail_fvs = do -- See Note [ApplicativeDo and refutable patterns] if any (hasRefutablePattern dflags) stmts' then (True, tail) - else needJoin monad_names tail + else needJoin monad_names tail Nothing (stmts, fvs) <- mkApplicativeStmt ctxt stmts' need_join tail' return (stmts, unionNameSets (fvs:fvss)) @@ -2304,25 +2329,47 @@ mkApplicativeStmt ctxt args need_join body_stmts -- | Given the statements following an ApplicativeStmt, determine whether -- we need a @join@ or not, and remove the @return@ if necessary. +-- +-- We don't need @join@ if there's a single @LastStmt@ in the form of +-- @return E@, @return $ E@, @pure E@ or @pure $ E@. needJoin :: MonadNames -> [ExprLStmt GhcRn] + -- If this is @Just pure@, replace return by pure + -- If this is @Nothing@, strip the return/pure + -> Maybe (HsExpr GhcRn) -> (Bool, [ExprLStmt GhcRn]) -needJoin _monad_names [] = (False, []) -- we're in an ApplicativeArg -needJoin monad_names [L loc (LastStmt _ e _ t)] - | Just (arg, wasDollar) <- isReturnApp monad_names e = - (False, [L loc (LastStmt noExtField arg (Just wasDollar) t)]) -needJoin _monad_names stmts = (True, stmts) - --- | @(Just e, False)@, if the expression is @return e@ --- @(Just e, True)@ if the expression is @return $ e@, +needJoin _monad_names [] _mb_pure = (False, []) -- we're in an ApplicativeArg +needJoin monad_names [L loc (LastStmt _ e _ t)] mb_pure + | Just (arg, noret) <- isReturnApp monad_names e mb_pure = + (False, [L loc (LastStmt noExtField arg noret t)]) +needJoin _monad_names stmts _mb_pure = (True, stmts) + +-- | @(Just e, Just False)@, if the expression is @return/pure e@, +-- and the third argument is Nothing, +-- @(Just e, Just True)@ if the expression is @return/pure $ e@, +-- and the third argument is Nothing, +-- @(Just (pure e), Nothing)@ if the expression is @return/pure e@, +-- and the third argument is @Just pure_expr@, +-- @(Just (pure $ e), Nothing)@ if the expression is @return/pure $ e@, +-- and the third argument is @Just pure_expr@, -- otherwise @Nothing@. isReturnApp :: MonadNames -> LHsExpr GhcRn - -> Maybe (LHsExpr GhcRn, Bool) -isReturnApp monad_names (L _ (HsPar _ _ expr _)) = isReturnApp monad_names expr -isReturnApp monad_names (L _ e) = case e of - OpApp _ l op r | is_return l, is_dollar op -> Just (r, True) - HsApp _ f arg | is_return f -> Just (arg, False) + -- If this is @Just pure@, replace return by pure + -- If this is @Nothing@, strip the return/pure + -> Maybe (HsExpr GhcRn) + -> Maybe (LHsExpr GhcRn, Maybe Bool) +isReturnApp monad_names (L _ (HsPar _ _ expr _)) mb_pure = + isReturnApp monad_names expr mb_pure +isReturnApp monad_names (L loc e) mb_pure = case e of + OpApp x l op r + | Just pure_expr <- mb_pure, is_return l, is_dollar op -> + Just (L loc (OpApp x (to_pure l pure_expr) op r), Nothing) + | is_return l, is_dollar op -> Just (r, Just True) + HsApp x f arg + | Just pure_expr <- mb_pure, is_return f -> + Just (L loc (HsApp x (to_pure f pure_expr) arg), Nothing) + | is_return f -> Just (arg, Just False) _otherwise -> Nothing where is_var f (L _ (HsPar _ _ e _)) = is_var f e @@ -2333,6 +2380,7 @@ isReturnApp monad_names (L _ e) = case e of is_return = is_var (\n -> n == return_name monad_names || n == pure_name monad_names) + to_pure (L loc _) pure_expr = L loc pure_expr is_dollar = is_var (`hasKey` dollarIdKey) {- diff --git a/testsuite/tests/ado/T20540.hs b/testsuite/tests/ado/T20540.hs new file mode 100644 index 0000000000..5af2be029f --- /dev/null +++ b/testsuite/tests/ado/T20540.hs @@ -0,0 +1,19 @@ +{-# LANGUAGE ApplicativeDo, DerivingStrategies, GeneralizedNewtypeDeriving #-} + +module T20540 where + +import Data.Functor.Identity + +newtype F a = F (Identity a) + deriving newtype (Functor, Applicative, Show) + +x :: F Int +x = do + return 3 + +y :: F Int +y = do + let a = 3 + let b = 4 + let c = 5 + return $ a + b + c diff --git a/testsuite/tests/ado/ado004.stderr b/testsuite/tests/ado/ado004.stderr index d3f33c81f6..61b8cee912 100644 --- a/testsuite/tests/ado/ado004.stderr +++ b/testsuite/tests/ado/ado004.stderr @@ -16,7 +16,9 @@ TYPE SIGNATURES (Functor f, Num t, Num b) => (t -> f b) -> f b test2b :: - forall {m :: * -> *} {t} {a}. (Monad m, Num t) => (t -> a) -> m a + forall {f :: * -> *} {t} {a}. + (Applicative f, Num t) => + (t -> a) -> f a test2c :: forall {f :: * -> *} {t} {b}. (Functor f, Num t, Num b) => diff --git a/testsuite/tests/ado/all.T b/testsuite/tests/ado/all.T index 86e18998b0..7369f9e986 100644 --- a/testsuite/tests/ado/all.T +++ b/testsuite/tests/ado/all.T @@ -18,4 +18,5 @@ test('T14163', normal, compile_and_run, ['']) test('T15344', normal, compile_and_run, ['']) test('T16628', normal, compile_fail, ['']) test('T17835', normal, compile, ['']) +test('T20540', normal, compile, ['']) test('T16135', when(compiler_debugged(),expect_broken(16135)), compile, ['']) |