From c3386502628712339b89ed607b03b14ddf12c9d2 Mon Sep 17 00:00:00 2001 From: Sebastian Graf Date: Thu, 21 Nov 2019 10:24:39 +0100 Subject: Stricten functions ins GHC.Natural This brings `Natural` on par with `Integer` and fixes #17499. Also does some manual CSE for 0 and 1 literals. --- libraries/base/GHC/Natural.hs | 80 ++++++++++++---------- libraries/base/tests/T17499.hs | 16 +++++ libraries/base/tests/all.T | 1 + .../tests/numeric/should_compile/T14465.stdout | 2 +- 4 files changed, 63 insertions(+), 36 deletions(-) create mode 100644 libraries/base/tests/T17499.hs diff --git a/libraries/base/GHC/Natural.hs b/libraries/base/GHC/Natural.hs index 93c67b6c7a..e65b41a7e2 100644 --- a/libraries/base/GHC/Natural.hs +++ b/libraries/base/GHC/Natural.hs @@ -147,6 +147,9 @@ data Natural = NatS# GmpLimb# -- ^ in @[0, maxBound::Word]@ , Ord -- ^ @since 4.8.0.0 ) +zero, one :: Natural +zero = NatS# 0## +one = NatS# 1## -- | Test whether all internal invariants are satisfied by 'Natural' value -- @@ -162,12 +165,12 @@ isValidNatural (NatJ# bn) = isTrue# (isValidBigNat# bn) && isTrue# (sizeofBigNat# bn ># 1#) signumNatural :: Natural -> Natural -signumNatural (NatS# 0##) = NatS# 0## -signumNatural _ = NatS# 1## +signumNatural (NatS# 0##) = zero +signumNatural _ = one -- {-# CONSTANT_FOLDED signumNatural #-} negateNatural :: Natural -> Natural -negateNatural (NatS# 0##) = NatS# 0## +negateNatural (NatS# 0##) = zero negateNatural _ = underflowError -- {-# CONSTANT_FOLDED negateNatural #-} @@ -183,8 +186,8 @@ naturalFromInteger _ = underflowError gcdNatural :: Natural -> Natural -> Natural gcdNatural (NatS# 0##) y = y gcdNatural x (NatS# 0##) = x -gcdNatural (NatS# 1##) _ = NatS# 1## -gcdNatural _ (NatS# 1##) = NatS# 1## +gcdNatural (NatS# 1##) _ = one +gcdNatural _ (NatS# 1##) = one gcdNatural (NatJ# x) (NatJ# y) = bigNatToNatural (gcdBigNat x y) gcdNatural (NatJ# x) (NatS# y) = NatS# (gcdBigNatWord x y) gcdNatural (NatS# x) (NatJ# y) = NatS# (gcdBigNatWord y x) @@ -192,18 +195,20 @@ gcdNatural (NatS# x) (NatS# y) = NatS# (gcdWord x y) -- | Compute least common multiple. lcmNatural :: Natural -> Natural -> Natural -lcmNatural (NatS# 0##) _ = NatS# 0## -lcmNatural _ (NatS# 0##) = NatS# 0## -lcmNatural (NatS# 1##) y = y -lcmNatural x (NatS# 1##) = x -lcmNatural x y = (x `quotNatural` (gcdNatural x y)) `timesNatural` y +-- Make sure we are strict in all arguments (#17499) +lcmNatural (NatS# 0##) !_ = zero +lcmNatural _ (NatS# 0##) = zero +lcmNatural (NatS# 1##) y = y +lcmNatural x (NatS# 1##) = x +lcmNatural x y = (x `quotNatural` (gcdNatural x y)) `timesNatural` y ---------------------------------------------------------------------------- quotRemNatural :: Natural -> Natural -> (Natural, Natural) -quotRemNatural _ (NatS# 0##) = divZeroError -quotRemNatural n (NatS# 1##) = (n,NatS# 0##) -quotRemNatural n@(NatS# _) (NatJ# _) = (NatS# 0##, n) +-- Make sure we are strict in all arguments (#17499) +quotRemNatural !_ (NatS# 0##) = divZeroError +quotRemNatural n (NatS# 1##) = (n,zero) +quotRemNatural n@(NatS# _) (NatJ# _) = (zero, n) quotRemNatural (NatS# n) (NatS# d) = case quotRemWord# n d of (# q, r #) -> (NatS# q, NatS# r) quotRemNatural (NatJ# n) (NatS# d) = case quotRemBigNatWord n d of @@ -213,21 +218,23 @@ quotRemNatural (NatJ# n) (NatJ# d) = case quotRemBigNat n d of -- {-# CONSTANT_FOLDED quotRemNatural #-} quotNatural :: Natural -> Natural -> Natural -quotNatural _ (NatS# 0##) = divZeroError -quotNatural n (NatS# 1##) = n -quotNatural (NatS# _) (NatJ# _) = NatS# 0## -quotNatural (NatS# n) (NatS# d) = NatS# (quotWord# n d) -quotNatural (NatJ# n) (NatS# d) = bigNatToNatural (quotBigNatWord n d) -quotNatural (NatJ# n) (NatJ# d) = bigNatToNatural (quotBigNat n d) +-- Make sure we are strict in all arguments (#17499) +quotNatural !_ (NatS# 0##) = divZeroError +quotNatural n (NatS# 1##) = n +quotNatural (NatS# _) (NatJ# _) = zero +quotNatural (NatS# n) (NatS# d) = NatS# (quotWord# n d) +quotNatural (NatJ# n) (NatS# d) = bigNatToNatural (quotBigNatWord n d) +quotNatural (NatJ# n) (NatJ# d) = bigNatToNatural (quotBigNat n d) -- {-# CONSTANT_FOLDED quotNatural #-} remNatural :: Natural -> Natural -> Natural -remNatural _ (NatS# 0##) = divZeroError -remNatural _ (NatS# 1##) = NatS# 0## -remNatural n@(NatS# _) (NatJ# _) = n -remNatural (NatS# n) (NatS# d) = NatS# (remWord# n d) -remNatural (NatJ# n) (NatS# d) = NatS# (remBigNatWord n d) -remNatural (NatJ# n) (NatJ# d) = bigNatToNatural (remBigNat n d) +-- Make sure we are strict in all arguments (#17499) +remNatural !_ (NatS# 0##) = divZeroError +remNatural _ (NatS# 1##) = zero +remNatural n@(NatS# _) (NatJ# _) = n +remNatural (NatS# n) (NatS# d) = NatS# (remWord# n d) +remNatural (NatJ# n) (NatS# d) = NatS# (remBigNatWord n d) +remNatural (NatJ# n) (NatJ# d) = bigNatToNatural (remBigNat n d) -- {-# CONSTANT_FOLDED remNatural #-} -- | @since 4.X.0.0 @@ -278,7 +285,7 @@ popCountNatural (NatJ# bn) = I# (popCountBigNat bn) shiftLNatural :: Natural -> Int -> Natural shiftLNatural n (I# 0#) = n -shiftLNatural (NatS# 0##) _ = NatS# 0## +shiftLNatural (NatS# 0##) _ = zero shiftLNatural (NatS# 1##) (I# i#) = bitNatural i# shiftLNatural (NatS# w) (I# i#) = bigNatToNatural (shiftLBigNat (wordToBigNat w) i#) @@ -289,7 +296,7 @@ shiftLNatural (NatJ# bn) (I# i#) shiftRNatural :: Natural -> Int -> Natural shiftRNatural n (I# 0#) = n shiftRNatural (NatS# w) (I# i#) - | isTrue# (i# >=# WORD_SIZE_IN_BITS#) = NatS# 0## + | isTrue# (i# >=# WORD_SIZE_IN_BITS#) = zero | True = NatS# (w `uncheckedShiftRL#` i#) shiftRNatural (NatJ# bn) (I# i#) = bigNatToNatural (shiftRBigNat bn i#) -- {-# CONSTANT_FOLDED shiftRNatural #-} @@ -311,8 +318,9 @@ plusNatural (NatJ# x) (NatJ# y) = NatJ# (plusBigNat x y) -- | 'Natural' multiplication timesNatural :: Natural -> Natural -> Natural -timesNatural _ (NatS# 0##) = NatS# 0## -timesNatural (NatS# 0##) _ = NatS# 0## +-- Make sure we are strict in all arguments (#17499) +timesNatural !_ (NatS# 0##) = zero +timesNatural (NatS# 0##) _ = zero timesNatural x (NatS# 1##) = x timesNatural (NatS# 1##) y = y timesNatural (NatS# x) (NatS# y) = case timesWord2# x y of @@ -342,7 +350,8 @@ minusNatural (NatJ# x) (NatJ# y) -- -- @since 4.8.0.0 minusNaturalMaybe :: Natural -> Natural -> Maybe Natural -minusNaturalMaybe x (NatS# 0##) = Just x +-- Make sure we are strict in all arguments (#17499) +minusNaturalMaybe !x (NatS# 0##) = Just x minusNaturalMaybe (NatS# x) (NatS# y) = case subWordC# x y of (# l, 0# #) -> Just (NatS# l) _ -> Nothing @@ -575,11 +584,12 @@ naturalToWordMaybe (Natural i) -- @since 4.8.0.0 powModNatural :: Natural -> Natural -> Natural -> Natural #if defined(MIN_VERSION_integer_gmp) -powModNatural _ _ (NatS# 0##) = divZeroError -powModNatural _ _ (NatS# 1##) = NatS# 0## -powModNatural _ (NatS# 0##) _ = NatS# 1## -powModNatural (NatS# 0##) _ _ = NatS# 0## -powModNatural (NatS# 1##) _ _ = NatS# 1## +-- Make sure we are strict in all arguments (#17499) +powModNatural !_ !_ (NatS# 0##) = divZeroError +powModNatural _ _ (NatS# 1##) = zero +powModNatural _ (NatS# 0##) _ = one +powModNatural (NatS# 0##) _ _ = zero +powModNatural (NatS# 1##) _ _ = one powModNatural (NatS# b) (NatS# e) (NatS# m) = NatS# (powModWord b e m) powModNatural b e (NatS# m) = NatS# (powModBigNatWord (naturalToBigNat b) (naturalToBigNat e) m) diff --git a/libraries/base/tests/T17499.hs b/libraries/base/tests/T17499.hs new file mode 100644 index 0000000000..512140c1b0 --- /dev/null +++ b/libraries/base/tests/T17499.hs @@ -0,0 +1,16 @@ +import Numeric.Natural + +import Control.Exception (evaluate) + +newtype Mod a = Mod a deriving (Show) + +instance Integral a => Num (Mod a) where + Mod a * Mod b = Mod (a * b `mod` 10000000019) + fromInteger n = Mod (fromInteger n `mod` 10000000019) + +main :: IO () +main = do + -- Should not allocate more compared to Integer + -- _ <- evaluate $ product $ map Mod [(1 :: Integer) .. 1000000] + _ <- evaluate $ product $ map Mod [(1 :: Natural) .. 1000000] + return () diff --git a/libraries/base/tests/all.T b/libraries/base/tests/all.T index 32dfaecf31..e5130d0348 100644 --- a/libraries/base/tests/all.T +++ b/libraries/base/tests/all.T @@ -253,3 +253,4 @@ test('T15349', [exit_code(1), expect_broken_for(15349, ['ghci'])], compile_and_r test('T16111', exit_code(1), compile_and_run, ['']) test('T16943a', normal, compile_and_run, ['']) test('T16943b', normal, compile_and_run, ['']) +test('T17499', collect_stats('bytes allocated',5), compile_and_run, ['-O -w']) diff --git a/testsuite/tests/numeric/should_compile/T14465.stdout b/testsuite/tests/numeric/should_compile/T14465.stdout index df97060635..b7c88c40ac 100644 --- a/testsuite/tests/numeric/should_compile/T14465.stdout +++ b/testsuite/tests/numeric/should_compile/T14465.stdout @@ -72,7 +72,7 @@ minusOne NatS# ds1 -> case ds1 of { __DEFAULT -> GHC.Natural.underflowError @ Natural; - 0## -> GHC.Natural.lcmNatural1 + 0## -> GHC.Natural.zero }; NatJ# ipv -> GHC.Natural.underflowError @ Natural } -- cgit v1.2.1