summaryrefslogtreecommitdiff
path: root/testsuite
diff options
context:
space:
mode:
authorsheaf <sam.derbyshire@gmail.com>2023-04-08 13:42:58 +0200
committerMarge Bot <ben+marge-bot@smart-cactus.org>2023-05-11 11:55:22 -0400
commit87eebf98cb485f7c9175330051736e147ade9848 (patch)
treeffa226b3fefa8b0a03e1798fa4f55affbddf654b /testsuite
parent630b1fea1e41a1e00860a30742b6ab8ade8a0de0 (diff)
downloadhaskell-87eebf98cb485f7c9175330051736e147ade9848.tar.gz
Add fused multiply-add instructions
This patch adds eight new primops that fuse a multiplication and an addition or subtraction: - `{fmadd,fmsub,fnmadd,fnmsub}{Float,Double}#` fmadd x y z is x * y + z, computed with a single rounding step. This patch implements code generation for these primops in the following backends: - X86, AArch64 and PowerPC NCG, - LLVM - C WASM uses the C implementation. The primops are unsupported in the JavaScript backend. The following constant folding rules are also provided: - compute a * b + c when a, b, c are all literals, - x * y + 0 ==> x * y, - ±1 * y + z ==> z ± y and x * ±1 + z ==> z ± x. NB: the constant folding rules incorrectly handle signed zero. This is a known limitation with GHC's floating-point constant folding rules (#21227), which we hope to resolve in the future.
Diffstat (limited to 'testsuite')
-rw-r--r--testsuite/driver/cpu_features.py5
-rw-r--r--testsuite/tests/primops/should_run/FMA_ConstantFold.hs236
-rw-r--r--testsuite/tests/primops/should_run/FMA_Primops.hs264
-rw-r--r--testsuite/tests/primops/should_run/all.T11
4 files changed, 516 insertions, 0 deletions
diff --git a/testsuite/driver/cpu_features.py b/testsuite/driver/cpu_features.py
index 43500130db..f6c441d88f 100644
--- a/testsuite/driver/cpu_features.py
+++ b/testsuite/driver/cpu_features.py
@@ -10,6 +10,7 @@ SUPPORTED_CPU_FEATURES = {
# x86:
'sse', 'sse2', 'sse3', 'ssse3', 'sse4_1', 'sse4_2',
'avx1', 'avx2',
+ 'fma',
'popcnt', 'bmi1', 'bmi2'
}
@@ -46,6 +47,10 @@ def get_cpu_features():
check_feature('avx2_0', 'avx2')
return features
+ elif config.arch in [ 'powerpc', 'powerpc64' ]:
+ # Hardcode support for 'fma' on PowerPC
+ return [ 'fma' ]
+
else:
# TODO: Add {Open,Free}BSD support
print('get_cpu_features: Lacking support for your platform')
diff --git a/testsuite/tests/primops/should_run/FMA_ConstantFold.hs b/testsuite/tests/primops/should_run/FMA_ConstantFold.hs
new file mode 100644
index 0000000000..80bac3231c
--- /dev/null
+++ b/testsuite/tests/primops/should_run/FMA_ConstantFold.hs
@@ -0,0 +1,236 @@
+{-# LANGUAGE BlockArguments #-}
+{-# LANGUAGE LexicalNegation #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE TypeApplications #-}
+
+module Main where
+
+import Control.Monad
+ ( unless )
+import Data.IORef
+ ( newIORef, readIORef, writeIORef )
+import GHC.Exts
+ ( Float(..), Float#, Double(..), Double#
+ , fmaddFloat# , fmsubFloat# , fnmaddFloat# , fnmsubFloat#
+ , fmaddDouble#, fmsubDouble#, fnmaddDouble#, fnmsubDouble#
+ )
+import GHC.Float
+ ( castFloatToWord32, castDoubleToWord64 )
+import System.Exit
+ ( exitFailure, exitSuccess )
+
+--------------------------------------------------------------------------------
+
+-- NB: This test tests constant folding for fused multiply-add operations.
+-- See FMA_Primops for the test of the primops.
+
+-- Want "on-the-nose" equality to make sure we properly distinguish 0.0 and -0.0.
+class StrictEq a where
+ strictlyEqual :: a -> a -> Bool
+instance StrictEq Float where
+ strictlyEqual x y = castFloatToWord32 x == castFloatToWord32 y
+instance StrictEq Double where
+ strictlyEqual x y = castDoubleToWord64 x == castDoubleToWord64 y
+
+class FMA a where
+ fmadd, fmsub, fnmadd, fnmsub :: a -> a -> a -> a
+instance FMA Float where
+ fmadd (F# x) (F# y) (F# z) = F# ( fmaddFloat# x y z)
+ fmsub (F# x) (F# y) (F# z) = F# ( fmsubFloat# x y z)
+ fnmadd (F# x) (F# y) (F# z) = F# (fnmaddFloat# x y z)
+ fnmsub (F# x) (F# y) (F# z) = F# (fnmsubFloat# x y z)
+ {-# INLINE fmadd #-}
+ {-# INLINE fnmadd #-}
+ {-# INLINE fmsub #-}
+ {-# INLINE fnmsub #-}
+instance FMA Double where
+ fmadd (D# x) (D# y) (D# z) = D# ( fmaddDouble# x y z)
+ fmsub (D# x) (D# y) (D# z) = D# ( fmsubDouble# x y z)
+ fnmadd (D# x) (D# y) (D# z) = D# (fnmaddDouble# x y z)
+ fnmsub (D# x) (D# y) (D# z) = D# (fnmsubDouble# x y z)
+ {-# INLINE fmadd #-}
+ {-# INLINE fnmadd #-}
+ {-# INLINE fmsub #-}
+ {-# INLINE fnmsub #-}
+
+main :: IO ()
+main = do
+
+ exit_ref <- newIORef False
+
+ let
+ it :: forall a. ( StrictEq a, Show a ) => String -> a -> a -> IO ()
+ it desc actual expected =
+ unless (actual `strictlyEqual` expected) do
+ writeIORef exit_ref True
+ putStrLn $ unlines
+ [ "FAIL " ++ desc
+ , " expected: " ++ show expected
+ , " actual: " ++ show actual ]
+
+ -- NB: throughout this test, we are using "round to nearest".
+
+ -- fmadd: x * y + z
+
+ -- Float
+ it "fmaddFloat#: sniff test"
+ ( fmadd @Float 2 3 1 ) 7
+
+ it "fmaddFloat#: excess precision"
+ ( fmadd @Float 0.99999994 1.00000012 -1 ) 5.96046377e-08
+
+ it "fmaddFloat#: +0 + +0 rounds properly"
+ ( fmadd @Float 1 0 0 ) 0
+
+ it "fmaddFloat#: +0 + -0 rounds properly"
+ ( fmadd @Float 1 0 -0 ) 0
+
+ it "fmaddFloat#: -0 + +0 rounds properly"
+ ( fmadd @Float 1 -0 0 ) 0
+
+ it "fmaddFloat#: -0 + -0 rounds properly"
+ ( fmadd @Float 1 -0 -0 ) -0
+
+ -- Double
+ it "fmaddDouble#: sniff test"
+ ( fmadd @Double 2 3 1 ) 7
+
+ it "fmaddDouble#: excess precision"
+ ( fmadd @Double 0.99999999999999989 1.0000000000000002 -1 ) 1.1102230246251563e-16
+
+ it "fmaddDouble#: +0 + +0 rounds properly"
+ ( fmadd @Double 1 0 0 ) 0
+
+ it "fmaddDouble#: +0 + -0 rounds properly"
+ ( fmadd @Double 1 0 -0 ) 0
+
+ it "fmaddDouble#: -0 + +0 rounds properly"
+ ( fmadd @Double 1 -0 0 ) 0
+
+ it "fmaddDouble#: -0 + -0 rounds properly"
+ ( fmadd @Double 1 -0 -0 ) -0
+
+ -- fmsub: x * y - z
+
+ -- Float
+ it "fmsubFloat#: sniff test"
+ ( fmsub @Float 2 3 1 ) 5.0
+
+ it "fmsubFloat#: excess precision"
+ ( fmsub @Float 0.99999994 1.00000012 1 ) 5.96046377e-08
+
+ it "fmsubFloat#: +0 + +0 rounds properly"
+ ( fmsub @Float 1 0 0 ) 0
+
+ it "fmsubFloat#: +0 + -0 rounds properly"
+ ( fmsub @Float 1 0 -0 ) 0
+
+ it "fmsubFloat#: -0 + +0 rounds properly"
+ ( fmsub @Float 1 -0 0 ) -0
+
+ it "fmsubFloat#: -0 + -0 rounds properly"
+ ( fmsub @Float 1 -0 -0 ) 0
+
+ -- Double
+ it "fmsubDouble#: sniff test"
+ ( fmsub @Double 2 3 1 ) 5.0
+
+ it "fmsubDouble#: excess precision"
+ ( fmsub @Double 0.99999999999999989 1.0000000000000002 1 ) 1.1102230246251563e-16
+
+ it "fmsubDouble#: +0 + +0 rounds properly"
+ ( fmsub @Double 1 0 0 ) 0
+
+ it "fmsubDouble#: +0 + -0 rounds properly"
+ ( fmsub @Double 1 0 -0 ) 0
+
+ it "fmsubDouble#: -0 + +0 rounds properly"
+ ( fmsub @Double 1 -0 0 ) -0
+
+ it "fmsubDouble#: -0 + -0 rounds properly"
+ ( fmsub @Double 1 -0 -0 ) 0
+
+ -- fnmadd: - x * y + z
+
+ -- Float
+ it "fnmaddFloat#: sniff test"
+ ( fnmadd @Float 2 3 1 ) -5.0
+
+ it "fnmaddFloat#: excess precision"
+ ( fnmadd @Float 0.99999994 1.00000012 1 ) -5.96046377e-08
+
+ it "fnmaddFloat#: +0 + +0 rounds properly"
+ ( fnmadd @Float 1 0 0 ) 0
+
+ it "fnmaddFloat#: +0 + -0 rounds properly"
+ ( fnmadd @Float 1 0 -0 ) -0
+
+ it "fnmaddFloat#: -0 + +0 rounds properly"
+ ( fnmadd @Float 1 -0 0 ) 0
+
+ it "fnmaddFloat#: -0 + -0 rounds properly"
+ ( fnmadd @Float 1 -0 -0 ) 0
+
+ -- Double
+ it "fnmaddDouble#: sniff test"
+ ( fnmadd @Double 2 3 1 ) -5.0
+
+ it "fnmaddDouble#: excess precision"
+ ( fnmadd @Double 0.99999999999999989 1.0000000000000002 1 ) -1.1102230246251563e-16
+
+ it "fnmaddDouble#: +0 + +0 rounds properly"
+ ( fnmadd @Double 1 0 0 ) 0
+
+ it "fnmaddDouble#: +0 + -0 rounds properly"
+ ( fnmadd @Double 1 0 -0 ) -0
+
+ it "fnmaddDouble#: -0 + +0 rounds properly"
+ ( fnmadd @Double 1 -0 0 ) 0
+
+ it "fnmaddDouble#: -0 + -0 rounds properly"
+ ( fnmadd @Double 1 -0 -0 ) 0
+
+ -- fnmsub: - x * y - z
+
+ -- Float
+ it "fnmsubFloat#: sniff test"
+ ( fnmsub @Float 2 3 1 ) -7
+
+ it "fnmsubFloat#: excess precision"
+ ( fnmsub @Float 0.99999994 1.00000012 -1 ) -5.96046377e-08
+
+ it "fnmsubFloat#: +0 + +0 rounds properly"
+ ( fnmsub @Float 1 0 0 ) -0
+
+ it "fnmsubFloat#: +0 + -0 rounds properly"
+ ( fnmsub @Float 1 0 -0 ) 0
+
+ it "fnmsubFloat#: -0 + +0 rounds properly"
+ ( fnmsub @Float 1 -0 0 ) 0
+
+ it "fnmsubFloat#: -0 + -0 rounds properly"
+ ( fnmsub @Float 1 -0 -0 ) 0
+
+ -- Double
+ it "fnmsubDouble#: sniff test"
+ ( fnmsub @Double 2 3 1 ) -7
+
+ it "fnmsubDouble#: excess precision"
+ ( fnmsub @Double 0.99999999999999989 1.0000000000000002 -1 ) -1.1102230246251563e-16
+
+ it "fnmsubDouble#: +0 + +0 rounds properly"
+ ( fnmsub @Double 1 0 0 ) -0
+
+ it "fnmsubDouble#: +0 + -0 rounds properly"
+ ( fnmsub @Double 1 0 -0 ) 0
+
+ it "fnmsubDouble#: -0 + +0 rounds properly"
+ ( fnmsub @Double 1 -0 0 ) 0
+
+ it "fnmsubDouble#: -0 + -0 rounds properly"
+ ( fnmsub @Double 1 -0 -0 ) 0
+
+ failure <- readIORef exit_ref
+ if failure
+ then exitFailure
+ else exitSuccess
diff --git a/testsuite/tests/primops/should_run/FMA_Primops.hs b/testsuite/tests/primops/should_run/FMA_Primops.hs
new file mode 100644
index 0000000000..a925ff6c3b
--- /dev/null
+++ b/testsuite/tests/primops/should_run/FMA_Primops.hs
@@ -0,0 +1,264 @@
+{-# LANGUAGE BlockArguments #-}
+{-# LANGUAGE LexicalNegation #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE TypeApplications #-}
+
+module Main where
+
+import Control.Monad
+ ( unless )
+import Data.IORef
+ ( newIORef, readIORef, writeIORef )
+import GHC.Exts
+ ( Float(..), Float#, Double(..), Double#
+ , fmaddFloat# , fmsubFloat# , fnmaddFloat# , fnmsubFloat#
+ , fmaddDouble#, fmsubDouble#, fnmaddDouble#, fnmsubDouble#
+ )
+import GHC.Float
+ ( castFloatToWord32, castDoubleToWord64 )
+import System.Exit
+ ( exitFailure, exitSuccess )
+
+--------------------------------------------------------------------------------
+
+-- Want "on-the-nose" equality to make sure we properly distinguish 0.0 and -0.0.
+class StrictEq a where
+ strictlyEqual :: a -> a -> Bool
+instance StrictEq Float where
+ strictlyEqual x y = castFloatToWord32 x == castFloatToWord32 y
+instance StrictEq Double where
+ strictlyEqual x y = castDoubleToWord64 x == castDoubleToWord64 y
+
+class FMA a where
+ fmadd, fmsub, fnmadd, fnmsub :: a -> a -> a -> a
+instance FMA Float where
+ fmadd (F# x) (F# y) (F# z) = F# ( fmaddFloat# x y z)
+ fmsub (F# x) (F# y) (F# z) = F# ( fmsubFloat# x y z)
+ fnmadd (F# x) (F# y) (F# z) = F# (fnmaddFloat# x y z)
+ fnmsub (F# x) (F# y) (F# z) = F# (fnmsubFloat# x y z)
+ {-# INLINE fmadd #-}
+ {-# INLINE fnmadd #-}
+ {-# INLINE fmsub #-}
+ {-# INLINE fnmsub #-}
+instance FMA Double where
+ fmadd (D# x) (D# y) (D# z) = D# ( fmaddDouble# x y z)
+ fmsub (D# x) (D# y) (D# z) = D# ( fmsubDouble# x y z)
+ fnmadd (D# x) (D# y) (D# z) = D# (fnmaddDouble# x y z)
+ fnmsub (D# x) (D# y) (D# z) = D# (fnmsubDouble# x y z)
+ {-# INLINE fmadd #-}
+ {-# INLINE fnmadd #-}
+ {-# INLINE fmsub #-}
+ {-# INLINE fnmsub #-}
+
+f1, f2, f3 :: Float
+f1 = 0.99999994 -- float before 1
+f2 = 1.00000012 -- float after 1
+f3 = 5.96046377e-08 -- float after 0
+d1, d2, d3 :: Double
+d1 = 0.99999999999999989 -- double before 1
+d2 = 1.0000000000000002 -- double after 1
+d3 = 1.1102230246251563e-16 -- double after 0
+zero, one, two, three, five, seven :: Num a => a
+zero = 0
+one = 1
+two = 2
+three = 3
+five = 5
+seven = 7
+
+-- NOINLINE to prevent constant folding
+-- (The test FMA_ConstantFold tests constant folding.)
+{-# NOINLINE f1 #-}
+{-# NOINLINE f2 #-}
+{-# NOINLINE f3 #-}
+{-# NOINLINE d1 #-}
+{-# NOINLINE d2 #-}
+{-# NOINLINE d3 #-}
+{-# NOINLINE zero #-}
+{-# NOINLINE one #-}
+{-# NOINLINE two #-}
+{-# NOINLINE three #-}
+{-# NOINLINE five #-}
+{-# NOINLINE seven #-}
+
+main :: IO ()
+main = do
+
+ exit_ref <- newIORef False
+
+ let
+ it :: forall a. ( StrictEq a, Show a ) => String -> a -> a -> IO ()
+ it desc actual expected =
+ unless (actual `strictlyEqual` expected) do
+ writeIORef exit_ref True
+ putStrLn $ unlines
+ [ "FAIL " ++ desc
+ , " expected: " ++ show expected
+ , " actual: " ++ show actual ]
+
+ -- NB: throughout this test, we are using "round to nearest".
+
+ -- fmadd: x * y + z
+
+ -- Float
+ it "fmaddFloat#: sniff test"
+ ( fmadd @Float two three one ) seven
+
+ it "fmaddFloat#: excess precision"
+ ( fmadd @Float f1 f2 -one ) f3
+
+ it "fmaddFloat#: +0 + +0 rounds properly"
+ ( fmadd @Float one zero zero ) zero
+
+ it "fmaddFloat#: +0 + -0 rounds properly"
+ ( fmadd @Float one zero -zero ) zero
+
+ it "fmaddFloat#: -0 + +0 rounds properly"
+ ( fmadd @Float one -zero zero ) zero
+
+ it "fmaddFloat#: -0 + -0 rounds properly"
+ ( fmadd @Float one -zero -zero ) -zero
+
+ -- Double
+ it "fmaddDouble#: sniff test"
+ ( fmadd @Double two three one ) seven
+
+ it "fmaddDouble#: excess precision"
+ ( fmadd @Double d1 d2 -one ) d3
+
+ it "fmaddDouble#: +0 + +0 rounds properly"
+ ( fmadd @Double one zero zero ) zero
+
+ it "fmaddDouble#: +0 + -0 rounds properly"
+ ( fmadd @Double one zero -zero ) zero
+
+ it "fmaddDouble#: -0 + +0 rounds properly"
+ ( fmadd @Double one -zero zero ) zero
+
+ it "fmaddDouble#: -0 + -0 rounds properly"
+ ( fmadd @Double one -zero -zero ) -zero
+
+ -- fmsub: x * y - z
+
+ -- Float
+ it "fmsubFloat#: sniff test"
+ ( fmsub @Float two three one ) five
+
+ it "fmsubFloat#: excess precision"
+ ( fmsub @Float f1 f2 one ) f3
+
+ it "fmsubFloat#: +0 + +0 rounds properly"
+ ( fmsub @Float one zero zero ) zero
+
+ it "fmsubFloat#: +0 + -0 rounds properly"
+ ( fmsub @Float one zero -zero ) zero
+
+ it "fmsubFloat#: -0 + +0 rounds properly"
+ ( fmsub @Float one -zero zero ) -zero
+
+ it "fmsubFloat#: -0 + -0 rounds properly"
+ ( fmsub @Float one -zero -zero ) zero
+
+ -- Double
+ it "fmsubDouble#: sniff test"
+ ( fmsub @Double two three one ) five
+
+ it "fmsubDouble#: excess precision"
+ ( fmsub @Double d1 d2 one ) d3
+
+ it "fmsubDouble#: +0 + +0 rounds properly"
+ ( fmsub @Double one zero zero ) zero
+
+ it "fmsubDouble#: +0 + -0 rounds properly"
+ ( fmsub @Double one zero -zero ) zero
+
+ it "fmsubDouble#: -0 + +0 rounds properly"
+ ( fmsub @Double one -zero zero ) -zero
+
+ it "fmsubDouble#: -0 + -0 rounds properly"
+ ( fmsub @Double one -zero -zero ) zero
+
+ -- fnmadd: - x * y + z
+
+ -- Float
+ it "fnmaddFloat#: sniff test"
+ ( fnmadd @Float two three one ) -five
+
+ it "fnmaddFloat#: excess precision"
+ ( fnmadd @Float f1 f2 one ) -f3
+
+ it "fnmaddFloat#: +0 + +0 rounds properly"
+ ( fnmadd @Float one zero zero ) zero
+
+ it "fnmaddFloat#: +0 + -0 rounds properly"
+ ( fnmadd @Float one zero -zero ) -zero
+
+ it "fnmaddFloat#: -0 + +0 rounds properly"
+ ( fnmadd @Float one -zero zero ) zero
+
+ it "fnmaddFloat#: -0 + -0 rounds properly"
+ ( fnmadd @Float one -zero -zero ) zero
+
+ -- Double
+ it "fnmaddDouble#: sniff test"
+ ( fnmadd @Double two three one ) -five
+
+ it "fnmaddDouble#: excess precision"
+ ( fnmadd @Double d1 d2 one ) -d3
+
+ it "fnmaddDouble#: +0 + +0 rounds properly"
+ ( fnmadd @Double one zero zero ) zero
+
+ it "fnmaddDouble#: +0 + -0 rounds properly"
+ ( fnmadd @Double one zero -zero ) -zero
+
+ it "fnmaddDouble#: -0 + +0 rounds properly"
+ ( fnmadd @Double one -zero zero ) zero
+
+ it "fnmaddDouble#: -0 + -0 rounds properly"
+ ( fnmadd @Double one -zero -zero ) zero
+
+ -- fnmsub: - x * y - z
+
+ -- Float
+ it "fnmsubFloat#: sniff test"
+ ( fnmsub @Float two three one ) -seven
+
+ it "fnmsubFloat#: excess precision"
+ ( fnmsub @Float f1 f2 -one ) -f3
+
+ it "fnmsubFloat#: +0 + +0 rounds properly"
+ ( fnmsub @Float one zero zero ) -zero
+
+ it "fnmsubFloat#: +0 + -0 rounds properly"
+ ( fnmsub @Float one zero -zero ) zero
+
+ it "fnmsubFloat#: -0 + +0 rounds properly"
+ ( fnmsub @Float one -zero zero ) zero
+
+ it "fnmsubFloat#: -0 + -0 rounds properly"
+ ( fnmsub @Float one -zero -zero ) zero
+
+ -- Double
+ it "fnmsubDouble#: sniff test"
+ ( fnmsub @Double two three one ) -seven
+
+ it "fnmsubDouble#: excess precision"
+ ( fnmsub @Double d1 d2 -one ) -d3
+
+ it "fnmsubDouble#: +0 + +0 rounds properly"
+ ( fnmsub @Double one zero zero ) -zero
+
+ it "fnmsubDouble#: +0 + -0 rounds properly"
+ ( fnmsub @Double one zero -zero ) zero
+
+ it "fnmsubDouble#: -0 + +0 rounds properly"
+ ( fnmsub @Double one -zero zero ) zero
+
+ it "fnmsubDouble#: -0 + -0 rounds properly"
+ ( fnmsub @Double one -zero -zero ) zero
+
+ failure <- readIORef exit_ref
+ if failure
+ then exitFailure
+ else exitSuccess
diff --git a/testsuite/tests/primops/should_run/all.T b/testsuite/tests/primops/should_run/all.T
index 4148546280..da6378df84 100644
--- a/testsuite/tests/primops/should_run/all.T
+++ b/testsuite/tests/primops/should_run/all.T
@@ -59,5 +59,16 @@ test('UnliftedTVar1', normal, compile_and_run, [''])
test('UnliftedTVar2', normal, compile_and_run, [''])
test('UnliftedWeakPtr', normal, compile_and_run, [''])
+test('FMA_Primops'
+ , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma'))
+ , js_skip # JS backend doesn't have an FMA implementation
+ ]
+ , compile_and_run, [''])
+test('FMA_ConstantFold'
+ , [ js_skip # JS backend doesn't have an FMA implementation ]
+ , expect_broken(21227)
+ ]
+ , compile_and_run, ['-O'])
+
test('T21624', normal, compile_and_run, [''])
test('T23071', ignore_stdout, compile_and_run, [''])