diff options
author | Manuel M T Chakravarty <chak@cse.unsw.edu.au> | 2012-12-05 15:28:19 +1100 |
---|---|---|
committer | Manuel M T Chakravarty <chak@cse.unsw.edu.au> | 2012-12-05 15:28:19 +1100 |
commit | b77da25ef0d95e776a43779bbb4843eb01d33552 (patch) | |
tree | 4aeb4d158a5e66d033bca83f2a804b2ce394b5ad | |
parent | 2a7217e3fa39410ac61e17f5c8e2ce3976bec1a9 (diff) | |
download | haskell-b77da25ef0d95e776a43779bbb4843eb01d33552.tar.gz |
Rewrote vectorisation avoidance (based on the HS paper)
* Vectorisation avoidance is now the default
* Types and values from unvectorised modules are permitted in scalar code
* Simplified the VECTORISE pragmas (see http://hackage.haskell.org/trac/ghc/wiki/DataParallel/VectPragma for the spec)
* Vectorisation information is now included in the annotated Core AST
27 files changed, 1080 insertions, 1077 deletions
diff --git a/compiler/coreSyn/CoreFVs.lhs b/compiler/coreSyn/CoreFVs.lhs index d2bb6ed57a..2a11723fa9 100644 --- a/compiler/coreSyn/CoreFVs.lhs +++ b/compiler/coreSyn/CoreFVs.lhs @@ -328,12 +328,11 @@ breaker, which is perfectly inlinable. vectsFreeVars :: [CoreVect] -> VarSet vectsFreeVars = foldr (unionVarSet . vectFreeVars) emptyVarSet where - vectFreeVars (Vect _ Nothing) = noFVs - vectFreeVars (Vect _ (Just rhs)) = expr_fvs rhs isLocalId emptyVarSet - vectFreeVars (NoVect _) = noFVs - vectFreeVars (VectType _ _ _) = noFVs - vectFreeVars (VectClass _) = noFVs - vectFreeVars (VectInst _) = noFVs + vectFreeVars (Vect _ rhs) = expr_fvs rhs isLocalId emptyVarSet + vectFreeVars (NoVect _) = noFVs + vectFreeVars (VectType _ _ _) = noFVs + vectFreeVars (VectClass _) = noFVs + vectFreeVars (VectInst _) = noFVs -- this function is only concerned with values, not types \end{code} diff --git a/compiler/coreSyn/CoreSubst.lhs b/compiler/coreSyn/CoreSubst.lhs index a8de9c2b16..f92b63bfe0 100644 --- a/compiler/coreSyn/CoreSubst.lhs +++ b/compiler/coreSyn/CoreSubst.lhs @@ -749,12 +749,11 @@ substVects subst = map (substVect subst) ------------------ substVect :: Subst -> CoreVect -> CoreVect -substVect _subst (Vect v Nothing) = Vect v Nothing -substVect subst (Vect v (Just rhs)) = Vect v (Just (simpleOptExprWith subst rhs)) -substVect _subst vd@(NoVect _) = vd -substVect _subst vd@(VectType _ _ _) = vd -substVect _subst vd@(VectClass _) = vd -substVect _subst vd@(VectInst _) = vd +substVect subst (Vect v rhs) = Vect v (simpleOptExprWith subst rhs) +substVect _subst vd@(NoVect _) = vd +substVect _subst vd@(VectType _ _ _) = vd +substVect _subst vd@(VectClass _) = vd +substVect _subst vd@(VectInst _) = vd ------------------ substVarSet :: Subst -> VarSet -> VarSet diff --git a/compiler/coreSyn/CoreSyn.lhs b/compiler/coreSyn/CoreSyn.lhs index a84a29a6c0..5216bebd23 100644 --- a/compiler/coreSyn/CoreSyn.lhs +++ b/compiler/coreSyn/CoreSyn.lhs @@ -592,11 +592,11 @@ Representation of desugared vectorisation declarations that are fed to the vecto 'ModGuts'). \begin{code} -data CoreVect = Vect Id (Maybe CoreExpr) +data CoreVect = Vect Id CoreExpr | NoVect Id | VectType Bool TyCon (Maybe TyCon) | VectClass TyCon -- class tycon - | VectInst Id -- instance dfun (always SCALAR) + | VectInst Id -- instance dfun (always SCALAR) !!!FIXME: should be superfluous now \end{code} diff --git a/compiler/coreSyn/PprCore.lhs b/compiler/coreSyn/PprCore.lhs index 3ca8c48855..a8ed4b875a 100644 --- a/compiler/coreSyn/PprCore.lhs +++ b/compiler/coreSyn/PprCore.lhs @@ -494,8 +494,7 @@ instance Outputable id => Outputable (Tickish id) where \begin{code} instance Outputable CoreVect where - ppr (Vect var Nothing) = ptext (sLit "VECTORISE SCALAR") <+> ppr var - ppr (Vect var (Just e)) = hang (ptext (sLit "VECTORISE") <+> ppr var <+> char '=') + ppr (Vect var e) = hang (ptext (sLit "VECTORISE") <+> ppr var <+> char '=') 4 (pprCoreExpr e) ppr (NoVect var) = ptext (sLit "NOVECTORISE") <+> ppr var ppr (VectType False var Nothing) = ptext (sLit "VECTORISE type") <+> ppr var diff --git a/compiler/deSugar/Desugar.lhs b/compiler/deSugar/Desugar.lhs index 28b0582076..78c95ceb88 100644 --- a/compiler/deSugar/Desugar.lhs +++ b/compiler/deSugar/Desugar.lhs @@ -432,7 +432,7 @@ the rule is precisly to optimise them: dsVect :: LVectDecl Id -> DsM CoreVect dsVect (L loc (HsVect (L _ v) rhs)) = putSrcSpanDs loc $ - do { rhs' <- fmapMaybeM dsLExpr rhs + do { rhs' <- dsLExpr rhs ; return $ Vect v rhs' } dsVect (L _loc (HsNoVect (L _ v))) diff --git a/compiler/hsSyn/HsDecls.lhs b/compiler/hsSyn/HsDecls.lhs index bac9ec6348..58a0aced14 100644 --- a/compiler/hsSyn/HsDecls.lhs +++ b/compiler/hsSyn/HsDecls.lhs @@ -1111,7 +1111,7 @@ type LVectDecl name = Located (VectDecl name) data VectDecl name = HsVect (Located name) - (Maybe (LHsExpr name)) -- 'Nothing' => SCALAR declaration + (LHsExpr name) | HsNoVect (Located name) | HsVectTypeIn -- pre type-checking @@ -1126,9 +1126,9 @@ data VectDecl name (Located name) | HsVectClassOut -- post type-checking Class - | HsVectInstIn -- pre type-checking (always SCALAR) + | HsVectInstIn -- pre type-checking (always SCALAR) !!!FIXME: should be superfluous now (LHsType name) - | HsVectInstOut -- post type-checking (always SCALAR) + | HsVectInstOut -- post type-checking (always SCALAR) !!!FIXME: should be superfluous now ClsInst deriving (Data, Typeable) @@ -1148,9 +1148,7 @@ lvectInstDecl (L _ (HsVectInstOut _)) = True lvectInstDecl _ = False instance OutputableBndr name => Outputable (VectDecl name) where - ppr (HsVect v Nothing) - = sep [text "{-# VECTORISE SCALAR" <+> ppr v <+> text "#-}" ] - ppr (HsVect v (Just rhs)) + ppr (HsVect v rhs) = sep [text "{-# VECTORISE" <+> ppr v, nest 4 $ pprExpr (unLoc rhs) <+> text "#-}" ] diff --git a/compiler/iface/LoadIface.lhs b/compiler/iface/LoadIface.lhs index 6c5e7d38d9..f3dadbfc53 100644 --- a/compiler/iface/LoadIface.lhs +++ b/compiler/iface/LoadIface.lhs @@ -750,18 +750,18 @@ pprFixities fixes = ptext (sLit "fixities") <+> pprWithCommas pprFix fixes pprFix (occ,fix) = ppr fix <+> ppr occ pprVectInfo :: IfaceVectInfo -> SDoc -pprVectInfo (IfaceVectInfo { ifaceVectInfoVar = vars - , ifaceVectInfoTyCon = tycons - , ifaceVectInfoTyConReuse = tyconsReuse - , ifaceVectInfoScalarVars = scalarVars - , ifaceVectInfoScalarTyCons = scalarTyCons +pprVectInfo (IfaceVectInfo { ifaceVectInfoVar = vars + , ifaceVectInfoTyCon = tycons + , ifaceVectInfoTyConReuse = tyconsReuse + , ifaceVectInfoParallelVars = parallelVars + , ifaceVectInfoParallelTyCons = parallelTyCons }) = vcat [ ptext (sLit "vectorised variables:") <+> hsep (map ppr vars) , ptext (sLit "vectorised tycons:") <+> hsep (map ppr tycons) , ptext (sLit "vectorised reused tycons:") <+> hsep (map ppr tyconsReuse) - , ptext (sLit "scalar variables:") <+> hsep (map ppr scalarVars) - , ptext (sLit "scalar tycons:") <+> hsep (map ppr scalarTyCons) + , ptext (sLit "parallel variables:") <+> hsep (map ppr parallelVars) + , ptext (sLit "parallel tycons:") <+> hsep (map ppr parallelTyCons) ] pprTrustInfo :: IfaceTrustInfo -> SDoc diff --git a/compiler/iface/MkIface.lhs b/compiler/iface/MkIface.lhs index ce07b375b3..6aed1b2be4 100644 --- a/compiler/iface/MkIface.lhs +++ b/compiler/iface/MkIface.lhs @@ -373,17 +373,17 @@ mkIface_ hsc_env maybe_old_fingerprint ifFamInstTcName = ifFamInstFam - flattenVectInfo (VectInfo { vectInfoVar = vVar - , vectInfoTyCon = vTyCon - , vectInfoScalarVars = vScalarVars - , vectInfoScalarTyCons = vScalarTyCons + flattenVectInfo (VectInfo { vectInfoVar = vVar + , vectInfoTyCon = vTyCon + , vectInfoParallelVars = vParallelVars + , vectInfoParallelTyCons = vParallelTyCons }) = IfaceVectInfo - { ifaceVectInfoVar = [Var.varName v | (v, _ ) <- varEnvElts vVar] - , ifaceVectInfoTyCon = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t /= t_v] - , ifaceVectInfoTyConReuse = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t == t_v] - , ifaceVectInfoScalarVars = [Var.varName v | v <- varSetElems vScalarVars] - , ifaceVectInfoScalarTyCons = nameSetToList vScalarTyCons + { ifaceVectInfoVar = [Var.varName v | (v, _ ) <- varEnvElts vVar] + , ifaceVectInfoTyCon = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t /= t_v] + , ifaceVectInfoTyConReuse = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t == t_v] + , ifaceVectInfoParallelVars = [Var.varName v | v <- varSetElems vParallelVars] + , ifaceVectInfoParallelTyCons = nameSetToList vParallelTyCons } ----------------------------- diff --git a/compiler/iface/TcIface.lhs b/compiler/iface/TcIface.lhs index 80c2029a70..8fdda45498 100644 --- a/compiler/iface/TcIface.lhs +++ b/compiler/iface/TcIface.lhs @@ -748,25 +748,25 @@ tcIfaceAnnTarget (ModuleTarget mod) = do -- tcIfaceVectInfo :: Module -> TypeEnv -> IfaceVectInfo -> IfL VectInfo tcIfaceVectInfo mod typeEnv (IfaceVectInfo - { ifaceVectInfoVar = vars - , ifaceVectInfoTyCon = tycons - , ifaceVectInfoTyConReuse = tyconsReuse - , ifaceVectInfoScalarVars = scalarVars - , ifaceVectInfoScalarTyCons = scalarTyCons + { ifaceVectInfoVar = vars + , ifaceVectInfoTyCon = tycons + , ifaceVectInfoTyConReuse = tyconsReuse + , ifaceVectInfoParallelVars = parallelVars + , ifaceVectInfoParallelTyCons = parallelTyCons }) - = do { let scalarTyConsSet = mkNameSet scalarTyCons - ; vVars <- mapM vectVarMapping vars + = do { let parallelTyConsSet = mkNameSet parallelTyCons + ; vVars <- mapM vectVarMapping vars ; let varsSet = mkVarSet (map fst vVars) - ; tyConRes1 <- mapM (vectTyConVectMapping varsSet) tycons - ; tyConRes2 <- mapM (vectTyConReuseMapping varsSet) tyconsReuse - ; vScalarVars <- mapM vectVar scalarVars + ; tyConRes1 <- mapM (vectTyConVectMapping varsSet) tycons + ; tyConRes2 <- mapM (vectTyConReuseMapping varsSet) tyconsReuse + ; vParallelVars <- mapM vectVar parallelVars ; let (vTyCons, vDataCons, vScSels) = unzip3 (tyConRes1 ++ tyConRes2) ; return $ VectInfo - { vectInfoVar = mkVarEnv vVars `extendVarEnvList` concat vScSels - , vectInfoTyCon = mkNameEnv vTyCons - , vectInfoDataCon = mkNameEnv (concat vDataCons) - , vectInfoScalarVars = mkVarSet vScalarVars - , vectInfoScalarTyCons = scalarTyConsSet + { vectInfoVar = mkVarEnv vVars `extendVarEnvList` concat vScSels + , vectInfoTyCon = mkNameEnv vTyCons + , vectInfoDataCon = mkNameEnv (concat vDataCons) + , vectInfoParallelVars = mkVarSet vParallelVars + , vectInfoParallelTyCons = parallelTyConsSet } } where diff --git a/compiler/main/HscTypes.lhs b/compiler/main/HscTypes.lhs index 343df00540..f32975fd4c 100644 --- a/compiler/main/HscTypes.lhs +++ b/compiler/main/HscTypes.lhs @@ -1968,11 +1968,11 @@ on just the OccName easily in a Core pass. -- data VectInfo = VectInfo - { vectInfoVar :: VarEnv (Var , Var ) -- ^ @(f, f_v)@ keyed on @f@ - , vectInfoTyCon :: NameEnv (TyCon , TyCon) -- ^ @(T, T_v)@ keyed on @T@ - , vectInfoDataCon :: NameEnv (DataCon, DataCon) -- ^ @(C, C_v)@ keyed on @C@ - , vectInfoScalarVars :: VarSet -- ^ set of purely scalar variables - , vectInfoScalarTyCons :: NameSet -- ^ set of scalar type constructors + { vectInfoVar :: VarEnv (Var , Var ) -- ^ @(f, f_v)@ keyed on @f@ + , vectInfoTyCon :: NameEnv (TyCon , TyCon) -- ^ @(T, T_v)@ keyed on @T@ + , vectInfoDataCon :: NameEnv (DataCon, DataCon) -- ^ @(C, C_v)@ keyed on @C@ + , vectInfoParallelVars :: VarSet -- ^ set of parallel variables + , vectInfoParallelTyCons :: NameSet -- ^ set of parallel type constructors } -- |Vectorisation information for 'ModIface'; i.e, the vectorisation information propagated @@ -1986,18 +1986,18 @@ data VectInfo -- data IfaceVectInfo = IfaceVectInfo - { ifaceVectInfoVar :: [Name] -- ^ All variables in here have a vectorised variant - , ifaceVectInfoTyCon :: [Name] -- ^ All 'TyCon's in here have a vectorised variant; - -- the name of the vectorised variant and those of its - -- data constructors are determined by - -- 'OccName.mkVectTyConOcc' and - -- 'OccName.mkVectDataConOcc'; the names of the - -- isomorphisms are determined by 'OccName.mkVectIsoOcc' - , ifaceVectInfoTyConReuse :: [Name] -- ^ The vectorised form of all the 'TyCon's in here - -- coincides with the unconverted form; the name of the - -- isomorphisms is determined by 'OccName.mkVectIsoOcc' - , ifaceVectInfoScalarVars :: [Name] -- iface version of 'vectInfoScalarVar' - , ifaceVectInfoScalarTyCons :: [Name] -- iface version of 'vectInfoScalarTyCon' + { ifaceVectInfoVar :: [Name] -- ^ All variables in here have a vectorised variant + , ifaceVectInfoTyCon :: [Name] -- ^ All 'TyCon's in here have a vectorised variant; + -- the name of the vectorised variant and those of its + -- data constructors are determined by + -- 'OccName.mkVectTyConOcc' and + -- 'OccName.mkVectDataConOcc'; the names of the + -- isomorphisms are determined by 'OccName.mkVectIsoOcc' + , ifaceVectInfoTyConReuse :: [Name] -- ^ The vectorised form of all the 'TyCon's in here + -- coincides with the unconverted form; the name of the + -- isomorphisms is determined by 'OccName.mkVectIsoOcc' + , ifaceVectInfoParallelVars :: [Name] -- iface version of 'vectInfoParallelVar' + , ifaceVectInfoParallelTyCons :: [Name] -- iface version of 'vectInfoParallelTyCon' } noVectInfo :: VectInfo @@ -2006,11 +2006,11 @@ noVectInfo plusVectInfo :: VectInfo -> VectInfo -> VectInfo plusVectInfo vi1 vi2 = - VectInfo (vectInfoVar vi1 `plusVarEnv` vectInfoVar vi2) - (vectInfoTyCon vi1 `plusNameEnv` vectInfoTyCon vi2) - (vectInfoDataCon vi1 `plusNameEnv` vectInfoDataCon vi2) - (vectInfoScalarVars vi1 `unionVarSet` vectInfoScalarVars vi2) - (vectInfoScalarTyCons vi1 `unionNameSets` vectInfoScalarTyCons vi2) + VectInfo (vectInfoVar vi1 `plusVarEnv` vectInfoVar vi2) + (vectInfoTyCon vi1 `plusNameEnv` vectInfoTyCon vi2) + (vectInfoDataCon vi1 `plusNameEnv` vectInfoDataCon vi2) + (vectInfoParallelVars vi1 `unionVarSet` vectInfoParallelVars vi2) + (vectInfoParallelTyCons vi1 `unionNameSets` vectInfoParallelTyCons vi2) concatVectInfo :: [VectInfo] -> VectInfo concatVectInfo = foldr plusVectInfo noVectInfo @@ -2024,11 +2024,11 @@ isNoIfaceVectInfo (IfaceVectInfo l1 l2 l3 l4 l5) instance Outputable VectInfo where ppr info = vcat - [ ptext (sLit "variables :") <+> ppr (vectInfoVar info) - , ptext (sLit "tycons :") <+> ppr (vectInfoTyCon info) - , ptext (sLit "datacons :") <+> ppr (vectInfoDataCon info) - , ptext (sLit "scalar vars :") <+> ppr (vectInfoScalarVars info) - , ptext (sLit "scalar tycons :") <+> ppr (vectInfoScalarTyCons info) + [ ptext (sLit "variables :") <+> ppr (vectInfoVar info) + , ptext (sLit "tycons :") <+> ppr (vectInfoTyCon info) + , ptext (sLit "datacons :") <+> ppr (vectInfoDataCon info) + , ptext (sLit "parallel vars :") <+> ppr (vectInfoParallelVars info) + , ptext (sLit "parallel tycons :") <+> ppr (vectInfoParallelTyCons info) ] \end{code} diff --git a/compiler/main/TidyPgm.lhs b/compiler/main/TidyPgm.lhs index 85127e63f6..da732c687e 100644 --- a/compiler/main/TidyPgm.lhs +++ b/compiler/main/TidyPgm.lhs @@ -542,10 +542,10 @@ tidyInstances tidy_dfun ispecs \begin{code} tidyVectInfo :: TidyEnv -> VectInfo -> VectInfo tidyVectInfo (_, var_env) info@(VectInfo { vectInfoVar = vars - , vectInfoScalarVars = scalarVars + , vectInfoParallelVars = parallelVars }) = info { vectInfoVar = tidy_vars - , vectInfoScalarVars = tidy_scalarVars + , vectInfoParallelVars = tidy_parallelVars } where -- we only export mappings whose domain and co-domain is exported (otherwise, the iface is @@ -559,9 +559,9 @@ tidyVectInfo (_, var_env) info@(VectInfo { vectInfoVar = vars , isDataConWorkId var || not (isImplicitId var) ] - tidy_scalarVars = mkVarSet [ lookup_var var - | var <- varSetElems scalarVars - , isGlobalId var || isExportedId var] + tidy_parallelVars = mkVarSet [ lookup_var var + | var <- varSetElems parallelVars + , isGlobalId var || isExportedId var] lookup_var var = lookupWithDefaultVarEnv var_env var var \end{code} diff --git a/compiler/parser/Parser.y.pp b/compiler/parser/Parser.y.pp index 6c19812762..690b005d76 100644 --- a/compiler/parser/Parser.y.pp +++ b/compiler/parser/Parser.y.pp @@ -577,8 +577,7 @@ topdecl :: { OrdList (LHsDecl RdrName) } | '{-# DEPRECATED' deprecations '#-}' { $2 } | '{-# WARNING' warnings '#-}' { $2 } | '{-# RULES' rules '#-}' { $2 } - | '{-# VECTORISE_SCALAR' qvar '#-}' { unitOL $ LL $ VectD (HsVect $2 Nothing) } - | '{-# VECTORISE' qvar '=' exp '#-}' { unitOL $ LL $ VectD (HsVect $2 (Just $4)) } + | '{-# VECTORISE' qvar '=' exp '#-}' { unitOL $ LL $ VectD (HsVect $2 $4) } | '{-# NOVECTORISE' qvar '#-}' { unitOL $ LL $ VectD (HsNoVect $2) } | '{-# VECTORISE' 'type' gtycon '#-}' { unitOL $ LL $ @@ -593,8 +592,6 @@ topdecl :: { OrdList (LHsDecl RdrName) } { unitOL $ LL $ VectD (HsVectTypeIn True $3 (Just $5)) } | '{-# VECTORISE' 'class' gtycon '#-}' { unitOL $ LL $ VectD (HsVectClassIn $3) } - | '{-# VECTORISE_SCALAR' 'instance' type '#-}' - { unitOL $ LL $ VectD (HsVectInstIn $3) } | annotation { unitOL $1 } | decl { unLoc $1 } diff --git a/compiler/rename/RnSource.lhs b/compiler/rename/RnSource.lhs index 595f4653d3..0d897f3f0b 100644 --- a/compiler/rename/RnSource.lhs +++ b/compiler/rename/RnSource.lhs @@ -723,18 +723,14 @@ badRuleLhsErr name lhs bad_e \begin{code} rnHsVectDecl :: VectDecl RdrName -> RnM (VectDecl Name, FreeVars) -rnHsVectDecl (HsVect var Nothing) - = do { var' <- lookupLocatedOccRn var - ; return (HsVect var' Nothing, unitFV (unLoc var')) - } -- FIXME: For the moment, the right-hand side is restricted to be a variable as we cannot properly -- typecheck a complex right-hand side without invoking 'vectType' from the vectoriser. -rnHsVectDecl (HsVect var (Just rhs@(L _ (HsVar _)))) +rnHsVectDecl (HsVect var rhs@(L _ (HsVar _))) = do { var' <- lookupLocatedOccRn var ; (rhs', fv_rhs) <- rnLExpr rhs - ; return (HsVect var' (Just rhs'), fv_rhs `addOneFV` unLoc var') + ; return (HsVect var' rhs', fv_rhs `addOneFV` unLoc var') } -rnHsVectDecl (HsVect _var (Just _rhs)) +rnHsVectDecl (HsVect _var _rhs) = failWith $ vcat [ ptext (sLit "IMPLEMENTATION RESTRICTION: right-hand side of a VECTORISE pragma") , ptext (sLit "must be an identifier") diff --git a/compiler/typecheck/TcBinds.lhs b/compiler/typecheck/TcBinds.lhs index 5eb8e150ef..b30aecab5f 100644 --- a/compiler/typecheck/TcBinds.lhs +++ b/compiler/typecheck/TcBinds.lhs @@ -739,17 +739,12 @@ tcVect :: VectDecl Name -> TcM (VectDecl TcId) -- during type checking. Instead, constrain the rhs of a vectorisation declaration to be a single -- identifier (this is checked in 'rnHsVectDecl'). Fix this by enabling the use of 'vectType' -- from the vectoriser here. -tcVect (HsVect name Nothing) - = addErrCtxt (vectCtxt name) $ - do { var <- wrapLocM tcLookupId name - ; return $ HsVect var Nothing - } -tcVect (HsVect name (Just rhs)) +tcVect (HsVect name rhs) = addErrCtxt (vectCtxt name) $ do { var <- wrapLocM tcLookupId name ; let L rhs_loc (HsVar rhs_var_name) = rhs ; rhs_id <- tcLookupId rhs_var_name - ; return $ HsVect var (Just $ L rhs_loc (HsVar rhs_id)) + ; return $ HsVect var (L rhs_loc (HsVar rhs_id)) } {- OLD CODE: diff --git a/compiler/typecheck/TcHsSyn.lhs b/compiler/typecheck/TcHsSyn.lhs index d1a82b225d..9714b4e7fb 100644 --- a/compiler/typecheck/TcHsSyn.lhs +++ b/compiler/typecheck/TcHsSyn.lhs @@ -1081,7 +1081,7 @@ zonkVects env = mappM (wrapLocM (zonkVect env)) zonkVect :: ZonkEnv -> VectDecl TcId -> TcM (VectDecl Id) zonkVect env (HsVect v e) = do { v' <- wrapLocM (zonkIdBndr env) v - ; e' <- fmapMaybeM (zonkLExpr env) e + ; e' <- zonkLExpr env e ; return $ HsVect v' e' } zonkVect env (HsNoVect v) diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs index 8b7e817826..e6c4b1e0cf 100644 --- a/compiler/vectorise/Vectorise.hs +++ b/compiler/vectorise/Vectorise.hs @@ -13,26 +13,22 @@ import Vectorise.Type.Type import Vectorise.Convert import Vectorise.Utils.Hoisting import Vectorise.Exp -import Vectorise.Vect import Vectorise.Env import Vectorise.Monad import HscTypes hiding ( MonadThings(..) ) import CoreUnfold ( mkInlineUnfolding ) -import CoreFVs import PprCore import CoreSyn import CoreMonad ( CoreM, getHscEnv ) import Type import Id import DynFlags -import BasicTypes ( isStrongLoopBreaker ) import Outputable import Util ( zipLazy ) import MonadUtils import Control.Monad -import Data.Maybe -- |Vectorise a single module. @@ -69,7 +65,7 @@ vectModule guts@(ModGuts { mg_tcs = tycons = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $ pprCoreBindings binds - -- Pick out all 'VECTORISE type' and 'VECTORISE class' pragmas + -- Pick out all 'VECTORISE [SCALAR] type' and 'VECTORISE class' pragmas ; let ty_vect_decls = [vd | vd@(VectType _ _ _) <- vect_decls] cls_vect_decls = [vd | vd@(VectClass _) <- vect_decls] @@ -87,8 +83,7 @@ vectModule guts@(ModGuts { mg_tcs = tycons -- Vectorise all the top level bindings and VECTORISE declarations on imported identifiers -- NB: Need to vectorise the imported bindings first (local bindings may depend on them). - ; let impBinds = [imp_id | Vect imp_id _ <- vect_decls, isGlobalId imp_id] ++ - [imp_id | VectInst imp_id <- vect_decls, isGlobalId imp_id] + ; let impBinds = [(imp_id, expr) | Vect imp_id expr <- vect_decls, isGlobalId imp_id] ; binds_imp <- mapM vectImpBind impBinds ; binds_top <- mapM vectTopBind binds @@ -101,7 +96,8 @@ vectModule guts@(ModGuts { mg_tcs = tycons } } --- Try to vectorise a top-level binding. If it doesn't vectorise then return it unharmed. +-- Try to vectorise a top-level binding. If it doesn't vectorise, or if it is entirely scalar, then +-- omit vectorisation of that binding. -- -- For example, for the binding -- @@ -125,129 +121,173 @@ vectModule guts@(ModGuts { mg_tcs = tycons -- lfoo = ... -- @ -- --- @vfoo@ is the "vectorised", or scalar, version that does the same as the original --- function foo, but takes an explicit environment. +-- @vfoo@ is the "vectorised", or scalar, version that does the same as the original function foo, +-- but takes an explicit environment. -- -- @lfoo@ is the "lifted" version that works on arrays. -- --- @v_foo@ combines both of these into a `Closure` that also contains the --- environment. +-- @v_foo@ combines both of these into a `Closure` that also contains the environment. -- --- The original binding @foo@ is rewritten to call the vectorised version --- present in the closure. +-- The original binding @foo@ is rewritten to call the vectorised version present in the closure. -- -- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma. If this -- pragma is used in a group of mutually recursive bindings, either all or no binding must have --- the pragma. If only some bindings are annotated, a fatal error is being raised. +-- the pragma. If only some bindings are annotated, a fatal error is being raised. (In the case of +-- scalar bindings, we only omit vectorisation if all bindings in a group are scalar.) +-- -- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or -- we may emit a warning and refrain from vectorising the entire group. -- vectTopBind :: CoreBind -> VM CoreBind vectTopBind b@(NonRec var expr) - = unlessNoVectDecl $ - do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it - -- to the vectorisation map. - ; (inline, isScalar, expr') <- vectTopRhs [] var expr - ; var' <- vectTopBinder var inline expr' - ; when isScalar $ - addGlobalScalarVar var - - -- We replace the original top-level binding by a value projected from the vectorised - -- closure and add any newly created hoisted top-level bindings. - ; cexpr <- tryConvert var var' expr - ; hs <- takeHoisted - ; return . Rec $ (var, cexpr) : (var', expr') : hs - } - `orElseErrV` - do { emitVt " Could NOT vectorise top-level binding" $ ppr var - ; return b + = do + { traceVt "= Vectorise non-recursive top-level variable" (ppr var) + + ; (hasNoVect, vectDecl) <- lookupVectDecl var + ; if hasNoVect + then do + { -- 'NOVECTORISE' pragma => leave this binding as it is + ; traceVt "NOVECTORISE" $ ppr var + ; return b + } + else do + { vectRhs <- case vectDecl of + Just (_, expr') -> + -- 'VECTORISE' pragma => just use the provided vectorised rhs + do + { traceVt "VECTORISE" $ ppr var + ; return $ Just (False, inlineMe, expr') + } + Nothing -> + -- no pragma => standard vectorisation of rhs + do + { traceVt "[Vanilla]" $ ppr var <+> char '=' <+> ppr expr + ; vectTopExpr var expr + } + ; hs <- takeHoisted -- make sure we clean those out (even if we skip) + ; case vectRhs of + { Nothing -> + -- scalar binding => leave this binding as it is + do + { traceVt "scalar binding [skip]" $ ppr var + ; return b + } + ; Just (parBind, inline, expr') -> do + { + -- vanilla case => create an appropriate top-level binding & add it to the vectorisation map + ; when parBind $ + addGlobalParallelVar var + ; var' <- vectTopBinder var inline expr' + + -- We replace the original top-level binding by a value projected from the vectorised + -- closure and add any newly created hoisted top-level bindings. + ; cexpr <- tryConvert var var' expr + ; return . Rec $ (var, cexpr) : (var', expr') : hs + } } } } + `orElseErrV` + do + { emitVt " Could NOT vectorise top-level binding" $ ppr var + ; return b + } +vectTopBind b@(Rec binds) + = do + { traceVt "= Vectorise recursive top-level variables" $ ppr vars + + ; vectDecls <- mapM lookupVectDecl vars + ; let hasNoVects = map fst vectDecls + ; if and hasNoVects + then do + { -- 'NOVECTORISE' pragmas => leave this entire binding group as it is + ; traceVt "NOVECTORISE" $ ppr vars + ; return b + } + else do + { if or hasNoVects + then do + { -- Inconsistent 'NOVECTORISE' pragmas => bail out + ; dflags <- getDynFlags + ; cantVectorise dflags noVectoriseErr (ppr b) } - where - unlessNoVectDecl vectorise - = do { hasNoVectDecl <- noVectDecl var - ; when hasNoVectDecl $ - traceVt "NOVECTORISE" $ ppr var - ; if hasNoVectDecl then return b else vectorise - } -vectTopBind b@(Rec bs) - = unlessSomeNoVectDecl $ - do { (vars', _, exprs', hs) <- fixV $ - \ ~(_, inlines, rhss, _) -> - do { -- Vectorise the right-hand sides, create an appropriate top-level bindings - -- and add them to the vectorisation map. - ; vars' <- sequence [vectTopBinder var inline rhs - | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)] - ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs - ; hs <- takeHoisted - ; if and areScalars - then -- (1) Entire recursive group is scalar - -- => add all variables to the global set of scalars - do { mapM_ addGlobalScalarVar vars - ; return (vars', inlines, exprs', hs) - } - else -- (2) At least one binding is not scalar - -- => vectorise again with empty set of local scalars - do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs - ; hs <- takeHoisted - ; return (vars', inlines, exprs', hs) - } - } - - -- Replace the original top-level bindings by a values projected from the vectorised - -- closures and add any newly created hoisted top-level bindings to the group. - ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs - ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs - } - `orElseErrV` - return b - where - (vars, exprs) = unzip bs + else do + { -- For all bindings *with* a pragma, just use the pragma-supplied vectorised expression + ; newBindsWPragma <- concat <$> + sequence [ vectTopBindAndConvert bind inlineMe expr' + | (bind, (_, Just (_, expr'))) <- zip binds vectDecls] + + -- Standard vectorisation of all rhses that are *without* a pragma. + -- NB: The reason for 'fixV' is rather subtle: 'vectTopBindAndConvert' adds entries for + -- the bound variables in the recursive group to the vectorisation map, which in turn + -- are needed by 'vectPolyExprs' (unless it returns 'Nothing'). + ; let bindsWOPragma = [bind | (bind, (_, Nothing)) <- zip binds vectDecls] + ; (newBinds, _) <- fixV $ + \ ~(_, exprs') -> + do + { -- Create appropriate top-level bindings, enter them into the vectorisation map, and + -- vectorise the right-hand sides + ; newBindsWOPragma <- concat <$> + sequence [vectTopBindAndConvert bind inline expr + | (bind, ~(inline, expr)) <- zipLazy bindsWOPragma exprs'] + -- irrefutable pattern and 'zipLazy' to tie the knot; + -- hence, can't use 'zipWithM' + ; vectRhses <- vectTopExprs bindsWOPragma + ; hs <- takeHoisted -- make sure we clean those out (even if we skip) - unlessSomeNoVectDecl vectorise - = do { hasNoVectDecls <- mapM noVectDecl vars - ; when (and hasNoVectDecls) $ - traceVt "NOVECTORISE" $ ppr vars - ; if and hasNoVectDecls - then return b -- all bindings have 'NOVECTORISE' - else if or hasNoVectDecls - then do dflags <- getDynFlags - cantVectorise dflags noVectoriseErr (ppr b) -- some (but not all) have 'NOVECTORISE' - else vectorise -- no binding has a 'NOVECTORISE' decl - } + ; case vectRhses of + Nothing -> + -- scalar bindings => skip all bindings except those with pragmas and retract the + -- entries into the vectorisation map for the scalar bindings + do + { traceVt "scalar bindings [skip]" $ ppr vars + ; mapM_ (undefGlobalVar . fst) bindsWOPragma + ; return (bindsWOPragma ++ newBindsWPragma, exprs') + } + Just (parBind, exprs') -> + -- vanilla case => record parallel variables and return the final bindings + do + { when parBind $ + mapM_ addGlobalParallelVar vars + ; return (newBindsWOPragma ++ newBindsWPragma ++ hs, exprs') + } + } + ; return $ Rec newBinds + } } } + `orElseErrV` + do + { emitVt " Could NOT vectorise top-level bindings" $ ppr vars + ; return b + } + where + vars = map fst binds noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group" + + -- Replace the original top-level bindings by a values projected from the vectorised + -- closures and add any newly created hoisted top-level bindings to the group. + vectTopBindAndConvert (var, expr) inline expr' + = do + { var' <- vectTopBinder var inline expr' + ; cexpr <- tryConvert var var' expr + ; return [(var, cexpr), (var', expr')] + } --- Add a vectorised binding to an imported top-level variable that has a VECTORISE [SCALAR] pragma +-- Add a vectorised binding to an imported top-level variable that has a VECTORISE pragma -- in this module. -- --- RESTIRCTION: Currently, we cannot use the pragma vor mutually recursive definitions. +-- RESTIRCTION: Currently, we cannot use the pragma for mutually recursive definitions. -- -vectImpBind :: Id -> VM CoreBind -vectImpBind var - = do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it - -- to the vectorisation map. For the non-lifted version, we refer to the original - -- definition — i.e., 'Var var'. - -- NB: To support recursive definitions, we tie a lazy knot. - ; (var', _, expr') <- fixV $ - \ ~(_, inline, rhs) -> - do { var' <- vectTopBinder var inline rhs - ; (inline, isScalar, expr') <- vectTopRhs [] var (Var var) - - ; when isScalar $ - addGlobalScalarVar var - ; return (var', inline, expr') - } +vectImpBind :: (Id, CoreExpr) -> VM CoreBind +vectImpBind (var, expr) + = do + { traceVt "= Add vectorised binding to imported variable" (ppr var) - -- We add any newly created hoisted top-level bindings. - ; hs <- takeHoisted - ; return . Rec $ (var', expr') : hs - } - --- | Make the vectorised version of this top level binder, and add the mapping --- between it and the original to the state. For some binder @foo@ the vectorised --- version is @$v_foo@ + ; var' <- vectTopBinder var inlineMe expr + ; return $ NonRec var' expr + } + +-- |Make the vectorised version of this top level binder, and add the mapping between it and the +-- original to the state. For some binder @foo@ the vectorised version is @$v_foo@ -- --- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is --- used inside of 'fixV' in 'vectTopBind'. +-- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is used inside of +-- 'fixV' in 'vectTopBind'. -- vectTopBinder :: Var -- ^ Name of the binding. -> Inline -- ^ Whether it should be inlined, used to annotate it. @@ -257,20 +297,20 @@ vectTopBinder var inline expr = do { -- Vectorise the type attached to the var. ; vty <- vectType (idType var) - -- If there is a vectorisation declartion for this binding, make sure that its type - -- matches - ; vectDecl <- lookupVectDecl var + -- If there is a vectorisation declartion for this binding, make sure its type matches + ; (_, vectDecl) <- lookupVectDecl var ; case vectDecl of Nothing -> return () Just (vdty, _) | eqType vty vdty -> return () | otherwise -> - do dflags <- getDynFlags - cantVectorise dflags ("Type mismatch in vectorisation pragma for " ++ showPpr dflags var) $ - (text "Expected type" <+> ppr vty) - $$ - (text "Inferred type" <+> ppr vdty) - + do + { dflags <- getDynFlags + ; cantVectorise dflags ("Type mismatch in vectorisation pragma for " ++ showPpr dflags var) $ + (text "Expected type" <+> ppr vty) + $$ + (text "Inferred type" <+> ppr vdty) + } -- Make the vectorised version of binding's name, and set the unfolding used for inlining ; var' <- liftM (`setIdUnfoldingLazily` unfolding) $ mkVectId var vty @@ -297,113 +337,17 @@ vectTopBinder var inline expr `setInlinePragma` dfunInlinePragma -} --- | Vectorise the RHS of a top-level binding, in an empty local environment. +-- |Project out the vectorised version of a binding from some closure, or return the original body +-- if that doesn't work. -- --- We need to distinguish four cases: --- --- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides --- vectorised code implemented by the user) --- => no automatic vectorisation & instead use the user-supplied code --- --- (2) We have a scalar vectorisation declaration for a variable that is no dfun --- => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation --- --- (3) We have a scalar vectorisation declaration for a variable that *is* a dfun --- => generate vectorised code according to the the "Note [Scalar dfuns]" below --- --- (4) There is no vectorisation declaration for the variable --- => perform automatic vectorisation of the RHS (the definition may or may not be a dfun; --- vectorisation proceeds differently depending on which it is) --- --- Note [Scalar dfuns] --- ~~~~~~~~~~~~~~~~~~~ --- --- Here is the translation scheme for scalar dfuns — assume the instance declaration: --- --- instance Num Int where --- (+) = primAdd --- {-# VECTORISE SCALAR instance Num Int #-} --- --- It desugars to --- --- $dNumInt :: Num Int --- $dNumInt = D:Num primAdd --- --- We vectorise it to --- --- $v$dNumInt :: V:Num Int --- $v$dNumInt = D:V:Num (closure2 ((+) $dNumInt) (scalar_zipWith ((+) $dNumInt)))) --- --- while adding the following entry to the vectorisation map: '$dNumInt' --> '$v$dNumInt'. --- --- See "Note [Vectorising classes]" in 'Vectorise.Type.Env' for the definition of 'V:Num'. --- --- NB: The outlined vectorisation scheme does not require the right-hand side of the original dfun. --- In fact, we definitely want to refer to the dfn variable instead of the right-hand side to --- ensure that the dictionary selection rules fire. --- -vectTopRhs :: [Var] -- ^ Names of all functions in the rec block - -> Var -- ^ Name of the binding. - -> CoreExpr -- ^ Body of the binding. - -> VM ( Inline -- (1) inline specification for the binding - , Bool -- (2) whether the right-hand side is a scalar computation - , CoreExpr) -- (3) the vectorised right-hand side -vectTopRhs recFs var expr - = closedV - $ do { globalScalar <- isGlobalScalarVar var - ; vectDecl <- lookupVectDecl var - ; dflags <- getDynFlags - ; let isDFun = isDFunId var - - ; traceVt ("vectTopRhs of " ++ showPpr dflags var ++ info globalScalar isDFun vectDecl ++ ":") $ - ppr expr - - ; rhs globalScalar isDFun vectDecl - } - where - rhs _globalScalar _isDFun (Just (_, expr')) -- Case (1) - = return (inlineMe, False, expr') - rhs True False Nothing -- Case (2) - = do { expr' <- vectScalarFun expr - ; return (inlineMe, True, vectorised expr') - } - rhs True True Nothing -- Case (3) - = do { expr' <- vectScalarDFun var - ; return (DontInline, True, expr') - } - rhs False False Nothing -- Case (4) — not a dfun - = do { let exprFvs = freeVars expr - ; (inline, isScalar, vexpr) - <- inBind var $ - vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs exprFvs Nothing - ; return (inline, isScalar, vectorised vexpr) - } - rhs False True Nothing -- Case (4) — is a dfun - = do { expr' <- vectDictExpr expr - ; return (DontInline, True, expr') - } - - info True False _ = " [VECTORISE SCALAR]" - info True True _ = " [VECTORISE SCALAR instance]" - info False _ vectDecl | isJust vectDecl = " [VECTORISE]" - | otherwise = " (no pragma)" - --- |Project out the vectorised version of a binding from some closure, --- or return the original body if that doesn't work or the binding is scalar. --- -tryConvert :: Var -- ^ Name of the original binding (eg @foo@) - -> Var -- ^ Name of vectorised version of binding (eg @$vfoo@) - -> CoreExpr -- ^ The original body of the binding. +tryConvert :: Var -- ^Name of the original binding (eg @foo@) + -> Var -- ^Name of vectorised version of binding (eg @$vfoo@) + -> CoreExpr -- ^The original body of the binding. -> VM CoreExpr tryConvert var vect_var rhs - = do { globalScalar <- isGlobalScalarVar var - ; if globalScalar - then - return rhs - else - fromVect (idType var) (Var vect_var) - `orElseErrV` - do { emitVt " Could NOT call vectorised from original version" $ ppr var - ; return rhs - } - } + = fromVect (idType var) (Var vect_var) + `orElseErrV` + do + { emitVt " Could NOT call vectorised from original version" $ ppr var + ; return rhs + } diff --git a/compiler/vectorise/Vectorise/Convert.hs b/compiler/vectorise/Vectorise/Convert.hs index 048362d59c..f21f5cac86 100644 --- a/compiler/vectorise/Vectorise/Convert.hs +++ b/compiler/vectorise/Vectorise/Convert.hs @@ -84,16 +84,16 @@ identityConv (AppTy {}) = noV $ text "identityConv: type appl. changes under identityConv (FunTy {}) = noV $ text "identityConv: function type changes under vectorisation" identityConv (ForAllTy {}) = noV $ text "identityConv: quantified type changes under vectorisation" --- |Check that this type constructor is neutral under type vectorisation — i.e., it is not altered --- by vectorisation as they contain no parallel arrays. +-- |Check that this type constructor is not changed by vectorisation — i.e., it does not embed any +-- parallel arrays. -- identityConvTyCon :: TyCon -> VM () identityConvTyCon tc - | isBoxedTupleTyCon tc = return () - | isUnLiftedTyCon tc = return () - | otherwise - = do tc' <- maybeV notVectErr (lookupTyCon tc) - if tc == tc' then return () else noV idErr + = do + { tc' <- lookupTyCon tc + ; case tc' of + Nothing -> return () + Just _ -> noV idErr + } where - notVectErr = text "identityConvTyCon: no vectorised version for type constructor" <+> ppr tc - idErr = text "identityConvTyCon: type constructor contains parallel arrays" <+> ppr tc + idErr = text "identityConvTyCon: type constructor contains parallel arrays" <+> ppr tc diff --git a/compiler/vectorise/Vectorise/Env.hs b/compiler/vectorise/Vectorise/Env.hs index d58ec8f800..345b4ba1c3 100644 --- a/compiler/vectorise/Vectorise/Env.hs +++ b/compiler/vectorise/Vectorise/Env.hs @@ -31,7 +31,7 @@ import Name import NameEnv import FastString import TysPrim -import TysWiredIn +--import TysWiredIn import Data.Maybe @@ -60,7 +60,8 @@ data LocalEnv -- ^Mapping from tyvars to their PA dictionaries. , local_bind_name :: FastString - -- ^Local binding name. + -- ^Local binding name. This is only used to generate better names for hoisted + -- expressions. } -- |Create an empty local environment. @@ -84,35 +85,34 @@ data GlobalEnv -- ^Mapping from global variables to their vectorised versions — aka the /vectorisation -- map/. - , global_vect_decls :: VarEnv (Type, CoreExpr) - -- ^Mapping from global variables that have a vectorisation declaration to the right-hand - -- side of that declaration and its type. This mapping only applies to non-scalar - -- vectorisation declarations. All variables with a scalar vectorisation declaration are - -- mentioned in 'global_scalars_vars'. - - , global_scalar_vars :: VarSet - -- ^Purely scalar variables. Code which mentions only these variables doesn't have to be - -- lifted. This includes variables from the current module that have a scalar - -- vectorisation declaration and those that the vectoriser determines to be scalar. - - , global_scalar_tycons :: NameSet - -- ^Type constructors whose values can only contain scalar data. This includes type - -- constructors that appear in a 'VECTORISE SCALAR type' pragma or 'VECTORISE type' pragma - -- *without* a right-hand side in the current or an imported module as well as type - -- constructors that are automatically identified as scalar by the vectoriser (in - -- 'Vectorise.Type.Env'). Scalar code may only operate on such data. + , global_parallel_vars :: VarSet + -- ^The domain of 'global_vars'. -- - -- NB: Not all type constructors in that set are members of the 'Scalar' type class - -- (which can be trivially marshalled across scalar code boundaries). - - , global_novect_vars :: VarSet - -- ^Variables that are not vectorised. (They may be referenced in the right-hand sides - -- of vectorisation declarations, though.) + -- This information is not redundant as it is impossible to extract the domain from a + -- 'VarEnv' (which is keyed on uniques alone). Moreover, we have mapped variables that + -- do not involve parallelism — e.g., the workers of vectorised, but scalar data types. + -- In addition, workers of parallel data types that we could not vectorise also need to + -- be tracked. + + , global_vect_decls :: VarEnv (Maybe (Type, CoreExpr)) + -- ^Mapping from global variables that have a vectorisation declaration to the right-hand + -- side of that declaration and its type and mapping variables that have NOVECTORISE + -- declarations to 'Nothing'. , global_tycons :: NameEnv TyCon - -- ^Mapping from TyCons to their vectorised versions. - -- TyCons which do not have to be vectorised are mapped to themselves. + -- ^Mapping from TyCons to their vectorised versions. The vectorised version will be + -- identical to the original version if it is not changed by vectorisation. In any case, + -- if a tycon appears in the domain of this mapping, it was successfully vectorised. + , global_parallel_tycons :: NameSet + -- ^Type constructors whose definition directly or indirectly includes a parallel type, + -- such as '[::]'. + -- + -- NB: This information is not redundant as some types have got a mapping in + -- 'global_tycons' (to a type other than themselves) and are still not parallel. An + -- example is '(->)'. Moreover, some types have *not* got a mapping in 'global_tycons' + -- (because they couldn't be vectorised), but still contain parallel types. + , global_datacons :: NameEnv DataCon -- ^Mapping from DataCons to their vectorised versions. @@ -129,7 +129,7 @@ data GlobalEnv -- ^External package inst-env & home-package inst-env for family instances. , global_bindings :: [(Var, CoreExpr)] - -- ^Hoisted bindings. + -- ^Hoisted bindings — temporary storage for toplevel bindings during code gen. } -- |Create an initial global environment. @@ -143,9 +143,8 @@ initGlobalEnv info vectDecls instEnvs famInstEnvs = GlobalEnv { global_vars = mapVarEnv snd $ vectInfoVar info , global_vect_decls = mkVarEnv vects - , global_scalar_vars = vectInfoScalarVars info `extendVarSetList` scalar_vars - , global_scalar_tycons = vectInfoScalarTyCons info `addListToNameSet` scalar_tycons - , global_novect_vars = mkVarSet novects + , global_parallel_vars = vectInfoParallelVars info + , global_parallel_tycons = vectInfoParallelTyCons info , global_tycons = mapNameEnv snd $ vectInfoTyCon info , global_datacons = mapNameEnv snd $ vectInfoDataCon info , global_pa_funs = emptyNameEnv @@ -155,23 +154,12 @@ initGlobalEnv info vectDecls instEnvs famInstEnvs , global_bindings = [] } where - vects = [(var, (ty, exp)) | Vect var (Just exp@(Var rhs_var)) <- vectDecls - , let ty = varType rhs_var] + vects = [(var, Just (ty, exp)) | Vect var exp@(Var rhs_var) <- vectDecls + , let ty = varType rhs_var] ++ -- FIXME: we currently only allow RHSes consisting of a -- single variable to be able to obtain the type without -- inference — see also 'TcBinds.tcVect' - scalar_vars = [var | Vect var Nothing <- vectDecls] ++ - [var | VectInst var <- vectDecls] ++ - [dataConWrapId doubleDataCon, dataConWrapId floatDataCon, dataConWrapId intDataCon] -- TODO: fix this hack - novects = [var | NoVect var <- vectDecls] - scalar_tycons = [tyConName tycon | VectType True tycon Nothing <- vectDecls] ++ - [tyConName tycon | VectType _ tycon (Just tycon') <- vectDecls - , tycon == tycon'] ++ - map tyConName [doublePrimTyCon, intPrimTyCon, floatPrimTyCon] -- TODO: fix this hack - -- - for 'VectType True tycon Nothing', we checked that the type does not - -- contain arrays (or type variables that could be instatiated to arrays) - -- - for 'VectType _ tycon (Just tycon')', where the two tycons are the same, - -- we also know that there can be no embedded arrays + [(var, Nothing) | NoVect var <- vectDecls] -- Operators on Global Environments ------------------------------------------- @@ -210,11 +198,11 @@ setPRFunsEnv ps genv = genv { global_pr_funs = mkNameEnv ps } modVectInfo :: GlobalEnv -> [Id] -> [TyCon] -> [CoreVect]-> VectInfo -> VectInfo modVectInfo env mg_ids mg_tyCons vectDecls info = info - { vectInfoVar = mk_env ids (global_vars env) - , vectInfoTyCon = mk_env tyCons (global_tycons env) - , vectInfoDataCon = mk_env dataCons (global_datacons env) - , vectInfoScalarVars = global_scalar_vars env `minusVarSet` vectInfoScalarVars info - , vectInfoScalarTyCons = global_scalar_tycons env `minusNameSet` vectInfoScalarTyCons info + { vectInfoVar = mk_env ids (global_vars env) + , vectInfoTyCon = mk_env tyCons (global_tycons env) + , vectInfoDataCon = mk_env dataCons (global_datacons env) + , vectInfoParallelVars = global_parallel_vars env `minusVarSet` vectInfoParallelVars info + , vectInfoParallelTyCons = global_parallel_tycons env `minusNameSet` vectInfoParallelTyCons info } where vectIds = [id | Vect id _ <- vectDecls] ++ diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs index 8c5ef0045d..88f123210b 100644 --- a/compiler/vectorise/Vectorise/Exp.hs +++ b/compiler/vectorise/Vectorise/Exp.hs @@ -3,10 +3,9 @@ -- |Vectorisation of expressions. module Vectorise.Exp - ( -- * Vectorise polymorphic expressions with special cases for right-hand sides of particular - -- variable bindings - vectPolyExpr - , vectDictExpr + ( -- * Vectorise right-hand sides of toplevel bindings + vectTopExpr + , vectTopExprs , vectScalarFun , vectScalarDFun ) @@ -32,393 +31,404 @@ import DataCon import TyCon import TcType import Type -import PrelNames +import TypeRep import Var import VarEnv import VarSet +import NameSet import Id import BasicTypes( isStrongLoopBreaker ) import Literal -import TysWiredIn import TysPrim import Outputable import FastString +import DynFlags +import Util +import MonadUtils + import Control.Monad -import Control.Applicative import Data.Maybe import Data.List -import TcRnMonad (doptM) -import DynFlags -import Util -- Main entry point to vectorise expressions ----------------------------------- --- |Vectorise a polymorphic expression. +-- |Vectorise a polymorphic expression that forms a *non-recursive* binding. +-- +-- Return 'Nothing' if the expression is scalar; otherwise, the first component of the result +-- (which is of type 'Bool') indicates whether the expression is parallel (i.e., whether it is +-- tagged as 'VIParr'). -- --- If not yet available, precompute vectorisation avoidance information before vectorising. If --- the vectorisation avoidance optimisation is enabled, also use the vectorisation avoidance --- information to encapsulated subexpression that do not need to be vectorised. +-- We have got the non-recursive case as a special case as it doesn't require to compute +-- vectorisation information twice. -- -vectPolyExpr :: Bool -> [Var] -> CoreExprWithFVs -> Maybe VITree - -> VM (Inline, Bool, VExpr) - -- precompute vectorisation avoidance information (and possibly encapsulated subexpressions) -vectPolyExpr loop_breaker recFns expr Nothing +vectTopExpr :: Var -> CoreExpr -> VM (Maybe (Bool, Inline, CoreExpr)) +vectTopExpr var expr = do - { vectAvoidance <- liftDs $ doptM Opt_AvoidVect - ; vi <- vectAvoidInfo expr - ; (expr', vi') <- - if vectAvoidance - then do - { (expr', vi') <- encapsulateScalars vi expr - ; traceVt "vectPolyExpr encapsulated:" (ppr $ deAnnotate expr') - ; return (expr', vi') - } - else return (expr, vi) - ; vectPolyExpr loop_breaker recFns expr' (Just vi') + { exprVI <- encapsulateScalars <=< vectAvoidInfo emptyVarSet . freeVars $ expr + ; if isVIEncaps exprVI + then + return Nothing + else do + { vExpr <- closedV $ + inBind var $ + vectAnnPolyExpr False exprVI + ; inline <- computeInline exprVI + ; return $ Just (isVIParr exprVI, inline, vectorised vExpr) + } } - -- traverse through ticks -vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr) (Just (VITNode _ [vit])) - = do - { (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr (Just vit) - ; return (inline, isScalarFn, vTick tickish expr') - } +-- Compute the inlining hint for the right-hand side of a top-level binding. +-- +computeInline :: CoreExprWithVectInfo -> VM Inline +computeInline ((_, VIDict), _) = return $ DontInline +computeInline (_, AnnTick _ expr) = computeInline expr +computeInline expr@(_, AnnLam _ _) = Inline <$> polyArity tvs + where + (tvs, _) = collectAnnTypeBinders expr +computeInline _expr = return $ DontInline - -- collect and vectorise type abstractions; then, descent into the body -vectPolyExpr loop_breaker recFns expr (Just vit) - = do - { let (tvs, mono) = collectAnnTypeBinders expr - vit' = stripLevels (length tvs) vit - ; arity <- polyArity tvs - ; polyAbstract tvs $ \args -> - do - { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vit' - ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono') - } +-- |Vectorise a recursive group of top-level polymorphic expressions. +-- +-- Return 'Nothing' if the expression group is scalar; otherwise, the first component of the result +-- (which is of type 'Bool') indicates whether the expressions are parallel (i.e., whether they are +-- tagged as 'VIParr'). +-- +vectTopExprs :: [(Var, CoreExpr)] -> VM (Maybe (Bool, [(Inline, CoreExpr)])) +vectTopExprs binds + = do + { exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs + ; if all isVIEncaps exprVIs + then + return Nothing + else do + { (areVIParr, vExprs) <- unzip <$> mapM encapsulateAndVect binds + ; return $ Just (or areVIParr, vExprs) + } } where - stripLevels 0 vit = vit - stripLevels n (VITNode _ [vit]) = stripLevels (n - 1) vit - stripLevels _ vit = pprPanic "vectPolyExpr: stripLevels:" (text (show vit)) + (vars, exprs) = unzip binds + + vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars + + encapsulateAndVect (var, expr) + = do + { exprVI <- vectAvoidAndEncapsulate (mkVarSet vars) expr + ; vExpr <- closedV $ + inBind var $ + vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI + ; inline <- computeInline exprVI + ; return (isVIParr exprVI, (inline, vectorised vExpr)) + } + +-- |Vectorise a polymorphic expression annotated with vectorisation information. +-- +-- The special case of dictionary functions is currently handled separately. (Would be neater to +-- integrate them, though!) +-- +vectAnnPolyExpr :: Bool -> CoreExprWithVectInfo -> VM VExpr +vectAnnPolyExpr loop_breaker (_, AnnTick tickish expr) + -- traverse through ticks + = vTick tickish <$> vectAnnPolyExpr loop_breaker expr +vectAnnPolyExpr loop_breaker expr + | isVIDict expr + -- special case the right-hand side of dictionary functions + = (, undefined) <$> vectDictExpr (deAnnotate expr) + | otherwise + -- collect and vectorise type abstractions; then, descent into the body + = polyAbstract tvs $ \args -> + mapVect (mkLams $ tvs ++ args) <$> vectFnExpr False loop_breaker mono + where + (tvs, mono) = collectAnnTypeBinders expr -- Encapsulate every purely sequential subexpression of a (potentially) parallel expression into a --- into a lambda abstraction over all its free variables followed by the corresponding application --- to those variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions. +-- lambda abstraction over all its free variables followed by the corresponding application to those +-- variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions. -- -- Preconditions: -- -- * All free variables and the result type must be /simple/ types. --- * The expression is sufficientlt complex (top warrant special treatment). For now, that is +-- * The expression is sufficiently complex (to warrant special treatment). For now, that is -- every expression that is not constant and contains at least one operation. -- -encapsulateScalars :: VITree -> CoreExprWithFVs -> VM (CoreExprWithFVs, VITree) -encapsulateScalars vit ce@(_, AnnType _ty) - = return (ce, vit) - -encapsulateScalars vit ce@(_, AnnVar _v) - = return (ce, vit) - -encapsulateScalars vit ce@(_, AnnLit _) - = return (ce, vit) - -encapsulateScalars (VITNode vi [vit]) (fvs, AnnTick tck expr) - = do { (extExpr, vit') <- encapsulateScalars vit expr - ; return ((fvs, AnnTick tck extExpr), VITNode vi [vit']) - } - -encapsulateScalars _ (_fvs, AnnTick _tck _expr) - = panic "encapsulateScalar AnnTick doesn't match up" - -encapsulateScalars (VITNode vi [vit]) ce@(fvs, AnnLam bndr expr) - = do { varsS <- varsSimple fvs - ; case (vi, varsS) of - (VISimple, True) -> do { let (e', vit') = liftSimple vit ce - ; return (e', vit') - } - _ -> do { (extExpr, vit') <- encapsulateScalars vit expr - ; return ((fvs, AnnLam bndr extExpr), VITNode vi [vit']) - } - } - -encapsulateScalars _ (_fvs, AnnLam _bndr _expr) - = panic "encapsulateScalars AnnLam doesn't match up" - -encapsulateScalars vt@(VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2) - = do { varsS <- varsSimple fvs - ; case (vi, varsS) of - (VISimple, True) -> do { let (e', vt') = liftSimple vt ce - -- ; checkTreeAnnM vt' e' - -- ; traceVt "Passed checkTree test!!" (ppr $ deAnnotate e') - ; return (e', vt') - } - _ -> do { (etaCe1, vit1') <- encapsulateScalars vit1 ce1 - ; (etaCe2, vit2') <- encapsulateScalars vit2 ce2 - ; return ((fvs, AnnApp etaCe1 etaCe2), VITNode vi [vit1', vit2']) - } - } - -encapsulateScalars _ (_fvs, AnnApp _ce1 _ce2) - = panic "encapsulateScalars AnnApp doesn't match up" - -encapsulateScalars vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts) - = do { varsS <- varsSimple fvs - ; case (vi, varsS) of - (VISimple, True) -> return $ liftSimple vt ce - _ -> do { (extScrut, scrutVit') <- encapsulateScalars scrutVit scrut - ; extAltsVits <- zipWithM expAlt altVits alts - ; let (extAlts, altVits') = unzip extAltsVits - ; return ((fvs, AnnCase extScrut bndr ty extAlts), VITNode vi (scrutVit': altVits')) - } - } +encapsulateScalars :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo +encapsulateScalars ce@(_, AnnType _ty) + = return ce +encapsulateScalars ce@((_, VISimple), AnnVar v) + | isFunTy . varType $ v -- NB: diverts from the paper: encapsulate scalar function types + = liftSimpleAndCase ce +encapsulateScalars ce@(_, AnnVar _v) + = return ce +encapsulateScalars ce@(_, AnnLit _) + = return ce +encapsulateScalars ((fvs, vi), AnnTick tck expr) + = do + { encExpr <- encapsulateScalars expr + ; return ((fvs, vi), AnnTick tck encExpr) + } +encapsulateScalars ce@((fvs, vi), AnnLam bndr expr) + = do + { varsS <- allScalarVarTypeSet fvs + ; case (vi, varsS) of + (VISimple, True) -> liftSimpleAndCase ce + _ -> do + { encExpr <- encapsulateScalars expr + ; return ((fvs, vi), AnnLam bndr encExpr) + } + } +encapsulateScalars ce@((fvs, vi), AnnApp ce1 ce2) + = do + { varsS <- allScalarVarTypeSet fvs + ; case (vi, varsS) of + (VISimple, True) -> liftSimpleAndCase ce + _ -> do + { encCe1 <- encapsulateScalars ce1 + ; encCe2 <- encapsulateScalars ce2 + ; return ((fvs, vi), AnnApp encCe1 encCe2) + } + } +encapsulateScalars ce@((fvs, vi), AnnCase scrut bndr ty alts) + = do + { varsS <- allScalarVarTypeSet fvs + ; case (vi, varsS) of + (VISimple, True) -> liftSimpleAndCase ce + _ -> do + { encScrut <- encapsulateScalars scrut + ; encAlts <- mapM encAlt alts + ; return ((fvs, vi), AnnCase encScrut bndr ty encAlts) + } + } where - expAlt vt (con, bndrs, expr) - = do { (extExpr, vt') <- encapsulateScalars vt expr - ; return ((con, bndrs, extExpr), vt') - } - -encapsulateScalars _ (_fvs, AnnCase _scrut _bndr _ty _alts) - = panic "encapsulateScalars AnnCase doesn't match up" - -encapsulateScalars vt@(VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2) - = do { varsS <- varsSimple fvs - ; case (vi, varsS) of - (VISimple, True) -> return $ liftSimple vt ce - _ -> do { (extExpr1, vt1') <- encapsulateScalars vt1 expr1 - ; (extExpr2, vt2') <- encapsulateScalars vt2 expr2 - ; return ((fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2), VITNode vi [vt1', vt2']) - } - } - -encapsulateScalars _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2) - = panic "encapsulateScalars AnnLet nonrec doesn't match up" - -encapsulateScalars vt@(VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr) - = do { varsS <- varsSimple fvs - ; case (vi, varsS) of - (VISimple, True) -> return $ liftSimple vt ce - _ -> do { extBndsVts <- zipWithM expBndg vtBnds bndngs - ; let (extBnds, vtBnds') = unzip extBndsVts - ; (extExpr, vtB') <- encapsulateScalars vtB expr - ; let vt' = VITNode vi (vtB':vtBnds') - ; return ((fvs, AnnLet (AnnRec extBnds) extExpr), vt') - } - } - where - expBndg vit (bndr, expr) - = do { (extExpr, vit') <- encapsulateScalars vit expr - ; return ((bndr, extExpr), vit') - } - -encapsulateScalars _ (_fvs, AnnLet (AnnRec _) _expr2) - = panic "encapsulateScalars AnnLet rec doesn't match up" - -encapsulateScalars (VITNode vi [vit]) (fvs, AnnCast expr coercion) - = do { (extExpr, vit') <- encapsulateScalars vit expr - ; return ((fvs, AnnCast extExpr coercion), VITNode vi [vit']) - } - -encapsulateScalars _ (_fvs, AnnCast _expr _coercion) - = panic "encapsulateScalars AnnCast rec doesn't match up" - -encapsulateScalars _ _ - = panic "encapsulateScalars case not handled" + encAlt (con, bndrs, expr) = (con, bndrs,) <$> encapsulateScalars expr +encapsulateScalars ce@((fvs, vi), AnnLet (AnnNonRec bndr expr1) expr2) + = do + { varsS <- allScalarVarTypeSet fvs + ; case (vi, varsS) of + (VISimple, True) -> liftSimpleAndCase ce + _ -> do + { encExpr1 <- encapsulateScalars expr1 + ; encExpr2 <- encapsulateScalars expr2 + ; return ((fvs, vi), AnnLet (AnnNonRec bndr encExpr1) encExpr2) + } + } +encapsulateScalars ce@((fvs, vi), AnnLet (AnnRec binds) expr) + = do + { varsS <- allScalarVarTypeSet fvs + ; case (vi, varsS) of + (VISimple, True) -> liftSimpleAndCase ce + _ -> do + { encBinds <- mapM encBind binds + ; encExpr <- encapsulateScalars expr + ; return ((fvs, vi), AnnLet (AnnRec encBinds) encExpr) + } + } + where + encBind (bndr, expr) = (bndr,) <$> encapsulateScalars expr +encapsulateScalars ((fvs, vi), AnnCast expr coercion) + = do + { encExpr <- encapsulateScalars expr + ; return ((fvs, vi), AnnCast encExpr coercion) + } +encapsulateScalars _ + = panic "Vectorise.Exp.encapsulateScalars: unknown constructor" --- Lambda-lift the given expression and apply it to the abstracted free variables. +-- Lambda-lift the given simple expression and apply it to the abstracted free variables. -- --- If the expression is a case expression scrutinising anything but a primitive type, then lift +-- If the expression is a case expression scrutinising anything, but a scalar type, then lift -- each alternative individually. -- -liftSimple :: VITree -> CoreExprWithFVs -> (CoreExprWithFVs, VITree) -liftSimple (VITNode vi (scrutVit : altVits)) (fvs, AnnCase expr bndr t alts) - | Just (c,_) <- splitTyConApp_maybe (exprType $ deAnnotate $ expr), - (not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon]) -- FIXME: shouldn't be hardcoded - = ((fvs, AnnCase expr bndr t alts'), VITNode vi (scrutVit : altVits')) - where - (alts', altVits') = unzip $ map (\(ac,bndrs, (alt, avi)) -> ((ac,bndrs,alt), avi)) $ - zipWith (\(ac, bndrs, aex) -> \altVi -> (ac, bndrs, liftSimple altVi aex)) alts altVits - -liftSimple viTree ae@(fvs, _annEx) - = (mkAnnApps (mkAnnLams ae vars) vars, viTree') +liftSimpleAndCase :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo +liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts) + = do + { vi <- vectAvoidInfoTypeOf expr + ; if (vi == VISimple) + then + return $ liftSimple aexpr -- if the scrutinee is scalar, we need no special treatment + else do + { alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts + ; return ((fvs, vi), AnnCase expr bndr t alts') + } + } +liftSimpleAndCase aexpr = return $ liftSimple aexpr + +liftSimple :: CoreExprWithVectInfo -> CoreExprWithVectInfo +liftSimple ((fvs, vi), expr) + = ASSERT(vi == VISimple) + mkAnnApps (mkAnnLams vars fvs expr) vars where - mkViTreeLams (VITNode _ vits) [] = VITNode VIEncaps vits - mkViTreeLams vi (_:vs) = VITNode VIEncaps [mkViTreeLams vi vs] + vars = varSetElems fvs - mkViTreeApps vi [] = vi - mkViTreeApps vi (_:vs) = VITNode VISimple [mkViTreeApps vi vs, VITNode VISimple []] - - vars = varSetElems fvs - viTree' = mkViTreeApps (mkViTreeLams viTree vars) vars - - mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet - mkAnnLam bndr ce = AnnLam bndr ce - - mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs - mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check! - mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs - - mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet) - mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v)) + mkAnnLams :: [Var] -> VarSet -> AnnExpr' Var (VarSet, VectAvoidInfo) -> CoreExprWithVectInfo + mkAnnLams [] fvs expr = ASSERT(isEmptyVarSet fvs) + ((emptyVarSet, VIEncaps), expr) + mkAnnLams (v:vs) fvs expr = mkAnnLams vs (fvs `delVarSet` v) (AnnLam v ((fvs, VIEncaps), expr)) - mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs - mkAnnApps (fv, aex') [] = (fv, aex') - mkAnnApps ae (v:vs) = - let - (fv, aex') = mkAnnApps ae vs - in (extendVarSet fv v, mkAnnApp (fv, aex') v) + mkAnnApps :: CoreExprWithVectInfo -> [Var] -> CoreExprWithVectInfo + mkAnnApps aexpr [] = aexpr + mkAnnApps aexpr (v:vs) = mkAnnApps (mkAnnApp aexpr v) vs + + mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo + mkAnnApp aexpr@((fvs, _vi), _expr) v + = ((fvs `extendVarSet` v, VISimple), AnnApp aexpr ((unitVarSet v, VISimple), AnnVar v)) -- |Vectorise an expression. -- -vectExpr :: CoreExprWithFVs -> VITree -> VM VExpr --- vectExpr e vi | not (checkTree vi (deAnnotate e)) --- = pprPanic "vectExpr" (ppr $ deAnnotate e) - -vectExpr (_, AnnVar v) _ +vectExpr :: CoreExprWithVectInfo -> VM VExpr + +vectExpr (_, AnnVar v) = vectVar v -vectExpr (_, AnnLit lit) _ +vectExpr (_, AnnLit lit) = vectConst $ Lit lit -vectExpr e@(_, AnnLam bndr _) vt - | isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e vt - | otherwise = do dflags <- getDynFlags - cantVectorise dflags "Unexpected type lambda (vectExpr)" (ppr (deAnnotate e)) +vectExpr e@(_, AnnLam bndr _) + | isId bndr = vectFnExpr True False e + | otherwise + = do + { dflags <- getDynFlags + ; cantVectorise dflags "Unexpected type lambda (vectExpr)" $ ppr (deAnnotate e) + } -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty'; -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint -- happy. -- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now? -vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) _ +vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) | v == pAT_ERROR_ID - = do { (vty, lty) <- vectAndLiftType ty - ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err']) - } + = do + { (vty, lty) <- vectAndLiftType ty + ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err']) + } where err' = deAnnotate err -- type application (handle multiple consecutive type applications simultaneously to ensure the -- PA dictionaries are put at the right places) -vectExpr e@(_, AnnApp _ arg) (VITNode _ [_, _]) +vectExpr e@(_, AnnApp _ arg) | isAnnTypeArg arg = vectPolyApp e - - -- 'Int', 'Float', or 'Double' literal - -- FIXME: this needs to be generalised -vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) _ - | Just con <- isDataConId_maybe v - , is_special_con con + + -- Lifted literal +vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) + | Just _con <- isDataConId_maybe v = do - let vexpr = App (Var v) (Lit lit) - lexpr <- liftPD vexpr - return (vexpr, lexpr) - where - is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon] + { let vexpr = App (Var v) (Lit lit) + ; lexpr <- liftPD vexpr + ; return (vexpr, lexpr) + } -- value application (dictionary or user value) -vectExpr e@(_, AnnApp fn arg) (VITNode _ [vit1, vit2]) +vectExpr e@(_, AnnApp fn arg) | isPredTy arg_ty -- dictionary application (whose result is not a dictionary) = vectPolyApp e | otherwise -- user value - = do { -- vectorise the types - ; varg_ty <- vectType arg_ty - ; vres_ty <- vectType res_ty + = do + { -- vectorise the types + ; varg_ty <- vectType arg_ty + ; vres_ty <- vectType res_ty - -- vectorise the function and argument expression - ; vfn <- vectExpr fn vit1 - ; varg <- vectExpr arg vit2 + -- vectorise the function and argument expression + ; vfn <- vectExpr fn + ; varg <- vectExpr arg - -- the vectorised function is a closure; apply it to the vectorised argument - ; mkClosureApp varg_ty vres_ty vfn varg - } + -- the vectorised function is a closure; apply it to the vectorised argument + ; mkClosureApp varg_ty vres_ty vfn varg + } where (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn -vectExpr (_, AnnCase scrut bndr ty alts) vt +vectExpr (_, AnnCase scrut bndr ty alts) | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty , isAlgTyCon tycon - = vectAlgCase tycon ty_args scrut bndr ty alts vt - | otherwise = do dflags <- getDynFlags - cantVectorise dflags "Can't vectorise expression" (ppr scrut_ty) + = vectAlgCase tycon ty_args scrut bndr ty alts + | otherwise + = do + { dflags <- getDynFlags + ; cantVectorise dflags "Can't vectorise expression (no algebraic type constructor)" $ + ppr scrut_ty + } where scrut_ty = exprType (deAnnotate scrut) -vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) (VITNode _ [vt1, vt2]) +vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) = do - vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs (Just vt1) - (vbndr, vbody) <- vectBndrIn bndr (vectExpr body vt2) - return $ vLet (vNonRec vbndr vrhs) vbody + { vrhs <- localV $ + inBind bndr $ + vectAnnPolyExpr False rhs + ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) + ; return $ vLet (vNonRec vbndr vrhs) vbody + } -vectExpr (_, AnnLet (AnnRec bs) body) (VITNode _ (vtB : vtBnds)) +vectExpr (_, AnnLet (AnnRec bs) body) = do - (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs + { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ liftM2 (,) - (zipWith3M vect_rhs bndrs rhss vtBnds) - (vectExpr body vtB) - return $ vLet (vRec vbndrs vrhss) vbody + (zipWithM vect_rhs bndrs rhss) + (vectExpr body) + ; return $ vLet (vRec vbndrs vrhss) vbody + } where (bndrs, rhss) = unzip bs - vect_rhs bndr rhs vt = localV - . inBind bndr - . liftM (\(_,_,z)->z) - $ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs (Just vt) - zipWith3M f xs ys zs = zipWithM (\x -> \(y,z) -> (f x y z)) xs (zip ys zs) + vect_rhs bndr rhs = localV $ + inBind bndr $ + vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) rhs -vectExpr (_, AnnTick tickish expr) (VITNode _ [vit]) - = liftM (vTick tickish) (vectExpr expr vit) +vectExpr (_, AnnTick tickish expr) + = vTick tickish <$> vectExpr expr -vectExpr (_, AnnType ty) _ - = liftM vType (vectType ty) +vectExpr (_, AnnType ty) + = vType <$> vectType ty -vectExpr e vit = do dflags <- getDynFlags - cantVectorise dflags "Can't vectorise expression (vectExpr)" (ppr (deAnnotate e) $$ text (" " ++ show vit)) +vectExpr e + = do + { dflags <- getDynFlags + ; cantVectorise dflags "Can't vectorise expression (vectExpr)" $ ppr (deAnnotate e) + } --- |Vectorise an expression that *may* have an outer lambda abstraction. +-- |Vectorise an expression that *may* have an outer lambda abstraction. If the expression is marked +-- as encapsulated ('VIEncaps'), vectorise it as a scalar computation (using a generalised scalar +-- zip). -- -- We do not handle type variables at this point, as they will already have been stripped off by --- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only +-- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only -- deal with Haskell 2011 and (2) class selectors are vectorised elsewhere. -- -vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should - -- be inlined - -> Bool -- ^ Whether the binding is a loop breaker - -> [Var] -- ^ Names of function in same recursive binding group - -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam` - -> VITree - -> VM (Inline, Bool, VExpr) --- vectFnExpr _ _ _ e vi | not (checkTree vi (deAnnotate e)) --- = pprPanic "vectFnExpr" (ppr $ deAnnotate e) -vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) vt@(VITNode _ [vt']) - -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type +vectFnExpr :: Bool -- ^If we process the RHS of a binding, whether that binding + -- should be inlined + -> Bool -- ^Whether the binding is a loop breaker + -> CoreExprWithVectInfo -- ^Expression to vectorise; must have an outer `AnnLam` + -> VM VExpr +vectFnExpr inline loop_breaker expr@(_ann, AnnLam bndr body) + -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type | isId bndr && isPredTy (idType bndr) - = do { vBndr <- vectBndr bndr - ; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body vt' - ; return (inline, isScalarFn, mapVect (mkLams [vectorised vBndr]) vbody) - } - -- non-predicate abstraction: vectorise (try to vectorise as a scalar computation) + = do + { vBndr <- vectBndr bndr + ; vbody <- vectFnExpr inline loop_breaker body + ; return $ mapVect (mkLams [vectorised vBndr]) vbody + } + -- non-predicate abstraction: vectorise as a scalar computation + | isId bndr && isVIEncaps expr + = vectScalarFun . deAnnotate $ expr + -- non-predicate abstraction: vectorise as a non-scalar computation | isId bndr - = mark DontInline True (vectScalarFunMaybe (deAnnotate expr) vt) - `orElseV` - mark inlineMe False (vectLam inline loop_breaker expr vt) -vectFnExpr _ _ _ e vt - -- not an abstraction: vectorise as a vanilla expression - = mark DontInline False $ vectExpr e vt - -mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a) -mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) } + = vectLam inline loop_breaker expr +vectFnExpr _ _ expr + -- not an abstraction: vectorise as a vanilla expression + = vectExpr expr -- |Vectorise type and dictionary applications. -- -- These are always headed by a variable (as we don't support higher-rank polymorphism), but may --- involve two sets of type variables and dictionaries. Consider, +-- involve two sets of type variables and dictionaries. Consider, -- -- > class C a where -- > m :: D b => b -> a -- -- The type of 'm' is 'm :: forall a. C a => forall b. D b => b -> a'. -- -vectPolyApp :: CoreExprWithFVs -> VM VExpr +vectPolyApp :: CoreExprWithVectInfo -> VM VExpr vectPolyApp e0 = case e4 of (_, AnnVar var) @@ -530,21 +540,6 @@ vectDictExpr (Coercion coe) -- instead they become dictionaries of vectorised methods). We treat them differently, though see -- "Note [Scalar dfuns]" in 'Vectorise'. -- -vectScalarFunMaybe :: CoreExpr -- ^ Expression to be vectorised - -> VITree -- ^ Vectorisation information - -> VM VExpr -vectScalarFunMaybe expr (VITNode VIEncaps _) = vectScalarFun expr -vectScalarFunMaybe _expr _ = noV $ ptext (sLit "not a scalar function") - --- |Vectorise an expression of functional type by lifting it by an application of a member of the --- zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.) This is only a valid strategy if the --- function does not contain parallel subcomputations and has only 'Scalar' types in its result and --- arguments — this is a predcondition for calling this function. --- --- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised, --- instead they become dictionaries of vectorised methods). We treat them differently, though see --- "Note [Scalar dfuns]" in 'Vectorise'. --- vectScalarFun :: CoreExpr -> VM VExpr vectScalarFun expr = do @@ -673,12 +668,11 @@ unVectDict ty e -- variables are passed explicit (as conventional arguments) into the body during closure -- construction. -- -vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. - -> Bool -- ^ Whether the binding is a loop breaker. - -> CoreExprWithFVs -- ^ Body of abstraction. - -> VITree +vectLam :: Bool -- ^ Should the RHS of a binding be inlined? + -> Bool -- ^ Whether the binding is a loop breaker. + -> CoreExprWithVectInfo -- ^ Body of abstraction. -> VM VExpr -vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi +vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _) = do { let (bndrs, body) = collectAnnValBinders expr -- grab the in-scope type variables @@ -706,18 +700,13 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi . hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity) $ do { -- generate the vectorised body of the lambda abstraction ; lc <- builtin liftingContext - ; let viBody = stripLams expr vi - -- ; checkTreeAnnM vi expr - ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body viBody) + ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) $ vectExpr body ; vbody' <- break_loop lc res_ty vbody ; return $ vLams lc vbndrs vbody' } } where - stripLams (_, AnnLam _ e) (VITNode _ [vt]) = stripLams e vt - stripLams _ vi = vi - maybe_inline n | inline = Inline n | otherwise = DontInline @@ -735,7 +724,7 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi (LitAlt (mkMachInt 0), [], empty)]) } | otherwise = return (ve, le) -vectLam _ _ _ _ = panic "vectLam" +vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda" -- Vectorise an algebraic case expression. -- @@ -754,31 +743,31 @@ vectLam _ _ _ _ = panic "vectLam" -- -- FIXME: this is too lazy -vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs-> Var -> Type - -> [(AltCon, [Var], CoreExprWithFVs)] -> VITree +vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type + -> [(AltCon, [Var], CoreExprWithVectInfo)] -> VM VExpr -vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] (VITNode _ (scrutVit : [altVit])) +vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] = do - vscrut <- vectExpr scrut scrutVit + vscrut <- vectExpr scrut (vty, lty) <- vectAndLiftType ty - (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit) + (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) return $ vCaseDEFAULT vscrut vbndr vty lty vbody -vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] (VITNode _ (scrutVit : [altVit])) +vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] = do - vscrut <- vectExpr scrut scrutVit + vscrut <- vectExpr scrut (vty, lty) <- vectAndLiftType ty - (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit) + (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) return $ vCaseDEFAULT vscrut vbndr vty lty vbody -vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _ (scrutVit : [altVit])) +vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] = do (vty, lty) <- vectAndLiftType ty - vexpr <- vectExpr scrut scrutVit + vexpr <- vectExpr scrut (vbndr, (vbndrs, (vect_body, lift_body))) <- vect_scrut_bndr . vectBndrsIn bndrs - $ vectExpr body altVit + $ vectExpr body let (vect_bndrs, lift_bndrs) = unzip vbndrs (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr) vect_dc <- maybeV dataConErr (lookupDataCon dc) @@ -796,9 +785,9 @@ vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _ dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc) -vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) +vectAlgCase tycon _ty_args scrut bndr ty alts = do - vect_tc <- maybeV tyConErr (lookupTyCon tycon) + vect_tc <- vectTyCon tycon (vty, lty) <- vectAndLiftType ty let arity = length (tyConDataCons vect_tc) @@ -807,10 +796,10 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) let sel = Var sel_bndr (vbndr, valts) <- vect_scrut_bndr - $ mapM (proc_alt arity sel vty lty) (zip alts' altVits) + $ mapM (proc_alt arity sel vty lty) alts' let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts - vexpr <- vectExpr scrut scrutVit + vexpr <- vectExpr scrut (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr) let (vect_bodies, lift_bodies) = unzip vbodies @@ -829,8 +818,6 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) return . vLet (vNonRec vbndr vexpr) $ (vect_case, lift_case) where - tyConErr = (text "vectAlgCase: type constructor not vectorised" <+> ppr tycon) - vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut") | otherwise = vectBndrIn bndr @@ -842,12 +829,12 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) cmp _ DEFAULT = GT cmp _ _ = panic "vectAlgCase/cmp" - proc_alt arity sel _ lty ((DataAlt dc, bndrs, body), vi) + proc_alt arity sel _ lty (DataAlt dc, bndrs, body@((fvs_body, _), _)) = do vect_dc <- maybeV dataConErr (lookupDataCon dc) let ntag = dataConTagZ vect_dc tag = mkDataConTag vect_dc - fvs = freeVarsOf body `delVarSetList` bndrs + fvs = fvs_body `delVarSetList` bndrs sel_tags <- liftM (`App` sel) (builtin (selTags arity)) lc <- builtin liftingContext @@ -860,7 +847,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) binds <- mapM (pack_var (Var lc) sel_tags tag) . filter isLocalId $ varSetElems fvs - (ve, le) <- vectExpr body vi + (ve, le) <- vectExpr body return (ve, Case (elems `App` sel) lc lty [(DEFAULT, [], (mkLets (concat binds) le))]) -- empty <- emptyPD vty @@ -892,9 +879,6 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) _ -> return [] -vectAlgCase tycon _ty_args _scrut _bndr _ty _alts (VITNode _ _) - = pprPanic "vectAlgCase (mismatched node information)" (ppr tycon) - -- Support to compute information for vectorisation avoidance ------------------ @@ -905,202 +889,248 @@ data VectAvoidInfo = VIParr -- tree contains parallel computations | VISimple -- result type is scalar & no parallel subcomputation | VIComplex -- any result type, no parallel subcomputation | VIEncaps -- tree encapsulated by 'liftSimple' + | VIDict -- dictionary computation (never parallel) deriving (Eq, Show) --- Instead of integrating the vectorisation avoidance information into Core expression, we keep --- them in a separate tree (that structurally mirrors the Core expression that it annotates). +-- Core expression annotated with free variables and vectorisation-specific information. -- -data VITree = VITNode VectAvoidInfo [VITree] - deriving (Show) +type CoreExprWithVectInfo = AnnExpr Id (VarSet, VectAvoidInfo) --- Is any of the tree nodes a 'VIPArr' node? +-- Yield the type of an annotated core expression. -- -anyVIPArr :: [VITree] -> Bool -anyVIPArr = or . (map (\(VITNode vi _) -> vi == VIParr)) +annExprType :: AnnExpr Var ann -> Type +annExprType = exprType . deAnnotate --- Compute Core annotations to determine for which subexpressions we can avoid vectorisation +-- Project the vectorisation information from an annotated Core expression. -- --- FIXME: free scalar vars don't actually need to be passed through, since encapsulations makes sure, --- that there are no free variables in encapsulated lambda expressions -vectAvoidInfo :: CoreExprWithFVs -> VM VITree -vectAvoidInfo ce@(_, AnnVar v) - = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce - ; viTrace ce vi [] - ; traceVt "vectAvoidInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce)) - ; return $ VITNode vi [] - } +vectAvoidInfoOf :: CoreExprWithVectInfo -> VectAvoidInfo +vectAvoidInfoOf ((_, vi), _) = vi -vectAvoidInfo ce@(_, AnnLit _) - = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce - ; viTrace ce vi [] - ; traceVt "vectAvoidInfo AnnLit" (ppr $ exprType $ deAnnotate ce) - ; return $ VITNode vi [] - } +-- Is this a 'VIParr' node? +-- +isVIParr :: CoreExprWithVectInfo -> Bool +isVIParr = (== VIParr) . vectAvoidInfoOf -vectAvoidInfo ce@(_, AnnApp e1 e2) - = do { vt1 <- vectAvoidInfo e1 - ; vt2 <- vectAvoidInfo e2 - ; vi <- if anyVIPArr [vt1, vt2] - then return VIParr - else vectAvoidInfoType $ exprType $ deAnnotate ce - ; viTrace ce vi [vt1, vt2] - ; return $ VITNode vi [vt1, vt2] - } +-- Is this a 'VIEncaps' node? +-- +isVIEncaps :: CoreExprWithVectInfo -> Bool +isVIEncaps = (== VIEncaps) . vectAvoidInfoOf -vectAvoidInfo ce@(_, AnnLam _var body) - = do { vt@(VITNode vi _) <- vectAvoidInfo body - ; viTrace ce vi [vt] - ; let resultVI | vi == VIParr = VIParr - | otherwise = VIComplex - ; return $ VITNode resultVI [vt] - } +-- Is this a 'VIDict' node? +-- +isVIDict :: CoreExprWithVectInfo -> Bool +isVIDict = (== VIDict) . vectAvoidInfoOf -vectAvoidInfo ce@(_, AnnLet (AnnNonRec _var expr) body) - = do { vtE <- vectAvoidInfo expr - ; vtB <- vectAvoidInfo body - ; vi <- if anyVIPArr [vtE, vtB] - then return VIParr - else vectAvoidInfoType $ exprType $ deAnnotate ce - ; viTrace ce vi [vtE, vtB] - ; return $ VITNode vi [vtE, vtB] - } +-- 'VIParr' if either argument is 'VIParr'; otherwise, the first argument. +-- +unlessVIParr :: VectAvoidInfo -> VectAvoidInfo -> VectAvoidInfo +unlessVIParr _ VIParr = VIParr +unlessVIParr vi _ = vi -vectAvoidInfo ce@(_, AnnLet (AnnRec bnds) body) - = do { let (_, exprs) = unzip bnds - ; vtBnds <- mapM (\e -> vectAvoidInfo e) exprs - ; if (anyVIPArr vtBnds) - then do { vtBnds' <- mapM (\e -> vectAvoidInfo e) exprs - ; vtB <- vectAvoidInfo body - ; return (VITNode VIParr (vtB: vtBnds')) - } - else do { vtB@(VITNode vib _) <- vectAvoidInfo body - ; ni <- if (vib == VIParr) - then return VIParr - else vectAvoidInfoType $ exprType $ deAnnotate ce - ; viTrace ce ni (vtB : vtBnds) - ; return $ VITNode ni (vtB : vtBnds) - } - } +-- 'VIParr' if either arguments vectorisation information is 'VIParr'; otherwise, the vectorisation +-- information of the first argument is produced. +-- +unlessVIParrExpr :: VectAvoidInfo -> CoreExprWithVectInfo -> VectAvoidInfo +infixl `unlessVIParrExpr` +unlessVIParrExpr e1 e2 = e1 `unlessVIParr` vectAvoidInfoOf e2 -vectAvoidInfo ce@(_, AnnCase expr _var _ty alts) - = do { vtExpr <- vectAvoidInfo expr - ; vtAlts <- mapM (\(_, _, e) -> vectAvoidInfo e) alts - ; ni <- if anyVIPArr (vtExpr : vtAlts) - then return VIParr - else vectAvoidInfoType $ exprType $ deAnnotate ce - ; viTrace ce ni (vtExpr : vtAlts) - ; return $ VITNode ni (vtExpr: vtAlts) - } +-- Compute Core annotations to determine for which subexpressions we can avoid vectorisation. +-- +-- * The first argument is the set of free, local variables whose evaluation may entail parallelism. +-- +vectAvoidInfo :: VarSet -> CoreExprWithFVs -> VM CoreExprWithVectInfo +vectAvoidInfo pvs ce@(fvs, AnnVar v) + = do + { gpvs <- globalParallelVars + ; vi <- if v `elemVarSet` pvs || v `elemVarSet` gpvs + then return VIParr + else vectAvoidInfoTypeOf ce + ; viTrace ce vi [] -vectAvoidInfo (_, AnnCast expr _) - = do { vt@(VITNode vi _) <- vectAvoidInfo expr - ; return $ VITNode vi [vt] - } + ; vit <- vectAvoidInfoTypeOf ce -- TEMPORARY + ; traceVt (" AnnVar: vectAvoidInfoTypeOf: " ++ show vit) empty -vectAvoidInfo (_, AnnTick _ expr) - = do { vt@(VITNode vi _) <- vectAvoidInfo expr - ; return $ VITNode vi [vt] - } + ; return ((fvs, vi), AnnVar v) + } -vectAvoidInfo (_, AnnType {}) - = return $ VITNode VISimple [] +vectAvoidInfo _pvs ce@(fvs, AnnLit lit) + = do + { vi <- vectAvoidInfoTypeOf ce + ; viTrace ce vi [] + ; return ((fvs, vi), AnnLit lit) + } -vectAvoidInfo (_, AnnCoercion {}) - = return $ VITNode VISimple [] +vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2) + = do + { ceVI <- vectAvoidInfoTypeOf ce + ; eVI1 <- vectAvoidInfo pvs e1 + ; eVI2 <- vectAvoidInfo pvs e2 + ; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2 + ; viTrace ce vi [eVI1, eVI2] + ; return ((fvs, vi), AnnApp eVI1 eVI2) + } + +vectAvoidInfo pvs ce@(fvs, AnnLam var body) + = do + { bodyVI <- vectAvoidInfo pvs body + ; varVI <- vectAvoidInfoType $ varType var + ; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI + ; viTrace ce vi [bodyVI] + ; return ((fvs, vi), AnnLam var bodyVI) + } + +vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body) + = do + { ceVI <- vectAvoidInfoTypeOf ce + ; eVI <- vectAvoidInfo pvs e + ; isScalarTy <- isScalar $ varType var + ; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy + then do -- binding is parallel + { bodyVI <- vectAvoidInfo (fvs `extendVarSet` var) body + ; return (bodyVI, VIParr) + } + else do -- binding doesn't affect parallelism + { bodyVI <- vectAvoidInfo fvs body + ; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI) + } + ; viTrace ce vi [eVI, bodyVI] + ; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI) + } + +vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body) + = do + { ceVI <- vectAvoidInfoTypeOf ce + ; bndsVI <- mapM (vectAvoidInfoBnd pvs) bnds + ; parrBndrs <- map fst <$> filterM isVIParrBnd bndsVI + ; if not . null $ parrBndrs + then do -- body may trigger parallelism via at least one binding + { new_pvs <- filterM ((not <$>) . isScalar . varType) parrBndrs + ; let extendedPvs = pvs `extendVarSetList` new_pvs + ; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds + ; bodyVI <- vectAvoidInfo extendedPvs body + ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI]) + ; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI) + } + else do -- demanded bindings cannot trigger parallelism + { bodyVI <- vectAvoidInfo pvs body + ; let vi = ceVI `unlessVIParrExpr` bodyVI + ; viTrace ce vi (map snd bndsVI ++ [bodyVI]) + ; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI) + } + } + where + vectAvoidInfoBnd pvs (var, e) = (var,) <$> vectAvoidInfo pvs e + + isVIParrBnd (var, eVI) + = do + { isScalarTy <- isScalar (varType var) + ; return $ isVIParr eVI && not isScalarTy + } + +vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts) + = do + { ceVI <- vectAvoidInfoTypeOf ce + ; eVI <- vectAvoidInfo pvs e + ; isScalarTy <- isScalar . annExprType $ e + ; altsVI <- mapM (vectAvoidInfoAlt (isVIParr eVI && not isScalarTy)) alts + ; allScalarBndrs <- anyM allScalarAltBndrs altsVI + ; let alteVIs = [eVI | (_, _, eVI) <- altsVI] + vi | isVIParr eVI && not allScalarBndrs = VIParr + | otherwise + = foldl unlessVIParrExpr ceVI alteVIs + ; viTrace ce vi (eVI : alteVIs) + ; return ((fvs, vi), AnnCase eVI var ty altsVI) + } + where + vectAvoidInfoAlt isScalarScrut (con, bndrs, e) = (con, bndrs,) <$> vectAvoidInfo altPvs e + where + altPvs | isScalarScrut = pvs + | otherwise = pvs `extendVarSetList` bndrs + + allScalarAltBndrs (_, bndrs, _) = allScalarVarType bndrs + +vectAvoidInfo pvs (fvs, AnnCast e (fvs_ann, ann)) + = do + { eVI <- vectAvoidInfo pvs e + ; return ((fvs, vectAvoidInfoOf eVI), AnnCast eVI ((fvs_ann, VISimple), ann)) + } + +vectAvoidInfo pvs (fvs, AnnTick tick e) + = do + { eVI <- vectAvoidInfo pvs e + ; return ((fvs, vectAvoidInfoOf eVI), AnnTick tick eVI) + } + +vectAvoidInfo _pvs (fvs, AnnType ty) + = return ((fvs, VISimple), AnnType ty) + +vectAvoidInfo _pvs (fvs, AnnCoercion coe) + = return ((fvs, VISimple), AnnCoercion coe) -- Compute vectorisation avoidance information for a type. -- vectAvoidInfoType :: Type -> VM VectAvoidInfo -vectAvoidInfoType ty - | maybeParrTy ty = return VIParr - | otherwise - = do { sType <- isSimpleType ty - ; if sType - then return VISimple - else return VIComplex - } +vectAvoidInfoType ty + | isPredTy ty + = return VIDict + | Just (arg, res) <- splitFunTy_maybe ty + = do + { argVI <- vectAvoidInfoType arg + ; resVI <- vectAvoidInfoType res + ; case (argVI, resVI) of + (VISimple, VISimple) -> return VISimple -- NB: diverts from the paper: scalar functions + (_ , VIDict) -> return VIDict + _ -> return $ VIComplex `unlessVIParr` argVI `unlessVIParr` resVI + } + | otherwise + = do + { parr <- maybeParrTy ty + ; if parr + then return VIParr + else do + { scalar <- isScalar ty + ; if scalar + then return VISimple + else return VIComplex + } } + +-- Compute vectorisation avoidance information for the type of a Core expression (with FVs). +-- +vectAvoidInfoTypeOf :: AnnExpr Var ann -> VM VectAvoidInfo +vectAvoidInfoTypeOf = vectAvoidInfoType . annExprType --- Checks whether the type might be a parallel array type. In particular, if the outermost --- constructor is a type family, we conservatively assume that it may be a parallel array type. +-- Checks whether the type might be a parallel array type. -- -maybeParrTy :: Type -> Bool +maybeParrTy :: Type -> VM Bool maybeParrTy ty - | Just ty' <- coreView ty = maybeParrTy ty' - | Just (tyCon, ts) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon - || or (map maybeParrTy ts) -maybeParrTy _ = False - --- FIXME: This should not be hardcoded. -isSimpleType :: Type -> VM Bool -isSimpleType ty - | Just (c, _cs) <- splitTyConApp_maybe ty - = return $ (tyConName c) `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName] -{- - = do { globals <- globalScalarTyCons - ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c) - ; return (elemNameSet (tyConName c) globals ) - } - -} - | Nothing <- splitTyConApp_maybe ty - = return False -isSimpleType ty - = pprPanic "Vectorise.Exp.isSimpleType not handled" (ppr ty) - -varsSimple :: VarSet -> VM Bool -varsSimple vs - = do { varTypes <- mapM isSimpleType $ map varType $ varSetElems vs - ; return $ and varTypes - } - -viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [VITree] -> VM () -viTrace ce vi vTs - = traceVt ("vitrace " ++ (show vi) ++ "[" ++ (concat $ map (\(VITNode vi _) -> show vi ++ " ") vTs) ++"]") - (ppr $ deAnnotate ce) - + -- looking through newtypes + | Just ty' <- coreView ty + = (== VIParr) <$> vectAvoidInfoType ty' + -- decompose constructor applications + | Just (tc, ts) <- splitTyConApp_maybe ty + = do + { isParallel <- (tyConName tc `elemNameSet`) <$> globalParallelTyCons + ; if isParallel + then return True + else or <$> mapM maybeParrTy ts + } +maybeParrTy (ForAllTy _ ty) = maybeParrTy ty +maybeParrTy _ = return False -{- ----- Sanity check of the tree, for debugging only -checkTree :: VITree -> CoreExpr -> Bool -checkTree (VITNode _ []) (Type _ty) - = True - -checkTree (VITNode _ []) (Var _v) - = True - -checkTree (VITNode _ []) (Lit _) - = True - -checkTree (VITNode _ [vit]) (Tick _ expr) - = checkTree vit expr - -checkTree (VITNode _ [vit]) (Lam _ expr) - = checkTree vit expr - -checkTree (VITNode _ [vit1, vit2]) (App ce1 ce2) - = (checkTree vit1 ce1) && (checkTree vit2 ce2) - -checkTree (VITNode _ (scrutVit : altVits)) (Case scrut _ _ alts) - = (checkTree scrutVit scrut) && (and $ zipWith checkAlt altVits alts) - where - checkAlt vt (_, _, expr) = checkTree vt expr - -checkTree (VITNode _ [vt1, vt2]) (Let (NonRec _ expr1) expr2) - = (checkTree vt1 expr1) && (checkTree vt2 expr2) - -checkTree (VITNode _ (vtB : vtBnds)) (Let (Rec bndngs) expr) - = (and $ zipWith checkBndr vtBnds bndngs) && - (checkTree vtB expr) - where - checkBndr vt (_, e) = checkTree vt e - -checkTree (VITNode _ [vit]) (Cast expr _) - = checkTree vit expr +-- Are the types of all variables in the 'Scalar' class? +-- +allScalarVarType :: [Var] -> VM Bool +allScalarVarType vs = and <$> mapM (isScalar . varType) vs -checkTree _ _ = False +-- Are the types of all variables in the set in the 'Scalar' class? +-- +allScalarVarTypeSet :: VarSet -> VM Bool +allScalarVarTypeSet = allScalarVarType . varSetElems -checkTreeAnnM:: VITree -> CoreExprWithFVs -> VM () -checkTreeAnnM vi e = - if not (checkTree vi $ deAnnotate e) - then error ("checkTreeAnnM : \n " ++ show vi) - else return () --} +-- Debugging support +-- +viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [CoreExprWithVectInfo] -> VM () +viTrace ce vi vTs + = traceVt ("vect info: " ++ show vi ++ "[" ++ + (concat $ map ((++ " ") . show . vectAvoidInfoOf) vTs) ++ "]") + (ppr $ deAnnotate ce) diff --git a/compiler/vectorise/Vectorise/Monad.hs b/compiler/vectorise/Vectorise/Monad.hs index 375b0af85e..6b5e9cc354 100644 --- a/compiler/vectorise/Vectorise/Monad.hs +++ b/compiler/vectorise/Vectorise/Monad.hs @@ -14,8 +14,8 @@ module Vectorise.Monad ( -- * Variables lookupVar, lookupVar_maybe, - addGlobalScalarVar, - addGlobalScalarTyCon, + addGlobalParallelVar, + addGlobalParallelTyCon, ) where import Vectorise.Monad.Base @@ -172,22 +172,22 @@ dumpVar dflags var = cantVectorise dflags "Variable not vectorised:" (ppr var) --- Global scalars -------------------------------------------------------------- +-- Global parallel entities ---------------------------------------------------- --- |Mark the given variable as scalar — i.e., executing the associated code does not involve any +-- |Mark the given variable as parallel — i.e., executing the associated code might involve -- parallel array computations. -- -addGlobalScalarVar :: Var -> VM () -addGlobalScalarVar var - = do { traceVt "addGlobalScalarVar" (ppr var) - ; updGEnv $ \env -> env{global_scalar_vars = extendVarSet (global_scalar_vars env) var} +addGlobalParallelVar :: Var -> VM () +addGlobalParallelVar var + = do { traceVt "addGlobalParallelVar" (ppr var) + ; updGEnv $ \env -> env{global_parallel_vars = extendVarSet (global_parallel_vars env) var} } --- |Mark the given type constructor as scalar — i.e., its values cannot embed parallel arrays. +-- |Mark the given type constructor as parallel — i.e., its values might embed parallel arrays. -- -addGlobalScalarTyCon :: TyCon -> VM () -addGlobalScalarTyCon tycon - = do { traceVt "addGlobalScalarTyCon" (ppr tycon) +addGlobalParallelTyCon :: TyCon -> VM () +addGlobalParallelTyCon tycon + = do { traceVt "addGlobalParallelTyCon" (ppr tycon) ; updGEnv $ \env -> - env{global_scalar_tycons = addOneToNameSet (global_scalar_tycons env) (tyConName tycon)} + env{global_parallel_tycons = addOneToNameSet (global_parallel_tycons env) (tyConName tycon)} } diff --git a/compiler/vectorise/Vectorise/Monad/Global.hs b/compiler/vectorise/Vectorise/Monad/Global.hs index a5c8449fc2..0fe460ad73 100644 --- a/compiler/vectorise/Vectorise/Monad/Global.hs +++ b/compiler/vectorise/Vectorise/Monad/Global.hs @@ -6,13 +6,13 @@ module Vectorise.Monad.Global ( updGEnv, -- * Vars - defGlobalVar, + defGlobalVar, undefGlobalVar, -- * Vectorisation declarations - lookupVectDecl, noVectDecl, + lookupVectDecl, -- * Scalars - globalScalarVars, isGlobalScalarVar, globalScalarTyCons, + globalParallelVars, globalParallelTyCons, -- * TyCons lookupTyCon, @@ -93,48 +93,54 @@ defGlobalVar v v' | otherwise = ptext (sLit "in the current module") +-- |Remove the mapping of a variable in the vectorisation map. +-- +undefGlobalVar :: Var -> VM () +undefGlobalVar v + = do + { traceVt "REMOVING global var mapping:" (ppr v) + ; updGEnv $ \env -> env { global_vars = delVarEnv (global_vars env) v } + } + -- Vectorisation declarations ------------------------------------------------- --- |Check whether a variable has a (non-scalar) vectorisation declaration. +-- |Check whether a variable has a vectorisation declaration. -- -lookupVectDecl :: Var -> VM (Maybe (Type, CoreExpr)) -lookupVectDecl var = readGEnv $ \env -> lookupVarEnv (global_vect_decls env) var - --- |Check whether a variable has a 'NOVECTORISE' declaration. +-- The first component of the result indicates whether the variable has a 'NOVECTORISE' declaration. +-- The second component contains the given type and expression in case of a 'VECTORISE' declaration. -- -noVectDecl :: Var -> VM Bool -noVectDecl var = readGEnv $ \env -> elemVarSet var (global_novect_vars env) +lookupVectDecl :: Var -> VM (Bool, Maybe (Type, CoreExpr)) +lookupVectDecl var + = readGEnv $ \env -> + case lookupVarEnv (global_vect_decls env) var of + Nothing -> (False, Nothing) + Just Nothing -> (True, Nothing) + Just vectDecl -> (False, vectDecl) --- Scalars -------------------------------------------------------------------- +-- Parallel entities ----------------------------------------------------------- --- |Get the set of global scalar variables. +-- |Get the set of global parallel variables. -- -globalScalarVars :: VM VarSet -globalScalarVars = readGEnv global_scalar_vars +globalParallelVars :: VM VarSet +globalParallelVars = readGEnv global_parallel_vars --- |Check whether a given variable is in the set of global scalar variables. +-- |Get the set of all parallel type constructors (those that may embed parallelism) including both +-- both those parallel type constructors declared in an imported module and those declared in the +-- current module. -- -isGlobalScalarVar :: Var -> VM Bool -isGlobalScalarVar var = readGEnv $ \env -> var `elemVarSet` global_scalar_vars env - --- |Get the set of global scalar type constructors including both those scalar type constructors --- declared in an imported module and those declared in the current module. --- -globalScalarTyCons :: VM NameSet -globalScalarTyCons = readGEnv global_scalar_tycons +globalParallelTyCons :: VM NameSet +globalParallelTyCons = readGEnv global_parallel_tycons -- TyCons --------------------------------------------------------------------- --- |Lookup the vectorised version of a `TyCon` from the global environment. +-- |Determine the vectorised version of a `TyCon`. The vectorisation map in the global environment +-- contains a vectorised version if the original `TyCon` embeds any parallel arrays. -- lookupTyCon :: TyCon -> VM (Maybe TyCon) lookupTyCon tc - | isUnLiftedTyCon tc || isTupleTyCon tc - = return (Just tc) - | otherwise = readGEnv $ \env -> lookupNameEnv (global_tycons env) (tyConName tc) -- |Add a mapping between plain and vectorised `TyCon`s to the global environment. diff --git a/compiler/vectorise/Vectorise/Monad/InstEnv.hs b/compiler/vectorise/Vectorise/Monad/InstEnv.hs index fc12ee567c..95546bf503 100644 --- a/compiler/vectorise/Vectorise/Monad/InstEnv.hs +++ b/compiler/vectorise/Vectorise/Monad/InstEnv.hs @@ -1,5 +1,6 @@ module Vectorise.Monad.InstEnv - ( lookupInst + ( existsInst + , lookupInst , lookupFamInst ) where @@ -21,6 +22,14 @@ import Util #include "HsVersions.h" +-- Check whether a unique class instance for a given class and type arguments exists. +-- +existsInst :: Class -> [Type] -> VM Bool +existsInst cls tys + = do { instEnv <- readGEnv global_inst_env + ; return $ either (const False) (const True) (lookupUniqueInstEnv instEnv cls tys) + } + -- Look up the dfun of a class instance. -- -- The match must be unique —i.e., match exactly one instance— but the @@ -64,6 +73,6 @@ lookupFamInst tycon tys [(fam_inst, rep_tys)] -> return ( fam_inst, rep_tys) _other -> do dflags <- getDynFlags - cantVectorise dflags "VectMonad.lookupFamInst: not found: " + cantVectorise dflags "Vectorise.Monad.InstEnv.lookupFamInst: not found: " (ppr $ mkTyConApp tycon tys) } diff --git a/compiler/vectorise/Vectorise/Monad/Local.hs b/compiler/vectorise/Vectorise/Monad/Local.hs index 8b3c1dcf19..5415c5691d 100644 --- a/compiler/vectorise/Vectorise/Monad/Local.hs +++ b/compiler/vectorise/Vectorise/Monad/Local.hs @@ -44,20 +44,24 @@ updLEnv f = VM $ \_ genv lenv -> return (Yes genv (f lenv) ()) -- localV :: VM a -> VM a localV p - = do env <- readLEnv id - x <- p - setLEnv env - return x + = do + { env <- readLEnv id + ; x <- p + ; setLEnv env + ; return x + } -- |Perform a computation in an empty local environment. -- closedV :: VM a -> VM a closedV p - = do env <- readLEnv id - setLEnv (emptyLocalEnv { local_bind_name = local_bind_name env }) - x <- p - setLEnv env - return x + = do + { env <- readLEnv id + ; setLEnv (emptyLocalEnv { local_bind_name = local_bind_name env }) + ; x <- p + ; setLEnv env + ; return x + } -- |Get the name of the local binding currently being vectorised. -- diff --git a/compiler/vectorise/Vectorise/Type/Classify.hs b/compiler/vectorise/Vectorise/Type/Classify.hs index 0cab706cf4..e1cd43ac3c 100644 --- a/compiler/vectorise/Vectorise/Type/Classify.hs +++ b/compiler/vectorise/Vectorise/Type/Classify.hs @@ -13,10 +13,12 @@ -- types. As '([::])' is being vectorised, any type constructor whose definition involves -- '([::])', either directly or indirectly, will be vectorised. -module Vectorise.Type.Classify ( - classifyTyCons -) where +module Vectorise.Type.Classify + ( classifyTyCons + ) +where +import NameSet import UniqSet import UniqFM import DataCon @@ -29,7 +31,7 @@ import Digraph -- |From a list of type constructors, extract those that can be vectorised, returning them in two -- sets, where the first result list /must be/ vectorised and the second result list /need not be/ --- vectorised. The third result list are those type constructors that we cannot convert (either +-- vectorised. The third result list are those type constructors that we cannot convert (either -- because they use language extensions or because they dependent on type constructors for which -- no vectorised version is available). @@ -37,28 +39,40 @@ import Digraph -- -- * tycons which have converted versions are mapped to 'True' -- * tycons which are not changed by vectorisation are mapped to 'False' --- * tycons which can't be converted are not elements of the map +-- * tycons which haven't been converted (because they can't or weren't vectorised) are not +-- elements of the map -- -classifyTyCons :: UniqFM Bool -- ^type constructor conversion status - -> [TyCon] -- ^type constructors that need to be classified - -> ([TyCon], [TyCon], [TyCon]) -- ^tycons to be converted & not to be converted -classifyTyCons convStatus tcs = classify [] [] [] convStatus (tyConGroups tcs) +classifyTyCons :: UniqFM Bool -- ^type constructor vectorisation status + -> NameSet -- ^tycons involving parallel arrays + -> [TyCon] -- ^type constructors that need to be classified + -> ( [TyCon] -- to be converted + , [TyCon] -- need not be converted (but could be) + , [TyCon] -- can't be converted, but involve parallel arrays + , [TyCon] -- can't be converted and have no parallel arrays + ) +classifyTyCons convStatus parTyCons tcs = classify [] [] [] [] convStatus parTyCons (tyConGroups tcs) where - classify conv keep ignored _ [] = (conv, keep, ignored) - classify conv keep ignored cs ((tcs, ds) : rs) + classify conv keep par novect _ _ [] = (conv, keep, par, novect) + classify conv keep par novect cs pts ((tcs, ds) : rs) | can_convert && must_convert - = classify (tcs ++ conv) keep ignored (cs `addListToUFM` [(tc, True) | tc <- tcs]) rs + = classify (tcs ++ conv) keep par novect (cs `addListToUFM` [(tc, True) | tc <- tcs]) pts' rs | can_convert - = classify conv (tcs ++ keep) ignored (cs `addListToUFM` [(tc, False) | tc <- tcs]) rs + = classify conv (tcs ++ keep) par novect (cs `addListToUFM` [(tc, False) | tc <- tcs]) pts' rs + | has_parr + = classify conv keep (tcs ++ par) novect cs pts' rs | otherwise - = classify conv keep (tcs ++ ignored) cs rs + = classify conv keep par (tcs ++ novect) cs pts' rs where refs = ds `delListFromUniqSet` tcs + + pts' | has_parr = pts `addListToNameSet` map tyConName tcs + | otherwise = pts can_convert = (isNullUFM (refs `minusUFM` cs) && all convertable tcs) || isShowClass tcs must_convert = foldUFM (||) False (intersectUFM_C const cs refs) && (not . isShowClass $ tcs) + has_parr = any ((`elemNameSet` parTyCons) . tyConName) . eltsUFM $ refs -- We currently admit Haskell 2011-style data and newtype declarations as well as type -- constructors representing classes. diff --git a/compiler/vectorise/Vectorise/Type/Env.hs b/compiler/vectorise/Vectorise/Type/Env.hs index 0051d072a4..faa80a8629 100644 --- a/compiler/vectorise/Vectorise/Type/Env.hs +++ b/compiler/vectorise/Vectorise/Type/Env.hs @@ -32,7 +32,9 @@ import Id import MkId import NameEnv import NameSet +import UniqFM import OccName +import Unique import Util import Outputable @@ -47,69 +49,85 @@ import Data.List -- Note [Pragmas to vectorise tycons] -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -- --- VECTORISE pragmas for type constructors cover three different flavours of vectorising data type +-- All imported type constructors that are not mapped to a vectorised type in the vectorisation map +-- (possibly because the defining module was not compiled with vectorisation) may be used in scalar +-- code encapsulated in vectorised code. If a such a type constructor 'T' is a member of the +-- 'Scalar' class (and hence also of 'PData' and 'PRepr'), it may also be used in vectorised code, +-- where 'T' represents itself, but the representation of 'T' still remains opaque in vectorised +-- code (i.e., it can only be used in scalar code). +-- +-- An example is the treatment of 'Int'. 'Int's can be used in vectorised code and remain unchanged +-- by vectorisation. However, the representation of 'Int' by the 'I#' data constructor wrapping an +-- 'Int#' is not exposed in vectorised code. Instead, computations involving the representation need +-- to be confined to scalar code. +-- +-- VECTORISE pragmas for type constructors cover four different flavours of vectorising data type -- constructors: -- --- (1) Data type constructor 'T' that may be used in vectorised code, where 'T' represents itself, --- but the representation of 'T' is opaque in vectorised code. +-- (1) Data type constructor 'T' that together with its constructors 'Cn' may be used in vectorised +-- code, where 'T' and the 'Cn' are automatically vectorised in the same manner as data types +-- declared in a vectorised module. This includes the case where the vectoriser determines that +-- the original representation of 'T' may be used in vectorised code (as it does not embed any +-- parallel arrays.) This case is for type constructors that are *imported* from a non- +-- vectorised module, but that we want to use with full vectorisation support. -- --- An example is the treatment of 'Int'. 'Int's can be used in vectorised code and remain --- unchanged by vectorisation. However, the representation of 'Int' by the 'I#' data --- constructor wrapping an 'Int#' is not exposed in vectorised code. Instead, computations --- involving the representation need to be confined to scalar code. +-- An example is the treatment of 'Ordering' and '[]'. The former remains unchanged by +-- vectorisation, whereas the latter is fully vectorised. -- --- 'PData' and 'PRepr' instances need to be explicitly supplied for 'T' (they are not generated --- by the vectoriser). +-- 'PData' and 'PRepr' instances are automatically generated by the vectoriser. -- --- Type constructors declared with {-# VECTORISE SCALAR type T #-} are treated in this manner. --- (The vectoriser never treats a type constructor automatically in this manner.) +-- Type constructors declared with {-# VECTORISE type T #-} are treated in this manner. -- -- (2) Data type constructor 'T' that may be used in vectorised code, where 'T' is represented by an --- explicitly given 'Tv', but the representation of 'T' is opaque in vectorised code. +-- explicitly given 'Tv', but the representation of 'T' is opaque in vectorised code (i.e., the +-- constructors of 'T' may not occur in vectorised code). -- --- An example is the treatment of '[::]'. '[::]'s can be used in vectorised code and is --- vectorised to 'PArray'. However, the representation of '[::]' is not exposed in vectorised --- code. Instead, computations involving the representation need to be confined to scalar code. +-- An example is the treatment of '[::]'. The type '[::]' can be used in vectorised code and is +-- vectorised to 'PArray'. However, the representation of '[::]' is not exposed in vectorised +-- code. Instead, computations involving the representation need to be confined to scalar code. -- -- 'PData' and 'PRepr' instances need to be explicitly supplied for 'T' (they are not generated -- by the vectoriser). -- --- Type constructors declared with {-# VECTORISE SCALAR type T = T' #-} are treated in this +-- Type constructors declared with {-# VECTORISE type T = Tv #-} are treated in this manner -- manner. (The vectoriser never treats a type constructor automatically in this manner.) -- --- (3) Data type constructor 'T' that together with its constructors 'Cn' may be used in vectorised --- code, where 'T' and the 'Cn' are automatically vectorised in the same manner as data types --- declared in a vectorised module. This includes the case where the vectoriser determines that --- the original representation of 'T' may be used in vectorised code (as it does not embed any --- parallel arrays.) This case is for type constructors that are *imported* from a non- --- vectorised module, but that we want to use with full vectorisation support. +-- (3) Data type constructor 'T' that does not contain any parallel arrays and has explicitly +-- provided 'PData' and 'PRepr' instances (and maybe also a 'Scalar' instance), which together +-- with the type's constructors 'Cn' may be used in vectorised code. The type 'T' and its +-- constructors 'Cn' are represented by themselves in vectorised code. -- --- An example is the treatment of 'Ordering' and '[]'. The former remains unchanged by --- vectorisation, whereas the latter is fully vectorised. - --- 'PData' and 'PRepr' instances are automatically generated by the vectoriser. +-- An example is 'Bool', which is represented by itself in vectorised code (as it cannot embed +-- any parallel arrays). However, we do not want any automatic generation of class and family +-- instances, which is why Case (1) does not apply. -- --- Type constructors declared with {-# VECTORISE type T #-} are treated in this manner. +-- 'PData' and 'PRepr' instances need to be explicitly supplied for 'T' (they are not generated +-- by the vectoriser). +-- +-- Type constructors declared with {-# VECTORISE SCALAR type T #-} are treated in this manner. -- --- (4) Data type constructor 'T' that together with its constructors 'Cn' may be used in vectorised --- code, where 'T' is represented by an explicitly given 'Tv' whose constructors 'Cvn' represent --- the original constructors in vectorised code. As a special case, we can have 'Tv = T' +-- (4) Data type constructor 'T' that does not contain any parallel arrays and that, in vectorised +-- code, is represented by an explicitly given 'Tv', but the representation of 'T' is opaque in +-- vectorised code and 'T' is regarded to be scalar — i.e., it may be used in encapsulated +-- scalar subcomputations. -- --- An example is the treatment of 'Bool', which is represented by itself in vectorised code --- (as it cannot embed any parallel arrays). However, we do not want any automatic generation --- of class and family instances, which is why Case (3) does not apply. +-- An example is the treatment of '(->)'. Types '(->)' can be used in vectorised code and are +-- vectorised to '(:->)'. However, the representation of '(->)' is not exposed in vectorised +-- code. Instead, computations involving the representation need to be confined to scalar code +-- and may be part of encapsulated scalar computations. -- -- 'PData' and 'PRepr' instances need to be explicitly supplied for 'T' (they are not generated -- by the vectoriser). -- --- Type constructors declared with {-# VECTORISE type T = T' #-} are treated in this manner. +-- Type constructors declared with {-# VECTORISE SCALAR type T = Tv #-} are treated in this +-- manner. (The vectoriser never treats a type constructor automatically in this manner.) -- -- In addition, we have also got a single pragma form for type classes: {-# VECTORISE class C #-}. -- It implies that the class type constructor may be used in vectorised code together with its data -- constructor. We generally produce a vectorised version of the data type and data constructor. -- We do not generate 'PData' and 'PRepr' instances for class type constructors. This pragma is the --- default for all type classes declared in this module, but the pragma can also be used explitly on --- imported classes. +-- default for all type classes declared in a vectorised module, but the pragma can also be used +-- explitly on imported classes. -- Note [Vectorising classes] -- ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -147,38 +165,36 @@ vectTypeEnv :: [TyCon] -- Type constructors defined in this mod vectTypeEnv tycons vectTypeDecls vectClassDecls = do { traceVt "** vectTypeEnv" $ ppr tycons - ; let -- {-# VECTORISE SCALAR type T -#} (imported and local tycons) - localAbstractTyCons = [tycon | VectType True tycon Nothing <- vectTypeDecls] - - -- {-# VECTORISE type T -#} (ONLY the imported tycons) + ; let -- {-# VECTORISE type T -#} (ONLY the imported tycons) impVectTyCons = ( [tycon | VectType False tycon Nothing <- vectTypeDecls] ++ [tycon | VectClass tycon <- vectClassDecls]) \\ tycons + + -- {-# VECTORISE [SCALAR] type T = Tv -#} (imported & local tycons with an /RHS/) + vectTyConsWithRHS = [ (tycon, rhs, isScalar) + | VectType isScalar tycon (Just rhs) <- vectTypeDecls] - -- {-# VECTORISE [SCALAR] type T = T' -#} (imported and local tycons) - vectTyConsWithRHS = [ (tycon, rhs, isAbstract) - | VectType isAbstract tycon (Just rhs) <- vectTypeDecls] + -- {-# VECTORISE SCALAR type T -#} (imported & local /scalar/ tycons without an RHS) + scalarTyConsNoRHS = [tycon | VectType True tycon Nothing <- vectTypeDecls] - -- filter VECTORISE SCALAR tycons and VECTORISE tycons with explicit rhses - vectSpecialTyConNames = mkNameSet . map tyConName $ - localAbstractTyCons ++ map fst3 vectTyConsWithRHS + -- Check that is not a VECTORISE SCALAR tycon nor VECTORISE tycons with explicit rhs? + vectSpecialTyConNames = mkNameSet . map tyConName $ + scalarTyConsNoRHS ++ map fst3 vectTyConsWithRHS notVectSpecialTyCon tc = not $ (tyConName tc) `elemNameSet` vectSpecialTyConNames - -- Build a map containing all vectorised type constructor. If they are scalar, they are - -- mapped to 'False' (vectorised type constructor == original type constructor). - ; allScalarTyConNames <- globalScalarTyCons -- covers both current and imported modules + -- Build a map containing all vectorised type constructor. If the vectorised type + -- constructor differs from the original one, then it is mapped to 'True'; if they are + -- both the same, then it maps to 'False'. ; vectTyCons <- globalVectTyCons - ; let vectTyConBase = mapNameEnv (const True) vectTyCons -- by default fully vectorised + ; let vectTyConBase = mapUFM_Directly isDistinct vectTyCons -- 'True' iff tc /= V[[tc]] + isDistinct u tc = u /= getUnique tc vectTyConFlavour = vectTyConBase `plusNameEnv` mkNameEnv [ (tyConName tycon, True) | (tycon, _, _) <- vectTyConsWithRHS] `plusNameEnv` - mkNameEnv [ (tcName, False) -- original representation - | tcName <- nameSetToList allScalarTyConNames] - `plusNameEnv` mkNameEnv [ (tyConName tycon, False) -- original representation - | tycon <- localAbstractTyCons] + | tycon <- scalarTyConsNoRHS] -- Split the list of 'TyCons' into the ones (1) that we must vectorise and those (2) @@ -189,11 +205,15 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls -- these are being handled separately. NB: Some type constructors may be marked SCALAR -- /and/ have an explicit right-hand side.) -- - -- Furthermore, 'drop_tcs' are those type constructors that we cannot vectorise. - ; let maybeVectoriseTyCons = filter notVectSpecialTyCon tycons ++ impVectTyCons - (conv_tcs, keep_tcs, drop_tcs) = classifyTyCons vectTyConFlavour maybeVectoriseTyCons + -- Furthermore, 'par_tcs' and 'drop_tcs' are those type constructors that we cannot + -- vectorise, and of those, only the 'par_tcs' involve parallel arrays. + ; parallelTyCons <- globalParallelTyCons + ; let maybeVectoriseTyCons = filter notVectSpecialTyCon tycons ++ impVectTyCons + (conv_tcs, keep_tcs, par_tcs, drop_tcs) + = classifyTyCons vectTyConFlavour parallelTyCons maybeVectoriseTyCons - ; traceVt " VECT SCALAR : " $ ppr localAbstractTyCons + ; traceVt " VECT SCALAR : " $ ppr (scalarTyConsNoRHS ++ + [tycon | (tycon, _, True) <- vectTyConsWithRHS]) ; traceVt " VECT [class] : " $ ppr impVectTyCons ; traceVt " VECT with rhs : " $ ppr (map fst3 vectTyConsWithRHS) ; traceVt " -- after classification (local and VECT [class] tycons) --" empty @@ -203,26 +223,22 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls -- warn the user about unvectorised type constructors ; let explanation = ptext (sLit "(They use unsupported language extensions") $$ ptext (sLit "or depend on type constructors that are not vectorised)") - drop_tcs_nosyn = filter (not . isSynTyCon) drop_tcs + drop_tcs_nosyn = filter (not . isSynTyCon) (par_tcs ++ drop_tcs) ; unless (null drop_tcs_nosyn) $ emitVt "Warning: cannot vectorise these type constructors:" $ pprQuotedList drop_tcs_nosyn $$ explanation - ; mapM_ addGlobalScalarTyCon keep_tcs + ; mapM_ addParallelTyConAndCons $ conv_tcs ++ par_tcs ; let mapping = - -- Type constructors that we don't need to vectorise, use the same + -- Type constructors that we found we don't need to vectorise and those + -- declared VECTORISE SCALAR /without/ an explicit right-hand side, use the same -- representation in both unvectorised and vectorised code; they are not -- abstract. - [(tycon, tycon, False) | tycon <- keep_tcs] + [(tycon, tycon, False) | tycon <- keep_tcs ++ scalarTyConsNoRHS] -- We do the same for type constructors declared VECTORISE SCALAR /without/ - -- an explicit right-hand side, but ignore their representation (data - -- constructors) as they are abstract. - ++ [(tycon, tycon, True) | tycon <- localAbstractTyCons] - -- Type constructors declared VECTORISE /with/ an explicit vectorised type, - -- we map from the original to the given type; whether they are abstract depends - -- on whether the vectorisation declaration was SCALAR. - ++ vectTyConsWithRHS + -- an explicit right-hand side + ++ [(tycon, vTycon, True) | (tycon, vTycon, _) <- vectTyConsWithRHS] ; syn_tcs <- catMaybes <$> mapM defTyConDataCons mapping -- Vectorise all the data type declarations that we can and must vectorise (enter the @@ -263,17 +279,15 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls do { defTyConPAs (zipLazy vect_tcs dfuns) -- Query the 'PData' instance type constructors for type constructors that have a - -- VECTORISE pragma with an explicit right-hand side (this is Item (4) of - -- "Note [Pragmas to vectorise tycons]" above). - ; let (withRHS_non_abstract, vwithRHS_non_abstract) - = unzip [(tycon, vtycon) | (tycon, vtycon, False) <- vectTyConsWithRHS] - ; pdata_withRHS_tcs <- mapM pdataReprTyConExact withRHS_non_abstract + -- VECTORISE SCALAR type pragma without an explicit right-hand side (this is Item + -- (3) of "Note [Pragmas to vectorise tycons]" above). + ; pdata_scalar_tcs <- mapM pdataReprTyConExact scalarTyConsNoRHS -- Build workers for all vectorised data constructors (except abstract ones) ; sequence_ $ - zipWith3 vectDataConWorkers (orig_tcs ++ withRHS_non_abstract) - (vect_tcs ++ vwithRHS_non_abstract) - (pdata_tcs ++ pdata_withRHS_tcs) + zipWith3 vectDataConWorkers (orig_tcs ++ scalarTyConsNoRHS) + (vect_tcs ++ scalarTyConsNoRHS) + (pdata_tcs ++ pdata_scalar_tcs) -- Build a 'PA' dictionary for all type constructors (except abstract ones & those -- defined with an explicit right-hand side where the dictionary is user-supplied) @@ -295,6 +309,12 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls } where fst3 (a, _, _) = a + + addParallelTyConAndCons tycon + = do + { addGlobalParallelTyCon tycon + ; mapM_ addGlobalParallelVar . concatMap dataConImplicitIds . tyConDataCons $ tycon + } -- Add a mapping from the original to vectorised type constructor to the vectorisation map. -- Unless the type constructor is abstract, also mappings from the orignal's data constructors @@ -307,21 +327,22 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls -- right type constructor when reading vectorisation information from interface files). -- defTyConDataCons (origTyCon, vectTyCon, isAbstract) - = do { canonName <- mkLocalisedName mkVectTyConOcc origName - ; if origName == vectName -- Case (1) - || vectName == canonName -- Case (2) - then do - { defTyCon origTyCon vectTyCon -- T --> vT - ; defDataCons -- Ci --> vCi - ; return Nothing - } - else do -- Case (3) - { let synTyCon = mkSyn canonName (mkTyConTy vectTyCon) -- type S = vT - ; defTyCon origTyCon synTyCon -- T --> S - ; defDataCons -- Ci --> vCi - ; return $ Just synTyCon - } - } + = do + { canonName <- mkLocalisedName mkVectTyConOcc origName + ; if origName == vectName -- Case (1) + || vectName == canonName -- Case (2) + then do + { defTyCon origTyCon vectTyCon -- T --> vT + ; defDataCons -- Ci --> vCi + ; return Nothing + } + else do -- Case (3) + { let synTyCon = mkSyn canonName (mkTyConTy vectTyCon) -- type S = vT + ; defTyCon origTyCon synTyCon -- T --> S + ; defDataCons -- Ci --> vCi + ; return $ Just synTyCon + } + } where origName = tyConName origTyCon vectName = tyConName vectTyCon @@ -343,9 +364,9 @@ buildTyConPADict vect_tc prepr_ax pdata_tc pdatas_tc = tyConRepr vect_tc >>= buildPADict vect_tc prepr_ax pdata_tc pdatas_tc -- Produce a custom-made worker for the data constructors of a vectorised data type. This includes --- all data constructors that may be used in vetcorised code — i.e., all data constructors of data --- types other than scalar ones. Also adds a mapping from the original to vectorised worker into --- the vectorisation map. +-- all data constructors that may be used in vectorised code — i.e., all data constructors of data +-- types with 'VECTORISE [SCALAR] type' pragmas with an explicit right-hand side. Also adds a mapping +-- from the original to vectorised worker into the vectorisation map. -- -- FIXME: It's not nice that we need create a special worker after the data constructors has -- already been constructed. Also, I don't think the worker is properly added to the data diff --git a/compiler/vectorise/Vectorise/Type/Type.hs b/compiler/vectorise/Vectorise/Type/Type.hs index a7ec86a296..ebb09e663c 100644 --- a/compiler/vectorise/Vectorise/Type/Type.hs +++ b/compiler/vectorise/Vectorise/Type/Type.hs @@ -14,21 +14,16 @@ import TcType import Type import TypeRep import TyCon -import Outputable import Control.Monad import Control.Applicative import Data.Maybe --- | Vectorise a type constructor. + +-- |Vectorise a type constructor. Unless there is a vectorised version (stripped of embedded +-- parallel arrays), the vectorised version is the same as the original. -- vectTyCon :: TyCon -> VM TyCon -vectTyCon tc - | isFunTyCon tc = builtin closureTyCon - | isBoxedTupleTyCon tc = return tc - | isUnLiftedTyCon tc = return tc - | otherwise - = maybeCantVectoriseM "Tycon not vectorised: " (ppr tc) - $ lookupTyCon tc +vectTyCon tc = maybe tc id <$> lookupTyCon tc -- |Produce the vectorised and lifted versions of a type. -- diff --git a/compiler/vectorise/Vectorise/Utils.hs b/compiler/vectorise/Vectorise/Utils.hs index c5f1cb7cb1..fafce7a67d 100644 --- a/compiler/vectorise/Vectorise/Utils.hs +++ b/compiler/vectorise/Vectorise/Utils.hs @@ -17,7 +17,7 @@ module Vectorise.Utils ( combinePD, liftPD, -- * Scalars - zipScalars, scalarClosure, + isScalar, zipScalars, scalarClosure, -- * Naming newLocalVar @@ -137,20 +137,29 @@ liftPD x -- Scalars -------------------------------------------------------------------- +isScalar :: Type -> VM Bool +isScalar ty + = do + { scalar <- builtin scalarClass + ; existsInst scalar [ty] + } + zipScalars :: [Type] -> Type -> VM CoreExpr zipScalars arg_tys res_ty - = do - scalar <- builtin scalarClass - (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args - zipf <- builtin (scalarZip $ length arg_tys) - return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns + = do + { scalar <- builtin scalarClass + ; (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args + ; zipf <- builtin (scalarZip $ length arg_tys) + ; return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns + } where ty_args = arg_tys ++ [res_ty] scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr scalarClosure arg_tys res_ty scalar_fun array_fun = do - ctr <- builtin (closureCtrFun $ length arg_tys) - pas <- mapM paDictOfType (init arg_tys) - return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty]) + { ctr <- builtin (closureCtrFun $ length arg_tys) + ; pas <- mapM paDictOfType (init arg_tys) + ; return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty]) `mkApps` (pas ++ [scalar_fun, array_fun]) + } |