summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Thomson <martin.thomson@gmail.com>2015-03-17 13:42:19 -0700
committerMartin Thomson <martin.thomson@gmail.com>2015-03-17 13:42:19 -0700
commit22794c5b6ab3628b2c602913e7b3a6aa118fe31d (patch)
treebbab30ba56aa68367bec95723f91fe23a92baa17
parent307df33d560aa70e12d3e59d45704a1dce6343e3 (diff)
downloadnss-hg-NSS_3_18_1_BETA1.tar.gz
Bug 753136 - More extensive extension exercising, r=ekrNSS_3_18_1_BETA1
-rw-r--r--external_tests/ssl_gtest/manifest.mn1
-rw-r--r--external_tests/ssl_gtest/ssl_extension_unittest.cc578
-rw-r--r--external_tests/ssl_gtest/ssl_loopback_unittest.cc43
-rw-r--r--external_tests/ssl_gtest/test_io.cc36
-rw-r--r--external_tests/ssl_gtest/test_io.h6
-rw-r--r--external_tests/ssl_gtest/tls_agent.cc1
-rw-r--r--external_tests/ssl_gtest/tls_connect.cc65
-rw-r--r--external_tests/ssl_gtest/tls_connect.h49
-rw-r--r--external_tests/ssl_gtest/tls_filter.cc2
9 files changed, 689 insertions, 92 deletions
diff --git a/external_tests/ssl_gtest/manifest.mn b/external_tests/ssl_gtest/manifest.mn
index ee883e9ac..9b8669b3e 100644
--- a/external_tests/ssl_gtest/manifest.mn
+++ b/external_tests/ssl_gtest/manifest.mn
@@ -8,6 +8,7 @@ MODULE = nss
CPPSRCS = \
ssl_loopback_unittest.cc \
+ ssl_extension_unittest.cc \
ssl_gtest.cc \
test_io.cc \
tls_agent.cc \
diff --git a/external_tests/ssl_gtest/ssl_extension_unittest.cc b/external_tests/ssl_gtest/ssl_extension_unittest.cc
new file mode 100644
index 000000000..a11f90543
--- /dev/null
+++ b/external_tests/ssl_gtest/ssl_extension_unittest.cc
@@ -0,0 +1,578 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+#include "sslproto.h"
+
+#include <memory>
+
+#include "tls_parser.h"
+#include "tls_filter.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+class TlsExtensionFilter : public TlsHandshakeFilter {
+ protected:
+ virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
+ const DataBuffer& input, DataBuffer* output) {
+ if (handshake_type == kTlsHandshakeClientHello) {
+ TlsParser parser(input);
+ if (!FindClientHelloExtensions(parser, version)) {
+ return false;
+ }
+ return FilterExtensions(parser, input, output);
+ }
+ if (handshake_type == kTlsHandshakeServerHello) {
+ TlsParser parser(input);
+ if (!FindServerHelloExtensions(parser, version)) {
+ return false;
+ }
+ return FilterExtensions(parser, input, output);
+ }
+ return false;
+ }
+
+ virtual bool FilterExtension(uint16_t extension_type,
+ const DataBuffer& input, DataBuffer* output) = 0;
+
+ public:
+ static bool FindClientHelloExtensions(TlsParser& parser, uint16_t version) {
+ if (!parser.Skip(2 + 32)) { // version + random
+ return false;
+ }
+ if (!parser.SkipVariable(1)) { // session ID
+ return false;
+ }
+ if (IsDtls(version) && !parser.SkipVariable(1)) { // DTLS cookie
+ return false;
+ }
+ if (!parser.SkipVariable(2)) { // cipher suites
+ return false;
+ }
+ if (!parser.SkipVariable(1)) { // compression methods
+ return false;
+ }
+ return true;
+ }
+
+ static bool FindServerHelloExtensions(TlsParser& parser, uint16_t version) {
+ if (!parser.Skip(2 + 32)) { // version + random
+ return false;
+ }
+ if (!parser.SkipVariable(1)) { // session ID
+ return false;
+ }
+ if (!parser.Skip(2)) { // cipher suite
+ return false;
+ }
+ if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) {
+ if (!parser.Skip(1)) { // compression method
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private:
+ bool FilterExtensions(TlsParser& parser,
+ const DataBuffer& input, DataBuffer* output) {
+ size_t length_offset = parser.consumed();
+ uint32_t all_extensions;
+ if (!parser.Read(&all_extensions, 2)) {
+ return false; // no extensions, odd but OK
+ }
+ if (all_extensions != parser.remaining()) {
+ return false; // malformed
+ }
+
+ bool changed = false;
+
+ // Write out the start of the message.
+ output->Allocate(input.len());
+ output->Write(0, input.data(), parser.consumed());
+ size_t output_offset = parser.consumed();
+
+ while (parser.remaining()) {
+ uint32_t extension_type;
+ if (!parser.Read(&extension_type, 2)) {
+ return false; // malformed
+ }
+
+ // Copy extension type.
+ output->Write(output_offset, extension_type, 2);
+
+ DataBuffer extension;
+ if (!parser.ReadVariable(&extension, 2)) {
+ return false; // malformed
+ }
+ output_offset = ApplyFilter(static_cast<uint16_t>(extension_type), extension,
+ output, output_offset + 2, &changed);
+ }
+ output->Truncate(output_offset);
+
+ if (changed) {
+ size_t newlen = output->len() - length_offset - 2;
+ if (newlen >= 0x10000) {
+ return false; // bad: size increased too much
+ }
+ output->Write(length_offset, newlen, 2);
+ }
+ return changed;
+ }
+
+ size_t ApplyFilter(uint16_t extension_type, const DataBuffer& extension,
+ DataBuffer* output, size_t offset, bool* changed) {
+ const DataBuffer* source = &extension;
+ DataBuffer filtered;
+ if (FilterExtension(extension_type, extension, &filtered) &&
+ filtered.len() < 0x10000) {
+ *changed = true;
+ std::cerr << "extension old: " << extension << std::endl;
+ std::cerr << "extension new: " << filtered << std::endl;
+ source = &filtered;
+ }
+
+ output->Write(offset, source->len(), 2);
+ output->Write(offset + 2, *source);
+ return offset + 2 + source->len();
+ }
+};
+
+class TlsExtensionTruncator : public TlsExtensionFilter {
+ public:
+ TlsExtensionTruncator(uint16_t extension, size_t length)
+ : extension_(extension), length_(length) {}
+ virtual bool FilterExtension(uint16_t extension_type,
+ const DataBuffer& input, DataBuffer* output) {
+ if (extension_type != extension_) {
+ return false;
+ }
+ if (input.len() <= length_) {
+ return false;
+ }
+
+ output->Assign(input.data(), length_);
+ return true;
+ }
+ private:
+ uint16_t extension_;
+ size_t length_;
+};
+
+class TlsExtensionDamager : public TlsExtensionFilter {
+ public:
+ TlsExtensionDamager(uint16_t extension, size_t index)
+ : extension_(extension), index_(index) {}
+ virtual bool FilterExtension(uint16_t extension_type,
+ const DataBuffer& input, DataBuffer* output) {
+ if (extension_type != extension_) {
+ return false;
+ }
+
+ *output = input;
+ output->data()[index_] += 73; // Increment selected for maximum damage
+ return true;
+ }
+ private:
+ uint16_t extension_;
+ size_t index_;
+};
+
+class TlsExtensionReplacer : public TlsExtensionFilter {
+ public:
+ TlsExtensionReplacer(uint16_t extension, const DataBuffer& data)
+ : extension_(extension), data_(data) {}
+ virtual bool FilterExtension(uint16_t extension_type,
+ const DataBuffer& input, DataBuffer* output) {
+ if (extension_type != extension_) {
+ return false;
+ }
+
+ *output = data_;
+ return true;
+ }
+ private:
+ uint16_t extension_;
+ DataBuffer data_;
+};
+
+class TlsExtensionInjector : public TlsHandshakeFilter {
+ public:
+ TlsExtensionInjector(uint16_t ext, DataBuffer& data)
+ : extension_(ext), data_(data) {}
+
+ virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
+ const DataBuffer& input, DataBuffer* output) {
+ size_t offset;
+ if (handshake_type == kTlsHandshakeClientHello) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindClientHelloExtensions(parser, version)) {
+ return false;
+ }
+ offset = parser.consumed();
+ } else if (handshake_type == kTlsHandshakeServerHello) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindServerHelloExtensions(parser, version)) {
+ return false;
+ }
+ offset = parser.consumed();
+ } else {
+ return false;
+ }
+
+ *output = input;
+
+ std::cerr << "Pre:" << input << std::endl;
+ std::cerr << "Lof:" << offset << std::endl;
+
+ // Increase the size of the extensions.
+ uint16_t* len_addr = reinterpret_cast<uint16_t*>(output->data() + offset);
+ std::cerr << "L-p:" << ntohs(*len_addr) << std::endl;
+ *len_addr = htons(ntohs(*len_addr) + data_.len() + 4);
+ std::cerr << "L-i:" << ntohs(*len_addr) << std::endl;
+
+
+ // Insert the extension type and length.
+ DataBuffer type_length;
+ type_length.Allocate(4);
+ type_length.Write(0, extension_, 2);
+ type_length.Write(2, data_.len(), 2);
+ output->Splice(type_length, offset + 2);
+
+ // Insert the payload.
+ output->Splice(data_, offset + 6);
+
+ std::cerr << "Aft:" << *output << std::endl;
+ return true;
+ }
+
+ private:
+ uint16_t extension_;
+ DataBuffer data_;
+};
+
+class TlsExtensionTestBase : public TlsConnectTestBase {
+ protected:
+ TlsExtensionTestBase(Mode mode, uint16_t version)
+ : TlsConnectTestBase(mode, version) {}
+
+ void ClientHelloErrorTest(PacketFilter* filter,
+ uint8_t alert = kTlsAlertDecodeError) {
+ auto alert_recorder = new TlsAlertRecorder();
+ server_->SetPacketFilter(alert_recorder);
+ if (filter) {
+ client_->SetPacketFilter(filter);
+ }
+ ConnectExpectFail();
+ ASSERT_EQ(kTlsAlertFatal, alert_recorder->level());
+ ASSERT_EQ(alert, alert_recorder->description());
+ }
+
+ void ServerHelloErrorTest(PacketFilter* filter,
+ uint8_t alert = kTlsAlertDecodeError) {
+ auto alert_recorder = new TlsAlertRecorder();
+ client_->SetPacketFilter(alert_recorder);
+ if (filter) {
+ server_->SetPacketFilter(filter);
+ }
+ ConnectExpectFail();
+ ASSERT_EQ(kTlsAlertFatal, alert_recorder->level());
+ ASSERT_EQ(alert, alert_recorder->description());
+ }
+
+ static void InitSimpleSni(DataBuffer* extension) {
+ const char* name = "host.name";
+ const size_t namelen = PL_strlen(name);
+ extension->Allocate(namelen + 5);
+ extension->Write(0, namelen + 3, 2);
+ extension->Write(2, static_cast<uint32_t>(0), 1); // 0 == hostname
+ extension->Write(3, namelen, 2);
+ extension->Write(5, reinterpret_cast<const uint8_t*>(name), namelen);
+ }
+};
+
+class TlsExtensionTestDtls
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ TlsExtensionTestDtls() : TlsExtensionTestBase(DGRAM, GetParam()) {}
+};
+
+class TlsExtensionTest12Plus
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<std::string> {
+ public:
+ TlsExtensionTest12Plus()
+ : TlsExtensionTestBase(TlsConnectTestBase::ToMode(GetParam()),
+ SSL_LIBRARY_VERSION_TLS_1_2) {}
+};
+
+class TlsExtensionTestGeneric
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+ TlsExtensionTestGeneric()
+ : TlsExtensionTestBase(TlsConnectTestBase::ToMode((std::get<0>(GetParam()))),
+ std::get<1>(GetParam())) {}
+};
+
+TEST_P(TlsExtensionTestGeneric, DamageSniLength) {
+ ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 1));
+}
+
+TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) {
+ ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 4));
+}
+
+TEST_P(TlsExtensionTestGeneric, TruncateSni) {
+ ClientHelloErrorTest(new TlsExtensionTruncator(ssl_server_name_xtn, 7));
+}
+
+// A valid extension that appears twice will be reported as unsupported.
+TEST_P(TlsExtensionTestGeneric, RepeatSni) {
+ DataBuffer extension;
+ InitSimpleSni(&extension);
+ ClientHelloErrorTest(new TlsExtensionInjector(ssl_server_name_xtn, extension),
+ kTlsAlertIllegalParameter);
+}
+
+// An SNI entry with zero length is considered invalid (strangely, not if it is
+// the last entry, which is probably a bug).
+TEST_P(TlsExtensionTestGeneric, BadSni) {
+ DataBuffer simple;
+ InitSimpleSni(&simple);
+ DataBuffer extension;
+ extension.Allocate(simple.len() + 3);
+ extension.Write(0, static_cast<uint32_t>(0), 3);
+ extension.Write(3, simple);
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_server_name_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) {
+ EnableAlpn();
+ DataBuffer extension;
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension),
+ kTlsAlertIllegalParameter);
+}
+
+// An empty ALPN isn't considered bad, though it does lead to there being no
+// protocol for the server to select.
+TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) {
+ EnableAlpn();
+ const uint8_t val[] = { 0x00, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension),
+ kTlsAlertNoApplicationProtocol);
+}
+
+TEST_P(TlsExtensionTestGeneric, OneByteAlpn) {
+ EnableAlpn();
+ ClientHelloErrorTest(new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 1));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) {
+ EnableAlpn();
+ // This will leave the length of the second entry, but no value.
+ ClientHelloErrorTest(new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 5));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
+ EnableAlpn();
+ const uint8_t val[] = { 0x01, 0x61, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnMismatch) {
+ const uint8_t client_alpn[] = { 0x01, 0x61 };
+ client_->EnableAlpn(client_alpn, sizeof(client_alpn));
+ const uint8_t server_alpn[] = { 0x02, 0x61, 0x62 };
+ server_->EnableAlpn(server_alpn, sizeof(server_alpn));
+
+ ClientHelloErrorTest(nullptr, kTlsAlertNoApplicationProtocol);
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnReturnedEmptyList) {
+ EnableAlpn();
+ const uint8_t val[] = { 0x00, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnReturnedEmptyName) {
+ EnableAlpn();
+ const uint8_t val[] = { 0x00, 0x01, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnReturnedListTrailingData) {
+ EnableAlpn();
+ const uint8_t val[] = { 0x00, 0x02, 0x01, 0x61, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnReturnedExtraEntry) {
+ EnableAlpn();
+ const uint8_t val[] = { 0x00, 0x04, 0x01, 0x61, 0x01, 0x62 };
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnReturnedBadListLength) {
+ EnableAlpn();
+ const uint8_t val[] = { 0x00, 0x99, 0x01, 0x61, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnReturnedBadNameLength) {
+ EnableAlpn();
+ const uint8_t val[] = { 0x00, 0x02, 0x99, 0x61 };
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestDtls, SrtpShort) {
+ EnableSrtp();
+ ClientHelloErrorTest(new TlsExtensionTruncator(ssl_use_srtp_xtn, 3));
+}
+
+TEST_P(TlsExtensionTestDtls, SrtpOdd) {
+ EnableSrtp();
+ const uint8_t val[] = { 0x00, 0x01, 0xff, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_use_srtp_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) {
+ const uint8_t val[] = { 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn,
+ extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) {
+ const uint8_t val[] = { 0x00, 0x02, 0x04, 0x01, 0x00 }; // sha-256, rsa
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn,
+ extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) {
+ const uint8_t val[] = { 0x00, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn,
+ extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) {
+ const uint8_t val[] = { 0x00, 0x01, 0x04 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn,
+ extension));
+}
+
+// The extension handling ignores unsupported hashes, so breaking that has no
+// effect on success rates. However, ssl3_SendServerKeyExchange catches an
+// unsupported signature algorithm.
+
+// This actually fails with a decryption error (fatal alert 51). That's a bad
+// to fail, since any tampering with the handshake will trigger that alert when
+// verifying the Finished message. Thus, this test is disabled until this error
+// is turned into an alert.
+TEST_P(TlsExtensionTest12Plus, DISABLED_SignatureAlgorithmsSigUnsupported) {
+ const uint8_t val[] = { 0x00, 0x02, 0x04, 0x99 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn,
+ extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) {
+ EnableSomeECDHECiphers();
+ const uint8_t val[] = { 0x00, 0x01, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_elliptic_curves_xtn,
+ extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) {
+ EnableSomeECDHECiphers();
+ const uint8_t val[] = { 0x09, 0x99, 0x00, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_elliptic_curves_xtn,
+ extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) {
+ EnableSomeECDHECiphers();
+ const uint8_t val[] = { 0x00, 0x02, 0x00, 0x00, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_elliptic_curves_xtn,
+ extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedPointsEmpty) {
+ EnableSomeECDHECiphers();
+ const uint8_t val[] = { 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_ec_point_formats_xtn,
+ extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedPointsBadLength) {
+ EnableSomeECDHECiphers();
+ const uint8_t val[] = { 0x99, 0x00, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_ec_point_formats_xtn,
+ extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedPointsTrailingData) {
+ EnableSomeECDHECiphers();
+ const uint8_t val[] = { 0x01, 0x00, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_ec_point_formats_xtn,
+ extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, RenegotiationInfoBadLength) {
+ const uint8_t val[] = { 0x99 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_renegotiation_info_xtn,
+ extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, RenegotiationInfoMismatch) {
+ const uint8_t val[] = { 0x01, 0x00 };
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_renegotiation_info_xtn,
+ extension));
+}
+
+// The extension has to contain a length.
+TEST_P(TlsExtensionTestGeneric, RenegotiationInfoExtensionEmpty) {
+ DataBuffer extension;
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_renegotiation_info_xtn,
+ extension));
+}
+
+INSTANTIATE_TEST_CASE_P(ExtensionTls10, TlsExtensionTestGeneric,
+ ::testing::Combine(
+ TlsConnectTestBase::kTlsModesStream,
+ TlsConnectTestBase::kTlsV10));
+INSTANTIATE_TEST_CASE_P(ExtensionVariants, TlsExtensionTestGeneric,
+ ::testing::Combine(
+ TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV11V12));
+INSTANTIATE_TEST_CASE_P(ExtensionTls12Plus, TlsExtensionTest12Plus,
+ TlsConnectTestBase::kTlsModesAll);
+INSTANTIATE_TEST_CASE_P(ExtensionDgram, TlsExtensionTestDtls,
+ TlsConnectTestBase::kTlsV11V12);
+
+} // namespace nspr_test
diff --git a/external_tests/ssl_gtest/ssl_loopback_unittest.cc b/external_tests/ssl_gtest/ssl_loopback_unittest.cc
index a984c2350..b372412f8 100644
--- a/external_tests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/external_tests/ssl_gtest/ssl_loopback_unittest.cc
@@ -44,13 +44,7 @@ TEST_P(TlsConnectGeneric, SetupOnly) {}
TEST_P(TlsConnectGeneric, Connect) {
Connect();
-
- // Check that we negotiated the expected version.
- if (mode_ == STREAM) {
- client_->CheckVersion(SSL_LIBRARY_VERSION_TLS_1_0);
- } else {
- client_->CheckVersion(SSL_LIBRARY_VERSION_TLS_1_1);
- }
+ client_->CheckVersion(SSL_LIBRARY_VERSION_TLS_1_2);
}
TEST_P(TlsConnectGeneric, ConnectResumed) {
@@ -159,19 +153,19 @@ TEST_P(TlsConnectGeneric, ConnectAlpn) {
server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, "a");
}
-TEST_F(DtlsConnectTest, ConnectSrtp) {
+TEST_P(TlsConnectDatagram, ConnectSrtp) {
EnableSrtp();
Connect();
CheckSrtp();
}
-TEST_F(TlsConnectTest, ConnectECDHE) {
+TEST_P(TlsConnectStream, ConnectECDHE) {
EnableSomeECDHECiphers();
Connect();
client_->CheckKEAType(ssl_kea_ecdh);
}
-TEST_F(TlsConnectTest, ConnectECDHETwiceReuseKey) {
+TEST_P(TlsConnectStream, ConnectECDHETwiceReuseKey) {
EnableSomeECDHECiphers();
TlsInspectorRecordHandshakeMessage* i1 =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
@@ -200,7 +194,7 @@ TEST_F(TlsConnectTest, ConnectECDHETwiceReuseKey) {
dhe1.public_key_.len()));
}
-TEST_F(TlsConnectTest, ConnectECDHETwiceNewKey) {
+TEST_P(TlsConnectStream, ConnectECDHETwiceNewKey) {
EnableSomeECDHECiphers();
SECStatus rv =
SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
@@ -234,24 +228,17 @@ TEST_F(TlsConnectTest, ConnectECDHETwiceNewKey) {
dhe1.public_key_.len())));
}
-TEST_P(TlsConnectGenericSingleVersion, Connect) {
- Connect();
-}
-
-static const std::string kTls[] = {"TLS"};
-static const std::string kTlsDtls[] = {"TLS", "DTLS"};
-static const uint16_t kTlsV10[] = {SSL_LIBRARY_VERSION_TLS_1_0};
-static const uint16_t kTlsV11V12[] = {SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_2};
-INSTANTIATE_TEST_CASE_P(Variants, TlsConnectGeneric,
- ::testing::ValuesIn(kTlsDtls));
-INSTANTIATE_TEST_CASE_P(VersionsStream, TlsConnectGenericSingleVersion,
+INSTANTIATE_TEST_CASE_P(VariantsStream10, TlsConnectGeneric,
::testing::Combine(
- ::testing::ValuesIn(kTls),
- ::testing::ValuesIn(kTlsV10)));
-INSTANTIATE_TEST_CASE_P(VersionsByVariants, TlsConnectGenericSingleVersion,
+ TlsConnectTestBase::kTlsModesStream,
+ TlsConnectTestBase::kTlsV10));
+INSTANTIATE_TEST_CASE_P(VariantsAll, TlsConnectGeneric,
::testing::Combine(
- ::testing::ValuesIn(kTlsDtls),
- ::testing::ValuesIn(kTlsV11V12)));
+ TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV11V12));
+INSTANTIATE_TEST_CASE_P(VersionsDatagram, TlsConnectDatagram,
+ TlsConnectTestBase::kTlsV11V12);
+INSTANTIATE_TEST_CASE_P(VersionsDatagram, TlsConnectStream,
+ TlsConnectTestBase::kTlsV11V12);
} // namespace nspr_test
diff --git a/external_tests/ssl_gtest/test_io.cc b/external_tests/ssl_gtest/test_io.cc
index 2bfd09178..70b2b1e1b 100644
--- a/external_tests/ssl_gtest/test_io.cc
+++ b/external_tests/ssl_gtest/test_io.cc
@@ -328,7 +328,7 @@ int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) {
}
Packet *front = input_.front();
- if (buflen < front->len()) {
+ if (static_cast<size_t>(buflen) < front->len()) {
PR_ASSERT(false);
PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0);
return -1;
@@ -344,33 +344,23 @@ int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) {
}
int32_t DummyPrSocket::Write(const void *buf, int32_t length) {
- DataBuffer packet(static_cast<const uint8_t*>(buf),
- static_cast<size_t>(length));
- if (filter_) {
- DataBuffer filtered;
- if (filter_->Filter(packet, &filtered)) {
- if (WriteDirect(filtered) != filtered.len()) {
- PR_SetError(PR_IO_ERROR, 0);
- return -1;
- }
- LOG("Wrote: " << packet);
- // libssl can't handle if this reports something other than the length of
- // what was passed in (or less, but we're not doing partial writes).
- return packet.len();
- }
- }
-
- return WriteDirect(packet);
-}
-
-int32_t DummyPrSocket::WriteDirect(const DataBuffer& packet) {
if (!peer_) {
PR_SetError(PR_IO_ERROR, 0);
return -1;
}
- peer_->PacketReceived(packet);
- return static_cast<int32_t>(packet.len()); // ignore truncation
+ DataBuffer packet(static_cast<const uint8_t*>(buf),
+ static_cast<size_t>(length));
+ DataBuffer filtered;
+ if (filter_ && filter_->Filter(packet, &filtered)) {
+ LOG("Filtered packet: " << filtered);
+ peer_->PacketReceived(filtered);
+ } else {
+ peer_->PacketReceived(packet);
+ }
+ // libssl can't handle it if this reports something other than the length
+ // of what was passed in (or less, but we're not doing partial writes).
+ return static_cast<int32_t>(packet.len());
}
Poller *Poller::instance;
diff --git a/external_tests/ssl_gtest/test_io.h b/external_tests/ssl_gtest/test_io.h
index d2424c60c..d55d3e4d8 100644
--- a/external_tests/ssl_gtest/test_io.h
+++ b/external_tests/ssl_gtest/test_io.h
@@ -12,6 +12,7 @@
#include <memory>
#include <queue>
#include <string>
+#include <ostream>
#include "prio.h"
@@ -36,6 +37,10 @@ class PacketFilter {
enum Mode { STREAM, DGRAM };
+inline std::ostream& operator<<(std::ostream& os, Mode m) {
+ return os << ((m == STREAM) ? "TLS" : "DTLS");
+}
+
class DummyPrSocket {
public:
~DummyPrSocket();
@@ -52,7 +57,6 @@ class DummyPrSocket {
int32_t Read(void* data, int32_t len);
int32_t Recv(void* buf, int32_t buflen);
int32_t Write(const void* buf, int32_t length);
- int32_t WriteDirect(const DataBuffer& data);
Mode mode() const { return mode_; }
bool readable() const { return !input_.empty(); }
diff --git a/external_tests/ssl_gtest/tls_agent.cc b/external_tests/ssl_gtest/tls_agent.cc
index b7f785d76..76e924574 100644
--- a/external_tests/ssl_gtest/tls_agent.cc
+++ b/external_tests/ssl_gtest/tls_agent.cc
@@ -163,6 +163,7 @@ void TlsAgent::CheckSrtp() {
ASSERT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
}
+
void TlsAgent::Handshake() {
SECStatus rv = SSL_ForceHandshake(ssl_fd_);
if (rv == SECSuccess) {
diff --git a/external_tests/ssl_gtest/tls_connect.cc b/external_tests/ssl_gtest/tls_connect.cc
index 3cec0cf6f..6101b271a 100644
--- a/external_tests/ssl_gtest/tls_connect.cc
+++ b/external_tests/ssl_gtest/tls_connect.cc
@@ -8,18 +8,56 @@
#include <iostream>
+#include "sslproto.h"
#include "gtest_utils.h"
extern std::string g_working_dir_path;
namespace nss_test {
-TlsConnectTestBase::TlsConnectTestBase(Mode mode)
+static const std::string kTlsModesStreamArr[] = {"TLS"};
+::testing::internal::ParamGenerator<std::string>
+ TlsConnectTestBase::kTlsModesStream = ::testing::ValuesIn(kTlsModesStreamArr);
+static const std::string kTlsModesAllArr[] = {"TLS", "DTLS"};
+::testing::internal::ParamGenerator<std::string>
+ TlsConnectTestBase::kTlsModesAll = ::testing::ValuesIn(kTlsModesAllArr);
+static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0};
+::testing::internal::ParamGenerator<uint16_t>
+ TlsConnectTestBase::kTlsV10 = ::testing::ValuesIn(kTlsV10Arr);
+static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t>
+ TlsConnectTestBase::kTlsV11V12 = ::testing::ValuesIn(kTlsV11V12Arr);
+// TODO: add TLS 1.3
+static const uint16_t kTlsV12PlusArr[] = {SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t>
+ TlsConnectTestBase::kTlsV12Plus = ::testing::ValuesIn(kTlsV12PlusArr);
+
+static std::string VersionString(uint16_t version) {
+ switch(version) {
+ case 0:
+ return "(no version)";
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ return "1.0";
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ return "1.1";
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ return "1.2";
+ default:
+ std::cerr << "Invalid version: " << version << std::endl;
+ EXPECT_TRUE(false);
+ return "";
+ }
+}
+
+TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version)
: mode_(mode),
client_(new TlsAgent("client", TlsAgent::CLIENT, mode_)),
server_(new TlsAgent("server", TlsAgent::SERVER, mode_)),
- version_(0),
- session_ids_() {}
+ version_(version),
+ session_ids_() {
+ std::cerr << "Version: " << mode_ << " " << VersionString(version_) << std::endl;
+}
TlsConnectTestBase::~TlsConnectTestBase() {
delete client_;
@@ -51,6 +89,11 @@ void TlsConnectTestBase::Init() {
client_->SetPeer(server_);
server_->SetPeer(client_);
+
+ if (version_) {
+ client_->SetVersionRange(version_, version_);
+ server_->SetVersionRange(version_, version_);
+ }
}
void TlsConnectTestBase::Reset() {
@@ -60,11 +103,6 @@ void TlsConnectTestBase::Reset() {
client_ = new TlsAgent("client", TlsAgent::CLIENT, mode_);
server_ = new TlsAgent("server", TlsAgent::SERVER, mode_);
- if (version_) {
- client_->SetVersionRange(version_, version_);
- server_->SetVersionRange(version_, version_);
- }
-
Init();
}
@@ -110,9 +148,9 @@ void TlsConnectTestBase::Connect() {
// Check and store session ids.
std::vector<uint8_t> sid_c1 = client_->session_id();
- ASSERT_EQ(32, sid_c1.size());
+ ASSERT_EQ(32U, sid_c1.size());
std::vector<uint8_t> sid_s1 = server_->session_id();
- ASSERT_EQ(32, sid_s1.size());
+ ASSERT_EQ(32U, sid_s1.size());
ASSERT_EQ(sid_c1, sid_s1);
session_ids_.push_back(sid_c1);
}
@@ -151,7 +189,7 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
if (resume_ct) {
// Check that the last two session ids match.
- ASSERT_GE(2, session_ids_.size());
+ ASSERT_GE(2U, session_ids_.size());
ASSERT_EQ(session_ids_[session_ids_.size()-1],
session_ids_[session_ids_.size()-2]);
}
@@ -178,8 +216,7 @@ void TlsConnectTestBase::CheckSrtp() {
}
TlsConnectGeneric::TlsConnectGeneric()
- : TlsConnectTestBase((GetParam() == "TLS") ? STREAM : DGRAM) {
- std::cerr << "Variant: " << GetParam() << std::endl;
-}
+ : TlsConnectTestBase(TlsConnectTestBase::ToMode(std::get<0>(GetParam())),
+ std::get<1>(GetParam())) {}
} // namespace nss_test
diff --git a/external_tests/ssl_gtest/tls_connect.h b/external_tests/ssl_gtest/tls_connect.h
index 799a54b03..a981399f6 100644
--- a/external_tests/ssl_gtest/tls_connect.h
+++ b/external_tests/ssl_gtest/tls_connect.h
@@ -21,7 +21,17 @@ namespace nss_test {
// A generic TLS connection test base.
class TlsConnectTestBase : public ::testing::Test {
public:
- TlsConnectTestBase(Mode mode);
+ static ::testing::internal::ParamGenerator<std::string> kTlsModesStream;
+ static ::testing::internal::ParamGenerator<std::string> kTlsModesAll;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV10;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus;
+
+ static inline Mode ToMode(const std::string& str) {
+ return str == "TLS" ? STREAM : DGRAM;
+ }
+
+ TlsConnectTestBase(Mode mode, uint16_t version);
virtual ~TlsConnectTestBase();
void SetUp();
@@ -58,42 +68,29 @@ class TlsConnectTestBase : public ::testing::Test {
};
// A TLS-only test base.
-class TlsConnectTest : public TlsConnectTestBase {
+class TlsConnectStream : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<uint16_t> {
public:
- TlsConnectTest() : TlsConnectTestBase(STREAM) {}
+ TlsConnectStream() : TlsConnectTestBase(STREAM, GetParam()) {}
};
// A DTLS-only test base.
-class DtlsConnectTest : public TlsConnectTestBase {
+class TlsConnectDatagram : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<uint16_t> {
public:
- DtlsConnectTest() : TlsConnectTestBase(DGRAM) {}
+ TlsConnectDatagram() : TlsConnectTestBase(DGRAM, GetParam()) {}
};
-// A generic test class that can be either STREAM or DGRAM. This is configured
-// in ssl_loopback_unittest.cc. All uses of this should use TEST_P().
-class TlsConnectGeneric : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::string> {
+// A generic test class that can be either STREAM or DGRAM and a single version
+// of TLS. This is configured in ssl_loopback_unittest.cc. All uses of this
+// should use TEST_P().
+class TlsConnectGeneric
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
public:
TlsConnectGeneric();
};
-// A generic test class that is a single version of TLS. This is configured
-// in ssl_loopback_unittest.cc. All uses of this should use TEST_P().
-class TlsConnectGenericSingleVersion : public TlsConnectTestBase,
- public ::testing::WithParamInterface<
-std::tuple<std::string,uint16_t>> {
-public:
- TlsConnectGenericSingleVersion() : TlsConnectTestBase(
- std::get<0>(GetParam()) == "TLS" ? STREAM : DGRAM) {
- uint16_t version = std::get<1>(GetParam());
-
- std::cerr << "Version : " << version << std::endl;
- client_->SetVersionRange(version, version);
- server_->SetVersionRange(version, version);
- version_ = version;
- }
-};
-
} // namespace nss_test
#endif
diff --git a/external_tests/ssl_gtest/tls_filter.cc b/external_tests/ssl_gtest/tls_filter.cc
index 3cbe9e5ac..4ed74e4aa 100644
--- a/external_tests/ssl_gtest/tls_filter.cc
+++ b/external_tests/ssl_gtest/tls_filter.cc
@@ -192,6 +192,8 @@ bool TlsAlertRecorder::FilterRecord(uint8_t content_type, uint16_t version,
return false;
}
+ std::cerr << "Alert: " << input << std::endl;
+
TlsParser parser(input);
uint8_t lvl;
if (!parser.Read(&lvl)) {