summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Peyton Jones <simonpj@microsoft.com>2020-07-14 15:37:18 +0100
committerSimon Peyton Jones <simonpj@microsoft.com>2020-07-16 15:27:20 +0100
commit45af3f080462c297eb3c186303040e21a880d2d2 (patch)
treea790a3cf48cb4b9e24854bb2f52d4bcc67a37ab8
parentae11bdfd98a10266bfc7de9e16b500be220307ac (diff)
downloadhaskell-wip/T18449.tar.gz
Refactor the simplification of join binderswip/T18449
This MR (for #18449) refactors the Simplifier's treatment of join-point binders. Specifically, it puts together, into GHC.Core.Opt.Simplify.Env.adjustJoinPointType two currently-separate ways in which we adjust the type of a join point. As the comment says: -- (adjustJoinPointType mult new_res_ty join_id) does two things: -- -- 1. Set the return type of the join_id to new_res_ty -- See Note [Return type for join points] -- -- 2. Adjust the multiplicity of arrows in join_id's type, as -- directed by 'mult'. See Note [Scaling join point arguments] I think this actually fixes a latent bug, by ensuring that the seIdSubst and seInScope have the right multiplicity on the type of join points. I did some tidying up while I was at it. No more setJoinResTy, or modifyJoinResTy: instead it's done locally in Simplify.Env.adjustJoinPointType
-rw-r--r--compiler/GHC/Core.hs21
-rw-r--r--compiler/GHC/Core/Opt/Simplify.hs51
-rw-r--r--compiler/GHC/Core/Opt/Simplify/Env.hs164
-rw-r--r--compiler/GHC/Core/Type.hs25
4 files changed, 141 insertions, 120 deletions
diff --git a/compiler/GHC/Core.hs b/compiler/GHC/Core.hs
index bc61929ed4..a99a1adfd6 100644
--- a/compiler/GHC/Core.hs
+++ b/compiler/GHC/Core.hs
@@ -798,15 +798,18 @@ and case-of-case] in GHC.Core.Opt.Simplify):
in
jump j z w
-The body of the join point now returns a Bool, so the label `j` has to have its
-type updated accordingly. Inconvenient though this may be, it has the advantage
-that 'GHC.Core.Utils.exprType' can still return a type for any expression, including
-a jump.
-
-This differs from the paper (see Note [Invariants on join points]). In the
-paper, we instead give j the type `Int -> Bool -> forall a. a`. Then each jump
-carries the "return type" as a parameter, exactly the way other non-returning
-functions like `error` work:
+The body of the join point now returns a Bool, so the label `j` has to
+have its type updated accordingly, which is done by
+GHC.Core.Opt.Simplify.Env.adjustJoinPointType. Inconvenient though
+this may be, it has the advantage that 'GHC.Core.Utils.exprType' can
+still return a type for any expression, including a jump.
+
+Relationship to the paper
+
+This plan differs from the paper (see Note [Invariants on join
+points]). In the paper, we instead give j the type `Int -> Bool ->
+forall a. a`. Then each jump carries the "return type" as a parameter,
+exactly the way other non-returning functions like `error` work:
case (join
j :: Int -> Bool -> forall a. a
diff --git a/compiler/GHC/Core/Opt/Simplify.hs b/compiler/GHC/Core/Opt/Simplify.hs
index 355dd256c1..abfad1940f 100644
--- a/compiler/GHC/Core/Opt/Simplify.hs
+++ b/compiler/GHC/Core/Opt/Simplify.hs
@@ -55,7 +55,7 @@ import GHC.Core.Rules ( lookupRule, getRules, initRuleOpts )
import GHC.Types.Basic
import GHC.Utils.Monad ( mapAccumLM, liftIO )
import GHC.Types.Var ( isTyCoVar )
-import GHC.Data.Maybe ( orElse, fromMaybe )
+import GHC.Data.Maybe ( orElse )
import Control.Monad
import GHC.Utils.Outputable
import GHC.Data.FastString
@@ -63,7 +63,6 @@ import GHC.Utils.Misc
import GHC.Utils.Error
import GHC.Unit.Module ( moduleName, pprModuleName )
import GHC.Core.Multiplicity
-import GHC.Core.TyCo.Rep ( TyCoBinder(..) )
import GHC.Builtin.PrimOps ( PrimOp (SeqOp) )
@@ -361,44 +360,8 @@ simplJoinBind :: SimplEnv
simplJoinBind env cont old_bndr new_bndr rhs rhs_se
= do { let rhs_env = rhs_se `setInScopeFromE` env
; rhs' <- simplJoinRhs rhs_env old_bndr rhs cont
- ; let mult = contHoleScaling cont
- arity = fromMaybe (pprPanic "simplJoinBind" (ppr new_bndr)) $
- isJoinIdDetails_maybe (idDetails new_bndr)
- new_type = scaleJoinPointType mult arity (varType new_bndr)
- new_bndr' = setIdType new_bndr new_type
- ; completeBind env NotTopLevel (Just cont) old_bndr new_bndr' rhs' }
+ ; completeBind env NotTopLevel (Just cont) old_bndr new_bndr rhs' }
-{-
-Note [Scaling join point arguments]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Consider a join point which is linear in its variable, in some context E:
-
-E[join j :: a #-> a
- j x = x
- in case v of
- A -> j 'x'
- B -> <blah>]
-
-The simplifier changes to:
-
-join j :: a #-> a
- j x = E[x]
-in case v of
- A -> j 'x'
- B -> E[<blah>]
-
-If E uses its argument in a nonlinear way (e.g. a case['Many]), then
-this is wrong: the join point has to change its type to a -> a.
-Otherwise, we'd get a linearity error.
-
-See also Note [Return type for join points] and Note [Join points and case-of-case].
--}
-scaleJoinPointType :: Mult -> Int -> Type -> Type
-scaleJoinPointType mult arity ty | arity == 0 = ty
- | otherwise = case splitPiTy ty of
- (binder, ty') -> mkPiTy (scaleBinder binder) (scaleJoinPointType mult (arity-1) ty')
- where scaleBinder (Anon af t) = Anon af (scaleScaled mult t)
- scaleBinder b@(Named _) = b
--------------------------
simplNonRecX :: SimplEnv
-> InId -- Old binder; not a JoinId
@@ -1726,8 +1689,9 @@ simplNonRecJoinPoint env bndr rhs body cont
= wrapJoinCont env cont $ \ env cont ->
do { -- We push join_cont into the join RHS and the body;
-- and wrap wrap_cont around the whole thing
- ; let res_ty = contResultType cont
- ; (env1, bndr1) <- simplNonRecJoinBndr env res_ty bndr
+ ; let mult = contHoleScaling cont
+ res_ty = contResultType cont
+ ; (env1, bndr1) <- simplNonRecJoinBndr env bndr mult res_ty
; (env2, bndr2) <- addBndrRules env1 bndr bndr1 (Just cont)
; (floats1, env3) <- simplJoinBind env2 cont bndr bndr2 rhs env
; (floats2, body') <- simplExprF env3 body cont
@@ -1740,9 +1704,10 @@ simplRecJoinPoint :: SimplEnv -> [(InId, InExpr)]
-> SimplM (SimplFloats, OutExpr)
simplRecJoinPoint env pairs body cont
= wrapJoinCont env cont $ \ env cont ->
- do { let bndrs = map fst pairs
+ do { let bndrs = map fst pairs
+ mult = contHoleScaling cont
res_ty = contResultType cont
- ; env1 <- simplRecJoinBndrs env res_ty bndrs
+ ; env1 <- simplRecJoinBndrs env bndrs mult res_ty
-- NB: bndrs' don't have unfoldings or rules
-- We add them as we go down
; (floats1, env2) <- simplRecBind env1 NotTopLevel (Just cont) pairs
diff --git a/compiler/GHC/Core/Opt/Simplify/Env.hs b/compiler/GHC/Core/Opt/Simplify/Env.hs
index c5f8193b4f..237739e23c 100644
--- a/compiler/GHC/Core/Opt/Simplify/Env.hs
+++ b/compiler/GHC/Core/Opt/Simplify/Env.hs
@@ -51,6 +51,7 @@ import GHC.Core.Opt.Simplify.Monad
import GHC.Core.Opt.Monad ( SimplMode(..) )
import GHC.Core
import GHC.Core.Utils
+import GHC.Core.Multiplicity ( scaleScaled )
import GHC.Types.Var
import GHC.Types.Var.Env
import GHC.Types.Var.Set
@@ -59,6 +60,7 @@ import GHC.Types.Id as Id
import GHC.Core.Make ( mkWildValBinder )
import GHC.Driver.Session ( DynFlags )
import GHC.Builtin.Types
+import GHC.Core.TyCo.Rep ( TyCoBinder(..) )
import qualified GHC.Core.Type as Type
import GHC.Core.Type hiding ( substTy, substTyVar, substTyVarBndr, extendTvSubst, extendCvSubst )
import qualified GHC.Core.Coercion as Coercion
@@ -741,24 +743,14 @@ simplBinder :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplBinder env bndr
| isTyVar bndr = do { let (env', tv) = substTyVarBndr env bndr
; seqTyVar tv `seq` return (env', tv) }
- | otherwise = do { let (env', id) = substIdBndr Nothing env bndr
+ | otherwise = do { let (env', id) = substIdBndr env bndr
; seqId id `seq` return (env', id) }
---------------
simplNonRecBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
-- A non-recursive let binder
simplNonRecBndr env id
- = do { let (env1, id1) = substIdBndr Nothing env id
- ; seqId id1 `seq` return (env1, id1) }
-
----------------
-simplNonRecJoinBndr :: SimplEnv -> OutType -> InBndr
- -> SimplM (SimplEnv, OutBndr)
--- A non-recursive let binder for a join point;
--- context being pushed inward may change the type
--- See Note [Return type for join points]
-simplNonRecJoinBndr env res_ty id
- = do { let (env1, id1) = substIdBndr (Just res_ty) env id
+ = do { let (env1, id1) = substIdBndr env id
; seqId id1 `seq` return (env1, id1) }
---------------
@@ -766,31 +758,20 @@ simplRecBndrs :: SimplEnv -> [InBndr] -> SimplM SimplEnv
-- Recursive let binders
simplRecBndrs env@(SimplEnv {}) ids
= ASSERT(all (not . isJoinId) ids)
- do { let (env1, ids1) = mapAccumL (substIdBndr Nothing) env ids
+ do { let (env1, ids1) = mapAccumL substIdBndr env ids
; seqIds ids1 `seq` return env1 }
----------------
-simplRecJoinBndrs :: SimplEnv -> OutType -> [InBndr] -> SimplM SimplEnv
--- Recursive let binders for join points;
--- context being pushed inward may change types
--- See Note [Return type for join points]
-simplRecJoinBndrs env@(SimplEnv {}) res_ty ids
- = ASSERT(all isJoinId ids)
- do { let (env1, ids1) = mapAccumL (substIdBndr (Just res_ty)) env ids
- ; seqIds ids1 `seq` return env1 }
---------------
-substIdBndr :: Maybe OutType -> SimplEnv -> InBndr -> (SimplEnv, OutBndr)
+substIdBndr :: SimplEnv -> InBndr -> (SimplEnv, OutBndr)
-- Might be a coercion variable
-substIdBndr new_res_ty env bndr
+substIdBndr env bndr
| isCoVar bndr = substCoVarBndr env bndr
- | otherwise = substNonCoVarIdBndr new_res_ty env bndr
+ | otherwise = substNonCoVarIdBndr env bndr
---------------
substNonCoVarIdBndr
- :: Maybe OutType -- New result type, if a join binder
- -- See Note [Return type for join points]
- -> SimplEnv
+ :: SimplEnv
-> InBndr -- Env and binder to transform
-> (SimplEnv, OutBndr)
-- Clone Id if necessary, substitute its type
@@ -810,25 +791,28 @@ substNonCoVarIdBndr
-- Similar to GHC.Core.Subst.substIdBndr, except that
-- the type of id_subst differs
-- all fragile info is zapped
-substNonCoVarIdBndr new_res_ty
- env@(SimplEnv { seInScope = in_scope
- , seIdSubst = id_subst })
- old_id
+substNonCoVarIdBndr env id = subst_id_bndr env id (\x -> x)
+
+subst_id_bndr :: SimplEnv
+ -> InBndr -- Env and binder to transform
+ -> (OutId -> OutId) -- Adjust the type
+ -> (SimplEnv, OutBndr)
+subst_id_bndr env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst })
+ old_id adjust_type
= ASSERT2( not (isCoVar old_id), ppr old_id )
(env { seInScope = in_scope `extendInScopeSet` new_id,
seIdSubst = new_subst }, new_id)
+ -- It's important that both seInScope and seIdSubt are updated with
+ -- the new_id, /after/ applying adjust_type. That's why adjust_type
+ -- is done here. If we did adjust_type in simplJoinBndr (the only
+ -- place that gives a non-identity adjust_type) we'd have to fiddle
+ -- afresh with both seInScope and seIdSubst
where
- id1 = uniqAway in_scope old_id
- id2 = substIdType env id1
-
- id3 | Just res_ty <- new_res_ty
- = id2 `setIdType` setJoinResTy (idJoinArity id2) res_ty (idType id2)
- -- See Note [Return type for join points]
- | otherwise
- = id2
-
- new_id = zapFragileIdInfo id3 -- Zaps rules, worker-info, unfolding
- -- and fragile OccInfo
+ id1 = uniqAway in_scope old_id
+ id2 = substIdType env id1
+ id3 = zapFragileIdInfo id2 -- Zaps rules, worker-info, unfolding
+ -- and fragile OccInfo
+ new_id = adjust_type id3
-- Extend the substitution if the unique has changed,
-- or there's some useful occurrence information
@@ -889,6 +873,100 @@ the letrec.
-}
+{- *********************************************************************
+* *
+ Join points
+* *
+********************************************************************* -}
+
+simplNonRecJoinBndr :: SimplEnv -> InBndr
+ -> Mult -> OutType
+ -> SimplM (SimplEnv, OutBndr)
+
+-- A non-recursive let binder for a join point;
+-- context being pushed inward may change the type
+-- See Note [Return type for join points]
+simplNonRecJoinBndr env id mult res_ty
+ = do { let (env1, id1) = simplJoinBndr mult res_ty env id
+ ; seqId id1 `seq` return (env1, id1) }
+
+simplRecJoinBndrs :: SimplEnv -> [InBndr]
+ -> Mult -> OutType
+ -> SimplM SimplEnv
+-- Recursive let binders for join points;
+-- context being pushed inward may change types
+-- See Note [Return type for join points]
+simplRecJoinBndrs env@(SimplEnv {}) ids mult res_ty
+ = ASSERT(all isJoinId ids)
+ do { let (env1, ids1) = mapAccumL (simplJoinBndr mult res_ty) env ids
+ ; seqIds ids1 `seq` return env1 }
+
+---------------
+simplJoinBndr :: Mult -> OutType
+ -> SimplEnv -> InBndr
+ -> (SimplEnv, OutBndr)
+simplJoinBndr mult res_ty env id
+ = subst_id_bndr env id (adjustJoinPointType mult res_ty)
+
+---------------
+adjustJoinPointType :: Mult
+ -> Type -- New result type
+ -> Id -- Old join-point Id
+ -> Id -- Adjusted jont-point Id
+-- (adjustJoinPointType mult new_res_ty join_id) does two things:
+--
+-- 1. Set the return type of the join_id to new_res_ty
+-- See Note [Return type for join points]
+--
+-- 2. Adjust the multiplicity of arrows in join_id's type, as
+-- directed by 'mult'. See Note [Scaling join point arguments]
+--
+-- INVARIANT: If any of the first n binders are foralls, those tyvars
+-- cannot appear in the original result type. See isValidJoinPointType.
+adjustJoinPointType mult new_res_ty join_id
+ = ASSERT( isJoinId join_id )
+ setIdType join_id new_join_ty
+ where
+ orig_ar = idJoinArity join_id
+ orig_ty = idType join_id
+
+ new_join_ty = go orig_ar orig_ty
+
+ go 0 _ = new_res_ty
+ go n ty | Just (arg_bndr, res_ty) <- splitPiTy_maybe ty
+ = mkPiTy (scale_bndr arg_bndr) $
+ go (n-1) res_ty
+ | otherwise
+ = pprPanic "adjustJoinPointType" (ppr orig_ar <+> ppr orig_ty)
+
+ scale_bndr (Anon af t) = Anon af (scaleScaled mult t)
+ scale_bndr b@(Named _) = b
+
+{- Note [Scaling join point arguments]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider a join point which is linear in its variable, in some context E:
+
+E[join j :: a #-> a
+ j x = x
+ in case v of
+ A -> j 'x'
+ B -> <blah>]
+
+The simplifier changes to:
+
+join j :: a #-> a
+ j x = E[x]
+in case v of
+ A -> j 'x'
+ B -> E[<blah>]
+
+If E uses its argument in a nonlinear way (e.g. a case['Many]), then
+this is wrong: the join point has to change its type to a -> a.
+Otherwise, we'd get a linearity error.
+
+See also Note [Return type for join points] and Note [Join points and case-of-case].
+-}
+
{-
************************************************************************
* *
diff --git a/compiler/GHC/Core/Type.hs b/compiler/GHC/Core/Type.hs
index 5bb11a9ee7..f65c7c5770 100644
--- a/compiler/GHC/Core/Type.hs
+++ b/compiler/GHC/Core/Type.hs
@@ -83,8 +83,6 @@ module GHC.Core.Type (
tyConArgFlags, appTyArgFlags,
synTyConResKind,
- modifyJoinResTy, setJoinResTy,
-
-- ** Analyzing types
TyCoMapper(..), mapTyCo, mapTyCoX,
TyCoFolder(..), foldTyCo,
@@ -2860,29 +2858,6 @@ splitVisVarsOfType orig_ty = Pair invis_vars vis_vars
splitVisVarsOfTypes :: [Type] -> Pair TyCoVarSet
splitVisVarsOfTypes = foldMap splitVisVarsOfType
-modifyJoinResTy :: Int -- Number of binders to skip
- -> (Type -> Type) -- Function to apply to result type
- -> Type -- Type of join point
- -> Type -- New type
--- INVARIANT: If any of the first n binders are foralls, those tyvars cannot
--- appear in the original result type. See isValidJoinPointType.
-modifyJoinResTy orig_ar f orig_ty
- = go orig_ar orig_ty
- where
- go 0 ty = f ty
- go n ty | Just (arg_bndr, res_ty) <- splitPiTy_maybe ty
- = mkPiTy arg_bndr (go (n-1) res_ty)
- | otherwise
- = pprPanic "modifyJoinResTy" (ppr orig_ar <+> ppr orig_ty)
-
-setJoinResTy :: Int -- Number of binders to skip
- -> Type -- New result type
- -> Type -- Type of join point
- -> Type -- New type
--- INVARIANT: Same as for modifyJoinResTy
-setJoinResTy ar new_res_ty ty
- = modifyJoinResTy ar (const new_res_ty) ty
-
{-
************************************************************************
* *