summaryrefslogtreecommitdiff
path: root/compiler/vectorise/Vectorise/Exp.hs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/vectorise/Vectorise/Exp.hs')
-rw-r--r--compiler/vectorise/Vectorise/Exp.hs1048
1 files changed, 539 insertions, 509 deletions
diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs
index 8c5ef0045d..88f123210b 100644
--- a/compiler/vectorise/Vectorise/Exp.hs
+++ b/compiler/vectorise/Vectorise/Exp.hs
@@ -3,10 +3,9 @@
-- |Vectorisation of expressions.
module Vectorise.Exp
- ( -- * Vectorise polymorphic expressions with special cases for right-hand sides of particular
- -- variable bindings
- vectPolyExpr
- , vectDictExpr
+ ( -- * Vectorise right-hand sides of toplevel bindings
+ vectTopExpr
+ , vectTopExprs
, vectScalarFun
, vectScalarDFun
)
@@ -32,393 +31,404 @@ import DataCon
import TyCon
import TcType
import Type
-import PrelNames
+import TypeRep
import Var
import VarEnv
import VarSet
+import NameSet
import Id
import BasicTypes( isStrongLoopBreaker )
import Literal
-import TysWiredIn
import TysPrim
import Outputable
import FastString
+import DynFlags
+import Util
+import MonadUtils
+
import Control.Monad
-import Control.Applicative
import Data.Maybe
import Data.List
-import TcRnMonad (doptM)
-import DynFlags
-import Util
-- Main entry point to vectorise expressions -----------------------------------
--- |Vectorise a polymorphic expression.
+-- |Vectorise a polymorphic expression that forms a *non-recursive* binding.
+--
+-- Return 'Nothing' if the expression is scalar; otherwise, the first component of the result
+-- (which is of type 'Bool') indicates whether the expression is parallel (i.e., whether it is
+-- tagged as 'VIParr').
--
--- If not yet available, precompute vectorisation avoidance information before vectorising. If
--- the vectorisation avoidance optimisation is enabled, also use the vectorisation avoidance
--- information to encapsulated subexpression that do not need to be vectorised.
+-- We have got the non-recursive case as a special case as it doesn't require to compute
+-- vectorisation information twice.
--
-vectPolyExpr :: Bool -> [Var] -> CoreExprWithFVs -> Maybe VITree
- -> VM (Inline, Bool, VExpr)
- -- precompute vectorisation avoidance information (and possibly encapsulated subexpressions)
-vectPolyExpr loop_breaker recFns expr Nothing
+vectTopExpr :: Var -> CoreExpr -> VM (Maybe (Bool, Inline, CoreExpr))
+vectTopExpr var expr
= do
- { vectAvoidance <- liftDs $ doptM Opt_AvoidVect
- ; vi <- vectAvoidInfo expr
- ; (expr', vi') <-
- if vectAvoidance
- then do
- { (expr', vi') <- encapsulateScalars vi expr
- ; traceVt "vectPolyExpr encapsulated:" (ppr $ deAnnotate expr')
- ; return (expr', vi')
- }
- else return (expr, vi)
- ; vectPolyExpr loop_breaker recFns expr' (Just vi')
+ { exprVI <- encapsulateScalars <=< vectAvoidInfo emptyVarSet . freeVars $ expr
+ ; if isVIEncaps exprVI
+ then
+ return Nothing
+ else do
+ { vExpr <- closedV $
+ inBind var $
+ vectAnnPolyExpr False exprVI
+ ; inline <- computeInline exprVI
+ ; return $ Just (isVIParr exprVI, inline, vectorised vExpr)
+ }
}
- -- traverse through ticks
-vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr) (Just (VITNode _ [vit]))
- = do
- { (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr (Just vit)
- ; return (inline, isScalarFn, vTick tickish expr')
- }
+-- Compute the inlining hint for the right-hand side of a top-level binding.
+--
+computeInline :: CoreExprWithVectInfo -> VM Inline
+computeInline ((_, VIDict), _) = return $ DontInline
+computeInline (_, AnnTick _ expr) = computeInline expr
+computeInline expr@(_, AnnLam _ _) = Inline <$> polyArity tvs
+ where
+ (tvs, _) = collectAnnTypeBinders expr
+computeInline _expr = return $ DontInline
- -- collect and vectorise type abstractions; then, descent into the body
-vectPolyExpr loop_breaker recFns expr (Just vit)
- = do
- { let (tvs, mono) = collectAnnTypeBinders expr
- vit' = stripLevels (length tvs) vit
- ; arity <- polyArity tvs
- ; polyAbstract tvs $ \args ->
- do
- { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vit'
- ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono')
- }
+-- |Vectorise a recursive group of top-level polymorphic expressions.
+--
+-- Return 'Nothing' if the expression group is scalar; otherwise, the first component of the result
+-- (which is of type 'Bool') indicates whether the expressions are parallel (i.e., whether they are
+-- tagged as 'VIParr').
+--
+vectTopExprs :: [(Var, CoreExpr)] -> VM (Maybe (Bool, [(Inline, CoreExpr)]))
+vectTopExprs binds
+ = do
+ { exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs
+ ; if all isVIEncaps exprVIs
+ then
+ return Nothing
+ else do
+ { (areVIParr, vExprs) <- unzip <$> mapM encapsulateAndVect binds
+ ; return $ Just (or areVIParr, vExprs)
+ }
}
where
- stripLevels 0 vit = vit
- stripLevels n (VITNode _ [vit]) = stripLevels (n - 1) vit
- stripLevels _ vit = pprPanic "vectPolyExpr: stripLevels:" (text (show vit))
+ (vars, exprs) = unzip binds
+
+ vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars
+
+ encapsulateAndVect (var, expr)
+ = do
+ { exprVI <- vectAvoidAndEncapsulate (mkVarSet vars) expr
+ ; vExpr <- closedV $
+ inBind var $
+ vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI
+ ; inline <- computeInline exprVI
+ ; return (isVIParr exprVI, (inline, vectorised vExpr))
+ }
+
+-- |Vectorise a polymorphic expression annotated with vectorisation information.
+--
+-- The special case of dictionary functions is currently handled separately. (Would be neater to
+-- integrate them, though!)
+--
+vectAnnPolyExpr :: Bool -> CoreExprWithVectInfo -> VM VExpr
+vectAnnPolyExpr loop_breaker (_, AnnTick tickish expr)
+ -- traverse through ticks
+ = vTick tickish <$> vectAnnPolyExpr loop_breaker expr
+vectAnnPolyExpr loop_breaker expr
+ | isVIDict expr
+ -- special case the right-hand side of dictionary functions
+ = (, undefined) <$> vectDictExpr (deAnnotate expr)
+ | otherwise
+ -- collect and vectorise type abstractions; then, descent into the body
+ = polyAbstract tvs $ \args ->
+ mapVect (mkLams $ tvs ++ args) <$> vectFnExpr False loop_breaker mono
+ where
+ (tvs, mono) = collectAnnTypeBinders expr
-- Encapsulate every purely sequential subexpression of a (potentially) parallel expression into a
--- into a lambda abstraction over all its free variables followed by the corresponding application
--- to those variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions.
+-- lambda abstraction over all its free variables followed by the corresponding application to those
+-- variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions.
--
-- Preconditions:
--
-- * All free variables and the result type must be /simple/ types.
--- * The expression is sufficientlt complex (top warrant special treatment). For now, that is
+-- * The expression is sufficiently complex (to warrant special treatment). For now, that is
-- every expression that is not constant and contains at least one operation.
--
-encapsulateScalars :: VITree -> CoreExprWithFVs -> VM (CoreExprWithFVs, VITree)
-encapsulateScalars vit ce@(_, AnnType _ty)
- = return (ce, vit)
-
-encapsulateScalars vit ce@(_, AnnVar _v)
- = return (ce, vit)
-
-encapsulateScalars vit ce@(_, AnnLit _)
- = return (ce, vit)
-
-encapsulateScalars (VITNode vi [vit]) (fvs, AnnTick tck expr)
- = do { (extExpr, vit') <- encapsulateScalars vit expr
- ; return ((fvs, AnnTick tck extExpr), VITNode vi [vit'])
- }
-
-encapsulateScalars _ (_fvs, AnnTick _tck _expr)
- = panic "encapsulateScalar AnnTick doesn't match up"
-
-encapsulateScalars (VITNode vi [vit]) ce@(fvs, AnnLam bndr expr)
- = do { varsS <- varsSimple fvs
- ; case (vi, varsS) of
- (VISimple, True) -> do { let (e', vit') = liftSimple vit ce
- ; return (e', vit')
- }
- _ -> do { (extExpr, vit') <- encapsulateScalars vit expr
- ; return ((fvs, AnnLam bndr extExpr), VITNode vi [vit'])
- }
- }
-
-encapsulateScalars _ (_fvs, AnnLam _bndr _expr)
- = panic "encapsulateScalars AnnLam doesn't match up"
-
-encapsulateScalars vt@(VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2)
- = do { varsS <- varsSimple fvs
- ; case (vi, varsS) of
- (VISimple, True) -> do { let (e', vt') = liftSimple vt ce
- -- ; checkTreeAnnM vt' e'
- -- ; traceVt "Passed checkTree test!!" (ppr $ deAnnotate e')
- ; return (e', vt')
- }
- _ -> do { (etaCe1, vit1') <- encapsulateScalars vit1 ce1
- ; (etaCe2, vit2') <- encapsulateScalars vit2 ce2
- ; return ((fvs, AnnApp etaCe1 etaCe2), VITNode vi [vit1', vit2'])
- }
- }
-
-encapsulateScalars _ (_fvs, AnnApp _ce1 _ce2)
- = panic "encapsulateScalars AnnApp doesn't match up"
-
-encapsulateScalars vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts)
- = do { varsS <- varsSimple fvs
- ; case (vi, varsS) of
- (VISimple, True) -> return $ liftSimple vt ce
- _ -> do { (extScrut, scrutVit') <- encapsulateScalars scrutVit scrut
- ; extAltsVits <- zipWithM expAlt altVits alts
- ; let (extAlts, altVits') = unzip extAltsVits
- ; return ((fvs, AnnCase extScrut bndr ty extAlts), VITNode vi (scrutVit': altVits'))
- }
- }
+encapsulateScalars :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
+encapsulateScalars ce@(_, AnnType _ty)
+ = return ce
+encapsulateScalars ce@((_, VISimple), AnnVar v)
+ | isFunTy . varType $ v -- NB: diverts from the paper: encapsulate scalar function types
+ = liftSimpleAndCase ce
+encapsulateScalars ce@(_, AnnVar _v)
+ = return ce
+encapsulateScalars ce@(_, AnnLit _)
+ = return ce
+encapsulateScalars ((fvs, vi), AnnTick tck expr)
+ = do
+ { encExpr <- encapsulateScalars expr
+ ; return ((fvs, vi), AnnTick tck encExpr)
+ }
+encapsulateScalars ce@((fvs, vi), AnnLam bndr expr)
+ = do
+ { varsS <- allScalarVarTypeSet fvs
+ ; case (vi, varsS) of
+ (VISimple, True) -> liftSimpleAndCase ce
+ _ -> do
+ { encExpr <- encapsulateScalars expr
+ ; return ((fvs, vi), AnnLam bndr encExpr)
+ }
+ }
+encapsulateScalars ce@((fvs, vi), AnnApp ce1 ce2)
+ = do
+ { varsS <- allScalarVarTypeSet fvs
+ ; case (vi, varsS) of
+ (VISimple, True) -> liftSimpleAndCase ce
+ _ -> do
+ { encCe1 <- encapsulateScalars ce1
+ ; encCe2 <- encapsulateScalars ce2
+ ; return ((fvs, vi), AnnApp encCe1 encCe2)
+ }
+ }
+encapsulateScalars ce@((fvs, vi), AnnCase scrut bndr ty alts)
+ = do
+ { varsS <- allScalarVarTypeSet fvs
+ ; case (vi, varsS) of
+ (VISimple, True) -> liftSimpleAndCase ce
+ _ -> do
+ { encScrut <- encapsulateScalars scrut
+ ; encAlts <- mapM encAlt alts
+ ; return ((fvs, vi), AnnCase encScrut bndr ty encAlts)
+ }
+ }
where
- expAlt vt (con, bndrs, expr)
- = do { (extExpr, vt') <- encapsulateScalars vt expr
- ; return ((con, bndrs, extExpr), vt')
- }
-
-encapsulateScalars _ (_fvs, AnnCase _scrut _bndr _ty _alts)
- = panic "encapsulateScalars AnnCase doesn't match up"
-
-encapsulateScalars vt@(VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2)
- = do { varsS <- varsSimple fvs
- ; case (vi, varsS) of
- (VISimple, True) -> return $ liftSimple vt ce
- _ -> do { (extExpr1, vt1') <- encapsulateScalars vt1 expr1
- ; (extExpr2, vt2') <- encapsulateScalars vt2 expr2
- ; return ((fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2), VITNode vi [vt1', vt2'])
- }
- }
-
-encapsulateScalars _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2)
- = panic "encapsulateScalars AnnLet nonrec doesn't match up"
-
-encapsulateScalars vt@(VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr)
- = do { varsS <- varsSimple fvs
- ; case (vi, varsS) of
- (VISimple, True) -> return $ liftSimple vt ce
- _ -> do { extBndsVts <- zipWithM expBndg vtBnds bndngs
- ; let (extBnds, vtBnds') = unzip extBndsVts
- ; (extExpr, vtB') <- encapsulateScalars vtB expr
- ; let vt' = VITNode vi (vtB':vtBnds')
- ; return ((fvs, AnnLet (AnnRec extBnds) extExpr), vt')
- }
- }
- where
- expBndg vit (bndr, expr)
- = do { (extExpr, vit') <- encapsulateScalars vit expr
- ; return ((bndr, extExpr), vit')
- }
-
-encapsulateScalars _ (_fvs, AnnLet (AnnRec _) _expr2)
- = panic "encapsulateScalars AnnLet rec doesn't match up"
-
-encapsulateScalars (VITNode vi [vit]) (fvs, AnnCast expr coercion)
- = do { (extExpr, vit') <- encapsulateScalars vit expr
- ; return ((fvs, AnnCast extExpr coercion), VITNode vi [vit'])
- }
-
-encapsulateScalars _ (_fvs, AnnCast _expr _coercion)
- = panic "encapsulateScalars AnnCast rec doesn't match up"
-
-encapsulateScalars _ _
- = panic "encapsulateScalars case not handled"
+ encAlt (con, bndrs, expr) = (con, bndrs,) <$> encapsulateScalars expr
+encapsulateScalars ce@((fvs, vi), AnnLet (AnnNonRec bndr expr1) expr2)
+ = do
+ { varsS <- allScalarVarTypeSet fvs
+ ; case (vi, varsS) of
+ (VISimple, True) -> liftSimpleAndCase ce
+ _ -> do
+ { encExpr1 <- encapsulateScalars expr1
+ ; encExpr2 <- encapsulateScalars expr2
+ ; return ((fvs, vi), AnnLet (AnnNonRec bndr encExpr1) encExpr2)
+ }
+ }
+encapsulateScalars ce@((fvs, vi), AnnLet (AnnRec binds) expr)
+ = do
+ { varsS <- allScalarVarTypeSet fvs
+ ; case (vi, varsS) of
+ (VISimple, True) -> liftSimpleAndCase ce
+ _ -> do
+ { encBinds <- mapM encBind binds
+ ; encExpr <- encapsulateScalars expr
+ ; return ((fvs, vi), AnnLet (AnnRec encBinds) encExpr)
+ }
+ }
+ where
+ encBind (bndr, expr) = (bndr,) <$> encapsulateScalars expr
+encapsulateScalars ((fvs, vi), AnnCast expr coercion)
+ = do
+ { encExpr <- encapsulateScalars expr
+ ; return ((fvs, vi), AnnCast encExpr coercion)
+ }
+encapsulateScalars _
+ = panic "Vectorise.Exp.encapsulateScalars: unknown constructor"
--- Lambda-lift the given expression and apply it to the abstracted free variables.
+-- Lambda-lift the given simple expression and apply it to the abstracted free variables.
--
--- If the expression is a case expression scrutinising anything but a primitive type, then lift
+-- If the expression is a case expression scrutinising anything, but a scalar type, then lift
-- each alternative individually.
--
-liftSimple :: VITree -> CoreExprWithFVs -> (CoreExprWithFVs, VITree)
-liftSimple (VITNode vi (scrutVit : altVits)) (fvs, AnnCase expr bndr t alts)
- | Just (c,_) <- splitTyConApp_maybe (exprType $ deAnnotate $ expr),
- (not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon]) -- FIXME: shouldn't be hardcoded
- = ((fvs, AnnCase expr bndr t alts'), VITNode vi (scrutVit : altVits'))
- where
- (alts', altVits') = unzip $ map (\(ac,bndrs, (alt, avi)) -> ((ac,bndrs,alt), avi)) $
- zipWith (\(ac, bndrs, aex) -> \altVi -> (ac, bndrs, liftSimple altVi aex)) alts altVits
-
-liftSimple viTree ae@(fvs, _annEx)
- = (mkAnnApps (mkAnnLams ae vars) vars, viTree')
+liftSimpleAndCase :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
+liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts)
+ = do
+ { vi <- vectAvoidInfoTypeOf expr
+ ; if (vi == VISimple)
+ then
+ return $ liftSimple aexpr -- if the scrutinee is scalar, we need no special treatment
+ else do
+ { alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts
+ ; return ((fvs, vi), AnnCase expr bndr t alts')
+ }
+ }
+liftSimpleAndCase aexpr = return $ liftSimple aexpr
+
+liftSimple :: CoreExprWithVectInfo -> CoreExprWithVectInfo
+liftSimple ((fvs, vi), expr)
+ = ASSERT(vi == VISimple)
+ mkAnnApps (mkAnnLams vars fvs expr) vars
where
- mkViTreeLams (VITNode _ vits) [] = VITNode VIEncaps vits
- mkViTreeLams vi (_:vs) = VITNode VIEncaps [mkViTreeLams vi vs]
+ vars = varSetElems fvs
- mkViTreeApps vi [] = vi
- mkViTreeApps vi (_:vs) = VITNode VISimple [mkViTreeApps vi vs, VITNode VISimple []]
-
- vars = varSetElems fvs
- viTree' = mkViTreeApps (mkViTreeLams viTree vars) vars
-
- mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet
- mkAnnLam bndr ce = AnnLam bndr ce
-
- mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
- mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check!
- mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs
-
- mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet)
- mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v))
+ mkAnnLams :: [Var] -> VarSet -> AnnExpr' Var (VarSet, VectAvoidInfo) -> CoreExprWithVectInfo
+ mkAnnLams [] fvs expr = ASSERT(isEmptyVarSet fvs)
+ ((emptyVarSet, VIEncaps), expr)
+ mkAnnLams (v:vs) fvs expr = mkAnnLams vs (fvs `delVarSet` v) (AnnLam v ((fvs, VIEncaps), expr))
- mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
- mkAnnApps (fv, aex') [] = (fv, aex')
- mkAnnApps ae (v:vs) =
- let
- (fv, aex') = mkAnnApps ae vs
- in (extendVarSet fv v, mkAnnApp (fv, aex') v)
+ mkAnnApps :: CoreExprWithVectInfo -> [Var] -> CoreExprWithVectInfo
+ mkAnnApps aexpr [] = aexpr
+ mkAnnApps aexpr (v:vs) = mkAnnApps (mkAnnApp aexpr v) vs
+
+ mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo
+ mkAnnApp aexpr@((fvs, _vi), _expr) v
+ = ((fvs `extendVarSet` v, VISimple), AnnApp aexpr ((unitVarSet v, VISimple), AnnVar v))
-- |Vectorise an expression.
--
-vectExpr :: CoreExprWithFVs -> VITree -> VM VExpr
--- vectExpr e vi | not (checkTree vi (deAnnotate e))
--- = pprPanic "vectExpr" (ppr $ deAnnotate e)
-
-vectExpr (_, AnnVar v) _
+vectExpr :: CoreExprWithVectInfo -> VM VExpr
+
+vectExpr (_, AnnVar v)
= vectVar v
-vectExpr (_, AnnLit lit) _
+vectExpr (_, AnnLit lit)
= vectConst $ Lit lit
-vectExpr e@(_, AnnLam bndr _) vt
- | isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e vt
- | otherwise = do dflags <- getDynFlags
- cantVectorise dflags "Unexpected type lambda (vectExpr)" (ppr (deAnnotate e))
+vectExpr e@(_, AnnLam bndr _)
+ | isId bndr = vectFnExpr True False e
+ | otherwise
+ = do
+ { dflags <- getDynFlags
+ ; cantVectorise dflags "Unexpected type lambda (vectExpr)" $ ppr (deAnnotate e)
+ }
-- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
-- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
-- happy.
-- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now?
-vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) _
+vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
| v == pAT_ERROR_ID
- = do { (vty, lty) <- vectAndLiftType ty
- ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
- }
+ = do
+ { (vty, lty) <- vectAndLiftType ty
+ ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
+ }
where
err' = deAnnotate err
-- type application (handle multiple consecutive type applications simultaneously to ensure the
-- PA dictionaries are put at the right places)
-vectExpr e@(_, AnnApp _ arg) (VITNode _ [_, _])
+vectExpr e@(_, AnnApp _ arg)
| isAnnTypeArg arg
= vectPolyApp e
-
- -- 'Int', 'Float', or 'Double' literal
- -- FIXME: this needs to be generalised
-vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) _
- | Just con <- isDataConId_maybe v
- , is_special_con con
+
+ -- Lifted literal
+vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
+ | Just _con <- isDataConId_maybe v
= do
- let vexpr = App (Var v) (Lit lit)
- lexpr <- liftPD vexpr
- return (vexpr, lexpr)
- where
- is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
+ { let vexpr = App (Var v) (Lit lit)
+ ; lexpr <- liftPD vexpr
+ ; return (vexpr, lexpr)
+ }
-- value application (dictionary or user value)
-vectExpr e@(_, AnnApp fn arg) (VITNode _ [vit1, vit2])
+vectExpr e@(_, AnnApp fn arg)
| isPredTy arg_ty -- dictionary application (whose result is not a dictionary)
= vectPolyApp e
| otherwise -- user value
- = do { -- vectorise the types
- ; varg_ty <- vectType arg_ty
- ; vres_ty <- vectType res_ty
+ = do
+ { -- vectorise the types
+ ; varg_ty <- vectType arg_ty
+ ; vres_ty <- vectType res_ty
- -- vectorise the function and argument expression
- ; vfn <- vectExpr fn vit1
- ; varg <- vectExpr arg vit2
+ -- vectorise the function and argument expression
+ ; vfn <- vectExpr fn
+ ; varg <- vectExpr arg
- -- the vectorised function is a closure; apply it to the vectorised argument
- ; mkClosureApp varg_ty vres_ty vfn varg
- }
+ -- the vectorised function is a closure; apply it to the vectorised argument
+ ; mkClosureApp varg_ty vres_ty vfn varg
+ }
where
(arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
-vectExpr (_, AnnCase scrut bndr ty alts) vt
+vectExpr (_, AnnCase scrut bndr ty alts)
| Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
, isAlgTyCon tycon
- = vectAlgCase tycon ty_args scrut bndr ty alts vt
- | otherwise = do dflags <- getDynFlags
- cantVectorise dflags "Can't vectorise expression" (ppr scrut_ty)
+ = vectAlgCase tycon ty_args scrut bndr ty alts
+ | otherwise
+ = do
+ { dflags <- getDynFlags
+ ; cantVectorise dflags "Can't vectorise expression (no algebraic type constructor)" $
+ ppr scrut_ty
+ }
where
scrut_ty = exprType (deAnnotate scrut)
-vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) (VITNode _ [vt1, vt2])
+vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
- vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs (Just vt1)
- (vbndr, vbody) <- vectBndrIn bndr (vectExpr body vt2)
- return $ vLet (vNonRec vbndr vrhs) vbody
+ { vrhs <- localV $
+ inBind bndr $
+ vectAnnPolyExpr False rhs
+ ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
+ ; return $ vLet (vNonRec vbndr vrhs) vbody
+ }
-vectExpr (_, AnnLet (AnnRec bs) body) (VITNode _ (vtB : vtBnds))
+vectExpr (_, AnnLet (AnnRec bs) body)
= do
- (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
+ { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
$ liftM2 (,)
- (zipWith3M vect_rhs bndrs rhss vtBnds)
- (vectExpr body vtB)
- return $ vLet (vRec vbndrs vrhss) vbody
+ (zipWithM vect_rhs bndrs rhss)
+ (vectExpr body)
+ ; return $ vLet (vRec vbndrs vrhss) vbody
+ }
where
(bndrs, rhss) = unzip bs
- vect_rhs bndr rhs vt = localV
- . inBind bndr
- . liftM (\(_,_,z)->z)
- $ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs (Just vt)
- zipWith3M f xs ys zs = zipWithM (\x -> \(y,z) -> (f x y z)) xs (zip ys zs)
+ vect_rhs bndr rhs = localV $
+ inBind bndr $
+ vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) rhs
-vectExpr (_, AnnTick tickish expr) (VITNode _ [vit])
- = liftM (vTick tickish) (vectExpr expr vit)
+vectExpr (_, AnnTick tickish expr)
+ = vTick tickish <$> vectExpr expr
-vectExpr (_, AnnType ty) _
- = liftM vType (vectType ty)
+vectExpr (_, AnnType ty)
+ = vType <$> vectType ty
-vectExpr e vit = do dflags <- getDynFlags
- cantVectorise dflags "Can't vectorise expression (vectExpr)" (ppr (deAnnotate e) $$ text (" " ++ show vit))
+vectExpr e
+ = do
+ { dflags <- getDynFlags
+ ; cantVectorise dflags "Can't vectorise expression (vectExpr)" $ ppr (deAnnotate e)
+ }
--- |Vectorise an expression that *may* have an outer lambda abstraction.
+-- |Vectorise an expression that *may* have an outer lambda abstraction. If the expression is marked
+-- as encapsulated ('VIEncaps'), vectorise it as a scalar computation (using a generalised scalar
+-- zip).
--
-- We do not handle type variables at this point, as they will already have been stripped off by
--- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only
+-- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only
-- deal with Haskell 2011 and (2) class selectors are vectorised elsewhere.
--
-vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should
- -- be inlined
- -> Bool -- ^ Whether the binding is a loop breaker
- -> [Var] -- ^ Names of function in same recursive binding group
- -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam`
- -> VITree
- -> VM (Inline, Bool, VExpr)
--- vectFnExpr _ _ _ e vi | not (checkTree vi (deAnnotate e))
--- = pprPanic "vectFnExpr" (ppr $ deAnnotate e)
-vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) vt@(VITNode _ [vt'])
- -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type
+vectFnExpr :: Bool -- ^If we process the RHS of a binding, whether that binding
+ -- should be inlined
+ -> Bool -- ^Whether the binding is a loop breaker
+ -> CoreExprWithVectInfo -- ^Expression to vectorise; must have an outer `AnnLam`
+ -> VM VExpr
+vectFnExpr inline loop_breaker expr@(_ann, AnnLam bndr body)
+ -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type
| isId bndr
&& isPredTy (idType bndr)
- = do { vBndr <- vectBndr bndr
- ; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body vt'
- ; return (inline, isScalarFn, mapVect (mkLams [vectorised vBndr]) vbody)
- }
- -- non-predicate abstraction: vectorise (try to vectorise as a scalar computation)
+ = do
+ { vBndr <- vectBndr bndr
+ ; vbody <- vectFnExpr inline loop_breaker body
+ ; return $ mapVect (mkLams [vectorised vBndr]) vbody
+ }
+ -- non-predicate abstraction: vectorise as a scalar computation
+ | isId bndr && isVIEncaps expr
+ = vectScalarFun . deAnnotate $ expr
+ -- non-predicate abstraction: vectorise as a non-scalar computation
| isId bndr
- = mark DontInline True (vectScalarFunMaybe (deAnnotate expr) vt)
- `orElseV`
- mark inlineMe False (vectLam inline loop_breaker expr vt)
-vectFnExpr _ _ _ e vt
- -- not an abstraction: vectorise as a vanilla expression
- = mark DontInline False $ vectExpr e vt
-
-mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
-mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
+ = vectLam inline loop_breaker expr
+vectFnExpr _ _ expr
+ -- not an abstraction: vectorise as a vanilla expression
+ = vectExpr expr
-- |Vectorise type and dictionary applications.
--
-- These are always headed by a variable (as we don't support higher-rank polymorphism), but may
--- involve two sets of type variables and dictionaries. Consider,
+-- involve two sets of type variables and dictionaries. Consider,
--
-- > class C a where
-- > m :: D b => b -> a
--
-- The type of 'm' is 'm :: forall a. C a => forall b. D b => b -> a'.
--
-vectPolyApp :: CoreExprWithFVs -> VM VExpr
+vectPolyApp :: CoreExprWithVectInfo -> VM VExpr
vectPolyApp e0
= case e4 of
(_, AnnVar var)
@@ -530,21 +540,6 @@ vectDictExpr (Coercion coe)
-- instead they become dictionaries of vectorised methods). We treat them differently, though see
-- "Note [Scalar dfuns]" in 'Vectorise'.
--
-vectScalarFunMaybe :: CoreExpr -- ^ Expression to be vectorised
- -> VITree -- ^ Vectorisation information
- -> VM VExpr
-vectScalarFunMaybe expr (VITNode VIEncaps _) = vectScalarFun expr
-vectScalarFunMaybe _expr _ = noV $ ptext (sLit "not a scalar function")
-
--- |Vectorise an expression of functional type by lifting it by an application of a member of the
--- zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.) This is only a valid strategy if the
--- function does not contain parallel subcomputations and has only 'Scalar' types in its result and
--- arguments — this is a predcondition for calling this function.
---
--- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised,
--- instead they become dictionaries of vectorised methods). We treat them differently, though see
--- "Note [Scalar dfuns]" in 'Vectorise'.
---
vectScalarFun :: CoreExpr -> VM VExpr
vectScalarFun expr
= do
@@ -673,12 +668,11 @@ unVectDict ty e
-- variables are passed explicit (as conventional arguments) into the body during closure
-- construction.
--
-vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
- -> Bool -- ^ Whether the binding is a loop breaker.
- -> CoreExprWithFVs -- ^ Body of abstraction.
- -> VITree
+vectLam :: Bool -- ^ Should the RHS of a binding be inlined?
+ -> Bool -- ^ Whether the binding is a loop breaker.
+ -> CoreExprWithVectInfo -- ^ Body of abstraction.
-> VM VExpr
-vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi
+vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _)
= do { let (bndrs, body) = collectAnnValBinders expr
-- grab the in-scope type variables
@@ -706,18 +700,13 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi
. hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
$ do { -- generate the vectorised body of the lambda abstraction
; lc <- builtin liftingContext
- ; let viBody = stripLams expr vi
- -- ; checkTreeAnnM vi expr
- ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body viBody)
+ ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) $ vectExpr body
; vbody' <- break_loop lc res_ty vbody
; return $ vLams lc vbndrs vbody'
}
}
where
- stripLams (_, AnnLam _ e) (VITNode _ [vt]) = stripLams e vt
- stripLams _ vi = vi
-
maybe_inline n | inline = Inline n
| otherwise = DontInline
@@ -735,7 +724,7 @@ vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi
(LitAlt (mkMachInt 0), [], empty)])
}
| otherwise = return (ve, le)
-vectLam _ _ _ _ = panic "vectLam"
+vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda"
-- Vectorise an algebraic case expression.
--
@@ -754,31 +743,31 @@ vectLam _ _ _ _ = panic "vectLam"
--
-- FIXME: this is too lazy
-vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs-> Var -> Type
- -> [(AltCon, [Var], CoreExprWithFVs)] -> VITree
+vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type
+ -> [(AltCon, [Var], CoreExprWithVectInfo)]
-> VM VExpr
-vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] (VITNode _ (scrutVit : [altVit]))
+vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
= do
- vscrut <- vectExpr scrut scrutVit
+ vscrut <- vectExpr scrut
(vty, lty) <- vectAndLiftType ty
- (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
+ (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
-vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] (VITNode _ (scrutVit : [altVit]))
+vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
= do
- vscrut <- vectExpr scrut scrutVit
+ vscrut <- vectExpr scrut
(vty, lty) <- vectAndLiftType ty
- (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
+ (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
-vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _ (scrutVit : [altVit]))
+vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
= do
(vty, lty) <- vectAndLiftType ty
- vexpr <- vectExpr scrut scrutVit
+ vexpr <- vectExpr scrut
(vbndr, (vbndrs, (vect_body, lift_body)))
<- vect_scrut_bndr
. vectBndrsIn bndrs
- $ vectExpr body altVit
+ $ vectExpr body
let (vect_bndrs, lift_bndrs) = unzip vbndrs
(vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
vect_dc <- maybeV dataConErr (lookupDataCon dc)
@@ -796,9 +785,9 @@ vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _
dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
-vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
+vectAlgCase tycon _ty_args scrut bndr ty alts
= do
- vect_tc <- maybeV tyConErr (lookupTyCon tycon)
+ vect_tc <- vectTyCon tycon
(vty, lty) <- vectAndLiftType ty
let arity = length (tyConDataCons vect_tc)
@@ -807,10 +796,10 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
let sel = Var sel_bndr
(vbndr, valts) <- vect_scrut_bndr
- $ mapM (proc_alt arity sel vty lty) (zip alts' altVits)
+ $ mapM (proc_alt arity sel vty lty) alts'
let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
- vexpr <- vectExpr scrut scrutVit
+ vexpr <- vectExpr scrut
(vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
let (vect_bodies, lift_bodies) = unzip vbodies
@@ -829,8 +818,6 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
return . vLet (vNonRec vbndr vexpr)
$ (vect_case, lift_case)
where
- tyConErr = (text "vectAlgCase: type constructor not vectorised" <+> ppr tycon)
-
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
@@ -842,12 +829,12 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
cmp _ DEFAULT = GT
cmp _ _ = panic "vectAlgCase/cmp"
- proc_alt arity sel _ lty ((DataAlt dc, bndrs, body), vi)
+ proc_alt arity sel _ lty (DataAlt dc, bndrs, body@((fvs_body, _), _))
= do
vect_dc <- maybeV dataConErr (lookupDataCon dc)
let ntag = dataConTagZ vect_dc
tag = mkDataConTag vect_dc
- fvs = freeVarsOf body `delVarSetList` bndrs
+ fvs = fvs_body `delVarSetList` bndrs
sel_tags <- liftM (`App` sel) (builtin (selTags arity))
lc <- builtin liftingContext
@@ -860,7 +847,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
binds <- mapM (pack_var (Var lc) sel_tags tag)
. filter isLocalId
$ varSetElems fvs
- (ve, le) <- vectExpr body vi
+ (ve, le) <- vectExpr body
return (ve, Case (elems `App` sel) lc lty
[(DEFAULT, [], (mkLets (concat binds) le))])
-- empty <- emptyPD vty
@@ -892,9 +879,6 @@ vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
_ -> return []
-vectAlgCase tycon _ty_args _scrut _bndr _ty _alts (VITNode _ _)
- = pprPanic "vectAlgCase (mismatched node information)" (ppr tycon)
-
-- Support to compute information for vectorisation avoidance ------------------
@@ -905,202 +889,248 @@ data VectAvoidInfo = VIParr -- tree contains parallel computations
| VISimple -- result type is scalar & no parallel subcomputation
| VIComplex -- any result type, no parallel subcomputation
| VIEncaps -- tree encapsulated by 'liftSimple'
+ | VIDict -- dictionary computation (never parallel)
deriving (Eq, Show)
--- Instead of integrating the vectorisation avoidance information into Core expression, we keep
--- them in a separate tree (that structurally mirrors the Core expression that it annotates).
+-- Core expression annotated with free variables and vectorisation-specific information.
--
-data VITree = VITNode VectAvoidInfo [VITree]
- deriving (Show)
+type CoreExprWithVectInfo = AnnExpr Id (VarSet, VectAvoidInfo)
--- Is any of the tree nodes a 'VIPArr' node?
+-- Yield the type of an annotated core expression.
--
-anyVIPArr :: [VITree] -> Bool
-anyVIPArr = or . (map (\(VITNode vi _) -> vi == VIParr))
+annExprType :: AnnExpr Var ann -> Type
+annExprType = exprType . deAnnotate
--- Compute Core annotations to determine for which subexpressions we can avoid vectorisation
+-- Project the vectorisation information from an annotated Core expression.
--
--- FIXME: free scalar vars don't actually need to be passed through, since encapsulations makes sure,
--- that there are no free variables in encapsulated lambda expressions
-vectAvoidInfo :: CoreExprWithFVs -> VM VITree
-vectAvoidInfo ce@(_, AnnVar v)
- = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce
- ; viTrace ce vi []
- ; traceVt "vectAvoidInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce))
- ; return $ VITNode vi []
- }
+vectAvoidInfoOf :: CoreExprWithVectInfo -> VectAvoidInfo
+vectAvoidInfoOf ((_, vi), _) = vi
-vectAvoidInfo ce@(_, AnnLit _)
- = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce
- ; viTrace ce vi []
- ; traceVt "vectAvoidInfo AnnLit" (ppr $ exprType $ deAnnotate ce)
- ; return $ VITNode vi []
- }
+-- Is this a 'VIParr' node?
+--
+isVIParr :: CoreExprWithVectInfo -> Bool
+isVIParr = (== VIParr) . vectAvoidInfoOf
-vectAvoidInfo ce@(_, AnnApp e1 e2)
- = do { vt1 <- vectAvoidInfo e1
- ; vt2 <- vectAvoidInfo e2
- ; vi <- if anyVIPArr [vt1, vt2]
- then return VIParr
- else vectAvoidInfoType $ exprType $ deAnnotate ce
- ; viTrace ce vi [vt1, vt2]
- ; return $ VITNode vi [vt1, vt2]
- }
+-- Is this a 'VIEncaps' node?
+--
+isVIEncaps :: CoreExprWithVectInfo -> Bool
+isVIEncaps = (== VIEncaps) . vectAvoidInfoOf
-vectAvoidInfo ce@(_, AnnLam _var body)
- = do { vt@(VITNode vi _) <- vectAvoidInfo body
- ; viTrace ce vi [vt]
- ; let resultVI | vi == VIParr = VIParr
- | otherwise = VIComplex
- ; return $ VITNode resultVI [vt]
- }
+-- Is this a 'VIDict' node?
+--
+isVIDict :: CoreExprWithVectInfo -> Bool
+isVIDict = (== VIDict) . vectAvoidInfoOf
-vectAvoidInfo ce@(_, AnnLet (AnnNonRec _var expr) body)
- = do { vtE <- vectAvoidInfo expr
- ; vtB <- vectAvoidInfo body
- ; vi <- if anyVIPArr [vtE, vtB]
- then return VIParr
- else vectAvoidInfoType $ exprType $ deAnnotate ce
- ; viTrace ce vi [vtE, vtB]
- ; return $ VITNode vi [vtE, vtB]
- }
+-- 'VIParr' if either argument is 'VIParr'; otherwise, the first argument.
+--
+unlessVIParr :: VectAvoidInfo -> VectAvoidInfo -> VectAvoidInfo
+unlessVIParr _ VIParr = VIParr
+unlessVIParr vi _ = vi
-vectAvoidInfo ce@(_, AnnLet (AnnRec bnds) body)
- = do { let (_, exprs) = unzip bnds
- ; vtBnds <- mapM (\e -> vectAvoidInfo e) exprs
- ; if (anyVIPArr vtBnds)
- then do { vtBnds' <- mapM (\e -> vectAvoidInfo e) exprs
- ; vtB <- vectAvoidInfo body
- ; return (VITNode VIParr (vtB: vtBnds'))
- }
- else do { vtB@(VITNode vib _) <- vectAvoidInfo body
- ; ni <- if (vib == VIParr)
- then return VIParr
- else vectAvoidInfoType $ exprType $ deAnnotate ce
- ; viTrace ce ni (vtB : vtBnds)
- ; return $ VITNode ni (vtB : vtBnds)
- }
- }
+-- 'VIParr' if either arguments vectorisation information is 'VIParr'; otherwise, the vectorisation
+-- information of the first argument is produced.
+--
+unlessVIParrExpr :: VectAvoidInfo -> CoreExprWithVectInfo -> VectAvoidInfo
+infixl `unlessVIParrExpr`
+unlessVIParrExpr e1 e2 = e1 `unlessVIParr` vectAvoidInfoOf e2
-vectAvoidInfo ce@(_, AnnCase expr _var _ty alts)
- = do { vtExpr <- vectAvoidInfo expr
- ; vtAlts <- mapM (\(_, _, e) -> vectAvoidInfo e) alts
- ; ni <- if anyVIPArr (vtExpr : vtAlts)
- then return VIParr
- else vectAvoidInfoType $ exprType $ deAnnotate ce
- ; viTrace ce ni (vtExpr : vtAlts)
- ; return $ VITNode ni (vtExpr: vtAlts)
- }
+-- Compute Core annotations to determine for which subexpressions we can avoid vectorisation.
+--
+-- * The first argument is the set of free, local variables whose evaluation may entail parallelism.
+--
+vectAvoidInfo :: VarSet -> CoreExprWithFVs -> VM CoreExprWithVectInfo
+vectAvoidInfo pvs ce@(fvs, AnnVar v)
+ = do
+ { gpvs <- globalParallelVars
+ ; vi <- if v `elemVarSet` pvs || v `elemVarSet` gpvs
+ then return VIParr
+ else vectAvoidInfoTypeOf ce
+ ; viTrace ce vi []
-vectAvoidInfo (_, AnnCast expr _)
- = do { vt@(VITNode vi _) <- vectAvoidInfo expr
- ; return $ VITNode vi [vt]
- }
+ ; vit <- vectAvoidInfoTypeOf ce -- TEMPORARY
+ ; traceVt (" AnnVar: vectAvoidInfoTypeOf: " ++ show vit) empty
-vectAvoidInfo (_, AnnTick _ expr)
- = do { vt@(VITNode vi _) <- vectAvoidInfo expr
- ; return $ VITNode vi [vt]
- }
+ ; return ((fvs, vi), AnnVar v)
+ }
-vectAvoidInfo (_, AnnType {})
- = return $ VITNode VISimple []
+vectAvoidInfo _pvs ce@(fvs, AnnLit lit)
+ = do
+ { vi <- vectAvoidInfoTypeOf ce
+ ; viTrace ce vi []
+ ; return ((fvs, vi), AnnLit lit)
+ }
-vectAvoidInfo (_, AnnCoercion {})
- = return $ VITNode VISimple []
+vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2)
+ = do
+ { ceVI <- vectAvoidInfoTypeOf ce
+ ; eVI1 <- vectAvoidInfo pvs e1
+ ; eVI2 <- vectAvoidInfo pvs e2
+ ; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2
+ ; viTrace ce vi [eVI1, eVI2]
+ ; return ((fvs, vi), AnnApp eVI1 eVI2)
+ }
+
+vectAvoidInfo pvs ce@(fvs, AnnLam var body)
+ = do
+ { bodyVI <- vectAvoidInfo pvs body
+ ; varVI <- vectAvoidInfoType $ varType var
+ ; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI
+ ; viTrace ce vi [bodyVI]
+ ; return ((fvs, vi), AnnLam var bodyVI)
+ }
+
+vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body)
+ = do
+ { ceVI <- vectAvoidInfoTypeOf ce
+ ; eVI <- vectAvoidInfo pvs e
+ ; isScalarTy <- isScalar $ varType var
+ ; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy
+ then do -- binding is parallel
+ { bodyVI <- vectAvoidInfo (fvs `extendVarSet` var) body
+ ; return (bodyVI, VIParr)
+ }
+ else do -- binding doesn't affect parallelism
+ { bodyVI <- vectAvoidInfo fvs body
+ ; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI)
+ }
+ ; viTrace ce vi [eVI, bodyVI]
+ ; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI)
+ }
+
+vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body)
+ = do
+ { ceVI <- vectAvoidInfoTypeOf ce
+ ; bndsVI <- mapM (vectAvoidInfoBnd pvs) bnds
+ ; parrBndrs <- map fst <$> filterM isVIParrBnd bndsVI
+ ; if not . null $ parrBndrs
+ then do -- body may trigger parallelism via at least one binding
+ { new_pvs <- filterM ((not <$>) . isScalar . varType) parrBndrs
+ ; let extendedPvs = pvs `extendVarSetList` new_pvs
+ ; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds
+ ; bodyVI <- vectAvoidInfo extendedPvs body
+ ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI])
+ ; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI)
+ }
+ else do -- demanded bindings cannot trigger parallelism
+ { bodyVI <- vectAvoidInfo pvs body
+ ; let vi = ceVI `unlessVIParrExpr` bodyVI
+ ; viTrace ce vi (map snd bndsVI ++ [bodyVI])
+ ; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI)
+ }
+ }
+ where
+ vectAvoidInfoBnd pvs (var, e) = (var,) <$> vectAvoidInfo pvs e
+
+ isVIParrBnd (var, eVI)
+ = do
+ { isScalarTy <- isScalar (varType var)
+ ; return $ isVIParr eVI && not isScalarTy
+ }
+
+vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts)
+ = do
+ { ceVI <- vectAvoidInfoTypeOf ce
+ ; eVI <- vectAvoidInfo pvs e
+ ; isScalarTy <- isScalar . annExprType $ e
+ ; altsVI <- mapM (vectAvoidInfoAlt (isVIParr eVI && not isScalarTy)) alts
+ ; allScalarBndrs <- anyM allScalarAltBndrs altsVI
+ ; let alteVIs = [eVI | (_, _, eVI) <- altsVI]
+ vi | isVIParr eVI && not allScalarBndrs = VIParr
+ | otherwise
+ = foldl unlessVIParrExpr ceVI alteVIs
+ ; viTrace ce vi (eVI : alteVIs)
+ ; return ((fvs, vi), AnnCase eVI var ty altsVI)
+ }
+ where
+ vectAvoidInfoAlt isScalarScrut (con, bndrs, e) = (con, bndrs,) <$> vectAvoidInfo altPvs e
+ where
+ altPvs | isScalarScrut = pvs
+ | otherwise = pvs `extendVarSetList` bndrs
+
+ allScalarAltBndrs (_, bndrs, _) = allScalarVarType bndrs
+
+vectAvoidInfo pvs (fvs, AnnCast e (fvs_ann, ann))
+ = do
+ { eVI <- vectAvoidInfo pvs e
+ ; return ((fvs, vectAvoidInfoOf eVI), AnnCast eVI ((fvs_ann, VISimple), ann))
+ }
+
+vectAvoidInfo pvs (fvs, AnnTick tick e)
+ = do
+ { eVI <- vectAvoidInfo pvs e
+ ; return ((fvs, vectAvoidInfoOf eVI), AnnTick tick eVI)
+ }
+
+vectAvoidInfo _pvs (fvs, AnnType ty)
+ = return ((fvs, VISimple), AnnType ty)
+
+vectAvoidInfo _pvs (fvs, AnnCoercion coe)
+ = return ((fvs, VISimple), AnnCoercion coe)
-- Compute vectorisation avoidance information for a type.
--
vectAvoidInfoType :: Type -> VM VectAvoidInfo
-vectAvoidInfoType ty
- | maybeParrTy ty = return VIParr
- | otherwise
- = do { sType <- isSimpleType ty
- ; if sType
- then return VISimple
- else return VIComplex
- }
+vectAvoidInfoType ty
+ | isPredTy ty
+ = return VIDict
+ | Just (arg, res) <- splitFunTy_maybe ty
+ = do
+ { argVI <- vectAvoidInfoType arg
+ ; resVI <- vectAvoidInfoType res
+ ; case (argVI, resVI) of
+ (VISimple, VISimple) -> return VISimple -- NB: diverts from the paper: scalar functions
+ (_ , VIDict) -> return VIDict
+ _ -> return $ VIComplex `unlessVIParr` argVI `unlessVIParr` resVI
+ }
+ | otherwise
+ = do
+ { parr <- maybeParrTy ty
+ ; if parr
+ then return VIParr
+ else do
+ { scalar <- isScalar ty
+ ; if scalar
+ then return VISimple
+ else return VIComplex
+ } }
+
+-- Compute vectorisation avoidance information for the type of a Core expression (with FVs).
+--
+vectAvoidInfoTypeOf :: AnnExpr Var ann -> VM VectAvoidInfo
+vectAvoidInfoTypeOf = vectAvoidInfoType . annExprType
--- Checks whether the type might be a parallel array type. In particular, if the outermost
--- constructor is a type family, we conservatively assume that it may be a parallel array type.
+-- Checks whether the type might be a parallel array type.
--
-maybeParrTy :: Type -> Bool
+maybeParrTy :: Type -> VM Bool
maybeParrTy ty
- | Just ty' <- coreView ty = maybeParrTy ty'
- | Just (tyCon, ts) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon
- || or (map maybeParrTy ts)
-maybeParrTy _ = False
-
--- FIXME: This should not be hardcoded.
-isSimpleType :: Type -> VM Bool
-isSimpleType ty
- | Just (c, _cs) <- splitTyConApp_maybe ty
- = return $ (tyConName c) `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName]
-{-
- = do { globals <- globalScalarTyCons
- ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c)
- ; return (elemNameSet (tyConName c) globals )
- }
- -}
- | Nothing <- splitTyConApp_maybe ty
- = return False
-isSimpleType ty
- = pprPanic "Vectorise.Exp.isSimpleType not handled" (ppr ty)
-
-varsSimple :: VarSet -> VM Bool
-varsSimple vs
- = do { varTypes <- mapM isSimpleType $ map varType $ varSetElems vs
- ; return $ and varTypes
- }
-
-viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [VITree] -> VM ()
-viTrace ce vi vTs
- = traceVt ("vitrace " ++ (show vi) ++ "[" ++ (concat $ map (\(VITNode vi _) -> show vi ++ " ") vTs) ++"]")
- (ppr $ deAnnotate ce)
-
+ -- looking through newtypes
+ | Just ty' <- coreView ty
+ = (== VIParr) <$> vectAvoidInfoType ty'
+ -- decompose constructor applications
+ | Just (tc, ts) <- splitTyConApp_maybe ty
+ = do
+ { isParallel <- (tyConName tc `elemNameSet`) <$> globalParallelTyCons
+ ; if isParallel
+ then return True
+ else or <$> mapM maybeParrTy ts
+ }
+maybeParrTy (ForAllTy _ ty) = maybeParrTy ty
+maybeParrTy _ = return False
-{-
----- Sanity check of the tree, for debugging only
-checkTree :: VITree -> CoreExpr -> Bool
-checkTree (VITNode _ []) (Type _ty)
- = True
-
-checkTree (VITNode _ []) (Var _v)
- = True
-
-checkTree (VITNode _ []) (Lit _)
- = True
-
-checkTree (VITNode _ [vit]) (Tick _ expr)
- = checkTree vit expr
-
-checkTree (VITNode _ [vit]) (Lam _ expr)
- = checkTree vit expr
-
-checkTree (VITNode _ [vit1, vit2]) (App ce1 ce2)
- = (checkTree vit1 ce1) && (checkTree vit2 ce2)
-
-checkTree (VITNode _ (scrutVit : altVits)) (Case scrut _ _ alts)
- = (checkTree scrutVit scrut) && (and $ zipWith checkAlt altVits alts)
- where
- checkAlt vt (_, _, expr) = checkTree vt expr
-
-checkTree (VITNode _ [vt1, vt2]) (Let (NonRec _ expr1) expr2)
- = (checkTree vt1 expr1) && (checkTree vt2 expr2)
-
-checkTree (VITNode _ (vtB : vtBnds)) (Let (Rec bndngs) expr)
- = (and $ zipWith checkBndr vtBnds bndngs) &&
- (checkTree vtB expr)
- where
- checkBndr vt (_, e) = checkTree vt e
-
-checkTree (VITNode _ [vit]) (Cast expr _)
- = checkTree vit expr
+-- Are the types of all variables in the 'Scalar' class?
+--
+allScalarVarType :: [Var] -> VM Bool
+allScalarVarType vs = and <$> mapM (isScalar . varType) vs
-checkTree _ _ = False
+-- Are the types of all variables in the set in the 'Scalar' class?
+--
+allScalarVarTypeSet :: VarSet -> VM Bool
+allScalarVarTypeSet = allScalarVarType . varSetElems
-checkTreeAnnM:: VITree -> CoreExprWithFVs -> VM ()
-checkTreeAnnM vi e =
- if not (checkTree vi $ deAnnotate e)
- then error ("checkTreeAnnM : \n " ++ show vi)
- else return ()
--}
+-- Debugging support
+--
+viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [CoreExprWithVectInfo] -> VM ()
+viTrace ce vi vTs
+ = traceVt ("vect info: " ++ show vi ++ "[" ++
+ (concat $ map ((++ " ") . show . vectAvoidInfoOf) vTs) ++ "]")
+ (ppr $ deAnnotate ce)