diff options
Diffstat (limited to 'compiler/vectorise/Vectorise/Exp.hs')
-rw-r--r-- | compiler/vectorise/Vectorise/Exp.hs | 1048 |
1 files changed, 539 insertions, 509 deletions
diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs index 8c5ef0045d..88f123210b 100644 --- a/compiler/vectorise/Vectorise/Exp.hs +++ b/compiler/vectorise/Vectorise/Exp.hs @@ -3,10 +3,9 @@ -- |Vectorisation of expressions. module Vectorise.Exp - ( -- * Vectorise polymorphic expressions with special cases for right-hand sides of particular - -- variable bindings - vectPolyExpr - , vectDictExpr + ( -- * Vectorise right-hand sides of toplevel bindings + vectTopExpr + , vectTopExprs , vectScalarFun , vectScalarDFun ) @@ -32,393 +31,404 @@ import DataCon import TyCon import TcType import Type -import PrelNames +import TypeRep import Var import VarEnv import VarSet +import NameSet import Id import BasicTypes( isStrongLoopBreaker ) import Literal -import TysWiredIn import TysPrim import Outputable import FastString +import DynFlags +import Util +import MonadUtils + import Control.Monad -import Control.Applicative import Data.Maybe import Data.List -import TcRnMonad (doptM) -import DynFlags -import Util -- Main entry point to vectorise expressions ----------------------------------- --- |Vectorise a polymorphic expression. +-- |Vectorise a polymorphic expression that forms a *non-recursive* binding. +-- +-- Return 'Nothing' if the expression is scalar; otherwise, the first component of the result +-- (which is of type 'Bool') indicates whether the expression is parallel (i.e., whether it is +-- tagged as 'VIParr'). -- --- If not yet available, precompute vectorisation avoidance information before vectorising. If --- the vectorisation avoidance optimisation is enabled, also use the vectorisation avoidance --- information to encapsulated subexpression that do not need to be vectorised. +-- We have got the non-recursive case as a special case as it doesn't require to compute +-- vectorisation information twice. -- -vectPolyExpr :: Bool -> [Var] -> CoreExprWithFVs -> Maybe VITree - -> VM (Inline, Bool, VExpr) - -- precompute vectorisation avoidance information (and possibly encapsulated subexpressions) -vectPolyExpr loop_breaker recFns expr Nothing +vectTopExpr :: Var -> CoreExpr -> VM (Maybe (Bool, Inline, CoreExpr)) +vectTopExpr var expr = do - { vectAvoidance <- liftDs $ doptM Opt_AvoidVect - ; vi <- vectAvoidInfo expr - ; (expr', vi') <- - if vectAvoidance - then do - { (expr', vi') <- encapsulateScalars vi expr - ; traceVt "vectPolyExpr encapsulated:" (ppr $ deAnnotate expr') - ; return (expr', vi') - } - else return (expr, vi) - ; vectPolyExpr loop_breaker recFns expr' (Just vi') + { exprVI <- encapsulateScalars <=< vectAvoidInfo emptyVarSet . freeVars $ expr + ; if isVIEncaps exprVI + then + return Nothing + else do + { vExpr <- closedV $ + inBind var $ + vectAnnPolyExpr False exprVI + ; inline <- computeInline exprVI + ; return $ Just (isVIParr exprVI, inline, vectorised vExpr) + } } - -- traverse through ticks -vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr) (Just (VITNode _ [vit])) - = do - { (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr (Just vit) - ; return (inline, isScalarFn, vTick tickish expr') - } +-- Compute the inlining hint for the right-hand side of a top-level binding. +-- +computeInline :: CoreExprWithVectInfo -> VM Inline +computeInline ((_, VIDict), _) = return $ DontInline +computeInline (_, AnnTick _ expr) = computeInline expr +computeInline expr@(_, AnnLam _ _) = Inline <$> polyArity tvs + where + (tvs, _) = collectAnnTypeBinders expr +computeInline _expr = return $ DontInline - -- collect and vectorise type abstractions; then, descent into the body -vectPolyExpr loop_breaker recFns expr (Just vit) - = do - { let (tvs, mono) = collectAnnTypeBinders expr - vit' = stripLevels (length tvs) vit - ; arity <- polyArity tvs - ; polyAbstract tvs $ \args -> - do - { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vit' - ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono') - } +-- |Vectorise a recursive group of top-level polymorphic expressions. +-- +-- Return 'Nothing' if the expression group is scalar; otherwise, the first component of the result +-- (which is of type 'Bool') indicates whether the expressions are parallel (i.e., whether they are +-- tagged as 'VIParr'). +-- +vectTopExprs :: [(Var, CoreExpr)] -> VM (Maybe (Bool, [(Inline, CoreExpr)])) +vectTopExprs binds + = do + { exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs + ; if all isVIEncaps exprVIs + then + return Nothing + else do + { (areVIParr, vExprs) <- unzip <$> mapM encapsulateAndVect binds + ; return $ Just (or areVIParr, vExprs) + } } where - stripLevels 0 vit = vit - stripLevels n (VITNode _ [vit]) = stripLevels (n - 1) vit - stripLevels _ vit = pprPanic "vectPolyExpr: stripLevels:" (text (show vit)) + (vars, exprs) = unzip binds + + vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars + + encapsulateAndVect (var, expr) + = do + { exprVI <- vectAvoidAndEncapsulate (mkVarSet vars) expr + ; vExpr <- closedV $ + inBind var $ + vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI + ; inline <- computeInline exprVI + ; return (isVIParr exprVI, (inline, vectorised vExpr)) + } + +-- |Vectorise a polymorphic expression annotated with vectorisation information. +-- +-- The special case of dictionary functions is currently handled separately. (Would be neater to +-- integrate them, though!) +-- +vectAnnPolyExpr :: Bool -> CoreExprWithVectInfo -> VM VExpr +vectAnnPolyExpr loop_breaker (_, AnnTick tickish expr) + -- traverse through ticks + = vTick tickish <$> vectAnnPolyExpr loop_breaker expr +vectAnnPolyExpr loop_breaker expr + | isVIDict expr + -- special case the right-hand side of dictionary functions + = (, undefined) <$> vectDictExpr (deAnnotate expr) + | otherwise + -- collect and vectorise type abstractions; then, descent into the body + = polyAbstract tvs $ \args -> + mapVect (mkLams $ tvs ++ args) <$> vectFnExpr False loop_breaker mono + where + (tvs, mono) = collectAnnTypeBinders expr -- Encapsulate every purely sequential subexpression of a (potentially) parallel expression into a --- into a lambda abstraction over all its free variables followed by the corresponding application --- to those variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions. +-- lambda abstraction over all its free variables followed by the corresponding application to those +-- variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions. -- -- Preconditions: -- -- * All free variables and the result type must be /simple/ types. --- * The expression is sufficientlt complex (top warrant special treatment). For now, that is +-- * The expression is sufficiently complex (to warrant special treatment). For now, that is -- every expression that is not constant and contains at least one operation. -- -encapsulateScalars :: VITree -> CoreExprWithFVs -> VM (CoreExprWithFVs, VITree) -encapsulateScalars vit ce@(_, AnnType _ty) - = return (ce, vit) - -encapsulateScalars vit ce@(_, AnnVar _v) - = return (ce, vit) - -encapsulateScalars vit ce@(_, AnnLit _) - = return (ce, vit) - -encapsulateScalars (VITNode vi [vit]) (fvs, AnnTick tck expr) - = do { (extExpr, vit') <- encapsulateScalars vit expr - ; return ((fvs, AnnTick tck extExpr), VITNode vi [vit']) - } - -encapsulateScalars _ (_fvs, AnnTick _tck _expr) - = panic "encapsulateScalar AnnTick doesn't match up" - -encapsulateScalars (VITNode vi [vit]) ce@(fvs, AnnLam bndr expr) - = do { varsS <- varsSimple fvs - ; case (vi, varsS) of - (VISimple, True) -> do { let (e', vit') = liftSimple vit ce - ; return (e', vit') - } - _ -> do { (extExpr, vit') <- encapsulateScalars vit expr - ; return ((fvs, AnnLam bndr extExpr), VITNode vi [vit']) - } - } - -encapsulateScalars _ (_fvs, AnnLam _bndr _expr) - = panic "encapsulateScalars AnnLam doesn't match up" - -encapsulateScalars vt@(VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2) - = do { varsS <- varsSimple fvs - ; case (vi, varsS) of - (VISimple, True) -> do { let (e', vt') = liftSimple vt ce - -- ; checkTreeAnnM vt' e' - -- ; traceVt "Passed checkTree test!!" (ppr $ deAnnotate e') - ; return (e', vt') - } - _ -> do { (etaCe1, vit1') <- encapsulateScalars vit1 ce1 - ; (etaCe2, vit2') <- encapsulateScalars vit2 ce2 - ; return ((fvs, AnnApp etaCe1 etaCe2), VITNode vi [vit1', vit2']) - } - } - -encapsulateScalars _ (_fvs, AnnApp _ce1 _ce2) - = panic "encapsulateScalars AnnApp doesn't match up" - -encapsulateScalars vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts) - = do { varsS <- varsSimple fvs - ; case (vi, varsS) of - (VISimple, True) -> return $ liftSimple vt ce - _ -> do { (extScrut, scrutVit') <- encapsulateScalars scrutVit scrut - ; extAltsVits <- zipWithM expAlt altVits alts - ; let (extAlts, altVits') = unzip extAltsVits - ; return ((fvs, AnnCase extScrut bndr ty extAlts), VITNode vi (scrutVit': altVits')) - } - } +encapsulateScalars :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo +encapsulateScalars ce@(_, AnnType _ty) + = return ce +encapsulateScalars ce@((_, VISimple), AnnVar v) + | isFunTy . varType $ v -- NB: diverts from the paper: encapsulate scalar function types + = liftSimpleAndCase ce +encapsulateScalars ce@(_, AnnVar _v) + = return ce +encapsulateScalars ce@(_, AnnLit _) + = return ce +encapsulateScalars ((fvs, vi), AnnTick tck expr) + = do + { encExpr <- encapsulateScalars expr + ; return ((fvs, vi), AnnTick tck encExpr) + } +encapsulateScalars ce@((fvs, vi), AnnLam bndr expr) + = do + { varsS <- allScalarVarTypeSet fvs + ; case (vi, varsS) of + (VISimple, True) -> liftSimpleAndCase ce + _ -> do + { encExpr <- encapsulateScalars expr + ; return ((fvs, vi), AnnLam bndr encExpr) + } + } +encapsulateScalars ce@((fvs, vi), AnnApp ce1 ce2) + = do + { varsS <- allScalarVarTypeSet fvs + ; case (vi, varsS) of + (VISimple, True) -> liftSimpleAndCase ce + _ -> do + { encCe1 <- encapsulateScalars ce1 + ; encCe2 <- encapsulateScalars ce2 + ; return ((fvs, vi), AnnApp encCe1 encCe2) + } + } +encapsulateScalars ce@((fvs, vi), AnnCase scrut bndr ty alts) + = do + { varsS <- allScalarVarTypeSet fvs + ; case (vi, varsS) of + (VISimple, True) -> liftSimpleAndCase ce + _ -> do + { encScrut <- encapsulateScalars scrut + ; encAlts <- mapM encAlt alts + ; return ((fvs, vi), AnnCase encScrut bndr ty encAlts) + } + } where - expAlt vt (con, bndrs, expr) - = do { (extExpr, vt') <- encapsulateScalars vt expr - ; return ((con, bndrs, extExpr), vt') - } - -encapsulateScalars _ (_fvs, AnnCase _scrut _bndr _ty _alts) - = panic "encapsulateScalars AnnCase doesn't match up" - -encapsulateScalars vt@(VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2) - = do { varsS <- varsSimple fvs - ; case (vi, varsS) of - (VISimple, True) -> return $ liftSimple vt ce - _ -> do { (extExpr1, vt1') <- encapsulateScalars vt1 expr1 - ; (extExpr2, vt2') <- encapsulateScalars vt2 expr2 - ; return ((fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2), VITNode vi [vt1', vt2']) - } - } - -encapsulateScalars _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2) - = panic "encapsulateScalars AnnLet nonrec doesn't match up" - -encapsulateScalars vt@(VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr) - = do { varsS <- varsSimple fvs - ; case (vi, varsS) of - (VISimple, True) -> return $ liftSimple vt ce - _ -> do { extBndsVts <- zipWithM expBndg vtBnds bndngs - ; let (extBnds, vtBnds') = unzip extBndsVts - ; (extExpr, vtB') <- encapsulateScalars vtB expr - ; let vt' = VITNode vi (vtB':vtBnds') - ; return ((fvs, AnnLet (AnnRec extBnds) extExpr), vt') - } - } - where - expBndg vit (bndr, expr) - = do { (extExpr, vit') <- encapsulateScalars vit expr - ; return ((bndr, extExpr), vit') - } - -encapsulateScalars _ (_fvs, AnnLet (AnnRec _) _expr2) - = panic "encapsulateScalars AnnLet rec doesn't match up" - -encapsulateScalars (VITNode vi [vit]) (fvs, AnnCast expr coercion) - = do { (extExpr, vit') <- encapsulateScalars vit expr - ; return ((fvs, AnnCast extExpr coercion), VITNode vi [vit']) - } - -encapsulateScalars _ (_fvs, AnnCast _expr _coercion) - = panic "encapsulateScalars AnnCast rec doesn't match up" - -encapsulateScalars _ _ - = panic "encapsulateScalars case not handled" + encAlt (con, bndrs, expr) = (con, bndrs,) <$> encapsulateScalars expr +encapsulateScalars ce@((fvs, vi), AnnLet (AnnNonRec bndr expr1) expr2) + = do + { varsS <- allScalarVarTypeSet fvs + ; case (vi, varsS) of + (VISimple, True) -> liftSimpleAndCase ce + _ -> do + { encExpr1 <- encapsulateScalars expr1 + ; encExpr2 <- encapsulateScalars expr2 + ; return ((fvs, vi), AnnLet (AnnNonRec bndr encExpr1) encExpr2) + } + } +encapsulateScalars ce@((fvs, vi), AnnLet (AnnRec binds) expr) + = do + { varsS <- allScalarVarTypeSet fvs + ; case (vi, varsS) of + (VISimple, True) -> liftSimpleAndCase ce + _ -> do + { encBinds <- mapM encBind binds + ; encExpr <- encapsulateScalars expr + ; return ((fvs, vi), AnnLet (AnnRec encBinds) encExpr) + } + } + where + encBind (bndr, expr) = (bndr,) <$> encapsulateScalars expr +encapsulateScalars ((fvs, vi), AnnCast expr coercion) + = do + { encExpr <- encapsulateScalars expr + ; return ((fvs, vi), AnnCast encExpr coercion) + } +encapsulateScalars _ + = panic "Vectorise.Exp.encapsulateScalars: unknown constructor" --- Lambda-lift the given expression and apply it to the abstracted free variables. +-- Lambda-lift the given simple expression and apply it to the abstracted free variables. -- --- If the expression is a case expression scrutinising anything but a primitive type, then lift +-- If the expression is a case expression scrutinising anything, but a scalar type, then lift -- each alternative individually. -- -liftSimple :: VITree -> CoreExprWithFVs -> (CoreExprWithFVs, VITree) -liftSimple (VITNode vi (scrutVit : altVits)) (fvs, AnnCase expr bndr t alts) - | Just (c,_) <- splitTyConApp_maybe (exprType $ deAnnotate $ expr), - (not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon]) -- FIXME: shouldn't be hardcoded - = ((fvs, AnnCase expr bndr t alts'), VITNode vi (scrutVit : altVits')) - where - (alts', altVits') = unzip $ map (\(ac,bndrs, (alt, avi)) -> ((ac,bndrs,alt), avi)) $ - zipWith (\(ac, bndrs, aex) -> \altVi -> (ac, bndrs, liftSimple altVi aex)) alts altVits - -liftSimple viTree ae@(fvs, _annEx) - = (mkAnnApps (mkAnnLams ae vars) vars, viTree') +liftSimpleAndCase :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo +liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts) + = do + { vi <- vectAvoidInfoTypeOf expr + ; if (vi == VISimple) + then + return $ liftSimple aexpr -- if the scrutinee is scalar, we need no special treatment + else do + { alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts + ; return ((fvs, vi), AnnCase expr bndr t alts') + } + } +liftSimpleAndCase aexpr = return $ liftSimple aexpr + +liftSimple :: CoreExprWithVectInfo -> CoreExprWithVectInfo +liftSimple ((fvs, vi), expr) + = ASSERT(vi == VISimple) + mkAnnApps (mkAnnLams vars fvs expr) vars where - mkViTreeLams (VITNode _ vits) [] = VITNode VIEncaps vits - mkViTreeLams vi (_:vs) = VITNode VIEncaps [mkViTreeLams vi vs] + vars = varSetElems fvs - mkViTreeApps vi [] = vi - mkViTreeApps vi (_:vs) = VITNode VISimple [mkViTreeApps vi vs, VITNode VISimple []] - - vars = varSetElems fvs - viTree' = mkViTreeApps (mkViTreeLams viTree vars) vars - - mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet - mkAnnLam bndr ce = AnnLam bndr ce - - mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs - mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check! - mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs - - mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet) - mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v)) + mkAnnLams :: [Var] -> VarSet -> AnnExpr' Var (VarSet, VectAvoidInfo) -> CoreExprWithVectInfo + mkAnnLams [] fvs expr = ASSERT(isEmptyVarSet fvs) + ((emptyVarSet, VIEncaps), expr) + mkAnnLams (v:vs) fvs expr = mkAnnLams vs (fvs `delVarSet` v) (AnnLam v ((fvs, VIEncaps), expr)) - mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs - mkAnnApps (fv, aex') [] = (fv, aex') - mkAnnApps ae (v:vs) = - let - (fv, aex') = mkAnnApps ae vs - in (extendVarSet fv v, mkAnnApp (fv, aex') v) + mkAnnApps :: CoreExprWithVectInfo -> [Var] -> CoreExprWithVectInfo + mkAnnApps aexpr [] = aexpr + mkAnnApps aexpr (v:vs) = mkAnnApps (mkAnnApp aexpr v) vs + + mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo + mkAnnApp aexpr@((fvs, _vi), _expr) v + = ((fvs `extendVarSet` v, VISimple), AnnApp aexpr ((unitVarSet v, VISimple), AnnVar v)) -- |Vectorise an expression. -- -vectExpr :: CoreExprWithFVs -> VITree -> VM VExpr --- vectExpr e vi | not (checkTree vi (deAnnotate e)) --- = pprPanic "vectExpr" (ppr $ deAnnotate e) - -vectExpr (_, AnnVar v) _ +vectExpr :: CoreExprWithVectInfo -> VM VExpr + +vectExpr (_, AnnVar v) = vectVar v -vectExpr (_, AnnLit lit) _ +vectExpr (_, AnnLit lit) = vectConst $ Lit lit -vectExpr e@(_, AnnLam bndr _) vt - | isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e vt - | otherwise = do dflags <- getDynFlags - cantVectorise dflags "Unexpected type lambda (vectExpr)" (ppr (deAnnotate e)) +vectExpr e@(_, AnnLam bndr _) + | isId bndr = vectFnExpr True False e + | otherwise + = do + { dflags <- getDynFlags + ; cantVectorise dflags "Unexpected type lambda (vectExpr)" $ ppr (deAnnotate e) + } -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty'; -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint -- happy. -- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now? -vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) _ +vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) | v == pAT_ERROR_ID - = do { (vty, lty) <- vectAndLiftType ty - ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err']) - } + = do + { (vty, lty) <- vectAndLiftType ty + ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err']) + } where err' = deAnnotate err -- type application (handle multiple consecutive type applications simultaneously to ensure the -- PA dictionaries are put at the right places) -vectExpr e@(_, AnnApp _ arg) (VITNode _ [_, _]) +vectExpr e@(_, AnnApp _ arg) | isAnnTypeArg arg = vectPolyApp e - - -- 'Int', 'Float', or 'Double' literal - -- FIXME: this needs to be generalised -vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) _ - | Just con <- isDataConId_maybe v - , is_special_con con + + -- Lifted literal +vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) + | Just _con <- isDataConId_maybe v = do - let vexpr = App (Var v) (Lit lit) - lexpr <- liftPD vexpr - return (vexpr, lexpr) - where - is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon] + { let vexpr = App (Var v) (Lit lit) + ; lexpr <- liftPD vexpr + ; return (vexpr, lexpr) + } -- value application (dictionary or user value) -vectExpr e@(_, AnnApp fn arg) (VITNode _ [vit1, vit2]) +vectExpr e@(_, AnnApp fn arg) | isPredTy arg_ty -- dictionary application (whose result is not a dictionary) = vectPolyApp e | otherwise -- user value - = do { -- vectorise the types - ; varg_ty <- vectType arg_ty - ; vres_ty <- vectType res_ty + = do + { -- vectorise the types + ; varg_ty <- vectType arg_ty + ; vres_ty <- vectType res_ty - -- vectorise the function and argument expression - ; vfn <- vectExpr fn vit1 - ; varg <- vectExpr arg vit2 + -- vectorise the function and argument expression + ; vfn <- vectExpr fn + ; varg <- vectExpr arg - -- the vectorised function is a closure; apply it to the vectorised argument - ; mkClosureApp varg_ty vres_ty vfn varg - } + -- the vectorised function is a closure; apply it to the vectorised argument + ; mkClosureApp varg_ty vres_ty vfn varg + } where (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn -vectExpr (_, AnnCase scrut bndr ty alts) vt +vectExpr (_, AnnCase scrut bndr ty alts) | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty , isAlgTyCon tycon - = vectAlgCase tycon ty_args scrut bndr ty alts vt - | otherwise = do dflags <- getDynFlags - cantVectorise dflags "Can't vectorise expression" (ppr scrut_ty) + = vectAlgCase tycon ty_args scrut bndr ty alts + | otherwise + = do + { dflags <- getDynFlags + ; cantVectorise dflags "Can't vectorise expression (no algebraic type constructor)" $ + ppr scrut_ty + } where scrut_ty = exprType (deAnnotate scrut) -vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) (VITNode _ [vt1, vt2]) +vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) = do - vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs (Just vt1) - (vbndr, vbody) <- vectBndrIn bndr (vectExpr body vt2) - return $ vLet (vNonRec vbndr vrhs) vbody + { vrhs <- localV $ + inBind bndr $ + vectAnnPolyExpr False rhs + ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) + ; return $ vLet (vNonRec vbndr vrhs) vbody + } -vectExpr (_, AnnLet (AnnRec bs) body) (VITNode _ (vtB : vtBnds)) +vectExpr (_, AnnLet (AnnRec bs) body) = do - (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs + { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ liftM2 (,) - (zipWith3M vect_rhs bndrs rhss vtBnds) - (vectExpr body vtB) - return $ vLet (vRec vbndrs vrhss) vbody + (zipWithM vect_rhs bndrs rhss) + (vectExpr body) + ; return $ vLet (vRec vbndrs vrhss) vbody + } where (bndrs, rhss) = unzip bs - vect_rhs bndr rhs vt = localV - . inBind bndr - . liftM (\(_,_,z)->z) - $ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs (Just vt) - zipWith3M f xs ys zs = zipWithM (\x -> \(y,z) -> (f x y z)) xs (zip ys zs) + vect_rhs bndr rhs = localV $ + inBind bndr $ + vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) rhs -vectExpr (_, AnnTick tickish expr) (VITNode _ [vit]) - = liftM (vTick tickish) (vectExpr expr vit) +vectExpr (_, AnnTick tickish expr) + = vTick tickish <$> vectExpr expr -vectExpr (_, AnnType ty) _ - = liftM vType (vectType ty) +vectExpr (_, AnnType ty) + = vType <$> vectType ty -vectExpr e vit = do dflags <- getDynFlags - cantVectorise dflags "Can't vectorise expression (vectExpr)" (ppr (deAnnotate e) $$ text (" " ++ show vit)) +vectExpr e + = do + { dflags <- getDynFlags + ; cantVectorise dflags "Can't vectorise expression (vectExpr)" $ ppr (deAnnotate e) + } --- |Vectorise an expression that *may* have an outer lambda abstraction. +-- |Vectorise an expression that *may* have an outer lambda abstraction. If the expression is marked +-- as encapsulated ('VIEncaps'), vectorise it as a scalar computation (using a generalised scalar +-- zip). -- -- We do not handle type variables at this point, as they will already have been stripped off by --- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only +-- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only -- deal with Haskell 2011 and (2) class selectors are vectorised elsewhere. -- -vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should - -- be inlined - -> Bool -- ^ Whether the binding is a loop breaker - -> [Var] -- ^ Names of function in same recursive binding group - -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam` - -> VITree - -> VM (Inline, Bool, VExpr) --- vectFnExpr _ _ _ e vi | not (checkTree vi (deAnnotate e)) --- = pprPanic "vectFnExpr" (ppr $ deAnnotate e) -vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) vt@(VITNode _ [vt']) - -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type +vectFnExpr :: Bool -- ^If we process the RHS of a binding, whether that binding + -- should be inlined + -> Bool -- ^Whether the binding is a loop breaker + -> CoreExprWithVectInfo -- ^Expression to vectorise; must have an outer `AnnLam` + -> VM VExpr +vectFnExpr inline loop_breaker expr@(_ann, AnnLam bndr body) + -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type | isId bndr && isPredTy (idType bndr) - = do { vBndr <- vectBndr bndr - ; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body vt' - ; return (inline, isScalarFn, mapVect (mkLams [vectorised vBndr]) vbody) - } - -- non-predicate abstraction: vectorise (try to vectorise as a scalar computation) + = do + { vBndr <- vectBndr bndr + ; vbody <- vectFnExpr inline loop_breaker body + ; return $ mapVect (mkLams [vectorised vBndr]) vbody + } + -- non-predicate abstraction: vectorise as a scalar computation + | isId bndr && isVIEncaps expr + = vectScalarFun . deAnnotate $ expr + -- non-predicate abstraction: vectorise as a non-scalar computation | isId bndr - = mark DontInline True (vectScalarFunMaybe (deAnnotate expr) vt) - `orElseV` - mark inlineMe False (vectLam inline loop_breaker expr vt) -vectFnExpr _ _ _ e vt - -- not an abstraction: vectorise as a vanilla expression - = mark DontInline False $ vectExpr e vt - -mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a) -mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) } + = vectLam inline loop_breaker expr +vectFnExpr _ _ expr + -- not an abstraction: vectorise as a vanilla expression + = vectExpr expr -- |Vectorise type and dictionary applications. -- -- These are always headed by a variable (as we don't support higher-rank polymorphism), but may --- involve two sets of type variables and dictionaries. Consider, +-- involve two sets of type variables and dictionaries. Consider, -- -- > class C a where -- > m :: D b => b -> a -- -- The type of 'm' is 'm :: forall a. C a => forall b. D b => b -> a'. -- -vectPolyApp :: CoreExprWithFVs -> VM VExpr +vectPolyApp :: CoreExprWithVectInfo -> VM VExpr vectPolyApp e0 = case e4 of (_, AnnVar var) @@ -530,21 +540,6 @@ vectDictExpr (Coercion coe) -- instead they become dictionaries of vectorised methods). We treat them differently, though see -- "Note [Scalar dfuns]" in 'Vectorise'. -- -vectScalarFunMaybe :: CoreExpr -- ^ Expression to be vectorised - -> VITree -- ^ Vectorisation information - -> VM VExpr -vectScalarFunMaybe expr (VITNode VIEncaps _) = vectScalarFun expr -vectScalarFunMaybe _expr _ = noV $ ptext (sLit "not a scalar function") - --- |Vectorise an expression of functional type by lifting it by an application of a member of the --- zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.) This is only a valid strategy if the --- function does not contain parallel subcomputations and has only 'Scalar' types in its result and --- arguments — this is a predcondition for calling this function. --- --- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised, --- instead they become dictionaries of vectorised methods). We treat them differently, though see --- "Note [Scalar dfuns]" in 'Vectorise'. --- vectScalarFun :: CoreExpr -> VM VExpr vectScalarFun expr = do @@ -673,12 +668,11 @@ unVectDict ty e -- variables are passed explicit (as conventional arguments) into the body during closure -- construction. -- -vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. - -> Bool -- ^ Whether the binding is a loop breaker. - -> CoreExprWithFVs -- ^ Body of abstraction. - -> VITree +vectLam :: Bool -- ^ Should the RHS of a binding be inlined? + -> Bool -- ^ Whether the binding is a loop breaker. + -> CoreExprWithVectInfo -- ^ Body of abstraction. -> VM VExpr -vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi +vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _) = do { let (bndrs, body) = collectAnnValBinders expr -- grab the in-scope type variables @@ -706,18 +700,13 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi . hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity) $ do { -- generate the vectorised body of the lambda abstraction ; lc <- builtin liftingContext - ; let viBody = stripLams expr vi - -- ; checkTreeAnnM vi expr - ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body viBody) + ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) $ vectExpr body ; vbody' <- break_loop lc res_ty vbody ; return $ vLams lc vbndrs vbody' } } where - stripLams (_, AnnLam _ e) (VITNode _ [vt]) = stripLams e vt - stripLams _ vi = vi - maybe_inline n | inline = Inline n | otherwise = DontInline @@ -735,7 +724,7 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi (LitAlt (mkMachInt 0), [], empty)]) } | otherwise = return (ve, le) -vectLam _ _ _ _ = panic "vectLam" +vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda" -- Vectorise an algebraic case expression. -- @@ -754,31 +743,31 @@ vectLam _ _ _ _ = panic "vectLam" -- -- FIXME: this is too lazy -vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs-> Var -> Type - -> [(AltCon, [Var], CoreExprWithFVs)] -> VITree +vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type + -> [(AltCon, [Var], CoreExprWithVectInfo)] -> VM VExpr -vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] (VITNode _ (scrutVit : [altVit])) +vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] = do - vscrut <- vectExpr scrut scrutVit + vscrut <- vectExpr scrut (vty, lty) <- vectAndLiftType ty - (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit) + (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) return $ vCaseDEFAULT vscrut vbndr vty lty vbody -vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] (VITNode _ (scrutVit : [altVit])) +vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] = do - vscrut <- vectExpr scrut scrutVit + vscrut <- vectExpr scrut (vty, lty) <- vectAndLiftType ty - (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit) + (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) return $ vCaseDEFAULT vscrut vbndr vty lty vbody -vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _ (scrutVit : [altVit])) +vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] = do (vty, lty) <- vectAndLiftType ty - vexpr <- vectExpr scrut scrutVit + vexpr <- vectExpr scrut (vbndr, (vbndrs, (vect_body, lift_body))) <- vect_scrut_bndr . vectBndrsIn bndrs - $ vectExpr body altVit + $ vectExpr body let (vect_bndrs, lift_bndrs) = unzip vbndrs (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr) vect_dc <- maybeV dataConErr (lookupDataCon dc) @@ -796,9 +785,9 @@ vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _ dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc) -vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) +vectAlgCase tycon _ty_args scrut bndr ty alts = do - vect_tc <- maybeV tyConErr (lookupTyCon tycon) + vect_tc <- vectTyCon tycon (vty, lty) <- vectAndLiftType ty let arity = length (tyConDataCons vect_tc) @@ -807,10 +796,10 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) let sel = Var sel_bndr (vbndr, valts) <- vect_scrut_bndr - $ mapM (proc_alt arity sel vty lty) (zip alts' altVits) + $ mapM (proc_alt arity sel vty lty) alts' let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts - vexpr <- vectExpr scrut scrutVit + vexpr <- vectExpr scrut (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr) let (vect_bodies, lift_bodies) = unzip vbodies @@ -829,8 +818,6 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) return . vLet (vNonRec vbndr vexpr) $ (vect_case, lift_case) where - tyConErr = (text "vectAlgCase: type constructor not vectorised" <+> ppr tycon) - vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut") | otherwise = vectBndrIn bndr @@ -842,12 +829,12 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) cmp _ DEFAULT = GT cmp _ _ = panic "vectAlgCase/cmp" - proc_alt arity sel _ lty ((DataAlt dc, bndrs, body), vi) + proc_alt arity sel _ lty (DataAlt dc, bndrs, body@((fvs_body, _), _)) = do vect_dc <- maybeV dataConErr (lookupDataCon dc) let ntag = dataConTagZ vect_dc tag = mkDataConTag vect_dc - fvs = freeVarsOf body `delVarSetList` bndrs + fvs = fvs_body `delVarSetList` bndrs sel_tags <- liftM (`App` sel) (builtin (selTags arity)) lc <- builtin liftingContext @@ -860,7 +847,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) binds <- mapM (pack_var (Var lc) sel_tags tag) . filter isLocalId $ varSetElems fvs - (ve, le) <- vectExpr body vi + (ve, le) <- vectExpr body return (ve, Case (elems `App` sel) lc lty [(DEFAULT, [], (mkLets (concat binds) le))]) -- empty <- emptyPD vty @@ -892,9 +879,6 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits)) _ -> return [] -vectAlgCase tycon _ty_args _scrut _bndr _ty _alts (VITNode _ _) - = pprPanic "vectAlgCase (mismatched node information)" (ppr tycon) - -- Support to compute information for vectorisation avoidance ------------------ @@ -905,202 +889,248 @@ data VectAvoidInfo = VIParr -- tree contains parallel computations | VISimple -- result type is scalar & no parallel subcomputation | VIComplex -- any result type, no parallel subcomputation | VIEncaps -- tree encapsulated by 'liftSimple' + | VIDict -- dictionary computation (never parallel) deriving (Eq, Show) --- Instead of integrating the vectorisation avoidance information into Core expression, we keep --- them in a separate tree (that structurally mirrors the Core expression that it annotates). +-- Core expression annotated with free variables and vectorisation-specific information. -- -data VITree = VITNode VectAvoidInfo [VITree] - deriving (Show) +type CoreExprWithVectInfo = AnnExpr Id (VarSet, VectAvoidInfo) --- Is any of the tree nodes a 'VIPArr' node? +-- Yield the type of an annotated core expression. -- -anyVIPArr :: [VITree] -> Bool -anyVIPArr = or . (map (\(VITNode vi _) -> vi == VIParr)) +annExprType :: AnnExpr Var ann -> Type +annExprType = exprType . deAnnotate --- Compute Core annotations to determine for which subexpressions we can avoid vectorisation +-- Project the vectorisation information from an annotated Core expression. -- --- FIXME: free scalar vars don't actually need to be passed through, since encapsulations makes sure, --- that there are no free variables in encapsulated lambda expressions -vectAvoidInfo :: CoreExprWithFVs -> VM VITree -vectAvoidInfo ce@(_, AnnVar v) - = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce - ; viTrace ce vi [] - ; traceVt "vectAvoidInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce)) - ; return $ VITNode vi [] - } +vectAvoidInfoOf :: CoreExprWithVectInfo -> VectAvoidInfo +vectAvoidInfoOf ((_, vi), _) = vi -vectAvoidInfo ce@(_, AnnLit _) - = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce - ; viTrace ce vi [] - ; traceVt "vectAvoidInfo AnnLit" (ppr $ exprType $ deAnnotate ce) - ; return $ VITNode vi [] - } +-- Is this a 'VIParr' node? +-- +isVIParr :: CoreExprWithVectInfo -> Bool +isVIParr = (== VIParr) . vectAvoidInfoOf -vectAvoidInfo ce@(_, AnnApp e1 e2) - = do { vt1 <- vectAvoidInfo e1 - ; vt2 <- vectAvoidInfo e2 - ; vi <- if anyVIPArr [vt1, vt2] - then return VIParr - else vectAvoidInfoType $ exprType $ deAnnotate ce - ; viTrace ce vi [vt1, vt2] - ; return $ VITNode vi [vt1, vt2] - } +-- Is this a 'VIEncaps' node? +-- +isVIEncaps :: CoreExprWithVectInfo -> Bool +isVIEncaps = (== VIEncaps) . vectAvoidInfoOf -vectAvoidInfo ce@(_, AnnLam _var body) - = do { vt@(VITNode vi _) <- vectAvoidInfo body - ; viTrace ce vi [vt] - ; let resultVI | vi == VIParr = VIParr - | otherwise = VIComplex - ; return $ VITNode resultVI [vt] - } +-- Is this a 'VIDict' node? +-- +isVIDict :: CoreExprWithVectInfo -> Bool +isVIDict = (== VIDict) . vectAvoidInfoOf -vectAvoidInfo ce@(_, AnnLet (AnnNonRec _var expr) body) - = do { vtE <- vectAvoidInfo expr - ; vtB <- vectAvoidInfo body - ; vi <- if anyVIPArr [vtE, vtB] - then return VIParr - else vectAvoidInfoType $ exprType $ deAnnotate ce - ; viTrace ce vi [vtE, vtB] - ; return $ VITNode vi [vtE, vtB] - } +-- 'VIParr' if either argument is 'VIParr'; otherwise, the first argument. +-- +unlessVIParr :: VectAvoidInfo -> VectAvoidInfo -> VectAvoidInfo +unlessVIParr _ VIParr = VIParr +unlessVIParr vi _ = vi -vectAvoidInfo ce@(_, AnnLet (AnnRec bnds) body) - = do { let (_, exprs) = unzip bnds - ; vtBnds <- mapM (\e -> vectAvoidInfo e) exprs - ; if (anyVIPArr vtBnds) - then do { vtBnds' <- mapM (\e -> vectAvoidInfo e) exprs - ; vtB <- vectAvoidInfo body - ; return (VITNode VIParr (vtB: vtBnds')) - } - else do { vtB@(VITNode vib _) <- vectAvoidInfo body - ; ni <- if (vib == VIParr) - then return VIParr - else vectAvoidInfoType $ exprType $ deAnnotate ce - ; viTrace ce ni (vtB : vtBnds) - ; return $ VITNode ni (vtB : vtBnds) - } - } +-- 'VIParr' if either arguments vectorisation information is 'VIParr'; otherwise, the vectorisation +-- information of the first argument is produced. +-- +unlessVIParrExpr :: VectAvoidInfo -> CoreExprWithVectInfo -> VectAvoidInfo +infixl `unlessVIParrExpr` +unlessVIParrExpr e1 e2 = e1 `unlessVIParr` vectAvoidInfoOf e2 -vectAvoidInfo ce@(_, AnnCase expr _var _ty alts) - = do { vtExpr <- vectAvoidInfo expr - ; vtAlts <- mapM (\(_, _, e) -> vectAvoidInfo e) alts - ; ni <- if anyVIPArr (vtExpr : vtAlts) - then return VIParr - else vectAvoidInfoType $ exprType $ deAnnotate ce - ; viTrace ce ni (vtExpr : vtAlts) - ; return $ VITNode ni (vtExpr: vtAlts) - } +-- Compute Core annotations to determine for which subexpressions we can avoid vectorisation. +-- +-- * The first argument is the set of free, local variables whose evaluation may entail parallelism. +-- +vectAvoidInfo :: VarSet -> CoreExprWithFVs -> VM CoreExprWithVectInfo +vectAvoidInfo pvs ce@(fvs, AnnVar v) + = do + { gpvs <- globalParallelVars + ; vi <- if v `elemVarSet` pvs || v `elemVarSet` gpvs + then return VIParr + else vectAvoidInfoTypeOf ce + ; viTrace ce vi [] -vectAvoidInfo (_, AnnCast expr _) - = do { vt@(VITNode vi _) <- vectAvoidInfo expr - ; return $ VITNode vi [vt] - } + ; vit <- vectAvoidInfoTypeOf ce -- TEMPORARY + ; traceVt (" AnnVar: vectAvoidInfoTypeOf: " ++ show vit) empty -vectAvoidInfo (_, AnnTick _ expr) - = do { vt@(VITNode vi _) <- vectAvoidInfo expr - ; return $ VITNode vi [vt] - } + ; return ((fvs, vi), AnnVar v) + } -vectAvoidInfo (_, AnnType {}) - = return $ VITNode VISimple [] +vectAvoidInfo _pvs ce@(fvs, AnnLit lit) + = do + { vi <- vectAvoidInfoTypeOf ce + ; viTrace ce vi [] + ; return ((fvs, vi), AnnLit lit) + } -vectAvoidInfo (_, AnnCoercion {}) - = return $ VITNode VISimple [] +vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2) + = do + { ceVI <- vectAvoidInfoTypeOf ce + ; eVI1 <- vectAvoidInfo pvs e1 + ; eVI2 <- vectAvoidInfo pvs e2 + ; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2 + ; viTrace ce vi [eVI1, eVI2] + ; return ((fvs, vi), AnnApp eVI1 eVI2) + } + +vectAvoidInfo pvs ce@(fvs, AnnLam var body) + = do + { bodyVI <- vectAvoidInfo pvs body + ; varVI <- vectAvoidInfoType $ varType var + ; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI + ; viTrace ce vi [bodyVI] + ; return ((fvs, vi), AnnLam var bodyVI) + } + +vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body) + = do + { ceVI <- vectAvoidInfoTypeOf ce + ; eVI <- vectAvoidInfo pvs e + ; isScalarTy <- isScalar $ varType var + ; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy + then do -- binding is parallel + { bodyVI <- vectAvoidInfo (fvs `extendVarSet` var) body + ; return (bodyVI, VIParr) + } + else do -- binding doesn't affect parallelism + { bodyVI <- vectAvoidInfo fvs body + ; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI) + } + ; viTrace ce vi [eVI, bodyVI] + ; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI) + } + +vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body) + = do + { ceVI <- vectAvoidInfoTypeOf ce + ; bndsVI <- mapM (vectAvoidInfoBnd pvs) bnds + ; parrBndrs <- map fst <$> filterM isVIParrBnd bndsVI + ; if not . null $ parrBndrs + then do -- body may trigger parallelism via at least one binding + { new_pvs <- filterM ((not <$>) . isScalar . varType) parrBndrs + ; let extendedPvs = pvs `extendVarSetList` new_pvs + ; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds + ; bodyVI <- vectAvoidInfo extendedPvs body + ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI]) + ; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI) + } + else do -- demanded bindings cannot trigger parallelism + { bodyVI <- vectAvoidInfo pvs body + ; let vi = ceVI `unlessVIParrExpr` bodyVI + ; viTrace ce vi (map snd bndsVI ++ [bodyVI]) + ; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI) + } + } + where + vectAvoidInfoBnd pvs (var, e) = (var,) <$> vectAvoidInfo pvs e + + isVIParrBnd (var, eVI) + = do + { isScalarTy <- isScalar (varType var) + ; return $ isVIParr eVI && not isScalarTy + } + +vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts) + = do + { ceVI <- vectAvoidInfoTypeOf ce + ; eVI <- vectAvoidInfo pvs e + ; isScalarTy <- isScalar . annExprType $ e + ; altsVI <- mapM (vectAvoidInfoAlt (isVIParr eVI && not isScalarTy)) alts + ; allScalarBndrs <- anyM allScalarAltBndrs altsVI + ; let alteVIs = [eVI | (_, _, eVI) <- altsVI] + vi | isVIParr eVI && not allScalarBndrs = VIParr + | otherwise + = foldl unlessVIParrExpr ceVI alteVIs + ; viTrace ce vi (eVI : alteVIs) + ; return ((fvs, vi), AnnCase eVI var ty altsVI) + } + where + vectAvoidInfoAlt isScalarScrut (con, bndrs, e) = (con, bndrs,) <$> vectAvoidInfo altPvs e + where + altPvs | isScalarScrut = pvs + | otherwise = pvs `extendVarSetList` bndrs + + allScalarAltBndrs (_, bndrs, _) = allScalarVarType bndrs + +vectAvoidInfo pvs (fvs, AnnCast e (fvs_ann, ann)) + = do + { eVI <- vectAvoidInfo pvs e + ; return ((fvs, vectAvoidInfoOf eVI), AnnCast eVI ((fvs_ann, VISimple), ann)) + } + +vectAvoidInfo pvs (fvs, AnnTick tick e) + = do + { eVI <- vectAvoidInfo pvs e + ; return ((fvs, vectAvoidInfoOf eVI), AnnTick tick eVI) + } + +vectAvoidInfo _pvs (fvs, AnnType ty) + = return ((fvs, VISimple), AnnType ty) + +vectAvoidInfo _pvs (fvs, AnnCoercion coe) + = return ((fvs, VISimple), AnnCoercion coe) -- Compute vectorisation avoidance information for a type. -- vectAvoidInfoType :: Type -> VM VectAvoidInfo -vectAvoidInfoType ty - | maybeParrTy ty = return VIParr - | otherwise - = do { sType <- isSimpleType ty - ; if sType - then return VISimple - else return VIComplex - } +vectAvoidInfoType ty + | isPredTy ty + = return VIDict + | Just (arg, res) <- splitFunTy_maybe ty + = do + { argVI <- vectAvoidInfoType arg + ; resVI <- vectAvoidInfoType res + ; case (argVI, resVI) of + (VISimple, VISimple) -> return VISimple -- NB: diverts from the paper: scalar functions + (_ , VIDict) -> return VIDict + _ -> return $ VIComplex `unlessVIParr` argVI `unlessVIParr` resVI + } + | otherwise + = do + { parr <- maybeParrTy ty + ; if parr + then return VIParr + else do + { scalar <- isScalar ty + ; if scalar + then return VISimple + else return VIComplex + } } + +-- Compute vectorisation avoidance information for the type of a Core expression (with FVs). +-- +vectAvoidInfoTypeOf :: AnnExpr Var ann -> VM VectAvoidInfo +vectAvoidInfoTypeOf = vectAvoidInfoType . annExprType --- Checks whether the type might be a parallel array type. In particular, if the outermost --- constructor is a type family, we conservatively assume that it may be a parallel array type. +-- Checks whether the type might be a parallel array type. -- -maybeParrTy :: Type -> Bool +maybeParrTy :: Type -> VM Bool maybeParrTy ty - | Just ty' <- coreView ty = maybeParrTy ty' - | Just (tyCon, ts) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon - || or (map maybeParrTy ts) -maybeParrTy _ = False - --- FIXME: This should not be hardcoded. -isSimpleType :: Type -> VM Bool -isSimpleType ty - | Just (c, _cs) <- splitTyConApp_maybe ty - = return $ (tyConName c) `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName] -{- - = do { globals <- globalScalarTyCons - ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c) - ; return (elemNameSet (tyConName c) globals ) - } - -} - | Nothing <- splitTyConApp_maybe ty - = return False -isSimpleType ty - = pprPanic "Vectorise.Exp.isSimpleType not handled" (ppr ty) - -varsSimple :: VarSet -> VM Bool -varsSimple vs - = do { varTypes <- mapM isSimpleType $ map varType $ varSetElems vs - ; return $ and varTypes - } - -viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [VITree] -> VM () -viTrace ce vi vTs - = traceVt ("vitrace " ++ (show vi) ++ "[" ++ (concat $ map (\(VITNode vi _) -> show vi ++ " ") vTs) ++"]") - (ppr $ deAnnotate ce) - + -- looking through newtypes + | Just ty' <- coreView ty + = (== VIParr) <$> vectAvoidInfoType ty' + -- decompose constructor applications + | Just (tc, ts) <- splitTyConApp_maybe ty + = do + { isParallel <- (tyConName tc `elemNameSet`) <$> globalParallelTyCons + ; if isParallel + then return True + else or <$> mapM maybeParrTy ts + } +maybeParrTy (ForAllTy _ ty) = maybeParrTy ty +maybeParrTy _ = return False -{- ----- Sanity check of the tree, for debugging only -checkTree :: VITree -> CoreExpr -> Bool -checkTree (VITNode _ []) (Type _ty) - = True - -checkTree (VITNode _ []) (Var _v) - = True - -checkTree (VITNode _ []) (Lit _) - = True - -checkTree (VITNode _ [vit]) (Tick _ expr) - = checkTree vit expr - -checkTree (VITNode _ [vit]) (Lam _ expr) - = checkTree vit expr - -checkTree (VITNode _ [vit1, vit2]) (App ce1 ce2) - = (checkTree vit1 ce1) && (checkTree vit2 ce2) - -checkTree (VITNode _ (scrutVit : altVits)) (Case scrut _ _ alts) - = (checkTree scrutVit scrut) && (and $ zipWith checkAlt altVits alts) - where - checkAlt vt (_, _, expr) = checkTree vt expr - -checkTree (VITNode _ [vt1, vt2]) (Let (NonRec _ expr1) expr2) - = (checkTree vt1 expr1) && (checkTree vt2 expr2) - -checkTree (VITNode _ (vtB : vtBnds)) (Let (Rec bndngs) expr) - = (and $ zipWith checkBndr vtBnds bndngs) && - (checkTree vtB expr) - where - checkBndr vt (_, e) = checkTree vt e - -checkTree (VITNode _ [vit]) (Cast expr _) - = checkTree vit expr +-- Are the types of all variables in the 'Scalar' class? +-- +allScalarVarType :: [Var] -> VM Bool +allScalarVarType vs = and <$> mapM (isScalar . varType) vs -checkTree _ _ = False +-- Are the types of all variables in the set in the 'Scalar' class? +-- +allScalarVarTypeSet :: VarSet -> VM Bool +allScalarVarTypeSet = allScalarVarType . varSetElems -checkTreeAnnM:: VITree -> CoreExprWithFVs -> VM () -checkTreeAnnM vi e = - if not (checkTree vi $ deAnnotate e) - then error ("checkTreeAnnM : \n " ++ show vi) - else return () --} +-- Debugging support +-- +viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [CoreExprWithVectInfo] -> VM () +viTrace ce vi vTs + = traceVt ("vect info: " ++ show vi ++ "[" ++ + (concat $ map ((++ " ") . show . vectAvoidInfoOf) vTs) ++ "]") + (ppr $ deAnnotate ce) |