diff options
author | Martin Thomson <martin.thomson@gmail.com> | 2015-03-03 11:39:56 -0800 |
---|---|---|
committer | Martin Thomson <martin.thomson@gmail.com> | 2015-03-03 11:39:56 -0800 |
commit | d90246d1a11d4cbb8a774df3e2beddd8ec913323 (patch) | |
tree | c308179224b0b8cac7787f151b66dfe8d1bbc68c | |
parent | a32941f568893ef73de007fdd53f7220d5d219ff (diff) | |
download | nss-hg-d90246d1a11d4cbb8a774df3e2beddd8ec913323.tar.gz |
Bug 1139082 - Refactoring ssl_gtest to use filters, r=ekr
-rw-r--r-- | external_tests/ssl_gtest/databuffer.h | 123 | ||||
-rw-r--r-- | external_tests/ssl_gtest/manifest.mn | 5 | ||||
-rw-r--r-- | external_tests/ssl_gtest/ssl_loopback_unittest.cc | 605 | ||||
-rw-r--r-- | external_tests/ssl_gtest/test_io.cc | 96 | ||||
-rw-r--r-- | external_tests/ssl_gtest/test_io.h | 33 | ||||
-rw-r--r-- | external_tests/ssl_gtest/tls_agent.cc | 208 | ||||
-rw-r--r-- | external_tests/ssl_gtest/tls_agent.h | 170 | ||||
-rw-r--r-- | external_tests/ssl_gtest/tls_connect.cc | 170 | ||||
-rw-r--r-- | external_tests/ssl_gtest/tls_connect.h | 79 | ||||
-rw-r--r-- | external_tests/ssl_gtest/tls_filter.cc | 226 | ||||
-rw-r--r-- | external_tests/ssl_gtest/tls_filter.h | 113 | ||||
-rw-r--r-- | external_tests/ssl_gtest/tls_parser.cc | 68 | ||||
-rw-r--r-- | external_tests/ssl_gtest/tls_parser.h | 88 |
13 files changed, 1290 insertions, 694 deletions
diff --git a/external_tests/ssl_gtest/databuffer.h b/external_tests/ssl_gtest/databuffer.h index 316aeb2a2..c3d3bb9be 100644 --- a/external_tests/ssl_gtest/databuffer.h +++ b/external_tests/ssl_gtest/databuffer.h @@ -7,33 +7,142 @@ #ifndef databuffer_h__ #define databuffer_h__ +#include <algorithm> +#include <cassert> +#include <cstring> +#include <iomanip> +#include <iostream> + +namespace nss_test { + class DataBuffer { public: DataBuffer() : data_(nullptr), len_(0) {} DataBuffer(const uint8_t *data, size_t len) : data_(nullptr), len_(0) { Assign(data, len); } + explicit DataBuffer(const DataBuffer& other) : data_(nullptr), len_(0) { + Assign(other.data(), other.len()); + } ~DataBuffer() { delete[] data_; } - void Assign(const uint8_t *data, size_t len) { - Allocate(len); - memcpy(static_cast<void *>(data_), static_cast<const void *>(data), len); + DataBuffer& operator=(const DataBuffer& other) { + if (&other != this) { + Assign(other.data(), other.len()); + } + return *this; } void Allocate(size_t len) { delete[] data_; - data_ = new unsigned char[len ? len : 1]; // Don't depend on new [0]. + data_ = new uint8_t[len ? len : 1]; // Don't depend on new [0]. len_ = len; } + void Truncate(size_t len) { + len_ = std::min(len_, len); + } + + void Assign(const uint8_t* data, size_t len) { + Allocate(len); + memcpy(static_cast<void *>(data_), static_cast<const void *>(data), len); + } + + // Write will do a new allocation and expand the size of the buffer if needed. + void Write(size_t index, const uint8_t* val, size_t count) { + if (index + count > len_) { + size_t newlen = index + count; + uint8_t* tmp = new uint8_t[newlen]; // Always > 0. + memcpy(static_cast<void*>(tmp), + static_cast<const void*>(data_), len_); + if (index > len_) { + memset(static_cast<void*>(tmp + len_), 0, index - len_); + } + delete[] data_; + data_ = tmp; + len_ = newlen; + } + memcpy(static_cast<void*>(data_ + index), + static_cast<const void*>(val), count); + } + + void Write(size_t index, const DataBuffer& buf) { + Write(index, buf.data(), buf.len()); + } + + // Write an integer, also performing host-to-network order conversion. + void Write(size_t index, uint32_t val, size_t count) { + assert(count <= sizeof(uint32_t)); + uint32_t nvalue = htonl(val); + auto* addr = reinterpret_cast<const uint8_t*>(&nvalue); + Write(index, addr + sizeof(uint32_t) - count, count); + } + + // Starting at |index|, remove |remove| bytes and replace them with the + // contents of |buf|. + void Splice(const DataBuffer& buf, size_t index, size_t remove = 0) { + Splice(buf.data(), buf.len(), index, remove); + } + + void Splice(const uint8_t* ins, size_t ins_len, size_t index, size_t remove = 0) { + uint8_t* old_value = data_; + size_t old_len = len_; + + // The amount of stuff remaining from the tail of the old. + size_t tail_len = old_len - std::min(old_len, index + remove); + // The new length: the head of the old, the new, and the tail of the old. + len_ = index + ins_len + tail_len; + data_ = new uint8_t[len_ ? len_ : 1]; + + // The head of the old. + Write(0, old_value, std::min(old_len, index)); + // Maybe a gap. + if (index > old_len) { + memset(old_value + index, 0, index - old_len); + } + // The new. + Write(index, ins, ins_len); + // The tail of the old. + if (tail_len > 0) { + Write(index + ins_len, + old_value + index + remove, tail_len); + } + + delete[] old_value; + } + + void Append(const DataBuffer& buf) { Splice(buf, len_); } + const uint8_t *data() const { return data_; } - uint8_t *data() { return data_; } + uint8_t* data() { return data_; } size_t len() const { return len_; } - const bool empty() const { return len_ != 0; } + bool empty() const { return len_ == 0; } private: - uint8_t *data_; + uint8_t* data_; size_t len_; }; +#ifdef DEBUG +static const size_t kMaxBufferPrint = 10000; +#else +static const size_t kMaxBufferPrint = 32; +#endif + +inline std::ostream& operator<<(std::ostream& stream, const DataBuffer& buf) { + stream << "[" << buf.len() << "] "; + for (size_t i = 0; i < buf.len(); ++i) { + if (i >= kMaxBufferPrint) { + stream << "..."; + break; + } + stream << std::hex << std::setfill('0') << std::setw(2) + << static_cast<unsigned>(buf.data()[i]); + } + stream << std::dec; + return stream; +} + +} // namespace nss_test + #endif diff --git a/external_tests/ssl_gtest/manifest.mn b/external_tests/ssl_gtest/manifest.mn index e66f532b4..ee883e9ac 100644 --- a/external_tests/ssl_gtest/manifest.mn +++ b/external_tests/ssl_gtest/manifest.mn @@ -1,4 +1,4 @@ -# +# # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. @@ -10,6 +10,9 @@ CPPSRCS = \ ssl_loopback_unittest.cc \ ssl_gtest.cc \ test_io.cc \ + tls_agent.cc \ + tls_connect.cc \ + tls_filter.cc \ tls_parser.cc \ $(NULL) diff --git a/external_tests/ssl_gtest/ssl_loopback_unittest.cc b/external_tests/ssl_gtest/ssl_loopback_unittest.cc index 6c01887a7..d70e2ceeb 100644 --- a/external_tests/ssl_gtest/ssl_loopback_unittest.cc +++ b/external_tests/ssl_gtest/ssl_loopback_unittest.cc @@ -1,182 +1,24 @@ -#include "prio.h" -#include "prerror.h" -#include "prlog.h" -#include "pk11func.h" +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + #include "ssl.h" -#include "sslerr.h" #include "sslproto.h" -#include "keyhi.h" #include <memory> -#include "test_io.h" #include "tls_parser.h" - -#define GTEST_HAS_RTTI 0 -#include "gtest/gtest.h" -#include "gtest_utils.h" - -extern std::string g_working_dir_path; +#include "tls_filter.h" +#include "tls_connect.h" namespace nss_test { -enum SessionResumptionMode { - RESUME_NONE = 0, - RESUME_SESSIONID = 1, - RESUME_TICKET = 2, - RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET -}; - -#define LOG(a) std::cerr << name_ << ": " << a << std::endl; - -// Inspector that parses out DTLS records and passes -// them on. -class TlsRecordInspector : public Inspector { - public: - virtual void Inspect(DummyPrSocket* adapter, const void* data, size_t len) { - TlsRecordParser parser(static_cast<const unsigned char*>(data), len); - - uint8_t content_type; - std::auto_ptr<DataBuffer> buf; - while (parser.NextRecord(&content_type, &buf)) { - OnRecord(adapter, content_type, buf->data(), buf->len()); - } - } - - virtual void OnRecord(DummyPrSocket* adapter, uint8_t content_type, - const unsigned char* record, size_t len) = 0; -}; - -// Inspector that injects arbitrary packets based on -// DTLS records of various types. -class TlsInspectorInjector : public TlsRecordInspector { - public: - TlsInspectorInjector(uint8_t packet_type, uint8_t handshake_type, - const unsigned char* data, size_t len) - : packet_type_(packet_type), - handshake_type_(handshake_type), - injected_(false), - data_(data, len) {} - - virtual void OnRecord(DummyPrSocket* adapter, uint8_t content_type, - const unsigned char* data, size_t len) { - // Only inject once. - if (injected_) { - return; - } - - // Check that the first byte is as requested. - if (content_type != packet_type_) { - return; - } - - if (handshake_type_ != 0xff) { - // Check that the packet is plausibly long enough. - if (len < 1) { - return; - } - - // Check that the handshake type is as requested. - if (data[0] != handshake_type_) { - return; - } - } - - adapter->WriteDirect(data_.data(), data_.len()); - } - - private: - uint8_t packet_type_; - uint8_t handshake_type_; - bool injected_; - DataBuffer data_; -}; - -// Make a copy of the first instance of a message. -class TlsInspectorRecordHandshakeMessage : public TlsRecordInspector { - public: - TlsInspectorRecordHandshakeMessage(uint8_t handshake_type) - : handshake_type_(handshake_type), buffer_() {} - - virtual void OnRecord(DummyPrSocket* adapter, uint8_t content_type, - const unsigned char* data, size_t len) { - // Only do this once. - if (buffer_.len()) { - return; - } - - // Check that the first byte is as requested. - if (content_type != kTlsHandshakeType) { - return; - } - - TlsParser parser(data, len); - while (parser.remaining()) { - unsigned char message_type; - // Read the content type. - if (!parser.Read(&message_type)) { - // Malformed. - return; - } - - // Read the record length. - uint32_t length; - if (!parser.Read(&length, 3)) { - // Malformed. - return; - } - - if (adapter->mode() == DGRAM) { - // DTLS - uint32_t message_seq; - if (!parser.Read(&message_seq, 2)) { - return; - } - - uint32_t fragment_offset; - if (!parser.Read(&fragment_offset, 3)) { - return; - } - - uint32_t fragment_length; - if (!parser.Read(&fragment_length, 3)) { - return; - } - - if ((fragment_offset != 0) || (fragment_length != length)) { - // This shouldn't happen because all current tests where we - // are using this code don't fragment. - return; - } - } - - unsigned char* dest = nullptr; - - if (message_type == handshake_type_) { - buffer_.Allocate(length); - dest = buffer_.data(); - } - - if (!parser.Read(dest, length)) { - // Malformed - return; - } - - if (dest) return; - } - } - - const DataBuffer& buffer() { return buffer_; } - - private: - uint8_t handshake_type_; - DataBuffer buffer_; -}; - class TlsServerKeyExchangeECDHE { public: - bool Parse(const unsigned char* data, size_t len) { - TlsParser parser(data, len); + bool Parse(const DataBuffer& buffer) { + TlsParser parser(buffer); uint8_t curve_type; if (!parser.Read(&curve_type)) { @@ -192,408 +34,12 @@ class TlsServerKeyExchangeECDHE { return false; } - uint32_t point_length; - if (!parser.Read(&point_length, 1)) { - return false; - } - - public_key_.Allocate(point_length); - if (!parser.Read(public_key_.data(), point_length)) { - return false; - } - - return true; + return parser.ReadVariable(&public_key_, 1); } DataBuffer public_key_; }; -class TlsAgent : public PollTarget { - public: - enum Role { CLIENT, SERVER }; - enum State { INIT, CONNECTING, CONNECTED, ERROR }; - - TlsAgent(const std::string& name, Role role, Mode mode) - : name_(name), - mode_(mode), - pr_fd_(nullptr), - adapter_(nullptr), - ssl_fd_(nullptr), - role_(role), - state_(INIT) { - memset(&info_, 0, sizeof(info_)); - memset(&csinfo_, 0, sizeof(csinfo_)); - } - - ~TlsAgent() { - if (pr_fd_) { - PR_Close(pr_fd_); - } - - if (ssl_fd_) { - PR_Close(ssl_fd_); - } - } - - bool Init() { - pr_fd_ = DummyPrSocket::CreateFD(name_, mode_); - if (!pr_fd_) return false; - - adapter_ = DummyPrSocket::GetAdapter(pr_fd_); - if (!adapter_) return false; - - return true; - } - - void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); } - - void SetInspector(Inspector* inspector) { adapter_->SetInspector(inspector); } - - void StartConnect() { - ASSERT_TRUE(EnsureTlsSetup()); - - SECStatus rv; - rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE); - ASSERT_EQ(SECSuccess, rv); - SetState(CONNECTING); - } - - void EnableSomeECDHECiphers() { - ASSERT_TRUE(EnsureTlsSetup()); - - const uint32_t EnabledCiphers[] = {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}; - - for (size_t i = 0; i < PR_ARRAY_SIZE(EnabledCiphers); ++i) { - SECStatus rv = SSL_CipherPrefSet(ssl_fd_, EnabledCiphers[i], PR_TRUE); - ASSERT_EQ(SECSuccess, rv); - } - } - - bool EnsureTlsSetup() { - // Don't set up twice - if (ssl_fd_) return true; - - if (adapter_->mode() == STREAM) { - ssl_fd_ = SSL_ImportFD(nullptr, pr_fd_); - } else { - ssl_fd_ = DTLS_ImportFD(nullptr, pr_fd_); - } - - EXPECT_NE(nullptr, ssl_fd_); - if (!ssl_fd_) return false; - pr_fd_ = nullptr; - - if (role_ == SERVER) { - CERTCertificate* cert = PK11_FindCertFromNickname(name_.c_str(), nullptr); - EXPECT_NE(nullptr, cert); - if (!cert) return false; - - SECKEYPrivateKey* priv = PK11_FindKeyByAnyCert(cert, nullptr); - EXPECT_NE(nullptr, priv); - if (!priv) return false; // Leak cert. - - SECStatus rv = SSL_ConfigSecureServer(ssl_fd_, cert, priv, kt_rsa); - EXPECT_EQ(SECSuccess, rv); - if (rv != SECSuccess) return false; // Leak cert and key. - - SECKEY_DestroyPrivateKey(priv); - CERT_DestroyCertificate(cert); - } else { - SECStatus rv = SSL_SetURL(ssl_fd_, "server"); - EXPECT_EQ(SECSuccess, rv); - if (rv != SECSuccess) return false; - } - - SECStatus rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, - reinterpret_cast<void*>(this)); - EXPECT_EQ(SECSuccess, rv); - if (rv != SECSuccess) return false; - - return true; - } - - void SetVersionRange(uint16_t minver, uint16_t maxver) { - SSLVersionRange range = {minver, maxver}; - ASSERT_EQ(SECSuccess, SSL_VersionRangeSet(ssl_fd_, &range)); - } - - State state() const { return state_; } - - const char* state_str() const { return state_str(state()); } - - const char* state_str(State state) const { return states[state]; } - - PRFileDesc* ssl_fd() { return ssl_fd_; } - - bool version(uint16_t* version) const { - if (state_ != CONNECTED) return false; - - *version = info_.protocolVersion; - - return true; - } - - bool cipher_suite(int16_t* cipher_suite) const { - if (state_ != CONNECTED) return false; - - *cipher_suite = info_.cipherSuite; - return true; - } - - std::string cipher_suite_name() const { - if (state_ != CONNECTED) return "UNKNOWN"; - - return csinfo_.cipherSuiteName; - } - - void CheckKEAType(SSLKEAType type) const { - ASSERT_EQ(CONNECTED, state_); - ASSERT_EQ(type, csinfo_.keaType); - } - - void CheckVersion(uint16_t version) const { - ASSERT_EQ(CONNECTED, state_); - ASSERT_EQ(version, info_.protocolVersion); - } - - - void Handshake() { - SECStatus rv = SSL_ForceHandshake(ssl_fd_); - if (rv == SECSuccess) { - LOG("Handshake success"); - SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_)); - ASSERT_EQ(SECSuccess, rv); - - rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_)); - ASSERT_EQ(SECSuccess, rv); - - SetState(CONNECTED); - return; - } - - int32_t err = PR_GetError(); - switch (err) { - case PR_WOULD_BLOCK_ERROR: - LOG("Would have blocked"); - // TODO(ekr@rtfm.com): set DTLS timeouts - Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, - &TlsAgent::ReadableCallback); - return; - break; - - // TODO(ekr@rtfm.com): needs special case for DTLS - case SSL_ERROR_RX_MALFORMED_HANDSHAKE: - default: - LOG("Handshake failed with error " << err); - SetState(ERROR); - return; - } - } - - std::vector<uint8_t> GetSessionId() { - return std::vector<uint8_t>(info_.sessionID, - info_.sessionID + info_.sessionIDLength); - } - - void ConfigureSessionCache(SessionResumptionMode mode) { - ASSERT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd_, - SSL_NO_CACHE, - mode & RESUME_SESSIONID ? - PR_FALSE : PR_TRUE); - ASSERT_EQ(SECSuccess, rv); - - rv = SSL_OptionSet(ssl_fd_, - SSL_ENABLE_SESSION_TICKETS, - mode & RESUME_TICKET ? - PR_TRUE : PR_FALSE); - ASSERT_EQ(SECSuccess, rv); - } - - private: - const static char* states[]; - - void SetState(State state) { - if (state_ == state) return; - - LOG("Changing state from " << state_str(state_) << " to " - << state_str(state)); - state_ = state; - } - - // Dummy auth certificate hook. - static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, - PRBool checksig, PRBool isServer) { - return SECSuccess; - } - - static void ReadableCallback(PollTarget* self, Event event) { - TlsAgent* agent = static_cast<TlsAgent*>(self); - agent->ReadableCallback_int(event); - } - - void ReadableCallback_int(Event event) { - LOG("Readable"); - Handshake(); - } - - const std::string name_; - Mode mode_; - PRFileDesc* pr_fd_; - DummyPrSocket* adapter_; - PRFileDesc* ssl_fd_; - Role role_; - State state_; - SSLChannelInfo info_; - SSLCipherSuiteInfo csinfo_; -}; - -const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"}; - -class TlsConnectTestBase : public ::testing::Test { - public: - TlsConnectTestBase(Mode mode) - : mode_(mode), - client_(new TlsAgent("client", TlsAgent::CLIENT, mode_)), - server_(new TlsAgent("server", TlsAgent::SERVER, mode_)) {} - - ~TlsConnectTestBase() { - delete client_; - delete server_; - } - - void SetUp() { - // Configure a fresh session cache. - SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); - - // Clear statistics. - SSL3Statistics* stats = SSL_GetStatistics(); - memset(stats, 0, sizeof(*stats)); - - Init(); - } - - void TearDown() { - client_ = nullptr; - server_ = nullptr; - - SSL_ClearSessionCache(); - SSL_ShutdownServerSessionIDCache(); - } - - void Init() { - ASSERT_TRUE(client_->Init()); - ASSERT_TRUE(server_->Init()); - - client_->SetPeer(server_); - server_->SetPeer(client_); - } - - void Reset() { - delete client_; - delete server_; - - client_ = new TlsAgent("client", TlsAgent::CLIENT, mode_); - server_ = new TlsAgent("server", TlsAgent::SERVER, mode_); - - Init(); - } - - void EnsureTlsSetup() { - ASSERT_TRUE(client_->EnsureTlsSetup()); - ASSERT_TRUE(server_->EnsureTlsSetup()); - } - - void Connect() { - server_->StartConnect(); // Server - client_->StartConnect(); // Client - client_->Handshake(); - server_->Handshake(); - - ASSERT_TRUE_WAIT(client_->state() != TlsAgent::CONNECTING && - server_->state() != TlsAgent::CONNECTING, - 5000); - ASSERT_EQ(TlsAgent::CONNECTED, server_->state()); - - int16_t cipher_suite1, cipher_suite2; - bool ret = client_->cipher_suite(&cipher_suite1); - ASSERT_TRUE(ret); - ret = server_->cipher_suite(&cipher_suite2); - ASSERT_TRUE(ret); - ASSERT_EQ(cipher_suite1, cipher_suite2); - - std::cerr << "Connected with cipher suite " << client_->cipher_suite_name() - << std::endl; - - // Check and store session ids. - std::vector<uint8_t> sid_c1 = client_->GetSessionId(); - ASSERT_EQ(32, sid_c1.size()); - std::vector<uint8_t> sid_s1 = server_->GetSessionId(); - ASSERT_EQ(32, sid_s1.size()); - ASSERT_EQ(sid_c1, sid_s1); - session_ids_.push_back(sid_c1); - } - - void EnableSomeECDHECiphers() { - client_->EnableSomeECDHECiphers(); - server_->EnableSomeECDHECiphers(); - } - - void ConfigureSessionCache(SessionResumptionMode client, - SessionResumptionMode server) { - client_->ConfigureSessionCache(client); - server_->ConfigureSessionCache(server); - } - - void CheckResumption(SessionResumptionMode expected) { - ASSERT_NE(RESUME_BOTH, expected); - - int resume_ct = expected != 0; - int stateless_ct = (expected & RESUME_TICKET) ? 1 : 0; - - SSL3Statistics* stats = SSL_GetStatistics(); - ASSERT_EQ(resume_ct, stats->hch_sid_cache_hits); - ASSERT_EQ(resume_ct, stats->hsh_sid_cache_hits); - - ASSERT_EQ(stateless_ct, stats->hch_sid_stateless_resumes); - ASSERT_EQ(stateless_ct, stats->hsh_sid_stateless_resumes); - - if (resume_ct) { - // Check that the last two session ids match. - ASSERT_GE(2, session_ids_.size()); - ASSERT_EQ(session_ids_[session_ids_.size()-1], - session_ids_[session_ids_.size()-2]); - } - } - - protected: - Mode mode_; - TlsAgent* client_; - TlsAgent* server_; - std::vector<std::vector<uint8_t>> session_ids_; -}; - -class TlsConnectTest : public TlsConnectTestBase { - public: - TlsConnectTest() : TlsConnectTestBase(STREAM) {} -}; - -class DtlsConnectTest : public TlsConnectTestBase { - public: - DtlsConnectTest() : TlsConnectTestBase(DGRAM) {} -}; - -class TlsConnectGeneric : public TlsConnectTestBase, - public ::testing::WithParamInterface<std::string> { - public: - TlsConnectGeneric() - : TlsConnectTestBase((GetParam() == "TLS") ? STREAM : DGRAM) { - std::cerr << "Variant: " << GetParam() << std::endl; - } -}; - TEST_P(TlsConnectGeneric, SetupOnly) {} TEST_P(TlsConnectGeneric, Connect) { @@ -729,6 +175,19 @@ TEST_P(TlsConnectGeneric, ConnectTLS_1_2_Only) { client_->CheckVersion(SSL_LIBRARY_VERSION_TLS_1_2); } +TEST_P(TlsConnectGeneric, ConnectAlpn) { + EnableAlpn(); + Connect(); + client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, "a"); + server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, "a"); +} + +TEST_F(DtlsConnectTest, ConnectSrtp) { + EnableSrtp(); + Connect(); + CheckSrtp(); +} + TEST_F(TlsConnectTest, ConnectECDHE) { EnableSomeECDHECiphers(); Connect(); @@ -739,24 +198,24 @@ TEST_F(TlsConnectTest, ConnectECDHETwiceReuseKey) { EnableSomeECDHECiphers(); TlsInspectorRecordHandshakeMessage* i1 = new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); - server_->SetInspector(i1); + server_->SetPacketFilter(i1); Connect(); client_->CheckKEAType(ssl_kea_ecdh); TlsServerKeyExchangeECDHE dhe1; - ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len())); + ASSERT_TRUE(dhe1.Parse(i1->buffer())); // Restart Reset(); TlsInspectorRecordHandshakeMessage* i2 = new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); - server_->SetInspector(i2); + server_->SetPacketFilter(i2); EnableSomeECDHECiphers(); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); Connect(); client_->CheckKEAType(ssl_kea_ecdh); TlsServerKeyExchangeECDHE dhe2; - ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len())); + ASSERT_TRUE(dhe2.Parse(i2->buffer())); // Make sure they are the same. ASSERT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len()); @@ -771,11 +230,11 @@ TEST_F(TlsConnectTest, ConnectECDHETwiceNewKey) { ASSERT_EQ(SECSuccess, rv); TlsInspectorRecordHandshakeMessage* i1 = new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); - server_->SetInspector(i1); + server_->SetPacketFilter(i1); Connect(); client_->CheckKEAType(ssl_kea_ecdh); TlsServerKeyExchangeECDHE dhe1; - ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len())); + ASSERT_TRUE(dhe1.Parse(i1->buffer())); // Restart Reset(); @@ -784,13 +243,13 @@ TEST_F(TlsConnectTest, ConnectECDHETwiceNewKey) { ASSERT_EQ(SECSuccess, rv); TlsInspectorRecordHandshakeMessage* i2 = new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); - server_->SetInspector(i2); + server_->SetPacketFilter(i2); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); Connect(); client_->CheckKEAType(ssl_kea_ecdh); TlsServerKeyExchangeECDHE dhe2; - ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len())); + ASSERT_TRUE(dhe2.Parse(i2->buffer())); // Make sure they are different. ASSERT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) && diff --git a/external_tests/ssl_gtest/test_io.cc b/external_tests/ssl_gtest/test_io.cc index 701647831..2bfd09178 100644 --- a/external_tests/ssl_gtest/test_io.cc +++ b/external_tests/ssl_gtest/test_io.cc @@ -4,42 +4,45 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ -#include <assert.h> +#include "test_io.h" +#include <algorithm> +#include <cassert> #include <iostream> #include <memory> #include "prerror.h" -#include "prio.h" #include "prlog.h" #include "prthread.h" -#include "test_io.h" +#include "databuffer.h" namespace nss_test { static PRDescIdentity test_fd_identity = PR_INVALID_IO_LAYER; -#define UNIMPLEMENTED() \ - fprintf(stderr, "Call to unimplemented function %s\n", __FUNCTION__); \ - PR_ASSERT(PR_FALSE); \ +#define UNIMPLEMENTED() \ + std::cerr << "Call to unimplemented function " \ + << __FUNCTION__ << std::endl; \ + PR_ASSERT(PR_FALSE); \ PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0) #define LOG(a) std::cerr << name_ << ": " << a << std::endl; -struct Packet { - Packet() : data_(nullptr), len_(0), offset_(0) {} +class Packet : public DataBuffer { + public: + Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {} - void Assign(const void *data, int32_t len) { - data_ = new uint8_t[len]; - memcpy(data_, data, len); - len_ = len; + void Advance(size_t delta) { + PR_ASSERT(offset_ + delta <= len()); + offset_ = std::min(len(), offset_ + delta); } - ~Packet() { delete data_; } - uint8_t *data_; - int32_t len_; - int32_t offset_; + size_t offset() const { return offset_; } + size_t remaining() const { return len() - offset_; } + + private: + size_t offset_; }; // Implementation of NSPR methods @@ -246,6 +249,16 @@ static int32_t DummyReserved(PRFileDesc *f) { return -1; } +DummyPrSocket::~DummyPrSocket() { + delete filter_; + while (!input_.empty()) + { + Packet* front = input_.front(); + input_.pop(); + delete front; + } +} + static const struct PRIOMethods DummyMethods = { PR_DESC_LAYERED, DummyClose, DummyRead, DummyWrite, DummyAvailable, DummyAvailable64, @@ -275,9 +288,8 @@ DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) { return reinterpret_cast<DummyPrSocket *>(fd->secret); } -void DummyPrSocket::PacketReceived(const void *data, int32_t len) { - input_.push(new Packet()); - input_.back()->Assign(data, len); +void DummyPrSocket::PacketReceived(const DataBuffer& packet) { + input_.push(new Packet(packet)); } int32_t DummyPrSocket::Read(void *data, int32_t len) { @@ -295,16 +307,18 @@ int32_t DummyPrSocket::Read(void *data, int32_t len) { } Packet *front = input_.front(); - int32_t to_read = std::min(len, front->len_ - front->offset_); - memcpy(data, front->data_ + front->offset_, to_read); - front->offset_ += to_read; + size_t to_read = std::min(static_cast<size_t>(len), + front->len() - front->offset()); + memcpy(data, static_cast<const void*>(front->data() + front->offset()), + to_read); + front->Advance(to_read); - if (front->offset_ == front->len_) { + if (!front->remaining()) { input_.pop(); delete front; } - return to_read; + return static_cast<int32_t>(to_read); } int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) { @@ -314,39 +328,49 @@ int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) { } Packet *front = input_.front(); - if (buflen < front->len_) { + if (buflen < front->len()) { PR_ASSERT(false); PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0); return -1; } - int32_t count = front->len_; - memcpy(buf, front->data_, count); + size_t count = front->len(); + memcpy(buf, front->data(), count); input_.pop(); delete front; - return count; + return static_cast<int32_t>(count); } int32_t DummyPrSocket::Write(const void *buf, int32_t length) { - if (inspector_) { - inspector_->Inspect(this, buf, length); + DataBuffer packet(static_cast<const uint8_t*>(buf), + static_cast<size_t>(length)); + if (filter_) { + DataBuffer filtered; + if (filter_->Filter(packet, &filtered)) { + if (WriteDirect(filtered) != filtered.len()) { + PR_SetError(PR_IO_ERROR, 0); + return -1; + } + LOG("Wrote: " << packet); + // libssl can't handle if this reports something other than the length of + // what was passed in (or less, but we're not doing partial writes). + return packet.len(); + } } - return WriteDirect(buf, length); + return WriteDirect(packet); } -int32_t DummyPrSocket::WriteDirect(const void *buf, int32_t length) { +int32_t DummyPrSocket::WriteDirect(const DataBuffer& packet) { if (!peer_) { PR_SetError(PR_IO_ERROR, 0); return -1; } - LOG("Wrote " << length); - - peer_->PacketReceived(buf, length); - return length; + peer_->PacketReceived(packet); + return static_cast<int32_t>(packet.len()); // ignore truncation } Poller *Poller::instance; diff --git a/external_tests/ssl_gtest/test_io.h b/external_tests/ssl_gtest/test_io.h index 64cc4fd5a..d2424c60c 100644 --- a/external_tests/ssl_gtest/test_io.h +++ b/external_tests/ssl_gtest/test_io.h @@ -13,25 +13,32 @@ #include <queue> #include <string> +#include "prio.h" + namespace nss_test { -struct Packet; +class DataBuffer; +class Packet; class DummyPrSocket; // Fwd decl. // Allow us to inspect a packet before it is written. -class Inspector { +class PacketFilter { public: - virtual ~Inspector() {} - - virtual void Inspect(DummyPrSocket* adapter, const void* data, - size_t len) = 0; + virtual ~PacketFilter() {} + + // The packet filter takes input and has the option of mutating it. + // + // A filter that modifies the data places the modified data in *output and + // returns true. A filter that does not modify data returns false, in which + // case the value in *output is ignored. + virtual bool Filter(const DataBuffer& input, DataBuffer* output) = 0; }; enum Mode { STREAM, DGRAM }; class DummyPrSocket { public: - ~DummyPrSocket() { delete inspector_; } + ~DummyPrSocket(); static PRFileDesc* CreateFD(const std::string& name, Mode mode); // Returns an FD. @@ -39,16 +46,16 @@ class DummyPrSocket { void SetPeer(DummyPrSocket* peer) { peer_ = peer; } - void SetInspector(Inspector* inspector) { inspector_ = inspector; } + void SetPacketFilter(PacketFilter* filter) { filter_ = filter; } - void PacketReceived(const void* data, int32_t len); + void PacketReceived(const DataBuffer& data); int32_t Read(void* data, int32_t len); int32_t Recv(void* buf, int32_t buflen); int32_t Write(const void* buf, int32_t length); - int32_t WriteDirect(const void* buf, int32_t length); + int32_t WriteDirect(const DataBuffer& data); Mode mode() const { return mode_; } - bool readable() { return !input_.empty(); } + bool readable() const { return !input_.empty(); } bool writable() { return true; } private: @@ -57,13 +64,13 @@ class DummyPrSocket { mode_(mode), peer_(nullptr), input_(), - inspector_(nullptr) {} + filter_(nullptr) {} const std::string name_; Mode mode_; DummyPrSocket* peer_; std::queue<Packet*> input_; - Inspector* inspector_; + PacketFilter* filter_; }; // Marker interface. diff --git a/external_tests/ssl_gtest/tls_agent.cc b/external_tests/ssl_gtest/tls_agent.cc new file mode 100644 index 000000000..6eeb651f5 --- /dev/null +++ b/external_tests/ssl_gtest/tls_agent.cc @@ -0,0 +1,208 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_agent.h" + +#include "pk11func.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" +#include "keyhi.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" + +namespace nss_test { + +const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"}; + +bool TlsAgent::EnsureTlsSetup() { + // Don't set up twice + if (ssl_fd_) return true; + + if (adapter_->mode() == STREAM) { + ssl_fd_ = SSL_ImportFD(nullptr, pr_fd_); + } else { + ssl_fd_ = DTLS_ImportFD(nullptr, pr_fd_); + } + + EXPECT_NE(nullptr, ssl_fd_); + if (!ssl_fd_) return false; + pr_fd_ = nullptr; + + if (role_ == SERVER) { + CERTCertificate* cert = PK11_FindCertFromNickname(name_.c_str(), nullptr); + EXPECT_NE(nullptr, cert); + if (!cert) return false; + + SECKEYPrivateKey* priv = PK11_FindKeyByAnyCert(cert, nullptr); + EXPECT_NE(nullptr, priv); + if (!priv) return false; // Leak cert. + + SECStatus rv = SSL_ConfigSecureServer(ssl_fd_, cert, priv, kt_rsa); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; // Leak cert and key. + + SECKEY_DestroyPrivateKey(priv); + CERT_DestroyCertificate(cert); + + rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook, + reinterpret_cast<void*>(this)); + EXPECT_EQ(SECSuccess, rv); // don't abort, just fail + } else { + SECStatus rv = SSL_SetURL(ssl_fd_, "server"); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + } + + SECStatus rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, + reinterpret_cast<void*>(this)); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + return true; +} + +void TlsAgent::StartConnect() { + ASSERT_TRUE(EnsureTlsSetup()); + + SECStatus rv; + rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE); + ASSERT_EQ(SECSuccess, rv); + SetState(CONNECTING); +} + +void TlsAgent::EnableSomeECDHECiphers() { + ASSERT_TRUE(EnsureTlsSetup()); + + const uint32_t EnabledCiphers[] = {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}; + + for (size_t i = 0; i < PR_ARRAY_SIZE(EnabledCiphers); ++i) { + SECStatus rv = SSL_CipherPrefSet(ssl_fd_, EnabledCiphers[i], PR_TRUE); + ASSERT_EQ(SECSuccess, rv); + } +} + +void TlsAgent::SetSessionTicketsEnabled(bool en) { + ASSERT_TRUE(EnsureTlsSetup()); + + SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS, + en ? PR_TRUE : PR_FALSE); + ASSERT_EQ(SECSuccess, rv); +} + +void TlsAgent::SetSessionCacheEnabled(bool en) { + ASSERT_TRUE(EnsureTlsSetup()); + + SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, + en ? PR_FALSE : PR_TRUE); + ASSERT_EQ(SECSuccess, rv); +} + +void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { + SSLVersionRange range = {minver, maxver}; + ASSERT_EQ(SECSuccess, SSL_VersionRangeSet(ssl_fd_, &range)); +} + +void TlsAgent::CheckKEAType(SSLKEAType type) const { + ASSERT_EQ(CONNECTED, state_); + ASSERT_EQ(type, csinfo_.keaType); +} + +void TlsAgent::CheckVersion(uint16_t version) const { + ASSERT_EQ(CONNECTED, state_); + ASSERT_EQ(version, info_.protocolVersion); +} + +void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { + ASSERT_TRUE(EnsureTlsSetup()); + + ASSERT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE)); + ASSERT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len)); +} + +void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, + const std::string& expected) { + SSLNextProtoState state; + char chosen[10]; + unsigned int chosen_len; + SECStatus rv = SSL_GetNextProto(ssl_fd_, &state, + reinterpret_cast<unsigned char*>(chosen), + &chosen_len, sizeof(chosen)); + ASSERT_EQ(SECSuccess, rv); + ASSERT_EQ(expected_state, state); + ASSERT_EQ(expected, std::string(chosen, chosen_len)); +} + +void TlsAgent::EnableSrtp() { + ASSERT_TRUE(EnsureTlsSetup()); + const uint16_t ciphers[] = { + SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32 + }; + ASSERT_EQ(SECSuccess, SSL_SetSRTPCiphers(ssl_fd_, ciphers, + PR_ARRAY_SIZE(ciphers))); + +} + +void TlsAgent::CheckSrtp() { + uint16_t actual; + ASSERT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual)); + ASSERT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual); +} + + +void TlsAgent::Handshake() { + SECStatus rv = SSL_ForceHandshake(ssl_fd_); + if (rv == SECSuccess) { + LOG("Handshake success"); + SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_)); + ASSERT_EQ(SECSuccess, rv); + + rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_)); + ASSERT_EQ(SECSuccess, rv); + + SetState(CONNECTED); + return; + } + + int32_t err = PR_GetError(); + switch (err) { + case PR_WOULD_BLOCK_ERROR: + LOG("Would have blocked"); + // TODO(ekr@rtfm.com): set DTLS timeouts + Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, + &TlsAgent::ReadableCallback); + return; + break; + + // TODO(ekr@rtfm.com): needs special case for DTLS + case SSL_ERROR_RX_MALFORMED_HANDSHAKE: + default: + LOG("Handshake failed with error " << err); + SetState(ERROR); + return; + } +} + +void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { + ASSERT_TRUE(EnsureTlsSetup()); + + SECStatus rv = SSL_OptionSet(ssl_fd_, + SSL_NO_CACHE, + mode & RESUME_SESSIONID ? + PR_FALSE : PR_TRUE); + ASSERT_EQ(SECSuccess, rv); + + rv = SSL_OptionSet(ssl_fd_, + SSL_ENABLE_SESSION_TICKETS, + mode & RESUME_TICKET ? + PR_TRUE : PR_FALSE); + ASSERT_EQ(SECSuccess, rv); +} + + +} // namespace nss_test diff --git a/external_tests/ssl_gtest/tls_agent.h b/external_tests/ssl_gtest/tls_agent.h new file mode 100644 index 000000000..aee835ea7 --- /dev/null +++ b/external_tests/ssl_gtest/tls_agent.h @@ -0,0 +1,170 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_agent_h_ +#define tls_agent_h_ + +#include "prio.h" +#include "ssl.h" + +#include <iostream> + +#include "test_io.h" + +namespace nss_test { + +#define LOG(msg) std::cerr << name_ << ": " << msg << std::endl + +enum SessionResumptionMode { + RESUME_NONE = 0, + RESUME_SESSIONID = 1, + RESUME_TICKET = 2, + RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET +}; + +class TlsAgent : public PollTarget { + public: + enum Role { CLIENT, SERVER }; + enum State { INIT, CONNECTING, CONNECTED, ERROR }; + + TlsAgent(const std::string& name, Role role, Mode mode) + : name_(name), + mode_(mode), + pr_fd_(nullptr), + adapter_(nullptr), + ssl_fd_(nullptr), + role_(role), + state_(INIT) { + memset(&info_, 0, sizeof(info_)); + memset(&csinfo_, 0, sizeof(csinfo_)); + } + + ~TlsAgent() { + if (pr_fd_) { + PR_Close(pr_fd_); + } + + if (ssl_fd_) { + PR_Close(ssl_fd_); + } + } + + bool Init() { + pr_fd_ = DummyPrSocket::CreateFD(name_, mode_); + if (!pr_fd_) return false; + + adapter_ = DummyPrSocket::GetAdapter(pr_fd_); + if (!adapter_) return false; + + return true; + } + + void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); } + + void SetPacketFilter(PacketFilter* filter) { + adapter_->SetPacketFilter(filter); + } + + + void StartConnect(); + void CheckKEAType(SSLKEAType type) const; + void CheckVersion(uint16_t version) const; + + void Handshake(); + void EnableSomeECDHECiphers(); + bool EnsureTlsSetup(); + + void ConfigureSessionCache(SessionResumptionMode mode); + void SetSessionTicketsEnabled(bool en); + void SetSessionCacheEnabled(bool en); + void SetVersionRange(uint16_t minver, uint16_t maxver); + void EnableAlpn(const uint8_t* val, size_t len); + void CheckAlpn(SSLNextProtoState expected_state, + const std::string& expected); + void EnableSrtp(); + void CheckSrtp(); + + State state() const { return state_; } + + const char* state_str() const { return state_str(state()); } + + const char* state_str(State state) const { return states[state]; } + + PRFileDesc* ssl_fd() { return ssl_fd_; } + + bool version(uint16_t* version) const { + if (state_ != CONNECTED) return false; + + *version = info_.protocolVersion; + + return true; + } + + bool cipher_suite(int16_t* cipher_suite) const { + if (state_ != CONNECTED) return false; + + *cipher_suite = info_.cipherSuite; + return true; + } + + std::string cipher_suite_name() const { + if (state_ != CONNECTED) return "UNKNOWN"; + + return csinfo_.cipherSuiteName; + } + + std::vector<uint8_t> session_id() const { + return std::vector<uint8_t>(info_.sessionID, + info_.sessionID + info_.sessionIDLength); + } + + private: + const static char* states[]; + + void SetState(State state) { + if (state_ == state) return; + + LOG("Changing state from " << state_str(state_) << " to " + << state_str(state)); + state_ = state; + } + + // Dummy auth certificate hook. + static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, + PRBool checksig, PRBool isServer) { + return SECSuccess; + } + + static void ReadableCallback(PollTarget* self, Event event) { + TlsAgent* agent = static_cast<TlsAgent*>(self); + agent->ReadableCallback_int(); + } + + void ReadableCallback_int() { + LOG("Readable"); + Handshake(); + } + + static PRInt32 SniHook(PRFileDesc *fd, const SECItem *srvNameArr, + PRUint32 srvNameArrSize, + void *arg) { + return SSL_SNI_CURRENT_CONFIG_IS_USED; + } + + const std::string name_; + Mode mode_; + PRFileDesc* pr_fd_; + DummyPrSocket* adapter_; + PRFileDesc* ssl_fd_; + Role role_; + State state_; + SSLChannelInfo info_; + SSLCipherSuiteInfo csinfo_; +}; + +} // namespace nss_test + +#endif diff --git a/external_tests/ssl_gtest/tls_connect.cc b/external_tests/ssl_gtest/tls_connect.cc new file mode 100644 index 000000000..6c6fd1a53 --- /dev/null +++ b/external_tests/ssl_gtest/tls_connect.cc @@ -0,0 +1,170 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_connect.h" + +#include <iostream> + +#include "gtest_utils.h" + +extern std::string g_working_dir_path; + +namespace nss_test { + +TlsConnectTestBase::TlsConnectTestBase(Mode mode) + : mode_(mode), + client_(new TlsAgent("client", TlsAgent::CLIENT, mode_)), + server_(new TlsAgent("server", TlsAgent::SERVER, mode_)) {} + +TlsConnectTestBase::~TlsConnectTestBase() { + delete client_; + delete server_; +} + +void TlsConnectTestBase::SetUp() { + // Configure a fresh session cache. + SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); + + // Clear statistics. + SSL3Statistics* stats = SSL_GetStatistics(); + memset(stats, 0, sizeof(*stats)); + + Init(); +} + +void TlsConnectTestBase::TearDown() { + client_ = nullptr; + server_ = nullptr; + + SSL_ClearSessionCache(); + SSL_ShutdownServerSessionIDCache(); +} + +void TlsConnectTestBase::Init() { + ASSERT_TRUE(client_->Init()); + ASSERT_TRUE(server_->Init()); + + client_->SetPeer(server_); + server_->SetPeer(client_); +} + +void TlsConnectTestBase::Reset() { + delete client_; + delete server_; + + client_ = new TlsAgent("client", TlsAgent::CLIENT, mode_); + server_ = new TlsAgent("server", TlsAgent::SERVER, mode_); + + Init(); +} + +void TlsConnectTestBase::EnsureTlsSetup() { + ASSERT_TRUE(client_->EnsureTlsSetup()); + ASSERT_TRUE(server_->EnsureTlsSetup()); +} + +void TlsConnectTestBase::Handshake() { + server_->StartConnect(); + client_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + + ASSERT_TRUE_WAIT(client_->state() != TlsAgent::CONNECTING && + server_->state() != TlsAgent::CONNECTING, + 5000); +} + +void TlsConnectTestBase::Connect() { + Handshake(); + + ASSERT_EQ(TlsAgent::CONNECTED, client_->state()); + ASSERT_EQ(TlsAgent::CONNECTED, server_->state()); + + int16_t cipher_suite1, cipher_suite2; + bool ret = client_->cipher_suite(&cipher_suite1); + ASSERT_TRUE(ret); + ret = server_->cipher_suite(&cipher_suite2); + ASSERT_TRUE(ret); + ASSERT_EQ(cipher_suite1, cipher_suite2); + + std::cerr << "Connected with cipher suite " << client_->cipher_suite_name() + << std::endl; + + // Check and store session ids. + std::vector<uint8_t> sid_c1 = client_->session_id(); + ASSERT_EQ(32, sid_c1.size()); + std::vector<uint8_t> sid_s1 = server_->session_id(); + ASSERT_EQ(32, sid_s1.size()); + ASSERT_EQ(sid_c1, sid_s1); + session_ids_.push_back(sid_c1); +} + +void TlsConnectTestBase::ConnectExpectFail() { + Handshake(); + + ASSERT_EQ(TlsAgent::ERROR, client_->state()); + ASSERT_EQ(TlsAgent::ERROR, server_->state()); +} + +void TlsConnectTestBase::EnableSomeECDHECiphers() { + client_->EnableSomeECDHECiphers(); + server_->EnableSomeECDHECiphers(); +} + + +void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, + SessionResumptionMode server) { + client_->ConfigureSessionCache(client); + server_->ConfigureSessionCache(server); +} + +void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { + ASSERT_NE(RESUME_BOTH, expected); + + int resume_ct = expected ? 1 : 0; + int stateless_ct = (expected & RESUME_TICKET) ? 1 : 0; + + SSL3Statistics* stats = SSL_GetStatistics(); + ASSERT_EQ(resume_ct, stats->hch_sid_cache_hits); + ASSERT_EQ(resume_ct, stats->hsh_sid_cache_hits); + + ASSERT_EQ(stateless_ct, stats->hch_sid_stateless_resumes); + ASSERT_EQ(stateless_ct, stats->hsh_sid_stateless_resumes); + + if (resume_ct) { + // Check that the last two session ids match. + ASSERT_GE(2, session_ids_.size()); + ASSERT_EQ(session_ids_[session_ids_.size()-1], + session_ids_[session_ids_.size()-2]); + } +} + +void TlsConnectTestBase::EnableAlpn() { + // A simple value of "a", "b". Note that the preferred value of "a" is placed + // at the end, because the NSS API follows the now defunct NPN specification, + // which places the preferred (and default) entry at the end of the list. + // NSS will move this final entry to the front when used with ALPN. + static const uint8_t val[] = { 0x01, 0x62, 0x01, 0x61 }; + client_->EnableAlpn(val, sizeof(val)); + server_->EnableAlpn(val, sizeof(val)); +} + +void TlsConnectTestBase::EnableSrtp() { + client_->EnableSrtp(); + server_->EnableSrtp(); +} + +void TlsConnectTestBase::CheckSrtp() { + client_->CheckSrtp(); + server_->CheckSrtp(); +} + +TlsConnectGeneric::TlsConnectGeneric() + : TlsConnectTestBase((GetParam() == "TLS") ? STREAM : DGRAM) { + std::cerr << "Variant: " << GetParam() << std::endl; +} + +} // namespace nss_test diff --git a/external_tests/ssl_gtest/tls_connect.h b/external_tests/ssl_gtest/tls_connect.h new file mode 100644 index 000000000..c263fe83f --- /dev/null +++ b/external_tests/ssl_gtest/tls_connect.h @@ -0,0 +1,79 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_connect_h_ +#define tls_connect_h_ + +#include "sslt.h" + +#include "tls_agent.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" + +namespace nss_test { + +// A generic TLS connection test base. +class TlsConnectTestBase : public ::testing::Test { + public: + TlsConnectTestBase(Mode mode); + virtual ~TlsConnectTestBase(); + + void SetUp(); + void TearDown(); + + // Initialize client and server. + void Init(); + // Re-initialize client and server. + void Reset(); + // Make sure TLS is configured for a connection. + void EnsureTlsSetup(); + + // Run the handshake. + void Handshake(); + // Connect and check that it works. + void Connect(); + // Connect and expect it to fail. + void ConnectExpectFail(); + + void EnableSomeECDHECiphers(); + void ConfigureSessionCache(SessionResumptionMode client, + SessionResumptionMode server); + void CheckResumption(SessionResumptionMode expected); + void EnableAlpn(); + void EnableSrtp(); + void CheckSrtp(); + + protected: + Mode mode_; + TlsAgent* client_; + TlsAgent* server_; + std::vector<std::vector<uint8_t>> session_ids_; +}; + +// A TLS-only test base. +class TlsConnectTest : public TlsConnectTestBase { + public: + TlsConnectTest() : TlsConnectTestBase(STREAM) {} +}; + +// A DTLS-only test base. +class DtlsConnectTest : public TlsConnectTestBase { + public: + DtlsConnectTest() : TlsConnectTestBase(DGRAM) {} +}; + +// A generic test class that can be either STREAM or DGRAM. This is configured +// in ssl_loopback_unittest.cc. All uses of this should use TEST_P(). +class TlsConnectGeneric : public TlsConnectTestBase, + public ::testing::WithParamInterface<std::string> { + public: + TlsConnectGeneric(); +}; + +} // namespace nss_test + +#endif diff --git a/external_tests/ssl_gtest/tls_filter.cc b/external_tests/ssl_gtest/tls_filter.cc new file mode 100644 index 000000000..3cbe9e5ac --- /dev/null +++ b/external_tests/ssl_gtest/tls_filter.cc @@ -0,0 +1,226 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_filter.h" + +#include <iostream> + +namespace nss_test { + +bool TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) { + bool changed = false; + size_t output_offset = 0U; + output->Allocate(input.len()); + + TlsParser parser(input); + while (parser.remaining()) { + size_t start = parser.consumed(); + uint8_t content_type; + if (!parser.Read(&content_type)) { + return false; + } + uint32_t version; + if (!parser.Read(&version, 2)) { + return false; + } + + if (IsDtls(version)) { + if (!parser.Skip(8)) { + return false; + } + } + size_t header_len = parser.consumed() - start; + output->Write(output_offset, input.data() + start, header_len); + + DataBuffer record; + if (!parser.ReadVariable(&record, 2)) { + return false; + } + + // Move the offset in the output forward. ApplyFilter() returns the index + // of the end of the record it wrote to the output, so we need to skip + // over the content type and version for the value passed to it. + output_offset = ApplyFilter(content_type, version, record, output, + output_offset + header_len, + &changed); + } + output->Truncate(output_offset); + + // Record how many packets we actually touched. + if (changed) { + ++count_; + } + + return changed; +} + +size_t TlsRecordFilter::ApplyFilter(uint8_t content_type, uint16_t version, + const DataBuffer& record, + DataBuffer* output, + size_t offset, bool* changed) { + const DataBuffer* source = &record; + DataBuffer filtered; + if (FilterRecord(content_type, version, record, &filtered) && + filtered.len() < 0x10000) { + *changed = true; + std::cerr << "record old: " << record << std::endl; + std::cerr << "record old: " << filtered << std::endl; + source = &filtered; + } + + output->Write(offset, source->len(), 2); + output->Write(offset + 2, *source); + return offset + 2 + source->len(); +} + +bool TlsHandshakeFilter::FilterRecord(uint8_t content_type, uint16_t version, + const DataBuffer& input, + DataBuffer* output) { + // Check that the first byte is as requested. + if (content_type != kTlsHandshakeType) { + return false; + } + + bool changed = false; + size_t output_offset = 0U; + output->Allocate(input.len()); // Preallocate a little. + + TlsParser parser(input); + while (parser.remaining()) { + size_t start = parser.consumed(); + uint8_t handshake_type; + if (!parser.Read(&handshake_type)) { + return false; // malformed + } + uint32_t length; + if (!parser.Read(&length, 3)) { + return false; // malformed + } + + if (IsDtls(version) && !CheckDtls(parser, length)) { + return false; + } + + size_t header_len = parser.consumed() - start; + output->Write(output_offset, input.data() + start, header_len); + + DataBuffer handshake; + if (!parser.Read(&handshake, length)) { + return false; + } + + // Move the offset in the output forward. ApplyFilter() returns the index + // of the end of the message it wrote to the output, so we need to identify + // offsets from the start of the message for length and the handshake + // message. + output_offset = ApplyFilter(version, handshake_type, handshake, + output, output_offset + 1, + output_offset + header_len, + &changed); + } + output->Truncate(output_offset); + return changed; +} + +bool TlsHandshakeFilter::CheckDtls(TlsParser& parser, size_t length) { + // Read and check DTLS parameters + if (!parser.Skip(2)) { // sequence number + return false; + } + + uint32_t fragment_offset; + if (!parser.Read(&fragment_offset, 3)) { + return false; + } + + uint32_t fragment_length; + if (!parser.Read(&fragment_length, 3)) { + return false; + } + + // All current tests where we are using this code don't fragment. + return (fragment_offset == 0 && fragment_length == length); +} + +size_t TlsHandshakeFilter::ApplyFilter( + uint16_t version, uint8_t handshake_type, const DataBuffer& handshake, + DataBuffer* output, size_t length_offset, size_t value_offset, + bool* changed) { + const DataBuffer* source = &handshake; + DataBuffer filtered; + if (FilterHandshake(version, handshake_type, handshake, &filtered) && + filtered.len() < 0x1000000) { + *changed = true; + std::cerr << "handshake old: " << handshake << std::endl; + std::cerr << "handshake new: " << filtered << std::endl; + source = &filtered; + } + + // Back up and overwrite the (two) length field(s): the handshake message + // length and the DTLS fragment length. + output->Write(length_offset, source->len(), 3); + if (IsDtls(version)) { + output->Write(length_offset + 8, source->len(), 3); + } + output->Write(value_offset, *source); + return value_offset + source->len(); +} + +bool TlsInspectorRecordHandshakeMessage::FilterHandshake( + uint16_t version, uint8_t handshake_type, + const DataBuffer& input, DataBuffer* output) { + // Only do this once. + if (buffer_.len()) { + return false; + } + + if (handshake_type == handshake_type_) { + buffer_ = input; + } + return false; +} + +bool TlsAlertRecorder::FilterRecord(uint8_t content_type, uint16_t version, + const DataBuffer& input, DataBuffer* output) { + if (level_ == kTlsAlertFatal) { // already fatal + return false; + } + if (content_type != kTlsAlertType) { + return false; + } + + TlsParser parser(input); + uint8_t lvl; + if (!parser.Read(&lvl)) { + return false; + } + if (lvl == kTlsAlertWarning) { // not strong enough + return false; + } + level_ = lvl; + (void)parser.Read(&description_); + return false; +} + +ChainedPacketFilter::~ChainedPacketFilter() { + for (auto it = filters_.begin(); it != filters_.end(); ++it) { + delete *it; + } +} + +bool ChainedPacketFilter::Filter(const DataBuffer& input, DataBuffer* output) { + DataBuffer in(input); + bool changed = false; + for (auto it = filters_.begin(); it != filters_.end(); ++it) { + if ((*it)->Filter(in, output)) { + in = *output; + changed = true; + } + } + return changed; +} + +} // namespace nss_test diff --git a/external_tests/ssl_gtest/tls_filter.h b/external_tests/ssl_gtest/tls_filter.h new file mode 100644 index 000000000..7ebd2c482 --- /dev/null +++ b/external_tests/ssl_gtest/tls_filter.h @@ -0,0 +1,113 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_filter_h_ +#define tls_filter_h_ + +#include <memory> +#include <vector> + +#include "test_io.h" +#include "tls_parser.h" + +namespace nss_test { + +// Abstract filter that operates on entire (D)TLS records. +class TlsRecordFilter : public PacketFilter { + public: + TlsRecordFilter() : count_(0) {} + + virtual bool Filter(const DataBuffer& input, DataBuffer* output); + + // Report how many packets were altered by the filter. + size_t filtered_packets() const { return count_; } + + protected: + virtual bool FilterRecord(uint8_t content_type, uint16_t version, + const DataBuffer& data, DataBuffer* changed) = 0; + private: + size_t ApplyFilter(uint8_t content_type, uint16_t version, + const DataBuffer& record, DataBuffer* output, + size_t offset, bool* changed); + + size_t count_; +}; + +// Abstract filter that operates on handshake messages rather than records. +// This assumes that the handshake messages are written in a block as entire +// records and that they don't span records or anything crazy like that. +class TlsHandshakeFilter : public TlsRecordFilter { + public: + TlsHandshakeFilter() {} + + protected: + virtual bool FilterRecord(uint8_t content_type, uint16_t version, + const DataBuffer& input, DataBuffer* output); + virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type, + const DataBuffer& input, DataBuffer* output) = 0; + + private: + bool CheckDtls(TlsParser& parser, size_t length); + size_t ApplyFilter(uint16_t version, uint8_t handshake_type, + const DataBuffer& record, DataBuffer* output, + size_t length_offset, size_t value_offset, bool* changed); +}; + +// Make a copy of the first instance of a handshake message. +class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter { + public: + TlsInspectorRecordHandshakeMessage(uint8_t handshake_type) + : handshake_type_(handshake_type), buffer_() {} + + virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type, + const DataBuffer& input, DataBuffer* output); + + const DataBuffer& buffer() const { return buffer_; } + + private: + uint8_t handshake_type_; + DataBuffer buffer_; +}; + +// Records an alert. If an alert has already been recorded, it won't save the +// new alert unless the old alert is a warning and the new one is fatal. +class TlsAlertRecorder : public TlsRecordFilter { + public: + TlsAlertRecorder() : level_(255), description_(255) {} + + virtual bool FilterRecord(uint8_t content_type, uint16_t version, + const DataBuffer& input, DataBuffer* output); + + uint8_t level() const { return level_; } + uint8_t description() const { return description_; } + + private: + uint8_t level_; + uint8_t description_; +}; + +// Runs multiple packet filters in series. +class ChainedPacketFilter : public PacketFilter { + public: + ChainedPacketFilter() {} + ChainedPacketFilter(const std::vector<PacketFilter*> filters) + : filters_(filters.begin(), filters.end()) {} + virtual ~ChainedPacketFilter(); + + virtual bool Filter(const DataBuffer& input, DataBuffer* output); + + // Takes ownership of the filter. + void Add(PacketFilter* filter) { + filters_.push_back(filter); + } + + private: + std::vector<PacketFilter*> filters_; +}; + +} // namespace nss_test + +#endif diff --git a/external_tests/ssl_gtest/tls_parser.cc b/external_tests/ssl_gtest/tls_parser.cc index cbd4c0239..1d56fffbf 100644 --- a/external_tests/ssl_gtest/tls_parser.cc +++ b/external_tests/ssl_gtest/tls_parser.cc @@ -6,13 +6,9 @@ #include "tls_parser.h" -// Process DTLS Records -#define CHECK_LENGTH(expected) \ - do { \ - if (remaining() < expected) return false; \ - } while (0) +namespace nss_test { -bool TlsParser::Read(unsigned char* val) { +bool TlsParser::Read(uint8_t* val) { if (remaining() < 1) { return false; } @@ -21,37 +17,55 @@ bool TlsParser::Read(unsigned char* val) { return true; } -bool TlsParser::Read(unsigned char* val, size_t len) { - if (remaining() < len) { +bool TlsParser::Read(uint32_t* val, size_t size) { + if (size > sizeof(uint32_t)) { return false; } - if (val) { - memcpy(val, ptr(), len); + uint32_t v = 0; + for (size_t i = 0; i < size; ++i) { + uint8_t tmp; + if (!Read(&tmp)) { + return false; + } + + v = (v << 8) | tmp; } - consume(len); + *val = v; return true; } -bool TlsRecordParser::NextRecord(uint8_t* ct, - std::auto_ptr<DataBuffer>* buffer) { - if (!remaining()) return false; - - CHECK_LENGTH(5U); - const uint8_t* ctp = reinterpret_cast<const uint8_t*>(ptr()); - consume(3); // ct + version - - const uint16_t* tmp = reinterpret_cast<const uint16_t*>(ptr()); - size_t length = ntohs(*tmp); - consume(2); +bool TlsParser::Read(DataBuffer* val, size_t len) { + if (remaining() < len) { + return false; + } - CHECK_LENGTH(length); - DataBuffer* db = new DataBuffer(ptr(), length); - consume(length); + val->Assign(ptr(), len); + consume(len); + return true; +} - *ct = *ctp; - buffer->reset(db); +bool TlsParser::ReadVariable(DataBuffer* val, size_t len_size) { + uint32_t len; + if (!Read(&len, len_size)) { + return false; + } + return Read(val, len); +} +bool TlsParser::Skip(size_t len) { + if (len > remaining()) { return false; } + consume(len); return true; } + +bool TlsParser::SkipVariable(size_t len_size) { + uint32_t len; + if (!Read(&len, len_size)) { + return false; + } + return Skip(len); +} + +} // namespace nss_test diff --git a/external_tests/ssl_gtest/tls_parser.h b/external_tests/ssl_gtest/tls_parser.h index 0276501f0..9ac4bdabe 100644 --- a/external_tests/ssl_gtest/tls_parser.h +++ b/external_tests/ssl_gtest/tls_parser.h @@ -8,17 +8,31 @@ #define tls_parser_h_ #include <memory> -#include <stdint.h> -#include <string.h> +#include <cstdint> +#include <cstring> #include <arpa/inet.h> #include "databuffer.h" +namespace nss_test { + const uint8_t kTlsChangeCipherSpecType = 0x14; +const uint8_t kTlsAlertType = 0x15; const uint8_t kTlsHandshakeType = 0x16; +const uint8_t kTlsHandshakeClientHello = 0x01; +const uint8_t kTlsHandshakeServerHello = 0x02; const uint8_t kTlsHandshakeCertificate = 0x0b; const uint8_t kTlsHandshakeServerKeyExchange = 0x0c; +const uint8_t kTlsAlertWarning = 1; +const uint8_t kTlsAlertFatal = 2; + +const uint8_t kTlsAlertHandshakeFailure = 0x28; +const uint8_t kTlsAlertIllegalParameter = 0x2f; +const uint8_t kTlsAlertDecodeError = 0x32; +const uint8_t kTlsAlertUnsupportedExtension = 0x6e; +const uint8_t kTlsAlertNoApplicationProtocol = 0x78; + const uint8_t kTlsFakeChangeCipherSpec[] = { kTlsChangeCipherSpecType, // Type 0xfe, 0xff, // Version @@ -28,56 +42,56 @@ const uint8_t kTlsFakeChangeCipherSpec[] = { 0x01 // Value }; +inline bool IsDtls(uint16_t version) { + return (version & 0x8000) == 0x8000; +} + +inline uint16_t NormalizeTlsVersion(uint16_t version) { + if (version == 0xfeff) { + return 0x0302; // special: DTLS 1.0 == TLS 1.1 + } + if (IsDtls(version)) { + return (version ^ 0xffff) + 0x0201; + } + return version; +} + +inline void WriteVariable(DataBuffer* target, size_t index, + const DataBuffer& buf, size_t len_size) { + target->Write(index, static_cast<uint32_t>(buf.len()), len_size); + target->Write(index + len_size, buf.data(), buf.len()); +} + class TlsParser { public: - TlsParser(const unsigned char *data, size_t len) + TlsParser(const uint8_t* data, size_t len) : buffer_(data, len), offset_(0) {} + explicit TlsParser(const DataBuffer& buf) + : buffer_(buf), offset_(0) {} - bool Read(unsigned char *val); - + bool Read(uint8_t* val); // Read an integral type of specified width. - bool Read(uint32_t *val, size_t len) { - if (len > sizeof(uint32_t)) return false; - - *val = 0; + bool Read(uint32_t* val, size_t size); + // Reads len bytes into dest buffer, overwriting it. + bool Read(DataBuffer* dest, size_t len); + // Reads bytes into dest buffer, overwriting it. The number of bytes is + // determined by reading from len_size bytes from the stream first. + bool ReadVariable(DataBuffer* dest, size_t len_size); - for (size_t i = 0; i < len; ++i) { - unsigned char tmp; + bool Skip(size_t len); + bool SkipVariable(size_t len_size); - (*val) <<= 8; - if (!Read(&tmp)) return false; - - *val += tmp; - } - - return true; - } - - bool Read(unsigned char *val, size_t len); + size_t consumed() const { return offset_; } size_t remaining() const { return buffer_.len() - offset_; } private: void consume(size_t len) { offset_ += len; } - const uint8_t *ptr() const { return buffer_.data() + offset_; } + const uint8_t* ptr() const { return buffer_.data() + offset_; } DataBuffer buffer_; size_t offset_; }; -class TlsRecordParser { - public: - TlsRecordParser(const unsigned char *data, size_t len) - : buffer_(data, len), offset_(0) {} - - bool NextRecord(uint8_t *ct, std::auto_ptr<DataBuffer> *buffer); - - private: - size_t remaining() const { return buffer_.len() - offset_; } - const uint8_t *ptr() const { return buffer_.data() + offset_; } - void consume(size_t len) { offset_ += len; } - - DataBuffer buffer_; - size_t offset_; -}; +} // namespace nss_test #endif |