summaryrefslogtreecommitdiff
path: root/gtests/ssl_gtest
diff options
context:
space:
mode:
Diffstat (limited to 'gtests/ssl_gtest')
-rw-r--r--gtests/ssl_gtest/bloomfilter_unittest.cc2
-rw-r--r--gtests/ssl_gtest/ssl_0rtt_unittest.cc10
-rw-r--r--gtests/ssl_gtest/ssl_agent_unittest.cc6
-rw-r--r--gtests/ssl_gtest/ssl_auth_unittest.cc100
-rw-r--r--gtests/ssl_gtest/ssl_cert_ext_unittest.cc10
-rw-r--r--gtests/ssl_gtest/ssl_ciphersuite_unittest.cc2
-rw-r--r--gtests/ssl_gtest/ssl_custext_unittest.cc31
-rw-r--r--gtests/ssl_gtest/ssl_damage_unittest.cc25
-rw-r--r--gtests/ssl_gtest/ssl_dhe_unittest.cc79
-rw-r--r--gtests/ssl_gtest/ssl_drop_unittest.cc78
-rw-r--r--gtests/ssl_gtest/ssl_ecdh_unittest.cc34
-rw-r--r--gtests/ssl_gtest/ssl_extension_unittest.cc212
-rw-r--r--gtests/ssl_gtest/ssl_fragment_unittest.cc4
-rw-r--r--gtests/ssl_gtest/ssl_fuzz_unittest.cc53
-rw-r--r--gtests/ssl_gtest/ssl_hrr_unittest.cc106
-rw-r--r--gtests/ssl_gtest/ssl_keylog_unittest.cc4
-rw-r--r--gtests/ssl_gtest/ssl_loopback_unittest.cc60
-rw-r--r--gtests/ssl_gtest/ssl_record_unittest.cc16
-rw-r--r--gtests/ssl_gtest/ssl_resumption_unittest.cc113
-rw-r--r--gtests/ssl_gtest/ssl_skip_unittest.cc94
-rw-r--r--gtests/ssl_gtest/ssl_staticrsa_unittest.cc32
-rw-r--r--gtests/ssl_gtest/ssl_tls13compat_unittest.cc54
-rw-r--r--gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc4
-rw-r--r--gtests/ssl_gtest/ssl_version_unittest.cc42
-rw-r--r--gtests/ssl_gtest/ssl_versionpolicy_unittest.cc6
-rw-r--r--gtests/ssl_gtest/test_io.cc4
-rw-r--r--gtests/ssl_gtest/test_io.h6
-rw-r--r--gtests/ssl_gtest/tls_agent.h14
-rw-r--r--gtests/ssl_gtest/tls_connect.cc16
-rw-r--r--gtests/ssl_gtest/tls_connect.h6
-rw-r--r--gtests/ssl_gtest/tls_filter.cc8
-rw-r--r--gtests/ssl_gtest/tls_filter.h130
32 files changed, 733 insertions, 628 deletions
diff --git a/gtests/ssl_gtest/bloomfilter_unittest.cc b/gtests/ssl_gtest/bloomfilter_unittest.cc
index 110cfa13a..6efe06ec7 100644
--- a/gtests/ssl_gtest/bloomfilter_unittest.cc
+++ b/gtests/ssl_gtest/bloomfilter_unittest.cc
@@ -105,4 +105,4 @@ static const BloomFilterConfig kBloomFilterConfigurations[] = {
INSTANTIATE_TEST_CASE_P(BloomFilterConfigurations, BloomFilterTest,
::testing::ValuesIn(kBloomFilterConfigurations));
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/gtests/ssl_gtest/ssl_0rtt_unittest.cc
index 7d6120dc1..ded388f57 100644
--- a/gtests/ssl_gtest/ssl_0rtt_unittest.cc
+++ b/gtests/ssl_gtest/ssl_0rtt_unittest.cc
@@ -94,7 +94,7 @@ class TlsZeroRttReplayTest : public TlsConnectTls13 {
// Now run a true 0-RTT handshake, but capture the first packet.
auto first_packet = std::make_shared<SaveFirstPacket>();
- client_->SetPacketFilter(first_packet);
+ client_->SetFilter(first_packet);
client_->Set0RttEnabled(true);
server_->Set0RttEnabled(true);
ExpectResumption(RESUME_TICKET);
@@ -115,9 +115,9 @@ class TlsZeroRttReplayTest : public TlsConnectTls13 {
server_->Set0RttEnabled(true);
// Capture the early_data extension, which should not appear.
- auto early_data_ext =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn);
- server_->SetPacketFilter(early_data_ext);
+ auto early_data_ext = std::make_shared<TlsExtensionCapture>(
+ server_, ssl_tls13_early_data_xtn);
+ server_->SetFilter(early_data_ext);
// Finally, replay the ClientHello and force the server to consume it. Stop
// after the server sends its first flight; the client will not be able to
@@ -607,7 +607,7 @@ TEST_P(TlsConnectTls13, ZeroRttOrdering) {
// Now, coalesce the next three things from the client: early data, second
// flight and 1-RTT data.
auto coalesce = std::make_shared<PacketCoalesceFilter>();
- client_->SetPacketFilter(coalesce);
+ client_->SetFilter(coalesce);
// Send (and hold) early data.
static const std::vector<uint8_t> early_data = {3, 2, 1};
diff --git a/gtests/ssl_gtest/ssl_agent_unittest.cc b/gtests/ssl_gtest/ssl_agent_unittest.cc
index 0aa9a4c78..c7a841c68 100644
--- a/gtests/ssl_gtest/ssl_agent_unittest.cc
+++ b/gtests/ssl_gtest/ssl_agent_unittest.cc
@@ -160,9 +160,9 @@ TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenWrite) {
SSL_LIBRARY_VERSION_TLS_1_3);
agent_->StartConnect();
agent_->Set0RttEnabled(true);
- auto filter = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeClientHello);
- agent_->SetPacketFilter(filter);
+ auto filter =
+ std::make_shared<TlsHandshakeRecorder>(agent_, kTlsHandshakeClientHello);
+ agent_->SetFilter(filter);
PRInt32 rv = PR_Write(agent_->ssl_fd(), k0RttData, strlen(k0RttData));
EXPECT_EQ(-1, rv);
int32_t err = PORT_GetError();
diff --git a/gtests/ssl_gtest/ssl_auth_unittest.cc b/gtests/ssl_gtest/ssl_auth_unittest.cc
index c44a18161..d9e31f78f 100644
--- a/gtests/ssl_gtest/ssl_auth_unittest.cc
+++ b/gtests/ssl_gtest/ssl_auth_unittest.cc
@@ -95,10 +95,9 @@ TEST_P(TlsConnectGeneric, ClientAuthBigRsa) {
}
// Offset is the position in the captured buffer where the signature sits.
-static void CheckSigScheme(
- std::shared_ptr<TlsInspectorRecordHandshakeMessage>& capture, size_t offset,
- std::shared_ptr<TlsAgent>& peer, uint16_t expected_scheme,
- size_t expected_size) {
+static void CheckSigScheme(std::shared_ptr<TlsHandshakeRecorder>& capture,
+ size_t offset, std::shared_ptr<TlsAgent>& peer,
+ uint16_t expected_scheme, size_t expected_size) {
EXPECT_LT(offset + 2U, capture->buffer().len());
uint32_t scheme = 0;
@@ -114,9 +113,9 @@ static void CheckSigScheme(
// in the default certificate.
TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
EnsureTlsSetup();
- auto capture_ske = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(capture_ske);
+ auto capture_ske = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ server_->SetFilter(capture_ske);
Connect();
CheckKeys();
@@ -133,10 +132,9 @@ TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
EnsureTlsSetup();
- auto capture_cert_verify =
- std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeCertificateVerify);
- client_->SetPacketFilter(capture_cert_verify);
+ auto capture_cert_verify = std::make_shared<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
+ client_->SetFilter(capture_cert_verify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
Connect();
@@ -147,10 +145,9 @@ TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) {
Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048);
- auto capture_cert_verify =
- std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeCertificateVerify);
- client_->SetPacketFilter(capture_cert_verify);
+ auto capture_cert_verify = std::make_shared<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
+ client_->SetFilter(capture_cert_verify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
Connect();
@@ -161,8 +158,8 @@ TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) {
class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter {
public:
- TlsZeroCertificateRequestSigAlgsFilter()
- : TlsHandshakeFilter({kTlsHandshakeCertificateRequest}) {}
+ TlsZeroCertificateRequestSigAlgsFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeCertificateRequest}) {}
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
@@ -207,12 +204,12 @@ class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter {
// supported_signature_algorithms in the CertificateRequest message.
TEST_P(TlsConnectTls12, ClientAuthNoSigAlgsFallback) {
EnsureTlsSetup();
- auto filter = std::make_shared<TlsZeroCertificateRequestSigAlgsFilter>();
- server_->SetPacketFilter(filter);
- auto capture_cert_verify =
- std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeCertificateVerify);
- client_->SetPacketFilter(capture_cert_verify);
+ auto filter =
+ std::make_shared<TlsZeroCertificateRequestSigAlgsFilter>(server_);
+ server_->SetFilter(filter);
+ auto capture_cert_verify = std::make_shared<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
+ client_->SetFilter(capture_cert_verify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
@@ -360,8 +357,8 @@ TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) {
// The signature_algorithms extension is mandatory in TLS 1.3.
TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) {
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_signature_algorithms_xtn));
+ client_->SetFilter(std::make_shared<TlsExtensionDropper>(
+ client_, ssl_signature_algorithms_xtn));
ConnectExpectAlert(server_, kTlsAlertMissingExtension);
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION);
@@ -370,8 +367,8 @@ TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) {
// TLS 1.2 has trouble detecting this sort of modification: it uses SHA1 and
// only fails when the Finished is checked.
TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) {
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_signature_algorithms_xtn));
+ client_->SetFilter(std::make_shared<TlsExtensionDropper>(
+ client_, ssl_signature_algorithms_xtn));
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
@@ -389,11 +386,11 @@ class BeforeFinished : public TlsRecordFilter {
enum HandshakeState { BEFORE_CCS, AFTER_CCS, DONE };
public:
- BeforeFinished(std::shared_ptr<TlsAgent>& client,
- std::shared_ptr<TlsAgent>& server, VoidFunction before_ccs,
- VoidFunction before_finished)
- : client_(client),
- server_(server),
+ BeforeFinished(const std::shared_ptr<TlsAgent>& client,
+ const std::shared_ptr<TlsAgent>& server,
+ VoidFunction before_ccs, VoidFunction before_finished)
+ : TlsRecordFilter(server),
+ client_(client),
before_ccs_(before_ccs),
before_finished_(before_finished),
state_(BEFORE_CCS) {}
@@ -413,7 +410,7 @@ class BeforeFinished : public TlsRecordFilter {
// but that means that they both get processed together.
DataBuffer ccs;
header.Write(&ccs, 0, body);
- server_.lock()->SendDirect(ccs);
+ agent()->SendDirect(ccs);
client_.lock()->Handshake();
state_ = AFTER_CCS;
// Request that the original record be dropped by the filter.
@@ -438,7 +435,6 @@ class BeforeFinished : public TlsRecordFilter {
private:
std::weak_ptr<TlsAgent> client_;
- std::weak_ptr<TlsAgent> server_;
VoidFunction before_ccs_;
VoidFunction before_finished_;
HandshakeState state_;
@@ -463,8 +459,8 @@ class BeforeFinished13 : public PacketFilter {
};
public:
- BeforeFinished13(std::shared_ptr<TlsAgent>& client,
- std::shared_ptr<TlsAgent>& server,
+ BeforeFinished13(const std::shared_ptr<TlsAgent>& client,
+ const std::shared_ptr<TlsAgent>& server,
VoidFunction before_finished)
: client_(client),
server_(server),
@@ -514,7 +510,7 @@ static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) {
// processed by the client, SSL_AuthCertificateComplete() is called.
TEST_F(TlsConnectDatagram13, AuthCompleteBeforeFinished) {
client_->SetAuthCertificateCallback(AuthCompleteBlock);
- server_->SetPacketFilter(
+ server_->SetFilter(
std::make_shared<BeforeFinished13>(client_, server_, [this]() {
EXPECT_EQ(SECSuccess,
SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
@@ -546,7 +542,7 @@ TEST_F(TlsConnectDatagram13, AuthCompleteAfterFinished) {
TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
client_->EnableFalseStart();
- server_->SetPacketFilter(std::make_shared<BeforeFinished>(
+ server_->SetFilter(std::make_shared<BeforeFinished>(
client_, server_,
[this]() { EXPECT_TRUE(client_->can_falsestart_hook_called()); },
[this]() {
@@ -562,7 +558,7 @@ TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) {
client_->EnableFalseStart();
client_->SetAuthCertificateCallback(AuthCompleteBlock);
- server_->SetPacketFilter(std::make_shared<BeforeFinished>(
+ server_->SetFilter(std::make_shared<BeforeFinished>(
client_, server_,
[]() {
// Do nothing before CCS
@@ -608,7 +604,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
// The client should send nothing from here on.
- client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
@@ -618,8 +614,8 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
- // Remove this before closing or the close_notify alert will trigger it.
- client_->DeletePacketFilter();
+ // Remove filter before closing or the close_notify alert will trigger it.
+ client_->ClearFilter();
}
TEST_P(TlsConnectGenericPre13, AuthCompleteFailDelayed) {
@@ -634,12 +630,12 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteFailDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
// The client should send nothing from here on.
- client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
// Report failure.
- client_->DeletePacketFilter();
+ client_->ClearFilter();
client_->ExpectSendAlert(kTlsAlertBadCertificate);
EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(),
SSL_ERROR_BAD_CERTIFICATE));
@@ -659,12 +655,12 @@ TEST_P(TlsConnectTls13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
// The client will send nothing until AuthCertificateComplete is called.
- client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
// This should allow the handshake to complete now.
- client_->DeletePacketFilter();
+ client_->ClearFilter();
EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
client_->Handshake(); // Send Finished
server_->Handshake(); // Transition to connected and send NewSessionTicket
@@ -682,12 +678,12 @@ TEST_P(TlsConnectTls13, AuthCompleteFailDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
// The client will send nothing until AuthCertificateComplete is called.
- client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
// Report failure.
- client_->DeletePacketFilter();
+ client_->ClearFilter();
ExpectAlert(client_, kTlsAlertBadCertificate);
EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(),
SSL_ERROR_BAD_CERTIFICATE));
@@ -831,9 +827,9 @@ TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigServer) {
TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigClient) {
Reset(certificate_);
- auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn);
- client_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_signature_algorithms_xtn);
+ client_->SetFilter(capture);
TestSignatureSchemeConfig(client_);
const DataBuffer& ext = capture->extension();
@@ -907,4 +903,4 @@ INSTANTIATE_TEST_CASE_P(
TlsAgent::kServerEcdsa384),
::testing::Values(ssl_auth_ecdsa),
::testing::Values(ssl_sig_ecdsa_sha1)));
-}
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
index 36ee104af..00b55a8c9 100644
--- a/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
+++ b/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
@@ -181,8 +181,8 @@ TEST_P(TlsConnectGenericPre13, OcspMangled) {
static const uint8_t val[] = {1};
auto replacer = std::make_shared<TlsExtensionReplacer>(
- ssl_cert_status_xtn, DataBuffer(val, sizeof(val)));
- server_->SetPacketFilter(replacer);
+ server_, ssl_cert_status_xtn, DataBuffer(val, sizeof(val)));
+ server_->SetFilter(replacer);
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
@@ -192,8 +192,8 @@ TEST_P(TlsConnectGeneric, OcspSuccess) {
EnsureTlsSetup();
client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
auto capture_ocsp =
- std::make_shared<TlsExtensionCapture>(ssl_cert_status_xtn);
- server_->SetPacketFilter(capture_ocsp);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_cert_status_xtn);
+ server_->SetFilter(capture_ocsp);
// The value should be available during the AuthCertificateCallback
client_->SetAuthCertificateCallback([](TlsAgent* agent, bool checksig,
@@ -245,4 +245,4 @@ TEST_P(TlsConnectGeneric, OcspHugeSuccess) {
Connect();
}
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
index 206ee1961..fa2238be7 100644
--- a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
+++ b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
@@ -466,4 +466,4 @@ static const SecStatusParams kSecStatusTestValuesArr[] = {
INSTANTIATE_TEST_CASE_P(TestSecurityStatus, SecurityStatusTest,
::testing::ValuesIn(kSecStatusTestValuesArr));
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_custext_unittest.cc b/gtests/ssl_gtest/ssl_custext_unittest.cc
index dad944a1f..7233b2218 100644
--- a/gtests/ssl_gtest/ssl_custext_unittest.cc
+++ b/gtests/ssl_gtest/ssl_custext_unittest.cc
@@ -150,9 +150,9 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionWriterDisable) {
client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NoopExtensionWriter,
nullptr, NoopExtensionHandler, nullptr);
EXPECT_EQ(SECSuccess, rv);
- auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_signed_cert_timestamp_xtn);
- client_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_signed_cert_timestamp_xtn);
+ client_->SetFilter(capture);
Connect();
// So nothing will be sent.
@@ -204,9 +204,9 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionOverride) {
EXPECT_EQ(SECSuccess, rv);
// Capture it to see what we got.
- auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_signed_cert_timestamp_xtn);
- client_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_signed_cert_timestamp_xtn);
+ client_->SetFilter(capture);
ConnectExpectAlert(server_, kTlsAlertDecodeError);
@@ -246,8 +246,8 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionClientToServer) {
EXPECT_EQ(SECSuccess, rv);
// Capture it to see what we got.
- auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
- client_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(client_, extension_code);
+ client_->SetFilter(capture);
// Handle it so that the handshake completes.
rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
@@ -290,9 +290,9 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientSH) {
EXPECT_EQ(SECSuccess, rv);
// Capture the extension from the ServerHello only and check it.
- auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
+ auto capture = std::make_shared<TlsExtensionCapture>(server_, extension_code);
capture->SetHandshakeTypes({kTlsHandshakeServerHello});
- server_->SetPacketFilter(capture);
+ server_->SetFilter(capture);
Connect();
@@ -329,9 +329,10 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientEE) {
EXPECT_EQ(SECSuccess, rv);
// Capture the extension from the EncryptedExtensions only and check it.
- auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
+ auto capture = std::make_shared<TlsExtensionCapture>(server_, extension_code);
capture->SetHandshakeTypes({kTlsHandshakeEncryptedExtensions});
- server_->SetTlsRecordFilter(capture);
+ capture->EnableDecryption();
+ server_->SetFilter(capture);
Connect();
@@ -350,8 +351,8 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionUnsolicitedServer) {
EXPECT_EQ(SECSuccess, rv);
// Capture it to see what we got.
- auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
- server_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(server_, extension_code);
+ server_->SetFilter(capture);
client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
@@ -500,4 +501,4 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionOverrunBuffer) {
client_->CheckErrorCode(SEC_ERROR_APPLICATION_CALLBACK_ERROR);
}
-} // namespace "nss_test"
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_damage_unittest.cc b/gtests/ssl_gtest/ssl_damage_unittest.cc
index d1668b823..e9e625a03 100644
--- a/gtests/ssl_gtest/ssl_damage_unittest.cc
+++ b/gtests/ssl_gtest/ssl_damage_unittest.cc
@@ -50,7 +50,7 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_3);
- server_->SetPacketFilter(std::make_shared<AfterRecordN>(
+ server_->SetFilter(std::make_shared<AfterRecordN>(
server_, client_,
0, // ServerHello.
[this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); }));
@@ -60,9 +60,10 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) {
TEST_P(TlsConnectGenericPre13, DamageServerSignature) {
EnsureTlsSetup();
- auto filter =
- std::make_shared<TlsLastByteDamager>(kTlsHandshakeServerKeyExchange);
- server_->SetTlsRecordFilter(filter);
+ auto filter = std::make_shared<TlsLastByteDamager>(
+ server_, kTlsHandshakeServerKeyExchange);
+ filter->EnableDecryption();
+ server_->SetFilter(filter);
ExpectAlert(client_, kTlsAlertDecryptError);
ConnectExpectFail();
client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
@@ -71,9 +72,10 @@ TEST_P(TlsConnectGenericPre13, DamageServerSignature) {
TEST_P(TlsConnectTls13, DamageServerSignature) {
EnsureTlsSetup();
- auto filter =
- std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
- server_->SetTlsRecordFilter(filter);
+ auto filter = std::make_shared<TlsLastByteDamager>(
+ server_, kTlsHandshakeCertificateVerify);
+ filter->EnableDecryption();
+ server_->SetFilter(filter);
ConnectExpectAlert(client_, kTlsAlertDecryptError);
client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
}
@@ -82,9 +84,10 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) {
EnsureTlsSetup();
client_->SetupClientAuth();
server_->RequestClientAuth(true);
- auto filter =
- std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
- client_->SetTlsRecordFilter(filter);
+ auto filter = std::make_shared<TlsLastByteDamager>(
+ client_, kTlsHandshakeCertificateVerify);
+ filter->EnableDecryption();
+ client_->SetFilter(filter);
server_->ExpectSendAlert(kTlsAlertDecryptError);
// Do these handshakes by hand to avoid race condition on
// the client processing the server's alert.
@@ -100,4 +103,4 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) {
server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
}
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_dhe_unittest.cc b/gtests/ssl_gtest/ssl_dhe_unittest.cc
index 899720607..b61728203 100644
--- a/gtests/ssl_gtest/ssl_dhe_unittest.cc
+++ b/gtests/ssl_gtest/ssl_dhe_unittest.cc
@@ -32,12 +32,12 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) {
client_->ConfigNamedGroups(kAllDHEGroups);
auto groups_capture =
- std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
auto shares_capture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture,
shares_capture};
- client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
Connect();
@@ -61,12 +61,12 @@ TEST_P(TlsConnectGeneric, ConnectFfdheClient) {
EnableOnlyDheCiphers();
client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
auto groups_capture =
- std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
auto shares_capture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture,
shares_capture};
- client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
Connect();
@@ -103,8 +103,8 @@ TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) {
class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter {
public:
- TlsDheServerKeyExchangeDamager()
- : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {}
+ TlsDheServerKeyExchangeDamager(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {}
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
@@ -122,7 +122,7 @@ class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter {
TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) {
EnableOnlyDheCiphers();
client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
- server_->SetPacketFilter(std::make_shared<TlsDheServerKeyExchangeDamager>());
+ server_->SetFilter(std::make_shared<TlsDheServerKeyExchangeDamager>(server_));
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
@@ -141,8 +141,9 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter {
kYZeroPad
};
- TlsDheSkeChangeY(uint8_t handshake_type, ChangeYTo change)
- : TlsHandshakeFilter({handshake_type}), change_Y_(change) {}
+ TlsDheSkeChangeY(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type, ChangeYTo change)
+ : TlsHandshakeFilter(agent, {handshake_type}), change_Y_(change) {}
protected:
void ChangeY(const DataBuffer& input, DataBuffer* output, size_t offset,
@@ -207,8 +208,9 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter {
class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {
public:
- TlsDheSkeChangeYServer(ChangeYTo change, bool modify)
- : TlsDheSkeChangeY(kTlsHandshakeServerKeyExchange, change),
+ TlsDheSkeChangeYServer(const std::shared_ptr<TlsAgent>& agent,
+ ChangeYTo change, bool modify)
+ : TlsDheSkeChangeY(agent, kTlsHandshakeServerKeyExchange, change),
modify_(modify),
p_() {}
@@ -245,9 +247,9 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {
class TlsDheSkeChangeYClient : public TlsDheSkeChangeY {
public:
TlsDheSkeChangeYClient(
- ChangeYTo change,
+ const std::shared_ptr<TlsAgent>& agent, ChangeYTo change,
std::shared_ptr<const TlsDheSkeChangeYServer> server_filter)
- : TlsDheSkeChangeY(kTlsHandshakeClientKeyExchange, change),
+ : TlsDheSkeChangeY(agent, kTlsHandshakeClientKeyExchange, change),
server_filter_(server_filter) {}
protected:
@@ -282,8 +284,8 @@ TEST_P(TlsDamageDHYTest, DamageServerY) {
client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
}
TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
- server_->SetPacketFilter(
- std::make_shared<TlsDheSkeChangeYServer>(change, true));
+ server_->SetFilter(
+ std::make_shared<TlsDheSkeChangeYServer>(server_, change, true));
if (change == TlsDheSkeChangeY::kYZeroPad) {
ExpectAlert(client_, kTlsAlertDecryptError);
@@ -312,14 +314,14 @@ TEST_P(TlsDamageDHYTest, DamageClientY) {
client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
}
// The filter on the server is required to capture the prime.
- auto server_filter =
- std::make_shared<TlsDheSkeChangeYServer>(TlsDheSkeChangeY::kYZero, false);
- server_->SetPacketFilter(server_filter);
+ auto server_filter = std::make_shared<TlsDheSkeChangeYServer>(
+ server_, TlsDheSkeChangeY::kYZero, false);
+ server_->SetFilter(server_filter);
// The client filter does the damage.
TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
- client_->SetPacketFilter(
- std::make_shared<TlsDheSkeChangeYClient>(change, server_filter));
+ client_->SetFilter(
+ std::make_shared<TlsDheSkeChangeYClient>(client_, change, server_filter));
if (change == TlsDheSkeChangeY::kYZeroPad) {
ExpectAlert(server_, kTlsAlertDecryptError);
@@ -358,7 +360,9 @@ INSTANTIATE_TEST_CASE_P(
class TlsDheSkeMakePEven : public TlsHandshakeFilter {
public:
- TlsDheSkeMakePEven() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {}
+ TlsDheSkeMakePEven(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {}
+
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
@@ -379,7 +383,7 @@ class TlsDheSkeMakePEven : public TlsHandshakeFilter {
// Even without requiring named groups, an even value for p is bad news.
TEST_P(TlsConnectGenericPre13, MakeDhePEven) {
EnableOnlyDheCiphers();
- server_->SetPacketFilter(std::make_shared<TlsDheSkeMakePEven>());
+ server_->SetFilter(std::make_shared<TlsDheSkeMakePEven>(server_));
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
@@ -389,7 +393,9 @@ TEST_P(TlsConnectGenericPre13, MakeDhePEven) {
class TlsDheSkeZeroPadP : public TlsHandshakeFilter {
public:
- TlsDheSkeZeroPadP() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {}
+ TlsDheSkeZeroPadP(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {}
+
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
@@ -407,7 +413,7 @@ class TlsDheSkeZeroPadP : public TlsHandshakeFilter {
// Zero padding only causes signature failure.
TEST_P(TlsConnectGenericPre13, PadDheP) {
EnableOnlyDheCiphers();
- server_->SetPacketFilter(std::make_shared<TlsDheSkeZeroPadP>());
+ server_->SetFilter(std::make_shared<TlsDheSkeZeroPadP>(server_));
ConnectExpectAlert(client_, kTlsAlertDecryptError);
@@ -529,12 +535,12 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) {
Reset();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
EnableOnlyDheCiphers();
- auto clientCapture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(clientCapture);
- auto serverCapture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- server_->SetPacketFilter(serverCapture);
+ auto clientCapture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_tls13_pre_shared_key_xtn);
+ client_->SetFilter(clientCapture);
+ auto serverCapture = std::make_shared<TlsExtensionCapture>(
+ server_, ssl_tls13_pre_shared_key_xtn);
+ server_->SetFilter(serverCapture);
ExpectResumption(RESUME_TICKET);
Connect();
CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
@@ -545,8 +551,9 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) {
class TlsDheSkeChangeSignature : public TlsHandshakeFilter {
public:
- TlsDheSkeChangeSignature(uint16_t version, const uint8_t* data, size_t len)
- : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}),
+ TlsDheSkeChangeSignature(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t version, const uint8_t* data, size_t len)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}),
version_(version),
data_(data),
len_(len) {}
@@ -595,8 +602,8 @@ TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) {
const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ffdhe_2048};
client_->ConfigNamedGroups(client_groups);
- server_->SetPacketFilter(std::make_shared<TlsDheSkeChangeSignature>(
- version_, kBogusDheSignature, sizeof(kBogusDheSignature)));
+ server_->SetFilter(std::make_shared<TlsDheSkeChangeSignature>(
+ server_, version_, kBogusDheSignature, sizeof(kBogusDheSignature)));
ConnectExpectAlert(client_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
diff --git a/gtests/ssl_gtest/ssl_drop_unittest.cc b/gtests/ssl_gtest/ssl_drop_unittest.cc
index c059e9938..ee8906deb 100644
--- a/gtests/ssl_gtest/ssl_drop_unittest.cc
+++ b/gtests/ssl_gtest/ssl_drop_unittest.cc
@@ -22,13 +22,13 @@ extern "C" {
namespace nss_test {
TEST_P(TlsConnectDatagramPre13, DropClientFirstFlightOnce) {
- client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1));
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x1));
Connect();
SendReceive();
}
TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightOnce) {
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x1));
Connect();
SendReceive();
}
@@ -37,32 +37,32 @@ TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightOnce) {
// flights that they send. Note: In DTLS 1.3, the shorter handshake means that
// this will also drop some application data, so we can't call SendReceive().
TEST_P(TlsConnectDatagramPre13, DropAllFirstTransmissions) {
- client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x15));
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x5));
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x15));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x5));
Connect();
}
// This drops the server's first flight three times.
TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightThrice) {
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x7));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x7));
Connect();
}
// This drops the client's second flight once
TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightOnce) {
- client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2));
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x2));
Connect();
}
// This drops the client's second flight three times.
TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightThrice) {
- client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe));
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0xe));
Connect();
}
// This drops the server's second flight three times.
TEST_P(TlsConnectDatagramPre13, DropServerSecondFlightThrice) {
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0xe));
Connect();
}
@@ -74,7 +74,7 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 {
expected_client_acks_(0),
expected_server_acks_(1) {}
- void SetUp() {
+ void SetUp() override {
TlsConnectDatagram13::SetUp();
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
SetFilters();
@@ -82,12 +82,8 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 {
void SetFilters() {
EnsureTlsSetup();
- client_->SetPacketFilter(client_filters_.chain_);
- client_filters_.ack_->SetAgent(client_.get());
- client_filters_.ack_->EnableDecryption();
- server_->SetPacketFilter(server_filters_.chain_);
- server_filters_.ack_->SetAgent(server_.get());
- server_filters_.ack_->EnableDecryption();
+ client_filters_.Init(client_);
+ server_filters_.Init(server_);
}
void HandshakeAndAck(const std::shared_ptr<TlsAgent>& agent) {
@@ -119,11 +115,17 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 {
class DropAckChain {
public:
DropAckChain()
- : records_(std::make_shared<TlsRecordRecorder>()),
- ack_(std::make_shared<TlsRecordRecorder>(content_ack)),
- drop_(std::make_shared<SelectiveRecordDropFilter>(0, false)),
- chain_(std::make_shared<ChainedPacketFilter>(
- ChainedPacketFilterInit({records_, ack_, drop_}))) {}
+ : records_(nullptr), ack_(nullptr), drop_(nullptr), chain_(nullptr) {}
+
+ void Init(const std::shared_ptr<TlsAgent>& agent) {
+ records_ = std::make_shared<TlsRecordRecorder>(agent);
+ ack_ = std::make_shared<TlsRecordRecorder>(agent, content_ack);
+ ack_->EnableDecryption();
+ drop_ = std::make_shared<SelectiveRecordDropFilter>(agent, 0, false);
+ chain_ = std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({records_, ack_, drop_}));
+ agent->SetFilter(chain_);
+ }
const TlsRecord& record(size_t i) const { return records_->record(i); }
@@ -227,7 +229,7 @@ TEST_F(TlsDropDatagram13, DropServerSecondRecordOnce) {
HandshakeAndAck(client_);
expected_client_acks_ = 1;
CheckedHandshakeSendReceive();
- CheckAcks(client_filters_, 0, {0});
+ CheckAcks(client_filters_, 0, {0}); // ServerHello
CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
}
@@ -257,7 +259,7 @@ TEST_F(TlsDropDatagram13, DropServerAckOnce) {
CheckPostHandshake();
// There should be two copies of the finished ACK
CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
- CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ CheckAcks(server_filters_, 1, {0x0002000000000000ULL});
}
// Drop the client certificate verify.
@@ -276,10 +278,9 @@ TEST_F(TlsDropDatagram13, DropClientCertVerify) {
// Ack of the whole client handshake.
CheckAcks(
server_filters_, 1,
- {0x0002000000000000ULL, // CH (we drop everything after this on client)
- 0x0002000000000003ULL, // CT (2)
- 0x0002000000000004ULL} // FIN (2)
- );
+ {0x0002000000000000ULL, // CH (we drop everything after this on client)
+ 0x0002000000000003ULL, // CT (2)
+ 0x0002000000000004ULL}); // FIN (2)
}
// Shrink the MTU down so that certs get split and drop the first piece.
@@ -303,10 +304,9 @@ TEST_F(TlsDropDatagram13, DropFirstHalfOfServerCertificate) {
EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len());
CheckedHandshakeSendReceive();
CheckAcks(client_filters_, 0,
- {0, // SH
- 0x0002000000000000ULL, // EE
- 0x0002000000000002ULL} // CT2
- );
+ {0, // SH
+ 0x0002000000000000ULL, // EE
+ 0x0002000000000002ULL}); // CT2
CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
}
@@ -540,7 +540,10 @@ TEST_F(TlsDropDatagram13, NoDropsDuringZeroRtt) {
ExpectEarlyDataAccepted(true);
CheckConnected();
SendReceive();
- CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ EXPECT_EQ(0U, client_filters_.ack_->count());
+ CheckAcks(server_filters_, 0,
+ {0x0001000000000001ULL, // EOED
+ 0x0002000000000000ULL}); // Finished
}
TEST_F(TlsDropDatagram13, DropEEDuringZeroRtt) {
@@ -558,7 +561,9 @@ TEST_F(TlsDropDatagram13, DropEEDuringZeroRtt) {
CheckConnected();
SendReceive();
CheckAcks(client_filters_, 0, {0});
- CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ CheckAcks(server_filters_, 0,
+ {0x0001000000000002ULL, // EOED
+ 0x0002000000000000ULL}); // Finished
}
class TlsReorderDatagram13 : public TlsDropDatagram13 {
@@ -688,6 +693,7 @@ TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) {
kTlsHandshakeType, DataBuffer(buf, sizeof(buf))));
server_->Handshake();
EXPECT_EQ(2UL, server_filters_.ack_->count());
+ // The server acknowledges client Finished twice.
CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
CheckAcks(server_filters_, 1, {0x0002000000000000ULL});
}
@@ -746,7 +752,9 @@ TEST_F(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) {
ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 0, 2}));
server_->Handshake();
CheckConnected();
- CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ EXPECT_EQ(0U, client_filters_.ack_->count());
+ // Acknowledgements for EOED and Finished.
+ CheckAcks(server_filters_, 0, {0x0001000000000002ULL, 0x0002000000000000ULL});
uint8_t buf[8];
rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
EXPECT_EQ(-1, rv);
@@ -783,7 +791,9 @@ TEST_F(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) {
ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 2, 0}));
server_->Handshake();
CheckConnected();
- CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ EXPECT_EQ(0U, client_filters_.ack_->count());
+ // Acknowledgements for EOED and Finished.
+ CheckAcks(server_filters_, 0, {0x0001000000000002ULL, 0x0002000000000000ULL});
uint8_t buf[8];
rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
EXPECT_EQ(-1, rv);
diff --git a/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/gtests/ssl_gtest/ssl_ecdh_unittest.cc
index 8caa7bf71..f0b499dc4 100644
--- a/gtests/ssl_gtest/ssl_ecdh_unittest.cc
+++ b/gtests/ssl_gtest/ssl_ecdh_unittest.cc
@@ -75,9 +75,9 @@ TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) {
// This causes a HelloRetryRequest in TLS 1.3. Earlier versions don't care.
TEST_P(TlsConnectGeneric, ConnectEcdheP384Server) {
EnsureTlsSetup();
- auto hrr_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeHelloRetryRequest);
- server_->SetPacketFilter(hrr_capture);
+ auto hrr_capture = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeHelloRetryRequest);
+ server_->SetFilter(hrr_capture);
const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
server_->ConfigNamedGroups(groups);
Connect();
@@ -193,8 +193,8 @@ TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) {
class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter {
public:
- TlsKeyExchangeGroupCapture()
- : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}),
+ TlsKeyExchangeGroupCapture(const std::shared_ptr<TlsAgent> &agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}),
group_(ssl_grp_none) {}
SSLNamedGroup group() const { return group_; }
@@ -221,10 +221,10 @@ class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter {
// P-256 is supported by the client (<= 1.2 only).
TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) {
EnsureTlsSetup();
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn));
- auto group_capture = std::make_shared<TlsKeyExchangeGroupCapture>();
- server_->SetPacketFilter(group_capture);
+ client_->SetFilter(
+ std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn));
+ auto group_capture = std::make_shared<TlsKeyExchangeGroupCapture>(server_);
+ server_->SetFilter(group_capture);
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
@@ -236,8 +236,8 @@ TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) {
// Supported groups is mandatory in TLS 1.3.
TEST_P(TlsConnectTls13, DropSupportedGroupExtension) {
EnsureTlsSetup();
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn));
+ client_->SetFilter(
+ std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn));
ConnectExpectAlert(server_, kTlsAlertMissingExtension);
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION);
@@ -516,7 +516,8 @@ TEST_P(TlsKeyExchangeTest13, MultipleClientShares) {
// Replace the point in the client key exchange message with an empty one
class ECCClientKEXFilter : public TlsHandshakeFilter {
public:
- ECCClientKEXFilter() : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}) {}
+ ECCClientKEXFilter(const std::shared_ptr<TlsAgent> &client)
+ : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}) {}
protected:
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
@@ -532,7 +533,8 @@ class ECCClientKEXFilter : public TlsHandshakeFilter {
// Replace the point in the server key exchange message with an empty one
class ECCServerKEXFilter : public TlsHandshakeFilter {
public:
- ECCServerKEXFilter() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {}
+ ECCServerKEXFilter(const std::shared_ptr<TlsAgent> &server)
+ : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}) {}
protected:
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
@@ -550,15 +552,13 @@ class ECCServerKEXFilter : public TlsHandshakeFilter {
};
TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyServerPoint) {
- // add packet filter
- server_->SetPacketFilter(std::make_shared<ECCServerKEXFilter>());
+ server_->SetFilter(std::make_shared<ECCServerKEXFilter>(server_));
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH);
}
TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyClientPoint) {
- // add packet filter
- client_->SetPacketFilter(std::make_shared<ECCClientKEXFilter>());
+ client_->SetFilter(std::make_shared<ECCClientKEXFilter>(client_));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_KEY_EXCH);
}
diff --git a/gtests/ssl_gtest/ssl_extension_unittest.cc b/gtests/ssl_gtest/ssl_extension_unittest.cc
index a9daeed82..295090cdd 100644
--- a/gtests/ssl_gtest/ssl_extension_unittest.cc
+++ b/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -19,8 +19,9 @@ namespace nss_test {
class TlsExtensionTruncator : public TlsExtensionFilter {
public:
- TlsExtensionTruncator(uint16_t extension, size_t length)
- : extension_(extension), length_(length) {}
+ TlsExtensionTruncator(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension, size_t length)
+ : TlsExtensionFilter(agent), extension_(extension), length_(length) {}
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) {
@@ -42,8 +43,9 @@ class TlsExtensionTruncator : public TlsExtensionFilter {
class TlsExtensionDamager : public TlsExtensionFilter {
public:
- TlsExtensionDamager(uint16_t extension, size_t index)
- : extension_(extension), index_(index) {}
+ TlsExtensionDamager(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension, size_t index)
+ : TlsExtensionFilter(agent), extension_(extension), index_(index) {}
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) {
@@ -63,8 +65,11 @@ class TlsExtensionDamager : public TlsExtensionFilter {
class TlsExtensionAppender : public TlsHandshakeFilter {
public:
- TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data)
- : TlsHandshakeFilter({handshake_type}), extension_(ext), data_(data) {}
+ TlsExtensionAppender(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type, uint16_t ext, DataBuffer& data)
+ : TlsHandshakeFilter(agent, {handshake_type}),
+ extension_(ext),
+ data_(data) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -124,13 +129,13 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
void ClientHelloErrorTest(std::shared_ptr<PacketFilter> filter,
uint8_t desc = kTlsAlertDecodeError) {
- client_->SetPacketFilter(filter);
+ client_->SetFilter(filter);
ConnectExpectAlert(server_, desc);
}
void ServerHelloErrorTest(std::shared_ptr<PacketFilter> filter,
uint8_t desc = kTlsAlertDecodeError) {
- server_->SetPacketFilter(filter);
+ server_->SetFilter(filter);
ConnectExpectAlert(client_, desc);
}
@@ -156,7 +161,7 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
StartConnect();
client_->Handshake(); // Send ClientHello
server_->Handshake(); // Send HRR.
- client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(type));
+ client_->SetFilter(std::make_shared<TlsExtensionDropper>(client_, type));
Handshake();
client_->CheckErrorCode(client_error);
server_->CheckErrorCode(server_error);
@@ -197,8 +202,8 @@ class TlsExtensionTest13
void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) {
DataBuffer versions_buf(buf, len);
- client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
- ssl_tls13_supported_versions_xtn, versions_buf));
+ client_->SetFilter(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_tls13_supported_versions_xtn, versions_buf));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -209,8 +214,8 @@ class TlsExtensionTest13
size_t index = versions_buf.Write(0, 2, 1);
versions_buf.Write(index, version, 2);
- client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
- ssl_tls13_supported_versions_xtn, versions_buf));
+ client_->SetFilter(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_tls13_supported_versions_xtn, versions_buf));
ConnectExpectFail();
}
};
@@ -241,26 +246,26 @@ class TlsExtensionTestPre13 : public TlsExtensionTestBase,
TEST_P(TlsExtensionTestGeneric, DamageSniLength) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 1));
+ std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 1));
}
TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 4));
+ std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 4));
}
TEST_P(TlsExtensionTestGeneric, TruncateSni) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_server_name_xtn, 7));
+ std::make_shared<TlsExtensionTruncator>(client_, ssl_server_name_xtn, 7));
}
// A valid extension that appears twice will be reported as unsupported.
TEST_P(TlsExtensionTestGeneric, RepeatSni) {
DataBuffer extension;
InitSimpleSni(&extension);
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionInjector>(ssl_server_name_xtn, extension),
- kTlsAlertIllegalParameter);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionInjector>(
+ client_, ssl_server_name_xtn, extension),
+ kTlsAlertIllegalParameter);
}
// An SNI entry with zero length is considered invalid (strangely, not if it is
@@ -272,23 +277,23 @@ TEST_P(TlsExtensionTestGeneric, BadSni) {
extension.Allocate(simple.len() + 3);
extension.Write(0, static_cast<uint32_t>(0), 3);
extension.Write(3, simple);
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_server_name_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, EmptySni) {
DataBuffer extension;
extension.Allocate(2);
extension.Write(0, static_cast<uint32_t>(0), 2);
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_server_name_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) {
EnableAlpn();
DataBuffer extension;
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension),
+ client_, ssl_app_layer_protocol_xtn, extension),
kTlsAlertIllegalParameter);
}
@@ -299,21 +304,21 @@ TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension),
+ client_, ssl_app_layer_protocol_xtn, extension),
kTlsAlertNoApplicationProtocol);
}
TEST_P(TlsExtensionTestGeneric, OneByteAlpn) {
EnableAlpn();
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 1));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_app_layer_protocol_xtn, 1));
}
TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) {
EnableAlpn();
// This will leave the length of the second entry, but no value.
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 5));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_app_layer_protocol_xtn, 5));
}
TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
@@ -321,7 +326,7 @@ TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
const uint8_t val[] = {0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ client_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, AlpnMismatch) {
@@ -340,7 +345,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) {
@@ -348,7 +353,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) {
const uint8_t val[] = {0x00, 0x01, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) {
@@ -356,7 +361,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) {
@@ -364,7 +369,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) {
const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) {
@@ -372,7 +377,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) {
const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) {
@@ -380,7 +385,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) {
const uint8_t val[] = {0x00, 0x02, 0x99, 0x61};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) {
@@ -388,43 +393,43 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) {
const uint8_t val[] = {0x00, 0x02, 0x01, 0x67};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension),
+ server_, ssl_app_layer_protocol_xtn, extension),
kTlsAlertIllegalParameter);
}
TEST_P(TlsExtensionTestDtls, SrtpShort) {
EnableSrtp();
ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_use_srtp_xtn, 3));
+ std::make_shared<TlsExtensionTruncator>(client_, ssl_use_srtp_xtn, 3));
}
TEST_P(TlsExtensionTestDtls, SrtpOdd) {
EnableSrtp();
const uint8_t val[] = {0x00, 0x01, 0xff, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionReplacer>(ssl_use_srtp_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_use_srtp_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) {
const uint8_t val[] = {0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension),
+ client_, ssl_signature_algorithms_xtn, extension),
kTlsAlertHandshakeFailure);
}
@@ -432,7 +437,7 @@ TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsNoOverlap) {
const uint8_t val[] = {0x00, 0x02, 0xff, 0xff};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension),
+ client_, ssl_signature_algorithms_xtn, extension),
kTlsAlertHandshakeFailure);
}
@@ -440,12 +445,12 @@ TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) {
const uint8_t val[] = {0x00, 0x01, 0x04};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn),
+ std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn),
version_ < SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsAlertDecryptError
: kTlsAlertMissingExtension);
}
@@ -454,63 +459,63 @@ TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) {
const uint8_t val[] = {0x00, 0x01, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_elliptic_curves_xtn, extension));
+ client_, ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) {
const uint8_t val[] = {0x09, 0x99, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_elliptic_curves_xtn, extension));
+ client_, ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_elliptic_curves_xtn, extension));
+ client_, ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) {
const uint8_t val[] = {0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_ec_point_formats_xtn, extension));
+ client_, ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) {
const uint8_t val[] = {0x99, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_ec_point_formats_xtn, extension));
+ client_, ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) {
const uint8_t val[] = {0x01, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_ec_point_formats_xtn, extension));
+ client_, ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) {
const uint8_t val[] = {0x99};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_renegotiation_info_xtn, extension));
+ client_, ssl_renegotiation_info_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) {
const uint8_t val[] = {0x01, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_renegotiation_info_xtn, extension));
+ client_, ssl_renegotiation_info_xtn, extension));
}
// The extension has to contain a length.
TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) {
DataBuffer extension;
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_renegotiation_info_xtn, extension));
+ client_, ssl_renegotiation_info_xtn, extension));
}
// This only works on TLS 1.2, since it relies on static RSA; otherwise libssl
@@ -519,10 +524,10 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_rsae_sha512,
ssl_sig_rsa_pss_rsae_sha384};
- auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn);
+ auto capture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_signature_algorithms_xtn);
client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes));
- client_->SetPacketFilter(capture);
+ client_->SetFilter(capture);
EnableOnlyStaticRsaCiphers();
Connect();
@@ -540,9 +545,9 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
// Temporary test to verify that we choke on an empty ClientKeyShare.
// This test will fail when we implement HelloRetryRequest.
TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_tls13_key_share_xtn, 2),
- kTlsAlertHandshakeFailure);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_tls13_key_share_xtn, 2),
+ kTlsAlertHandshakeFailure);
}
// These tests only work in stream mode because the client sends a
@@ -551,8 +556,8 @@ TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
// packet gets dropped.
TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) {
EnsureTlsSetup();
- server_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_tls13_key_share_xtn));
+ server_->SetFilter(
+ std::make_shared<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn));
client_->ExpectSendAlert(kTlsAlertMissingExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -572,8 +577,8 @@ TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) {
0x02};
DataBuffer buf(key_share, sizeof(key_share));
EnsureTlsSetup();
- server_->SetPacketFilter(
- std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf));
+ server_->SetFilter(std::make_shared<TlsExtensionReplacer>(
+ server_, ssl_tls13_key_share_xtn, buf));
client_->ExpectSendAlert(kTlsAlertIllegalParameter);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -594,8 +599,8 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
0x02};
DataBuffer buf(key_share, sizeof(key_share));
EnsureTlsSetup();
- server_->SetPacketFilter(
- std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf));
+ server_->SetFilter(std::make_shared<TlsExtensionReplacer>(
+ server_, ssl_tls13_key_share_xtn, buf));
client_->ExpectSendAlert(kTlsAlertMissingExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -606,8 +611,8 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) {
SetupForResume();
DataBuffer empty;
- server_->SetPacketFilter(std::make_shared<TlsExtensionInjector>(
- ssl_signature_algorithms_xtn, empty));
+ server_->SetFilter(std::make_shared<TlsExtensionInjector>(
+ server_, ssl_signature_algorithms_xtn, empty));
client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -627,8 +632,12 @@ typedef std::function<void(TlsPreSharedKeyReplacer*)>
class TlsPreSharedKeyReplacer : public TlsExtensionFilter {
public:
- TlsPreSharedKeyReplacer(TlsPreSharedKeyReplacerFunc function)
- : identities_(), binders_(), function_(function) {}
+ TlsPreSharedKeyReplacer(const std::shared_ptr<TlsAgent>& agent,
+ TlsPreSharedKeyReplacerFunc function)
+ : TlsExtensionFilter(agent),
+ identities_(),
+ binders_(),
+ function_(function) {}
static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size,
const std::unique_ptr<DataBuffer>& replace,
@@ -742,8 +751,10 @@ class TlsPreSharedKeyReplacer : public TlsExtensionFilter {
TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) {
SetupForResume();
- client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([](
- TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); }));
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->identities_[0].identity.Truncate(0);
+ }));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -753,8 +764,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) {
TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1);
}));
ConnectExpectAlert(server_, kTlsAlertDecryptError);
@@ -766,8 +777,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->binders_[0].Write(r->binders_[0].len(), 0xff, 1);
}));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
@@ -779,7 +790,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) {
SetupForResume();
- client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_,
[](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); }));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
@@ -791,8 +803,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) {
TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->identities_.push_back(r->identities_[0]);
r->binders_.push_back(r->binders_[0]);
}));
@@ -806,8 +818,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) {
TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->identities_.push_back(r->identities_[0]);
}));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
@@ -818,8 +830,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) {
SetupForResume();
- client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([](
- TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); }));
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->binders_.push_back(r->binders_[0]);
+ }));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -831,8 +845,8 @@ TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) {
const uint8_t empty_buf[] = {0};
DataBuffer empty(empty_buf, 0);
// Inject an unused extension after the PSK extension.
- client_->SetPacketFilter(std::make_shared<TlsExtensionAppender>(
- kTlsHandshakeClientHello, 0xffff, empty));
+ client_->SetFilter(std::make_shared<TlsExtensionAppender>(
+ client_, kTlsHandshakeClientHello, 0xffff, empty));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -842,8 +856,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) {
SetupForResume();
DataBuffer empty;
- client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(
- ssl_tls13_psk_key_exchange_modes_xtn));
+ client_->SetFilter(std::make_shared<TlsExtensionDropper>(
+ client_, ssl_tls13_psk_key_exchange_modes_xtn));
ConnectExpectAlert(server_, kTlsAlertMissingExtension);
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES);
@@ -858,8 +872,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
kTls13PskKe};
DataBuffer modes(ke_modes, sizeof(ke_modes));
- client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
- ssl_tls13_psk_key_exchange_modes_xtn, modes));
+ client_->SetFilter(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_tls13_psk_key_exchange_modes_xtn, modes));
client_->ExpectSendAlert(kTlsAlertBadRecordMac);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -870,8 +884,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) {
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
auto capture = std::make_shared<TlsExtensionCapture>(
- ssl_tls13_psk_key_exchange_modes_xtn);
- client_->SetPacketFilter(capture);
+ client_, ssl_tls13_psk_key_exchange_modes_xtn);
+ client_->SetFilter(capture);
Connect();
EXPECT_FALSE(capture->captured());
}
@@ -966,12 +980,11 @@ class TlsBogusExtensionTest : public TlsConnectTestBase,
void AddFilter(uint8_t message, uint16_t extension) {
static uint8_t empty_buf[1] = {0};
DataBuffer empty(empty_buf, 0);
- auto filter =
- std::make_shared<TlsExtensionAppender>(message, extension, empty);
+ auto filter = std::make_shared<TlsExtensionAppender>(server_, message,
+ extension, empty);
+ server_->SetFilter(filter);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
- server_->SetTlsRecordFilter(filter);
- } else {
- server_->SetPacketFilter(filter);
+ filter->EnableDecryption();
}
}
@@ -1087,8 +1100,9 @@ TEST_P(TlsConnectStream, IncludePadding) {
SECStatus rv = SSL_SetURL(client_->ssl_fd(), long_name);
EXPECT_EQ(SECSuccess, rv);
- auto capture = std::make_shared<TlsExtensionCapture>(ssl_padding_xtn);
- client_->SetPacketFilter(capture);
+ auto capture =
+ std::make_shared<TlsExtensionCapture>(client_, ssl_padding_xtn);
+ client_->SetFilter(capture);
client_->StartConnect();
client_->Handshake();
EXPECT_TRUE(capture->captured());
diff --git a/gtests/ssl_gtest/ssl_fragment_unittest.cc b/gtests/ssl_gtest/ssl_fragment_unittest.cc
index 64b824786..f4940bf28 100644
--- a/gtests/ssl_gtest/ssl_fragment_unittest.cc
+++ b/gtests/ssl_gtest/ssl_fragment_unittest.cc
@@ -149,13 +149,13 @@ class RecordFragmenter : public PacketFilter {
};
TEST_P(TlsConnectDatagram, FragmentClientPackets) {
- client_->SetPacketFilter(std::make_shared<RecordFragmenter>());
+ client_->SetFilter(std::make_shared<RecordFragmenter>());
Connect();
SendReceive();
}
TEST_P(TlsConnectDatagram, FragmentServerPackets) {
- server_->SetPacketFilter(std::make_shared<RecordFragmenter>());
+ server_->SetFilter(std::make_shared<RecordFragmenter>());
Connect();
SendReceive();
}
diff --git a/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/gtests/ssl_gtest/ssl_fuzz_unittest.cc
index ab4c0eab7..03c98d245 100644
--- a/gtests/ssl_gtest/ssl_fuzz_unittest.cc
+++ b/gtests/ssl_gtest/ssl_fuzz_unittest.cc
@@ -27,7 +27,8 @@ class TlsFuzzTest : public ::testing::Test {};
// Record the application data stream.
class TlsApplicationDataRecorder : public TlsRecordFilter {
public:
- TlsApplicationDataRecorder() : buffer_() {}
+ TlsApplicationDataRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent), buffer_() {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
@@ -106,16 +107,18 @@ FUZZ_P(TlsConnectGeneric, DeterministicTranscript) {
DisableECDHEServerKeyReuse();
DataBuffer buffer;
- client_->SetPacketFilter(std::make_shared<TlsConversationRecorder>(buffer));
- server_->SetPacketFilter(std::make_shared<TlsConversationRecorder>(buffer));
+ client_->SetFilter(
+ std::make_shared<TlsConversationRecorder>(client_, buffer));
+ server_->SetFilter(
+ std::make_shared<TlsConversationRecorder>(server_, buffer));
// Reset the RNG state.
EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0));
Connect();
// Ensure the filters go away before |buffer| does.
- client_->DeletePacketFilter();
- server_->DeletePacketFilter();
+ client_->ClearFilter();
+ server_->ClearFilter();
if (last.len() > 0) {
EXPECT_EQ(last, buffer);
@@ -133,10 +136,10 @@ FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) {
EnsureTlsSetup();
// Set up app data filters.
- auto client_recorder = std::make_shared<TlsApplicationDataRecorder>();
- client_->SetPacketFilter(client_recorder);
- auto server_recorder = std::make_shared<TlsApplicationDataRecorder>();
- server_->SetPacketFilter(server_recorder);
+ auto client_recorder = std::make_shared<TlsApplicationDataRecorder>(client_);
+ client_->SetFilter(client_recorder);
+ auto server_recorder = std::make_shared<TlsApplicationDataRecorder>(server_);
+ server_->SetFilter(server_recorder);
Connect();
@@ -161,10 +164,10 @@ FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) {
FUZZ_P(TlsConnectGeneric, BogusClientFinished) {
EnsureTlsSetup();
- auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
- kTlsHandshakeFinished,
+ auto filter = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
+ client_, kTlsHandshakeFinished,
DataBuffer(kShortEmptyFinished, sizeof(kShortEmptyFinished)));
- client_->SetPacketFilter(i1);
+ client_->SetFilter(filter);
Connect();
SendReceive();
}
@@ -173,10 +176,10 @@ FUZZ_P(TlsConnectGeneric, BogusClientFinished) {
FUZZ_P(TlsConnectGeneric, BogusServerFinished) {
EnsureTlsSetup();
- auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
- kTlsHandshakeFinished,
+ auto filter = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
+ server_, kTlsHandshakeFinished,
DataBuffer(kLongEmptyFinished, sizeof(kLongEmptyFinished)));
- server_->SetPacketFilter(i1);
+ server_->SetFilter(filter);
Connect();
SendReceive();
}
@@ -187,7 +190,7 @@ FUZZ_P(TlsConnectGeneric, BogusServerAuthSignature) {
uint8_t msg_type = version_ == SSL_LIBRARY_VERSION_TLS_1_3
? kTlsHandshakeCertificateVerify
: kTlsHandshakeServerKeyExchange;
- server_->SetPacketFilter(std::make_shared<TlsLastByteDamager>(msg_type));
+ server_->SetFilter(std::make_shared<TlsLastByteDamager>(server_, msg_type));
Connect();
SendReceive();
}
@@ -197,8 +200,8 @@ FUZZ_P(TlsConnectGeneric, BogusClientAuthSignature) {
EnsureTlsSetup();
client_->SetupClientAuth();
server_->RequestClientAuth(true);
- client_->SetPacketFilter(
- std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify));
+ client_->SetFilter(std::make_shared<TlsLastByteDamager>(
+ client_, kTlsHandshakeCertificateVerify));
Connect();
}
@@ -219,29 +222,29 @@ FUZZ_P(TlsConnectGeneric, SessionTicketResumption) {
FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) {
ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
- auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeNewSessionTicket);
- server_->SetPacketFilter(i1);
+ auto filter = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeNewSessionTicket);
+ server_->SetFilter(filter);
Connect();
- std::cerr << "ticket" << i1->buffer() << std::endl;
+ std::cerr << "ticket" << filter->buffer() << std::endl;
size_t offset = 4; /* lifetime */
if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
offset += 4; /* ticket_age_add */
uint32_t nonce_len = 0;
- EXPECT_TRUE(i1->buffer().Read(offset, 1, &nonce_len));
+ EXPECT_TRUE(filter->buffer().Read(offset, 1, &nonce_len));
offset += 1 + nonce_len;
}
offset += 2 + /* ticket length */
2; /* TLS_EX_SESS_TICKET_VERSION */
// Check the protocol version number.
uint32_t tls_version = 0;
- EXPECT_TRUE(i1->buffer().Read(offset, sizeof(version_), &tls_version));
+ EXPECT_TRUE(filter->buffer().Read(offset, sizeof(version_), &tls_version));
EXPECT_EQ(version_, static_cast<decltype(version_)>(tls_version));
// Check the cipher suite.
uint32_t suite = 0;
- EXPECT_TRUE(i1->buffer().Read(offset + sizeof(version_), 2, &suite));
+ EXPECT_TRUE(filter->buffer().Read(offset + sizeof(version_), 2, &suite));
client_->CheckCipherSuite(static_cast<uint16_t>(suite));
}
}
diff --git a/gtests/ssl_gtest/ssl_hrr_unittest.cc b/gtests/ssl_gtest/ssl_hrr_unittest.cc
index 93e19a720..55b6d6d90 100644
--- a/gtests/ssl_gtest/ssl_hrr_unittest.cc
+++ b/gtests/ssl_gtest/ssl_hrr_unittest.cc
@@ -35,17 +35,17 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
// Send first ClientHello and send 0-RTT data
auto capture_early_data =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn);
- client_->SetPacketFilter(capture_early_data);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_early_data_xtn);
+ client_->SetFilter(capture_early_data);
client_->Handshake();
EXPECT_EQ(k0RttDataLen, PR_Write(client_->ssl_fd(), k0RttData,
k0RttDataLen)); // 0-RTT write.
EXPECT_TRUE(capture_early_data->captured());
// Send the HelloRetryRequest
- auto hrr_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeHelloRetryRequest);
- server_->SetPacketFilter(hrr_capture);
+ auto hrr_capture = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeHelloRetryRequest);
+ server_->SetFilter(hrr_capture);
server_->Handshake();
EXPECT_LT(0U, hrr_capture->buffer().len());
@@ -56,8 +56,8 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
// Make a new capture for the early data.
capture_early_data =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn);
- client_->SetPacketFilter(capture_early_data);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_early_data_xtn);
+ client_->SetFilter(capture_early_data);
// Complete the handshake successfully
Handshake();
@@ -71,6 +71,10 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
// packet. If the record is split into two packets, or there are multiple
// handshake packets, this will break.
class CorrectMessageSeqAfterHrrFilter : public TlsRecordFilter {
+ public:
+ CorrectMessageSeqAfterHrrFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent) {}
+
protected:
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& record, size_t* offset,
@@ -131,8 +135,8 @@ TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) {
// Correct the DTLS message sequence number after an HRR.
if (variant_ == ssl_variant_datagram) {
- client_->SetPacketFilter(
- std::make_shared<CorrectMessageSeqAfterHrrFilter>());
+ client_->SetFilter(
+ std::make_shared<CorrectMessageSeqAfterHrrFilter>(client_));
}
server_->SetPeer(client_);
@@ -151,7 +155,8 @@ TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) {
class KeyShareReplayer : public TlsExtensionFilter {
public:
- KeyShareReplayer() {}
+ KeyShareReplayer(const std::shared_ptr<TlsAgent>& agent)
+ : TlsExtensionFilter(agent) {}
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
@@ -178,7 +183,7 @@ class KeyShareReplayer : public TlsExtensionFilter {
// server should reject this.
TEST_P(TlsConnectTls13, RetryWithSameKeyShare) {
EnsureTlsSetup();
- client_->SetPacketFilter(std::make_shared<KeyShareReplayer>());
+ client_->SetFilter(std::make_shared<KeyShareReplayer>(client_));
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
ssl_grp_ec_secp521r1};
server_->ConfigNamedGroups(groups);
@@ -192,7 +197,7 @@ TEST_P(TlsConnectTls13, RetryWithSameKeyShare) {
TEST_P(TlsConnectTls13, RetryWithTwoShares) {
EnsureTlsSetup();
EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
- client_->SetPacketFilter(std::make_shared<KeyShareReplayer>());
+ client_->SetFilter(std::make_shared<KeyShareReplayer>(client_));
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
ssl_grp_ec_secp521r1};
@@ -238,9 +243,10 @@ TEST_P(TlsConnectTls13, RetryCallbackAcceptGroupMismatch) {
return ssl_hello_retry_accept;
};
- auto capture = std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
+ auto capture =
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_cookie_xtn);
capture->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
- server_->SetPacketFilter(capture);
+ server_->SetFilter(capture);
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
server_->ConfigNamedGroups(groups);
@@ -359,14 +365,14 @@ SSLHelloRetryRequestAction RetryHello(PRBool firstHello,
TEST_P(TlsConnectTls13, RetryCallbackRetry) {
EnsureTlsSetup();
- auto capture_hrr = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- ssl_hs_hello_retry_request);
+ auto capture_hrr = std::make_shared<TlsHandshakeRecorder>(
+ server_, ssl_hs_hello_retry_request);
auto capture_key_share =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
std::vector<std::shared_ptr<PacketFilter>> chain = {capture_hrr,
capture_key_share};
- server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(chain));
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(chain));
size_t cb_called = 0;
EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
@@ -383,8 +389,8 @@ TEST_P(TlsConnectTls13, RetryCallbackRetry) {
<< "no key_share extension expected";
auto capture_cookie =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
- client_->SetPacketFilter(capture_cookie);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_cookie_xtn);
+ client_->SetFilter(capture_cookie);
Handshake();
CheckConnected();
@@ -413,9 +419,9 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) {
EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
auto capture_server =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
- server_->SetPacketFilter(capture_server);
+ server_->SetFilter(capture_server);
size_t cb_called = 0;
EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
@@ -431,8 +437,8 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) {
<< "no key_share extension expected from server";
auto capture_client_2nd =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
- client_->SetPacketFilter(capture_client_2nd);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
+ client_->SetFilter(capture_client_2nd);
Handshake();
CheckConnected();
@@ -449,12 +455,12 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithGroupMismatch) {
EnsureTlsSetup();
auto capture_cookie =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_cookie_xtn);
capture_cookie->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
auto capture_key_share =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
- server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
ChainedPacketFilterInit{capture_cookie, capture_key_share}));
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
@@ -493,9 +499,9 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithToken) {
EnsureTlsSetup();
auto capture_key_share =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
- server_->SetPacketFilter(capture_key_share);
+ server_->SetFilter(capture_key_share);
size_t cb_called = 0;
EXPECT_EQ(SECSuccess,
@@ -513,9 +519,9 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithTokenAndGroupMismatch) {
server_->ConfigNamedGroups(groups);
auto capture_key_share =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
- server_->SetPacketFilter(capture_key_share);
+ server_->SetFilter(capture_key_share);
size_t cb_called = 0;
EXPECT_EQ(SECSuccess,
@@ -589,8 +595,8 @@ TEST_P(TlsConnectTls13, RetryStatefulDropCookie) {
EnsureTlsSetup();
TriggerHelloRetryRequest(client_, server_);
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_tls13_cookie_xtn));
+ client_->SetFilter(
+ std::make_shared<TlsExtensionDropper>(client_, ssl_tls13_cookie_xtn));
ExpectAlert(server_, kTlsAlertMissingExtension);
Handshake();
@@ -603,8 +609,9 @@ TEST_F(TlsConnectStreamTls13, RetryStatelessDamageFirstClientHello) {
ConfigureSelfEncrypt();
EnsureTlsSetup();
- auto damage_ch = std::make_shared<TlsExtensionInjector>(0xfff3, DataBuffer());
- client_->SetPacketFilter(damage_ch);
+ auto damage_ch =
+ std::make_shared<TlsExtensionInjector>(client_, 0xfff3, DataBuffer());
+ client_->SetFilter(damage_ch);
TriggerHelloRetryRequest(client_, server_);
MakeNewServer();
@@ -625,8 +632,9 @@ TEST_F(TlsConnectStreamTls13, RetryStatelessDamageSecondClientHello) {
TriggerHelloRetryRequest(client_, server_);
MakeNewServer();
- auto damage_ch = std::make_shared<TlsExtensionInjector>(0xfff3, DataBuffer());
- client_->SetPacketFilter(damage_ch);
+ auto damage_ch =
+ std::make_shared<TlsExtensionInjector>(client_, 0xfff3, DataBuffer());
+ client_->SetFilter(damage_ch);
// Key exchange fails when the handshake continues because client and server
// disagree about the transcript.
@@ -640,7 +648,7 @@ TEST_F(TlsConnectStreamTls13, RetryStatelessDamageSecondClientHello) {
// Read the cipher suite from the HRR and disable it on the identified agent.
static void DisableSuiteFromHrr(
std::shared_ptr<TlsAgent>& agent,
- std::shared_ptr<TlsInspectorRecordHandshakeMessage>& capture_hrr) {
+ std::shared_ptr<TlsHandshakeRecorder>& capture_hrr) {
uint32_t tmp;
size_t offset = 2 + 32; // skip version + server_random
ASSERT_TRUE(
@@ -657,9 +665,9 @@ TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteClient) {
ConfigureSelfEncrypt();
EnsureTlsSetup();
- auto capture_hrr = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- ssl_hs_hello_retry_request);
- server_->SetPacketFilter(capture_hrr);
+ auto capture_hrr = std::make_shared<TlsHandshakeRecorder>(
+ server_, ssl_hs_hello_retry_request);
+ server_->SetFilter(capture_hrr);
TriggerHelloRetryRequest(client_, server_);
MakeNewServer();
@@ -678,9 +686,9 @@ TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteServer) {
ConfigureSelfEncrypt();
EnsureTlsSetup();
- auto capture_hrr = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- ssl_hs_hello_retry_request);
- server_->SetPacketFilter(capture_hrr);
+ auto capture_hrr = std::make_shared<TlsHandshakeRecorder>(
+ server_, ssl_hs_hello_retry_request);
+ server_->SetFilter(capture_hrr);
TriggerHelloRetryRequest(client_, server_);
MakeNewServer();
@@ -761,8 +769,8 @@ TEST_F(TlsConnectStreamTls13, RetryWithDifferentCipherSuite) {
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
server_->ConfigNamedGroups(groups);
// Then switch out the default suite (TLS_AES_128_GCM_SHA256).
- server_->SetPacketFilter(std::make_shared<SelectedCipherSuiteReplacer>(
- TLS_CHACHA20_POLY1305_SHA256));
+ server_->SetFilter(std::make_shared<SelectedCipherSuiteReplacer>(
+ server_, TLS_CHACHA20_POLY1305_SHA256));
client_->ExpectSendAlert(kTlsAlertIllegalParameter);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
@@ -777,7 +785,7 @@ TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) {
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
ssl_grp_ec_secp521r1};
server_->ConfigNamedGroups(groups);
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x2));
Connect();
}
@@ -833,9 +841,9 @@ TEST_P(TlsKeyExchange13,
EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
auto capture_server =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
- server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
ChainedPacketFilterInit{capture_hrr_, capture_server}));
size_t cb_called = 0;
diff --git a/gtests/ssl_gtest/ssl_keylog_unittest.cc b/gtests/ssl_gtest/ssl_keylog_unittest.cc
index 8ed342305..322b64837 100644
--- a/gtests/ssl_gtest/ssl_keylog_unittest.cc
+++ b/gtests/ssl_gtest/ssl_keylog_unittest.cc
@@ -20,8 +20,8 @@ static const std::string keylog_env = "SSLKEYLOGFILE=" + keylog_file_path;
class KeyLogFileTest : public TlsConnectGeneric {
public:
- void SetUp() {
- TlsConnectTestBase::SetUp();
+ void SetUp() override {
+ TlsConnectGeneric::SetUp();
// Remove previous results (if any).
(void)remove(keylog_file_path.c_str());
PR_SetEnv(keylog_env.c_str());
diff --git a/gtests/ssl_gtest/ssl_loopback_unittest.cc b/gtests/ssl_gtest/ssl_loopback_unittest.cc
index f1a789367..227c06234 100644
--- a/gtests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -56,7 +56,8 @@ TEST_P(TlsConnectGeneric, CipherSuiteMismatch) {
class TlsAlertRecorder : public TlsRecordFilter {
public:
- TlsAlertRecorder() : level_(255), description_(255) {}
+ TlsAlertRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent), level_(255), description_(255) {}
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
@@ -86,9 +87,9 @@ class TlsAlertRecorder : public TlsRecordFilter {
class HelloTruncator : public TlsHandshakeFilter {
public:
- HelloTruncator()
+ HelloTruncator(const std::shared_ptr<TlsAgent>& agent)
: TlsHandshakeFilter(
- {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {}
+ agent, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {}
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override {
@@ -99,9 +100,9 @@ class HelloTruncator : public TlsHandshakeFilter {
// Verify that when NSS reports that an alert is sent, it is actually sent.
TEST_P(TlsConnectGeneric, CaptureAlertServer) {
- client_->SetPacketFilter(std::make_shared<HelloTruncator>());
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- server_->SetPacketFilter(alert_recorder);
+ client_->SetFilter(std::make_shared<HelloTruncator>(client_));
+ auto alert_recorder = std::make_shared<TlsAlertRecorder>(server_);
+ server_->SetFilter(alert_recorder);
ConnectExpectAlert(server_, kTlsAlertDecodeError);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
@@ -109,9 +110,9 @@ TEST_P(TlsConnectGeneric, CaptureAlertServer) {
}
TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
- server_->SetPacketFilter(std::make_shared<HelloTruncator>());
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- client_->SetPacketFilter(alert_recorder);
+ server_->SetFilter(std::make_shared<HelloTruncator>(server_));
+ auto alert_recorder = std::make_shared<TlsAlertRecorder>(client_);
+ client_->SetFilter(alert_recorder);
ConnectExpectAlert(client_, kTlsAlertDecodeError);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
@@ -120,9 +121,9 @@ TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
// In TLS 1.3, the server can't read the client alert.
TEST_P(TlsConnectTls13, CaptureAlertClient) {
- server_->SetPacketFilter(std::make_shared<HelloTruncator>());
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- client_->SetPacketFilter(alert_recorder);
+ server_->SetFilter(std::make_shared<HelloTruncator>(server_));
+ auto alert_recorder = std::make_shared<TlsAlertRecorder>(client_);
+ client_->SetFilter(alert_recorder);
StartConnect();
@@ -173,7 +174,8 @@ TEST_P(TlsConnectGeneric, ConnectSendReceive) {
class SaveTlsRecord : public TlsRecordFilter {
public:
- SaveTlsRecord(size_t index) : index_(index), count_(0), contents_() {}
+ SaveTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index)
+ : TlsRecordFilter(agent), index_(index), count_(0), contents_() {}
const DataBuffer& contents() const { return contents_; }
@@ -198,8 +200,9 @@ class SaveTlsRecord : public TlsRecordFilter {
TEST_F(TlsConnectStreamTls13, DecryptRecordClient) {
EnsureTlsSetup();
// 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer
- auto saved = std::make_shared<SaveTlsRecord>(3);
- client_->SetTlsRecordFilter(saved);
+ auto saved = std::make_shared<SaveTlsRecord>(client_, 3);
+ saved->EnableDecryption();
+ client_->SetFilter(saved);
Connect();
SendReceive();
@@ -215,8 +218,9 @@ TEST_F(TlsConnectStreamTls13, DecryptRecordServer) {
EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
// 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer
- auto saved = std::make_shared<SaveTlsRecord>(3);
- server_->SetTlsRecordFilter(saved);
+ auto saved = std::make_shared<SaveTlsRecord>(server_, 3);
+ saved->EnableDecryption();
+ server_->SetFilter(saved);
Connect();
SendReceive();
@@ -228,7 +232,8 @@ TEST_F(TlsConnectStreamTls13, DecryptRecordServer) {
class DropTlsRecord : public TlsRecordFilter {
public:
- DropTlsRecord(size_t index) : index_(index), count_(0) {}
+ DropTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index)
+ : TlsRecordFilter(agent), index_(index), count_(0) {}
protected:
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
@@ -253,7 +258,9 @@ TEST_F(TlsConnectStreamTls13, DropRecordServer) {
SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
// 0 = ServerHello, 1 = other handshake, 2 = first write
- server_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2));
+ auto filter = std::make_shared<DropTlsRecord>(server_, 2);
+ filter->EnableDecryption();
+ server_->SetFilter(filter);
Connect();
server_->SendData(23, 23); // This should be dropped, so it won't be counted.
server_->ResetSentBytes();
@@ -263,7 +270,9 @@ TEST_F(TlsConnectStreamTls13, DropRecordServer) {
TEST_F(TlsConnectStreamTls13, DropRecordClient) {
EnsureTlsSetup();
// 0 = ClientHello, 1 = Finished, 2 = first write
- client_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2));
+ auto filter = std::make_shared<DropTlsRecord>(client_, 2);
+ filter->EnableDecryption();
+ client_->SetFilter(filter);
Connect();
client_->SendData(26, 26); // This should be dropped, so it won't be counted.
client_->ResetSentBytes();
@@ -371,7 +380,8 @@ TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) {
class TlsPreCCSHeaderInjector : public TlsRecordFilter {
public:
- TlsPreCCSHeaderInjector() {}
+ TlsPreCCSHeaderInjector(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent) {}
virtual PacketFilter::Action FilterRecord(
const TlsRecordHeader& record_header, const DataBuffer& input,
size_t* offset, DataBuffer* output) override {
@@ -388,14 +398,14 @@ class TlsPreCCSHeaderInjector : public TlsRecordFilter {
};
TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) {
- client_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>());
+ client_->SetFilter(std::make_shared<TlsPreCCSHeaderInjector>(client_));
ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
}
TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) {
- server_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>());
+ server_->SetFilter(std::make_shared<TlsPreCCSHeaderInjector>(server_));
StartConnect();
ExpectAlert(client_, kTlsAlertUnexpectedMessage);
Handshake();
@@ -476,8 +486,8 @@ TEST_F(TlsConnectTest, OneNRecordSplitting) {
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0);
EnsureTlsSetup();
ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA);
- auto records = std::make_shared<TlsRecordRecorder>();
- server_->SetPacketFilter(records);
+ auto records = std::make_shared<TlsRecordRecorder>(server_);
+ server_->SetFilter(records);
// This should be split into 1, 16384 and 20.
DataBuffer big_buffer;
big_buffer.Allocate(1 + 16384 + 20);
diff --git a/gtests/ssl_gtest/ssl_record_unittest.cc b/gtests/ssl_gtest/ssl_record_unittest.cc
index d1d496f49..5aab1b352 100644
--- a/gtests/ssl_gtest/ssl_record_unittest.cc
+++ b/gtests/ssl_gtest/ssl_record_unittest.cc
@@ -103,8 +103,8 @@ TEST_P(TlsPaddingTest, LastByteOfPadWrong) {
class RecordReplacer : public TlsRecordFilter {
public:
- RecordReplacer(size_t size)
- : TlsRecordFilter(), enabled_(false), size_(size) {}
+ RecordReplacer(const std::shared_ptr<TlsAgent>& agent, size_t size)
+ : TlsRecordFilter(agent), enabled_(false), size_(size) {}
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& data,
@@ -135,8 +135,9 @@ TEST_F(TlsConnectStreamTls13, LargeRecord) {
EnsureTlsSetup();
const size_t record_limit = 16384;
- auto replacer = std::make_shared<RecordReplacer>(record_limit);
- client_->SetTlsRecordFilter(replacer);
+ auto replacer = std::make_shared<RecordReplacer>(client_, record_limit);
+ replacer->EnableDecryption();
+ client_->SetFilter(replacer);
Connect();
replacer->Enable();
@@ -149,8 +150,9 @@ TEST_F(TlsConnectStreamTls13, TooLargeRecord) {
EnsureTlsSetup();
const size_t record_limit = 16384;
- auto replacer = std::make_shared<RecordReplacer>(record_limit + 1);
- client_->SetTlsRecordFilter(replacer);
+ auto replacer = std::make_shared<RecordReplacer>(client_, record_limit + 1);
+ replacer->EnableDecryption();
+ client_->SetFilter(replacer);
Connect();
replacer->Enable();
@@ -177,4 +179,4 @@ auto kTrueFalse = ::testing::ValuesIn(kTrueFalseArr);
INSTANTIATE_TEST_CASE_P(TlsPadding, TlsPaddingTest,
::testing::Combine(kContentSizes, kTrueFalse));
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_resumption_unittest.cc b/gtests/ssl_gtest/ssl_resumption_unittest.cc
index 4f3a98ad4..9a80f00e6 100644
--- a/gtests/ssl_gtest/ssl_resumption_unittest.cc
+++ b/gtests/ssl_gtest/ssl_resumption_unittest.cc
@@ -219,8 +219,8 @@ TEST_P(TlsConnectGenericResumption, ConnectWithExpiredTicketAtClient) {
SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3)
? ssl_tls13_pre_shared_key_xtn
: ssl_session_ticket_xtn;
- auto capture = std::make_shared<TlsExtensionCapture>(xtn);
- client_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(client_, xtn);
+ client_->SetFilter(capture);
Connect();
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
@@ -245,8 +245,8 @@ TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) {
SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3)
? ssl_tls13_pre_shared_key_xtn
: ssl_session_ticket_xtn;
- auto capture = std::make_shared<TlsExtensionCapture>(xtn);
- client_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(client_, xtn);
+ client_->SetFilter(capture);
StartConnect();
client_->Handshake();
EXPECT_TRUE(capture->captured());
@@ -327,25 +327,25 @@ TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) {
// Prior to TLS 1.3, we were not fully ephemeral; though 1.3 fixes that
TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) {
- auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(i1);
+ auto filter = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ server_->SetFilter(filter);
Connect();
CheckKeys();
TlsServerKeyExchangeEcdhe dhe1;
- EXPECT_TRUE(dhe1.Parse(i1->buffer()));
+ EXPECT_TRUE(dhe1.Parse(filter->buffer()));
// Restart
Reset();
- auto i2 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(i2);
+ auto filter2 = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ server_->SetFilter(filter2);
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
CheckKeys();
TlsServerKeyExchangeEcdhe dhe2;
- EXPECT_TRUE(dhe2.Parse(i2->buffer()));
+ EXPECT_TRUE(dhe2.Parse(filter2->buffer()));
// Make sure they are the same.
EXPECT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len());
@@ -356,26 +356,26 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) {
// This test parses the ServerKeyExchange, which isn't in 1.3
TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) {
server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
- auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(i1);
+ auto filter = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ server_->SetFilter(filter);
Connect();
CheckKeys();
TlsServerKeyExchangeEcdhe dhe1;
- EXPECT_TRUE(dhe1.Parse(i1->buffer()));
+ EXPECT_TRUE(dhe1.Parse(filter->buffer()));
// Restart
Reset();
server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
- auto i2 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(i2);
+ auto filter2 = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ server_->SetFilter(filter2);
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
CheckKeys();
TlsServerKeyExchangeEcdhe dhe2;
- EXPECT_TRUE(dhe2.Parse(i2->buffer()));
+ EXPECT_TRUE(dhe2.Parse(filter2->buffer()));
// Make sure they are different.
EXPECT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) &&
@@ -434,8 +434,9 @@ TEST_P(TlsConnectGenericResumption, TestResumeClientDifferentCipher) {
} else {
ticket_extension = ssl_session_ticket_xtn;
}
- auto ticket_capture = std::make_shared<TlsExtensionCapture>(ticket_extension);
- client_->SetPacketFilter(ticket_capture);
+ auto ticket_capture =
+ std::make_shared<TlsExtensionCapture>(client_, ticket_extension);
+ client_->SetFilter(ticket_capture);
Connect();
CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
EXPECT_EQ(0U, ticket_capture->extension().len());
@@ -468,8 +469,8 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
Reset();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
- server_->SetPacketFilter(std::make_shared<SelectedCipherSuiteReplacer>(
- ChooseAnotherCipher(version_)));
+ server_->SetFilter(std::make_shared<SelectedCipherSuiteReplacer>(
+ server_, ChooseAnotherCipher(version_)));
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
client_->ExpectSendAlert(kTlsAlertIllegalParameter);
@@ -490,8 +491,10 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
class SelectedVersionReplacer : public TlsHandshakeFilter {
public:
- SelectedVersionReplacer(uint16_t version)
- : TlsHandshakeFilter({kTlsHandshakeServerHello}), version_(version) {}
+ SelectedVersionReplacer(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t version)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}),
+ version_(version) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
@@ -543,8 +546,8 @@ TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) {
// Enable the lower version on the client.
client_->SetVersionRange(version_ - 1, version_);
server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
- server_->SetPacketFilter(
- std::make_shared<SelectedVersionReplacer>(override_version));
+ server_->SetFilter(
+ std::make_shared<SelectedVersionReplacer>(server_, override_version));
ConnectExpectAlert(client_, kTlsAlertHandshakeFailure);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
@@ -567,8 +570,9 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
ExpectResumption(RESUME_TICKET);
- auto c1 = std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(c1);
+ auto c1 = std::make_shared<TlsExtensionCapture>(client_,
+ ssl_tls13_pre_shared_key_xtn);
+ client_->SetFilter(c1);
Connect();
SendReceive();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
@@ -584,8 +588,9 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) {
ClearStats();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
- auto c2 = std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(c2);
+ auto c2 = std::make_shared<TlsExtensionCapture>(client_,
+ ssl_tls13_pre_shared_key_xtn);
+ client_->SetFilter(c2);
ExpectResumption(RESUME_TICKET);
Connect();
SendReceive();
@@ -656,9 +661,10 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNSTWithToken) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
- auto nst_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- ssl_hs_new_session_ticket);
- server_->SetTlsRecordFilter(nst_capture);
+ auto nst_capture = std::make_shared<TlsHandshakeRecorder>(
+ server_, ssl_hs_new_session_ticket);
+ nst_capture->EnableDecryption();
+ server_->SetFilter(nst_capture);
Connect();
// Clear the session ticket keys to invalidate the old ticket.
@@ -678,9 +684,9 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNSTWithToken) {
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
ExpectResumption(RESUME_TICKET);
- auto psk_capture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(psk_capture);
+ auto psk_capture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_tls13_pre_shared_key_xtn);
+ client_->SetFilter(psk_capture);
Connect();
SendReceive();
@@ -696,9 +702,10 @@ TEST_F(TlsConnectTest, SendSessionTicketWithTicketsDisabled) {
EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
- auto nst_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- ssl_hs_new_session_ticket);
- server_->SetTlsRecordFilter(nst_capture);
+ auto nst_capture = std::make_shared<TlsHandshakeRecorder>(
+ server_, ssl_hs_new_session_ticket);
+ nst_capture->EnableDecryption();
+ server_->SetFilter(nst_capture);
Connect();
EXPECT_EQ(0U, nst_capture->buffer().len()) << "expect nothing captured yet";
@@ -714,9 +721,9 @@ TEST_F(TlsConnectTest, SendSessionTicketWithTicketsDisabled) {
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
ExpectResumption(RESUME_TICKET);
- auto psk_capture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(psk_capture);
+ auto psk_capture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_tls13_pre_shared_key_xtn);
+ client_->SetFilter(psk_capture);
Connect();
SendReceive();
@@ -819,20 +826,20 @@ TEST_F(TlsConnectTest, TestTls13ResumptionForcedDowngrade) {
// We will eventually fail the (sid.version == SH.version) check.
std::vector<std::shared_ptr<PacketFilter>> filters;
filters.push_back(std::make_shared<SelectedCipherSuiteReplacer>(
- TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256));
- filters.push_back(
- std::make_shared<SelectedVersionReplacer>(SSL_LIBRARY_VERSION_TLS_1_2));
+ server_, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256));
+ filters.push_back(std::make_shared<SelectedVersionReplacer>(
+ server_, SSL_LIBRARY_VERSION_TLS_1_2));
// Drop a bunch of extensions so that we get past the SH processing. The
// version extension says TLS 1.3, which is counter to our goal, the others
// are not permitted in TLS 1.2 handshakes.
+ filters.push_back(std::make_shared<TlsExtensionDropper>(
+ server_, ssl_tls13_supported_versions_xtn));
filters.push_back(
- std::make_shared<TlsExtensionDropper>(ssl_tls13_supported_versions_xtn));
- filters.push_back(
- std::make_shared<TlsExtensionDropper>(ssl_tls13_key_share_xtn));
- filters.push_back(
- std::make_shared<TlsExtensionDropper>(ssl_tls13_pre_shared_key_xtn));
- server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(filters));
+ std::make_shared<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn));
+ filters.push_back(std::make_shared<TlsExtensionDropper>(
+ server_, ssl_tls13_pre_shared_key_xtn));
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(filters));
// The client here generates an unexpected_message alert when it receives an
// encrypted handshake message from the server (EncryptedExtension). The
diff --git a/gtests/ssl_gtest/ssl_skip_unittest.cc b/gtests/ssl_gtest/ssl_skip_unittest.cc
index 335bfecfa..e4a9e5aed 100644
--- a/gtests/ssl_gtest/ssl_skip_unittest.cc
+++ b/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -22,8 +22,11 @@ namespace nss_test {
class TlsHandshakeSkipFilter : public TlsRecordFilter {
public:
// A TLS record filter that skips handshake messages of the identified type.
- TlsHandshakeSkipFilter(uint8_t handshake_type)
- : handshake_type_(handshake_type), skipped_(false) {}
+ TlsHandshakeSkipFilter(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type)
+ : TlsRecordFilter(agent),
+ handshake_type_(handshake_type),
+ skipped_(false) {}
protected:
// Takes a record; if it is a handshake record, it removes the first handshake
@@ -92,9 +95,14 @@ class TlsSkipTest : public TlsConnectTestBase,
TlsSkipTest()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
+ EnsureTlsSetup();
+ }
+
void ServerSkipTest(std::shared_ptr<PacketFilter> filter,
uint8_t alert = kTlsAlertUnexpectedMessage) {
- server_->SetPacketFilter(filter);
+ server_->SetFilter(filter);
ConnectExpectAlert(client_, alert);
}
};
@@ -105,9 +113,14 @@ class Tls13SkipTest : public TlsConnectTestBase,
Tls13SkipTest()
: TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
- void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
EnsureTlsSetup();
- server_->SetTlsRecordFilter(filter);
+ }
+
+ void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ filter->EnableDecryption();
+ server_->SetFilter(filter);
ExpectAlert(client_, kTlsAlertUnexpectedMessage);
ConnectExpectFail();
client_->CheckErrorCode(error);
@@ -115,8 +128,8 @@ class Tls13SkipTest : public TlsConnectTestBase,
}
void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
- EnsureTlsSetup();
- client_->SetTlsRecordFilter(filter);
+ filter->EnableDecryption();
+ client_->SetFilter(filter);
server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
ConnectExpectFailOneSide(TlsAgent::SERVER);
@@ -129,48 +142,49 @@ class Tls13SkipTest : public TlsConnectTestBase,
TEST_P(TlsSkipTest, SkipCertificateRsa) {
EnableOnlyStaticRsaCiphers();
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertificateDhe) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdhe) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipServerKeyExchange) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
- auto chain = std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit{
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
- std::make_shared<TlsHandshakeSkipFilter>(
- kTlsHandshakeServerKeyExchange)});
+ auto chain = std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit{std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate),
+ std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange)});
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
@@ -178,48 +192,48 @@ TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
auto chain = std::make_shared<ChainedPacketFilter>();
- chain->Add(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
- chain->Add(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(Tls13SkipTest, SkipEncryptedExtensions) {
ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
- kTlsHandshakeEncryptedExtensions),
+ server_, kTlsHandshakeEncryptedExtensions),
SSL_ERROR_RX_UNEXPECTED_CERTIFICATE);
}
TEST_P(Tls13SkipTest, SkipServerCertificate) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
- SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
}
TEST_P(Tls13SkipTest, SkipServerCertificateVerify) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
- SSL_ERROR_RX_UNEXPECTED_FINISHED);
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
}
TEST_P(Tls13SkipTest, SkipClientCertificate) {
client_->SetupClientAuth();
server_->RequestClientAuth(true);
client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
- ClientSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
- SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+ ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ client_, kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
}
TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
client_->SetupClientAuth();
server_->RequestClientAuth(true);
client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
- ClientSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
- SSL_ERROR_RX_UNEXPECTED_FINISHED);
+ ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ client_, kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
}
INSTANTIATE_TEST_CASE_P(
diff --git a/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
index e7fe44d92..c614af99f 100644
--- a/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
+++ b/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
@@ -48,10 +48,9 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSA) {
// This test is stream so we can catch the bad_record_mac alert.
TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) {
EnableOnlyStaticRsaCiphers();
- auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
- kTlsHandshakeClientKeyExchange,
- DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
- client_->SetPacketFilter(i1);
+ client_->SetFilter(std::make_shared<TlsInspectorReplaceHandshakeMessage>(
+ client_, kTlsHandshakeClientKeyExchange,
+ DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))));
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
@@ -59,8 +58,8 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) {
// This test is stream so we can catch the bad_record_mac alert.
TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
EnableOnlyStaticRsaCiphers();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
+ client_->SetFilter(
+ std::make_shared<TlsClientHelloVersionChanger>(client_, server_));
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
@@ -69,8 +68,8 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
// ConnectStaticRSABogusPMSVersionDetect.
TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
EnableOnlyStaticRsaCiphers();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
+ client_->SetFilter(
+ std::make_shared<TlsClientHelloVersionChanger>(client_, server_));
server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE);
Connect();
}
@@ -79,10 +78,9 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
- auto inspect = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
- kTlsHandshakeClientKeyExchange,
- DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
- client_->SetPacketFilter(inspect);
+ client_->SetFilter(std::make_shared<TlsInspectorReplaceHandshakeMessage>(
+ client_, kTlsHandshakeClientKeyExchange,
+ DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))));
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
@@ -91,8 +89,8 @@ TEST_P(TlsConnectStreamPre13,
ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
+ client_->SetFilter(
+ std::make_shared<TlsClientHelloVersionChanger>(client_, server_));
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
@@ -100,10 +98,10 @@ TEST_P(TlsConnectStreamPre13,
ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
+ client_->SetFilter(
+ std::make_shared<TlsClientHelloVersionChanger>(client_, server_));
server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE);
Connect();
}
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_tls13compat_unittest.cc b/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
index 75cee52fc..43f502fae 100644
--- a/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
+++ b/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
@@ -67,10 +67,7 @@ class Tls13CompatTest : public TlsConnectStreamTls13 {
private:
struct Recorders {
- Recorders()
- : records_(new TlsRecordRecorder()),
- hello_(new TlsInspectorRecordHandshakeMessage(std::set<uint8_t>(
- {kTlsHandshakeClientHello, kTlsHandshakeServerHello}))) {}
+ Recorders() : records_(nullptr), hello_(nullptr) {}
uint8_t session_id_length() const {
// session_id is always after version (2) and random (32).
@@ -91,12 +88,22 @@ class Tls13CompatTest : public TlsConnectStreamTls13 {
}
void Install(std::shared_ptr<TlsAgent>& agent) {
- agent->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+ if (records_ && records_->agent() == agent) {
+ // Avoid replacing the filters if they are already installed on this
+ // agent. This ensures that InstallFilters() can be used after
+ // MakeNewServer() without losing state on the client filters.
+ return;
+ }
+ records_.reset(new TlsRecordRecorder(agent));
+ hello_.reset(new TlsHandshakeRecorder(
+ agent, std::set<uint8_t>(
+ {kTlsHandshakeClientHello, kTlsHandshakeServerHello})));
+ agent->SetFilter(std::make_shared<ChainedPacketFilter>(
ChainedPacketFilterInit({records_, hello_})));
}
std::shared_ptr<TlsRecordRecorder> records_;
- std::shared_ptr<TlsInspectorRecordHandshakeMessage> hello_;
+ std::shared_ptr<TlsHandshakeRecorder> hello_;
};
void CheckRecordsAreTls12(const std::string& agent,
@@ -171,16 +178,20 @@ TEST_F(Tls13CompatTest, EnabledStatelessHrr) {
server_->StartConnect();
client_->Handshake();
server_->Handshake();
+
+ // The server should send CCS before HRR.
CheckForCCS(false, true);
- // A new server should just work, but not send another CCS.
+ // A new server should complete the handshake, and not send CCS.
MakeNewServer();
InstallFilters();
server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
Handshake();
CheckConnected();
- CheckForCompatHandshake();
+ CheckRecordVersions();
+ CheckHelloVersions();
+ CheckForCCS(true, false);
}
TEST_F(Tls13CompatTest, EnabledHrrZeroRtt) {
@@ -262,10 +273,10 @@ TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello12) {
TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) {
EnsureTlsSetup();
client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE);
- auto client_records = std::make_shared<TlsRecordRecorder>();
- client_->SetPacketFilter(client_records);
- auto server_records = std::make_shared<TlsRecordRecorder>();
- server_->SetPacketFilter(server_records);
+ auto client_records = std::make_shared<TlsRecordRecorder>(client_);
+ client_->SetFilter(client_records);
+ auto server_records = std::make_shared<TlsRecordRecorder>(server_);
+ server_->SetFilter(server_records);
Connect();
ASSERT_EQ(2U, client_records->count()); // CH, Fin
@@ -283,7 +294,8 @@ TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) {
class AddSessionIdFilter : public TlsHandshakeFilter {
public:
- AddSessionIdFilter() : TlsHandshakeFilter({ssl_hs_client_hello}) {}
+ AddSessionIdFilter(const std::shared_ptr<TlsAgent>& client)
+ : TlsHandshakeFilter(client, {ssl_hs_client_hello}) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
@@ -303,14 +315,14 @@ class AddSessionIdFilter : public TlsHandshakeFilter {
// mode. It should be ignored instead.
TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) {
EnsureTlsSetup();
- auto client_records = std::make_shared<TlsRecordRecorder>();
- client_->SetPacketFilter(
+ auto client_records = std::make_shared<TlsRecordRecorder>(client_);
+ client_->SetFilter(
std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit(
- {client_records, std::make_shared<AddSessionIdFilter>()})));
- auto server_hello = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerHello);
- auto server_records = std::make_shared<TlsRecordRecorder>();
- server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+ {client_records, std::make_shared<AddSessionIdFilter>(client_)})));
+ auto server_hello =
+ std::make_shared<TlsHandshakeRecorder>(server_, kTlsHandshakeServerHello);
+ auto server_records = std::make_shared<TlsRecordRecorder>(server_);
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
ChainedPacketFilterInit({server_records, server_hello})));
StartConnect();
client_->Handshake();
@@ -334,4 +346,4 @@ TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) {
EXPECT_EQ(0U, session_id_len);
}
-} // nss_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
index 2f8ddd6fe..a590ee0ed 100644
--- a/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
+++ b/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
@@ -147,10 +147,10 @@ class SSLv2ClientHelloTestF : public TlsConnectTestBase {
SSLv2ClientHelloTestF(SSLProtocolVariant variant, uint16_t version)
: TlsConnectTestBase(variant, version), filter_(nullptr) {}
- void SetUp() {
+ void SetUp() override {
TlsConnectTestBase::SetUp();
filter_ = std::make_shared<SSLv2ClientHelloFilter>(client_, version_);
- client_->SetPacketFilter(filter_);
+ client_->SetFilter(filter_);
}
void SetExpectedVersion(uint16_t version) {
diff --git a/gtests/ssl_gtest/ssl_version_unittest.cc b/gtests/ssl_gtest/ssl_version_unittest.cc
index 9db293b07..747f019cf 100644
--- a/gtests/ssl_gtest/ssl_version_unittest.cc
+++ b/gtests/ssl_gtest/ssl_version_unittest.cc
@@ -56,18 +56,16 @@ TEST_P(TlsConnectGeneric, ServerNegotiateTls12) {
// two validate that we can also detect fallback using the
// SSL_SetDowngradeCheckVersion() API.
TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) {
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- SSL_LIBRARY_VERSION_TLS_1_1));
+ client_->SetFilter(std::make_shared<TlsClientHelloVersionSetter>(
+ client_, SSL_LIBRARY_VERSION_TLS_1_1));
ConnectExpectFail();
ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
}
/* Attempt to negotiate the bogus DTLS 1.1 version. */
TEST_F(DtlsConnectTest, TestDtlsVersion11) {
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- ((~0x0101) & 0xffff)));
+ client_->SetFilter(std::make_shared<TlsClientHelloVersionSetter>(
+ client_, ((~0x0101) & 0xffff)));
ConnectExpectFail();
// It's kind of surprising that SSL_ERROR_NO_CYPHER_OVERLAP is
// what is returned here, but this is deliberate in ssl3_HandleAlert().
@@ -78,9 +76,8 @@ TEST_F(DtlsConnectTest, TestDtlsVersion11) {
// Disabled as long as we have draft version.
TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) {
EnsureTlsSetup();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- SSL_LIBRARY_VERSION_TLS_1_2));
+ client_->SetFilter(std::make_shared<TlsClientHelloVersionSetter>(
+ client_, SSL_LIBRARY_VERSION_TLS_1_2));
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
@@ -92,9 +89,8 @@ TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) {
// TLS 1.1 clients do not check the random values, so we should
// instead get a handshake failure alert from the server.
TEST_F(TlsConnectTest, TestDowngradeDetectionToTls10) {
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- SSL_LIBRARY_VERSION_TLS_1_0));
+ client_->SetFilter(std::make_shared<TlsClientHelloVersionSetter>(
+ client_, SSL_LIBRARY_VERSION_TLS_1_0));
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
SSL_LIBRARY_VERSION_TLS_1_1);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
@@ -177,12 +173,11 @@ class Tls13NoSupportedVersions : public TlsConnectStreamTls12 {
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_2);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, max_server_version);
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- overwritten_client_version));
- auto capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerHello);
- server_->SetPacketFilter(capture);
+ client_->SetFilter(std::make_shared<TlsClientHelloVersionSetter>(
+ client_, overwritten_client_version));
+ auto capture = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerHello);
+ server_->SetFilter(capture);
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
@@ -214,12 +209,11 @@ TEST_F(Tls13NoSupportedVersions,
// Offer 1.3 but with ClientHello.legacy_version == TLS 1.4. This
// causes a bad MAC error when we read EncryptedExtensions.
TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) {
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- SSL_LIBRARY_VERSION_TLS_1_3 + 1));
- auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_supported_versions_xtn);
- server_->SetPacketFilter(capture);
+ client_->SetFilter(std::make_shared<TlsClientHelloVersionSetter>(
+ client_, SSL_LIBRARY_VERSION_TLS_1_3 + 1));
+ auto capture = std::make_shared<TlsExtensionCapture>(
+ server_, ssl_tls13_supported_versions_xtn);
+ server_->SetFilter(capture);
client_->ExpectSendAlert(kTlsAlertBadRecordMac);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
diff --git a/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc b/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
index eda96831c..7f3c4a896 100644
--- a/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
+++ b/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
@@ -189,12 +189,12 @@ class TestPolicyVersionRange
}
}
- void SetUp() {
- SetPolicy(policy_.range());
+ void SetUp() override {
TlsConnectTestBase::SetUp();
+ SetPolicy(policy_.range());
}
- void TearDown() {
+ void TearDown() override {
TlsConnectTestBase::TearDown();
saved_version_policy_.RestoreOriginalPolicy();
}
diff --git a/gtests/ssl_gtest/test_io.cc b/gtests/ssl_gtest/test_io.cc
index adcdbfbaf..728217851 100644
--- a/gtests/ssl_gtest/test_io.cc
+++ b/gtests/ssl_gtest/test_io.cc
@@ -25,10 +25,6 @@ namespace nss_test {
if (g_ssl_gtest_verbose) LOG(a); \
} while (false)
-void DummyPrSocket::SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
- filter_ = filter;
-}
-
ScopedPRFileDesc DummyPrSocket::CreateFD() {
static PRDescIdentity test_fd_identity =
PR_GetUniqueIdentity("testtransportadapter");
diff --git a/gtests/ssl_gtest/test_io.h b/gtests/ssl_gtest/test_io.h
index 469d90a7c..dbeb6b9d4 100644
--- a/gtests/ssl_gtest/test_io.h
+++ b/gtests/ssl_gtest/test_io.h
@@ -74,7 +74,9 @@ class DummyPrSocket : public DummyIOLayerMethods {
std::weak_ptr<DummyPrSocket>& peer() { return peer_; }
void SetPeer(const std::shared_ptr<DummyPrSocket>& peer) { peer_ = peer; }
- void SetPacketFilter(std::shared_ptr<PacketFilter> filter);
+ void SetPacketFilter(const std::shared_ptr<PacketFilter>& filter) {
+ filter_ = filter;
+ }
// Drops peer, packet filter and any outstanding packets.
void Reset();
@@ -176,6 +178,6 @@ class Poller {
timers_;
};
-} // end of namespace
+} // namespace nss_test
#endif
diff --git a/gtests/ssl_gtest/tls_agent.h b/gtests/ssl_gtest/tls_agent.h
index 9bde5dfda..941fb649e 100644
--- a/gtests/ssl_gtest/tls_agent.h
+++ b/gtests/ssl_gtest/tls_agent.h
@@ -80,18 +80,10 @@ class TlsAgent : public PollTarget {
adapter_->SetPeer(peer->adapter_);
}
- // Set a filter that can access plaintext (TLS 1.3 only).
- void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) {
- filter->SetAgent(this);
+ void SetFilter(std::shared_ptr<PacketFilter> filter) {
adapter_->SetPacketFilter(filter);
- filter->EnableDecryption();
}
-
- void SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
- adapter_->SetPacketFilter(filter);
- }
-
- void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); }
+ void ClearFilter() { adapter_->SetPacketFilter(nullptr); }
void StartConnect(PRFileDesc* model = nullptr);
void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
@@ -463,7 +455,7 @@ class TlsAgentTestBase : public ::testing::Test {
void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
int32_t error_code = 0);
- std::unique_ptr<TlsAgent> agent_;
+ std::shared_ptr<TlsAgent> agent_;
TlsAgent::Role role_;
SSLProtocolVariant variant_;
uint16_t version_;
diff --git a/gtests/ssl_gtest/tls_connect.cc b/gtests/ssl_gtest/tls_connect.cc
index b1e90d89d..bc146e042 100644
--- a/gtests/ssl_gtest/tls_connect.cc
+++ b/gtests/ssl_gtest/tls_connect.cc
@@ -770,17 +770,17 @@ TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken()
void TlsKeyExchangeTest::EnsureKeyShareSetup() {
EnsureTlsSetup();
groups_capture_ =
- std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
shares_capture_ =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
- shares_capture2_ =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn, true);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
+ shares_capture2_ = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_tls13_key_share_xtn, true);
std::vector<std::shared_ptr<PacketFilter>> captures = {
groups_capture_, shares_capture_, shares_capture2_};
- client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
- capture_hrr_ = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeHelloRetryRequest);
- server_->SetPacketFilter(capture_hrr_);
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
+ capture_hrr_ = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeHelloRetryRequest);
+ server_->SetFilter(capture_hrr_);
}
void TlsKeyExchangeTest::ConfigNamedGroups(
diff --git a/gtests/ssl_gtest/tls_connect.h b/gtests/ssl_gtest/tls_connect.h
index 9746f9865..6a35fc78b 100644
--- a/gtests/ssl_gtest/tls_connect.h
+++ b/gtests/ssl_gtest/tls_connect.h
@@ -45,8 +45,8 @@ class TlsConnectTestBase : public ::testing::Test {
TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version);
virtual ~TlsConnectTestBase();
- void SetUp();
- void TearDown();
+ virtual void SetUp();
+ virtual void TearDown();
// Initialize client and server.
void Init();
@@ -320,7 +320,7 @@ class TlsKeyExchangeTest : public TlsConnectGeneric {
std::shared_ptr<TlsExtensionCapture> groups_capture_;
std::shared_ptr<TlsExtensionCapture> shares_capture_;
std::shared_ptr<TlsExtensionCapture> shares_capture2_;
- std::shared_ptr<TlsInspectorRecordHandshakeMessage> capture_hrr_;
+ std::shared_ptr<TlsHandshakeRecorder> capture_hrr_;
void EnsureKeyShareSetup();
void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
diff --git a/gtests/ssl_gtest/tls_filter.cc b/gtests/ssl_gtest/tls_filter.cc
index 89f201295..d34b13bcb 100644
--- a/gtests/ssl_gtest/tls_filter.cc
+++ b/gtests/ssl_gtest/tls_filter.cc
@@ -452,7 +452,7 @@ size_t TlsHandshakeFilter::HandshakeHeader::Write(
return offset;
}
-PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake(
+PacketFilter::Action TlsHandshakeRecorder::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
// Only do this once.
@@ -763,7 +763,7 @@ PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
if (counter_++ == record_) {
DataBuffer buf;
header.Write(&buf, 0, body);
- src_.lock()->SendDirect(buf);
+ agent()->SendDirect(buf);
dest_.lock()->Handshake();
func_();
return DROP;
@@ -772,7 +772,7 @@ PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
return KEEP;
}
-PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake(
+PacketFilter::Action TlsClientHelloVersionChanger::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
EXPECT_EQ(SECSuccess,
@@ -808,7 +808,7 @@ PacketFilter::Action SelectiveRecordDropFilter::FilterRecord(
return pattern;
}
-PacketFilter::Action TlsInspectorClientHelloVersionSetter::FilterHandshake(
+PacketFilter::Action TlsClientHelloVersionSetter::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
*output = input;
diff --git a/gtests/ssl_gtest/tls_filter.h b/gtests/ssl_gtest/tls_filter.h
index 1db3b90f6..9485b5eb3 100644
--- a/gtests/ssl_gtest/tls_filter.h
+++ b/gtests/ssl_gtest/tls_filter.h
@@ -74,16 +74,15 @@ struct TlsRecord {
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
- TlsRecordFilter()
- : agent_(nullptr),
+ TlsRecordFilter(const std::shared_ptr<TlsAgent>& agent)
+ : agent_(agent),
count_(0),
cipher_spec_(),
dropped_record_(false),
in_sequence_number_(0),
out_sequence_number_(0) {}
- void SetAgent(const TlsAgent* agent) { agent_ = agent; }
- const TlsAgent* agent() const { return agent_; }
+ std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); }
// External interface. Overrides PacketFilter.
PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output);
@@ -126,7 +125,7 @@ class TlsRecordFilter : public PacketFilter {
static void CipherSpecChanged(void* arg, PRBool sending,
ssl3CipherSpec* newSpec);
- const TlsAgent* agent_;
+ std::weak_ptr<TlsAgent> agent_;
size_t count_;
std::unique_ptr<TlsCipherSpec> cipher_spec_;
// Whether we dropped a record since the cipher spec changed.
@@ -175,9 +174,13 @@ inline std::ostream& operator<<(std::ostream& stream,
// records and that they don't span records or anything crazy like that.
class TlsHandshakeFilter : public TlsRecordFilter {
public:
- TlsHandshakeFilter() : handshake_types_(), preceding_fragment_() {}
- TlsHandshakeFilter(const std::set<uint8_t>& types)
- : handshake_types_(types), preceding_fragment_() {}
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent), handshake_types_(), preceding_fragment_() {}
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent,
+ const std::set<uint8_t>& types)
+ : TlsRecordFilter(agent),
+ handshake_types_(types),
+ preceding_fragment_() {}
// This filter can be set to be selective based on handshake message type. If
// this function isn't used (or the set is empty), then all handshake messages
@@ -229,12 +232,14 @@ class TlsHandshakeFilter : public TlsRecordFilter {
};
// Make a copy of the first instance of a handshake message.
-class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
+class TlsHandshakeRecorder : public TlsHandshakeFilter {
public:
- TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
- : TlsHandshakeFilter({handshake_type}), buffer_() {}
- TlsInspectorRecordHandshakeMessage(const std::set<uint8_t>& handshake_types)
- : TlsHandshakeFilter(handshake_types), buffer_() {}
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type)
+ : TlsHandshakeFilter(agent, {handshake_type}), buffer_() {}
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent,
+ const std::set<uint8_t>& handshake_types)
+ : TlsHandshakeFilter(agent, handshake_types), buffer_() {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -251,9 +256,10 @@ class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
// Replace all instances of a handshake message.
class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
public:
- TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type,
+ TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type,
const DataBuffer& replacement)
- : TlsHandshakeFilter({handshake_type}), buffer_(replacement) {}
+ : TlsHandshakeFilter(agent, {handshake_type}), buffer_(replacement) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -266,9 +272,11 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
// Make a copy of each record of a given type.
class TlsRecordRecorder : public TlsRecordFilter {
public:
- TlsRecordRecorder(uint8_t ct) : filter_(true), ct_(ct), records_() {}
- TlsRecordRecorder()
- : filter_(false),
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent, uint8_t ct)
+ : TlsRecordFilter(agent), filter_(true), ct_(ct), records_() {}
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent),
+ filter_(false),
ct_(content_handshake), // dummy (<optional> is C++14)
records_() {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
@@ -289,7 +297,9 @@ class TlsRecordRecorder : public TlsRecordFilter {
// Make a copy of the complete conversation.
class TlsConversationRecorder : public TlsRecordFilter {
public:
- TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {}
+ TlsConversationRecorder(const std::shared_ptr<TlsAgent>& agent,
+ DataBuffer& buffer)
+ : TlsRecordFilter(agent), buffer_(buffer) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
@@ -302,6 +312,8 @@ class TlsConversationRecorder : public TlsRecordFilter {
// Make a copy of the records
class TlsHeaderRecorder : public TlsRecordFilter {
public:
+ TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
@@ -338,13 +350,15 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
class TlsExtensionFilter : public TlsHandshakeFilter {
public:
- TlsExtensionFilter()
- : TlsHandshakeFilter({kTlsHandshakeClientHello, kTlsHandshakeServerHello,
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent,
+ {kTlsHandshakeClientHello, kTlsHandshakeServerHello,
kTlsHandshakeHelloRetryRequest,
kTlsHandshakeEncryptedExtensions}) {}
- TlsExtensionFilter(const std::set<uint8_t>& types)
- : TlsHandshakeFilter(types) {}
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent,
+ const std::set<uint8_t>& types)
+ : TlsHandshakeFilter(agent, types) {}
static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
@@ -365,8 +379,13 @@ class TlsExtensionFilter : public TlsHandshakeFilter {
class TlsExtensionCapture : public TlsExtensionFilter {
public:
- TlsExtensionCapture(uint16_t ext, bool last = false)
- : extension_(ext), captured_(false), last_(last), data_() {}
+ TlsExtensionCapture(const std::shared_ptr<TlsAgent>& agent, uint16_t ext,
+ bool last = false)
+ : TlsExtensionFilter(agent),
+ extension_(ext),
+ captured_(false),
+ last_(last),
+ data_() {}
const DataBuffer& extension() const { return data_; }
bool captured() const { return captured_; }
@@ -385,8 +404,9 @@ class TlsExtensionCapture : public TlsExtensionFilter {
class TlsExtensionReplacer : public TlsExtensionFilter {
public:
- TlsExtensionReplacer(uint16_t extension, const DataBuffer& data)
- : extension_(extension), data_(data) {}
+ TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension, const DataBuffer& data)
+ : TlsExtensionFilter(agent), extension_(extension), data_(data) {}
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) override;
@@ -398,7 +418,9 @@ class TlsExtensionReplacer : public TlsExtensionFilter {
class TlsExtensionDropper : public TlsExtensionFilter {
public:
- TlsExtensionDropper(uint16_t extension) : extension_(extension) {}
+ TlsExtensionDropper(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension)
+ : TlsExtensionFilter(agent), extension_(extension) {}
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer&, DataBuffer*) override;
@@ -408,8 +430,9 @@ class TlsExtensionDropper : public TlsExtensionFilter {
class TlsExtensionInjector : public TlsHandshakeFilter {
public:
- TlsExtensionInjector(uint16_t ext, const DataBuffer& data)
- : extension_(ext), data_(data) {}
+ TlsExtensionInjector(const std::shared_ptr<TlsAgent>& agent, uint16_t ext,
+ const DataBuffer& data)
+ : TlsHandshakeFilter(agent), extension_(ext), data_(data) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
@@ -426,16 +449,20 @@ typedef std::function<void(void)> VoidFunction;
class AfterRecordN : public TlsRecordFilter {
public:
- AfterRecordN(std::shared_ptr<TlsAgent>& src, std::shared_ptr<TlsAgent>& dest,
- unsigned int record, VoidFunction func)
- : src_(src), dest_(dest), record_(record), func_(func), counter_(0) {}
+ AfterRecordN(const std::shared_ptr<TlsAgent>& src,
+ const std::shared_ptr<TlsAgent>& dest, unsigned int record,
+ VoidFunction func)
+ : TlsRecordFilter(src),
+ dest_(dest),
+ record_(record),
+ func_(func),
+ counter_(0) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) override;
private:
- std::weak_ptr<TlsAgent> src_;
std::weak_ptr<TlsAgent> dest_;
unsigned int record_;
VoidFunction func_;
@@ -444,10 +471,12 @@ class AfterRecordN : public TlsRecordFilter {
// When we see the ClientKeyExchange from |client|, increment the
// ClientHelloVersion on |server|.
-class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter {
+class TlsClientHelloVersionChanger : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server)
- : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}), server_(server) {}
+ TlsClientHelloVersionChanger(const std::shared_ptr<TlsAgent>& client,
+ const std::shared_ptr<TlsAgent>& server)
+ : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}),
+ server_(server) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -477,14 +506,16 @@ class SelectiveDropFilter : public PacketFilter {
// datagram, we just drop one.
class SelectiveRecordDropFilter : public TlsRecordFilter {
public:
- SelectiveRecordDropFilter(uint32_t pattern, bool enabled = true)
- : pattern_(pattern), counter_(0) {
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent,
+ uint32_t pattern, bool enabled = true)
+ : TlsRecordFilter(agent), pattern_(pattern), counter_(0) {
if (!enabled) {
Disable();
}
}
- SelectiveRecordDropFilter(std::initializer_list<size_t> records)
- : SelectiveRecordDropFilter(ToPattern(records), true) {}
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent,
+ std::initializer_list<size_t> records)
+ : SelectiveRecordDropFilter(agent, ToPattern(records), true) {}
void Reset(uint32_t pattern) {
counter_ = 0;
@@ -509,10 +540,12 @@ class SelectiveRecordDropFilter : public TlsRecordFilter {
};
// Set the version number in the ClientHello.
-class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
+class TlsClientHelloVersionSetter : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionSetter(uint16_t version)
- : TlsHandshakeFilter({kTlsHandshakeClientHello}), version_(version) {}
+ TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t version)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeClientHello}),
+ version_(version) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -525,7 +558,8 @@ class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
// Damages the last byte of a handshake message.
class TlsLastByteDamager : public TlsHandshakeFilter {
public:
- TlsLastByteDamager(uint8_t type) : type_(type) {}
+ TlsLastByteDamager(const std::shared_ptr<TlsAgent>& agent, uint8_t type)
+ : TlsHandshakeFilter(agent), type_(type) {}
PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) override {
@@ -545,8 +579,10 @@ class TlsLastByteDamager : public TlsHandshakeFilter {
class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
public:
- SelectedCipherSuiteReplacer(uint16_t suite)
- : TlsHandshakeFilter({kTlsHandshakeServerHello}), cipher_suite_(suite) {}
+ SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t suite)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}),
+ cipher_suite_(suite) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,