diff options
-rw-r--r-- | gtests/ssl_gtest/libssl_internals.c | 13 | ||||
-rw-r--r-- | gtests/ssl_gtest/ssl_ciphersuite_unittest.cc | 11 | ||||
-rw-r--r-- | gtests/ssl_gtest/ssl_drop_unittest.cc | 78 | ||||
-rw-r--r-- | gtests/ssl_gtest/ssl_fragment_unittest.cc | 28 | ||||
-rw-r--r-- | gtests/ssl_gtest/ssl_hrr_unittest.cc | 5 | ||||
-rw-r--r-- | gtests/ssl_gtest/ssl_loopback_unittest.cc | 3 | ||||
-rw-r--r-- | gtests/ssl_gtest/ssl_record_unittest.cc | 23 | ||||
-rw-r--r-- | gtests/ssl_gtest/tls_agent.cc | 22 | ||||
-rw-r--r-- | gtests/ssl_gtest/tls_agent.h | 3 | ||||
-rw-r--r-- | gtests/ssl_gtest/tls_filter.cc | 187 | ||||
-rw-r--r-- | gtests/ssl_gtest/tls_filter.h | 29 | ||||
-rw-r--r-- | lib/ssl/dtls13con.c | 37 | ||||
-rw-r--r-- | lib/ssl/dtls13con.h | 4 | ||||
-rw-r--r-- | lib/ssl/dtlscon.c | 105 | ||||
-rw-r--r-- | lib/ssl/dtlscon.h | 2 | ||||
-rw-r--r-- | lib/ssl/ssl.h | 11 | ||||
-rw-r--r-- | lib/ssl/ssl3con.c | 233 | ||||
-rw-r--r-- | lib/ssl/ssl3gthr.c | 106 | ||||
-rw-r--r-- | lib/ssl/sslimpl.h | 16 | ||||
-rw-r--r-- | lib/ssl/sslsecur.c | 2 | ||||
-rw-r--r-- | lib/ssl/sslsock.c | 17 | ||||
-rw-r--r-- | lib/ssl/sslspec.h | 4 | ||||
-rw-r--r-- | lib/ssl/tls13con.c | 95 | ||||
-rw-r--r-- | lib/ssl/tls13con.h | 1 |
24 files changed, 724 insertions, 311 deletions
diff --git a/gtests/ssl_gtest/libssl_internals.c b/gtests/ssl_gtest/libssl_internals.c index 17b4ffe49..e43113de4 100644 --- a/gtests/ssl_gtest/libssl_internals.c +++ b/gtests/ssl_gtest/libssl_internals.c @@ -237,22 +237,23 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { if (!ss) { return SECFailure; } - if (to >= RECORD_SEQ_MAX) { + if (to > RECORD_SEQ_MAX) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } ssl_GetSpecWriteLock(ss); spec = ss->ssl3.crSpec; - spec->seqNum = to; + spec->nextSeqNum = to; /* For DTLS, we need to fix the record sequence number. For this, we can just * scrub the entire structure on the assumption that the new sequence number * is far enough past the last received sequence number. */ - if (spec->seqNum <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { + if (spec->nextSeqNum <= + spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } - dtls_RecordSetRecvd(&spec->recvdRecords, spec->seqNum); + dtls_RecordSetRecvd(&spec->recvdRecords, spec->nextSeqNum - 1); ssl_ReleaseSpecWriteLock(ss); return SECSuccess; @@ -270,7 +271,7 @@ SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) { return SECFailure; } ssl_GetSpecWriteLock(ss); - ss->ssl3.cwSpec->seqNum = to; + ss->ssl3.cwSpec->nextSeqNum = to; ssl_ReleaseSpecWriteLock(ss); return SECSuccess; } @@ -284,7 +285,7 @@ SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) { return SECFailure; } ssl_GetSpecReadLock(ss); - to = ss->ssl3.cwSpec->seqNum + DTLS_RECVD_RECORDS_WINDOW + extra; + to = ss->ssl3.cwSpec->nextSeqNum + DTLS_RECVD_RECORDS_WINDOW + extra; ssl_ReleaseSpecReadLock(ss); return SSLInt_AdvanceWriteSeqNum(fd, to); } diff --git a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc index fa2238be7..ec289bdd6 100644 --- a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc +++ b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc @@ -166,8 +166,8 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { case ssl_calg_seed: break; } - EXPECT_TRUE(false) << "No limit for " << csinfo_.cipherSuiteName; - return 1ULL < 48; + ADD_FAILURE() << "No limit for " << csinfo_.cipherSuiteName; + return 0; } uint64_t last_safe_write() const { @@ -246,12 +246,13 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) { client_->SendData(10, 10); server_->ReadBytes(); // This should be OK. + server_->ReadBytes(); // Read twice to flush any 1,N-1 record splitting. } else { // In TLS 1.3, reading or writing triggers a KeyUpdate. That would mean // that the sequence numbers would reset and we wouldn't hit the limit. So - // we move the sequence number to one less than the limit directly and don't - // test sending and receiving just before the limit. - uint64_t last = record_limit() - 1; + // move the sequence number to the limit directly and don't test sending and + // receiving just before the limit. + uint64_t last = record_limit(); EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last)); } diff --git a/gtests/ssl_gtest/ssl_drop_unittest.cc b/gtests/ssl_gtest/ssl_drop_unittest.cc index ee8906deb..a6c25bacf 100644 --- a/gtests/ssl_gtest/ssl_drop_unittest.cc +++ b/gtests/ssl_gtest/ssl_drop_unittest.cc @@ -66,7 +66,8 @@ TEST_P(TlsConnectDatagramPre13, DropServerSecondFlightThrice) { Connect(); } -class TlsDropDatagram13 : public TlsConnectDatagram13 { +class TlsDropDatagram13 : public TlsConnectDatagram13, + public ::testing::WithParamInterface<bool> { public: TlsDropDatagram13() : client_filters_(), @@ -77,6 +78,9 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 { void SetUp() override { TlsConnectDatagram13::SetUp(); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + int short_header = GetParam() ? PR_TRUE : PR_FALSE; + client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, short_header); + server_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, short_header); SetFilters(); } @@ -186,7 +190,7 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 { // to the client upon receiving the client Finished. // Dropping complete first and second flights does not produce // ACKs -TEST_F(TlsDropDatagram13, DropClientFirstFlightOnce) { +TEST_P(TlsDropDatagram13, DropClientFirstFlightOnce) { client_filters_.drop_->Reset({0}); StartConnect(); client_->Handshake(); @@ -195,7 +199,7 @@ TEST_F(TlsDropDatagram13, DropClientFirstFlightOnce) { CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); } -TEST_F(TlsDropDatagram13, DropServerFirstFlightOnce) { +TEST_P(TlsDropDatagram13, DropServerFirstFlightOnce) { server_filters_.drop_->Reset(0xff); StartConnect(); client_->Handshake(); @@ -209,7 +213,7 @@ TEST_F(TlsDropDatagram13, DropServerFirstFlightOnce) { // Dropping the server's first record also does not produce // an ACK because the next record is ignored. // TODO(ekr@rtfm.com): We should generate an empty ACK. -TEST_F(TlsDropDatagram13, DropServerFirstRecordOnce) { +TEST_P(TlsDropDatagram13, DropServerFirstRecordOnce) { server_filters_.drop_->Reset({0}); StartConnect(); client_->Handshake(); @@ -221,7 +225,7 @@ TEST_F(TlsDropDatagram13, DropServerFirstRecordOnce) { // Dropping the second packet of the server's flight should // produce an ACK. -TEST_F(TlsDropDatagram13, DropServerSecondRecordOnce) { +TEST_P(TlsDropDatagram13, DropServerSecondRecordOnce) { server_filters_.drop_->Reset({1}); StartConnect(); client_->Handshake(); @@ -235,7 +239,7 @@ TEST_F(TlsDropDatagram13, DropServerSecondRecordOnce) { // Drop the server ACK and verify that the client retransmits // the ClientHello. -TEST_F(TlsDropDatagram13, DropServerAckOnce) { +TEST_P(TlsDropDatagram13, DropServerAckOnce) { StartConnect(); client_->Handshake(); server_->Handshake(); @@ -263,7 +267,7 @@ TEST_F(TlsDropDatagram13, DropServerAckOnce) { } // Drop the client certificate verify. -TEST_F(TlsDropDatagram13, DropClientCertVerify) { +TEST_P(TlsDropDatagram13, DropClientCertVerify) { StartConnect(); client_->SetupClientAuth(); server_->RequestClientAuth(true); @@ -284,7 +288,7 @@ TEST_F(TlsDropDatagram13, DropClientCertVerify) { } // Shrink the MTU down so that certs get split and drop the first piece. -TEST_F(TlsDropDatagram13, DropFirstHalfOfServerCertificate) { +TEST_P(TlsDropDatagram13, DropFirstHalfOfServerCertificate) { server_filters_.drop_->Reset({2}); StartConnect(); ShrinkPostServerHelloMtu(); @@ -311,7 +315,7 @@ TEST_F(TlsDropDatagram13, DropFirstHalfOfServerCertificate) { } // Shrink the MTU down so that certs get split and drop the second piece. -TEST_F(TlsDropDatagram13, DropSecondHalfOfServerCertificate) { +TEST_P(TlsDropDatagram13, DropSecondHalfOfServerCertificate) { server_filters_.drop_->Reset({3}); StartConnect(); ShrinkPostServerHelloMtu(); @@ -524,11 +528,11 @@ class TlsFragmentationAndRecoveryTest : public TlsDropDatagram13 { size_t cert_len_; }; -TEST_F(TlsFragmentationAndRecoveryTest, DropFirstHalf) { RunTest(0); } +TEST_P(TlsFragmentationAndRecoveryTest, DropFirstHalf) { RunTest(0); } -TEST_F(TlsFragmentationAndRecoveryTest, DropSecondHalf) { RunTest(1); } +TEST_P(TlsFragmentationAndRecoveryTest, DropSecondHalf) { RunTest(1); } -TEST_F(TlsDropDatagram13, NoDropsDuringZeroRtt) { +TEST_P(TlsDropDatagram13, NoDropsDuringZeroRtt) { SetupForZeroRtt(); SetFilters(); std::cerr << "Starting second handshake" << std::endl; @@ -546,7 +550,7 @@ TEST_F(TlsDropDatagram13, NoDropsDuringZeroRtt) { 0x0002000000000000ULL}); // Finished } -TEST_F(TlsDropDatagram13, DropEEDuringZeroRtt) { +TEST_P(TlsDropDatagram13, DropEEDuringZeroRtt) { SetupForZeroRtt(); SetFilters(); std::cerr << "Starting second handshake" << std::endl; @@ -591,7 +595,7 @@ class TlsReorderDatagram13 : public TlsDropDatagram13 { // Reorder the server records so that EE comes at the end // of the flight and will still produce an ACK. -TEST_F(TlsDropDatagram13, ReorderServerEE) { +TEST_P(TlsDropDatagram13, ReorderServerEE) { server_filters_.drop_->Reset({1}); StartConnect(); client_->Handshake(); @@ -647,7 +651,7 @@ class TlsSendCipherSpecCapturer { std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_; }; -TEST_F(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) { +TEST_P(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) { StartConnect(); TlsSendCipherSpecCapturer capturer(client_); client_->Handshake(); @@ -662,9 +666,9 @@ TEST_F(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) { auto spec = capturer.spec(0); ASSERT_NE(nullptr, spec.get()); ASSERT_EQ(2, spec->epoch()); - ASSERT_TRUE(client_->SendEncryptedRecord( - spec, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 0x0002000000000002, - kTlsApplicationDataType, DataBuffer(buf, sizeof(buf)))); + ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002, + kTlsApplicationDataType, + DataBuffer(buf, sizeof(buf)))); // Now have the server consume the bogus message. server_->ExpectSendAlert(illegal_parameter, kTlsAlertFatal); @@ -673,7 +677,7 @@ TEST_F(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) { EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, PORT_GetError()); } -TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) { +TEST_P(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) { StartConnect(); TlsSendCipherSpecCapturer capturer(client_); client_->Handshake(); @@ -688,9 +692,9 @@ TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) { auto spec = capturer.spec(0); ASSERT_NE(nullptr, spec.get()); ASSERT_EQ(2, spec->epoch()); - ASSERT_TRUE(client_->SendEncryptedRecord( - spec, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 0x0002000000000002, - kTlsHandshakeType, DataBuffer(buf, sizeof(buf)))); + ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002, + kTlsHandshakeType, + DataBuffer(buf, sizeof(buf)))); server_->Handshake(); EXPECT_EQ(2UL, server_filters_.ack_->count()); // The server acknowledges client Finished twice. @@ -700,7 +704,7 @@ TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) { // Shrink the MTU down so that certs get split and then swap the first and // second pieces of the server certificate. -TEST_F(TlsReorderDatagram13, ReorderServerCertificate) { +TEST_P(TlsReorderDatagram13, ReorderServerCertificate) { StartConnect(); ShrinkPostServerHelloMtu(); client_->Handshake(); @@ -722,7 +726,7 @@ TEST_F(TlsReorderDatagram13, ReorderServerCertificate) { CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); } -TEST_F(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) { +TEST_P(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) { SetupForZeroRtt(); SetFilters(); std::cerr << "Starting second handshake" << std::endl; @@ -761,7 +765,7 @@ TEST_F(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) { EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); } -TEST_F(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) { +TEST_P(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) { SetupForZeroRtt(); SetFilters(); std::cerr << "Starting second handshake" << std::endl; @@ -812,12 +816,17 @@ static void GetCipherAndLimit(uint16_t version, uint16_t* cipher, *cipher = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256; *limit = (1ULL << 48) - 1; } else { + // This test probably isn't especially useful for TLS 1.3, which has a much + // shorter sequence number encoding. That space can probably be searched in + // a reasonable amount of time. *cipher = TLS_CHACHA20_POLY1305_SHA256; - *limit = (1ULL << 48) - 1; + // Assume that we are starting with an expected sequence number of 0. + *limit = (1ULL << 29) - 1; } } // This simulates a huge number of drops on one side. +// See Bug 12965514 where a large gap was handled very inefficiently. TEST_P(TlsConnectDatagram, MissLotsOfPackets) { uint16_t cipher; uint64_t limit; @@ -834,6 +843,17 @@ TEST_P(TlsConnectDatagram, MissLotsOfPackets) { SendReceive(); } +// Send a sequence number of 0xfffffffd and it should be interpreted as that +// (and not -3 or UINT64_MAX - 2). +TEST_F(TlsConnectDatagram13, UnderflowSequenceNumber) { + Connect(); + // This is only valid if short headers are disabled. + client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_FALSE); + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), (1ULL << 30) - 3)); + SendReceive(); +} + class TlsConnectDatagram12Plus : public TlsConnectDatagram { public: TlsConnectDatagram12Plus() : TlsConnectDatagram() {} @@ -865,5 +885,11 @@ INSTANTIATE_TEST_CASE_P(Datagram12Plus, TlsConnectDatagram12Plus, TlsConnectTestBase::kTlsV12Plus); INSTANTIATE_TEST_CASE_P(DatagramPre13, TlsConnectDatagramPre13, TlsConnectTestBase::kTlsV11V12); +INSTANTIATE_TEST_CASE_P(DatagramDrop13, TlsDropDatagram13, + ::testing::Values(true, false)); +INSTANTIATE_TEST_CASE_P(DatagramReorder13, TlsReorderDatagram13, + ::testing::Values(true, false)); +INSTANTIATE_TEST_CASE_P(DatagramFragment13, TlsFragmentationAndRecoveryTest, + ::testing::Values(true, false)); } // namespace nss_test diff --git a/gtests/ssl_gtest/ssl_fragment_unittest.cc b/gtests/ssl_gtest/ssl_fragment_unittest.cc index f4940bf28..92947c2c7 100644 --- a/gtests/ssl_gtest/ssl_fragment_unittest.cc +++ b/gtests/ssl_gtest/ssl_fragment_unittest.cc @@ -20,14 +20,16 @@ namespace nss_test { // This class cuts every unencrypted handshake record into two parts. class RecordFragmenter : public PacketFilter { public: - RecordFragmenter() : sequence_number_(0), splitting_(true) {} + RecordFragmenter(bool is_dtls13) + : is_dtls13_(is_dtls13), sequence_number_(0), splitting_(true) {} private: class HandshakeSplitter { public: - HandshakeSplitter(const DataBuffer& input, DataBuffer* output, - uint64_t* sequence_number) - : input_(input), + HandshakeSplitter(bool is_dtls13, const DataBuffer& input, + DataBuffer* output, uint64_t* sequence_number) + : is_dtls13_(is_dtls13), + input_(input), output_(output), cursor_(0), sequence_number_(sequence_number) {} @@ -35,9 +37,9 @@ class RecordFragmenter : public PacketFilter { private: void WriteRecord(TlsRecordHeader& record_header, DataBuffer& record_fragment) { - TlsRecordHeader fragment_header(record_header.version(), - record_header.content_type(), - *sequence_number_); + TlsRecordHeader fragment_header( + record_header.variant(), record_header.version(), + record_header.content_type(), *sequence_number_); ++*sequence_number_; if (::g_ssl_gtest_verbose) { std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment @@ -88,7 +90,7 @@ class RecordFragmenter : public PacketFilter { while (parser.remaining()) { TlsRecordHeader header; DataBuffer record; - if (!header.Parse(0, &parser, &record)) { + if (!header.Parse(is_dtls13_, 0, &parser, &record)) { ADD_FAILURE() << "bad record header"; return false; } @@ -118,6 +120,7 @@ class RecordFragmenter : public PacketFilter { } private: + bool is_dtls13_; const DataBuffer& input_; DataBuffer* output_; size_t cursor_; @@ -132,7 +135,7 @@ class RecordFragmenter : public PacketFilter { } output->Allocate(input.len()); - HandshakeSplitter splitter(input, output, &sequence_number_); + HandshakeSplitter splitter(is_dtls13_, input, output, &sequence_number_); if (!splitter.Split()) { // If splitting fails, we obviously reached encrypted packets. // Stop splitting from that point onward. @@ -144,18 +147,21 @@ class RecordFragmenter : public PacketFilter { } private: + bool is_dtls13_; uint64_t sequence_number_; bool splitting_; }; TEST_P(TlsConnectDatagram, FragmentClientPackets) { - client_->SetFilter(std::make_shared<RecordFragmenter>()); + bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3; + client_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13)); Connect(); SendReceive(); } TEST_P(TlsConnectDatagram, FragmentServerPackets) { - server_->SetFilter(std::make_shared<RecordFragmenter>()); + bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3; + server_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13)); Connect(); SendReceive(); } diff --git a/gtests/ssl_gtest/ssl_hrr_unittest.cc b/gtests/ssl_gtest/ssl_hrr_unittest.cc index ba4cd804d..c78b328d8 100644 --- a/gtests/ssl_gtest/ssl_hrr_unittest.cc +++ b/gtests/ssl_gtest/ssl_hrr_unittest.cc @@ -81,8 +81,9 @@ class CorrectMessageSeqAfterHrrFilter : public TlsRecordFilter { } DataBuffer buffer(record); - TlsRecordHeader new_header = {header.version(), header.content_type(), - header.sequence_number() + 1}; + TlsRecordHeader new_header(header.variant(), header.version(), + header.content_type(), + header.sequence_number() + 1); // Correct message_seq. buffer.Write(4, 1U, 2); diff --git a/gtests/ssl_gtest/ssl_loopback_unittest.cc b/gtests/ssl_gtest/ssl_loopback_unittest.cc index 2c292ae27..b4a74b99e 100644 --- a/gtests/ssl_gtest/ssl_loopback_unittest.cc +++ b/gtests/ssl_gtest/ssl_loopback_unittest.cc @@ -383,7 +383,8 @@ class TlsPreCCSHeaderInjector : public TlsRecordFilter { std::cerr << "Injecting Finished header before CCS\n"; const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c}; DataBuffer hhdr_buf(hhdr, sizeof(hhdr)); - TlsRecordHeader nhdr(record_header.version(), kTlsHandshakeType, 0); + TlsRecordHeader nhdr(record_header.variant(), record_header.version(), + kTlsHandshakeType, 0); *offset = nhdr.Write(output, *offset, hhdr_buf); *offset = record_header.Write(output, *offset, input); return CHANGE; diff --git a/gtests/ssl_gtest/ssl_record_unittest.cc b/gtests/ssl_gtest/ssl_record_unittest.cc index e76dab488..97bbbba01 100644 --- a/gtests/ssl_gtest/ssl_record_unittest.cc +++ b/gtests/ssl_gtest/ssl_record_unittest.cc @@ -168,6 +168,29 @@ TEST_F(TlsConnectStreamTls13, TooLargeRecord) { EXPECT_EQ(SSL_ERROR_RECORD_OVERFLOW_ALERT, PORT_GetError()); } +class ShortHeaderChecker : public PacketFilter { + public: + PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output) { + // The first octet should be 0b001xxxxx. + EXPECT_EQ(1, input.data()[0] >> 5); + return KEEP; + } +}; + +TEST_F(TlsConnectDatagram13, ShortHeadersClient) { + Connect(); + client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_TRUE); + client_->SetFilter(std::make_shared<ShortHeaderChecker>()); + SendReceive(); +} + +TEST_F(TlsConnectDatagram13, ShortHeadersServer) { + Connect(); + server_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_TRUE); + server_->SetFilter(std::make_shared<ShortHeaderChecker>()); + SendReceive(); +} + const static size_t kContentSizesArr[] = { 1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288}; diff --git a/gtests/ssl_gtest/tls_agent.cc b/gtests/ssl_gtest/tls_agent.cc index dcff0df7d..084da0ab5 100644 --- a/gtests/ssl_gtest/tls_agent.cc +++ b/gtests/ssl_gtest/tls_agent.cc @@ -947,12 +947,13 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) { } bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, - uint16_t wireVersion, uint64_t seq, - uint8_t ct, const DataBuffer& buf) { - LOGV("Writing " << buf.len() << " bytes"); - // Ensure we are a TLS 1.3 cipher agent. + uint64_t seq, uint8_t ct, + const DataBuffer& buf) { + LOGV("Encrypting " << buf.len() << " bytes"); + // Ensure that we are doing TLS 1.3. EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3); - TlsRecordHeader header(wireVersion, kTlsApplicationDataType, seq); + TlsRecordHeader header(variant_, expected_version_, kTlsApplicationDataType, + seq); DataBuffer padded = buf; padded.Write(padded.len(), ct, 1); DataBuffer ciphertext; @@ -1074,15 +1075,20 @@ void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, - uint64_t seq_num) { + uint64_t sequence_number) { size_t index = 0; index = out->Write(index, type, 1); if (variant == ssl_variant_stream) { index = out->Write(index, version, 2); + } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 && + type == kTlsApplicationDataType) { + uint32_t epoch = (sequence_number >> 48) & 0x3; + uint32_t seqno = sequence_number & ((1ULL << 30) - 1); + index = out->Write(index, (epoch << 30) | seqno, 4); } else { index = out->Write(index, TlsVersionToDtlsVersion(version), 2); - index = out->Write(index, seq_num >> 32, 4); - index = out->Write(index, seq_num & PR_UINT32_MAX, 4); + index = out->Write(index, sequence_number >> 32, 4); + index = out->Write(index, sequence_number & PR_UINT32_MAX, 4); } index = out->Write(index, len, 2); out->Write(index, buf, len); diff --git a/gtests/ssl_gtest/tls_agent.h b/gtests/ssl_gtest/tls_agent.h index 6719f56e4..5ce5e6280 100644 --- a/gtests/ssl_gtest/tls_agent.h +++ b/gtests/ssl_gtest/tls_agent.h @@ -143,8 +143,7 @@ class TlsAgent : public PollTarget { void SendData(size_t bytes, size_t blocksize = 1024); void SendBuffer(const DataBuffer& buf); bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, - uint16_t wireVersion, uint64_t seq, uint8_t ct, - const DataBuffer& buf); + uint64_t seq, uint8_t ct, const DataBuffer& buf); // Send data directly to the underlying socket, skipping the TLS layer. void SendDirect(const DataBuffer& buf); void SendRecordDirect(const TlsRecord& record); diff --git a/gtests/ssl_gtest/tls_filter.cc b/gtests/ssl_gtest/tls_filter.cc index e775ea6fc..10b1e31d0 100644 --- a/gtests/ssl_gtest/tls_filter.cc +++ b/gtests/ssl_gtest/tls_filter.cc @@ -30,11 +30,9 @@ void TlsVersioned::WriteStream(std::ostream& stream) const { case SSL_LIBRARY_VERSION_TLS_1_0: stream << "1.0"; break; - case SSL_LIBRARY_VERSION_DTLS_1_0_WIRE: case SSL_LIBRARY_VERSION_TLS_1_1: stream << (is_dtls() ? "1.0" : "1.1"); break; - case SSL_LIBRARY_VERSION_DTLS_1_2_WIRE: case SSL_LIBRARY_VERSION_TLS_1_2: stream << "1.2"; break; @@ -67,8 +65,14 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending, return; } - self->in_sequence_number_ = 0; - self->out_sequence_number_ = 0; + uint64_t seq_no; + if (self->agent()->variant() == ssl_variant_datagram) { + seq_no = static_cast<uint64_t>(SSLInt_CipherSpecToEpoch(newSpec)) << 48; + } else { + seq_no = 0; + } + self->in_sequence_number_ = seq_no; + self->out_sequence_number_ = seq_no; self->dropped_record_ = false; self->cipher_spec_.reset(new TlsCipherSpec()); bool ret = self->cipher_spec_->Init( @@ -77,33 +81,59 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending, EXPECT_EQ(true, ret); } +bool TlsRecordFilter::is_dtls13() const { + if (agent()->variant() != ssl_variant_datagram) { + return false; + } + if (agent()->state() == TlsAgent::STATE_CONNECTED) { + return agent()->version() >= SSL_LIBRARY_VERSION_TLS_1_3; + } + SSLPreliminaryChannelInfo info; + EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(agent()->ssl_fd(), &info, + sizeof(info))); + return (info.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) || + info.canSendEarlyData; +} + PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) { + // Disable during shutdown. + if (!agent()) { + return KEEP; + } + bool changed = false; size_t offset = 0U; - output->Allocate(input.len()); + output->Allocate(input.len()); TlsParser parser(input); while (parser.remaining()) { TlsRecordHeader header; DataBuffer record; - if (!header.Parse(in_sequence_number_, &parser, &record)) { + if (!header.Parse(is_dtls13(), in_sequence_number_, &parser, &record)) { ADD_FAILURE() << "not a valid record"; return KEEP; } - // Track the sequence number, which is necessary for stream mode (the - // sequence number is in the header for datagram). + // Track the sequence number, which is necessary for stream mode when + // decrypting and for TLS 1.3 datagram to recover the sequence number. // - // This isn't perfectly robust. If there is a change from an active cipher + // We reset the counter when the cipher spec changes, but that notification + // appears before a record is sent. If multiple records are sent with + // different cipher specs, this would fail. This filters out cleartext + // records, so we don't get confused by handshake messages that are sent at + // the same time as encrypted records. Sequence numbers are therefore + // likely to be incorrect for cleartext records. + // + // This isn't perfectly robust: if there is a change from an active cipher // spec to another active cipher spec (KeyUpdate for instance) AND writes - // are consolidated across that change AND packets were dropped from the - // older epoch, we will not correctly re-encrypt records in the old epoch to - // update their sequence numbers. - if (cipher_spec_ && header.content_type() == kTlsApplicationDataType) { - ++in_sequence_number_; + // are consolidated across that change, this code could use the wrong + // sequence numbers when re-encrypting records with the old keys. + if (header.content_type() == kTlsApplicationDataType) { + in_sequence_number_ = + (std::max)(in_sequence_number_, header.sequence_number() + 1); } if (FilterRecord(header, record, &offset, output) != KEEP) { @@ -131,11 +161,14 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( DataBuffer plaintext; if (!Unprotect(header, record, &inner_content_type, &plaintext)) { + if (g_ssl_gtest_verbose) { + std::cerr << "unprotect failed: " << header << ":" << record << std::endl; + } return KEEP; } - TlsRecordHeader real_header = {header.version(), inner_content_type, - header.sequence_number()}; + TlsRecordHeader real_header(header.variant(), header.version(), + inner_content_type, header.sequence_number()); PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered); // In stream mode, even if something doesn't change we need to re-encrypt if @@ -166,8 +199,8 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( } else { seq_num = out_sequence_number_++; } - TlsRecordHeader out_header = {header.version(), header.content_type(), - seq_num}; + TlsRecordHeader out_header(header.variant(), header.version(), + header.content_type(), seq_num); DataBuffer ciphertext; bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext); @@ -179,20 +212,109 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( return CHANGE; } -bool TlsRecordHeader::Parse(uint64_t seqno, TlsParser* parser, +size_t TlsRecordHeader::header_length() const { + if (!is_dtls()) { + return 5; + } + if (version() >= SSL_LIBRARY_VERSION_TLS_1_3 && + content_type_ == kTlsApplicationDataType) { + return 7; + } + return 13; +} + +uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t expected, + uint32_t partial, + size_t partial_bits) { + EXPECT_GE(32U, partial_bits); + uint64_t mask = (1 << partial_bits) - 1; + // First we determine the highest possible value. This is half the + // expressible range above the expected value. + uint64_t cap = expected + (1ULL << (partial_bits - 1)); + // Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234. + uint64_t seq_no = (cap & ~mask) | partial; + // If the partial value is higher than the same partial piece from the cap, + // then the real value has to be lower. e.g., xxxx1234 can't become xxxx5678. + if (partial > (cap & mask)) { + seq_no -= 1ULL << partial_bits; + } + return seq_no; +} + +// Determine the full epoch and sequence number from an expected and raw value. +// The expected and output values are packed as they are in DTLS 1.2 and +// earlier: with 16 bits of epoch and 48 bits of sequence number. +uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint32_t raw, + size_t seq_no_bits, + size_t epoch_bits) { + uint64_t epoch_mask = (1ULL << epoch_bits) - 1; + uint64_t epoch = RecoverSequenceNumber( + expected >> 48, (raw >> seq_no_bits) & epoch_mask, epoch_bits); + if (epoch > (expected >> 48)) { + // If the epoch has changed, reset the expected sequence number. + expected = 0; + } else { + // Otherwise, retain just the sequence number part. + expected &= (1ULL << 48) - 1; + } + uint64_t seq_no_mask = (1ULL << seq_no_bits) - 1; + uint64_t seq_no = + RecoverSequenceNumber(expected, raw & seq_no_mask, seq_no_bits); + return (epoch << 48) | seq_no; +} + +bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser, DataBuffer* body) { if (!parser->Read(&content_type_)) { return false; } + if (is_dtls13) { + variant_ = ssl_variant_datagram; + version_ = SSL_LIBRARY_VERSION_TLS_1_3; + +#ifndef UNSAFE_FUZZER_MODE + // Deal with the 7 octet header. + if (content_type_ == kTlsApplicationDataType) { + uint32_t tmp; + if (!parser->Read(&tmp, 4)) { + return false; + } + sequence_number_ = ParseSequenceNumber(seqno, tmp, 30, 2); + return parser->ReadVariable(body, 2); + } + + // The short, 2 octet header. + if ((content_type_ & 0xe0) == 0x20) { + uint32_t tmp; + if (!parser->Read(&tmp, 1)) { + return false; + } + // Need to use the low 5 bits of the first octet too. + tmp |= (content_type_ & 0x1f) << 8; + content_type_ = kTlsApplicationDataType; + sequence_number_ = ParseSequenceNumber(seqno, tmp, 12, 1); + return parser->Read(body, parser->remaining()); + } + + // The full 13 octet header can only be used for a few types. + EXPECT_TRUE(content_type_ == kTlsAlertType || + content_type_ == kTlsHandshakeType || + content_type_ == kTlsAckType); +#endif + } + uint32_t ver; if (!parser->Read(&ver, 2)) { return false; } - version_ = ver; + if (!is_dtls13) { + variant_ = IsDtls(ver) ? ssl_variant_datagram : ssl_variant_stream; + } + version_ = NormalizeTlsVersion(ver); - // If this is DTLS, overwrite the sequence number. - if (IsDtls(ver)) { + if (is_dtls()) { + // If this is DTLS, read the sequence number. uint32_t tmp; if (!parser->Read(&tmp, 4)) { return false; @@ -211,11 +333,21 @@ bool TlsRecordHeader::Parse(uint64_t seqno, TlsParser* parser, size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const { offset = buffer->Write(offset, content_type_, 1); - offset = buffer->Write(offset, version_, 2); - if (is_dtls()) { - // write epoch (2 octet), and seqnum (6 octet) - offset = buffer->Write(offset, sequence_number_ >> 32, 4); - offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4); + if (is_dtls() && version_ >= SSL_LIBRARY_VERSION_TLS_1_3 && + content_type() == kTlsApplicationDataType) { + // application_data records in TLS 1.3 have a different header format. + // Always use the long header here for simplicity. + uint32_t e = (sequence_number_ >> 48) & 0x3; + uint32_t seqno = sequence_number_ & ((1ULL << 30) - 1); + offset = buffer->Write(offset, (e << 30) | seqno, 4); + } else { + uint16_t v = is_dtls() ? TlsVersionToDtlsVersion(version_) : version_; + offset = buffer->Write(offset, v, 2); + if (is_dtls()) { + // write epoch (2 octet), and seqnum (6 octet) + offset = buffer->Write(offset, sequence_number_ >> 32, 4); + offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4); + } } offset = buffer->Write(offset, body.len(), 2); offset = buffer->Write(offset, body); @@ -406,6 +538,7 @@ bool TlsHandshakeFilter::HandshakeHeader::Parse( const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) { *complete = false; + variant_ = record_header.variant(); version_ = record_header.version(); if (!parser->Read(&handshake_type_)) { return false; // malformed diff --git a/gtests/ssl_gtest/tls_filter.h b/gtests/ssl_gtest/tls_filter.h index 5d1321f8f..296c0a9ee 100644 --- a/gtests/ssl_gtest/tls_filter.h +++ b/gtests/ssl_gtest/tls_filter.h @@ -11,7 +11,7 @@ #include <memory> #include <set> #include <vector> - +#include "sslt.h" #include "test_io.h" #include "tls_agent.h" #include "tls_parser.h" @@ -27,38 +27,47 @@ class TlsCipherSpec; class TlsVersioned { public: - TlsVersioned() : version_(0) {} - explicit TlsVersioned(uint16_t v) : version_(v) {} + TlsVersioned() : variant_(ssl_variant_stream), version_(0) {} + TlsVersioned(SSLProtocolVariant var, uint16_t ver) + : variant_(var), version_(ver) {} - bool is_dtls() const { return IsDtls(version_); } + bool is_dtls() const { return variant_ == ssl_variant_datagram; } + SSLProtocolVariant variant() const { return variant_; } uint16_t version() const { return version_; } void WriteStream(std::ostream& stream) const; protected: + SSLProtocolVariant variant_; uint16_t version_; }; class TlsRecordHeader : public TlsVersioned { public: TlsRecordHeader() : TlsVersioned(), content_type_(0), sequence_number_(0) {} - TlsRecordHeader(uint16_t ver, uint8_t ct, uint64_t seqno) - : TlsVersioned(ver), content_type_(ct), sequence_number_(seqno) {} + TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct, + uint64_t seqno) + : TlsVersioned(var, ver), content_type_(ct), sequence_number_(seqno) {} uint8_t content_type() const { return content_type_; } uint64_t sequence_number() const { return sequence_number_; } uint16_t epoch() const { return static_cast<uint16_t>(sequence_number_ >> 48); } - size_t header_length() const { return is_dtls() ? 13 : 5; } - + size_t header_length() const; // Parse the header; return true if successful; body in an outparam if OK. - bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body); + bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser, + DataBuffer* body); // Write the header and body to a buffer at the given offset. // Return the offset of the end of the write. size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; private: + static uint64_t RecoverSequenceNumber(uint64_t expected, uint32_t partial, + size_t partial_bits); + static uint64_t ParseSequenceNumber(uint64_t expected, uint32_t raw, + size_t seq_no_bits, size_t epoch_bits); + uint8_t content_type_; uint64_t sequence_number_; }; @@ -127,6 +136,8 @@ class TlsRecordFilter : public PacketFilter { return KEEP; } + bool is_dtls13() const; + private: static void CipherSpecChanged(void* arg, PRBool sending, ssl3CipherSpec* newSpec); diff --git a/lib/ssl/dtls13con.c b/lib/ssl/dtls13con.c index aba0f62ab..04cbd373c 100644 --- a/lib/ssl/dtls13con.c +++ b/lib/ssl/dtls13con.c @@ -11,6 +11,43 @@ #include "sslimpl.h" #include "sslproto.h" +SECStatus +dtls13_InsertCipherTextHeader(const sslSocket *ss, ssl3CipherSpec *cwSpec, + sslBuffer *wrBuf, PRBool *needsLength) +{ + PRUint32 seq; + SECStatus rv; + + /* Avoid using short records for the handshake. We pack multiple records + * into the one datagram for the handshake. */ + if (ss->opt.enableDtlsShortHeader && + cwSpec->epoch != TrafficKeyHandshake) { + *needsLength = PR_FALSE; + /* The short header is comprised of two octets in the form + * 0b001essssssssssss where 'e' is the low bit of the epoch and 's' is + * the low 12 bits of the sequence number. */ + seq = 0x2000 | + (((uint64_t)cwSpec->epoch & 1) << 12) | + (cwSpec->nextSeqNum & 0xfff); + return sslBuffer_AppendNumber(wrBuf, seq, 2); + } + + rv = sslBuffer_AppendNumber(wrBuf, content_application_data, 1); + if (rv != SECSuccess) { + return SECFailure; + } + + /* The epoch and sequence number are encoded on 4 octets, with the epoch + * consuming the first two bits. */ + seq = (((uint64_t)cwSpec->epoch & 3) << 30) | (cwSpec->nextSeqNum & 0x3fffffff); + rv = sslBuffer_AppendNumber(wrBuf, seq, 4); + if (rv != SECSuccess) { + return SECFailure; + } + *needsLength = PR_TRUE; + return SECSuccess; +} + /* DTLS 1.3 Record map for ACK processing. * This represents a single fragment, so a record which includes * multiple fragments will have one entry for each fragment on the diff --git a/lib/ssl/dtls13con.h b/lib/ssl/dtls13con.h index bf14d3bd2..ca48ef363 100644 --- a/lib/ssl/dtls13con.h +++ b/lib/ssl/dtls13con.h @@ -9,6 +9,10 @@ #ifndef __dtls13con_h_ #define __dtls13con_h_ +SECStatus dtls13_InsertCipherTextHeader(const sslSocket *ss, + ssl3CipherSpec *cwSpec, + sslBuffer *wrBuf, + PRBool *needsLength); SECStatus dtls13_RememberFragment(sslSocket *ss, PRCList *list, PRUint32 sequence, PRUint32 offset, PRUint32 length, DTLSEpoch epoch, diff --git a/lib/ssl/dtlscon.c b/lib/ssl/dtlscon.c index 2f335f924..6c3d7d24c 100644 --- a/lib/ssl/dtlscon.c +++ b/lib/ssl/dtlscon.c @@ -776,7 +776,7 @@ dtls_FragmentHandshake(sslSocket *ss, DTLSQueuedMessage *msg) rv = dtls13_RememberFragment(ss, &ss->ssl3.hs.dtlsSentHandshake, msgSeq, fragmentOffset, fragmentLen, msg->cwSpec->epoch, - msg->cwSpec->seqNum); + msg->cwSpec->nextSeqNum); if (rv != SECSuccess) { return SECFailure; } @@ -1319,6 +1319,107 @@ DTLS_GetHandshakeTimeout(PRFileDesc *socket, PRIntervalTime *timeout) return SECSuccess; } +PRBool +dtls_IsLongHeader(SSL3ProtocolVersion version, PRUint8 firstOctet) +{ +#ifndef UNSAFE_FUZZER_MODE + return version < SSL_LIBRARY_VERSION_TLS_1_3 || + firstOctet == content_handshake || + firstOctet == content_ack || + firstOctet == content_alert; +#else + return PR_TRUE; +#endif +} + +DTLSEpoch +dtls_ReadEpoch(const ssl3CipherSpec *crSpec, const PRUint8 *hdr) +{ + DTLSEpoch epoch; + DTLSEpoch maxEpoch; + DTLSEpoch partial; + + if (dtls_IsLongHeader(crSpec->version, hdr[0])) { + return ((DTLSEpoch)hdr[3] << 8) | hdr[4]; + } + + /* A lot of how we recover the epoch here will depend on how we plan to + * manage KeyUpdate. In the case that we decide to install a new read spec + * as a KeyUpdate is handled, crSpec will always be the highest epoch we can + * possibly receive. That makes this easier to manage. */ + if ((hdr[0] & 0xe0) == 0x20) { + /* Use crSpec->epoch, or crSpec->epoch - 1 if the last bit differs. */ + if (((hdr[0] >> 4) & 1) == (crSpec->epoch & 1)) { + return crSpec->epoch; + } + return crSpec->epoch - 1; + } + + /* dtls_GatherData should ensure that this works. */ + PORT_Assert(hdr[0] == content_application_data); + + /* This uses the same method as is used to recover the sequence number in + * dtls_ReadSequenceNumber, except that the maximum value is set to the + * current epoch. */ + partial = hdr[1] >> 6; + maxEpoch = PR_MAX(crSpec->epoch, 3); + epoch = (maxEpoch & 0xfffc) | partial; + if (partial > (maxEpoch & 0x03)) { + epoch -= 4; + } + return epoch; +} + +static sslSequenceNumber +dtls_ReadSequenceNumber(const ssl3CipherSpec *spec, const PRUint8 *hdr) +{ + sslSequenceNumber cap; + sslSequenceNumber partial; + sslSequenceNumber seqNum; + sslSequenceNumber mask; + + if (dtls_IsLongHeader(spec->version, hdr[0])) { + static const unsigned int seqNumOffset = 5; /* type, version, epoch */ + static const unsigned int seqNumLength = 6; + sslReader r = SSL_READER(hdr + seqNumOffset, seqNumLength); + (void)sslRead_ReadNumber(&r, seqNumLength, &seqNum); + return seqNum; + } + + /* Only the least significant bits of the sequence number is available here. + * This recovers the value based on the next expected sequence number. + * + * This works by determining the maximum possible sequence number, which is + * half the range of possible values above the expected next value (the + * expected next value is in |spec->seqNum|). Then, the last part of the + * sequence number is replaced. If that causes the value to exceed the + * maximum, subtract an entire range. + */ + if ((hdr[0] & 0xe0) == 0x20) { + /* A 12-bit sequence number. */ + cap = spec->nextSeqNum + (1ULL << 11); + partial = (((sslSequenceNumber)hdr[0] & 0xf) << 8) | + (sslSequenceNumber)hdr[1]; + mask = (1ULL << 12) - 1; + } else { + /* A 30-bit sequence number. */ + cap = spec->nextSeqNum + (1ULL << 29); + partial = (((sslSequenceNumber)hdr[1] & 0x3f) << 24) | + ((sslSequenceNumber)hdr[2] << 16) | + ((sslSequenceNumber)hdr[3] << 8) | + (sslSequenceNumber)hdr[4]; + mask = (1ULL << 30) - 1; + } + seqNum = (cap & ~mask) | partial; + /* The second check prevents the value from underflowing if we get a large + * gap at the start of a connection, where this subtraction would cause the + * sequence number to wrap to near UINT64_MAX. */ + if ((partial > (cap & mask)) && (seqNum > mask)) { + seqNum -= mask + 1; + } + return seqNum; +} + /* * DTLS relevance checks: * Note that this code currently ignores all out-of-epoch packets, @@ -1336,7 +1437,7 @@ dtls_IsRelevant(sslSocket *ss, const ssl3CipherSpec *spec, const SSL3Ciphertext *cText, sslSequenceNumber *seqNumOut) { - sslSequenceNumber seqNum = cText->seq_num & RECORD_SEQ_MASK; + sslSequenceNumber seqNum = dtls_ReadSequenceNumber(spec, cText->hdr); if (dtls_RecordGetRecvd(&spec->recvdRecords, seqNum) != 0) { SSL_TRC(10, ("%d: SSL3[%d]: dtls_IsRelevant, rejecting " "potentially replayed packet", diff --git a/lib/ssl/dtlscon.h b/lib/ssl/dtlscon.h index d094380f8..45fc069b9 100644 --- a/lib/ssl/dtlscon.h +++ b/lib/ssl/dtlscon.h @@ -41,8 +41,10 @@ extern SSL3ProtocolVersion dtls_TLSVersionToDTLSVersion(SSL3ProtocolVersion tlsv); extern SSL3ProtocolVersion dtls_DTLSVersionToTLSVersion(SSL3ProtocolVersion dtlsv); +DTLSEpoch dtls_ReadEpoch(const ssl3CipherSpec *crSpec, const PRUint8 *hdr); extern PRBool dtls_IsRelevant(sslSocket *ss, const ssl3CipherSpec *spec, const SSL3Ciphertext *cText, sslSequenceNumber *seqNum); void dtls_ReceivedFirstMessageInFlight(sslSocket *ss); +PRBool dtls_IsLongHeader(SSL3ProtocolVersion version, PRUint8 firstOctet); #endif diff --git a/lib/ssl/ssl.h b/lib/ssl/ssl.h index 25aabbaa2..ad8ec0f8b 100644 --- a/lib/ssl/ssl.h +++ b/lib/ssl/ssl.h @@ -254,6 +254,17 @@ SSL_IMPORT PRFileDesc *DTLS_ImportFD(PRFileDesc *model, PRFileDesc *fd); * no effect for a server. This setting is ignored for DTLS. */ #define SSL_ENABLE_TLS13_COMPAT_MODE 35 +/* Enables the sending of DTLS records using the short (two octet) record + * header. Only do this if there are 2^10 or fewer packets in flight at a time; + * using this with a larger number of packets in flight could mean that packets + * are dropped if there is reordering. + * + * This applies to TLS 1.3 only. This is not a parameter that is negotiated + * during the TLS handshake. Unlike other socket options, this option can be + * changed after a handshake is complete. + */ +#define SSL_ENABLE_DTLS_SHORT_HEADER 36 + #ifdef SSL_DEPRECATED_FUNCTION /* Old deprecated function names */ SSL_IMPORT SECStatus SSL_Enable(PRFileDesc *fd, int option, PRIntn on); diff --git a/lib/ssl/ssl3con.c b/lib/ssl/ssl3con.c index df9d8cb6c..22fdaf5b1 100644 --- a/lib/ssl/ssl3con.c +++ b/lib/ssl/ssl3con.c @@ -1415,7 +1415,7 @@ ssl3_SetupPendingCipherSpec(sslSocket *ss, CipherSpecDirection direction, spec->macDef = ssl_GetMacDef(ss, suiteDef); spec->epoch = prev->epoch + 1; - spec->seqNum = 0; + spec->nextSeqNum = 0; if (IS_DTLS(ss) && direction == CipherSpecRead) { dtls_InitRecvdRecords(&spec->recvdRecords); } @@ -2004,6 +2004,7 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec, unsigned int ivLen = 0; unsigned char pseudoHeaderBuf[13]; sslBuffer pseudoHeader = SSL_BUFFER(pseudoHeaderBuf); + int len; if (cwSpec->cipherDef->type == type_block && cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_1) { @@ -2013,29 +2014,32 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec, * record. */ ivLen = cwSpec->cipherDef->iv_size; - if (ivLen > wrBuf->space) { + if (ivLen > SSL_BUFFER_SPACE(wrBuf)) { PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); return SECFailure; } - rv = PK11_GenerateRandom(wrBuf->buf, ivLen); + rv = PK11_GenerateRandom(SSL_BUFFER_NEXT(wrBuf), ivLen); if (rv != SECSuccess) { ssl_MapLowLevelError(SSL_ERROR_GENERATE_RANDOM_FAILURE); return rv; } rv = cwSpec->cipher(cwSpec->cipherContext, - wrBuf->buf, /* output */ - (int *)&wrBuf->len, /* outlen */ - ivLen, /* max outlen */ - wrBuf->buf, /* input */ - ivLen); /* input len */ - if (rv != SECSuccess || wrBuf->len != ivLen) { + SSL_BUFFER_NEXT(wrBuf), /* output */ + &len, /* outlen */ + ivLen, /* max outlen */ + SSL_BUFFER_NEXT(wrBuf), /* input */ + ivLen); /* input len */ + if (rv != SECSuccess || len != ivLen) { PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); return SECFailure; } + + rv = sslBuffer_Skip(wrBuf, len, NULL); + PORT_Assert(rv == SECSuccess); /* Can't fail. */ } rv = ssl3_BuildRecordPseudoHeader( - cwSpec->epoch, cwSpec->seqNum, type, + cwSpec->epoch, cwSpec->nextSeqNum, type, cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_0, cwSpec->recordVersion, isDTLS, contentLen, &pseudoHeader); PORT_Assert(rv == SECSuccess); @@ -2043,23 +2047,26 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec, const int nonceLen = cwSpec->cipherDef->explicit_nonce_size; const int tagLen = cwSpec->cipherDef->tag_size; - if (nonceLen + contentLen + tagLen > wrBuf->space) { + if (nonceLen + contentLen + tagLen > SSL_BUFFER_SPACE(wrBuf)) { PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); return SECFailure; } rv = cwSpec->aead( &cwSpec->keyMaterial, - PR_FALSE, /* do encrypt */ - wrBuf->buf, /* output */ - (int *)&wrBuf->len, /* out len */ - wrBuf->space, /* max out */ - pIn, contentLen, /* input */ + PR_FALSE, /* do encrypt */ + SSL_BUFFER_NEXT(wrBuf), /* output */ + &len, /* out len */ + SSL_BUFFER_SPACE(wrBuf), /* max out */ + pIn, contentLen, /* input */ SSL_BUFFER_BASE(&pseudoHeader), SSL_BUFFER_LEN(&pseudoHeader)); if (rv != SECSuccess) { PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); return SECFailure; } + + rv = sslBuffer_Skip(wrBuf, len, NULL); + PORT_Assert(rv == SECSuccess); /* Can't fail. */ } else { int blockSize = cwSpec->cipherDef->block_size; @@ -2069,7 +2076,7 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec, rv = ssl3_ComputeRecordMAC(cwSpec, SSL_BUFFER_BASE(&pseudoHeader), SSL_BUFFER_LEN(&pseudoHeader), pIn, contentLen, - wrBuf->buf + ivLen + contentLen, &macLen); + SSL_BUFFER_NEXT(wrBuf) + contentLen, &macLen); if (rv != SECSuccess) { ssl_MapLowLevelError(SSL_ERROR_MAC_COMPUTATION_FAILURE); return SECFailure; @@ -2095,7 +2102,7 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec, PORT_Assert((fragLen % blockSize) == 0); /* Pad according to TLS rules (also acceptable to SSL3). */ - pBuf = &wrBuf->buf[ivLen + fragLen - 1]; + pBuf = SSL_BUFFER_NEXT(wrBuf) + fragLen - 1; for (i = padding_length + 1; i > 0; --i) { *pBuf-- = padding_length; } @@ -2112,14 +2119,14 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec, p2Len += oddLen; PORT_Assert((blockSize < 2) || (p2Len % blockSize) == 0); - memmove(wrBuf->buf + ivLen + p1Len, pIn + p1Len, oddLen); + memmove(SSL_BUFFER_NEXT(wrBuf) + p1Len, pIn + p1Len, oddLen); } if (p1Len > 0) { int cipherBytesPart1 = -1; rv = cwSpec->cipher(cwSpec->cipherContext, - wrBuf->buf + ivLen, /* output */ - &cipherBytesPart1, /* actual outlen */ - p1Len, /* max outlen */ + SSL_BUFFER_NEXT(wrBuf), /* output */ + &cipherBytesPart1, /* actual outlen */ + p1Len, /* max outlen */ pIn, p1Len); /* input, and inputlen */ PORT_Assert(rv == SECSuccess && cipherBytesPart1 == (int)p1Len); @@ -2127,22 +2134,24 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec, PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); return SECFailure; } - wrBuf->len += cipherBytesPart1; + rv = sslBuffer_Skip(wrBuf, p1Len, NULL); + PORT_Assert(rv == SECSuccess); } if (p2Len > 0) { int cipherBytesPart2 = -1; rv = cwSpec->cipher(cwSpec->cipherContext, - wrBuf->buf + ivLen + p1Len, + SSL_BUFFER_NEXT(wrBuf), &cipherBytesPart2, /* output and actual outLen */ p2Len, /* max outlen */ - wrBuf->buf + ivLen + p1Len, + SSL_BUFFER_NEXT(wrBuf), p2Len); /* input and inputLen*/ PORT_Assert(rv == SECSuccess && cipherBytesPart2 == (int)p2Len); if (rv != SECSuccess || cipherBytesPart2 != (int)p2Len) { PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); return SECFailure; } - wrBuf->len += cipherBytesPart2; + rv = sslBuffer_Skip(wrBuf, p2Len, NULL); + PORT_Assert(rv == SECSuccess); } } @@ -2152,14 +2161,18 @@ ssl3_MACEncryptRecord(ssl3CipherSpec *cwSpec, /* Note: though this can report failure, it shouldn't. */ static SECStatus ssl_InsertRecordHeader(const sslSocket *ss, ssl3CipherSpec *cwSpec, - SSL3ContentType contentType, unsigned int len, - sslBuffer *wrBuf) + SSL3ContentType contentType, sslBuffer *wrBuf, + PRBool *needsLength) { SECStatus rv; #ifndef UNSAFE_FUZZER_MODE if (cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_3 && - cwSpec->cipherDef->calg != ssl_calg_null) { + cwSpec->epoch > TrafficKeyClearText) { + if (IS_DTLS(ss)) { + return dtls13_InsertCipherTextHeader(ss, cwSpec, wrBuf, + needsLength); + } contentType = content_application_data; } #endif @@ -2177,16 +2190,12 @@ ssl_InsertRecordHeader(const sslSocket *ss, ssl3CipherSpec *cwSpec, if (rv != SECSuccess) { return SECFailure; } - rv = sslBuffer_AppendNumber(wrBuf, cwSpec->seqNum, 6); + rv = sslBuffer_AppendNumber(wrBuf, cwSpec->nextSeqNum, 6); if (rv != SECSuccess) { return SECFailure; } } - rv = sslBuffer_AppendNumber(wrBuf, len, 2); - if (rv != SECSuccess) { - return SECFailure; - } - + *needsLength = PR_TRUE; return SECSuccess; } @@ -2194,66 +2203,67 @@ SECStatus ssl_ProtectRecord(sslSocket *ss, ssl3CipherSpec *cwSpec, SSL3ContentType type, const PRUint8 *pIn, PRUint32 contentLen, sslBuffer *wrBuf) { - unsigned int headerLen = IS_DTLS(ss) ? DTLS_RECORD_HEADER_LENGTH - : SSL3_RECORD_HEADER_LENGTH; - sslBuffer protBuf = SSL_BUFFER_FIXED(SSL_BUFFER_BASE(wrBuf) + headerLen, - SSL_BUFFER_SPACE(wrBuf) - headerLen); - PRBool isTLS13; + PRBool needsLength; + unsigned int lenOffset; SECStatus rv; PORT_Assert(cwSpec->direction == CipherSpecWrite); PORT_Assert(SSL_BUFFER_LEN(wrBuf) == 0); PORT_Assert(cwSpec->cipherDef->max_records <= RECORD_SEQ_MAX); - if (cwSpec->seqNum >= cwSpec->cipherDef->max_records) { + + if (cwSpec->nextSeqNum >= cwSpec->cipherDef->max_records) { /* We should have automatically updated before here in TLS 1.3. */ PORT_Assert(cwSpec->version < SSL_LIBRARY_VERSION_TLS_1_3); SSL_TRC(3, ("%d: SSL[-]: write sequence number at limit 0x%0llx", - SSL_GETPID(), cwSpec->seqNum)); + SSL_GETPID(), cwSpec->nextSeqNum)); PORT_SetError(SSL_ERROR_TOO_MANY_RECORDS); return SECFailure; } - isTLS13 = (PRBool)(cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_3); + rv = ssl_InsertRecordHeader(ss, cwSpec, type, wrBuf, &needsLength); + if (rv != SECSuccess) { + return SECFailure; + } + if (needsLength) { + rv = sslBuffer_Skip(wrBuf, 2, &lenOffset); + if (rv != SECSuccess) { + return SECFailure; + } + } #ifdef UNSAFE_FUZZER_MODE { int len; - rv = Null_Cipher(NULL, SSL_BUFFER_BASE(&protBuf), &len, - SSL_BUFFER_SPACE(&protBuf), pIn, contentLen); + rv = Null_Cipher(NULL, SSL_BUFFER_NEXT(wrBuf), &len, + SSL_BUFFER_SPACE(wrBuf), pIn, contentLen); if (rv != SECSuccess) { return SECFailure; /* error was set */ } - rv = sslBuffer_Skip(&protBuf, len, NULL); + rv = sslBuffer_Skip(wrBuf, len, NULL); PORT_Assert(rv == SECSuccess); /* Can't fail. */ } #else - if (isTLS13) { - rv = tls13_ProtectRecord(ss, cwSpec, type, pIn, contentLen, &protBuf); + if (cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_3) { + rv = tls13_ProtectRecord(ss, cwSpec, type, pIn, contentLen, wrBuf); } else { rv = ssl3_MACEncryptRecord(cwSpec, ss->sec.isServer, IS_DTLS(ss), type, - pIn, contentLen, &protBuf); + pIn, contentLen, wrBuf); } #endif if (rv != SECSuccess) { return SECFailure; /* error was set */ } - PORT_Assert(protBuf.len <= MAX_FRAGMENT_LENGTH + (isTLS13 ? 256 : 1024)); - - rv = ssl_InsertRecordHeader(ss, cwSpec, type, SSL_BUFFER_LEN(&protBuf), - wrBuf); - if (rv != SECSuccess) { - return SECFailure; - } - - PORT_Assert(SSL_BUFFER_LEN(wrBuf) == headerLen); - rv = sslBuffer_Skip(wrBuf, SSL_BUFFER_LEN(&protBuf), NULL); - if (rv != SECSuccess) { - PORT_Assert(0); /* Can't fail. */ - return SECFailure; + if (needsLength) { + /* Insert the length. */ + rv = sslBuffer_InsertLength(wrBuf, lenOffset, 2); + if (rv != SECSuccess) { + PORT_Assert(0); /* Can't fail. */ + return SECFailure; + } } - ++cwSpec->seqNum; + ++cwSpec->nextSeqNum; return SECSuccess; } @@ -2291,6 +2301,7 @@ ssl_ProtectNextRecord(sslSocket *ss, ssl3CipherSpec *spec, SSL3ContentType type, *written = contentLen; return SECSuccess; } + /* Process the plain text before sending it. * Returns the number of bytes of plaintext that were successfully sent * plus the number of bytes of plaintext that were copied into the @@ -2368,7 +2379,7 @@ ssl3_SendRecord(sslSocket *ss, rv = ssl_ProtectNextRecord(ss, spec, type, pIn, nIn, &written); ssl_ReleaseSpecReadLock(ss); if (rv != SECSuccess) { - return SECFailure; + goto loser; } PORT_Assert(written > 0); @@ -11847,6 +11858,7 @@ ssl3_UnprotectRecord(sslSocket *ss, unsigned int good; unsigned int ivLen = 0; SSL3ContentType rType; + SSL3ProtocolVersion rVersion; unsigned int minLength; unsigned int originalLen = 0; PRUint8 headerBuf[13]; @@ -11919,7 +11931,9 @@ ssl3_UnprotectRecord(sslSocket *ss, return SECFailure; } - rType = cText->type; + rType = (SSL3ContentType)cText->hdr[0]; + rVersion = ((SSL3ProtocolVersion)cText->hdr[1] << 8) | + (SSL3ProtocolVersion)cText->hdr[2]; if (cipher_def->type == type_aead) { /* XXX For many AEAD ciphers, the plaintext is shorter than the * ciphertext by a fixed byte count, but it is not true in general. @@ -11929,8 +11943,8 @@ ssl3_UnprotectRecord(sslSocket *ss, cText->buf->len - cipher_def->explicit_nonce_size - cipher_def->tag_size; rv = ssl3_BuildRecordPseudoHeader( - spec->epoch, IS_DTLS(ss) ? cText->seq_num : spec->seqNum, - rType, isTLS, cText->version, IS_DTLS(ss), decryptedLen, &header); + spec->epoch, cText->seqNum, + rType, isTLS, rVersion, IS_DTLS(ss), decryptedLen, &header); PORT_Assert(rv == SECSuccess); rv = spec->aead(&spec->keyMaterial, PR_TRUE, /* do decrypt */ @@ -11977,8 +11991,8 @@ ssl3_UnprotectRecord(sslSocket *ss, /* compute the MAC */ rv = ssl3_BuildRecordPseudoHeader( - spec->epoch, IS_DTLS(ss) ? cText->seq_num : spec->seqNum, - rType, isTLS, cText->version, IS_DTLS(ss), + spec->epoch, cText->seqNum, + rType, isTLS, rVersion, IS_DTLS(ss), plaintext->len - spec->macDef->mac_size, &header); PORT_Assert(rv == SECSuccess); if (cipher_def->type == type_block) { @@ -12028,13 +12042,19 @@ ssl3_UnprotectRecord(sslSocket *ss, return SECSuccess; } -static SECStatus +SECStatus ssl3_HandleNonApplicationData(sslSocket *ss, SSL3ContentType rType, DTLSEpoch epoch, sslSequenceNumber seqNum, sslBuffer *databuf) { SECStatus rv; + /* check for Token Presence */ + if (!ssl3_ClientAuthTokenPresent(ss->sec.ci.sid)) { + PORT_SetError(SSL_ERROR_TOKEN_INSERTION_REMOVAL); + return SECFailure; + } + ssl_GetSSL3HandshakeLock(ss); /* All the functions called in this switch MUST set error code if @@ -12080,15 +12100,16 @@ ssl3_HandleNonApplicationData(sslSocket *ss, SSL3ContentType rType, * Returns NULL if no appropriate cipher spec is found. */ static ssl3CipherSpec * -ssl3_GetCipherSpec(sslSocket *ss, sslSequenceNumber seq) +ssl3_GetCipherSpec(sslSocket *ss, SSL3Ciphertext *cText) { ssl3CipherSpec *crSpec = ss->ssl3.crSpec; ssl3CipherSpec *newSpec = NULL; - DTLSEpoch epoch = seq >> 48; + DTLSEpoch epoch; if (!IS_DTLS(ss)) { return crSpec; } + epoch = dtls_ReadEpoch(crSpec, cText->hdr); if (crSpec->epoch == epoch) { return crSpec; } @@ -12128,16 +12149,15 @@ ssl3_GetCipherSpec(sslSocket *ss, sslSequenceNumber seq) * Application Data records. */ SECStatus -ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) +ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText) { SECStatus rv; PRBool isTLS; DTLSEpoch epoch; - sslSequenceNumber seqNum = 0; ssl3CipherSpec *spec = NULL; PRBool outOfOrderSpec = PR_FALSE; SSL3ContentType rType; - sslBuffer *plaintext; + sslBuffer *plaintext = &ss->gs.buf; SSL3AlertDescription alert = internal_error; PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss)); @@ -12147,27 +12167,15 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) return SECFailure; } - /* cText is NULL when we're called from ssl3_RestartHandshakeAfterXXX(). - * This implies that databuf holds a previously deciphered SSL Handshake - * message. - */ - if (cText == NULL) { - SSL_DBG(("%d: SSL3[%d]: HandleRecord, resuming handshake", - SSL_GETPID(), ss->fd)); - /* Note that this doesn't pass the epoch and sequence number of the - * record through, which DTLS 1.3 depends on. DTLS doesn't support - * asynchronous certificate validation, so that should be OK. */ - PORT_Assert(!IS_DTLS(ss)); - return ssl3_HandleNonApplicationData(ss, content_handshake, - 0, 0, databuf); - } + /* Clear out the buffer in case this exits early. Any data then won't be + * processed twice. */ + plaintext->len = 0; ssl_GetSpecReadLock(ss); /******************************************/ - spec = ssl3_GetCipherSpec(ss, cText->seq_num); + spec = ssl3_GetCipherSpec(ss, cText); if (!spec) { PORT_Assert(IS_DTLS(ss)); ssl_ReleaseSpecReadLock(ss); /*****************************/ - databuf->len = 0; /* Needed to ensure data not left around */ return SECSuccess; } if (spec != ss->ssl3.crSpec) { @@ -12178,36 +12186,30 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) } isTLS = (PRBool)(spec->version > SSL_LIBRARY_VERSION_3_0); if (IS_DTLS(ss)) { - if (!dtls_IsRelevant(ss, spec, cText, &seqNum)) { + if (!dtls_IsRelevant(ss, spec, cText, &cText->seqNum)) { ssl_ReleaseSpecReadLock(ss); /*****************************/ - databuf->len = 0; /* Needed to ensure data not left around */ - return SECSuccess; } } else { - seqNum = spec->seqNum + 1; + cText->seqNum = spec->nextSeqNum; } - if (seqNum >= spec->cipherDef->max_records) { + if (cText->seqNum >= spec->cipherDef->max_records) { ssl_ReleaseSpecReadLock(ss); /*****************************/ SSL_TRC(3, ("%d: SSL[%d]: read sequence number at limit 0x%0llx", - SSL_GETPID(), ss->fd, seqNum)); + SSL_GETPID(), ss->fd, cText->seqNum)); PORT_SetError(SSL_ERROR_TOO_MANY_RECORDS); return SECFailure; } - plaintext = databuf; - plaintext->len = 0; /* filled in by Unprotect call below. */ - /* We're waiting for another ClientHello, which will appear unencrypted. * Use the content type to tell whether this is should be discarded. * * XXX If we decide to remove the content type from encrypted records, this * will become much more difficult to manage. */ if (ss->ssl3.hs.zeroRttIgnore == ssl_0rtt_ignore_hrr && - cText->type == content_application_data) { + cText->hdr[0] == content_application_data) { ssl_ReleaseSpecReadLock(ss); /*****************************/ PORT_Assert(ss->ssl3.hs.ws == wait_client_hello); - databuf->len = 0; return SECSuccess; } @@ -12224,6 +12226,7 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) } #ifdef UNSAFE_FUZZER_MODE + rType = cText->hdr[0]; rv = Null_Cipher(NULL, plaintext->buf, (int *)&plaintext->len, plaintext->space, cText->buf->buf, cText->buf->len); #else @@ -12233,9 +12236,10 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) if (spec->version < SSL_LIBRARY_VERSION_TLS_1_3 || spec->cipherDef->calg == ssl_calg_null) { /* Unencrypted TLS 1.3 records use the pre-TLS 1.3 format. */ + rType = cText->hdr[0]; rv = ssl3_UnprotectRecord(ss, spec, cText, plaintext, &alert); } else { - rv = tls13_UnprotectRecord(ss, spec, cText, plaintext, &alert); + rv = tls13_UnprotectRecord(ss, spec, cText, plaintext, &rType, &alert); } #endif @@ -12245,14 +12249,14 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) SSL_DBG(("%d: SSL3[%d]: decryption failed", SSL_GETPID(), ss->fd)); /* Ensure that we don't process this data again. */ - databuf->len = 0; + plaintext->len = 0; /* Ignore a CCS if the alternative handshake is negotiated. Note that * this will fail if the server fails to negotiate the alternative * handshake type in a 0-RTT session that is resumed from a session that * did negotiate it. We don't care about that corner case right now. */ if (ss->version >= SSL_LIBRARY_VERSION_TLS_1_3 && - cText->type == content_change_cipher_spec && + cText->hdr[0] == content_change_cipher_spec && ss->ssl3.hs.ws != idle_handshake && cText->buf->len == 1 && cText->buf->buf[0] == change_cipher_spec_choice) { @@ -12275,9 +12279,11 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) } /* SECSuccess */ - spec->seqNum = PR_MAX(spec->seqNum, seqNum); if (IS_DTLS(ss)) { - dtls_RecordSetRecvd(&spec->recvdRecords, seqNum); + dtls_RecordSetRecvd(&spec->recvdRecords, cText->seqNum); + spec->nextSeqNum = PR_MAX(spec->nextSeqNum, cText->seqNum + 1); + } else { + ++spec->nextSeqNum; } epoch = spec->epoch; @@ -12286,19 +12292,18 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) /* * The decrypted data is now in plaintext. */ - rType = cText->type; /* This must go after decryption because TLS 1.3 - * has encrypted content types. */ /* IMPORTANT: We are in DTLS 1.3 mode and we have processed something * from the wrong epoch. Divert to a divert processing function to make * sure we don't accidentally use the data unsafely. */ if (outOfOrderSpec) { PORT_Assert(IS_DTLS(ss) && ss->version >= SSL_LIBRARY_VERSION_TLS_1_3); - return dtls13_HandleOutOfEpochRecord(ss, spec, rType, databuf); + return dtls13_HandleOutOfEpochRecord(ss, spec, rType, plaintext); } /* Check the length of the plaintext. */ - if (isTLS && databuf->len > MAX_FRAGMENT_LENGTH) { + if (isTLS && plaintext->len > MAX_FRAGMENT_LENGTH) { + plaintext->len = 0; SSL3_SendAlert(ss, alert_fatal, record_overflow); PORT_SetError(SSL_ERROR_RX_RECORD_TOO_LONG); return SECFailure; @@ -12313,14 +12318,16 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) if (ss->version >= SSL_LIBRARY_VERSION_TLS_1_3 && ss->sec.isServer && ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted) { - return tls13_HandleEarlyApplicationData(ss, databuf); + return tls13_HandleEarlyApplicationData(ss, plaintext); } + plaintext->len = 0; (void)SSL3_SendAlert(ss, alert_fatal, unexpected_message); PORT_SetError(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA); return SECFailure; } - return ssl3_HandleNonApplicationData(ss, rType, epoch, seqNum, databuf); + return ssl3_HandleNonApplicationData(ss, rType, epoch, cText->seqNum, + plaintext); } /* diff --git a/lib/ssl/ssl3gthr.c b/lib/ssl/ssl3gthr.c index 8b323bb05..b0dd7315f 100644 --- a/lib/ssl/ssl3gthr.c +++ b/lib/ssl/ssl3gthr.c @@ -264,8 +264,9 @@ static int dtls_GatherData(sslSocket *ss, sslGather *gs, int flags) { int nb; - int err; - int rv = 1; + PRUint8 contentType; + unsigned int headerLen; + SECStatus rv; SSL_TRC(30, ("dtls_GatherData")); @@ -285,81 +286,96 @@ dtls_GatherData(sslSocket *ss, sslGather *gs, int flags) ** to 13 (the size of the record header). */ if (gs->dtlsPacket.space < MAX_FRAGMENT_LENGTH + 2048 + 13) { - err = sslBuffer_Grow(&gs->dtlsPacket, - MAX_FRAGMENT_LENGTH + 2048 + 13); - if (err) { /* realloc has set error code to no mem. */ - return err; + rv = sslBuffer_Grow(&gs->dtlsPacket, + MAX_FRAGMENT_LENGTH + 2048 + 13); + if (rv != SECSuccess) { + return -1; /* Code already set. */ } } /* recv() needs to read a full datagram at a time */ nb = ssl_DefRecv(ss, gs->dtlsPacket.buf, gs->dtlsPacket.space, flags); - if (nb > 0) { PRINT_BUF(60, (ss, "raw gather data:", gs->dtlsPacket.buf, nb)); } else if (nb == 0) { /* EOF */ SSL_TRC(30, ("%d: SSL3[%d]: EOF", SSL_GETPID(), ss->fd)); - rv = 0; - return rv; + return 0; } else /* if (nb < 0) */ { SSL_DBG(("%d: SSL3[%d]: recv error %d", SSL_GETPID(), ss->fd, PR_GetError())); - rv = SECFailure; - return rv; + return -1; } gs->dtlsPacket.len = nb; } + contentType = gs->dtlsPacket.buf[gs->dtlsPacketOffset]; + if (dtls_IsLongHeader(ss->version, contentType)) { + headerLen = 13; + } else if (contentType == content_application_data) { + headerLen = 7; + } else if ((contentType & 0xe0) == 0x20) { + headerLen = 2; + } else { + SSL_DBG(("%d: SSL3[%d]: invalid first octet (%d) for DTLS", + SSL_GETPID(), ss->fd, contentType)); + PORT_SetError(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE); + gs->dtlsPacketOffset = 0; + gs->dtlsPacket.len = 0; + return -1; + } + /* At this point we should have >=1 complete records lined up in * dtlsPacket. Read off the header. */ - if ((gs->dtlsPacket.len - gs->dtlsPacketOffset) < 13) { + if ((gs->dtlsPacket.len - gs->dtlsPacketOffset) < headerLen) { SSL_DBG(("%d: SSL3[%d]: rest of DTLS packet " "too short to contain header", SSL_GETPID(), ss->fd)); - PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + PORT_SetError(PR_WOULD_BLOCK_ERROR); gs->dtlsPacketOffset = 0; gs->dtlsPacket.len = 0; - rv = SECFailure; - return rv; + return -1; } - memcpy(gs->hdr, gs->dtlsPacket.buf + gs->dtlsPacketOffset, 13); - gs->dtlsPacketOffset += 13; + memcpy(gs->hdr, SSL_BUFFER_BASE(&gs->dtlsPacket) + gs->dtlsPacketOffset, + headerLen); + gs->dtlsPacketOffset += headerLen; /* Have received SSL3 record header in gs->hdr. */ - gs->remainder = (gs->hdr[11] << 8) | gs->hdr[12]; + if (headerLen == 13) { + gs->remainder = (gs->hdr[11] << 8) | gs->hdr[12]; + } else if (headerLen == 7) { + gs->remainder = (gs->hdr[5] << 8) | gs->hdr[6]; + } else { + PORT_Assert(headerLen = 2); + gs->remainder = gs->dtlsPacket.len - gs->dtlsPacketOffset; + } if ((gs->dtlsPacket.len - gs->dtlsPacketOffset) < gs->remainder) { SSL_DBG(("%d: SSL3[%d]: rest of DTLS packet too short " "to contain rest of body", SSL_GETPID(), ss->fd)); - PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + PORT_SetError(PR_WOULD_BLOCK_ERROR); gs->dtlsPacketOffset = 0; gs->dtlsPacket.len = 0; - rv = SECFailure; - return rv; + return -1; } /* OK, we have at least one complete packet, copy into inbuf */ - if (gs->remainder > gs->inbuf.space) { - err = sslBuffer_Grow(&gs->inbuf, gs->remainder); - if (err) { /* realloc has set error code to no mem. */ - return err; - } + gs->inbuf.len = 0; + rv = sslBuffer_Append(&gs->inbuf, + SSL_BUFFER_BASE(&gs->dtlsPacket) + gs->dtlsPacketOffset, + gs->remainder); + if (rv != SECSuccess) { + return -1; /* code already set. */ } - - SSL_TRC(20, ("%d: SSL3[%d]: dtls gathered record type=%d len=%d", - SSL_GETPID(), ss->fd, gs->hdr[0], gs->inbuf.len)); - - memcpy(gs->inbuf.buf, gs->dtlsPacket.buf + gs->dtlsPacketOffset, - gs->remainder); - gs->inbuf.len = gs->remainder; gs->offset = gs->remainder; gs->dtlsPacketOffset += gs->remainder; gs->state = GS_INIT; + SSL_TRC(20, ("%d: SSL3[%d]: dtls gathered record type=%d len=%d", + SSL_GETPID(), ss->fd, contentType, gs->inbuf.len)); return 1; } @@ -442,7 +458,11 @@ ssl3_GatherCompleteHandshake(sslSocket *ss, int flags) * We need to process it now before we overwrite it with the next * handshake record. */ - rv = ssl3_HandleRecord(ss, NULL, &ss->gs.buf); + SSL_DBG(("%d: SSL3[%d]: resuming handshake", + SSL_GETPID(), ss->fd)); + PORT_Assert(!IS_DTLS(ss)); + rv = ssl3_HandleNonApplicationData(ss, content_handshake, + 0, 0, &ss->gs.buf); } else { /* State for SSLv2 client hello support. */ ssl2Gather ssl2gs = { PR_FALSE, 0 }; @@ -495,20 +515,13 @@ ssl3_GatherCompleteHandshake(sslSocket *ss, int flags) * If it's application data, ss->gs.buf will not be empty upon return. * If it's a change cipher spec, alert, or handshake message, * ss->gs.buf.len will be 0 when ssl3_HandleRecord returns SECSuccess. + * + * cText only needs to be valid for this next function call, so + * it can borrow gs.hdr. */ - cText.type = (SSL3ContentType)ss->gs.hdr[0]; - cText.version = (ss->gs.hdr[1] << 8) | ss->gs.hdr[2]; - - if (IS_DTLS(ss)) { - sslSequenceNumber seq_num; - - /* DTLS sequence number */ - PORT_Memcpy(&seq_num, &ss->gs.hdr[3], sizeof(seq_num)); - cText.seq_num = PR_ntohll(seq_num); - } - + cText.hdr = ss->gs.hdr; cText.buf = &ss->gs.inbuf; - rv = ssl3_HandleRecord(ss, &cText, &ss->gs.buf); + rv = ssl3_HandleRecord(ss, &cText); } } if (rv < 0) { @@ -520,7 +533,6 @@ ssl3_GatherCompleteHandshake(sslSocket *ss, int flags) * completing any renegotiation handshake we may be doing. */ PORT_Assert(ss->firstHsDone); - PORT_Assert(cText.type == content_application_data); break; } diff --git a/lib/ssl/sslimpl.h b/lib/ssl/sslimpl.h index 10d0333d9..799049c32 100644 --- a/lib/ssl/sslimpl.h +++ b/lib/ssl/sslimpl.h @@ -261,6 +261,7 @@ typedef struct sslOptionsStr { unsigned int requireDHENamedGroups : 1; unsigned int enable0RttData : 1; unsigned int enableTls13CompatMode : 1; + unsigned int enableDtlsShortHeader : 1; } sslOptions; typedef enum { sslHandshakingUndetermined = 0, @@ -780,9 +781,11 @@ struct ssl3StateStr { #define IS_DTLS(ss) (ss->protocolVariant == ssl_variant_datagram) typedef struct { - SSL3ContentType type; - SSL3ProtocolVersion version; - sslSequenceNumber seq_num; /* DTLS only */ + /* |seqNum| eventually contains the reconstructed sequence number. */ + sslSequenceNumber seqNum; + /* The header of the cipherText. */ + const PRUint8 *hdr; + /* |buf| is the payload of the ciphertext. */ sslBuffer *buf; } SSL3Ciphertext; @@ -1375,8 +1378,11 @@ SECStatus ssl3_SendClientHello(sslSocket *ss, sslClientHelloType type); /* * input into the SSL3 machinery from the actualy network reading code */ -SECStatus ssl3_HandleRecord( - sslSocket *ss, SSL3Ciphertext *cipher, sslBuffer *out); +SECStatus ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cipher); +SECStatus ssl3_HandleNonApplicationData(sslSocket *ss, SSL3ContentType rType, + DTLSEpoch epoch, + sslSequenceNumber seqNum, + sslBuffer *databuf); SECStatus ssl_RemoveTLSCBCPadding(sslBuffer *plaintext, unsigned int macSize); int ssl3_GatherAppDataRecord(sslSocket *ss, int flags); diff --git a/lib/ssl/sslsecur.c b/lib/ssl/sslsecur.c index f09ec067c..d3424a7ad 100644 --- a/lib/ssl/sslsecur.c +++ b/lib/ssl/sslsecur.c @@ -791,7 +791,7 @@ tls13_CheckKeyUpdate(sslSocket *ss, CipherSpecDirection dir) spec = ss->ssl3.cwSpec; margin = spec->cipherDef->max_records / 4; } - seqNum = spec->seqNum; + seqNum = spec->nextSeqNum; keyUpdate = seqNum > spec->cipherDef->max_records - margin; ssl_ReleaseSpecReadLock(ss); if (!keyUpdate) { diff --git a/lib/ssl/sslsock.c b/lib/ssl/sslsock.c index e08d5e232..f5d829d4e 100644 --- a/lib/ssl/sslsock.c +++ b/lib/ssl/sslsock.c @@ -81,7 +81,8 @@ static sslOptions ssl_defaults = { .enableSignedCertTimestamps = PR_FALSE, .requireDHENamedGroups = PR_FALSE, .enable0RttData = PR_FALSE, - .enableTls13CompatMode = PR_FALSE + .enableTls13CompatMode = PR_FALSE, + .enableDtlsShortHeader = PR_FALSE }; /* @@ -807,6 +808,10 @@ SSL_OptionSet(PRFileDesc *fd, PRInt32 which, PRIntn val) ss->opt.enableTls13CompatMode = val; break; + case SSL_ENABLE_DTLS_SHORT_HEADER: + ss->opt.enableDtlsShortHeader = val; + break; + default: PORT_SetError(SEC_ERROR_INVALID_ARGS); rv = SECFailure; @@ -943,6 +948,9 @@ SSL_OptionGet(PRFileDesc *fd, PRInt32 which, PRIntn *pVal) case SSL_ENABLE_TLS13_COMPAT_MODE: val = ss->opt.enableTls13CompatMode; break; + case SSL_ENABLE_DTLS_SHORT_HEADER: + val = ss->opt.enableDtlsShortHeader; + break; default: PORT_SetError(SEC_ERROR_INVALID_ARGS); rv = SECFailure; @@ -1063,6 +1071,9 @@ SSL_OptionGetDefault(PRInt32 which, PRIntn *pVal) case SSL_ENABLE_TLS13_COMPAT_MODE: val = ssl_defaults.enableTls13CompatMode; break; + case SSL_ENABLE_DTLS_SHORT_HEADER: + val = ssl_defaults.enableDtlsShortHeader; + break; default: PORT_SetError(SEC_ERROR_INVALID_ARGS); rv = SECFailure; @@ -1246,6 +1257,10 @@ SSL_OptionSetDefault(PRInt32 which, PRIntn val) ssl_defaults.enableTls13CompatMode = val; break; + case SSL_ENABLE_DTLS_SHORT_HEADER: + ssl_defaults.enableDtlsShortHeader = val; + break; + default: PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; diff --git a/lib/ssl/sslspec.h b/lib/ssl/sslspec.h index 729ac1006..207bd6ef6 100644 --- a/lib/ssl/sslspec.h +++ b/lib/ssl/sslspec.h @@ -162,7 +162,9 @@ struct ssl3CipherSpecStr { DTLSEpoch epoch; const char *phase; - sslSequenceNumber seqNum; + + /* The next sequence number to be sent or received. */ + sslSequenceNumber nextSeqNum; DTLSRecvdRecords recvdRecords; /* The number of 0-RTT bytes that can be sent or received in TLS 1.3. This diff --git a/lib/ssl/tls13con.c b/lib/ssl/tls13con.c index c06acc83a..ed0bb5ecb 100644 --- a/lib/ssl/tls13con.c +++ b/lib/ssl/tls13con.c @@ -792,7 +792,7 @@ tls13_HandleKeyUpdate(sslSocket *ss, PRUint8 *b, unsigned int length) /* Only send an update if we have sent with the current spec. This * prevents us from being forced to crank forward pointlessly. */ ssl_GetSpecReadLock(ss); - sendUpdate = ss->ssl3.cwSpec->seqNum > 0; + sendUpdate = ss->ssl3.cwSpec->nextSeqNum > 0; ssl_ReleaseSpecReadLock(ss); } else { sendUpdate = PR_TRUE; @@ -1620,7 +1620,7 @@ tls13_HandleClientHelloPart2(sslSocket *ss, ssl_GetSpecWriteLock(ss); /* Increase the write sequence number. The read sequence number * will be reset after this to early data or handshake. */ - ss->ssl3.cwSpec->seqNum = 1; + ss->ssl3.cwSpec->nextSeqNum = 1; ssl_ReleaseSpecWriteLock(ss); } @@ -2007,7 +2007,7 @@ tls13_SendHelloRetryRequest(sslSocket *ss, /* We depend on this being exactly one record and one message. */ PORT_Assert(!IS_DTLS(ss) || (ss->ssl3.hs.sendMessageSeq == 1 && - ss->ssl3.cwSpec->seqNum == 1)); + ss->ssl3.cwSpec->nextSeqNum == 1)); ssl_ReleaseXmitBufLock(ss); ss->ssl3.hs.helloRetry = PR_TRUE; @@ -3316,7 +3316,7 @@ tls13_SetCipherSpec(sslSocket *ss, PRUint16 epoch, return SECFailure; } spec->epoch = epoch; - spec->seqNum = 0; + spec->nextSeqNum = 0; if (IS_DTLS(ss)) { dtls_InitRecvdRecords(&spec->recvdRecords); } @@ -4843,43 +4843,48 @@ tls13_ProtectRecord(sslSocket *ss, PORT_Assert(cwSpec->direction == CipherSpecWrite); SSL_TRC(3, ("%d: TLS13[%d]: spec=%d epoch=%d (%s) protect 0x%0llx len=%u", SSL_GETPID(), ss->fd, cwSpec, cwSpec->epoch, cwSpec->phase, - cwSpec->seqNum, contentLen)); + cwSpec->nextSeqNum, contentLen)); - if (contentLen + 1 + tagLen > wrBuf->space) { + if (contentLen + 1 + tagLen > SSL_BUFFER_SPACE(wrBuf)) { PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); return SECFailure; } /* Copy the data into the wrBuf. We're going to encrypt in-place * in the AEAD branch anyway */ - PORT_Memcpy(wrBuf->buf, pIn, contentLen); + PORT_Memcpy(SSL_BUFFER_NEXT(wrBuf), pIn, contentLen); if (cipher_def->calg == ssl_calg_null) { /* Shortcut for plaintext */ - wrBuf->len = contentLen; + rv = sslBuffer_Skip(wrBuf, contentLen, NULL); + PORT_Assert(rv == SECSuccess); } else { PRUint8 aad[8]; + int len; PORT_Assert(cipher_def->type == type_aead); /* Add the content type at the end. */ - wrBuf->buf[contentLen] = type; + *(SSL_BUFFER_NEXT(wrBuf) + contentLen) = type; rv = tls13_FormatAdditionalData(ss, aad, sizeof(aad), cwSpec->epoch, - cwSpec->seqNum); + cwSpec->nextSeqNum); if (rv != SECSuccess) { return SECFailure; } rv = cwSpec->aead(&cwSpec->keyMaterial, - PR_FALSE, /* do encrypt */ - wrBuf->buf, /* output */ - (int *)&wrBuf->len, /* out len */ - wrBuf->space, /* max out */ - wrBuf->buf, contentLen + 1, /* input */ + PR_FALSE, /* do encrypt */ + SSL_BUFFER_NEXT(wrBuf), /* output */ + &len, /* out len */ + SSL_BUFFER_SPACE(wrBuf), /* max out */ + SSL_BUFFER_NEXT(wrBuf), /* input */ + contentLen + 1, /* input len */ aad, sizeof(aad)); if (rv != SECSuccess) { PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); return SECFailure; } + rv = sslBuffer_Skip(wrBuf, len, NULL); + PORT_Assert(rv == SECSuccess); } return SECSuccess; @@ -4897,25 +4902,21 @@ tls13_ProtectRecord(sslSocket *ss, SECStatus tls13_UnprotectRecord(sslSocket *ss, ssl3CipherSpec *spec, - SSL3Ciphertext *cText, sslBuffer *plaintext, + SSL3Ciphertext *cText, + sslBuffer *plaintext, + SSL3ContentType *innerType, SSL3AlertDescription *alert) { const ssl3BulkCipherDef *cipher_def = spec->cipherDef; - sslSequenceNumber seqNum; PRUint8 aad[8]; SECStatus rv; *alert = bad_record_mac; /* Default alert for most issues. */ PORT_Assert(spec->direction == CipherSpecRead); - if (IS_DTLS(ss)) { - seqNum = cText->seq_num & RECORD_SEQ_MASK; - } else { - seqNum = spec->seqNum; - } SSL_TRC(3, ("%d: TLS13[%d]: spec=%d epoch=%d (%s) unprotect 0x%0llx len=%u", - SSL_GETPID(), ss->fd, spec, spec->epoch, spec->phase, seqNum, - cText->buf->len)); + SSL_GETPID(), ss->fd, spec, spec->epoch, spec->phase, + cText->seqNum, cText->buf->len)); /* We can perform this test in variable time because the record's total * length and the ciphersuite are both public knowledge. */ @@ -4927,28 +4928,37 @@ tls13_UnprotectRecord(sslSocket *ss, return SECFailure; } - /* Verify that the content type is right, even though we overwrite it. */ - if (cText->type != content_application_data) { + /* Verify that the content type is right, even though we overwrite it. + * Also allow the DTLS short header in TLS 1.3. */ + if (!(cText->hdr[0] == content_application_data || + (IS_DTLS(ss) && + ss->version >= SSL_LIBRARY_VERSION_TLS_1_3 && + (cText->hdr[0] & 0xe0) == 0x20))) { SSL_TRC(3, - ("%d: TLS13[%d]: record has invalid exterior content type=%d", - SSL_GETPID(), ss->fd, cText->type)); + ("%d: TLS13[%d]: record has invalid exterior type=%2.2x", + SSL_GETPID(), ss->fd, cText->hdr[0])); /* Do we need a better error here? */ PORT_SetError(SSL_ERROR_BAD_MAC_READ); return SECFailure; } - /* Check the version number in the record. */ - if (cText->version != spec->recordVersion) { - /* Do we need a better error here? */ - SSL_TRC(3, - ("%d: TLS13[%d]: record has bogus version", - SSL_GETPID(), ss->fd)); - return SECFailure; + /* Check the version number in the record. Stream only. */ + if (!IS_DTLS(ss)) { + SSL3ProtocolVersion version = + ((SSL3ProtocolVersion)cText->hdr[1] << 8) | + (SSL3ProtocolVersion)cText->hdr[2]; + if (version != spec->recordVersion) { + /* Do we need a better error here? */ + SSL_TRC(3, ("%d: TLS13[%d]: record has bogus version", + SSL_GETPID(), ss->fd)); + return SECFailure; + } } /* Decrypt */ PORT_Assert(cipher_def->type == type_aead); - rv = tls13_FormatAdditionalData(ss, aad, sizeof(aad), spec->epoch, seqNum); + rv = tls13_FormatAdditionalData(ss, aad, sizeof(aad), spec->epoch, + cText->seqNum); if (rv != SECSuccess) { return SECFailure; } @@ -4977,9 +4987,7 @@ tls13_UnprotectRecord(sslSocket *ss, /* Bogus padding. */ if (plaintext->len < 1) { - SSL_TRC(3, - ("%d: TLS13[%d]: empty record", - SSL_GETPID(), ss->fd, cText->type)); + SSL_TRC(3, ("%d: TLS13[%d]: empty record", SSL_GETPID(), ss->fd)); /* It's safe to report this specifically because it happened * after the MAC has been verified. */ PORT_SetError(SSL_ERROR_BAD_BLOCK_PADDING); @@ -4987,12 +4995,12 @@ tls13_UnprotectRecord(sslSocket *ss, } /* Record the type. */ - cText->type = plaintext->buf[plaintext->len - 1]; + *innerType = (SSL3ContentType)plaintext->buf[plaintext->len - 1]; --plaintext->len; /* Check that we haven't received too much 0-RTT data. */ if (spec->epoch == TrafficKeyEarlyApplicationData && - cText->type == content_application_data) { + *innerType == content_application_data) { if (plaintext->len > spec->earlyDataRemaining) { *alert = unexpected_message; PORT_SetError(SSL_ERROR_TOO_MUCH_EARLY_DATA); @@ -5002,9 +5010,8 @@ tls13_UnprotectRecord(sslSocket *ss, } SSL_TRC(10, - ("%d: TLS13[%d]: %s received record of length=%d type=%d", - SSL_GETPID(), ss->fd, SSL_ROLE(ss), - plaintext->len, cText->type)); + ("%d: TLS13[%d]: %s received record of length=%d, type=%d", + SSL_GETPID(), ss->fd, SSL_ROLE(ss), plaintext->len, *innerType)); return SECSuccess; } diff --git a/lib/ssl/tls13con.h b/lib/ssl/tls13con.h index 1aaffb651..7ef1fabc2 100644 --- a/lib/ssl/tls13con.h +++ b/lib/ssl/tls13con.h @@ -28,6 +28,7 @@ typedef enum { SECStatus tls13_UnprotectRecord( sslSocket *ss, ssl3CipherSpec *spec, SSL3Ciphertext *cText, sslBuffer *plaintext, + SSL3ContentType *innerType, SSL3AlertDescription *alert); #if defined(WIN32) |