From 42c3d8f3aa593c224174558fd6f3d2709e08f7d0 Mon Sep 17 00:00:00 2001 From: weidai Date: Wed, 16 Jul 2003 01:53:45 +0000 Subject: added support for using encoding parameters and key derivation parameters git-svn-id: svn://svn.code.sf.net/p/cryptopp/code/trunk/c5@98 57ff6487-cd31-0410-9ec3-f628ee90f5f0 --- oaep.cpp | 58 ++++++++++++++++++++++------------------------------------ 1 file changed, 22 insertions(+), 36 deletions(-) (limited to 'oaep.cpp') diff --git a/oaep.cpp b/oaep.cpp index 8913631..ddd846d 100644 --- a/oaep.cpp +++ b/oaep.cpp @@ -9,30 +9,12 @@ NAMESPACE_BEGIN(CryptoPP) // ******************************************************** -ANONYMOUS_NAMESPACE_BEGIN - template - struct PHashComputation - { - PHashComputation() {H().CalculateDigest(pHash, P, PLen);} - byte pHash[H::DIGESTSIZE]; - }; - - template - const byte *PHash() - { - static PHashComputation pHash; - return pHash.pHash; - } -NAMESPACE_END - -template -unsigned int OAEP::MaxUnpaddedLength(unsigned int paddedLength) const +unsigned int OAEP_Base::MaxUnpaddedLength(unsigned int paddedLength) const { - return paddedLength/8 > 1+2*H::DIGESTSIZE ? paddedLength/8-1-2*H::DIGESTSIZE : 0; + return SaturatingSubtract(paddedLength/8, 1+2*DigestSize()); } -template -void OAEP::Pad(RandomNumberGenerator &rng, const byte *input, unsigned int inputLength, byte *oaepBlock, unsigned int oaepBlockLen) const +void OAEP_Base::Pad(RandomNumberGenerator &rng, const byte *input, unsigned int inputLength, byte *oaepBlock, unsigned int oaepBlockLen, const NameValuePairs ¶meters) const { assert (inputLength <= MaxUnpaddedLength(oaepBlockLen)); @@ -44,26 +26,28 @@ void OAEP::Pad(RandomNumberGenerator &rng, const byte *input, unsi } oaepBlockLen /= 8; - const unsigned int hLen = H::DIGESTSIZE; + std::auto_ptr pHash(NewHash()); + const unsigned int hLen = pHash->DigestSize(); const unsigned int seedLen = hLen, dbLen = oaepBlockLen-seedLen; byte *const maskedSeed = oaepBlock; byte *const maskedDB = oaepBlock+seedLen; + ConstByteArrayParameter encodingParameters; + parameters.GetValue(Name::EncodingParameters(), encodingParameters); + // DB = pHash || 00 ... || 01 || M - memcpy(maskedDB, PHash(), hLen); + pHash->CalculateDigest(maskedDB, encodingParameters.begin(), encodingParameters.size()); memset(maskedDB+hLen, 0, dbLen-hLen-inputLength-1); maskedDB[dbLen-inputLength-1] = 0x01; memcpy(maskedDB+dbLen-inputLength, input, inputLength); rng.GenerateBlock(maskedSeed, seedLen); - H h; - MGF mgf; - mgf.GenerateAndMask(h, maskedDB, dbLen, maskedSeed, seedLen); - mgf.GenerateAndMask(h, maskedSeed, seedLen, maskedDB, dbLen); + std::auto_ptr pMGF(NewMGF()); + pMGF->GenerateAndMask(*pHash, maskedDB, dbLen, maskedSeed, seedLen); + pMGF->GenerateAndMask(*pHash, maskedSeed, seedLen, maskedDB, dbLen); } -template -DecodingResult OAEP::Unpad(const byte *oaepBlock, unsigned int oaepBlockLen, byte *output) const +DecodingResult OAEP_Base::Unpad(const byte *oaepBlock, unsigned int oaepBlockLen, byte *output, const NameValuePairs ¶meters) const { bool invalid = false; @@ -75,7 +59,8 @@ DecodingResult OAEP::Unpad(const byte *oaepBlock, unsigned int oae } oaepBlockLen /= 8; - const unsigned int hLen = H::DIGESTSIZE; + std::auto_ptr pHash(NewHash()); + const unsigned int hLen = pHash->DigestSize(); const unsigned int seedLen = hLen, dbLen = oaepBlockLen-seedLen; invalid = (oaepBlockLen < 2*hLen+1) || invalid; @@ -84,17 +69,18 @@ DecodingResult OAEP::Unpad(const byte *oaepBlock, unsigned int oae byte *const maskedSeed = t; byte *const maskedDB = t+seedLen; - H h; - MGF mgf; - mgf.GenerateAndMask(h, maskedSeed, seedLen, maskedDB, dbLen); - mgf.GenerateAndMask(h, maskedDB, dbLen, maskedSeed, seedLen); + std::auto_ptr pMGF(NewMGF()); + pMGF->GenerateAndMask(*pHash, maskedSeed, seedLen, maskedDB, dbLen); + pMGF->GenerateAndMask(*pHash, maskedDB, dbLen, maskedSeed, seedLen); - // DB = pHash' || 00 ... || 01 || M + ConstByteArrayParameter encodingParameters; + parameters.GetValue(Name::EncodingParameters(), encodingParameters); + // DB = pHash' || 00 ... || 01 || M byte *M = std::find(maskedDB+hLen, maskedDB+dbLen, 0x01); invalid = (M == maskedDB+dbLen) || invalid; invalid = (std::find_if(maskedDB+hLen, M, std::bind2nd(std::not_equal_to(), 0)) != M) || invalid; - invalid = (memcmp(maskedDB, PHash(), hLen) != 0) || invalid; + invalid = !pHash->VerifyDigest(maskedDB, encodingParameters.begin(), encodingParameters.size()) || invalid; if (invalid) return DecodingResult(); -- cgit v1.2.1