summaryrefslogtreecommitdiff
path: root/rsa/key.py
diff options
context:
space:
mode:
authorSybren A. Stüvel <sybren@stuvel.eu>2019-08-04 16:41:01 +0200
committerSybren A. Stüvel <sybren@stuvel.eu>2019-08-04 17:05:58 +0200
commitb6cebd53fcafd3088fc8361f6d3466166f75410b (patch)
treea1a3912fb9e91e249e433df0a9b79572f46340f3 /rsa/key.py
parent6760eb76e665dc81863a82110164c4b3b38e7ee9 (diff)
downloadrsa-git-b6cebd53fcafd3088fc8361f6d3466166f75410b.tar.gz
Added type annotations + some fixes to get them correct
One functional change: `CryptoOperation.read_infile()` now reads bytes from `sys.stdin` instead of text. This is necessary to be consistent with the rest of the code, which all deals with bytes.
Diffstat (limited to 'rsa/key.py')
-rw-r--r--rsa/key.py95
1 files changed, 49 insertions, 46 deletions
diff --git a/rsa/key.py b/rsa/key.py
index 1565967..05c77ef 100644
--- a/rsa/key.py
+++ b/rsa/key.py
@@ -34,6 +34,7 @@ of pyasn1.
"""
import logging
+import typing
import warnings
import rsa.prime
@@ -47,17 +48,17 @@ log = logging.getLogger(__name__)
DEFAULT_EXPONENT = 65537
-class AbstractKey(object):
+class AbstractKey:
"""Abstract superclass for private and public keys."""
__slots__ = ('n', 'e')
- def __init__(self, n, e):
+ def __init__(self, n: int, e: int) -> None:
self.n = n
self.e = e
@classmethod
- def _load_pkcs1_pem(cls, keyfile):
+ def _load_pkcs1_pem(cls, keyfile: bytes) -> 'AbstractKey':
"""Loads a key in PKCS#1 PEM format, implement in a subclass.
:param keyfile: contents of a PEM-encoded file that contains
@@ -69,7 +70,7 @@ class AbstractKey(object):
"""
@classmethod
- def _load_pkcs1_der(cls, keyfile):
+ def _load_pkcs1_der(cls, keyfile: bytes) -> 'AbstractKey':
"""Loads a key in PKCS#1 PEM format, implement in a subclass.
:param keyfile: contents of a DER-encoded file that contains
@@ -80,14 +81,14 @@ class AbstractKey(object):
:rtype: AbstractKey
"""
- def _save_pkcs1_pem(self):
+ def _save_pkcs1_pem(self) -> bytes:
"""Saves the key in PKCS#1 PEM format, implement in a subclass.
:returns: the PEM-encoded key.
:rtype: bytes
"""
- def _save_pkcs1_der(self):
+ def _save_pkcs1_der(self) -> bytes:
"""Saves the key in PKCS#1 DER format, implement in a subclass.
:returns: the DER-encoded key.
@@ -95,7 +96,7 @@ class AbstractKey(object):
"""
@classmethod
- def load_pkcs1(cls, keyfile, format='PEM'):
+ def load_pkcs1(cls, keyfile: bytes, format='PEM') -> 'AbstractKey':
"""Loads a key in PKCS#1 DER or PEM format.
:param keyfile: contents of a DER- or PEM-encoded file that contains
@@ -117,7 +118,7 @@ class AbstractKey(object):
return method(keyfile)
@staticmethod
- def _assert_format_exists(file_format, methods):
+ def _assert_format_exists(file_format: str, methods: typing.Mapping[str, typing.Callable]) -> typing.Callable:
"""Checks whether the given file format exists in 'methods'.
"""
@@ -128,7 +129,7 @@ class AbstractKey(object):
raise ValueError('Unsupported format: %r, try one of %s' % (file_format,
formats))
- def save_pkcs1(self, format='PEM'):
+ def save_pkcs1(self, format='PEM') -> bytes:
"""Saves the key in PKCS#1 DER or PEM format.
:param format: the format to save; 'PEM' or 'DER'
@@ -145,7 +146,7 @@ class AbstractKey(object):
method = self._assert_format_exists(format, methods)
return method()
- def blind(self, message, r):
+ def blind(self, message: int, r: int) -> int:
"""Performs blinding on the message using random number 'r'.
:param message: the message, as integer, to blind.
@@ -162,7 +163,7 @@ class AbstractKey(object):
return (message * pow(r, self.e, self.n)) % self.n
- def unblind(self, blinded, r):
+ def unblind(self, blinded: int, r: int) -> int:
"""Performs blinding on the message using random number 'r'.
:param blinded: the blinded message, as integer, to unblind.
@@ -206,18 +207,18 @@ class PublicKey(AbstractKey):
def __getitem__(self, key):
return getattr(self, key)
- def __repr__(self):
+ def __repr__(self) -> str:
return 'PublicKey(%i, %i)' % (self.n, self.e)
- def __getstate__(self):
+ def __getstate__(self) -> typing.Tuple[int, int]:
"""Returns the key as tuple for pickling."""
return self.n, self.e
- def __setstate__(self, state):
+ def __setstate__(self, state: typing.Tuple[int, int]) -> None:
"""Sets the key from tuple."""
self.n, self.e = state
- def __eq__(self, other):
+ def __eq__(self, other: typing.Any) -> bool:
if other is None:
return False
@@ -226,14 +227,14 @@ class PublicKey(AbstractKey):
return self.n == other.n and self.e == other.e
- def __ne__(self, other):
+ def __ne__(self, other: typing.Any) -> bool:
return not (self == other)
- def __hash__(self):
+ def __hash__(self) -> int:
return hash((self.n, self.e))
@classmethod
- def _load_pkcs1_der(cls, keyfile):
+ def _load_pkcs1_der(cls, keyfile: bytes) -> 'PublicKey':
"""Loads a key in PKCS#1 DER format.
:param keyfile: contents of a DER-encoded file that contains the public
@@ -259,7 +260,7 @@ class PublicKey(AbstractKey):
(priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey())
return cls(n=int(priv['modulus']), e=int(priv['publicExponent']))
- def _save_pkcs1_der(self):
+ def _save_pkcs1_der(self) -> bytes:
"""Saves the public key in PKCS#1 DER format.
:returns: the DER-encoded public key.
@@ -277,7 +278,7 @@ class PublicKey(AbstractKey):
return encoder.encode(asn_key)
@classmethod
- def _load_pkcs1_pem(cls, keyfile):
+ def _load_pkcs1_pem(cls, keyfile: bytes) -> 'PublicKey':
"""Loads a PKCS#1 PEM-encoded public key file.
The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and
@@ -291,7 +292,7 @@ class PublicKey(AbstractKey):
der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY')
return cls._load_pkcs1_der(der)
- def _save_pkcs1_pem(self):
+ def _save_pkcs1_pem(self) -> bytes:
"""Saves a PKCS#1 PEM-encoded public key file.
:return: contents of a PEM-encoded file that contains the public key.
@@ -302,7 +303,7 @@ class PublicKey(AbstractKey):
return rsa.pem.save_pem(der, 'RSA PUBLIC KEY')
@classmethod
- def load_pkcs1_openssl_pem(cls, keyfile):
+ def load_pkcs1_openssl_pem(cls, keyfile: bytes) -> 'PublicKey':
"""Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL.
These files can be recognised in that they start with BEGIN PUBLIC KEY
@@ -321,14 +322,12 @@ class PublicKey(AbstractKey):
return cls.load_pkcs1_openssl_der(der)
@classmethod
- def load_pkcs1_openssl_der(cls, keyfile):
+ def load_pkcs1_openssl_der(cls, keyfile: bytes) -> 'PublicKey':
"""Loads a PKCS#1 DER-encoded public key file from OpenSSL.
:param keyfile: contents of a DER-encoded file that contains the public
key, from OpenSSL.
:return: a PublicKey object
- :rtype: bytes
-
"""
from rsa.asn1 import OpenSSLPubKey
@@ -369,7 +368,7 @@ class PrivateKey(AbstractKey):
__slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef')
- def __init__(self, n, e, d, p, q):
+ def __init__(self, n: int, e: int, d: int, p: int, q: int) -> None:
AbstractKey.__init__(self, n, e)
self.d = d
self.p = p
@@ -383,18 +382,18 @@ class PrivateKey(AbstractKey):
def __getitem__(self, key):
return getattr(self, key)
- def __repr__(self):
- return 'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self
+ def __repr__(self) -> str:
+ return 'PrivateKey(%i, %i, %i, %i, %i)' % (self.n, self.e, self.d, self.p, self.q)
- def __getstate__(self):
+ def __getstate__(self) -> typing.Tuple[int, int, int, int, int, int, int, int]:
"""Returns the key as tuple for pickling."""
return self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef
- def __setstate__(self, state):
+ def __setstate__(self, state: typing.Tuple[int, int, int, int, int, int, int, int]):
"""Sets the key from tuple."""
self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef = state
- def __eq__(self, other):
+ def __eq__(self, other: typing.Any) -> bool:
if other is None:
return False
@@ -410,13 +409,13 @@ class PrivateKey(AbstractKey):
self.exp2 == other.exp2 and
self.coef == other.coef)
- def __ne__(self, other):
+ def __ne__(self, other: typing.Any) -> bool:
return not (self == other)
- def __hash__(self):
+ def __hash__(self) -> int:
return hash((self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef))
- def blinded_decrypt(self, encrypted):
+ def blinded_decrypt(self, encrypted: int) -> int:
"""Decrypts the message using blinding to prevent side-channel attacks.
:param encrypted: the encrypted message
@@ -432,7 +431,7 @@ class PrivateKey(AbstractKey):
return self.unblind(decrypted, blind_r)
- def blinded_encrypt(self, message):
+ def blinded_encrypt(self, message: int) -> int:
"""Encrypts the message using blinding to prevent side-channel attacks.
:param message: the message to encrypt
@@ -448,7 +447,7 @@ class PrivateKey(AbstractKey):
return self.unblind(encrypted, blind_r)
@classmethod
- def _load_pkcs1_der(cls, keyfile):
+ def _load_pkcs1_der(cls, keyfile: bytes) -> 'PrivateKey':
"""Loads a key in PKCS#1 DER format.
:param keyfile: contents of a DER-encoded file that contains the private
@@ -505,7 +504,7 @@ class PrivateKey(AbstractKey):
return key
- def _save_pkcs1_der(self):
+ def _save_pkcs1_der(self) -> bytes:
"""Saves the private key in PKCS#1 DER format.
:returns: the DER-encoded private key.
@@ -543,7 +542,7 @@ class PrivateKey(AbstractKey):
return encoder.encode(asn_key)
@classmethod
- def _load_pkcs1_pem(cls, keyfile):
+ def _load_pkcs1_pem(cls, keyfile: bytes) -> 'PrivateKey':
"""Loads a PKCS#1 PEM-encoded private key file.
The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and
@@ -558,7 +557,7 @@ class PrivateKey(AbstractKey):
der = rsa.pem.load_pem(keyfile, b'RSA PRIVATE KEY')
return cls._load_pkcs1_der(der)
- def _save_pkcs1_pem(self):
+ def _save_pkcs1_pem(self) -> bytes:
"""Saves a PKCS#1 PEM-encoded private key file.
:return: contents of a PEM-encoded file that contains the private key.
@@ -569,7 +568,7 @@ class PrivateKey(AbstractKey):
return rsa.pem.save_pem(der, b'RSA PRIVATE KEY')
-def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
+def find_p_q(nbits: int, getprime_func=rsa.prime.getprime, accurate=True) -> typing.Tuple[int, int]:
"""Returns a tuple of two different primes of nbits bits each.
The resulting p * q has exacty 2 * nbits bits, and the returned p and q
@@ -647,7 +646,7 @@ def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
return max(p, q), min(p, q)
-def calculate_keys_custom_exponent(p, q, exponent):
+def calculate_keys_custom_exponent(p: int, q: int, exponent: int) -> typing.Tuple[int, int]:
"""Calculates an encryption and a decryption key given p, q and an exponent,
and returns them as a tuple (e, d)
@@ -677,7 +676,7 @@ def calculate_keys_custom_exponent(p, q, exponent):
return exponent, d
-def calculate_keys(p, q):
+def calculate_keys(p: int, q: int) -> typing.Tuple[int, int]:
"""Calculates an encryption and a decryption key given p and q, and
returns them as a tuple (e, d)
@@ -690,7 +689,10 @@ def calculate_keys(p, q):
return calculate_keys_custom_exponent(p, q, DEFAULT_EXPONENT)
-def gen_keys(nbits, getprime_func, accurate=True, exponent=DEFAULT_EXPONENT):
+def gen_keys(nbits: int,
+ getprime_func: typing.Callable[[int], int],
+ accurate=True,
+ exponent=DEFAULT_EXPONENT) -> typing.Tuple[int, int, int, int]:
"""Generate RSA keys of nbits bits. Returns (p, q, e, d).
Note: this can take a long time, depending on the key size.
@@ -718,7 +720,8 @@ def gen_keys(nbits, getprime_func, accurate=True, exponent=DEFAULT_EXPONENT):
return p, q, e, d
-def newkeys(nbits, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT):
+def newkeys(nbits: int, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT) \
+ -> typing.Tuple[PublicKey, PrivateKey]:
"""Generates public and private keys, and returns them as (pub, priv).
The public key is also known as the 'encryption key', and is a
@@ -753,9 +756,9 @@ def newkeys(nbits, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT):
# Determine which getprime function to use
if poolsize > 1:
from rsa import parallel
- import functools
- getprime_func = functools.partial(parallel.getprime, poolsize=poolsize)
+ def getprime_func(nbits):
+ return parallel.getprime(nbits, poolsize=poolsize)
else:
getprime_func = rsa.prime.getprime