summaryrefslogtreecommitdiff
path: root/compiler/GHC
diff options
context:
space:
mode:
authorSimon Peyton Jones <simonpj@microsoft.com>2020-04-21 23:57:33 +0100
committerMarge Bot <ben+marge-bot@smart-cactus.org>2020-05-29 01:39:19 -0400
commit0e3361ca414012e3ec40a260c2323986ce770db6 (patch)
tree2d76064448dd4d24444a9a2232dbd8f09f15217f /compiler/GHC
parentbbeb2389596df61ace5778ec580895ea32cc3c6f (diff)
downloadhaskell-0e3361ca414012e3ec40a260c2323986ce770db6.tar.gz
Make Lint check return type of a join point
Consider join x = rhs in body It's important that the type of 'rhs' is the same as the type of 'body', but Lint wasn't checking that invariant. Now it does! This was exposed by investigation into !3113.
Diffstat (limited to 'compiler/GHC')
-rw-r--r--compiler/GHC/Core/Lint.hs38
1 files changed, 31 insertions, 7 deletions
diff --git a/compiler/GHC/Core/Lint.hs b/compiler/GHC/Core/Lint.hs
index fe3bec8a48..601c0fc38a 100644
--- a/compiler/GHC/Core/Lint.hs
+++ b/compiler/GHC/Core/Lint.hs
@@ -462,7 +462,7 @@ lintCoreBindings dflags pass local_in_scope binds
addLoc TopLevelBindings $
do { checkL (null dups) (dupVars dups)
; checkL (null ext_dups) (dupExtVars ext_dups)
- ; lintRecBindings TopLevel all_pairs $
+ ; lintRecBindings TopLevel all_pairs $ \_ ->
return () }
where
all_pairs = flattenBinds binds
@@ -573,11 +573,11 @@ Check a core binding, returning the list of variables bound.
-}
lintRecBindings :: TopLevelFlag -> [(Id, CoreExpr)]
- -> LintM a -> LintM a
+ -> ([LintedId] -> LintM a) -> LintM a
lintRecBindings top_lvl pairs thing_inside
= lintIdBndrs top_lvl bndrs $ \ bndrs' ->
do { zipWithM_ lint_pair bndrs' rhss
- ; thing_inside }
+ ; thing_inside bndrs' }
where
(bndrs, rhss) = unzip pairs
lint_pair bndr' rhs
@@ -585,6 +585,12 @@ lintRecBindings top_lvl pairs thing_inside
do { rhs_ty <- lintRhs bndr' rhs -- Check the rhs
; lintLetBind top_lvl Recursive bndr' rhs rhs_ty }
+lintLetBody :: [LintedId] -> CoreExpr -> LintM LintedType
+lintLetBody bndrs body
+ = do { body_ty <- addLoc (BodyOfLetRec bndrs) (lintCoreExpr body)
+ ; mapM_ (lintJoinBndrType body_ty) bndrs
+ ; return body_ty }
+
lintLetBind :: TopLevelFlag -> RecFlag -> LintedId
-> CoreExpr -> LintedType -> LintM ()
-- Binder's type, and the RHS, have already been linted
@@ -831,7 +837,7 @@ lintCoreExpr (Let (NonRec bndr rhs) body)
-- Now lint the binder
; lintBinder LetBind bndr $ \bndr' ->
do { lintLetBind NotTopLevel NonRecursive bndr' rhs rhs_ty
- ; addLoc (BodyOfLetRec [bndr]) (lintCoreExpr body) } }
+ ; lintLetBody [bndr'] body } }
| otherwise
= failWithL (mkLetErr bndr rhs) -- Not quite accurate
@@ -848,9 +854,8 @@ lintCoreExpr e@(Let (Rec pairs) body)
; checkL (all isJoinId bndrs || all (not . isJoinId) bndrs) $
mkInconsistentRecMsg bndrs
- ; lintRecBindings NotTopLevel pairs $
- addLoc (BodyOfLetRec bndrs) $
- lintCoreExpr body }
+ ; lintRecBindings NotTopLevel pairs $ \ bndrs' ->
+ lintLetBody bndrs' body }
where
bndrs = map fst pairs
@@ -951,6 +956,25 @@ checkDeadIdOcc id
= return ()
------------------
+lintJoinBndrType :: LintedType -- Type of the body
+ -> LintedId -- Possibly a join Id
+ -> LintM ()
+-- Checks that the return type of a join Id matches the body
+-- E.g. join j x = rhs in body
+-- The type of 'rhs' must be the same as the type of 'body'
+lintJoinBndrType body_ty bndr
+ | Just arity <- isJoinId_maybe bndr
+ , let bndr_ty = idType bndr
+ , (bndrs, res) <- splitPiTys bndr_ty
+ = checkL (length bndrs >= arity
+ && body_ty `eqType` mkPiTys (drop arity bndrs) res) $
+ hang (text "Join point returns different type than body")
+ 2 (vcat [ text "Join bndr:" <+> ppr bndr <+> dcolon <+> ppr (idType bndr)
+ , text "Join arity:" <+> ppr arity
+ , text "Body type:" <+> ppr body_ty ])
+ | otherwise
+ = return ()
+
checkJoinOcc :: Id -> JoinArity -> LintM ()
-- Check that if the occurrence is a JoinId, then so is the
-- binding site, and it's a valid join Id