summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Thomson <martin.thomson@gmail.com>2016-11-22 10:43:09 +1100
committerMartin Thomson <martin.thomson@gmail.com>2016-11-22 10:43:09 +1100
commit35ec5710364f95d871681320c005bd01278787e6 (patch)
tree0369a05f48a4ee696dbb0bfc5b9bd811e75fdc5c
parent8cc4458cccc1ba92906ffeaeef1a0fcbc6d3694b (diff)
downloadnss-hg-35ec5710364f95d871681320c005bd01278787e6.tar.gz
Bug 1315405 - Refactor DTLS fragment handling, r=ttaubert
-rw-r--r--gtests/ssl_gtest/ssl_fragment_unittest.cc158
-rw-r--r--lib/ssl/dtlscon.c58
2 files changed, 184 insertions, 32 deletions
diff --git a/gtests/ssl_gtest/ssl_fragment_unittest.cc b/gtests/ssl_gtest/ssl_fragment_unittest.cc
new file mode 100644
index 000000000..fce46f5ae
--- /dev/null
+++ b/gtests/ssl_gtest/ssl_fragment_unittest.cc
@@ -0,0 +1,158 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// This class cuts every unencrypted handshake record into two parts.
+class RecordFragmenter : public PacketFilter {
+ public:
+ RecordFragmenter() : sequence_number_(0), splitting_(true) {}
+
+ private:
+ class HandshakeSplitter {
+ public:
+ HandshakeSplitter(const DataBuffer& input, DataBuffer* output,
+ uint64_t* sequence_number)
+ : input_(input),
+ output_(output),
+ cursor_(0),
+ sequence_number_(sequence_number) {}
+
+ private:
+ void WriteRecord(TlsRecordFilter::RecordHeader& record_header,
+ DataBuffer& record_fragment) {
+ TlsRecordFilter::RecordHeader fragment_header(
+ record_header.version(), record_header.content_type(),
+ *sequence_number_);
+ ++*sequence_number_;
+ if (::g_ssl_gtest_verbose) {
+ std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment
+ << std::endl;
+ }
+ cursor_ = fragment_header.Write(output_, cursor_, record_fragment);
+ }
+
+ bool SplitRecord(TlsRecordFilter::RecordHeader& record_header,
+ DataBuffer& record) {
+ TlsParser parser(record);
+ while (parser.remaining()) {
+ TlsHandshakeFilter::HandshakeHeader handshake_header;
+ DataBuffer handshake_body;
+ if (!handshake_header.Parse(&parser, record_header, &handshake_body)) {
+ ADD_FAILURE() << "couldn't parse handshake header";
+ return false;
+ }
+
+ DataBuffer record_fragment;
+ // We can't fragment handshake records that are too small.
+ if (handshake_body.len() < 2) {
+ handshake_header.Write(&record_fragment, 0U, handshake_body);
+ WriteRecord(record_header, record_fragment);
+ continue;
+ }
+
+ size_t cut = handshake_body.len() / 2;
+ handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 0U,
+ cut);
+ WriteRecord(record_header, record_fragment);
+
+ handshake_header.WriteFragment(&record_fragment, 0U, handshake_body,
+ cut, handshake_body.len() - cut);
+ WriteRecord(record_header, record_fragment);
+ }
+ return true;
+ }
+
+ public:
+ bool Split() {
+ TlsParser parser(input_);
+ while (parser.remaining()) {
+ TlsRecordFilter::RecordHeader header;
+ DataBuffer record;
+ if (!header.Parse(&parser, &record)) {
+ ADD_FAILURE() << "bad record header";
+ return false;
+ }
+
+ if (::g_ssl_gtest_verbose) {
+ std::cerr << "Record: " << header << ' ' << record << std::endl;
+ }
+
+ // Don't touch packets from a non-zero epoch. Leave these unmodified.
+ if ((header.sequence_number() >> 48) != 0ULL) {
+ cursor_ = header.Write(output_, cursor_, record);
+ continue;
+ }
+
+ // Just rewrite the sequence number (CCS only).
+ if (header.content_type() != kTlsHandshakeType) {
+ EXPECT_EQ(kTlsChangeCipherSpecType, header.content_type());
+ WriteRecord(header, record);
+ continue;
+ }
+
+ if (!SplitRecord(header, record)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private:
+ const DataBuffer& input_;
+ DataBuffer* output_;
+ size_t cursor_;
+ uint64_t* sequence_number_;
+ };
+
+ protected:
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override {
+ if (!splitting_) {
+ return KEEP;
+ }
+
+ output->Allocate(input.len());
+ HandshakeSplitter splitter(input, output, &sequence_number_);
+ if (!splitter.Split()) {
+ // If splitting fails, we obviously reached encrypted packets.
+ // Stop splitting from that point onward.
+ splitting_ = false;
+ return KEEP;
+ }
+
+ return CHANGE;
+ }
+
+ private:
+ uint64_t sequence_number_;
+ bool splitting_;
+};
+
+TEST_P(TlsConnectDatagram, FragmentClientPackets) {
+ client_->SetPacketFilter(new RecordFragmenter());
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectDatagram, FragmentServerPackets) {
+ server_->SetPacketFilter(new RecordFragmenter());
+ Connect();
+ SendReceive();
+}
+
+} // namespace nss_test
diff --git a/lib/ssl/dtlscon.c b/lib/ssl/dtlscon.c
index 8aebda389..f73c9aeed 100644
--- a/lib/ssl/dtlscon.c
+++ b/lib/ssl/dtlscon.c
@@ -237,6 +237,26 @@ dtls_RetransmitDetected(sslSocket *ss)
return rv;
}
+static SECStatus
+dtls_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *data, PRBool last)
+{
+
+ /* At this point we are advancing our state machine, so we can free our last
+ * flight of messages. */
+ dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
+ ss->ssl3.hs.recvdHighWater = -1;
+
+ /* Reset the timer to the initial value if the retry counter
+ * is 0, per Sec. 4.2.4.1 */
+ dtls_CancelTimer(ss);
+ if (ss->ssl3.hs.rtRetries == 0) {
+ ss->ssl3.hs.rtTimeoutMs = DTLS_RETRANSMIT_INITIAL_MS;
+ }
+
+ return ssl3_HandleHandshakeMessage(ss, data, ss->ssl3.hs.msg_len,
+ last);
+}
+
/* Called only from ssl3_HandleRecord, for each (deciphered) DTLS record.
* origBuf is the decrypted ssl record content and is expected to contain
* complete handshake records
@@ -331,23 +351,10 @@ dtls_HandleHandshake(sslSocket *ss, sslBuffer *origBuf)
ss->ssl3.hs.msg_type = (SSL3HandshakeType)type;
ss->ssl3.hs.msg_len = message_length;
- /* At this point we are advancing our state machine, so
- * we can free our last flight of messages */
- dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
- ss->ssl3.hs.recvdHighWater = -1;
- dtls_CancelTimer(ss);
-
- /* Reset the timer to the initial value if the retry counter
- * is 0, per Sec. 4.2.4.1 */
- if (ss->ssl3.hs.rtRetries == 0) {
- ss->ssl3.hs.rtTimeoutMs = DTLS_RETRANSMIT_INITIAL_MS;
- }
-
- rv = ssl3_HandleHandshakeMessage(ss, buf.buf, ss->ssl3.hs.msg_len,
+ rv = dtls_HandleHandshakeMessage(ss, buf.buf,
buf.len == fragment_length);
if (rv == SECFailure) {
- /* Do not attempt to process rest of messages in this record */
- break;
+ break; /* Discard the remainder of the record. */
}
} else {
if (message_seq < ss->ssl3.hs.recvMessageSeq) {
@@ -448,24 +455,11 @@ dtls_HandleHandshake(sslSocket *ss, sslBuffer *origBuf)
/* If we have all the bytes, then we are good to go */
if (ss->ssl3.hs.recvdHighWater == ss->ssl3.hs.msg_len) {
- ss->ssl3.hs.recvdHighWater = -1;
+ rv = dtls_HandleHandshakeMessage(ss, ss->ssl3.hs.msg_body.buf,
+ buf.len == fragment_length);
- rv = ssl3_HandleHandshakeMessage(
- ss,
- ss->ssl3.hs.msg_body.buf, ss->ssl3.hs.msg_len,
- buf.len == fragment_length);
- if (rv == SECFailure)
- break; /* Skip rest of record */
-
- /* At this point we are advancing our state machine, so
- * we can free our last flight of messages */
- dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
- dtls_CancelTimer(ss);
-
- /* If there have been no retries this time, reset the
- * timer value to the default per Section 4.2.4.1 */
- if (ss->ssl3.hs.rtRetries == 0) {
- ss->ssl3.hs.rtTimeoutMs = DTLS_RETRANSMIT_INITIAL_MS;
+ if (rv == SECFailure) {
+ break; /* Discard the rest of the record. */
}
}
}