summaryrefslogtreecommitdiff
path: root/rsa/pkcs1.py
diff options
context:
space:
mode:
Diffstat (limited to 'rsa/pkcs1.py')
-rw-r--r--rsa/pkcs1.py117
1 files changed, 63 insertions, 54 deletions
diff --git a/rsa/pkcs1.py b/rsa/pkcs1.py
index 9adad90..5992c7f 100644
--- a/rsa/pkcs1.py
+++ b/rsa/pkcs1.py
@@ -41,37 +41,41 @@ else:
# ASN.1 codes that describe the hash algorithm used.
HASH_ASN1 = {
- 'MD5': b'\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10',
- 'SHA-1': b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14',
- 'SHA-224': b'\x30\x2d\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x04\x05\x00\x04\x1c',
- 'SHA-256': b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20',
- 'SHA-384': b'\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30',
- 'SHA-512': b'\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40',
+ "MD5": b"\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10",
+ "SHA-1": b"\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14",
+ "SHA-224": b"\x30\x2d\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x04\x05\x00\x04\x1c",
+ "SHA-256": b"\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20",
+ "SHA-384": b"\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30",
+ "SHA-512": b"\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40",
}
HASH_METHODS: typing.Dict[str, typing.Callable[[], HashType]] = {
- 'MD5': hashlib.md5,
- 'SHA-1': hashlib.sha1,
- 'SHA-224': hashlib.sha224,
- 'SHA-256': hashlib.sha256,
- 'SHA-384': hashlib.sha384,
- 'SHA-512': hashlib.sha512,
+ "MD5": hashlib.md5,
+ "SHA-1": hashlib.sha1,
+ "SHA-224": hashlib.sha224,
+ "SHA-256": hashlib.sha256,
+ "SHA-384": hashlib.sha384,
+ "SHA-512": hashlib.sha512,
}
if sys.version_info >= (3, 6):
# Python 3.6 introduced SHA3 support.
- HASH_ASN1.update({
- 'SHA3-256': b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x08\x05\x00\x04\x20',
- 'SHA3-384': b'\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x09\x05\x00\x04\x30',
- 'SHA3-512': b'\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x0a\x05\x00\x04\x40',
- })
-
- HASH_METHODS.update({
- 'SHA3-256': hashlib.sha3_256,
- 'SHA3-384': hashlib.sha3_384,
- 'SHA3-512': hashlib.sha3_512,
- })
+ HASH_ASN1.update(
+ {
+ "SHA3-256": b"\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x08\x05\x00\x04\x20",
+ "SHA3-384": b"\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x09\x05\x00\x04\x30",
+ "SHA3-512": b"\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x0a\x05\x00\x04\x40",
+ }
+ )
+
+ HASH_METHODS.update(
+ {
+ "SHA3-256": hashlib.sha3_256,
+ "SHA3-384": hashlib.sha3_384,
+ "SHA3-512": hashlib.sha3_512,
+ }
+ )
class CryptoError(Exception):
@@ -105,11 +109,13 @@ def _pad_for_encryption(message: bytes, target_length: int) -> bytes:
msglength = len(message)
if msglength > max_msglength:
- raise OverflowError('%i bytes needed for message, but there is only'
- ' space for %i' % (msglength, max_msglength))
+ raise OverflowError(
+ "%i bytes needed for message, but there is only"
+ " space for %i" % (msglength, max_msglength)
+ )
# Get random padding
- padding = b''
+ padding = b""
padding_length = target_length - msglength - 3
# We remove 0-bytes, so we'll end up with less padding than we've asked for,
@@ -121,15 +127,12 @@ def _pad_for_encryption(message: bytes, target_length: int) -> bytes:
# after removing the 0-bytes. This increases the chance of getting
# enough bytes, especially when needed_bytes is small
new_padding = os.urandom(needed_bytes + 5)
- new_padding = new_padding.replace(b'\x00', b'')
+ new_padding = new_padding.replace(b"\x00", b"")
padding = padding + new_padding[:needed_bytes]
assert len(padding) == padding_length
- return b''.join([b'\x00\x02',
- padding,
- b'\x00',
- message])
+ return b"".join([b"\x00\x02", padding, b"\x00", message])
def _pad_for_signing(message: bytes, target_length: int) -> bytes:
@@ -155,15 +158,14 @@ def _pad_for_signing(message: bytes, target_length: int) -> bytes:
msglength = len(message)
if msglength > max_msglength:
- raise OverflowError('%i bytes needed for message, but there is only'
- ' space for %i' % (msglength, max_msglength))
+ raise OverflowError(
+ "%i bytes needed for message, but there is only"
+ " space for %i" % (msglength, max_msglength)
+ )
padding_length = target_length - msglength - 3
- return b''.join([b'\x00\x01',
- padding_length * b'\xff',
- b'\x00',
- message])
+ return b"".join([b"\x00\x01", padding_length * b"\xff", b"\x00", message])
def encrypt(message: bytes, pub_key: key.PublicKey) -> bytes:
@@ -259,13 +261,13 @@ def decrypt(crypto: bytes, priv_key: key.PrivateKey) -> bytes:
# integer). This fixes CVE-2020-13757.
if len(crypto) > blocksize:
# This is operating on public information, so doesn't need to be constant-time.
- raise DecryptionError('Decryption failed')
+ raise DecryptionError("Decryption failed")
# If we can't find the cleartext marker, decryption failed.
- cleartext_marker_bad = not compare_digest(cleartext[:2], b'\x00\x02')
+ cleartext_marker_bad = not compare_digest(cleartext[:2], b"\x00\x02")
# Find the 00 separator between the padding and the message
- sep_idx = cleartext.find(b'\x00', 2)
+ sep_idx = cleartext.find(b"\x00", 2)
# sep_idx indicates the position of the `\x00` separator that separates the
# padding from the actual message. The padding should be at least 8 bytes
@@ -276,9 +278,9 @@ def decrypt(crypto: bytes, priv_key: key.PrivateKey) -> bytes:
anything_bad = cleartext_marker_bad | sep_idx_bad
if anything_bad:
- raise DecryptionError('Decryption failed')
+ raise DecryptionError("Decryption failed")
- return cleartext[sep_idx + 1:]
+ return cleartext[sep_idx + 1 :]
def sign_hash(hash_value: bytes, priv_key: key.PrivateKey, hash_method: str) -> bytes:
@@ -299,7 +301,7 @@ def sign_hash(hash_value: bytes, priv_key: key.PrivateKey, hash_method: str) ->
# Get the ASN1 code for this hash method
if hash_method not in HASH_ASN1:
- raise ValueError('Invalid hash method: %s' % hash_method)
+ raise ValueError("Invalid hash method: %s" % hash_method)
asn1code = HASH_ASN1[hash_method]
# Encrypt the hash with the private key
@@ -365,11 +367,11 @@ def verify(message: bytes, signature: bytes, pub_key: key.PublicKey) -> str:
expected = _pad_for_signing(cleartext, keylength)
if len(signature) != keylength:
- raise VerificationError('Verification failed')
+ raise VerificationError("Verification failed")
# Compare with the signed one
if expected != clearsig:
- raise VerificationError('Verification failed')
+ raise VerificationError("Verification failed")
return method_name
@@ -426,7 +428,7 @@ def compute_hash(message: typing.Union[bytes, typing.BinaryIO], method_name: str
"""
if method_name not in HASH_METHODS:
- raise ValueError('Invalid hash method: %s' % method_name)
+ raise ValueError("Invalid hash method: %s" % method_name)
method = HASH_METHODS[method_name]
hasher = method()
@@ -434,7 +436,7 @@ def compute_hash(message: typing.Union[bytes, typing.BinaryIO], method_name: str
if isinstance(message, bytes):
hasher.update(message)
else:
- assert hasattr(message, 'read') and hasattr(message.read, '__call__')
+ assert hasattr(message, "read") and hasattr(message.read, "__call__")
# read as 1K blocks
for block in yield_fixedblocks(message, 1024):
hasher.update(block)
@@ -454,14 +456,21 @@ def _find_method_hash(clearsig: bytes) -> str:
if asn1code in clearsig:
return hashname
- raise VerificationError('Verification failed')
+ raise VerificationError("Verification failed")
-__all__ = ['encrypt', 'decrypt', 'sign', 'verify',
- 'DecryptionError', 'VerificationError', 'CryptoError']
+__all__ = [
+ "encrypt",
+ "decrypt",
+ "sign",
+ "verify",
+ "DecryptionError",
+ "VerificationError",
+ "CryptoError",
+]
-if __name__ == '__main__':
- print('Running doctests 1000x or until failure')
+if __name__ == "__main__":
+ print("Running doctests 1000x or until failure")
import doctest
for count in range(1000):
@@ -470,6 +479,6 @@ if __name__ == '__main__':
break
if count % 100 == 0 and count:
- print('%i times' % count)
+ print("%i times" % count)
- print('Doctests done')
+ print("Doctests done")