From f95d45c36e7c7131747259956821d844e8952e5d Mon Sep 17 00:00:00 2001 From: Lorry Tar Creator Date: Thu, 8 Jun 2017 10:53:01 +0000 Subject: nss-3.31 --- nss/gtests/ssl_gtest/Makefile | 5 +- nss/gtests/ssl_gtest/databuffer.h | 191 ------- nss/gtests/ssl_gtest/gtest_utils.h | 2 +- nss/gtests/ssl_gtest/libssl_internals.c | 99 +++- nss/gtests/ssl_gtest/libssl_internals.h | 13 + nss/gtests/ssl_gtest/manifest.mn | 14 +- nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc | 199 ++++++++ nss/gtests/ssl_gtest/ssl_agent_unittest.cc | 28 +- nss/gtests/ssl_gtest/ssl_auth_unittest.cc | 195 +++++-- nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc | 83 ++- nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc | 45 +- nss/gtests/ssl_gtest/ssl_damage_unittest.cc | 64 ++- nss/gtests/ssl_gtest/ssl_dhe_unittest.cc | 100 ++-- nss/gtests/ssl_gtest/ssl_drop_unittest.cc | 16 +- nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc | 74 ++- nss/gtests/ssl_gtest/ssl_ems_unittest.cc | 6 +- nss/gtests/ssl_gtest/ssl_exporter_unittest.cc | 34 +- nss/gtests/ssl_gtest/ssl_extension_unittest.cc | 568 ++++++++++++++------- nss/gtests/ssl_gtest/ssl_fragment_unittest.cc | 157 ++++++ nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc | 192 ++++--- nss/gtests/ssl_gtest/ssl_gather_unittest.cc | 143 ++++++ nss/gtests/ssl_gtest/ssl_gtest.cc | 14 +- nss/gtests/ssl_gtest/ssl_gtest.gyp | 31 +- nss/gtests/ssl_gtest/ssl_hrr_unittest.cc | 117 ++++- nss/gtests/ssl_gtest/ssl_loopback_unittest.cc | 163 ++++-- nss/gtests/ssl_gtest/ssl_resumption_unittest.cc | 161 +++++- nss/gtests/ssl_gtest/ssl_skip_unittest.cc | 143 ++++-- nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc | 50 +- .../ssl_gtest/ssl_v2_client_hello_unittest.cc | 58 ++- nss/gtests/ssl_gtest/ssl_version_unittest.cc | 59 ++- nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc | 394 ++++++++++++++ nss/gtests/ssl_gtest/test_io.cc | 386 ++------------ nss/gtests/ssl_gtest/test_io.h | 97 ++-- nss/gtests/ssl_gtest/tls_agent.cc | 371 +++++++++----- nss/gtests/ssl_gtest/tls_agent.h | 135 +++-- nss/gtests/ssl_gtest/tls_connect.cc | 175 ++++--- nss/gtests/ssl_gtest/tls_connect.h | 86 ++-- nss/gtests/ssl_gtest/tls_filter.cc | 326 ++++++++---- nss/gtests/ssl_gtest/tls_filter.h | 231 ++++++--- nss/gtests/ssl_gtest/tls_parser.cc | 73 --- nss/gtests/ssl_gtest/tls_parser.h | 131 ----- nss/gtests/ssl_gtest/tls_protect.cc | 145 ++++++ nss/gtests/ssl_gtest/tls_protect.h | 76 +++ 43 files changed, 3815 insertions(+), 1835 deletions(-) delete mode 100644 nss/gtests/ssl_gtest/databuffer.h create mode 100644 nss/gtests/ssl_gtest/ssl_fragment_unittest.cc create mode 100644 nss/gtests/ssl_gtest/ssl_gather_unittest.cc create mode 100644 nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc delete mode 100644 nss/gtests/ssl_gtest/tls_parser.cc delete mode 100644 nss/gtests/ssl_gtest/tls_parser.h create mode 100644 nss/gtests/ssl_gtest/tls_protect.cc create mode 100644 nss/gtests/ssl_gtest/tls_protect.h (limited to 'nss/gtests/ssl_gtest') diff --git a/nss/gtests/ssl_gtest/Makefile b/nss/gtests/ssl_gtest/Makefile index dfb8df9..a9a9290 100644 --- a/nss/gtests/ssl_gtest/Makefile +++ b/nss/gtests/ssl_gtest/Makefile @@ -33,11 +33,8 @@ ifdef NSS_SSL_ENABLE_ZLIB include $(CORE_DEPTH)/coreconf/zlib.mk endif -ifndef NSS_ENABLE_TLS_1_3 -NSS_DISABLE_TLS_1_3=1 -endif - ifdef NSS_DISABLE_TLS_1_3 +NSS_DISABLE_TLS_1_3=1 # Run parameterized tests only, for which we can easily exclude TLS 1.3 CPPSRCS := $(filter-out $(shell grep -l '^TEST_F' $(CPPSRCS)), $(CPPSRCS)) CFLAGS += -DNSS_DISABLE_TLS_1_3 diff --git a/nss/gtests/ssl_gtest/databuffer.h b/nss/gtests/ssl_gtest/databuffer.h deleted file mode 100644 index e7236d4..0000000 --- a/nss/gtests/ssl_gtest/databuffer.h +++ /dev/null @@ -1,191 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#ifndef databuffer_h__ -#define databuffer_h__ - -#include -#include -#include -#include -#include -#if defined(WIN32) || defined(WIN64) -#include -#else -#include -#endif - -extern bool g_ssl_gtest_verbose; - -namespace nss_test { - -class DataBuffer { - public: - DataBuffer() : data_(nullptr), len_(0) {} - DataBuffer(const uint8_t* data, size_t len) : data_(nullptr), len_(0) { - Assign(data, len); - } - DataBuffer(const DataBuffer& other) : data_(nullptr), len_(0) { - Assign(other); - } - ~DataBuffer() { delete[] data_; } - - DataBuffer& operator=(const DataBuffer& other) { - if (&other != this) { - Assign(other); - } - return *this; - } - - void Allocate(size_t len) { - delete[] data_; - data_ = new uint8_t[len ? len : 1]; // Don't depend on new [0]. - len_ = len; - } - - void Truncate(size_t len) { len_ = std::min(len_, len); } - - void Assign(const DataBuffer& other) { Assign(other.data(), other.len()); } - - void Assign(const uint8_t* data, size_t len) { - if (data) { - Allocate(len); - memcpy(static_cast(data_), static_cast(data), len); - } else { - assert(len == 0); - data_ = nullptr; - len_ = 0; - } - } - - // Write will do a new allocation and expand the size of the buffer if needed. - // Returns the offset of the end of the write. - size_t Write(size_t index, const uint8_t* val, size_t count) { - assert(val); - if (index + count > len_) { - size_t newlen = index + count; - uint8_t* tmp = new uint8_t[newlen]; // Always > 0. - if (data_) { - memcpy(static_cast(tmp), static_cast(data_), len_); - } - if (index > len_) { - memset(static_cast(tmp + len_), 0, index - len_); - } - delete[] data_; - data_ = tmp; - len_ = newlen; - } - if (data_) { - memcpy(static_cast(data_ + index), static_cast(val), - count); - } - return index + count; - } - - size_t Write(size_t index, const DataBuffer& buf) { - return Write(index, buf.data(), buf.len()); - } - - // Write an integer, also performing host-to-network order conversion. - // Returns the offset of the end of the write. - size_t Write(size_t index, uint32_t val, size_t count) { - assert(count <= sizeof(uint32_t)); - uint32_t nvalue = htonl(val); - auto* addr = reinterpret_cast(&nvalue); - return Write(index, addr + sizeof(uint32_t) - count, count); - } - - // This can't use the same trick as Write(), since we might be reading from a - // smaller data source. - bool Read(size_t index, size_t count, uint32_t* val) const { - assert(count < sizeof(uint32_t)); - assert(val); - if ((index > len()) || (count > (len() - index))) { - return false; - } - *val = 0; - for (size_t i = 0; i < count; ++i) { - *val = (*val << 8) | data()[index + i]; - } - return true; - } - - // Starting at |index|, remove |remove| bytes and replace them with the - // contents of |buf|. - void Splice(const DataBuffer& buf, size_t index, size_t remove = 0) { - Splice(buf.data(), buf.len(), index, remove); - } - - void Splice(const uint8_t* ins, size_t ins_len, size_t index, - size_t remove = 0) { - assert(ins); - uint8_t* old_value = data_; - size_t old_len = len_; - - // The amount of stuff remaining from the tail of the old. - size_t tail_len = old_len - std::min(old_len, index + remove); - // The new length: the head of the old, the new, and the tail of the old. - len_ = index + ins_len + tail_len; - data_ = new uint8_t[len_ ? len_ : 1]; - - // The head of the old. - if (old_value) { - Write(0, old_value, std::min(old_len, index)); - } - // Maybe a gap. - if (old_value && index > old_len) { - memset(old_value + index, 0, index - old_len); - } - // The new. - Write(index, ins, ins_len); - // The tail of the old. - if (tail_len > 0) { - Write(index + ins_len, old_value + index + remove, tail_len); - } - - delete[] old_value; - } - - void Append(const DataBuffer& buf) { Splice(buf, len_); } - - const uint8_t* data() const { return data_; } - uint8_t* data() { return data_; } - size_t len() const { return len_; } - bool empty() const { return len_ == 0; } - - private: - uint8_t* data_; - size_t len_; -}; - -static const size_t kMaxBufferPrint = 32; - -inline std::ostream& operator<<(std::ostream& stream, const DataBuffer& buf) { - stream << "[" << buf.len() << "] "; - for (size_t i = 0; i < buf.len(); ++i) { - if (!g_ssl_gtest_verbose && i >= kMaxBufferPrint) { - stream << "..."; - break; - } - stream << std::hex << std::setfill('0') << std::setw(2) - << static_cast(buf.data()[i]); - } - stream << std::dec; - return stream; -} - -inline bool operator==(const DataBuffer& a, const DataBuffer& b) { - return (a.empty() && b.empty()) || - (a.len() == b.len() && 0 == memcmp(a.data(), b.data(), a.len())); -} - -inline bool operator!=(const DataBuffer& a, const DataBuffer& b) { - return !(a == b); -} - -} // namespace nss_test - -#endif diff --git a/nss/gtests/ssl_gtest/gtest_utils.h b/nss/gtests/ssl_gtest/gtest_utils.h index 3ecd96c..2344c3c 100644 --- a/nss/gtests/ssl_gtest/gtest_utils.h +++ b/nss/gtests/ssl_gtest/gtest_utils.h @@ -34,7 +34,7 @@ class Timeout : public PollTarget { bool timed_out() const { return !handle_; } private: - Poller::Timer* handle_; + std::shared_ptr handle_; }; } // namespace nss_test diff --git a/nss/gtests/ssl_gtest/libssl_internals.c b/nss/gtests/ssl_gtest/libssl_internals.c index 5136ee8..32ffcb6 100644 --- a/nss/gtests/ssl_gtest/libssl_internals.c +++ b/nss/gtests/ssl_gtest/libssl_internals.c @@ -10,8 +10,6 @@ #include "nss.h" #include "pk11pub.h" #include "seccomon.h" -#include "ssl.h" -#include "sslimpl.h" SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd) { sslSocket *ss = ssl_FindSocket(fd); @@ -35,15 +33,8 @@ SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd, return SECFailure; } - SECStatus rv = ssl3_InitState(ss); - if (rv != SECSuccess) { - return rv; - } - - rv = ssl3_RestartHandshakeHashes(ss); - if (rv != SECSuccess) { - return rv; - } + ssl3_InitState(ss); + ssl3_RestartHandshakeHashes(ss); // Ensure we don't overrun hs.client_random. rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len); @@ -64,18 +55,15 @@ PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext) { return (PRBool)(ss && ssl3_ExtensionNegotiated(ss, ext)); } -void SSLInt_ClearSessionTicketKey() { - ssl3_SessionTicketShutdown(NULL, NULL); - NSS_UnregisterShutdown(ssl3_SessionTicketShutdown, NULL); -} +void SSLInt_ClearSessionTicketKey() { ssl_ResetSessionTicketKeys(); } SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) { sslSocket *ss = ssl_FindSocket(fd); - if (ss) { - ss->ssl3.mtu = mtu; - return SECSuccess; + if (!ss) { + return SECFailure; } - return SECFailure; + ss->ssl3.mtu = mtu; + return SECSuccess; } PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) { @@ -199,7 +187,9 @@ SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len) { if (ss->xtnData.nextProto.data) { SECITEM_FreeItem(&ss->xtnData.nextProto, PR_FALSE); } - if (!SECITEM_AllocItem(NULL, &ss->xtnData.nextProto, len)) return SECFailure; + if (!SECITEM_AllocItem(NULL, &ss->xtnData.nextProto, len)) { + return SECFailure; + } PORT_Memcpy(ss->xtnData.nextProto.data, data, len); return SECSuccess; @@ -211,7 +201,7 @@ PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType) { return PR_FALSE; } - return (PRBool)(!!ssl_FindServerCertByAuthType(ss, authType)); + return (PRBool)(!!ssl_FindServerCert(ss, authType, NULL)); } PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) { @@ -256,6 +246,7 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { return SECFailure; } if (to >= (1ULL << 48)) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } ssl_GetSpecWriteLock(ss); @@ -267,6 +258,7 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { * scrub the entire structure on the assumption that the new sequence number * is far enough past the last received sequence number. */ if (to <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } dtls_RecordSetRecvd(&spec->recvdRecords, to); @@ -284,6 +276,7 @@ SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) { return SECFailure; } if (to >= (1ULL << 48)) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } ssl_GetSpecWriteLock(ss); @@ -314,6 +307,40 @@ SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) { return groupDef->keaType; } +SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd, + sslCipherSpecChangedFunc func, + void *arg) { + sslSocket *ss; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + ss->ssl3.changedCipherSpecFunc = func; + ss->ssl3.changedCipherSpecArg = arg; + + return SECSuccess; +} + +static ssl3KeyMaterial *GetKeyingMaterial(PRBool isServer, + ssl3CipherSpec *spec) { + return isServer ? &spec->server : &spec->client; +} + +PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec) { + return GetKeyingMaterial(isServer, spec)->write_key; +} + +SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer, + ssl3CipherSpec *spec) { + return spec->cipher_def->calg; +} + +unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec) { + return GetKeyingMaterial(isServer, spec)->write_iv; +} + SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd) { sslSocket *ss; @@ -335,6 +362,36 @@ SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result) { } *result = ss->ssl3.hs.shortHeaders; + return SECSuccess; +} + +void SSLInt_SetTicketLifetime(uint32_t lifetime) { + ssl_ticket_lifetime = lifetime; +} + +void SSLInt_SetMaxEarlyDataSize(uint32_t size) { + ssl_max_early_data_size = size; +} + +SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) { + sslSocket *ss; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + /* This only works when resuming. */ + if (!ss->statelessResume) { + PORT_SetError(SEC_INTERNAL_ONLY); + return SECFailure; + } + + /* Modifying both specs allows this to be used on either peer. */ + ssl_GetSpecWriteLock(ss); + ss->ssl3.crSpec->earlyDataRemaining = size; + ss->ssl3.cwSpec->earlyDataRemaining = size; + ssl_ReleaseSpecWriteLock(ss); return SECSuccess; } diff --git a/nss/gtests/ssl_gtest/libssl_internals.h b/nss/gtests/ssl_gtest/libssl_internals.h index 6ea66db..531c31f 100644 --- a/nss/gtests/ssl_gtest/libssl_internals.h +++ b/nss/gtests/ssl_gtest/libssl_internals.h @@ -11,6 +11,8 @@ #include "prio.h" #include "seccomon.h" +#include "ssl.h" +#include "sslimpl.h" #include "sslt.h" SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd); @@ -37,7 +39,18 @@ SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to); SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to); SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra); SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group); + +SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd, + sslCipherSpecChangedFunc func, + void *arg); +PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec); +SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer, + ssl3CipherSpec *spec); +unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec); SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd); SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result); +void SSLInt_SetTicketLifetime(uint32_t lifetime); +void SSLInt_SetMaxEarlyDataSize(uint32_t size); +SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size); #endif // ndef libssl_internals_h_ diff --git a/nss/gtests/ssl_gtest/manifest.mn b/nss/gtests/ssl_gtest/manifest.mn index 391db81..e3775cc 100644 --- a/nss/gtests/ssl_gtest/manifest.mn +++ b/nss/gtests/ssl_gtest/manifest.mn @@ -12,6 +12,9 @@ CSRCS = \ $(NULL) CPPSRCS = \ + $(CORE_DEPTH)/cpputil/dummy_io.cc \ + $(CORE_DEPTH)/cpputil/dummy_io_fwd.cc \ + $(CORE_DEPTH)/cpputil/tls_parser.cc \ ssl_0rtt_unittest.cc \ ssl_agent_unittest.cc \ ssl_auth_unittest.cc \ @@ -24,7 +27,9 @@ CPPSRCS = \ ssl_ems_unittest.cc \ ssl_exporter_unittest.cc \ ssl_extension_unittest.cc \ + ssl_fragment_unittest.cc \ ssl_fuzz_unittest.cc \ + ssl_gather_unittest.cc \ ssl_gtest.cc \ ssl_hrr_unittest.cc \ ssl_loopback_unittest.cc \ @@ -34,21 +39,22 @@ CPPSRCS = \ ssl_staticrsa_unittest.cc \ ssl_v2_client_hello_unittest.cc \ ssl_version_unittest.cc \ + ssl_versionpolicy_unittest.cc \ test_io.cc \ tls_agent.cc \ tls_connect.cc \ tls_hkdf_unittest.cc \ tls_filter.cc \ - tls_parser.cc \ + tls_protect.cc \ $(NULL) INCLUDES += -I$(CORE_DEPTH)/gtests/google_test/gtest/include \ - -I$(CORE_DEPTH)/gtests/common + -I$(CORE_DEPTH)/gtests/common \ + -I$(CORE_DEPTH)/cpputil REQUIRES = nspr nss libdbm gtest PROGRAM = ssl_gtest -EXTRA_LIBS = $(DIST)/lib/$(LIB_PREFIX)gtest.$(LIB_SUFFIX) \ - $(DIST)/lib/$(LIB_PREFIX)softokn.$(LIB_SUFFIX) +EXTRA_LIBS = $(DIST)/lib/$(LIB_PREFIX)gtest.$(LIB_SUFFIX) USE_STATIC_LIBS = 1 diff --git a/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc index cf5a27f..85b7011 100644 --- a/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc @@ -155,6 +155,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnServer) { client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a"); EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, sizeof(b))); client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b"); + ExpectAlert(client_, kTlsAlertIllegalParameter); return true; }); Handshake(); @@ -174,6 +175,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnClient) { PRUint8 b[] = {'b'}; EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, 1)); client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b"); + ExpectAlert(client_, kTlsAlertIllegalParameter); return true; }); Handshake(); @@ -200,4 +202,201 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeBoth) { CheckAlpn("b"); } +// The client should abort the connection when sending a 0-rtt handshake but +// the servers responds with a TLS 1.2 ServerHello. (no app data sent) +TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->Set0RttEnabled(true); // set ticket_allow_early_data + Connect(); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + client_->StartConnect(); + server_->StartConnect(); + + // We will send the early data xtn without sending actual early data. Thus + // a 1.2 server shouldn't fail until the client sends an alert because the + // client sends end_of_early_data only after reading the server's flight. + client_->Set0RttEnabled(true); + + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + } + client_->Handshake(); + server_->Handshake(); + ASSERT_TRUE_WAIT( + (client_->error_code() == SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA), 2000); + + // DTLS will timeout as we bump the epoch when installing the early app data + // cipher suite. Thus the encrypted alert will be ignored. + if (variant_ == ssl_variant_stream) { + // The client sends an encrypted alert message. + ASSERT_TRUE_WAIT( + (server_->error_code() == SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA), + 2000); + } +} + +// The client should abort the connection when sending a 0-rtt handshake but +// the servers responds with a TLS 1.2 ServerHello. (with app data) +TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->Set0RttEnabled(true); // set ticket_allow_early_data + Connect(); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + client_->StartConnect(); + server_->StartConnect(); + + // Send the early data xtn in the CH, followed by early app data. The server + // will fail right after sending its flight, when receiving the early data. + client_->Set0RttEnabled(true); + ZeroRttSendReceive(true, false, [this]() { + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + } + return true; + }); + + client_->Handshake(); + server_->Handshake(); + ASSERT_TRUE_WAIT( + (client_->error_code() == SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA), 2000); + + // DTLS will timeout as we bump the epoch when installing the early app data + // cipher suite. Thus the encrypted alert will be ignored. + if (variant_ == ssl_variant_stream) { + // The server sends an alert when receiving the early app data record. + ASSERT_TRUE_WAIT( + (server_->error_code() == SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA), + 2000); + } +} + +static void CheckEarlyDataLimit(const std::shared_ptr& agent, + size_t expected_size) { + SSLPreliminaryChannelInfo preinfo; + SECStatus rv = + SSL_GetPreliminaryChannelInfo(agent->ssl_fd(), &preinfo, sizeof(preinfo)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(expected_size, static_cast(preinfo.maxEarlyDataSize)); +} + +TEST_P(TlsConnectTls13, SendTooMuchEarlyData) { + const char* big_message = "0123456789abcdef"; + const size_t short_size = strlen(big_message) - 1; + const PRInt32 short_length = static_cast(short_size); + SSLInt_SetMaxEarlyDataSize(static_cast(short_size)); + SetupForZeroRtt(); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + ExpectAlert(client_, kTlsAlertEndOfEarlyData); + client_->Handshake(); + CheckEarlyDataLimit(client_, short_size); + + PRInt32 sent; + // Writing more than the limit will succeed in TLS, but fail in DTLS. + if (variant_ == ssl_variant_stream) { + sent = PR_Write(client_->ssl_fd(), big_message, + static_cast(strlen(big_message))); + } else { + sent = PR_Write(client_->ssl_fd(), big_message, + static_cast(strlen(big_message))); + EXPECT_GE(0, sent); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Try an exact-sized write now. + sent = PR_Write(client_->ssl_fd(), big_message, short_length); + } + EXPECT_EQ(short_length, sent); + + // Even a single octet write should now fail. + sent = PR_Write(client_->ssl_fd(), big_message, 1); + EXPECT_GE(0, sent); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Process the ClientHello and read 0-RTT. + server_->Handshake(); + CheckEarlyDataLimit(server_, short_size); + + std::vector buf(short_size + 1); + PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity()); + EXPECT_EQ(short_length, read); + EXPECT_EQ(0, memcmp(big_message, buf.data(), short_size)); + + // Second read fails. + read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity()); + EXPECT_EQ(SECFailure, read); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) { + const size_t limit = 5; + SSLInt_SetMaxEarlyDataSize(limit); + SetupForZeroRtt(); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + client_->ExpectSendAlert(kTlsAlertEndOfEarlyData); + client_->Handshake(); // Send ClientHello + CheckEarlyDataLimit(client_, limit); + + // Lift the limit on the client. + EXPECT_EQ(SECSuccess, + SSLInt_SetSocketMaxEarlyDataSize(client_->ssl_fd(), 1000)); + + // Send message + const char* message = "0123456789abcdef"; + const PRInt32 message_len = static_cast(strlen(message)); + EXPECT_EQ(message_len, PR_Write(client_->ssl_fd(), message, message_len)); + + if (variant_ == ssl_variant_stream) { + // This error isn't fatal for DTLS. + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + } + server_->Handshake(); // Process ClientHello, send server flight. + server_->Handshake(); // Just to make sure that we don't read ahead. + CheckEarlyDataLimit(server_, limit); + + // Attempt to read early data. + std::vector buf(strlen(message) + 1); + EXPECT_GT(0, PR_Read(server_->ssl_fd(), buf.data(), buf.capacity())); + if (variant_ == ssl_variant_stream) { + server_->CheckErrorCode(SSL_ERROR_TOO_MUCH_EARLY_DATA); + } + + client_->Handshake(); // Process the handshake. + client_->Handshake(); // Process the alert. + if (variant_ == ssl_variant_stream) { + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + } +} + } // namespace nss_test diff --git a/nss/gtests/ssl_gtest/ssl_agent_unittest.cc b/nss/gtests/ssl_gtest/ssl_agent_unittest.cc index 0e6ddae..5035a33 100644 --- a/nss/gtests/ssl_gtest/ssl_agent_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_agent_unittest.cc @@ -56,6 +56,7 @@ static const char *k0RttData = "ABCDEF"; TEST_P(TlsAgentTest, EarlyFinished) { DataBuffer buffer; MakeTrivialHandshakeRecord(kTlsHandshakeFinished, 0, &buffer); + ExpectAlert(kTlsAlertUnexpectedMessage); ProcessMessage(buffer, TlsAgent::STATE_ERROR, SSL_ERROR_RX_UNEXPECTED_FINISHED); } @@ -63,15 +64,14 @@ TEST_P(TlsAgentTest, EarlyFinished) { TEST_P(TlsAgentTest, EarlyCertificateVerify) { DataBuffer buffer; MakeTrivialHandshakeRecord(kTlsHandshakeCertificateVerify, 0, &buffer); + ExpectAlert(kTlsAlertUnexpectedMessage); ProcessMessage(buffer, TlsAgent::STATE_ERROR, SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); } -TEST_P(TlsAgentTestClient, CannedHello) { +TEST_P(TlsAgentTestClient13, CannedHello) { DataBuffer buffer; EnsureInit(); - agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, - SSL_LIBRARY_VERSION_TLS_1_3); DataBuffer server_hello; MakeHandshakeMessage(kTlsHandshakeServerHello, kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello), &server_hello); @@ -80,7 +80,7 @@ TEST_P(TlsAgentTestClient, CannedHello) { ProcessMessage(buffer, TlsAgent::STATE_CONNECTING); } -TEST_P(TlsAgentTestClient, EncryptedExtensionsInClear) { +TEST_P(TlsAgentTestClient13, EncryptedExtensionsInClear) { DataBuffer server_hello; MakeHandshakeMessage(kTlsHandshakeServerHello, kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello), &server_hello); @@ -92,8 +92,7 @@ TEST_P(TlsAgentTestClient, EncryptedExtensionsInClear) { MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, server_hello.data(), server_hello.len(), &buffer); EnsureInit(); - agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, - SSL_LIBRARY_VERSION_TLS_1_3); + ExpectAlert(kTlsAlertUnexpectedMessage); ProcessMessage(buffer, TlsAgent::STATE_ERROR, SSL_ERROR_RX_UNEXPECTED_HANDSHAKE); } @@ -118,6 +117,7 @@ TEST_F(TlsAgentStreamTestClient, EncryptedExtensionsInClearTwoPieces) { agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, SSL_LIBRARY_VERSION_TLS_1_3); ProcessMessage(buffer, TlsAgent::STATE_CONNECTING); + ExpectAlert(kTlsAlertUnexpectedMessage); ProcessMessage(buffer2, TlsAgent::STATE_ERROR, SSL_ERROR_RX_UNEXPECTED_HANDSHAKE); } @@ -148,6 +148,7 @@ TEST_F(TlsAgentDgramTestClient, EncryptedExtensionsInClearTwoPieces) { agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, SSL_LIBRARY_VERSION_TLS_1_3); ProcessMessage(buffer, TlsAgent::STATE_CONNECTING); + ExpectAlert(kTlsAlertUnexpectedMessage); ProcessMessage(buffer2, TlsAgent::STATE_ERROR, SSL_ERROR_RX_UNEXPECTED_HANDSHAKE); } @@ -158,8 +159,8 @@ TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenWrite) { SSL_LIBRARY_VERSION_TLS_1_3); agent_->StartConnect(); agent_->Set0RttEnabled(true); - auto filter = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeClientHello); + auto filter = std::make_shared( + kTlsHandshakeClientHello); agent_->SetPacketFilter(filter); PRInt32 rv = PR_Write(agent_->ssl_fd(), k0RttData, strlen(k0RttData)); EXPECT_EQ(-1, rv); @@ -178,6 +179,7 @@ TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenRead) { MakeRecord(kTlsApplicationDataType, SSL_LIBRARY_VERSION_TLS_1_3, reinterpret_cast(k0RttData), strlen(k0RttData), &buffer); + ExpectAlert(kTlsAlertUnexpectedMessage); ProcessMessage(buffer, TlsAgent::STATE_ERROR, SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA); } @@ -198,13 +200,19 @@ TEST_F(TlsAgentStreamTestServer, Set0RttOptionClientHelloThenRead) { MakeRecord(kTlsApplicationDataType, SSL_LIBRARY_VERSION_TLS_1_3, reinterpret_cast(k0RttData), strlen(k0RttData), &buffer); + ExpectAlert(kTlsAlertBadRecordMac); ProcessMessage(buffer, TlsAgent::STATE_ERROR, SSL_ERROR_BAD_MAC_READ); } INSTANTIATE_TEST_CASE_P( AgentTests, TlsAgentTest, ::testing::Combine(TlsAgentTestBase::kTlsRolesAll, - TlsConnectTestBase::kTlsModesStream)); + TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); INSTANTIATE_TEST_CASE_P(ClientTests, TlsAgentTestClient, - TlsConnectTestBase::kTlsModesAll); + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_CASE_P(ClientTests13, TlsAgentTestClient13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13)); } // namespace nss_test diff --git a/nss/gtests/ssl_gtest/ssl_auth_unittest.cc b/nss/gtests/ssl_gtest/ssl_auth_unittest.cc index e407d55..dbcbc9a 100644 --- a/nss/gtests/ssl_gtest/ssl_auth_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_auth_unittest.cc @@ -77,9 +77,10 @@ TEST_P(TlsConnectGeneric, ClientAuthBigRsa) { } // Offset is the position in the captured buffer where the signature sits. -static void CheckSigScheme(TlsInspectorRecordHandshakeMessage* capture, - size_t offset, TlsAgent* peer, - uint16_t expected_scheme, size_t expected_size) { +static void CheckSigScheme( + std::shared_ptr& capture, size_t offset, + std::shared_ptr& peer, uint16_t expected_scheme, + size_t expected_size) { EXPECT_LT(offset + 2U, capture->buffer().len()); uint32_t scheme = 0; @@ -95,8 +96,8 @@ static void CheckSigScheme(TlsInspectorRecordHandshakeMessage* capture, // in the default certificate. TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) { EnsureTlsSetup(); - auto capture_ske = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + auto capture_ske = std::make_shared( + kTlsHandshakeServerKeyExchange); server_->SetPacketFilter(capture_ske); Connect(); CheckKeys(); @@ -114,7 +115,8 @@ TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) { TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) { EnsureTlsSetup(); auto capture_cert_verify = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeCertificateVerify); + std::make_shared( + kTlsHandshakeCertificateVerify); client_->SetPacketFilter(capture_cert_verify); client_->SetupClientAuth(); server_->RequestClientAuth(true); @@ -127,7 +129,8 @@ TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) { TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) { Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048); auto capture_cert_verify = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeCertificateVerify); + std::make_shared( + kTlsHandshakeCertificateVerify); client_->SetPacketFilter(capture_cert_verify); client_->SetupClientAuth(); server_->RequestClientAuth(true); @@ -136,6 +139,76 @@ TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) { CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_sha256, 2048); } +class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter { + public: + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) { + if (header.handshake_type() != kTlsHandshakeCertificateRequest) { + return KEEP; + } + + TlsParser parser(input); + std::cerr << "Zeroing CertReq.supported_signature_algorithms" << std::endl; + + DataBuffer cert_types; + if (!parser.ReadVariable(&cert_types, 1)) { + ADD_FAILURE(); + return KEEP; + } + + if (!parser.SkipVariable(2)) { + ADD_FAILURE(); + return KEEP; + } + + DataBuffer cas; + if (!parser.ReadVariable(&cas, 2)) { + ADD_FAILURE(); + return KEEP; + } + + size_t idx = 0; + + // Write certificate types. + idx = output->Write(idx, cert_types.len(), 1); + idx = output->Write(idx, cert_types); + + // Write zero signature algorithms. + idx = output->Write(idx, 0U, 2); + + // Write certificate authorities. + idx = output->Write(idx, cas.len(), 2); + idx = output->Write(idx, cas); + + return CHANGE; + } +}; + +// Check that we fall back to SHA-1 when the server doesn't provide any +// supported_signature_algorithms in the CertificateRequest message. +TEST_P(TlsConnectTls12, ClientAuthNoSigAlgsFallback) { + EnsureTlsSetup(); + auto filter = std::make_shared(); + server_->SetPacketFilter(filter); + auto capture_cert_verify = + std::make_shared( + kTlsHandshakeCertificateVerify); + client_->SetPacketFilter(capture_cert_verify); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + + ConnectExpectAlert(server_, kTlsAlertDecryptError); + + // We're expecting a bad signature here because we tampered with a handshake + // message (CertReq). Previously, without the SHA-1 fallback, we would've + // seen a malformed record alert. + server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + + CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pkcs1_sha1, 1024); +} + static const SSLSignatureScheme SignatureSchemeEcdsaSha384[] = { ssl_sig_ecdsa_secp384r1_sha384}; static const SSLSignatureScheme SignatureSchemeEcdsaSha256[] = { @@ -198,20 +271,38 @@ TEST_P(TlsConnectGeneric, SignatureAlgorithmServerOnly) { ssl_sig_ecdsa_secp384r1_sha384); } -TEST_P(TlsConnectTls12Plus, SignatureSchemeCurveMismatch) { +// In TLS 1.2, curve and hash aren't bound together. +TEST_P(TlsConnectTls12, SignatureSchemeCurveMismatch) { Reset(TlsAgent::kServerEcdsa256); client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); - ConnectExpectFail(); + Connect(); +} + +// In TLS 1.3, curve and hash are coupled. +TEST_P(TlsConnectTls13, SignatureSchemeCurveMismatch) { + Reset(TlsAgent::kServerEcdsa256); + client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); } -TEST_P(TlsConnectTls12Plus, SignatureSchemeBadConfig) { +// Configuring a P-256 cert with only SHA-384 signatures is OK in TLS 1.2. +TEST_P(TlsConnectTls12, SignatureSchemeBadConfig) { Reset(TlsAgent::kServerEcdsa256); // P-256 cert can't be used. server_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); - ConnectExpectFail(); + Connect(); +} + +// A P-256 certificate in TLS 1.3 needs a SHA-256 signature scheme. +TEST_P(TlsConnectTls13, SignatureSchemeBadConfig) { + Reset(TlsAgent::kServerEcdsa256); // P-256 cert can't be used. + server_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); } @@ -234,7 +325,7 @@ TEST_P(TlsConnectTls12Plus, SignatureAlgorithmNoOverlapEcdsa) { PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); server_->SetSignatureSchemes(SignatureSchemeEcdsaSha256, PR_ARRAY_SIZE(SignatureSchemeEcdsaSha256)); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); } @@ -252,8 +343,8 @@ TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) { // The signature_algorithms extension is mandatory in TLS 1.3. TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) { client_->SetPacketFilter( - new TlsExtensionDropper(ssl_signature_algorithms_xtn)); - ConnectExpectFail(); + std::make_shared(ssl_signature_algorithms_xtn)); + ConnectExpectAlert(server_, kTlsAlertMissingExtension); client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); server_->CheckErrorCode(SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION); } @@ -262,8 +353,8 @@ TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) { // only fails when the Finished is checked. TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) { client_->SetPacketFilter( - new TlsExtensionDropper(ssl_signature_algorithms_xtn)); - ConnectExpectFail(); + std::make_shared(ssl_signature_algorithms_xtn)); + ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); } @@ -280,7 +371,8 @@ class BeforeFinished : public TlsRecordFilter { enum HandshakeState { BEFORE_CCS, AFTER_CCS, DONE }; public: - BeforeFinished(TlsAgent* client, TlsAgent* server, VoidFunction before_ccs, + BeforeFinished(std::shared_ptr& client, + std::shared_ptr& server, VoidFunction before_ccs, VoidFunction before_finished) : client_(client), server_(server), @@ -289,7 +381,7 @@ class BeforeFinished : public TlsRecordFilter { state_(BEFORE_CCS) {} protected: - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) { switch (state_) { @@ -303,8 +395,8 @@ class BeforeFinished : public TlsRecordFilter { // but that means that they both get processed together. DataBuffer ccs; header.Write(&ccs, 0, body); - server_->SendDirect(ccs); - client_->Handshake(); + server_.lock()->SendDirect(ccs); + client_.lock()->Handshake(); state_ = AFTER_CCS; // Request that the original record be dropped by the filter. return DROP; @@ -327,8 +419,8 @@ class BeforeFinished : public TlsRecordFilter { } private: - TlsAgent* client_; - TlsAgent* server_; + std::weak_ptr client_; + std::weak_ptr server_; VoidFunction before_ccs_; VoidFunction before_finished_; HandshakeState state_; @@ -353,7 +445,8 @@ class BeforeFinished13 : public PacketFilter { }; public: - BeforeFinished13(TlsAgent* client, TlsAgent* server, + BeforeFinished13(std::shared_ptr& client, + std::shared_ptr& server, VoidFunction before_finished) : client_(client), server_(server), @@ -367,7 +460,7 @@ class BeforeFinished13 : public PacketFilter { case 1: // Packet 1 is the server's entire first flight. Drop it. EXPECT_EQ(SECSuccess, - SSLInt_SetMTU(server_->ssl_fd(), input.len() - 1)); + SSLInt_SetMTU(server_.lock()->ssl_fd(), input.len() - 1)); return DROP; // Packet 2 is the first part of the server's retransmitted first @@ -377,7 +470,7 @@ class BeforeFinished13 : public PacketFilter { // Packet 3 is the second part of the server's retransmitted first // flight. Before passing that on, make sure that the client processes // packet 2, then call the before_finished_() callback. - client_->Handshake(); + client_.lock()->Handshake(); before_finished_(); break; @@ -388,8 +481,8 @@ class BeforeFinished13 : public PacketFilter { } private: - TlsAgent* client_; - TlsAgent* server_; + std::weak_ptr client_; + std::weak_ptr server_; VoidFunction before_finished_; size_t records_; }; @@ -403,9 +496,11 @@ static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) { // processed by the client, SSL_AuthCertificateComplete() is called. TEST_F(TlsConnectDatagram13, AuthCompleteBeforeFinished) { client_->SetAuthCertificateCallback(AuthCompleteBlock); - server_->SetPacketFilter(new BeforeFinished13(client_, server_, [this]() { - EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); - })); + server_->SetPacketFilter( + std::make_shared(client_, server_, [this]() { + EXPECT_EQ(SECSuccess, + SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + })); Connect(); } @@ -422,9 +517,9 @@ static void TriggerAuthComplete(PollTarget* target, Event event) { TEST_F(TlsConnectDatagram13, AuthCompleteAfterFinished) { client_->SetAuthCertificateCallback( [this](TlsAgent*, PRBool, PRBool) -> SECStatus { - Poller::Timer* timer_handle; + std::shared_ptr timer_handle; // This is really just to unroll the stack. - Poller::Instance()->SetTimer(1U, client_, TriggerAuthComplete, + Poller::Instance()->SetTimer(1U, client_.get(), TriggerAuthComplete, &timer_handle); return SECWouldBlock; }); @@ -433,7 +528,7 @@ TEST_F(TlsConnectDatagram13, AuthCompleteAfterFinished) { TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) { client_->EnableFalseStart(); - server_->SetPacketFilter(new BeforeFinished( + server_->SetPacketFilter(std::make_shared( client_, server_, [this]() { EXPECT_TRUE(client_->can_falsestart_hook_called()); }, [this]() { @@ -449,7 +544,7 @@ TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) { TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) { client_->EnableFalseStart(); client_->SetAuthCertificateCallback(AuthCompleteBlock); - server_->SetPacketFilter(new BeforeFinished( + server_->SetPacketFilter(std::make_shared( client_, server_, []() { // Do nothing before CCS @@ -496,7 +591,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) { EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); // The client should send nothing from here on. - client_->SetPacketFilter(new EnforceNoActivity()); + client_->SetPacketFilter(std::make_shared()); client_->Handshake(); EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); @@ -507,7 +602,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) { EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); // Remove this before closing or the close_notify alert will trigger it. - client_->SetPacketFilter(nullptr); + client_->DeletePacketFilter(); } // TLS 1.3 handles a delayed AuthComplete callback differently since the @@ -523,12 +618,12 @@ TEST_P(TlsConnectTls13, AuthCompleteDelayed) { EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); // The client will send nothing until AuthCertificateComplete is called. - client_->SetPacketFilter(new EnforceNoActivity()); + client_->SetPacketFilter(std::make_shared()); client_->Handshake(); EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); // This should allow the handshake to complete now. - client_->SetPacketFilter(nullptr); + client_->DeletePacketFilter(); EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); client_->Handshake(); // Send Finished server_->Handshake(); // Transition to connected and send NewSessionTicket @@ -621,8 +716,8 @@ TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPss) { &ServerCertDataRsaPss)); } -// mode, version, certificate, auth type, signature scheme -typedef std::tuple SignatureSchemeProfile; @@ -637,7 +732,7 @@ class TlsSignatureSchemeConfiguration signature_scheme_(std::get<4>(GetParam())) {} protected: - void TestSignatureSchemeConfig(TlsAgent* configPeer) { + void TestSignatureSchemeConfig(std::shared_ptr& configPeer) { EnsureTlsSetup(); configPeer->SetSignatureSchemes(&signature_scheme_, 1); Connect(); @@ -657,8 +752,8 @@ TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigServer) { TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigClient) { Reset(certificate_); - TlsExtensionCapture* capture = - new TlsExtensionCapture(ssl_signature_algorithms_xtn); + auto capture = + std::make_shared(ssl_signature_algorithms_xtn); client_->SetPacketFilter(capture); TestSignatureSchemeConfig(client_); @@ -683,7 +778,7 @@ TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigBoth) { INSTANTIATE_TEST_CASE_P( SignatureSchemeRsa, TlsSignatureSchemeConfiguration, ::testing::Combine( - TlsConnectTestBase::kTlsModesAll, TlsConnectTestBase::kTlsV12Plus, + TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus, ::testing::Values(TlsAgent::kServerRsaSign), ::testing::Values(ssl_auth_rsa_sign), ::testing::Values(ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384, @@ -692,42 +787,42 @@ INSTANTIATE_TEST_CASE_P( // PSS with SHA-512 needs a bigger key to work. INSTANTIATE_TEST_CASE_P( SignatureSchemeBigRsa, TlsSignatureSchemeConfiguration, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus, ::testing::Values(TlsAgent::kRsa2048), ::testing::Values(ssl_auth_rsa_sign), ::testing::Values(ssl_sig_rsa_pss_sha512))); INSTANTIATE_TEST_CASE_P( SignatureSchemeRsaSha1, TlsSignatureSchemeConfiguration, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12, ::testing::Values(TlsAgent::kServerRsa), ::testing::Values(ssl_auth_rsa_sign), ::testing::Values(ssl_sig_rsa_pkcs1_sha1))); INSTANTIATE_TEST_CASE_P( SignatureSchemeEcdsaP256, TlsSignatureSchemeConfiguration, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus, ::testing::Values(TlsAgent::kServerEcdsa256), ::testing::Values(ssl_auth_ecdsa), ::testing::Values(ssl_sig_ecdsa_secp256r1_sha256))); INSTANTIATE_TEST_CASE_P( SignatureSchemeEcdsaP384, TlsSignatureSchemeConfiguration, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus, ::testing::Values(TlsAgent::kServerEcdsa384), ::testing::Values(ssl_auth_ecdsa), ::testing::Values(ssl_sig_ecdsa_secp384r1_sha384))); INSTANTIATE_TEST_CASE_P( SignatureSchemeEcdsaP521, TlsSignatureSchemeConfiguration, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus, ::testing::Values(TlsAgent::kServerEcdsa521), ::testing::Values(ssl_auth_ecdsa), ::testing::Values(ssl_sig_ecdsa_secp521r1_sha512))); INSTANTIATE_TEST_CASE_P( SignatureSchemeEcdsaSha1, TlsSignatureSchemeConfiguration, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12, ::testing::Values(TlsAgent::kServerEcdsa256, TlsAgent::kServerEcdsa384), diff --git a/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc index 876c368..3463782 100644 --- a/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc @@ -23,9 +23,10 @@ namespace nss_test { // by the relevant callbacks on the client. class SignedCertificateTimestampsExtractor { public: - SignedCertificateTimestampsExtractor(TlsAgent* client) : client_(client) { - client_->SetAuthCertificateCallback( - [&](TlsAgent* agent, bool checksig, bool isServer) -> SECStatus { + SignedCertificateTimestampsExtractor(std::shared_ptr& client) + : client_(client) { + client->SetAuthCertificateCallback( + [this](TlsAgent* agent, bool checksig, bool isServer) -> SECStatus { const SECItem* scts = SSL_PeerSignedCertTimestamps(agent->ssl_fd()); EXPECT_TRUE(scts); if (!scts) { @@ -34,7 +35,7 @@ class SignedCertificateTimestampsExtractor { auth_timestamps_.reset(new DataBuffer(scts->data, scts->len)); return SECSuccess; }); - client_->SetHandshakeCallback([&](TlsAgent* agent) { + client->SetHandshakeCallback([this](TlsAgent* agent) { const SECItem* scts = SSL_PeerSignedCertTimestamps(agent->ssl_fd()); ASSERT_TRUE(scts); handshake_timestamps_.reset(new DataBuffer(scts->data, scts->len)); @@ -48,12 +49,13 @@ class SignedCertificateTimestampsExtractor { EXPECT_TRUE(handshake_timestamps_); EXPECT_EQ(timestamps, *handshake_timestamps_); - const SECItem* current = SSL_PeerSignedCertTimestamps(client_->ssl_fd()); + const SECItem* current = + SSL_PeerSignedCertTimestamps(client_.lock()->ssl_fd()); EXPECT_EQ(timestamps, DataBuffer(current->data, current->len)); } private: - TlsAgent* client_; + std::weak_ptr client_; std::unique_ptr auth_timestamps_; std::unique_ptr handshake_timestamps_; }; @@ -62,10 +64,22 @@ static const uint8_t kSctValue[] = {0x01, 0x23, 0x45, 0x67, 0x89}; static const SECItem kSctItem = {siBuffer, const_cast(kSctValue), sizeof(kSctValue)}; static const DataBuffer kSctBuffer(kSctValue, sizeof(kSctValue)); +static const SSLExtraServerCertData kExtraSctData = {ssl_auth_null, nullptr, + nullptr, &kSctItem}; // Test timestamps extraction during a successful handshake. -TEST_P(TlsConnectGeneric, SignedCertificateTimestampsHandshake) { +TEST_P(TlsConnectGenericPre13, SignedCertificateTimestampsLegacy) { EnsureTlsSetup(); + + // We have to use the legacy API consistently here for configuring certs. + // Also, this doesn't work in TLS 1.3 because this only configures the SCT for + // RSA decrypt and PKCS#1 signing, not PSS. + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsa, &cert, &priv)); + EXPECT_EQ(SECSuccess, SSL_ConfigSecureServerWithCertChain( + server_->ssl_fd(), cert.get(), nullptr, priv.get(), + ssl_kea_rsa)); EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(), &kSctItem, ssl_kea_rsa)); EXPECT_EQ(SECSuccess, @@ -78,13 +92,10 @@ TEST_P(TlsConnectGeneric, SignedCertificateTimestampsHandshake) { timestamps_extractor.assertTimestamps(kSctBuffer); } -TEST_P(TlsConnectGeneric, SignedCertificateTimestampsConfig) { - static const SSLExtraServerCertData kExtraData = {ssl_auth_rsa_sign, nullptr, - nullptr, &kSctItem}; - +TEST_P(TlsConnectGeneric, SignedCertificateTimestampsSuccess) { EnsureTlsSetup(); EXPECT_TRUE( - server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraData)); + server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraSctData)); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE)); @@ -99,8 +110,8 @@ TEST_P(TlsConnectGeneric, SignedCertificateTimestampsConfig) { // when the client / the server / both have not enabled the feature. TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveClient) { EnsureTlsSetup(); - EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(), - &kSctItem, ssl_kea_rsa)); + EXPECT_TRUE( + server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraSctData)); SignedCertificateTimestampsExtractor timestamps_extractor(client_); Connect(); @@ -141,8 +152,8 @@ static const SECItem kOcspItems[] = { {siBuffer, const_cast(kOcspValue2), sizeof(kOcspValue2)}}; static const SECItemArray kOcspResponses = {const_cast(kOcspItems), PR_ARRAY_SIZE(kOcspItems)}; -const static SSLExtraServerCertData kOcspExtraData = { - ssl_auth_rsa_sign, nullptr, &kOcspResponses, nullptr}; +const static SSLExtraServerCertData kOcspExtraData = {ssl_auth_null, nullptr, + &kOcspResponses, nullptr}; TEST_P(TlsConnectGeneric, NoOcsp) { EnsureTlsSetup(); @@ -176,10 +187,10 @@ TEST_P(TlsConnectGenericPre13, OcspMangled) { server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData)); static const uint8_t val[] = {1}; - auto replacer = new TlsExtensionReplacer(ssl_cert_status_xtn, - DataBuffer(val, sizeof(val))); + auto replacer = std::make_shared( + ssl_cert_status_xtn, DataBuffer(val, sizeof(val))); server_->SetPacketFilter(replacer); - ConnectExpectFail(); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); } @@ -188,7 +199,8 @@ TEST_P(TlsConnectGeneric, OcspSuccess) { EnsureTlsSetup(); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); - auto capture_ocsp = new TlsExtensionCapture(ssl_cert_status_xtn); + auto capture_ocsp = + std::make_shared(ssl_cert_status_xtn); server_->SetPacketFilter(capture_ocsp); // The value should be available during the AuthCertificateCallback @@ -211,4 +223,35 @@ TEST_P(TlsConnectGeneric, OcspSuccess) { EXPECT_EQ(0U, capture_ocsp->extension().len()); } +TEST_P(TlsConnectGeneric, OcspHugeSuccess) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); + + uint8_t hugeOcspValue[16385]; + memset(hugeOcspValue, 0xa1, sizeof(hugeOcspValue)); + const SECItem hugeOcspItems[] = { + {siBuffer, const_cast(hugeOcspValue), sizeof(hugeOcspValue)}}; + const SECItemArray hugeOcspResponses = {const_cast(hugeOcspItems), + PR_ARRAY_SIZE(hugeOcspItems)}; + const SSLExtraServerCertData hugeOcspExtraData = { + ssl_auth_null, nullptr, &hugeOcspResponses, nullptr}; + + // The value should be available during the AuthCertificateCallback + client_->SetAuthCertificateCallback([&](TlsAgent* agent, bool checksig, + bool isServer) -> SECStatus { + const SECItemArray* ocsp = SSL_PeerStapledOCSPResponses(agent->ssl_fd()); + if (!ocsp) { + return SECFailure; + } + EXPECT_EQ(1U, ocsp->len) << "We only provide the first item"; + EXPECT_EQ(0, SECITEM_CompareItem(&hugeOcspItems[0], &ocsp->items[0])); + return SECSuccess; + }); + EXPECT_TRUE(server_->ConfigServerCert(TlsAgent::kServerRsa, true, + &hugeOcspExtraData)); + + Connect(); +} + } // namespace nspr_test diff --git a/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc index ab10a84..85c30b2 100644 --- a/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc @@ -22,17 +22,17 @@ extern "C" { namespace nss_test { -// mode, version, cipher suite -typedef std::tuple CipherSuiteProfile; class TlsCipherSuiteTestBase : public TlsConnectTestBase { public: - TlsCipherSuiteTestBase(const std::string &mode, uint16_t version, + TlsCipherSuiteTestBase(SSLProtocolVariant variant, uint16_t version, uint16_t cipher_suite, SSLNamedGroup group, SSLSignatureScheme signature_scheme) - : TlsConnectTestBase(mode, version), + : TlsConnectTestBase(variant, version), cipher_suite_(cipher_suite), group_(group), signature_scheme_(signature_scheme), @@ -128,16 +128,22 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { Connect(); SendReceive(); - // Check that we used the right cipher suite. + // Check that we used the right cipher suite, auth type and kea type. uint16_t actual; - EXPECT_TRUE(client_->cipher_suite(&actual) && actual == cipher_suite_); - EXPECT_TRUE(server_->cipher_suite(&actual) && actual == cipher_suite_); + EXPECT_TRUE(client_->cipher_suite(&actual)); + EXPECT_EQ(cipher_suite_, actual); + EXPECT_TRUE(server_->cipher_suite(&actual)); + EXPECT_EQ(cipher_suite_, actual); SSLAuthType auth; - EXPECT_TRUE(client_->auth_type(&auth) && auth == auth_type_); - EXPECT_TRUE(server_->auth_type(&auth) && auth == auth_type_); + EXPECT_TRUE(client_->auth_type(&auth)); + EXPECT_EQ(auth_type_, auth); + EXPECT_TRUE(server_->auth_type(&auth)); + EXPECT_EQ(auth_type_, auth); SSLKEAType kea; - EXPECT_TRUE(client_->kea_type(&kea) && kea == kea_type_); - EXPECT_TRUE(server_->kea_type(&kea) && kea == kea_type_); + EXPECT_TRUE(client_->kea_type(&kea)); + EXPECT_EQ(kea_type_, kea); + EXPECT_TRUE(server_->kea_type(&kea)); + EXPECT_EQ(kea_type_, kea); } // Get the expected limit on the number of records that can be sent for the @@ -252,14 +258,17 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) { // authentication tag. static const uint8_t payload[18] = {6}; DataBuffer record; - uint64_t epoch = 0; - if (mode_ == DGRAM) { - epoch++; + uint64_t epoch; + if (variant_ == ssl_variant_datagram) { if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { - epoch++; + epoch = 3; // Application traffic keys. + } else { + epoch = 1; } + } else { + epoch = 0; } - TlsAgentTestBase::MakeRecord(mode_, kTlsApplicationDataType, version_, + TlsAgentTestBase::MakeRecord(variant_, kTlsApplicationDataType, version_, payload, sizeof(payload), &record, (epoch << 48) | record_limit()); server_->adapter()->PacketReceived(record); @@ -287,7 +296,7 @@ TEST_P(TlsCipherSuiteTest, WriteLimit) { k##name##Ciphers = ::testing::ValuesIn(k##name##CiphersArr); \ INSTANTIATE_TEST_CASE_P( \ CipherSuite##name, TlsCipherSuiteTest, \ - ::testing::Combine(TlsConnectTestBase::kTlsModes##modes, \ + ::testing::Combine(TlsConnectTestBase::kTlsVariants##modes, \ TlsConnectTestBase::kTls##versions, k##name##Ciphers, \ groups, sigalgs)); @@ -396,7 +405,7 @@ class SecurityStatusTest public ::testing::WithParamInterface { public: SecurityStatusTest() - : TlsCipherSuiteTestBase("TLS", GetParam().version, + : TlsCipherSuiteTestBase(ssl_variant_stream, GetParam().version, GetParam().cipher_suite, ssl_grp_none, ssl_sig_none) {} }; diff --git a/nss/gtests/ssl_gtest/ssl_damage_unittest.cc b/nss/gtests/ssl_gtest/ssl_damage_unittest.cc index 9dadcbd..69fd003 100644 --- a/nss/gtests/ssl_gtest/ssl_damage_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_damage_unittest.cc @@ -33,12 +33,14 @@ TEST_F(TlsConnectTest, DamageSecretHandleClientFinished) { client_->StartConnect(); client_->Handshake(); server_->Handshake(); - std::cerr << "Damaging HS secret\n"; + std::cerr << "Damaging HS secret" << std::endl; SSLInt_DamageClientHsTrafficSecret(server_->ssl_fd()); client_->Handshake(); - server_->Handshake(); // The client thinks it has connected. EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + + ExpectAlert(server_, kTlsAlertDecryptError); + server_->Handshake(); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); client_->Handshake(); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); @@ -49,7 +51,10 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) { SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_3); - server_->SetPacketFilter(new AfterRecordN( + client_->ExpectSendAlert(kTlsAlertDecryptError); + // The server can't read the client's alert, so it also sends an alert. + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->SetPacketFilter(std::make_shared( server_, client_, 0, // ServerHello. [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); })); @@ -58,4 +63,57 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) { server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); } +TEST_P(TlsConnectGenericPre13, DamageServerSignature) { + EnsureTlsSetup(); + auto filter = + std::make_shared(kTlsHandshakeServerKeyExchange); + server_->SetTlsRecordFilter(filter); + ExpectAlert(client_, kTlsAlertDecryptError); + ConnectExpectFail(); + client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); + server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); +} + +TEST_P(TlsConnectTls13, DamageServerSignature) { + EnsureTlsSetup(); + auto filter = + std::make_shared(kTlsHandshakeCertificateVerify); + server_->SetTlsRecordFilter(filter); + filter->EnableDecryption(); + client_->ExpectSendAlert(kTlsAlertDecryptError); + // The server can't read the client's alert, so it also sends an alert. + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + } else { + ConnectExpectFailOneSide(TlsAgent::CLIENT); + } + client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); +} + +TEST_P(TlsConnectGeneric, DamageClientSignature) { + EnsureTlsSetup(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + auto filter = + std::make_shared(kTlsHandshakeCertificateVerify); + client_->SetTlsRecordFilter(filter); + server_->ExpectSendAlert(kTlsAlertDecryptError); + filter->EnableDecryption(); + // Do these handshakes by hand to avoid race condition on + // the client processing the server's alert. + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + server_->Handshake(); + EXPECT_EQ(version_ >= SSL_LIBRARY_VERSION_TLS_1_3 + ? TlsAgent::STATE_CONNECTED + : TlsAgent::STATE_CONNECTING, + client_->state()); + server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); +} + } // namespace nspr_test diff --git a/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc index 82d5558..9794330 100644 --- a/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc @@ -31,12 +31,13 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) { EnsureTlsSetup(); client_->ConfigNamedGroups(kAllDHEGroups); - auto groups_capture = new TlsExtensionCapture(ssl_supported_groups_xtn); - auto shares_capture = new TlsExtensionCapture(ssl_tls13_key_share_xtn); - std::vector captures; - captures.push_back(groups_capture); - captures.push_back(shares_capture); - client_->SetPacketFilter(new ChainedPacketFilter(captures)); + auto groups_capture = + std::make_shared(ssl_supported_groups_xtn); + auto shares_capture = + std::make_shared(ssl_tls13_key_share_xtn); + std::vector> captures = {groups_capture, + shares_capture}; + client_->SetPacketFilter(std::make_shared(captures)); Connect(); @@ -60,12 +61,13 @@ TEST_P(TlsConnectGeneric, ConnectFfdheClient) { EnableOnlyDheCiphers(); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); - auto groups_capture = new TlsExtensionCapture(ssl_supported_groups_xtn); - auto shares_capture = new TlsExtensionCapture(ssl_tls13_key_share_xtn); - std::vector captures; - captures.push_back(groups_capture); - captures.push_back(shares_capture); - client_->SetPacketFilter(new ChainedPacketFilter(captures)); + auto groups_capture = + std::make_shared(ssl_supported_groups_xtn); + auto shares_capture = + std::make_shared(ssl_tls13_key_share_xtn); + std::vector> captures = {groups_capture, + shares_capture}; + client_->SetPacketFilter(std::make_shared(captures)); Connect(); @@ -95,7 +97,7 @@ TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) { Connect(); CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign); } else { - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); } @@ -126,9 +128,9 @@ TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) { EnableOnlyDheCiphers(); EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); - server_->SetPacketFilter(new TlsDheServerKeyExchangeDamager()); + server_->SetPacketFilter(std::make_shared()); - ConnectExpectFail(); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_WEAK_SERVER_EPHEMERAL_DH_KEY); server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); @@ -249,8 +251,9 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { class TlsDheSkeChangeYClient : public TlsDheSkeChangeY { public: - TlsDheSkeChangeYClient(ChangeYTo change, - const TlsDheSkeChangeYServer* server_filter) + TlsDheSkeChangeYClient( + ChangeYTo change, + std::shared_ptr server_filter) : TlsDheSkeChangeY(change), server_filter_(server_filter) {} protected: @@ -266,13 +269,14 @@ class TlsDheSkeChangeYClient : public TlsDheSkeChangeY { } private: - const TlsDheSkeChangeYServer* server_filter_; + std::shared_ptr server_filter_; }; -/* This matrix includes: mode (stream/datagram), TLS version, what change to +/* This matrix includes: variant (stream/datagram), TLS version, what change to * make to dh_Ys, whether the client will be configured to require DH named * groups. Test all combinations. */ -typedef std::tuple +typedef std::tuple DamageDHYProfile; class TlsDamageDHYTest : public TlsConnectTestBase, @@ -289,8 +293,14 @@ TEST_P(TlsDamageDHYTest, DamageServerY) { SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); } TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); - server_->SetPacketFilter(new TlsDheSkeChangeYServer(change, true)); + server_->SetPacketFilter( + std::make_shared(change, true)); + if (change == TlsDheSkeChangeY::kYZeroPad) { + ExpectAlert(client_, kTlsAlertDecryptError); + } else { + ExpectAlert(client_, kTlsAlertIllegalParameter); + } ConnectExpectFail(); if (change == TlsDheSkeChangeY::kYZeroPad) { // Zero padding Y only manifests in a signature failure. @@ -314,14 +324,20 @@ TEST_P(TlsDamageDHYTest, DamageClientY) { SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); } // The filter on the server is required to capture the prime. - TlsDheSkeChangeYServer* server_filter = - new TlsDheSkeChangeYServer(TlsDheSkeChangeY::kYZero, false); + auto server_filter = + std::make_shared(TlsDheSkeChangeY::kYZero, false); server_->SetPacketFilter(server_filter); // The client filter does the damage. TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); - client_->SetPacketFilter(new TlsDheSkeChangeYClient(change, server_filter)); + client_->SetPacketFilter( + std::make_shared(change, server_filter)); + if (change == TlsDheSkeChangeY::kYZeroPad) { + ExpectAlert(server_, kTlsAlertDecryptError); + } else { + ExpectAlert(server_, kTlsAlertHandshakeFailure); + } ConnectExpectFail(); if (change == TlsDheSkeChangeY::kYZeroPad) { // Zero padding Y only manifests in a finished error. @@ -343,13 +359,13 @@ static const bool kTrueFalseArr[] = {true, false}; static ::testing::internal::ParamGenerator kTrueFalse = ::testing::ValuesIn(kTrueFalseArr); -INSTANTIATE_TEST_CASE_P(DamageYStream, TlsDamageDHYTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10ToV12, - kAllY, kTrueFalse)); +INSTANTIATE_TEST_CASE_P( + DamageYStream, TlsDamageDHYTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12, kAllY, kTrueFalse)); INSTANTIATE_TEST_CASE_P( DamageYDatagram, TlsDamageDHYTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11V12, kAllY, kTrueFalse)); class TlsDheSkeMakePEven : public TlsHandshakeFilter { @@ -378,9 +394,9 @@ class TlsDheSkeMakePEven : public TlsHandshakeFilter { // Even without requiring named groups, an even value for p is bad news. TEST_P(TlsConnectGenericPre13, MakeDhePEven) { EnableOnlyDheCiphers(); - server_->SetPacketFilter(new TlsDheSkeMakePEven()); + server_->SetPacketFilter(std::make_shared()); - ConnectExpectFail(); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE); server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); @@ -409,9 +425,9 @@ class TlsDheSkeZeroPadP : public TlsHandshakeFilter { // Zero padding only causes signature failure. TEST_P(TlsConnectGenericPre13, PadDheP) { EnableOnlyDheCiphers(); - server_->SetPacketFilter(new TlsDheSkeZeroPadP()); + server_->SetPacketFilter(std::make_shared()); - ConnectExpectFail(); + ConnectExpectAlert(client_, kTlsAlertDecryptError); // In TLS 1.0 and 1.1, the client reports a device error. if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) { @@ -470,7 +486,7 @@ TEST_P(TlsConnectTls13, NamedGroupMismatch13) { server_->ConfigNamedGroups(server_groups); client_->ConfigNamedGroups(client_groups); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); } @@ -488,7 +504,7 @@ TEST_P(TlsConnectGenericPre13, RequireNamedGroupsMismatchPre13) { server_->ConfigNamedGroups(server_groups); client_->ConfigNamedGroups(client_groups); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); } @@ -518,7 +534,7 @@ TEST_P(TlsConnectGenericPre13, MismatchDHE) { EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(client_->ssl_fd(), clientGroups, PR_ARRAY_SIZE(clientGroups))); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); } @@ -533,11 +549,11 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) { Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); EnableOnlyDheCiphers(); - TlsExtensionCapture* clientCapture = - new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn); + auto clientCapture = + std::make_shared(ssl_tls13_pre_shared_key_xtn); client_->SetPacketFilter(clientCapture); - TlsExtensionCapture* serverCapture = - new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn); + auto serverCapture = + std::make_shared(ssl_tls13_pre_shared_key_xtn); server_->SetPacketFilter(serverCapture); ExpectResumption(RESUME_TICKET); Connect(); @@ -599,10 +615,10 @@ TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) { const std::vector client_groups = {ssl_grp_ffdhe_2048}; client_->ConfigNamedGroups(client_groups); - server_->SetPacketFilter(new TlsDheSkeChangeSignature( + server_->SetPacketFilter(std::make_shared( version_, kBogusDheSignature, sizeof(kBogusDheSignature))); - ConnectExpectFail(); + ConnectExpectAlert(client_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); } diff --git a/nss/gtests/ssl_gtest/ssl_drop_unittest.cc b/nss/gtests/ssl_gtest/ssl_drop_unittest.cc index 89ca28e..3cc3b0e 100644 --- a/nss/gtests/ssl_gtest/ssl_drop_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_drop_unittest.cc @@ -21,13 +21,13 @@ extern "C" { namespace nss_test { TEST_P(TlsConnectDatagram, DropClientFirstFlightOnce) { - client_->SetPacketFilter(new SelectiveDropFilter(0x1)); + client_->SetPacketFilter(std::make_shared(0x1)); Connect(); SendReceive(); } TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) { - server_->SetPacketFilter(new SelectiveDropFilter(0x1)); + server_->SetPacketFilter(std::make_shared(0x1)); Connect(); SendReceive(); } @@ -36,32 +36,32 @@ TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) { // flights that they send. Note: In DTLS 1.3, the shorter handshake means that // this will also drop some application data, so we can't call SendReceive(). TEST_P(TlsConnectDatagram, DropAllFirstTransmissions) { - client_->SetPacketFilter(new SelectiveDropFilter(0x15)); - server_->SetPacketFilter(new SelectiveDropFilter(0x5)); + client_->SetPacketFilter(std::make_shared(0x15)); + server_->SetPacketFilter(std::make_shared(0x5)); Connect(); } // This drops the server's first flight three times. TEST_P(TlsConnectDatagram, DropServerFirstFlightThrice) { - server_->SetPacketFilter(new SelectiveDropFilter(0x7)); + server_->SetPacketFilter(std::make_shared(0x7)); Connect(); } // This drops the client's second flight once TEST_P(TlsConnectDatagram, DropClientSecondFlightOnce) { - client_->SetPacketFilter(new SelectiveDropFilter(0x2)); + client_->SetPacketFilter(std::make_shared(0x2)); Connect(); } // This drops the client's second flight three times. TEST_P(TlsConnectDatagram, DropClientSecondFlightThrice) { - client_->SetPacketFilter(new SelectiveDropFilter(0xe)); + client_->SetPacketFilter(std::make_shared(0xe)); Connect(); } // This drops the server's second flight three times. TEST_P(TlsConnectDatagram, DropServerSecondFlightThrice) { - server_->SetPacketFilter(new SelectiveDropFilter(0xe)); + server_->SetPacketFilter(std::make_shared(0xe)); Connect(); } diff --git a/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc index 43dfcba..1e406b6 100644 --- a/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc @@ -58,7 +58,7 @@ TEST_P(TlsConnectTls12, ConnectEcdheP384) { Reset(TlsAgent::kServerEcdsa384); ConnectWithCipherSuite(TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_ecdsa, - ssl_sig_ecdsa_secp384r1_sha384); + ssl_sig_ecdsa_secp256r1_sha256); } TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) { @@ -75,8 +75,8 @@ TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) { // This causes a HelloRetryRequest in TLS 1.3. Earlier versions don't care. TEST_P(TlsConnectGeneric, ConnectEcdheP384Server) { EnsureTlsSetup(); - auto hrr_capture = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest); + auto hrr_capture = std::make_shared( + kTlsHandshakeHelloRetryRequest); server_->SetPacketFilter(hrr_capture); const std::vector groups = {ssl_grp_ec_secp384r1}; server_->ConfigNamedGroups(groups); @@ -191,6 +191,60 @@ TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) { ssl_sig_rsa_pss_sha256); } +class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter { + public: + TlsKeyExchangeGroupCapture() : group_(ssl_grp_none) {} + + SSLNamedGroup group() const { return group_; } + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, + const DataBuffer &input, + DataBuffer *output) { + if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { + return KEEP; + } + + uint32_t value = 0; + EXPECT_TRUE(input.Read(0, 1, &value)); + EXPECT_EQ(3U, value) << "curve type has to be 3"; + + EXPECT_TRUE(input.Read(1, 2, &value)); + group_ = static_cast(value); + + return KEEP; + } + + private: + SSLNamedGroup group_; +}; + +// If we strip the client's supported groups extension, the server should assume +// P-256 is supported by the client (<= 1.2 only). +TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) { + EnsureTlsSetup(); + client_->SetPacketFilter( + std::make_shared(ssl_supported_groups_xtn)); + auto group_capture = std::make_shared(); + server_->SetPacketFilter(group_capture); + + ConnectExpectAlert(server_, kTlsAlertDecryptError); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + + EXPECT_EQ(ssl_grp_ec_secp256r1, group_capture->group()); +} + +// Supported groups is mandatory in TLS 1.3. +TEST_P(TlsConnectTls13, DropSupportedGroupExtension) { + EnsureTlsSetup(); + client_->SetPacketFilter( + std::make_shared(ssl_supported_groups_xtn)); + ConnectExpectAlert(server_, kTlsAlertMissingExtension); + client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION); +} + // If we only have a lame group, we fall back to static RSA. TEST_P(TlsConnectGenericPre13, UseLameGroup) { const std::vector groups = {ssl_grp_ec_secp192r1}; @@ -431,7 +485,7 @@ TEST_P(TlsConnectGeneric, P256ClientAndCurve25519Server) { client_->ConfigNamedGroups(client_groups); server_->ConfigNamedGroups(server_groups); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); } @@ -507,25 +561,25 @@ class ECCServerKEXFilter : public TlsHandshakeFilter { TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyServerPoint) { // add packet filter - server_->SetPacketFilter(new ECCServerKEXFilter()); - ConnectExpectFail(); + server_->SetPacketFilter(std::make_shared()); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH); } TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyClientPoint) { // add packet filter - client_->SetPacketFilter(new ECCClientKEXFilter()); - ConnectExpectFail(); + client_->SetPacketFilter(std::make_shared()); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_KEY_EXCH); } INSTANTIATE_TEST_CASE_P(KeyExchangeTest, TlsKeyExchangeTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV11Plus)); #ifndef NSS_DISABLE_TLS_1_3 INSTANTIATE_TEST_CASE_P(KeyExchangeTest, TlsKeyExchangeTest13, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV13)); #endif diff --git a/nss/gtests/ssl_gtest/ssl_ems_unittest.cc b/nss/gtests/ssl_gtest/ssl_ems_unittest.cc index b9c725b..dad6ca0 100644 --- a/nss/gtests/ssl_gtest/ssl_ems_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_ems_unittest.cc @@ -79,11 +79,7 @@ TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretResumeWithout) { Reset(); server_->EnableExtendedMasterSecret(); - auto alert_recorder = new TlsAlertRecorder(); - server_->SetPacketFilter(alert_recorder); - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertHandshakeFailure, alert_recorder->description()); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); } TEST_P(TlsConnectGenericPre13, ConnectNormalResumeWithExtendedMasterSecret) { diff --git a/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc b/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc index 0a0d9f2..be407b4 100644 --- a/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc @@ -14,7 +14,8 @@ namespace nss_test { static const char* kExporterLabel = "EXPORTER-duck"; static const uint8_t kExporterContext[] = {0x12, 0x34, 0x56}; -static void ExportAndCompare(TlsAgent* client, TlsAgent* server, bool context) { +static void ExportAndCompare(std::shared_ptr& client, + std::shared_ptr& server, bool context) { static const size_t exporter_len = 10; uint8_t client_value[exporter_len] = {0}; EXPECT_EQ(SECSuccess, @@ -76,6 +77,33 @@ TEST_P(TlsConnectTls13, ExporterContextEmptyIsSameAsNone) { ExportAndCompare(client_, server_, false); } +TEST_P(TlsConnectGenericPre13, ExporterContextLengthTooLong) { + static const uint8_t kExporterContextTooLong[PR_UINT16_MAX] = { + 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xFF}; + + EnsureTlsSetup(); + Connect(); + CheckKeys(); + + static const size_t exporter_len = 10; + uint8_t client_value[exporter_len] = {0}; + EXPECT_EQ(SECFailure, + SSL_ExportKeyingMaterial(client_->ssl_fd(), kExporterLabel, + strlen(kExporterLabel), PR_TRUE, + kExporterContextTooLong, + sizeof(kExporterContextTooLong), + client_value, sizeof(client_value))); + EXPECT_EQ(PORT_GetError(), SEC_ERROR_INVALID_ARGS); + uint8_t server_value[exporter_len] = {0xff}; + EXPECT_EQ(SECFailure, + SSL_ExportKeyingMaterial(server_->ssl_fd(), kExporterLabel, + strlen(kExporterLabel), PR_TRUE, + kExporterContextTooLong, + sizeof(kExporterContextTooLong), + server_value, sizeof(server_value))); + EXPECT_EQ(PORT_GetError(), SEC_ERROR_INVALID_ARGS); +} + // This has a weird signature so that it can be passed to the SNI callback. int32_t RegularExporterShouldFail(TlsAgent* agent, const SECItem* srvNameArr, PRUint32 srvNameArrSize) { @@ -90,13 +118,15 @@ int32_t RegularExporterShouldFail(TlsAgent* agent, const SECItem* srvNameArr, TEST_P(TlsConnectTls13, EarlyExporter) { SetupForZeroRtt(); + ExpectAlert(client_, kTlsAlertEndOfEarlyData); client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); ExpectResumption(RESUME_TICKET); client_->Handshake(); // Send ClientHello. uint8_t client_value[10] = {0}; - RegularExporterShouldFail(client_, nullptr, 0); + RegularExporterShouldFail(client_.get(), nullptr, 0); + EXPECT_EQ(SECSuccess, SSL_ExportEarlyKeyingMaterial( client_->ssl_fd(), kExporterLabel, strlen(kExporterLabel), diff --git a/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/nss/gtests/ssl_gtest/ssl_extension_unittest.cc index 9200e72..d151394 100644 --- a/nss/gtests/ssl_gtest/ssl_extension_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -69,22 +69,11 @@ class TlsExtensionInjector : public TlsHandshakeFilter { virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - size_t offset; - if (header.handshake_type() == kTlsHandshakeClientHello) { - TlsParser parser(input); - if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) { - return KEEP; - } - offset = parser.consumed(); - } else if (header.handshake_type() == kTlsHandshakeServerHello) { - TlsParser parser(input); - if (!TlsExtensionFilter::FindServerHelloExtensions(&parser)) { - return KEEP; - } - offset = parser.consumed(); - } else { + TlsParser parser(input); + if (!TlsExtensionFilter::FindExtensions(&parser, header)) { return KEEP; } + size_t offset = parser.consumed(); *output = input; @@ -116,38 +105,41 @@ class TlsExtensionInjector : public TlsHandshakeFilter { class TlsExtensionAppender : public TlsHandshakeFilter { public: - TlsExtensionAppender(uint16_t ext, DataBuffer& data) - : extension_(ext), data_(data) {} + TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data) + : handshake_type_(handshake_type), extension_(ext), data_(data) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - size_t offset; + if (header.handshake_type() != handshake_type_) { + return KEEP; + } + TlsParser parser(input); - if (header.handshake_type() == kTlsHandshakeClientHello) { - if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) { - return KEEP; - } - } else if (header.handshake_type() == kTlsHandshakeServerHello) { - if (!TlsExtensionFilter::FindServerHelloExtensions(&parser)) { - return KEEP; - } - } else { + if (!TlsExtensionFilter::FindExtensions(&parser, header)) { return KEEP; } - offset = parser.consumed(); *output = input; - uint32_t ext_len; - if (!parser.Read(&ext_len, 2)) { - ADD_FAILURE(); + // Increase the length of the extensions block. + if (!UpdateLength(output, parser.consumed(), 2)) { return KEEP; } - ext_len += 4 + data_.len(); - output->Write(offset, ext_len, 2); + // Extensions in Certificate are nested twice. Increase the size of the + // certificate list. + if (header.handshake_type() == kTlsHandshakeCertificate) { + TlsParser p2(input); + if (!p2.SkipVariable(1)) { + ADD_FAILURE(); + return KEEP; + } + if (!UpdateLength(output, p2.consumed(), 3)) { + return KEEP; + } + } - offset = output->len(); + size_t offset = output->len(); offset = output->Write(offset, extension_, 2); WriteVariable(output, offset, data_, 2); @@ -155,39 +147,38 @@ class TlsExtensionAppender : public TlsHandshakeFilter { } private: + bool UpdateLength(DataBuffer* output, size_t offset, size_t size) { + uint32_t len; + if (!output->Read(offset, size, &len)) { + ADD_FAILURE(); + return false; + } + + len += 4 + data_.len(); + output->Write(offset, len, size); + return true; + } + + const uint8_t handshake_type_; const uint16_t extension_; const DataBuffer data_; }; class TlsExtensionTestBase : public TlsConnectTestBase { protected: - TlsExtensionTestBase(Mode mode, uint16_t version) - : TlsConnectTestBase(mode, version) {} - TlsExtensionTestBase(const std::string& mode, uint16_t version) - : TlsConnectTestBase(mode, version) {} - - void ClientHelloErrorTest(PacketFilter* filter, - uint8_t alert = kTlsAlertDecodeError) { - auto alert_recorder = new TlsAlertRecorder(); - server_->SetPacketFilter(alert_recorder); - if (filter) { - client_->SetPacketFilter(filter); - } - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(alert, alert_recorder->description()); + TlsExtensionTestBase(SSLProtocolVariant variant, uint16_t version) + : TlsConnectTestBase(variant, version) {} + + void ClientHelloErrorTest(std::shared_ptr filter, + uint8_t desc = kTlsAlertDecodeError) { + client_->SetPacketFilter(filter); + ConnectExpectAlert(server_, desc); } - void ServerHelloErrorTest(PacketFilter* filter, - uint8_t alert = kTlsAlertDecodeError) { - auto alert_recorder = new TlsAlertRecorder(); - client_->SetPacketFilter(alert_recorder); - if (filter) { - server_->SetPacketFilter(filter); - } - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(alert, alert_recorder->description()); + void ServerHelloErrorTest(std::shared_ptr filter, + uint8_t desc = kTlsAlertDecodeError) { + server_->SetPacketFilter(filter); + ConnectExpectAlert(client_, desc); } static void InitSimpleSni(DataBuffer* extension) { @@ -213,7 +204,7 @@ class TlsExtensionTestBase : public TlsConnectTestBase { server_->StartConnect(); client_->Handshake(); // Send ClientHello server_->Handshake(); // Send HRR. - client_->SetPacketFilter(new TlsExtensionDropper(type)); + client_->SetPacketFilter(std::make_shared(type)); Handshake(); client_->CheckErrorCode(client_error); server_->CheckErrorCode(server_error); @@ -223,38 +214,40 @@ class TlsExtensionTestBase : public TlsConnectTestBase { class TlsExtensionTestDtls : public TlsExtensionTestBase, public ::testing::WithParamInterface { public: - TlsExtensionTestDtls() : TlsExtensionTestBase(DGRAM, GetParam()) {} + TlsExtensionTestDtls() + : TlsExtensionTestBase(ssl_variant_datagram, GetParam()) {} }; -class TlsExtensionTest12Plus - : public TlsExtensionTestBase, - public ::testing::WithParamInterface> { +class TlsExtensionTest12Plus : public TlsExtensionTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsExtensionTest12Plus() : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { } }; -class TlsExtensionTest12 - : public TlsExtensionTestBase, - public ::testing::WithParamInterface> { +class TlsExtensionTest12 : public TlsExtensionTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsExtensionTest12() : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { } }; -class TlsExtensionTest13 : public TlsExtensionTestBase, - public ::testing::WithParamInterface { +class TlsExtensionTest13 + : public TlsExtensionTestBase, + public ::testing::WithParamInterface { public: TlsExtensionTest13() : TlsExtensionTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) { DataBuffer versions_buf(buf, len); - client_->SetPacketFilter(new TlsExtensionReplacer( + client_->SetPacketFilter(std::make_shared( ssl_tls13_supported_versions_xtn, versions_buf)); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); } @@ -264,7 +257,7 @@ class TlsExtensionTest13 : public TlsExtensionTestBase, size_t index = versions_buf.Write(0, 2, 1); versions_buf.Write(index, version, 2); - client_->SetPacketFilter(new TlsExtensionReplacer( + client_->SetPacketFilter(std::make_shared( ssl_tls13_supported_versions_xtn, versions_buf)); ConnectExpectFail(); } @@ -273,21 +266,21 @@ class TlsExtensionTest13 : public TlsExtensionTestBase, class TlsExtensionTest13Stream : public TlsExtensionTestBase { public: TlsExtensionTest13Stream() - : TlsExtensionTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {} + : TlsExtensionTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {} }; -class TlsExtensionTestGeneric - : public TlsExtensionTestBase, - public ::testing::WithParamInterface> { +class TlsExtensionTestGeneric : public TlsExtensionTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsExtensionTestGeneric() : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { } }; -class TlsExtensionTestPre13 - : public TlsExtensionTestBase, - public ::testing::WithParamInterface> { +class TlsExtensionTestPre13 : public TlsExtensionTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsExtensionTestPre13() : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { @@ -295,23 +288,27 @@ class TlsExtensionTestPre13 }; TEST_P(TlsExtensionTestGeneric, DamageSniLength) { - ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 1)); + ClientHelloErrorTest( + std::make_shared(ssl_server_name_xtn, 1)); } TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) { - ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 4)); + ClientHelloErrorTest( + std::make_shared(ssl_server_name_xtn, 4)); } TEST_P(TlsExtensionTestGeneric, TruncateSni) { - ClientHelloErrorTest(new TlsExtensionTruncator(ssl_server_name_xtn, 7)); + ClientHelloErrorTest( + std::make_shared(ssl_server_name_xtn, 7)); } // A valid extension that appears twice will be reported as unsupported. TEST_P(TlsExtensionTestGeneric, RepeatSni) { DataBuffer extension; InitSimpleSni(&extension); - ClientHelloErrorTest(new TlsExtensionInjector(ssl_server_name_xtn, extension), - kTlsAlertIllegalParameter); + ClientHelloErrorTest( + std::make_shared(ssl_server_name_xtn, extension), + kTlsAlertIllegalParameter); } // An SNI entry with zero length is considered invalid (strangely, not if it is @@ -324,7 +321,7 @@ TEST_P(TlsExtensionTestGeneric, BadSni) { extension.Write(0, static_cast(0), 3); extension.Write(3, simple); ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_server_name_xtn, extension)); + std::make_shared(ssl_server_name_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, EmptySni) { @@ -332,15 +329,15 @@ TEST_P(TlsExtensionTestGeneric, EmptySni) { extension.Allocate(2); extension.Write(0, static_cast(0), 2); ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_server_name_xtn, extension)); + std::make_shared(ssl_server_name_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) { EnableAlpn(); DataBuffer extension; - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension), - kTlsAlertIllegalParameter); + ClientHelloErrorTest(std::make_shared( + ssl_app_layer_protocol_xtn, extension), + kTlsAlertIllegalParameter); } // An empty ALPN isn't considered bad, though it does lead to there being no @@ -349,30 +346,30 @@ TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) { EnableAlpn(); const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension), - kTlsAlertNoApplicationProtocol); + ClientHelloErrorTest(std::make_shared( + ssl_app_layer_protocol_xtn, extension), + kTlsAlertNoApplicationProtocol); } TEST_P(TlsExtensionTestGeneric, OneByteAlpn) { EnableAlpn(); ClientHelloErrorTest( - new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 1)); + std::make_shared(ssl_app_layer_protocol_xtn, 1)); } TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) { EnableAlpn(); // This will leave the length of the second entry, but no value. ClientHelloErrorTest( - new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 5)); + std::make_shared(ssl_app_layer_protocol_xtn, 5)); } TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { EnableAlpn(); const uint8_t val[] = {0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, AlpnMismatch) { @@ -390,158 +387,169 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) { EnableAlpn(); const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); - ServerHelloErrorTest( - new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); + ServerHelloErrorTest(std::make_shared( + ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) { EnableAlpn(); const uint8_t val[] = {0x00, 0x01, 0x00}; DataBuffer extension(val, sizeof(val)); - ServerHelloErrorTest( - new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); + ServerHelloErrorTest(std::make_shared( + ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) { EnableAlpn(); const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); - ServerHelloErrorTest( - new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); + ServerHelloErrorTest(std::make_shared( + ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) { EnableAlpn(); const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62}; DataBuffer extension(val, sizeof(val)); - ServerHelloErrorTest( - new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); + ServerHelloErrorTest(std::make_shared( + ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) { EnableAlpn(); const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); - ServerHelloErrorTest( - new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); + ServerHelloErrorTest(std::make_shared( + ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) { EnableAlpn(); const uint8_t val[] = {0x00, 0x02, 0x99, 0x61}; DataBuffer extension(val, sizeof(val)); - ServerHelloErrorTest( - new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); + ServerHelloErrorTest(std::make_shared( + ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x02, 0x01, 0x67}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest(std::make_shared( + ssl_app_layer_protocol_xtn, extension), + kTlsAlertIllegalParameter); } TEST_P(TlsExtensionTestDtls, SrtpShort) { EnableSrtp(); - ClientHelloErrorTest(new TlsExtensionTruncator(ssl_use_srtp_xtn, 3)); + ClientHelloErrorTest( + std::make_shared(ssl_use_srtp_xtn, 3)); } TEST_P(TlsExtensionTestDtls, SrtpOdd) { EnableSrtp(); const uint8_t val[] = {0x00, 0x01, 0xff, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_use_srtp_xtn, extension)); + ClientHelloErrorTest( + std::make_shared(ssl_use_srtp_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) { const uint8_t val[] = {0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) { const uint8_t val[] = {0x00, 0x01, 0x04}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) { - ClientHelloErrorTest(new TlsExtensionDropper(ssl_supported_groups_xtn), - version_ < SSL_LIBRARY_VERSION_TLS_1_3 - ? kTlsAlertDecryptError - : kTlsAlertMissingExtension); + ClientHelloErrorTest( + std::make_shared(ssl_supported_groups_xtn), + version_ < SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsAlertDecryptError + : kTlsAlertMissingExtension); } TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) { const uint8_t val[] = {0x00, 0x01, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) { const uint8_t val[] = {0x09, 0x99, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) { const uint8_t val[] = {0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) { const uint8_t val[] = {0x99, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) { const uint8_t val[] = {0x01, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) { const uint8_t val[] = {0x99}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_renegotiation_info_xtn, extension)); } TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) { const uint8_t val[] = {0x01, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_renegotiation_info_xtn, extension)); } // The extension has to contain a length. TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) { DataBuffer extension; - ClientHelloErrorTest( - new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + ssl_renegotiation_info_xtn, extension)); } // This only works on TLS 1.2, since it relies on static RSA; otherwise libssl @@ -550,8 +558,8 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_sha512, ssl_sig_rsa_pss_sha384}; - TlsExtensionCapture* capture = - new TlsExtensionCapture(ssl_signature_algorithms_xtn); + auto capture = + std::make_shared(ssl_signature_algorithms_xtn); client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes)); client_->SetPacketFilter(capture); EnableOnlyStaticRsaCiphers(); @@ -571,8 +579,9 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { // Temporary test to verify that we choke on an empty ClientKeyShare. // This test will fail when we implement HelloRetryRequest. TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { - ClientHelloErrorTest(new TlsExtensionTruncator(ssl_tls13_key_share_xtn, 2), - kTlsAlertHandshakeFailure); + ClientHelloErrorTest( + std::make_shared(ssl_tls13_key_share_xtn, 2), + kTlsAlertHandshakeFailure); } // These tests only work in stream mode because the client sends a @@ -581,7 +590,10 @@ TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { // packet gets dropped. TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) { EnsureTlsSetup(); - server_->SetPacketFilter(new TlsExtensionDropper(ssl_tls13_key_share_xtn)); + server_->SetPacketFilter( + std::make_shared(ssl_tls13_key_share_xtn)); + client_->ExpectSendAlert(kTlsAlertMissingExtension); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); EXPECT_EQ(SSL_ERROR_MISSING_KEY_SHARE, client_->error_code()); EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code()); @@ -600,7 +612,9 @@ TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) { DataBuffer buf(key_share, sizeof(key_share)); EnsureTlsSetup(); server_->SetPacketFilter( - new TlsExtensionReplacer(ssl_tls13_key_share_xtn, buf)); + std::make_shared(ssl_tls13_key_share_xtn, buf)); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); EXPECT_EQ(SSL_ERROR_RX_MALFORMED_KEY_SHARE, client_->error_code()); EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code()); @@ -620,7 +634,9 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) { DataBuffer buf(key_share, sizeof(key_share)); EnsureTlsSetup(); server_->SetPacketFilter( - new TlsExtensionReplacer(ssl_tls13_key_share_xtn, buf)); + std::make_shared(ssl_tls13_key_share_xtn, buf)); + client_->ExpectSendAlert(kTlsAlertMissingExtension); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); EXPECT_EQ(SSL_ERROR_MISSING_KEY_SHARE, client_->error_code()); EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code()); @@ -629,8 +645,10 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) { TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) { SetupForResume(); DataBuffer empty; - server_->SetPacketFilter( - new TlsExtensionInjector(ssl_signature_algorithms_xtn, empty)); + server_->SetPacketFilter(std::make_shared( + ssl_signature_algorithms_xtn, empty)); + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); EXPECT_EQ(SSL_ERROR_EXTENSION_DISALLOWED_FOR_VERSION, client_->error_code()); EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code()); @@ -763,9 +781,9 @@ class TlsPreSharedKeyReplacer : public TlsExtensionFilter { TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) { SetupForResume(); - client_->SetPacketFilter(new TlsPreSharedKeyReplacer([]( + client_->SetPacketFilter(std::make_shared([]( TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); })); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); } @@ -775,10 +793,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) { SetupForResume(); client_->SetPacketFilter( - new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) { + std::make_shared([](TlsPreSharedKeyReplacer* r) { r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1); })); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); } @@ -788,10 +806,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { SetupForResume(); client_->SetPacketFilter( - new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) { + std::make_shared([](TlsPreSharedKeyReplacer* r) { r->binders_[0].Write(r->binders_[0].len(), 0xff, 1); })); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); } @@ -800,9 +818,9 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) { SetupForResume(); - client_->SetPacketFilter(new TlsPreSharedKeyReplacer( + client_->SetPacketFilter(std::make_shared( [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); })); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); } @@ -813,11 +831,11 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) { SetupForResume(); client_->SetPacketFilter( - new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) { + std::make_shared([](TlsPreSharedKeyReplacer* r) { r->identities_.push_back(r->identities_[0]); r->binders_.push_back(r->binders_[0]); })); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); } @@ -828,10 +846,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) { SetupForResume(); client_->SetPacketFilter( - new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) { + std::make_shared([](TlsPreSharedKeyReplacer* r) { r->identities_.push_back(r->identities_[0]); })); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); } @@ -839,9 +857,9 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) { TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) { SetupForResume(); - client_->SetPacketFilter(new TlsPreSharedKeyReplacer([]( + client_->SetPacketFilter(std::make_shared([]( TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); })); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); } @@ -851,10 +869,10 @@ TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) { const uint8_t empty_buf[] = {0}; DataBuffer empty(empty_buf, 0); - client_->SetPacketFilter( - // Inject an unused extension. - new TlsExtensionAppender(0xffff, empty)); - ConnectExpectFail(); + // Inject an unused extension after the PSK extension. + client_->SetPacketFilter(std::make_shared( + kTlsHandshakeClientHello, 0xffff, empty)); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); } @@ -863,9 +881,9 @@ TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) { SetupForResume(); DataBuffer empty; - client_->SetPacketFilter( - new TlsExtensionDropper(ssl_tls13_psk_key_exchange_modes_xtn)); - ConnectExpectFail(); + client_->SetPacketFilter(std::make_shared( + ssl_tls13_psk_key_exchange_modes_xtn)); + ConnectExpectAlert(server_, kTlsAlertMissingExtension); client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES); } @@ -879,8 +897,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) { kTls13PskKe}; DataBuffer modes(ke_modes, sizeof(ke_modes)); - client_->SetPacketFilter( - new TlsExtensionReplacer(ssl_tls13_psk_key_exchange_modes_xtn, modes)); + client_->SetPacketFilter(std::make_shared( + ssl_tls13_psk_key_exchange_modes_xtn, modes)); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); @@ -888,7 +908,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) { TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) { ConfigureSessionCache(RESUME_NONE, RESUME_NONE); - auto capture = new TlsExtensionCapture(ssl_tls13_psk_key_exchange_modes_xtn); + auto capture = std::make_shared( + ssl_tls13_psk_key_exchange_modes_xtn); client_->SetPacketFilter(capture); Connect(); EXPECT_FALSE(capture->captured()); @@ -899,6 +920,7 @@ TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) { // 1. Both sides only support TLS 1.3, so we get a cipher version // error. TEST_P(TlsExtensionTest13, RemoveTls13FromVersionList) { + ExpectAlert(server_, kTlsAlertProtocolVersion); ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2); client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); @@ -909,6 +931,7 @@ TEST_P(TlsExtensionTest13, RemoveTls13FromVersionList) { TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListServerV12) { server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_3); + ExpectAlert(server_, kTlsAlertHandshakeFailure); ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2); client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); @@ -921,6 +944,11 @@ TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListBothV12) { SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_3); +#ifndef TLS_1_3_DRAFT_VERSION + ExpectAlert(server_, kTlsAlertIllegalParameter); +#else + ExpectAlert(server_, kTlsAlertDecryptError); +#endif ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2); #ifndef TLS_1_3_DRAFT_VERSION client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); @@ -932,18 +960,21 @@ TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListBothV12) { } TEST_P(TlsExtensionTest13, HrrThenRemoveSignatureAlgorithms) { + ExpectAlert(server_, kTlsAlertMissingExtension); HrrThenRemoveExtensionsTest(ssl_signature_algorithms_xtn, SSL_ERROR_MISSING_EXTENSION_ALERT, SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION); } TEST_P(TlsExtensionTest13, HrrThenRemoveKeyShare) { + ExpectAlert(server_, kTlsAlertIllegalParameter); HrrThenRemoveExtensionsTest(ssl_tls13_key_share_xtn, SSL_ERROR_ILLEGAL_PARAMETER_ALERT, SSL_ERROR_BAD_2ND_CLIENT_HELLO); } TEST_P(TlsExtensionTest13, HrrThenRemoveSupportedGroups) { + ExpectAlert(server_, kTlsAlertMissingExtension); HrrThenRemoveExtensionsTest(ssl_supported_groups_xtn, SSL_ERROR_MISSING_EXTENSION_ALERT, SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION); @@ -959,27 +990,192 @@ TEST_P(TlsExtensionTest13, OddVersionList) { ConnectWithBogusVersionList(ext, sizeof(ext)); } -INSTANTIATE_TEST_CASE_P(ExtensionStream, TlsExtensionTestGeneric, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsVAll)); -INSTANTIATE_TEST_CASE_P(ExtensionDatagram, TlsExtensionTestGeneric, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, - TlsConnectTestBase::kTlsV11Plus)); +// TODO: this only tests extensions in server messages. The client can extend +// Certificate messages, which is not checked here. +class TlsBogusExtensionTest : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple> { + public: + TlsBogusExtensionTest() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + + protected: + virtual void ConnectAndFail(uint8_t message) = 0; + + void AddFilter(uint8_t message, uint16_t extension) { + static uint8_t empty_buf[1] = {0}; + DataBuffer empty(empty_buf, 0); + auto filter = + std::make_shared(message, extension, empty); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + server_->SetTlsRecordFilter(filter); + filter->EnableDecryption(); + } else { + server_->SetPacketFilter(filter); + } + } + + void Run(uint8_t message, uint16_t extension = 0xff) { + EnsureTlsSetup(); + AddFilter(message, extension); + ConnectAndFail(message); + } +}; + +class TlsBogusExtensionTestPre13 : public TlsBogusExtensionTest { + protected: + void ConnectAndFail(uint8_t) override { + ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension); + } +}; + +class TlsBogusExtensionTest13 : public TlsBogusExtensionTest { + protected: + void ConnectAndFail(uint8_t message) override { + if (message == kTlsHandshakeHelloRetryRequest) { + ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension); + return; + } + + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); // ClientHello + server_->Handshake(); // ServerHello + + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + client_->Handshake(); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + } + server_->Handshake(); + } +}; + +TEST_P(TlsBogusExtensionTestPre13, AddBogusExtensionServerHello) { + Run(kTlsHandshakeServerHello); +} + +TEST_P(TlsBogusExtensionTest13, AddBogusExtensionServerHello) { + Run(kTlsHandshakeServerHello); +} + +TEST_P(TlsBogusExtensionTest13, AddBogusExtensionEncryptedExtensions) { + Run(kTlsHandshakeEncryptedExtensions); +} + +TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) { + Run(kTlsHandshakeCertificate); +} + +TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) { + server_->RequestClientAuth(false); + Run(kTlsHandshakeCertificateRequest); +} + +TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) { + static const std::vector groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + Run(kTlsHandshakeHelloRetryRequest); +} + +TEST_P(TlsBogusExtensionTest13, AddVersionExtensionServerHello) { + Run(kTlsHandshakeServerHello, ssl_tls13_supported_versions_xtn); +} + +TEST_P(TlsBogusExtensionTest13, AddVersionExtensionEncryptedExtensions) { + Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_supported_versions_xtn); +} + +TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificate) { + Run(kTlsHandshakeCertificate, ssl_tls13_supported_versions_xtn); +} + +TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificateRequest) { + server_->RequestClientAuth(false); + Run(kTlsHandshakeCertificateRequest, ssl_tls13_supported_versions_xtn); +} + +TEST_P(TlsBogusExtensionTest13, AddVersionExtensionHelloRetryRequest) { + static const std::vector groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + Run(kTlsHandshakeHelloRetryRequest, ssl_tls13_supported_versions_xtn); +} + +// NewSessionTicket allows unknown extensions AND it isn't protected by the +// Finished. So adding an unknown extension doesn't cause an error. +TEST_P(TlsBogusExtensionTest13, AddBogusExtensionNewSessionTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + + AddFilter(kTlsHandshakeNewSessionTicket, 0xff); + Connect(); + SendReceive(); + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectStream, IncludePadding) { + EnsureTlsSetup(); + + // This needs to be long enough to push a TLS 1.0 ClientHello over 255, but + // short enough not to push a TLS 1.3 ClientHello over 511. + static const char* long_name = + "chickenchickenchickenchickenchickenchickenchickenchicken." + "chickenchickenchickenchickenchickenchickenchickenchicken." + "chickenchickenchickenchickenchicken."; + SECStatus rv = SSL_SetURL(client_->ssl_fd(), long_name); + EXPECT_EQ(SECSuccess, rv); + + auto capture = std::make_shared(ssl_padding_xtn); + client_->SetPacketFilter(capture); + client_->StartConnect(); + client_->Handshake(); + EXPECT_TRUE(capture->captured()); +} + +INSTANTIATE_TEST_CASE_P( + ExtensionStream, TlsExtensionTestGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_CASE_P( + ExtensionDatagram, TlsExtensionTestGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); INSTANTIATE_TEST_CASE_P(ExtensionDatagramOnly, TlsExtensionTestDtls, TlsConnectTestBase::kTlsV11Plus); INSTANTIATE_TEST_CASE_P(ExtensionTls12Plus, TlsExtensionTest12Plus, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus)); -INSTANTIATE_TEST_CASE_P(ExtensionPre13Stream, TlsExtensionTestPre13, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10ToV12)); +INSTANTIATE_TEST_CASE_P( + ExtensionPre13Stream, TlsExtensionTestPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); INSTANTIATE_TEST_CASE_P(ExtensionPre13Datagram, TlsExtensionTestPre13, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV11V12)); INSTANTIATE_TEST_CASE_P(ExtensionTls13, TlsExtensionTest13, - TlsConnectTestBase::kTlsModesAll); - -} // namespace nspr_test + TlsConnectTestBase::kTlsVariantsAll); + +INSTANTIATE_TEST_CASE_P( + BogusExtensionStream, TlsBogusExtensionTestPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); +INSTANTIATE_TEST_CASE_P( + BogusExtensionDatagram, TlsBogusExtensionTestPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11V12)); + +INSTANTIATE_TEST_CASE_P(BogusExtension13, TlsBogusExtensionTest13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13)); + +} // namespace nss_test diff --git a/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc b/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc new file mode 100644 index 0000000..44cacce --- /dev/null +++ b/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc @@ -0,0 +1,157 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +// This class cuts every unencrypted handshake record into two parts. +class RecordFragmenter : public PacketFilter { + public: + RecordFragmenter() : sequence_number_(0), splitting_(true) {} + + private: + class HandshakeSplitter { + public: + HandshakeSplitter(const DataBuffer& input, DataBuffer* output, + uint64_t* sequence_number) + : input_(input), + output_(output), + cursor_(0), + sequence_number_(sequence_number) {} + + private: + void WriteRecord(TlsRecordHeader& record_header, + DataBuffer& record_fragment) { + TlsRecordHeader fragment_header(record_header.version(), + record_header.content_type(), + *sequence_number_); + ++*sequence_number_; + if (::g_ssl_gtest_verbose) { + std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment + << std::endl; + } + cursor_ = fragment_header.Write(output_, cursor_, record_fragment); + } + + bool SplitRecord(TlsRecordHeader& record_header, DataBuffer& record) { + TlsParser parser(record); + while (parser.remaining()) { + TlsHandshakeFilter::HandshakeHeader handshake_header; + DataBuffer handshake_body; + if (!handshake_header.Parse(&parser, record_header, &handshake_body)) { + ADD_FAILURE() << "couldn't parse handshake header"; + return false; + } + + DataBuffer record_fragment; + // We can't fragment handshake records that are too small. + if (handshake_body.len() < 2) { + handshake_header.Write(&record_fragment, 0U, handshake_body); + WriteRecord(record_header, record_fragment); + continue; + } + + size_t cut = handshake_body.len() / 2; + handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 0U, + cut); + WriteRecord(record_header, record_fragment); + + handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, + cut, handshake_body.len() - cut); + WriteRecord(record_header, record_fragment); + } + return true; + } + + public: + bool Split() { + TlsParser parser(input_); + while (parser.remaining()) { + TlsRecordHeader header; + DataBuffer record; + if (!header.Parse(&parser, &record)) { + ADD_FAILURE() << "bad record header"; + return false; + } + + if (::g_ssl_gtest_verbose) { + std::cerr << "Record: " << header << ' ' << record << std::endl; + } + + // Don't touch packets from a non-zero epoch. Leave these unmodified. + if ((header.sequence_number() >> 48) != 0ULL) { + cursor_ = header.Write(output_, cursor_, record); + continue; + } + + // Just rewrite the sequence number (CCS only). + if (header.content_type() != kTlsHandshakeType) { + EXPECT_EQ(kTlsChangeCipherSpecType, header.content_type()); + WriteRecord(header, record); + continue; + } + + if (!SplitRecord(header, record)) { + return false; + } + } + return true; + } + + private: + const DataBuffer& input_; + DataBuffer* output_; + size_t cursor_; + uint64_t* sequence_number_; + }; + + protected: + virtual PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override { + if (!splitting_) { + return KEEP; + } + + output->Allocate(input.len()); + HandshakeSplitter splitter(input, output, &sequence_number_); + if (!splitter.Split()) { + // If splitting fails, we obviously reached encrypted packets. + // Stop splitting from that point onward. + splitting_ = false; + return KEEP; + } + + return CHANGE; + } + + private: + uint64_t sequence_number_; + bool splitting_; +}; + +TEST_P(TlsConnectDatagram, FragmentClientPackets) { + client_->SetPacketFilter(std::make_shared()); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectDatagram, FragmentServerPackets) { + server_->SetPacketFilter(std::make_shared()); + Connect(); + SendReceive(); +} + +} // namespace nss_test diff --git a/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc index d144cd7..d08a0b6 100644 --- a/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc @@ -12,6 +12,12 @@ namespace nss_test { #ifdef UNSAFE_FUZZER_MODE +#define FUZZ_F(c, f) TEST_F(c, Fuzz_##f) +#define FUZZ_P(c, f) TEST_P(c, Fuzz_##f) +#else +#define FUZZ_F(c, f) TEST_F(c, DISABLED_Fuzz_##f) +#define FUZZ_P(c, f) TEST_P(c, DISABLED_Fuzz_##f) +#endif const uint8_t kShortEmptyFinished[8] = {0}; const uint8_t kLongEmptyFinished[128] = {0}; @@ -23,7 +29,7 @@ class TlsApplicationDataRecorder : public TlsRecordFilter { public: TlsApplicationDataRecorder() : buffer_() {} - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output) { if (header.content_type() == kTlsApplicationDataType) { @@ -39,56 +45,28 @@ class TlsApplicationDataRecorder : public TlsRecordFilter { DataBuffer buffer_; }; -// Damages an SKE or CV signature. -class TlsSignatureDamager : public TlsHandshakeFilter { - public: - TlsSignatureDamager(uint8_t type) : type_(type) {} - virtual PacketFilter::Action FilterHandshake( - const TlsHandshakeFilter::HandshakeHeader& header, - const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != type_) { - return KEEP; - } - - *output = input; - - // Modify the last byte of the signature. - output->data()[output->len() - 1]++; - return CHANGE; - } - - private: - uint8_t type_; -}; - -void ResetState() { - // Clear the list of RSA blinding params. - BL_Cleanup(); - - // Reinit the list of RSA blinding params. - EXPECT_EQ(SECSuccess, BL_Init()); - - // Reset the RNG state. - EXPECT_EQ(SECSuccess, RNG_ResetForFuzzing()); -} - // Ensure that ssl_Time() returns a constant value. -TEST_F(TlsFuzzTest, Fuzz_SSL_Time_Constant) { - PRInt32 now = ssl_Time(); +FUZZ_F(TlsFuzzTest, SSL_Time_Constant) { + PRUint32 now = ssl_Time(); PR_Sleep(PR_SecondsToInterval(2)); EXPECT_EQ(ssl_Time(), now); } // Check that due to the deterministic PRNG we derive // the same master secret in two consecutive TLS sessions. -TEST_P(TlsConnectGeneric, Fuzz_DeterministicExporter) { +FUZZ_P(TlsConnectGeneric, DeterministicExporter) { const char kLabel[] = "label"; std::vector out1(32), out2(32); + // Make sure we have RSA blinding params. + Connect(); + + Reset(); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); DisableECDHEServerKeyReuse(); - ResetState(); + // Reset the RNG state. + EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0)); Connect(); // Export a key derived from the MS and nonces. @@ -101,7 +79,8 @@ TEST_P(TlsConnectGeneric, Fuzz_DeterministicExporter) { ConfigureSessionCache(RESUME_NONE, RESUME_NONE); DisableECDHEServerKeyReuse(); - ResetState(); + // Reset the RNG state. + EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0)); Connect(); // Export another key derived from the MS and nonces. @@ -115,7 +94,10 @@ TEST_P(TlsConnectGeneric, Fuzz_DeterministicExporter) { // Check that due to the deterministic RNG two consecutive // TLS sessions will have the exact same transcript. -TEST_P(TlsConnectGeneric, Fuzz_DeterministicTranscript) { +FUZZ_P(TlsConnectGeneric, DeterministicTranscript) { + // Make sure we have RSA blinding params. + Connect(); + // Connect a few times and compare the transcripts byte-by-byte. DataBuffer last; for (size_t i = 0; i < 5; i++) { @@ -124,15 +106,16 @@ TEST_P(TlsConnectGeneric, Fuzz_DeterministicTranscript) { DisableECDHEServerKeyReuse(); DataBuffer buffer; - client_->SetPacketFilter(new TlsConversationRecorder(buffer)); - server_->SetPacketFilter(new TlsConversationRecorder(buffer)); + client_->SetPacketFilter(std::make_shared(buffer)); + server_->SetPacketFilter(std::make_shared(buffer)); - ResetState(); + // Reset the RNG state. + EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0)); Connect(); // Ensure the filters go away before |buffer| does. - client_->SetPacketFilter(nullptr); - server_->SetPacketFilter(nullptr); + client_->DeletePacketFilter(); + server_->DeletePacketFilter(); if (last.len() > 0) { EXPECT_EQ(last, buffer); @@ -146,13 +129,13 @@ TEST_P(TlsConnectGeneric, Fuzz_DeterministicTranscript) { // with all supported TLS versions, STREAM and DGRAM. // Check that records are NOT encrypted. // Check that records don't have a MAC. -TEST_P(TlsConnectGeneric, Fuzz_ConnectSendReceive_NullCipher) { +FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) { EnsureTlsSetup(); // Set up app data filters. - auto client_recorder = new TlsApplicationDataRecorder(); + auto client_recorder = std::make_shared(); client_->SetPacketFilter(client_recorder); - auto server_recorder = new TlsApplicationDataRecorder(); + auto server_recorder = std::make_shared(); server_->SetPacketFilter(server_recorder); Connect(); @@ -175,10 +158,10 @@ TEST_P(TlsConnectGeneric, Fuzz_ConnectSendReceive_NullCipher) { } // Check that an invalid Finished message doesn't abort the connection. -TEST_P(TlsConnectGeneric, Fuzz_BogusClientFinished) { +FUZZ_P(TlsConnectGeneric, BogusClientFinished) { EnsureTlsSetup(); - auto i1 = new TlsInspectorReplaceHandshakeMessage( + auto i1 = std::make_shared( kTlsHandshakeFinished, DataBuffer(kShortEmptyFinished, sizeof(kShortEmptyFinished))); client_->SetPacketFilter(i1); @@ -187,10 +170,10 @@ TEST_P(TlsConnectGeneric, Fuzz_BogusClientFinished) { } // Check that an invalid Finished message doesn't abort the connection. -TEST_P(TlsConnectGeneric, Fuzz_BogusServerFinished) { +FUZZ_P(TlsConnectGeneric, BogusServerFinished) { EnsureTlsSetup(); - auto i1 = new TlsInspectorReplaceHandshakeMessage( + auto i1 = std::make_shared( kTlsHandshakeFinished, DataBuffer(kLongEmptyFinished, sizeof(kLongEmptyFinished))); server_->SetPacketFilter(i1); @@ -199,25 +182,120 @@ TEST_P(TlsConnectGeneric, Fuzz_BogusServerFinished) { } // Check that an invalid server auth signature doesn't abort the connection. -TEST_P(TlsConnectGeneric, Fuzz_BogusServerAuthSignature) { +FUZZ_P(TlsConnectGeneric, BogusServerAuthSignature) { EnsureTlsSetup(); uint8_t msg_type = version_ == SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsHandshakeCertificateVerify : kTlsHandshakeServerKeyExchange; - server_->SetPacketFilter(new TlsSignatureDamager(msg_type)); + server_->SetPacketFilter(std::make_shared(msg_type)); Connect(); SendReceive(); } // Check that an invalid client auth signature doesn't abort the connection. -TEST_P(TlsConnectGeneric, Fuzz_BogusClientAuthSignature) { +FUZZ_P(TlsConnectGeneric, BogusClientAuthSignature) { EnsureTlsSetup(); client_->SetupClientAuth(); server_->RequestClientAuth(true); client_->SetPacketFilter( - new TlsSignatureDamager(kTlsHandshakeCertificateVerify)); + std::make_shared(kTlsHandshakeCertificateVerify)); Connect(); } -#endif +// Check that session ticket resumption works. +FUZZ_P(TlsConnectGeneric, SessionTicketResumption) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +class TlsSessionTicketMacDamager : public TlsExtensionFilter { + public: + TlsSessionTicketMacDamager() {} + virtual PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) { + if (extension_type != ssl_session_ticket_xtn && + extension_type != ssl_tls13_pre_shared_key_xtn) { + return KEEP; + } + + *output = input; + + // Handle everything before TLS 1.3. + if (extension_type == ssl_session_ticket_xtn) { + // Modify the last byte of the MAC. + output->data()[output->len() - 1] ^= 0xff; + } + + // Handle TLS 1.3. + if (extension_type == ssl_tls13_pre_shared_key_xtn) { + TlsParser parser(input); + + uint32_t ids_len; + EXPECT_TRUE(parser.Read(&ids_len, 2) && ids_len > 0); + + uint32_t ticket_len; + EXPECT_TRUE(parser.Read(&ticket_len, 2) && ticket_len > 0); + + // Modify the last byte of the MAC. + output->data()[2 + 2 + ticket_len - 1] ^= 0xff; + } + + return CHANGE; + } +}; + +// Check that session ticket resumption works with a bad MAC. +FUZZ_P(TlsConnectGeneric, SessionTicketResumptionBadMac) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + + client_->SetPacketFilter(std::make_shared()); + Connect(); + SendReceive(); +} + +// Check that session tickets are not encrypted. +FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) { + ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); + + auto i1 = std::make_shared( + kTlsHandshakeNewSessionTicket); + server_->SetPacketFilter(i1); + Connect(); + + size_t offset = 4; /* lifetime */ + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { + offset += 1 + 1 + /* ke_modes */ + 1 + 1; /* auth_modes */ + } + + offset += 2 + /* ticket length */ + 16 + /* SESS_TICKET_KEY_NAME_LEN */ + 16 + /* AES-128 IV */ + 2 + /* ciphertext length */ + 2; /* TLS_EX_SESS_TICKET_VERSION */ + + // Check the protocol version number. + uint32_t tls_version = 0; + EXPECT_TRUE(i1->buffer().Read(offset, sizeof(version_), &tls_version)); + EXPECT_EQ(version_, static_cast(tls_version)); + + // Check the cipher suite. + uint32_t suite = 0; + EXPECT_TRUE(i1->buffer().Read(offset + sizeof(version_), 2, &suite)); + client_->CheckCipherSuite(static_cast(suite)); +} } diff --git a/nss/gtests/ssl_gtest/ssl_gather_unittest.cc b/nss/gtests/ssl_gtest/ssl_gather_unittest.cc new file mode 100644 index 0000000..f47b2f4 --- /dev/null +++ b/nss/gtests/ssl_gtest/ssl_gather_unittest.cc @@ -0,0 +1,143 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +class GatherV2ClientHelloTest : public TlsConnectTestBase { + public: + GatherV2ClientHelloTest() : TlsConnectTestBase(ssl_variant_stream, 0) {} + + void ConnectExpectMalformedClientHello(const DataBuffer &data) { + EnsureTlsSetup(); + server_->ExpectSendAlert(kTlsAlertIllegalParameter); + client_->SendDirect(data); + server_->StartConnect(); + server_->Handshake(); + ASSERT_TRUE_WAIT( + (server_->error_code() == SSL_ERROR_RX_MALFORMED_CLIENT_HELLO), 2000); + } +}; + +// Gather a 5-byte v3 record, with a zero fragment length. The empty handshake +// message should be ignored, and the connection will succeed afterwards. +TEST_F(TlsConnectTest, GatherEmptyV3Record) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x16, 1); // handshake + idx = buffer.Write(idx, 0x0301, 2); // record_version + (void)buffer.Write(idx, 0U, 2); // length=0 + + EnsureTlsSetup(); + client_->SendDirect(buffer); + Connect(); +} + +// Gather a 5-byte v3 record, with a fragment length exceeding the maximum. +TEST_F(TlsConnectTest, GatherExcessiveV3Record) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x16, 1); // handshake + idx = buffer.Write(idx, 0x0301, 2); // record_version + (void)buffer.Write(idx, MAX_FRAGMENT_LENGTH + 2048 + 1, 2); // length=max+1 + + EnsureTlsSetup(); + server_->ExpectSendAlert(kTlsAlertRecordOverflow); + client_->SendDirect(buffer); + server_->StartConnect(); + server_->Handshake(); + ASSERT_TRUE_WAIT((server_->error_code() == SSL_ERROR_RX_RECORD_TOO_LONG), + 2000); +} + +// Gather a 3-byte v2 header, with a fragment length of 2. +TEST_F(GatherV2ClientHelloTest, GatherV2RecordLongHeader) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x0002, 2); // length=2 (long header) + idx = buffer.Write(idx, 0U, 1); // padding=0 + (void)buffer.Write(idx, 0U, 2); // data + + ConnectExpectMalformedClientHello(buffer); +} + +// Gather a 3-byte v2 header, with a fragment length of 1. +TEST_F(GatherV2ClientHelloTest, GatherV2RecordLongHeader2) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x0001, 2); // length=1 (long header) + idx = buffer.Write(idx, 0U, 1); // padding=0 + idx = buffer.Write(idx, 0U, 1); // data + (void)buffer.Write(idx, 0U, 1); // surplus (need 5 bytes total) + + ConnectExpectMalformedClientHello(buffer); +} + +// Gather a 3-byte v2 header, with a zero fragment length. +TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordLongHeader) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0U, 2); // length=0 (long header) + idx = buffer.Write(idx, 0U, 1); // padding=0 + (void)buffer.Write(idx, 0U, 2); // surplus (need 5 bytes total) + + ConnectExpectMalformedClientHello(buffer); +} + +// Gather a 2-byte v2 header, with a fragment length of 3. +TEST_F(GatherV2ClientHelloTest, GatherV2RecordShortHeader) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x8003, 2); // length=3 (short header) + (void)buffer.Write(idx, 0U, 3); // data + + ConnectExpectMalformedClientHello(buffer); +} + +// Gather a 2-byte v2 header, with a fragment length of 2. +TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordShortHeader2) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x8002, 2); // length=2 (short header) + idx = buffer.Write(idx, 0U, 2); // data + (void)buffer.Write(idx, 0U, 1); // surplus (need 5 bytes total) + + ConnectExpectMalformedClientHello(buffer); +} + +// Gather a 2-byte v2 header, with a fragment length of 1. +TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordShortHeader3) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x8001, 2); // length=1 (short header) + idx = buffer.Write(idx, 0U, 1); // data + (void)buffer.Write(idx, 0U, 2); // surplus (need 5 bytes total) + + ConnectExpectMalformedClientHello(buffer); +} + +// Gather a 2-byte v2 header, with a zero fragment length. +TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordShortHeader) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x8000, 2); // length=0 (short header) + (void)buffer.Write(idx, 0U, 3); // surplus (need 5 bytes total) + + ConnectExpectMalformedClientHello(buffer); +} + +} // namespace nss_test diff --git a/nss/gtests/ssl_gtest/ssl_gtest.cc b/nss/gtests/ssl_gtest/ssl_gtest.cc index 2d08dd8..cd10076 100644 --- a/nss/gtests/ssl_gtest/ssl_gtest.cc +++ b/nss/gtests/ssl_gtest/ssl_gtest.cc @@ -31,12 +31,18 @@ int main(int argc, char** argv) { } } - NSS_Initialize(g_working_dir_path.c_str(), "", "", SECMOD_DB, - NSS_INIT_READONLY); - NSS_SetDomesticPolicy(); + if (NSS_Initialize(g_working_dir_path.c_str(), "", "", SECMOD_DB, + NSS_INIT_READONLY) != SECSuccess) { + return 1; + } + if (NSS_SetDomesticPolicy() != SECSuccess) { + return 1; + } int rv = RUN_ALL_TESTS(); - NSS_Shutdown(); + if (NSS_Shutdown() != SECSuccess) { + return 1; + } nss_test::Poller::Shutdown(); diff --git a/nss/gtests/ssl_gtest/ssl_gtest.gyp b/nss/gtests/ssl_gtest/ssl_gtest.gyp index e232a8b..f0b96d6 100644 --- a/nss/gtests/ssl_gtest/ssl_gtest.gyp +++ b/nss/gtests/ssl_gtest/ssl_gtest.gyp @@ -25,6 +25,8 @@ 'ssl_exporter_unittest.cc', 'ssl_extension_unittest.cc', 'ssl_fuzz_unittest.cc', + 'ssl_fragment_unittest.cc', + 'ssl_gather_unittest.cc', 'ssl_gtest.cc', 'ssl_hrr_unittest.cc', 'ssl_loopback_unittest.cc', @@ -34,37 +36,45 @@ 'ssl_staticrsa_unittest.cc', 'ssl_v2_client_hello_unittest.cc', 'ssl_version_unittest.cc', + 'ssl_versionpolicy_unittest.cc', 'test_io.cc', 'tls_agent.cc', 'tls_connect.cc', 'tls_filter.cc', 'tls_hkdf_unittest.cc', - 'tls_parser.cc' + 'tls_protect.cc' ], 'dependencies': [ '<(DEPTH)/exports.gyp:nss_exports', '<(DEPTH)/lib/util/util.gyp:nssutil3', - '<(DEPTH)/lib/sqlite/sqlite.gyp:sqlite3', '<(DEPTH)/gtests/google_test/google_test.gyp:gtest', - '<(DEPTH)/lib/softoken/softoken.gyp:softokn', '<(DEPTH)/lib/smime/smime.gyp:smime', '<(DEPTH)/lib/ssl/ssl.gyp:ssl', '<(DEPTH)/lib/nss/nss.gyp:nss_static', - '<(DEPTH)/cmd/lib/lib.gyp:sectool', '<(DEPTH)/lib/pkcs12/pkcs12.gyp:pkcs12', '<(DEPTH)/lib/pkcs7/pkcs7.gyp:pkcs7', '<(DEPTH)/lib/certhigh/certhigh.gyp:certhi', '<(DEPTH)/lib/cryptohi/cryptohi.gyp:cryptohi', - '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap', - '<(DEPTH)/lib/softoken/softoken.gyp:softokn', '<(DEPTH)/lib/certdb/certdb.gyp:certdb', '<(DEPTH)/lib/pki/pki.gyp:nsspki', '<(DEPTH)/lib/dev/dev.gyp:nssdev', '<(DEPTH)/lib/base/base.gyp:nssb', - '<(DEPTH)/lib/freebl/freebl.gyp:<(freebl_name)', - '<(DEPTH)/lib/zlib/zlib.gyp:nss_zlib' + '<(DEPTH)/lib/zlib/zlib.gyp:nss_zlib', + '<(DEPTH)/cpputil/cpputil.gyp:cpputil', ], 'conditions': [ + [ 'test_build==1', { + 'dependencies': [ + '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap_static', + ], + }, { + 'dependencies': [ + '<(DEPTH)/lib/sqlite/sqlite.gyp:sqlite3', + '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap', + '<(DEPTH)/lib/softoken/softoken.gyp:softokn', + '<(DEPTH)/lib/freebl/freebl.gyp:freebl', + ], + }], [ 'disable_dbm==0', { 'dependencies': [ '<(DEPTH)/lib/dbm/src/src.gyp:dbm', @@ -90,10 +100,11 @@ ], 'target_defaults': { 'include_dirs': [ - '../../gtests/google_test/gtest/include', - '../../gtests/common', '../../lib/ssl' ], + 'defines': [ + 'NSS_USE_STATIC_LIBS' + ], }, 'variables': { 'module': 'nss', diff --git a/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc b/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc index 5d670fa..39055f6 100644 --- a/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc @@ -34,7 +34,8 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) { ExpectResumption(RESUME_TICKET); // Send first ClientHello and send 0-RTT data - auto capture_early_data = new TlsExtensionCapture(ssl_tls13_early_data_xtn); + auto capture_early_data = + std::make_shared(ssl_tls13_early_data_xtn); client_->SetPacketFilter(capture_early_data); client_->Handshake(); EXPECT_EQ(k0RttDataLen, PR_Write(client_->ssl_fd(), k0RttData, @@ -42,8 +43,8 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) { EXPECT_TRUE(capture_early_data->captured()); // Send the HelloRetryRequest - auto hrr_capture = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest); + auto hrr_capture = std::make_shared( + kTlsHandshakeHelloRetryRequest); server_->SetPacketFilter(hrr_capture); server_->Handshake(); EXPECT_LT(0U, hrr_capture->buffer().len()); @@ -54,7 +55,8 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) { EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); // Make a new capture for the early data. - capture_early_data = new TlsExtensionCapture(ssl_tls13_early_data_xtn); + capture_early_data = + std::make_shared(ssl_tls13_early_data_xtn); client_->SetPacketFilter(capture_early_data); // Complete the handshake successfully @@ -65,6 +67,88 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) { EXPECT_FALSE(capture_early_data->captured()); } +// This filter only works for DTLS 1.3 where there is exactly one handshake +// packet. If the record is split into two packets, or there are multiple +// handshake packets, this will break. +class CorrectMessageSeqAfterHrrFilter : public TlsRecordFilter { + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& record, size_t* offset, + DataBuffer* output) { + if (filtered_packets() > 0 || header.content_type() != content_handshake) { + return KEEP; + } + + DataBuffer buffer(record); + TlsRecordHeader new_header = {header.version(), header.content_type(), + header.sequence_number() + 1}; + + // Correct message_seq. + buffer.Write(4, 1U, 2); + + *offset = new_header.Write(output, *offset, buffer); + return CHANGE; + } +}; + +TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) { + static const std::vector groups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + + SetupForZeroRtt(); + ExpectResumption(RESUME_TICKET); + + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + + // A new client that tries to resume with 0-RTT but doesn't send the + // correct key share(s). The server will respond with an HRR. + auto orig_client = + std::make_shared(client_->name(), TlsAgent::CLIENT, variant_); + client_.swap(orig_client); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->ConfigureSessionCache(RESUME_BOTH); + client_->Set0RttEnabled(true); + client_->StartConnect(); + + // Swap in the new client. + client_->SetPeer(server_); + server_->SetPeer(client_); + + // Send the ClientHello. + client_->Handshake(); + // Process the CH, send an HRR. + server_->Handshake(); + + // Swap the client we created manually with the one that successfully + // received a PSK, and try to resume with 0-RTT. The client doesn't know + // about the HRR so it will send the early_data xtn as well as 0-RTT data. + client_.swap(orig_client); + orig_client.reset(); + + // Correct the DTLS message sequence number after an HRR. + if (variant_ == ssl_variant_datagram) { + client_->SetPacketFilter( + std::make_shared()); + } + + server_->SetPeer(client_); + client_->Handshake(); + + // Send 0-RTT data. + const char* k0RttData = "ABCDEF"; + const PRInt32 k0RttDataLen = static_cast(strlen(k0RttData)); + PRInt32 rv = PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); + EXPECT_EQ(k0RttDataLen, rv); + + ExpectAlert(server_, kTlsAlertUnsupportedExtension); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_EXTENSION_ALERT); +} + class KeyShareReplayer : public TlsExtensionFilter { public: KeyShareReplayer() {} @@ -94,11 +178,11 @@ class KeyShareReplayer : public TlsExtensionFilter { // server should reject this. TEST_P(TlsConnectTls13, RetryWithSameKeyShare) { EnsureTlsSetup(); - client_->SetPacketFilter(new KeyShareReplayer()); + client_->SetPacketFilter(std::make_shared()); static const std::vector groups = {ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1}; server_->ConfigNamedGroups(groups); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code()); EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code()); } @@ -109,7 +193,7 @@ TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) { static const std::vector groups = {ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1}; server_->ConfigNamedGroups(groups); - server_->SetPacketFilter(new SelectiveDropFilter(0x2)); + server_->SetPacketFilter(std::make_shared(0x2)); Connect(); } @@ -169,16 +253,13 @@ TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) { // Here we replace the TLS server with one that does TLS 1.2 only. // This will happily send the client a TLS 1.2 ServerHello. - TlsAgent* replacement_server = - new TlsAgent(server_->name(), TlsAgent::SERVER, mode_); - delete server_; - server_ = replacement_server; - server_->Init(); + server_.reset(new TlsAgent(server_->name(), TlsAgent::SERVER, variant_)); client_->SetPeer(server_); server_->SetPeer(client_); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_2); server_->StartConnect(); + ExpectAlert(client_, kTlsAlertIllegalParameter); Handshake(); EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, server_->error_code()); EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); @@ -189,8 +270,6 @@ class HelloRetryRequestAgentTest : public TlsAgentTestClient { void SetUp() override { TlsAgentTestClient::SetUp(); EnsureInit(); - agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, - SSL_LIBRARY_VERSION_TLS_1_3); agent_->StartConnect(); } @@ -232,6 +311,7 @@ TEST_P(HelloRetryRequestAgentTest, SendSecondHelloRetryRequest) { MakeGroupHrr(ssl_grp_ec_secp384r1, &hrr, 0); ProcessMessage(hrr, TlsAgent::STATE_CONNECTING); MakeGroupHrr(ssl_grp_ec_secp521r1, &hrr, 1); + ExpectAlert(kTlsAlertUnexpectedMessage); ProcessMessage(hrr, TlsAgent::STATE_ERROR, SSL_ERROR_RX_UNEXPECTED_HELLO_RETRY_REQUEST); } @@ -241,6 +321,7 @@ TEST_P(HelloRetryRequestAgentTest, SendSecondHelloRetryRequest) { TEST_P(HelloRetryRequestAgentTest, HandleBogusHelloRetryRequest) { DataBuffer hrr; MakeGroupHrr(ssl_grp_ec_curve25519, &hrr); + ExpectAlert(kTlsAlertIllegalParameter); ProcessMessage(hrr, TlsAgent::STATE_ERROR, SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST); } @@ -248,6 +329,7 @@ TEST_P(HelloRetryRequestAgentTest, HandleBogusHelloRetryRequest) { TEST_P(HelloRetryRequestAgentTest, HandleNoopHelloRetryRequest) { DataBuffer hrr; MakeCannedHrr(nullptr, 0U, &hrr); + ExpectAlert(kTlsAlertDecodeError); ProcessMessage(hrr, TlsAgent::STATE_ERROR, SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST); } @@ -265,7 +347,7 @@ TEST_P(HelloRetryRequestAgentTest, HandleHelloRetryRequestCookie) { 0x13}; DataBuffer hrr; MakeCannedHrr(canned_cookie_hrr, sizeof(canned_cookie_hrr), &hrr); - TlsExtensionCapture* capture = new TlsExtensionCapture(ssl_tls13_cookie_xtn); + auto capture = std::make_shared(ssl_tls13_cookie_xtn); agent_->SetPacketFilter(capture); ProcessMessage(hrr, TlsAgent::STATE_CONNECTING); const size_t cookie_pos = 2 + 2; // cookie_xtn, extension len @@ -275,10 +357,11 @@ TEST_P(HelloRetryRequestAgentTest, HandleHelloRetryRequestCookie) { } INSTANTIATE_TEST_CASE_P(HelloRetryRequestAgentTests, HelloRetryRequestAgentTest, - TlsConnectTestBase::kTlsModesAll); + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13)); #ifndef NSS_DISABLE_TLS_1_3 INSTANTIATE_TEST_CASE_P(HelloRetryRequestKeyExchangeTests, TlsKeyExchange13, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV13)); #endif diff --git a/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc index 65c0ca1..fd05754 100644 --- a/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc @@ -39,7 +39,7 @@ TEST_P(TlsConnectGeneric, ConnectEcdsa) { CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); } -TEST_P(TlsConnectGenericPre13, CipherSuiteMismatch) { +TEST_P(TlsConnectGeneric, CipherSuiteMismatch) { EnsureTlsSetup(); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); @@ -48,11 +48,97 @@ TEST_P(TlsConnectGenericPre13, CipherSuiteMismatch) { client_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA); } - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); } +class TlsAlertRecorder : public TlsRecordFilter { + public: + TlsAlertRecorder() : level_(255), description_(255) {} + + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + if (level_ != 255) { // Already captured. + return KEEP; + } + if (header.content_type() != kTlsAlertType) { + return KEEP; + } + + std::cerr << "Alert: " << input << std::endl; + + TlsParser parser(input); + EXPECT_TRUE(parser.Read(&level_)); + EXPECT_TRUE(parser.Read(&description_)); + return KEEP; + } + + uint8_t level() const { return level_; } + uint8_t description() const { return description_; } + + private: + uint8_t level_; + uint8_t description_; +}; + +class HelloTruncator : public TlsHandshakeFilter { + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + if (header.handshake_type() != kTlsHandshakeClientHello && + header.handshake_type() != kTlsHandshakeServerHello) { + return KEEP; + } + output->Assign(input.data(), input.len() - 1); + return CHANGE; + } +}; + +// Verify that when NSS reports that an alert is sent, it is actually sent. +TEST_P(TlsConnectGeneric, CaptureAlertServer) { + client_->SetPacketFilter(std::make_shared()); + auto alert_recorder = std::make_shared(); + server_->SetPacketFilter(alert_recorder); + + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description()); +} + +TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { + server_->SetPacketFilter(std::make_shared()); + auto alert_recorder = std::make_shared(); + client_->SetPacketFilter(alert_recorder); + + ConnectExpectAlert(client_, kTlsAlertDecodeError); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); +} + +// In TLS 1.3, the server can't read the client alert. +TEST_P(TlsConnectTls13, CaptureAlertClient) { + server_->SetPacketFilter(std::make_shared()); + auto alert_recorder = std::make_shared(); + client_->SetPacketFilter(alert_recorder); + + server_->StartConnect(); + client_->StartConnect(); + + client_->Handshake(); + client_->ExpectSendAlert(kTlsAlertDecodeError); + server_->Handshake(); + client_->Handshake(); + if (variant_ == ssl_variant_stream) { + // DTLS just drops the alert it can't decrypt. + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + } + server_->Handshake(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); +} + TEST_P(TlsConnectGenericPre13, ConnectFalseStart) { client_->EnableFalseStart(); Connect(); @@ -141,7 +227,8 @@ TEST_P(TlsConnectGeneric, ConnectWithCompressionMaybe) { client_->EnableCompression(); server_->EnableCompression(); Connect(); - EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && mode_ != DGRAM, + EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && + variant_ != ssl_variant_datagram, client_->is_compressed()); SendReceive(); } @@ -161,16 +248,15 @@ TEST_P(TlsConnectDatagram, TestDtlsHolddownExpiry) { class TlsPreCCSHeaderInjector : public TlsRecordFilter { public: TlsPreCCSHeaderInjector() {} - virtual PacketFilter::Action FilterRecord(const RecordHeader& record_header, - const DataBuffer& input, - size_t* offset, - DataBuffer* output) override { + virtual PacketFilter::Action FilterRecord( + const TlsRecordHeader& record_header, const DataBuffer& input, + size_t* offset, DataBuffer* output) override { if (record_header.content_type() != kTlsChangeCipherSpecType) return KEEP; std::cerr << "Injecting Finished header before CCS\n"; const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c}; DataBuffer hhdr_buf(hhdr, sizeof(hhdr)); - RecordHeader nhdr(record_header.version(), kTlsHandshakeType, 0); + TlsRecordHeader nhdr(record_header.version(), kTlsHandshakeType, 0); *offset = nhdr.Write(output, *offset, hhdr_buf); *offset = record_header.Write(output, *offset, input); return CHANGE; @@ -178,24 +264,28 @@ class TlsPreCCSHeaderInjector : public TlsRecordFilter { }; TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) { - client_->SetPacketFilter(new TlsPreCCSHeaderInjector()); - ConnectExpectFail(); + client_->SetPacketFilter(std::make_shared()); + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); } TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) { - server_->SetPacketFilter(new TlsPreCCSHeaderInjector()); + server_->SetPacketFilter(std::make_shared()); client_->StartConnect(); server_->StartConnect(); + ExpectAlert(client_, kTlsAlertUnexpectedMessage); Handshake(); EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + server_->Handshake(); // Make sure alert is consumed. } TEST_P(TlsConnectTls13, UnknownAlert) { Connect(); + server_->ExpectSendAlert(0xff, kTlsAlertWarning); + client_->ExpectReceiveAlert(0xff, kTlsAlertWarning); SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning, 0xff); // Unknown value. client_->ExpectReadWriteError(); @@ -204,20 +294,14 @@ TEST_P(TlsConnectTls13, UnknownAlert) { TEST_P(TlsConnectTls13, AlertWrongLevel) { Connect(); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning); + client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning); SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning, kTlsAlertUnexpectedMessage); client_->ExpectReadWriteError(); client_->WaitForErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT, 2000); } -TEST_F(TlsConnectStreamTls13, NegotiateShortHeaders) { - client_->SetShortHeadersEnabled(); - server_->SetShortHeadersEnabled(); - client_->ExpectShortHeaders(); - server_->ExpectShortHeaders(); - Connect(); -} - TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) { EnsureTlsSetup(); client_->StartConnect(); @@ -229,12 +313,21 @@ TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) { client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE); } -INSTANTIATE_TEST_CASE_P(GenericStream, TlsConnectGeneric, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsVAll)); +TEST_F(TlsConnectStreamTls13, NegotiateShortHeaders) { + client_->SetShortHeadersEnabled(); + server_->SetShortHeadersEnabled(); + client_->ExpectShortHeaders(); + server_->ExpectShortHeaders(); + Connect(); +} + +INSTANTIATE_TEST_CASE_P( + GenericStream, TlsConnectGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); INSTANTIATE_TEST_CASE_P( GenericDatagram, TlsConnectGeneric, - ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11Plus)); INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream, @@ -242,33 +335,35 @@ INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream, INSTANTIATE_TEST_CASE_P(DatagramOnly, TlsConnectDatagram, TlsConnectTestBase::kTlsV11Plus); -INSTANTIATE_TEST_CASE_P(Pre12Stream, TlsConnectPre12, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10V11)); +INSTANTIATE_TEST_CASE_P( + Pre12Stream, TlsConnectPre12, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10V11)); INSTANTIATE_TEST_CASE_P( Pre12Datagram, TlsConnectPre12, - ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11)); INSTANTIATE_TEST_CASE_P(Version12Only, TlsConnectTls12, - TlsConnectTestBase::kTlsModesAll); + TlsConnectTestBase::kTlsVariantsAll); #ifndef NSS_DISABLE_TLS_1_3 INSTANTIATE_TEST_CASE_P(Version13Only, TlsConnectTls13, - TlsConnectTestBase::kTlsModesAll); + TlsConnectTestBase::kTlsVariantsAll); #endif -INSTANTIATE_TEST_CASE_P(Pre13Stream, TlsConnectGenericPre13, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10ToV12)); +INSTANTIATE_TEST_CASE_P( + Pre13Stream, TlsConnectGenericPre13, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); INSTANTIATE_TEST_CASE_P( Pre13Datagram, TlsConnectGenericPre13, - ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, TlsConnectTestBase::kTlsV11V12)); INSTANTIATE_TEST_CASE_P(Pre13StreamOnly, TlsConnectStreamPre13, TlsConnectTestBase::kTlsV10ToV12); INSTANTIATE_TEST_CASE_P(Version12Plus, TlsConnectTls12Plus, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus)); } // namespace nspr_test diff --git a/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc b/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc index cfe42cb..7b43870 100644 --- a/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc @@ -21,6 +21,7 @@ extern "C" { #include "tls_connect.h" #include "tls_filter.h" #include "tls_parser.h" +#include "tls_protect.h" namespace nss_test { @@ -200,6 +201,63 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicketForget) { SendReceive(); } +TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtClient) { + SSLInt_SetTicketLifetime(1); // one second + // This causes a ticket resumption. + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + WAIT_(false, 1000); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + + // TLS 1.3 uses the pre-shared key extension instead. + SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) + ? ssl_tls13_pre_shared_key_xtn + : ssl_session_ticket_xtn; + auto capture = std::make_shared(xtn); + client_->SetPacketFilter(capture); + Connect(); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_FALSE(capture->captured()); + } else { + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(0U, capture->extension().len()); + } +} + +TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) { + SSLInt_SetTicketLifetime(1); // one second + // This causes a ticket resumption. + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + + SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) + ? ssl_tls13_pre_shared_key_xtn + : ssl_session_ticket_xtn; + auto capture = std::make_shared(xtn); + client_->SetPacketFilter(capture); + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + EXPECT_TRUE(capture->captured()); + EXPECT_LT(0U, capture->extension().len()); + + WAIT_(false, 1000); // Let the ticket expire on the server. + + Handshake(); + CheckConnected(); +} + // This callback switches out the "server" cert used on the server with // the "client" certificate, which should be the same type. static int32_t SwitchCertificates(TlsAgent* agent, const SECItem* srvNameArr, @@ -245,8 +303,8 @@ TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) { // Prior to TLS 1.3, we were not fully ephemeral; though 1.3 fixes that TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) { - TlsInspectorRecordHandshakeMessage* i1 = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + auto i1 = std::make_shared( + kTlsHandshakeServerKeyExchange); server_->SetPacketFilter(i1); Connect(); CheckKeys(); @@ -255,8 +313,8 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) { // Restart Reset(); - TlsInspectorRecordHandshakeMessage* i2 = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + auto i2 = std::make_shared( + kTlsHandshakeServerKeyExchange); server_->SetPacketFilter(i2); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); Connect(); @@ -277,8 +335,8 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) { SECStatus rv = SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); EXPECT_EQ(SECSuccess, rv); - TlsInspectorRecordHandshakeMessage* i1 = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + auto i1 = std::make_shared( + kTlsHandshakeServerKeyExchange); server_->SetPacketFilter(i1); Connect(); CheckKeys(); @@ -290,8 +348,8 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) { server_->EnsureTlsSetup(); rv = SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); EXPECT_EQ(SECSuccess, rv); - TlsInspectorRecordHandshakeMessage* i2 = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + auto i2 = std::make_shared( + kTlsHandshakeServerKeyExchange); server_->SetPacketFilter(i2); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); Connect(); @@ -356,7 +414,7 @@ TEST_P(TlsConnectGeneric, TestResumeClientDifferentCipher) { } else { ticket_extension = ssl_session_ticket_xtn; } - auto ticket_capture = new TlsExtensionCapture(ticket_extension); + auto ticket_capture = std::make_shared(ticket_extension); client_->SetPacketFilter(ticket_capture); Connect(); CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); @@ -420,9 +478,15 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) { Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - server_->SetPacketFilter( - new SelectedCipherSuiteReplacer(ChooseAnotherCipher(version_))); + server_->SetPacketFilter(std::make_shared( + ChooseAnotherCipher(version_))); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + } else { + ExpectAlert(client_, kTlsAlertHandshakeFailure); + } ConnectExpectFail(); client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { @@ -459,7 +523,7 @@ class SelectedVersionReplacer : public TlsHandshakeFilter { // lower version number on resumption. TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) { uint16_t override_version = 0; - if (mode_ == STREAM) { + if (variant_ == ssl_variant_stream) { switch (version_) { case SSL_LIBRARY_VERSION_TLS_1_0: return; // Skip the test. @@ -492,9 +556,10 @@ TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) { // Enable the lower version on the client. client_->SetVersionRange(version_ - 1, version_); server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); - server_->SetPacketFilter(new SelectedVersionReplacer(override_version)); + server_->SetPacketFilter( + std::make_shared(override_version)); - ConnectExpectFail(); + ConnectExpectAlert(client_, kTlsAlertHandshakeFailure); client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT); } @@ -515,8 +580,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); ExpectResumption(RESUME_TICKET); - TlsExtensionCapture* c1 = - new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn); + auto c1 = std::make_shared(ssl_tls13_pre_shared_key_xtn); client_->SetPacketFilter(c1); Connect(); SendReceive(); @@ -533,8 +597,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) { ClearStats(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); - TlsExtensionCapture* c2 = - new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn); + auto c2 = std::make_shared(ssl_tls13_pre_shared_key_xtn); client_->SetPacketFilter(c2); ExpectResumption(RESUME_TICKET); Connect(); @@ -579,4 +642,66 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNST) { SendReceive(); } +TEST_F(TlsConnectTest, TestTls13ResumptionDowngrade) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + // Try resuming the connection. This will fail resuming the 1.3 session + // from before, but will successfully establish a 1.2 connection. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + Connect(); + + // Renegotiate to ensure we don't carryover any state + // from the 1.3 resumption attempt. + client_->SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_2); + client_->PrepareForRenegotiate(); + server_->StartRenegotiate(); + Handshake(); + + SendReceive(); + CheckKeys(); +} + +TEST_F(TlsConnectTest, TestTls13ResumptionForcedDowngrade) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + // Try resuming the connection. + Reset(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + // Enable the lower version on the client. + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + + // Add filters that set downgrade SH.version to 1.2 and the cipher suite + // to one that works with 1.2, so that we don't run into early sanity checks. + // We will eventually fail the (sid.version == SH.version) check. + std::vector> filters; + filters.push_back(std::make_shared( + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)); + filters.push_back( + std::make_shared(SSL_LIBRARY_VERSION_TLS_1_2)); + server_->SetPacketFilter(std::make_shared(filters)); + + client_->ExpectSendAlert(kTlsAlertDecodeError); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); // Server can't read + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); +} + } // namespace nss_test diff --git a/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/nss/gtests/ssl_gtest/ssl_skip_unittest.cc index 523a374..a130ef7 100644 --- a/nss/gtests/ssl_gtest/ssl_skip_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_skip_unittest.cc @@ -28,9 +28,9 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter { protected: // Takes a record; if it is a handshake record, it removes the first handshake // message that is of handshake_type_ type. - virtual PacketFilter::Action FilterRecord(const RecordHeader& record_header, - const DataBuffer& input, - DataBuffer* output) { + virtual PacketFilter::Action FilterRecord( + const TlsRecordHeader& record_header, const DataBuffer& input, + DataBuffer* output) { if (record_header.content_type() != kTlsHandshakeType) { return KEEP; } @@ -78,81 +78,162 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter { bool skipped_; }; -class TlsSkipTest - : public TlsConnectTestBase, - public ::testing::WithParamInterface> { +class TlsSkipTest : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple> { protected: TlsSkipTest() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} - void ServerSkipTest(PacketFilter* filter, + void ServerSkipTest(std::shared_ptr filter, uint8_t alert = kTlsAlertUnexpectedMessage) { - auto alert_recorder = new TlsAlertRecorder(); - client_->SetPacketFilter(alert_recorder); - if (filter) { - server_->SetPacketFilter(filter); + server_->SetPacketFilter(filter); + ConnectExpectAlert(client_, alert); + } +}; + +class Tls13SkipTest : public TlsConnectTestBase, + public ::testing::WithParamInterface { + protected: + Tls13SkipTest() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + + void ServerSkipTest(std::shared_ptr filter, int32_t error) { + EnsureTlsSetup(); + server_->SetTlsRecordFilter(filter); + filter->EnableDecryption(); + client_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); + } else { + ConnectExpectFailOneSide(TlsAgent::CLIENT); + } + client_->CheckErrorCode(error); + if (variant_ == ssl_variant_stream) { + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + } else { + ASSERT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); } - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(alert, alert_recorder->description()); + } + + void ClientSkipTest(std::shared_ptr filter, int32_t error) { + EnsureTlsSetup(); + client_->SetTlsRecordFilter(filter); + filter->EnableDecryption(); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFailOneSide(TlsAgent::SERVER); + + server_->CheckErrorCode(error); + ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + + client_->Handshake(); // Make sure to consume the alert the server sends. } }; TEST_P(TlsSkipTest, SkipCertificateRsa) { EnableOnlyStaticRsaCiphers(); - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + ServerSkipTest( + std::make_shared(kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertificateDhe) { - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + ServerSkipTest( + std::make_shared(kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipCertificateEcdhe) { - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + ServerSkipTest( + std::make_shared(kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipCertificateEcdsa) { Reset(TlsAgent::kServerEcdsa256); - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + ServerSkipTest( + std::make_shared(kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipServerKeyExchange) { - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); + ServerSkipTest( + std::make_shared(kTlsHandshakeServerKeyExchange)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) { Reset(TlsAgent::kServerEcdsa256); - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); + ServerSkipTest( + std::make_shared(kTlsHandshakeServerKeyExchange)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertAndKeyExch) { - auto chain = new ChainedPacketFilter(); - chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); - chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); + auto chain = std::make_shared(); + chain->Add( + std::make_shared(kTlsHandshakeCertificate)); + chain->Add( + std::make_shared(kTlsHandshakeServerKeyExchange)); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) { Reset(TlsAgent::kServerEcdsa256); - auto chain = new ChainedPacketFilter(); - chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); - chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); + auto chain = std::make_shared(); + chain->Add( + std::make_shared(kTlsHandshakeCertificate)); + chain->Add( + std::make_shared(kTlsHandshakeServerKeyExchange)); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } -INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10)); +TEST_P(Tls13SkipTest, SkipEncryptedExtensions) { + ServerSkipTest(std::make_shared( + kTlsHandshakeEncryptedExtensions), + SSL_ERROR_RX_UNEXPECTED_CERTIFICATE); +} + +TEST_P(Tls13SkipTest, SkipServerCertificate) { + ServerSkipTest( + std::make_shared(kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); +} + +TEST_P(Tls13SkipTest, SkipServerCertificateVerify) { + ServerSkipTest( + std::make_shared(kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); +} + +TEST_P(Tls13SkipTest, SkipClientCertificate) { + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); + ClientSkipTest( + std::make_shared(kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); +} + +TEST_P(Tls13SkipTest, SkipClientCertificateVerify) { + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); + ClientSkipTest( + std::make_shared(kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); +} + +INSTANTIATE_TEST_CASE_P( + SkipTls10, TlsSkipTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10)); INSTANTIATE_TEST_CASE_P(SkipVariants, TlsSkipTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV11V12)); - +INSTANTIATE_TEST_CASE_P(Skip13Variants, Tls13SkipTest, + TlsConnectTestBase::kTlsVariantsAll); } // namespace nss_test diff --git a/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc index baf24ed..8db1f30 100644 --- a/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc @@ -48,28 +48,20 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSA) { // This test is stream so we can catch the bad_record_mac alert. TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) { EnableOnlyStaticRsaCiphers(); - TlsInspectorReplaceHandshakeMessage* i1 = - new TlsInspectorReplaceHandshakeMessage( - kTlsHandshakeClientKeyExchange, - DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))); + auto i1 = std::make_shared( + kTlsHandshakeClientKeyExchange, + DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))); client_->SetPacketFilter(i1); - auto alert_recorder = new TlsAlertRecorder(); - server_->SetPacketFilter(alert_recorder); - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description()); + ConnectExpectAlert(server_, kTlsAlertBadRecordMac); } // Test that a PMS with a bogus version number is handled correctly. // This test is stream so we can catch the bad_record_mac alert. TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) { EnableOnlyStaticRsaCiphers(); - client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_)); - auto alert_recorder = new TlsAlertRecorder(); - server_->SetPacketFilter(alert_recorder); - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description()); + client_->SetPacketFilter( + std::make_shared(server_)); + ConnectExpectAlert(server_, kTlsAlertBadRecordMac); } // Test that a PMS with a bogus version number is ignored when @@ -77,7 +69,8 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) { // ConnectStaticRSABogusPMSVersionDetect. TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) { EnableOnlyStaticRsaCiphers(); - client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_)); + client_->SetPacketFilter( + std::make_shared(server_)); server_->DisableRollbackDetection(); Connect(); } @@ -86,16 +79,11 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) { TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) { EnableOnlyStaticRsaCiphers(); EnableExtendedMasterSecret(); - TlsInspectorReplaceHandshakeMessage* inspect = - new TlsInspectorReplaceHandshakeMessage( - kTlsHandshakeClientKeyExchange, - DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))); + auto inspect = std::make_shared( + kTlsHandshakeClientKeyExchange, + DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))); client_->SetPacketFilter(inspect); - auto alert_recorder = new TlsAlertRecorder(); - server_->SetPacketFilter(alert_recorder); - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description()); + ConnectExpectAlert(server_, kTlsAlertBadRecordMac); } // This test is stream so we can catch the bad_record_mac alert. @@ -103,19 +91,17 @@ TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) { EnableOnlyStaticRsaCiphers(); EnableExtendedMasterSecret(); - client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_)); - auto alert_recorder = new TlsAlertRecorder(); - server_->SetPacketFilter(alert_recorder); - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description()); + client_->SetPacketFilter( + std::make_shared(server_)); + ConnectExpectAlert(server_, kTlsAlertBadRecordMac); } TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) { EnableOnlyStaticRsaCiphers(); EnableExtendedMasterSecret(); - client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_)); + client_->SetPacketFilter( + std::make_shared(server_)); server_->DisableRollbackDetection(); Connect(); } diff --git a/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc index 8b586be..110e3e0 100644 --- a/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc @@ -23,7 +23,7 @@ namespace nss_test { // Replaces the client hello with an SSLv2 version once. class SSLv2ClientHelloFilter : public PacketFilter { public: - SSLv2ClientHelloFilter(TlsAgent* client, uint16_t version) + SSLv2ClientHelloFilter(std::shared_ptr& client, uint16_t version) : replaced_(false), client_(client), version_(version), @@ -121,7 +121,7 @@ class SSLv2ClientHelloFilter : public PacketFilter { // Update the client random so that the handshake succeeds. SECStatus rv = SSLInt_UpdateSSLv2ClientRandom( - client_->ssl_fd(), challenge.data(), challenge.size(), + client_.lock()->ssl_fd(), challenge.data(), challenge.size(), output->data() + hdr_len, output->len() - hdr_len); EXPECT_EQ(SECSuccess, rv); @@ -130,7 +130,7 @@ class SSLv2ClientHelloFilter : public PacketFilter { private: bool replaced_; - TlsAgent* client_; + std::weak_ptr client_; uint16_t version_; uint8_t pad_len_; uint8_t reported_pad_len_; @@ -141,14 +141,15 @@ class SSLv2ClientHelloFilter : public PacketFilter { class SSLv2ClientHelloTestF : public TlsConnectTestBase { public: - SSLv2ClientHelloTestF() : TlsConnectTestBase(STREAM, 0), filter_(nullptr) {} + SSLv2ClientHelloTestF() + : TlsConnectTestBase(ssl_variant_stream, 0), filter_(nullptr) {} - SSLv2ClientHelloTestF(Mode mode, uint16_t version) - : TlsConnectTestBase(mode, version), filter_(nullptr) {} + SSLv2ClientHelloTestF(SSLProtocolVariant variant, uint16_t version) + : TlsConnectTestBase(variant, version), filter_(nullptr) {} void SetUp() { TlsConnectTestBase::SetUp(); - filter_ = new SSLv2ClientHelloFilter(client_, version_); + filter_ = std::make_shared(client_, version_); client_->SetPacketFilter(filter_); } @@ -185,7 +186,7 @@ class SSLv2ClientHelloTestF : public TlsConnectTestBase { void SetSendEscape(bool send_escape) { filter_->SetSendEscape(send_escape); } private: - SSLv2ClientHelloFilter* filter_; + std::shared_ptr filter_; }; // Parameterized version of SSLv2ClientHelloTestF we can @@ -193,7 +194,8 @@ class SSLv2ClientHelloTestF : public TlsConnectTestBase { class SSLv2ClientHelloTest : public SSLv2ClientHelloTestF, public ::testing::WithParamInterface { public: - SSLv2ClientHelloTest() : SSLv2ClientHelloTestF(STREAM, GetParam()) {} + SSLv2ClientHelloTest() + : SSLv2ClientHelloTestF(ssl_variant_stream, GetParam()) {} }; // Test negotiating TLS 1.0 - 1.2. @@ -202,6 +204,28 @@ TEST_P(SSLv2ClientHelloTest, Connect) { Connect(); } +// Sending a v2 ClientHello after a no-op v3 record must fail. +TEST_P(SSLv2ClientHelloTest, ConnectAfterEmptyV3Record) { + DataBuffer buffer; + + size_t idx = 0; + idx = buffer.Write(idx, 0x16, 1); // handshake + idx = buffer.Write(idx, 0x0301, 2); // record_version + (void)buffer.Write(idx, 0U, 2); // length=0 + + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + EnsureTlsSetup(); + client_->SendDirect(buffer); + + // Need padding so the connection doesn't just time out. With a v2 + // ClientHello parsed as a v3 record we will use the record version + // as the record length. + SetPadding(255); + + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + EXPECT_EQ(SSL_ERROR_BAD_CLIENT, server_->error_code()); +} + // Test negotiating TLS 1.3. TEST_F(SSLv2ClientHelloTestF, Connect13) { EnsureTlsSetup(); @@ -211,7 +235,7 @@ TEST_F(SSLv2ClientHelloTestF, Connect13) { std::vector cipher_suites = {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}; SetAvailableCipherSuites(cipher_suites); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); } @@ -238,7 +262,7 @@ TEST_P(SSLv2ClientHelloTest, SendSecurityEscape) { // Set a big padding so that the server fails instead of timing out. SetPadding(255); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); } // Invalid SSLv2 client hello padding must fail the handshake. @@ -248,7 +272,7 @@ TEST_P(SSLv2ClientHelloTest, AddErroneousPadding) { // Append 5 bytes of padding but say it's only 4. SetPadding(5, 4); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); } @@ -259,7 +283,7 @@ TEST_P(SSLv2ClientHelloTest, AddErroneousPadding2) { // Append 5 bytes of padding but say it's 6. SetPadding(5, 6); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); } @@ -270,7 +294,7 @@ TEST_P(SSLv2ClientHelloTest, SmallClientRandom) { // Send a ClientRandom that's too small. SetClientRandomLength(15); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); } @@ -288,7 +312,7 @@ TEST_P(SSLv2ClientHelloTest, BigClientRandom) { // Send a ClientRandom that's too big. SetClientRandomLength(33); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); } @@ -297,7 +321,7 @@ TEST_P(SSLv2ClientHelloTest, BigClientRandom) { TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) { RequireSafeRenegotiation(); SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); EXPECT_EQ(SSL_ERROR_UNSAFE_NEGOTIATION, server_->error_code()); } @@ -339,7 +363,7 @@ TEST_F(SSLv2ClientHelloTestF, InappropriateFallbackSCSV) { TLS_FALLBACK_SCSV}; SetAvailableCipherSuites(cipher_suites); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertInappropriateFallback); EXPECT_EQ(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT, server_->error_code()); } diff --git a/nss/gtests/ssl_gtest/ssl_version_unittest.cc b/nss/gtests/ssl_gtest/ssl_version_unittest.cc index b353849..379a67e 100644 --- a/nss/gtests/ssl_gtest/ssl_version_unittest.cc +++ b/nss/gtests/ssl_gtest/ssl_version_unittest.cc @@ -57,7 +57,8 @@ TEST_P(TlsConnectGeneric, ServerNegotiateTls12) { // SSL_SetDowngradeCheckVersion() API. TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) { client_->SetPacketFilter( - new TlsInspectorClientHelloVersionSetter(SSL_LIBRARY_VERSION_TLS_1_1)); + std::make_shared( + SSL_LIBRARY_VERSION_TLS_1_1)); ConnectExpectFail(); ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); } @@ -65,7 +66,8 @@ TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) { /* Attempt to negotiate the bogus DTLS 1.1 version. */ TEST_F(DtlsConnectTest, TestDtlsVersion11) { client_->SetPacketFilter( - new TlsInspectorClientHelloVersionSetter(((~0x0101) & 0xffff))); + std::make_shared( + ((~0x0101) & 0xffff))); ConnectExpectFail(); // It's kind of surprising that SSL_ERROR_NO_CYPHER_OVERLAP is // what is returned here, but this is deliberate in ssl3_HandleAlert(). @@ -77,7 +79,8 @@ TEST_F(DtlsConnectTest, TestDtlsVersion11) { TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) { EnsureTlsSetup(); client_->SetPacketFilter( - new TlsInspectorClientHelloVersionSetter(SSL_LIBRARY_VERSION_TLS_1_2)); + std::make_shared( + SSL_LIBRARY_VERSION_TLS_1_2)); client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, @@ -90,7 +93,8 @@ TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) { // instead get a handshake failure alert from the server. TEST_F(TlsConnectTest, TestDowngradeDetectionToTls10) { client_->SetPacketFilter( - new TlsInspectorClientHelloVersionSetter(SSL_LIBRARY_VERSION_TLS_1_0)); + std::make_shared( + SSL_LIBRARY_VERSION_TLS_1_0)); client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, SSL_LIBRARY_VERSION_TLS_1_1); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, @@ -123,6 +127,18 @@ TEST_F(TlsConnectTest, TestFallbackFromTls13) { } #endif +TEST_P(TlsConnectGeneric, TestFallbackSCSVVersionMatch) { + client_->SetFallbackSCSVEnabled(true); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, TestFallbackSCSVVersionMismatch) { + client_->SetFallbackSCSVEnabled(true); + server_->SetVersionRange(version_, version_ + 1); + ConnectExpectAlert(server_, kTlsAlertInappropriateFallback); + client_->CheckErrorCode(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT); +} + // The TLS v1.3 spec section C.4 states that 'Implementations MUST NOT send or // accept any records with a version less than { 3, 0 }'. Thus we will not // allow version ranges including both SSL v3 and TLS v1.3. @@ -161,6 +177,13 @@ TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) { // doesn't fail. server_->ResetPreliminaryInfo(); server_->StartRenegotiate(); + + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + } else { + ExpectAlert(client_, kTlsAlertIllegalParameter); + } + Handshake(); if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { // In TLS 1.3, the server detects this problem. @@ -194,6 +217,11 @@ TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) { // doesn't fail. server_->ResetPreliminaryInfo(); client_->StartRenegotiate(); + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + } else { + ExpectAlert(client_, kTlsAlertIllegalParameter); + } Handshake(); if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { // In TLS 1.3, the server detects this problem. @@ -225,13 +253,14 @@ TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) { TEST_P(TlsConnectGeneric, AlertBeforeServerHello) { EnsureTlsSetup(); + client_->ExpectReceiveAlert(kTlsAlertUnrecognizedName, kTlsAlertWarning); client_->StartConnect(); server_->StartConnect(); client_->Handshake(); // Send ClientHello. static const uint8_t kWarningAlert[] = {kTlsAlertWarning, kTlsAlertUnrecognizedName}; DataBuffer alert; - TlsAgentTestBase::MakeRecord(mode_, kTlsAlertType, + TlsAgentTestBase::MakeRecord(variant_, kTlsAlertType, SSL_LIBRARY_VERSION_TLS_1_0, kWarningAlert, PR_ARRAY_SIZE(kWarningAlert), &alert); client_->adapter()->PacketReceived(alert); @@ -246,11 +275,12 @@ class Tls13NoSupportedVersions : public TlsConnectStreamTls12 { SSL_LIBRARY_VERSION_TLS_1_2); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, max_server_version); client_->SetPacketFilter( - new TlsInspectorClientHelloVersionSetter(overwritten_client_version)); - auto capture = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerHello); + std::make_shared( + overwritten_client_version)); + auto capture = std::make_shared( + kTlsHandshakeServerHello); server_->SetPacketFilter(capture); - ConnectExpectFail(); + ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); const DataBuffer& server_hello = capture->buffer(); @@ -281,11 +311,14 @@ TEST_F(Tls13NoSupportedVersions, // Offer 1.3 but with ClientHello.legacy_version == TLS 1.4. This // causes a bad MAC error when we read EncryptedExtensions. TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) { - client_->SetPacketFilter(new TlsInspectorClientHelloVersionSetter( - SSL_LIBRARY_VERSION_TLS_1_3 + 1)); - auto capture = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerHello); + client_->SetPacketFilter( + std::make_shared( + SSL_LIBRARY_VERSION_TLS_1_3 + 1)); + auto capture = std::make_shared( + kTlsHandshakeServerHello); server_->SetPacketFilter(capture); + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); diff --git a/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc b/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc new file mode 100644 index 0000000..eda9683 --- /dev/null +++ b/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc @@ -0,0 +1,394 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nss.h" +#include "secerr.h" +#include "ssl.h" +#include "ssl3prot.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +#include + +namespace nss_test { + +std::string GetSSLVersionString(uint16_t v) { + switch (v) { + case SSL_LIBRARY_VERSION_3_0: + return "ssl3"; + case SSL_LIBRARY_VERSION_TLS_1_0: + return "tls1.0"; + case SSL_LIBRARY_VERSION_TLS_1_1: + return "tls1.1"; + case SSL_LIBRARY_VERSION_TLS_1_2: + return "tls1.2"; + case SSL_LIBRARY_VERSION_TLS_1_3: + return "tls1.3"; + case SSL_LIBRARY_VERSION_NONE: + return "NONE"; + } + if (v < SSL_LIBRARY_VERSION_3_0) { + return "undefined-too-low"; + } + return "undefined-too-high"; +} + +inline std::ostream& operator<<(std::ostream& stream, + const SSLVersionRange& vr) { + return stream << GetSSLVersionString(vr.min) << "," + << GetSSLVersionString(vr.max); +} + +class VersionRangeWithLabel { + public: + VersionRangeWithLabel(const std::string& label, const SSLVersionRange& vr) + : label_(label), vr_(vr) {} + VersionRangeWithLabel(const std::string& label, uint16_t min, uint16_t max) + : label_(label) { + vr_.min = min; + vr_.max = max; + } + VersionRangeWithLabel(const std::string& label) : label_(label) { + vr_.min = vr_.max = SSL_LIBRARY_VERSION_NONE; + } + + void WriteStream(std::ostream& stream) const { + stream << " " << label_ << ": " << vr_; + } + + uint16_t min() const { return vr_.min; } + uint16_t max() const { return vr_.max; } + SSLVersionRange range() const { return vr_; } + + private: + std::string label_; + SSLVersionRange vr_; +}; + +inline std::ostream& operator<<(std::ostream& stream, + const VersionRangeWithLabel& vrwl) { + vrwl.WriteStream(stream); + return stream; +} + +typedef std::tuple // input max + PolicyVersionRangeInput; + +class TestPolicyVersionRange + : public TlsConnectTestBase, + public ::testing::WithParamInterface { + public: + TestPolicyVersionRange() + : TlsConnectTestBase(std::get<0>(GetParam()), 0), + variant_(std::get<0>(GetParam())), + policy_("policy", std::get<1>(GetParam()), std::get<2>(GetParam())), + input_("input", std::get<3>(GetParam()), std::get<4>(GetParam())), + library_("supported-by-library", + ((variant_ == ssl_variant_stream) + ? SSL_LIBRARY_VERSION_MIN_SUPPORTED_STREAM + : SSL_LIBRARY_VERSION_MIN_SUPPORTED_DATAGRAM), + SSL_LIBRARY_VERSION_MAX_SUPPORTED) { + TlsConnectTestBase::SkipVersionChecks(); + } + + void SetPolicy(const SSLVersionRange& policy) { + NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, 0); + + SECStatus rv; + rv = NSS_OptionSet(NSS_TLS_VERSION_MIN_POLICY, policy.min); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionSet(NSS_TLS_VERSION_MAX_POLICY, policy.max); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionSet(NSS_DTLS_VERSION_MIN_POLICY, policy.min); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, policy.max); + ASSERT_EQ(SECSuccess, rv); + } + + void CreateDummySocket(std::shared_ptr* dummy_socket, + ScopedPRFileDesc* ssl_fd) { + (*dummy_socket).reset(new DummyPrSocket("dummy", variant_)); + *ssl_fd = (*dummy_socket)->CreateFD(); + if (variant_ == ssl_variant_stream) { + SSL_ImportFD(nullptr, ssl_fd->get()); + } else { + DTLS_ImportFD(nullptr, ssl_fd->get()); + } + } + + bool GetOverlap(const SSLVersionRange& r1, const SSLVersionRange& r2, + SSLVersionRange* overlap) { + if (r1.min == SSL_LIBRARY_VERSION_NONE || + r1.max == SSL_LIBRARY_VERSION_NONE || + r2.min == SSL_LIBRARY_VERSION_NONE || + r2.max == SSL_LIBRARY_VERSION_NONE) { + return false; + } + + SSLVersionRange temp; + temp.min = PR_MAX(r1.min, r2.min); + temp.max = PR_MIN(r1.max, r2.max); + + if (temp.min > temp.max) { + return false; + } + + *overlap = temp; + return true; + } + + bool IsValidInputForVersionRangeSet(SSLVersionRange* expectedEffectiveRange) { + if (input_.min() <= SSL_LIBRARY_VERSION_3_0 && + input_.max() >= SSL_LIBRARY_VERSION_TLS_1_3) { + // This is always invalid input, independent of policy + return false; + } + + if (input_.min() < library_.min() || input_.max() > library_.max() || + input_.min() > input_.max()) { + // Asking for unsupported ranges is invalid input for VersionRangeSet + // APIs, regardless of overlap. + return false; + } + + SSLVersionRange overlap_with_library; + if (!GetOverlap(input_.range(), library_.range(), &overlap_with_library)) { + return false; + } + + SSLVersionRange overlap_with_library_and_policy; + if (!GetOverlap(overlap_with_library, policy_.range(), + &overlap_with_library_and_policy)) { + return false; + } + + RemoveConflictingVersions(variant_, &overlap_with_library_and_policy); + *expectedEffectiveRange = overlap_with_library_and_policy; + return true; + } + + void RemoveConflictingVersions(SSLProtocolVariant variant, + SSLVersionRange* r) { + ASSERT_TRUE(r != nullptr); + if (r->max >= SSL_LIBRARY_VERSION_TLS_1_3 && + r->min < SSL_LIBRARY_VERSION_TLS_1_0) { + r->min = SSL_LIBRARY_VERSION_TLS_1_0; + } + } + + void SetUp() { + SetPolicy(policy_.range()); + TlsConnectTestBase::SetUp(); + } + + void TearDown() { + TlsConnectTestBase::TearDown(); + saved_version_policy_.RestoreOriginalPolicy(); + } + + protected: + class VersionPolicy { + public: + VersionPolicy() { SaveOriginalPolicy(); } + + void RestoreOriginalPolicy() { + SECStatus rv; + rv = NSS_OptionSet(NSS_TLS_VERSION_MIN_POLICY, saved_min_tls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionSet(NSS_TLS_VERSION_MAX_POLICY, saved_max_tls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionSet(NSS_DTLS_VERSION_MIN_POLICY, saved_min_dtls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, saved_max_dtls_); + ASSERT_EQ(SECSuccess, rv); + // If it wasn't set initially, clear the bit that we set. + if (!(saved_algorithm_policy_ & NSS_USE_POLICY_IN_SSL)) { + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, 0, + NSS_USE_POLICY_IN_SSL); + ASSERT_EQ(SECSuccess, rv); + } + } + + private: + void SaveOriginalPolicy() { + SECStatus rv; + rv = NSS_OptionGet(NSS_TLS_VERSION_MIN_POLICY, &saved_min_tls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionGet(NSS_TLS_VERSION_MAX_POLICY, &saved_max_tls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionGet(NSS_DTLS_VERSION_MIN_POLICY, &saved_min_dtls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_OptionGet(NSS_DTLS_VERSION_MAX_POLICY, &saved_max_dtls_); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_GetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, + &saved_algorithm_policy_); + ASSERT_EQ(SECSuccess, rv); + } + + int32_t saved_min_tls_; + int32_t saved_max_tls_; + int32_t saved_min_dtls_; + int32_t saved_max_dtls_; + uint32_t saved_algorithm_policy_; + }; + + VersionPolicy saved_version_policy_; + + SSLProtocolVariant variant_; + const VersionRangeWithLabel policy_; + const VersionRangeWithLabel input_; + const VersionRangeWithLabel library_; +}; + +static const uint16_t kExpandedVersionsArr[] = { + /* clang-format off */ + SSL_LIBRARY_VERSION_3_0 - 1, + SSL_LIBRARY_VERSION_3_0, + SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2, +#ifndef NSS_DISABLE_TLS_1_3 + SSL_LIBRARY_VERSION_TLS_1_3, +#endif + SSL_LIBRARY_VERSION_MAX_SUPPORTED + 1 + /* clang-format on */ +}; +static ::testing::internal::ParamGenerator kExpandedVersions = + ::testing::ValuesIn(kExpandedVersionsArr); + +TEST_P(TestPolicyVersionRange, TestAllTLSVersionsAndPolicyCombinations) { + ASSERT_TRUE(variant_ == ssl_variant_stream || + variant_ == ssl_variant_datagram) + << "testing unsupported ssl variant"; + + std::cerr << "testing: " << variant_ << policy_ << input_ << library_ + << std::endl; + + SSLVersionRange supported_range; + SECStatus rv = SSL_VersionRangeGetSupported(variant_, &supported_range); + VersionRangeWithLabel supported("SSL_VersionRangeGetSupported", + supported_range); + + std::cerr << supported << std::endl; + + std::shared_ptr dummy_socket; + ScopedPRFileDesc ssl_fd; + CreateDummySocket(&dummy_socket, &ssl_fd); + + SECStatus rv_socket; + SSLVersionRange overlap_policy_and_lib; + if (!GetOverlap(policy_.range(), library_.range(), &overlap_policy_and_lib)) { + EXPECT_EQ(SECFailure, rv) + << "expected SSL_VersionRangeGetSupported to fail with invalid policy"; + + SSLVersionRange enabled_range; + rv = SSL_VersionRangeGetDefault(variant_, &enabled_range); + EXPECT_EQ(SECFailure, rv) + << "expected SSL_VersionRangeGetDefault to fail with invalid policy"; + + SSLVersionRange enabled_range_on_socket; + rv_socket = SSL_VersionRangeGet(ssl_fd.get(), &enabled_range_on_socket); + EXPECT_EQ(SECFailure, rv_socket) + << "expected SSL_VersionRangeGet to fail with invalid policy"; + + ConnectExpectFail(); + return; + } + + EXPECT_EQ(SECSuccess, rv) + << "expected SSL_VersionRangeGetSupported to succeed with valid policy"; + + EXPECT_TRUE(supported_range.min != SSL_LIBRARY_VERSION_NONE && + supported_range.max != SSL_LIBRARY_VERSION_NONE) + << "expected SSL_VersionRangeGetSupported to return real values with " + "valid policy"; + + RemoveConflictingVersions(variant_, &overlap_policy_and_lib); + VersionRangeWithLabel overlap_info("overlap", overlap_policy_and_lib); + + EXPECT_TRUE(supported_range == overlap_policy_and_lib) + << "expected range from GetSupported to be identical with calculated " + "overlap " + << overlap_info; + + // We don't know which versions are "enabled by default" by the library, + // therefore we don't know if there's overlap between the default + // and the policy, and therefore, we don't if TLS connections should + // be successful or fail in this combination. + // Therefore we don't test if we can connect, without having configured a + // version range explicitly. + + // Now start testing with supplied input. + + SSLVersionRange expected_effective_range; + bool is_valid_input = + IsValidInputForVersionRangeSet(&expected_effective_range); + + SSLVersionRange temp_input = input_.range(); + rv = SSL_VersionRangeSetDefault(variant_, &temp_input); + rv_socket = SSL_VersionRangeSet(ssl_fd.get(), &temp_input); + + if (!is_valid_input) { + EXPECT_EQ(SECFailure, rv) + << "expected failure return from SSL_VersionRangeSetDefault"; + + EXPECT_EQ(SECFailure, rv_socket) + << "expected failure return from SSL_VersionRangeSet"; + return; + } + + EXPECT_EQ(SECSuccess, rv) + << "expected successful return from SSL_VersionRangeSetDefault"; + + EXPECT_EQ(SECSuccess, rv_socket) + << "expected successful return from SSL_VersionRangeSet"; + + SSLVersionRange effective; + SSLVersionRange effective_socket; + + rv = SSL_VersionRangeGetDefault(variant_, &effective); + EXPECT_EQ(SECSuccess, rv) + << "expected successful return from SSL_VersionRangeGetDefault"; + + rv_socket = SSL_VersionRangeGet(ssl_fd.get(), &effective_socket); + EXPECT_EQ(SECSuccess, rv_socket) + << "expected successful return from SSL_VersionRangeGet"; + + VersionRangeWithLabel expected_info("expectation", expected_effective_range); + VersionRangeWithLabel effective_info("effectively-enabled", effective); + + EXPECT_TRUE(expected_effective_range == effective) + << "range returned by SSL_VersionRangeGetDefault doesn't match " + "expectation: " + << expected_info << effective_info; + + EXPECT_TRUE(expected_effective_range == effective_socket) + << "range returned by SSL_VersionRangeGet doesn't match " + "expectation: " + << expected_info << effective_info; + + // Because we found overlap between policy and supported versions, + // and because we have used SetDefault to enable at least one version, + // it should be possible to execute an SSL/TLS connection. + Connect(); +} + +INSTANTIATE_TEST_CASE_P(TLSVersionRanges, TestPolicyVersionRange, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + kExpandedVersions, kExpandedVersions, + kExpandedVersions, + kExpandedVersions)); +} // namespace nss_test diff --git a/nss/gtests/ssl_gtest/test_io.cc b/nss/gtests/ssl_gtest/test_io.cc index f3fd0b2..b9f0c67 100644 --- a/nss/gtests/ssl_gtest/test_io.cc +++ b/nss/gtests/ssl_gtest/test_io.cc @@ -15,314 +15,33 @@ #include "prlog.h" #include "prthread.h" -#include "databuffer.h" - extern bool g_ssl_gtest_verbose; namespace nss_test { -static PRDescIdentity test_fd_identity = PR_INVALID_IO_LAYER; - -#define UNIMPLEMENTED() \ - std::cerr << "Call to unimplemented function " << __FUNCTION__ << std::endl; \ - PR_ASSERT(PR_FALSE); \ - PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0) - #define LOG(a) std::cerr << name_ << ": " << a << std::endl #define LOGV(a) \ do { \ if (g_ssl_gtest_verbose) LOG(a); \ } while (false) -class Packet : public DataBuffer { - public: - Packet(const DataBuffer &buf) : DataBuffer(buf), offset_(0) {} - - void Advance(size_t delta) { - PR_ASSERT(offset_ + delta <= len()); - offset_ = std::min(len(), offset_ + delta); - } - - size_t offset() const { return offset_; } - size_t remaining() const { return len() - offset_; } - - private: - size_t offset_; -}; - -// Implementation of NSPR methods -static PRStatus DummyClose(PRFileDesc *f) { - DummyPrSocket *io = reinterpret_cast(f->secret); - f->secret = nullptr; - f->dtor(f); - delete io; - return PR_SUCCESS; -} - -static int32_t DummyRead(PRFileDesc *f, void *buf, int32_t length) { - DummyPrSocket *io = reinterpret_cast(f->secret); - return io->Read(buf, length); -} - -static int32_t DummyWrite(PRFileDesc *f, const void *buf, int32_t length) { - DummyPrSocket *io = reinterpret_cast(f->secret); - return io->Write(buf, length); -} - -static int32_t DummyAvailable(PRFileDesc *f) { - UNIMPLEMENTED(); - return -1; -} - -int64_t DummyAvailable64(PRFileDesc *f) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummySync(PRFileDesc *f) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static int32_t DummySeek(PRFileDesc *f, int32_t offset, PRSeekWhence how) { - UNIMPLEMENTED(); - return -1; -} - -static int64_t DummySeek64(PRFileDesc *f, int64_t offset, PRSeekWhence how) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummyFileInfo(PRFileDesc *f, PRFileInfo *info) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRStatus DummyFileInfo64(PRFileDesc *f, PRFileInfo64 *info) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static int32_t DummyWritev(PRFileDesc *f, const PRIOVec *iov, int32_t iov_size, - PRIntervalTime to) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummyConnect(PRFileDesc *f, const PRNetAddr *addr, - PRIntervalTime to) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRFileDesc *DummyAccept(PRFileDesc *sd, PRNetAddr *addr, - PRIntervalTime to) { - UNIMPLEMENTED(); - return nullptr; -} - -static PRStatus DummyBind(PRFileDesc *f, const PRNetAddr *addr) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRStatus DummyListen(PRFileDesc *f, int32_t depth) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRStatus DummyShutdown(PRFileDesc *f, int32_t how) { - DummyPrSocket *io = reinterpret_cast(f->secret); - io->Reset(); - return PR_SUCCESS; -} - -// This function does not support peek. -static int32_t DummyRecv(PRFileDesc *f, void *buf, int32_t buflen, - int32_t flags, PRIntervalTime to) { - PR_ASSERT(flags == 0); - if (flags != 0) { - PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); - return -1; - } - - DummyPrSocket *io = reinterpret_cast(f->secret); - - if (io->mode() == DGRAM) { - return io->Recv(buf, buflen); - } else { - return io->Read(buf, buflen); - } -} - -// Note: this is always nonblocking and assumes a zero timeout. -static int32_t DummySend(PRFileDesc *f, const void *buf, int32_t amount, - int32_t flags, PRIntervalTime to) { - int32_t written = DummyWrite(f, buf, amount); - return written; -} - -static int32_t DummyRecvfrom(PRFileDesc *f, void *buf, int32_t amount, - int32_t flags, PRNetAddr *addr, - PRIntervalTime to) { - UNIMPLEMENTED(); - return -1; -} - -static int32_t DummySendto(PRFileDesc *f, const void *buf, int32_t amount, - int32_t flags, const PRNetAddr *addr, - PRIntervalTime to) { - UNIMPLEMENTED(); - return -1; -} - -static int16_t DummyPoll(PRFileDesc *f, int16_t in_flags, int16_t *out_flags) { - UNIMPLEMENTED(); - return -1; -} - -static int32_t DummyAcceptRead(PRFileDesc *sd, PRFileDesc **nd, - PRNetAddr **raddr, void *buf, int32_t amount, - PRIntervalTime t) { - UNIMPLEMENTED(); - return -1; -} - -static int32_t DummyTransmitFile(PRFileDesc *sd, PRFileDesc *f, - const void *headers, int32_t hlen, - PRTransmitFileFlags flags, PRIntervalTime t) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummyGetpeername(PRFileDesc *f, PRNetAddr *addr) { - // TODO: Modify to return unique names for each channel - // somehow, as opposed to always the same static address. The current - // implementation messes up the session cache, which is why it's off - // elsewhere - addr->inet.family = PR_AF_INET; - addr->inet.port = 0; - addr->inet.ip = 0; - - return PR_SUCCESS; -} - -static PRStatus DummyGetsockname(PRFileDesc *f, PRNetAddr *addr) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRStatus DummyGetsockoption(PRFileDesc *f, PRSocketOptionData *opt) { - switch (opt->option) { - case PR_SockOpt_Nonblocking: - opt->value.non_blocking = PR_TRUE; - return PR_SUCCESS; - default: - UNIMPLEMENTED(); - break; - } - - return PR_FAILURE; -} - -// Imitate setting socket options. These are mostly noops. -static PRStatus DummySetsockoption(PRFileDesc *f, - const PRSocketOptionData *opt) { - switch (opt->option) { - case PR_SockOpt_Nonblocking: - return PR_SUCCESS; - case PR_SockOpt_NoDelay: - return PR_SUCCESS; - default: - UNIMPLEMENTED(); - break; - } - - return PR_FAILURE; -} - -static int32_t DummySendfile(PRFileDesc *out, PRSendFileData *in, - PRTransmitFileFlags flags, PRIntervalTime to) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummyConnectContinue(PRFileDesc *f, int16_t flags) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static int32_t DummyReserved(PRFileDesc *f) { - UNIMPLEMENTED(); - return -1; -} - -DummyPrSocket::~DummyPrSocket() { Reset(); } - -void DummyPrSocket::SetPacketFilter(PacketFilter *filter) { - if (filter_) { - delete filter_; - } +void DummyPrSocket::SetPacketFilter(std::shared_ptr filter) { filter_ = filter; } -void DummyPrSocket::Reset() { - delete filter_; - if (peer_) { - peer_->SetPeer(nullptr); - peer_ = nullptr; - } - while (!input_.empty()) { - Packet *front = input_.front(); - input_.pop(); - delete front; - } -} - -static const struct PRIOMethods DummyMethods = { - PR_DESC_LAYERED, DummyClose, - DummyRead, DummyWrite, - DummyAvailable, DummyAvailable64, - DummySync, DummySeek, - DummySeek64, DummyFileInfo, - DummyFileInfo64, DummyWritev, - DummyConnect, DummyAccept, - DummyBind, DummyListen, - DummyShutdown, DummyRecv, - DummySend, DummyRecvfrom, - DummySendto, DummyPoll, - DummyAcceptRead, DummyTransmitFile, - DummyGetsockname, DummyGetpeername, - DummyReserved, DummyReserved, - DummyGetsockoption, DummySetsockoption, - DummySendfile, DummyConnectContinue, - DummyReserved, DummyReserved, - DummyReserved, DummyReserved}; - -PRFileDesc *DummyPrSocket::CreateFD(const std::string &name, Mode mode) { - if (test_fd_identity == PR_INVALID_IO_LAYER) { - test_fd_identity = PR_GetUniqueIdentity("testtransportadapter"); - } - - PRFileDesc *fd = (PR_CreateIOLayerStub(test_fd_identity, &DummyMethods)); - fd->secret = reinterpret_cast(new DummyPrSocket(name, mode)); - - return fd; -} - -DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) { - return reinterpret_cast(fd->secret); +ScopedPRFileDesc DummyPrSocket::CreateFD() { + static PRDescIdentity test_fd_identity = + PR_GetUniqueIdentity("testtransportadapter"); + return DummyIOLayerMethods::CreateFD(test_fd_identity, this); } void DummyPrSocket::PacketReceived(const DataBuffer &packet) { - input_.push(new Packet(packet)); + input_.push(Packet(packet)); } -int32_t DummyPrSocket::Read(void *data, int32_t len) { - PR_ASSERT(mode_ == STREAM); - - if (mode_ != STREAM) { +int32_t DummyPrSocket::Read(PRFileDesc *f, void *data, int32_t len) { + PR_ASSERT(variant_ == ssl_variant_stream); + if (variant_ != ssl_variant_stream) { PR_SetError(PR_INVALID_METHOD_ERROR, 0); return -1; } @@ -333,45 +52,54 @@ int32_t DummyPrSocket::Read(void *data, int32_t len) { return -1; } - Packet *front = input_.front(); + auto &front = input_.front(); size_t to_read = - std::min(static_cast(len), front->len() - front->offset()); - memcpy(data, static_cast(front->data() + front->offset()), + std::min(static_cast(len), front.len() - front.offset()); + memcpy(data, static_cast(front.data() + front.offset()), to_read); - front->Advance(to_read); + front.Advance(to_read); - if (!front->remaining()) { + if (!front.remaining()) { input_.pop(); - delete front; } return static_cast(to_read); } -int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) { +int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen, + int32_t flags, PRIntervalTime to) { + PR_ASSERT(flags == 0); + if (flags != 0) { + PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); + return -1; + } + + if (variant() != ssl_variant_datagram) { + return Read(f, buf, buflen); + } + if (input_.empty()) { PR_SetError(PR_WOULD_BLOCK_ERROR, 0); return -1; } - Packet *front = input_.front(); - if (static_cast(buflen) < front->len()) { + auto &front = input_.front(); + if (static_cast(buflen) < front.len()) { PR_ASSERT(false); PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0); return -1; } - size_t count = front->len(); - memcpy(buf, front->data(), count); + size_t count = front.len(); + memcpy(buf, front.data(), count); input_.pop(); - delete front; - return static_cast(count); } -int32_t DummyPrSocket::Write(const void *buf, int32_t length) { - if (!peer_ || !writeable_) { +int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) { + auto peer = peer_.lock(); + if (!peer || !writeable_) { PR_SetError(PR_IO_ERROR, 0); return -1; } @@ -387,14 +115,14 @@ int32_t DummyPrSocket::Write(const void *buf, int32_t length) { case PacketFilter::CHANGE: LOG("Original packet: " << packet); LOG("Filtered packet: " << filtered); - peer_->PacketReceived(filtered); + peer->PacketReceived(filtered); break; case PacketFilter::DROP: LOG("Droppped packet: " << packet); break; case PacketFilter::KEEP: LOGV("Packet: " << packet); - peer_->PacketReceived(packet); + peer->PacketReceived(packet); break; } // libssl can't handle it if this reports something other than the length @@ -415,43 +143,31 @@ void Poller::Shutdown() { instance = nullptr; } -Poller::~Poller() { - while (!timers_.empty()) { - Timer *timer = timers_.top(); - timers_.pop(); - delete timer; - } -} +void Poller::Wait(Event event, std::shared_ptr &adapter, + PollTarget *target, PollCallback cb) { + assert(event < TIMER_EVENT); + if (event >= TIMER_EVENT) return; -void Poller::Wait(Event event, DummyPrSocket *adapter, PollTarget *target, - PollCallback cb) { + std::unique_ptr waiter; auto it = waiters_.find(adapter); - Waiter *waiter; - if (it == waiters_.end()) { - waiter = new Waiter(adapter); + waiter.reset(new Waiter(adapter)); } else { - waiter = it->second; + waiter = std::move(it->second); } - assert(event < TIMER_EVENT); - if (event >= TIMER_EVENT) return; - waiter->targets_[event] = target; waiter->callbacks_[event] = cb; - waiters_[adapter] = waiter; + waiters_[adapter] = std::move(waiter); } -void Poller::Cancel(Event event, DummyPrSocket *adapter) { +void Poller::Cancel(Event event, std::shared_ptr &adapter) { auto it = waiters_.find(adapter); - Waiter *waiter; - if (it == waiters_.end()) { return; } - waiter = it->second; - + auto &waiter = it->second; waiter->targets_[event] = nullptr; waiter->callbacks_[event] = nullptr; @@ -460,13 +176,12 @@ void Poller::Cancel(Event event, DummyPrSocket *adapter) { if (waiter->callbacks_[i]) return; } - delete waiter; waiters_.erase(adapter); } void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb, - Timer **timer) { - Timer *t = new Timer(PR_Now() + timer_ms * 1000, target, cb); + std::shared_ptr *timer) { + auto t = std::make_shared(PR_Now() + timer_ms * 1000, target, cb); timers_.push(t); if (timer) *timer = t; } @@ -482,7 +197,7 @@ bool Poller::Poll() { // Figure out the timer for the select. if (!timers_.empty()) { - Timer *first_timer = timers_.top(); + auto first_timer = timers_.top(); if (now >= first_timer->deadline_) { // Timer expired. timeout = PR_INTERVAL_NO_WAIT; @@ -493,7 +208,7 @@ bool Poller::Poll() { } for (auto it = waiters_.begin(); it != waiters_.end(); ++it) { - Waiter *waiter = it->second; + auto &waiter = it->second; if (waiter->callbacks_[READABLE_EVENT]) { if (waiter->io_->readable()) { @@ -522,12 +237,11 @@ bool Poller::Poll() { while (!timers_.empty()) { if (now < timers_.top()->deadline_) break; - Timer *timer = timers_.top(); + auto timer = timers_.top(); timers_.pop(); if (timer->callback_) { timer->callback_(timer->target_, TIMER_EVENT); } - delete timer; } return true; diff --git a/nss/gtests/ssl_gtest/test_io.h b/nss/gtests/ssl_gtest/test_io.h index b78db0d..ac24972 100644 --- a/nss/gtests/ssl_gtest/test_io.h +++ b/nss/gtests/ssl_gtest/test_io.h @@ -14,12 +14,15 @@ #include #include +#include "databuffer.h" +#include "dummy_io.h" #include "prio.h" +#include "scoped_ptrs.h" +#include "sslt.h" namespace nss_test { class DataBuffer; -class Packet; class DummyPrSocket; // Fwd decl. // Allow us to inspect a packet before it is written. @@ -42,49 +45,59 @@ class PacketFilter { virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0; }; -enum Mode { STREAM, DGRAM }; - -inline std::ostream& operator<<(std::ostream& os, Mode m) { - return os << ((m == STREAM) ? "TLS" : "DTLS"); -} - -class DummyPrSocket { +class DummyPrSocket : public DummyIOLayerMethods { public: - ~DummyPrSocket(); + DummyPrSocket(const std::string& name, SSLProtocolVariant variant) + : name_(name), + variant_(variant), + peer_(), + input_(), + filter_(nullptr), + writeable_(true) {} + virtual ~DummyPrSocket() {} - static PRFileDesc* CreateFD(const std::string& name, - Mode mode); // Returns an FD. - static DummyPrSocket* GetAdapter(PRFileDesc* fd); + // Create a file descriptor that will reference this object. The fd must not + // live longer than this adapter; call PR_Close() before. + ScopedPRFileDesc CreateFD(); - DummyPrSocket* peer() const { return peer_; } - void SetPeer(DummyPrSocket* peer) { peer_ = peer; } - void SetPacketFilter(PacketFilter* filter); + std::weak_ptr& peer() { return peer_; } + void SetPeer(const std::shared_ptr& peer) { peer_ = peer; } + void SetPacketFilter(std::shared_ptr filter); // Drops peer, packet filter and any outstanding packets. void Reset(); void PacketReceived(const DataBuffer& data); - int32_t Read(void* data, int32_t len); - int32_t Recv(void* buf, int32_t buflen); - int32_t Write(const void* buf, int32_t length); + int32_t Read(PRFileDesc* f, void* data, int32_t len) override; + int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags, + PRIntervalTime to) override; + int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override; void CloseWrites() { writeable_ = false; } - Mode mode() const { return mode_; } + SSLProtocolVariant variant() const { return variant_; } bool readable() const { return !input_.empty(); } private: - DummyPrSocket(const std::string& name, Mode mode) - : name_(name), - mode_(mode), - peer_(nullptr), - input_(), - filter_(nullptr), - writeable_(true) {} + class Packet : public DataBuffer { + public: + Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {} + + void Advance(size_t delta) { + PR_ASSERT(offset_ + delta <= len()); + offset_ = std::min(len(), offset_ + delta); + } + + size_t offset() const { return offset_; } + size_t remaining() const { return len() - offset_; } + + private: + size_t offset_; + }; const std::string name_; - Mode mode_; - DummyPrSocket* peer_; - std::queue input_; - PacketFilter* filter_; + SSLProtocolVariant variant_; + std::weak_ptr peer_; + std::queue input_; + std::shared_ptr filter_; bool writeable_; }; @@ -111,40 +124,44 @@ class Poller { PollCallback callback_; }; - void Wait(Event event, DummyPrSocket* adapter, PollTarget* target, - PollCallback cb); - void Cancel(Event event, DummyPrSocket* adapter); + void Wait(Event event, std::shared_ptr& adapter, + PollTarget* target, PollCallback cb); + void Cancel(Event event, std::shared_ptr& adapter); void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb, - Timer** handle); + std::shared_ptr* handle); bool Poll(); private: Poller() : waiters_(), timers_() {} - ~Poller(); + ~Poller() {} class Waiter { public: - Waiter(DummyPrSocket* io) : io_(io) { + Waiter(std::shared_ptr io) : io_(io) { + memset(&targets_[0], 0, sizeof(targets_)); memset(&callbacks_[0], 0, sizeof(callbacks_)); } void WaitFor(Event event, PollCallback callback); - DummyPrSocket* io_; + std::shared_ptr io_; PollTarget* targets_[TIMER_EVENT]; PollCallback callbacks_[TIMER_EVENT]; }; class TimerComparator { public: - bool operator()(const Timer* lhs, const Timer* rhs) { + bool operator()(const std::shared_ptr lhs, + const std::shared_ptr rhs) { return lhs->deadline_ > rhs->deadline_; } }; static Poller* instance; - std::map waiters_; - std::priority_queue, TimerComparator> timers_; + std::map, std::unique_ptr> waiters_; + std::priority_queue, + std::vector>, TimerComparator> + timers_; }; } // end of namespace diff --git a/nss/gtests/ssl_gtest/tls_agent.cc b/nss/gtests/ssl_gtest/tls_agent.cc index b75bba5..a53cf88 100644 --- a/nss/gtests/ssl_gtest/tls_agent.cc +++ b/nss/gtests/ssl_gtest/tls_agent.cc @@ -43,14 +43,14 @@ const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa"; const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa"; const std::string TlsAgent::kServerDsa = "dsa"; -TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode) +TlsAgent::TlsAgent(const std::string& name, Role role, + SSLProtocolVariant variant) : name_(name), - mode_(mode), + variant_(variant), + role_(role), server_key_bits_(0), - pr_fd_(nullptr), - adapter_(nullptr), + adapter_(new DummyPrSocket(role_str(), variant)), ssl_fd_(nullptr), - role_(role), state_(STATE_INIT), timer_handle_(nullptr), falsestart_enabled_(false), @@ -61,6 +61,10 @@ TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode) can_falsestart_hook_called_(false), sni_hook_called_(false), auth_certificate_hook_called_(false), + expected_received_alert_(kTlsAlertCloseNotify), + expected_received_alert_level_(kTlsAlertWarning), + expected_sent_alert_(kTlsAlertCloseNotify), + expected_sent_alert_level_(kTlsAlertWarning), handshake_callback_called_(false), error_code_(0), send_ctr_(0), @@ -69,29 +73,31 @@ TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode) handshake_callback_(), auth_certificate_callback_(), sni_callback_(), - expect_short_headers_(false) { + expect_short_headers_(false), + skip_version_checks_(false) { memset(&info_, 0, sizeof(info_)); memset(&csinfo_, 0, sizeof(csinfo_)); - SECStatus rv = SSL_VersionRangeGetDefault( - mode_ == STREAM ? ssl_variant_stream : ssl_variant_datagram, &vrange_); + SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_); EXPECT_EQ(SECSuccess, rv); } TlsAgent::~TlsAgent() { - if (adapter_) { - Poller::Instance()->Cancel(READABLE_EVENT, adapter_); - // The adapter is closed when the FD closes. - } if (timer_handle_) { timer_handle_->Cancel(); } - if (pr_fd_) { - PR_Close(pr_fd_); + if (adapter_) { + Poller::Instance()->Cancel(READABLE_EVENT, adapter_); } - if (ssl_fd_) { - PR_Close(ssl_fd_); + // Add failures manually, if any, so we don't throw in a destructor. + if (expected_received_alert_ != kTlsAlertCloseNotify || + expected_received_alert_level_ != kTlsAlertWarning) { + ADD_FAILURE() << "Wrong expected_received_alert status"; + } + if (expected_sent_alert_ != kTlsAlertCloseNotify || + expected_sent_alert_level_ != kTlsAlertWarning) { + ADD_FAILURE() << "Wrong expected_sent_alert status"; } } @@ -102,27 +108,39 @@ void TlsAgent::SetState(State state) { state_ = state; } +/*static*/ bool TlsAgent::LoadCertificate(const std::string& name, + ScopedCERTCertificate* cert, + ScopedSECKEYPrivateKey* priv) { + cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr)); + EXPECT_NE(nullptr, cert->get()); + if (!cert->get()) return false; + + priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr)); + EXPECT_NE(nullptr, priv->get()); + if (!priv->get()) return false; + + return true; +} + bool TlsAgent::ConfigServerCert(const std::string& name, bool updateKeyBits, const SSLExtraServerCertData* serverCertData) { - ScopedCERTCertificate cert(PK11_FindCertFromNickname(name.c_str(), nullptr)); - EXPECT_NE(nullptr, cert.get()); - if (!cert.get()) return false; + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + if (!TlsAgent::LoadCertificate(name, &cert, &priv)) { + return false; + } - ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get())); - EXPECT_NE(nullptr, pub.get()); - if (!pub.get()) return false; if (updateKeyBits) { + ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get())); + EXPECT_NE(nullptr, pub.get()); + if (!pub.get()) return false; server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get()); } - ScopedSECKEYPrivateKey priv(PK11_FindKeyByAnyCert(cert.get(), nullptr)); - EXPECT_NE(nullptr, priv.get()); - if (!priv.get()) return false; - SECStatus rv = - SSL_ConfigSecureServer(ssl_fd_, nullptr, nullptr, ssl_kea_null); + SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null); EXPECT_EQ(SECFailure, rv); - rv = SSL_ConfigServerCert(ssl_fd_, cert.get(), priv.get(), serverCertData, + rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData, serverCertData ? sizeof(*serverCertData) : 0); return rv == SECSuccess; } @@ -131,41 +149,59 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { // Don't set up twice if (ssl_fd_) return true; - if (adapter_->mode() == STREAM) { - ssl_fd_ = SSL_ImportFD(modelSocket, pr_fd_); + ScopedPRFileDesc dummy_fd(adapter_->CreateFD()); + EXPECT_NE(nullptr, dummy_fd); + if (!dummy_fd) { + return false; + } + if (adapter_->variant() == ssl_variant_stream) { + ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get())); } else { - ssl_fd_ = DTLS_ImportFD(modelSocket, pr_fd_); + ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get())); } EXPECT_NE(nullptr, ssl_fd_); - if (!ssl_fd_) return false; - pr_fd_ = nullptr; + if (!ssl_fd_) { + return false; + } + dummy_fd.release(); // Now subsumed by ssl_fd_. - SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_); - EXPECT_EQ(SECSuccess, rv); - if (rv != SECSuccess) return false; + SECStatus rv; + if (!skip_version_checks_) { + rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + } if (role_ == SERVER) { EXPECT_TRUE(ConfigServerCert(name_, true)); - rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook, this); + rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; ScopedCERTCertList anchors(CERT_NewCertList()); - rv = SSL_SetTrustAnchors(ssl_fd_, anchors.get()); + rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get()); if (rv != SECSuccess) return false; } else { - rv = SSL_SetURL(ssl_fd_, "server"); + rv = SSL_SetURL(ssl_fd(), "server"); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; } - rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this); + rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; - rv = SSL_HandshakeCallback(ssl_fd_, HandshakeCallback, this); + rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; @@ -177,38 +213,31 @@ void TlsAgent::SetupClientAuth() { ASSERT_EQ(CLIENT, role_); EXPECT_EQ(SECSuccess, - SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook, + SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook, reinterpret_cast(this))); } -bool TlsAgent::GetClientAuthCredentials(CERTCertificate** cert, - SECKEYPrivateKey** priv) const { - *cert = PK11_FindCertFromNickname(name_.c_str(), nullptr); - EXPECT_NE(nullptr, *cert); - if (!*cert) return false; - - *priv = PK11_FindKeyByAnyCert(*cert, nullptr); - EXPECT_NE(nullptr, *priv); - if (!*priv) return false; // Leak cert. - - return true; -} - SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, CERTDistNames* caNames, - CERTCertificate** cert, - SECKEYPrivateKey** privKey) { + CERTCertificate** clientCert, + SECKEYPrivateKey** clientKey) { TlsAgent* agent = reinterpret_cast(self); ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd())); EXPECT_TRUE(peerCert) << "Client should be able to see the server cert"; - if (agent->GetClientAuthCredentials(cert, privKey)) { - return SECSuccess; + + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) { + return SECFailure; } - return SECFailure; + + *clientCert = cert.release(); + *clientKey = priv.release(); + return SECSuccess; } bool TlsAgent::GetPeerChainLength(size_t* count) { - CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd_); + CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd()); if (!chain) return false; *count = 0; @@ -224,17 +253,21 @@ bool TlsAgent::GetPeerChainLength(size_t* count) { return true; } +void TlsAgent::CheckCipherSuite(uint16_t cipher_suite) { + EXPECT_EQ(csinfo_.cipherSuite, cipher_suite); +} + void TlsAgent::RequestClientAuth(bool requireAuth) { EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(SERVER, role_); EXPECT_EQ(SECSuccess, - SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE)); - EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE, + SSL_OptionSet(ssl_fd(), SSL_REQUEST_CERTIFICATE, PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE)); EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook( - ssl_fd_, &TlsAgent::ClientAuthenticated, this)); + ssl_fd(), &TlsAgent::ClientAuthenticated, this)); expect_client_auth_ = true; } @@ -242,7 +275,7 @@ void TlsAgent::StartConnect(PRFileDesc* model) { EXPECT_TRUE(EnsureTlsSetup(model)); SECStatus rv; - rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE); + rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); SetState(STATE_CONNECTING); } @@ -250,7 +283,7 @@ void TlsAgent::StartConnect(PRFileDesc* model) { void TlsAgent::DisableAllCiphers() { for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { SECStatus rv = - SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_FALSE); + SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE); EXPECT_EQ(SECSuccess, rv); } } @@ -287,7 +320,7 @@ void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) { EXPECT_EQ(sizeof(csinfo), csinfo.length); if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) { - rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE); + rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); EXPECT_EQ(SECSuccess, rv); } } @@ -325,7 +358,7 @@ void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) { if ((csinfo.authType == authType) || (csinfo.keaType == ssl_kea_tls13_any)) { - rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE); + rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); EXPECT_EQ(SECSuccess, rv); } } @@ -333,20 +366,20 @@ void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) { void TlsAgent::EnableSingleCipher(uint16_t cipher) { DisableAllCiphers(); - SECStatus rv = SSL_CipherPrefSet(ssl_fd_, cipher, PR_TRUE); + SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::ConfigNamedGroups(const std::vector& groups) { EXPECT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_NamedGroupConfig(ssl_fd_, &groups[0], groups.size()); + SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size()); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SetSessionTicketsEnabled(bool en) { EXPECT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS, + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS, en ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); } @@ -354,7 +387,7 @@ void TlsAgent::SetSessionTicketsEnabled(bool en) { void TlsAgent::SetSessionCacheEnabled(bool en) { EXPECT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE); + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE); EXPECT_EQ(SECSuccess, rv); } @@ -362,14 +395,22 @@ void TlsAgent::Set0RttEnabled(bool en) { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = - SSL_OptionSet(ssl_fd_, SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); + SSL_OptionSet(ssl_fd(), SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::SetFallbackSCSVEnabled(bool en) { + EXPECT_TRUE(role_ == CLIENT && EnsureTlsSetup()); + + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALLBACK_SCSV, + en ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SetShortHeadersEnabled() { EXPECT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd_); + SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd()); EXPECT_EQ(SECSuccess, rv); } @@ -377,8 +418,8 @@ void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { vrange_.min = minver; vrange_.max = maxver; - if (ssl_fd_) { - SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_); + if (ssl_fd()) { + SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); EXPECT_EQ(SECSuccess, rv); } } @@ -398,32 +439,34 @@ void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; } void TlsAgent::ExpectShortHeaders() { expect_short_headers_ = true; } +void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; } + void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count) { EXPECT_TRUE(EnsureTlsSetup()); EXPECT_LE(count, SSL_SignatureMaxCount()); EXPECT_EQ(SECSuccess, - SSL_SignatureSchemePrefSet(ssl_fd_, schemes, + SSL_SignatureSchemePrefSet(ssl_fd(), schemes, static_cast(count))); - EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd_, schemes, 0)) + EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0)) << "setting no schemes should fail and do nothing"; std::vector configuredSchemes(count); unsigned int configuredCount; EXPECT_EQ(SECFailure, - SSL_SignatureSchemePrefGet(ssl_fd_, nullptr, &configuredCount, 1)) + SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1)) << "get schemes, schemes is nullptr"; EXPECT_EQ(SECFailure, - SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0], + SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], &configuredCount, 0)) << "get schemes, too little space"; EXPECT_EQ(SECFailure, - SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0], nullptr, + SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr, configuredSchemes.size())) << "get schemes, countOut is nullptr"; EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet( - ssl_fd_, &configuredSchemes[0], &configuredCount, + ssl_fd(), &configuredSchemes[0], &configuredCount, configuredSchemes.size())); // SignatureSchemePrefSet drops unsupported algorithms silently, so the // number that are configured might be fewer. @@ -524,10 +567,10 @@ void TlsAgent::EnableFalseStart() { EXPECT_TRUE(EnsureTlsSetup()); falsestart_enabled_ = true; + EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback( + ssl_fd(), CanFalseStartCallback, this)); EXPECT_EQ(SECSuccess, - SSL_SetCanFalseStartCallback(ssl_fd_, CanFalseStartCallback, this)); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(ssl_fd_, SSL_ENABLE_FALSE_START, PR_TRUE)); + SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALSE_START, PR_TRUE)); } void TlsAgent::ExpectResumption() { expect_resumption_ = true; } @@ -535,8 +578,8 @@ void TlsAgent::ExpectResumption() { expect_resumption_ = true; } void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { EXPECT_TRUE(EnsureTlsSetup()); - EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE)); - EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len)); + EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_ENABLE_ALPN, PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len)); } void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, @@ -544,7 +587,7 @@ void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, SSLNextProtoState state; char chosen[10]; unsigned int chosen_len; - SECStatus rv = SSL_GetNextProto(ssl_fd_, &state, + SECStatus rv = SSL_GetNextProto(ssl_fd(), &state, reinterpret_cast(chosen), &chosen_len, sizeof(chosen)); EXPECT_EQ(SECSuccess, rv); @@ -562,12 +605,12 @@ void TlsAgent::EnableSrtp() { const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}; EXPECT_EQ(SECSuccess, - SSL_SetSRTPCiphers(ssl_fd_, ciphers, PR_ARRAY_SIZE(ciphers))); + SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers))); } void TlsAgent::CheckSrtp() const { uint16_t actual; - EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual)); + EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual)); EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual); } @@ -578,6 +621,55 @@ void TlsAgent::CheckErrorCode(int32_t expected) const { << PORT_ErrorToName(expected) << std::endl; } +static uint8_t GetExpectedAlertLevel(uint8_t alert) { + switch (alert) { + case kTlsAlertCloseNotify: + case kTlsAlertEndOfEarlyData: + return kTlsAlertWarning; + default: + break; + } + return kTlsAlertFatal; +} + +void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) { + expected_received_alert_ = alert; + if (level == 0) { + expected_received_alert_level_ = GetExpectedAlertLevel(alert); + } else { + expected_received_alert_level_ = level; + } +} + +void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) { + expected_sent_alert_ = alert; + if (level == 0) { + expected_sent_alert_level_ = GetExpectedAlertLevel(alert); + } else { + expected_sent_alert_level_ = level; + } +} + +void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) { + LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal") + << " alert " << (sent ? "sent" : "received") << ": " + << static_cast(alert->description)); + + auto& expected = sent ? expected_sent_alert_ : expected_received_alert_; + auto& expected_level = + sent ? expected_sent_alert_level_ : expected_received_alert_level_; + /* Silently pass close_notify in case the test has already ended. */ + if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning && + alert->description == expected && alert->level == expected_level) { + return; + } + + EXPECT_EQ(expected, alert->description); + EXPECT_EQ(expected_level, alert->level); + expected = kTlsAlertCloseNotify; + expected_level = kTlsAlertWarning; +} + void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const { ASSERT_EQ(0, error_code_); WAIT_(error_code_ != 0, delay); @@ -589,7 +681,7 @@ void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const { void TlsAgent::CheckPreliminaryInfo() { SSLPreliminaryChannelInfo info; EXPECT_EQ(SECSuccess, - SSL_GetPreliminaryChannelInfo(ssl_fd_, &info, sizeof(info))); + SSL_GetPreliminaryChannelInfo(ssl_fd(), &info, sizeof(info))); EXPECT_EQ(sizeof(info), info.length); EXPECT_TRUE(info.valuesSet & ssl_preinfo_version); EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite); @@ -619,7 +711,7 @@ void TlsAgent::CheckCallbacks() const { // These callbacks shouldn't fire if we are resuming, except on TLS 1.3. if (role_ == SERVER) { - PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd_, ssl_server_name_xtn); + PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn); EXPECT_EQ(((!expect_resumption_ && have_sni) || expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3), sni_hook_called_); @@ -639,11 +731,15 @@ void TlsAgent::ResetPreliminaryInfo() { } void TlsAgent::Connected() { + if (state_ == STATE_CONNECTED) { + return; + } + LOG("Handshake success"); CheckPreliminaryInfo(); CheckCallbacks(); - SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_)); + SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_)); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(info_), info_.length); @@ -658,18 +754,19 @@ void TlsAgent::Connected() { EXPECT_EQ(sizeof(csinfo_), csinfo_.length); if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { - PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd_); + PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd()); // We use one ciphersuite in each direction, plus one that's kept around // by DTLS for retransmission. - PRInt32 expected = ((mode_ == DGRAM) && (role_ == CLIENT)) ? 3 : 2; + PRInt32 expected = + ((variant_ == ssl_variant_datagram) && (role_ == CLIENT)) ? 3 : 2; EXPECT_EQ(expected, cipherSuites); if (expected != cipherSuites) { - SSLInt_PrintTls13CipherSpecs(ssl_fd_); + SSLInt_PrintTls13CipherSpecs(ssl_fd()); } } PRBool short_headers; - rv = SSLInt_UsingShortHeaders(ssl_fd_, &short_headers); + rv = SSLInt_UsingShortHeaders(ssl_fd(), &short_headers); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ((PRBool)expect_short_headers_, short_headers); SetState(STATE_CONNECTED); @@ -679,7 +776,7 @@ void TlsAgent::EnableExtendedMasterSecret() { ASSERT_TRUE(EnsureTlsSetup()); SECStatus rv = - SSL_OptionSet(ssl_fd_, SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); + SSL_OptionSet(ssl_fd(), SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); ASSERT_EQ(SECSuccess, rv); } @@ -701,13 +798,13 @@ void TlsAgent::CheckEarlyDataAccepted(bool expected) { } void TlsAgent::CheckSecretsDestroyed() { - ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd_)); + ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd())); } void TlsAgent::DisableRollbackDetection() { ASSERT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ROLLBACK_DETECTION, PR_FALSE); + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ROLLBACK_DETECTION, PR_FALSE); ASSERT_EQ(SECSuccess, rv); } @@ -715,23 +812,22 @@ void TlsAgent::DisableRollbackDetection() { void TlsAgent::EnableCompression() { ASSERT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_DEFLATE, PR_TRUE); + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_DEFLATE, PR_TRUE); ASSERT_EQ(SECSuccess, rv); } void TlsAgent::SetDowngradeCheckVersion(uint16_t version) { ASSERT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd_, version); + SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), version); ASSERT_EQ(SECSuccess, rv); } void TlsAgent::Handshake() { LOGV("Handshake"); - SECStatus rv = SSL_ForceHandshake(ssl_fd_); + SECStatus rv = SSL_ForceHandshake(ssl_fd()); if (rv == SECSuccess) { Connected(); - Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, &TlsAgent::ReadableCallback); return; @@ -740,14 +836,14 @@ void TlsAgent::Handshake() { int32_t err = PR_GetError(); if (err == PR_WOULD_BLOCK_ERROR) { LOGV("Would have blocked"); - if (mode_ == DGRAM) { + if (variant_ == ssl_variant_datagram) { if (timer_handle_) { timer_handle_->Cancel(); timer_handle_ = nullptr; } PRIntervalTime timeout; - rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout); + rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout); if (rv == SECSuccess) { Poller::Instance()->SetTimer( timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_); @@ -773,13 +869,18 @@ void TlsAgent::PrepareForRenegotiate() { void TlsAgent::StartRenegotiate() { PrepareForRenegotiate(); - SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE); + SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SendDirect(const DataBuffer& buf) { LOG("Send Direct " << buf); - adapter_->peer()->PacketReceived(buf); + auto peer = adapter_->peer().lock(); + if (peer) { + peer->PacketReceived(buf); + } else { + LOG("Send Direct peer absent"); + } } static bool ErrorIsNonFatal(PRErrorCode code) { @@ -806,7 +907,7 @@ void TlsAgent::SendData(size_t bytes, size_t blocksize) { void TlsAgent::SendBuffer(const DataBuffer& buf) { LOGV("Writing " << buf.len() << " bytes"); - int32_t rv = PR_Write(ssl_fd_, buf.data(), buf.len()); + int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len()); if (expect_readwrite_error_) { EXPECT_GT(0, rv); EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_); @@ -820,7 +921,7 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) { void TlsAgent::ReadBytes() { uint8_t block[1024]; - int32_t rv = PR_Read(ssl_fd_, block, sizeof(block)); + int32_t rv = PR_Read(ssl_fd(), block, sizeof(block)); LOGV("ReadBytes " << rv); int32_t err; @@ -853,18 +954,19 @@ void TlsAgent::ResetSentBytes() { send_ctr_ = 0; } void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { EXPECT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); EXPECT_EQ(SECSuccess, rv); - rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS, + rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS, mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::DisableECDHEServerKeyReuse() { + ASSERT_TRUE(EnsureTlsSetup()); ASSERT_EQ(TlsAgent::SERVER, role_); - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); EXPECT_EQ(SECSuccess, rv); } @@ -877,29 +979,25 @@ void TlsAgentTestBase::SetUp() { } void TlsAgentTestBase::TearDown() { - delete agent_; + agent_ = nullptr; SSL_ClearSessionCache(); SSL_ShutdownServerSessionIDCache(); } void TlsAgentTestBase::Reset(const std::string& server_name) { - delete agent_; - Init(server_name); -} - -void TlsAgentTestBase::Init(const std::string& server_name) { - agent_ = + agent_.reset( new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name, - role_, mode_); - agent_->Init(); - fd_ = DummyPrSocket::CreateFD(agent_->role_str(), mode_); - agent_->adapter()->SetPeer(DummyPrSocket::GetAdapter(fd_)); + role_, variant_)); + if (version_) { + agent_->SetVersionRange(version_, version_); + } + agent_->adapter()->SetPeer(sink_adapter_); agent_->StartConnect(); } void TlsAgentTestBase::EnsureInit() { if (!agent_) { - Init(); + Reset(); } const std::vector groups = { ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, @@ -907,6 +1005,11 @@ void TlsAgentTestBase::EnsureInit() { agent_->ConfigNamedGroups(groups); } +void TlsAgentTestBase::ExpectAlert(uint8_t alert) { + EnsureInit(); + agent_->ExpectSendAlert(alert); +} + void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, int32_t error_code) { @@ -922,14 +1025,16 @@ void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, } } -void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version, - const uint8_t* buf, size_t len, - DataBuffer* out, uint64_t seq_num) { +void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type, + uint16_t version, const uint8_t* buf, + size_t len, DataBuffer* out, + uint64_t seq_num) { size_t index = 0; index = out->Write(index, type, 1); - index = out->Write( - index, mode == STREAM ? version : TlsVersionToDtlsVersion(version), 2); - if (mode == DGRAM) { + if (variant == ssl_variant_stream) { + index = out->Write(index, version, 2); + } else { + index = out->Write(index, TlsVersionToDtlsVersion(version), 2); index = out->Write(index, seq_num >> 32, 4); index = out->Write(index, seq_num & PR_UINT32_MAX, 4); } @@ -940,7 +1045,7 @@ void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version, void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, uint64_t seq_num) const { - MakeRecord(mode_, type, version, buf, len, out, seq_num); + MakeRecord(variant_, type, version, buf, len, out, seq_num); } void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type, @@ -959,7 +1064,7 @@ void TlsAgentTestBase::MakeHandshakeMessageFragment( if (!fragment_length) fragment_length = hs_len; index = out->Write(index, hs_type, 1); // Handshake record type. index = out->Write(index, hs_len, 3); // Handshake length - if (mode_ == DGRAM) { + if (variant_ == ssl_variant_datagram) { index = out->Write(index, seq_num, 2); index = out->Write(index, fragment_offset, 3); index = out->Write(index, fragment_length, 3); diff --git a/nss/gtests/ssl_gtest/tls_agent.h b/nss/gtests/ssl_gtest/tls_agent.h index 78923c9..32f6175 100644 --- a/nss/gtests/ssl_gtest/tls_agent.h +++ b/nss/gtests/ssl_gtest/tls_agent.h @@ -14,9 +14,11 @@ #include #include "test_io.h" +#include "tls_filter.h" #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" +#include "scoped_ptrs.h" extern bool g_ssl_gtest_verbose; @@ -42,6 +44,8 @@ const extern std::vector kECDHEGroups; const extern std::vector kFFDHEGroups; const extern std::vector kFasterDHEGroups; +// These functions are called from callbacks. They use bare pointers because +// TlsAgent sets up the callback and it doesn't know who owns it. typedef std::function AuthCertificateCallbackFunction; @@ -70,25 +74,24 @@ class TlsAgent : public PollTarget { static const std::string kServerEcdhRsa; static const std::string kServerDsa; - TlsAgent(const std::string& name, Role role, Mode mode); + TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant); virtual ~TlsAgent(); - bool Init() { - pr_fd_ = DummyPrSocket::CreateFD(role_str(), mode_); - if (!pr_fd_) return false; - - adapter_ = DummyPrSocket::GetAdapter(pr_fd_); - if (!adapter_) return false; - - return true; + void SetPeer(std::shared_ptr& peer) { + adapter_->SetPeer(peer->adapter_); } - void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); } + void SetTlsRecordFilter(std::shared_ptr filter) { + filter->SetAgent(this); + adapter_->SetPacketFilter(filter); + } - void SetPacketFilter(PacketFilter* filter) { + void SetPacketFilter(std::shared_ptr filter) { adapter_->SetPacketFilter(filter); } + void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); } + void StartConnect(PRFileDesc* model = nullptr); void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group, size_t kea_size = 0) const; @@ -107,6 +110,9 @@ class TlsAgent : public PollTarget { void PrepareForRenegotiate(); // Prepares for renegotiation, then actually triggers it. void StartRenegotiate(); + static bool LoadCertificate(const std::string& name, + ScopedCERTCertificate* cert, + ScopedSECKEYPrivateKey* priv); bool ConfigServerCert(const std::string& name, bool updateKeyBits = false, const SSLExtraServerCertData* serverCertData = nullptr); bool ConfigServerCertWithChain(const std::string& name); @@ -114,13 +120,12 @@ class TlsAgent : public PollTarget { void SetupClientAuth(); void RequestClientAuth(bool requireAuth); - bool GetClientAuthCredentials(CERTCertificate** cert, - SECKEYPrivateKey** priv) const; void ConfigureSessionCache(SessionResumptionMode mode); void SetSessionTicketsEnabled(bool en); void SetSessionCacheEnabled(bool en); void Set0RttEnabled(bool en); + void SetFallbackSCSVEnabled(bool en); void SetShortHeadersEnabled(); void SetVersionRange(uint16_t minver, uint16_t maxver); void GetVersionRange(uint16_t* minver, uint16_t* maxver); @@ -132,6 +137,7 @@ class TlsAgent : public PollTarget { void EnableFalseStart(); void ExpectResumption(); void ExpectShortHeaders(); + void SkipVersionChecks(); void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count); void EnableAlpn(const uint8_t* val, size_t len); void CheckAlpn(SSLNextProtoState expected_state, @@ -157,6 +163,7 @@ class TlsAgent : public PollTarget { void ConfigNamedGroups(const std::vector& groups); void DisableECDHEServerKeyReuse(); bool GetPeerChainLength(size_t* count); + void CheckCipherSuite(uint16_t cipher_suite); const std::string& name() const { return name_; } @@ -166,15 +173,15 @@ class TlsAgent : public PollTarget { State state() const { return state_; } const CERTCertificate* peer_cert() const { - return SSL_PeerCertificate(ssl_fd_); + return SSL_PeerCertificate(ssl_fd_.get()); } const char* state_str() const { return state_str(state()); } static const char* state_str(State state) { return states[state]; } - PRFileDesc* ssl_fd() { return ssl_fd_; } - DummyPrSocket* adapter() { return adapter_; } + PRFileDesc* ssl_fd() const { return ssl_fd_.get(); } + std::shared_ptr& adapter() { return adapter_; } bool is_compressed() const { return info_.compressionMethod != ssl_compression_null; @@ -239,6 +246,9 @@ class TlsAgent : public PollTarget { sni_callback_ = sni_callback; } + void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0); + void ExpectSendAlert(uint8_t alert, uint8_t level = 0); + private: const static char* states[]; @@ -320,6 +330,18 @@ class TlsAgent : public PollTarget { return SECSuccess; } + void CheckAlert(bool sent, const SSLAlert* alert); + + static void AlertReceivedCallback(const PRFileDesc* fd, void* arg, + const SSLAlert* alert) { + reinterpret_cast(arg)->CheckAlert(false, alert); + } + + static void AlertSentCallback(const PRFileDesc* fd, void* arg, + const SSLAlert* alert) { + reinterpret_cast(arg)->CheckAlert(true, alert); + } + static void HandshakeCallback(PRFileDesc* fd, void* arg) { TlsAgent* agent = reinterpret_cast(arg); agent->handshake_callback_called_ = true; @@ -336,14 +358,13 @@ class TlsAgent : public PollTarget { void Connected(); const std::string name_; - Mode mode_; - uint16_t server_key_bits_; - PRFileDesc* pr_fd_; - DummyPrSocket* adapter_; - PRFileDesc* ssl_fd_; + SSLProtocolVariant variant_; Role role_; + uint16_t server_key_bits_; + std::shared_ptr adapter_; + ScopedPRFileDesc ssl_fd_; State state_; - Poller::Timer* timer_handle_; + std::shared_ptr timer_handle_; bool falsestart_enabled_; uint16_t expected_version_; uint16_t expected_cipher_suite_; @@ -352,6 +373,10 @@ class TlsAgent : public PollTarget { bool can_falsestart_hook_called_; bool sni_hook_called_; bool auth_certificate_hook_called_; + uint8_t expected_received_alert_; + uint8_t expected_received_alert_level_; + uint8_t expected_sent_alert_; + uint8_t expected_sent_alert_level_; bool handshake_callback_called_; SSLChannelInfo info_; SSLCipherSuiteInfo csinfo_; @@ -364,6 +389,7 @@ class TlsAgent : public PollTarget { AuthCertificateCallbackFunction auth_certificate_callback_; SniCallbackFunction sni_callback_; bool expect_short_headers_; + bool skip_version_checks_; }; inline std::ostream& operator<<(std::ostream& stream, @@ -375,20 +401,23 @@ class TlsAgentTestBase : public ::testing::Test { public: static ::testing::internal::ParamGenerator kTlsRolesAll; - TlsAgentTestBase(TlsAgent::Role role, Mode mode) - : agent_(nullptr), fd_(nullptr), role_(role), mode_(mode) {} - ~TlsAgentTestBase() { - if (fd_) { - PR_Close(fd_); - } - } + TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant, + uint16_t version = 0) + : agent_(nullptr), + role_(role), + variant_(variant), + version_(version), + sink_adapter_(new DummyPrSocket("sink", variant)) {} + virtual ~TlsAgentTestBase() {} void SetUp(); void TearDown(); - static void MakeRecord(Mode mode, uint8_t type, uint16_t version, - const uint8_t* buf, size_t len, DataBuffer* out, - uint64_t seq_num = 0); + void ExpectAlert(uint8_t alert); + + static void MakeRecord(SSLProtocolVariant variant, uint8_t type, + uint16_t version, const uint8_t* buf, size_t len, + DataBuffer* out, uint64_t seq_num = 0); void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, uint64_t seq_num = 0) const; void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len, @@ -403,10 +432,6 @@ class TlsAgentTestBase : public ::testing::Test { return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER; } - static inline Mode ToMode(const std::string& str) { - return str == "TLS" ? STREAM : DGRAM; - } - void Init(const std::string& server_name = TlsAgent::kServerRsa); void Reset(const std::string& server_name = TlsAgent::kServerRsa); @@ -415,43 +440,57 @@ class TlsAgentTestBase : public ::testing::Test { void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, int32_t error_code = 0); - TlsAgent* agent_; - PRFileDesc* fd_; + std::unique_ptr agent_; TlsAgent::Role role_; - Mode mode_; + SSLProtocolVariant variant_; + uint16_t version_; + // This adapter is here just to accept packets from this agent. + std::shared_ptr sink_adapter_; }; -class TlsAgentTest : public TlsAgentTestBase, - public ::testing::WithParamInterface< - std::tuple> { +class TlsAgentTest + : public TlsAgentTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsAgentTest() : TlsAgentTestBase(ToRole(std::get<0>(GetParam())), - ToMode(std::get<1>(GetParam()))) {} + std::get<1>(GetParam()), std::get<2>(GetParam())) {} }; class TlsAgentTestClient : public TlsAgentTestBase, - public ::testing::WithParamInterface { + public ::testing::WithParamInterface< + std::tuple> { public: TlsAgentTestClient() - : TlsAgentTestBase(TlsAgent::CLIENT, ToMode(GetParam())) {} + : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()), + std::get<1>(GetParam())) {} }; +class TlsAgentTestClient13 : public TlsAgentTestClient {}; + class TlsAgentStreamTestClient : public TlsAgentTestBase { public: - TlsAgentStreamTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, STREAM) {} + TlsAgentStreamTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {} }; class TlsAgentStreamTestServer : public TlsAgentTestBase { public: - TlsAgentStreamTestServer() : TlsAgentTestBase(TlsAgent::SERVER, STREAM) {} + TlsAgentStreamTestServer() + : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {} }; class TlsAgentDgramTestClient : public TlsAgentTestBase { public: - TlsAgentDgramTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, DGRAM) {} + TlsAgentDgramTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {} }; +inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) { + return vr1.min == vr2.min && vr1.max == vr2.max; +} + } // namespace nss_test #endif diff --git a/nss/gtests/ssl_gtest/tls_connect.cc b/nss/gtests/ssl_gtest/tls_connect.cc index d025499..861d162 100644 --- a/nss/gtests/ssl_gtest/tls_connect.cc +++ b/nss/gtests/ssl_gtest/tls_connect.cc @@ -13,23 +13,27 @@ extern "C" { #include "databuffer.h" #include "gtest_utils.h" +#include "scoped_ptrs.h" #include "sslproto.h" extern std::string g_working_dir_path; namespace nss_test { -static const std::string kTlsModesStreamArr[] = {"TLS"}; -::testing::internal::ParamGenerator - TlsConnectTestBase::kTlsModesStream = - ::testing::ValuesIn(kTlsModesStreamArr); -static const std::string kTlsModesDatagramArr[] = {"DTLS"}; -::testing::internal::ParamGenerator - TlsConnectTestBase::kTlsModesDatagram = - ::testing::ValuesIn(kTlsModesDatagramArr); -static const std::string kTlsModesAllArr[] = {"TLS", "DTLS"}; -::testing::internal::ParamGenerator - TlsConnectTestBase::kTlsModesAll = ::testing::ValuesIn(kTlsModesAllArr); +static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream}; +::testing::internal::ParamGenerator + TlsConnectTestBase::kTlsVariantsStream = + ::testing::ValuesIn(kTlsVariantsStreamArr); +static const SSLProtocolVariant kTlsVariantsDatagramArr[] = { + ssl_variant_datagram}; +::testing::internal::ParamGenerator + TlsConnectTestBase::kTlsVariantsDatagram = + ::testing::ValuesIn(kTlsVariantsDatagramArr); +static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream, + ssl_variant_datagram}; +::testing::internal::ParamGenerator + TlsConnectTestBase::kTlsVariantsAll = + ::testing::ValuesIn(kTlsVariantsAllArr); static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV10 = @@ -99,30 +103,29 @@ std::string VersionString(uint16_t version) { } } -TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version) - : mode_(mode), - client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_)), - server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_)), +TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, + uint16_t version) + : variant_(variant), + client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)), + server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)), client_model_(nullptr), server_model_(nullptr), version_(version), expected_resumption_mode_(RESUME_NONE), session_ids_(), expect_extended_master_secret_(false), - expect_early_data_accepted_(false) { + expect_early_data_accepted_(false), + skip_version_checks_(false) { std::string v; - if (mode_ == DGRAM && version_ == SSL_LIBRARY_VERSION_TLS_1_1) { + if (variant_ == ssl_variant_datagram && + version_ == SSL_LIBRARY_VERSION_TLS_1_1) { v = "1.0"; } else { v = VersionString(version_); } - std::cerr << "Version: " << mode_ << " " << v << std::endl; + std::cerr << "Version: " << variant_ << " " << v << std::endl; } -TlsConnectTestBase::TlsConnectTestBase(const std::string& mode, - uint16_t version) - : TlsConnectTestBase(TlsConnectTestBase::ToMode(mode), version) {} - TlsConnectTestBase::~TlsConnectTestBase() {} // Check the group of each of the supported groups @@ -173,18 +176,15 @@ void TlsConnectTestBase::ClearServerCache() { void TlsConnectTestBase::SetUp() { SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); SSLInt_ClearSessionTicketKey(); + SSLInt_SetTicketLifetime(30); + SSLInt_SetMaxEarlyDataSize(1024); ClearStats(); Init(); } void TlsConnectTestBase::TearDown() { - delete client_; - delete server_; - if (client_model_) { - ASSERT_NE(server_model_, nullptr); - delete client_model_; - delete server_model_; - } + client_ = nullptr; + server_ = nullptr; SSL_ClearSessionCache(); SSLInt_ClearSessionTicketKey(); @@ -192,9 +192,6 @@ void TlsConnectTestBase::TearDown() { } void TlsConnectTestBase::Init() { - EXPECT_TRUE(client_->Init()); - EXPECT_TRUE(server_->Init()); - client_->SetPeer(server_); server_->SetPeer(client_); @@ -212,11 +209,12 @@ void TlsConnectTestBase::Reset() { void TlsConnectTestBase::Reset(const std::string& server_name, const std::string& client_name) { - delete client_; - delete server_; - - client_ = new TlsAgent(client_name, TlsAgent::CLIENT, mode_); - server_ = new TlsAgent(server_name, TlsAgent::SERVER, mode_); + client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_)); + server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_)); + if (skip_version_checks_) { + client_->SkipVersionChecks(); + server_->SkipVersionChecks(); + } Init(); } @@ -276,10 +274,12 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { } void TlsConnectTestBase::CheckConnected() { - // Check the version is as expected EXPECT_EQ(client_->version(), server_->version()); - EXPECT_EQ(std::min(client_->max_version(), server_->max_version()), - client_->version()); + if (!skip_version_checks_) { + // Check the version is as expected + EXPECT_EQ(std::min(client_->max_version(), server_->max_version()), + client_->version()); + } EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); @@ -345,6 +345,13 @@ void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, scheme = ssl_sig_none; break; case ssl_auth_rsa_sign: + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) { + scheme = ssl_sig_rsa_pss_sha256; + } else { + scheme = ssl_sig_rsa_pkcs1_sha256; + } + break; + case ssl_auth_rsa_pss: scheme = ssl_sig_rsa_pss_sha256; break; case ssl_auth_ecdsa: @@ -373,7 +380,36 @@ void TlsConnectTestBase::ConnectExpectFail() { ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); } +void TlsConnectTestBase::ExpectAlert(std::shared_ptr& sender, + uint8_t alert) { + EnsureTlsSetup(); + auto receiver = (sender == client_) ? server_ : client_; + sender->ExpectSendAlert(alert); + receiver->ExpectReceiveAlert(alert); +} + +void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr& sender, + uint8_t alert) { + ExpectAlert(sender, alert); + ConnectExpectFail(); +} + +void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) { + server_->StartConnect(); + client_->StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + client_->Handshake(); + server_->Handshake(); + + auto failing_agent = server_; + if (failing_side == TlsAgent::CLIENT) { + failing_agent = client_; + } + ASSERT_TRUE_WAIT(failing_agent->state() == TlsAgent::STATE_ERROR, 5000); +} + void TlsConnectTestBase::ConfigureVersion(uint16_t version) { + version_ = version; client_->SetVersionRange(version, version); server_->SetVersionRange(version, version); } @@ -424,10 +460,16 @@ void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, client_->ConfigureSessionCache(client); server_->ConfigureSessionCache(server); if ((server & RESUME_TICKET) != 0) { - // This is an abomination. NSS encrypts session tickets with the server's - // RSA public key. That means we need the server to have an RSA certificate - // even if it won't be used for the connection. - server_->ConfigServerCert(TlsAgent::kServerRsaDecrypt); + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey privKey; + ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, + &privKey)); + + ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); + ASSERT_TRUE(pubKey); + + EXPECT_EQ(SECSuccess, + SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); } } @@ -472,13 +514,15 @@ void TlsConnectTestBase::EnsureModelSockets() { // Make sure models agents are available. if (!client_model_) { ASSERT_EQ(server_model_, nullptr); - client_model_ = new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_); - server_model_ = new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_); + client_model_.reset( + new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)); + server_model_.reset( + new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)); + if (skip_version_checks_) { + client_model_->SkipVersionChecks(); + server_model_->SkipVersionChecks(); + } } - - // Initialise agents. - ASSERT_TRUE(client_model_->Init()); - ASSERT_TRUE(server_model_->Init()); } void TlsConnectTestBase::CheckAlpn(const std::string& val) { @@ -540,6 +584,10 @@ void TlsConnectTestBase::ZeroRttSendReceive( const char* k0RttData = "ABCDEF"; const PRInt32 k0RttDataLen = static_cast(strlen(k0RttData)); + if (expect_writable && expect_readable) { + ExpectAlert(client_, kTlsAlertEndOfEarlyData); + } + client_->Handshake(); // Send ClientHello. if (post_clienthello_check) { if (!post_clienthello_check()) return; @@ -599,6 +647,12 @@ void TlsConnectTestBase::DisableECDHEServerKeyReuse() { server_->DisableECDHEServerKeyReuse(); } +void TlsConnectTestBase::SkipVersionChecks() { + skip_version_checks_ = true; + client_->SkipVersionChecks(); + server_->SkipVersionChecks(); +} + TlsConnectGeneric::TlsConnectGeneric() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} @@ -616,16 +670,17 @@ TlsConnectTls13::TlsConnectTls13() void TlsKeyExchangeTest::EnsureKeyShareSetup() { EnsureTlsSetup(); - groups_capture_ = new TlsExtensionCapture(ssl_supported_groups_xtn); - shares_capture_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn); - shares_capture2_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn, true); - std::vector captures; - captures.push_back(groups_capture_); - captures.push_back(shares_capture_); - captures.push_back(shares_capture2_); - client_->SetPacketFilter(new ChainedPacketFilter(captures)); - capture_hrr_ = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest); + groups_capture_ = + std::make_shared(ssl_supported_groups_xtn); + shares_capture_ = + std::make_shared(ssl_tls13_key_share_xtn); + shares_capture2_ = + std::make_shared(ssl_tls13_key_share_xtn, true); + std::vector> captures = { + groups_capture_, shares_capture_, shares_capture2_}; + client_->SetPacketFilter(std::make_shared(captures)); + capture_hrr_ = std::make_shared( + kTlsHandshakeHelloRetryRequest); server_->SetPacketFilter(capture_hrr_); } diff --git a/nss/gtests/ssl_gtest/tls_connect.h b/nss/gtests/ssl_gtest/tls_connect.h index aa4a32d..73e8dc8 100644 --- a/nss/gtests/ssl_gtest/tls_connect.h +++ b/nss/gtests/ssl_gtest/tls_connect.h @@ -25,9 +25,12 @@ extern std::string VersionString(uint16_t version); // A generic TLS connection test base. class TlsConnectTestBase : public ::testing::Test { public: - static ::testing::internal::ParamGenerator kTlsModesStream; - static ::testing::internal::ParamGenerator kTlsModesDatagram; - static ::testing::internal::ParamGenerator kTlsModesAll; + static ::testing::internal::ParamGenerator + kTlsVariantsStream; + static ::testing::internal::ParamGenerator + kTlsVariantsDatagram; + static ::testing::internal::ParamGenerator + kTlsVariantsAll; static ::testing::internal::ParamGenerator kTlsV10; static ::testing::internal::ParamGenerator kTlsV11; static ::testing::internal::ParamGenerator kTlsV12; @@ -39,8 +42,7 @@ class TlsConnectTestBase : public ::testing::Test { static ::testing::internal::ParamGenerator kTlsV12Plus; static ::testing::internal::ParamGenerator kTlsVAll; - TlsConnectTestBase(Mode mode, uint16_t version); - TlsConnectTestBase(const std::string& mode, uint16_t version); + TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version); virtual ~TlsConnectTestBase(); void SetUp(); @@ -68,6 +70,9 @@ class TlsConnectTestBase : public ::testing::Test { void CheckConnected(); // Connect and expect it to fail. void ConnectExpectFail(); + void ExpectAlert(std::shared_ptr& sender, uint8_t alert); + void ConnectExpectAlert(std::shared_ptr& sender, uint8_t alert); + void ConnectExpectFailOneSide(TlsAgent::Role failingSide); void ConnectWithCipherSuite(uint16_t cipher_suite); // Check that the keys used in the handshake match expectations. void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, @@ -108,13 +113,14 @@ class TlsConnectTestBase : public ::testing::Test { void ExpectExtendedMasterSecret(bool expected); void ExpectEarlyDataAccepted(bool expected); void DisableECDHEServerKeyReuse(); + void SkipVersionChecks(); protected: - Mode mode_; - TlsAgent* client_; - TlsAgent* server_; - TlsAgent* client_model_; - TlsAgent* server_model_; + SSLProtocolVariant variant_; + std::shared_ptr client_; + std::shared_ptr server_; + std::unique_ptr client_model_; + std::unique_ptr server_model_; uint16_t version_; SessionResumptionMode expected_resumption_mode_; std::vector> session_ids_; @@ -126,16 +132,13 @@ class TlsConnectTestBase : public ::testing::Test { const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61}; private: - static inline Mode ToMode(const std::string& str) { - return str == "TLS" ? STREAM : DGRAM; - } - void CheckResumption(SessionResumptionMode expected); void CheckExtendedMasterSecret(); void CheckEarlyDataAccepted(); bool expect_extended_master_secret_; bool expect_early_data_accepted_; + bool skip_version_checks_; // Track groups and make sure that there are no duplicates. class DuplicateGroupChecker { @@ -154,20 +157,20 @@ class TlsConnectTestBase : public ::testing::Test { // A non-parametrized TLS test base. class TlsConnectTest : public TlsConnectTestBase { public: - TlsConnectTest() : TlsConnectTestBase(STREAM, 0) {} + TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {} }; // A non-parametrized DTLS-only test base. class DtlsConnectTest : public TlsConnectTestBase { public: - DtlsConnectTest() : TlsConnectTestBase(DGRAM, 0) {} + DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {} }; // A TLS-only test base. class TlsConnectStream : public TlsConnectTestBase, public ::testing::WithParamInterface { public: - TlsConnectStream() : TlsConnectTestBase(STREAM, GetParam()) {} + TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {} }; // A TLS-only test base for tests before 1.3 @@ -177,30 +180,30 @@ class TlsConnectStreamPre13 : public TlsConnectStream {}; class TlsConnectDatagram : public TlsConnectTestBase, public ::testing::WithParamInterface { public: - TlsConnectDatagram() : TlsConnectTestBase(DGRAM, GetParam()) {} + TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {} }; -// A generic test class that can be either STREAM or DGRAM and a single version -// of TLS. This is configured in ssl_loopback_unittest.cc. All uses of this -// should use TEST_P(). -class TlsConnectGeneric - : public TlsConnectTestBase, - public ::testing::WithParamInterface> { +// A generic test class that can be either stream or datagram and a single +// version of TLS. This is configured in ssl_loopback_unittest.cc. +class TlsConnectGeneric : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsConnectGeneric(); }; // A Pre TLS 1.2 generic test. -class TlsConnectPre12 - : public TlsConnectTestBase, - public ::testing::WithParamInterface> { +class TlsConnectPre12 : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsConnectPre12(); }; // A TLS 1.2 only generic test. -class TlsConnectTls12 : public TlsConnectTestBase, - public ::testing::WithParamInterface { +class TlsConnectTls12 + : public TlsConnectTestBase, + public ::testing::WithParamInterface { public: TlsConnectTls12(); }; @@ -209,20 +212,21 @@ class TlsConnectTls12 : public TlsConnectTestBase, class TlsConnectStreamTls12 : public TlsConnectTestBase { public: TlsConnectStreamTls12() - : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_2) {} + : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {} }; // A TLS 1.2+ generic test. -class TlsConnectTls12Plus - : public TlsConnectTestBase, - public ::testing::WithParamInterface> { +class TlsConnectTls12Plus : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: TlsConnectTls12Plus(); }; // A TLS 1.3 only generic test. -class TlsConnectTls13 : public TlsConnectTestBase, - public ::testing::WithParamInterface { +class TlsConnectTls13 + : public TlsConnectTestBase, + public ::testing::WithParamInterface { public: TlsConnectTls13(); }; @@ -231,13 +235,13 @@ class TlsConnectTls13 : public TlsConnectTestBase, class TlsConnectStreamTls13 : public TlsConnectTestBase { public: TlsConnectStreamTls13() - : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {} + : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {} }; class TlsConnectDatagram13 : public TlsConnectTestBase { public: TlsConnectDatagram13() - : TlsConnectTestBase(DGRAM, SSL_LIBRARY_VERSION_TLS_1_3) {} + : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {} }; // A variant that is used only with Pre13. @@ -245,10 +249,10 @@ class TlsConnectGenericPre13 : public TlsConnectGeneric {}; class TlsKeyExchangeTest : public TlsConnectGeneric { protected: - TlsExtensionCapture* groups_capture_; - TlsExtensionCapture* shares_capture_; - TlsExtensionCapture* shares_capture2_; - TlsInspectorRecordHandshakeMessage* capture_hrr_; + std::shared_ptr groups_capture_; + std::shared_ptr shares_capture_; + std::shared_ptr shares_capture2_; + std::shared_ptr capture_hrr_; void EnsureKeyShareSetup(); void ConfigNamedGroups(const std::vector& groups); diff --git a/nss/gtests/ssl_gtest/tls_filter.cc b/nss/gtests/ssl_gtest/tls_filter.cc index 4f7d195..76d9aaa 100644 --- a/nss/gtests/ssl_gtest/tls_filter.cc +++ b/nss/gtests/ssl_gtest/tls_filter.cc @@ -15,9 +15,62 @@ extern "C" { #include #include "gtest_utils.h" #include "tls_agent.h" +#include "tls_filter.h" +#include "tls_protect.h" namespace nss_test { +void TlsVersioned::WriteStream(std::ostream& stream) const { + stream << (is_dtls() ? "DTLS " : "TLS "); + switch (version()) { + case 0: + stream << "(no version)"; + break; + case SSL_LIBRARY_VERSION_TLS_1_0: + stream << "1.0"; + break; + case SSL_LIBRARY_VERSION_DTLS_1_0_WIRE: + case SSL_LIBRARY_VERSION_TLS_1_1: + stream << (is_dtls() ? "1.0" : "1.1"); + break; + case SSL_LIBRARY_VERSION_DTLS_1_2_WIRE: + case SSL_LIBRARY_VERSION_TLS_1_2: + stream << "1.2"; + break; + case SSL_LIBRARY_VERSION_TLS_1_3: + stream << "1.3"; + break; + default: + stream << "Invalid version: " << version(); + break; + } +} + +void TlsRecordFilter::EnableDecryption() { + SSLInt_SetCipherSpecChangeFunc(agent()->ssl_fd(), CipherSpecChanged, + (void*)this); +} + +void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending, + ssl3CipherSpec* newSpec) { + TlsRecordFilter* self = static_cast(arg); + PRBool isServer = self->agent()->role() == TlsAgent::SERVER; + + if (g_ssl_gtest_verbose) { + std::cerr << "Cipher spec changed. Role=" + << (isServer ? "server" : "client") + << " direction=" << (sending ? "send" : "receive") << std::endl; + } + if (!sending) return; + + self->cipher_spec_.reset(new TlsCipherSpec()); + bool ret = + self->cipher_spec_->Init(SSLInt_CipherSpecToAlgorithm(isServer, newSpec), + SSLInt_CipherSpecToKey(isServer, newSpec), + SSLInt_CipherSpecToIv(isServer, newSpec)); + EXPECT_EQ(true, ret); +} + PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) { bool changed = false; @@ -25,10 +78,13 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, output->Allocate(input.len()); TlsParser parser(input); + while (parser.remaining()) { - RecordHeader header; + TlsRecordHeader header; DataBuffer record; + if (!header.Parse(&parser, &record)) { + ADD_FAILURE() << "not a valid record"; return KEEP; } @@ -49,12 +105,21 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, return KEEP; } -PacketFilter::Action TlsRecordFilter::FilterRecord(const RecordHeader& header, - const DataBuffer& record, - size_t* offset, - DataBuffer* output) { +PacketFilter::Action TlsRecordFilter::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, + DataBuffer* output) { DataBuffer filtered; - PacketFilter::Action action = FilterRecord(header, record, &filtered); + uint8_t inner_content_type; + DataBuffer plaintext; + + if (!Unprotect(header, record, &inner_content_type, &plaintext)) { + return KEEP; + } + + TlsRecordHeader real_header = {header.version(), inner_content_type, + header.sequence_number()}; + + PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered); if (action == KEEP) { return KEEP; } @@ -64,19 +129,21 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(const RecordHeader& header, return DROP; } - const DataBuffer* source = &record; - if (action == CHANGE) { - EXPECT_GT(0x10000U, filtered.len()); - std::cerr << "record old: " << record << std::endl; - std::cerr << "record new: " << filtered << std::endl; - source = &filtered; - } + EXPECT_GT(0x10000U, filtered.len()); + std::cerr << "record old: " << plaintext << std::endl; + std::cerr << "record new: " << filtered << std::endl; - *offset = header.Write(output, *offset, *source); + DataBuffer ciphertext; + bool rv = Protect(header, inner_content_type, filtered, &ciphertext); + EXPECT_TRUE(rv); + if (!rv) { + return KEEP; + } + *offset = header.Write(output, *offset, ciphertext); return CHANGE; } -bool TlsRecordFilter::RecordHeader::Parse(TlsParser* parser, DataBuffer* body) { +bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) { if (!parser->Read(&content_type_)) { return false; } @@ -102,8 +169,8 @@ bool TlsRecordFilter::RecordHeader::Parse(TlsParser* parser, DataBuffer* body) { return parser->ReadVariable(body, 2); } -size_t TlsRecordFilter::RecordHeader::Write(DataBuffer* buffer, size_t offset, - const DataBuffer& body) const { +size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset, + const DataBuffer& body) const { offset = buffer->Write(offset, content_type_, 1); offset = buffer->Write(offset, version_, 2); if (is_dtls()) { @@ -116,8 +183,48 @@ size_t TlsRecordFilter::RecordHeader::Write(DataBuffer* buffer, size_t offset, return offset; } +bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, + const DataBuffer& ciphertext, + uint8_t* inner_content_type, + DataBuffer* plaintext) { + if (!cipher_spec_ || header.content_type() != kTlsApplicationDataType) { + *inner_content_type = header.content_type(); + *plaintext = ciphertext; + return true; + } + + if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) return false; + + size_t len = plaintext->len(); + while (len > 0 && !plaintext->data()[len - 1]) { + --len; + } + if (!len) { + // Bogus padding. + return false; + } + + *inner_content_type = plaintext->data()[len - 1]; + plaintext->Truncate(len - 1); + + return true; +} + +bool TlsRecordFilter::Protect(const TlsRecordHeader& header, + uint8_t inner_content_type, + const DataBuffer& plaintext, + DataBuffer* ciphertext) { + if (!cipher_spec_ || header.content_type() != kTlsApplicationDataType) { + *ciphertext = plaintext; + return true; + } + DataBuffer padded = plaintext; + padded.Write(padded.len(), inner_content_type, 1); + return cipher_spec_->Protect(header, padded, ciphertext); +} + PacketFilter::Action TlsHandshakeFilter::FilterRecord( - const RecordHeader& record_header, const DataBuffer& input, + const TlsRecordHeader& record_header, const DataBuffer& input, DataBuffer* output) { // Check that the first byte is as requested. if (record_header.content_type() != kTlsHandshakeType) { @@ -159,9 +266,8 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord( return changed ? (offset ? CHANGE : DROP) : KEEP; } -bool TlsHandshakeFilter::HandshakeHeader::ReadLength(TlsParser* parser, - const RecordHeader& header, - uint32_t* length) { +bool TlsHandshakeFilter::HandshakeHeader::ReadLength( + TlsParser* parser, const TlsRecordHeader& header, uint32_t* length) { if (!parser->Read(length, 3)) { return false; // malformed } @@ -192,7 +298,7 @@ bool TlsHandshakeFilter::HandshakeHeader::ReadLength(TlsParser* parser, } bool TlsHandshakeFilter::HandshakeHeader::Parse( - TlsParser* parser, const RecordHeader& record_header, DataBuffer* body) { + TlsParser* parser, const TlsRecordHeader& record_header, DataBuffer* body) { version_ = record_header.version(); if (!parser->Read(&handshake_type_)) { return false; // malformed @@ -205,15 +311,28 @@ bool TlsHandshakeFilter::HandshakeHeader::Parse( return parser->Read(body, length); } -size_t TlsHandshakeFilter::HandshakeHeader::Write( - DataBuffer* buffer, size_t offset, const DataBuffer& body) const { +size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment( + DataBuffer* buffer, size_t offset, const DataBuffer& body, + size_t fragment_offset, size_t fragment_length) const { + EXPECT_TRUE(is_dtls()); + EXPECT_GE(body.len(), fragment_offset + fragment_length); offset = buffer->Write(offset, handshake_type(), 1); offset = buffer->Write(offset, body.len(), 3); + offset = buffer->Write(offset, message_seq_, 2); + offset = buffer->Write(offset, fragment_offset, 3); + offset = buffer->Write(offset, fragment_length, 3); + offset = + buffer->Write(offset, body.data() + fragment_offset, fragment_length); + return offset; +} + +size_t TlsHandshakeFilter::HandshakeHeader::Write( + DataBuffer* buffer, size_t offset, const DataBuffer& body) const { if (is_dtls()) { - offset = buffer->Write(offset, message_seq_, 2); - offset = buffer->Write(offset, 0U, 3); // fragment_offset - offset = buffer->Write(offset, body.len(), 3); + return WriteFragment(buffer, offset, body, 0U, body.len()); } + offset = buffer->Write(offset, handshake_type(), 1); + offset = buffer->Write(offset, body.len(), 3); offset = buffer->Write(offset, body); return offset; } @@ -244,42 +363,12 @@ PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake( } PacketFilter::Action TlsConversationRecorder::FilterRecord( - const RecordHeader& header, const DataBuffer& input, DataBuffer* output) { + const TlsRecordHeader& header, const DataBuffer& input, + DataBuffer* output) { buffer_.Append(input); return KEEP; } -PacketFilter::Action TlsAlertRecorder::FilterRecord(const RecordHeader& header, - const DataBuffer& input, - DataBuffer* output) { - if (level_ == kTlsAlertFatal) { // already fatal - return KEEP; - } - if (header.content_type() != kTlsAlertType) { - return KEEP; - } - - std::cerr << "Alert: " << input << std::endl; - - TlsParser parser(input); - uint8_t lvl; - if (!parser.Read(&lvl)) { - return KEEP; - } - if (lvl == kTlsAlertWarning) { // not strong enough - return KEEP; - } - level_ = lvl; - (void)parser.Read(&description_); - return KEEP; -} - -ChainedPacketFilter::~ChainedPacketFilter() { - for (auto it = filters_.begin(); it != filters_.end(); ++it) { - delete *it; - } -} - PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input, DataBuffer* output) { DataBuffer in(input); @@ -297,28 +386,7 @@ PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input, return changed ? CHANGE : KEEP; } -PacketFilter::Action TlsExtensionFilter::FilterHandshake( - const HandshakeHeader& header, const DataBuffer& input, - DataBuffer* output) { - if (header.handshake_type() == kTlsHandshakeClientHello) { - TlsParser parser(input); - if (!FindClientHelloExtensions(&parser, header)) { - return KEEP; - } - return FilterExtensions(&parser, input, output); - } - if (header.handshake_type() == kTlsHandshakeServerHello) { - TlsParser parser(input); - if (!FindServerHelloExtensions(&parser)) { - return KEEP; - } - return FilterExtensions(&parser, input, output); - } - return KEEP; -} - -bool TlsExtensionFilter::FindClientHelloExtensions(TlsParser* parser, - const Versioned& header) { +bool FindClientHelloExtensions(TlsParser* parser, const TlsVersioned& header) { if (!parser->Skip(2 + 32)) { // version + random return false; } @@ -337,7 +405,7 @@ bool TlsExtensionFilter::FindClientHelloExtensions(TlsParser* parser, return true; } -bool TlsExtensionFilter::FindServerHelloExtensions(TlsParser* parser) { +bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) { uint32_t vtmp; if (!parser->Read(&vtmp, 2)) { return false; @@ -362,6 +430,92 @@ bool TlsExtensionFilter::FindServerHelloExtensions(TlsParser* parser) { return true; } +static bool FindHelloRetryExtensions(TlsParser* parser, + const TlsVersioned& header) { + // TODO for -19 add cipher suite + if (!parser->Skip(2)) { // version + return false; + } + return true; +} + +bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) { + return true; +} + +static bool FindCertReqExtensions(TlsParser* parser, + const TlsVersioned& header) { + if (!parser->SkipVariable(1)) { // request context + return false; + } + // TODO remove the next two for -19 + if (!parser->SkipVariable(2)) { // signature_algorithms + return false; + } + if (!parser->SkipVariable(2)) { // certificate_authorities + return false; + } + return true; +} + +// Only look at the EE cert for this one. +static bool FindCertificateExtensions(TlsParser* parser, + const TlsVersioned& header) { + if (!parser->SkipVariable(1)) { // request context + return false; + } + if (!parser->Skip(3)) { // length of certificate list + return false; + } + if (!parser->SkipVariable(3)) { // ASN1Cert + return false; + } + return true; +} + +static bool FindNewSessionTicketExtensions(TlsParser* parser, + const TlsVersioned& header) { + if (!parser->Skip(8)) { // lifetime, age add + return false; + } + if (!parser->SkipVariable(2)) { // ticket + return false; + } + return true; +} + +static const std::map kExtensionFinders = { + {kTlsHandshakeClientHello, FindClientHelloExtensions}, + {kTlsHandshakeServerHello, FindServerHelloExtensions}, + {kTlsHandshakeHelloRetryRequest, FindHelloRetryExtensions}, + {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions}, + {kTlsHandshakeCertificateRequest, FindCertReqExtensions}, + {kTlsHandshakeCertificate, FindCertificateExtensions}, + {kTlsHandshakeNewSessionTicket, FindNewSessionTicketExtensions}}; + +bool TlsExtensionFilter::FindExtensions(TlsParser* parser, + const HandshakeHeader& header) { + auto it = kExtensionFinders.find(header.handshake_type()); + if (it == kExtensionFinders.end()) { + return false; + } + return (it->second)(parser, header); +} + +PacketFilter::Action TlsExtensionFilter::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + if (handshake_types_.count(header.handshake_type()) == 0) { + return KEEP; + } + + TlsParser parser(input); + if (!FindExtensions(&parser, header)) { + return KEEP; + } + return FilterExtensions(&parser, input, output); +} + PacketFilter::Action TlsExtensionFilter::FilterExtensions( TlsParser* parser, const DataBuffer& input, DataBuffer* output) { size_t length_offset = parser->consumed(); @@ -456,14 +610,14 @@ PacketFilter::Action TlsExtensionDropper::FilterExtension( return KEEP; } -PacketFilter::Action AfterRecordN::FilterRecord(const RecordHeader& header, +PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) { if (counter_++ == record_) { DataBuffer buf; header.Write(&buf, 0, body); - src_->SendDirect(buf); - dest_->Handshake(); + src_.lock()->SendDirect(buf); + dest_.lock()->Handshake(); func_(); return DROP; } @@ -476,7 +630,7 @@ PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake( DataBuffer* output) { if (header.handshake_type() == kTlsHandshakeClientKeyExchange) { EXPECT_EQ(SECSuccess, - SSLInt_IncrementClientHandshakeVersion(server_->ssl_fd())); + SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd())); } return KEEP; } diff --git a/nss/gtests/ssl_gtest/tls_filter.h b/nss/gtests/ssl_gtest/tls_filter.h index fa2e387..e4030e2 100644 --- a/nss/gtests/ssl_gtest/tls_filter.h +++ b/nss/gtests/ssl_gtest/tls_filter.h @@ -9,17 +9,67 @@ #include #include +#include #include #include "test_io.h" #include "tls_parser.h" +#include "tls_protect.h" + +extern "C" { +#include "libssl_internals.h" +} namespace nss_test { +class TlsCipherSpec; +class TlsAgent; + +class TlsVersioned { + public: + TlsVersioned() : version_(0) {} + explicit TlsVersioned(uint16_t version) : version_(version) {} + + bool is_dtls() const { return IsDtls(version_); } + uint16_t version() const { return version_; } + + void WriteStream(std::ostream& stream) const; + + protected: + uint16_t version_; +}; + +class TlsRecordHeader : public TlsVersioned { + public: + TlsRecordHeader() : TlsVersioned(), content_type_(0), sequence_number_(0) {} + TlsRecordHeader(uint16_t version, uint8_t content_type, + uint64_t sequence_number) + : TlsVersioned(version), + content_type_(content_type), + sequence_number_(sequence_number) {} + + uint8_t content_type() const { return content_type_; } + uint64_t sequence_number() const { return sequence_number_; } + size_t header_length() const { return is_dtls() ? 11 : 3; } + + // Parse the header; return true if successful; body in an outparam if OK. + bool Parse(TlsParser* parser, DataBuffer* body); + // Write the header and body to a buffer at the given offset. + // Return the offset of the end of the write. + size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; + + private: + uint8_t content_type_; + uint64_t sequence_number_; +}; + // Abstract filter that operates on entire (D)TLS records. class TlsRecordFilter : public PacketFilter { public: - TlsRecordFilter() : count_(0) {} + TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {} + + void SetAgent(const TlsAgent* agent) { agent_ = agent; } + const TlsAgent* agent() const { return agent_; } // External interface. Overrides PacketFilter. PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); @@ -27,42 +77,14 @@ class TlsRecordFilter : public PacketFilter { // Report how many packets were altered by the filter. size_t filtered_packets() const { return count_; } - class Versioned { - public: - Versioned() : version_(0) {} - explicit Versioned(uint16_t version) : version_(version) {} - - bool is_dtls() const { return IsDtls(version_); } - uint16_t version() const { return version_; } - - protected: - uint16_t version_; - }; - - class RecordHeader : public Versioned { - public: - RecordHeader() : Versioned(), content_type_(0), sequence_number_(0) {} - RecordHeader(uint16_t version, uint8_t content_type, - uint64_t sequence_number) - : Versioned(version), - content_type_(content_type), - sequence_number_(sequence_number) {} - - uint8_t content_type() const { return content_type_; } - uint64_t sequence_number() const { return sequence_number_; } - size_t header_length() const { return is_dtls() ? 11 : 3; } - - // Parse the header; return true if successful; body in an outparam if OK. - bool Parse(TlsParser* parser, DataBuffer* body); - // Write the header and body to a buffer at the given offset. - // Return the offset of the end of the write. - size_t Write(DataBuffer* buffer, size_t offset, - const DataBuffer& body) const; - - private: - uint8_t content_type_; - uint64_t sequence_number_; - }; + // Enable decryption. This only works properly for TLS 1.3 and above. + // Enabling it for lower version tests will cause undefined + // behavior. + void EnableDecryption(); + bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText, + uint8_t* inner_content_type, DataBuffer* plaintext); + bool Protect(const TlsRecordHeader& header, uint8_t inner_content_type, + const DataBuffer& plaintext, DataBuffer* ciphertext); protected: // There are two filter functions which can be overriden. Both are @@ -72,7 +94,7 @@ class TlsRecordFilter : public PacketFilter { // just lets you change the record contents. By default, the // outer one calls the inner one, so if you override the outer // one, the inner one is never called unless you call it yourself. - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, DataBuffer* output); @@ -80,16 +102,49 @@ class TlsRecordFilter : public PacketFilter { // sequence number (which is zero for TLS), plus the existing record payload. // It returns an action (KEEP, CHANGE, DROP). It writes to the `changed` // outparam with the new record contents if it chooses to CHANGE the record. - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& data, DataBuffer* changed) { return KEEP; } private: + static void CipherSpecChanged(void* arg, PRBool sending, + ssl3CipherSpec* newSpec); + + const TlsAgent* agent_; size_t count_; + std::unique_ptr cipher_spec_; }; +inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) { + v.WriteStream(stream); + return stream; +} + +inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) { + hdr.WriteStream(stream); + stream << ' '; + switch (hdr.content_type()) { + case kTlsChangeCipherSpecType: + stream << "CCS"; + break; + case kTlsAlertType: + stream << "Alert"; + break; + case kTlsHandshakeType: + stream << "Handshake"; + break; + case kTlsApplicationDataType: + stream << "Data"; + break; + default: + stream << '<' << hdr.content_type() << '>'; + break; + } + return stream << ' ' << std::hex << hdr.sequence_number() << std::dec; +} + // Abstract filter that operates on handshake messages rather than records. // This assumes that the handshake messages are written in a block as entire // records and that they don't span records or anything crazy like that. @@ -97,20 +152,23 @@ class TlsHandshakeFilter : public TlsRecordFilter { public: TlsHandshakeFilter() {} - class HandshakeHeader : public Versioned { + class HandshakeHeader : public TlsVersioned { public: - HandshakeHeader() : Versioned(), handshake_type_(0), message_seq_(0) {} + HandshakeHeader() : TlsVersioned(), handshake_type_(0), message_seq_(0) {} uint8_t handshake_type() const { return handshake_type_; } - bool Parse(TlsParser* parser, const RecordHeader& record_header, + bool Parse(TlsParser* parser, const TlsRecordHeader& record_header, DataBuffer* body); size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; + size_t WriteFragment(DataBuffer* buffer, size_t offset, + const DataBuffer& body, size_t fragment_offset, + size_t fragment_length) const; private: // Reads the length from the record header. // This also reads the DTLS fragment information and checks it. - bool ReadLength(TlsParser* parser, const RecordHeader& header, + bool ReadLength(TlsParser* parser, const TlsRecordHeader& header, uint32_t* length); uint8_t handshake_type_; @@ -119,7 +177,7 @@ class TlsHandshakeFilter : public TlsRecordFilter { }; protected: - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, @@ -167,7 +225,7 @@ class TlsConversationRecorder : public TlsRecordFilter { public: TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {} - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); @@ -175,43 +233,39 @@ class TlsConversationRecorder : public TlsRecordFilter { DataBuffer& buffer_; }; -// Records an alert. If an alert has already been recorded, it won't save the -// new alert unless the old alert is a warning and the new one is fatal. -class TlsAlertRecorder : public TlsRecordFilter { - public: - TlsAlertRecorder() : level_(255), description_(255) {} - - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, - const DataBuffer& input, - DataBuffer* output); - - uint8_t level() const { return level_; } - uint8_t description() const { return description_; } - - private: - uint8_t level_; - uint8_t description_; -}; - // Runs multiple packet filters in series. class ChainedPacketFilter : public PacketFilter { public: ChainedPacketFilter() {} - ChainedPacketFilter(const std::vector filters) + ChainedPacketFilter(const std::vector> filters) : filters_(filters.begin(), filters.end()) {} - virtual ~ChainedPacketFilter(); + virtual ~ChainedPacketFilter() {} virtual PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); // Takes ownership of the filter. - void Add(PacketFilter* filter) { filters_.push_back(filter); } + void Add(std::shared_ptr filter) { filters_.push_back(filter); } private: - std::vector filters_; + std::vector> filters_; }; +typedef std::function + TlsExtensionFinder; + class TlsExtensionFilter : public TlsHandshakeFilter { + public: + TlsExtensionFilter() : handshake_types_() { + handshake_types_.insert(kTlsHandshakeClientHello); + handshake_types_.insert(kTlsHandshakeServerHello); + } + + TlsExtensionFilter(const std::set& types) + : handshake_types_(types) {} + + static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); + protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -221,15 +275,12 @@ class TlsExtensionFilter : public TlsHandshakeFilter { const DataBuffer& input, DataBuffer* output) = 0; - public: - static bool FindClientHelloExtensions(TlsParser* parser, - const Versioned& header); - static bool FindServerHelloExtensions(TlsParser* parser); - private: PacketFilter::Action FilterExtensions(TlsParser* parser, const DataBuffer& input, DataBuffer* output); + + std::set handshake_types_; }; class TlsExtensionCapture : public TlsExtensionFilter { @@ -280,17 +331,17 @@ typedef std::function VoidFunction; class AfterRecordN : public TlsRecordFilter { public: - AfterRecordN(TlsAgent* src, TlsAgent* dest, unsigned int record, - VoidFunction func) + AfterRecordN(std::shared_ptr& src, std::shared_ptr& dest, + unsigned int record, VoidFunction func) : src_(src), dest_(dest), record_(record), func_(func), counter_(0) {} - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) override; private: - TlsAgent* src_; - TlsAgent* dest_; + std::weak_ptr src_; + std::weak_ptr dest_; unsigned int record_; VoidFunction func_; unsigned int counter_; @@ -300,14 +351,15 @@ class AfterRecordN : public TlsRecordFilter { // ClientHelloVersion on |server|. class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter { public: - TlsInspectorClientHelloVersionChanger(TlsAgent* server) : server_(server) {} + TlsInspectorClientHelloVersionChanger(std::shared_ptr& server) + : server_(server) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: - TlsAgent* server_; + std::weak_ptr server_; }; // This class selectively drops complete writes. This relies on the fact that @@ -338,6 +390,27 @@ class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { uint16_t version_; }; +// Damages the last byte of a handshake message. +class TlsLastByteDamager : public TlsHandshakeFilter { + public: + TlsLastByteDamager(uint8_t type) : type_(type) {} + PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) override { + if (header.handshake_type() != type_) { + return KEEP; + } + + *output = input; + + output->data()[output->len() - 1]++; + return CHANGE; + } + + private: + uint8_t type_; +}; + } // namespace nss_test #endif diff --git a/nss/gtests/ssl_gtest/tls_parser.cc b/nss/gtests/ssl_gtest/tls_parser.cc deleted file mode 100644 index e4c06aa..0000000 --- a/nss/gtests/ssl_gtest/tls_parser.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#include "tls_parser.h" - -namespace nss_test { - -bool TlsParser::Read(uint8_t* val) { - if (remaining() < 1) { - return false; - } - *val = *ptr(); - consume(1); - return true; -} - -bool TlsParser::Read(uint32_t* val, size_t size) { - if (size > sizeof(uint32_t)) { - return false; - } - - uint32_t v = 0; - for (size_t i = 0; i < size; ++i) { - uint8_t tmp; - if (!Read(&tmp)) { - return false; - } - - v = (v << 8) | tmp; - } - - *val = v; - return true; -} - -bool TlsParser::Read(DataBuffer* val, size_t len) { - if (remaining() < len) { - return false; - } - - val->Assign(ptr(), len); - consume(len); - return true; -} - -bool TlsParser::ReadVariable(DataBuffer* val, size_t len_size) { - uint32_t len; - if (!Read(&len, len_size)) { - return false; - } - return Read(val, len); -} - -bool TlsParser::Skip(size_t len) { - if (len > remaining()) { - return false; - } - consume(len); - return true; -} - -bool TlsParser::SkipVariable(size_t len_size) { - uint32_t len; - if (!Read(&len, len_size)) { - return false; - } - return Skip(len); -} - -} // namespace nss_test diff --git a/nss/gtests/ssl_gtest/tls_parser.h b/nss/gtests/ssl_gtest/tls_parser.h deleted file mode 100644 index c79d45a..0000000 --- a/nss/gtests/ssl_gtest/tls_parser.h +++ /dev/null @@ -1,131 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#ifndef tls_parser_h_ -#define tls_parser_h_ - -#include -#include -#include -#if defined(WIN32) || defined(WIN64) -#include -#else -#include -#endif -#include "databuffer.h" - -namespace nss_test { - -const uint8_t kTlsChangeCipherSpecType = 20; -const uint8_t kTlsAlertType = 21; -const uint8_t kTlsHandshakeType = 22; -const uint8_t kTlsApplicationDataType = 23; - -const uint8_t kTlsHandshakeClientHello = 1; -const uint8_t kTlsHandshakeServerHello = 2; -const uint8_t kTlsHandshakeHelloRetryRequest = 6; -const uint8_t kTlsHandshakeEncryptedExtensions = 8; -const uint8_t kTlsHandshakeCertificate = 11; -const uint8_t kTlsHandshakeServerKeyExchange = 12; -const uint8_t kTlsHandshakeCertificateVerify = 15; -const uint8_t kTlsHandshakeClientKeyExchange = 16; -const uint8_t kTlsHandshakeFinished = 20; - -const uint8_t kTlsAlertWarning = 1; -const uint8_t kTlsAlertFatal = 2; - -const uint8_t kTlsAlertUnexpectedMessage = 10; -const uint8_t kTlsAlertBadRecordMac = 20; -const uint8_t kTlsAlertHandshakeFailure = 40; -const uint8_t kTlsAlertIllegalParameter = 47; -const uint8_t kTlsAlertDecodeError = 50; -const uint8_t kTlsAlertDecryptError = 51; -const uint8_t kTlsAlertMissingExtension = 109; -const uint8_t kTlsAlertUnsupportedExtension = 110; -const uint8_t kTlsAlertUnrecognizedName = 112; -const uint8_t kTlsAlertNoApplicationProtocol = 120; - -const uint8_t kTlsFakeChangeCipherSpec[] = { - kTlsChangeCipherSpecType, // Type - 0xfe, - 0xff, // Version - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x10, // Fictitious sequence # - 0x00, - 0x01, // Length - 0x01 // Value -}; - -static const uint8_t kTls13PskKe = 0; -static const uint8_t kTls13PskDhKe = 1; -static const uint8_t kTls13PskAuth = 0; -static const uint8_t kTls13PskSignAuth = 1; - -inline bool IsDtls(uint16_t version) { return (version & 0x8000) == 0x8000; } - -inline uint16_t NormalizeTlsVersion(uint16_t version) { - if (version == 0xfeff) { - return 0x0302; // special: DTLS 1.0 == TLS 1.1 - } - if (IsDtls(version)) { - return (version ^ 0xffff) + 0x0201; - } - return version; -} - -inline uint16_t TlsVersionToDtlsVersion(uint16_t version) { - if (version == 0x0302) { - return 0xfeff; - } - if (version == 0x0304) { - return version; - } - return 0xffff - version + 0x0201; -} - -inline size_t WriteVariable(DataBuffer* target, size_t index, - const DataBuffer& buf, size_t len_size) { - index = target->Write(index, static_cast(buf.len()), len_size); - return target->Write(index, buf.data(), buf.len()); -} - -class TlsParser { - public: - TlsParser(const uint8_t* data, size_t len) : buffer_(data, len), offset_(0) {} - explicit TlsParser(const DataBuffer& buf) : buffer_(buf), offset_(0) {} - - bool Read(uint8_t* val); - // Read an integral type of specified width. - bool Read(uint32_t* val, size_t size); - // Reads len bytes into dest buffer, overwriting it. - bool Read(DataBuffer* dest, size_t len); - // Reads bytes into dest buffer, overwriting it. The number of bytes is - // determined by reading from len_size bytes from the stream first. - bool ReadVariable(DataBuffer* dest, size_t len_size); - - bool Skip(size_t len); - bool SkipVariable(size_t len_size); - - size_t consumed() const { return offset_; } - size_t remaining() const { return buffer_.len() - offset_; } - - private: - void consume(size_t len) { offset_ += len; } - const uint8_t* ptr() const { return buffer_.data() + offset_; } - - DataBuffer buffer_; - size_t offset_; -}; - -} // namespace nss_test - -#endif diff --git a/nss/gtests/ssl_gtest/tls_protect.cc b/nss/gtests/ssl_gtest/tls_protect.cc new file mode 100644 index 0000000..efcd89e --- /dev/null +++ b/nss/gtests/ssl_gtest/tls_protect.cc @@ -0,0 +1,145 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_protect.h" +#include "tls_filter.h" + +namespace nss_test { + +AeadCipher::~AeadCipher() { + if (key_) { + PK11_FreeSymKey(key_); + } +} + +bool AeadCipher::Init(PK11SymKey *key, const uint8_t *iv) { + key_ = PK11_ReferenceSymKey(key); + if (!key_) return false; + + memcpy(iv_, iv, sizeof(iv_)); + return true; +} + +void AeadCipher::FormatNonce(uint64_t seq, uint8_t *nonce) { + memcpy(nonce, iv_, 12); + + for (size_t i = 0; i < 8; ++i) { + nonce[12 - (i + 1)] ^= seq & 0xff; + seq >>= 8; + } + + DataBuffer d(nonce, 12); + std::cerr << "Nonce " << d << std::endl; +} + +bool AeadCipher::AeadInner(bool decrypt, void *params, size_t param_length, + const uint8_t *in, size_t inlen, uint8_t *out, + size_t *outlen, size_t maxlen) { + SECStatus rv; + unsigned int uoutlen = 0; + SECItem param = { + siBuffer, static_cast(params), + static_cast(param_length), + }; + + if (decrypt) { + rv = PK11_Decrypt(key_, mech_, ¶m, out, &uoutlen, maxlen, in, inlen); + } else { + rv = PK11_Encrypt(key_, mech_, ¶m, out, &uoutlen, maxlen, in, inlen); + } + *outlen = (int)uoutlen; + + return rv == SECSuccess; +} + +bool AeadCipherAesGcm::Aead(bool decrypt, uint64_t seq, const uint8_t *in, + size_t inlen, uint8_t *out, size_t *outlen, + size_t maxlen) { + CK_GCM_PARAMS aeadParams; + unsigned char nonce[12]; + + memset(&aeadParams, 0, sizeof(aeadParams)); + aeadParams.pIv = nonce; + aeadParams.ulIvLen = sizeof(nonce); + aeadParams.pAAD = NULL; + aeadParams.ulAADLen = 0; + aeadParams.ulTagBits = 128; + + FormatNonce(seq, nonce); + return AeadInner(decrypt, (unsigned char *)&aeadParams, sizeof(aeadParams), + in, inlen, out, outlen, maxlen); +} + +bool AeadCipherChacha20Poly1305::Aead(bool decrypt, uint64_t seq, + const uint8_t *in, size_t inlen, + uint8_t *out, size_t *outlen, + size_t maxlen) { + CK_NSS_AEAD_PARAMS aeadParams; + unsigned char nonce[12]; + + memset(&aeadParams, 0, sizeof(aeadParams)); + aeadParams.pNonce = nonce; + aeadParams.ulNonceLen = sizeof(nonce); + aeadParams.pAAD = NULL; + aeadParams.ulAADLen = 0; + aeadParams.ulTagLen = 16; + + FormatNonce(seq, nonce); + return AeadInner(decrypt, (unsigned char *)&aeadParams, sizeof(aeadParams), + in, inlen, out, outlen, maxlen); +} + +bool TlsCipherSpec::Init(SSLCipherAlgorithm cipher, PK11SymKey *key, + const uint8_t *iv) { + switch (cipher) { + case ssl_calg_aes_gcm: + aead_.reset(new AeadCipherAesGcm()); + break; + case ssl_calg_chacha20: + aead_.reset(new AeadCipherChacha20Poly1305()); + break; + default: + return false; + } + + return aead_->Init(key, iv); +} + +bool TlsCipherSpec::Unprotect(const TlsRecordHeader &header, + const DataBuffer &ciphertext, + DataBuffer *plaintext) { + // Make space. + plaintext->Allocate(ciphertext.len()); + + size_t len; + bool ret = + aead_->Aead(true, header.sequence_number(), ciphertext.data(), + ciphertext.len(), plaintext->data(), &len, plaintext->len()); + if (!ret) return false; + + plaintext->Truncate(len); + + return true; +} + +bool TlsCipherSpec::Protect(const TlsRecordHeader &header, + const DataBuffer &plaintext, + DataBuffer *ciphertext) { + // Make a padded buffer. + + ciphertext->Allocate(plaintext.len() + + 32); // Room for any plausible auth tag + size_t len; + bool ret = + aead_->Aead(false, header.sequence_number(), plaintext.data(), + plaintext.len(), ciphertext->data(), &len, ciphertext->len()); + if (!ret) return false; + ciphertext->Truncate(len); + + return true; +} + +} // namespace nss_test diff --git a/nss/gtests/ssl_gtest/tls_protect.h b/nss/gtests/ssl_gtest/tls_protect.h new file mode 100644 index 0000000..4efbd6e --- /dev/null +++ b/nss/gtests/ssl_gtest/tls_protect.h @@ -0,0 +1,76 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_protection_h_ +#define tls_protection_h_ + +#include +#include + +#include "databuffer.h" +#include "pk11pub.h" +#include "sslt.h" + +namespace nss_test { +class TlsRecordHeader; + +class AeadCipher { + public: + AeadCipher(CK_MECHANISM_TYPE mech) : mech_(mech), key_(nullptr) {} + ~AeadCipher(); + + bool Init(PK11SymKey *key, const uint8_t *iv); + virtual bool Aead(bool decrypt, uint64_t seq, const uint8_t *in, size_t inlen, + uint8_t *out, size_t *outlen, size_t maxlen) = 0; + + protected: + void FormatNonce(uint64_t seq, uint8_t *nonce); + bool AeadInner(bool decrypt, void *params, size_t param_length, + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen, + size_t maxlen); + + CK_MECHANISM_TYPE mech_; + PK11SymKey *key_; + uint8_t iv_[12]; +}; + +class AeadCipherChacha20Poly1305 : public AeadCipher { + public: + AeadCipherChacha20Poly1305() : AeadCipher(CKM_NSS_CHACHA20_POLY1305) {} + + protected: + bool Aead(bool decrypt, uint64_t seq, const uint8_t *in, size_t inlen, + uint8_t *out, size_t *outlen, size_t maxlen); +}; + +class AeadCipherAesGcm : public AeadCipher { + public: + AeadCipherAesGcm() : AeadCipher(CKM_AES_GCM) {} + + protected: + bool Aead(bool decrypt, uint64_t seq, const uint8_t *in, size_t inlen, + uint8_t *out, size_t *outlen, size_t maxlen); +}; + +// Our analog of ssl3CipherSpec +class TlsCipherSpec { + public: + TlsCipherSpec() : aead_() {} + + bool Init(SSLCipherAlgorithm cipher, PK11SymKey *key, const uint8_t *iv); + + bool Protect(const TlsRecordHeader &header, const DataBuffer &plaintext, + DataBuffer *ciphertext); + bool Unprotect(const TlsRecordHeader &header, const DataBuffer &ciphertext, + DataBuffer *plaintext); + + private: + std::unique_ptr aead_; +}; + +} // namespace nss_test + +#endif -- cgit v1.2.1