summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/backpack/RnModIface.hs4
-rw-r--r--compiler/basicTypes/BasicTypes.hs126
-rw-r--r--compiler/basicTypes/Demand.hs8
-rw-r--r--compiler/basicTypes/Id.hs62
-rw-r--r--compiler/basicTypes/IdInfo.hs34
-rw-r--r--compiler/basicTypes/IdInfo.hs-boot2
-rw-r--r--compiler/basicTypes/Var.hs18
-rw-r--r--compiler/basicTypes/VarEnv.hs10
-rw-r--r--compiler/coreSyn/CoreArity.hs158
-rw-r--r--compiler/coreSyn/CoreArity.hs-boot6
-rw-r--r--compiler/coreSyn/CoreLint.hs337
-rw-r--r--compiler/coreSyn/CorePrep.hs120
-rw-r--r--compiler/coreSyn/CoreStats.hs44
-rw-r--r--compiler/coreSyn/CoreSubst.hs33
-rw-r--r--compiler/coreSyn/CoreSyn.hs223
-rw-r--r--compiler/coreSyn/CoreUnfold.hs47
-rw-r--r--compiler/coreSyn/CoreUtils.hs19
-rw-r--r--compiler/coreSyn/MkCore.hs1
-rw-r--r--compiler/coreSyn/PprCore.hs33
-rw-r--r--compiler/deSugar/DsUtils.hs14
-rw-r--r--compiler/iface/IfaceSyn.hs36
-rw-r--r--compiler/iface/TcIface.hs13
-rw-r--r--compiler/iface/ToIface.hs5
-rw-r--r--compiler/simplCore/CSE.hs16
-rw-r--r--compiler/simplCore/CoreMonad.hs2
-rw-r--r--compiler/simplCore/FloatIn.hs79
-rw-r--r--compiler/simplCore/FloatOut.hs261
-rw-r--r--compiler/simplCore/LiberateCase.hs24
-rw-r--r--compiler/simplCore/OccurAnal.hs956
-rw-r--r--compiler/simplCore/SetLevels.hs363
-rw-r--r--compiler/simplCore/SimplCore.hs17
-rw-r--r--compiler/simplCore/SimplEnv.hs204
-rw-r--r--compiler/simplCore/SimplUtils.hs29
-rw-r--r--compiler/simplCore/Simplify.hs554
-rw-r--r--compiler/specialise/Rules.hs6
-rw-r--r--compiler/specialise/SpecConstr.hs23
-rw-r--r--compiler/specialise/Specialise.hs23
-rw-r--r--compiler/stgSyn/CoreToStg.hs286
-rw-r--r--compiler/stranal/DmdAnal.hs16
-rw-r--r--compiler/stranal/WorkWrap.hs74
-rw-r--r--compiler/stranal/WwLib.hs96
-rw-r--r--compiler/types/Type.hs66
-rw-r--r--compiler/utils/Outputable.hs13
-rw-r--r--compiler/utils/UniqFM.hs10
-rw-r--r--testsuite/tests/deSugar/should_compile/T2431.stderr29
-rw-r--r--testsuite/tests/deriving/perf/all.T4
-rw-r--r--testsuite/tests/numeric/should_compile/T7116.stdout21
-rw-r--r--testsuite/tests/perf/compiler/all.T18
-rw-r--r--testsuite/tests/perf/haddock/all.T6
-rw-r--r--testsuite/tests/perf/join_points/Makefile3
-rw-r--r--testsuite/tests/perf/join_points/all.T28
-rw-r--r--testsuite/tests/perf/join_points/join001.hs16
-rw-r--r--testsuite/tests/perf/join_points/join002.hs51
-rw-r--r--testsuite/tests/perf/join_points/join002.stdout1
-rw-r--r--testsuite/tests/perf/join_points/join003.hs69
-rw-r--r--testsuite/tests/perf/join_points/join003.stdout1
-rw-r--r--testsuite/tests/perf/join_points/join004.hs30
-rw-r--r--testsuite/tests/perf/join_points/join004.stdout1
-rw-r--r--testsuite/tests/perf/join_points/join005.hs23
-rw-r--r--testsuite/tests/perf/join_points/join006.hs22
-rw-r--r--testsuite/tests/perf/join_points/join007.hs42
-rw-r--r--testsuite/tests/perf/join_points/join007.stdout1
-rw-r--r--testsuite/tests/perf/should_run/all.T6
-rw-r--r--testsuite/tests/roles/should_compile/Roles13.stderr41
-rw-r--r--testsuite/tests/simplCore/should_compile/Makefile3
-rw-r--r--testsuite/tests/simplCore/should_compile/T13156.hs37
-rw-r--r--testsuite/tests/simplCore/should_compile/T13156.stdout4
-rw-r--r--testsuite/tests/simplCore/should_compile/T3717.stderr17
-rw-r--r--testsuite/tests/simplCore/should_compile/T3772.stdout17
-rw-r--r--testsuite/tests/simplCore/should_compile/T4908.stderr19
-rw-r--r--testsuite/tests/simplCore/should_compile/T4930.stderr28
-rw-r--r--testsuite/tests/simplCore/should_compile/T5658b.stdout2
-rw-r--r--testsuite/tests/simplCore/should_compile/T7360.stderr47
-rw-r--r--testsuite/tests/simplCore/should_compile/T9400.stderr15
-rw-r--r--testsuite/tests/simplCore/should_compile/all.T3
-rw-r--r--testsuite/tests/simplCore/should_compile/par01.stderr15
-rw-r--r--testsuite/tests/simplCore/should_compile/spec-inline.stderr29
77 files changed, 3969 insertions, 1151 deletions
diff --git a/compiler/backpack/RnModIface.hs b/compiler/backpack/RnModIface.hs
index a6d6eddf49..e32bb74e48 100644
--- a/compiler/backpack/RnModIface.hs
+++ b/compiler/backpack/RnModIface.hs
@@ -606,8 +606,8 @@ rnIfaceConAlt (IfaceDataAlt data_occ) = IfaceDataAlt <$> rnIfaceGlobal data_occ
rnIfaceConAlt alt = pure alt
rnIfaceLetBndr :: Rename IfaceLetBndr
-rnIfaceLetBndr (IfLetBndr fs ty info)
- = IfLetBndr fs <$> rnIfaceType ty <*> rnIfaceIdInfo info
+rnIfaceLetBndr (IfLetBndr fs ty info jpi)
+ = IfLetBndr fs <$> rnIfaceType ty <*> rnIfaceIdInfo info <*> pure jpi
rnIfaceLamBndr :: Rename IfaceLamBndr
rnIfaceLamBndr (bndr, oneshot) = (,) <$> rnIfaceBndr bndr <*> pure oneshot
diff --git a/compiler/basicTypes/BasicTypes.hs b/compiler/basicTypes/BasicTypes.hs
index cf4c9702bc..ff4d2c7cce 100644
--- a/compiler/basicTypes/BasicTypes.hs
+++ b/compiler/basicTypes/BasicTypes.hs
@@ -24,7 +24,7 @@ module BasicTypes(
ConTag, ConTagZ, fIRST_TAG,
- Arity, RepArity,
+ Arity, RepArity, JoinArity,
Alignment,
@@ -64,13 +64,15 @@ module BasicTypes(
noOneShotInfo, hasNoOneShotInfo, isOneShotInfo,
bestOneShot, worstOneShot,
- OccInfo(..), seqOccInfo, zapFragileOcc, isOneOcc,
- isDeadOcc, isStrongLoopBreaker, isWeakLoopBreaker, isNoOcc,
+ OccInfo(..), noOccInfo, seqOccInfo, zapFragileOcc, isOneOcc,
+ isDeadOcc, isStrongLoopBreaker, isWeakLoopBreaker, isManyOccs,
strongLoopBreaker, weakLoopBreaker,
InsideLam, insideLam, notInsideLam,
OneBranch, oneBranch, notOneBranch,
InterestingCxt,
+ TailCallInfo(..), tailCallInfo, zapOccTailCallInfo,
+ isAlwaysTailCalled,
EP(..),
@@ -154,6 +156,12 @@ type Arity = Int
-- \(# x, y #) -> fib (x + y) has representation arity 2
type RepArity = Int
+-- | The number of arguments that a join point takes. Unlike the arity of a
+-- function, this is a purely syntactic property and is fixed when the join
+-- point is created (or converted from a value). Both type and value arguments
+-- are counted.
+type JoinArity = Int
+
{-
************************************************************************
* *
@@ -808,20 +816,23 @@ defn of OccInfo here, safely at the bottom
-- | identifier Occurrence Information
data OccInfo
- = NoOccInfo -- ^ There are many occurrences, or unknown occurrences
+ = ManyOccs { occ_tail :: !TailCallInfo }
+ -- ^ There are many occurrences, or unknown occurrences
| IAmDead -- ^ Marks unused variables. Sometimes useful for
-- lambda and case-bound variables.
- | OneOcc
- !InsideLam
- !OneBranch
- !InterestingCxt -- ^ Occurs exactly once, not inside a rule
+ | OneOcc { occ_in_lam :: !InsideLam
+ , occ_one_br :: !OneBranch
+ , occ_int_cxt :: !InterestingCxt
+ , occ_tail :: !TailCallInfo }
+ -- ^ Occurs exactly once (per branch), not inside a rule
-- | This identifier breaks a loop of mutually recursive functions. The field
-- marks whether it is only a loop breaker due to a reference in a rule
- | IAmALoopBreaker -- Note [LoopBreaker OccInfo]
- !RulesOnly
+ | IAmALoopBreaker { occ_rules_only :: !RulesOnly
+ , occ_tail :: !TailCallInfo }
+ -- Note [LoopBreaker OccInfo]
deriving (Eq)
@@ -839,9 +850,12 @@ Note [LoopBreaker OccInfo]
See OccurAnal Note [Weak loop breakers]
-}
-isNoOcc :: OccInfo -> Bool
-isNoOcc NoOccInfo = True
-isNoOcc _ = False
+noOccInfo :: OccInfo
+noOccInfo = ManyOccs { occ_tail = NoTailCallInfo }
+
+isManyOccs :: OccInfo -> Bool
+isManyOccs ManyOccs{} = True
+isManyOccs _ = False
seqOccInfo :: OccInfo -> ()
seqOccInfo occ = occ `seq` ()
@@ -868,17 +882,41 @@ oneBranch, notOneBranch :: OneBranch
oneBranch = True
notOneBranch = False
+-----------------
+data TailCallInfo = AlwaysTailCalled JoinArity -- See Note [TailCallInfo]
+ | NoTailCallInfo
+ deriving (Eq)
+
+tailCallInfo :: OccInfo -> TailCallInfo
+tailCallInfo IAmDead = NoTailCallInfo
+tailCallInfo other = occ_tail other
+
+zapOccTailCallInfo :: OccInfo -> OccInfo
+zapOccTailCallInfo IAmDead = IAmDead
+zapOccTailCallInfo occ = occ { occ_tail = NoTailCallInfo }
+
+isAlwaysTailCalled :: OccInfo -> Bool
+isAlwaysTailCalled occ
+ = case tailCallInfo occ of AlwaysTailCalled{} -> True
+ NoTailCallInfo -> False
+
+instance Outputable TailCallInfo where
+ ppr (AlwaysTailCalled ar) = sep [ text "Tail", int ar ]
+ ppr _ = empty
+
+-----------------
strongLoopBreaker, weakLoopBreaker :: OccInfo
-strongLoopBreaker = IAmALoopBreaker False
-weakLoopBreaker = IAmALoopBreaker True
+strongLoopBreaker = IAmALoopBreaker False NoTailCallInfo
+weakLoopBreaker = IAmALoopBreaker True NoTailCallInfo
isWeakLoopBreaker :: OccInfo -> Bool
-isWeakLoopBreaker (IAmALoopBreaker _) = True
+isWeakLoopBreaker (IAmALoopBreaker{}) = True
isWeakLoopBreaker _ = False
isStrongLoopBreaker :: OccInfo -> Bool
-isStrongLoopBreaker (IAmALoopBreaker False) = True -- Loop-breaker that breaks a non-rule cycle
-isStrongLoopBreaker _ = False
+isStrongLoopBreaker (IAmALoopBreaker { occ_rules_only = False }) = True
+ -- Loop-breaker that breaks a non-rule cycle
+isStrongLoopBreaker _ = False
isDeadOcc :: OccInfo -> Bool
isDeadOcc IAmDead = True
@@ -889,16 +927,21 @@ isOneOcc (OneOcc {}) = True
isOneOcc _ = False
zapFragileOcc :: OccInfo -> OccInfo
-zapFragileOcc (OneOcc {}) = NoOccInfo
-zapFragileOcc occ = occ
+-- Keep only the most robust data: deadness, loop-breaker-hood
+zapFragileOcc (OneOcc {}) = noOccInfo
+zapFragileOcc occ = zapOccTailCallInfo occ
instance Outputable OccInfo where
-- only used for debugging; never parsed. KSW 1999-07
- ppr NoOccInfo = empty
- ppr (IAmALoopBreaker ro) = text "LoopBreaker" <> if ro then char '!' else empty
+ ppr (ManyOccs tails) = pprShortTailCallInfo tails
ppr IAmDead = text "Dead"
- ppr (OneOcc inside_lam one_branch int_cxt)
- = text "Once" <> pp_lam <> pp_br <> pp_args
+ ppr (IAmALoopBreaker rule_only tails)
+ = text "LoopBreaker" <> pp_ro <> pprShortTailCallInfo tails
+ where
+ pp_ro | rule_only = char '!'
+ | otherwise = empty
+ ppr (OneOcc inside_lam one_branch int_cxt tail_info)
+ = text "Once" <> pp_lam <> pp_br <> pp_args <> pp_tail
where
pp_lam | inside_lam = char 'L'
| otherwise = empty
@@ -906,8 +949,43 @@ instance Outputable OccInfo where
| otherwise = char '*'
pp_args | int_cxt = char '!'
| otherwise = empty
+ pp_tail = pprShortTailCallInfo tail_info
+
+pprShortTailCallInfo :: TailCallInfo -> SDoc
+pprShortTailCallInfo (AlwaysTailCalled ar) = char 'T' <> brackets (int ar)
+pprShortTailCallInfo NoTailCallInfo = empty
{-
+Note [TailCallInfo]
+~~~~~~~~~~~~~~~~~~~
+The occurrence analyser determines what can be made into a join point, but it
+doesn't change the binder into a JoinId because then it would be inconsistent
+with the occurrences. Thus it's left to the simplifier (or to simpleOptExpr) to
+change the IdDetails.
+
+The AlwaysTailCalled marker actually means slightly more than simply that the
+function is always tail-called. See Note [Invariants on join points].
+
+This info is quite fragile and should not be relied upon unless the occurrence
+analyser has *just* run. Use 'Id.isJoinId_maybe' for the permanent state of
+the join-point-hood of a binder; a join id itself will not be marked
+AlwaysTailCalled.
+
+Note that there is a 'TailCallInfo' on a 'ManyOccs' value. One might expect that
+being tail-called would mean that the variable could only appear once per branch
+(thus getting a `OneOcc { occ_one_br = True }` occurrence info), but a join
+point can also be invoked from other join points, not just from case branches:
+
+ let j1 x = ...
+ j2 y = ... j1 z {- tail call -} ...
+ in case w of
+ A -> j1 v
+ B -> j2 u
+ C -> j2 q
+
+Here both 'j1' and 'j2' will get marked AlwaysTailCalled, but j1 will get
+ManyOccs and j2 will get `OneOcc { occ_one_br = True }`.
+
************************************************************************
* *
Default method specification
diff --git a/compiler/basicTypes/Demand.hs b/compiler/basicTypes/Demand.hs
index c72bf3909d..8cacf2270c 100644
--- a/compiler/basicTypes/Demand.hs
+++ b/compiler/basicTypes/Demand.hs
@@ -304,7 +304,9 @@ splitArgStrProdDmd n (Str _ s) = splitStrProdDmd n s
splitStrProdDmd :: Int -> StrDmd -> Maybe [ArgStr]
splitStrProdDmd n HyperStr = Just (replicate n strBot)
splitStrProdDmd n HeadStr = Just (replicate n strTop)
-splitStrProdDmd n (SProd ds) = ASSERT( ds `lengthIs` n) Just ds
+splitStrProdDmd n (SProd ds) = WARN( not (ds `lengthIs` n),
+ text "splitStrProdDmd" $$ ppr n $$ ppr ds )
+ Just ds
splitStrProdDmd _ (SCall {}) = Nothing
-- This can happen when the programmer uses unsafeCoerce,
-- and we don't then want to crash the compiler (Trac #9208)
@@ -586,7 +588,9 @@ seqArgUse _ = ()
splitUseProdDmd :: Int -> UseDmd -> Maybe [ArgUse]
splitUseProdDmd n Used = Just (replicate n useTop)
splitUseProdDmd n UHead = Just (replicate n Abs)
-splitUseProdDmd n (UProd ds) = ASSERT2( ds `lengthIs` n, text "splitUseProdDmd" $$ ppr n $$ ppr ds )
+splitUseProdDmd n (UProd ds) = WARN( not (ds `lengthIs` n),
+ text "splitUseProdDmd" $$ ppr n
+ $$ ppr ds )
Just ds
splitUseProdDmd _ (UCall _ _) = Nothing
-- This can happen when the programmer uses unsafeCoerce,
diff --git a/compiler/basicTypes/Id.hs b/compiler/basicTypes/Id.hs
index 2b1bdfd51b..acb22e8c9b 100644
--- a/compiler/basicTypes/Id.hs
+++ b/compiler/basicTypes/Id.hs
@@ -52,7 +52,7 @@ module Id (
globaliseId, localiseId,
setIdInfo, lazySetIdInfo, modifyIdInfo, maybeModifyIdInfo,
zapLamIdInfo, zapIdDemandInfo, zapIdUsageInfo, zapIdUsageEnvInfo,
- zapIdUsedOnceInfo,
+ zapIdUsedOnceInfo, zapIdTailCallInfo,
zapFragileIdInfo, zapIdStrictness,
transferPolyIdInfo,
@@ -73,6 +73,10 @@ module Id (
-- ** Evidence variables
DictId, isDictId, isEvVar,
+ -- ** Join variables
+ JoinId, isJoinId, isJoinId_maybe, idJoinArity,
+ asJoinId, asJoinId_maybe, zapJoinId,
+
-- ** Inline pragma stuff
idInlinePragma, setInlinePragma, modifyInlinePragma,
idInlineActivation, setInlineActivation, idRuleMatchInfo,
@@ -118,11 +122,12 @@ import IdInfo
import BasicTypes
-- Imported and re-exported
-import Var( Id, CoVar, DictId,
+import Var( Id, CoVar, DictId, JoinId,
InId, InVar,
OutId, OutVar,
- idInfo, idDetails, globaliseId, varType,
- isId, isLocalId, isGlobalId, isExportedId )
+ idInfo, idDetails, setIdDetails, globaliseId, varType,
+ isId, isLocalId, isGlobalId, isExportedId,
+ isJoinId, isJoinId_maybe )
import qualified Var
import Type
@@ -157,7 +162,10 @@ infixl 1 `setIdUnfolding`,
`idCafInfo`,
`setIdDemandInfo`,
- `setIdStrictness`
+ `setIdStrictness`,
+
+ `asJoinId`,
+ `asJoinId_maybe`
{-
************************************************************************
@@ -546,6 +554,40 @@ isDictId id = isDictTy (idType id)
{-
************************************************************************
* *
+ Join variables
+* *
+************************************************************************
+-}
+
+idJoinArity :: JoinId -> JoinArity
+idJoinArity id = isJoinId_maybe id `orElse` pprPanic "idJoinArity" (ppr id)
+
+asJoinId :: Id -> JoinArity -> JoinId
+asJoinId id arity = WARN(not (isLocalId id),
+ text "global id being marked as join var:" <+> ppr id)
+ WARN(not (is_vanilla_or_join id),
+ ppr id <+> pprIdDetails (idDetails id))
+ id `setIdDetails` JoinId arity
+ where
+ is_vanilla_or_join id = case Var.idDetails id of
+ VanillaId -> True
+ JoinId {} -> True
+ _ -> False
+
+zapJoinId :: Id -> Id
+-- May be a regular id already
+zapJoinId jid | isJoinId jid = zapIdTailCallInfo (jid `setIdDetails` VanillaId)
+ -- Core Lint may complain if still marked
+ -- as AlwaysTailCalled
+ | otherwise = jid
+
+asJoinId_maybe :: Id -> Maybe JoinArity -> Id
+asJoinId_maybe id (Just arity) = asJoinId id arity
+asJoinId_maybe id Nothing = zapJoinId id
+
+{-
+************************************************************************
+* *
\subsection{IdInfo stuff}
* *
************************************************************************
@@ -590,9 +632,11 @@ zapIdStrictness id = modifyIdInfo (`setStrictnessInfo` nopSig) id
isStrictId :: Id -> Bool
isStrictId id
= ASSERT2( isId id, text "isStrictId: not an id: " <+> ppr id )
+ not (isJoinId id) && (
(isStrictType (idType id)) ||
-- Take the best of both strictnesses - old and new
(isStrictDmd (idDemandInfo id))
+ )
---------------------------------
-- UNFOLDING
@@ -660,7 +704,7 @@ setIdOccInfo :: Id -> OccInfo -> Id
setIdOccInfo id occ_info = modifyIdInfo (`setOccInfo` occ_info) id
zapIdOccInfo :: Id -> Id
-zapIdOccInfo b = b `setIdOccInfo` NoOccInfo
+zapIdOccInfo b = b `setIdOccInfo` noOccInfo
{-
---------------------------------
@@ -804,6 +848,9 @@ zapIdUsageEnvInfo = zapInfo zapUsageEnvInfo
zapIdUsedOnceInfo :: Id -> Id
zapIdUsedOnceInfo = zapInfo zapUsedOnceInfo
+zapIdTailCallInfo :: Id -> Id
+zapIdTailCallInfo = zapInfo zapTailCallInfo
+
{-
Note [transferPolyIdInfo]
~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -869,13 +916,14 @@ transferPolyIdInfo old_id abstract_wrt new_id
old_inline_prag = inlinePragInfo old_info
old_occ_info = occInfo old_info
new_arity = old_arity + arity_increase
+ new_occ_info = zapOccTailCallInfo old_occ_info
old_strictness = strictnessInfo old_info
new_strictness = increaseStrictSigArity arity_increase old_strictness
transfer new_info = new_info `setArityInfo` new_arity
`setInlinePragInfo` old_inline_prag
- `setOccInfo` old_occ_info
+ `setOccInfo` new_occ_info
`setStrictnessInfo` new_strictness
isNeverLevPolyId :: Id -> Bool
diff --git a/compiler/basicTypes/IdInfo.hs b/compiler/basicTypes/IdInfo.hs
index 44815393e3..f29fba7db1 100644
--- a/compiler/basicTypes/IdInfo.hs
+++ b/compiler/basicTypes/IdInfo.hs
@@ -14,6 +14,7 @@ Haskell. [WDP 94/11])
module IdInfo (
-- * The IdDetails type
IdDetails(..), pprIdDetails, coVarDetails, isCoVarDetails,
+ JoinArity, isJoinIdDetails_maybe,
RecSelParent(..),
-- * The IdInfo type
@@ -28,6 +29,7 @@ module IdInfo (
-- ** Zapping various forms of Info
zapLamInfo, zapFragileInfo,
zapDemandInfo, zapUsageInfo, zapUsageEnvInfo, zapUsedOnceInfo,
+ zapTailCallInfo,
-- ** The ArityInfo type
ArityInfo,
@@ -55,6 +57,9 @@ module IdInfo (
InsideLam, OneBranch,
insideLam, notInsideLam, oneBranch, notOneBranch,
+ TailCallInfo(..),
+ tailCallInfo, isAlwaysTailCalled,
+
-- ** The RuleInfo type
RuleInfo(..),
emptyRuleInfo,
@@ -153,6 +158,8 @@ data IdDetails
| CoVarId -- ^ A coercion variable
-- This only covers /un-lifted/ coercions, of type
-- (t1 ~# t2) or (t1 ~R# t2), not their lifted variants
+ | JoinId JoinArity -- ^ An 'Id' for a join point taking n arguments
+ -- Note [Join points] in CoreSyn
-- | Recursive Selector Parent
data RecSelParent = RecSelData TyCon | RecSelPatSyn PatSyn deriving Eq
@@ -176,6 +183,10 @@ isCoVarDetails :: IdDetails -> Bool
isCoVarDetails CoVarId = True
isCoVarDetails _ = False
+isJoinIdDetails_maybe :: IdDetails -> Maybe JoinArity
+isJoinIdDetails_maybe (JoinId join_arity) = Just join_arity
+isJoinIdDetails_maybe _ = Nothing
+
instance Outputable IdDetails where
ppr = pprIdDetails
@@ -195,6 +206,7 @@ pprIdDetails other = brackets (pp other)
= brackets $ text "RecSel" <>
ppWhen is_naughty (text "(naughty)")
pp CoVarId = text "CoVarId"
+ pp (JoinId arity) = text "JoinId" <> parens (int arity)
{-
************************************************************************
@@ -285,7 +297,7 @@ vanillaIdInfo
unfoldingInfo = noUnfolding,
oneShotInfo = NoOneShotInfo,
inlinePragInfo = defaultInlinePragma,
- occInfo = NoOccInfo,
+ occInfo = noOccInfo,
demandInfo = topDmd,
strictnessInfo = nopSig,
callArityInfo = unknownArity,
@@ -482,12 +494,16 @@ zapLamInfo info@(IdInfo {occInfo = occ, demandInfo = demand})
where
-- The "unsafe" occ info is the ones that say I'm not in a lambda
-- because that might not be true for an unsaturated lambda
- is_safe_occ (OneOcc in_lam _ _) = in_lam
- is_safe_occ _other = True
+ is_safe_occ occ | isAlwaysTailCalled occ = False
+ is_safe_occ (OneOcc { occ_in_lam = in_lam }) = in_lam
+ is_safe_occ _other = True
safe_occ = case occ of
- OneOcc _ once int_cxt -> OneOcc insideLam once int_cxt
- _other -> occ
+ OneOcc{} -> occ { occ_in_lam = True
+ , occ_tail = NoTailCallInfo }
+ IAmALoopBreaker{}
+ -> occ { occ_tail = NoTailCallInfo }
+ _other -> occ
is_safe_dmd dmd = not (isStrictDmd dmd)
@@ -529,6 +545,14 @@ zapFragileUnfolding unf
| isFragileUnfolding unf = noUnfolding
| otherwise = unf
+zapTailCallInfo :: IdInfo -> Maybe IdInfo
+zapTailCallInfo info
+ = case occInfo info of
+ occ | isAlwaysTailCalled occ -> Just (info `setOccInfo` safe_occ)
+ | otherwise -> Nothing
+ where
+ safe_occ = occ { occ_tail = NoTailCallInfo }
+
{-
************************************************************************
* *
diff --git a/compiler/basicTypes/IdInfo.hs-boot b/compiler/basicTypes/IdInfo.hs-boot
index 0fabad3bbb..27c1217e15 100644
--- a/compiler/basicTypes/IdInfo.hs-boot
+++ b/compiler/basicTypes/IdInfo.hs-boot
@@ -1,4 +1,5 @@
module IdInfo where
+import BasicTypes
import Outputable
data IdInfo
data IdDetails
@@ -6,5 +7,6 @@ data IdDetails
vanillaIdInfo :: IdInfo
coVarDetails :: IdDetails
isCoVarDetails :: IdDetails -> Bool
+isJoinIdDetails_maybe :: IdDetails -> Maybe JoinArity
pprIdDetails :: IdDetails -> SDoc
diff --git a/compiler/basicTypes/Var.hs b/compiler/basicTypes/Var.hs
index 3f78c2800f..2b728afa4f 100644
--- a/compiler/basicTypes/Var.hs
+++ b/compiler/basicTypes/Var.hs
@@ -34,7 +34,7 @@
module Var (
-- * The main data type and synonyms
- Var, CoVar, Id, NcId, DictId, DFunId, EvVar, EqVar, EvId, IpId,
+ Var, CoVar, Id, NcId, DictId, DFunId, EvVar, EqVar, EvId, IpId, JoinId,
TyVar, TypeVar, KindVar, TKVar, TyCoVar,
-- * In and Out variants
@@ -57,6 +57,7 @@ module Var (
-- ** Predicates
isId, isTyVar, isTcTyVar,
isLocalVar, isLocalId, isCoVar, isNonCoVarId, isTyCoVar,
+ isJoinId, isJoinId_maybe,
isGlobalId, isExportedId,
mustHaveLocalBinding,
@@ -83,8 +84,11 @@ module Var (
import {-# SOURCE #-} TyCoRep( Type, Kind, pprKind )
import {-# SOURCE #-} TcType( TcTyVarDetails, pprTcTyVarDetails, vanillaSkolemTv )
-import {-# SOURCE #-} IdInfo( IdDetails, IdInfo, coVarDetails, isCoVarDetails, vanillaIdInfo, pprIdDetails )
+import {-# SOURCE #-} IdInfo( IdDetails, IdInfo, coVarDetails, isCoVarDetails,
+ isJoinIdDetails_maybe,
+ vanillaIdInfo, pprIdDetails )
+import BasicTypes ( JoinArity )
import Name hiding (varName)
import Unique ( Uniquable, Unique, getKey, getUnique
, mkUniqueGrimily, nonDetCmpUnique )
@@ -92,6 +96,7 @@ import Util
import Binary
import DynFlags
import Outputable
+import Maybes
import Data.Data
@@ -149,6 +154,7 @@ type IpId = EvId -- A term-level implicit parameter
-- | Equality Variable
type EqVar = EvId -- Boxed equality evidence
+type JoinId = Id -- A join variable
-- | Type or Coercion Variable
type TyCoVar = Id -- Type, *or* coercion variable
@@ -612,6 +618,14 @@ isNonCoVarId :: Var -> Bool
isNonCoVarId (Id { id_details = details }) = not (isCoVarDetails details)
isNonCoVarId _ = False
+isJoinId :: Var -> Bool
+isJoinId (Id { id_details = details }) = isJust (isJoinIdDetails_maybe details)
+isJoinId _ = False
+
+isJoinId_maybe :: Var -> Maybe JoinArity
+isJoinId_maybe (Id { id_details = details }) = isJoinIdDetails_maybe details
+isJoinId_maybe _ = Nothing
+
isLocalId :: Var -> Bool
isLocalId (Id { idScope = LocalId _ }) = True
isLocalId _ = False
diff --git a/compiler/basicTypes/VarEnv.hs b/compiler/basicTypes/VarEnv.hs
index dcb64a9c2d..64357d77fa 100644
--- a/compiler/basicTypes/VarEnv.hs
+++ b/compiler/basicTypes/VarEnv.hs
@@ -12,8 +12,8 @@ module VarEnv (
elemVarEnv,
extendVarEnv, extendVarEnv_C, extendVarEnv_Acc, extendVarEnv_Directly,
extendVarEnvList,
- plusVarEnv, plusVarEnv_C, plusVarEnv_CD, plusVarEnvList,
- alterVarEnv,
+ plusVarEnv, plusVarEnv_C, plusVarEnv_CD, plusMaybeVarEnv_C,
+ plusVarEnvList, alterVarEnv,
delVarEnvList, delVarEnv, delVarEnv_Directly,
minusVarEnv, intersectsVarEnv,
lookupVarEnv, lookupVarEnv_NF, lookupWithDefaultVarEnv,
@@ -41,6 +41,7 @@ module VarEnv (
unitDVarEnv,
delDVarEnv,
delDVarEnvList,
+ minusDVarEnv,
partitionDVarEnv,
anyDVarEnv,
@@ -450,6 +451,7 @@ minusVarEnv :: VarEnv a -> VarEnv b -> VarEnv a
intersectsVarEnv :: VarEnv a -> VarEnv a -> Bool
plusVarEnv_C :: (a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
plusVarEnv_CD :: (a -> a -> a) -> VarEnv a -> a -> VarEnv a -> a -> VarEnv a
+plusMaybeVarEnv_C :: (a -> a -> Maybe a) -> VarEnv a -> VarEnv a -> VarEnv a
mapVarEnv :: (a -> b) -> VarEnv a -> VarEnv b
modifyVarEnv :: (a -> a) -> VarEnv a -> Var -> VarEnv a
@@ -471,6 +473,7 @@ extendVarEnv_Directly = addToUFM_Directly
extendVarEnvList = addListToUFM
plusVarEnv_C = plusUFM_C
plusVarEnv_CD = plusUFM_CD
+plusMaybeVarEnv_C = plusMaybeUFM_C
delVarEnvList = delListFromUFM
delVarEnv = delFromUFM
minusVarEnv = minusUFM
@@ -541,6 +544,9 @@ mkDVarEnv = listToUDFM
extendDVarEnv :: DVarEnv a -> Var -> a -> DVarEnv a
extendDVarEnv = addToUDFM
+minusDVarEnv :: DVarEnv a -> DVarEnv a' -> DVarEnv a
+minusDVarEnv = minusUDFM
+
lookupDVarEnv :: DVarEnv a -> Var -> Maybe a
lookupDVarEnv = lookupUDFM
diff --git a/compiler/coreSyn/CoreArity.hs b/compiler/coreSyn/CoreArity.hs
index 0d6f4b6627..49f58c66ae 100644
--- a/compiler/coreSyn/CoreArity.hs
+++ b/compiler/coreSyn/CoreArity.hs
@@ -11,7 +11,8 @@
-- | Arity and eta expansion
module CoreArity (
manifestArity, exprArity, typeArity, exprBotStrictness_maybe,
- exprEtaExpandArity, findRhsArity, CheapFun, etaExpand
+ exprEtaExpandArity, findRhsArity, CheapFun, etaExpand,
+ etaExpandToJoinPoint, etaExpandToJoinPointRule
) where
#include "HsVersions.h"
@@ -952,11 +953,17 @@ etaInfoApp subst (Case e b ty alts) eis
etaInfoApp subst (Let b e) eis
= Let b' (etaInfoApp subst' e eis)
where
- (subst', b') = subst_bind subst b
+ (subst', b') = etaInfoAppBind subst b eis
etaInfoApp subst (Tick t e) eis
= Tick (substTickish subst t) (etaInfoApp subst e eis)
+etaInfoApp subst expr _
+ | (Var fun, _) <- collectArgs expr
+ , Var fun' <- lookupIdSubst (text "etaInfoApp" <+> ppr fun) subst fun
+ , isJoinId fun'
+ = subst_expr subst expr
+
etaInfoApp subst e eis
= go (subst_expr subst e) eis
where
@@ -965,6 +972,94 @@ etaInfoApp subst e eis
go e (EtaCo co : eis) = go (Cast e co) eis
--------------
+-- | Apply the eta info to a local binding. Mostly delegates to
+-- `etaInfoAppLocalBndr` and `etaInfoAppRhs`.
+etaInfoAppBind :: Subst -> CoreBind -> [EtaInfo] -> (Subst, CoreBind)
+etaInfoAppBind subst (NonRec bndr rhs) eis
+ = (subst', NonRec bndr' rhs')
+ where
+ bndr_w_new_type = etaInfoAppLocalBndr bndr eis
+ (subst', bndr1) = substBndr subst bndr_w_new_type
+ rhs' = etaInfoAppRhs subst bndr1 rhs eis
+ bndr' | isJoinId bndr = bndr1 `setIdArity` manifestArity rhs'
+ -- Arity may have changed
+ -- (see etaInfoAppRhs example)
+ | otherwise = bndr1
+etaInfoAppBind subst (Rec pairs) eis
+ = (subst', Rec (bndrs' `zip` rhss'))
+ where
+ (bndrs, rhss) = unzip pairs
+ bndrs_w_new_types = map (\bndr -> etaInfoAppLocalBndr bndr eis) bndrs
+ (subst', bndrs1) = substRecBndrs subst bndrs_w_new_types
+ rhss' = zipWith process bndrs1 rhss
+ process bndr' rhs = etaInfoAppRhs subst' bndr' rhs eis
+ bndrs' | isJoinId (head bndrs)
+ = [ bndr1 `setIdArity` manifestArity rhs'
+ | (bndr1, rhs') <- bndrs1 `zip` rhss' ]
+ -- Arities may have changed
+ -- (see etaInfoAppRhs example)
+ | otherwise
+ = bndrs1
+
+--------------
+-- | Apply the eta info to a binder's RHS. Only interesting for a join point,
+-- where we might have this:
+-- join j :: a -> [a] -> [a]
+-- j x = \xs -> x : xs in jump j z
+-- Eta-expanding produces this:
+-- \ys -> (join j :: a -> [a] -> [a]
+-- j x = \xs -> x : xs in jump j z) ys
+-- Now when we push the application to ys inward (see Note [No crap in
+-- eta-expanded code]), it goes to the body of the RHS of the join point (after
+-- the lambda x!):
+-- \ys -> join j :: a -> [a]
+-- j x = x : ys in jump j z
+-- Note that the type and arity of j have both changed.
+etaInfoAppRhs :: Subst -> CoreBndr -> CoreExpr -> [EtaInfo] -> CoreExpr
+etaInfoAppRhs subst bndr expr eis
+ | Just arity <- isJoinId_maybe bndr
+ = do_join_point arity
+ | otherwise
+ = subst_expr subst expr
+ where
+ do_join_point arity = mkLams join_bndrs' join_body'
+ where
+ (join_bndrs, join_body) = collectNBinders arity expr
+ (subst', join_bndrs') = substBndrs subst join_bndrs
+ join_body' = etaInfoApp subst' join_body eis
+
+
+--------------
+-- | Apply the eta info to a local binder. A join point will have the EtaInfos
+-- applied to its RHS, so its type may change. See comment on etaInfoAppRhs for
+-- an example. See Note [No crap in eta-expanded code] for why all this is
+-- necessary.
+etaInfoAppLocalBndr :: CoreBndr -> [EtaInfo] -> CoreBndr
+etaInfoAppLocalBndr bndr orig_eis
+ = case isJoinId_maybe bndr of
+ Just arity -> bndr `setIdType` modifyJoinResTy arity (app orig_eis) ty
+ Nothing -> bndr
+ where
+ ty = idType bndr
+
+ -- | Apply the given EtaInfos to the result type of the join point.
+ app :: [EtaInfo] -- To apply
+ -> Type -- Result type of join point
+ -> Type -- New result type
+ app [] ty
+ = ty
+ app (EtaVar v : eis) ty
+ | isId v = app eis (funResultTy ty)
+ | otherwise = app eis (piResultTy ty (mkTyVarTy v))
+ app (EtaCo co : eis) ty
+ = ASSERT2(from_ty `eqType` ty, fsep ([text "can't apply", ppr orig_eis,
+ text "to", ppr bndr <+> dcolon <+>
+ ppr (idType bndr)]))
+ app eis to_ty
+ where
+ Pair from_ty to_ty = coercionKind co
+
+--------------
mkEtaWW :: Arity -> CoreExpr -> InScopeSet -> Type
-> (InScopeSet, [EtaInfo])
-- EtaInfo contains fresh variables,
@@ -1018,14 +1113,65 @@ mkEtaWW orig_n orig_expr in_scope orig_ty
--------------
--- Avoiding unnecessary substitution; use short-cutting versions
+-- Don't use short-cutting substitution - we may be changing the types of join
+-- points, so applying the in-scope set is necessary
+-- TODO Check if we actually *are* changing any join points' types
subst_expr :: Subst -> CoreExpr -> CoreExpr
-subst_expr = substExprSC (text "CoreArity:substExpr")
+subst_expr = substExpr (text "CoreArity:substExpr")
+
+
+--------------
-subst_bind :: Subst -> CoreBind -> (Subst, CoreBind)
-subst_bind = substBindSC
+-- | Split an expression into the given number of binders and a body,
+-- eta-expanding if necessary. Counts value *and* type binders.
+etaExpandToJoinPoint :: JoinArity -> CoreExpr -> ([CoreBndr], CoreExpr)
+etaExpandToJoinPoint join_arity expr
+ = go join_arity [] expr
+ where
+ go 0 rev_bs e = (reverse rev_bs, e)
+ go n rev_bs (Lam b e) = go (n-1) (b : rev_bs) e
+ go n rev_bs e = case etaBodyForJoinPoint n e of
+ (bs, e') -> (reverse rev_bs ++ bs, e')
+
+etaExpandToJoinPointRule :: JoinArity -> CoreRule -> CoreRule
+etaExpandToJoinPointRule _ rule@(BuiltinRule {})
+ = WARN(True, (sep [text "Can't eta-expand built-in rule:", ppr rule]))
+ -- How did a local binding get a built-in rule anyway? Probably a plugin.
+ rule
+etaExpandToJoinPointRule join_arity rule@(Rule { ru_bndrs = bndrs, ru_rhs = rhs
+ , ru_args = args })
+ | need_args == 0
+ = rule
+ | need_args < 0
+ = pprPanic "etaExpandToJoinPointRule" (ppr join_arity $$ ppr rule)
+ | otherwise
+ = rule { ru_bndrs = bndrs ++ new_bndrs, ru_args = args ++ new_args
+ , ru_rhs = new_rhs }
+ where
+ need_args = join_arity - length args
+ (new_bndrs, new_rhs) = etaBodyForJoinPoint need_args rhs
+ new_args = varsToCoreExprs new_bndrs
+
+-- Adds as many binders as asked for; assumes expr is not a lambda
+etaBodyForJoinPoint :: Int -> CoreExpr -> ([CoreBndr], CoreExpr)
+etaBodyForJoinPoint need_args body
+ = go need_args (exprType body) (init_subst body) [] body
+ where
+ go 0 _ _ rev_bs e
+ = (reverse rev_bs, e)
+ go n ty subst rev_bs e
+ | Just (tv, res_ty) <- splitForAllTy_maybe ty
+ , let (subst', tv') = Type.substTyVarBndr subst tv
+ = go (n-1) res_ty subst' (tv' : rev_bs) (e `App` Type (mkTyVarTy tv'))
+ | Just (arg_ty, res_ty) <- splitFunTy_maybe ty
+ , let (subst', b) = freshEtaId n subst arg_ty
+ = go (n-1) res_ty subst' (b : rev_bs) (e `App` Var b)
+ | otherwise
+ = pprPanic "etaBodyForJoinPoint" $ int need_args $$
+ ppr body $$ ppr (exprType body)
+ init_subst e = mkEmptyTCvSubst (mkInScopeSet (exprFreeVars e))
--------------
freshEtaId :: Int -> TCvSubst -> Type -> (TCvSubst, Id)
diff --git a/compiler/coreSyn/CoreArity.hs-boot b/compiler/coreSyn/CoreArity.hs-boot
new file mode 100644
index 0000000000..4c155daa9c
--- /dev/null
+++ b/compiler/coreSyn/CoreArity.hs-boot
@@ -0,0 +1,6 @@
+module CoreArity where
+
+import BasicTypes
+import CoreSyn
+
+etaExpandToJoinPoint :: JoinArity -> CoreExpr -> ([CoreBndr], CoreExpr)
diff --git a/compiler/coreSyn/CoreLint.hs b/compiler/coreSyn/CoreLint.hs
index c09b4a0288..a776038f6b 100644
--- a/compiler/coreSyn/CoreLint.hs
+++ b/compiler/coreSyn/CoreLint.hs
@@ -37,6 +37,7 @@ import VarEnv
import VarSet
import Name
import Id
+import IdInfo
import PprCore
import ErrUtils
import Coercion
@@ -168,6 +169,28 @@ different types, called bad coercions. Following coercions are forbidden:
coerced to (# B_1,..,B_m #) if n=m and for each pair A_i, B_i rules
(a-e) holds.
+Note [Join points]
+~~~~~~~~~~~~~~~~~~
+
+We check the rules listed in Note [Invariants on join points] in CoreSyn. The
+only one that causes any difficulty is the first: All occurrences must be tail
+calls. To this end, along with the in-scope set, we remember in le_bad_joins the
+subset of join ids that are no longer allowed because they were declared "too
+far away." For example:
+
+ join j x = ... in
+ case e of
+ A -> jump j y -- good
+ B -> case (jump j z) of -- BAD
+ C -> join h = jump j w in ... -- good
+ D -> let x = jump j v in ... -- BAD
+
+A join point remains valid in case branches, so when checking the A branch, j
+is still valid. When we check the scrutinee of the inner case, however, we add j
+to le_bad_joins and catch the error. Similarly, join points can occur free in
+RHSes of other join points but not the RHSes of value bindings (thunks and
+functions).
+
************************************************************************
* *
Beginning and ending passes
@@ -251,6 +274,7 @@ coreDumpFlag CoreDesugar = Just Opt_D_dump_ds
coreDumpFlag CoreDesugarOpt = Just Opt_D_dump_ds
coreDumpFlag CoreTidy = Just Opt_D_dump_simpl
coreDumpFlag CorePrep = Just Opt_D_dump_prep
+coreDumpFlag CoreOccurAnal = Just Opt_D_dump_occur_anal
coreDumpFlag CoreDoPrintCore = Nothing
coreDumpFlag (CoreDoRuleCheck {}) = Nothing
@@ -473,7 +497,7 @@ lintSingleBinding :: TopLevelFlag -> RecFlag -> (Id, CoreExpr) -> LintM ()
lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
= addLoc (RhsOf binder) $
-- Check the rhs
- do { ty <- lintRhs rhs
+ do { ty <- lintRhs binder rhs
; lint_bndr binder -- Check match to RHS type
; binder_ty <- applySubstTy (idType binder)
; ensureEqTys binder_ty ty (mkRhsMsg binder (text "RHS") ty)
@@ -481,6 +505,7 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
-- Check the let/app invariant
-- See Note [CoreSyn let/app invariant] in CoreSyn
; checkL (not (isUnliftedType binder_ty)
+ || isJoinId binder
|| (isNonRec rec_flag && exprOkForSpeculation rhs)
|| exprIsLiteralString rhs)
(mkRhsPrimMsg binder rhs)
@@ -501,6 +526,11 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
(mkTopNonLitStrMsg binder)
; flags <- getLintFlags
+
+ -- Check that if the binder is top-level, it's not a join point
+ ; checkL (not (isJoinId binder && isTopLevel top_lvl_flag))
+ (mkTopJoinMsg binder)
+
; when (lf_check_inline_loop_breakers flags
&& isStrongLoopBreaker (idOccInfo binder)
&& isInlinePragma (idInlinePragma binder))
@@ -535,7 +565,7 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
ppr binder)
_ -> return ()
- ; mapM_ (lintCoreRule binder_ty) (idCoreRules binder)
+ ; mapM_ (lintCoreRule binder binder_ty) (idCoreRules binder)
; lintIdUnfolding binder binder_ty (idUnfolding binder) }
-- We should check the unfolding, if any, but this is tricky because
@@ -546,20 +576,45 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
lint_bndr var | isId var = lintIdBndr top_lvl_flag var $ \_ -> return ()
| otherwise = return ()
--- | Checks the RHS of top-level bindings. It only differs from 'lintCoreExpr'
+-- | Checks the RHS of bindings. It only differs from 'lintCoreExpr'
-- in that it doesn't reject occurrences of the function 'makeStatic' when they
--- appear at the top level and @lf_check_static_ptrs == AllowAtTopLevel@.
+-- appear at the top level and @lf_check_static_ptrs == AllowAtTopLevel@, and
+-- for join points, it skips the outer lambdas that take arguments to the
+-- join point.
--
-- See Note [Checking StaticPtrs].
-lintRhs :: CoreExpr -> LintM OutType
-lintRhs rhs = fmap lf_check_static_ptrs getLintFlags >>= go
+lintRhs :: Id -> CoreExpr -> LintM OutType
+lintRhs bndr rhs
+ | Just arity <- isJoinId_maybe bndr
+ = lint_join_lams arity arity True rhs
+ | AlwaysTailCalled arity <- tailCallInfo (idOccInfo bndr)
+ = lint_join_lams arity arity False rhs
+ where
+ lint_join_lams 0 _ _ rhs
+ = lintCoreExpr rhs
+ lint_join_lams n tot enforce (Lam var expr)
+ = addLoc (LambdaBodyOf var) $
+ lintBinder var $ \ var' ->
+ do { body_ty <- lint_join_lams (n-1) tot enforce expr
+ ; return $ mkLamType var' body_ty }
+ lint_join_lams n tot True _other
+ = failWithL $ mkBadJoinArityMsg bndr tot (tot-n)
+ lint_join_lams _ _ False rhs
+ = markAllJoinsBad $ lintCoreExpr rhs
+ -- Future join point, not yet eta-expanded
+ -- Body is not a tail position
+
+-- Allow applications of the data constructor @StaticPtr@ at the top
+-- but produce errors otherwise.
+lintRhs _bndr rhs = fmap lf_check_static_ptrs getLintFlags >>= go
where
-- Allow occurrences of 'makeStatic' at the top-level but produce errors
-- otherwise.
go AllowAtTopLevel
| (binders0, rhs') <- collectTyBinders rhs
, Just (fun, t, info, e) <- collectMakeStaticArgs rhs'
- = foldr
+ = markAllJoinsBad $
+ foldr
-- imitate @lintCoreExpr (Lam ...)@
(\var loopBinders ->
addLoc (LambdaBodyOf var) $
@@ -572,12 +627,12 @@ lintRhs rhs = fmap lf_check_static_ptrs getLintFlags >>= go
addLoc (AnExpr rhs') $ lintCoreArgs fun_ty [Type t, info, e]
)
binders0
- go _ = lintCoreExpr rhs
+ go _ = markAllJoinsBad $ lintCoreExpr rhs
lintIdUnfolding :: Id -> Type -> Unfolding -> LintM ()
lintIdUnfolding bndr bndr_ty (CoreUnfolding { uf_tmpl = rhs, uf_src = src })
| isStableSource src
- = do { ty <- lintCoreExpr rhs
+ = do { ty <- lintRhs bndr rhs
; ensureEqTys bndr_ty ty (mkRhsMsg bndr (text "unfolding") ty) }
lintIdUnfolding bndr bndr_ty (DFunUnfolding { df_con = con, df_bndrs = bndrs
@@ -624,18 +679,13 @@ lintCoreExpr :: CoreExpr -> LintM OutType
-- If you edit this function, you may need to update the GHC formalism
-- See Note [GHC Formalism]
lintCoreExpr (Var var)
- = do { checkL (isNonCoVarId var)
- (text "Non term variable" <+> ppr var)
-
- ; checkDeadIdOcc var
- ; var' <- lookupIdInScope var
- ; return (idType var') }
+ = lintCoreVar var 0
lintCoreExpr (Lit lit)
= return (literalType lit)
lintCoreExpr (Cast expr co)
- = do { expr_ty <- lintCoreExpr expr
+ = do { expr_ty <- markAllJoinsBad $ lintCoreExpr expr
; co' <- applySubstCo co
; (_, k2, from_ty, to_ty, r) <- lintCoercion co'
; lintL (classifiesTypeWithValues k2)
@@ -644,14 +694,20 @@ lintCoreExpr (Cast expr co)
; ensureEqTys from_ty expr_ty (mkCastErr expr co' from_ty expr_ty)
; return to_ty }
-lintCoreExpr (Tick (Breakpoint _ ids) expr)
- = do forM_ ids $ \id -> do
- checkDeadIdOcc id
- lookupIdInScope id
- lintCoreExpr expr
-
-lintCoreExpr (Tick _other_tickish expr)
- = lintCoreExpr expr
+lintCoreExpr (Tick tickish expr)
+ = do case tickish of
+ Breakpoint _ ids -> forM_ ids $ \id -> do
+ checkDeadIdOcc id
+ lookupIdInScope id
+ _ -> return ()
+ markAllJoinsBadIf block_joins $ lintCoreExpr expr
+ where
+ block_joins = not (tickish `tickishScopesLike` SoftScope)
+ -- TODO Consider whether this is the correct rule. It is consistent with
+ -- the simplifier's behaviour - cost-centre-scoped ticks become part of
+ -- the continuation, and thus they behave like part of an evaluation
+ -- context, but soft-scoped and non-scoped ticks simply wrap the result
+ -- (see Simplify.simplTick).
lintCoreExpr (Let (NonRec tv (Type ty)) body)
| isTyVar tv
@@ -661,7 +717,7 @@ lintCoreExpr (Let (NonRec tv (Type ty)) body)
do { addLoc (RhsOf tv) $ lintTyKind tv' ty'
-- Now extend the substitution so we
-- take advantage of it in the body
- ; extendSubstL tv' ty' $
+ ; extendSubstL tv ty' $
addLoc (BodyOfLetRec [tv]) $
lintCoreExpr body } }
@@ -677,6 +733,8 @@ lintCoreExpr (Let (NonRec bndr rhs) body)
lintCoreExpr (Let (Rec pairs) body)
= lintIdBndrs bndrs $ \_ ->
do { checkL (null dups) (dupVars dups)
+ ; checkL (all isJoinId bndrs || all (not . isJoinId) bndrs) $
+ mkInconsistentRecMsg bndrs
; mapM_ (lintSingleBinding NotTopLevel Recursive) pairs
; addLoc (BodyOfLetRec bndrs) (lintCoreExpr body) }
where
@@ -684,24 +742,15 @@ lintCoreExpr (Let (Rec pairs) body)
(_, dups) = removeDups compare bndrs
lintCoreExpr e@(App _ _)
- = do lf <- getLintFlags
- -- Check for a nested occurrence of the StaticPtr constructor.
- -- See Note [Checking StaticPtrs].
- case fun of
- Var b | lf_check_static_ptrs lf /= AllowAnywhere
- , idName b == makeStaticName
- -> do
- failWithL $ text "Found makeStatic nested in an expression: " <+>
- ppr e
- _ -> go
+ = addLoc (AnExpr e) $
+ do { fun_ty <- lintCoreFun fun (length args)
+ ; lintCoreArgs fun_ty args }
where
- go = do { fun_ty <- lintCoreExpr fun
- ; addLoc (AnExpr e) $ lintCoreArgs fun_ty args }
-
(fun, args) = collectArgs e
lintCoreExpr (Lam var expr)
= addLoc (LambdaBodyOf var) $
+ markAllJoinsBad $
lintBinder var $ \ var' ->
do { body_ty <- lintCoreExpr expr
; return $ mkLamType var' body_ty }
@@ -709,7 +758,7 @@ lintCoreExpr (Lam var expr)
lintCoreExpr e@(Case scrut var alt_ty alts) =
-- Check the scrutinee
do { let scrut_diverges = exprIsBottom scrut
- ; scrut_ty <- lintCoreExpr scrut
+ ; scrut_ty <- markAllJoinsBad $ lintCoreExpr scrut
; (alt_ty, _) <- lintInTy alt_ty
; (var_ty, _) <- lintInTy (idType var)
@@ -762,6 +811,63 @@ lintCoreExpr (Coercion co)
= do { (k1, k2, ty1, ty2, role) <- lintInCo co
; return (mkHeteroCoercionType role k1 k2 ty1 ty2) }
+lintCoreVar :: Var -> Int -- Number of arguments (type or value) being passed
+ -> LintM Type -- returns type of the *variable*
+lintCoreVar var nargs
+ = do { checkL (isNonCoVarId var)
+ (text "Non term variable" <+> ppr var)
+
+ ; lf <- getLintFlags
+ -- Check for a nested occurrence of the StaticPtr constructor.
+ -- See Note [Checking StaticPtrs].
+ ; when (nargs /= 0 && lf_check_static_ptrs lf /= AllowAnywhere) $
+ checkL (idName var /= makeStaticName) $
+ text "Found makeStatic nested in an expression"
+
+ ; checkDeadIdOcc var
+ ; ty <- applySubstTy (idType var)
+ ; var' <- lookupIdInScope var
+ ; let ty' = idType var'
+ ; ensureEqTys ty ty' $ mkBndrOccTypeMismatchMsg var' var ty' ty
+ ; mb_join_arity
+ <- case isJoinId_maybe var' of
+ Just join_arity ->
+ do { checkL (isJoinId_maybe var == Just join_arity) $
+ mkJoinBndrOccMismatchMsg var' var
+ ; return $ Just join_arity }
+ Nothing ->
+ case tailCallInfo (idOccInfo var') of
+ AlwaysTailCalled join_arity -> return $ Just join_arity
+ -- This function will be turned into a join point by the
+ -- simplifier; typecheck it as if it already were one
+ NoTailCallInfo -> return $ Nothing
+ ; case mb_join_arity of
+ Just join_arity ->
+ do { bad <- isBadJoin var'
+ ; checkL (not bad) $ mkJoinOutOfScopeMsg var'
+ ; checkL (nargs == join_arity) $
+ mkBadJumpMsg var' join_arity nargs }
+ Nothing ->
+ do { checkL (not (isJoinId var)) $
+ mkJoinBndrOccMismatchMsg var' var }
+ ; return (idType var') }
+
+lintCoreFun :: CoreExpr -> Int -- Number of arguments (type or val) being passed
+ -> LintM Type -- returns type of the *function*
+lintCoreFun (Var var) nargs
+ = lintCoreVar var nargs
+lintCoreFun (Lam var body) nargs
+ -- Act like lintCoreExpr of Lam, but *don't* call markAllJoinsBad; see
+ -- Note [Beta redexes]
+ | nargs /= 0
+ = addLoc (LambdaBodyOf var) $
+ lintBinder var $ \ var' ->
+ do { body_ty <- lintCoreFun body (nargs - 1)
+ ; return $ mkLamType var' body_ty }
+lintCoreFun expr nargs
+ = markAllJoinsBadIf (nargs /= 0) $
+ lintCoreExpr expr
+
{-
Note [No alternatives lint check]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -783,6 +889,33 @@ correct, but that exprIsBottom is unable to see it. In particular, the
empty-type check in exprIsBottom is an approximation. Therefore, this
check is not fully reliable, and we keep both around.
+Note [Beta redexes]
+~~~~~~~~~~~~~~~~~~~
+Consider:
+
+ join j @x y z = ... in
+ (\@x y z -> jump j @x y z) @t e1 e2
+
+This is clearly ill-typed, since the jump is inside both an application and a
+lambda, either of which is enough to disqualify it as a tail call (see Note
+[Invariants on join points] in CoreSyn). However, strictly from a
+lambda-calculus perspective, the term doesn't go wrong---after the two beta
+reductions, the jump *is* a tail call and everything is fine.
+
+Why would we want to allow this when we have let? One reason is that a compound
+beta redex (that is, one with more than one argument) has different scoping
+rules: naively reducing the above example using lets will capture any free
+occurrence of y in e2. More fundamentally, type lets are tricky; many passes,
+such as Float Out, tacitly assume that the incoming program's type lets have
+all been dealt with by the simplifier. Thus we don't want to let-bind any types
+in, say, CoreSubst.simpleOptPgm, which in some circumstances can run immediately
+before Float Out.
+
+All that said, currently CoreSubst.simpleOptPgm is the only thing using this
+loophole, doing so to avoid re-traversing large functions (beta-reducing a type
+lambda without introducing a type let requires a substitution). TODO: Improve
+simpleOptPgm so that we can forget all this ever happened.
+
************************************************************************
* *
\subsection[lintCoreArgs]{lintCoreArgs}
@@ -806,7 +939,7 @@ lintCoreArg fun_ty (Type arg_ty)
; lintTyApp fun_ty arg_ty' }
lintCoreArg fun_ty arg
- = do { arg_ty <- lintCoreExpr arg
+ = do { arg_ty <- markAllJoinsBad $ lintCoreExpr arg
-- See Note [Levity polymorphism invariants] in CoreSyn
; lintL (not (isTypeLevPoly arg_ty))
(text "Levity-polymorphic argument:" <+>
@@ -1225,15 +1358,21 @@ lint_app doc kfn kas
* *
********************************************************************* -}
-lintCoreRule :: OutType -> CoreRule -> LintM ()
-lintCoreRule _ (BuiltinRule {})
+lintCoreRule :: OutVar -> OutType -> CoreRule -> LintM ()
+lintCoreRule _ _ (BuiltinRule {})
= return () -- Don't bother
-lintCoreRule fun_ty (Rule { ru_name = name, ru_bndrs = bndrs
- , ru_args = args, ru_rhs = rhs })
+lintCoreRule fun fun_ty rule@(Rule { ru_name = name, ru_bndrs = bndrs
+ , ru_args = args, ru_rhs = rhs })
= lintBinders bndrs $ \ _ ->
do { lhs_ty <- foldM lintCoreArg fun_ty args
- ; rhs_ty <- lintCoreExpr rhs
+ ; rhs_ty <- case isJoinId_maybe fun of
+ Just join_arity
+ -> do { checkL (args `lengthIs` join_arity) $
+ mkBadJoinPointRuleMsg fun join_arity rule
+ -- See Note [Rules for join points]
+ ; lintCoreExpr rhs }
+ _ -> markAllJoinsBad $ lintCoreExpr rhs
; ensureEqTys lhs_ty rhs_ty $
(rule_doc <+> vcat [ text "lhs type:" <+> ppr lhs_ty
, text "rhs type:" <+> ppr rhs_ty ])
@@ -1273,6 +1412,26 @@ we'll end up with
RULE forall x y. f ($gw y) = $gw (x+1)
This seems sufficiently obscure that there isn't enough payoff to
try to trim the forall'd binder list.
+
+Note [Rules for join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+A join point cannot be partially applied. However, the left-hand side of a rule
+for a join point is effectively a *pattern*, not a piece of code, so there's an
+argument to be made for allowing a situation like this:
+
+ join $sj :: Int -> Int -> String
+ $sj n m = ...
+ j :: forall a. Eq a => a -> a -> String
+ {-# RULES "SPEC j" jump j @ Int $dEq = jump $sj #-}
+ j @a $dEq x y = ...
+
+Applying this rule can't turn a well-typed program into an ill-typed one, so
+conceivably we could allow it. But we can always eta-expand such an
+"undersaturated" rule (see 'CoreArity.etaExpandToJoinPointRule'), and in fact
+the simplifier would have to in order to deal with the RHS. So we take a
+conservative view and don't allow undersaturated rules for join points. See
+Note [Rules and join points] in OccurAnal for further discussion.
-}
{-
@@ -1624,6 +1783,8 @@ data LintEnv
, le_subst :: TCvSubst -- Current type substitution; we also use this
-- to keep track of all the variables in scope,
-- both Ids and TyVars
+ , le_bad_joins :: IdSet -- Join points that are no longer valid
+ -- See Note [Join points]
, le_dynflags :: DynFlags -- DynamicFlags
}
@@ -1734,7 +1895,8 @@ initL dflags flags m
= case unLintM m env (emptyBag, emptyBag) of
(_, errs) -> errs
where
- env = LE { le_flags = flags, le_subst = emptyTCvSubst, le_loc = [], le_dynflags = dflags }
+ env = LE { le_flags = flags, le_subst = emptyTCvSubst, le_loc = []
+ , le_dynflags = dflags, le_bad_joins = emptyVarSet }
getLintFlags :: LintM LintFlags
getLintFlags = LintM $ \ env errs -> (Just (le_flags env), errs)
@@ -1791,8 +1953,11 @@ inCasePat = LintM $ \ env errs -> (Just (is_case_pat env), errs)
addInScopeVars :: [Var] -> LintM a -> LintM a
addInScopeVars vars m
= LintM $ \ env errs ->
- unLintM m (env { le_subst = extendTCvInScopeList (le_subst env) vars })
+ unLintM m (env { le_subst = extendTCvInScopeList (le_subst env) vars
+ , le_bad_joins = bad_joins' env })
errs
+ where
+ bad_joins' env = delVarSetList (le_bad_joins env) (filter isJoinId vars)
addInScopeVarSet :: VarSet -> LintM a -> LintM a
addInScopeVarSet vars m
@@ -1803,7 +1968,11 @@ addInScopeVarSet vars m
addInScopeVar :: Var -> LintM a -> LintM a
addInScopeVar var m
= LintM $ \ env errs ->
- unLintM m (env { le_subst = extendTCvInScope (le_subst env) var }) errs
+ unLintM m (env { le_subst = extendTCvInScope (le_subst env) var
+ , le_bad_joins = bad_joins' env }) errs
+ where
+ bad_joins' env | isJoinId var = delVarSet (le_bad_joins env) var
+ | otherwise = le_bad_joins env
extendSubstL :: TyVar -> Type -> LintM a -> LintM a
extendSubstL tv ty m
@@ -1814,6 +1983,18 @@ updateTCvSubst :: TCvSubst -> LintM a -> LintM a
updateTCvSubst subst' m
= LintM $ \ env errs -> unLintM m (env { le_subst = subst' }) errs
+markAllJoinsBad :: LintM a -> LintM a
+markAllJoinsBad m
+ = LintM $ \ env errs -> unLintM m (marked env) errs
+ where
+ marked env = env { le_bad_joins = filterVarSet isJoinId in_set }
+ where
+ in_set = getInScopeVars (getTCvInScope (le_subst env))
+
+markAllJoinsBadIf :: Bool -> LintM a -> LintM a
+markAllJoinsBadIf True m = markAllJoinsBad m
+markAllJoinsBadIf False m = m
+
getTCvSubst :: LintM TCvSubst
getTCvSubst = LintM (\ env errs -> (Just (le_subst env), errs))
@@ -1839,6 +2020,10 @@ lookupIdInScope id
where
out_of_scope = pprBndr LetBind id <+> text "is out of scope"
+isBadJoin :: Id -> LintM Bool
+isBadJoin id = LintM $ \env errs -> (Just (id `elemVarSet` le_bad_joins env),
+ errs)
+
lintTyCoVarInScope :: Var -> LintM ()
lintTyCoVarInScope v = lintInScope (text "is out of scope") v
@@ -2096,6 +2281,62 @@ mkBadTyVarMsg tv
= text "Non-tyvar used in TyVarTy:"
<+> ppr tv <+> dcolon <+> ppr (varType tv)
+mkTopJoinMsg :: Var -> SDoc
+mkTopJoinMsg var
+ = text "Join point at top level:" <+> ppr var
+
+mkBadJoinArityMsg :: Var -> Int -> Int -> SDoc
+mkBadJoinArityMsg var ar nlams
+ = vcat [ text "Join point has too few lambdas",
+ text "Join var:" <+> ppr var,
+ text "Join arity:" <+> ppr ar,
+ text "Number of lambdas:" <+> ppr nlams ]
+
+mkJoinOutOfScopeMsg :: Var -> SDoc
+mkJoinOutOfScopeMsg var
+ = text "Join variable no longer in scope:" <+> ppr var
+
+mkBadJumpMsg :: Var -> Int -> Int -> SDoc
+mkBadJumpMsg var ar nargs
+ = vcat [ text "Join point invoked with wrong number of arguments",
+ text "Join var:" <+> ppr var,
+ text "Join arity:" <+> ppr ar,
+ text "Number of arguments:" <+> int nargs ]
+
+mkInconsistentRecMsg :: [Var] -> SDoc
+mkInconsistentRecMsg bndrs
+ = vcat [ text "Recursive let binders mix values and join points",
+ text "Binders:" <+> hsep (map ppr_with_details bndrs) ]
+ where
+ ppr_with_details bndr = ppr bndr <> ppr (idDetails bndr)
+
+mkJoinBndrOccMismatchMsg :: Var -> Var -> SDoc
+mkJoinBndrOccMismatchMsg bndr var
+ = vcat [ text "Mismatch in join point status between binder and occurrence",
+ text "Var:" <+> ppr bndr,
+ text "Binder:" <+> ppr_join_status bndr,
+ text "Occ:" <+> ppr_join_status var ]
+ where
+ ppr_join_status v = case details of JoinId _ -> ppr details
+ _ -> text "not a join id"
+ where
+ details = idDetails v
+
+mkBndrOccTypeMismatchMsg :: Var -> Var -> OutType -> OutType -> SDoc
+mkBndrOccTypeMismatchMsg bndr var bndr_ty var_ty
+ = vcat [ text "Mismatch in type between binder and occurrence"
+ , text "Var:" <+> ppr bndr
+ , text "Binder type:" <+> ppr bndr_ty
+ , text "Occurrence type:" <+> ppr var_ty
+ , text " Before subst:" <+> ppr (idType var) ]
+
+mkBadJoinPointRuleMsg :: JoinId -> JoinArity -> CoreRule -> SDoc
+mkBadJoinPointRuleMsg bndr join_arity rule
+ = vcat [ text "Join point has rule with wrong number of arguments"
+ , text "Var:" <+> ppr bndr
+ , text "Join arity:" <+> ppr join_arity
+ , text "Rule:" <+> ppr rule ]
+
pprLeftOrRight :: LeftOrRight -> MsgDoc
pprLeftOrRight CLeft = text "left"
pprLeftOrRight CRight = text "right"
diff --git a/compiler/coreSyn/CorePrep.hs b/compiler/coreSyn/CorePrep.hs
index 4e4cbb9ff1..74de5af82d 100644
--- a/compiler/coreSyn/CorePrep.hs
+++ b/compiler/coreSyn/CorePrep.hs
@@ -204,9 +204,13 @@ corePrepTopBinds initialCorePrepEnv binds
= go initialCorePrepEnv binds
where
go _ [] = return emptyFloats
- go env (bind : binds) = do (env', bind') <- cpeBind TopLevel env bind
- binds' <- go env' binds
- return (bind' `appendFloats` binds')
+ go env (bind : binds) = do (env', floats, maybe_new_bind)
+ <- cpeBind TopLevel env bind
+ MASSERT(isNothing maybe_new_bind)
+ -- Only join points get returned this way by
+ -- cpeBind, and no join point may float to top
+ floatss <- go env' binds
+ return (floats `appendFloats` floatss)
mkDataConWorkers :: DynFlags -> ModLocation -> [TyCon] -> [CoreBind]
-- See Note [Data constructor workers]
@@ -280,6 +284,29 @@ This is all very gruesome and horrible. It would be better to figure
out CafInfo later, after CorePrep. We'll do that in due course.
Meanwhile this horrible hack works.
+Note [Join points and floating]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Join points can float out of other join points but not out of value bindings:
+
+ let z =
+ let w = ... in -- can float
+ join k = ... in -- can't float
+ ... jump k ...
+ join j x1 ... xn =
+ let y = ... in -- can float (but don't want to)
+ join h = ... in -- can float (but not much point)
+ ... jump h ...
+ in ...
+
+Here, the jump to h remains valid if h is floated outward, but the jump to k
+does not.
+
+We don't float *out* of join points. It would only be safe to float out of
+nullary join points (or ones where the arguments are all either type arguments
+or dead binders). Nullary join points aren't ever recursive, so they're always
+effectively one-shot functions, which we don't float out of. We *could* float
+join points from nullary join points, but there's no clear benefit at this
+stage.
Note [Data constructor workers]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -369,8 +396,12 @@ Into this one:
-}
cpeBind :: TopLevelFlag -> CorePrepEnv -> CoreBind
- -> UniqSM (CorePrepEnv, Floats)
+ -> UniqSM (CorePrepEnv,
+ Floats, -- Floating value bindings
+ Maybe CoreBind) -- Just bind' <=> returned new bind; no float
+ -- Nothing <=> added bind' to floats instead
cpeBind top_lvl env (NonRec bndr rhs)
+ | not (isJoinId bndr)
= do { (_, bndr1) <- cpCloneBndr env bndr
; let dmd = idDemandInfo bndr
is_unlifted = isUnliftedType (idType bndr)
@@ -380,7 +411,7 @@ cpeBind top_lvl env (NonRec bndr rhs)
env bndr1 rhs
-- See Note [Inlining in CorePrep]
; if exprIsTrivial rhs2 && isNotTopLevel top_lvl
- then return (extendCorePrepEnvExpr env bndr rhs2, floats)
+ then return (extendCorePrepEnvExpr env bndr rhs2, floats, Nothing)
else do {
; let new_float = mkFloat dmd is_unlifted bndr2 rhs2
@@ -388,19 +419,38 @@ cpeBind top_lvl env (NonRec bndr rhs)
-- We want bndr'' in the envt, because it records
-- the evaluated-ness of the binder
; return (extendCorePrepEnv env bndr bndr2,
- addFloat floats new_float) }}
+ addFloat floats new_float,
+ Nothing) }}
+ | otherwise -- See Note [Join points and floating]
+ = ASSERT(not (isTopLevel top_lvl)) -- can't have top-level join point
+ do { (_, bndr1) <- cpCloneBndr env bndr
+ ; (bndr2, rhs1) <- cpeJoinPair env bndr1 rhs
+ ; return (extendCorePrepEnv env bndr bndr2,
+ emptyFloats,
+ Just (NonRec bndr2 rhs1)) }
cpeBind top_lvl env (Rec pairs)
- = do { let (bndrs,rhss) = unzip pairs
- ; (env', bndrs1) <- cpCloneBndrs env (map fst pairs)
+ | not (isJoinId (head bndrs))
+ = do { (env', bndrs1) <- cpCloneBndrs env bndrs
; stuff <- zipWithM (cpePair top_lvl Recursive topDmd False env') bndrs1 rhss
; let (floats_s, bndrs2, rhss2) = unzip3 stuff
all_pairs = foldrOL add_float (bndrs2 `zip` rhss2)
(concatFloats floats_s)
; return (extendCorePrepEnvList env (bndrs `zip` bndrs2),
- unitFloat (FloatLet (Rec all_pairs))) }
+ unitFloat (FloatLet (Rec all_pairs)),
+ Nothing) }
+ | otherwise -- See Note [Join points and floating]
+ = do { (env', bndrs1) <- cpCloneBndrs env bndrs
+ ; pairs1 <- zipWithM (cpeJoinPair env') bndrs1 rhss
+
+ ; let bndrs2 = map fst pairs1
+ ; return (extendCorePrepEnvList env' (bndrs `zip` bndrs2),
+ emptyFloats,
+ Just (Rec pairs1)) }
where
+ (bndrs, rhss) = unzip pairs
+
-- Flatten all the floats, and the currrent
-- group into a single giant Rec
add_float (FloatLet (NonRec b r)) prs2 = (b,r) : prs2
@@ -413,7 +463,8 @@ cpePair :: TopLevelFlag -> RecFlag -> Demand -> Bool
-> UniqSM (Floats, Id, CpeRhs)
-- Used for all bindings
cpePair top_lvl is_rec dmd is_unlifted env bndr rhs
- = do { (floats1, rhs1) <- cpeRhsE env rhs
+ = ASSERT(not (isJoinId bndr)) -- those should use cpeJoinPair
+ do { (floats1, rhs1) <- cpeRhsE env rhs
-- See if we are allowed to float this stuff out of the RHS
; (floats2, rhs2) <- float_from_rhs floats1 rhs1
@@ -496,6 +547,45 @@ When InlineMe notes go away this won't happen any more. But
it seems good for CorePrep to be robust.
-}
+---------------
+cpeJoinPair :: CorePrepEnv -> JoinId -> CoreExpr
+ -> UniqSM (JoinId, CpeRhs)
+-- Used for all join bindings
+cpeJoinPair env bndr rhs
+ = ASSERT(isJoinId bndr)
+ do { let Just join_arity = isJoinId_maybe bndr
+ (bndrs, body) = collectNBinders join_arity rhs
+
+ ; (env', bndrs') <- cpCloneBndrs env bndrs
+
+ ; body' <- cpeBodyNF env' body -- Will let-bind the body if it starts
+ -- with a lambda
+
+ ; let rhs' = mkCoreLams bndrs' body'
+ bndr' = bndr `setIdUnfolding` evaldUnfolding
+ `setIdArity` count isId bndrs
+ -- See Note [Arity and join points]
+
+ ; return (bndr', rhs') }
+
+{-
+Note [Arity and join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Up to now, we've allowed a join point to have an arity greater than its join
+arity (minus type arguments), since this is what's useful for eta expansion.
+However, for code gen purposes, its arity must be exactly the number of value
+arguments it will be called with, and it must have exactly that many value
+lambdas. Hence if there are extra lambdas we must let-bind the body of the RHS:
+
+ join j x y z = \w -> ... in ...
+ =>
+ join j x y z = (let f = \w -> ... in f) in ...
+
+This is also what happens with Note [Silly extra arguments]. Note that it's okay
+for us to mess with the arity because a join point is never exported.
+-}
+
-- ---------------------------------------------------------------------------
-- CpeRhs: produces a result satisfying CpeRhs
-- ---------------------------------------------------------------------------
@@ -518,10 +608,12 @@ cpeRhsE _env expr@(Lit {}) = return (emptyFloats, expr)
cpeRhsE env expr@(Var {}) = cpeApp env expr
cpeRhsE env expr@(App {}) = cpeApp env expr
-cpeRhsE env (Let bind expr)
- = do { (env', new_binds) <- cpeBind NotTopLevel env bind
- ; (floats, body) <- cpeRhsE env' expr
- ; return (new_binds `appendFloats` floats, body) }
+cpeRhsE env (Let bind body)
+ = do { (env', bind_floats, maybe_bind') <- cpeBind NotTopLevel env bind
+ ; (body_floats, body') <- cpeRhsE env' body
+ ; let expr' = case maybe_bind' of Just bind' -> Let bind' body'
+ Nothing -> body'
+ ; return (bind_floats `appendFloats` body_floats, expr') }
cpeRhsE env (Tick tickish expr)
| tickishPlace tickish == PlaceNonLam && tickish `tickishScopesLike` SoftScope
diff --git a/compiler/coreSyn/CoreStats.hs b/compiler/coreSyn/CoreStats.hs
index 9ad83214ce..4da81fdb03 100644
--- a/compiler/coreSyn/CoreStats.hs
+++ b/compiler/coreSyn/CoreStats.hs
@@ -11,50 +11,64 @@ module CoreStats (
CoreStats(..), coreBindsStats, exprStats,
) where
+import BasicTypes
import CoreSyn
import Outputable
import Coercion
import Var
import Type (Type, typeSize, seqType)
-import Id (idType)
+import Id (idType, isJoinId)
import CoreSeq (megaSeqIdInfo)
data CoreStats = CS { cs_tm :: Int -- Terms
, cs_ty :: Int -- Types
- , cs_co :: Int } -- Coercions
+ , cs_co :: Int -- Coercions
+ , cs_vb :: Int -- Local value bindings
+ , cs_jb :: Int } -- Local join bindings
instance Outputable CoreStats where
- ppr (CS { cs_tm = i1, cs_ty = i2, cs_co = i3 })
+ ppr (CS { cs_tm = i1, cs_ty = i2, cs_co = i3, cs_vb = i4, cs_jb = i5 })
= braces (sep [text "terms:" <+> intWithCommas i1 <> comma,
text "types:" <+> intWithCommas i2 <> comma,
- text "coercions:" <+> intWithCommas i3])
+ text "coercions:" <+> intWithCommas i3 <> comma,
+ text "joins:" <+> intWithCommas i5 <> char '/' <>
+ intWithCommas (i4 + i5) ])
plusCS :: CoreStats -> CoreStats -> CoreStats
-plusCS (CS { cs_tm = p1, cs_ty = q1, cs_co = r1 })
- (CS { cs_tm = p2, cs_ty = q2, cs_co = r2 })
- = CS { cs_tm = p1+p2, cs_ty = q1+q2, cs_co = r1+r2 }
+plusCS (CS { cs_tm = p1, cs_ty = q1, cs_co = r1, cs_vb = v1, cs_jb = j1 })
+ (CS { cs_tm = p2, cs_ty = q2, cs_co = r2, cs_vb = v2, cs_jb = j2 })
+ = CS { cs_tm = p1+p2, cs_ty = q1+q2, cs_co = r1+r2, cs_vb = v1+v2
+ , cs_jb = j1+j2 }
zeroCS, oneTM :: CoreStats
-zeroCS = CS { cs_tm = 0, cs_ty = 0, cs_co = 0 }
+zeroCS = CS { cs_tm = 0, cs_ty = 0, cs_co = 0, cs_vb = 0, cs_jb = 0 }
oneTM = zeroCS { cs_tm = 1 }
sumCS :: (a -> CoreStats) -> [a] -> CoreStats
sumCS f = foldr (plusCS . f) zeroCS
coreBindsStats :: [CoreBind] -> CoreStats
-coreBindsStats = sumCS bindStats
+coreBindsStats = sumCS (bindStats TopLevel)
-bindStats :: CoreBind -> CoreStats
-bindStats (NonRec v r) = bindingStats v r
-bindStats (Rec prs) = sumCS (\(v,r) -> bindingStats v r) prs
+bindStats :: TopLevelFlag -> CoreBind -> CoreStats
+bindStats top_lvl (NonRec v r) = bindingStats top_lvl v r
+bindStats top_lvl (Rec prs) = sumCS (\(v,r) -> bindingStats top_lvl v r) prs
-bindingStats :: Var -> CoreExpr -> CoreStats
-bindingStats v r = bndrStats v `plusCS` exprStats r
+bindingStats :: TopLevelFlag -> Var -> CoreExpr -> CoreStats
+bindingStats top_lvl v r = letBndrStats top_lvl v `plusCS` exprStats r
bndrStats :: Var -> CoreStats
bndrStats v = oneTM `plusCS` tyStats (varType v)
+letBndrStats :: TopLevelFlag -> Var -> CoreStats
+letBndrStats top_lvl v
+ | isTyVar v || isTopLevel top_lvl = bndrStats v
+ | isJoinId v = oneTM { cs_jb = 1 } `plusCS` ty_stats
+ | otherwise = oneTM { cs_vb = 1 } `plusCS` ty_stats
+ where
+ ty_stats = tyStats (varType v)
+
exprStats :: CoreExpr -> CoreStats
exprStats (Var {}) = oneTM
exprStats (Lit {}) = oneTM
@@ -62,7 +76,7 @@ exprStats (Type t) = tyStats t
exprStats (Coercion c) = coStats c
exprStats (App f a) = exprStats f `plusCS` exprStats a
exprStats (Lam b e) = bndrStats b `plusCS` exprStats e
-exprStats (Let b e) = bindStats b `plusCS` exprStats e
+exprStats (Let b e) = bindStats NotTopLevel b `plusCS` exprStats e
exprStats (Case e b _ as) = exprStats e `plusCS` bndrStats b
`plusCS` sumCS altStats as
exprStats (Cast e co) = coStats co `plusCS` exprStats e
diff --git a/compiler/coreSyn/CoreSubst.hs b/compiler/coreSyn/CoreSubst.hs
index 72df704e1c..9d69493d9e 100644
--- a/compiler/coreSyn/CoreSubst.hs
+++ b/compiler/coreSyn/CoreSubst.hs
@@ -39,6 +39,10 @@ module CoreSubst (
#include "HsVersions.h"
+import {-# SOURCE #-} CoreArity ( etaExpandToJoinPoint )
+ -- Needed for simpleOptPgm to convert bindings to join
+ -- points, but CoreArity uses substitutions throughout
+
import CoreSyn
import CoreFVs
import CoreSeq
@@ -867,6 +871,9 @@ simpleOptExpr :: CoreExpr -> CoreExpr
-- We also inline bindings that bind a Eq# box: see
-- See Note [Getting the map/coerce RULE to work].
--
+-- Also we convert functions to join points where possible (as
+-- the occurrence analyser does most of the work anyway).
+--
-- The result is NOT guaranteed occurrence-analysed, because
-- in (let x = y in ....) we substitute for x; so y's occ-info
-- may change radically
@@ -1012,8 +1019,9 @@ simple_opt_bind' subst (Rec prs)
= (subst'', res_bind)
where
res_bind = Just (Rec (reverse rev_prs'))
- (subst', bndrs') = subst_opt_bndrs subst (map fst prs)
- (subst'', rev_prs') = foldl do_pr (subst', []) (prs `zip` bndrs')
+ prs' = map (uncurry convert_if_marked) prs
+ (subst', bndrs') = subst_opt_bndrs subst (map fst prs')
+ (subst'', rev_prs') = foldl do_pr (subst', []) (prs' `zip` bndrs')
do_pr (subst, prs) ((b,r), b')
= case maybe_substitute subst b r2 of
Just subst' -> (subst', prs)
@@ -1023,7 +1031,20 @@ simple_opt_bind' subst (Rec prs)
r2 = simple_opt_expr subst r
simple_opt_bind' subst (NonRec b r)
- = simple_opt_out_bind subst (b, simple_opt_expr subst r)
+ = simple_opt_out_bind subst (b', simple_opt_expr subst r')
+ where
+ (b', r') = convert_if_marked b r
+
+convert_if_marked :: InVar -> InExpr -> (InVar, InExpr)
+convert_if_marked bndr rhs
+ | isId bndr
+ , AlwaysTailCalled ar <- tailCallInfo (idOccInfo bndr)
+ -- Marked to become a join point
+ , (bndrs, body) <- etaExpandToJoinPoint ar rhs
+ = -- Tail call info now unnecessary
+ (zapIdTailCallInfo (bndr `asJoinId` ar), mkLams bndrs body)
+ | otherwise
+ = (bndr, rhs)
----------------------
simple_opt_out_bind :: Subst -> (InVar, OutExpr) -> (Subst, Maybe CoreBind)
@@ -1072,8 +1093,10 @@ maybe_substitute subst b r
safe_to_inline :: OccInfo -> Bool
safe_to_inline (IAmALoopBreaker {}) = False
safe_to_inline IAmDead = True
- safe_to_inline (OneOcc in_lam one_br _) = (not in_lam && one_br) || trivial
- safe_to_inline NoOccInfo = trivial
+ safe_to_inline occ@(OneOcc {}) = (not (occ_in_lam occ) &&
+ occ_one_br occ)
+ || trivial
+ safe_to_inline (ManyOccs {}) = trivial
trivial | exprIsTrivial r = True
| (Var fun, args) <- collectArgs r
diff --git a/compiler/coreSyn/CoreSyn.hs b/compiler/coreSyn/CoreSyn.hs
index 333a55b8e3..f74e3e585a 100644
--- a/compiler/coreSyn/CoreSyn.hs
+++ b/compiler/coreSyn/CoreSyn.hs
@@ -3,7 +3,7 @@
(c) The GRASP/AQUA Project, Glasgow University, 1992-1998
-}
-{-# LANGUAGE CPP, DeriveDataTypeable #-}
+{-# LANGUAGE CPP, DeriveDataTypeable, FlexibleContexts #-}
-- | CoreSyn holds all the main data types for use by for the Glasgow Haskell Compiler midsection
module CoreSyn (
@@ -21,7 +21,7 @@ module CoreSyn (
-- ** 'Expr' construction
mkLets, mkLams,
- mkApps, mkTyApps, mkCoApps, mkVarApps,
+ mkApps, mkTyApps, mkCoApps, mkVarApps, mkTyArg,
mkIntLit, mkIntLitInt,
mkWordLit, mkWordLitWord,
@@ -38,6 +38,7 @@ module CoreSyn (
-- ** Simple 'Expr' access functions and predicates
bindersOf, bindersOfBinds, rhssOfBind, rhssOfAlts,
collectBinders, collectTyBinders, collectTyAndValBinders,
+ collectNBinders,
collectArgs, collectArgsTicks, flattenBinds,
exprToType, exprToCoercion_maybe,
@@ -75,7 +76,8 @@ module CoreSyn (
collectAnnArgs, collectAnnArgsTicks,
-- ** Operations on annotations
- deAnnotate, deAnnotate', deAnnAlt, collectAnnBndrs,
+ deAnnotate, deAnnotate', deAnnAlt,
+ collectAnnBndrs, collectNAnnBndrs,
-- * Orphanhood
IsOrphan(..), isOrphan, notOrphan, chooseOrphanAnchor,
@@ -408,7 +410,8 @@ The let/app invariant
the right hand side of a non-recursive 'Let', and
the argument of an 'App',
/may/ be of unlifted type, but only if
- the expression is ok-for-speculation.
+ the expression is ok-for-speculation
+ or the 'Let' is for a join point.
This means that the let can be floated around
without difficulty. For example, this is OK:
@@ -510,6 +513,181 @@ this exhaustive list can be empty!
conversion; remember STG is un-typed, so there is no need for
the empty case to do the type conversion.
+Note [Join points]
+~~~~~~~~~~~~~~~~~~
+In Core, a *join point* is a specially tagged function whose only occurrences
+are saturated tail calls. A tail call can appear in these places:
+
+ 1. In the branches (not the scrutinee) of a case
+ 2. Underneath a let (value or join point)
+ 3. Inside another join point
+
+We write a join-point declaration as
+ join j @a @b x y = e1 in e2,
+like a let binding but with "join" instead (or "join rec" for "let rec"). Note
+that we put the parameters before the = rather than using lambdas; this is
+because it's relevant how many parameters the join point takes *as a join
+point.* This number is called the *join arity,* distinct from arity because it
+counts types as well as values. Note that a join point may return a lambda! So
+ join j x = x + 1
+is different from
+ join j = \x -> x + 1
+The former has join arity 1, while the latter has join arity 0.
+
+The identifier for a join point is called a join id or a *label.* An invocation
+is called a *jump.* We write a jump using the jump keyword:
+
+ jump j 3
+
+The words *label* and *jump* are evocative of assembly code (or Cmm) for a
+reason: join points are indeed compiled as labeled blocks, and jumps become
+actual jumps (plus argument passing and stack adjustment). There is no closure
+allocated and only a fraction of the function-call overhead. Hence we would
+like as many functions as possible to become join points (see OccurAnal) and
+the type rules for join points ensure we preserve the properties that make them
+efficient.
+
+In the actual AST, a join point is indicated by the IdDetails of the binder: a
+local value binding gets 'VanillaId' but a join point gets a 'JoinId' with its
+join arity.
+
+For more details, see the paper:
+
+ Luke Maurer, Paul Downen, Zena Ariola, and Simon Peyton Jones. "Compiling
+ without continuations." Submitted to PLDI'17.
+
+ https://www.microsoft.com/en-us/research/publication/compiling-without-continuations/
+
+Note [Invariants on join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Join points must follow these invariants:
+
+ 1. All occurrences must be tail calls. Each of these tail calls must pass the
+ same number of arguments, counting both types and values; we call this the
+ "join arity" (to distinguish from regular arity, which only counts values).
+ 2. For join arity n, the right-hand side must begin with at least n lambdas.
+ 3. If the binding is recursive, then all other bindings in the recursive group
+ must also be join points.
+ 4. The binding's type must not be polymorphic in its return type (as defined
+ in Note [The polymorphism rule of join points]).
+
+Examples:
+
+ join j1 x = 1 + x in jump j (jump j x) -- Fails 1: non-tail call
+ join j1' x = 1 + x in if even a
+ then jump j1 a
+ else jump j1 a b -- Fails 1: inconsistent calls
+ join j2 x = flip (+) x in j2 1 2 -- Fails 2: not enough lambdas
+ join j2' x = \y -> x + y in j3 1 -- Passes: extra lams ok
+ join j @a (x :: a) = x -- Fails 4: polymorphic in ret type
+
+Invariant 1 applies to left-hand sides of rewrite rules, so a rule for a join
+point must have an exact call as its LHS.
+
+Strictly speaking, invariant 3 is redundant, since a call from inside a lazy
+binding isn't a tail call. Since a let-bound value can't invoke a free join
+point, then, they can't be mutually recursive. (A Core binding group *can*
+include spurious extra bindings if the occurrence analyser hasn't run, so
+invariant 3 does still need to be checked.) For the rigorous definition of
+"tail call", see Section 3 of the paper (Note [Join points]).
+
+Invariant 4 is subtle; see Note [The polymorphism rule of join points].
+
+Core Lint will check these invariants, anticipating that any binder whose
+OccInfo is marked AlwaysTailCalled will become a join point as soon as the
+simplifier (or simpleOptPgm) runs.
+
+Note [The type of a join point]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+A join point has the same type it would have as a function. That is, if it takes
+an Int and a Bool and its body produces a String, its type is `Int -> Bool ->
+String`. Natural as this may seem, it can be awkward. A join point shouldn't be
+thought to "return" in the same sense a function does---a jump is one-way. This
+is crucial for understanding how case-of-case interacts with join points:
+
+ case (join
+ j :: Int -> Bool -> String
+ j x y = ...
+ in
+ jump j z w) of
+ "" -> True
+ _ -> False
+
+The simplifier will pull the case into the join point (see Note [Case-of-case
+and join points] in Simplify):
+
+ join
+ j :: Int -> Bool -> Bool -- changed!
+ j x y = case ... of "" -> True
+ _ -> False
+ in
+ jump j z w
+
+The body of the join point now returns a Bool, so the label `j` has to have its
+type updated accordingly. Inconvenient though this may be, it has the advantage
+that 'CoreUtils.exprType' can still return a type for any expression, including
+a jump.
+
+This differs from the paper (see Note [Invariants on join points]). In the
+paper, we instead give j the type `Int -> Bool -> forall a. a`. Then each jump
+carries the "return type" as a parameter, exactly the way other non-returning
+functions like `error` work:
+
+ case (join
+ j :: Int -> Bool -> forall a. a
+ j x y = ...
+ in
+ jump j z w @String) of
+ "" -> True
+ _ -> False
+
+Now we can move the case inward and we only have to change the jump:
+
+ join
+ j :: Int -> Bool -> forall a. a
+ j x y = case ... of "" -> True
+ _ -> False
+ in
+ jump j z w @Bool
+
+(Core Lint would still check that the body of the join point has the right type;
+that type would simply not be reflected in the join id.)
+
+Note [The polymorphism rule of join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Invariant 4 of Note [Invariants on join points] forbids a join point to be
+polymorphic in its return type. That is, if its type is
+
+ forall a1 ... ak. t1 -> ... -> tn -> r
+
+where its join arity is k+n, none of the type parameters ai may occur free in r.
+The most direct explanation is that given
+
+ join j @a1 ... @ak x1 ... xn = e1 in e2
+
+our typing rules require `e1` and `e2` to have the same type. Therefore the type
+of `e1`---the return type of the join point---must be the same as the type of
+e2. Since the type variables aren't bound in `e2`, its type can't include them,
+and thus neither can the type of `e1`.
+
+There's a deeper explanation in terms of the sequent calculus in Section 5.3 of
+a previous paper:
+
+ Paul Downen, Luke Maurer, Zena Ariola, and Simon Peyton Jones. "Sequent
+ calculus as a compiler intermediate language." ICFP'16.
+
+ https://www.microsoft.com/en-us/research/wp-content/uploads/2016/04/sequent-calculus-icfp16.pdf
+
+The quick version: Consider the CPS term (the paper uses the sequent calculus,
+but we can translate readily):
+
+ \k -> join j @a1 ... @ak x1 ... xn = e1 k in e2 k
+
+Since `j` is a join point, it doesn't bind a continuation variable but reuses
+the variable `k` from the context. But the parameters `ai` are not in `k`'s
+scope, and `k`'s type determines the return type of `j`; thus the `ai`s don't
+appear in the return type of `j`. (Also, since `e1` and `e2` are passed the same
+continuation, they must have the same type; hence the direct explanation above.)
************************************************************************
* *
@@ -1534,10 +1712,16 @@ type TaggedAlt t = Alt (TaggedBndr t)
instance Outputable b => Outputable (TaggedBndr b) where
ppr (TB b l) = char '<' <> ppr b <> comma <> ppr l <> char '>'
-instance Outputable b => OutputableBndr (TaggedBndr b) where
+-- OutputableBndr Var is declared separately in PprCore; using a FlexibleContext
+-- to avoid circularity
+instance (OutputableBndr Var, Outputable b) =>
+ OutputableBndr (TaggedBndr b) where
pprBndr _ b = ppr b -- Simple
pprInfixOcc b = ppr b
pprPrefixOcc b = ppr b
+ pprNonRecBndrKeyword (TB b _) = pprNonRecBndrKeyword b
+ pprRecBndrKeyword (TB b _) = pprRecBndrKeyword b
+ pprLamsOnLhs (TB b _) = pprLamsOnLhs b
deTagExpr :: TaggedExpr t -> CoreExpr
deTagExpr (Var v) = Var v
@@ -1584,17 +1768,17 @@ mkCoApps f args = foldl (\ e a -> App e (Coercion a)) f args
mkVarApps f vars = foldl (\ e a -> App e (varToCoreExpr a)) f vars
mkConApp con args = mkApps (Var (dataConWorkId con)) args
-mkTyApps f args = foldl (\ e a -> App e (typeOrCoercion a)) f args
- where
- typeOrCoercion ty
- | Just co <- isCoercionTy_maybe ty = Coercion co
- | otherwise = Type ty
+mkTyApps f args = foldl (\ e a -> App e (mkTyArg a)) f args
mkConApp2 :: DataCon -> [Type] -> [Var] -> Expr b
mkConApp2 con tys arg_ids = Var (dataConWorkId con)
`mkApps` map Type tys
`mkApps` map varToCoreExpr arg_ids
+mkTyArg :: Type -> Expr b
+mkTyArg ty
+ | Just co <- isCoercionTy_maybe ty = Coercion co
+ | otherwise = Type ty
-- | Create a machine integer literal expression of type @Int#@ from an @Integer@.
-- If you want an expression of type @Int@ use 'MkCore.mkIntExpr'
@@ -1750,6 +1934,9 @@ collectBinders :: Expr b -> ([b], Expr b)
collectTyBinders :: CoreExpr -> ([TyVar], CoreExpr)
collectValBinders :: CoreExpr -> ([Id], CoreExpr)
collectTyAndValBinders :: CoreExpr -> ([TyVar], [Id], CoreExpr)
+-- | Strip off exactly N leading lambdas (type or value). Good for use with
+-- join points.
+collectNBinders :: Int -> Expr b -> ([b], Expr b)
collectBinders expr
= go [] expr
@@ -1775,6 +1962,13 @@ collectTyAndValBinders expr
(tvs, body1) = collectTyBinders expr
(ids, body) = collectValBinders body1
+collectNBinders orig_n orig_expr
+ = go orig_n [] orig_expr
+ where
+ go 0 bs expr = (reverse bs, expr)
+ go n bs (Lam b e) = go (n-1) (b:bs) e
+ go _ _ _ = pprPanic "collectNBinders" $ int orig_n
+
-- | Takes a nested application expression and returns the the function
-- being applied and the arguments to which it is applied
collectArgs :: Expr b -> (Expr b, [Arg b])
@@ -1929,3 +2123,12 @@ collectAnnBndrs e
where
collect bs (_, AnnLam b body) = collect (b:bs) body
collect bs body = (reverse bs, body)
+
+-- | As 'collectNBinders' but for 'AnnExpr' rather than 'Expr'
+collectNAnnBndrs :: Int -> AnnExpr bndr annot -> ([bndr], AnnExpr bndr annot)
+collectNAnnBndrs orig_n e
+ = collect orig_n [] e
+ where
+ collect 0 bs body = (reverse bs, body)
+ collect n bs (_, AnnLam b body) = collect (n-1) (b:bs) body
+ collect _ _ _ = pprPanic "collectNBinders" $ int orig_n
diff --git a/compiler/coreSyn/CoreUnfold.hs b/compiler/coreSyn/CoreUnfold.hs
index 574d8418d6..11c4a5e92d 100644
--- a/compiler/coreSyn/CoreUnfold.hs
+++ b/compiler/coreSyn/CoreUnfold.hs
@@ -523,15 +523,13 @@ sizeExpr dflags bOMB_OUT_SIZE top_args expr
| otherwise = size_up e
size_up (Let (NonRec binder rhs) body)
- = size_up rhs `addSizeNSD`
- size_up body `addSizeN`
- (if isUnliftedType (idType binder) then 0 else 10)
- -- For the allocation
- -- If the binder has an unlifted type there is no allocation
+ = size_up_rhs (binder, rhs) `addSizeNSD`
+ size_up body `addSizeN`
+ size_up_alloc binder
size_up (Let (Rec pairs) body)
- = foldr (addSizeNSD . size_up . snd)
- (size_up body `addSizeN` (10 * length pairs)) -- (length pairs) for the allocation
+ = foldr (addSizeNSD . size_up_rhs)
+ (size_up body `addSizeN` sum (map (size_up_alloc . fst) pairs))
pairs
size_up (Case e _ _ alts)
@@ -606,6 +604,14 @@ sizeExpr dflags bOMB_OUT_SIZE top_args expr
| otherwise
= False
+ size_up_rhs (bndr, rhs)
+ | Just join_arity <- isJoinId_maybe bndr
+ -- Skip arguments to join point
+ , (_bndrs, body) <- collectNBinders join_arity rhs
+ = size_up body
+ | otherwise
+ = size_up rhs
+
------------
-- size_up_app is used when there's ONE OR MORE value args
size_up_app (App fun arg) args voids
@@ -642,6 +648,16 @@ sizeExpr dflags bOMB_OUT_SIZE top_args expr
-- A good example is Foreign.C.Error.errrnoToIOError
------------
+ -- Cost to allocate binding with given binder
+ size_up_alloc bndr
+ | isTyVar bndr -- Doesn't exist at runtime
+ || isUnliftedType (idType bndr) -- Doesn't live in heap
+ || isJoinId bndr -- Not allocated at all
+ = 0
+ | otherwise
+ = 10
+
+ ------------
-- These addSize things have to be here because
-- I don't want to give them bOMB_OUT_SIZE as an argument
addSizeN TooBig _ = TooBig
@@ -706,6 +722,17 @@ callSize
-> Int
callSize n_val_args voids = 10 * (1 + n_val_args - voids)
+-- | The size of a jump to a join point
+jumpSize
+ :: Int -- ^ number of value args
+ -> Int -- ^ number of value args that are void
+ -> Int
+jumpSize n_val_args voids = 2 * (1 + n_val_args - voids)
+ -- A jump is 20% the size of a function call. Making jumps free reopens
+ -- bug #6048, but making them any more expensive loses a 21% improvement in
+ -- spectral/puzzle. TODO Perhaps adjusting the default threshold would be a
+ -- better solution?
+
funSize :: DynFlags -> [Id] -> Id -> Int -> Int -> ExprSize
-- Size for functions that are not constructors or primops
-- Note [Function applications]
@@ -715,9 +742,11 @@ funSize dflags top_args fun n_val_args voids
| otherwise = SizeIs size arg_discount res_discount
where
some_val_args = n_val_args > 0
+ is_join = isJoinId fun
- size | some_val_args = callSize n_val_args voids
- | otherwise = 0
+ size | is_join = jumpSize n_val_args voids
+ | not some_val_args = 0
+ | otherwise = callSize n_val_args voids
-- The 1+ is for the function itself
-- Add 1 for each non-trivial arg;
-- the allocation cost, as in let(rec)
diff --git a/compiler/coreSyn/CoreUtils.hs b/compiler/coreSyn/CoreUtils.hs
index d856e3d3d5..4eef079b32 100644
--- a/compiler/coreSyn/CoreUtils.hs
+++ b/compiler/coreSyn/CoreUtils.hs
@@ -49,7 +49,10 @@ module CoreUtils (
stripTicksE, stripTicksT,
-- * StaticPtr
- collectMakeStaticArgs
+ collectMakeStaticArgs,
+
+ -- * Join points
+ isJoinBind
) where
#include "HsVersions.h"
@@ -2304,3 +2307,17 @@ collectMakeStaticArgs e
| (fun@(Var b), [Type t, loc, arg], _) <- collectArgsTicks (const True) e
, idName b == makeStaticName = Just (fun, t, loc, arg)
collectMakeStaticArgs _ = Nothing
+
+{-
+************************************************************************
+* *
+\subsection{Join points}
+* *
+************************************************************************
+-}
+
+-- | Does this binding bind a join point (or a recursive group of join points)?
+isJoinBind :: CoreBind -> Bool
+isJoinBind (NonRec b _) = isJoinId b
+isJoinBind (Rec ((b, _) : _)) = isJoinId b
+isJoinBind _ = False
diff --git a/compiler/coreSyn/MkCore.hs b/compiler/coreSyn/MkCore.hs
index 882faa7f92..7d2420245a 100644
--- a/compiler/coreSyn/MkCore.hs
+++ b/compiler/coreSyn/MkCore.hs
@@ -107,6 +107,7 @@ sortQuantVars vs = sorted_tcvs ++ ids
mkCoreLet :: CoreBind -> CoreExpr -> CoreExpr
mkCoreLet (NonRec bndr rhs) body -- See Note [CoreSyn let/app invariant]
| needsCaseBinding (idType bndr) rhs
+ , not (isJoinId bndr)
= Case rhs bndr (exprType body) [(DEFAULT,[],body)]
mkCoreLet bind body
= Let bind body
diff --git a/compiler/coreSyn/PprCore.hs b/compiler/coreSyn/PprCore.hs
index 152a701991..196a9b9973 100644
--- a/compiler/coreSyn/PprCore.hs
+++ b/compiler/coreSyn/PprCore.hs
@@ -29,6 +29,7 @@ import Type
import Coercion
import DynFlags
import BasicTypes
+import Maybes
import Util
import Outputable
import FastString
@@ -113,7 +114,14 @@ ppr_bind ann (Rec binds) = vcat (map pp binds)
ppr_binding :: OutputableBndr b => Annotation b -> (b, Expr b) -> SDoc
ppr_binding ann (val_bdr, expr)
= ann expr $$ pprBndr LetBind val_bdr $$
- hang (ppr val_bdr <+> equals) 2 (pprCoreExpr expr)
+ hang (ppr val_bdr <+> sep (map (pprBndr LambdaBind) lhs_bndrs) <+> equals) 2
+ (pprCoreExpr rhs)
+ where
+ (bndrs, body) = collectBinders expr
+ (lhs_bndrs, rhs_bndrs) = splitAt (pprLamsOnLhs val_bdr) bndrs
+ rhs = mkLams rhs_bndrs body
+ -- Returns ([], expr) unless it's a join point, in which
+ -- case we want the args before the =
pprParendExpr expr = ppr_expr parens expr
pprCoreExpr expr = ppr_expr noParens expr
@@ -131,7 +139,8 @@ ppr_expr :: OutputableBndr b => (SDoc -> SDoc) -> Expr b -> SDoc
-- The function adds parens in context that need
-- an atomic value (e.g. function args)
-ppr_expr _ (Var name) = ppr name
+ppr_expr _ (Var name) = ppWhen (isJoinId name) (text "jump") <+>
+ ppr name
ppr_expr add_par (Type ty) = add_par (text "TYPE:" <+> ppr ty) -- Weird
ppr_expr add_par (Coercion co) = add_par (text "CO:" <+> ppr co)
ppr_expr add_par (Lit lit) = pprLiteral add_par lit
@@ -172,7 +181,10 @@ ppr_expr add_par expr@(App {})
tc = dataConTyCon dc
saturated = val_args `lengthIs` idArity f
- _ -> parens (hang (ppr f) 2 pp_args)
+ _ -> parens (hang fun_doc 2 pp_args)
+ where
+ fun_doc | isJoinId f = text "jump" <+> ppr f
+ | otherwise = ppr f
_ -> parens (hang (pprParendExpr fun) 2 pp_args)
}
@@ -239,12 +251,14 @@ ppr_expr add_par (Let bind@(NonRec val_bdr rhs) expr@(Let _ _))
-- General case (recursive case, too)
ppr_expr add_par (Let bind expr)
= add_par $
- sep [hang (ptext keyword) 2 (ppr_bind noAnn bind <+> text "} in"),
+ sep [hang (keyword <+> char '{') 2 (ppr_bind noAnn bind <+> text "} in"),
pprCoreExpr expr]
where
keyword = case bind of
- Rec _ -> (sLit "letrec {")
- NonRec _ _ -> (sLit "let {")
+ NonRec b _ -> pprNonRecBndrKeyword b
+ Rec ((b,_):_) -> pprRecBndrKeyword b
+ Rec [] -> text "let" -- This *shouldn't* happen, but
+ -- let's be tolerant here
ppr_expr add_par (Tick tickish expr)
= sdocWithDynFlags $ \dflags ->
@@ -315,6 +329,11 @@ instance OutputableBndr Var where
pprBndr = pprCoreBinder
pprInfixOcc = pprInfixName . varName
pprPrefixOcc = pprPrefixName . varName
+ pprNonRecBndrKeyword bndr | isJoinId bndr = text "join"
+ | otherwise = text "let"
+ pprRecBndrKeyword bndr | isJoinId bndr = text "joinrec"
+ | otherwise = text "letrec"
+ pprLamsOnLhs bndr = isJoinId_maybe bndr `orElse` 0
pprCoreBinder :: BindingSite -> Var -> SDoc
pprCoreBinder LetBind binder
@@ -398,7 +417,7 @@ pprIdBndrInfo info
lbv_info = oneShotInfo info
has_prag = not (isDefaultInlinePragma prag_info)
- has_occ = not (isNoOcc occ_info)
+ has_occ = not (isManyOccs occ_info)
has_dmd = not $ isTopDmd dmd_info
has_lbv = not (hasNoOneShotInfo lbv_info)
diff --git a/compiler/deSugar/DsUtils.hs b/compiler/deSugar/DsUtils.hs
index 0d336adbd9..165130aa94 100644
--- a/compiler/deSugar/DsUtils.hs
+++ b/compiler/deSugar/DsUtils.hs
@@ -893,6 +893,15 @@ for the primitive case:
\end{verbatim}
Now @fail.33@ is a function, so it can be let-bound.
+
+We would *like* to use join points here; in fact, these "fail variables" are
+paradigmatic join points! Sadly, this breaks pattern synonyms, which desugar as
+CPS functions - i.e. they take "join points" as parameters. It's not impossible
+to imagine extending our type system to allow passing join points around (very
+carefully), but we certainly don't support it now.
+
+99.99% of the time, the fail variables wind up as join points in short order
+anyway, and the Void# doesn't do much harm.
-}
mkFailurePair :: CoreExpr -- Result type of the whole case expression
@@ -912,6 +921,11 @@ mkFailurePair expr
{-
Note [Failure thunks and CPR]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+(This note predates join points as formal entities (hence the quotation marks).
+We can't use actual join points here (see above); if we did, this would also
+solve the CPR problem, since join points don't get CPR'd. See Note [Don't CPR
+join points] in WorkWrap.)
+
When we make a failure point we ensure that it
does not look like a thunk. Example:
diff --git a/compiler/iface/IfaceSyn.hs b/compiler/iface/IfaceSyn.hs
index d4dd51e1b7..7740977263 100644
--- a/compiler/iface/IfaceSyn.hs
+++ b/compiler/iface/IfaceSyn.hs
@@ -10,7 +10,7 @@ module IfaceSyn (
IfaceDecl(..), IfaceFamTyConFlav(..), IfaceClassOp(..), IfaceAT(..),
IfaceConDecl(..), IfaceConDecls(..), IfaceEqSpec,
- IfaceExpr(..), IfaceAlt, IfaceLetBndr(..),
+ IfaceExpr(..), IfaceAlt, IfaceLetBndr(..), IfaceJoinInfo(..),
IfaceBinding(..), IfaceConAlt(..),
IfaceIdInfo(..), IfaceIdDetails(..), IfaceUnfolding(..),
IfaceInfoItem(..), IfaceRule(..), IfaceAnnotation(..), IfaceAnnTarget,
@@ -502,7 +502,10 @@ data IfaceBinding
-- IfaceLetBndr is like IfaceIdBndr, but has IdInfo too
-- It's used for *non-top-level* let/rec binders
-- See Note [IdInfo on nested let-bindings]
-data IfaceLetBndr = IfLetBndr IfLclName IfaceType IfaceIdInfo
+data IfaceLetBndr = IfLetBndr IfLclName IfaceType IfaceIdInfo IfaceJoinInfo
+
+data IfaceJoinInfo = IfaceNotJoinPoint
+ | IfaceJoinPoint JoinArity
{-
Note [Empty case alternatives]
@@ -1158,8 +1161,8 @@ ppr_con_bs :: IfaceConAlt -> [IfLclName] -> SDoc
ppr_con_bs con bs = ppr con <+> hsep (map ppr bs)
ppr_bind :: (IfaceLetBndr, IfaceExpr) -> SDoc
-ppr_bind (IfLetBndr b ty info, rhs)
- = sep [hang (ppr b <+> dcolon <+> ppr ty) 2 (ppr info),
+ppr_bind (IfLetBndr b ty info ji, rhs)
+ = sep [hang (ppr b <+> dcolon <+> ppr ty) 2 (ppr ji <+> ppr info),
equals <+> pprIfaceExpr noParens rhs]
------------------
@@ -1207,6 +1210,10 @@ instance Outputable IfaceInfoItem where
ppr HsNoCafRefs = text "HasNoCafRefs"
ppr HsLevity = text "Never levity-polymorphic"
+instance Outputable IfaceJoinInfo where
+ ppr IfaceNotJoinPoint = empty
+ ppr (IfaceJoinPoint ar) = angleBrackets (text "join" <+> ppr ar)
+
instance Outputable IfaceUnfolding where
ppr (IfCompulsory e) = text "<compulsory>" <+> parens (ppr e)
ppr (IfCoreUnfold s e) = (if s
@@ -1407,8 +1414,8 @@ freeNamesIfLetBndr :: IfaceLetBndr -> NameSet
-- Remember IfaceLetBndr is used only for *nested* bindings
-- The IdInfo can contain an unfolding (in the case of
-- local INLINE pragmas), so look there too
-freeNamesIfLetBndr (IfLetBndr _name ty info) = freeNamesIfType ty
- &&& freeNamesIfIdInfo info
+freeNamesIfLetBndr (IfLetBndr _name ty info _ji) = freeNamesIfType ty
+ &&& freeNamesIfIdInfo info
freeNamesIfTvBndr :: IfaceTvBndr -> NameSet
freeNamesIfTvBndr (_fs,k) = freeNamesIfKind k
@@ -2075,14 +2082,27 @@ instance Binary IfaceBinding where
_ -> do { ac <- get bh; return (IfaceRec ac) }
instance Binary IfaceLetBndr where
- put_ bh (IfLetBndr a b c) = do
+ put_ bh (IfLetBndr a b c d) = do
put_ bh a
put_ bh b
put_ bh c
+ put_ bh d
get bh = do a <- get bh
b <- get bh
c <- get bh
- return (IfLetBndr a b c)
+ d <- get bh
+ return (IfLetBndr a b c d)
+
+instance Binary IfaceJoinInfo where
+ put_ bh IfaceNotJoinPoint = putByte bh 0
+ put_ bh (IfaceJoinPoint ar) = do
+ putByte bh 1
+ put_ bh ar
+ get bh = do
+ h <- getByte bh
+ case h of
+ 0 -> return IfaceNotJoinPoint
+ _ -> liftM IfaceJoinPoint $ get bh
instance Binary IfaceTyConParent where
put_ bh IfNoParent = putByte bh 0
diff --git a/compiler/iface/TcIface.hs b/compiler/iface/TcIface.hs
index e08a3d71f6..f6a4f41965 100644
--- a/compiler/iface/TcIface.hs
+++ b/compiler/iface/TcIface.hs
@@ -1367,12 +1367,13 @@ tcIfaceExpr (IfaceCase scrut case_bndr alts) = do
alts' <- mapM (tcIfaceAlt scrut' tc_app) alts
return (Case scrut' case_bndr' (coreAltsType alts') alts')
-tcIfaceExpr (IfaceLet (IfaceNonRec (IfLetBndr fs ty info) rhs) body)
+tcIfaceExpr (IfaceLet (IfaceNonRec (IfLetBndr fs ty info ji) rhs) body)
= do { name <- newIfaceName (mkVarOccFS fs)
; ty' <- tcIfaceType ty
; id_info <- tcIdInfo False {- Don't ignore prags; we are inside one! -}
name ty' info
; let id = mkLocalIdOrCoVarWithInfo name ty' id_info
+ `asJoinId_maybe` tcJoinInfo ji
; rhs' <- tcIfaceExpr rhs
; body' <- extendIfaceIdEnv [id] (tcIfaceExpr body)
; return (Let (NonRec id rhs') body') }
@@ -1384,11 +1385,11 @@ tcIfaceExpr (IfaceLet (IfaceRec pairs) body)
; body' <- tcIfaceExpr body
; return (Let (Rec pairs') body') } }
where
- tc_rec_bndr (IfLetBndr fs ty _)
+ tc_rec_bndr (IfLetBndr fs ty _ ji)
= do { name <- newIfaceName (mkVarOccFS fs)
; ty' <- tcIfaceType ty
- ; return (mkLocalIdOrCoVar name ty') }
- tc_pair (IfLetBndr _ _ info, rhs) id
+ ; return (mkLocalIdOrCoVar name ty' `asJoinId_maybe` tcJoinInfo ji) }
+ tc_pair (IfLetBndr _ _ info _, rhs) id
= do { rhs' <- tcIfaceExpr rhs
; id_info <- tcIdInfo False {- Don't ignore prags; we are inside one! -}
(idName id) (idType id) info
@@ -1509,6 +1510,10 @@ tcIdInfo ignore_prags name ty info = do
| otherwise = info
; return (info1 `setUnfoldingInfo` unf) }
+tcJoinInfo :: IfaceJoinInfo -> Maybe JoinArity
+tcJoinInfo (IfaceJoinPoint ar) = Just ar
+tcJoinInfo IfaceNotJoinPoint = Nothing
+
tcUnfolding :: Name -> Type -> IdInfo -> IfaceUnfolding -> IfL Unfolding
tcUnfolding name _ info (IfCoreUnfold stable if_expr)
= do { dflags <- getDynFlags
diff --git a/compiler/iface/ToIface.hs b/compiler/iface/ToIface.hs
index 696d0ffc0f..37d41f4393 100644
--- a/compiler/iface/ToIface.hs
+++ b/compiler/iface/ToIface.hs
@@ -325,6 +325,7 @@ toIfaceLetBndr :: Id -> IfaceLetBndr
toIfaceLetBndr id = IfLetBndr (occNameFS (getOccName id))
(toIfaceType (idType id))
(toIfaceIdInfo (idInfo id))
+ (toIfaceJoinInfo (isJoinId_maybe id))
-- Put into the interface file any IdInfo that CoreTidy.tidyLetBndr
-- has left on the Id. See Note [IdInfo on nested let-bindings] in IfaceSyn
@@ -382,6 +383,10 @@ toIfaceIdInfo id_info
levity_hsinfo | isNeverLevPolyIdInfo id_info = Just HsLevity
| otherwise = Nothing
+toIfaceJoinInfo :: Maybe JoinArity -> IfaceJoinInfo
+toIfaceJoinInfo (Just ar) = IfaceJoinPoint ar
+toIfaceJoinInfo Nothing = IfaceNotJoinPoint
+
--------------------------
toIfUnfolding :: Bool -> Unfolding -> Maybe IfaceInfoItem
toIfUnfolding lb (CoreUnfolding { uf_tmpl = rhs
diff --git a/compiler/simplCore/CSE.hs b/compiler/simplCore/CSE.hs
index f9314bd362..971b3e0ea6 100644
--- a/compiler/simplCore/CSE.hs
+++ b/compiler/simplCore/CSE.hs
@@ -11,7 +11,7 @@ module CSE (cseProgram) where
#include "HsVersions.h"
import CoreSubst
-import Var ( Var )
+import Var ( Var, isJoinId )
import Id ( Id, idType, idUnfolding, idInlineActivation
, zapIdOccInfo, zapIdUsageInfo )
import CoreUtils ( mkAltExpr
@@ -245,6 +245,18 @@ not if you are using unsafe casts. I actually found a case where we
had
(x :: HValue) |> (UnsafeCo :: HValue ~ Array# Int)
+Note [CSE for join points?]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We must not be naive about join points in CSE:
+ join j = e in
+ if b then jump j else 1 + e
+The expression (1 + jump j) is not good (see Note [Invariants on join points] in
+CoreSyn). This seems to come up quite seldom, but it happens (first seen
+compiling ppHtml in Haddock.Backends.Xhtml).
+
+We could try and be careful by tracking which join points are still valid at
+each subexpression, but since join points aren't allocated or shared, there's
+less to gain by trying to CSE them.
************************************************************************
* *
@@ -304,6 +316,8 @@ addBinding env in_id out_id rhs'
-- See Note [CSE for INLINE and NOINLINE]
|| isStableUnfolding (idUnfolding out_id)
-- See Note [CSE for stable unfoldings]
+ || isJoinId in_id
+ -- See Note [CSE for join points?]
-- Should we use SUBSTITUTE or EXTEND?
-- See Note [CSE for bindings]
diff --git a/compiler/simplCore/CoreMonad.hs b/compiler/simplCore/CoreMonad.hs
index 12e69b97e2..7b807765a8 100644
--- a/compiler/simplCore/CoreMonad.hs
+++ b/compiler/simplCore/CoreMonad.hs
@@ -133,6 +133,7 @@ data CoreToDo -- These are diff core-to-core passes,
| CoreTidy
| CorePrep
+ | CoreOccurAnal
instance Outputable CoreToDo where
ppr (CoreDoSimplify _ _) = text "Simplifier"
@@ -152,6 +153,7 @@ instance Outputable CoreToDo where
ppr CoreDesugarOpt = text "Desugar (after optimization)"
ppr CoreTidy = text "Tidy Core"
ppr CorePrep = text "CorePrep"
+ ppr CoreOccurAnal = text "Occurrence analysis"
ppr CoreDoPrintCore = text "Print core"
ppr (CoreDoRuleCheck {}) = text "Rule check"
ppr CoreDoNothing = text "CoreDoNothing"
diff --git a/compiler/simplCore/FloatIn.hs b/compiler/simplCore/FloatIn.hs
index f32b5a387b..1fd969e638 100644
--- a/compiler/simplCore/FloatIn.hs
+++ b/compiler/simplCore/FloatIn.hs
@@ -23,7 +23,7 @@ import MkCore
import CoreUtils ( exprIsDupable, exprIsExpandable,
exprOkForSideEffects, mkTicks )
import CoreFVs
-import Id ( isOneShotBndr, idType )
+import Id ( isJoinId, isJoinId_maybe, isOneShotBndr, idType )
import Var
import Type ( isUnliftedType )
import VarSet
@@ -31,6 +31,7 @@ import Util
import DynFlags
import Outputable
import Data.List( mapAccumL )
+import BasicTypes ( RecFlag(..), isRec )
{-
Top-level interface function, @floatInwards@. Note that we do not
@@ -160,18 +161,25 @@ fiExpr dflags to_drop ann_expr@(_,AnnApp {})
(zipWith (fiExpr dflags) arg_drops ann_args)
where
(ann_fun, ann_args, ticks) = collectAnnArgsTicks tickishFloatable ann_expr
- (extra_fvs, arg_fvs) = mapAccumL mk_arg_fvs emptyDVarSet ann_args
+ (extra_fvs0, fun_fvs)
+ | (_, AnnVar _) <- ann_fun = (freeVarsOf ann_fun, emptyDVarSet)
+ -- Don't float the binding for f into f x y z; see Note [Join points]
+ -- for why we *can't* do it when f is a join point. (If f isn't a
+ -- join point, floating it in isn't especially harmful but it's
+ -- useless since the simplifier will immediately float it back out.)
+ | otherwise = (emptyDVarSet, freeVarsOf ann_fun)
+ (extra_fvs, arg_fvs) = mapAccumL mk_arg_fvs extra_fvs0 ann_args
mk_arg_fvs :: FreeVarSet -> CoreExprWithFVs -> (FreeVarSet, FreeVarSet)
mk_arg_fvs extra_fvs ann_arg
- | noFloatIntoRhs ann_arg
+ | noFloatIntoRhs False NonRecursive ann_arg
= (extra_fvs `unionDVarSet` freeVarsOf ann_arg, emptyDVarSet)
| otherwise
= (extra_fvs, freeVarsOf ann_arg)
drop_here : extra_drop : fun_drop : arg_drops
= sepBindsByDropPoint dflags False
- (extra_fvs : freeVarsOf ann_fun : arg_fvs)
+ (extra_fvs : fun_fvs : arg_fvs)
(freeVarsOfType ann_fun `unionDVarSet`
mapUnionDVarSet freeVarsOfType ann_args)
to_drop
@@ -186,6 +194,28 @@ We don't want to float bindings into here
because that might destroy the let/app invariant, which requires
unlifted function arguments to be ok-for-speculation.
+Note [Join points]
+~~~~~~~~~~~~~~~~~~
+
+Generally, we don't need to worry about join points - there are places we're
+not allowed to float them, but since they can't have occurrences in those
+places, we're not tempted.
+
+We do need to be careful about jumps, however:
+
+ joinrec j x y z = ... in
+ jump j a b c
+
+Previous versions often floated the definition of a recursive function into its
+only non-recursive occurrence. But for a join point, this is a disaster:
+
+ (joinrec j x y z = ... in
+ jump j) a b c -- wrong!
+
+Every jump must be exact, so the jump to j must have three arguments. Hence
+we're careful not to float into the target of a jump (though we can float into
+the arguments just fine).
+
Note [Floating in past a lambda group]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* We must be careful about floating inside a value lambda.
@@ -221,6 +251,9 @@ So we treat lambda in groups, using the following rule:
This is what the 'go' function in the AnnLam case is doing.
+(Join points are handled similarly: a join point is considered one-shot iff
+it's non-recursive, so we float only into non-recursive join points.)
+
Urk! if all are tyvars, and we don't float in, we may miss an
opportunity to float inside a nested case branch
-}
@@ -308,11 +341,14 @@ fiExpr dflags to_drop (_,AnnLet (AnnNonRec id rhs) body)
rhs_fvs = freeVarsOf rhs
rule_fvs = idRuleAndUnfoldingVarsDSet id -- See Note [extra_fvs (2): free variables of rules]
- extra_fvs | noFloatIntoRhs rhs = rule_fvs `unionDVarSet` freeVarsOf rhs
- | otherwise = rule_fvs
+ extra_fvs | noFloatIntoRhs (isJoinId id) NonRecursive rhs
+ = rule_fvs `unionDVarSet` freeVarsOf rhs
+ | otherwise
+ = rule_fvs
-- See Note [extra_fvs (1): avoid floating into RHS]
-- No point in floating in only to float straight out again
- -- Ditto ok-for-speculation unlifted RHSs
+ -- We *can't* float into ok-for-speculation unlifted RHSs
+ -- But do float into join points
[shared_binds, extra_binds, rhs_binds, body_binds]
= sepBindsByDropPoint dflags False
@@ -327,7 +363,7 @@ fiExpr dflags to_drop (_,AnnLet (AnnNonRec id rhs) body)
shared_binds -- the bindings used both in rhs and body
-- Push rhs_binds into the right hand side of the binding
- rhs' = fiExpr dflags rhs_binds rhs
+ rhs' = fiRhs dflags rhs_binds id rhs
rhs_fvs' = rhs_fvs `unionDVarSet` floatedBindsFVs rhs_binds `unionDVarSet` rule_fvs
-- Don't forget the rule_fvs; the binding mentions them!
@@ -341,8 +377,8 @@ fiExpr dflags to_drop (_,AnnLet (AnnRec bindings) body)
-- See Note [extra_fvs (1,2)]
rule_fvs = mapUnionDVarSet idRuleAndUnfoldingVarsDSet ids
extra_fvs = rule_fvs `unionDVarSet`
- unionDVarSets [ freeVarsOf rhs | rhs@(_, rhs') <- rhss
- , noFloatIntoExpr rhs' ]
+ unionDVarSets [ freeVarsOf rhs | (bndr, rhs) <- bindings
+ , noFloatIntoRhs (isJoinId bndr) Recursive rhs ]
(shared_binds:extra_binds:body_binds:rhss_binds)
= sepBindsByDropPoint dflags False
@@ -367,7 +403,7 @@ fiExpr dflags to_drop (_,AnnLet (AnnRec bindings) body)
-> [(Id, CoreExpr)]
fi_bind to_drops pairs
- = [ (binder, fiExpr dflags to_drop rhs)
+ = [ (binder, fiRhs dflags to_drop binder rhs)
| ((binder, rhs), to_drop) <- zipEqual "fi_bind" pairs to_drops ]
{-
@@ -418,7 +454,8 @@ fiExpr dflags to_drop (_, AnnCase scrut case_bndr ty alts)
-- Float into the alts with the is_case flag set
(drop_here2 : alts_drops_s)
- = sepBindsByDropPoint dflags True alts_fvs all_alts_ty_fvs alts_drops
+ = sepBindsByDropPoint dflags True alts_fvs all_alts_ty_fvs
+ alts_drops
scrut_fvs = freeVarsOf scrut
alts_fvs = map alt_fvs alts
@@ -434,17 +471,29 @@ fiExpr dflags to_drop (_, AnnCase scrut case_bndr ty alts)
fi_alt to_drop (con, args, rhs) = (con, args, fiExpr dflags to_drop rhs)
+fiRhs :: DynFlags -> FloatInBinds -> CoreBndr -> CoreExprWithFVs -> CoreExpr
+fiRhs dflags to_drop bndr rhs
+ | Just join_arity <- isJoinId_maybe bndr
+ , let (bndrs, body) = collectNAnnBndrs join_arity rhs
+ = mkLams bndrs (fiExpr dflags to_drop body)
+ | otherwise
+ = fiExpr dflags to_drop rhs
+
okToFloatInside :: [Var] -> Bool
okToFloatInside bndrs = all ok bndrs
where
ok b = not (isId b) || isOneShotBndr b
-- Push the floats inside there are no non-one-shot value binders
-noFloatIntoRhs :: CoreExprWithFVs -> Bool
+noFloatIntoRhs :: Bool -> RecFlag -> CoreExprWithFVs -> Bool
-- ^ True if it's a bad idea to float bindings into this RHS
-- Preconditio: rhs :: rhs_ty
-noFloatIntoRhs rhs@(_, rhs')
- = isUnliftedType rhs_ty -- See Note [Do not destroy the let/app invariant]
+noFloatIntoRhs is_join is_rec rhs@(_, rhs')
+ | is_join
+ = isRec is_rec -- Joins are one-shot iff non-recursive
+ | otherwise
+ = isUnliftedType rhs_ty
+ -- See Note [Do not destroy the let/app invariant]
|| noFloatIntoExpr rhs'
where
rhs_ty = exprTypeFV rhs
diff --git a/compiler/simplCore/FloatOut.hs b/compiler/simplCore/FloatOut.hs
index 10955d2861..17ffba404c 100644
--- a/compiler/simplCore/FloatOut.hs
+++ b/compiler/simplCore/FloatOut.hs
@@ -19,7 +19,8 @@ import CoreMonad ( FloatOutSwitches(..) )
import DynFlags
import ErrUtils ( dumpIfSet_dyn )
-import Id ( Id, idArity, isBottomingId )
+import Id ( Id, idArity, idType, isBottomingId,
+ isJoinId, isJoinId_maybe )
import Var ( Var )
import SetLevels
import UniqSupply ( UniqSupply )
@@ -27,8 +28,11 @@ import Bag
import Util
import Maybes
import Outputable
+import Type
import qualified Data.IntMap as M
+import Data.List ( partition )
+
#include "HsVersions.h"
{-
@@ -104,6 +108,52 @@ vwhich might usefully be separated to
@
Well, maybe. We don't do this at the moment.
+Note [Join points]
+~~~~~~~~~~~~~~~~~~
+Every occurrence of a join point must be a tail call (see Note [Invariants on
+join points] in CoreSyn), so we must be careful with how far we float them. The
+mechanism for doing so is the *join ceiling*, detailed in Note [Join ceiling]
+in SetLevels. For us, the significance is that a binder might be marked to be
+dropped at the nearest boundary between tail calls and non-tail calls. For
+example:
+
+ (< join j = ... in
+ let x = < ... > in
+ case < ... > of
+ A -> ...
+ B -> ...
+ >) < ... > < ... >
+
+Here the join ceilings are marked with angle brackets. Either side of an
+application is a join ceiling, as is the scrutinee position of a case
+expression or the RHS of a let binding (but not a join point).
+
+Why do we *want* do float join points at all? After all, they're never
+allocated, so there's no sharing to be gained by floating them. However, the
+other benefit of floating is making RHSes small, and this can have a significant
+impact. In particular, stream fusion has been known to produce nested loops like
+this:
+
+ joinrec j1 x1 =
+ joinrec j2 x2 =
+ joinrec j3 x3 = ... jump j1 (x3 + 1) ... jump j2 (x3 + 1) ...
+ in jump j3 x2
+ in jump j2 x1
+ in jump j1 x
+
+(Assume x1 and x2 do *not* occur free in j3.)
+
+Here j1 and j2 are wholly superfluous---each of them merely forwards its
+argument to j3. Since j3 only refers to x3, we can float j2 and j3 to make
+everything one big mutual recursion:
+
+ joinrec j1 x1 = jump j2 x1
+ j2 x2 = jump j3 x2
+ j3 x3 = ... jump j1 (x3 + 1) ... jump j2 (x3 + 1) ...
+ in jump j1 x
+
+Now the simplifier will happily inline the trivial j1 and j2, leaving only j3.
+Without floating, we're stuck with three loops instead of one.
************************************************************************
* *
@@ -141,8 +191,11 @@ floatTopBind bind
= case (floatBind bind) of { (fs, floats, bind') ->
let float_bag = flattenTopFloats floats
in case bind' of
- Rec prs -> (fs, unitBag (Rec (addTopFloatPairs float_bag prs)))
- NonRec {} -> (fs, float_bag `snocBag` bind') }
+ -- bind' can't have unlifted values or join points, so can only be one
+ -- value bind, rec or non-rec (see comment on floatBind)
+ [Rec prs] -> (fs, unitBag (Rec (addTopFloatPairs float_bag prs)))
+ [NonRec b e] -> (fs, float_bag `snocBag` NonRec b e)
+ _ -> pprPanic "floatTopBind" (ppr bind') }
{-
************************************************************************
@@ -152,42 +205,76 @@ floatTopBind bind
************************************************************************
-}
-floatBind :: LevelledBind -> (FloatStats, FloatBinds, CoreBind)
+floatBind :: LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
+ -- Returns a list with either
+ -- * A single non-recursive binding (value or join point), or
+ -- * The following, in order:
+ -- * Zero or more non-rec unlifted bindings
+ -- * One or both of:
+ -- * A recursive group of join binds
+ -- * A recursive group of value binds
+ -- See Note [Floating out of Rec rhss] for why things get arranged this way.
floatBind (NonRec (TB var _) rhs)
- = case (floatExpr rhs) of { (fs, rhs_floats, rhs') ->
+ = case (floatRhs var rhs) of { (fs, rhs_floats, rhs') ->
-- A tiresome hack:
-- see Note [Bottoming floats: eta expansion] in SetLevels
let rhs'' | isBottomingId var = etaExpand (idArity var) rhs'
| otherwise = rhs'
- in (fs, rhs_floats, NonRec var rhs'') }
+ in (fs, rhs_floats, [NonRec var rhs'']) }
floatBind (Rec pairs)
= case floatList do_pair pairs of { (fs, rhs_floats, new_pairs) ->
- (fs, rhs_floats, Rec (concat new_pairs)) }
+ let (new_ul_pairss, new_other_pairss) = unzip new_pairs
+ (new_join_pairs, new_l_pairs) = partition (isJoinId . fst)
+ (concat new_other_pairss)
+ -- Can't put the join points and the values in the same rec group
+ new_rec_binds | null new_join_pairs = [ Rec new_l_pairs ]
+ | null new_l_pairs = [ Rec new_join_pairs ]
+ | otherwise = [ Rec new_l_pairs
+ , Rec new_join_pairs ]
+ new_non_rec_binds = [ NonRec b e | (b, e) <- concat new_ul_pairss ]
+ in
+ (fs, rhs_floats, new_non_rec_binds ++ new_rec_binds) }
where
+ do_pair :: (LevelledBndr, LevelledExpr)
+ -> (FloatStats, FloatBinds,
+ ([(Id,CoreExpr)], -- Non-recursive unlifted value bindings
+ [(Id,CoreExpr)])) -- Join points and lifted value bindings
do_pair (TB name spec, rhs)
| isTopLvl dest_lvl -- See Note [floatBind for top level]
- = case (floatExpr rhs) of { (fs, rhs_floats, rhs') ->
- (fs, emptyFloats, addTopFloatPairs (flattenTopFloats rhs_floats) [(name, rhs')])}
+ = case (floatRhs name rhs) of { (fs, rhs_floats, rhs') ->
+ (fs, emptyFloats, ([], addTopFloatPairs (flattenTopFloats rhs_floats)
+ [(name, rhs')]))}
| otherwise -- Note [Floating out of Rec rhss]
- = case (floatExpr rhs) of { (fs, rhs_floats, rhs') ->
+ = case (floatRhs name rhs) of { (fs, rhs_floats, rhs') ->
case (partitionByLevel dest_lvl rhs_floats) of { (rhs_floats', heres) ->
- case (splitRecFloats heres) of { (pairs, case_heres) ->
- (fs, rhs_floats', (name, installUnderLambdas case_heres rhs') : pairs) }}}
+ case (splitRecFloats heres) of { (ul_pairs, pairs, case_heres) ->
+ let pairs' = (name, installUnderLambdas case_heres rhs') : pairs in
+ (fs, rhs_floats', (ul_pairs, pairs')) }}}
where
dest_lvl = floatSpecLevel spec
-splitRecFloats :: Bag FloatBind -> ([(Id,CoreExpr)], Bag FloatBind)
+splitRecFloats :: Bag FloatBind
+ -> ([(Id,CoreExpr)], -- Non-recursive unlifted value bindings
+ [(Id,CoreExpr)], -- Join points and lifted value bindings
+ Bag FloatBind) -- A tail of further bindings
-- The "tail" begins with a case
-- See Note [Floating out of Rec rhss]
splitRecFloats fs
- = go [] (bagToList fs)
+ = go [] [] (bagToList fs)
where
- go prs (FloatLet (NonRec b r) : fs) = go ((b,r):prs) fs
- go prs (FloatLet (Rec prs') : fs) = go (prs' ++ prs) fs
- go prs fs = (prs, listToBag fs)
+ go ul_prs prs (FloatLet (NonRec b r) : fs) | isUnliftedType (idType b)
+ , not (isJoinId b)
+ = go ((b,r):ul_prs) prs fs
+ | otherwise
+ = go ul_prs ((b,r):prs) fs
+ go ul_prs prs (FloatLet (Rec prs') : fs) = go ul_prs (prs' ++ prs) fs
+ go ul_prs prs fs = (reverse ul_prs, prs,
+ listToBag fs)
+ -- Order only matters for
+ -- non-rec
installUnderLambdas :: Bag FloatBind -> CoreExpr -> CoreExpr
-- Note [Floating out of Rec rhss]
@@ -227,6 +314,31 @@ So, gruesomely, we split the floats into
This loses full-laziness the rare situation where there is a
FloatCase and a Rec interacting.
+If there are unlifted FloatLets (that *aren't* join points) among the floats,
+we can't add them to the recursive group without angering Core Lint, but since
+they must be ok-for-speculation, they can't actually be making any recursive
+calls, so we can safely pull them out and keep them non-recursive.
+
+(Why is something getting floated to <1,0> that doesn't make a recursive call?
+The case that came up in testing was that f *and* the unlifted binding were
+getting floated *to the same place*:
+
+ \x<2,0> ->
+ ... <3,0>
+ letrec { f<F<2,0>> =
+ ... let x'<F<2,0>> = x +# 1# in ...
+ } in ...
+
+Everything gets labeled "float to <2,0>" because it all depends on x, but this
+makes f and x' look mutually recursive when they're not.
+
+The test was shootout/k-nucleotide, as compiled using commit 47d5dd68 on the
+wip/join-points branch.
+
+TODO: This can probably be solved somehow in SetLevels. The difference between
+"this *is at* level <2,0>" and "this *depends on* level <2,0>" is very
+important.)
+
Note [floatBind for top level]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We may have a *nested* binding whose destination level is (FloatMe tOP_LEVEL), thus
@@ -285,27 +397,28 @@ floatExpr (Coercion co) = (zeroStats, emptyFloats, Coercion co)
floatExpr (Lit lit) = (zeroStats, emptyFloats, Lit lit)
floatExpr (App e a)
- = case (floatExpr e) of { (fse, floats_e, e') ->
- case (floatExpr a) of { (fsa, floats_a, a') ->
+ = case (atJoinCeiling $ floatExpr e) of { (fse, floats_e, e') ->
+ case (atJoinCeiling $ floatExpr a) of { (fsa, floats_a, a') ->
(fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}
floatExpr lam@(Lam (TB _ lam_spec) _)
= let (bndrs_w_lvls, body) = collectBinders lam
bndrs = [b | TB b _ <- bndrs_w_lvls]
- bndr_lvl = floatSpecLevel lam_spec
+ bndr_lvl = asJoinCeilLvl (floatSpecLevel lam_spec)
-- All the binders have the same level
-- See SetLevels.lvlLamBndrs
+ -- Use asJoinCeilLvl to make this the join ceiling
in
case (floatBody bndr_lvl body) of { (fs, floats, body') ->
(add_to_stats fs floats, floats, mkLams bndrs body') }
floatExpr (Tick tickish expr)
| tickish `tickishScopesLike` SoftScope -- not scoped, can just float
- = case (floatExpr expr) of { (fs, floating_defns, expr') ->
+ = case (atJoinCeiling $ floatExpr expr) of { (fs, floating_defns, expr') ->
(fs, floating_defns, Tick tickish expr') }
| not (tickishCounts tickish) || tickishCanSplit tickish
- = case (floatExpr expr) of { (fs, floating_defns, expr') ->
+ = case (atJoinCeiling $ floatExpr expr) of { (fs, floating_defns, expr') ->
let -- Annotate bindings floated outwards past an scc expression
-- with the cc. We mark that cc as "duplicated", though.
annotated_defns = wrapTick (mkNoCount tickish) floating_defns
@@ -321,25 +434,27 @@ floatExpr (Tick tickish expr)
= pprPanic "floatExpr tick" (ppr tickish)
floatExpr (Cast expr co)
- = case (floatExpr expr) of { (fs, floating_defns, expr') ->
+ = case (atJoinCeiling $ floatExpr expr) of { (fs, floating_defns, expr') ->
(fs, floating_defns, Cast expr' co) }
floatExpr (Let bind body)
= case bind_spec of
FloatMe dest_lvl
- -> case (floatBind bind) of { (fsb, bind_floats, bind') ->
+ -> case (floatBind bind) of { (fsb, bind_floats, binds') ->
case (floatExpr body) of { (fse, body_floats, body') ->
+ let new_bind_floats = foldr plusFloats emptyFloats
+ (map (unitLetFloat dest_lvl) binds') in
( add_stats fsb fse
- , bind_floats `plusFloats` unitLetFloat dest_lvl bind'
+ , bind_floats `plusFloats` new_bind_floats
`plusFloats` body_floats
, body') }}
StayPut bind_lvl -- See Note [Avoiding unnecessary floating]
- -> case (floatBind bind) of { (fsb, bind_floats, bind') ->
+ -> case (floatBind bind) of { (fsb, bind_floats, binds') ->
case (floatBody bind_lvl body) of { (fse, body_floats, body') ->
( add_stats fsb fse
, bind_floats `plusFloats` body_floats
- , Let bind' body') }}
+ , foldr Let body' binds' ) }}
where
bind_spec = case bind of
NonRec (TB _ s) _ -> s
@@ -350,8 +465,8 @@ floatExpr (Case scrut (TB case_bndr case_spec) ty alts)
= case case_spec of
FloatMe dest_lvl -- Case expression moves
| [(con@(DataAlt {}), bndrs, rhs)] <- alts
- -> case floatExpr scrut of { (fse, fde, scrut') ->
- case floatExpr rhs of { (fsb, fdb, rhs') ->
+ -> case atJoinCeiling $ floatExpr scrut of { (fse, fde, scrut') ->
+ case floatExpr rhs of { (fsb, fdb, rhs') ->
let
float = unitCaseFloat dest_lvl scrut'
case_bndr con [b | TB b _ <- bndrs]
@@ -361,7 +476,7 @@ floatExpr (Case scrut (TB case_bndr case_spec) ty alts)
-> pprPanic "Floating multi-case" (ppr alts)
StayPut bind_lvl -- Case expression stays put
- -> case floatExpr scrut of { (fse, fde, scrut') ->
+ -> case atJoinCeiling $ floatExpr scrut of { (fse, fde, scrut') ->
case floatList (float_alt bind_lvl) alts of { (fsa, fda, alts') ->
(add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
}}
@@ -370,6 +485,25 @@ floatExpr (Case scrut (TB case_bndr case_spec) ty alts)
= case (floatBody bind_lvl rhs) of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
+floatRhs :: CoreBndr
+ -> LevelledExpr
+ -> (FloatStats, FloatBinds, CoreExpr)
+floatRhs bndr rhs
+ | Just join_arity <- isJoinId_maybe bndr
+ , Just (bndrs, body) <- try_collect join_arity rhs []
+ = case bndrs of
+ [] -> floatExpr rhs
+ (TB _ lam_spec):_ ->
+ let lvl = floatSpecLevel lam_spec in
+ case floatBody lvl body of { (fs, floats, body') ->
+ (fs, floats, mkLams [b | TB b _ <- bndrs] body') }
+ | otherwise
+ = atJoinCeiling $ floatExpr rhs
+ where
+ try_collect 0 expr acc = Just (reverse acc, expr)
+ try_collect n (Lam b e) acc = try_collect (n-1) e (b:acc)
+ try_collect _ _ _ = Nothing
+
{-
Note [Avoiding unnecessary floating]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -439,8 +573,10 @@ add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
= FlS (a1 + a2) (b1 + b2) (c1 + c2)
add_to_stats :: FloatStats -> FloatBinds -> FloatStats
-add_to_stats (FlS a b c) (FB tops others)
- = FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)
+add_to_stats (FlS a b c) (FB tops ceils others)
+ = FlS (a + lengthBag tops)
+ (b + lengthBag ceils + lengthBag (flattenMajor others))
+ (c + 1)
{-
************************************************************************
@@ -474,18 +610,21 @@ type MajorEnv = M.IntMap MinorEnv -- Keyed by major level
type MinorEnv = M.IntMap (Bag FloatBind) -- Keyed by minor level
data FloatBinds = FB !(Bag FloatLet) -- Destined for top level
- !MajorEnv -- Levels other than top
+ !(Bag FloatBind) -- Destined for join ceiling
+ !MajorEnv -- Other levels
-- See Note [Representation of FloatBinds]
instance Outputable FloatBinds where
- ppr (FB fbs defs)
+ ppr (FB fbs ceils defs)
= text "FB" <+> (braces $ vcat
[ text "tops =" <+> ppr fbs
+ , text "ceils =" <+> ppr ceils
, text "non-tops =" <+> ppr defs ])
flattenTopFloats :: FloatBinds -> Bag CoreBind
-flattenTopFloats (FB tops defs)
+flattenTopFloats (FB tops ceils defs)
= ASSERT2( isEmptyBag (flattenMajor defs), ppr defs )
+ ASSERT2( isEmptyBag ceils, ppr ceils )
tops
addTopFloatPairs :: Bag CoreBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
@@ -502,22 +641,29 @@ flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor = M.foldr unionBags emptyBag
emptyFloats :: FloatBinds
-emptyFloats = FB emptyBag M.empty
+emptyFloats = FB emptyBag emptyBag M.empty
unitCaseFloat :: Level -> CoreExpr -> Id -> AltCon -> [Var] -> FloatBinds
-unitCaseFloat (Level major minor) e b con bs
- = FB emptyBag (M.singleton major (M.singleton minor (unitBag (FloatCase e b con bs))))
+unitCaseFloat (Level major minor t) e b con bs
+ | t == JoinCeilLvl
+ = FB emptyBag floats M.empty
+ | otherwise
+ = FB emptyBag emptyBag (M.singleton major (M.singleton minor floats))
+ where
+ floats = unitBag (FloatCase e b con bs)
unitLetFloat :: Level -> FloatLet -> FloatBinds
-unitLetFloat lvl@(Level major minor) b
- | isTopLvl lvl = FB (unitBag b) M.empty
- | otherwise = FB emptyBag (M.singleton major (M.singleton minor floats))
+unitLetFloat lvl@(Level major minor t) b
+ | isTopLvl lvl = FB (unitBag b) emptyBag M.empty
+ | t == JoinCeilLvl = FB emptyBag floats M.empty
+ | otherwise = FB emptyBag emptyBag (M.singleton major
+ (M.singleton minor floats))
where
floats = unitBag (FloatLet b)
plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
-plusFloats (FB t1 l1) (FB t2 l2)
- = FB (t1 `unionBags` t2) (l1 `plusMajor` l2)
+plusFloats (FB t1 c1 l1) (FB t2 c2 l2)
+ = FB (t1 `unionBags` t2) (c1 `unionBags` c2) (l1 `plusMajor` l2)
plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor = M.unionWith plusMinor
@@ -557,9 +703,10 @@ partitionByMajorLevel (Level major _) (FB tops defns)
Just h -> flattenMinor h
-}
-partitionByLevel (Level major minor) (FB tops defns)
- = (FB tops (outer_maj `plusMajor` M.singleton major outer_min),
- here_min `unionBags` flattenMinor inner_min
+partitionByLevel (Level major minor typ) (FB tops ceils defns)
+ = (FB tops ceils' (outer_maj `plusMajor` M.singleton major outer_min),
+ here_min `unionBags` here_ceil
+ `unionBags` flattenMinor inner_min
`unionBags` flattenMajor inner_maj)
where
@@ -568,10 +715,28 @@ partitionByLevel (Level major minor) (FB tops defns)
Nothing -> (M.empty, Nothing, M.empty)
Just min_defns -> M.splitLookup minor min_defns
here_min = mb_here_min `orElse` emptyBag
+ (here_ceil, ceils') | typ == JoinCeilLvl = (ceils, emptyBag)
+ | otherwise = (emptyBag, ceils)
+
+-- Like partitionByLevel, but instead split out the bindings that are marked
+-- to float to the nearest join ceiling (see Note [Join points])
+partitionAtJoinCeiling :: FloatBinds -> (FloatBinds, Bag FloatBind)
+partitionAtJoinCeiling (FB tops ceils defs)
+ = (FB tops emptyBag defs, ceils)
+
+-- Perform some action at a join ceiling, i.e., don't let join points float out
+-- (see Note [Join points])
+atJoinCeiling :: (FloatStats, FloatBinds, CoreExpr)
+ -> (FloatStats, FloatBinds, CoreExpr)
+atJoinCeiling (fs, floats, expr')
+ = (fs, floats', install ceils expr')
+ where
+ (floats', ceils) = partitionAtJoinCeiling floats
wrapTick :: Tickish Id -> FloatBinds -> FloatBinds
-wrapTick t (FB tops defns)
- = FB (mapBag wrap_bind tops) (M.map (M.map wrap_defns) defns)
+wrapTick t (FB tops ceils defns)
+ = FB (mapBag wrap_bind tops) (wrap_defns ceils)
+ (M.map (M.map wrap_defns) defns)
where
wrap_defns = mapBag wrap_one
diff --git a/compiler/simplCore/LiberateCase.hs b/compiler/simplCore/LiberateCase.hs
index 1df1405329..1776db51fd 100644
--- a/compiler/simplCore/LiberateCase.hs
+++ b/compiler/simplCore/LiberateCase.hs
@@ -197,10 +197,13 @@ libCase :: LibCaseEnv
-> CoreExpr
-> CoreExpr
-libCase env (Var v) = libCaseId env v
+libCase env (Var v) = libCaseApp env v []
libCase _ (Lit lit) = Lit lit
libCase _ (Type ty) = Type ty
libCase _ (Coercion co) = Coercion co
+libCase env e@(App {}) | let (fun, args) = collectArgs e
+ , Var v <- fun
+ = libCaseApp env v args
libCase env (App fun arg) = App (libCase env fun) (libCase env arg)
libCase env (Tick tickish body) = Tick tickish (libCase env body)
libCase env (Cast e co) = Cast (libCase env e) co
@@ -228,20 +231,31 @@ libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
{-
Ids
~~~
+
+To unfold, we can't just wrap the id itself in its binding if it's a join point:
+
+ jump j a b c => (joinrec j x y z = ... in jump j) a b c -- wrong!!!
+
+Every jump must provide all arguments, so we have to be careful to wrap the
+whole jump instead:
+
+ jump j a b c => joinrec j x y z = ... in jump j a b c -- right
+
-}
-libCaseId :: LibCaseEnv -> Id -> CoreExpr
-libCaseId env v
+libCaseApp :: LibCaseEnv -> Id -> [CoreExpr] -> CoreExpr
+libCaseApp env v args
| Just the_bind <- lookupRecId env v -- It's a use of a recursive thing
, notNull free_scruts -- with free vars scrutinised in RHS
- = Let the_bind (Var v)
+ = Let the_bind expr'
| otherwise
- = Var v
+ = expr'
where
rec_id_level = lookupLevel env v
free_scruts = freeScruts env rec_id_level
+ expr' = mkApps (Var v) (map (libCase env) args)
freeScruts :: LibCaseEnv
-> LibCaseLevel -- Level of the recursive Id
diff --git a/compiler/simplCore/OccurAnal.hs b/compiler/simplCore/OccurAnal.hs
index a50fe223f1..864d468a35 100644
--- a/compiler/simplCore/OccurAnal.hs
+++ b/compiler/simplCore/OccurAnal.hs
@@ -11,7 +11,7 @@ The occurrence analyser re-typechecks a core expression, returning a new
core expression with (hopefully) improved usage information.
-}
-{-# LANGUAGE CPP, BangPatterns #-}
+{-# LANGUAGE CPP, BangPatterns, MultiWayIf #-}
module OccurAnal (
occurAnalysePgm, occurAnalyseExpr, occurAnalyseExpr_NoBinderSwap
@@ -24,16 +24,17 @@ import CoreFVs
import CoreUtils ( exprIsTrivial, isDefaultAlt, isExpandableApp,
stripTicksTopE, mkTicks )
import Id
+import IdInfo
import Name( localiseName )
import BasicTypes
import Module( Module )
import Coercion
+import Type
import VarSet
import VarEnv
import Var
import Demand ( argOneShots, argsOneShots )
-import Maybes ( orElse )
import Digraph ( SCC(..), Node
, stronglyConnCompFromEdgedVerticesUniq
, stronglyConnCompFromEdgedVerticesUniqR )
@@ -59,7 +60,7 @@ occurAnalysePgm :: Module -- Used only in debug output
-> [CoreRule] -> [CoreVect] -> VarSet
-> CoreProgram -> CoreProgram
occurAnalysePgm this_mod active_rule imp_rules vects vectVars binds
- | isEmptyVarEnv final_usage
+ | isEmptyDetails final_usage
= occ_anald_binds
| otherwise -- See Note [Glomming]
@@ -69,14 +70,15 @@ occurAnalysePgm this_mod active_rule imp_rules vects vectVars binds
where
init_env = initOccEnv active_rule
(final_usage, occ_anald_binds) = go init_env binds
- (_, occ_anald_glommed_binds) = occAnalRecBind init_env imp_rule_edges
+ (_, occ_anald_glommed_binds) = occAnalRecBind init_env TopLevel
+ imp_rule_edges
(flattenBinds occ_anald_binds)
initial_uds
-- It's crucial to re-analyse the glommed-together bindings
-- so that we establish the right loop breakers. Otherwise
-- we can easily create an infinite loop (Trac #9583 is an example)
- initial_uds = addIdOccs emptyDetails
+ initial_uds = addManyOccsSet emptyDetails
(rulesFreeVars imp_rules `unionVarSet`
vectsFreeVars vects `unionVarSet`
vectVars)
@@ -100,7 +102,8 @@ occurAnalysePgm this_mod active_rule imp_rules vects vectVars binds
= (final_usage, bind' ++ binds')
where
(bs_usage, binds') = go env binds
- (final_usage, bind') = occAnalBind env imp_rule_edges bind bs_usage
+ (final_usage, bind') = occAnalBind env TopLevel imp_rule_edges bind
+ bs_usage
occurAnalyseExpr :: CoreExpr -> CoreExpr
-- Do occurrence analysis, and discard occurrence info returned
@@ -640,6 +643,133 @@ But watch out! If 'fs' is not chosen as a loop breaker, we may get an infinite
- now there's another opportunity to apply the RULE
This showed up when compiling Control.Concurrent.Chan.getChanContents.
+
+------------------------------------------------------------
+Note [Finding join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+It's the occurrence analyser's job to find bindings that we can turn into join
+points, but it doesn't perform that transformation right away. Rather, it marks
+the eligible bindings as part of their occurrence data, leaving it to the
+simplifier (or to simpleOptPgm) to actually change the binder's 'IdDetails'.
+The simplifier then eta-expands the RHS if needed and then updates the
+occurrence sites. Dividing the work this way means that the occurrence analyser
+still only takes one pass, yet one can always tell the difference between a
+function call and a jump by looking at the occurrence (because the same pass
+changes the 'IdDetails' and propagates the binders to their occurrence sites).
+
+To track potential join points, we use the 'occ_tail' field of OccInfo. A value
+of `AlwaysTailCalled n` indicates that every occurrence of the variable is a
+tail call with `n` arguments (counting both value and type arguments). Otherwise
+'occ_tail' will be 'NoTailCallInfo'. The tail call info flows bottom-up with the
+rest of 'OccInfo' until it goes on the binder.
+
+Note [Rules and join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Things get fiddly with rules. Suppose we have:
+
+ let j :: Int -> Int
+ j y = 2 * y
+ k :: Int -> Int -> Int
+ {-# RULES "SPEC k 0" k 0 = j #-}
+ k x y = x + 2 * y
+ in ...
+
+Now suppose that both j and k appear only as saturated tail calls in the body.
+Thus we would like to make them both join points. The rule complicates matters,
+though, as its RHS has an unapplied occurrence of j. *However*, if we were to
+eta-expand the rule, all would be well:
+
+ {-# RULES "SPEC k 0" forall a. k 0 a = j a #-}
+
+So conceivably we could notice that a potential join point would have an
+"undersaturated" rule and account for it. This would mean we could make
+something that's been specialised a join point, for instance. But local bindings
+are rarely specialised, and being overly cautious about rules only
+costs us anything when, for some `j`:
+
+ * Before specialisation, `j` has non-tail calls, so it can't be a join point.
+ * During specialisation, `j` gets specialised and thus acquires rules.
+ * Sometime afterward, the non-tail calls to `j` disappear (as dead code, say),
+ and so now `j` *could* become a join point.
+
+This appears to be very rare in practice. TODO Perhaps we should gather
+statistics to be sure.
+
+Note [Excess polymorphism and join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+In principle, if a function would be a join point except that it fails
+the polymorphism rule (see Note [The polymorphism rule of join points] in
+CoreSyn), it can still be made a join point with some effort. This is because
+all tail calls must return the same type (they return to the same context!), and
+thus if the return type depends on an argument, that argument must always be the
+same.
+
+For instance, consider:
+
+ let f :: forall a. a -> Char -> [a]
+ f @a x c = ... f @a x 'a' ...
+ in ... f @Int 1 'b' ... f @Int 2 'c' ...
+
+(where the calls are tail calls). `f` fails the polymorphism rule because its
+return type is [a], where [a] is bound. But since the type argument is always
+'Int', we can rewrite it as:
+
+ let f' :: Int -> Char -> [Int]
+ f' x c = ... f' x 'a' ...
+ in ... f' 1 'b' ... f 2 'c' ...
+
+and now we can make f' a join point:
+
+ join f' :: Int -> Char -> [Int]
+ f' x c = ... jump f' x 'a' ...
+ in ... jump f' 1 'b' ... jump f' 2 'c' ...
+
+It's not clear that this comes up often, however. TODO: Measure how often and
+add this analysis if necessary.
+
+------------------------------------------------------------
+Note [Adjusting for lambdas]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+There's a bit of a dance we need to do after analysing a lambda expression or
+a right-hand side. In particular, we need to
+
+ a) call 'markAllInsideLam' *unless* the binding is for a thunk, a one-shot
+ lambda, or a non-recursive join point; and
+ b) call 'markAllNonTailCalled' *unless* the binding is for a join point.
+
+Some examples, with how the free occurrences in e (assumed not to be a value
+lambda) get marked:
+
+ inside lam non-tail-called
+ ------------------------------------------------------------
+ let x = e No Yes
+ let f = \x -> e Yes Yes
+ let f = \x{OneShot} -> e No Yes
+ \x -> e Yes Yes
+ join j x = e No No
+ joinrec j x = e Yes No
+
+There are a few other caveats; most importantly, if we're marking a binding as
+'AlwaysTailCalled', it's *going* to be a join point, so we treat it as one so
+that the effect cascades properly. Consequently, at the time the RHS is
+analysed, we won't know what adjustments to make; thus 'occAnalLamOrRhs' must
+return the unadjusted 'UsageDetails', to be adjusted by 'adjustRhsUsage' once
+join-point-hood has been decided.
+
+Thus the overall sequence taking place in 'occAnalNonRecBind' and
+'occAnalRecBind' is as follows:
+
+ 1. Call 'occAnalLamOrRhs' to find usage information for the RHS.
+ 2. Call 'tagNonRecBinder' or 'tagRecBinders', which decides whether to make
+ the binding a join point.
+ 3. Call 'adjustRhsUsage' accordingly. (Done as part of 'tagRecBinders' when
+ recursive.)
+
+(In the recursive case, this logic is spread between 'makeNode' and
+'occAnalRec'.)
-}
------------------------------------------------------------------
@@ -647,21 +777,22 @@ This showed up when compiling Control.Concurrent.Chan.getChanContents.
------------------------------------------------------------------
occAnalBind :: OccEnv -- The incoming OccEnv
+ -> TopLevelFlag
-> ImpRuleEdges
-> CoreBind
-> UsageDetails -- Usage details of scope
-> (UsageDetails, -- Of the whole let(rec)
[CoreBind])
-occAnalBind env top_env (NonRec binder rhs) body_usage
- = occAnalNonRecBind env top_env binder rhs body_usage
-occAnalBind env top_env (Rec pairs) body_usage
- = occAnalRecBind env top_env pairs body_usage
+occAnalBind env lvl top_env (NonRec binder rhs) body_usage
+ = occAnalNonRecBind env lvl top_env binder rhs body_usage
+occAnalBind env lvl top_env (Rec pairs) body_usage
+ = occAnalRecBind env lvl top_env pairs body_usage
-----------------
-occAnalNonRecBind :: OccEnv -> ImpRuleEdges -> Var -> CoreExpr
+occAnalNonRecBind :: OccEnv -> TopLevelFlag -> ImpRuleEdges -> Var -> CoreExpr
-> UsageDetails -> (UsageDetails, [CoreBind])
-occAnalNonRecBind env imp_rule_edges binder rhs body_usage
+occAnalNonRecBind env lvl imp_rule_edges binder rhs body_usage
| isTyVar binder -- A type let; we don't gather usage info
= (body_usage, [NonRec binder rhs])
@@ -669,24 +800,36 @@ occAnalNonRecBind env imp_rule_edges binder rhs body_usage
= (body_usage, [])
| otherwise -- It's mentioned in the body
- = (body_usage' +++ rhs_usage4, [NonRec tagged_binder rhs'])
+ = (body_usage' +++ rhs_usage', [NonRec tagged_binder rhs'])
where
- (body_usage', tagged_binder) = tagBinder body_usage binder
- (rhs_usage1, rhs') = occAnalNonRecRhs env tagged_binder rhs
- rhs_usage2 = addIdOccs rhs_usage1 (idUnfoldingVars binder)
-
- rhs_usage3 = addIdOccs rhs_usage2 (idRuleVars binder)
+ (bndrs, body) = collectBinders rhs
+ (body_usage', tagged_binder) = tagNonRecBinder lvl body_usage binder
+ (rhs_usage1, bndrs', body') = occAnalNonRecRhs env tagged_binder bndrs body
+ rhs' = mkLams bndrs' body'
+ rhs_usage2 = case occAnalUnfolding env NonRecursive binder of
+ Just unf_usage -> rhs_usage1 +++ unf_usage
+ Nothing -> rhs_usage1
+ -- See Note [Unfoldings and join points]
+
+ mb_join_arity = willBeJoinId_maybe tagged_binder
+ rules_w_uds = occAnalRules env mb_join_arity NonRecursive tagged_binder
+
+ rhs_usage3 = rhs_usage2 +++ combineUsageDetailsList
+ (map (\(_, l, r) -> l +++ r) rules_w_uds)
-- See Note [Rules are extra RHSs] and Note [Rule dependency info]
- rhs_usage4 = maybe rhs_usage3 (addIdOccs rhs_usage3) $
+ rhs_usage4 = maybe rhs_usage3 (addManyOccsSet rhs_usage3) $
lookupVarEnv imp_rule_edges binder
-- See Note [Preventing loops due to imported functions rules]
+ rhs_usage' = adjustRhsUsage (willBeJoinId_maybe tagged_binder) NonRecursive
+ bndrs' rhs_usage4
+
-----------------
-occAnalRecBind :: OccEnv -> ImpRuleEdges -> [(Var,CoreExpr)]
+occAnalRecBind :: OccEnv -> TopLevelFlag -> ImpRuleEdges -> [(Var,CoreExpr)]
-> UsageDetails -> (UsageDetails, [CoreBind])
-occAnalRecBind env imp_rule_edges pairs body_usage
- = foldr occAnalRec (body_usage, []) sccs
+occAnalRecBind env lvl imp_rule_edges pairs body_usage
+ = foldr (occAnalRec lvl) (body_usage, []) sccs
-- For a recursive group, we
-- * occ-analyse all the RHSs
-- * compute strongly-connected components
@@ -703,27 +846,40 @@ occAnalRecBind env imp_rule_edges pairs body_usage
bndr_set = mkVarSet (map fst pairs)
+{-
+Note [Unfoldings and join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We assume that anything in an unfolding occurs multiple times, since unfoldings
+are often copied (that's the whole point!). But we still need to track tail
+calls for the purpose of finding join points.
+-}
+
-----------------------------
-occAnalRec :: SCC Details
+occAnalRec :: TopLevelFlag
+ -> SCC Details
-> (UsageDetails, [CoreBind])
-> (UsageDetails, [CoreBind])
-- The NonRec case is just like a Let (NonRec ...) above
-occAnalRec (AcyclicSCC (ND { nd_bndr = bndr, nd_rhs = rhs, nd_uds = rhs_uds}))
+occAnalRec lvl (AcyclicSCC (ND { nd_bndr = bndr, nd_rhs = rhs
+ , nd_uds = rhs_uds, nd_rhs_bndrs = rhs_bndrs }))
(body_uds, binds)
| not (bndr `usedIn` body_uds)
= (body_uds, binds) -- See Note [Dead code]
| otherwise -- It's mentioned in the body
- = (body_uds' +++ rhs_uds,
+ = (body_uds' +++ rhs_uds',
NonRec tagged_bndr rhs : binds)
where
- (body_uds', tagged_bndr) = tagBinder body_uds bndr
+ (body_uds', tagged_bndr) = tagNonRecBinder lvl body_uds bndr
+ rhs_uds' = adjustRhsUsage (willBeJoinId_maybe tagged_bndr) NonRecursive
+ rhs_bndrs rhs_uds
-- The Rec case is the interesting one
-- See Note [Recursive bindings: the grand plan]
-- See Note [Loop breaking]
-occAnalRec (CyclicSCC details_s) (body_uds, binds)
+occAnalRec lvl (CyclicSCC details_s) (body_uds, binds)
| not (any (`usedIn` body_uds) bndrs) -- NB: look at body_uds, not total_uds
= (body_uds, binds) -- See Note [Dead code]
@@ -738,16 +894,12 @@ occAnalRec (CyclicSCC details_s) (body_uds, binds)
bndrs = map nd_bndr details_s
bndr_set = mkVarSet bndrs
- ----------------------------
- -- Compute usage details
- total_uds = foldl add_uds body_uds details_s
- final_uds = total_uds `minusVarEnv` bndr_set
- add_uds usage_so_far nd = usage_so_far +++ nd_uds nd
-
------------------------------
-- See Note [Choosing loop breakers] for loop_breaker_nodes
+ final_uds :: UsageDetails
loop_breaker_nodes :: [LetrecNode]
- loop_breaker_nodes = mkLoopBreakerNodes bndr_set total_uds details_s
+ (final_uds, loop_breaker_nodes)
+ = mkLoopBreakerNodes lvl bndr_set body_uds details_s
------------------------------
weak_fvs :: VarSet
@@ -832,13 +984,18 @@ reOrderNodes depth bndr_set weak_fvs (node : nodes) binds
mk_loop_breaker :: LetrecNode -> Binding
mk_loop_breaker (ND { nd_bndr = bndr, nd_rhs = rhs}, _, _)
- = (setIdOccInfo bndr strongLoopBreaker, rhs)
+ = (bndr `setIdOccInfo` strongLoopBreaker { occ_tail = tail_info }, rhs)
+ where
+ tail_info = tailCallInfo (idOccInfo bndr)
mk_non_loop_breaker :: VarSet -> LetrecNode -> Binding
-- See Note [Weak loop breakers]
mk_non_loop_breaker weak_fvs (ND { nd_bndr = bndr, nd_rhs = rhs}, _, _)
- | bndr `elemVarSet` weak_fvs = (setIdOccInfo bndr weakLoopBreaker, rhs)
+ | bndr `elemVarSet` weak_fvs = (setIdOccInfo bndr occ', rhs)
| otherwise = (bndr, rhs)
+ where
+ occ' = weakLoopBreaker { occ_tail = tail_info }
+ tail_info = tailCallInfo (idOccInfo bndr)
----------------------------------
chooseLoopBreaker :: Bool -- True <=> Too many iterations,
@@ -982,7 +1139,7 @@ we choose 'plus1' as the loop breaker (which is entirely possible
otherwise), the loop does not unravel nicely.
-@occAnalRhs@ deals with the question of bindings where the Id is marked
+@occAnalUnfolding@ deals with the question of bindings where the Id is marked
by an INLINE pragma. For these we record that anything which occurs
in its RHS occurs many times. This pessimistically assumes that ths
inlined binder also occurs many times in its scope, but if it doesn't
@@ -1010,6 +1167,9 @@ type LetrecNode = Node Unique Details -- Node comes from Digraph
data Details
= ND { nd_bndr :: Id -- Binder
, nd_rhs :: CoreExpr -- RHS, already occ-analysed
+ , nd_rhs_bndrs :: [CoreBndr] -- Outer lambdas of RHS
+ -- INVARIANT: (nd_rhs_bndrs nd, _) ==
+ -- collectBinders (nd_rhs nd)
, nd_uds :: UsageDetails -- Usage from RHS, and RULES, and stable unfoldings
-- ignoring phase (ie assuming all are active)
@@ -1064,6 +1224,7 @@ makeNode env imp_rule_edges bndr_set (bndr, rhs)
where
details = ND { nd_bndr = bndr
, nd_rhs = rhs'
+ , nd_rhs_bndrs = bndrs'
, nd_uds = rhs_usage3
, nd_inl = inl_fvs
, nd_weak = node_fvs `minusVarSet` inl_fvs
@@ -1072,54 +1233,66 @@ makeNode env imp_rule_edges bndr_set (bndr, rhs)
-- Constructing the edges for the main Rec computation
-- See Note [Forming Rec groups]
- (rhs_usage1, rhs') = occAnalRecRhs env rhs
- rhs_usage2 = addIdOccs rhs_usage1 all_rule_fvs -- Note [Rules are extra RHSs]
- -- Note [Rule dependency info]
- rhs_usage3 = case mb_unf_fvs of
- Just unf_fvs -> addIdOccs rhs_usage2 unf_fvs
+ (bndrs, body) = collectBinders rhs
+ (rhs_usage1, bndrs', body') = occAnalRecRhs env bndrs body
+ rhs' = mkLams bndrs' body'
+ rhs_usage2 = rhs_usage1 +++ all_rule_uds
+ -- Note [Rules are extra RHSs]
+ -- Note [Rule dependency info]
+ rhs_usage3 = case mb_unf_uds of
+ Just unf_uds -> rhs_usage2 +++ unf_uds
Nothing -> rhs_usage2
node_fvs = udFreeVars bndr_set rhs_usage3
-- Finding the free variables of the rules
is_active = occ_rule_act env :: Activation -> Bool
- rules = filterOut isBuiltinRule (idCoreRules bndr)
- rules_w_fvs :: [(Activation, VarSet)] -- Find the RHS fvs
- rules_w_fvs = maybe id (\ids -> ((AlwaysActive, ids):)) (lookupVarEnv imp_rule_edges bndr)
- -- See Note [Preventing loops due to imported functions rules]
- [ (ru_act rule, fvs)
- | rule <- rules
- , let fvs = exprFreeVars (ru_rhs rule)
- `delVarSetList` ru_bndrs rule
- , not (isEmptyVarSet fvs) ]
- all_rule_fvs = rule_lhs_fvs `unionVarSet` rule_rhs_fvs
- rule_rhs_fvs = mapUnionVarSet snd rules_w_fvs
- rule_lhs_fvs = mapUnionVarSet (\ru -> exprsFreeVars (ru_args ru)
- `delVarSetList` ru_bndrs ru) rules
- active_rule_fvs = unionVarSets [fvs | (a,fvs) <- rules_w_fvs, is_active a]
-
- -- Finding the free variables of the INLINE pragma (if any)
- unf = realIdUnfolding bndr -- Ignore any current loop-breaker flag
- mb_unf_fvs = stableUnfoldingVars unf
+
+ rules_w_uds :: [(CoreRule, UsageDetails, UsageDetails)]
+ rules_w_uds = occAnalRules env (Just (length bndrs)) Recursive bndr
+
+ rules_w_rhs_fvs :: [(Activation, VarSet)] -- Find the RHS fvs
+ rules_w_rhs_fvs = maybe id (\ids -> ((AlwaysActive, ids):))
+ (lookupVarEnv imp_rule_edges bndr)
+ -- See Note [Preventing loops due to imported functions rules]
+ [ (ru_act rule, udFreeVars bndr_set rhs_uds)
+ | (rule, _, rhs_uds) <- rules_w_uds ]
+ all_rule_uds = combineUsageDetailsList $
+ concatMap (\(_, l, r) -> [l, r]) rules_w_uds
+ active_rule_fvs = unionVarSets [fvs | (a,fvs) <- rules_w_rhs_fvs
+ , is_active a]
+
+ -- Finding the usage details of the INLINE pragma (if any)
+ mb_unf_uds = occAnalUnfolding env Recursive bndr
-- Find the "nd_inl" free vars; for the loop-breaker phase
- inl_fvs = case mb_unf_fvs of
+ inl_fvs = case mb_unf_uds of
Nothing -> udFreeVars bndr_set rhs_usage1 -- No INLINE, use RHS
- Just unf_fvs -> unf_fvs
+ Just unf_uds -> udFreeVars bndr_set unf_uds
-- We could check for an *active* INLINE (returning
-- emptyVarSet for an inactive one), but is_active
-- isn't the right thing (it tells about
-- RULE activation), so we'd need more plumbing
-mkLoopBreakerNodes :: VarSet -> UsageDetails -> [Details] -> [LetrecNode]
--- Does three things
+mkLoopBreakerNodes :: TopLevelFlag
+ -> VarSet
+ -> UsageDetails -- for BODY of let
+ -> [Details]
+ -> (UsageDetails, -- adjusted
+ [LetrecNode])
+-- Does four things
-- a) tag each binder with its occurrence info
-- b) add a NodeScore to each node
-- c) make a Node with the right dependency edges for
-- the loop-breaker SCC analysis
-mkLoopBreakerNodes bndr_set total_uds details_s
- = map mk_lb_node details_s
+-- d) adjust each RHS's usage details according to
+-- the binder's (new) shotness and join-point-hood
+mkLoopBreakerNodes lvl bndr_set body_uds details_s
+ = (final_uds, zipWith mk_lb_node details_s bndrs')
where
- mk_lb_node nd@(ND { nd_bndr = bndr, nd_rhs = rhs, nd_inl = inl_fvs })
+ (final_uds, bndrs') = tagRecBinders lvl body_uds
+ [ (nd_bndr nd, nd_uds nd, nd_rhs_bndrs nd)
+ | nd <- details_s ]
+ mk_lb_node nd@(ND { nd_bndr = bndr, nd_rhs = rhs, nd_inl = inl_fvs }) bndr'
= (nd', varUnique bndr, nonDetKeysUFM lb_deps)
-- It's OK to use nonDetKeysUFM here as
-- stronglyConnCompFromEdgedVerticesR is still deterministic with edges
@@ -1127,7 +1300,6 @@ mkLoopBreakerNodes bndr_set total_uds details_s
-- Note [Deterministic SCC] in Digraph.
where
nd' = nd { nd_bndr = bndr', nd_score = score }
- bndr' = setBinderOcc total_uds bndr
score = nodeScore bndr bndr' rhs lb_deps
lb_deps = extendFvs_ rule_fv_env inl_fvs
@@ -1156,59 +1328,57 @@ nodeScore old_bndr new_bndr bind_rhs lb_deps
| old_bndr `elemVarSet` lb_deps -- Self-recursive things are great loop breakers
= (0, 0, True) -- See Note [Self-recursion and loop breakers]
- | otherwise -- An Id has an unfolding
- = case id_unfolding of
- DFunUnfolding { df_args = args }
- -- Never choose a DFun as a loop breaker
- -- Note [DFuns should not be loop breakers]
- -> (9, length args, is_lb)
-
- CoreUnfolding { uf_src = src, uf_tmpl = unf_rhs, uf_guidance = guide }
- | isStableSource src
- -> case guide of
- UnfWhen {} -> (6, cheapExprSize unf_rhs, is_lb)
- UnfIfGoodArgs { ug_size = size} -> (3, size, is_lb)
- UnfNever -> (0, 0, is_lb)
- -- See Note [Loop breakers and INLINE/INLINABLE pragmas] for
- -- the 6 vs 3 choice
-
- -- Note that this case hits /all/ stable unfoldings, so we
- -- never look at 'bind_rhs' for stable unfoldings. That's right, because
- -- 'rhs' is irrelevant for inlining things with a stable unfolding
-
- -- Data structures are more important than INLINE pragmas
- -- so that dictionary/method recursion unravels
-
- _ | exprIsTrivial bind_rhs
- -> mk_score 10 -- Practically certain to be inlined
- -- Used to have also: && not (isExportedId bndr)
- -- But I found this sometimes cost an extra iteration when we have
- -- rec { d = (a,b); a = ...df...; b = ...df...; df = d }
- -- where df is the exported dictionary. Then df makes a really
- -- bad choice for loop breaker
-
- | is_con_app bind_rhs -- Data types help with cases: Note [Constructor applications]
- -> mk_score 5
-
- | isOneOcc (idOccInfo new_bndr)
- -> mk_score 2 -- Likely to be inlined
-
- | canUnfold id_unfolding -- The Id has some kind of unfolding
- -> mk_score 1
+ | exprIsTrivial rhs
+ = mk_score 10 -- Practically certain to be inlined
+ -- Used to have also: && not (isExportedId bndr)
+ -- But I found this sometimes cost an extra iteration when we have
+ -- rec { d = (a,b); a = ...df...; b = ...df...; df = d }
+ -- where df is the exported dictionary. Then df makes a really
+ -- bad choice for loop breaker
- | otherwise
- -> (0, 0, is_lb)
+ | DFunUnfolding { df_args = args } <- id_unfolding
+ -- Never choose a DFun as a loop breaker
+ -- Note [DFuns should not be loop breakers]
+ = (9, length args, is_lb)
+
+ -- Data structures are more important than INLINE pragmas
+ -- so that dictionary/method recursion unravels
+
+ | CoreUnfolding { uf_guidance = UnfWhen {} } <- id_unfolding
+ = mk_score 6
+
+ | is_con_app rhs -- Data types help with cases:
+ = mk_score 5 -- Note [Constructor applications]
+
+ | isStableUnfolding id_unfolding
+ , canUnfold id_unfolding
+ = mk_score 3
+
+ | isOneOcc (idOccInfo new_bndr)
+ = mk_score 2 -- Likely to be inlined
+
+ | canUnfold id_unfolding -- The Id has some kind of unfolding
+ = mk_score 1
+
+ | otherwise
+ = (0, 0, is_lb)
where
mk_score :: Int -> NodeScore
mk_score rank = (rank, rhs_size, is_lb)
is_lb = isStrongLoopBreaker (idOccInfo old_bndr)
+ rhs = case id_unfolding of
+ CoreUnfolding { uf_src = src, uf_tmpl = unf_rhs }
+ | isStableSource src
+ -> unf_rhs
+ _ -> bind_rhs
+ -- 'bind_rhs' is irrelevant for inlining things with a stable unfolding
rhs_size = case id_unfolding of
CoreUnfolding { uf_guidance = guidance }
| UnfIfGoodArgs { ug_size = size } <- guidance
-> size
- _ -> cheapExprSize bind_rhs
+ _ -> cheapExprSize rhs
id_unfolding = realIdUnfolding old_bndr
-- realIdUnfolding: Ignore loop-breaker-ness here because
@@ -1349,20 +1519,29 @@ Hence the is_lb field of NodeScore
************************************************************************
-}
-occAnalRecRhs :: OccEnv -> CoreExpr -- Rhs
- -> (UsageDetails, CoreExpr)
+occAnalRhs :: OccEnv -> RecFlag -> Id -> [CoreBndr] -> CoreExpr
+ -> (UsageDetails, [CoreBndr], CoreExpr)
-- Returned usage details covers only the RHS,
-- and *not* the RULE or INLINE template for the Id
-occAnalRecRhs env rhs = occAnal (rhsCtxt env) rhs
+occAnalRhs env Recursive _ bndrs body
+ = occAnalRecRhs env bndrs body
+occAnalRhs env NonRecursive id bndrs body
+ = occAnalNonRecRhs env id bndrs body
+
+occAnalRecRhs :: OccEnv -> [CoreBndr] -> CoreExpr -- Rhs lambdas, body
+ -> (UsageDetails, [CoreBndr], CoreExpr)
+ -- Returned usage details covers only the RHS,
+ -- and *not* the RULE or INLINE template for the Id
+occAnalRecRhs env bndrs body = occAnalLamOrRhs (rhsCtxt env) bndrs body
occAnalNonRecRhs :: OccEnv
- -> Id -> CoreExpr -- Binder and rhs
+ -> Id -> [CoreBndr] -> CoreExpr -- Binder; rhs lams, body
-- Binder is already tagged with occurrence info
- -> (UsageDetails, CoreExpr)
+ -> (UsageDetails, [CoreBndr], CoreExpr)
-- Returned usage details covers only the RHS,
-- and *not* the RULE or INLINE template for the Id
-occAnalNonRecRhs env bndr rhs
- = occAnal rhs_env rhs
+occAnalNonRecRhs env bndr bndrs body
+ = occAnalLamOrRhs rhs_env bndrs body
where
-- See Note [Cascading inlines]
env1 | certainly_inline = env
@@ -1374,13 +1553,70 @@ occAnalNonRecRhs env bndr rhs
certainly_inline -- See Note [Cascading inlines]
= case idOccInfo bndr of
- OneOcc in_lam one_br _ -> not in_lam && one_br && active && not_stable
+ OneOcc { occ_in_lam = in_lam, occ_one_br = one_br }
+ -> not in_lam && one_br && active && not_stable
_ -> False
dmd = idDemandInfo bndr
active = isAlwaysActive (idInlineActivation bndr)
not_stable = not (isStableUnfolding (idUnfolding bndr))
+occAnalUnfolding :: OccEnv
+ -> RecFlag
+ -> Id
+ -> Maybe UsageDetails
+ -- Just the analysis, not a new unfolding. The unfolding
+ -- got analysed when it was created and we don't need to
+ -- update it.
+occAnalUnfolding env rec_flag id
+ = case realIdUnfolding id of -- ignore previous loop-breaker flag
+ CoreUnfolding { uf_tmpl = rhs, uf_src = src }
+ | not (isStableSource src)
+ -> Nothing
+ | otherwise
+ -> Just $ zapDetails usage
+ where
+ (bndrs, body) = collectBinders rhs
+ (usage, _, _) = occAnalRhs env rec_flag id bndrs body
+
+ DFunUnfolding { df_bndrs = bndrs, df_args = args }
+ -> Just $ zapDetails (delDetailsList usage bndrs)
+ where
+ usage = foldr (+++) emptyDetails (map (fst . occAnal env) args)
+
+ _ -> Nothing
+
+occAnalRules :: OccEnv
+ -> Maybe JoinArity -- If the binder is (or MAY become) a join
+ -- point, what its join arity is (or WOULD
+ -- become). See Note [Rules and join points].
+ -> RecFlag
+ -> Id
+ -> [(CoreRule, -- Each (non-built-in) rule
+ UsageDetails, -- Usage details for LHS
+ UsageDetails)] -- Usage details for RHS
+occAnalRules env mb_expected_join_arity rec_flag id
+ = [ (rule, lhs_uds, rhs_uds) | rule@Rule {} <- idCoreRules id
+ , let (lhs_uds, rhs_uds) = occ_anal_rule rule ]
+ where
+ occ_anal_rule (Rule { ru_bndrs = bndrs, ru_args = args, ru_rhs = rhs })
+ = (lhs_uds, final_rhs_uds)
+ where
+ lhs_uds = addManyOccsSet emptyDetails $
+ (exprsFreeVars args `delVarSetList` bndrs)
+ (rhs_bndrs, rhs_body) = collectBinders rhs
+ (rhs_uds, _, _) = occAnalRhs env rec_flag id rhs_bndrs rhs_body
+ -- Note [Rules are extra RHSs]
+ -- Note [Rule dependency info]
+ final_rhs_uds = adjust_tail_info bndrs $ markAllMany $
+ (rhs_uds `delDetailsList` bndrs)
+ occ_anal_rule _
+ = (emptyDetails, emptyDetails)
+
+ adjust_tail_info bndrs uds -- see Note [Rules and join points]
+ = case mb_expected_join_arity of
+ Just ar | bndrs `lengthIs` ar -> uds
+ _ -> markAllNonTailCalled uds
{-
Note [Cascading inlines]
~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1437,8 +1673,8 @@ occAnal :: OccEnv
occAnal _ expr@(Type _) = (emptyDetails, expr)
occAnal _ expr@(Lit _) = (emptyDetails, expr)
-occAnal env expr@(Var v) = (mkOneOcc env v False, expr)
- -- At one stage, I gathered the idRuleVars for v here too,
+occAnal env expr@(Var _) = occAnalApp env (expr, [], [])
+ -- At one stage, I gathered the idRuleVars for the variable here too,
-- which in a way is the right thing to do.
-- But that went wrong right after specialisation, when
-- the *occurrences* of the overloaded function didn't have any
@@ -1446,7 +1682,7 @@ occAnal env expr@(Var v) = (mkOneOcc env v False, expr)
-- weren't used at all.
occAnal _ (Coercion co)
- = (addIdOccs emptyDetails (coVarsOfCo co), Coercion co)
+ = (addManyOccsSet emptyDetails (coVarsOfCo co), Coercion co)
-- See Note [Gather occurrences of coercion variables]
{-
@@ -1458,10 +1694,10 @@ we can sort them into the right place when doing dependency analysis.
occAnal env (Tick tickish body)
| tickish `tickishScopesLike` SoftScope
- = (usage, Tick tickish body')
+ = (markAllNonTailCalled usage, Tick tickish body')
| Breakpoint _ ids <- tickish
- = (usage_lam +++ mkVarEnv (zip ids (repeat NoOccInfo)), Tick tickish body')
+ = (usage_lam +++ foldr addManyOccs emptyDetails ids, Tick tickish body')
-- never substitute for any of the Ids in a Breakpoint
| otherwise
@@ -1469,14 +1705,20 @@ occAnal env (Tick tickish body)
where
!(usage,body') = occAnal env body
-- for a non-soft tick scope, we can inline lambdas only
- usage_lam = mapVarEnv markInsideLam usage
+ usage_lam = markAllNonTailCalled (markAllInsideLam usage)
+ -- TODO There may be ways to make ticks and join points play
+ -- nicer together, but right now there are problems:
+ -- let j x = ... in tick<t> (j 1)
+ -- Making j a join point may cause the simplifier to drop t
+ -- (if the tick is put into the continuation). So we don't
+ -- count j 1 as a tail call.
occAnal env (Cast expr co)
= case occAnal env expr of { (usage, expr') ->
- let usage1 = markManyIf (isRhsEnv env) usage
- usage2 = addIdOccs usage1 (coVarsOfCo co)
+ let usage1 = zapDetailsIf (isRhsEnv env) usage
+ usage2 = addManyOccsSet usage1 (coVarsOfCo co)
-- See Note [Gather occurrences of coercion variables]
- in (usage2, Cast expr' co)
+ in (markAllNonTailCalled usage2, Cast expr' co)
-- If we see let x = y `cast` co
-- then mark y as 'Many' so that we don't
-- immediately inline y again.
@@ -1491,7 +1733,7 @@ occAnal env app@(App _ _)
occAnal env (Lam x body) | isTyVar x
= case occAnal env body of { (body_usage, body') ->
- (body_usage, Lam x body')
+ (markAllNonTailCalled body_usage, Lam x body')
}
-- For value lambdas we do a special hack. Consider
@@ -1504,19 +1746,17 @@ occAnal env (Lam x body) | isTyVar x
-- Then, the simplifier is careful when partially applying lambdas.
occAnal env expr@(Lam _ _)
- = case occAnal env_body body of { (body_usage, body') ->
+ = case occAnalLamOrRhs env binders body of { (usage, tagged_binders, body') ->
let
- (final_usage, tagged_binders) = tagLamBinders body_usage binders'
- -- Use binders' to put one-shot info on the lambdas
-
- really_final_usage
- | all isOneShotBndr binders' = final_usage
- | otherwise = mapVarEnv markInsideLam final_usage
+ expr' = mkLams tagged_binders body'
+ final_usage | all isOneShotBndr tagged_binders
+ = markAllNonTailCalled usage
+ | otherwise
+ = markAllInsideLam $ markAllNonTailCalled usage
in
- (really_final_usage, mkLams tagged_binders body') }
+ (final_usage, expr') }
where
(binders, body) = collectBinders expr
- (env_body, binders') = oneShotGroup env binders
occAnal env (Case scrut bndr ty alts)
= case occ_anal_scrut scrut alts of { (scrut_usage, scrut') ->
@@ -1524,7 +1764,8 @@ occAnal env (Case scrut bndr ty alts)
let
alts_usage = foldr combineAltsUsageDetails emptyDetails alts_usage_s
(alts_usage1, tagged_bndr) = tag_case_bndr alts_usage bndr
- total_usage = scrut_usage +++ alts_usage1
+ total_usage = markAllNonTailCalled scrut_usage +++ alts_usage1
+ -- Alts can have tail calls, but the scrutinee can't
in
total_usage `seq` (total_usage, Case scrut' tagged_bndr ty alts') }}
where
@@ -1538,18 +1779,21 @@ occAnal env (Case scrut bndr ty alts)
-- into
-- case x of w { (p,q) -> f (p,q) }
tag_case_bndr usage bndr
- = case lookupVarEnv usage bndr of
- Nothing -> (usage, setIdOccInfo bndr IAmDead)
- Just _ -> (usage `delVarEnv` bndr, setIdOccInfo bndr NoOccInfo)
+ = (usage', setIdOccInfo bndr final_occ_info)
+ where
+ occ_info = lookupDetails usage bndr
+ usage' = usage `delDetails` bndr
+ final_occ_info = case occ_info of IAmDead -> IAmDead
+ _ -> noOccInfo
alt_env = mkAltEnv env scrut bndr
occ_anal_alt = occAnalAlt alt_env
occ_anal_scrut (Var v) (alt1 : other_alts)
| not (null other_alts) || not (isDefaultAlt alt1)
- = (mkOneOcc env v True, Var v) -- The 'True' says that the variable occurs
- -- in an interesting context; the case has
- -- at least one non-default alternative
+ = (mkOneOcc env v True 0, Var v)
+ -- The 'True' says that the variable occurs in an interesting
+ -- context; the case has at least one non-default alternative
occ_anal_scrut (Tick t e) alts
| t `tickishScopesLike` SoftScope
-- No reason to not look through all ticks here, but only
@@ -1561,8 +1805,10 @@ occAnal env (Case scrut bndr ty alts)
= occAnal (vanillaCtxt env) scrut -- No need for rhsCtxt
occAnal env (Let bind body)
- = case occAnal env body of { (body_usage, body') ->
- case occAnalBind env noImpRuleEdges bind body_usage of { (final_usage, new_binds) ->
+ = case occAnal env body of { (body_usage, body') ->
+ case occAnalBind env NotTopLevel
+ noImpRuleEdges bind
+ body_usage of { (final_usage, new_binds) ->
(final_usage, mkLets new_binds body') }}
occAnalArgs :: OccEnv -> [CoreExpr] -> [OneShots] -> (UsageDetails, [CoreExpr])
@@ -1608,8 +1854,9 @@ occAnalApp env (Var fun, args, ticks)
!(args_uds, args') = occAnalArgs env args one_shots
!final_args_uds
- | isRhsEnv env && is_exp = mapVarEnv markInsideLam args_uds
- | otherwise = args_uds
+ | isRhsEnv env && is_exp = markAllNonTailCalled $
+ markAllInsideLam args_uds
+ | otherwise = markAllNonTailCalled args_uds
-- We mark the free vars of the argument of a constructor or PAP
-- as "inside-lambda", if it is the RHS of a let(rec).
-- This means that nothing gets inlined into a constructor or PAP
@@ -1621,7 +1868,8 @@ occAnalApp env (Var fun, args, ticks)
-- See Note [Arguments of let-bound constructors]
n_val_args = valArgCount args
- fun_uds = mkOneOcc env fun (n_val_args > 0)
+ n_args = length args
+ fun_uds = mkOneOcc env fun (n_val_args > 0) n_args
is_exp = isExpandableApp fun n_val_args
-- See Note [CONLIKE pragma] in BasicTypes
-- The definition of is_exp should match that in
@@ -1631,7 +1879,8 @@ occAnalApp env (Var fun, args, ticks)
-- See Note [Use one-shot info]
occAnalApp env (fun, args, ticks)
- = (fun_uds +++ args_uds, mkTicks ticks $ mkApps fun' args')
+ = (markAllNonTailCalled (fun_uds +++ args_uds),
+ mkTicks ticks $ mkApps fun' args')
where
!(fun_uds, fun') = occAnal (addAppCtxt env args) fun
-- The addAppCtxt is a bit cunning. One iteration of the simplifier
@@ -1642,11 +1891,11 @@ occAnalApp env (fun, args, ticks)
-- onto the context stack.
!(args_uds, args') = occAnalArgs env args []
-markManyIf :: Bool -- If this is true
- -> UsageDetails -- Then do markMany on this
- -> UsageDetails
-markManyIf True uds = mapVarEnv markMany uds
-markManyIf False uds = uds
+zapDetailsIf :: Bool -- If this is true
+ -> UsageDetails -- Then do zapDetails on this
+ -> UsageDetails
+zapDetailsIf True uds = zapDetails uds
+zapDetailsIf False uds = uds
{-
Note [Use one-shot information]
@@ -1690,6 +1939,28 @@ life, beause it binds 'y' to (a,b) (imagine got inlined and
scrutinised y).
-}
+occAnalLamOrRhs :: OccEnv -> [CoreBndr] -> CoreExpr
+ -> (UsageDetails, [CoreBndr], CoreExpr)
+occAnalLamOrRhs env [] body
+ = case occAnal env body of (body_usage, body') -> (body_usage, [], body')
+ -- RHS of thunk or nullary join point
+occAnalLamOrRhs env (bndr:bndrs) body
+ | isTyVar bndr
+ = -- Important: Keep the environment so that we don't inline into an RHS like
+ -- \(@ x) -> C @x (f @x)
+ -- (see the beginning of Note [Cascading inlines]).
+ case occAnalLamOrRhs env bndrs body of
+ (body_usage, bndrs', body') -> (body_usage, bndr:bndrs', body')
+occAnalLamOrRhs env binders body
+ = case occAnal env_body body of { (body_usage, body') ->
+ let
+ (final_usage, tagged_binders) = tagLamBinders body_usage binders'
+ -- Use binders' to put one-shot info on the lambdas
+ in
+ (final_usage, tagged_binders, body') }
+ where
+ (env_body, binders') = oneShotGroup env binders
+
occAnalAlt :: (OccEnv, Maybe (Id, CoreExpr))
-> CoreAlt
-> (UsageDetails, Alt IdWithOccInfo)
@@ -1722,7 +1993,7 @@ wrapAltRHS env (Just (scrut_var, let_rhs)) alt_usg bndrs alt_rhs
-- if the scrutinee was a cast, so we must gather their
-- usage. See Note [Gather occurrences of coercion variables]
(let_rhs_usg, let_rhs') = occAnal env let_rhs
- (alt_usg', tagged_scrut_var) = tagBinder alt_usg scrut_var
+ (alt_usg', [tagged_scrut_var]) = tagLamBinders alt_usg [scrut_var]
wrapAltRHS _ _ alt_usg _ alt_rhs
= (alt_usg, alt_rhs)
@@ -2054,48 +2325,191 @@ mkAltEnv env@(OccEnv { occ_gbl_scrut = pe }) scrut case_bndr
\subsection[OccurAnal-types]{OccEnv}
* *
************************************************************************
+
+Note [UsageDetails and zapping]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+On many occasions, we must modify all gathered occurrence data at once. For
+instance, all occurrences underneath a (non-one-shot) lambda set the
+'occ_in_lam' flag to become 'True'. We could use 'mapVarEnv' to do this, but
+that takes O(n) time and we will do this often---in particular, there are many
+places where tail calls are not allowed, and each of these causes all variables
+to get marked with 'NoTailCallInfo'.
+
+Instead of relying on `mapVarEnv`, then, we carry three 'IdEnv's around along
+with the 'OccInfoEnv'. Each of these extra environments is a "zapped set"
+recording which variables have been zapped in some way. Zapping all occurrence
+info then simply means setting the corresponding zapped set to the whole
+'OccInfoEnv', a fast O(1) operation.
-}
-type UsageDetails = IdEnv OccInfo -- A finite map from ids to their usage
+type OccInfoEnv = IdEnv OccInfo -- A finite map from ids to their usage
-- INVARIANT: never IAmDead
-- (Deadness is signalled by not being in the map at all)
+type ZappedSet = OccInfoEnv -- Values are ignored
+
+data UsageDetails
+ = UD { ud_env :: !OccInfoEnv
+ , ud_z_many :: ZappedSet -- apply 'markMany' to these
+ , ud_z_in_lam :: ZappedSet -- apply 'markInsideLam' to these
+ , ud_z_no_tail :: ZappedSet } -- apply 'markNonTailCalled' to these
+ -- INVARIANT: All three zapped sets are subsets of the OccInfoEnv
+
+instance Outputable UsageDetails where
+ ppr ud = ppr (ud_env (flattenUsageDetails ud))
+
+-------------------
+-- UsageDetails API
+
(+++), combineAltsUsageDetails
:: UsageDetails -> UsageDetails -> UsageDetails
+(+++) = combineUsageDetailsWith addOccInfo
+combineAltsUsageDetails = combineUsageDetailsWith orOccInfo
-(+++) usage1 usage2
- = plusVarEnv_C addOccInfo usage1 usage2
-
-combineAltsUsageDetails usage1 usage2
- = plusVarEnv_C orOccInfo usage1 usage2
+combineUsageDetailsList :: [UsageDetails] -> UsageDetails
+combineUsageDetailsList = foldl (+++) emptyDetails
-addOneOcc :: UsageDetails -> Id -> OccInfo -> UsageDetails
-addOneOcc usage id info
- = plusVarEnv_C addOccInfo usage (unitVarEnv id info)
- -- ToDo: make this more efficient
+mkOneOcc :: OccEnv -> Id -> InterestingCxt -> JoinArity -> UsageDetails
+mkOneOcc env id int_cxt arity
+ | isLocalId id
+ = singleton $ OneOcc { occ_in_lam = False
+ , occ_one_br = True
+ , occ_int_cxt = int_cxt
+ , occ_tail = AlwaysTailCalled arity }
+ | id `elemVarEnv` occ_gbl_scrut env
+ = singleton noOccInfo
-emptyDetails :: UsageDetails
-emptyDetails = (emptyVarEnv :: UsageDetails)
+ | otherwise
+ = emptyDetails
+ where
+ singleton info = emptyDetails { ud_env = unitVarEnv id info }
-usedIn :: Id -> UsageDetails -> Bool
-v `usedIn` details = isExportedId v || v `elemVarEnv` details
+addOneOcc :: UsageDetails -> Id -> OccInfo -> UsageDetails
+addOneOcc ud id info
+ = ud { ud_env = extendVarEnv_C plus_zapped (ud_env ud) id info }
+ `alterZappedSets` (`delVarEnv` id)
+ where
+ plus_zapped old new = doZapping ud id old `addOccInfo` new
-addIdOccs :: UsageDetails -> VarSet -> UsageDetails
-addIdOccs usage id_set = nonDetFoldUFM addIdOcc usage id_set
- -- It's OK to use nonDetFoldUFM here because addIdOcc commutes
+addManyOccsSet :: UsageDetails -> VarSet -> UsageDetails
+addManyOccsSet usage id_set = nonDetFoldUFM addManyOccs usage id_set
+ -- It's OK to use nonDetFoldUFM here because addManyOccs commutes
-addIdOcc :: Id -> UsageDetails -> UsageDetails
-addIdOcc v u | isId v = addOneOcc u v NoOccInfo
- | otherwise = u
- -- Give a non-committal binder info (i.e NoOccInfo) because
+-- Add several occurrences, assumed not to be tail calls
+addManyOccs :: Var -> UsageDetails -> UsageDetails
+addManyOccs v u | isId v = addOneOcc u v noOccInfo
+ | otherwise = u
+ -- Give a non-committal binder info (i.e noOccInfo) because
-- a) Many copies of the specialised thing can appear
-- b) We don't want to substitute a BIG expression inside a RULE
-- even if that's the only occurrence of the thing
-- (Same goes for INLINE.)
+delDetails :: UsageDetails -> Id -> UsageDetails
+delDetails ud bndr
+ = ud `alterUsageDetails` (`delVarEnv` bndr)
+
+delDetailsList :: UsageDetails -> [Id] -> UsageDetails
+delDetailsList ud bndrs
+ = ud `alterUsageDetails` (`delVarEnvList` bndrs)
+
+emptyDetails :: UsageDetails
+emptyDetails = UD { ud_env = emptyVarEnv
+ , ud_z_many = emptyVarEnv
+ , ud_z_in_lam = emptyVarEnv
+ , ud_z_no_tail = emptyVarEnv }
+
+isEmptyDetails :: UsageDetails -> Bool
+isEmptyDetails = isEmptyVarEnv . ud_env
+
+markAllMany, markAllInsideLam, markAllNonTailCalled, zapDetails
+ :: UsageDetails -> UsageDetails
+markAllMany ud = ud { ud_z_many = ud_env ud }
+markAllInsideLam ud = ud { ud_z_in_lam = ud_env ud }
+markAllNonTailCalled ud = ud { ud_z_no_tail = ud_env ud }
+
+zapDetails = markAllMany . markAllNonTailCalled -- effectively sets to noOccInfo
+
+lookupDetails :: UsageDetails -> Id -> OccInfo
+lookupDetails ud id
+ = case lookupVarEnv (ud_env ud) id of
+ Just occ -> doZapping ud id occ
+ Nothing -> IAmDead
+
+usedIn :: Id -> UsageDetails -> Bool
+v `usedIn` ud = isExportedId v || v `elemVarEnv` ud_env ud
+
udFreeVars :: VarSet -> UsageDetails -> VarSet
-- Find the subset of bndrs that are mentioned in uds
-udFreeVars bndrs uds = intersectUFM_C (\b _ -> b) bndrs uds
+udFreeVars bndrs ud = intersectUFM_C (\b _ -> b) bndrs (ud_env ud)
+
+-------------------
+-- Auxiliary functions for UsageDetails implementation
+
+combineUsageDetailsWith :: (OccInfo -> OccInfo -> OccInfo)
+ -> UsageDetails -> UsageDetails -> UsageDetails
+combineUsageDetailsWith plus_occ_info ud1 ud2
+ | isEmptyDetails ud1 = ud2
+ | isEmptyDetails ud2 = ud1
+ | otherwise
+ = UD { ud_env = plusVarEnv_C plus_occ_info (ud_env ud1) (ud_env ud2)
+ , ud_z_many = plusVarEnv (ud_z_many ud1) (ud_z_many ud2)
+ , ud_z_in_lam = plusVarEnv (ud_z_in_lam ud1) (ud_z_in_lam ud2)
+ , ud_z_no_tail = plusVarEnv (ud_z_no_tail ud1) (ud_z_no_tail ud2) }
+
+doZapping :: UsageDetails -> Var -> OccInfo -> OccInfo
+doZapping ud var occ
+ = doZappingByUnique ud (varUnique var) occ
+
+doZappingByUnique :: UsageDetails -> Unique -> OccInfo -> OccInfo
+doZappingByUnique ud uniq
+ = (if | in_subset ud_z_many -> markMany
+ | in_subset ud_z_in_lam -> markInsideLam
+ | otherwise -> id) .
+ (if | in_subset ud_z_no_tail -> markNonTailCalled
+ | otherwise -> id)
+ where
+ in_subset field = uniq `elemVarEnvByKey` field ud
+
+alterZappedSets :: UsageDetails -> (ZappedSet -> ZappedSet) -> UsageDetails
+alterZappedSets ud f
+ = ud { ud_z_many = f (ud_z_many ud)
+ , ud_z_in_lam = f (ud_z_in_lam ud)
+ , ud_z_no_tail = f (ud_z_no_tail ud) }
+
+alterUsageDetails :: UsageDetails -> (OccInfoEnv -> OccInfoEnv) -> UsageDetails
+alterUsageDetails ud f
+ = ud { ud_env = f (ud_env ud) }
+ `alterZappedSets` f
+
+flattenUsageDetails :: UsageDetails -> UsageDetails
+flattenUsageDetails ud
+ = ud { ud_env = mapUFM_Directly (doZappingByUnique ud) (ud_env ud) }
+ `alterZappedSets` const emptyVarEnv
+
+-------------------
+-- See Note [Adjusting right-hand sides]
+adjustRhsUsage :: Maybe JoinArity -> RecFlag
+ -> [CoreBndr] -- Outer lambdas, AFTER occ anal
+ -> UsageDetails -> UsageDetails
+adjustRhsUsage mb_join_arity rec_flag bndrs usage
+ = maybe_mark_lam (maybe_drop_tails usage)
+ where
+ maybe_mark_lam ud | one_shot = ud
+ | otherwise = markAllInsideLam ud
+ maybe_drop_tails ud | exact_join = ud
+ | otherwise = markAllNonTailCalled ud
+
+ one_shot = case mb_join_arity of
+ Just join_arity
+ | isRec rec_flag -> False
+ | otherwise -> all isOneShotBndr (drop join_arity bndrs)
+ Nothing -> all isOneShotBndr bndrs
+
+ exact_join = case mb_join_arity of
+ Just join_arity -> join_arity == length bndrs
+ _ -> False
type IdWithOccInfo = Id
@@ -2109,37 +2523,145 @@ tagLamBinders :: UsageDetails -- Of scope
tagLamBinders usage binders = usage' `seq` (usage', bndrs')
where
(usage', bndrs') = mapAccumR tag_lam usage binders
- tag_lam usage bndr = (usage2, setBinderOcc usage bndr)
+ tag_lam usage bndr = (usage2, bndr')
where
- usage1 = usage `delVarEnv` bndr
- usage2 | isId bndr = addIdOccs usage1 (idUnfoldingVars bndr)
+ occ = lookupDetails usage bndr
+ bndr' = setBinderOcc (markNonTailCalled occ) bndr
+ -- Don't try to make an argument into a join point
+ usage1 = usage `delDetails` bndr
+ usage2 | isId bndr = addManyOccsSet usage1 (idUnfoldingVars bndr)
+ -- This is effectively the RHS of a
+ -- non-join-point binding, so it's okay to use
+ -- addManyOccsSet, which assumes no tail calls
| otherwise = usage1
-tagBinder :: UsageDetails -- Of scope
- -> Id -- Binders
- -> (UsageDetails, -- Details with binders removed
- IdWithOccInfo) -- Tagged binders
+tagNonRecBinder :: TopLevelFlag -- At top level?
+ -> UsageDetails -- Of scope
+ -> CoreBndr -- Binder
+ -> (UsageDetails, -- Details with binder removed
+ IdWithOccInfo) -- Tagged binder
-tagBinder usage binder
+tagNonRecBinder lvl usage binder
= let
- usage' = usage `delVarEnv` binder
- binder' = setBinderOcc usage binder
+ occ = lookupDetails usage binder
+ will_be_join = decideJoinPointHood lvl usage [binder]
+ occ' | will_be_join = occ -- must already be marked AlwaysTailCalled
+ | otherwise = markNonTailCalled occ
+ binder' = setBinderOcc occ' binder
+ usage' = usage `delDetails` binder
in
usage' `seq` (usage', binder')
-setBinderOcc :: UsageDetails -> CoreBndr -> CoreBndr
-setBinderOcc usage bndr
+tagRecBinders :: TopLevelFlag -- At top level?
+ -> UsageDetails -- Of body of let ONLY
+ -> [(CoreBndr, -- Binder
+ UsageDetails, -- RHS usage details
+ [CoreBndr])] -- Lambdas in new RHS
+ -> (UsageDetails, -- Adjusted details for whole scope,
+ -- with binders removed
+ [IdWithOccInfo]) -- Tagged binders
+-- Substantially more complicated than non-recursive case. Need to adjust RHS
+-- details *before* tagging binders (because the tags depend on the RHSes).
+tagRecBinders lvl body_uds triples
+ = let
+ (bndrs, rhs_udss, _) = unzip3 triples
+
+ -- 1. Determine join-point-hood of whole group, as determined by
+ -- the *unadjusted* usage details
+ unadj_uds = body_uds +++ combineUsageDetailsList rhs_udss
+ will_be_joins = decideJoinPointHood lvl unadj_uds bndrs
+
+ -- 2. Adjust usage details of each RHS, taking into account the
+ -- join-point-hood decision
+ rhs_udss' = map adjust triples
+ adjust (bndr, rhs_uds, rhs_bndrs)
+ = adjustRhsUsage mb_join_arity Recursive rhs_bndrs rhs_uds
+ where
+ -- Can't use willBeJoinId_maybe here because we haven't tagged the
+ -- binder yet (the tag depends on these adjustments!)
+ mb_join_arity
+ | will_be_joins
+ , let occ = lookupDetails unadj_uds bndr
+ , AlwaysTailCalled arity <- tailCallInfo occ
+ = Just arity
+ | otherwise
+ = ASSERT(not will_be_joins) -- Should be AlwaysTailCalled if we're
+ -- making join points!
+ Nothing
+
+ -- 3. Compute final usage details from adjusted RHS details
+ adj_uds = body_uds +++ combineUsageDetailsList rhs_udss'
+
+ -- 4. Tag each binder with its adjusted details modulo the
+ -- join-point-hood decision
+ occs = map (lookupDetails adj_uds) bndrs
+ occs' | will_be_joins = occs
+ | otherwise = map markNonTailCalled occs
+ bndrs' = zipWith setBinderOcc occs' bndrs
+
+ -- 5. Drop the binders from the adjusted details and return
+ usage' = adj_uds `delDetailsList` bndrs
+ in
+ (usage', bndrs')
+
+setBinderOcc :: OccInfo -> CoreBndr -> CoreBndr
+setBinderOcc occ_info bndr
| isTyVar bndr = bndr
- | isExportedId bndr = case idOccInfo bndr of
- NoOccInfo -> bndr
- _ -> setIdOccInfo bndr NoOccInfo
+ | isExportedId bndr = if isManyOccs (idOccInfo bndr)
+ then bndr
+ else setIdOccInfo bndr noOccInfo
-- Don't use local usage info for visible-elsewhere things
-- BUT *do* erase any IAmALoopBreaker annotation, because we're
-- about to re-generate it and it shouldn't be "sticky"
| otherwise = setIdOccInfo bndr occ_info
+
+-- | Decide whether some bindings should be made into join points or not.
+-- Returns `False` if they can't be join points. Note that it's an
+-- all-or-nothing decision, as if multiple binders are given, they're assumed to
+-- be mutually recursive.
+--
+-- See Note [Invariants for join points] in CoreSyn.
+decideJoinPointHood :: TopLevelFlag -> UsageDetails
+ -> [CoreBndr]
+ -> Bool
+decideJoinPointHood TopLevel _ _
+ = False
+decideJoinPointHood NotTopLevel usage bndrs
+ | isJoinId (head bndrs)
+ = WARN(not all_ok, text "OccurAnal failed to rediscover join point(s):" <+>
+ ppr bndrs)
+ all_ok
+ | otherwise
+ = all_ok
where
- occ_info = lookupVarEnv usage bndr `orElse` IAmDead
+ -- See Note [Invariants on join points]; invariants cited by number below.
+ -- Invariant 2 is always satisfiable by the simplifier by eta expansion.
+ all_ok = -- Invariant 3: Either all are join points or none are
+ all ok bndrs
+
+ ok bndr
+ | -- Invariant 1: Only tail calls, all same join arity
+ AlwaysTailCalled arity <- tailCallInfo (lookupDetails usage bndr)
+ , -- Invariant 1 as applied to LHSes of rules
+ all (ok_rule arity) (idCoreRules bndr)
+ -- Invariant 4: Satisfies polymorphism rule
+ , isValidJoinPointType arity (idType bndr)
+ = True
+ | otherwise
+ = False
+
+ ok_rule _ BuiltinRule{} = False -- only possible with plugin shenanigans
+ ok_rule join_arity (Rule { ru_args = args })
+ = length args == join_arity
+ -- Invariant 1 as applied to LHSes of rules
+
+willBeJoinId_maybe :: CoreBndr -> Maybe JoinArity
+willBeJoinId_maybe bndr
+ | AlwaysTailCalled arity <- tailCallInfo (idOccInfo bndr)
+ = Just arity
+ | otherwise
+ = isJoinId_maybe bndr
{-
************************************************************************
@@ -2149,37 +2671,41 @@ setBinderOcc usage bndr
************************************************************************
-}
-mkOneOcc :: OccEnv -> Id -> InterestingCxt -> UsageDetails
-mkOneOcc env id int_cxt
- | isLocalId id
- = unitVarEnv id (OneOcc False True int_cxt)
+markMany, markInsideLam, markNonTailCalled :: OccInfo -> OccInfo
- | id `elemVarEnv` occ_gbl_scrut env
- = unitVarEnv id NoOccInfo
-
- | otherwise
- = emptyDetails
-
-markMany, markInsideLam :: OccInfo -> OccInfo
+markMany IAmDead = IAmDead
+markMany occ = ManyOccs { occ_tail = occ_tail occ }
-markMany _ = NoOccInfo
+markInsideLam occ@(OneOcc {}) = occ { occ_in_lam = True }
+markInsideLam occ = occ
-markInsideLam (OneOcc _ one_br int_cxt) = OneOcc True one_br int_cxt
-markInsideLam occ = occ
+markNonTailCalled IAmDead = IAmDead
+markNonTailCalled occ = occ { occ_tail = NoTailCallInfo }
addOccInfo, orOccInfo :: OccInfo -> OccInfo -> OccInfo
addOccInfo a1 a2 = ASSERT( not (isDeadOcc a1 || isDeadOcc a2) )
- NoOccInfo -- Both branches are at least One
+ ManyOccs { occ_tail = tailCallInfo a1 `andTailCallInfo`
+ tailCallInfo a2 }
+ -- Both branches are at least One
-- (Argument is never IAmDead)
-- (orOccInfo orig new) is used
-- when combining occurrence info from branches of a case
-orOccInfo (OneOcc in_lam1 _ int_cxt1)
- (OneOcc in_lam2 _ int_cxt2)
- = OneOcc (in_lam1 || in_lam2)
- False -- False, because it occurs in both branches
- (int_cxt1 && int_cxt2)
+orOccInfo (OneOcc { occ_in_lam = in_lam1, occ_int_cxt = int_cxt1
+ , occ_tail = tail1 })
+ (OneOcc { occ_in_lam = in_lam2, occ_int_cxt = int_cxt2
+ , occ_tail = tail2 })
+ = OneOcc { occ_in_lam = in_lam1 || in_lam2
+ , occ_one_br = False -- False, because it occurs in both branches
+ , occ_int_cxt = int_cxt1 && int_cxt2
+ , occ_tail = tail1 `andTailCallInfo` tail2 }
orOccInfo a1 a2 = ASSERT( not (isDeadOcc a1 || isDeadOcc a2) )
- NoOccInfo
+ ManyOccs { occ_tail = tailCallInfo a1 `andTailCallInfo`
+ tailCallInfo a2 }
+
+andTailCallInfo :: TailCallInfo -> TailCallInfo -> TailCallInfo
+andTailCallInfo info@(AlwaysTailCalled arity1) (AlwaysTailCalled arity2)
+ | arity1 == arity2 = info
+andTailCallInfo _ _ = NoTailCallInfo
diff --git a/compiler/simplCore/SetLevels.hs b/compiler/simplCore/SetLevels.hs
index c0d6e8d862..d1ff3fc18b 100644
--- a/compiler/simplCore/SetLevels.hs
+++ b/compiler/simplCore/SetLevels.hs
@@ -49,11 +49,11 @@
the scrutinee of the case, and we can inline it.
-}
-{-# LANGUAGE CPP #-}
+{-# LANGUAGE CPP, MultiWayIf #-}
module SetLevels (
setLevels,
- Level(..), tOP_LEVEL,
+ Level(..), LevelType(..), tOP_LEVEL, isJoinCeilLvl, asJoinCeilLvl,
LevelledBind, LevelledExpr, LevelledBndr,
FloatSpec(..), floatSpecLevel,
@@ -74,6 +74,7 @@ import CoreArity ( exprBotStrictness_maybe )
import CoreFVs -- all of it
import CoreSubst
import MkCore ( sortQuantVars )
+
import Id
import IdInfo
import Var
@@ -84,7 +85,7 @@ import Demand ( StrictSig, increaseStrictSigArity )
import Name ( getOccName, mkSystemVarName )
import OccName ( occNameString )
import Type ( isUnliftedType, Type, mkLamTypes, splitTyConApp_maybe )
-import BasicTypes ( Arity, RecFlag(..) )
+import BasicTypes ( Arity, RecFlag(..), isRec )
import DataCon ( dataConOrigResTy )
import TysWiredIn
import UniqSupply
@@ -95,6 +96,8 @@ import UniqDFM
import FV
import Data.Maybe
+import Control.Monad ( zipWithM )
+
{-
************************************************************************
* *
@@ -107,10 +110,12 @@ type LevelledExpr = TaggedExpr FloatSpec
type LevelledBind = TaggedBind FloatSpec
type LevelledBndr = TaggedBndr FloatSpec
-data Level = Level Int -- Major level: number of enclosing value lambdas
- Int -- Minor level: number of big-lambda and/or case
- -- expressions between here and the nearest
- -- enclosing value lambda
+data Level = Level Int -- Level number of enclosing lambdas
+ Int -- Number of big-lambda and/or case expressions and/or
+ -- context boundaries between
+ -- here and the nearest enclosing lambda
+ LevelType -- Binder or join ceiling?
+data LevelType = BndrLvl | JoinCeilLvl deriving (Eq)
data FloatSpec
= FloatMe Level -- Float to just inside the binding
@@ -139,7 +144,7 @@ a_0 = let b_? = ... in
x_1 = ... b ... in ...
\end{verbatim}
-The main function @lvlExpr@ carries a ``context level'' (@ctxt_lvl@).
+The main function @lvlExpr@ carries a ``context level'' (@le_ctxt_lvl@).
That's meant to be the level number of the enclosing binder in the
final (floated) program. If the level number of a sub-expression is
less than that of the context, then it might be worth let-binding the
@@ -176,6 +181,26 @@ One particular case is that of workers: we don't want to float the
call to the worker outside the wrapper, otherwise the worker might get
inlined into the floated expression, and an importing module won't see
the worker at all.
+
+Note [Join ceiling]
+~~~~~~~~~~~~~~~~~~~
+Join points can't float very far; too far, and they can't remain join points
+(though see Note [When to ruin a join point]). So, suppose we have:
+
+ f x =
+ (joinrec j y = ... x ... in jump j x) + 1
+
+One may be tempted to float j out to the top of f's RHS, but then the jump
+would not be a tail call. Thus we keep track of a level called the *join
+ceiling* past which join points are not allowed to float.
+
+The troublesome thing is that, unlike most levels to which something might
+float, there is not necessarily an identifier to which the join ceiling is
+attached. Fortunately, if something is to be floated to a join ceiling, it must
+be dropped at the *nearest* join ceiling. Thus each level is marked as to
+whether it is a join ceiling, so that FloatOut can tell which binders are being
+floated to the nearest join ceiling and which to a particular binder (or set of
+binders).
-}
instance Outputable FloatSpec where
@@ -183,36 +208,44 @@ instance Outputable FloatSpec where
ppr (StayPut l) = ppr l
tOP_LEVEL :: Level
-tOP_LEVEL = Level 0 0
+tOP_LEVEL = Level 0 0 BndrLvl
incMajorLvl :: Level -> Level
-incMajorLvl (Level major _) = Level (major + 1) 0
+incMajorLvl (Level major _ _) = Level (major + 1) 0 BndrLvl
incMinorLvl :: Level -> Level
-incMinorLvl (Level major minor) = Level major (minor+1)
+incMinorLvl (Level major minor _) = Level major (minor+1) BndrLvl
+
+asJoinCeilLvl :: Level -> Level
+asJoinCeilLvl (Level major minor _) = Level major minor JoinCeilLvl
maxLvl :: Level -> Level -> Level
-maxLvl l1@(Level maj1 min1) l2@(Level maj2 min2)
+maxLvl l1@(Level maj1 min1 _) l2@(Level maj2 min2 _)
| (maj1 > maj2) || (maj1 == maj2 && min1 > min2) = l1
| otherwise = l2
ltLvl :: Level -> Level -> Bool
-ltLvl (Level maj1 min1) (Level maj2 min2)
+ltLvl (Level maj1 min1 _) (Level maj2 min2 _)
= (maj1 < maj2) || (maj1 == maj2 && min1 < min2)
ltMajLvl :: Level -> Level -> Bool
-- Tells if one level belongs to a difft *lambda* level to another
-ltMajLvl (Level maj1 _) (Level maj2 _) = maj1 < maj2
+ltMajLvl (Level maj1 _ _) (Level maj2 _ _) = maj1 < maj2
isTopLvl :: Level -> Bool
-isTopLvl (Level 0 0) = True
-isTopLvl _ = False
+isTopLvl (Level 0 0 _) = True
+isTopLvl _ = False
+
+isJoinCeilLvl :: Level -> Bool
+isJoinCeilLvl (Level _ _ t) = t == JoinCeilLvl
instance Outputable Level where
- ppr (Level maj min) = hcat [ char '<', int maj, char ',', int min, char '>' ]
+ ppr (Level maj min typ)
+ = hcat [ char '<', int maj, char ',', int min, char '>'
+ , ppWhen (typ == JoinCeilLvl) (char 'C') ]
instance Eq Level where
- (Level maj1 min1) == (Level maj2 min2) = maj1 == maj2 && min1 == min2
+ (Level maj1 min1 _) == (Level maj2 min2 _) = maj1 == maj2 && min1 == min2
{-
************************************************************************
@@ -241,14 +274,14 @@ setLevels float_lams binds us
lvlTopBind :: LevelEnv -> Bind Id -> LvlM (LevelledBind, LevelEnv)
lvlTopBind env (NonRec bndr rhs)
- = do { rhs' <- lvlExpr env (freeVars rhs)
+ = do { rhs' <- lvlNonTailExpr env (freeVars rhs)
; let (env', [bndr']) = substAndLvlBndrs NonRecursive env tOP_LEVEL [bndr]
; return (NonRec bndr' rhs', env') }
lvlTopBind env (Rec pairs)
= do let (bndrs,rhss) = unzip pairs
(env', bndrs') = substAndLvlBndrs Recursive env tOP_LEVEL bndrs
- rhss' <- mapM (lvlExpr env' . freeVars) rhss
+ rhss' <- mapM (lvlNonTailExpr env' . freeVars) rhss
return (Rec (bndrs' `zip` rhss'), env')
{-
@@ -278,16 +311,16 @@ lvlExpr :: LevelEnv -- Context
-> LvlM LevelledExpr -- Result expression
{-
-The @ctxt_lvl@ is, roughly, the level of the innermost enclosing
+The @le_ctxt_lvl@ is, roughly, the level of the innermost enclosing
binder. Here's an example
v = \x -> ...\y -> let r = case (..x..) of
..x..
in ..
-When looking at the rhs of @r@, @ctxt_lvl@ will be 1 because that's
+When looking at the rhs of @r@, @le_ctxt_lvl@ will be 1 because that's
the level of @r@, even though it's inside a level-2 @\y@. It's
-important that @ctxt_lvl@ is 1 and not 2 in @r@'s rhs, because we
+important that @le_ctxt_lvl@ is 1 and not 2 in @r@'s rhs, because we
don't want @lvlExpr@ to turn the scrutinee of the @case@ into an MFE
--- because it isn't a *maximal* free expression.
@@ -300,11 +333,11 @@ lvlExpr env (_, AnnVar v) = return (lookupVar env v)
lvlExpr _ (_, AnnLit lit) = return (Lit lit)
lvlExpr env (_, AnnCast expr (_, co)) = do
- expr' <- lvlExpr env expr
+ expr' <- lvlNonTailExpr env expr
return (Cast expr' (substCo (le_subst env) co))
lvlExpr env (_, AnnTick tickish expr) = do
- expr' <- lvlExpr env expr
+ expr' <- lvlNonTailExpr env expr
let tickish' = substTickish (le_subst env) tickish
return (Tick tickish' expr')
@@ -319,8 +352,8 @@ lvlExpr env expr@(_, AnnApp _ _) = do
, Nothing <- isClassOpId_maybe f ->
do
let (lapp, rargs) = left (n_val_args - arity) expr []
- rargs' <- mapM (lvlMFE False env) rargs
- lapp' <- lvlMFE False env lapp
+ rargs' <- mapM (lvlNonTailMFE False env) rargs
+ lapp' <- lvlNonTailMFE False env lapp
return (foldl App lapp' rargs')
where
n_val_args = count (isValArg . deAnnotate) args
@@ -338,8 +371,8 @@ lvlExpr env expr@(_, AnnApp _ _) = do
-- No PAPs that we can float: just carry on with the
-- arguments and the function.
_otherwise -> do
- args' <- mapM (lvlMFE False env) args
- fun' <- lvlExpr env fun
+ args' <- mapM (lvlNonTailMFE False env) args
+ fun' <- lvlNonTailExpr env fun
return (foldl App fun' args')
-- We don't split adjacent lambdas. That is, given
@@ -350,7 +383,7 @@ lvlExpr env expr@(_, AnnApp _ _) = do
-- lambdas makes them more expensive.
lvlExpr env expr@(_, AnnLam {})
- = do { new_body <- lvlMFE True new_env body
+ = do { new_body <- lvlNonTailMFE True new_env body
; return (mkLams new_bndrs new_body) }
where
(bndrs, body) = collectAnnBndrs expr
@@ -372,9 +405,15 @@ lvlExpr env (_, AnnLet bind body)
; return (Let bind' body') }
lvlExpr env (_, AnnCase scrut case_bndr ty alts)
- = do { scrut' <- lvlMFE True env scrut
+ = do { scrut' <- lvlNonTailMFE True env scrut
; lvlCase env (freeVarsOf scrut) scrut' case_bndr ty alts }
+lvlNonTailExpr :: LevelEnv -- Context
+ -> CoreExprWithFVs -- Input expression
+ -> LvlM LevelledExpr -- Result expression
+lvlNonTailExpr env expr
+ = lvlExpr (placeJoinCeiling env) expr
+
-------------------------------------------
lvlCase :: LevelEnv -- Level of in-scope names/tyvars
-> DVarSet -- Free vars of input scrutinee
@@ -394,14 +433,16 @@ lvlCase env scrut_fvs scrut' case_bndr ty alts
; let rhs_env = extendCaseBndrEnv env1 case_bndr scrut'
; body' <- lvlMFE True rhs_env body
; let alt' = (con, [TB b (StayPut dest_lvl) | b <- bs'], body')
- ; return (Case scrut' (TB case_bndr' (FloatMe dest_lvl)) ty [alt']) }
+ ; return (Case scrut' (TB case_bndr' (FloatMe dest_lvl)) ty' [alt']) }
| otherwise -- Stays put
= do { let (alts_env1, [case_bndr']) = substAndLvlBndrs NonRecursive env incd_lvl [case_bndr]
alts_env = extendCaseBndrEnv alts_env1 case_bndr scrut'
; alts' <- mapM (lvl_alt alts_env) alts
- ; return (Case scrut' case_bndr' ty alts') }
+ ; return (Case scrut' case_bndr' ty' alts') }
where
+ ty' = substTy (le_subst env) ty
+
incd_lvl = incMinorLvl (le_ctxt_lvl env)
dest_lvl = maxFvLevel (const True) env scrut_fvs
-- Don't abstact over type variables, hence const True
@@ -487,6 +528,7 @@ lvlMFE True env e@(_, AnnCase {})
lvlMFE strict_ctxt env ann_expr
| floatTopLvlOnly env && not (isTopLvl dest_lvl)
-- Only floating to the top level is allowed.
+ || isTopLvl dest_lvl && need_join -- Can't put join point at top level
|| isExprLevPoly expr
-- We can't let-bind levity polymorphic expressions
-- See Note [Levity polymorphism invariants] in CoreSyn
@@ -496,10 +538,11 @@ lvlMFE strict_ctxt env ann_expr
lvlExpr env ann_expr
| Just (wrap_float, wrap_use)
- <- canFloat_maybe rhs_env strict_ctxt float_is_lam expr
- = do { expr1 <- lvlExpr rhs_env ann_expr
+ <- canFloat_maybe rhs_env strict_ctxt (float_is_lam || need_join) expr
+ = do { expr1 <- if need_join then lvlExpr rhs_env ann_expr
+ else lvlNonTailExpr rhs_env ann_expr
; let abs_expr = mkLams abs_vars_w_lvls (wrap_float expr1)
- ; var <- newLvlVar abs_expr
+ ; var <- newLvlVar abs_expr join_arity_maybe
; let var2 = annotateBotStr var float_n_lams mb_bot_str
; return (Let (NonRec (TB var2 (FloatMe dest_lvl)) abs_expr)
(wrap_use (mkVarApps (Var var2) abs_vars))) }
@@ -514,13 +557,18 @@ lvlMFE strict_ctxt env ann_expr
mb_bot_str = exprBotStrictness_maybe expr
-- See Note [Bottoming floats]
-- esp Bottoming floats (2)
- dest_lvl = destLevel env fvs (isFunction ann_expr) is_bot
+ dest_lvl = destLevel env fvs (isFunction ann_expr) is_bot need_join
abs_vars = abstractVars dest_lvl env fvs
float_is_lam = float_n_lams > 0 -- The floated thing will be a value lambda
float_n_lams = count isId abs_vars -- so nothing is shared; the only benefit
-- is getting it to the top level
(rhs_env, abs_vars_w_lvls) = lvlLamBndrs env dest_lvl abs_vars
+ -- Note [Join points and MFEs]
+ need_join = any (\v -> isId v && remainsJoinId env v) (dVarSetElems fvs)
+ join_arity_maybe | need_join = Just (length abs_vars)
+ | otherwise = Nothing
+
-- A decision to float entails let-binding this thing, and we only do
-- that if we'll escape a value lambda, or will go to the top level.
float_me = (dest_lvl `ltMajLvl` (le_ctxt_lvl env) -- Escapes a value lambda
@@ -542,6 +590,14 @@ lvlMFE strict_ctxt env ann_expr
-- concat = /\ a -> lvl a
-- which is pretty stupid. Hence the strict_ctxt test
+lvlNonTailMFE :: Bool -- True <=> strict context [body of case
+ -- or let]
+ -> LevelEnv -- Level of in-scope names/tyvars
+ -> CoreExprWithFVs -- input expression
+ -> LvlM LevelledExpr -- Result expression
+lvlNonTailMFE strict_ctxt env ann_expr
+ = lvlMFE strict_ctxt (placeJoinCeiling env) ann_expr
+
canFloat_maybe :: LevelEnv
-> Bool -- Strict context
-> Bool -- The float has a value lambda
@@ -553,6 +609,7 @@ canFloat_maybe env strict_ctxt float_is_lam expr
| float_is_lam || exprIsTopLevelBindable expr
= Just (id, id) -- No wrapping needed if the type is lifted, or
-- if we are wrapping it in one or more value lambdas
+ -- or making it a join point
-- OK, so the float has an unlifted type and no value lambdas
| strict_ctxt
@@ -668,6 +725,43 @@ Because in doing so we share a tiny bit of computation (the switch) but
in exchange we build a thunk, which is bad. This case reduces allocation
by 7% in spectral/puzzle (a rather strange benchmark) and 1.2% in real/fem.
Doesn't change any other allocation at all.
+
+Note [Join points and MFEs]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+When we create an MFE float, if it has a free join variable, the new binding
+must be a join point:
+
+ let join j x = ...
+ in case a of A -> ...
+ B -> j 3
+
+ =>
+
+ let join j x = ...
+ join k = j 3 -- only valid because k is a join point
+ in case a of A -> ...
+ B -> k
+
+Normally we're very circumspect about floating join points, but in this case
+it's definitely safe because we can only be floating it as far as another join
+binding. In other words, one might worry about a situation like:
+
+ let join j x = ...
+ in case a of A -> ...
+ B -> f (j 3)
+
+ =>
+
+ let join j x = ...
+ in case a of A -> ...
+ B -> f (let join k = j 3 in k)
+
+Here we have created the MFE float k, and are contemplating floating it up to
+j. This would indeed be an invalid operation on a join point like k. However,
+this example is ill-typed to begin with, since this time the call to j is not a
+tail call. In summary, the very occurrence of the join variable in the MFE is
+proof that we can float the MFE as far as that binding.
-}
annotateBotStr :: Id -> Arity -> Maybe (Arity, StrictSig) -> Id
@@ -779,8 +873,9 @@ lvlBind env (AnnNonRec bndr rhs)
-- We can't float an unlifted binding to top level, so we don't
-- float it at all. It's a bit brutal, but unlifted bindings
-- aren't expensive either
+
= -- No float
- do { rhs' <- lvlExpr env rhs
+ do { rhs' <- lvlRhs env NonRecursive False mb_join_arity rhs
; let bind_lvl = incMinorLvl (le_ctxt_lvl env)
(env', [bndr']) = substAndLvlBndrs NonRecursive env bind_lvl [bndr]
; return (NonRec bndr' rhs', env') }
@@ -788,15 +883,19 @@ lvlBind env (AnnNonRec bndr rhs)
-- Otherwise we are going to float
| null abs_vars
= do { -- No type abstraction; clone existing binder
- rhs' <- lvlExpr (setCtxtLvl env dest_lvl) rhs
- ; (env', [bndr']) <- cloneLetVars NonRecursive env dest_lvl [bndr]
+ rhs' <- lvlRhs (setCtxtLvl env dest_lvl) NonRecursive
+ zapping_join mb_join_arity rhs
+ ; (env', [bndr']) <- cloneLetVars NonRecursive env dest_lvl
+ zapping_join [bndr]
; let bndr2 = annotateBotStr bndr' 0 mb_bot_str
; return (NonRec (TB bndr2 (FloatMe dest_lvl)) rhs', env') }
| otherwise
= do { -- Yes, type abstraction; create a new binder, extend substitution, etc
- rhs' <- lvlFloatRhs abs_vars dest_lvl env rhs
- ; (env', [bndr']) <- newPolyBndrs dest_lvl env abs_vars [bndr]
+ rhs' <- lvlFloatRhs abs_vars dest_lvl env NonRecursive
+ zapping_join mb_join_arity rhs
+ ; (env', [bndr']) <- newPolyBndrs dest_lvl env abs_vars
+ zapping_join [bndr]
; let bndr2 = annotateBotStr bndr' n_extra mb_bot_str
; return (NonRec (TB bndr2 (FloatMe dest_lvl)) rhs', env') }
@@ -805,24 +904,34 @@ lvlBind env (AnnNonRec bndr rhs)
bind_fvs = rhs_fvs `unionDVarSet` dIdFreeVars bndr
abs_vars = abstractVars dest_lvl env bind_fvs
dest_lvl = destLevel env bind_fvs (isFunction rhs) is_bot
+ is_unfloatable_join
mb_bot_str = exprBotStrictness_maybe (deAnnotate rhs)
-- See Note [Bottoming floats]
-- esp Bottoming floats (2)
is_bot = isJust mb_bot_str
n_extra = count isId abs_vars
+ mb_join_arity = isJoinId_maybe bndr
+ is_unfloatable_join = case mb_join_arity of Just ar -> ar > 0
+ Nothing -> False
+ -- See Note [When to ruin a join point]
+ zapping_join = dest_lvl `ltLvl` joinCeilingLevel env
+
lvlBind env (AnnRec pairs)
| floatTopLvlOnly env && not (isTopLvl dest_lvl)
-- Only floating to the top level is allowed.
|| not (profitableFloat env dest_lvl)
= do { let bind_lvl = incMinorLvl (le_ctxt_lvl env)
(env', bndrs') = substAndLvlBndrs Recursive env bind_lvl bndrs
- ; rhss' <- mapM (lvlExpr env') rhss
+ ; rhss' <- zipWithM (lvlRhs env' Recursive False) mb_join_arities rhss
; return (Rec (bndrs' `zip` rhss'), env') }
| null abs_vars
- = do { (new_env, new_bndrs) <- cloneLetVars Recursive env dest_lvl bndrs
- ; new_rhss <- mapM (lvlExpr (setCtxtLvl new_env dest_lvl)) rhss
+ = do { (new_env, new_bndrs) <- cloneLetVars Recursive env dest_lvl
+ zapping_joins bndrs
+ ; let env_rhs = setCtxtLvl new_env dest_lvl
+ ; new_rhss <- zipWithM (lvlRhs env_rhs Recursive zapping_joins)
+ mb_join_arities rhss
; return ( Rec ([TB b (FloatMe dest_lvl) | b <- new_bndrs] `zip` new_rhss)
, new_env) }
@@ -843,13 +952,17 @@ lvlBind env (AnnRec pairs)
let (rhs_env, abs_vars_w_lvls) = lvlLamBndrs env dest_lvl abs_vars
rhs_lvl = le_ctxt_lvl rhs_env
- (rhs_env', [new_bndr]) <- cloneLetVars Recursive rhs_env rhs_lvl [bndr]
+ (rhs_env', [new_bndr]) <- cloneLetVars Recursive rhs_env rhs_lvl
+ zapping_joins [bndr]
let
(lam_bndrs, rhs_body) = collectAnnBndrs rhs
(body_env1, lam_bndrs1) = substBndrsSL NonRecursive rhs_env' lam_bndrs
(body_env2, lam_bndrs2) = lvlLamBndrs body_env1 rhs_lvl lam_bndrs1
- new_rhs_body <- lvlExpr body_env2 rhs_body
- (poly_env, [poly_bndr]) <- newPolyBndrs dest_lvl env abs_vars [bndr]
+ mb_join_arity = isJoinId_maybe bndr
+ new_rhs_body <- lvlRhs body_env2 Recursive zapping_joins
+ mb_join_arity rhs_body
+ (poly_env, [poly_bndr]) <- newPolyBndrs dest_lvl env abs_vars
+ zapping_joins [bndr]
return (Rec [(TB poly_bndr (FloatMe dest_lvl)
, mkLams abs_vars_w_lvls $
mkLams lam_bndrs2 $
@@ -859,8 +972,11 @@ lvlBind env (AnnRec pairs)
, poly_env)
| otherwise -- Non-null abs_vars
- = do { (new_env, new_bndrs) <- newPolyBndrs dest_lvl env abs_vars bndrs
- ; new_rhss <- mapM (lvlFloatRhs abs_vars dest_lvl new_env) rhss
+ = do { (new_env, new_bndrs) <- newPolyBndrs dest_lvl env abs_vars
+ zapping_joins bndrs
+ ; new_rhss <- zipWithM (lvlFloatRhs abs_vars dest_lvl new_env
+ Recursive zapping_joins)
+ mb_join_arities rhss
; return ( Rec ([TB b (FloatMe dest_lvl) | b <- new_bndrs] `zip` new_rhss)
, new_env) }
@@ -876,26 +992,72 @@ lvlBind env (AnnRec pairs)
bndrs
dest_lvl = destLevel env bind_fvs (all isFunction rhss) False
+ has_unfloatable_join
abs_vars = abstractVars dest_lvl env bind_fvs
+ mb_join_arities = map isJoinId_maybe bndrs
+ has_unfloatable_join
+ = any (\mb_ar -> case mb_ar of Just ar -> ar > 0
+ Nothing -> False) mb_join_arities
+ zapping_joins = dest_lvl `ltLvl` joinCeilingLevel env
+
+lvlRhs :: LevelEnv
+ -> RecFlag
+ -> Bool -- True <=> we're zapping a join point back to a value
+ -> Maybe JoinArity
+ -> CoreExprWithFVs
+ -> LvlM LevelledExpr
+lvlRhs env rec_flag zapping_join mb_join_arity expr
+ = lvlFloatRhs [] (le_ctxt_lvl env) env rec_flag zapping_join
+ mb_join_arity expr
+
profitableFloat :: LevelEnv -> Level -> Bool
profitableFloat env dest_lvl
= (dest_lvl `ltMajLvl` le_ctxt_lvl env) -- Escapes a value lambda
|| isTopLvl dest_lvl -- Going all the way to top level
+
+{-
+Note [When to ruin a join point]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Generally, we protect join points zealously. However, there are two situations
+in which it can pay to promote a join point to a function:
+
+1. If the join point has no value arguments, then floating it outward will make
+ it a *thunk*, not a function, so we might get increased sharing.
+2. If we float the join point all the way to the top level, it still won't be
+ allocated, so the cost is much less.
+
+Refusing to lose a join point in either of these cases can be disastrous---for
+instance, allocation in imaginary/x2n1 *triples* because $w$s^ becomes too big
+to inline, which prevents Float In from making a particular binding strictly
+demanded.
+-}
+
----------------------------------------------------
-- Three help functions for the type-abstraction case
-lvlFloatRhs :: [OutVar] -> Level -> LevelEnv -> CoreExprWithFVs
- -> UniqSM (Expr LevelledBndr)
-lvlFloatRhs abs_vars dest_lvl env rhs
- = do { body' <- lvlExpr rhs_env body
+lvlFloatRhs :: [OutVar] -> Level -> LevelEnv -> RecFlag -> Bool
+ -> Maybe JoinArity -> CoreExprWithFVs
+ -> LvlM (Expr LevelledBndr)
+lvlFloatRhs abs_vars dest_lvl env rec zapping_joins mb_join_arity rhs
+ = do { body' <- if | Just _ <- mb_join_arity, not zapping_joins
+ -> lvlExpr rhs_env body
+ | otherwise
+ -> lvlNonTailExpr rhs_env body
; return (mkLams all_bndrs_w_lvls body') }
where
- (bndrs, body) = collectAnnBndrs rhs
+ (bndrs, body) | Just join_arity <- mb_join_arity
+ = collectNAnnBndrs join_arity rhs
+ | otherwise
+ = collectAnnBndrs rhs
(env1, bndrs1) = substBndrsSL NonRecursive env bndrs
all_bndrs = abs_vars ++ bndrs1
- (rhs_env, all_bndrs_w_lvls) = lvlLamBndrs env1 dest_lvl all_bndrs
+ (rhs_env, all_bndrs_w_lvls) | Just _ <- mb_join_arity
+ = lvlJoinBndrs env1 dest_lvl rec all_bndrs
+ | otherwise
+ = lvlLamBndrs env1 dest_lvl all_bndrs
-- The important thing here is that we call lvlLamBndrs on
-- all these binders at once (abs_vars and bndrs), so they
-- all get the same major level. Otherwise we create stupid
@@ -941,13 +1103,21 @@ lvlLamBndrs env lvl bndrs
-- probable one-shot lambda"
-- See Note [Computing one-shot info] in Demand.hs
+lvlJoinBndrs :: LevelEnv -> Level -> RecFlag -> [OutVar]
+ -> (LevelEnv, [LevelledBndr])
+lvlJoinBndrs env lvl rec bndrs
+ = lvlBndrs env new_lvl bndrs
+ where
+ new_lvl | isRec rec = incMajorLvl lvl
+ | otherwise = incMinorLvl lvl
+ -- Non-recursive join points are one-shot; recursive ones are not
lvlBndrs :: LevelEnv -> Level -> [CoreBndr] -> (LevelEnv, [LevelledBndr])
-- The binders returned are exactly the same as the ones passed,
-- apart from applying the substitution, but they are now paired
-- with a (StayPut level)
--
--- The returned envt has ctxt_lvl updated to the new_lvl
+-- The returned envt has le_ctxt_lvl updated to the new_lvl
--
-- All the new binders get the same level, because
-- any floating binding is either going to float past
@@ -964,8 +1134,9 @@ lvlBndrs env@(LE { le_lvl_env = lvl_env }) new_lvl bndrs
destLevel :: LevelEnv -> DVarSet
-> Bool -- True <=> is function
-> Bool -- True <=> is bottom
+ -> Bool -- True <=> is join point (or can be floated anyway)
-> Level
-destLevel env fvs is_function is_bot
+destLevel env fvs is_function is_bot is_join
| is_bot = tOP_LEVEL -- Send bottoming bindings to the top
-- regardless; see Note [Bottoming floats]
-- Esp Bottoming floats (1)
@@ -975,9 +1146,16 @@ destLevel env fvs is_function is_bot
, countFreeIds fvs <= n_args
= tOP_LEVEL -- Send functions to top level; see
-- the comments with isFunction
+ | is_join, hits_ceiling = join_ceiling
+ | otherwise = max_fv_level
+ where
+ max_fv_level = maxFvLevel isId env fvs -- Max over Ids only; the tyvars
+ -- will be abstracted
- | otherwise = maxFvLevel isId env fvs -- Max over Ids only; the tyvars
- -- will be abstracted
+ hits_ceiling = max_fv_level `ltLvl` join_ceiling &&
+ not (isTopLvl max_fv_level)
+ -- Note [When to ruin a join point]
+ join_ceiling = joinCeilingLevel env
isFunction :: CoreExprWithFVs -> Bool
-- The idea here is that we want to float *functions* to
@@ -1019,6 +1197,7 @@ data LevelEnv
= LE { le_switches :: FloatOutSwitches
, le_ctxt_lvl :: Level -- The current level
, le_lvl_env :: VarEnv Level -- Domain is *post-cloned* TyVars and Ids
+ , le_join_ceil:: Level -- Highest level to which joins float
, le_subst :: Subst -- Domain is pre-cloned TyVars and Ids
-- The Id -> CoreExpr in the Subst is ignored
-- (since we want to substitute a LevelledExpr for
@@ -1050,6 +1229,7 @@ initialEnv :: FloatOutSwitches -> LevelEnv
initialEnv float_lams
= LE { le_switches = float_lams
, le_ctxt_lvl = tOP_LEVEL
+ , le_join_ceil = panic "initialEnv"
, le_lvl_env = emptyVarEnv
, le_subst = emptySubst
, le_env = emptyVarEnv }
@@ -1087,6 +1267,13 @@ extendCaseBndrEnv le@(LE { le_subst = subst, le_env = id_env })
, le_env = add_id id_env (case_bndr, scrut_var) }
extendCaseBndrEnv env _ _ = env
+-- See Note [Join ceiling]
+placeJoinCeiling :: LevelEnv -> LevelEnv
+placeJoinCeiling le@(LE { le_ctxt_lvl = lvl })
+ = le { le_ctxt_lvl = lvl', le_join_ceil = lvl' }
+ where
+ lvl' = asJoinCeilLvl (incMinorLvl lvl)
+
maxFvLevel :: (Var -> Bool) -> LevelEnv -> DVarSet -> Level
maxFvLevel max_me (LE { le_lvl_env = lvl_env, le_env = id_env }) var_set
= foldDVarSet max_in tOP_LEVEL var_set
@@ -1107,6 +1294,18 @@ lookupVar le v = case lookupVarEnv (le_env le) v of
Just (_, expr) -> expr
_ -> Var v
+-- Level to which join points are allowed to float (boundary of current tail
+-- context). See Note [Join ceiling]
+joinCeilingLevel :: LevelEnv -> Level
+joinCeilingLevel = le_join_ceil
+
+remainsJoinId :: LevelEnv -> Id -> Bool
+remainsJoinId le v = case lookupVarEnv (le_env le) v of
+ Just (v':_, _) -> isJoinId v'
+ Nothing -> isJoinId v
+ Just ([], e) -> pprPanic "remainsJoinId" $
+ ppr v $$ ppr e
+
abstractVars :: Level -> LevelEnv -> DVarSet -> [OutVar]
-- Find the variables in fvs, free vars of the target expression,
-- whose level is greater than the destination level
@@ -1154,12 +1353,13 @@ type LvlM result = UniqSM result
initLvl :: UniqSupply -> UniqSM a -> a
initLvl = initUs_
-newPolyBndrs :: Level -> LevelEnv -> [OutVar] -> [InId] -> UniqSM (LevelEnv, [OutId])
+newPolyBndrs :: Level -> LevelEnv -> [OutVar] -> Bool -> [InId]
+ -> LvlM (LevelEnv, [OutId])
-- The envt is extended to bind the new bndrs to dest_lvl, but
--- the ctxt_lvl is unaffected
+-- the le_ctxt_lvl is unaffected
newPolyBndrs dest_lvl
env@(LE { le_lvl_env = lvl_env, le_subst = subst, le_env = id_env })
- abs_vars bndrs
+ abs_vars zapping_joins bndrs
= ASSERT( all (not . isCoVar) bndrs ) -- What would we add to the CoSubst in this case. No easy answer.
do { uniqs <- getUniquesM
; let new_bndrs = zipWith mk_poly_bndr bndrs uniqs
@@ -1173,17 +1373,28 @@ newPolyBndrs dest_lvl
add_id env (v, v') = extendVarEnv env v ((v':abs_vars), mkVarApps (Var v') abs_vars)
mk_poly_bndr bndr uniq = transferPolyIdInfo bndr abs_vars $ -- Note [transferPolyIdInfo] in Id.hs
+ maybe_transfer_join_info bndr $
mkSysLocalOrCoVar (mkFastString str) uniq poly_ty
where
str = "poly_" ++ occNameString (getOccName bndr)
poly_ty = mkLamTypes abs_vars (CoreSubst.substTy subst (idType bndr))
+ maybe_transfer_join_info bndr new_bndr
+ | not zapping_joins
+ , Just join_arity <- isJoinId_maybe bndr
+ = new_bndr `asJoinId`
+ join_arity + length abs_vars
+ | otherwise
+ = new_bndr
newLvlVar :: LevelledExpr -- The RHS of the new binding
+ -> Maybe JoinArity -- Its join arity, if it is a join point
-> LvlM Id
-newLvlVar lvld_rhs
+newLvlVar lvld_rhs join_arity_maybe
= do { uniq <- getUniqueM
- ; return (mk_id uniq rhs_ty) }
+ ; return (add_join_info (mk_id uniq rhs_ty))
+ }
where
+ add_join_info var = var `asJoinId_maybe` join_arity_maybe
de_tagged_rhs = deTagExpr lvld_rhs
rhs_ty = exprType de_tagged_rhs
@@ -1208,25 +1419,30 @@ cloneCaseBndrs env@(LE { le_subst = subst, le_lvl_env = lvl_env, le_env = id_env
; return (env', vs') }
-cloneLetVars :: RecFlag -> LevelEnv -> Level -> [Var] -> LvlM (LevelEnv, [Var])
+cloneLetVars :: RecFlag -> LevelEnv -> Level -> Bool -> [InVar]
+ -> LvlM (LevelEnv, [OutVar])
-- See Note [Need for cloning during float-out]
-- Works for Ids bound by let(rec)
-- The dest_lvl is attributed to the binders in the new env,
--- but cloneVars doesn't affect the ctxt_lvl of the incoming env
+-- but cloneVars doesn't affect the le_ctxt_lvl of the incoming env
cloneLetVars is_rec
env@(LE { le_subst = subst, le_lvl_env = lvl_env, le_env = id_env })
- dest_lvl vs
+ dest_lvl zapping_joins vs
= do { us <- getUniqueSupplyM
- ; let (subst', vs1) = case is_rec of
- NonRecursive -> cloneBndrs subst us vs
- Recursive -> cloneRecIdBndrs subst us vs
- vs2 = map zap_demand_info vs1 -- See Note [Zapping the demand info]
+ ; let vs1 = map (zap_demand_info . maybe_zap_join) vs
+ -- See Note [Zapping the demand info]
+ (subst', vs2) = case is_rec of
+ NonRecursive -> cloneBndrs subst us vs1
+ Recursive -> cloneRecIdBndrs subst us vs1
prs = vs `zip` vs2
env' = env { le_lvl_env = addLvls dest_lvl lvl_env vs2
, le_subst = subst'
, le_env = foldl add_id id_env prs }
; return (env', vs2) }
+ where
+ maybe_zap_join v | isId v, zapping_joins = zapJoinId v
+ | otherwise = v
add_id :: IdEnv ([Var], LevelledExpr) -> (Var, Var) -> IdEnv ([Var], LevelledExpr)
add_id id_env (v, v1)
@@ -1247,4 +1463,7 @@ binding site. Eg
f :: Int -> Int
f x = let v = 3*4 in v+x
Here v is strict; but if we float v to top level, it isn't any more.
+
+Similarly, if we're floating a join point, it won't be one anymore, so we zap
+join point information as well.
-}
diff --git a/compiler/simplCore/SimplCore.hs b/compiler/simplCore/SimplCore.hs
index 304dc5a346..f032aad95c 100644
--- a/compiler/simplCore/SimplCore.hs
+++ b/compiler/simplCore/SimplCore.hs
@@ -207,12 +207,16 @@ getCoreToDo dflags
-- Static forms are moved to the top level with the FloatOut pass.
-- See Note [Grand plan for static forms] in StaticPtrTable.
static_ptrs_float_outwards =
- runWhen static_ptrs $ CoreDoFloatOutwards FloatOutSwitches
- { floatOutLambdas = Just 0
- , floatOutConstants = True
- , floatOutOverSatApps = False
- , floatToTopLevelOnly = True
- }
+ runWhen static_ptrs $ CoreDoPasses
+ [ simpl_gently -- Float Out can't handle type lets (sometimes created
+ -- by simpleOptPgm via mkParallelBindings)
+ , CoreDoFloatOutwards FloatOutSwitches
+ { floatOutLambdas = Just 0
+ , floatOutConstants = True
+ , floatOutOverSatApps = False
+ , floatToTopLevelOnly = True
+ }
+ ]
core_todo =
if opt_level == 0 then
@@ -704,6 +708,7 @@ simplifyPgmIO pass@(CoreDoSimplify max_iterations mode)
} ;
Err.dumpIfSet_dyn dflags Opt_D_dump_occur_anal "Occurrence analysis"
(pprCoreBindings tagged_binds);
+ lintPassResult hsc_env CoreOccurAnal tagged_binds;
-- Get any new rules, and extend the rule base
-- See Note [Overall plumbing for rules] in Rules.hs
diff --git a/compiler/simplCore/SimplEnv.hs b/compiler/simplCore/SimplEnv.hs
index 99d8291491..f35d120af9 100644
--- a/compiler/simplCore/SimplEnv.hs
+++ b/compiler/simplCore/SimplEnv.hs
@@ -20,17 +20,22 @@ module SimplEnv (
-- * Substitution results
SimplSR(..), mkContEx, substId, lookupRecBndr, refineFromInScope,
+ isJoinIdInEnv_maybe,
-- * Simplifying 'Id' binders
- simplNonRecBndr, simplRecBndrs,
+ simplNonRecBndr, simplNonRecJoinBndr, simplRecBndrs, simplRecJoinBndrs,
simplBinder, simplBinders,
substTy, substTyVar, getTCvSubst,
substCo, substCoVar,
-- * Floats
- Floats, emptyFloats, isEmptyFloats, addNonRec, addFloats, extendFloats,
+ Floats, emptyFloats, isEmptyFloats,
+ addNonRec, addFloats, extendFloats,
wrapFloats, setFloats, zapFloats, addRecFloats, mapFloats,
- doFloatFromRhs, getFloatBinds
+ doFloatFromRhs, getFloatBinds,
+
+ JoinFloats, emptyJoinFloats, isEmptyJoinFloats,
+ wrapJoinFloats, zapJoinFloats, restoreJoinFloats, getJoinFloatBinds,
) where
#include "HsVersions.h"
@@ -54,6 +59,7 @@ import BasicTypes
import MonadUtils
import Outputable
import Util
+import UniqFM ( pprUniqFM )
import Data.List
@@ -86,8 +92,10 @@ data SimplEnv
-- They are all OutVars, and all bound in this module
seInScope :: InScopeSet, -- OutVars only
-- Includes all variables bound by seFloats
- seFloats :: Floats
+ seFloats :: Floats,
-- See Note [Simplifier floats]
+ seJoinFloats :: JoinFloats
+ -- Handled separately; they don't go very far
}
type StaticEnv = SimplEnv -- Just the static part is relevant
@@ -97,17 +105,24 @@ pprSimplEnv :: SimplEnv -> SDoc
pprSimplEnv env
= vcat [text "TvSubst:" <+> ppr (seTvSubst env),
text "CvSubst:" <+> ppr (seCvSubst env),
- text "IdSubst:" <+> ppr (seIdSubst env),
+ text "IdSubst:" <+> id_subst_doc,
text "InScope:" <+> in_scope_vars_doc
]
where
+ id_subst_doc = pprUniqFM ppr_id_subst (seIdSubst env)
+ ppr_id_subst (m_ar, sr) = arity_part <+> ppr sr
+ where arity_part = case m_ar of Just ar -> brackets $
+ text "join" <+> int ar
+ Nothing -> empty
+
in_scope_vars_doc = pprVarSet (getInScopeVars (seInScope env))
(vcat . map ppr_one)
ppr_one v | isId v = ppr v <+> ppr (idUnfolding v)
| otherwise = ppr v
-type SimplIdSubst = IdEnv SimplSR -- IdId |--> OutExpr
+type SimplIdSubst = IdEnv (Maybe JoinArity, SimplSR) -- IdId |--> OutExpr
-- See Note [Extending the Subst] in CoreSubst
+ -- See Note [Join arity in SimplIdSubst]
-- | A substitution result.
data SimplSR
@@ -192,6 +207,20 @@ seIdSubst:
map to the same target: x->x, y->x. Notably:
case y of x { ... }
That's why the "set" is actually a VarEnv Var
+
+Note [Join arity in SimplIdSubst]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We have to remember which incoming variables are join points (the occurrences
+may not be marked correctly yet; we're in change of propagating the change if
+OccurAnal makes something a join point). Normally the in-scope set is where we
+keep the latest information, but the in-scope set tracks only OutVars; if a
+binding is unconditionally inlined, it never makes it into the in-scope set,
+and we need to know at the occurrence site that the variable is a join point so
+that we know to drop the context. Thus we remember which join points we're
+substituting. Clumsily, finding whether an InVar is a join variable may require
+looking in both the substitution *and* the in-scope set (see
+'isJoinIdInEnv_maybe').
-}
mkSimplEnv :: SimplifierMode -> SimplEnv
@@ -199,6 +228,7 @@ mkSimplEnv mode
= SimplEnv { seMode = mode
, seInScope = init_in_scope
, seFloats = emptyFloats
+ , seJoinFloats = emptyJoinFloats
, seTvSubst = emptyVarEnv
, seCvSubst = emptyVarEnv
, seIdSubst = emptyVarEnv }
@@ -241,7 +271,7 @@ updMode upd env = env { seMode = upd (seMode env) }
extendIdSubst :: SimplEnv -> Id -> SimplSR -> SimplEnv
extendIdSubst env@(SimplEnv {seIdSubst = subst}) var res
= ASSERT2( isId var && not (isCoVar var), ppr var )
- env {seIdSubst = extendVarEnv subst var res}
+ env { seIdSubst = extendVarEnv subst var (isJoinId_maybe var, res) }
extendTvSubst :: SimplEnv -> TyVar -> Type -> SimplEnv
extendTvSubst env@(SimplEnv {seTvSubst = tsubst}) var res
@@ -264,13 +294,22 @@ setInScope :: SimplEnv -> SimplEnv -> SimplEnv
-- Set the in-scope set, and *zap* the floats
setInScope env env_with_scope
= env { seInScope = seInScope env_with_scope,
- seFloats = emptyFloats }
+ seFloats = emptyFloats,
+ seJoinFloats = emptyJoinFloats }
setFloats :: SimplEnv -> SimplEnv -> SimplEnv
-- Set the in-scope set *and* the floats
setFloats env env_with_floats
= env { seInScope = seInScope env_with_floats,
- seFloats = seFloats env_with_floats }
+ seFloats = seFloats env_with_floats,
+ seJoinFloats = seJoinFloats env_with_floats }
+
+restoreJoinFloats :: SimplEnv -> SimplEnv -> SimplEnv
+-- Put back floats previously zapped
+-- Unlike 'setFloats', does *not* update the in-scope set, since the right-hand
+-- env is assumed to be *older*
+restoreJoinFloats env old_env
+ = env { seJoinFloats = seJoinFloats old_env }
addNewInScopeIds :: SimplEnv -> [CoreBndr] -> SimplEnv
-- The new Ids are guaranteed to be freshly allocated
@@ -331,6 +370,8 @@ Can't happen:
data Floats = Floats (OrdList OutBind) FloatFlag
-- See Note [Simplifier floats]
+type JoinFloats = OrdList OutBind
+
data FloatFlag
= FltLifted -- All bindings are lifted and lazy *or*
-- consist of a single primitive string literal
@@ -389,9 +430,13 @@ so we must take the 'or' of the two.
emptyFloats :: Floats
emptyFloats = Floats nilOL FltLifted
+emptyJoinFloats :: JoinFloats
+emptyJoinFloats = nilOL
+
unitFloat :: OutBind -> Floats
-- This key function constructs a singleton float with the right form
-unitFloat bind = Floats (unitOL bind) (flag bind)
+unitFloat bind = ASSERT(all (not . isJoinId) (bindersOf bind))
+ Floats (unitOL bind) (flag bind)
where
flag (Rec {}) = FltLifted
flag (NonRec bndr rhs)
@@ -404,6 +449,10 @@ unitFloat bind = Floats (unitOL bind) (flag bind)
FltCareful
-- Unlifted binders can only be let-bound if exprOkForSpeculation holds
+unitJoinFloat :: OutBind -> JoinFloats
+unitJoinFloat bind = ASSERT(all isJoinId (bindersOf bind))
+ unitOL bind
+
addNonRec :: SimplEnv -> OutId -> OutExpr -> SimplEnv
-- Add a non-recursive binding and extend the in-scope set
-- The latter is important; the binder may already be in the
@@ -412,58 +461,104 @@ addNonRec :: SimplEnv -> OutId -> OutExpr -> SimplEnv
addNonRec env id rhs
= id `seq` -- This seq forces the Id, and hence its IdInfo,
-- and hence any inner substitutions
- env { seFloats = seFloats env `addFlts` unitFloat (NonRec id rhs),
+ env { seFloats = floats',
+ seJoinFloats = jfloats',
seInScope = extendInScopeSet (seInScope env) id }
+ where
+ bind = NonRec id rhs
+
+ floats' | isJoinId id = seFloats env
+ | otherwise = seFloats env `addFlts` unitFloat bind
+ jfloats' | isJoinId id = seJoinFloats env `addJoinFlts` unitJoinFloat bind
+ | otherwise = seJoinFloats env
extendFloats :: SimplEnv -> OutBind -> SimplEnv
-- Add these bindings to the floats, and extend the in-scope env too
extendFloats env bind
- = env { seFloats = seFloats env `addFlts` unitFloat bind,
+ = ASSERT(all (not . isJoinId) (bindersOf bind))
+ env { seFloats = floats',
+ seJoinFloats = jfloats',
seInScope = extendInScopeSetList (seInScope env) bndrs }
where
bndrs = bindersOf bind
+ floats' | isJoinBind bind = seFloats env
+ | otherwise = seFloats env `addFlts` unitFloat bind
+ jfloats' | isJoinBind bind = seJoinFloats env `addJoinFlts`
+ unitJoinFloat bind
+ | otherwise = seJoinFloats env
+
addFloats :: SimplEnv -> SimplEnv -> SimplEnv
-- Add the floats for env2 to env1;
-- *plus* the in-scope set for env2, which is bigger
-- than that for env1
addFloats env1 env2
= env1 {seFloats = seFloats env1 `addFlts` seFloats env2,
+ seJoinFloats = seJoinFloats env1 `addJoinFlts` seJoinFloats env2,
seInScope = seInScope env2 }
addFlts :: Floats -> Floats -> Floats
addFlts (Floats bs1 l1) (Floats bs2 l2)
= Floats (bs1 `appOL` bs2) (l1 `andFF` l2)
+addJoinFlts :: JoinFloats -> JoinFloats -> JoinFloats
+addJoinFlts = appOL
+
zapFloats :: SimplEnv -> SimplEnv
-zapFloats env = env { seFloats = emptyFloats }
+zapFloats env = env { seFloats = emptyFloats
+ , seJoinFloats = emptyJoinFloats }
+
+zapJoinFloats :: SimplEnv -> SimplEnv
+zapJoinFloats env = env { seJoinFloats = emptyJoinFloats }
addRecFloats :: SimplEnv -> SimplEnv -> SimplEnv
-- Flattens the floats from env2 into a single Rec group,
-- prepends the floats from env1, and puts the result back in env2
-- This is all very specific to the way recursive bindings are
-- handled; see Simplify.simplRecBind
-addRecFloats env1 env2@(SimplEnv {seFloats = Floats bs ff})
+addRecFloats env1 env2@(SimplEnv {seFloats = Floats bs ff
+ ,seJoinFloats = jbs })
= ASSERT2( case ff of { FltLifted -> True; _ -> False }, ppr (fromOL bs) )
- env2 {seFloats = seFloats env1 `addFlts` unitFloat (Rec (flattenBinds (fromOL bs)))}
+ env2 {seFloats = seFloats env1 `addFlts` floats'
+ ,seJoinFloats = seJoinFloats env1 `addJoinFlts` jfloats'}
+ where
+ floats' | isNilOL bs = emptyFloats
+ | otherwise = unitFloat (Rec (flattenBinds (fromOL bs)))
+ jfloats' | isNilOL jbs = emptyJoinFloats
+ | otherwise = unitJoinFloat (Rec (flattenBinds (fromOL jbs)))
wrapFloats :: SimplEnv -> OutExpr -> OutExpr
-- Wrap the floats around the expression; they should all
-- satisfy the let/app invariant, so mkLets should do the job just fine
-wrapFloats (SimplEnv {seFloats = Floats bs _}) body
- = foldrOL Let body bs
+wrapFloats env@(SimplEnv {seFloats = Floats bs _}) body
+ = foldrOL Let (wrapJoinFloats env body) bs
+ -- Note: Always safe to put the joins on the inside since the values
+ -- can't refer to them
+
+wrapJoinFloats :: SimplEnv -> OutExpr -> OutExpr
+wrapJoinFloats (SimplEnv {seJoinFloats = jbs}) body
+ = foldrOL Let body jbs
getFloatBinds :: SimplEnv -> [CoreBind]
-getFloatBinds (SimplEnv {seFloats = Floats bs _})
- = fromOL bs
+getFloatBinds env@(SimplEnv {seFloats = Floats bs _})
+ = fromOL bs ++ getJoinFloatBinds env
+
+getJoinFloatBinds :: SimplEnv -> [CoreBind]
+getJoinFloatBinds (SimplEnv {seJoinFloats = jbs})
+ = fromOL jbs
isEmptyFloats :: SimplEnv -> Bool
-isEmptyFloats (SimplEnv {seFloats = Floats bs _})
- = isNilOL bs
+isEmptyFloats env@(SimplEnv {seFloats = Floats bs _})
+ = isNilOL bs && isEmptyJoinFloats env
+
+isEmptyJoinFloats :: SimplEnv -> Bool
+isEmptyJoinFloats (SimplEnv {seJoinFloats = jbs})
+ = isNilOL jbs
mapFloats :: SimplEnv -> ((Id,CoreExpr) -> (Id,CoreExpr)) -> SimplEnv
-mapFloats env@SimplEnv { seFloats = Floats fs ff } fun
- = env { seFloats = Floats (mapOL app fs) ff }
+mapFloats env@SimplEnv { seFloats = Floats fs ff, seJoinFloats = jfs } fun
+ = env { seFloats = Floats (mapOL app fs) ff
+ , seJoinFloats = mapOL app jfs }
where
app (NonRec b e) = case fun (b,e) of (b',e') -> NonRec b' e'
app (Rec bs) = Rec (map fun bs)
@@ -490,7 +585,7 @@ find that it has been substituted by b. (Or conceivably cloned.)
substId :: SimplEnv -> InId -> SimplSR
-- Returns DoneEx only on a non-Var expression
substId (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
- = case lookupVarEnv ids v of -- Note [Global Ids in the substitution]
+ = case snd <$> lookupVarEnv ids v of -- Note [Global Ids in the substitution]
Nothing -> DoneId (refineFromInScope in_scope v)
Just (DoneId v) -> DoneId (refineFromInScope in_scope v)
Just (DoneEx (Var v)) -> DoneId (refineFromInScope in_scope v)
@@ -499,6 +594,15 @@ substId (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
-- Get the most up-to-date thing from the in-scope set
-- Even though it isn't in the substitution, it may be in
-- the in-scope set with better IdInfo
+
+isJoinIdInEnv_maybe :: SimplEnv -> InId -> Maybe JoinArity
+isJoinIdInEnv_maybe (SimplEnv { seInScope = inScope, seIdSubst = ids }) v
+ | not (isLocalId v) = Nothing
+ | Just (m_ar, _) <- lookupVarEnv ids v = m_ar
+ | Just v' <- lookupInScope inScope v = isJoinId_maybe v'
+ | otherwise = WARN( True , ppr v )
+ isJoinId_maybe v
+
refineFromInScope :: InScopeSet -> Var -> Var
refineFromInScope in_scope v
| isLocalId v = case lookupInScope in_scope v of
@@ -511,7 +615,7 @@ lookupRecBndr :: SimplEnv -> InId -> OutId
-- but where we have not yet done its RHS
lookupRecBndr (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
= case lookupVarEnv ids v of
- Just (DoneId v) -> v
+ Just (_, DoneId v) -> v
Just _ -> pprPanic "lookupRecBndr" (ppr v)
Nothing -> refineFromInScope in_scope v
@@ -539,33 +643,53 @@ simplBinder :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplBinder env bndr
| isTyVar bndr = do { let (env', tv) = substTyVarBndr env bndr
; seqTyVar tv `seq` return (env', tv) }
- | otherwise = do { let (env', id) = substIdBndr env bndr
+ | otherwise = do { let (env', id) = substIdBndr Nothing env bndr
; seqId id `seq` return (env', id) }
---------------
simplNonRecBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
-- A non-recursive let binder
simplNonRecBndr env id
- = do { let (env1, id1) = substIdBndr env id
+ = do { let (env1, id1) = substIdBndr Nothing env id
+ ; seqId id1 `seq` return (env1, id1) }
+
+---------------
+simplNonRecJoinBndr :: SimplEnv -> OutType -> InBndr
+ -> SimplM (SimplEnv, OutBndr)
+-- A non-recursive let binder for a join point; context being pushed inward may
+-- change the type
+simplNonRecJoinBndr env res_ty id
+ = do { let (env1, id1) = substIdBndr (Just res_ty) env id
; seqId id1 `seq` return (env1, id1) }
---------------
simplRecBndrs :: SimplEnv -> [InBndr] -> SimplM SimplEnv
-- Recursive let binders
simplRecBndrs env@(SimplEnv {}) ids
- = do { let (env1, ids1) = mapAccumL substIdBndr env ids
+ = ASSERT(all (not . isJoinId) ids)
+ do { let (env1, ids1) = mapAccumL (substIdBndr Nothing) env ids
+ ; seqIds ids1 `seq` return env1 }
+
+---------------
+simplRecJoinBndrs :: SimplEnv -> OutType -> [InBndr] -> SimplM SimplEnv
+-- Recursive let binders for join points; context being pushed inward may
+-- change types
+simplRecJoinBndrs env@(SimplEnv {}) res_ty ids
+ = ASSERT(all isJoinId ids)
+ do { let (env1, ids1) = mapAccumL (substIdBndr (Just res_ty)) env ids
; seqIds ids1 `seq` return env1 }
---------------
-substIdBndr :: SimplEnv -> InBndr -> (SimplEnv, OutBndr)
+substIdBndr :: Maybe OutType -> SimplEnv -> InBndr -> (SimplEnv, OutBndr)
-- Might be a coercion variable
-substIdBndr env bndr
+substIdBndr new_res_ty env bndr
| isCoVar bndr = substCoVarBndr env bndr
- | otherwise = substNonCoVarIdBndr env bndr
+ | otherwise = substNonCoVarIdBndr new_res_ty env bndr
---------------
substNonCoVarIdBndr
- :: SimplEnv
+ :: Maybe OutType -- New result type, if a join binder
+ -> SimplEnv
-> InBndr -- Env and binder to transform
-> (SimplEnv, OutBndr)
-- Clone Id if necessary, substitute its type
@@ -585,7 +709,9 @@ substNonCoVarIdBndr
-- Similar to CoreSubst.substIdBndr, except that
-- the type of id_subst differs
-- all fragile info is zapped
-substNonCoVarIdBndr env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst })
+substNonCoVarIdBndr new_res_ty
+ env@(SimplEnv { seInScope = in_scope
+ , seIdSubst = id_subst })
old_id
= ASSERT2( not (isCoVar old_id), ppr old_id )
(env { seInScope = in_scope `extendInScopeSet` new_id,
@@ -593,14 +719,19 @@ substNonCoVarIdBndr env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst }
where
id1 = uniqAway in_scope old_id
id2 = substIdType env id1
- new_id = zapFragileIdInfo id2 -- Zaps rules, worker-info, unfolding
+ id3 | Just res_ty <- new_res_ty
+ = id2 `setIdType` setJoinResTy (idJoinArity id2) res_ty (idType id2)
+ | otherwise
+ = id2
+ new_id = zapFragileIdInfo id3 -- Zaps rules, worker-info, unfolding
-- and fragile OccInfo
-- Extend the substitution if the unique has changed,
-- or there's some useful occurrence information
-- See the notes with substTyVarBndr for the delSubstEnv
new_subst | new_id /= old_id
- = extendVarEnv id_subst old_id (DoneId new_id)
+ = extendVarEnv id_subst old_id
+ (isJoinId_maybe new_id, DoneId new_id)
| otherwise
= delVarEnv id_subst old_id
@@ -664,7 +795,8 @@ the letrec.
-}
getTCvSubst :: SimplEnv -> TCvSubst
-getTCvSubst (SimplEnv { seInScope = in_scope, seTvSubst = tv_env, seCvSubst = cv_env })
+getTCvSubst (SimplEnv { seInScope = in_scope, seTvSubst = tv_env
+ , seCvSubst = cv_env })
= mkTCvSubst in_scope (tv_env, cv_env)
substTy :: SimplEnv -> Type -> Type
diff --git a/compiler/simplCore/SimplUtils.hs b/compiler/simplCore/SimplUtils.hs
index 3b48924ed1..2e985c5713 100644
--- a/compiler/simplCore/SimplUtils.hs
+++ b/compiler/simplCore/SimplUtils.hs
@@ -19,7 +19,7 @@ module SimplUtils (
-- The continuation type
SimplCont(..), DupFlag(..),
isSimplified,
- contIsDupable, contResultType, contHoleType,
+ contIsDupable, contResultType, contHoleType, applyContToJoinType,
contIsTrivial, contArgs,
countArgs,
mkBoringStop, mkRhsStop, mkLazyArgStop, contIsRhsOrArg,
@@ -47,6 +47,7 @@ import CoreArity
import CoreUnfold
import Name
import Id
+import IdInfo
import Var
import Demand
import SimplMonad
@@ -361,6 +362,10 @@ contHoleType (ApplyToVal { sc_arg = e, sc_env = se, sc_dup = dup, sc_cont = k })
contHoleType (Select { sc_dup = d, sc_bndr = b, sc_env = se })
= perhapsSubstTy d se (idType b)
+applyContToJoinType :: JoinArity -> SimplCont -> OutType -> OutType
+applyContToJoinType ar cont ty
+ = setJoinResTy ar (contResultType cont) ty
+
-------------------
countArgs :: SimplCont -> Int
-- Count all arguments, including types, coercions, and other values
@@ -629,7 +634,7 @@ interestingArg env e = go env 0 e
-- n is # value args to which the expression is applied
go env n (Var v)
| SimplEnv { seIdSubst = ids, seInScope = in_scope } <- env
- = case lookupVarEnv ids v of
+ = case snd <$> lookupVarEnv ids v of
Nothing -> go_var n (refineFromInScope in_scope v)
Just (DoneId v') -> go_var n (refineFromInScope in_scope v')
Just (DoneEx e) -> go (zapSubstEnv env) n e
@@ -1054,7 +1059,9 @@ preInlineUnconditionally dflags env top_lvl bndr rhs
| isCoVar bndr = False -- Note [Do not inline CoVars unconditionally]
| otherwise = case idOccInfo bndr of
IAmDead -> True -- Happens in ((\x.1) v)
- OneOcc in_lam True int_cxt -> try_once in_lam int_cxt
+ occ@OneOcc { occ_one_br = True }
+ -> try_once (occ_in_lam occ)
+ (occ_int_cxt occ)
_ -> False
where
mode = getMode env
@@ -1180,7 +1187,8 @@ postInlineUnconditionally dflags env top_lvl bndr occ_info rhs unfolding
-- False -> case x of ...
-- This is very important in practice; e.g. wheel-seive1 doubles
-- in allocation if you miss this out
- OneOcc in_lam _one_br int_cxt -- OneOcc => no code-duplication issue
+ OneOcc { occ_in_lam = in_lam, occ_int_cxt = int_cxt }
+ -- OneOcc => no code-duplication issue
-> smallEnoughToInline dflags unfolding -- Small enough to dup
-- ToDo: consider discount on smallEnoughToInline if int_cxt is true
--
@@ -1398,9 +1406,10 @@ because the latter is not well-kinded.
************************************************************************
-}
-tryEtaExpandRhs :: SimplEnv -> OutId -> OutExpr -> SimplM (Arity, OutExpr)
+tryEtaExpandRhs :: SimplEnv -> RecFlag -> OutId -> OutExpr
+ -> SimplM (Arity, OutExpr)
-- See Note [Eta-expanding at let bindings]
-tryEtaExpandRhs env bndr rhs
+tryEtaExpandRhs env is_rec bndr rhs
= do { dflags <- getDynFlags
; (new_arity, new_rhs) <- try_expand dflags
@@ -1419,8 +1428,12 @@ tryEtaExpandRhs env bndr rhs
new_arity2 = idCallArity bndr
new_arity = max new_arity1 new_arity2
, new_arity > old_arity -- And the current manifest arity isn't enough
- = do { tick (EtaExpansion bndr)
- ; return (new_arity, etaExpand new_arity rhs) }
+ = if is_rec == Recursive && isJoinId bndr
+ then WARN(True, text "Can't eta-expand recursive join point:" <+>
+ ppr bndr)
+ return (old_arity, rhs)
+ else do { tick (EtaExpansion bndr)
+ ; return (new_arity, etaExpand new_arity rhs) }
| otherwise
= return (old_arity, rhs)
diff --git a/compiler/simplCore/Simplify.hs b/compiler/simplCore/Simplify.hs
index c1f2a9f705..7c6f8757cc 100644
--- a/compiler/simplCore/Simplify.hs
+++ b/compiler/simplCore/Simplify.hs
@@ -18,7 +18,7 @@ import SimplUtils
import FamInstEnv ( FamInstEnv )
import Literal ( litIsLifted ) --, mkMachInt ) -- temporalily commented out. See #8326
import Id
-import MkId ( seqId, voidPrimId )
+import MkId ( seqId )
import MkCore ( mkImpossibleExpr, castBottomExpr )
import IdInfo
import Name ( Name, mkSystemVarName, isExternalName, getOccFS )
@@ -37,10 +37,11 @@ import CoreArity
import CoreSubst ( pushCoTyArg, pushCoValArg )
--import PrimOp ( tagToEnumKey ) -- temporalily commented out. See #8326
import Rules ( mkRuleInfo, lookupRule, getRules )
-import TysPrim ( voidPrimTy ) --, intPrimTy ) -- temporalily commented out. See #8326
-import BasicTypes ( TopLevelFlag(..), isTopLevel, RecFlag(..) )
+--import TysPrim ( intPrimTy ) -- temporalily commented out. See #8326
+import BasicTypes ( TopLevelFlag(..), isNotTopLevel, isTopLevel,
+ RecFlag(..) )
import MonadUtils ( foldlM, mapAccumLM, liftIO )
-import Maybes ( orElse )
+import Maybes ( isJust, fromJust, orElse )
--import Unique ( hasKey ) -- temporalily commented out. See #8326
import Control.Monad
import Outputable
@@ -203,6 +204,35 @@ we should eta expand wherever we find a (value) lambda? Then the eta
expansion at a let RHS can concentrate solely on the PAP case.
+Case-of-case and join points
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When we perform the case-of-case transform (or otherwise push continuations
+inward), we want to treat join points specially. Since they're always
+tail-called and we want to maintain this invariant, we can do this (for any
+evaluation context E):
+
+ E[join j = e
+ in case ... of
+ A -> jump j 1
+ B -> jump j 2
+ C -> f 3]
+
+ -->
+
+ join j = E[e]
+ in case ... of
+ A -> jump j 1
+ B -> jump j 2
+ C -> E[f 3]
+
+As is evident from the example, there are two components to this behavior:
+
+ 1. When entering the RHS of a join point, copy the context inside.
+ 2. When a join point is invoked, discard the outer context.
+
+Clearly we need to be very careful here to remain consistent---neither part is
+optional!
+
************************************************************************
* *
\subsection{Bindings}
@@ -232,9 +262,11 @@ simplTopBinds env0 binds0
simpl_binds env (bind:binds) = do { env' <- simpl_bind env bind
; simpl_binds env' binds }
- simpl_bind env (Rec pairs) = simplRecBind env TopLevel pairs
+ simpl_bind env (Rec pairs) = simplRecBind env TopLevel Nothing pairs
simpl_bind env (NonRec b r) = do { (env', b') <- addBndrRules env b (lookupRecBndr env b)
- ; simplRecOrTopPair env' TopLevel NonRecursive b b' r }
+ ; simplRecOrTopPair env' TopLevel
+ NonRecursive Nothing
+ b b' r }
{-
************************************************************************
@@ -247,10 +279,10 @@ simplRecBind is used for
* recursive bindings only
-}
-simplRecBind :: SimplEnv -> TopLevelFlag
+simplRecBind :: SimplEnv -> TopLevelFlag -> Maybe SimplCont
-> [(InId, InExpr)]
-> SimplM SimplEnv
-simplRecBind env0 top_lvl pairs0
+simplRecBind env0 top_lvl mb_cont pairs0
= do { (env_with_info, triples) <- mapAccumLM add_rules env0 pairs0
; env1 <- go (zapFloats env_with_info) triples
; return (env0 `addRecFloats` env1) }
@@ -266,7 +298,8 @@ simplRecBind env0 top_lvl pairs0
go env [] = return env
go env ((old_bndr, new_bndr, rhs) : pairs)
- = do { env' <- simplRecOrTopPair env top_lvl Recursive old_bndr new_bndr rhs
+ = do { env' <- simplRecOrTopPair env top_lvl Recursive mb_cont
+ old_bndr new_bndr rhs
; go env' pairs }
{-
@@ -278,18 +311,18 @@ It assumes the binder has already been simplified, but not its IdInfo.
-}
simplRecOrTopPair :: SimplEnv
- -> TopLevelFlag -> RecFlag
+ -> TopLevelFlag -> RecFlag -> Maybe SimplCont
-> InId -> OutBndr -> InExpr -- Binder and rhs
-> SimplM SimplEnv -- Returns an env that includes the binding
-simplRecOrTopPair env top_lvl is_rec old_bndr new_bndr rhs
+simplRecOrTopPair env top_lvl is_rec mb_cont old_bndr new_bndr rhs
= do { dflags <- getDynFlags
; trace_bind dflags $
if preInlineUnconditionally dflags env top_lvl old_bndr rhs
-- Check for unconditional inline
then do tick (PreInlineUnconditionally old_bndr)
return (extendIdSubst env old_bndr (mkContEx env rhs))
- else simplLazyBind env top_lvl is_rec old_bndr new_bndr rhs env }
+ else simplBind env top_lvl is_rec mb_cont old_bndr new_bndr rhs env }
where
trace_bind dflags thing_inside
| not (dopt Opt_D_verbose_core2core dflags)
@@ -300,7 +333,7 @@ simplRecOrTopPair env top_lvl is_rec old_bndr new_bndr rhs
-- helps to locate the tracing for inlining and rule firing
{-
-simplLazyBind is used for
+simplBind is used for
* [simplRecOrTopPair] recursive bindings (whether top level or not)
* [simplRecOrTopPair] top-level non-recursive bindings
* [simplNonRecE] non-top-level *lazy* non-recursive bindings
@@ -315,6 +348,19 @@ Nota bene:
that should have been done already.
-}
+simplBind :: SimplEnv
+ -> TopLevelFlag -> RecFlag -> Maybe SimplCont
+ -> InId -> OutId -- Binder, both pre-and post simpl
+ -- The OutId has IdInfo, except arity, unfolding
+ -> InExpr -> SimplEnv -- The RHS and its environment
+ -> SimplM SimplEnv
+simplBind env top_lvl is_rec mb_cont bndr bndr1 rhs rhs_se
+ | isJoinId bndr1
+ = ASSERT(isNotTopLevel top_lvl && isJust mb_cont)
+ simplJoinBind env is_rec (fromJust mb_cont) bndr bndr1 rhs rhs_se
+ | otherwise
+ = simplLazyBind env top_lvl is_rec bndr bndr1 rhs rhs_se
+
simplLazyBind :: SimplEnv
-> TopLevelFlag -> RecFlag
-> InId -> OutId -- Binder, both pre-and post simpl
@@ -346,7 +392,10 @@ simplLazyBind env top_lvl is_rec bndr bndr1 rhs rhs_se
-- Simplify the RHS
; let rhs_cont = mkRhsStop (substTy body_env (exprType body))
- ; (body_env1, body1) <- simplExprF body_env body rhs_cont
+ ; (body_env0, body0) <- simplExprF (zapJoinFloats body_env)
+ body rhs_cont
+ ; let body1 = wrapJoinFloats body_env0 body0
+ body_env1 = body_env0 `restoreJoinFloats` body_env
-- ANF-ise a constructor or PAP rhs
; (body_env2, body2) <- prepareRhs top_lvl body_env1 bndr1 body1
@@ -367,7 +416,24 @@ simplLazyBind env top_lvl is_rec bndr bndr1 rhs rhs_se
; env' <- foldlM (addPolyBind top_lvl) env poly_binds
; return (env', rhs') }
- ; completeBind env' top_lvl bndr bndr1 rhs' }
+ ; completeBind env' top_lvl is_rec Nothing bndr bndr1 rhs' }
+
+simplJoinBind :: SimplEnv
+ -> RecFlag
+ -> SimplCont
+ -> InId -> OutId -- Binder, both pre-and post simpl
+ -- The OutId has IdInfo, except arity,
+ -- unfolding
+ -> InExpr -> SimplEnv -- The RHS and its environment
+ -> SimplM SimplEnv
+simplJoinBind env is_rec cont bndr bndr1 rhs rhs_se
+ = -- pprTrace "simplLazyBind" ((ppr bndr <+> ppr bndr1) $$
+ -- ppr rhs $$ ppr (seIdSubst rhs_se)) $
+ do { let rhs_env = rhs_se `setInScope` env
+
+ -- Simplify the RHS
+ ; rhs' <- simplJoinRhs rhs_env cont bndr rhs
+ ; completeBind env NotTopLevel is_rec (Just cont) bndr bndr1 rhs' }
{-
A specialised variant of simplNonRec used when the RHS is already simplified,
@@ -402,13 +468,15 @@ completeNonRecX :: TopLevelFlag -> SimplEnv
-- See Note [CoreSyn let/app invariant] in CoreSyn
completeNonRecX top_lvl env is_strict old_bndr new_bndr new_rhs
- = do { (env1, rhs1) <- prepareRhs top_lvl (zapFloats env) new_bndr new_rhs
+ = ASSERT(not (isJoinId new_bndr))
+ do { (env1, rhs1) <- prepareRhs top_lvl (zapFloats env) new_bndr new_rhs
; (env2, rhs2) <-
if doFloatFromRhs NotTopLevel NonRecursive is_strict rhs1 env1
then do { tick LetFloatFromLet
; return (addFloats env env1, rhs1) } -- Add the floats to the main env
else return (env, wrapFloats env1 rhs1) -- Wrap the floats around the RHS
- ; completeBind env2 NotTopLevel old_bndr new_bndr rhs2 }
+ ; completeBind env2 NotTopLevel NonRecursive Nothing
+ old_bndr new_bndr rhs2 }
{-
{- No, no, no! Do not try preInlineUnconditionally in completeNonRecX
@@ -664,6 +732,8 @@ Nor does it do the atomic-argument thing
completeBind :: SimplEnv
-> TopLevelFlag -- Flag stuck into unfolding
+ -> RecFlag -- Recursive binding?
+ -> Maybe SimplCont -- Required only for join point
-> InId -- Old binder
-> OutId -> OutExpr -- New binder and RHS
-> SimplM SimplEnv
@@ -672,7 +742,7 @@ completeBind :: SimplEnv
-- * or by adding to the floats in the envt
--
-- Precondition: rhs obeys the let/app invariant
-completeBind env top_lvl old_bndr new_bndr new_rhs
+completeBind env top_lvl is_rec mb_cont old_bndr new_bndr new_rhs
| isCoVar old_bndr
= case new_rhs of
Coercion co -> return (extendCvSubst env old_bndr co)
@@ -686,10 +756,15 @@ completeBind env top_lvl old_bndr new_bndr new_rhs
-- Do eta-expansion on the RHS of the binding
-- See Note [Eta-expanding at let bindings] in SimplUtils
- ; (new_arity, final_rhs) <- tryEtaExpandRhs env new_bndr new_rhs
+ ; (new_arity, final_rhs) <- if isJoinId new_bndr
+ then return (manifestArity new_rhs, new_rhs)
+ -- Note [Don't eta-expand join points]
+ else tryEtaExpandRhs env is_rec
+ new_bndr new_rhs
-- Simplify the unfolding
- ; new_unfolding <- simplLetUnfolding env top_lvl old_bndr final_rhs old_unf
+ ; new_unfolding <- simplLetUnfolding env top_lvl mb_cont old_bndr
+ final_rhs old_unf
; dflags <- getDynFlags
; if postInlineUnconditionally dflags env top_lvl new_bndr occ_info
@@ -740,7 +815,8 @@ addPolyBind :: TopLevelFlag -> SimplEnv -> OutBind -> SimplM SimplEnv
-- INVARIANT: the arity is correct on the incoming binders
addPolyBind top_lvl env (NonRec poly_id rhs)
- = do { unfolding <- simplLetUnfolding env top_lvl poly_id rhs noUnfolding
+ = do { unfolding <- simplLetUnfolding env top_lvl Nothing poly_id rhs
+ noUnfolding
-- Assumes that poly_id did not have an INLINE prag
-- which is perhaps wrong. ToDo: think about this
; let final_id = setIdInfo poly_id $
@@ -793,6 +869,44 @@ After inlining f at some of its call sites the original binding may
(for example) be no longer strictly demanded.
The solution here is a bit ad hoc...
+Note [Don't eta-expand join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Similarly to CPR (see Note [Don't CPR join points] in WorkWrap), a join point
+stands well to gain from its outer binding's eta-expansion, and eta-expanding a
+join point is fraught with issues like how to deal with a cast:
+
+ let join $j1 :: IO ()
+ $j1 = ...
+ $j2 :: Int -> IO ()
+ $j2 n = if n > 0 then $j1
+ else ...
+
+ =>
+
+ let join $j1 :: IO ()
+ $j1 = (\eta -> ...)
+ `cast` N:IO :: State# RealWorld -> (# State# RealWorld, ())
+ ~ IO ()
+ $j2 :: Int -> IO ()
+ $j2 n = (\eta -> if n > 0 then $j1
+ else ...)
+ `cast` N:IO :: State# RealWorld -> (# State# RealWorld, ())
+ ~ IO ()
+
+The cast here can't be pushed inside the lambda (since it's not casting to a
+function type), so the lambda has to stay, but it can't because it contains a
+reference to a join point. In fact, $j2 can't be eta-expanded at all. Rather
+than try and detect this situation (and whatever other situations crop up!), we
+don't bother; again, any surrounding eta-expansion will improve these join
+points anyway, since an outer cast can *always* be pushed inside. By the time
+CorePrep comes around, the code is very likely to look more like this:
+
+ let join $j1 :: State# RealWorld -> (# State# RealWorld, ())
+ $j1 = (...) eta
+ $j2 :: Int -> State# RealWorld -> (# State# RealWorld, ())
+ $j2 = if n > 0 then $j1
+ else (...) eta
************************************************************************
* *
@@ -917,17 +1031,33 @@ simplExprF1 env (Case scrut bndr _ alts) cont
, sc_env = env, sc_cont = cont })
simplExprF1 env (Let (Rec pairs) body) cont
- = do { env' <- simplRecBndrs env (map fst pairs)
- -- NB: bndrs' don't have unfoldings or rules
- -- We add them as we go down
-
- ; env'' <- simplRecBind env' NotTopLevel pairs
- ; simplExprF env'' body cont }
+ = simplRecE env pairs body cont
simplExprF1 env (Let (NonRec bndr rhs) body) cont
= simplNonRecE env bndr (rhs, env) ([], body) cont
---------------------------------
+-- Simplify a join point, adding the context.
+-- Context goes *inside* the lambdas. IOW, if the join point has arity n, we do:
+-- \x1 .. xn -> e => \x1 .. xn -> E[e]
+-- Note that we need the arity of the join point, since e may be a lambda
+-- (though this is unlikely). See Note [Case-of-case and join points].
+simplJoinRhs :: SimplEnv -> SimplCont -> InId -> InExpr
+ -> SimplM OutExpr
+simplJoinRhs env cont bndr expr
+ | Just arity <- isJoinId_maybe bndr
+ = simpl_join_lams arity
+ | otherwise
+ = pprPanic "simplJoinRhs" (ppr bndr)
+ where
+ simpl_join_lams arity
+ = do { (env', join_bndrs') <- simplLamBndrs env join_bndrs
+ ; join_body' <- simplExprC env' join_body cont
+ ; return $ mkLams join_bndrs' join_body' }
+ where
+ (join_bndrs, join_body) = collectNBinders arity expr
+
+---------------------------------
simplType :: SimplEnv -> InType -> SimplM OutType
-- Kept monadic just so we can do the seqType
simplType env ty
@@ -1270,7 +1400,7 @@ simplLamBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplLamBndr env bndr
| isId bndr && isFragileUnfolding old_unf -- Special case
= do { (env1, bndr1) <- simplBinder env bndr
- ; unf' <- simplUnfolding env1 NotTopLevel bndr old_unf
+ ; unf' <- simplUnfolding env1 NotTopLevel Nothing bndr old_unf
; let bndr2 = bndr1 `setIdUnfolding` unf'
; return (modifyInScope env1 bndr2, bndr2) }
@@ -1322,6 +1452,25 @@ simplNonRecE env bndr (rhs, rhs_se) (bndrs, body) cont
-> simplExprF (rhs_se `setFloats` env) rhs
(StrictBind bndr bndrs body env cont)
+ | Just (bndr', rhs') <- matchOrConvertToJoinPoint bndr rhs
+ -> do { let cont_dup_res_ty = resultTypeOfDupableCont (getMode env)
+ [bndr'] cont
+ ; (env1, bndr1) <- simplNonRecJoinBndr env
+ cont_dup_res_ty bndr'
+ ; (env2, bndr2) <- addBndrRules env1 bndr' bndr1
+ ; (env3, cont_dup, cont_nodup)
+ <- prepareLetCont (zapJoinFloats env2) [bndr'] cont
+ ; MASSERT2(cont_dup_res_ty `eqType` contResultType cont_dup,
+ ppr cont_dup_res_ty $$ blankLine $$
+ ppr cont $$ blankLine $$
+ ppr cont_dup $$ blankLine $$
+ ppr cont_nodup)
+ ; env4 <- simplJoinBind env3 NonRecursive cont_dup bndr' bndr2
+ rhs' rhs_se
+ ; (env5, expr) <- simplLam env4 bndrs body cont_dup
+ ; rebuild (env5 `restoreJoinFloats` env2)
+ (wrapJoinFloats env5 expr) cont_nodup }
+
| otherwise
-> ASSERT( not (isTyVar bndr) )
do { (env1, bndr1) <- simplNonRecBndr env bndr
@@ -1329,6 +1478,64 @@ simplNonRecE env bndr (rhs, rhs_se) (bndrs, body) cont
; env3 <- simplLazyBind env2 NotTopLevel NonRecursive bndr bndr2 rhs rhs_se
; simplLam env3 bndrs body cont }
+------------------
+simplRecE :: SimplEnv
+ -> [(InId, InExpr)]
+ -> InExpr
+ -> SimplCont
+ -> SimplM (SimplEnv, OutExpr)
+
+-- simplRecE is used for
+-- * non-top-level recursive lets in expressions
+simplRecE env pairs body cont
+ | Just pairs' <- matchOrConvertToJoinPoints pairs
+ = do { let bndrs' = map fst pairs'
+ cont_dup_res_ty = resultTypeOfDupableCont (getMode env)
+ bndrs' cont
+ ; env1 <- simplRecJoinBndrs env cont_dup_res_ty bndrs'
+ -- NB: bndrs' don't have unfoldings or rules
+ -- We add them as we go down
+ ; (env2, cont_dup, cont_nodup) <- prepareLetCont (zapJoinFloats env1)
+ bndrs' cont
+ ; MASSERT2(cont_dup_res_ty `eqType` contResultType cont_dup,
+ ppr cont_dup_res_ty $$ blankLine $$
+ ppr cont $$ blankLine $$
+ ppr cont_dup $$ blankLine $$
+ ppr cont_nodup)
+ ; env3 <- simplRecBind env2 NotTopLevel (Just cont_dup) pairs'
+ ; (env4, expr) <- simplExprF env3 body cont_dup
+ ; rebuild (env4 `restoreJoinFloats` env1)
+ (wrapJoinFloats env4 expr) cont_nodup }
+ | otherwise
+ = do { let bndrs = map fst pairs
+ ; MASSERT(all (not . isJoinId) bndrs)
+ ; env1 <- simplRecBndrs env bndrs
+ -- NB: bndrs' don't have unfoldings or rules
+ -- We add them as we go down
+ ; env2 <- simplRecBind env1 NotTopLevel (Just cont) pairs
+ ; simplExprF env2 body cont }
+
+-- | Perform the conversion of a value binding to a join point if it's marked
+-- as 'AlwaysTailCalled'. If it's already a join point, return it as is.
+-- Otherwise return 'Nothing'.
+matchOrConvertToJoinPoint :: InBndr -> InExpr -> Maybe (JoinId, InExpr)
+matchOrConvertToJoinPoint bndr rhs
+ | not (isId bndr)
+ = Nothing
+ | isJoinId bndr
+ = -- No point in keeping tailCallInfo around; very fragile
+ Just (zapIdTailCallInfo bndr, rhs)
+ | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
+ , (bndrs, body) <- etaExpandToJoinPoint join_arity rhs
+ = Just (zapIdTailCallInfo (bndr `asJoinId` join_arity),
+ mkLams bndrs body)
+ | otherwise
+ = Nothing
+
+matchOrConvertToJoinPoints :: [(InBndr, InExpr)] -> Maybe [(InBndr, InExpr)]
+matchOrConvertToJoinPoints bndrs
+ = mapM (uncurry matchOrConvertToJoinPoint) bndrs
+
{-
************************************************************************
* *
@@ -1351,9 +1558,11 @@ simplVar env var
simplIdF :: SimplEnv -> InId -> SimplCont -> SimplM (SimplEnv, OutExpr)
simplIdF env var cont
= case substId env var of
- DoneEx e -> simplExprF (zapSubstEnv env) e cont
+ DoneEx e -> simplExprF (zapSubstEnv env) e trimmed_cont
ContEx tvs cvs ids e -> simplExprF (setSubstEnv env tvs cvs ids) e cont
- DoneId var1 -> completeCall env var1 cont
+ -- Don't trim; haven't already simplified
+ -- the join, so the cont was never copied
+ DoneId var1 -> completeCall env var1 trimmed_cont
-- Note [zapSubstEnv]
-- The template is already simplified, so don't re-substitute.
-- This is VITAL. Consider
@@ -1363,6 +1572,24 @@ simplIdF env var cont
-- We'll clone the inner \x, adding x->x' in the id_subst
-- Then when we inline y, we must *not* replace x by x' in
-- the inlined copy!!
+ where
+ trimmed_cont | Just arity <- isJoinIdInEnv_maybe env var
+ = trim_cont arity cont
+ | otherwise
+ = cont
+
+ -- Drop outer context from join point invocation
+ -- Note [Case-of-case and join points]
+ trim_cont 0 cont@(Stop {})
+ = cont
+ trim_cont 0 cont
+ = mkBoringStop (contResultType cont)
+ trim_cont n cont@(ApplyToVal { sc_cont = k })
+ = cont { sc_cont = trim_cont (n-1) k }
+ trim_cont n cont@(ApplyToTy { sc_cont = k })
+ = cont { sc_cont = trim_cont (n-1) k } -- join arity counts types!
+ trim_cont _ cont
+ = pprPanic "completeCall" $ ppr var $$ ppr cont
---------------------------------------------------------
-- Dealing with a call site
@@ -1935,7 +2162,8 @@ rebuildCase env scrut case_bndr alts cont
reallyRebuildCase env scrut case_bndr alts cont
= do { -- Prepare the continuation;
-- The new subst_env is in place
- (env', dup_cont, nodup_cont) <- prepareCaseCont env alts cont
+ (env', dup_cont, nodup_cont) <- prepareCaseCont (zapJoinFloats env)
+ alts cont
-- Simplify the alternatives
; (scrut', case_bndr', alts') <- simplAlts env' scrut case_bndr alts dup_cont
@@ -1947,7 +2175,8 @@ reallyRebuildCase env scrut case_bndr alts cont
-- Notice that rebuild gets the in-scope set from env', not alt_env
-- (which in any case is only build in simplAlts)
-- The case binder *not* scope over the whole returned case-expression
- ; rebuild env' case_expr nodup_cont }
+ ; rebuild (env' `restoreJoinFloats` env)
+ (wrapJoinFloats env' case_expr) nodup_cont }
{-
simplCaseBinder checks whether the scrutinee is a variable, v. If so,
@@ -2348,23 +2577,87 @@ prepareCaseCont :: SimplEnv
-- The idea is that we'll transform thus:
-- Knodup[ (case _ of { p1 -> Kdup[r1]; ...; pn -> Kdup[rn] }
--
--- We may also return some extra bindings in SimplEnv (that scope over
--- the entire continuation)
+-- We may also return some extra value bindings in SimplEnv (that scope over
+-- the entire continuation) as well as some join points (thus must *not* float
+-- past the continuation!).
+-- Hence, the full story is this:
+-- K[case _ of { p1 -> r1; ...; pn -> rn }] ==>
+-- F_v[Knodup[F_j[ (case _ of { p1 -> Kdup[r1]; ...; pn -> Kdup[rn] }) ]]]
+-- Here F_v represents some values that got floated out and F_j represents some
+-- join points that got floated out.
--
-- When case-of-case is off, just make the entire continuation non-dupable
prepareCaseCont env alts cont
- | not (sm_case_case (getMode env)) = return (env, mkBoringStop (contHoleType cont), cont)
- | not (many_alts alts) = return (env, cont, mkBoringStop (contResultType cont))
- | otherwise = mkDupableCont env cont
- where
- many_alts :: [InAlt] -> Bool -- True iff strictly > 1 non-bottom alternative
- many_alts [] = False -- See Note [Bottom alternatives]
- many_alts [_] = False
- many_alts (alt:alts)
- | is_bot_alt alt = many_alts alts
- | otherwise = not (all is_bot_alt alts)
+ | not (sm_case_case (getMode env))
+ = return (env, mkBoringStop (contHoleType cont), cont)
+ | not (altsWouldDup alts)
+ = return (env, cont, mkBoringStop (contResultType cont))
+ | otherwise
+ = mkDupableCont env cont
+
+prepareLetCont :: SimplEnv
+ -> [InBndr] -> SimplCont
+ -> SimplM (SimplEnv,
+ SimplCont, -- Dupable part
+ SimplCont) -- Non-dupable part
+
+-- Similar to prepareCaseCont, only for
+-- K[let { j1 = r1; ...; jn -> rn } in _]
+-- If the js are join points, this will turn into
+-- Knodup[join { j1 = Kdup[r1]; ...; jn = Kdup[rn] } in Kdup[_]].
+--
+-- When case-of-case is off and it's a join binding, just make the entire
+-- continuation non-dupable. This is necessary because otherwise
+-- case (join j = ... in case e of { A -> jump j 1; ... }) of { B -> ... }
+-- becomes
+-- join j = case ... of { B -> ... } in
+-- case (case e of { A -> jump j 1; ... }) of { B -> ... },
+-- and the reference to j is invalid.
+prepareLetCont env bndrs cont
+ | not (isJoinId (head bndrs))
+ = return (env, cont, mkBoringStop (contResultType cont))
+ | not (sm_case_case (getMode env))
+ = return (env, mkBoringStop (contHoleType cont), cont)
+ | otherwise
+ = mkDupableCont env cont
+
+-- Predict the result type of the dupable cont returned by prepareLetCont (= the
+-- hole type of the non-dupable part). Ugly, but sadly necessary so that we can
+-- know what the new type of a recursive join point will be before we start
+-- simplifying it.
+resultTypeOfDupableCont :: SimplifierMode
+ -> [InBndr]
+ -> SimplCont
+ -> OutType -- INVARIANT: Result type of dupable cont
+ -- returned by prepareLetCont
+-- IMPORTANT: This must be kept in sync with mkDupableCont!
+resultTypeOfDupableCont mode bndrs cont
+ | not (any isJoinId bndrs) = contResultType cont
+ | not (sm_case_case mode) = contHoleType cont
+ | otherwise = go cont
+ where
+ go cont | contIsDupable cont = contResultType cont
+ go (Stop {}) = panic "typeOfDupableCont" -- Handled by previous eqn
+ go (CastIt _ cont) = go cont
+ go cont@(TickIt {}) = contHoleType cont
+ go cont@(StrictBind {}) = contHoleType cont
+ go (StrictArg _ _ cont) = go cont
+ go cont@(ApplyToTy {}) = go (sc_cont cont)
+ go cont@(ApplyToVal {}) = go (sc_cont cont)
+ go (Select { sc_alts = alts, sc_cont = cont })
+ | not (sm_case_case mode) = contHoleType cont
+ | not (altsWouldDup alts) = contResultType cont
+ | otherwise = go cont
+
+altsWouldDup :: [InAlt] -> Bool -- True iff strictly > 1 non-bottom alternative
+altsWouldDup [] = False -- See Note [Bottom alternatives]
+altsWouldDup [_] = False
+altsWouldDup (alt:alts)
+ | is_bot_alt alt = altsWouldDup alts
+ | otherwise = not (all is_bot_alt alts)
+ where
is_bot_alt (_,_,rhs) = exprIsBottom rhs
{-
@@ -2375,9 +2668,7 @@ When we have
of alts
then we can just duplicate those alts because the A and C cases
will disappear immediately. This is more direct than creating
-join points and inlining them away; and in some cases we would
-not even create the join points (see Note [Single-alternative case])
-and we would keep the case-of-case which is silly. See Trac #4930.
+join points and inlining them away. See Trac #4930.
-}
mkDupableCont :: SimplEnv -> SimplCont
@@ -2423,15 +2714,6 @@ mkDupableCont env (ApplyToVal { sc_arg = arg, sc_dup = dup, sc_env = se, sc_cont
, sc_dup = OkToDup, sc_cont = dup_cont }
; return (env'', app_cont, nodup_cont) }
-mkDupableCont env cont@(Select { sc_bndr = case_bndr, sc_alts = [(_, bs, _rhs)] })
--- See Note [Single-alternative case]
--- | not (exprIsDupable rhs && contIsDupable case_cont)
--- | not (isDeadBinder case_bndr)
- | all isDeadBinder bs -- InIds
- && not (isUnliftedType (idType case_bndr))
- -- Note [Single-alternative-unlifted]
- = return (env, mkBoringStop (contHoleType cont), cont)
-
mkDupableCont env (Select { sc_bndr = case_bndr, sc_alts = alts
, sc_env = se, sc_cont = cont })
= -- e.g. (case [...hole...] of { pi -> ei })
@@ -2509,19 +2791,16 @@ mkDupableAlt env case_bndr (con, bndrs', rhs') = do
-- The case binder is alive but trivial, so why has
-- it not been substituted away?
- used_bndrs' | isDeadBinder case_bndr = filter abstract_over bndrs'
- | otherwise = bndrs' ++ [case_bndr_w_unf]
+ final_bndrs'
+ | isDeadBinder case_bndr = filter abstract_over bndrs'
+ | otherwise = bndrs' ++ [case_bndr_w_unf]
abstract_over bndr
| isTyVar bndr = True -- Abstract over all type variables just in case
| otherwise = not (isDeadBinder bndr)
-- The deadness info on the new Ids is preserved by simplBinders
-
- ; (final_bndrs', final_args) -- Note [Join point abstraction]
- <- if (any isId used_bndrs')
- then return (used_bndrs', varsToCoreExprs used_bndrs')
- else do { rw_id <- newId (fsLit "w") voidPrimTy
- ; return ([setOneShotLambda rw_id], [Var voidPrimId]) }
+ final_args -- Note [Join point abstraction]
+ = varsToCoreExprs final_bndrs'
; join_bndr <- newId (fsLit "$j") (mkLamTypes final_bndrs' rhs_ty')
-- Note [Funky mkLamTypes]
@@ -2534,10 +2813,14 @@ mkDupableAlt env case_bndr (con, bndrs', rhs') = do
one_shot v | isId v = setOneShotLambda v
| otherwise = v
join_rhs = mkLams really_final_bndrs rhs'
- join_arity = exprArity join_rhs
- join_call = mkApps (Var join_bndr) final_args
-
- ; env' <- addPolyBind NotTopLevel env (NonRec (join_bndr `setIdArity` join_arity) join_rhs)
+ arity = length (filter (not . isTyVar) final_bndrs')
+ join_arity = length final_bndrs'
+ final_join_bndr = (join_bndr `setIdArity` arity)
+ `asJoinId` join_arity
+ join_call = mkApps (Var final_join_bndr) final_args
+ final_join_bind = NonRec final_join_bndr join_rhs
+
+ ; env' <- addPolyBind NotTopLevel env final_join_bind
; return (env', (con, bndrs', join_call)) }
-- See Note [Duplicated env]
@@ -2660,6 +2943,12 @@ type variables as well as term variables.
Note [Join point abstraction]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+NB: This note is now historical. Now that "join point" is not a fuzzy concept
+but a formal syntactic construct (as distinguished by the JoinId constructor of
+IdDetails), each of these concerns is handled separately, with no need for a
+vestigial extra argument.
+
Join points always have at least one value argument,
for several reasons
@@ -2769,114 +3058,6 @@ Unlike StrictArg, there doesn't seem anything to gain from
duplicating a StrictBind continuation, so we don't.
-Note [Single-alternative cases]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-This case is just like the ArgOf case. Here's an example:
- data T a = MkT !a
- ...(MkT (abs x))...
-Then we get
- case (case x of I# x' ->
- case x' <# 0# of
- True -> I# (negate# x')
- False -> I# x') of y {
- DEFAULT -> MkT y
-Because the (case x) has only one alternative, we'll transform to
- case x of I# x' ->
- case (case x' <# 0# of
- True -> I# (negate# x')
- False -> I# x') of y {
- DEFAULT -> MkT y
-But now we do *NOT* want to make a join point etc, giving
- case x of I# x' ->
- let $j = \y -> MkT y
- in case x' <# 0# of
- True -> $j (I# (negate# x'))
- False -> $j (I# x')
-In this case the $j will inline again, but suppose there was a big
-strict computation enclosing the orginal call to MkT. Then, it won't
-"see" the MkT any more, because it's big and won't get duplicated.
-And, what is worse, nothing was gained by the case-of-case transform.
-
-So, in circumstances like these, we don't want to build join points
-and push the outer case into the branches of the inner one. Instead,
-don't duplicate the continuation.
-
-When should we use this strategy? We should not use it on *every*
-single-alternative case:
- e.g. case (case ....) of (a,b) -> (# a,b #)
-Here we must push the outer case into the inner one!
-Other choices:
-
- * Match [(DEFAULT,_,_)], but in the common case of Int,
- the alternative-filling-in code turned the outer case into
- case (...) of y { I# _ -> MkT y }
-
- * Match on single alternative plus (not (isDeadBinder case_bndr))
- Rationale: pushing the case inwards won't eliminate the construction.
- But there's a risk of
- case (...) of y { (a,b) -> let z=(a,b) in ... }
- Now y looks dead, but it'll come alive again. Still, this
- seems like the best option at the moment.
-
- * Match on single alternative plus (all (isDeadBinder bndrs))
- Rationale: this is essentially seq.
-
- * Match when the rhs is *not* duplicable, and hence would lead to a
- join point. This catches the disaster-case above. We can test
- the *un-simplified* rhs, which is fine. It might get bigger or
- smaller after simplification; if it gets smaller, this case might
- fire next time round. NB also that we must test contIsDupable
- case_cont *too, because case_cont might be big!
-
- HOWEVER: I found that this version doesn't work well, because
- we can get let x = case (...) of { small } in ...case x...
- When x is inlined into its full context, we find that it was a bad
- idea to have pushed the outer case inside the (...) case.
-
-There is a cost to not doing case-of-case; see Trac #10626.
-
-Note [Single-alternative-unlifted]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Here's another single-alternative where we really want to do case-of-case:
-
-data Mk1 = Mk1 Int# | Mk2 Int#
-
-M1.f =
- \r [x_s74 y_s6X]
- case
- case y_s6X of tpl_s7m {
- M1.Mk1 ipv_s70 -> ipv_s70;
- M1.Mk2 ipv_s72 -> ipv_s72;
- }
- of
- wild_s7c
- { __DEFAULT ->
- case
- case x_s74 of tpl_s7n {
- M1.Mk1 ipv_s77 -> ipv_s77;
- M1.Mk2 ipv_s79 -> ipv_s79;
- }
- of
- wild1_s7b
- { __DEFAULT -> ==# [wild1_s7b wild_s7c];
- };
- };
-
-So the outer case is doing *nothing at all*, other than serving as a
-join-point. In this case we really want to do case-of-case and decide
-whether to use a real join point or just duplicate the continuation:
-
- let $j s7c = case x of
- Mk1 ipv77 -> (==) s7c ipv77
- Mk1 ipv79 -> (==) s7c ipv79
- in
- case y of
- Mk1 ipv70 -> $j ipv70
- Mk2 ipv72 -> $j ipv72
-
-Hence: check whether the case binder's type is unlifted, because then
-the outer case is *not* a seq.
-
************************************************************************
* *
Unfoldings
@@ -2885,12 +3066,13 @@ the outer case is *not* a seq.
-}
simplLetUnfolding :: SimplEnv-> TopLevelFlag
+ -> Maybe SimplCont
-> InId
-> OutExpr
-> Unfolding -> SimplM Unfolding
-simplLetUnfolding env top_lvl id new_rhs unf
+simplLetUnfolding env top_lvl cont_mb id new_rhs unf
| isStableUnfolding unf
- = simplUnfolding env top_lvl id unf
+ = simplUnfolding env top_lvl cont_mb id unf
| otherwise
= is_bottoming `seq` -- See Note [Force bottoming field]
do { dflags <- getDynFlags
@@ -2905,9 +3087,10 @@ simplLetUnfolding env top_lvl id new_rhs unf
is_top_lvl = isTopLevel top_lvl
is_bottoming = isBottomingId id
-simplUnfolding :: SimplEnv-> TopLevelFlag -> InId -> Unfolding -> SimplM Unfolding
+simplUnfolding :: SimplEnv -> TopLevelFlag -> Maybe SimplCont -> InId
+ -> Unfolding -> SimplM Unfolding
-- Note [Setting the new unfolding]
-simplUnfolding env top_lvl id unf
+simplUnfolding env top_lvl cont_mb id unf
= case unf of
NoUnfolding -> return unf
BootUnfolding -> return unf
@@ -2920,7 +3103,10 @@ simplUnfolding env top_lvl id unf
CoreUnfolding { uf_tmpl = expr, uf_src = src, uf_guidance = guide }
| isStableSource src
- -> do { expr' <- simplExpr rule_env expr
+ -> do { expr' <- if isJoinId id
+ then let Just cont = cont_mb
+ in simplJoinRhs rule_env cont id expr
+ else simplExpr rule_env expr
; case guide of
UnfWhen { ug_arity = arity, ug_unsat_ok = sat_ok } -- Happens for INLINE things
-> let guide' = UnfWhen { ug_arity = arity, ug_unsat_ok = sat_ok
@@ -3010,9 +3196,11 @@ simplRules env mb_new_nm rules
simpl_rule rule@(Rule { ru_bndrs = bndrs, ru_args = args
, ru_fn = fn_name, ru_rhs = rhs })
= do { (env', bndrs') <- simplBinders env bndrs
- ; let rule_env = updMode updModeForRules env'
+ ; let rhs_ty = substTy env' (exprType rhs)
+ rule_cont = mkBoringStop rhs_ty
+ rule_env = updMode updModeForRules env'
; args' <- mapM (simplExpr rule_env) args
- ; rhs' <- simplExpr rule_env rhs
+ ; rhs' <- simplExprC rule_env rhs rule_cont
; return (rule { ru_bndrs = bndrs'
, ru_fn = mb_new_nm `orElse` fn_name
, ru_args = args'
diff --git a/compiler/specialise/Rules.hs b/compiler/specialise/Rules.hs
index 42cb13e8df..ba44794db4 100644
--- a/compiler/specialise/Rules.hs
+++ b/compiler/specialise/Rules.hs
@@ -35,7 +35,8 @@ import OccurAnal ( occurAnalyseExpr )
import CoreFVs ( exprFreeVars, exprsFreeVars, bindFreeVars
, rulesFreeVarsDSet, exprsOrphNames, exprFreeVarsList )
import CoreUtils ( exprType, eqExpr, mkTick, mkTicks,
- stripTicksTopT, stripTicksTopE )
+ stripTicksTopT, stripTicksTopE,
+ isJoinBind )
import PprCore ( pprRules )
import Type ( Type, substTy, mkTCvSubst )
import TcType ( tcSplitTyConApp_maybe )
@@ -728,7 +729,8 @@ match renv subst e1 (Var v2) -- Note [Expanding variables]
match renv subst e1 (Let bind e2)
| -- pprTrace "match:Let" (vcat [ppr bind, ppr $ okToFloat (rv_lcl renv) (bindFreeVars bind)]) $
- okToFloat (rv_lcl renv) (bindFreeVars bind) -- See Note [Matching lets]
+ not (isJoinBind bind) -- can't float join point out of argument position
+ , okToFloat (rv_lcl renv) (bindFreeVars bind) -- See Note [Matching lets]
= match (renv { rv_fltR = flt_subst' })
(subst { rs_binds = rs_binds subst . Let bind'
, rs_bndrs = extendVarSetList (rs_bndrs subst) new_bndrs })
diff --git a/compiler/specialise/SpecConstr.hs b/compiler/specialise/SpecConstr.hs
index 71d2d4b25d..5ee2dec594 100644
--- a/compiler/specialise/SpecConstr.hs
+++ b/compiler/specialise/SpecConstr.hs
@@ -1625,15 +1625,8 @@ spec_one env fn arg_bndrs body (call_pat@(qvars, pats), rule_number)
-- return ()
-- And build the results
- ; let spec_id = mkLocalIdOrCoVar spec_name (mkLamTypes spec_lam_args body_ty)
- -- See Note [Transfer strictness]
- `setIdStrictness` spec_str
- `setIdArity` count isId spec_lam_args
- spec_str = calcSpecStrictness fn spec_lam_args pats
-
-
- -- Conditionally use result of new worker-wrapper transform
- (spec_lam_args, spec_call_args) = mkWorkerArgs (sc_dflags env) qvars body_ty
+ ; let (spec_lam_args, spec_call_args) = mkWorkerArgs (sc_dflags env)
+ qvars body_ty
-- Usual w/w hack to avoid generating
-- a spec_rhs of unlifted type and no args
@@ -1641,6 +1634,18 @@ spec_one env fn arg_bndrs body (call_pat@(qvars, pats), rule_number)
-- Annotate the variables with the strictness information from
-- the function (see Note [Strictness information in worker binders])
+ spec_join_arity | isJoinId fn = Just (length spec_lam_args)
+ | otherwise = Nothing
+ spec_id = mkLocalIdOrCoVar spec_name
+ (mkLamTypes spec_lam_args body_ty)
+ -- See Note [Transfer strictness]
+ `setIdStrictness` spec_str
+ `setIdArity` count isId spec_lam_args
+ `asJoinId_maybe` spec_join_arity
+ spec_str = calcSpecStrictness fn spec_lam_args pats
+
+
+ -- Conditionally use result of new worker-wrapper transform
spec_rhs = mkLams spec_lam_args_str spec_body
body_ty = exprType spec_body
rule_rhs = mkVarApps (Var spec_id) spec_call_args
diff --git a/compiler/specialise/Specialise.hs b/compiler/specialise/Specialise.hs
index a2b1604950..2b4d9f5185 100644
--- a/compiler/specialise/Specialise.hs
+++ b/compiler/specialise/Specialise.hs
@@ -23,6 +23,7 @@ import CoreSyn
import Rules
import CoreUtils ( exprIsTrivial, applyTypeToArgs, mkCast )
import CoreFVs ( exprFreeVars, exprsFreeVars, idFreeVars, exprsFreeIdsList )
+import CoreArity ( etaExpandToJoinPointRule )
import UniqSupply
import Name
import MkId ( voidArgId, voidPrimId )
@@ -1270,11 +1271,17 @@ specCalls mb_mod env rules_for_me calls_for_me fn rhs
let body_ty = applyTypeToArgs rhs fn_type rule_args
(lam_args, app_args) -- Add a dummy argument if body_ty is unlifted
| isUnliftedType body_ty -- C.f. WwLib.mkWorkerArgs
+ , not (isJoinId fn)
= (poly_tyvars ++ [voidArgId], poly_tyvars ++ [voidPrimId])
| otherwise = (poly_tyvars, poly_tyvars)
spec_id_ty = mkLamTypes lam_args body_ty
+ join_arity_change = length app_args - length rule_args
+ spec_join_arity | Just orig_join_arity <- isJoinId_maybe fn
+ = Just (orig_join_arity + join_arity_change)
+ | otherwise
+ = Nothing
- ; spec_f <- newSpecIdSM fn spec_id_ty
+ ; spec_f <- newSpecIdSM fn spec_id_ty spec_join_arity
; (spec_rhs, rhs_uds) <- specExpr rhs_env2 (mkLams lam_args body)
; this_mod <- getModule
; let
@@ -1292,7 +1299,7 @@ specCalls mb_mod env rules_for_me calls_for_me fn rhs
-- otherwise uniques end up there, making builds
-- less deterministic (See #4012 comment:61 ff)
- spec_env_rule = mkRule
+ rule_wout_eta = mkRule
this_mod
True {- Auto generated -}
is_local
@@ -1303,6 +1310,12 @@ specCalls mb_mod env rules_for_me calls_for_me fn rhs
rule_args
(mkVarApps (Var spec_f) app_args)
+ spec_env_rule
+ = case isJoinId_maybe fn of
+ Just join_arity -> etaExpandToJoinPointRule join_arity
+ rule_wout_eta
+ Nothing -> rule_wout_eta
+
-- Add the { d1' = dx1; d2' = dx2 } usage stuff
final_uds = foldr consDictBind rhs_uds dx_binds
@@ -1332,6 +1345,7 @@ specCalls mb_mod env rules_for_me calls_for_me fn rhs
spec_f_w_arity = spec_f `setIdArity` max 0 (fn_arity - n_dicts)
`setInlinePragma` spec_inl_prag
`setIdUnfolding` spec_unf
+ `asJoinId_maybe` spec_join_arity
; return (Just ((spec_f_w_arity, spec_rhs), final_uds, spec_env_rule)) } }
@@ -2260,13 +2274,14 @@ newDictBndr env b = do { uniq <- getUniqueM
ty' = substTy env (idType b)
; return (mkUserLocalOrCoVar (nameOccName n) uniq ty' (getSrcSpan n)) }
-newSpecIdSM :: Id -> Type -> SpecM Id
+newSpecIdSM :: Id -> Type -> Maybe JoinArity -> SpecM Id
-- Give the new Id a similar occurrence name to the old one
-newSpecIdSM old_id new_ty
+newSpecIdSM old_id new_ty join_arity_maybe
= do { uniq <- getUniqueM
; let name = idName old_id
new_occ = mkSpecOcc (nameOccName name)
new_id = mkUserLocalOrCoVar new_occ uniq new_ty (getSrcSpan name)
+ `asJoinId_maybe` join_arity_maybe
; return new_id }
{-
diff --git a/compiler/stgSyn/CoreToStg.hs b/compiler/stgSyn/CoreToStg.hs
index 37df9e2146..900d23f7b5 100644
--- a/compiler/stgSyn/CoreToStg.hs
+++ b/compiler/stgSyn/CoreToStg.hs
@@ -16,7 +16,7 @@ module CoreToStg ( coreToStg, coreExprToStg ) where
#include "HsVersions.h"
import CoreSyn
-import CoreUtils ( exprType, findDefault )
+import CoreUtils ( exprType, findDefault, isJoinBind )
import CoreArity ( manifestArity )
import StgSyn
@@ -28,11 +28,10 @@ import Id
import IdInfo
import DataCon
import CostCentre ( noCCS )
-import VarSet
import VarEnv
import Module
-import Name ( getOccName, isExternalName, nameOccName )
-import OccName ( occNameString, occNameFS )
+import Name ( isExternalName, nameOccName )
+import OccName ( occNameFS )
import BasicTypes ( Arity )
import TysWiredIn ( unboxedUnitDataCon )
import Literal
@@ -139,6 +138,10 @@ import Control.Monad (liftM, ap)
-- Note [What is a non-escaping let]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--
+-- NB: Nowadays this is recognized by the occurrence analyser by turning a
+-- "non-escaping let" into a join point. The following is then an operational
+-- account of join points.
+--
-- Consider:
--
-- let x = fvs \ args -> e
@@ -155,8 +158,7 @@ import Control.Monad (liftM, ap)
-- to the code for `x'.
--
-- All of this is provided x is:
--- 1. non-updatable - it must have at least one parameter (see Note
--- [Join point abstraction]);
+-- 1. non-updatable;
-- 2. guaranteed to be entered before the stack retreats -- ie x is not
-- buried in a heap-allocated closure, or passed as an argument to
-- something;
@@ -203,7 +205,7 @@ coreToStg dflags this_mod pgm
coreExprToStg :: CoreExpr -> StgExpr
coreExprToStg expr
- = new_expr where (new_expr,_,_) = initLne emptyVarEnv (coreToStgExpr expr)
+ = new_expr where (new_expr,_) = initCts emptyVarEnv (coreToStgExpr expr)
coreTopBindsToStg
@@ -244,7 +246,7 @@ coreTopBindToStg dflags this_mod env body_fvs (NonRec id rhs)
how_bound = LetBound TopLet $! manifestArity rhs
(stg_rhs, fvs') =
- initLne env $ do
+ initCts env $ do
(stg_rhs, fvs') <- coreToTopStgRhs dflags this_mod body_fvs (id,rhs)
return (stg_rhs, fvs')
@@ -267,7 +269,7 @@ coreTopBindToStg dflags this_mod env body_fvs (Rec pairs)
env' = extendVarEnvList env extra_env'
(stg_rhss, fvs')
- = initLne env' $ do
+ = initCts env' $ do
(stg_rhss, fvss') <- mapAndUnzipM (coreToTopStgRhs dflags this_mod body_fvs) pairs
let fvs' = unionFVInfos fvss'
return (stg_rhss, fvs')
@@ -298,10 +300,10 @@ coreToTopStgRhs
-> Module
-> FreeVarsInfo -- Free var info for the scope of the binding
-> (Id,CoreExpr)
- -> LneM (StgRhs, FreeVarsInfo)
+ -> CtsM (StgRhs, FreeVarsInfo)
coreToTopStgRhs dflags this_mod scope_fv_info (bndr, rhs)
- = do { (new_rhs, rhs_fvs, _) <- coreToStgExpr rhs
+ = do { (new_rhs, rhs_fvs) <- coreToStgExpr rhs
; let stg_rhs = mkTopStgRhs dflags this_mod rhs_fvs bndr bndr_info new_rhs
stg_arity = stgRhsArity stg_rhs
@@ -343,13 +345,8 @@ mkTopStgRhs dflags this_mod = mkStgRhs' con_updateable
coreToStgExpr
:: CoreExpr
- -> LneM (StgExpr, -- Decorated STG expr
- FreeVarsInfo, -- Its free vars (NB free, not live)
- EscVarsSet) -- Its escapees, a subset of its free vars;
- -- also a subset of the domain of the envt
- -- because we are only interested in the escapees
- -- for vars which might be turned into
- -- let-no-escaped ones.
+ -> CtsM (StgExpr, -- Decorated STG expr
+ FreeVarsInfo) -- Its free vars (NB free, not live)
-- The second and third components can be derived in a simple bottom up pass, not
-- dependent on any decisions about which variables will be let-no-escaped or
@@ -360,7 +357,7 @@ coreToStgExpr
-- No LitInteger's should be left by the time this is called. CorePrep
-- should have converted them all to a real core representation.
coreToStgExpr (Lit (LitInteger {})) = panic "coreToStgExpr: LitInteger"
-coreToStgExpr (Lit l) = return (StgLit l, emptyFVInfo, emptyVarSet)
+coreToStgExpr (Lit l) = return (StgLit l, emptyFVInfo)
coreToStgExpr (Var v) = coreToStgApp Nothing v [] []
coreToStgExpr (Coercion _) = coreToStgApp Nothing coercionTokenId [] []
@@ -374,15 +371,14 @@ coreToStgExpr expr@(Lam _ _)
(args, body) = myCollectBinders expr
args' = filterStgBinders args
in
- extendVarEnvLne [ (a, LambdaBound) | a <- args' ] $ do
- (body, body_fvs, body_escs) <- coreToStgExpr body
+ extendVarEnvCts [ (a, LambdaBound) | a <- args' ] $ do
+ (body, body_fvs) <- coreToStgExpr body
let
fvs = args' `minusFVBinders` body_fvs
- escs = body_escs `delVarSetList` args'
result_expr | null args' = body
| otherwise = StgLam args' body
- return (result_expr, fvs, escs)
+ return (result_expr, fvs)
coreToStgExpr (Tick tick expr)
= do case tick of
@@ -390,8 +386,8 @@ coreToStgExpr (Tick tick expr)
ProfNote{} -> return ()
SourceNote{} -> return ()
Breakpoint{} -> panic "coreToStgExpr: breakpoint should not happen"
- (expr2, fvs, escs) <- coreToStgExpr expr
- return (StgTick tick expr2, fvs, escs)
+ (expr2, fvs) <- coreToStgExpr expr
+ return (StgTick tick expr2, fvs)
coreToStgExpr (Cast expr _)
= coreToStgExpr expr
@@ -411,12 +407,11 @@ coreToStgExpr (Case scrut _ _ [])
coreToStgExpr (Case scrut bndr _ alts) = do
- (alts2, alts_fvs, alts_escs)
- <- extendVarEnvLne [(bndr, LambdaBound)] $ do
- (alts2, fvs_s, escs_s) <- mapAndUnzip3M vars_alt alts
+ (alts2, alts_fvs)
+ <- extendVarEnvCts [(bndr, LambdaBound)] $ do
+ (alts2, fvs_s) <- mapAndUnzipM vars_alt alts
return ( alts2,
- unionFVInfos fvs_s,
- unionVarSets escs_s )
+ unionFVInfos fvs_s )
let
-- Determine whether the default binder is dead or not
-- This helps the code generator to avoid generating an assignment
@@ -428,19 +423,14 @@ coreToStgExpr (Case scrut bndr _ alts) = do
-- since this is from the point of view of the case expr, where
-- the default binder is not free.
alts_fvs_wo_bndr = bndr `minusFVBinder` alts_fvs
- alts_escs_wo_bndr = alts_escs `delVarSet` bndr
-- We tell the scrutinee that everything
-- live in the alts is live in it, too.
- (scrut2, scrut_fvs, _scrut_escs) <- coreToStgExpr scrut
+ (scrut2, scrut_fvs) <- coreToStgExpr scrut
return (
StgCase scrut2 bndr' (mkStgAltType bndr alts) alts2,
- scrut_fvs `unionFVInfo` alts_fvs_wo_bndr,
- alts_escs_wo_bndr `unionVarSet` getFVSet scrut_fvs
- -- You might think we should have scrut_escs, not
- -- (getFVSet scrut_fvs), but actually we can't call, and
- -- then return from, a let-no-escape thing.
+ scrut_fvs `unionFVInfo` alts_fvs_wo_bndr
)
where
vars_alt (con, binders, rhs)
@@ -449,32 +439,19 @@ coreToStgExpr (Case scrut bndr _ alts) = do
-- See Note [Nullary unboxed tuple] in Type.hs
-- where a nullary tuple is mapped to (State# World#)
ASSERT( null binders )
- do { (rhs2, rhs_fvs, rhs_escs) <- coreToStgExpr rhs
- ; return ((DEFAULT, [], rhs2), rhs_fvs, rhs_escs) }
+ do { (rhs2, rhs_fvs) <- coreToStgExpr rhs
+ ; return ((DEFAULT, [], rhs2), rhs_fvs) }
| otherwise
= let -- Remove type variables
binders' = filterStgBinders binders
in
- extendVarEnvLne [(b, LambdaBound) | b <- binders'] $ do
- (rhs2, rhs_fvs, rhs_escs) <- coreToStgExpr rhs
+ extendVarEnvCts [(b, LambdaBound) | b <- binders'] $ do
+ (rhs2, rhs_fvs) <- coreToStgExpr rhs
return ( (con, binders', rhs2),
- binders' `minusFVBinders` rhs_fvs,
- rhs_escs `delVarSetList` binders' )
- -- ToDo: remove the delVarSet;
- -- since escs won't include any of these binders
-
--- Lets not only take quite a bit of work, but this is where we convert
--- then to let-no-escapes, if we wish.
--- (Meanwhile, we don't expect to see let-no-escapes...)
-
+ binders' `minusFVBinders` rhs_fvs )
coreToStgExpr (Let bind body) = do
- (new_let, fvs, escs, _)
- <- mfix (\ ~(_, _, _, no_binder_escapes) ->
- coreToStgLet no_binder_escapes bind body
- )
-
- return (new_let, fvs, escs)
+ coreToStgLet bind body
coreToStgExpr e = pprPanic "coreToStgExpr" (ppr e)
@@ -530,12 +507,12 @@ coreToStgApp
-> Id -- Function
-> [CoreArg] -- Arguments
-> [Tickish Id] -- Debug ticks
- -> LneM (StgExpr, FreeVarsInfo, EscVarsSet)
+ -> CtsM (StgExpr, FreeVarsInfo)
coreToStgApp _ f args ticks = do
(args', args_fvs, ticks') <- coreToStgArgs args
- how_bound <- lookupVarLne f
+ how_bound <- lookupVarCts f
let
n_val_args = valArgCount args
@@ -560,25 +537,6 @@ coreToStgApp _ f args ticks = do
| f_arity > 0 && saturated = stgSatOcc -- Saturated or over-saturated function call
| otherwise = stgUnsatOcc -- Unsaturated function or thunk
- fun_escs
- | not_letrec_bound = emptyVarSet -- Only letrec-bound escapees are interesting
- | f_arity == n_val_args = emptyVarSet -- A function *or thunk* with an exactly
- -- saturated call doesn't escape
- -- (let-no-escape applies to 'thunks' too)
-
- | otherwise = unitVarSet f -- Inexact application; it does escape
-
- -- At the moment of the call:
-
- -- either the function is *not* let-no-escaped, in which case
- -- nothing is live except live_in_cont
- -- or the function *is* let-no-escaped in which case the
- -- variables it uses are live, but still the function
- -- itself is not. PS. In this case, the function's
- -- live vars should already include those of the
- -- continuation, but it does no harm to just union the
- -- two regardless.
-
res_ty = exprType (mkApps (Var f) args)
app = case idDetails f of
DataConWorkId dc
@@ -602,18 +560,14 @@ coreToStgApp _ f args ticks = do
TickBoxOpId {} -> pprPanic "coreToStg TickBox" $ ppr (f,args')
_other -> StgApp f args'
fvs = fun_fvs `unionFVInfo` args_fvs
- vars = fun_escs `unionVarSet` (getFVSet args_fvs)
- -- All the free vars of the args are disqualified
- -- from being let-no-escaped.
tapp = foldr StgTick app (ticks ++ ticks')
-- Forcing these fixes a leak in the code generator, noticed while
-- profiling for trac #4367
- app `seq` fvs `seq` seqVarSet vars `seq` return (
+ app `seq` fvs `seq` return (
tapp,
- fvs,
- vars
+ fvs
)
@@ -623,7 +577,7 @@ coreToStgApp _ f args ticks = do
-- This is the guy that turns applications into A-normal form
-- ---------------------------------------------------------------------------
-coreToStgArgs :: [CoreArg] -> LneM ([StgArg], FreeVarsInfo, [Tickish Id])
+coreToStgArgs :: [CoreArg] -> CtsM ([StgArg], FreeVarsInfo, [Tickish Id])
coreToStgArgs []
= return ([], emptyFVInfo, [])
@@ -642,7 +596,7 @@ coreToStgArgs (Tick t e : args)
coreToStgArgs (arg : args) = do -- Non-type argument
(stg_args, args_fvs, ticks) <- coreToStgArgs args
- (arg', arg_fvs, _escs) <- coreToStgExpr arg
+ (arg', arg_fvs) <- coreToStgExpr arg
let
fvs = args_fvs `unionFVInfo` arg_fvs
@@ -682,69 +636,40 @@ coreToStgArgs (arg : args) = do -- Non-type argument
-- ---------------------------------------------------------------------------
coreToStgLet
- :: Bool -- True <=> yes, we are let-no-escaping this let
- -> CoreBind -- bindings
+ :: CoreBind -- bindings
-> CoreExpr -- body
- -> LneM (StgExpr, -- new let
- FreeVarsInfo, -- variables free in the whole let
- EscVarsSet, -- variables that escape from the whole let
- Bool) -- True <=> none of the binders in the bindings
- -- is among the escaping vars
-
-coreToStgLet let_no_escape bind body = do
- (bind2, bind_fvs, bind_escs,
- body2, body_fvs, body_escs)
- <- mfix $ \ ~(_, _, _, _, rec_body_fvs, _) -> do
-
- ( bind2, bind_fvs, bind_escs, env_ext)
+ -> CtsM (StgExpr, -- new let
+ FreeVarsInfo) -- variables free in the whole let
+
+coreToStgLet bind body = do
+ (bind2, bind_fvs,
+ body2, body_fvs)
+ <- mfix $ \ ~(_, _, _, rec_body_fvs) -> do
+
+ ( bind2, bind_fvs, env_ext)
<- vars_bind rec_body_fvs bind
-- Do the body
- extendVarEnvLne env_ext $ do
- (body2, body_fvs, body_escs) <- coreToStgExpr body
+ extendVarEnvCts env_ext $ do
+ (body2, body_fvs) <- coreToStgExpr body
- return (bind2, bind_fvs, bind_escs,
- body2, body_fvs, body_escs)
+ return (bind2, bind_fvs,
+ body2, body_fvs)
-- Compute the new let-expression
let
- new_let | let_no_escape = StgLetNoEscape bind2 body2
- | otherwise = StgLet bind2 body2
+ new_let | isJoinBind bind = StgLetNoEscape bind2 body2
+ | otherwise = StgLet bind2 body2
free_in_whole_let
= binders `minusFVBinders` (bind_fvs `unionFVInfo` body_fvs)
- real_bind_escs = if let_no_escape then
- bind_escs
- else
- getFVSet bind_fvs
- -- Everything escapes which is free in the bindings
-
- let_escs = (real_bind_escs `unionVarSet` body_escs) `delVarSetList` binders
-
- all_escs = bind_escs `unionVarSet` body_escs -- Still includes binders of
- -- this let(rec)
-
- no_binder_escapes = isEmptyVarSet (set_of_binders `intersectVarSet` all_escs)
-
- -- Debugging code as requested by Andrew Kennedy
- checked_no_binder_escapes
- | debugIsOn && not no_binder_escapes && any is_join_var binders
- = pprTrace "Interesting! A join var that isn't let-no-escaped" (ppr binders)
- False
- | otherwise = no_binder_escapes
-
- -- Mustn't depend on the passed-in let_no_escape flag, since
- -- no_binder_escapes is used by the caller to derive the flag!
return (
new_let,
- free_in_whole_let,
- let_escs,
- checked_no_binder_escapes
+ free_in_whole_let
)
where
- set_of_binders = mkVarSet binders
binders = bindersOf bind
mk_binding binder rhs
@@ -752,53 +677,44 @@ coreToStgLet let_no_escape bind body = do
vars_bind :: FreeVarsInfo -- Free var info for body of binding
-> CoreBind
- -> LneM (StgBinding,
+ -> CtsM (StgBinding,
FreeVarsInfo,
- EscVarsSet, -- free vars; escapee vars
[(Id, HowBound)]) -- extension to environment
vars_bind body_fvs (NonRec binder rhs) = do
- (rhs2, bind_fvs, escs) <- coreToStgRhs body_fvs (binder,rhs)
+ (rhs2, bind_fvs) <- coreToStgRhs body_fvs (binder,rhs)
let
env_ext_item = mk_binding binder rhs
return (StgNonRec binder rhs2,
- bind_fvs, escs, [env_ext_item])
+ bind_fvs, [env_ext_item])
vars_bind body_fvs (Rec pairs)
- = mfix $ \ ~(_, rec_rhs_fvs, _, _) ->
+ = mfix $ \ ~(_, rec_rhs_fvs, _) ->
let
rec_scope_fvs = unionFVInfo body_fvs rec_rhs_fvs
binders = map fst pairs
env_ext = [ mk_binding b rhs
| (b,rhs) <- pairs ]
in
- extendVarEnvLne env_ext $ do
- (rhss2, fvss, escss)
- <- mapAndUnzip3M (coreToStgRhs rec_scope_fvs) pairs
+ extendVarEnvCts env_ext $ do
+ (rhss2, fvss)
+ <- mapAndUnzipM (coreToStgRhs rec_scope_fvs) pairs
let
bind_fvs = unionFVInfos fvss
- escs = unionVarSets escss
return (StgRec (binders `zip` rhss2),
- bind_fvs, escs, env_ext)
-
-
-is_join_var :: Id -> Bool
--- A hack (used only for compiler debuggging) to tell if
--- a variable started life as a join point ($j)
-is_join_var j = occNameString (getOccName j) == "$j"
+ bind_fvs, env_ext)
coreToStgRhs :: FreeVarsInfo -- Free var info for the scope of the binding
-> (Id,CoreExpr)
- -> LneM (StgRhs, FreeVarsInfo, EscVarsSet)
+ -> CtsM (StgRhs, FreeVarsInfo)
coreToStgRhs scope_fv_info (bndr, rhs) = do
- (new_rhs, rhs_fvs, rhs_escs) <- coreToStgExpr rhs
- return (mkStgRhs rhs_fvs bndr bndr_info new_rhs,
- rhs_fvs, rhs_escs)
+ (new_rhs, rhs_fvs) <- coreToStgExpr rhs
+ return (mkStgRhs rhs_fvs bndr bndr_info new_rhs, rhs_fvs)
where
bndr_info = lookupFVInfo scope_fv_info bndr
@@ -814,6 +730,12 @@ mkStgRhs' con_updateable rhs_fvs bndr binder_info rhs
(getFVs rhs_fvs)
ReEntrant
bndrs body
+ | isJoinId bndr -- must be nullary join point
+ = ASSERT(idJoinArity bndr == 0)
+ StgRhsClosure noCCS binder_info
+ (getFVs rhs_fvs)
+ ReEntrant -- ignored for LNE
+ [] rhs
| StgConApp con args _ <- unticked_rhs
, not (con_updateable con args)
= -- CorePrep does this right, but just to make sure
@@ -883,19 +805,18 @@ isPAP env _ = False
-}
-- ---------------------------------------------------------------------------
--- A little monad for this let-no-escaping pass
+-- A monad for the core-to-STG pass
-- ---------------------------------------------------------------------------
--- There's a lot of stuff to pass around, so we use this LneM monad to
--- help. All the stuff here is only passed *down*.
+-- There's a lot of stuff to pass around, so we use this CtsM
+-- ("core-to-STG monad") monad to help. All the stuff here is only passed
+-- *down*.
-newtype LneM a = LneM
- { unLneM :: IdEnv HowBound
+newtype CtsM a = CtsM
+ { unCtsM :: IdEnv HowBound
-> a
}
-type EscVarsSet = IdSet
-
data HowBound
= ImportBound -- Used only as a response to lookupBinding; never
-- exists in the range of the (IdEnv HowBound)
@@ -937,45 +858,45 @@ topLevelBound _ = False
-- The std monad functions:
-initLne :: IdEnv HowBound -> LneM a -> a
-initLne env m = unLneM m env
+initCts :: IdEnv HowBound -> CtsM a -> a
+initCts env m = unCtsM m env
-{-# INLINE thenLne #-}
-{-# INLINE returnLne #-}
+{-# INLINE thenCts #-}
+{-# INLINE returnCts #-}
-returnLne :: a -> LneM a
-returnLne e = LneM $ \_ -> e
+returnCts :: a -> CtsM a
+returnCts e = CtsM $ \_ -> e
-thenLne :: LneM a -> (a -> LneM b) -> LneM b
-thenLne m k = LneM $ \env
- -> unLneM (k (unLneM m env)) env
+thenCts :: CtsM a -> (a -> CtsM b) -> CtsM b
+thenCts m k = CtsM $ \env
+ -> unCtsM (k (unCtsM m env)) env
-instance Functor LneM where
+instance Functor CtsM where
fmap = liftM
-instance Applicative LneM where
- pure = returnLne
+instance Applicative CtsM where
+ pure = returnCts
(<*>) = ap
-instance Monad LneM where
- (>>=) = thenLne
+instance Monad CtsM where
+ (>>=) = thenCts
-instance MonadFix LneM where
- mfix expr = LneM $ \env ->
- let result = unLneM (expr result) env
+instance MonadFix CtsM where
+ mfix expr = CtsM $ \env ->
+ let result = unCtsM (expr result) env
in result
-- Functions specific to this monad:
-extendVarEnvLne :: [(Id, HowBound)] -> LneM a -> LneM a
-extendVarEnvLne ids_w_howbound expr
- = LneM $ \env
- -> unLneM expr (extendVarEnvList env ids_w_howbound)
+extendVarEnvCts :: [(Id, HowBound)] -> CtsM a -> CtsM a
+extendVarEnvCts ids_w_howbound expr
+ = CtsM $ \env
+ -> unCtsM expr (extendVarEnvList env ids_w_howbound)
-lookupVarLne :: Id -> LneM HowBound
-lookupVarLne v = LneM $ \env -> lookupBinding env v
+lookupVarCts :: Id -> CtsM HowBound
+lookupVarCts v = CtsM $ \env -> lookupBinding env v
lookupBinding :: IdEnv HowBound -> Id -> HowBound
lookupBinding env v = case lookupVarEnv env v of
@@ -1057,9 +978,6 @@ getFVs fvs = [id | (id, how_bound, _) <- nonDetEltsUFM fvs,
-- See Note [Unique Determinism and code generation]
not (topLevelBound how_bound) ]
-getFVSet :: FreeVarsInfo -> VarSet
-getFVSet fvs = mkVarSet (getFVs fvs)
-
plusFVInfo :: (Var, HowBound, StgBinderInfo)
-> (Var, HowBound, StgBinderInfo)
-> (Var, HowBound, StgBinderInfo)
diff --git a/compiler/stranal/DmdAnal.hs b/compiler/stranal/DmdAnal.hs
index 79ae20f8fb..212767e531 100644
--- a/compiler/stranal/DmdAnal.hs
+++ b/compiler/stranal/DmdAnal.hs
@@ -268,7 +268,7 @@ dmdAnal' env dmd (Case scrut case_bndr ty alts)
-- This is used for a non-recursive local let without manifest lambdas.
-- This is the LetUp rule in the paper “Higher-Order Cardinality Analysis”.
dmdAnal' env dmd (Let (NonRec id rhs) body)
- | useLetUp rhs
+ | useLetUp id rhs
, Nothing <- unpackTrivial rhs
-- dmdAnalRhsLetDown treats trivial right hand sides specially
-- so if we have a trival right hand side, fall through to that.
@@ -632,7 +632,7 @@ dmdAnalRhsLetDown top_lvl rec_flag env id rhs
trim_sums = not (isTopLevel top_lvl) -- See Note [CPR for sum types]
-- See Note [CPR for thunks]
- is_thunk = not (exprIsHNF rhs)
+ is_thunk = not (exprIsHNF rhs) && not (isJoinId id)
not_strict
= isTopLevel top_lvl -- Top level and recursive things don't
|| isJust rec_flag -- get their demandInfo set at all
@@ -654,11 +654,13 @@ unpackTrivial _ = Nothing
-- down (rhs before body).
--
-- We use LetDown if there is a chance to get a useful strictness signature.
--- This is the case when there are manifest value lambdas.
-useLetUp :: CoreExpr -> Bool
-useLetUp (Lam v e) | isTyVar v = useLetUp e
-useLetUp (Lam _ _) = False
-useLetUp _ = True
+-- This is the case when there are manifest value lambdas or the binding is a
+-- join point (hence always acts like a function, not a value).
+useLetUp :: Var -> CoreExpr -> Bool
+useLetUp f _ | isJoinId f = False
+useLetUp f (Lam v e) | isTyVar v = useLetUp f e
+useLetUp _ (Lam _ _) = False
+useLetUp _ _ = True
{-
diff --git a/compiler/stranal/WorkWrap.hs b/compiler/stranal/WorkWrap.hs
index d50bb223f6..0963df0d06 100644
--- a/compiler/stranal/WorkWrap.hs
+++ b/compiler/stranal/WorkWrap.hs
@@ -14,6 +14,7 @@ import CoreFVs ( exprFreeVars )
import Var
import Id
import IdInfo
+import Type
import UniqSupply
import BasicTypes
import DynFlags
@@ -237,6 +238,48 @@ There is an infelicity though. We may get something like
The code for f duplicates that for g, without any real benefit. It
won't really be executed, because calls to f will go via the inlining.
+Note [Don't CPR join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+There's no point in doing CPR on a join point. If the whole function is getting
+CPR'd, then the case expression around the worker function will get pushed into
+the join point by the simplifier, which will have the same effect that CPR would
+have - the result will be returned in an unboxed tuple.
+
+ f z = let join j x y = (x+1, y+1)
+ in case z of A -> j 1 2
+ B -> j 2 3
+
+ =>
+
+ f z = case $wf z of (# a, b #) -> (a, b)
+ $wf z = case (let join j x y = (x+1, y+1)
+ in case z of A -> j 1 2
+ B -> j 2 3) of (a, b) -> (# a, b #)
+
+ =>
+
+ f z = case $wf z of (# a, b #) -> (a, b)
+ $wf z = let join j x y = (# x+1, y+1 #)
+ in case z of A -> j 1 2
+ B -> j 2 3
+
+Doing CPR on a join point would be tricky anyway, as the worker could not be
+a join point because it would not be tail-called. However, doing the *argument*
+part of W/W still works for join points, since the wrapper body will make a tail
+call:
+
+ f z = let join j x y = x + y
+ in ...
+
+ =>
+
+ f z = let join $wj x# y# = x# +# y#
+ j x y = case x of I# x# ->
+ case y of I# y# ->
+ $wj x# y#
+ in ...
+
Note [Wrapper activation]
~~~~~~~~~~~~~~~~~~~~~~~~~
When should the wrapper inlining be active? It must not be active
@@ -312,8 +355,9 @@ tryWW dflags fam_envs is_rec fn_id rhs
-- See Note [Zapping DmdEnv after Demand Analyzer] and
-- See Note [Zapping Used Once info in WorkWrap]
- is_fun = notNull wrap_dmds
- is_thunk = not is_fun && not (exprIsHNF rhs)
+ is_fun = notNull wrap_dmds || isJoinId fn_id
+ is_thunk = not is_fun && not (exprIsHNF rhs) && not (isJoinId fn_id)
+ && not (isUnliftedType (idType fn_id))
{-
Note [Zapping DmdEnv after Demand Analyzer]
@@ -362,9 +406,10 @@ splitFun :: DynFlags -> FamInstEnvs -> Id -> IdInfo -> [Demand] -> DmdResult ->
splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
= WARN( not (wrap_dmds `lengthIs` arity), ppr fn_id <+> (ppr arity $$ ppr wrap_dmds $$ ppr res_info) ) do
-- The arity should match the signature
- stuff <- mkWwBodies dflags fam_envs rhs_fvs fun_ty wrap_dmds res_info
+ stuff <- mkWwBodies dflags fam_envs rhs_fvs mb_join_arity fun_ty
+ wrap_dmds use_res_info
case stuff of
- Just (work_demands, wrap_fn, work_fn) -> do
+ Just (work_demands, join_arity, wrap_fn, work_fn) -> do
work_uniq <- getUniqueM
let work_rhs = work_fn rhs
work_prag = InlinePragma { inl_src = SourceText "{-# INLINE"
@@ -375,7 +420,10 @@ splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
-- idl_inline: copy from fn_id; see Note [Worker-wrapper for INLINABLE functions]
-- idl_act: see Note [Activation for INLINABLE workers]
-- inl_rule: it does not make sense for workers to be constructorlike.
-
+ work_join_arity | isJoinId fn_id = Just join_arity
+ | otherwise = Nothing
+ -- worker is join point iff wrapper is join point
+ -- (see Note [Don't CPR join points])
work_id = mkWorkerId work_uniq fn_id (exprType work_rhs)
`setIdOccInfo` occInfo fn_info
-- Copy over occurrence info from parent
@@ -396,6 +444,9 @@ splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
`setIdArity` work_arity
-- Set the arity so that the Core Lint check that the
+ -- arity is consistent with the demand type goes
+ -- through
+ `asJoinId_maybe` work_join_arity
work_arity = length work_demands
@@ -404,7 +455,6 @@ splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
worker_demand | single_call = mkWorkerDemand work_arity
| otherwise = topDmd
- -- arity is consistent with the demand type goes through
wrap_act = ActiveAfter NoSourceText 0
wrap_rhs = wrap_fn work_id
@@ -418,7 +468,7 @@ splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
wrap_id = fn_id `setIdUnfolding` mkWwInlineRule wrap_rhs arity
`setInlinePragma` wrap_prag
- `setIdOccInfo` NoOccInfo
+ `setIdOccInfo` noOccInfo
-- Zap any loop-breaker-ness, to avoid bleating from Lint
-- about a loop breaker with an INLINE rule
@@ -429,6 +479,7 @@ splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
Nothing -> return [(fn_id, rhs)]
where
+ mb_join_arity = isJoinId_maybe fn_id
rhs_fvs = exprFreeVars rhs
fun_ty = idType fn_id
inl_prag = inlinePragInfo fn_info
@@ -437,7 +488,11 @@ splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
-- The arity is set by the simplifier using exprEtaExpandArity
-- So it may be more than the number of top-level-visible lambdas
- work_res_info = case returnsCPR_maybe res_info of
+ use_res_info | isJoinId fn_id = topRes -- Note [Don't CPR join points]
+ | otherwise = res_info
+ work_res_info | isJoinId fn_id = res_info -- Worker remains CPR-able
+ | otherwise
+ = case returnsCPR_maybe res_info of
Just _ -> topRes -- Cpr stuff done by wrapper; kill it here
Nothing -> res_info -- Preserve exception/divergence
@@ -536,7 +591,8 @@ then the splitting will go deeper too.
splitThunk :: DynFlags -> FamInstEnvs -> RecFlag -> Var -> Expr Var -> UniqSM [(Var, Expr Var)]
splitThunk dflags fam_envs is_rec fn_id rhs
- = do { (useful,_, wrap_fn, work_fn) <- mkWWstr dflags fam_envs [fn_id]
+ = ASSERT(not (isJoinId fn_id))
+ do { (useful,_, wrap_fn, work_fn) <- mkWWstr dflags fam_envs [fn_id]
; let res = [ (fn_id, Let (NonRec fn_id rhs) (wrap_fn (work_fn (Var fn_id)))) ]
; if useful then ASSERT2( isNonRec is_rec, ppr fn_id ) -- The thunk must be non-recursive
return res
diff --git a/compiler/stranal/WwLib.hs b/compiler/stranal/WwLib.hs
index fd0826c5fd..8b02ba0862 100644
--- a/compiler/stranal/WwLib.hs
+++ b/compiler/stranal/WwLib.hs
@@ -16,10 +16,11 @@ module WwLib ( mkWwBodies, mkWWstr, mkWorkerArgs
import CoreSyn
import CoreUtils ( exprType, mkCast )
import Id
-import IdInfo ( vanillaIdInfo )
+import IdInfo ( JoinArity, vanillaIdInfo )
import DataCon
import Demand
-import MkCore ( mkRuntimeErrorApp, aBSENT_ERROR_ID, mkCoreUbxTup )
+import MkCore ( mkRuntimeErrorApp, aBSENT_ERROR_ID, mkCoreUbxTup
+ , mkCoreApp, mkCoreLet )
import MkId ( voidArgId, voidPrimId )
import TysPrim ( voidPrimTy )
import TysWiredIn ( tupleDataCon )
@@ -112,6 +113,7 @@ the unusable strictness-info into the interfaces.
type WwResult
= ([Demand], -- Demands for worker (value) args
+ JoinArity, -- Number of worker (type OR value) args
Id -> CoreExpr, -- Wrapper body, lacking only the worker Id
CoreExpr -> CoreExpr) -- Worker body, lacking the original function rhs
@@ -119,6 +121,7 @@ mkWwBodies :: DynFlags
-> FamInstEnvs
-> VarSet -- Free vars of RHS
-- See Note [Freshen WW arguments]
+ -> Maybe JoinArity -- Just ar <=> is join point with join arity ar
-> Type -- Type of original function
-> [Demand] -- Strictness of original function
-> DmdResult -- Info about function result
@@ -135,7 +138,7 @@ mkWwBodies :: DynFlags
-- let x = (a,b) in
-- E
-mkWwBodies dflags fam_envs rhs_fvs fun_ty demands res_info
+mkWwBodies dflags fam_envs rhs_fvs mb_join_arity fun_ty demands res_info
= do { let empty_subst = mkEmptyTCvSubst (mkInScopeSet rhs_fvs)
-- See Note [Freshen WW arguments]
@@ -152,8 +155,10 @@ mkWwBodies dflags fam_envs rhs_fvs fun_ty demands res_info
worker_body = mkLams work_lam_args. work_fn_str . work_fn_cpr . work_fn_args
; if isWorkerSmallEnough dflags work_args
+ && not (too_many_args_for_join_point wrap_args)
&& (useful1 && not only_one_void_argument || useful2)
- then return (Just (worker_args_dmds, wrapper_body, worker_body))
+ then return (Just (worker_args_dmds, length work_call_args,
+ wrapper_body, worker_body))
else return Nothing
}
-- We use an INLINE unconditionally, even if the wrapper turns out to be
@@ -173,6 +178,17 @@ mkWwBodies dflags fam_envs rhs_fvs fun_ty demands res_info
| otherwise
= False
+ -- Note [Join points returning functions]
+ too_many_args_for_join_point wrap_args
+ | Just join_arity <- mb_join_arity
+ , wrap_args `lengthExceeds` join_arity
+ = WARN(True, text "Unable to worker/wrapper join point with arity " <+>
+ int join_arity <+> text "but" <+>
+ int (length wrap_args) <+> text "args")
+ True
+ | otherwise
+ = False
+
-- See Note [Limit w/w arity]
isWorkerSmallEnough :: DynFlags -> [Var] -> Bool
isWorkerSmallEnough dflags vars = count isId vars <= maxWorkerArgs dflags
@@ -264,6 +280,67 @@ create a space leak. 2) It can prevent inlining *under a lambda*. If w/w
removes the last argument from a function f, then f now looks like a thunk, and
so f can't be inlined *under a lambda*.
+Note [Join points and beta-redexes]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Originally, the worker would invoke the original function by calling it with
+arguments, thus producing a beta-redex for the simplifier to munch away:
+
+ \x y z -> e => (\x y z -> e) wx wy wz
+
+Now that we have special rules about join points, however, this is Not Good if
+the original function is itself a join point, as then it may contain invocations
+of other join points:
+
+ join j1 x = ...
+ join j2 y = if y == 0 then 0 else j1 y
+
+ =>
+
+ join j1 x = ...
+ join $wj2 y# = let wy = I# y# in (\y -> if y == 0 then 0 else jump j1 y) wy
+ join j2 y = case y of I# y# -> jump $wj2 y#
+
+There can't be an intervening lambda between a join point's declaration and its
+occurrences, so $wj2 here is wrong. But of course, this is easy enough to fix:
+
+ ...
+ let join $wj2 y# = let wy = I# y# in let y = wy in if y == 0 then 0 else j1 y
+ ...
+
+Hence we simply do the beta-reduction here. (This would be harder if we had to
+worry about hygiene, but luckily wy is freshly generated.)
+
+Note [Join points returning functions]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+It is crucial that the arity of a join point depends on its *callers,* not its
+own syntax. What this means is that a join point can have "extra lambdas":
+
+f :: Int -> Int -> (Int, Int) -> Int
+f x y = join j (z, w) = \(u, v) -> ...
+ in jump j (x, y)
+
+Typically this happens with functions that are seen as computing functions,
+rather than being curried. (The real-life example was GraphOps.addConflicts.)
+
+When we create the wrapper, it *must* be in "eta-contracted" form so that the
+jump has the right number of arguments:
+
+f x y = join $wj z' w' = \u' v' -> let {z = z'; w = w'; u = u'; v = v'} in ...
+ j (z, w) = jump $wj z w
+
+(See Note [Join points and beta-redexes] for where the lets come from.) If j
+were a function, we would instead say
+
+f x y = let $wj = \z' w' u' v' -> let {z = z'; w = w'; u = u'; v = v'} in ...
+ j (z, w) (u, v) = $wj z w u v
+
+Notice that the worker ends up with the same lambdas; it's only the wrapper we
+have to be concerned about.
+
+FIXME Currently the functionality to produce "eta-contracted" wrappers is
+unimplemented; we simply give up.
************************************************************************
* *
@@ -324,7 +401,7 @@ mkWWargs subst fun_ty demands
<- mkWWargs subst fun_ty' demands'
; return (id : wrap_args,
Lam id . wrap_fn_args,
- work_fn_args . (`App` varToCoreExpr id),
+ apply_or_bind_then work_fn_args (varToCoreExpr id),
res_ty) }
| Just (tv, fun_ty') <- splitForAllTy_maybe fun_ty
@@ -335,7 +412,7 @@ mkWWargs subst fun_ty demands
<- mkWWargs subst' fun_ty' demands
; return (tv' : wrap_args,
Lam tv' . wrap_fn_args,
- work_fn_args . (`mkTyApps` [mkTyVarTy tv']),
+ apply_or_bind_then work_fn_args (mkTyArg (mkTyVarTy tv')),
res_ty) }
| Just (co, rep_ty) <- topNormaliseNewType_maybe fun_ty
@@ -358,7 +435,12 @@ mkWWargs subst fun_ty demands
| otherwise
= WARN( True, ppr fun_ty ) -- Should not happen: if there is a demand
return ([], id, id, substTy subst fun_ty) -- then there should be a function arrow
-
+ where
+ -- See Note [Join points and beta-redexes]
+ apply_or_bind_then k arg (Lam bndr body)
+ = mkCoreLet (NonRec bndr arg) (k body) -- Important that arg is fresh!
+ apply_or_bind_then k arg fun
+ = k $ mkCoreApp (text "mkWWargs") fun arg
applyToVars :: [Var] -> CoreExpr -> CoreExpr
applyToVars vars fn = mkVarApps fn vars
diff --git a/compiler/types/Type.hs b/compiler/types/Type.hs
index ad1b11f625..86d6eaecd0 100644
--- a/compiler/types/Type.hs
+++ b/compiler/types/Type.hs
@@ -58,6 +58,8 @@ module Type (
filterOutInvisibleTyVars, partitionInvisibles,
synTyConResKind,
+ modifyJoinResTy, setJoinResTy,
+
-- Analyzing types
TyCoMapper(..), mapType, mapCoercion,
@@ -101,6 +103,8 @@ module Type (
isCoercionTy_maybe, isCoercionType, isForAllTy,
isPiTy, isTauTy, isFamFreeTy,
+ isValidJoinPointType,
+
-- (Lifting and boxity)
isLiftedType_maybe, isUnliftedType, isUnboxedTupleType, isUnboxedSumType,
isAlgType, isClosedAlgType,
@@ -197,6 +201,8 @@ module Type (
#include "HsVersions.h"
+import BasicTypes
+
-- We import the representation and primitive functions from TyCoRep.
-- Many things are reexported, but not the representation!
@@ -1952,6 +1958,43 @@ isPrimitiveType ty = case splitTyConApp_maybe ty of
{-
************************************************************************
* *
+\subsection{Join points}
+* *
+************************************************************************
+-}
+
+-- | Determine whether a type could be the type of a join point of given total
+-- arity, according to the polymorphism rule. A join point cannot be polymorphic
+-- in its return type, since given
+-- join j @a @b x y z = e1 in e2,
+-- the types of e1 and e2 must be the same, and a and b are not in scope for e2.
+-- (See Note [The polymorphism rule of join points] in CoreSyn.) Returns False
+-- also if the type simply doesn't have enough arguments.
+--
+-- Note that we need to know how many arguments (type *and* value) the putative
+-- join point takes; for instance, if
+-- j :: forall a. a -> Int
+-- then j could be a binary join point returning an Int, but it could *not* be a
+-- unary join point returning a -> Int.
+--
+-- TODO: See Note [Excess polymorphism and join points]
+isValidJoinPointType :: JoinArity -> Type -> Bool
+isValidJoinPointType arity ty
+ = valid_under emptyVarSet arity ty
+ where
+ valid_under tvs arity ty
+ | arity == 0
+ = isEmptyVarSet (tvs `intersectVarSet` tyCoVarsOfType ty)
+ | Just (t, ty') <- splitForAllTy_maybe ty
+ = valid_under (tvs `extendVarSet` t) (arity-1) ty'
+ | Just (_, res_ty) <- splitFunTy_maybe ty
+ = valid_under tvs (arity-1) res_ty
+ | otherwise
+ = False
+
+{-
+************************************************************************
+* *
\subsection{Sequencing on types}
* *
************************************************************************
@@ -2303,3 +2346,26 @@ splitVisVarsOfType orig_ty = Pair invis_vars vis_vars
splitVisVarsOfTypes :: [Type] -> Pair TyCoVarSet
splitVisVarsOfTypes = foldMap splitVisVarsOfType
+
+modifyJoinResTy :: Int -- Number of binders to skip
+ -> (Type -> Type) -- Function to apply to result type
+ -> Type -- Type of join point
+ -> Type -- New type
+-- INVARIANT: If any of the first n binders are foralls, those tyvars cannot
+-- appear in the original result type. See isValidJoinPointType.
+modifyJoinResTy orig_ar f orig_ty
+ = go orig_ar orig_ty
+ where
+ go 0 ty = f ty
+ go n ty | Just (arg_bndr, res_ty) <- splitPiTy_maybe ty
+ = mkPiTy arg_bndr (go (n-1) res_ty)
+ | otherwise
+ = pprPanic "modifyJoinResTy" (ppr orig_ar <+> ppr orig_ty)
+
+setJoinResTy :: Int -- Number of binders to skip
+ -> Type -- New result type
+ -> Type -- Type of join point
+ -> Type -- New type
+-- INVARIANT: Same as for modifyJoinResTy
+setJoinResTy ar new_res_ty ty
+ = modifyJoinResTy ar (const new_res_ty) ty
diff --git a/compiler/utils/Outputable.hs b/compiler/utils/Outputable.hs
index 7d79f93bb2..3f94a68413 100644
--- a/compiler/utils/Outputable.hs
+++ b/compiler/utils/Outputable.hs
@@ -944,6 +944,19 @@ class Outputable a => OutputableBndr a where
-- prefix position of an application, thus (f a b) or ((+) x)
-- or infix position, thus (a `f` b) or (x + y)
+ pprNonRecBndrKeyword, pprRecBndrKeyword :: a -> SDoc
+ -- Print which keyword introduces the binder in Core code. This should be
+ -- "let" or "letrec" for a value but "join" or "joinrec" for a join point.
+ pprNonRecBndrKeyword _ = text "let"
+ pprRecBndrKeyword _ = text "letrec"
+
+ pprLamsOnLhs :: a -> Int
+ -- For a join point of join arity n, we want to print j = \x1 ... xn -> e
+ -- as "j x1 ... xn = e" to differentiate when a join point returns a
+ -- lambda (the first rendering looks like a nullary join point returning
+ -- an n-argument function).
+ pprLamsOnLhs _ = 0
+
{-
************************************************************************
* *
diff --git a/compiler/utils/UniqFM.hs b/compiler/utils/UniqFM.hs
index 38d94342ad..49ceb89d90 100644
--- a/compiler/utils/UniqFM.hs
+++ b/compiler/utils/UniqFM.hs
@@ -49,6 +49,7 @@ module UniqFM (
plusUFM,
plusUFM_C,
plusUFM_CD,
+ plusMaybeUFM_C,
plusUFMList,
minusUFM,
intersectUFM,
@@ -217,6 +218,15 @@ plusUFM_CD f (UFM xm) dx (UFM ym) dy
(M.map (\y -> dx `f` y))
xm ym
+plusMaybeUFM_C :: (elt -> elt -> Maybe elt)
+ -> UniqFM elt -> UniqFM elt -> UniqFM elt
+plusMaybeUFM_C f (UFM xm) (UFM ym)
+ = UFM $ M.mergeWithKey
+ (\_ x y -> x `f` y)
+ id
+ id
+ xm ym
+
plusUFMList :: [UniqFM elt] -> UniqFM elt
plusUFMList = foldl' plusUFM emptyUFM
diff --git a/testsuite/tests/deSugar/should_compile/T2431.stderr b/testsuite/tests/deSugar/should_compile/T2431.stderr
index a8da44b73f..83826408cf 100644
--- a/testsuite/tests/deSugar/should_compile/T2431.stderr
+++ b/testsuite/tests/deSugar/should_compile/T2431.stderr
@@ -1,8 +1,9 @@
==================== Tidy Core ====================
-Result size of Tidy Core = {terms: 44, types: 34, coercions: 1}
+Result size of Tidy Core
+ = {terms: 44, types: 34, coercions: 1, joins: 0/0}
--- RHS size: {terms: 2, types: 4, coercions: 1}
+-- RHS size: {terms: 2, types: 4, coercions: 1, joins: 0/0}
T2431.$WRefl [InlPrag=INLINE] :: forall a. a :~: a
[GblId[DataConWrapper],
Caf=NoCafRefs,
@@ -16,47 +17,47 @@ T2431.$WRefl =
\ (@ a) ->
T2431.Refl @ a @ a @~ (<a>_N :: (a :: *) GHC.Prim.~# (a :: *))
--- RHS size: {terms: 4, types: 8, coercions: 0}
+-- RHS size: {terms: 4, types: 8, coercions: 0, joins: 0/0}
absurd :: forall a. (Int :~: Bool) -> a
[GblId, Arity=1, Caf=NoCafRefs, Str=<L,U>x]
absurd = \ (@ a) (x :: Int :~: Bool) -> case x of { }
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
$trModule1 :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs]
$trModule1 = "main"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
$trModule2 :: GHC.Types.TrName
[GblId, Caf=NoCafRefs]
$trModule2 = GHC.Types.TrNameS $trModule1
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
$trModule3 :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs]
$trModule3 = "T2431"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
$trModule4 :: GHC.Types.TrName
[GblId, Caf=NoCafRefs]
$trModule4 = GHC.Types.TrNameS $trModule3
--- RHS size: {terms: 3, types: 0, coercions: 0}
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
T2431.$trModule :: GHC.Types.Module
[GblId, Caf=NoCafRefs]
T2431.$trModule = GHC.Types.Module $trModule2 $trModule4
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
$tc'Refl1 :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs]
$tc'Refl1 = "'Refl"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
$tc'Refl2 :: GHC.Types.TrName
[GblId, Caf=NoCafRefs]
$tc'Refl2 = GHC.Types.TrNameS $tc'Refl1
--- RHS size: {terms: 5, types: 0, coercions: 0}
+-- RHS size: {terms: 5, types: 0, coercions: 0, joins: 0/0}
T2431.$tc'Refl :: GHC.Types.TyCon
[GblId, Caf=NoCafRefs]
T2431.$tc'Refl =
@@ -66,17 +67,17 @@ T2431.$tc'Refl =
T2431.$trModule
$tc'Refl2
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
$tc:~:1 :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs]
$tc:~:1 = ":~:"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
$tc:~:2 :: GHC.Types.TrName
[GblId, Caf=NoCafRefs]
$tc:~:2 = GHC.Types.TrNameS $tc:~:1
--- RHS size: {terms: 5, types: 0, coercions: 0}
+-- RHS size: {terms: 5, types: 0, coercions: 0, joins: 0/0}
T2431.$tc:~: :: GHC.Types.TyCon
[GblId, Caf=NoCafRefs]
T2431.$tc:~: =
diff --git a/testsuite/tests/deriving/perf/all.T b/testsuite/tests/deriving/perf/all.T
index 0c3e9a4d3e..4c4bb97101 100644
--- a/testsuite/tests/deriving/perf/all.T
+++ b/testsuite/tests/deriving/perf/all.T
@@ -1,6 +1,8 @@
test('T10858',
[compiler_stats_num_field('bytes allocated',
- [ (wordsize(64), 222312440, 8) ]),
+ [ (wordsize(64), 247768192, 8) ]),
+ # Initial: 222312440
+ # 2016-12-19 247768192 Join points (#19288)
only_ways(['normal'])],
compile,
['-O'])
diff --git a/testsuite/tests/numeric/should_compile/T7116.stdout b/testsuite/tests/numeric/should_compile/T7116.stdout
index 7fe4d93d87..bc2f85b85f 100644
--- a/testsuite/tests/numeric/should_compile/T7116.stdout
+++ b/testsuite/tests/numeric/should_compile/T7116.stdout
@@ -1,8 +1,9 @@
==================== Tidy Core ====================
-Result size of Tidy Core = {terms: 50, types: 25, coercions: 0}
+Result size of Tidy Core
+ = {terms: 50, types: 25, coercions: 0, joins: 0/0}
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T7116.$trModule4 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -10,7 +11,7 @@ T7116.$trModule4 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 20 0}]
T7116.$trModule4 = "main"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T7116.$trModule3 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -19,7 +20,7 @@ T7116.$trModule3 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T7116.$trModule3 = GHC.Types.TrNameS T7116.$trModule4
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T7116.$trModule2 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -27,7 +28,7 @@ T7116.$trModule2 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
T7116.$trModule2 = "T7116"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T7116.$trModule1 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -36,7 +37,7 @@ T7116.$trModule1 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T7116.$trModule1 = GHC.Types.TrNameS T7116.$trModule2
--- RHS size: {terms: 3, types: 0, coercions: 0}
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
T7116.$trModule :: GHC.Types.Module
[GblId,
Caf=NoCafRefs,
@@ -46,7 +47,7 @@ T7116.$trModule :: GHC.Types.Module
T7116.$trModule =
GHC.Types.Module T7116.$trModule3 T7116.$trModule1
--- RHS size: {terms: 8, types: 3, coercions: 0}
+-- RHS size: {terms: 8, types: 3, coercions: 0, joins: 0/0}
dr :: Double -> Double
[GblId,
Arity=1,
@@ -63,7 +64,7 @@ dr =
\ (x :: Double) ->
case x of { GHC.Types.D# x1 -> GHC.Types.D# (GHC.Prim.+## x1 x1) }
--- RHS size: {terms: 8, types: 3, coercions: 0}
+-- RHS size: {terms: 8, types: 3, coercions: 0, joins: 0/0}
dl :: Double -> Double
[GblId,
Arity=1,
@@ -78,7 +79,7 @@ dl =
\ (x :: Double) ->
case x of { GHC.Types.D# y -> GHC.Types.D# (GHC.Prim.+## y y) }
--- RHS size: {terms: 8, types: 3, coercions: 0}
+-- RHS size: {terms: 8, types: 3, coercions: 0, joins: 0/0}
fr :: Float -> Float
[GblId,
Arity=1,
@@ -97,7 +98,7 @@ fr =
GHC.Types.F# (GHC.Prim.plusFloat# x1 x1)
}
--- RHS size: {terms: 8, types: 3, coercions: 0}
+-- RHS size: {terms: 8, types: 3, coercions: 0, joins: 0/0}
fl :: Float -> Float
[GblId,
Arity=1,
diff --git a/testsuite/tests/perf/compiler/all.T b/testsuite/tests/perf/compiler/all.T
index d9b0be509a..822ccb0026 100644
--- a/testsuite/tests/perf/compiler/all.T
+++ b/testsuite/tests/perf/compiler/all.T
@@ -67,7 +67,7 @@ test('T1969',
# 2014-06-29 5949188 (x86/Linux)
# 2015-07-11 6241108 (x86/Linux, 64bit machine) use +RTS -G1
# 2016-04-06 9093608 (x86/Linux, 64bit machine)
- (wordsize(64), 17285216, 15)]),
+ (wordsize(64), 19924328, 15)]),
# 2014-09-10 10463640, 10 # post-AMP-update (somewhat stabelish)
# looks like the peak is around ~10M, but we're
# unlikely to GC exactly on the peak.
@@ -79,6 +79,7 @@ test('T1969',
# 2015-07-11 11670120 (amd64/Linux)
# 2015-10-28 15017528 (amd64/Linux) emit typeable at definition site
# 2016-10-12 17285216 (amd64/Linux) it's not entirely clear why
+ # 2017-02-01 19924328 (amd64/Linux) Join points (#12988)
compiler_stats_num_field('bytes allocated',
[(platform('i386-unknown-mingw32'), 301784492, 5),
# 215582916 (x86/Windows)
@@ -439,9 +440,10 @@ test('T5631',
test('parsing001',
[compiler_stats_num_field('bytes allocated',
[(wordsize(32), 274000576, 10),
- (wordsize(64), 581551384, 5)]),
+ (wordsize(64), 493730288, 5)]),
# expected value: 587079016 (amd64/Linux)
# 2016-09-01: 581551384 (amd64/Linux) Restore w/w limit (#11565)
+ # 2016-12-19: 493730288 (amd64/Linux) Join points (#12988)
only_ways(['normal']),
],
compile_fail, [''])
@@ -503,7 +505,7 @@ test('T5321Fun',
# 2014-09-03: 299656164 (specialisation and inlining)
# 10/12/2014: 206406188 # Improvements in constraint solver
# 2016-04-06: 279922360 x86/Linux
- (wordsize(64), 525895608, 5)])
+ (wordsize(64), 498135752, 5)])
# prev: 585521080
# 29/08/2012: 713385808 # (increase due to new codegen)
# 15/05/2013: 628341952 # (reason for decrease unknown)
@@ -526,6 +528,7 @@ test('T5321Fun',
# change, however. Namely I am
# quite skeptical of the downward
# "drift" reported above
+ # 31/01/2017: 498135752 # Join points (#12988)
],
compile,[''])
@@ -802,7 +805,7 @@ test('T9872d',
test('T9961',
[ only_ways(['normal']),
compiler_stats_num_field('bytes allocated',
- [(wordsize(64), 537297968, 5),
+ [(wordsize(64), 571246936, 5),
# 2015-01-12 807117816 Initally created
# 2015-spring 772510192 Got better
# 2015-05-22 663978160 Fix for #10370 improves it more
@@ -811,6 +814,7 @@ test('T9961',
# 2016-03-20 519436672 x64_64/Linux Don't use build desugaring for large lists (#11707)
# 2016-03-24 568526784 x64_64/Linux Add eqInt* variants (#11688)
# 2016-09-01 537297968 x64_64/Linux Restore w/w limit (#11565)
+ # 2016-12-19 571246936 x64_64/Linux Join points (#12988)
(wordsize(32), 275264188, 5)
# was 375647160
# 2016-04-06 275264188 x86/Linux
@@ -934,8 +938,9 @@ test('T13035',
test('T13056',
[ only_ways(['optasm']),
compiler_stats_num_field('bytes allocated',
- [(wordsize(64), 520166912, 5),
+ [(wordsize(64), 546800240, 5),
# 2017-01-06 520166912 initial
+ # 2017-01-31 546800240 Join points (#12988)
]),
],
compile,
@@ -943,9 +948,10 @@ test('T13056',
test('T12707',
[ compiler_stats_num_field('bytes allocated',
- [(wordsize(64), 1348865648, 5),
+ [(wordsize(64), 1280336112, 5),
# initial: 1271577192
# 2017-01-22: 1348865648 Allow top-level strings in Core
+ # 2017-01-31: 1280336112 Join points (#12988)
]),
],
compile,
diff --git a/testsuite/tests/perf/haddock/all.T b/testsuite/tests/perf/haddock/all.T
index 8ec02cefcc..f037954263 100644
--- a/testsuite/tests/perf/haddock/all.T
+++ b/testsuite/tests/perf/haddock/all.T
@@ -5,7 +5,7 @@
test('haddock.base',
[unless(in_tree_compiler(), skip), req_haddock
,stats_num_field('bytes allocated',
- [(wordsize(64), 32855223200, 5)
+ [(wordsize(64), 31115778088 , 5)
# 2012-08-14: 5920822352 (amd64/Linux)
# 2012-09-20: 5829972376 (amd64/Linux)
# 2012-10-08: 5902601224 (amd64/Linux)
@@ -30,6 +30,7 @@ test('haddock.base',
# 2015-12-17: 27812188000 (x86_64/Linux) - Move Data.Functor.* into base
# 2016-02-25: 30987348040 (x86_64/Linux) - RuntimeRep
# 2016-05-12: 32855223200 (x86_64/Linux) - Make Generic1 poly-kinded
+ # 2017-01-11: 31115778088 (x86_64/Linux) - Join points (#12988)
,(platform('i386-unknown-mingw32'), 4434804940, 5)
# 2013-02-10: 3358693084 (x86/Windows)
@@ -52,7 +53,7 @@ test('haddock.base',
test('haddock.Cabal',
[unless(in_tree_compiler(), skip), req_haddock
,stats_num_field('bytes allocated',
- [(wordsize(64), 25478853176 , 5)
+ [(wordsize(64), 23272708864, 5)
# 2012-08-14: 3255435248 (amd64/Linux)
# 2012-08-29: 3324606664 (amd64/Linux, new codegen)
# 2012-10-08: 3373401360 (amd64/Linux)
@@ -92,6 +93,7 @@ test('haddock.Cabal',
# 2016-10-03: 21554874976 (amd64/Linux) - Cabal update
# 2016-10-06: 23706190072 (amd64/Linux) - Cabal update
# 2016-12-20: 25478853176 (amd64/Linux) - Cabal update
+ # 2017-01-14: 23272708864 (amd64/Linux) - Join points (#12988)
,(platform('i386-unknown-mingw32'), 3293415576, 5)
# 2012-10-30: 1733638168 (x86/Windows)
diff --git a/testsuite/tests/perf/join_points/Makefile b/testsuite/tests/perf/join_points/Makefile
new file mode 100644
index 0000000000..9101fbd40a
--- /dev/null
+++ b/testsuite/tests/perf/join_points/Makefile
@@ -0,0 +1,3 @@
+TOP=../../..
+include $(TOP)/mk/boilerplate.mk
+include $(TOP)/mk/test.mk
diff --git a/testsuite/tests/perf/join_points/all.T b/testsuite/tests/perf/join_points/all.T
new file mode 100644
index 0000000000..b6f6e40699
--- /dev/null
+++ b/testsuite/tests/perf/join_points/all.T
@@ -0,0 +1,28 @@
+# Only compile with optimisation
+def f( name, opts ):
+ opts.only_ways = ['optasm']
+
+setTestOpts(f)
+
+test('join001', normal, compile, [''])
+
+test('join002',
+ [stats_num_field('bytes allocated', [(wordsize(64), 2000290792, 5)])],
+ compile_and_run,
+ [''])
+test('join003',
+ [stats_num_field('bytes allocated', [(wordsize(64), 2000290792, 5)])],
+ compile_and_run,
+ [''])
+test('join004',
+ [stats_num_field('bytes allocated', [(wordsize(64), 48146720, 5)])],
+ compile_and_run,
+ [''])
+
+test('join005', normal, compile, [''])
+test('join006', normal, compile, [''])
+
+test('join007',
+ [stats_num_field('bytes allocated', [(wordsize(64), 50944, 5)])],
+ compile_and_run,
+ [''])
diff --git a/testsuite/tests/perf/join_points/join001.hs b/testsuite/tests/perf/join_points/join001.hs
new file mode 100644
index 0000000000..04dce36dc9
--- /dev/null
+++ b/testsuite/tests/perf/join_points/join001.hs
@@ -0,0 +1,16 @@
+{-# LANGUAGE BangPatterns #-}
+
+module Main where
+
+findDivBy :: Int -> [Int] -> Maybe Int
+findDivBy p ns
+ -- go should be a join point and should get worker/wrappered; the worker
+ -- must also be a join point (since it's mutually recursive with one).
+ = let go !p ns = case ns of n:ns' -> case n `mod` p of 0 -> Just n
+ _ -> go p ns'
+ [] -> Nothing
+ in case p of
+ 0 -> error "div by zero"
+ _ -> go p ns
+
+main = print $ findDivBy 7 [1..10]
diff --git a/testsuite/tests/perf/join_points/join002.hs b/testsuite/tests/perf/join_points/join002.hs
new file mode 100644
index 0000000000..49aecdcc7f
--- /dev/null
+++ b/testsuite/tests/perf/join_points/join002.hs
@@ -0,0 +1,51 @@
+module Main where
+
+import Data.List
+
+-- These four functions should all wind up the same; they represent successive
+-- simplifications that should happen. (Actual details may vary, since find
+-- isn't quite defined this way, but the differences disappear by the end.)
+
+firstEvenIsPositive1 :: [Int] -> Bool
+firstEvenIsPositive1 = maybe False (> 0) . find even
+
+-- After inlining:
+
+firstEvenIsPositive2 :: [Int] -> Bool
+firstEvenIsPositive2 xs =
+ let go xs = case xs of x:xs' -> if even x then Just x else go xs'
+ [] -> Nothing
+ in case go xs of Just n -> n > 0
+ Nothing -> False
+
+-- Note that go *could* be a join point if it were declared inside the scrutinee
+-- instead of outside. So it's now Float In's job to move the binding inward a
+-- smidge. *But* if it goes too far inward (as it would until recently), it will
+-- wrap only "go" instead of "go xs", which won't let us mark go as a join point
+-- since join points can't be partially invoked.
+--
+-- After Float In:
+
+firstEvenIsPositive3 :: [Int] -> Bool
+firstEvenIsPositive3 xs =
+ case let {-join-} go xs = case xs of x:xs' -> if even x then Just x
+ else go xs'
+ [] -> Nothing
+ in go xs of
+ Just n -> n > 0
+ Nothing -> False
+
+-- After the simplifier:
+
+firstEvenIsPositive4 :: [Int] -> Bool
+firstEvenIsPositive4 xs =
+ let {-join-} go xs = case xs of x:xs' -> if even x then x > 0 else go xs'
+ [] -> False
+ in go xs
+
+-- This only worked because go was a join point so that the case gets moved
+-- inside.
+
+{-# NOINLINE firstEvenIsPositive1 #-}
+
+main = print $ or $ [firstEvenIsPositive1 [1,3..n] | n <- [1..10000]]
diff --git a/testsuite/tests/perf/join_points/join002.stdout b/testsuite/tests/perf/join_points/join002.stdout
new file mode 100644
index 0000000000..bc59c12aa1
--- /dev/null
+++ b/testsuite/tests/perf/join_points/join002.stdout
@@ -0,0 +1 @@
+False
diff --git a/testsuite/tests/perf/join_points/join003.hs b/testsuite/tests/perf/join_points/join003.hs
new file mode 100644
index 0000000000..051c2d8bfe
--- /dev/null
+++ b/testsuite/tests/perf/join_points/join003.hs
@@ -0,0 +1,69 @@
+{-
+ - A variation on join002.hs that avoids Float Out issues. The join points in
+ - join002.hs may get floated to top level, which is necessary to allow in
+ - general, but which makes them into functions rather than join points, thus
+ - messing up the test.
+ -}
+
+module Main (
+ firstMultIsPositive1, firstMultIsPositive2, firstMultIsPositive3,
+ firstMultIsPositive4,
+
+ main
+) where
+
+import Data.List
+
+divides :: Int -> Int -> Bool
+p `divides` n = n `mod` p == 0
+
+infix 4 `divides`
+
+-- These four functions should all wind up the same; they represent successive
+-- simplifications that should happen. (Actual details may vary, since find
+-- isn't quite defined this way, but the differences disappear by the end.)
+
+firstMultIsPositive1 :: Int -> [Int] -> Bool
+firstMultIsPositive1 p = maybe False (> 0) . find (p `divides`)
+
+-- After inlining:
+
+firstMultIsPositive2 :: Int -> [Int] -> Bool
+firstMultIsPositive2 p xs =
+ let go xs = case xs of x:xs' -> if p `divides` x then Just x else go xs'
+ [] -> Nothing
+ in case go xs of Just n -> n > 0
+ Nothing -> False
+
+-- Note that go *could* be a join point if it were declared inside the scrutinee
+-- instead of outside. So it's now Float In's job to move the binding inward a
+-- smidge. *But* if it goes too far inward (as it would until recently), it will
+-- wrap only "go" instead of "go xs", which won't let us mark go as a join point
+-- since join points can't be partially invoked.
+--
+-- After Float In:
+
+firstMultIsPositive3 :: Int -> [Int] -> Bool
+firstMultIsPositive3 p xs =
+ case let {-join-} go xs = case xs of x:xs' -> if p `divides` x then Just x
+ else go xs'
+ [] -> Nothing
+ in go xs of
+ Just n -> n > 0
+ Nothing -> False
+
+-- After the simplifier:
+
+firstMultIsPositive4 :: Int -> [Int] -> Bool
+firstMultIsPositive4 p xs =
+ let {-join-} go xs = case xs of x:xs' -> if p `divides` x then x > 0
+ else go xs'
+ [] -> False
+ in go xs
+
+-- This only worked because go was a join point so that the case gets moved
+-- inside.
+
+{-# NOINLINE firstMultIsPositive1 #-}
+
+main = print $ or $ [firstMultIsPositive1 2 [1,3..n] | n <- [1..10000]]
diff --git a/testsuite/tests/perf/join_points/join003.stdout b/testsuite/tests/perf/join_points/join003.stdout
new file mode 100644
index 0000000000..bc59c12aa1
--- /dev/null
+++ b/testsuite/tests/perf/join_points/join003.stdout
@@ -0,0 +1 @@
+False
diff --git a/testsuite/tests/perf/join_points/join004.hs b/testsuite/tests/perf/join_points/join004.hs
new file mode 100644
index 0000000000..1962cc266e
--- /dev/null
+++ b/testsuite/tests/perf/join_points/join004.hs
@@ -0,0 +1,30 @@
+{-
+ - A rather contrived example demonstrating the virtues of not floating join
+ - points outward.
+ -}
+
+module Main (main) where
+
+-- Calculate n `div` d `div` d by looping.
+
+{-# NOINLINE slowDivDiv #-}
+slowDivDiv :: Int -> Int -> Int
+slowDivDiv n d
+ = let {-# NOINLINE divPos #-}
+ divPos :: Int -> Int
+ divPos n0
+ = -- This function is a join point (all calls are tail calls), so it
+ -- never causes a closure allocation, so it doesn't help to float it
+ -- out. Thus -fno-join-points causes a ~25% jump in allocations.
+ let go n' i
+ = case n' >= d of True -> go (n' - d) (i + 1)
+ False -> i
+ in go n0 0
+ in case n >= 0 of True -> divPos (divPos n)
+ False -> divPos (-(divPos (-n)))
+ -- It's important that divPos be called twice
+ -- because otherwise it'd be a one-shot lambda
+ -- and so the join point would be floated
+ -- back in again.
+
+main = print $ sum [slowDivDiv n d | n <- [1..1000], d <- [1..1000]]
diff --git a/testsuite/tests/perf/join_points/join004.stdout b/testsuite/tests/perf/join_points/join004.stdout
new file mode 100644
index 0000000000..a3abc727e8
--- /dev/null
+++ b/testsuite/tests/perf/join_points/join004.stdout
@@ -0,0 +1 @@
+793515
diff --git a/testsuite/tests/perf/join_points/join005.hs b/testsuite/tests/perf/join_points/join005.hs
new file mode 100644
index 0000000000..7feec98c28
--- /dev/null
+++ b/testsuite/tests/perf/join_points/join005.hs
@@ -0,0 +1,23 @@
+{- Test of Worker/Wrapper operating on join points -}
+
+module Main where
+
+sumOfMultiplesOf :: Int -> [Int] -> Int
+sumOfMultiplesOf p ns
+ = {- This is a join point (and it will stay that way---it won't get floated to
+ top level because p occurs free). It should get worker/wrappered. -}
+ let go ns acc
+ = case ns of [] -> acc
+ n:ns' -> case n `mod` p of 0 -> go ns' (acc + n)
+ _ -> go ns' acc
+ in go ns 0
+
+{-
+It's hard to test for this, but what should happen is that go gets W/W'd and the
+worker is a join point (else Core Lint will complain). Interestingly, go is
+*not* CPR'd, because then the worker couldn't be a join point, but once the
+simplifier runs, the worker ends up returning Int# anyway. See Note [Don't CPR
+join points] in WorkWrap.hs.
+-}
+
+main = print $ sumOfMultiplesOf 2 [1..10]
diff --git a/testsuite/tests/perf/join_points/join006.hs b/testsuite/tests/perf/join_points/join006.hs
new file mode 100644
index 0000000000..3c0b2ceecd
--- /dev/null
+++ b/testsuite/tests/perf/join_points/join006.hs
@@ -0,0 +1,22 @@
+module Main where
+
+{-# NOINLINE foo #-}
+foo :: Int -> Int
+foo x = x
+
+{-# RULES "foo/5" forall (f :: Int -> Int). foo (f 5) = foo (f 42) #-}
+-- highly suspect, of course!
+
+main = print $ foo (let {-# NOINLINE j #-}
+ j :: Int -> Int
+ j n = n + 1 in j 5)
+
+{-
+If we're not careful, this will get rewritten to
+
+ main = print $ let <join> j n = n + 1 in foo (j 42)
+
+which violates the join point invariant (can't invoke a join point from
+non-tail context). Solution is to refuse to float join points when matching
+RULES.
+-}
diff --git a/testsuite/tests/perf/join_points/join007.hs b/testsuite/tests/perf/join_points/join007.hs
new file mode 100644
index 0000000000..aa2f68c0bc
--- /dev/null
+++ b/testsuite/tests/perf/join_points/join007.hs
@@ -0,0 +1,42 @@
+-- Test of fusion in unfold/destroy style. Originally, unfold/destroy supported
+-- filter, but the constructors (here Done and Yield) couldn't be compiled away.
+-- Join points let us do this by pulling the case from sumS into the loop in
+-- filterS.
+
+{-# LANGUAGE GADTs #-}
+
+module Main (main) where
+
+data Stream a where Stream :: (s -> Step a s) -> s -> Stream a
+data Step a s = Done | Yield a s
+
+{-# INLINE sumS #-}
+sumS :: Stream Int -> Int
+sumS (Stream next s0) = go s0 0
+ where
+ go s acc = case next s of Done -> acc
+ Yield a s' -> go s' (acc + a)
+
+{-# INLINE filterS #-}
+filterS :: (a -> Bool) -> Stream a -> Stream a
+filterS p (Stream next s0) = Stream fnext s0
+ where
+ fnext s = seek s
+ where
+ -- should be a join point!
+ seek s = case next s of Done -> Done
+ Yield a s' | p a -> Yield a s'
+ | otherwise -> seek s'
+
+{-# INLINE enumFromToS #-}
+enumFromToS :: Int -> Int -> Stream Int
+enumFromToS lo hi = Stream next lo
+ where
+ next n | n > hi = Done
+ | otherwise = Yield n (n+1)
+
+{-# NOINLINE test #-} -- for -ddump-simpl
+test :: Int -> Int -> Int
+test lo hi = sumS (filterS even (enumFromToS lo hi))
+
+main = print $ test 1 10000000
diff --git a/testsuite/tests/perf/join_points/join007.stdout b/testsuite/tests/perf/join_points/join007.stdout
new file mode 100644
index 0000000000..0a7ad1072c
--- /dev/null
+++ b/testsuite/tests/perf/join_points/join007.stdout
@@ -0,0 +1 @@
+25000005000000
diff --git a/testsuite/tests/perf/should_run/all.T b/testsuite/tests/perf/should_run/all.T
index 382c317a9b..592e63c274 100644
--- a/testsuite/tests/perf/should_run/all.T
+++ b/testsuite/tests/perf/should_run/all.T
@@ -436,10 +436,11 @@ test('T9203',
[ (wordsize(32), 84345136 , 5)
# was
# 2016-04-06 84345136 (i386/Debian) not sure
- , (wordsize(64), 95451192, 5) ]),
+ , (wordsize(64), 84620888, 5) ]),
# was 95747304
# 2019-09-10 94547280 post-AMP cleanup
# 2015-10-28 95451192 emit Typeable at definition site
+ # 2016-12-19 84620888 Join points
only_ways(['normal'])],
compile_and_run,
['-O2'])
@@ -447,9 +448,10 @@ test('T9203',
test('T9339',
[stats_num_field('bytes allocated',
[ (wordsize(32), 40046844, 5)
- , (wordsize(64), 80050760, 5) ]),
+ , (wordsize(64), 50728, 5) ]),
# w/o fusing last: 320005080
# 2014-07-22: 80050760
+ # 2016-08-17: 50728 Join points (#12988)
only_ways(['normal'])],
compile_and_run,
['-O2'])
diff --git a/testsuite/tests/roles/should_compile/Roles13.stderr b/testsuite/tests/roles/should_compile/Roles13.stderr
index 20206e28df..7e510d442e 100644
--- a/testsuite/tests/roles/should_compile/Roles13.stderr
+++ b/testsuite/tests/roles/should_compile/Roles13.stderr
@@ -1,13 +1,14 @@
==================== Tidy Core ====================
-Result size of Tidy Core = {terms: 63, types: 26, coercions: 5}
+Result size of Tidy Core
+ = {terms: 63, types: 26, coercions: 5, joins: 0/0}
--- RHS size: {terms: 2, types: 2, coercions: 0}
+-- RHS size: {terms: 2, types: 2, coercions: 0, joins: 0/0}
convert1 :: Wrap Age -> Wrap Age
[GblId, Arity=1, Caf=NoCafRefs]
convert1 = \ (ds :: Wrap Age) -> ds
--- RHS size: {terms: 1, types: 0, coercions: 5}
+-- RHS size: {terms: 1, types: 0, coercions: 5, joins: 0/0}
convert :: Wrap Age -> Int
[GblId, Arity=1, Caf=NoCafRefs]
convert =
@@ -15,42 +16,42 @@ convert =
`cast` (<Wrap Age>_R -> Roles13.N:Wrap[0] Roles13.N:Age[0]
:: ((Wrap Age -> Wrap Age) :: *) ~R# ((Wrap Age -> Int) :: *))
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
$trModule1 :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs]
$trModule1 = "main"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
$trModule2 :: GHC.Types.TrName
[GblId, Caf=NoCafRefs]
$trModule2 = GHC.Types.TrNameS $trModule1
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
$trModule3 :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs]
$trModule3 = "Roles13"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
$trModule4 :: GHC.Types.TrName
[GblId, Caf=NoCafRefs]
$trModule4 = GHC.Types.TrNameS $trModule3
--- RHS size: {terms: 3, types: 0, coercions: 0}
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
Roles13.$trModule :: GHC.Types.Module
[GblId, Caf=NoCafRefs]
Roles13.$trModule = GHC.Types.Module $trModule2 $trModule4
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
$tc'MkAge1 :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs]
$tc'MkAge1 = "'MkAge"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
$tc'MkAge2 :: GHC.Types.TrName
[GblId, Caf=NoCafRefs]
$tc'MkAge2 = GHC.Types.TrNameS $tc'MkAge1
--- RHS size: {terms: 5, types: 0, coercions: 0}
+-- RHS size: {terms: 5, types: 0, coercions: 0, joins: 0/0}
Roles13.$tc'MkAge :: GHC.Types.TyCon
[GblId, Caf=NoCafRefs]
Roles13.$tc'MkAge =
@@ -60,17 +61,17 @@ Roles13.$tc'MkAge =
Roles13.$trModule
$tc'MkAge2
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
$tcAge1 :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs]
$tcAge1 = "Age"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
$tcAge2 :: GHC.Types.TrName
[GblId, Caf=NoCafRefs]
$tcAge2 = GHC.Types.TrNameS $tcAge1
--- RHS size: {terms: 5, types: 0, coercions: 0}
+-- RHS size: {terms: 5, types: 0, coercions: 0, joins: 0/0}
Roles13.$tcAge :: GHC.Types.TyCon
[GblId, Caf=NoCafRefs]
Roles13.$tcAge =
@@ -80,17 +81,17 @@ Roles13.$tcAge =
Roles13.$trModule
$tcAge2
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
$tc'MkWrap1 :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs]
$tc'MkWrap1 = "'MkWrap"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
$tc'MkWrap2 :: GHC.Types.TrName
[GblId, Caf=NoCafRefs]
$tc'MkWrap2 = GHC.Types.TrNameS $tc'MkWrap1
--- RHS size: {terms: 5, types: 0, coercions: 0}
+-- RHS size: {terms: 5, types: 0, coercions: 0, joins: 0/0}
Roles13.$tc'MkWrap :: GHC.Types.TyCon
[GblId, Caf=NoCafRefs]
Roles13.$tc'MkWrap =
@@ -100,17 +101,17 @@ Roles13.$tc'MkWrap =
Roles13.$trModule
$tc'MkWrap2
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
$tcWrap1 :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs]
$tcWrap1 = "Wrap"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
$tcWrap2 :: GHC.Types.TrName
[GblId, Caf=NoCafRefs]
$tcWrap2 = GHC.Types.TrNameS $tcWrap1
--- RHS size: {terms: 5, types: 0, coercions: 0}
+-- RHS size: {terms: 5, types: 0, coercions: 0, joins: 0/0}
Roles13.$tcWrap :: GHC.Types.TyCon
[GblId, Caf=NoCafRefs]
Roles13.$tcWrap =
diff --git a/testsuite/tests/simplCore/should_compile/Makefile b/testsuite/tests/simplCore/should_compile/Makefile
index 5a465d9818..ef3e74ad7f 100644
--- a/testsuite/tests/simplCore/should_compile/Makefile
+++ b/testsuite/tests/simplCore/should_compile/Makefile
@@ -47,6 +47,7 @@ T5658b:
$(RM) -f T5658b.o T5658b.hi
'$(TEST_HC)' $(TEST_HC_OPTS) -O -c T5658b.hs -ddump-simpl | grep -c indexIntArray
# Trac 5658 meant that there were three calls to indexIntArray instead of two
+# (now four due to join-point discount causing W/W to stabilize unfolding)
T5776:
$(RM) -f T5776.o T5776.hi
@@ -121,7 +122,7 @@ T13155:
T13156:
$(RM) -f T13156.hi T13156.o
'$(TEST_HC)' $(TEST_HC_OPTS) -c T13156.hs -O -ddump-prep -dsuppress-uniques | grep "case"
- # There should be a single 'case case'
+ # There should be a single 'case r @ GHC.Types.Any'
.PHONY: T4138
T4138:
diff --git a/testsuite/tests/simplCore/should_compile/T13156.hs b/testsuite/tests/simplCore/should_compile/T13156.hs
index cdc322af38..2ddfa2cefb 100644
--- a/testsuite/tests/simplCore/should_compile/T13156.hs
+++ b/testsuite/tests/simplCore/should_compile/T13156.hs
@@ -4,39 +4,42 @@ f g x = let r :: [a] -> [a]
r = case g x of True -> reverse . reverse
False -> reverse
in
- r `seq` r `seq` True
+ r `seq` r `seq` r
{- Expected -ddump-prep looks like this.
- (Room for improvement on the case (case ..) line.)
+ (Case-of-type-lambda an oddity of Core Prep.)
--- RHS size: {terms: 9, types: 9, coercions: 0}
+-- RHS size: {terms: 9, types: 9, coercions: 0, joins: 0/0}
T13156.f1 :: forall a. [a] -> [a]
[GblId, Arity=1, Caf=NoCafRefs, Str=<S,1*U>, Unf=OtherCon []]
T13156.f1 =
- \ (@ a_aC4) (x_sNG [Occ=Once] :: [a]) ->
- case GHC.List.reverse @ a x_sNG of sat_sNH { __DEFAULT ->
- GHC.List.reverse1 @ a sat_sNH (GHC.Types.[] @ a)
+ \ (@ a) (x [Occ=Once] :: [a]) ->
+ case GHC.List.reverse @ a x of sat { __DEFAULT ->
+ GHC.List.reverse1 @ a sat (GHC.Types.[] @ a)
}
--- RHS size: {terms: 13, types: 20, coercions: 0}
-T13156.f :: forall p. (p -> GHC.Types.Bool) -> p -> GHC.Types.Bool
+-- RHS size: {terms: 17, types: 28, coercions: 0, joins: 0/0}
+T13156.f
+ :: forall p.
+ (p -> GHC.Types.Bool) -> p -> [GHC.Types.Int] -> [GHC.Types.Int]
[GblId,
Arity=2,
Caf=NoCafRefs,
Str=<C(S),1*C1(U)><L,U>,
Unf=OtherCon []]
T13156.f =
- \ (@ p_aBS)
- (g_sNI [Occ=Once!] :: p -> GHC.Types.Bool)
- (x_sNJ [Occ=Once] :: p) ->
- case case g_sNI x_sNJ of {
- GHC.Types.False -> GHC.List.reverse @ GHC.Types.Any;
- GHC.Types.True -> T13156.f1 @ GHC.Types.Any
- }
- of
+ \ (@ p)
+ (g [Occ=Once!] :: p -> GHC.Types.Bool)
+ (x [Occ=Once] :: p) ->
+ case \ (@ a) ->
+ case g x of {
+ GHC.Types.False -> GHC.List.reverse @ a;
+ GHC.Types.True -> T13156.f1 @ a
+ }
+ of r [Dmd=<S,U>]
{ __DEFAULT ->
- GHC.Types.True
+ case r @ GHC.Types.Any of { __DEFAULT -> r @ GHC.Types.Int }
}
-}
diff --git a/testsuite/tests/simplCore/should_compile/T13156.stdout b/testsuite/tests/simplCore/should_compile/T13156.stdout
index 51d10c851f..5aa8f6aa38 100644
--- a/testsuite/tests/simplCore/should_compile/T13156.stdout
+++ b/testsuite/tests/simplCore/should_compile/T13156.stdout
@@ -1,2 +1,4 @@
case GHC.List.reverse @ a x of sat { __DEFAULT ->
- case case g x of {
+ case \ (@ a1) ->
+ case g x of {
+ case r @ GHC.Types.Any of { __DEFAULT -> r @ a }
diff --git a/testsuite/tests/simplCore/should_compile/T3717.stderr b/testsuite/tests/simplCore/should_compile/T3717.stderr
index f9adeb28da..9bcc4f31ac 100644
--- a/testsuite/tests/simplCore/should_compile/T3717.stderr
+++ b/testsuite/tests/simplCore/should_compile/T3717.stderr
@@ -1,8 +1,9 @@
==================== Tidy Core ====================
-Result size of Tidy Core = {terms: 36, types: 15, coercions: 0}
+Result size of Tidy Core
+ = {terms: 36, types: 15, coercions: 0, joins: 0/0}
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T3717.$trModule4 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -10,7 +11,7 @@ T3717.$trModule4 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 20 0}]
T3717.$trModule4 = "main"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T3717.$trModule3 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -19,7 +20,7 @@ T3717.$trModule3 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T3717.$trModule3 = GHC.Types.TrNameS T3717.$trModule4
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T3717.$trModule2 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -27,7 +28,7 @@ T3717.$trModule2 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
T3717.$trModule2 = "T3717"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T3717.$trModule1 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -36,7 +37,7 @@ T3717.$trModule1 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T3717.$trModule1 = GHC.Types.TrNameS T3717.$trModule2
--- RHS size: {terms: 3, types: 0, coercions: 0}
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
T3717.$trModule :: GHC.Types.Module
[GblId,
Caf=NoCafRefs,
@@ -47,7 +48,7 @@ T3717.$trModule =
GHC.Types.Module T3717.$trModule3 T3717.$trModule1
Rec {
--- RHS size: {terms: 10, types: 2, coercions: 0}
+-- RHS size: {terms: 10, types: 2, coercions: 0, joins: 0/0}
T3717.$wfoo [InlPrag=[0], Occ=LoopBreaker]
:: GHC.Prim.Int# -> GHC.Prim.Int#
[GblId, Arity=1, Caf=NoCafRefs, Str=<S,1*U>]
@@ -59,7 +60,7 @@ T3717.$wfoo =
}
end Rec }
--- RHS size: {terms: 10, types: 4, coercions: 0}
+-- RHS size: {terms: 10, types: 4, coercions: 0, joins: 0/0}
foo [InlPrag=INLINE[0]] :: Int -> Int
[GblId,
Arity=1,
diff --git a/testsuite/tests/simplCore/should_compile/T3772.stdout b/testsuite/tests/simplCore/should_compile/T3772.stdout
index 76936e336f..98a809d95f 100644
--- a/testsuite/tests/simplCore/should_compile/T3772.stdout
+++ b/testsuite/tests/simplCore/should_compile/T3772.stdout
@@ -1,9 +1,10 @@
==================== Tidy Core ====================
-Result size of Tidy Core = {terms: 40, types: 16, coercions: 0}
+Result size of Tidy Core
+ = {terms: 40, types: 16, coercions: 0, joins: 0/0}
Rec {
--- RHS size: {terms: 10, types: 2, coercions: 0}
+-- RHS size: {terms: 10, types: 2, coercions: 0, joins: 0/0}
$wxs :: GHC.Prim.Int# -> ()
[GblId, Arity=1, Caf=NoCafRefs, Str=<S,1*U>]
$wxs =
@@ -14,7 +15,7 @@ $wxs =
}
end Rec }
--- RHS size: {terms: 14, types: 5, coercions: 0}
+-- RHS size: {terms: 14, types: 5, coercions: 0, joins: 0/0}
foo [InlPrag=NOINLINE] :: Int -> ()
[GblId, Arity=1, Caf=NoCafRefs, Str=<S(S),1*U(U)>]
foo =
@@ -26,7 +27,7 @@ foo =
}
}
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T3772.$trModule2 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -34,7 +35,7 @@ T3772.$trModule2 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
T3772.$trModule2 = "T3772"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T3772.$trModule1 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -43,7 +44,7 @@ T3772.$trModule1 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T3772.$trModule1 = GHC.Types.TrNameS T3772.$trModule2
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T3772.$trModule4 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -51,7 +52,7 @@ T3772.$trModule4 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 20 0}]
T3772.$trModule4 = "main"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T3772.$trModule3 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -60,7 +61,7 @@ T3772.$trModule3 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T3772.$trModule3 = GHC.Types.TrNameS T3772.$trModule4
--- RHS size: {terms: 3, types: 0, coercions: 0}
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
T3772.$trModule :: GHC.Types.Module
[GblId,
Caf=NoCafRefs,
diff --git a/testsuite/tests/simplCore/should_compile/T4908.stderr b/testsuite/tests/simplCore/should_compile/T4908.stderr
index e9957bf9de..185b9b3529 100644
--- a/testsuite/tests/simplCore/should_compile/T4908.stderr
+++ b/testsuite/tests/simplCore/should_compile/T4908.stderr
@@ -1,8 +1,9 @@
==================== Tidy Core ====================
-Result size of Tidy Core = {terms: 68, types: 43, coercions: 0}
+Result size of Tidy Core
+ = {terms: 68, types: 43, coercions: 0, joins: 0/0}
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T4908.$trModule4 :: Addr#
[GblId,
Caf=NoCafRefs,
@@ -10,7 +11,7 @@ T4908.$trModule4 :: Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 20 0}]
T4908.$trModule4 = "main"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T4908.$trModule3 :: TrName
[GblId,
Caf=NoCafRefs,
@@ -19,7 +20,7 @@ T4908.$trModule3 :: TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T4908.$trModule3 = GHC.Types.TrNameS T4908.$trModule4
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T4908.$trModule2 :: Addr#
[GblId,
Caf=NoCafRefs,
@@ -27,7 +28,7 @@ T4908.$trModule2 :: Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
T4908.$trModule2 = "T4908"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T4908.$trModule1 :: TrName
[GblId,
Caf=NoCafRefs,
@@ -36,7 +37,7 @@ T4908.$trModule1 :: TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T4908.$trModule1 = GHC.Types.TrNameS T4908.$trModule2
--- RHS size: {terms: 3, types: 0, coercions: 0}
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
T4908.$trModule :: Module
[GblId,
Caf=NoCafRefs,
@@ -47,7 +48,7 @@ T4908.$trModule =
GHC.Types.Module T4908.$trModule3 T4908.$trModule1
Rec {
--- RHS size: {terms: 19, types: 5, coercions: 0}
+-- RHS size: {terms: 19, types: 5, coercions: 0, joins: 0/0}
T4908.f_$s$wf [Occ=LoopBreaker] :: Int -> Int# -> Int# -> Bool
[GblId, Arity=3, Caf=NoCafRefs, Str=<L,A><L,1*U><S,1*U>]
T4908.f_$s$wf =
@@ -62,7 +63,7 @@ T4908.f_$s$wf =
}
end Rec }
--- RHS size: {terms: 24, types: 13, coercions: 0}
+-- RHS size: {terms: 24, types: 13, coercions: 0, joins: 0/0}
T4908.$wf [InlPrag=[0]] :: Int# -> (Int, Int) -> Bool
[GblId,
Arity=2,
@@ -85,7 +86,7 @@ T4908.$wf =
0# -> GHC.Types.True
}
--- RHS size: {terms: 8, types: 6, coercions: 0}
+-- RHS size: {terms: 8, types: 6, coercions: 0, joins: 0/0}
f [InlPrag=INLINE[0]] :: Int -> (Int, Int) -> Bool
[GblId,
Arity=2,
diff --git a/testsuite/tests/simplCore/should_compile/T4930.stderr b/testsuite/tests/simplCore/should_compile/T4930.stderr
index 365584d3d0..9db97a5e1f 100644
--- a/testsuite/tests/simplCore/should_compile/T4930.stderr
+++ b/testsuite/tests/simplCore/should_compile/T4930.stderr
@@ -1,8 +1,9 @@
==================== Tidy Core ====================
-Result size of Tidy Core = {terms: 49, types: 19, coercions: 0}
+Result size of Tidy Core
+ = {terms: 44, types: 17, coercions: 0, joins: 0/0}
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T4930.$trModule4 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -10,7 +11,7 @@ T4930.$trModule4 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 20 0}]
T4930.$trModule4 = "main"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T4930.$trModule3 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -19,7 +20,7 @@ T4930.$trModule3 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T4930.$trModule3 = GHC.Types.TrNameS T4930.$trModule4
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T4930.$trModule2 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -27,7 +28,7 @@ T4930.$trModule2 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
T4930.$trModule2 = "T4930"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T4930.$trModule1 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -36,7 +37,7 @@ T4930.$trModule1 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T4930.$trModule1 = GHC.Types.TrNameS T4930.$trModule2
--- RHS size: {terms: 3, types: 0, coercions: 0}
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
T4930.$trModule :: GHC.Types.Module
[GblId,
Caf=NoCafRefs,
@@ -47,24 +48,19 @@ T4930.$trModule =
GHC.Types.Module T4930.$trModule3 T4930.$trModule1
Rec {
--- RHS size: {terms: 23, types: 6, coercions: 0}
+-- RHS size: {terms: 18, types: 4, coercions: 0, joins: 0/0}
T4930.$wfoo [InlPrag=[0], Occ=LoopBreaker]
:: GHC.Prim.Int# -> GHC.Prim.Int#
[GblId, Arity=1, Caf=NoCafRefs, Str=<S,U>]
T4930.$wfoo =
\ (ww :: GHC.Prim.Int#) ->
- case case GHC.Prim.tagToEnum# @ Bool (GHC.Prim.<# ww 5#) of {
- False -> GHC.Types.I# (GHC.Prim.+# ww 2#);
- True ->
- case T4930.$wfoo ww of ww1 { __DEFAULT -> GHC.Types.I# ww1 }
- }
- of
- { GHC.Types.I# ipv ->
- GHC.Prim.+# ww 5#
+ case GHC.Prim.tagToEnum# @ Bool (GHC.Prim.<# ww 5#) of {
+ False -> GHC.Prim.+# ww 5#;
+ True -> case T4930.$wfoo ww of { __DEFAULT -> GHC.Prim.+# ww 5# }
}
end Rec }
--- RHS size: {terms: 10, types: 4, coercions: 0}
+-- RHS size: {terms: 10, types: 4, coercions: 0, joins: 0/0}
foo [InlPrag=INLINE[0]] :: Int -> Int
[GblId,
Arity=1,
diff --git a/testsuite/tests/simplCore/should_compile/T5658b.stdout b/testsuite/tests/simplCore/should_compile/T5658b.stdout
index 0cfbf08886..b8626c4cff 100644
--- a/testsuite/tests/simplCore/should_compile/T5658b.stdout
+++ b/testsuite/tests/simplCore/should_compile/T5658b.stdout
@@ -1 +1 @@
-2
+4
diff --git a/testsuite/tests/simplCore/should_compile/T7360.stderr b/testsuite/tests/simplCore/should_compile/T7360.stderr
index 2e387b27bc..b35c39931c 100644
--- a/testsuite/tests/simplCore/should_compile/T7360.stderr
+++ b/testsuite/tests/simplCore/should_compile/T7360.stderr
@@ -1,8 +1,9 @@
==================== Tidy Core ====================
-Result size of Tidy Core = {terms: 94, types: 48, coercions: 0}
+Result size of Tidy Core
+ = {terms: 94, types: 48, coercions: 0, joins: 0/0}
--- RHS size: {terms: 6, types: 3, coercions: 0}
+-- RHS size: {terms: 6, types: 3, coercions: 0, joins: 0/0}
T7360.$WFoo3 [InlPrag=INLINE] :: Int -> Foo
[GblId[DataConWrapper],
Arity=1,
@@ -17,19 +18,19 @@ T7360.$WFoo3 =
\ (dt [Occ=Once!] :: Int) ->
case dt of { GHC.Types.I# dt [Occ=Once] -> T7360.Foo3 dt }
--- RHS size: {terms: 5, types: 2, coercions: 0}
+-- RHS size: {terms: 5, types: 2, coercions: 0, joins: 0/0}
fun1 [InlPrag=NOINLINE] :: Foo -> ()
[GblId, Arity=1, Caf=NoCafRefs, Str=<S,1*U>]
fun1 = \ (x :: Foo) -> case x of { __DEFAULT -> GHC.Tuple.() }
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T7360.fun5 :: ()
[GblId,
Unf=Unf{Src=<vanilla>, TopLvl=True, Value=False, ConLike=False,
WorkFree=False, Expandable=False, Guidance=IF_ARGS [] 20 0}]
T7360.fun5 = fun1 T7360.Foo1
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T7360.fun4 :: Int
[GblId,
Caf=NoCafRefs,
@@ -38,7 +39,7 @@ T7360.fun4 :: Int
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T7360.fun4 = GHC.Types.I# 0#
--- RHS size: {terms: 16, types: 13, coercions: 0}
+-- RHS size: {terms: 16, types: 13, coercions: 0, joins: 0/0}
fun2 :: forall a. [a] -> ((), Int)
[GblId,
Arity=1,
@@ -66,7 +67,7 @@ fun2 =
}
})
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T7360.$trModule4 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -74,7 +75,7 @@ T7360.$trModule4 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 20 0}]
T7360.$trModule4 = "main"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T7360.$trModule3 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -83,7 +84,7 @@ T7360.$trModule3 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T7360.$trModule3 = GHC.Types.TrNameS T7360.$trModule4
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T7360.$trModule2 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -91,7 +92,7 @@ T7360.$trModule2 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
T7360.$trModule2 = "T7360"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T7360.$trModule1 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -100,7 +101,7 @@ T7360.$trModule1 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T7360.$trModule1 = GHC.Types.TrNameS T7360.$trModule2
--- RHS size: {terms: 3, types: 0, coercions: 0}
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
T7360.$trModule :: GHC.Types.Module
[GblId,
Caf=NoCafRefs,
@@ -110,7 +111,7 @@ T7360.$trModule :: GHC.Types.Module
T7360.$trModule =
GHC.Types.Module T7360.$trModule3 T7360.$trModule1
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T7360.$tc'Foo9 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -118,7 +119,7 @@ T7360.$tc'Foo9 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
T7360.$tc'Foo9 = "'Foo3"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T7360.$tc'Foo8 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -127,7 +128,7 @@ T7360.$tc'Foo8 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T7360.$tc'Foo8 = GHC.Types.TrNameS T7360.$tc'Foo9
--- RHS size: {terms: 5, types: 0, coercions: 0}
+-- RHS size: {terms: 5, types: 0, coercions: 0, joins: 0/0}
T7360.$tc'Foo3 :: GHC.Types.TyCon
[GblId,
Caf=NoCafRefs,
@@ -141,7 +142,7 @@ T7360.$tc'Foo3 =
T7360.$trModule
T7360.$tc'Foo8
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T7360.$tc'Foo7 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -149,7 +150,7 @@ T7360.$tc'Foo7 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
T7360.$tc'Foo7 = "'Foo2"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T7360.$tc'Foo6 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -158,7 +159,7 @@ T7360.$tc'Foo6 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T7360.$tc'Foo6 = GHC.Types.TrNameS T7360.$tc'Foo7
--- RHS size: {terms: 5, types: 0, coercions: 0}
+-- RHS size: {terms: 5, types: 0, coercions: 0, joins: 0/0}
T7360.$tc'Foo2 :: GHC.Types.TyCon
[GblId,
Caf=NoCafRefs,
@@ -172,7 +173,7 @@ T7360.$tc'Foo2 =
T7360.$trModule
T7360.$tc'Foo6
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T7360.$tc'Foo5 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -180,7 +181,7 @@ T7360.$tc'Foo5 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
T7360.$tc'Foo5 = "'Foo1"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T7360.$tc'Foo4 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -189,7 +190,7 @@ T7360.$tc'Foo4 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T7360.$tc'Foo4 = GHC.Types.TrNameS T7360.$tc'Foo5
--- RHS size: {terms: 5, types: 0, coercions: 0}
+-- RHS size: {terms: 5, types: 0, coercions: 0, joins: 0/0}
T7360.$tc'Foo1 :: GHC.Types.TyCon
[GblId,
Caf=NoCafRefs,
@@ -203,7 +204,7 @@ T7360.$tc'Foo1 =
T7360.$trModule
T7360.$tc'Foo4
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T7360.$tcFoo2 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -211,7 +212,7 @@ T7360.$tcFoo2 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 20 0}]
T7360.$tcFoo2 = "Foo"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
T7360.$tcFoo1 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -220,7 +221,7 @@ T7360.$tcFoo1 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
T7360.$tcFoo1 = GHC.Types.TrNameS T7360.$tcFoo2
--- RHS size: {terms: 5, types: 0, coercions: 0}
+-- RHS size: {terms: 5, types: 0, coercions: 0, joins: 0/0}
T7360.$tcFoo :: GHC.Types.TyCon
[GblId,
Caf=NoCafRefs,
diff --git a/testsuite/tests/simplCore/should_compile/T9400.stderr b/testsuite/tests/simplCore/should_compile/T9400.stderr
index 92979b36b1..a8004dce8b 100644
--- a/testsuite/tests/simplCore/should_compile/T9400.stderr
+++ b/testsuite/tests/simplCore/should_compile/T9400.stderr
@@ -1,33 +1,34 @@
==================== Tidy Core ====================
-Result size of Tidy Core = {terms: 37, types: 22, coercions: 0}
+Result size of Tidy Core
+ = {terms: 37, types: 22, coercions: 0, joins: 0/0}
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
$trModule1 :: Addr#
[GblId, Caf=NoCafRefs]
$trModule1 = "main"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
$trModule2 :: TrName
[GblId, Caf=NoCafRefs]
$trModule2 = GHC.Types.TrNameS $trModule1
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
$trModule3 :: Addr#
[GblId, Caf=NoCafRefs]
$trModule3 = "T9400"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
$trModule4 :: TrName
[GblId, Caf=NoCafRefs]
$trModule4 = GHC.Types.TrNameS $trModule3
--- RHS size: {terms: 3, types: 0, coercions: 0}
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
T9400.$trModule :: Module
[GblId, Caf=NoCafRefs]
T9400.$trModule = GHC.Types.Module $trModule2 $trModule4
--- RHS size: {terms: 22, types: 15, coercions: 0}
+-- RHS size: {terms: 22, types: 15, coercions: 0, joins: 0/0}
main :: IO ()
[GblId]
main =
diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T
index d63d0d1958..4cc11de737 100644
--- a/testsuite/tests/simplCore/should_compile/all.T
+++ b/testsuite/tests/simplCore/should_compile/all.T
@@ -24,7 +24,6 @@ test('simpl020', [], multimod_compile, ['simpl020', '-v0'])
test('simpl-T1370', normal, compile, [''])
test('T2520', normal, compile, [''])
-
test('spec001', when(fast(), skip), compile, [''])
test('spec002', normal, compile, [''])
test('spec003', normal, compile, [''])
@@ -99,7 +98,7 @@ test('T4918', [], run_command, ['$MAKE -s --no-print-directory T4918'])
# result of -ddump-simpl, which is never advertised to
# be very stable
test('T4945',
- expect_broken(4945),
+ normal,
run_command,
['$MAKE -s --no-print-directory T4945'])
diff --git a/testsuite/tests/simplCore/should_compile/par01.stderr b/testsuite/tests/simplCore/should_compile/par01.stderr
index 4ccb9d892b..bbcb9ef4fd 100644
--- a/testsuite/tests/simplCore/should_compile/par01.stderr
+++ b/testsuite/tests/simplCore/should_compile/par01.stderr
@@ -1,9 +1,10 @@
==================== CorePrep ====================
-Result size of CorePrep = {terms: 22, types: 10, coercions: 0}
+Result size of CorePrep
+ = {terms: 22, types: 10, coercions: 0, joins: 0/0}
Rec {
--- RHS size: {terms: 7, types: 3, coercions: 0}
+-- RHS size: {terms: 7, types: 3, coercions: 0, joins: 0/0}
Par01.depth [Occ=LoopBreaker] :: GHC.Types.Int -> GHC.Types.Int
[GblId, Arity=1, Caf=NoCafRefs, Str=<L,U>, Unf=OtherCon []]
Par01.depth =
@@ -13,27 +14,27 @@ Par01.depth =
}
end Rec }
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
Par01.$trModule4 :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs, Unf=OtherCon []]
Par01.$trModule4 = "main"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
Par01.$trModule3 :: GHC.Types.TrName
[GblId, Caf=NoCafRefs, Str=m1, Unf=OtherCon []]
Par01.$trModule3 = GHC.Types.TrNameS Par01.$trModule4
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
Par01.$trModule2 :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs, Unf=OtherCon []]
Par01.$trModule2 = "Par01"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
Par01.$trModule1 :: GHC.Types.TrName
[GblId, Caf=NoCafRefs, Str=m1, Unf=OtherCon []]
Par01.$trModule1 = GHC.Types.TrNameS Par01.$trModule2
--- RHS size: {terms: 3, types: 0, coercions: 0}
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
Par01.$trModule :: GHC.Types.Module
[GblId, Caf=NoCafRefs, Str=m, Unf=OtherCon []]
Par01.$trModule =
diff --git a/testsuite/tests/simplCore/should_compile/spec-inline.stderr b/testsuite/tests/simplCore/should_compile/spec-inline.stderr
index 0de46d181d..dda28c8926 100644
--- a/testsuite/tests/simplCore/should_compile/spec-inline.stderr
+++ b/testsuite/tests/simplCore/should_compile/spec-inline.stderr
@@ -1,8 +1,9 @@
==================== Tidy Core ====================
-Result size of Tidy Core = {terms: 178, types: 68, coercions: 0}
+Result size of Tidy Core
+ = {terms: 178, types: 68, coercions: 0, joins: 0/2}
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
Roman.$trModule4 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -10,7 +11,7 @@ Roman.$trModule4 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 20 0}]
Roman.$trModule4 = "main"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
Roman.$trModule3 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -19,7 +20,7 @@ Roman.$trModule3 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
Roman.$trModule3 = GHC.Types.TrNameS Roman.$trModule4
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
Roman.$trModule2 :: GHC.Prim.Addr#
[GblId,
Caf=NoCafRefs,
@@ -27,7 +28,7 @@ Roman.$trModule2 :: GHC.Prim.Addr#
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 30 0}]
Roman.$trModule2 = "Roman"#
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
Roman.$trModule1 :: GHC.Types.TrName
[GblId,
Caf=NoCafRefs,
@@ -36,7 +37,7 @@ Roman.$trModule1 :: GHC.Types.TrName
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
Roman.$trModule1 = GHC.Types.TrNameS Roman.$trModule2
--- RHS size: {terms: 3, types: 0, coercions: 0}
+-- RHS size: {terms: 3, types: 0, coercions: 0, joins: 0/0}
Roman.$trModule :: GHC.Types.Module
[GblId,
Caf=NoCafRefs,
@@ -46,19 +47,19 @@ Roman.$trModule :: GHC.Types.Module
Roman.$trModule =
GHC.Types.Module Roman.$trModule3 Roman.$trModule1
--- RHS size: {terms: 1, types: 0, coercions: 0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
lvl :: GHC.Prim.Addr#
[GblId, Caf=NoCafRefs]
lvl = "spec-inline.hs:(19,5)-(29,25)|function go"#
--- RHS size: {terms: 2, types: 2, coercions: 0}
+-- RHS size: {terms: 2, types: 2, coercions: 0, joins: 0/0}
Roman.foo3 :: Int
[GblId, Str=x]
Roman.foo3 =
Control.Exception.Base.patError @ 'GHC.Types.LiftedRep @ Int lvl
Rec {
--- RHS size: {terms: 55, types: 9, coercions: 0}
+-- RHS size: {terms: 55, types: 9, coercions: 0, joins: 0/1}
Roman.foo_$s$wgo [Occ=LoopBreaker]
:: GHC.Prim.Int# -> GHC.Prim.Int# -> GHC.Prim.Int#
[GblId, Arity=2, Caf=NoCafRefs, Str=<S,U><S,U>]
@@ -88,7 +89,7 @@ Roman.foo_$s$wgo =
}
end Rec }
--- RHS size: {terms: 74, types: 22, coercions: 0}
+-- RHS size: {terms: 74, types: 22, coercions: 0, joins: 0/1}
Roman.$wgo [InlPrag=[0]] :: Maybe Int -> Maybe Int -> GHC.Prim.Int#
[GblId,
Arity=2,
@@ -132,7 +133,7 @@ Roman.$wgo =
}
}
--- RHS size: {terms: 9, types: 5, coercions: 0}
+-- RHS size: {terms: 9, types: 5, coercions: 0, joins: 0/0}
Roman.foo_go [InlPrag=INLINE[0]] :: Maybe Int -> Maybe Int -> Int
[GblId,
Arity=2,
@@ -146,7 +147,7 @@ Roman.foo_go =
\ (w :: Maybe Int) (w1 :: Maybe Int) ->
case Roman.$wgo w w1 of ww { __DEFAULT -> GHC.Types.I# ww }
--- RHS size: {terms: 2, types: 0, coercions: 0}
+-- RHS size: {terms: 2, types: 0, coercions: 0, joins: 0/0}
Roman.foo2 :: Int
[GblId,
Caf=NoCafRefs,
@@ -155,7 +156,7 @@ Roman.foo2 :: Int
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
Roman.foo2 = GHC.Types.I# 6#
--- RHS size: {terms: 2, types: 1, coercions: 0}
+-- RHS size: {terms: 2, types: 1, coercions: 0, joins: 0/0}
Roman.foo1 :: Maybe Int
[GblId,
Caf=NoCafRefs,
@@ -164,7 +165,7 @@ Roman.foo1 :: Maybe Int
WorkFree=True, Expandable=True, Guidance=IF_ARGS [] 10 20}]
Roman.foo1 = GHC.Base.Just @ Int Roman.foo2
--- RHS size: {terms: 11, types: 4, coercions: 0}
+-- RHS size: {terms: 11, types: 4, coercions: 0, joins: 0/0}
foo :: Int -> Int
[GblId,
Arity=1,