diff options
Diffstat (limited to 'compiler/vectorise/Vectorise/Exp.hs')
-rw-r--r-- | compiler/vectorise/Vectorise/Exp.hs | 60 |
1 files changed, 40 insertions, 20 deletions
diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs index 83c87100a2..ffc1b9caf2 100644 --- a/compiler/vectorise/Vectorise/Exp.hs +++ b/compiler/vectorise/Vectorise/Exp.hs @@ -31,7 +31,7 @@ import DataCon import TyCon import TcType import Type -import TypeRep +import TyCoRep import Var import VarEnv import VarSet @@ -363,7 +363,7 @@ vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) | v == pAT_ERROR_ID = do { (vty, lty) <- vectAndLiftType ty - ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err']) + ; return (mkCoreApps (Var v) [Type (getLevity "vectExpr" vty), Type vty, err'], mkCoreApps (Var v) [Type lty, err']) } where err' = deAnnotate err @@ -712,11 +712,11 @@ vectScalarDFun var ; return $ mkLams (tvs ++ vThetaBndr) vBody } where - ty = varType var - (tvs, theta, pty) = tcSplitSigmaTy ty -- 'theta' is the instance context - (cls, tys) = tcSplitDFunHead pty -- 'pty' is the instance head - selIds = classAllSelIds cls - dataCon = classDataCon cls + ty = varType var + (tvs, theta, pty) = tcSplitSigmaTy ty -- 'theta' is the instance context + (cls, tys) = tcSplitDFunHead pty -- 'pty' is the instance head + selIds = classAllSelIds cls + dataCon = classDataCon cls -- Build a value of the dictionary before vectorisation from original, unvectorised type and an -- expression computing the vectorised dictionary. @@ -1039,7 +1039,7 @@ unlessVIParrExpr e1 e2 = e1 `unlessVIParr` vectAvoidInfoOf e2 -- * The first argument is the set of free, local variables whose evaluation may entail parallelism. -- vectAvoidInfo :: VarSet -> CoreExprWithFVs -> VM CoreExprWithVectInfo -vectAvoidInfo pvs ce@(fvs, AnnVar v) +vectAvoidInfo pvs ce@(_, AnnVar v) = do { gpvs <- globalParallelVars ; vi <- if v `elemVarSet` pvs || v `elemVarSet` gpvs @@ -1052,15 +1052,19 @@ vectAvoidInfo pvs ce@(fvs, AnnVar v) ; return ((udfmToUfm fvs, vi), AnnVar v) } + where + fvs = freeVarsOf ce -vectAvoidInfo _pvs ce@(fvs, AnnLit lit) +vectAvoidInfo _pvs ce@(_, AnnLit lit) = do { vi <- vectAvoidInfoTypeOf ce ; viTrace ce vi [] ; return ((udfmToUfm fvs, vi), AnnLit lit) } + where + fvs = freeVarsOf ce -vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2) +vectAvoidInfo pvs ce@(_, AnnApp e1 e2) = do { ceVI <- vectAvoidInfoTypeOf ce ; eVI1 <- vectAvoidInfo pvs e1 @@ -1069,8 +1073,10 @@ vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2) -- ; viTrace ce vi [eVI1, eVI2] ; return ((udfmToUfm fvs, vi), AnnApp eVI1 eVI2) } + where + fvs = freeVarsOf ce -vectAvoidInfo pvs (fvs, AnnLam var body) +vectAvoidInfo pvs ce@(_, AnnLam var body) = do { bodyVI <- vectAvoidInfo pvs body ; varVI <- vectAvoidInfoType $ varType var @@ -1078,8 +1084,10 @@ vectAvoidInfo pvs (fvs, AnnLam var body) -- ; viTrace ce vi [bodyVI] ; return ((udfmToUfm fvs, vi), AnnLam var bodyVI) } + where + fvs = freeVarsOf ce -vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body) +vectAvoidInfo pvs ce@(_, AnnLet (AnnNonRec var e) body) = do { ceVI <- vectAvoidInfoTypeOf ce ; eVI <- vectAvoidInfo pvs e @@ -1096,8 +1104,10 @@ vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body) -- ; viTrace ce vi [eVI, bodyVI] ; return ((udfmToUfm fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI) } + where + fvs = freeVarsOf ce -vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body) +vectAvoidInfo pvs ce@(_, AnnLet (AnnRec bnds) body) = do { ceVI <- vectAvoidInfoTypeOf ce ; bndsVI <- mapM (vectAvoidInfoBnd pvs) bnds @@ -1119,6 +1129,7 @@ vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body) } } where + fvs = freeVarsOf ce vectAvoidInfoBnd pvs (var, e) = (var,) <$> vectAvoidInfo pvs e isVIParrBnd (var, eVI) @@ -1127,7 +1138,7 @@ vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body) ; return $ isVIParr eVI && not isScalarTy } -vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts) +vectAvoidInfo pvs ce@(_, AnnCase e var ty alts) = do { ceVI <- vectAvoidInfoTypeOf ce ; eVI <- vectAvoidInfo pvs e @@ -1138,6 +1149,7 @@ vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts) ; return ((udfmToUfm fvs, vi), AnnCase eVI var ty altsVI) } where + fvs = freeVarsOf ce vectAvoidInfoAlt scrutIsPar (con, bndrs, e) = do { allScalar <- allScalarVarType bndrs @@ -1146,24 +1158,31 @@ vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts) ; (con, bndrs,) <$> vectAvoidInfo altPvs e } -vectAvoidInfo pvs (fvs, AnnCast e (fvs_ann, ann)) +vectAvoidInfo pvs ce@(_, AnnCast e (fvs_ann, ann)) = do { eVI <- vectAvoidInfo pvs e - ; return ((udfmToUfm fvs, vectAvoidInfoOf eVI) - , AnnCast eVI ((udfmToUfm fvs_ann, VISimple), ann)) + ; return ((udfmToUfm fvs, vectAvoidInfoOf eVI), AnnCast eVI ((udfmToUfm $ freeVarsOfAnn fvs_ann, VISimple), ann)) } + where + fvs = freeVarsOf ce -vectAvoidInfo pvs (fvs, AnnTick tick e) +vectAvoidInfo pvs ce@(_, AnnTick tick e) = do { eVI <- vectAvoidInfo pvs e ; return ((udfmToUfm fvs, vectAvoidInfoOf eVI), AnnTick tick eVI) } + where + fvs = freeVarsOf ce -vectAvoidInfo _pvs (fvs, AnnType ty) +vectAvoidInfo _pvs ce@(_, AnnType ty) = return ((udfmToUfm fvs, VISimple), AnnType ty) + where + fvs = freeVarsOf ce -vectAvoidInfo _pvs (fvs, AnnCoercion coe) +vectAvoidInfo _pvs ce@(_, AnnCoercion coe) = return ((udfmToUfm fvs, VISimple), AnnCoercion coe) + where + fvs = freeVarsOf ce -- Compute vectorisation avoidance information for a type. -- @@ -1212,6 +1231,7 @@ maybeParrTy ty then return True else or <$> mapM maybeParrTy ts } + -- must be a Named ForAllTy because anon ones respond to splitTyConApp_maybe maybeParrTy (ForAllTy _ ty) = maybeParrTy ty maybeParrTy _ = return False |