diff options
23 files changed, 623 insertions, 21 deletions
diff --git a/compiler/cmm/CmmMachOp.hs b/compiler/cmm/CmmMachOp.hs index fdbfd6e857..8ac4a6fa7b 100644 --- a/compiler/cmm/CmmMachOp.hs +++ b/compiler/cmm/CmmMachOp.hs @@ -587,6 +587,8 @@ data CallishMachOp | MO_Memcmp Int | MO_PopCnt Width + | MO_Pdep Width + | MO_Pext Width | MO_Clz Width | MO_Ctz Width diff --git a/compiler/cmm/CmmParse.y b/compiler/cmm/CmmParse.y index 7ffb4fbe42..8afbd2f9d9 100644 --- a/compiler/cmm/CmmParse.y +++ b/compiler/cmm/CmmParse.y @@ -1006,6 +1006,16 @@ callishMachOps = listToUFM $ ( "popcnt32", (,) $ MO_PopCnt W32 ), ( "popcnt64", (,) $ MO_PopCnt W64 ), + ( "pdep8", (,) $ MO_Pdep W8 ), + ( "pdep16", (,) $ MO_Pdep W16 ), + ( "pdep32", (,) $ MO_Pdep W32 ), + ( "pdep64", (,) $ MO_Pdep W64 ), + + ( "pext8", (,) $ MO_Pext W8 ), + ( "pext16", (,) $ MO_Pext W16 ), + ( "pext32", (,) $ MO_Pext W32 ), + ( "pext64", (,) $ MO_Pext W64 ), + ( "cmpxchg8", (,) $ MO_Cmpxchg W8 ), ( "cmpxchg16", (,) $ MO_Cmpxchg W16 ), ( "cmpxchg32", (,) $ MO_Cmpxchg W32 ), diff --git a/compiler/cmm/PprC.hs b/compiler/cmm/PprC.hs index 1ddd1cd266..76e4d4cb94 100644 --- a/compiler/cmm/PprC.hs +++ b/compiler/cmm/PprC.hs @@ -789,6 +789,8 @@ pprCallishMachOp_for_C mop MO_Memcmp _ -> text "memcmp" (MO_BSwap w) -> ptext (sLit $ bSwapLabel w) (MO_PopCnt w) -> ptext (sLit $ popCntLabel w) + (MO_Pext w) -> ptext (sLit $ pextLabel w) + (MO_Pdep w) -> ptext (sLit $ pdepLabel w) (MO_Clz w) -> ptext (sLit $ clzLabel w) (MO_Ctz w) -> ptext (sLit $ ctzLabel w) (MO_AtomicRMW w amop) -> ptext (sLit $ atomicRMWLabel w amop) diff --git a/compiler/codeGen/StgCmmPrim.hs b/compiler/codeGen/StgCmmPrim.hs index 0a6ac9dba3..948af2aba0 100644 --- a/compiler/codeGen/StgCmmPrim.hs +++ b/compiler/codeGen/StgCmmPrim.hs @@ -584,6 +584,20 @@ emitPrimOp _ [res] PopCnt32Op [w] = emitPopCntCall res w W32 emitPrimOp _ [res] PopCnt64Op [w] = emitPopCntCall res w W64 emitPrimOp dflags [res] PopCntOp [w] = emitPopCntCall res w (wordWidth dflags) +-- Parallel bit deposit +emitPrimOp _ [res] Pdep8Op [src, mask] = emitPdepCall res src mask W8 +emitPrimOp _ [res] Pdep16Op [src, mask] = emitPdepCall res src mask W16 +emitPrimOp _ [res] Pdep32Op [src, mask] = emitPdepCall res src mask W32 +emitPrimOp _ [res] Pdep64Op [src, mask] = emitPdepCall res src mask W64 +emitPrimOp dflags [res] PdepOp [src, mask] = emitPdepCall res src mask (wordWidth dflags) + +-- Parallel bit extract +emitPrimOp _ [res] Pext8Op [src, mask] = emitPextCall res src mask W8 +emitPrimOp _ [res] Pext16Op [src, mask] = emitPextCall res src mask W16 +emitPrimOp _ [res] Pext32Op [src, mask] = emitPextCall res src mask W32 +emitPrimOp _ [res] Pext64Op [src, mask] = emitPextCall res src mask W64 +emitPrimOp dflags [res] PextOp [src, mask] = emitPextCall res src mask (wordWidth dflags) + -- count leading zeros emitPrimOp _ [res] Clz8Op [w] = emitClzCall res w W8 emitPrimOp _ [res] Clz16Op [w] = emitClzCall res w W16 @@ -2266,6 +2280,20 @@ emitPopCntCall res x width = do (MO_PopCnt width) [ x ] +emitPdepCall :: LocalReg -> CmmExpr -> CmmExpr -> Width -> FCode () +emitPdepCall res x y width = do + emitPrimCall + [ res ] + (MO_Pdep width) + [ x, y ] + +emitPextCall :: LocalReg -> CmmExpr -> CmmExpr -> Width -> FCode () +emitPextCall res x y width = do + emitPrimCall + [ res ] + (MO_Pext width) + [ x, y ] + emitClzCall :: LocalReg -> CmmExpr -> Width -> FCode () emitClzCall res x width = do emitPrimCall diff --git a/compiler/coreSyn/MkCore.hs b/compiler/coreSyn/MkCore.hs index 49a9e9dbbd..f2b940bfd1 100644 --- a/compiler/coreSyn/MkCore.hs +++ b/compiler/coreSyn/MkCore.hs @@ -867,4 +867,3 @@ mkAbsentErrorApp res_ty err_msg = mkApps (Var aBSENT_ERROR_ID) [ Type res_ty, err_string ] where err_string = Lit (mkMachString err_msg) - diff --git a/compiler/llvmGen/LlvmCodeGen/CodeGen.hs b/compiler/llvmGen/LlvmCodeGen/CodeGen.hs index a88642b531..e812dd445f 100644 --- a/compiler/llvmGen/LlvmCodeGen/CodeGen.hs +++ b/compiler/llvmGen/LlvmCodeGen/CodeGen.hs @@ -46,6 +46,8 @@ import Data.Maybe ( catMaybes ) type Atomic = Bool type LlvmStatements = OrdList LlvmStatement +data Signage = Signed | Unsigned deriving (Eq, Show) + -- ----------------------------------------------------------------------------- -- | Top-level of the LLVM proc Code generator -- @@ -207,7 +209,7 @@ genCall t@(PrimTarget (MO_Prefetch_Data localityInt)) [] args let args_hints' = zip args arg_hints argVars <- arg_varsW args_hints' ([], nilOL, []) fptr <- liftExprData $ getFunPtr funTy t - argVars' <- castVarsW $ zip argVars argTy + argVars' <- castVarsW Signed $ zip argVars argTy doTrashStmts let argSuffix = [mkIntLit i32 0, mkIntLit i32 localityInt, mkIntLit i32 1] @@ -218,6 +220,11 @@ genCall t@(PrimTarget (MO_Prefetch_Data localityInt)) [] args -- and return types genCall t@(PrimTarget (MO_PopCnt w)) dsts args = genCallSimpleCast w t dsts args + +genCall t@(PrimTarget (MO_Pdep w)) dsts args = + genCallSimpleCast2 w t dsts args +genCall t@(PrimTarget (MO_Pext w)) dsts args = + genCallSimpleCast2 w t dsts args genCall t@(PrimTarget (MO_Clz w)) dsts args = genCallSimpleCast w t dsts args genCall t@(PrimTarget (MO_Ctz w)) dsts args = @@ -285,7 +292,7 @@ genCall t@(PrimTarget op) [] args let args_hints = zip args arg_hints argVars <- arg_varsW args_hints ([], nilOL, []) fptr <- getFunPtrW funTy t - argVars' <- castVarsW $ zip argVars argTy + argVars' <- castVarsW Signed $ zip argVars argTy doTrashStmts let alignVal = mkIntLit i32 align @@ -518,7 +525,7 @@ genCallExtract target@(PrimTarget op) w (argA, argB) (llvmTypeA, llvmTypeB) = do -- Process the arguments. let args_hints = zip [argA, argB] (snd $ foreignTargetHints target) (argsV1, args1, top1) <- arg_vars args_hints ([], nilOL, []) - (argsV2, args2) <- castVars $ zip argsV1 argTy + (argsV2, args2) <- castVars Signed $ zip argsV1 argTy -- Get the function and make the call. fname <- cmmPrimOpFunctions op @@ -558,9 +565,9 @@ genCallSimpleCast w t@(PrimTarget op) [dst] args = do let (_, arg_hints) = foreignTargetHints t let args_hints = zip args arg_hints (argsV, stmts2, top2) <- arg_vars args_hints ([], nilOL, []) - (argsV', stmts4) <- castVars $ zip argsV [width] + (argsV', stmts4) <- castVars Signed $ zip argsV [width] (retV, s1) <- doExpr width $ Call StdCall fptr argsV' [] - ([retV'], stmts5) <- castVars [(retV,dstTy)] + ([retV'], stmts5) <- castVars (cmmPrimOpRetValSignage op) [(retV,dstTy)] let s2 = Store retV' dstV let stmts = stmts2 `appOL` stmts4 `snocOL` @@ -569,6 +576,37 @@ genCallSimpleCast w t@(PrimTarget op) [dst] args = do genCallSimpleCast _ _ dsts _ = panic ("genCallSimpleCast: " ++ show (length dsts) ++ " dsts") +-- Handle simple function call that only need simple type casting, of the form: +-- truncate arg >>= \a -> call(a) >>= zext +-- +-- since GHC only really has i32 and i64 types and things like Word8 are backed +-- by an i32 and just present a logical i8 range. So we must handle conversions +-- from i32 to i8 explicitly as LLVM is strict about types. +genCallSimpleCast2 :: Width -> ForeignTarget -> [CmmFormal] -> [CmmActual] + -> LlvmM StmtData +genCallSimpleCast2 w t@(PrimTarget op) [dst] args = do + let width = widthToLlvmInt w + dstTy = cmmToLlvmType $ localRegType dst + + fname <- cmmPrimOpFunctions op + (fptr, _, top3) <- getInstrinct fname width (const width <$> args) + + dstV <- getCmmReg (CmmLocal dst) + + let (_, arg_hints) = foreignTargetHints t + let args_hints = zip args arg_hints + (argsV, stmts2, top2) <- arg_vars args_hints ([], nilOL, []) + (argsV', stmts4) <- castVars Signed $ zip argsV (const width <$> argsV) + (retV, s1) <- doExpr width $ Call StdCall fptr argsV' [] + ([retV'], stmts5) <- castVars (cmmPrimOpRetValSignage op) [(retV,dstTy)] + let s2 = Store retV' dstV + + let stmts = stmts2 `appOL` stmts4 `snocOL` + s1 `appOL` stmts5 `snocOL` s2 + return (stmts, top2 ++ top3) +genCallSimpleCast2 _ _ dsts _ = + panic ("genCallSimpleCast2: " ++ show (length dsts) ++ " dsts") + -- | Create a function pointer from a target. getFunPtrW :: (LMString -> LlvmType) -> ForeignTarget -> WriterT LlvmAccum LlvmM LlvmVar @@ -638,31 +676,32 @@ arg_vars ((e, _):rest) (vars, stmts, tops) -- | Cast a collection of LLVM variables to specific types. -castVarsW :: [(LlvmVar, LlvmType)] +castVarsW :: Signage + -> [(LlvmVar, LlvmType)] -> WriterT LlvmAccum LlvmM [LlvmVar] -castVarsW vars = do - (vars, stmts) <- lift $ castVars vars +castVarsW signage vars = do + (vars, stmts) <- lift $ castVars signage vars tell $ LlvmAccum stmts mempty return vars -- | Cast a collection of LLVM variables to specific types. -castVars :: [(LlvmVar, LlvmType)] +castVars :: Signage -> [(LlvmVar, LlvmType)] -> LlvmM ([LlvmVar], LlvmStatements) -castVars vars = do - done <- mapM (uncurry castVar) vars +castVars signage vars = do + done <- mapM (uncurry (castVar signage)) vars let (vars', stmts) = unzip done return (vars', toOL stmts) -- | Cast an LLVM variable to a specific type, panicing if it can't be done. -castVar :: LlvmVar -> LlvmType -> LlvmM (LlvmVar, LlvmStatement) -castVar v t | getVarType v == t +castVar :: Signage -> LlvmVar -> LlvmType -> LlvmM (LlvmVar, LlvmStatement) +castVar signage v t | getVarType v == t = return (v, Nop) | otherwise = do dflags <- getDynFlags let op = case (getVarType v, t) of (LMInt n, LMInt m) - -> if n < m then LM_Sext else LM_Trunc + -> if n < m then extend else LM_Trunc (vt, _) | isFloat vt && isFloat t -> if llvmWidthInBits dflags vt < llvmWidthInBits dflags t then LM_Fpext else LM_Fptrunc @@ -676,7 +715,16 @@ castVar v t | getVarType v == t (vt, _) -> panic $ "castVars: Can't cast this type (" ++ showSDoc dflags (ppr vt) ++ ") to (" ++ showSDoc dflags (ppr t) ++ ")" doExpr t $ Cast op v t + where extend = case signage of + Signed -> LM_Sext + Unsigned -> LM_Zext + +cmmPrimOpRetValSignage :: CallishMachOp -> Signage +cmmPrimOpRetValSignage mop = case mop of + MO_Pdep _ -> Unsigned + MO_Pext _ -> Unsigned + _ -> Signed -- | Decide what C function to use to implement a CallishMachOp cmmPrimOpFunctions :: CallishMachOp -> LlvmM LMString @@ -735,6 +783,15 @@ cmmPrimOpFunctions mop = do (MO_Clz w) -> fsLit $ "llvm.ctlz." ++ showSDoc dflags (ppr $ widthToLlvmInt w) (MO_Ctz w) -> fsLit $ "llvm.cttz." ++ showSDoc dflags (ppr $ widthToLlvmInt w) + (MO_Pdep w) -> let w' = showSDoc dflags (ppr $ widthInBits w) + in if isBmi2Enabled dflags + then fsLit $ "llvm.x86.bmi.pdep." ++ w' + else fsLit $ "hs_pdep" ++ w' + (MO_Pext w) -> let w' = showSDoc dflags (ppr $ widthInBits w) + in if isBmi2Enabled dflags + then fsLit $ "llvm.x86.bmi.pext." ++ w' + else fsLit $ "hs_pext" ++ w' + (MO_Prefetch_Data _ )-> fsLit "llvm.prefetch" MO_AddIntC w -> fsLit $ "llvm.sadd.with.overflow." @@ -1212,7 +1269,7 @@ genMachOp _ op [x] = case op of negateVec ty v2 negOp = do (vx, stmts1, top) <- exprToVar x - ([vx'], stmts2) <- castVars [(vx, ty)] + ([vx'], stmts2) <- castVars Signed [(vx, ty)] (v1, s1) <- doExpr ty $ LlvmOp negOp v2 vx' return (v1, stmts1 `appOL` stmts2 `snocOL` s1, top) @@ -1275,7 +1332,7 @@ genMachOp_slow :: EOption -> MachOp -> [CmmExpr] -> LlvmM ExprData genMachOp_slow _ (MO_V_Extract l w) [val, idx] = runExprData $ do vval <- exprToVarW val vidx <- exprToVarW idx - [vval'] <- castVarsW [(vval, LMVector l ty)] + [vval'] <- castVarsW Signed [(vval, LMVector l ty)] doExprW ty $ Extract vval' vidx where ty = widthToLlvmInt w @@ -1283,7 +1340,7 @@ genMachOp_slow _ (MO_V_Extract l w) [val, idx] = runExprData $ do genMachOp_slow _ (MO_VF_Extract l w) [val, idx] = runExprData $ do vval <- exprToVarW val vidx <- exprToVarW idx - [vval'] <- castVarsW [(vval, LMVector l ty)] + [vval'] <- castVarsW Signed [(vval, LMVector l ty)] doExprW ty $ Extract vval' vidx where ty = widthToLlvmFloat w @@ -1293,7 +1350,7 @@ genMachOp_slow _ (MO_V_Insert l w) [val, elt, idx] = runExprData $ do vval <- exprToVarW val velt <- exprToVarW elt vidx <- exprToVarW idx - [vval'] <- castVarsW [(vval, ty)] + [vval'] <- castVarsW Signed [(vval, ty)] doExprW ty $ Insert vval' velt vidx where ty = LMVector l (widthToLlvmInt w) @@ -1302,7 +1359,7 @@ genMachOp_slow _ (MO_VF_Insert l w) [val, elt, idx] = runExprData $ do vval <- exprToVarW val velt <- exprToVarW elt vidx <- exprToVarW idx - [vval'] <- castVarsW [(vval, ty)] + [vval'] <- castVarsW Signed [(vval, ty)] doExprW ty $ Insert vval' velt vidx where ty = LMVector l (widthToLlvmFloat w) @@ -1414,7 +1471,7 @@ genMachOp_slow opt op [x, y] = case op of binCastLlvmOp ty binOp = runExprData $ do vx <- exprToVarW x vy <- exprToVarW y - [vx', vy'] <- castVarsW [(vx, ty), (vy, ty)] + [vx', vy'] <- castVarsW Signed [(vx, ty), (vy, ty)] doExprW ty $ binOp vx' vy' -- | Need to use EOption here as Cmm expects word size results from diff --git a/compiler/main/DriverPipeline.hs b/compiler/main/DriverPipeline.hs index 90976b115b..c6c9f9e1f6 100644 --- a/compiler/main/DriverPipeline.hs +++ b/compiler/main/DriverPipeline.hs @@ -848,6 +848,8 @@ llvmOptions dflags = ++ ["+avx512cd"| isAvx512cdEnabled dflags ] ++ ["+avx512er"| isAvx512erEnabled dflags ] ++ ["+avx512pf"| isAvx512pfEnabled dflags ] + ++ ["+bmi" | isBmiEnabled dflags ] + ++ ["+bmi2" | isBmi2Enabled dflags ] -- ----------------------------------------------------------------------------- -- | Each phase in the pipeline returns the next phase to execute, and the diff --git a/compiler/main/DynFlags.hs b/compiler/main/DynFlags.hs index ef4e2f8b85..3324d5532e 100644 --- a/compiler/main/DynFlags.hs +++ b/compiler/main/DynFlags.hs @@ -150,6 +150,8 @@ module DynFlags ( isSseEnabled, isSse2Enabled, isSse4_2Enabled, + isBmiEnabled, + isBmi2Enabled, isAvxEnabled, isAvx2Enabled, isAvx512cdEnabled, @@ -1005,6 +1007,7 @@ data DynFlags = DynFlags { -- | Machine dependent flags (-m<blah> stuff) sseVersion :: Maybe SseVersion, + bmiVersion :: Maybe BmiVersion, avx :: Bool, avx2 :: Bool, avx512cd :: Bool, -- Enable AVX-512 Conflict Detection Instructions. @@ -1806,6 +1809,7 @@ defaultDynFlags mySettings myLlvmTargets = interactivePrint = Nothing, nextWrapperNum = panic "defaultDynFlags: No nextWrapperNum", sseVersion = Nothing, + bmiVersion = Nothing, avx = False, avx2 = False, avx512cd = False, @@ -3201,6 +3205,10 @@ dynamic_flags_deps = [ d { sseVersion = Just SSE4 })) , make_ord_flag defGhcFlag "msse4.2" (noArg (\d -> d { sseVersion = Just SSE42 })) + , make_ord_flag defGhcFlag "mbmi" (noArg (\d -> + d { bmiVersion = Just BMI1 })) + , make_ord_flag defGhcFlag "mbmi2" (noArg (\d -> + d { bmiVersion = Just BMI2 })) , make_ord_flag defGhcFlag "mavx" (noArg (\d -> d { avx = True })) , make_ord_flag defGhcFlag "mavx2" (noArg (\d -> d { avx2 = True })) , make_ord_flag defGhcFlag "mavx512cd" (noArg (\d -> @@ -5447,6 +5455,25 @@ isAvx512pfEnabled :: DynFlags -> Bool isAvx512pfEnabled dflags = avx512pf dflags -- ----------------------------------------------------------------------------- +-- BMI2 + +data BmiVersion = BMI1 + | BMI2 + deriving (Eq, Ord) + +isBmiEnabled :: DynFlags -> Bool +isBmiEnabled dflags = case platformArch (targetPlatform dflags) of + ArchX86_64 -> bmiVersion dflags >= Just BMI1 + ArchX86 -> bmiVersion dflags >= Just BMI1 + _ -> False + +isBmi2Enabled :: DynFlags -> Bool +isBmi2Enabled dflags = case platformArch (targetPlatform dflags) of + ArchX86_64 -> bmiVersion dflags >= Just BMI2 + ArchX86 -> bmiVersion dflags >= Just BMI2 + _ -> False + +-- ----------------------------------------------------------------------------- -- Linker/compiler information -- LinkerInfo contains any extra options needed by the system linker. diff --git a/compiler/nativeGen/CPrim.hs b/compiler/nativeGen/CPrim.hs index ad61a002d3..399d646000 100644 --- a/compiler/nativeGen/CPrim.hs +++ b/compiler/nativeGen/CPrim.hs @@ -5,6 +5,8 @@ module CPrim , atomicRMWLabel , cmpxchgLabel , popCntLabel + , pdepLabel + , pextLabel , bSwapLabel , clzLabel , ctzLabel @@ -26,6 +28,24 @@ popCntLabel w = "hs_popcnt" ++ pprWidth w pprWidth W64 = "64" pprWidth w = pprPanic "popCntLabel: Unsupported word width " (ppr w) +pdepLabel :: Width -> String +pdepLabel w = "hs_pdep" ++ pprWidth w + where + pprWidth W8 = "8" + pprWidth W16 = "16" + pprWidth W32 = "32" + pprWidth W64 = "64" + pprWidth w = pprPanic "pdepLabel: Unsupported word width " (ppr w) + +pextLabel :: Width -> String +pextLabel w = "hs_pext" ++ pprWidth w + where + pprWidth W8 = "8" + pprWidth W16 = "16" + pprWidth W32 = "32" + pprWidth W64 = "64" + pprWidth w = pprPanic "pextLabel: Unsupported word width " (ppr w) + bSwapLabel :: Width -> String bSwapLabel w = "hs_bswap" ++ pprWidth w where diff --git a/compiler/nativeGen/PPC/CodeGen.hs b/compiler/nativeGen/PPC/CodeGen.hs index 898a31a657..e2c568c836 100644 --- a/compiler/nativeGen/PPC/CodeGen.hs +++ b/compiler/nativeGen/PPC/CodeGen.hs @@ -2004,6 +2004,8 @@ genCCall' dflags gcp target dest_regs args MO_BSwap w -> (fsLit $ bSwapLabel w, False) MO_PopCnt w -> (fsLit $ popCntLabel w, False) + MO_Pdep w -> (fsLit $ pdepLabel w, False) + MO_Pext w -> (fsLit $ pextLabel w, False) MO_Clz _ -> unsupported MO_Ctz _ -> unsupported MO_AtomicRMW {} -> unsupported diff --git a/compiler/nativeGen/SPARC/CodeGen.hs b/compiler/nativeGen/SPARC/CodeGen.hs index 55c1d1531d..6dfd58950e 100644 --- a/compiler/nativeGen/SPARC/CodeGen.hs +++ b/compiler/nativeGen/SPARC/CodeGen.hs @@ -654,6 +654,8 @@ outOfLineMachOp_table mop MO_BSwap w -> fsLit $ bSwapLabel w MO_PopCnt w -> fsLit $ popCntLabel w + MO_Pdep w -> fsLit $ pdepLabel w + MO_Pext w -> fsLit $ pextLabel w MO_Clz w -> fsLit $ clzLabel w MO_Ctz w -> fsLit $ ctzLabel w MO_AtomicRMW w amop -> fsLit $ atomicRMWLabel w amop diff --git a/compiler/nativeGen/X86/CodeGen.hs b/compiler/nativeGen/X86/CodeGen.hs index 6c0e0ac783..eb6af1ff41 100644 --- a/compiler/nativeGen/X86/CodeGen.hs +++ b/compiler/nativeGen/X86/CodeGen.hs @@ -1872,6 +1872,72 @@ genCCall dflags is32Bit (PrimTarget (MO_PopCnt width)) dest_regs@[dst] format = intFormat width lbl = mkCmmCodeLabel primUnitId (fsLit (popCntLabel width)) +genCCall dflags is32Bit (PrimTarget (MO_Pdep width)) dest_regs@[dst] + args@[src, mask] = do + let platform = targetPlatform dflags + if isBmi2Enabled dflags + then do code_src <- getAnyReg src + code_mask <- getAnyReg mask + src_r <- getNewRegNat format + mask_r <- getNewRegNat format + let dst_r = getRegisterReg platform False (CmmLocal dst) + return $ code_src src_r `appOL` code_mask mask_r `appOL` + (if width == W8 then + -- The PDEP instruction doesn't take a r/m8 + unitOL (MOVZxL II8 (OpReg src_r ) (OpReg src_r )) `appOL` + unitOL (MOVZxL II8 (OpReg mask_r) (OpReg mask_r)) `appOL` + unitOL (PDEP II16 (OpReg mask_r) (OpReg src_r ) dst_r) + else + unitOL (PDEP format (OpReg mask_r) (OpReg src_r) dst_r)) `appOL` + (if width == W8 || width == W16 then + -- We used a 16-bit destination register above, + -- so zero-extend + unitOL (MOVZxL II16 (OpReg dst_r) (OpReg dst_r)) + else nilOL) + else do + targetExpr <- cmmMakeDynamicReference dflags + CallReference lbl + let target = ForeignTarget targetExpr (ForeignConvention CCallConv + [NoHint] [NoHint] + CmmMayReturn) + genCCall dflags is32Bit target dest_regs args + where + format = intFormat width + lbl = mkCmmCodeLabel primUnitId (fsLit (pdepLabel width)) + +genCCall dflags is32Bit (PrimTarget (MO_Pext width)) dest_regs@[dst] + args@[src, mask] = do + let platform = targetPlatform dflags + if isBmi2Enabled dflags + then do code_src <- getAnyReg src + code_mask <- getAnyReg mask + src_r <- getNewRegNat format + mask_r <- getNewRegNat format + let dst_r = getRegisterReg platform False (CmmLocal dst) + return $ code_src src_r `appOL` code_mask mask_r `appOL` + (if width == W8 then + -- The PEXT instruction doesn't take a r/m8 + unitOL (MOVZxL II8 (OpReg src_r ) (OpReg src_r )) `appOL` + unitOL (MOVZxL II8 (OpReg mask_r) (OpReg mask_r)) `appOL` + unitOL (PEXT II16 (OpReg mask_r) (OpReg src_r) dst_r) + else + unitOL (PEXT format (OpReg mask_r) (OpReg src_r) dst_r)) `appOL` + (if width == W8 || width == W16 then + -- We used a 16-bit destination register above, + -- so zero-extend + unitOL (MOVZxL II16 (OpReg dst_r) (OpReg dst_r)) + else nilOL) + else do + targetExpr <- cmmMakeDynamicReference dflags + CallReference lbl + let target = ForeignTarget targetExpr (ForeignConvention CCallConv + [NoHint] [NoHint] + CmmMayReturn) + genCCall dflags is32Bit target dest_regs args + where + format = intFormat width + lbl = mkCmmCodeLabel primUnitId (fsLit (pextLabel width)) + genCCall dflags is32Bit (PrimTarget (MO_Clz width)) dest_regs@[dst] args@[src] | is32Bit && width == W64 = do -- Fallback to `hs_clz64` on i386 @@ -2689,6 +2755,9 @@ outOfLineCmmOp mop res args MO_Clz w -> fsLit $ clzLabel w MO_Ctz _ -> unsupported + MO_Pdep w -> fsLit $ pdepLabel w + MO_Pext w -> fsLit $ pextLabel w + MO_AtomicRMW _ _ -> fsLit "atomicrmw" MO_AtomicRead _ -> fsLit "atomicread" MO_AtomicWrite _ -> fsLit "atomicwrite" diff --git a/compiler/nativeGen/X86/Instr.hs b/compiler/nativeGen/X86/Instr.hs index c937d4dba0..f4f625b4a5 100644 --- a/compiler/nativeGen/X86/Instr.hs +++ b/compiler/nativeGen/X86/Instr.hs @@ -345,6 +345,10 @@ data Instr | BSF Format Operand Reg -- bit scan forward | BSR Format Operand Reg -- bit scan reverse + -- bit manipulation instructions + | PDEP Format Operand Operand Reg -- [BMI2] deposit bits to the specified mask + | PEXT Format Operand Operand Reg -- [BMI2] extract bits from the specified mask + -- prefetch | PREFETCH PrefetchVariant Format Operand -- prefetch Variant, addr size, address to prefetch -- variant can be NTA, Lvl0, Lvl1, or Lvl2 @@ -464,6 +468,9 @@ x86_regUsageOfInstr platform instr BSF _ src dst -> mkRU (use_R src []) [dst] BSR _ src dst -> mkRU (use_R src []) [dst] + PDEP _ src mask dst -> mkRU (use_R src $ use_R mask []) [dst] + PEXT _ src mask dst -> mkRU (use_R src $ use_R mask []) [dst] + -- note: might be a better way to do this PREFETCH _ _ src -> mkRU (use_R src []) [] LOCK i -> x86_regUsageOfInstr platform i @@ -640,6 +647,8 @@ x86_patchRegsOfInstr instr env CLTD _ -> instr POPCNT fmt src dst -> POPCNT fmt (patchOp src) (env dst) + PDEP fmt src mask dst -> PDEP fmt (patchOp src) (patchOp mask) (env dst) + PEXT fmt src mask dst -> PEXT fmt (patchOp src) (patchOp mask) (env dst) BSF fmt src dst -> BSF fmt (patchOp src) (env dst) BSR fmt src dst -> BSR fmt (patchOp src) (env dst) diff --git a/compiler/nativeGen/X86/Ppr.hs b/compiler/nativeGen/X86/Ppr.hs index 84ce7516b5..f5011b2a95 100644 --- a/compiler/nativeGen/X86/Ppr.hs +++ b/compiler/nativeGen/X86/Ppr.hs @@ -648,6 +648,9 @@ pprInstr (POPCNT format src dst) = pprOpOp (sLit "popcnt") format src (OpReg dst pprInstr (BSF format src dst) = pprOpOp (sLit "bsf") format src (OpReg dst) pprInstr (BSR format src dst) = pprOpOp (sLit "bsr") format src (OpReg dst) +pprInstr (PDEP format src mask dst) = pprFormatOpOpReg (sLit "pdep") format src mask dst +pprInstr (PEXT format src mask dst) = pprFormatOpOpReg (sLit "pext") format src mask dst + pprInstr (PREFETCH NTA format src ) = pprFormatOp_ (sLit "prefetchnta") format src pprInstr (PREFETCH Lvl0 format src) = pprFormatOp_ (sLit "prefetcht0") format src pprInstr (PREFETCH Lvl1 format src) = pprFormatOp_ (sLit "prefetcht1") format src @@ -1262,6 +1265,16 @@ pprFormatRegRegReg name format reg1 reg2 reg3 pprReg format reg3 ] +pprFormatOpOpReg :: LitString -> Format -> Operand -> Operand -> Reg -> SDoc +pprFormatOpOpReg name format op1 op2 reg3 + = hcat [ + pprMnemonic name format, + pprOperand format op1, + comma, + pprOperand format op2, + comma, + pprReg format reg3 + ] pprFormatAddrReg :: LitString -> Format -> AddrMode -> Reg -> SDoc pprFormatAddrReg name format op dst diff --git a/compiler/prelude/primops.txt.pp b/compiler/prelude/primops.txt.pp index d8d7f6e3e1..43e8f535d3 100644 --- a/compiler/prelude/primops.txt.pp +++ b/compiler/prelude/primops.txt.pp @@ -403,6 +403,28 @@ primop PopCnt64Op "popCnt64#" GenPrimOp WORD64 -> Word# primop PopCntOp "popCnt#" Monadic Word# -> Word# {Count the number of set bits in a word.} +primop Pdep8Op "pdep8#" Dyadic Word# -> Word# -> Word# + {Deposit bits to lower 8 bits of a word at locations specified by a mask.} +primop Pdep16Op "pdep16#" Dyadic Word# -> Word# -> Word# + {Deposit bits to lower 16 bits of a word at locations specified by a mask.} +primop Pdep32Op "pdep32#" Dyadic Word# -> Word# -> Word# + {Deposit bits to lower 32 bits of a word at locations specified by a mask.} +primop Pdep64Op "pdep64#" GenPrimOp WORD64 -> WORD64 -> WORD64 + {Deposit bits to a word at locations specified by a mask.} +primop PdepOp "pdep#" Dyadic Word# -> Word# -> Word# + {Deposit bits to a word at locations specified by a mask.} + +primop Pext8Op "pext8#" Dyadic Word# -> Word# -> Word# + {Extract bits from lower 8 bits of a word at locations specified by a mask.} +primop Pext16Op "pext16#" Dyadic Word# -> Word# -> Word# + {Extract bits from lower 16 bits of a word at locations specified by a mask.} +primop Pext32Op "pext32#" Dyadic Word# -> Word# -> Word# + {Extract bits from lower 32 bits of a word at locations specified by a mask.} +primop Pext64Op "pext64#" GenPrimOp WORD64 -> WORD64 -> WORD64 + {Extract bits from a word at locations specified by a mask.} +primop PextOp "pext#" Dyadic Word# -> Word# -> Word# + {Extract bits from a word at locations specified by a mask.} + primop Clz8Op "clz8#" Monadic Word# -> Word# {Count leading zeros in the lower 8 bits of a word.} primop Clz16Op "clz16#" Monadic Word# -> Word# diff --git a/libraries/ghc-prim/cbits/pdep.c b/libraries/ghc-prim/cbits/pdep.c new file mode 100644 index 0000000000..8435ffe186 --- /dev/null +++ b/libraries/ghc-prim/cbits/pdep.c @@ -0,0 +1,48 @@ +#include "Rts.h" +#include "MachDeps.h" + +extern StgWord64 hs_pdep64(StgWord64 src, StgWord64 mask); + +StgWord64 +hs_pdep64(StgWord64 src, StgWord64 mask) +{ + uint64_t result = 0; + + while (1) { + // Mask out all but the lowest bit + const uint64_t lowest = (-mask & mask); + + if (lowest == 0) { + break; + } + + const uint64_t lsb = (uint64_t)((int64_t)(src << 63) >> 63); + + result |= lsb & lowest; + mask &= ~lowest; + src >>= 1; + } + + return result; +} + +extern StgWord hs_pdep32(StgWord src, StgWord mask); +StgWord +hs_pdep32(StgWord src, StgWord mask) +{ + return hs_pdep64(src, mask); +} + +extern StgWord hs_pdep16(StgWord src, StgWord mask); +StgWord +hs_pdep16(StgWord src, StgWord mask) +{ + return hs_pdep64(src, mask); +} + +extern StgWord hs_pdep8(StgWord src, StgWord mask); +StgWord +hs_pdep8(StgWord src, StgWord mask) +{ + return hs_pdep64(src, mask); +} diff --git a/libraries/ghc-prim/cbits/pext.c b/libraries/ghc-prim/cbits/pext.c new file mode 100644 index 0000000000..fe960b1342 --- /dev/null +++ b/libraries/ghc-prim/cbits/pext.c @@ -0,0 +1,44 @@ +#include "Rts.h" +#include "MachDeps.h" + +extern StgWord64 hs_pext64(StgWord64 src, StgWord64 mask); + +StgWord64 +hs_pext64(StgWord64 src, StgWord64 mask) +{ + uint64_t result = 0; + int offset = 0; + + for (int bit = 0; bit != sizeof(uint64_t) * 8; ++bit) { + const uint64_t src_bit = (src >> bit) & 1; + const uint64_t mask_bit = (mask >> bit) & 1; + + if (mask_bit) { + result |= (uint64_t)(src_bit) << offset; + ++offset; + } + } + + return result; +} + +extern StgWord hs_pext32(StgWord src, StgWord mask); +StgWord +hs_pext32(StgWord src, StgWord mask) +{ + return hs_pext64(src, mask); +} + +extern StgWord hs_pext16(StgWord src, StgWord mask); +StgWord +hs_pext16(StgWord src, StgWord mask) +{ + return hs_pext64(src, mask); +} + +extern StgWord hs_pext8(StgWord src, StgWord mask); +StgWord +hs_pext8(StgWord src, StgWord mask) +{ + return hs_pext64(src, mask); +} diff --git a/libraries/ghc-prim/ghc-prim.cabal b/libraries/ghc-prim/ghc-prim.cabal index e99686a10b..9b8c1ac196 100644 --- a/libraries/ghc-prim/ghc-prim.cabal +++ b/libraries/ghc-prim/ghc-prim.cabal @@ -73,6 +73,8 @@ Library cbits/ctz.c cbits/debug.c cbits/longlong.c + cbits/pdep.c + cbits/pext.c cbits/popcnt.c cbits/word2float.c diff --git a/testsuite/tests/codeGen/should_run/all.T b/testsuite/tests/codeGen/should_run/all.T index 214a9d5704..42d8a2f767 100644 --- a/testsuite/tests/codeGen/should_run/all.T +++ b/testsuite/tests/codeGen/should_run/all.T @@ -77,6 +77,8 @@ test('cgrun069', omit_ways(['ghci']), multi_compile_and_run, test('cgrun070', normal, compile_and_run, ['']) test('cgrun071', normal, compile_and_run, ['']) test('cgrun072', normal, compile_and_run, ['']) +test('cgrun075', normal, compile_and_run, ['']) +test('cgrun076', normal, compile_and_run, ['']) test('T1852', normal, compile_and_run, ['']) test('T1861', extra_run_opts('0'), compile_and_run, ['']) diff --git a/testsuite/tests/codeGen/should_run/cgrun075.hs b/testsuite/tests/codeGen/should_run/cgrun075.hs new file mode 100644 index 0000000000..09e35b4d8a --- /dev/null +++ b/testsuite/tests/codeGen/should_run/cgrun075.hs @@ -0,0 +1,115 @@ +{-# LANGUAGE BangPatterns, CPP, MagicHash #-} + +module Main ( main ) where + +import Data.Bits +import GHC.Int +import GHC.Prim +import GHC.Word +import Data.Int +import Data.Word + +#include "MachDeps.h" + +main = putStr + ( test_pdep ++ "\n" + ++ test_pdep8 ++ "\n" + ++ test_pdep16 ++ "\n" + ++ test_pdep32 ++ "\n" + ++ test_pdep64 ++ "\n" + ++ "\n" + ) + +class Pdep a where + pdep :: a -> a -> a + +instance Pdep Word where + pdep (W# src#) (W# mask#) = W# (pdep# src# mask#) + +instance Pdep Word8 where + pdep (W8# src#) (W8# mask#) = W8# (pdep8# src# mask#) + +instance Pdep Word16 where + pdep (W16# src#) (W16# mask#) = W16# (pdep16# src# mask#) + +instance Pdep Word32 where + pdep (W32# src#) (W32# mask#) = W32# (pdep32# src# mask#) + +instance Pdep Word64 where + pdep (W64# src#) (W64# mask#) = W64# (pdep64# src# mask#) + +class SlowPdep a where + slowPdep :: a -> a -> a + +instance SlowPdep Word where + slowPdep s m = fromIntegral (slowPdep64 (fromIntegral s) (fromIntegral m)) + +instance SlowPdep Word8 where + slowPdep s m = fromIntegral (slowPdep64 (fromIntegral s) (fromIntegral m)) + +instance SlowPdep Word16 where + slowPdep s m = fromIntegral (slowPdep64 (fromIntegral s) (fromIntegral m)) + +instance SlowPdep Word32 where + slowPdep s m = fromIntegral (slowPdep64 (fromIntegral s) (fromIntegral m)) + +instance SlowPdep Word64 where + slowPdep s m = fromIntegral (slowPdep64 (fromIntegral s) (fromIntegral m)) + +slowPdep64 :: Word64 -> Word64 -> Word64 +slowPdep64 = slowPdep64' 0 + +slowPdep32 :: Word32 -> Word32 -> Word32 +slowPdep32 s m = fromIntegral (slowPdep64 (fromIntegral s) (fromIntegral m)) + +lsb :: Word64 -> Word64 +lsb src = fromIntegral ((fromIntegral (src `shiftL` 63) :: Int64) `shiftR` 63) + +slowPdep64' :: Word64 -> Word64 -> Word64 -> Word64 +slowPdep64' result src mask = if lowest /= 0 + then slowPdep64' newResult (src `shiftR` 1) (mask .&. complement lowest) + else result + where lowest = (-mask) .&. mask + newResult = (result .|. ((lsb src) .&. lowest)) + +test_pdep = test (0 :: Word ) pdep slowPdep +test_pdep8 = test (0 :: Word8 ) pdep slowPdep +test_pdep16 = test (0 :: Word16) pdep slowPdep +test_pdep32 = test (0 :: Word32) pdep slowPdep +test_pdep64 = test (0 :: Word64) pdep slowPdep + +mask n = (2 ^ n) - 1 + +fst4 :: (a, b, c, d) -> a +fst4 (a, _, _, _) = a + +runCase :: Eq a + => (a -> a -> a) + -> (a -> a -> a) + -> (a, a) + -> (Bool, a, a, (a, a)) +runCase fast slow (x, y) = (slow x y == fast x y, slow x y, fast x y, (x, y)) + +test :: (Show a, Num a, Eq a) => a -> (a -> a -> a) -> (a -> a -> a) -> String +test _ fast slow = case failing of + [] -> "OK" + ((_, e, a, i):xs) -> + "FAIL\n" ++ " Input: " ++ show i ++ "\nExpected: " ++ show e ++ + "\n Actual: " ++ show a + where failing = dropWhile fst4 . map (runCase fast slow) $ cases + cases = (,) <$> numbers <*> numbers + -- 10 random numbers +#if SIZEOF_HSWORD == 4 + numbers = [ 1480294021, 1626858410, 2316287658, 1246556957, 3806579062 + , 65945563 , 1521588071, 791321966 , 1355466914, 2284998160 + ] +#elif SIZEOF_HSWORD == 8 + numbers = [ 11004539497957619752, 5625461252166958202 + , 1799960778872209546 , 16979826074020750638 + , 12789915432197771481, 11680809699809094550 + , 13208678822802632247, 13794454868797172383 + , 13364728999716654549, 17516539991479925226 + ] +#else +# error Unexpected word size +#endif diff --git a/testsuite/tests/codeGen/should_run/cgrun075.stdout b/testsuite/tests/codeGen/should_run/cgrun075.stdout new file mode 100644 index 0000000000..e22e2cd950 --- /dev/null +++ b/testsuite/tests/codeGen/should_run/cgrun075.stdout @@ -0,0 +1,6 @@ +OK +OK +OK +OK +OK + diff --git a/testsuite/tests/codeGen/should_run/cgrun076.hs b/testsuite/tests/codeGen/should_run/cgrun076.hs new file mode 100644 index 0000000000..7fa42d74e0 --- /dev/null +++ b/testsuite/tests/codeGen/should_run/cgrun076.hs @@ -0,0 +1,115 @@ +{-# LANGUAGE BangPatterns, CPP, MagicHash #-} + +module Main ( main ) where + +import Data.Bits +import GHC.Int +import GHC.Prim +import GHC.Word +import Data.Int +import Data.Word + +#include "MachDeps.h" + +main = putStr + ( test_pext ++ "\n" + ++ test_pext8 ++ "\n" + ++ test_pext16 ++ "\n" + ++ test_pext32 ++ "\n" + ++ test_pext64 ++ "\n" + ++ "\n" + ) + +class Pext a where + pext :: a -> a -> a + +instance Pext Word where + pext (W# src#) (W# mask#) = W# (pext# src# mask#) + +instance Pext Word8 where + pext (W8# src#) (W8# mask#) = W8# (pext8# src# mask#) + +instance Pext Word16 where + pext (W16# src#) (W16# mask#) = W16# (pext16# src# mask#) + +instance Pext Word32 where + pext (W32# src#) (W32# mask#) = W32# (pext32# src# mask#) + +instance Pext Word64 where + pext (W64# src#) (W64# mask#) = W64# (pext64# src# mask#) + +class SlowPext a where + slowPext :: a -> a -> a + +instance SlowPext Word where + slowPext s m = fromIntegral (slowPext64 (fromIntegral s) (fromIntegral m)) + +instance SlowPext Word8 where + slowPext s m = fromIntegral (slowPext64 (fromIntegral s) (fromIntegral m)) + +instance SlowPext Word16 where + slowPext s m = fromIntegral (slowPext64 (fromIntegral s) (fromIntegral m)) + +instance SlowPext Word32 where + slowPext s m = fromIntegral (slowPext64 (fromIntegral s) (fromIntegral m)) + +instance SlowPext Word64 where + slowPext s m = fromIntegral (slowPext64 (fromIntegral s) (fromIntegral m)) + +slowPext64 :: Word64 -> Word64 -> Word64 +slowPext64 = slowPext64' 0 0 0 + +slowPext32 :: Word32 -> Word32 -> Word32 +slowPext32 s m = fromIntegral (slowPext64 (fromIntegral s) (fromIntegral m)) + +slowPext64' :: Word64 -> Int -> Int -> Word64 -> Word64 -> Word64 +slowPext64' result offset index src mask = if index /= 64 + then if maskBit /= 0 + then slowPext64' nextResult (offset + 1) (index + 1) src mask + else slowPext64' result offset (index + 1) src mask + else result + where srcBit = (src `shiftR` index) .&. 1 + maskBit = (mask `shiftR` index) .&. 1 + nextResult = result .|. (srcBit `shiftL` offset) + +test_pext = test (0 :: Word ) pext slowPext +test_pext8 = test (0 :: Word8 ) pext slowPext +test_pext16 = test (0 :: Word16) pext slowPext +test_pext32 = test (0 :: Word32) pext slowPext +test_pext64 = test (0 :: Word64) pext slowPext + +mask n = (2 ^ n) - 1 + +fst4 :: (a, b, c, d) -> a +fst4 (a, _, _, _) = a + +runCase :: Eq a + => (a -> a -> a) + -> (a -> a -> a) + -> (a, a) + -> (Bool, a, a, (a, a)) +runCase fast slow (x, y) = (slow x y == fast x y, slow x y, fast x y, (x, y)) + +test :: (Show a, Num a, Eq a) => a -> (a -> a -> a) -> (a -> a -> a) -> String +test _ fast slow = case failing of + [] -> "OK" + ((_, e, a, i):xs) -> + "FAIL\n" ++ " Input: " ++ show i ++ "\nExpected: " ++ show e ++ + "\n Actual: " ++ show a + where failing = dropWhile fst4 . map (runCase fast slow) $ cases + cases = (,) <$> numbers <*> numbers + -- 10 random numbers +#if SIZEOF_HSWORD == 4 + numbers = [ 1480294021, 1626858410, 2316287658, 1246556957, 3806579062 + , 65945563 , 1521588071, 791321966 , 1355466914, 2284998160 + ] +#elif SIZEOF_HSWORD == 8 + numbers = [ 11004539497957619752, 5625461252166958202 + , 1799960778872209546 , 16979826074020750638 + , 12789915432197771481, 11680809699809094550 + , 13208678822802632247, 13794454868797172383 + , 13364728999716654549, 17516539991479925226 + ] +#else +# error Unexpected word size +#endif diff --git a/testsuite/tests/codeGen/should_run/cgrun076.stdout b/testsuite/tests/codeGen/should_run/cgrun076.stdout new file mode 100644 index 0000000000..e22e2cd950 --- /dev/null +++ b/testsuite/tests/codeGen/should_run/cgrun076.stdout @@ -0,0 +1,6 @@ +OK +OK +OK +OK +OK + |