diff options
author | Herbert Valerio Riedel <hvr@gnu.org> | 2014-01-09 00:19:31 +0100 |
---|---|---|
committer | Herbert Valerio Riedel <hvr@gnu.org> | 2014-01-13 12:42:02 +0100 |
commit | 7fdd02695d7cff44b534d82dfcdb5a98191e35cf (patch) | |
tree | 2fa096d715216416b771dd50adb2220bc9d88132 /libraries/integer-gmp | |
parent | 069a49cc578fd8d7227e00af5e56b56c6a01efe6 (diff) | |
download | haskell-7fdd02695d7cff44b534d82dfcdb5a98191e35cf.tar.gz |
Allocate initial 1-limb mpz_t on the Stack and introduce MPZ# type
We now allocate a 1-limb mpz_t on the stack instead of doing a more
expensive heap-allocation (especially if the heap-allocated copy becomes
garbage right away); this addresses #8647.
In order to delay heap allocations of 1-limb `ByteArray#`s instead of
the previous `(# Int#, ByteArray# #)` pair, a 3-tuple
`(# Int#, ByteArray#, Word# #)` is returned now. This tuple is given the
type-synonym `MPZ#`.
This 3-tuple representation uses either the 1st and the 2nd element, or
the 1st and the 3rd element to represent the limb(s) (NB: undefined
`ByteArray#` elements must not be accessed as they don't point to a
proper `ByteArray#`, see also `DUMMY_BYTE_ARR`); more specifically, the
following encoding is used (where `⊥` means undefined/unused):
- (# 0#, ⊥, 0## #) -> value = 0
- (# 1#, ⊥, w #) -> value = w
- (# -1#, ⊥, w #) -> value = -w
- (# s#, d, 0## #) -> value = J# s d
The `mpzToInteger` helper takes care of converting `MPZ#` into an
`Integer`, and allocating a 1-limb `ByteArray#` in case the
value (`w`/`-w`) doesn't fit the `S# Int#` representation).
The following nofib benchmarks benefit from this optimization:
Program Size Allocs Runtime Elapsed TotalMem
------------------------------------------------------------------
bernouilli +0.2% -5.2% 0.12 0.12 +0.0%
gamteb +0.2% -1.7% 0.03 0.03 +0.0%
kahan +0.3% -13.2% 0.17 0.17 +0.0%
mandel +0.2% -24.6% 0.04 0.04 +0.0%
power +0.2% -2.6% -2.0% -2.0% -8.3%
primetest +0.1% -17.3% 0.06 0.06 +0.0%
rsa +0.2% -18.5% 0.02 0.02 +0.0%
scs +0.1% -2.9% -0.1% -0.1% +0.0%
sphere +0.3% -0.8% 0.03 0.03 +0.0%
symalg +0.2% -3.1% 0.01 0.01 +0.0%
------------------------------------------------------------------
Min +0.1% -24.6% -4.6% -4.6% -8.3%
Max +0.3% +0.0% +5.9% +5.9% +4.5%
Geometric Mean +0.2% -1.0% +0.2% +0.2% -0.0%
Signed-off-by: Herbert Valerio Riedel <hvr@gnu.org>
Diffstat (limited to 'libraries/integer-gmp')
-rw-r--r-- | libraries/integer-gmp/GHC/Integer/GMP/Prim.hs | 105 | ||||
-rw-r--r-- | libraries/integer-gmp/GHC/Integer/Type.lhs | 164 | ||||
-rw-r--r-- | libraries/integer-gmp/cbits/gmp-wrappers.cmm | 224 |
3 files changed, 306 insertions, 187 deletions
diff --git a/libraries/integer-gmp/GHC/Integer/GMP/Prim.hs b/libraries/integer-gmp/GHC/Integer/GMP/Prim.hs index 261df29897..3790345dcb 100644 --- a/libraries/integer-gmp/GHC/Integer/GMP/Prim.hs +++ b/libraries/integer-gmp/GHC/Integer/GMP/Prim.hs @@ -4,6 +4,8 @@ #include "MachDeps.h" module GHC.Integer.GMP.Prim ( + MPZ#, + cmpInteger#, cmpIntegerInt#, @@ -79,6 +81,41 @@ import GHC.Types -- Double isn't available yet, and we shouldn't be using defaults anyway: default () +-- | This is represents a @mpz_t@ value in a heap-saving way. +-- +-- The first tuple element, @/s/@, encodes the sign of the integer +-- @/i/@ (i.e. @signum /s/ == signum /i/@), and the number of /limbs/ +-- used to represent the magnitude. If @abs /s/ > 1@, the 'ByteArray#' +-- contains @abs /s/@ limbs encoding the integer. Otherwise, if @abs +-- /s/ < 2@, the single limb is stored in the 'Word#' element instead +-- (and the 'ByteArray#' element is undefined and MUST NOT be accessed +-- as it doesn't point to a proper 'ByteArray#' but rather to an +-- unsafe-coerced 'Int' in order be polite to the GC -- see +-- @DUMMY_BYTE_ARR@ in gmp-wrappers.cmm) +-- +-- More specifically, the following encoding is used (where `⊥` means +-- undefined/unused): +-- +-- * (# 0#, ⊥, 0## #) -> value = 0 +-- * (# 1#, ⊥, w #) -> value = w +-- * (# -1#, ⊥, w #) -> value = -w +-- * (# s#, d, 0## #) -> value = J# s d +-- +-- This representation allows to avoid temporary heap allocations +-- (-> Trac #8647) of 1-limb 'ByteArray#'s which fit into the +-- 'S#'-constructor. Moreover, this allows to delays 1-limb +-- 'ByteArray#' heap allocations, as such 1-limb `mpz_t`s can be +-- optimistically allocated on the Cmm stack and returned as a @#word@ +-- in case the `mpz_t` wasn't grown beyond 1 limb by the GMP +-- operation. +-- +-- See also the 'GHC.Integer.Type.mpzToInteger' function which ought +-- to be used for converting 'MPZ#'s to 'Integer's and the +-- @MP_INT_1LIMB_RETURN()@ macro in @gmp-wrappers.cmm@ which +-- constructs 'MPZ#' values in the first place for implementation +-- details. +type MPZ# = (# Int#, ByteArray#, Word# #) + -- | Returns -1,0,1 according as first argument is less than, equal to, or greater than second argument. -- foreign import prim "integer_cmm_cmpIntegerzh" cmpInteger# @@ -92,87 +129,87 @@ foreign import prim "integer_cmm_cmpIntegerIntzh" cmpIntegerInt# -- | -- foreign import prim "integer_cmm_plusIntegerzh" plusInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | Optimized version of 'plusInteger#' for summing big-ints with small-ints -- foreign import prim "integer_cmm_plusIntegerIntzh" plusIntegerInt# - :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> MPZ# -- | -- foreign import prim "integer_cmm_minusIntegerzh" minusInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | Optimized version of 'minusInteger#' for substracting small-ints from big-ints -- foreign import prim "integer_cmm_minusIntegerIntzh" minusIntegerInt# - :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> MPZ# -- | -- foreign import prim "integer_cmm_timesIntegerzh" timesInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | Optimized version of 'timesInteger#' for multiplying big-ints with small-ints -- foreign import prim "integer_cmm_timesIntegerIntzh" timesIntegerInt# - :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> MPZ# -- | Compute div and mod simultaneously, where div rounds towards negative -- infinity and\ @(q,r) = divModInteger#(x,y)@ implies -- @plusInteger# (timesInteger# q y) r = x@. -- foreign import prim "integer_cmm_quotRemIntegerzh" quotRemInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray#, Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# MPZ#, MPZ# #) -- | Variant of 'quotRemInteger#' -- foreign import prim "integer_cmm_quotRemIntegerWordzh" quotRemIntegerWord# - :: Int# -> ByteArray# -> Word# -> (# Int#, ByteArray#, Int#, ByteArray# #) + :: Int# -> ByteArray# -> Word# -> (# MPZ#, MPZ# #) -- | Rounds towards zero. -- foreign import prim "integer_cmm_quotIntegerzh" quotInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | Rounds towards zero. foreign import prim "integer_cmm_quotIntegerWordzh" quotIntegerWord# - :: Int# -> ByteArray# -> Word# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Word# -> MPZ# -- | Satisfies \texttt{plusInteger\# (timesInteger\# (quotInteger\# x y) y) (remInteger\# x y) == x}. -- foreign import prim "integer_cmm_remIntegerzh" remInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | Variant of 'remInteger#' foreign import prim "integer_cmm_remIntegerWordzh" remIntegerWord# - :: Int# -> ByteArray# -> Word# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Word# -> MPZ# -- | Compute div and mod simultaneously, where div rounds towards negative infinity -- and\texttt{(q,r) = divModInteger\#(x,y)} implies \texttt{plusInteger\# (timesInteger\# q y) r = x}. -- foreign import prim "integer_cmm_divModIntegerzh" divModInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray#, Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# MPZ#, MPZ# #) foreign import prim "integer_cmm_divIntegerzh" divInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# foreign import prim "integer_cmm_modIntegerzh" modInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | Divisor is guaranteed to be a factor of dividend. -- foreign import prim "integer_cmm_divExactIntegerzh" divExactInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | Greatest common divisor. -- foreign import prim "integer_cmm_gcdIntegerzh" gcdInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | Extended greatest common divisor. -- foreign import prim "integer_cmm_gcdExtIntegerzh" gcdExtInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray#, Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# MPZ#, MPZ# #) -- | Greatest common divisor, where second argument is an ordinary {\tt Int\#}. -- @@ -189,32 +226,34 @@ foreign import prim "integer_cmm_gcdIntzh" gcdInt# -- represent an {\tt Integer\#} holding the mantissa. -- foreign import prim "integer_cmm_decodeDoublezh" decodeDouble# - :: Double# -> (# Int#, Int#, ByteArray# #) + :: Double# -> (# Int#, MPZ# #) -- | -- +-- Note: This primitive doesn't use 'MPZ#' because its purpose is to instantiate a 'J#'-value. foreign import prim "integer_cmm_int2Integerzh" int2Integer# :: Int# -> (# Int#, ByteArray# #) -- | -- +-- Note: This primitive doesn't use 'MPZ#' because its purpose is to instantiate a 'J#'-value. foreign import prim "integer_cmm_word2Integerzh" word2Integer# :: Word# -> (# Int#, ByteArray# #) -- | -- foreign import prim "integer_cmm_andIntegerzh" andInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | -- foreign import prim "integer_cmm_orIntegerzh" orInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | -- foreign import prim "integer_cmm_xorIntegerzh" xorInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | -- @@ -224,37 +263,37 @@ foreign import prim "integer_cmm_testBitIntegerzh" testBitInteger# -- | -- foreign import prim "integer_cmm_mul2ExpIntegerzh" mul2ExpInteger# - :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> MPZ# -- | -- foreign import prim "integer_cmm_fdivQ2ExpIntegerzh" fdivQ2ExpInteger# - :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> MPZ# -- | -- foreign import prim "integer_cmm_powIntegerzh" powInteger# - :: Int# -> ByteArray# -> Word# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Word# -> MPZ# -- | -- foreign import prim "integer_cmm_powModIntegerzh" powModInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | -- foreign import prim "integer_cmm_powModSecIntegerzh" powModSecInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | -- foreign import prim "integer_cmm_recipModIntegerzh" recipModInteger# - :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ# -- | -- foreign import prim "integer_cmm_nextPrimeIntegerzh" nextPrimeInteger# - :: Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> MPZ# -- | -- @@ -269,12 +308,12 @@ foreign import prim "integer_cmm_sizeInBasezh" sizeInBaseInteger# -- | -- foreign import prim "integer_cmm_importIntegerFromByteArrayzh" importIntegerFromByteArray# - :: ByteArray# -> Word# -> Word# -> Int# -> (# Int#, ByteArray# #) + :: ByteArray# -> Word# -> Word# -> Int# -> MPZ# -- | -- foreign import prim "integer_cmm_importIntegerFromAddrzh" importIntegerFromAddr# - :: Addr# -> Word# -> Int# -> State# s -> (# State# s, Int#, ByteArray# #) + :: Addr# -> Word# -> Int# -> State# s -> (# State# s, MPZ# #) -- | -- @@ -289,12 +328,14 @@ foreign import prim "integer_cmm_exportIntegerToAddrzh" exportIntegerToAddr# -- | -- foreign import prim "integer_cmm_complementIntegerzh" complementInteger# - :: Int# -> ByteArray# -> (# Int#, ByteArray# #) + :: Int# -> ByteArray# -> MPZ# #if WORD_SIZE_IN_BITS < 64 +-- Note: This primitive doesn't use 'MPZ#' because its purpose is to instantiate a 'J#'-value. foreign import prim "integer_cmm_int64ToIntegerzh" int64ToInteger# :: Int64# -> (# Int#, ByteArray# #) +-- Note: This primitive doesn't use 'MPZ#' because its purpose is to instantiate a 'J#'-value. foreign import prim "integer_cmm_word64ToIntegerzh" word64ToInteger# :: Word64# -> (# Int#, ByteArray# #) diff --git a/libraries/integer-gmp/GHC/Integer/Type.lhs b/libraries/integer-gmp/GHC/Integer/Type.lhs index 731c5fcdf1..ab4fe9d67d 100644 --- a/libraries/integer-gmp/GHC/Integer/Type.lhs +++ b/libraries/integer-gmp/GHC/Integer/Type.lhs @@ -37,6 +37,7 @@ import GHC.Prim ( import GHC.Integer.GMP.Prim ( -- GMP-related primitives + MPZ#, cmpInteger#, cmpIntegerInt#, plusInteger#, plusIntegerInt#, minusInteger#, minusIntegerInt#, timesInteger#, timesIntegerInt#, @@ -172,6 +173,37 @@ smartJ# (-1#) mb# | isTrue# (v <# 0#) = S# v where v = negateInt# (indexIntArray# mb# 0#) smartJ# s# mb# = J# s# mb# + +-- |Construct 'Integer' out of a 'MPZ#' as returned by GMP wrapper primops +-- +-- IMPORTANT: The 'ByteArray#' element MUST NOT be accessed unless the +-- size-element indicates more than one limb! +-- +-- See notes at definition site of 'MPZ#' in "GHC.Integer.GMP.Prim" +-- for more details. +mpzToInteger :: MPZ# -> Integer +mpzToInteger (# 0#, _, _ #) = S# 0# +mpzToInteger (# 1#, _, w# #) | isTrue# (v# >=# 0#) = S# v# + | True = case word2Integer# w# of (# _, d #) -> J# 1# d + where + v# = word2Int# w# +mpzToInteger (# -1#, _, w# #) | isTrue# (v# <=# 0#) = S# v# + | True = case word2Integer# w# of (# _, d #) -> J# -1# d + where + v# = negateInt# (word2Int# w#) +mpzToInteger (# s#, mb#, _ #) = J# s# mb# + +-- | Variant of 'mpzToInteger' for pairs of 'Integer's +mpzToInteger2 :: (# MPZ#, MPZ# #) -> (# Integer, Integer #) +mpzToInteger2 (# mpz1, mpz2 #) = (# i1, i2 #) + where + !i1 = mpzToInteger mpz1 -- This use of `!` avoids creating thunks, + !i2 = mpzToInteger mpz2 -- see also Note [Use S# if possible]. + +-- |Negate MPZ# +mpzNeg :: MPZ# -> MPZ# +mpzNeg (# s#, mb#, w# #) = (# negateInt# s#, mb#, w# #) + \end{code} Note [Use S# if possible] @@ -221,26 +253,19 @@ Just using smartJ# in this way has good results: {-# NOINLINE quotRemInteger #-} quotRemInteger :: Integer -> Integer -> (# Integer, Integer #) -quotRemInteger a@(S# INT_MINBOUND) b = quotRemInteger (toBig a) b +quotRemInteger (S# INT_MINBOUND) b = quotRemInteger minIntAsBig b quotRemInteger (S# i) (S# j) = case quotRemInt# i j of (# q, r #) -> (# S# q, S# r #) quotRemInteger (J# s1 d1) (S# b) | isTrue# (b <# 0#) = case quotRemIntegerWord# s1 d1 (int2Word# (negateInt# b)) of - (# s3, d3, s4, d4 #) -> let !q = smartJ# (negateInt# s3) d3 - !r = smartJ# s4 d4 - in (# q, r #) + (# q, r #) -> let !q' = mpzToInteger(mpzNeg q) + !r' = mpzToInteger(mpzNeg r) + in (# q', r' #) quotRemInteger (J# s1 d1) (S# b) - = case quotRemIntegerWord# s1 d1 (int2Word# b) of - (# s3, d3, s4, d4 #) -> let !q = smartJ# s3 d3 - !r = smartJ# s4 d4 - in (# q, r #) + = mpzToInteger2(quotRemIntegerWord# s1 d1 (int2Word# b)) quotRemInteger i1@(S# _) i2@(J# _ _) = quotRemInteger (toBig i1) i2 quotRemInteger (J# s1 d1) (J# s2 d2) - = case (quotRemInteger# s1 d1 s2 d2) of - (# s3, d3, s4, d4 #) -> let !q = smartJ# s3 d3 - !r = smartJ# s4 d4 - in (# q, r #) - -- See Note [Use S# if possible] + = mpzToInteger2(quotRemInteger# s1 d1 s2 d2) -- See Note [Use S# if possible] {-# NOINLINE divModInteger #-} divModInteger :: Integer -> Integer -> (# Integer, Integer #) @@ -256,11 +281,7 @@ divModInteger (S# i) (S# j) = (# S# d, S# m #) divModInteger i1@(J# _ _) i2@(S# _) = divModInteger i1 (toBig i2) divModInteger i1@(S# _) i2@(J# _ _) = divModInteger (toBig i1) i2 -divModInteger (J# s1 d1) (J# s2 d2) - = case (divModInteger# s1 d1 s2 d2) of - (# s3, d3, s4, d4 #) -> let !q = smartJ# s3 d3 - !r = smartJ# s4 d4 - in (# q, r #) +divModInteger (J# s1 d1) (J# s2 d2) = mpzToInteger2 (divModInteger# s1 d1 s2 d2) {-# NOINLINE remInteger #-} remInteger :: Integer -> Integer -> Integer @@ -276,12 +297,11 @@ remInteger ia@(S# a) (J# sb b) -} remInteger ia@(S# _) ib@(J# _ _) = remInteger (toBig ia) ib remInteger (J# sa a) (S# b) - = case remIntegerWord# sa a w of - (# sr, r #) -> smartJ# sr r + = mpzToInteger (remIntegerWord# sa a w) where w = int2Word# (if isTrue# (b <# 0#) then negateInt# b else b) remInteger (J# sa a) (J# sb b) - = case remInteger# sa a sb b of (# sr, r #) -> smartJ# sr r + = mpzToInteger (remInteger# sa a sb b) {-# NOINLINE quotInteger #-} quotInteger :: Integer -> Integer -> Integer @@ -295,13 +315,11 @@ quotInteger (S# a) (J# sb b) -} quotInteger ia@(S# _) ib@(J# _ _) = quotInteger (toBig ia) ib quotInteger (J# sa a) (S# b) | isTrue# (b <# 0#) - = case quotIntegerWord# sa a (int2Word# (negateInt# b)) of - (# sq, q #) -> smartJ# (negateInt# sq) q + = mpzToInteger (mpzNeg (quotIntegerWord# sa a (int2Word# (negateInt# b)))) quotInteger (J# sa a) (S# b) - = case quotIntegerWord# sa a (int2Word# b) of - (# sq, q #) -> smartJ# sq q + = mpzToInteger (quotIntegerWord# sa a (int2Word# b)) quotInteger (J# sa a) (J# sb b) - = case quotInteger# sa a sb b of (# sg, g #) -> smartJ# sg g + = mpzToInteger (quotInteger# sa a sb b) {-# NOINLINE modInteger #-} modInteger :: Integer -> Integer -> Integer @@ -310,10 +328,9 @@ modInteger (S# a) (S# b) = S# (modInt# a b) modInteger ia@(S# _) ib@(J# _ _) = modInteger (toBig ia) ib modInteger (J# sa a) (S# b) = case int2Integer# b of { (# sb, b' #) -> - case modInteger# sa a sb b' of { (# sr, r #) -> - S# (integer2Int# sr r) }} + mpzToInteger (modInteger# sa a sb b') } modInteger (J# sa a) (J# sb b) - = case modInteger# sa a sb b of (# sr, r #) -> smartJ# sr r + = mpzToInteger (modInteger# sa a sb b) {-# NOINLINE divInteger #-} divInteger :: Integer -> Integer -> Integer @@ -321,10 +338,9 @@ divInteger (S# INT_MINBOUND) b = divInteger minIntAsBig b divInteger (S# a) (S# b) = S# (divInt# a b) divInteger ia@(S# _) ib@(J# _ _) = divInteger (toBig ia) ib divInteger (J# sa a) (S# b) - = case int2Integer# b of { (# sb, b' #) -> - case divInteger# sa a sb b' of (# sq, q #) -> smartJ# sq q } + = case int2Integer# b of { (# sb, b' #) -> mpzToInteger (divInteger# sa a sb b') } divInteger (J# sa a) (J# sb b) - = case divInteger# sa a sb b of (# sg, g #) -> smartJ# sg g + = mpzToInteger (divInteger# sa a sb b) \end{code} @@ -344,8 +360,7 @@ gcdInteger ia@(S# a) ib@(J# sb b) where !absA = if isTrue# (a <# 0#) then negateInt# a else a !absSb = if isTrue# (sb <# 0#) then negateInt# sb else sb gcdInteger ia@(J# _ _) ib@(S# _) = gcdInteger ib ia -gcdInteger (J# sa a) (J# sb b) - = case gcdInteger# sa a sb b of (# sg, g #) -> smartJ# sg g +gcdInteger (J# sa a) (J# sb b) = mpzToInteger (gcdInteger# sa a sb b) -- | Extended euclidean algorithm. -- @@ -356,11 +371,7 @@ gcdExtInteger :: Integer -> Integer -> (# Integer, Integer #) gcdExtInteger a@(S# _) b@(S# _) = gcdExtInteger (toBig a) (toBig b) gcdExtInteger a@(S# _) b@(J# _ _) = gcdExtInteger (toBig a) b gcdExtInteger a@(J# _ _) b@(S# _) = gcdExtInteger a (toBig b) -gcdExtInteger (J# sa a) (J# sb b) - = case gcdExtInteger# sa a sb b of - (# sg, g, ss, s #) -> let !g' = smartJ# sg g - !s' = smartJ# ss s - in (# g', s' #) +gcdExtInteger (J# sa a) (J# sb b) = mpzToInteger2 (gcdExtInteger# sa a sb b) -- | Compute least common multiple. {-# NOINLINE lcmInteger #-} @@ -387,10 +398,8 @@ divExact (S# a) (J# sb b) = S# (quotInt# a (integer2Int# sb b)) divExact (J# sa a) (S# b) = case int2Integer# b of - (# sb, b' #) -> case divExactInteger# sa a sb b' of - (# sd, d #) -> smartJ# sd d -divExact (J# sa a) (J# sb b) - = case divExactInteger# sa a sb b of (# sd, d #) -> smartJ# sd d + (# sb, b' #) -> mpzToInteger (divExactInteger# sa a sb b') +divExact (J# sa a) (J# sb b) = mpzToInteger (divExactInteger# sa a sb b) \end{code} @@ -529,14 +538,11 @@ plusInteger (S# i) (S# j) = case addIntC# i j of if isTrue# (c ==# 0#) then S# r else case int2Integer# i of - (# s, d #) -> case plusIntegerInt# s d j of - (# s', d' #) -> J# s' d' + (# s, d #) -> mpzToInteger (plusIntegerInt# s d j) plusInteger i1@(J# _ _) (S# 0#) = i1 -plusInteger (J# s1 d1) (S# j) = case plusIntegerInt# s1 d1 j of - (# s, d #) -> smartJ# s d +plusInteger (J# s1 d1) (S# j) = mpzToInteger (plusIntegerInt# s1 d1 j) plusInteger i1@(S# _) i2@(J# _ _) = plusInteger i2 i1 -plusInteger (J# s1 d1) (J# s2 d2) = case plusInteger# s1 d1 s2 d2 of - (# s, d #) -> smartJ# s d +plusInteger (J# s1 d1) (J# s2 d2) = mpzToInteger (plusInteger# s1 d1 s2 d2) {-# NOINLINE minusInteger #-} minusInteger :: Integer -> Integer -> Integer @@ -544,32 +550,25 @@ minusInteger (S# i) (S# j) = case subIntC# i j of (# r, c #) -> if isTrue# (c ==# 0#) then S# r else case int2Integer# i of - (# s, d #) -> case minusIntegerInt# s d j of - (# s', d' #) -> J# s' d' + (# s, d #) -> mpzToInteger (minusIntegerInt# s d j) minusInteger i1@(J# _ _) (S# 0#) = i1 -minusInteger (J# s1 d1) (S# j) = case minusIntegerInt# s1 d1 j of - (# s, d #) -> smartJ# s d +minusInteger (J# s1 d1) (S# j) = mpzToInteger (minusIntegerInt# s1 d1 j) minusInteger (S# 0#) (J# s2 d2) = J# (negateInt# s2) d2 -minusInteger (S# i) (J# s2 d2) = case plusIntegerInt# (negateInt# s2) d2 i of - (# s, d #) -> smartJ# s d -minusInteger (J# s1 d1) (J# s2 d2) = case minusInteger# s1 d1 s2 d2 of - (# s, d #) -> smartJ# s d +minusInteger (S# i) (J# s2 d2) = mpzToInteger (plusIntegerInt# (negateInt# s2) d2 i) +minusInteger (J# s1 d1) (J# s2 d2) = mpzToInteger (minusInteger# s1 d1 s2 d2) {-# NOINLINE timesInteger #-} timesInteger :: Integer -> Integer -> Integer timesInteger (S# i) (S# j) = if isTrue# (mulIntMayOflo# i j ==# 0#) then S# (i *# j) else case int2Integer# i of - (# s, d #) -> case timesIntegerInt# s d j of - (# s', d' #) -> smartJ# s' d' + (# s, d #) -> mpzToInteger (timesIntegerInt# s d j) timesInteger (S# 0#) _ = S# 0# timesInteger (S# -1#) i2 = negateInteger i2 timesInteger (S# 1#) i2 = i2 -timesInteger (S# i1) (J# s2 d2) = case timesIntegerInt# s2 d2 i1 of - (# s, d #) -> J# s d +timesInteger (S# i1) (J# s2 d2) = mpzToInteger (timesIntegerInt# s2 d2 i1) timesInteger i1@(J# _ _) i2@(S# _) = timesInteger i2 i1 -- swap args & retry -timesInteger (J# s1 d1) (J# s2 d2) = case timesInteger# s1 d1 s2 d2 of - (# s, d #) -> J# s d +timesInteger (J# s1 d1) (J# s2 d2) = mpzToInteger (timesInteger# s1 d1 s2 d2) {-# NOINLINE negateInteger #-} negateInteger :: Integer -> Integer @@ -599,8 +598,8 @@ encodeDoubleInteger (J# s# d#) e = encodeDouble# s# d# e {-# NOINLINE decodeDoubleInteger #-} decodeDoubleInteger :: Double# -> (# Integer, Int# #) decodeDoubleInteger d = case decodeDouble# d of - (# exp#, s#, d# #) -> let !s = smartJ# s# d# - in (# s, exp# #) + (# exp#, man# #) -> let !man = mpzToInteger man# + in (# man, exp# #) -- previous code: doubleFromInteger n = fromInteger n = encodeFloat n 0 -- doesn't work too well, because encodeFloat is defined in @@ -646,8 +645,7 @@ andInteger :: Integer -> Integer -> Integer x@(S# _) `andInteger` y@(J# _ _) = toBig x `andInteger` y x@(J# _ _) `andInteger` y@(S# _) = x `andInteger` toBig y (J# s1 d1) `andInteger` (J# s2 d2) = - case andInteger# s1 d1 s2 d2 of - (# s, d #) -> smartJ# s d + mpzToInteger (andInteger# s1 d1 s2 d2) {-# NOINLINE orInteger #-} orInteger :: Integer -> Integer -> Integer @@ -655,8 +653,7 @@ orInteger :: Integer -> Integer -> Integer x@(S# _) `orInteger` y@(J# _ _) = toBig x `orInteger` y x@(J# _ _) `orInteger` y@(S# _) = x `orInteger` toBig y (J# s1 d1) `orInteger` (J# s2 d2) = - case orInteger# s1 d1 s2 d2 of - (# s, d #) -> J# s d + mpzToInteger (orInteger# s1 d1 s2 d2) {-# NOINLINE xorInteger #-} xorInteger :: Integer -> Integer -> Integer @@ -664,27 +661,24 @@ xorInteger :: Integer -> Integer -> Integer x@(S# _) `xorInteger` y@(J# _ _) = toBig x `xorInteger` y x@(J# _ _) `xorInteger` y@(S# _) = x `xorInteger` toBig y (J# s1 d1) `xorInteger` (J# s2 d2) = - case xorInteger# s1 d1 s2 d2 of - (# s, d #) -> smartJ# s d + mpzToInteger (xorInteger# s1 d1 s2 d2) {-# NOINLINE complementInteger #-} complementInteger :: Integer -> Integer complementInteger (S# x) = S# (word2Int# (int2Word# x `xor#` int2Word# (0# -# 1#))) complementInteger (J# s d) - = case complementInteger# s d of (# s', d' #) -> smartJ# s' d' + = mpzToInteger (complementInteger# s d) {-# NOINLINE shiftLInteger #-} shiftLInteger :: Integer -> Int# -> Integer shiftLInteger j@(S# _) i = shiftLInteger (toBig j) i -shiftLInteger (J# s d) i = case mul2ExpInteger# s d i of - (# s', d' #) -> J# s' d' +shiftLInteger (J# s d) i = mpzToInteger (mul2ExpInteger# s d i) {-# NOINLINE shiftRInteger #-} shiftRInteger :: Integer -> Int# -> Integer shiftRInteger j@(S# _) i = shiftRInteger (toBig j) i -shiftRInteger (J# s d) i = case fdivQ2ExpInteger# s d i of - (# s', d' #) -> smartJ# s' d' +shiftRInteger (J# s d) i = mpzToInteger (fdivQ2ExpInteger# s d i) {-# NOINLINE testBitInteger #-} testBitInteger :: Integer -> Int# -> Bool @@ -695,8 +689,7 @@ testBitInteger (J# s d) i = isTrue# (testBitInteger# s d i /=# 0#) {-# NOINLINE powInteger #-} powInteger :: Integer -> Word# -> Integer powInteger j@(S# _) e = powInteger (toBig j) e -powInteger (J# s d) e = case powInteger# s d e of - (# s', d' #) -> smartJ# s' d' +powInteger (J# s d) e = mpzToInteger (powInteger# s d e) -- | \"@'powModInteger' /b/ /e/ /m/@\" computes base @/b/@ raised to -- exponent @/e/@ modulo @/m/@. @@ -709,8 +702,7 @@ powInteger (J# s d) e = case powInteger# s d e of {-# NOINLINE powModInteger #-} powModInteger :: Integer -> Integer -> Integer -> Integer powModInteger (J# s1 d1) (J# s2 d2) (J# s3 d3) = - case powModInteger# s1 d1 s2 d2 s3 d3 of - (# s', d' #) -> smartJ# s' d' + mpzToInteger (powModInteger# s1 d1 s2 d2 s3 d3) powModInteger b e m = powModInteger (toBig b) (toBig e) (toBig m) -- | \"@'powModSecInteger' /b/ /e/ /m/@\" computes base @/b/@ raised to @@ -724,8 +716,7 @@ powModInteger b e m = powModInteger (toBig b) (toBig e) (toBig m) {-# NOINLINE powModSecInteger #-} powModSecInteger :: Integer -> Integer -> Integer -> Integer powModSecInteger (J# s1 d1) (J# s2 d2) (J# s3 d3) = - case powModSecInteger# s1 d1 s2 d2 s3 d3 of - (# s', d' #) -> J# s' d' + mpzToInteger (powModSecInteger# s1 d1 s2 d2 s3 d3) powModSecInteger b e m = powModSecInteger (toBig b) (toBig e) (toBig m) -- | \"@'recipModInteger' /x/ /m/@\" computes the inverse of @/x/@ modulo @/m/@. If @@ -740,8 +731,7 @@ recipModInteger :: Integer -> Integer -> Integer recipModInteger j@(S# _) m@(S# _) = recipModInteger (toBig j) (toBig m) recipModInteger j@(S# _) m@(J# _ _) = recipModInteger (toBig j) m recipModInteger j@(J# _ _) m@(S# _) = recipModInteger j (toBig m) -recipModInteger (J# s d) (J# ms md) = case recipModInteger# s d ms md of - (# s', d' #) -> smartJ# s' d' +recipModInteger (J# s d) (J# ms md) = mpzToInteger (recipModInteger# s d ms md) -- | Probalistic Miller-Rabin primality test. -- @@ -771,7 +761,7 @@ testPrimeInteger (J# s d) reps = testPrimeInteger# s d reps {-# NOINLINE nextPrimeInteger #-} nextPrimeInteger :: Integer -> Integer nextPrimeInteger j@(S# _) = nextPrimeInteger (toBig j) -nextPrimeInteger (J# s d) = case nextPrimeInteger# s d of (# s', d' #) -> smartJ# s' d' +nextPrimeInteger (J# s d) = mpzToInteger (nextPrimeInteger# s d) -- | Compute number of digits (without sign) in given @/base/@. -- @@ -861,7 +851,7 @@ exportIntegerToAddr j@(S# _) addr o e = exportIntegerToAddr (toBig j) addr o e - -- * returns a new 'Integer' {-# NOINLINE importIntegerFromByteArray #-} importIntegerFromByteArray :: ByteArray# -> Word# -> Word# -> Int# -> Integer -importIntegerFromByteArray ba o l e = case importIntegerFromByteArray# ba o l e of (# s', d' #) -> smartJ# s' d' +importIntegerFromByteArray ba o l e = mpzToInteger (importIntegerFromByteArray# ba o l e) -- | Read 'Integer' (without sign) from memory location at @/addr/@ in -- base-256 representation. @@ -874,7 +864,7 @@ importIntegerFromByteArray ba o l e = case importIntegerFromByteArray# ba o l e {-# NOINLINE importIntegerFromAddr #-} importIntegerFromAddr :: Addr# -> Word# -> Int# -> State# s -> (# State# s, Integer #) importIntegerFromAddr addr l e st = case importIntegerFromAddr# addr l e st of - (# st', s', d' #) -> let !j = smartJ# s' d' in (# st', j #) + (# st', mpz #) -> let !j = mpzToInteger mpz in (# st', j #) \end{code} diff --git a/libraries/integer-gmp/cbits/gmp-wrappers.cmm b/libraries/integer-gmp/cbits/gmp-wrappers.cmm index 2c9bbd25fd..28c1333938 100644 --- a/libraries/integer-gmp/cbits/gmp-wrappers.cmm +++ b/libraries/integer-gmp/cbits/gmp-wrappers.cmm @@ -28,7 +28,6 @@ #include "Cmm.h" #include "GmpDerivedConstants.h" -import "integer-gmp" __gmpz_init; import "integer-gmp" __gmpz_add; import "integer-gmp" __gmpz_add_ui; import "integer-gmp" __gmpz_sub; @@ -68,6 +67,8 @@ import "integer-gmp" __gmpz_export; import "integer-gmp" integer_cbits_decodeDouble; +import "integer-gmp" stg_INTLIKE_closure; + /* ----------------------------------------------------------------------------- Arbitrary-precision Integer operations. @@ -75,6 +76,15 @@ import "integer-gmp" integer_cbits_decodeDouble; the case for all the platforms that GHC supports, currently. -------------------------------------------------------------------------- */ +/* This is used when a dummy pointer is needed for a ByteArray# return value + + Ideally this would be a statically allocated 'ByteArray#' + containing SIZEOF_W 0-bytes. However, since in those cases when a + dummy value is needed, the 'ByteArray#' is not supposed to be + accessed anyway, this is should be a tolerable hack. + */ +#define DUMMY_BYTE_ARR (stg_INTLIKE_closure+1) + /* set mpz_t from Int#/ByteArray# */ #define MP_INT_SET_FROM_BA(mp_ptr,i,ba) \ MP_INT__mp_alloc(mp_ptr) = W_TO_INT(BYTE_ARR_WDS(ba)); \ @@ -85,42 +95,103 @@ import "integer-gmp" integer_cbits_decodeDouble; #define MP_INT_AS_PAIR(mp_ptr) \ TO_W_(MP_INT__mp_size(mp_ptr)),(MP_INT__mp_d(mp_ptr)-SIZEOF_StgArrWords) +#define MP_INT_TO_BA(mp_ptr) \ + (MP_INT__mp_d(mp_ptr)-SIZEOF_StgArrWords) + +/* Size of mpz_t with single limb */ +#define SIZEOF_MP_INT_1LIMB (SIZEOF_MP_INT+WDS(1)) + +/* Initialize 0-valued single-limb mpz_t at mp_ptr */ +#define MP_INT_1LIMB_INIT0(mp_ptr) \ + MP_INT__mp_alloc(mp_ptr) = W_TO_INT(1); \ + MP_INT__mp_size(mp_ptr) = W_TO_INT(0); \ + MP_INT__mp_d(mp_ptr) = (mp_ptr+SIZEOF_MP_INT) + + +/* return mpz_t as (# s::Int#, d::ByteArray#, l1::Word# #) tuple + * + * semantics: + * + * (# 0, _, 0 #) -> value = 0 + * (# 1, _, w #) -> value = w + * (# -1, _, w #) -> value = -w + * (# s, d, 0 #) -> value = J# s d + * + */ +#define MP_INT_1LIMB_RETURN(mp_ptr) \ + CInt __mp_s; \ + __mp_s = MP_INT__mp_size(mp_ptr); \ + \ + if (__mp_s == W_TO_INT(0)) \ + { \ + return (0,DUMMY_BYTE_ARR,0); \ + } \ + \ + if (__mp_s == W_TO_INT(-1) || __mp_s == W_TO_INT(1)) \ + { \ + return (TO_W_(__mp_s),DUMMY_BYTE_ARR,W_[MP_INT__mp_d(mp_ptr)]); \ + } \ + \ + return (TO_W_(__mp_s),MP_INT_TO_BA(mp_ptr),0) + +/* Helper macro used by MP_INT_1LIMB_RETURN2 */ +#define MP_INT_1LIMB_AS_TUP3(s,d,w,mp_ptr) \ + CInt s; P_ d; W_ w; \ + s = MP_INT__mp_size(mp_ptr); \ + \ + if (s == W_TO_INT(0)) \ + { \ + d = DUMMY_BYTE_ARR; w = 0; \ + } else { \ + if (s == W_TO_INT(-1) || s == W_TO_INT(1)) \ + { \ + d = DUMMY_BYTE_ARR; w = W_[MP_INT__mp_d(mp_ptr)]; \ + } else { \ + d = MP_INT_TO_BA(mp_ptr); w = 0; \ + } \ + } -/* :: ByteArray# -> Word# -> Word# -> Int# -> (# Int#, ByteArray# #) */ +#define MP_INT_1LIMB_RETURN2(mp_ptr1,mp_ptr2) \ + MP_INT_1LIMB_AS_TUP3(__r1s,__r1d,__r1w,mp_ptr1); \ + MP_INT_1LIMB_AS_TUP3(__r2s,__r2d,__r2w,mp_ptr2); \ + return (TO_W_(__r1s),__r1d,__r1w, TO_W_(__r2s),__r2d,__r2w) + +/* :: ByteArray# -> Word# -> Word# -> Int# -> (# Int#, ByteArray#, Word# #) */ integer_cmm_importIntegerFromByteArrayzh (P_ ba, W_ of, W_ sz, W_ e) { W_ src_ptr; W_ mp_result; again: - STK_CHK_GEN_N (SIZEOF_MP_INT); + STK_CHK_GEN_N (SIZEOF_MP_INT_1LIMB); MAYBE_GC(again); - mp_result = Sp - SIZEOF_MP_INT; + mp_result = Sp - SIZEOF_MP_INT_1LIMB; + MP_INT_1LIMB_INIT0(mp_result); src_ptr = BYTE_ARR_CTS(ba) + of; - ccall __gmpz_init(mp_result "ptr"); ccall __gmpz_import(mp_result "ptr", sz, W_TO_INT(e), W_TO_INT(1), W_TO_INT(0), 0, src_ptr "ptr"); - return(MP_INT_AS_PAIR(mp_result)); + MP_INT_1LIMB_RETURN(mp_result); } -/* :: Addr# -> Word# -> Int# -> State# s -> (# State# s, Int#, ByteArray# #) */ +/* :: Addr# -> Word# -> Int# -> State# s -> (# State# s, Int#, ByteArray#, Word# #) */ integer_cmm_importIntegerFromAddrzh (W_ src_ptr, W_ sz, W_ e) { W_ mp_result; again: - STK_CHK_GEN_N (SIZEOF_MP_INT); + STK_CHK_GEN_N (SIZEOF_MP_INT_1LIMB); MAYBE_GC(again); - mp_result = Sp - SIZEOF_MP_INT; + mp_result = Sp - SIZEOF_MP_INT_1LIMB; + + MP_INT_1LIMB_INIT0(mp_result); - ccall __gmpz_init(mp_result "ptr"); ccall __gmpz_import(mp_result "ptr", sz, W_TO_INT(e), W_TO_INT(1), W_TO_INT(0), 0, src_ptr "ptr"); - return(MP_INT_AS_PAIR(mp_result)); + MP_INT_1LIMB_RETURN(mp_result); } /* :: Int# -> ByteArray# -> MutableByteArray# s -> Word# -> Int# -> State# s -> (# State# s, Word# #) */ @@ -329,22 +400,22 @@ name (W_ ws1, P_ d1, W_ ws2, P_ d2) \ W_ mp_result1; \ \ again: \ - STK_CHK_GEN_N (3 * SIZEOF_MP_INT); \ + STK_CHK_GEN_N (2*SIZEOF_MP_INT + SIZEOF_MP_INT_1LIMB); \ MAYBE_GC(again); \ \ - mp_tmp1 = Sp - 1 * SIZEOF_MP_INT; \ - mp_tmp2 = Sp - 2 * SIZEOF_MP_INT; \ - mp_result1 = Sp - 3 * SIZEOF_MP_INT; \ + mp_tmp1 = Sp - 1*SIZEOF_MP_INT; \ + mp_tmp2 = Sp - 2*SIZEOF_MP_INT; \ + mp_result1 = Sp - 2*SIZEOF_MP_INT - SIZEOF_MP_INT_1LIMB; \ \ MP_INT_SET_FROM_BA(mp_tmp1,ws1,d1); \ MP_INT_SET_FROM_BA(mp_tmp2,ws2,d2); \ \ - ccall __gmpz_init(mp_result1 "ptr"); \ + MP_INT_1LIMB_INIT0(mp_result1); \ \ /* Perform the operation */ \ ccall mp_fun(mp_result1 "ptr",mp_tmp1 "ptr",mp_tmp2 "ptr"); \ \ - return (MP_INT_AS_PAIR(mp_result1)); \ + MP_INT_1LIMB_RETURN(mp_result1); \ } #define GMP_TAKE3_RET1(name,mp_fun) \ @@ -356,25 +427,25 @@ name (W_ ws1, P_ d1, W_ ws2, P_ d2, W_ ws3, P_ d3) \ W_ mp_result1; \ \ again: \ - STK_CHK_GEN_N (4 * SIZEOF_MP_INT); \ + STK_CHK_GEN_N (3*SIZEOF_MP_INT + SIZEOF_MP_INT_1LIMB); \ MAYBE_GC(again); \ \ - mp_tmp1 = Sp - 1 * SIZEOF_MP_INT; \ - mp_tmp2 = Sp - 2 * SIZEOF_MP_INT; \ - mp_tmp3 = Sp - 3 * SIZEOF_MP_INT; \ - mp_result1 = Sp - 4 * SIZEOF_MP_INT; \ + mp_tmp1 = Sp - 1*SIZEOF_MP_INT; \ + mp_tmp2 = Sp - 2*SIZEOF_MP_INT; \ + mp_tmp3 = Sp - 3*SIZEOF_MP_INT; \ + mp_result1 = Sp - 3*SIZEOF_MP_INT - SIZEOF_MP_INT_1LIMB; \ \ MP_INT_SET_FROM_BA(mp_tmp1,ws1,d1); \ MP_INT_SET_FROM_BA(mp_tmp2,ws2,d2); \ MP_INT_SET_FROM_BA(mp_tmp3,ws3,d3); \ \ - ccall __gmpz_init(mp_result1 "ptr"); \ + MP_INT_1LIMB_INIT0(mp_result1); \ \ /* Perform the operation */ \ - ccall mp_fun(mp_result1 "ptr",mp_tmp1 "ptr",mp_tmp2 "ptr", \ - mp_tmp3 "ptr"); \ + ccall mp_fun(mp_result1 "ptr", \ + mp_tmp1 "ptr", mp_tmp2 "ptr", mp_tmp3 "ptr"); \ \ - return (MP_INT_AS_PAIR(mp_result1)); \ + MP_INT_1LIMB_RETURN(mp_result1); \ } #define GMP_TAKE1_UL1_RET1(name,mp_fun) \ @@ -385,20 +456,20 @@ name (W_ ws1, P_ d1, W_ wul) \ \ /* call doYouWantToGC() */ \ again: \ - STK_CHK_GEN_N (2 * SIZEOF_MP_INT); \ + STK_CHK_GEN_N (SIZEOF_MP_INT + SIZEOF_MP_INT_1LIMB); \ MAYBE_GC(again); \ \ - mp_tmp = Sp - 1 * SIZEOF_MP_INT; \ - mp_result = Sp - 2 * SIZEOF_MP_INT; \ + mp_tmp = Sp - SIZEOF_MP_INT; \ + mp_result = Sp - SIZEOF_MP_INT - SIZEOF_MP_INT_1LIMB; \ \ MP_INT_SET_FROM_BA(mp_tmp,ws1,d1); \ \ - ccall __gmpz_init(mp_result "ptr"); \ + MP_INT_1LIMB_INIT0(mp_result); \ \ /* Perform the operation */ \ ccall mp_fun(mp_result "ptr", mp_tmp "ptr", W_TO_LONG(wul)); \ \ - return (MP_INT_AS_PAIR(mp_result)); \ + MP_INT_1LIMB_RETURN(mp_result); \ } #define GMP_TAKE1_I1_RETI1(name,mp_fun) \ @@ -446,20 +517,20 @@ name (W_ ws1, P_ d1) \ W_ mp_result1; \ \ again: \ - STK_CHK_GEN_N (2 * SIZEOF_MP_INT); \ + STK_CHK_GEN_N (SIZEOF_MP_INT + SIZEOF_MP_INT_1LIMB); \ MAYBE_GC(again); \ \ - mp_tmp1 = Sp - 1 * SIZEOF_MP_INT; \ - mp_result1 = Sp - 2 * SIZEOF_MP_INT; \ + mp_tmp1 = Sp - SIZEOF_MP_INT; \ + mp_result1 = Sp - SIZEOF_MP_INT - SIZEOF_MP_INT_1LIMB; \ \ MP_INT_SET_FROM_BA(mp_tmp1,ws1,d1); \ \ - ccall __gmpz_init(mp_result1 "ptr"); \ + MP_INT_1LIMB_INIT0(mp_result1); \ \ /* Perform the operation */ \ ccall mp_fun(mp_result1 "ptr",mp_tmp1 "ptr"); \ \ - return(MP_INT_AS_PAIR(mp_result1)); \ + MP_INT_1LIMB_RETURN(mp_result1); \ } #define GMP_TAKE2_RET2(name,mp_fun) \ @@ -471,24 +542,25 @@ name (W_ ws1, P_ d1, W_ ws2, P_ d2) \ W_ mp_result2; \ \ again: \ - STK_CHK_GEN_N (4 * SIZEOF_MP_INT); \ + STK_CHK_GEN_N (2*SIZEOF_MP_INT + 2*SIZEOF_MP_INT_1LIMB); \ MAYBE_GC(again); \ \ - mp_tmp1 = Sp - 1 * SIZEOF_MP_INT; \ - mp_tmp2 = Sp - 2 * SIZEOF_MP_INT; \ - mp_result1 = Sp - 3 * SIZEOF_MP_INT; \ - mp_result2 = Sp - 4 * SIZEOF_MP_INT; \ + mp_tmp1 = Sp - 1*SIZEOF_MP_INT; \ + mp_tmp2 = Sp - 2*SIZEOF_MP_INT; \ + mp_result1 = Sp - 2*SIZEOF_MP_INT - 1*SIZEOF_MP_INT_1LIMB; \ + mp_result2 = Sp - 2*SIZEOF_MP_INT - 2*SIZEOF_MP_INT_1LIMB; \ \ MP_INT_SET_FROM_BA(mp_tmp1,ws1,d1); \ MP_INT_SET_FROM_BA(mp_tmp2,ws2,d2); \ \ - ccall __gmpz_init(mp_result1 "ptr"); \ - ccall __gmpz_init(mp_result2 "ptr"); \ + MP_INT_1LIMB_INIT0(mp_result1); \ + MP_INT_1LIMB_INIT0(mp_result2); \ \ /* Perform the operation */ \ - ccall mp_fun(mp_result1 "ptr",mp_result2 "ptr",mp_tmp1 "ptr",mp_tmp2 "ptr"); \ + ccall mp_fun(mp_result1 "ptr", mp_result2 "ptr", \ + mp_tmp1 "ptr", mp_tmp2 "ptr"); \ \ - return (MP_INT_AS_PAIR(mp_result1),MP_INT_AS_PAIR(mp_result2)); \ + MP_INT_1LIMB_RETURN2(mp_result1, mp_result2); \ } #define GMP_TAKE1_UL1_RET2(name,mp_fun) \ @@ -499,23 +571,23 @@ name (W_ ws1, P_ d1, W_ wul2) \ W_ mp_result2; \ \ again: \ - STK_CHK_GEN_N (3 * SIZEOF_MP_INT); \ + STK_CHK_GEN_N (SIZEOF_MP_INT + 2*SIZEOF_MP_INT_1LIMB); \ MAYBE_GC(again); \ \ - mp_tmp1 = Sp - 1 * SIZEOF_MP_INT; \ - mp_result1 = Sp - 2 * SIZEOF_MP_INT; \ - mp_result2 = Sp - 3 * SIZEOF_MP_INT; \ + mp_tmp1 = Sp - SIZEOF_MP_INT; \ + mp_result1 = Sp - SIZEOF_MP_INT - 1*SIZEOF_MP_INT_1LIMB; \ + mp_result2 = Sp - SIZEOF_MP_INT - 2*SIZEOF_MP_INT_1LIMB; \ \ MP_INT_SET_FROM_BA(mp_tmp1,ws1,d1); \ \ - ccall __gmpz_init(mp_result1 "ptr"); \ - ccall __gmpz_init(mp_result2 "ptr"); \ + MP_INT_1LIMB_INIT0(mp_result1); \ + MP_INT_1LIMB_INIT0(mp_result2); \ \ /* Perform the operation */ \ ccall mp_fun(mp_result1 "ptr", mp_result2 "ptr", \ mp_tmp1 "ptr", W_TO_LONG(wul2)); \ \ - return (MP_INT_AS_PAIR(mp_result1),MP_INT_AS_PAIR(mp_result2)); \ + MP_INT_1LIMB_RETURN2(mp_result1, mp_result2); \ } GMP_TAKE2_RET1(integer_cmm_plusIntegerzh, __gmpz_add) @@ -657,16 +729,17 @@ integer_cmm_cmpIntegerzh (W_ usize, P_ d1, W_ vsize, P_ d2) integer_cmm_decodeDoublezh (D_ arg) { - D_ arg; - W_ p; W_ mp_tmp1; W_ mp_tmp_w; - STK_CHK_GEN_N (2 * SIZEOF_MP_INT); +#if SIZEOF_DOUBLE != SIZEOF_W + W_ p; + + STK_CHK_GEN_N (SIZEOF_MP_INT + SIZEOF_W); ALLOC_PRIM (ARR_SIZE); - mp_tmp1 = Sp - 1 * SIZEOF_MP_INT; - mp_tmp_w = Sp - 2 * SIZEOF_MP_INT; + mp_tmp1 = Sp - SIZEOF_MP_INT; + mp_tmp_w = Sp - SIZEOF_MP_INT - SIZEOF_W; /* Be prepared to tell Lennart-coded integer_cbits_decodeDouble where mantissa.d can be put (it does not care about the rest) */ @@ -675,14 +748,29 @@ integer_cmm_decodeDoublezh (D_ arg) StgArrWords_bytes(p) = DOUBLE_MANTISSA_SIZE; MP_INT__mp_d(mp_tmp1) = BYTE_ARR_CTS(p); +#else + /* When SIZEOF_DOUBLE == SIZEOF_W == 8, the result will fit into a + single 8-byte limb, and so we avoid allocating on the Heap and + use only the Stack instead */ + + STK_CHK_GEN_N (SIZEOF_MP_INT_1LIMB + SIZEOF_W); + + mp_tmp1 = Sp - SIZEOF_MP_INT_1LIMB; + mp_tmp_w = Sp - SIZEOF_MP_INT_1LIMB - SIZEOF_W; + + MP_INT_1LIMB_INIT0(mp_tmp1); +#endif + /* Perform the operation */ - ccall integer_cbits_decodeDouble(mp_tmp1 "ptr", mp_tmp_w "ptr",arg); + ccall integer_cbits_decodeDouble(mp_tmp1 "ptr", mp_tmp_w "ptr", arg); + + /* returns: (Int# (expn), MPZ#) */ + MP_INT_1LIMB_AS_TUP3(r1s, r1d, r1w, mp_tmp1); - /* returns: (Int# (expn), Int#, ByteArray#) */ - return (W_[mp_tmp_w], TO_W_(MP_INT__mp_size(mp_tmp1)), p); + return (W_[mp_tmp_w], TO_W_(r1s), r1d, r1w); } -/* :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #) */ +/* :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray#, Word# #) */ #define GMPX_TAKE1_UL1_RET1(name,pos_arg_fun,neg_arg_fun) \ name(W_ ws1, P_ d1, W_ wl) \ { \ @@ -690,23 +778,23 @@ name(W_ ws1, P_ d1, W_ wl) \ W_ mp_result; \ \ again: \ - STK_CHK_GEN_N (2 * SIZEOF_MP_INT); \ + STK_CHK_GEN_N (SIZEOF_MP_INT + SIZEOF_MP_INT_1LIMB); \ MAYBE_GC(again); \ \ - mp_tmp = Sp - 1 * SIZEOF_MP_INT; \ - mp_result = Sp - 2 * SIZEOF_MP_INT; \ + mp_tmp = Sp - SIZEOF_MP_INT; \ + mp_result = Sp - SIZEOF_MP_INT - SIZEOF_MP_INT_1LIMB; \ \ MP_INT_SET_FROM_BA(mp_tmp,ws1,d1); \ \ - ccall __gmpz_init(mp_result "ptr"); \ + MP_INT_1LIMB_INIT0(mp_result); \ \ if(%lt(wl,0)) { \ ccall neg_arg_fun(mp_result "ptr", mp_tmp "ptr", W_TO_LONG(-wl)); \ - return(MP_INT_AS_PAIR(mp_result)); \ + } else { \ + ccall pos_arg_fun(mp_result "ptr", mp_tmp "ptr", W_TO_LONG(wl)); \ } \ \ - ccall pos_arg_fun(mp_result "ptr", mp_tmp "ptr", W_TO_LONG(wl)); \ - return(MP_INT_AS_PAIR(mp_result)); \ + MP_INT_1LIMB_RETURN(mp_result); \ } /* NB: We need both primitives as we can't express 'minusIntegerInt#' |