summaryrefslogtreecommitdiff
path: root/compiler/vectorise/Vectorise.hs
blob: d4b970f91c132712644649162a1090d59441c46d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
{-# OPTIONS -w #-}
-- The above warning supression flag is a temporary kludge.
-- While working on this module you are encouraged to remove it and fix
-- any warnings in the module. See
--     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
-- for details

module Vectorise( vectorise )
where

#include "HsVersions.h"

import VectMonad
import VectUtils
import VectType
import VectCore

import DynFlags
import HscTypes

import CoreLint             ( showPass, endPass )
import CoreSyn
import CoreUtils
import CoreFVs
import SimplMonad           ( SimplCount, zeroSimplCount )
import Rules                ( RuleBase )
import DataCon
import TyCon
import Type
import FamInstEnv           ( extendFamInstEnvList )
import InstEnv              ( extendInstEnvList )
import Var
import VarEnv
import VarSet
import Name                 ( Name, mkSysTvName, getName )
import NameEnv
import Id
import MkId                 ( unwrapFamInstScrut )
import OccName
import Module               ( Module )

import DsMonad hiding (mapAndUnzipM)
import DsUtils              ( mkCoreTup, mkCoreTupTy )

import Literal              ( Literal )
import PrelNames
import TysWiredIn
import TysPrim              ( intPrimTy )
import BasicTypes           ( Boxity(..) )

import Outputable
import FastString
import Control.Monad        ( liftM, liftM2, zipWithM, mapAndUnzipM )

vectorise :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
          -> IO (SimplCount, ModGuts)
vectorise hsc_env _ _ guts
  = do
      showPass dflags "Vectorisation"
      eps <- hscEPS hsc_env
      let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
      Just (info', guts') <- initV hsc_env guts info (vectModule guts)
      endPass dflags "Vectorisation" Opt_D_dump_vect (mg_binds guts')
      return (zeroSimplCount dflags, guts' { mg_vect_info = info' })
  where
    dflags = hsc_dflags hsc_env

vectModule :: ModGuts -> VM ModGuts
vectModule guts
  = do
      (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
      
      let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
      updGEnv (setFamInstEnv fam_inst_env')
     
      -- dicts   <- mapM buildPADict pa_insts
      -- workers <- mapM vectDataConWorkers pa_insts
      binds'  <- mapM vectTopBind (mg_binds guts)
      return $ guts { mg_types        = types'
                    , mg_binds        = Rec tc_binds : binds'
                    , mg_fam_inst_env = fam_inst_env'
                    , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
                    }

vectTopBind :: CoreBind -> VM CoreBind
vectTopBind b@(NonRec var expr)
  = do
      var'  <- vectTopBinder var
      expr' <- vectTopRhs var expr
      hs    <- takeHoisted
      return . Rec $ (var, expr) : (var', expr') : hs
  `orElseV`
    return b

vectTopBind b@(Rec bs)
  = do
      vars'  <- mapM vectTopBinder vars
      exprs' <- zipWithM vectTopRhs vars exprs
      hs     <- takeHoisted
      return . Rec $ bs ++ zip vars' exprs' ++ hs
  `orElseV`
    return b
  where
    (vars, exprs) = unzip bs

vectTopBinder :: Var -> VM Var
vectTopBinder var
  = do
      vty  <- vectType (idType var)
      var' <- cloneId mkVectOcc var vty
      defGlobalVar var var'
      return var'
    
vectTopRhs :: Var -> CoreExpr -> VM CoreExpr
vectTopRhs var expr
  = do
      closedV . liftM vectorised
              . inBind var
              $ vectPolyExpr (freeVars expr)

-- ----------------------------------------------------------------------------
-- Bindings

vectBndr :: Var -> VM VVar
vectBndr v
  = do
      vty <- vectType (idType v)
      lty <- mkPArrayType vty
      let vv = v `Id.setIdType` vty
          lv = v `Id.setIdType` lty
      updLEnv (mapTo vv lv)
      return (vv, lv)
  where
    mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (vv, lv) }

vectBndrIn :: Var -> VM a -> VM (VVar, a)
vectBndrIn v p
  = localV
  $ do
      vv <- vectBndr v
      x <- p
      return (vv, x)

vectBndrIn' :: Var -> (VVar -> VM a) -> VM (VVar, a)
vectBndrIn' v p
  = localV
  $ do
      vv <- vectBndr v
      x  <- p vv
      return (vv, x)

vectBndrsIn :: [Var] -> VM a -> VM ([VVar], a)
vectBndrsIn vs p
  = localV
  $ do
      vvs <- mapM vectBndr vs
      x <- p
      return (vvs, x)

-- ----------------------------------------------------------------------------
-- Expressions

vectVar :: Var -> VM VExpr
vectVar v
  = do
      r <- lookupVar v
      case r of
        Local (vv,lv) -> return (Var vv, Var lv)
        Global vv     -> do
                           let vexpr = Var vv
                           lexpr <- liftPA vexpr
                           return (vexpr, lexpr)

vectPolyVar :: Var -> [Type] -> VM VExpr
vectPolyVar v tys
  = do
      vtys <- mapM vectType tys
      r <- lookupVar v
      case r of
        Local (vv, lv) -> liftM2 (,) (polyApply (Var vv) vtys)
                                     (polyApply (Var lv) vtys)
        Global poly    -> do
                            vexpr <- polyApply (Var poly) vtys
                            lexpr <- liftPA vexpr
                            return (vexpr, lexpr)

vectLiteral :: Literal -> VM VExpr
vectLiteral lit
  = do
      lexpr <- liftPA (Lit lit)
      return (Lit lit, lexpr)

vectPolyExpr :: CoreExprWithFVs -> VM VExpr
vectPolyExpr expr
  = polyAbstract tvs $ \abstract ->
    do
      mono' <- vectExpr mono
      return $ mapVect abstract mono'
  where
    (tvs, mono) = collectAnnTypeBinders expr  
                
vectExpr :: CoreExprWithFVs -> VM VExpr
vectExpr (_, AnnType ty)
  = liftM vType (vectType ty)

vectExpr (_, AnnVar v) = vectVar v

vectExpr (_, AnnLit lit) = vectLiteral lit

vectExpr (_, AnnNote note expr)
  = liftM (vNote note) (vectExpr expr)

vectExpr e@(_, AnnApp _ arg)
  | isAnnTypeArg arg
  = vectTyAppExpr fn tys
  where
    (fn, tys) = collectAnnTypeArgs e

vectExpr (_, AnnApp fn arg)
  = do
      arg_ty' <- vectType arg_ty
      res_ty' <- vectType res_ty
      fn'     <- vectExpr fn
      arg'    <- vectExpr arg
      mkClosureApp arg_ty' res_ty' fn' arg'
  where
    (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn

vectExpr (_, AnnCase scrut bndr ty alts)
  | isAlgType scrut_ty
  = vectAlgCase scrut bndr ty alts
  where
    scrut_ty = exprType (deAnnotate scrut)

vectExpr (_, AnnCase expr bndr ty alts)
  = panic "vectExpr: case"

vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
  = do
      vrhs <- localV . inBind bndr $ vectPolyExpr rhs
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vLet (vNonRec vbndr vrhs) vbody

vectExpr (_, AnnLet (AnnRec bs) body)
  = do
      (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
                                $ liftM2 (,)
                                  (zipWithM vect_rhs bndrs rhss)
                                  (vectPolyExpr body)
      return $ vLet (vRec vbndrs vrhss) vbody
  where
    (bndrs, rhss) = unzip bs

    vect_rhs bndr rhs = localV
                      . inBind bndr
                      $ vectExpr rhs

vectExpr e@(fvs, AnnLam bndr _)
  | not (isId bndr) = pprPanic "vectExpr" (ppr $ deAnnotate e)
  | otherwise = vectLam fvs bs body
  where
    (bs,body) = collectAnnValBinders e

vectLam :: VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
vectLam fvs bs body
  = do
      tyvars <- localTyVars
      (vs, vvs) <- readLEnv $ \env ->
                   unzip [(var, vv) | var <- varSetElems fvs
                                    , Just vv <- [lookupVarEnv (local_vars env) var]]

      arg_tys <- mapM (vectType . idType) bs
      res_ty  <- vectType (exprType $ deAnnotate body)

      buildClosures tyvars vvs arg_tys res_ty
        . hoistPolyVExpr tyvars
        $ do
            lc <- builtin liftingContext
            (vbndrs, vbody) <- vectBndrsIn (vs ++ bs)
                                           (vectExpr body)
            return $ vLams lc vbndrs vbody
  
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
vectTyAppExpr e tys = pprPanic "vectTyAppExpr" (ppr $ deAnnotate e)

type CoreAltWithFVs = AnnAlt Id VarSet

-- We convert
--
--   case e :: t of v { ... }
--
-- to
--
--   V:    let v = e in case v of _ { ... }
--   L:    let v = e in case v `cast` ... of _ { ... }
--
-- When lifting, we have to do it this way because v must have the type
-- [:V(T):] but the scrutinee must be cast to the representation type.
--   

-- FIXME: this is too lazy
vectAlgCase scrut bndr ty [(DEFAULT, [], body)]
  = do
      vscrut <- vectExpr scrut
      vty    <- vectType ty
      lty    <- mkPArrayType vty
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

vectAlgCase scrut bndr ty [(DataAlt dc, bndrs, body)]
  = do
      vty <- vectType ty
      lty <- mkPArrayType vty
      vexpr <- vectExpr scrut
      (vbndr, (vbndrs, vbody)) <- vectBndrIn bndr
                                . vectBndrsIn bndrs
                                $ vectExpr body

      (vscrut, arr_tc, arg_tys) <- mkVScrut (vVar vbndr)
      vect_dc <- maybeV (lookupDataCon dc)
      let [arr_dc] = tyConDataCons arr_tc
      let shape_tys = take (dataConRepArity arr_dc - length bndrs)
                           (dataConRepArgTys arr_dc)
      shape_bndrs <- mapM (newLocalVar FSLIT("s")) shape_tys
      return . vLet (vNonRec vbndr vexpr)
             $ vCaseProd vscrut vty lty vect_dc arr_dc shape_bndrs vbndrs vbody