summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBarry Mead <barrymead@cox.net>2010-02-16 10:12:47 -0700
committerBarry Mead <barrymead@cox.net>2010-02-16 10:12:47 -0700
commit8cb869bef935642beaa87f34a5bb26c41c0b6fac (patch)
tree0eac82f2c6507cf409d2225ab97c330a31d0de49
parent4b2d4e0bc90e46f734220d8f55f5efab3661483d (diff)
downloadrsa-git-8cb869bef935642beaa87f34a5bb26c41c0b6fac.tar.gz
New faster decryption with Chinese Remainder Theorem 2-part exponentiation and combine method
-rw-r--r--rsa/fastrsa.py590
1 files changed, 590 insertions, 0 deletions
diff --git a/rsa/fastrsa.py b/rsa/fastrsa.py
new file mode 100644
index 0000000..1e37273
--- /dev/null
+++ b/rsa/fastrsa.py
@@ -0,0 +1,590 @@
+"""RSA module
+
+Module for calculating large primes, and RSA encryption, decryption,
+signing and verification. Includes generating public and private keys.
+"""
+
+__author__ = "Sybren Stuvel, Marloes de Boer, Ivo Tamboer, and Barry Mead"
+__date__ = "2010-02-08"
+
+import math
+import os
+import random
+import sys
+import types
+
+def gcd(p, q):
+ """Returns the greatest common divisor of p and q
+ >>> gcd(48, 180)
+ 12
+ """
+ # Iterateive Version is faster and uses much less stack space
+ while q != 0:
+ if p < q: (p,q) = (q,p)
+ (p,q) = (q, p % q)
+ return p
+
+
+def bytes2int(bytes):
+ """Converts a list of bytes or a string to an integer
+
+ >>> (((128 * 256) + 64) * 256) + 15
+ 8405007
+ >>> l = [128, 64, 15]
+ >>> bytes2int(l) #same as bytes2int('\x80@\x0f')
+ 8405007
+ """
+
+ if not (type(bytes) is types.ListType or type(bytes) is types.StringType):
+ raise TypeError("You must pass a string or a list")
+
+ # Convert byte stream to integer
+ integer = 0
+ for byte in bytes:
+ integer *= 256
+ if type(byte) is types.StringType: byte = ord(byte)
+ integer += byte
+
+ return integer
+
+def int2bytes(number):
+ """Converts a number to a string of bytes
+
+ >>>int2bytes(123456789)
+ '\x07[\xcd\x15'
+ >>> bytes2int(int2bytes(123456789))
+ 123456789
+ """
+
+ if not (type(number) is types.LongType or type(number) is types.IntType):
+ raise TypeError("You must pass a long or an int")
+
+ string = ""
+
+ while number > 0:
+ string = "%s%s" % (chr(number & 0xFF), string)
+ number /= 256
+
+ return string
+
+def to64(number):
+ """Converts a number in the range of 0 to 63 into base 64 digit
+ character in the range of '0'-'9', 'A'-'Z', 'a'-'z','-','_'.
+
+ >>> to64(10)
+ 'A'
+ """
+
+ if not (type(number) is types.LongType or type(number) is types.IntType):
+ raise TypeError("You must pass a long or an int")
+
+ if 0 <= number <= 9: #00-09 translates to '0' - '9'
+ return chr(number + 48)
+
+ if 10 <= number <= 35:
+ return chr(number + 55) #10-35 translates to 'A' - 'Z'
+
+ if 36 <= number <= 61:
+ return chr(number + 61) #36-61 translates to 'a' - 'z'
+
+ if number == 62: # 62 translates to '-' (minus)
+ return chr(45)
+
+ if number == 63: # 63 translates to '_' (underscore)
+ return chr(95)
+
+
+def from64(number):
+ """Converts an ordinal character value in the range of
+ 0-9,A-Z,a-z,-,_ to a number in the range of 0-63.
+
+ >>> from64(49)
+ 1
+ """
+
+ if not (type(number) is types.LongType or type(number) is types.IntType):
+ raise TypeError("You must pass a long or an int")
+
+ if 48 <= number <= 57: #ord('0') - ord('9') translates to 0-9
+ return(number - 48)
+
+ if 65 <= number <= 90: #ord('A') - ord('Z') translates to 10-35
+ return(number - 55)
+
+ if 97 <= number <= 122: #ord('a') - ord('z') translates to 36-61
+ return(number - 61)
+
+ if number == 45: #ord('-') translates to 62
+ return(62)
+
+ if number == 95: #ord('_') translates to 63
+ return(63)
+
+
+
+def int2str64(number):
+ """Converts a number to a string of base64 encoded characters in
+ the range of '0'-'9','A'-'Z,'a'-'z','-','_'.
+
+ >>> int2str64(123456789)
+ '7MyqL'
+ """
+
+ if not (type(number) is types.LongType or type(number) is types.IntType):
+ raise TypeError("You must pass a long or an int")
+
+ string = ""
+
+ while number > 0:
+ string = "%s%s" % (to64(number & 0x3F), string)
+ number /= 64
+
+ return string
+
+
+def str642int(string):
+ """Converts a base64 encoded string into an integer.
+ The chars of this string in in the range '0'-'9','A'-'Z','a'-'z','-','_'
+
+ >>> str642int('7MyqL')
+ 123456789
+ """
+
+ if not (type(string) is types.ListType or type(string) is types.StringType):
+ raise TypeError("You must pass a string or a list")
+
+ integer = 0
+ for byte in string:
+ integer *= 64
+ if type(byte) is types.StringType: byte = ord(byte)
+ integer += from64(byte)
+
+ return integer
+
+
+def fast_exponentiation(a, e, n):
+ """Calculates r = a^e mod n
+ """
+ #Single loop version is faster and uses less memory
+ #MSB is always 1 so skip testing it and start with result = a
+ msbe = int(math.ceil(math.log(e,2))) - 2 #Find MSB-1 of exponent
+ test = long(1 << msbe)
+ a %= n #Throw away any overflow
+ result = a #Start with result = a (skip MSB test)
+ while test != 0:
+ if e & test != 0: #If exponent bit 1 square and mult by a
+ result = (result * result * a) % n
+ else: #If exponent bit 0 just square
+ result = (result * result) % n
+ test >>= 1 #Move to next exponent bit
+ return result
+
+def read_random_int(nbits):
+ """Reads a random integer of approximately nbits bits rounded up
+ to whole bytes"""
+
+ nbytes = int(math.ceil(nbits/8.))
+ randomdata = os.urandom(nbytes)
+ return bytes2int(randomdata)
+
+def randint(minvalue, maxvalue):
+ """Returns a random integer x with minvalue <= x <= maxvalue"""
+
+ # Safety - get a lot of random data even if the range is fairly
+ # small
+ min_nbits = 32
+
+ # The range of the random numbers we need to generate
+ range = maxvalue - minvalue
+
+ # Which is this number of bytes
+ rangebytes = int(math.ceil(math.log(range, 2) / 8.))
+
+ # Convert to bits, but make sure it's always at least min_nbits*2
+ rangebits = max(rangebytes * 8, min_nbits * 2)
+
+ # Take a random number of bits between min_nbits and rangebits
+ nbits = random.randint(min_nbits, rangebits)
+
+ return (read_random_int(nbits) % range) + minvalue
+
+def jacobi(a, b):
+ """Calculates the value of the Jacobi symbol (a/b)
+ where both a and b are positive integers, and b is odd
+ """
+
+ if a == 0: return 0
+ result = 1
+ while a > 1:
+ if a & 1:
+ if ((a-1)*(b-1) >> 2) & 1:
+ result = -result
+ a, b = b % a, a
+ else:
+ if (((b ** 2) - 1) >> 3) & 1:
+ result = -result
+ a >>= 1
+ if a == 0: return 0
+ return result
+
+def jacobi_witness(x, n):
+ """Returns False if n is an Euler pseudo-prime with base x, and
+ True otherwise.
+ """
+
+ j = jacobi(x, n) % n
+ f = fast_exponentiation(x, (n-1)/2, n)
+
+ if j == f: return False
+ return True
+
+def randomized_primality_testing(n, k):
+ """Calculates whether n is composite (which is always correct) or
+ prime (which is incorrect with error probability 2**-k)
+
+ Returns False if the number if composite, and True if it's
+ probably prime.
+ """
+
+ # 50% of Jacobi-witnesses can report compositness of non-prime numbers
+
+ for i in range(k):
+ x = randint(1, n-1)
+ if jacobi_witness(x, n): return False
+
+ return True
+
+def is_prime(number):
+ """Returns True if the number is prime, and False otherwise.
+
+ >>> is_prime(42)
+ 0
+ >>> is_prime(41)
+ 1
+ """
+
+ if randomized_primality_testing(number, 6):
+ # Prime, according to Jacobi
+ return True
+
+ # Not prime
+ return False
+
+
+def getprime(nbits):
+ """Returns a prime number of max. 'math.ceil(nbits/8)*8' bits. In
+ other words: nbits is rounded up to whole bytes.
+
+ >>> p = getprime(8)
+ >>> is_prime(p-1)
+ 0
+ >>> is_prime(p)
+ 1
+ >>> is_prime(p+1)
+ 0
+ """
+
+ while True:
+ integer = read_random_int(nbits)
+
+ # Make sure it's odd
+ integer |= 1
+
+ # Test for primeness
+ if is_prime(integer): break
+
+ # Retry if not prime
+
+ return integer
+
+def are_relatively_prime(a, b):
+ """Returns True if a and b are relatively prime, and False if they
+ are not.
+
+ >>> are_relatively_prime(2, 3)
+ 1
+ >>> are_relatively_prime(2, 4)
+ 0
+ """
+
+ d = gcd(a, b)
+ return (d == 1)
+
+def find_p_q(nbits):
+ """Returns a tuple of two different primes of nbits bits"""
+ pbits = nbits + (nbits/16) #Make sure that p and q aren't too close
+ qbits = nbits - (nbits/16) #or the factoring programs can factor n
+ p = getprime(pbits)
+ while True:
+ q = getprime(qbits)
+ if not q == p: break
+ return (p, q)
+
+def extended_euclid_gcd(a, b):
+ """Returns a tuple (d, i, j) such that d = gcd(a, b) = ia + jb
+ """
+ # Iterateive Version is faster and uses much less stack space
+ x = 0
+ y = 1
+ lx = 1
+ ly = 0
+ lb = b #Remember modulus (to remove negs)
+ while b != 0:
+ q = long(a/b)
+ (a, b) = (b, a % b)
+ (x, lx) = ((lx - (q * x)),x)
+ (y, ly) = ((ly - (q * y)),y)
+ if (lx < 0): lx += lb #No Negative return values
+ return (a, lx, ly)
+
+# Main function: calculate encryption and decryption keys
+def calculate_keys(p, q, nbits):
+ """Calculates an encryption and a decryption key for p and q, and
+ returns them as a tuple (e, d)"""
+
+ n = p * q
+ phi_n = (p-1) * (q-1)
+
+ while True:
+ # Make sure e has enough bits so we ensure "wrapping" through
+ # modulo n
+ e = max(65537,getprime(nbits/4)) #minimum e is 65537 per RSA spec
+ if are_relatively_prime(e, n) and are_relatively_prime(e, phi_n): break
+
+ (r, dp, j) = extended_euclid_gcd(e, p-1) #Compute exponent dp
+
+ if not r == 1:
+ raise Exception("e (%d) and p-1 (%d) are not relatively prime" % (e, p-1))
+
+ (r, dq, j) = extended_euclid_gcd(e, q-1) #Compute exponent dq
+
+ if not r == 1:
+ raise Exception("e (%d) and q-1 (%d) are not relatively prime" % (e, q-1))
+
+ (r, qi, j) = extended_euclid_gcd(q, p) #Compute coefficent qi
+
+ if not r == 1:
+ raise Exception("q (%d) and p (%d) are not relatively prime" % (q, p))
+
+ return (e, dp, dq, qi)
+
+
+def gen_keys(nbits):
+ """Generate RSA keys of nbits bits. Returns (p, q, e, d).
+
+ Note: this can take a long time, depending on the key size.
+ """
+
+ (p, q) = find_p_q(nbits)
+ (e, dp, dq, qi) = calculate_keys(p, q, nbits)
+
+ return (p, q, e, dp, dq, qi)
+
+def newkeys(nbits):
+ """Generates public and private keys, and returns them as (pub,
+ priv).
+
+ The public key consists of a dict {e: ..., , n: ....). The private
+ key consists of a dict {d: ...., p: ...., q: ....).
+ """
+ nbits = max(9,nbits) #Minimum key size is 9 bit for p and q
+ (p, q, e, dp, dq, qi) = gen_keys(nbits)
+
+ return ( {'e':e,'n':p*q}, {'p':p,'q':q,'dp':dp,'dq':dq,'qi':qi} )
+
+def encrypt_int(message, key):
+ """Encrypts a message using public key 'key', working modulo n"""
+
+ if type(message) is types.IntType:
+ message = long(message)
+
+ if not type(message) is types.LongType:
+ raise TypeError("You must pass a long or int")
+
+ if message < 0 or message > key['n']:
+ raise OverflowError("The message is too long")
+
+ #Note: Bit exponents start at zero (bit counts start at 1) this is correct
+ safebit = int(math.floor(math.log(key['n'],2))) - 1 #compute safe bit (MSB - 1)
+ message += (1 << safebit) #add safebit to ensure folding
+
+ return fast_exponentiation(message, key['e'], key['n'])
+
+def verify_int(cyphertext, key):
+ """Decrypts cyphertext using public key 'key', working modulo n"""
+
+ if type(cyphertext) is types.IntType:
+ cyphertext = long(cyphertext)
+
+ if not type(cyphertext) is types.LongType:
+ raise TypeError("You must pass a long or int")
+
+ message = fast_exponentiation(cyphertext, key['e'], key['n'])
+
+ #Note: Bit exponents start at zero (bit counts start at 1) this is correct
+ safebit = int(math.floor(math.log(key['n'],2))) - 1 #safe bit (MSB - 1)
+ message -= (1 << safebit) #remove safe bit before decode
+
+ return message
+
+def decrypt_int(cyphertext, key):
+ """Decrypts a cypher text using the private key 'key', working
+ modulo n"""
+
+ n = key['p'] * key['q']
+ #Decrypt in 2 parts, using faster Chinese Remainder Theorem method
+ m1 = fast_exponentiation(cyphertext, key['dp'], key['p'])
+ m2 = fast_exponentiation(cyphertext, key['dq'], key['q'])
+ dif = m1 - m2
+ if dif < 0: dif += key['p']
+ h = (key['qi'] * dif) % key['p']
+ message = m2 + (h * key['q'])
+
+ safebit = int(math.floor(math.log(n,2))) - 1 #compute safe bit (MSB - 1)
+ message -= (1 << safebit) #remove safebit before decode
+
+ return message
+
+def sign_int(message, key):
+ """Encrypts a message with the private key 'key', working
+ modulo n"""
+
+ if type(message) is types.IntType:
+ message = long(message)
+
+ if not type(message) is types.LongType:
+ raise TypeError("You must pass a long or int")
+
+ n = key['p'] * key['q'] #computer n from p and q
+
+ if message < 0 or message > n:
+ raise OverflowError("The message is too long")
+
+ safebit = int(math.floor(math.log(n,2))) - 1 #compute safe bit (MSB - 1)
+ message += (1 << safebit) #add safebit before encrypt
+
+ #Encrypt in 2 parts, using faster Chinese Remainder Theorem method
+ c1 = fast_exponentiation(message, key['dp'], key['p'])
+ c2 = fast_exponentiation(message, key['dq'], key['q'])
+ dif = c1 - c2
+ if dif < 0: dif += key['p']
+ h = (key['qi'] * dif) % key['p']
+ cyphertext = c2 + (h * key['q'])
+
+ return cyphertext
+
+def encode64chops(chops):
+ """base64encodes chops and combines them into a ',' delimited string"""
+
+ chips = [] #chips are character chops
+
+ for value in chops:
+ chips.append(int2str64(value))
+
+ encoded = ""
+
+ for string in chips:
+ encoded = encoded + string + ',' #delimit chops with comma
+
+ return encoded
+
+def decode64chops(string):
+ """base64decodes and makes a ',' delimited string into chops"""
+
+ chips = string.split(',') #split chops at commas
+
+ chops = []
+
+ for string in chips: #make char chops (chips) into chops
+ chops.append(str642int(string))
+
+ return chops
+
+def chopstring(message, key, funcref):
+ """Chops the 'message' into integers that fit into n,
+ leaving room for a safebit to be added to ensure that all
+ messages fold during exponentiation. The MSB of the number n
+ is not independant modulo n (setting it could cause overflow), so
+ use the next lower bit for the safebit. Therefore reserve 2-bits
+ in the number n for non-data bits. Calls specified encryption
+ function for each chop.
+
+ Used by 'encrypt' and 'sign'.
+ """
+
+ if key.__contains__('n'):
+ n = key['n'] #Public key has n already
+ else:
+ n = key['p'] * key['q'] #Private key has p & q
+
+ msglen = len(message)
+ mbits = msglen * 8
+ # floor of log deducts 1 bit of n and the - 1, deducts the second bit.
+ nbits = int(math.floor(math.log(n, 2))) - 1 # leave room for safebit
+ nbytes = nbits / 8
+ blocks = msglen / nbytes
+
+ if msglen % nbytes > 0:
+ blocks += 1
+
+ cypher = []
+
+ for bindex in range(blocks):
+ offset = bindex * nbytes
+ block = message[offset:offset+nbytes]
+ value = bytes2int(block)
+ cypher.append(funcref(value, key))
+
+ return encode64chops(cypher) #Encode encrypted ints to base64 strings
+
+def gluechops(string, key, funcref):
+ """Glues chops back together into a string. calls
+ funcref(integer, key) for each chop.
+
+ Used by 'decrypt' and 'verify'.
+ """
+ message = ""
+
+ chops = decode64chops(string) #Decode base64 strings into integer chops
+
+ for cpart in chops:
+ mpart = funcref(cpart, key) #Decrypt each chop
+ message += int2bytes(mpart) #Combine decrypted strings into a msg
+
+ return message
+
+def encrypt(message, key):
+ """Encrypts a string 'message' with the public key 'key'"""
+ if key.__contains__('n'):
+ return chopstring(message, key, encrypt_int)
+ else:
+ raise Exception("You must use the public key with encrypt")
+
+def sign(message, key):
+ """Signs a string 'message' with the private key 'key'"""
+ if key.__contains__('p'):
+ return chopstring(message, key, sign_int)
+ else:
+ raise Exception("You must use the private key with sign")
+
+def decrypt(cypher, key):
+ """Decrypts a cypher with the private key 'key'"""
+ if key.__contains__('p'):
+ return gluechops(cypher, key, decrypt_int)
+ else:
+ raise Exception("You must use the private key with decrypt")
+
+def verify(cypher, key):
+ """Verifies a cypher with the public key 'key'"""
+ if key.__contains__('n'):
+ return gluechops(cypher, key, verify_int)
+ else:
+ raise Exception("You must use the public key with verify")
+
+# Do doctest if we're not imported
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
+
+__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify"]
+