summaryrefslogtreecommitdiff
path: root/compiler/cmm/CmmSink.hs
blob: 02195c91e13928b98c1c7adb4909341a52f56bf9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
{-# LANGUAGE GADTs #-}
module CmmSink (
     cmmSink
  ) where

import Cmm
import BlockId
import CmmLive
import CmmUtils
import Hoopl

import UniqFM
-- import Outputable

import Data.List (partition)
import qualified Data.Set as Set

-- -----------------------------------------------------------------------------
-- Sinking

-- This is an optimisation pass that
--  (a) moves assignments closer to their uses, to reduce register pressure
--  (b) pushes assignments into a single branch of a conditional if possible

-- It is particularly helpful in the Cmm generated by the Stg->Cmm
-- code generator, in which every function starts with a copyIn
-- sequence like:
--
--    x1 = R1
--    x2 = Sp[8]
--    x3 = Sp[16]
--    if (Sp - 32 < SpLim) then L1 else L2
--
-- we really want to push the x1..x3 assignments into the L2 branch.
--
-- Algorithm:
--
--  * Start by doing liveness analysis.
--  * Keep a list of assignments; earlier ones may refer to later ones
--  * Walk forwards through the graph;
--    * At an assignment:
--      * pick up the assignment and add it to the list
--    * At a store:
--      * drop any assignments that the store refers to
--      * drop any assignments that refer to memory that may be written
--        by the store
--      * do this recursively, dropping dependent assignments
--    * At a multi-way branch:
--      * drop any assignments that are live on more than one branch
--      * if any successor has more than one predecessor, drop everything
--        live in that successor
-- 
-- As a side-effect we'll delete some dead assignments (transitively,
-- even).  Maybe we could do without removeDeadAssignments?

-- If we do this *before* stack layout, we might be able to avoid
-- saving some things across calls/procpoints.
--
-- *but*, that will invalidate the liveness analysis, and we'll have
-- to re-do it.

type Assignment = (LocalReg, CmmExpr, AbsMem)

cmmSink :: CmmGraph -> CmmGraph
cmmSink graph = ofBlockList (g_entry graph) $ sink mapEmpty $ blocks
  where
  liveness = cmmLiveness graph
  getLive l = mapFindWithDefault Set.empty l liveness

  blocks = postorderDfs graph

  join_pts = findJoinPoints blocks

  sink :: BlockEnv [Assignment] -> [CmmBlock] -> [CmmBlock]
  sink _ [] = []
  sink sunk (b:bs) =
    -- pprTrace "sink" (ppr lbl) $
    blockJoin first final_middle final_last : sink sunk' bs
    where
      lbl = entryLabel b
      (first, middle, last) = blockSplit b

      succs = successors last

      -- Annotate the middle nodes with the registers live *after*
      -- the node.  This will help us decide whether we can inline
      -- an assignment in the current node or not.
      live = Set.unions (map getLive succs)
      live_middle = gen_kill last live
      ann_middles = annotate live_middle (blockToList middle)

      -- Now sink and inline in this block
      (middle', assigs) = walk ann_middles (mapFindWithDefault [] lbl sunk)
      (final_last, assigs') = tryToInline live last assigs

      -- We cannot sink into join points (successors with more than
      -- one predecessor), so identify the join points and the set
      -- of registers live in them.
      (joins, nonjoins) = partition (`mapMember` join_pts) succs
      live_in_joins = Set.unions (map getLive joins)

      -- We do not want to sink an assignment into multiple branches,
      -- so identify the set of registers live in multiple successors.
      -- This is made more complicated because when we sink an assignment
      -- into one branch, this might change the set of registers that are
      -- now live in multiple branches.
      init_live_sets = map getLive nonjoins
      live_in_multi live_sets r =
         case filter (Set.member r) live_sets of
           (_one:_two:_) -> True
           _ -> False

      -- Now, drop any assignments that we will not sink any further.
      (dropped_last, assigs'') = dropAssignments drop_if init_live_sets assigs'

      drop_if a@(r,rhs,_) live_sets = (should_drop, live_sets')
          where
            should_drop =  a `conflicts` final_last
                        || {- not (isTiny rhs) && -} live_in_multi live_sets r
                        || r `Set.member` live_in_joins

            live_sets' | should_drop = live_sets
                       | otherwise   = map upd live_sets

            upd set | r `Set.member` set = set `Set.union` live_rhs
                    | otherwise          = set

            live_rhs = foldRegsUsed extendRegSet emptyRegSet rhs

      final_middle = foldl blockSnoc middle' dropped_last

      sunk' = mapUnion sunk $
                 mapFromList [ (l, filterAssignments (getLive l) assigs'')
                             | l <- succs ]

{-
-- tiny: an expression we don't mind duplicating
isTiny :: CmmExpr -> Bool
isTiny (CmmReg _) = True
isTiny (CmmLit _) = True
isTiny _other     = False
-}

--
-- annotate each node with the set of registers live *after* the node
--
annotate :: RegSet -> [CmmNode O O] -> [(RegSet, CmmNode O O)]
annotate live nodes = snd $ foldr ann (live,[]) nodes
  where ann n (live,nodes) = (gen_kill n live, (live,n) : nodes)

--
-- Find the blocks that have multiple successors (join points)
--
findJoinPoints :: [CmmBlock] -> BlockEnv Int
findJoinPoints blocks = mapFilter (>1) succ_counts
 where
  all_succs = concatMap successors blocks

  succ_counts :: BlockEnv Int
  succ_counts = foldr (\l -> mapInsertWith (+) l 1) mapEmpty all_succs

--
-- filter the list of assignments to remove any assignments that
-- are not live in a continuation.
--
filterAssignments :: RegSet -> [Assignment] -> [Assignment]
filterAssignments live assigs = reverse (go assigs [])
  where go []             kept = kept
        go (a@(r,_,_):as) kept | needed    = go as (a:kept)
                               | otherwise = go as kept
           where
              needed = r `Set.member` live
                       || any (a `conflicts`) (map toNode kept)
                       --  Note that we must keep assignments that are
                       -- referred to by other assignments we have
                       -- already kept.

-- -----------------------------------------------------------------------------
-- Walk through the nodes of a block, sinking and inlining assignments
-- as we go.

walk :: [(RegSet, CmmNode O O)]         -- nodes of the block, annotated with
                                        -- the set of registers live *after*
                                        -- this node.

     -> [Assignment]                    -- The current list of
                                        -- assignments we are sinking.
                                        -- Later assignments may refer
                                        -- to earlier ones.

     -> ( Block CmmNode O O             -- The new block
        , [Assignment]                  -- Assignments to sink further
        )

walk nodes assigs = go nodes emptyBlock assigs
 where
   go []               block as = (block, as)
   go ((live,node):ns) block as
    | discard                    = go ns block as
    | Just a <- shouldSink node1 = go ns block (a : as1)
    | otherwise                  = go ns block' as'
    where
      -- discard dead assignments.  This doesn't do as good a job as
      -- removeDeadAsssignments, because it would need multiple passes
      -- to get all the dead code, but it catches the common case of
      -- superfluous reloads from the stack that the stack allocator
      -- leaves behind.
      discard = case node of
                  CmmAssign (CmmLocal r) _ -> not (r `Set.member` live)
                  _otherwise -> False
  
      (node1, as1) = tryToInline live node as

      (dropped, as') = dropAssignmentsSimple (`conflicts` node1) as1
      block' = foldl blockSnoc block dropped `blockSnoc` node1

--
-- Heuristic to decide whether to pick up and sink an assignment
-- Currently we pick up all assignments to local registers.  It might
-- be profitable to sink assignments to global regs too, but the
-- liveness analysis doesn't track those (yet) so we can't.
--
shouldSink :: CmmNode e x -> Maybe Assignment
shouldSink (CmmAssign (CmmLocal r) e) | no_local_regs = Just (r, e, exprMem e)
  where no_local_regs = True -- foldRegsUsed (\_ _ -> False) True e
shouldSink _other = Nothing

toNode :: Assignment -> CmmNode O O
toNode (r,rhs,_) = CmmAssign (CmmLocal r) rhs

dropAssignmentsSimple :: (Assignment -> Bool) -> [Assignment]
                      -> ([CmmNode O O], [Assignment])
dropAssignmentsSimple f = dropAssignments (\a _ -> (f a, ())) ()

dropAssignments :: (Assignment -> s -> (Bool, s)) -> s -> [Assignment]
                -> ([CmmNode O O], [Assignment])
dropAssignments should_drop state assigs
 = (dropped, reverse kept)
 where
   (dropped,kept) = go state assigs [] []

   go _ []             dropped kept = (dropped, kept)
   go state (assig : rest) dropped kept
      | conflict  = go state' rest (toNode assig : dropped) kept
      | otherwise = go state' rest dropped (assig:kept)
      where
        (dropit, state') = should_drop assig state
        conflict = dropit || any (assig `conflicts`) dropped


-- -----------------------------------------------------------------------------
-- Try to inline assignments into a node.

tryToInline
   :: RegSet                    -- set of registers live after this
                                -- node.  We cannot inline anything
                                -- that is live after the node, unless
                                -- it is small enough to duplicate.
   -> CmmNode O x               -- The node to inline into
   -> [Assignment]              -- Assignments to inline
   -> (
        CmmNode O x             -- New node
      , [Assignment]            -- Remaining assignments
      )

tryToInline live node assigs = go live usages node assigs
 where
  usages :: UniqFM Int
  usages = foldRegsUsed addUsage emptyUFM node

  go _live _usages node [] = (node, [])

  go live usages node (a@(l,rhs,_) : rest)
   | occurs_once_in_this_node  = inline_and_discard
   | False {- isTiny rhs -}    = inline_and_keep
     -- ^^ seems to make things slightly worse
   where
        inline_and_discard = go live' usages' node' rest

        inline_and_keep = (node'', a : rest')
          where (node'',rest') = inline_and_discard

        occurs_once_in_this_node =
         not (l `elemRegSet` live) &&  lookupUFM usages l == Just 1

        live'   = foldRegsUsed extendRegSet live rhs
        usages' = foldRegsUsed addUsage usages rhs

        node' = mapExpDeep inline node
           where inline (CmmReg    (CmmLocal l'))     | l == l' = rhs
                 inline (CmmRegOff (CmmLocal l') off) | l == l'
                    = cmmOffset rhs off
                 inline other = other

  go live usages node (assig@(_,rhs,_) : rest)
    = (node', assig : rest')
    where (node', rest') = go live usages' node rest
          usages' = foldRegsUsed addUsage usages rhs

addUsage :: UniqFM Int -> LocalReg -> UniqFM Int
addUsage m r = addToUFM_C (+) m r 1


-- -----------------------------------------------------------------------------

-- | @conflicts (r,e) stmt@ is @False@ if and only if the assignment
-- @r = e@ can be safely commuted past @stmt@.
--
-- We only sink "r = G" assignments right now, so conflicts is very simple:
--
conflicts :: Assignment -> CmmNode O x -> Bool
(r, rhs, addr) `conflicts` node

  -- (1) an assignment to a register conflicts with a use of the register
  | CmmAssign reg  _ <- node, reg `regUsedIn` rhs                 = True
  | foldRegsUsed (\b r' -> r == r' || b) False node               = True

  -- (2) a store to an address conflicts with a read of the same memory
  | CmmStore addr' e <- node, memConflicts addr (loadAddr addr' (cmmExprWidth e)) = True

  -- (3) an assignment to Hp/Sp conflicts with a heap/stack read respectively
  | HeapMem    <- addr, CmmAssign (CmmGlobal Hp) _ <- node         = True
  | StackMem   <- addr, CmmAssign (CmmGlobal Sp) _ <- node         = True
  | SpMem{}    <- addr, CmmAssign (CmmGlobal Sp) _ <- node         = True

  -- (4) otherwise, no conflict
  | otherwise = False


-- An abstraction of memory read or written.
data AbsMem
  = NoMem            -- no memory accessed
  | AnyMem           -- arbitrary memory
  | HeapMem          -- definitely heap memory
  | StackMem         -- definitely stack memory
  | SpMem            -- <size>[Sp+n]
       {-# UNPACK #-} !Int
       {-# UNPACK #-} !Int

-- Having SpMem is important because it lets us float loads from Sp
-- past stores to Sp as long as they don't overlap, and this helps to
-- unravel some long sequences of
--    x1 = [Sp + 8]
--    x2 = [Sp + 16]
--    ...
--    [Sp + 8]  = xi
--    [Sp + 16] = xj
--
-- Note that SpMem is invalidated if Sp is changed, but the definition
-- of 'conflicts' above handles that.

bothMems :: AbsMem -> AbsMem -> AbsMem
bothMems NoMem    x         = x
bothMems x        NoMem     = x
bothMems HeapMem  HeapMem   = HeapMem
bothMems StackMem StackMem     = StackMem
bothMems (SpMem o1 w1) (SpMem o2 w2)
  | o1 == o2  = SpMem o1 (max w1 w2)
  | otherwise = StackMem
bothMems SpMem{}  StackMem  = StackMem
bothMems StackMem SpMem{}   = StackMem
bothMems _         _        = AnyMem

memConflicts :: AbsMem -> AbsMem -> Bool
memConflicts NoMem      _          = False
memConflicts _          NoMem      = False
memConflicts HeapMem    StackMem   = False
memConflicts StackMem   HeapMem    = False
memConflicts SpMem{}    HeapMem    = False
memConflicts HeapMem    SpMem{}    = False
memConflicts (SpMem o1 w1) (SpMem o2 w2)
  | o1 < o2   = o1 + w1 > o2
  | otherwise = o2 + w2 > o1
memConflicts _         _         = True

exprMem :: CmmExpr -> AbsMem
exprMem (CmmLoad addr w)  = bothMems (loadAddr addr (typeWidth w)) (exprMem addr)
exprMem (CmmMachOp _ es)  = foldr bothMems NoMem (map exprMem es)
exprMem _                 = NoMem

loadAddr :: CmmExpr -> Width -> AbsMem
loadAddr e w =
  case e of
   CmmReg r       -> regAddr r 0 w
   CmmRegOff r i  -> regAddr r i w
   _other | CmmGlobal Sp `regUsedIn` e -> StackMem
          | otherwise -> AnyMem

regAddr :: CmmReg -> Int -> Width -> AbsMem
regAddr (CmmGlobal Sp) i w = SpMem i (widthInBytes w)
regAddr (CmmGlobal Hp) _ _ = HeapMem
regAddr r _ _ | isGcPtrType (cmmRegType r) = HeapMem -- yay! GCPtr pays for itself
regAddr _ _ _ = AnyMem