diff options
-rw-r--r-- | lib/Crypto/PublicKey/RSA.py | 2 | ||||
-rw-r--r-- | lib/Crypto/SelfTest/PublicKey/test_RSA.py | 60 |
2 files changed, 62 insertions, 0 deletions
diff --git a/lib/Crypto/PublicKey/RSA.py b/lib/Crypto/PublicKey/RSA.py index bab9288..01ee84f 100644 --- a/lib/Crypto/PublicKey/RSA.py +++ b/lib/Crypto/PublicKey/RSA.py @@ -286,6 +286,8 @@ class _RSAobj(pubkey.pubkey): def __setstate__(self, d): if not hasattr(self, 'implementation'): self.implementation = RSAImplementation() + if not hasattr(self, '_randfunc'): + self._randfunc = Random.new().read t = [] for k in self.keydata: if not d.has_key(k): diff --git a/lib/Crypto/SelfTest/PublicKey/test_RSA.py b/lib/Crypto/SelfTest/PublicKey/test_RSA.py index c971042..e0d1c0a 100644 --- a/lib/Crypto/SelfTest/PublicKey/test_RSA.py +++ b/lib/Crypto/SelfTest/PublicKey/test_RSA.py @@ -28,6 +28,7 @@ __revision__ = "$Id$" import sys import os +import pickle if sys.version_info[0] == 2 and sys.version_info[1] == 1: from Crypto.Util.py21compat import * from Crypto.Util.py3compat import * @@ -87,6 +88,21 @@ class RSATest(unittest.TestCase): ce 33 52 52 4d 04 16 a5 a4 41 e7 00 af 46 15 03 """ + # The same key, in pickled format (from pycrypto 2.3) + # to ensure backward compatibility + pickled_key_2_3 = \ + "(iCrypto.PublicKey.RSA\n_RSAobj\np0\n(dp2\nS'e'\np3\nL17L\nsS'd'\np4"\ + "\nL11646763154293086160147889314553506764606353688284149120983587488"\ + "79382229568306696406525871631480713149376749558222371890533687587223"\ + "51580531956820574156366843733156436163097164007967904900300775223658"\ + "03543233292399245064743971969473468304536714979010219881003396235861"\ + "8370829441895425705728523874962107052993L\nsS'n'\np5\nL1319966490819"\ + "88309815009412231606409998872008467220356704480658206329986017741425"\ + "59273959878490114749026269828326520214759381792655199845793621772998"\ + "40439054838068985140623386496543388290455526885872858516219460533763"\ + "92312680578795692682905599590422046720587710762927130740460442438533"\ + "124053848898103790124491L\nsb." + def setUp(self): global RSA, Random, bytes_to_long from Crypto.PublicKey import RSA @@ -178,6 +194,29 @@ class RSATest(unittest.TestCase): self.assertRaises(ValueError, self.rsa.construct, [self.n, self.e, self.n-1]) + def test_serialization(self): + """RSA (default implementation) serialize/unserialize key""" + rsaObj_orig = self.rsa.generate(1024) + rsaObj = pickle.loads(pickle.dumps(rsaObj_orig)) + self._check_private_key(rsaObj) + self._exercise_primitive(rsaObj) + pub = rsaObj.publickey() + self._check_public_key(pub) + self._exercise_public_primitive(rsaObj) + + plaintext = a2b_hex(self.plaintext) + ciphertext1 = rsaObj_orig.encrypt(plaintext, b("")) + ciphertext2 = rsaObj.encrypt(plaintext, b("")) + self.assertEqual(ciphertext1, ciphertext2) + + def test_serialization_compat(self): + """RSA (default implementation) backward compatibility serialization""" + rsaObj = pickle.loads(self.pickled_key_2_3) + plaintext = a2b_hex(self.plaintext) + ciphertext = a2b_hex(self.ciphertext) + ciphertext_result = rsaObj.encrypt(plaintext, b(""))[0] + self.assertEqual(ciphertext_result, ciphertext) + def _check_private_key(self, rsaObj): # Check capabilities self.assertEqual(1, rsaObj.has_private()) @@ -352,6 +391,18 @@ class RSAFastMathTest(RSATest): def test_factoring(self): RSATest.test_factoring(self) + + def test_serialization(self): + """RSA (_fastmath implementation) serialize/unserialize key + """ + RSATest.test_serialization(self) + + def test_serialization_compat(self): + """RSA (_fastmath implementation) backward compatibility serialization + """ + RSATest.test_serialization_compat(self) + + class RSASlowMathTest(RSATest): def setUp(self): RSATest.setUp(self) @@ -388,6 +439,15 @@ class RSASlowMathTest(RSATest): def test_factoring(self): RSATest.test_factoring(self) + def test_serialization(self): + """RSA (_slowmath implementation) serialize/unserialize key""" + RSATest.test_serialization(self) + + def test_serialization_compat(self): + """RSA (_slowmath implementation) backward compatibility serialization + """ + RSATest.test_serialization_compat(self) + def get_tests(config={}): tests = [] tests += list_test_cases(RSATest) |