diff options
-rw-r--r-- | gtests/ssl_gtest/ssl_fragment_unittest.cc | 158 | ||||
-rw-r--r-- | lib/ssl/dtlscon.c | 58 |
2 files changed, 184 insertions, 32 deletions
diff --git a/gtests/ssl_gtest/ssl_fragment_unittest.cc b/gtests/ssl_gtest/ssl_fragment_unittest.cc new file mode 100644 index 000000000..fce46f5ae --- /dev/null +++ b/gtests/ssl_gtest/ssl_fragment_unittest.cc @@ -0,0 +1,158 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +// This class cuts every unencrypted handshake record into two parts. +class RecordFragmenter : public PacketFilter { + public: + RecordFragmenter() : sequence_number_(0), splitting_(true) {} + + private: + class HandshakeSplitter { + public: + HandshakeSplitter(const DataBuffer& input, DataBuffer* output, + uint64_t* sequence_number) + : input_(input), + output_(output), + cursor_(0), + sequence_number_(sequence_number) {} + + private: + void WriteRecord(TlsRecordFilter::RecordHeader& record_header, + DataBuffer& record_fragment) { + TlsRecordFilter::RecordHeader fragment_header( + record_header.version(), record_header.content_type(), + *sequence_number_); + ++*sequence_number_; + if (::g_ssl_gtest_verbose) { + std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment + << std::endl; + } + cursor_ = fragment_header.Write(output_, cursor_, record_fragment); + } + + bool SplitRecord(TlsRecordFilter::RecordHeader& record_header, + DataBuffer& record) { + TlsParser parser(record); + while (parser.remaining()) { + TlsHandshakeFilter::HandshakeHeader handshake_header; + DataBuffer handshake_body; + if (!handshake_header.Parse(&parser, record_header, &handshake_body)) { + ADD_FAILURE() << "couldn't parse handshake header"; + return false; + } + + DataBuffer record_fragment; + // We can't fragment handshake records that are too small. + if (handshake_body.len() < 2) { + handshake_header.Write(&record_fragment, 0U, handshake_body); + WriteRecord(record_header, record_fragment); + continue; + } + + size_t cut = handshake_body.len() / 2; + handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 0U, + cut); + WriteRecord(record_header, record_fragment); + + handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, + cut, handshake_body.len() - cut); + WriteRecord(record_header, record_fragment); + } + return true; + } + + public: + bool Split() { + TlsParser parser(input_); + while (parser.remaining()) { + TlsRecordFilter::RecordHeader header; + DataBuffer record; + if (!header.Parse(&parser, &record)) { + ADD_FAILURE() << "bad record header"; + return false; + } + + if (::g_ssl_gtest_verbose) { + std::cerr << "Record: " << header << ' ' << record << std::endl; + } + + // Don't touch packets from a non-zero epoch. Leave these unmodified. + if ((header.sequence_number() >> 48) != 0ULL) { + cursor_ = header.Write(output_, cursor_, record); + continue; + } + + // Just rewrite the sequence number (CCS only). + if (header.content_type() != kTlsHandshakeType) { + EXPECT_EQ(kTlsChangeCipherSpecType, header.content_type()); + WriteRecord(header, record); + continue; + } + + if (!SplitRecord(header, record)) { + return false; + } + } + return true; + } + + private: + const DataBuffer& input_; + DataBuffer* output_; + size_t cursor_; + uint64_t* sequence_number_; + }; + + protected: + virtual PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override { + if (!splitting_) { + return KEEP; + } + + output->Allocate(input.len()); + HandshakeSplitter splitter(input, output, &sequence_number_); + if (!splitter.Split()) { + // If splitting fails, we obviously reached encrypted packets. + // Stop splitting from that point onward. + splitting_ = false; + return KEEP; + } + + return CHANGE; + } + + private: + uint64_t sequence_number_; + bool splitting_; +}; + +TEST_P(TlsConnectDatagram, FragmentClientPackets) { + client_->SetPacketFilter(new RecordFragmenter()); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectDatagram, FragmentServerPackets) { + server_->SetPacketFilter(new RecordFragmenter()); + Connect(); + SendReceive(); +} + +} // namespace nss_test diff --git a/lib/ssl/dtlscon.c b/lib/ssl/dtlscon.c index 8aebda389..f73c9aeed 100644 --- a/lib/ssl/dtlscon.c +++ b/lib/ssl/dtlscon.c @@ -237,6 +237,26 @@ dtls_RetransmitDetected(sslSocket *ss) return rv; } +static SECStatus +dtls_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *data, PRBool last) +{ + + /* At this point we are advancing our state machine, so we can free our last + * flight of messages. */ + dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight); + ss->ssl3.hs.recvdHighWater = -1; + + /* Reset the timer to the initial value if the retry counter + * is 0, per Sec. 4.2.4.1 */ + dtls_CancelTimer(ss); + if (ss->ssl3.hs.rtRetries == 0) { + ss->ssl3.hs.rtTimeoutMs = DTLS_RETRANSMIT_INITIAL_MS; + } + + return ssl3_HandleHandshakeMessage(ss, data, ss->ssl3.hs.msg_len, + last); +} + /* Called only from ssl3_HandleRecord, for each (deciphered) DTLS record. * origBuf is the decrypted ssl record content and is expected to contain * complete handshake records @@ -331,23 +351,10 @@ dtls_HandleHandshake(sslSocket *ss, sslBuffer *origBuf) ss->ssl3.hs.msg_type = (SSL3HandshakeType)type; ss->ssl3.hs.msg_len = message_length; - /* At this point we are advancing our state machine, so - * we can free our last flight of messages */ - dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight); - ss->ssl3.hs.recvdHighWater = -1; - dtls_CancelTimer(ss); - - /* Reset the timer to the initial value if the retry counter - * is 0, per Sec. 4.2.4.1 */ - if (ss->ssl3.hs.rtRetries == 0) { - ss->ssl3.hs.rtTimeoutMs = DTLS_RETRANSMIT_INITIAL_MS; - } - - rv = ssl3_HandleHandshakeMessage(ss, buf.buf, ss->ssl3.hs.msg_len, + rv = dtls_HandleHandshakeMessage(ss, buf.buf, buf.len == fragment_length); if (rv == SECFailure) { - /* Do not attempt to process rest of messages in this record */ - break; + break; /* Discard the remainder of the record. */ } } else { if (message_seq < ss->ssl3.hs.recvMessageSeq) { @@ -448,24 +455,11 @@ dtls_HandleHandshake(sslSocket *ss, sslBuffer *origBuf) /* If we have all the bytes, then we are good to go */ if (ss->ssl3.hs.recvdHighWater == ss->ssl3.hs.msg_len) { - ss->ssl3.hs.recvdHighWater = -1; + rv = dtls_HandleHandshakeMessage(ss, ss->ssl3.hs.msg_body.buf, + buf.len == fragment_length); - rv = ssl3_HandleHandshakeMessage( - ss, - ss->ssl3.hs.msg_body.buf, ss->ssl3.hs.msg_len, - buf.len == fragment_length); - if (rv == SECFailure) - break; /* Skip rest of record */ - - /* At this point we are advancing our state machine, so - * we can free our last flight of messages */ - dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight); - dtls_CancelTimer(ss); - - /* If there have been no retries this time, reset the - * timer value to the default per Section 4.2.4.1 */ - if (ss->ssl3.hs.rtRetries == 0) { - ss->ssl3.hs.rtTimeoutMs = DTLS_RETRANSMIT_INITIAL_MS; + if (rv == SECFailure) { + break; /* Discard the rest of the record. */ } } } |