summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBarry Mead <barrymead@cox.net>2010-02-15 03:04:21 -0700
committerBarry Mead <barrymead@cox.net>2010-02-15 03:04:21 -0700
commit24953a5a5eb9c9b2591fd2e963565269404f40e1 (patch)
treef2036dbc1c36fe177f624bfcc8fd8cc93d4d58ef
parent474c221b892435e76d3d7bd65b7230c8acfe2482 (diff)
downloadrsa-24953a5a5eb9c9b2591fd2e963565269404f40e1.tar.gz
faster gcd egcd f_exp
-rw-r--r--rsa/__init__.py337
1 files changed, 214 insertions, 123 deletions
diff --git a/rsa/__init__.py b/rsa/__init__.py
index cd6a760..e6c862d 100644
--- a/rsa/__init__.py
+++ b/rsa/__init__.py
@@ -1,50 +1,37 @@
"""RSA module
-pri = k[1] //Private part of keys d,p,q
Module for calculating large primes, and RSA encryption, decryption,
signing and verification. Includes generating public and private keys.
-
-WARNING: this code implements the mathematics of RSA. It is not suitable for
-real-world secure cryptography purposes. It has not been reviewed by a security
-expert. It does not include padding of data. There are many ways in which the
-output of this module, when used without any modification, can be sucessfully
-attacked.
"""
-__author__ = "Sybren Stuvel, Marloes de Boer and Ivo Tamboer"
-__date__ = "2010-02-05"
-__version__ = '1.3.3'
-
-# NOTE: Python's modulo can return negative numbers. We compensate for
-# this behaviour using the abs() function
+__author__ = "Sybren Stuvel, Marloes de Boer, Ivo Tamboer, and Barry Mead"
+__date__ = "2010-02-08"
-from cPickle import dumps, loads
-import base64
import math
import os
import random
import sys
import types
-import zlib
def gcd(p, q):
"""Returns the greatest common divisor of p and q
-
-
- >>> gcd(42, 6)
- 6
+ >>> gcd(48, 180)
+ 12
"""
- if p<q: return gcd(q, p)
- if q == 0: return p
- return gcd(q, abs(p%q))
+ # Iterateive Version is faster and uses much less stack space
+ while q != 0:
+ if p < q: (p,q) = (q,p)
+ (p,q) = (q, p % q)
+ return p
+
def bytes2int(bytes):
"""Converts a list of bytes or a string to an integer
- >>> (128*256 + 64)*256 + + 15
+ >>> (((128 * 256) + 64) * 256) + 15
8405007
>>> l = [128, 64, 15]
- >>> bytes2int(l)
+ >>> bytes2int(l) #same as bytes2int('\x80@\x0f')
8405007
"""
@@ -63,6 +50,8 @@ def bytes2int(bytes):
def int2bytes(number):
"""Converts a number to a string of bytes
+ >>>int2bytes(123456789)
+ '\x07[\xcd\x15'
>>> bytes2int(int2bytes(123456789))
123456789
"""
@@ -78,32 +67,126 @@ def int2bytes(number):
return string
-def fast_exponentiation(a, p, n):
- """Calculates r = a^p mod n
+def to64(number):
+ """Converts a number in the range of 0 to 63 into base 64 digit
+ character in the range of '0'-'9', 'A'-'Z', 'a'-'z','-','_'.
+
+ >>> to64(10)
+ 'A'
+ """
+
+ if not (type(number) is types.LongType or type(number) is types.IntType):
+ raise TypeError("You must pass a long or an int")
+
+ if 0 <= number <= 9: #00-09 translates to '0' - '9'
+ return chr(number + 48)
+
+ if 10 <= number <= 35:
+ return chr(number + 55) #10-35 translates to 'A' - 'Z'
+
+ if 36 <= number <= 61:
+ return chr(number + 61) #36-61 translates to 'a' - 'z'
+
+ if number == 62: # 62 translates to '-' (minus)
+ return chr(45)
+
+ if number == 63: # 63 translates to '_' (underscore)
+ return chr(95)
+
+
+def from64(number):
+ """Converts an ordinal character value in the range of
+ 0-9,A-Z,a-z,-,_ to a number in the range of 0-63.
+
+ >>> from64(49)
+ 1
+ """
+
+ if not (type(number) is types.LongType or type(number) is types.IntType):
+ raise TypeError("You must pass a long or an int")
+
+ if 48 <= number <= 57: #ord('0') - ord('9') translates to 0-9
+ return(number - 48)
+
+ if 65 <= number <= 90: #ord('A') - ord('Z') translates to 10-35
+ return(number - 55)
+
+ if 97 <= number <= 122: #ord('a') - ord('z') translates to 36-61
+ return(number - 61)
+
+ if number == 45: #ord('-') translates to 62
+ return(62)
+
+ if number == 95: #ord('_') translates to 63
+ return(63)
+
+
+
+def int2str64(number):
+ """Converts a number to a string of base64 encoded characters in
+ the range of '0'-'9','A'-'Z,'a'-'z','-','_'.
+
+ >>> int2str64(123456789)
+ '7MyqL'
"""
- result = a % n
- remainders = []
- while p != 1:
- remainders.append(p & 1)
- p = p >> 1
- while remainders:
- rem = remainders.pop()
- result = ((a ** rem) * result ** 2) % n
+
+ if not (type(number) is types.LongType or type(number) is types.IntType):
+ raise TypeError("You must pass a long or an int")
+
+ string = ""
+
+ while number > 0:
+ string = "%s%s" % (to64(number & 0x3F), string)
+ number /= 64
+
+ return string
+
+
+def str642int(string):
+ """Converts a base64 encoded string into an integer.
+ The chars of this string in in the range '0'-'9','A'-'Z','a'-'z','-','_'
+
+ >>> str642int('7MyqL')
+ 123456789
+ """
+
+ if not (type(string) is types.ListType or type(string) is types.StringType):
+ raise TypeError("You must pass a string or a list")
+
+ integer = 0
+ for byte in string:
+ integer *= 64
+ if type(byte) is types.StringType: byte = ord(byte)
+ integer += from64(byte)
+
+ return integer
+
+
+def fast_exponentiation(a, e, n):
+ """Calculates r = a^e mod n
+ """
+ #Single loop version is faster and uses less memory
+ #MSB is always 1 so skip testing it and start with result = a
+ msbe = int(math.ceil(math.log(e,2))) - 2 #Find MSB-1 of exponent
+ test = long(1 << msbe)
+ a %= n #Throw away any overflow
+ result = a #Start with result = a (skip MSB test)
+ while test != 0:
+ if e & test != 0: #If exponent bit 1 square and mult by a
+ result = (result * result * a) % n
+ else: #If exponent bit 0 just square
+ result = (result * result) % n
+ test >>= 1 #Move to next exponent bit
return result
def read_random_int(nbits):
"""Reads a random integer of approximately nbits bits rounded up
to whole bytes"""
- nbytes = ceil(nbits/8.)
+ nbytes = int(math.ceil(nbits/8.))
randomdata = os.urandom(nbytes)
return bytes2int(randomdata)
-def ceil(x):
- """ceil(x) -> int(math.ceil(x))"""
-
- return int(math.ceil(x))
-
def randint(minvalue, maxvalue):
"""Returns a random integer x with minvalue <= x <= maxvalue"""
@@ -115,7 +198,7 @@ def randint(minvalue, maxvalue):
range = maxvalue - minvalue
# Which is this number of bytes
- rangebytes = ceil(math.log(range, 2) / 8.)
+ rangebytes = int(math.ceil(math.log(range, 2) / 8.))
# Convert to bits, but make sure it's always at least min_nbits*2
rangebits = max(rangebytes * 8, min_nbits * 2)
@@ -125,29 +208,23 @@ def randint(minvalue, maxvalue):
return (read_random_int(nbits) % range) + minvalue
-def fermat_little_theorem(p):
- """Returns 1 if p may be prime, and something else if p definitely
- is not prime"""
-
- a = randint(1, p-1)
- return fast_exponentiation(a, p-1, p)
-
def jacobi(a, b):
"""Calculates the value of the Jacobi symbol (a/b)
+ where both a and b are positive integers, and b is odd
"""
- if a % b == 0:
- return 0
+ if a == 0: return 0
result = 1
while a > 1:
if a & 1:
if ((a-1)*(b-1) >> 2) & 1:
result = -result
- b, a = a, b % a
+ a, b = b % a, a
else:
- if ((b ** 2 - 1) >> 3) & 1:
+ if (((b ** 2) - 1) >> 3) & 1:
result = -result
- a = a >> 1
+ a >>= 1
+ if a == 0: return 0
return result
def jacobi_witness(x, n):
@@ -169,11 +246,9 @@ def randomized_primality_testing(n, k):
probably prime.
"""
- q = 0.5 # Property of the jacobi_witness function
+ # 50% of Jacobi-witnesses can report compositness of non-prime numbers
- # t = int(math.ceil(k / math.log(1/q, 2)))
- t = ceil(k / math.log(1/q, 2))
- for i in range(t+1):
+ for i in range(k):
x = randint(1, n-1)
if jacobi_witness(x, n): return False
@@ -188,13 +263,7 @@ def is_prime(number):
1
"""
- """
- if not fermat_little_theorem(number) == 1:
- # Not prime, according to Fermat's little theorem
- return False
- """
-
- if randomized_primality_testing(number, 5):
+ if randomized_primality_testing(number, 6):
# Prime, according to Jacobi
return True
@@ -215,8 +284,6 @@ def getprime(nbits):
0
"""
- nbytes = int(math.ceil(nbits/8.))
-
while True:
integer = read_random_int(nbits)
@@ -245,26 +312,38 @@ def are_relatively_prime(a, b):
def find_p_q(nbits):
"""Returns a tuple of two different primes of nbits bits"""
-
- p = getprime(nbits)
- while True:
- q = getprime(nbits)
- if not q == p: break
-
+ still_looking = True
+ small_primes=[3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79]
+ pbits = nbits + (nbits/16) #Make sure that p and q aren't too close
+ qbits = nbits - (nbits/16) #or the factoring programs can factor n
+ while still_looking:
+ p = getprime(pbits)
+ while True:
+ q = getprime(qbits)
+ if not q == p: break
+ #Now verify that phi_n (p-1)*(q-1) is not divisible by small primes
+ phi_n = (p-1)*(q-1)
+ still_looking = False
+ for sp in small_primes:
+ if phi_n % sp == 0: #check each small prime for divisibility
+ still_looking = True
+ break #Any divisible small prime, keep looking
return (p, q)
def extended_euclid_gcd(a, b):
"""Returns a tuple (d, i, j) such that d = gcd(a, b) = ia + jb
"""
-
- if b == 0:
- return (a, 1, 0)
-
- q = abs(a % b)
- r = long(a / b)
- (d, k, l) = extended_euclid_gcd(b, q)
-
- return (d, l, k - l*r)
+ # Iterateive Version is faster and uses much less stack space
+ x = 0
+ y = 1
+ lx = 1
+ ly = 0
+ while b != 0:
+ q = long(a/b)
+ (a, b) = (b, a % b)
+ (x, lx) = ((lx - (q * x)),x)
+ (y, ly) = ((ly - (q * y)),y)
+ return (a, lx, ly)
# Main function: calculate encryption and decryption keys
def calculate_keys(p, q, nbits):
@@ -284,7 +363,7 @@ def calculate_keys(p, q, nbits):
if not d == 1:
raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n))
-
+ if (i < 0): i += phi_n
if not (e * i) % phi_n == 1:
raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n))
@@ -297,17 +376,12 @@ def gen_keys(nbits):
Note: this can take a long time, depending on the key size.
"""
- while True:
- (p, q) = find_p_q(nbits)
- (e, d) = calculate_keys(p, q, nbits)
-
- # For some reason, d is sometimes negative. We don't know how
- # to fix it (yet), so we keep trying until everything is shiny
- if d > 0: break
+ (p, q) = find_p_q(nbits)
+ (e, d) = calculate_keys(p, q, nbits)
return (p, q, e, d)
-def gen_pubpriv_keys(nbits):
+def newkeys(nbits):
"""Generates public and private keys, and returns them as (pub,
priv).
@@ -320,60 +394,77 @@ def gen_pubpriv_keys(nbits):
return ( {'e': e, 'n': p*q}, {'d': d, 'p': p, 'q': q} )
def encrypt_int(message, ekey, n):
- """Encrypts a message using encryption key 'ekey', working modulo
- n"""
+ """Encrypts a message using encryption key 'ekey', working modulo n"""
if type(message) is types.IntType:
return encrypt_int(long(message), ekey, n)
if not type(message) is types.LongType:
- raise TypeError("You must pass a long or an int")
+ raise TypeError("You must pass a long or int")
- if message > 0 and \
- math.floor(math.log(message, 2)) > math.floor(math.log(n, 2)):
+ if message < 0 or message > n:
raise OverflowError("The message is too long")
+ #Note: Bit exponents start at zero (bit counts start at 1) this is correct
+ safebit = int(math.floor(math.log(n,2))) - 1 #compute safe bit (MSB - 1)
+ message += (1 << safebit) #add safebit to ensure folding
+
return fast_exponentiation(message, ekey, n)
def decrypt_int(cyphertext, dkey, n):
"""Decrypts a cypher text using the decryption key 'dkey', working
modulo n"""
- return encrypt_int(cyphertext, dkey, n)
+ message = fast_exponentiation(cyphertext, dkey, n)
+
+ safebit = int(math.floor(math.log(n,2))) - 1 #compute safe bit (MSB - 1)
+ message -= (1 << safebit) #remove safebit before decode
+
+ return message
+
+def encode64chops(chops):
+ """base64encodes chops and combines them into a ',' delimited string"""
+
+ chips = [] #chips are character chops
+
+ for value in chops:
+ chips.append(int2str64(value))
-def sign_int(message, dkey, n):
- """Signs 'message' using key 'dkey', working modulo n"""
+ encoded = ""
- return decrypt_int(message, dkey, n)
+ for string in chips:
+ encoded = encoded + string + ',' #delimit chops with comma
-def verify_int(signed, ekey, n):
- """verifies 'signed' using key 'ekey', working modulo n"""
+ return encoded
- return encrypt_int(signed, ekey, n)
+def decode64chops(string):
+ """base64decodes and makes a ',' delimited string into chops"""
-def picklechops(chops):
- """Pickles and base64encodes it's argument chops"""
+ chips = string.split(',') #split chops at commas
- value = zlib.compress(dumps(chops))
- encoded = base64.encodestring(value)
- return encoded.strip()
+ chops = []
-def unpicklechops(string):
- """base64decodes and unpickes it's argument string into chops"""
+ for string in chips: #make char chops (chips) into chops
+ chops.append(str642int(string))
- return loads(zlib.decompress(base64.decodestring(string)))
+ return chops
def chopstring(message, key, n, funcref):
- """Splits 'message' into chops that are at most as long as n,
- converts these into integers, and calls funcref(integer, key, n)
- for each chop.
+ """Chops the 'message' into integers that fit into n,
+ leaving room for a safebit to be added to ensure that all
+ messages fold during exponentiation. The MSB of the number n
+ is not independant modulo n (setting it could cause overflow), so
+ use the next lower bit for the safebit. Therefore reserve 2-bits
+ in the number n for non-data bits. Calls specified encryption
+ function for each chop.
Used by 'encrypt' and 'sign'.
"""
msglen = len(message)
mbits = msglen * 8
- nbits = int(math.floor(math.log(n, 2)))
+ # floor of log deducts 1 bit of n and the - 1, deducts the second bit.
+ nbits = int(math.floor(math.log(n, 2))) - 1 # leave room for safebit
nbytes = nbits / 8
blocks = msglen / nbytes
@@ -388,9 +479,9 @@ def chopstring(message, key, n, funcref):
value = bytes2int(block)
cypher.append(funcref(value, key, n))
- return picklechops(cypher)
+ return encode64chops(cypher) #Encode encrypted ints to base64 strings
-def gluechops(chops, key, n, funcref):
+def gluechops(string, key, n, funcref):
"""Glues chops back together into a string. calls
funcref(integer, key, n) for each chop.
@@ -398,11 +489,11 @@ def gluechops(chops, key, n, funcref):
"""
message = ""
- chops = unpicklechops(chops)
+ chops = decode64chops(string) #Decode base64 strings into integer chops
for cpart in chops:
- mpart = funcref(cpart, key, n)
- message += int2bytes(mpart)
+ mpart = funcref(cpart, key, n) #Decrypt each chop
+ message += int2bytes(mpart) #Combine decrypted strings into a msg
return message
@@ -414,7 +505,7 @@ def encrypt(message, key):
def sign(message, key):
"""Signs a string 'message' with the private key 'key'"""
- return chopstring(message, key['d'], key['p']*key['q'], decrypt_int)
+ return chopstring(message, key['d'], key['p']*key['q'], encrypt_int)
def decrypt(cypher, key):
"""Decrypts a cypher with the private key 'key'"""
@@ -424,12 +515,12 @@ def decrypt(cypher, key):
def verify(cypher, key):
"""Verifies a cypher with the public key 'key'"""
- return gluechops(cypher, key['e'], key['n'], encrypt_int)
+ return gluechops(cypher, key['e'], key['n'], decrypt_int)
# Do doctest if we're not imported
if __name__ == "__main__":
import doctest
doctest.testmod()
-__all__ = ["gen_pubpriv_keys", "encrypt", "decrypt", "sign", "verify"]
+__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify"]