summaryrefslogtreecommitdiff
path: root/gtests/pk11_gtest
diff options
context:
space:
mode:
authorMartin Thomson <mt@lowentropy.net>2022-02-14 19:14:36 +0000
committerMartin Thomson <mt@lowentropy.net>2022-02-14 19:14:36 +0000
commit2c4b67fb3e56f2f37425423622415cc287ccb5e3 (patch)
treeaeeb84bd27c884f67e1619c51afcb0a4ffe6eb11 /gtests/pk11_gtest
parentf79fdd517a1cdd13012e0bdfec6c17935a1aab95 (diff)
downloadnss-hg-2c4b67fb3e56f2f37425423622415cc287ccb5e3.tar.gz
Bug 1747957 - Use Wycheproof JSON for RSASSA-PSS, r=nss-reviewers,bbeurdouche
Differential Revision: https://phabricator.services.mozilla.com/D134846
Diffstat (limited to 'gtests/pk11_gtest')
-rw-r--r--gtests/pk11_gtest/json.h168
-rw-r--r--gtests/pk11_gtest/pk11_hpke_unittest.cc157
-rw-r--r--gtests/pk11_gtest/pk11_rsapss_unittest.cc239
3 files changed, 369 insertions, 195 deletions
diff --git a/gtests/pk11_gtest/json.h b/gtests/pk11_gtest/json.h
new file mode 100644
index 000000000..1cd6c341c
--- /dev/null
+++ b/gtests/pk11_gtest/json.h
@@ -0,0 +1,168 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#define __json_h__
+#include <vector>
+#include "gtest/gtest.h"
+#include "nss_scoped_ptrs.h"
+#include "pk11pub.h"
+
+// If we make a few assumptions about the file, parsing JSON can be easy.
+// This is not a full parser, it only works on a narrow set of inputs.
+class JsonReader {
+ public:
+ JsonReader(const std::string &n) : buf_(), available_(0), i_(0) {
+ f_.reset(PR_Open(n.c_str(), PR_RDONLY, 00600));
+ EXPECT_TRUE(f_) << "error opening vectors from: " << n;
+ buf_[0] = 0;
+ }
+
+ void next() { i_++; }
+ uint8_t peek() {
+ TopUp();
+ return buf_[i_];
+ }
+ uint8_t take() {
+ uint8_t v = peek();
+ next();
+ return v;
+ }
+
+ // No input checking, overflow protection, or any safety.
+ // Returns 0 if there isn't a number here rather than aborting.
+ uint64_t ReadInt() {
+ SkipWhitespace();
+ uint8_t c = peek();
+ uint64_t v = 0;
+ while (c >= '0' && c <= '9') {
+ v = v * 10 + c - '0';
+ next();
+ c = peek();
+ }
+ return v;
+ }
+
+ // No input checking, no unicode, no escaping (not even \"), just read ASCII.
+ std::string ReadString() {
+ SkipWhitespace();
+ if (peek() != '"') {
+ return "";
+ }
+ next();
+
+ std::string s;
+ uint8_t c = take();
+ while (c != '"') {
+ s.push_back(c);
+ c = take();
+ }
+ return s;
+ }
+
+ std::string ReadLabel() {
+ std::string s = ReadString();
+ SkipWhitespace();
+ EXPECT_EQ(take(), ':');
+ return s;
+ }
+
+ std::vector<uint8_t> ReadHex() {
+ SkipWhitespace();
+ uint8_t c = take();
+ EXPECT_EQ(c, '"');
+ std::vector<uint8_t> v;
+ c = take();
+ while (c != '"') {
+ v.push_back(JsonReader::Hex(c) << 4 | JsonReader::Hex(take()));
+ c = take();
+ }
+ return v;
+ }
+
+ bool NextItem(uint8_t h = '{', uint8_t t = '}') {
+ SkipWhitespace();
+ switch (uint8_t c = take()) {
+ case ',':
+ return true;
+ case '{':
+ case '[':
+ EXPECT_EQ(c, h);
+ SkipWhitespace();
+ if (peek() == t) {
+ next();
+ return false;
+ }
+ return true;
+ case '}':
+ case ']':
+ EXPECT_EQ(c, t);
+ return false;
+ default:
+ ADD_FAILURE() << "Unexpected '" << c << "'";
+ }
+ return false;
+ }
+
+ bool NextItemArray() { return NextItem('[', ']'); }
+
+ void SkipValue() {
+ uint8_t c = take();
+ if (c == '"') {
+ do {
+ c = take();
+ } while (c != '"');
+ } else if (c >= '0' && c <= '9') {
+ c = peek();
+ while (c >= '0' && c <= '9') {
+ next();
+ c = peek();
+ }
+ } else {
+ ADD_FAILURE() << "No idea how to skip'" << c << "'";
+ }
+ }
+
+ private:
+ void TopUp() {
+ if (available_ > i_) {
+ return;
+ }
+ i_ = 0;
+ if (!f_) {
+ return;
+ }
+ PRInt32 res = PR_Read(f_.get(), buf_, sizeof(buf_));
+ if (res > 0) {
+ available_ = static_cast<size_t>(res);
+ } else {
+ available_ = 1;
+ f_.reset(nullptr);
+ buf_[0] = 0;
+ }
+ }
+
+ void SkipWhitespace() {
+ uint8_t c = peek();
+ while (c && (c == ' ' || c == '\t' || c == '\r' || c == '\n')) {
+ next();
+ c = peek();
+ }
+ }
+
+ // This only handles lowercase.
+ uint8_t Hex(uint8_t c) {
+ if (c >= '0' && c <= '9') {
+ return c - '0';
+ }
+ EXPECT_TRUE(c >= 'a' && c <= 'f');
+ return c - 'a' + 10;
+ }
+
+ ScopedPRFileDesc f_;
+ uint8_t buf_[4096];
+ size_t available_;
+ size_t i_;
+};
diff --git a/gtests/pk11_gtest/pk11_hpke_unittest.cc b/gtests/pk11_gtest/pk11_hpke_unittest.cc
index b08a6a5cd..0ed81f899 100644
--- a/gtests/pk11_gtest/pk11_hpke_unittest.cc
+++ b/gtests/pk11_gtest/pk11_hpke_unittest.cc
@@ -7,6 +7,7 @@
#include <memory>
#include "blapi.h"
#include "gtest/gtest.h"
+#include "json.h"
#include "nss.h"
#include "nss_scoped_ptrs.h"
#include "pk11hpke.h"
@@ -224,156 +225,6 @@ class HpkeTest {
}
};
-// If we make a few assumptions about the file, parsing JSON can be easy.
-// This is not a full parser, it only works on a narrow set of inputs.
-class JsonReader {
- public:
- JsonReader(const std::string &n) : buf_(), available_(0), i_(0) {
- f_.reset(PR_Open(n.c_str(), PR_RDONLY, 00600));
- EXPECT_TRUE(f_) << "error opening vectors from: " << n;
- buf_[0] = 0;
- }
-
- void next() { i_++; }
- uint8_t peek() {
- TopUp();
- return buf_[i_];
- }
- uint8_t take() {
- uint8_t v = peek();
- next();
- return v;
- }
-
- // No input checking, overflow protection, or any safety.
- // Returns 0 if there isn't a number here rather than aborting.
- uint64_t ReadInt() {
- SkipWhitespace();
- uint8_t c = peek();
- uint64_t v = 0;
- while (c >= '0' && c <= '9') {
- v = v * 10 + c - '0';
- next();
- c = peek();
- }
- return v;
- }
-
- // No input checking, no unicode, no escaping (not even \"), just read ASCII.
- std::string ReadLabel() {
- SkipWhitespace();
- if (peek() != '"') {
- return "";
- }
- next();
-
- std::string s;
- uint8_t c = take();
- while (c != '"') {
- s.push_back(c);
- c = take();
- }
- SkipWhitespace();
- EXPECT_EQ(take(), ':');
- return s;
- }
-
- std::vector<uint8_t> ReadHex() {
- SkipWhitespace();
- uint8_t c = take();
- EXPECT_EQ(c, '"');
- std::vector<uint8_t> v;
- c = take();
- while (c != '"') {
- v.push_back(JsonReader::Hex(c) << 4 | JsonReader::Hex(take()));
- c = take();
- }
- return v;
- }
-
- bool NextItem(uint8_t h = '{', uint8_t t = '}') {
- SkipWhitespace();
- switch (uint8_t c = take()) {
- case ',':
- return true;
- case '{':
- case '[':
- EXPECT_EQ(c, h);
- SkipWhitespace();
- if (peek() == t) {
- next();
- return false;
- }
- return true;
- case '}':
- case ']':
- EXPECT_EQ(c, t);
- return false;
- default:
- ADD_FAILURE() << "Unexpected '" << c << "'";
- }
- return false;
- }
-
- void SkipValue() {
- uint8_t c = take();
- if (c == '"') {
- do {
- c = take();
- } while (c != '"');
- } else if (c >= '0' && c <= '9') {
- c = peek();
- while (c >= '0' && c <= '9') {
- next();
- c = peek();
- }
- } else {
- ADD_FAILURE() << "No idea how to skip'" << c << "'";
- }
- }
-
- private:
- void TopUp() {
- if (available_ > i_) {
- return;
- }
- i_ = 0;
- if (!f_) {
- return;
- }
- PRInt32 res = PR_Read(f_.get(), buf_, sizeof(buf_));
- if (res > 0) {
- available_ = static_cast<size_t>(res);
- } else {
- available_ = 1;
- f_.reset(nullptr);
- buf_[0] = 0;
- }
- }
-
- void SkipWhitespace() {
- uint8_t c = peek();
- while (c && (c == ' ' || c == '\t' || c == '\r' || c == '\n')) {
- next();
- c = peek();
- }
- }
-
- // This only handles lowercase.
- uint8_t Hex(uint8_t c) {
- if (c >= '0' && c <= '9') {
- return c - '0';
- }
- EXPECT_TRUE(c >= 'a' && c <= 'f');
- return c - 'a' + 10;
- }
-
- ScopedPRFileDesc f_;
- uint8_t buf_[4096];
- size_t available_;
- size_t i_;
-};
-
struct HpkeEncryptVector {
std::vector<uint8_t> pt;
std::vector<uint8_t> aad;
@@ -382,7 +233,7 @@ struct HpkeEncryptVector {
static std::vector<HpkeEncryptVector> ReadVec(JsonReader &r) {
std::vector<HpkeEncryptVector> all;
- while (r.NextItem('[', ']')) {
+ while (r.NextItemArray()) {
HpkeEncryptVector enc;
while (r.NextItem()) {
std::string n = r.ReadLabel();
@@ -414,7 +265,7 @@ struct HpkeExportVector {
static std::vector<HpkeExportVector> ReadVec(JsonReader &r) {
std::vector<HpkeExportVector> all;
- while (r.NextItem('[', ']')) {
+ while (r.NextItemArray()) {
HpkeExportVector exp;
while (r.NextItem()) {
std::string n = r.ReadLabel();
@@ -476,7 +327,7 @@ struct HpkeVector {
std::vector<HpkeVector> all_tests;
uint32_t test_id = 0;
- while (r.NextItem('[', ']')) {
+ while (r.NextItemArray()) {
HpkeVector vec = {0};
uint32_t fields = 0;
enum class RequiredFields {
diff --git a/gtests/pk11_gtest/pk11_rsapss_unittest.cc b/gtests/pk11_gtest/pk11_rsapss_unittest.cc
index e8428f794..53e4b342e 100644
--- a/gtests/pk11_gtest/pk11_rsapss_unittest.cc
+++ b/gtests/pk11_gtest/pk11_rsapss_unittest.cc
@@ -8,6 +8,7 @@
#include "nss.h"
#include "pk11pub.h"
#include "sechash.h"
+#include "json.h"
#include "databuffer.h"
@@ -16,14 +17,9 @@
#include "pk11_signature_test.h"
#include "pk11_rsapss_vectors.h"
+#include "testvectors_base/test-structs.h"
-#include "testvectors/rsa_pss_2048_sha256_mgf1_32-vectors.h"
-#include "testvectors/rsa_pss_2048_sha1_mgf1_20-vectors.h"
-#include "testvectors/rsa_pss_2048_sha256_mgf1_0-vectors.h"
-#include "testvectors/rsa_pss_3072_sha256_mgf1_32-vectors.h"
-#include "testvectors/rsa_pss_4096_sha256_mgf1_32-vectors.h"
-#include "testvectors/rsa_pss_4096_sha512_mgf1_32-vectors.h"
-#include "testvectors/rsa_pss_misc-vectors.h"
+extern std::string g_source_dir;
namespace nss_test {
@@ -46,7 +42,7 @@ CK_MECHANISM_TYPE RsaPssMapCombo(SECOidTag hashOid) {
}
class Pkcs11RsaPssTestBase : public Pk11SignatureTest {
- protected:
+ public:
Pkcs11RsaPssTestBase(SECOidTag hashOid, CK_RSA_PKCS_MGF_TYPE mgf, int sLen)
: Pk11SignatureTest(CKM_RSA_PKCS_PSS, hashOid, RsaPssMapCombo(hashOid)) {
pss_params_.hashAlg = PK11_AlgtagToMechanism(hashOid);
@@ -80,13 +76,183 @@ class Pkcs11RsaPssTest : public Pkcs11RsaPssTestBase {
: Pkcs11RsaPssTestBase(SEC_OID_SHA1, CKG_MGF1_SHA1, SHA1_LENGTH) {}
};
-class Pkcs11RsaPssTestWycheproof
- : public Pkcs11RsaPssTestBase,
- public ::testing::WithParamInterface<RsaPssTestVector> {
+class Pkcs11RsaPssTestWycheproof : public ::testing::Test {
public:
- Pkcs11RsaPssTestWycheproof()
- : Pkcs11RsaPssTestBase(GetParam().hash_oid, GetParam().mgf_hash,
- GetParam().sLen) {}
+ Pkcs11RsaPssTestWycheproof() {}
+
+ void Run(const std::string& file) {
+ std::string basename = "rsa_pss_" + file + "_test.json";
+ std::string dir = ::g_source_dir + "/../common/wycheproof/source_vectors/";
+ std::cout << "Running tests from: " << basename << std::endl;
+
+ JsonReader r(dir + basename);
+ while (r.NextItem()) {
+ std::string n = r.ReadLabel();
+ if (n == "") {
+ break;
+ }
+ if (n == "algorithm") {
+ ASSERT_EQ("RSASSA-PSS", r.ReadString());
+ } else if (n == "generatorVersion") {
+ (void)r.ReadString();
+ } else if (n == "numberOfTests") {
+ (void)r.ReadInt();
+ } else if (n == "header") {
+ while (r.NextItemArray()) {
+ std::cout << r.ReadString() << std::endl;
+ }
+ } else if (n == "notes") {
+ while (r.NextItem()) {
+ std::string note = r.ReadLabel();
+ if (note == "") {
+ break;
+ }
+ std::cout << note << ": " << r.ReadString() << std::endl;
+ }
+ } else if (n == "schema") {
+ ASSERT_EQ("rsassa_pss_verify_schema.json", r.ReadString());
+ } else if (n == "testGroups") {
+ while (r.NextItemArray()) {
+ RunGroup(r);
+ }
+ }
+ }
+ }
+
+ private:
+ struct TestVector {
+ uint64_t id;
+ std::vector<uint8_t> msg;
+ std::vector<uint8_t> sig;
+ bool valid;
+ };
+
+ class Pkcs11RsaPssTestWrap : public Pkcs11RsaPssTestBase {
+ public:
+ Pkcs11RsaPssTestWrap(SECOidTag hash, CK_RSA_PKCS_MGF_TYPE mgf, int s_len)
+ : Pkcs11RsaPssTestBase(hash, mgf, s_len) {}
+
+ void TestBody() {}
+
+ void Verify(const Pkcs11SignatureTestParams& params, bool valid) {
+ Pk11SignatureTest::Verify(params, valid);
+ }
+ };
+
+ void ReadTests(JsonReader& r, std::vector<TestVector>* tests) {
+ while (r.NextItemArray()) {
+ TestVector t;
+ while (r.NextItem()) {
+ std::string n = r.ReadLabel();
+ if (n == "") {
+ break;
+ }
+ if (n == "tcId") {
+ t.id = r.ReadInt();
+ } else if (n == "comment") {
+ (void)r.ReadString();
+ } else if (n == "msg") {
+ t.msg = r.ReadHex();
+ } else if (n == "sig") {
+ t.sig = r.ReadHex();
+ } else if (n == "result") {
+ std::string s = r.ReadString();
+ t.valid = (s == "valid" || s == "acceptable");
+ } else if (n == "flags") {
+ while (r.NextItemArray()) {
+ (void)r.ReadString();
+ }
+ } else {
+ FAIL() << "unknown test entry attribute";
+ }
+ }
+ tests->push_back(t);
+ }
+ }
+
+ void RunTests(const std::vector<uint8_t>& public_key, SECOidTag hash,
+ CK_RSA_PKCS_MGF_TYPE mgf, int s_len,
+ const std::vector<TestVector>& tests) {
+ ASSERT_NE(0u, public_key.size());
+ ASSERT_NE(SEC_OID_UNKNOWN, hash);
+ ASSERT_NE(CKM_INVALID_MECHANISM, mgf);
+ ASSERT_NE(0u, tests.size());
+
+ for (auto& v : tests) {
+ std::cout << "Running tcid: " << v.id << std::endl;
+
+ Pkcs11RsaPssTestWrap test(hash, mgf, s_len);
+ Pkcs11SignatureTestParams params = {
+ DataBuffer(), DataBuffer(public_key.data(), public_key.size()),
+ DataBuffer(v.msg.data(), v.msg.size()),
+ DataBuffer(v.sig.data(), v.sig.size())};
+ test.Verify(params, v.valid);
+ }
+ }
+
+ void RunGroup(JsonReader& r) {
+ std::vector<uint8_t> public_key;
+ SECOidTag hash = SEC_OID_UNKNOWN;
+ CK_RSA_PKCS_MGF_TYPE mgf = CKM_INVALID_MECHANISM;
+ int s_len = 0;
+ std::vector<TestVector> tests;
+ while (r.NextItem()) {
+ std::string n = r.ReadLabel();
+ if (n == "") {
+ break;
+ }
+ if (n == "e" || n == "keyAsn" || n == "keyPem" || n == "n") {
+ (void)r.ReadString();
+ } else if (n == "keyDer") {
+ public_key = r.ReadHex();
+ } else if (n == "keysize") {
+ (void)r.ReadInt();
+ } else if (n == "mgf") {
+ std::string s = r.ReadString();
+ ASSERT_EQ(s, "MGF1");
+ } else if (n == "mgfSha") {
+ std::string s = r.ReadString();
+ if (s == "SHA-1") {
+ mgf = CKG_MGF1_SHA1;
+ } else if (s == "SHA-224") {
+ mgf = CKG_MGF1_SHA224;
+ } else if (s == "SHA-256") {
+ mgf = CKG_MGF1_SHA256;
+ } else if (s == "SHA-384") {
+ mgf = CKG_MGF1_SHA384;
+ } else if (s == "SHA-512") {
+ mgf = CKG_MGF1_SHA512;
+ } else {
+ FAIL() << "unsupported MGF hash";
+ }
+ } else if (n == "sLen") {
+ s_len = static_cast<unsigned int>(r.ReadInt());
+ } else if (n == "sha") {
+ std::string s = r.ReadString();
+ if (s == "SHA-1") {
+ hash = SEC_OID_SHA1;
+ } else if (s == "SHA-224") {
+ hash = SEC_OID_SHA224;
+ } else if (s == "SHA-256") {
+ hash = SEC_OID_SHA256;
+ } else if (s == "SHA-384") {
+ hash = SEC_OID_SHA384;
+ } else if (s == "SHA-512") {
+ hash = SEC_OID_SHA512;
+ } else {
+ FAIL() << "unsupported hash";
+ }
+ } else if (n == "type") {
+ ASSERT_EQ("RsassaPssVerify", r.ReadString());
+ } else if (n == "tests") {
+ ReadTests(r, &tests);
+ } else {
+ FAIL() << "unknown test group attribute: " << n;
+ }
+ }
+
+ RunTests(public_key, hash, mgf, s_len, tests);
+ }
};
TEST_F(Pkcs11RsaPssTest, GenerateAndSignAndVerify) {
@@ -213,33 +379,22 @@ static const Pkcs11SignatureTestParams kRsaPssVectors[] = {
INSTANTIATE_TEST_SUITE_P(RsaPssSignVerify, Pkcs11RsaPssVectorTest,
::testing::ValuesIn(kRsaPssVectors));
-TEST_P(Pkcs11RsaPssTestWycheproof, Verify) { Verify(GetParam()); }
-
-INSTANTIATE_TEST_SUITE_P(
- Wycheproof2048RsaPssSha120Test, Pkcs11RsaPssTestWycheproof,
- ::testing::ValuesIn(kRsaPss2048Sha120WycheproofVectors));
-
-INSTANTIATE_TEST_SUITE_P(
- Wycheproof2048RsaPssSha25632Test, Pkcs11RsaPssTestWycheproof,
- ::testing::ValuesIn(kRsaPss2048Sha25632WycheproofVectors));
-
-INSTANTIATE_TEST_SUITE_P(
- Wycheproof2048RsaPssSha2560Test, Pkcs11RsaPssTestWycheproof,
- ::testing::ValuesIn(kRsaPss2048Sha2560WycheproofVectors));
-
-INSTANTIATE_TEST_SUITE_P(
- Wycheproof3072RsaPssSha25632Test, Pkcs11RsaPssTestWycheproof,
- ::testing::ValuesIn(kRsaPss3072Sha25632WycheproofVectors));
-
-INSTANTIATE_TEST_SUITE_P(
- Wycheproof4096RsaPssSha25632Test, Pkcs11RsaPssTestWycheproof,
- ::testing::ValuesIn(kRsaPss4096Sha25632WycheproofVectors));
-
-INSTANTIATE_TEST_SUITE_P(
- Wycheproof4096RsaPssSha51232Test, Pkcs11RsaPssTestWycheproof,
- ::testing::ValuesIn(kRsaPss4096Sha51232WycheproofVectors));
-
-INSTANTIATE_TEST_SUITE_P(WycheproofRsaPssMiscTest, Pkcs11RsaPssTestWycheproof,
- ::testing::ValuesIn(kRsaPssMiscWycheproofVectors));
+TEST_F(Pkcs11RsaPssTestWycheproof, RsaPss2048Sha1) { Run("2048_sha1_mgf1_20"); }
+TEST_F(Pkcs11RsaPssTestWycheproof, RsaPss2048Sha256_0) {
+ Run("2048_sha256_mgf1_0");
+}
+TEST_F(Pkcs11RsaPssTestWycheproof, RsaPss2048Sha256_32) {
+ Run("2048_sha256_mgf1_32");
+}
+TEST_F(Pkcs11RsaPssTestWycheproof, RsaPss3072Sha256) {
+ Run("3072_sha256_mgf1_32");
+}
+TEST_F(Pkcs11RsaPssTestWycheproof, RsaPss4096Sha256) {
+ Run("4096_sha256_mgf1_32");
+}
+TEST_F(Pkcs11RsaPssTestWycheproof, RsaPss4096Sha512) {
+ Run("4096_sha512_mgf1_32");
+}
+TEST_F(Pkcs11RsaPssTestWycheproof, RsaPssMisc) { Run("misc"); }
} // namespace nss_test