From 87eebf98cb485f7c9175330051736e147ade9848 Mon Sep 17 00:00:00 2001 From: sheaf Date: Sat, 8 Apr 2023 13:42:58 +0200 Subject: Add fused multiply-add instructions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch adds eight new primops that fuse a multiplication and an addition or subtraction: - `{fmadd,fmsub,fnmadd,fnmsub}{Float,Double}#` fmadd x y z is x * y + z, computed with a single rounding step. This patch implements code generation for these primops in the following backends: - X86, AArch64 and PowerPC NCG, - LLVM - C WASM uses the C implementation. The primops are unsupported in the JavaScript backend. The following constant folding rules are also provided: - compute a * b + c when a, b, c are all literals, - x * y + 0 ==> x * y, - ±1 * y + z ==> z ± y and x * ±1 + z ==> z ± x. NB: the constant folding rules incorrectly handle signed zero. This is a known limitation with GHC's floating-point constant folding rules (#21227), which we hope to resolve in the future. --- compiler/GHC/Builtin/primops.txt.pp | 69 ++++++ compiler/GHC/Cmm/MachOp.hs | 38 +++ compiler/GHC/Cmm/Parser.y | 7 +- compiler/GHC/CmmToAsm/AArch64/CodeGen.hs | 41 +++- compiler/GHC/CmmToAsm/AArch64/Instr.hs | 19 ++ compiler/GHC/CmmToAsm/AArch64/Ppr.hs | 7 + compiler/GHC/CmmToAsm/PPC/CodeGen.hs | 35 ++- compiler/GHC/CmmToAsm/PPC/Instr.hs | 11 + compiler/GHC/CmmToAsm/PPC/Ppr.hs | 18 ++ compiler/GHC/CmmToAsm/Wasm/FromCmm.hs | 4 +- compiler/GHC/CmmToAsm/X86/CodeGen.hs | 92 ++++++- compiler/GHC/CmmToAsm/X86/Instr.hs | 22 +- compiler/GHC/CmmToAsm/X86/Ppr.hs | 23 ++ compiler/GHC/CmmToC.hs | 22 +- compiler/GHC/CmmToLlvm/CodeGen.hs | 28 ++- compiler/GHC/Core/Opt/ConstantFold.hs | 156 ++++++++++++ compiler/GHC/Driver/Config/StgToCmm.hs | 22 ++ compiler/GHC/Driver/Pipeline/Execute.hs | 1 + compiler/GHC/Driver/Session.hs | 11 +- compiler/GHC/Llvm/Ppr.hs | 8 + compiler/GHC/Llvm/Syntax.hs | 3 + compiler/GHC/Llvm/Types.hs | 1 - compiler/GHC/StgToCmm/Config.hs | 2 + compiler/GHC/StgToCmm/Prim.hs | 44 ++++ compiler/GHC/StgToJS/Prim.hs | 10 + compiler/GHC/SysTools/Cpp.hs | 4 + docs/users_guide/9.8.1-notes.rst | 18 ++ docs/users_guide/using.rst | 18 ++ libraries/ghc-prim/changelog.md | 18 ++ rts/RtsSymbols.c | 1 - rts/StgPrimFloat.c | 1 - testsuite/driver/cpu_features.py | 5 + .../tests/primops/should_run/FMA_ConstantFold.hs | 236 ++++++++++++++++++ testsuite/tests/primops/should_run/FMA_Primops.hs | 264 +++++++++++++++++++++ testsuite/tests/primops/should_run/all.T | 11 + 35 files changed, 1244 insertions(+), 26 deletions(-) create mode 100644 testsuite/tests/primops/should_run/FMA_ConstantFold.hs create mode 100644 testsuite/tests/primops/should_run/FMA_Primops.hs diff --git a/compiler/GHC/Builtin/primops.txt.pp b/compiler/GHC/Builtin/primops.txt.pp index 5b730c1943..b2b3f1d8f5 100644 --- a/compiler/GHC/Builtin/primops.txt.pp +++ b/compiler/GHC/Builtin/primops.txt.pp @@ -1369,6 +1369,75 @@ primop FloatDecode_IntOp "decodeFloat_Int#" GenPrimOp First 'Int#' in result is the mantissa; second is the exponent.} with out_of_line = True +------------------------------------------------------------------------ +section "Fused multiply-add operations" + { #fma# + + The fused multiply-add primops 'fmaddFloat#' and 'fmaddDouble#' + implement the operation + + \[ + \lambda\ x\ y\ z \rightarrow x * y + z + \] + + with a single floating-point rounding operation at the end, as opposed to + rounding twice (which can accumulate rounding errors). + + These primops can be compiled directly to a single machine instruction on + architectures that support them. Currently, these are: + + 1. x86 with CPUs that support the FMA3 extended instruction set (which + includes most processors since 2013). + 2. PowerPC. + 3. AArch64. + + This requires users pass the '-mfma' flag to GHC. Otherwise, the primop + is implemented by falling back to the C standard library, which might + perform software emulation (this may yield results that are not IEEE + compliant on some platforms). + + The additional operations 'fmsubFloat#'/'fmsubDouble#', + 'fnmaddFloat#'/'fnmaddDouble#' and 'fnmsubFloat#'/'fnmsubDouble#' provide + variants on 'fmaddFloat#'/'fmaddDouble#' in which some signs are changed: + + \[ + \begin{aligned} + \mathrm{fmadd}\ x\ y\ z &= \phantom{+} x * y + z \\[8pt] + \mathrm{fmsub}\ x\ y\ z &= \phantom{+} x * y - z \\[8pt] + \mathrm{fnmadd}\ x\ y\ z &= - x * y + z \\[8pt] + \mathrm{fnmsub}\ x\ y\ z &= - x * y - z + \end{aligned} + \] + + } +------------------------------------------------------------------------ + +primop FloatFMAdd "fmaddFloat#" GenPrimOp + Float# -> Float# -> Float# -> Float# + {Fused multiply-add operation @x*y+z@. See "GHC.Prim#fma".} +primop FloatFMSub "fmsubFloat#" GenPrimOp + Float# -> Float# -> Float# -> Float# + {Fused multiply-subtract operation @x*y-z@. See "GHC.Prim#fma".} +primop FloatFNMAdd "fnmaddFloat#" GenPrimOp + Float# -> Float# -> Float# -> Float# + {Fused negate-multiply-add operation @-x*y+z@. See "GHC.Prim#fma".} +primop FloatFNMSub "fnmsubFloat#" GenPrimOp + Float# -> Float# -> Float# -> Float# + {Fused negate-multiply-subtract operation @-x*y-z@. See "GHC.Prim#fma".} + +primop DoubleFMAdd "fmaddDouble#" GenPrimOp + Double# -> Double# -> Double# -> Double# + {Fused multiply-add operation @x*y+z@. See "GHC.Prim#fma".} +primop DoubleFMSub "fmsubDouble#" GenPrimOp + Double# -> Double# -> Double# -> Double# + {Fused multiply-subtract operation @x*y-z@. See "GHC.Prim#fma".} +primop DoubleFNMAdd "fnmaddDouble#" GenPrimOp + Double# -> Double# -> Double# -> Double# + {Fused negate-multiply-add operation @-x*y+z@. See "GHC.Prim#fma".} +primop DoubleFNMSub "fnmsubDouble#" GenPrimOp + Double# -> Double# -> Double# -> Double# + {Fused negate-multiply-subtract operation @-x*y-z@. See "GHC.Prim#fma".} + ------------------------------------------------------------------------ section "Arrays" {Operations on 'Array#'.} diff --git a/compiler/GHC/Cmm/MachOp.hs b/compiler/GHC/Cmm/MachOp.hs index d134fdc346..29863853c1 100644 --- a/compiler/GHC/Cmm/MachOp.hs +++ b/compiler/GHC/Cmm/MachOp.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE LambdaCase #-} + {-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} module GHC.Cmm.MachOp @@ -26,6 +28,9 @@ module GHC.Cmm.MachOp -- Atomic read-modify-write , MemoryOrdering(..) , AtomicMachOp(..) + + -- Fused multiply-add + , FMASign(..), pprFMASign ) where @@ -88,6 +93,10 @@ data MachOp | MO_F_Mul Width | MO_F_Quot Width + -- Floating-point fused multiply-add operations + -- | Fused multiply-add, see 'FMASign'. + | MO_FMA FMASign Width + -- Floating point comparison | MO_F_Eq Width | MO_F_Ne Width @@ -160,7 +169,30 @@ data MachOp pprMachOp :: MachOp -> SDoc pprMachOp mo = text (show mo) +-- | Where are the signs in a fused multiply-add instruction? +-- +-- @x*y + z@ vs @x*y - z@ vs @-x*y+z@ vs @-x*y-z@. +-- +-- Warning: the signs aren't consistent across architectures (X86, PowerPC, AArch64). +-- The user-facing implementation uses the X86 convention, while the relevant +-- backends use their corresponding conventions. +data FMASign + -- | Fused multiply-add @x*y + z@. + = FMAdd + -- | Fused multiply-subtract. On X86: @x*y - z@. + | FMSub + -- | Fused multiply-add. On X86: @-x*y + z@. + | FNMAdd + -- | Fused multiply-subtract. On X86: @-x*y - z@. + | FNMSub + deriving (Eq, Show) +pprFMASign :: IsLine doc => FMASign -> doc +pprFMASign = \case + FMAdd -> text "fmadd" + FMSub -> text "fmsub" + FNMAdd -> text "fnmadd" + FNMSub -> text "fnmsub" -- ----------------------------------------------------------------------------- -- Some common MachReps @@ -398,6 +430,9 @@ machOpResultType platform mop tys = MO_F_Mul r -> cmmFloat r MO_F_Quot r -> cmmFloat r MO_F_Neg r -> cmmFloat r + + MO_FMA _ r -> cmmFloat r + MO_F_Eq {} -> comparisonResultRep platform MO_F_Ne {} -> comparisonResultRep platform MO_F_Ge {} -> comparisonResultRep platform @@ -489,6 +524,9 @@ machOpArgReps platform op = MO_F_Mul r -> [r,r] MO_F_Quot r -> [r,r] MO_F_Neg r -> [r] + + MO_FMA _ r -> [r,r,r] + MO_F_Eq r -> [r,r] MO_F_Ne r -> [r,r] MO_F_Ge r -> [r,r] diff --git a/compiler/GHC/Cmm/Parser.y b/compiler/GHC/Cmm/Parser.y index 29870dc647..495e72e37d 100644 --- a/compiler/GHC/Cmm/Parser.y +++ b/compiler/GHC/Cmm/Parser.y @@ -1009,7 +1009,7 @@ machOps = listToUFM $ ( "eq", MO_Eq ), ( "ne", MO_Ne ), ( "mul", MO_Mul ), - ( "mulmayoflo", MO_S_MulMayOflo ), + ( "mulmayoflo", MO_S_MulMayOflo ), ( "neg", MO_S_Neg ), ( "quot", MO_S_Quot ), ( "rem", MO_S_Rem ), @@ -1040,6 +1040,11 @@ machOps = listToUFM $ ( "fmul", MO_F_Mul ), ( "fquot", MO_F_Quot ), + ( "fmadd" , MO_FMA FMAdd ), + ( "fmsub" , MO_FMA FMSub ), + ( "fnmadd", MO_FMA FNMAdd ), + ( "fnmsub", MO_FMA FNMSub ), + ( "feq", MO_F_Eq ), ( "fne", MO_F_Ne ), ( "fge", MO_F_Ge ), diff --git a/compiler/GHC/CmmToAsm/AArch64/CodeGen.hs b/compiler/GHC/CmmToAsm/AArch64/CodeGen.hs index 8ebccaf093..c0e9a7e8d5 100644 --- a/compiler/GHC/CmmToAsm/AArch64/CodeGen.hs +++ b/compiler/GHC/CmmToAsm/AArch64/CodeGen.hs @@ -783,7 +783,7 @@ getRegister' config plat expr where w' = formatToWidth (cmmTypeFormat (cmmRegType reg)) r' = getRegisterReg plat reg - -- Generic case. + -- Generic binary case. CmmMachOp op [x, y] -> do -- alright, so we have an operation, and two expressions. And we want to essentially do -- ensure we get float regs (TODO(Ben): What?) @@ -956,7 +956,44 @@ getRegister' config plat expr -- TODO - op -> pprPanic "getRegister' (unhandled dyadic CmmMachOp): " $ (pprMachOp op) <+> text "in" <+> (pdoc plat expr) + op -> pprPanic "getRegister' (unhandled dyadic CmmMachOp): " $ + (pprMachOp op) <+> text "in" <+> (pdoc plat expr) + + -- Generic ternary case. + CmmMachOp op [x, y, z] -> + + case op of + + -- Floating-point fused multiply-add operations + + -- x86 fmadd x * y + z <=> AArch64 fmadd : d = r1 * r2 + r3 + -- x86 fmsub x * y - z <=> AArch64 fnmsub: d = r1 * r2 - r3 + -- x86 fnmadd - x * y + z <=> AArch64 fmsub : d = - r1 * r2 + r3 + -- x86 fnmsub - x * y - z <=> AArch64 fnmadd: d = - r1 * r2 - r3 + + MO_FMA var w -> case var of + FMAdd -> float3Op w (\d n m a -> unitOL $ FMA FMAdd d n m a) + FMSub -> float3Op w (\d n m a -> unitOL $ FMA FNMSub d n m a) + FNMAdd -> float3Op w (\d n m a -> unitOL $ FMA FMSub d n m a) + FNMSub -> float3Op w (\d n m a -> unitOL $ FMA FNMAdd d n m a) + + _ -> pprPanic "getRegister' (unhandled ternary CmmMachOp): " $ + (pprMachOp op) <+> text "in" <+> (pdoc plat expr) + + where + float3Op w op = do + (reg_fx, format_x, code_fx) <- getFloatReg x + (reg_fy, format_y, code_fy) <- getFloatReg y + (reg_fz, format_z, code_fz) <- getFloatReg z + massertPpr (isFloatFormat format_x && isFloatFormat format_y && isFloatFormat format_z) $ + text "float3Op: non-float" + return $ + Any (floatFormat w) $ \ dst -> + code_fx `appOL` + code_fy `appOL` + code_fz `appOL` + op (OpReg w dst) (OpReg w reg_fx) (OpReg w reg_fy) (OpReg w reg_fz) + CmmMachOp _op _xs -> pprPanic "getRegister' (variadic CmmMachOp): " (pdoc plat expr) diff --git a/compiler/GHC/CmmToAsm/AArch64/Instr.hs b/compiler/GHC/CmmToAsm/AArch64/Instr.hs index 7bf78becb6..166ab2ca17 100644 --- a/compiler/GHC/CmmToAsm/AArch64/Instr.hs +++ b/compiler/GHC/CmmToAsm/AArch64/Instr.hs @@ -142,6 +142,8 @@ regUsageOfInstr platform instr = case instr of SCVTF dst src -> usage (regOp src, regOp dst) FCVTZS dst src -> usage (regOp src, regOp dst) FABS dst src -> usage (regOp src, regOp dst) + FMA _ dst src1 src2 src3 -> + usage (regOp src1 ++ regOp src2 ++ regOp src3, regOp dst) _ -> panic $ "regUsageOfInstr: " ++ instrCon instr @@ -280,6 +282,9 @@ patchRegsOfInstr instr env = case instr of SCVTF o1 o2 -> SCVTF (patchOp o1) (patchOp o2) FCVTZS o1 o2 -> FCVTZS (patchOp o1) (patchOp o2) FABS o1 o2 -> FABS (patchOp o1) (patchOp o2) + FMA s o1 o2 o3 o4 -> + FMA s (patchOp o1) (patchOp o2) (patchOp o3) (patchOp o4) + _ -> panic $ "patchRegsOfInstr: " ++ instrCon instr where patchOp :: Operand -> Operand @@ -650,6 +655,14 @@ data Instr -- Float ABSolute value | FABS Operand Operand + -- | Floating-point fused multiply-add instructions + -- + -- - fmadd : d = r1 * r2 + r3 + -- - fnmsub: d = r1 * r2 - r3 + -- - fmsub : d = - r1 * r2 + r3 + -- - fnmadd: d = - r1 * r2 - r3 + | FMA FMASign Operand Operand Operand Operand + instrCon :: Instr -> String instrCon i = case i of @@ -715,6 +728,12 @@ instrCon i = SCVTF{} -> "SCVTF" FCVTZS{} -> "FCVTZS" FABS{} -> "FABS" + FMA variant _ _ _ _ -> + case variant of + FMAdd -> "FMADD" + FMSub -> "FMSUB" + FNMAdd -> "FNMADD" + FNMSub -> "FNMSUB" data Target = TBlock BlockId diff --git a/compiler/GHC/CmmToAsm/AArch64/Ppr.hs b/compiler/GHC/CmmToAsm/AArch64/Ppr.hs index 475324afce..646f914c8d 100644 --- a/compiler/GHC/CmmToAsm/AArch64/Ppr.hs +++ b/compiler/GHC/CmmToAsm/AArch64/Ppr.hs @@ -546,6 +546,13 @@ pprInstr platform instr = case instr of SCVTF o1 o2 -> op2 (text "\tscvtf") o1 o2 FCVTZS o1 o2 -> op2 (text "\tfcvtzs") o1 o2 FABS o1 o2 -> op2 (text "\tfabs") o1 o2 + FMA variant d r1 r2 r3 -> + let fma = case variant of + FMAdd -> text "\tfmadd" + FMSub -> text "\tfmsub" + FNMAdd -> text "\tfnmadd" + FNMSub -> text "\tfnmsub" + in op4 fma d r1 r2 r3 where op2 op o1 o2 = line $ op <+> pprOp platform o1 <> comma <+> pprOp platform o2 op3 op o1 o2 o3 = line $ op <+> pprOp platform o1 <> comma <+> pprOp platform o2 <> comma <+> pprOp platform o3 op4 op o1 o2 o3 o4 = line $ op <+> pprOp platform o1 <> comma <+> pprOp platform o2 <> comma <+> pprOp platform o3 <> comma <+> pprOp platform o4 diff --git a/compiler/GHC/CmmToAsm/PPC/CodeGen.hs b/compiler/GHC/CmmToAsm/PPC/CodeGen.hs index 7dac4f221b..f8a726da6c 100644 --- a/compiler/GHC/CmmToAsm/PPC/CodeGen.hs +++ b/compiler/GHC/CmmToAsm/PPC/CodeGen.hs @@ -649,6 +649,21 @@ getRegister' _ _ (CmmMachOp mop [x, y]) -- dyadic PrimOps code <- remainderCode rep sgn tmp x y return (Any fmt code) +getRegister' _ _ (CmmMachOp mop [x, y, z]) -- ternary PrimOps + = case mop of + + -- x86 fmadd x * y + z <> PPC fmadd rt = ra * rc + rb + -- x86 fmsub x * y - z <> PPC fmsub rt = ra * rc - rb + -- x86 fnmadd - x * y + z ~~ PPC fnmsub rt = -(ra * rc - rb) + -- x86 fnmsub - x * y - z ~~ PPC fnmadd rt = -(ra * rc + rb) + + MO_FMA variant w -> + case variant of + FMAdd -> fma_code w (FMADD FMAdd) x y z + FMSub -> fma_code w (FMADD FMSub) x y z + FNMAdd -> fma_code w (FMADD FNMAdd) x y z + FNMSub -> fma_code w (FMADD FNMSub) x y z + _ -> panic "PPC.CodeGen.getRegister: no match" getRegister' _ _ (CmmLit (CmmInt i rep)) | Just imm <- makeImmediate rep True i @@ -2358,10 +2373,28 @@ trivialUCode rep instr x = do let code' dst = code `snocOL` instr dst src return (Any rep code') +-- | Generate code for a 4-register FMA instruction, +-- e.g. @fmadd rt ra rc rb := rt <- ra * rc + rb@. +fma_code :: Width + -> (Format -> Reg -> Reg -> Reg -> Reg -> Instr) + -> CmmExpr + -> CmmExpr + -> CmmExpr + -> NatM Register +fma_code w instr ra rc rb = do + let rep = floatFormat w + (src1, code1) <- getSomeReg ra + (src2, code2) <- getSomeReg rc + (src3, code3) <- getSomeReg rb + let instrCode rt = + code1 `appOL` + code2 `appOL` + code3 `snocOL` instr rep rt src1 src2 src3 + return $ Any rep instrCode + -- There is no "remainder" instruction on the PPC, so we have to do -- it the hard way. -- The "sgn" parameter is the signedness for the division instruction - remainderCode :: Width -> Bool -> Reg -> CmmExpr -> CmmExpr -> NatM (Reg -> InstrBlock) remainderCode rep sgn reg_q arg_x arg_y = do diff --git a/compiler/GHC/CmmToAsm/PPC/Instr.hs b/compiler/GHC/CmmToAsm/PPC/Instr.hs index 639ae979f8..3fedcc1fc4 100644 --- a/compiler/GHC/CmmToAsm/PPC/Instr.hs +++ b/compiler/GHC/CmmToAsm/PPC/Instr.hs @@ -280,6 +280,14 @@ data Instr | FABS Reg Reg -- abs is the same for single and double | FNEG Reg Reg -- negate is the same for single and double prec. + -- | Fused multiply-add instructions. + -- + -- - FMADD: @rd = (ra * rb) + rd@ + -- - FMSUB: @rd = ra * rb - rd@ + -- - FNMADD: @rd = -(ra * rb + rd)@ + -- - FNMSUB: @rd = -(ra * rb - rd)@ + | FMADD FMASign Format Reg Reg Reg Reg + | FCMP Reg Reg | FCTIWZ Reg Reg -- convert to integer word @@ -380,6 +388,7 @@ regUsageOfInstr platform instr MFCR reg -> usage ([], [reg]) MFLR reg -> usage ([], [reg]) FETCHPC reg -> usage ([], [reg]) + FMADD _ _ rt ra rc rb -> usage ([ra, rc, rb], [rt]) _ -> noUsage where usage (src, dst) = RU (filter (interesting platform) src) @@ -467,6 +476,8 @@ patchRegsOfInstr instr env FDIV fmt r1 r2 r3 -> FDIV fmt (env r1) (env r2) (env r3) FABS r1 r2 -> FABS (env r1) (env r2) FNEG r1 r2 -> FNEG (env r1) (env r2) + FMADD sgn fmt r1 r2 r3 r4 + -> FMADD sgn fmt (env r1) (env r2) (env r3) (env r4) FCMP r1 r2 -> FCMP (env r1) (env r2) FCTIWZ r1 r2 -> FCTIWZ (env r1) (env r2) FCTIDZ r1 r2 -> FCTIDZ (env r1) (env r2) diff --git a/compiler/GHC/CmmToAsm/PPC/Ppr.hs b/compiler/GHC/CmmToAsm/PPC/Ppr.hs index ba364df1b0..f1d6733327 100644 --- a/compiler/GHC/CmmToAsm/PPC/Ppr.hs +++ b/compiler/GHC/CmmToAsm/PPC/Ppr.hs @@ -934,6 +934,9 @@ pprInstr platform instr = case instr of FNEG reg1 reg2 -> pprUnary (text "fneg") reg1 reg2 + FMADD signs fmt dst ra rc rb + -> pprTernaryF (pprFMASign signs) fmt dst ra rc rb + FCMP reg1 reg2 -> line $ hcat [ char '\t', @@ -1083,6 +1086,21 @@ pprBinaryF op fmt reg1 reg2 reg3 = line $ hcat [ pprReg reg3 ] +pprTernaryF :: IsDoc doc => Line doc -> Format -> Reg -> Reg -> Reg -> Reg -> doc +pprTernaryF op fmt rt ra rc rb = line $ hcat [ + char '\t', + op, + pprFFormat fmt, + char '\t', + pprReg rt, + text ", ", + pprReg ra, + text ", ", + pprReg rc, + text ", ", + pprReg rb + ] + pprRI :: IsLine doc => Platform -> RI -> doc pprRI _ (RIReg r) = pprReg r pprRI platform (RIImm r) = pprImm platform r diff --git a/compiler/GHC/CmmToAsm/Wasm/FromCmm.hs b/compiler/GHC/CmmToAsm/Wasm/FromCmm.hs index 7ca323d72d..9a4c3f34c2 100644 --- a/compiler/GHC/CmmToAsm/Wasm/FromCmm.hs +++ b/compiler/GHC/CmmToAsm/Wasm/FromCmm.hs @@ -816,7 +816,9 @@ lower_CmmMachOp lbl (MO_SS_Conv w0 w1) xs = lower_MO_SS_Conv lbl w0 w1 xs lower_CmmMachOp lbl (MO_UU_Conv w0 w1) xs = lower_MO_UU_Conv lbl w0 w1 xs lower_CmmMachOp lbl (MO_XX_Conv w0 w1) xs = lower_MO_UU_Conv lbl w0 w1 xs lower_CmmMachOp lbl (MO_FF_Conv w0 w1) xs = lower_MO_FF_Conv lbl w0 w1 xs -lower_CmmMachOp _ _ _ = panic "lower_CmmMachOp: unreachable" +lower_CmmMachOp _ mop _ = + pprPanic "lower_CmmMachOp: unreachable" $ + vcat [ text "offending MachOp:" <+> pprMachOp mop ] -- | Lower a 'CmmLit'. Note that we don't emit 'f32.const' or -- 'f64.const' for the time being, and instead emit their relative bit diff --git a/compiler/GHC/CmmToAsm/X86/CodeGen.hs b/compiler/GHC/CmmToAsm/X86/CodeGen.hs index d6ef821c9f..859b27e248 100644 --- a/compiler/GHC/CmmToAsm/X86/CodeGen.hs +++ b/compiler/GHC/CmmToAsm/X86/CodeGen.hs @@ -901,14 +901,10 @@ getRegister' _ is32Bit (CmmMachOp mop [x, y]) = -- dyadic MachOps MO_U_Lt _ -> condIntReg LU x y MO_U_Le _ -> condIntReg LEU x y - MO_F_Add w -> trivialFCode_sse2 w ADD x y - - MO_F_Sub w -> trivialFCode_sse2 w SUB x y - - MO_F_Quot w -> trivialFCode_sse2 w FDIV x y - - MO_F_Mul w -> trivialFCode_sse2 w MUL x y - + MO_F_Add w -> trivialFCode_sse2 w ADD x y + MO_F_Sub w -> trivialFCode_sse2 w SUB x y + MO_F_Quot w -> trivialFCode_sse2 w FDIV x y + MO_F_Mul w -> trivialFCode_sse2 w MUL x y MO_Add rep -> add_code rep x y MO_Sub rep -> sub_code rep x y @@ -1113,6 +1109,13 @@ getRegister' _ is32Bit (CmmMachOp mop [x, y]) = -- dyadic MachOps return (Fixed format result code) +getRegister' _plat _is32Bit (CmmMachOp mop [x, y, z]) = -- ternary MachOps + case mop of + -- Floating point fused multiply-add operations @ ± x*y ± z@ + MO_FMA var w -> genFMA3Code w var x y z + + _other -> pprPanic "getRegister(x86) - ternary CmmMachOp (1)" + (pprMachOp mop) getRegister' _ _ (CmmLoad mem pk _) | isFloatType pk @@ -3151,12 +3154,12 @@ genTrivialCode rep instr a b = do a_code <- getAnyReg a tmp <- getNewRegNat rep let - -- We want the value of b to stay alive across the computation of a. - -- But, we want to calculate a straight into the destination register, + -- We want the value of 'b' to stay alive across the computation of 'a'. + -- But, we want to calculate 'a' straight into the destination register, -- because the instruction only has two operands (dst := dst `op` src). - -- The troublesome case is when the result of b is in the same register - -- as the destination reg. In this case, we have to save b in a - -- new temporary across the computation of a. + -- The troublesome case is when the result of 'b' is in the same register + -- as the destination 'reg'. In this case, we have to save 'b' in a + -- new temporary across the computation of 'a'. code dst | dst `regClashesWithOp` b_op = b_code `appOL` @@ -3174,6 +3177,69 @@ reg `regClashesWithOp` OpReg reg2 = reg == reg2 reg `regClashesWithOp` OpAddr amode = any (==reg) (addrModeRegs amode) _ `regClashesWithOp` _ = False +-- | Generate code for a fused multiply-add operation, of the form @± x * y ± z@, +-- with 3 operands (FMA3 instruction set). +genFMA3Code :: Width + -> FMASign + -> CmmExpr -> CmmExpr -> CmmExpr -> NatM Register +genFMA3Code w signs x y z = do + + -- For the FMA instruction, we want to compute x * y + z + -- + -- There are three possible instructions we could emit: + -- + -- - fmadd213 z y x, result in x, z can be a memory address + -- - fmadd132 x z y, result in y, x can be a memory address + -- - fmadd231 y x z, result in z, y can be a memory address + -- + -- This suggests two possible optimisations: + -- + -- - OPTIMISATION 1 + -- If one argument is an address, use the instruction that allows + -- a memory address in that position. + -- + -- - OPTIMISATION 2 + -- If one argument is in a fixed register, use the instruction that puts + -- the result in that same register. + -- + -- Currently we follow neither of these optimisations, + -- opting to always use fmadd213 for simplicity. + let rep = floatFormat w + (y_reg, y_code) <- getNonClobberedReg y + (z_reg, z_code) <- getNonClobberedReg z + x_code <- getAnyReg x + y_tmp <- getNewRegNat rep + z_tmp <- getNewRegNat rep + let + fma213 = FMA3 rep signs FMA213 + code dst + | dst == y_reg + , dst == z_reg + = y_code `appOL` + unitOL (MOV rep (OpReg y_reg) (OpReg y_tmp)) `appOL` + z_code `appOL` + unitOL (MOV rep (OpReg z_reg) (OpReg z_tmp)) `appOL` + x_code dst `snocOL` + fma213 (OpReg z_tmp) y_tmp dst + | dst == y_reg + = y_code `appOL` + unitOL (MOV rep (OpReg y_reg) (OpReg z_tmp)) `appOL` + z_code `appOL` + x_code dst `snocOL` + fma213 (OpReg z_reg) y_tmp dst + | dst == z_reg + = y_code `appOL` + z_code `appOL` + unitOL (MOV rep (OpReg z_reg) (OpReg z_tmp)) `appOL` + x_code dst `snocOL` + fma213 (OpReg z_tmp) y_reg dst + | otherwise + = y_code `appOL` + z_code `appOL` + x_code dst `snocOL` + fma213 (OpReg z_reg) y_reg dst + return (Any rep code) + ----------- trivialUCode :: Format -> (Operand -> Instr) diff --git a/compiler/GHC/CmmToAsm/X86/Instr.hs b/compiler/GHC/CmmToAsm/X86/Instr.hs index ccb3ce09ba..b4e93a1c5d 100644 --- a/compiler/GHC/CmmToAsm/X86/Instr.hs +++ b/compiler/GHC/CmmToAsm/X86/Instr.hs @@ -12,6 +12,7 @@ module GHC.CmmToAsm.X86.Instr ( Instr(..) , Operand(..) , PrefetchVariant(..) + , FMAPermutation(..) , JumpDest(..) , getJumpDestBlockId , canShortcut @@ -272,6 +273,10 @@ data Instr | CVTSI2SS Format Operand Reg -- I32/I64 to F32 | CVTSI2SD Format Operand Reg -- I32/I64 to F64 + -- | FMA3 fused multiply-add operations. + | FMA3 Format FMASign FMAPermutation Operand Reg Reg + -- src1 (r/m), src2 (r), dst (r) + -- use ADD, SUB, and SQRT for arithmetic. In both cases, operands -- are Operand Reg. @@ -351,7 +356,7 @@ data Operand | OpImm Imm -- immediate value | OpAddr AddrMode -- memory reference - +data FMAPermutation = FMA132 | FMA213 | FMA231 -- | Returns which registers are read and written as a (read, written) -- pair. @@ -438,6 +443,8 @@ regUsageOfInstr platform instr PDEP _ src mask dst -> mkRU (use_R src $ use_R mask []) [dst] PEXT _ src mask dst -> mkRU (use_R src $ use_R mask []) [dst] + FMA3 _ _ _ src1 src2 dst -> usageFMA src1 src2 dst + -- note: might be a better way to do this PREFETCH _ _ src -> mkRU (use_R src []) [] LOCK i -> regUsageOfInstr platform i @@ -482,6 +489,15 @@ regUsageOfInstr platform instr usageRMM (OpReg src) (OpAddr ea) (OpReg reg) = mkRU (use_EA ea [src, reg]) [reg] usageRMM _ _ _ = panic "X86.RegInfo.usageRMM: no match" + -- 3 operand form of FMA instructions. + usageFMA :: Operand -> Reg -> Reg -> RegUsage + usageFMA (OpReg src1) src2 dst + = mkRU [src1, src2, dst] [dst] + usageFMA (OpAddr ea1) src2 dst + = mkRU (use_EA ea1 [src2, dst]) [dst] + usageFMA _ _ _ + = panic "X86.RegInfo.usageFMA: no match" + -- 1 operand form; operand Modified usageM :: Operand -> RegUsage usageM (OpReg reg) = mkRU [reg] [reg] @@ -561,6 +577,8 @@ patchRegsOfInstr instr env JMP op regs -> JMP (patchOp op) regs JMP_TBL op ids s lbl -> JMP_TBL (patchOp op) ids s lbl + FMA3 fmt perm var x1 x2 x3 -> patch3 (FMA3 fmt perm var) x1 x2 x3 + -- literally only support storing the top x87 stack value st(0) X87Store fmt dst -> X87Store fmt (lookupAddr dst) @@ -612,6 +630,8 @@ patchRegsOfInstr instr env patch1 insn op = insn $! patchOp op patch2 :: (Operand -> Operand -> a) -> Operand -> Operand -> a patch2 insn src dst = (insn $! patchOp src) $! patchOp dst + patch3 :: (Operand -> Reg -> Reg -> a) -> Operand -> Reg -> Reg -> a + patch3 insn src1 src2 dst = ((insn $! patchOp src1) $! env src2) $! env dst patchOp (OpReg reg) = OpReg $! env reg patchOp (OpImm imm) = OpImm imm diff --git a/compiler/GHC/CmmToAsm/X86/Ppr.hs b/compiler/GHC/CmmToAsm/X86/Ppr.hs index 4a8f55fdf0..0d649f2efb 100644 --- a/compiler/GHC/CmmToAsm/X86/Ppr.hs +++ b/compiler/GHC/CmmToAsm/X86/Ppr.hs @@ -838,6 +838,14 @@ pprInstr platform i = case i of FDIV format op1 op2 -> pprFormatOpOp (text "div") format op1 op2 + FMA3 format var perm op1 op2 op3 + -> let mnemo = case var of + FMAdd -> text "vfmadd" + FMSub -> text "vfmsub" + FNMAdd -> text "vfnmadd" + FNMSub -> text "vfnmsub" + in pprFormatOpRegReg (mnemo <> pprFMAPermutation perm) format op1 op2 op3 + SQRT format op1 op2 -> pprFormatOpReg (text "sqrt") format op1 op2 @@ -968,6 +976,21 @@ pprInstr platform i = case i of pprOperand platform format op2 ] + pprFormatOpRegReg :: Line doc -> Format -> Operand -> Reg -> Reg -> doc + pprFormatOpRegReg name format op1 op2 op3 + = line $ hcat [ + pprMnemonic name format, + pprOperand platform format op1, + comma, + pprReg platform format op2, + comma, + pprReg platform format op3 + ] + + pprFMAPermutation :: FMAPermutation -> Line doc + pprFMAPermutation FMA132 = text "132" + pprFMAPermutation FMA213 = text "213" + pprFMAPermutation FMA231 = text "231" pprOpOp :: Line doc -> Format -> Operand -> Operand -> doc pprOpOp name format op1 op2 diff --git a/compiler/GHC/CmmToC.hs b/compiler/GHC/CmmToC.hs index d4ac5cd4e4..fb8774c77c 100644 --- a/compiler/GHC/CmmToC.hs +++ b/compiler/GHC/CmmToC.hs @@ -529,6 +529,11 @@ machOpNeedsCast platform mop args pprMachOpApp' :: Platform -> MachOp -> [CmmExpr] -> SDoc pprMachOpApp' platform mop args = case args of + + -- ternary + args@[_,_,_] -> + pprMachOp_for_C platform mop <> parens (pprWithCommas pprArg args) + -- dyadic [x,y] -> pprArg x <+> pprMachOp_for_C platform mop <+> pprArg y @@ -711,13 +716,28 @@ pprMachOp_for_C platform mop = case mop of MO_U_Quot _ -> char '/' MO_U_Rem _ -> char '%' - -- & Floating-point operations + -- Floating-point operations MO_F_Add _ -> char '+' MO_F_Sub _ -> char '-' MO_F_Neg _ -> char '-' MO_F_Mul _ -> char '*' MO_F_Quot _ -> char '/' + -- Floating-point fused multiply-add operations + MO_FMA FMAdd w -> + case w of + W32 -> text "fmaf" + W64 -> text "fma" + _ -> + pprTrace "offending mop:" + (text "FMAdd") + (panic $ "PprC.pprMachOp_for_C: FMAdd unsupported" + ++ "at width " ++ show w) + MO_FMA var _width -> + pprTrace "offending mop:" + (text $ "FMA " ++ show var) + (panic $ "PprC.pprMachOp_for_C: should have been handled earlier!") + -- Signed comparisons MO_S_Ge _ -> text ">=" MO_S_Le _ -> text "<=" diff --git a/compiler/GHC/CmmToLlvm/CodeGen.hs b/compiler/GHC/CmmToLlvm/CodeGen.hs index 1d658b359e..ba72aa1417 100644 --- a/compiler/GHC/CmmToLlvm/CodeGen.hs +++ b/compiler/GHC/CmmToLlvm/CodeGen.hs @@ -1469,6 +1469,9 @@ genMachOp _ op [x] = case op of MO_F_Sub _ -> panicOp MO_F_Mul _ -> panicOp MO_F_Quot _ -> panicOp + + MO_FMA _ _ -> panicOp + MO_F_Eq _ -> panicOp MO_F_Ne _ -> panicOp MO_F_Ge _ -> panicOp @@ -1652,6 +1655,8 @@ genMachOp_slow opt op [x, y] = case op of MO_F_Mul _ -> genBinMach LM_MO_FMul MO_F_Quot _ -> genBinMach LM_MO_FDiv + MO_FMA _ _ -> panicOp + MO_And _ -> genBinMach LM_MO_And MO_Or _ -> genBinMach LM_MO_Or MO_Xor _ -> genBinMach LM_MO_Xor @@ -1785,8 +1790,27 @@ genMachOp_slow opt op [x, y] = case op of panicOp = panic $ "LLVM.CodeGen.genMachOp_slow: unary op encountered" ++ "with two arguments! (" ++ show op ++ ")" --- More than two expression, invalid! -genMachOp_slow _ _ _ = panic "genMachOp: More than 2 expressions in MachOp!" +genMachOp_slow _opt op [x, y, z] = case op of + MO_FMA var _ -> triLlvmOp getVarType (FMAOp var) + _ -> panicOp + where + triLlvmOp ty op = do + platform <- getPlatform + runExprData $ do + vx <- exprToVarW x + vy <- exprToVarW y + vz <- exprToVarW z + + if | getVarType vx == getVarType vy + , getVarType vx == getVarType vz + -> doExprW (ty vx) $ op vx vy vz + | otherwise + -> pprPanic "triLlvmOp types" (pdoc platform x $$ pdoc platform y $$ pdoc platform z) + panicOp = panic $ "LLVM.CodeGen.genMachOp_slow: non-ternary op encountered" + ++ "with three arguments! (" ++ show op ++ ")" + +-- More than three expressions, invalid! +genMachOp_slow _ _ _ = panic "genMachOp_slow: More than 3 expressions in MachOp!" -- | Handle CmmLoad expression. diff --git a/compiler/GHC/Core/Opt/ConstantFold.hs b/compiler/GHC/Core/Opt/ConstantFold.hs index 42ced5a86a..fbccce15bb 100644 --- a/compiler/GHC/Core/Opt/ConstantFold.hs +++ b/compiler/GHC/Core/Opt/ConstantFold.hs @@ -65,6 +65,9 @@ import GHC.Builtin.Types import GHC.Builtin.Types.Prim import GHC.Builtin.Names +import GHC.Cmm.MachOp ( FMASign(..) ) +import GHC.Cmm.Type ( Width(..) ) + import GHC.Data.FastString import GHC.Data.Maybe ( orElse ) @@ -677,6 +680,11 @@ primOpRules nm = \case FloatMulOp -> mkPrimOpRule nm 2 [ binaryLit (floatOp2 (*)) , identity onef , strengthReduction twof FloatAddOp ] + FloatFMAdd -> mkPrimOpRule nm 3 (fmaRules FMAdd W32) + FloatFMSub -> mkPrimOpRule nm 3 (fmaRules FMSub W32) + FloatFNMAdd -> mkPrimOpRule nm 3 (fmaRules FNMAdd W32) + FloatFNMSub -> mkPrimOpRule nm 3 (fmaRules FNMSub W32) + -- zeroElem zerof doesn't hold because of NaN FloatDivOp -> mkPrimOpRule nm 2 [ guardFloatDiv >> binaryLit (floatOp2 (/)) , rightIdentity onef ] @@ -692,6 +700,10 @@ primOpRules nm = \case DoubleMulOp -> mkPrimOpRule nm 2 [ binaryLit (doubleOp2 (*)) , identity oned , strengthReduction twod DoubleAddOp ] + DoubleFMAdd -> mkPrimOpRule nm 3 (fmaRules FMAdd W64) + DoubleFMSub -> mkPrimOpRule nm 3 (fmaRules FMSub W64) + DoubleFNMAdd -> mkPrimOpRule nm 3 (fmaRules FNMAdd W64) + DoubleFNMSub -> mkPrimOpRule nm 3 (fmaRules FNMSub W64) -- zeroElem zerod doesn't hold because of NaN DoubleDivOp -> mkPrimOpRule nm 2 [ guardDoubleDiv >> binaryLit (doubleOp2 (/)) , rightIdentity oned ] @@ -1138,6 +1150,150 @@ doubleDecodeOp env (LitDouble ((decodeFloat . fromRational @Double) -> (m, e))) doubleDecodeOp _ _ = Nothing +-------------------------- + +-- | Constant folding rules for fused multiply-add operations. +fmaRules :: FMASign -> Width -> [RuleM CoreExpr] +fmaRules signs width = + [ fmaLit signs width + , fmaZero_z signs width + , fmaOne signs width ] + +-- | Compute @a * b + c@ when @a@, @b@, @c@ are all literals. +fmaLit :: FMASign -> Width -> RuleM CoreExpr +fmaLit signs width = do + env <- getRuleOpts + [Lit l1, Lit l2, Lit l3] <- getArgs + liftMaybe $ + op env + (convFloating env l1) + (convFloating env l2) + (convFloating env l3) + + where + op env l1 l2 l3 = + case width of + W32 + | LitFloat x <- l1 + , LitFloat y <- l2 + , LitFloat z <- l3 + -> Just $ mkFloatVal env $ + case signs of + FMAdd -> x * y + z + FMSub -> x * y - z + FNMAdd -> negate ( x * y ) + z + FNMSub -> negate ( x * y ) - z + W64 + | LitDouble x <- l1 + , LitDouble y <- l2 + , LitDouble z <- l3 + -> Just $ mkDoubleVal env $ + case signs of + FMAdd -> x * y + z + FMSub -> x * y - z + FNMAdd -> negate ( x * y ) + z + FNMSub -> negate ( x * y ) - z + _ -> Nothing + +-- | @x * y + 0 = x * y@. +fmaZero_z :: FMASign -> Width -> RuleM CoreExpr +fmaZero_z signs width = do + [x, y, Lit z] <- getArgs + let + -- TODO: we should additionally check the sign of z. + -- FMAdd, FNMAdd: should be -0.0. + -- FMSub, FNMSub: should be +0.0. + ok = + case width of + W32 + | LitFloat 0 <- z + -> True + W64 + | LitDouble 0 <- z + -> True + _ -> False + neg = case width of + W32 -> FloatNegOp + W64 -> DoubleNegOp + _ -> panic "fmaZero_xy: not Float# or Double#" + mul = case width of + W32 -> FloatMulOp + W64 -> DoubleMulOp + _ -> panic "fmaZero_z: not Float# or Double#" + if ok + then return $ case signs of + FMAdd -> Var (primOpId mul) `App` x `App` y + FMSub -> Var (primOpId mul) `App` x `App` y + FNMAdd -> Var (primOpId neg) `App` (Var (primOpId mul) `App` x `App` y) + FNMSub -> Var (primOpId neg) `App` (Var (primOpId mul) `App` x `App` y) + else mzero + +-- | @±1 * y + z ==> z ± y@ and @x * ±1 + z ==> z ± x@. +fmaOne :: FMASign -> Width -> RuleM CoreExpr +fmaOne signs width = do + [x, y, z] <- getArgs + let + posNegOne_maybe :: Rational -> Maybe Bool + posNegOne_maybe i + | i == 1 + = Just False + | i == -1 + = Just True + | otherwise + = Nothing + ok = + case width of + W32 + | Lit (LitFloat i) <- x + , Just sgn <- posNegOne_maybe i + -> Just (sgn, y) + | Lit (LitFloat i) <- y + , Just sgn <- posNegOne_maybe i + -> Just (sgn, x) + W64 + | Lit (LitDouble i) <- x + , Just sgn <- posNegOne_maybe i + -> Just (sgn, y) + | Lit (LitDouble i) <- y + , Just sgn <- posNegOne_maybe i + -> Just (sgn, x) + _ -> Nothing + neg = case width of + W32 -> FloatNegOp + W64 -> DoubleNegOp + _ -> panic "fmaOne: not Float# or Double#" + add = case width of + W32 -> FloatAddOp + W64 -> DoubleAddOp + _ -> panic "fmaOne: not Float# or Double#" + sub = case width of + W32 -> FloatSubOp + W64 -> DoubleSubOp + _ -> panic "fmaOne: not Float# or Double#" + case ok of + Nothing -> mzero + Just (sgn, t) -> return $ + if -- t + z + | ( signs == FMAdd && sgn == False ) + || ( signs == FNMAdd && sgn == True ) + -> Var (primOpId add) `App` t `App` z + -- - t + z + | signs == FMAdd + || signs == FNMAdd + -> Var (primOpId sub) `App` z `App` t + -- t - z + | ( signs == FMSub && sgn == False ) + || ( signs == FNMSub && sgn == True ) + -> Var (primOpId sub) `App` t `App` z + -- - t - z + | signs == FMSub + || signs == FNMSub + -> Var (primOpId neg) `App` (Var (primOpId add) `App` t `App` z) + | otherwise + -> pprPanic "fmaOne: non-exhaustive pattern match" $ + vcat [ text "signs:" <+> text (show signs) + , text "sign:" <+> ppr sgn ] + -------------------------- {- Note [The litEq rule: converting equality to case] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/compiler/GHC/Driver/Config/StgToCmm.hs b/compiler/GHC/Driver/Config/StgToCmm.hs index 283ece1d50..f738974c46 100644 --- a/compiler/GHC/Driver/Config/StgToCmm.hs +++ b/compiler/GHC/Driver/Config/StgToCmm.hs @@ -1,3 +1,6 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} + module GHC.Driver.Config.StgToCmm ( initStgToCmmConfig ) where @@ -6,6 +9,7 @@ import GHC.Prelude.Basic import GHC.StgToCmm.Config +import GHC.Cmm.MachOp ( FMASign(..)) import GHC.Driver.Backend import GHC.Driver.Session import GHC.Platform @@ -50,6 +54,24 @@ initStgToCmmConfig dflags mod = StgToCmmConfig , stgToCmmAllowQuotRemInstr = ncg && (x86ish || ppc) , stgToCmmAllowQuotRem2 = (ncg && (x86ish || ppc)) || llvm , stgToCmmAllowExtendedAddSubInstrs = (ncg && (x86ish || ppc)) || llvm + , stgToCmmAllowFMAInstr = + if + | not (isFmaEnabled dflags) + || not (ncg || llvm) + -- If we're not using the native code generator or LLVM, + -- fall back to the generic implementation. + || platformArch platform == ArchWasm32 + -- WASM doesn't support native FMA instructions (at the time of writing). + -> const False + + -- FNMSub and FNMAdd have different semantics on PowerPC, + -- so we avoid using them. + | ppc + -> \ case { FMAdd -> True; FMSub -> True; _ -> False } + + | otherwise + -> const True + , stgToCmmAllowIntMul2Instr = (ncg && x86ish) || llvm -- SIMD flags , stgToCmmVecInstrsErr = vec_err diff --git a/compiler/GHC/Driver/Pipeline/Execute.hs b/compiler/GHC/Driver/Pipeline/Execute.hs index 7694975e80..84113df8eb 100644 --- a/compiler/GHC/Driver/Pipeline/Execute.hs +++ b/compiler/GHC/Driver/Pipeline/Execute.hs @@ -1028,6 +1028,7 @@ llvmOptions llvm_config dflags = ++ ["+avx512cd"| isAvx512cdEnabled dflags ] ++ ["+avx512er"| isAvx512erEnabled dflags ] ++ ["+avx512pf"| isAvx512pfEnabled dflags ] + ++ ["+fma" | isFmaEnabled dflags ] ++ ["+bmi" | isBmiEnabled dflags ] ++ ["+bmi2" | isBmi2Enabled dflags ] diff --git a/compiler/GHC/Driver/Session.hs b/compiler/GHC/Driver/Session.hs index 6a7ba74477..52361abb09 100644 --- a/compiler/GHC/Driver/Session.hs +++ b/compiler/GHC/Driver/Session.hs @@ -208,6 +208,7 @@ module GHC.Driver.Session ( isAvx512erEnabled, isAvx512fEnabled, isAvx512pfEnabled, + isFmaEnabled, -- * Linker/compiler information LinkerInfo(..), @@ -718,6 +719,7 @@ data DynFlags = DynFlags { avx512er :: Bool, -- Enable AVX-512 Exponential and Reciprocal Instructions. avx512f :: Bool, -- Enable AVX-512 instructions. avx512pf :: Bool, -- Enable AVX-512 PreFetch Instructions. + fma :: Bool, -- ^ Enable FMA instructions. -- | Run-time linker information (what options we need, etc.) rtldInfo :: IORef (Maybe LinkerInfo), @@ -1303,6 +1305,7 @@ defaultDynFlags mySettings = avx512er = False, avx512f = False, avx512pf = False, + fma = False, rtldInfo = panic "defaultDynFlags: no rtldInfo", rtccInfo = panic "defaultDynFlags: no rtccInfo", rtasmInfo = panic "defaultDynFlags: no rtasmInfo", @@ -2730,6 +2733,7 @@ dynamic_flags_deps = [ , make_ord_flag defGhcFlag "mavx512f" (noArg (\d -> d { avx512f = True })) , make_ord_flag defGhcFlag "mavx512pf" (noArg (\d -> d { avx512pf = True })) + , make_ord_flag defGhcFlag "mfma" (noArg (\d -> d { fma = True })) ------ Plugin flags ------------------------------------------------ , make_ord_flag defGhcFlag "fplugin-opt" (hasArg addPluginModuleNameOption) @@ -5008,7 +5012,7 @@ setUnsafeGlobalDynFlags dflags = do -- ----------------------------------------------------------------------------- --- SSE and AVX +-- SSE, AVX, FMA isSse4_2Enabled :: DynFlags -> Bool isSse4_2Enabled dflags = sseVersion dflags >= Just SSE42 @@ -5031,6 +5035,9 @@ isAvx512fEnabled dflags = avx512f dflags isAvx512pfEnabled :: DynFlags -> Bool isAvx512pfEnabled dflags = avx512pf dflags +isFmaEnabled :: DynFlags -> Bool +isFmaEnabled dflags = fma dflags + -- ----------------------------------------------------------------------------- -- BMI2 @@ -5046,6 +5053,8 @@ isBmi2Enabled dflags = case platformArch (targetPlatform dflags) of ArchX86 -> bmiVersion dflags >= Just BMI2 _ -> False +-- ----------------------------------------------------------------------------- + -- | Indicate if cost-centre profiling is enabled sccProfilingEnabled :: DynFlags -> Bool sccProfilingEnabled dflags = profileIsProfiling (targetProfile dflags) diff --git a/compiler/GHC/Llvm/Ppr.hs b/compiler/GHC/Llvm/Ppr.hs index 36bfdf3405..f0f09a46ef 100644 --- a/compiler/GHC/Llvm/Ppr.hs +++ b/compiler/GHC/Llvm/Ppr.hs @@ -40,6 +40,7 @@ import GHC.Llvm.Types import Data.List ( intersperse ) import GHC.Utils.Outputable +import GHC.Cmm.MachOp ( FMASign(..), pprFMASign ) import GHC.CmmToLlvm.Config import GHC.Utils.Panic import GHC.Types.Unique @@ -288,6 +289,7 @@ ppLlvmExpression opts expr AtomicRMW aop tgt src ordering -> ppAtomicRMW opts aop tgt src ordering CmpXChg addr old new s_ord f_ord -> ppCmpXChg opts addr old new s_ord f_ord Phi tp predecessors -> ppPhi opts tp predecessors + FMAOp op x y z -> pprFMAOp opts op x y z Asm asm c ty v se sk -> ppAsm opts asm c ty v se sk MExpr meta expr -> ppMetaAnnotExpr opts meta expr {-# SPECIALIZE ppLlvmExpression :: LlvmCgConfig -> LlvmExpression -> SDoc #-} @@ -374,6 +376,12 @@ ppCmpOp opts op left right = {-# SPECIALIZE ppCmpOp :: LlvmCgConfig -> LlvmCmpOp -> LlvmVar -> LlvmVar -> SDoc #-} {-# SPECIALIZE ppCmpOp :: LlvmCgConfig -> LlvmCmpOp -> LlvmVar -> LlvmVar -> HLine #-} -- see Note [SPECIALIZE to HDoc] in GHC.Utils.Outputable +pprFMAOp :: IsLine doc => LlvmCgConfig -> FMASign -> LlvmVar -> LlvmVar -> LlvmVar -> doc +pprFMAOp opts signs x y z = + pprFMASign signs <+> ppLlvmType (getVarType x) + <+> ppName opts x <> comma + <+> ppName opts y <> comma + <+> ppName opts z ppAssignment :: IsLine doc => LlvmCgConfig -> LlvmVar -> doc -> doc ppAssignment opts var expr = ppName opts var <+> equals <+> expr diff --git a/compiler/GHC/Llvm/Syntax.hs b/compiler/GHC/Llvm/Syntax.hs index 882cb0660b..390380ad17 100644 --- a/compiler/GHC/Llvm/Syntax.hs +++ b/compiler/GHC/Llvm/Syntax.hs @@ -10,6 +10,7 @@ import GHC.Llvm.MetaData import GHC.Llvm.Types import GHC.Types.Unique +import GHC.Cmm.MachOp ( FMASign(..) ) -- | Block labels type LlvmBlockId = Unique @@ -338,6 +339,8 @@ data LlvmExpression -} | Asm LMString LMString LlvmType [LlvmVar] Bool Bool + | FMAOp FMASign LlvmVar LlvmVar LlvmVar + {- | A LLVM expression with metadata attached to it. -} diff --git a/compiler/GHC/Llvm/Types.hs b/compiler/GHC/Llvm/Types.hs index f80b261584..ce2a82ce34 100644 --- a/compiler/GHC/Llvm/Types.hs +++ b/compiler/GHC/Llvm/Types.hs @@ -775,7 +775,6 @@ ppLlvmCastOp LM_Bitcast = text "bitcast" {-# SPECIALIZE ppLlvmCastOp :: LlvmCastOp -> SDoc #-} {-# SPECIALIZE ppLlvmCastOp :: LlvmCastOp -> HLine #-} -- see Note [SPECIALIZE to HDoc] in GHC.Utils.Outputable - -- ----------------------------------------------------------------------------- -- * Floating point conversion -- diff --git a/compiler/GHC/StgToCmm/Config.hs b/compiler/GHC/StgToCmm/Config.hs index f2bd349ae7..1465cde2e0 100644 --- a/compiler/GHC/StgToCmm/Config.hs +++ b/compiler/GHC/StgToCmm/Config.hs @@ -11,6 +11,7 @@ import GHC.Unit.Module import GHC.Utils.Outputable import GHC.Utils.TmpFs +import GHC.Cmm.MachOp ( FMASign(..) ) import GHC.Prelude @@ -66,6 +67,7 @@ data StgToCmmConfig = StgToCmmConfig , stgToCmmAllowQuotRem2 :: !Bool -- ^ Allowed to generate QuotRem , stgToCmmAllowExtendedAddSubInstrs :: !Bool -- ^ Allowed to generate AddWordC, SubWordC, Add2, etc. , stgToCmmAllowIntMul2Instr :: !Bool -- ^ Allowed to generate IntMul2 instruction + , stgToCmmAllowFMAInstr :: FMASign -> Bool -- ^ Allowed to generate FMA instruction , stgToCmmTickyAP :: !Bool -- ^ Disable use of precomputed standard thunks. ------------------------------ SIMD flags ------------------------------------ -- Each of these flags checks vector compatibility with the backend requested diff --git a/compiler/GHC/StgToCmm/Prim.hs b/compiler/GHC/StgToCmm/Prim.hs index f4a1924f19..d1f6bb0191 100644 --- a/compiler/GHC/StgToCmm/Prim.hs +++ b/compiler/GHC/StgToCmm/Prim.hs @@ -1395,6 +1395,11 @@ emitPrimOp cfg primop = DoubleDivOp -> \args -> opTranslate args (MO_F_Quot W64) DoubleNegOp -> \args -> opTranslate args (MO_F_Neg W64) + DoubleFMAdd -> fmaOp FMAdd W64 + DoubleFMSub -> fmaOp FMSub W64 + DoubleFNMAdd -> fmaOp FNMAdd W64 + DoubleFNMSub -> fmaOp FNMSub W64 + -- Float ops FloatEqOp -> \args -> opTranslate args (MO_F_Eq W32) @@ -1410,6 +1415,11 @@ emitPrimOp cfg primop = FloatDivOp -> \args -> opTranslate args (MO_F_Quot W32) FloatNegOp -> \args -> opTranslate args (MO_F_Neg W32) + FloatFMAdd -> fmaOp FMAdd W32 + FloatFMSub -> fmaOp FMSub W32 + FloatFNMAdd -> fmaOp FNMAdd W32 + FloatFNMSub -> fmaOp FNMSub W32 + -- Vector ops (VecAddOp FloatVec n w) -> \args -> opTranslate args (MO_VF_Add n w) @@ -1735,6 +1745,27 @@ emitPrimOp cfg primop = allowExtAdd = stgToCmmAllowExtendedAddSubInstrs cfg allowInt2Mul = stgToCmmAllowIntMul2Instr cfg + allowFMA = stgToCmmAllowFMAInstr cfg + + fmaOp :: FMASign -> Width -> [CmmActual] -> PrimopCmmEmit + fmaOp signs w args@[arg_x, arg_y, arg_z] + | allowFMA signs + = opTranslate args (MO_FMA signs w) + | otherwise + = case signs of + + -- For fused multiply-add x * y + z, we fall back to the C implementation. + FMAdd -> opIntoRegs $ \ [res] -> fmaCCall w res arg_x arg_y arg_z + + -- Other fused multiply-add operations are implemented in terms of fmadd + -- This is sound: it does not lose any precision. + FMSub -> fmaOp FMAdd w [arg_x, arg_y, neg arg_z] + FNMAdd -> fmaOp FMAdd w [neg arg_x, arg_y, arg_z] + FNMSub -> fmaOp FMAdd w [neg arg_x, arg_y, neg arg_z] + where + neg x = CmmMachOp (MO_F_Neg w) [x] + fmaOp _ _ _ = panic "fmaOp: wrong number of arguments (expected 3)" + data PrimopCmmEmit -- | Out of line fake primop that's actually just a foreign call to other -- (presumably) C--. @@ -2023,6 +2054,19 @@ genericIntMul2Op [res_c, res_h, res_l] both_args@[arg_x, arg_y] ] genericIntMul2Op _ _ = panic "genericIntMul2Op" +fmaCCall :: Width -> CmmFormal -> CmmActual -> CmmActual -> CmmActual -> FCode () +fmaCCall width res arg_x arg_y arg_z = + emitCCall + [(res,NoHint)] + (CmmLit (CmmLabel fma_lbl)) + [(arg_x,NoHint), (arg_y,NoHint), (arg_z,NoHint)] + where + fma_lbl = mkForeignLabel fma_op Nothing ForeignLabelInExternalPackage IsFunction + fma_op = case width of + W32 -> fsLit "fmaf" + W64 -> fsLit "fma" + _ -> panic ("fmaCall: " ++ show width) + ------------------------------------------------------------------------------ -- Helpers for translating various minor variants of array indexing. diff --git a/compiler/GHC/StgToJS/Prim.hs b/compiler/GHC/StgToJS/Prim.hs index 36f12e3409..c051318b22 100644 --- a/compiler/GHC/StgToJS/Prim.hs +++ b/compiler/GHC/StgToJS/Prim.hs @@ -488,6 +488,11 @@ genPrim prof bound ty op = case op of DoubleDecode_2IntOp -> \[s,h,l,e] [x] -> PrimInline $ appT [s,h,l,e] "h$decodeDouble2Int" [x] DoubleDecode_Int64Op -> \[s1,s2,e] [d] -> PrimInline $ appT [e,s1,s2] "h$decodeDoubleInt64" [d] + DoubleFMAdd -> unhandledPrimop op + DoubleFMSub -> unhandledPrimop op + DoubleFNMAdd -> unhandledPrimop op + DoubleFNMSub -> unhandledPrimop op + ------------------------------ Float -------------------------------------------- FloatGtOp -> \[r] [x,y] -> PrimInline $ r |= if10 (x .>. y) @@ -524,6 +529,11 @@ genPrim prof bound ty op = case op of FloatToDoubleOp -> \[r] [x] -> PrimInline $ r |= x FloatDecode_IntOp -> \[s,e] [x] -> PrimInline $ appT [s,e] "h$decodeFloatInt" [x] + FloatFMAdd -> unhandledPrimop op + FloatFMSub -> unhandledPrimop op + FloatFNMAdd -> unhandledPrimop op + FloatFNMSub -> unhandledPrimop op + ------------------------------ Arrays ------------------------------------------- NewArrayOp -> \[r] [l,e] -> PrimInline $ r |= app "h$newArray" [l,e] diff --git a/compiler/GHC/SysTools/Cpp.hs b/compiler/GHC/SysTools/Cpp.hs index 61f70342a6..9410410d01 100644 --- a/compiler/GHC/SysTools/Cpp.hs +++ b/compiler/GHC/SysTools/Cpp.hs @@ -98,6 +98,9 @@ doCpp logger tmpfs dflags unit_env opts input_fn output_fn = do [ "-D__SSE2__" | isSse2Enabled platform ] ++ [ "-D__SSE4_2__" | isSse4_2Enabled dflags ] + let fma_def = + [ "-D__FMA__" | isFmaEnabled dflags ] + let avx_defs = [ "-D__AVX__" | isAvxEnabled dflags ] ++ [ "-D__AVX2__" | isAvx2Enabled dflags ] ++ @@ -140,6 +143,7 @@ doCpp logger tmpfs dflags unit_env opts input_fn output_fn = do ++ map GHC.SysTools.Option th_defs ++ map GHC.SysTools.Option hscpp_opts ++ map GHC.SysTools.Option sse_defs + ++ map GHC.SysTools.Option fma_def ++ map GHC.SysTools.Option avx_defs ++ map GHC.SysTools.Option io_manager_defs ++ mb_macro_include diff --git a/docs/users_guide/9.8.1-notes.rst b/docs/users_guide/9.8.1-notes.rst index e7e9acdf75..84d9105efd 100644 --- a/docs/users_guide/9.8.1-notes.rst +++ b/docs/users_guide/9.8.1-notes.rst @@ -142,6 +142,24 @@ Runtime system - ``sameMutVar#``, ``sameTVar#``, ``sameMVar#`` - ``sameIOPort#``, ``eqStableName#``. +- New primops for fused multiply-add operations. These primops combine a + multiplication and an addition, compiling to a single instruction when + the ``-mfma`` flag is enabled and the architecture supports it. + + The new primops are ``fmaddFloat#, fmsubFloat#, fnmaddFloat#, fnmsubFloat# :: Float# -> Float# -> Float# -> Float#`` + and ``fmaddDouble#, fmsubDouble#, fnmaddDouble#, fnmsubDouble# :: Double# -> Double# -> Double# -> Double#``. + + These implement the following operations, while performing one single + rounding at the end, leading to a more accurate result: + + - ``fmaddFloat# x y z``, ``fmaddDouble# x y z`` compute ``x * y + z``. + - ``fmsubFloat# x y z``, ``fmsubDouble# x y z`` compute ``x * y - z``. + - ``fnmaddFloat# x y z``, ``fnmaddDouble# x y z`` compute ``- x * y + z``. + - ``fnmsubFloat# x y z``, ``fnmsubDouble# x y z`` compute ``- x * y - z``. + + Warning: on unsupported architectures, the software emulation provided by + the fallback to the C standard library is not guaranteed to be IEEE-compliant. + ``ghc`` library ~~~~~~~~~~~~~~~ diff --git a/docs/users_guide/using.rst b/docs/users_guide/using.rst index 787b6a0503..8de7dd3533 100644 --- a/docs/users_guide/using.rst +++ b/docs/users_guide/using.rst @@ -1732,6 +1732,24 @@ Some flags only make sense for particular target platforms. :ref:`native code generator `. The resulting compiled code will only run on processors that support BMI2 (Intel Haswell and newer, AMD Excavator, Zen and newer). +.. ghc-flag:: -mfma + :shortdesc: Use native FMA instructions for fused multiply-add floating-point operations + :type: dynamic + :category: platform-options + + :since: 9.8.1 + + Use native FMA instructions to implement the fused multiply-add floating-point + operations of the form ``x * y + z``. + This allows computing a multiplication and addition in a single instruction, + without an intermediate rounding step. + Supported architectures: X86 with the FMA3 instruction set (this includes + most consumer processors since 2013), PowerPC and AArch64. + + When this flag is disabled, GHC falls back to the C implementation of fused + multiply-add, which might perform non-IEEE-compliant software emulation on + some platforms (depending on the implementation of the C standard library). + Haddock ------- diff --git a/libraries/ghc-prim/changelog.md b/libraries/ghc-prim/changelog.md index 1cf411c029..39e5face03 100644 --- a/libraries/ghc-prim/changelog.md +++ b/libraries/ghc-prim/changelog.md @@ -23,6 +23,24 @@ - `copyAddrToAddrNonOverlapping#` - `setAddrRange#` +- New primops for fused multiply-add operations. These primops combine a + multiplication and an addition, compiling to a single instruction when + the `-mfma` flag is enabled and the architecture supports it. + + The new primops are `fmaddFloat#, fmsubFloat#, fnmaddFloat#, fnmsubFloat# :: Float# -> Float# -> Float# -> Float#` + and `fmaddDouble#, fmsubDouble#, fnmaddDouble#, fnmsubDouble# :: Double# -> Double# -> Double# -> Double#`. + + These implement the following operations, while performing one single + rounding at the end, leading to a more accurate result: + + - `fmaddFloat# x y z`, `fmaddDouble# x y z` compute `x * y + z`. + - `fmsubFloat# x y z`, `fmsubDouble# x y z` compute `x * y - z`. + - `fnmaddFloat# x y z`, `fnmaddDouble# x y z` compute `- x * y + z`. + - `fnmsubFloat# x y z`, `fnmsubDouble# x y z` compute `- x * y - z`. + + Warning: on unsupported architectures, the software emulation provided by + the fallback to the C standard library is not guaranteed to be IEEE-compliant. + ## 0.10.0 - Shipped with GHC 9.6.1 diff --git a/rts/RtsSymbols.c b/rts/RtsSymbols.c index dee6c57f5e..d5ed5bb543 100644 --- a/rts/RtsSymbols.c +++ b/rts/RtsSymbols.c @@ -930,7 +930,6 @@ extern char **environ; RTS_USER_SIGNALS_SYMBOLS \ RTS_INTCHAR_SYMBOLS - // 64-bit support functions in libgcc.a #if defined(__GNUC__) && SIZEOF_VOID_P <= 4 && !defined(_ABIN32) #define RTS_LIBGCC_SYMBOLS \ diff --git a/rts/StgPrimFloat.c b/rts/StgPrimFloat.c index a8c266ae78..d105a0d76b 100644 --- a/rts/StgPrimFloat.c +++ b/rts/StgPrimFloat.c @@ -248,4 +248,3 @@ __decodeFloat_Int (I_ *man, I_ *exp, StgFloat flt) *man = - *man; } } - diff --git a/testsuite/driver/cpu_features.py b/testsuite/driver/cpu_features.py index 43500130db..f6c441d88f 100644 --- a/testsuite/driver/cpu_features.py +++ b/testsuite/driver/cpu_features.py @@ -10,6 +10,7 @@ SUPPORTED_CPU_FEATURES = { # x86: 'sse', 'sse2', 'sse3', 'ssse3', 'sse4_1', 'sse4_2', 'avx1', 'avx2', + 'fma', 'popcnt', 'bmi1', 'bmi2' } @@ -46,6 +47,10 @@ def get_cpu_features(): check_feature('avx2_0', 'avx2') return features + elif config.arch in [ 'powerpc', 'powerpc64' ]: + # Hardcode support for 'fma' on PowerPC + return [ 'fma' ] + else: # TODO: Add {Open,Free}BSD support print('get_cpu_features: Lacking support for your platform') diff --git a/testsuite/tests/primops/should_run/FMA_ConstantFold.hs b/testsuite/tests/primops/should_run/FMA_ConstantFold.hs new file mode 100644 index 0000000000..80bac3231c --- /dev/null +++ b/testsuite/tests/primops/should_run/FMA_ConstantFold.hs @@ -0,0 +1,236 @@ +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE LexicalNegation #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE TypeApplications #-} + +module Main where + +import Control.Monad + ( unless ) +import Data.IORef + ( newIORef, readIORef, writeIORef ) +import GHC.Exts + ( Float(..), Float#, Double(..), Double# + , fmaddFloat# , fmsubFloat# , fnmaddFloat# , fnmsubFloat# + , fmaddDouble#, fmsubDouble#, fnmaddDouble#, fnmsubDouble# + ) +import GHC.Float + ( castFloatToWord32, castDoubleToWord64 ) +import System.Exit + ( exitFailure, exitSuccess ) + +-------------------------------------------------------------------------------- + +-- NB: This test tests constant folding for fused multiply-add operations. +-- See FMA_Primops for the test of the primops. + +-- Want "on-the-nose" equality to make sure we properly distinguish 0.0 and -0.0. +class StrictEq a where + strictlyEqual :: a -> a -> Bool +instance StrictEq Float where + strictlyEqual x y = castFloatToWord32 x == castFloatToWord32 y +instance StrictEq Double where + strictlyEqual x y = castDoubleToWord64 x == castDoubleToWord64 y + +class FMA a where + fmadd, fmsub, fnmadd, fnmsub :: a -> a -> a -> a +instance FMA Float where + fmadd (F# x) (F# y) (F# z) = F# ( fmaddFloat# x y z) + fmsub (F# x) (F# y) (F# z) = F# ( fmsubFloat# x y z) + fnmadd (F# x) (F# y) (F# z) = F# (fnmaddFloat# x y z) + fnmsub (F# x) (F# y) (F# z) = F# (fnmsubFloat# x y z) + {-# INLINE fmadd #-} + {-# INLINE fnmadd #-} + {-# INLINE fmsub #-} + {-# INLINE fnmsub #-} +instance FMA Double where + fmadd (D# x) (D# y) (D# z) = D# ( fmaddDouble# x y z) + fmsub (D# x) (D# y) (D# z) = D# ( fmsubDouble# x y z) + fnmadd (D# x) (D# y) (D# z) = D# (fnmaddDouble# x y z) + fnmsub (D# x) (D# y) (D# z) = D# (fnmsubDouble# x y z) + {-# INLINE fmadd #-} + {-# INLINE fnmadd #-} + {-# INLINE fmsub #-} + {-# INLINE fnmsub #-} + +main :: IO () +main = do + + exit_ref <- newIORef False + + let + it :: forall a. ( StrictEq a, Show a ) => String -> a -> a -> IO () + it desc actual expected = + unless (actual `strictlyEqual` expected) do + writeIORef exit_ref True + putStrLn $ unlines + [ "FAIL " ++ desc + , " expected: " ++ show expected + , " actual: " ++ show actual ] + + -- NB: throughout this test, we are using "round to nearest". + + -- fmadd: x * y + z + + -- Float + it "fmaddFloat#: sniff test" + ( fmadd @Float 2 3 1 ) 7 + + it "fmaddFloat#: excess precision" + ( fmadd @Float 0.99999994 1.00000012 -1 ) 5.96046377e-08 + + it "fmaddFloat#: +0 + +0 rounds properly" + ( fmadd @Float 1 0 0 ) 0 + + it "fmaddFloat#: +0 + -0 rounds properly" + ( fmadd @Float 1 0 -0 ) 0 + + it "fmaddFloat#: -0 + +0 rounds properly" + ( fmadd @Float 1 -0 0 ) 0 + + it "fmaddFloat#: -0 + -0 rounds properly" + ( fmadd @Float 1 -0 -0 ) -0 + + -- Double + it "fmaddDouble#: sniff test" + ( fmadd @Double 2 3 1 ) 7 + + it "fmaddDouble#: excess precision" + ( fmadd @Double 0.99999999999999989 1.0000000000000002 -1 ) 1.1102230246251563e-16 + + it "fmaddDouble#: +0 + +0 rounds properly" + ( fmadd @Double 1 0 0 ) 0 + + it "fmaddDouble#: +0 + -0 rounds properly" + ( fmadd @Double 1 0 -0 ) 0 + + it "fmaddDouble#: -0 + +0 rounds properly" + ( fmadd @Double 1 -0 0 ) 0 + + it "fmaddDouble#: -0 + -0 rounds properly" + ( fmadd @Double 1 -0 -0 ) -0 + + -- fmsub: x * y - z + + -- Float + it "fmsubFloat#: sniff test" + ( fmsub @Float 2 3 1 ) 5.0 + + it "fmsubFloat#: excess precision" + ( fmsub @Float 0.99999994 1.00000012 1 ) 5.96046377e-08 + + it "fmsubFloat#: +0 + +0 rounds properly" + ( fmsub @Float 1 0 0 ) 0 + + it "fmsubFloat#: +0 + -0 rounds properly" + ( fmsub @Float 1 0 -0 ) 0 + + it "fmsubFloat#: -0 + +0 rounds properly" + ( fmsub @Float 1 -0 0 ) -0 + + it "fmsubFloat#: -0 + -0 rounds properly" + ( fmsub @Float 1 -0 -0 ) 0 + + -- Double + it "fmsubDouble#: sniff test" + ( fmsub @Double 2 3 1 ) 5.0 + + it "fmsubDouble#: excess precision" + ( fmsub @Double 0.99999999999999989 1.0000000000000002 1 ) 1.1102230246251563e-16 + + it "fmsubDouble#: +0 + +0 rounds properly" + ( fmsub @Double 1 0 0 ) 0 + + it "fmsubDouble#: +0 + -0 rounds properly" + ( fmsub @Double 1 0 -0 ) 0 + + it "fmsubDouble#: -0 + +0 rounds properly" + ( fmsub @Double 1 -0 0 ) -0 + + it "fmsubDouble#: -0 + -0 rounds properly" + ( fmsub @Double 1 -0 -0 ) 0 + + -- fnmadd: - x * y + z + + -- Float + it "fnmaddFloat#: sniff test" + ( fnmadd @Float 2 3 1 ) -5.0 + + it "fnmaddFloat#: excess precision" + ( fnmadd @Float 0.99999994 1.00000012 1 ) -5.96046377e-08 + + it "fnmaddFloat#: +0 + +0 rounds properly" + ( fnmadd @Float 1 0 0 ) 0 + + it "fnmaddFloat#: +0 + -0 rounds properly" + ( fnmadd @Float 1 0 -0 ) -0 + + it "fnmaddFloat#: -0 + +0 rounds properly" + ( fnmadd @Float 1 -0 0 ) 0 + + it "fnmaddFloat#: -0 + -0 rounds properly" + ( fnmadd @Float 1 -0 -0 ) 0 + + -- Double + it "fnmaddDouble#: sniff test" + ( fnmadd @Double 2 3 1 ) -5.0 + + it "fnmaddDouble#: excess precision" + ( fnmadd @Double 0.99999999999999989 1.0000000000000002 1 ) -1.1102230246251563e-16 + + it "fnmaddDouble#: +0 + +0 rounds properly" + ( fnmadd @Double 1 0 0 ) 0 + + it "fnmaddDouble#: +0 + -0 rounds properly" + ( fnmadd @Double 1 0 -0 ) -0 + + it "fnmaddDouble#: -0 + +0 rounds properly" + ( fnmadd @Double 1 -0 0 ) 0 + + it "fnmaddDouble#: -0 + -0 rounds properly" + ( fnmadd @Double 1 -0 -0 ) 0 + + -- fnmsub: - x * y - z + + -- Float + it "fnmsubFloat#: sniff test" + ( fnmsub @Float 2 3 1 ) -7 + + it "fnmsubFloat#: excess precision" + ( fnmsub @Float 0.99999994 1.00000012 -1 ) -5.96046377e-08 + + it "fnmsubFloat#: +0 + +0 rounds properly" + ( fnmsub @Float 1 0 0 ) -0 + + it "fnmsubFloat#: +0 + -0 rounds properly" + ( fnmsub @Float 1 0 -0 ) 0 + + it "fnmsubFloat#: -0 + +0 rounds properly" + ( fnmsub @Float 1 -0 0 ) 0 + + it "fnmsubFloat#: -0 + -0 rounds properly" + ( fnmsub @Float 1 -0 -0 ) 0 + + -- Double + it "fnmsubDouble#: sniff test" + ( fnmsub @Double 2 3 1 ) -7 + + it "fnmsubDouble#: excess precision" + ( fnmsub @Double 0.99999999999999989 1.0000000000000002 -1 ) -1.1102230246251563e-16 + + it "fnmsubDouble#: +0 + +0 rounds properly" + ( fnmsub @Double 1 0 0 ) -0 + + it "fnmsubDouble#: +0 + -0 rounds properly" + ( fnmsub @Double 1 0 -0 ) 0 + + it "fnmsubDouble#: -0 + +0 rounds properly" + ( fnmsub @Double 1 -0 0 ) 0 + + it "fnmsubDouble#: -0 + -0 rounds properly" + ( fnmsub @Double 1 -0 -0 ) 0 + + failure <- readIORef exit_ref + if failure + then exitFailure + else exitSuccess diff --git a/testsuite/tests/primops/should_run/FMA_Primops.hs b/testsuite/tests/primops/should_run/FMA_Primops.hs new file mode 100644 index 0000000000..a925ff6c3b --- /dev/null +++ b/testsuite/tests/primops/should_run/FMA_Primops.hs @@ -0,0 +1,264 @@ +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE LexicalNegation #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE TypeApplications #-} + +module Main where + +import Control.Monad + ( unless ) +import Data.IORef + ( newIORef, readIORef, writeIORef ) +import GHC.Exts + ( Float(..), Float#, Double(..), Double# + , fmaddFloat# , fmsubFloat# , fnmaddFloat# , fnmsubFloat# + , fmaddDouble#, fmsubDouble#, fnmaddDouble#, fnmsubDouble# + ) +import GHC.Float + ( castFloatToWord32, castDoubleToWord64 ) +import System.Exit + ( exitFailure, exitSuccess ) + +-------------------------------------------------------------------------------- + +-- Want "on-the-nose" equality to make sure we properly distinguish 0.0 and -0.0. +class StrictEq a where + strictlyEqual :: a -> a -> Bool +instance StrictEq Float where + strictlyEqual x y = castFloatToWord32 x == castFloatToWord32 y +instance StrictEq Double where + strictlyEqual x y = castDoubleToWord64 x == castDoubleToWord64 y + +class FMA a where + fmadd, fmsub, fnmadd, fnmsub :: a -> a -> a -> a +instance FMA Float where + fmadd (F# x) (F# y) (F# z) = F# ( fmaddFloat# x y z) + fmsub (F# x) (F# y) (F# z) = F# ( fmsubFloat# x y z) + fnmadd (F# x) (F# y) (F# z) = F# (fnmaddFloat# x y z) + fnmsub (F# x) (F# y) (F# z) = F# (fnmsubFloat# x y z) + {-# INLINE fmadd #-} + {-# INLINE fnmadd #-} + {-# INLINE fmsub #-} + {-# INLINE fnmsub #-} +instance FMA Double where + fmadd (D# x) (D# y) (D# z) = D# ( fmaddDouble# x y z) + fmsub (D# x) (D# y) (D# z) = D# ( fmsubDouble# x y z) + fnmadd (D# x) (D# y) (D# z) = D# (fnmaddDouble# x y z) + fnmsub (D# x) (D# y) (D# z) = D# (fnmsubDouble# x y z) + {-# INLINE fmadd #-} + {-# INLINE fnmadd #-} + {-# INLINE fmsub #-} + {-# INLINE fnmsub #-} + +f1, f2, f3 :: Float +f1 = 0.99999994 -- float before 1 +f2 = 1.00000012 -- float after 1 +f3 = 5.96046377e-08 -- float after 0 +d1, d2, d3 :: Double +d1 = 0.99999999999999989 -- double before 1 +d2 = 1.0000000000000002 -- double after 1 +d3 = 1.1102230246251563e-16 -- double after 0 +zero, one, two, three, five, seven :: Num a => a +zero = 0 +one = 1 +two = 2 +three = 3 +five = 5 +seven = 7 + +-- NOINLINE to prevent constant folding +-- (The test FMA_ConstantFold tests constant folding.) +{-# NOINLINE f1 #-} +{-# NOINLINE f2 #-} +{-# NOINLINE f3 #-} +{-# NOINLINE d1 #-} +{-# NOINLINE d2 #-} +{-# NOINLINE d3 #-} +{-# NOINLINE zero #-} +{-# NOINLINE one #-} +{-# NOINLINE two #-} +{-# NOINLINE three #-} +{-# NOINLINE five #-} +{-# NOINLINE seven #-} + +main :: IO () +main = do + + exit_ref <- newIORef False + + let + it :: forall a. ( StrictEq a, Show a ) => String -> a -> a -> IO () + it desc actual expected = + unless (actual `strictlyEqual` expected) do + writeIORef exit_ref True + putStrLn $ unlines + [ "FAIL " ++ desc + , " expected: " ++ show expected + , " actual: " ++ show actual ] + + -- NB: throughout this test, we are using "round to nearest". + + -- fmadd: x * y + z + + -- Float + it "fmaddFloat#: sniff test" + ( fmadd @Float two three one ) seven + + it "fmaddFloat#: excess precision" + ( fmadd @Float f1 f2 -one ) f3 + + it "fmaddFloat#: +0 + +0 rounds properly" + ( fmadd @Float one zero zero ) zero + + it "fmaddFloat#: +0 + -0 rounds properly" + ( fmadd @Float one zero -zero ) zero + + it "fmaddFloat#: -0 + +0 rounds properly" + ( fmadd @Float one -zero zero ) zero + + it "fmaddFloat#: -0 + -0 rounds properly" + ( fmadd @Float one -zero -zero ) -zero + + -- Double + it "fmaddDouble#: sniff test" + ( fmadd @Double two three one ) seven + + it "fmaddDouble#: excess precision" + ( fmadd @Double d1 d2 -one ) d3 + + it "fmaddDouble#: +0 + +0 rounds properly" + ( fmadd @Double one zero zero ) zero + + it "fmaddDouble#: +0 + -0 rounds properly" + ( fmadd @Double one zero -zero ) zero + + it "fmaddDouble#: -0 + +0 rounds properly" + ( fmadd @Double one -zero zero ) zero + + it "fmaddDouble#: -0 + -0 rounds properly" + ( fmadd @Double one -zero -zero ) -zero + + -- fmsub: x * y - z + + -- Float + it "fmsubFloat#: sniff test" + ( fmsub @Float two three one ) five + + it "fmsubFloat#: excess precision" + ( fmsub @Float f1 f2 one ) f3 + + it "fmsubFloat#: +0 + +0 rounds properly" + ( fmsub @Float one zero zero ) zero + + it "fmsubFloat#: +0 + -0 rounds properly" + ( fmsub @Float one zero -zero ) zero + + it "fmsubFloat#: -0 + +0 rounds properly" + ( fmsub @Float one -zero zero ) -zero + + it "fmsubFloat#: -0 + -0 rounds properly" + ( fmsub @Float one -zero -zero ) zero + + -- Double + it "fmsubDouble#: sniff test" + ( fmsub @Double two three one ) five + + it "fmsubDouble#: excess precision" + ( fmsub @Double d1 d2 one ) d3 + + it "fmsubDouble#: +0 + +0 rounds properly" + ( fmsub @Double one zero zero ) zero + + it "fmsubDouble#: +0 + -0 rounds properly" + ( fmsub @Double one zero -zero ) zero + + it "fmsubDouble#: -0 + +0 rounds properly" + ( fmsub @Double one -zero zero ) -zero + + it "fmsubDouble#: -0 + -0 rounds properly" + ( fmsub @Double one -zero -zero ) zero + + -- fnmadd: - x * y + z + + -- Float + it "fnmaddFloat#: sniff test" + ( fnmadd @Float two three one ) -five + + it "fnmaddFloat#: excess precision" + ( fnmadd @Float f1 f2 one ) -f3 + + it "fnmaddFloat#: +0 + +0 rounds properly" + ( fnmadd @Float one zero zero ) zero + + it "fnmaddFloat#: +0 + -0 rounds properly" + ( fnmadd @Float one zero -zero ) -zero + + it "fnmaddFloat#: -0 + +0 rounds properly" + ( fnmadd @Float one -zero zero ) zero + + it "fnmaddFloat#: -0 + -0 rounds properly" + ( fnmadd @Float one -zero -zero ) zero + + -- Double + it "fnmaddDouble#: sniff test" + ( fnmadd @Double two three one ) -five + + it "fnmaddDouble#: excess precision" + ( fnmadd @Double d1 d2 one ) -d3 + + it "fnmaddDouble#: +0 + +0 rounds properly" + ( fnmadd @Double one zero zero ) zero + + it "fnmaddDouble#: +0 + -0 rounds properly" + ( fnmadd @Double one zero -zero ) -zero + + it "fnmaddDouble#: -0 + +0 rounds properly" + ( fnmadd @Double one -zero zero ) zero + + it "fnmaddDouble#: -0 + -0 rounds properly" + ( fnmadd @Double one -zero -zero ) zero + + -- fnmsub: - x * y - z + + -- Float + it "fnmsubFloat#: sniff test" + ( fnmsub @Float two three one ) -seven + + it "fnmsubFloat#: excess precision" + ( fnmsub @Float f1 f2 -one ) -f3 + + it "fnmsubFloat#: +0 + +0 rounds properly" + ( fnmsub @Float one zero zero ) -zero + + it "fnmsubFloat#: +0 + -0 rounds properly" + ( fnmsub @Float one zero -zero ) zero + + it "fnmsubFloat#: -0 + +0 rounds properly" + ( fnmsub @Float one -zero zero ) zero + + it "fnmsubFloat#: -0 + -0 rounds properly" + ( fnmsub @Float one -zero -zero ) zero + + -- Double + it "fnmsubDouble#: sniff test" + ( fnmsub @Double two three one ) -seven + + it "fnmsubDouble#: excess precision" + ( fnmsub @Double d1 d2 -one ) -d3 + + it "fnmsubDouble#: +0 + +0 rounds properly" + ( fnmsub @Double one zero zero ) -zero + + it "fnmsubDouble#: +0 + -0 rounds properly" + ( fnmsub @Double one zero -zero ) zero + + it "fnmsubDouble#: -0 + +0 rounds properly" + ( fnmsub @Double one -zero zero ) zero + + it "fnmsubDouble#: -0 + -0 rounds properly" + ( fnmsub @Double one -zero -zero ) zero + + failure <- readIORef exit_ref + if failure + then exitFailure + else exitSuccess diff --git a/testsuite/tests/primops/should_run/all.T b/testsuite/tests/primops/should_run/all.T index 4148546280..da6378df84 100644 --- a/testsuite/tests/primops/should_run/all.T +++ b/testsuite/tests/primops/should_run/all.T @@ -59,5 +59,16 @@ test('UnliftedTVar1', normal, compile_and_run, ['']) test('UnliftedTVar2', normal, compile_and_run, ['']) test('UnliftedWeakPtr', normal, compile_and_run, ['']) +test('FMA_Primops' + , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma')) + , js_skip # JS backend doesn't have an FMA implementation + ] + , compile_and_run, ['']) +test('FMA_ConstantFold' + , [ js_skip # JS backend doesn't have an FMA implementation ] + , expect_broken(21227) + ] + , compile_and_run, ['-O']) + test('T21624', normal, compile_and_run, ['']) test('T23071', ignore_stdout, compile_and_run, ['']) -- cgit v1.2.1