summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Jacobs <kjacobs@mozilla.com>2020-01-06 21:26:20 +0000
committerKevin Jacobs <kjacobs@mozilla.com>2020-01-06 21:26:20 +0000
commit1a9015776d73205f7808c27a96dc47f1637bc3f7 (patch)
tree3af5a6045c9265c3f68da81367375f97978b568c
parent8ee7dfd77a639eb627b61d125ba638ce4252fc6a (diff)
downloadnss-hg-1a9015776d73205f7808c27a96dc47f1637bc3f7.tar.gz
Bug 1599514 - Update DTLS 1.3 support to draft-30 r=mt
This patch updates the DTLS 1.3 implementation to draft version 30, including unified header format and sequence number encryption. Also added are new `SSL_CreateMask` experimental functions. Differential Revision: https://phabricator.services.mozilla.com/D51014
-rw-r--r--cpputil/databuffer.h1
-rw-r--r--cpputil/scoped_ptrs_ssl.h2
-rw-r--r--cpputil/tls_parser.h5
-rw-r--r--gtests/ssl_gtest/manifest.mn3
-rw-r--r--gtests/ssl_gtest/ssl_aead_unittest.cc (renamed from gtests/ssl_gtest/ssl_primitive_unittest.cc)6
-rw-r--r--gtests/ssl_gtest/ssl_ciphersuite_unittest.cc29
-rw-r--r--gtests/ssl_gtest/ssl_drop_unittest.cc71
-rw-r--r--gtests/ssl_gtest/ssl_gtest.gyp3
-rw-r--r--gtests/ssl_gtest/ssl_masking_unittest.cc337
-rw-r--r--gtests/ssl_gtest/ssl_record_unittest.cc33
-rw-r--r--gtests/ssl_gtest/ssl_recordsize_unittest.cc36
-rw-r--r--gtests/ssl_gtest/ssl_tls13compat_unittest.cc15
-rw-r--r--gtests/ssl_gtest/tls_agent.cc37
-rw-r--r--gtests/ssl_gtest/tls_filter.cc189
-rw-r--r--gtests/ssl_gtest/tls_filter.h120
-rw-r--r--gtests/ssl_gtest/tls_protect.cc89
-rw-r--r--gtests/ssl_gtest/tls_protect.h5
-rw-r--r--lib/ssl/dtls13con.c92
-rw-r--r--lib/ssl/dtls13con.h6
-rw-r--r--lib/ssl/dtlscon.c40
-rw-r--r--lib/ssl/dtlscon.h1
-rw-r--r--lib/ssl/ssl3con.c45
-rw-r--r--lib/ssl/ssl3gthr.c25
-rw-r--r--lib/ssl/ssl3prot.h2
-rw-r--r--lib/ssl/sslexp.h50
-rw-r--r--lib/ssl/sslimpl.h26
-rw-r--r--lib/ssl/sslprimitive.c205
-rw-r--r--lib/ssl/sslsock.c3
-rw-r--r--lib/ssl/sslspec.c3
-rw-r--r--lib/ssl/sslspec.h3
-rw-r--r--lib/ssl/tls13con.c40
-rw-r--r--lib/ssl/tls13con.h6
32 files changed, 1260 insertions, 268 deletions
diff --git a/cpputil/databuffer.h b/cpputil/databuffer.h
index e981a7c22..4bedd075d 100644
--- a/cpputil/databuffer.h
+++ b/cpputil/databuffer.h
@@ -23,6 +23,7 @@ class DataBuffer {
DataBuffer(const DataBuffer& other) : data_(nullptr), len_(0) {
Assign(other);
}
+ explicit DataBuffer(size_t l) : data_(nullptr), len_(0) { Allocate(l); }
~DataBuffer() { delete[] data_; }
DataBuffer& operator=(const DataBuffer& other) {
diff --git a/cpputil/scoped_ptrs_ssl.h b/cpputil/scoped_ptrs_ssl.h
index 474187540..682ebab82 100644
--- a/cpputil/scoped_ptrs_ssl.h
+++ b/cpputil/scoped_ptrs_ssl.h
@@ -12,6 +12,7 @@
struct ScopedDeleteSSL {
void operator()(SSLAeadContext* ctx) { SSL_DestroyAead(ctx); }
+ void operator()(SSLMaskingContext* ctx) { SSL_DestroyMaskingContext(ctx); }
void operator()(SSLAntiReplayContext* ctx) {
SSL_ReleaseAntiReplayContext(ctx);
}
@@ -34,6 +35,7 @@ struct ScopedMaybeDeleteSSL {
SCOPED(SSLAeadContext);
SCOPED(SSLAntiReplayContext);
+SCOPED(SSLMaskingContext);
SCOPED(SSLResumptionTokenInfo);
#undef SCOPED
diff --git a/cpputil/tls_parser.h b/cpputil/tls_parser.h
index 05dd99fc8..6636b3c6a 100644
--- a/cpputil/tls_parser.h
+++ b/cpputil/tls_parser.h
@@ -74,6 +74,11 @@ const uint8_t kTlsFakeChangeCipherSpec[] = {
0x01 // Value
};
+const uint8_t kCtDtlsCiphertext = 0x20;
+const uint8_t kCtDtlsCiphertextMask = 0xE0;
+const uint8_t kCtDtlsCiphertext16bSeqno = 0x08;
+const uint8_t kCtDtlsCiphertextLengthPresent = 0x04;
+
static const uint8_t kTls13PskKe = 0;
static const uint8_t kTls13PskDhKe = 1;
static const uint8_t kTls13PskAuth = 0;
diff --git a/gtests/ssl_gtest/manifest.mn b/gtests/ssl_gtest/manifest.mn
index ed1128f7c..d5e96a490 100644
--- a/gtests/ssl_gtest/manifest.mn
+++ b/gtests/ssl_gtest/manifest.mn
@@ -14,6 +14,7 @@ CSRCS = \
CPPSRCS = \
bloomfilter_unittest.cc \
ssl_0rtt_unittest.cc \
+ ssl_aead_unittest.cc \
ssl_agent_unittest.cc \
ssl_auth_unittest.cc \
ssl_cert_ext_unittest.cc \
@@ -35,8 +36,8 @@ CPPSRCS = \
ssl_hrr_unittest.cc \
ssl_keyupdate_unittest.cc \
ssl_loopback_unittest.cc \
+ ssl_masking_unittest.cc \
ssl_misc_unittest.cc \
- ssl_primitive_unittest.cc \
ssl_record_unittest.cc \
ssl_recordsep_unittest.cc \
ssl_recordsize_unittest.cc \
diff --git a/gtests/ssl_gtest/ssl_primitive_unittest.cc b/gtests/ssl_gtest/ssl_aead_unittest.cc
index 66ecdeb12..d94683be3 100644
--- a/gtests/ssl_gtest/ssl_primitive_unittest.cc
+++ b/gtests/ssl_gtest/ssl_aead_unittest.cc
@@ -54,7 +54,7 @@ class AeadTest : public ::testing::Test {
ASSERT_GE(kMaxSize, ciphertext_len);
ASSERT_LT(0U, ciphertext_len);
- uint8_t output[kMaxSize];
+ uint8_t output[kMaxSize] = {0};
unsigned int output_len = 0;
EXPECT_EQ(SECSuccess, SSL_AeadEncrypt(ctx.get(), 0, kAad, sizeof(kAad),
kPlaintext, sizeof(kPlaintext),
@@ -181,7 +181,7 @@ TEST_F(AeadTest, AeadNoPointer) {
}
TEST_F(AeadTest, AeadAes128Gcm) {
- SSLAeadContext *ctxInit;
+ SSLAeadContext *ctxInit = nullptr;
ASSERT_EQ(SECSuccess,
SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256,
secret_.get(), kLabel, strlen(kLabel), &ctxInit));
@@ -203,7 +203,7 @@ TEST_F(AeadTest, AeadAes256Gcm) {
}
TEST_F(AeadTest, AeadChaCha20Poly1305) {
- SSLAeadContext *ctxInit;
+ SSLAeadContext *ctxInit = nullptr;
ASSERT_EQ(
SECSuccess,
SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_CHACHA20_POLY1305_SHA256,
diff --git a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
index 7739fe76f..86cb02d73 100644
--- a/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
+++ b/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
@@ -263,6 +263,7 @@ TEST_P(TlsCipherSuiteTest, ResumeCipherSuite) {
TEST_P(TlsCipherSuiteTest, ReadLimit) {
SetupCertificate();
EnableSingleCipher();
+ TlsSendCipherSpecCapturer capturer(client_);
ConnectAndCheckCipherSuite();
if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
uint64_t last = last_safe_write();
@@ -295,9 +296,31 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) {
} else {
epoch = 0;
}
- TlsAgentTestBase::MakeRecord(variant_, ssl_ct_application_data, version_,
- payload, sizeof(payload), &record,
- (epoch << 48) | record_limit());
+
+ uint64_t seqno = (epoch << 48) | record_limit();
+
+ // DTLS 1.3 masks the sequence number
+ if (variant_ == ssl_variant_datagram &&
+ version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ auto spec = capturer.spec(1);
+ ASSERT_NE(nullptr, spec.get());
+ ASSERT_EQ(3, spec->epoch());
+
+ DataBuffer pt, ct;
+ uint8_t dtls13_ctype = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
+ kCtDtlsCiphertextLengthPresent;
+ TlsRecordHeader hdr(variant_, version_, dtls13_ctype, seqno);
+ pt.Assign(payload, sizeof(payload));
+ TlsRecordHeader out_hdr;
+ spec->Protect(hdr, pt, &ct, &out_hdr);
+
+ auto rv = out_hdr.Write(&record, 0, ct);
+ EXPECT_EQ(out_hdr.header_length() + ct.len(), rv);
+ } else {
+ TlsAgentTestBase::MakeRecord(variant_, ssl_ct_application_data, version_,
+ payload, sizeof(payload), &record, seqno);
+ }
+
client_->SendDirect(record);
server_->ExpectReadWriteError();
server_->ReadBytes();
diff --git a/gtests/ssl_gtest/ssl_drop_unittest.cc b/gtests/ssl_gtest/ssl_drop_unittest.cc
index b441b5c10..05b38e381 100644
--- a/gtests/ssl_gtest/ssl_drop_unittest.cc
+++ b/gtests/ssl_gtest/ssl_drop_unittest.cc
@@ -619,55 +619,6 @@ TEST_P(TlsDropDatagram13, ReorderServerEE) {
// The client sends an out of order non-handshake message
// but with the handshake key.
-class TlsSendCipherSpecCapturer {
- public:
- TlsSendCipherSpecCapturer(const std::shared_ptr<TlsAgent>& agent)
- : agent_(agent), send_cipher_specs_() {
- EXPECT_EQ(SECSuccess,
- SSL_SecretCallback(agent_->ssl_fd(), SecretCallback, this));
- }
-
- std::shared_ptr<TlsCipherSpec> spec(size_t i) {
- if (i >= send_cipher_specs_.size()) {
- return nullptr;
- }
- return send_cipher_specs_[i];
- }
-
- private:
- static void SecretCallback(PRFileDesc* fd, PRUint16 epoch,
- SSLSecretDirection dir, PK11SymKey* secret,
- void* arg) {
- auto self = static_cast<TlsSendCipherSpecCapturer*>(arg);
- std::cerr << self->agent_->role_str() << ": capture " << dir
- << " secret for epoch " << epoch << std::endl;
-
- if (dir == ssl_secret_read) {
- return;
- }
-
- SSLPreliminaryChannelInfo preinfo;
- EXPECT_EQ(SECSuccess,
- SSL_GetPreliminaryChannelInfo(self->agent_->ssl_fd(), &preinfo,
- sizeof(preinfo)));
- EXPECT_EQ(sizeof(preinfo), preinfo.length);
- EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite);
-
- SSLCipherSuiteInfo cipherinfo;
- EXPECT_EQ(SECSuccess,
- SSL_GetCipherSuiteInfo(preinfo.cipherSuite, &cipherinfo,
- sizeof(cipherinfo)));
- EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length);
-
- auto spec = std::make_shared<TlsCipherSpec>(true, epoch);
- EXPECT_TRUE(spec->SetKeys(&cipherinfo, secret));
- self->send_cipher_specs_.push_back(spec);
- }
-
- std::shared_ptr<TlsAgent> agent_;
- std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_;
-};
-
TEST_F(TlsConnectDatagram13, SendOutOfOrderAppWithHandshakeKey) {
StartConnect();
// Capturing secrets means that we can't use decrypting filters on the client.
@@ -684,8 +635,10 @@ TEST_F(TlsConnectDatagram13, SendOutOfOrderAppWithHandshakeKey) {
auto spec = capturer.spec(0);
ASSERT_NE(nullptr, spec.get());
ASSERT_EQ(2, spec->epoch());
- ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002,
- ssl_ct_application_data,
+
+ uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
+ kCtDtlsCiphertextLengthPresent;
+ ASSERT_TRUE(client_->SendEncryptedRecord(spec, 0x0002000000000002, dtls13_ct,
DataBuffer(buf, sizeof(buf))));
// Now have the server consume the bogus message.
@@ -844,7 +797,7 @@ static void GetCipherAndLimit(uint16_t version, uint16_t* cipher,
// a reasonable amount of time.
*cipher = TLS_CHACHA20_POLY1305_SHA256;
// Assume that we are starting with an expected sequence number of 0.
- *limit = (1ULL << 29) - 1;
+ *limit = (1ULL << 15) - 1;
}
}
@@ -866,14 +819,14 @@ TEST_P(TlsConnectDatagram, MissLotsOfPackets) {
SendReceive();
}
-// Send a sequence number of 0xfffffffd and it should be interpreted as that
+// Send a sequence number of 0xfffd and it should be interpreted as that
// (and not -3 or UINT64_MAX - 2).
TEST_F(TlsConnectDatagram13, UnderflowSequenceNumber) {
Connect();
// This is only valid if short headers are disabled.
client_->SetOption(SSL_ENABLE_DTLS_SHORT_HEADER, PR_FALSE);
EXPECT_EQ(SECSuccess,
- SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), (1ULL << 30) - 3));
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), (1ULL << 16) - 3));
SendReceive();
}
@@ -918,9 +871,13 @@ class TlsReplaceFirstRecordWithJunk : public TlsRecordFilter {
return KEEP;
}
replaced_ = true;
- TlsRecordHeader out_header(header.variant(), header.version(),
- ssl_ct_application_data,
- header.sequence_number());
+
+ uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
+ kCtDtlsCiphertextLengthPresent;
+ TlsRecordHeader out_header(
+ header.variant(), header.version(),
+ is_dtls13() ? dtls13_ct : ssl_ct_application_data,
+ header.sequence_number());
static const uint8_t junk[] = {1, 2, 3, 4};
*offset = out_header.Write(output, *offset, DataBuffer(junk, sizeof(junk)));
diff --git a/gtests/ssl_gtest/ssl_gtest.gyp b/gtests/ssl_gtest/ssl_gtest.gyp
index 6cff0fc9d..ae79c41fe 100644
--- a/gtests/ssl_gtest/ssl_gtest.gyp
+++ b/gtests/ssl_gtest/ssl_gtest.gyp
@@ -15,6 +15,7 @@
'libssl_internals.c',
'selfencrypt_unittest.cc',
'ssl_0rtt_unittest.cc',
+ 'ssl_aead_unittest.cc',
'ssl_agent_unittest.cc',
'ssl_auth_unittest.cc',
'ssl_cert_ext_unittest.cc',
@@ -36,8 +37,8 @@
'ssl_hrr_unittest.cc',
'ssl_keyupdate_unittest.cc',
'ssl_loopback_unittest.cc',
+ 'ssl_masking_unittest.cc',
'ssl_misc_unittest.cc',
- 'ssl_primitive_unittest.cc',
'ssl_record_unittest.cc',
'ssl_recordsep_unittest.cc',
'ssl_recordsize_unittest.cc',
diff --git a/gtests/ssl_gtest/ssl_masking_unittest.cc b/gtests/ssl_gtest/ssl_masking_unittest.cc
new file mode 100644
index 000000000..5b63b945b
--- /dev/null
+++ b/gtests/ssl_gtest/ssl_masking_unittest.cc
@@ -0,0 +1,337 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <memory>
+
+#include "keyhi.h"
+#include "pk11pub.h"
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslexp.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "nss_scoped_ptrs.h"
+#include "scoped_ptrs_ssl.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+// From tls_hkdf_unittest.cc:
+extern size_t GetHashLength(SSLHashType ht);
+
+const std::string kLabel = "sn";
+
+class MaskingTest : public ::testing::Test {
+ public:
+ MaskingTest() : slot_(PK11_GetInternalSlot()) {}
+
+ void InitSecret(SSLHashType hash_type) {
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ PK11SymKey *s = PK11_KeyGen(slot_.get(), CKM_GENERIC_SECRET_KEY_GEN,
+ nullptr, AES_128_KEY_LENGTH, nullptr);
+ ASSERT_NE(nullptr, s);
+ secret_.reset(s);
+ }
+
+ void SetUp() override {
+ InitSecret(ssl_hash_sha256);
+ PORT_SetError(0);
+ }
+
+ protected:
+ ScopedPK11SymKey secret_;
+ ScopedPK11SlotInfo slot_;
+ void CreateMask(PRUint16 ciphersuite, std::string label,
+ const std::vector<uint8_t> &sample,
+ std::vector<uint8_t> *out_mask) {
+ ASSERT_NE(nullptr, out_mask);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite,
+ secret_.get(), label.c_str(),
+ label.size(), &ctx_init));
+ EXPECT_EQ(0, PORT_GetError());
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
+ out_mask->data(), out_mask->size()));
+ EXPECT_EQ(0, PORT_GetError());
+ bool all_zeros = std::all_of(out_mask->begin(), out_mask->end(),
+ [](uint8_t v) { return v == 0; });
+
+ // If out_mask is short, |all_zeros| will be (expectedly) true often enough
+ // to fail tests.
+ // In this case, just retry to make sure we're not outputting zeros
+ // continuously.
+ if (all_zeros && out_mask->size() < 3) {
+ unsigned int tries = 2;
+ std::vector<uint8_t> tmp_sample = sample;
+ std::vector<uint8_t> tmp_mask(out_mask->size());
+ while (tries--) {
+ tmp_sample.data()[0]++; // Tweak something to get a new mask.
+ EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), tmp_sample.data(),
+ tmp_sample.size(), tmp_mask.data(),
+ tmp_mask.size()));
+ EXPECT_EQ(0, PORT_GetError());
+ bool retry_zero = std::all_of(tmp_mask.begin(), tmp_mask.end(),
+ [](uint8_t v) { return v == 0; });
+ if (!retry_zero) {
+ all_zeros = false;
+ break;
+ }
+ }
+ }
+ EXPECT_FALSE(all_zeros);
+ }
+};
+
+TEST_F(MaskingTest, MaskContextNoLabel) {
+ std::vector<uint8_t> sample(AES_BLOCK_SIZE);
+ std::vector<uint8_t> mask(AES_BLOCK_SIZE);
+ CreateMask(TLS_AES_128_GCM_SHA256, std::string(""), sample, &mask);
+}
+
+TEST_F(MaskingTest, MaskContextUnsupportedMech) {
+ std::vector<uint8_t> sample(AES_BLOCK_SIZE);
+ std::vector<uint8_t> mask(AES_BLOCK_SIZE);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECFailure,
+ SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_RSA_WITH_AES_128_CBC_SHA256,
+ secret_.get(), nullptr, 0, &ctx_init));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ EXPECT_EQ(nullptr, ctx_init);
+}
+
+TEST_F(MaskingTest, MaskNullSample) {
+ std::vector<uint8_t> mask(AES_BLOCK_SIZE);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_AES_128_GCM_SHA256, secret_.get(),
+ kLabel.c_str(), kLabel.size(), &ctx_init));
+ EXPECT_EQ(0, PORT_GetError());
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+
+ EXPECT_EQ(SECFailure,
+ SSL_CreateMask(ctx.get(), nullptr, 0, mask.data(), mask.size()));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), nullptr, mask.size(),
+ mask.data(), mask.size()));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(MaskingTest, MaskContextUnsupportedVersion) {
+ std::vector<uint8_t> sample(AES_BLOCK_SIZE);
+ std::vector<uint8_t> mask(AES_BLOCK_SIZE);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECFailure, SSL_CreateMaskingContext(
+ SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256,
+ secret_.get(), nullptr, 0, &ctx_init));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ EXPECT_EQ(nullptr, ctx_init);
+}
+
+TEST_F(MaskingTest, MaskTooMuchOutput) {
+ // Max internally-supported length for AES
+ std::vector<uint8_t> sample(AES_BLOCK_SIZE);
+ std::vector<uint8_t> mask(AES_BLOCK_SIZE + 1);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_AES_128_GCM_SHA256, secret_.get(),
+ kLabel.c_str(), kLabel.size(), &ctx_init));
+ EXPECT_EQ(0, PORT_GetError());
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+
+ EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
+ mask.data(), mask.size()));
+ EXPECT_EQ(SEC_ERROR_OUTPUT_LEN, PORT_GetError());
+}
+
+TEST_F(MaskingTest, MaskShortOutput) {
+ std::vector<uint8_t> sample(16);
+ std::vector<uint8_t> mask(16); // Don't pass a null
+
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_AES_128_GCM_SHA256, secret_.get(),
+ kLabel.c_str(), kLabel.size(), &ctx_init));
+ EXPECT_EQ(0, PORT_GetError());
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+ EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
+ mask.data(), 0));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(MaskingTest, MaskRotateLabel) {
+ std::vector<uint8_t> sample(AES_BLOCK_SIZE);
+ std::vector<uint8_t> mask1(AES_BLOCK_SIZE);
+ std::vector<uint8_t> mask2(AES_BLOCK_SIZE);
+ EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
+ sample.size()));
+
+ CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask1);
+ CreateMask(TLS_AES_128_GCM_SHA256, std::string("sn1"), sample, &mask2);
+ EXPECT_FALSE(mask1 == mask2);
+}
+
+TEST_F(MaskingTest, MaskRotateSample) {
+ std::vector<uint8_t> sample(AES_BLOCK_SIZE);
+ std::vector<uint8_t> mask1(AES_BLOCK_SIZE);
+ std::vector<uint8_t> mask2(AES_BLOCK_SIZE);
+
+ EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
+ sample.size()));
+ CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask1);
+
+ EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
+ sample.size()));
+ CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask2);
+ EXPECT_FALSE(mask1 == mask2);
+}
+
+TEST_F(MaskingTest, MaskAesRederive) {
+ std::vector<uint8_t> sample(AES_BLOCK_SIZE);
+ std::vector<uint8_t> mask1(AES_BLOCK_SIZE);
+ std::vector<uint8_t> mask2(AES_BLOCK_SIZE);
+
+ SECStatus rv =
+ PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), sample.size());
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Check that re-using inputs with a new context produces the same mask.
+ CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask1);
+ CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask2);
+ EXPECT_TRUE(mask1 == mask2);
+}
+
+TEST_F(MaskingTest, MaskAesTooLong) {
+ std::vector<uint8_t> sample(AES_BLOCK_SIZE + 1);
+ std::vector<uint8_t> mask(AES_BLOCK_SIZE + 1);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_AES_128_GCM_SHA256, secret_.get(),
+ kLabel.c_str(), kLabel.size(), &ctx_init));
+ EXPECT_EQ(0, PORT_GetError());
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+ EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
+ mask.data(), mask.size()));
+ EXPECT_EQ(SEC_ERROR_OUTPUT_LEN, PORT_GetError());
+}
+
+TEST_F(MaskingTest, MaskAesShortSample) {
+ std::vector<uint8_t> sample(AES_BLOCK_SIZE - 1);
+ std::vector<uint8_t> mask(AES_BLOCK_SIZE - 1);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess,
+ SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_AES_128_GCM_SHA256, secret_.get(),
+ kLabel.c_str(), kLabel.size(), &ctx_init));
+ EXPECT_EQ(0, PORT_GetError());
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+
+ EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
+ mask.data(), mask.size()));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(MaskingTest, MaskAesShortValid) {
+ std::vector<uint8_t> sample(AES_BLOCK_SIZE);
+ std::vector<uint8_t> mask(1);
+ EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
+ sample.size()));
+ CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask);
+}
+
+TEST_F(MaskingTest, MaskChaChaRederive) {
+ // Block-aligned.
+ std::vector<uint8_t> sample(32);
+ std::vector<uint8_t> mask1(32);
+ std::vector<uint8_t> mask2(32);
+ EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
+ sample.size()));
+ CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask1);
+ CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask2);
+ EXPECT_TRUE(mask1 == mask2);
+}
+
+TEST_F(MaskingTest, MaskChaChaRederiveOddSizes) {
+ // Non-block-aligned.
+ std::vector<uint8_t> sample(27);
+ std::vector<uint8_t> mask1(26);
+ std::vector<uint8_t> mask2(25);
+ EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
+ sample.size()));
+ CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask1);
+ CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask2);
+ mask1.pop_back();
+ EXPECT_TRUE(mask1 == mask2);
+}
+
+TEST_F(MaskingTest, MaskChaChaLongValid) {
+ // Max internally-supported length for ChaCha
+ std::vector<uint8_t> sample(128);
+ std::vector<uint8_t> mask(128);
+ EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
+ sample.size()));
+ CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask);
+}
+
+TEST_F(MaskingTest, MaskChaChaTooLong) {
+ // Max internally-supported length for ChaCha
+ std::vector<uint8_t> sample(128 + 1);
+ std::vector<uint8_t> mask(128 + 1);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess, SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_CHACHA20_POLY1305_SHA256,
+ secret_.get(), kLabel.c_str(),
+ kLabel.size(), &ctx_init));
+ EXPECT_EQ(0, PORT_GetError());
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+ EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
+ mask.data(), mask.size()));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(MaskingTest, MaskChaChaShortSample) {
+ std::vector<uint8_t> sample(15); // Should have 4B ctr, 12B nonce.
+ std::vector<uint8_t> mask(15);
+ SSLMaskingContext *ctx_init = nullptr;
+ EXPECT_EQ(SECSuccess, SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3,
+ TLS_CHACHA20_POLY1305_SHA256,
+ secret_.get(), kLabel.c_str(),
+ kLabel.size(), &ctx_init));
+ EXPECT_EQ(0, PORT_GetError());
+ ASSERT_NE(nullptr, ctx_init);
+ ScopedSSLMaskingContext ctx(ctx_init);
+ EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(),
+ mask.data(), mask.size()));
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(MaskingTest, MaskChaChaShortValid) {
+ std::vector<uint8_t> sample(16);
+ std::vector<uint8_t> mask(1);
+ EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(),
+ sample.size()));
+ CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask);
+}
+
+} // namespace nss_test
diff --git a/gtests/ssl_gtest/ssl_record_unittest.cc b/gtests/ssl_gtest/ssl_record_unittest.cc
index 86783b86e..ca4fc96f8 100644
--- a/gtests/ssl_gtest/ssl_record_unittest.cc
+++ b/gtests/ssl_gtest/ssl_record_unittest.cc
@@ -185,8 +185,8 @@ TEST_F(TlsConnectStreamTls13, TooLargeRecord) {
class ShortHeaderChecker : public PacketFilter {
public:
PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output) {
- // The first octet should be 0b001xxxxx.
- EXPECT_EQ(1, input.data()[0] >> 5);
+ // The first octet should be 0b001000xx.
+ EXPECT_EQ(kCtDtlsCiphertext, (input.data()[0] & ~0x3));
return KEEP;
}
};
@@ -205,6 +205,35 @@ TEST_F(TlsConnectDatagram13, ShortHeadersServer) {
SendReceive();
}
+// Send a DTLSCiphertext header with a 2B sequence number, and no length.
+TEST_F(TlsConnectDatagram13, DtlsAlternateShortHeader) {
+ StartConnect();
+ TlsSendCipherSpecCapturer capturer(client_);
+ Connect();
+ SendReceive(50);
+
+ uint8_t buf[] = {0x32, 0x33, 0x34};
+ auto spec = capturer.spec(1);
+ ASSERT_NE(nullptr, spec.get());
+ ASSERT_EQ(3, spec->epoch());
+
+ uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno;
+ TlsRecordHeader header(variant_, SSL_LIBRARY_VERSION_TLS_1_3, dtls13_ct,
+ 0x0003000000000001);
+ TlsRecordHeader out_header(header);
+ DataBuffer msg(buf, sizeof(buf));
+ msg.Write(msg.len(), ssl_ct_application_data, 1);
+ DataBuffer ciphertext;
+ EXPECT_TRUE(spec->Protect(header, msg, &ciphertext, &out_header));
+
+ DataBuffer record;
+ auto rv = out_header.Write(&record, 0, ciphertext);
+ EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv);
+ client_->SendDirect(record);
+
+ server_->ReadBytes(3);
+}
+
TEST_F(TlsConnectStreamTls13, UnencryptedFinishedMessage) {
StartConnect();
client_->Handshake(); // Send ClientHello
diff --git a/gtests/ssl_gtest/ssl_recordsize_unittest.cc b/gtests/ssl_gtest/ssl_recordsize_unittest.cc
index f2003a358..8926b5551 100644
--- a/gtests/ssl_gtest/ssl_recordsize_unittest.cc
+++ b/gtests/ssl_gtest/ssl_recordsize_unittest.cc
@@ -19,7 +19,8 @@ namespace nss_test {
// This class tracks the maximum size of record that was sent, both cleartext
// and plain. It only tracks records that have an outer type of
-// application_data. In TLS 1.3, this includes handshake messages.
+// application_data or DTLSCiphertext. In TLS 1.3, this includes handshake
+// messages.
class TlsRecordMaximum : public TlsRecordFilter {
public:
TlsRecordMaximum(const std::shared_ptr<TlsAgent>& a)
@@ -34,7 +35,7 @@ class TlsRecordMaximum : public TlsRecordFilter {
DataBuffer* output) override {
std::cerr << "max: " << record << std::endl;
// Ignore unprotected packets.
- if (header.content_type() != ssl_ct_application_data) {
+ if (!header.is_protected()) {
return KEEP;
}
@@ -195,9 +196,23 @@ class TlsRecordExpander : public TlsRecordFilter {
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& data,
DataBuffer* changed) {
- if (header.content_type() != ssl_ct_application_data) {
- return KEEP;
+ if (!header.is_protected()) {
+ // We're targeting application_data records. If the record is
+ // |!is_protected()|, we have two possibilities:
+ if (!decrypting()) {
+ // 1) We're not decrypting, in which this case this is truly an
+ // unencrypted record (Keep).
+ return KEEP;
+ }
+ if (header.content_type() != ssl_ct_application_data) {
+ // 2) We are decrypting, so is_protected() read the internal
+ // content_type. If the internal ct IS NOT application_data, then
+ // it's not our target (Keep).
+ return KEEP;
+ }
+ // Otherwise, the the internal ct IS application_data (Change).
}
+
changed->Allocate(data.len() + expansion_);
changed->Write(0, data.data(), data.len());
return CHANGE;
@@ -261,30 +276,31 @@ class TlsRecordPadder : public TlsRecordFilter {
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& record, size_t* offset,
DataBuffer* output) override {
- if (header.content_type() != ssl_ct_application_data) {
+ if (!header.is_protected()) {
return KEEP;
}
uint16_t protection_epoch;
uint8_t inner_content_type;
DataBuffer plaintext;
+ TlsRecordHeader out_header;
if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
- &plaintext)) {
+ &plaintext, &out_header)) {
return KEEP;
}
- if (inner_content_type != ssl_ct_application_data) {
+ if (decrypting() && inner_content_type != ssl_ct_application_data) {
return KEEP;
}
DataBuffer ciphertext;
- bool ok = Protect(spec(protection_epoch), header, inner_content_type,
- plaintext, &ciphertext, padding_);
+ bool ok = Protect(spec(protection_epoch), out_header, inner_content_type,
+ plaintext, &ciphertext, &out_header, padding_);
EXPECT_TRUE(ok);
if (!ok) {
return KEEP;
}
- *offset = header.Write(output, *offset, ciphertext);
+ *offset = out_header.Write(output, *offset, ciphertext);
return CHANGE;
}
diff --git a/gtests/ssl_gtest/ssl_tls13compat_unittest.cc b/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
index ecb63d476..6905ed0c0 100644
--- a/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
+++ b/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
@@ -384,14 +384,16 @@ TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) {
ASSERT_EQ(2U, client_records->count()); // CH, Fin
EXPECT_EQ(ssl_ct_handshake, client_records->record(0).header.content_type());
- EXPECT_EQ(ssl_ct_application_data,
- client_records->record(1).header.content_type());
+ EXPECT_EQ(kCtDtlsCiphertext,
+ (client_records->record(1).header.content_type() &
+ kCtDtlsCiphertextMask));
ASSERT_EQ(6U, server_records->count()); // SH, EE, CT, CV, Fin, Ack
EXPECT_EQ(ssl_ct_handshake, server_records->record(0).header.content_type());
for (size_t i = 1; i < server_records->count(); ++i) {
- EXPECT_EQ(ssl_ct_application_data,
- server_records->record(i).header.content_type());
+ EXPECT_EQ(kCtDtlsCiphertext,
+ (server_records->record(i).header.content_type() &
+ kCtDtlsCiphertextMask));
}
}
@@ -440,8 +442,9 @@ TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) {
ASSERT_EQ(5U, server_records->count()); // SH, EE, CT, CV, Fin
EXPECT_EQ(ssl_ct_handshake, server_records->record(0).header.content_type());
for (size_t i = 1; i < server_records->count(); ++i) {
- EXPECT_EQ(ssl_ct_application_data,
- server_records->record(i).header.content_type());
+ EXPECT_EQ(kCtDtlsCiphertext,
+ (server_records->record(i).header.content_type() &
+ kCtDtlsCiphertextMask));
}
uint32_t session_id_len = 0;
diff --git a/gtests/ssl_gtest/tls_agent.cc b/gtests/ssl_gtest/tls_agent.cc
index 88640481e..b52306961 100644
--- a/gtests/ssl_gtest/tls_agent.cc
+++ b/gtests/ssl_gtest/tls_agent.cc
@@ -1064,21 +1064,28 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) {
bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
uint64_t seq, uint8_t ct,
const DataBuffer& buf) {
- LOGV("Encrypting " << buf.len() << " bytes");
// Ensure that we are doing TLS 1.3.
EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3);
- TlsRecordHeader header(variant_, expected_version_, ssl_ct_application_data,
- seq);
+ if (variant_ != ssl_variant_datagram) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ LOGV("Encrypting " << buf.len() << " bytes");
+ uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
+ kCtDtlsCiphertextLengthPresent;
+ TlsRecordHeader header(variant_, expected_version_, dtls13_ct, seq);
+ TlsRecordHeader out_header(header);
DataBuffer padded = buf;
padded.Write(padded.len(), ct, 1);
DataBuffer ciphertext;
- if (!spec->Protect(header, padded, &ciphertext)) {
+ if (!spec->Protect(header, padded, &ciphertext, &out_header)) {
return false;
}
DataBuffer record;
- auto rv = header.Write(&record, 0, ciphertext);
- EXPECT_EQ(header.header_length() + ciphertext.len(), rv);
+ auto rv = out_header.Write(&record, 0, ciphertext);
+ EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv);
SendDirect(record);
return true;
}
@@ -1202,16 +1209,26 @@ void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type,
uint16_t version, const uint8_t* buf,
size_t len, DataBuffer* out,
uint64_t sequence_number) {
+ // Fixup the content type for DTLSCiphertext
+ if (variant == ssl_variant_datagram &&
+ version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ type == ssl_ct_application_data) {
+ type = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno |
+ kCtDtlsCiphertextLengthPresent;
+ }
+
size_t index = 0;
- index = out->Write(index, type, 1);
if (variant == ssl_variant_stream) {
+ index = out->Write(index, type, 1);
index = out->Write(index, version, 2);
} else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
- type == ssl_ct_application_data) {
+ (type & kCtDtlsCiphertextMask) == kCtDtlsCiphertext) {
uint32_t epoch = (sequence_number >> 48) & 0x3;
- uint32_t seqno = sequence_number & ((1ULL << 30) - 1);
- index = out->Write(index, (epoch << 30) | seqno, 4);
+ index = out->Write(index, type | epoch, 1);
+ uint32_t seqno = sequence_number & ((1ULL << 16) - 1);
+ index = out->Write(index, seqno, 2);
} else {
+ index = out->Write(index, type, 1);
index = out->Write(index, TlsVersionToDtlsVersion(version), 2);
index = out->Write(index, sequence_number >> 32, 4);
index = out->Write(index, sequence_number & PR_UINT32_MAX, 4);
diff --git a/gtests/ssl_gtest/tls_filter.cc b/gtests/ssl_gtest/tls_filter.cc
index b2917274b..d47ee71ab 100644
--- a/gtests/ssl_gtest/tls_filter.cc
+++ b/gtests/ssl_gtest/tls_filter.cc
@@ -120,6 +120,10 @@ bool TlsRecordFilter::is_dtls13() const {
info.canSendEarlyData;
}
+bool TlsRecordFilter::is_dtls13_ciphertext(uint8_t ct) const {
+ return is_dtls13() && (ct & kCtDtlsCiphertextMask) == kCtDtlsCiphertext;
+}
+
// Gets the cipher spec that matches the specified epoch.
TlsCipherSpec& TlsRecordFilter::spec(uint16_t write_epoch) {
for (auto& sp : cipher_specs_) {
@@ -196,23 +200,24 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
uint8_t inner_content_type;
DataBuffer plaintext;
uint16_t protection_epoch = 0;
+ TlsRecordHeader out_header(header);
if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
- &plaintext)) {
+ &plaintext, &out_header)) {
std::cerr << agent()->role_str() << ": unprotect failed: " << header << ":"
<< record << std::endl;
return KEEP;
}
auto& protection_spec = spec(protection_epoch);
- TlsRecordHeader real_header(header.variant(), header.version(),
- inner_content_type, header.sequence_number());
+ TlsRecordHeader real_header(out_header.variant(), out_header.version(),
+ inner_content_type, out_header.sequence_number());
PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered);
// In stream mode, even if something doesn't change we need to re-encrypt if
// previous packets were dropped.
if (action == KEEP) {
- if (header.is_dtls() || !protection_spec.record_dropped()) {
+ if (out_header.is_dtls() || !protection_spec.record_dropped()) {
// Count every outgoing packet.
protection_spec.RecordProtected();
return KEEP;
@@ -221,7 +226,7 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
}
if (action == DROP) {
- std::cerr << "record drop: " << header << ":" << record << std::endl;
+ std::cerr << "record drop: " << out_header << ":" << record << std::endl;
protection_spec.RecordDropped();
return DROP;
}
@@ -233,17 +238,15 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
}
uint64_t seq_num = protection_spec.next_out_seqno();
- if (!decrypting_ && header.is_dtls()) {
+ if (!decrypting_ && out_header.is_dtls()) {
// Copy over the epoch, which isn't tracked when not decrypting.
- seq_num |= header.sequence_number() & (0xffffULL << 48);
+ seq_num |= out_header.sequence_number() & (0xffffULL << 48);
}
-
- TlsRecordHeader out_header(header.variant(), header.version(),
- header.content_type(), seq_num);
+ out_header.sequence_number(seq_num);
DataBuffer ciphertext;
bool rv = Protect(protection_spec, out_header, inner_content_type, filtered,
- &ciphertext);
+ &ciphertext, &out_header);
if (!rv) {
return KEEP;
}
@@ -262,19 +265,67 @@ size_t TlsRecordHeader::header_length() const {
return WriteHeader(&buf, 0, 0);
}
-uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t expected,
+bool TlsRecordHeader::MaskSequenceNumber() {
+ return MaskSequenceNumber(sn_mask());
+}
+
+bool TlsRecordHeader::MaskSequenceNumber(const DataBuffer& mask) {
+ if (mask.empty()) {
+ return false;
+ }
+
+ if (is_dtls13_ciphertext()) {
+ uint64_t seqno = sequence_number();
+ uint8_t len = content_type() & kCtDtlsCiphertext16bSeqno ? 2 : 1;
+ uint16_t seqno_bitmask = (1 << len * 8) - 1;
+ DataBuffer val;
+ if (val.Write(0, seqno & seqno_bitmask, len) != len) {
+ return false;
+ }
+
+ val.data()[0] ^= mask.data()[0];
+ if (len == 2 && mask.len() > 1) {
+ val.data()[1] ^= mask.data()[1];
+ }
+
+ uint32_t tmp;
+ if (!val.Read(0, len, &tmp)) {
+ return false;
+ }
+
+ seqno = (seqno & ~seqno_bitmask) | tmp;
+ seqno_is_masked_ = !seqno_is_masked_;
+ if (!seqno_is_masked_) {
+ seqno = ParseSequenceNumber(guess_seqno_, seqno, len * 8, 2);
+ }
+ sequence_number_ = seqno;
+
+ // Now update the header bytes
+ if (header_.len() > 1) {
+ header_.data()[1] ^= mask.data()[0];
+ if ((content_type() & kCtDtlsCiphertext16bSeqno) && header().len() > 2) {
+ header_.data()[2] ^= mask.data()[1];
+ }
+ }
+ }
+
+ sn_mask_ = mask;
+ return true;
+}
+
+uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t guess_seqno,
uint32_t partial,
size_t partial_bits) {
EXPECT_GE(32U, partial_bits);
uint64_t mask = (1ULL << partial_bits) - 1;
// First we determine the highest possible value. This is half the
- // expressible range above the expected value, less 1.
+ // expressible range above the expected value (|guess_seqno|), less 1.
//
// We subtract the extra 1 from the cap so that when given a choice between
// the equidistant expected+N and expected-N we want to chose the lower. With
// 0-RTT, we sometimes have to recover an epoch of 1 when we expect an epoch
// of 3 and with 2 partial bits, the alternative result of 5 is wrong.
- uint64_t cap = expected + (1ULL << (partial_bits - 1)) - 1;
+ uint64_t cap = guess_seqno + (1ULL << (partial_bits - 1)) - 1;
// Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234.
uint64_t seq_no = (cap & ~mask) | partial;
// If the partial value is higher than the same partial piece from the cap,
@@ -286,15 +337,18 @@ uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t expected,
}
// Determine the full epoch and sequence number from an expected and raw value.
-// The expected and output values are packed as they are in DTLS 1.2 and
-// earlier: with 16 bits of epoch and 48 bits of sequence number.
-uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint32_t raw,
+// The expected, raw, and output values are packed as they are in DTLS 1.2 and
+// earlier: with 16 bits of epoch and 48 bits of sequence number. The raw value
+// is packed this way (even before recovery) so that we don't need to track a
+// moving value between two calls (one to recover the epoch, and one after
+// unmasking to recover the sequence number).
+uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint64_t raw,
size_t seq_no_bits,
size_t epoch_bits) {
uint64_t epoch_mask = (1ULL << epoch_bits) - 1;
- uint64_t epoch = RecoverSequenceNumber(
- expected >> 48, (raw >> seq_no_bits) & epoch_mask, epoch_bits);
- if (epoch > (expected >> 48)) {
+ uint64_t ep = RecoverSequenceNumber(expected >> 48, (raw >> 48) & epoch_mask,
+ epoch_bits);
+ if (ep > (expected >> 48)) {
// If the epoch has changed, reset the expected sequence number.
expected = 0;
} else {
@@ -302,9 +356,12 @@ uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint32_t raw,
expected &= (1ULL << 48) - 1;
}
uint64_t seq_no_mask = (1ULL << seq_no_bits) - 1;
- uint64_t seq_no =
- RecoverSequenceNumber(expected, raw & seq_no_mask, seq_no_bits);
- return (epoch << 48) | seq_no;
+ uint64_t seq_no = (raw & seq_no_mask);
+ if (!seqno_is_masked_) {
+ seq_no = RecoverSequenceNumber(expected, seq_no, seq_no_bits);
+ }
+
+ return (ep << 48) | seq_no;
}
bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser,
@@ -320,38 +377,47 @@ bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser,
version_ = SSL_LIBRARY_VERSION_TLS_1_3;
#ifndef UNSAFE_FUZZER_MODE
- // Deal with the 7 octet header.
- if (content_type_ == ssl_ct_application_data) {
+ // Deal with the DTLSCipherText header.
+ if (is_dtls13_ciphertext()) {
+ uint8_t seq_no_bytes =
+ (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1;
uint32_t tmp;
- if (!parser->Read(&tmp, 4)) {
- return false;
- }
- sequence_number_ = ParseSequenceNumber(seqno, tmp, 30, 2);
- if (!parser->ReadFromMark(&header_, parser->consumed() + 2 - mark,
- mark)) {
+
+ if (!parser->Read(&tmp, seq_no_bytes)) {
return false;
}
- return parser->ReadVariable(body, 2);
- }
- // The short, 2 octet header.
- if ((content_type_ & 0xe0) == 0x20) {
- uint32_t tmp;
- if (!parser->Read(&tmp, 1)) {
- return false;
+ // Store the guess if masked. If and when seqno_bytesenceNumber is called,
+ // the value will be unmasked and recovered. This assumes we only call
+ // Parse() on headers containing masked values.
+ seqno_is_masked_ = true;
+ guess_seqno_ = seqno;
+ uint64_t ep = content_type_ & 0x03;
+ sequence_number_ = (ep << 48) | tmp;
+
+ // Recover the full epoch. Note the sequence number portion holds the
+ // masked value until a call to Mask() reveals it (as indicated by
+ // |seqno_is_masked_|).
+ sequence_number_ =
+ ParseSequenceNumber(seqno, sequence_number_, seq_no_bytes * 8, 2);
+
+ uint32_t len_bytes =
+ (content_type_ & kCtDtlsCiphertextLengthPresent) ? 2 : 0;
+ if (len_bytes) {
+ if (!parser->Read(&tmp, 2)) {
+ return false;
+ }
}
- // Need to use the low 5 bits of the first octet too.
- tmp |= (content_type_ & 0x1f) << 8;
- content_type_ = ssl_ct_application_data;
- sequence_number_ = ParseSequenceNumber(seqno, tmp, 12, 1);
if (!parser->ReadFromMark(&header_, parser->consumed() - mark, mark)) {
return false;
}
- return parser->Read(body, parser->remaining());
+
+ return len_bytes ? parser->Read(body, tmp)
+ : parser->Read(body, parser->remaining());
}
- // The full 13 octet header can only be used for a few types.
+ // The full DTLSPlainText header can only be used for a few types.
EXPECT_TRUE(content_type_ == ssl_ct_alert ||
content_type_ == ssl_ct_handshake ||
content_type_ == ssl_ct_ack);
@@ -389,15 +455,20 @@ bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser,
size_t TlsRecordHeader::WriteHeader(DataBuffer* buffer, size_t offset,
size_t body_len) const {
- offset = buffer->Write(offset, content_type_, 1);
- if (is_dtls() && version_ >= SSL_LIBRARY_VERSION_TLS_1_3 &&
- content_type() == ssl_ct_application_data) {
+ if (is_dtls13_ciphertext()) {
+ uint8_t seq_no_bytes = (content_type_ & kCtDtlsCiphertext16bSeqno) ? 2 : 1;
// application_data records in TLS 1.3 have a different header format.
- // Always use the long header here for simplicity.
uint32_t e = (sequence_number_ >> 48) & 0x3;
- uint32_t seqno = sequence_number_ & ((1ULL << 30) - 1);
- offset = buffer->Write(offset, (e << 30) | seqno, 4);
+ uint32_t seqno = sequence_number_ & ((1ULL << seq_no_bytes * 8) - 1);
+ uint8_t new_content_type_ = content_type_ | e;
+ offset = buffer->Write(offset, new_content_type_, 1);
+ offset = buffer->Write(offset, seqno, seq_no_bytes);
+
+ if (content_type_ & kCtDtlsCiphertextLengthPresent) {
+ offset = buffer->Write(offset, body_len, 2);
+ }
} else {
+ offset = buffer->Write(offset, content_type_, 1);
uint16_t v = is_dtls() ? TlsVersionToDtlsVersion(version_) : version_;
offset = buffer->Write(offset, v, 2);
if (is_dtls()) {
@@ -405,8 +476,9 @@ size_t TlsRecordHeader::WriteHeader(DataBuffer* buffer, size_t offset,
offset = buffer->Write(offset, sequence_number_ >> 32, 4);
offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4);
}
+ offset = buffer->Write(offset, body_len, 2);
}
- offset = buffer->Write(offset, body_len, 2);
+
return offset;
}
@@ -421,8 +493,9 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
const DataBuffer& ciphertext,
uint16_t* protection_epoch,
uint8_t* inner_content_type,
- DataBuffer* plaintext) {
- if (!decrypting_ || header.content_type() != ssl_ct_application_data) {
+ DataBuffer* plaintext,
+ TlsRecordHeader* out_header) {
+ if (!decrypting_ || !header.is_protected()) {
// Maintain the epoch and sequence number for plaintext records.
uint16_t ep = 0;
if (agent()->variant() == ssl_variant_datagram) {
@@ -438,7 +511,7 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
uint16_t ep = 0;
if (agent()->variant() == ssl_variant_datagram) {
ep = static_cast<uint16_t>(header.sequence_number() >> 48);
- if (!spec(ep).Unprotect(header, ciphertext, plaintext)) {
+ if (!spec(ep).Unprotect(header, ciphertext, plaintext, out_header)) {
return false;
}
} else {
@@ -446,7 +519,8 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
// can't just use the newest keys because the same flight of messages can
// contain multiple epochs. So... trial decrypt!
for (size_t i = cipher_specs_.size() - 1; i > 0; --i) {
- if (cipher_specs_[i].Unprotect(header, ciphertext, plaintext)) {
+ if (cipher_specs_[i].Unprotect(header, ciphertext, plaintext,
+ out_header)) {
ep = cipher_specs_[i].epoch();
break;
}
@@ -481,7 +555,8 @@ bool TlsRecordFilter::Protect(TlsCipherSpec& protection_spec,
const TlsRecordHeader& header,
uint8_t inner_content_type,
const DataBuffer& plaintext,
- DataBuffer* ciphertext, size_t padding) {
+ DataBuffer* ciphertext,
+ TlsRecordHeader* out_header, size_t padding) {
if (!protection_spec.is_protected()) {
// Not protected, just keep the sequence numbers updated.
protection_spec.RecordProtected();
@@ -494,7 +569,7 @@ bool TlsRecordFilter::Protect(TlsCipherSpec& protection_spec,
size_t offset = padded.Write(0, plaintext.data(), plaintext.len());
padded.Write(offset, inner_content_type, 1);
- bool ok = protection_spec.Protect(header, padded, ciphertext);
+ bool ok = protection_spec.Protect(header, padded, ciphertext, out_header);
if (!ok) {
ADD_FAILURE() << "protect fail";
} else if (g_ssl_gtest_verbose) {
diff --git a/gtests/ssl_gtest/tls_filter.h b/gtests/ssl_gtest/tls_filter.h
index 64ee71c89..8cf558f9c 100644
--- a/gtests/ssl_gtest/tls_filter.h
+++ b/gtests/ssl_gtest/tls_filter.h
@@ -12,6 +12,7 @@
#include <set>
#include <vector>
#include "sslt.h"
+#include "sslproto.h"
#include "test_io.h"
#include "tls_agent.h"
#include "tls_parser.h"
@@ -25,6 +26,59 @@ namespace nss_test {
class TlsCipherSpec;
+class TlsSendCipherSpecCapturer {
+ public:
+ TlsSendCipherSpecCapturer(const std::shared_ptr<TlsAgent>& agent)
+ : agent_(agent), send_cipher_specs_() {
+ EXPECT_EQ(SECSuccess,
+ SSL_SecretCallback(agent_->ssl_fd(), SecretCallback, this));
+ }
+
+ std::shared_ptr<TlsCipherSpec> spec(size_t i) {
+ if (i >= send_cipher_specs_.size()) {
+ return nullptr;
+ }
+ return send_cipher_specs_[i];
+ }
+
+ private:
+ static void SecretCallback(PRFileDesc* fd, PRUint16 epoch,
+ SSLSecretDirection dir, PK11SymKey* secret,
+ void* arg) {
+ auto self = static_cast<TlsSendCipherSpecCapturer*>(arg);
+ std::cerr << self->agent_->role_str() << ": capture " << dir
+ << " secret for epoch " << epoch << std::endl;
+
+ if (dir == ssl_secret_read) {
+ return;
+ }
+
+ SSLPreliminaryChannelInfo preinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetPreliminaryChannelInfo(self->agent_->ssl_fd(), &preinfo,
+ sizeof(preinfo)));
+ EXPECT_EQ(sizeof(preinfo), preinfo.length);
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite);
+
+ // Check the version:
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version);
+ ASSERT_GE(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion);
+
+ SSLCipherSuiteInfo cipherinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetCipherSuiteInfo(preinfo.cipherSuite, &cipherinfo,
+ sizeof(cipherinfo)));
+ EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length);
+
+ auto spec = std::make_shared<TlsCipherSpec>(true, epoch);
+ EXPECT_TRUE(spec->SetKeys(&cipherinfo, secret));
+ self->send_cipher_specs_.push_back(spec);
+ }
+
+ std::shared_ptr<TlsAgent> agent_;
+ std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_;
+};
+
class TlsVersioned {
public:
TlsVersioned() : variant_(ssl_variant_stream), version_(0) {}
@@ -45,22 +99,57 @@ class TlsVersioned {
class TlsRecordHeader : public TlsVersioned {
public:
TlsRecordHeader()
- : TlsVersioned(), content_type_(0), sequence_number_(0), header_() {}
+ : TlsVersioned(),
+ content_type_(0),
+ guess_seqno_(0),
+ seqno_is_masked_(false),
+ sequence_number_(0),
+ header_() {}
TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct,
uint64_t seqno)
: TlsVersioned(var, ver),
content_type_(ct),
+ guess_seqno_(0),
+ seqno_is_masked_(false),
sequence_number_(seqno),
- header_() {}
+ header_(),
+ sn_mask_() {}
+
+ bool is_protected() const {
+ // *TLS < 1.3
+ if (version() < SSL_LIBRARY_VERSION_TLS_1_3 &&
+ content_type() == ssl_ct_application_data) {
+ return true;
+ }
+
+ // TLS 1.3
+ if (!is_dtls() && version() >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ content_type() == ssl_ct_application_data) {
+ return true;
+ }
+
+ // DTLS 1.3
+ return is_dtls13_ciphertext();
+ }
uint8_t content_type() const { return content_type_; }
- uint64_t sequence_number() const { return sequence_number_; }
uint16_t epoch() const {
return static_cast<uint16_t>(sequence_number_ >> 48);
}
+ uint64_t sequence_number() const { return sequence_number_; }
+ void sequence_number(uint64_t seqno) { sequence_number_ = seqno; }
+ const DataBuffer& sn_mask() const { return sn_mask_; }
+ bool is_dtls13_ciphertext() const {
+ return is_dtls() && (version() >= SSL_LIBRARY_VERSION_TLS_1_3) &&
+ (content_type() & kCtDtlsCiphertextMask) == kCtDtlsCiphertext;
+ }
+
size_t header_length() const;
const DataBuffer& header() const { return header_; }
+ bool MaskSequenceNumber();
+ bool MaskSequenceNumber(const DataBuffer& mask);
+
// Parse the header; return true if successful; body in an outparam if OK.
bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser,
DataBuffer* body);
@@ -70,14 +159,17 @@ class TlsRecordHeader : public TlsVersioned {
size_t WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const;
private:
- static uint64_t RecoverSequenceNumber(uint64_t expected, uint32_t partial,
+ static uint64_t RecoverSequenceNumber(uint64_t guess_seqno, uint32_t partial,
size_t partial_bits);
- static uint64_t ParseSequenceNumber(uint64_t expected, uint32_t raw,
- size_t seq_no_bits, size_t epoch_bits);
+ uint64_t ParseSequenceNumber(uint64_t expected, uint64_t raw,
+ size_t seq_no_bits, size_t epoch_bits);
uint8_t content_type_;
+ uint64_t guess_seqno_;
+ bool seqno_is_masked_;
uint64_t sequence_number_;
DataBuffer header_;
+ DataBuffer sn_mask_;
};
struct TlsRecord {
@@ -111,12 +203,14 @@ class TlsRecordFilter : public PacketFilter {
// Enabling it for lower version tests will cause undefined
// behavior.
void EnableDecryption();
+ bool decrypting() const { return decrypting_; };
bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText,
uint16_t* protection_epoch, uint8_t* inner_content_type,
- DataBuffer* plaintext);
+ DataBuffer* plaintext, TlsRecordHeader* out_header);
bool Protect(TlsCipherSpec& protection_spec, const TlsRecordHeader& header,
uint8_t inner_content_type, const DataBuffer& plaintext,
- DataBuffer* ciphertext, size_t padding = 0);
+ DataBuffer* ciphertext, TlsRecordHeader* out_header,
+ size_t padding = 0);
protected:
// There are two filter functions which can be overriden. Both are
@@ -141,6 +235,7 @@ class TlsRecordFilter : public PacketFilter {
}
bool is_dtls13() const;
+ bool is_dtls13_ciphertext(uint8_t ct) const;
TlsCipherSpec& spec(uint16_t epoch);
private:
@@ -471,8 +566,9 @@ class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter {
uint16_t protection_epoch = 0;
uint8_t inner_content_type;
DataBuffer plaintext;
+ TlsRecordHeader out_header;
if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
- &plaintext) ||
+ &plaintext, &out_header) ||
!plaintext.len()) {
return KEEP;
}
@@ -501,12 +597,12 @@ class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter {
}
DataBuffer ciphertext;
- bool ok = Protect(spec(protection_epoch), header, inner_content_type,
- plaintext, &ciphertext, 0);
+ bool ok = Protect(spec(protection_epoch), out_header, inner_content_type,
+ plaintext, &ciphertext, &out_header);
if (!ok) {
return KEEP;
}
- *offset = header.Write(output, *offset, ciphertext);
+ *offset = out_header.Write(output, *offset, ciphertext);
return CHANGE;
}
diff --git a/gtests/ssl_gtest/tls_protect.cc b/gtests/ssl_gtest/tls_protect.cc
index de91982f7..7737fe5ea 100644
--- a/gtests/ssl_gtest/tls_protect.cc
+++ b/gtests/ssl_gtest/tls_protect.cc
@@ -25,39 +25,66 @@ TlsCipherSpec::TlsCipherSpec(bool dtls, uint16_t epoc)
bool TlsCipherSpec::SetKeys(SSLCipherSuiteInfo* cipherinfo,
PK11SymKey* secret) {
- SSLAeadContext* ctx;
+ SSLAeadContext* aead_ctx;
SECStatus rv = SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3,
cipherinfo->cipherSuite, secret, "",
0, // Use the default labels.
- &ctx);
+ &aead_ctx);
if (rv != SECSuccess) {
return false;
}
- aead_.reset(ctx);
+ aead_.reset(aead_ctx);
+
+ SSLMaskingContext* mask_ctx;
+ const char kHkdfPurposeSn[] = "sn";
+ rv = SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3,
+ cipherinfo->cipherSuite, secret, kHkdfPurposeSn,
+ strlen(kHkdfPurposeSn), &mask_ctx);
+ if (rv != SECSuccess) {
+ return false;
+ }
+ mask_.reset(mask_ctx);
return true;
}
bool TlsCipherSpec::Unprotect(const TlsRecordHeader& header,
const DataBuffer& ciphertext,
- DataBuffer* plaintext) {
- if (aead_ == nullptr) {
+ DataBuffer* plaintext,
+ TlsRecordHeader* out_header) {
+ if (!aead_ || !out_header) {
return false;
}
+ *out_header = header;
+
// Make space.
plaintext->Allocate(ciphertext.len());
- auto header_bytes = header.header();
unsigned int len;
- uint64_t seqno;
- if (dtls_) {
- seqno = header.sequence_number();
- } else {
- seqno = in_seqno_;
+ uint64_t seqno = dtls_ ? header.sequence_number() : in_seqno_;
+ SECStatus rv;
+
+ if (header.is_dtls13_ciphertext()) {
+ if (!mask_ || !out_header) {
+ return false;
+ }
+ PORT_Assert(ciphertext.len() >= 16);
+ DataBuffer mask(2);
+ rv = SSL_CreateMask(mask_.get(), ciphertext.data(), ciphertext.len(),
+ mask.data(), mask.len());
+ if (rv != SECSuccess) {
+ return false;
+ }
+
+ if (!out_header->MaskSequenceNumber(mask)) {
+ return false;
+ }
+ seqno = out_header->sequence_number();
}
- SECStatus rv =
- SSL_AeadDecrypt(aead_.get(), seqno, header_bytes.data(),
- header_bytes.len(), ciphertext.data(), ciphertext.len(),
- plaintext->data(), &len, plaintext->len());
+
+ auto header_bytes = out_header->header();
+ rv = SSL_AeadDecrypt(aead_.get(), seqno, header_bytes.data(),
+ header_bytes.len(), ciphertext.data(), ciphertext.len(),
+ plaintext->data(), &len, plaintext->len());
if (rv != SECSuccess) {
return false;
}
@@ -69,11 +96,14 @@ bool TlsCipherSpec::Unprotect(const TlsRecordHeader& header,
}
bool TlsCipherSpec::Protect(const TlsRecordHeader& header,
- const DataBuffer& plaintext,
- DataBuffer* ciphertext) {
- if (aead_ == nullptr) {
+ const DataBuffer& plaintext, DataBuffer* ciphertext,
+ TlsRecordHeader* out_header) {
+ if (!aead_ || !out_header) {
return false;
}
+
+ *out_header = header;
+
// Make a padded buffer.
ciphertext->Allocate(plaintext.len() +
32); // Room for any plausible auth tag
@@ -81,12 +111,7 @@ bool TlsCipherSpec::Protect(const TlsRecordHeader& header,
DataBuffer header_bytes;
(void)header.WriteHeader(&header_bytes, 0, plaintext.len() + 16);
- uint64_t seqno;
- if (dtls_) {
- seqno = header.sequence_number();
- } else {
- seqno = out_seqno_;
- }
+ uint64_t seqno = dtls_ ? header.sequence_number() : out_seqno_;
SECStatus rv =
SSL_AeadEncrypt(aead_.get(), seqno, header_bytes.data(),
@@ -96,6 +121,22 @@ bool TlsCipherSpec::Protect(const TlsRecordHeader& header,
return false;
}
+ if (header.is_dtls13_ciphertext()) {
+ if (!mask_ || !out_header) {
+ return false;
+ }
+ PORT_Assert(ciphertext->len() >= 16);
+ DataBuffer mask(2);
+ rv = SSL_CreateMask(mask_.get(), ciphertext->data(), ciphertext->len(),
+ mask.data(), mask.len());
+ if (rv != SECSuccess) {
+ return false;
+ }
+ if (!out_header->MaskSequenceNumber(mask)) {
+ return false;
+ }
+ }
+
RecordProtected();
ciphertext->Truncate(len);
diff --git a/gtests/ssl_gtest/tls_protect.h b/gtests/ssl_gtest/tls_protect.h
index b1febf887..d7ea2aa12 100644
--- a/gtests/ssl_gtest/tls_protect.h
+++ b/gtests/ssl_gtest/tls_protect.h
@@ -27,9 +27,9 @@ class TlsCipherSpec {
bool SetKeys(SSLCipherSuiteInfo* cipherinfo, PK11SymKey* secret);
bool Protect(const TlsRecordHeader& header, const DataBuffer& plaintext,
- DataBuffer* ciphertext);
+ DataBuffer* ciphertext, TlsRecordHeader* out_header);
bool Unprotect(const TlsRecordHeader& header, const DataBuffer& ciphertext,
- DataBuffer* plaintext);
+ DataBuffer* plaintext, TlsRecordHeader* out_header);
uint16_t epoch() const { return epoch_; }
uint64_t next_in_seqno() const { return in_seqno_; }
@@ -52,6 +52,7 @@ class TlsCipherSpec {
uint64_t out_seqno_;
bool record_dropped_ = false;
ScopedSSLAeadContext aead_;
+ ScopedSSLMaskingContext mask_;
};
} // namespace nss_test
diff --git a/lib/ssl/dtls13con.c b/lib/ssl/dtls13con.c
index 0c4fc7fcd..c87e0907a 100644
--- a/lib/ssl/dtls13con.c
+++ b/lib/ssl/dtls13con.c
@@ -10,38 +10,52 @@
#include "ssl.h"
#include "sslimpl.h"
#include "sslproto.h"
+#include "keyhi.h"
+#include "pk11func.h"
+/*
+ * 0 1 2 3 4 5 6 7
+ * +-+-+-+-+-+-+-+-+
+ * |0|0|1|C|S|L|E E|
+ * +-+-+-+-+-+-+-+-+
+ * | Connection ID | Legend:
+ * | (if any, |
+ * / length as / C - CID present
+ * | negotiated) | S - Sequence number length
+ * +-+-+-+-+-+-+-+-+ L - Length present
+ * | 8 or 16 bit | E - Epoch
+ * |Sequence Number|
+ * +-+-+-+-+-+-+-+-+
+ * | 16 bit Length |
+ * | (if present) |
+ * +-+-+-+-+-+-+-+-+
+ */
SECStatus
-dtls13_InsertCipherTextHeader(const sslSocket *ss, ssl3CipherSpec *cwSpec,
+dtls13_InsertCipherTextHeader(const sslSocket *ss, const ssl3CipherSpec *cwSpec,
sslBuffer *wrBuf, PRBool *needsLength)
{
- PRUint32 seq;
- SECStatus rv;
-
/* Avoid using short records for the handshake. We pack multiple records
* into the one datagram for the handshake. */
if (ss->opt.enableDtlsShortHeader &&
- cwSpec->epoch != TrafficKeyHandshake) {
+ cwSpec->epoch > TrafficKeyHandshake) {
*needsLength = PR_FALSE;
/* The short header is comprised of two octets in the form
- * 0b001essssssssssss where 'e' is the low bit of the epoch and 's' is
- * the low 12 bits of the sequence number. */
- seq = 0x2000 |
- (((uint64_t)cwSpec->epoch & 1) << 12) |
- (cwSpec->nextSeqNum & 0xfff);
- return sslBuffer_AppendNumber(wrBuf, seq, 2);
+ * 0b001000eessssssss where 'e' is the low two bits of the
+ * epoch and 's' is the low 8 bits of the sequence number. */
+ PRUint8 ct = 0x20 | ((uint64_t)cwSpec->epoch & 0x3);
+ if (sslBuffer_AppendNumber(wrBuf, ct, 1) != SECSuccess) {
+ return SECFailure;
+ }
+ PRUint8 seq = cwSpec->nextSeqNum & 0xff;
+ return sslBuffer_AppendNumber(wrBuf, seq, 1);
}
- rv = sslBuffer_AppendNumber(wrBuf, ssl_ct_application_data, 1);
- if (rv != SECSuccess) {
+ PRUint8 ct = 0x2c | ((PRUint8)cwSpec->epoch & 0x3);
+ if (sslBuffer_AppendNumber(wrBuf, ct, 1) != SECSuccess) {
return SECFailure;
}
-
- /* The epoch and sequence number are encoded on 4 octets, with the epoch
- * consuming the first two bits. */
- seq = (((uint64_t)cwSpec->epoch & 3) << 30) | (cwSpec->nextSeqNum & 0x3fffffff);
- rv = sslBuffer_AppendNumber(wrBuf, seq, 4);
- if (rv != SECSuccess) {
+ if (sslBuffer_AppendNumber(wrBuf,
+ (cwSpec->nextSeqNum & 0xffff), 2) != SECSuccess) {
return SECFailure;
}
*needsLength = PR_TRUE;
@@ -512,3 +526,43 @@ dtls13_HolddownTimerCb(sslSocket *ss)
ssl_CipherSpecReleaseByEpoch(ss, ssl_secret_read, TrafficKeyHandshake);
ssl_ClearPRCList(&ss->ssl3.hs.dtlsRcvdHandshake, NULL);
}
+
+SECStatus
+dtls13_MaskSequenceNumber(sslSocket *ss, ssl3CipherSpec *spec,
+ PRUint8 *hdr, PRUint8 *cipherText, PRUint32 cipherTextLen)
+{
+ PORT_Assert(IS_DTLS(ss));
+ if (spec->version < SSL_LIBRARY_VERSION_TLS_1_3) {
+ return SECSuccess;
+ }
+
+ if (spec->maskContext) {
+ PRUint8 mask[2];
+ SECStatus rv = ssl_CreateMaskInner(spec->maskContext, cipherText, cipherTextLen, mask, sizeof(mask));
+
+ if (rv != SECSuccess) {
+ return SECFailure;
+ }
+
+ hdr[1] ^= mask[0];
+ if (hdr[0] & 0x08) {
+ hdr[2] ^= mask[1];
+ }
+ }
+
+ return SECSuccess;
+}
+
+CK_MECHANISM_TYPE
+tls13_SequenceNumberEncryptionMechanism(SSLCipherAlgorithm bulkAlgorithm)
+{
+ switch (bulkAlgorithm) {
+ case ssl_calg_aes_gcm:
+ return CKM_AES_ECB;
+ case ssl_calg_chacha20:
+ return CKM_NSS_CHACHA20_CTR;
+ default:
+ PORT_Assert(PR_FALSE);
+ }
+ return CKM_INVALID_MECHANISM;
+}
diff --git a/lib/ssl/dtls13con.h b/lib/ssl/dtls13con.h
index ce92a8a55..057d63efb 100644
--- a/lib/ssl/dtls13con.h
+++ b/lib/ssl/dtls13con.h
@@ -10,7 +10,7 @@
#define __dtls13con_h_
SECStatus dtls13_InsertCipherTextHeader(const sslSocket *ss,
- ssl3CipherSpec *cwSpec,
+ const ssl3CipherSpec *cwSpec,
sslBuffer *wrBuf,
PRBool *needsLength);
SECStatus dtls13_RememberFragment(sslSocket *ss, PRCList *list,
@@ -29,5 +29,9 @@ SECStatus dtls13_SendAck(sslSocket *ss);
void dtls13_SendAckCb(sslSocket *ss);
void dtls13_HolddownTimerCb(sslSocket *ss);
void dtls_ReceivedFirstMessageInFlight(sslSocket *ss);
+SECStatus dtls13_MaskSequenceNumber(sslSocket *ss, ssl3CipherSpec *spec,
+ PRUint8 *hdr, PRUint8 *cipherText, PRUint32 cipherTextLen);
+
+CK_MECHANISM_TYPE tls13_SequenceNumberEncryptionMechanism(SSLCipherAlgorithm bulkAlgorithm);
#endif
diff --git a/lib/ssl/dtlscon.c b/lib/ssl/dtlscon.c
index 9417063f1..ae84b81d9 100644
--- a/lib/ssl/dtlscon.c
+++ b/lib/ssl/dtlscon.c
@@ -1335,6 +1335,14 @@ dtls_IsLongHeader(SSL3ProtocolVersion version, PRUint8 firstOctet)
#endif
}
+PRBool
+dtls_IsDtls13Ciphertext(SSL3ProtocolVersion version, PRUint8 firstOctet)
+{
+ // Allow no version in case we haven't negotiated one yet.
+ return (version == 0 || version >= SSL_LIBRARY_VERSION_TLS_1_3) &&
+ (firstOctet & 0xe0) == 0x20;
+}
+
DTLSEpoch
dtls_ReadEpoch(const ssl3CipherSpec *crSpec, const PRUint8 *hdr)
{
@@ -1349,13 +1357,12 @@ dtls_ReadEpoch(const ssl3CipherSpec *crSpec, const PRUint8 *hdr)
/* A lot of how we recover the epoch here will depend on how we plan to
* manage KeyUpdate. In the case that we decide to install a new read spec
* as a KeyUpdate is handled, crSpec will always be the highest epoch we can
- * possibly receive. That makes this easier to manage. */
- if ((hdr[0] & 0xe0) == 0x20) {
+ * possibly receive. That makes this easier to manage.
+ */
+ if (dtls_IsDtls13Ciphertext(crSpec->version, hdr[0])) {
+ /* TODO(ekr@rtfm.com: do something with the two-bit epoch. */
/* Use crSpec->epoch, or crSpec->epoch - 1 if the last bit differs. */
- if (((hdr[0] >> 4) & 1) == (crSpec->epoch & 1)) {
- return crSpec->epoch;
- }
- return crSpec->epoch - 1;
+ return crSpec->epoch - ((hdr[0] ^ crSpec->epoch) & 0x3);
}
/* dtls_GatherData should ensure that this works. */
@@ -1398,20 +1405,15 @@ dtls_ReadSequenceNumber(const ssl3CipherSpec *spec, const PRUint8 *hdr)
* sequence number is replaced. If that causes the value to exceed the
* maximum, subtract an entire range.
*/
- if ((hdr[0] & 0xe0) == 0x20) {
- /* A 12-bit sequence number. */
- cap = spec->nextSeqNum + (1ULL << 11);
- partial = (((sslSequenceNumber)hdr[0] & 0xf) << 8) |
- (sslSequenceNumber)hdr[1];
- mask = (1ULL << 12) - 1;
+ if (hdr[0] & 0x08) {
+ cap = spec->nextSeqNum + (1ULL << 15);
+ partial = (((sslSequenceNumber)hdr[1]) << 8) |
+ (sslSequenceNumber)hdr[2];
+ mask = (1ULL << 16) - 1;
} else {
- /* A 30-bit sequence number. */
- cap = spec->nextSeqNum + (1ULL << 29);
- partial = (((sslSequenceNumber)hdr[1] & 0x3f) << 24) |
- ((sslSequenceNumber)hdr[2] << 16) |
- ((sslSequenceNumber)hdr[3] << 8) |
- (sslSequenceNumber)hdr[4];
- mask = (1ULL << 30) - 1;
+ cap = spec->nextSeqNum + (1ULL << 7);
+ partial = (sslSequenceNumber)hdr[1];
+ mask = (1ULL << 8) - 1;
}
seqNum = (cap & ~mask) | partial;
/* The second check prevents the value from underflowing if we get a large
diff --git a/lib/ssl/dtlscon.h b/lib/ssl/dtlscon.h
index 4ede3c2ca..9d10aa248 100644
--- a/lib/ssl/dtlscon.h
+++ b/lib/ssl/dtlscon.h
@@ -47,4 +47,5 @@ extern PRBool dtls_IsRelevant(sslSocket *ss, const ssl3CipherSpec *spec,
sslSequenceNumber *seqNum);
void dtls_ReceivedFirstMessageInFlight(sslSocket *ss);
PRBool dtls_IsLongHeader(SSL3ProtocolVersion version, PRUint8 firstOctet);
+PRBool dtls_IsDtls13Ciphertext(SSL3ProtocolVersion version, PRUint8 firstOctet);
#endif
diff --git a/lib/ssl/ssl3con.c b/lib/ssl/ssl3con.c
index 60b247fd7..e8ea99d82 100644
--- a/lib/ssl/ssl3con.c
+++ b/lib/ssl/ssl3con.c
@@ -2406,7 +2406,6 @@ ssl_ProtectRecord(sslSocket *ss, ssl3CipherSpec *cwSpec, SSLContentType ct,
PORT_Assert(cwSpec->cipherDef->max_records <= RECORD_SEQ_MAX);
if (cwSpec->nextSeqNum >= cwSpec->cipherDef->max_records) {
- /* We should have automatically updated before here in TLS 1.3. */
PORT_Assert(cwSpec->version < SSL_LIBRARY_VERSION_TLS_1_3);
SSL_TRC(3, ("%d: SSL[-]: write sequence number at limit 0x%0llx",
SSL_GETPID(), cwSpec->nextSeqNum));
@@ -2438,7 +2437,28 @@ ssl_ProtectRecord(sslSocket *ss, ssl3CipherSpec *cwSpec, SSLContentType ct,
}
#else
if (cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ PRUint8 *cipherText = SSL_BUFFER_NEXT(wrBuf);
+ unsigned int bufLen = SSL_BUFFER_LEN(wrBuf);
rv = tls13_ProtectRecord(ss, cwSpec, ct, pIn, contentLen, wrBuf);
+ if (rv != SECSuccess) {
+ return SECFailure;
+ }
+ if (IS_DTLS(ss)) {
+ bufLen = SSL_BUFFER_LEN(wrBuf) - bufLen;
+#ifdef UNSAFE_FUZZER_MODE
+ /* The null cipher doesn't add a tag. Make sure the "ciphertext"
+ * is long enough for mask creation. */
+ unsigned char tmpCt[AES_BLOCK_SIZE] = { 0 };
+ if (bufLen < 16) {
+ memcpy(tmpCt, cipherText, bufLen);
+ bufLen = sizeof(tmpCt);
+ cipherText = tmpCt;
+ }
+#endif
+ rv = dtls13_MaskSequenceNumber(ss, cwSpec,
+ SSL_BUFFER_BASE(wrBuf),
+ cipherText, bufLen);
+ }
} else {
rv = ssl3_MACEncryptRecord(cwSpec, ss->sec.isServer, IS_DTLS(ss), ct,
pIn, contentLen, wrBuf);
@@ -12899,6 +12919,24 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText)
}
isTLS = (PRBool)(spec->version > SSL_LIBRARY_VERSION_3_0);
if (IS_DTLS(ss)) {
+ unsigned int bufLen = SSL_BUFFER_LEN(cText->buf);
+ unsigned char *cipherText = SSL_BUFFER_BASE(cText->buf);
+#ifdef UNSAFE_FUZZER_MODE
+ /* The null cipher doesn't add a tag. Make sure the "ciphertext"
+ * is long enough for mask creation. */
+ unsigned char tmpCt[AES_BLOCK_SIZE] = { 0 };
+ if (bufLen < 16) {
+ memcpy(tmpCt, cipherText, bufLen);
+ bufLen = sizeof(tmpCt);
+ cipherText = tmpCt;
+ }
+#endif
+ if (dtls13_MaskSequenceNumber(ss, spec, cText->hdr,
+ cipherText, bufLen) != SECSuccess) {
+ ssl_ReleaseSpecReadLock(ss); /*****************************/
+ PORT_SetError(SSL_ERROR_DECRYPTION_FAILURE);
+ return SECFailure;
+ }
if (!dtls_IsRelevant(ss, spec, cText, &cText->seqNum)) {
ssl_ReleaseSpecReadLock(ss); /*****************************/
return SECSuccess;
@@ -12940,7 +12978,10 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText)
/* Encrypted application data records could arrive before the handshake
* completes in DTLS 1.3. These can look like valid TLS 1.2 application_data
* records in epoch 0, which is never valid. Pretend they didn't decrypt. */
- if (spec->epoch == 0 && rType == ssl_ct_application_data) {
+
+ if (spec->epoch == 0 && ((IS_DTLS(ss) &&
+ dtls_IsDtls13Ciphertext(0, rType)) ||
+ rType == ssl_ct_application_data)) {
PORT_SetError(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA);
alert = unexpected_message;
rv = SECFailure;
diff --git a/lib/ssl/ssl3gthr.c b/lib/ssl/ssl3gthr.c
index f9c741746..3bc6e8edc 100644
--- a/lib/ssl/ssl3gthr.c
+++ b/lib/ssl/ssl3gthr.c
@@ -268,6 +268,7 @@ dtls_GatherData(sslSocket *ss, sslGather *gs, int flags)
PRUint8 contentType;
unsigned int headerLen;
SECStatus rv;
+ PRBool dtlsLengthPresent = PR_TRUE;
SSL_TRC(30, ("dtls_GatherData"));
@@ -316,8 +317,20 @@ dtls_GatherData(sslSocket *ss, sslGather *gs, int flags)
headerLen = 13;
} else if (contentType == ssl_ct_application_data) {
headerLen = 7;
- } else if ((contentType & 0xe0) == 0x20) {
- headerLen = 2;
+ } else if (dtls_IsDtls13Ciphertext(ss->version, contentType)) {
+ /* We don't support CIDs. */
+ if (contentType & 0x10) {
+ PORT_Assert(PR_FALSE);
+ PORT_SetError(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE);
+ gs->dtlsPacketOffset = 0;
+ gs->dtlsPacket.len = 0;
+ return -1;
+ }
+
+ dtlsLengthPresent = (contentType & 0x04) == 0x04;
+ PRUint8 dtlsSeqNoSize = (contentType & 0x08) ? 2 : 1;
+ PRUint8 dtlsLengthBytes = dtlsLengthPresent ? 2 : 0;
+ headerLen = 1 + dtlsSeqNoSize + dtlsLengthBytes;
} else {
SSL_DBG(("%d: SSL3[%d]: invalid first octet (%d) for DTLS",
SSL_GETPID(), ss->fd, contentType));
@@ -345,12 +358,10 @@ dtls_GatherData(sslSocket *ss, sslGather *gs, int flags)
gs->dtlsPacketOffset += headerLen;
/* Have received SSL3 record header in gs->hdr. */
- if (headerLen == 13) {
- gs->remainder = (gs->hdr[11] << 8) | gs->hdr[12];
- } else if (headerLen == 7) {
- gs->remainder = (gs->hdr[5] << 8) | gs->hdr[6];
+ if (dtlsLengthPresent) {
+ gs->remainder = (gs->hdr[headerLen - 2] << 8) |
+ gs->hdr[headerLen - 1];
} else {
- PORT_Assert(headerLen == 2);
gs->remainder = gs->dtlsPacket.len - gs->dtlsPacketOffset;
}
diff --git a/lib/ssl/ssl3prot.h b/lib/ssl/ssl3prot.h
index ffe837301..b180931e9 100644
--- a/lib/ssl/ssl3prot.h
+++ b/lib/ssl/ssl3prot.h
@@ -14,7 +14,7 @@ typedef PRUint16 SSL3ProtocolVersion;
/* version numbers are defined in sslproto.h */
/* DTLS 1.3 is still a draft. */
-#define DTLS_1_3_DRAFT_VERSION 28
+#define DTLS_1_3_DRAFT_VERSION 30
typedef PRUint16 ssl3CipherSuite;
/* The cipher suites are defined in sslproto.h */
diff --git a/lib/ssl/sslexp.h b/lib/ssl/sslexp.h
index b734d86ca..61b1fc088 100644
--- a/lib/ssl/sslexp.h
+++ b/lib/ssl/sslexp.h
@@ -826,6 +826,56 @@ typedef PRTime(PR_CALLBACK *SSLTimeFunc)(void *arg);
PRUint16 _numCiphers), \
(fd, cipherOrder, numCiphers))
+/*
+ * The following functions expose a masking primitive that uses ciphersuite and
+ * version information to set paramaters for the masking key and mask generation
+ * logic. This is only supported for TLS 1.3.
+ *
+ * The key and IV are generated using the TLS KDF with a custom label. That is
+ * HKDF-Expand-Label(secret, label, "", L), where |label| is an input to
+ * SSL_CreateMaskingContext.
+ *
+ * The mask generation logic in SSL_CreateMask is determined by the underlying
+ * symmetric cipher:
+ * - For AES-ECB, mask = AES-ECB(mask_key, sample). |len| must be <= 16 as
+ * the output is limited to a single block.
+ * - For CHACHA20, mask = ChaCha20(mask_key, sample[0..3], sample[4..15], {0}.len)
+ * That is, the low 4 bytes of |sample| used as the counter, the remaining 12 bytes
+ * the nonce. We encrypt |len| bytes of zeros, returning the raw key stream.
+ *
+ * The caller must pre-allocate at least |len| bytes for output. If the underlying
+ * cipher cannot produce the requested amount of data, SECFailure is returned.
+ */
+
+typedef struct SSLMaskingContextStr {
+ CK_MECHANISM_TYPE mech;
+ PRUint16 version;
+ PRUint16 cipherSuite;
+ PK11SymKey *secret;
+} SSLMaskingContext;
+
+#define SSL_CreateMaskingContext(version, cipherSuite, secret, \
+ label, labelLen, ctx) \
+ SSL_EXPERIMENTAL_API("SSL_CreateMaskingContext", \
+ (PRUint16 _version, PRUint16 _cipherSuite, \
+ PK11SymKey * _secret, \
+ const char *_label, \
+ unsigned int _labelLen, \
+ SSLMaskingContext **_ctx), \
+ (version, cipherSuite, secret, label, labelLen, ctx))
+
+#define SSL_DestroyMaskingContext(ctx) \
+ SSL_EXPERIMENTAL_API("SSL_DestroyMaskingContext", \
+ (SSLMaskingContext * _ctx), \
+ (ctx))
+
+#define SSL_CreateMask(ctx, sample, sampleLen, mask, maskLen) \
+ SSL_EXPERIMENTAL_API("SSL_CreateMask", \
+ (SSLMaskingContext * _ctx, const PRUint8 *_sample, \
+ unsigned int _sampleLen, PRUint8 *_mask, \
+ unsigned int _maskLen), \
+ (ctx, sample, sampleLen, mask, maskLen))
+
/* Deprecated experimental APIs */
#define SSL_UseAltServerHelloType(fd, enable) SSL_DEPRECATED_EXPERIMENTAL_API
#define SSL_SetupAntiReplay(a, b, c) SSL_DEPRECATED_EXPERIMENTAL_API
diff --git a/lib/ssl/sslimpl.h b/lib/ssl/sslimpl.h
index 4a393b281..af789c73e 100644
--- a/lib/ssl/sslimpl.h
+++ b/lib/ssl/sslimpl.h
@@ -810,7 +810,7 @@ typedef struct {
/* |seqNum| eventually contains the reconstructed sequence number. */
sslSequenceNumber seqNum;
/* The header of the cipherText. */
- const PRUint8 *hdr;
+ PRUint8 *hdr;
unsigned int hdrLen;
/* |buf| is the payload of the ciphertext. */
@@ -1849,6 +1849,30 @@ SSLExp_HkdfExpandLabelWithMech(PRUint16 version, PRUint16 cipherSuite, PK11SymKe
SECStatus SSLExp_SetTimeFunc(PRFileDesc *fd, SSLTimeFunc f, void *arg);
+extern SECStatus ssl_CreateMaskingContextInner(PRUint16 version, PRUint16 cipherSuite,
+ PK11SymKey *secret,
+ const char *label,
+ unsigned int labelLen,
+ SSLMaskingContext **ctx);
+
+extern SECStatus ssl_CreateMaskInner(SSLMaskingContext *ctx, const PRUint8 *sample,
+ unsigned int sampleLen, PRUint8 *outMask,
+ unsigned int maskLen);
+
+extern SECStatus ssl_DestroyMaskingContextInner(SSLMaskingContext *ctx);
+
+SECStatus SSLExp_CreateMaskingContext(PRUint16 version, PRUint16 cipherSuite,
+ PK11SymKey *secret,
+ const char *label,
+ unsigned int labelLen,
+ SSLMaskingContext **ctx);
+
+SECStatus SSLExp_CreateMask(SSLMaskingContext *ctx, const PRUint8 *sample,
+ unsigned int sampleLen, PRUint8 *mask,
+ unsigned int len);
+
+SECStatus SSLExp_DestroyMaskingContext(SSLMaskingContext *ctx);
+
SEC_END_PROTOS
#if defined(XP_UNIX) || defined(XP_OS2) || defined(XP_BEOS)
diff --git a/lib/ssl/sslprimitive.c b/lib/ssl/sslprimitive.c
index 540c17840..5522f96fd 100644
--- a/lib/ssl/sslprimitive.c
+++ b/lib/ssl/sslprimitive.c
@@ -6,6 +6,7 @@
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+#include "blapit.h"
#include "keyhi.h"
#include "pk11pub.h"
#include "sechash.h"
@@ -23,34 +24,6 @@ struct SSLAeadContextStr {
ssl3KeyMaterial keys;
};
-static SECStatus
-tls13_GetHashAndCipher(PRUint16 version, PRUint16 cipherSuite,
- SSLHashType *hash, const ssl3BulkCipherDef **cipher)
-{
- if (version < SSL_LIBRARY_VERSION_TLS_1_3) {
- PORT_SetError(SEC_ERROR_INVALID_ARGS);
- return SECFailure;
- }
-
- // Lookup and check the suite.
- SSLVersionRange vrange = { version, version };
- if (!ssl3_CipherSuiteAllowedForVersionRange(cipherSuite, &vrange)) {
- PORT_SetError(SEC_ERROR_INVALID_ARGS);
- return SECFailure;
- }
- const ssl3CipherSuiteDef *suiteDef = ssl_LookupCipherSuiteDef(cipherSuite);
- const ssl3BulkCipherDef *cipherDef = ssl_GetBulkCipherDef(suiteDef);
- if (cipherDef->type != type_aead) {
- PORT_SetError(SEC_ERROR_INVALID_ARGS);
- return SECFailure;
- }
- *hash = suiteDef->prf_hash;
- if (cipher != NULL) {
- *cipher = cipherDef;
- }
- return SECSuccess;
-}
-
SECStatus
SSLExp_MakeAead(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *secret,
const char *labelPrefix, unsigned int labelPrefixLen,
@@ -272,3 +245,179 @@ SSLExp_HkdfExpandLabelWithMech(PRUint16 version, PRUint16 cipherSuite, PK11SymKe
return tls13_HkdfExpandLabel(prk, hash, hsHash, hsHashLen, label, labelLen,
mech, keySize, keyp);
}
+
+SECStatus
+ssl_CreateMaskingContextInner(PRUint16 version, PRUint16 cipherSuite,
+ PK11SymKey *secret,
+ const char *label,
+ unsigned int labelLen,
+ SSLMaskingContext **ctx)
+{
+ if (!secret || !ctx || (!label && labelLen)) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+
+ SSLMaskingContext *out = PORT_ZNew(SSLMaskingContext);
+ if (out == NULL) {
+ goto loser;
+ }
+
+ SSLHashType hash;
+ const ssl3BulkCipherDef *cipher;
+ SECStatus rv = tls13_GetHashAndCipher(version, cipherSuite,
+ &hash, &cipher);
+ if (rv != SECSuccess) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ goto loser; /* Code already set. */
+ }
+
+ out->mech = tls13_SequenceNumberEncryptionMechanism(cipher->calg);
+ if (out->mech == CKM_INVALID_MECHANISM) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ goto loser;
+ }
+
+ // Derive the masking key
+ rv = tls13_HkdfExpandLabel(secret, hash,
+ NULL, 0, // Handshake hash.
+ label, labelLen,
+ out->mech,
+ cipher->key_size, &out->secret);
+ if (rv != SECSuccess) {
+ goto loser;
+ }
+
+ out->version = version;
+ out->cipherSuite = cipherSuite;
+
+ *ctx = out;
+ return SECSuccess;
+loser:
+ SSLExp_DestroyMaskingContext(out);
+ return SECFailure;
+}
+
+SECStatus
+ssl_CreateMaskInner(SSLMaskingContext *ctx, const PRUint8 *sample,
+ unsigned int sampleLen, PRUint8 *outMask,
+ unsigned int maskLen)
+{
+ if (!ctx || !sample || !sampleLen || !outMask || !maskLen) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+
+ if (ctx->secret == NULL) {
+ PORT_SetError(SEC_ERROR_NO_KEY);
+ return SECFailure;
+ }
+
+ SECStatus rv = SECFailure;
+ unsigned int outMaskLen = 0;
+
+ /* Internal output len/buf, for use if the caller allocated and requested
+ * less than one block of output. |oneBlock| should have size equal to the
+ * largest block size supported below. */
+ PRUint8 oneBlock[AES_BLOCK_SIZE];
+ PRUint8 *outMask_ = outMask;
+ unsigned int maskLen_ = maskLen;
+
+ switch (ctx->mech) {
+ case CKM_AES_ECB:
+ if (sampleLen < AES_BLOCK_SIZE) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+ if (maskLen_ < AES_BLOCK_SIZE) {
+ outMask_ = oneBlock;
+ maskLen_ = sizeof(oneBlock);
+ }
+ rv = PK11_Encrypt(ctx->secret,
+ ctx->mech,
+ NULL,
+ outMask_, &outMaskLen, maskLen_,
+ sample, AES_BLOCK_SIZE);
+ if (rv == SECSuccess &&
+ maskLen < AES_BLOCK_SIZE) {
+ memcpy(outMask, outMask_, maskLen);
+ }
+ break;
+ case CKM_NSS_CHACHA20_CTR:
+ if (sampleLen < 16) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+
+ SECItem param;
+ param.type = siBuffer;
+ param.len = 16;
+ param.data = (PRUint8 *)sample; // const-cast :(
+ unsigned char zeros[128] = { 0 };
+
+ if (maskLen > sizeof(zeros)) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+
+ rv = PK11_Encrypt(ctx->secret,
+ ctx->mech,
+ &param,
+ outMask, &outMaskLen,
+ maskLen,
+ zeros, maskLen);
+ break;
+ default:
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+
+ if (rv != SECSuccess) {
+ PORT_SetError(SEC_ERROR_PKCS11_FUNCTION_FAILED);
+ return SECFailure;
+ }
+
+ // Ensure we produced at least as much material as requested.
+ if (outMaskLen < maskLen) {
+ PORT_SetError(SEC_ERROR_OUTPUT_LEN);
+ return SECFailure;
+ }
+
+ return SECSuccess;
+}
+
+SECStatus
+ssl_DestroyMaskingContextInner(SSLMaskingContext *ctx)
+{
+ if (!ctx) {
+ return SECSuccess;
+ }
+
+ PK11_FreeSymKey(ctx->secret);
+ PORT_ZFree(ctx, sizeof(*ctx));
+ return SECSuccess;
+}
+
+SECStatus
+SSLExp_CreateMask(SSLMaskingContext *ctx, const PRUint8 *sample,
+ unsigned int sampleLen, PRUint8 *outMask,
+ unsigned int maskLen)
+{
+ return ssl_CreateMaskInner(ctx, sample, sampleLen, outMask, maskLen);
+}
+
+SECStatus
+SSLExp_CreateMaskingContext(PRUint16 version, PRUint16 cipherSuite,
+ PK11SymKey *secret,
+ const char *label,
+ unsigned int labelLen,
+ SSLMaskingContext **ctx)
+{
+ return ssl_CreateMaskingContextInner(version, cipherSuite, secret, label, labelLen, ctx);
+}
+
+SECStatus
+SSLExp_DestroyMaskingContext(SSLMaskingContext *ctx)
+{
+ return ssl_DestroyMaskingContextInner(ctx);
+}
diff --git a/lib/ssl/sslsock.c b/lib/ssl/sslsock.c
index aa0e76e3c..581f0c467 100644
--- a/lib/ssl/sslsock.c
+++ b/lib/ssl/sslsock.c
@@ -4220,8 +4220,11 @@ struct {
EXP(CipherSuiteOrderGet),
EXP(CipherSuiteOrderSet),
EXP(CreateAntiReplayContext),
+ EXP(CreateMask),
+ EXP(CreateMaskingContext),
EXP(DelegateCredential),
EXP(DestroyAead),
+ EXP(DestroyMaskingContext),
EXP(DestroyResumptionTokenInfo),
EXP(EnableESNI),
EXP(EncodeESNIKeys),
diff --git a/lib/ssl/sslspec.c b/lib/ssl/sslspec.c
index def3c6750..c5bedad7a 100644
--- a/lib/ssl/sslspec.c
+++ b/lib/ssl/sslspec.c
@@ -7,6 +7,8 @@
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "ssl.h"
+#include "sslexp.h"
+#include "sslimpl.h"
#include "sslproto.h"
#include "pk11func.h"
#include "secitem.h"
@@ -227,6 +229,7 @@ ssl_FreeCipherSpec(ssl3CipherSpec *spec)
}
PK11_FreeSymKey(spec->masterSecret);
ssl_DestroyKeyMaterial(&spec->keyMaterial);
+ ssl_DestroyMaskingContextInner(spec->maskContext);
PORT_ZFree(spec, sizeof(*spec));
}
diff --git a/lib/ssl/sslspec.h b/lib/ssl/sslspec.h
index ca9ef540f..d00b20d76 100644
--- a/lib/ssl/sslspec.h
+++ b/lib/ssl/sslspec.h
@@ -169,6 +169,9 @@ struct ssl3CipherSpecStr {
* negotiated value for TLS 1.3; it is reduced by one to account for the
* content type octet. */
PRUint16 recordSizeLimit;
+
+ /* Masking context used for DTLS 1.3 */
+ SSLMaskingContext *maskContext;
};
typedef void (*sslCipherSpecChangedFunc)(void *arg,
diff --git a/lib/ssl/tls13con.c b/lib/ssl/tls13con.c
index c3528a52f..97c191872 100644
--- a/lib/ssl/tls13con.c
+++ b/lib/ssl/tls13con.c
@@ -131,6 +131,7 @@ const char kHkdfLabelExporterMasterSecret[] = "exp master";
const char kHkdfLabelResumption[] = "resumption";
const char kHkdfLabelTrafficUpdate[] = "traffic upd";
const char kHkdfPurposeKey[] = "key";
+const char kHkdfPurposeSn[] = "sn";
const char kHkdfPurposeIv[] = "iv";
const char keylogLabelClientEarlyTrafficSecret[] = "CLIENT_EARLY_TRAFFIC_SECRET";
@@ -286,6 +287,34 @@ tls13_GetHash(const sslSocket *ss)
return ss->ssl3.hs.suite_def->prf_hash;
}
+SECStatus
+tls13_GetHashAndCipher(PRUint16 version, PRUint16 cipherSuite,
+ SSLHashType *hash, const ssl3BulkCipherDef **cipher)
+{
+ if (version < SSL_LIBRARY_VERSION_TLS_1_3) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+
+ // Lookup and check the suite.
+ SSLVersionRange vrange = { version, version };
+ if (!ssl3_CipherSuiteAllowedForVersionRange(cipherSuite, &vrange)) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+ const ssl3CipherSuiteDef *suiteDef = ssl_LookupCipherSuiteDef(cipherSuite);
+ const ssl3BulkCipherDef *cipherDef = ssl_GetBulkCipherDef(suiteDef);
+ if (cipherDef->type != type_aead) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+ *hash = suiteDef->prf_hash;
+ if (cipher != NULL) {
+ *cipher = cipherDef;
+ }
+ return SECSuccess;
+}
+
unsigned int
tls13_GetHashSizeForHash(SSLHashType hash)
{
@@ -3474,6 +3503,17 @@ tls13_DeriveTrafficKeys(sslSocket *ss, ssl3CipherSpec *spec,
goto loser;
}
+ if (IS_DTLS(ss) && spec->epoch > 0) {
+ rv = ssl_CreateMaskingContextInner(spec->version,
+ ss->ssl3.hs.cipher_suite, prk, kHkdfPurposeSn,
+ strlen(kHkdfPurposeSn), &spec->maskContext);
+ if (rv != SECSuccess) {
+ LOG_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE);
+ PORT_Assert(0);
+ goto loser;
+ }
+ }
+
rv = tls13_HkdfExpandLabelRaw(prk, tls13_GetHash(ss),
NULL, 0,
kHkdfPurposeIv, strlen(kHkdfPurposeIv),
diff --git a/lib/ssl/tls13con.h b/lib/ssl/tls13con.h
index bd309419f..0160740a4 100644
--- a/lib/ssl/tls13con.h
+++ b/lib/ssl/tls13con.h
@@ -44,10 +44,12 @@ PRBool tls13_InHsState(sslSocket *ss, ...);
PRBool tls13_IsPostHandshake(const sslSocket *ss);
-SSLHashType tls13_GetHashForCipherSuite(ssl3CipherSuite suite);
SSLHashType tls13_GetHash(const sslSocket *ss);
-unsigned int tls13_GetHashSizeForHash(SSLHashType hash);
+SECStatus tls13_GetHashAndCipher(PRUint16 version, PRUint16 cipherSuite,
+ SSLHashType *hash, const ssl3BulkCipherDef **cipher);
+SSLHashType tls13_GetHashForCipherSuite(ssl3CipherSuite suite);
unsigned int tls13_GetHashSize(const sslSocket *ss);
+unsigned int tls13_GetHashSizeForHash(SSLHashType hash);
CK_MECHANISM_TYPE tls13_GetHkdfMechanism(sslSocket *ss);
CK_MECHANISM_TYPE tls13_GetHkdfMechanismForHash(SSLHashType hash);
SECStatus tls13_ComputeHash(sslSocket *ss, SSL3Hashes *hashes,