summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkeller@cse.unsw.edu.au <unknown>2011-02-14 00:29:45 +0000
committerkeller@cse.unsw.edu.au <unknown>2011-02-14 00:29:45 +0000
commit80cb2c397aec9751586c3a2a753f848e143dbd67 (patch)
tree095c2589f8b775c362eb90ba25a92b6040dc3cea
parent37b0cb1147cadef4d68f3fc61faa3ec11ad47440 (diff)
downloadhaskell-80cb2c397aec9751586c3a2a753f848e143dbd67.tar.gz
Handling of recursive scalar functions in isScalarLam
-rw-r--r--compiler/vectorise/Vectorise.hs39
-rw-r--r--compiler/vectorise/Vectorise/Exp.hs43
-rw-r--r--compiler/vectorise/Vectorise/Monad.hs8
3 files changed, 57 insertions, 33 deletions
diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs
index 8c9579e621..999e8ef9e1 100644
--- a/compiler/vectorise/Vectorise.hs
+++ b/compiler/vectorise/Vectorise.hs
@@ -115,7 +115,7 @@ vectModule guts
vectTopBind :: CoreBind -> VM CoreBind
vectTopBind b@(NonRec var expr)
= do
- (inline, expr') <- vectTopRhs var expr
+ (inline, _, expr') <- vectTopRhs [] var expr
var' <- vectTopBinder var inline expr'
-- Vectorising the body may create other top-level bindings.
@@ -131,15 +131,23 @@ vectTopBind b@(NonRec var expr)
vectTopBind b@(Rec bs)
= do
+ -- pprTrace "in Rec" (ppr vars) $ return ()
(vars', _, exprs')
<- fixV $ \ ~(_, inlines, rhss) ->
do vars' <- sequence [vectTopBinder var inline rhs
| (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
- (inlines', exprs')
- <- mapAndUnzipM (uncurry vectTopRhs) bs
-
- return (vars', inlines', exprs')
-
+ (inlines', areScalars', exprs')
+ <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
+ if (and areScalars') || (length bs <= 1)
+ then do
+ -- pprTrace "in Rec - all scalars??" (ppr areScalars') $ return ()
+ return (vars', inlines', exprs')
+ else do
+ -- pprTrace "in Rec - not all scalars" (ppr areScalars') $ return ()
+ mapM deleteGlobalScalar vars
+ (inlines'', _, exprs'') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
+ return (vars', inlines'', exprs'')
+
hs <- takeHoisted
cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
@@ -147,7 +155,9 @@ vectTopBind b@(Rec bs)
return b
where
(vars, exprs) = unzip bs
-
+ mapAndUnzip3M f xs = do
+ ys <- mapM f xs
+ return $ unzip3 ys
-- | Make the vectorised version of this top level binder, and add the mapping
-- between it and the original to the state. For some binder @foo@ the vectorised
@@ -182,21 +192,22 @@ vectTopBinder var inline expr
-- | Vectorise the RHS of a top-level binding, in an empty local environment.
vectTopRhs
- :: Var -- ^ Name of the binding.
+ :: [Var] -- ^ Names of all functions in the rec block
+ -> Var -- ^ Name of the binding.
-> CoreExpr -- ^ Body of the binding.
- -> VM (Inline, CoreExpr)
+ -> VM (Inline, Bool, CoreExpr)
-vectTopRhs var expr
+vectTopRhs recFs var expr
= dtrace (vcat [text "vectTopRhs", ppr expr])
$ closedV
$ do (inline, isScalar, vexpr) <- inBind var
- $ pprTrace "vectTopRhs" (ppr var)
- $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
+ -- $ pprTrace "vectTopRhs" (ppr var)
+ $ vectPolyExpr (isLoopBreaker $ idOccInfo var) recFs
(freeVars expr)
if isScalar
then addGlobalScalar var
- else return ()
- return (inline, vectorised vexpr)
+ else deleteGlobalScalar var
+ return (inline, isScalar, vectorised vexpr)
-- | Project out the vectorised version of a binding from some closure,
diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs
index b94224ab7b..091a760d79 100644
--- a/compiler/vectorise/Vectorise/Exp.hs
+++ b/compiler/vectorise/Vectorise/Exp.hs
@@ -35,20 +35,21 @@ import Data.List
-- | Vectorise a polymorphic expression.
vectPolyExpr
:: Bool -- ^ When vectorising the RHS of a binding, whether that
- -- binding is a loop breaker.
+ -- binding is a loop breaker.
+ -> [Var]
-> CoreExprWithFVs
-> VM (Inline, Bool, VExpr)
-vectPolyExpr loop_breaker (_, AnnNote note expr)
- = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker expr
+vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
+ = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
return (inline, isScalarFn, vNote note expr')
-vectPolyExpr loop_breaker expr
+vectPolyExpr loop_breaker recFns expr
= do
arity <- polyArity tvs
polyAbstract tvs $ \args ->
do
- (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker mono
+ (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
return (addInlineArity inline arity, isScalarFn,
mapVect (mkLams $ tvs ++ args) mono')
where
@@ -117,7 +118,7 @@ vectExpr (_, AnnCase scrut bndr ty alts)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
- vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False rhs
+ vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vLet (vNonRec vbndr vrhs) vbody
@@ -134,10 +135,10 @@ vectExpr (_, AnnLet (AnnRec bs) body)
vect_rhs bndr rhs = localV
. inBind bndr
. liftM (\(_,_,z)->z)
- $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
+ $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
vectExpr e@(_, AnnLam bndr _)
- | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False e
+ | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
{-
onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
`orElseV` vectLam True fvs bs body
@@ -152,18 +153,19 @@ vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnno
vectFnExpr
:: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
-> Bool -- ^ Whether the binding is a loop breaker.
+ -> [Var]
-> CoreExprWithFVs -- ^ Expression to vectorise. Must have an outer `AnnLam`.
-> VM (Inline, Bool, VExpr)
-vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
- | isId bndr = pprTrace "vectFnExpr -- id" (ppr fvs )$
+vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
+ | isId bndr = -- pprTrace "vectFnExpr -- id" (ppr fvs )$
onlyIfV True -- (isEmptyVarSet fvs) -- we check for free variables later. TODO: clean up
- (mark DontInline True . vectScalarLam bs $ deAnnotate body)
+ (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
`orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
where
(bs,body) = collectAnnValBinders e
-vectFnExpr _ _ e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e
+vectFnExpr _ _ _ e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e
mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
@@ -172,13 +174,18 @@ mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
-- | Vectorise a function where are the args have scalar type,
-- that is Int, Float, Double etc.
vectScalarLam
- :: [Var] -- ^ Bound variables of function.
+ :: [Var] -- ^ Bound variables of function
+ -> [Var]
-> CoreExpr -- ^ Function body.
-> VM VExpr
-vectScalarLam args body
- = do scalars <- globalScalars
- pprTrace "vectScalarLam" (ppr $ is_scalar (extendVarSetList scalars args) body) $
+vectScalarLam args recFns body
+ = do scalars' <- globalScalars
+ let scalars = unionVarSet (mkVarSet recFns) scalars'
+{- pprTrace "vectScalarLam uses" (ppr $ uses scalars body) $
+ pprTrace "vectScalarLam is prim res" (ppr $ is_prim_ty res_ty) $
+ pprTrace "vectScalarLam is scalar body" (ppr $ is_scalar (extendVarSetList scalars args) body) $
+ pprTrace "vectScalarLam arg tys" (ppr $ arg_tys) $ -}
onlyIfV (all is_prim_ty arg_tys
&& is_prim_ty res_ty
&& is_scalar (extendVarSetList scalars args) body
@@ -190,7 +197,7 @@ vectScalarLam args body
(zipf `App` Var fn_var)
clo_var <- hoistExpr (fsLit "clo") clo DontInline
lclo <- liftPD (Var clo_var)
- pprTrace " lam is scalar" (ppr "") $
+ {- pprTrace " lam is scalar" (ppr "") $ -}
return (Var clo_var, lclo)
where
arg_tys = map idType args
@@ -214,7 +221,7 @@ vectScalarLam args body
| isPrimTyCon tycon = False
| isAbstractTyCon tycon = True
| isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon = any (maybe_parr_ty' alreadySeen) args
- | isDataTyCon tycon = pprTrace "isDataTyCon" (ppr tycon) $
+ | isDataTyCon tycon = -- pprTrace "isDataTyCon" (ppr tycon) $
any (maybe_parr_ty' alreadySeen) args ||
hasParrDataCon alreadySeen tycon
| otherwise = True
diff --git a/compiler/vectorise/Vectorise/Monad.hs b/compiler/vectorise/Vectorise/Monad.hs
index 77b9b7fdf3..259743058e 100644
--- a/compiler/vectorise/Vectorise/Monad.hs
+++ b/compiler/vectorise/Vectorise/Monad.hs
@@ -17,7 +17,8 @@ module Vectorise.Monad (
maybeCantVectoriseVarM,
dumpVar,
addGlobalScalar,
-
+ deleteGlobalScalar,
+
-- * Primitives
lookupPrimPArray,
lookupPrimMethod
@@ -146,6 +147,11 @@ addGlobalScalar :: Var -> VM ()
addGlobalScalar var
= updGEnv $ \env -> pprTrace "addGLobalScalar" (ppr var) env{global_scalars = extendVarSet (global_scalars env) var}
+deleteGlobalScalar :: Var -> VM ()
+deleteGlobalScalar var
+ = updGEnv $ \env -> pprTrace "deleteGLobalScalar" (ppr var) env{global_scalars = delVarSet (global_scalars env) var}
+
+
-- Primitives -----------------------------------------------------------------
lookupPrimPArray :: TyCon -> VM (Maybe TyCon)
lookupPrimPArray = liftBuiltinDs . primPArray