summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/coreSyn/CoreFVs.lhs11
-rw-r--r--compiler/coreSyn/CoreSubst.lhs11
-rw-r--r--compiler/coreSyn/CoreSyn.lhs4
-rw-r--r--compiler/coreSyn/PprCore.lhs3
-rw-r--r--compiler/deSugar/Desugar.lhs2
-rw-r--r--compiler/hsSyn/HsDecls.lhs10
-rw-r--r--compiler/iface/LoadIface.lhs14
-rw-r--r--compiler/iface/MkIface.lhs18
-rw-r--r--compiler/iface/TcIface.lhs30
-rw-r--r--compiler/main/HscTypes.lhs54
-rw-r--r--compiler/main/TidyPgm.lhs10
-rw-r--r--compiler/parser/Parser.y.pp5
-rw-r--r--compiler/rename/RnSource.lhs10
-rw-r--r--compiler/typecheck/TcBinds.lhs9
-rw-r--r--compiler/typecheck/TcHsSyn.lhs2
-rw-r--r--compiler/vectorise/Vectorise.hs396
-rw-r--r--compiler/vectorise/Vectorise/Convert.hs18
-rw-r--r--compiler/vectorise/Vectorise/Env.hs88
-rw-r--r--compiler/vectorise/Vectorise/Exp.hs1048
-rw-r--r--compiler/vectorise/Vectorise/Monad.hs26
-rw-r--r--compiler/vectorise/Vectorise/Monad/Global.hs60
-rw-r--r--compiler/vectorise/Vectorise/Monad/InstEnv.hs13
-rw-r--r--compiler/vectorise/Vectorise/Monad/Local.hs22
-rw-r--r--compiler/vectorise/Vectorise/Type/Classify.hs42
-rw-r--r--compiler/vectorise/Vectorise/Type/Env.hs211
-rw-r--r--compiler/vectorise/Vectorise/Type/Type.hs13
-rw-r--r--compiler/vectorise/Vectorise/Utils.hs27
27 files changed, 1080 insertions, 1077 deletions
diff --git a/compiler/coreSyn/CoreFVs.lhs b/compiler/coreSyn/CoreFVs.lhs
index d2bb6ed57a..2a11723fa9 100644
--- a/compiler/coreSyn/CoreFVs.lhs
+++ b/compiler/coreSyn/CoreFVs.lhs
@@ -328,12 +328,11 @@ breaker, which is perfectly inlinable.
vectsFreeVars :: [CoreVect] -> VarSet
vectsFreeVars = foldr (unionVarSet . vectFreeVars) emptyVarSet
where
- vectFreeVars (Vect _ Nothing) = noFVs
- vectFreeVars (Vect _ (Just rhs)) = expr_fvs rhs isLocalId emptyVarSet
- vectFreeVars (NoVect _) = noFVs
- vectFreeVars (VectType _ _ _) = noFVs
- vectFreeVars (VectClass _) = noFVs
- vectFreeVars (VectInst _) = noFVs
+ vectFreeVars (Vect _ rhs) = expr_fvs rhs isLocalId emptyVarSet
+ vectFreeVars (NoVect _) = noFVs
+ vectFreeVars (VectType _ _ _) = noFVs
+ vectFreeVars (VectClass _) = noFVs
+ vectFreeVars (VectInst _) = noFVs
-- this function is only concerned with values, not types
\end{code}
diff --git a/compiler/coreSyn/CoreSubst.lhs b/compiler/coreSyn/CoreSubst.lhs
index a8de9c2b16..f92b63bfe0 100644
--- a/compiler/coreSyn/CoreSubst.lhs
+++ b/compiler/coreSyn/CoreSubst.lhs
@@ -749,12 +749,11 @@ substVects subst = map (substVect subst)
------------------
substVect :: Subst -> CoreVect -> CoreVect
-substVect _subst (Vect v Nothing) = Vect v Nothing
-substVect subst (Vect v (Just rhs)) = Vect v (Just (simpleOptExprWith subst rhs))
-substVect _subst vd@(NoVect _) = vd
-substVect _subst vd@(VectType _ _ _) = vd
-substVect _subst vd@(VectClass _) = vd
-substVect _subst vd@(VectInst _) = vd
+substVect subst (Vect v rhs) = Vect v (simpleOptExprWith subst rhs)
+substVect _subst vd@(NoVect _) = vd
+substVect _subst vd@(VectType _ _ _) = vd
+substVect _subst vd@(VectClass _) = vd
+substVect _subst vd@(VectInst _) = vd
------------------
substVarSet :: Subst -> VarSet -> VarSet
diff --git a/compiler/coreSyn/CoreSyn.lhs b/compiler/coreSyn/CoreSyn.lhs
index a84a29a6c0..5216bebd23 100644
--- a/compiler/coreSyn/CoreSyn.lhs
+++ b/compiler/coreSyn/CoreSyn.lhs
@@ -592,11 +592,11 @@ Representation of desugared vectorisation declarations that are fed to the vecto
'ModGuts').
\begin{code}
-data CoreVect = Vect Id (Maybe CoreExpr)
+data CoreVect = Vect Id CoreExpr
| NoVect Id
| VectType Bool TyCon (Maybe TyCon)
| VectClass TyCon -- class tycon
- | VectInst Id -- instance dfun (always SCALAR)
+ | VectInst Id -- instance dfun (always SCALAR) !!!FIXME: should be superfluous now
\end{code}
diff --git a/compiler/coreSyn/PprCore.lhs b/compiler/coreSyn/PprCore.lhs
index 3ca8c48855..a8ed4b875a 100644
--- a/compiler/coreSyn/PprCore.lhs
+++ b/compiler/coreSyn/PprCore.lhs
@@ -494,8 +494,7 @@ instance Outputable id => Outputable (Tickish id) where
\begin{code}
instance Outputable CoreVect where
- ppr (Vect var Nothing) = ptext (sLit "VECTORISE SCALAR") <+> ppr var
- ppr (Vect var (Just e)) = hang (ptext (sLit "VECTORISE") <+> ppr var <+> char '=')
+ ppr (Vect var e) = hang (ptext (sLit "VECTORISE") <+> ppr var <+> char '=')
4 (pprCoreExpr e)
ppr (NoVect var) = ptext (sLit "NOVECTORISE") <+> ppr var
ppr (VectType False var Nothing) = ptext (sLit "VECTORISE type") <+> ppr var
diff --git a/compiler/deSugar/Desugar.lhs b/compiler/deSugar/Desugar.lhs
index 28b0582076..78c95ceb88 100644
--- a/compiler/deSugar/Desugar.lhs
+++ b/compiler/deSugar/Desugar.lhs
@@ -432,7 +432,7 @@ the rule is precisly to optimise them:
dsVect :: LVectDecl Id -> DsM CoreVect
dsVect (L loc (HsVect (L _ v) rhs))
= putSrcSpanDs loc $
- do { rhs' <- fmapMaybeM dsLExpr rhs
+ do { rhs' <- dsLExpr rhs
; return $ Vect v rhs'
}
dsVect (L _loc (HsNoVect (L _ v)))
diff --git a/compiler/hsSyn/HsDecls.lhs b/compiler/hsSyn/HsDecls.lhs
index bac9ec6348..58a0aced14 100644
--- a/compiler/hsSyn/HsDecls.lhs
+++ b/compiler/hsSyn/HsDecls.lhs
@@ -1111,7 +1111,7 @@ type LVectDecl name = Located (VectDecl name)
data VectDecl name
= HsVect
(Located name)
- (Maybe (LHsExpr name)) -- 'Nothing' => SCALAR declaration
+ (LHsExpr name)
| HsNoVect
(Located name)
| HsVectTypeIn -- pre type-checking
@@ -1126,9 +1126,9 @@ data VectDecl name
(Located name)
| HsVectClassOut -- post type-checking
Class
- | HsVectInstIn -- pre type-checking (always SCALAR)
+ | HsVectInstIn -- pre type-checking (always SCALAR) !!!FIXME: should be superfluous now
(LHsType name)
- | HsVectInstOut -- post type-checking (always SCALAR)
+ | HsVectInstOut -- post type-checking (always SCALAR) !!!FIXME: should be superfluous now
ClsInst
deriving (Data, Typeable)
@@ -1148,9 +1148,7 @@ lvectInstDecl (L _ (HsVectInstOut _)) = True
lvectInstDecl _ = False
instance OutputableBndr name => Outputable (VectDecl name) where
- ppr (HsVect v Nothing)
- = sep [text "{-# VECTORISE SCALAR" <+> ppr v <+> text "#-}" ]
- ppr (HsVect v (Just rhs))
+ ppr (HsVect v rhs)
= sep [text "{-# VECTORISE" <+> ppr v,
nest 4 $
pprExpr (unLoc rhs) <+> text "#-}" ]
diff --git a/compiler/iface/LoadIface.lhs b/compiler/iface/LoadIface.lhs
index 6c5e7d38d9..f3dadbfc53 100644
--- a/compiler/iface/LoadIface.lhs
+++ b/compiler/iface/LoadIface.lhs
@@ -750,18 +750,18 @@ pprFixities fixes = ptext (sLit "fixities") <+> pprWithCommas pprFix fixes
pprFix (occ,fix) = ppr fix <+> ppr occ
pprVectInfo :: IfaceVectInfo -> SDoc
-pprVectInfo (IfaceVectInfo { ifaceVectInfoVar = vars
- , ifaceVectInfoTyCon = tycons
- , ifaceVectInfoTyConReuse = tyconsReuse
- , ifaceVectInfoScalarVars = scalarVars
- , ifaceVectInfoScalarTyCons = scalarTyCons
+pprVectInfo (IfaceVectInfo { ifaceVectInfoVar = vars
+ , ifaceVectInfoTyCon = tycons
+ , ifaceVectInfoTyConReuse = tyconsReuse
+ , ifaceVectInfoParallelVars = parallelVars
+ , ifaceVectInfoParallelTyCons = parallelTyCons
}) =
vcat
[ ptext (sLit "vectorised variables:") <+> hsep (map ppr vars)
, ptext (sLit "vectorised tycons:") <+> hsep (map ppr tycons)
, ptext (sLit "vectorised reused tycons:") <+> hsep (map ppr tyconsReuse)
- , ptext (sLit "scalar variables:") <+> hsep (map ppr scalarVars)
- , ptext (sLit "scalar tycons:") <+> hsep (map ppr scalarTyCons)
+ , ptext (sLit "parallel variables:") <+> hsep (map ppr parallelVars)
+ , ptext (sLit "parallel tycons:") <+> hsep (map ppr parallelTyCons)
]
pprTrustInfo :: IfaceTrustInfo -> SDoc
diff --git a/compiler/iface/MkIface.lhs b/compiler/iface/MkIface.lhs
index ce07b375b3..6aed1b2be4 100644
--- a/compiler/iface/MkIface.lhs
+++ b/compiler/iface/MkIface.lhs
@@ -373,17 +373,17 @@ mkIface_ hsc_env maybe_old_fingerprint
ifFamInstTcName = ifFamInstFam
- flattenVectInfo (VectInfo { vectInfoVar = vVar
- , vectInfoTyCon = vTyCon
- , vectInfoScalarVars = vScalarVars
- , vectInfoScalarTyCons = vScalarTyCons
+ flattenVectInfo (VectInfo { vectInfoVar = vVar
+ , vectInfoTyCon = vTyCon
+ , vectInfoParallelVars = vParallelVars
+ , vectInfoParallelTyCons = vParallelTyCons
}) =
IfaceVectInfo
- { ifaceVectInfoVar = [Var.varName v | (v, _ ) <- varEnvElts vVar]
- , ifaceVectInfoTyCon = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t /= t_v]
- , ifaceVectInfoTyConReuse = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t == t_v]
- , ifaceVectInfoScalarVars = [Var.varName v | v <- varSetElems vScalarVars]
- , ifaceVectInfoScalarTyCons = nameSetToList vScalarTyCons
+ { ifaceVectInfoVar = [Var.varName v | (v, _ ) <- varEnvElts vVar]
+ , ifaceVectInfoTyCon = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t /= t_v]
+ , ifaceVectInfoTyConReuse = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t == t_v]
+ , ifaceVectInfoParallelVars = [Var.varName v | v <- varSetElems vParallelVars]
+ , ifaceVectInfoParallelTyCons = nameSetToList vParallelTyCons
}
-----------------------------
diff --git a/compiler/iface/TcIface.lhs b/compiler/iface/TcIface.lhs
index 80c2029a70..8fdda45498 100644
--- a/compiler/iface/TcIface.lhs
+++ b/compiler/iface/TcIface.lhs
@@ -748,25 +748,25 @@ tcIfaceAnnTarget (ModuleTarget mod) = do
--
tcIfaceVectInfo :: Module -> TypeEnv -> IfaceVectInfo -> IfL VectInfo
tcIfaceVectInfo mod typeEnv (IfaceVectInfo
- { ifaceVectInfoVar = vars
- , ifaceVectInfoTyCon = tycons
- , ifaceVectInfoTyConReuse = tyconsReuse
- , ifaceVectInfoScalarVars = scalarVars
- , ifaceVectInfoScalarTyCons = scalarTyCons
+ { ifaceVectInfoVar = vars
+ , ifaceVectInfoTyCon = tycons
+ , ifaceVectInfoTyConReuse = tyconsReuse
+ , ifaceVectInfoParallelVars = parallelVars
+ , ifaceVectInfoParallelTyCons = parallelTyCons
})
- = do { let scalarTyConsSet = mkNameSet scalarTyCons
- ; vVars <- mapM vectVarMapping vars
+ = do { let parallelTyConsSet = mkNameSet parallelTyCons
+ ; vVars <- mapM vectVarMapping vars
; let varsSet = mkVarSet (map fst vVars)
- ; tyConRes1 <- mapM (vectTyConVectMapping varsSet) tycons
- ; tyConRes2 <- mapM (vectTyConReuseMapping varsSet) tyconsReuse
- ; vScalarVars <- mapM vectVar scalarVars
+ ; tyConRes1 <- mapM (vectTyConVectMapping varsSet) tycons
+ ; tyConRes2 <- mapM (vectTyConReuseMapping varsSet) tyconsReuse
+ ; vParallelVars <- mapM vectVar parallelVars
; let (vTyCons, vDataCons, vScSels) = unzip3 (tyConRes1 ++ tyConRes2)
; return $ VectInfo
- { vectInfoVar = mkVarEnv vVars `extendVarEnvList` concat vScSels
- , vectInfoTyCon = mkNameEnv vTyCons
- , vectInfoDataCon = mkNameEnv (concat vDataCons)
- , vectInfoScalarVars = mkVarSet vScalarVars
- , vectInfoScalarTyCons = scalarTyConsSet
+ { vectInfoVar = mkVarEnv vVars `extendVarEnvList` concat vScSels
+ , vectInfoTyCon = mkNameEnv vTyCons
+ , vectInfoDataCon = mkNameEnv (concat vDataCons)
+ , vectInfoParallelVars = mkVarSet vParallelVars
+ , vectInfoParallelTyCons = parallelTyConsSet
}
}
where
diff --git a/compiler/main/HscTypes.lhs b/compiler/main/HscTypes.lhs
index 343df00540..f32975fd4c 100644
--- a/compiler/main/HscTypes.lhs
+++ b/compiler/main/HscTypes.lhs
@@ -1968,11 +1968,11 @@ on just the OccName easily in a Core pass.
--
data VectInfo
= VectInfo
- { vectInfoVar :: VarEnv (Var , Var ) -- ^ @(f, f_v)@ keyed on @f@
- , vectInfoTyCon :: NameEnv (TyCon , TyCon) -- ^ @(T, T_v)@ keyed on @T@
- , vectInfoDataCon :: NameEnv (DataCon, DataCon) -- ^ @(C, C_v)@ keyed on @C@
- , vectInfoScalarVars :: VarSet -- ^ set of purely scalar variables
- , vectInfoScalarTyCons :: NameSet -- ^ set of scalar type constructors
+ { vectInfoVar :: VarEnv (Var , Var ) -- ^ @(f, f_v)@ keyed on @f@
+ , vectInfoTyCon :: NameEnv (TyCon , TyCon) -- ^ @(T, T_v)@ keyed on @T@
+ , vectInfoDataCon :: NameEnv (DataCon, DataCon) -- ^ @(C, C_v)@ keyed on @C@
+ , vectInfoParallelVars :: VarSet -- ^ set of parallel variables
+ , vectInfoParallelTyCons :: NameSet -- ^ set of parallel type constructors
}
-- |Vectorisation information for 'ModIface'; i.e, the vectorisation information propagated
@@ -1986,18 +1986,18 @@ data VectInfo
--
data IfaceVectInfo
= IfaceVectInfo
- { ifaceVectInfoVar :: [Name] -- ^ All variables in here have a vectorised variant
- , ifaceVectInfoTyCon :: [Name] -- ^ All 'TyCon's in here have a vectorised variant;
- -- the name of the vectorised variant and those of its
- -- data constructors are determined by
- -- 'OccName.mkVectTyConOcc' and
- -- 'OccName.mkVectDataConOcc'; the names of the
- -- isomorphisms are determined by 'OccName.mkVectIsoOcc'
- , ifaceVectInfoTyConReuse :: [Name] -- ^ The vectorised form of all the 'TyCon's in here
- -- coincides with the unconverted form; the name of the
- -- isomorphisms is determined by 'OccName.mkVectIsoOcc'
- , ifaceVectInfoScalarVars :: [Name] -- iface version of 'vectInfoScalarVar'
- , ifaceVectInfoScalarTyCons :: [Name] -- iface version of 'vectInfoScalarTyCon'
+ { ifaceVectInfoVar :: [Name] -- ^ All variables in here have a vectorised variant
+ , ifaceVectInfoTyCon :: [Name] -- ^ All 'TyCon's in here have a vectorised variant;
+ -- the name of the vectorised variant and those of its
+ -- data constructors are determined by
+ -- 'OccName.mkVectTyConOcc' and
+ -- 'OccName.mkVectDataConOcc'; the names of the
+ -- isomorphisms are determined by 'OccName.mkVectIsoOcc'
+ , ifaceVectInfoTyConReuse :: [Name] -- ^ The vectorised form of all the 'TyCon's in here
+ -- coincides with the unconverted form; the name of the
+ -- isomorphisms is determined by 'OccName.mkVectIsoOcc'
+ , ifaceVectInfoParallelVars :: [Name] -- iface version of 'vectInfoParallelVar'
+ , ifaceVectInfoParallelTyCons :: [Name] -- iface version of 'vectInfoParallelTyCon'
}
noVectInfo :: VectInfo
@@ -2006,11 +2006,11 @@ noVectInfo
plusVectInfo :: VectInfo -> VectInfo -> VectInfo
plusVectInfo vi1 vi2 =
- VectInfo (vectInfoVar vi1 `plusVarEnv` vectInfoVar vi2)
- (vectInfoTyCon vi1 `plusNameEnv` vectInfoTyCon vi2)
- (vectInfoDataCon vi1 `plusNameEnv` vectInfoDataCon vi2)
- (vectInfoScalarVars vi1 `unionVarSet` vectInfoScalarVars vi2)
- (vectInfoScalarTyCons vi1 `unionNameSets` vectInfoScalarTyCons vi2)
+ VectInfo (vectInfoVar vi1 `plusVarEnv` vectInfoVar vi2)
+ (vectInfoTyCon vi1 `plusNameEnv` vectInfoTyCon vi2)
+ (vectInfoDataCon vi1 `plusNameEnv` vectInfoDataCon vi2)
+ (vectInfoParallelVars vi1 `unionVarSet` vectInfoParallelVars vi2)
+ (vectInfoParallelTyCons vi1 `unionNameSets` vectInfoParallelTyCons vi2)
concatVectInfo :: [VectInfo] -> VectInfo
concatVectInfo = foldr plusVectInfo noVectInfo
@@ -2024,11 +2024,11 @@ isNoIfaceVectInfo (IfaceVectInfo l1 l2 l3 l4 l5)
instance Outputable VectInfo where
ppr info = vcat
- [ ptext (sLit "variables :") <+> ppr (vectInfoVar info)
- , ptext (sLit "tycons :") <+> ppr (vectInfoTyCon info)
- , ptext (sLit "datacons :") <+> ppr (vectInfoDataCon info)
- , ptext (sLit "scalar vars :") <+> ppr (vectInfoScalarVars info)
- , ptext (sLit "scalar tycons :") <+> ppr (vectInfoScalarTyCons info)
+ [ ptext (sLit "variables :") <+> ppr (vectInfoVar info)
+ , ptext (sLit "tycons :") <+> ppr (vectInfoTyCon info)
+ , ptext (sLit "datacons :") <+> ppr (vectInfoDataCon info)
+ , ptext (sLit "parallel vars :") <+> ppr (vectInfoParallelVars info)
+ , ptext (sLit "parallel tycons :") <+> ppr (vectInfoParallelTyCons info)
]
\end{code}
diff --git a/compiler/main/TidyPgm.lhs b/compiler/main/TidyPgm.lhs
index 85127e63f6..da732c687e 100644
--- a/compiler/main/TidyPgm.lhs
+++ b/compiler/main/TidyPgm.lhs
@@ -542,10 +542,10 @@ tidyInstances tidy_dfun ispecs
\begin{code}
tidyVectInfo :: TidyEnv -> VectInfo -> VectInfo
tidyVectInfo (_, var_env) info@(VectInfo { vectInfoVar = vars
- , vectInfoScalarVars = scalarVars
+ , vectInfoParallelVars = parallelVars
})
= info { vectInfoVar = tidy_vars
- , vectInfoScalarVars = tidy_scalarVars
+ , vectInfoParallelVars = tidy_parallelVars
}
where
-- we only export mappings whose domain and co-domain is exported (otherwise, the iface is
@@ -559,9 +559,9 @@ tidyVectInfo (_, var_env) info@(VectInfo { vectInfoVar = vars
, isDataConWorkId var || not (isImplicitId var)
]
- tidy_scalarVars = mkVarSet [ lookup_var var
- | var <- varSetElems scalarVars
- , isGlobalId var || isExportedId var]
+ tidy_parallelVars = mkVarSet [ lookup_var var
+ | var <- varSetElems parallelVars
+ , isGlobalId var || isExportedId var]
lookup_var var = lookupWithDefaultVarEnv var_env var var
\end{code}
diff --git a/compiler/parser/Parser.y.pp b/compiler/parser/Parser.y.pp
index 6c19812762..690b005d76 100644
--- a/compiler/parser/Parser.y.pp
+++ b/compiler/parser/Parser.y.pp
@@ -577,8 +577,7 @@ topdecl :: { OrdList (LHsDecl RdrName) }
| '{-# DEPRECATED' deprecations '#-}' { $2 }
| '{-# WARNING' warnings '#-}' { $2 }
| '{-# RULES' rules '#-}' { $2 }
- | '{-# VECTORISE_SCALAR' qvar '#-}' { unitOL $ LL $ VectD (HsVect $2 Nothing) }
- | '{-# VECTORISE' qvar '=' exp '#-}' { unitOL $ LL $ VectD (HsVect $2 (Just $4)) }
+ | '{-# VECTORISE' qvar '=' exp '#-}' { unitOL $ LL $ VectD (HsVect $2 $4) }
| '{-# NOVECTORISE' qvar '#-}' { unitOL $ LL $ VectD (HsNoVect $2) }
| '{-# VECTORISE' 'type' gtycon '#-}'
{ unitOL $ LL $
@@ -593,8 +592,6 @@ topdecl :: { OrdList (LHsDecl RdrName) }
{ unitOL $ LL $
VectD (HsVectTypeIn True $3 (Just $5)) }
| '{-# VECTORISE' 'class' gtycon '#-}' { unitOL $ LL $ VectD (HsVectClassIn $3) }
- | '{-# VECTORISE_SCALAR' 'instance' type '#-}'
- { unitOL $ LL $ VectD (HsVectInstIn $3) }
| annotation { unitOL $1 }
| decl { unLoc $1 }
diff --git a/compiler/rename/RnSource.lhs b/compiler/rename/RnSource.lhs
index 595f4653d3..0d897f3f0b 100644
--- a/compiler/rename/RnSource.lhs
+++ b/compiler/rename/RnSource.lhs
@@ -723,18 +723,14 @@ badRuleLhsErr name lhs bad_e
\begin{code}
rnHsVectDecl :: VectDecl RdrName -> RnM (VectDecl Name, FreeVars)
-rnHsVectDecl (HsVect var Nothing)
- = do { var' <- lookupLocatedOccRn var
- ; return (HsVect var' Nothing, unitFV (unLoc var'))
- }
-- FIXME: For the moment, the right-hand side is restricted to be a variable as we cannot properly
-- typecheck a complex right-hand side without invoking 'vectType' from the vectoriser.
-rnHsVectDecl (HsVect var (Just rhs@(L _ (HsVar _))))
+rnHsVectDecl (HsVect var rhs@(L _ (HsVar _)))
= do { var' <- lookupLocatedOccRn var
; (rhs', fv_rhs) <- rnLExpr rhs
- ; return (HsVect var' (Just rhs'), fv_rhs `addOneFV` unLoc var')
+ ; return (HsVect var' rhs', fv_rhs `addOneFV` unLoc var')
}
-rnHsVectDecl (HsVect _var (Just _rhs))
+rnHsVectDecl (HsVect _var _rhs)
= failWith $ vcat
[ ptext (sLit "IMPLEMENTATION RESTRICTION: right-hand side of a VECTORISE pragma")
, ptext (sLit "must be an identifier")
diff --git a/compiler/typecheck/TcBinds.lhs b/compiler/typecheck/TcBinds.lhs
index 5eb8e150ef..b30aecab5f 100644
--- a/compiler/typecheck/TcBinds.lhs
+++ b/compiler/typecheck/TcBinds.lhs
@@ -739,17 +739,12 @@ tcVect :: VectDecl Name -> TcM (VectDecl TcId)
-- during type checking. Instead, constrain the rhs of a vectorisation declaration to be a single
-- identifier (this is checked in 'rnHsVectDecl'). Fix this by enabling the use of 'vectType'
-- from the vectoriser here.
-tcVect (HsVect name Nothing)
- = addErrCtxt (vectCtxt name) $
- do { var <- wrapLocM tcLookupId name
- ; return $ HsVect var Nothing
- }
-tcVect (HsVect name (Just rhs))
+tcVect (HsVect name rhs)
= addErrCtxt (vectCtxt name) $
do { var <- wrapLocM tcLookupId name
; let L rhs_loc (HsVar rhs_var_name) = rhs
; rhs_id <- tcLookupId rhs_var_name
- ; return $ HsVect var (Just $ L rhs_loc (HsVar rhs_id))
+ ; return $ HsVect var (L rhs_loc (HsVar rhs_id))
}
{- OLD CODE:
diff --git a/compiler/typecheck/TcHsSyn.lhs b/compiler/typecheck/TcHsSyn.lhs
index d1a82b225d..9714b4e7fb 100644
--- a/compiler/typecheck/TcHsSyn.lhs
+++ b/compiler/typecheck/TcHsSyn.lhs
@@ -1081,7 +1081,7 @@ zonkVects env = mappM (wrapLocM (zonkVect env))
zonkVect :: ZonkEnv -> VectDecl TcId -> TcM (VectDecl Id)
zonkVect env (HsVect v e)
= do { v' <- wrapLocM (zonkIdBndr env) v
- ; e' <- fmapMaybeM (zonkLExpr env) e
+ ; e' <- zonkLExpr env e
; return $ HsVect v' e'
}
zonkVect env (HsNoVect v)
diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs
index 8b7e817826..e6c4b1e0cf 100644
--- a/compiler/vectorise/Vectorise.hs
+++ b/compiler/vectorise/Vectorise.hs
@@ -13,26 +13,22 @@ import Vectorise.Type.Type
import Vectorise.Convert
import Vectorise.Utils.Hoisting
import Vectorise.Exp
-import Vectorise.Vect
import Vectorise.Env
import Vectorise.Monad
import HscTypes hiding ( MonadThings(..) )
import CoreUnfold ( mkInlineUnfolding )
-import CoreFVs
import PprCore
import CoreSyn
import CoreMonad ( CoreM, getHscEnv )
import Type
import Id
import DynFlags
-import BasicTypes ( isStrongLoopBreaker )
import Outputable
import Util ( zipLazy )
import MonadUtils
import Control.Monad
-import Data.Maybe
-- |Vectorise a single module.
@@ -69,7 +65,7 @@ vectModule guts@(ModGuts { mg_tcs = tycons
= do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $
pprCoreBindings binds
- -- Pick out all 'VECTORISE type' and 'VECTORISE class' pragmas
+ -- Pick out all 'VECTORISE [SCALAR] type' and 'VECTORISE class' pragmas
; let ty_vect_decls = [vd | vd@(VectType _ _ _) <- vect_decls]
cls_vect_decls = [vd | vd@(VectClass _) <- vect_decls]
@@ -87,8 +83,7 @@ vectModule guts@(ModGuts { mg_tcs = tycons
-- Vectorise all the top level bindings and VECTORISE declarations on imported identifiers
-- NB: Need to vectorise the imported bindings first (local bindings may depend on them).
- ; let impBinds = [imp_id | Vect imp_id _ <- vect_decls, isGlobalId imp_id] ++
- [imp_id | VectInst imp_id <- vect_decls, isGlobalId imp_id]
+ ; let impBinds = [(imp_id, expr) | Vect imp_id expr <- vect_decls, isGlobalId imp_id]
; binds_imp <- mapM vectImpBind impBinds
; binds_top <- mapM vectTopBind binds
@@ -101,7 +96,8 @@ vectModule guts@(ModGuts { mg_tcs = tycons
}
}
--- Try to vectorise a top-level binding. If it doesn't vectorise then return it unharmed.
+-- Try to vectorise a top-level binding. If it doesn't vectorise, or if it is entirely scalar, then
+-- omit vectorisation of that binding.
--
-- For example, for the binding
--
@@ -125,129 +121,173 @@ vectModule guts@(ModGuts { mg_tcs = tycons
-- lfoo = ...
-- @
--
--- @vfoo@ is the "vectorised", or scalar, version that does the same as the original
--- function foo, but takes an explicit environment.
+-- @vfoo@ is the "vectorised", or scalar, version that does the same as the original function foo,
+-- but takes an explicit environment.
--
-- @lfoo@ is the "lifted" version that works on arrays.
--
--- @v_foo@ combines both of these into a `Closure` that also contains the
--- environment.
+-- @v_foo@ combines both of these into a `Closure` that also contains the environment.
--
--- The original binding @foo@ is rewritten to call the vectorised version
--- present in the closure.
+-- The original binding @foo@ is rewritten to call the vectorised version present in the closure.
--
-- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma. If this
-- pragma is used in a group of mutually recursive bindings, either all or no binding must have
--- the pragma. If only some bindings are annotated, a fatal error is being raised.
+-- the pragma. If only some bindings are annotated, a fatal error is being raised. (In the case of
+-- scalar bindings, we only omit vectorisation if all bindings in a group are scalar.)
+--
-- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
-- we may emit a warning and refrain from vectorising the entire group.
--
vectTopBind :: CoreBind -> VM CoreBind
vectTopBind b@(NonRec var expr)
- = unlessNoVectDecl $
- do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
- -- to the vectorisation map.
- ; (inline, isScalar, expr') <- vectTopRhs [] var expr
- ; var' <- vectTopBinder var inline expr'
- ; when isScalar $
- addGlobalScalarVar var
-
- -- We replace the original top-level binding by a value projected from the vectorised
- -- closure and add any newly created hoisted top-level bindings.
- ; cexpr <- tryConvert var var' expr
- ; hs <- takeHoisted
- ; return . Rec $ (var, cexpr) : (var', expr') : hs
- }
- `orElseErrV`
- do { emitVt " Could NOT vectorise top-level binding" $ ppr var
- ; return b
+ = do
+ { traceVt "= Vectorise non-recursive top-level variable" (ppr var)
+
+ ; (hasNoVect, vectDecl) <- lookupVectDecl var
+ ; if hasNoVect
+ then do
+ { -- 'NOVECTORISE' pragma => leave this binding as it is
+ ; traceVt "NOVECTORISE" $ ppr var
+ ; return b
+ }
+ else do
+ { vectRhs <- case vectDecl of
+ Just (_, expr') ->
+ -- 'VECTORISE' pragma => just use the provided vectorised rhs
+ do
+ { traceVt "VECTORISE" $ ppr var
+ ; return $ Just (False, inlineMe, expr')
+ }
+ Nothing ->
+ -- no pragma => standard vectorisation of rhs
+ do
+ { traceVt "[Vanilla]" $ ppr var <+> char '=' <+> ppr expr
+ ; vectTopExpr var expr
+ }
+ ; hs <- takeHoisted -- make sure we clean those out (even if we skip)
+ ; case vectRhs of
+ { Nothing ->
+ -- scalar binding => leave this binding as it is
+ do
+ { traceVt "scalar binding [skip]" $ ppr var
+ ; return b
+ }
+ ; Just (parBind, inline, expr') -> do
+ {
+ -- vanilla case => create an appropriate top-level binding & add it to the vectorisation map
+ ; when parBind $
+ addGlobalParallelVar var
+ ; var' <- vectTopBinder var inline expr'
+
+ -- We replace the original top-level binding by a value projected from the vectorised
+ -- closure and add any newly created hoisted top-level bindings.
+ ; cexpr <- tryConvert var var' expr
+ ; return . Rec $ (var, cexpr) : (var', expr') : hs
+ } } } }
+ `orElseErrV`
+ do
+ { emitVt " Could NOT vectorise top-level binding" $ ppr var
+ ; return b
+ }
+vectTopBind b@(Rec binds)
+ = do
+ { traceVt "= Vectorise recursive top-level variables" $ ppr vars
+
+ ; vectDecls <- mapM lookupVectDecl vars
+ ; let hasNoVects = map fst vectDecls
+ ; if and hasNoVects
+ then do
+ { -- 'NOVECTORISE' pragmas => leave this entire binding group as it is
+ ; traceVt "NOVECTORISE" $ ppr vars
+ ; return b
+ }
+ else do
+ { if or hasNoVects
+ then do
+ { -- Inconsistent 'NOVECTORISE' pragmas => bail out
+ ; dflags <- getDynFlags
+ ; cantVectorise dflags noVectoriseErr (ppr b)
}
- where
- unlessNoVectDecl vectorise
- = do { hasNoVectDecl <- noVectDecl var
- ; when hasNoVectDecl $
- traceVt "NOVECTORISE" $ ppr var
- ; if hasNoVectDecl then return b else vectorise
- }
-vectTopBind b@(Rec bs)
- = unlessSomeNoVectDecl $
- do { (vars', _, exprs', hs) <- fixV $
- \ ~(_, inlines, rhss, _) ->
- do { -- Vectorise the right-hand sides, create an appropriate top-level bindings
- -- and add them to the vectorisation map.
- ; vars' <- sequence [vectTopBinder var inline rhs
- | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
- ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
- ; hs <- takeHoisted
- ; if and areScalars
- then -- (1) Entire recursive group is scalar
- -- => add all variables to the global set of scalars
- do { mapM_ addGlobalScalarVar vars
- ; return (vars', inlines, exprs', hs)
- }
- else -- (2) At least one binding is not scalar
- -- => vectorise again with empty set of local scalars
- do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
- ; hs <- takeHoisted
- ; return (vars', inlines, exprs', hs)
- }
- }
-
- -- Replace the original top-level bindings by a values projected from the vectorised
- -- closures and add any newly created hoisted top-level bindings to the group.
- ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
- ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
- }
- `orElseErrV`
- return b
- where
- (vars, exprs) = unzip bs
+ else do
+ { -- For all bindings *with* a pragma, just use the pragma-supplied vectorised expression
+ ; newBindsWPragma <- concat <$>
+ sequence [ vectTopBindAndConvert bind inlineMe expr'
+ | (bind, (_, Just (_, expr'))) <- zip binds vectDecls]
+
+ -- Standard vectorisation of all rhses that are *without* a pragma.
+ -- NB: The reason for 'fixV' is rather subtle: 'vectTopBindAndConvert' adds entries for
+ -- the bound variables in the recursive group to the vectorisation map, which in turn
+ -- are needed by 'vectPolyExprs' (unless it returns 'Nothing').
+ ; let bindsWOPragma = [bind | (bind, (_, Nothing)) <- zip binds vectDecls]
+ ; (newBinds, _) <- fixV $
+ \ ~(_, exprs') ->
+ do
+ { -- Create appropriate top-level bindings, enter them into the vectorisation map, and
+ -- vectorise the right-hand sides
+ ; newBindsWOPragma <- concat <$>
+ sequence [vectTopBindAndConvert bind inline expr
+ | (bind, ~(inline, expr)) <- zipLazy bindsWOPragma exprs']
+ -- irrefutable pattern and 'zipLazy' to tie the knot;
+ -- hence, can't use 'zipWithM'
+ ; vectRhses <- vectTopExprs bindsWOPragma
+ ; hs <- takeHoisted -- make sure we clean those out (even if we skip)
- unlessSomeNoVectDecl vectorise
- = do { hasNoVectDecls <- mapM noVectDecl vars
- ; when (and hasNoVectDecls) $
- traceVt "NOVECTORISE" $ ppr vars
- ; if and hasNoVectDecls
- then return b -- all bindings have 'NOVECTORISE'
- else if or hasNoVectDecls
- then do dflags <- getDynFlags
- cantVectorise dflags noVectoriseErr (ppr b) -- some (but not all) have 'NOVECTORISE'
- else vectorise -- no binding has a 'NOVECTORISE' decl
- }
+ ; case vectRhses of
+ Nothing ->
+ -- scalar bindings => skip all bindings except those with pragmas and retract the
+ -- entries into the vectorisation map for the scalar bindings
+ do
+ { traceVt "scalar bindings [skip]" $ ppr vars
+ ; mapM_ (undefGlobalVar . fst) bindsWOPragma
+ ; return (bindsWOPragma ++ newBindsWPragma, exprs')
+ }
+ Just (parBind, exprs') ->
+ -- vanilla case => record parallel variables and return the final bindings
+ do
+ { when parBind $
+ mapM_ addGlobalParallelVar vars
+ ; return (newBindsWOPragma ++ newBindsWPragma ++ hs, exprs')
+ }
+ }
+ ; return $ Rec newBinds
+ } } }
+ `orElseErrV`
+ do
+ { emitVt " Could NOT vectorise top-level bindings" $ ppr vars
+ ; return b
+ }
+ where
+ vars = map fst binds
noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
+
+ -- Replace the original top-level bindings by a values projected from the vectorised
+ -- closures and add any newly created hoisted top-level bindings to the group.
+ vectTopBindAndConvert (var, expr) inline expr'
+ = do
+ { var' <- vectTopBinder var inline expr'
+ ; cexpr <- tryConvert var var' expr
+ ; return [(var, cexpr), (var', expr')]
+ }
--- Add a vectorised binding to an imported top-level variable that has a VECTORISE [SCALAR] pragma
+-- Add a vectorised binding to an imported top-level variable that has a VECTORISE pragma
-- in this module.
--
--- RESTIRCTION: Currently, we cannot use the pragma vor mutually recursive definitions.
+-- RESTIRCTION: Currently, we cannot use the pragma for mutually recursive definitions.
--
-vectImpBind :: Id -> VM CoreBind
-vectImpBind var
- = do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
- -- to the vectorisation map. For the non-lifted version, we refer to the original
- -- definition — i.e., 'Var var'.
- -- NB: To support recursive definitions, we tie a lazy knot.
- ; (var', _, expr') <- fixV $
- \ ~(_, inline, rhs) ->
- do { var' <- vectTopBinder var inline rhs
- ; (inline, isScalar, expr') <- vectTopRhs [] var (Var var)
-
- ; when isScalar $
- addGlobalScalarVar var
- ; return (var', inline, expr')
- }
+vectImpBind :: (Id, CoreExpr) -> VM CoreBind
+vectImpBind (var, expr)
+ = do
+ { traceVt "= Add vectorised binding to imported variable" (ppr var)
- -- We add any newly created hoisted top-level bindings.
- ; hs <- takeHoisted
- ; return . Rec $ (var', expr') : hs
- }
-
--- | Make the vectorised version of this top level binder, and add the mapping
--- between it and the original to the state. For some binder @foo@ the vectorised
--- version is @$v_foo@
+ ; var' <- vectTopBinder var inlineMe expr
+ ; return $ NonRec var' expr
+ }
+
+-- |Make the vectorised version of this top level binder, and add the mapping between it and the
+-- original to the state. For some binder @foo@ the vectorised version is @$v_foo@
--
--- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is
--- used inside of 'fixV' in 'vectTopBind'.
+-- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is used inside of
+-- 'fixV' in 'vectTopBind'.
--
vectTopBinder :: Var -- ^ Name of the binding.
-> Inline -- ^ Whether it should be inlined, used to annotate it.
@@ -257,20 +297,20 @@ vectTopBinder var inline expr
= do { -- Vectorise the type attached to the var.
; vty <- vectType (idType var)
- -- If there is a vectorisation declartion for this binding, make sure that its type
- -- matches
- ; vectDecl <- lookupVectDecl var
+ -- If there is a vectorisation declartion for this binding, make sure its type matches
+ ; (_, vectDecl) <- lookupVectDecl var
; case vectDecl of
Nothing -> return ()
Just (vdty, _)
| eqType vty vdty -> return ()
| otherwise ->
- do dflags <- getDynFlags
- cantVectorise dflags ("Type mismatch in vectorisation pragma for " ++ showPpr dflags var) $
- (text "Expected type" <+> ppr vty)
- $$
- (text "Inferred type" <+> ppr vdty)
-
+ do
+ { dflags <- getDynFlags
+ ; cantVectorise dflags ("Type mismatch in vectorisation pragma for " ++ showPpr dflags var) $
+ (text "Expected type" <+> ppr vty)
+ $$
+ (text "Inferred type" <+> ppr vdty)
+ }
-- Make the vectorised version of binding's name, and set the unfolding used for inlining
; var' <- liftM (`setIdUnfoldingLazily` unfolding)
$ mkVectId var vty
@@ -297,113 +337,17 @@ vectTopBinder var inline expr
`setInlinePragma` dfunInlinePragma
-}
--- | Vectorise the RHS of a top-level binding, in an empty local environment.
+-- |Project out the vectorised version of a binding from some closure, or return the original body
+-- if that doesn't work.
--
--- We need to distinguish four cases:
---
--- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
--- vectorised code implemented by the user)
--- => no automatic vectorisation & instead use the user-supplied code
---
--- (2) We have a scalar vectorisation declaration for a variable that is no dfun
--- => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
---
--- (3) We have a scalar vectorisation declaration for a variable that *is* a dfun
--- => generate vectorised code according to the the "Note [Scalar dfuns]" below
---
--- (4) There is no vectorisation declaration for the variable
--- => perform automatic vectorisation of the RHS (the definition may or may not be a dfun;
--- vectorisation proceeds differently depending on which it is)
---
--- Note [Scalar dfuns]
--- ~~~~~~~~~~~~~~~~~~~
---
--- Here is the translation scheme for scalar dfuns — assume the instance declaration:
---
--- instance Num Int where
--- (+) = primAdd
--- {-# VECTORISE SCALAR instance Num Int #-}
---
--- It desugars to
---
--- $dNumInt :: Num Int
--- $dNumInt = D:Num primAdd
---
--- We vectorise it to
---
--- $v$dNumInt :: V:Num Int
--- $v$dNumInt = D:V:Num (closure2 ((+) $dNumInt) (scalar_zipWith ((+) $dNumInt))))
---
--- while adding the following entry to the vectorisation map: '$dNumInt' --> '$v$dNumInt'.
---
--- See "Note [Vectorising classes]" in 'Vectorise.Type.Env' for the definition of 'V:Num'.
---
--- NB: The outlined vectorisation scheme does not require the right-hand side of the original dfun.
--- In fact, we definitely want to refer to the dfn variable instead of the right-hand side to
--- ensure that the dictionary selection rules fire.
---
-vectTopRhs :: [Var] -- ^ Names of all functions in the rec block
- -> Var -- ^ Name of the binding.
- -> CoreExpr -- ^ Body of the binding.
- -> VM ( Inline -- (1) inline specification for the binding
- , Bool -- (2) whether the right-hand side is a scalar computation
- , CoreExpr) -- (3) the vectorised right-hand side
-vectTopRhs recFs var expr
- = closedV
- $ do { globalScalar <- isGlobalScalarVar var
- ; vectDecl <- lookupVectDecl var
- ; dflags <- getDynFlags
- ; let isDFun = isDFunId var
-
- ; traceVt ("vectTopRhs of " ++ showPpr dflags var ++ info globalScalar isDFun vectDecl ++ ":") $
- ppr expr
-
- ; rhs globalScalar isDFun vectDecl
- }
- where
- rhs _globalScalar _isDFun (Just (_, expr')) -- Case (1)
- = return (inlineMe, False, expr')
- rhs True False Nothing -- Case (2)
- = do { expr' <- vectScalarFun expr
- ; return (inlineMe, True, vectorised expr')
- }
- rhs True True Nothing -- Case (3)
- = do { expr' <- vectScalarDFun var
- ; return (DontInline, True, expr')
- }
- rhs False False Nothing -- Case (4) — not a dfun
- = do { let exprFvs = freeVars expr
- ; (inline, isScalar, vexpr)
- <- inBind var $
- vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs exprFvs Nothing
- ; return (inline, isScalar, vectorised vexpr)
- }
- rhs False True Nothing -- Case (4) — is a dfun
- = do { expr' <- vectDictExpr expr
- ; return (DontInline, True, expr')
- }
-
- info True False _ = " [VECTORISE SCALAR]"
- info True True _ = " [VECTORISE SCALAR instance]"
- info False _ vectDecl | isJust vectDecl = " [VECTORISE]"
- | otherwise = " (no pragma)"
-
--- |Project out the vectorised version of a binding from some closure,
--- or return the original body if that doesn't work or the binding is scalar.
---
-tryConvert :: Var -- ^ Name of the original binding (eg @foo@)
- -> Var -- ^ Name of vectorised version of binding (eg @$vfoo@)
- -> CoreExpr -- ^ The original body of the binding.
+tryConvert :: Var -- ^Name of the original binding (eg @foo@)
+ -> Var -- ^Name of vectorised version of binding (eg @$vfoo@)
+ -> CoreExpr -- ^The original body of the binding.
-> VM CoreExpr
tryConvert var vect_var rhs
- = do { globalScalar <- isGlobalScalarVar var
- ; if globalScalar
- then
- return rhs
- else
- fromVect (idType var) (Var vect_var)
- `orElseErrV`
- do { emitVt " Could NOT call vectorised from original version" $ ppr var
- ; return rhs
- }
- }
+ = fromVect (idType var) (Var vect_var)
+ `orElseErrV`
+ do
+ { emitVt " Could NOT call vectorised from original version" $ ppr var
+ ; return rhs
+ }
diff --git a/compiler/vectorise/Vectorise/Convert.hs b/compiler/vectorise/Vectorise/Convert.hs
index 048362d59c..f21f5cac86 100644
--- a/compiler/vectorise/Vectorise/Convert.hs
+++ b/compiler/vectorise/Vectorise/Convert.hs
@@ -84,16 +84,16 @@ identityConv (AppTy {}) = noV $ text "identityConv: type appl. changes under
identityConv (FunTy {}) = noV $ text "identityConv: function type changes under vectorisation"
identityConv (ForAllTy {}) = noV $ text "identityConv: quantified type changes under vectorisation"
--- |Check that this type constructor is neutral under type vectorisation — i.e., it is not altered
--- by vectorisation as they contain no parallel arrays.
+-- |Check that this type constructor is not changed by vectorisation — i.e., it does not embed any
+-- parallel arrays.
--
identityConvTyCon :: TyCon -> VM ()
identityConvTyCon tc
- | isBoxedTupleTyCon tc = return ()
- | isUnLiftedTyCon tc = return ()
- | otherwise
- = do tc' <- maybeV notVectErr (lookupTyCon tc)
- if tc == tc' then return () else noV idErr
+ = do
+ { tc' <- lookupTyCon tc
+ ; case tc' of
+ Nothing -> return ()
+ Just _ -> noV idErr
+ }
where
- notVectErr = text "identityConvTyCon: no vectorised version for type constructor" <+> ppr tc
- idErr = text "identityConvTyCon: type constructor contains parallel arrays" <+> ppr tc
+ idErr = text "identityConvTyCon: type constructor contains parallel arrays" <+> ppr tc
diff --git a/compiler/vectorise/Vectorise/Env.hs b/compiler/vectorise/Vectorise/Env.hs
index d58ec8f800..345b4ba1c3 100644
--- a/compiler/vectorise/Vectorise/Env.hs
+++ b/compiler/vectorise/Vectorise/Env.hs
@@ -31,7 +31,7 @@ import Name
import NameEnv
import FastString
import TysPrim
-import TysWiredIn
+--import TysWiredIn
import Data.Maybe
@@ -60,7 +60,8 @@ data LocalEnv
-- ^Mapping from tyvars to their PA dictionaries.
, local_bind_name :: FastString
- -- ^Local binding name.
+ -- ^Local binding name. This is only used to generate better names for hoisted
+ -- expressions.
}
-- |Create an empty local environment.
@@ -84,35 +85,34 @@ data GlobalEnv
-- ^Mapping from global variables to their vectorised versions — aka the /vectorisation
-- map/.
- , global_vect_decls :: VarEnv (Type, CoreExpr)
- -- ^Mapping from global variables that have a vectorisation declaration to the right-hand
- -- side of that declaration and its type. This mapping only applies to non-scalar
- -- vectorisation declarations. All variables with a scalar vectorisation declaration are
- -- mentioned in 'global_scalars_vars'.
-
- , global_scalar_vars :: VarSet
- -- ^Purely scalar variables. Code which mentions only these variables doesn't have to be
- -- lifted. This includes variables from the current module that have a scalar
- -- vectorisation declaration and those that the vectoriser determines to be scalar.
-
- , global_scalar_tycons :: NameSet
- -- ^Type constructors whose values can only contain scalar data. This includes type
- -- constructors that appear in a 'VECTORISE SCALAR type' pragma or 'VECTORISE type' pragma
- -- *without* a right-hand side in the current or an imported module as well as type
- -- constructors that are automatically identified as scalar by the vectoriser (in
- -- 'Vectorise.Type.Env'). Scalar code may only operate on such data.
+ , global_parallel_vars :: VarSet
+ -- ^The domain of 'global_vars'.
--
- -- NB: Not all type constructors in that set are members of the 'Scalar' type class
- -- (which can be trivially marshalled across scalar code boundaries).
-
- , global_novect_vars :: VarSet
- -- ^Variables that are not vectorised. (They may be referenced in the right-hand sides
- -- of vectorisation declarations, though.)
+ -- This information is not redundant as it is impossible to extract the domain from a
+ -- 'VarEnv' (which is keyed on uniques alone). Moreover, we have mapped variables that
+ -- do not involve parallelism — e.g., the workers of vectorised, but scalar data types.
+ -- In addition, workers of parallel data types that we could not vectorise also need to
+ -- be tracked.
+
+ , global_vect_decls :: VarEnv (Maybe (Type, CoreExpr))
+ -- ^Mapping from global variables that have a vectorisation declaration to the right-hand
+ -- side of that declaration and its type and mapping variables that have NOVECTORISE
+ -- declarations to 'Nothing'.
, global_tycons :: NameEnv TyCon
- -- ^Mapping from TyCons to their vectorised versions.
- -- TyCons which do not have to be vectorised are mapped to themselves.
+ -- ^Mapping from TyCons to their vectorised versions. The vectorised version will be
+ -- identical to the original version if it is not changed by vectorisation. In any case,
+ -- if a tycon appears in the domain of this mapping, it was successfully vectorised.
+ , global_parallel_tycons :: NameSet
+ -- ^Type constructors whose definition directly or indirectly includes a parallel type,
+ -- such as '[::]'.
+ --
+ -- NB: This information is not redundant as some types have got a mapping in
+ -- 'global_tycons' (to a type other than themselves) and are still not parallel. An
+ -- example is '(->)'. Moreover, some types have *not* got a mapping in 'global_tycons'
+ -- (because they couldn't be vectorised), but still contain parallel types.
+
, global_datacons :: NameEnv DataCon
-- ^Mapping from DataCons to their vectorised versions.
@@ -129,7 +129,7 @@ data GlobalEnv
-- ^External package inst-env & home-package inst-env for family instances.
, global_bindings :: [(Var, CoreExpr)]
- -- ^Hoisted bindings.
+ -- ^Hoisted bindings — temporary storage for toplevel bindings during code gen.
}
-- |Create an initial global environment.
@@ -143,9 +143,8 @@ initGlobalEnv info vectDecls instEnvs famInstEnvs
= GlobalEnv
{ global_vars = mapVarEnv snd $ vectInfoVar info
, global_vect_decls = mkVarEnv vects
- , global_scalar_vars = vectInfoScalarVars info `extendVarSetList` scalar_vars
- , global_scalar_tycons = vectInfoScalarTyCons info `addListToNameSet` scalar_tycons
- , global_novect_vars = mkVarSet novects
+ , global_parallel_vars = vectInfoParallelVars info
+ , global_parallel_tycons = vectInfoParallelTyCons info
, global_tycons = mapNameEnv snd $ vectInfoTyCon info
, global_datacons = mapNameEnv snd $ vectInfoDataCon info
, global_pa_funs = emptyNameEnv
@@ -155,23 +154,12 @@ initGlobalEnv info vectDecls instEnvs famInstEnvs
, global_bindings = []
}
where
- vects = [(var, (ty, exp)) | Vect var (Just exp@(Var rhs_var)) <- vectDecls
- , let ty = varType rhs_var]
+ vects = [(var, Just (ty, exp)) | Vect var exp@(Var rhs_var) <- vectDecls
+ , let ty = varType rhs_var] ++
-- FIXME: we currently only allow RHSes consisting of a
-- single variable to be able to obtain the type without
-- inference — see also 'TcBinds.tcVect'
- scalar_vars = [var | Vect var Nothing <- vectDecls] ++
- [var | VectInst var <- vectDecls] ++
- [dataConWrapId doubleDataCon, dataConWrapId floatDataCon, dataConWrapId intDataCon] -- TODO: fix this hack
- novects = [var | NoVect var <- vectDecls]
- scalar_tycons = [tyConName tycon | VectType True tycon Nothing <- vectDecls] ++
- [tyConName tycon | VectType _ tycon (Just tycon') <- vectDecls
- , tycon == tycon'] ++
- map tyConName [doublePrimTyCon, intPrimTyCon, floatPrimTyCon] -- TODO: fix this hack
- -- - for 'VectType True tycon Nothing', we checked that the type does not
- -- contain arrays (or type variables that could be instatiated to arrays)
- -- - for 'VectType _ tycon (Just tycon')', where the two tycons are the same,
- -- we also know that there can be no embedded arrays
+ [(var, Nothing) | NoVect var <- vectDecls]
-- Operators on Global Environments -------------------------------------------
@@ -210,11 +198,11 @@ setPRFunsEnv ps genv = genv { global_pr_funs = mkNameEnv ps }
modVectInfo :: GlobalEnv -> [Id] -> [TyCon] -> [CoreVect]-> VectInfo -> VectInfo
modVectInfo env mg_ids mg_tyCons vectDecls info
= info
- { vectInfoVar = mk_env ids (global_vars env)
- , vectInfoTyCon = mk_env tyCons (global_tycons env)
- , vectInfoDataCon = mk_env dataCons (global_datacons env)
- , vectInfoScalarVars = global_scalar_vars env `minusVarSet` vectInfoScalarVars info
- , vectInfoScalarTyCons = global_scalar_tycons env `minusNameSet` vectInfoScalarTyCons info
+ { vectInfoVar = mk_env ids (global_vars env)
+ , vectInfoTyCon = mk_env tyCons (global_tycons env)
+ , vectInfoDataCon = mk_env dataCons (global_datacons env)
+ , vectInfoParallelVars = global_parallel_vars env `minusVarSet` vectInfoParallelVars info
+ , vectInfoParallelTyCons = global_parallel_tycons env `minusNameSet` vectInfoParallelTyCons info
}
where
vectIds = [id | Vect id _ <- vectDecls] ++
diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs
index 8c5ef0045d..88f123210b 100644
--- a/compiler/vectorise/Vectorise/Exp.hs
+++ b/compiler/vectorise/Vectorise/Exp.hs
@@ -3,10 +3,9 @@
-- |Vectorisation of expressions.
module Vectorise.Exp
- ( -- * Vectorise polymorphic expressions with special cases for right-hand sides of particular
- -- variable bindings
- vectPolyExpr
- , vectDictExpr
+ ( -- * Vectorise right-hand sides of toplevel bindings
+ vectTopExpr
+ , vectTopExprs
, vectScalarFun
, vectScalarDFun
)
@@ -32,393 +31,404 @@ import DataCon
import TyCon
import TcType
import Type
-import PrelNames
+import TypeRep
import Var
import VarEnv
import VarSet
+import NameSet
import Id
import BasicTypes( isStrongLoopBreaker )
import Literal
-import TysWiredIn
import TysPrim
import Outputable
import FastString
+import DynFlags
+import Util
+import MonadUtils
+
import Control.Monad
-import Control.Applicative
import Data.Maybe
import Data.List
-import TcRnMonad (doptM)
-import DynFlags
-import Util
-- Main entry point to vectorise expressions -----------------------------------
--- |Vectorise a polymorphic expression.
+-- |Vectorise a polymorphic expression that forms a *non-recursive* binding.
+--
+-- Return 'Nothing' if the expression is scalar; otherwise, the first component of the result
+-- (which is of type 'Bool') indicates whether the expression is parallel (i.e., whether it is
+-- tagged as 'VIParr').
--
--- If not yet available, precompute vectorisation avoidance information before vectorising. If
--- the vectorisation avoidance optimisation is enabled, also use the vectorisation avoidance
--- information to encapsulated subexpression that do not need to be vectorised.
+-- We have got the non-recursive case as a special case as it doesn't require to compute
+-- vectorisation information twice.
--
-vectPolyExpr :: Bool -> [Var] -> CoreExprWithFVs -> Maybe VITree
- -> VM (Inline, Bool, VExpr)
- -- precompute vectorisation avoidance information (and possibly encapsulated subexpressions)
-vectPolyExpr loop_breaker recFns expr Nothing
+vectTopExpr :: Var -> CoreExpr -> VM (Maybe (Bool, Inline, CoreExpr))
+vectTopExpr var expr
= do
- { vectAvoidance <- liftDs $ doptM Opt_AvoidVect
- ; vi <- vectAvoidInfo expr
- ; (expr', vi') <-
- if vectAvoidance
- then do
- { (expr', vi') <- encapsulateScalars vi expr
- ; traceVt "vectPolyExpr encapsulated:" (ppr $ deAnnotate expr')
- ; return (expr', vi')
- }
- else return (expr, vi)
- ; vectPolyExpr loop_breaker recFns expr' (Just vi')
+ { exprVI <- encapsulateScalars <=< vectAvoidInfo emptyVarSet . freeVars $ expr
+ ; if isVIEncaps exprVI
+ then
+ return Nothing
+ else do
+ { vExpr <- closedV $
+ inBind var $
+ vectAnnPolyExpr False exprVI
+ ; inline <- computeInline exprVI
+ ; return $ Just (isVIParr exprVI, inline, vectorised vExpr)
+ }
}
- -- traverse through ticks
-vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr) (Just (VITNode _ [vit]))
- = do
- { (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr (Just vit)
- ; return (inline, isScalarFn, vTick tickish expr')
- }
+-- Compute the inlining hint for the right-hand side of a top-level binding.
+--
+computeInline :: CoreExprWithVectInfo -> VM Inline
+computeInline ((_, VIDict), _) = return $ DontInline
+computeInline (_, AnnTick _ expr) = computeInline expr
+computeInline expr@(_, AnnLam _ _) = Inline <$> polyArity tvs
+ where
+ (tvs, _) = collectAnnTypeBinders expr
+computeInline _expr = return $ DontInline
- -- collect and vectorise type abstractions; then, descent into the body
-vectPolyExpr loop_breaker recFns expr (Just vit)
- = do
- { let (tvs, mono) = collectAnnTypeBinders expr
- vit' = stripLevels (length tvs) vit
- ; arity <- polyArity tvs
- ; polyAbstract tvs $ \args ->
- do
- { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vit'
- ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono')
- }
+-- |Vectorise a recursive group of top-level polymorphic expressions.
+--
+-- Return 'Nothing' if the expression group is scalar; otherwise, the first component of the result
+-- (which is of type 'Bool') indicates whether the expressions are parallel (i.e., whether they are
+-- tagged as 'VIParr').
+--
+vectTopExprs :: [(Var, CoreExpr)] -> VM (Maybe (Bool, [(Inline, CoreExpr)]))
+vectTopExprs binds
+ = do
+ { exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs
+ ; if all isVIEncaps exprVIs
+ then
+ return Nothing
+ else do
+ { (areVIParr, vExprs) <- unzip <$> mapM encapsulateAndVect binds
+ ; return $ Just (or areVIParr, vExprs)
+ }
}
where
- stripLevels 0 vit = vit
- stripLevels n (VITNode _ [vit]) = stripLevels (n - 1) vit
- stripLevels _ vit = pprPanic "vectPolyExpr: stripLevels:" (text (show vit))
+ (vars, exprs) = unzip binds
+
+ vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars
+
+ encapsulateAndVect (var, expr)
+ = do
+ { exprVI <- vectAvoidAndEncapsulate (mkVarSet vars) expr
+ ; vExpr <- closedV $
+ inBind var $
+ vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI
+ ; inline <- computeInline exprVI
+ ; return (isVIParr exprVI, (inline, vectorised vExpr))
+ }
+
+-- |Vectorise a polymorphic expression annotated with vectorisation information.
+--
+-- The special case of dictionary functions is currently handled separately. (Would be neater to
+-- integrate them, though!)
+--
+vectAnnPolyExpr :: Bool -> CoreExprWithVectInfo -> VM VExpr
+vectAnnPolyExpr loop_breaker (_, AnnTick tickish expr)
+ -- traverse through ticks
+ = vTick tickish <$> vectAnnPolyExpr loop_breaker expr
+vectAnnPolyExpr loop_breaker expr
+ | isVIDict expr
+ -- special case the right-hand side of dictionary functions
+ = (, undefined) <$> vectDictExpr (deAnnotate expr)
+ | otherwise
+ -- collect and vectorise type abstractions; then, descent into the body
+ = polyAbstract tvs $ \args ->
+ mapVect (mkLams $ tvs ++ args) <$> vectFnExpr False loop_breaker mono
+ where
+ (tvs, mono) = collectAnnTypeBinders expr
-- Encapsulate every purely sequential subexpression of a (potentially) parallel expression into a
--- into a lambda abstraction over all its free variables followed by the corresponding application
--- to those variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions.
+-- lambda abstraction over all its free variables followed by the corresponding application to those
+-- variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions.
--
-- Preconditions:
--
-- * All free variables and the result type must be /simple/ types.
--- * The expression is sufficientlt complex (top warrant special treatment). For now, that is
+-- * The expression is sufficiently complex (to warrant special treatment). For now, that is
-- every expression that is not constant and contains at least one operation.
--
-encapsulateScalars :: VITree -> CoreExprWithFVs -> VM (CoreExprWithFVs, VITree)
-encapsulateScalars vit ce@(_, AnnType _ty)
- = return (ce, vit)
-
-encapsulateScalars vit ce@(_, AnnVar _v)
- = return (ce, vit)
-
-encapsulateScalars vit ce@(_, AnnLit _)
- = return (ce, vit)
-
-encapsulateScalars (VITNode vi [vit]) (fvs, AnnTick tck expr)
- = do { (extExpr, vit') <- encapsulateScalars vit expr
- ; return ((fvs, AnnTick tck extExpr), VITNode vi [vit'])
- }
-
-encapsulateScalars _ (_fvs, AnnTick _tck _expr)
- = panic "encapsulateScalar AnnTick doesn't match up"
-
-encapsulateScalars (VITNode vi [vit]) ce@(fvs, AnnLam bndr expr)
- = do { varsS <- varsSimple fvs
- ; case (vi, varsS) of
- (VISimple, True) -> do { let (e', vit') = liftSimple vit ce
- ; return (e', vit')
- }
- _ -> do { (extExpr, vit') <- encapsulateScalars vit expr
- ; return ((fvs, AnnLam bndr extExpr), VITNode vi [vit'])
- }
- }
-
-encapsulateScalars _ (_fvs, AnnLam _bndr _expr)
- = panic "encapsulateScalars AnnLam doesn't match up"
-
-encapsulateScalars vt@(VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2)
- = do { varsS <- varsSimple fvs
- ; case (vi, varsS) of
- (VISimple, True) -> do { let (e', vt') = liftSimple vt ce
- -- ; checkTreeAnnM vt' e'
- -- ; traceVt "Passed checkTree test!!" (ppr $ deAnnotate e')
- ; return (e', vt')
- }
- _ -> do { (etaCe1, vit1') <- encapsulateScalars vit1 ce1
- ; (etaCe2, vit2') <- encapsulateScalars vit2 ce2
- ; return ((fvs, AnnApp etaCe1 etaCe2), VITNode vi [vit1', vit2'])
- }
- }
-
-encapsulateScalars _ (_fvs, AnnApp _ce1 _ce2)
- = panic "encapsulateScalars AnnApp doesn't match up"
-
-encapsulateScalars vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts)
- = do { varsS <- varsSimple fvs
- ; case (vi, varsS) of
- (VISimple, True) -> return $ liftSimple vt ce
- _ -> do { (extScrut, scrutVit') <- encapsulateScalars scrutVit scrut
- ; extAltsVits <- zipWithM expAlt altVits alts
- ; let (extAlts, altVits') = unzip extAltsVits
- ; return ((fvs, AnnCase extScrut bndr ty extAlts), VITNode vi (scrutVit': altVits'))
- }
- }
+encapsulateScalars :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
+encapsulateScalars ce@(_, AnnType _ty)
+ = return ce
+encapsulateScalars ce@((_, VISimple), AnnVar v)
+ | isFunTy . varType $ v -- NB: diverts from the paper: encapsulate scalar function types
+ = liftSimpleAndCase ce
+encapsulateScalars ce@(_, AnnVar _v)
+ = return ce
+encapsulateScalars ce@(_, AnnLit _)
+ = return ce
+encapsulateScalars ((fvs, vi), AnnTick tck expr)
+ = do
+ { encExpr <- encapsulateScalars expr
+ ; return ((fvs, vi), AnnTick tck encExpr)
+ }
+encapsulateScalars ce@((fvs, vi), AnnLam bndr expr)
+ = do
+ { varsS <- allScalarVarTypeSet fvs
+ ; case (vi, varsS) of
+ (VISimple, True) -> liftSimpleAndCase ce
+ _ -> do
+ { encExpr <- encapsulateScalars expr
+ ; return ((fvs, vi), AnnLam bndr encExpr)
+ }
+ }
+encapsulateScalars ce@((fvs, vi), AnnApp ce1 ce2)
+ = do
+ { varsS <- allScalarVarTypeSet fvs
+ ; case (vi, varsS) of
+ (VISimple, True) -> liftSimpleAndCase ce
+ _ -> do
+ { encCe1 <- encapsulateScalars ce1
+ ; encCe2 <- encapsulateScalars ce2
+ ; return ((fvs, vi), AnnApp encCe1 encCe2)
+ }
+ }
+encapsulateScalars ce@((fvs, vi), AnnCase scrut bndr ty alts)
+ = do
+ { varsS <- allScalarVarTypeSet fvs
+ ; case (vi, varsS) of
+ (VISimple, True) -> liftSimpleAndCase ce
+ _ -> do
+ { encScrut <- encapsulateScalars scrut
+ ; encAlts <- mapM encAlt alts
+ ; return ((fvs, vi), AnnCase encScrut bndr ty encAlts)
+ }
+ }
where
- expAlt vt (con, bndrs, expr)
- = do { (extExpr, vt') <- encapsulateScalars vt expr
- ; return ((con, bndrs, extExpr), vt')
- }
-
-encapsulateScalars _ (_fvs, AnnCase _scrut _bndr _ty _alts)
- = panic "encapsulateScalars AnnCase doesn't match up"
-
-encapsulateScalars vt@(VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2)
- = do { varsS <- varsSimple fvs
- ; case (vi, varsS) of
- (VISimple, True) -> return $ liftSimple vt ce
- _ -> do { (extExpr1, vt1') <- encapsulateScalars vt1 expr1
- ; (extExpr2, vt2') <- encapsulateScalars vt2 expr2
- ; return ((fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2), VITNode vi [vt1', vt2'])
- }
- }
-
-encapsulateScalars _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2)
- = panic "encapsulateScalars AnnLet nonrec doesn't match up"
-
-encapsulateScalars vt@(VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr)
- = do { varsS <- varsSimple fvs
- ; case (vi, varsS) of
- (VISimple, True) -> return $ liftSimple vt ce
- _ -> do { extBndsVts <- zipWithM expBndg vtBnds bndngs
- ; let (extBnds, vtBnds') = unzip extBndsVts
- ; (extExpr, vtB') <- encapsulateScalars vtB expr
- ; let vt' = VITNode vi (vtB':vtBnds')
- ; return ((fvs, AnnLet (AnnRec extBnds) extExpr), vt')
- }
- }
- where
- expBndg vit (bndr, expr)
- = do { (extExpr, vit') <- encapsulateScalars vit expr
- ; return ((bndr, extExpr), vit')
- }
-
-encapsulateScalars _ (_fvs, AnnLet (AnnRec _) _expr2)
- = panic "encapsulateScalars AnnLet rec doesn't match up"
-
-encapsulateScalars (VITNode vi [vit]) (fvs, AnnCast expr coercion)
- = do { (extExpr, vit') <- encapsulateScalars vit expr
- ; return ((fvs, AnnCast extExpr coercion), VITNode vi [vit'])
- }
-
-encapsulateScalars _ (_fvs, AnnCast _expr _coercion)
- = panic "encapsulateScalars AnnCast rec doesn't match up"
-
-encapsulateScalars _ _
- = panic "encapsulateScalars case not handled"
+ encAlt (con, bndrs, expr) = (con, bndrs,) <$> encapsulateScalars expr
+encapsulateScalars ce@((fvs, vi), AnnLet (AnnNonRec bndr expr1) expr2)
+ = do
+ { varsS <- allScalarVarTypeSet fvs
+ ; case (vi, varsS) of
+ (VISimple, True) -> liftSimpleAndCase ce
+ _ -> do
+ { encExpr1 <- encapsulateScalars expr1
+ ; encExpr2 <- encapsulateScalars expr2
+ ; return ((fvs, vi), AnnLet (AnnNonRec bndr encExpr1) encExpr2)
+ }
+ }
+encapsulateScalars ce@((fvs, vi), AnnLet (AnnRec binds) expr)
+ = do
+ { varsS <- allScalarVarTypeSet fvs
+ ; case (vi, varsS) of
+ (VISimple, True) -> liftSimpleAndCase ce
+ _ -> do
+ { encBinds <- mapM encBind binds
+ ; encExpr <- encapsulateScalars expr
+ ; return ((fvs, vi), AnnLet (AnnRec encBinds) encExpr)
+ }
+ }
+ where
+ encBind (bndr, expr) = (bndr,) <$> encapsulateScalars expr
+encapsulateScalars ((fvs, vi), AnnCast expr coercion)
+ = do
+ { encExpr <- encapsulateScalars expr
+ ; return ((fvs, vi), AnnCast encExpr coercion)
+ }
+encapsulateScalars _
+ = panic "Vectorise.Exp.encapsulateScalars: unknown constructor"
--- Lambda-lift the given expression and apply it to the abstracted free variables.
+-- Lambda-lift the given simple expression and apply it to the abstracted free variables.
--
--- If the expression is a case expression scrutinising anything but a primitive type, then lift
+-- If the expression is a case expression scrutinising anything, but a scalar type, then lift
-- each alternative individually.
--
-liftSimple :: VITree -> CoreExprWithFVs -> (CoreExprWithFVs, VITree)
-liftSimple (VITNode vi (scrutVit : altVits)) (fvs, AnnCase expr bndr t alts)
- | Just (c,_) <- splitTyConApp_maybe (exprType $ deAnnotate $ expr),
- (not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon]) -- FIXME: shouldn't be hardcoded
- = ((fvs, AnnCase expr bndr t alts'), VITNode vi (scrutVit : altVits'))
- where
- (alts', altVits') = unzip $ map (\(ac,bndrs, (alt, avi)) -> ((ac,bndrs,alt), avi)) $
- zipWith (\(ac, bndrs, aex) -> \altVi -> (ac, bndrs, liftSimple altVi aex)) alts altVits
-
-liftSimple viTree ae@(fvs, _annEx)
- = (mkAnnApps (mkAnnLams ae vars) vars, viTree')
+liftSimpleAndCase :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
+liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts)
+ = do
+ { vi <- vectAvoidInfoTypeOf expr
+ ; if (vi == VISimple)
+ then
+ return $ liftSimple aexpr -- if the scrutinee is scalar, we need no special treatment
+ else do
+ { alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts
+ ; return ((fvs, vi), AnnCase expr bndr t alts')
+ }
+ }
+liftSimpleAndCase aexpr = return $ liftSimple aexpr
+
+liftSimple :: CoreExprWithVectInfo -> CoreExprWithVectInfo
+liftSimple ((fvs, vi), expr)
+ = ASSERT(vi == VISimple)
+ mkAnnApps (mkAnnLams vars fvs expr) vars
where
- mkViTreeLams (VITNode _ vits) [] = VITNode VIEncaps vits
- mkViTreeLams vi (_:vs) = VITNode VIEncaps [mkViTreeLams vi vs]
+ vars = varSetElems fvs
- mkViTreeApps vi [] = vi
- mkViTreeApps vi (_:vs) = VITNode VISimple [mkViTreeApps vi vs, VITNode VISimple []]
-
- vars = varSetElems fvs
- viTree' = mkViTreeApps (mkViTreeLams viTree vars) vars
-
- mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet
- mkAnnLam bndr ce = AnnLam bndr ce
-
- mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
- mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check!
- mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs
-
- mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet)
- mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v))
+ mkAnnLams :: [Var] -> VarSet -> AnnExpr' Var (VarSet, VectAvoidInfo) -> CoreExprWithVectInfo
+ mkAnnLams [] fvs expr = ASSERT(isEmptyVarSet fvs)
+ ((emptyVarSet, VIEncaps), expr)
+ mkAnnLams (v:vs) fvs expr = mkAnnLams vs (fvs `delVarSet` v) (AnnLam v ((fvs, VIEncaps), expr))
- mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
- mkAnnApps (fv, aex') [] = (fv, aex')
- mkAnnApps ae (v:vs) =
- let
- (fv, aex') = mkAnnApps ae vs
- in (extendVarSet fv v, mkAnnApp (fv, aex') v)
+ mkAnnApps :: CoreExprWithVectInfo -> [Var] -> CoreExprWithVectInfo
+ mkAnnApps aexpr [] = aexpr
+ mkAnnApps aexpr (v:vs) = mkAnnApps (mkAnnApp aexpr v) vs
+
+ mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo
+ mkAnnApp aexpr@((fvs, _vi), _expr) v
+ = ((fvs `extendVarSet` v, VISimple), AnnApp aexpr ((unitVarSet v, VISimple), AnnVar v))
-- |Vectorise an expression.
--
-vectExpr :: CoreExprWithFVs -> VITree -> VM VExpr
--- vectExpr e vi | not (checkTree vi (deAnnotate e))
--- = pprPanic "vectExpr" (ppr $ deAnnotate e)
-
-vectExpr (_, AnnVar v) _
+vectExpr :: CoreExprWithVectInfo -> VM VExpr
+
+vectExpr (_, AnnVar v)
= vectVar v
-vectExpr (_, AnnLit lit) _
+vectExpr (_, AnnLit lit)
= vectConst $ Lit lit
-vectExpr e@(_, AnnLam bndr _) vt
- | isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e vt
- | otherwise = do dflags <- getDynFlags
- cantVectorise dflags "Unexpected type lambda (vectExpr)" (ppr (deAnnotate e))
+vectExpr e@(_, AnnLam bndr _)
+ | isId bndr = vectFnExpr True False e
+ | otherwise
+ = do
+ { dflags <- getDynFlags
+ ; cantVectorise dflags "Unexpected type lambda (vectExpr)" $ ppr (deAnnotate e)
+ }
-- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
-- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
-- happy.
-- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now?
-vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) _
+vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
| v == pAT_ERROR_ID
- = do { (vty, lty) <- vectAndLiftType ty
- ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
- }
+ = do
+ { (vty, lty) <- vectAndLiftType ty
+ ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
+ }
where
err' = deAnnotate err
-- type application (handle multiple consecutive type applications simultaneously to ensure the
-- PA dictionaries are put at the right places)
-vectExpr e@(_, AnnApp _ arg) (VITNode _ [_, _])
+vectExpr e@(_, AnnApp _ arg)
| isAnnTypeArg arg
= vectPolyApp e
-
- -- 'Int', 'Float', or 'Double' literal
- -- FIXME: this needs to be generalised
-vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) _
- | Just con <- isDataConId_maybe v
- , is_special_con con
+
+ -- Lifted literal
+vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
+ | Just _con <- isDataConId_maybe v
= do
- let vexpr = App (Var v) (Lit lit)
- lexpr <- liftPD vexpr
- return (vexpr, lexpr)
- where
- is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
+ { let vexpr = App (Var v) (Lit lit)
+ ; lexpr <- liftPD vexpr
+ ; return (vexpr, lexpr)
+ }
-- value application (dictionary or user value)
-vectExpr e@(_, AnnApp fn arg) (VITNode _ [vit1, vit2])
+vectExpr e@(_, AnnApp fn arg)
| isPredTy arg_ty -- dictionary application (whose result is not a dictionary)
= vectPolyApp e
| otherwise -- user value
- = do { -- vectorise the types
- ; varg_ty <- vectType arg_ty
- ; vres_ty <- vectType res_ty
+ = do
+ { -- vectorise the types
+ ; varg_ty <- vectType arg_ty
+ ; vres_ty <- vectType res_ty
- -- vectorise the function and argument expression
- ; vfn <- vectExpr fn vit1
- ; varg <- vectExpr arg vit2
+ -- vectorise the function and argument expression
+ ; vfn <- vectExpr fn
+ ; varg <- vectExpr arg
- -- the vectorised function is a closure; apply it to the vectorised argument
- ; mkClosureApp varg_ty vres_ty vfn varg
- }
+ -- the vectorised function is a closure; apply it to the vectorised argument
+ ; mkClosureApp varg_ty vres_ty vfn varg
+ }
where
(arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
-vectExpr (_, AnnCase scrut bndr ty alts) vt
+vectExpr (_, AnnCase scrut bndr ty alts)
| Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
, isAlgTyCon tycon
- = vectAlgCase tycon ty_args scrut bndr ty alts vt
- | otherwise = do dflags <- getDynFlags
- cantVectorise dflags "Can't vectorise expression" (ppr scrut_ty)
+ = vectAlgCase tycon ty_args scrut bndr ty alts
+ | otherwise
+ = do
+ { dflags <- getDynFlags
+ ; cantVectorise dflags "Can't vectorise expression (no algebraic type constructor)" $
+ ppr scrut_ty
+ }
where
scrut_ty = exprType (deAnnotate scrut)
-vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) (VITNode _ [vt1, vt2])
+vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
- vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs (Just vt1)
- (vbndr, vbody) <- vectBndrIn bndr (vectExpr body vt2)
- return $ vLet (vNonRec vbndr vrhs) vbody
+ { vrhs <- localV $
+ inBind bndr $
+ vectAnnPolyExpr False rhs
+ ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
+ ; return $ vLet (vNonRec vbndr vrhs) vbody
+ }
-vectExpr (_, AnnLet (AnnRec bs) body) (VITNode _ (vtB : vtBnds))
+vectExpr (_, AnnLet (AnnRec bs) body)
= do
- (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
+ { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
$ liftM2 (,)
- (zipWith3M vect_rhs bndrs rhss vtBnds)
- (vectExpr body vtB)
- return $ vLet (vRec vbndrs vrhss) vbody
+ (zipWithM vect_rhs bndrs rhss)
+ (vectExpr body)
+ ; return $ vLet (vRec vbndrs vrhss) vbody
+ }
where
(bndrs, rhss) = unzip bs
- vect_rhs bndr rhs vt = localV
- . inBind bndr
- . liftM (\(_,_,z)->z)
- $ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs (Just vt)
- zipWith3M f xs ys zs = zipWithM (\x -> \(y,z) -> (f x y z)) xs (zip ys zs)
+ vect_rhs bndr rhs = localV $
+ inBind bndr $
+ vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) rhs
-vectExpr (_, AnnTick tickish expr) (VITNode _ [vit])
- = liftM (vTick tickish) (vectExpr expr vit)
+vectExpr (_, AnnTick tickish expr)
+ = vTick tickish <$> vectExpr expr
-vectExpr (_, AnnType ty) _
- = liftM vType (vectType ty)
+vectExpr (_, AnnType ty)
+ = vType <$> vectType ty
-vectExpr e vit = do dflags <- getDynFlags
- cantVectorise dflags "Can't vectorise expression (vectExpr)" (ppr (deAnnotate e) $$ text (" " ++ show vit))
+vectExpr e
+ = do
+ { dflags <- getDynFlags
+ ; cantVectorise dflags "Can't vectorise expression (vectExpr)" $ ppr (deAnnotate e)
+ }
--- |Vectorise an expression that *may* have an outer lambda abstraction.
+-- |Vectorise an expression that *may* have an outer lambda abstraction. If the expression is marked
+-- as encapsulated ('VIEncaps'), vectorise it as a scalar computation (using a generalised scalar
+-- zip).
--
-- We do not handle type variables at this point, as they will already have been stripped off by
--- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only
+-- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only
-- deal with Haskell 2011 and (2) class selectors are vectorised elsewhere.
--
-vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should
- -- be inlined
- -> Bool -- ^ Whether the binding is a loop breaker
- -> [Var] -- ^ Names of function in same recursive binding group
- -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam`
- -> VITree
- -> VM (Inline, Bool, VExpr)
--- vectFnExpr _ _ _ e vi | not (checkTree vi (deAnnotate e))
--- = pprPanic "vectFnExpr" (ppr $ deAnnotate e)
-vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) vt@(VITNode _ [vt'])
- -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type
+vectFnExpr :: Bool -- ^If we process the RHS of a binding, whether that binding
+ -- should be inlined
+ -> Bool -- ^Whether the binding is a loop breaker
+ -> CoreExprWithVectInfo -- ^Expression to vectorise; must have an outer `AnnLam`
+ -> VM VExpr
+vectFnExpr inline loop_breaker expr@(_ann, AnnLam bndr body)
+ -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type
| isId bndr
&& isPredTy (idType bndr)
- = do { vBndr <- vectBndr bndr
- ; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body vt'
- ; return (inline, isScalarFn, mapVect (mkLams [vectorised vBndr]) vbody)
- }
- -- non-predicate abstraction: vectorise (try to vectorise as a scalar computation)
+ = do
+ { vBndr <- vectBndr bndr
+ ; vbody <- vectFnExpr inline loop_breaker body
+ ; return $ mapVect (mkLams [vectorised vBndr]) vbody
+ }
+ -- non-predicate abstraction: vectorise as a scalar computation
+ | isId bndr && isVIEncaps expr
+ = vectScalarFun . deAnnotate $ expr
+ -- non-predicate abstraction: vectorise as a non-scalar computation
| isId bndr
- = mark DontInline True (vectScalarFunMaybe (deAnnotate expr) vt)
- `orElseV`
- mark inlineMe False (vectLam inline loop_breaker expr vt)
-vectFnExpr _ _ _ e vt
- -- not an abstraction: vectorise as a vanilla expression
- = mark DontInline False $ vectExpr e vt
-
-mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
-mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
+ = vectLam inline loop_breaker expr
+vectFnExpr _ _ expr
+ -- not an abstraction: vectorise as a vanilla expression
+ = vectExpr expr
-- |Vectorise type and dictionary applications.
--
-- These are always headed by a variable (as we don't support higher-rank polymorphism), but may
--- involve two sets of type variables and dictionaries. Consider,
+-- involve two sets of type variables and dictionaries. Consider,
--
-- > class C a where
-- > m :: D b => b -> a
--
-- The type of 'm' is 'm :: forall a. C a => forall b. D b => b -> a'.
--
-vectPolyApp :: CoreExprWithFVs -> VM VExpr
+vectPolyApp :: CoreExprWithVectInfo -> VM VExpr
vectPolyApp e0
= case e4 of
(_, AnnVar var)
@@ -530,21 +540,6 @@ vectDictExpr (Coercion coe)
-- instead they become dictionaries of vectorised methods). We treat them differently, though see
-- "Note [Scalar dfuns]" in 'Vectorise'.
--
-vectScalarFunMaybe :: CoreExpr -- ^ Expression to be vectorised
- -> VITree -- ^ Vectorisation information
- -> VM VExpr
-vectScalarFunMaybe expr (VITNode VIEncaps _) = vectScalarFun expr
-vectScalarFunMaybe _expr _ = noV $ ptext (sLit "not a scalar function")
-
--- |Vectorise an expression of functional type by lifting it by an application of a member of the
--- zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.) This is only a valid strategy if the
--- function does not contain parallel subcomputations and has only 'Scalar' types in its result and
--- arguments — this is a predcondition for calling this function.
---
--- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised,
--- instead they become dictionaries of vectorised methods). We treat them differently, though see
--- "Note [Scalar dfuns]" in 'Vectorise'.
---
vectScalarFun :: CoreExpr -> VM VExpr
vectScalarFun expr
= do
@@ -673,12 +668,11 @@ unVectDict ty e
-- variables are passed explicit (as conventional arguments) into the body during closure
-- construction.
--
-vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
- -> Bool -- ^ Whether the binding is a loop breaker.
- -> CoreExprWithFVs -- ^ Body of abstraction.
- -> VITree
+vectLam :: Bool -- ^ Should the RHS of a binding be inlined?
+ -> Bool -- ^ Whether the binding is a loop breaker.
+ -> CoreExprWithVectInfo -- ^ Body of abstraction.
-> VM VExpr
-vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi
+vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _)
= do { let (bndrs, body) = collectAnnValBinders expr
-- grab the in-scope type variables
@@ -706,18 +700,13 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi
. hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
$ do { -- generate the vectorised body of the lambda abstraction
; lc <- builtin liftingContext
- ; let viBody = stripLams expr vi
- -- ; checkTreeAnnM vi expr
- ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body viBody)
+ ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) $ vectExpr body
; vbody' <- break_loop lc res_ty vbody
; return $ vLams lc vbndrs vbody'
}
}
where
- stripLams (_, AnnLam _ e) (VITNode _ [vt]) = stripLams e vt
- stripLams _ vi = vi
-
maybe_inline n | inline = Inline n
| otherwise = DontInline
@@ -735,7 +724,7 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi
(LitAlt (mkMachInt 0), [], empty)])
}
| otherwise = return (ve, le)
-vectLam _ _ _ _ = panic "vectLam"
+vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda"
-- Vectorise an algebraic case expression.
--
@@ -754,31 +743,31 @@ vectLam _ _ _ _ = panic "vectLam"
--
-- FIXME: this is too lazy
-vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs-> Var -> Type
- -> [(AltCon, [Var], CoreExprWithFVs)] -> VITree
+vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type
+ -> [(AltCon, [Var], CoreExprWithVectInfo)]
-> VM VExpr
-vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] (VITNode _ (scrutVit : [altVit]))
+vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
= do
- vscrut <- vectExpr scrut scrutVit
+ vscrut <- vectExpr scrut
(vty, lty) <- vectAndLiftType ty
- (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
+ (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
-vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] (VITNode _ (scrutVit : [altVit]))
+vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
= do
- vscrut <- vectExpr scrut scrutVit
+ vscrut <- vectExpr scrut
(vty, lty) <- vectAndLiftType ty
- (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
+ (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
-vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _ (scrutVit : [altVit]))
+vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
= do
(vty, lty) <- vectAndLiftType ty
- vexpr <- vectExpr scrut scrutVit
+ vexpr <- vectExpr scrut
(vbndr, (vbndrs, (vect_body, lift_body)))
<- vect_scrut_bndr
. vectBndrsIn bndrs
- $ vectExpr body altVit
+ $ vectExpr body
let (vect_bndrs, lift_bndrs) = unzip vbndrs
(vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
vect_dc <- maybeV dataConErr (lookupDataCon dc)
@@ -796,9 +785,9 @@ vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _
dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
-vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
+vectAlgCase tycon _ty_args scrut bndr ty alts
= do
- vect_tc <- maybeV tyConErr (lookupTyCon tycon)
+ vect_tc <- vectTyCon tycon
(vty, lty) <- vectAndLiftType ty
let arity = length (tyConDataCons vect_tc)
@@ -807,10 +796,10 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
let sel = Var sel_bndr
(vbndr, valts) <- vect_scrut_bndr
- $ mapM (proc_alt arity sel vty lty) (zip alts' altVits)
+ $ mapM (proc_alt arity sel vty lty) alts'
let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
- vexpr <- vectExpr scrut scrutVit
+ vexpr <- vectExpr scrut
(vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
let (vect_bodies, lift_bodies) = unzip vbodies
@@ -829,8 +818,6 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
return . vLet (vNonRec vbndr vexpr)
$ (vect_case, lift_case)
where
- tyConErr = (text "vectAlgCase: type constructor not vectorised" <+> ppr tycon)
-
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
@@ -842,12 +829,12 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
cmp _ DEFAULT = GT
cmp _ _ = panic "vectAlgCase/cmp"
- proc_alt arity sel _ lty ((DataAlt dc, bndrs, body), vi)
+ proc_alt arity sel _ lty (DataAlt dc, bndrs, body@((fvs_body, _), _))
= do
vect_dc <- maybeV dataConErr (lookupDataCon dc)
let ntag = dataConTagZ vect_dc
tag = mkDataConTag vect_dc
- fvs = freeVarsOf body `delVarSetList` bndrs
+ fvs = fvs_body `delVarSetList` bndrs
sel_tags <- liftM (`App` sel) (builtin (selTags arity))
lc <- builtin liftingContext
@@ -860,7 +847,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
binds <- mapM (pack_var (Var lc) sel_tags tag)
. filter isLocalId
$ varSetElems fvs
- (ve, le) <- vectExpr body vi
+ (ve, le) <- vectExpr body
return (ve, Case (elems `App` sel) lc lty
[(DEFAULT, [], (mkLets (concat binds) le))])
-- empty <- emptyPD vty
@@ -892,9 +879,6 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
_ -> return []
-vectAlgCase tycon _ty_args _scrut _bndr _ty _alts (VITNode _ _)
- = pprPanic "vectAlgCase (mismatched node information)" (ppr tycon)
-
-- Support to compute information for vectorisation avoidance ------------------
@@ -905,202 +889,248 @@ data VectAvoidInfo = VIParr -- tree contains parallel computations
| VISimple -- result type is scalar & no parallel subcomputation
| VIComplex -- any result type, no parallel subcomputation
| VIEncaps -- tree encapsulated by 'liftSimple'
+ | VIDict -- dictionary computation (never parallel)
deriving (Eq, Show)
--- Instead of integrating the vectorisation avoidance information into Core expression, we keep
--- them in a separate tree (that structurally mirrors the Core expression that it annotates).
+-- Core expression annotated with free variables and vectorisation-specific information.
--
-data VITree = VITNode VectAvoidInfo [VITree]
- deriving (Show)
+type CoreExprWithVectInfo = AnnExpr Id (VarSet, VectAvoidInfo)
--- Is any of the tree nodes a 'VIPArr' node?
+-- Yield the type of an annotated core expression.
--
-anyVIPArr :: [VITree] -> Bool
-anyVIPArr = or . (map (\(VITNode vi _) -> vi == VIParr))
+annExprType :: AnnExpr Var ann -> Type
+annExprType = exprType . deAnnotate
--- Compute Core annotations to determine for which subexpressions we can avoid vectorisation
+-- Project the vectorisation information from an annotated Core expression.
--
--- FIXME: free scalar vars don't actually need to be passed through, since encapsulations makes sure,
--- that there are no free variables in encapsulated lambda expressions
-vectAvoidInfo :: CoreExprWithFVs -> VM VITree
-vectAvoidInfo ce@(_, AnnVar v)
- = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce
- ; viTrace ce vi []
- ; traceVt "vectAvoidInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce))
- ; return $ VITNode vi []
- }
+vectAvoidInfoOf :: CoreExprWithVectInfo -> VectAvoidInfo
+vectAvoidInfoOf ((_, vi), _) = vi
-vectAvoidInfo ce@(_, AnnLit _)
- = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce
- ; viTrace ce vi []
- ; traceVt "vectAvoidInfo AnnLit" (ppr $ exprType $ deAnnotate ce)
- ; return $ VITNode vi []
- }
+-- Is this a 'VIParr' node?
+--
+isVIParr :: CoreExprWithVectInfo -> Bool
+isVIParr = (== VIParr) . vectAvoidInfoOf
-vectAvoidInfo ce@(_, AnnApp e1 e2)
- = do { vt1 <- vectAvoidInfo e1
- ; vt2 <- vectAvoidInfo e2
- ; vi <- if anyVIPArr [vt1, vt2]
- then return VIParr
- else vectAvoidInfoType $ exprType $ deAnnotate ce
- ; viTrace ce vi [vt1, vt2]
- ; return $ VITNode vi [vt1, vt2]
- }
+-- Is this a 'VIEncaps' node?
+--
+isVIEncaps :: CoreExprWithVectInfo -> Bool
+isVIEncaps = (== VIEncaps) . vectAvoidInfoOf
-vectAvoidInfo ce@(_, AnnLam _var body)
- = do { vt@(VITNode vi _) <- vectAvoidInfo body
- ; viTrace ce vi [vt]
- ; let resultVI | vi == VIParr = VIParr
- | otherwise = VIComplex
- ; return $ VITNode resultVI [vt]
- }
+-- Is this a 'VIDict' node?
+--
+isVIDict :: CoreExprWithVectInfo -> Bool
+isVIDict = (== VIDict) . vectAvoidInfoOf
-vectAvoidInfo ce@(_, AnnLet (AnnNonRec _var expr) body)
- = do { vtE <- vectAvoidInfo expr
- ; vtB <- vectAvoidInfo body
- ; vi <- if anyVIPArr [vtE, vtB]
- then return VIParr
- else vectAvoidInfoType $ exprType $ deAnnotate ce
- ; viTrace ce vi [vtE, vtB]
- ; return $ VITNode vi [vtE, vtB]
- }
+-- 'VIParr' if either argument is 'VIParr'; otherwise, the first argument.
+--
+unlessVIParr :: VectAvoidInfo -> VectAvoidInfo -> VectAvoidInfo
+unlessVIParr _ VIParr = VIParr
+unlessVIParr vi _ = vi
-vectAvoidInfo ce@(_, AnnLet (AnnRec bnds) body)
- = do { let (_, exprs) = unzip bnds
- ; vtBnds <- mapM (\e -> vectAvoidInfo e) exprs
- ; if (anyVIPArr vtBnds)
- then do { vtBnds' <- mapM (\e -> vectAvoidInfo e) exprs
- ; vtB <- vectAvoidInfo body
- ; return (VITNode VIParr (vtB: vtBnds'))
- }
- else do { vtB@(VITNode vib _) <- vectAvoidInfo body
- ; ni <- if (vib == VIParr)
- then return VIParr
- else vectAvoidInfoType $ exprType $ deAnnotate ce
- ; viTrace ce ni (vtB : vtBnds)
- ; return $ VITNode ni (vtB : vtBnds)
- }
- }
+-- 'VIParr' if either arguments vectorisation information is 'VIParr'; otherwise, the vectorisation
+-- information of the first argument is produced.
+--
+unlessVIParrExpr :: VectAvoidInfo -> CoreExprWithVectInfo -> VectAvoidInfo
+infixl `unlessVIParrExpr`
+unlessVIParrExpr e1 e2 = e1 `unlessVIParr` vectAvoidInfoOf e2
-vectAvoidInfo ce@(_, AnnCase expr _var _ty alts)
- = do { vtExpr <- vectAvoidInfo expr
- ; vtAlts <- mapM (\(_, _, e) -> vectAvoidInfo e) alts
- ; ni <- if anyVIPArr (vtExpr : vtAlts)
- then return VIParr
- else vectAvoidInfoType $ exprType $ deAnnotate ce
- ; viTrace ce ni (vtExpr : vtAlts)
- ; return $ VITNode ni (vtExpr: vtAlts)
- }
+-- Compute Core annotations to determine for which subexpressions we can avoid vectorisation.
+--
+-- * The first argument is the set of free, local variables whose evaluation may entail parallelism.
+--
+vectAvoidInfo :: VarSet -> CoreExprWithFVs -> VM CoreExprWithVectInfo
+vectAvoidInfo pvs ce@(fvs, AnnVar v)
+ = do
+ { gpvs <- globalParallelVars
+ ; vi <- if v `elemVarSet` pvs || v `elemVarSet` gpvs
+ then return VIParr
+ else vectAvoidInfoTypeOf ce
+ ; viTrace ce vi []
-vectAvoidInfo (_, AnnCast expr _)
- = do { vt@(VITNode vi _) <- vectAvoidInfo expr
- ; return $ VITNode vi [vt]
- }
+ ; vit <- vectAvoidInfoTypeOf ce -- TEMPORARY
+ ; traceVt (" AnnVar: vectAvoidInfoTypeOf: " ++ show vit) empty
-vectAvoidInfo (_, AnnTick _ expr)
- = do { vt@(VITNode vi _) <- vectAvoidInfo expr
- ; return $ VITNode vi [vt]
- }
+ ; return ((fvs, vi), AnnVar v)
+ }
-vectAvoidInfo (_, AnnType {})
- = return $ VITNode VISimple []
+vectAvoidInfo _pvs ce@(fvs, AnnLit lit)
+ = do
+ { vi <- vectAvoidInfoTypeOf ce
+ ; viTrace ce vi []
+ ; return ((fvs, vi), AnnLit lit)
+ }
-vectAvoidInfo (_, AnnCoercion {})
- = return $ VITNode VISimple []
+vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2)
+ = do
+ { ceVI <- vectAvoidInfoTypeOf ce
+ ; eVI1 <- vectAvoidInfo pvs e1
+ ; eVI2 <- vectAvoidInfo pvs e2
+ ; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2
+ ; viTrace ce vi [eVI1, eVI2]
+ ; return ((fvs, vi), AnnApp eVI1 eVI2)
+ }
+
+vectAvoidInfo pvs ce@(fvs, AnnLam var body)
+ = do
+ { bodyVI <- vectAvoidInfo pvs body
+ ; varVI <- vectAvoidInfoType $ varType var
+ ; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI
+ ; viTrace ce vi [bodyVI]
+ ; return ((fvs, vi), AnnLam var bodyVI)
+ }
+
+vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body)
+ = do
+ { ceVI <- vectAvoidInfoTypeOf ce
+ ; eVI <- vectAvoidInfo pvs e
+ ; isScalarTy <- isScalar $ varType var
+ ; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy
+ then do -- binding is parallel
+ { bodyVI <- vectAvoidInfo (fvs `extendVarSet` var) body
+ ; return (bodyVI, VIParr)
+ }
+ else do -- binding doesn't affect parallelism
+ { bodyVI <- vectAvoidInfo fvs body
+ ; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI)
+ }
+ ; viTrace ce vi [eVI, bodyVI]
+ ; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI)
+ }
+
+vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body)
+ = do
+ { ceVI <- vectAvoidInfoTypeOf ce
+ ; bndsVI <- mapM (vectAvoidInfoBnd pvs) bnds
+ ; parrBndrs <- map fst <$> filterM isVIParrBnd bndsVI
+ ; if not . null $ parrBndrs
+ then do -- body may trigger parallelism via at least one binding
+ { new_pvs <- filterM ((not <$>) . isScalar . varType) parrBndrs
+ ; let extendedPvs = pvs `extendVarSetList` new_pvs
+ ; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds
+ ; bodyVI <- vectAvoidInfo extendedPvs body
+ ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI])
+ ; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI)
+ }
+ else do -- demanded bindings cannot trigger parallelism
+ { bodyVI <- vectAvoidInfo pvs body
+ ; let vi = ceVI `unlessVIParrExpr` bodyVI
+ ; viTrace ce vi (map snd bndsVI ++ [bodyVI])
+ ; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI)
+ }
+ }
+ where
+ vectAvoidInfoBnd pvs (var, e) = (var,) <$> vectAvoidInfo pvs e
+
+ isVIParrBnd (var, eVI)
+ = do
+ { isScalarTy <- isScalar (varType var)
+ ; return $ isVIParr eVI && not isScalarTy
+ }
+
+vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts)
+ = do
+ { ceVI <- vectAvoidInfoTypeOf ce
+ ; eVI <- vectAvoidInfo pvs e
+ ; isScalarTy <- isScalar . annExprType $ e
+ ; altsVI <- mapM (vectAvoidInfoAlt (isVIParr eVI && not isScalarTy)) alts
+ ; allScalarBndrs <- anyM allScalarAltBndrs altsVI
+ ; let alteVIs = [eVI | (_, _, eVI) <- altsVI]
+ vi | isVIParr eVI && not allScalarBndrs = VIParr
+ | otherwise
+ = foldl unlessVIParrExpr ceVI alteVIs
+ ; viTrace ce vi (eVI : alteVIs)
+ ; return ((fvs, vi), AnnCase eVI var ty altsVI)
+ }
+ where
+ vectAvoidInfoAlt isScalarScrut (con, bndrs, e) = (con, bndrs,) <$> vectAvoidInfo altPvs e
+ where
+ altPvs | isScalarScrut = pvs
+ | otherwise = pvs `extendVarSetList` bndrs
+
+ allScalarAltBndrs (_, bndrs, _) = allScalarVarType bndrs
+
+vectAvoidInfo pvs (fvs, AnnCast e (fvs_ann, ann))
+ = do
+ { eVI <- vectAvoidInfo pvs e
+ ; return ((fvs, vectAvoidInfoOf eVI), AnnCast eVI ((fvs_ann, VISimple), ann))
+ }
+
+vectAvoidInfo pvs (fvs, AnnTick tick e)
+ = do
+ { eVI <- vectAvoidInfo pvs e
+ ; return ((fvs, vectAvoidInfoOf eVI), AnnTick tick eVI)
+ }
+
+vectAvoidInfo _pvs (fvs, AnnType ty)
+ = return ((fvs, VISimple), AnnType ty)
+
+vectAvoidInfo _pvs (fvs, AnnCoercion coe)
+ = return ((fvs, VISimple), AnnCoercion coe)
-- Compute vectorisation avoidance information for a type.
--
vectAvoidInfoType :: Type -> VM VectAvoidInfo
-vectAvoidInfoType ty
- | maybeParrTy ty = return VIParr
- | otherwise
- = do { sType <- isSimpleType ty
- ; if sType
- then return VISimple
- else return VIComplex
- }
+vectAvoidInfoType ty
+ | isPredTy ty
+ = return VIDict
+ | Just (arg, res) <- splitFunTy_maybe ty
+ = do
+ { argVI <- vectAvoidInfoType arg
+ ; resVI <- vectAvoidInfoType res
+ ; case (argVI, resVI) of
+ (VISimple, VISimple) -> return VISimple -- NB: diverts from the paper: scalar functions
+ (_ , VIDict) -> return VIDict
+ _ -> return $ VIComplex `unlessVIParr` argVI `unlessVIParr` resVI
+ }
+ | otherwise
+ = do
+ { parr <- maybeParrTy ty
+ ; if parr
+ then return VIParr
+ else do
+ { scalar <- isScalar ty
+ ; if scalar
+ then return VISimple
+ else return VIComplex
+ } }
+
+-- Compute vectorisation avoidance information for the type of a Core expression (with FVs).
+--
+vectAvoidInfoTypeOf :: AnnExpr Var ann -> VM VectAvoidInfo
+vectAvoidInfoTypeOf = vectAvoidInfoType . annExprType
--- Checks whether the type might be a parallel array type. In particular, if the outermost
--- constructor is a type family, we conservatively assume that it may be a parallel array type.
+-- Checks whether the type might be a parallel array type.
--
-maybeParrTy :: Type -> Bool
+maybeParrTy :: Type -> VM Bool
maybeParrTy ty
- | Just ty' <- coreView ty = maybeParrTy ty'
- | Just (tyCon, ts) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon
- || or (map maybeParrTy ts)
-maybeParrTy _ = False
-
--- FIXME: This should not be hardcoded.
-isSimpleType :: Type -> VM Bool
-isSimpleType ty
- | Just (c, _cs) <- splitTyConApp_maybe ty
- = return $ (tyConName c) `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName]
-{-
- = do { globals <- globalScalarTyCons
- ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c)
- ; return (elemNameSet (tyConName c) globals )
- }
- -}
- | Nothing <- splitTyConApp_maybe ty
- = return False
-isSimpleType ty
- = pprPanic "Vectorise.Exp.isSimpleType not handled" (ppr ty)
-
-varsSimple :: VarSet -> VM Bool
-varsSimple vs
- = do { varTypes <- mapM isSimpleType $ map varType $ varSetElems vs
- ; return $ and varTypes
- }
-
-viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [VITree] -> VM ()
-viTrace ce vi vTs
- = traceVt ("vitrace " ++ (show vi) ++ "[" ++ (concat $ map (\(VITNode vi _) -> show vi ++ " ") vTs) ++"]")
- (ppr $ deAnnotate ce)
-
+ -- looking through newtypes
+ | Just ty' <- coreView ty
+ = (== VIParr) <$> vectAvoidInfoType ty'
+ -- decompose constructor applications
+ | Just (tc, ts) <- splitTyConApp_maybe ty
+ = do
+ { isParallel <- (tyConName tc `elemNameSet`) <$> globalParallelTyCons
+ ; if isParallel
+ then return True
+ else or <$> mapM maybeParrTy ts
+ }
+maybeParrTy (ForAllTy _ ty) = maybeParrTy ty
+maybeParrTy _ = return False
-{-
----- Sanity check of the tree, for debugging only
-checkTree :: VITree -> CoreExpr -> Bool
-checkTree (VITNode _ []) (Type _ty)
- = True
-
-checkTree (VITNode _ []) (Var _v)
- = True
-
-checkTree (VITNode _ []) (Lit _)
- = True
-
-checkTree (VITNode _ [vit]) (Tick _ expr)
- = checkTree vit expr
-
-checkTree (VITNode _ [vit]) (Lam _ expr)
- = checkTree vit expr
-
-checkTree (VITNode _ [vit1, vit2]) (App ce1 ce2)
- = (checkTree vit1 ce1) && (checkTree vit2 ce2)
-
-checkTree (VITNode _ (scrutVit : altVits)) (Case scrut _ _ alts)
- = (checkTree scrutVit scrut) && (and $ zipWith checkAlt altVits alts)
- where
- checkAlt vt (_, _, expr) = checkTree vt expr
-
-checkTree (VITNode _ [vt1, vt2]) (Let (NonRec _ expr1) expr2)
- = (checkTree vt1 expr1) && (checkTree vt2 expr2)
-
-checkTree (VITNode _ (vtB : vtBnds)) (Let (Rec bndngs) expr)
- = (and $ zipWith checkBndr vtBnds bndngs) &&
- (checkTree vtB expr)
- where
- checkBndr vt (_, e) = checkTree vt e
-
-checkTree (VITNode _ [vit]) (Cast expr _)
- = checkTree vit expr
+-- Are the types of all variables in the 'Scalar' class?
+--
+allScalarVarType :: [Var] -> VM Bool
+allScalarVarType vs = and <$> mapM (isScalar . varType) vs
-checkTree _ _ = False
+-- Are the types of all variables in the set in the 'Scalar' class?
+--
+allScalarVarTypeSet :: VarSet -> VM Bool
+allScalarVarTypeSet = allScalarVarType . varSetElems
-checkTreeAnnM:: VITree -> CoreExprWithFVs -> VM ()
-checkTreeAnnM vi e =
- if not (checkTree vi $ deAnnotate e)
- then error ("checkTreeAnnM : \n " ++ show vi)
- else return ()
--}
+-- Debugging support
+--
+viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [CoreExprWithVectInfo] -> VM ()
+viTrace ce vi vTs
+ = traceVt ("vect info: " ++ show vi ++ "[" ++
+ (concat $ map ((++ " ") . show . vectAvoidInfoOf) vTs) ++ "]")
+ (ppr $ deAnnotate ce)
diff --git a/compiler/vectorise/Vectorise/Monad.hs b/compiler/vectorise/Vectorise/Monad.hs
index 375b0af85e..6b5e9cc354 100644
--- a/compiler/vectorise/Vectorise/Monad.hs
+++ b/compiler/vectorise/Vectorise/Monad.hs
@@ -14,8 +14,8 @@ module Vectorise.Monad (
-- * Variables
lookupVar,
lookupVar_maybe,
- addGlobalScalarVar,
- addGlobalScalarTyCon,
+ addGlobalParallelVar,
+ addGlobalParallelTyCon,
) where
import Vectorise.Monad.Base
@@ -172,22 +172,22 @@ dumpVar dflags var
= cantVectorise dflags "Variable not vectorised:" (ppr var)
--- Global scalars --------------------------------------------------------------
+-- Global parallel entities ----------------------------------------------------
--- |Mark the given variable as scalar — i.e., executing the associated code does not involve any
+-- |Mark the given variable as parallel — i.e., executing the associated code might involve
-- parallel array computations.
--
-addGlobalScalarVar :: Var -> VM ()
-addGlobalScalarVar var
- = do { traceVt "addGlobalScalarVar" (ppr var)
- ; updGEnv $ \env -> env{global_scalar_vars = extendVarSet (global_scalar_vars env) var}
+addGlobalParallelVar :: Var -> VM ()
+addGlobalParallelVar var
+ = do { traceVt "addGlobalParallelVar" (ppr var)
+ ; updGEnv $ \env -> env{global_parallel_vars = extendVarSet (global_parallel_vars env) var}
}
--- |Mark the given type constructor as scalar — i.e., its values cannot embed parallel arrays.
+-- |Mark the given type constructor as parallel — i.e., its values might embed parallel arrays.
--
-addGlobalScalarTyCon :: TyCon -> VM ()
-addGlobalScalarTyCon tycon
- = do { traceVt "addGlobalScalarTyCon" (ppr tycon)
+addGlobalParallelTyCon :: TyCon -> VM ()
+addGlobalParallelTyCon tycon
+ = do { traceVt "addGlobalParallelTyCon" (ppr tycon)
; updGEnv $ \env ->
- env{global_scalar_tycons = addOneToNameSet (global_scalar_tycons env) (tyConName tycon)}
+ env{global_parallel_tycons = addOneToNameSet (global_parallel_tycons env) (tyConName tycon)}
}
diff --git a/compiler/vectorise/Vectorise/Monad/Global.hs b/compiler/vectorise/Vectorise/Monad/Global.hs
index a5c8449fc2..0fe460ad73 100644
--- a/compiler/vectorise/Vectorise/Monad/Global.hs
+++ b/compiler/vectorise/Vectorise/Monad/Global.hs
@@ -6,13 +6,13 @@ module Vectorise.Monad.Global (
updGEnv,
-- * Vars
- defGlobalVar,
+ defGlobalVar, undefGlobalVar,
-- * Vectorisation declarations
- lookupVectDecl, noVectDecl,
+ lookupVectDecl,
-- * Scalars
- globalScalarVars, isGlobalScalarVar, globalScalarTyCons,
+ globalParallelVars, globalParallelTyCons,
-- * TyCons
lookupTyCon,
@@ -93,48 +93,54 @@ defGlobalVar v v'
| otherwise
= ptext (sLit "in the current module")
+-- |Remove the mapping of a variable in the vectorisation map.
+--
+undefGlobalVar :: Var -> VM ()
+undefGlobalVar v
+ = do
+ { traceVt "REMOVING global var mapping:" (ppr v)
+ ; updGEnv $ \env -> env { global_vars = delVarEnv (global_vars env) v }
+ }
+
-- Vectorisation declarations -------------------------------------------------
--- |Check whether a variable has a (non-scalar) vectorisation declaration.
+-- |Check whether a variable has a vectorisation declaration.
--
-lookupVectDecl :: Var -> VM (Maybe (Type, CoreExpr))
-lookupVectDecl var = readGEnv $ \env -> lookupVarEnv (global_vect_decls env) var
-
--- |Check whether a variable has a 'NOVECTORISE' declaration.
+-- The first component of the result indicates whether the variable has a 'NOVECTORISE' declaration.
+-- The second component contains the given type and expression in case of a 'VECTORISE' declaration.
--
-noVectDecl :: Var -> VM Bool
-noVectDecl var = readGEnv $ \env -> elemVarSet var (global_novect_vars env)
+lookupVectDecl :: Var -> VM (Bool, Maybe (Type, CoreExpr))
+lookupVectDecl var
+ = readGEnv $ \env ->
+ case lookupVarEnv (global_vect_decls env) var of
+ Nothing -> (False, Nothing)
+ Just Nothing -> (True, Nothing)
+ Just vectDecl -> (False, vectDecl)
--- Scalars --------------------------------------------------------------------
+-- Parallel entities -----------------------------------------------------------
--- |Get the set of global scalar variables.
+-- |Get the set of global parallel variables.
--
-globalScalarVars :: VM VarSet
-globalScalarVars = readGEnv global_scalar_vars
+globalParallelVars :: VM VarSet
+globalParallelVars = readGEnv global_parallel_vars
--- |Check whether a given variable is in the set of global scalar variables.
+-- |Get the set of all parallel type constructors (those that may embed parallelism) including both
+-- both those parallel type constructors declared in an imported module and those declared in the
+-- current module.
--
-isGlobalScalarVar :: Var -> VM Bool
-isGlobalScalarVar var = readGEnv $ \env -> var `elemVarSet` global_scalar_vars env
-
--- |Get the set of global scalar type constructors including both those scalar type constructors
--- declared in an imported module and those declared in the current module.
---
-globalScalarTyCons :: VM NameSet
-globalScalarTyCons = readGEnv global_scalar_tycons
+globalParallelTyCons :: VM NameSet
+globalParallelTyCons = readGEnv global_parallel_tycons
-- TyCons ---------------------------------------------------------------------
--- |Lookup the vectorised version of a `TyCon` from the global environment.
+-- |Determine the vectorised version of a `TyCon`. The vectorisation map in the global environment
+-- contains a vectorised version if the original `TyCon` embeds any parallel arrays.
--
lookupTyCon :: TyCon -> VM (Maybe TyCon)
lookupTyCon tc
- | isUnLiftedTyCon tc || isTupleTyCon tc
- = return (Just tc)
- | otherwise
= readGEnv $ \env -> lookupNameEnv (global_tycons env) (tyConName tc)
-- |Add a mapping between plain and vectorised `TyCon`s to the global environment.
diff --git a/compiler/vectorise/Vectorise/Monad/InstEnv.hs b/compiler/vectorise/Vectorise/Monad/InstEnv.hs
index fc12ee567c..95546bf503 100644
--- a/compiler/vectorise/Vectorise/Monad/InstEnv.hs
+++ b/compiler/vectorise/Vectorise/Monad/InstEnv.hs
@@ -1,5 +1,6 @@
module Vectorise.Monad.InstEnv
- ( lookupInst
+ ( existsInst
+ , lookupInst
, lookupFamInst
)
where
@@ -21,6 +22,14 @@ import Util
#include "HsVersions.h"
+-- Check whether a unique class instance for a given class and type arguments exists.
+--
+existsInst :: Class -> [Type] -> VM Bool
+existsInst cls tys
+ = do { instEnv <- readGEnv global_inst_env
+ ; return $ either (const False) (const True) (lookupUniqueInstEnv instEnv cls tys)
+ }
+
-- Look up the dfun of a class instance.
--
-- The match must be unique —i.e., match exactly one instance— but the
@@ -64,6 +73,6 @@ lookupFamInst tycon tys
[(fam_inst, rep_tys)] -> return ( fam_inst, rep_tys)
_other ->
do dflags <- getDynFlags
- cantVectorise dflags "VectMonad.lookupFamInst: not found: "
+ cantVectorise dflags "Vectorise.Monad.InstEnv.lookupFamInst: not found: "
(ppr $ mkTyConApp tycon tys)
}
diff --git a/compiler/vectorise/Vectorise/Monad/Local.hs b/compiler/vectorise/Vectorise/Monad/Local.hs
index 8b3c1dcf19..5415c5691d 100644
--- a/compiler/vectorise/Vectorise/Monad/Local.hs
+++ b/compiler/vectorise/Vectorise/Monad/Local.hs
@@ -44,20 +44,24 @@ updLEnv f = VM $ \_ genv lenv -> return (Yes genv (f lenv) ())
--
localV :: VM a -> VM a
localV p
- = do env <- readLEnv id
- x <- p
- setLEnv env
- return x
+ = do
+ { env <- readLEnv id
+ ; x <- p
+ ; setLEnv env
+ ; return x
+ }
-- |Perform a computation in an empty local environment.
--
closedV :: VM a -> VM a
closedV p
- = do env <- readLEnv id
- setLEnv (emptyLocalEnv { local_bind_name = local_bind_name env })
- x <- p
- setLEnv env
- return x
+ = do
+ { env <- readLEnv id
+ ; setLEnv (emptyLocalEnv { local_bind_name = local_bind_name env })
+ ; x <- p
+ ; setLEnv env
+ ; return x
+ }
-- |Get the name of the local binding currently being vectorised.
--
diff --git a/compiler/vectorise/Vectorise/Type/Classify.hs b/compiler/vectorise/Vectorise/Type/Classify.hs
index 0cab706cf4..e1cd43ac3c 100644
--- a/compiler/vectorise/Vectorise/Type/Classify.hs
+++ b/compiler/vectorise/Vectorise/Type/Classify.hs
@@ -13,10 +13,12 @@
-- types. As '([::])' is being vectorised, any type constructor whose definition involves
-- '([::])', either directly or indirectly, will be vectorised.
-module Vectorise.Type.Classify (
- classifyTyCons
-) where
+module Vectorise.Type.Classify
+ ( classifyTyCons
+ )
+where
+import NameSet
import UniqSet
import UniqFM
import DataCon
@@ -29,7 +31,7 @@ import Digraph
-- |From a list of type constructors, extract those that can be vectorised, returning them in two
-- sets, where the first result list /must be/ vectorised and the second result list /need not be/
--- vectorised. The third result list are those type constructors that we cannot convert (either
+-- vectorised. The third result list are those type constructors that we cannot convert (either
-- because they use language extensions or because they dependent on type constructors for which
-- no vectorised version is available).
@@ -37,28 +39,40 @@ import Digraph
--
-- * tycons which have converted versions are mapped to 'True'
-- * tycons which are not changed by vectorisation are mapped to 'False'
--- * tycons which can't be converted are not elements of the map
+-- * tycons which haven't been converted (because they can't or weren't vectorised) are not
+-- elements of the map
--
-classifyTyCons :: UniqFM Bool -- ^type constructor conversion status
- -> [TyCon] -- ^type constructors that need to be classified
- -> ([TyCon], [TyCon], [TyCon]) -- ^tycons to be converted & not to be converted
-classifyTyCons convStatus tcs = classify [] [] [] convStatus (tyConGroups tcs)
+classifyTyCons :: UniqFM Bool -- ^type constructor vectorisation status
+ -> NameSet -- ^tycons involving parallel arrays
+ -> [TyCon] -- ^type constructors that need to be classified
+ -> ( [TyCon] -- to be converted
+ , [TyCon] -- need not be converted (but could be)
+ , [TyCon] -- can't be converted, but involve parallel arrays
+ , [TyCon] -- can't be converted and have no parallel arrays
+ )
+classifyTyCons convStatus parTyCons tcs = classify [] [] [] [] convStatus parTyCons (tyConGroups tcs)
where
- classify conv keep ignored _ [] = (conv, keep, ignored)
- classify conv keep ignored cs ((tcs, ds) : rs)
+ classify conv keep par novect _ _ [] = (conv, keep, par, novect)
+ classify conv keep par novect cs pts ((tcs, ds) : rs)
| can_convert && must_convert
- = classify (tcs ++ conv) keep ignored (cs `addListToUFM` [(tc, True) | tc <- tcs]) rs
+ = classify (tcs ++ conv) keep par novect (cs `addListToUFM` [(tc, True) | tc <- tcs]) pts' rs
| can_convert
- = classify conv (tcs ++ keep) ignored (cs `addListToUFM` [(tc, False) | tc <- tcs]) rs
+ = classify conv (tcs ++ keep) par novect (cs `addListToUFM` [(tc, False) | tc <- tcs]) pts' rs
+ | has_parr
+ = classify conv keep (tcs ++ par) novect cs pts' rs
| otherwise
- = classify conv keep (tcs ++ ignored) cs rs
+ = classify conv keep par (tcs ++ novect) cs pts' rs
where
refs = ds `delListFromUniqSet` tcs
+
+ pts' | has_parr = pts `addListToNameSet` map tyConName tcs
+ | otherwise = pts
can_convert = (isNullUFM (refs `minusUFM` cs) && all convertable tcs)
|| isShowClass tcs
must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
&& (not . isShowClass $ tcs)
+ has_parr = any ((`elemNameSet` parTyCons) . tyConName) . eltsUFM $ refs
-- We currently admit Haskell 2011-style data and newtype declarations as well as type
-- constructors representing classes.
diff --git a/compiler/vectorise/Vectorise/Type/Env.hs b/compiler/vectorise/Vectorise/Type/Env.hs
index 0051d072a4..faa80a8629 100644
--- a/compiler/vectorise/Vectorise/Type/Env.hs
+++ b/compiler/vectorise/Vectorise/Type/Env.hs
@@ -32,7 +32,9 @@ import Id
import MkId
import NameEnv
import NameSet
+import UniqFM
import OccName
+import Unique
import Util
import Outputable
@@ -47,69 +49,85 @@ import Data.List
-- Note [Pragmas to vectorise tycons]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--
--- VECTORISE pragmas for type constructors cover three different flavours of vectorising data type
+-- All imported type constructors that are not mapped to a vectorised type in the vectorisation map
+-- (possibly because the defining module was not compiled with vectorisation) may be used in scalar
+-- code encapsulated in vectorised code. If a such a type constructor 'T' is a member of the
+-- 'Scalar' class (and hence also of 'PData' and 'PRepr'), it may also be used in vectorised code,
+-- where 'T' represents itself, but the representation of 'T' still remains opaque in vectorised
+-- code (i.e., it can only be used in scalar code).
+--
+-- An example is the treatment of 'Int'. 'Int's can be used in vectorised code and remain unchanged
+-- by vectorisation. However, the representation of 'Int' by the 'I#' data constructor wrapping an
+-- 'Int#' is not exposed in vectorised code. Instead, computations involving the representation need
+-- to be confined to scalar code.
+--
+-- VECTORISE pragmas for type constructors cover four different flavours of vectorising data type
-- constructors:
--
--- (1) Data type constructor 'T' that may be used in vectorised code, where 'T' represents itself,
--- but the representation of 'T' is opaque in vectorised code.
+-- (1) Data type constructor 'T' that together with its constructors 'Cn' may be used in vectorised
+-- code, where 'T' and the 'Cn' are automatically vectorised in the same manner as data types
+-- declared in a vectorised module. This includes the case where the vectoriser determines that
+-- the original representation of 'T' may be used in vectorised code (as it does not embed any
+-- parallel arrays.) This case is for type constructors that are *imported* from a non-
+-- vectorised module, but that we want to use with full vectorisation support.
--
--- An example is the treatment of 'Int'. 'Int's can be used in vectorised code and remain
--- unchanged by vectorisation. However, the representation of 'Int' by the 'I#' data
--- constructor wrapping an 'Int#' is not exposed in vectorised code. Instead, computations
--- involving the representation need to be confined to scalar code.
+-- An example is the treatment of 'Ordering' and '[]'. The former remains unchanged by
+-- vectorisation, whereas the latter is fully vectorised.
--
--- 'PData' and 'PRepr' instances need to be explicitly supplied for 'T' (they are not generated
--- by the vectoriser).
+-- 'PData' and 'PRepr' instances are automatically generated by the vectoriser.
--
--- Type constructors declared with {-# VECTORISE SCALAR type T #-} are treated in this manner.
--- (The vectoriser never treats a type constructor automatically in this manner.)
+-- Type constructors declared with {-# VECTORISE type T #-} are treated in this manner.
--
-- (2) Data type constructor 'T' that may be used in vectorised code, where 'T' is represented by an
--- explicitly given 'Tv', but the representation of 'T' is opaque in vectorised code.
+-- explicitly given 'Tv', but the representation of 'T' is opaque in vectorised code (i.e., the
+-- constructors of 'T' may not occur in vectorised code).
--
--- An example is the treatment of '[::]'. '[::]'s can be used in vectorised code and is
--- vectorised to 'PArray'. However, the representation of '[::]' is not exposed in vectorised
--- code. Instead, computations involving the representation need to be confined to scalar code.
+-- An example is the treatment of '[::]'. The type '[::]' can be used in vectorised code and is
+-- vectorised to 'PArray'. However, the representation of '[::]' is not exposed in vectorised
+-- code. Instead, computations involving the representation need to be confined to scalar code.
--
-- 'PData' and 'PRepr' instances need to be explicitly supplied for 'T' (they are not generated
-- by the vectoriser).
--
--- Type constructors declared with {-# VECTORISE SCALAR type T = T' #-} are treated in this
+-- Type constructors declared with {-# VECTORISE type T = Tv #-} are treated in this manner
-- manner. (The vectoriser never treats a type constructor automatically in this manner.)
--
--- (3) Data type constructor 'T' that together with its constructors 'Cn' may be used in vectorised
--- code, where 'T' and the 'Cn' are automatically vectorised in the same manner as data types
--- declared in a vectorised module. This includes the case where the vectoriser determines that
--- the original representation of 'T' may be used in vectorised code (as it does not embed any
--- parallel arrays.) This case is for type constructors that are *imported* from a non-
--- vectorised module, but that we want to use with full vectorisation support.
+-- (3) Data type constructor 'T' that does not contain any parallel arrays and has explicitly
+-- provided 'PData' and 'PRepr' instances (and maybe also a 'Scalar' instance), which together
+-- with the type's constructors 'Cn' may be used in vectorised code. The type 'T' and its
+-- constructors 'Cn' are represented by themselves in vectorised code.
--
--- An example is the treatment of 'Ordering' and '[]'. The former remains unchanged by
--- vectorisation, whereas the latter is fully vectorised.
-
--- 'PData' and 'PRepr' instances are automatically generated by the vectoriser.
+-- An example is 'Bool', which is represented by itself in vectorised code (as it cannot embed
+-- any parallel arrays). However, we do not want any automatic generation of class and family
+-- instances, which is why Case (1) does not apply.
--
--- Type constructors declared with {-# VECTORISE type T #-} are treated in this manner.
+-- 'PData' and 'PRepr' instances need to be explicitly supplied for 'T' (they are not generated
+-- by the vectoriser).
+--
+-- Type constructors declared with {-# VECTORISE SCALAR type T #-} are treated in this manner.
--
--- (4) Data type constructor 'T' that together with its constructors 'Cn' may be used in vectorised
--- code, where 'T' is represented by an explicitly given 'Tv' whose constructors 'Cvn' represent
--- the original constructors in vectorised code. As a special case, we can have 'Tv = T'
+-- (4) Data type constructor 'T' that does not contain any parallel arrays and that, in vectorised
+-- code, is represented by an explicitly given 'Tv', but the representation of 'T' is opaque in
+-- vectorised code and 'T' is regarded to be scalar — i.e., it may be used in encapsulated
+-- scalar subcomputations.
--
--- An example is the treatment of 'Bool', which is represented by itself in vectorised code
--- (as it cannot embed any parallel arrays). However, we do not want any automatic generation
--- of class and family instances, which is why Case (3) does not apply.
+-- An example is the treatment of '(->)'. Types '(->)' can be used in vectorised code and are
+-- vectorised to '(:->)'. However, the representation of '(->)' is not exposed in vectorised
+-- code. Instead, computations involving the representation need to be confined to scalar code
+-- and may be part of encapsulated scalar computations.
--
-- 'PData' and 'PRepr' instances need to be explicitly supplied for 'T' (they are not generated
-- by the vectoriser).
--
--- Type constructors declared with {-# VECTORISE type T = T' #-} are treated in this manner.
+-- Type constructors declared with {-# VECTORISE SCALAR type T = Tv #-} are treated in this
+-- manner. (The vectoriser never treats a type constructor automatically in this manner.)
--
-- In addition, we have also got a single pragma form for type classes: {-# VECTORISE class C #-}.
-- It implies that the class type constructor may be used in vectorised code together with its data
-- constructor. We generally produce a vectorised version of the data type and data constructor.
-- We do not generate 'PData' and 'PRepr' instances for class type constructors. This pragma is the
--- default for all type classes declared in this module, but the pragma can also be used explitly on
--- imported classes.
+-- default for all type classes declared in a vectorised module, but the pragma can also be used
+-- explitly on imported classes.
-- Note [Vectorising classes]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -147,38 +165,36 @@ vectTypeEnv :: [TyCon] -- Type constructors defined in this mod
vectTypeEnv tycons vectTypeDecls vectClassDecls
= do { traceVt "** vectTypeEnv" $ ppr tycons
- ; let -- {-# VECTORISE SCALAR type T -#} (imported and local tycons)
- localAbstractTyCons = [tycon | VectType True tycon Nothing <- vectTypeDecls]
-
- -- {-# VECTORISE type T -#} (ONLY the imported tycons)
+ ; let -- {-# VECTORISE type T -#} (ONLY the imported tycons)
impVectTyCons = ( [tycon | VectType False tycon Nothing <- vectTypeDecls]
++ [tycon | VectClass tycon <- vectClassDecls])
\\ tycons
+
+ -- {-# VECTORISE [SCALAR] type T = Tv -#} (imported & local tycons with an /RHS/)
+ vectTyConsWithRHS = [ (tycon, rhs, isScalar)
+ | VectType isScalar tycon (Just rhs) <- vectTypeDecls]
- -- {-# VECTORISE [SCALAR] type T = T' -#} (imported and local tycons)
- vectTyConsWithRHS = [ (tycon, rhs, isAbstract)
- | VectType isAbstract tycon (Just rhs) <- vectTypeDecls]
+ -- {-# VECTORISE SCALAR type T -#} (imported & local /scalar/ tycons without an RHS)
+ scalarTyConsNoRHS = [tycon | VectType True tycon Nothing <- vectTypeDecls]
- -- filter VECTORISE SCALAR tycons and VECTORISE tycons with explicit rhses
- vectSpecialTyConNames = mkNameSet . map tyConName $
- localAbstractTyCons ++ map fst3 vectTyConsWithRHS
+ -- Check that is not a VECTORISE SCALAR tycon nor VECTORISE tycons with explicit rhs?
+ vectSpecialTyConNames = mkNameSet . map tyConName $
+ scalarTyConsNoRHS ++ map fst3 vectTyConsWithRHS
notVectSpecialTyCon tc = not $ (tyConName tc) `elemNameSet` vectSpecialTyConNames
- -- Build a map containing all vectorised type constructor. If they are scalar, they are
- -- mapped to 'False' (vectorised type constructor == original type constructor).
- ; allScalarTyConNames <- globalScalarTyCons -- covers both current and imported modules
+ -- Build a map containing all vectorised type constructor. If the vectorised type
+ -- constructor differs from the original one, then it is mapped to 'True'; if they are
+ -- both the same, then it maps to 'False'.
; vectTyCons <- globalVectTyCons
- ; let vectTyConBase = mapNameEnv (const True) vectTyCons -- by default fully vectorised
+ ; let vectTyConBase = mapUFM_Directly isDistinct vectTyCons -- 'True' iff tc /= V[[tc]]
+ isDistinct u tc = u /= getUnique tc
vectTyConFlavour = vectTyConBase
`plusNameEnv`
mkNameEnv [ (tyConName tycon, True)
| (tycon, _, _) <- vectTyConsWithRHS]
`plusNameEnv`
- mkNameEnv [ (tcName, False) -- original representation
- | tcName <- nameSetToList allScalarTyConNames]
- `plusNameEnv`
mkNameEnv [ (tyConName tycon, False) -- original representation
- | tycon <- localAbstractTyCons]
+ | tycon <- scalarTyConsNoRHS]
-- Split the list of 'TyCons' into the ones (1) that we must vectorise and those (2)
@@ -189,11 +205,15 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls
-- these are being handled separately. NB: Some type constructors may be marked SCALAR
-- /and/ have an explicit right-hand side.)
--
- -- Furthermore, 'drop_tcs' are those type constructors that we cannot vectorise.
- ; let maybeVectoriseTyCons = filter notVectSpecialTyCon tycons ++ impVectTyCons
- (conv_tcs, keep_tcs, drop_tcs) = classifyTyCons vectTyConFlavour maybeVectoriseTyCons
+ -- Furthermore, 'par_tcs' and 'drop_tcs' are those type constructors that we cannot
+ -- vectorise, and of those, only the 'par_tcs' involve parallel arrays.
+ ; parallelTyCons <- globalParallelTyCons
+ ; let maybeVectoriseTyCons = filter notVectSpecialTyCon tycons ++ impVectTyCons
+ (conv_tcs, keep_tcs, par_tcs, drop_tcs)
+ = classifyTyCons vectTyConFlavour parallelTyCons maybeVectoriseTyCons
- ; traceVt " VECT SCALAR : " $ ppr localAbstractTyCons
+ ; traceVt " VECT SCALAR : " $ ppr (scalarTyConsNoRHS ++
+ [tycon | (tycon, _, True) <- vectTyConsWithRHS])
; traceVt " VECT [class] : " $ ppr impVectTyCons
; traceVt " VECT with rhs : " $ ppr (map fst3 vectTyConsWithRHS)
; traceVt " -- after classification (local and VECT [class] tycons) --" empty
@@ -203,26 +223,22 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls
-- warn the user about unvectorised type constructors
; let explanation = ptext (sLit "(They use unsupported language extensions") $$
ptext (sLit "or depend on type constructors that are not vectorised)")
- drop_tcs_nosyn = filter (not . isSynTyCon) drop_tcs
+ drop_tcs_nosyn = filter (not . isSynTyCon) (par_tcs ++ drop_tcs)
; unless (null drop_tcs_nosyn) $
emitVt "Warning: cannot vectorise these type constructors:" $
pprQuotedList drop_tcs_nosyn $$ explanation
- ; mapM_ addGlobalScalarTyCon keep_tcs
+ ; mapM_ addParallelTyConAndCons $ conv_tcs ++ par_tcs
; let mapping =
- -- Type constructors that we don't need to vectorise, use the same
+ -- Type constructors that we found we don't need to vectorise and those
+ -- declared VECTORISE SCALAR /without/ an explicit right-hand side, use the same
-- representation in both unvectorised and vectorised code; they are not
-- abstract.
- [(tycon, tycon, False) | tycon <- keep_tcs]
+ [(tycon, tycon, False) | tycon <- keep_tcs ++ scalarTyConsNoRHS]
-- We do the same for type constructors declared VECTORISE SCALAR /without/
- -- an explicit right-hand side, but ignore their representation (data
- -- constructors) as they are abstract.
- ++ [(tycon, tycon, True) | tycon <- localAbstractTyCons]
- -- Type constructors declared VECTORISE /with/ an explicit vectorised type,
- -- we map from the original to the given type; whether they are abstract depends
- -- on whether the vectorisation declaration was SCALAR.
- ++ vectTyConsWithRHS
+ -- an explicit right-hand side
+ ++ [(tycon, vTycon, True) | (tycon, vTycon, _) <- vectTyConsWithRHS]
; syn_tcs <- catMaybes <$> mapM defTyConDataCons mapping
-- Vectorise all the data type declarations that we can and must vectorise (enter the
@@ -263,17 +279,15 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls
do { defTyConPAs (zipLazy vect_tcs dfuns)
-- Query the 'PData' instance type constructors for type constructors that have a
- -- VECTORISE pragma with an explicit right-hand side (this is Item (4) of
- -- "Note [Pragmas to vectorise tycons]" above).
- ; let (withRHS_non_abstract, vwithRHS_non_abstract)
- = unzip [(tycon, vtycon) | (tycon, vtycon, False) <- vectTyConsWithRHS]
- ; pdata_withRHS_tcs <- mapM pdataReprTyConExact withRHS_non_abstract
+ -- VECTORISE SCALAR type pragma without an explicit right-hand side (this is Item
+ -- (3) of "Note [Pragmas to vectorise tycons]" above).
+ ; pdata_scalar_tcs <- mapM pdataReprTyConExact scalarTyConsNoRHS
-- Build workers for all vectorised data constructors (except abstract ones)
; sequence_ $
- zipWith3 vectDataConWorkers (orig_tcs ++ withRHS_non_abstract)
- (vect_tcs ++ vwithRHS_non_abstract)
- (pdata_tcs ++ pdata_withRHS_tcs)
+ zipWith3 vectDataConWorkers (orig_tcs ++ scalarTyConsNoRHS)
+ (vect_tcs ++ scalarTyConsNoRHS)
+ (pdata_tcs ++ pdata_scalar_tcs)
-- Build a 'PA' dictionary for all type constructors (except abstract ones & those
-- defined with an explicit right-hand side where the dictionary is user-supplied)
@@ -295,6 +309,12 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls
}
where
fst3 (a, _, _) = a
+
+ addParallelTyConAndCons tycon
+ = do
+ { addGlobalParallelTyCon tycon
+ ; mapM_ addGlobalParallelVar . concatMap dataConImplicitIds . tyConDataCons $ tycon
+ }
-- Add a mapping from the original to vectorised type constructor to the vectorisation map.
-- Unless the type constructor is abstract, also mappings from the orignal's data constructors
@@ -307,21 +327,22 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls
-- right type constructor when reading vectorisation information from interface files).
--
defTyConDataCons (origTyCon, vectTyCon, isAbstract)
- = do { canonName <- mkLocalisedName mkVectTyConOcc origName
- ; if origName == vectName -- Case (1)
- || vectName == canonName -- Case (2)
- then do
- { defTyCon origTyCon vectTyCon -- T --> vT
- ; defDataCons -- Ci --> vCi
- ; return Nothing
- }
- else do -- Case (3)
- { let synTyCon = mkSyn canonName (mkTyConTy vectTyCon) -- type S = vT
- ; defTyCon origTyCon synTyCon -- T --> S
- ; defDataCons -- Ci --> vCi
- ; return $ Just synTyCon
- }
- }
+ = do
+ { canonName <- mkLocalisedName mkVectTyConOcc origName
+ ; if origName == vectName -- Case (1)
+ || vectName == canonName -- Case (2)
+ then do
+ { defTyCon origTyCon vectTyCon -- T --> vT
+ ; defDataCons -- Ci --> vCi
+ ; return Nothing
+ }
+ else do -- Case (3)
+ { let synTyCon = mkSyn canonName (mkTyConTy vectTyCon) -- type S = vT
+ ; defTyCon origTyCon synTyCon -- T --> S
+ ; defDataCons -- Ci --> vCi
+ ; return $ Just synTyCon
+ }
+ }
where
origName = tyConName origTyCon
vectName = tyConName vectTyCon
@@ -343,9 +364,9 @@ buildTyConPADict vect_tc prepr_ax pdata_tc pdatas_tc
= tyConRepr vect_tc >>= buildPADict vect_tc prepr_ax pdata_tc pdatas_tc
-- Produce a custom-made worker for the data constructors of a vectorised data type. This includes
--- all data constructors that may be used in vetcorised code — i.e., all data constructors of data
--- types other than scalar ones. Also adds a mapping from the original to vectorised worker into
--- the vectorisation map.
+-- all data constructors that may be used in vectorised code — i.e., all data constructors of data
+-- types with 'VECTORISE [SCALAR] type' pragmas with an explicit right-hand side. Also adds a mapping
+-- from the original to vectorised worker into the vectorisation map.
--
-- FIXME: It's not nice that we need create a special worker after the data constructors has
-- already been constructed. Also, I don't think the worker is properly added to the data
diff --git a/compiler/vectorise/Vectorise/Type/Type.hs b/compiler/vectorise/Vectorise/Type/Type.hs
index a7ec86a296..ebb09e663c 100644
--- a/compiler/vectorise/Vectorise/Type/Type.hs
+++ b/compiler/vectorise/Vectorise/Type/Type.hs
@@ -14,21 +14,16 @@ import TcType
import Type
import TypeRep
import TyCon
-import Outputable
import Control.Monad
import Control.Applicative
import Data.Maybe
--- | Vectorise a type constructor.
+
+-- |Vectorise a type constructor. Unless there is a vectorised version (stripped of embedded
+-- parallel arrays), the vectorised version is the same as the original.
--
vectTyCon :: TyCon -> VM TyCon
-vectTyCon tc
- | isFunTyCon tc = builtin closureTyCon
- | isBoxedTupleTyCon tc = return tc
- | isUnLiftedTyCon tc = return tc
- | otherwise
- = maybeCantVectoriseM "Tycon not vectorised: " (ppr tc)
- $ lookupTyCon tc
+vectTyCon tc = maybe tc id <$> lookupTyCon tc
-- |Produce the vectorised and lifted versions of a type.
--
diff --git a/compiler/vectorise/Vectorise/Utils.hs b/compiler/vectorise/Vectorise/Utils.hs
index c5f1cb7cb1..fafce7a67d 100644
--- a/compiler/vectorise/Vectorise/Utils.hs
+++ b/compiler/vectorise/Vectorise/Utils.hs
@@ -17,7 +17,7 @@ module Vectorise.Utils (
combinePD, liftPD,
-- * Scalars
- zipScalars, scalarClosure,
+ isScalar, zipScalars, scalarClosure,
-- * Naming
newLocalVar
@@ -137,20 +137,29 @@ liftPD x
-- Scalars --------------------------------------------------------------------
+isScalar :: Type -> VM Bool
+isScalar ty
+ = do
+ { scalar <- builtin scalarClass
+ ; existsInst scalar [ty]
+ }
+
zipScalars :: [Type] -> Type -> VM CoreExpr
zipScalars arg_tys res_ty
- = do
- scalar <- builtin scalarClass
- (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
- zipf <- builtin (scalarZip $ length arg_tys)
- return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
+ = do
+ { scalar <- builtin scalarClass
+ ; (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
+ ; zipf <- builtin (scalarZip $ length arg_tys)
+ ; return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
+ }
where
ty_args = arg_tys ++ [res_ty]
scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
scalarClosure arg_tys res_ty scalar_fun array_fun
= do
- ctr <- builtin (closureCtrFun $ length arg_tys)
- pas <- mapM paDictOfType (init arg_tys)
- return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
+ { ctr <- builtin (closureCtrFun $ length arg_tys)
+ ; pas <- mapM paDictOfType (init arg_tys)
+ ; return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
`mkApps` (pas ++ [scalar_fun, array_fun])
+ }