diff options
author | Sybren A. Stüvel <sybren@stuvel.eu> | 2011-07-30 19:53:11 +0200 |
---|---|---|
committer | Sybren A. Stüvel <sybren@stuvel.eu> | 2011-07-30 19:53:11 +0200 |
commit | 877fce05e5c26c1c3339b06918a83cb19797ff4b (patch) | |
tree | aec4adeaf4fcead2a1db56e0f56c0c6038318583 | |
parent | a3c476ed5fc469f41d820d7db6eedaa61575f184 (diff) | |
download | rsa-git-877fce05e5c26c1c3339b06918a83cb19797ff4b.tar.gz |
Better type checking in core, casting ASN-ints to Python int
-rw-r--r-- | rsa/core.py | 23 | ||||
-rw-r--r-- | rsa/key.py | 6 |
2 files changed, 22 insertions, 7 deletions
diff --git a/rsa/core.py b/rsa/core.py index b66eaea..cc95f59 100644 --- a/rsa/core.py +++ b/rsa/core.py @@ -21,14 +21,19 @@ mathematically on integers. import types +def assert_int(var, name): + + if type(var) in (types.IntType, types.LongType): + return + + raise TypeError('%s should be an integer, not %s' % (name, var.__class__)) + def encrypt_int(message, ekey, n): """Encrypts a message using encryption key 'ekey', working modulo n""" - if type(message) is types.IntType: - message = long(message) - - if not type(message) is types.LongType: - raise TypeError("You must pass a long or int") + assert_int(message, 'message') + assert_int(ekey, 'ekey') + assert_int(n, 'n') if message < 0: raise ValueError('Only non-negative numbers are supported') @@ -42,6 +47,14 @@ def decrypt_int(cyphertext, dkey, n): """Decrypts a cypher text using the decryption key 'dkey', working modulo n""" + if type(cyphertext) not in (types.IntType, types.LongType): + raise TypeError('cyphertext should be an integer, not %s' % + cyphertext.__type__) + + assert_int(cyphertext, 'cyphertext') + assert_int(dkey, 'dkey') + assert_int(n, 'n') + message = pow(cyphertext, dkey, n) return message @@ -157,7 +157,8 @@ class PublicKey(AbstractKey): # modulus INTEGER, -- n # publicExponent INTEGER, -- e - return cls(*priv) + as_ints = tuple(int(x) for x in priv) + return cls(*as_ints) def save_pkcs1_der(self): '''Saves the public key in PKCS#1 DER format. @@ -330,7 +331,8 @@ class PrivateKey(AbstractKey): if priv[0] != 0: raise ValueError('Unable to read this file, version %s != 0' % priv[0]) - return cls(*priv[1:9]) + as_ints = tuple(int(x) for x in priv[1:9]) + return cls(*as_ints) def save_pkcs1_der(self): '''Saves the private key in PKCS#1 DER format. |