summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Thomson <martin.thomson@gmail.com>2015-03-03 11:39:56 -0800
committerMartin Thomson <martin.thomson@gmail.com>2015-03-03 11:39:56 -0800
commitd90246d1a11d4cbb8a774df3e2beddd8ec913323 (patch)
treec308179224b0b8cac7787f151b66dfe8d1bbc68c
parenta32941f568893ef73de007fdd53f7220d5d219ff (diff)
downloadnss-hg-d90246d1a11d4cbb8a774df3e2beddd8ec913323.tar.gz
Bug 1139082 - Refactoring ssl_gtest to use filters, r=ekr
-rw-r--r--external_tests/ssl_gtest/databuffer.h123
-rw-r--r--external_tests/ssl_gtest/manifest.mn5
-rw-r--r--external_tests/ssl_gtest/ssl_loopback_unittest.cc605
-rw-r--r--external_tests/ssl_gtest/test_io.cc96
-rw-r--r--external_tests/ssl_gtest/test_io.h33
-rw-r--r--external_tests/ssl_gtest/tls_agent.cc208
-rw-r--r--external_tests/ssl_gtest/tls_agent.h170
-rw-r--r--external_tests/ssl_gtest/tls_connect.cc170
-rw-r--r--external_tests/ssl_gtest/tls_connect.h79
-rw-r--r--external_tests/ssl_gtest/tls_filter.cc226
-rw-r--r--external_tests/ssl_gtest/tls_filter.h113
-rw-r--r--external_tests/ssl_gtest/tls_parser.cc68
-rw-r--r--external_tests/ssl_gtest/tls_parser.h88
13 files changed, 1290 insertions, 694 deletions
diff --git a/external_tests/ssl_gtest/databuffer.h b/external_tests/ssl_gtest/databuffer.h
index 316aeb2a2..c3d3bb9be 100644
--- a/external_tests/ssl_gtest/databuffer.h
+++ b/external_tests/ssl_gtest/databuffer.h
@@ -7,33 +7,142 @@
#ifndef databuffer_h__
#define databuffer_h__
+#include <algorithm>
+#include <cassert>
+#include <cstring>
+#include <iomanip>
+#include <iostream>
+
+namespace nss_test {
+
class DataBuffer {
public:
DataBuffer() : data_(nullptr), len_(0) {}
DataBuffer(const uint8_t *data, size_t len) : data_(nullptr), len_(0) {
Assign(data, len);
}
+ explicit DataBuffer(const DataBuffer& other) : data_(nullptr), len_(0) {
+ Assign(other.data(), other.len());
+ }
~DataBuffer() { delete[] data_; }
- void Assign(const uint8_t *data, size_t len) {
- Allocate(len);
- memcpy(static_cast<void *>(data_), static_cast<const void *>(data), len);
+ DataBuffer& operator=(const DataBuffer& other) {
+ if (&other != this) {
+ Assign(other.data(), other.len());
+ }
+ return *this;
}
void Allocate(size_t len) {
delete[] data_;
- data_ = new unsigned char[len ? len : 1]; // Don't depend on new [0].
+ data_ = new uint8_t[len ? len : 1]; // Don't depend on new [0].
len_ = len;
}
+ void Truncate(size_t len) {
+ len_ = std::min(len_, len);
+ }
+
+ void Assign(const uint8_t* data, size_t len) {
+ Allocate(len);
+ memcpy(static_cast<void *>(data_), static_cast<const void *>(data), len);
+ }
+
+ // Write will do a new allocation and expand the size of the buffer if needed.
+ void Write(size_t index, const uint8_t* val, size_t count) {
+ if (index + count > len_) {
+ size_t newlen = index + count;
+ uint8_t* tmp = new uint8_t[newlen]; // Always > 0.
+ memcpy(static_cast<void*>(tmp),
+ static_cast<const void*>(data_), len_);
+ if (index > len_) {
+ memset(static_cast<void*>(tmp + len_), 0, index - len_);
+ }
+ delete[] data_;
+ data_ = tmp;
+ len_ = newlen;
+ }
+ memcpy(static_cast<void*>(data_ + index),
+ static_cast<const void*>(val), count);
+ }
+
+ void Write(size_t index, const DataBuffer& buf) {
+ Write(index, buf.data(), buf.len());
+ }
+
+ // Write an integer, also performing host-to-network order conversion.
+ void Write(size_t index, uint32_t val, size_t count) {
+ assert(count <= sizeof(uint32_t));
+ uint32_t nvalue = htonl(val);
+ auto* addr = reinterpret_cast<const uint8_t*>(&nvalue);
+ Write(index, addr + sizeof(uint32_t) - count, count);
+ }
+
+ // Starting at |index|, remove |remove| bytes and replace them with the
+ // contents of |buf|.
+ void Splice(const DataBuffer& buf, size_t index, size_t remove = 0) {
+ Splice(buf.data(), buf.len(), index, remove);
+ }
+
+ void Splice(const uint8_t* ins, size_t ins_len, size_t index, size_t remove = 0) {
+ uint8_t* old_value = data_;
+ size_t old_len = len_;
+
+ // The amount of stuff remaining from the tail of the old.
+ size_t tail_len = old_len - std::min(old_len, index + remove);
+ // The new length: the head of the old, the new, and the tail of the old.
+ len_ = index + ins_len + tail_len;
+ data_ = new uint8_t[len_ ? len_ : 1];
+
+ // The head of the old.
+ Write(0, old_value, std::min(old_len, index));
+ // Maybe a gap.
+ if (index > old_len) {
+ memset(old_value + index, 0, index - old_len);
+ }
+ // The new.
+ Write(index, ins, ins_len);
+ // The tail of the old.
+ if (tail_len > 0) {
+ Write(index + ins_len,
+ old_value + index + remove, tail_len);
+ }
+
+ delete[] old_value;
+ }
+
+ void Append(const DataBuffer& buf) { Splice(buf, len_); }
+
const uint8_t *data() const { return data_; }
- uint8_t *data() { return data_; }
+ uint8_t* data() { return data_; }
size_t len() const { return len_; }
- const bool empty() const { return len_ != 0; }
+ bool empty() const { return len_ == 0; }
private:
- uint8_t *data_;
+ uint8_t* data_;
size_t len_;
};
+#ifdef DEBUG
+static const size_t kMaxBufferPrint = 10000;
+#else
+static const size_t kMaxBufferPrint = 32;
+#endif
+
+inline std::ostream& operator<<(std::ostream& stream, const DataBuffer& buf) {
+ stream << "[" << buf.len() << "] ";
+ for (size_t i = 0; i < buf.len(); ++i) {
+ if (i >= kMaxBufferPrint) {
+ stream << "...";
+ break;
+ }
+ stream << std::hex << std::setfill('0') << std::setw(2)
+ << static_cast<unsigned>(buf.data()[i]);
+ }
+ stream << std::dec;
+ return stream;
+}
+
+} // namespace nss_test
+
#endif
diff --git a/external_tests/ssl_gtest/manifest.mn b/external_tests/ssl_gtest/manifest.mn
index e66f532b4..ee883e9ac 100644
--- a/external_tests/ssl_gtest/manifest.mn
+++ b/external_tests/ssl_gtest/manifest.mn
@@ -1,4 +1,4 @@
-#
+#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
@@ -10,6 +10,9 @@ CPPSRCS = \
ssl_loopback_unittest.cc \
ssl_gtest.cc \
test_io.cc \
+ tls_agent.cc \
+ tls_connect.cc \
+ tls_filter.cc \
tls_parser.cc \
$(NULL)
diff --git a/external_tests/ssl_gtest/ssl_loopback_unittest.cc b/external_tests/ssl_gtest/ssl_loopback_unittest.cc
index 6c01887a7..d70e2ceeb 100644
--- a/external_tests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/external_tests/ssl_gtest/ssl_loopback_unittest.cc
@@ -1,182 +1,24 @@
-#include "prio.h"
-#include "prerror.h"
-#include "prlog.h"
-#include "pk11func.h"
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
#include "ssl.h"
-#include "sslerr.h"
#include "sslproto.h"
-#include "keyhi.h"
#include <memory>
-#include "test_io.h"
#include "tls_parser.h"
-
-#define GTEST_HAS_RTTI 0
-#include "gtest/gtest.h"
-#include "gtest_utils.h"
-
-extern std::string g_working_dir_path;
+#include "tls_filter.h"
+#include "tls_connect.h"
namespace nss_test {
-enum SessionResumptionMode {
- RESUME_NONE = 0,
- RESUME_SESSIONID = 1,
- RESUME_TICKET = 2,
- RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
-};
-
-#define LOG(a) std::cerr << name_ << ": " << a << std::endl;
-
-// Inspector that parses out DTLS records and passes
-// them on.
-class TlsRecordInspector : public Inspector {
- public:
- virtual void Inspect(DummyPrSocket* adapter, const void* data, size_t len) {
- TlsRecordParser parser(static_cast<const unsigned char*>(data), len);
-
- uint8_t content_type;
- std::auto_ptr<DataBuffer> buf;
- while (parser.NextRecord(&content_type, &buf)) {
- OnRecord(adapter, content_type, buf->data(), buf->len());
- }
- }
-
- virtual void OnRecord(DummyPrSocket* adapter, uint8_t content_type,
- const unsigned char* record, size_t len) = 0;
-};
-
-// Inspector that injects arbitrary packets based on
-// DTLS records of various types.
-class TlsInspectorInjector : public TlsRecordInspector {
- public:
- TlsInspectorInjector(uint8_t packet_type, uint8_t handshake_type,
- const unsigned char* data, size_t len)
- : packet_type_(packet_type),
- handshake_type_(handshake_type),
- injected_(false),
- data_(data, len) {}
-
- virtual void OnRecord(DummyPrSocket* adapter, uint8_t content_type,
- const unsigned char* data, size_t len) {
- // Only inject once.
- if (injected_) {
- return;
- }
-
- // Check that the first byte is as requested.
- if (content_type != packet_type_) {
- return;
- }
-
- if (handshake_type_ != 0xff) {
- // Check that the packet is plausibly long enough.
- if (len < 1) {
- return;
- }
-
- // Check that the handshake type is as requested.
- if (data[0] != handshake_type_) {
- return;
- }
- }
-
- adapter->WriteDirect(data_.data(), data_.len());
- }
-
- private:
- uint8_t packet_type_;
- uint8_t handshake_type_;
- bool injected_;
- DataBuffer data_;
-};
-
-// Make a copy of the first instance of a message.
-class TlsInspectorRecordHandshakeMessage : public TlsRecordInspector {
- public:
- TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
- : handshake_type_(handshake_type), buffer_() {}
-
- virtual void OnRecord(DummyPrSocket* adapter, uint8_t content_type,
- const unsigned char* data, size_t len) {
- // Only do this once.
- if (buffer_.len()) {
- return;
- }
-
- // Check that the first byte is as requested.
- if (content_type != kTlsHandshakeType) {
- return;
- }
-
- TlsParser parser(data, len);
- while (parser.remaining()) {
- unsigned char message_type;
- // Read the content type.
- if (!parser.Read(&message_type)) {
- // Malformed.
- return;
- }
-
- // Read the record length.
- uint32_t length;
- if (!parser.Read(&length, 3)) {
- // Malformed.
- return;
- }
-
- if (adapter->mode() == DGRAM) {
- // DTLS
- uint32_t message_seq;
- if (!parser.Read(&message_seq, 2)) {
- return;
- }
-
- uint32_t fragment_offset;
- if (!parser.Read(&fragment_offset, 3)) {
- return;
- }
-
- uint32_t fragment_length;
- if (!parser.Read(&fragment_length, 3)) {
- return;
- }
-
- if ((fragment_offset != 0) || (fragment_length != length)) {
- // This shouldn't happen because all current tests where we
- // are using this code don't fragment.
- return;
- }
- }
-
- unsigned char* dest = nullptr;
-
- if (message_type == handshake_type_) {
- buffer_.Allocate(length);
- dest = buffer_.data();
- }
-
- if (!parser.Read(dest, length)) {
- // Malformed
- return;
- }
-
- if (dest) return;
- }
- }
-
- const DataBuffer& buffer() { return buffer_; }
-
- private:
- uint8_t handshake_type_;
- DataBuffer buffer_;
-};
-
class TlsServerKeyExchangeECDHE {
public:
- bool Parse(const unsigned char* data, size_t len) {
- TlsParser parser(data, len);
+ bool Parse(const DataBuffer& buffer) {
+ TlsParser parser(buffer);
uint8_t curve_type;
if (!parser.Read(&curve_type)) {
@@ -192,408 +34,12 @@ class TlsServerKeyExchangeECDHE {
return false;
}
- uint32_t point_length;
- if (!parser.Read(&point_length, 1)) {
- return false;
- }
-
- public_key_.Allocate(point_length);
- if (!parser.Read(public_key_.data(), point_length)) {
- return false;
- }
-
- return true;
+ return parser.ReadVariable(&public_key_, 1);
}
DataBuffer public_key_;
};
-class TlsAgent : public PollTarget {
- public:
- enum Role { CLIENT, SERVER };
- enum State { INIT, CONNECTING, CONNECTED, ERROR };
-
- TlsAgent(const std::string& name, Role role, Mode mode)
- : name_(name),
- mode_(mode),
- pr_fd_(nullptr),
- adapter_(nullptr),
- ssl_fd_(nullptr),
- role_(role),
- state_(INIT) {
- memset(&info_, 0, sizeof(info_));
- memset(&csinfo_, 0, sizeof(csinfo_));
- }
-
- ~TlsAgent() {
- if (pr_fd_) {
- PR_Close(pr_fd_);
- }
-
- if (ssl_fd_) {
- PR_Close(ssl_fd_);
- }
- }
-
- bool Init() {
- pr_fd_ = DummyPrSocket::CreateFD(name_, mode_);
- if (!pr_fd_) return false;
-
- adapter_ = DummyPrSocket::GetAdapter(pr_fd_);
- if (!adapter_) return false;
-
- return true;
- }
-
- void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); }
-
- void SetInspector(Inspector* inspector) { adapter_->SetInspector(inspector); }
-
- void StartConnect() {
- ASSERT_TRUE(EnsureTlsSetup());
-
- SECStatus rv;
- rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE);
- ASSERT_EQ(SECSuccess, rv);
- SetState(CONNECTING);
- }
-
- void EnableSomeECDHECiphers() {
- ASSERT_TRUE(EnsureTlsSetup());
-
- const uint32_t EnabledCiphers[] = {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
- TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA};
-
- for (size_t i = 0; i < PR_ARRAY_SIZE(EnabledCiphers); ++i) {
- SECStatus rv = SSL_CipherPrefSet(ssl_fd_, EnabledCiphers[i], PR_TRUE);
- ASSERT_EQ(SECSuccess, rv);
- }
- }
-
- bool EnsureTlsSetup() {
- // Don't set up twice
- if (ssl_fd_) return true;
-
- if (adapter_->mode() == STREAM) {
- ssl_fd_ = SSL_ImportFD(nullptr, pr_fd_);
- } else {
- ssl_fd_ = DTLS_ImportFD(nullptr, pr_fd_);
- }
-
- EXPECT_NE(nullptr, ssl_fd_);
- if (!ssl_fd_) return false;
- pr_fd_ = nullptr;
-
- if (role_ == SERVER) {
- CERTCertificate* cert = PK11_FindCertFromNickname(name_.c_str(), nullptr);
- EXPECT_NE(nullptr, cert);
- if (!cert) return false;
-
- SECKEYPrivateKey* priv = PK11_FindKeyByAnyCert(cert, nullptr);
- EXPECT_NE(nullptr, priv);
- if (!priv) return false; // Leak cert.
-
- SECStatus rv = SSL_ConfigSecureServer(ssl_fd_, cert, priv, kt_rsa);
- EXPECT_EQ(SECSuccess, rv);
- if (rv != SECSuccess) return false; // Leak cert and key.
-
- SECKEY_DestroyPrivateKey(priv);
- CERT_DestroyCertificate(cert);
- } else {
- SECStatus rv = SSL_SetURL(ssl_fd_, "server");
- EXPECT_EQ(SECSuccess, rv);
- if (rv != SECSuccess) return false;
- }
-
- SECStatus rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook,
- reinterpret_cast<void*>(this));
- EXPECT_EQ(SECSuccess, rv);
- if (rv != SECSuccess) return false;
-
- return true;
- }
-
- void SetVersionRange(uint16_t minver, uint16_t maxver) {
- SSLVersionRange range = {minver, maxver};
- ASSERT_EQ(SECSuccess, SSL_VersionRangeSet(ssl_fd_, &range));
- }
-
- State state() const { return state_; }
-
- const char* state_str() const { return state_str(state()); }
-
- const char* state_str(State state) const { return states[state]; }
-
- PRFileDesc* ssl_fd() { return ssl_fd_; }
-
- bool version(uint16_t* version) const {
- if (state_ != CONNECTED) return false;
-
- *version = info_.protocolVersion;
-
- return true;
- }
-
- bool cipher_suite(int16_t* cipher_suite) const {
- if (state_ != CONNECTED) return false;
-
- *cipher_suite = info_.cipherSuite;
- return true;
- }
-
- std::string cipher_suite_name() const {
- if (state_ != CONNECTED) return "UNKNOWN";
-
- return csinfo_.cipherSuiteName;
- }
-
- void CheckKEAType(SSLKEAType type) const {
- ASSERT_EQ(CONNECTED, state_);
- ASSERT_EQ(type, csinfo_.keaType);
- }
-
- void CheckVersion(uint16_t version) const {
- ASSERT_EQ(CONNECTED, state_);
- ASSERT_EQ(version, info_.protocolVersion);
- }
-
-
- void Handshake() {
- SECStatus rv = SSL_ForceHandshake(ssl_fd_);
- if (rv == SECSuccess) {
- LOG("Handshake success");
- SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_));
- ASSERT_EQ(SECSuccess, rv);
-
- rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_));
- ASSERT_EQ(SECSuccess, rv);
-
- SetState(CONNECTED);
- return;
- }
-
- int32_t err = PR_GetError();
- switch (err) {
- case PR_WOULD_BLOCK_ERROR:
- LOG("Would have blocked");
- // TODO(ekr@rtfm.com): set DTLS timeouts
- Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
- &TlsAgent::ReadableCallback);
- return;
- break;
-
- // TODO(ekr@rtfm.com): needs special case for DTLS
- case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
- default:
- LOG("Handshake failed with error " << err);
- SetState(ERROR);
- return;
- }
- }
-
- std::vector<uint8_t> GetSessionId() {
- return std::vector<uint8_t>(info_.sessionID,
- info_.sessionID + info_.sessionIDLength);
- }
-
- void ConfigureSessionCache(SessionResumptionMode mode) {
- ASSERT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd_,
- SSL_NO_CACHE,
- mode & RESUME_SESSIONID ?
- PR_FALSE : PR_TRUE);
- ASSERT_EQ(SECSuccess, rv);
-
- rv = SSL_OptionSet(ssl_fd_,
- SSL_ENABLE_SESSION_TICKETS,
- mode & RESUME_TICKET ?
- PR_TRUE : PR_FALSE);
- ASSERT_EQ(SECSuccess, rv);
- }
-
- private:
- const static char* states[];
-
- void SetState(State state) {
- if (state_ == state) return;
-
- LOG("Changing state from " << state_str(state_) << " to "
- << state_str(state));
- state_ = state;
- }
-
- // Dummy auth certificate hook.
- static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
- PRBool checksig, PRBool isServer) {
- return SECSuccess;
- }
-
- static void ReadableCallback(PollTarget* self, Event event) {
- TlsAgent* agent = static_cast<TlsAgent*>(self);
- agent->ReadableCallback_int(event);
- }
-
- void ReadableCallback_int(Event event) {
- LOG("Readable");
- Handshake();
- }
-
- const std::string name_;
- Mode mode_;
- PRFileDesc* pr_fd_;
- DummyPrSocket* adapter_;
- PRFileDesc* ssl_fd_;
- Role role_;
- State state_;
- SSLChannelInfo info_;
- SSLCipherSuiteInfo csinfo_;
-};
-
-const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
-
-class TlsConnectTestBase : public ::testing::Test {
- public:
- TlsConnectTestBase(Mode mode)
- : mode_(mode),
- client_(new TlsAgent("client", TlsAgent::CLIENT, mode_)),
- server_(new TlsAgent("server", TlsAgent::SERVER, mode_)) {}
-
- ~TlsConnectTestBase() {
- delete client_;
- delete server_;
- }
-
- void SetUp() {
- // Configure a fresh session cache.
- SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
-
- // Clear statistics.
- SSL3Statistics* stats = SSL_GetStatistics();
- memset(stats, 0, sizeof(*stats));
-
- Init();
- }
-
- void TearDown() {
- client_ = nullptr;
- server_ = nullptr;
-
- SSL_ClearSessionCache();
- SSL_ShutdownServerSessionIDCache();
- }
-
- void Init() {
- ASSERT_TRUE(client_->Init());
- ASSERT_TRUE(server_->Init());
-
- client_->SetPeer(server_);
- server_->SetPeer(client_);
- }
-
- void Reset() {
- delete client_;
- delete server_;
-
- client_ = new TlsAgent("client", TlsAgent::CLIENT, mode_);
- server_ = new TlsAgent("server", TlsAgent::SERVER, mode_);
-
- Init();
- }
-
- void EnsureTlsSetup() {
- ASSERT_TRUE(client_->EnsureTlsSetup());
- ASSERT_TRUE(server_->EnsureTlsSetup());
- }
-
- void Connect() {
- server_->StartConnect(); // Server
- client_->StartConnect(); // Client
- client_->Handshake();
- server_->Handshake();
-
- ASSERT_TRUE_WAIT(client_->state() != TlsAgent::CONNECTING &&
- server_->state() != TlsAgent::CONNECTING,
- 5000);
- ASSERT_EQ(TlsAgent::CONNECTED, server_->state());
-
- int16_t cipher_suite1, cipher_suite2;
- bool ret = client_->cipher_suite(&cipher_suite1);
- ASSERT_TRUE(ret);
- ret = server_->cipher_suite(&cipher_suite2);
- ASSERT_TRUE(ret);
- ASSERT_EQ(cipher_suite1, cipher_suite2);
-
- std::cerr << "Connected with cipher suite " << client_->cipher_suite_name()
- << std::endl;
-
- // Check and store session ids.
- std::vector<uint8_t> sid_c1 = client_->GetSessionId();
- ASSERT_EQ(32, sid_c1.size());
- std::vector<uint8_t> sid_s1 = server_->GetSessionId();
- ASSERT_EQ(32, sid_s1.size());
- ASSERT_EQ(sid_c1, sid_s1);
- session_ids_.push_back(sid_c1);
- }
-
- void EnableSomeECDHECiphers() {
- client_->EnableSomeECDHECiphers();
- server_->EnableSomeECDHECiphers();
- }
-
- void ConfigureSessionCache(SessionResumptionMode client,
- SessionResumptionMode server) {
- client_->ConfigureSessionCache(client);
- server_->ConfigureSessionCache(server);
- }
-
- void CheckResumption(SessionResumptionMode expected) {
- ASSERT_NE(RESUME_BOTH, expected);
-
- int resume_ct = expected != 0;
- int stateless_ct = (expected & RESUME_TICKET) ? 1 : 0;
-
- SSL3Statistics* stats = SSL_GetStatistics();
- ASSERT_EQ(resume_ct, stats->hch_sid_cache_hits);
- ASSERT_EQ(resume_ct, stats->hsh_sid_cache_hits);
-
- ASSERT_EQ(stateless_ct, stats->hch_sid_stateless_resumes);
- ASSERT_EQ(stateless_ct, stats->hsh_sid_stateless_resumes);
-
- if (resume_ct) {
- // Check that the last two session ids match.
- ASSERT_GE(2, session_ids_.size());
- ASSERT_EQ(session_ids_[session_ids_.size()-1],
- session_ids_[session_ids_.size()-2]);
- }
- }
-
- protected:
- Mode mode_;
- TlsAgent* client_;
- TlsAgent* server_;
- std::vector<std::vector<uint8_t>> session_ids_;
-};
-
-class TlsConnectTest : public TlsConnectTestBase {
- public:
- TlsConnectTest() : TlsConnectTestBase(STREAM) {}
-};
-
-class DtlsConnectTest : public TlsConnectTestBase {
- public:
- DtlsConnectTest() : TlsConnectTestBase(DGRAM) {}
-};
-
-class TlsConnectGeneric : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::string> {
- public:
- TlsConnectGeneric()
- : TlsConnectTestBase((GetParam() == "TLS") ? STREAM : DGRAM) {
- std::cerr << "Variant: " << GetParam() << std::endl;
- }
-};
-
TEST_P(TlsConnectGeneric, SetupOnly) {}
TEST_P(TlsConnectGeneric, Connect) {
@@ -729,6 +175,19 @@ TEST_P(TlsConnectGeneric, ConnectTLS_1_2_Only) {
client_->CheckVersion(SSL_LIBRARY_VERSION_TLS_1_2);
}
+TEST_P(TlsConnectGeneric, ConnectAlpn) {
+ EnableAlpn();
+ Connect();
+ client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, "a");
+ server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, "a");
+}
+
+TEST_F(DtlsConnectTest, ConnectSrtp) {
+ EnableSrtp();
+ Connect();
+ CheckSrtp();
+}
+
TEST_F(TlsConnectTest, ConnectECDHE) {
EnableSomeECDHECiphers();
Connect();
@@ -739,24 +198,24 @@ TEST_F(TlsConnectTest, ConnectECDHETwiceReuseKey) {
EnableSomeECDHECiphers();
TlsInspectorRecordHandshakeMessage* i1 =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
- server_->SetInspector(i1);
+ server_->SetPacketFilter(i1);
Connect();
client_->CheckKEAType(ssl_kea_ecdh);
TlsServerKeyExchangeECDHE dhe1;
- ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len()));
+ ASSERT_TRUE(dhe1.Parse(i1->buffer()));
// Restart
Reset();
TlsInspectorRecordHandshakeMessage* i2 =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
- server_->SetInspector(i2);
+ server_->SetPacketFilter(i2);
EnableSomeECDHECiphers();
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
client_->CheckKEAType(ssl_kea_ecdh);
TlsServerKeyExchangeECDHE dhe2;
- ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len()));
+ ASSERT_TRUE(dhe2.Parse(i2->buffer()));
// Make sure they are the same.
ASSERT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len());
@@ -771,11 +230,11 @@ TEST_F(TlsConnectTest, ConnectECDHETwiceNewKey) {
ASSERT_EQ(SECSuccess, rv);
TlsInspectorRecordHandshakeMessage* i1 =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
- server_->SetInspector(i1);
+ server_->SetPacketFilter(i1);
Connect();
client_->CheckKEAType(ssl_kea_ecdh);
TlsServerKeyExchangeECDHE dhe1;
- ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len()));
+ ASSERT_TRUE(dhe1.Parse(i1->buffer()));
// Restart
Reset();
@@ -784,13 +243,13 @@ TEST_F(TlsConnectTest, ConnectECDHETwiceNewKey) {
ASSERT_EQ(SECSuccess, rv);
TlsInspectorRecordHandshakeMessage* i2 =
new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
- server_->SetInspector(i2);
+ server_->SetPacketFilter(i2);
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
client_->CheckKEAType(ssl_kea_ecdh);
TlsServerKeyExchangeECDHE dhe2;
- ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len()));
+ ASSERT_TRUE(dhe2.Parse(i2->buffer()));
// Make sure they are different.
ASSERT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) &&
diff --git a/external_tests/ssl_gtest/test_io.cc b/external_tests/ssl_gtest/test_io.cc
index 701647831..2bfd09178 100644
--- a/external_tests/ssl_gtest/test_io.cc
+++ b/external_tests/ssl_gtest/test_io.cc
@@ -4,42 +4,45 @@
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at http://mozilla.org/MPL/2.0/. */
-#include <assert.h>
+#include "test_io.h"
+#include <algorithm>
+#include <cassert>
#include <iostream>
#include <memory>
#include "prerror.h"
-#include "prio.h"
#include "prlog.h"
#include "prthread.h"
-#include "test_io.h"
+#include "databuffer.h"
namespace nss_test {
static PRDescIdentity test_fd_identity = PR_INVALID_IO_LAYER;
-#define UNIMPLEMENTED() \
- fprintf(stderr, "Call to unimplemented function %s\n", __FUNCTION__); \
- PR_ASSERT(PR_FALSE); \
+#define UNIMPLEMENTED() \
+ std::cerr << "Call to unimplemented function " \
+ << __FUNCTION__ << std::endl; \
+ PR_ASSERT(PR_FALSE); \
PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0)
#define LOG(a) std::cerr << name_ << ": " << a << std::endl;
-struct Packet {
- Packet() : data_(nullptr), len_(0), offset_(0) {}
+class Packet : public DataBuffer {
+ public:
+ Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {}
- void Assign(const void *data, int32_t len) {
- data_ = new uint8_t[len];
- memcpy(data_, data, len);
- len_ = len;
+ void Advance(size_t delta) {
+ PR_ASSERT(offset_ + delta <= len());
+ offset_ = std::min(len(), offset_ + delta);
}
- ~Packet() { delete data_; }
- uint8_t *data_;
- int32_t len_;
- int32_t offset_;
+ size_t offset() const { return offset_; }
+ size_t remaining() const { return len() - offset_; }
+
+ private:
+ size_t offset_;
};
// Implementation of NSPR methods
@@ -246,6 +249,16 @@ static int32_t DummyReserved(PRFileDesc *f) {
return -1;
}
+DummyPrSocket::~DummyPrSocket() {
+ delete filter_;
+ while (!input_.empty())
+ {
+ Packet* front = input_.front();
+ input_.pop();
+ delete front;
+ }
+}
+
static const struct PRIOMethods DummyMethods = {
PR_DESC_LAYERED, DummyClose, DummyRead,
DummyWrite, DummyAvailable, DummyAvailable64,
@@ -275,9 +288,8 @@ DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) {
return reinterpret_cast<DummyPrSocket *>(fd->secret);
}
-void DummyPrSocket::PacketReceived(const void *data, int32_t len) {
- input_.push(new Packet());
- input_.back()->Assign(data, len);
+void DummyPrSocket::PacketReceived(const DataBuffer& packet) {
+ input_.push(new Packet(packet));
}
int32_t DummyPrSocket::Read(void *data, int32_t len) {
@@ -295,16 +307,18 @@ int32_t DummyPrSocket::Read(void *data, int32_t len) {
}
Packet *front = input_.front();
- int32_t to_read = std::min(len, front->len_ - front->offset_);
- memcpy(data, front->data_ + front->offset_, to_read);
- front->offset_ += to_read;
+ size_t to_read = std::min(static_cast<size_t>(len),
+ front->len() - front->offset());
+ memcpy(data, static_cast<const void*>(front->data() + front->offset()),
+ to_read);
+ front->Advance(to_read);
- if (front->offset_ == front->len_) {
+ if (!front->remaining()) {
input_.pop();
delete front;
}
- return to_read;
+ return static_cast<int32_t>(to_read);
}
int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) {
@@ -314,39 +328,49 @@ int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) {
}
Packet *front = input_.front();
- if (buflen < front->len_) {
+ if (buflen < front->len()) {
PR_ASSERT(false);
PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0);
return -1;
}
- int32_t count = front->len_;
- memcpy(buf, front->data_, count);
+ size_t count = front->len();
+ memcpy(buf, front->data(), count);
input_.pop();
delete front;
- return count;
+ return static_cast<int32_t>(count);
}
int32_t DummyPrSocket::Write(const void *buf, int32_t length) {
- if (inspector_) {
- inspector_->Inspect(this, buf, length);
+ DataBuffer packet(static_cast<const uint8_t*>(buf),
+ static_cast<size_t>(length));
+ if (filter_) {
+ DataBuffer filtered;
+ if (filter_->Filter(packet, &filtered)) {
+ if (WriteDirect(filtered) != filtered.len()) {
+ PR_SetError(PR_IO_ERROR, 0);
+ return -1;
+ }
+ LOG("Wrote: " << packet);
+ // libssl can't handle if this reports something other than the length of
+ // what was passed in (or less, but we're not doing partial writes).
+ return packet.len();
+ }
}
- return WriteDirect(buf, length);
+ return WriteDirect(packet);
}
-int32_t DummyPrSocket::WriteDirect(const void *buf, int32_t length) {
+int32_t DummyPrSocket::WriteDirect(const DataBuffer& packet) {
if (!peer_) {
PR_SetError(PR_IO_ERROR, 0);
return -1;
}
- LOG("Wrote " << length);
-
- peer_->PacketReceived(buf, length);
- return length;
+ peer_->PacketReceived(packet);
+ return static_cast<int32_t>(packet.len()); // ignore truncation
}
Poller *Poller::instance;
diff --git a/external_tests/ssl_gtest/test_io.h b/external_tests/ssl_gtest/test_io.h
index 64cc4fd5a..d2424c60c 100644
--- a/external_tests/ssl_gtest/test_io.h
+++ b/external_tests/ssl_gtest/test_io.h
@@ -13,25 +13,32 @@
#include <queue>
#include <string>
+#include "prio.h"
+
namespace nss_test {
-struct Packet;
+class DataBuffer;
+class Packet;
class DummyPrSocket; // Fwd decl.
// Allow us to inspect a packet before it is written.
-class Inspector {
+class PacketFilter {
public:
- virtual ~Inspector() {}
-
- virtual void Inspect(DummyPrSocket* adapter, const void* data,
- size_t len) = 0;
+ virtual ~PacketFilter() {}
+
+ // The packet filter takes input and has the option of mutating it.
+ //
+ // A filter that modifies the data places the modified data in *output and
+ // returns true. A filter that does not modify data returns false, in which
+ // case the value in *output is ignored.
+ virtual bool Filter(const DataBuffer& input, DataBuffer* output) = 0;
};
enum Mode { STREAM, DGRAM };
class DummyPrSocket {
public:
- ~DummyPrSocket() { delete inspector_; }
+ ~DummyPrSocket();
static PRFileDesc* CreateFD(const std::string& name,
Mode mode); // Returns an FD.
@@ -39,16 +46,16 @@ class DummyPrSocket {
void SetPeer(DummyPrSocket* peer) { peer_ = peer; }
- void SetInspector(Inspector* inspector) { inspector_ = inspector; }
+ void SetPacketFilter(PacketFilter* filter) { filter_ = filter; }
- void PacketReceived(const void* data, int32_t len);
+ void PacketReceived(const DataBuffer& data);
int32_t Read(void* data, int32_t len);
int32_t Recv(void* buf, int32_t buflen);
int32_t Write(const void* buf, int32_t length);
- int32_t WriteDirect(const void* buf, int32_t length);
+ int32_t WriteDirect(const DataBuffer& data);
Mode mode() const { return mode_; }
- bool readable() { return !input_.empty(); }
+ bool readable() const { return !input_.empty(); }
bool writable() { return true; }
private:
@@ -57,13 +64,13 @@ class DummyPrSocket {
mode_(mode),
peer_(nullptr),
input_(),
- inspector_(nullptr) {}
+ filter_(nullptr) {}
const std::string name_;
Mode mode_;
DummyPrSocket* peer_;
std::queue<Packet*> input_;
- Inspector* inspector_;
+ PacketFilter* filter_;
};
// Marker interface.
diff --git a/external_tests/ssl_gtest/tls_agent.cc b/external_tests/ssl_gtest/tls_agent.cc
new file mode 100644
index 000000000..6eeb651f5
--- /dev/null
+++ b/external_tests/ssl_gtest/tls_agent.cc
@@ -0,0 +1,208 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_agent.h"
+
+#include "pk11func.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+#include "keyhi.h"
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+
+namespace nss_test {
+
+const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
+
+bool TlsAgent::EnsureTlsSetup() {
+ // Don't set up twice
+ if (ssl_fd_) return true;
+
+ if (adapter_->mode() == STREAM) {
+ ssl_fd_ = SSL_ImportFD(nullptr, pr_fd_);
+ } else {
+ ssl_fd_ = DTLS_ImportFD(nullptr, pr_fd_);
+ }
+
+ EXPECT_NE(nullptr, ssl_fd_);
+ if (!ssl_fd_) return false;
+ pr_fd_ = nullptr;
+
+ if (role_ == SERVER) {
+ CERTCertificate* cert = PK11_FindCertFromNickname(name_.c_str(), nullptr);
+ EXPECT_NE(nullptr, cert);
+ if (!cert) return false;
+
+ SECKEYPrivateKey* priv = PK11_FindKeyByAnyCert(cert, nullptr);
+ EXPECT_NE(nullptr, priv);
+ if (!priv) return false; // Leak cert.
+
+ SECStatus rv = SSL_ConfigSecureServer(ssl_fd_, cert, priv, kt_rsa);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false; // Leak cert and key.
+
+ SECKEY_DestroyPrivateKey(priv);
+ CERT_DestroyCertificate(cert);
+
+ rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook,
+ reinterpret_cast<void*>(this));
+ EXPECT_EQ(SECSuccess, rv); // don't abort, just fail
+ } else {
+ SECStatus rv = SSL_SetURL(ssl_fd_, "server");
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+ }
+
+ SECStatus rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook,
+ reinterpret_cast<void*>(this));
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ return true;
+}
+
+void TlsAgent::StartConnect() {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv;
+ rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE);
+ ASSERT_EQ(SECSuccess, rv);
+ SetState(CONNECTING);
+}
+
+void TlsAgent::EnableSomeECDHECiphers() {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ const uint32_t EnabledCiphers[] = {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA};
+
+ for (size_t i = 0; i < PR_ARRAY_SIZE(EnabledCiphers); ++i) {
+ SECStatus rv = SSL_CipherPrefSet(ssl_fd_, EnabledCiphers[i], PR_TRUE);
+ ASSERT_EQ(SECSuccess, rv);
+ }
+}
+
+void TlsAgent::SetSessionTicketsEnabled(bool en) {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
+ en ? PR_TRUE : PR_FALSE);
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetSessionCacheEnabled(bool en) {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE,
+ en ? PR_FALSE : PR_TRUE);
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
+ SSLVersionRange range = {minver, maxver};
+ ASSERT_EQ(SECSuccess, SSL_VersionRangeSet(ssl_fd_, &range));
+}
+
+void TlsAgent::CheckKEAType(SSLKEAType type) const {
+ ASSERT_EQ(CONNECTED, state_);
+ ASSERT_EQ(type, csinfo_.keaType);
+}
+
+void TlsAgent::CheckVersion(uint16_t version) const {
+ ASSERT_EQ(CONNECTED, state_);
+ ASSERT_EQ(version, info_.protocolVersion);
+}
+
+void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ ASSERT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE));
+ ASSERT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len));
+}
+
+void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
+ const std::string& expected) {
+ SSLNextProtoState state;
+ char chosen[10];
+ unsigned int chosen_len;
+ SECStatus rv = SSL_GetNextProto(ssl_fd_, &state,
+ reinterpret_cast<unsigned char*>(chosen),
+ &chosen_len, sizeof(chosen));
+ ASSERT_EQ(SECSuccess, rv);
+ ASSERT_EQ(expected_state, state);
+ ASSERT_EQ(expected, std::string(chosen, chosen_len));
+}
+
+void TlsAgent::EnableSrtp() {
+ ASSERT_TRUE(EnsureTlsSetup());
+ const uint16_t ciphers[] = {
+ SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32
+ };
+ ASSERT_EQ(SECSuccess, SSL_SetSRTPCiphers(ssl_fd_, ciphers,
+ PR_ARRAY_SIZE(ciphers)));
+
+}
+
+void TlsAgent::CheckSrtp() {
+ uint16_t actual;
+ ASSERT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual));
+ ASSERT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
+}
+
+
+void TlsAgent::Handshake() {
+ SECStatus rv = SSL_ForceHandshake(ssl_fd_);
+ if (rv == SECSuccess) {
+ LOG("Handshake success");
+ SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_));
+ ASSERT_EQ(SECSuccess, rv);
+
+ rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_));
+ ASSERT_EQ(SECSuccess, rv);
+
+ SetState(CONNECTED);
+ return;
+ }
+
+ int32_t err = PR_GetError();
+ switch (err) {
+ case PR_WOULD_BLOCK_ERROR:
+ LOG("Would have blocked");
+ // TODO(ekr@rtfm.com): set DTLS timeouts
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ return;
+ break;
+
+ // TODO(ekr@rtfm.com): needs special case for DTLS
+ case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
+ default:
+ LOG("Handshake failed with error " << err);
+ SetState(ERROR);
+ return;
+ }
+}
+
+void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_,
+ SSL_NO_CACHE,
+ mode & RESUME_SESSIONID ?
+ PR_FALSE : PR_TRUE);
+ ASSERT_EQ(SECSuccess, rv);
+
+ rv = SSL_OptionSet(ssl_fd_,
+ SSL_ENABLE_SESSION_TICKETS,
+ mode & RESUME_TICKET ?
+ PR_TRUE : PR_FALSE);
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+
+} // namespace nss_test
diff --git a/external_tests/ssl_gtest/tls_agent.h b/external_tests/ssl_gtest/tls_agent.h
new file mode 100644
index 000000000..aee835ea7
--- /dev/null
+++ b/external_tests/ssl_gtest/tls_agent.h
@@ -0,0 +1,170 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_agent_h_
+#define tls_agent_h_
+
+#include "prio.h"
+#include "ssl.h"
+
+#include <iostream>
+
+#include "test_io.h"
+
+namespace nss_test {
+
+#define LOG(msg) std::cerr << name_ << ": " << msg << std::endl
+
+enum SessionResumptionMode {
+ RESUME_NONE = 0,
+ RESUME_SESSIONID = 1,
+ RESUME_TICKET = 2,
+ RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
+};
+
+class TlsAgent : public PollTarget {
+ public:
+ enum Role { CLIENT, SERVER };
+ enum State { INIT, CONNECTING, CONNECTED, ERROR };
+
+ TlsAgent(const std::string& name, Role role, Mode mode)
+ : name_(name),
+ mode_(mode),
+ pr_fd_(nullptr),
+ adapter_(nullptr),
+ ssl_fd_(nullptr),
+ role_(role),
+ state_(INIT) {
+ memset(&info_, 0, sizeof(info_));
+ memset(&csinfo_, 0, sizeof(csinfo_));
+ }
+
+ ~TlsAgent() {
+ if (pr_fd_) {
+ PR_Close(pr_fd_);
+ }
+
+ if (ssl_fd_) {
+ PR_Close(ssl_fd_);
+ }
+ }
+
+ bool Init() {
+ pr_fd_ = DummyPrSocket::CreateFD(name_, mode_);
+ if (!pr_fd_) return false;
+
+ adapter_ = DummyPrSocket::GetAdapter(pr_fd_);
+ if (!adapter_) return false;
+
+ return true;
+ }
+
+ void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); }
+
+ void SetPacketFilter(PacketFilter* filter) {
+ adapter_->SetPacketFilter(filter);
+ }
+
+
+ void StartConnect();
+ void CheckKEAType(SSLKEAType type) const;
+ void CheckVersion(uint16_t version) const;
+
+ void Handshake();
+ void EnableSomeECDHECiphers();
+ bool EnsureTlsSetup();
+
+ void ConfigureSessionCache(SessionResumptionMode mode);
+ void SetSessionTicketsEnabled(bool en);
+ void SetSessionCacheEnabled(bool en);
+ void SetVersionRange(uint16_t minver, uint16_t maxver);
+ void EnableAlpn(const uint8_t* val, size_t len);
+ void CheckAlpn(SSLNextProtoState expected_state,
+ const std::string& expected);
+ void EnableSrtp();
+ void CheckSrtp();
+
+ State state() const { return state_; }
+
+ const char* state_str() const { return state_str(state()); }
+
+ const char* state_str(State state) const { return states[state]; }
+
+ PRFileDesc* ssl_fd() { return ssl_fd_; }
+
+ bool version(uint16_t* version) const {
+ if (state_ != CONNECTED) return false;
+
+ *version = info_.protocolVersion;
+
+ return true;
+ }
+
+ bool cipher_suite(int16_t* cipher_suite) const {
+ if (state_ != CONNECTED) return false;
+
+ *cipher_suite = info_.cipherSuite;
+ return true;
+ }
+
+ std::string cipher_suite_name() const {
+ if (state_ != CONNECTED) return "UNKNOWN";
+
+ return csinfo_.cipherSuiteName;
+ }
+
+ std::vector<uint8_t> session_id() const {
+ return std::vector<uint8_t>(info_.sessionID,
+ info_.sessionID + info_.sessionIDLength);
+ }
+
+ private:
+ const static char* states[];
+
+ void SetState(State state) {
+ if (state_ == state) return;
+
+ LOG("Changing state from " << state_str(state_) << " to "
+ << state_str(state));
+ state_ = state;
+ }
+
+ // Dummy auth certificate hook.
+ static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
+ PRBool checksig, PRBool isServer) {
+ return SECSuccess;
+ }
+
+ static void ReadableCallback(PollTarget* self, Event event) {
+ TlsAgent* agent = static_cast<TlsAgent*>(self);
+ agent->ReadableCallback_int();
+ }
+
+ void ReadableCallback_int() {
+ LOG("Readable");
+ Handshake();
+ }
+
+ static PRInt32 SniHook(PRFileDesc *fd, const SECItem *srvNameArr,
+ PRUint32 srvNameArrSize,
+ void *arg) {
+ return SSL_SNI_CURRENT_CONFIG_IS_USED;
+ }
+
+ const std::string name_;
+ Mode mode_;
+ PRFileDesc* pr_fd_;
+ DummyPrSocket* adapter_;
+ PRFileDesc* ssl_fd_;
+ Role role_;
+ State state_;
+ SSLChannelInfo info_;
+ SSLCipherSuiteInfo csinfo_;
+};
+
+} // namespace nss_test
+
+#endif
diff --git a/external_tests/ssl_gtest/tls_connect.cc b/external_tests/ssl_gtest/tls_connect.cc
new file mode 100644
index 000000000..6c6fd1a53
--- /dev/null
+++ b/external_tests/ssl_gtest/tls_connect.cc
@@ -0,0 +1,170 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_connect.h"
+
+#include <iostream>
+
+#include "gtest_utils.h"
+
+extern std::string g_working_dir_path;
+
+namespace nss_test {
+
+TlsConnectTestBase::TlsConnectTestBase(Mode mode)
+ : mode_(mode),
+ client_(new TlsAgent("client", TlsAgent::CLIENT, mode_)),
+ server_(new TlsAgent("server", TlsAgent::SERVER, mode_)) {}
+
+TlsConnectTestBase::~TlsConnectTestBase() {
+ delete client_;
+ delete server_;
+}
+
+void TlsConnectTestBase::SetUp() {
+ // Configure a fresh session cache.
+ SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
+
+ // Clear statistics.
+ SSL3Statistics* stats = SSL_GetStatistics();
+ memset(stats, 0, sizeof(*stats));
+
+ Init();
+}
+
+void TlsConnectTestBase::TearDown() {
+ client_ = nullptr;
+ server_ = nullptr;
+
+ SSL_ClearSessionCache();
+ SSL_ShutdownServerSessionIDCache();
+}
+
+void TlsConnectTestBase::Init() {
+ ASSERT_TRUE(client_->Init());
+ ASSERT_TRUE(server_->Init());
+
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+}
+
+void TlsConnectTestBase::Reset() {
+ delete client_;
+ delete server_;
+
+ client_ = new TlsAgent("client", TlsAgent::CLIENT, mode_);
+ server_ = new TlsAgent("server", TlsAgent::SERVER, mode_);
+
+ Init();
+}
+
+void TlsConnectTestBase::EnsureTlsSetup() {
+ ASSERT_TRUE(client_->EnsureTlsSetup());
+ ASSERT_TRUE(server_->EnsureTlsSetup());
+}
+
+void TlsConnectTestBase::Handshake() {
+ server_->StartConnect();
+ client_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ ASSERT_TRUE_WAIT(client_->state() != TlsAgent::CONNECTING &&
+ server_->state() != TlsAgent::CONNECTING,
+ 5000);
+}
+
+void TlsConnectTestBase::Connect() {
+ Handshake();
+
+ ASSERT_EQ(TlsAgent::CONNECTED, client_->state());
+ ASSERT_EQ(TlsAgent::CONNECTED, server_->state());
+
+ int16_t cipher_suite1, cipher_suite2;
+ bool ret = client_->cipher_suite(&cipher_suite1);
+ ASSERT_TRUE(ret);
+ ret = server_->cipher_suite(&cipher_suite2);
+ ASSERT_TRUE(ret);
+ ASSERT_EQ(cipher_suite1, cipher_suite2);
+
+ std::cerr << "Connected with cipher suite " << client_->cipher_suite_name()
+ << std::endl;
+
+ // Check and store session ids.
+ std::vector<uint8_t> sid_c1 = client_->session_id();
+ ASSERT_EQ(32, sid_c1.size());
+ std::vector<uint8_t> sid_s1 = server_->session_id();
+ ASSERT_EQ(32, sid_s1.size());
+ ASSERT_EQ(sid_c1, sid_s1);
+ session_ids_.push_back(sid_c1);
+}
+
+void TlsConnectTestBase::ConnectExpectFail() {
+ Handshake();
+
+ ASSERT_EQ(TlsAgent::ERROR, client_->state());
+ ASSERT_EQ(TlsAgent::ERROR, server_->state());
+}
+
+void TlsConnectTestBase::EnableSomeECDHECiphers() {
+ client_->EnableSomeECDHECiphers();
+ server_->EnableSomeECDHECiphers();
+}
+
+
+void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
+ SessionResumptionMode server) {
+ client_->ConfigureSessionCache(client);
+ server_->ConfigureSessionCache(server);
+}
+
+void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
+ ASSERT_NE(RESUME_BOTH, expected);
+
+ int resume_ct = expected ? 1 : 0;
+ int stateless_ct = (expected & RESUME_TICKET) ? 1 : 0;
+
+ SSL3Statistics* stats = SSL_GetStatistics();
+ ASSERT_EQ(resume_ct, stats->hch_sid_cache_hits);
+ ASSERT_EQ(resume_ct, stats->hsh_sid_cache_hits);
+
+ ASSERT_EQ(stateless_ct, stats->hch_sid_stateless_resumes);
+ ASSERT_EQ(stateless_ct, stats->hsh_sid_stateless_resumes);
+
+ if (resume_ct) {
+ // Check that the last two session ids match.
+ ASSERT_GE(2, session_ids_.size());
+ ASSERT_EQ(session_ids_[session_ids_.size()-1],
+ session_ids_[session_ids_.size()-2]);
+ }
+}
+
+void TlsConnectTestBase::EnableAlpn() {
+ // A simple value of "a", "b". Note that the preferred value of "a" is placed
+ // at the end, because the NSS API follows the now defunct NPN specification,
+ // which places the preferred (and default) entry at the end of the list.
+ // NSS will move this final entry to the front when used with ALPN.
+ static const uint8_t val[] = { 0x01, 0x62, 0x01, 0x61 };
+ client_->EnableAlpn(val, sizeof(val));
+ server_->EnableAlpn(val, sizeof(val));
+}
+
+void TlsConnectTestBase::EnableSrtp() {
+ client_->EnableSrtp();
+ server_->EnableSrtp();
+}
+
+void TlsConnectTestBase::CheckSrtp() {
+ client_->CheckSrtp();
+ server_->CheckSrtp();
+}
+
+TlsConnectGeneric::TlsConnectGeneric()
+ : TlsConnectTestBase((GetParam() == "TLS") ? STREAM : DGRAM) {
+ std::cerr << "Variant: " << GetParam() << std::endl;
+}
+
+} // namespace nss_test
diff --git a/external_tests/ssl_gtest/tls_connect.h b/external_tests/ssl_gtest/tls_connect.h
new file mode 100644
index 000000000..c263fe83f
--- /dev/null
+++ b/external_tests/ssl_gtest/tls_connect.h
@@ -0,0 +1,79 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_connect_h_
+#define tls_connect_h_
+
+#include "sslt.h"
+
+#include "tls_agent.h"
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+
+namespace nss_test {
+
+// A generic TLS connection test base.
+class TlsConnectTestBase : public ::testing::Test {
+ public:
+ TlsConnectTestBase(Mode mode);
+ virtual ~TlsConnectTestBase();
+
+ void SetUp();
+ void TearDown();
+
+ // Initialize client and server.
+ void Init();
+ // Re-initialize client and server.
+ void Reset();
+ // Make sure TLS is configured for a connection.
+ void EnsureTlsSetup();
+
+ // Run the handshake.
+ void Handshake();
+ // Connect and check that it works.
+ void Connect();
+ // Connect and expect it to fail.
+ void ConnectExpectFail();
+
+ void EnableSomeECDHECiphers();
+ void ConfigureSessionCache(SessionResumptionMode client,
+ SessionResumptionMode server);
+ void CheckResumption(SessionResumptionMode expected);
+ void EnableAlpn();
+ void EnableSrtp();
+ void CheckSrtp();
+
+ protected:
+ Mode mode_;
+ TlsAgent* client_;
+ TlsAgent* server_;
+ std::vector<std::vector<uint8_t>> session_ids_;
+};
+
+// A TLS-only test base.
+class TlsConnectTest : public TlsConnectTestBase {
+ public:
+ TlsConnectTest() : TlsConnectTestBase(STREAM) {}
+};
+
+// A DTLS-only test base.
+class DtlsConnectTest : public TlsConnectTestBase {
+ public:
+ DtlsConnectTest() : TlsConnectTestBase(DGRAM) {}
+};
+
+// A generic test class that can be either STREAM or DGRAM. This is configured
+// in ssl_loopback_unittest.cc. All uses of this should use TEST_P().
+class TlsConnectGeneric : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<std::string> {
+ public:
+ TlsConnectGeneric();
+};
+
+} // namespace nss_test
+
+#endif
diff --git a/external_tests/ssl_gtest/tls_filter.cc b/external_tests/ssl_gtest/tls_filter.cc
new file mode 100644
index 000000000..3cbe9e5ac
--- /dev/null
+++ b/external_tests/ssl_gtest/tls_filter.cc
@@ -0,0 +1,226 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_filter.h"
+
+#include <iostream>
+
+namespace nss_test {
+
+bool TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) {
+ bool changed = false;
+ size_t output_offset = 0U;
+ output->Allocate(input.len());
+
+ TlsParser parser(input);
+ while (parser.remaining()) {
+ size_t start = parser.consumed();
+ uint8_t content_type;
+ if (!parser.Read(&content_type)) {
+ return false;
+ }
+ uint32_t version;
+ if (!parser.Read(&version, 2)) {
+ return false;
+ }
+
+ if (IsDtls(version)) {
+ if (!parser.Skip(8)) {
+ return false;
+ }
+ }
+ size_t header_len = parser.consumed() - start;
+ output->Write(output_offset, input.data() + start, header_len);
+
+ DataBuffer record;
+ if (!parser.ReadVariable(&record, 2)) {
+ return false;
+ }
+
+ // Move the offset in the output forward. ApplyFilter() returns the index
+ // of the end of the record it wrote to the output, so we need to skip
+ // over the content type and version for the value passed to it.
+ output_offset = ApplyFilter(content_type, version, record, output,
+ output_offset + header_len,
+ &changed);
+ }
+ output->Truncate(output_offset);
+
+ // Record how many packets we actually touched.
+ if (changed) {
+ ++count_;
+ }
+
+ return changed;
+}
+
+size_t TlsRecordFilter::ApplyFilter(uint8_t content_type, uint16_t version,
+ const DataBuffer& record,
+ DataBuffer* output,
+ size_t offset, bool* changed) {
+ const DataBuffer* source = &record;
+ DataBuffer filtered;
+ if (FilterRecord(content_type, version, record, &filtered) &&
+ filtered.len() < 0x10000) {
+ *changed = true;
+ std::cerr << "record old: " << record << std::endl;
+ std::cerr << "record old: " << filtered << std::endl;
+ source = &filtered;
+ }
+
+ output->Write(offset, source->len(), 2);
+ output->Write(offset + 2, *source);
+ return offset + 2 + source->len();
+}
+
+bool TlsHandshakeFilter::FilterRecord(uint8_t content_type, uint16_t version,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ // Check that the first byte is as requested.
+ if (content_type != kTlsHandshakeType) {
+ return false;
+ }
+
+ bool changed = false;
+ size_t output_offset = 0U;
+ output->Allocate(input.len()); // Preallocate a little.
+
+ TlsParser parser(input);
+ while (parser.remaining()) {
+ size_t start = parser.consumed();
+ uint8_t handshake_type;
+ if (!parser.Read(&handshake_type)) {
+ return false; // malformed
+ }
+ uint32_t length;
+ if (!parser.Read(&length, 3)) {
+ return false; // malformed
+ }
+
+ if (IsDtls(version) && !CheckDtls(parser, length)) {
+ return false;
+ }
+
+ size_t header_len = parser.consumed() - start;
+ output->Write(output_offset, input.data() + start, header_len);
+
+ DataBuffer handshake;
+ if (!parser.Read(&handshake, length)) {
+ return false;
+ }
+
+ // Move the offset in the output forward. ApplyFilter() returns the index
+ // of the end of the message it wrote to the output, so we need to identify
+ // offsets from the start of the message for length and the handshake
+ // message.
+ output_offset = ApplyFilter(version, handshake_type, handshake,
+ output, output_offset + 1,
+ output_offset + header_len,
+ &changed);
+ }
+ output->Truncate(output_offset);
+ return changed;
+}
+
+bool TlsHandshakeFilter::CheckDtls(TlsParser& parser, size_t length) {
+ // Read and check DTLS parameters
+ if (!parser.Skip(2)) { // sequence number
+ return false;
+ }
+
+ uint32_t fragment_offset;
+ if (!parser.Read(&fragment_offset, 3)) {
+ return false;
+ }
+
+ uint32_t fragment_length;
+ if (!parser.Read(&fragment_length, 3)) {
+ return false;
+ }
+
+ // All current tests where we are using this code don't fragment.
+ return (fragment_offset == 0 && fragment_length == length);
+}
+
+size_t TlsHandshakeFilter::ApplyFilter(
+ uint16_t version, uint8_t handshake_type, const DataBuffer& handshake,
+ DataBuffer* output, size_t length_offset, size_t value_offset,
+ bool* changed) {
+ const DataBuffer* source = &handshake;
+ DataBuffer filtered;
+ if (FilterHandshake(version, handshake_type, handshake, &filtered) &&
+ filtered.len() < 0x1000000) {
+ *changed = true;
+ std::cerr << "handshake old: " << handshake << std::endl;
+ std::cerr << "handshake new: " << filtered << std::endl;
+ source = &filtered;
+ }
+
+ // Back up and overwrite the (two) length field(s): the handshake message
+ // length and the DTLS fragment length.
+ output->Write(length_offset, source->len(), 3);
+ if (IsDtls(version)) {
+ output->Write(length_offset + 8, source->len(), 3);
+ }
+ output->Write(value_offset, *source);
+ return value_offset + source->len();
+}
+
+bool TlsInspectorRecordHandshakeMessage::FilterHandshake(
+ uint16_t version, uint8_t handshake_type,
+ const DataBuffer& input, DataBuffer* output) {
+ // Only do this once.
+ if (buffer_.len()) {
+ return false;
+ }
+
+ if (handshake_type == handshake_type_) {
+ buffer_ = input;
+ }
+ return false;
+}
+
+bool TlsAlertRecorder::FilterRecord(uint8_t content_type, uint16_t version,
+ const DataBuffer& input, DataBuffer* output) {
+ if (level_ == kTlsAlertFatal) { // already fatal
+ return false;
+ }
+ if (content_type != kTlsAlertType) {
+ return false;
+ }
+
+ TlsParser parser(input);
+ uint8_t lvl;
+ if (!parser.Read(&lvl)) {
+ return false;
+ }
+ if (lvl == kTlsAlertWarning) { // not strong enough
+ return false;
+ }
+ level_ = lvl;
+ (void)parser.Read(&description_);
+ return false;
+}
+
+ChainedPacketFilter::~ChainedPacketFilter() {
+ for (auto it = filters_.begin(); it != filters_.end(); ++it) {
+ delete *it;
+ }
+}
+
+bool ChainedPacketFilter::Filter(const DataBuffer& input, DataBuffer* output) {
+ DataBuffer in(input);
+ bool changed = false;
+ for (auto it = filters_.begin(); it != filters_.end(); ++it) {
+ if ((*it)->Filter(in, output)) {
+ in = *output;
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace nss_test
diff --git a/external_tests/ssl_gtest/tls_filter.h b/external_tests/ssl_gtest/tls_filter.h
new file mode 100644
index 000000000..7ebd2c482
--- /dev/null
+++ b/external_tests/ssl_gtest/tls_filter.h
@@ -0,0 +1,113 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_filter_h_
+#define tls_filter_h_
+
+#include <memory>
+#include <vector>
+
+#include "test_io.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// Abstract filter that operates on entire (D)TLS records.
+class TlsRecordFilter : public PacketFilter {
+ public:
+ TlsRecordFilter() : count_(0) {}
+
+ virtual bool Filter(const DataBuffer& input, DataBuffer* output);
+
+ // Report how many packets were altered by the filter.
+ size_t filtered_packets() const { return count_; }
+
+ protected:
+ virtual bool FilterRecord(uint8_t content_type, uint16_t version,
+ const DataBuffer& data, DataBuffer* changed) = 0;
+ private:
+ size_t ApplyFilter(uint8_t content_type, uint16_t version,
+ const DataBuffer& record, DataBuffer* output,
+ size_t offset, bool* changed);
+
+ size_t count_;
+};
+
+// Abstract filter that operates on handshake messages rather than records.
+// This assumes that the handshake messages are written in a block as entire
+// records and that they don't span records or anything crazy like that.
+class TlsHandshakeFilter : public TlsRecordFilter {
+ public:
+ TlsHandshakeFilter() {}
+
+ protected:
+ virtual bool FilterRecord(uint8_t content_type, uint16_t version,
+ const DataBuffer& input, DataBuffer* output);
+ virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
+ const DataBuffer& input, DataBuffer* output) = 0;
+
+ private:
+ bool CheckDtls(TlsParser& parser, size_t length);
+ size_t ApplyFilter(uint16_t version, uint8_t handshake_type,
+ const DataBuffer& record, DataBuffer* output,
+ size_t length_offset, size_t value_offset, bool* changed);
+};
+
+// Make a copy of the first instance of a handshake message.
+class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
+ public:
+ TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
+ : handshake_type_(handshake_type), buffer_() {}
+
+ virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
+ const DataBuffer& input, DataBuffer* output);
+
+ const DataBuffer& buffer() const { return buffer_; }
+
+ private:
+ uint8_t handshake_type_;
+ DataBuffer buffer_;
+};
+
+// Records an alert. If an alert has already been recorded, it won't save the
+// new alert unless the old alert is a warning and the new one is fatal.
+class TlsAlertRecorder : public TlsRecordFilter {
+ public:
+ TlsAlertRecorder() : level_(255), description_(255) {}
+
+ virtual bool FilterRecord(uint8_t content_type, uint16_t version,
+ const DataBuffer& input, DataBuffer* output);
+
+ uint8_t level() const { return level_; }
+ uint8_t description() const { return description_; }
+
+ private:
+ uint8_t level_;
+ uint8_t description_;
+};
+
+// Runs multiple packet filters in series.
+class ChainedPacketFilter : public PacketFilter {
+ public:
+ ChainedPacketFilter() {}
+ ChainedPacketFilter(const std::vector<PacketFilter*> filters)
+ : filters_(filters.begin(), filters.end()) {}
+ virtual ~ChainedPacketFilter();
+
+ virtual bool Filter(const DataBuffer& input, DataBuffer* output);
+
+ // Takes ownership of the filter.
+ void Add(PacketFilter* filter) {
+ filters_.push_back(filter);
+ }
+
+ private:
+ std::vector<PacketFilter*> filters_;
+};
+
+} // namespace nss_test
+
+#endif
diff --git a/external_tests/ssl_gtest/tls_parser.cc b/external_tests/ssl_gtest/tls_parser.cc
index cbd4c0239..1d56fffbf 100644
--- a/external_tests/ssl_gtest/tls_parser.cc
+++ b/external_tests/ssl_gtest/tls_parser.cc
@@ -6,13 +6,9 @@
#include "tls_parser.h"
-// Process DTLS Records
-#define CHECK_LENGTH(expected) \
- do { \
- if (remaining() < expected) return false; \
- } while (0)
+namespace nss_test {
-bool TlsParser::Read(unsigned char* val) {
+bool TlsParser::Read(uint8_t* val) {
if (remaining() < 1) {
return false;
}
@@ -21,37 +17,55 @@ bool TlsParser::Read(unsigned char* val) {
return true;
}
-bool TlsParser::Read(unsigned char* val, size_t len) {
- if (remaining() < len) {
+bool TlsParser::Read(uint32_t* val, size_t size) {
+ if (size > sizeof(uint32_t)) {
return false;
}
- if (val) {
- memcpy(val, ptr(), len);
+ uint32_t v = 0;
+ for (size_t i = 0; i < size; ++i) {
+ uint8_t tmp;
+ if (!Read(&tmp)) {
+ return false;
+ }
+
+ v = (v << 8) | tmp;
}
- consume(len);
+ *val = v;
return true;
}
-bool TlsRecordParser::NextRecord(uint8_t* ct,
- std::auto_ptr<DataBuffer>* buffer) {
- if (!remaining()) return false;
-
- CHECK_LENGTH(5U);
- const uint8_t* ctp = reinterpret_cast<const uint8_t*>(ptr());
- consume(3); // ct + version
-
- const uint16_t* tmp = reinterpret_cast<const uint16_t*>(ptr());
- size_t length = ntohs(*tmp);
- consume(2);
+bool TlsParser::Read(DataBuffer* val, size_t len) {
+ if (remaining() < len) {
+ return false;
+ }
- CHECK_LENGTH(length);
- DataBuffer* db = new DataBuffer(ptr(), length);
- consume(length);
+ val->Assign(ptr(), len);
+ consume(len);
+ return true;
+}
- *ct = *ctp;
- buffer->reset(db);
+bool TlsParser::ReadVariable(DataBuffer* val, size_t len_size) {
+ uint32_t len;
+ if (!Read(&len, len_size)) {
+ return false;
+ }
+ return Read(val, len);
+}
+bool TlsParser::Skip(size_t len) {
+ if (len > remaining()) { return false; }
+ consume(len);
return true;
}
+
+bool TlsParser::SkipVariable(size_t len_size) {
+ uint32_t len;
+ if (!Read(&len, len_size)) {
+ return false;
+ }
+ return Skip(len);
+}
+
+} // namespace nss_test
diff --git a/external_tests/ssl_gtest/tls_parser.h b/external_tests/ssl_gtest/tls_parser.h
index 0276501f0..9ac4bdabe 100644
--- a/external_tests/ssl_gtest/tls_parser.h
+++ b/external_tests/ssl_gtest/tls_parser.h
@@ -8,17 +8,31 @@
#define tls_parser_h_
#include <memory>
-#include <stdint.h>
-#include <string.h>
+#include <cstdint>
+#include <cstring>
#include <arpa/inet.h>
#include "databuffer.h"
+namespace nss_test {
+
const uint8_t kTlsChangeCipherSpecType = 0x14;
+const uint8_t kTlsAlertType = 0x15;
const uint8_t kTlsHandshakeType = 0x16;
+const uint8_t kTlsHandshakeClientHello = 0x01;
+const uint8_t kTlsHandshakeServerHello = 0x02;
const uint8_t kTlsHandshakeCertificate = 0x0b;
const uint8_t kTlsHandshakeServerKeyExchange = 0x0c;
+const uint8_t kTlsAlertWarning = 1;
+const uint8_t kTlsAlertFatal = 2;
+
+const uint8_t kTlsAlertHandshakeFailure = 0x28;
+const uint8_t kTlsAlertIllegalParameter = 0x2f;
+const uint8_t kTlsAlertDecodeError = 0x32;
+const uint8_t kTlsAlertUnsupportedExtension = 0x6e;
+const uint8_t kTlsAlertNoApplicationProtocol = 0x78;
+
const uint8_t kTlsFakeChangeCipherSpec[] = {
kTlsChangeCipherSpecType, // Type
0xfe, 0xff, // Version
@@ -28,56 +42,56 @@ const uint8_t kTlsFakeChangeCipherSpec[] = {
0x01 // Value
};
+inline bool IsDtls(uint16_t version) {
+ return (version & 0x8000) == 0x8000;
+}
+
+inline uint16_t NormalizeTlsVersion(uint16_t version) {
+ if (version == 0xfeff) {
+ return 0x0302; // special: DTLS 1.0 == TLS 1.1
+ }
+ if (IsDtls(version)) {
+ return (version ^ 0xffff) + 0x0201;
+ }
+ return version;
+}
+
+inline void WriteVariable(DataBuffer* target, size_t index,
+ const DataBuffer& buf, size_t len_size) {
+ target->Write(index, static_cast<uint32_t>(buf.len()), len_size);
+ target->Write(index + len_size, buf.data(), buf.len());
+}
+
class TlsParser {
public:
- TlsParser(const unsigned char *data, size_t len)
+ TlsParser(const uint8_t* data, size_t len)
: buffer_(data, len), offset_(0) {}
+ explicit TlsParser(const DataBuffer& buf)
+ : buffer_(buf), offset_(0) {}
- bool Read(unsigned char *val);
-
+ bool Read(uint8_t* val);
// Read an integral type of specified width.
- bool Read(uint32_t *val, size_t len) {
- if (len > sizeof(uint32_t)) return false;
-
- *val = 0;
+ bool Read(uint32_t* val, size_t size);
+ // Reads len bytes into dest buffer, overwriting it.
+ bool Read(DataBuffer* dest, size_t len);
+ // Reads bytes into dest buffer, overwriting it. The number of bytes is
+ // determined by reading from len_size bytes from the stream first.
+ bool ReadVariable(DataBuffer* dest, size_t len_size);
- for (size_t i = 0; i < len; ++i) {
- unsigned char tmp;
+ bool Skip(size_t len);
+ bool SkipVariable(size_t len_size);
- (*val) <<= 8;
- if (!Read(&tmp)) return false;
-
- *val += tmp;
- }
-
- return true;
- }
-
- bool Read(unsigned char *val, size_t len);
+ size_t consumed() const { return offset_; }
size_t remaining() const { return buffer_.len() - offset_; }
private:
void consume(size_t len) { offset_ += len; }
- const uint8_t *ptr() const { return buffer_.data() + offset_; }
+ const uint8_t* ptr() const { return buffer_.data() + offset_; }
DataBuffer buffer_;
size_t offset_;
};
-class TlsRecordParser {
- public:
- TlsRecordParser(const unsigned char *data, size_t len)
- : buffer_(data, len), offset_(0) {}
-
- bool NextRecord(uint8_t *ct, std::auto_ptr<DataBuffer> *buffer);
-
- private:
- size_t remaining() const { return buffer_.len() - offset_; }
- const uint8_t *ptr() const { return buffer_.data() + offset_; }
- void consume(size_t len) { offset_ += len; }
-
- DataBuffer buffer_;
- size_t offset_;
-};
+} // namespace nss_test
#endif