summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDwayne C. Litzenberger <dlitz@dlitz.net>2012-07-03 11:52:59 -0400
committerDwayne C. Litzenberger <dlitz@dlitz.net>2012-07-03 11:52:59 -0400
commit868ec6aa98e797e6c8bb3581bb9218e6ddeab000 (patch)
treeebaa1aba53a5f287b53d2082780bc10de6d5ffba
parentd31f7df39a6d3db73a16909de4669d337b69c40c (diff)
parentb215cc9e18012d9b291ea8c0adc6eefa507dd3b1 (diff)
downloadpycrypto-868ec6aa98e797e6c8bb3581bb9218e6ddeab000.tar.gz
Merge branch 'error-propagation-fixes'
-rw-r--r--lib/Crypto/SelfTest/Util/test_number.py29
-rw-r--r--src/_fastmath.c13
2 files changed, 37 insertions, 5 deletions
diff --git a/lib/Crypto/SelfTest/Util/test_number.py b/lib/Crypto/SelfTest/Util/test_number.py
index bdbc9b1..2201a93 100644
--- a/lib/Crypto/SelfTest/Util/test_number.py
+++ b/lib/Crypto/SelfTest/Util/test_number.py
@@ -32,6 +32,9 @@ if sys.version_info[0] == 2 and sys.version_info[1] == 1:
import unittest
+class MyError(Exception):
+ """Dummy exception used for tests"""
+
# NB: In some places, we compare tuples instead of just output values so that
# if any inputs cause a test failure, we'll be able to tell which ones.
@@ -289,6 +292,32 @@ class FastmathTests(unittest.TestCase):
self.assertEqual(n, k.n)
self.assertEqual(e, k.e)
+ def test_isPrime_randfunc_exception(self):
+ """Test that when isPrime is called, an exception raised in randfunc is propagated."""
+ def randfunc(n):
+ raise MyError
+ prime = 3536384141L # Needs to be large enough so that rabinMillerTest will be invoked
+ self.assertRaises(MyError, number._fastmath.isPrime, prime, randfunc=randfunc)
+
+ def test_getStrongPrime_randfunc_exception(self):
+ """Test that when getStrongPrime is called, an exception raised in randfunc is propagated."""
+ def randfunc(n):
+ raise MyError
+ self.assertRaises(MyError, number._fastmath.getStrongPrime, 512, randfunc=randfunc)
+
+ def test_isPrime_randfunc_bogus(self):
+ """Test that when isPrime is called, an exception is raised if randfunc returns something bogus."""
+ def randfunc(n):
+ return None
+ prime = 3536384141L # Needs to be large enough so that rabinMillerTest will be invoked
+ self.assertRaises(TypeError, number._fastmath.isPrime, prime, randfunc=randfunc)
+
+ def test_getStrongPrime_randfunc_bogus(self):
+ """Test that when getStrongPrime is called, an exception is raised if randfunc returns something bogus."""
+ def randfunc(n):
+ return None
+ self.assertRaises(TypeError, number._fastmath.getStrongPrime, 512, randfunc=randfunc)
+
def get_tests(config={}):
from Crypto.SelfTest.st_common import list_test_cases
tests = list_test_cases(MiscTests)
diff --git a/src/_fastmath.c b/src/_fastmath.c
index b8b24b6..f05e70f 100644
--- a/src/_fastmath.c
+++ b/src/_fastmath.c
@@ -1097,8 +1097,9 @@ cleanup:
mpz_clear (n);
Py_END_ALLOW_THREADS;
- if (result == 0)
- {
+ if (result < 0) {
+ return NULL;
+ } else if (result == 0) {
Py_INCREF(Py_False);
return Py_False;
} else {
@@ -1323,6 +1324,7 @@ sieve_field (char *field, unsigned long int field_size, mpz_t start)
/* Tests if n is prime.
* Returns 0 when n is definitly composite.
* Returns 1 when n is probably prime.
+ * Returns -1 when there is an error.
* every round reduces the chance of a false positive be at least 1/4.
*
* If randfunc is omitted, then the python version Random.new().read is used.
@@ -1335,7 +1337,8 @@ static int
rabinMillerTest (mpz_t n, int rounds, PyObject *randfunc)
{
int base_was_tested;
- unsigned long int i, j, b, composite, return_val=1;
+ unsigned long int i, j, b, composite;
+ int return_val = 1;
mpz_t a, m, z, n_1, tmp;
mpz_t tested[MAX_RABIN_MILLER_ROUNDS];
@@ -1449,13 +1452,13 @@ cleanup:
static PyObject *
getStrongPrime (PyObject *self, PyObject *args, PyObject *kwargs)
{
- unsigned long int i, j, result, bits, x, e=0;
+ unsigned long int i, j, bits, x, e=0;
mpz_t p[2], y[2], R, X;
mpz_t tmp[2], lower_bound, upper_bound, range, increment;
mpf_t tmp_bound;
char *field;
double false_positive_prob;
- int rabin_miller_rounds, is_possible_prime, error = 0;
+ int rabin_miller_rounds, is_possible_prime, error = 0, result;
PyObject *prime, *randfunc=NULL;
static char *kwlist[] = {"N", "e", "false_positive_prob", "randfunc", NULL};
unsigned long int base_size = SIEVE_BASE_SIZE;