diff options
author | Sybren A. St?vel <sybren@stuvel.eu> | 2011-07-30 19:53:11 +0200 |
---|---|---|
committer | Sybren A. St?vel <sybren@stuvel.eu> | 2011-07-30 19:53:11 +0200 |
commit | a63693829a04483e92b367f7669e3ff4646d6a0c (patch) | |
tree | aec4adeaf4fcead2a1db56e0f56c0c6038318583 /rsa | |
parent | f2794752f92634530ddc5a8437c17749b00eb4b1 (diff) | |
download | rsa-a63693829a04483e92b367f7669e3ff4646d6a0c.tar.gz |
Better type checking in core, casting ASN-ints to Python int
Diffstat (limited to 'rsa')
-rw-r--r-- | rsa/core.py | 23 | ||||
-rw-r--r-- | rsa/key.py | 6 |
2 files changed, 22 insertions, 7 deletions
diff --git a/rsa/core.py b/rsa/core.py index b66eaea..cc95f59 100644 --- a/rsa/core.py +++ b/rsa/core.py @@ -21,14 +21,19 @@ mathematically on integers. import types +def assert_int(var, name): + + if type(var) in (types.IntType, types.LongType): + return + + raise TypeError('%s should be an integer, not %s' % (name, var.__class__)) + def encrypt_int(message, ekey, n): """Encrypts a message using encryption key 'ekey', working modulo n""" - if type(message) is types.IntType: - message = long(message) - - if not type(message) is types.LongType: - raise TypeError("You must pass a long or int") + assert_int(message, 'message') + assert_int(ekey, 'ekey') + assert_int(n, 'n') if message < 0: raise ValueError('Only non-negative numbers are supported') @@ -42,6 +47,14 @@ def decrypt_int(cyphertext, dkey, n): """Decrypts a cypher text using the decryption key 'dkey', working modulo n""" + if type(cyphertext) not in (types.IntType, types.LongType): + raise TypeError('cyphertext should be an integer, not %s' % + cyphertext.__type__) + + assert_int(cyphertext, 'cyphertext') + assert_int(dkey, 'dkey') + assert_int(n, 'n') + message = pow(cyphertext, dkey, n) return message @@ -157,7 +157,8 @@ class PublicKey(AbstractKey): # modulus INTEGER, -- n # publicExponent INTEGER, -- e - return cls(*priv) + as_ints = tuple(int(x) for x in priv) + return cls(*as_ints) def save_pkcs1_der(self): '''Saves the public key in PKCS#1 DER format. @@ -330,7 +331,8 @@ class PrivateKey(AbstractKey): if priv[0] != 0: raise ValueError('Unable to read this file, version %s != 0' % priv[0]) - return cls(*priv[1:9]) + as_ints = tuple(int(x) for x in priv[1:9]) + return cls(*as_ints) def save_pkcs1_der(self): '''Saves the private key in PKCS#1 DER format. |