1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
|
module VectUtils (
collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
collectAnnValBinders,
dataConTagZ, mkDataConTag, mkDataConTagLit,
newLocalVVar,
mkBuiltinCo,
mkPADictType, mkPArrayType, mkPReprType,
parrayReprTyCon, parrayReprDataCon, mkVScrut,
prDFunOfTyCon,
paDictArgType, paDictOfType, paDFunType,
paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
zipScalars, scalarClosure,
polyAbstract, polyApply, polyVApply,
hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
buildClosure, buildClosures,
mkClosureApp
) where
import VectCore
import VectMonad
import MkCore
import CoreSyn
import CoreUtils
import Coercion
import Type
import TypeRep
import TyCon
import DataCon
import Var
import MkId ( unwrapFamInstScrut )
import TysWiredIn
import BasicTypes ( Boxity(..) )
import Literal ( Literal, mkMachInt )
import Outputable
import FastString
import Control.Monad
collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
collectAnnTypeArgs expr = go expr []
where
go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
go e tys = (e, tys)
collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnTypeBinders expr = go [] expr
where
go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
go bs e = (reverse bs, e)
collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnValBinders expr = go [] expr
where
go bs (_, AnnLam b e) | isId b = go (b:bs) e
go bs e = (reverse bs, e)
isAnnTypeArg :: AnnExpr b ann -> Bool
isAnnTypeArg (_, AnnType _) = True
isAnnTypeArg _ = False
dataConTagZ :: DataCon -> Int
dataConTagZ con = dataConTag con - fIRST_TAG
mkDataConTagLit :: DataCon -> Literal
mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
mkDataConTag :: DataCon -> CoreExpr
mkDataConTag = mkIntLitInt . dataConTagZ
splitPrimTyCon :: Type -> Maybe TyCon
splitPrimTyCon ty
| Just (tycon, []) <- splitTyConApp_maybe ty
, isPrimTyCon tycon
= Just tycon
| otherwise = Nothing
mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
mkBuiltinTyConApp get_tc tys
= do
tc <- builtin get_tc
return $ mkTyConApp tc tys
mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
mkBuiltinTyConApps get_tc tys ty
= do
tc <- builtin get_tc
return $ foldr (mk tc) ty tys
where
mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
{-
mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
mkBuiltinTyConApps1 _ dft [] = return dft
mkBuiltinTyConApps1 get_tc _ tys
= do
tc <- builtin get_tc
case tys of
[] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
_ -> return $ foldr1 (mk tc) tys
where
mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
mkClosureType :: Type -> Type -> VM Type
mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
-}
mkClosureTypes :: [Type] -> Type -> VM Type
mkClosureTypes = mkBuiltinTyConApps closureTyCon
mkPReprType :: Type -> VM Type
mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
mkPADictType :: Type -> VM Type
mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
mkPArrayType :: Type -> VM Type
mkPArrayType ty
| Just tycon <- splitPrimTyCon ty
= do
r <- lookupPrimPArray tycon
case r of
Just arr -> return $ mkTyConApp arr []
Nothing -> cantVectorise "Primitive tycon not vectorised" (ppr tycon)
mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
mkBuiltinCo get_tc
= do
tc <- builtin get_tc
return $ mkTyConApp tc []
parrayReprTyCon :: Type -> VM (TyCon, [Type])
parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
parrayReprDataCon :: Type -> VM (DataCon, [Type])
parrayReprDataCon ty
= do
(tc, arg_tys) <- parrayReprTyCon ty
let [dc] = tyConDataCons tc
return (dc, arg_tys)
mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
mkVScrut (ve, le)
= do
(tc, arg_tys) <- parrayReprTyCon (exprType ve)
return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
prDFunOfTyCon :: TyCon -> VM CoreExpr
prDFunOfTyCon tycon
= liftM Var
. maybeCantVectoriseM "No PR dictionary for tycon" (ppr tycon)
$ lookupTyConPR tycon
paDictArgType :: TyVar -> VM (Maybe Type)
paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
where
go ty k | Just k' <- kindView k = go ty k'
go ty (FunTy k1 k2)
= do
tv <- newTyVar (fsLit "a") k1
mty1 <- go (TyVarTy tv) k1
case mty1 of
Just ty1 -> do
mty2 <- go (AppTy ty (TyVarTy tv)) k2
return $ fmap (ForAllTy tv . FunTy ty1) mty2
Nothing -> go ty k2
go ty k
| isLiftedTypeKind k
= liftM Just (mkPADictType ty)
go _ _ = return Nothing
paDictOfType :: Type -> VM CoreExpr
paDictOfType ty = paDictOfTyApp ty_fn ty_args
where
(ty_fn, ty_args) = splitAppTys ty
paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
paDictOfTyApp ty_fn ty_args
| Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
paDictOfTyApp (TyVarTy tv) ty_args
= do
dfun <- maybeV (lookupTyVarPA tv)
paDFunApply dfun ty_args
paDictOfTyApp (TyConApp tc _) ty_args
= do
dfun <- maybeCantVectoriseM "No PA dictionary for tycon" (ppr tc)
$ lookupTyConPA tc
paDFunApply (Var dfun) ty_args
paDictOfTyApp ty _
= cantVectorise "Can't construct PA dictionary for type" (ppr ty)
paDFunType :: TyCon -> VM Type
paDFunType tc
= do
margs <- mapM paDictArgType tvs
res <- mkPADictType (mkTyConApp tc arg_tys)
return . mkForAllTys tvs
$ mkFunTys [arg | Just arg <- margs] res
where
tvs = tyConTyVars tc
arg_tys = mkTyVarTys tvs
paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
paDFunApply dfun tys
= do
dicts <- mapM paDictOfType tys
return $ mkApps (mkTyApps dfun tys) dicts
type PAMethod = (Builtins -> Var, String)
pa_length, pa_replicate, pa_empty, pa_pack :: (Builtins -> Var, String)
pa_length = (lengthPAVar, "lengthPA")
pa_replicate = (replicatePAVar, "replicatePA")
pa_empty = (emptyPAVar, "emptyPA")
pa_pack = (packPAVar, "packPA")
paMethod :: PAMethod -> Type -> VM CoreExpr
paMethod (_method, name) ty
| Just tycon <- splitPrimTyCon ty
= liftM Var
. maybeCantVectoriseM "No PA method" (text name <+> text "for" <+> ppr tycon)
$ lookupPrimMethod tycon name
paMethod (method, _name) ty
= do
fn <- builtin method
dict <- paDictOfType ty
return $ mkApps (Var fn) [Type ty, dict]
mkPR :: Type -> VM CoreExpr
mkPR ty
= do
fn <- builtin mkPRVar
dict <- paDictOfType ty
return $ mkApps (Var fn) [Type ty, dict]
lengthPA :: Type -> CoreExpr -> VM CoreExpr
lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
replicatePA len x = liftM (`mkApps` [len,x])
(paMethod pa_replicate (exprType x))
emptyPA :: Type -> VM CoreExpr
emptyPA = paMethod pa_empty
packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
packPA ty xs len sel = liftM (`mkApps` [xs, len, sel])
(paMethod pa_pack ty)
combinePA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> [CoreExpr]
-> VM CoreExpr
combinePA ty len sel is xs
= liftM (`mkApps` (len : sel : is : xs))
(paMethod (combinePAVar n, "combine" ++ show n ++ "PA") ty)
where
n = length xs
liftPA :: CoreExpr -> VM CoreExpr
liftPA x
= do
lc <- builtin liftingContext
replicatePA (Var lc) x
zipScalars :: [Type] -> Type -> VM CoreExpr
zipScalars arg_tys res_ty
= do
scalar <- builtin scalarClass
(dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
zipf <- builtin (scalarZip $ length arg_tys)
return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
where
ty_args = arg_tys ++ [res_ty]
scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
scalarClosure arg_tys res_ty scalar_fun array_fun
= do
ctr <- builtin (closureCtrFun $ length arg_tys)
pas <- mapM paDictOfType (init arg_tys)
return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
`mkApps` (pas ++ [scalar_fun, array_fun])
newLocalVVar :: FastString -> Type -> VM VVar
newLocalVVar fs vty
= do
lty <- mkPArrayType vty
vv <- newLocalVar fs vty
lv <- newLocalVar fs lty
return (vv,lv)
polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
polyAbstract tvs p
= localV
$ do
mdicts <- mapM mk_dict_var tvs
zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
p (mk_lams mdicts)
where
mk_dict_var tv = do
r <- paDictArgType tv
case r of
Just ty -> liftM Just (newLocalVar (fsLit "dPA") ty)
Nothing -> return Nothing
mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
polyApply :: CoreExpr -> [Type] -> VM CoreExpr
polyApply expr tys
= do
dicts <- mapM paDictOfType tys
return $ expr `mkTyApps` tys `mkApps` dicts
polyVApply :: VExpr -> [Type] -> VM VExpr
polyVApply expr tys
= do
dicts <- mapM paDictOfType tys
return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
hoistBinding :: Var -> CoreExpr -> VM ()
hoistBinding v e = updGEnv $ \env ->
env { global_bindings = (v,e) : global_bindings env }
hoistExpr :: FastString -> CoreExpr -> VM Var
hoistExpr fs expr
= do
var <- newLocalVar fs (exprType expr)
hoistBinding var expr
return var
hoistVExpr :: VExpr -> VM VVar
hoistVExpr (ve, le)
= do
fs <- getBindName
vv <- hoistExpr ('v' `consFS` fs) ve
lv <- hoistExpr ('l' `consFS` fs) le
return (vv, lv)
hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
hoistPolyVExpr tvs p
= do
expr <- closedV . polyAbstract tvs $ \abstract ->
liftM (mapVect abstract) p
fn <- hoistVExpr expr
polyVApply (vVar fn) (mkTyVarTys tvs)
takeHoisted :: VM [(Var, CoreExpr)]
takeHoisted
= do
env <- readGEnv id
setGEnv $ env { global_bindings = [] }
return $ global_bindings env
{-
boxExpr :: Type -> VExpr -> VM VExpr
boxExpr ty (vexpr, lexpr)
| Just (tycon, []) <- splitTyConApp_maybe ty
, isUnLiftedTyCon tycon
= do
r <- lookupBoxedTyCon tycon
case r of
Just tycon' -> let [dc] = tyConDataCons tycon'
in
return (mkConApp dc [vexpr], lexpr)
Nothing -> return (vexpr, lexpr)
-}
mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
= do
dict <- paDictOfType env_ty
mkv <- builtin mkClosureVar
mkl <- builtin mkClosurePVar
return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
= do
vapply <- builtin applyClosureVar
lapply <- builtin applyClosurePVar
return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
buildClosures _ _ [] _ mk_body
= mk_body
buildClosures tvs vars [arg_ty] res_ty mk_body
= liftM vInlineMe (buildClosure tvs vars arg_ty res_ty mk_body)
buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
= do
res_ty' <- mkClosureTypes arg_tys res_ty
arg <- newLocalVVar (fsLit "x") arg_ty
liftM vInlineMe
. buildClosure tvs vars arg_ty res_ty'
. hoistPolyVExpr tvs
$ do
lc <- builtin liftingContext
clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
return $ vLams lc (vars ++ [arg]) clo
-- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
-- where
-- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
-- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
--
buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
buildClosure tvs vars arg_ty res_ty mk_body
= do
(env_ty, env, bind) <- buildEnv vars
env_bndr <- newLocalVVar (fsLit "env") env_ty
arg_bndr <- newLocalVVar (fsLit "arg") arg_ty
fn <- hoistPolyVExpr tvs
$ do
lc <- builtin liftingContext
body <- mk_body
body' <- bind (vVar env_bndr)
(vVarApps lc body (vars ++ [arg_bndr]))
return . vInlineMe $ vLamsWithoutLC [env_bndr, arg_bndr] body'
mkClosure arg_ty res_ty env_ty fn env
buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
buildEnv vvs
= do
lc <- builtin liftingContext
let (ty, venv, vbind) = mkVectEnv tys vs
(lenv, lbind) <- mkLiftEnv lc tys ls
return (ty, (venv, lenv),
\(venv,lenv) (vbody,lbody) ->
do
let vbody' = vbind venv vbody
lbody' <- lbind lenv lbody
return (vbody', lbody'))
where
(vs,ls) = unzip vvs
tys = map varType vs
mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
mkVectEnv [] [] = (unitTy, Var unitDataConId, \_ body -> body)
mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
\env body -> mkWildCase env ty (exprType body)
[(DataAlt (tupleCon Boxed (length vs)), vs, body)])
where
ty = mkCoreTupTy tys
mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
mkLiftEnv lc [ty] [v]
= return (Var v, \env body ->
do
len <- lengthPA ty (Var v)
return . Let (NonRec v env)
$ Case len lc (exprType body) [(DEFAULT, [], body)])
-- NOTE: this transparently deals with empty environments
mkLiftEnv lc tys vs
= do
(env_tc, env_tyargs) <- parrayReprTyCon vty
bndrs <- if null vs then do
v <- newDummyVar unitTy
return [v]
else return vs
let [env_con] = tyConDataCons env_tc
env = Var (dataConWrapId env_con)
`mkTyApps` env_tyargs
`mkApps` (Var lc : args)
bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
in
return $ mkWildCase scrut (exprType scrut)
(exprType body)
[(DataAlt env_con, lc : bndrs, body)]
return (env, bind)
where
vty = mkCoreTupTy tys
args | null vs = [Var unitDataConId]
| otherwise = map Var vs
|