summaryrefslogtreecommitdiff
path: root/compiler/vectorise/Vectorise/Monad.hs
blob: 04a0bf24defd56711e76c801eafbad898844a36b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
module Vectorise.Monad (
  module Vectorise.Monad.Base,
  module Vectorise.Monad.Naming,
  module Vectorise.Monad.Local,
  module Vectorise.Monad.Global,
  module Vectorise.Monad.InstEnv,
  initV,

  -- * Builtins
  liftBuiltinDs,
  builtin,
  builtins,
  
  -- * Variables
  lookupVar,
  lookupVar_maybe,
  addGlobalParallelVar, 
  addGlobalParallelTyCon, 
) where

import Vectorise.Monad.Base
import Vectorise.Monad.Naming
import Vectorise.Monad.Local
import Vectorise.Monad.Global
import Vectorise.Monad.InstEnv
import Vectorise.Builtins
import Vectorise.Env

import CoreSyn
import DsMonad
import HscTypes hiding ( MonadThings(..) )
import DynFlags
import MonadUtils (liftIO)
import InstEnv
import Class
import TyCon
import NameSet
import VarSet
import VarEnv
import Var
import Id
import Name
import ErrUtils
import Outputable


-- |Run a vectorisation computation.
--
initV :: HscEnv
      -> ModGuts
      -> VectInfo
      -> VM a
      -> IO (Maybe (VectInfo, a))
initV hsc_env guts info thing_inside
  = do { dumpIfVtTrace "Incoming VectInfo" (ppr info)

       ; let type_env = typeEnvFromEntities ids (mg_tcs guts) (mg_fam_insts guts)
       ; (_, Just res) <- initDs hsc_env (mg_module guts)
                                         (mg_rdr_env guts) type_env go

       ; case res of
           Nothing
             -> dumpIfVtTrace "Vectorisation FAILED!" empty
           Just (info', _)
             -> dumpIfVtTrace "Outgoing VectInfo" (ppr info')

       ; return res
       }
  where
    dflags = hsc_dflags hsc_env

    dumpIfVtTrace = dumpIfSet_dyn dflags Opt_D_dump_vt_trace
    
    bindsToIds (NonRec v _)   = [v]
    bindsToIds (Rec    binds) = map fst binds
    
    ids = concatMap bindsToIds (mg_binds guts)

    go 
      = do {   -- set up tables of builtin entities
           ; builtins        <- initBuiltins
           ; builtin_vars    <- initBuiltinVars builtins

               -- set up class and type family envrionments
           ; eps <- liftIO $ hscEPS hsc_env
           ; let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
                 instEnvs    = (eps_inst_env     eps, mg_inst_env     guts)
                 builtin_pas = initClassDicts instEnvs (paClass builtins)  -- grab all 'PA' and..
                 builtin_prs = initClassDicts instEnvs (prClass builtins)  -- ..'PR' class instances

               -- construct the initial global environment
           ; let genv = extendImportedVarsEnv builtin_vars
                        . setPAFunsEnv        builtin_pas
                        . setPRFunsEnv        builtin_prs
                        $ initGlobalEnv (gopt Opt_VectorisationAvoidance dflags) 
                                        info (mg_vect_decls guts) instEnvs famInstEnvs
 
               -- perform vectorisation
           ; r <- runVM thing_inside builtins genv emptyLocalEnv
           ; case r of
               Yes genv _ x -> return $ Just (new_info genv, x)
               No reason    -> do { unqual <- mkPrintUnqualifiedDs
                                  ; liftIO $ 
                                      printInfoForUser dflags unqual $ 
                                        mkDumpDoc "Warning: vectorisation failure:" reason
                                  ; return Nothing
                                  }
           }

    new_info genv = modVectInfo genv ids (mg_tcs guts) (mg_vect_decls guts) info

    -- For a given DPH class, produce a mapping from type constructor (in head position) to the
    -- instance dfun for that type constructor and class.  (DPH class instances cannot overlap in
    -- head constructors.)
    --
    initClassDicts :: (InstEnv, InstEnv) -> Class -> [(Name, Var)]
    initClassDicts insts cls = map find $ classInstances insts cls
      where
        find i | [Just tc] <- instanceRoughTcs i = (tc, instanceDFunId i)
               | otherwise                       = pprPanic invalidInstance (ppr i)

    invalidInstance = "Invalid DPH instance (overlapping in head constructor)"


-- Builtins -------------------------------------------------------------------

-- |Lift a desugaring computation using the `Builtins` into the vectorisation monad.
--
liftBuiltinDs :: (Builtins -> DsM a) -> VM a
liftBuiltinDs p = VM $ \bi genv lenv -> do { x <- p bi; return (Yes genv lenv x)}

-- |Project something from the set of builtins.
--
builtin :: (Builtins -> a) -> VM a
builtin f = VM $ \bi genv lenv -> return (Yes genv lenv (f bi))

-- |Lift a function using the `Builtins` into the vectorisation monad.
--
builtins :: (a -> Builtins -> b) -> VM (a -> b)
builtins f = VM $ \bi genv lenv -> return (Yes genv lenv (`f` bi))


-- Var ------------------------------------------------------------------------

-- |Lookup the vectorised, and if local, also the lifted version of a variable.
--
-- * If it's in the global environment we get the vectorised version.
-- * If it's in the local environment we get both the vectorised and lifted version.
--
lookupVar :: Var -> VM (Scope Var (Var, Var))
lookupVar v
  = do { mb_res <- lookupVar_maybe v
       ; case mb_res of
           Just x  -> return x
           Nothing ->
               do dflags <- getDynFlags
                  dumpVar dflags v
       }

lookupVar_maybe :: Var -> VM (Maybe (Scope Var (Var, Var)))
lookupVar_maybe v
 = do { r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v
      ; case r of
          Just e  -> return $ Just (Local e)
          Nothing -> fmap Global <$> (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
      }

dumpVar :: DynFlags -> Var -> a
dumpVar dflags var
  | Just _    <- isClassOpId_maybe var
  = cantVectorise dflags "ClassOpId not vectorised:" (ppr var)
  | otherwise
  = cantVectorise dflags "Variable not vectorised:" (ppr var)


-- Global parallel entities ----------------------------------------------------

-- |Mark the given variable as parallel — i.e., executing the associated code might involve
-- parallel array computations.
--
addGlobalParallelVar :: Var -> VM ()
addGlobalParallelVar var
  = do { traceVt "addGlobalParallelVar" (ppr var)
       ; updGEnv $ \env -> env{global_parallel_vars = extendVarSet (global_parallel_vars env) var}
       }

-- |Mark the given type constructor as parallel — i.e., its values might embed parallel arrays.
--
addGlobalParallelTyCon :: TyCon -> VM ()
addGlobalParallelTyCon tycon
  = do { traceVt "addGlobalParallelTyCon" (ppr tycon)
       ; updGEnv $ \env -> 
           env{global_parallel_tycons = addOneToNameSet (global_parallel_tycons env) (tyConName tycon)}
       }