summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Thomson <martin.thomson@gmail.com>2018-01-03 15:36:18 +1100
committerMartin Thomson <martin.thomson@gmail.com>2018-01-03 15:36:18 +1100
commite81e6990bbab90fba9eeb78d2885e3d826c393c3 (patch)
treeacc4d7b1f9f38b10d2a027d0b9c6c3a2f71c0f70
parent2046692aa9c9af51c7b8510cdd4bb6f9f949e54d (diff)
downloadnss-hg-e81e6990bbab90fba9eeb78d2885e3d826c393c3.tar.gz
Bug 1427675 - Add TlsAgent argument to TlsRecordFilter, r=ekr
This is a fairly disruptive change, but mostly just mechanical. There are a few extra changes: - I have renamed the TlsInspector* filters for consistency. This was purely mechanical. - I renamed the SetPacketFilter function to just SetFilter. Also mechanical. - TlsRecordFilter maintains a weak pointer reference to the TlsAgent now rather than using a bare pointer. This meant that I had to change TlsAgentTestBase to use shared_ptr rather than unique_ptr to support of use of filters with those tests. - I removed the helper function that enables decryption. Enabling decryption is now more explicit. - I ran a newer clang-format version and it fixed a few extra things, like the comments on the end of namespace {} blocks, some of which were wrong. - I discovered a bug in some of the drop tests: in the 0-RTT tests, the filters were being installed on the client and server right at the start, which meant that they were capturing the first handshake and not the second one. This was clearly against intent, but the tests were mostly right still, it was only the expected ACKs that were wrong. We were expecting just one record to be ACKed by a server (Finished), but the record with EndOfEarlyData should have been acknowledged as well. - In TlsSkipTest and Tls13SkipTest, I had to override SetUp() so that client_ and server_ are initialized prior to constructing filters. In doing so, I noticed that we weren't being consistent about overriding SetUp properly, so I fixed the small number of instances of that by adding an override label to each and marking the base method virtual. - The stateless HRR test for TLS 1.3 compat mode was replacing the server, but expecting to retain the same filters. That wasn't a problem in that case, but I didn't want to have any places where the filter was set on a different agent from the one that was passed to it.
-rw-r--r--gtests/ssl_gtest/bloomfilter_unittest.cc2
-rw-r--r--gtests/ssl_gtest/ssl_0rtt_unittest.cc10
-rw-r--r--gtests/ssl_gtest/ssl_agent_unittest.cc6
-rw-r--r--gtests/ssl_gtest/ssl_auth_unittest.cc100
-rw-r--r--gtests/ssl_gtest/ssl_cert_ext_unittest.cc10
-rw-r--r--gtests/ssl_gtest/ssl_ciphersuite_unittest.cc2
-rw-r--r--gtests/ssl_gtest/ssl_custext_unittest.cc31
-rw-r--r--gtests/ssl_gtest/ssl_damage_unittest.cc25
-rw-r--r--gtests/ssl_gtest/ssl_dhe_unittest.cc79
-rw-r--r--gtests/ssl_gtest/ssl_drop_unittest.cc78
-rw-r--r--gtests/ssl_gtest/ssl_ecdh_unittest.cc34
-rw-r--r--gtests/ssl_gtest/ssl_extension_unittest.cc212
-rw-r--r--gtests/ssl_gtest/ssl_fragment_unittest.cc4
-rw-r--r--gtests/ssl_gtest/ssl_fuzz_unittest.cc53
-rw-r--r--gtests/ssl_gtest/ssl_hrr_unittest.cc106
-rw-r--r--gtests/ssl_gtest/ssl_keylog_unittest.cc4
-rw-r--r--gtests/ssl_gtest/ssl_loopback_unittest.cc60
-rw-r--r--gtests/ssl_gtest/ssl_record_unittest.cc16
-rw-r--r--gtests/ssl_gtest/ssl_resumption_unittest.cc113
-rw-r--r--gtests/ssl_gtest/ssl_skip_unittest.cc94
-rw-r--r--gtests/ssl_gtest/ssl_staticrsa_unittest.cc32
-rw-r--r--gtests/ssl_gtest/ssl_tls13compat_unittest.cc54
-rw-r--r--gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc4
-rw-r--r--gtests/ssl_gtest/ssl_version_unittest.cc42
-rw-r--r--gtests/ssl_gtest/ssl_versionpolicy_unittest.cc6
-rw-r--r--gtests/ssl_gtest/test_io.cc4
-rw-r--r--gtests/ssl_gtest/test_io.h6
-rw-r--r--gtests/ssl_gtest/tls_agent.h14
-rw-r--r--gtests/ssl_gtest/tls_connect.cc16
-rw-r--r--gtests/ssl_gtest/tls_connect.h6
-rw-r--r--gtests/ssl_gtest/tls_filter.cc8
-rw-r--r--gtests/ssl_gtest/tls_filter.h130
32 files changed, 733 insertions, 628 deletions
diff --git a/gtests/ssl_gtest/bloomfilter_unittest.cc b/gtests/ssl_gtest/bloomfilter_unittest.cc
index 110cfa13a..6efe06ec7 100644
--- a/gtests/ssl_gtest/bloomfilter_unittest.cc
+++ b/gtests/ssl_gtest/bloomfilter_unittest.cc
@@ -105,4 +105,4 @@ static const BloomFilterConfig kBloomFilterConfigurations[] = {
INSTANTIATE_TEST_CASE_P(BloomFilterConfigurations, BloomFilterTest,
::testing::ValuesIn(kBloomFilterConfigurations));
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/gtests/ssl_gtest/ssl_0rtt_unittest.cc
index 7d6120dc1..ded388f57 100644
--- a/gtests/ssl_gtest/ssl_0rtt_unittest.cc
+++ b/gtests/ssl_gtest/ssl_0rtt_unittest.cc
@@ -94,7 +94,7 @@ class TlsZeroRttReplayTest : public TlsConnectTls13 {
// Now run a true 0-RTT handshake, but capture the first packet.
auto first_packet = std::make_shared<SaveFirstPacket>();
- client_->SetPacketFilter(first_packet);
+ client_->SetFilter(first_packet);
client_->Set0RttEnabled(true);
server_->Set0RttEnabled(true);
ExpectResumption(RESUME_TICKET);
@@ -115,9 +115,9 @@ class TlsZeroRttReplayTest : public TlsConnectTls13 {
server_->Set0RttEnabled(true);
// Capture the early_data extension, which should not appear.
- auto early_data_ext =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn);
- server_->SetPacketFilter(early_data_ext);
+ auto early_data_ext = std::make_shared<TlsExtensionCapture>(
+ server_, ssl_tls13_early_data_xtn);
+ server_->SetFilter(early_data_ext);
// Finally, replay the ClientHello and force the server to consume it. Stop
// after the server sends its first flight; the client will not be able to
@@ -607,7 +607,7 @@ TEST_P(TlsConnectTls13, ZeroRttOrdering) {
// Now, coalesce the next three things from the client: early data, second
// flight and 1-RTT data.
auto coalesce = std::make_shared<PacketCoalesceFilter>();
- client_->SetPacketFilter(coalesce);
+ client_->SetFilter(coalesce);
// Send (and hold) early data.
static const std::vector<uint8_t> early_data = {3, 2, 1};
diff --git a/gtests/ssl_gtest/ssl_agent_unittest.cc b/gtests/ssl_gtest/ssl_agent_unittest.cc
index 0aa9a4c78..c7a841c68 100644
--- a/gtests/ssl_gtest/ssl_agent_unittest.cc
+++ b/gtests/ssl_gtest/ssl_agent_unittest.cc
@@ -160,9 +160,9 @@ TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenWrite) {
SSL_LIBRARY_VERSION_TLS_1_3);
agent_->StartConnect();
agent_->Set0RttEnabled(true);
- auto filter = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeClientHello);
- agent_->SetPacketFilter(filter);
+ auto filter =
+ std::make_shared<TlsHandshakeRecorder>(agent_, kTlsHandshakeClientHello);
+ agent_->SetFilter(filter);
PRInt32 rv = PR_Write(agent_->ssl_fd(), k0RttData, strlen(k0RttData));
EXPECT_EQ(-1, rv);
int32_t err = PORT_GetError();
diff --git a/gtests/ssl_gtest/ssl_auth_unittest.cc b/gtests/ssl_gtest/ssl_auth_unittest.cc
index c44a18161..d9e31f78f 100644
--- a/gtests/ssl_gtest/ssl_auth_unittest.cc
+++ b/gtests/ssl_gtest/ssl_auth_unittest.cc
@@ -95,10 +95,9 @@ TEST_P(TlsConnectGeneric, ClientAuthBigRsa) {
}
// Offset is the position in the captured buffer where the signature sits.
-static void CheckSigScheme(
- std::shared_ptr<TlsInspectorRecordHandshakeMessage>& capture, size_t offset,
- std::shared_ptr<TlsAgent>& peer, uint16_t expected_scheme,
- size_t expected_size) {
+static void CheckSigScheme(std::shared_ptr<TlsHandshakeRecorder>& capture,
+ size_t offset, std::shared_ptr<TlsAgent>& peer,
+ uint16_t expected_scheme, size_t expected_size) {
EXPECT_LT(offset + 2U, capture->buffer().len());
uint32_t scheme = 0;
@@ -114,9 +113,9 @@ static void CheckSigScheme(
// in the default certificate.
TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
EnsureTlsSetup();
- auto capture_ske = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(capture_ske);
+ auto capture_ske = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ server_->SetFilter(capture_ske);
Connect();
CheckKeys();
@@ -133,10 +132,9 @@ TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
EnsureTlsSetup();
- auto capture_cert_verify =
- std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeCertificateVerify);
- client_->SetPacketFilter(capture_cert_verify);
+ auto capture_cert_verify = std::make_shared<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
+ client_->SetFilter(capture_cert_verify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
Connect();
@@ -147,10 +145,9 @@ TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) {
Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048);
- auto capture_cert_verify =
- std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeCertificateVerify);
- client_->SetPacketFilter(capture_cert_verify);
+ auto capture_cert_verify = std::make_shared<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
+ client_->SetFilter(capture_cert_verify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
Connect();
@@ -161,8 +158,8 @@ TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) {
class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter {
public:
- TlsZeroCertificateRequestSigAlgsFilter()
- : TlsHandshakeFilter({kTlsHandshakeCertificateRequest}) {}
+ TlsZeroCertificateRequestSigAlgsFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeCertificateRequest}) {}
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
@@ -207,12 +204,12 @@ class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter {
// supported_signature_algorithms in the CertificateRequest message.
TEST_P(TlsConnectTls12, ClientAuthNoSigAlgsFallback) {
EnsureTlsSetup();
- auto filter = std::make_shared<TlsZeroCertificateRequestSigAlgsFilter>();
- server_->SetPacketFilter(filter);
- auto capture_cert_verify =
- std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeCertificateVerify);
- client_->SetPacketFilter(capture_cert_verify);
+ auto filter =
+ std::make_shared<TlsZeroCertificateRequestSigAlgsFilter>(server_);
+ server_->SetFilter(filter);
+ auto capture_cert_verify = std::make_shared<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
+ client_->SetFilter(capture_cert_verify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
@@ -360,8 +357,8 @@ TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) {
// The signature_algorithms extension is mandatory in TLS 1.3.
TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) {
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_signature_algorithms_xtn));
+ client_->SetFilter(std::make_shared<TlsExtensionDropper>(
+ client_, ssl_signature_algorithms_xtn));
ConnectExpectAlert(server_, kTlsAlertMissingExtension);
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION);
@@ -370,8 +367,8 @@ TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) {
// TLS 1.2 has trouble detecting this sort of modification: it uses SHA1 and
// only fails when the Finished is checked.
TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) {
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_signature_algorithms_xtn));
+ client_->SetFilter(std::make_shared<TlsExtensionDropper>(
+ client_, ssl_signature_algorithms_xtn));
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
@@ -389,11 +386,11 @@ class BeforeFinished : public TlsRecordFilter {
enum HandshakeState { BEFORE_CCS, AFTER_CCS, DONE };
public:
- BeforeFinished(std::shared_ptr<TlsAgent>& client,
- std::shared_ptr<TlsAgent>& server, VoidFunction before_ccs,
- VoidFunction before_finished)
- : client_(client),
- server_(server),
+ BeforeFinished(const std::shared_ptr<TlsAgent>& client,
+ const std::shared_ptr<TlsAgent>& server,
+ VoidFunction before_ccs, VoidFunction before_finished)
+ : TlsRecordFilter(server),
+ client_(client),
before_ccs_(before_ccs),
before_finished_(before_finished),
state_(BEFORE_CCS) {}
@@ -413,7 +410,7 @@ class BeforeFinished : public TlsRecordFilter {
// but that means that they both get processed together.
DataBuffer ccs;
header.Write(&ccs, 0, body);
- server_.lock()->SendDirect(ccs);
+ agent()->SendDirect(ccs);
client_.lock()->Handshake();
state_ = AFTER_CCS;
// Request that the original record be dropped by the filter.
@@ -438,7 +435,6 @@ class BeforeFinished : public TlsRecordFilter {
private:
std::weak_ptr<TlsAgent> client_;
- std::weak_ptr<TlsAgent> server_;
VoidFunction before_ccs_;
VoidFunction before_finished_;
HandshakeState state_;
@@ -463,8 +459,8 @@ class BeforeFinished13 : public PacketFilter {
};
public:
- BeforeFinished13(std::shared_ptr<TlsAgent>& client,
- std::shared_ptr<TlsAgent>& server,
+ BeforeFinished13(const std::shared_ptr<TlsAgent>& client,
+ const std::shared_ptr<TlsAgent>& server,
VoidFunction before_finished)
: client_(client),
server_(server),
@@ -514,7 +510,7 @@ static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) {
// processed by the client, SSL_AuthCertificateComplete() is called.
TEST_F(TlsConnectDatagram13, AuthCompleteBeforeFinished) {
client_->SetAuthCertificateCallback(AuthCompleteBlock);
- server_->SetPacketFilter(
+ server_->SetFilter(
std::make_shared<BeforeFinished13>(client_, server_, [this]() {
EXPECT_EQ(SECSuccess,
SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
@@ -546,7 +542,7 @@ TEST_F(TlsConnectDatagram13, AuthCompleteAfterFinished) {
TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
client_->EnableFalseStart();
- server_->SetPacketFilter(std::make_shared<BeforeFinished>(
+ server_->SetFilter(std::make_shared<BeforeFinished>(
client_, server_,
[this]() { EXPECT_TRUE(client_->can_falsestart_hook_called()); },
[this]() {
@@ -562,7 +558,7 @@ TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) {
client_->EnableFalseStart();
client_->SetAuthCertificateCallback(AuthCompleteBlock);
- server_->SetPacketFilter(std::make_shared<BeforeFinished>(
+ server_->SetFilter(std::make_shared<BeforeFinished>(
client_, server_,
[]() {
// Do nothing before CCS
@@ -608,7 +604,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
// The client should send nothing from here on.
- client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
@@ -618,8 +614,8 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
- // Remove this before closing or the close_notify alert will trigger it.
- client_->DeletePacketFilter();
+ // Remove filter before closing or the close_notify alert will trigger it.
+ client_->ClearFilter();
}
TEST_P(TlsConnectGenericPre13, AuthCompleteFailDelayed) {
@@ -634,12 +630,12 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteFailDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
// The client should send nothing from here on.
- client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
// Report failure.
- client_->DeletePacketFilter();
+ client_->ClearFilter();
client_->ExpectSendAlert(kTlsAlertBadCertificate);
EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(),
SSL_ERROR_BAD_CERTIFICATE));
@@ -659,12 +655,12 @@ TEST_P(TlsConnectTls13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
// The client will send nothing until AuthCertificateComplete is called.
- client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
// This should allow the handshake to complete now.
- client_->DeletePacketFilter();
+ client_->ClearFilter();
EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
client_->Handshake(); // Send Finished
server_->Handshake(); // Transition to connected and send NewSessionTicket
@@ -682,12 +678,12 @@ TEST_P(TlsConnectTls13, AuthCompleteFailDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
// The client will send nothing until AuthCertificateComplete is called.
- client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
// Report failure.
- client_->DeletePacketFilter();
+ client_->ClearFilter();
ExpectAlert(client_, kTlsAlertBadCertificate);
EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(),
SSL_ERROR_BAD_CERTIFICATE));
@@ -831,9 +827,9 @@ TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigServer) {
TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigClient) {
Reset(certificate_);
- auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn);
- client_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_signature_algorithms_xtn);
+ client_->SetFilter(capture);
TestSignatureSchemeConfig(client_);
const DataBuffer& ext = capture->extension();
@@ -907,4 +903,4 @@ INSTANTIATE_TEST_CASE_P(
TlsAgent::kServerEcdsa384),
::testing::Values(ssl_auth_ecdsa),
::testing::Values(ssl_sig_ecdsa_sha1)));
-}
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
index 36ee104af..00b55a8c9 100644
--- a/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
+++ b/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
@@ -181,8 +181,8 @@ TEST_P(TlsConnectGenericPre13, OcspMangled) {
static const uint8_t val[] = {1};
auto replacer = std::make_shared<TlsExtensionReplacer>(
- ssl_cert_status_xtn, DataBuffer(val, sizeof(val)));
- server_->SetPacketFilter(replacer);
+ server_, ssl_cert_status_xtn, DataBuffer(val, sizeof(val)));
+ server_->SetFilter(replacer);
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
@@ -192,8 +192,8 @@ TEST_P(TlsConnectGeneric, OcspSuccess) {
EnsureTlsSetup();
client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
auto capture_ocsp =
- std::make_shared<TlsExtensionCapture>(ssl_cert_status_xtn);
- server_->SetPacketFilter(capture_ocsp);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_cert_status_xtn);
+ server_->SetFilter(capture_ocsp);
// The value should be available during the AuthCertificateCallback
client_->SetAuthCertificateCallback([](TlsAgent* agent, bool checksig,
@@ -245,4 +245,4 @@ TEST_P(TlsConnectGeneric, OcspHugeSuccess) {
Connect();
}
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
index 206ee1961..fa2238be7 100644
--- a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
+++ b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
@@ -466,4 +466,4 @@ static const SecStatusParams kSecStatusTestValuesArr[] = {
INSTANTIATE_TEST_CASE_P(TestSecurityStatus, SecurityStatusTest,
::testing::ValuesIn(kSecStatusTestValuesArr));
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_custext_unittest.cc b/gtests/ssl_gtest/ssl_custext_unittest.cc
index dad944a1f..7233b2218 100644
--- a/gtests/ssl_gtest/ssl_custext_unittest.cc
+++ b/gtests/ssl_gtest/ssl_custext_unittest.cc
@@ -150,9 +150,9 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionWriterDisable) {
client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NoopExtensionWriter,
nullptr, NoopExtensionHandler, nullptr);
EXPECT_EQ(SECSuccess, rv);
- auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_signed_cert_timestamp_xtn);
- client_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_signed_cert_timestamp_xtn);
+ client_->SetFilter(capture);
Connect();
// So nothing will be sent.
@@ -204,9 +204,9 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionOverride) {
EXPECT_EQ(SECSuccess, rv);
// Capture it to see what we got.
- auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_signed_cert_timestamp_xtn);
- client_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_signed_cert_timestamp_xtn);
+ client_->SetFilter(capture);
ConnectExpectAlert(server_, kTlsAlertDecodeError);
@@ -246,8 +246,8 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionClientToServer) {
EXPECT_EQ(SECSuccess, rv);
// Capture it to see what we got.
- auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
- client_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(client_, extension_code);
+ client_->SetFilter(capture);
// Handle it so that the handshake completes.
rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
@@ -290,9 +290,9 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientSH) {
EXPECT_EQ(SECSuccess, rv);
// Capture the extension from the ServerHello only and check it.
- auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
+ auto capture = std::make_shared<TlsExtensionCapture>(server_, extension_code);
capture->SetHandshakeTypes({kTlsHandshakeServerHello});
- server_->SetPacketFilter(capture);
+ server_->SetFilter(capture);
Connect();
@@ -329,9 +329,10 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientEE) {
EXPECT_EQ(SECSuccess, rv);
// Capture the extension from the EncryptedExtensions only and check it.
- auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
+ auto capture = std::make_shared<TlsExtensionCapture>(server_, extension_code);
capture->SetHandshakeTypes({kTlsHandshakeEncryptedExtensions});
- server_->SetTlsRecordFilter(capture);
+ capture->EnableDecryption();
+ server_->SetFilter(capture);
Connect();
@@ -350,8 +351,8 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionUnsolicitedServer) {
EXPECT_EQ(SECSuccess, rv);
// Capture it to see what we got.
- auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
- server_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(server_, extension_code);
+ server_->SetFilter(capture);
client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
@@ -500,4 +501,4 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionOverrunBuffer) {
client_->CheckErrorCode(SEC_ERROR_APPLICATION_CALLBACK_ERROR);
}
-} // namespace "nss_test"
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_damage_unittest.cc b/gtests/ssl_gtest/ssl_damage_unittest.cc
index d1668b823..e9e625a03 100644
--- a/gtests/ssl_gtest/ssl_damage_unittest.cc
+++ b/gtests/ssl_gtest/ssl_damage_unittest.cc
@@ -50,7 +50,7 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_3);
- server_->SetPacketFilter(std::make_shared<AfterRecordN>(
+ server_->SetFilter(std::make_shared<AfterRecordN>(
server_, client_,
0, // ServerHello.
[this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); }));
@@ -60,9 +60,10 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) {
TEST_P(TlsConnectGenericPre13, DamageServerSignature) {
EnsureTlsSetup();
- auto filter =
- std::make_shared<TlsLastByteDamager>(kTlsHandshakeServerKeyExchange);
- server_->SetTlsRecordFilter(filter);
+ auto filter = std::make_shared<TlsLastByteDamager>(
+ server_, kTlsHandshakeServerKeyExchange);
+ filter->EnableDecryption();
+ server_->SetFilter(filter);
ExpectAlert(client_, kTlsAlertDecryptError);
ConnectExpectFail();
client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
@@ -71,9 +72,10 @@ TEST_P(TlsConnectGenericPre13, DamageServerSignature) {
TEST_P(TlsConnectTls13, DamageServerSignature) {
EnsureTlsSetup();
- auto filter =
- std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
- server_->SetTlsRecordFilter(filter);
+ auto filter = std::make_shared<TlsLastByteDamager>(
+ server_, kTlsHandshakeCertificateVerify);
+ filter->EnableDecryption();
+ server_->SetFilter(filter);
ConnectExpectAlert(client_, kTlsAlertDecryptError);
client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
}
@@ -82,9 +84,10 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) {
EnsureTlsSetup();
client_->SetupClientAuth();
server_->RequestClientAuth(true);
- auto filter =
- std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
- client_->SetTlsRecordFilter(filter);
+ auto filter = std::make_shared<TlsLastByteDamager>(
+ client_, kTlsHandshakeCertificateVerify);
+ filter->EnableDecryption();
+ client_->SetFilter(filter);
server_->ExpectSendAlert(kTlsAlertDecryptError);
// Do these handshakes by hand to avoid race condition on
// the client processing the server's alert.
@@ -100,4 +103,4 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) {
server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
}
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_dhe_unittest.cc b/gtests/ssl_gtest/ssl_dhe_unittest.cc
index 899720607..b61728203 100644
--- a/gtests/ssl_gtest/ssl_dhe_unittest.cc
+++ b/gtests/ssl_gtest/ssl_dhe_unittest.cc
@@ -32,12 +32,12 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) {
client_->ConfigNamedGroups(kAllDHEGroups);
auto groups_capture =
- std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
auto shares_capture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture,
shares_capture};
- client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
Connect();
@@ -61,12 +61,12 @@ TEST_P(TlsConnectGeneric, ConnectFfdheClient) {
EnableOnlyDheCiphers();
client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
auto groups_capture =
- std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
auto shares_capture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture,
shares_capture};
- client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
Connect();
@@ -103,8 +103,8 @@ TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) {
class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter {
public:
- TlsDheServerKeyExchangeDamager()
- : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {}
+ TlsDheServerKeyExchangeDamager(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {}
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
@@ -122,7 +122,7 @@ class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter {
TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) {
EnableOnlyDheCiphers();
client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
- server_->SetPacketFilter(std::make_shared<TlsDheServerKeyExchangeDamager>());
+ server_->SetFilter(std::make_shared<TlsDheServerKeyExchangeDamager>(server_));
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
@@ -141,8 +141,9 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter {
kYZeroPad
};
- TlsDheSkeChangeY(uint8_t handshake_type, ChangeYTo change)
- : TlsHandshakeFilter({handshake_type}), change_Y_(change) {}
+ TlsDheSkeChangeY(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type, ChangeYTo change)
+ : TlsHandshakeFilter(agent, {handshake_type}), change_Y_(change) {}
protected:
void ChangeY(const DataBuffer& input, DataBuffer* output, size_t offset,
@@ -207,8 +208,9 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter {
class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {
public:
- TlsDheSkeChangeYServer(ChangeYTo change, bool modify)
- : TlsDheSkeChangeY(kTlsHandshakeServerKeyExchange, change),
+ TlsDheSkeChangeYServer(const std::shared_ptr<TlsAgent>& agent,
+ ChangeYTo change, bool modify)
+ : TlsDheSkeChangeY(agent, kTlsHandshakeServerKeyExchange, change),
modify_(modify),
p_() {}
@@ -245,9 +247,9 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {
class TlsDheSkeChangeYClient : public TlsDheSkeChangeY {
public:
TlsDheSkeChangeYClient(
- ChangeYTo change,
+ const std::shared_ptr<TlsAgent>& agent, ChangeYTo change,
std::shared_ptr<const TlsDheSkeChangeYServer> server_filter)
- : TlsDheSkeChangeY(kTlsHandshakeClientKeyExchange, change),
+ : TlsDheSkeChangeY(agent, kTlsHandshakeClientKeyExchange, change),
server_filter_(server_filter) {}
protected:
@@ -282,8 +284,8 @@ TEST_P(TlsDamageDHYTest, DamageServerY) {
client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
}
TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
- server_->SetPacketFilter(
- std::make_shared<TlsDheSkeChangeYServer>(change, true));
+ server_->SetFilter(
+ std::make_shared<TlsDheSkeChangeYServer>(server_, change, true));
if (change == TlsDheSkeChangeY::kYZeroPad) {
ExpectAlert(client_, kTlsAlertDecryptError);
@@ -312,14 +314,14 @@ TEST_P(TlsDamageDHYTest, DamageClientY) {
client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
}
// The filter on the server is required to capture the prime.
- auto server_filter =
- std::make_shared<TlsDheSkeChangeYServer>(TlsDheSkeChangeY::kYZero, false);
- server_->SetPacketFilter(server_filter);
+ auto server_filter = std::make_shared<TlsDheSkeChangeYServer>(
+ server_, TlsDheSkeChangeY::kYZero, false);
+ server_->SetFilter(server_filter);
// The client filter does the damage.
TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
- client_->SetPacketFilter(
- std::make_shared<TlsDheSkeChangeYClient>(change, server_filter));
+ client_->SetFilter(
+ std::make_shared<TlsDheSkeChangeYClient>(client_, change, server_filter));
if (change == TlsDheSkeChangeY::kYZeroPad) {
ExpectAlert(server_, kTlsAlertDecryptError);
@@ -358,7 +360,9 @@ INSTANTIATE_TEST_CASE_P(
class TlsDheSkeMakePEven : public TlsHandshakeFilter {
public:
- TlsDheSkeMakePEven() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {}
+ TlsDheSkeMakePEven(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {}
+
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
@@ -379,7 +383,7 @@ class TlsDheSkeMakePEven : public TlsHandshakeFilter {
// Even without requiring named groups, an even value for p is bad news.
TEST_P(TlsConnectGenericPre13, MakeDhePEven) {
EnableOnlyDheCiphers();
- server_->SetPacketFilter(std::make_shared<TlsDheSkeMakePEven>());
+ server_->SetFilter(std::make_shared<TlsDheSkeMakePEven>(server_));
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
@@ -389,7 +393,9 @@ TEST_P(TlsConnectGenericPre13, MakeDhePEven) {
class TlsDheSkeZeroPadP : public TlsHandshakeFilter {
public:
- TlsDheSkeZeroPadP() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {}
+ TlsDheSkeZeroPadP(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {}
+
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
@@ -407,7 +413,7 @@ class TlsDheSkeZeroPadP : public TlsHandshakeFilter {
// Zero padding only causes signature failure.
TEST_P(TlsConnectGenericPre13, PadDheP) {
EnableOnlyDheCiphers();
- server_->SetPacketFilter(std::make_shared<TlsDheSkeZeroPadP>());
+ server_->SetFilter(std::make_shared<TlsDheSkeZeroPadP>(server_));
ConnectExpectAlert(client_, kTlsAlertDecryptError);
@@ -529,12 +535,12 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) {
Reset();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
EnableOnlyDheCiphers();
- auto clientCapture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(clientCapture);
- auto serverCapture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- server_->SetPacketFilter(serverCapture);
+ auto clientCapture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_tls13_pre_shared_key_xtn);
+ client_->SetFilter(clientCapture);
+ auto serverCapture = std::make_shared<TlsExtensionCapture>(
+ server_, ssl_tls13_pre_shared_key_xtn);
+ server_->SetFilter(serverCapture);
ExpectResumption(RESUME_TICKET);
Connect();
CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
@@ -545,8 +551,9 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) {
class TlsDheSkeChangeSignature : public TlsHandshakeFilter {
public:
- TlsDheSkeChangeSignature(uint16_t version, const uint8_t* data, size_t len)
- : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}),
+ TlsDheSkeChangeSignature(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t version, const uint8_t* data, size_t len)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}),
version_(version),
data_(data),
len_(len) {}
@@ -595,8 +602,8 @@ TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) {
const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ffdhe_2048};
client_->ConfigNamedGroups(client_groups);
- server_->SetPacketFilter(std::make_shared<TlsDheSkeChangeSignature>(
- version_, kBogusDheSignature, sizeof(kBogusDheSignature)));
+ server_->SetFilter(std::make_shared<TlsDheSkeChangeSignature>(
+ server_, version_, kBogusDheSignature, sizeof(kBogusDheSignature)));
ConnectExpectAlert(client_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
diff --git a/gtests/ssl_gtest/ssl_drop_unittest.cc b/gtests/ssl_gtest/ssl_drop_unittest.cc
index c059e9938..ee8906deb 100644
--- a/gtests/ssl_gtest/ssl_drop_unittest.cc
+++ b/gtests/ssl_gtest/ssl_drop_unittest.cc
@@ -22,13 +22,13 @@ extern "C" {
namespace nss_test {
TEST_P(TlsConnectDatagramPre13, DropClientFirstFlightOnce) {
- client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1));
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x1));
Connect();
SendReceive();
}
TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightOnce) {
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x1));
Connect();
SendReceive();
}
@@ -37,32 +37,32 @@ TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightOnce) {
// flights that they send. Note: In DTLS 1.3, the shorter handshake means that
// this will also drop some application data, so we can't call SendReceive().
TEST_P(TlsConnectDatagramPre13, DropAllFirstTransmissions) {
- client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x15));
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x5));
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x15));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x5));
Connect();
}
// This drops the server's first flight three times.
TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightThrice) {
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x7));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x7));
Connect();
}
// This drops the client's second flight once
TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightOnce) {
- client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2));
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x2));
Connect();
}
// This drops the client's second flight three times.
TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightThrice) {
- client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe));
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0xe));
Connect();
}
// This drops the server's second flight three times.
TEST_P(TlsConnectDatagramPre13, DropServerSecondFlightThrice) {
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0xe));
Connect();
}
@@ -74,7 +74,7 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 {
expected_client_acks_(0),
expected_server_acks_(1) {}
- void SetUp() {
+ void SetUp() override {
TlsConnectDatagram13::SetUp();
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
SetFilters();
@@ -82,12 +82,8 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 {
void SetFilters() {
EnsureTlsSetup();
- client_->SetPacketFilter(client_filters_.chain_);
- client_filters_.ack_->SetAgent(client_.get());
- client_filters_.ack_->EnableDecryption();
- server_->SetPacketFilter(server_filters_.chain_);
- server_filters_.ack_->SetAgent(server_.get());
- server_filters_.ack_->EnableDecryption();
+ client_filters_.Init(client_);
+ server_filters_.Init(server_);
}
void HandshakeAndAck(const std::shared_ptr<TlsAgent>& agent) {
@@ -119,11 +115,17 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 {
class DropAckChain {
public:
DropAckChain()
- : records_(std::make_shared<TlsRecordRecorder>()),
- ack_(std::make_shared<TlsRecordRecorder>(content_ack)),
- drop_(std::make_shared<SelectiveRecordDropFilter>(0, false)),
- chain_(std::make_shared<ChainedPacketFilter>(
- ChainedPacketFilterInit({records_, ack_, drop_}))) {}
+ : records_(nullptr), ack_(nullptr), drop_(nullptr), chain_(nullptr) {}
+
+ void Init(const std::shared_ptr<TlsAgent>& agent) {
+ records_ = std::make_shared<TlsRecordRecorder>(agent);
+ ack_ = std::make_shared<TlsRecordRecorder>(agent, content_ack);
+ ack_->EnableDecryption();
+ drop_ = std::make_shared<SelectiveRecordDropFilter>(agent, 0, false);
+ chain_ = std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({records_, ack_, drop_}));
+ agent->SetFilter(chain_);
+ }
const TlsRecord& record(size_t i) const { return records_->record(i); }
@@ -227,7 +229,7 @@ TEST_F(TlsDropDatagram13, DropServerSecondRecordOnce) {
HandshakeAndAck(client_);
expected_client_acks_ = 1;
CheckedHandshakeSendReceive();
- CheckAcks(client_filters_, 0, {0});
+ CheckAcks(client_filters_, 0, {0}); // ServerHello
CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
}
@@ -257,7 +259,7 @@ TEST_F(TlsDropDatagram13, DropServerAckOnce) {
CheckPostHandshake();
// There should be two copies of the finished ACK
CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
- CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ CheckAcks(server_filters_, 1, {0x0002000000000000ULL});
}
// Drop the client certificate verify.
@@ -276,10 +278,9 @@ TEST_F(TlsDropDatagram13, DropClientCertVerify) {
// Ack of the whole client handshake.
CheckAcks(
server_filters_, 1,
- {0x0002000000000000ULL, // CH (we drop everything after this on client)
- 0x0002000000000003ULL, // CT (2)
- 0x0002000000000004ULL} // FIN (2)
- );
+ {0x0002000000000000ULL, // CH (we drop everything after this on client)
+ 0x0002000000000003ULL, // CT (2)
+ 0x0002000000000004ULL}); // FIN (2)
}
// Shrink the MTU down so that certs get split and drop the first piece.
@@ -303,10 +304,9 @@ TEST_F(TlsDropDatagram13, DropFirstHalfOfServerCertificate) {
EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len());
CheckedHandshakeSendReceive();
CheckAcks(client_filters_, 0,
- {0, // SH
- 0x0002000000000000ULL, // EE
- 0x0002000000000002ULL} // CT2
- );
+ {0, // SH
+ 0x0002000000000000ULL, // EE
+ 0x0002000000000002ULL}); // CT2
CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
}
@@ -540,7 +540,10 @@ TEST_F(TlsDropDatagram13, NoDropsDuringZeroRtt) {
ExpectEarlyDataAccepted(true);
CheckConnected();
SendReceive();
- CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ EXPECT_EQ(0U, client_filters_.ack_->count());
+ CheckAcks(server_filters_, 0,
+ {0x0001000000000001ULL, // EOED
+ 0x0002000000000000ULL}); // Finished
}
TEST_F(TlsDropDatagram13, DropEEDuringZeroRtt) {
@@ -558,7 +561,9 @@ TEST_F(TlsDropDatagram13, DropEEDuringZeroRtt) {
CheckConnected();
SendReceive();
CheckAcks(client_filters_, 0, {0});
- CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ CheckAcks(server_filters_, 0,
+ {0x0001000000000002ULL, // EOED
+ 0x0002000000000000ULL}); // Finished
}
class TlsReorderDatagram13 : public TlsDropDatagram13 {
@@ -688,6 +693,7 @@ TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) {
kTlsHandshakeType, DataBuffer(buf, sizeof(buf))));
server_->Handshake();
EXPECT_EQ(2UL, server_filters_.ack_->count());
+ // The server acknowledges client Finished twice.
CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
CheckAcks(server_filters_, 1, {0x0002000000000000ULL});
}
@@ -746,7 +752,9 @@ TEST_F(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) {
ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 0, 2}));
server_->Handshake();
CheckConnected();
- CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ EXPECT_EQ(0U, client_filters_.ack_->count());
+ // Acknowledgements for EOED and Finished.
+ CheckAcks(server_filters_, 0, {0x0001000000000002ULL, 0x0002000000000000ULL});
uint8_t buf[8];
rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
EXPECT_EQ(-1, rv);
@@ -783,7 +791,9 @@ TEST_F(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) {
ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 2, 0}));
server_->Handshake();
CheckConnected();
- CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ EXPECT_EQ(0U, client_filters_.ack_->count());
+ // Acknowledgements for EOED and Finished.
+ CheckAcks(server_filters_, 0, {0x0001000000000002ULL, 0x0002000000000000ULL});
uint8_t buf[8];
rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
EXPECT_EQ(-1, rv);
diff --git a/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/gtests/ssl_gtest/ssl_ecdh_unittest.cc
index 8caa7bf71..f0b499dc4 100644
--- a/gtests/ssl_gtest/ssl_ecdh_unittest.cc
+++ b/gtests/ssl_gtest/ssl_ecdh_unittest.cc
@@ -75,9 +75,9 @@ TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) {
// This causes a HelloRetryRequest in TLS 1.3. Earlier versions don't care.
TEST_P(TlsConnectGeneric, ConnectEcdheP384Server) {
EnsureTlsSetup();
- auto hrr_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeHelloRetryRequest);
- server_->SetPacketFilter(hrr_capture);
+ auto hrr_capture = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeHelloRetryRequest);
+ server_->SetFilter(hrr_capture);
const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
server_->ConfigNamedGroups(groups);
Connect();
@@ -193,8 +193,8 @@ TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) {
class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter {
public:
- TlsKeyExchangeGroupCapture()
- : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}),
+ TlsKeyExchangeGroupCapture(const std::shared_ptr<TlsAgent> &agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}),
group_(ssl_grp_none) {}
SSLNamedGroup group() const { return group_; }
@@ -221,10 +221,10 @@ class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter {
// P-256 is supported by the client (<= 1.2 only).
TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) {
EnsureTlsSetup();
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn));
- auto group_capture = std::make_shared<TlsKeyExchangeGroupCapture>();
- server_->SetPacketFilter(group_capture);
+ client_->SetFilter(
+ std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn));
+ auto group_capture = std::make_shared<TlsKeyExchangeGroupCapture>(server_);
+ server_->SetFilter(group_capture);
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
@@ -236,8 +236,8 @@ TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) {
// Supported groups is mandatory in TLS 1.3.
TEST_P(TlsConnectTls13, DropSupportedGroupExtension) {
EnsureTlsSetup();
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn));
+ client_->SetFilter(
+ std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn));
ConnectExpectAlert(server_, kTlsAlertMissingExtension);
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION);
@@ -516,7 +516,8 @@ TEST_P(TlsKeyExchangeTest13, MultipleClientShares) {
// Replace the point in the client key exchange message with an empty one
class ECCClientKEXFilter : public TlsHandshakeFilter {
public:
- ECCClientKEXFilter() : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}) {}
+ ECCClientKEXFilter(const std::shared_ptr<TlsAgent> &client)
+ : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}) {}
protected:
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
@@ -532,7 +533,8 @@ class ECCClientKEXFilter : public TlsHandshakeFilter {
// Replace the point in the server key exchange message with an empty one
class ECCServerKEXFilter : public TlsHandshakeFilter {
public:
- ECCServerKEXFilter() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {}
+ ECCServerKEXFilter(const std::shared_ptr<TlsAgent> &server)
+ : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}) {}
protected:
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
@@ -550,15 +552,13 @@ class ECCServerKEXFilter : public TlsHandshakeFilter {
};
TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyServerPoint) {
- // add packet filter
- server_->SetPacketFilter(std::make_shared<ECCServerKEXFilter>());
+ server_->SetFilter(std::make_shared<ECCServerKEXFilter>(server_));
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH);
}
TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyClientPoint) {
- // add packet filter
- client_->SetPacketFilter(std::make_shared<ECCClientKEXFilter>());
+ client_->SetFilter(std::make_shared<ECCClientKEXFilter>(client_));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_KEY_EXCH);
}
diff --git a/gtests/ssl_gtest/ssl_extension_unittest.cc b/gtests/ssl_gtest/ssl_extension_unittest.cc
index a9daeed82..295090cdd 100644
--- a/gtests/ssl_gtest/ssl_extension_unittest.cc
+++ b/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -19,8 +19,9 @@ namespace nss_test {
class TlsExtensionTruncator : public TlsExtensionFilter {
public:
- TlsExtensionTruncator(uint16_t extension, size_t length)
- : extension_(extension), length_(length) {}
+ TlsExtensionTruncator(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension, size_t length)
+ : TlsExtensionFilter(agent), extension_(extension), length_(length) {}
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) {
@@ -42,8 +43,9 @@ class TlsExtensionTruncator : public TlsExtensionFilter {
class TlsExtensionDamager : public TlsExtensionFilter {
public:
- TlsExtensionDamager(uint16_t extension, size_t index)
- : extension_(extension), index_(index) {}
+ TlsExtensionDamager(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension, size_t index)
+ : TlsExtensionFilter(agent), extension_(extension), index_(index) {}
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) {
@@ -63,8 +65,11 @@ class TlsExtensionDamager : public TlsExtensionFilter {
class TlsExtensionAppender : public TlsHandshakeFilter {
public:
- TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data)
- : TlsHandshakeFilter({handshake_type}), extension_(ext), data_(data) {}
+ TlsExtensionAppender(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type, uint16_t ext, DataBuffer& data)
+ : TlsHandshakeFilter(agent, {handshake_type}),
+ extension_(ext),
+ data_(data) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -124,13 +129,13 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
void ClientHelloErrorTest(std::shared_ptr<PacketFilter> filter,
uint8_t desc = kTlsAlertDecodeError) {
- client_->SetPacketFilter(filter);
+ client_->SetFilter(filter);
ConnectExpectAlert(server_, desc);
}
void ServerHelloErrorTest(std::shared_ptr<PacketFilter> filter,
uint8_t desc = kTlsAlertDecodeError) {
- server_->SetPacketFilter(filter);
+ server_->SetFilter(filter);
ConnectExpectAlert(client_, desc);
}
@@ -156,7 +161,7 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
StartConnect();
client_->Handshake(); // Send ClientHello
server_->Handshake(); // Send HRR.
- client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(type));
+ client_->SetFilter(std::make_shared<TlsExtensionDropper>(client_, type));
Handshake();
client_->CheckErrorCode(client_error);
server_->CheckErrorCode(server_error);
@@ -197,8 +202,8 @@ class TlsExtensionTest13
void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) {
DataBuffer versions_buf(buf, len);
- client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
- ssl_tls13_supported_versions_xtn, versions_buf));
+ client_->SetFilter(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_tls13_supported_versions_xtn, versions_buf));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -209,8 +214,8 @@ class TlsExtensionTest13
size_t index = versions_buf.Write(0, 2, 1);
versions_buf.Write(index, version, 2);
- client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
- ssl_tls13_supported_versions_xtn, versions_buf));
+ client_->SetFilter(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_tls13_supported_versions_xtn, versions_buf));
ConnectExpectFail();
}
};
@@ -241,26 +246,26 @@ class TlsExtensionTestPre13 : public TlsExtensionTestBase,
TEST_P(TlsExtensionTestGeneric, DamageSniLength) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 1));
+ std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 1));
}
TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 4));
+ std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 4));
}
TEST_P(TlsExtensionTestGeneric, TruncateSni) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_server_name_xtn, 7));
+ std::make_shared<TlsExtensionTruncator>(client_, ssl_server_name_xtn, 7));
}
// A valid extension that appears twice will be reported as unsupported.
TEST_P(TlsExtensionTestGeneric, RepeatSni) {
DataBuffer extension;
InitSimpleSni(&extension);
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionInjector>(ssl_server_name_xtn, extension),
- kTlsAlertIllegalParameter);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionInjector>(
+ client_, ssl_server_name_xtn, extension),
+ kTlsAlertIllegalParameter);
}
// An SNI entry with zero length is considered invalid (strangely, not if it is
@@ -272,23 +277,23 @@ TEST_P(TlsExtensionTestGeneric, BadSni) {
extension.Allocate(simple.len() + 3);
extension.Write(0, static_cast<uint32_t>(0), 3);
extension.Write(3, simple);
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_server_name_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, EmptySni) {
DataBuffer extension;
extension.Allocate(2);
extension.Write(0, static_cast<uint32_t>(0), 2);
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_server_name_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) {
EnableAlpn();
DataBuffer extension;
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension),
+ client_, ssl_app_layer_protocol_xtn, extension),
kTlsAlertIllegalParameter);
}
@@ -299,21 +304,21 @@ TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension),
+ client_, ssl_app_layer_protocol_xtn, extension),
kTlsAlertNoApplicationProtocol);
}
TEST_P(TlsExtensionTestGeneric, OneByteAlpn) {
EnableAlpn();
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 1));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_app_layer_protocol_xtn, 1));
}
TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) {
EnableAlpn();
// This will leave the length of the second entry, but no value.
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 5));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_app_layer_protocol_xtn, 5));
}
TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
@@ -321,7 +326,7 @@ TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
const uint8_t val[] = {0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ client_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, AlpnMismatch) {
@@ -340,7 +345,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) {
@@ -348,7 +353,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) {
const uint8_t val[] = {0x00, 0x01, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) {
@@ -356,7 +361,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) {
@@ -364,7 +369,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) {
const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) {
@@ -372,7 +377,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) {
const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) {
@@ -380,7 +385,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) {
const uint8_t val[] = {0x00, 0x02, 0x99, 0x61};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) {
@@ -388,43 +393,43 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) {
const uint8_t val[] = {0x00, 0x02, 0x01, 0x67};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension),
+ server_, ssl_app_layer_protocol_xtn, extension),
kTlsAlertIllegalParameter);
}
TEST_P(TlsExtensionTestDtls, SrtpShort) {
EnableSrtp();
ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_use_srtp_xtn, 3));
+ std::make_shared<TlsExtensionTruncator>(client_, ssl_use_srtp_xtn, 3));
}
TEST_P(TlsExtensionTestDtls, SrtpOdd) {
EnableSrtp();
const uint8_t val[] = {0x00, 0x01, 0xff, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionReplacer>(ssl_use_srtp_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_use_srtp_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) {
const uint8_t val[] = {0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension),
+ client_, ssl_signature_algorithms_xtn, extension),
kTlsAlertHandshakeFailure);
}
@@ -432,7 +437,7 @@ TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsNoOverlap) {
const uint8_t val[] = {0x00, 0x02, 0xff, 0xff};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension),
+ client_, ssl_signature_algorithms_xtn, extension),
kTlsAlertHandshakeFailure);
}
@@ -440,12 +445,12 @@ TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) {
const uint8_t val[] = {0x00, 0x01, 0x04};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn),
+ std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn),
version_ < SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsAlertDecryptError
: kTlsAlertMissingExtension);
}
@@ -454,63 +459,63 @@ TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) {
const uint8_t val[] = {0x00, 0x01, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_elliptic_curves_xtn, extension));
+ client_, ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) {
const uint8_t val[] = {0x09, 0x99, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_elliptic_curves_xtn, extension));
+ client_, ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_elliptic_curves_xtn, extension));
+ client_, ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) {
const uint8_t val[] = {0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_ec_point_formats_xtn, extension));
+ client_, ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) {
const uint8_t val[] = {0x99, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_ec_point_formats_xtn, extension));
+ client_, ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) {
const uint8_t val[] = {0x01, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_ec_point_formats_xtn, extension));
+ client_, ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) {
const uint8_t val[] = {0x99};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_renegotiation_info_xtn, extension));
+ client_, ssl_renegotiation_info_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) {
const uint8_t val[] = {0x01, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_renegotiation_info_xtn, extension));
+ client_, ssl_renegotiation_info_xtn, extension));
}
// The extension has to contain a length.
TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) {
DataBuffer extension;
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_renegotiation_info_xtn, extension));
+ client_, ssl_renegotiation_info_xtn, extension));
}
// This only works on TLS 1.2, since it relies on static RSA; otherwise libssl
@@ -519,10 +524,10 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_rsae_sha512,
ssl_sig_rsa_pss_rsae_sha384};
- auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn);
+ auto capture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_signature_algorithms_xtn);
client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes));
- client_->SetPacketFilter(capture);
+ client_->SetFilter(capture);
EnableOnlyStaticRsaCiphers();
Connect();
@@ -540,9 +545,9 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
// Temporary test to verify that we choke on an empty ClientKeyShare.
// This test will fail when we implement HelloRetryRequest.
TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_tls13_key_share_xtn, 2),
- kTlsAlertHandshakeFailure);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_tls13_key_share_xtn, 2),
+ kTlsAlertHandshakeFailure);
}
// These tests only work in stream mode because the client sends a
@@ -551,8 +556,8 @@ TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
// packet gets dropped.
TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) {
EnsureTlsSetup();
- server_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_tls13_key_share_xtn));
+ server_->SetFilter(
+ std::make_shared<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn));
client_->ExpectSendAlert(kTlsAlertMissingExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -572,8 +577,8 @@ TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) {
0x02};
DataBuffer buf(key_share, sizeof(key_share));
EnsureTlsSetup();
- server_->SetPacketFilter(
- std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf));
+ server_->SetFilter(std::make_shared<TlsExtensionReplacer>(
+ server_, ssl_tls13_key_share_xtn, buf));
client_->ExpectSendAlert(kTlsAlertIllegalParameter);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -594,8 +599,8 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
0x02};
DataBuffer buf(key_share, sizeof(key_share));
EnsureTlsSetup();
- server_->SetPacketFilter(
- std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf));
+ server_->SetFilter(std::make_shared<TlsExtensionReplacer>(
+ server_, ssl_tls13_key_share_xtn, buf));
client_->ExpectSendAlert(kTlsAlertMissingExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -606,8 +611,8 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) {
SetupForResume();
DataBuffer empty;
- server_->SetPacketFilter(std::make_shared<TlsExtensionInjector>(
- ssl_signature_algorithms_xtn, empty));
+ server_->SetFilter(std::make_shared<TlsExtensionInjector>(
+ server_, ssl_signature_algorithms_xtn, empty));
client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -627,8 +632,12 @@ typedef std::function<void(TlsPreSharedKeyReplacer*)>
class TlsPreSharedKeyReplacer : public TlsExtensionFilter {
public:
- TlsPreSharedKeyReplacer(TlsPreSharedKeyReplacerFunc function)
- : identities_(), binders_(), function_(function) {}
+ TlsPreSharedKeyReplacer(const std::shared_ptr<TlsAgent>& agent,
+ TlsPreSharedKeyReplacerFunc function)
+ : TlsExtensionFilter(agent),
+ identities_(),
+ binders_(),
+ function_(function) {}
static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size,
const std::unique_ptr<DataBuffer>& replace,
@@ -742,8 +751,10 @@ class TlsPreSharedKeyReplacer : public TlsExtensionFilter {
TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) {
SetupForResume();
- client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([](
- TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); }));
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->identities_[0].identity.Truncate(0);
+ }));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -753,8 +764,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) {
TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1);
}));
ConnectExpectAlert(server_, kTlsAlertDecryptError);
@@ -766,8 +777,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->binders_[0].Write(r->binders_[0].len(), 0xff, 1);
}));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
@@ -779,7 +790,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) {
SetupForResume();
- client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_,
[](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); }));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
@@ -791,8 +803,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) {
TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->identities_.push_back(r->identities_[0]);
r->binders_.push_back(r->binders_[0]);
}));
@@ -806,8 +818,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) {
TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->identities_.push_back(r->identities_[0]);
}));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
@@ -818,8 +830,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) {
SetupForResume();
- client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([](
- TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); }));
+ client_->SetFilter(std::make_shared<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->binders_.push_back(r->binders_[0]);
+ }));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -831,8 +845,8 @@ TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) {
const uint8_t empty_buf[] = {0};
DataBuffer empty(empty_buf, 0);
// Inject an unused extension after the PSK extension.
- client_->SetPacketFilter(std::make_shared<TlsExtensionAppender>(
- kTlsHandshakeClientHello, 0xffff, empty));
+ client_->SetFilter(std::make_shared<TlsExtensionAppender>(
+ client_, kTlsHandshakeClientHello, 0xffff, empty));
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -842,8 +856,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) {
SetupForResume();
DataBuffer empty;
- client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(
- ssl_tls13_psk_key_exchange_modes_xtn));
+ client_->SetFilter(std::make_shared<TlsExtensionDropper>(
+ client_, ssl_tls13_psk_key_exchange_modes_xtn));
ConnectExpectAlert(server_, kTlsAlertMissingExtension);
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES);
@@ -858,8 +872,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
kTls13PskKe};
DataBuffer modes(ke_modes, sizeof(ke_modes));
- client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
- ssl_tls13_psk_key_exchange_modes_xtn, modes));
+ client_->SetFilter(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_tls13_psk_key_exchange_modes_xtn, modes));
client_->ExpectSendAlert(kTlsAlertBadRecordMac);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -870,8 +884,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) {
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
auto capture = std::make_shared<TlsExtensionCapture>(
- ssl_tls13_psk_key_exchange_modes_xtn);
- client_->SetPacketFilter(capture);
+ client_, ssl_tls13_psk_key_exchange_modes_xtn);
+ client_->SetFilter(capture);
Connect();
EXPECT_FALSE(capture->captured());
}
@@ -966,12 +980,11 @@ class TlsBogusExtensionTest : public TlsConnectTestBase,
void AddFilter(uint8_t message, uint16_t extension) {
static uint8_t empty_buf[1] = {0};
DataBuffer empty(empty_buf, 0);
- auto filter =
- std::make_shared<TlsExtensionAppender>(message, extension, empty);
+ auto filter = std::make_shared<TlsExtensionAppender>(server_, message,
+ extension, empty);
+ server_->SetFilter(filter);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
- server_->SetTlsRecordFilter(filter);
- } else {
- server_->SetPacketFilter(filter);
+ filter->EnableDecryption();
}
}
@@ -1087,8 +1100,9 @@ TEST_P(TlsConnectStream, IncludePadding) {
SECStatus rv = SSL_SetURL(client_->ssl_fd(), long_name);
EXPECT_EQ(SECSuccess, rv);
- auto capture = std::make_shared<TlsExtensionCapture>(ssl_padding_xtn);
- client_->SetPacketFilter(capture);
+ auto capture =
+ std::make_shared<TlsExtensionCapture>(client_, ssl_padding_xtn);
+ client_->SetFilter(capture);
client_->StartConnect();
client_->Handshake();
EXPECT_TRUE(capture->captured());
diff --git a/gtests/ssl_gtest/ssl_fragment_unittest.cc b/gtests/ssl_gtest/ssl_fragment_unittest.cc
index 64b824786..f4940bf28 100644
--- a/gtests/ssl_gtest/ssl_fragment_unittest.cc
+++ b/gtests/ssl_gtest/ssl_fragment_unittest.cc
@@ -149,13 +149,13 @@ class RecordFragmenter : public PacketFilter {
};
TEST_P(TlsConnectDatagram, FragmentClientPackets) {
- client_->SetPacketFilter(std::make_shared<RecordFragmenter>());
+ client_->SetFilter(std::make_shared<RecordFragmenter>());
Connect();
SendReceive();
}
TEST_P(TlsConnectDatagram, FragmentServerPackets) {
- server_->SetPacketFilter(std::make_shared<RecordFragmenter>());
+ server_->SetFilter(std::make_shared<RecordFragmenter>());
Connect();
SendReceive();
}
diff --git a/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/gtests/ssl_gtest/ssl_fuzz_unittest.cc
index ab4c0eab7..03c98d245 100644
--- a/gtests/ssl_gtest/ssl_fuzz_unittest.cc
+++ b/gtests/ssl_gtest/ssl_fuzz_unittest.cc
@@ -27,7 +27,8 @@ class TlsFuzzTest : public ::testing::Test {};
// Record the application data stream.
class TlsApplicationDataRecorder : public TlsRecordFilter {
public:
- TlsApplicationDataRecorder() : buffer_() {}
+ TlsApplicationDataRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent), buffer_() {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
@@ -106,16 +107,18 @@ FUZZ_P(TlsConnectGeneric, DeterministicTranscript) {
DisableECDHEServerKeyReuse();
DataBuffer buffer;
- client_->SetPacketFilter(std::make_shared<TlsConversationRecorder>(buffer));
- server_->SetPacketFilter(std::make_shared<TlsConversationRecorder>(buffer));
+ client_->SetFilter(
+ std::make_shared<TlsConversationRecorder>(client_, buffer));
+ server_->SetFilter(
+ std::make_shared<TlsConversationRecorder>(server_, buffer));
// Reset the RNG state.
EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0));
Connect();
// Ensure the filters go away before |buffer| does.
- client_->DeletePacketFilter();
- server_->DeletePacketFilter();
+ client_->ClearFilter();
+ server_->ClearFilter();
if (last.len() > 0) {
EXPECT_EQ(last, buffer);
@@ -133,10 +136,10 @@ FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) {
EnsureTlsSetup();
// Set up app data filters.
- auto client_recorder = std::make_shared<TlsApplicationDataRecorder>();
- client_->SetPacketFilter(client_recorder);
- auto server_recorder = std::make_shared<TlsApplicationDataRecorder>();
- server_->SetPacketFilter(server_recorder);
+ auto client_recorder = std::make_shared<TlsApplicationDataRecorder>(client_);
+ client_->SetFilter(client_recorder);
+ auto server_recorder = std::make_shared<TlsApplicationDataRecorder>(server_);
+ server_->SetFilter(server_recorder);
Connect();
@@ -161,10 +164,10 @@ FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) {
FUZZ_P(TlsConnectGeneric, BogusClientFinished) {
EnsureTlsSetup();
- auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
- kTlsHandshakeFinished,
+ auto filter = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
+ client_, kTlsHandshakeFinished,
DataBuffer(kShortEmptyFinished, sizeof(kShortEmptyFinished)));
- client_->SetPacketFilter(i1);
+ client_->SetFilter(filter);
Connect();
SendReceive();
}
@@ -173,10 +176,10 @@ FUZZ_P(TlsConnectGeneric, BogusClientFinished) {
FUZZ_P(TlsConnectGeneric, BogusServerFinished) {
EnsureTlsSetup();
- auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
- kTlsHandshakeFinished,
+ auto filter = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
+ server_, kTlsHandshakeFinished,
DataBuffer(kLongEmptyFinished, sizeof(kLongEmptyFinished)));
- server_->SetPacketFilter(i1);
+ server_->SetFilter(filter);
Connect();
SendReceive();
}
@@ -187,7 +190,7 @@ FUZZ_P(TlsConnectGeneric, BogusServerAuthSignature) {
uint8_t msg_type = version_ == SSL_LIBRARY_VERSION_TLS_1_3
? kTlsHandshakeCertificateVerify
: kTlsHandshakeServerKeyExchange;
- server_->SetPacketFilter(std::make_shared<TlsLastByteDamager>(msg_type));
+ server_->SetFilter(std::make_shared<TlsLastByteDamager>(server_, msg_type));
Connect();
SendReceive();
}
@@ -197,8 +200,8 @@ FUZZ_P(TlsConnectGeneric, BogusClientAuthSignature) {
EnsureTlsSetup();
client_->SetupClientAuth();
server_->RequestClientAuth(true);
- client_->SetPacketFilter(
- std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify));
+ client_->SetFilter(std::make_shared<TlsLastByteDamager>(
+ client_, kTlsHandshakeCertificateVerify));
Connect();
}
@@ -219,29 +222,29 @@ FUZZ_P(TlsConnectGeneric, SessionTicketResumption) {
FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) {
ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
- auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeNewSessionTicket);
- server_->SetPacketFilter(i1);
+ auto filter = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeNewSessionTicket);
+ server_->SetFilter(filter);
Connect();
- std::cerr << "ticket" << i1->buffer() << std::endl;
+ std::cerr << "ticket" << filter->buffer() << std::endl;
size_t offset = 4; /* lifetime */
if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
offset += 4; /* ticket_age_add */
uint32_t nonce_len = 0;
- EXPECT_TRUE(i1->buffer().Read(offset, 1, &nonce_len));
+ EXPECT_TRUE(filter->buffer().Read(offset, 1, &nonce_len));
offset += 1 + nonce_len;
}
offset += 2 + /* ticket length */
2; /* TLS_EX_SESS_TICKET_VERSION */
// Check the protocol version number.
uint32_t tls_version = 0;
- EXPECT_TRUE(i1->buffer().Read(offset, sizeof(version_), &tls_version));
+ EXPECT_TRUE(filter->buffer().Read(offset, sizeof(version_), &tls_version));
EXPECT_EQ(version_, static_cast<decltype(version_)>(tls_version));
// Check the cipher suite.
uint32_t suite = 0;
- EXPECT_TRUE(i1->buffer().Read(offset + sizeof(version_), 2, &suite));
+ EXPECT_TRUE(filter->buffer().Read(offset + sizeof(version_), 2, &suite));
client_->CheckCipherSuite(static_cast<uint16_t>(suite));
}
}
diff --git a/gtests/ssl_gtest/ssl_hrr_unittest.cc b/gtests/ssl_gtest/ssl_hrr_unittest.cc
index 93e19a720..55b6d6d90 100644
--- a/gtests/ssl_gtest/ssl_hrr_unittest.cc
+++ b/gtests/ssl_gtest/ssl_hrr_unittest.cc
@@ -35,17 +35,17 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
// Send first ClientHello and send 0-RTT data
auto capture_early_data =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn);
- client_->SetPacketFilter(capture_early_data);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_early_data_xtn);
+ client_->SetFilter(capture_early_data);
client_->Handshake();
EXPECT_EQ(k0RttDataLen, PR_Write(client_->ssl_fd(), k0RttData,
k0RttDataLen)); // 0-RTT write.
EXPECT_TRUE(capture_early_data->captured());
// Send the HelloRetryRequest
- auto hrr_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeHelloRetryRequest);
- server_->SetPacketFilter(hrr_capture);
+ auto hrr_capture = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeHelloRetryRequest);
+ server_->SetFilter(hrr_capture);
server_->Handshake();
EXPECT_LT(0U, hrr_capture->buffer().len());
@@ -56,8 +56,8 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
// Make a new capture for the early data.
capture_early_data =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn);
- client_->SetPacketFilter(capture_early_data);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_early_data_xtn);
+ client_->SetFilter(capture_early_data);
// Complete the handshake successfully
Handshake();
@@ -71,6 +71,10 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
// packet. If the record is split into two packets, or there are multiple
// handshake packets, this will break.
class CorrectMessageSeqAfterHrrFilter : public TlsRecordFilter {
+ public:
+ CorrectMessageSeqAfterHrrFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent) {}
+
protected:
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& record, size_t* offset,
@@ -131,8 +135,8 @@ TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) {
// Correct the DTLS message sequence number after an HRR.
if (variant_ == ssl_variant_datagram) {
- client_->SetPacketFilter(
- std::make_shared<CorrectMessageSeqAfterHrrFilter>());
+ client_->SetFilter(
+ std::make_shared<CorrectMessageSeqAfterHrrFilter>(client_));
}
server_->SetPeer(client_);
@@ -151,7 +155,8 @@ TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) {
class KeyShareReplayer : public TlsExtensionFilter {
public:
- KeyShareReplayer() {}
+ KeyShareReplayer(const std::shared_ptr<TlsAgent>& agent)
+ : TlsExtensionFilter(agent) {}
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
@@ -178,7 +183,7 @@ class KeyShareReplayer : public TlsExtensionFilter {
// server should reject this.
TEST_P(TlsConnectTls13, RetryWithSameKeyShare) {
EnsureTlsSetup();
- client_->SetPacketFilter(std::make_shared<KeyShareReplayer>());
+ client_->SetFilter(std::make_shared<KeyShareReplayer>(client_));
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
ssl_grp_ec_secp521r1};
server_->ConfigNamedGroups(groups);
@@ -192,7 +197,7 @@ TEST_P(TlsConnectTls13, RetryWithSameKeyShare) {
TEST_P(TlsConnectTls13, RetryWithTwoShares) {
EnsureTlsSetup();
EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
- client_->SetPacketFilter(std::make_shared<KeyShareReplayer>());
+ client_->SetFilter(std::make_shared<KeyShareReplayer>(client_));
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
ssl_grp_ec_secp521r1};
@@ -238,9 +243,10 @@ TEST_P(TlsConnectTls13, RetryCallbackAcceptGroupMismatch) {
return ssl_hello_retry_accept;
};
- auto capture = std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
+ auto capture =
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_cookie_xtn);
capture->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
- server_->SetPacketFilter(capture);
+ server_->SetFilter(capture);
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
server_->ConfigNamedGroups(groups);
@@ -359,14 +365,14 @@ SSLHelloRetryRequestAction RetryHello(PRBool firstHello,
TEST_P(TlsConnectTls13, RetryCallbackRetry) {
EnsureTlsSetup();
- auto capture_hrr = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- ssl_hs_hello_retry_request);
+ auto capture_hrr = std::make_shared<TlsHandshakeRecorder>(
+ server_, ssl_hs_hello_retry_request);
auto capture_key_share =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
std::vector<std::shared_ptr<PacketFilter>> chain = {capture_hrr,
capture_key_share};
- server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(chain));
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(chain));
size_t cb_called = 0;
EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
@@ -383,8 +389,8 @@ TEST_P(TlsConnectTls13, RetryCallbackRetry) {
<< "no key_share extension expected";
auto capture_cookie =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
- client_->SetPacketFilter(capture_cookie);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_cookie_xtn);
+ client_->SetFilter(capture_cookie);
Handshake();
CheckConnected();
@@ -413,9 +419,9 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) {
EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
auto capture_server =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
- server_->SetPacketFilter(capture_server);
+ server_->SetFilter(capture_server);
size_t cb_called = 0;
EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
@@ -431,8 +437,8 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) {
<< "no key_share extension expected from server";
auto capture_client_2nd =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
- client_->SetPacketFilter(capture_client_2nd);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
+ client_->SetFilter(capture_client_2nd);
Handshake();
CheckConnected();
@@ -449,12 +455,12 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithGroupMismatch) {
EnsureTlsSetup();
auto capture_cookie =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_cookie_xtn);
capture_cookie->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
auto capture_key_share =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
- server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
ChainedPacketFilterInit{capture_cookie, capture_key_share}));
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
@@ -493,9 +499,9 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithToken) {
EnsureTlsSetup();
auto capture_key_share =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
- server_->SetPacketFilter(capture_key_share);
+ server_->SetFilter(capture_key_share);
size_t cb_called = 0;
EXPECT_EQ(SECSuccess,
@@ -513,9 +519,9 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithTokenAndGroupMismatch) {
server_->ConfigNamedGroups(groups);
auto capture_key_share =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
- server_->SetPacketFilter(capture_key_share);
+ server_->SetFilter(capture_key_share);
size_t cb_called = 0;
EXPECT_EQ(SECSuccess,
@@ -589,8 +595,8 @@ TEST_P(TlsConnectTls13, RetryStatefulDropCookie) {
EnsureTlsSetup();
TriggerHelloRetryRequest(client_, server_);
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_tls13_cookie_xtn));
+ client_->SetFilter(
+ std::make_shared<TlsExtensionDropper>(client_, ssl_tls13_cookie_xtn));
ExpectAlert(server_, kTlsAlertMissingExtension);
Handshake();
@@ -603,8 +609,9 @@ TEST_F(TlsConnectStreamTls13, RetryStatelessDamageFirstClientHello) {
ConfigureSelfEncrypt();
EnsureTlsSetup();
- auto damage_ch = std::make_shared<TlsExtensionInjector>(0xfff3, DataBuffer());
- client_->SetPacketFilter(damage_ch);
+ auto damage_ch =
+ std::make_shared<TlsExtensionInjector>(client_, 0xfff3, DataBuffer());
+ client_->SetFilter(damage_ch);
TriggerHelloRetryRequest(client_, server_);
MakeNewServer();
@@ -625,8 +632,9 @@ TEST_F(TlsConnectStreamTls13, RetryStatelessDamageSecondClientHello) {
TriggerHelloRetryRequest(client_, server_);
MakeNewServer();
- auto damage_ch = std::make_shared<TlsExtensionInjector>(0xfff3, DataBuffer());
- client_->SetPacketFilter(damage_ch);
+ auto damage_ch =
+ std::make_shared<TlsExtensionInjector>(client_, 0xfff3, DataBuffer());
+ client_->SetFilter(damage_ch);
// Key exchange fails when the handshake continues because client and server
// disagree about the transcript.
@@ -640,7 +648,7 @@ TEST_F(TlsConnectStreamTls13, RetryStatelessDamageSecondClientHello) {
// Read the cipher suite from the HRR and disable it on the identified agent.
static void DisableSuiteFromHrr(
std::shared_ptr<TlsAgent>& agent,
- std::shared_ptr<TlsInspectorRecordHandshakeMessage>& capture_hrr) {
+ std::shared_ptr<TlsHandshakeRecorder>& capture_hrr) {
uint32_t tmp;
size_t offset = 2 + 32; // skip version + server_random
ASSERT_TRUE(
@@ -657,9 +665,9 @@ TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteClient) {
ConfigureSelfEncrypt();
EnsureTlsSetup();
- auto capture_hrr = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- ssl_hs_hello_retry_request);
- server_->SetPacketFilter(capture_hrr);
+ auto capture_hrr = std::make_shared<TlsHandshakeRecorder>(
+ server_, ssl_hs_hello_retry_request);
+ server_->SetFilter(capture_hrr);
TriggerHelloRetryRequest(client_, server_);
MakeNewServer();
@@ -678,9 +686,9 @@ TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteServer) {
ConfigureSelfEncrypt();
EnsureTlsSetup();
- auto capture_hrr = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- ssl_hs_hello_retry_request);
- server_->SetPacketFilter(capture_hrr);
+ auto capture_hrr = std::make_shared<TlsHandshakeRecorder>(
+ server_, ssl_hs_hello_retry_request);
+ server_->SetFilter(capture_hrr);
TriggerHelloRetryRequest(client_, server_);
MakeNewServer();
@@ -761,8 +769,8 @@ TEST_F(TlsConnectStreamTls13, RetryWithDifferentCipherSuite) {
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
server_->ConfigNamedGroups(groups);
// Then switch out the default suite (TLS_AES_128_GCM_SHA256).
- server_->SetPacketFilter(std::make_shared<SelectedCipherSuiteReplacer>(
- TLS_CHACHA20_POLY1305_SHA256));
+ server_->SetFilter(std::make_shared<SelectedCipherSuiteReplacer>(
+ server_, TLS_CHACHA20_POLY1305_SHA256));
client_->ExpectSendAlert(kTlsAlertIllegalParameter);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
@@ -777,7 +785,7 @@ TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) {
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
ssl_grp_ec_secp521r1};
server_->ConfigNamedGroups(groups);
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x2));
Connect();
}
@@ -833,9 +841,9 @@ TEST_P(TlsKeyExchange13,
EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
auto capture_server =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
- server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
ChainedPacketFilterInit{capture_hrr_, capture_server}));
size_t cb_called = 0;
diff --git a/gtests/ssl_gtest/ssl_keylog_unittest.cc b/gtests/ssl_gtest/ssl_keylog_unittest.cc
index 8ed342305..322b64837 100644
--- a/gtests/ssl_gtest/ssl_keylog_unittest.cc
+++ b/gtests/ssl_gtest/ssl_keylog_unittest.cc
@@ -20,8 +20,8 @@ static const std::string keylog_env = "SSLKEYLOGFILE=" + keylog_file_path;
class KeyLogFileTest : public TlsConnectGeneric {
public:
- void SetUp() {
- TlsConnectTestBase::SetUp();
+ void SetUp() override {
+ TlsConnectGeneric::SetUp();
// Remove previous results (if any).
(void)remove(keylog_file_path.c_str());
PR_SetEnv(keylog_env.c_str());
diff --git a/gtests/ssl_gtest/ssl_loopback_unittest.cc b/gtests/ssl_gtest/ssl_loopback_unittest.cc
index f1a789367..227c06234 100644
--- a/gtests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -56,7 +56,8 @@ TEST_P(TlsConnectGeneric, CipherSuiteMismatch) {
class TlsAlertRecorder : public TlsRecordFilter {
public:
- TlsAlertRecorder() : level_(255), description_(255) {}
+ TlsAlertRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent), level_(255), description_(255) {}
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
@@ -86,9 +87,9 @@ class TlsAlertRecorder : public TlsRecordFilter {
class HelloTruncator : public TlsHandshakeFilter {
public:
- HelloTruncator()
+ HelloTruncator(const std::shared_ptr<TlsAgent>& agent)
: TlsHandshakeFilter(
- {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {}
+ agent, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {}
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override {
@@ -99,9 +100,9 @@ class HelloTruncator : public TlsHandshakeFilter {
// Verify that when NSS reports that an alert is sent, it is actually sent.
TEST_P(TlsConnectGeneric, CaptureAlertServer) {
- client_->SetPacketFilter(std::make_shared<HelloTruncator>());
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- server_->SetPacketFilter(alert_recorder);
+ client_->SetFilter(std::make_shared<HelloTruncator>(client_));
+ auto alert_recorder = std::make_shared<TlsAlertRecorder>(server_);
+ server_->SetFilter(alert_recorder);
ConnectExpectAlert(server_, kTlsAlertDecodeError);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
@@ -109,9 +110,9 @@ TEST_P(TlsConnectGeneric, CaptureAlertServer) {
}
TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
- server_->SetPacketFilter(std::make_shared<HelloTruncator>());
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- client_->SetPacketFilter(alert_recorder);
+ server_->SetFilter(std::make_shared<HelloTruncator>(server_));
+ auto alert_recorder = std::make_shared<TlsAlertRecorder>(client_);
+ client_->SetFilter(alert_recorder);
ConnectExpectAlert(client_, kTlsAlertDecodeError);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
@@ -120,9 +121,9 @@ TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
// In TLS 1.3, the server can't read the client alert.
TEST_P(TlsConnectTls13, CaptureAlertClient) {
- server_->SetPacketFilter(std::make_shared<HelloTruncator>());
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- client_->SetPacketFilter(alert_recorder);
+ server_->SetFilter(std::make_shared<HelloTruncator>(server_));
+ auto alert_recorder = std::make_shared<TlsAlertRecorder>(client_);
+ client_->SetFilter(alert_recorder);
StartConnect();
@@ -173,7 +174,8 @@ TEST_P(TlsConnectGeneric, ConnectSendReceive) {
class SaveTlsRecord : public TlsRecordFilter {
public:
- SaveTlsRecord(size_t index) : index_(index), count_(0), contents_() {}
+ SaveTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index)
+ : TlsRecordFilter(agent), index_(index), count_(0), contents_() {}
const DataBuffer& contents() const { return contents_; }
@@ -198,8 +200,9 @@ class SaveTlsRecord : public TlsRecordFilter {
TEST_F(TlsConnectStreamTls13, DecryptRecordClient) {
EnsureTlsSetup();
// 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer
- auto saved = std::make_shared<SaveTlsRecord>(3);
- client_->SetTlsRecordFilter(saved);
+ auto saved = std::make_shared<SaveTlsRecord>(client_, 3);
+ saved->EnableDecryption();
+ client_->SetFilter(saved);
Connect();
SendReceive();
@@ -215,8 +218,9 @@ TEST_F(TlsConnectStreamTls13, DecryptRecordServer) {
EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
// 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer
- auto saved = std::make_shared<SaveTlsRecord>(3);
- server_->SetTlsRecordFilter(saved);
+ auto saved = std::make_shared<SaveTlsRecord>(server_, 3);
+ saved->EnableDecryption();
+ server_->SetFilter(saved);
Connect();
SendReceive();
@@ -228,7 +232,8 @@ TEST_F(TlsConnectStreamTls13, DecryptRecordServer) {
class DropTlsRecord : public TlsRecordFilter {
public:
- DropTlsRecord(size_t index) : index_(index), count_(0) {}
+ DropTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index)
+ : TlsRecordFilter(agent), index_(index), count_(0) {}
protected:
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
@@ -253,7 +258,9 @@ TEST_F(TlsConnectStreamTls13, DropRecordServer) {
SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
// 0 = ServerHello, 1 = other handshake, 2 = first write
- server_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2));
+ auto filter = std::make_shared<DropTlsRecord>(server_, 2);
+ filter->EnableDecryption();
+ server_->SetFilter(filter);
Connect();
server_->SendData(23, 23); // This should be dropped, so it won't be counted.
server_->ResetSentBytes();
@@ -263,7 +270,9 @@ TEST_F(TlsConnectStreamTls13, DropRecordServer) {
TEST_F(TlsConnectStreamTls13, DropRecordClient) {
EnsureTlsSetup();
// 0 = ClientHello, 1 = Finished, 2 = first write
- client_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2));
+ auto filter = std::make_shared<DropTlsRecord>(client_, 2);
+ filter->EnableDecryption();
+ client_->SetFilter(filter);
Connect();
client_->SendData(26, 26); // This should be dropped, so it won't be counted.
client_->ResetSentBytes();
@@ -371,7 +380,8 @@ TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) {
class TlsPreCCSHeaderInjector : public TlsRecordFilter {
public:
- TlsPreCCSHeaderInjector() {}
+ TlsPreCCSHeaderInjector(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent) {}
virtual PacketFilter::Action FilterRecord(
const TlsRecordHeader& record_header, const DataBuffer& input,
size_t* offset, DataBuffer* output) override {
@@ -388,14 +398,14 @@ class TlsPreCCSHeaderInjector : public TlsRecordFilter {
};
TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) {
- client_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>());
+ client_->SetFilter(std::make_shared<TlsPreCCSHeaderInjector>(client_));
ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
}
TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) {
- server_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>());
+ server_->SetFilter(std::make_shared<TlsPreCCSHeaderInjector>(server_));
StartConnect();
ExpectAlert(client_, kTlsAlertUnexpectedMessage);
Handshake();
@@ -476,8 +486,8 @@ TEST_F(TlsConnectTest, OneNRecordSplitting) {
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0);
EnsureTlsSetup();
ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA);
- auto records = std::make_shared<TlsRecordRecorder>();
- server_->SetPacketFilter(records);
+ auto records = std::make_shared<TlsRecordRecorder>(server_);
+ server_->SetFilter(records);
// This should be split into 1, 16384 and 20.
DataBuffer big_buffer;
big_buffer.Allocate(1 + 16384 + 20);
diff --git a/gtests/ssl_gtest/ssl_record_unittest.cc b/gtests/ssl_gtest/ssl_record_unittest.cc
index d1d496f49..5aab1b352 100644
--- a/gtests/ssl_gtest/ssl_record_unittest.cc
+++ b/gtests/ssl_gtest/ssl_record_unittest.cc
@@ -103,8 +103,8 @@ TEST_P(TlsPaddingTest, LastByteOfPadWrong) {
class RecordReplacer : public TlsRecordFilter {
public:
- RecordReplacer(size_t size)
- : TlsRecordFilter(), enabled_(false), size_(size) {}
+ RecordReplacer(const std::shared_ptr<TlsAgent>& agent, size_t size)
+ : TlsRecordFilter(agent), enabled_(false), size_(size) {}
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& data,
@@ -135,8 +135,9 @@ TEST_F(TlsConnectStreamTls13, LargeRecord) {
EnsureTlsSetup();
const size_t record_limit = 16384;
- auto replacer = std::make_shared<RecordReplacer>(record_limit);
- client_->SetTlsRecordFilter(replacer);
+ auto replacer = std::make_shared<RecordReplacer>(client_, record_limit);
+ replacer->EnableDecryption();
+ client_->SetFilter(replacer);
Connect();
replacer->Enable();
@@ -149,8 +150,9 @@ TEST_F(TlsConnectStreamTls13, TooLargeRecord) {
EnsureTlsSetup();
const size_t record_limit = 16384;
- auto replacer = std::make_shared<RecordReplacer>(record_limit + 1);
- client_->SetTlsRecordFilter(replacer);
+ auto replacer = std::make_shared<RecordReplacer>(client_, record_limit + 1);
+ replacer->EnableDecryption();
+ client_->SetFilter(replacer);
Connect();
replacer->Enable();
@@ -177,4 +179,4 @@ auto kTrueFalse = ::testing::ValuesIn(kTrueFalseArr);
INSTANTIATE_TEST_CASE_P(TlsPadding, TlsPaddingTest,
::testing::Combine(kContentSizes, kTrueFalse));
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_resumption_unittest.cc b/gtests/ssl_gtest/ssl_resumption_unittest.cc
index 4f3a98ad4..9a80f00e6 100644
--- a/gtests/ssl_gtest/ssl_resumption_unittest.cc
+++ b/gtests/ssl_gtest/ssl_resumption_unittest.cc
@@ -219,8 +219,8 @@ TEST_P(TlsConnectGenericResumption, ConnectWithExpiredTicketAtClient) {
SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3)
? ssl_tls13_pre_shared_key_xtn
: ssl_session_ticket_xtn;
- auto capture = std::make_shared<TlsExtensionCapture>(xtn);
- client_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(client_, xtn);
+ client_->SetFilter(capture);
Connect();
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
@@ -245,8 +245,8 @@ TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) {
SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3)
? ssl_tls13_pre_shared_key_xtn
: ssl_session_ticket_xtn;
- auto capture = std::make_shared<TlsExtensionCapture>(xtn);
- client_->SetPacketFilter(capture);
+ auto capture = std::make_shared<TlsExtensionCapture>(client_, xtn);
+ client_->SetFilter(capture);
StartConnect();
client_->Handshake();
EXPECT_TRUE(capture->captured());
@@ -327,25 +327,25 @@ TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) {
// Prior to TLS 1.3, we were not fully ephemeral; though 1.3 fixes that
TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) {
- auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(i1);
+ auto filter = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ server_->SetFilter(filter);
Connect();
CheckKeys();
TlsServerKeyExchangeEcdhe dhe1;
- EXPECT_TRUE(dhe1.Parse(i1->buffer()));
+ EXPECT_TRUE(dhe1.Parse(filter->buffer()));
// Restart
Reset();
- auto i2 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(i2);
+ auto filter2 = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ server_->SetFilter(filter2);
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
CheckKeys();
TlsServerKeyExchangeEcdhe dhe2;
- EXPECT_TRUE(dhe2.Parse(i2->buffer()));
+ EXPECT_TRUE(dhe2.Parse(filter2->buffer()));
// Make sure they are the same.
EXPECT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len());
@@ -356,26 +356,26 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) {
// This test parses the ServerKeyExchange, which isn't in 1.3
TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) {
server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
- auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(i1);
+ auto filter = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ server_->SetFilter(filter);
Connect();
CheckKeys();
TlsServerKeyExchangeEcdhe dhe1;
- EXPECT_TRUE(dhe1.Parse(i1->buffer()));
+ EXPECT_TRUE(dhe1.Parse(filter->buffer()));
// Restart
Reset();
server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
- auto i2 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(i2);
+ auto filter2 = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
+ server_->SetFilter(filter2);
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
CheckKeys();
TlsServerKeyExchangeEcdhe dhe2;
- EXPECT_TRUE(dhe2.Parse(i2->buffer()));
+ EXPECT_TRUE(dhe2.Parse(filter2->buffer()));
// Make sure they are different.
EXPECT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) &&
@@ -434,8 +434,9 @@ TEST_P(TlsConnectGenericResumption, TestResumeClientDifferentCipher) {
} else {
ticket_extension = ssl_session_ticket_xtn;
}
- auto ticket_capture = std::make_shared<TlsExtensionCapture>(ticket_extension);
- client_->SetPacketFilter(ticket_capture);
+ auto ticket_capture =
+ std::make_shared<TlsExtensionCapture>(client_, ticket_extension);
+ client_->SetFilter(ticket_capture);
Connect();
CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
EXPECT_EQ(0U, ticket_capture->extension().len());
@@ -468,8 +469,8 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
Reset();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
- server_->SetPacketFilter(std::make_shared<SelectedCipherSuiteReplacer>(
- ChooseAnotherCipher(version_)));
+ server_->SetFilter(std::make_shared<SelectedCipherSuiteReplacer>(
+ server_, ChooseAnotherCipher(version_)));
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
client_->ExpectSendAlert(kTlsAlertIllegalParameter);
@@ -490,8 +491,10 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
class SelectedVersionReplacer : public TlsHandshakeFilter {
public:
- SelectedVersionReplacer(uint16_t version)
- : TlsHandshakeFilter({kTlsHandshakeServerHello}), version_(version) {}
+ SelectedVersionReplacer(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t version)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}),
+ version_(version) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
@@ -543,8 +546,8 @@ TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) {
// Enable the lower version on the client.
client_->SetVersionRange(version_ - 1, version_);
server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
- server_->SetPacketFilter(
- std::make_shared<SelectedVersionReplacer>(override_version));
+ server_->SetFilter(
+ std::make_shared<SelectedVersionReplacer>(server_, override_version));
ConnectExpectAlert(client_, kTlsAlertHandshakeFailure);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
@@ -567,8 +570,9 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
ExpectResumption(RESUME_TICKET);
- auto c1 = std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(c1);
+ auto c1 = std::make_shared<TlsExtensionCapture>(client_,
+ ssl_tls13_pre_shared_key_xtn);
+ client_->SetFilter(c1);
Connect();
SendReceive();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
@@ -584,8 +588,9 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) {
ClearStats();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
- auto c2 = std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(c2);
+ auto c2 = std::make_shared<TlsExtensionCapture>(client_,
+ ssl_tls13_pre_shared_key_xtn);
+ client_->SetFilter(c2);
ExpectResumption(RESUME_TICKET);
Connect();
SendReceive();
@@ -656,9 +661,10 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNSTWithToken) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
- auto nst_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- ssl_hs_new_session_ticket);
- server_->SetTlsRecordFilter(nst_capture);
+ auto nst_capture = std::make_shared<TlsHandshakeRecorder>(
+ server_, ssl_hs_new_session_ticket);
+ nst_capture->EnableDecryption();
+ server_->SetFilter(nst_capture);
Connect();
// Clear the session ticket keys to invalidate the old ticket.
@@ -678,9 +684,9 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNSTWithToken) {
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
ExpectResumption(RESUME_TICKET);
- auto psk_capture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(psk_capture);
+ auto psk_capture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_tls13_pre_shared_key_xtn);
+ client_->SetFilter(psk_capture);
Connect();
SendReceive();
@@ -696,9 +702,10 @@ TEST_F(TlsConnectTest, SendSessionTicketWithTicketsDisabled) {
EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
- auto nst_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- ssl_hs_new_session_ticket);
- server_->SetTlsRecordFilter(nst_capture);
+ auto nst_capture = std::make_shared<TlsHandshakeRecorder>(
+ server_, ssl_hs_new_session_ticket);
+ nst_capture->EnableDecryption();
+ server_->SetFilter(nst_capture);
Connect();
EXPECT_EQ(0U, nst_capture->buffer().len()) << "expect nothing captured yet";
@@ -714,9 +721,9 @@ TEST_F(TlsConnectTest, SendSessionTicketWithTicketsDisabled) {
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
ExpectResumption(RESUME_TICKET);
- auto psk_capture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(psk_capture);
+ auto psk_capture = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_tls13_pre_shared_key_xtn);
+ client_->SetFilter(psk_capture);
Connect();
SendReceive();
@@ -819,20 +826,20 @@ TEST_F(TlsConnectTest, TestTls13ResumptionForcedDowngrade) {
// We will eventually fail the (sid.version == SH.version) check.
std::vector<std::shared_ptr<PacketFilter>> filters;
filters.push_back(std::make_shared<SelectedCipherSuiteReplacer>(
- TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256));
- filters.push_back(
- std::make_shared<SelectedVersionReplacer>(SSL_LIBRARY_VERSION_TLS_1_2));
+ server_, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256));
+ filters.push_back(std::make_shared<SelectedVersionReplacer>(
+ server_, SSL_LIBRARY_VERSION_TLS_1_2));
// Drop a bunch of extensions so that we get past the SH processing. The
// version extension says TLS 1.3, which is counter to our goal, the others
// are not permitted in TLS 1.2 handshakes.
+ filters.push_back(std::make_shared<TlsExtensionDropper>(
+ server_, ssl_tls13_supported_versions_xtn));
filters.push_back(
- std::make_shared<TlsExtensionDropper>(ssl_tls13_supported_versions_xtn));
- filters.push_back(
- std::make_shared<TlsExtensionDropper>(ssl_tls13_key_share_xtn));
- filters.push_back(
- std::make_shared<TlsExtensionDropper>(ssl_tls13_pre_shared_key_xtn));
- server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(filters));
+ std::make_shared<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn));
+ filters.push_back(std::make_shared<TlsExtensionDropper>(
+ server_, ssl_tls13_pre_shared_key_xtn));
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(filters));
// The client here generates an unexpected_message alert when it receives an
// encrypted handshake message from the server (EncryptedExtension). The
diff --git a/gtests/ssl_gtest/ssl_skip_unittest.cc b/gtests/ssl_gtest/ssl_skip_unittest.cc
index 335bfecfa..e4a9e5aed 100644
--- a/gtests/ssl_gtest/ssl_skip_unittest.cc
+++ b/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -22,8 +22,11 @@ namespace nss_test {
class TlsHandshakeSkipFilter : public TlsRecordFilter {
public:
// A TLS record filter that skips handshake messages of the identified type.
- TlsHandshakeSkipFilter(uint8_t handshake_type)
- : handshake_type_(handshake_type), skipped_(false) {}
+ TlsHandshakeSkipFilter(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type)
+ : TlsRecordFilter(agent),
+ handshake_type_(handshake_type),
+ skipped_(false) {}
protected:
// Takes a record; if it is a handshake record, it removes the first handshake
@@ -92,9 +95,14 @@ class TlsSkipTest : public TlsConnectTestBase,
TlsSkipTest()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
+ EnsureTlsSetup();
+ }
+
void ServerSkipTest(std::shared_ptr<PacketFilter> filter,
uint8_t alert = kTlsAlertUnexpectedMessage) {
- server_->SetPacketFilter(filter);
+ server_->SetFilter(filter);
ConnectExpectAlert(client_, alert);
}
};
@@ -105,9 +113,14 @@ class Tls13SkipTest : public TlsConnectTestBase,
Tls13SkipTest()
: TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
- void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
EnsureTlsSetup();
- server_->SetTlsRecordFilter(filter);
+ }
+
+ void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ filter->EnableDecryption();
+ server_->SetFilter(filter);
ExpectAlert(client_, kTlsAlertUnexpectedMessage);
ConnectExpectFail();
client_->CheckErrorCode(error);
@@ -115,8 +128,8 @@ class Tls13SkipTest : public TlsConnectTestBase,
}
void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
- EnsureTlsSetup();
- client_->SetTlsRecordFilter(filter);
+ filter->EnableDecryption();
+ client_->SetFilter(filter);
server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
ConnectExpectFailOneSide(TlsAgent::SERVER);
@@ -129,48 +142,49 @@ class Tls13SkipTest : public TlsConnectTestBase,
TEST_P(TlsSkipTest, SkipCertificateRsa) {
EnableOnlyStaticRsaCiphers();
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertificateDhe) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdhe) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipServerKeyExchange) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
- auto chain = std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit{
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
- std::make_shared<TlsHandshakeSkipFilter>(
- kTlsHandshakeServerKeyExchange)});
+ auto chain = std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit{std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate),
+ std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange)});
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
@@ -178,48 +192,48 @@ TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
auto chain = std::make_shared<ChainedPacketFilter>();
- chain->Add(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
- chain->Add(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(Tls13SkipTest, SkipEncryptedExtensions) {
ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
- kTlsHandshakeEncryptedExtensions),
+ server_, kTlsHandshakeEncryptedExtensions),
SSL_ERROR_RX_UNEXPECTED_CERTIFICATE);
}
TEST_P(Tls13SkipTest, SkipServerCertificate) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
- SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
}
TEST_P(Tls13SkipTest, SkipServerCertificateVerify) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
- SSL_ERROR_RX_UNEXPECTED_FINISHED);
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
}
TEST_P(Tls13SkipTest, SkipClientCertificate) {
client_->SetupClientAuth();
server_->RequestClientAuth(true);
client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
- ClientSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
- SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+ ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ client_, kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
}
TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
client_->SetupClientAuth();
server_->RequestClientAuth(true);
client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
- ClientSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
- SSL_ERROR_RX_UNEXPECTED_FINISHED);
+ ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ client_, kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
}
INSTANTIATE_TEST_CASE_P(
diff --git a/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
index e7fe44d92..c614af99f 100644
--- a/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
+++ b/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
@@ -48,10 +48,9 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSA) {
// This test is stream so we can catch the bad_record_mac alert.
TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) {
EnableOnlyStaticRsaCiphers();
- auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
- kTlsHandshakeClientKeyExchange,
- DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
- client_->SetPacketFilter(i1);
+ client_->SetFilter(std::make_shared<TlsInspectorReplaceHandshakeMessage>(
+ client_, kTlsHandshakeClientKeyExchange,
+ DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))));
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
@@ -59,8 +58,8 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) {
// This test is stream so we can catch the bad_record_mac alert.
TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
EnableOnlyStaticRsaCiphers();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
+ client_->SetFilter(
+ std::make_shared<TlsClientHelloVersionChanger>(client_, server_));
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
@@ -69,8 +68,8 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
// ConnectStaticRSABogusPMSVersionDetect.
TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
EnableOnlyStaticRsaCiphers();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
+ client_->SetFilter(
+ std::make_shared<TlsClientHelloVersionChanger>(client_, server_));
server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE);
Connect();
}
@@ -79,10 +78,9 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
- auto inspect = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
- kTlsHandshakeClientKeyExchange,
- DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
- client_->SetPacketFilter(inspect);
+ client_->SetFilter(std::make_shared<TlsInspectorReplaceHandshakeMessage>(
+ client_, kTlsHandshakeClientKeyExchange,
+ DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))));
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
@@ -91,8 +89,8 @@ TEST_P(TlsConnectStreamPre13,
ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
+ client_->SetFilter(
+ std::make_shared<TlsClientHelloVersionChanger>(client_, server_));
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
@@ -100,10 +98,10 @@ TEST_P(TlsConnectStreamPre13,
ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
+ client_->SetFilter(
+ std::make_shared<TlsClientHelloVersionChanger>(client_, server_));
server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE);
Connect();
}
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_tls13compat_unittest.cc b/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
index 75cee52fc..43f502fae 100644
--- a/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
+++ b/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
@@ -67,10 +67,7 @@ class Tls13CompatTest : public TlsConnectStreamTls13 {
private:
struct Recorders {
- Recorders()
- : records_(new TlsRecordRecorder()),
- hello_(new TlsInspectorRecordHandshakeMessage(std::set<uint8_t>(
- {kTlsHandshakeClientHello, kTlsHandshakeServerHello}))) {}
+ Recorders() : records_(nullptr), hello_(nullptr) {}
uint8_t session_id_length() const {
// session_id is always after version (2) and random (32).
@@ -91,12 +88,22 @@ class Tls13CompatTest : public TlsConnectStreamTls13 {
}
void Install(std::shared_ptr<TlsAgent>& agent) {
- agent->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+ if (records_ && records_->agent() == agent) {
+ // Avoid replacing the filters if they are already installed on this
+ // agent. This ensures that InstallFilters() can be used after
+ // MakeNewServer() without losing state on the client filters.
+ return;
+ }
+ records_.reset(new TlsRecordRecorder(agent));
+ hello_.reset(new TlsHandshakeRecorder(
+ agent, std::set<uint8_t>(
+ {kTlsHandshakeClientHello, kTlsHandshakeServerHello})));
+ agent->SetFilter(std::make_shared<ChainedPacketFilter>(
ChainedPacketFilterInit({records_, hello_})));
}
std::shared_ptr<TlsRecordRecorder> records_;
- std::shared_ptr<TlsInspectorRecordHandshakeMessage> hello_;
+ std::shared_ptr<TlsHandshakeRecorder> hello_;
};
void CheckRecordsAreTls12(const std::string& agent,
@@ -171,16 +178,20 @@ TEST_F(Tls13CompatTest, EnabledStatelessHrr) {
server_->StartConnect();
client_->Handshake();
server_->Handshake();
+
+ // The server should send CCS before HRR.
CheckForCCS(false, true);
- // A new server should just work, but not send another CCS.
+ // A new server should complete the handshake, and not send CCS.
MakeNewServer();
InstallFilters();
server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
Handshake();
CheckConnected();
- CheckForCompatHandshake();
+ CheckRecordVersions();
+ CheckHelloVersions();
+ CheckForCCS(true, false);
}
TEST_F(Tls13CompatTest, EnabledHrrZeroRtt) {
@@ -262,10 +273,10 @@ TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello12) {
TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) {
EnsureTlsSetup();
client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE);
- auto client_records = std::make_shared<TlsRecordRecorder>();
- client_->SetPacketFilter(client_records);
- auto server_records = std::make_shared<TlsRecordRecorder>();
- server_->SetPacketFilter(server_records);
+ auto client_records = std::make_shared<TlsRecordRecorder>(client_);
+ client_->SetFilter(client_records);
+ auto server_records = std::make_shared<TlsRecordRecorder>(server_);
+ server_->SetFilter(server_records);
Connect();
ASSERT_EQ(2U, client_records->count()); // CH, Fin
@@ -283,7 +294,8 @@ TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) {
class AddSessionIdFilter : public TlsHandshakeFilter {
public:
- AddSessionIdFilter() : TlsHandshakeFilter({ssl_hs_client_hello}) {}
+ AddSessionIdFilter(const std::shared_ptr<TlsAgent>& client)
+ : TlsHandshakeFilter(client, {ssl_hs_client_hello}) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
@@ -303,14 +315,14 @@ class AddSessionIdFilter : public TlsHandshakeFilter {
// mode. It should be ignored instead.
TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) {
EnsureTlsSetup();
- auto client_records = std::make_shared<TlsRecordRecorder>();
- client_->SetPacketFilter(
+ auto client_records = std::make_shared<TlsRecordRecorder>(client_);
+ client_->SetFilter(
std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit(
- {client_records, std::make_shared<AddSessionIdFilter>()})));
- auto server_hello = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerHello);
- auto server_records = std::make_shared<TlsRecordRecorder>();
- server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+ {client_records, std::make_shared<AddSessionIdFilter>(client_)})));
+ auto server_hello =
+ std::make_shared<TlsHandshakeRecorder>(server_, kTlsHandshakeServerHello);
+ auto server_records = std::make_shared<TlsRecordRecorder>(server_);
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
ChainedPacketFilterInit({server_records, server_hello})));
StartConnect();
client_->Handshake();
@@ -334,4 +346,4 @@ TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) {
EXPECT_EQ(0U, session_id_len);
}
-} // nss_test
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
index 2f8ddd6fe..a590ee0ed 100644
--- a/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
+++ b/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
@@ -147,10 +147,10 @@ class SSLv2ClientHelloTestF : public TlsConnectTestBase {
SSLv2ClientHelloTestF(SSLProtocolVariant variant, uint16_t version)
: TlsConnectTestBase(variant, version), filter_(nullptr) {}
- void SetUp() {
+ void SetUp() override {
TlsConnectTestBase::SetUp();
filter_ = std::make_shared<SSLv2ClientHelloFilter>(client_, version_);
- client_->SetPacketFilter(filter_);
+ client_->SetFilter(filter_);
}
void SetExpectedVersion(uint16_t version) {
diff --git a/gtests/ssl_gtest/ssl_version_unittest.cc b/gtests/ssl_gtest/ssl_version_unittest.cc
index 9db293b07..747f019cf 100644
--- a/gtests/ssl_gtest/ssl_version_unittest.cc
+++ b/gtests/ssl_gtest/ssl_version_unittest.cc
@@ -56,18 +56,16 @@ TEST_P(TlsConnectGeneric, ServerNegotiateTls12) {
// two validate that we can also detect fallback using the
// SSL_SetDowngradeCheckVersion() API.
TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) {
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- SSL_LIBRARY_VERSION_TLS_1_1));
+ client_->SetFilter(std::make_shared<TlsClientHelloVersionSetter>(
+ client_, SSL_LIBRARY_VERSION_TLS_1_1));
ConnectExpectFail();
ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
}
/* Attempt to negotiate the bogus DTLS 1.1 version. */
TEST_F(DtlsConnectTest, TestDtlsVersion11) {
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- ((~0x0101) & 0xffff)));
+ client_->SetFilter(std::make_shared<TlsClientHelloVersionSetter>(
+ client_, ((~0x0101) & 0xffff)));
ConnectExpectFail();
// It's kind of surprising that SSL_ERROR_NO_CYPHER_OVERLAP is
// what is returned here, but this is deliberate in ssl3_HandleAlert().
@@ -78,9 +76,8 @@ TEST_F(DtlsConnectTest, TestDtlsVersion11) {
// Disabled as long as we have draft version.
TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) {
EnsureTlsSetup();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- SSL_LIBRARY_VERSION_TLS_1_2));
+ client_->SetFilter(std::make_shared<TlsClientHelloVersionSetter>(
+ client_, SSL_LIBRARY_VERSION_TLS_1_2));
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
@@ -92,9 +89,8 @@ TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) {
// TLS 1.1 clients do not check the random values, so we should
// instead get a handshake failure alert from the server.
TEST_F(TlsConnectTest, TestDowngradeDetectionToTls10) {
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- SSL_LIBRARY_VERSION_TLS_1_0));
+ client_->SetFilter(std::make_shared<TlsClientHelloVersionSetter>(
+ client_, SSL_LIBRARY_VERSION_TLS_1_0));
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
SSL_LIBRARY_VERSION_TLS_1_1);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
@@ -177,12 +173,11 @@ class Tls13NoSupportedVersions : public TlsConnectStreamTls12 {
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_2);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, max_server_version);
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- overwritten_client_version));
- auto capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerHello);
- server_->SetPacketFilter(capture);
+ client_->SetFilter(std::make_shared<TlsClientHelloVersionSetter>(
+ client_, overwritten_client_version));
+ auto capture = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerHello);
+ server_->SetFilter(capture);
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
@@ -214,12 +209,11 @@ TEST_F(Tls13NoSupportedVersions,
// Offer 1.3 but with ClientHello.legacy_version == TLS 1.4. This
// causes a bad MAC error when we read EncryptedExtensions.
TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) {
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- SSL_LIBRARY_VERSION_TLS_1_3 + 1));
- auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_supported_versions_xtn);
- server_->SetPacketFilter(capture);
+ client_->SetFilter(std::make_shared<TlsClientHelloVersionSetter>(
+ client_, SSL_LIBRARY_VERSION_TLS_1_3 + 1));
+ auto capture = std::make_shared<TlsExtensionCapture>(
+ server_, ssl_tls13_supported_versions_xtn);
+ server_->SetFilter(capture);
client_->ExpectSendAlert(kTlsAlertBadRecordMac);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
diff --git a/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc b/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
index eda96831c..7f3c4a896 100644
--- a/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
+++ b/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
@@ -189,12 +189,12 @@ class TestPolicyVersionRange
}
}
- void SetUp() {
- SetPolicy(policy_.range());
+ void SetUp() override {
TlsConnectTestBase::SetUp();
+ SetPolicy(policy_.range());
}
- void TearDown() {
+ void TearDown() override {
TlsConnectTestBase::TearDown();
saved_version_policy_.RestoreOriginalPolicy();
}
diff --git a/gtests/ssl_gtest/test_io.cc b/gtests/ssl_gtest/test_io.cc
index adcdbfbaf..728217851 100644
--- a/gtests/ssl_gtest/test_io.cc
+++ b/gtests/ssl_gtest/test_io.cc
@@ -25,10 +25,6 @@ namespace nss_test {
if (g_ssl_gtest_verbose) LOG(a); \
} while (false)
-void DummyPrSocket::SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
- filter_ = filter;
-}
-
ScopedPRFileDesc DummyPrSocket::CreateFD() {
static PRDescIdentity test_fd_identity =
PR_GetUniqueIdentity("testtransportadapter");
diff --git a/gtests/ssl_gtest/test_io.h b/gtests/ssl_gtest/test_io.h
index 469d90a7c..dbeb6b9d4 100644
--- a/gtests/ssl_gtest/test_io.h
+++ b/gtests/ssl_gtest/test_io.h
@@ -74,7 +74,9 @@ class DummyPrSocket : public DummyIOLayerMethods {
std::weak_ptr<DummyPrSocket>& peer() { return peer_; }
void SetPeer(const std::shared_ptr<DummyPrSocket>& peer) { peer_ = peer; }
- void SetPacketFilter(std::shared_ptr<PacketFilter> filter);
+ void SetPacketFilter(const std::shared_ptr<PacketFilter>& filter) {
+ filter_ = filter;
+ }
// Drops peer, packet filter and any outstanding packets.
void Reset();
@@ -176,6 +178,6 @@ class Poller {
timers_;
};
-} // end of namespace
+} // namespace nss_test
#endif
diff --git a/gtests/ssl_gtest/tls_agent.h b/gtests/ssl_gtest/tls_agent.h
index 9bde5dfda..941fb649e 100644
--- a/gtests/ssl_gtest/tls_agent.h
+++ b/gtests/ssl_gtest/tls_agent.h
@@ -80,18 +80,10 @@ class TlsAgent : public PollTarget {
adapter_->SetPeer(peer->adapter_);
}
- // Set a filter that can access plaintext (TLS 1.3 only).
- void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) {
- filter->SetAgent(this);
+ void SetFilter(std::shared_ptr<PacketFilter> filter) {
adapter_->SetPacketFilter(filter);
- filter->EnableDecryption();
}
-
- void SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
- adapter_->SetPacketFilter(filter);
- }
-
- void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); }
+ void ClearFilter() { adapter_->SetPacketFilter(nullptr); }
void StartConnect(PRFileDesc* model = nullptr);
void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
@@ -463,7 +455,7 @@ class TlsAgentTestBase : public ::testing::Test {
void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
int32_t error_code = 0);
- std::unique_ptr<TlsAgent> agent_;
+ std::shared_ptr<TlsAgent> agent_;
TlsAgent::Role role_;
SSLProtocolVariant variant_;
uint16_t version_;
diff --git a/gtests/ssl_gtest/tls_connect.cc b/gtests/ssl_gtest/tls_connect.cc
index b1e90d89d..bc146e042 100644
--- a/gtests/ssl_gtest/tls_connect.cc
+++ b/gtests/ssl_gtest/tls_connect.cc
@@ -770,17 +770,17 @@ TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken()
void TlsKeyExchangeTest::EnsureKeyShareSetup() {
EnsureTlsSetup();
groups_capture_ =
- std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
shares_capture_ =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
- shares_capture2_ =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn, true);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
+ shares_capture2_ = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_tls13_key_share_xtn, true);
std::vector<std::shared_ptr<PacketFilter>> captures = {
groups_capture_, shares_capture_, shares_capture2_};
- client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
- capture_hrr_ = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeHelloRetryRequest);
- server_->SetPacketFilter(capture_hrr_);
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
+ capture_hrr_ = std::make_shared<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeHelloRetryRequest);
+ server_->SetFilter(capture_hrr_);
}
void TlsKeyExchangeTest::ConfigNamedGroups(
diff --git a/gtests/ssl_gtest/tls_connect.h b/gtests/ssl_gtest/tls_connect.h
index 9746f9865..6a35fc78b 100644
--- a/gtests/ssl_gtest/tls_connect.h
+++ b/gtests/ssl_gtest/tls_connect.h
@@ -45,8 +45,8 @@ class TlsConnectTestBase : public ::testing::Test {
TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version);
virtual ~TlsConnectTestBase();
- void SetUp();
- void TearDown();
+ virtual void SetUp();
+ virtual void TearDown();
// Initialize client and server.
void Init();
@@ -320,7 +320,7 @@ class TlsKeyExchangeTest : public TlsConnectGeneric {
std::shared_ptr<TlsExtensionCapture> groups_capture_;
std::shared_ptr<TlsExtensionCapture> shares_capture_;
std::shared_ptr<TlsExtensionCapture> shares_capture2_;
- std::shared_ptr<TlsInspectorRecordHandshakeMessage> capture_hrr_;
+ std::shared_ptr<TlsHandshakeRecorder> capture_hrr_;
void EnsureKeyShareSetup();
void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
diff --git a/gtests/ssl_gtest/tls_filter.cc b/gtests/ssl_gtest/tls_filter.cc
index 89f201295..d34b13bcb 100644
--- a/gtests/ssl_gtest/tls_filter.cc
+++ b/gtests/ssl_gtest/tls_filter.cc
@@ -452,7 +452,7 @@ size_t TlsHandshakeFilter::HandshakeHeader::Write(
return offset;
}
-PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake(
+PacketFilter::Action TlsHandshakeRecorder::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
// Only do this once.
@@ -763,7 +763,7 @@ PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
if (counter_++ == record_) {
DataBuffer buf;
header.Write(&buf, 0, body);
- src_.lock()->SendDirect(buf);
+ agent()->SendDirect(buf);
dest_.lock()->Handshake();
func_();
return DROP;
@@ -772,7 +772,7 @@ PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
return KEEP;
}
-PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake(
+PacketFilter::Action TlsClientHelloVersionChanger::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
EXPECT_EQ(SECSuccess,
@@ -808,7 +808,7 @@ PacketFilter::Action SelectiveRecordDropFilter::FilterRecord(
return pattern;
}
-PacketFilter::Action TlsInspectorClientHelloVersionSetter::FilterHandshake(
+PacketFilter::Action TlsClientHelloVersionSetter::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
*output = input;
diff --git a/gtests/ssl_gtest/tls_filter.h b/gtests/ssl_gtest/tls_filter.h
index 1db3b90f6..9485b5eb3 100644
--- a/gtests/ssl_gtest/tls_filter.h
+++ b/gtests/ssl_gtest/tls_filter.h
@@ -74,16 +74,15 @@ struct TlsRecord {
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
- TlsRecordFilter()
- : agent_(nullptr),
+ TlsRecordFilter(const std::shared_ptr<TlsAgent>& agent)
+ : agent_(agent),
count_(0),
cipher_spec_(),
dropped_record_(false),
in_sequence_number_(0),
out_sequence_number_(0) {}
- void SetAgent(const TlsAgent* agent) { agent_ = agent; }
- const TlsAgent* agent() const { return agent_; }
+ std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); }
// External interface. Overrides PacketFilter.
PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output);
@@ -126,7 +125,7 @@ class TlsRecordFilter : public PacketFilter {
static void CipherSpecChanged(void* arg, PRBool sending,
ssl3CipherSpec* newSpec);
- const TlsAgent* agent_;
+ std::weak_ptr<TlsAgent> agent_;
size_t count_;
std::unique_ptr<TlsCipherSpec> cipher_spec_;
// Whether we dropped a record since the cipher spec changed.
@@ -175,9 +174,13 @@ inline std::ostream& operator<<(std::ostream& stream,
// records and that they don't span records or anything crazy like that.
class TlsHandshakeFilter : public TlsRecordFilter {
public:
- TlsHandshakeFilter() : handshake_types_(), preceding_fragment_() {}
- TlsHandshakeFilter(const std::set<uint8_t>& types)
- : handshake_types_(types), preceding_fragment_() {}
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent), handshake_types_(), preceding_fragment_() {}
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent,
+ const std::set<uint8_t>& types)
+ : TlsRecordFilter(agent),
+ handshake_types_(types),
+ preceding_fragment_() {}
// This filter can be set to be selective based on handshake message type. If
// this function isn't used (or the set is empty), then all handshake messages
@@ -229,12 +232,14 @@ class TlsHandshakeFilter : public TlsRecordFilter {
};
// Make a copy of the first instance of a handshake message.
-class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
+class TlsHandshakeRecorder : public TlsHandshakeFilter {
public:
- TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
- : TlsHandshakeFilter({handshake_type}), buffer_() {}
- TlsInspectorRecordHandshakeMessage(const std::set<uint8_t>& handshake_types)
- : TlsHandshakeFilter(handshake_types), buffer_() {}
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type)
+ : TlsHandshakeFilter(agent, {handshake_type}), buffer_() {}
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent,
+ const std::set<uint8_t>& handshake_types)
+ : TlsHandshakeFilter(agent, handshake_types), buffer_() {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -251,9 +256,10 @@ class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
// Replace all instances of a handshake message.
class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
public:
- TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type,
+ TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type,
const DataBuffer& replacement)
- : TlsHandshakeFilter({handshake_type}), buffer_(replacement) {}
+ : TlsHandshakeFilter(agent, {handshake_type}), buffer_(replacement) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -266,9 +272,11 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
// Make a copy of each record of a given type.
class TlsRecordRecorder : public TlsRecordFilter {
public:
- TlsRecordRecorder(uint8_t ct) : filter_(true), ct_(ct), records_() {}
- TlsRecordRecorder()
- : filter_(false),
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent, uint8_t ct)
+ : TlsRecordFilter(agent), filter_(true), ct_(ct), records_() {}
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent),
+ filter_(false),
ct_(content_handshake), // dummy (<optional> is C++14)
records_() {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
@@ -289,7 +297,9 @@ class TlsRecordRecorder : public TlsRecordFilter {
// Make a copy of the complete conversation.
class TlsConversationRecorder : public TlsRecordFilter {
public:
- TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {}
+ TlsConversationRecorder(const std::shared_ptr<TlsAgent>& agent,
+ DataBuffer& buffer)
+ : TlsRecordFilter(agent), buffer_(buffer) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
@@ -302,6 +312,8 @@ class TlsConversationRecorder : public TlsRecordFilter {
// Make a copy of the records
class TlsHeaderRecorder : public TlsRecordFilter {
public:
+ TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
@@ -338,13 +350,15 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
class TlsExtensionFilter : public TlsHandshakeFilter {
public:
- TlsExtensionFilter()
- : TlsHandshakeFilter({kTlsHandshakeClientHello, kTlsHandshakeServerHello,
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent,
+ {kTlsHandshakeClientHello, kTlsHandshakeServerHello,
kTlsHandshakeHelloRetryRequest,
kTlsHandshakeEncryptedExtensions}) {}
- TlsExtensionFilter(const std::set<uint8_t>& types)
- : TlsHandshakeFilter(types) {}
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent,
+ const std::set<uint8_t>& types)
+ : TlsHandshakeFilter(agent, types) {}
static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
@@ -365,8 +379,13 @@ class TlsExtensionFilter : public TlsHandshakeFilter {
class TlsExtensionCapture : public TlsExtensionFilter {
public:
- TlsExtensionCapture(uint16_t ext, bool last = false)
- : extension_(ext), captured_(false), last_(last), data_() {}
+ TlsExtensionCapture(const std::shared_ptr<TlsAgent>& agent, uint16_t ext,
+ bool last = false)
+ : TlsExtensionFilter(agent),
+ extension_(ext),
+ captured_(false),
+ last_(last),
+ data_() {}
const DataBuffer& extension() const { return data_; }
bool captured() const { return captured_; }
@@ -385,8 +404,9 @@ class TlsExtensionCapture : public TlsExtensionFilter {
class TlsExtensionReplacer : public TlsExtensionFilter {
public:
- TlsExtensionReplacer(uint16_t extension, const DataBuffer& data)
- : extension_(extension), data_(data) {}
+ TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension, const DataBuffer& data)
+ : TlsExtensionFilter(agent), extension_(extension), data_(data) {}
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) override;
@@ -398,7 +418,9 @@ class TlsExtensionReplacer : public TlsExtensionFilter {
class TlsExtensionDropper : public TlsExtensionFilter {
public:
- TlsExtensionDropper(uint16_t extension) : extension_(extension) {}
+ TlsExtensionDropper(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension)
+ : TlsExtensionFilter(agent), extension_(extension) {}
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer&, DataBuffer*) override;
@@ -408,8 +430,9 @@ class TlsExtensionDropper : public TlsExtensionFilter {
class TlsExtensionInjector : public TlsHandshakeFilter {
public:
- TlsExtensionInjector(uint16_t ext, const DataBuffer& data)
- : extension_(ext), data_(data) {}
+ TlsExtensionInjector(const std::shared_ptr<TlsAgent>& agent, uint16_t ext,
+ const DataBuffer& data)
+ : TlsHandshakeFilter(agent), extension_(ext), data_(data) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
@@ -426,16 +449,20 @@ typedef std::function<void(void)> VoidFunction;
class AfterRecordN : public TlsRecordFilter {
public:
- AfterRecordN(std::shared_ptr<TlsAgent>& src, std::shared_ptr<TlsAgent>& dest,
- unsigned int record, VoidFunction func)
- : src_(src), dest_(dest), record_(record), func_(func), counter_(0) {}
+ AfterRecordN(const std::shared_ptr<TlsAgent>& src,
+ const std::shared_ptr<TlsAgent>& dest, unsigned int record,
+ VoidFunction func)
+ : TlsRecordFilter(src),
+ dest_(dest),
+ record_(record),
+ func_(func),
+ counter_(0) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) override;
private:
- std::weak_ptr<TlsAgent> src_;
std::weak_ptr<TlsAgent> dest_;
unsigned int record_;
VoidFunction func_;
@@ -444,10 +471,12 @@ class AfterRecordN : public TlsRecordFilter {
// When we see the ClientKeyExchange from |client|, increment the
// ClientHelloVersion on |server|.
-class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter {
+class TlsClientHelloVersionChanger : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server)
- : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}), server_(server) {}
+ TlsClientHelloVersionChanger(const std::shared_ptr<TlsAgent>& client,
+ const std::shared_ptr<TlsAgent>& server)
+ : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}),
+ server_(server) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -477,14 +506,16 @@ class SelectiveDropFilter : public PacketFilter {
// datagram, we just drop one.
class SelectiveRecordDropFilter : public TlsRecordFilter {
public:
- SelectiveRecordDropFilter(uint32_t pattern, bool enabled = true)
- : pattern_(pattern), counter_(0) {
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent,
+ uint32_t pattern, bool enabled = true)
+ : TlsRecordFilter(agent), pattern_(pattern), counter_(0) {
if (!enabled) {
Disable();
}
}
- SelectiveRecordDropFilter(std::initializer_list<size_t> records)
- : SelectiveRecordDropFilter(ToPattern(records), true) {}
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent,
+ std::initializer_list<size_t> records)
+ : SelectiveRecordDropFilter(agent, ToPattern(records), true) {}
void Reset(uint32_t pattern) {
counter_ = 0;
@@ -509,10 +540,12 @@ class SelectiveRecordDropFilter : public TlsRecordFilter {
};
// Set the version number in the ClientHello.
-class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
+class TlsClientHelloVersionSetter : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionSetter(uint16_t version)
- : TlsHandshakeFilter({kTlsHandshakeClientHello}), version_(version) {}
+ TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t version)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeClientHello}),
+ version_(version) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -525,7 +558,8 @@ class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
// Damages the last byte of a handshake message.
class TlsLastByteDamager : public TlsHandshakeFilter {
public:
- TlsLastByteDamager(uint8_t type) : type_(type) {}
+ TlsLastByteDamager(const std::shared_ptr<TlsAgent>& agent, uint8_t type)
+ : TlsHandshakeFilter(agent), type_(type) {}
PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) override {
@@ -545,8 +579,10 @@ class TlsLastByteDamager : public TlsHandshakeFilter {
class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
public:
- SelectedCipherSuiteReplacer(uint16_t suite)
- : TlsHandshakeFilter({kTlsHandshakeServerHello}), cipher_suite_(suite) {}
+ SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t suite)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}),
+ cipher_suite_(suite) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,