summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/vectorise/Vectorise.hs4
-rw-r--r--compiler/vectorise/Vectorise/Exp.hs197
2 files changed, 117 insertions, 84 deletions
diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs
index c721414e37..cf11c046ba 100644
--- a/compiler/vectorise/Vectorise.hs
+++ b/compiler/vectorise/Vectorise.hs
@@ -210,7 +210,9 @@ vectTopBind b@(Rec binds)
; cantVectorise dflags noVectoriseErr (ppr b)
}
else do
- { -- For all bindings *with* a pragma, just use the pragma-supplied vectorised expression
+ { traceVt "[Vanilla]" $ vcat [ppr var <+> char '=' <+> ppr expr | (var, expr) <- binds]
+
+ -- For all bindings *with* a pragma, just use the pragma-supplied vectorised expression
; newBindsWPragma <- concat <$>
sequence [ vectTopBindAndConvert bind inlineMe expr'
| (bind, (_, Just (_, expr'))) <- zip binds vectDecls]
diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs
index a32390260c..6ce1b9b025 100644
--- a/compiler/vectorise/Vectorise/Exp.hs
+++ b/compiler/vectorise/Vectorise/Exp.hs
@@ -99,11 +99,19 @@ vectTopExprs binds
= do
{ exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs
; if all isVIEncaps exprVIs
- then
- return Nothing
+ -- if all bindings are scalar => don't vectorise this group of bindings
+ then return Nothing
else do
- { (areVIParr, vExprs) <- unzip <$> mapM encapsulateAndVect binds
- ; return $ Just (or areVIParr, vExprs)
+ { -- non-scalar bindings need to be vectorised
+ ; let areVIParr = any isVIParr exprVIs
+ ; revised_exprVIs <- if not areVIParr
+ -- if no binding is parallel => 'exprVIs' is ready for vectorisation
+ then return exprVIs
+ -- if any binding is parallel => recompute the vectorisation info
+ else mapM (vectAvoidAndEncapsulate (mkVarSet vars)) exprs
+
+ ; vExprs <- zipWithM vect vars revised_exprVIs
+ ; return $ Just (areVIParr, vExprs)
}
}
where
@@ -111,14 +119,13 @@ vectTopExprs binds
vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars
- encapsulateAndVect (var, expr)
+ vect var exprVI
= do
- { exprVI <- vectAvoidAndEncapsulate (mkVarSet vars) expr
- ; vExpr <- closedV $
+ { vExpr <- closedV $
inBind var $
vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI
; inline <- computeInline exprVI
- ; return (isVIParr exprVI, (inline, vectorised vExpr))
+ ; return (inline, vectorised vExpr)
}
-- |Vectorise a polymorphic expression annotated with vectorisation information.
@@ -302,8 +309,8 @@ vectExpr (_, AnnVar v)
vectExpr (_, AnnLit lit)
= vectConst $ Lit lit
-vectExpr aexpr@(_, AnnLam bndr _)
- = vectFnExpr True False aexpr
+vectExpr aexpr@(_, AnnLam _ _)
+ = traceVt "vectExpr [AnnLam]:" (ppr . deAnnotate $ aexpr) >> vectFnExpr True False aexpr
-- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
-- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
@@ -368,19 +375,24 @@ vectExpr (_, AnnCase scrut bndr ty alts)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
- { vrhs <- localV $
+ { traceVt "let binding (non-recursive)" empty
+ ; vrhs <- localV $
inBind bndr $
vectAnnPolyExpr False rhs
+ ; traceVt "let body (non-recursive)" empty
; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
; return $ vLet (vNonRec vbndr vrhs) vbody
}
vectExpr (_, AnnLet (AnnRec bs) body)
= do
- { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
- $ liftM2 (,)
- (zipWithM vect_rhs bndrs rhss)
- (vectExpr body)
+ { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ do
+ { traceVt "let bindings (recursive)" empty
+ ; vrhss <- zipWithM vect_rhs bndrs rhss
+ ; traceVt "let body (recursive)" empty
+ ; vbody <- vectExpr body
+ ; return (vrhss, vbody)
+ }
; return $ vLet (vRec vbndrs vrhss) vbody
}
where
@@ -442,7 +454,7 @@ vectFnExpr _ _ aexpr
= vectScalarFun . deAnnotate $ aexpr
| otherwise
-- not an abstraction: vectorise as a non-scalar vanilla expression
- -- NB: we can get here legitimately due to the recursion in the first case above
+ -- NB: we can get here due to the recursion in the first case above and from 'vectAnnPolyExpr'
= vectExpr aexpr
-- |Vectorise type and dictionary applications.
@@ -570,7 +582,7 @@ vectDictExpr (Coercion coe)
vectScalarFun :: CoreExpr -> VM VExpr
vectScalarFun expr
= do
- { traceVt "vectorise scalar functions:" (ppr expr)
+ { traceVt "vectScalarFun:" (ppr expr)
; let (arg_tys, res_ty) = splitFunTys (exprType expr)
; mkScalarFun arg_tys res_ty expr
}
@@ -700,7 +712,9 @@ vectLam :: Bool -- ^ Should the RHS of a binding be inlined?
-> CoreExprWithVectInfo -- ^ Body of abstraction.
-> VM VExpr
vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _)
- = do { let (bndrs, body) = collectAnnValBinders expr
+ = do { traceVt "fully vectorise a lambda expression" (ppr . deAnnotate $ expr)
+
+ ; let (bndrs, body) = collectAnnValBinders expr
-- grab the in-scope type variables
; tyvars <- localTyVars
@@ -769,40 +783,47 @@ vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda"
-- have to handle the case where v is a wild var correctly.
--
--- FIXME: this is too lazy
+-- FIXME: this is too lazy...is it?
vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type
-> [(AltCon, [Var], CoreExprWithVectInfo)]
-> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
= do
- vscrut <- vectExpr scrut
- (vty, lty) <- vectAndLiftType ty
- (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
- return $ vCaseDEFAULT vscrut vbndr vty lty vbody
-
+ { traceVt "scrutinee (DEFAULT only)" empty
+ ; vscrut <- vectExpr scrut
+ ; (vty, lty) <- vectAndLiftType ty
+ ; traceVt "alternative body (DEFAULT only)" empty
+ ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
+ ; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
+ }
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
= do
- vscrut <- vectExpr scrut
- (vty, lty) <- vectAndLiftType ty
- (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
- return $ vCaseDEFAULT vscrut vbndr vty lty vbody
-
+ { traceVt "scrutinee (one shot w/o binders)" empty
+ ; vscrut <- vectExpr scrut
+ ; (vty, lty) <- vectAndLiftType ty
+ ; traceVt "alternative body (one shot w/o binders)" empty
+ ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
+ ; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
+ }
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
= do
- (vty, lty) <- vectAndLiftType ty
- vexpr <- vectExpr scrut
- (vbndr, (vbndrs, (vect_body, lift_body)))
- <- vect_scrut_bndr
- . vectBndrsIn bndrs
- $ vectExpr body
- let (vect_bndrs, lift_bndrs) = unzip vbndrs
- (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
- vect_dc <- maybeV dataConErr (lookupDataCon dc)
-
- let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
+ { traceVt "scrutinee (one shot w/ binders)" empty
+ ; vexpr <- vectExpr scrut
+ ; (vty, lty) <- vectAndLiftType ty
+ ; traceVt "alternative body (one shot w/ binders)" empty
+ ; (vbndr, (vbndrs, (vect_body, lift_body)))
+ <- vect_scrut_bndr
+ . vectBndrsIn bndrs
+ $ vectExpr body
+ ; let (vect_bndrs, lift_bndrs) = unzip vbndrs
+ ; (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
+ ; vect_dc <- maybeV dataConErr (lookupDataCon dc)
+
+ ; let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
- return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
+ ; return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
+ }
where
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
@@ -814,36 +835,40 @@ vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
vectAlgCase tycon _ty_args scrut bndr ty alts
= do
- vect_tc <- vectTyCon tycon
- (vty, lty) <- vectAndLiftType ty
+ { traceVt "scrutinee (general case)" empty
+ ; vexpr <- vectExpr scrut
+
+ ; vect_tc <- vectTyCon tycon
+ ; (vty, lty) <- vectAndLiftType ty
- let arity = length (tyConDataCons vect_tc)
- sel_ty <- builtin (selTy arity)
- sel_bndr <- newLocalVar (fsLit "sel") sel_ty
- let sel = Var sel_bndr
+ ; let arity = length (tyConDataCons vect_tc)
+ ; sel_ty <- builtin (selTy arity)
+ ; sel_bndr <- newLocalVar (fsLit "sel") sel_ty
+ ; let sel = Var sel_bndr
- (vbndr, valts) <- vect_scrut_bndr
+ ; traceVt "alternatives' body (general case)" empty
+ ; (vbndr, valts) <- vect_scrut_bndr
$ mapM (proc_alt arity sel vty lty) alts'
- let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
+ ; let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
- vexpr <- vectExpr scrut
- (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
+ ; (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
- let (vect_bodies, lift_bodies) = unzip vbodies
+ ; let (vect_bodies, lift_bodies) = unzip vbodies
- vdummy <- newDummyVar (exprType vect_scrut)
- ldummy <- newDummyVar (exprType lift_scrut)
- let vect_case = Case vect_scrut vdummy vty
+ ; vdummy <- newDummyVar (exprType vect_scrut)
+ ; ldummy <- newDummyVar (exprType lift_scrut)
+ ; let vect_case = Case vect_scrut vdummy vty
(zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
- lc <- builtin liftingContext
- lbody <- combinePD vty (Var lc) sel lift_bodies
- let lift_case = Case lift_scrut ldummy lty
+ ; lc <- builtin liftingContext
+ ; lbody <- combinePD vty (Var lc) sel lift_bodies
+ ; let lift_case = Case lift_scrut ldummy lty
[(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
lbody)]
- return . vLet (vNonRec vbndr vexpr)
+ ; return . vLet (vNonRec vbndr vexpr)
$ (vect_case, lift_case)
+ }
where
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
@@ -871,12 +896,14 @@ vectAlgCase tycon _ty_args scrut bndr ty alts
<- vectBndrsIn bndrs
. localV
$ do
- binds <- mapM (pack_var (Var lc) sel_tags tag)
+ { binds <- mapM (pack_var (Var lc) sel_tags tag)
. filter isLocalId
$ varSetElems fvs
- (ve, le) <- vectExpr body
- return (ve, Case (elems `App` sel) lc lty
+ ; traceVt "case alternative:" (ppr . deAnnotate $ body)
+ ; (ve, le) <- vectExpr body
+ ; return (ve, Case (elems `App` sel) lc lty
[(DEFAULT, [], (mkLets (concat binds) le))])
+ }
-- empty <- emptyPD vty
-- return (ve, Case (elems `App` sel) lc lty
-- [(DEFAULT, [], Let (NonRec flags_var flags_expr)
@@ -887,25 +914,26 @@ vectAlgCase tycon _ty_args scrut bndr ty alts
where
dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
-
proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
+ -- Pack a variable for a case alternative context *if* the variable is vectorised. If it
+ -- isn't, ignore it as scalar variables don't need to be packed.
pack_var len tags t v
= do
- r <- lookupVar v
- case r of
- Local (vv, lv) ->
+ { r <- lookupVar_maybe v
+ ; case r of
+ Just (Local (vv, lv)) ->
do
- lv' <- cloneVar lv
- expr <- packByTagPD (idType vv) (Var lv) len tags t
- updLEnv (\env -> env { local_vars = extendVarEnv
- (local_vars env) v (vv, lv') })
- return [(NonRec lv' expr)]
-
+ { lv' <- cloneVar lv
+ ; expr <- packByTagPD (idType vv) (Var lv) len tags t
+ ; updLEnv (\env -> env { local_vars = extendVarEnv (local_vars env) v (vv, lv') })
+ ; return [(NonRec lv' expr)]
+ }
_ -> return []
-
+ }
+
-- Support to compute information for vectorisation avoidance ------------------
@@ -972,7 +1000,10 @@ vectAvoidInfo pvs ce@(fvs, AnnVar v)
; vi <- if v `elemVarSet` pvs || v `elemVarSet` gpvs
then return VIParr
else vectAvoidInfoTypeOf ce
- ; viTrace ce vi []
+ ; viTrace ce vi []
+ ; when (vi == VIParr) $
+ traceVt " reason:" $ if v `elemVarSet` pvs then text "local" else
+ if v `elemVarSet` gpvs then text "global" else text "parallel type"
; return ((fvs, vi), AnnVar v)
}
@@ -990,16 +1021,16 @@ vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2)
; eVI1 <- vectAvoidInfo pvs e1
; eVI2 <- vectAvoidInfo pvs e2
; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2
- ; viTrace ce vi [eVI1, eVI2]
+ -- ; viTrace ce vi [eVI1, eVI2]
; return ((fvs, vi), AnnApp eVI1 eVI2)
}
-vectAvoidInfo pvs ce@(fvs, AnnLam var body)
+vectAvoidInfo pvs (fvs, AnnLam var body)
= do
{ bodyVI <- vectAvoidInfo pvs body
; varVI <- vectAvoidInfoType $ varType var
; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI
- ; viTrace ce vi [bodyVI]
+ -- ; viTrace ce vi [bodyVI]
; return ((fvs, vi), AnnLam var bodyVI)
}
@@ -1010,14 +1041,14 @@ vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body)
; isScalarTy <- isScalar $ varType var
; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy
then do -- binding is parallel
- { bodyVI <- vectAvoidInfo (fvs `extendVarSet` var) body
+ { bodyVI <- vectAvoidInfo (pvs `extendVarSet` var) body
; return (bodyVI, VIParr)
}
else do -- binding doesn't affect parallelism
- { bodyVI <- vectAvoidInfo fvs body
+ { bodyVI <- vectAvoidInfo pvs body
; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI)
}
- ; viTrace ce vi [eVI, bodyVI]
+ -- ; viTrace ce vi [eVI, bodyVI]
; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI)
}
@@ -1032,13 +1063,13 @@ vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body)
; let extendedPvs = pvs `extendVarSetList` new_pvs
; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds
; bodyVI <- vectAvoidInfo extendedPvs body
- ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI])
+ -- ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI])
; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI)
}
else do -- demanded bindings cannot trigger parallelism
{ bodyVI <- vectAvoidInfo pvs body
; let vi = ceVI `unlessVIParrExpr` bodyVI
- ; viTrace ce vi (map snd bndsVI ++ [bodyVI])
+ -- ; viTrace ce vi (map snd bndsVI ++ [bodyVI])
; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI)
}
}
@@ -1058,7 +1089,7 @@ vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts)
; altsVI <- mapM (vectAvoidInfoAlt (isVIParr eVI)) alts
; let alteVIs = [eVI | (_, _, eVI) <- altsVI]
vi = foldl unlessVIParrExpr ceVI (eVI:alteVIs) -- NB: same effect as in the paper
- ; viTrace ce vi (eVI : alteVIs)
+ -- ; viTrace ce vi (eVI : alteVIs)
; return ((fvs, vi), AnnCase eVI var ty altsVI)
}
where