From aa33502883c1fc2a150d4adec4b0205e3112fd21 Mon Sep 17 00:00:00 2001 From: weidai Date: Mon, 16 Apr 2007 00:39:56 +0000 Subject: Test: Encode now tests decryption also git-svn-id: svn://svn.code.sf.net/p/cryptopp/code/trunk/c5@318 57ff6487-cd31-0410-9ec3-f628ee90f5f0 --- datatest.cpp | 168 ++++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 109 insertions(+), 59 deletions(-) (limited to 'datatest.cpp') diff --git a/datatest.cpp b/datatest.cpp index 1942326..4a32609 100644 --- a/datatest.cpp +++ b/datatest.cpp @@ -12,7 +12,6 @@ USING_NAMESPACE(CryptoPP) USING_NAMESPACE(std) RandomPool & GlobalRNG(); -void RegisterFactories(); typedef std::map TestData; @@ -44,6 +43,60 @@ static void SignalTestError() throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test"); } +const std::string & GetRequiredDatum(const TestData &data, const char *name) +{ + TestData::const_iterator i = data.find(name); + if (i == data.end()) + SignalTestError(); + return i->second; +} + +void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target) +{ + std::string s1 = GetRequiredDatum(data, name), s2; + + while (!s1.empty()) + { + while (s1[0] == ' ') + s1 = s1.substr(1); + + int repeat = 1; + if (s1[0] == 'r') + { + repeat = atoi(s1.c_str()+1); + s1 = s1.substr(s1.find(' ')+1); + } + + s2.clear(); + + if (s1[0] == '\"') + { + s2 = s1.substr(1, s1.find('\"', 1)-1); + s1 = s1.substr(s2.length() + 2); + } + else if (s1.substr(0, 2) == "0x") + { + StringSource(s1.substr(2, s1.find(' ')), true, new HexDecoder(new StringSink(s2))); + s1 = s1.substr(STDMIN(s1.find(' '), s1.length())); + } + else + { + StringSource(s1.substr(0, s1.find(' ')), true, new HexDecoder(new StringSink(s2))); + s1 = s1.substr(STDMIN(s1.find(' '), s1.length())); + } + + while (repeat--) + target.Put((const byte *)s2.data(), s2.size()); + } +} + +std::string GetDecodedDatum(const TestData &data, const char *name) +{ + std::string s; + PutDecodedDatumInto(data, name, StringSink(s).Ref()); + return s; +} + class TestDataNameValuePairs : public NameValuePairs { public: @@ -64,13 +117,13 @@ public: else if (valueType == typeid(ConstByteArrayParameter)) { m_temp.resize(0); - StringSource(value, true, new HexDecoder(new StringSink(m_temp))); + PutDecodedDatumInto(m_data, name, StringSink(m_temp).Ref()); reinterpret_cast(pValue)->Assign((const byte *)m_temp.data(), m_temp.size(), true); } else if (valueType == typeid(const byte *)) { m_temp.resize(0); - StringSource(value, true, new HexDecoder(new StringSink(m_temp))); + PutDecodedDatumInto(m_data, name, StringSink(m_temp).Ref()); *reinterpret_cast(pValue) = (const byte *)m_temp.data(); } else @@ -84,43 +137,6 @@ private: mutable std::string m_temp; }; -const std::string & GetRequiredDatum(const TestData &data, const char *name) -{ - TestData::const_iterator i = data.find(name); - if (i == data.end()) - SignalTestError(); - return i->second; -} - -void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target) -{ - std::string s1 = GetRequiredDatum(data, name), s2; - - int repeat = 1; - if (s1[0] == 'r') - { - repeat = atoi(s1.c_str()+1); - s1 = s1.substr(s1.find(' ')+1); - } - - if (s1[0] == '\"') - s2 = s1.substr(1, s1.find('\"', 1)-1); - else if (s1.substr(0, 2) == "0x") - StringSource(s1.substr(2), true, new HexDecoder(new StringSink(s2))); - else - StringSource(s1, true, new HexDecoder(new StringSink(s2))); - - while (repeat--) - target.Put((const byte *)s2.data(), s2.size()); -} - -std::string GetDecodedDatum(const TestData &data, const char *name) -{ - std::string s; - PutDecodedDatumInto(data, name, StringSink(s).Ref()); - return s; -} - void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv) { if (!pub.Validate(GlobalRNG(), 3)) @@ -256,40 +272,70 @@ void TestSymmetricCipher(TestData &v) std::string test = GetRequiredDatum(v, "Test"); std::string key = GetDecodedDatum(v, "Key"); - std::string ciphertext = GetDecodedDatum(v, "Ciphertext"); std::string plaintext = GetDecodedDatum(v, "Plaintext"); TestDataNameValuePairs pairs(v); - if (test == "Encrypt") + if (test == "Encrypt" || test == "EncryptXorDigest") { std::auto_ptr encryptor(ObjectFactoryRegistry::Registry().CreateObject(name.c_str())); + std::auto_ptr decryptor(ObjectFactoryRegistry::Registry().CreateObject(name.c_str())); ConstByteArrayParameter iv; if (pairs.GetValue(Name::IV(), iv) && iv.size() != encryptor->IVSize()) SignalTestFailure(); encryptor->SetKey((const byte *)key.data(), key.size(), pairs); - int seek = pairs.GetIntValueWithDefault("Seek", 0); - if (seek) - encryptor->Seek(seek); - std::string encrypted; - StringSource ss(plaintext, true, new StreamTransformationFilter(*encryptor, new StringSink(encrypted), StreamTransformationFilter::NO_PADDING)); - if (encrypted != ciphertext) - SignalTestFailure(); - } - else if (test == "Decrypt") - { - std::auto_ptr decryptor(ObjectFactoryRegistry::Registry().CreateObject(name.c_str())); - ConstByteArrayParameter iv; - if (pairs.GetValue(Name::IV(), iv) && iv.size() != decryptor->IVSize()) - SignalTestFailure(); decryptor->SetKey((const byte *)key.data(), key.size(), pairs); int seek = pairs.GetIntValueWithDefault("Seek", 0); if (seek) + { + encryptor->Seek(seek); decryptor->Seek(seek); + } + std::string encrypted, xorDigest, ciphertext, ciphertextXorDigest; + StringSource ss(plaintext, false, new StreamTransformationFilter(*encryptor, new StringSink(encrypted), StreamTransformationFilter::NO_PADDING)); + ss.Pump(plaintext.size()/2 + 1); + ss.PumpAll(); + /*{ + std::string z; + encryptor->Seek(seek); + StringSource ss(plaintext, false, new StreamTransformationFilter(*encryptor, new StringSink(z), StreamTransformationFilter::NO_PADDING)); + while (ss.Pump(64)) {} + ss.PumpAll(); + for (int i=0; i hash; HashTransformation *pHash = NULL; + TestDataNameValuePairs pairs(v); + if (testDigest) { hash.reset(ObjectFactoryRegistry::Registry().CreateObject(name.c_str())); @@ -316,8 +364,11 @@ void TestDigestOrMAC(TestData &v, bool testDigest) { mac.reset(ObjectFactoryRegistry::Registry().CreateObject(name.c_str())); pHash = mac.get(); + ConstByteArrayParameter iv; + if (pairs.GetValue(Name::IV(), iv) && iv.size() != mac->IVSize()) + SignalTestFailure(); std::string key = GetDecodedDatum(v, "Key"); - mac->SetKey((const byte *)key.c_str(), key.size()); + mac->SetKey((const byte *)key.c_str(), key.size(), pairs); } if (test == "Verify" || test == "VerifyTruncated" || test == "NotVerify") @@ -499,7 +550,6 @@ void TestDataFile(const std::string &filename, unsigned int &totalTests, unsigne bool RunTestDataFile(const char *filename) { - RegisterFactories(); unsigned int totalTests = 0, failedTests = 0; TestDataFile(filename, totalTests, failedTests); cout << "\nTests complete. Total tests = " << totalTests << ". Failed tests = " << failedTests << ".\n"; -- cgit v1.2.1