summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsheaf <sam.derbyshire@gmail.com>2023-04-08 13:42:58 +0200
committerMarge Bot <ben+marge-bot@smart-cactus.org>2023-05-11 11:55:22 -0400
commit87eebf98cb485f7c9175330051736e147ade9848 (patch)
treeffa226b3fefa8b0a03e1798fa4f55affbddf654b
parent630b1fea1e41a1e00860a30742b6ab8ade8a0de0 (diff)
downloadhaskell-87eebf98cb485f7c9175330051736e147ade9848.tar.gz
Add fused multiply-add instructions
This patch adds eight new primops that fuse a multiplication and an addition or subtraction: - `{fmadd,fmsub,fnmadd,fnmsub}{Float,Double}#` fmadd x y z is x * y + z, computed with a single rounding step. This patch implements code generation for these primops in the following backends: - X86, AArch64 and PowerPC NCG, - LLVM - C WASM uses the C implementation. The primops are unsupported in the JavaScript backend. The following constant folding rules are also provided: - compute a * b + c when a, b, c are all literals, - x * y + 0 ==> x * y, - ±1 * y + z ==> z ± y and x * ±1 + z ==> z ± x. NB: the constant folding rules incorrectly handle signed zero. This is a known limitation with GHC's floating-point constant folding rules (#21227), which we hope to resolve in the future.
-rw-r--r--compiler/GHC/Builtin/primops.txt.pp69
-rw-r--r--compiler/GHC/Cmm/MachOp.hs38
-rw-r--r--compiler/GHC/Cmm/Parser.y7
-rw-r--r--compiler/GHC/CmmToAsm/AArch64/CodeGen.hs41
-rw-r--r--compiler/GHC/CmmToAsm/AArch64/Instr.hs19
-rw-r--r--compiler/GHC/CmmToAsm/AArch64/Ppr.hs7
-rw-r--r--compiler/GHC/CmmToAsm/PPC/CodeGen.hs35
-rw-r--r--compiler/GHC/CmmToAsm/PPC/Instr.hs11
-rw-r--r--compiler/GHC/CmmToAsm/PPC/Ppr.hs18
-rw-r--r--compiler/GHC/CmmToAsm/Wasm/FromCmm.hs4
-rw-r--r--compiler/GHC/CmmToAsm/X86/CodeGen.hs92
-rw-r--r--compiler/GHC/CmmToAsm/X86/Instr.hs22
-rw-r--r--compiler/GHC/CmmToAsm/X86/Ppr.hs23
-rw-r--r--compiler/GHC/CmmToC.hs22
-rw-r--r--compiler/GHC/CmmToLlvm/CodeGen.hs28
-rw-r--r--compiler/GHC/Core/Opt/ConstantFold.hs156
-rw-r--r--compiler/GHC/Driver/Config/StgToCmm.hs22
-rw-r--r--compiler/GHC/Driver/Pipeline/Execute.hs1
-rw-r--r--compiler/GHC/Driver/Session.hs11
-rw-r--r--compiler/GHC/Llvm/Ppr.hs8
-rw-r--r--compiler/GHC/Llvm/Syntax.hs3
-rw-r--r--compiler/GHC/Llvm/Types.hs1
-rw-r--r--compiler/GHC/StgToCmm/Config.hs2
-rw-r--r--compiler/GHC/StgToCmm/Prim.hs44
-rw-r--r--compiler/GHC/StgToJS/Prim.hs10
-rw-r--r--compiler/GHC/SysTools/Cpp.hs4
-rw-r--r--docs/users_guide/9.8.1-notes.rst18
-rw-r--r--docs/users_guide/using.rst18
-rw-r--r--libraries/ghc-prim/changelog.md18
-rw-r--r--rts/RtsSymbols.c1
-rw-r--r--rts/StgPrimFloat.c1
-rw-r--r--testsuite/driver/cpu_features.py5
-rw-r--r--testsuite/tests/primops/should_run/FMA_ConstantFold.hs236
-rw-r--r--testsuite/tests/primops/should_run/FMA_Primops.hs264
-rw-r--r--testsuite/tests/primops/should_run/all.T11
35 files changed, 1244 insertions, 26 deletions
diff --git a/compiler/GHC/Builtin/primops.txt.pp b/compiler/GHC/Builtin/primops.txt.pp
index 5b730c1943..b2b3f1d8f5 100644
--- a/compiler/GHC/Builtin/primops.txt.pp
+++ b/compiler/GHC/Builtin/primops.txt.pp
@@ -1370,6 +1370,75 @@ primop FloatDecode_IntOp "decodeFloat_Int#" GenPrimOp
with out_of_line = True
------------------------------------------------------------------------
+section "Fused multiply-add operations"
+ { #fma#
+
+ The fused multiply-add primops 'fmaddFloat#' and 'fmaddDouble#'
+ implement the operation
+
+ \[
+ \lambda\ x\ y\ z \rightarrow x * y + z
+ \]
+
+ with a single floating-point rounding operation at the end, as opposed to
+ rounding twice (which can accumulate rounding errors).
+
+ These primops can be compiled directly to a single machine instruction on
+ architectures that support them. Currently, these are:
+
+ 1. x86 with CPUs that support the FMA3 extended instruction set (which
+ includes most processors since 2013).
+ 2. PowerPC.
+ 3. AArch64.
+
+ This requires users pass the '-mfma' flag to GHC. Otherwise, the primop
+ is implemented by falling back to the C standard library, which might
+ perform software emulation (this may yield results that are not IEEE
+ compliant on some platforms).
+
+ The additional operations 'fmsubFloat#'/'fmsubDouble#',
+ 'fnmaddFloat#'/'fnmaddDouble#' and 'fnmsubFloat#'/'fnmsubDouble#' provide
+ variants on 'fmaddFloat#'/'fmaddDouble#' in which some signs are changed:
+
+ \[
+ \begin{aligned}
+ \mathrm{fmadd}\ x\ y\ z &= \phantom{+} x * y + z \\[8pt]
+ \mathrm{fmsub}\ x\ y\ z &= \phantom{+} x * y - z \\[8pt]
+ \mathrm{fnmadd}\ x\ y\ z &= - x * y + z \\[8pt]
+ \mathrm{fnmsub}\ x\ y\ z &= - x * y - z
+ \end{aligned}
+ \]
+
+ }
+------------------------------------------------------------------------
+
+primop FloatFMAdd "fmaddFloat#" GenPrimOp
+ Float# -> Float# -> Float# -> Float#
+ {Fused multiply-add operation @x*y+z@. See "GHC.Prim#fma".}
+primop FloatFMSub "fmsubFloat#" GenPrimOp
+ Float# -> Float# -> Float# -> Float#
+ {Fused multiply-subtract operation @x*y-z@. See "GHC.Prim#fma".}
+primop FloatFNMAdd "fnmaddFloat#" GenPrimOp
+ Float# -> Float# -> Float# -> Float#
+ {Fused negate-multiply-add operation @-x*y+z@. See "GHC.Prim#fma".}
+primop FloatFNMSub "fnmsubFloat#" GenPrimOp
+ Float# -> Float# -> Float# -> Float#
+ {Fused negate-multiply-subtract operation @-x*y-z@. See "GHC.Prim#fma".}
+
+primop DoubleFMAdd "fmaddDouble#" GenPrimOp
+ Double# -> Double# -> Double# -> Double#
+ {Fused multiply-add operation @x*y+z@. See "GHC.Prim#fma".}
+primop DoubleFMSub "fmsubDouble#" GenPrimOp
+ Double# -> Double# -> Double# -> Double#
+ {Fused multiply-subtract operation @x*y-z@. See "GHC.Prim#fma".}
+primop DoubleFNMAdd "fnmaddDouble#" GenPrimOp
+ Double# -> Double# -> Double# -> Double#
+ {Fused negate-multiply-add operation @-x*y+z@. See "GHC.Prim#fma".}
+primop DoubleFNMSub "fnmsubDouble#" GenPrimOp
+ Double# -> Double# -> Double# -> Double#
+ {Fused negate-multiply-subtract operation @-x*y-z@. See "GHC.Prim#fma".}
+
+------------------------------------------------------------------------
section "Arrays"
{Operations on 'Array#'.}
------------------------------------------------------------------------
diff --git a/compiler/GHC/Cmm/MachOp.hs b/compiler/GHC/Cmm/MachOp.hs
index d134fdc346..29863853c1 100644
--- a/compiler/GHC/Cmm/MachOp.hs
+++ b/compiler/GHC/Cmm/MachOp.hs
@@ -1,3 +1,5 @@
+{-# LANGUAGE LambdaCase #-}
+
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
module GHC.Cmm.MachOp
@@ -26,6 +28,9 @@ module GHC.Cmm.MachOp
-- Atomic read-modify-write
, MemoryOrdering(..)
, AtomicMachOp(..)
+
+ -- Fused multiply-add
+ , FMASign(..), pprFMASign
)
where
@@ -88,6 +93,10 @@ data MachOp
| MO_F_Mul Width
| MO_F_Quot Width
+ -- Floating-point fused multiply-add operations
+ -- | Fused multiply-add, see 'FMASign'.
+ | MO_FMA FMASign Width
+
-- Floating point comparison
| MO_F_Eq Width
| MO_F_Ne Width
@@ -160,7 +169,30 @@ data MachOp
pprMachOp :: MachOp -> SDoc
pprMachOp mo = text (show mo)
+-- | Where are the signs in a fused multiply-add instruction?
+--
+-- @x*y + z@ vs @x*y - z@ vs @-x*y+z@ vs @-x*y-z@.
+--
+-- Warning: the signs aren't consistent across architectures (X86, PowerPC, AArch64).
+-- The user-facing implementation uses the X86 convention, while the relevant
+-- backends use their corresponding conventions.
+data FMASign
+ -- | Fused multiply-add @x*y + z@.
+ = FMAdd
+ -- | Fused multiply-subtract. On X86: @x*y - z@.
+ | FMSub
+ -- | Fused multiply-add. On X86: @-x*y + z@.
+ | FNMAdd
+ -- | Fused multiply-subtract. On X86: @-x*y - z@.
+ | FNMSub
+ deriving (Eq, Show)
+pprFMASign :: IsLine doc => FMASign -> doc
+pprFMASign = \case
+ FMAdd -> text "fmadd"
+ FMSub -> text "fmsub"
+ FNMAdd -> text "fnmadd"
+ FNMSub -> text "fnmsub"
-- -----------------------------------------------------------------------------
-- Some common MachReps
@@ -398,6 +430,9 @@ machOpResultType platform mop tys =
MO_F_Mul r -> cmmFloat r
MO_F_Quot r -> cmmFloat r
MO_F_Neg r -> cmmFloat r
+
+ MO_FMA _ r -> cmmFloat r
+
MO_F_Eq {} -> comparisonResultRep platform
MO_F_Ne {} -> comparisonResultRep platform
MO_F_Ge {} -> comparisonResultRep platform
@@ -489,6 +524,9 @@ machOpArgReps platform op =
MO_F_Mul r -> [r,r]
MO_F_Quot r -> [r,r]
MO_F_Neg r -> [r]
+
+ MO_FMA _ r -> [r,r,r]
+
MO_F_Eq r -> [r,r]
MO_F_Ne r -> [r,r]
MO_F_Ge r -> [r,r]
diff --git a/compiler/GHC/Cmm/Parser.y b/compiler/GHC/Cmm/Parser.y
index 29870dc647..495e72e37d 100644
--- a/compiler/GHC/Cmm/Parser.y
+++ b/compiler/GHC/Cmm/Parser.y
@@ -1009,7 +1009,7 @@ machOps = listToUFM $
( "eq", MO_Eq ),
( "ne", MO_Ne ),
( "mul", MO_Mul ),
- ( "mulmayoflo", MO_S_MulMayOflo ),
+ ( "mulmayoflo", MO_S_MulMayOflo ),
( "neg", MO_S_Neg ),
( "quot", MO_S_Quot ),
( "rem", MO_S_Rem ),
@@ -1040,6 +1040,11 @@ machOps = listToUFM $
( "fmul", MO_F_Mul ),
( "fquot", MO_F_Quot ),
+ ( "fmadd" , MO_FMA FMAdd ),
+ ( "fmsub" , MO_FMA FMSub ),
+ ( "fnmadd", MO_FMA FNMAdd ),
+ ( "fnmsub", MO_FMA FNMSub ),
+
( "feq", MO_F_Eq ),
( "fne", MO_F_Ne ),
( "fge", MO_F_Ge ),
diff --git a/compiler/GHC/CmmToAsm/AArch64/CodeGen.hs b/compiler/GHC/CmmToAsm/AArch64/CodeGen.hs
index 8ebccaf093..c0e9a7e8d5 100644
--- a/compiler/GHC/CmmToAsm/AArch64/CodeGen.hs
+++ b/compiler/GHC/CmmToAsm/AArch64/CodeGen.hs
@@ -783,7 +783,7 @@ getRegister' config plat expr
where w' = formatToWidth (cmmTypeFormat (cmmRegType reg))
r' = getRegisterReg plat reg
- -- Generic case.
+ -- Generic binary case.
CmmMachOp op [x, y] -> do
-- alright, so we have an operation, and two expressions. And we want to essentially do
-- ensure we get float regs (TODO(Ben): What?)
@@ -956,7 +956,44 @@ getRegister' config plat expr
-- TODO
- op -> pprPanic "getRegister' (unhandled dyadic CmmMachOp): " $ (pprMachOp op) <+> text "in" <+> (pdoc plat expr)
+ op -> pprPanic "getRegister' (unhandled dyadic CmmMachOp): " $
+ (pprMachOp op) <+> text "in" <+> (pdoc plat expr)
+
+ -- Generic ternary case.
+ CmmMachOp op [x, y, z] ->
+
+ case op of
+
+ -- Floating-point fused multiply-add operations
+
+ -- x86 fmadd x * y + z <=> AArch64 fmadd : d = r1 * r2 + r3
+ -- x86 fmsub x * y - z <=> AArch64 fnmsub: d = r1 * r2 - r3
+ -- x86 fnmadd - x * y + z <=> AArch64 fmsub : d = - r1 * r2 + r3
+ -- x86 fnmsub - x * y - z <=> AArch64 fnmadd: d = - r1 * r2 - r3
+
+ MO_FMA var w -> case var of
+ FMAdd -> float3Op w (\d n m a -> unitOL $ FMA FMAdd d n m a)
+ FMSub -> float3Op w (\d n m a -> unitOL $ FMA FNMSub d n m a)
+ FNMAdd -> float3Op w (\d n m a -> unitOL $ FMA FMSub d n m a)
+ FNMSub -> float3Op w (\d n m a -> unitOL $ FMA FNMAdd d n m a)
+
+ _ -> pprPanic "getRegister' (unhandled ternary CmmMachOp): " $
+ (pprMachOp op) <+> text "in" <+> (pdoc plat expr)
+
+ where
+ float3Op w op = do
+ (reg_fx, format_x, code_fx) <- getFloatReg x
+ (reg_fy, format_y, code_fy) <- getFloatReg y
+ (reg_fz, format_z, code_fz) <- getFloatReg z
+ massertPpr (isFloatFormat format_x && isFloatFormat format_y && isFloatFormat format_z) $
+ text "float3Op: non-float"
+ return $
+ Any (floatFormat w) $ \ dst ->
+ code_fx `appOL`
+ code_fy `appOL`
+ code_fz `appOL`
+ op (OpReg w dst) (OpReg w reg_fx) (OpReg w reg_fy) (OpReg w reg_fz)
+
CmmMachOp _op _xs
-> pprPanic "getRegister' (variadic CmmMachOp): " (pdoc plat expr)
diff --git a/compiler/GHC/CmmToAsm/AArch64/Instr.hs b/compiler/GHC/CmmToAsm/AArch64/Instr.hs
index 7bf78becb6..166ab2ca17 100644
--- a/compiler/GHC/CmmToAsm/AArch64/Instr.hs
+++ b/compiler/GHC/CmmToAsm/AArch64/Instr.hs
@@ -142,6 +142,8 @@ regUsageOfInstr platform instr = case instr of
SCVTF dst src -> usage (regOp src, regOp dst)
FCVTZS dst src -> usage (regOp src, regOp dst)
FABS dst src -> usage (regOp src, regOp dst)
+ FMA _ dst src1 src2 src3 ->
+ usage (regOp src1 ++ regOp src2 ++ regOp src3, regOp dst)
_ -> panic $ "regUsageOfInstr: " ++ instrCon instr
@@ -280,6 +282,9 @@ patchRegsOfInstr instr env = case instr of
SCVTF o1 o2 -> SCVTF (patchOp o1) (patchOp o2)
FCVTZS o1 o2 -> FCVTZS (patchOp o1) (patchOp o2)
FABS o1 o2 -> FABS (patchOp o1) (patchOp o2)
+ FMA s o1 o2 o3 o4 ->
+ FMA s (patchOp o1) (patchOp o2) (patchOp o3) (patchOp o4)
+
_ -> panic $ "patchRegsOfInstr: " ++ instrCon instr
where
patchOp :: Operand -> Operand
@@ -650,6 +655,14 @@ data Instr
-- Float ABSolute value
| FABS Operand Operand
+ -- | Floating-point fused multiply-add instructions
+ --
+ -- - fmadd : d = r1 * r2 + r3
+ -- - fnmsub: d = r1 * r2 - r3
+ -- - fmsub : d = - r1 * r2 + r3
+ -- - fnmadd: d = - r1 * r2 - r3
+ | FMA FMASign Operand Operand Operand Operand
+
instrCon :: Instr -> String
instrCon i =
case i of
@@ -715,6 +728,12 @@ instrCon i =
SCVTF{} -> "SCVTF"
FCVTZS{} -> "FCVTZS"
FABS{} -> "FABS"
+ FMA variant _ _ _ _ ->
+ case variant of
+ FMAdd -> "FMADD"
+ FMSub -> "FMSUB"
+ FNMAdd -> "FNMADD"
+ FNMSub -> "FNMSUB"
data Target
= TBlock BlockId
diff --git a/compiler/GHC/CmmToAsm/AArch64/Ppr.hs b/compiler/GHC/CmmToAsm/AArch64/Ppr.hs
index 475324afce..646f914c8d 100644
--- a/compiler/GHC/CmmToAsm/AArch64/Ppr.hs
+++ b/compiler/GHC/CmmToAsm/AArch64/Ppr.hs
@@ -546,6 +546,13 @@ pprInstr platform instr = case instr of
SCVTF o1 o2 -> op2 (text "\tscvtf") o1 o2
FCVTZS o1 o2 -> op2 (text "\tfcvtzs") o1 o2
FABS o1 o2 -> op2 (text "\tfabs") o1 o2
+ FMA variant d r1 r2 r3 ->
+ let fma = case variant of
+ FMAdd -> text "\tfmadd"
+ FMSub -> text "\tfmsub"
+ FNMAdd -> text "\tfnmadd"
+ FNMSub -> text "\tfnmsub"
+ in op4 fma d r1 r2 r3
where op2 op o1 o2 = line $ op <+> pprOp platform o1 <> comma <+> pprOp platform o2
op3 op o1 o2 o3 = line $ op <+> pprOp platform o1 <> comma <+> pprOp platform o2 <> comma <+> pprOp platform o3
op4 op o1 o2 o3 o4 = line $ op <+> pprOp platform o1 <> comma <+> pprOp platform o2 <> comma <+> pprOp platform o3 <> comma <+> pprOp platform o4
diff --git a/compiler/GHC/CmmToAsm/PPC/CodeGen.hs b/compiler/GHC/CmmToAsm/PPC/CodeGen.hs
index 7dac4f221b..f8a726da6c 100644
--- a/compiler/GHC/CmmToAsm/PPC/CodeGen.hs
+++ b/compiler/GHC/CmmToAsm/PPC/CodeGen.hs
@@ -649,6 +649,21 @@ getRegister' _ _ (CmmMachOp mop [x, y]) -- dyadic PrimOps
code <- remainderCode rep sgn tmp x y
return (Any fmt code)
+getRegister' _ _ (CmmMachOp mop [x, y, z]) -- ternary PrimOps
+ = case mop of
+
+ -- x86 fmadd x * y + z <> PPC fmadd rt = ra * rc + rb
+ -- x86 fmsub x * y - z <> PPC fmsub rt = ra * rc - rb
+ -- x86 fnmadd - x * y + z ~~ PPC fnmsub rt = -(ra * rc - rb)
+ -- x86 fnmsub - x * y - z ~~ PPC fnmadd rt = -(ra * rc + rb)
+
+ MO_FMA variant w ->
+ case variant of
+ FMAdd -> fma_code w (FMADD FMAdd) x y z
+ FMSub -> fma_code w (FMADD FMSub) x y z
+ FNMAdd -> fma_code w (FMADD FNMAdd) x y z
+ FNMSub -> fma_code w (FMADD FNMSub) x y z
+ _ -> panic "PPC.CodeGen.getRegister: no match"
getRegister' _ _ (CmmLit (CmmInt i rep))
| Just imm <- makeImmediate rep True i
@@ -2358,10 +2373,28 @@ trivialUCode rep instr x = do
let code' dst = code `snocOL` instr dst src
return (Any rep code')
+-- | Generate code for a 4-register FMA instruction,
+-- e.g. @fmadd rt ra rc rb := rt <- ra * rc + rb@.
+fma_code :: Width
+ -> (Format -> Reg -> Reg -> Reg -> Reg -> Instr)
+ -> CmmExpr
+ -> CmmExpr
+ -> CmmExpr
+ -> NatM Register
+fma_code w instr ra rc rb = do
+ let rep = floatFormat w
+ (src1, code1) <- getSomeReg ra
+ (src2, code2) <- getSomeReg rc
+ (src3, code3) <- getSomeReg rb
+ let instrCode rt =
+ code1 `appOL`
+ code2 `appOL`
+ code3 `snocOL` instr rep rt src1 src2 src3
+ return $ Any rep instrCode
+
-- There is no "remainder" instruction on the PPC, so we have to do
-- it the hard way.
-- The "sgn" parameter is the signedness for the division instruction
-
remainderCode :: Width -> Bool -> Reg -> CmmExpr -> CmmExpr
-> NatM (Reg -> InstrBlock)
remainderCode rep sgn reg_q arg_x arg_y = do
diff --git a/compiler/GHC/CmmToAsm/PPC/Instr.hs b/compiler/GHC/CmmToAsm/PPC/Instr.hs
index 639ae979f8..3fedcc1fc4 100644
--- a/compiler/GHC/CmmToAsm/PPC/Instr.hs
+++ b/compiler/GHC/CmmToAsm/PPC/Instr.hs
@@ -280,6 +280,14 @@ data Instr
| FABS Reg Reg -- abs is the same for single and double
| FNEG Reg Reg -- negate is the same for single and double prec.
+ -- | Fused multiply-add instructions.
+ --
+ -- - FMADD: @rd = (ra * rb) + rd@
+ -- - FMSUB: @rd = ra * rb - rd@
+ -- - FNMADD: @rd = -(ra * rb + rd)@
+ -- - FNMSUB: @rd = -(ra * rb - rd)@
+ | FMADD FMASign Format Reg Reg Reg Reg
+
| FCMP Reg Reg
| FCTIWZ Reg Reg -- convert to integer word
@@ -380,6 +388,7 @@ regUsageOfInstr platform instr
MFCR reg -> usage ([], [reg])
MFLR reg -> usage ([], [reg])
FETCHPC reg -> usage ([], [reg])
+ FMADD _ _ rt ra rc rb -> usage ([ra, rc, rb], [rt])
_ -> noUsage
where
usage (src, dst) = RU (filter (interesting platform) src)
@@ -467,6 +476,8 @@ patchRegsOfInstr instr env
FDIV fmt r1 r2 r3 -> FDIV fmt (env r1) (env r2) (env r3)
FABS r1 r2 -> FABS (env r1) (env r2)
FNEG r1 r2 -> FNEG (env r1) (env r2)
+ FMADD sgn fmt r1 r2 r3 r4
+ -> FMADD sgn fmt (env r1) (env r2) (env r3) (env r4)
FCMP r1 r2 -> FCMP (env r1) (env r2)
FCTIWZ r1 r2 -> FCTIWZ (env r1) (env r2)
FCTIDZ r1 r2 -> FCTIDZ (env r1) (env r2)
diff --git a/compiler/GHC/CmmToAsm/PPC/Ppr.hs b/compiler/GHC/CmmToAsm/PPC/Ppr.hs
index ba364df1b0..f1d6733327 100644
--- a/compiler/GHC/CmmToAsm/PPC/Ppr.hs
+++ b/compiler/GHC/CmmToAsm/PPC/Ppr.hs
@@ -934,6 +934,9 @@ pprInstr platform instr = case instr of
FNEG reg1 reg2
-> pprUnary (text "fneg") reg1 reg2
+ FMADD signs fmt dst ra rc rb
+ -> pprTernaryF (pprFMASign signs) fmt dst ra rc rb
+
FCMP reg1 reg2
-> line $ hcat [
char '\t',
@@ -1083,6 +1086,21 @@ pprBinaryF op fmt reg1 reg2 reg3 = line $ hcat [
pprReg reg3
]
+pprTernaryF :: IsDoc doc => Line doc -> Format -> Reg -> Reg -> Reg -> Reg -> doc
+pprTernaryF op fmt rt ra rc rb = line $ hcat [
+ char '\t',
+ op,
+ pprFFormat fmt,
+ char '\t',
+ pprReg rt,
+ text ", ",
+ pprReg ra,
+ text ", ",
+ pprReg rc,
+ text ", ",
+ pprReg rb
+ ]
+
pprRI :: IsLine doc => Platform -> RI -> doc
pprRI _ (RIReg r) = pprReg r
pprRI platform (RIImm r) = pprImm platform r
diff --git a/compiler/GHC/CmmToAsm/Wasm/FromCmm.hs b/compiler/GHC/CmmToAsm/Wasm/FromCmm.hs
index 7ca323d72d..9a4c3f34c2 100644
--- a/compiler/GHC/CmmToAsm/Wasm/FromCmm.hs
+++ b/compiler/GHC/CmmToAsm/Wasm/FromCmm.hs
@@ -816,7 +816,9 @@ lower_CmmMachOp lbl (MO_SS_Conv w0 w1) xs = lower_MO_SS_Conv lbl w0 w1 xs
lower_CmmMachOp lbl (MO_UU_Conv w0 w1) xs = lower_MO_UU_Conv lbl w0 w1 xs
lower_CmmMachOp lbl (MO_XX_Conv w0 w1) xs = lower_MO_UU_Conv lbl w0 w1 xs
lower_CmmMachOp lbl (MO_FF_Conv w0 w1) xs = lower_MO_FF_Conv lbl w0 w1 xs
-lower_CmmMachOp _ _ _ = panic "lower_CmmMachOp: unreachable"
+lower_CmmMachOp _ mop _ =
+ pprPanic "lower_CmmMachOp: unreachable" $
+ vcat [ text "offending MachOp:" <+> pprMachOp mop ]
-- | Lower a 'CmmLit'. Note that we don't emit 'f32.const' or
-- 'f64.const' for the time being, and instead emit their relative bit
diff --git a/compiler/GHC/CmmToAsm/X86/CodeGen.hs b/compiler/GHC/CmmToAsm/X86/CodeGen.hs
index d6ef821c9f..859b27e248 100644
--- a/compiler/GHC/CmmToAsm/X86/CodeGen.hs
+++ b/compiler/GHC/CmmToAsm/X86/CodeGen.hs
@@ -901,14 +901,10 @@ getRegister' _ is32Bit (CmmMachOp mop [x, y]) = -- dyadic MachOps
MO_U_Lt _ -> condIntReg LU x y
MO_U_Le _ -> condIntReg LEU x y
- MO_F_Add w -> trivialFCode_sse2 w ADD x y
-
- MO_F_Sub w -> trivialFCode_sse2 w SUB x y
-
- MO_F_Quot w -> trivialFCode_sse2 w FDIV x y
-
- MO_F_Mul w -> trivialFCode_sse2 w MUL x y
-
+ MO_F_Add w -> trivialFCode_sse2 w ADD x y
+ MO_F_Sub w -> trivialFCode_sse2 w SUB x y
+ MO_F_Quot w -> trivialFCode_sse2 w FDIV x y
+ MO_F_Mul w -> trivialFCode_sse2 w MUL x y
MO_Add rep -> add_code rep x y
MO_Sub rep -> sub_code rep x y
@@ -1113,6 +1109,13 @@ getRegister' _ is32Bit (CmmMachOp mop [x, y]) = -- dyadic MachOps
return (Fixed format result code)
+getRegister' _plat _is32Bit (CmmMachOp mop [x, y, z]) = -- ternary MachOps
+ case mop of
+ -- Floating point fused multiply-add operations @ ± x*y ± z@
+ MO_FMA var w -> genFMA3Code w var x y z
+
+ _other -> pprPanic "getRegister(x86) - ternary CmmMachOp (1)"
+ (pprMachOp mop)
getRegister' _ _ (CmmLoad mem pk _)
| isFloatType pk
@@ -3151,12 +3154,12 @@ genTrivialCode rep instr a b = do
a_code <- getAnyReg a
tmp <- getNewRegNat rep
let
- -- We want the value of b to stay alive across the computation of a.
- -- But, we want to calculate a straight into the destination register,
+ -- We want the value of 'b' to stay alive across the computation of 'a'.
+ -- But, we want to calculate 'a' straight into the destination register,
-- because the instruction only has two operands (dst := dst `op` src).
- -- The troublesome case is when the result of b is in the same register
- -- as the destination reg. In this case, we have to save b in a
- -- new temporary across the computation of a.
+ -- The troublesome case is when the result of 'b' is in the same register
+ -- as the destination 'reg'. In this case, we have to save 'b' in a
+ -- new temporary across the computation of 'a'.
code dst
| dst `regClashesWithOp` b_op =
b_code `appOL`
@@ -3174,6 +3177,69 @@ reg `regClashesWithOp` OpReg reg2 = reg == reg2
reg `regClashesWithOp` OpAddr amode = any (==reg) (addrModeRegs amode)
_ `regClashesWithOp` _ = False
+-- | Generate code for a fused multiply-add operation, of the form @± x * y ± z@,
+-- with 3 operands (FMA3 instruction set).
+genFMA3Code :: Width
+ -> FMASign
+ -> CmmExpr -> CmmExpr -> CmmExpr -> NatM Register
+genFMA3Code w signs x y z = do
+
+ -- For the FMA instruction, we want to compute x * y + z
+ --
+ -- There are three possible instructions we could emit:
+ --
+ -- - fmadd213 z y x, result in x, z can be a memory address
+ -- - fmadd132 x z y, result in y, x can be a memory address
+ -- - fmadd231 y x z, result in z, y can be a memory address
+ --
+ -- This suggests two possible optimisations:
+ --
+ -- - OPTIMISATION 1
+ -- If one argument is an address, use the instruction that allows
+ -- a memory address in that position.
+ --
+ -- - OPTIMISATION 2
+ -- If one argument is in a fixed register, use the instruction that puts
+ -- the result in that same register.
+ --
+ -- Currently we follow neither of these optimisations,
+ -- opting to always use fmadd213 for simplicity.
+ let rep = floatFormat w
+ (y_reg, y_code) <- getNonClobberedReg y
+ (z_reg, z_code) <- getNonClobberedReg z
+ x_code <- getAnyReg x
+ y_tmp <- getNewRegNat rep
+ z_tmp <- getNewRegNat rep
+ let
+ fma213 = FMA3 rep signs FMA213
+ code dst
+ | dst == y_reg
+ , dst == z_reg
+ = y_code `appOL`
+ unitOL (MOV rep (OpReg y_reg) (OpReg y_tmp)) `appOL`
+ z_code `appOL`
+ unitOL (MOV rep (OpReg z_reg) (OpReg z_tmp)) `appOL`
+ x_code dst `snocOL`
+ fma213 (OpReg z_tmp) y_tmp dst
+ | dst == y_reg
+ = y_code `appOL`
+ unitOL (MOV rep (OpReg y_reg) (OpReg z_tmp)) `appOL`
+ z_code `appOL`
+ x_code dst `snocOL`
+ fma213 (OpReg z_reg) y_tmp dst
+ | dst == z_reg
+ = y_code `appOL`
+ z_code `appOL`
+ unitOL (MOV rep (OpReg z_reg) (OpReg z_tmp)) `appOL`
+ x_code dst `snocOL`
+ fma213 (OpReg z_tmp) y_reg dst
+ | otherwise
+ = y_code `appOL`
+ z_code `appOL`
+ x_code dst `snocOL`
+ fma213 (OpReg z_reg) y_reg dst
+ return (Any rep code)
+
-----------
trivialUCode :: Format -> (Operand -> Instr)
diff --git a/compiler/GHC/CmmToAsm/X86/Instr.hs b/compiler/GHC/CmmToAsm/X86/Instr.hs
index ccb3ce09ba..b4e93a1c5d 100644
--- a/compiler/GHC/CmmToAsm/X86/Instr.hs
+++ b/compiler/GHC/CmmToAsm/X86/Instr.hs
@@ -12,6 +12,7 @@ module GHC.CmmToAsm.X86.Instr
( Instr(..)
, Operand(..)
, PrefetchVariant(..)
+ , FMAPermutation(..)
, JumpDest(..)
, getJumpDestBlockId
, canShortcut
@@ -272,6 +273,10 @@ data Instr
| CVTSI2SS Format Operand Reg -- I32/I64 to F32
| CVTSI2SD Format Operand Reg -- I32/I64 to F64
+ -- | FMA3 fused multiply-add operations.
+ | FMA3 Format FMASign FMAPermutation Operand Reg Reg
+ -- src1 (r/m), src2 (r), dst (r)
+
-- use ADD, SUB, and SQRT for arithmetic. In both cases, operands
-- are Operand Reg.
@@ -351,7 +356,7 @@ data Operand
| OpImm Imm -- immediate value
| OpAddr AddrMode -- memory reference
-
+data FMAPermutation = FMA132 | FMA213 | FMA231
-- | Returns which registers are read and written as a (read, written)
-- pair.
@@ -438,6 +443,8 @@ regUsageOfInstr platform instr
PDEP _ src mask dst -> mkRU (use_R src $ use_R mask []) [dst]
PEXT _ src mask dst -> mkRU (use_R src $ use_R mask []) [dst]
+ FMA3 _ _ _ src1 src2 dst -> usageFMA src1 src2 dst
+
-- note: might be a better way to do this
PREFETCH _ _ src -> mkRU (use_R src []) []
LOCK i -> regUsageOfInstr platform i
@@ -482,6 +489,15 @@ regUsageOfInstr platform instr
usageRMM (OpReg src) (OpAddr ea) (OpReg reg) = mkRU (use_EA ea [src, reg]) [reg]
usageRMM _ _ _ = panic "X86.RegInfo.usageRMM: no match"
+ -- 3 operand form of FMA instructions.
+ usageFMA :: Operand -> Reg -> Reg -> RegUsage
+ usageFMA (OpReg src1) src2 dst
+ = mkRU [src1, src2, dst] [dst]
+ usageFMA (OpAddr ea1) src2 dst
+ = mkRU (use_EA ea1 [src2, dst]) [dst]
+ usageFMA _ _ _
+ = panic "X86.RegInfo.usageFMA: no match"
+
-- 1 operand form; operand Modified
usageM :: Operand -> RegUsage
usageM (OpReg reg) = mkRU [reg] [reg]
@@ -561,6 +577,8 @@ patchRegsOfInstr instr env
JMP op regs -> JMP (patchOp op) regs
JMP_TBL op ids s lbl -> JMP_TBL (patchOp op) ids s lbl
+ FMA3 fmt perm var x1 x2 x3 -> patch3 (FMA3 fmt perm var) x1 x2 x3
+
-- literally only support storing the top x87 stack value st(0)
X87Store fmt dst -> X87Store fmt (lookupAddr dst)
@@ -612,6 +630,8 @@ patchRegsOfInstr instr env
patch1 insn op = insn $! patchOp op
patch2 :: (Operand -> Operand -> a) -> Operand -> Operand -> a
patch2 insn src dst = (insn $! patchOp src) $! patchOp dst
+ patch3 :: (Operand -> Reg -> Reg -> a) -> Operand -> Reg -> Reg -> a
+ patch3 insn src1 src2 dst = ((insn $! patchOp src1) $! env src2) $! env dst
patchOp (OpReg reg) = OpReg $! env reg
patchOp (OpImm imm) = OpImm imm
diff --git a/compiler/GHC/CmmToAsm/X86/Ppr.hs b/compiler/GHC/CmmToAsm/X86/Ppr.hs
index 4a8f55fdf0..0d649f2efb 100644
--- a/compiler/GHC/CmmToAsm/X86/Ppr.hs
+++ b/compiler/GHC/CmmToAsm/X86/Ppr.hs
@@ -838,6 +838,14 @@ pprInstr platform i = case i of
FDIV format op1 op2
-> pprFormatOpOp (text "div") format op1 op2
+ FMA3 format var perm op1 op2 op3
+ -> let mnemo = case var of
+ FMAdd -> text "vfmadd"
+ FMSub -> text "vfmsub"
+ FNMAdd -> text "vfnmadd"
+ FNMSub -> text "vfnmsub"
+ in pprFormatOpRegReg (mnemo <> pprFMAPermutation perm) format op1 op2 op3
+
SQRT format op1 op2
-> pprFormatOpReg (text "sqrt") format op1 op2
@@ -968,6 +976,21 @@ pprInstr platform i = case i of
pprOperand platform format op2
]
+ pprFormatOpRegReg :: Line doc -> Format -> Operand -> Reg -> Reg -> doc
+ pprFormatOpRegReg name format op1 op2 op3
+ = line $ hcat [
+ pprMnemonic name format,
+ pprOperand platform format op1,
+ comma,
+ pprReg platform format op2,
+ comma,
+ pprReg platform format op3
+ ]
+
+ pprFMAPermutation :: FMAPermutation -> Line doc
+ pprFMAPermutation FMA132 = text "132"
+ pprFMAPermutation FMA213 = text "213"
+ pprFMAPermutation FMA231 = text "231"
pprOpOp :: Line doc -> Format -> Operand -> Operand -> doc
pprOpOp name format op1 op2
diff --git a/compiler/GHC/CmmToC.hs b/compiler/GHC/CmmToC.hs
index d4ac5cd4e4..fb8774c77c 100644
--- a/compiler/GHC/CmmToC.hs
+++ b/compiler/GHC/CmmToC.hs
@@ -529,6 +529,11 @@ machOpNeedsCast platform mop args
pprMachOpApp' :: Platform -> MachOp -> [CmmExpr] -> SDoc
pprMachOpApp' platform mop args
= case args of
+
+ -- ternary
+ args@[_,_,_] ->
+ pprMachOp_for_C platform mop <> parens (pprWithCommas pprArg args)
+
-- dyadic
[x,y] -> pprArg x <+> pprMachOp_for_C platform mop <+> pprArg y
@@ -711,13 +716,28 @@ pprMachOp_for_C platform mop = case mop of
MO_U_Quot _ -> char '/'
MO_U_Rem _ -> char '%'
- -- & Floating-point operations
+ -- Floating-point operations
MO_F_Add _ -> char '+'
MO_F_Sub _ -> char '-'
MO_F_Neg _ -> char '-'
MO_F_Mul _ -> char '*'
MO_F_Quot _ -> char '/'
+ -- Floating-point fused multiply-add operations
+ MO_FMA FMAdd w ->
+ case w of
+ W32 -> text "fmaf"
+ W64 -> text "fma"
+ _ ->
+ pprTrace "offending mop:"
+ (text "FMAdd")
+ (panic $ "PprC.pprMachOp_for_C: FMAdd unsupported"
+ ++ "at width " ++ show w)
+ MO_FMA var _width ->
+ pprTrace "offending mop:"
+ (text $ "FMA " ++ show var)
+ (panic $ "PprC.pprMachOp_for_C: should have been handled earlier!")
+
-- Signed comparisons
MO_S_Ge _ -> text ">="
MO_S_Le _ -> text "<="
diff --git a/compiler/GHC/CmmToLlvm/CodeGen.hs b/compiler/GHC/CmmToLlvm/CodeGen.hs
index 1d658b359e..ba72aa1417 100644
--- a/compiler/GHC/CmmToLlvm/CodeGen.hs
+++ b/compiler/GHC/CmmToLlvm/CodeGen.hs
@@ -1469,6 +1469,9 @@ genMachOp _ op [x] = case op of
MO_F_Sub _ -> panicOp
MO_F_Mul _ -> panicOp
MO_F_Quot _ -> panicOp
+
+ MO_FMA _ _ -> panicOp
+
MO_F_Eq _ -> panicOp
MO_F_Ne _ -> panicOp
MO_F_Ge _ -> panicOp
@@ -1652,6 +1655,8 @@ genMachOp_slow opt op [x, y] = case op of
MO_F_Mul _ -> genBinMach LM_MO_FMul
MO_F_Quot _ -> genBinMach LM_MO_FDiv
+ MO_FMA _ _ -> panicOp
+
MO_And _ -> genBinMach LM_MO_And
MO_Or _ -> genBinMach LM_MO_Or
MO_Xor _ -> genBinMach LM_MO_Xor
@@ -1785,8 +1790,27 @@ genMachOp_slow opt op [x, y] = case op of
panicOp = panic $ "LLVM.CodeGen.genMachOp_slow: unary op encountered"
++ "with two arguments! (" ++ show op ++ ")"
--- More than two expression, invalid!
-genMachOp_slow _ _ _ = panic "genMachOp: More than 2 expressions in MachOp!"
+genMachOp_slow _opt op [x, y, z] = case op of
+ MO_FMA var _ -> triLlvmOp getVarType (FMAOp var)
+ _ -> panicOp
+ where
+ triLlvmOp ty op = do
+ platform <- getPlatform
+ runExprData $ do
+ vx <- exprToVarW x
+ vy <- exprToVarW y
+ vz <- exprToVarW z
+
+ if | getVarType vx == getVarType vy
+ , getVarType vx == getVarType vz
+ -> doExprW (ty vx) $ op vx vy vz
+ | otherwise
+ -> pprPanic "triLlvmOp types" (pdoc platform x $$ pdoc platform y $$ pdoc platform z)
+ panicOp = panic $ "LLVM.CodeGen.genMachOp_slow: non-ternary op encountered"
+ ++ "with three arguments! (" ++ show op ++ ")"
+
+-- More than three expressions, invalid!
+genMachOp_slow _ _ _ = panic "genMachOp_slow: More than 3 expressions in MachOp!"
-- | Handle CmmLoad expression.
diff --git a/compiler/GHC/Core/Opt/ConstantFold.hs b/compiler/GHC/Core/Opt/ConstantFold.hs
index 42ced5a86a..fbccce15bb 100644
--- a/compiler/GHC/Core/Opt/ConstantFold.hs
+++ b/compiler/GHC/Core/Opt/ConstantFold.hs
@@ -65,6 +65,9 @@ import GHC.Builtin.Types
import GHC.Builtin.Types.Prim
import GHC.Builtin.Names
+import GHC.Cmm.MachOp ( FMASign(..) )
+import GHC.Cmm.Type ( Width(..) )
+
import GHC.Data.FastString
import GHC.Data.Maybe ( orElse )
@@ -677,6 +680,11 @@ primOpRules nm = \case
FloatMulOp -> mkPrimOpRule nm 2 [ binaryLit (floatOp2 (*))
, identity onef
, strengthReduction twof FloatAddOp ]
+ FloatFMAdd -> mkPrimOpRule nm 3 (fmaRules FMAdd W32)
+ FloatFMSub -> mkPrimOpRule nm 3 (fmaRules FMSub W32)
+ FloatFNMAdd -> mkPrimOpRule nm 3 (fmaRules FNMAdd W32)
+ FloatFNMSub -> mkPrimOpRule nm 3 (fmaRules FNMSub W32)
+
-- zeroElem zerof doesn't hold because of NaN
FloatDivOp -> mkPrimOpRule nm 2 [ guardFloatDiv >> binaryLit (floatOp2 (/))
, rightIdentity onef ]
@@ -692,6 +700,10 @@ primOpRules nm = \case
DoubleMulOp -> mkPrimOpRule nm 2 [ binaryLit (doubleOp2 (*))
, identity oned
, strengthReduction twod DoubleAddOp ]
+ DoubleFMAdd -> mkPrimOpRule nm 3 (fmaRules FMAdd W64)
+ DoubleFMSub -> mkPrimOpRule nm 3 (fmaRules FMSub W64)
+ DoubleFNMAdd -> mkPrimOpRule nm 3 (fmaRules FNMAdd W64)
+ DoubleFNMSub -> mkPrimOpRule nm 3 (fmaRules FNMSub W64)
-- zeroElem zerod doesn't hold because of NaN
DoubleDivOp -> mkPrimOpRule nm 2 [ guardDoubleDiv >> binaryLit (doubleOp2 (/))
, rightIdentity oned ]
@@ -1139,6 +1151,150 @@ doubleDecodeOp _ _
= Nothing
--------------------------
+
+-- | Constant folding rules for fused multiply-add operations.
+fmaRules :: FMASign -> Width -> [RuleM CoreExpr]
+fmaRules signs width =
+ [ fmaLit signs width
+ , fmaZero_z signs width
+ , fmaOne signs width ]
+
+-- | Compute @a * b + c@ when @a@, @b@, @c@ are all literals.
+fmaLit :: FMASign -> Width -> RuleM CoreExpr
+fmaLit signs width = do
+ env <- getRuleOpts
+ [Lit l1, Lit l2, Lit l3] <- getArgs
+ liftMaybe $
+ op env
+ (convFloating env l1)
+ (convFloating env l2)
+ (convFloating env l3)
+
+ where
+ op env l1 l2 l3 =
+ case width of
+ W32
+ | LitFloat x <- l1
+ , LitFloat y <- l2
+ , LitFloat z <- l3
+ -> Just $ mkFloatVal env $
+ case signs of
+ FMAdd -> x * y + z
+ FMSub -> x * y - z
+ FNMAdd -> negate ( x * y ) + z
+ FNMSub -> negate ( x * y ) - z
+ W64
+ | LitDouble x <- l1
+ , LitDouble y <- l2
+ , LitDouble z <- l3
+ -> Just $ mkDoubleVal env $
+ case signs of
+ FMAdd -> x * y + z
+ FMSub -> x * y - z
+ FNMAdd -> negate ( x * y ) + z
+ FNMSub -> negate ( x * y ) - z
+ _ -> Nothing
+
+-- | @x * y + 0 = x * y@.
+fmaZero_z :: FMASign -> Width -> RuleM CoreExpr
+fmaZero_z signs width = do
+ [x, y, Lit z] <- getArgs
+ let
+ -- TODO: we should additionally check the sign of z.
+ -- FMAdd, FNMAdd: should be -0.0.
+ -- FMSub, FNMSub: should be +0.0.
+ ok =
+ case width of
+ W32
+ | LitFloat 0 <- z
+ -> True
+ W64
+ | LitDouble 0 <- z
+ -> True
+ _ -> False
+ neg = case width of
+ W32 -> FloatNegOp
+ W64 -> DoubleNegOp
+ _ -> panic "fmaZero_xy: not Float# or Double#"
+ mul = case width of
+ W32 -> FloatMulOp
+ W64 -> DoubleMulOp
+ _ -> panic "fmaZero_z: not Float# or Double#"
+ if ok
+ then return $ case signs of
+ FMAdd -> Var (primOpId mul) `App` x `App` y
+ FMSub -> Var (primOpId mul) `App` x `App` y
+ FNMAdd -> Var (primOpId neg) `App` (Var (primOpId mul) `App` x `App` y)
+ FNMSub -> Var (primOpId neg) `App` (Var (primOpId mul) `App` x `App` y)
+ else mzero
+
+-- | @±1 * y + z ==> z ± y@ and @x * ±1 + z ==> z ± x@.
+fmaOne :: FMASign -> Width -> RuleM CoreExpr
+fmaOne signs width = do
+ [x, y, z] <- getArgs
+ let
+ posNegOne_maybe :: Rational -> Maybe Bool
+ posNegOne_maybe i
+ | i == 1
+ = Just False
+ | i == -1
+ = Just True
+ | otherwise
+ = Nothing
+ ok =
+ case width of
+ W32
+ | Lit (LitFloat i) <- x
+ , Just sgn <- posNegOne_maybe i
+ -> Just (sgn, y)
+ | Lit (LitFloat i) <- y
+ , Just sgn <- posNegOne_maybe i
+ -> Just (sgn, x)
+ W64
+ | Lit (LitDouble i) <- x
+ , Just sgn <- posNegOne_maybe i
+ -> Just (sgn, y)
+ | Lit (LitDouble i) <- y
+ , Just sgn <- posNegOne_maybe i
+ -> Just (sgn, x)
+ _ -> Nothing
+ neg = case width of
+ W32 -> FloatNegOp
+ W64 -> DoubleNegOp
+ _ -> panic "fmaOne: not Float# or Double#"
+ add = case width of
+ W32 -> FloatAddOp
+ W64 -> DoubleAddOp
+ _ -> panic "fmaOne: not Float# or Double#"
+ sub = case width of
+ W32 -> FloatSubOp
+ W64 -> DoubleSubOp
+ _ -> panic "fmaOne: not Float# or Double#"
+ case ok of
+ Nothing -> mzero
+ Just (sgn, t) -> return $
+ if -- t + z
+ | ( signs == FMAdd && sgn == False )
+ || ( signs == FNMAdd && sgn == True )
+ -> Var (primOpId add) `App` t `App` z
+ -- - t + z
+ | signs == FMAdd
+ || signs == FNMAdd
+ -> Var (primOpId sub) `App` z `App` t
+ -- t - z
+ | ( signs == FMSub && sgn == False )
+ || ( signs == FNMSub && sgn == True )
+ -> Var (primOpId sub) `App` t `App` z
+ -- - t - z
+ | signs == FMSub
+ || signs == FNMSub
+ -> Var (primOpId neg) `App` (Var (primOpId add) `App` t `App` z)
+ | otherwise
+ -> pprPanic "fmaOne: non-exhaustive pattern match" $
+ vcat [ text "signs:" <+> text (show signs)
+ , text "sign:" <+> ppr sgn ]
+
+--------------------------
{- Note [The litEq rule: converting equality to case]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This stuff turns
diff --git a/compiler/GHC/Driver/Config/StgToCmm.hs b/compiler/GHC/Driver/Config/StgToCmm.hs
index 283ece1d50..f738974c46 100644
--- a/compiler/GHC/Driver/Config/StgToCmm.hs
+++ b/compiler/GHC/Driver/Config/StgToCmm.hs
@@ -1,3 +1,6 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MultiWayIf #-}
+
module GHC.Driver.Config.StgToCmm
( initStgToCmmConfig
) where
@@ -6,6 +9,7 @@ import GHC.Prelude.Basic
import GHC.StgToCmm.Config
+import GHC.Cmm.MachOp ( FMASign(..))
import GHC.Driver.Backend
import GHC.Driver.Session
import GHC.Platform
@@ -50,6 +54,24 @@ initStgToCmmConfig dflags mod = StgToCmmConfig
, stgToCmmAllowQuotRemInstr = ncg && (x86ish || ppc)
, stgToCmmAllowQuotRem2 = (ncg && (x86ish || ppc)) || llvm
, stgToCmmAllowExtendedAddSubInstrs = (ncg && (x86ish || ppc)) || llvm
+ , stgToCmmAllowFMAInstr =
+ if
+ | not (isFmaEnabled dflags)
+ || not (ncg || llvm)
+ -- If we're not using the native code generator or LLVM,
+ -- fall back to the generic implementation.
+ || platformArch platform == ArchWasm32
+ -- WASM doesn't support native FMA instructions (at the time of writing).
+ -> const False
+
+ -- FNMSub and FNMAdd have different semantics on PowerPC,
+ -- so we avoid using them.
+ | ppc
+ -> \ case { FMAdd -> True; FMSub -> True; _ -> False }
+
+ | otherwise
+ -> const True
+
, stgToCmmAllowIntMul2Instr = (ncg && x86ish) || llvm
-- SIMD flags
, stgToCmmVecInstrsErr = vec_err
diff --git a/compiler/GHC/Driver/Pipeline/Execute.hs b/compiler/GHC/Driver/Pipeline/Execute.hs
index 7694975e80..84113df8eb 100644
--- a/compiler/GHC/Driver/Pipeline/Execute.hs
+++ b/compiler/GHC/Driver/Pipeline/Execute.hs
@@ -1028,6 +1028,7 @@ llvmOptions llvm_config dflags =
++ ["+avx512cd"| isAvx512cdEnabled dflags ]
++ ["+avx512er"| isAvx512erEnabled dflags ]
++ ["+avx512pf"| isAvx512pfEnabled dflags ]
+ ++ ["+fma" | isFmaEnabled dflags ]
++ ["+bmi" | isBmiEnabled dflags ]
++ ["+bmi2" | isBmi2Enabled dflags ]
diff --git a/compiler/GHC/Driver/Session.hs b/compiler/GHC/Driver/Session.hs
index 6a7ba74477..52361abb09 100644
--- a/compiler/GHC/Driver/Session.hs
+++ b/compiler/GHC/Driver/Session.hs
@@ -208,6 +208,7 @@ module GHC.Driver.Session (
isAvx512erEnabled,
isAvx512fEnabled,
isAvx512pfEnabled,
+ isFmaEnabled,
-- * Linker/compiler information
LinkerInfo(..),
@@ -718,6 +719,7 @@ data DynFlags = DynFlags {
avx512er :: Bool, -- Enable AVX-512 Exponential and Reciprocal Instructions.
avx512f :: Bool, -- Enable AVX-512 instructions.
avx512pf :: Bool, -- Enable AVX-512 PreFetch Instructions.
+ fma :: Bool, -- ^ Enable FMA instructions.
-- | Run-time linker information (what options we need, etc.)
rtldInfo :: IORef (Maybe LinkerInfo),
@@ -1303,6 +1305,7 @@ defaultDynFlags mySettings =
avx512er = False,
avx512f = False,
avx512pf = False,
+ fma = False,
rtldInfo = panic "defaultDynFlags: no rtldInfo",
rtccInfo = panic "defaultDynFlags: no rtccInfo",
rtasmInfo = panic "defaultDynFlags: no rtasmInfo",
@@ -2730,6 +2733,7 @@ dynamic_flags_deps = [
, make_ord_flag defGhcFlag "mavx512f" (noArg (\d -> d { avx512f = True }))
, make_ord_flag defGhcFlag "mavx512pf" (noArg (\d ->
d { avx512pf = True }))
+ , make_ord_flag defGhcFlag "mfma" (noArg (\d -> d { fma = True }))
------ Plugin flags ------------------------------------------------
, make_ord_flag defGhcFlag "fplugin-opt" (hasArg addPluginModuleNameOption)
@@ -5008,7 +5012,7 @@ setUnsafeGlobalDynFlags dflags = do
-- -----------------------------------------------------------------------------
--- SSE and AVX
+-- SSE, AVX, FMA
isSse4_2Enabled :: DynFlags -> Bool
isSse4_2Enabled dflags = sseVersion dflags >= Just SSE42
@@ -5031,6 +5035,9 @@ isAvx512fEnabled dflags = avx512f dflags
isAvx512pfEnabled :: DynFlags -> Bool
isAvx512pfEnabled dflags = avx512pf dflags
+isFmaEnabled :: DynFlags -> Bool
+isFmaEnabled dflags = fma dflags
+
-- -----------------------------------------------------------------------------
-- BMI2
@@ -5046,6 +5053,8 @@ isBmi2Enabled dflags = case platformArch (targetPlatform dflags) of
ArchX86 -> bmiVersion dflags >= Just BMI2
_ -> False
+-- -----------------------------------------------------------------------------
+
-- | Indicate if cost-centre profiling is enabled
sccProfilingEnabled :: DynFlags -> Bool
sccProfilingEnabled dflags = profileIsProfiling (targetProfile dflags)
diff --git a/compiler/GHC/Llvm/Ppr.hs b/compiler/GHC/Llvm/Ppr.hs
index 36bfdf3405..f0f09a46ef 100644
--- a/compiler/GHC/Llvm/Ppr.hs
+++ b/compiler/GHC/Llvm/Ppr.hs
@@ -40,6 +40,7 @@ import GHC.Llvm.Types
import Data.List ( intersperse )
import GHC.Utils.Outputable
+import GHC.Cmm.MachOp ( FMASign(..), pprFMASign )
import GHC.CmmToLlvm.Config
import GHC.Utils.Panic
import GHC.Types.Unique
@@ -288,6 +289,7 @@ ppLlvmExpression opts expr
AtomicRMW aop tgt src ordering -> ppAtomicRMW opts aop tgt src ordering
CmpXChg addr old new s_ord f_ord -> ppCmpXChg opts addr old new s_ord f_ord
Phi tp predecessors -> ppPhi opts tp predecessors
+ FMAOp op x y z -> pprFMAOp opts op x y z
Asm asm c ty v se sk -> ppAsm opts asm c ty v se sk
MExpr meta expr -> ppMetaAnnotExpr opts meta expr
{-# SPECIALIZE ppLlvmExpression :: LlvmCgConfig -> LlvmExpression -> SDoc #-}
@@ -374,6 +376,12 @@ ppCmpOp opts op left right =
{-# SPECIALIZE ppCmpOp :: LlvmCgConfig -> LlvmCmpOp -> LlvmVar -> LlvmVar -> SDoc #-}
{-# SPECIALIZE ppCmpOp :: LlvmCgConfig -> LlvmCmpOp -> LlvmVar -> LlvmVar -> HLine #-} -- see Note [SPECIALIZE to HDoc] in GHC.Utils.Outputable
+pprFMAOp :: IsLine doc => LlvmCgConfig -> FMASign -> LlvmVar -> LlvmVar -> LlvmVar -> doc
+pprFMAOp opts signs x y z =
+ pprFMASign signs <+> ppLlvmType (getVarType x)
+ <+> ppName opts x <> comma
+ <+> ppName opts y <> comma
+ <+> ppName opts z
ppAssignment :: IsLine doc => LlvmCgConfig -> LlvmVar -> doc -> doc
ppAssignment opts var expr = ppName opts var <+> equals <+> expr
diff --git a/compiler/GHC/Llvm/Syntax.hs b/compiler/GHC/Llvm/Syntax.hs
index 882cb0660b..390380ad17 100644
--- a/compiler/GHC/Llvm/Syntax.hs
+++ b/compiler/GHC/Llvm/Syntax.hs
@@ -10,6 +10,7 @@ import GHC.Llvm.MetaData
import GHC.Llvm.Types
import GHC.Types.Unique
+import GHC.Cmm.MachOp ( FMASign(..) )
-- | Block labels
type LlvmBlockId = Unique
@@ -338,6 +339,8 @@ data LlvmExpression
-}
| Asm LMString LMString LlvmType [LlvmVar] Bool Bool
+ | FMAOp FMASign LlvmVar LlvmVar LlvmVar
+
{- |
A LLVM expression with metadata attached to it.
-}
diff --git a/compiler/GHC/Llvm/Types.hs b/compiler/GHC/Llvm/Types.hs
index f80b261584..ce2a82ce34 100644
--- a/compiler/GHC/Llvm/Types.hs
+++ b/compiler/GHC/Llvm/Types.hs
@@ -775,7 +775,6 @@ ppLlvmCastOp LM_Bitcast = text "bitcast"
{-# SPECIALIZE ppLlvmCastOp :: LlvmCastOp -> SDoc #-}
{-# SPECIALIZE ppLlvmCastOp :: LlvmCastOp -> HLine #-} -- see Note [SPECIALIZE to HDoc] in GHC.Utils.Outputable
-
-- -----------------------------------------------------------------------------
-- * Floating point conversion
--
diff --git a/compiler/GHC/StgToCmm/Config.hs b/compiler/GHC/StgToCmm/Config.hs
index f2bd349ae7..1465cde2e0 100644
--- a/compiler/GHC/StgToCmm/Config.hs
+++ b/compiler/GHC/StgToCmm/Config.hs
@@ -11,6 +11,7 @@ import GHC.Unit.Module
import GHC.Utils.Outputable
import GHC.Utils.TmpFs
+import GHC.Cmm.MachOp ( FMASign(..) )
import GHC.Prelude
@@ -66,6 +67,7 @@ data StgToCmmConfig = StgToCmmConfig
, stgToCmmAllowQuotRem2 :: !Bool -- ^ Allowed to generate QuotRem
, stgToCmmAllowExtendedAddSubInstrs :: !Bool -- ^ Allowed to generate AddWordC, SubWordC, Add2, etc.
, stgToCmmAllowIntMul2Instr :: !Bool -- ^ Allowed to generate IntMul2 instruction
+ , stgToCmmAllowFMAInstr :: FMASign -> Bool -- ^ Allowed to generate FMA instruction
, stgToCmmTickyAP :: !Bool -- ^ Disable use of precomputed standard thunks.
------------------------------ SIMD flags ------------------------------------
-- Each of these flags checks vector compatibility with the backend requested
diff --git a/compiler/GHC/StgToCmm/Prim.hs b/compiler/GHC/StgToCmm/Prim.hs
index f4a1924f19..d1f6bb0191 100644
--- a/compiler/GHC/StgToCmm/Prim.hs
+++ b/compiler/GHC/StgToCmm/Prim.hs
@@ -1395,6 +1395,11 @@ emitPrimOp cfg primop =
DoubleDivOp -> \args -> opTranslate args (MO_F_Quot W64)
DoubleNegOp -> \args -> opTranslate args (MO_F_Neg W64)
+ DoubleFMAdd -> fmaOp FMAdd W64
+ DoubleFMSub -> fmaOp FMSub W64
+ DoubleFNMAdd -> fmaOp FNMAdd W64
+ DoubleFNMSub -> fmaOp FNMSub W64
+
-- Float ops
FloatEqOp -> \args -> opTranslate args (MO_F_Eq W32)
@@ -1410,6 +1415,11 @@ emitPrimOp cfg primop =
FloatDivOp -> \args -> opTranslate args (MO_F_Quot W32)
FloatNegOp -> \args -> opTranslate args (MO_F_Neg W32)
+ FloatFMAdd -> fmaOp FMAdd W32
+ FloatFMSub -> fmaOp FMSub W32
+ FloatFNMAdd -> fmaOp FNMAdd W32
+ FloatFNMSub -> fmaOp FNMSub W32
+
-- Vector ops
(VecAddOp FloatVec n w) -> \args -> opTranslate args (MO_VF_Add n w)
@@ -1735,6 +1745,27 @@ emitPrimOp cfg primop =
allowExtAdd = stgToCmmAllowExtendedAddSubInstrs cfg
allowInt2Mul = stgToCmmAllowIntMul2Instr cfg
+ allowFMA = stgToCmmAllowFMAInstr cfg
+
+ fmaOp :: FMASign -> Width -> [CmmActual] -> PrimopCmmEmit
+ fmaOp signs w args@[arg_x, arg_y, arg_z]
+ | allowFMA signs
+ = opTranslate args (MO_FMA signs w)
+ | otherwise
+ = case signs of
+
+ -- For fused multiply-add x * y + z, we fall back to the C implementation.
+ FMAdd -> opIntoRegs $ \ [res] -> fmaCCall w res arg_x arg_y arg_z
+
+ -- Other fused multiply-add operations are implemented in terms of fmadd
+ -- This is sound: it does not lose any precision.
+ FMSub -> fmaOp FMAdd w [arg_x, arg_y, neg arg_z]
+ FNMAdd -> fmaOp FMAdd w [neg arg_x, arg_y, arg_z]
+ FNMSub -> fmaOp FMAdd w [neg arg_x, arg_y, neg arg_z]
+ where
+ neg x = CmmMachOp (MO_F_Neg w) [x]
+ fmaOp _ _ _ = panic "fmaOp: wrong number of arguments (expected 3)"
+
data PrimopCmmEmit
-- | Out of line fake primop that's actually just a foreign call to other
-- (presumably) C--.
@@ -2023,6 +2054,19 @@ genericIntMul2Op [res_c, res_h, res_l] both_args@[arg_x, arg_y]
]
genericIntMul2Op _ _ = panic "genericIntMul2Op"
+fmaCCall :: Width -> CmmFormal -> CmmActual -> CmmActual -> CmmActual -> FCode ()
+fmaCCall width res arg_x arg_y arg_z =
+ emitCCall
+ [(res,NoHint)]
+ (CmmLit (CmmLabel fma_lbl))
+ [(arg_x,NoHint), (arg_y,NoHint), (arg_z,NoHint)]
+ where
+ fma_lbl = mkForeignLabel fma_op Nothing ForeignLabelInExternalPackage IsFunction
+ fma_op = case width of
+ W32 -> fsLit "fmaf"
+ W64 -> fsLit "fma"
+ _ -> panic ("fmaCall: " ++ show width)
+
------------------------------------------------------------------------------
-- Helpers for translating various minor variants of array indexing.
diff --git a/compiler/GHC/StgToJS/Prim.hs b/compiler/GHC/StgToJS/Prim.hs
index 36f12e3409..c051318b22 100644
--- a/compiler/GHC/StgToJS/Prim.hs
+++ b/compiler/GHC/StgToJS/Prim.hs
@@ -488,6 +488,11 @@ genPrim prof bound ty op = case op of
DoubleDecode_2IntOp -> \[s,h,l,e] [x] -> PrimInline $ appT [s,h,l,e] "h$decodeDouble2Int" [x]
DoubleDecode_Int64Op -> \[s1,s2,e] [d] -> PrimInline $ appT [e,s1,s2] "h$decodeDoubleInt64" [d]
+ DoubleFMAdd -> unhandledPrimop op
+ DoubleFMSub -> unhandledPrimop op
+ DoubleFNMAdd -> unhandledPrimop op
+ DoubleFNMSub -> unhandledPrimop op
+
------------------------------ Float --------------------------------------------
FloatGtOp -> \[r] [x,y] -> PrimInline $ r |= if10 (x .>. y)
@@ -524,6 +529,11 @@ genPrim prof bound ty op = case op of
FloatToDoubleOp -> \[r] [x] -> PrimInline $ r |= x
FloatDecode_IntOp -> \[s,e] [x] -> PrimInline $ appT [s,e] "h$decodeFloatInt" [x]
+ FloatFMAdd -> unhandledPrimop op
+ FloatFMSub -> unhandledPrimop op
+ FloatFNMAdd -> unhandledPrimop op
+ FloatFNMSub -> unhandledPrimop op
+
------------------------------ Arrays -------------------------------------------
NewArrayOp -> \[r] [l,e] -> PrimInline $ r |= app "h$newArray" [l,e]
diff --git a/compiler/GHC/SysTools/Cpp.hs b/compiler/GHC/SysTools/Cpp.hs
index 61f70342a6..9410410d01 100644
--- a/compiler/GHC/SysTools/Cpp.hs
+++ b/compiler/GHC/SysTools/Cpp.hs
@@ -98,6 +98,9 @@ doCpp logger tmpfs dflags unit_env opts input_fn output_fn = do
[ "-D__SSE2__" | isSse2Enabled platform ] ++
[ "-D__SSE4_2__" | isSse4_2Enabled dflags ]
+ let fma_def =
+ [ "-D__FMA__" | isFmaEnabled dflags ]
+
let avx_defs =
[ "-D__AVX__" | isAvxEnabled dflags ] ++
[ "-D__AVX2__" | isAvx2Enabled dflags ] ++
@@ -140,6 +143,7 @@ doCpp logger tmpfs dflags unit_env opts input_fn output_fn = do
++ map GHC.SysTools.Option th_defs
++ map GHC.SysTools.Option hscpp_opts
++ map GHC.SysTools.Option sse_defs
+ ++ map GHC.SysTools.Option fma_def
++ map GHC.SysTools.Option avx_defs
++ map GHC.SysTools.Option io_manager_defs
++ mb_macro_include
diff --git a/docs/users_guide/9.8.1-notes.rst b/docs/users_guide/9.8.1-notes.rst
index e7e9acdf75..84d9105efd 100644
--- a/docs/users_guide/9.8.1-notes.rst
+++ b/docs/users_guide/9.8.1-notes.rst
@@ -142,6 +142,24 @@ Runtime system
- ``sameMutVar#``, ``sameTVar#``, ``sameMVar#``
- ``sameIOPort#``, ``eqStableName#``.
+- New primops for fused multiply-add operations. These primops combine a
+ multiplication and an addition, compiling to a single instruction when
+ the ``-mfma`` flag is enabled and the architecture supports it.
+
+ The new primops are ``fmaddFloat#, fmsubFloat#, fnmaddFloat#, fnmsubFloat# :: Float# -> Float# -> Float# -> Float#``
+ and ``fmaddDouble#, fmsubDouble#, fnmaddDouble#, fnmsubDouble# :: Double# -> Double# -> Double# -> Double#``.
+
+ These implement the following operations, while performing one single
+ rounding at the end, leading to a more accurate result:
+
+ - ``fmaddFloat# x y z``, ``fmaddDouble# x y z`` compute ``x * y + z``.
+ - ``fmsubFloat# x y z``, ``fmsubDouble# x y z`` compute ``x * y - z``.
+ - ``fnmaddFloat# x y z``, ``fnmaddDouble# x y z`` compute ``- x * y + z``.
+ - ``fnmsubFloat# x y z``, ``fnmsubDouble# x y z`` compute ``- x * y - z``.
+
+ Warning: on unsupported architectures, the software emulation provided by
+ the fallback to the C standard library is not guaranteed to be IEEE-compliant.
+
``ghc`` library
~~~~~~~~~~~~~~~
diff --git a/docs/users_guide/using.rst b/docs/users_guide/using.rst
index 787b6a0503..8de7dd3533 100644
--- a/docs/users_guide/using.rst
+++ b/docs/users_guide/using.rst
@@ -1732,6 +1732,24 @@ Some flags only make sense for particular target platforms.
:ref:`native code generator <native-code-gen>`. The resulting compiled
code will only run on processors that support BMI2 (Intel Haswell and newer, AMD Excavator, Zen and newer).
+.. ghc-flag:: -mfma
+ :shortdesc: Use native FMA instructions for fused multiply-add floating-point operations
+ :type: dynamic
+ :category: platform-options
+
+ :since: 9.8.1
+
+ Use native FMA instructions to implement the fused multiply-add floating-point
+ operations of the form ``x * y + z``.
+ This allows computing a multiplication and addition in a single instruction,
+ without an intermediate rounding step.
+ Supported architectures: X86 with the FMA3 instruction set (this includes
+ most consumer processors since 2013), PowerPC and AArch64.
+
+ When this flag is disabled, GHC falls back to the C implementation of fused
+ multiply-add, which might perform non-IEEE-compliant software emulation on
+ some platforms (depending on the implementation of the C standard library).
+
Haddock
-------
diff --git a/libraries/ghc-prim/changelog.md b/libraries/ghc-prim/changelog.md
index 1cf411c029..39e5face03 100644
--- a/libraries/ghc-prim/changelog.md
+++ b/libraries/ghc-prim/changelog.md
@@ -23,6 +23,24 @@
- `copyAddrToAddrNonOverlapping#`
- `setAddrRange#`
+- New primops for fused multiply-add operations. These primops combine a
+ multiplication and an addition, compiling to a single instruction when
+ the `-mfma` flag is enabled and the architecture supports it.
+
+ The new primops are `fmaddFloat#, fmsubFloat#, fnmaddFloat#, fnmsubFloat# :: Float# -> Float# -> Float# -> Float#`
+ and `fmaddDouble#, fmsubDouble#, fnmaddDouble#, fnmsubDouble# :: Double# -> Double# -> Double# -> Double#`.
+
+ These implement the following operations, while performing one single
+ rounding at the end, leading to a more accurate result:
+
+ - `fmaddFloat# x y z`, `fmaddDouble# x y z` compute `x * y + z`.
+ - `fmsubFloat# x y z`, `fmsubDouble# x y z` compute `x * y - z`.
+ - `fnmaddFloat# x y z`, `fnmaddDouble# x y z` compute `- x * y + z`.
+ - `fnmsubFloat# x y z`, `fnmsubDouble# x y z` compute `- x * y - z`.
+
+ Warning: on unsupported architectures, the software emulation provided by
+ the fallback to the C standard library is not guaranteed to be IEEE-compliant.
+
## 0.10.0
- Shipped with GHC 9.6.1
diff --git a/rts/RtsSymbols.c b/rts/RtsSymbols.c
index dee6c57f5e..d5ed5bb543 100644
--- a/rts/RtsSymbols.c
+++ b/rts/RtsSymbols.c
@@ -930,7 +930,6 @@ extern char **environ;
RTS_USER_SIGNALS_SYMBOLS \
RTS_INTCHAR_SYMBOLS
-
// 64-bit support functions in libgcc.a
#if defined(__GNUC__) && SIZEOF_VOID_P <= 4 && !defined(_ABIN32)
#define RTS_LIBGCC_SYMBOLS \
diff --git a/rts/StgPrimFloat.c b/rts/StgPrimFloat.c
index a8c266ae78..d105a0d76b 100644
--- a/rts/StgPrimFloat.c
+++ b/rts/StgPrimFloat.c
@@ -248,4 +248,3 @@ __decodeFloat_Int (I_ *man, I_ *exp, StgFloat flt)
*man = - *man;
}
}
-
diff --git a/testsuite/driver/cpu_features.py b/testsuite/driver/cpu_features.py
index 43500130db..f6c441d88f 100644
--- a/testsuite/driver/cpu_features.py
+++ b/testsuite/driver/cpu_features.py
@@ -10,6 +10,7 @@ SUPPORTED_CPU_FEATURES = {
# x86:
'sse', 'sse2', 'sse3', 'ssse3', 'sse4_1', 'sse4_2',
'avx1', 'avx2',
+ 'fma',
'popcnt', 'bmi1', 'bmi2'
}
@@ -46,6 +47,10 @@ def get_cpu_features():
check_feature('avx2_0', 'avx2')
return features
+ elif config.arch in [ 'powerpc', 'powerpc64' ]:
+ # Hardcode support for 'fma' on PowerPC
+ return [ 'fma' ]
+
else:
# TODO: Add {Open,Free}BSD support
print('get_cpu_features: Lacking support for your platform')
diff --git a/testsuite/tests/primops/should_run/FMA_ConstantFold.hs b/testsuite/tests/primops/should_run/FMA_ConstantFold.hs
new file mode 100644
index 0000000000..80bac3231c
--- /dev/null
+++ b/testsuite/tests/primops/should_run/FMA_ConstantFold.hs
@@ -0,0 +1,236 @@
+{-# LANGUAGE BlockArguments #-}
+{-# LANGUAGE LexicalNegation #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE TypeApplications #-}
+
+module Main where
+
+import Control.Monad
+ ( unless )
+import Data.IORef
+ ( newIORef, readIORef, writeIORef )
+import GHC.Exts
+ ( Float(..), Float#, Double(..), Double#
+ , fmaddFloat# , fmsubFloat# , fnmaddFloat# , fnmsubFloat#
+ , fmaddDouble#, fmsubDouble#, fnmaddDouble#, fnmsubDouble#
+ )
+import GHC.Float
+ ( castFloatToWord32, castDoubleToWord64 )
+import System.Exit
+ ( exitFailure, exitSuccess )
+
+--------------------------------------------------------------------------------
+
+-- NB: This test tests constant folding for fused multiply-add operations.
+-- See FMA_Primops for the test of the primops.
+
+-- Want "on-the-nose" equality to make sure we properly distinguish 0.0 and -0.0.
+class StrictEq a where
+ strictlyEqual :: a -> a -> Bool
+instance StrictEq Float where
+ strictlyEqual x y = castFloatToWord32 x == castFloatToWord32 y
+instance StrictEq Double where
+ strictlyEqual x y = castDoubleToWord64 x == castDoubleToWord64 y
+
+class FMA a where
+ fmadd, fmsub, fnmadd, fnmsub :: a -> a -> a -> a
+instance FMA Float where
+ fmadd (F# x) (F# y) (F# z) = F# ( fmaddFloat# x y z)
+ fmsub (F# x) (F# y) (F# z) = F# ( fmsubFloat# x y z)
+ fnmadd (F# x) (F# y) (F# z) = F# (fnmaddFloat# x y z)
+ fnmsub (F# x) (F# y) (F# z) = F# (fnmsubFloat# x y z)
+ {-# INLINE fmadd #-}
+ {-# INLINE fnmadd #-}
+ {-# INLINE fmsub #-}
+ {-# INLINE fnmsub #-}
+instance FMA Double where
+ fmadd (D# x) (D# y) (D# z) = D# ( fmaddDouble# x y z)
+ fmsub (D# x) (D# y) (D# z) = D# ( fmsubDouble# x y z)
+ fnmadd (D# x) (D# y) (D# z) = D# (fnmaddDouble# x y z)
+ fnmsub (D# x) (D# y) (D# z) = D# (fnmsubDouble# x y z)
+ {-# INLINE fmadd #-}
+ {-# INLINE fnmadd #-}
+ {-# INLINE fmsub #-}
+ {-# INLINE fnmsub #-}
+
+main :: IO ()
+main = do
+
+ exit_ref <- newIORef False
+
+ let
+ it :: forall a. ( StrictEq a, Show a ) => String -> a -> a -> IO ()
+ it desc actual expected =
+ unless (actual `strictlyEqual` expected) do
+ writeIORef exit_ref True
+ putStrLn $ unlines
+ [ "FAIL " ++ desc
+ , " expected: " ++ show expected
+ , " actual: " ++ show actual ]
+
+ -- NB: throughout this test, we are using "round to nearest".
+
+ -- fmadd: x * y + z
+
+ -- Float
+ it "fmaddFloat#: sniff test"
+ ( fmadd @Float 2 3 1 ) 7
+
+ it "fmaddFloat#: excess precision"
+ ( fmadd @Float 0.99999994 1.00000012 -1 ) 5.96046377e-08
+
+ it "fmaddFloat#: +0 + +0 rounds properly"
+ ( fmadd @Float 1 0 0 ) 0
+
+ it "fmaddFloat#: +0 + -0 rounds properly"
+ ( fmadd @Float 1 0 -0 ) 0
+
+ it "fmaddFloat#: -0 + +0 rounds properly"
+ ( fmadd @Float 1 -0 0 ) 0
+
+ it "fmaddFloat#: -0 + -0 rounds properly"
+ ( fmadd @Float 1 -0 -0 ) -0
+
+ -- Double
+ it "fmaddDouble#: sniff test"
+ ( fmadd @Double 2 3 1 ) 7
+
+ it "fmaddDouble#: excess precision"
+ ( fmadd @Double 0.99999999999999989 1.0000000000000002 -1 ) 1.1102230246251563e-16
+
+ it "fmaddDouble#: +0 + +0 rounds properly"
+ ( fmadd @Double 1 0 0 ) 0
+
+ it "fmaddDouble#: +0 + -0 rounds properly"
+ ( fmadd @Double 1 0 -0 ) 0
+
+ it "fmaddDouble#: -0 + +0 rounds properly"
+ ( fmadd @Double 1 -0 0 ) 0
+
+ it "fmaddDouble#: -0 + -0 rounds properly"
+ ( fmadd @Double 1 -0 -0 ) -0
+
+ -- fmsub: x * y - z
+
+ -- Float
+ it "fmsubFloat#: sniff test"
+ ( fmsub @Float 2 3 1 ) 5.0
+
+ it "fmsubFloat#: excess precision"
+ ( fmsub @Float 0.99999994 1.00000012 1 ) 5.96046377e-08
+
+ it "fmsubFloat#: +0 + +0 rounds properly"
+ ( fmsub @Float 1 0 0 ) 0
+
+ it "fmsubFloat#: +0 + -0 rounds properly"
+ ( fmsub @Float 1 0 -0 ) 0
+
+ it "fmsubFloat#: -0 + +0 rounds properly"
+ ( fmsub @Float 1 -0 0 ) -0
+
+ it "fmsubFloat#: -0 + -0 rounds properly"
+ ( fmsub @Float 1 -0 -0 ) 0
+
+ -- Double
+ it "fmsubDouble#: sniff test"
+ ( fmsub @Double 2 3 1 ) 5.0
+
+ it "fmsubDouble#: excess precision"
+ ( fmsub @Double 0.99999999999999989 1.0000000000000002 1 ) 1.1102230246251563e-16
+
+ it "fmsubDouble#: +0 + +0 rounds properly"
+ ( fmsub @Double 1 0 0 ) 0
+
+ it "fmsubDouble#: +0 + -0 rounds properly"
+ ( fmsub @Double 1 0 -0 ) 0
+
+ it "fmsubDouble#: -0 + +0 rounds properly"
+ ( fmsub @Double 1 -0 0 ) -0
+
+ it "fmsubDouble#: -0 + -0 rounds properly"
+ ( fmsub @Double 1 -0 -0 ) 0
+
+ -- fnmadd: - x * y + z
+
+ -- Float
+ it "fnmaddFloat#: sniff test"
+ ( fnmadd @Float 2 3 1 ) -5.0
+
+ it "fnmaddFloat#: excess precision"
+ ( fnmadd @Float 0.99999994 1.00000012 1 ) -5.96046377e-08
+
+ it "fnmaddFloat#: +0 + +0 rounds properly"
+ ( fnmadd @Float 1 0 0 ) 0
+
+ it "fnmaddFloat#: +0 + -0 rounds properly"
+ ( fnmadd @Float 1 0 -0 ) -0
+
+ it "fnmaddFloat#: -0 + +0 rounds properly"
+ ( fnmadd @Float 1 -0 0 ) 0
+
+ it "fnmaddFloat#: -0 + -0 rounds properly"
+ ( fnmadd @Float 1 -0 -0 ) 0
+
+ -- Double
+ it "fnmaddDouble#: sniff test"
+ ( fnmadd @Double 2 3 1 ) -5.0
+
+ it "fnmaddDouble#: excess precision"
+ ( fnmadd @Double 0.99999999999999989 1.0000000000000002 1 ) -1.1102230246251563e-16
+
+ it "fnmaddDouble#: +0 + +0 rounds properly"
+ ( fnmadd @Double 1 0 0 ) 0
+
+ it "fnmaddDouble#: +0 + -0 rounds properly"
+ ( fnmadd @Double 1 0 -0 ) -0
+
+ it "fnmaddDouble#: -0 + +0 rounds properly"
+ ( fnmadd @Double 1 -0 0 ) 0
+
+ it "fnmaddDouble#: -0 + -0 rounds properly"
+ ( fnmadd @Double 1 -0 -0 ) 0
+
+ -- fnmsub: - x * y - z
+
+ -- Float
+ it "fnmsubFloat#: sniff test"
+ ( fnmsub @Float 2 3 1 ) -7
+
+ it "fnmsubFloat#: excess precision"
+ ( fnmsub @Float 0.99999994 1.00000012 -1 ) -5.96046377e-08
+
+ it "fnmsubFloat#: +0 + +0 rounds properly"
+ ( fnmsub @Float 1 0 0 ) -0
+
+ it "fnmsubFloat#: +0 + -0 rounds properly"
+ ( fnmsub @Float 1 0 -0 ) 0
+
+ it "fnmsubFloat#: -0 + +0 rounds properly"
+ ( fnmsub @Float 1 -0 0 ) 0
+
+ it "fnmsubFloat#: -0 + -0 rounds properly"
+ ( fnmsub @Float 1 -0 -0 ) 0
+
+ -- Double
+ it "fnmsubDouble#: sniff test"
+ ( fnmsub @Double 2 3 1 ) -7
+
+ it "fnmsubDouble#: excess precision"
+ ( fnmsub @Double 0.99999999999999989 1.0000000000000002 -1 ) -1.1102230246251563e-16
+
+ it "fnmsubDouble#: +0 + +0 rounds properly"
+ ( fnmsub @Double 1 0 0 ) -0
+
+ it "fnmsubDouble#: +0 + -0 rounds properly"
+ ( fnmsub @Double 1 0 -0 ) 0
+
+ it "fnmsubDouble#: -0 + +0 rounds properly"
+ ( fnmsub @Double 1 -0 0 ) 0
+
+ it "fnmsubDouble#: -0 + -0 rounds properly"
+ ( fnmsub @Double 1 -0 -0 ) 0
+
+ failure <- readIORef exit_ref
+ if failure
+ then exitFailure
+ else exitSuccess
diff --git a/testsuite/tests/primops/should_run/FMA_Primops.hs b/testsuite/tests/primops/should_run/FMA_Primops.hs
new file mode 100644
index 0000000000..a925ff6c3b
--- /dev/null
+++ b/testsuite/tests/primops/should_run/FMA_Primops.hs
@@ -0,0 +1,264 @@
+{-# LANGUAGE BlockArguments #-}
+{-# LANGUAGE LexicalNegation #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE TypeApplications #-}
+
+module Main where
+
+import Control.Monad
+ ( unless )
+import Data.IORef
+ ( newIORef, readIORef, writeIORef )
+import GHC.Exts
+ ( Float(..), Float#, Double(..), Double#
+ , fmaddFloat# , fmsubFloat# , fnmaddFloat# , fnmsubFloat#
+ , fmaddDouble#, fmsubDouble#, fnmaddDouble#, fnmsubDouble#
+ )
+import GHC.Float
+ ( castFloatToWord32, castDoubleToWord64 )
+import System.Exit
+ ( exitFailure, exitSuccess )
+
+--------------------------------------------------------------------------------
+
+-- Want "on-the-nose" equality to make sure we properly distinguish 0.0 and -0.0.
+class StrictEq a where
+ strictlyEqual :: a -> a -> Bool
+instance StrictEq Float where
+ strictlyEqual x y = castFloatToWord32 x == castFloatToWord32 y
+instance StrictEq Double where
+ strictlyEqual x y = castDoubleToWord64 x == castDoubleToWord64 y
+
+class FMA a where
+ fmadd, fmsub, fnmadd, fnmsub :: a -> a -> a -> a
+instance FMA Float where
+ fmadd (F# x) (F# y) (F# z) = F# ( fmaddFloat# x y z)
+ fmsub (F# x) (F# y) (F# z) = F# ( fmsubFloat# x y z)
+ fnmadd (F# x) (F# y) (F# z) = F# (fnmaddFloat# x y z)
+ fnmsub (F# x) (F# y) (F# z) = F# (fnmsubFloat# x y z)
+ {-# INLINE fmadd #-}
+ {-# INLINE fnmadd #-}
+ {-# INLINE fmsub #-}
+ {-# INLINE fnmsub #-}
+instance FMA Double where
+ fmadd (D# x) (D# y) (D# z) = D# ( fmaddDouble# x y z)
+ fmsub (D# x) (D# y) (D# z) = D# ( fmsubDouble# x y z)
+ fnmadd (D# x) (D# y) (D# z) = D# (fnmaddDouble# x y z)
+ fnmsub (D# x) (D# y) (D# z) = D# (fnmsubDouble# x y z)
+ {-# INLINE fmadd #-}
+ {-# INLINE fnmadd #-}
+ {-# INLINE fmsub #-}
+ {-# INLINE fnmsub #-}
+
+f1, f2, f3 :: Float
+f1 = 0.99999994 -- float before 1
+f2 = 1.00000012 -- float after 1
+f3 = 5.96046377e-08 -- float after 0
+d1, d2, d3 :: Double
+d1 = 0.99999999999999989 -- double before 1
+d2 = 1.0000000000000002 -- double after 1
+d3 = 1.1102230246251563e-16 -- double after 0
+zero, one, two, three, five, seven :: Num a => a
+zero = 0
+one = 1
+two = 2
+three = 3
+five = 5
+seven = 7
+
+-- NOINLINE to prevent constant folding
+-- (The test FMA_ConstantFold tests constant folding.)
+{-# NOINLINE f1 #-}
+{-# NOINLINE f2 #-}
+{-# NOINLINE f3 #-}
+{-# NOINLINE d1 #-}
+{-# NOINLINE d2 #-}
+{-# NOINLINE d3 #-}
+{-# NOINLINE zero #-}
+{-# NOINLINE one #-}
+{-# NOINLINE two #-}
+{-# NOINLINE three #-}
+{-# NOINLINE five #-}
+{-# NOINLINE seven #-}
+
+main :: IO ()
+main = do
+
+ exit_ref <- newIORef False
+
+ let
+ it :: forall a. ( StrictEq a, Show a ) => String -> a -> a -> IO ()
+ it desc actual expected =
+ unless (actual `strictlyEqual` expected) do
+ writeIORef exit_ref True
+ putStrLn $ unlines
+ [ "FAIL " ++ desc
+ , " expected: " ++ show expected
+ , " actual: " ++ show actual ]
+
+ -- NB: throughout this test, we are using "round to nearest".
+
+ -- fmadd: x * y + z
+
+ -- Float
+ it "fmaddFloat#: sniff test"
+ ( fmadd @Float two three one ) seven
+
+ it "fmaddFloat#: excess precision"
+ ( fmadd @Float f1 f2 -one ) f3
+
+ it "fmaddFloat#: +0 + +0 rounds properly"
+ ( fmadd @Float one zero zero ) zero
+
+ it "fmaddFloat#: +0 + -0 rounds properly"
+ ( fmadd @Float one zero -zero ) zero
+
+ it "fmaddFloat#: -0 + +0 rounds properly"
+ ( fmadd @Float one -zero zero ) zero
+
+ it "fmaddFloat#: -0 + -0 rounds properly"
+ ( fmadd @Float one -zero -zero ) -zero
+
+ -- Double
+ it "fmaddDouble#: sniff test"
+ ( fmadd @Double two three one ) seven
+
+ it "fmaddDouble#: excess precision"
+ ( fmadd @Double d1 d2 -one ) d3
+
+ it "fmaddDouble#: +0 + +0 rounds properly"
+ ( fmadd @Double one zero zero ) zero
+
+ it "fmaddDouble#: +0 + -0 rounds properly"
+ ( fmadd @Double one zero -zero ) zero
+
+ it "fmaddDouble#: -0 + +0 rounds properly"
+ ( fmadd @Double one -zero zero ) zero
+
+ it "fmaddDouble#: -0 + -0 rounds properly"
+ ( fmadd @Double one -zero -zero ) -zero
+
+ -- fmsub: x * y - z
+
+ -- Float
+ it "fmsubFloat#: sniff test"
+ ( fmsub @Float two three one ) five
+
+ it "fmsubFloat#: excess precision"
+ ( fmsub @Float f1 f2 one ) f3
+
+ it "fmsubFloat#: +0 + +0 rounds properly"
+ ( fmsub @Float one zero zero ) zero
+
+ it "fmsubFloat#: +0 + -0 rounds properly"
+ ( fmsub @Float one zero -zero ) zero
+
+ it "fmsubFloat#: -0 + +0 rounds properly"
+ ( fmsub @Float one -zero zero ) -zero
+
+ it "fmsubFloat#: -0 + -0 rounds properly"
+ ( fmsub @Float one -zero -zero ) zero
+
+ -- Double
+ it "fmsubDouble#: sniff test"
+ ( fmsub @Double two three one ) five
+
+ it "fmsubDouble#: excess precision"
+ ( fmsub @Double d1 d2 one ) d3
+
+ it "fmsubDouble#: +0 + +0 rounds properly"
+ ( fmsub @Double one zero zero ) zero
+
+ it "fmsubDouble#: +0 + -0 rounds properly"
+ ( fmsub @Double one zero -zero ) zero
+
+ it "fmsubDouble#: -0 + +0 rounds properly"
+ ( fmsub @Double one -zero zero ) -zero
+
+ it "fmsubDouble#: -0 + -0 rounds properly"
+ ( fmsub @Double one -zero -zero ) zero
+
+ -- fnmadd: - x * y + z
+
+ -- Float
+ it "fnmaddFloat#: sniff test"
+ ( fnmadd @Float two three one ) -five
+
+ it "fnmaddFloat#: excess precision"
+ ( fnmadd @Float f1 f2 one ) -f3
+
+ it "fnmaddFloat#: +0 + +0 rounds properly"
+ ( fnmadd @Float one zero zero ) zero
+
+ it "fnmaddFloat#: +0 + -0 rounds properly"
+ ( fnmadd @Float one zero -zero ) -zero
+
+ it "fnmaddFloat#: -0 + +0 rounds properly"
+ ( fnmadd @Float one -zero zero ) zero
+
+ it "fnmaddFloat#: -0 + -0 rounds properly"
+ ( fnmadd @Float one -zero -zero ) zero
+
+ -- Double
+ it "fnmaddDouble#: sniff test"
+ ( fnmadd @Double two three one ) -five
+
+ it "fnmaddDouble#: excess precision"
+ ( fnmadd @Double d1 d2 one ) -d3
+
+ it "fnmaddDouble#: +0 + +0 rounds properly"
+ ( fnmadd @Double one zero zero ) zero
+
+ it "fnmaddDouble#: +0 + -0 rounds properly"
+ ( fnmadd @Double one zero -zero ) -zero
+
+ it "fnmaddDouble#: -0 + +0 rounds properly"
+ ( fnmadd @Double one -zero zero ) zero
+
+ it "fnmaddDouble#: -0 + -0 rounds properly"
+ ( fnmadd @Double one -zero -zero ) zero
+
+ -- fnmsub: - x * y - z
+
+ -- Float
+ it "fnmsubFloat#: sniff test"
+ ( fnmsub @Float two three one ) -seven
+
+ it "fnmsubFloat#: excess precision"
+ ( fnmsub @Float f1 f2 -one ) -f3
+
+ it "fnmsubFloat#: +0 + +0 rounds properly"
+ ( fnmsub @Float one zero zero ) -zero
+
+ it "fnmsubFloat#: +0 + -0 rounds properly"
+ ( fnmsub @Float one zero -zero ) zero
+
+ it "fnmsubFloat#: -0 + +0 rounds properly"
+ ( fnmsub @Float one -zero zero ) zero
+
+ it "fnmsubFloat#: -0 + -0 rounds properly"
+ ( fnmsub @Float one -zero -zero ) zero
+
+ -- Double
+ it "fnmsubDouble#: sniff test"
+ ( fnmsub @Double two three one ) -seven
+
+ it "fnmsubDouble#: excess precision"
+ ( fnmsub @Double d1 d2 -one ) -d3
+
+ it "fnmsubDouble#: +0 + +0 rounds properly"
+ ( fnmsub @Double one zero zero ) -zero
+
+ it "fnmsubDouble#: +0 + -0 rounds properly"
+ ( fnmsub @Double one zero -zero ) zero
+
+ it "fnmsubDouble#: -0 + +0 rounds properly"
+ ( fnmsub @Double one -zero zero ) zero
+
+ it "fnmsubDouble#: -0 + -0 rounds properly"
+ ( fnmsub @Double one -zero -zero ) zero
+
+ failure <- readIORef exit_ref
+ if failure
+ then exitFailure
+ else exitSuccess
diff --git a/testsuite/tests/primops/should_run/all.T b/testsuite/tests/primops/should_run/all.T
index 4148546280..da6378df84 100644
--- a/testsuite/tests/primops/should_run/all.T
+++ b/testsuite/tests/primops/should_run/all.T
@@ -59,5 +59,16 @@ test('UnliftedTVar1', normal, compile_and_run, [''])
test('UnliftedTVar2', normal, compile_and_run, [''])
test('UnliftedWeakPtr', normal, compile_and_run, [''])
+test('FMA_Primops'
+ , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma'))
+ , js_skip # JS backend doesn't have an FMA implementation
+ ]
+ , compile_and_run, [''])
+test('FMA_ConstantFold'
+ , [ js_skip # JS backend doesn't have an FMA implementation ]
+ , expect_broken(21227)
+ ]
+ , compile_and_run, ['-O'])
+
test('T21624', normal, compile_and_run, [''])
test('T23071', ignore_stdout, compile_and_run, [''])