diff options
author | Barry Mead <barrymead@cox.net> | 2010-02-15 03:04:21 -0700 |
---|---|---|
committer | Barry Mead <barrymead@cox.net> | 2010-02-15 03:04:21 -0700 |
commit | 24953a5a5eb9c9b2591fd2e963565269404f40e1 (patch) | |
tree | f2036dbc1c36fe177f624bfcc8fd8cc93d4d58ef | |
parent | 474c221b892435e76d3d7bd65b7230c8acfe2482 (diff) | |
download | rsa-24953a5a5eb9c9b2591fd2e963565269404f40e1.tar.gz |
faster gcd egcd f_exp
-rw-r--r-- | rsa/__init__.py | 337 |
1 files changed, 214 insertions, 123 deletions
diff --git a/rsa/__init__.py b/rsa/__init__.py index cd6a760..e6c862d 100644 --- a/rsa/__init__.py +++ b/rsa/__init__.py @@ -1,50 +1,37 @@ """RSA module -pri = k[1] //Private part of keys d,p,q Module for calculating large primes, and RSA encryption, decryption, signing and verification. Includes generating public and private keys. - -WARNING: this code implements the mathematics of RSA. It is not suitable for -real-world secure cryptography purposes. It has not been reviewed by a security -expert. It does not include padding of data. There are many ways in which the -output of this module, when used without any modification, can be sucessfully -attacked. """ -__author__ = "Sybren Stuvel, Marloes de Boer and Ivo Tamboer" -__date__ = "2010-02-05" -__version__ = '1.3.3' - -# NOTE: Python's modulo can return negative numbers. We compensate for -# this behaviour using the abs() function +__author__ = "Sybren Stuvel, Marloes de Boer, Ivo Tamboer, and Barry Mead" +__date__ = "2010-02-08" -from cPickle import dumps, loads -import base64 import math import os import random import sys import types -import zlib def gcd(p, q): """Returns the greatest common divisor of p and q - - - >>> gcd(42, 6) - 6 + >>> gcd(48, 180) + 12 """ - if p<q: return gcd(q, p) - if q == 0: return p - return gcd(q, abs(p%q)) + # Iterateive Version is faster and uses much less stack space + while q != 0: + if p < q: (p,q) = (q,p) + (p,q) = (q, p % q) + return p + def bytes2int(bytes): """Converts a list of bytes or a string to an integer - >>> (128*256 + 64)*256 + + 15 + >>> (((128 * 256) + 64) * 256) + 15 8405007 >>> l = [128, 64, 15] - >>> bytes2int(l) + >>> bytes2int(l) #same as bytes2int('\x80@\x0f') 8405007 """ @@ -63,6 +50,8 @@ def bytes2int(bytes): def int2bytes(number): """Converts a number to a string of bytes + >>>int2bytes(123456789) + '\x07[\xcd\x15' >>> bytes2int(int2bytes(123456789)) 123456789 """ @@ -78,32 +67,126 @@ def int2bytes(number): return string -def fast_exponentiation(a, p, n): - """Calculates r = a^p mod n +def to64(number): + """Converts a number in the range of 0 to 63 into base 64 digit + character in the range of '0'-'9', 'A'-'Z', 'a'-'z','-','_'. + + >>> to64(10) + 'A' + """ + + if not (type(number) is types.LongType or type(number) is types.IntType): + raise TypeError("You must pass a long or an int") + + if 0 <= number <= 9: #00-09 translates to '0' - '9' + return chr(number + 48) + + if 10 <= number <= 35: + return chr(number + 55) #10-35 translates to 'A' - 'Z' + + if 36 <= number <= 61: + return chr(number + 61) #36-61 translates to 'a' - 'z' + + if number == 62: # 62 translates to '-' (minus) + return chr(45) + + if number == 63: # 63 translates to '_' (underscore) + return chr(95) + + +def from64(number): + """Converts an ordinal character value in the range of + 0-9,A-Z,a-z,-,_ to a number in the range of 0-63. + + >>> from64(49) + 1 + """ + + if not (type(number) is types.LongType or type(number) is types.IntType): + raise TypeError("You must pass a long or an int") + + if 48 <= number <= 57: #ord('0') - ord('9') translates to 0-9 + return(number - 48) + + if 65 <= number <= 90: #ord('A') - ord('Z') translates to 10-35 + return(number - 55) + + if 97 <= number <= 122: #ord('a') - ord('z') translates to 36-61 + return(number - 61) + + if number == 45: #ord('-') translates to 62 + return(62) + + if number == 95: #ord('_') translates to 63 + return(63) + + + +def int2str64(number): + """Converts a number to a string of base64 encoded characters in + the range of '0'-'9','A'-'Z,'a'-'z','-','_'. + + >>> int2str64(123456789) + '7MyqL' """ - result = a % n - remainders = [] - while p != 1: - remainders.append(p & 1) - p = p >> 1 - while remainders: - rem = remainders.pop() - result = ((a ** rem) * result ** 2) % n + + if not (type(number) is types.LongType or type(number) is types.IntType): + raise TypeError("You must pass a long or an int") + + string = "" + + while number > 0: + string = "%s%s" % (to64(number & 0x3F), string) + number /= 64 + + return string + + +def str642int(string): + """Converts a base64 encoded string into an integer. + The chars of this string in in the range '0'-'9','A'-'Z','a'-'z','-','_' + + >>> str642int('7MyqL') + 123456789 + """ + + if not (type(string) is types.ListType or type(string) is types.StringType): + raise TypeError("You must pass a string or a list") + + integer = 0 + for byte in string: + integer *= 64 + if type(byte) is types.StringType: byte = ord(byte) + integer += from64(byte) + + return integer + + +def fast_exponentiation(a, e, n): + """Calculates r = a^e mod n + """ + #Single loop version is faster and uses less memory + #MSB is always 1 so skip testing it and start with result = a + msbe = int(math.ceil(math.log(e,2))) - 2 #Find MSB-1 of exponent + test = long(1 << msbe) + a %= n #Throw away any overflow + result = a #Start with result = a (skip MSB test) + while test != 0: + if e & test != 0: #If exponent bit 1 square and mult by a + result = (result * result * a) % n + else: #If exponent bit 0 just square + result = (result * result) % n + test >>= 1 #Move to next exponent bit return result def read_random_int(nbits): """Reads a random integer of approximately nbits bits rounded up to whole bytes""" - nbytes = ceil(nbits/8.) + nbytes = int(math.ceil(nbits/8.)) randomdata = os.urandom(nbytes) return bytes2int(randomdata) -def ceil(x): - """ceil(x) -> int(math.ceil(x))""" - - return int(math.ceil(x)) - def randint(minvalue, maxvalue): """Returns a random integer x with minvalue <= x <= maxvalue""" @@ -115,7 +198,7 @@ def randint(minvalue, maxvalue): range = maxvalue - minvalue # Which is this number of bytes - rangebytes = ceil(math.log(range, 2) / 8.) + rangebytes = int(math.ceil(math.log(range, 2) / 8.)) # Convert to bits, but make sure it's always at least min_nbits*2 rangebits = max(rangebytes * 8, min_nbits * 2) @@ -125,29 +208,23 @@ def randint(minvalue, maxvalue): return (read_random_int(nbits) % range) + minvalue -def fermat_little_theorem(p): - """Returns 1 if p may be prime, and something else if p definitely - is not prime""" - - a = randint(1, p-1) - return fast_exponentiation(a, p-1, p) - def jacobi(a, b): """Calculates the value of the Jacobi symbol (a/b) + where both a and b are positive integers, and b is odd """ - if a % b == 0: - return 0 + if a == 0: return 0 result = 1 while a > 1: if a & 1: if ((a-1)*(b-1) >> 2) & 1: result = -result - b, a = a, b % a + a, b = b % a, a else: - if ((b ** 2 - 1) >> 3) & 1: + if (((b ** 2) - 1) >> 3) & 1: result = -result - a = a >> 1 + a >>= 1 + if a == 0: return 0 return result def jacobi_witness(x, n): @@ -169,11 +246,9 @@ def randomized_primality_testing(n, k): probably prime. """ - q = 0.5 # Property of the jacobi_witness function + # 50% of Jacobi-witnesses can report compositness of non-prime numbers - # t = int(math.ceil(k / math.log(1/q, 2))) - t = ceil(k / math.log(1/q, 2)) - for i in range(t+1): + for i in range(k): x = randint(1, n-1) if jacobi_witness(x, n): return False @@ -188,13 +263,7 @@ def is_prime(number): 1 """ - """ - if not fermat_little_theorem(number) == 1: - # Not prime, according to Fermat's little theorem - return False - """ - - if randomized_primality_testing(number, 5): + if randomized_primality_testing(number, 6): # Prime, according to Jacobi return True @@ -215,8 +284,6 @@ def getprime(nbits): 0 """ - nbytes = int(math.ceil(nbits/8.)) - while True: integer = read_random_int(nbits) @@ -245,26 +312,38 @@ def are_relatively_prime(a, b): def find_p_q(nbits): """Returns a tuple of two different primes of nbits bits""" - - p = getprime(nbits) - while True: - q = getprime(nbits) - if not q == p: break - + still_looking = True + small_primes=[3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79] + pbits = nbits + (nbits/16) #Make sure that p and q aren't too close + qbits = nbits - (nbits/16) #or the factoring programs can factor n + while still_looking: + p = getprime(pbits) + while True: + q = getprime(qbits) + if not q == p: break + #Now verify that phi_n (p-1)*(q-1) is not divisible by small primes + phi_n = (p-1)*(q-1) + still_looking = False + for sp in small_primes: + if phi_n % sp == 0: #check each small prime for divisibility + still_looking = True + break #Any divisible small prime, keep looking return (p, q) def extended_euclid_gcd(a, b): """Returns a tuple (d, i, j) such that d = gcd(a, b) = ia + jb """ - - if b == 0: - return (a, 1, 0) - - q = abs(a % b) - r = long(a / b) - (d, k, l) = extended_euclid_gcd(b, q) - - return (d, l, k - l*r) + # Iterateive Version is faster and uses much less stack space + x = 0 + y = 1 + lx = 1 + ly = 0 + while b != 0: + q = long(a/b) + (a, b) = (b, a % b) + (x, lx) = ((lx - (q * x)),x) + (y, ly) = ((ly - (q * y)),y) + return (a, lx, ly) # Main function: calculate encryption and decryption keys def calculate_keys(p, q, nbits): @@ -284,7 +363,7 @@ def calculate_keys(p, q, nbits): if not d == 1: raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n)) - + if (i < 0): i += phi_n if not (e * i) % phi_n == 1: raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n)) @@ -297,17 +376,12 @@ def gen_keys(nbits): Note: this can take a long time, depending on the key size. """ - while True: - (p, q) = find_p_q(nbits) - (e, d) = calculate_keys(p, q, nbits) - - # For some reason, d is sometimes negative. We don't know how - # to fix it (yet), so we keep trying until everything is shiny - if d > 0: break + (p, q) = find_p_q(nbits) + (e, d) = calculate_keys(p, q, nbits) return (p, q, e, d) -def gen_pubpriv_keys(nbits): +def newkeys(nbits): """Generates public and private keys, and returns them as (pub, priv). @@ -320,60 +394,77 @@ def gen_pubpriv_keys(nbits): return ( {'e': e, 'n': p*q}, {'d': d, 'p': p, 'q': q} ) def encrypt_int(message, ekey, n): - """Encrypts a message using encryption key 'ekey', working modulo - n""" + """Encrypts a message using encryption key 'ekey', working modulo n""" if type(message) is types.IntType: return encrypt_int(long(message), ekey, n) if not type(message) is types.LongType: - raise TypeError("You must pass a long or an int") + raise TypeError("You must pass a long or int") - if message > 0 and \ - math.floor(math.log(message, 2)) > math.floor(math.log(n, 2)): + if message < 0 or message > n: raise OverflowError("The message is too long") + #Note: Bit exponents start at zero (bit counts start at 1) this is correct + safebit = int(math.floor(math.log(n,2))) - 1 #compute safe bit (MSB - 1) + message += (1 << safebit) #add safebit to ensure folding + return fast_exponentiation(message, ekey, n) def decrypt_int(cyphertext, dkey, n): """Decrypts a cypher text using the decryption key 'dkey', working modulo n""" - return encrypt_int(cyphertext, dkey, n) + message = fast_exponentiation(cyphertext, dkey, n) + + safebit = int(math.floor(math.log(n,2))) - 1 #compute safe bit (MSB - 1) + message -= (1 << safebit) #remove safebit before decode + + return message + +def encode64chops(chops): + """base64encodes chops and combines them into a ',' delimited string""" + + chips = [] #chips are character chops + + for value in chops: + chips.append(int2str64(value)) -def sign_int(message, dkey, n): - """Signs 'message' using key 'dkey', working modulo n""" + encoded = "" - return decrypt_int(message, dkey, n) + for string in chips: + encoded = encoded + string + ',' #delimit chops with comma -def verify_int(signed, ekey, n): - """verifies 'signed' using key 'ekey', working modulo n""" + return encoded - return encrypt_int(signed, ekey, n) +def decode64chops(string): + """base64decodes and makes a ',' delimited string into chops""" -def picklechops(chops): - """Pickles and base64encodes it's argument chops""" + chips = string.split(',') #split chops at commas - value = zlib.compress(dumps(chops)) - encoded = base64.encodestring(value) - return encoded.strip() + chops = [] -def unpicklechops(string): - """base64decodes and unpickes it's argument string into chops""" + for string in chips: #make char chops (chips) into chops + chops.append(str642int(string)) - return loads(zlib.decompress(base64.decodestring(string))) + return chops def chopstring(message, key, n, funcref): - """Splits 'message' into chops that are at most as long as n, - converts these into integers, and calls funcref(integer, key, n) - for each chop. + """Chops the 'message' into integers that fit into n, + leaving room for a safebit to be added to ensure that all + messages fold during exponentiation. The MSB of the number n + is not independant modulo n (setting it could cause overflow), so + use the next lower bit for the safebit. Therefore reserve 2-bits + in the number n for non-data bits. Calls specified encryption + function for each chop. Used by 'encrypt' and 'sign'. """ msglen = len(message) mbits = msglen * 8 - nbits = int(math.floor(math.log(n, 2))) + # floor of log deducts 1 bit of n and the - 1, deducts the second bit. + nbits = int(math.floor(math.log(n, 2))) - 1 # leave room for safebit nbytes = nbits / 8 blocks = msglen / nbytes @@ -388,9 +479,9 @@ def chopstring(message, key, n, funcref): value = bytes2int(block) cypher.append(funcref(value, key, n)) - return picklechops(cypher) + return encode64chops(cypher) #Encode encrypted ints to base64 strings -def gluechops(chops, key, n, funcref): +def gluechops(string, key, n, funcref): """Glues chops back together into a string. calls funcref(integer, key, n) for each chop. @@ -398,11 +489,11 @@ def gluechops(chops, key, n, funcref): """ message = "" - chops = unpicklechops(chops) + chops = decode64chops(string) #Decode base64 strings into integer chops for cpart in chops: - mpart = funcref(cpart, key, n) - message += int2bytes(mpart) + mpart = funcref(cpart, key, n) #Decrypt each chop + message += int2bytes(mpart) #Combine decrypted strings into a msg return message @@ -414,7 +505,7 @@ def encrypt(message, key): def sign(message, key): """Signs a string 'message' with the private key 'key'""" - return chopstring(message, key['d'], key['p']*key['q'], decrypt_int) + return chopstring(message, key['d'], key['p']*key['q'], encrypt_int) def decrypt(cypher, key): """Decrypts a cypher with the private key 'key'""" @@ -424,12 +515,12 @@ def decrypt(cypher, key): def verify(cypher, key): """Verifies a cypher with the public key 'key'""" - return gluechops(cypher, key['e'], key['n'], encrypt_int) + return gluechops(cypher, key['e'], key['n'], decrypt_int) # Do doctest if we're not imported if __name__ == "__main__": import doctest doctest.testmod() -__all__ = ["gen_pubpriv_keys", "encrypt", "decrypt", "sign", "verify"] +__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify"] |