summaryrefslogtreecommitdiff
path: root/rsa
diff options
context:
space:
mode:
authorSybren A. St?vel <sybren@stuvel.eu>2011-07-30 19:53:11 +0200
committerSybren A. St?vel <sybren@stuvel.eu>2011-07-30 19:53:11 +0200
commita63693829a04483e92b367f7669e3ff4646d6a0c (patch)
treeaec4adeaf4fcead2a1db56e0f56c0c6038318583 /rsa
parentf2794752f92634530ddc5a8437c17749b00eb4b1 (diff)
downloadrsa-a63693829a04483e92b367f7669e3ff4646d6a0c.tar.gz
Better type checking in core, casting ASN-ints to Python int
Diffstat (limited to 'rsa')
-rw-r--r--rsa/core.py23
-rw-r--r--rsa/key.py6
2 files changed, 22 insertions, 7 deletions
diff --git a/rsa/core.py b/rsa/core.py
index b66eaea..cc95f59 100644
--- a/rsa/core.py
+++ b/rsa/core.py
@@ -21,14 +21,19 @@ mathematically on integers.
import types
+def assert_int(var, name):
+
+ if type(var) in (types.IntType, types.LongType):
+ return
+
+ raise TypeError('%s should be an integer, not %s' % (name, var.__class__))
+
def encrypt_int(message, ekey, n):
"""Encrypts a message using encryption key 'ekey', working modulo n"""
- if type(message) is types.IntType:
- message = long(message)
-
- if not type(message) is types.LongType:
- raise TypeError("You must pass a long or int")
+ assert_int(message, 'message')
+ assert_int(ekey, 'ekey')
+ assert_int(n, 'n')
if message < 0:
raise ValueError('Only non-negative numbers are supported')
@@ -42,6 +47,14 @@ def decrypt_int(cyphertext, dkey, n):
"""Decrypts a cypher text using the decryption key 'dkey', working
modulo n"""
+ if type(cyphertext) not in (types.IntType, types.LongType):
+ raise TypeError('cyphertext should be an integer, not %s' %
+ cyphertext.__type__)
+
+ assert_int(cyphertext, 'cyphertext')
+ assert_int(dkey, 'dkey')
+ assert_int(n, 'n')
+
message = pow(cyphertext, dkey, n)
return message
diff --git a/rsa/key.py b/rsa/key.py
index 9e01958..d1fc2ec 100644
--- a/rsa/key.py
+++ b/rsa/key.py
@@ -157,7 +157,8 @@ class PublicKey(AbstractKey):
# modulus INTEGER, -- n
# publicExponent INTEGER, -- e
- return cls(*priv)
+ as_ints = tuple(int(x) for x in priv)
+ return cls(*as_ints)
def save_pkcs1_der(self):
'''Saves the public key in PKCS#1 DER format.
@@ -330,7 +331,8 @@ class PrivateKey(AbstractKey):
if priv[0] != 0:
raise ValueError('Unable to read this file, version %s != 0' % priv[0])
- return cls(*priv[1:9])
+ as_ints = tuple(int(x) for x in priv[1:9])
+ return cls(*as_ints)
def save_pkcs1_der(self):
'''Saves the private key in PKCS#1 DER format.