1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
|
module Vectorise( vectorise )
where
import VectMonad
import VectUtils
import VectType
import VectCore
import HscTypes hiding ( MonadThings(..) )
import Module ( PackageId )
import CoreSyn
import CoreUtils
import MkCore ( mkWildCase )
import CoreFVs
import CoreMonad ( CoreM, getHscEnv )
import DataCon
import TyCon
import Type
import FamInstEnv ( extendFamInstEnvList )
import Var
import VarEnv
import VarSet
import Id
import OccName
import Literal ( Literal, mkMachInt )
import TysWiredIn
import TysPrim ( intPrimTy )
import Outputable
import FastString
import Control.Monad ( liftM, liftM2, zipWithM )
import Data.List ( sortBy, unzip4 )
vectorise :: PackageId -> ModGuts -> CoreM ModGuts
vectorise backend guts = do
hsc_env <- getHscEnv
liftIO $ vectoriseIO backend hsc_env guts
vectoriseIO :: PackageId -> HscEnv -> ModGuts -> IO ModGuts
vectoriseIO backend hsc_env guts
= do
eps <- hscEPS hsc_env
let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
Just (info', guts') <- initV backend hsc_env guts info (vectModule guts)
return (guts' { mg_vect_info = info' })
vectModule :: ModGuts -> VM ModGuts
vectModule guts
= do
(types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
updGEnv (setFamInstEnv fam_inst_env')
-- dicts <- mapM buildPADict pa_insts
-- workers <- mapM vectDataConWorkers pa_insts
binds' <- mapM vectTopBind (mg_binds guts)
return $ guts { mg_types = types'
, mg_binds = Rec tc_binds : binds'
, mg_fam_inst_env = fam_inst_env'
, mg_fam_insts = mg_fam_insts guts ++ fam_insts
}
vectTopBind :: CoreBind -> VM CoreBind
vectTopBind b@(NonRec var expr)
= do
var' <- vectTopBinder var
expr' <- vectTopRhs var expr
hs <- takeHoisted
cexpr <- tryConvert var var' expr
return . Rec $ (var, cexpr) : (var', expr') : hs
`orElseV`
return b
vectTopBind b@(Rec bs)
= do
vars' <- mapM vectTopBinder vars
exprs' <- zipWithM vectTopRhs vars exprs
hs <- takeHoisted
cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
`orElseV`
return b
where
(vars, exprs) = unzip bs
vectTopBinder :: Var -> VM Var
vectTopBinder var
= do
vty <- vectType (idType var)
var' <- cloneId mkVectOcc var vty
defGlobalVar var var'
return var'
vectTopRhs :: Var -> CoreExpr -> VM CoreExpr
vectTopRhs var expr
= do
closedV . liftM vectorised
. inBind var
$ vectPolyExpr (freeVars expr)
tryConvert :: Var -> Var -> CoreExpr -> VM CoreExpr
tryConvert var vect_var rhs
= fromVect (idType var) (Var vect_var) `orElseV` return rhs
-- ----------------------------------------------------------------------------
-- Bindings
vectBndr :: Var -> VM VVar
vectBndr v
= do
(vty, lty) <- vectAndLiftType (idType v)
let vv = v `Id.setIdType` vty
lv = v `Id.setIdType` lty
updLEnv (mapTo vv lv)
return (vv, lv)
where
mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (vv, lv) }
vectBndrNew :: Var -> FastString -> VM VVar
vectBndrNew v fs
= do
vty <- vectType (idType v)
vv <- newLocalVVar fs vty
updLEnv (upd vv)
return vv
where
upd vv env = env { local_vars = extendVarEnv (local_vars env) v vv }
vectBndrIn :: Var -> VM a -> VM (VVar, a)
vectBndrIn v p
= localV
$ do
vv <- vectBndr v
x <- p
return (vv, x)
vectBndrNewIn :: Var -> FastString -> VM a -> VM (VVar, a)
vectBndrNewIn v fs p
= localV
$ do
vv <- vectBndrNew v fs
x <- p
return (vv, x)
vectBndrsIn :: [Var] -> VM a -> VM ([VVar], a)
vectBndrsIn vs p
= localV
$ do
vvs <- mapM vectBndr vs
x <- p
return (vvs, x)
-- ----------------------------------------------------------------------------
-- Expressions
vectVar :: Var -> VM VExpr
vectVar v
= do
r <- lookupVar v
case r of
Local (vv,lv) -> return (Var vv, Var lv)
Global vv -> do
let vexpr = Var vv
lexpr <- liftPD vexpr
return (vexpr, lexpr)
vectPolyVar :: Var -> [Type] -> VM VExpr
vectPolyVar v tys
= do
vtys <- mapM vectType tys
r <- lookupVar v
case r of
Local (vv, lv) -> liftM2 (,) (polyApply (Var vv) vtys)
(polyApply (Var lv) vtys)
Global poly -> do
vexpr <- polyApply (Var poly) vtys
lexpr <- liftPD vexpr
return (vexpr, lexpr)
vectLiteral :: Literal -> VM VExpr
vectLiteral lit
= do
lexpr <- liftPD (Lit lit)
return (Lit lit, lexpr)
vectPolyExpr :: CoreExprWithFVs -> VM VExpr
vectPolyExpr (_, AnnNote note expr)
= liftM (vNote note) $ vectPolyExpr expr
vectPolyExpr expr
= polyAbstract tvs $ \abstract ->
do
mono' <- vectFnExpr False mono
return $ mapVect abstract mono'
where
(tvs, mono) = collectAnnTypeBinders expr
vectExpr :: CoreExprWithFVs -> VM VExpr
vectExpr (_, AnnType ty)
= liftM vType (vectType ty)
vectExpr (_, AnnVar v) = vectVar v
vectExpr (_, AnnLit lit) = vectLiteral lit
vectExpr (_, AnnNote note expr)
= liftM (vNote note) (vectExpr expr)
vectExpr e@(_, AnnApp _ arg)
| isAnnTypeArg arg
= vectTyAppExpr fn tys
where
(fn, tys) = collectAnnTypeArgs e
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
| Just con <- isDataConId_maybe v
, is_special_con con
= do
let vexpr = App (Var v) (Lit lit)
lexpr <- liftPD vexpr
return (vexpr, lexpr)
where
is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
vectExpr (_, AnnApp fn arg)
= do
arg_ty' <- vectType arg_ty
res_ty' <- vectType res_ty
fn' <- vectExpr fn
arg' <- vectExpr arg
mkClosureApp arg_ty' res_ty' fn' arg'
where
(arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
vectExpr (_, AnnCase scrut bndr ty alts)
| Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
, isAlgTyCon tycon
= vectAlgCase tycon ty_args scrut bndr ty alts
where
scrut_ty = exprType (deAnnotate scrut)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
vrhs <- localV . inBind bndr $ vectPolyExpr rhs
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vLet (vNonRec vbndr vrhs) vbody
vectExpr (_, AnnLet (AnnRec bs) body)
= do
(vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
$ liftM2 (,)
(zipWithM vect_rhs bndrs rhss)
(vectPolyExpr body)
return $ vLet (vRec vbndrs vrhss) vbody
where
(bndrs, rhss) = unzip bs
vect_rhs bndr rhs = localV
. inBind bndr
$ vectExpr rhs
vectExpr e@(_, AnnLam bndr _)
| isId bndr = vectFnExpr True e
{-
onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
`orElseV` vectLam True fvs bs body
where
(bs,body) = collectAnnValBinders e
-}
vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
vectFnExpr :: Bool -> CoreExprWithFVs -> VM VExpr
vectFnExpr inline e@(fvs, AnnLam bndr _)
| isId bndr = onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
`orElseV` vectLam inline fvs bs body
where
(bs,body) = collectAnnValBinders e
vectFnExpr _ e = vectExpr e
vectScalarLam :: [Var] -> CoreExpr -> VM VExpr
vectScalarLam args body
= do
scalars <- globalScalars
onlyIfV (all is_scalar_ty arg_tys
&& is_scalar_ty res_ty
&& is_scalar (extendVarSetList scalars args) body)
$ do
fn_var <- hoistExpr (fsLit "fn") (mkLams args body)
zipf <- zipScalars arg_tys res_ty
clo <- scalarClosure arg_tys res_ty (Var fn_var)
(zipf `App` Var fn_var)
clo_var <- hoistExpr (fsLit "clo") clo
lclo <- liftPD (Var clo_var)
return (Var clo_var, lclo)
where
arg_tys = map idType args
res_ty = exprType body
is_scalar_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty
= tycon == intTyCon
|| tycon == floatTyCon
|| tycon == doubleTyCon
| otherwise = False
is_scalar vs (Var v) = v `elemVarSet` vs
is_scalar _ e@(Lit _) = is_scalar_ty $ exprType e
is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
is_scalar _ _ = False
vectLam :: Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
vectLam inline fvs bs body
= do
tyvars <- localTyVars
(vs, vvs) <- readLEnv $ \env ->
unzip [(var, vv) | var <- varSetElems fvs
, Just vv <- [lookupVarEnv (local_vars env) var]]
arg_tys <- mapM (vectType . idType) bs
res_ty <- vectType (exprType $ deAnnotate body)
buildClosures tyvars vvs arg_tys res_ty
. hoistPolyVExpr tyvars
$ do
lc <- builtin liftingContext
(vbndrs, vbody) <- vectBndrsIn (vs ++ bs)
(vectExpr body)
return . maybe_inline $ vLams lc vbndrs vbody
where
maybe_inline = if inline then vInlineMe else id
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
(ppr $ deAnnotate e `mkTyApps` tys)
-- We convert
--
-- case e :: t of v { ... }
--
-- to
--
-- V: let v' = e in case v' of _ { ... }
-- L: let v' = e in case v' `cast` ... of _ { ... }
--
-- When lifting, we have to do it this way because v must have the type
-- [:V(T):] but the scrutinee must be cast to the representation type. We also
-- have to handle the case where v is a wild var correctly.
--
-- FIXME: this is too lazy
vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
-> [(AltCon, [Var], CoreExprWithFVs)]
-> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
= do
vscrut <- vectExpr scrut
(vty, lty) <- vectAndLiftType ty
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
= do
vscrut <- vectExpr scrut
(vty, lty) <- vectAndLiftType ty
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
= do
(vty, lty) <- vectAndLiftType ty
vexpr <- vectExpr scrut
(vbndr, (vbndrs, (vect_body, lift_body)))
<- vect_scrut_bndr
. vectBndrsIn bndrs
$ vectExpr body
let (vect_bndrs, lift_bndrs) = unzip vbndrs
(vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
vect_dc <- maybeV (lookupDataCon dc)
let [pdata_dc] = tyConDataCons pdata_tc
let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
where
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
mk_wild_case expr ty dc bndrs body
= mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
vectAlgCase tycon _ty_args scrut bndr ty alts
= do
vect_tc <- maybeV (lookupTyCon tycon)
(vty, lty) <- vectAndLiftType ty
let arity = length (tyConDataCons vect_tc)
sel_ty <- builtin (selTy arity)
sel_bndr <- newLocalVar (fsLit "sel") sel_ty
let sel = Var sel_bndr
(vbndr, valts) <- vect_scrut_bndr
$ mapM (proc_alt arity sel vty lty) alts'
let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
vexpr <- vectExpr scrut
(vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
let [pdata_dc] = tyConDataCons pdata_tc
let (vect_bodies, lift_bodies) = unzip vbodies
vdummy <- newDummyVar (exprType vect_scrut)
ldummy <- newDummyVar (exprType lift_scrut)
let vect_case = Case vect_scrut vdummy vty
(zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
lc <- builtin liftingContext
lbody <- combinePD vty (Var lc) sel lift_bodies
let lift_case = Case lift_scrut ldummy lty
[(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
lbody)]
return . vLet (vNonRec vbndr vexpr)
$ (vect_case, lift_case)
where
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
cmp DEFAULT DEFAULT = EQ
cmp DEFAULT _ = LT
cmp _ DEFAULT = GT
cmp _ _ = panic "vectAlgCase/cmp"
proc_alt arity sel vty lty (DataAlt dc, bndrs, body)
= do
vect_dc <- maybeV (lookupDataCon dc)
let ntag = dataConTagZ vect_dc
tag = mkDataConTag vect_dc
fvs = freeVarsOf body `delVarSetList` bndrs
sel_tags <- liftM (`App` sel) (builtin (selTags arity))
lc <- builtin liftingContext
elems <- builtin (selElements arity ntag)
(vbndrs, vbody)
<- vectBndrsIn bndrs
. localV
$ do
binds <- mapM (pack_var (Var lc) sel_tags tag)
. filter isLocalId
$ varSetElems fvs
(ve, le) <- vectExpr body
return (ve, Case (elems `App` sel) lc lty
[(DEFAULT, [], (mkLets (concat binds) le))])
-- empty <- emptyPD vty
-- return (ve, Case (elems `App` sel) lc lty
-- [(DEFAULT, [], Let (NonRec flags_var flags_expr)
-- $ mkLets (concat binds) le),
-- (LitAlt (mkMachInt 0), [], empty)])
let (vect_bndrs, lift_bndrs) = unzip vbndrs
return (vect_dc, vect_bndrs, lift_bndrs, vbody)
proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
pack_var len tags t v
= do
r <- lookupVar v
case r of
Local (vv, lv) ->
do
lv' <- cloneVar lv
expr <- packByTagPD (idType vv) (Var lv) len tags t
updLEnv (\env -> env { local_vars = extendVarEnv
(local_vars env) v (vv, lv') })
return [(NonRec lv' expr)]
_ -> return []
|