diff options
author | Martin Thomson <martin.thomson@gmail.com> | 2015-03-17 13:42:19 -0700 |
---|---|---|
committer | Martin Thomson <martin.thomson@gmail.com> | 2015-03-17 13:42:19 -0700 |
commit | 22794c5b6ab3628b2c602913e7b3a6aa118fe31d (patch) | |
tree | bbab30ba56aa68367bec95723f91fe23a92baa17 | |
parent | 307df33d560aa70e12d3e59d45704a1dce6343e3 (diff) | |
download | nss-hg-22794c5b6ab3628b2c602913e7b3a6aa118fe31d.tar.gz |
Bug 753136 - More extensive extension exercising, r=ekrNSS_3_18_1_BETA1
-rw-r--r-- | external_tests/ssl_gtest/manifest.mn | 1 | ||||
-rw-r--r-- | external_tests/ssl_gtest/ssl_extension_unittest.cc | 578 | ||||
-rw-r--r-- | external_tests/ssl_gtest/ssl_loopback_unittest.cc | 43 | ||||
-rw-r--r-- | external_tests/ssl_gtest/test_io.cc | 36 | ||||
-rw-r--r-- | external_tests/ssl_gtest/test_io.h | 6 | ||||
-rw-r--r-- | external_tests/ssl_gtest/tls_agent.cc | 1 | ||||
-rw-r--r-- | external_tests/ssl_gtest/tls_connect.cc | 65 | ||||
-rw-r--r-- | external_tests/ssl_gtest/tls_connect.h | 49 | ||||
-rw-r--r-- | external_tests/ssl_gtest/tls_filter.cc | 2 |
9 files changed, 689 insertions, 92 deletions
diff --git a/external_tests/ssl_gtest/manifest.mn b/external_tests/ssl_gtest/manifest.mn index ee883e9ac..9b8669b3e 100644 --- a/external_tests/ssl_gtest/manifest.mn +++ b/external_tests/ssl_gtest/manifest.mn @@ -8,6 +8,7 @@ MODULE = nss CPPSRCS = \ ssl_loopback_unittest.cc \ + ssl_extension_unittest.cc \ ssl_gtest.cc \ test_io.cc \ tls_agent.cc \ diff --git a/external_tests/ssl_gtest/ssl_extension_unittest.cc b/external_tests/ssl_gtest/ssl_extension_unittest.cc new file mode 100644 index 000000000..a11f90543 --- /dev/null +++ b/external_tests/ssl_gtest/ssl_extension_unittest.cc @@ -0,0 +1,578 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" +#include "sslproto.h" + +#include <memory> + +#include "tls_parser.h" +#include "tls_filter.h" +#include "tls_connect.h" + +namespace nss_test { + +class TlsExtensionFilter : public TlsHandshakeFilter { + protected: + virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type, + const DataBuffer& input, DataBuffer* output) { + if (handshake_type == kTlsHandshakeClientHello) { + TlsParser parser(input); + if (!FindClientHelloExtensions(parser, version)) { + return false; + } + return FilterExtensions(parser, input, output); + } + if (handshake_type == kTlsHandshakeServerHello) { + TlsParser parser(input); + if (!FindServerHelloExtensions(parser, version)) { + return false; + } + return FilterExtensions(parser, input, output); + } + return false; + } + + virtual bool FilterExtension(uint16_t extension_type, + const DataBuffer& input, DataBuffer* output) = 0; + + public: + static bool FindClientHelloExtensions(TlsParser& parser, uint16_t version) { + if (!parser.Skip(2 + 32)) { // version + random + return false; + } + if (!parser.SkipVariable(1)) { // session ID + return false; + } + if (IsDtls(version) && !parser.SkipVariable(1)) { // DTLS cookie + return false; + } + if (!parser.SkipVariable(2)) { // cipher suites + return false; + } + if (!parser.SkipVariable(1)) { // compression methods + return false; + } + return true; + } + + static bool FindServerHelloExtensions(TlsParser& parser, uint16_t version) { + if (!parser.Skip(2 + 32)) { // version + random + return false; + } + if (!parser.SkipVariable(1)) { // session ID + return false; + } + if (!parser.Skip(2)) { // cipher suite + return false; + } + if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) { + if (!parser.Skip(1)) { // compression method + return false; + } + } + return true; + } + + private: + bool FilterExtensions(TlsParser& parser, + const DataBuffer& input, DataBuffer* output) { + size_t length_offset = parser.consumed(); + uint32_t all_extensions; + if (!parser.Read(&all_extensions, 2)) { + return false; // no extensions, odd but OK + } + if (all_extensions != parser.remaining()) { + return false; // malformed + } + + bool changed = false; + + // Write out the start of the message. + output->Allocate(input.len()); + output->Write(0, input.data(), parser.consumed()); + size_t output_offset = parser.consumed(); + + while (parser.remaining()) { + uint32_t extension_type; + if (!parser.Read(&extension_type, 2)) { + return false; // malformed + } + + // Copy extension type. + output->Write(output_offset, extension_type, 2); + + DataBuffer extension; + if (!parser.ReadVariable(&extension, 2)) { + return false; // malformed + } + output_offset = ApplyFilter(static_cast<uint16_t>(extension_type), extension, + output, output_offset + 2, &changed); + } + output->Truncate(output_offset); + + if (changed) { + size_t newlen = output->len() - length_offset - 2; + if (newlen >= 0x10000) { + return false; // bad: size increased too much + } + output->Write(length_offset, newlen, 2); + } + return changed; + } + + size_t ApplyFilter(uint16_t extension_type, const DataBuffer& extension, + DataBuffer* output, size_t offset, bool* changed) { + const DataBuffer* source = &extension; + DataBuffer filtered; + if (FilterExtension(extension_type, extension, &filtered) && + filtered.len() < 0x10000) { + *changed = true; + std::cerr << "extension old: " << extension << std::endl; + std::cerr << "extension new: " << filtered << std::endl; + source = &filtered; + } + + output->Write(offset, source->len(), 2); + output->Write(offset + 2, *source); + return offset + 2 + source->len(); + } +}; + +class TlsExtensionTruncator : public TlsExtensionFilter { + public: + TlsExtensionTruncator(uint16_t extension, size_t length) + : extension_(extension), length_(length) {} + virtual bool FilterExtension(uint16_t extension_type, + const DataBuffer& input, DataBuffer* output) { + if (extension_type != extension_) { + return false; + } + if (input.len() <= length_) { + return false; + } + + output->Assign(input.data(), length_); + return true; + } + private: + uint16_t extension_; + size_t length_; +}; + +class TlsExtensionDamager : public TlsExtensionFilter { + public: + TlsExtensionDamager(uint16_t extension, size_t index) + : extension_(extension), index_(index) {} + virtual bool FilterExtension(uint16_t extension_type, + const DataBuffer& input, DataBuffer* output) { + if (extension_type != extension_) { + return false; + } + + *output = input; + output->data()[index_] += 73; // Increment selected for maximum damage + return true; + } + private: + uint16_t extension_; + size_t index_; +}; + +class TlsExtensionReplacer : public TlsExtensionFilter { + public: + TlsExtensionReplacer(uint16_t extension, const DataBuffer& data) + : extension_(extension), data_(data) {} + virtual bool FilterExtension(uint16_t extension_type, + const DataBuffer& input, DataBuffer* output) { + if (extension_type != extension_) { + return false; + } + + *output = data_; + return true; + } + private: + uint16_t extension_; + DataBuffer data_; +}; + +class TlsExtensionInjector : public TlsHandshakeFilter { + public: + TlsExtensionInjector(uint16_t ext, DataBuffer& data) + : extension_(ext), data_(data) {} + + virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type, + const DataBuffer& input, DataBuffer* output) { + size_t offset; + if (handshake_type == kTlsHandshakeClientHello) { + TlsParser parser(input); + if (!TlsExtensionFilter::FindClientHelloExtensions(parser, version)) { + return false; + } + offset = parser.consumed(); + } else if (handshake_type == kTlsHandshakeServerHello) { + TlsParser parser(input); + if (!TlsExtensionFilter::FindServerHelloExtensions(parser, version)) { + return false; + } + offset = parser.consumed(); + } else { + return false; + } + + *output = input; + + std::cerr << "Pre:" << input << std::endl; + std::cerr << "Lof:" << offset << std::endl; + + // Increase the size of the extensions. + uint16_t* len_addr = reinterpret_cast<uint16_t*>(output->data() + offset); + std::cerr << "L-p:" << ntohs(*len_addr) << std::endl; + *len_addr = htons(ntohs(*len_addr) + data_.len() + 4); + std::cerr << "L-i:" << ntohs(*len_addr) << std::endl; + + + // Insert the extension type and length. + DataBuffer type_length; + type_length.Allocate(4); + type_length.Write(0, extension_, 2); + type_length.Write(2, data_.len(), 2); + output->Splice(type_length, offset + 2); + + // Insert the payload. + output->Splice(data_, offset + 6); + + std::cerr << "Aft:" << *output << std::endl; + return true; + } + + private: + uint16_t extension_; + DataBuffer data_; +}; + +class TlsExtensionTestBase : public TlsConnectTestBase { + protected: + TlsExtensionTestBase(Mode mode, uint16_t version) + : TlsConnectTestBase(mode, version) {} + + void ClientHelloErrorTest(PacketFilter* filter, + uint8_t alert = kTlsAlertDecodeError) { + auto alert_recorder = new TlsAlertRecorder(); + server_->SetPacketFilter(alert_recorder); + if (filter) { + client_->SetPacketFilter(filter); + } + ConnectExpectFail(); + ASSERT_EQ(kTlsAlertFatal, alert_recorder->level()); + ASSERT_EQ(alert, alert_recorder->description()); + } + + void ServerHelloErrorTest(PacketFilter* filter, + uint8_t alert = kTlsAlertDecodeError) { + auto alert_recorder = new TlsAlertRecorder(); + client_->SetPacketFilter(alert_recorder); + if (filter) { + server_->SetPacketFilter(filter); + } + ConnectExpectFail(); + ASSERT_EQ(kTlsAlertFatal, alert_recorder->level()); + ASSERT_EQ(alert, alert_recorder->description()); + } + + static void InitSimpleSni(DataBuffer* extension) { + const char* name = "host.name"; + const size_t namelen = PL_strlen(name); + extension->Allocate(namelen + 5); + extension->Write(0, namelen + 3, 2); + extension->Write(2, static_cast<uint32_t>(0), 1); // 0 == hostname + extension->Write(3, namelen, 2); + extension->Write(5, reinterpret_cast<const uint8_t*>(name), namelen); + } +}; + +class TlsExtensionTestDtls + : public TlsExtensionTestBase, + public ::testing::WithParamInterface<uint16_t> { + public: + TlsExtensionTestDtls() : TlsExtensionTestBase(DGRAM, GetParam()) {} +}; + +class TlsExtensionTest12Plus + : public TlsExtensionTestBase, + public ::testing::WithParamInterface<std::string> { + public: + TlsExtensionTest12Plus() + : TlsExtensionTestBase(TlsConnectTestBase::ToMode(GetParam()), + SSL_LIBRARY_VERSION_TLS_1_2) {} +}; + +class TlsExtensionTestGeneric + : public TlsExtensionTestBase, + public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { + public: + TlsExtensionTestGeneric() + : TlsExtensionTestBase(TlsConnectTestBase::ToMode((std::get<0>(GetParam()))), + std::get<1>(GetParam())) {} +}; + +TEST_P(TlsExtensionTestGeneric, DamageSniLength) { + ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 1)); +} + +TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) { + ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 4)); +} + +TEST_P(TlsExtensionTestGeneric, TruncateSni) { + ClientHelloErrorTest(new TlsExtensionTruncator(ssl_server_name_xtn, 7)); +} + +// A valid extension that appears twice will be reported as unsupported. +TEST_P(TlsExtensionTestGeneric, RepeatSni) { + DataBuffer extension; + InitSimpleSni(&extension); + ClientHelloErrorTest(new TlsExtensionInjector(ssl_server_name_xtn, extension), + kTlsAlertIllegalParameter); +} + +// An SNI entry with zero length is considered invalid (strangely, not if it is +// the last entry, which is probably a bug). +TEST_P(TlsExtensionTestGeneric, BadSni) { + DataBuffer simple; + InitSimpleSni(&simple); + DataBuffer extension; + extension.Allocate(simple.len() + 3); + extension.Write(0, static_cast<uint32_t>(0), 3); + extension.Write(3, simple); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_server_name_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) { + EnableAlpn(); + DataBuffer extension; + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension), + kTlsAlertIllegalParameter); +} + +// An empty ALPN isn't considered bad, though it does lead to there being no +// protocol for the server to select. +TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) { + EnableAlpn(); + const uint8_t val[] = { 0x00, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension), + kTlsAlertNoApplicationProtocol); +} + +TEST_P(TlsExtensionTestGeneric, OneByteAlpn) { + EnableAlpn(); + ClientHelloErrorTest(new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 1)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) { + EnableAlpn(); + // This will leave the length of the second entry, but no value. + ClientHelloErrorTest(new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 5)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { + EnableAlpn(); + const uint8_t val[] = { 0x01, 0x61, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnMismatch) { + const uint8_t client_alpn[] = { 0x01, 0x61 }; + client_->EnableAlpn(client_alpn, sizeof(client_alpn)); + const uint8_t server_alpn[] = { 0x02, 0x61, 0x62 }; + server_->EnableAlpn(server_alpn, sizeof(server_alpn)); + + ClientHelloErrorTest(nullptr, kTlsAlertNoApplicationProtocol); +} + +TEST_P(TlsExtensionTestGeneric, AlpnReturnedEmptyList) { + EnableAlpn(); + const uint8_t val[] = { 0x00, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnReturnedEmptyName) { + EnableAlpn(); + const uint8_t val[] = { 0x00, 0x01, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnReturnedListTrailingData) { + EnableAlpn(); + const uint8_t val[] = { 0x00, 0x02, 0x01, 0x61, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnReturnedExtraEntry) { + EnableAlpn(); + const uint8_t val[] = { 0x00, 0x04, 0x01, 0x61, 0x01, 0x62 }; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnReturnedBadListLength) { + EnableAlpn(); + const uint8_t val[] = { 0x00, 0x99, 0x01, 0x61, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnReturnedBadNameLength) { + EnableAlpn(); + const uint8_t val[] = { 0x00, 0x02, 0x99, 0x61 }; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestDtls, SrtpShort) { + EnableSrtp(); + ClientHelloErrorTest(new TlsExtensionTruncator(ssl_use_srtp_xtn, 3)); +} + +TEST_P(TlsExtensionTestDtls, SrtpOdd) { + EnableSrtp(); + const uint8_t val[] = { 0x00, 0x01, 0xff, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_use_srtp_xtn, extension)); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) { + const uint8_t val[] = { 0x00 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn, + extension)); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) { + const uint8_t val[] = { 0x00, 0x02, 0x04, 0x01, 0x00 }; // sha-256, rsa + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn, + extension)); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) { + const uint8_t val[] = { 0x00, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn, + extension)); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) { + const uint8_t val[] = { 0x00, 0x01, 0x04 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn, + extension)); +} + +// The extension handling ignores unsupported hashes, so breaking that has no +// effect on success rates. However, ssl3_SendServerKeyExchange catches an +// unsupported signature algorithm. + +// This actually fails with a decryption error (fatal alert 51). That's a bad +// to fail, since any tampering with the handshake will trigger that alert when +// verifying the Finished message. Thus, this test is disabled until this error +// is turned into an alert. +TEST_P(TlsExtensionTest12Plus, DISABLED_SignatureAlgorithmsSigUnsupported) { + const uint8_t val[] = { 0x00, 0x02, 0x04, 0x99 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn, + extension)); +} + +TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) { + EnableSomeECDHECiphers(); + const uint8_t val[] = { 0x00, 0x01, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_elliptic_curves_xtn, + extension)); +} + +TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) { + EnableSomeECDHECiphers(); + const uint8_t val[] = { 0x09, 0x99, 0x00, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_elliptic_curves_xtn, + extension)); +} + +TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) { + EnableSomeECDHECiphers(); + const uint8_t val[] = { 0x00, 0x02, 0x00, 0x00, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_elliptic_curves_xtn, + extension)); +} + +TEST_P(TlsExtensionTestGeneric, SupportedPointsEmpty) { + EnableSomeECDHECiphers(); + const uint8_t val[] = { 0x00 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_ec_point_formats_xtn, + extension)); +} + +TEST_P(TlsExtensionTestGeneric, SupportedPointsBadLength) { + EnableSomeECDHECiphers(); + const uint8_t val[] = { 0x99, 0x00, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_ec_point_formats_xtn, + extension)); +} + +TEST_P(TlsExtensionTestGeneric, SupportedPointsTrailingData) { + EnableSomeECDHECiphers(); + const uint8_t val[] = { 0x01, 0x00, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_ec_point_formats_xtn, + extension)); +} + +TEST_P(TlsExtensionTestGeneric, RenegotiationInfoBadLength) { + const uint8_t val[] = { 0x99 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_renegotiation_info_xtn, + extension)); +} + +TEST_P(TlsExtensionTestGeneric, RenegotiationInfoMismatch) { + const uint8_t val[] = { 0x01, 0x00 }; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_renegotiation_info_xtn, + extension)); +} + +// The extension has to contain a length. +TEST_P(TlsExtensionTestGeneric, RenegotiationInfoExtensionEmpty) { + DataBuffer extension; + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_renegotiation_info_xtn, + extension)); +} + +INSTANTIATE_TEST_CASE_P(ExtensionTls10, TlsExtensionTestGeneric, + ::testing::Combine( + TlsConnectTestBase::kTlsModesStream, + TlsConnectTestBase::kTlsV10)); +INSTANTIATE_TEST_CASE_P(ExtensionVariants, TlsExtensionTestGeneric, + ::testing::Combine( + TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV11V12)); +INSTANTIATE_TEST_CASE_P(ExtensionTls12Plus, TlsExtensionTest12Plus, + TlsConnectTestBase::kTlsModesAll); +INSTANTIATE_TEST_CASE_P(ExtensionDgram, TlsExtensionTestDtls, + TlsConnectTestBase::kTlsV11V12); + +} // namespace nspr_test diff --git a/external_tests/ssl_gtest/ssl_loopback_unittest.cc b/external_tests/ssl_gtest/ssl_loopback_unittest.cc index a984c2350..b372412f8 100644 --- a/external_tests/ssl_gtest/ssl_loopback_unittest.cc +++ b/external_tests/ssl_gtest/ssl_loopback_unittest.cc @@ -44,13 +44,7 @@ TEST_P(TlsConnectGeneric, SetupOnly) {} TEST_P(TlsConnectGeneric, Connect) { Connect(); - - // Check that we negotiated the expected version. - if (mode_ == STREAM) { - client_->CheckVersion(SSL_LIBRARY_VERSION_TLS_1_0); - } else { - client_->CheckVersion(SSL_LIBRARY_VERSION_TLS_1_1); - } + client_->CheckVersion(SSL_LIBRARY_VERSION_TLS_1_2); } TEST_P(TlsConnectGeneric, ConnectResumed) { @@ -159,19 +153,19 @@ TEST_P(TlsConnectGeneric, ConnectAlpn) { server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, "a"); } -TEST_F(DtlsConnectTest, ConnectSrtp) { +TEST_P(TlsConnectDatagram, ConnectSrtp) { EnableSrtp(); Connect(); CheckSrtp(); } -TEST_F(TlsConnectTest, ConnectECDHE) { +TEST_P(TlsConnectStream, ConnectECDHE) { EnableSomeECDHECiphers(); Connect(); client_->CheckKEAType(ssl_kea_ecdh); } -TEST_F(TlsConnectTest, ConnectECDHETwiceReuseKey) { +TEST_P(TlsConnectStream, ConnectECDHETwiceReuseKey) { EnableSomeECDHECiphers(); TlsInspectorRecordHandshakeMessage* i1 = new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); @@ -200,7 +194,7 @@ TEST_F(TlsConnectTest, ConnectECDHETwiceReuseKey) { dhe1.public_key_.len())); } -TEST_F(TlsConnectTest, ConnectECDHETwiceNewKey) { +TEST_P(TlsConnectStream, ConnectECDHETwiceNewKey) { EnableSomeECDHECiphers(); SECStatus rv = SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); @@ -234,24 +228,17 @@ TEST_F(TlsConnectTest, ConnectECDHETwiceNewKey) { dhe1.public_key_.len()))); } -TEST_P(TlsConnectGenericSingleVersion, Connect) { - Connect(); -} - -static const std::string kTls[] = {"TLS"}; -static const std::string kTlsDtls[] = {"TLS", "DTLS"}; -static const uint16_t kTlsV10[] = {SSL_LIBRARY_VERSION_TLS_1_0}; -static const uint16_t kTlsV11V12[] = {SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_2}; -INSTANTIATE_TEST_CASE_P(Variants, TlsConnectGeneric, - ::testing::ValuesIn(kTlsDtls)); -INSTANTIATE_TEST_CASE_P(VersionsStream, TlsConnectGenericSingleVersion, +INSTANTIATE_TEST_CASE_P(VariantsStream10, TlsConnectGeneric, ::testing::Combine( - ::testing::ValuesIn(kTls), - ::testing::ValuesIn(kTlsV10))); -INSTANTIATE_TEST_CASE_P(VersionsByVariants, TlsConnectGenericSingleVersion, + TlsConnectTestBase::kTlsModesStream, + TlsConnectTestBase::kTlsV10)); +INSTANTIATE_TEST_CASE_P(VariantsAll, TlsConnectGeneric, ::testing::Combine( - ::testing::ValuesIn(kTlsDtls), - ::testing::ValuesIn(kTlsV11V12))); + TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV11V12)); +INSTANTIATE_TEST_CASE_P(VersionsDatagram, TlsConnectDatagram, + TlsConnectTestBase::kTlsV11V12); +INSTANTIATE_TEST_CASE_P(VersionsDatagram, TlsConnectStream, + TlsConnectTestBase::kTlsV11V12); } // namespace nspr_test diff --git a/external_tests/ssl_gtest/test_io.cc b/external_tests/ssl_gtest/test_io.cc index 2bfd09178..70b2b1e1b 100644 --- a/external_tests/ssl_gtest/test_io.cc +++ b/external_tests/ssl_gtest/test_io.cc @@ -328,7 +328,7 @@ int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) { } Packet *front = input_.front(); - if (buflen < front->len()) { + if (static_cast<size_t>(buflen) < front->len()) { PR_ASSERT(false); PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0); return -1; @@ -344,33 +344,23 @@ int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) { } int32_t DummyPrSocket::Write(const void *buf, int32_t length) { - DataBuffer packet(static_cast<const uint8_t*>(buf), - static_cast<size_t>(length)); - if (filter_) { - DataBuffer filtered; - if (filter_->Filter(packet, &filtered)) { - if (WriteDirect(filtered) != filtered.len()) { - PR_SetError(PR_IO_ERROR, 0); - return -1; - } - LOG("Wrote: " << packet); - // libssl can't handle if this reports something other than the length of - // what was passed in (or less, but we're not doing partial writes). - return packet.len(); - } - } - - return WriteDirect(packet); -} - -int32_t DummyPrSocket::WriteDirect(const DataBuffer& packet) { if (!peer_) { PR_SetError(PR_IO_ERROR, 0); return -1; } - peer_->PacketReceived(packet); - return static_cast<int32_t>(packet.len()); // ignore truncation + DataBuffer packet(static_cast<const uint8_t*>(buf), + static_cast<size_t>(length)); + DataBuffer filtered; + if (filter_ && filter_->Filter(packet, &filtered)) { + LOG("Filtered packet: " << filtered); + peer_->PacketReceived(filtered); + } else { + peer_->PacketReceived(packet); + } + // libssl can't handle it if this reports something other than the length + // of what was passed in (or less, but we're not doing partial writes). + return static_cast<int32_t>(packet.len()); } Poller *Poller::instance; diff --git a/external_tests/ssl_gtest/test_io.h b/external_tests/ssl_gtest/test_io.h index d2424c60c..d55d3e4d8 100644 --- a/external_tests/ssl_gtest/test_io.h +++ b/external_tests/ssl_gtest/test_io.h @@ -12,6 +12,7 @@ #include <memory> #include <queue> #include <string> +#include <ostream> #include "prio.h" @@ -36,6 +37,10 @@ class PacketFilter { enum Mode { STREAM, DGRAM }; +inline std::ostream& operator<<(std::ostream& os, Mode m) { + return os << ((m == STREAM) ? "TLS" : "DTLS"); +} + class DummyPrSocket { public: ~DummyPrSocket(); @@ -52,7 +57,6 @@ class DummyPrSocket { int32_t Read(void* data, int32_t len); int32_t Recv(void* buf, int32_t buflen); int32_t Write(const void* buf, int32_t length); - int32_t WriteDirect(const DataBuffer& data); Mode mode() const { return mode_; } bool readable() const { return !input_.empty(); } diff --git a/external_tests/ssl_gtest/tls_agent.cc b/external_tests/ssl_gtest/tls_agent.cc index b7f785d76..76e924574 100644 --- a/external_tests/ssl_gtest/tls_agent.cc +++ b/external_tests/ssl_gtest/tls_agent.cc @@ -163,6 +163,7 @@ void TlsAgent::CheckSrtp() { ASSERT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual); } + void TlsAgent::Handshake() { SECStatus rv = SSL_ForceHandshake(ssl_fd_); if (rv == SECSuccess) { diff --git a/external_tests/ssl_gtest/tls_connect.cc b/external_tests/ssl_gtest/tls_connect.cc index 3cec0cf6f..6101b271a 100644 --- a/external_tests/ssl_gtest/tls_connect.cc +++ b/external_tests/ssl_gtest/tls_connect.cc @@ -8,18 +8,56 @@ #include <iostream> +#include "sslproto.h" #include "gtest_utils.h" extern std::string g_working_dir_path; namespace nss_test { -TlsConnectTestBase::TlsConnectTestBase(Mode mode) +static const std::string kTlsModesStreamArr[] = {"TLS"}; +::testing::internal::ParamGenerator<std::string> + TlsConnectTestBase::kTlsModesStream = ::testing::ValuesIn(kTlsModesStreamArr); +static const std::string kTlsModesAllArr[] = {"TLS", "DTLS"}; +::testing::internal::ParamGenerator<std::string> + TlsConnectTestBase::kTlsModesAll = ::testing::ValuesIn(kTlsModesAllArr); +static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0}; +::testing::internal::ParamGenerator<uint16_t> + TlsConnectTestBase::kTlsV10 = ::testing::ValuesIn(kTlsV10Arr); +static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> + TlsConnectTestBase::kTlsV11V12 = ::testing::ValuesIn(kTlsV11V12Arr); +// TODO: add TLS 1.3 +static const uint16_t kTlsV12PlusArr[] = {SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> + TlsConnectTestBase::kTlsV12Plus = ::testing::ValuesIn(kTlsV12PlusArr); + +static std::string VersionString(uint16_t version) { + switch(version) { + case 0: + return "(no version)"; + case SSL_LIBRARY_VERSION_TLS_1_0: + return "1.0"; + case SSL_LIBRARY_VERSION_TLS_1_1: + return "1.1"; + case SSL_LIBRARY_VERSION_TLS_1_2: + return "1.2"; + default: + std::cerr << "Invalid version: " << version << std::endl; + EXPECT_TRUE(false); + return ""; + } +} + +TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version) : mode_(mode), client_(new TlsAgent("client", TlsAgent::CLIENT, mode_)), server_(new TlsAgent("server", TlsAgent::SERVER, mode_)), - version_(0), - session_ids_() {} + version_(version), + session_ids_() { + std::cerr << "Version: " << mode_ << " " << VersionString(version_) << std::endl; +} TlsConnectTestBase::~TlsConnectTestBase() { delete client_; @@ -51,6 +89,11 @@ void TlsConnectTestBase::Init() { client_->SetPeer(server_); server_->SetPeer(client_); + + if (version_) { + client_->SetVersionRange(version_, version_); + server_->SetVersionRange(version_, version_); + } } void TlsConnectTestBase::Reset() { @@ -60,11 +103,6 @@ void TlsConnectTestBase::Reset() { client_ = new TlsAgent("client", TlsAgent::CLIENT, mode_); server_ = new TlsAgent("server", TlsAgent::SERVER, mode_); - if (version_) { - client_->SetVersionRange(version_, version_); - server_->SetVersionRange(version_, version_); - } - Init(); } @@ -110,9 +148,9 @@ void TlsConnectTestBase::Connect() { // Check and store session ids. std::vector<uint8_t> sid_c1 = client_->session_id(); - ASSERT_EQ(32, sid_c1.size()); + ASSERT_EQ(32U, sid_c1.size()); std::vector<uint8_t> sid_s1 = server_->session_id(); - ASSERT_EQ(32, sid_s1.size()); + ASSERT_EQ(32U, sid_s1.size()); ASSERT_EQ(sid_c1, sid_s1); session_ids_.push_back(sid_c1); } @@ -151,7 +189,7 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { if (resume_ct) { // Check that the last two session ids match. - ASSERT_GE(2, session_ids_.size()); + ASSERT_GE(2U, session_ids_.size()); ASSERT_EQ(session_ids_[session_ids_.size()-1], session_ids_[session_ids_.size()-2]); } @@ -178,8 +216,7 @@ void TlsConnectTestBase::CheckSrtp() { } TlsConnectGeneric::TlsConnectGeneric() - : TlsConnectTestBase((GetParam() == "TLS") ? STREAM : DGRAM) { - std::cerr << "Variant: " << GetParam() << std::endl; -} + : TlsConnectTestBase(TlsConnectTestBase::ToMode(std::get<0>(GetParam())), + std::get<1>(GetParam())) {} } // namespace nss_test diff --git a/external_tests/ssl_gtest/tls_connect.h b/external_tests/ssl_gtest/tls_connect.h index 799a54b03..a981399f6 100644 --- a/external_tests/ssl_gtest/tls_connect.h +++ b/external_tests/ssl_gtest/tls_connect.h @@ -21,7 +21,17 @@ namespace nss_test { // A generic TLS connection test base. class TlsConnectTestBase : public ::testing::Test { public: - TlsConnectTestBase(Mode mode); + static ::testing::internal::ParamGenerator<std::string> kTlsModesStream; + static ::testing::internal::ParamGenerator<std::string> kTlsModesAll; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV10; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus; + + static inline Mode ToMode(const std::string& str) { + return str == "TLS" ? STREAM : DGRAM; + } + + TlsConnectTestBase(Mode mode, uint16_t version); virtual ~TlsConnectTestBase(); void SetUp(); @@ -58,42 +68,29 @@ class TlsConnectTestBase : public ::testing::Test { }; // A TLS-only test base. -class TlsConnectTest : public TlsConnectTestBase { +class TlsConnectStream : public TlsConnectTestBase, + public ::testing::WithParamInterface<uint16_t> { public: - TlsConnectTest() : TlsConnectTestBase(STREAM) {} + TlsConnectStream() : TlsConnectTestBase(STREAM, GetParam()) {} }; // A DTLS-only test base. -class DtlsConnectTest : public TlsConnectTestBase { +class TlsConnectDatagram : public TlsConnectTestBase, + public ::testing::WithParamInterface<uint16_t> { public: - DtlsConnectTest() : TlsConnectTestBase(DGRAM) {} + TlsConnectDatagram() : TlsConnectTestBase(DGRAM, GetParam()) {} }; -// A generic test class that can be either STREAM or DGRAM. This is configured -// in ssl_loopback_unittest.cc. All uses of this should use TEST_P(). -class TlsConnectGeneric : public TlsConnectTestBase, - public ::testing::WithParamInterface<std::string> { +// A generic test class that can be either STREAM or DGRAM and a single version +// of TLS. This is configured in ssl_loopback_unittest.cc. All uses of this +// should use TEST_P(). +class TlsConnectGeneric + : public TlsConnectTestBase, + public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { public: TlsConnectGeneric(); }; -// A generic test class that is a single version of TLS. This is configured -// in ssl_loopback_unittest.cc. All uses of this should use TEST_P(). -class TlsConnectGenericSingleVersion : public TlsConnectTestBase, - public ::testing::WithParamInterface< -std::tuple<std::string,uint16_t>> { -public: - TlsConnectGenericSingleVersion() : TlsConnectTestBase( - std::get<0>(GetParam()) == "TLS" ? STREAM : DGRAM) { - uint16_t version = std::get<1>(GetParam()); - - std::cerr << "Version : " << version << std::endl; - client_->SetVersionRange(version, version); - server_->SetVersionRange(version, version); - version_ = version; - } -}; - } // namespace nss_test #endif diff --git a/external_tests/ssl_gtest/tls_filter.cc b/external_tests/ssl_gtest/tls_filter.cc index 3cbe9e5ac..4ed74e4aa 100644 --- a/external_tests/ssl_gtest/tls_filter.cc +++ b/external_tests/ssl_gtest/tls_filter.cc @@ -192,6 +192,8 @@ bool TlsAlertRecorder::FilterRecord(uint8_t content_type, uint16_t version, return false; } + std::cerr << "Alert: " << input << std::endl; + TlsParser parser(input); uint8_t lvl; if (!parser.Read(&lvl)) { |