summaryrefslogtreecommitdiff
path: root/gtests/ssl_gtest
diff options
context:
space:
mode:
authorDennis Jackson <djackson@mozilla.com>2022-06-16 11:22:49 +0000
committerDennis Jackson <djackson@mozilla.com>2022-06-16 11:22:49 +0000
commitd0681e2eacbe08c4b02de798cddffcae2c003a61 (patch)
tree39251cd4c321ef25bf4497fd5d7cbf8140e8bf40 /gtests/ssl_gtest
parent275120fccb522a8c54d84d417e70fc061048df34 (diff)
downloadnss-hg-d0681e2eacbe08c4b02de798cddffcae2c003a61.tar.gz
Bug 1617956 - Add support for asynchronous client auth hooks. r=mt
Differential Revision: https://phabricator.services.mozilla.com/D138149
Diffstat (limited to 'gtests/ssl_gtest')
-rw-r--r--gtests/ssl_gtest/ssl_auth_unittest.cc274
-rw-r--r--gtests/ssl_gtest/tls_agent.cc124
-rw-r--r--gtests/ssl_gtest/tls_agent.h20
3 files changed, 337 insertions, 81 deletions
diff --git a/gtests/ssl_gtest/ssl_auth_unittest.cc b/gtests/ssl_gtest/ssl_auth_unittest.cc
index 925b82721..c71c0062e 100644
--- a/gtests/ssl_gtest/ssl_auth_unittest.cc
+++ b/gtests/ssl_gtest/ssl_auth_unittest.cc
@@ -169,13 +169,6 @@ TEST_P(TlsConnectGenericPre13, ServerAuthRejectAsync) {
server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning);
}
-TEST_P(TlsConnectGeneric, ClientAuth) {
- client_->SetupClientAuth();
- server_->RequestClientAuth(true);
- Connect();
- CheckKeys();
-}
-
class TlsCertificateRequestContextRecorder : public TlsHandshakeFilter {
public:
TlsCertificateRequestContextRecorder(const std::shared_ptr<TlsAgent>& a,
@@ -204,16 +197,135 @@ class TlsCertificateRequestContextRecorder : public TlsHandshakeFilter {
bool filtered_;
};
-// All stream only tests; DTLS isn't supported yet.
+using ClientAuthParam =
+ std::tuple<SSLProtocolVariant, uint16_t, ClientAuthCallbackType>;
+
+class TlsConnectClientAuth
+ : public TlsConnectTestBase,
+ public testing::WithParamInterface<ClientAuthParam> {
+ public:
+ TlsConnectClientAuth()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+};
+
+// Wrapper classes for tests that target specific versions
+
+class TlsConnectClientAuth13 : public TlsConnectClientAuth {};
+
+class TlsConnectClientAuth12 : public TlsConnectClientAuth {};
+
+class TlsConnectClientAuthStream13 : public TlsConnectClientAuth {};
+
+class TlsConnectClientAuthPre13 : public TlsConnectClientAuth {};
+
+class TlsConnectClientAuth12Plus : public TlsConnectClientAuth {};
+
+std::string getClientAuthTestName(
+ testing::TestParamInfo<ClientAuthParam> info) {
+ auto param = info.param;
+ auto variant = std::get<0>(param);
+ auto version = std::get<1>(param);
+ auto callback_type = std::get<2>(param);
+
+ std::string output = std::string();
+ switch (variant) {
+ case ssl_variant_stream:
+ output.append("TLS");
+ break;
+ case ssl_variant_datagram:
+ output.append("DTLS");
+ break;
+ }
+ output.append(VersionString(version).replace(1, 1, ""));
+ switch (callback_type) {
+ case ClientAuthCallbackType::kAsyncImmediate:
+ output.append("AsyncImmediate");
+ break;
+ case ClientAuthCallbackType::kAsyncDelay:
+ output.append("AsyncDelay");
+ break;
+ case ClientAuthCallbackType::kSync:
+ output.append("Sync");
+ break;
+ case ClientAuthCallbackType::kNone:
+ output.append("None");
+ break;
+ }
+ return output;
+}
+
+auto kClientAuthCallbacks = testing::Values(
+ ClientAuthCallbackType::kAsyncImmediate,
+ ClientAuthCallbackType::kAsyncDelay, ClientAuthCallbackType::kSync,
+ ClientAuthCallbackType::kNone);
+
+INSTANTIATE_TEST_SUITE_P(
+ ClientAuthGenericStream, TlsConnectClientAuth,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll, kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(
+ ClientAuthGenericDatagram, TlsConnectClientAuth,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus, kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(ClientAuth13, TlsConnectClientAuth13,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13,
+ kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(
+ ClientAuth13, TlsConnectClientAuthStream13,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV13, kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(ClientAuth12, TlsConnectClientAuth12,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12,
+ kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(
+ ClientAuthPre13Stream, TlsConnectClientAuthPre13,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12, kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(
+ ClientAuthPre13Datagram, TlsConnectClientAuthPre13,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11V12, kClientAuthCallbacks),
+ getClientAuthTestName);
+
+INSTANTIATE_TEST_SUITE_P(ClientAuth12Plus, TlsConnectClientAuth12Plus,
+ testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV12Plus,
+ kClientAuthCallbacks),
+ getClientAuthTestName);
+
+TEST_P(TlsConnectClientAuth, ClientAuth) {
+ EnsureTlsSetup();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys();
+ client_->CheckClientAuthCompleted();
+}
+
+// All stream only tests; PostHandshakeAuth isn't supported for DTLS.
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuth) {
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuth) {
EnsureTlsSetup();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
auto capture_cert_req = MakeTlsFilter<TlsCertificateRequestContextRecorder>(
server_, kTlsHandshakeCertificateRequest);
auto capture_certificate =
MakeTlsFilter<TlsCertificateRequestContextRecorder>(
client_, kTlsHandshakeCertificate);
- client_->SetupClientAuth();
client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE);
size_t called = 0;
server_->SetAuthCertificateCallback(
@@ -232,11 +344,15 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuth) {
// handled on both client and server.
server_->SendData(50);
client_->ReadBytes(50);
+ client_->ClientAuthCallbackComplete();
client_->SendData(50);
server_->ReadBytes(50);
+
EXPECT_EQ(1U, called);
- EXPECT_TRUE(capture_cert_req->filtered());
- EXPECT_TRUE(capture_certificate->filtered());
+ ASSERT_TRUE(capture_cert_req->filtered());
+ ASSERT_TRUE(capture_certificate->filtered());
+
+ client_->CheckClientAuthCompleted();
// Check if a non-empty request context is generated and it is
// properly sent back.
EXPECT_LT(0U, capture_cert_req->buffer().len());
@@ -252,7 +368,7 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuth) {
EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
}
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuthAfterResumption) {
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthAfterResumption) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
Connect();
@@ -267,7 +383,7 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthAfterResumption) {
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
ExpectResumption(RESUME_TICKET);
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE);
Connect();
SendReceive();
@@ -280,10 +396,14 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthAfterResumption) {
});
EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
<< "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+
server_->SendData(50);
client_->ReadBytes(50);
+ client_->ClientAuthCallbackComplete();
client_->SendData(50);
server_->ReadBytes(50);
+
+ client_->CheckClientAuthCompleted();
EXPECT_EQ(1U, called);
ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd()));
@@ -429,8 +549,8 @@ TEST_P(TlsConnectTls12, AutoClientSelectDsa) {
EXPECT_TRUE(dsa.hookCalled);
}
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuthMultiple) {
- client_->SetupClientAuth();
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthMultiple) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
size_t called = 0;
@@ -445,36 +565,42 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthMultiple) {
// Send 1st CertificateRequest.
EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
<< "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+
server_->SendData(50);
client_->ReadBytes(50);
+ client_->ClientAuthCallbackComplete();
+ client_->ReadBytes(50);
client_->SendData(50);
server_->ReadBytes(50);
EXPECT_EQ(1U, called);
+ client_->CheckClientAuthCompleted(1);
ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd()));
ASSERT_NE(nullptr, cert1.get());
ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd()));
ASSERT_NE(nullptr, cert2.get());
EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
// Send 2nd CertificateRequest.
- EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook(
- client_->ssl_fd(), GetClientAuthDataHook, nullptr));
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd()))
<< "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
+
server_->SendData(50);
client_->ReadBytes(50);
+ client_->ClientAuthCallbackComplete();
+ client_->ReadBytes(50);
client_->SendData(50);
server_->ReadBytes(50);
+ client_->CheckClientAuthCompleted(2);
EXPECT_EQ(2U, called);
ScopedCERTCertificate cert3(SSL_PeerCertificate(server_->ssl_fd()));
ASSERT_NE(nullptr, cert3.get());
ScopedCERTCertificate cert4(SSL_LocalCertificate(client_->ssl_fd()));
ASSERT_NE(nullptr, cert4.get());
EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert4->derCert));
- EXPECT_FALSE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert1->derCert));
}
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuthConcurrent) {
- client_->SetupClientAuth();
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthConcurrent) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
Connect();
@@ -486,8 +612,8 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthConcurrent) {
EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
}
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuthBeforeKeyUpdate) {
- client_->SetupClientAuth();
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthBeforeKeyUpdate) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
Connect();
@@ -499,8 +625,9 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthBeforeKeyUpdate) {
EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
}
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDuringClientKeyUpdate) {
- client_->SetupClientAuth();
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthDuringClientKeyUpdate) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
+ ;
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
Connect();
@@ -514,11 +641,14 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDuringClientKeyUpdate) {
client_->SendData(50); // client sends KeyUpdate
server_->ReadBytes(50); // server receives KeyUpdate and defers response
CheckEpochs(4, 3);
- client_->ReadBytes(50); // client receives CertificateRequest
+ client_->ReadBytes(60); // client receives CertificateRequest
+ client_->ClientAuthCallbackComplete();
+ client_->ReadBytes(50); // Finish reading the remaining bytes
client_->SendData(
50); // client sends Certificate, CertificateVerify, Finished
server_->ReadBytes(
50); // server receives Certificate, CertificateVerify, Finished
+ client_->CheckClientAuthCompleted();
client_->CheckEpochs(3, 4);
server_->CheckEpochs(4, 4);
server_->SendData(50); // server sends KeyUpdate
@@ -526,8 +656,8 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDuringClientKeyUpdate) {
client_->CheckEpochs(4, 4);
}
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuthMissingExtension) {
- client_->SetupClientAuth();
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthMissingExtension) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
Connect();
// Send CertificateRequest, should fail due to missing
// post_handshake_auth extension.
@@ -535,8 +665,8 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthMissingExtension) {
EXPECT_EQ(SSL_ERROR_MISSING_POST_HANDSHAKE_AUTH_EXTENSION, PORT_GetError());
}
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuthAfterClientAuth) {
- client_->SetupClientAuth();
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthAfterClientAuth) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
server_->RequestClientAuth(true);
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
@@ -593,10 +723,10 @@ class TlsDamageCertificateRequestContextFilter : public TlsHandshakeFilter {
}
};
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuthContextMismatch) {
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthContextMismatch) {
EnsureTlsSetup();
MakeTlsFilter<TlsDamageCertificateRequestContextFilter>(server_);
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
Connect();
@@ -605,6 +735,8 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthContextMismatch) {
<< "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
server_->SendData(50);
client_->ReadBytes(50);
+ client_->ClientAuthCallbackComplete();
+ client_->ReadBytes(50);
client_->SendData(50);
server_->ExpectSendAlert(kTlsAlertIllegalParameter);
server_->ReadBytes(50);
@@ -636,10 +768,10 @@ class TlsDamageSignatureFilter : public TlsHandshakeFilter {
}
};
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuthBadSignature) {
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthBadSignature) {
EnsureTlsSetup();
MakeTlsFilter<TlsDamageSignatureFilter>(client_);
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
Connect();
@@ -648,20 +780,22 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthBadSignature) {
<< "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
server_->SendData(50);
client_->ReadBytes(50);
+ client_->ClientAuthCallbackComplete();
client_->SendData(50);
+ client_->CheckClientAuthCompleted();
server_->ExpectSendAlert(kTlsAlertDecodeError);
server_->ReadBytes(50);
EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CERT_VERIFY, PORT_GetError());
}
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDecline) {
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthDecline) {
EnsureTlsSetup();
auto capture_cert_req = MakeTlsFilter<TlsCertificateRequestContextRecorder>(
server_, kTlsHandshakeCertificateRequest);
auto capture_certificate =
MakeTlsFilter<TlsCertificateRequestContextRecorder>(
client_, kTlsHandshakeCertificate);
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
EXPECT_EQ(SECSuccess,
@@ -709,9 +843,10 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDecline) {
// Check if post-handshake auth still works when session tickets are enabled:
// https://bugzilla.mozilla.org/show_bug.cgi?id=1553443
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuthWithSessionTicketsEnabled) {
+TEST_P(TlsConnectClientAuthStream13,
+ PostHandshakeAuthWithSessionTicketsEnabled) {
EnsureTlsSetup();
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE));
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
@@ -743,7 +878,8 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthWithSessionTicketsEnabled) {
EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
}
-TEST_P(TlsConnectGenericPre13, ClientAuthRequiredRejected) {
+TEST_P(TlsConnectClientAuthPre13, ClientAuthRequiredRejected) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), false);
server_->RequestClientAuth(true);
ConnectExpectAlert(server_, kTlsAlertBadCertificate);
client_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT);
@@ -752,13 +888,16 @@ TEST_P(TlsConnectGenericPre13, ClientAuthRequiredRejected) {
// In TLS 1.3, the client will claim that the connection is done and then
// receive the alert afterwards. So drive the handshake manually.
-TEST_P(TlsConnectTls13, ClientAuthRequiredRejected) {
+TEST_P(TlsConnectClientAuth13, ClientAuthRequiredRejected) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), false);
server_->RequestClientAuth(true);
StartConnect();
client_->Handshake(); // CH
server_->Handshake(); // SH.. (no resumption)
+
client_->Handshake(); // Next message
ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ client_->CheckClientAuthCompleted();
ExpectAlert(server_, kTlsAlertCertificateRequired);
server_->Handshake(); // Alert
server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE);
@@ -766,33 +905,34 @@ TEST_P(TlsConnectTls13, ClientAuthRequiredRejected) {
client_->CheckErrorCode(SSL_ERROR_RX_CERTIFICATE_REQUIRED_ALERT);
}
-TEST_P(TlsConnectGeneric, ClientAuthRequestedRejected) {
+TEST_P(TlsConnectClientAuth, ClientAuthRequestedRejected) {
+ client_->SetupClientAuth(std::get<2>(GetParam()), false);
server_->RequestClientAuth(false);
Connect();
CheckKeys();
}
-TEST_P(TlsConnectGeneric, ClientAuthEcdsa) {
+TEST_P(TlsConnectClientAuth, ClientAuthEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
server_->RequestClientAuth(true);
Connect();
CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
}
-TEST_P(TlsConnectGeneric, ClientAuthWithEch) {
+TEST_P(TlsConnectClientAuth, ClientAuthWithEch) {
Reset(TlsAgent::kServerEcdsa256);
EnsureTlsSetup();
SetupEch(client_, server_);
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
server_->RequestClientAuth(true);
Connect();
CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
}
-TEST_P(TlsConnectGeneric, ClientAuthBigRsa) {
+TEST_P(TlsConnectClientAuth, ClientAuthBigRsa) {
Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048);
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
server_->RequestClientAuth(true);
Connect();
CheckKeys();
@@ -835,11 +975,11 @@ TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
1024);
}
-TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
+TEST_P(TlsConnectClientAuth12, ClientAuthCheckSigAlg) {
EnsureTlsSetup();
auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>(
client_, kTlsHandshakeCertificateVerify);
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
server_->RequestClientAuth(true);
Connect();
CheckKeys();
@@ -847,11 +987,11 @@ TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pkcs1_sha1, 1024);
}
-TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) {
+TEST_P(TlsConnectClientAuth12, ClientAuthBigRsaCheckSigAlg) {
Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048);
auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>(
client_, kTlsHandshakeCertificateVerify);
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
server_->RequestClientAuth(true);
Connect();
CheckKeys();
@@ -886,7 +1026,7 @@ class TlsReplaceSignatureSchemeFilter : public TlsHandshakeFilter {
// This only works under TLS 1.2, because PSS doesn't work with TLS
// 1.0 or TLS 1.1 and the TLS 1.3 1-RTT handshake is partially
// successful at the client side.
-TEST_P(TlsConnectTls12, ClientAuthInconsistentRsaeSignatureScheme) {
+TEST_P(TlsConnectClientAuth12, ClientAuthInconsistentRsaeSignatureScheme) {
static const SSLSignatureScheme kSignatureSchemePss[] = {
ssl_sig_rsa_pss_pss_sha256, ssl_sig_rsa_pss_rsae_sha256};
@@ -895,7 +1035,7 @@ TEST_P(TlsConnectTls12, ClientAuthInconsistentRsaeSignatureScheme) {
PR_ARRAY_SIZE(kSignatureSchemePss));
server_->SetSignatureSchemes(kSignatureSchemePss,
PR_ARRAY_SIZE(kSignatureSchemePss));
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
server_->RequestClientAuth(true);
EnsureTlsSetup();
@@ -912,7 +1052,7 @@ TEST_P(TlsConnectTls12, ClientAuthInconsistentRsaeSignatureScheme) {
// This only works under TLS 1.2, because PSS doesn't work with TLS
// 1.0 or TLS 1.1 and the TLS 1.3 1-RTT handshake is partially
// successful at the client side.
-TEST_P(TlsConnectTls12, ClientAuthInconsistentPssSignatureScheme) {
+TEST_P(TlsConnectClientAuth12, ClientAuthInconsistentPssSignatureScheme) {
static const SSLSignatureScheme kSignatureSchemePss[] = {
ssl_sig_rsa_pss_rsae_sha256, ssl_sig_rsa_pss_pss_sha256};
@@ -921,7 +1061,7 @@ TEST_P(TlsConnectTls12, ClientAuthInconsistentPssSignatureScheme) {
PR_ARRAY_SIZE(kSignatureSchemePss));
server_->SetSignatureSchemes(kSignatureSchemePss,
PR_ARRAY_SIZE(kSignatureSchemePss));
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
server_->RequestClientAuth(true);
EnsureTlsSetup();
@@ -932,7 +1072,7 @@ TEST_P(TlsConnectTls12, ClientAuthInconsistentPssSignatureScheme) {
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
}
-TEST_P(TlsConnectTls13, ClientAuthPkcs1SignatureScheme) {
+TEST_P(TlsConnectClientAuth13, ClientAuthPkcs1SignatureScheme) {
static const SSLSignatureScheme kSignatureScheme[] = {
ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pss_rsae_sha256};
@@ -941,7 +1081,7 @@ TEST_P(TlsConnectTls13, ClientAuthPkcs1SignatureScheme) {
PR_ARRAY_SIZE(kSignatureScheme));
server_->SetSignatureSchemes(kSignatureScheme,
PR_ARRAY_SIZE(kSignatureScheme));
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
server_->RequestClientAuth(true);
auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>(
@@ -954,14 +1094,14 @@ TEST_P(TlsConnectTls13, ClientAuthPkcs1SignatureScheme) {
}
// Client should refuse to connect without a usable signature scheme.
-TEST_P(TlsConnectTls13, ClientAuthPkcs1SignatureSchemeOnly) {
+TEST_P(TlsConnectClientAuth13, ClientAuthPkcs1SignatureSchemeOnly) {
static const SSLSignatureScheme kSignatureScheme[] = {
ssl_sig_rsa_pkcs1_sha256};
Reset(TlsAgent::kServerRsa, "rsa");
client_->SetSignatureSchemes(kSignatureScheme,
PR_ARRAY_SIZE(kSignatureScheme));
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
client_->StartConnect();
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
@@ -970,14 +1110,14 @@ TEST_P(TlsConnectTls13, ClientAuthPkcs1SignatureSchemeOnly) {
// Though the client has a usable signature scheme, when a certificate is
// requested, it can't produce one.
-TEST_P(TlsConnectTls13, ClientAuthPkcs1AndEcdsaScheme) {
+TEST_P(TlsConnectClientAuth13, ClientAuthPkcs1AndEcdsaScheme) {
static const SSLSignatureScheme kSignatureScheme[] = {
ssl_sig_rsa_pkcs1_sha256, ssl_sig_ecdsa_secp256r1_sha256};
Reset(TlsAgent::kServerRsa, "rsa");
client_->SetSignatureSchemes(kSignatureScheme,
PR_ARRAY_SIZE(kSignatureScheme));
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
server_->RequestClientAuth(true);
ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
@@ -1031,12 +1171,12 @@ class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter {
// Check that we send an alert when the server doesn't provide any
// supported_signature_algorithms in the CertificateRequest message.
-TEST_P(TlsConnectTls12, ClientAuthNoSigAlgs) {
+TEST_P(TlsConnectClientAuth12, ClientAuthNoSigAlgs) {
EnsureTlsSetup();
MakeTlsFilter<TlsZeroCertificateRequestSigAlgsFilter>(server_);
auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>(
client_, kTlsHandshakeCertificateVerify);
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
server_->RequestClientAuth(true);
ConnectExpectAlert(client_, kTlsAlertHandshakeFailure);
@@ -1061,9 +1201,9 @@ static SECStatus GetEcClientAuthDataHook(void* self, PRFileDesc* fd,
return SECSuccess;
}
-TEST_P(TlsConnectTls12Plus, ClientAuthDisjointSchemes) {
+TEST_P(TlsConnectClientAuth12Plus, ClientAuthDisjointSchemes) {
EnsureTlsSetup();
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
server_->RequestClientAuth(true);
SSLSignatureScheme server_scheme = ssl_sig_rsa_pss_rsae_sha256;
@@ -1103,7 +1243,7 @@ TEST_P(TlsConnectTls12Plus, ClientAuthDisjointSchemes) {
}
}
-TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDisjointSchemes) {
+TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthDisjointSchemes) {
EnsureTlsSetup();
SSLSignatureScheme server_scheme = ssl_sig_rsa_pss_rsae_sha256;
std::vector<SSLSignatureScheme> client_schemes{
@@ -1116,7 +1256,7 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDisjointSchemes) {
static_cast<unsigned int>(client_schemes.size()));
EXPECT_EQ(SECSuccess, rv);
- client_->SetupClientAuth();
+ client_->SetupClientAuth(std::get<2>(GetParam()), true);
client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE);
// Select an EC cert that's incompatible with server schemes.
diff --git a/gtests/ssl_gtest/tls_agent.cc b/gtests/ssl_gtest/tls_agent.cc
index 650be8449..8ec2f40f7 100644
--- a/gtests/ssl_gtest/tls_agent.cc
+++ b/gtests/ssl_gtest/tls_agent.cc
@@ -324,10 +324,19 @@ void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) {
EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd(), ctx.get()));
}
-void TlsAgent::SetupClientAuth() {
+// Defaults to a Sync callback returning success
+void TlsAgent::SetupClientAuth(ClientAuthCallbackType callbackType,
+ bool callbackSuccess) {
EXPECT_TRUE(EnsureTlsSetup());
ASSERT_EQ(CLIENT, role_);
+ client_auth_callback_type_ = callbackType;
+ client_auth_callback_success_ = callbackSuccess;
+
+ if (callbackType == ClientAuthCallbackType::kNone && !callbackSuccess) {
+ // Don't set a callback for this case.
+ return;
+ }
EXPECT_EQ(SECSuccess,
SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook,
reinterpret_cast<void*>(this)));
@@ -344,26 +353,95 @@ void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) {
}
}
+// Complete processing of Client Certificate Selection
+// A No-op if the agent is using synchronous client cert selection.
+// Otherwise, calls SSL_ClientCertCallbackComplete.
+// kAsyncDelay triggers a call to SSL_ForceHandshake prior to completion to
+// ensure that the socket is correctly blocked.
+void TlsAgent::ClientAuthCallbackComplete() {
+ ASSERT_EQ(CLIENT, role_);
+
+ if (client_auth_callback_type_ != ClientAuthCallbackType::kAsyncDelay &&
+ client_auth_callback_type_ != ClientAuthCallbackType::kAsyncImmediate) {
+ return;
+ }
+ client_auth_callback_fired_++;
+ EXPECT_TRUE(client_auth_callback_awaiting_);
+
+ std::cerr << "client: calling SSL_ClientCertCallbackComplete with status "
+ << (client_auth_callback_success_ ? "success" : "failed")
+ << std::endl;
+
+ client_auth_callback_awaiting_ = false;
+
+ if (client_auth_callback_type_ == ClientAuthCallbackType::kAsyncDelay) {
+ std::cerr
+ << "Running Handshake prior to running SSL_ClientCertCallbackComplete"
+ << std::endl;
+ SECStatus rv = SSL_ForceHandshake(ssl_fd());
+ EXPECT_EQ(rv, SECFailure);
+ EXPECT_EQ(PORT_GetError(), PR_WOULD_BLOCK_ERROR);
+ }
+
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ if (client_auth_callback_success_) {
+ ASSERT_TRUE(TlsAgent::LoadCertificate(name(), &cert, &priv));
+ EXPECT_EQ(SECSuccess,
+ SSL_ClientCertCallbackComplete(ssl_fd(), SECSuccess,
+ priv.release(), cert.release()));
+ } else {
+ EXPECT_EQ(SECSuccess, SSL_ClientCertCallbackComplete(ssl_fd(), SECFailure,
+ nullptr, nullptr));
+ }
+}
+
SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
CERTDistNames* caNames,
CERTCertificate** clientCert,
SECKEYPrivateKey** clientKey) {
TlsAgent* agent = reinterpret_cast<TlsAgent*>(self);
- ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
- EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";
+ EXPECT_EQ(CLIENT, agent->role_);
+ agent->client_auth_callback_fired_++;
+
+ switch (agent->client_auth_callback_type_) {
+ case ClientAuthCallbackType::kAsyncDelay:
+ case ClientAuthCallbackType::kAsyncImmediate:
+ std::cerr << "Waiting for complete call" << std::endl;
+ agent->client_auth_callback_awaiting_ = true;
+ return SECWouldBlock;
+ case ClientAuthCallbackType::kSync:
+ case ClientAuthCallbackType::kNone:
+ // Handle the sync case. None && Success is treated as Sync and Success.
+ if (!agent->client_auth_callback_success_) {
+ return SECFailure;
+ }
+ ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
+ EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";
- // See bug 1573945
- // CheckCertReqAgainstDefaultCAs(caNames);
+ // See bug 1573945
+ // CheckCertReqAgainstDefaultCAs(caNames);
- ScopedCERTCertificate cert;
- ScopedSECKEYPrivateKey priv;
- if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) {
- return SECFailure;
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) {
+ return SECFailure;
+ }
+
+ *clientCert = cert.release();
+ *clientKey = priv.release();
+ return SECSuccess;
}
+ /* This is unreachable, but some old compilers can't tell that. */
+ PORT_Assert(0);
+ PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
+ return SECFailure;
+}
- *clientCert = cert.release();
- *clientKey = priv.release();
- return SECSuccess;
+// Increments by 1 for each callback
+bool TlsAgent::CheckClientAuthCallbacksCompleted(uint8_t expected) {
+ EXPECT_EQ(CLIENT, role_);
+ return expected == client_auth_callback_fired_;
}
bool TlsAgent::GetPeerChainLength(size_t* count) {
@@ -954,6 +1032,24 @@ void TlsAgent::Connected() {
SetState(STATE_CONNECTED);
}
+void TlsAgent::CheckClientAuthCompleted(uint8_t handshakes) {
+ EXPECT_FALSE(client_auth_callback_awaiting_);
+ switch (client_auth_callback_type_) {
+ case ClientAuthCallbackType::kNone:
+ if (!client_auth_callback_success_) {
+ EXPECT_TRUE(CheckClientAuthCallbacksCompleted(0));
+ break;
+ }
+ case ClientAuthCallbackType::kSync:
+ EXPECT_TRUE(CheckClientAuthCallbacksCompleted(handshakes));
+ break;
+ case ClientAuthCallbackType::kAsyncDelay:
+ case ClientAuthCallbackType::kAsyncImmediate:
+ EXPECT_TRUE(CheckClientAuthCallbacksCompleted(2 * handshakes));
+ break;
+ }
+}
+
void TlsAgent::EnableExtendedMasterSecret() {
SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
}
@@ -988,6 +1084,10 @@ void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) {
void TlsAgent::Handshake() {
LOGV("Handshake");
SECStatus rv = SSL_ForceHandshake(ssl_fd());
+ if (client_auth_callback_awaiting_) {
+ ClientAuthCallbackComplete();
+ rv = SSL_ForceHandshake(ssl_fd());
+ }
if (rv == SECSuccess) {
Connected();
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
diff --git a/gtests/ssl_gtest/tls_agent.h b/gtests/ssl_gtest/tls_agent.h
index 8b54155a5..35375e0c1 100644
--- a/gtests/ssl_gtest/tls_agent.h
+++ b/gtests/ssl_gtest/tls_agent.h
@@ -39,6 +39,13 @@ enum SessionResumptionMode {
RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
};
+enum class ClientAuthCallbackType {
+ kAsyncImmediate,
+ kAsyncDelay,
+ kSync,
+ kNone,
+};
+
class PacketFilter;
class TlsAgent;
class TlsCipherSpec;
@@ -144,9 +151,13 @@ class TlsAgent : public PollTarget {
bool ConfigServerCertWithChain(const std::string& name);
bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr);
- void SetupClientAuth();
+ void SetupClientAuth(
+ ClientAuthCallbackType callbackType = ClientAuthCallbackType::kSync,
+ bool callbackSuccess = true);
void RequestClientAuth(bool requireAuth);
-
+ void ClientAuthCallbackComplete();
+ bool CheckClientAuthCallbacksCompleted(uint8_t expected);
+ void CheckClientAuthCompleted(uint8_t handshakes = 1);
void SetOption(int32_t option, int value);
void ConfigureSessionCache(SessionResumptionMode mode);
void Set0RttEnabled(bool en);
@@ -465,6 +476,11 @@ class TlsAgent : public PollTarget {
std::vector<uint8_t> resumption_token_;
NssPolicy policy_;
NssOption option_;
+ ClientAuthCallbackType client_auth_callback_type_ =
+ ClientAuthCallbackType::kNone;
+ bool client_auth_callback_success_ = false;
+ uint8_t client_auth_callback_fired_ = 0;
+ bool client_auth_callback_awaiting_ = false;
};
inline std::ostream& operator<<(std::ostream& stream,