summaryrefslogtreecommitdiff
path: root/compiler/GHC/Stg/Lift/Monad.hs
blob: 6110e3b809d5c3778bd625a77eb15ed83dc986ab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412

{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}

-- | Hides away distracting bookkeeping while lambda lifting into a 'LiftM'
-- monad.
module GHC.Stg.Lift.Monad (
    decomposeStgBinding, mkStgBinding,
    Env (..),
    -- * #floats# Handling floats
    -- $floats
    FloatLang (..), collectFloats, -- Exported just for the docs
    -- * Transformation monad
    LiftM, runLiftM,
    -- ** Adding bindings
    startBindingGroup, endBindingGroup, addTopStringLit, addLiftedBinding, addStaticPtrBinding,
    -- ** Substitution and binders
    withSubstBndr, withSubstBndrs, withLiftedBndr, withLiftedBndrs,
    -- ** Occurrences
    substOcc, isLifted, formerFreeVars, liftedIdsExpander
  ) where

import GHC.Prelude

import GHC.Types.Basic
import GHC.Types.CostCentre ( isCurrentCCS, dontCareCCS )
import GHC.Driver.Session
import GHC.Data.FastString
import GHC.Types.Id
import GHC.Types.Name
import GHC.Utils.Outputable
import GHC.Data.OrdList
import GHC.Stg.Subst
import GHC.Stg.Syntax
import GHC.Core.Utils
import GHC.Types.Unique.Supply
import GHC.Utils.Panic
import GHC.Utils.Panic.Plain
import GHC.Types.Var.Env
import GHC.Types.Var.Set
import GHC.Core.Multiplicity

import Control.Arrow ( second )
import Control.Monad.Trans.Class
import Control.Monad.Trans.RWS.Strict ( RWST, runRWST )
import qualified Control.Monad.Trans.RWS.Strict as RWS
import Control.Monad.Trans.Cont ( ContT (..) )
import Data.ByteString ( ByteString )
import GHC.Core.TyCo.Rep
import GHC.Core.Type
import GHC.Builtin.Types
import GHC.Linker.Types
import GHC.Types.Unique
import GHC.Fingerprint
import GHC.Unit.Module
import Data.List (intercalate)
import GHC.Types.Literal
import GHC.Platform
import Data.Maybe
import GHC.LanguageExtensions
import GHC.Types.SrcLoc

-- | @uncurry 'mkStgBinding' . 'decomposeStgBinding' = id@
decomposeStgBinding :: GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding (StgRec pairs) = (Recursive, pairs)
decomposeStgBinding (StgNonRec bndr rhs) = (NonRecursive, [(bndr, rhs)])

mkStgBinding :: RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding Recursive = StgRec
mkStgBinding NonRecursive = uncurry StgNonRec . head

-- | Environment threaded around in a scoped, @Reader@-like fashion.
data Env
  = Env
  { e_dflags     :: !DynFlags
  -- ^ Read-only.
  , e_subst      :: !Subst
  -- ^ We need to track the renamings of local 'InId's to their lifted 'OutId',
  -- because shadowing might make a closure's free variables unavailable at its
  -- call sites. Consider:
  -- @
  --    let f y = x + y in let x = 4 in f x
  -- @
  -- Here, @f@ can't be lifted to top-level, because its free variable @x@ isn't
  -- available at its call site.
  , e_expansions :: !(IdEnv DIdSet)
  -- ^ Lifted 'Id's don't occur as free variables in any closure anymore, because
  -- they are bound at the top-level. Every occurrence must supply the formerly
  -- free variables of the lifted 'Id', so they in turn become free variables of
  -- the call sites. This environment tracks this expansion from lifted 'Id's to
  -- their free variables.
  --
  -- 'InId's to 'OutId's.
  --
  -- Invariant: 'Id's not present in this map won't be substituted.
  , e_mod :: !Module
  }

emptyEnv :: Module -> DynFlags -> Env
emptyEnv this_mod dflags = Env dflags emptySubst emptyVarEnv this_mod


-- Note [Handling floats]
-- ~~~~~~~~~~~~~~~~~~~~~~
-- $floats
-- Consider the following expression:
--
-- @
--     f x =
--       let g y = ... f y ...
--       in g x
-- @
--
-- What happens when we want to lift @g@? Normally, we'd put the lifted @l_g@
-- binding above the binding for @f@:
--
-- @
--     g f y = ... f y ...
--     f x = g f x
-- @
--
-- But this very unnecessarily turns a known call to @f@ into an unknown one, in
-- addition to complicating matters for the analysis.
-- Instead, we'd really like to put both functions in the same recursive group,
-- thereby preserving the known call:
--
-- @
--     Rec {
--       g y = ... f y ...
--       f x = g x
--     }
-- @
--
-- But we don't want this to happen for just /any/ binding. That would create
-- possibly huge recursive groups in the process, calling for an occurrence
-- analyser on STG.
-- So, we need to track when we lift a binding out of a recursive RHS and add
-- the binding to the same recursive group as the enclosing recursive binding
-- (which must have either already been at the top-level or decided to be
-- lifted itself in order to preserve the known call).
--
-- This is done by expressing this kind of nesting structure as a 'Writer' over
-- @['FloatLang']@ and flattening this expression in 'runLiftM' by a call to
-- 'collectFloats'.
-- API-wise, the analysis will not need to know about the whole 'FloatLang'
-- business and will just manipulate it indirectly through actions in 'LiftM'.

-- | We need to detect when we are lifting something out of the RHS of a
-- recursive binding (c.f. "GHC.Stg.Lift.Monad#floats"), in which case that
-- binding needs to be added to the same top-level recursive group. This
-- requires we detect a certain nesting structure, which is encoded by
-- 'StartBindingGroup' and 'EndBindingGroup'.
--
-- Although 'collectFloats' will only ever care if the current binding to be
-- lifted (through 'LiftedBinding') will occur inside such a binding group or
-- not, e.g. doesn't care about the nesting level as long as its greater than 0.
data FloatLang
  = StartBindingGroup
  | EndBindingGroup
  | PlainTopBinding OutStgTopBinding
  | LiftedBinding OutStgBinding
  | LiftedStaticBinding SptEntry Id StgRhs

instance Outputable FloatLang where
  ppr StartBindingGroup = char '('
  ppr EndBindingGroup = char ')'
  ppr (PlainTopBinding StgTopStringLit{}) = text "<str>"
  ppr (PlainTopBinding (StgTopLifted b)) = ppr (LiftedBinding b)
  ppr (LiftedBinding bind) = (if isRec rec then char 'r' else char 'n') <+> ppr (map fst pairs)
    where
      (rec, pairs) = decomposeStgBinding bind
  ppr (LiftedStaticBinding _spt binder _bind) = ppr binder

-- | Flattens an expression in @['FloatLang']@ into an STG program, see "GHC.Stg.Lift.Monad#floats".
-- Important pre-conditions: The nesting of opening 'StartBindingGroup's and
-- closing 'EndBindingGroup's is balanced. Also, it is crucial that every binding
-- group has at least one recursive binding inside. Otherwise there's no point
-- in announcing the binding group in the first place and an @ASSERT@ will
-- trigger.
collectFloats :: [FloatLang] -> [OutStgTopBinding]
collectFloats = go (0 :: Int) []
  where
    go 0 [] [] = []
    go _ _ [] = pprPanic "collectFloats" (text "unterminated group")
    go n binds (f:rest) = case f of
      StartBindingGroup -> go (n+1) binds rest
      EndBindingGroup
        | n == 0 -> pprPanic "collectFloats" (text "no group to end")
        | n == 1 -> StgTopLifted (merge_binds binds) : go 0 [] rest
        | otherwise -> go (n-1) binds rest
      PlainTopBinding top_bind
        | n == 0 -> top_bind : go n binds rest
        | otherwise -> pprPanic "collectFloats" (text "plain top binding inside group")
      LiftedBinding bind
        | n == 0 -> StgTopLifted (rm_cccs bind) : go n binds rest
        | otherwise -> go n (bind:binds) rest
      LiftedStaticBinding _ binder bind
        | n == 0 -> StgTopLifted (StgNonRec binder bind) : go n binds rest
        | otherwise -> go n (StgNonRec binder bind : binds) rest

    map_rhss f = uncurry mkStgBinding . second (map (second f)) . decomposeStgBinding
    rm_cccs = map_rhss removeRhsCCCS
    merge_binds binds = assert (any is_rec binds) $
                        StgRec (concatMap (snd . decomposeStgBinding . rm_cccs) binds)
    is_rec StgRec{} = True
    is_rec _ = False

collectSPTEntries :: [FloatLang] -> [SptEntry]
collectSPTEntries = mapMaybe go
  where
    go (LiftedStaticBinding spt_entry _ _) = Just spt_entry
    go _ = Nothing

-- | Omitting this makes for strange closure allocation schemes that crash the
-- GC.
removeRhsCCCS :: GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS (StgRhsClosure ext ccs upd bndrs body)
  | isCurrentCCS ccs
  = StgRhsClosure ext dontCareCCS upd bndrs body
removeRhsCCCS (StgRhsCon ccs con mu ts args)
  | isCurrentCCS ccs
  = StgRhsCon dontCareCCS con mu ts args
removeRhsCCCS rhs = rhs

-- | The analysis monad consists of the following 'RWST' components:
--
--     * 'Env': Reader-like context. Contains a substitution, info about how
--       how lifted identifiers are to be expanded into applications and details
--       such as 'DynFlags'.
--
--     * @'OrdList' 'FloatLang'@: Writer output for the resulting STG program.
--
--     * No pure state component
--
--     * But wrapping around 'UniqSM' for generating fresh lifted binders.
--       (The @uniqAway@ approach could give the same name to two different
--       lifted binders, so this is necessary.)
newtype LiftM a
  = LiftM { unwrapLiftM :: RWST Env (OrdList FloatLang) () UniqSM a }
  deriving (Functor, Applicative, Monad)

instance HasDynFlags LiftM where
  getDynFlags = LiftM (RWS.asks e_dflags)

instance MonadUnique LiftM where
  getUniqueSupplyM = LiftM (lift getUniqueSupplyM)
  getUniqueM = LiftM (lift getUniqueM)
  getUniquesM = LiftM (lift getUniquesM)

runLiftM :: Module -> DynFlags -> UniqSupply -> LiftM () -> ([OutStgTopBinding], [SptEntry])
runLiftM this_mod dflags us (LiftM m) = (collectFloats (fromOL floats), spt_entries)
  where
    spt_entries
      | xopt StaticPointers dflags = collectSPTEntries final_floats
      | otherwise = []
    final_floats = fromOL floats
    (_, _, floats) = initUs_ us (runRWST m (emptyEnv this_mod dflags) ())

-- | Writes a plain 'StgTopStringLit' to the output.
addTopStringLit :: OutId -> ByteString -> LiftM ()
addTopStringLit id = LiftM . RWS.tell . unitOL . PlainTopBinding . StgTopStringLit id

-- | Starts a recursive binding group. See "GHC.Stg.Lift.Monad#floats" and 'collectFloats'.
startBindingGroup :: LiftM ()
startBindingGroup = LiftM $ RWS.tell $ unitOL $ StartBindingGroup

-- | Ends a recursive binding group. See "GHC.Stg.Lift.Monad#floats" and 'collectFloats'.
endBindingGroup :: LiftM ()
endBindingGroup = LiftM $ RWS.tell $ unitOL $ EndBindingGroup

-- | Lifts a binding to top-level. Depending on whether it's declared inside
-- a recursive RHS (see "GHC.Stg.Lift.Monad#floats" and 'collectFloats'), this might be added to
-- an existing recursive top-level binding group.
addLiftedBinding :: OutStgBinding -> LiftM ()
addLiftedBinding = LiftM . RWS.tell . unitOL . LiftedBinding

addLiftedStaticPtrBinding :: SptEntry -> Id -> StgRhs -> LiftM ()
addLiftedStaticPtrBinding spt_entry binder bind =
  LiftM . RWS.tell . unitOL $ LiftedStaticBinding spt_entry binder bind

newStaticPtrBndr :: Unique -> Module -> Type -> Id
newStaticPtrBndr uniq this_mod ty =
  let str = "$static_ptr" ++ show uniq
  in mkVanillaGlobal
        -- This makes and external name but *doesn't* add it to the name cache,
        -- this is safe because the name is only, and can only be, referenced from the
        -- SPT init stub.
        (mkExternalName uniq this_mod (mkVarOcc str) noSrcSpan)
        (mkTyConApp staticPtrTyCon [ty])

addStaticPtrBinding :: StgArg -> StgArg -> LiftM StgExpr
addStaticPtrBinding loc payload = do
  uniq <- getUniqueM
  uniq_info <- getUniqueM
  this_mod <- LiftM $ RWS.asks e_mod
  platform <- targetPlatform <$> (LiftM $ RWS.asks e_dflags)
  let binder_ptr = newStaticPtrBndr uniq this_mod (stgArgType payload)
      binder_info = mkSysLocal (mkFastString "$static_ptr_info") uniq_info Many (mkTyConTy staticPtrInfoTyCon)
      fp@(Fingerprint w0 w1) = mkStaticPtrFingerprint this_mod uniq
      unit_lit = LitString (bytesFS $ unitFS $ moduleUnit this_mod)
      mod_name_lit = LitString (bytesFS $ moduleNameFS $ moduleName this_mod)

      info = StgRhsCon dontCareCCS staticPtrInfoDataCon NoNumber [] [StgLitArg unit_lit, StgLitArg mod_name_lit, loc]

      ptr = StgRhsCon dontCareCCS staticPtrDataCon NoNumber []
              [ StgLitArg (mkWord64LitWordRep platform (toInteger w0))
              , StgLitArg (mkWord64LitWordRep platform (toInteger w1))
              , StgVarArg binder_info
              , payload]

      spt_entry = SptEntry binder_ptr fp
  addLiftedBinding (StgNonRec binder_info info)
  addLiftedStaticPtrBinding spt_entry binder_ptr ptr
  return (StgApp binder_ptr [])

  where
    mkStaticPtrFingerprint :: Module -> Unique -> Fingerprint
    mkStaticPtrFingerprint this_mod n = fingerprintString $ intercalate ":"
        [ unitString $ moduleUnit this_mod
        , moduleNameString $ moduleName this_mod
        , show n
        ]

    -- Choose either 'Word64#' or 'Word#' to represent the arguments of the
    -- 'Fingerprint' data constructor.
    mkWord64LitWordRep :: Platform -> Integer -> Literal
    mkWord64LitWordRep platform =
      case platformWordSize platform of
        PW4 -> mkLitWord64
        PW8 -> mkLitWord platform . toInteger

-- | Takes a binder and a continuation which is called with the substituted
-- binder. The continuation will be evaluated in a 'LiftM' context in which that
-- binder is deemed in scope. Think of it as a 'RWS.local' computation: After
-- the continuation finishes, the new binding won't be in scope anymore.
withSubstBndr :: Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr bndr inner = LiftM $ do
  subst <- RWS.asks e_subst
  let (bndr', subst') = substBndr bndr subst
  RWS.local (\e -> e { e_subst = subst' }) (unwrapLiftM (inner bndr'))

-- | See 'withSubstBndr'.
withSubstBndrs :: Traversable f => f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs = runContT . traverse (ContT . withSubstBndr)

-- | Similarly to 'withSubstBndr', this function takes a set of variables to
-- abstract over, the binder to lift (and generate a fresh, substituted name
-- for) and a continuation in which that fresh, lifted binder is in scope.
--
-- It takes care of all the details involved with copying and adjusting the
-- binder and fresh name generation.
withLiftedBndr :: DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr abs_ids bndr inner = do
  uniq <- getUniqueM
  let str = "$l" ++ occNameString (getOccName bndr)
  let ty = mkLamTypes (dVarSetElems abs_ids) (idType bndr)
  let bndr'
        -- See Note [transferPolyIdInfo] in GHC.Types.Id. We need to do this at least
        -- for arity information.
        = transferPolyIdInfo bndr (dVarSetElems abs_ids)
        . mkSysLocal (mkFastString str) uniq Many
        $ ty
  LiftM $ RWS.local
    (\e -> e
      { e_subst = extendSubst bndr bndr' $ extendInScope bndr' $ e_subst e
      , e_expansions = extendVarEnv (e_expansions e) bndr abs_ids
      })
    (unwrapLiftM (inner bndr'))

-- | Create a new exported vanilla id for the static pointer

-- | See 'withLiftedBndr'.
withLiftedBndrs :: Traversable f => DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs abs_ids = runContT . traverse (ContT . withLiftedBndr abs_ids)

-- | Substitutes a binder /occurrence/, which was brought in scope earlier by
-- 'withSubstBndr' \/ 'withLiftedBndr'.
substOcc :: Id -> LiftM Id
substOcc id = LiftM (RWS.asks (lookupIdSubst id . e_subst))

-- | Whether the given binding was decided to be lambda lifted.
isLifted :: InId -> LiftM Bool
isLifted bndr = LiftM (RWS.asks (elemVarEnv bndr . e_expansions))

-- | Returns an empty list for a binding that was not lifted and the list of all
-- local variables the binding abstracts over (so, exactly the additional
-- arguments at adjusted call sites) otherwise.
formerFreeVars :: InId -> LiftM [OutId]
formerFreeVars f = LiftM $ do
  expansions <- RWS.asks e_expansions
  pure $ case lookupVarEnv expansions f of
    Nothing -> []
    Just fvs -> dVarSetElems fvs

-- | Creates an /expander function/ for the current set of lifted binders.
-- This expander function will replace any 'InId' by their corresponding 'OutId'
-- and, in addition, will expand any lifted binders by the former free variables
-- it abstracts over.
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander = LiftM $ do
  expansions <- RWS.asks e_expansions
  subst <- RWS.asks e_subst
  -- We use @noWarnLookupIdSubst@ here in order to suppress "not in scope"
  -- warnings generated by 'lookupIdSubst' due to local bindings within RHS.
  -- These are not in the InScopeSet of @subst@ and extending the InScopeSet in
  -- @goodToLift@/@closureGrowth@ before passing it on to @expander@ is too much
  -- trouble.
  let go set fv = case lookupVarEnv expansions fv of
        Nothing -> extendDVarSet set (noWarnLookupIdSubst fv subst) -- Not lifted
        Just fvs' -> unionDVarSet set fvs'
  let expander fvs = foldl' go emptyDVarSet (dVarSetElems fvs)
  pure expander