summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSybren A. St?vel <sybren@stuvel.eu>2011-07-24 19:00:56 +0200
committerSybren A. St?vel <sybren@stuvel.eu>2011-07-24 19:00:56 +0200
commitd0c9c2ad88bc4953125b54d7c21f8dcf18cc78d5 (patch)
tree931d7fcb53d7cefdfb72a41c20381ad65e148b41
parentac2d48aefb7b6ec801dcfed8385e9d674b8e2414 (diff)
downloadrsa-d0c9c2ad88bc4953125b54d7c21f8dcf18cc78d5.tar.gz
Added simpler save/load functions
-rw-r--r--rsa/key.py56
-rw-r--r--tests/test_load_save_keys.py16
2 files changed, 58 insertions, 14 deletions
diff --git a/rsa/key.py b/rsa/key.py
index 2aa39f1..7b13fa6 100644
--- a/rsa/key.py
+++ b/rsa/key.py
@@ -25,6 +25,7 @@ of pyasn1.
'''
+import abc
import logging
import rsa.prime
@@ -33,7 +34,53 @@ import rsa.common
log = logging.getLogger(__name__)
-class PublicKey(object):
+class AbstractKey(object):
+ '''Abstract superclass for private and public keys.'''
+
+ @classmethod
+ def load_pkcs1(cls, keyfile, format='pem'):
+ r'''Loads a key in PKCS#1 DER or PEM format.
+
+ @param keyfile: contents of a DER- or PEM-encoded file that contains
+ the public key.
+ @param format: the format of the file to load; 'pem' or 'der'
+ @return: a PublicKey object
+ '''
+
+ methods = {
+ 'pem': cls.load_pkcs1_pem,
+ 'der': cls.load_pkcs1_der,
+ }
+
+ if format not in methods:
+ formats = ', '.join(sorted(methods.keys()))
+ raise ValueError('Unsupported format: %r, try one of %s' % (format,
+ formats))
+
+ method = methods[format]
+ return method(keyfile)
+
+ def save_pkcs1(self, format='pem'):
+ '''Saves the public key in PKCS#1 DER or PEM format.
+
+ @param format: the format to save; 'pem' or 'der'
+ @returns: the DER- or PEM-encoded public key.
+ '''
+
+ methods = {
+ 'pem': self.save_pkcs1_pem,
+ 'der': self.save_pkcs1_der,
+ }
+
+ if format not in methods:
+ formats = ', '.join(sorted(methods.keys()))
+ raise ValueError('Unsupported format: %r, try one of %s' % (format,
+ formats))
+
+ method = methods[format]
+ return method()
+
+class PublicKey(AbstractKey):
'''Represents a public RSA key.
This key is also known as the 'encryption key'. It contains the 'n' and 'e'
@@ -158,8 +205,7 @@ class PublicKey(object):
der = self.save_pkcs1_der()
return rsa.pem.save_pem(der, 'RSA PUBLIC KEY')
-
-class PrivateKey(object):
+class PrivateKey(AbstractKey):
'''Represents a private RSA key.
This key is also known as the 'decryption key'. It contains the 'n', 'e',
@@ -347,7 +393,6 @@ class PrivateKey(object):
return rsa.pem.save_pem(der, 'RSA PRIVATE KEY')
-
def extended_gcd(a, b):
"""Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
"""
@@ -512,8 +557,7 @@ def newkeys(nbits, accurate=True):
PrivateKey(n, e, d, p, q)
)
-__all__ = ['PublicKey', 'PrivateKey', 'newkeys', 'load_private_key_der',
- 'load_private_key_pem', 'save_private_key_der', 'save_private_key_pem']
+__all__ = ['PublicKey', 'PrivateKey', 'newkeys']
if __name__ == '__main__':
import doctest
diff --git a/tests/test_load_save_keys.py b/tests/test_load_save_keys.py
index fca4241..abcfb18 100644
--- a/tests/test_load_save_keys.py
+++ b/tests/test_load_save_keys.py
@@ -54,7 +54,7 @@ class DerTest(unittest.TestCase):
def test_load_private_key(self):
'''Test loading private DER keys.'''
- key = rsa.key.PrivateKey.load_pkcs1_der(PRIVATE_DER)
+ key = rsa.key.PrivateKey.load_pkcs1(PRIVATE_DER, 'der')
expected = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
self.assertEqual(expected, key)
@@ -63,14 +63,14 @@ class DerTest(unittest.TestCase):
'''Test saving private DER keys.'''
key = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
- der = key.save_pkcs1_der()
+ der = key.save_pkcs1('der')
self.assertEqual(PRIVATE_DER, der)
def test_load_public_key(self):
'''Test loading public DER keys.'''
- key = rsa.key.PublicKey.load_pkcs1_der(PUBLIC_DER)
+ key = rsa.key.PublicKey.load_pkcs1(PUBLIC_DER, 'der')
expected = rsa.key.PublicKey(3727264081, 65537)
self.assertEqual(expected, key)
@@ -79,7 +79,7 @@ class DerTest(unittest.TestCase):
'''Test saving public DER keys.'''
key = rsa.key.PublicKey(3727264081, 65537)
- der = key.save_pkcs1_der()
+ der = key.save_pkcs1('der')
self.assertEqual(PUBLIC_DER, der)
@@ -90,7 +90,7 @@ class PemTest(unittest.TestCase):
def test_load_private_key(self):
'''Test loading private PEM files.'''
- key = rsa.key.PrivateKey.load_pkcs1_pem(PRIVATE_PEM)
+ key = rsa.key.PrivateKey.load_pkcs1(PRIVATE_PEM, 'pem')
expected = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
self.assertEqual(expected, key)
@@ -99,14 +99,14 @@ class PemTest(unittest.TestCase):
'''Test saving private PEM files.'''
key = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
- pem = key.save_pkcs1_pem()
+ pem = key.save_pkcs1('pem')
self.assertEqual(CLEAN_PRIVATE_PEM, pem)
def test_load_public_key(self):
'''Test loading public PEM files.'''
- key = rsa.key.PublicKey.load_pkcs1_pem(PUBLIC_PEM)
+ key = rsa.key.PublicKey.load_pkcs1(PUBLIC_PEM, 'pem')
expected = rsa.key.PublicKey(3727264081, 65537)
self.assertEqual(expected, key)
@@ -115,7 +115,7 @@ class PemTest(unittest.TestCase):
'''Test saving public PEM files.'''
key = rsa.key.PublicKey(3727264081, 65537)
- pem = key.save_pkcs1_pem()
+ pem = key.save_pkcs1('pem')
self.assertEqual(CLEAN_PUBLIC_PEM, pem)