summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorApoorv Ingle <apoorv-ingle@uiowa.edu>2023-03-08 18:40:25 -0600
committerApoorv Ingle <apoorv-ingle@uiowa.edu>2023-04-20 19:04:34 -0500
commit4d1360f2a299a48e0d732c0b2f46ea489df4923a (patch)
treebe70476ec1f81d0a2d85fe1f056d413f2896f275
parent1db30fe1dd38dd8ffedfadf3845706fcde02933b (diff)
downloadhaskell-4d1360f2a299a48e0d732c0b2f46ea489df4923a.tar.gz
HsExpand for HsDo
Fixes for #18324 - fixed rec do blocks to use mfix - make sure fail is used for pattern match failures in bind statments
-rw-r--r--compiler/GHC/Hs/Expr.hs14
-rw-r--r--compiler/GHC/Rename/Expr.hs179
-rw-r--r--testsuite/tests/rebindable/T18324.hs21
-rw-r--r--testsuite/tests/rebindable/all.T2
-rw-r--r--testsuite/tests/rebindable/pattern-fails.hs9
5 files changed, 214 insertions, 11 deletions
diff --git a/compiler/GHC/Hs/Expr.hs b/compiler/GHC/Hs/Expr.hs
index be7af5002a..e8b6b4d745 100644
--- a/compiler/GHC/Hs/Expr.hs
+++ b/compiler/GHC/Hs/Expr.hs
@@ -1087,11 +1087,12 @@ data HsExpansion orig expanded
= HsExpanded orig expanded
deriving Data
--- | Just print the original expression (the @a@).
+-- | Just print the original expression (the @a@) with the expanded version (the @b@)
instance (Outputable a, Outputable b) => Outputable (HsExpansion a b) where
ppr (HsExpanded orig expanded)
- = ifPprDebug (vcat [ppr orig, braces (text "Expansion:" <+> ppr expanded)])
- (ppr orig)
+ -- = ifPprDebug (vcat [ppr orig, braces (text "Expansion:" <+> ppr expanded)])
+ -- (ppr orig)
+ = ppr orig <+> braces (text "Expansion:" <+> ppr expanded)
{-
@@ -1993,6 +1994,13 @@ matchDoContextErrString (MDoExpr m) = prependQualified m (text "'mdo' block")
matchDoContextErrString ListComp = text "list comprehension"
matchDoContextErrString MonadComp = text "monad comprehension"
+instance Outputable HsDoFlavour where
+ ppr (DoExpr m) = text "DoExpr" <+> parens (ppr m)
+ ppr (MDoExpr m) = text "MDoExpr" <+> parens (ppr m)
+ ppr GhciStmtCtxt = text "GhciStmtCtxt"
+ ppr ListComp = text "ListComp"
+ ppr MonadComp = text "MonadComp"
+
pprMatchInCtxt :: (OutputableBndrId idR, Outputable body)
=> Match (GhcPass idR) body -> SDoc
pprMatchInCtxt match = hang (text "In" <+> pprMatchContext (m_ctxt match)
diff --git a/compiler/GHC/Rename/Expr.hs b/compiler/GHC/Rename/Expr.hs
index b68ff6a492..4e40d05948 100644
--- a/compiler/GHC/Rename/Expr.hs
+++ b/compiler/GHC/Rename/Expr.hs
@@ -59,6 +59,7 @@ import GHC.Builtin.Types ( nilDataConName )
import GHC.Types.FieldLabel
import GHC.Types.Fixity
import GHC.Types.Id.Make
+import GHC.Types.Basic(Origin(..))
import GHC.Types.Name
import GHC.Types.Name.Set
import GHC.Types.Name.Reader
@@ -78,7 +79,7 @@ import qualified GHC.LanguageExtensions as LangExt
import Language.Haskell.Syntax.Basic (FieldLabelString(..))
import Control.Monad
-import Data.List (unzip4, minimumBy)
+import Data.List (unzip4, minimumBy, (\\))
import Data.List.NonEmpty ( NonEmpty(..), nonEmpty )
import Control.Arrow (first)
import Data.Ord
@@ -436,9 +437,25 @@ rnExpr (HsDo _ do_or_lc (L l stmts))
= do { ((stmts1, _), fvs1) <-
rnStmtsWithFreeVars (HsDoStmt do_or_lc) rnExpr stmts
(\ _ -> return ((), emptyFVs))
- ; (pp_stmts, fvs2) <- postProcessStmtsForApplicativeDo do_or_lc stmts1
- ; return ( HsDo noExtField do_or_lc (L l pp_stmts), fvs1 `plusFV` fvs2 ) }
-
+ ; ((pp_stmts, fvs2), is_app_do) <- postProcessStmtsForApplicativeDo do_or_lc stmts1
+ ; let orig_do_block = HsDo noExtField do_or_lc (L l pp_stmts)
+ ; return $ case do_or_lc of
+ DoExpr {} -> (if is_app_do
+ -- TODO i don't want to thing about applicative stmt rearrangements yet
+ then orig_do_block
+ else let expd_do_block = expand_do_stmts do_or_lc pp_stmts
+ in mkExpandedExpr orig_do_block expd_do_block
+ , fvs1 `plusFV` fvs2 )
+ MDoExpr {} -> (if is_app_do
+ -- TODO i don't want to thing about applicative stmt rearrangements yet
+ then orig_do_block
+ else let expd_do_block = expand_do_stmts do_or_lc pp_stmts
+ in mkExpandedExpr orig_do_block expd_do_block
+ , fvs1 `plusFV` fvs2 )
+ _ -> (orig_do_block, fvs1 `plusFV` fvs2)
+ -- ListComp -> (orig_do_block, fvs1 `plusFV` fvs2)
+ -- GhciStmtCtxt -> (orig_do_block, fvs1 `plusFV` fvs2)
+ }
-- ExplicitList: see Note [Handling overloaded and rebindable constructs]
rnExpr (ExplicitList _ exps)
= do { (exps', fvs) <- rnExprs exps
@@ -1072,7 +1089,7 @@ rnStmts ctxt rnBody stmts thing_inside
postProcessStmtsForApplicativeDo
:: HsDoFlavour
-> [(ExprLStmt GhcRn, FreeVars)]
- -> RnM ([ExprLStmt GhcRn], FreeVars)
+ -> RnM (([ExprLStmt GhcRn], FreeVars), Bool) -- True <=> applicative do statement
postProcessStmtsForApplicativeDo ctxt stmts
= do {
-- rearrange the statements using ApplicativeStmt if
@@ -1086,8 +1103,10 @@ postProcessStmtsForApplicativeDo ctxt stmts
; in_th_bracket <- isBrackStage <$> getStage
; if ado_is_on && is_do_expr && not in_th_bracket
then do { traceRn "ppsfa" (ppr stmts)
- ; rearrangeForApplicativeDo ctxt stmts }
- else noPostProcessStmts (HsDoStmt ctxt) stmts }
+ ; ado_stmts_and_fvs <- rearrangeForApplicativeDo ctxt stmts
+ ; return (ado_stmts_and_fvs, True) }
+ else do { do_stmts_and_fvs <- noPostProcessStmts (HsDoStmt ctxt) stmts
+ ; return (do_stmts_and_fvs, False) } }
-- | strip the FreeVars annotations from statements
noPostProcessStmts
@@ -1180,7 +1199,7 @@ rnStmt ctxt rnBody (L loc (LastStmt _ (L lb body) noret _)) thing_inside
else return (noSyntaxExpr, emptyFVs)
-- The 'return' in a LastStmt is used only
-- for MonadComp; and we don't want to report
- -- "non in scope: return" in other cases
+ -- "not in scope: return" in other cases
-- #15607
; (thing, fvs3) <- thing_inside []
@@ -2718,6 +2737,150 @@ mkExpandedExpr
-> HsExpr GhcRn -- ^ suitably wrapped 'HsExpansion'
mkExpandedExpr a b = XExpr (HsExpanded a b)
+
+
+-- | Expand the Do statments so that it works fine with Quicklook
+-- See Note[Rebindable Do Expanding Statements]
+-- ANI Questions: 1. What should be the location information in the expanded expression? Currently the error is displayed on the expanded expr and not on the unexpanded expr
+expand_do_stmts :: HsDoFlavour -> [ExprLStmt GhcRn] -> HsExpr GhcRn
+
+expand_do_stmts do_flavour [L _ (LastStmt _ body _ NoSyntaxExprRn)]
+ -- if it is a last statement of a list comprehension, we need to explicitly return it -- See Note [TODO]
+ -- genHsApp (genHsVar returnMName) body
+ | ListComp <- do_flavour
+ = genHsApp (genHsVar returnMName) body
+ | MonadComp <- do_flavour
+ = unLoc body -- genHsApp (genHsVar returnMName) body
+ | otherwise
+ -- Last statement is just body if we are not in ListComp context. See Syntax.Expr.LastStmt
+ = unLoc body
+
+expand_do_stmts _ [L _ (LastStmt _ body _ (SyntaxExprRn ret))]
+--
+-- ------------------------------------------------
+-- return e ~~> return e
+-- definitely works T18324.hs
+ = unLoc $ mkHsApp (noLocA ret) body
+
+expand_do_stmts do_or_lc ((L _ (BindStmt xbsrn x e)): lstmts)
+ | SyntaxExprRn bind_op <- xbsrn_bindOp xbsrn
+ , Just (SyntaxExprRn fail_op) <- xbsrn_failOp xbsrn =
+-- the pattern binding x can fail
+-- stmts ~~> stmt' let f x = stmts'; f _ = fail ".."
+-- -------------------------------------------------------
+-- x <- e ; stmts ~~> (Prelude.>>=) e f
+
+ foldl genHsApp bind_op -- (>>=)
+ [ e
+ , noLocA $ failable_expr x (expand_do_stmts do_or_lc lstmts) fail_op
+ ]
+ | SyntaxExprRn bop <- xbsrn_bindOp xbsrn
+ , Nothing <- xbsrn_failOp xbsrn = -- irrefutable pattern so no failure
+-- stmts ~~> stmt'
+-- ------------------------------------------------
+-- x <- e ; stmts ~~> (Prelude.>>=) e (\ x -> stmts')
+ foldl genHsApp bop -- (>>=)
+ [ e
+ , mkHsLam [x] (noLocA $ expand_do_stmts do_or_lc lstmts) -- (\ x -> stmts')
+ ]
+
+ | otherwise = -- just use the polymorhpic bindop. TODO: Necessary?
+ genHsApps bindMName -- (Prelude.>>=)
+ [ e
+ , mkHsLam [x] (noLocA $ expand_do_stmts do_or_lc lstmts) -- (\ x -> stmts')
+ ]
+
+ where
+ failable_expr :: LPat GhcRn -> HsExpr GhcRn -> HsExpr GhcRn -> HsExpr GhcRn
+ failable_expr pat expr fail_op = HsLam noExtField $
+ mkMatchGroup Generated
+ (noLocA [ mkHsCaseAlt pat (noLocA expr)
+ , mkHsCaseAlt nlWildPatName
+ (noLocA $ genHsApp fail_op
+ (nlHsLit $ mkHsString "fail pattern")) ])
+
+expand_do_stmts do_or_lc (L _ (LetStmt _ bnds) : lstmts) =
+-- stmts ~~> stmts'
+-- ------------------------------------------------
+-- let x = e ; stmts ~~> let x = e in stmts'
+ HsLet NoExtField noHsTok bnds noHsTok
+ $ noLocA (expand_do_stmts do_or_lc lstmts)
+
+
+expand_do_stmts do_or_lc ((L _ (BodyStmt _ e (SyntaxExprRn f) _)) : lstmts) =
+-- stmts ~~> stmts'
+-- ----------------------------------------------
+-- e ; stmts ~~> (Prelude.>>) e stmt'
+ unLoc $ nlHsApp (nlHsApp (noLocA f) -- (>>) See Note [BodyStmt]
+ e)
+ $ (noLocA $ expand_do_stmts do_or_lc lstmts)
+
+expand_do_stmts do_or_lc ((L l (RecStmt { recS_stmts = rec_stmts
+ , recS_later_ids = later_ids -- forward referenced local ids
+ , recS_rec_ids = local_ids -- ids referenced outside of the rec block
+ , recS_mfix_fn = SyntaxExprRn mfix_fun -- the `mfix` expr
+ , recS_ret_fn = SyntaxExprRn return_fun -- the `return` expr
+ -- use it explicitly
+ -- at the end of expanded rec block
+ }))
+ : lstmts) =
+-- See Note [Typing a RecStmt]
+-- stmts ~~> stmts'
+-- -------------------------------------------------------------------------------------------
+-- rec { later_ids, local_ids, rec_block } ; stmts
+-- ~~> (Prelude.>>=) (mfix (\[ local_ids ++ later_ids ]
+-- -> do { rec_stmts
+-- ; return (later_ids, local_ids) } ))
+-- (\ [ local_ids ++ later_ids ] -> stmts')
+
+ genHsApps bindMName -- (Prelude.>>=)
+ [ (noLocA mfix_fun) `mkHsApp` mfix_expr -- mfix (do block)
+ , mkHsLam [ mkBigLHsVarPatTup all_ids ] -- (\ x -> stmts')
+ (L l $ expand_do_stmts do_or_lc lstmts)
+ ]
+ where
+ local_only_ids = local_ids \\ later_ids -- get unique local rec ids; local rec ids and later ids overlap
+ all_ids = local_only_ids ++ later_ids -- put local ids before return ids
+
+ return_stmt :: ExprLStmt GhcRn
+ return_stmt = noLocA $ LastStmt noExtField
+ (mkHsApp (noLocA return_fun)
+ $ mkBigLHsTup (map nlHsVar all_ids) noExtField)
+ Nothing
+ (SyntaxExprRn return_fun)
+ do_stmts :: XRec GhcRn [ExprLStmt GhcRn]
+ do_stmts = noLocA $ (unLoc rec_stmts) ++ [return_stmt]
+ do_block :: LHsExpr GhcRn
+ do_block = noLocA $ HsDo noExtField (DoExpr Nothing) $ do_stmts
+ mfix_expr :: LHsExpr GhcRn
+ mfix_expr = mkHsLam [ mkBigLHsVarPatTup all_ids ] $ do_block
+
+expand_do_stmts _ (stmt@(L _ (RecStmt {})):_) =
+ pprPanic "expand_do_stmts: impossible happened RecStmt" $ ppr stmt
+
+
+expand_do_stmts _ (stmt@(L _ (TransStmt {})):_) =
+ pprPanic "expand_do_stmts: impossible happened TransStmt" $ ppr stmt
+
+expand_do_stmts _ (stmt@(L _ (ParStmt {})):_) =
+-- See See Note [Monad Comprehensions]
+-- Parallel statements only appear in
+-- stmts ~~> stmts'
+-- -------------------------------------------------------------------------------------------
+-- ; stmts
+-- ~~> (Prelude.>>=) (mfix (\[ local_ids ++ later_ids ]
+-- -> do { rec_stmts
+-- ; return (later_ids, local_ids) } ))
+-- (\ [ local_ids ++ later_ids ] -> stmts')
+ pprPanic "expand_do_stmts: impossible happened ParStmt" $ ppr stmt
+
+expand_do_stmts _ (stmt@(L _ (ApplicativeStmt {})):_) =
+-- See Note [Applicative BodyStmt]
+
+ pprPanic "expand_do_stmts: impossible happened ApplicativeStmt" $ ppr stmt
+
+expand_do_stmts do_flavor stmts = pprPanic "expand_do_stmts: impossible happened" $ (ppr do_flavor $$ ppr stmts)
+
-----------------------------------------
-- Bits and pieces for RecordDotSyntax.
--
diff --git a/testsuite/tests/rebindable/T18324.hs b/testsuite/tests/rebindable/T18324.hs
new file mode 100644
index 0000000000..53712084ae
--- /dev/null
+++ b/testsuite/tests/rebindable/T18324.hs
@@ -0,0 +1,21 @@
+{-# LANGUAGE ImpredicativeTypes, DeriveAnyClass #-}
+-- {-# LANGUAGE MonadComprehensions, RecursiveDo #-}
+module Main where
+
+
+type Id = forall a. a -> a
+
+t :: IO Id
+t = return id
+
+p :: Id -> (Bool, Int)
+p f = (f True, f 3)
+
+foo1 = t >>= \x -> return (p x)
+
+foo2 = do { x <- t ; return (p x) }
+
+
+main = do x <- foo2
+ putStrLn $ show x
+
diff --git a/testsuite/tests/rebindable/all.T b/testsuite/tests/rebindable/all.T
index b5123102e9..0ef4c4942f 100644
--- a/testsuite/tests/rebindable/all.T
+++ b/testsuite/tests/rebindable/all.T
@@ -42,3 +42,5 @@ test('T14670', expect_broken(14670), compile, [''])
test('T19167', normal, compile, [''])
test('T19918', normal, compile_and_run, [''])
test('T20126', normal, compile_fail, [''])
+test('T18324', normal, compile_and_run, [''])
+test('pattern-fails', normal, compile_and_run, [''])
diff --git a/testsuite/tests/rebindable/pattern-fails.hs b/testsuite/tests/rebindable/pattern-fails.hs
new file mode 100644
index 0000000000..d7693a0337
--- /dev/null
+++ b/testsuite/tests/rebindable/pattern-fails.hs
@@ -0,0 +1,9 @@
+module Main where
+
+
+main :: IO ()
+main = putStrLn . show $ qqq ['c']
+
+qqq :: [a] -> Maybe (a, [a])
+qqq ts = do { (a:b:as) <- Just ts
+ ; return (a, as) }