diff options
Diffstat (limited to 'src/_fastmath.c')
-rw-r--r-- | src/_fastmath.c | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/src/_fastmath.c b/src/_fastmath.c index 4b5dede..b8b24b6 100644 --- a/src/_fastmath.c +++ b/src/_fastmath.c @@ -66,19 +66,26 @@ static void longObjToMPZ (mpz_t m, PyLongObject * p) { int size, i; + long negative; mpz_t temp, temp2; mpz_init (temp); mpz_init (temp2); #ifdef IS_PY3K - if (p->ob_base.ob_size > 0) + if (p->ob_base.ob_size > 0) { size = p->ob_base.ob_size; - else + negative = 1; + } else { size = -p->ob_base.ob_size; + negative = -1; + } #else - if (p->ob_size > 0) + if (p->ob_size > 0) { size = p->ob_size; - else + negative = 1; + } else { size = -p->ob_size; + negative = -1; + } #endif mpz_set_ui (m, 0); for (i = 0; i < size; i++) @@ -91,6 +98,7 @@ longObjToMPZ (mpz_t m, PyLongObject * p) #endif mpz_add (m, m, temp2); } + mpz_mul_si(m, m, negative); mpz_clear (temp); mpz_clear (temp2); } @@ -104,12 +112,15 @@ mpzToLongObj (mpz_t m) #else int size = (mpz_sizeinbase (m, 2) + SHIFT - 1) / SHIFT; #endif + int sgn; int i; mpz_t temp; PyLongObject *l = _PyLong_New (size); if (!l) return NULL; - mpz_init_set (temp, m); + sgn = mpz_sgn(m); + mpz_init(temp); + mpz_mul_si(temp, m, sgn); for (i = 0; i < size; i++) { #ifdef IS_PY3K @@ -124,9 +135,9 @@ mpzToLongObj (mpz_t m) while ((i > 0) && (l->ob_digit[i - 1] == 0)) i--; #ifdef IS_PY3K - l->ob_base.ob_size = i; + l->ob_base.ob_size = i * sgn; #else - l->ob_size = i; + l->ob_size = i * sgn; #endif mpz_clear (temp); return (PyObject *) l; @@ -1062,7 +1073,7 @@ isPrime (PyObject * self, PyObject * args, PyObject * kwargs) longObjToMPZ (n, (PyLongObject *) l); Py_BEGIN_ALLOW_THREADS; - /* first check if n is known to be prime and do some trail division */ + /* first check if n is known to be prime and do some trial division */ for (i = 0; i < SIEVE_BASE_SIZE; ++i) { if (mpz_cmp_ui (n, sieve_base[i]) == 0) @@ -1342,8 +1353,12 @@ rabinMillerTest (mpz_t n, int rounds, PyObject *randfunc) } Py_BEGIN_ALLOW_THREADS; - if ((mpz_tstbit (n, 0) == 0) || (mpz_cmp_ui (n, 3) < 0)) - return (mpz_cmp_ui (n, 2) == 0); + /* check special cases (n==2, n even, n < 2) */ + if ((mpz_tstbit (n, 0) == 0) || (mpz_cmp_ui (n, 3) < 0)) { + return_val = (mpz_cmp_ui (n, 2) == 0); + Py_BLOCK_THREADS; + return return_val; + } mpz_init (tmp); mpz_init (n_1); |