diff options
Diffstat (limited to 'compiler/vectorise')
-rw-r--r-- | compiler/vectorise/VectBuiltIn.hs | 90 | ||||
-rw-r--r-- | compiler/vectorise/VectMonad.hs | 2 | ||||
-rw-r--r-- | compiler/vectorise/VectType.hs | 4 | ||||
-rw-r--r-- | compiler/vectorise/VectUtils.hs | 45 | ||||
-rw-r--r-- | compiler/vectorise/Vectorise.hs | 10 |
5 files changed, 61 insertions, 90 deletions
diff --git a/compiler/vectorise/VectBuiltIn.hs b/compiler/vectorise/VectBuiltIn.hs index 4f27b1e5c5..05b1289ee8 100644 --- a/compiler/vectorise/VectBuiltIn.hs +++ b/compiler/vectorise/VectBuiltIn.hs @@ -23,7 +23,8 @@ import TypeRep ( funTyCon ) import Type ( Type ) import TysPrim import TysWiredIn ( unitTyCon, tupleTyCon, intTyConName ) -import PrelNames +import Module ( Module, mkModule, mkModuleNameFS ) +import PackageConfig ( ndpPackageId ) import BasicTypes ( Boxity(..) ) import FastString @@ -38,6 +39,15 @@ mAX_NDP_PROD = 3 mAX_NDP_SUM :: Int mAX_NDP_SUM = 3 +mkNDPModule :: FastString -> Module +mkNDPModule m = mkModule ndpPackageId (mkModuleNameFS m) + +nDP_PARRAY = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.PArray") +nDP_REPR = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Repr") +nDP_CLOSURE = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Closure") +nDP_PRIM = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Prim") +nDP_INSTANCES = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Instances") + data Builtins = Builtins { parrayTyCon :: TyCon , paTyCon :: TyCon @@ -80,33 +90,33 @@ prodTyCon n bi initBuiltins :: DsM Builtins initBuiltins = do - parrayTyCon <- dsLookupTyCon parrayTyConName - paTyCon <- dsLookupTyCon paTyConName + parrayTyCon <- externalTyCon nDP_PARRAY FSLIT("PArray") + paTyCon <- externalTyCon nDP_PARRAY FSLIT("PA") let [paDataCon] = tyConDataCons paTyCon - preprTyCon <- dsLookupTyCon preprTyConName - prTyCon <- dsLookupTyCon prTyConName + preprTyCon <- externalTyCon nDP_PARRAY FSLIT("PRepr") + prTyCon <- externalTyCon nDP_PARRAY FSLIT("PR") let [prDataCon] = tyConDataCons prTyCon - parrayIntPrimTyCon <- dsLookupTyCon parrayIntPrimTyConName - closureTyCon <- dsLookupTyCon closureTyConName + parrayIntPrimTyCon <- externalTyCon nDP_PRIM FSLIT("PArray_Int#") + closureTyCon <- externalTyCon nDP_CLOSURE FSLIT(":->") - voidTyCon <- lookupExternalTyCon nDP_REPR FSLIT("Void") - wrapTyCon <- lookupExternalTyCon nDP_REPR FSLIT("Wrap") - sum_tcs <- mapM (lookupExternalTyCon nDP_REPR) + voidTyCon <- externalTyCon nDP_REPR FSLIT("Void") + wrapTyCon <- externalTyCon nDP_REPR FSLIT("Wrap") + sum_tcs <- mapM (externalTyCon nDP_REPR) [mkFastString ("Sum" ++ show i) | i <- [2..mAX_NDP_SUM]] let sumTyCons = listArray (2, mAX_NDP_SUM) sum_tcs - voidVar <- lookupExternalVar nDP_REPR FSLIT("void") - mkPRVar <- dsLookupGlobalId mkPRName - mkClosureVar <- dsLookupGlobalId mkClosureName - applyClosureVar <- dsLookupGlobalId applyClosureName - mkClosurePVar <- dsLookupGlobalId mkClosurePName - applyClosurePVar <- dsLookupGlobalId applyClosurePName - replicatePAIntPrimVar <- dsLookupGlobalId replicatePAIntPrimName - upToPAIntPrimVar <- dsLookupGlobalId upToPAIntPrimName - lengthPAVar <- dsLookupGlobalId lengthPAName - replicatePAVar <- dsLookupGlobalId replicatePAName - emptyPAVar <- dsLookupGlobalId emptyPAName + voidVar <- externalVar nDP_REPR FSLIT("void") + mkPRVar <- externalVar nDP_PARRAY FSLIT("mkPR") + mkClosureVar <- externalVar nDP_CLOSURE FSLIT("mkClosure") + applyClosureVar <- externalVar nDP_CLOSURE FSLIT("$:") + mkClosurePVar <- externalVar nDP_CLOSURE FSLIT("mkClosureP") + applyClosurePVar <- externalVar nDP_CLOSURE FSLIT("$:^") + replicatePAIntPrimVar <- externalVar nDP_PRIM FSLIT("replicatePA_Int#") + upToPAIntPrimVar <- externalVar nDP_PRIM FSLIT("upToPA_Int#") + lengthPAVar <- externalVar nDP_PARRAY FSLIT("lengthPA") + replicatePAVar <- externalVar nDP_PARRAY FSLIT("replicatePA") + emptyPAVar <- externalVar nDP_PARRAY FSLIT("emptyPA") -- packPAVar <- dsLookupGlobalId packPAName -- combinePAVar <- dsLookupGlobalId combinePAName @@ -141,21 +151,13 @@ initBuiltins , liftingContext = liftingContext } -initBuiltinTyCons :: DsM [(Name, TyCon)] -initBuiltinTyCons - = do - vects <- sequence vs - return (zip origs vects) - where - (origs, vs) = unzip builtinTyCons - -builtinTyCons :: [(Name, DsM TyCon)] -builtinTyCons = [(tyConName funTyCon, dsLookupTyCon closureTyConName)] +initBuiltinTyCons :: Builtins -> [(Name, TyCon)] +initBuiltinTyCons bi = [(tyConName funTyCon, closureTyCon bi)] initBuiltinDicts :: [(Name, Module, FastString)] -> DsM [(Name, Var)] initBuiltinDicts ps = do - dicts <- zipWithM lookupExternalVar mods fss + dicts <- zipWithM externalVar mods fss return $ zip tcs dicts where (tcs, mods, fss) = unzip3 ps @@ -165,11 +167,11 @@ initBuiltinPAs = initBuiltinDicts . builtinPAs builtinPAs :: Builtins -> [(Name, Module, FastString)] builtinPAs bi = [ - mk closureTyConName nDP_CLOSURE FSLIT("dPA_Clo") - , mk (tyConName $ voidTyCon bi) nDP_REPR FSLIT("dPA_Void") - , mk unitTyConName nDP_INSTANCES FSLIT("dPA_Unit") + mk (tyConName $ closureTyCon bi) nDP_CLOSURE FSLIT("dPA_Clo") + , mk (tyConName $ voidTyCon bi) nDP_REPR FSLIT("dPA_Void") + , mk unitTyConName nDP_INSTANCES FSLIT("dPA_Unit") - , mk intTyConName nDP_INSTANCES FSLIT("dPA_Int") + , mk intTyConName nDP_INSTANCES FSLIT("dPA_Int") ] ++ tups where @@ -185,10 +187,10 @@ initBuiltinPRs = initBuiltinDicts . builtinPRs builtinPRs :: Builtins -> [(Name, Module, FastString)] builtinPRs bi = [ - mk (tyConName unitTyCon) nDP_REPR FSLIT("dPR_Unit") - , mk (tyConName $ voidTyCon bi) nDP_REPR FSLIT("dPR_Void") - , mk (tyConName $ wrapTyCon bi) nDP_REPR FSLIT("dPR_Wrap") - , mk closureTyConName nDP_CLOSURE FSLIT("dPR_Clo") + mk (tyConName unitTyCon) nDP_REPR FSLIT("dPR_Unit") + , mk (tyConName $ voidTyCon bi) nDP_REPR FSLIT("dPR_Void") + , mk (tyConName $ wrapTyCon bi) nDP_REPR FSLIT("dPR_Wrap") + , mk (tyConName $ closureTyCon bi) nDP_CLOSURE FSLIT("dPR_Clo") -- temporary , mk intTyConName nDP_INSTANCES FSLIT("dPR_Int") @@ -205,12 +207,12 @@ builtinPRs bi = mk_prod n = (tyConName $ prodTyCon n bi, nDP_REPR, mkFastString ("dPR_" ++ show n)) -lookupExternalVar :: Module -> FastString -> DsM Var -lookupExternalVar mod fs +externalVar :: Module -> FastString -> DsM Var +externalVar mod fs = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs) -lookupExternalTyCon :: Module -> FastString -> DsM TyCon -lookupExternalTyCon mod fs +externalTyCon :: Module -> FastString -> DsM TyCon +externalTyCon mod fs = dsLookupTyCon =<< lookupOrig mod (mkOccNameFS tcName fs) unitTyConName = tyConName unitTyCon diff --git a/compiler/vectorise/VectMonad.hs b/compiler/vectorise/VectMonad.hs index cf71a00e55..56aeb141c1 100644 --- a/compiler/vectorise/VectMonad.hs +++ b/compiler/vectorise/VectMonad.hs @@ -462,7 +462,7 @@ initV hsc_env guts info p go = do builtins <- initBuiltins - builtin_tycons <- initBuiltinTyCons + let builtin_tycons = initBuiltinTyCons builtins builtin_pas <- initBuiltinPAs builtins builtin_prs <- initBuiltinPRs builtins diff --git a/compiler/vectorise/VectType.hs b/compiler/vectorise/VectType.hs index 77cb4295ed..aa8e4f8d29 100644 --- a/compiler/vectorise/VectType.hs +++ b/compiler/vectorise/VectType.hs @@ -585,11 +585,11 @@ buildToArrPRepr repr vect_tc prepr_tc arr_tc $ mkConApp data_con [Var len_var, Var repr_var] to_prod repr_vars@(r : _) - (ProdRepr { prod_components = tys + (ProdRepr { prod_components = tys@(ty : _) , prod_arr_tycon = tycon , prod_arr_data_con = data_con }) = do - len <- lengthPA (Var r) + len <- lengthPA ty (Var r) return . wrapFamInstBody tycon tys . mkConApp data_con $ map Type tys ++ len : map Var repr_vars diff --git a/compiler/vectorise/VectUtils.hs b/compiler/vectorise/VectUtils.hs index 1fb268f8a5..42bcab37bb 100644 --- a/compiler/vectorise/VectUtils.hs +++ b/compiler/vectorise/VectUtils.hs @@ -2,7 +2,6 @@ module VectUtils ( collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg, collectAnnValBinders, mkDataConTag, mkDataConTagLit, - splitClosureTy, mkBuiltinCo, mkPADictType, mkPArrayType, mkPReprType, @@ -75,36 +74,6 @@ mkDataConTagLit con mkDataConTag :: DataCon -> CoreExpr mkDataConTag con = mkIntLitInt (dataConTag con - fIRST_TAG) -splitUnTy :: String -> Name -> Type -> Type -splitUnTy s name ty - | Just (tc, [ty']) <- splitTyConApp_maybe ty - , tyConName tc == name - = ty' - - | otherwise = pprPanic s (ppr ty) - -splitBinTy :: String -> Name -> Type -> (Type, Type) -splitBinTy s name ty - | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty - , tyConName tc == name - = (ty1, ty2) - - | otherwise = pprPanic s (ppr ty) - -splitFixedTyConApp :: TyCon -> Type -> [Type] -splitFixedTyConApp tc ty - | Just (tc', tys) <- splitTyConApp_maybe ty - , tc == tc' - = tys - - | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty) - -splitClosureTy :: Type -> (Type, Type) -splitClosureTy = splitBinTy "splitClosureTy" closureTyConName - -splitPArrayTy :: Type -> Type -splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName - splitPrimTyCon :: Type -> Maybe TyCon splitPrimTyCon ty | Just (tycon, []) <- splitTyConApp_maybe ty @@ -267,10 +236,8 @@ mkPR ty dict <- paDictOfType ty return $ mkApps (Var fn) [Type ty, dict] -lengthPA :: CoreExpr -> VM CoreExpr -lengthPA x = liftM (`App` x) (paMethod pa_length ty) - where - ty = splitPArrayTy (exprType x) +lengthPA :: Type -> CoreExpr -> VM CoreExpr +lengthPA ty x = liftM (`App` x) (paMethod pa_length ty) replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr replicatePA len x = liftM (`mkApps` [len,x]) @@ -364,15 +331,13 @@ mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv) return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv], Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv]) -mkClosureApp :: VExpr -> VExpr -> VM VExpr -mkClosureApp (vclo, lclo) (varg, larg) +mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr +mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg) = do vapply <- builtin applyClosureVar lapply <- builtin applyClosurePVar return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg], Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg]) - where - (arg_ty, res_ty) = splitClosureTy (exprType vclo) buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr buildClosures tvs vars [] res_ty mk_body @@ -441,7 +406,7 @@ mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM mkLiftEnv lc [ty] [v] = return (Var v, \env body -> do - len <- lengthPA (Var v) + len <- lengthPA ty (Var v) return . Let (NonRec v env) $ Case len lc (exprType body) [(DEFAULT, [], body)]) diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs index 85f4e4612a..ada4956b8d 100644 --- a/compiler/vectorise/Vectorise.hs +++ b/compiler/vectorise/Vectorise.hs @@ -211,9 +211,13 @@ vectExpr e@(_, AnnApp _ arg) vectExpr (_, AnnApp fn arg) = do - fn' <- vectExpr fn - arg' <- vectExpr arg - mkClosureApp fn' arg' + arg_ty' <- vectType arg_ty + res_ty' <- vectType res_ty + fn' <- vectExpr fn + arg' <- vectExpr arg + mkClosureApp arg_ty' res_ty' fn' arg' + where + (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn vectExpr (_, AnnCase scrut bndr ty alts) | isAlgType scrut_ty |