diff options
author | Simon Peyton Jones <simonpj@microsoft.com> | 2017-02-16 09:42:32 +0000 |
---|---|---|
committer | Simon Peyton Jones <simonpj@microsoft.com> | 2017-02-16 14:24:57 +0000 |
commit | 6bab649bde653f13c15eba30d5007bef4a9a9d3a (patch) | |
tree | 9732155c1110fa3e2b3d5e68f249eee4c47a35ed | |
parent | fc9d152b058f21ab03986ea722d0c94688b9969f (diff) | |
download | haskell-6bab649bde653f13c15eba30d5007bef4a9a9d3a.tar.gz |
Improve checking of joins in Core Lint
This patch addresses the rather expensive treatment of join points,
identified in Trac #13220 comment:17
Before we were tracking the "bad joins". Now we track the good ones.
That is easier to think about, and much more efficient; see CoreLint
Note [Join points].
On the way I did some other modest refactoring, among other things
removing a duplicated call of lintIdBndr for let-bindings.
On teh
-rw-r--r-- | compiler/coreSyn/CoreLint.hs | 253 |
1 files changed, 130 insertions, 123 deletions
diff --git a/compiler/coreSyn/CoreLint.hs b/compiler/coreSyn/CoreLint.hs index f87989d482..053ac21d15 100644 --- a/compiler/coreSyn/CoreLint.hs +++ b/compiler/coreSyn/CoreLint.hs @@ -151,7 +151,6 @@ find an occurrence of an Id, we fetch it from the in-scope set. Note [Bad unsafe coercion] ~~~~~~~~~~~~~~~~~~~~~~~~~~ - For discussion see https://ghc.haskell.org/trac/ghc/wiki/BadUnsafeCoercions Linter introduces additional rules that checks improper coercion between different types, called bad coercions. Following coercions are forbidden: @@ -170,12 +169,10 @@ different types, called bad coercions. Following coercions are forbidden: Note [Join points] ~~~~~~~~~~~~~~~~~~ - We check the rules listed in Note [Invariants on join points] in CoreSyn. The only one that causes any difficulty is the first: All occurrences must be tail -calls. To this end, along with the in-scope set, we remember in le_bad_joins the -subset of join ids that are no longer allowed because they were declared "too -far away." For example: +calls. To this end, along with the in-scope set, we remember in le_joins the +subset of in-scope Ids that are valid join ids. For example: join j x = ... in case e of @@ -184,11 +181,11 @@ far away." For example: C -> join h = jump j w in ... -- good D -> let x = jump j v in ... -- BAD -A join point remains valid in case branches, so when checking the A branch, j -is still valid. When we check the scrutinee of the inner case, however, we add j -to le_bad_joins and catch the error. Similarly, join points can occur free in -RHSes of other join points but not the RHSes of value bindings (thunks and -functions). +A join point remains valid in case branches, so when checking the A +branch, j is still valid. When we check the scrutinee of the inner +case, however, we set le_joins to empty, and catch the +error. Similarly, join points can occur free in RHSes of other join +points but not the RHSes of value bindings (thunks and functions). ************************************************************************ * * @@ -387,10 +384,9 @@ lintCoreBindings :: DynFlags -> CoreToDo -> [Var] -> CoreProgram -> (Bag MsgDoc, -- If you edit this function, you may need to update the GHC formalism -- See Note [GHC Formalism] lintCoreBindings dflags pass local_in_scope binds - = initL dflags flags $ - addLoc TopLevelBindings $ - addInScopeVars local_in_scope $ - addInScopeVars binders $ + = initL dflags flags in_scope_set $ + addLoc TopLevelBindings $ + lintIdBndrs TopLevel binders $ -- Put all the top-level binders in scope at the start -- This is because transformation rules can bring something -- into use 'unexpectedly' @@ -398,6 +394,8 @@ lintCoreBindings dflags pass local_in_scope binds ; checkL (null ext_dups) (dupExtVars ext_dups) ; mapM lint_bind binds } where + in_scope_set = mkInScopeSet (mkVarSet local_in_scope) + flags = LF { lf_check_global_ids = check_globals , lf_check_inline_loop_breakers = check_lbs , lf_check_static_ptrs = check_static_ptrs } @@ -463,9 +461,9 @@ lintUnfolding dflags locn vars expr | isEmptyBag errs = Nothing | otherwise = Just (pprMessageBag errs) where - (_warns, errs) = initL dflags defaultLintFlags linter + in_scope = mkInScopeSet vars + (_warns, errs) = initL dflags defaultLintFlags in_scope linter linter = addLoc (ImportedUnfolding locn) $ - addInScopeVarSet vars $ lintCoreExpr expr lintExpr :: DynFlags @@ -477,9 +475,9 @@ lintExpr dflags vars expr | isEmptyBag errs = Nothing | otherwise = Just (pprMessageBag errs) where - (_warns, errs) = initL dflags defaultLintFlags linter + in_scope = mkInScopeSet (mkVarSet vars) + (_warns, errs) = initL dflags defaultLintFlags in_scope linter linter = addLoc TopLevelBindings $ - addInScopeVars vars $ lintCoreExpr expr {- @@ -499,7 +497,6 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs) = addLoc (RhsOf binder) $ -- Check the rhs do { ty <- lintRhs binder rhs - ; lint_bndr binder -- Check match to RHS type ; binder_ty <- applySubstTy (idType binder) ; ensureEqTys binder_ty ty (mkRhsMsg binder (text "RHS") ty) @@ -571,11 +568,6 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs) -- We should check the unfolding, if any, but this is tricky because -- the unfolding is a SimplifiableCoreExpr. Give up for now. - where - -- If you edit this function, you may need to update the GHC formalism - -- See Note [GHC Formalism] - lint_bndr var | isId var = lintIdBndr top_lvl_flag var $ \_ -> return () - | otherwise = return () -- | Checks the RHS of bindings. It only differs from 'lintCoreExpr' -- in that it doesn't reject occurrences of the function 'makeStatic' when they @@ -680,7 +672,7 @@ lintCoreExpr :: CoreExpr -> LintM OutType -- If you edit this function, you may need to update the GHC formalism -- See Note [GHC Formalism] lintCoreExpr (Var var) - = lintCoreVar var 0 + = lintVarOcc var 0 lintCoreExpr (Lit lit) = return (literalType lit) @@ -726,13 +718,16 @@ lintCoreExpr (Let (NonRec bndr rhs) body) | isId bndr = do { lintSingleBinding NotTopLevel NonRecursive (bndr,rhs) ; addLoc (BodyOfLetRec [bndr]) - (lintIdBndr NotTopLevel bndr $ \_ -> lintCoreExpr body) } + (lintIdBndr NotTopLevel bndr $ \_ -> + addGoodJoins [bndr] $ + lintCoreExpr body) } | otherwise = failWithL (mkLetErr bndr rhs) -- Not quite accurate lintCoreExpr (Let (Rec pairs) body) - = lintIdBndrs bndrs $ \_ -> + = lintIdBndrs NotTopLevel bndrs $ + addGoodJoins bndrs $ do { checkL (null dups) (dupVars dups) ; checkL (all isJoinId bndrs || all (not . isJoinId) bndrs) $ mkInconsistentRecMsg bndrs @@ -812,51 +807,38 @@ lintCoreExpr (Coercion co) = do { (k1, k2, ty1, ty2, role) <- lintInCo co ; return (mkHeteroCoercionType role k1 k2 ty1 ty2) } -lintCoreVar :: Var -> Int -- Number of arguments (type or value) being passed +---------------------- +lintVarOcc :: Var -> Int -- Number of arguments (type or value) being passed -> LintM Type -- returns type of the *variable* -lintCoreVar var nargs +lintVarOcc var nargs = do { checkL (isNonCoVarId var) (text "Non term variable" <+> ppr var) - ; lf <- getLintFlags + -- Cneck that the type of the occurrence is the same + -- as the type of the binding site + ; ty <- applySubstTy (idType var) + ; var' <- lookupIdInScope var + ; let ty' = idType var' + ; ensureEqTys ty ty' $ mkBndrOccTypeMismatchMsg var' var ty' ty + -- Check for a nested occurrence of the StaticPtr constructor. -- See Note [Checking StaticPtrs]. + ; lf <- getLintFlags ; when (nargs /= 0 && lf_check_static_ptrs lf /= AllowAnywhere) $ checkL (idName var /= makeStaticName) $ text "Found makeStatic nested in an expression" ; checkDeadIdOcc var - ; ty <- applySubstTy (idType var) - ; var' <- lookupIdInScope var - ; let ty' = idType var' - ; ensureEqTys ty ty' $ mkBndrOccTypeMismatchMsg var' var ty' ty - ; mb_join_arity - <- case isJoinId_maybe var' of - Just join_arity -> - do { checkL (isJoinId_maybe var == Just join_arity) $ - mkJoinBndrOccMismatchMsg var' var - ; return $ Just join_arity } - Nothing -> - case tailCallInfo (idOccInfo var') of - AlwaysTailCalled join_arity -> return $ Just join_arity - -- This function will be turned into a join point by the - -- simplifier; typecheck it as if it already were one - NoTailCallInfo -> return $ Nothing - ; case mb_join_arity of - Just join_arity -> - do { bad <- isBadJoin var' - ; checkL (not bad) $ mkJoinOutOfScopeMsg var' - ; checkL (nargs == join_arity) $ - mkBadJumpMsg var' join_arity nargs } - Nothing -> - do { checkL (not (isJoinId var)) $ - mkJoinBndrOccMismatchMsg var' var } + ; checkJoinOcc var nargs + ; return (idType var') } -lintCoreFun :: CoreExpr -> Int -- Number of arguments (type or val) being passed - -> LintM Type -- returns type of the *function* +lintCoreFun :: CoreExpr + -> Int -- Number of arguments (type or val) being passed + -> LintM Type -- Returns type of the *function* lintCoreFun (Var var) nargs - = lintCoreVar var nargs + = lintVarOcc var nargs + lintCoreFun (Lam var body) nargs -- Act like lintCoreExpr of Lam, but *don't* call markAllJoinsBad; see -- Note [Beta redexes] @@ -865,10 +847,47 @@ lintCoreFun (Lam var body) nargs lintBinder var $ \ var' -> do { body_ty <- lintCoreFun body (nargs - 1) ; return $ mkLamType var' body_ty } + lintCoreFun expr nargs = markAllJoinsBadIf (nargs /= 0) $ lintCoreExpr expr +------------------ +checkDeadIdOcc :: Id -> LintM () +-- Occurrences of an Id should never be dead.... +-- except when we are checking a case pattern +checkDeadIdOcc id + | isDeadOcc (idOccInfo id) + = do { in_case <- inCasePat + ; checkL in_case + (text "Occurrence of a dead Id" <+> ppr id) } + | otherwise + = return () + +------------------ +checkJoinOcc :: Id -> JoinArity -> LintM () +-- Check that if the occurrence is a JoinId, then so is the +-- binding site, and it's a valid join Id +checkJoinOcc var n_args + | Just join_arity_occ <- isJoinId_maybe var + = do { mb_join_arity_bndr <- lookupJoinId var + ; case mb_join_arity_bndr of { + Nothing -> -- Binder is not a join point + addErrL (invalidJoinOcc var) ; + + Just join_arity_bndr -> + + do { checkL (join_arity_bndr == join_arity_occ) $ + -- Arity differs at binding site and occurrence + mkJoinBndrOccMismatchMsg var join_arity_bndr join_arity_occ + + ; checkL (n_args == join_arity_occ) $ + -- Arity doesn't match #args + mkBadJumpMsg var join_arity_occ n_args } } } + + | otherwise + = return () + {- Note [No alternatives lint check] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1010,17 +1029,6 @@ lintTyKind tyvar arg_ty where tyvar_kind = tyVarKind tyvar -checkDeadIdOcc :: Id -> LintM () --- Occurrences of an Id should never be dead.... --- except when we are checking a case pattern -checkDeadIdOcc id - | isDeadOcc (idOccInfo id) - = do { in_case <- inCasePat - ; checkL in_case - (text "Occurrence of a dead Id" <+> ppr id) } - | otherwise - = return () - {- ************************************************************************ * * @@ -1152,21 +1160,22 @@ lintCoBndr cv thing_inside (text "CoVar with non-coercion type:" <+> pprTyVar cv) ; updateTCvSubst subst' (thing_inside cv') } -lintIdBndrs :: [Var] -> ([Var] -> LintM a) -> LintM a -lintIdBndrs ids linterF +lintIdBndrs :: TopLevelFlag -> [Var] -> LintM a -> LintM a +lintIdBndrs top_lvl ids linterF = go ids where - go [] = linterF [] - go (id:ids) = lintIdBndr NotTopLevel id $ \id -> - lintIdBndrs ids $ \ids -> - linterF (id:ids) + go [] = linterF + go (id:ids) = lintIdBndr top_lvl id $ \_ -> + lintIdBndrs top_lvl ids $ + linterF lintIdBndr :: TopLevelFlag -> InVar -> (OutVar -> LintM a) -> LintM a -- Do substitution on the type of a binder and add the var with this -- new type to the in-scope set of the second argument -- ToDo: lint its rules lintIdBndr top_lvl id linterF - = do { flags <- getLintFlags + = ASSERT2( isId id, ppr id ) + do { flags <- getLintFlags ; checkL (not (lf_check_global_ids flags) || isLocalId id) (text "Non-local Id binder" <+> ppr id) -- See Note [Checking for global Ids] @@ -1784,7 +1793,8 @@ data LintEnv , le_subst :: TCvSubst -- Current type substitution; we also use this -- to keep track of all the variables in scope, -- both Ids and TyVars - , le_bad_joins :: IdSet -- Join points that are no longer valid + , le_joins :: IdSet -- Join points in scope that are valid + -- A subset of teh InScopeSet in le_subst -- See Note [Join points] , le_dynflags :: DynFlags -- DynamicFlags } @@ -1891,13 +1901,17 @@ data LintLocInfo | InType Type -- Inside a type | InCo Coercion -- Inside a coercion -initL :: DynFlags -> LintFlags -> LintM a -> WarnsAndErrs -- Errors and warnings -initL dflags flags m +initL :: DynFlags -> LintFlags -> InScopeSet + -> LintM a -> WarnsAndErrs -- Errors and warnings +initL dflags flags in_scope m = case unLintM m env (emptyBag, emptyBag) of (_, errs) -> errs where - env = LE { le_flags = flags, le_subst = emptyTCvSubst, le_loc = [] - , le_dynflags = dflags, le_bad_joins = emptyVarSet } + env = LE { le_flags = flags + , le_subst = mkEmptyTCvSubst in_scope + , le_joins = emptyVarSet + , le_loc = [] + , le_dynflags = dflags } getLintFlags :: LintM LintFlags getLintFlags = LintM $ \ env errs -> (Just (le_flags env), errs) @@ -1952,29 +1966,12 @@ inCasePat = LintM $ \ env errs -> (Just (is_case_pat env), errs) is_case_pat (LE { le_loc = CasePat {} : _ }) = True is_case_pat _other = False -addInScopeVars :: [Var] -> LintM a -> LintM a -addInScopeVars vars m - = LintM $ \ env errs -> - unLintM m (env { le_subst = extendTCvInScopeList (le_subst env) vars - , le_bad_joins = bad_joins' env }) - errs - where - bad_joins' env = delVarSetList (le_bad_joins env) (filter isJoinId vars) - -addInScopeVarSet :: VarSet -> LintM a -> LintM a -addInScopeVarSet vars m - = LintM $ \ env errs -> - unLintM m (env { le_subst = extendTCvInScopeSet (le_subst env) vars }) - errs - addInScopeVar :: Var -> LintM a -> LintM a addInScopeVar var m = LintM $ \ env errs -> - unLintM m (env { le_subst = extendTCvInScope (le_subst env) var - , le_bad_joins = bad_joins' env }) errs - where - bad_joins' env | isJoinId var = delVarSet (le_bad_joins env) var - | otherwise = le_bad_joins env + unLintM m (env { le_subst = extendTCvInScope (le_subst env) var + , le_joins = delVarSet (le_joins env) var + }) errs extendSubstL :: TyVar -> Type -> LintM a -> LintM a extendSubstL tv ty m @@ -1987,16 +1984,25 @@ updateTCvSubst subst' m markAllJoinsBad :: LintM a -> LintM a markAllJoinsBad m - = LintM $ \ env errs -> unLintM m (marked env) errs - where - marked env = env { le_bad_joins = filterVarSet isJoinId in_set } - where - in_set = getInScopeVars (getTCvInScope (le_subst env)) + = LintM $ \ env errs -> unLintM m (env { le_joins = emptyVarSet }) errs markAllJoinsBadIf :: Bool -> LintM a -> LintM a markAllJoinsBadIf True m = markAllJoinsBad m markAllJoinsBadIf False m = m +addGoodJoins :: [Var] -> LintM a -> LintM a +addGoodJoins vars thing_inside + | null join_ids + = thing_inside + | otherwise + = LintM $ \ env errs -> unLintM thing_inside (add_joins env) errs + where + add_joins env = env { le_joins = le_joins env `extendVarSetList` join_ids } + join_ids = filter isJoinId vars + +getValidJoins :: LintM IdSet +getValidJoins = LintM (\ env errs -> (Just (le_joins env), errs)) + getTCvSubst :: LintM TCvSubst getTCvSubst = LintM (\ env errs -> (Just (le_subst env), errs)) @@ -2022,9 +2028,14 @@ lookupIdInScope id where out_of_scope = pprBndr LetBind id <+> text "is out of scope" -isBadJoin :: Id -> LintM Bool -isBadJoin id = LintM $ \env errs -> (Just (id `elemVarSet` le_bad_joins env), - errs) +lookupJoinId :: Id -> LintM (Maybe JoinArity) +-- Look up an Id which should be a join point, valid here +-- If so, return its arity, if not return Nothing +lookupJoinId id + = do { join_set <- getValidJoins + ; case lookupVarSet join_set id of + Just id' -> return (isJoinId_maybe id') + Nothing -> return Nothing } lintTyCoVarInScope :: Var -> LintM () lintTyCoVarInScope v = lintInScope (text "is out of scope") v @@ -2294,9 +2305,10 @@ mkBadJoinArityMsg var ar nlams text "Join arity:" <+> ppr ar, text "Number of lambdas:" <+> ppr nlams ] -mkJoinOutOfScopeMsg :: Var -> SDoc -mkJoinOutOfScopeMsg var - = text "Join variable no longer in scope:" <+> ppr var +invalidJoinOcc :: Var -> SDoc +invalidJoinOcc var + = vcat [ text "Invalid occurrence of a join variable:" <+> ppr var + , text "The binder is either not a join point, or not valid here" ] mkBadJumpMsg :: Var -> Int -> Int -> SDoc mkBadJumpMsg var ar nargs @@ -2312,17 +2324,12 @@ mkInconsistentRecMsg bndrs where ppr_with_details bndr = ppr bndr <> ppr (idDetails bndr) -mkJoinBndrOccMismatchMsg :: Var -> Var -> SDoc -mkJoinBndrOccMismatchMsg bndr var - = vcat [ text "Mismatch in join point status between binder and occurrence", - text "Var:" <+> ppr bndr, - text "Binder:" <+> ppr_join_status bndr, - text "Occ:" <+> ppr_join_status var ] - where - ppr_join_status v = case details of JoinId _ -> ppr details - _ -> text "not a join id" - where - details = idDetails v +mkJoinBndrOccMismatchMsg :: Var -> JoinArity -> JoinArity -> SDoc +mkJoinBndrOccMismatchMsg bndr join_arity_bndr join_arity_occ + = vcat [ text "Mismatch in join point arity between binder and occurrence" + , text "Var:" <+> ppr bndr + , text "Arity at binding site:" <+> ppr join_arity_bndr + , text "Arity at occurrence: " <+> ppr join_arity_occ ] mkBndrOccTypeMismatchMsg :: Var -> Var -> OutType -> OutType -> SDoc mkBndrOccTypeMismatchMsg bndr var bndr_ty var_ty |