diff options
Diffstat (limited to 'lib/Crypto/Signature/PKCS1_PSS.py')
-rw-r--r-- | lib/Crypto/Signature/PKCS1_PSS.py | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/lib/Crypto/Signature/PKCS1_PSS.py b/lib/Crypto/Signature/PKCS1_PSS.py index cd9eaf3..3840959 100644 --- a/lib/Crypto/Signature/PKCS1_PSS.py +++ b/lib/Crypto/Signature/PKCS1_PSS.py @@ -72,6 +72,7 @@ if sys.version_info[0] == 2 and sys.version_info[1] == 1: import Crypto.Util.number from Crypto.Util.number import ceil_shift, ceil_div, long_to_bytes from Crypto.Util.strxor import strxor +from Crypto.Hash import new as Hash_new class PSS_SigScheme: """This signature scheme can perform PKCS#1 PSS RSA signature or verification.""" @@ -203,7 +204,11 @@ def MGF1(mgfSeed, maskLen, hash): T = b("") for counter in xrange(ceil_div(maskLen, hash.digest_size)): c = long_to_bytes(counter, 4) - T = T + hash.new(mgfSeed + c).digest() + try: + T = T + hash.new(mgfSeed + c).digest() + except AttributeError: + # hash object doesn't have a "new" method. Use Crypto.Hash.new() to instantiate it + T = T + Hash_new(hash, mgfSeed + c).digest() assert(len(T)>=maskLen) return T[:maskLen] @@ -253,7 +258,11 @@ def EMSA_PSS_ENCODE(mhash, emBits, randFunc, mgf, sLen): if randFunc and sLen>0: salt = randFunc(sLen) # Step 5 and 6 - h = mhash.new(bchr(0x00)*8 + mhash.digest() + salt) + try: + h = mhash.new(bchr(0x00)*8 + mhash.digest() + salt) + except AttributeError: + # hash object doesn't have a "new" method. Use Crypto.Hash.new() to instantiate it + h = Hash_new(mhash, bchr(0x00)*8 + mhash.digest() + salt) # Step 7 and 8 db = bchr(0x00)*(emLen-sLen-mhash.digest_size-2) + bchr(0x01) + salt # Step 9 @@ -328,7 +337,11 @@ def EMSA_PSS_VERIFY(mhash, em, emBits, mgf, sLen): salt = b("") if sLen: salt = db[-sLen:] # Step 12 and 13 - hp = mhash.new(bchr(0x00)*8 + mhash.digest() + salt).digest() + try: + hp = mhash.new(bchr(0x00)*8 + mhash.digest() + salt).digest() + except AttributeError: + # hash object doesn't have a "new" method. Use Crypto.Hash.new() to instantiate it + hp = Hash_new(mhash, bchr(0x00)*8 + mhash.digest() + salt).digest() # Step 14 if h!=hp: return False |