summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>2009-10-30 00:41:37 +0000
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>2009-10-30 00:41:37 +0000
commit222415a5b658e737a0a1f2c980c6f80635289f75 (patch)
treedca1069b5dc9b378ec190b3247179472f2ad7fa4 /compiler
parentcfccfa67393fcf8cb43aaa465d421b67c7117580 (diff)
downloadhaskell-222415a5b658e737a0a1f2c980c6f80635289f75.tar.gz
Adapt vectoriser to new inlining mechanism
Diffstat (limited to 'compiler')
-rw-r--r--compiler/vectorise/VectCore.hs8
-rw-r--r--compiler/vectorise/VectType.hs131
-rw-r--r--compiler/vectorise/VectUtils.hs74
-rw-r--r--compiler/vectorise/Vectorise.hs115
4 files changed, 202 insertions, 126 deletions
diff --git a/compiler/vectorise/VectCore.hs b/compiler/vectorise/VectCore.hs
index d651526ddf..cdae4dd996 100644
--- a/compiler/vectorise/VectCore.hs
+++ b/compiler/vectorise/VectCore.hs
@@ -10,7 +10,7 @@ module VectCore (
vVar, vType, vNote, vLet,
vLams, vLamsWithoutLC, vVarApps,
- vCaseDEFAULT, vInlineMe
+ vCaseDEFAULT
) where
#include "HsVersions.h"
@@ -18,7 +18,6 @@ module VectCore (
import CoreSyn
import Type ( Type )
import Var
-import Outputable
type Vect a = (a,a)
type VVar = Vect Var
@@ -83,8 +82,3 @@ vCaseDEFAULT (vscrut, lscrut) (vbndr, lbndr) vty lty (vbody, lbody)
where
mkDEFAULT e = [(DEFAULT, [], e)]
-vInlineMe :: VExpr -> VExpr
-vInlineMe (vexpr, lexpr) = (mkInlineMe vexpr, mkInlineMe lexpr)
-
-mkInlineMe :: CoreExpr -> CoreExpr
-mkInlineMe = pprTrace "VectCore.mkInlineMe" (text "Roman: need to replace mkInlineMe with an InlineRule somehow")
diff --git a/compiler/vectorise/VectType.hs b/compiler/vectorise/VectType.hs
index 7b9ec50e83..6e7557e9e2 100644
--- a/compiler/vectorise/VectType.hs
+++ b/compiler/vectorise/VectType.hs
@@ -11,6 +11,7 @@ import VectCore
import HscTypes ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
import CoreSyn
import CoreUtils
+import CoreUnfold
import MkCore ( mkWildCase )
import BuildTyCl
import DataCon
@@ -20,9 +21,11 @@ import TypeRep
import Coercion
import FamInstEnv ( FamInst, mkLocalFamInst )
import OccName
+import Id
import MkId
-import BasicTypes ( StrictnessMark(..), boolToRecFlag )
-import Var ( Var, TyVar )
+import BasicTypes ( StrictnessMark(..), boolToRecFlag,
+ dfunInlinePragma )
+import Var ( Var, TyVar, varType )
import Name ( Name, getOccName )
import NameEnv
@@ -37,7 +40,7 @@ import FastString
import MonadUtils ( zipWith3M, foldrM, concatMapM )
import Control.Monad ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
-import Data.List ( inits, tails, zipWith4, zipWith6 )
+import Data.List ( inits, tails, zipWith4, zipWith5 )
-- ----------------------------------------------------------------------------
-- Types
@@ -119,26 +122,28 @@ vectTypeEnv env
let orig_tcs = keep_tcs ++ conv_tcs
vect_tcs = keep_tcs ++ new_tcs
- dfuns <- mapM mkPADFun vect_tcs
- defTyConPAs (zip vect_tcs dfuns)
- reprs <- mapM tyConRepr vect_tcs
- repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
- pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
- binds <- sequence (zipWith6 buildTyConBindings orig_tcs
- vect_tcs
- repr_tcs
- pdata_tcs
- dfuns
- reprs)
-
- let all_new_tcs = new_tcs ++ repr_tcs ++ pdata_tcs
+ (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
+ do
+ defTyConPAs (zipLazy vect_tcs dfuns')
+ reprs <- mapM tyConRepr vect_tcs
+ repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
+ pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
+ dfuns <- sequence $ zipWith5 buildTyConBindings orig_tcs
+ vect_tcs
+ repr_tcs
+ pdata_tcs
+ reprs
+ binds <- takeHoisted
+ return (dfuns, binds, repr_tcs ++ pdata_tcs)
+
+ let all_new_tcs = new_tcs ++ inst_tcs
let new_env = extendTypeEnvList env
(map ATyCon all_new_tcs
++ [ADataCon dc | tc <- all_new_tcs
, dc <- tyConDataCons tc])
- return (new_env, map mkLocalFamInst (repr_tcs ++ pdata_tcs), concat binds)
+ return (new_env, map mkLocalFamInst inst_tcs, binds)
where
tycons = typeEnvTyCons env
groups = tyConGroups tycons
@@ -715,18 +720,12 @@ buildPDataDataCon orig_name vect_tc repr_tc repr
comp_ty r = mkPDataType (compOrigType r)
-mkPADFun :: TyCon -> VM Var
-mkPADFun vect_tc
- = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc
-
-buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var -> SumRepr
- -> VM [(Var, CoreExpr)]
-buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc dfun repr
+buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr
+ -> VM Var
+buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
= do
vectDataConWorkers orig_tc vect_tc pdata_tc
- dict <- buildPADict vect_tc prepr_tc pdata_tc repr
- binds <- takeHoisted
- return $ (dfun, dict) : binds
+ buildPADict vect_tc prepr_tc pdata_tc repr
vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
vectDataConWorkers orig_tc vect_tc arr_tc
@@ -781,53 +780,71 @@ vectDataConWorkers orig_tc vect_tc arr_tc
def_worker data_con arg_tys mk_body
= do
+ arity <- polyArity tyvars
body <- closedV
. inBind orig_worker
- . polyAbstract tyvars $ \abstract ->
- liftM (abstract . vectorised)
+ . polyAbstract tyvars $ \args ->
+ liftM (mkLams (tyvars ++ args) . vectorised)
$ buildClosures tyvars [] arg_tys res_ty mk_body
- vect_worker <- cloneId mkVectOcc orig_worker (exprType body)
+ raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
+ let vect_worker = raw_worker `setIdUnfolding`
+ mkInlineRule InlSat body arity
defGlobalVar orig_worker vect_worker
return (vect_worker, body)
where
orig_worker = dataConWorkId data_con
-buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
+buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
buildPADict vect_tc prepr_tc arr_tc repr
- = polyAbstract tvs $ \abstract ->
+ = polyAbstract tvs $ \args ->
do
- meth_binds <- mapM mk_method paMethods
- let meth_exprs = map (Var . fst) meth_binds
+ method_ids <- mapM (method args) paMethods
+
+ pa_tc <- builtin paTyCon
+ pa_con <- builtin paDataCon
+ let dict = mkLams (tvs ++ args)
+ $ mkConApp pa_con
+ $ Type inst_ty : map (method_call args) method_ids
+
+ dfun_ty = mkForAllTys tvs
+ $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
+
+ raw_dfun <- newExportedVar dfun_name dfun_ty
+ let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding pa_con method_ids
+ `setInlinePragma` dfunInlinePragma
- pa_dc <- builtin paDataCon
- let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs)
- body = Let (Rec meth_binds) dict
- return . mkInlineMe $ abstract body
+ hoistBinding dfun dict
+ return dfun
where
- tvs = tyConTyVars arr_tc
+ tvs = tyConTyVars vect_tc
arg_tys = mkTyVarTys tvs
+ inst_ty = mkTyConApp vect_tc arg_tys
- mk_method (name, build)
+ dfun_name = mkPADFunOcc (getOccName vect_tc)
+
+ method args (name, build)
= localV
$ do
- body <- build vect_tc prepr_tc arr_tc repr
- var <- newLocalVar name (exprType body)
- return (var, mkInlineMe body)
-
--- The InlineMe note has gone away. Instead, you need to use
--- CoreUnfold.mkInlineRule to make an InlineRule for the thing, and
--- attach *that* as the unfolding for the dictionary binder
-mkInlineMe :: CoreExpr -> CoreExpr
-mkInlineMe expr = pprTrace "VectType: Roman, you need to use the new InlineRule story"
- (ppr expr) expr
-
-paMethods :: [(FastString, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
-paMethods = [(fsLit "dictPRepr", buildPRDict),
- (fsLit "toPRepr", buildToPRepr),
- (fsLit "fromPRepr", buildFromPRepr),
- (fsLit "toArrPRepr", buildToArrPRepr),
- (fsLit "fromArrPRepr", buildFromArrPRepr)]
+ expr <- build vect_tc prepr_tc arr_tc repr
+ let body = mkLams (tvs ++ args) expr
+ raw_var <- newExportedVar (method_name name) (exprType body)
+ let var = raw_var
+ `setIdUnfolding` mkInlineRule InlSat body (length args)
+ hoistBinding var body
+ return var
+
+ method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
+
+ method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
+
+
+paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
+paMethods = [("dictPRepr", buildPRDict),
+ ("toPRepr", buildToPRepr),
+ ("fromPRepr", buildFromPRepr),
+ ("toArrPRepr", buildToArrPRepr),
+ ("fromArrPRepr", buildFromArrPRepr)]
-- | Split the given tycons into two sets depending on whether they have to be
-- converted (first list) or not (second list). The first argument contains
diff --git a/compiler/vectorise/VectUtils.hs b/compiler/vectorise/VectUtils.hs
index e5084241d6..9faa0edf73 100644
--- a/compiler/vectorise/VectUtils.hs
+++ b/compiler/vectorise/VectUtils.hs
@@ -15,7 +15,8 @@ module VectUtils (
combinePD,
liftPD,
zipScalars, scalarClosure,
- polyAbstract, polyApply, polyVApply,
+ polyAbstract, polyApply, polyVApply, polyArity,
+ Inline(..), addInlineArity, inlineMe,
hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
buildClosure, buildClosures,
mkClosureApp
@@ -27,6 +28,7 @@ import VectMonad
import MkCore ( mkCoreTup, mkCoreTupTy, mkWildCase )
import CoreSyn
import CoreUtils
+import CoreUnfold ( mkInlineRule )
import Coercion
import Type
import TypeRep
@@ -34,6 +36,7 @@ import TyCon
import DataCon
import Var
import MkId ( unwrapFamInstScrut )
+import Id ( setIdUnfolding )
import TysWiredIn
import BasicTypes ( Boxity(..) )
import Literal ( Literal, mkMachInt )
@@ -43,7 +46,6 @@ import FastString
import Control.Monad
-
collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
collectAnnTypeArgs expr = go expr []
where
@@ -315,13 +317,14 @@ newLocalVVar fs vty
lv <- newLocalVar fs lty
return (vv,lv)
-polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
+polyAbstract :: [TyVar] -> ([Var] -> VM a) -> VM a
polyAbstract tvs p
= localV
$ do
mdicts <- mapM mk_dict_var tvs
- zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
- p (mk_lams mdicts)
+ zipWithM_ (\tv -> maybe (defLocalTyVar tv)
+ (defLocalTyVarWithPA tv . Var)) tvs mdicts
+ p (mk_args mdicts)
where
mk_dict_var tv = do
r <- paDictArgType tv
@@ -329,7 +332,12 @@ polyAbstract tvs p
Just ty -> liftM Just (newLocalVar (fsLit "dPA") ty)
Nothing -> return Nothing
- mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
+ mk_args mdicts = [dict | Just dict <- mdicts]
+
+polyArity :: [TyVar] -> VM Int
+polyArity tvs = do
+ tys <- mapM paDictArgType tvs
+ return $ length [() | Just _ <- tys]
polyApply :: CoreExpr -> [Type] -> VM CoreExpr
polyApply expr tys
@@ -343,31 +351,48 @@ polyVApply expr tys
dicts <- mapM paDictOfType tys
return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
+
+data Inline = Inline Int -- arity
+ | DontInline
+
+addInlineArity :: Inline -> Int -> Inline
+addInlineArity (Inline m) n = Inline (m+n)
+addInlineArity DontInline _ = DontInline
+
+inlineMe :: Inline
+inlineMe = Inline 0
+
hoistBinding :: Var -> CoreExpr -> VM ()
hoistBinding v e = updGEnv $ \env ->
env { global_bindings = (v,e) : global_bindings env }
-hoistExpr :: FastString -> CoreExpr -> VM Var
-hoistExpr fs expr
+hoistExpr :: FastString -> CoreExpr -> Inline -> VM Var
+hoistExpr fs expr inl
= do
- var <- newLocalVar fs (exprType expr)
+ var <- mk_inline `liftM` newLocalVar fs (exprType expr)
hoistBinding var expr
return var
+ where
+ mk_inline var = case inl of
+ Inline arity -> var `setIdUnfolding`
+ mkInlineRule InlSat expr arity
+ DontInline -> var
-hoistVExpr :: VExpr -> VM VVar
-hoistVExpr (ve, le)
+hoistVExpr :: VExpr -> Inline -> VM VVar
+hoistVExpr (ve, le) inl
= do
fs <- getBindName
- vv <- hoistExpr ('v' `consFS` fs) ve
- lv <- hoistExpr ('l' `consFS` fs) le
+ vv <- hoistExpr ('v' `consFS` fs) ve inl
+ lv <- hoistExpr ('l' `consFS` fs) le (addInlineArity inl 1)
return (vv, lv)
-hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
-hoistPolyVExpr tvs p
+hoistPolyVExpr :: [TyVar] -> Inline -> VM VExpr -> VM VExpr
+hoistPolyVExpr tvs inline p
= do
- expr <- closedV . polyAbstract tvs $ \abstract ->
- liftM (mapVect abstract) p
- fn <- hoistVExpr expr
+ inline' <- liftM (addInlineArity inline) (polyArity tvs)
+ expr <- closedV . polyAbstract tvs $ \args ->
+ liftM (mapVect (mkLams $ tvs ++ args)) p
+ fn <- hoistVExpr expr inline'
polyVApply (vVar fn) (mkTyVarTys tvs)
takeHoisted :: VM [(Var, CoreExpr)]
@@ -413,14 +438,15 @@ buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
buildClosures _ _ [] _ mk_body
= mk_body
buildClosures tvs vars [arg_ty] res_ty mk_body
- = liftM vInlineMe (buildClosure tvs vars arg_ty res_ty mk_body)
+ = -- liftM vInlineMe $
+ buildClosure tvs vars arg_ty res_ty mk_body
buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
= do
res_ty' <- mkClosureTypes arg_tys res_ty
arg <- newLocalVVar (fsLit "x") arg_ty
- liftM vInlineMe
- . buildClosure tvs vars arg_ty res_ty'
- . hoistPolyVExpr tvs
+ -- liftM vInlineMe
+ buildClosure tvs vars arg_ty res_ty'
+ . hoistPolyVExpr tvs (Inline (length vars + 1))
$ do
lc <- builtin liftingContext
clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
@@ -438,11 +464,11 @@ buildClosure tvs vars arg_ty res_ty mk_body
env_bndr <- newLocalVVar (fsLit "env") env_ty
arg_bndr <- newLocalVVar (fsLit "arg") arg_ty
- fn <- hoistPolyVExpr tvs
+ fn <- hoistPolyVExpr tvs (Inline 2)
$ do
lc <- builtin liftingContext
body <- mk_body
- return . vInlineMe
+ return -- . vInlineMe
. vLams lc [env_bndr, arg_bndr]
$ bind (vVar env_bndr)
(vVarApps lc body (vars ++ [arg_bndr]))
diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs
index 2bce391a8f..59fded3c4f 100644
--- a/compiler/vectorise/Vectorise.hs
+++ b/compiler/vectorise/Vectorise.hs
@@ -12,6 +12,7 @@ import HscTypes hiding ( MonadThings(..) )
import Module ( PackageId )
import CoreSyn
import CoreUtils
+import CoreUnfold ( mkInlineRule )
import MkCore ( mkWildCase )
import CoreFVs
import CoreMonad ( CoreM, getHscEnv )
@@ -24,6 +25,7 @@ import VarEnv
import VarSet
import Id
import OccName
+import BasicTypes ( isLoopBreaker )
import Literal ( Literal, mkMachInt )
import TysWiredIn
@@ -31,7 +33,8 @@ import TysPrim ( intPrimTy )
import Outputable
import FastString
-import Control.Monad ( liftM, liftM2, zipWithM )
+import Util ( zipLazy )
+import Control.Monad
import Data.List ( sortBy, unzip4 )
vectorise :: PackageId -> ModGuts -> CoreM ModGuts
@@ -67,8 +70,8 @@ vectModule guts
vectTopBind :: CoreBind -> VM CoreBind
vectTopBind b@(NonRec var expr)
= do
- var' <- vectTopBinder var
- expr' <- vectTopRhs var expr
+ (inline, expr') <- vectTopRhs var expr
+ var' <- vectTopBinder var inline expr'
hs <- takeHoisted
cexpr <- tryConvert var var' expr
return . Rec $ (var, cexpr) : (var', expr') : hs
@@ -77,8 +80,13 @@ vectTopBind b@(NonRec var expr)
vectTopBind b@(Rec bs)
= do
- vars' <- mapM vectTopBinder vars
- exprs' <- zipWithM vectTopRhs vars exprs
+ (vars', _, exprs') <- fixV $ \ ~(_, inlines, rhss) ->
+ do
+ vars' <- sequence [vectTopBinder var inline rhs
+ | (var, ~(inline, rhs))
+ <- zipLazy vars (zip inlines rhss)]
+ (inlines', exprs') <- mapAndUnzipM (uncurry vectTopRhs) bs
+ return (vars', inlines', exprs')
hs <- takeHoisted
cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
@@ -87,20 +95,28 @@ vectTopBind b@(Rec bs)
where
(vars, exprs) = unzip bs
-vectTopBinder :: Var -> VM Var
-vectTopBinder var
+-- NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
+-- used inside of fixV in vectTopBind
+vectTopBinder :: Var -> Inline -> CoreExpr -> VM Var
+vectTopBinder var inline expr
= do
vty <- vectType (idType var)
- var' <- cloneId mkVectOcc var vty
+ var' <- liftM (`setIdUnfolding` unfolding) $ cloneId mkVectOcc var vty
defGlobalVar var var'
return var'
+ where
+ unfolding = case inline of
+ Inline arity -> mkInlineRule InlSat expr arity
+ DontInline -> noUnfolding
-vectTopRhs :: Var -> CoreExpr -> VM CoreExpr
+vectTopRhs :: Var -> CoreExpr -> VM (Inline, CoreExpr)
vectTopRhs var expr
- = do
- closedV . liftM vectorised
- . inBind var
- $ vectPolyExpr (freeVars expr)
+ = closedV
+ $ do
+ (inline, vexpr) <- inBind var
+ $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
+ (freeVars expr)
+ return (inline, vectorised vexpr)
tryConvert :: Var -> Var -> CoreExpr -> VM CoreExpr
tryConvert var vect_var rhs
@@ -187,14 +203,19 @@ vectLiteral lit
lexpr <- liftPD (Lit lit)
return (Lit lit, lexpr)
-vectPolyExpr :: CoreExprWithFVs -> VM VExpr
-vectPolyExpr (_, AnnNote note expr)
- = liftM (vNote note) $ vectPolyExpr expr
-vectPolyExpr expr
- = polyAbstract tvs $ \abstract ->
- do
- mono' <- vectFnExpr False mono
- return $ mapVect abstract mono'
+vectPolyExpr :: Bool -> CoreExprWithFVs -> VM (Inline, VExpr)
+vectPolyExpr loop_breaker (_, AnnNote note expr)
+ = do
+ (inline, expr') <- vectPolyExpr loop_breaker expr
+ return (inline, vNote note expr')
+vectPolyExpr loop_breaker expr
+ = do
+ arity <- polyArity tvs
+ polyAbstract tvs $ \args ->
+ do
+ (inline, mono') <- vectFnExpr False loop_breaker mono
+ return (addInlineArity inline arity,
+ mapVect (mkLams $ tvs ++ args) mono')
where
(tvs, mono) = collectAnnTypeBinders expr
@@ -245,7 +266,7 @@ vectExpr (_, AnnCase scrut bndr ty alts)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
- vrhs <- localV . inBind bndr $ vectPolyExpr rhs
+ vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vLet (vNonRec vbndr vrhs) vbody
@@ -254,17 +275,18 @@ vectExpr (_, AnnLet (AnnRec bs) body)
(vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
$ liftM2 (,)
(zipWithM vect_rhs bndrs rhss)
- (vectPolyExpr body)
+ (vectExpr body)
return $ vLet (vRec vbndrs vrhss) vbody
where
(bndrs, rhss) = unzip bs
vect_rhs bndr rhs = localV
. inBind bndr
- $ vectExpr rhs
+ . liftM snd
+ $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
vectExpr e@(_, AnnLam bndr _)
- | isId bndr = vectFnExpr True e
+ | isId bndr = liftM snd $ vectFnExpr True False e
{-
onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
`orElseV` vectLam True fvs bs body
@@ -274,14 +296,17 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
-vectFnExpr :: Bool -> CoreExprWithFVs -> VM VExpr
-vectFnExpr inline e@(fvs, AnnLam bndr _)
- | isId bndr = onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
- `orElseV` vectLam inline fvs bs body
+vectFnExpr :: Bool -> Bool -> CoreExprWithFVs -> VM (Inline, VExpr)
+vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
+ | isId bndr = onlyIfV (isEmptyVarSet fvs)
+ (mark DontInline . vectScalarLam bs $ deAnnotate body)
+ `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
where
(bs,body) = collectAnnValBinders e
-vectFnExpr _ e = vectExpr e
+vectFnExpr _ _ e = mark DontInline $ vectExpr e
+mark :: Inline -> VM a -> VM (Inline, a)
+mark b p = do { x <- p; return (b,x) }
vectScalarLam :: [Var] -> CoreExpr -> VM VExpr
vectScalarLam args body
@@ -291,11 +316,11 @@ vectScalarLam args body
&& is_scalar_ty res_ty
&& is_scalar (extendVarSetList scalars args) body)
$ do
- fn_var <- hoistExpr (fsLit "fn") (mkLams args body)
+ fn_var <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
zipf <- zipScalars arg_tys res_ty
clo <- scalarClosure arg_tys res_ty (Var fn_var)
(zipf `App` Var fn_var)
- clo_var <- hoistExpr (fsLit "clo") clo
+ clo_var <- hoistExpr (fsLit "clo") clo DontInline
lclo <- liftPD (Var clo_var)
return (Var clo_var, lclo)
where
@@ -314,8 +339,8 @@ vectScalarLam args body
is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
is_scalar _ _ = False
-vectLam :: Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
-vectLam inline fvs bs body
+vectLam :: Bool -> Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
+vectLam inline loop_breaker fvs bs body
= do
tyvars <- localTyVars
(vs, vvs) <- readLEnv $ \env ->
@@ -326,14 +351,28 @@ vectLam inline fvs bs body
res_ty <- vectType (exprType $ deAnnotate body)
buildClosures tyvars vvs arg_tys res_ty
- . hoistPolyVExpr tyvars
+ . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
$ do
lc <- builtin liftingContext
(vbndrs, vbody) <- vectBndrsIn (vs ++ bs)
(vectExpr body)
- return . maybe_inline $ vLams lc vbndrs vbody
+ vbody' <- break_loop lc res_ty vbody
+ return $ vLams lc vbndrs vbody'
where
- maybe_inline = if inline then vInlineMe else id
+ maybe_inline n | inline = Inline n
+ | otherwise = DontInline
+
+ break_loop lc ty (ve, le)
+ | loop_breaker
+ = do
+ empty <- emptyPD ty
+ lty <- mkPDataType ty
+ return (ve, mkWildCase (Var lc) intPrimTy lty
+ [(DEFAULT, [], le),
+ (LitAlt (mkMachInt 0), [], empty)])
+
+ | otherwise = return (ve, le)
+
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
@@ -441,7 +480,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts
cmp _ DEFAULT = GT
cmp _ _ = panic "vectAlgCase/cmp"
- proc_alt arity sel vty lty (DataAlt dc, bndrs, body)
+ proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
= do
vect_dc <- maybeV (lookupDataCon dc)
let ntag = dataConTagZ vect_dc