diff options
author | Simon Peyton Jones <simonpj@microsoft.com> | 2020-07-14 15:37:18 +0100 |
---|---|---|
committer | Simon Peyton Jones <simonpj@microsoft.com> | 2020-07-16 15:27:20 +0100 |
commit | 45af3f080462c297eb3c186303040e21a880d2d2 (patch) | |
tree | a790a3cf48cb4b9e24854bb2f52d4bcc67a37ab8 /compiler | |
parent | ae11bdfd98a10266bfc7de9e16b500be220307ac (diff) | |
download | haskell-wip/T18449.tar.gz |
Refactor the simplification of join binderswip/T18449
This MR (for #18449) refactors the Simplifier's treatment
of join-point binders.
Specifically, it puts together, into
GHC.Core.Opt.Simplify.Env.adjustJoinPointType
two currently-separate ways in which we adjust the type of
a join point. As the comment says:
-- (adjustJoinPointType mult new_res_ty join_id) does two things:
--
-- 1. Set the return type of the join_id to new_res_ty
-- See Note [Return type for join points]
--
-- 2. Adjust the multiplicity of arrows in join_id's type, as
-- directed by 'mult'. See Note [Scaling join point arguments]
I think this actually fixes a latent bug, by ensuring that the
seIdSubst and seInScope have the right multiplicity on the type
of join points.
I did some tidying up while I was at it. No more
setJoinResTy, or modifyJoinResTy: instead it's done locally in
Simplify.Env.adjustJoinPointType
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/GHC/Core.hs | 21 | ||||
-rw-r--r-- | compiler/GHC/Core/Opt/Simplify.hs | 51 | ||||
-rw-r--r-- | compiler/GHC/Core/Opt/Simplify/Env.hs | 164 | ||||
-rw-r--r-- | compiler/GHC/Core/Type.hs | 25 |
4 files changed, 141 insertions, 120 deletions
diff --git a/compiler/GHC/Core.hs b/compiler/GHC/Core.hs index bc61929ed4..a99a1adfd6 100644 --- a/compiler/GHC/Core.hs +++ b/compiler/GHC/Core.hs @@ -798,15 +798,18 @@ and case-of-case] in GHC.Core.Opt.Simplify): in jump j z w -The body of the join point now returns a Bool, so the label `j` has to have its -type updated accordingly. Inconvenient though this may be, it has the advantage -that 'GHC.Core.Utils.exprType' can still return a type for any expression, including -a jump. - -This differs from the paper (see Note [Invariants on join points]). In the -paper, we instead give j the type `Int -> Bool -> forall a. a`. Then each jump -carries the "return type" as a parameter, exactly the way other non-returning -functions like `error` work: +The body of the join point now returns a Bool, so the label `j` has to +have its type updated accordingly, which is done by +GHC.Core.Opt.Simplify.Env.adjustJoinPointType. Inconvenient though +this may be, it has the advantage that 'GHC.Core.Utils.exprType' can +still return a type for any expression, including a jump. + +Relationship to the paper + +This plan differs from the paper (see Note [Invariants on join +points]). In the paper, we instead give j the type `Int -> Bool -> +forall a. a`. Then each jump carries the "return type" as a parameter, +exactly the way other non-returning functions like `error` work: case (join j :: Int -> Bool -> forall a. a diff --git a/compiler/GHC/Core/Opt/Simplify.hs b/compiler/GHC/Core/Opt/Simplify.hs index 355dd256c1..abfad1940f 100644 --- a/compiler/GHC/Core/Opt/Simplify.hs +++ b/compiler/GHC/Core/Opt/Simplify.hs @@ -55,7 +55,7 @@ import GHC.Core.Rules ( lookupRule, getRules, initRuleOpts ) import GHC.Types.Basic import GHC.Utils.Monad ( mapAccumLM, liftIO ) import GHC.Types.Var ( isTyCoVar ) -import GHC.Data.Maybe ( orElse, fromMaybe ) +import GHC.Data.Maybe ( orElse ) import Control.Monad import GHC.Utils.Outputable import GHC.Data.FastString @@ -63,7 +63,6 @@ import GHC.Utils.Misc import GHC.Utils.Error import GHC.Unit.Module ( moduleName, pprModuleName ) import GHC.Core.Multiplicity -import GHC.Core.TyCo.Rep ( TyCoBinder(..) ) import GHC.Builtin.PrimOps ( PrimOp (SeqOp) ) @@ -361,44 +360,8 @@ simplJoinBind :: SimplEnv simplJoinBind env cont old_bndr new_bndr rhs rhs_se = do { let rhs_env = rhs_se `setInScopeFromE` env ; rhs' <- simplJoinRhs rhs_env old_bndr rhs cont - ; let mult = contHoleScaling cont - arity = fromMaybe (pprPanic "simplJoinBind" (ppr new_bndr)) $ - isJoinIdDetails_maybe (idDetails new_bndr) - new_type = scaleJoinPointType mult arity (varType new_bndr) - new_bndr' = setIdType new_bndr new_type - ; completeBind env NotTopLevel (Just cont) old_bndr new_bndr' rhs' } + ; completeBind env NotTopLevel (Just cont) old_bndr new_bndr rhs' } -{- -Note [Scaling join point arguments] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Consider a join point which is linear in its variable, in some context E: - -E[join j :: a #-> a - j x = x - in case v of - A -> j 'x' - B -> <blah>] - -The simplifier changes to: - -join j :: a #-> a - j x = E[x] -in case v of - A -> j 'x' - B -> E[<blah>] - -If E uses its argument in a nonlinear way (e.g. a case['Many]), then -this is wrong: the join point has to change its type to a -> a. -Otherwise, we'd get a linearity error. - -See also Note [Return type for join points] and Note [Join points and case-of-case]. --} -scaleJoinPointType :: Mult -> Int -> Type -> Type -scaleJoinPointType mult arity ty | arity == 0 = ty - | otherwise = case splitPiTy ty of - (binder, ty') -> mkPiTy (scaleBinder binder) (scaleJoinPointType mult (arity-1) ty') - where scaleBinder (Anon af t) = Anon af (scaleScaled mult t) - scaleBinder b@(Named _) = b -------------------------- simplNonRecX :: SimplEnv -> InId -- Old binder; not a JoinId @@ -1726,8 +1689,9 @@ simplNonRecJoinPoint env bndr rhs body cont = wrapJoinCont env cont $ \ env cont -> do { -- We push join_cont into the join RHS and the body; -- and wrap wrap_cont around the whole thing - ; let res_ty = contResultType cont - ; (env1, bndr1) <- simplNonRecJoinBndr env res_ty bndr + ; let mult = contHoleScaling cont + res_ty = contResultType cont + ; (env1, bndr1) <- simplNonRecJoinBndr env bndr mult res_ty ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 (Just cont) ; (floats1, env3) <- simplJoinBind env2 cont bndr bndr2 rhs env ; (floats2, body') <- simplExprF env3 body cont @@ -1740,9 +1704,10 @@ simplRecJoinPoint :: SimplEnv -> [(InId, InExpr)] -> SimplM (SimplFloats, OutExpr) simplRecJoinPoint env pairs body cont = wrapJoinCont env cont $ \ env cont -> - do { let bndrs = map fst pairs + do { let bndrs = map fst pairs + mult = contHoleScaling cont res_ty = contResultType cont - ; env1 <- simplRecJoinBndrs env res_ty bndrs + ; env1 <- simplRecJoinBndrs env bndrs mult res_ty -- NB: bndrs' don't have unfoldings or rules -- We add them as we go down ; (floats1, env2) <- simplRecBind env1 NotTopLevel (Just cont) pairs diff --git a/compiler/GHC/Core/Opt/Simplify/Env.hs b/compiler/GHC/Core/Opt/Simplify/Env.hs index c5f8193b4f..237739e23c 100644 --- a/compiler/GHC/Core/Opt/Simplify/Env.hs +++ b/compiler/GHC/Core/Opt/Simplify/Env.hs @@ -51,6 +51,7 @@ import GHC.Core.Opt.Simplify.Monad import GHC.Core.Opt.Monad ( SimplMode(..) ) import GHC.Core import GHC.Core.Utils +import GHC.Core.Multiplicity ( scaleScaled ) import GHC.Types.Var import GHC.Types.Var.Env import GHC.Types.Var.Set @@ -59,6 +60,7 @@ import GHC.Types.Id as Id import GHC.Core.Make ( mkWildValBinder ) import GHC.Driver.Session ( DynFlags ) import GHC.Builtin.Types +import GHC.Core.TyCo.Rep ( TyCoBinder(..) ) import qualified GHC.Core.Type as Type import GHC.Core.Type hiding ( substTy, substTyVar, substTyVarBndr, extendTvSubst, extendCvSubst ) import qualified GHC.Core.Coercion as Coercion @@ -741,24 +743,14 @@ simplBinder :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr) simplBinder env bndr | isTyVar bndr = do { let (env', tv) = substTyVarBndr env bndr ; seqTyVar tv `seq` return (env', tv) } - | otherwise = do { let (env', id) = substIdBndr Nothing env bndr + | otherwise = do { let (env', id) = substIdBndr env bndr ; seqId id `seq` return (env', id) } --------------- simplNonRecBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr) -- A non-recursive let binder simplNonRecBndr env id - = do { let (env1, id1) = substIdBndr Nothing env id - ; seqId id1 `seq` return (env1, id1) } - ---------------- -simplNonRecJoinBndr :: SimplEnv -> OutType -> InBndr - -> SimplM (SimplEnv, OutBndr) --- A non-recursive let binder for a join point; --- context being pushed inward may change the type --- See Note [Return type for join points] -simplNonRecJoinBndr env res_ty id - = do { let (env1, id1) = substIdBndr (Just res_ty) env id + = do { let (env1, id1) = substIdBndr env id ; seqId id1 `seq` return (env1, id1) } --------------- @@ -766,31 +758,20 @@ simplRecBndrs :: SimplEnv -> [InBndr] -> SimplM SimplEnv -- Recursive let binders simplRecBndrs env@(SimplEnv {}) ids = ASSERT(all (not . isJoinId) ids) - do { let (env1, ids1) = mapAccumL (substIdBndr Nothing) env ids + do { let (env1, ids1) = mapAccumL substIdBndr env ids ; seqIds ids1 `seq` return env1 } ---------------- -simplRecJoinBndrs :: SimplEnv -> OutType -> [InBndr] -> SimplM SimplEnv --- Recursive let binders for join points; --- context being pushed inward may change types --- See Note [Return type for join points] -simplRecJoinBndrs env@(SimplEnv {}) res_ty ids - = ASSERT(all isJoinId ids) - do { let (env1, ids1) = mapAccumL (substIdBndr (Just res_ty)) env ids - ; seqIds ids1 `seq` return env1 } --------------- -substIdBndr :: Maybe OutType -> SimplEnv -> InBndr -> (SimplEnv, OutBndr) +substIdBndr :: SimplEnv -> InBndr -> (SimplEnv, OutBndr) -- Might be a coercion variable -substIdBndr new_res_ty env bndr +substIdBndr env bndr | isCoVar bndr = substCoVarBndr env bndr - | otherwise = substNonCoVarIdBndr new_res_ty env bndr + | otherwise = substNonCoVarIdBndr env bndr --------------- substNonCoVarIdBndr - :: Maybe OutType -- New result type, if a join binder - -- See Note [Return type for join points] - -> SimplEnv + :: SimplEnv -> InBndr -- Env and binder to transform -> (SimplEnv, OutBndr) -- Clone Id if necessary, substitute its type @@ -810,25 +791,28 @@ substNonCoVarIdBndr -- Similar to GHC.Core.Subst.substIdBndr, except that -- the type of id_subst differs -- all fragile info is zapped -substNonCoVarIdBndr new_res_ty - env@(SimplEnv { seInScope = in_scope - , seIdSubst = id_subst }) - old_id +substNonCoVarIdBndr env id = subst_id_bndr env id (\x -> x) + +subst_id_bndr :: SimplEnv + -> InBndr -- Env and binder to transform + -> (OutId -> OutId) -- Adjust the type + -> (SimplEnv, OutBndr) +subst_id_bndr env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst }) + old_id adjust_type = ASSERT2( not (isCoVar old_id), ppr old_id ) (env { seInScope = in_scope `extendInScopeSet` new_id, seIdSubst = new_subst }, new_id) + -- It's important that both seInScope and seIdSubt are updated with + -- the new_id, /after/ applying adjust_type. That's why adjust_type + -- is done here. If we did adjust_type in simplJoinBndr (the only + -- place that gives a non-identity adjust_type) we'd have to fiddle + -- afresh with both seInScope and seIdSubst where - id1 = uniqAway in_scope old_id - id2 = substIdType env id1 - - id3 | Just res_ty <- new_res_ty - = id2 `setIdType` setJoinResTy (idJoinArity id2) res_ty (idType id2) - -- See Note [Return type for join points] - | otherwise - = id2 - - new_id = zapFragileIdInfo id3 -- Zaps rules, worker-info, unfolding - -- and fragile OccInfo + id1 = uniqAway in_scope old_id + id2 = substIdType env id1 + id3 = zapFragileIdInfo id2 -- Zaps rules, worker-info, unfolding + -- and fragile OccInfo + new_id = adjust_type id3 -- Extend the substitution if the unique has changed, -- or there's some useful occurrence information @@ -889,6 +873,100 @@ the letrec. -} +{- ********************************************************************* +* * + Join points +* * +********************************************************************* -} + +simplNonRecJoinBndr :: SimplEnv -> InBndr + -> Mult -> OutType + -> SimplM (SimplEnv, OutBndr) + +-- A non-recursive let binder for a join point; +-- context being pushed inward may change the type +-- See Note [Return type for join points] +simplNonRecJoinBndr env id mult res_ty + = do { let (env1, id1) = simplJoinBndr mult res_ty env id + ; seqId id1 `seq` return (env1, id1) } + +simplRecJoinBndrs :: SimplEnv -> [InBndr] + -> Mult -> OutType + -> SimplM SimplEnv +-- Recursive let binders for join points; +-- context being pushed inward may change types +-- See Note [Return type for join points] +simplRecJoinBndrs env@(SimplEnv {}) ids mult res_ty + = ASSERT(all isJoinId ids) + do { let (env1, ids1) = mapAccumL (simplJoinBndr mult res_ty) env ids + ; seqIds ids1 `seq` return env1 } + +--------------- +simplJoinBndr :: Mult -> OutType + -> SimplEnv -> InBndr + -> (SimplEnv, OutBndr) +simplJoinBndr mult res_ty env id + = subst_id_bndr env id (adjustJoinPointType mult res_ty) + +--------------- +adjustJoinPointType :: Mult + -> Type -- New result type + -> Id -- Old join-point Id + -> Id -- Adjusted jont-point Id +-- (adjustJoinPointType mult new_res_ty join_id) does two things: +-- +-- 1. Set the return type of the join_id to new_res_ty +-- See Note [Return type for join points] +-- +-- 2. Adjust the multiplicity of arrows in join_id's type, as +-- directed by 'mult'. See Note [Scaling join point arguments] +-- +-- INVARIANT: If any of the first n binders are foralls, those tyvars +-- cannot appear in the original result type. See isValidJoinPointType. +adjustJoinPointType mult new_res_ty join_id + = ASSERT( isJoinId join_id ) + setIdType join_id new_join_ty + where + orig_ar = idJoinArity join_id + orig_ty = idType join_id + + new_join_ty = go orig_ar orig_ty + + go 0 _ = new_res_ty + go n ty | Just (arg_bndr, res_ty) <- splitPiTy_maybe ty + = mkPiTy (scale_bndr arg_bndr) $ + go (n-1) res_ty + | otherwise + = pprPanic "adjustJoinPointType" (ppr orig_ar <+> ppr orig_ty) + + scale_bndr (Anon af t) = Anon af (scaleScaled mult t) + scale_bndr b@(Named _) = b + +{- Note [Scaling join point arguments] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Consider a join point which is linear in its variable, in some context E: + +E[join j :: a #-> a + j x = x + in case v of + A -> j 'x' + B -> <blah>] + +The simplifier changes to: + +join j :: a #-> a + j x = E[x] +in case v of + A -> j 'x' + B -> E[<blah>] + +If E uses its argument in a nonlinear way (e.g. a case['Many]), then +this is wrong: the join point has to change its type to a -> a. +Otherwise, we'd get a linearity error. + +See also Note [Return type for join points] and Note [Join points and case-of-case]. +-} + {- ************************************************************************ * * diff --git a/compiler/GHC/Core/Type.hs b/compiler/GHC/Core/Type.hs index 5bb11a9ee7..f65c7c5770 100644 --- a/compiler/GHC/Core/Type.hs +++ b/compiler/GHC/Core/Type.hs @@ -83,8 +83,6 @@ module GHC.Core.Type ( tyConArgFlags, appTyArgFlags, synTyConResKind, - modifyJoinResTy, setJoinResTy, - -- ** Analyzing types TyCoMapper(..), mapTyCo, mapTyCoX, TyCoFolder(..), foldTyCo, @@ -2860,29 +2858,6 @@ splitVisVarsOfType orig_ty = Pair invis_vars vis_vars splitVisVarsOfTypes :: [Type] -> Pair TyCoVarSet splitVisVarsOfTypes = foldMap splitVisVarsOfType -modifyJoinResTy :: Int -- Number of binders to skip - -> (Type -> Type) -- Function to apply to result type - -> Type -- Type of join point - -> Type -- New type --- INVARIANT: If any of the first n binders are foralls, those tyvars cannot --- appear in the original result type. See isValidJoinPointType. -modifyJoinResTy orig_ar f orig_ty - = go orig_ar orig_ty - where - go 0 ty = f ty - go n ty | Just (arg_bndr, res_ty) <- splitPiTy_maybe ty - = mkPiTy arg_bndr (go (n-1) res_ty) - | otherwise - = pprPanic "modifyJoinResTy" (ppr orig_ar <+> ppr orig_ty) - -setJoinResTy :: Int -- Number of binders to skip - -> Type -- New result type - -> Type -- Type of join point - -> Type -- New type --- INVARIANT: Same as for modifyJoinResTy -setJoinResTy ar new_res_ty ty - = modifyJoinResTy ar (const new_res_ty) ty - {- ************************************************************************ * * |