summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gtests/ssl_gtest/libssl_internals.c13
-rw-r--r--gtests/ssl_gtest/ssl_ciphersuite_unittest.cc11
-rw-r--r--gtests/ssl_gtest/ssl_drop_unittest.cc78
-rw-r--r--gtests/ssl_gtest/ssl_fragment_unittest.cc28
-rw-r--r--gtests/ssl_gtest/ssl_hrr_unittest.cc5
-rw-r--r--gtests/ssl_gtest/ssl_loopback_unittest.cc3
-rw-r--r--gtests/ssl_gtest/ssl_record_unittest.cc23
-rw-r--r--gtests/ssl_gtest/tls_agent.cc22
-rw-r--r--gtests/ssl_gtest/tls_agent.h3
-rw-r--r--gtests/ssl_gtest/tls_filter.cc187
-rw-r--r--gtests/ssl_gtest/tls_filter.h29
-rw-r--r--lib/ssl/dtls13con.c37
-rw-r--r--lib/ssl/dtls13con.h4
-rw-r--r--lib/ssl/dtlscon.c105
-rw-r--r--lib/ssl/dtlscon.h2
-rw-r--r--lib/ssl/ssl.h11
-rw-r--r--lib/ssl/ssl3con.c233
-rw-r--r--lib/ssl/ssl3gthr.c106
-rw-r--r--lib/ssl/sslimpl.h16
-rw-r--r--lib/ssl/sslsecur.c2
-rw-r--r--lib/ssl/sslsock.c17
-rw-r--r--lib/ssl/sslspec.h4
-rw-r--r--lib/ssl/tls13con.c95
-rw-r--r--lib/ssl/tls13con.h1
24 files changed, 724 insertions, 311 deletions
diff --git a/gtests/ssl_gtest/libssl_internals.c b/gtests/ssl_gtest/libssl_internals.c
index 17b4ffe49..e43113de4 100644
--- a/gtests/ssl_gtest/libssl_internals.c
+++ b/gtests/ssl_gtest/libssl_internals.c
@@ -237,22 +237,23 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) {
if (!ss) {
return SECFailure;
}
- if (to >= RECORD_SEQ_MAX) {
+ if (to > RECORD_SEQ_MAX) {
PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
}
ssl_GetSpecWriteLock(ss);
spec = ss->ssl3.crSpec;
- spec->seqNum = to;
+ spec->nextSeqNum = to;
/* For DTLS, we need to fix the record sequence number. For this, we can just
* scrub the entire structure on the assumption that the new sequence number
* is far enough past the last received sequence number. */
- if (spec->seqNum <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) {
+ if (spec->nextSeqNum <=
+ spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) {
PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
}
- dtls_RecordSetRecvd(&spec->recvdRecords, spec->seqNum);
+ dtls_RecordSetRecvd(&spec->recvdRecords, spec->nextSeqNum - 1);
ssl_ReleaseSpecWriteLock(ss);
return SECSuccess;
@@ -270,7 +271,7 @@ SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) {
return SECFailure;
}
ssl_GetSpecWriteLock(ss);
- ss->ssl3.cwSpec->seqNum = to;
+ ss->ssl3.cwSpec->nextSeqNum = to;
ssl_ReleaseSpecWriteLock(ss);
return SECSuccess;
}
@@ -284,7 +285,7 @@ SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) {
return SECFailure;
}
ssl_GetSpecReadLock(ss);
- to = ss->ssl3.cwSpec->seqNum + DTLS_RECVD_RECORDS_WINDOW + extra;
+ to = ss->ssl3.cwSpec->nextSeqNum + DTLS_RECVD_RECORDS_WINDOW + extra;
ssl_ReleaseSpecReadLock(ss);
return SSLInt_AdvanceWriteSeqNum(fd, to);
}
diff --git a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
index fa2238be7..ec289bdd6 100644
--- a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
+++ b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
@@ -166,8 +166,8 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase {
case ssl_calg_seed:
break;
}
- EXPECT_TRUE(false) << "No limit for " << csinfo_.cipherSuiteName;
- return 1ULL < 48;
+ ADD_FAILURE() << "No limit for " << csinfo_.cipherSuiteName;
+ return 0;
}
uint64_t last_safe_write() const {
@@ -246,12 +246,13 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) {
client_->SendData(10, 10);
server_->ReadBytes(); // This should be OK.
+ server_->ReadBytes(); // Read twice to flush any 1,N-1 record splitting.
} else {
// In TLS 1.3, reading or writing triggers a KeyUpdate. That would mean
// that the sequence numbers would reset and we wouldn't hit the limit. So
- // we move the sequence number to one less than the limit directly and don't
- // test sending and receiving just before the limit.
- uint64_t last = record_limit() - 1;
+ // move the sequence number to the limit directly and don't test sending and
+ // receiving just before the limit.
+ uint64_t last = record_limit();
EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last));
}
diff --git a/gtests/ssl_gtest/ssl_drop_unittest.cc b/gtests/ssl_gtest/ssl_drop_unittest.cc
index ee8906deb..a6c25bacf 100644
--- a/gtests/ssl_gtest/ssl_drop_unittest.cc
+++ b/gtests/ssl_gtest/ssl_drop_unittest.cc
@@ -66,7 +66,8 @@ TEST_P(TlsConnectDatagramPre13, DropServerSecondFlightThrice) {
Connect();
}
-class TlsDropDatagram13 : public TlsConnectDatagram13 {
+class TlsDropDatagram13 : public TlsConnectDatagram13,
+ public ::testing::WithParamInterface<bool> {
public:
TlsDropDatagram13()
: client_filters_(),
@@ -77,6 +78,9 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 {
void SetUp() override {
TlsConnectDatagram13::SetUp();
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ int short_header = GetParam() ? PR_TRUE : PR_FALSE;
+ client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, short_header);
+ server_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, short_header);
SetFilters();
}
@@ -186,7 +190,7 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 {
// to the client upon receiving the client Finished.
// Dropping complete first and second flights does not produce
// ACKs
-TEST_F(TlsDropDatagram13, DropClientFirstFlightOnce) {
+TEST_P(TlsDropDatagram13, DropClientFirstFlightOnce) {
client_filters_.drop_->Reset({0});
StartConnect();
client_->Handshake();
@@ -195,7 +199,7 @@ TEST_F(TlsDropDatagram13, DropClientFirstFlightOnce) {
CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
}
-TEST_F(TlsDropDatagram13, DropServerFirstFlightOnce) {
+TEST_P(TlsDropDatagram13, DropServerFirstFlightOnce) {
server_filters_.drop_->Reset(0xff);
StartConnect();
client_->Handshake();
@@ -209,7 +213,7 @@ TEST_F(TlsDropDatagram13, DropServerFirstFlightOnce) {
// Dropping the server's first record also does not produce
// an ACK because the next record is ignored.
// TODO(ekr@rtfm.com): We should generate an empty ACK.
-TEST_F(TlsDropDatagram13, DropServerFirstRecordOnce) {
+TEST_P(TlsDropDatagram13, DropServerFirstRecordOnce) {
server_filters_.drop_->Reset({0});
StartConnect();
client_->Handshake();
@@ -221,7 +225,7 @@ TEST_F(TlsDropDatagram13, DropServerFirstRecordOnce) {
// Dropping the second packet of the server's flight should
// produce an ACK.
-TEST_F(TlsDropDatagram13, DropServerSecondRecordOnce) {
+TEST_P(TlsDropDatagram13, DropServerSecondRecordOnce) {
server_filters_.drop_->Reset({1});
StartConnect();
client_->Handshake();
@@ -235,7 +239,7 @@ TEST_F(TlsDropDatagram13, DropServerSecondRecordOnce) {
// Drop the server ACK and verify that the client retransmits
// the ClientHello.
-TEST_F(TlsDropDatagram13, DropServerAckOnce) {
+TEST_P(TlsDropDatagram13, DropServerAckOnce) {
StartConnect();
client_->Handshake();
server_->Handshake();
@@ -263,7 +267,7 @@ TEST_F(TlsDropDatagram13, DropServerAckOnce) {
}
// Drop the client certificate verify.
-TEST_F(TlsDropDatagram13, DropClientCertVerify) {
+TEST_P(TlsDropDatagram13, DropClientCertVerify) {
StartConnect();
client_->SetupClientAuth();
server_->RequestClientAuth(true);
@@ -284,7 +288,7 @@ TEST_F(TlsDropDatagram13, DropClientCertVerify) {
}
// Shrink the MTU down so that certs get split and drop the first piece.
-TEST_F(TlsDropDatagram13, DropFirstHalfOfServerCertificate) {
+TEST_P(TlsDropDatagram13, DropFirstHalfOfServerCertificate) {
server_filters_.drop_->Reset({2});
StartConnect();
ShrinkPostServerHelloMtu();
@@ -311,7 +315,7 @@ TEST_F(TlsDropDatagram13, DropFirstHalfOfServerCertificate) {
}
// Shrink the MTU down so that certs get split and drop the second piece.
-TEST_F(TlsDropDatagram13, DropSecondHalfOfServerCertificate) {
+TEST_P(TlsDropDatagram13, DropSecondHalfOfServerCertificate) {
server_filters_.drop_->Reset({3});
StartConnect();
ShrinkPostServerHelloMtu();
@@ -524,11 +528,11 @@ class TlsFragmentationAndRecoveryTest : public TlsDropDatagram13 {
size_t cert_len_;
};
-TEST_F(TlsFragmentationAndRecoveryTest, DropFirstHalf) { RunTest(0); }
+TEST_P(TlsFragmentationAndRecoveryTest, DropFirstHalf) { RunTest(0); }
-TEST_F(TlsFragmentationAndRecoveryTest, DropSecondHalf) { RunTest(1); }
+TEST_P(TlsFragmentationAndRecoveryTest, DropSecondHalf) { RunTest(1); }
-TEST_F(TlsDropDatagram13, NoDropsDuringZeroRtt) {
+TEST_P(TlsDropDatagram13, NoDropsDuringZeroRtt) {
SetupForZeroRtt();
SetFilters();
std::cerr << "Starting second handshake" << std::endl;
@@ -546,7 +550,7 @@ TEST_F(TlsDropDatagram13, NoDropsDuringZeroRtt) {
0x0002000000000000ULL}); // Finished
}
-TEST_F(TlsDropDatagram13, DropEEDuringZeroRtt) {
+TEST_P(TlsDropDatagram13, DropEEDuringZeroRtt) {
SetupForZeroRtt();
SetFilters();
std::cerr << "Starting second handshake" << std::endl;
@@ -591,7 +595,7 @@ class TlsReorderDatagram13 : public TlsDropDatagram13 {
// Reorder the server records so that EE comes at the end
// of the flight and will still produce an ACK.
-TEST_F(TlsDropDatagram13, ReorderServerEE) {
+TEST_P(TlsDropDatagram13, ReorderServerEE) {
server_filters_.drop_->Reset({1});
StartConnect();
client_->Handshake();
@@ -647,7 +651,7 @@ class TlsSendCipherSpecCapturer {
std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_;
};
-TEST_F(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) {
+TEST_P(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) {
StartConnect();
TlsSendCipherSpecCapturer capturer(client_);
client_->Handshake();
@@ -662,9 +666,9 @@ TEST_F(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) {
auto spec = capturer.spec(0);
ASSERT_NE(nullptr, spec.get());
ASSERT_EQ(2, spec->epoch());
- ASSERT_TRUE(client_->SendEncryptedRecord(
- spec, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 0x0002000000000002,
- kTlsApplicationDataType, DataBuffer(buf, sizeof(buf))));
+ ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002,
+ kTlsApplicationDataType,
+ DataBuffer(buf, sizeof(buf))));
// Now have the server consume the bogus message.
server_->ExpectSendAlert(illegal_parameter, kTlsAlertFatal);
@@ -673,7 +677,7 @@ TEST_F(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) {
EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, PORT_GetError());
}
-TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) {
+TEST_P(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) {
StartConnect();
TlsSendCipherSpecCapturer capturer(client_);
client_->Handshake();
@@ -688,9 +692,9 @@ TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) {
auto spec = capturer.spec(0);
ASSERT_NE(nullptr, spec.get());
ASSERT_EQ(2, spec->epoch());
- ASSERT_TRUE(client_->SendEncryptedRecord(
- spec, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 0x0002000000000002,
- kTlsHandshakeType, DataBuffer(buf, sizeof(buf))));
+ ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002,
+ kTlsHandshakeType,
+ DataBuffer(buf, sizeof(buf))));
server_->Handshake();
EXPECT_EQ(2UL, server_filters_.ack_->count());
// The server acknowledges client Finished twice.
@@ -700,7 +704,7 @@ TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) {
// Shrink the MTU down so that certs get split and then swap the first and
// second pieces of the server certificate.
-TEST_F(TlsReorderDatagram13, ReorderServerCertificate) {
+TEST_P(TlsReorderDatagram13, ReorderServerCertificate) {
StartConnect();
ShrinkPostServerHelloMtu();
client_->Handshake();
@@ -722,7 +726,7 @@ TEST_F(TlsReorderDatagram13, ReorderServerCertificate) {
CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
}
-TEST_F(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) {
+TEST_P(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) {
SetupForZeroRtt();
SetFilters();
std::cerr << "Starting second handshake" << std::endl;
@@ -761,7 +765,7 @@ TEST_F(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) {
EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
}
-TEST_F(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) {
+TEST_P(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) {
SetupForZeroRtt();
SetFilters();
std::cerr << "Starting second handshake" << std::endl;
@@ -812,12 +816,17 @@ static void GetCipherAndLimit(uint16_t version, uint16_t* cipher,
*cipher = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256;
*limit = (1ULL << 48) - 1;
} else {
+ // This test probably isn't especially useful for TLS 1.3, which has a much
+ // shorter sequence number encoding. That space can probably be searched in
+ // a reasonable amount of time.
*cipher = TLS_CHACHA20_POLY1305_SHA256;
- *limit = (1ULL << 48) - 1;
+ // Assume that we are starting with an expected sequence number of 0.
+ *limit = (1ULL << 29) - 1;
}
}
// This simulates a huge number of drops on one side.
+// See Bug 12965514 where a large gap was handled very inefficiently.
TEST_P(TlsConnectDatagram, MissLotsOfPackets) {
uint16_t cipher;
uint64_t limit;
@@ -834,6 +843,17 @@ TEST_P(TlsConnectDatagram, MissLotsOfPackets) {
SendReceive();
}
+// Send a sequence number of 0xfffffffd and it should be interpreted as that
+// (and not -3 or UINT64_MAX - 2).
+TEST_F(TlsConnectDatagram13, UnderflowSequenceNumber) {
+ Connect();
+ // This is only valid if short headers are disabled.
+ client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_FALSE);
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), (1ULL << 30) - 3));
+ SendReceive();
+}
+
class TlsConnectDatagram12Plus : public TlsConnectDatagram {
public:
TlsConnectDatagram12Plus() : TlsConnectDatagram() {}
@@ -865,5 +885,11 @@ INSTANTIATE_TEST_CASE_P(Datagram12Plus, TlsConnectDatagram12Plus,
TlsConnectTestBase::kTlsV12Plus);
INSTANTIATE_TEST_CASE_P(DatagramPre13, TlsConnectDatagramPre13,
TlsConnectTestBase::kTlsV11V12);
+INSTANTIATE_TEST_CASE_P(DatagramDrop13, TlsDropDatagram13,
+ ::testing::Values(true, false));
+INSTANTIATE_TEST_CASE_P(DatagramReorder13, TlsReorderDatagram13,
+ ::testing::Values(true, false));
+INSTANTIATE_TEST_CASE_P(DatagramFragment13, TlsFragmentationAndRecoveryTest,
+ ::testing::Values(true, false));
} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_fragment_unittest.cc b/gtests/ssl_gtest/ssl_fragment_unittest.cc
index f4940bf28..92947c2c7 100644
--- a/gtests/ssl_gtest/ssl_fragment_unittest.cc
+++ b/gtests/ssl_gtest/ssl_fragment_unittest.cc
@@ -20,14 +20,16 @@ namespace nss_test {
// This class cuts every unencrypted handshake record into two parts.
class RecordFragmenter : public PacketFilter {
public:
- RecordFragmenter() : sequence_number_(0), splitting_(true) {}
+ RecordFragmenter(bool is_dtls13)
+ : is_dtls13_(is_dtls13), sequence_number_(0), splitting_(true) {}
private:
class HandshakeSplitter {
public:
- HandshakeSplitter(const DataBuffer& input, DataBuffer* output,
- uint64_t* sequence_number)
- : input_(input),
+ HandshakeSplitter(bool is_dtls13, const DataBuffer& input,
+ DataBuffer* output, uint64_t* sequence_number)
+ : is_dtls13_(is_dtls13),
+ input_(input),
output_(output),
cursor_(0),
sequence_number_(sequence_number) {}
@@ -35,9 +37,9 @@ class RecordFragmenter : public PacketFilter {
private:
void WriteRecord(TlsRecordHeader& record_header,
DataBuffer& record_fragment) {
- TlsRecordHeader fragment_header(record_header.version(),
- record_header.content_type(),
- *sequence_number_);
+ TlsRecordHeader fragment_header(
+ record_header.variant(), record_header.version(),
+ record_header.content_type(), *sequence_number_);
++*sequence_number_;
if (::g_ssl_gtest_verbose) {
std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment
@@ -88,7 +90,7 @@ class RecordFragmenter : public PacketFilter {
while (parser.remaining()) {
TlsRecordHeader header;
DataBuffer record;
- if (!header.Parse(0, &parser, &record)) {
+ if (!header.Parse(is_dtls13_, 0, &parser, &record)) {
ADD_FAILURE() << "bad record header";
return false;
}
@@ -118,6 +120,7 @@ class RecordFragmenter : public PacketFilter {
}
private:
+ bool is_dtls13_;
const DataBuffer& input_;
DataBuffer* output_;
size_t cursor_;
@@ -132,7 +135,7 @@ class RecordFragmenter : public PacketFilter {
}
output->Allocate(input.len());
- HandshakeSplitter splitter(input, output, &sequence_number_);
+ HandshakeSplitter splitter(is_dtls13_, input, output, &sequence_number_);
if (!splitter.Split()) {
// If splitting fails, we obviously reached encrypted packets.
// Stop splitting from that point onward.
@@ -144,18 +147,21 @@ class RecordFragmenter : public PacketFilter {
}
private:
+ bool is_dtls13_;
uint64_t sequence_number_;
bool splitting_;
};
TEST_P(TlsConnectDatagram, FragmentClientPackets) {
- client_->SetFilter(std::make_shared<RecordFragmenter>());
+ bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3;
+ client_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13));
Connect();
SendReceive();
}
TEST_P(TlsConnectDatagram, FragmentServerPackets) {
- server_->SetFilter(std::make_shared<RecordFragmenter>());
+ bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3;
+ server_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13));
Connect();
SendReceive();
}
diff --git a/gtests/ssl_gtest/ssl_hrr_unittest.cc b/gtests/ssl_gtest/ssl_hrr_unittest.cc
index ba4cd804d..c78b328d8 100644
--- a/gtests/ssl_gtest/ssl_hrr_unittest.cc
+++ b/gtests/ssl_gtest/ssl_hrr_unittest.cc
@@ -81,8 +81,9 @@ class CorrectMessageSeqAfterHrrFilter : public TlsRecordFilter {
}
DataBuffer buffer(record);
- TlsRecordHeader new_header = {header.version(), header.content_type(),
- header.sequence_number() + 1};
+ TlsRecordHeader new_header(header.variant(), header.version(),
+ header.content_type(),
+ header.sequence_number() + 1);
// Correct message_seq.
buffer.Write(4, 1U, 2);
diff --git a/gtests/ssl_gtest/ssl_loopback_unittest.cc b/gtests/ssl_gtest/ssl_loopback_unittest.cc
index 2c292ae27..b4a74b99e 100644
--- a/gtests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -383,7 +383,8 @@ class TlsPreCCSHeaderInjector : public TlsRecordFilter {
std::cerr << "Injecting Finished header before CCS\n";
const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c};
DataBuffer hhdr_buf(hhdr, sizeof(hhdr));
- TlsRecordHeader nhdr(record_header.version(), kTlsHandshakeType, 0);
+ TlsRecordHeader nhdr(record_header.variant(), record_header.version(),
+ kTlsHandshakeType, 0);
*offset = nhdr.Write(output, *offset, hhdr_buf);
*offset = record_header.Write(output, *offset, input);
return CHANGE;
diff --git a/gtests/ssl_gtest/ssl_record_unittest.cc b/gtests/ssl_gtest/ssl_record_unittest.cc
index e76dab488..97bbbba01 100644
--- a/gtests/ssl_gtest/ssl_record_unittest.cc
+++ b/gtests/ssl_gtest/ssl_record_unittest.cc
@@ -168,6 +168,29 @@ TEST_F(TlsConnectStreamTls13, TooLargeRecord) {
EXPECT_EQ(SSL_ERROR_RECORD_OVERFLOW_ALERT, PORT_GetError());
}
+class ShortHeaderChecker : public PacketFilter {
+ public:
+ PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output) {
+ // The first octet should be 0b001xxxxx.
+ EXPECT_EQ(1, input.data()[0] >> 5);
+ return KEEP;
+ }
+};
+
+TEST_F(TlsConnectDatagram13, ShortHeadersClient) {
+ Connect();
+ client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_TRUE);
+ client_->SetFilter(std::make_shared<ShortHeaderChecker>());
+ SendReceive();
+}
+
+TEST_F(TlsConnectDatagram13, ShortHeadersServer) {
+ Connect();
+ server_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_TRUE);
+ server_->SetFilter(std::make_shared<ShortHeaderChecker>());
+ SendReceive();
+}
+
const static size_t kContentSizesArr[] = {
1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288};
diff --git a/gtests/ssl_gtest/tls_agent.cc b/gtests/ssl_gtest/tls_agent.cc
index dcff0df7d..084da0ab5 100644
--- a/gtests/ssl_gtest/tls_agent.cc
+++ b/gtests/ssl_gtest/tls_agent.cc
@@ -947,12 +947,13 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) {
}
bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
- uint16_t wireVersion, uint64_t seq,
- uint8_t ct, const DataBuffer& buf) {
- LOGV("Writing " << buf.len() << " bytes");
- // Ensure we are a TLS 1.3 cipher agent.
+ uint64_t seq, uint8_t ct,
+ const DataBuffer& buf) {
+ LOGV("Encrypting " << buf.len() << " bytes");
+ // Ensure that we are doing TLS 1.3.
EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3);
- TlsRecordHeader header(wireVersion, kTlsApplicationDataType, seq);
+ TlsRecordHeader header(variant_, expected_version_, kTlsApplicationDataType,
+ seq);
DataBuffer padded = buf;
padded.Write(padded.len(), ct, 1);
DataBuffer ciphertext;
@@ -1074,15 +1075,20 @@ void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type,
uint16_t version, const uint8_t* buf,
size_t len, DataBuffer* out,
- uint64_t seq_num) {
+ uint64_t sequence_number) {
size_t index = 0;
index = out->Write(index, type, 1);
if (variant == ssl_variant_stream) {
index = out->Write(index, version, 2);
+ } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ type == kTlsApplicationDataType) {
+ uint32_t epoch = (sequence_number >> 48) & 0x3;
+ uint32_t seqno = sequence_number & ((1ULL << 30) - 1);
+ index = out->Write(index, (epoch << 30) | seqno, 4);
} else {
index = out->Write(index, TlsVersionToDtlsVersion(version), 2);
- index = out->Write(index, seq_num >> 32, 4);
- index = out->Write(index, seq_num & PR_UINT32_MAX, 4);
+ index = out->Write(index, sequence_number >> 32, 4);
+ index = out->Write(index, sequence_number & PR_UINT32_MAX, 4);
}
index = out->Write(index, len, 2);
out->Write(index, buf, len);
diff --git a/gtests/ssl_gtest/tls_agent.h b/gtests/ssl_gtest/tls_agent.h
index 6719f56e4..5ce5e6280 100644
--- a/gtests/ssl_gtest/tls_agent.h
+++ b/gtests/ssl_gtest/tls_agent.h
@@ -143,8 +143,7 @@ class TlsAgent : public PollTarget {
void SendData(size_t bytes, size_t blocksize = 1024);
void SendBuffer(const DataBuffer& buf);
bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
- uint16_t wireVersion, uint64_t seq, uint8_t ct,
- const DataBuffer& buf);
+ uint64_t seq, uint8_t ct, const DataBuffer& buf);
// Send data directly to the underlying socket, skipping the TLS layer.
void SendDirect(const DataBuffer& buf);
void SendRecordDirect(const TlsRecord& record);
diff --git a/gtests/ssl_gtest/tls_filter.cc b/gtests/ssl_gtest/tls_filter.cc
index e775ea6fc..10b1e31d0 100644
--- a/gtests/ssl_gtest/tls_filter.cc
+++ b/gtests/ssl_gtest/tls_filter.cc
@@ -30,11 +30,9 @@ void TlsVersioned::WriteStream(std::ostream& stream) const {
case SSL_LIBRARY_VERSION_TLS_1_0:
stream << "1.0";
break;
- case SSL_LIBRARY_VERSION_DTLS_1_0_WIRE:
case SSL_LIBRARY_VERSION_TLS_1_1:
stream << (is_dtls() ? "1.0" : "1.1");
break;
- case SSL_LIBRARY_VERSION_DTLS_1_2_WIRE:
case SSL_LIBRARY_VERSION_TLS_1_2:
stream << "1.2";
break;
@@ -67,8 +65,14 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending,
return;
}
- self->in_sequence_number_ = 0;
- self->out_sequence_number_ = 0;
+ uint64_t seq_no;
+ if (self->agent()->variant() == ssl_variant_datagram) {
+ seq_no = static_cast<uint64_t>(SSLInt_CipherSpecToEpoch(newSpec)) << 48;
+ } else {
+ seq_no = 0;
+ }
+ self->in_sequence_number_ = seq_no;
+ self->out_sequence_number_ = seq_no;
self->dropped_record_ = false;
self->cipher_spec_.reset(new TlsCipherSpec());
bool ret = self->cipher_spec_->Init(
@@ -77,33 +81,59 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending,
EXPECT_EQ(true, ret);
}
+bool TlsRecordFilter::is_dtls13() const {
+ if (agent()->variant() != ssl_variant_datagram) {
+ return false;
+ }
+ if (agent()->state() == TlsAgent::STATE_CONNECTED) {
+ return agent()->version() >= SSL_LIBRARY_VERSION_TLS_1_3;
+ }
+ SSLPreliminaryChannelInfo info;
+ EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(agent()->ssl_fd(), &info,
+ sizeof(info)));
+ return (info.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) ||
+ info.canSendEarlyData;
+}
+
PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
+ // Disable during shutdown.
+ if (!agent()) {
+ return KEEP;
+ }
+
bool changed = false;
size_t offset = 0U;
- output->Allocate(input.len());
+ output->Allocate(input.len());
TlsParser parser(input);
while (parser.remaining()) {
TlsRecordHeader header;
DataBuffer record;
- if (!header.Parse(in_sequence_number_, &parser, &record)) {
+ if (!header.Parse(is_dtls13(), in_sequence_number_, &parser, &record)) {
ADD_FAILURE() << "not a valid record";
return KEEP;
}
- // Track the sequence number, which is necessary for stream mode (the
- // sequence number is in the header for datagram).
+ // Track the sequence number, which is necessary for stream mode when
+ // decrypting and for TLS 1.3 datagram to recover the sequence number.
//
- // This isn't perfectly robust. If there is a change from an active cipher
+ // We reset the counter when the cipher spec changes, but that notification
+ // appears before a record is sent. If multiple records are sent with
+ // different cipher specs, this would fail. This filters out cleartext
+ // records, so we don't get confused by handshake messages that are sent at
+ // the same time as encrypted records. Sequence numbers are therefore
+ // likely to be incorrect for cleartext records.
+ //
+ // This isn't perfectly robust: if there is a change from an active cipher
// spec to another active cipher spec (KeyUpdate for instance) AND writes
- // are consolidated across that change AND packets were dropped from the
- // older epoch, we will not correctly re-encrypt records in the old epoch to
- // update their sequence numbers.
- if (cipher_spec_ && header.content_type() == kTlsApplicationDataType) {
- ++in_sequence_number_;
+ // are consolidated across that change, this code could use the wrong
+ // sequence numbers when re-encrypting records with the old keys.
+ if (header.content_type() == kTlsApplicationDataType) {
+ in_sequence_number_ =
+ (std::max)(in_sequence_number_, header.sequence_number() + 1);
}
if (FilterRecord(header, record, &offset, output) != KEEP) {
@@ -131,11 +161,14 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
DataBuffer plaintext;
if (!Unprotect(header, record, &inner_content_type, &plaintext)) {
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "unprotect failed: " << header << ":" << record << std::endl;
+ }
return KEEP;
}
- TlsRecordHeader real_header = {header.version(), inner_content_type,
- header.sequence_number()};
+ TlsRecordHeader real_header(header.variant(), header.version(),
+ inner_content_type, header.sequence_number());
PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered);
// In stream mode, even if something doesn't change we need to re-encrypt if
@@ -166,8 +199,8 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
} else {
seq_num = out_sequence_number_++;
}
- TlsRecordHeader out_header = {header.version(), header.content_type(),
- seq_num};
+ TlsRecordHeader out_header(header.variant(), header.version(),
+ header.content_type(), seq_num);
DataBuffer ciphertext;
bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext);
@@ -179,20 +212,109 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
return CHANGE;
}
-bool TlsRecordHeader::Parse(uint64_t seqno, TlsParser* parser,
+size_t TlsRecordHeader::header_length() const {
+ if (!is_dtls()) {
+ return 5;
+ }
+ if (version() >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ content_type_ == kTlsApplicationDataType) {
+ return 7;
+ }
+ return 13;
+}
+
+uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t expected,
+ uint32_t partial,
+ size_t partial_bits) {
+ EXPECT_GE(32U, partial_bits);
+ uint64_t mask = (1 << partial_bits) - 1;
+ // First we determine the highest possible value. This is half the
+ // expressible range above the expected value.
+ uint64_t cap = expected + (1ULL << (partial_bits - 1));
+ // Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234.
+ uint64_t seq_no = (cap & ~mask) | partial;
+ // If the partial value is higher than the same partial piece from the cap,
+ // then the real value has to be lower. e.g., xxxx1234 can't become xxxx5678.
+ if (partial > (cap & mask)) {
+ seq_no -= 1ULL << partial_bits;
+ }
+ return seq_no;
+}
+
+// Determine the full epoch and sequence number from an expected and raw value.
+// The expected and output values are packed as they are in DTLS 1.2 and
+// earlier: with 16 bits of epoch and 48 bits of sequence number.
+uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint32_t raw,
+ size_t seq_no_bits,
+ size_t epoch_bits) {
+ uint64_t epoch_mask = (1ULL << epoch_bits) - 1;
+ uint64_t epoch = RecoverSequenceNumber(
+ expected >> 48, (raw >> seq_no_bits) & epoch_mask, epoch_bits);
+ if (epoch > (expected >> 48)) {
+ // If the epoch has changed, reset the expected sequence number.
+ expected = 0;
+ } else {
+ // Otherwise, retain just the sequence number part.
+ expected &= (1ULL << 48) - 1;
+ }
+ uint64_t seq_no_mask = (1ULL << seq_no_bits) - 1;
+ uint64_t seq_no =
+ RecoverSequenceNumber(expected, raw & seq_no_mask, seq_no_bits);
+ return (epoch << 48) | seq_no;
+}
+
+bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser,
DataBuffer* body) {
if (!parser->Read(&content_type_)) {
return false;
}
+ if (is_dtls13) {
+ variant_ = ssl_variant_datagram;
+ version_ = SSL_LIBRARY_VERSION_TLS_1_3;
+
+#ifndef UNSAFE_FUZZER_MODE
+ // Deal with the 7 octet header.
+ if (content_type_ == kTlsApplicationDataType) {
+ uint32_t tmp;
+ if (!parser->Read(&tmp, 4)) {
+ return false;
+ }
+ sequence_number_ = ParseSequenceNumber(seqno, tmp, 30, 2);
+ return parser->ReadVariable(body, 2);
+ }
+
+ // The short, 2 octet header.
+ if ((content_type_ & 0xe0) == 0x20) {
+ uint32_t tmp;
+ if (!parser->Read(&tmp, 1)) {
+ return false;
+ }
+ // Need to use the low 5 bits of the first octet too.
+ tmp |= (content_type_ & 0x1f) << 8;
+ content_type_ = kTlsApplicationDataType;
+ sequence_number_ = ParseSequenceNumber(seqno, tmp, 12, 1);
+ return parser->Read(body, parser->remaining());
+ }
+
+ // The full 13 octet header can only be used for a few types.
+ EXPECT_TRUE(content_type_ == kTlsAlertType ||
+ content_type_ == kTlsHandshakeType ||
+ content_type_ == kTlsAckType);
+#endif
+ }
+
uint32_t ver;
if (!parser->Read(&ver, 2)) {
return false;
}
- version_ = ver;
+ if (!is_dtls13) {
+ variant_ = IsDtls(ver) ? ssl_variant_datagram : ssl_variant_stream;
+ }
+ version_ = NormalizeTlsVersion(ver);
- // If this is DTLS, overwrite the sequence number.
- if (IsDtls(ver)) {
+ if (is_dtls()) {
+ // If this is DTLS, read the sequence number.
uint32_t tmp;
if (!parser->Read(&tmp, 4)) {
return false;
@@ -211,11 +333,21 @@ bool TlsRecordHeader::Parse(uint64_t seqno, TlsParser* parser,
size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset,
const DataBuffer& body) const {
offset = buffer->Write(offset, content_type_, 1);
- offset = buffer->Write(offset, version_, 2);
- if (is_dtls()) {
- // write epoch (2 octet), and seqnum (6 octet)
- offset = buffer->Write(offset, sequence_number_ >> 32, 4);
- offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4);
+ if (is_dtls() && version_ >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ content_type() == kTlsApplicationDataType) {
+ // application_data records in TLS 1.3 have a different header format.
+ // Always use the long header here for simplicity.
+ uint32_t e = (sequence_number_ >> 48) & 0x3;
+ uint32_t seqno = sequence_number_ & ((1ULL << 30) - 1);
+ offset = buffer->Write(offset, (e << 30) | seqno, 4);
+ } else {
+ uint16_t v = is_dtls() ? TlsVersionToDtlsVersion(version_) : version_;
+ offset = buffer->Write(offset, v, 2);
+ if (is_dtls()) {
+ // write epoch (2 octet), and seqnum (6 octet)
+ offset = buffer->Write(offset, sequence_number_ >> 32, 4);
+ offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4);
+ }
}
offset = buffer->Write(offset, body.len(), 2);
offset = buffer->Write(offset, body);
@@ -406,6 +538,7 @@ bool TlsHandshakeFilter::HandshakeHeader::Parse(
const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) {
*complete = false;
+ variant_ = record_header.variant();
version_ = record_header.version();
if (!parser->Read(&handshake_type_)) {
return false; // malformed
diff --git a/gtests/ssl_gtest/tls_filter.h b/gtests/ssl_gtest/tls_filter.h
index 5d1321f8f..296c0a9ee 100644
--- a/gtests/ssl_gtest/tls_filter.h
+++ b/gtests/ssl_gtest/tls_filter.h
@@ -11,7 +11,7 @@
#include <memory>
#include <set>
#include <vector>
-
+#include "sslt.h"
#include "test_io.h"
#include "tls_agent.h"
#include "tls_parser.h"
@@ -27,38 +27,47 @@ class TlsCipherSpec;
class TlsVersioned {
public:
- TlsVersioned() : version_(0) {}
- explicit TlsVersioned(uint16_t v) : version_(v) {}
+ TlsVersioned() : variant_(ssl_variant_stream), version_(0) {}
+ TlsVersioned(SSLProtocolVariant var, uint16_t ver)
+ : variant_(var), version_(ver) {}
- bool is_dtls() const { return IsDtls(version_); }
+ bool is_dtls() const { return variant_ == ssl_variant_datagram; }
+ SSLProtocolVariant variant() const { return variant_; }
uint16_t version() const { return version_; }
void WriteStream(std::ostream& stream) const;
protected:
+ SSLProtocolVariant variant_;
uint16_t version_;
};
class TlsRecordHeader : public TlsVersioned {
public:
TlsRecordHeader() : TlsVersioned(), content_type_(0), sequence_number_(0) {}
- TlsRecordHeader(uint16_t ver, uint8_t ct, uint64_t seqno)
- : TlsVersioned(ver), content_type_(ct), sequence_number_(seqno) {}
+ TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct,
+ uint64_t seqno)
+ : TlsVersioned(var, ver), content_type_(ct), sequence_number_(seqno) {}
uint8_t content_type() const { return content_type_; }
uint64_t sequence_number() const { return sequence_number_; }
uint16_t epoch() const {
return static_cast<uint16_t>(sequence_number_ >> 48);
}
- size_t header_length() const { return is_dtls() ? 13 : 5; }
-
+ size_t header_length() const;
// Parse the header; return true if successful; body in an outparam if OK.
- bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body);
+ bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser,
+ DataBuffer* body);
// Write the header and body to a buffer at the given offset.
// Return the offset of the end of the write.
size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
private:
+ static uint64_t RecoverSequenceNumber(uint64_t expected, uint32_t partial,
+ size_t partial_bits);
+ static uint64_t ParseSequenceNumber(uint64_t expected, uint32_t raw,
+ size_t seq_no_bits, size_t epoch_bits);
+
uint8_t content_type_;
uint64_t sequence_number_;
};
@@ -127,6 +136,8 @@ class TlsRecordFilter : public PacketFilter {
return KEEP;
}
+ bool is_dtls13() const;
+
private:
static void CipherSpecChanged(void* arg, PRBool sending,
ssl3CipherSpec* newSpec);
diff --git a/lib/ssl/dtls13con.c b/lib/ssl/dtls13con.c
index aba0f62ab..04cbd373c 100644
--- a/lib/ssl/dtls13con.c
+++ b/lib/ssl/dtls13con.c
@@ -11,6 +11,43 @@
#include "sslimpl.h"
#include "sslproto.h"
+SECStatus
+dtls13_InsertCipherTextHeader(const sslSocket *ss, ssl3CipherSpec *cwSpec,
+ sslBuffer *wrBuf, PRBool *needsLength)
+{
+ PRUint32 seq;
+ SECStatus rv;
+
+ /* Avoid using short records for the handshake. We pack multiple records
+ * into the one datagram for the handshake. */
+ if (ss->opt.enableDtlsShortHeader &&
+ cwSpec->epoch != TrafficKeyHandshake) {
+ *needsLength = PR_FALSE;
+ /* The short header is comprised of two octets in the form
+ * 0b001essssssssssss where 'e' is the low bit of the epoch and 's' is
+ * the low 12 bits of the sequence number. */
+ seq = 0x2000 |
+ (((uint64_t)cwSpec->epoch & 1) << 12) |
+ (cwSpec->nextSeqNum & 0xfff);
+ return sslBuffer_AppendNumber(wrBuf, seq, 2);
+ }
+
+ rv = sslBuffer_AppendNumber(wrBuf, content_application_data, 1);
+ if (rv != SECSuccess) {
+ return SECFailure;
+ }
+
+ /* The epoch and sequence number are encoded on 4 octets, with the epoch
+ * consuming the first two bits. */
+ seq = (((uint64_t)cwSpec->epoch & 3) << 30) | (cwSpec->nextSeqNum & 0x3fffffff);
+ rv = sslBuffer_AppendNumber(wrBuf, seq, 4);
+ if (rv != SECSuccess) {
+ return SECFailure;
+ }
+ *needsLength = PR_TRUE;
+ return SECSuccess;
+}
+
/* DTLS 1.3 Record map for ACK processing.
* This represents a single fragment, so a record which includes
* multiple fragments will have one entry for each fragment on the
diff --git a/lib/ssl/dtls13con.h b/lib/ssl/dtls13con.h
index bf14d3bd2..ca48ef363 100644
--- a/lib/ssl/dtls13con.h
+++ b/lib/ssl/dtls13con.h
@@ -9,6 +9,10 @@
#ifndef __dtls13con_h_
#define __dtls13con_h_
+SECStatus dtls13_InsertCipherTextHeader(const sslSocket *ss,
+ ssl3CipherSpec *cwSpec,
+ sslBuffer *wrBuf,
+ PRBool *needsLength);
SECStatus dtls13_RememberFragment(sslSocket *ss, PRCList *list,
PRUint32 sequence, PRUint32 offset,
PRUint32 length, DTLSEpoch epoch,
diff --git a/lib/ssl/dtlscon.c b/lib/ssl/dtlscon.c
index 2f335f924..6c3d7d24c 100644
--- a/lib/ssl/dtlscon.c
+++ b/lib/ssl/dtlscon.c
@@ -776,7 +776,7 @@ dtls_FragmentHandshake(sslSocket *ss, DTLSQueuedMessage *msg)
rv = dtls13_RememberFragment(ss, &ss->ssl3.hs.dtlsSentHandshake,
msgSeq, fragmentOffset, fragmentLen,
msg->cwSpec->epoch,
- msg->cwSpec->seqNum);
+ msg->cwSpec->nextSeqNum);
if (rv != SECSuccess) {
return SECFailure;
}
@@ -1319,6 +1319,107 @@ DTLS_GetHandshakeTimeout(PRFileDesc *socket, PRIntervalTime *timeout)
return SECSuccess;
}
+PRBool
+dtls_IsLongHeader(SSL3ProtocolVersion version, PRUint8 firstOctet)
+{
+#ifndef UNSAFE_FUZZER_MODE
+ return version < SSL_LIBRARY_VERSION_TLS_1_3 ||
+ firstOctet == content_handshake ||
+ firstOctet == content_ack ||
+ firstOctet == content_alert;
+#else
+ return PR_TRUE;
+#endif
+}
+
+DTLSEpoch
+dtls_ReadEpoch(const ssl3CipherSpec *crSpec, const PRUint8 *hdr)
+{
+ DTLSEpoch epoch;
+ DTLSEpoch maxEpoch;
+ DTLSEpoch partial;
+
+ if (dtls_IsLongHeader(crSpec->version, hdr[0])) {
+ return ((DTLSEpoch)hdr[3] << 8) | hdr[4];
+ }
+
+ /* A lot of how we recover the epoch here will depend on how we plan to
+ * manage KeyUpdate. In the case that we decide to install a new read spec
+ * as a KeyUpdate is handled, crSpec will always be the highest epoch we can
+ * possibly receive. That makes this easier to manage. */
+ if ((hdr[0] & 0xe0) == 0x20) {
+ /* Use crSpec->epoch, or crSpec->epoch - 1 if the last bit differs. */
+ if (((hdr[0] >> 4) & 1) == (crSpec->epoch & 1)) {
+ return crSpec->epoch;
+ }
+ return crSpec->epoch - 1;
+ }
+
+ /* dtls_GatherData should ensure that this works. */
+ PORT_Assert(hdr[0] == content_application_data);
+
+ /* This uses the same method as is used to recover the sequence number in
+ * dtls_ReadSequenceNumber, except that the maximum value is set to the
+ * current epoch. */
+ partial = hdr[1] >> 6;
+ maxEpoch = PR_MAX(crSpec->epoch, 3);
+ epoch = (maxEpoch & 0xfffc) | partial;
+ if (partial > (maxEpoch & 0x03)) {
+ epoch -= 4;
+ }
+ return epoch;
+}
+
+static sslSequenceNumber
+dtls_ReadSequenceNumber(const ssl3CipherSpec *spec, const PRUint8 *hdr)
+{
+ sslSequenceNumber cap;
+ sslSequenceNumber partial;
+ sslSequenceNumber seqNum;
+ sslSequenceNumber mask;
+
+ if (dtls_IsLongHeader(spec->version, hdr[0])) {
+ static const unsigned int seqNumOffset = 5; /* type, version, epoch */
+ static const unsigned int seqNumLength = 6;
+ sslReader r = SSL_READER(hdr + seqNumOffset, seqNumLength);
+ (void)sslRead_ReadNumber(&r, seqNumLength, &seqNum);
+ return seqNum;
+ }
+
+ /* Only the least significant bits of the sequence number is available here.
+ * This recovers the value based on the next expected sequence number.
+ *
+ * This works by determining the maximum possible sequence number, which is
+ * half the range of possible values above the expected next value (the
+ * expected next value is in |spec->seqNum|). Then, the last part of the
+ * sequence number is replaced. If that causes the value to exceed the
+ * maximum, subtract an entire range.
+ */
+ if ((hdr[0] & 0xe0) == 0x20) {
+ /* A 12-bit sequence number. */
+ cap = spec->nextSeqNum + (1ULL << 11);
+ partial = (((sslSequenceNumber)hdr[0] & 0xf) << 8) |
+ (sslSequenceNumber)hdr[1];
+ mask = (1ULL << 12) - 1;
+ } else {
+ /* A 30-bit sequence number. */
+ cap = spec->nextSeqNum + (1ULL << 29);
+ partial = (((sslSequenceNumber)hdr[1] & 0x3f) << 24) |
+ ((sslSequenceNumber)hdr[2] << 16) |
+ ((sslSequenceNumber)hdr[3] << 8) |
+ (sslSequenceNumber)hdr[4];
+ mask = (1ULL << 30) - 1;
+ }
+ seqNum = (cap & ~mask) | partial;
+ /* The second check prevents the value from underflowing if we get a large
+ * gap at the start of a connection, where this subtraction would cause the
+ * sequence number to wrap to near UINT64_MAX. */
+ if ((partial > (cap & mask)) && (seqNum > mask)) {
+ seqNum -= mask + 1;
+ }
+ return seqNum;
+}
+
/*
* DTLS relevance checks:
* Note that this code currently ignores all out-of-epoch packets,
@@ -1336,7 +1437,7 @@ dtls_IsRelevant(sslSocket *ss, const ssl3CipherSpec *spec,
const SSL3Ciphertext *cText,
sslSequenceNumber *seqNumOut)
{
- sslSequenceNumber seqNum = cText->seq_num & RECORD_SEQ_MASK;
+ sslSequenceNumber seqNum = dtls_ReadSequenceNumber(spec, cText->hdr);
if (dtls_RecordGetRecvd(&spec->recvdRecords, seqNum) != 0) {
SSL_TRC(10, ("%d: SSL3[%d]: dtls_IsRelevant, rejecting "
"potentially replayed packet",
diff --git a/lib/ssl/dtlscon.h b/lib/ssl/dtlscon.h
index d094380f8..45fc069b9 100644
--- a/lib/ssl/dtlscon.h
+++ b/lib/ssl/dtlscon.h
@@ -41,8 +41,10 @@ extern SSL3ProtocolVersion
dtls_TLSVersionToDTLSVersion(SSL3ProtocolVersion tlsv);
extern SSL3ProtocolVersion
dtls_DTLSVersionToTLSVersion(SSL3ProtocolVersion dtlsv);
+DTLSEpoch dtls_ReadEpoch(const ssl3CipherSpec *crSpec, const PRUint8 *hdr);
extern PRBool dtls_IsRelevant(sslSocket *ss, const ssl3CipherSpec *spec,
const SSL3Ciphertext *cText,
sslSequenceNumber *seqNum);
void dtls_ReceivedFirstMessageInFlight(sslSocket *ss);
+PRBool dtls_IsLongHeader(SSL3ProtocolVersion version, PRUint8 firstOctet);
#endif
diff --git a/lib/ssl/ssl.h b/lib/ssl/ssl.h
index 25aabbaa2..ad8ec0f8b 100644
--- a/lib/ssl/ssl.h
+++ b/lib/ssl/ssl.h
@@ -254,6 +254,17 @@ SSL_IMPORT PRFileDesc *DTLS_ImportFD(PRFileDesc *model, PRFileDesc *fd);
* no effect for a server. This setting is ignored for DTLS. */
#define SSL_ENABLE_TLS13_COMPAT_MODE 35
+/* Enables the sending of DTLS records using the short (two octet) record
+ * header. Only do this if there are 2^10 or fewer packets in flight at a time;
+ * using this with a larger number of packets in flight could mean that packets
+ * are dropped if there is reordering.
+ *
+ * This applies to TLS 1.3 only. This is not a parameter that is negotiated
+ * during the TLS handshake. Unlike other socket options, this option can be
+ * changed after a handshake is complete.
+ */
+#define SSL_ENABLE_DTLS_SHORT_HEADER 36
+
#ifdef SSL_DEPRECATED_FUNCTION
/* Old deprecated function names */
SSL_IMPORT SECStatus SSL_Enable(PRFileDesc *fd, int option, PRIntn on);
diff --git a/lib/ssl/ssl3con.c b/lib/ssl/ssl3con.c
index df9d8cb6c..22fdaf5b1 100644
--- a/lib/ssl/ssl3con.c
+++ b/lib/ssl/ssl3con.c
@@ -1415,7 +1415,7 @@ ssl3_SetupPendingCipherSpec(sslSocket *ss, CipherSpecDirection direction,
spec->macDef = ssl_GetMacDef(ss, suiteDef);
spec->epoch = prev->epoch + 1;
- spec->seqNum = 0;
+ spec->nextSeqNum = 0;
if (IS_DTLS(ss) && direction == CipherSpecRead) {
dtls_InitRecvdRecords(&spec->recvdRecords);
}
@@ -2004,6 +2004,7 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec,
unsigned int ivLen = 0;
unsigned char pseudoHeaderBuf[13];
sslBuffer pseudoHeader = SSL_BUFFER(pseudoHeaderBuf);
+ int len;
if (cwSpec->cipherDef->type == type_block &&
cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_1) {
@@ -2013,29 +2014,32 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec,
* record.
*/
ivLen = cwSpec->cipherDef->iv_size;
- if (ivLen > wrBuf->space) {
+ if (ivLen > SSL_BUFFER_SPACE(wrBuf)) {
PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
return SECFailure;
}
- rv = PK11_GenerateRandom(wrBuf->buf, ivLen);
+ rv = PK11_GenerateRandom(SSL_BUFFER_NEXT(wrBuf), ivLen);
if (rv != SECSuccess) {
ssl_MapLowLevelError(SSL_ERROR_GENERATE_RANDOM_FAILURE);
return rv;
}
rv = cwSpec->cipher(cwSpec->cipherContext,
- wrBuf->buf, /* output */
- (int *)&wrBuf->len, /* outlen */
- ivLen, /* max outlen */
- wrBuf->buf, /* input */
- ivLen); /* input len */
- if (rv != SECSuccess || wrBuf->len != ivLen) {
+ SSL_BUFFER_NEXT(wrBuf), /* output */
+ &len, /* outlen */
+ ivLen, /* max outlen */
+ SSL_BUFFER_NEXT(wrBuf), /* input */
+ ivLen); /* input len */
+ if (rv != SECSuccess || len != ivLen) {
PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE);
return SECFailure;
}
+
+ rv = sslBuffer_Skip(wrBuf, len, NULL);
+ PORT_Assert(rv == SECSuccess); /* Can't fail. */
}
rv = ssl3_BuildRecordPseudoHeader(
- cwSpec->epoch, cwSpec->seqNum, type,
+ cwSpec->epoch, cwSpec->nextSeqNum, type,
cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_0, cwSpec->recordVersion,
isDTLS, contentLen, &pseudoHeader);
PORT_Assert(rv == SECSuccess);
@@ -2043,23 +2047,26 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec,
const int nonceLen = cwSpec->cipherDef->explicit_nonce_size;
const int tagLen = cwSpec->cipherDef->tag_size;
- if (nonceLen + contentLen + tagLen > wrBuf->space) {
+ if (nonceLen + contentLen + tagLen > SSL_BUFFER_SPACE(wrBuf)) {
PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
return SECFailure;
}
rv = cwSpec->aead(
&cwSpec->keyMaterial,
- PR_FALSE, /* do encrypt */
- wrBuf->buf, /* output */
- (int *)&wrBuf->len, /* out len */
- wrBuf->space, /* max out */
- pIn, contentLen, /* input */
+ PR_FALSE, /* do encrypt */
+ SSL_BUFFER_NEXT(wrBuf), /* output */
+ &len, /* out len */
+ SSL_BUFFER_SPACE(wrBuf), /* max out */
+ pIn, contentLen, /* input */
SSL_BUFFER_BASE(&pseudoHeader), SSL_BUFFER_LEN(&pseudoHeader));
if (rv != SECSuccess) {
PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE);
return SECFailure;
}
+
+ rv = sslBuffer_Skip(wrBuf, len, NULL);
+ PORT_Assert(rv == SECSuccess); /* Can't fail. */
} else {
int blockSize = cwSpec->cipherDef->block_size;
@@ -2069,7 +2076,7 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec,
rv = ssl3_ComputeRecordMAC(cwSpec, SSL_BUFFER_BASE(&pseudoHeader),
SSL_BUFFER_LEN(&pseudoHeader),
pIn, contentLen,
- wrBuf->buf + ivLen + contentLen, &macLen);
+ SSL_BUFFER_NEXT(wrBuf) + contentLen, &macLen);
if (rv != SECSuccess) {
ssl_MapLowLevelError(SSL_ERROR_MAC_COMPUTATION_FAILURE);
return SECFailure;
@@ -2095,7 +2102,7 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec,
PORT_Assert((fragLen % blockSize) == 0);
/* Pad according to TLS rules (also acceptable to SSL3). */
- pBuf = &wrBuf->buf[ivLen + fragLen - 1];
+ pBuf = SSL_BUFFER_NEXT(wrBuf) + fragLen - 1;
for (i = padding_length + 1; i > 0; --i) {
*pBuf-- = padding_length;
}
@@ -2112,14 +2119,14 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec,
p2Len += oddLen;
PORT_Assert((blockSize < 2) ||
(p2Len % blockSize) == 0);
- memmove(wrBuf->buf + ivLen + p1Len, pIn + p1Len, oddLen);
+ memmove(SSL_BUFFER_NEXT(wrBuf) + p1Len, pIn + p1Len, oddLen);
}
if (p1Len > 0) {
int cipherBytesPart1 = -1;
rv = cwSpec->cipher(cwSpec->cipherContext,
- wrBuf->buf + ivLen, /* output */
- &cipherBytesPart1, /* actual outlen */
- p1Len, /* max outlen */
+ SSL_BUFFER_NEXT(wrBuf), /* output */
+ &cipherBytesPart1, /* actual outlen */
+ p1Len, /* max outlen */
pIn,
p1Len); /* input, and inputlen */
PORT_Assert(rv == SECSuccess && cipherBytesPart1 == (int)p1Len);
@@ -2127,22 +2134,24 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec,
PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE);
return SECFailure;
}
- wrBuf->len += cipherBytesPart1;
+ rv = sslBuffer_Skip(wrBuf, p1Len, NULL);
+ PORT_Assert(rv == SECSuccess);
}
if (p2Len > 0) {
int cipherBytesPart2 = -1;
rv = cwSpec->cipher(cwSpec->cipherContext,
- wrBuf->buf + ivLen + p1Len,
+ SSL_BUFFER_NEXT(wrBuf),
&cipherBytesPart2, /* output and actual outLen */
p2Len, /* max outlen */
- wrBuf->buf + ivLen + p1Len,
+ SSL_BUFFER_NEXT(wrBuf),
p2Len); /* input and inputLen*/
PORT_Assert(rv == SECSuccess && cipherBytesPart2 == (int)p2Len);
if (rv != SECSuccess || cipherBytesPart2 != (int)p2Len) {
PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE);
return SECFailure;
}
- wrBuf->len += cipherBytesPart2;
+ rv = sslBuffer_Skip(wrBuf, p2Len, NULL);
+ PORT_Assert(rv == SECSuccess);
}
}
@@ -2152,14 +2161,18 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec,
/* Note: though this can report failure, it shouldn't. */
static SECStatus
ssl_InsertRecordHeader(const sslSocket *ss, ssl3CipherSpec *cwSpec,
- SSL3ContentType contentType, unsigned int len,
- sslBuffer *wrBuf)
+ SSL3ContentType contentType, sslBuffer *wrBuf,
+ PRBool *needsLength)
{
SECStatus rv;
#ifndef UNSAFE_FUZZER_MODE
if (cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
- cwSpec->cipherDef->calg != ssl_calg_null) {
+ cwSpec->epoch > TrafficKeyClearText) {
+ if (IS_DTLS(ss)) {
+ return dtls13_InsertCipherTextHeader(ss, cwSpec, wrBuf,
+ needsLength);
+ }
contentType = content_application_data;
}
#endif
@@ -2177,16 +2190,12 @@ ssl_InsertRecordHeader(const sslSocket *ss, ssl3CipherSpec *cwSpec,
if (rv != SECSuccess) {
return SECFailure;
}
- rv = sslBuffer_AppendNumber(wrBuf, cwSpec->seqNum, 6);
+ rv = sslBuffer_AppendNumber(wrBuf, cwSpec->nextSeqNum, 6);
if (rv != SECSuccess) {
return SECFailure;
}
}
- rv = sslBuffer_AppendNumber(wrBuf, len, 2);
- if (rv != SECSuccess) {
- return SECFailure;
- }
-
+ *needsLength = PR_TRUE;
return SECSuccess;
}
@@ -2194,66 +2203,67 @@ SECStatus
ssl_ProtectRecord(sslSocket *ss, ssl3CipherSpec *cwSpec, SSL3ContentType type,
const PRUint8 *pIn, PRUint32 contentLen, sslBuffer *wrBuf)
{
- unsigned int headerLen = IS_DTLS(ss) ? DTLS_RECORD_HEADER_LENGTH
- : SSL3_RECORD_HEADER_LENGTH;
- sslBuffer protBuf = SSL_BUFFER_FIXED(SSL_BUFFER_BASE(wrBuf) + headerLen,
- SSL_BUFFER_SPACE(wrBuf) - headerLen);
- PRBool isTLS13;
+ PRBool needsLength;
+ unsigned int lenOffset;
SECStatus rv;
PORT_Assert(cwSpec->direction == CipherSpecWrite);
PORT_Assert(SSL_BUFFER_LEN(wrBuf) == 0);
PORT_Assert(cwSpec->cipherDef->max_records <= RECORD_SEQ_MAX);
- if (cwSpec->seqNum >= cwSpec->cipherDef->max_records) {
+
+ if (cwSpec->nextSeqNum >= cwSpec->cipherDef->max_records) {
/* We should have automatically updated before here in TLS 1.3. */
PORT_Assert(cwSpec->version < SSL_LIBRARY_VERSION_TLS_1_3);
SSL_TRC(3, ("%d: SSL[-]: write sequence number at limit 0x%0llx",
- SSL_GETPID(), cwSpec->seqNum));
+ SSL_GETPID(), cwSpec->nextSeqNum));
PORT_SetError(SSL_ERROR_TOO_MANY_RECORDS);
return SECFailure;
}
- isTLS13 = (PRBool)(cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_3);
+ rv = ssl_InsertRecordHeader(ss, cwSpec, type, wrBuf, &needsLength);
+ if (rv != SECSuccess) {
+ return SECFailure;
+ }
+ if (needsLength) {
+ rv = sslBuffer_Skip(wrBuf, 2, &lenOffset);
+ if (rv != SECSuccess) {
+ return SECFailure;
+ }
+ }
#ifdef UNSAFE_FUZZER_MODE
{
int len;
- rv = Null_Cipher(NULL, SSL_BUFFER_BASE(&protBuf), &len,
- SSL_BUFFER_SPACE(&protBuf), pIn, contentLen);
+ rv = Null_Cipher(NULL, SSL_BUFFER_NEXT(wrBuf), &len,
+ SSL_BUFFER_SPACE(wrBuf), pIn, contentLen);
if (rv != SECSuccess) {
return SECFailure; /* error was set */
}
- rv = sslBuffer_Skip(&protBuf, len, NULL);
+ rv = sslBuffer_Skip(wrBuf, len, NULL);
PORT_Assert(rv == SECSuccess); /* Can't fail. */
}
#else
- if (isTLS13) {
- rv = tls13_ProtectRecord(ss, cwSpec, type, pIn, contentLen, &protBuf);
+ if (cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ rv = tls13_ProtectRecord(ss, cwSpec, type, pIn, contentLen, wrBuf);
} else {
rv = ssl3_MACEncryptRecord(cwSpec, ss->sec.isServer, IS_DTLS(ss), type,
- pIn, contentLen, &protBuf);
+ pIn, contentLen, wrBuf);
}
#endif
if (rv != SECSuccess) {
return SECFailure; /* error was set */
}
- PORT_Assert(protBuf.len <= MAX_FRAGMENT_LENGTH + (isTLS13 ? 256 : 1024));
-
- rv = ssl_InsertRecordHeader(ss, cwSpec, type, SSL_BUFFER_LEN(&protBuf),
- wrBuf);
- if (rv != SECSuccess) {
- return SECFailure;
- }
-
- PORT_Assert(SSL_BUFFER_LEN(wrBuf) == headerLen);
- rv = sslBuffer_Skip(wrBuf, SSL_BUFFER_LEN(&protBuf), NULL);
- if (rv != SECSuccess) {
- PORT_Assert(0); /* Can't fail. */
- return SECFailure;
+ if (needsLength) {
+ /* Insert the length. */
+ rv = sslBuffer_InsertLength(wrBuf, lenOffset, 2);
+ if (rv != SECSuccess) {
+ PORT_Assert(0); /* Can't fail. */
+ return SECFailure;
+ }
}
- ++cwSpec->seqNum;
+ ++cwSpec->nextSeqNum;
return SECSuccess;
}
@@ -2291,6 +2301,7 @@ ssl_ProtectNextRecord(sslSocket *ss, ssl3CipherSpec *spec, SSL3ContentType type,
*written = contentLen;
return SECSuccess;
}
+
/* Process the plain text before sending it.
* Returns the number of bytes of plaintext that were successfully sent
* plus the number of bytes of plaintext that were copied into the
@@ -2368,7 +2379,7 @@ ssl3_SendRecord(sslSocket *ss,
rv = ssl_ProtectNextRecord(ss, spec, type, pIn, nIn, &written);
ssl_ReleaseSpecReadLock(ss);
if (rv != SECSuccess) {
- return SECFailure;
+ goto loser;
}
PORT_Assert(written > 0);
@@ -11847,6 +11858,7 @@ ssl3_UnprotectRecord(sslSocket *ss,
unsigned int good;
unsigned int ivLen = 0;
SSL3ContentType rType;
+ SSL3ProtocolVersion rVersion;
unsigned int minLength;
unsigned int originalLen = 0;
PRUint8 headerBuf[13];
@@ -11919,7 +11931,9 @@ ssl3_UnprotectRecord(sslSocket *ss,
return SECFailure;
}
- rType = cText->type;
+ rType = (SSL3ContentType)cText->hdr[0];
+ rVersion = ((SSL3ProtocolVersion)cText->hdr[1] << 8) |
+ (SSL3ProtocolVersion)cText->hdr[2];
if (cipher_def->type == type_aead) {
/* XXX For many AEAD ciphers, the plaintext is shorter than the
* ciphertext by a fixed byte count, but it is not true in general.
@@ -11929,8 +11943,8 @@ ssl3_UnprotectRecord(sslSocket *ss,
cText->buf->len - cipher_def->explicit_nonce_size -
cipher_def->tag_size;
rv = ssl3_BuildRecordPseudoHeader(
- spec->epoch, IS_DTLS(ss) ? cText->seq_num : spec->seqNum,
- rType, isTLS, cText->version, IS_DTLS(ss), decryptedLen, &header);
+ spec->epoch, cText->seqNum,
+ rType, isTLS, rVersion, IS_DTLS(ss), decryptedLen, &header);
PORT_Assert(rv == SECSuccess);
rv = spec->aead(&spec->keyMaterial,
PR_TRUE, /* do decrypt */
@@ -11977,8 +11991,8 @@ ssl3_UnprotectRecord(sslSocket *ss,
/* compute the MAC */
rv = ssl3_BuildRecordPseudoHeader(
- spec->epoch, IS_DTLS(ss) ? cText->seq_num : spec->seqNum,
- rType, isTLS, cText->version, IS_DTLS(ss),
+ spec->epoch, cText->seqNum,
+ rType, isTLS, rVersion, IS_DTLS(ss),
plaintext->len - spec->macDef->mac_size, &header);
PORT_Assert(rv == SECSuccess);
if (cipher_def->type == type_block) {
@@ -12028,13 +12042,19 @@ ssl3_UnprotectRecord(sslSocket *ss,
return SECSuccess;
}
-static SECStatus
+SECStatus
ssl3_HandleNonApplicationData(sslSocket *ss, SSL3ContentType rType,
DTLSEpoch epoch, sslSequenceNumber seqNum,
sslBuffer *databuf)
{
SECStatus rv;
+ /* check for Token Presence */
+ if (!ssl3_ClientAuthTokenPresent(ss->sec.ci.sid)) {
+ PORT_SetError(SSL_ERROR_TOKEN_INSERTION_REMOVAL);
+ return SECFailure;
+ }
+
ssl_GetSSL3HandshakeLock(ss);
/* All the functions called in this switch MUST set error code if
@@ -12080,15 +12100,16 @@ ssl3_HandleNonApplicationData(sslSocket *ss, SSL3ContentType rType,
* Returns NULL if no appropriate cipher spec is found.
*/
static ssl3CipherSpec *
-ssl3_GetCipherSpec(sslSocket *ss, sslSequenceNumber seq)
+ssl3_GetCipherSpec(sslSocket *ss, SSL3Ciphertext *cText)
{
ssl3CipherSpec *crSpec = ss->ssl3.crSpec;
ssl3CipherSpec *newSpec = NULL;
- DTLSEpoch epoch = seq >> 48;
+ DTLSEpoch epoch;
if (!IS_DTLS(ss)) {
return crSpec;
}
+ epoch = dtls_ReadEpoch(crSpec, cText->hdr);
if (crSpec->epoch == epoch) {
return crSpec;
}
@@ -12128,16 +12149,15 @@ ssl3_GetCipherSpec(sslSocket *ss, sslSequenceNumber seq)
* Application Data records.
*/
SECStatus
-ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf)
+ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText)
{
SECStatus rv;
PRBool isTLS;
DTLSEpoch epoch;
- sslSequenceNumber seqNum = 0;
ssl3CipherSpec *spec = NULL;
PRBool outOfOrderSpec = PR_FALSE;
SSL3ContentType rType;
- sslBuffer *plaintext;
+ sslBuffer *plaintext = &ss->gs.buf;
SSL3AlertDescription alert = internal_error;
PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
@@ -12147,27 +12167,15 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf)
return SECFailure;
}
- /* cText is NULL when we're called from ssl3_RestartHandshakeAfterXXX().
- * This implies that databuf holds a previously deciphered SSL Handshake
- * message.
- */
- if (cText == NULL) {
- SSL_DBG(("%d: SSL3[%d]: HandleRecord, resuming handshake",
- SSL_GETPID(), ss->fd));
- /* Note that this doesn't pass the epoch and sequence number of the
- * record through, which DTLS 1.3 depends on. DTLS doesn't support
- * asynchronous certificate validation, so that should be OK. */
- PORT_Assert(!IS_DTLS(ss));
- return ssl3_HandleNonApplicationData(ss, content_handshake,
- 0, 0, databuf);
- }
+ /* Clear out the buffer in case this exits early. Any data then won't be
+ * processed twice. */
+ plaintext->len = 0;
ssl_GetSpecReadLock(ss); /******************************************/
- spec = ssl3_GetCipherSpec(ss, cText->seq_num);
+ spec = ssl3_GetCipherSpec(ss, cText);
if (!spec) {
PORT_Assert(IS_DTLS(ss));
ssl_ReleaseSpecReadLock(ss); /*****************************/
- databuf->len = 0; /* Needed to ensure data not left around */
return SECSuccess;
}
if (spec != ss->ssl3.crSpec) {
@@ -12178,36 +12186,30 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf)
}
isTLS = (PRBool)(spec->version > SSL_LIBRARY_VERSION_3_0);
if (IS_DTLS(ss)) {
- if (!dtls_IsRelevant(ss, spec, cText, &seqNum)) {
+ if (!dtls_IsRelevant(ss, spec, cText, &cText->seqNum)) {
ssl_ReleaseSpecReadLock(ss); /*****************************/
- databuf->len = 0; /* Needed to ensure data not left around */
-
return SECSuccess;
}
} else {
- seqNum = spec->seqNum + 1;
+ cText->seqNum = spec->nextSeqNum;
}
- if (seqNum >= spec->cipherDef->max_records) {
+ if (cText->seqNum >= spec->cipherDef->max_records) {
ssl_ReleaseSpecReadLock(ss); /*****************************/
SSL_TRC(3, ("%d: SSL[%d]: read sequence number at limit 0x%0llx",
- SSL_GETPID(), ss->fd, seqNum));
+ SSL_GETPID(), ss->fd, cText->seqNum));
PORT_SetError(SSL_ERROR_TOO_MANY_RECORDS);
return SECFailure;
}
- plaintext = databuf;
- plaintext->len = 0; /* filled in by Unprotect call below. */
-
/* We're waiting for another ClientHello, which will appear unencrypted.
* Use the content type to tell whether this is should be discarded.
*
* XXX If we decide to remove the content type from encrypted records, this
* will become much more difficult to manage. */
if (ss->ssl3.hs.zeroRttIgnore == ssl_0rtt_ignore_hrr &&
- cText->type == content_application_data) {
+ cText->hdr[0] == content_application_data) {
ssl_ReleaseSpecReadLock(ss); /*****************************/
PORT_Assert(ss->ssl3.hs.ws == wait_client_hello);
- databuf->len = 0;
return SECSuccess;
}
@@ -12224,6 +12226,7 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf)
}
#ifdef UNSAFE_FUZZER_MODE
+ rType = cText->hdr[0];
rv = Null_Cipher(NULL, plaintext->buf, (int *)&plaintext->len,
plaintext->space, cText->buf->buf, cText->buf->len);
#else
@@ -12233,9 +12236,10 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf)
if (spec->version < SSL_LIBRARY_VERSION_TLS_1_3 ||
spec->cipherDef->calg == ssl_calg_null) {
/* Unencrypted TLS 1.3 records use the pre-TLS 1.3 format. */
+ rType = cText->hdr[0];
rv = ssl3_UnprotectRecord(ss, spec, cText, plaintext, &alert);
} else {
- rv = tls13_UnprotectRecord(ss, spec, cText, plaintext, &alert);
+ rv = tls13_UnprotectRecord(ss, spec, cText, plaintext, &rType, &alert);
}
#endif
@@ -12245,14 +12249,14 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf)
SSL_DBG(("%d: SSL3[%d]: decryption failed", SSL_GETPID(), ss->fd));
/* Ensure that we don't process this data again. */
- databuf->len = 0;
+ plaintext->len = 0;
/* Ignore a CCS if the alternative handshake is negotiated. Note that
* this will fail if the server fails to negotiate the alternative
* handshake type in a 0-RTT session that is resumed from a session that
* did negotiate it. We don't care about that corner case right now. */
if (ss->version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
- cText->type == content_change_cipher_spec &&
+ cText->hdr[0] == content_change_cipher_spec &&
ss->ssl3.hs.ws != idle_handshake &&
cText->buf->len == 1 &&
cText->buf->buf[0] == change_cipher_spec_choice) {
@@ -12275,9 +12279,11 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf)
}
/* SECSuccess */
- spec->seqNum = PR_MAX(spec->seqNum, seqNum);
if (IS_DTLS(ss)) {
- dtls_RecordSetRecvd(&spec->recvdRecords, seqNum);
+ dtls_RecordSetRecvd(&spec->recvdRecords, cText->seqNum);
+ spec->nextSeqNum = PR_MAX(spec->nextSeqNum, cText->seqNum + 1);
+ } else {
+ ++spec->nextSeqNum;
}
epoch = spec->epoch;
@@ -12286,19 +12292,18 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf)
/*
* The decrypted data is now in plaintext.
*/
- rType = cText->type; /* This must go after decryption because TLS 1.3
- * has encrypted content types. */
/* IMPORTANT: We are in DTLS 1.3 mode and we have processed something
* from the wrong epoch. Divert to a divert processing function to make
* sure we don't accidentally use the data unsafely. */
if (outOfOrderSpec) {
PORT_Assert(IS_DTLS(ss) && ss->version >= SSL_LIBRARY_VERSION_TLS_1_3);
- return dtls13_HandleOutOfEpochRecord(ss, spec, rType, databuf);
+ return dtls13_HandleOutOfEpochRecord(ss, spec, rType, plaintext);
}
/* Check the length of the plaintext. */
- if (isTLS && databuf->len > MAX_FRAGMENT_LENGTH) {
+ if (isTLS && plaintext->len > MAX_FRAGMENT_LENGTH) {
+ plaintext->len = 0;
SSL3_SendAlert(ss, alert_fatal, record_overflow);
PORT_SetError(SSL_ERROR_RX_RECORD_TOO_LONG);
return SECFailure;
@@ -12313,14 +12318,16 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf)
if (ss->version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
ss->sec.isServer &&
ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted) {
- return tls13_HandleEarlyApplicationData(ss, databuf);
+ return tls13_HandleEarlyApplicationData(ss, plaintext);
}
+ plaintext->len = 0;
(void)SSL3_SendAlert(ss, alert_fatal, unexpected_message);
PORT_SetError(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA);
return SECFailure;
}
- return ssl3_HandleNonApplicationData(ss, rType, epoch, seqNum, databuf);
+ return ssl3_HandleNonApplicationData(ss, rType, epoch, cText->seqNum,
+ plaintext);
}
/*
diff --git a/lib/ssl/ssl3gthr.c b/lib/ssl/ssl3gthr.c
index 8b323bb05..b0dd7315f 100644
--- a/lib/ssl/ssl3gthr.c
+++ b/lib/ssl/ssl3gthr.c
@@ -264,8 +264,9 @@ static int
dtls_GatherData(sslSocket *ss, sslGather *gs, int flags)
{
int nb;
- int err;
- int rv = 1;
+ PRUint8 contentType;
+ unsigned int headerLen;
+ SECStatus rv;
SSL_TRC(30, ("dtls_GatherData"));
@@ -285,81 +286,96 @@ dtls_GatherData(sslSocket *ss, sslGather *gs, int flags)
** to 13 (the size of the record header).
*/
if (gs->dtlsPacket.space < MAX_FRAGMENT_LENGTH + 2048 + 13) {
- err = sslBuffer_Grow(&gs->dtlsPacket,
- MAX_FRAGMENT_LENGTH + 2048 + 13);
- if (err) { /* realloc has set error code to no mem. */
- return err;
+ rv = sslBuffer_Grow(&gs->dtlsPacket,
+ MAX_FRAGMENT_LENGTH + 2048 + 13);
+ if (rv != SECSuccess) {
+ return -1; /* Code already set. */
}
}
/* recv() needs to read a full datagram at a time */
nb = ssl_DefRecv(ss, gs->dtlsPacket.buf, gs->dtlsPacket.space, flags);
-
if (nb > 0) {
PRINT_BUF(60, (ss, "raw gather data:", gs->dtlsPacket.buf, nb));
} else if (nb == 0) {
/* EOF */
SSL_TRC(30, ("%d: SSL3[%d]: EOF", SSL_GETPID(), ss->fd));
- rv = 0;
- return rv;
+ return 0;
} else /* if (nb < 0) */ {
SSL_DBG(("%d: SSL3[%d]: recv error %d", SSL_GETPID(), ss->fd,
PR_GetError()));
- rv = SECFailure;
- return rv;
+ return -1;
}
gs->dtlsPacket.len = nb;
}
+ contentType = gs->dtlsPacket.buf[gs->dtlsPacketOffset];
+ if (dtls_IsLongHeader(ss->version, contentType)) {
+ headerLen = 13;
+ } else if (contentType == content_application_data) {
+ headerLen = 7;
+ } else if ((contentType & 0xe0) == 0x20) {
+ headerLen = 2;
+ } else {
+ SSL_DBG(("%d: SSL3[%d]: invalid first octet (%d) for DTLS",
+ SSL_GETPID(), ss->fd, contentType));
+ PORT_SetError(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE);
+ gs->dtlsPacketOffset = 0;
+ gs->dtlsPacket.len = 0;
+ return -1;
+ }
+
/* At this point we should have >=1 complete records lined up in
* dtlsPacket. Read off the header.
*/
- if ((gs->dtlsPacket.len - gs->dtlsPacketOffset) < 13) {
+ if ((gs->dtlsPacket.len - gs->dtlsPacketOffset) < headerLen) {
SSL_DBG(("%d: SSL3[%d]: rest of DTLS packet "
"too short to contain header",
SSL_GETPID(), ss->fd));
- PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
+ PORT_SetError(PR_WOULD_BLOCK_ERROR);
gs->dtlsPacketOffset = 0;
gs->dtlsPacket.len = 0;
- rv = SECFailure;
- return rv;
+ return -1;
}
- memcpy(gs->hdr, gs->dtlsPacket.buf + gs->dtlsPacketOffset, 13);
- gs->dtlsPacketOffset += 13;
+ memcpy(gs->hdr, SSL_BUFFER_BASE(&gs->dtlsPacket) + gs->dtlsPacketOffset,
+ headerLen);
+ gs->dtlsPacketOffset += headerLen;
/* Have received SSL3 record header in gs->hdr. */
- gs->remainder = (gs->hdr[11] << 8) | gs->hdr[12];
+ if (headerLen == 13) {
+ gs->remainder = (gs->hdr[11] << 8) | gs->hdr[12];
+ } else if (headerLen == 7) {
+ gs->remainder = (gs->hdr[5] << 8) | gs->hdr[6];
+ } else {
+ PORT_Assert(headerLen = 2);
+ gs->remainder = gs->dtlsPacket.len - gs->dtlsPacketOffset;
+ }
if ((gs->dtlsPacket.len - gs->dtlsPacketOffset) < gs->remainder) {
SSL_DBG(("%d: SSL3[%d]: rest of DTLS packet too short "
"to contain rest of body",
SSL_GETPID(), ss->fd));
- PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
+ PORT_SetError(PR_WOULD_BLOCK_ERROR);
gs->dtlsPacketOffset = 0;
gs->dtlsPacket.len = 0;
- rv = SECFailure;
- return rv;
+ return -1;
}
/* OK, we have at least one complete packet, copy into inbuf */
- if (gs->remainder > gs->inbuf.space) {
- err = sslBuffer_Grow(&gs->inbuf, gs->remainder);
- if (err) { /* realloc has set error code to no mem. */
- return err;
- }
+ gs->inbuf.len = 0;
+ rv = sslBuffer_Append(&gs->inbuf,
+ SSL_BUFFER_BASE(&gs->dtlsPacket) + gs->dtlsPacketOffset,
+ gs->remainder);
+ if (rv != SECSuccess) {
+ return -1; /* code already set. */
}
-
- SSL_TRC(20, ("%d: SSL3[%d]: dtls gathered record type=%d len=%d",
- SSL_GETPID(), ss->fd, gs->hdr[0], gs->inbuf.len));
-
- memcpy(gs->inbuf.buf, gs->dtlsPacket.buf + gs->dtlsPacketOffset,
- gs->remainder);
- gs->inbuf.len = gs->remainder;
gs->offset = gs->remainder;
gs->dtlsPacketOffset += gs->remainder;
gs->state = GS_INIT;
+ SSL_TRC(20, ("%d: SSL3[%d]: dtls gathered record type=%d len=%d",
+ SSL_GETPID(), ss->fd, contentType, gs->inbuf.len));
return 1;
}
@@ -442,7 +458,11 @@ ssl3_GatherCompleteHandshake(sslSocket *ss, int flags)
* We need to process it now before we overwrite it with the next
* handshake record.
*/
- rv = ssl3_HandleRecord(ss, NULL, &ss->gs.buf);
+ SSL_DBG(("%d: SSL3[%d]: resuming handshake",
+ SSL_GETPID(), ss->fd));
+ PORT_Assert(!IS_DTLS(ss));
+ rv = ssl3_HandleNonApplicationData(ss, content_handshake,
+ 0, 0, &ss->gs.buf);
} else {
/* State for SSLv2 client hello support. */
ssl2Gather ssl2gs = { PR_FALSE, 0 };
@@ -495,20 +515,13 @@ ssl3_GatherCompleteHandshake(sslSocket *ss, int flags)
* If it's application data, ss->gs.buf will not be empty upon return.
* If it's a change cipher spec, alert, or handshake message,
* ss->gs.buf.len will be 0 when ssl3_HandleRecord returns SECSuccess.
+ *
+ * cText only needs to be valid for this next function call, so
+ * it can borrow gs.hdr.
*/
- cText.type = (SSL3ContentType)ss->gs.hdr[0];
- cText.version = (ss->gs.hdr[1] << 8) | ss->gs.hdr[2];
-
- if (IS_DTLS(ss)) {
- sslSequenceNumber seq_num;
-
- /* DTLS sequence number */
- PORT_Memcpy(&seq_num, &ss->gs.hdr[3], sizeof(seq_num));
- cText.seq_num = PR_ntohll(seq_num);
- }
-
+ cText.hdr = ss->gs.hdr;
cText.buf = &ss->gs.inbuf;
- rv = ssl3_HandleRecord(ss, &cText, &ss->gs.buf);
+ rv = ssl3_HandleRecord(ss, &cText);
}
}
if (rv < 0) {
@@ -520,7 +533,6 @@ ssl3_GatherCompleteHandshake(sslSocket *ss, int flags)
* completing any renegotiation handshake we may be doing.
*/
PORT_Assert(ss->firstHsDone);
- PORT_Assert(cText.type == content_application_data);
break;
}
diff --git a/lib/ssl/sslimpl.h b/lib/ssl/sslimpl.h
index 10d0333d9..799049c32 100644
--- a/lib/ssl/sslimpl.h
+++ b/lib/ssl/sslimpl.h
@@ -261,6 +261,7 @@ typedef struct sslOptionsStr {
unsigned int requireDHENamedGroups : 1;
unsigned int enable0RttData : 1;
unsigned int enableTls13CompatMode : 1;
+ unsigned int enableDtlsShortHeader : 1;
} sslOptions;
typedef enum { sslHandshakingUndetermined = 0,
@@ -780,9 +781,11 @@ struct ssl3StateStr {
#define IS_DTLS(ss) (ss->protocolVariant == ssl_variant_datagram)
typedef struct {
- SSL3ContentType type;
- SSL3ProtocolVersion version;
- sslSequenceNumber seq_num; /* DTLS only */
+ /* |seqNum| eventually contains the reconstructed sequence number. */
+ sslSequenceNumber seqNum;
+ /* The header of the cipherText. */
+ const PRUint8 *hdr;
+ /* |buf| is the payload of the ciphertext. */
sslBuffer *buf;
} SSL3Ciphertext;
@@ -1375,8 +1378,11 @@ SECStatus ssl3_SendClientHello(sslSocket *ss, sslClientHelloType type);
/*
* input into the SSL3 machinery from the actualy network reading code
*/
-SECStatus ssl3_HandleRecord(
- sslSocket *ss, SSL3Ciphertext *cipher, sslBuffer *out);
+SECStatus ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cipher);
+SECStatus ssl3_HandleNonApplicationData(sslSocket *ss, SSL3ContentType rType,
+ DTLSEpoch epoch,
+ sslSequenceNumber seqNum,
+ sslBuffer *databuf);
SECStatus ssl_RemoveTLSCBCPadding(sslBuffer *plaintext, unsigned int macSize);
int ssl3_GatherAppDataRecord(sslSocket *ss, int flags);
diff --git a/lib/ssl/sslsecur.c b/lib/ssl/sslsecur.c
index f09ec067c..d3424a7ad 100644
--- a/lib/ssl/sslsecur.c
+++ b/lib/ssl/sslsecur.c
@@ -791,7 +791,7 @@ tls13_CheckKeyUpdate(sslSocket *ss, CipherSpecDirection dir)
spec = ss->ssl3.cwSpec;
margin = spec->cipherDef->max_records / 4;
}
- seqNum = spec->seqNum;
+ seqNum = spec->nextSeqNum;
keyUpdate = seqNum > spec->cipherDef->max_records - margin;
ssl_ReleaseSpecReadLock(ss);
if (!keyUpdate) {
diff --git a/lib/ssl/sslsock.c b/lib/ssl/sslsock.c
index e08d5e232..f5d829d4e 100644
--- a/lib/ssl/sslsock.c
+++ b/lib/ssl/sslsock.c
@@ -81,7 +81,8 @@ static sslOptions ssl_defaults = {
.enableSignedCertTimestamps = PR_FALSE,
.requireDHENamedGroups = PR_FALSE,
.enable0RttData = PR_FALSE,
- .enableTls13CompatMode = PR_FALSE
+ .enableTls13CompatMode = PR_FALSE,
+ .enableDtlsShortHeader = PR_FALSE
};
/*
@@ -807,6 +808,10 @@ SSL_OptionSet(PRFileDesc *fd, PRInt32 which, PRIntn val)
ss->opt.enableTls13CompatMode = val;
break;
+ case SSL_ENABLE_DTLS_SHORT_HEADER:
+ ss->opt.enableDtlsShortHeader = val;
+ break;
+
default:
PORT_SetError(SEC_ERROR_INVALID_ARGS);
rv = SECFailure;
@@ -943,6 +948,9 @@ SSL_OptionGet(PRFileDesc *fd, PRInt32 which, PRIntn *pVal)
case SSL_ENABLE_TLS13_COMPAT_MODE:
val = ss->opt.enableTls13CompatMode;
break;
+ case SSL_ENABLE_DTLS_SHORT_HEADER:
+ val = ss->opt.enableDtlsShortHeader;
+ break;
default:
PORT_SetError(SEC_ERROR_INVALID_ARGS);
rv = SECFailure;
@@ -1063,6 +1071,9 @@ SSL_OptionGetDefault(PRInt32 which, PRIntn *pVal)
case SSL_ENABLE_TLS13_COMPAT_MODE:
val = ssl_defaults.enableTls13CompatMode;
break;
+ case SSL_ENABLE_DTLS_SHORT_HEADER:
+ val = ssl_defaults.enableDtlsShortHeader;
+ break;
default:
PORT_SetError(SEC_ERROR_INVALID_ARGS);
rv = SECFailure;
@@ -1246,6 +1257,10 @@ SSL_OptionSetDefault(PRInt32 which, PRIntn val)
ssl_defaults.enableTls13CompatMode = val;
break;
+ case SSL_ENABLE_DTLS_SHORT_HEADER:
+ ssl_defaults.enableDtlsShortHeader = val;
+ break;
+
default:
PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
diff --git a/lib/ssl/sslspec.h b/lib/ssl/sslspec.h
index 729ac1006..207bd6ef6 100644
--- a/lib/ssl/sslspec.h
+++ b/lib/ssl/sslspec.h
@@ -162,7 +162,9 @@ struct ssl3CipherSpecStr {
DTLSEpoch epoch;
const char *phase;
- sslSequenceNumber seqNum;
+
+ /* The next sequence number to be sent or received. */
+ sslSequenceNumber nextSeqNum;
DTLSRecvdRecords recvdRecords;
/* The number of 0-RTT bytes that can be sent or received in TLS 1.3. This
diff --git a/lib/ssl/tls13con.c b/lib/ssl/tls13con.c
index c06acc83a..ed0bb5ecb 100644
--- a/lib/ssl/tls13con.c
+++ b/lib/ssl/tls13con.c
@@ -792,7 +792,7 @@ tls13_HandleKeyUpdate(sslSocket *ss, PRUint8 *b, unsigned int length)
/* Only send an update if we have sent with the current spec. This
* prevents us from being forced to crank forward pointlessly. */
ssl_GetSpecReadLock(ss);
- sendUpdate = ss->ssl3.cwSpec->seqNum > 0;
+ sendUpdate = ss->ssl3.cwSpec->nextSeqNum > 0;
ssl_ReleaseSpecReadLock(ss);
} else {
sendUpdate = PR_TRUE;
@@ -1620,7 +1620,7 @@ tls13_HandleClientHelloPart2(sslSocket *ss,
ssl_GetSpecWriteLock(ss);
/* Increase the write sequence number. The read sequence number
* will be reset after this to early data or handshake. */
- ss->ssl3.cwSpec->seqNum = 1;
+ ss->ssl3.cwSpec->nextSeqNum = 1;
ssl_ReleaseSpecWriteLock(ss);
}
@@ -2007,7 +2007,7 @@ tls13_SendHelloRetryRequest(sslSocket *ss,
/* We depend on this being exactly one record and one message. */
PORT_Assert(!IS_DTLS(ss) || (ss->ssl3.hs.sendMessageSeq == 1 &&
- ss->ssl3.cwSpec->seqNum == 1));
+ ss->ssl3.cwSpec->nextSeqNum == 1));
ssl_ReleaseXmitBufLock(ss);
ss->ssl3.hs.helloRetry = PR_TRUE;
@@ -3316,7 +3316,7 @@ tls13_SetCipherSpec(sslSocket *ss, PRUint16 epoch,
return SECFailure;
}
spec->epoch = epoch;
- spec->seqNum = 0;
+ spec->nextSeqNum = 0;
if (IS_DTLS(ss)) {
dtls_InitRecvdRecords(&spec->recvdRecords);
}
@@ -4843,43 +4843,48 @@ tls13_ProtectRecord(sslSocket *ss,
PORT_Assert(cwSpec->direction == CipherSpecWrite);
SSL_TRC(3, ("%d: TLS13[%d]: spec=%d epoch=%d (%s) protect 0x%0llx len=%u",
SSL_GETPID(), ss->fd, cwSpec, cwSpec->epoch, cwSpec->phase,
- cwSpec->seqNum, contentLen));
+ cwSpec->nextSeqNum, contentLen));
- if (contentLen + 1 + tagLen > wrBuf->space) {
+ if (contentLen + 1 + tagLen > SSL_BUFFER_SPACE(wrBuf)) {
PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
return SECFailure;
}
/* Copy the data into the wrBuf. We're going to encrypt in-place
* in the AEAD branch anyway */
- PORT_Memcpy(wrBuf->buf, pIn, contentLen);
+ PORT_Memcpy(SSL_BUFFER_NEXT(wrBuf), pIn, contentLen);
if (cipher_def->calg == ssl_calg_null) {
/* Shortcut for plaintext */
- wrBuf->len = contentLen;
+ rv = sslBuffer_Skip(wrBuf, contentLen, NULL);
+ PORT_Assert(rv == SECSuccess);
} else {
PRUint8 aad[8];
+ int len;
PORT_Assert(cipher_def->type == type_aead);
/* Add the content type at the end. */
- wrBuf->buf[contentLen] = type;
+ *(SSL_BUFFER_NEXT(wrBuf) + contentLen) = type;
rv = tls13_FormatAdditionalData(ss, aad, sizeof(aad), cwSpec->epoch,
- cwSpec->seqNum);
+ cwSpec->nextSeqNum);
if (rv != SECSuccess) {
return SECFailure;
}
rv = cwSpec->aead(&cwSpec->keyMaterial,
- PR_FALSE, /* do encrypt */
- wrBuf->buf, /* output */
- (int *)&wrBuf->len, /* out len */
- wrBuf->space, /* max out */
- wrBuf->buf, contentLen + 1, /* input */
+ PR_FALSE, /* do encrypt */
+ SSL_BUFFER_NEXT(wrBuf), /* output */
+ &len, /* out len */
+ SSL_BUFFER_SPACE(wrBuf), /* max out */
+ SSL_BUFFER_NEXT(wrBuf), /* input */
+ contentLen + 1, /* input len */
aad, sizeof(aad));
if (rv != SECSuccess) {
PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE);
return SECFailure;
}
+ rv = sslBuffer_Skip(wrBuf, len, NULL);
+ PORT_Assert(rv == SECSuccess);
}
return SECSuccess;
@@ -4897,25 +4902,21 @@ tls13_ProtectRecord(sslSocket *ss,
SECStatus
tls13_UnprotectRecord(sslSocket *ss,
ssl3CipherSpec *spec,
- SSL3Ciphertext *cText, sslBuffer *plaintext,
+ SSL3Ciphertext *cText,
+ sslBuffer *plaintext,
+ SSL3ContentType *innerType,
SSL3AlertDescription *alert)
{
const ssl3BulkCipherDef *cipher_def = spec->cipherDef;
- sslSequenceNumber seqNum;
PRUint8 aad[8];
SECStatus rv;
*alert = bad_record_mac; /* Default alert for most issues. */
PORT_Assert(spec->direction == CipherSpecRead);
- if (IS_DTLS(ss)) {
- seqNum = cText->seq_num & RECORD_SEQ_MASK;
- } else {
- seqNum = spec->seqNum;
- }
SSL_TRC(3, ("%d: TLS13[%d]: spec=%d epoch=%d (%s) unprotect 0x%0llx len=%u",
- SSL_GETPID(), ss->fd, spec, spec->epoch, spec->phase, seqNum,
- cText->buf->len));
+ SSL_GETPID(), ss->fd, spec, spec->epoch, spec->phase,
+ cText->seqNum, cText->buf->len));
/* We can perform this test in variable time because the record's total
* length and the ciphersuite are both public knowledge. */
@@ -4927,28 +4928,37 @@ tls13_UnprotectRecord(sslSocket *ss,
return SECFailure;
}
- /* Verify that the content type is right, even though we overwrite it. */
- if (cText->type != content_application_data) {
+ /* Verify that the content type is right, even though we overwrite it.
+ * Also allow the DTLS short header in TLS 1.3. */
+ if (!(cText->hdr[0] == content_application_data ||
+ (IS_DTLS(ss) &&
+ ss->version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ (cText->hdr[0] & 0xe0) == 0x20))) {
SSL_TRC(3,
- ("%d: TLS13[%d]: record has invalid exterior content type=%d",
- SSL_GETPID(), ss->fd, cText->type));
+ ("%d: TLS13[%d]: record has invalid exterior type=%2.2x",
+ SSL_GETPID(), ss->fd, cText->hdr[0]));
/* Do we need a better error here? */
PORT_SetError(SSL_ERROR_BAD_MAC_READ);
return SECFailure;
}
- /* Check the version number in the record. */
- if (cText->version != spec->recordVersion) {
- /* Do we need a better error here? */
- SSL_TRC(3,
- ("%d: TLS13[%d]: record has bogus version",
- SSL_GETPID(), ss->fd));
- return SECFailure;
+ /* Check the version number in the record. Stream only. */
+ if (!IS_DTLS(ss)) {
+ SSL3ProtocolVersion version =
+ ((SSL3ProtocolVersion)cText->hdr[1] << 8) |
+ (SSL3ProtocolVersion)cText->hdr[2];
+ if (version != spec->recordVersion) {
+ /* Do we need a better error here? */
+ SSL_TRC(3, ("%d: TLS13[%d]: record has bogus version",
+ SSL_GETPID(), ss->fd));
+ return SECFailure;
+ }
}
/* Decrypt */
PORT_Assert(cipher_def->type == type_aead);
- rv = tls13_FormatAdditionalData(ss, aad, sizeof(aad), spec->epoch, seqNum);
+ rv = tls13_FormatAdditionalData(ss, aad, sizeof(aad), spec->epoch,
+ cText->seqNum);
if (rv != SECSuccess) {
return SECFailure;
}
@@ -4977,9 +4987,7 @@ tls13_UnprotectRecord(sslSocket *ss,
/* Bogus padding. */
if (plaintext->len < 1) {
- SSL_TRC(3,
- ("%d: TLS13[%d]: empty record",
- SSL_GETPID(), ss->fd, cText->type));
+ SSL_TRC(3, ("%d: TLS13[%d]: empty record", SSL_GETPID(), ss->fd));
/* It's safe to report this specifically because it happened
* after the MAC has been verified. */
PORT_SetError(SSL_ERROR_BAD_BLOCK_PADDING);
@@ -4987,12 +4995,12 @@ tls13_UnprotectRecord(sslSocket *ss,
}
/* Record the type. */
- cText->type = plaintext->buf[plaintext->len - 1];
+ *innerType = (SSL3ContentType)plaintext->buf[plaintext->len - 1];
--plaintext->len;
/* Check that we haven't received too much 0-RTT data. */
if (spec->epoch == TrafficKeyEarlyApplicationData &&
- cText->type == content_application_data) {
+ *innerType == content_application_data) {
if (plaintext->len > spec->earlyDataRemaining) {
*alert = unexpected_message;
PORT_SetError(SSL_ERROR_TOO_MUCH_EARLY_DATA);
@@ -5002,9 +5010,8 @@ tls13_UnprotectRecord(sslSocket *ss,
}
SSL_TRC(10,
- ("%d: TLS13[%d]: %s received record of length=%d type=%d",
- SSL_GETPID(), ss->fd, SSL_ROLE(ss),
- plaintext->len, cText->type));
+ ("%d: TLS13[%d]: %s received record of length=%d, type=%d",
+ SSL_GETPID(), ss->fd, SSL_ROLE(ss), plaintext->len, *innerType));
return SECSuccess;
}
diff --git a/lib/ssl/tls13con.h b/lib/ssl/tls13con.h
index 1aaffb651..7ef1fabc2 100644
--- a/lib/ssl/tls13con.h
+++ b/lib/ssl/tls13con.h
@@ -28,6 +28,7 @@ typedef enum {
SECStatus tls13_UnprotectRecord(
sslSocket *ss, ssl3CipherSpec *spec,
SSL3Ciphertext *cText, sslBuffer *plaintext,
+ SSL3ContentType *innerType,
SSL3AlertDescription *alert);
#if defined(WIN32)