diff options
author | Martin Thomson <mt@lowentropy.net> | 2022-02-14 19:14:36 +0000 |
---|---|---|
committer | Martin Thomson <mt@lowentropy.net> | 2022-02-14 19:14:36 +0000 |
commit | 2c4b67fb3e56f2f37425423622415cc287ccb5e3 (patch) | |
tree | aeeb84bd27c884f67e1619c51afcb0a4ffe6eb11 /gtests/pk11_gtest | |
parent | f79fdd517a1cdd13012e0bdfec6c17935a1aab95 (diff) | |
download | nss-hg-2c4b67fb3e56f2f37425423622415cc287ccb5e3.tar.gz |
Bug 1747957 - Use Wycheproof JSON for RSASSA-PSS, r=nss-reviewers,bbeurdouche
Differential Revision: https://phabricator.services.mozilla.com/D134846
Diffstat (limited to 'gtests/pk11_gtest')
-rw-r--r-- | gtests/pk11_gtest/json.h | 168 | ||||
-rw-r--r-- | gtests/pk11_gtest/pk11_hpke_unittest.cc | 157 | ||||
-rw-r--r-- | gtests/pk11_gtest/pk11_rsapss_unittest.cc | 239 |
3 files changed, 369 insertions, 195 deletions
diff --git a/gtests/pk11_gtest/json.h b/gtests/pk11_gtest/json.h new file mode 100644 index 000000000..1cd6c341c --- /dev/null +++ b/gtests/pk11_gtest/json.h @@ -0,0 +1,168 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#define __json_h__ +#include <vector> +#include "gtest/gtest.h" +#include "nss_scoped_ptrs.h" +#include "pk11pub.h" + +// If we make a few assumptions about the file, parsing JSON can be easy. +// This is not a full parser, it only works on a narrow set of inputs. +class JsonReader { + public: + JsonReader(const std::string &n) : buf_(), available_(0), i_(0) { + f_.reset(PR_Open(n.c_str(), PR_RDONLY, 00600)); + EXPECT_TRUE(f_) << "error opening vectors from: " << n; + buf_[0] = 0; + } + + void next() { i_++; } + uint8_t peek() { + TopUp(); + return buf_[i_]; + } + uint8_t take() { + uint8_t v = peek(); + next(); + return v; + } + + // No input checking, overflow protection, or any safety. + // Returns 0 if there isn't a number here rather than aborting. + uint64_t ReadInt() { + SkipWhitespace(); + uint8_t c = peek(); + uint64_t v = 0; + while (c >= '0' && c <= '9') { + v = v * 10 + c - '0'; + next(); + c = peek(); + } + return v; + } + + // No input checking, no unicode, no escaping (not even \"), just read ASCII. + std::string ReadString() { + SkipWhitespace(); + if (peek() != '"') { + return ""; + } + next(); + + std::string s; + uint8_t c = take(); + while (c != '"') { + s.push_back(c); + c = take(); + } + return s; + } + + std::string ReadLabel() { + std::string s = ReadString(); + SkipWhitespace(); + EXPECT_EQ(take(), ':'); + return s; + } + + std::vector<uint8_t> ReadHex() { + SkipWhitespace(); + uint8_t c = take(); + EXPECT_EQ(c, '"'); + std::vector<uint8_t> v; + c = take(); + while (c != '"') { + v.push_back(JsonReader::Hex(c) << 4 | JsonReader::Hex(take())); + c = take(); + } + return v; + } + + bool NextItem(uint8_t h = '{', uint8_t t = '}') { + SkipWhitespace(); + switch (uint8_t c = take()) { + case ',': + return true; + case '{': + case '[': + EXPECT_EQ(c, h); + SkipWhitespace(); + if (peek() == t) { + next(); + return false; + } + return true; + case '}': + case ']': + EXPECT_EQ(c, t); + return false; + default: + ADD_FAILURE() << "Unexpected '" << c << "'"; + } + return false; + } + + bool NextItemArray() { return NextItem('[', ']'); } + + void SkipValue() { + uint8_t c = take(); + if (c == '"') { + do { + c = take(); + } while (c != '"'); + } else if (c >= '0' && c <= '9') { + c = peek(); + while (c >= '0' && c <= '9') { + next(); + c = peek(); + } + } else { + ADD_FAILURE() << "No idea how to skip'" << c << "'"; + } + } + + private: + void TopUp() { + if (available_ > i_) { + return; + } + i_ = 0; + if (!f_) { + return; + } + PRInt32 res = PR_Read(f_.get(), buf_, sizeof(buf_)); + if (res > 0) { + available_ = static_cast<size_t>(res); + } else { + available_ = 1; + f_.reset(nullptr); + buf_[0] = 0; + } + } + + void SkipWhitespace() { + uint8_t c = peek(); + while (c && (c == ' ' || c == '\t' || c == '\r' || c == '\n')) { + next(); + c = peek(); + } + } + + // This only handles lowercase. + uint8_t Hex(uint8_t c) { + if (c >= '0' && c <= '9') { + return c - '0'; + } + EXPECT_TRUE(c >= 'a' && c <= 'f'); + return c - 'a' + 10; + } + + ScopedPRFileDesc f_; + uint8_t buf_[4096]; + size_t available_; + size_t i_; +}; diff --git a/gtests/pk11_gtest/pk11_hpke_unittest.cc b/gtests/pk11_gtest/pk11_hpke_unittest.cc index b08a6a5cd..0ed81f899 100644 --- a/gtests/pk11_gtest/pk11_hpke_unittest.cc +++ b/gtests/pk11_gtest/pk11_hpke_unittest.cc @@ -7,6 +7,7 @@ #include <memory> #include "blapi.h" #include "gtest/gtest.h" +#include "json.h" #include "nss.h" #include "nss_scoped_ptrs.h" #include "pk11hpke.h" @@ -224,156 +225,6 @@ class HpkeTest { } }; -// If we make a few assumptions about the file, parsing JSON can be easy. -// This is not a full parser, it only works on a narrow set of inputs. -class JsonReader { - public: - JsonReader(const std::string &n) : buf_(), available_(0), i_(0) { - f_.reset(PR_Open(n.c_str(), PR_RDONLY, 00600)); - EXPECT_TRUE(f_) << "error opening vectors from: " << n; - buf_[0] = 0; - } - - void next() { i_++; } - uint8_t peek() { - TopUp(); - return buf_[i_]; - } - uint8_t take() { - uint8_t v = peek(); - next(); - return v; - } - - // No input checking, overflow protection, or any safety. - // Returns 0 if there isn't a number here rather than aborting. - uint64_t ReadInt() { - SkipWhitespace(); - uint8_t c = peek(); - uint64_t v = 0; - while (c >= '0' && c <= '9') { - v = v * 10 + c - '0'; - next(); - c = peek(); - } - return v; - } - - // No input checking, no unicode, no escaping (not even \"), just read ASCII. - std::string ReadLabel() { - SkipWhitespace(); - if (peek() != '"') { - return ""; - } - next(); - - std::string s; - uint8_t c = take(); - while (c != '"') { - s.push_back(c); - c = take(); - } - SkipWhitespace(); - EXPECT_EQ(take(), ':'); - return s; - } - - std::vector<uint8_t> ReadHex() { - SkipWhitespace(); - uint8_t c = take(); - EXPECT_EQ(c, '"'); - std::vector<uint8_t> v; - c = take(); - while (c != '"') { - v.push_back(JsonReader::Hex(c) << 4 | JsonReader::Hex(take())); - c = take(); - } - return v; - } - - bool NextItem(uint8_t h = '{', uint8_t t = '}') { - SkipWhitespace(); - switch (uint8_t c = take()) { - case ',': - return true; - case '{': - case '[': - EXPECT_EQ(c, h); - SkipWhitespace(); - if (peek() == t) { - next(); - return false; - } - return true; - case '}': - case ']': - EXPECT_EQ(c, t); - return false; - default: - ADD_FAILURE() << "Unexpected '" << c << "'"; - } - return false; - } - - void SkipValue() { - uint8_t c = take(); - if (c == '"') { - do { - c = take(); - } while (c != '"'); - } else if (c >= '0' && c <= '9') { - c = peek(); - while (c >= '0' && c <= '9') { - next(); - c = peek(); - } - } else { - ADD_FAILURE() << "No idea how to skip'" << c << "'"; - } - } - - private: - void TopUp() { - if (available_ > i_) { - return; - } - i_ = 0; - if (!f_) { - return; - } - PRInt32 res = PR_Read(f_.get(), buf_, sizeof(buf_)); - if (res > 0) { - available_ = static_cast<size_t>(res); - } else { - available_ = 1; - f_.reset(nullptr); - buf_[0] = 0; - } - } - - void SkipWhitespace() { - uint8_t c = peek(); - while (c && (c == ' ' || c == '\t' || c == '\r' || c == '\n')) { - next(); - c = peek(); - } - } - - // This only handles lowercase. - uint8_t Hex(uint8_t c) { - if (c >= '0' && c <= '9') { - return c - '0'; - } - EXPECT_TRUE(c >= 'a' && c <= 'f'); - return c - 'a' + 10; - } - - ScopedPRFileDesc f_; - uint8_t buf_[4096]; - size_t available_; - size_t i_; -}; - struct HpkeEncryptVector { std::vector<uint8_t> pt; std::vector<uint8_t> aad; @@ -382,7 +233,7 @@ struct HpkeEncryptVector { static std::vector<HpkeEncryptVector> ReadVec(JsonReader &r) { std::vector<HpkeEncryptVector> all; - while (r.NextItem('[', ']')) { + while (r.NextItemArray()) { HpkeEncryptVector enc; while (r.NextItem()) { std::string n = r.ReadLabel(); @@ -414,7 +265,7 @@ struct HpkeExportVector { static std::vector<HpkeExportVector> ReadVec(JsonReader &r) { std::vector<HpkeExportVector> all; - while (r.NextItem('[', ']')) { + while (r.NextItemArray()) { HpkeExportVector exp; while (r.NextItem()) { std::string n = r.ReadLabel(); @@ -476,7 +327,7 @@ struct HpkeVector { std::vector<HpkeVector> all_tests; uint32_t test_id = 0; - while (r.NextItem('[', ']')) { + while (r.NextItemArray()) { HpkeVector vec = {0}; uint32_t fields = 0; enum class RequiredFields { diff --git a/gtests/pk11_gtest/pk11_rsapss_unittest.cc b/gtests/pk11_gtest/pk11_rsapss_unittest.cc index e8428f794..53e4b342e 100644 --- a/gtests/pk11_gtest/pk11_rsapss_unittest.cc +++ b/gtests/pk11_gtest/pk11_rsapss_unittest.cc @@ -8,6 +8,7 @@ #include "nss.h" #include "pk11pub.h" #include "sechash.h" +#include "json.h" #include "databuffer.h" @@ -16,14 +17,9 @@ #include "pk11_signature_test.h" #include "pk11_rsapss_vectors.h" +#include "testvectors_base/test-structs.h" -#include "testvectors/rsa_pss_2048_sha256_mgf1_32-vectors.h" -#include "testvectors/rsa_pss_2048_sha1_mgf1_20-vectors.h" -#include "testvectors/rsa_pss_2048_sha256_mgf1_0-vectors.h" -#include "testvectors/rsa_pss_3072_sha256_mgf1_32-vectors.h" -#include "testvectors/rsa_pss_4096_sha256_mgf1_32-vectors.h" -#include "testvectors/rsa_pss_4096_sha512_mgf1_32-vectors.h" -#include "testvectors/rsa_pss_misc-vectors.h" +extern std::string g_source_dir; namespace nss_test { @@ -46,7 +42,7 @@ CK_MECHANISM_TYPE RsaPssMapCombo(SECOidTag hashOid) { } class Pkcs11RsaPssTestBase : public Pk11SignatureTest { - protected: + public: Pkcs11RsaPssTestBase(SECOidTag hashOid, CK_RSA_PKCS_MGF_TYPE mgf, int sLen) : Pk11SignatureTest(CKM_RSA_PKCS_PSS, hashOid, RsaPssMapCombo(hashOid)) { pss_params_.hashAlg = PK11_AlgtagToMechanism(hashOid); @@ -80,13 +76,183 @@ class Pkcs11RsaPssTest : public Pkcs11RsaPssTestBase { : Pkcs11RsaPssTestBase(SEC_OID_SHA1, CKG_MGF1_SHA1, SHA1_LENGTH) {} }; -class Pkcs11RsaPssTestWycheproof - : public Pkcs11RsaPssTestBase, - public ::testing::WithParamInterface<RsaPssTestVector> { +class Pkcs11RsaPssTestWycheproof : public ::testing::Test { public: - Pkcs11RsaPssTestWycheproof() - : Pkcs11RsaPssTestBase(GetParam().hash_oid, GetParam().mgf_hash, - GetParam().sLen) {} + Pkcs11RsaPssTestWycheproof() {} + + void Run(const std::string& file) { + std::string basename = "rsa_pss_" + file + "_test.json"; + std::string dir = ::g_source_dir + "/../common/wycheproof/source_vectors/"; + std::cout << "Running tests from: " << basename << std::endl; + + JsonReader r(dir + basename); + while (r.NextItem()) { + std::string n = r.ReadLabel(); + if (n == "") { + break; + } + if (n == "algorithm") { + ASSERT_EQ("RSASSA-PSS", r.ReadString()); + } else if (n == "generatorVersion") { + (void)r.ReadString(); + } else if (n == "numberOfTests") { + (void)r.ReadInt(); + } else if (n == "header") { + while (r.NextItemArray()) { + std::cout << r.ReadString() << std::endl; + } + } else if (n == "notes") { + while (r.NextItem()) { + std::string note = r.ReadLabel(); + if (note == "") { + break; + } + std::cout << note << ": " << r.ReadString() << std::endl; + } + } else if (n == "schema") { + ASSERT_EQ("rsassa_pss_verify_schema.json", r.ReadString()); + } else if (n == "testGroups") { + while (r.NextItemArray()) { + RunGroup(r); + } + } + } + } + + private: + struct TestVector { + uint64_t id; + std::vector<uint8_t> msg; + std::vector<uint8_t> sig; + bool valid; + }; + + class Pkcs11RsaPssTestWrap : public Pkcs11RsaPssTestBase { + public: + Pkcs11RsaPssTestWrap(SECOidTag hash, CK_RSA_PKCS_MGF_TYPE mgf, int s_len) + : Pkcs11RsaPssTestBase(hash, mgf, s_len) {} + + void TestBody() {} + + void Verify(const Pkcs11SignatureTestParams& params, bool valid) { + Pk11SignatureTest::Verify(params, valid); + } + }; + + void ReadTests(JsonReader& r, std::vector<TestVector>* tests) { + while (r.NextItemArray()) { + TestVector t; + while (r.NextItem()) { + std::string n = r.ReadLabel(); + if (n == "") { + break; + } + if (n == "tcId") { + t.id = r.ReadInt(); + } else if (n == "comment") { + (void)r.ReadString(); + } else if (n == "msg") { + t.msg = r.ReadHex(); + } else if (n == "sig") { + t.sig = r.ReadHex(); + } else if (n == "result") { + std::string s = r.ReadString(); + t.valid = (s == "valid" || s == "acceptable"); + } else if (n == "flags") { + while (r.NextItemArray()) { + (void)r.ReadString(); + } + } else { + FAIL() << "unknown test entry attribute"; + } + } + tests->push_back(t); + } + } + + void RunTests(const std::vector<uint8_t>& public_key, SECOidTag hash, + CK_RSA_PKCS_MGF_TYPE mgf, int s_len, + const std::vector<TestVector>& tests) { + ASSERT_NE(0u, public_key.size()); + ASSERT_NE(SEC_OID_UNKNOWN, hash); + ASSERT_NE(CKM_INVALID_MECHANISM, mgf); + ASSERT_NE(0u, tests.size()); + + for (auto& v : tests) { + std::cout << "Running tcid: " << v.id << std::endl; + + Pkcs11RsaPssTestWrap test(hash, mgf, s_len); + Pkcs11SignatureTestParams params = { + DataBuffer(), DataBuffer(public_key.data(), public_key.size()), + DataBuffer(v.msg.data(), v.msg.size()), + DataBuffer(v.sig.data(), v.sig.size())}; + test.Verify(params, v.valid); + } + } + + void RunGroup(JsonReader& r) { + std::vector<uint8_t> public_key; + SECOidTag hash = SEC_OID_UNKNOWN; + CK_RSA_PKCS_MGF_TYPE mgf = CKM_INVALID_MECHANISM; + int s_len = 0; + std::vector<TestVector> tests; + while (r.NextItem()) { + std::string n = r.ReadLabel(); + if (n == "") { + break; + } + if (n == "e" || n == "keyAsn" || n == "keyPem" || n == "n") { + (void)r.ReadString(); + } else if (n == "keyDer") { + public_key = r.ReadHex(); + } else if (n == "keysize") { + (void)r.ReadInt(); + } else if (n == "mgf") { + std::string s = r.ReadString(); + ASSERT_EQ(s, "MGF1"); + } else if (n == "mgfSha") { + std::string s = r.ReadString(); + if (s == "SHA-1") { + mgf = CKG_MGF1_SHA1; + } else if (s == "SHA-224") { + mgf = CKG_MGF1_SHA224; + } else if (s == "SHA-256") { + mgf = CKG_MGF1_SHA256; + } else if (s == "SHA-384") { + mgf = CKG_MGF1_SHA384; + } else if (s == "SHA-512") { + mgf = CKG_MGF1_SHA512; + } else { + FAIL() << "unsupported MGF hash"; + } + } else if (n == "sLen") { + s_len = static_cast<unsigned int>(r.ReadInt()); + } else if (n == "sha") { + std::string s = r.ReadString(); + if (s == "SHA-1") { + hash = SEC_OID_SHA1; + } else if (s == "SHA-224") { + hash = SEC_OID_SHA224; + } else if (s == "SHA-256") { + hash = SEC_OID_SHA256; + } else if (s == "SHA-384") { + hash = SEC_OID_SHA384; + } else if (s == "SHA-512") { + hash = SEC_OID_SHA512; + } else { + FAIL() << "unsupported hash"; + } + } else if (n == "type") { + ASSERT_EQ("RsassaPssVerify", r.ReadString()); + } else if (n == "tests") { + ReadTests(r, &tests); + } else { + FAIL() << "unknown test group attribute: " << n; + } + } + + RunTests(public_key, hash, mgf, s_len, tests); + } }; TEST_F(Pkcs11RsaPssTest, GenerateAndSignAndVerify) { @@ -213,33 +379,22 @@ static const Pkcs11SignatureTestParams kRsaPssVectors[] = { INSTANTIATE_TEST_SUITE_P(RsaPssSignVerify, Pkcs11RsaPssVectorTest, ::testing::ValuesIn(kRsaPssVectors)); -TEST_P(Pkcs11RsaPssTestWycheproof, Verify) { Verify(GetParam()); } - -INSTANTIATE_TEST_SUITE_P( - Wycheproof2048RsaPssSha120Test, Pkcs11RsaPssTestWycheproof, - ::testing::ValuesIn(kRsaPss2048Sha120WycheproofVectors)); - -INSTANTIATE_TEST_SUITE_P( - Wycheproof2048RsaPssSha25632Test, Pkcs11RsaPssTestWycheproof, - ::testing::ValuesIn(kRsaPss2048Sha25632WycheproofVectors)); - -INSTANTIATE_TEST_SUITE_P( - Wycheproof2048RsaPssSha2560Test, Pkcs11RsaPssTestWycheproof, - ::testing::ValuesIn(kRsaPss2048Sha2560WycheproofVectors)); - -INSTANTIATE_TEST_SUITE_P( - Wycheproof3072RsaPssSha25632Test, Pkcs11RsaPssTestWycheproof, - ::testing::ValuesIn(kRsaPss3072Sha25632WycheproofVectors)); - -INSTANTIATE_TEST_SUITE_P( - Wycheproof4096RsaPssSha25632Test, Pkcs11RsaPssTestWycheproof, - ::testing::ValuesIn(kRsaPss4096Sha25632WycheproofVectors)); - -INSTANTIATE_TEST_SUITE_P( - Wycheproof4096RsaPssSha51232Test, Pkcs11RsaPssTestWycheproof, - ::testing::ValuesIn(kRsaPss4096Sha51232WycheproofVectors)); - -INSTANTIATE_TEST_SUITE_P(WycheproofRsaPssMiscTest, Pkcs11RsaPssTestWycheproof, - ::testing::ValuesIn(kRsaPssMiscWycheproofVectors)); +TEST_F(Pkcs11RsaPssTestWycheproof, RsaPss2048Sha1) { Run("2048_sha1_mgf1_20"); } +TEST_F(Pkcs11RsaPssTestWycheproof, RsaPss2048Sha256_0) { + Run("2048_sha256_mgf1_0"); +} +TEST_F(Pkcs11RsaPssTestWycheproof, RsaPss2048Sha256_32) { + Run("2048_sha256_mgf1_32"); +} +TEST_F(Pkcs11RsaPssTestWycheproof, RsaPss3072Sha256) { + Run("3072_sha256_mgf1_32"); +} +TEST_F(Pkcs11RsaPssTestWycheproof, RsaPss4096Sha256) { + Run("4096_sha256_mgf1_32"); +} +TEST_F(Pkcs11RsaPssTestWycheproof, RsaPss4096Sha512) { + Run("4096_sha512_mgf1_32"); +} +TEST_F(Pkcs11RsaPssTestWycheproof, RsaPssMisc) { Run("misc"); } } // namespace nss_test |