diff options
-rw-r--r-- | compiler/vectorise/Vectorise.hs | 4 | ||||
-rw-r--r-- | compiler/vectorise/Vectorise/Exp.hs | 197 |
2 files changed, 117 insertions, 84 deletions
diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs index c721414e37..cf11c046ba 100644 --- a/compiler/vectorise/Vectorise.hs +++ b/compiler/vectorise/Vectorise.hs @@ -210,7 +210,9 @@ vectTopBind b@(Rec binds) ; cantVectorise dflags noVectoriseErr (ppr b) } else do - { -- For all bindings *with* a pragma, just use the pragma-supplied vectorised expression + { traceVt "[Vanilla]" $ vcat [ppr var <+> char '=' <+> ppr expr | (var, expr) <- binds] + + -- For all bindings *with* a pragma, just use the pragma-supplied vectorised expression ; newBindsWPragma <- concat <$> sequence [ vectTopBindAndConvert bind inlineMe expr' | (bind, (_, Just (_, expr'))) <- zip binds vectDecls] diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs index a32390260c..6ce1b9b025 100644 --- a/compiler/vectorise/Vectorise/Exp.hs +++ b/compiler/vectorise/Vectorise/Exp.hs @@ -99,11 +99,19 @@ vectTopExprs binds = do { exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs ; if all isVIEncaps exprVIs - then - return Nothing + -- if all bindings are scalar => don't vectorise this group of bindings + then return Nothing else do - { (areVIParr, vExprs) <- unzip <$> mapM encapsulateAndVect binds - ; return $ Just (or areVIParr, vExprs) + { -- non-scalar bindings need to be vectorised + ; let areVIParr = any isVIParr exprVIs + ; revised_exprVIs <- if not areVIParr + -- if no binding is parallel => 'exprVIs' is ready for vectorisation + then return exprVIs + -- if any binding is parallel => recompute the vectorisation info + else mapM (vectAvoidAndEncapsulate (mkVarSet vars)) exprs + + ; vExprs <- zipWithM vect vars revised_exprVIs + ; return $ Just (areVIParr, vExprs) } } where @@ -111,14 +119,13 @@ vectTopExprs binds vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars - encapsulateAndVect (var, expr) + vect var exprVI = do - { exprVI <- vectAvoidAndEncapsulate (mkVarSet vars) expr - ; vExpr <- closedV $ + { vExpr <- closedV $ inBind var $ vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI ; inline <- computeInline exprVI - ; return (isVIParr exprVI, (inline, vectorised vExpr)) + ; return (inline, vectorised vExpr) } -- |Vectorise a polymorphic expression annotated with vectorisation information. @@ -302,8 +309,8 @@ vectExpr (_, AnnVar v) vectExpr (_, AnnLit lit) = vectConst $ Lit lit -vectExpr aexpr@(_, AnnLam bndr _) - = vectFnExpr True False aexpr +vectExpr aexpr@(_, AnnLam _ _) + = traceVt "vectExpr [AnnLam]:" (ppr . deAnnotate $ aexpr) >> vectFnExpr True False aexpr -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty'; -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint @@ -368,19 +375,24 @@ vectExpr (_, AnnCase scrut bndr ty alts) vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) = do - { vrhs <- localV $ + { traceVt "let binding (non-recursive)" empty + ; vrhs <- localV $ inBind bndr $ vectAnnPolyExpr False rhs + ; traceVt "let body (non-recursive)" empty ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) ; return $ vLet (vNonRec vbndr vrhs) vbody } vectExpr (_, AnnLet (AnnRec bs) body) = do - { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs - $ liftM2 (,) - (zipWithM vect_rhs bndrs rhss) - (vectExpr body) + { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ do + { traceVt "let bindings (recursive)" empty + ; vrhss <- zipWithM vect_rhs bndrs rhss + ; traceVt "let body (recursive)" empty + ; vbody <- vectExpr body + ; return (vrhss, vbody) + } ; return $ vLet (vRec vbndrs vrhss) vbody } where @@ -442,7 +454,7 @@ vectFnExpr _ _ aexpr = vectScalarFun . deAnnotate $ aexpr | otherwise -- not an abstraction: vectorise as a non-scalar vanilla expression - -- NB: we can get here legitimately due to the recursion in the first case above + -- NB: we can get here due to the recursion in the first case above and from 'vectAnnPolyExpr' = vectExpr aexpr -- |Vectorise type and dictionary applications. @@ -570,7 +582,7 @@ vectDictExpr (Coercion coe) vectScalarFun :: CoreExpr -> VM VExpr vectScalarFun expr = do - { traceVt "vectorise scalar functions:" (ppr expr) + { traceVt "vectScalarFun:" (ppr expr) ; let (arg_tys, res_ty) = splitFunTys (exprType expr) ; mkScalarFun arg_tys res_ty expr } @@ -700,7 +712,9 @@ vectLam :: Bool -- ^ Should the RHS of a binding be inlined? -> CoreExprWithVectInfo -- ^ Body of abstraction. -> VM VExpr vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _) - = do { let (bndrs, body) = collectAnnValBinders expr + = do { traceVt "fully vectorise a lambda expression" (ppr . deAnnotate $ expr) + + ; let (bndrs, body) = collectAnnValBinders expr -- grab the in-scope type variables ; tyvars <- localTyVars @@ -769,40 +783,47 @@ vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda" -- have to handle the case where v is a wild var correctly. -- --- FIXME: this is too lazy +-- FIXME: this is too lazy...is it? vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type -> [(AltCon, [Var], CoreExprWithVectInfo)] -> VM VExpr vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] = do - vscrut <- vectExpr scrut - (vty, lty) <- vectAndLiftType ty - (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) - return $ vCaseDEFAULT vscrut vbndr vty lty vbody - + { traceVt "scrutinee (DEFAULT only)" empty + ; vscrut <- vectExpr scrut + ; (vty, lty) <- vectAndLiftType ty + ; traceVt "alternative body (DEFAULT only)" empty + ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) + ; return $ vCaseDEFAULT vscrut vbndr vty lty vbody + } vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] = do - vscrut <- vectExpr scrut - (vty, lty) <- vectAndLiftType ty - (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) - return $ vCaseDEFAULT vscrut vbndr vty lty vbody - + { traceVt "scrutinee (one shot w/o binders)" empty + ; vscrut <- vectExpr scrut + ; (vty, lty) <- vectAndLiftType ty + ; traceVt "alternative body (one shot w/o binders)" empty + ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) + ; return $ vCaseDEFAULT vscrut vbndr vty lty vbody + } vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] = do - (vty, lty) <- vectAndLiftType ty - vexpr <- vectExpr scrut - (vbndr, (vbndrs, (vect_body, lift_body))) - <- vect_scrut_bndr - . vectBndrsIn bndrs - $ vectExpr body - let (vect_bndrs, lift_bndrs) = unzip vbndrs - (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr) - vect_dc <- maybeV dataConErr (lookupDataCon dc) - - let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body + { traceVt "scrutinee (one shot w/ binders)" empty + ; vexpr <- vectExpr scrut + ; (vty, lty) <- vectAndLiftType ty + ; traceVt "alternative body (one shot w/ binders)" empty + ; (vbndr, (vbndrs, (vect_body, lift_body))) + <- vect_scrut_bndr + . vectBndrsIn bndrs + $ vectExpr body + ; let (vect_bndrs, lift_bndrs) = unzip vbndrs + ; (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr) + ; vect_dc <- maybeV dataConErr (lookupDataCon dc) + + ; let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body - return $ vLet (vNonRec vbndr vexpr) (vcase, lcase) + ; return $ vLet (vNonRec vbndr vexpr) (vcase, lcase) + } where vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut") | otherwise = vectBndrIn bndr @@ -814,36 +835,40 @@ vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] vectAlgCase tycon _ty_args scrut bndr ty alts = do - vect_tc <- vectTyCon tycon - (vty, lty) <- vectAndLiftType ty + { traceVt "scrutinee (general case)" empty + ; vexpr <- vectExpr scrut + + ; vect_tc <- vectTyCon tycon + ; (vty, lty) <- vectAndLiftType ty - let arity = length (tyConDataCons vect_tc) - sel_ty <- builtin (selTy arity) - sel_bndr <- newLocalVar (fsLit "sel") sel_ty - let sel = Var sel_bndr + ; let arity = length (tyConDataCons vect_tc) + ; sel_ty <- builtin (selTy arity) + ; sel_bndr <- newLocalVar (fsLit "sel") sel_ty + ; let sel = Var sel_bndr - (vbndr, valts) <- vect_scrut_bndr + ; traceVt "alternatives' body (general case)" empty + ; (vbndr, valts) <- vect_scrut_bndr $ mapM (proc_alt arity sel vty lty) alts' - let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts + ; let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts - vexpr <- vectExpr scrut - (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr) + ; (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr) - let (vect_bodies, lift_bodies) = unzip vbodies + ; let (vect_bodies, lift_bodies) = unzip vbodies - vdummy <- newDummyVar (exprType vect_scrut) - ldummy <- newDummyVar (exprType lift_scrut) - let vect_case = Case vect_scrut vdummy vty + ; vdummy <- newDummyVar (exprType vect_scrut) + ; ldummy <- newDummyVar (exprType lift_scrut) + ; let vect_case = Case vect_scrut vdummy vty (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies) - lc <- builtin liftingContext - lbody <- combinePD vty (Var lc) sel lift_bodies - let lift_case = Case lift_scrut ldummy lty + ; lc <- builtin liftingContext + ; lbody <- combinePD vty (Var lc) sel lift_bodies + ; let lift_case = Case lift_scrut ldummy lty [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss, lbody)] - return . vLet (vNonRec vbndr vexpr) + ; return . vLet (vNonRec vbndr vexpr) $ (vect_case, lift_case) + } where vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut") | otherwise = vectBndrIn bndr @@ -871,12 +896,14 @@ vectAlgCase tycon _ty_args scrut bndr ty alts <- vectBndrsIn bndrs . localV $ do - binds <- mapM (pack_var (Var lc) sel_tags tag) + { binds <- mapM (pack_var (Var lc) sel_tags tag) . filter isLocalId $ varSetElems fvs - (ve, le) <- vectExpr body - return (ve, Case (elems `App` sel) lc lty + ; traceVt "case alternative:" (ppr . deAnnotate $ body) + ; (ve, le) <- vectExpr body + ; return (ve, Case (elems `App` sel) lc lty [(DEFAULT, [], (mkLets (concat binds) le))]) + } -- empty <- emptyPD vty -- return (ve, Case (elems `App` sel) lc lty -- [(DEFAULT, [], Let (NonRec flags_var flags_expr) @@ -887,25 +914,26 @@ vectAlgCase tycon _ty_args scrut bndr ty alts where dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc) - proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt" mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body) + -- Pack a variable for a case alternative context *if* the variable is vectorised. If it + -- isn't, ignore it as scalar variables don't need to be packed. pack_var len tags t v = do - r <- lookupVar v - case r of - Local (vv, lv) -> + { r <- lookupVar_maybe v + ; case r of + Just (Local (vv, lv)) -> do - lv' <- cloneVar lv - expr <- packByTagPD (idType vv) (Var lv) len tags t - updLEnv (\env -> env { local_vars = extendVarEnv - (local_vars env) v (vv, lv') }) - return [(NonRec lv' expr)] - + { lv' <- cloneVar lv + ; expr <- packByTagPD (idType vv) (Var lv) len tags t + ; updLEnv (\env -> env { local_vars = extendVarEnv (local_vars env) v (vv, lv') }) + ; return [(NonRec lv' expr)] + } _ -> return [] - + } + -- Support to compute information for vectorisation avoidance ------------------ @@ -972,7 +1000,10 @@ vectAvoidInfo pvs ce@(fvs, AnnVar v) ; vi <- if v `elemVarSet` pvs || v `elemVarSet` gpvs then return VIParr else vectAvoidInfoTypeOf ce - ; viTrace ce vi [] + ; viTrace ce vi [] + ; when (vi == VIParr) $ + traceVt " reason:" $ if v `elemVarSet` pvs then text "local" else + if v `elemVarSet` gpvs then text "global" else text "parallel type" ; return ((fvs, vi), AnnVar v) } @@ -990,16 +1021,16 @@ vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2) ; eVI1 <- vectAvoidInfo pvs e1 ; eVI2 <- vectAvoidInfo pvs e2 ; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2 - ; viTrace ce vi [eVI1, eVI2] + -- ; viTrace ce vi [eVI1, eVI2] ; return ((fvs, vi), AnnApp eVI1 eVI2) } -vectAvoidInfo pvs ce@(fvs, AnnLam var body) +vectAvoidInfo pvs (fvs, AnnLam var body) = do { bodyVI <- vectAvoidInfo pvs body ; varVI <- vectAvoidInfoType $ varType var ; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI - ; viTrace ce vi [bodyVI] + -- ; viTrace ce vi [bodyVI] ; return ((fvs, vi), AnnLam var bodyVI) } @@ -1010,14 +1041,14 @@ vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body) ; isScalarTy <- isScalar $ varType var ; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy then do -- binding is parallel - { bodyVI <- vectAvoidInfo (fvs `extendVarSet` var) body + { bodyVI <- vectAvoidInfo (pvs `extendVarSet` var) body ; return (bodyVI, VIParr) } else do -- binding doesn't affect parallelism - { bodyVI <- vectAvoidInfo fvs body + { bodyVI <- vectAvoidInfo pvs body ; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI) } - ; viTrace ce vi [eVI, bodyVI] + -- ; viTrace ce vi [eVI, bodyVI] ; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI) } @@ -1032,13 +1063,13 @@ vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body) ; let extendedPvs = pvs `extendVarSetList` new_pvs ; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds ; bodyVI <- vectAvoidInfo extendedPvs body - ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI]) + -- ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI]) ; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI) } else do -- demanded bindings cannot trigger parallelism { bodyVI <- vectAvoidInfo pvs body ; let vi = ceVI `unlessVIParrExpr` bodyVI - ; viTrace ce vi (map snd bndsVI ++ [bodyVI]) + -- ; viTrace ce vi (map snd bndsVI ++ [bodyVI]) ; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI) } } @@ -1058,7 +1089,7 @@ vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts) ; altsVI <- mapM (vectAvoidInfoAlt (isVIParr eVI)) alts ; let alteVIs = [eVI | (_, _, eVI) <- altsVI] vi = foldl unlessVIParrExpr ceVI (eVI:alteVIs) -- NB: same effect as in the paper - ; viTrace ce vi (eVI : alteVIs) + -- ; viTrace ce vi (eVI : alteVIs) ; return ((fvs, vi), AnnCase eVI var ty altsVI) } where |