diff options
author | Dennis Jackson <djackson@mozilla.com> | 2022-06-16 11:22:49 +0000 |
---|---|---|
committer | Dennis Jackson <djackson@mozilla.com> | 2022-06-16 11:22:49 +0000 |
commit | d0681e2eacbe08c4b02de798cddffcae2c003a61 (patch) | |
tree | 39251cd4c321ef25bf4497fd5d7cbf8140e8bf40 /gtests/ssl_gtest | |
parent | 275120fccb522a8c54d84d417e70fc061048df34 (diff) | |
download | nss-hg-d0681e2eacbe08c4b02de798cddffcae2c003a61.tar.gz |
Bug 1617956 - Add support for asynchronous client auth hooks. r=mt
Differential Revision: https://phabricator.services.mozilla.com/D138149
Diffstat (limited to 'gtests/ssl_gtest')
-rw-r--r-- | gtests/ssl_gtest/ssl_auth_unittest.cc | 274 | ||||
-rw-r--r-- | gtests/ssl_gtest/tls_agent.cc | 124 | ||||
-rw-r--r-- | gtests/ssl_gtest/tls_agent.h | 20 |
3 files changed, 337 insertions, 81 deletions
diff --git a/gtests/ssl_gtest/ssl_auth_unittest.cc b/gtests/ssl_gtest/ssl_auth_unittest.cc index 925b82721..c71c0062e 100644 --- a/gtests/ssl_gtest/ssl_auth_unittest.cc +++ b/gtests/ssl_gtest/ssl_auth_unittest.cc @@ -169,13 +169,6 @@ TEST_P(TlsConnectGenericPre13, ServerAuthRejectAsync) { server_->ExpectReceiveAlert(kTlsAlertCloseNotify, kTlsAlertWarning); } -TEST_P(TlsConnectGeneric, ClientAuth) { - client_->SetupClientAuth(); - server_->RequestClientAuth(true); - Connect(); - CheckKeys(); -} - class TlsCertificateRequestContextRecorder : public TlsHandshakeFilter { public: TlsCertificateRequestContextRecorder(const std::shared_ptr<TlsAgent>& a, @@ -204,16 +197,135 @@ class TlsCertificateRequestContextRecorder : public TlsHandshakeFilter { bool filtered_; }; -// All stream only tests; DTLS isn't supported yet. +using ClientAuthParam = + std::tuple<SSLProtocolVariant, uint16_t, ClientAuthCallbackType>; + +class TlsConnectClientAuth + : public TlsConnectTestBase, + public testing::WithParamInterface<ClientAuthParam> { + public: + TlsConnectClientAuth() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} +}; + +// Wrapper classes for tests that target specific versions + +class TlsConnectClientAuth13 : public TlsConnectClientAuth {}; + +class TlsConnectClientAuth12 : public TlsConnectClientAuth {}; + +class TlsConnectClientAuthStream13 : public TlsConnectClientAuth {}; + +class TlsConnectClientAuthPre13 : public TlsConnectClientAuth {}; + +class TlsConnectClientAuth12Plus : public TlsConnectClientAuth {}; + +std::string getClientAuthTestName( + testing::TestParamInfo<ClientAuthParam> info) { + auto param = info.param; + auto variant = std::get<0>(param); + auto version = std::get<1>(param); + auto callback_type = std::get<2>(param); + + std::string output = std::string(); + switch (variant) { + case ssl_variant_stream: + output.append("TLS"); + break; + case ssl_variant_datagram: + output.append("DTLS"); + break; + } + output.append(VersionString(version).replace(1, 1, "")); + switch (callback_type) { + case ClientAuthCallbackType::kAsyncImmediate: + output.append("AsyncImmediate"); + break; + case ClientAuthCallbackType::kAsyncDelay: + output.append("AsyncDelay"); + break; + case ClientAuthCallbackType::kSync: + output.append("Sync"); + break; + case ClientAuthCallbackType::kNone: + output.append("None"); + break; + } + return output; +} + +auto kClientAuthCallbacks = testing::Values( + ClientAuthCallbackType::kAsyncImmediate, + ClientAuthCallbackType::kAsyncDelay, ClientAuthCallbackType::kSync, + ClientAuthCallbackType::kNone); + +INSTANTIATE_TEST_SUITE_P( + ClientAuthGenericStream, TlsConnectClientAuth, + testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll, kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P( + ClientAuthGenericDatagram, TlsConnectClientAuth, + testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus, kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P(ClientAuth13, TlsConnectClientAuth13, + testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13, + kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P( + ClientAuth13, TlsConnectClientAuthStream13, + testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV13, kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P(ClientAuth12, TlsConnectClientAuth12, + testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12, + kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P( + ClientAuthPre13Stream, TlsConnectClientAuthPre13, + testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12, kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P( + ClientAuthPre13Datagram, TlsConnectClientAuthPre13, + testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11V12, kClientAuthCallbacks), + getClientAuthTestName); + +INSTANTIATE_TEST_SUITE_P(ClientAuth12Plus, TlsConnectClientAuth12Plus, + testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12Plus, + kClientAuthCallbacks), + getClientAuthTestName); + +TEST_P(TlsConnectClientAuth, ClientAuth) { + EnsureTlsSetup(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(); + client_->CheckClientAuthCompleted(); +} + +// All stream only tests; PostHandshakeAuth isn't supported for DTLS. -TEST_F(TlsConnectStreamTls13, PostHandshakeAuth) { +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuth) { EnsureTlsSetup(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); auto capture_cert_req = MakeTlsFilter<TlsCertificateRequestContextRecorder>( server_, kTlsHandshakeCertificateRequest); auto capture_certificate = MakeTlsFilter<TlsCertificateRequestContextRecorder>( client_, kTlsHandshakeCertificate); - client_->SetupClientAuth(); client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE); size_t called = 0; server_->SetAuthCertificateCallback( @@ -232,11 +344,15 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuth) { // handled on both client and server. server_->SendData(50); client_->ReadBytes(50); + client_->ClientAuthCallbackComplete(); client_->SendData(50); server_->ReadBytes(50); + EXPECT_EQ(1U, called); - EXPECT_TRUE(capture_cert_req->filtered()); - EXPECT_TRUE(capture_certificate->filtered()); + ASSERT_TRUE(capture_cert_req->filtered()); + ASSERT_TRUE(capture_certificate->filtered()); + + client_->CheckClientAuthCompleted(); // Check if a non-empty request context is generated and it is // properly sent back. EXPECT_LT(0U, capture_cert_req->buffer().len()); @@ -252,7 +368,7 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuth) { EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); } -TEST_F(TlsConnectStreamTls13, PostHandshakeAuthAfterResumption) { +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthAfterResumption) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); Connect(); @@ -267,7 +383,7 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthAfterResumption) { ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); ExpectResumption(RESUME_TICKET); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE); Connect(); SendReceive(); @@ -280,10 +396,14 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthAfterResumption) { }); EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); client_->ReadBytes(50); + client_->ClientAuthCallbackComplete(); client_->SendData(50); server_->ReadBytes(50); + + client_->CheckClientAuthCompleted(); EXPECT_EQ(1U, called); ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd())); @@ -429,8 +549,8 @@ TEST_P(TlsConnectTls12, AutoClientSelectDsa) { EXPECT_TRUE(dsa.hookCalled); } -TEST_F(TlsConnectStreamTls13, PostHandshakeAuthMultiple) { - client_->SetupClientAuth(); +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthMultiple) { + client_->SetupClientAuth(std::get<2>(GetParam()), true); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); size_t called = 0; @@ -445,36 +565,42 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthMultiple) { // Send 1st CertificateRequest. EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); client_->ReadBytes(50); + client_->ClientAuthCallbackComplete(); + client_->ReadBytes(50); client_->SendData(50); server_->ReadBytes(50); EXPECT_EQ(1U, called); + client_->CheckClientAuthCompleted(1); ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd())); ASSERT_NE(nullptr, cert1.get()); ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd())); ASSERT_NE(nullptr, cert2.get()); EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); // Send 2nd CertificateRequest. - EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook( - client_->ssl_fd(), GetClientAuthDataHook, nullptr)); + client_->SetupClientAuth(std::get<2>(GetParam()), true); EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); client_->ReadBytes(50); + client_->ClientAuthCallbackComplete(); + client_->ReadBytes(50); client_->SendData(50); server_->ReadBytes(50); + client_->CheckClientAuthCompleted(2); EXPECT_EQ(2U, called); ScopedCERTCertificate cert3(SSL_PeerCertificate(server_->ssl_fd())); ASSERT_NE(nullptr, cert3.get()); ScopedCERTCertificate cert4(SSL_LocalCertificate(client_->ssl_fd())); ASSERT_NE(nullptr, cert4.get()); EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert4->derCert)); - EXPECT_FALSE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert1->derCert)); } -TEST_F(TlsConnectStreamTls13, PostHandshakeAuthConcurrent) { - client_->SetupClientAuth(); +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthConcurrent) { + client_->SetupClientAuth(std::get<2>(GetParam()), true); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); Connect(); @@ -486,8 +612,8 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthConcurrent) { EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); } -TEST_F(TlsConnectStreamTls13, PostHandshakeAuthBeforeKeyUpdate) { - client_->SetupClientAuth(); +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthBeforeKeyUpdate) { + client_->SetupClientAuth(std::get<2>(GetParam()), true); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); Connect(); @@ -499,8 +625,9 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthBeforeKeyUpdate) { EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); } -TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDuringClientKeyUpdate) { - client_->SetupClientAuth(); +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthDuringClientKeyUpdate) { + client_->SetupClientAuth(std::get<2>(GetParam()), true); + ; EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); Connect(); @@ -514,11 +641,14 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDuringClientKeyUpdate) { client_->SendData(50); // client sends KeyUpdate server_->ReadBytes(50); // server receives KeyUpdate and defers response CheckEpochs(4, 3); - client_->ReadBytes(50); // client receives CertificateRequest + client_->ReadBytes(60); // client receives CertificateRequest + client_->ClientAuthCallbackComplete(); + client_->ReadBytes(50); // Finish reading the remaining bytes client_->SendData( 50); // client sends Certificate, CertificateVerify, Finished server_->ReadBytes( 50); // server receives Certificate, CertificateVerify, Finished + client_->CheckClientAuthCompleted(); client_->CheckEpochs(3, 4); server_->CheckEpochs(4, 4); server_->SendData(50); // server sends KeyUpdate @@ -526,8 +656,8 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDuringClientKeyUpdate) { client_->CheckEpochs(4, 4); } -TEST_F(TlsConnectStreamTls13, PostHandshakeAuthMissingExtension) { - client_->SetupClientAuth(); +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthMissingExtension) { + client_->SetupClientAuth(std::get<2>(GetParam()), true); Connect(); // Send CertificateRequest, should fail due to missing // post_handshake_auth extension. @@ -535,8 +665,8 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthMissingExtension) { EXPECT_EQ(SSL_ERROR_MISSING_POST_HANDSHAKE_AUTH_EXTENSION, PORT_GetError()); } -TEST_F(TlsConnectStreamTls13, PostHandshakeAuthAfterClientAuth) { - client_->SetupClientAuth(); +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthAfterClientAuth) { + client_->SetupClientAuth(std::get<2>(GetParam()), true); server_->RequestClientAuth(true); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); @@ -593,10 +723,10 @@ class TlsDamageCertificateRequestContextFilter : public TlsHandshakeFilter { } }; -TEST_F(TlsConnectStreamTls13, PostHandshakeAuthContextMismatch) { +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthContextMismatch) { EnsureTlsSetup(); MakeTlsFilter<TlsDamageCertificateRequestContextFilter>(server_); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); Connect(); @@ -605,6 +735,8 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthContextMismatch) { << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); server_->SendData(50); client_->ReadBytes(50); + client_->ClientAuthCallbackComplete(); + client_->ReadBytes(50); client_->SendData(50); server_->ExpectSendAlert(kTlsAlertIllegalParameter); server_->ReadBytes(50); @@ -636,10 +768,10 @@ class TlsDamageSignatureFilter : public TlsHandshakeFilter { } }; -TEST_F(TlsConnectStreamTls13, PostHandshakeAuthBadSignature) { +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthBadSignature) { EnsureTlsSetup(); MakeTlsFilter<TlsDamageSignatureFilter>(client_); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); Connect(); @@ -648,20 +780,22 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthBadSignature) { << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); server_->SendData(50); client_->ReadBytes(50); + client_->ClientAuthCallbackComplete(); client_->SendData(50); + client_->CheckClientAuthCompleted(); server_->ExpectSendAlert(kTlsAlertDecodeError); server_->ReadBytes(50); EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CERT_VERIFY, PORT_GetError()); } -TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDecline) { +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthDecline) { EnsureTlsSetup(); auto capture_cert_req = MakeTlsFilter<TlsCertificateRequestContextRecorder>( server_, kTlsHandshakeCertificateRequest); auto capture_certificate = MakeTlsFilter<TlsCertificateRequestContextRecorder>( client_, kTlsHandshakeCertificate); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); EXPECT_EQ(SECSuccess, @@ -709,9 +843,10 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDecline) { // Check if post-handshake auth still works when session tickets are enabled: // https://bugzilla.mozilla.org/show_bug.cgi?id=1553443 -TEST_F(TlsConnectStreamTls13, PostHandshakeAuthWithSessionTicketsEnabled) { +TEST_P(TlsConnectClientAuthStream13, + PostHandshakeAuthWithSessionTicketsEnabled) { EnsureTlsSetup(); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), @@ -743,7 +878,8 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthWithSessionTicketsEnabled) { EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); } -TEST_P(TlsConnectGenericPre13, ClientAuthRequiredRejected) { +TEST_P(TlsConnectClientAuthPre13, ClientAuthRequiredRejected) { + client_->SetupClientAuth(std::get<2>(GetParam()), false); server_->RequestClientAuth(true); ConnectExpectAlert(server_, kTlsAlertBadCertificate); client_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT); @@ -752,13 +888,16 @@ TEST_P(TlsConnectGenericPre13, ClientAuthRequiredRejected) { // In TLS 1.3, the client will claim that the connection is done and then // receive the alert afterwards. So drive the handshake manually. -TEST_P(TlsConnectTls13, ClientAuthRequiredRejected) { +TEST_P(TlsConnectClientAuth13, ClientAuthRequiredRejected) { + client_->SetupClientAuth(std::get<2>(GetParam()), false); server_->RequestClientAuth(true); StartConnect(); client_->Handshake(); // CH server_->Handshake(); // SH.. (no resumption) + client_->Handshake(); // Next message ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + client_->CheckClientAuthCompleted(); ExpectAlert(server_, kTlsAlertCertificateRequired); server_->Handshake(); // Alert server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE); @@ -766,33 +905,34 @@ TEST_P(TlsConnectTls13, ClientAuthRequiredRejected) { client_->CheckErrorCode(SSL_ERROR_RX_CERTIFICATE_REQUIRED_ALERT); } -TEST_P(TlsConnectGeneric, ClientAuthRequestedRejected) { +TEST_P(TlsConnectClientAuth, ClientAuthRequestedRejected) { + client_->SetupClientAuth(std::get<2>(GetParam()), false); server_->RequestClientAuth(false); Connect(); CheckKeys(); } -TEST_P(TlsConnectGeneric, ClientAuthEcdsa) { +TEST_P(TlsConnectClientAuth, ClientAuthEcdsa) { Reset(TlsAgent::kServerEcdsa256); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); server_->RequestClientAuth(true); Connect(); CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); } -TEST_P(TlsConnectGeneric, ClientAuthWithEch) { +TEST_P(TlsConnectClientAuth, ClientAuthWithEch) { Reset(TlsAgent::kServerEcdsa256); EnsureTlsSetup(); SetupEch(client_, server_); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); server_->RequestClientAuth(true); Connect(); CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); } -TEST_P(TlsConnectGeneric, ClientAuthBigRsa) { +TEST_P(TlsConnectClientAuth, ClientAuthBigRsa) { Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); server_->RequestClientAuth(true); Connect(); CheckKeys(); @@ -835,11 +975,11 @@ TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) { 1024); } -TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) { +TEST_P(TlsConnectClientAuth12, ClientAuthCheckSigAlg) { EnsureTlsSetup(); auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>( client_, kTlsHandshakeCertificateVerify); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); server_->RequestClientAuth(true); Connect(); CheckKeys(); @@ -847,11 +987,11 @@ TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) { CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pkcs1_sha1, 1024); } -TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) { +TEST_P(TlsConnectClientAuth12, ClientAuthBigRsaCheckSigAlg) { Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048); auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>( client_, kTlsHandshakeCertificateVerify); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); server_->RequestClientAuth(true); Connect(); CheckKeys(); @@ -886,7 +1026,7 @@ class TlsReplaceSignatureSchemeFilter : public TlsHandshakeFilter { // This only works under TLS 1.2, because PSS doesn't work with TLS // 1.0 or TLS 1.1 and the TLS 1.3 1-RTT handshake is partially // successful at the client side. -TEST_P(TlsConnectTls12, ClientAuthInconsistentRsaeSignatureScheme) { +TEST_P(TlsConnectClientAuth12, ClientAuthInconsistentRsaeSignatureScheme) { static const SSLSignatureScheme kSignatureSchemePss[] = { ssl_sig_rsa_pss_pss_sha256, ssl_sig_rsa_pss_rsae_sha256}; @@ -895,7 +1035,7 @@ TEST_P(TlsConnectTls12, ClientAuthInconsistentRsaeSignatureScheme) { PR_ARRAY_SIZE(kSignatureSchemePss)); server_->SetSignatureSchemes(kSignatureSchemePss, PR_ARRAY_SIZE(kSignatureSchemePss)); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); server_->RequestClientAuth(true); EnsureTlsSetup(); @@ -912,7 +1052,7 @@ TEST_P(TlsConnectTls12, ClientAuthInconsistentRsaeSignatureScheme) { // This only works under TLS 1.2, because PSS doesn't work with TLS // 1.0 or TLS 1.1 and the TLS 1.3 1-RTT handshake is partially // successful at the client side. -TEST_P(TlsConnectTls12, ClientAuthInconsistentPssSignatureScheme) { +TEST_P(TlsConnectClientAuth12, ClientAuthInconsistentPssSignatureScheme) { static const SSLSignatureScheme kSignatureSchemePss[] = { ssl_sig_rsa_pss_rsae_sha256, ssl_sig_rsa_pss_pss_sha256}; @@ -921,7 +1061,7 @@ TEST_P(TlsConnectTls12, ClientAuthInconsistentPssSignatureScheme) { PR_ARRAY_SIZE(kSignatureSchemePss)); server_->SetSignatureSchemes(kSignatureSchemePss, PR_ARRAY_SIZE(kSignatureSchemePss)); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); server_->RequestClientAuth(true); EnsureTlsSetup(); @@ -932,7 +1072,7 @@ TEST_P(TlsConnectTls12, ClientAuthInconsistentPssSignatureScheme) { ConnectExpectAlert(server_, kTlsAlertIllegalParameter); } -TEST_P(TlsConnectTls13, ClientAuthPkcs1SignatureScheme) { +TEST_P(TlsConnectClientAuth13, ClientAuthPkcs1SignatureScheme) { static const SSLSignatureScheme kSignatureScheme[] = { ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pss_rsae_sha256}; @@ -941,7 +1081,7 @@ TEST_P(TlsConnectTls13, ClientAuthPkcs1SignatureScheme) { PR_ARRAY_SIZE(kSignatureScheme)); server_->SetSignatureSchemes(kSignatureScheme, PR_ARRAY_SIZE(kSignatureScheme)); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); server_->RequestClientAuth(true); auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>( @@ -954,14 +1094,14 @@ TEST_P(TlsConnectTls13, ClientAuthPkcs1SignatureScheme) { } // Client should refuse to connect without a usable signature scheme. -TEST_P(TlsConnectTls13, ClientAuthPkcs1SignatureSchemeOnly) { +TEST_P(TlsConnectClientAuth13, ClientAuthPkcs1SignatureSchemeOnly) { static const SSLSignatureScheme kSignatureScheme[] = { ssl_sig_rsa_pkcs1_sha256}; Reset(TlsAgent::kServerRsa, "rsa"); client_->SetSignatureSchemes(kSignatureScheme, PR_ARRAY_SIZE(kSignatureScheme)); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); client_->StartConnect(); client_->Handshake(); EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); @@ -970,14 +1110,14 @@ TEST_P(TlsConnectTls13, ClientAuthPkcs1SignatureSchemeOnly) { // Though the client has a usable signature scheme, when a certificate is // requested, it can't produce one. -TEST_P(TlsConnectTls13, ClientAuthPkcs1AndEcdsaScheme) { +TEST_P(TlsConnectClientAuth13, ClientAuthPkcs1AndEcdsaScheme) { static const SSLSignatureScheme kSignatureScheme[] = { ssl_sig_rsa_pkcs1_sha256, ssl_sig_ecdsa_secp256r1_sha256}; Reset(TlsAgent::kServerRsa, "rsa"); client_->SetSignatureSchemes(kSignatureScheme, PR_ARRAY_SIZE(kSignatureScheme)); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); server_->RequestClientAuth(true); ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); @@ -1031,12 +1171,12 @@ class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter { // Check that we send an alert when the server doesn't provide any // supported_signature_algorithms in the CertificateRequest message. -TEST_P(TlsConnectTls12, ClientAuthNoSigAlgs) { +TEST_P(TlsConnectClientAuth12, ClientAuthNoSigAlgs) { EnsureTlsSetup(); MakeTlsFilter<TlsZeroCertificateRequestSigAlgsFilter>(server_); auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>( client_, kTlsHandshakeCertificateVerify); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); server_->RequestClientAuth(true); ConnectExpectAlert(client_, kTlsAlertHandshakeFailure); @@ -1061,9 +1201,9 @@ static SECStatus GetEcClientAuthDataHook(void* self, PRFileDesc* fd, return SECSuccess; } -TEST_P(TlsConnectTls12Plus, ClientAuthDisjointSchemes) { +TEST_P(TlsConnectClientAuth12Plus, ClientAuthDisjointSchemes) { EnsureTlsSetup(); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); server_->RequestClientAuth(true); SSLSignatureScheme server_scheme = ssl_sig_rsa_pss_rsae_sha256; @@ -1103,7 +1243,7 @@ TEST_P(TlsConnectTls12Plus, ClientAuthDisjointSchemes) { } } -TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDisjointSchemes) { +TEST_P(TlsConnectClientAuthStream13, PostHandshakeAuthDisjointSchemes) { EnsureTlsSetup(); SSLSignatureScheme server_scheme = ssl_sig_rsa_pss_rsae_sha256; std::vector<SSLSignatureScheme> client_schemes{ @@ -1116,7 +1256,7 @@ TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDisjointSchemes) { static_cast<unsigned int>(client_schemes.size())); EXPECT_EQ(SECSuccess, rv); - client_->SetupClientAuth(); + client_->SetupClientAuth(std::get<2>(GetParam()), true); client_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE); // Select an EC cert that's incompatible with server schemes. diff --git a/gtests/ssl_gtest/tls_agent.cc b/gtests/ssl_gtest/tls_agent.cc index 650be8449..8ec2f40f7 100644 --- a/gtests/ssl_gtest/tls_agent.cc +++ b/gtests/ssl_gtest/tls_agent.cc @@ -324,10 +324,19 @@ void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) { EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd(), ctx.get())); } -void TlsAgent::SetupClientAuth() { +// Defaults to a Sync callback returning success +void TlsAgent::SetupClientAuth(ClientAuthCallbackType callbackType, + bool callbackSuccess) { EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(CLIENT, role_); + client_auth_callback_type_ = callbackType; + client_auth_callback_success_ = callbackSuccess; + + if (callbackType == ClientAuthCallbackType::kNone && !callbackSuccess) { + // Don't set a callback for this case. + return; + } EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook, reinterpret_cast<void*>(this))); @@ -344,26 +353,95 @@ void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) { } } +// Complete processing of Client Certificate Selection +// A No-op if the agent is using synchronous client cert selection. +// Otherwise, calls SSL_ClientCertCallbackComplete. +// kAsyncDelay triggers a call to SSL_ForceHandshake prior to completion to +// ensure that the socket is correctly blocked. +void TlsAgent::ClientAuthCallbackComplete() { + ASSERT_EQ(CLIENT, role_); + + if (client_auth_callback_type_ != ClientAuthCallbackType::kAsyncDelay && + client_auth_callback_type_ != ClientAuthCallbackType::kAsyncImmediate) { + return; + } + client_auth_callback_fired_++; + EXPECT_TRUE(client_auth_callback_awaiting_); + + std::cerr << "client: calling SSL_ClientCertCallbackComplete with status " + << (client_auth_callback_success_ ? "success" : "failed") + << std::endl; + + client_auth_callback_awaiting_ = false; + + if (client_auth_callback_type_ == ClientAuthCallbackType::kAsyncDelay) { + std::cerr + << "Running Handshake prior to running SSL_ClientCertCallbackComplete" + << std::endl; + SECStatus rv = SSL_ForceHandshake(ssl_fd()); + EXPECT_EQ(rv, SECFailure); + EXPECT_EQ(PORT_GetError(), PR_WOULD_BLOCK_ERROR); + } + + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + if (client_auth_callback_success_) { + ASSERT_TRUE(TlsAgent::LoadCertificate(name(), &cert, &priv)); + EXPECT_EQ(SECSuccess, + SSL_ClientCertCallbackComplete(ssl_fd(), SECSuccess, + priv.release(), cert.release())); + } else { + EXPECT_EQ(SECSuccess, SSL_ClientCertCallbackComplete(ssl_fd(), SECFailure, + nullptr, nullptr)); + } +} + SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, CERTDistNames* caNames, CERTCertificate** clientCert, SECKEYPrivateKey** clientKey) { TlsAgent* agent = reinterpret_cast<TlsAgent*>(self); - ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd())); - EXPECT_TRUE(peerCert) << "Client should be able to see the server cert"; + EXPECT_EQ(CLIENT, agent->role_); + agent->client_auth_callback_fired_++; + + switch (agent->client_auth_callback_type_) { + case ClientAuthCallbackType::kAsyncDelay: + case ClientAuthCallbackType::kAsyncImmediate: + std::cerr << "Waiting for complete call" << std::endl; + agent->client_auth_callback_awaiting_ = true; + return SECWouldBlock; + case ClientAuthCallbackType::kSync: + case ClientAuthCallbackType::kNone: + // Handle the sync case. None && Success is treated as Sync and Success. + if (!agent->client_auth_callback_success_) { + return SECFailure; + } + ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd())); + EXPECT_TRUE(peerCert) << "Client should be able to see the server cert"; - // See bug 1573945 - // CheckCertReqAgainstDefaultCAs(caNames); + // See bug 1573945 + // CheckCertReqAgainstDefaultCAs(caNames); - ScopedCERTCertificate cert; - ScopedSECKEYPrivateKey priv; - if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) { - return SECFailure; + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) { + return SECFailure; + } + + *clientCert = cert.release(); + *clientKey = priv.release(); + return SECSuccess; } + /* This is unreachable, but some old compilers can't tell that. */ + PORT_Assert(0); + PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); + return SECFailure; +} - *clientCert = cert.release(); - *clientKey = priv.release(); - return SECSuccess; +// Increments by 1 for each callback +bool TlsAgent::CheckClientAuthCallbacksCompleted(uint8_t expected) { + EXPECT_EQ(CLIENT, role_); + return expected == client_auth_callback_fired_; } bool TlsAgent::GetPeerChainLength(size_t* count) { @@ -954,6 +1032,24 @@ void TlsAgent::Connected() { SetState(STATE_CONNECTED); } +void TlsAgent::CheckClientAuthCompleted(uint8_t handshakes) { + EXPECT_FALSE(client_auth_callback_awaiting_); + switch (client_auth_callback_type_) { + case ClientAuthCallbackType::kNone: + if (!client_auth_callback_success_) { + EXPECT_TRUE(CheckClientAuthCallbacksCompleted(0)); + break; + } + case ClientAuthCallbackType::kSync: + EXPECT_TRUE(CheckClientAuthCallbacksCompleted(handshakes)); + break; + case ClientAuthCallbackType::kAsyncDelay: + case ClientAuthCallbackType::kAsyncImmediate: + EXPECT_TRUE(CheckClientAuthCallbacksCompleted(2 * handshakes)); + break; + } +} + void TlsAgent::EnableExtendedMasterSecret() { SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); } @@ -988,6 +1084,10 @@ void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) { void TlsAgent::Handshake() { LOGV("Handshake"); SECStatus rv = SSL_ForceHandshake(ssl_fd()); + if (client_auth_callback_awaiting_) { + ClientAuthCallbackComplete(); + rv = SSL_ForceHandshake(ssl_fd()); + } if (rv == SECSuccess) { Connected(); Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, diff --git a/gtests/ssl_gtest/tls_agent.h b/gtests/ssl_gtest/tls_agent.h index 8b54155a5..35375e0c1 100644 --- a/gtests/ssl_gtest/tls_agent.h +++ b/gtests/ssl_gtest/tls_agent.h @@ -39,6 +39,13 @@ enum SessionResumptionMode { RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET }; +enum class ClientAuthCallbackType { + kAsyncImmediate, + kAsyncDelay, + kSync, + kNone, +}; + class PacketFilter; class TlsAgent; class TlsCipherSpec; @@ -144,9 +151,13 @@ class TlsAgent : public PollTarget { bool ConfigServerCertWithChain(const std::string& name); bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr); - void SetupClientAuth(); + void SetupClientAuth( + ClientAuthCallbackType callbackType = ClientAuthCallbackType::kSync, + bool callbackSuccess = true); void RequestClientAuth(bool requireAuth); - + void ClientAuthCallbackComplete(); + bool CheckClientAuthCallbacksCompleted(uint8_t expected); + void CheckClientAuthCompleted(uint8_t handshakes = 1); void SetOption(int32_t option, int value); void ConfigureSessionCache(SessionResumptionMode mode); void Set0RttEnabled(bool en); @@ -465,6 +476,11 @@ class TlsAgent : public PollTarget { std::vector<uint8_t> resumption_token_; NssPolicy policy_; NssOption option_; + ClientAuthCallbackType client_auth_callback_type_ = + ClientAuthCallbackType::kNone; + bool client_auth_callback_success_ = false; + uint8_t client_auth_callback_fired_ = 0; + bool client_auth_callback_awaiting_ = false; }; inline std::ostream& operator<<(std::ostream& stream, |