From 87eebf98cb485f7c9175330051736e147ade9848 Mon Sep 17 00:00:00 2001 From: sheaf Date: Sat, 8 Apr 2023 13:42:58 +0200 Subject: Add fused multiply-add instructions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch adds eight new primops that fuse a multiplication and an addition or subtraction: - `{fmadd,fmsub,fnmadd,fnmsub}{Float,Double}#` fmadd x y z is x * y + z, computed with a single rounding step. This patch implements code generation for these primops in the following backends: - X86, AArch64 and PowerPC NCG, - LLVM - C WASM uses the C implementation. The primops are unsupported in the JavaScript backend. The following constant folding rules are also provided: - compute a * b + c when a, b, c are all literals, - x * y + 0 ==> x * y, - ±1 * y + z ==> z ± y and x * ±1 + z ==> z ± x. NB: the constant folding rules incorrectly handle signed zero. This is a known limitation with GHC's floating-point constant folding rules (#21227), which we hope to resolve in the future. --- testsuite/driver/cpu_features.py | 5 + .../tests/primops/should_run/FMA_ConstantFold.hs | 236 ++++++++++++++++++ testsuite/tests/primops/should_run/FMA_Primops.hs | 264 +++++++++++++++++++++ testsuite/tests/primops/should_run/all.T | 11 + 4 files changed, 516 insertions(+) create mode 100644 testsuite/tests/primops/should_run/FMA_ConstantFold.hs create mode 100644 testsuite/tests/primops/should_run/FMA_Primops.hs (limited to 'testsuite') diff --git a/testsuite/driver/cpu_features.py b/testsuite/driver/cpu_features.py index 43500130db..f6c441d88f 100644 --- a/testsuite/driver/cpu_features.py +++ b/testsuite/driver/cpu_features.py @@ -10,6 +10,7 @@ SUPPORTED_CPU_FEATURES = { # x86: 'sse', 'sse2', 'sse3', 'ssse3', 'sse4_1', 'sse4_2', 'avx1', 'avx2', + 'fma', 'popcnt', 'bmi1', 'bmi2' } @@ -46,6 +47,10 @@ def get_cpu_features(): check_feature('avx2_0', 'avx2') return features + elif config.arch in [ 'powerpc', 'powerpc64' ]: + # Hardcode support for 'fma' on PowerPC + return [ 'fma' ] + else: # TODO: Add {Open,Free}BSD support print('get_cpu_features: Lacking support for your platform') diff --git a/testsuite/tests/primops/should_run/FMA_ConstantFold.hs b/testsuite/tests/primops/should_run/FMA_ConstantFold.hs new file mode 100644 index 0000000000..80bac3231c --- /dev/null +++ b/testsuite/tests/primops/should_run/FMA_ConstantFold.hs @@ -0,0 +1,236 @@ +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE LexicalNegation #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE TypeApplications #-} + +module Main where + +import Control.Monad + ( unless ) +import Data.IORef + ( newIORef, readIORef, writeIORef ) +import GHC.Exts + ( Float(..), Float#, Double(..), Double# + , fmaddFloat# , fmsubFloat# , fnmaddFloat# , fnmsubFloat# + , fmaddDouble#, fmsubDouble#, fnmaddDouble#, fnmsubDouble# + ) +import GHC.Float + ( castFloatToWord32, castDoubleToWord64 ) +import System.Exit + ( exitFailure, exitSuccess ) + +-------------------------------------------------------------------------------- + +-- NB: This test tests constant folding for fused multiply-add operations. +-- See FMA_Primops for the test of the primops. + +-- Want "on-the-nose" equality to make sure we properly distinguish 0.0 and -0.0. +class StrictEq a where + strictlyEqual :: a -> a -> Bool +instance StrictEq Float where + strictlyEqual x y = castFloatToWord32 x == castFloatToWord32 y +instance StrictEq Double where + strictlyEqual x y = castDoubleToWord64 x == castDoubleToWord64 y + +class FMA a where + fmadd, fmsub, fnmadd, fnmsub :: a -> a -> a -> a +instance FMA Float where + fmadd (F# x) (F# y) (F# z) = F# ( fmaddFloat# x y z) + fmsub (F# x) (F# y) (F# z) = F# ( fmsubFloat# x y z) + fnmadd (F# x) (F# y) (F# z) = F# (fnmaddFloat# x y z) + fnmsub (F# x) (F# y) (F# z) = F# (fnmsubFloat# x y z) + {-# INLINE fmadd #-} + {-# INLINE fnmadd #-} + {-# INLINE fmsub #-} + {-# INLINE fnmsub #-} +instance FMA Double where + fmadd (D# x) (D# y) (D# z) = D# ( fmaddDouble# x y z) + fmsub (D# x) (D# y) (D# z) = D# ( fmsubDouble# x y z) + fnmadd (D# x) (D# y) (D# z) = D# (fnmaddDouble# x y z) + fnmsub (D# x) (D# y) (D# z) = D# (fnmsubDouble# x y z) + {-# INLINE fmadd #-} + {-# INLINE fnmadd #-} + {-# INLINE fmsub #-} + {-# INLINE fnmsub #-} + +main :: IO () +main = do + + exit_ref <- newIORef False + + let + it :: forall a. ( StrictEq a, Show a ) => String -> a -> a -> IO () + it desc actual expected = + unless (actual `strictlyEqual` expected) do + writeIORef exit_ref True + putStrLn $ unlines + [ "FAIL " ++ desc + , " expected: " ++ show expected + , " actual: " ++ show actual ] + + -- NB: throughout this test, we are using "round to nearest". + + -- fmadd: x * y + z + + -- Float + it "fmaddFloat#: sniff test" + ( fmadd @Float 2 3 1 ) 7 + + it "fmaddFloat#: excess precision" + ( fmadd @Float 0.99999994 1.00000012 -1 ) 5.96046377e-08 + + it "fmaddFloat#: +0 + +0 rounds properly" + ( fmadd @Float 1 0 0 ) 0 + + it "fmaddFloat#: +0 + -0 rounds properly" + ( fmadd @Float 1 0 -0 ) 0 + + it "fmaddFloat#: -0 + +0 rounds properly" + ( fmadd @Float 1 -0 0 ) 0 + + it "fmaddFloat#: -0 + -0 rounds properly" + ( fmadd @Float 1 -0 -0 ) -0 + + -- Double + it "fmaddDouble#: sniff test" + ( fmadd @Double 2 3 1 ) 7 + + it "fmaddDouble#: excess precision" + ( fmadd @Double 0.99999999999999989 1.0000000000000002 -1 ) 1.1102230246251563e-16 + + it "fmaddDouble#: +0 + +0 rounds properly" + ( fmadd @Double 1 0 0 ) 0 + + it "fmaddDouble#: +0 + -0 rounds properly" + ( fmadd @Double 1 0 -0 ) 0 + + it "fmaddDouble#: -0 + +0 rounds properly" + ( fmadd @Double 1 -0 0 ) 0 + + it "fmaddDouble#: -0 + -0 rounds properly" + ( fmadd @Double 1 -0 -0 ) -0 + + -- fmsub: x * y - z + + -- Float + it "fmsubFloat#: sniff test" + ( fmsub @Float 2 3 1 ) 5.0 + + it "fmsubFloat#: excess precision" + ( fmsub @Float 0.99999994 1.00000012 1 ) 5.96046377e-08 + + it "fmsubFloat#: +0 + +0 rounds properly" + ( fmsub @Float 1 0 0 ) 0 + + it "fmsubFloat#: +0 + -0 rounds properly" + ( fmsub @Float 1 0 -0 ) 0 + + it "fmsubFloat#: -0 + +0 rounds properly" + ( fmsub @Float 1 -0 0 ) -0 + + it "fmsubFloat#: -0 + -0 rounds properly" + ( fmsub @Float 1 -0 -0 ) 0 + + -- Double + it "fmsubDouble#: sniff test" + ( fmsub @Double 2 3 1 ) 5.0 + + it "fmsubDouble#: excess precision" + ( fmsub @Double 0.99999999999999989 1.0000000000000002 1 ) 1.1102230246251563e-16 + + it "fmsubDouble#: +0 + +0 rounds properly" + ( fmsub @Double 1 0 0 ) 0 + + it "fmsubDouble#: +0 + -0 rounds properly" + ( fmsub @Double 1 0 -0 ) 0 + + it "fmsubDouble#: -0 + +0 rounds properly" + ( fmsub @Double 1 -0 0 ) -0 + + it "fmsubDouble#: -0 + -0 rounds properly" + ( fmsub @Double 1 -0 -0 ) 0 + + -- fnmadd: - x * y + z + + -- Float + it "fnmaddFloat#: sniff test" + ( fnmadd @Float 2 3 1 ) -5.0 + + it "fnmaddFloat#: excess precision" + ( fnmadd @Float 0.99999994 1.00000012 1 ) -5.96046377e-08 + + it "fnmaddFloat#: +0 + +0 rounds properly" + ( fnmadd @Float 1 0 0 ) 0 + + it "fnmaddFloat#: +0 + -0 rounds properly" + ( fnmadd @Float 1 0 -0 ) -0 + + it "fnmaddFloat#: -0 + +0 rounds properly" + ( fnmadd @Float 1 -0 0 ) 0 + + it "fnmaddFloat#: -0 + -0 rounds properly" + ( fnmadd @Float 1 -0 -0 ) 0 + + -- Double + it "fnmaddDouble#: sniff test" + ( fnmadd @Double 2 3 1 ) -5.0 + + it "fnmaddDouble#: excess precision" + ( fnmadd @Double 0.99999999999999989 1.0000000000000002 1 ) -1.1102230246251563e-16 + + it "fnmaddDouble#: +0 + +0 rounds properly" + ( fnmadd @Double 1 0 0 ) 0 + + it "fnmaddDouble#: +0 + -0 rounds properly" + ( fnmadd @Double 1 0 -0 ) -0 + + it "fnmaddDouble#: -0 + +0 rounds properly" + ( fnmadd @Double 1 -0 0 ) 0 + + it "fnmaddDouble#: -0 + -0 rounds properly" + ( fnmadd @Double 1 -0 -0 ) 0 + + -- fnmsub: - x * y - z + + -- Float + it "fnmsubFloat#: sniff test" + ( fnmsub @Float 2 3 1 ) -7 + + it "fnmsubFloat#: excess precision" + ( fnmsub @Float 0.99999994 1.00000012 -1 ) -5.96046377e-08 + + it "fnmsubFloat#: +0 + +0 rounds properly" + ( fnmsub @Float 1 0 0 ) -0 + + it "fnmsubFloat#: +0 + -0 rounds properly" + ( fnmsub @Float 1 0 -0 ) 0 + + it "fnmsubFloat#: -0 + +0 rounds properly" + ( fnmsub @Float 1 -0 0 ) 0 + + it "fnmsubFloat#: -0 + -0 rounds properly" + ( fnmsub @Float 1 -0 -0 ) 0 + + -- Double + it "fnmsubDouble#: sniff test" + ( fnmsub @Double 2 3 1 ) -7 + + it "fnmsubDouble#: excess precision" + ( fnmsub @Double 0.99999999999999989 1.0000000000000002 -1 ) -1.1102230246251563e-16 + + it "fnmsubDouble#: +0 + +0 rounds properly" + ( fnmsub @Double 1 0 0 ) -0 + + it "fnmsubDouble#: +0 + -0 rounds properly" + ( fnmsub @Double 1 0 -0 ) 0 + + it "fnmsubDouble#: -0 + +0 rounds properly" + ( fnmsub @Double 1 -0 0 ) 0 + + it "fnmsubDouble#: -0 + -0 rounds properly" + ( fnmsub @Double 1 -0 -0 ) 0 + + failure <- readIORef exit_ref + if failure + then exitFailure + else exitSuccess diff --git a/testsuite/tests/primops/should_run/FMA_Primops.hs b/testsuite/tests/primops/should_run/FMA_Primops.hs new file mode 100644 index 0000000000..a925ff6c3b --- /dev/null +++ b/testsuite/tests/primops/should_run/FMA_Primops.hs @@ -0,0 +1,264 @@ +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE LexicalNegation #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE TypeApplications #-} + +module Main where + +import Control.Monad + ( unless ) +import Data.IORef + ( newIORef, readIORef, writeIORef ) +import GHC.Exts + ( Float(..), Float#, Double(..), Double# + , fmaddFloat# , fmsubFloat# , fnmaddFloat# , fnmsubFloat# + , fmaddDouble#, fmsubDouble#, fnmaddDouble#, fnmsubDouble# + ) +import GHC.Float + ( castFloatToWord32, castDoubleToWord64 ) +import System.Exit + ( exitFailure, exitSuccess ) + +-------------------------------------------------------------------------------- + +-- Want "on-the-nose" equality to make sure we properly distinguish 0.0 and -0.0. +class StrictEq a where + strictlyEqual :: a -> a -> Bool +instance StrictEq Float where + strictlyEqual x y = castFloatToWord32 x == castFloatToWord32 y +instance StrictEq Double where + strictlyEqual x y = castDoubleToWord64 x == castDoubleToWord64 y + +class FMA a where + fmadd, fmsub, fnmadd, fnmsub :: a -> a -> a -> a +instance FMA Float where + fmadd (F# x) (F# y) (F# z) = F# ( fmaddFloat# x y z) + fmsub (F# x) (F# y) (F# z) = F# ( fmsubFloat# x y z) + fnmadd (F# x) (F# y) (F# z) = F# (fnmaddFloat# x y z) + fnmsub (F# x) (F# y) (F# z) = F# (fnmsubFloat# x y z) + {-# INLINE fmadd #-} + {-# INLINE fnmadd #-} + {-# INLINE fmsub #-} + {-# INLINE fnmsub #-} +instance FMA Double where + fmadd (D# x) (D# y) (D# z) = D# ( fmaddDouble# x y z) + fmsub (D# x) (D# y) (D# z) = D# ( fmsubDouble# x y z) + fnmadd (D# x) (D# y) (D# z) = D# (fnmaddDouble# x y z) + fnmsub (D# x) (D# y) (D# z) = D# (fnmsubDouble# x y z) + {-# INLINE fmadd #-} + {-# INLINE fnmadd #-} + {-# INLINE fmsub #-} + {-# INLINE fnmsub #-} + +f1, f2, f3 :: Float +f1 = 0.99999994 -- float before 1 +f2 = 1.00000012 -- float after 1 +f3 = 5.96046377e-08 -- float after 0 +d1, d2, d3 :: Double +d1 = 0.99999999999999989 -- double before 1 +d2 = 1.0000000000000002 -- double after 1 +d3 = 1.1102230246251563e-16 -- double after 0 +zero, one, two, three, five, seven :: Num a => a +zero = 0 +one = 1 +two = 2 +three = 3 +five = 5 +seven = 7 + +-- NOINLINE to prevent constant folding +-- (The test FMA_ConstantFold tests constant folding.) +{-# NOINLINE f1 #-} +{-# NOINLINE f2 #-} +{-# NOINLINE f3 #-} +{-# NOINLINE d1 #-} +{-# NOINLINE d2 #-} +{-# NOINLINE d3 #-} +{-# NOINLINE zero #-} +{-# NOINLINE one #-} +{-# NOINLINE two #-} +{-# NOINLINE three #-} +{-# NOINLINE five #-} +{-# NOINLINE seven #-} + +main :: IO () +main = do + + exit_ref <- newIORef False + + let + it :: forall a. ( StrictEq a, Show a ) => String -> a -> a -> IO () + it desc actual expected = + unless (actual `strictlyEqual` expected) do + writeIORef exit_ref True + putStrLn $ unlines + [ "FAIL " ++ desc + , " expected: " ++ show expected + , " actual: " ++ show actual ] + + -- NB: throughout this test, we are using "round to nearest". + + -- fmadd: x * y + z + + -- Float + it "fmaddFloat#: sniff test" + ( fmadd @Float two three one ) seven + + it "fmaddFloat#: excess precision" + ( fmadd @Float f1 f2 -one ) f3 + + it "fmaddFloat#: +0 + +0 rounds properly" + ( fmadd @Float one zero zero ) zero + + it "fmaddFloat#: +0 + -0 rounds properly" + ( fmadd @Float one zero -zero ) zero + + it "fmaddFloat#: -0 + +0 rounds properly" + ( fmadd @Float one -zero zero ) zero + + it "fmaddFloat#: -0 + -0 rounds properly" + ( fmadd @Float one -zero -zero ) -zero + + -- Double + it "fmaddDouble#: sniff test" + ( fmadd @Double two three one ) seven + + it "fmaddDouble#: excess precision" + ( fmadd @Double d1 d2 -one ) d3 + + it "fmaddDouble#: +0 + +0 rounds properly" + ( fmadd @Double one zero zero ) zero + + it "fmaddDouble#: +0 + -0 rounds properly" + ( fmadd @Double one zero -zero ) zero + + it "fmaddDouble#: -0 + +0 rounds properly" + ( fmadd @Double one -zero zero ) zero + + it "fmaddDouble#: -0 + -0 rounds properly" + ( fmadd @Double one -zero -zero ) -zero + + -- fmsub: x * y - z + + -- Float + it "fmsubFloat#: sniff test" + ( fmsub @Float two three one ) five + + it "fmsubFloat#: excess precision" + ( fmsub @Float f1 f2 one ) f3 + + it "fmsubFloat#: +0 + +0 rounds properly" + ( fmsub @Float one zero zero ) zero + + it "fmsubFloat#: +0 + -0 rounds properly" + ( fmsub @Float one zero -zero ) zero + + it "fmsubFloat#: -0 + +0 rounds properly" + ( fmsub @Float one -zero zero ) -zero + + it "fmsubFloat#: -0 + -0 rounds properly" + ( fmsub @Float one -zero -zero ) zero + + -- Double + it "fmsubDouble#: sniff test" + ( fmsub @Double two three one ) five + + it "fmsubDouble#: excess precision" + ( fmsub @Double d1 d2 one ) d3 + + it "fmsubDouble#: +0 + +0 rounds properly" + ( fmsub @Double one zero zero ) zero + + it "fmsubDouble#: +0 + -0 rounds properly" + ( fmsub @Double one zero -zero ) zero + + it "fmsubDouble#: -0 + +0 rounds properly" + ( fmsub @Double one -zero zero ) -zero + + it "fmsubDouble#: -0 + -0 rounds properly" + ( fmsub @Double one -zero -zero ) zero + + -- fnmadd: - x * y + z + + -- Float + it "fnmaddFloat#: sniff test" + ( fnmadd @Float two three one ) -five + + it "fnmaddFloat#: excess precision" + ( fnmadd @Float f1 f2 one ) -f3 + + it "fnmaddFloat#: +0 + +0 rounds properly" + ( fnmadd @Float one zero zero ) zero + + it "fnmaddFloat#: +0 + -0 rounds properly" + ( fnmadd @Float one zero -zero ) -zero + + it "fnmaddFloat#: -0 + +0 rounds properly" + ( fnmadd @Float one -zero zero ) zero + + it "fnmaddFloat#: -0 + -0 rounds properly" + ( fnmadd @Float one -zero -zero ) zero + + -- Double + it "fnmaddDouble#: sniff test" + ( fnmadd @Double two three one ) -five + + it "fnmaddDouble#: excess precision" + ( fnmadd @Double d1 d2 one ) -d3 + + it "fnmaddDouble#: +0 + +0 rounds properly" + ( fnmadd @Double one zero zero ) zero + + it "fnmaddDouble#: +0 + -0 rounds properly" + ( fnmadd @Double one zero -zero ) -zero + + it "fnmaddDouble#: -0 + +0 rounds properly" + ( fnmadd @Double one -zero zero ) zero + + it "fnmaddDouble#: -0 + -0 rounds properly" + ( fnmadd @Double one -zero -zero ) zero + + -- fnmsub: - x * y - z + + -- Float + it "fnmsubFloat#: sniff test" + ( fnmsub @Float two three one ) -seven + + it "fnmsubFloat#: excess precision" + ( fnmsub @Float f1 f2 -one ) -f3 + + it "fnmsubFloat#: +0 + +0 rounds properly" + ( fnmsub @Float one zero zero ) -zero + + it "fnmsubFloat#: +0 + -0 rounds properly" + ( fnmsub @Float one zero -zero ) zero + + it "fnmsubFloat#: -0 + +0 rounds properly" + ( fnmsub @Float one -zero zero ) zero + + it "fnmsubFloat#: -0 + -0 rounds properly" + ( fnmsub @Float one -zero -zero ) zero + + -- Double + it "fnmsubDouble#: sniff test" + ( fnmsub @Double two three one ) -seven + + it "fnmsubDouble#: excess precision" + ( fnmsub @Double d1 d2 -one ) -d3 + + it "fnmsubDouble#: +0 + +0 rounds properly" + ( fnmsub @Double one zero zero ) -zero + + it "fnmsubDouble#: +0 + -0 rounds properly" + ( fnmsub @Double one zero -zero ) zero + + it "fnmsubDouble#: -0 + +0 rounds properly" + ( fnmsub @Double one -zero zero ) zero + + it "fnmsubDouble#: -0 + -0 rounds properly" + ( fnmsub @Double one -zero -zero ) zero + + failure <- readIORef exit_ref + if failure + then exitFailure + else exitSuccess diff --git a/testsuite/tests/primops/should_run/all.T b/testsuite/tests/primops/should_run/all.T index 4148546280..da6378df84 100644 --- a/testsuite/tests/primops/should_run/all.T +++ b/testsuite/tests/primops/should_run/all.T @@ -59,5 +59,16 @@ test('UnliftedTVar1', normal, compile_and_run, ['']) test('UnliftedTVar2', normal, compile_and_run, ['']) test('UnliftedWeakPtr', normal, compile_and_run, ['']) +test('FMA_Primops' + , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma')) + , js_skip # JS backend doesn't have an FMA implementation + ] + , compile_and_run, ['']) +test('FMA_ConstantFold' + , [ js_skip # JS backend doesn't have an FMA implementation ] + , expect_broken(21227) + ] + , compile_and_run, ['-O']) + test('T21624', normal, compile_and_run, ['']) test('T23071', ignore_stdout, compile_and_run, ['']) -- cgit v1.2.1