diff options
Diffstat (limited to 'rsa/key.py')
-rw-r--r-- | rsa/key.py | 215 |
1 files changed, 117 insertions, 98 deletions
@@ -50,7 +50,7 @@ DEFAULT_EXPONENT = 65537 class AbstractKey: """Abstract superclass for private and public keys.""" - __slots__ = ('n', 'e', 'blindfac', 'blindfac_inverse', 'mutex') + __slots__ = ("n", "e", "blindfac", "blindfac_inverse", "mutex") def __init__(self, n: int, e: int) -> None: self.n = n @@ -64,7 +64,7 @@ class AbstractKey: self.mutex = threading.Lock() @classmethod - def _load_pkcs1_pem(cls, keyfile: bytes) -> 'AbstractKey': + def _load_pkcs1_pem(cls, keyfile: bytes) -> "AbstractKey": """Loads a key in PKCS#1 PEM format, implement in a subclass. :param keyfile: contents of a PEM-encoded file that contains @@ -76,7 +76,7 @@ class AbstractKey: """ @classmethod - def _load_pkcs1_der(cls, keyfile: bytes) -> 'AbstractKey': + def _load_pkcs1_der(cls, keyfile: bytes) -> "AbstractKey": """Loads a key in PKCS#1 PEM format, implement in a subclass. :param keyfile: contents of a DER-encoded file that contains @@ -102,7 +102,7 @@ class AbstractKey: """ @classmethod - def load_pkcs1(cls, keyfile: bytes, format: str = 'PEM') -> 'AbstractKey': + def load_pkcs1(cls, keyfile: bytes, format: str = "PEM") -> "AbstractKey": """Loads a key in PKCS#1 DER or PEM format. :param keyfile: contents of a DER- or PEM-encoded file that contains @@ -116,27 +116,28 @@ class AbstractKey: """ methods = { - 'PEM': cls._load_pkcs1_pem, - 'DER': cls._load_pkcs1_der, + "PEM": cls._load_pkcs1_pem, + "DER": cls._load_pkcs1_der, } method = cls._assert_format_exists(format, methods) return method(keyfile) @staticmethod - def _assert_format_exists(file_format: str, methods: typing.Mapping[str, typing.Callable]) \ - -> typing.Callable: - """Checks whether the given file format exists in 'methods'. - """ + def _assert_format_exists( + file_format: str, methods: typing.Mapping[str, typing.Callable] + ) -> typing.Callable: + """Checks whether the given file format exists in 'methods'.""" try: return methods[file_format] except KeyError as ex: - formats = ', '.join(sorted(methods.keys())) - raise ValueError('Unsupported format: %r, try one of %s' % (file_format, - formats)) from ex + formats = ", ".join(sorted(methods.keys())) + raise ValueError( + "Unsupported format: %r, try one of %s" % (file_format, formats) + ) from ex - def save_pkcs1(self, format: str = 'PEM') -> bytes: + def save_pkcs1(self, format: str = "PEM") -> bytes: """Saves the key in PKCS#1 DER or PEM format. :param format: the format to save; 'PEM' or 'DER' @@ -146,8 +147,8 @@ class AbstractKey: """ methods = { - 'PEM': self._save_pkcs1_pem, - 'DER': self._save_pkcs1_der, + "PEM": self._save_pkcs1_pem, + "DER": self._save_pkcs1_der, } method = self._assert_format_exists(format, methods) @@ -186,7 +187,7 @@ class AbstractKey: blind_r = rsa.randnum.randint(self.n - 1) if rsa.prime.are_relatively_prime(self.n, blind_r): return blind_r - raise RuntimeError('unable to find blinding factor') + raise RuntimeError("unable to find blinding factor") def _update_blinding_factor(self) -> typing.Tuple[int, int]: """Update blinding factors. @@ -212,6 +213,7 @@ class AbstractKey: return self.blindfac, self.blindfac_inverse + class PublicKey(AbstractKey): """Represents a public RSA key. @@ -236,13 +238,13 @@ class PublicKey(AbstractKey): """ - __slots__ = ('n', 'e') + __slots__ = ("n", "e") def __getitem__(self, key: str) -> int: return getattr(self, key) def __repr__(self) -> str: - return 'PublicKey(%i, %i)' % (self.n, self.e) + return "PublicKey(%i, %i)" % (self.n, self.e) def __getstate__(self) -> typing.Tuple[int, int]: """Returns the key as tuple for pickling.""" @@ -269,7 +271,7 @@ class PublicKey(AbstractKey): return hash((self.n, self.e)) @classmethod - def _load_pkcs1_der(cls, keyfile: bytes) -> 'PublicKey': + def _load_pkcs1_der(cls, keyfile: bytes) -> "PublicKey": """Loads a key in PKCS#1 DER format. :param keyfile: contents of a DER-encoded file that contains the public @@ -293,7 +295,7 @@ class PublicKey(AbstractKey): from rsa.asn1 import AsnPubKey (priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey()) - return cls(n=int(priv['modulus']), e=int(priv['publicExponent'])) + return cls(n=int(priv["modulus"]), e=int(priv["publicExponent"])) def _save_pkcs1_der(self) -> bytes: """Saves the public key in PKCS#1 DER format. @@ -307,13 +309,13 @@ class PublicKey(AbstractKey): # Create the ASN object asn_key = AsnPubKey() - asn_key.setComponentByName('modulus', self.n) - asn_key.setComponentByName('publicExponent', self.e) + asn_key.setComponentByName("modulus", self.n) + asn_key.setComponentByName("publicExponent", self.e) return encoder.encode(asn_key) @classmethod - def _load_pkcs1_pem(cls, keyfile: bytes) -> 'PublicKey': + def _load_pkcs1_pem(cls, keyfile: bytes) -> "PublicKey": """Loads a PKCS#1 PEM-encoded public key file. The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and @@ -324,7 +326,7 @@ class PublicKey(AbstractKey): :return: a PublicKey object """ - der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY') + der = rsa.pem.load_pem(keyfile, "RSA PUBLIC KEY") return cls._load_pkcs1_der(der) def _save_pkcs1_pem(self) -> bytes: @@ -335,10 +337,10 @@ class PublicKey(AbstractKey): """ der = self._save_pkcs1_der() - return rsa.pem.save_pem(der, 'RSA PUBLIC KEY') + return rsa.pem.save_pem(der, "RSA PUBLIC KEY") @classmethod - def load_pkcs1_openssl_pem(cls, keyfile: bytes) -> 'PublicKey': + def load_pkcs1_openssl_pem(cls, keyfile: bytes) -> "PublicKey": """Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL. These files can be recognised in that they start with BEGIN PUBLIC KEY @@ -353,11 +355,11 @@ class PublicKey(AbstractKey): :return: a PublicKey object """ - der = rsa.pem.load_pem(keyfile, 'PUBLIC KEY') + der = rsa.pem.load_pem(keyfile, "PUBLIC KEY") return cls.load_pkcs1_openssl_der(der) @classmethod - def load_pkcs1_openssl_der(cls, keyfile: bytes) -> 'PublicKey': + def load_pkcs1_openssl_der(cls, keyfile: bytes) -> "PublicKey": """Loads a PKCS#1 DER-encoded public key file from OpenSSL. :param keyfile: contents of a DER-encoded file that contains the public @@ -371,10 +373,10 @@ class PublicKey(AbstractKey): (keyinfo, _) = decoder.decode(keyfile, asn1Spec=OpenSSLPubKey()) - if keyinfo['header']['oid'] != univ.ObjectIdentifier('1.2.840.113549.1.1.1'): + if keyinfo["header"]["oid"] != univ.ObjectIdentifier("1.2.840.113549.1.1.1"): raise TypeError("This is not a DER-encoded OpenSSL-compatible public key") - return cls._load_pkcs1_der(keyinfo['key'][1:]) + return cls._load_pkcs1_der(keyinfo["key"][1:]) class PrivateKey(AbstractKey): @@ -401,7 +403,7 @@ class PrivateKey(AbstractKey): """ - __slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef') + __slots__ = ("n", "e", "d", "p", "q", "exp1", "exp2", "coef") def __init__(self, n: int, e: int, d: int, p: int, q: int) -> None: AbstractKey.__init__(self, n, e) @@ -418,7 +420,13 @@ class PrivateKey(AbstractKey): return getattr(self, key) def __repr__(self) -> str: - return 'PrivateKey(%i, %i, %i, %i, %i)' % (self.n, self.e, self.d, self.p, self.q) + return "PrivateKey(%i, %i, %i, %i, %i)" % ( + self.n, + self.e, + self.d, + self.p, + self.q, + ) def __getstate__(self) -> typing.Tuple[int, int, int, int, int, int, int, int]: """Returns the key as tuple for pickling.""" @@ -436,14 +444,16 @@ class PrivateKey(AbstractKey): if not isinstance(other, PrivateKey): return False - return (self.n == other.n and - self.e == other.e and - self.d == other.d and - self.p == other.p and - self.q == other.q and - self.exp1 == other.exp1 and - self.exp2 == other.exp2 and - self.coef == other.coef) + return ( + self.n == other.n + and self.e == other.e + and self.d == other.d + and self.p == other.p + and self.q == other.q + and self.exp1 == other.exp1 + and self.exp2 == other.exp2 + and self.coef == other.coef + ) def __ne__(self, other: typing.Any) -> bool: return not (self == other) @@ -481,7 +491,7 @@ class PrivateKey(AbstractKey): return self.unblind(encrypted, blindfac_inverse) @classmethod - def _load_pkcs1_der(cls, keyfile: bytes) -> 'PrivateKey': + def _load_pkcs1_der(cls, keyfile: bytes) -> "PrivateKey": """Loads a key in PKCS#1 DER format. :param keyfile: contents of a DER-encoded file that contains the private @@ -503,6 +513,7 @@ class PrivateKey(AbstractKey): """ from pyasn1.codec.der import decoder + (priv, _) = decoder.decode(keyfile) # ASN.1 contents of DER encoded private key: @@ -521,7 +532,7 @@ class PrivateKey(AbstractKey): # } if priv[0] != 0: - raise ValueError('Unable to read this file, version %s != 0' % priv[0]) + raise ValueError("Unable to read this file, version %s != 0" % priv[0]) as_ints = map(int, priv[1:6]) key = cls(*as_ints) @@ -530,9 +541,9 @@ class PrivateKey(AbstractKey): if (key.exp1, key.exp2, key.coef) != (exp1, exp2, coef): warnings.warn( - 'You have provided a malformed keyfile. Either the exponents ' - 'or the coefficient are incorrect. Using the correct values ' - 'instead.', + "You have provided a malformed keyfile. Either the exponents " + "or the coefficient are incorrect. Using the correct values " + "instead.", UserWarning, ) @@ -550,33 +561,33 @@ class PrivateKey(AbstractKey): class AsnPrivKey(univ.Sequence): componentType = namedtype.NamedTypes( - namedtype.NamedType('version', univ.Integer()), - namedtype.NamedType('modulus', univ.Integer()), - namedtype.NamedType('publicExponent', univ.Integer()), - namedtype.NamedType('privateExponent', univ.Integer()), - namedtype.NamedType('prime1', univ.Integer()), - namedtype.NamedType('prime2', univ.Integer()), - namedtype.NamedType('exponent1', univ.Integer()), - namedtype.NamedType('exponent2', univ.Integer()), - namedtype.NamedType('coefficient', univ.Integer()), + namedtype.NamedType("version", univ.Integer()), + namedtype.NamedType("modulus", univ.Integer()), + namedtype.NamedType("publicExponent", univ.Integer()), + namedtype.NamedType("privateExponent", univ.Integer()), + namedtype.NamedType("prime1", univ.Integer()), + namedtype.NamedType("prime2", univ.Integer()), + namedtype.NamedType("exponent1", univ.Integer()), + namedtype.NamedType("exponent2", univ.Integer()), + namedtype.NamedType("coefficient", univ.Integer()), ) # Create the ASN object asn_key = AsnPrivKey() - asn_key.setComponentByName('version', 0) - asn_key.setComponentByName('modulus', self.n) - asn_key.setComponentByName('publicExponent', self.e) - asn_key.setComponentByName('privateExponent', self.d) - asn_key.setComponentByName('prime1', self.p) - asn_key.setComponentByName('prime2', self.q) - asn_key.setComponentByName('exponent1', self.exp1) - asn_key.setComponentByName('exponent2', self.exp2) - asn_key.setComponentByName('coefficient', self.coef) + asn_key.setComponentByName("version", 0) + asn_key.setComponentByName("modulus", self.n) + asn_key.setComponentByName("publicExponent", self.e) + asn_key.setComponentByName("privateExponent", self.d) + asn_key.setComponentByName("prime1", self.p) + asn_key.setComponentByName("prime2", self.q) + asn_key.setComponentByName("exponent1", self.exp1) + asn_key.setComponentByName("exponent2", self.exp2) + asn_key.setComponentByName("coefficient", self.coef) return encoder.encode(asn_key) @classmethod - def _load_pkcs1_pem(cls, keyfile: bytes) -> 'PrivateKey': + def _load_pkcs1_pem(cls, keyfile: bytes) -> "PrivateKey": """Loads a PKCS#1 PEM-encoded private key file. The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and @@ -588,7 +599,7 @@ class PrivateKey(AbstractKey): :return: a PrivateKey object """ - der = rsa.pem.load_pem(keyfile, b'RSA PRIVATE KEY') + der = rsa.pem.load_pem(keyfile, b"RSA PRIVATE KEY") return cls._load_pkcs1_der(der) def _save_pkcs1_pem(self) -> bytes: @@ -599,12 +610,14 @@ class PrivateKey(AbstractKey): """ der = self._save_pkcs1_der() - return rsa.pem.save_pem(der, b'RSA PRIVATE KEY') + return rsa.pem.save_pem(der, b"RSA PRIVATE KEY") -def find_p_q(nbits: int, - getprime_func: typing.Callable[[int], int] = rsa.prime.getprime, - accurate: bool = True) -> typing.Tuple[int, int]: +def find_p_q( + nbits: int, + getprime_func: typing.Callable[[int], int] = rsa.prime.getprime, + accurate: bool = True, +) -> typing.Tuple[int, int]: """Returns a tuple of two different primes of nbits bits each. The resulting p * q has exacty 2 * nbits bits, and the returned p and q @@ -644,16 +657,16 @@ def find_p_q(nbits: int, qbits = nbits - shift # Choose the two initial primes - log.debug('find_p_q(%i): Finding p', nbits) + log.debug("find_p_q(%i): Finding p", nbits) p = getprime_func(pbits) - log.debug('find_p_q(%i): Finding q', nbits) + log.debug("find_p_q(%i): Finding q", nbits) q = getprime_func(qbits) def is_acceptable(p: int, q: int) -> bool: """Returns True iff p and q are acceptable: - - p and q differ - - (p * q) has the right nr of bits (when accurate=True) + - p and q differ + - (p * q) has the right nr of bits (when accurate=True) """ if p == q: @@ -701,13 +714,17 @@ def calculate_keys_custom_exponent(p: int, q: int, exponent: int) -> typing.Tupl d = rsa.common.inverse(exponent, phi_n) except rsa.common.NotRelativePrimeError as ex: raise rsa.common.NotRelativePrimeError( - exponent, phi_n, ex.d, - msg="e (%d) and phi_n (%d) are not relatively prime (divider=%i)" % - (exponent, phi_n, ex.d)) from ex + exponent, + phi_n, + ex.d, + msg="e (%d) and phi_n (%d) are not relatively prime (divider=%i)" + % (exponent, phi_n, ex.d), + ) from ex if (exponent * d) % phi_n != 1: - raise ValueError("e (%d) and d (%d) are not mult. inv. modulo " - "phi_n (%d)" % (exponent, d, phi_n)) + raise ValueError( + "e (%d) and d (%d) are not mult. inv. modulo " "phi_n (%d)" % (exponent, d, phi_n) + ) return exponent, d @@ -725,10 +742,12 @@ def calculate_keys(p: int, q: int) -> typing.Tuple[int, int]: return calculate_keys_custom_exponent(p, q, DEFAULT_EXPONENT) -def gen_keys(nbits: int, - getprime_func: typing.Callable[[int], int], - accurate: bool = True, - exponent: int = DEFAULT_EXPONENT) -> typing.Tuple[int, int, int, int]: +def gen_keys( + nbits: int, + getprime_func: typing.Callable[[int], int], + accurate: bool = True, + exponent: int = DEFAULT_EXPONENT, +) -> typing.Tuple[int, int, int, int]: """Generate RSA keys of nbits bits. Returns (p, q, e, d). Note: this can take a long time, depending on the key size. @@ -756,10 +775,12 @@ def gen_keys(nbits: int, return p, q, e, d -def newkeys(nbits: int, - accurate: bool = True, - poolsize: int = 1, - exponent: int = DEFAULT_EXPONENT) -> typing.Tuple[PublicKey, PrivateKey]: +def newkeys( + nbits: int, + accurate: bool = True, + poolsize: int = 1, + exponent: int = DEFAULT_EXPONENT, +) -> typing.Tuple[PublicKey, PrivateKey]: """Generates public and private keys, and returns them as (pub, priv). The public key is also known as the 'encryption key', and is a @@ -786,10 +807,10 @@ def newkeys(nbits: int, """ if nbits < 16: - raise ValueError('Key too small') + raise ValueError("Key too small") if poolsize < 1: - raise ValueError('Pool size (%i) should be >= 1' % poolsize) + raise ValueError("Pool size (%i) should be >= 1" % poolsize) # Determine which getprime function to use if poolsize > 1: @@ -797,6 +818,7 @@ def newkeys(nbits: int, def getprime_func(nbits: int) -> int: return parallel.getprime(nbits, poolsize=poolsize) + else: getprime_func = rsa.prime.getprime @@ -806,15 +828,12 @@ def newkeys(nbits: int, # Create the key objects n = p * q - return ( - PublicKey(n, e), - PrivateKey(n, e, d, p, q) - ) + return (PublicKey(n, e), PrivateKey(n, e, d, p, q)) -__all__ = ['PublicKey', 'PrivateKey', 'newkeys'] +__all__ = ["PublicKey", "PrivateKey", "newkeys"] -if __name__ == '__main__': +if __name__ == "__main__": import doctest try: @@ -824,8 +843,8 @@ if __name__ == '__main__': break if (count % 10 == 0 and count) or count == 1: - print('%i times' % count) + print("%i times" % count) except KeyboardInterrupt: - print('Aborted') + print("Aborted") else: - print('Doctests done') + print("Doctests done") |