summaryrefslogtreecommitdiff
path: root/compiler/coreSyn
diff options
context:
space:
mode:
authorSimon Peyton Jones <simonpj@microsoft.com>2017-02-17 15:03:01 +0000
committerBen Gamari <ben@smart-cactus.org>2017-02-21 09:31:17 -0500
commite790126cd57ab39649b1fd42996733fafe20eb34 (patch)
tree6f2ff2f6274834cda888532254c99513dd86b9a8 /compiler/coreSyn
parent82694e6765da1db4e7596ec410e4c41d3bf1ca94 (diff)
downloadhaskell-e790126cd57ab39649b1fd42996733fafe20eb34.tar.gz
Improve Core Lint, mainly for join points
* lintSingleBinding: check that join points have a valid join-point type (Trac #13281) * lintIdBinder: check that a JoinId is bound by a non-top-level let i.e. not a top level binder not lambda/case binder * Check for empty Rec [] bindings * Rename lintIdBndrs to lintLetBndrs
Diffstat (limited to 'compiler/coreSyn')
-rw-r--r--compiler/coreSyn/CoreLint.hs107
1 files changed, 69 insertions, 38 deletions
diff --git a/compiler/coreSyn/CoreLint.hs b/compiler/coreSyn/CoreLint.hs
index 4aa7d44713..aed938220c 100644
--- a/compiler/coreSyn/CoreLint.hs
+++ b/compiler/coreSyn/CoreLint.hs
@@ -392,7 +392,7 @@ lintCoreBindings :: DynFlags -> CoreToDo -> [Var] -> CoreProgram -> (Bag MsgDoc,
lintCoreBindings dflags pass local_in_scope binds
= initL dflags flags in_scope_set $
addLoc TopLevelBindings $
- lintIdBndrs TopLevel binders $
+ lintLetBndrs TopLevel binders $
-- Put all the top-level binders in scope at the start
-- This is because transformation rules can bring something
-- into use 'unexpectedly'
@@ -531,9 +531,12 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
; flags <- getLintFlags
- -- Check that if the binder is top-level, it's not a join point
- ; checkL (not (isJoinId binder && isTopLevel top_lvl_flag))
- (mkTopJoinMsg binder)
+ -- Check that a join-point binder has a valid type
+ -- NB: lintIdBinder has checked that it is not top-level bound
+ ; case isJoinId_maybe binder of
+ Nothing -> return ()
+ Just arity -> checkL (isValidJoinPointType arity binder_ty)
+ (mkInvalidJoinPointMsg binder binder_ty)
; when (lf_check_inline_loop_breakers flags
&& isStrongLoopBreaker (idOccInfo binder)
@@ -591,11 +594,13 @@ lintRhs bndr rhs
where
lint_join_lams 0 _ _ rhs
= lintCoreExpr rhs
+
lint_join_lams n tot enforce (Lam var expr)
= addLoc (LambdaBodyOf var) $
- lintBinder var $ \ var' ->
+ lintBinder LambdaBind var $ \ var' ->
do { body_ty <- lint_join_lams (n-1) tot enforce expr
; return $ mkLamType var' body_ty }
+
lint_join_lams n tot True _other
= failWithL $ mkBadJoinArityMsg bndr tot (tot-n)
lint_join_lams _ _ False rhs
@@ -617,7 +622,7 @@ lintRhs _bndr rhs = fmap lf_check_static_ptrs getLintFlags >>= go
-- imitate @lintCoreExpr (Lam ...)@
(\var loopBinders ->
addLoc (LambdaBodyOf var) $
- lintBinder var $ \var' ->
+ lintBinder LambdaBind var $ \var' ->
do { body_ty <- loopBinders
; return $ mkLamType var' body_ty }
)
@@ -636,7 +641,7 @@ lintIdUnfolding bndr bndr_ty (CoreUnfolding { uf_tmpl = rhs, uf_src = src })
lintIdUnfolding bndr bndr_ty (DFunUnfolding { df_con = con, df_bndrs = bndrs
, df_args = args })
- = do { ty <- lintBinders bndrs $ \ bndrs' ->
+ = do { ty <- lintBinders LambdaBind bndrs $ \ bndrs' ->
do { res_ty <- lintCoreArgs (dataConRepType con) args
; return (mkLamTypes bndrs' res_ty) }
; ensureEqTys bndr_ty ty (mkRhsMsg bndr (text "dfun unfolding") ty) }
@@ -724,19 +729,26 @@ lintCoreExpr (Let (NonRec bndr rhs) body)
| isId bndr
= do { lintSingleBinding NotTopLevel NonRecursive (bndr,rhs)
; addLoc (BodyOfLetRec [bndr])
- (lintIdBndr NotTopLevel bndr $ \_ ->
+ (lintIdBndr NotTopLevel LetBind bndr $ \_ ->
addGoodJoins [bndr] $
lintCoreExpr body) }
| otherwise
= failWithL (mkLetErr bndr rhs) -- Not quite accurate
-lintCoreExpr (Let (Rec pairs) body)
- = lintIdBndrs NotTopLevel bndrs $
+lintCoreExpr e@(Let (Rec pairs) body)
+ = lintLetBndrs NotTopLevel bndrs $
addGoodJoins bndrs $
- do { checkL (null dups) (dupVars dups)
+ do { -- Check that the list of pairs is non-empty
+ checkL (not (null pairs)) (emptyRec e)
+
+ -- Check that there are no duplicated binders
+ ; checkL (null dups) (dupVars dups)
+
+ -- Check that either all the binders are joins, or none
; checkL (all isJoinId bndrs || all (not . isJoinId) bndrs) $
mkInconsistentRecMsg bndrs
+
; mapM_ (lintSingleBinding NotTopLevel Recursive) pairs
; addLoc (BodyOfLetRec bndrs) (lintCoreExpr body) }
where
@@ -753,7 +765,7 @@ lintCoreExpr e@(App _ _)
lintCoreExpr (Lam var expr)
= addLoc (LambdaBodyOf var) $
markAllJoinsBad $
- lintBinder var $ \ var' ->
+ lintBinder LambdaBind var $ \ var' ->
do { body_ty <- lintCoreExpr expr
; return $ mkLamType var' body_ty }
@@ -798,7 +810,7 @@ lintCoreExpr e@(Case scrut var alt_ty alts) =
; subst <- getTCvSubst
; ensureEqTys var_ty scrut_ty (mkScrutMsg var var_ty scrut_ty subst)
- ; lintIdBndr NotTopLevel var $ \_ ->
+ ; lintIdBndr NotTopLevel CaseBind var $ \_ ->
do { -- Check the alternatives
mapM_ (lintCoreAlt scrut_ty alt_ty) alts
; checkCaseAlts e scrut_ty alts
@@ -850,7 +862,7 @@ lintCoreFun (Lam var body) nargs
-- Note [Beta redexes]
| nargs /= 0
= addLoc (LambdaBodyOf var) $
- lintBinder var $ \ var' ->
+ lintBinder LambdaBind var $ \ var' ->
do { body_ty <- lintCoreFun body (nargs - 1)
; return $ mkLamType var' body_ty }
@@ -1117,7 +1129,7 @@ lintCoreAlt scrut_ty alt_ty alt@(DataAlt con, args, rhs)
; let con_payload_ty = piResultTys (dataConRepType con) tycon_arg_tys
-- And now bring the new binders into scope
- ; lintBinders args $ \ args' -> do
+ ; lintBinders CasePatBind args $ \ args' -> do
{ addLoc (CasePat alt) (lintAltBinders scrut_ty con_payload_ty args')
; lintAltExpr rhs alt_ty } }
@@ -1136,19 +1148,19 @@ lintCoreAlt scrut_ty alt_ty alt@(DataAlt con, args, rhs)
-- 1. Lint var types or kinds (possibly substituting)
-- 2. Add the binder to the in scope set, and if its a coercion var,
-- we may extend the substitution to reflect its (possibly) new kind
-lintBinders :: [Var] -> ([Var] -> LintM a) -> LintM a
-lintBinders [] linterF = linterF []
-lintBinders (var:vars) linterF = lintBinder var $ \var' ->
- lintBinders vars $ \ vars' ->
- linterF (var':vars')
+lintBinders :: BindingSite -> [Var] -> ([Var] -> LintM a) -> LintM a
+lintBinders _ [] linterF = linterF []
+lintBinders site (var:vars) linterF = lintBinder site var $ \var' ->
+ lintBinders site vars $ \ vars' ->
+ linterF (var':vars')
-- If you edit this function, you may need to update the GHC formalism
-- See Note [GHC Formalism]
-lintBinder :: Var -> (Var -> LintM a) -> LintM a
-lintBinder var linterF
- | isTyVar var = lintTyBndr var linterF
- | isCoVar var = lintCoBndr var linterF
- | otherwise = lintIdBndr NotTopLevel var linterF
+lintBinder :: BindingSite -> Var -> (Var -> LintM a) -> LintM a
+lintBinder site var linterF
+ | isTyVar var = lintTyBndr var linterF
+ | isCoVar var = lintCoBndr var linterF
+ | otherwise = lintIdBndr NotTopLevel site var linterF
lintTyBndr :: InTyVar -> (OutTyVar -> LintM a) -> LintM a
lintTyBndr tv thing_inside
@@ -1166,20 +1178,20 @@ lintCoBndr cv thing_inside
(text "CoVar with non-coercion type:" <+> pprTyVar cv)
; updateTCvSubst subst' (thing_inside cv') }
-lintIdBndrs :: TopLevelFlag -> [Var] -> LintM a -> LintM a
-lintIdBndrs top_lvl ids linterF
+lintLetBndrs :: TopLevelFlag -> [Var] -> LintM a -> LintM a
+lintLetBndrs top_lvl ids linterF
= go ids
where
go [] = linterF
- go (id:ids) = lintIdBndr top_lvl id $ \_ ->
- lintIdBndrs top_lvl ids $
- linterF
+ go (id:ids) = lintIdBndr top_lvl LetBind id $ \_ ->
+ go ids
-lintIdBndr :: TopLevelFlag -> InVar -> (OutVar -> LintM a) -> LintM a
+lintIdBndr :: TopLevelFlag -> BindingSite
+ -> InVar -> (OutVar -> LintM a) -> LintM a
-- Do substitution on the type of a binder and add the var with this
-- new type to the in-scope set of the second argument
-- ToDo: lint its rules
-lintIdBndr top_lvl id linterF
+lintIdBndr top_lvl bind_site id linterF
= ASSERT2( isId id, ppr id )
do { flags <- getLintFlags
; checkL (not (lf_check_global_ids flags) || isLocalId id)
@@ -1187,11 +1199,11 @@ lintIdBndr top_lvl id linterF
-- See Note [Checking for global Ids]
-- Check that if the binder is nested, it is not marked as exported
- ; checkL (not (isExportedId id) || isTopLevel top_lvl)
+ ; checkL (not (isExportedId id) || is_top_lvl)
(mkNonTopExportedMsg id)
-- Check that if the binder is nested, it does not have an external name
- ; checkL (not (isExternalName (Var.varName id)) || isTopLevel top_lvl)
+ ; checkL (not (isExternalName (Var.varName id)) || is_top_lvl)
(mkNonTopExternalNameMsg id)
; (ty, k) <- lintInTy (idType id)
@@ -1200,8 +1212,18 @@ lintIdBndr top_lvl id linterF
(text "Levity-polymorphic binder:" <+>
(ppr id <+> dcolon <+> parens (ppr ty <+> dcolon <+> ppr k)))
+ -- Check that a join-id is a not-top-level let-binding
+ ; when (isJoinId id) $
+ checkL (not is_top_lvl && is_let_bind) $
+ mkBadJoinBindMsg id
+
; let id' = setIdType id ty
; addInScopeVar id' $ (linterF id') }
+ where
+ is_top_lvl = isTopLevel top_lvl
+ is_let_bind = case bind_site of
+ LetBind -> True
+ _ -> False
{-
%************************************************************************
@@ -1387,7 +1409,7 @@ lintCoreRule _ _ (BuiltinRule {})
lintCoreRule fun fun_ty rule@(Rule { ru_name = name, ru_bndrs = bndrs
, ru_args = args, ru_rhs = rhs })
- = lintBinders bndrs $ \ _ ->
+ = lintBinders LambdaBind bndrs $ \ _ ->
do { lhs_ty <- foldM lintCoreArg fun_ty args
; rhs_ty <- case isJoinId_maybe fun of
Just join_arity
@@ -2225,6 +2247,9 @@ mkTyAppMsg ty arg_ty
hang (text "Arg type:")
4 (ppr arg_ty <+> dcolon <+> ppr (typeKind arg_ty))]
+emptyRec :: CoreExpr -> MsgDoc
+emptyRec e = hang (text "Empty Rec binding:") 2 (ppr e)
+
mkRhsMsg :: Id -> SDoc -> Type -> MsgDoc
mkRhsMsg binder what ty
= vcat
@@ -2311,9 +2336,15 @@ mkBadTyVarMsg tv
= text "Non-tyvar used in TyVarTy:"
<+> ppr tv <+> dcolon <+> ppr (varType tv)
-mkTopJoinMsg :: Var -> SDoc
-mkTopJoinMsg var
- = text "Join point at top level:" <+> ppr var
+mkBadJoinBindMsg :: Var -> SDoc
+mkBadJoinBindMsg var
+ = vcat [ text "Bad join point binding:" <+> ppr var
+ , text "Join points can be bound only by a non-top-level let" ]
+
+mkInvalidJoinPointMsg :: Var -> Type -> SDoc
+mkInvalidJoinPointMsg var ty
+ = hang (text "Join point has invalid type:")
+ 2 (ppr var <+> dcolon <+> ppr ty)
mkBadJoinArityMsg :: Var -> Int -> Int -> SDoc
mkBadJoinArityMsg var ar nlams