summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrank Sievertsen <packaging@fx5.de>2012-11-19 08:21:21 +0100
committerDwayne Litzenberger <dlitz@dlitz.net>2013-02-16 11:14:09 -0800
commitdb52ac71e804e85ace3edd772d2e78055972ac2c (patch)
treeb165d36a8c919262cd17dc3c7714b7aac614df18
parent20cf8add93ccd6f4e4dc60691585256ba36e7810 (diff)
downloadpycrypto-db52ac71e804e85ace3edd772d2e78055972ac2c.tar.gz
Fix RSA object serialization
-rw-r--r--lib/Crypto/PublicKey/RSA.py2
-rw-r--r--lib/Crypto/SelfTest/PublicKey/test_RSA.py60
2 files changed, 62 insertions, 0 deletions
diff --git a/lib/Crypto/PublicKey/RSA.py b/lib/Crypto/PublicKey/RSA.py
index bab9288..01ee84f 100644
--- a/lib/Crypto/PublicKey/RSA.py
+++ b/lib/Crypto/PublicKey/RSA.py
@@ -286,6 +286,8 @@ class _RSAobj(pubkey.pubkey):
def __setstate__(self, d):
if not hasattr(self, 'implementation'):
self.implementation = RSAImplementation()
+ if not hasattr(self, '_randfunc'):
+ self._randfunc = Random.new().read
t = []
for k in self.keydata:
if not d.has_key(k):
diff --git a/lib/Crypto/SelfTest/PublicKey/test_RSA.py b/lib/Crypto/SelfTest/PublicKey/test_RSA.py
index c971042..e0d1c0a 100644
--- a/lib/Crypto/SelfTest/PublicKey/test_RSA.py
+++ b/lib/Crypto/SelfTest/PublicKey/test_RSA.py
@@ -28,6 +28,7 @@ __revision__ = "$Id$"
import sys
import os
+import pickle
if sys.version_info[0] == 2 and sys.version_info[1] == 1:
from Crypto.Util.py21compat import *
from Crypto.Util.py3compat import *
@@ -87,6 +88,21 @@ class RSATest(unittest.TestCase):
ce 33 52 52 4d 04 16 a5 a4 41 e7 00 af 46 15 03
"""
+ # The same key, in pickled format (from pycrypto 2.3)
+ # to ensure backward compatibility
+ pickled_key_2_3 = \
+ "(iCrypto.PublicKey.RSA\n_RSAobj\np0\n(dp2\nS'e'\np3\nL17L\nsS'd'\np4"\
+ "\nL11646763154293086160147889314553506764606353688284149120983587488"\
+ "79382229568306696406525871631480713149376749558222371890533687587223"\
+ "51580531956820574156366843733156436163097164007967904900300775223658"\
+ "03543233292399245064743971969473468304536714979010219881003396235861"\
+ "8370829441895425705728523874962107052993L\nsS'n'\np5\nL1319966490819"\
+ "88309815009412231606409998872008467220356704480658206329986017741425"\
+ "59273959878490114749026269828326520214759381792655199845793621772998"\
+ "40439054838068985140623386496543388290455526885872858516219460533763"\
+ "92312680578795692682905599590422046720587710762927130740460442438533"\
+ "124053848898103790124491L\nsb."
+
def setUp(self):
global RSA, Random, bytes_to_long
from Crypto.PublicKey import RSA
@@ -178,6 +194,29 @@ class RSATest(unittest.TestCase):
self.assertRaises(ValueError, self.rsa.construct, [self.n, self.e, self.n-1])
+ def test_serialization(self):
+ """RSA (default implementation) serialize/unserialize key"""
+ rsaObj_orig = self.rsa.generate(1024)
+ rsaObj = pickle.loads(pickle.dumps(rsaObj_orig))
+ self._check_private_key(rsaObj)
+ self._exercise_primitive(rsaObj)
+ pub = rsaObj.publickey()
+ self._check_public_key(pub)
+ self._exercise_public_primitive(rsaObj)
+
+ plaintext = a2b_hex(self.plaintext)
+ ciphertext1 = rsaObj_orig.encrypt(plaintext, b(""))
+ ciphertext2 = rsaObj.encrypt(plaintext, b(""))
+ self.assertEqual(ciphertext1, ciphertext2)
+
+ def test_serialization_compat(self):
+ """RSA (default implementation) backward compatibility serialization"""
+ rsaObj = pickle.loads(self.pickled_key_2_3)
+ plaintext = a2b_hex(self.plaintext)
+ ciphertext = a2b_hex(self.ciphertext)
+ ciphertext_result = rsaObj.encrypt(plaintext, b(""))[0]
+ self.assertEqual(ciphertext_result, ciphertext)
+
def _check_private_key(self, rsaObj):
# Check capabilities
self.assertEqual(1, rsaObj.has_private())
@@ -352,6 +391,18 @@ class RSAFastMathTest(RSATest):
def test_factoring(self):
RSATest.test_factoring(self)
+
+ def test_serialization(self):
+ """RSA (_fastmath implementation) serialize/unserialize key
+ """
+ RSATest.test_serialization(self)
+
+ def test_serialization_compat(self):
+ """RSA (_fastmath implementation) backward compatibility serialization
+ """
+ RSATest.test_serialization_compat(self)
+
+
class RSASlowMathTest(RSATest):
def setUp(self):
RSATest.setUp(self)
@@ -388,6 +439,15 @@ class RSASlowMathTest(RSATest):
def test_factoring(self):
RSATest.test_factoring(self)
+ def test_serialization(self):
+ """RSA (_slowmath implementation) serialize/unserialize key"""
+ RSATest.test_serialization(self)
+
+ def test_serialization_compat(self):
+ """RSA (_slowmath implementation) backward compatibility serialization
+ """
+ RSATest.test_serialization_compat(self)
+
def get_tests(config={}):
tests = []
tests += list_test_cases(RSATest)