diff options
-rw-r--r-- | rsa/key.py | 56 | ||||
-rw-r--r-- | tests/test_load_save_keys.py | 16 |
2 files changed, 58 insertions, 14 deletions
@@ -25,6 +25,7 @@ of pyasn1. ''' +import abc import logging import rsa.prime @@ -33,7 +34,53 @@ import rsa.common log = logging.getLogger(__name__) -class PublicKey(object): +class AbstractKey(object): + '''Abstract superclass for private and public keys.''' + + @classmethod + def load_pkcs1(cls, keyfile, format='pem'): + r'''Loads a key in PKCS#1 DER or PEM format. + + @param keyfile: contents of a DER- or PEM-encoded file that contains + the public key. + @param format: the format of the file to load; 'pem' or 'der' + @return: a PublicKey object + ''' + + methods = { + 'pem': cls.load_pkcs1_pem, + 'der': cls.load_pkcs1_der, + } + + if format not in methods: + formats = ', '.join(sorted(methods.keys())) + raise ValueError('Unsupported format: %r, try one of %s' % (format, + formats)) + + method = methods[format] + return method(keyfile) + + def save_pkcs1(self, format='pem'): + '''Saves the public key in PKCS#1 DER or PEM format. + + @param format: the format to save; 'pem' or 'der' + @returns: the DER- or PEM-encoded public key. + ''' + + methods = { + 'pem': self.save_pkcs1_pem, + 'der': self.save_pkcs1_der, + } + + if format not in methods: + formats = ', '.join(sorted(methods.keys())) + raise ValueError('Unsupported format: %r, try one of %s' % (format, + formats)) + + method = methods[format] + return method() + +class PublicKey(AbstractKey): '''Represents a public RSA key. This key is also known as the 'encryption key'. It contains the 'n' and 'e' @@ -158,8 +205,7 @@ class PublicKey(object): der = self.save_pkcs1_der() return rsa.pem.save_pem(der, 'RSA PUBLIC KEY') - -class PrivateKey(object): +class PrivateKey(AbstractKey): '''Represents a private RSA key. This key is also known as the 'decryption key'. It contains the 'n', 'e', @@ -347,7 +393,6 @@ class PrivateKey(object): return rsa.pem.save_pem(der, 'RSA PRIVATE KEY') - def extended_gcd(a, b): """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb """ @@ -512,8 +557,7 @@ def newkeys(nbits, accurate=True): PrivateKey(n, e, d, p, q) ) -__all__ = ['PublicKey', 'PrivateKey', 'newkeys', 'load_private_key_der', - 'load_private_key_pem', 'save_private_key_der', 'save_private_key_pem'] +__all__ = ['PublicKey', 'PrivateKey', 'newkeys'] if __name__ == '__main__': import doctest diff --git a/tests/test_load_save_keys.py b/tests/test_load_save_keys.py index fca4241..abcfb18 100644 --- a/tests/test_load_save_keys.py +++ b/tests/test_load_save_keys.py @@ -54,7 +54,7 @@ class DerTest(unittest.TestCase): def test_load_private_key(self): '''Test loading private DER keys.''' - key = rsa.key.PrivateKey.load_pkcs1_der(PRIVATE_DER) + key = rsa.key.PrivateKey.load_pkcs1(PRIVATE_DER, 'der') expected = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) self.assertEqual(expected, key) @@ -63,14 +63,14 @@ class DerTest(unittest.TestCase): '''Test saving private DER keys.''' key = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) - der = key.save_pkcs1_der() + der = key.save_pkcs1('der') self.assertEqual(PRIVATE_DER, der) def test_load_public_key(self): '''Test loading public DER keys.''' - key = rsa.key.PublicKey.load_pkcs1_der(PUBLIC_DER) + key = rsa.key.PublicKey.load_pkcs1(PUBLIC_DER, 'der') expected = rsa.key.PublicKey(3727264081, 65537) self.assertEqual(expected, key) @@ -79,7 +79,7 @@ class DerTest(unittest.TestCase): '''Test saving public DER keys.''' key = rsa.key.PublicKey(3727264081, 65537) - der = key.save_pkcs1_der() + der = key.save_pkcs1('der') self.assertEqual(PUBLIC_DER, der) @@ -90,7 +90,7 @@ class PemTest(unittest.TestCase): def test_load_private_key(self): '''Test loading private PEM files.''' - key = rsa.key.PrivateKey.load_pkcs1_pem(PRIVATE_PEM) + key = rsa.key.PrivateKey.load_pkcs1(PRIVATE_PEM, 'pem') expected = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) self.assertEqual(expected, key) @@ -99,14 +99,14 @@ class PemTest(unittest.TestCase): '''Test saving private PEM files.''' key = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) - pem = key.save_pkcs1_pem() + pem = key.save_pkcs1('pem') self.assertEqual(CLEAN_PRIVATE_PEM, pem) def test_load_public_key(self): '''Test loading public PEM files.''' - key = rsa.key.PublicKey.load_pkcs1_pem(PUBLIC_PEM) + key = rsa.key.PublicKey.load_pkcs1(PUBLIC_PEM, 'pem') expected = rsa.key.PublicKey(3727264081, 65537) self.assertEqual(expected, key) @@ -115,7 +115,7 @@ class PemTest(unittest.TestCase): '''Test saving public PEM files.''' key = rsa.key.PublicKey(3727264081, 65537) - pem = key.save_pkcs1_pem() + pem = key.save_pkcs1('pem') self.assertEqual(CLEAN_PUBLIC_PEM, pem) |