summaryrefslogtreecommitdiff
path: root/rsa/keygen.py
blob: 1f1c9a7ad579924b9cac2e421be4c14713ef1cdf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
'''RSA key generation code.

Create new keys with the newkeys() function.

The private key consists of a dict {d: ...., p: ...., q: ....).

The public key consists of a dict {e: ..., , n: p*q)


'''

import rsa.prime

def extended_gcd(a, b):
    """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
    """
    # r = gcd(a,b) i = multiplicitive inverse of a mod b
    #      or      j = multiplicitive inverse of b mod a
    # Neg return values for i or j are made positive mod b or a respectively
    # Iterateive Version is faster and uses much less stack space
    x = 0
    y = 1
    lx = 1
    ly = 0
    oa = a                             #Remember original a/b to remove 
    ob = b                             #negative values from return results
    while b != 0:
        q = long(a // b)
        (a, b)  = (b, a % b)
        (x, lx) = ((lx - (q * x)),x)
        (y, ly) = ((ly - (q * y)),y)
    if (lx < 0): lx += ob              #If neg wrap modulo orignal b
    if (ly < 0): ly += oa              #If neg wrap modulo orignal a
    return (a, lx, ly)                 #Return only positive values



def find_p_q(nbits):
    ''''Returns a tuple of two different primes of nbits bits each.
    
    >>> (p, q) = find_p_q(128)
    
    The resulting p and q should be very close to 2*nbits bits, and no more
    than 2*nbits bits:
    >>> from rsa import common
    >>> common.bit_size(p * q) <= 256
    True
    >>> common.bit_size(p * q) > 240
    True
    
    '''
    
    # Make sure that p and q aren't too close or the factoring programs can
    # factor n.
    shift = nbits // 16
    pbits = nbits + shift
    qbits = nbits - shift
    
    p = rsa.prime.getprime(pbits)
    
    while True:
        q = rsa.prime.getprime(qbits)

        # Make sure p and q are different.
        if q != p: break
        
    return (p, q)

def calculate_keys(p, q, nbits):
    """Calculates an encryption and a decryption key given p and q, and
    returns them as a tuple (e, d)

    """

    phi_n = (p-1) * (q-1)

    # A very common choice for e is 65537
    e = 65537

    (d, i, _) = extended_gcd(e, phi_n)

    if not d == 1:
        raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n))
    if (i < 0):
        raise Exception("New extended_gcd shouldn't return negative values")
    if not (e * i) % phi_n == 1:
        raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n))

    return (e, i)


def gen_keys(nbits):
    """Generate RSA keys of nbits bits. Returns (p, q, e, d).

    Note: this can take a long time, depending on the key size.
    
    @param nbits: the total number of bits in ``p`` and ``q``. Both ``p`` and
        ``q`` will use ``nbits/2`` bits.
    """

    (p, q) = find_p_q(nbits // 2)
    (e, d) = calculate_keys(p, q, nbits // 2)

    return (p, q, e, d)

def newkeys(nbits):
    """Generates public and private keys, and returns them as (pub,
    priv).

    The public key consists of a dict {e: ..., , n: ....). The private
    key consists of a dict {d: ...., p: ...., q: ...., n: p*q).
    
    @param nbits: the number of bits required to store ``n = p*q``.
    
    """

    # Don't let nbits go below 9 bits
    nbits = max(9, nbits)
    (p, q, e, d) = gen_keys(nbits)
    
    n = p * q

    return ( {'e': e, 'n': n}, {'d': d, 'p': p, 'q': q, 'n': n} )

    
if __name__ == '__main__':
    print 'Running doctests 1000x or until failure'
    import doctest
    
    for count in range(1000):
        (failures, tests) = doctest.testmod()
        if failures:
            break
        
        if count and count % 100 == 0:
            print '%i times' % count
    
    print 'Doctests done'