From b0a0fdd8f4f847d0942e9a7f8464a7fb8ae94921 Mon Sep 17 00:00:00 2001 From: Kevin Jacobs Date: Thu, 27 Feb 2020 02:39:11 +0000 Subject: Bug 1608892 - Update DTLS 1.3 to draft-34 r=mt This patch updates the DTLS 1.3 implementation to draft-34. Notable changes: 1) Key separation via `ssl_protocol_variant`. 2) No longer apply sequence number masking when in `UNSAFE_FUZZER_MODE`. This allowed removal of workarounds for unpadded (<16B) ciphertexts being used as input to `SSL_CreateMask`. 3) Compile ssl_gtests in `UNSAFE_FUZZER_MODE` iff `--fuzz=tls` was specified. Currently all gtests are compiled this way if `--fuzz`, but lib/ssl only if `--fuzz=tls`. (See above, we can't have ssl_gtests in fuzzer mode, but not lib/ssl, since the masking mismatch will break filters). 4) Parameterize masking tests, as appropriate. 5) Reject non-empty legacy_cookie, and test. 6) Reject ciphertexts <16B in length in `dtls13_MaskSequenceNumber` (if not `UNSAFE_FUZZER_MODE`). Differential Revision: https://phabricator.services.mozilla.com/D62488 --- gtests/ssl_gtest/ssl_extension_unittest.cc | 47 ++++ gtests/ssl_gtest/ssl_gtest.gyp | 12 ++ gtests/ssl_gtest/ssl_masking_unittest.cc | 335 +++++++++++++++-------------- gtests/ssl_gtest/tls_filter.cc | 12 +- gtests/ssl_gtest/tls_filter.h | 2 +- gtests/ssl_gtest/tls_hkdf_unittest.cc | 6 +- gtests/ssl_gtest/tls_protect.cc | 16 +- lib/ssl/dtls13con.c | 15 +- lib/ssl/ssl3con.c | 43 ++-- lib/ssl/ssl3prot.h | 2 +- lib/ssl/sslexp.h | 79 ++++++- lib/ssl/sslimpl.h | 22 ++ lib/ssl/sslinfo.c | 2 +- lib/ssl/sslprimitive.c | 70 ++++-- lib/ssl/sslsock.c | 4 + lib/ssl/tls13con.c | 28 ++- lib/ssl/tls13esni.c | 4 +- lib/ssl/tls13hkdf.c | 15 +- lib/ssl/tls13hkdf.h | 5 +- lib/ssl/tls13replay.c | 2 +- 20 files changed, 476 insertions(+), 245 deletions(-) diff --git a/gtests/ssl_gtest/ssl_extension_unittest.cc b/gtests/ssl_gtest/ssl_extension_unittest.cc index 837b6c9c4..b85568f43 100644 --- a/gtests/ssl_gtest/ssl_extension_unittest.cc +++ b/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -20,6 +20,45 @@ namespace nss_test { +class Dtls13LegacyCookieInjector : public TlsHandshakeFilter { + public: + Dtls13LegacyCookieInjector(const std::shared_ptr& a) + : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + const uint8_t cookie_bytes[] = {0x03, 0x0A, 0x0B, 0x0C}; + uint32_t offset = 2 /* version */ + 32 /* random */; + + if (agent()->variant() != ssl_variant_datagram) { + ADD_FAILURE(); + return KEEP; + } + + if (header.handshake_type() != ssl_hs_client_hello) { + return KEEP; + } + + DataBuffer cookie(cookie_bytes, sizeof(cookie_bytes)); + *output = input; + + // Add the SID length (if any) to locate the cookie. + uint32_t sid_len = 0; + if (!output->Read(offset, 1, &sid_len)) { + ADD_FAILURE(); + return KEEP; + } + offset += 1 + sid_len; + output->Splice(cookie, offset, 1); + + return CHANGE; + } + + private: + DataBuffer cookie_; +}; + class TlsExtensionTruncator : public TlsExtensionFilter { public: TlsExtensionTruncator(const std::shared_ptr& a, uint16_t extension, @@ -1246,6 +1285,14 @@ TEST_P(TlsConnectStream, IncludePadding) { EXPECT_TRUE(capture->captured()); } +TEST_F(TlsConnectDatagram13, Dtls13RejectLegacyCookie) { + EnsureTlsSetup(); + MakeTlsFilter(client_); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + INSTANTIATE_TEST_CASE_P( ExtensionStream, TlsExtensionTestGeneric, ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, diff --git a/gtests/ssl_gtest/ssl_gtest.gyp b/gtests/ssl_gtest/ssl_gtest.gyp index ae79c41fe..c44af7ed1 100644 --- a/gtests/ssl_gtest/ssl_gtest.gyp +++ b/gtests/ssl_gtest/ssl_gtest.gyp @@ -104,6 +104,18 @@ 'NSS_ALLOW_SSLKEYLOGFILE', ], }], + # ssl_gtest fuzz defines should only be determined by the 'fuzz_tls' + # flag (so as to match lib/ssl). If gtest.gypi added the define due + # to '--fuzz' only, remove it. + ['fuzz_tls==1', { + 'defines': [ + 'UNSAFE_FUZZER_MODE', + ], + }, { + 'defines!': [ + 'UNSAFE_FUZZER_MODE', + ], + }], ], } ], diff --git a/gtests/ssl_gtest/ssl_masking_unittest.cc b/gtests/ssl_gtest/ssl_masking_unittest.cc index 5b63b945b..cf0553cbb 100644 --- a/gtests/ssl_gtest/ssl_masking_unittest.cc +++ b/gtests/ssl_gtest/ssl_masking_unittest.cc @@ -46,23 +46,25 @@ class MaskingTest : public ::testing::Test { protected: ScopedPK11SymKey secret_; ScopedPK11SlotInfo slot_; - void CreateMask(PRUint16 ciphersuite, std::string label, - const std::vector &sample, + // Should have 4B ctr, 12B nonce for ChaCha, or >=16B ciphertext for AES. + // Use the same default size for mask output. + static const int kSampleSize = 16; + static const int kMaskSize = 16; + void CreateMask(PRUint16 ciphersuite, SSLProtocolVariant variant, + std::string label, const std::vector &sample, std::vector *out_mask) { ASSERT_NE(nullptr, out_mask); SSLMaskingContext *ctx_init = nullptr; EXPECT_EQ(SECSuccess, - SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite, - secret_.get(), label.c_str(), - label.size(), &ctx_init)); - EXPECT_EQ(0, PORT_GetError()); + SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite, variant, + secret_.get(), label.c_str(), label.size(), &ctx_init)); ASSERT_NE(nullptr, ctx_init); ScopedSSLMaskingContext ctx(ctx_init); EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), out_mask->data(), out_mask->size())); - EXPECT_EQ(0, PORT_GetError()); bool all_zeros = std::all_of(out_mask->begin(), out_mask->end(), [](uint8_t v) { return v == 0; }); @@ -79,7 +81,6 @@ class MaskingTest : public ::testing::Test { EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), tmp_sample.data(), tmp_sample.size(), tmp_mask.data(), tmp_mask.size())); - EXPECT_EQ(0, PORT_GetError()); bool retry_zero = std::all_of(tmp_mask.begin(), tmp_mask.end(), [](uint8_t v) { return v == 0; }); if (!retry_zero) { @@ -92,186 +93,233 @@ class MaskingTest : public ::testing::Test { } }; -TEST_F(MaskingTest, MaskContextNoLabel) { - std::vector sample(AES_BLOCK_SIZE); - std::vector mask(AES_BLOCK_SIZE); - CreateMask(TLS_AES_128_GCM_SHA256, std::string(""), sample, &mask); +class SuiteTest : public MaskingTest, + public ::testing::WithParamInterface { + public: + SuiteTest() : ciphersuite_(GetParam()) {} + void CreateMask(std::string label, const std::vector &sample, + std::vector *out_mask) { + MaskingTest::CreateMask(ciphersuite_, ssl_variant_datagram, label, sample, + out_mask); + } + + protected: + const uint16_t ciphersuite_; +}; + +class VariantTest : public MaskingTest, + public ::testing::WithParamInterface { + public: + VariantTest() : variant_(GetParam()) {} + void CreateMask(uint16_t ciphersuite, std::string label, + const std::vector &sample, + std::vector *out_mask) { + MaskingTest::CreateMask(ciphersuite, variant_, label, sample, out_mask); + } + + protected: + const SSLProtocolVariant variant_; +}; + +class VariantSuiteTest : public MaskingTest, + public ::testing::WithParamInterface< + std::tuple> { + public: + VariantSuiteTest() + : variant_(std::get<0>(GetParam())), + ciphersuite_(std::get<1>(GetParam())) {} + void CreateMask(std::string label, const std::vector &sample, + std::vector *out_mask) { + MaskingTest::CreateMask(ciphersuite_, variant_, label, sample, out_mask); + } + + protected: + const SSLProtocolVariant variant_; + const uint16_t ciphersuite_; +}; + +TEST_P(VariantSuiteTest, MaskContextNoLabel) { + std::vector sample(kSampleSize); + std::vector mask(kMaskSize); + CreateMask(std::string(""), sample, &mask); } -TEST_F(MaskingTest, MaskContextUnsupportedMech) { - std::vector sample(AES_BLOCK_SIZE); - std::vector mask(AES_BLOCK_SIZE); +TEST_P(VariantSuiteTest, MaskNoSample) { + std::vector mask(kMaskSize); SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_, + secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init)); + ASSERT_NE(nullptr, ctx_init); + ScopedSSLMaskingContext ctx(ctx_init); + EXPECT_EQ(SECFailure, - SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, - TLS_RSA_WITH_AES_128_CBC_SHA256, - secret_.get(), nullptr, 0, &ctx_init)); + SSL_CreateMask(ctx.get(), nullptr, 0, mask.data(), mask.size())); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), nullptr, mask.size(), + mask.data(), mask.size())); EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); - EXPECT_EQ(nullptr, ctx_init); } -TEST_F(MaskingTest, MaskNullSample) { - std::vector mask(AES_BLOCK_SIZE); +TEST_P(VariantSuiteTest, MaskShortSample) { + std::vector sample(kSampleSize); + std::vector mask(kMaskSize); SSLMaskingContext *ctx_init = nullptr; EXPECT_EQ(SECSuccess, - SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, - TLS_AES_128_GCM_SHA256, secret_.get(), - kLabel.c_str(), kLabel.size(), &ctx_init)); - EXPECT_EQ(0, PORT_GetError()); + SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_, + secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init)); ASSERT_NE(nullptr, ctx_init); ScopedSSLMaskingContext ctx(ctx_init); EXPECT_EQ(SECFailure, - SSL_CreateMask(ctx.get(), nullptr, 0, mask.data(), mask.size())); + SSL_CreateMask(ctx.get(), sample.data(), sample.size() - 1, + mask.data(), mask.size())); EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} - EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), nullptr, mask.size(), - mask.data(), mask.size())); +TEST_P(VariantSuiteTest, MaskContextUnsupportedMech) { + std::vector sample(kSampleSize); + std::vector mask(kMaskSize); + SSLMaskingContext *ctx_init = nullptr; + EXPECT_EQ(SECFailure, + SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_AES_128_CBC_SHA256, + variant_, secret_.get(), nullptr, 0, &ctx_init)); EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(nullptr, ctx_init); } -TEST_F(MaskingTest, MaskContextUnsupportedVersion) { - std::vector sample(AES_BLOCK_SIZE); - std::vector mask(AES_BLOCK_SIZE); +TEST_P(VariantSuiteTest, MaskContextUnsupportedVersion) { + std::vector sample(kSampleSize); + std::vector mask(kMaskSize); SSLMaskingContext *ctx_init = nullptr; - EXPECT_EQ(SECFailure, SSL_CreateMaskingContext( - SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256, + EXPECT_EQ(SECFailure, SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_2, ciphersuite_, variant_, secret_.get(), nullptr, 0, &ctx_init)); EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); EXPECT_EQ(nullptr, ctx_init); } -TEST_F(MaskingTest, MaskTooMuchOutput) { - // Max internally-supported length for AES - std::vector sample(AES_BLOCK_SIZE); - std::vector mask(AES_BLOCK_SIZE + 1); +TEST_P(VariantSuiteTest, MaskMaxLength) { + uint32_t max_mask_len = kMaskSize; + if (ciphersuite_ == TLS_CHACHA20_POLY1305_SHA256) { + // Internal limitation for ChaCha20 masks. + max_mask_len = 128; + } + + std::vector sample(kSampleSize); + std::vector mask(max_mask_len + 1); SSLMaskingContext *ctx_init = nullptr; EXPECT_EQ(SECSuccess, - SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, - TLS_AES_128_GCM_SHA256, secret_.get(), - kLabel.c_str(), kLabel.size(), &ctx_init)); - EXPECT_EQ(0, PORT_GetError()); + SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_, + secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init)); ASSERT_NE(nullptr, ctx_init); ScopedSSLMaskingContext ctx(ctx_init); + EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + mask.data(), mask.size() - 1)); EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), mask.data(), mask.size())); EXPECT_EQ(SEC_ERROR_OUTPUT_LEN, PORT_GetError()); } -TEST_F(MaskingTest, MaskShortOutput) { - std::vector sample(16); - std::vector mask(16); // Don't pass a null +TEST_P(VariantSuiteTest, MaskMinLength) { + std::vector sample(kSampleSize); + std::vector mask(1); // Don't pass a null SSLMaskingContext *ctx_init = nullptr; EXPECT_EQ(SECSuccess, - SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, - TLS_AES_128_GCM_SHA256, secret_.get(), - kLabel.c_str(), kLabel.size(), &ctx_init)); - EXPECT_EQ(0, PORT_GetError()); + SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, variant_, + secret_.get(), kLabel.c_str(), kLabel.size(), &ctx_init)); ASSERT_NE(nullptr, ctx_init); ScopedSSLMaskingContext ctx(ctx_init); EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), mask.data(), 0)); EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(SECSuccess, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), + mask.data(), 1)); } -TEST_F(MaskingTest, MaskRotateLabel) { - std::vector sample(AES_BLOCK_SIZE); - std::vector mask1(AES_BLOCK_SIZE); - std::vector mask2(AES_BLOCK_SIZE); +TEST_P(VariantSuiteTest, MaskRotateLabel) { + std::vector sample(kSampleSize); + std::vector mask1(kMaskSize); + std::vector mask2(kMaskSize); EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), sample.size())); - CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask1); - CreateMask(TLS_AES_128_GCM_SHA256, std::string("sn1"), sample, &mask2); + CreateMask(kLabel, sample, &mask1); + CreateMask(std::string("sn1"), sample, &mask2); EXPECT_FALSE(mask1 == mask2); } -TEST_F(MaskingTest, MaskRotateSample) { - std::vector sample(AES_BLOCK_SIZE); - std::vector mask1(AES_BLOCK_SIZE); - std::vector mask2(AES_BLOCK_SIZE); +TEST_P(VariantSuiteTest, MaskRotateSample) { + std::vector sample(kSampleSize); + std::vector mask1(kMaskSize); + std::vector mask2(kMaskSize); EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), sample.size())); - CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask1); + CreateMask(kLabel, sample, &mask1); EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), sample.size())); - CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask2); + CreateMask(kLabel, sample, &mask2); EXPECT_FALSE(mask1 == mask2); } -TEST_F(MaskingTest, MaskAesRederive) { - std::vector sample(AES_BLOCK_SIZE); - std::vector mask1(AES_BLOCK_SIZE); - std::vector mask2(AES_BLOCK_SIZE); +TEST_P(VariantSuiteTest, MaskRederive) { + std::vector sample(kSampleSize); + std::vector mask1(kMaskSize); + std::vector mask2(kMaskSize); SECStatus rv = PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), sample.size()); EXPECT_EQ(SECSuccess, rv); // Check that re-using inputs with a new context produces the same mask. - CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask1); - CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask2); + CreateMask(kLabel, sample, &mask1); + CreateMask(kLabel, sample, &mask2); EXPECT_TRUE(mask1 == mask2); } -TEST_F(MaskingTest, MaskAesTooLong) { - std::vector sample(AES_BLOCK_SIZE + 1); - std::vector mask(AES_BLOCK_SIZE + 1); - SSLMaskingContext *ctx_init = nullptr; +TEST_P(SuiteTest, MaskTlsVariantKeySeparation) { + std::vector sample(kSampleSize); + std::vector tls_mask(kMaskSize); + std::vector dtls_mask(kMaskSize); + SSLMaskingContext *stream_ctx_init = nullptr; + SSLMaskingContext *datagram_ctx_init = nullptr; + + // Init + EXPECT_EQ(SECSuccess, SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, + ssl_variant_stream, secret_.get(), kLabel.c_str(), + kLabel.size(), &stream_ctx_init)); + ASSERT_NE(nullptr, stream_ctx_init); + EXPECT_EQ(SECSuccess, SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, ciphersuite_, + ssl_variant_datagram, secret_.get(), kLabel.c_str(), + kLabel.size(), &datagram_ctx_init)); + ASSERT_NE(nullptr, datagram_ctx_init); + ScopedSSLMaskingContext tls_ctx(stream_ctx_init); + ScopedSSLMaskingContext dtls_ctx(datagram_ctx_init); + + // Derive EXPECT_EQ(SECSuccess, - SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, - TLS_AES_128_GCM_SHA256, secret_.get(), - kLabel.c_str(), kLabel.size(), &ctx_init)); - EXPECT_EQ(0, PORT_GetError()); - ASSERT_NE(nullptr, ctx_init); - ScopedSSLMaskingContext ctx(ctx_init); - EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), - mask.data(), mask.size())); - EXPECT_EQ(SEC_ERROR_OUTPUT_LEN, PORT_GetError()); -} + SSL_CreateMask(tls_ctx.get(), sample.data(), sample.size(), + tls_mask.data(), tls_mask.size())); -TEST_F(MaskingTest, MaskAesShortSample) { - std::vector sample(AES_BLOCK_SIZE - 1); - std::vector mask(AES_BLOCK_SIZE - 1); - SSLMaskingContext *ctx_init = nullptr; EXPECT_EQ(SECSuccess, - SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, - TLS_AES_128_GCM_SHA256, secret_.get(), - kLabel.c_str(), kLabel.size(), &ctx_init)); - EXPECT_EQ(0, PORT_GetError()); - ASSERT_NE(nullptr, ctx_init); - ScopedSSLMaskingContext ctx(ctx_init); - - EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), - mask.data(), mask.size())); - EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); -} - -TEST_F(MaskingTest, MaskAesShortValid) { - std::vector sample(AES_BLOCK_SIZE); - std::vector mask(1); - EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), - sample.size())); - CreateMask(TLS_AES_128_GCM_SHA256, kLabel, sample, &mask); + SSL_CreateMask(dtls_ctx.get(), sample.data(), sample.size(), + dtls_mask.data(), dtls_mask.size())); + EXPECT_NE(tls_mask, dtls_mask); } -TEST_F(MaskingTest, MaskChaChaRederive) { - // Block-aligned. - std::vector sample(32); - std::vector mask1(32); - std::vector mask2(32); - EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), - sample.size())); - CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask1); - CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask2); - EXPECT_TRUE(mask1 == mask2); -} - -TEST_F(MaskingTest, MaskChaChaRederiveOddSizes) { +TEST_P(VariantTest, MaskChaChaRederiveOddSizes) { // Non-block-aligned. std::vector sample(27); std::vector mask1(26); @@ -284,54 +332,19 @@ TEST_F(MaskingTest, MaskChaChaRederiveOddSizes) { EXPECT_TRUE(mask1 == mask2); } -TEST_F(MaskingTest, MaskChaChaLongValid) { - // Max internally-supported length for ChaCha - std::vector sample(128); - std::vector mask(128); - EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), - sample.size())); - CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask); -} +static const uint16_t kMaskingCiphersuites[] = {TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384}; +::testing::internal::ParamGenerator kMaskingCiphersuiteParams = + ::testing::ValuesIn(kMaskingCiphersuites); -TEST_F(MaskingTest, MaskChaChaTooLong) { - // Max internally-supported length for ChaCha - std::vector sample(128 + 1); - std::vector mask(128 + 1); - SSLMaskingContext *ctx_init = nullptr; - EXPECT_EQ(SECSuccess, SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, - TLS_CHACHA20_POLY1305_SHA256, - secret_.get(), kLabel.c_str(), - kLabel.size(), &ctx_init)); - EXPECT_EQ(0, PORT_GetError()); - ASSERT_NE(nullptr, ctx_init); - ScopedSSLMaskingContext ctx(ctx_init); - EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), - mask.data(), mask.size())); - EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); -} +INSTANTIATE_TEST_CASE_P(GenericMasking, SuiteTest, kMaskingCiphersuiteParams); -TEST_F(MaskingTest, MaskChaChaShortSample) { - std::vector sample(15); // Should have 4B ctr, 12B nonce. - std::vector mask(15); - SSLMaskingContext *ctx_init = nullptr; - EXPECT_EQ(SECSuccess, SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, - TLS_CHACHA20_POLY1305_SHA256, - secret_.get(), kLabel.c_str(), - kLabel.size(), &ctx_init)); - EXPECT_EQ(0, PORT_GetError()); - ASSERT_NE(nullptr, ctx_init); - ScopedSSLMaskingContext ctx(ctx_init); - EXPECT_EQ(SECFailure, SSL_CreateMask(ctx.get(), sample.data(), sample.size(), - mask.data(), mask.size())); - EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); -} +INSTANTIATE_TEST_CASE_P(GenericMasking, VariantTest, + TlsConnectTestBase::kTlsVariantsAll); -TEST_F(MaskingTest, MaskChaChaShortValid) { - std::vector sample(16); - std::vector mask(1); - EXPECT_EQ(SECSuccess, PK11_GenerateRandomOnSlot(slot_.get(), sample.data(), - sample.size())); - CreateMask(TLS_CHACHA20_POLY1305_SHA256, kLabel, sample, &mask); -} +INSTANTIATE_TEST_CASE_P(GenericMasking, VariantSuiteTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + kMaskingCiphersuiteParams)); } // namespace nss_test diff --git a/gtests/ssl_gtest/tls_filter.cc b/gtests/ssl_gtest/tls_filter.cc index d47ee71ab..ef08c7e30 100644 --- a/gtests/ssl_gtest/tls_filter.cc +++ b/gtests/ssl_gtest/tls_filter.cc @@ -269,11 +269,12 @@ bool TlsRecordHeader::MaskSequenceNumber() { return MaskSequenceNumber(sn_mask()); } -bool TlsRecordHeader::MaskSequenceNumber(const DataBuffer& mask) { - if (mask.empty()) { +bool TlsRecordHeader::MaskSequenceNumber(const DataBuffer& mask_buf) { + if (mask_buf.empty()) { return false; } + DataBuffer mask; if (is_dtls13_ciphertext()) { uint64_t seqno = sequence_number(); uint8_t len = content_type() & kCtDtlsCiphertext16bSeqno ? 2 : 1; @@ -283,11 +284,15 @@ bool TlsRecordHeader::MaskSequenceNumber(const DataBuffer& mask) { return false; } +#ifdef UNSAFE_FUZZER_MODE + // Use a null mask. + mask.Allocate(mask_buf.len()); +#endif + mask.Append(mask_buf); val.data()[0] ^= mask.data()[0]; if (len == 2 && mask.len() > 1) { val.data()[1] ^= mask.data()[1]; } - uint32_t tmp; if (!val.Read(0, len, &tmp)) { return false; @@ -1152,5 +1157,4 @@ PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake( output->Write(pos, static_cast(cipher_suite_), 2); return CHANGE; } - } // namespace nss_test diff --git a/gtests/ssl_gtest/tls_filter.h b/gtests/ssl_gtest/tls_filter.h index 8cf558f9c..5300075ea 100644 --- a/gtests/ssl_gtest/tls_filter.h +++ b/gtests/ssl_gtest/tls_filter.h @@ -148,7 +148,7 @@ class TlsRecordHeader : public TlsVersioned { const DataBuffer& header() const { return header_; } bool MaskSequenceNumber(); - bool MaskSequenceNumber(const DataBuffer& mask); + bool MaskSequenceNumber(const DataBuffer& mask_buf); // Parse the header; return true if successful; body in an outparam if OK. bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser, diff --git a/gtests/ssl_gtest/tls_hkdf_unittest.cc b/gtests/ssl_gtest/tls_hkdf_unittest.cc index e1ad9e9f0..7dea7e68b 100644 --- a/gtests/ssl_gtest/tls_hkdf_unittest.cc +++ b/gtests/ssl_gtest/tls_hkdf_unittest.cc @@ -192,9 +192,9 @@ class TlsHkdfTest : public ::testing::Test, std::vector output(expected.len()); - SECStatus rv = tls13_HkdfExpandLabelRaw(prk->get(), base_hash, session_hash, - session_hash_len, label, label_len, - &output[0], output.size()); + SECStatus rv = tls13_HkdfExpandLabelRaw( + prk->get(), base_hash, session_hash, session_hash_len, label, label_len, + ssl_variant_stream, &output[0], output.size()); ASSERT_EQ(SECSuccess, rv); DumpData("Output", &output[0], output.size()); EXPECT_EQ(0, memcmp(expected.data(), &output[0], expected.len())); diff --git a/gtests/ssl_gtest/tls_protect.cc b/gtests/ssl_gtest/tls_protect.cc index 7737fe5ea..6187660a5 100644 --- a/gtests/ssl_gtest/tls_protect.cc +++ b/gtests/ssl_gtest/tls_protect.cc @@ -26,10 +26,12 @@ TlsCipherSpec::TlsCipherSpec(bool dtls, uint16_t epoc) bool TlsCipherSpec::SetKeys(SSLCipherSuiteInfo* cipherinfo, PK11SymKey* secret) { SSLAeadContext* aead_ctx; - SECStatus rv = SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, - cipherinfo->cipherSuite, secret, "", - 0, // Use the default labels. - &aead_ctx); + SSLProtocolVariant variant = + dtls_ ? ssl_variant_datagram : ssl_variant_stream; + SECStatus rv = + SSL_MakeVariantAead(SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite, + variant, secret, "", 0, // Use the default labels. + &aead_ctx); if (rv != SECSuccess) { return false; } @@ -37,9 +39,9 @@ bool TlsCipherSpec::SetKeys(SSLCipherSuiteInfo* cipherinfo, SSLMaskingContext* mask_ctx; const char kHkdfPurposeSn[] = "sn"; - rv = SSL_CreateMaskingContext(SSL_LIBRARY_VERSION_TLS_1_3, - cipherinfo->cipherSuite, secret, kHkdfPurposeSn, - strlen(kHkdfPurposeSn), &mask_ctx); + rv = SSL_CreateVariantMaskingContext( + SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite, variant, secret, + kHkdfPurposeSn, strlen(kHkdfPurposeSn), &mask_ctx); if (rv != SECSuccess) { return false; } diff --git a/lib/ssl/dtls13con.c b/lib/ssl/dtls13con.c index c87e0907a..daa5e2c7c 100644 --- a/lib/ssl/dtls13con.c +++ b/lib/ssl/dtls13con.c @@ -537,19 +537,32 @@ dtls13_MaskSequenceNumber(sslSocket *ss, ssl3CipherSpec *spec, } if (spec->maskContext) { +#ifdef UNSAFE_FUZZER_MODE + /* Use a null mask. */ + PRUint8 mask[2] = { 0 }; +#else + /* "This procedure requires the ciphertext length be at least 16 bytes. + * Receivers MUST reject shorter records as if they had failed + * deprotection, as described in Section 4.5.2." */ + if (cipherTextLen < 16) { + PORT_SetError(SSL_ERROR_BAD_MAC_READ); + return SECFailure; + } + PRUint8 mask[2]; SECStatus rv = ssl_CreateMaskInner(spec->maskContext, cipherText, cipherTextLen, mask, sizeof(mask)); if (rv != SECSuccess) { + PORT_SetError(SSL_ERROR_BAD_MAC_READ); return SECFailure; } +#endif hdr[1] ^= mask[0]; if (hdr[0] & 0x08) { hdr[2] ^= mask[1]; } } - return SECSuccess; } diff --git a/lib/ssl/ssl3con.c b/lib/ssl/ssl3con.c index e8ea99d82..b0db2acb1 100644 --- a/lib/ssl/ssl3con.c +++ b/lib/ssl/ssl3con.c @@ -2445,16 +2445,6 @@ ssl_ProtectRecord(sslSocket *ss, ssl3CipherSpec *cwSpec, SSLContentType ct, } if (IS_DTLS(ss)) { bufLen = SSL_BUFFER_LEN(wrBuf) - bufLen; -#ifdef UNSAFE_FUZZER_MODE - /* The null cipher doesn't add a tag. Make sure the "ciphertext" - * is long enough for mask creation. */ - unsigned char tmpCt[AES_BLOCK_SIZE] = { 0 }; - if (bufLen < 16) { - memcpy(tmpCt, cipherText, bufLen); - bufLen = sizeof(tmpCt); - cipherText = tmpCt; - } -#endif rv = dtls13_MaskSequenceNumber(ss, cwSpec, SSL_BUFFER_BASE(wrBuf), cipherText, bufLen); @@ -8631,15 +8621,12 @@ ssl3_HandleClientHello(sslSocket *ss, PRUint8 *b, PRUint32 length) goto loser; /* malformed */ } - /* Grab the client's cookie, if present. */ + /* Grab the client's cookie, if present. It is checked after version negotiation. */ if (IS_DTLS(ss)) { rv = ssl3_ConsumeHandshakeVariable(ss, &cookieBytes, 1, &b, &length); if (rv != SECSuccess) { goto loser; /* malformed */ } - if (cookieBytes.len != 0) { - goto loser; /* We never send cookies in DTLS 1.2. */ - } } /* Grab the list of cipher suites. */ @@ -8745,6 +8732,13 @@ ssl3_HandleClientHello(sslSocket *ss, PRUint8 *b, PRUint32 length) errCode = SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER; goto alert_loser; } + + /* A DTLS 1.3-only client MUST set the legacy_cookie field to zero length. + * If a DTLS 1.3 ClientHello is received with any other value in this field, + * the server MUST abort the handshake with an "illegal_parameter" alert. */ + if (IS_DTLS(ss) && cookieBytes.len != 0) { + goto alert_loser; + } } else { /* HRR is TLS1.3-only. We ignore the Cookie extension here. */ if (ss->ssl3.hs.helloRetry) { @@ -8765,6 +8759,11 @@ ssl3_HandleClientHello(sslSocket *ss, PRUint8 *b, PRUint32 length) !memchr(comps.data, ssl_compression_null, comps.len)) { goto alert_loser; } + + /* We never send cookies in DTLS 1.2. */ + if (IS_DTLS(ss) && cookieBytes.len != 0) { + goto loser; + } } /* Now parse the rest of the extensions. */ @@ -12919,22 +12918,10 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText) } isTLS = (PRBool)(spec->version > SSL_LIBRARY_VERSION_3_0); if (IS_DTLS(ss)) { - unsigned int bufLen = SSL_BUFFER_LEN(cText->buf); - unsigned char *cipherText = SSL_BUFFER_BASE(cText->buf); -#ifdef UNSAFE_FUZZER_MODE - /* The null cipher doesn't add a tag. Make sure the "ciphertext" - * is long enough for mask creation. */ - unsigned char tmpCt[AES_BLOCK_SIZE] = { 0 }; - if (bufLen < 16) { - memcpy(tmpCt, cipherText, bufLen); - bufLen = sizeof(tmpCt); - cipherText = tmpCt; - } -#endif if (dtls13_MaskSequenceNumber(ss, spec, cText->hdr, - cipherText, bufLen) != SECSuccess) { + SSL_BUFFER_BASE(cText->buf), SSL_BUFFER_LEN(cText->buf)) != SECSuccess) { ssl_ReleaseSpecReadLock(ss); /*****************************/ - PORT_SetError(SSL_ERROR_DECRYPTION_FAILURE); + /* code already set. */ return SECFailure; } if (!dtls_IsRelevant(ss, spec, cText, &cText->seqNum)) { diff --git a/lib/ssl/ssl3prot.h b/lib/ssl/ssl3prot.h index b180931e9..d7375551f 100644 --- a/lib/ssl/ssl3prot.h +++ b/lib/ssl/ssl3prot.h @@ -14,7 +14,7 @@ typedef PRUint16 SSL3ProtocolVersion; /* version numbers are defined in sslproto.h */ /* DTLS 1.3 is still a draft. */ -#define DTLS_1_3_DRAFT_VERSION 30 +#define DTLS_1_3_DRAFT_VERSION 34 typedef PRUint16 ssl3CipherSuite; /* The cipher suites are defined in sslproto.h */ diff --git a/lib/ssl/sslexp.h b/lib/ssl/sslexp.h index 61b1fc088..d23ad9411 100644 --- a/lib/ssl/sslexp.h +++ b/lib/ssl/sslexp.h @@ -662,7 +662,11 @@ typedef SECStatus(PR_CALLBACK *SSLRecordWriteCallback)( * used in TLS. The lower bits of the IV are XORed with the 64-bit counter to * produce the nonce. Otherwise, this is an AEAD interface similar to that * described in RFC 5116. - */ + * + * Note: SSL_MakeAead internally calls SSL_MakeVariantAead with a variant of + * "stream", behaving as noted above. If "datagram" variant is passed instead, + * the Label prefix used in HKDF-Expand is "dtls13" instead of "tls13 ". See + * 7.1 of RFC 8446 and draft-ietf-tls-dtls13-34. */ typedef struct SSLAeadContextStr SSLAeadContext; #define SSL_MakeAead(version, cipherSuite, secret, \ @@ -676,6 +680,18 @@ typedef struct SSLAeadContextStr SSLAeadContext; (version, cipherSuite, secret, \ labelPrefix, labelPrefixLen, ctx)) +#define SSL_MakeVariantAead(version, cipherSuite, variant, secret, \ + labelPrefix, labelPrefixLen, ctx) \ + SSL_EXPERIMENTAL_API("SSL_MakeVariantAead", \ + (PRUint16 _version, PRUint16 _cipherSuite, \ + SSLProtocolVariant _variant, \ + PK11SymKey * _secret, \ + const char *_labelPrefix, \ + unsigned int _labelPrefixLen, \ + SSLAeadContext **_ctx), \ + (version, cipherSuite, variant, secret, \ + labelPrefix, labelPrefixLen, ctx)) + #define SSL_AeadEncrypt(ctx, counter, aad, aadLen, in, inLen, \ output, outputLen, maxOutputLen) \ SSL_EXPERIMENTAL_API("SSL_AeadEncrypt", \ @@ -716,8 +732,13 @@ typedef struct SSLAeadContextStr SSLAeadContext; PK11SymKey * *_keyp), \ (version, cipherSuite, salt, ikm, keyp)) -/* SSL_HkdfExpandLabel produces a key with a mechanism that is suitable for - * input to SSL_HkdfExpandLabel or SSL_MakeAead. */ +/* SSL_HkdfExpandLabel and SSL_HkdfVariantExpandLabel produce a key with a + * mechanism that is suitable for input to SSL_HkdfExpandLabel or SSL_MakeAead. + * + * Note: SSL_HkdfVariantExpandLabel internally calls SSL_HkdfExpandLabel with + * a default "stream" variant. If "datagram" variant is passed instead, the + * Label prefix used in HKDF-Expand is "dtls13" instead of "tls13 ". See 7.1 of + * RFC 8446 and draft-ietf-tls-dtls13-34. */ #define SSL_HkdfExpandLabel(version, cipherSuite, prk, \ hsHash, hsHashLen, label, labelLen, keyp) \ SSL_EXPERIMENTAL_API("SSL_HkdfExpandLabel", \ @@ -729,9 +750,28 @@ typedef struct SSLAeadContextStr SSLAeadContext; (version, cipherSuite, prk, \ hsHash, hsHashLen, label, labelLen, keyp)) -/* SSL_HkdfExpandLabelWithMech uses the KDF from the selected TLS version and - * cipher suite, as with the other calls, but the provided mechanism and key - * size. This allows the key to be used more widely. */ +#define SSL_HkdfVariantExpandLabel(version, cipherSuite, prk, \ + hsHash, hsHashLen, label, labelLen, variant, \ + keyp) \ + SSL_EXPERIMENTAL_API("SSL_HkdfVariantExpandLabel", \ + (PRUint16 _version, PRUint16 _cipherSuite, \ + PK11SymKey * _prk, \ + const PRUint8 *_hsHash, unsigned int _hsHashLen, \ + const char *_label, unsigned int _labelLen, \ + SSLProtocolVariant _variant, \ + PK11SymKey **_keyp), \ + (version, cipherSuite, prk, \ + hsHash, hsHashLen, label, labelLen, variant, \ + keyp)) + +/* SSL_HkdfExpandLabelWithMech and SSL_HkdfVariantExpandLabelWithMech use the KDF + * from the selected TLS version and cipher suite, as with the other calls, but + * the provided mechanism and key size. This allows the key to be used more widely. + * + * Note: SSL_HkdfExpandLabelWithMech internally calls SSL_HkdfVariantExpandLabelWithMech + * with a default "stream" variant. If "datagram" variant is passed instead, the + * Label prefix used in HKDF-Expand is "dtls13" instead of "tls13 ". See 7.1 of + * RFC 8446 and draft-ietf-tls-dtls13-34. */ #define SSL_HkdfExpandLabelWithMech(version, cipherSuite, prk, \ hsHash, hsHashLen, label, labelLen, \ mech, keySize, keyp) \ @@ -746,6 +786,21 @@ typedef struct SSLAeadContextStr SSLAeadContext; hsHash, hsHashLen, label, labelLen, \ mech, keySize, keyp)) +#define SSL_HkdfVariantExpandLabelWithMech(version, cipherSuite, prk, \ + hsHash, hsHashLen, label, labelLen, \ + mech, keySize, variant, keyp) \ + SSL_EXPERIMENTAL_API("SSL_HkdfVariantExpandLabelWithMech", \ + (PRUint16 _version, PRUint16 _cipherSuite, \ + PK11SymKey * _prk, \ + const PRUint8 *_hsHash, unsigned int _hsHashLen, \ + const char *_label, unsigned int _labelLen, \ + CK_MECHANISM_TYPE _mech, unsigned int _keySize, \ + SSLProtocolVariant _variant, \ + PK11SymKey **_keyp), \ + (version, cipherSuite, prk, \ + hsHash, hsHashLen, label, labelLen, \ + mech, keySize, variant, keyp)) + /* SSL_SetTimeFunc overrides the default time function (PR_Now()) and provides * an alternative source of time for the socket. This is used in testing, and in * applications that need better control over how the clock is accessed. Set the @@ -864,6 +919,18 @@ typedef struct SSLMaskingContextStr { SSLMaskingContext **_ctx), \ (version, cipherSuite, secret, label, labelLen, ctx)) +#define SSL_CreateVariantMaskingContext(version, cipherSuite, variant, \ + secret, label, labelLen, ctx) \ + SSL_EXPERIMENTAL_API("SSL_CreateVariantMaskingContext", \ + (PRUint16 _version, PRUint16 _cipherSuite, \ + SSLProtocolVariant _variant, \ + PK11SymKey * _secret, \ + const char *_label, \ + unsigned int _labelLen, \ + SSLMaskingContext **_ctx), \ + (version, cipherSuite, variant, secret, \ + label, labelLen, ctx)) + #define SSL_DestroyMaskingContext(ctx) \ SSL_EXPERIMENTAL_API("SSL_DestroyMaskingContext", \ (SSLMaskingContext * _ctx), \ diff --git a/lib/ssl/sslimpl.h b/lib/ssl/sslimpl.h index af789c73e..2ca945562 100644 --- a/lib/ssl/sslimpl.h +++ b/lib/ssl/sslimpl.h @@ -1824,6 +1824,10 @@ SECStatus SSLExp_GetCurrentEpoch(PRFileDesc *fd, PRUint16 *readEpoch, SECStatus SSLExp_MakeAead(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *secret, const char *labelPrefix, unsigned int labelPrefixLen, SSLAeadContext **ctx); + +SECStatus SSLExp_MakeVariantAead(PRUint16 version, PRUint16 cipherSuite, SSLProtocolVariant variant, + PK11SymKey *secret, const char *labelPrefix, + unsigned int labelPrefixLen, SSLAeadContext **ctx); SECStatus SSLExp_DestroyAead(SSLAeadContext *ctx); SECStatus SSLExp_AeadEncrypt(const SSLAeadContext *ctx, PRUint64 counter, const PRUint8 *aad, unsigned int aadLen, @@ -1840,16 +1844,27 @@ SECStatus SSLExp_HkdfExpandLabel(PRUint16 version, PRUint16 cipherSuite, PK11Sym const PRUint8 *hsHash, unsigned int hsHashLen, const char *label, unsigned int labelLen, PK11SymKey **key); +SECStatus SSLExp_HkdfVariantExpandLabel(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *prk, + const PRUint8 *hsHash, unsigned int hsHashLen, + const char *label, unsigned int labelLen, + SSLProtocolVariant variant, PK11SymKey **key); SECStatus SSLExp_HkdfExpandLabelWithMech(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *prk, const PRUint8 *hsHash, unsigned int hsHashLen, const char *label, unsigned int labelLen, CK_MECHANISM_TYPE mech, unsigned int keySize, PK11SymKey **keyp); +SECStatus +SSLExp_HkdfVariantExpandLabelWithMech(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *prk, + const PRUint8 *hsHash, unsigned int hsHashLen, + const char *label, unsigned int labelLen, + CK_MECHANISM_TYPE mech, unsigned int keySize, + SSLProtocolVariant variant, PK11SymKey **keyp); SECStatus SSLExp_SetTimeFunc(PRFileDesc *fd, SSLTimeFunc f, void *arg); extern SECStatus ssl_CreateMaskingContextInner(PRUint16 version, PRUint16 cipherSuite, + SSLProtocolVariant variant, PK11SymKey *secret, const char *label, unsigned int labelLen, @@ -1867,6 +1882,13 @@ SECStatus SSLExp_CreateMaskingContext(PRUint16 version, PRUint16 cipherSuite, unsigned int labelLen, SSLMaskingContext **ctx); +SECStatus SSLExp_CreateVariantMaskingContext(PRUint16 version, PRUint16 cipherSuite, + SSLProtocolVariant variant, + PK11SymKey *secret, + const char *label, + unsigned int labelLen, + SSLMaskingContext **ctx); + SECStatus SSLExp_CreateMask(SSLMaskingContext *ctx, const PRUint8 *sample, unsigned int sampleLen, PRUint8 *mask, unsigned int len); diff --git a/lib/ssl/sslinfo.c b/lib/ssl/sslinfo.c index b069888e2..115c38dc1 100644 --- a/lib/ssl/sslinfo.c +++ b/lib/ssl/sslinfo.c @@ -432,7 +432,7 @@ tls13_Exporter(sslSocket *ss, PK11SymKey *secret, contextHash.u.raw, contextHash.len, kExporterInnerLabel, strlen(kExporterInnerLabel), - out, outLen); + ss->protocolVariant, out, outLen); PK11_FreeSymKey(innerSecret); return rv; } diff --git a/lib/ssl/sslprimitive.c b/lib/ssl/sslprimitive.c index 5522f96fd..7cff599ad 100644 --- a/lib/ssl/sslprimitive.c +++ b/lib/ssl/sslprimitive.c @@ -25,9 +25,9 @@ struct SSLAeadContextStr { }; SECStatus -SSLExp_MakeAead(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *secret, - const char *labelPrefix, unsigned int labelPrefixLen, - SSLAeadContext **ctx) +SSLExp_MakeVariantAead(PRUint16 version, PRUint16 cipherSuite, SSLProtocolVariant variant, + PK11SymKey *secret, const char *labelPrefix, + unsigned int labelPrefixLen, SSLAeadContext **ctx) { SSLAeadContext *out = NULL; char label[255]; // Maximum length label. @@ -62,7 +62,7 @@ SSLExp_MakeAead(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *secret, unsigned int ivLen = cipher->iv_size + cipher->explicit_nonce_size; rv = tls13_HkdfExpandLabelRaw(secret, hash, NULL, 0, // Handshake hash. - label, labelLen, + label, labelLen, variant, out->keys.iv, ivLen); if (rv != SECSuccess) { goto loser; @@ -72,8 +72,8 @@ SSLExp_MakeAead(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *secret, labelLen = labelPrefixLen + strlen(keySuffix); rv = tls13_HkdfExpandLabel(secret, hash, NULL, 0, // Handshake hash. - label, labelLen, - out->mech, cipher->key_size, &out->keys.key); + label, labelLen, out->mech, cipher->key_size, + variant, &out->keys.key); if (rv != SECSuccess) { goto loser; } @@ -86,6 +86,14 @@ loser: return SECFailure; } +SECStatus +SSLExp_MakeAead(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *secret, + const char *labelPrefix, unsigned int labelPrefixLen, SSLAeadContext **ctx) +{ + return SSLExp_MakeVariantAead(version, cipherSuite, ssl_variant_stream, secret, + labelPrefix, labelPrefixLen, ctx); +} + SECStatus SSLExp_DestroyAead(SSLAeadContext *ctx) { @@ -202,8 +210,17 @@ SSLExp_HkdfExtract(PRUint16 version, PRUint16 cipherSuite, SECStatus SSLExp_HkdfExpandLabel(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *prk, const PRUint8 *hsHash, unsigned int hsHashLen, - const char *label, unsigned int labelLen, - PK11SymKey **keyp) + const char *label, unsigned int labelLen, PK11SymKey **keyp) +{ + return SSLExp_HkdfVariantExpandLabel(version, cipherSuite, prk, hsHash, hsHashLen, + label, labelLen, ssl_variant_stream, keyp); +} + +SECStatus +SSLExp_HkdfVariantExpandLabel(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *prk, + const PRUint8 *hsHash, unsigned int hsHashLen, + const char *label, unsigned int labelLen, + SSLProtocolVariant variant, PK11SymKey **keyp) { if (prk == NULL || keyp == NULL || label == NULL || labelLen == 0) { @@ -219,7 +236,7 @@ SSLExp_HkdfExpandLabel(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *prk, } return tls13_HkdfExpandLabel(prk, hash, hsHash, hsHashLen, label, labelLen, tls13_GetHkdfMechanismForHash(hash), - tls13_GetHashSizeForHash(hash), keyp); + tls13_GetHashSizeForHash(hash), variant, keyp); } SECStatus @@ -228,6 +245,18 @@ SSLExp_HkdfExpandLabelWithMech(PRUint16 version, PRUint16 cipherSuite, PK11SymKe const char *label, unsigned int labelLen, CK_MECHANISM_TYPE mech, unsigned int keySize, PK11SymKey **keyp) +{ + return SSLExp_HkdfVariantExpandLabelWithMech(version, cipherSuite, prk, hsHash, hsHashLen, + label, labelLen, mech, keySize, + ssl_variant_stream, keyp); +} + +SECStatus +SSLExp_HkdfVariantExpandLabelWithMech(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *prk, + const PRUint8 *hsHash, unsigned int hsHashLen, + const char *label, unsigned int labelLen, + CK_MECHANISM_TYPE mech, unsigned int keySize, + SSLProtocolVariant variant, PK11SymKey **keyp) { if (prk == NULL || keyp == NULL || label == NULL || labelLen == 0 || @@ -243,11 +272,12 @@ SSLExp_HkdfExpandLabelWithMech(PRUint16 version, PRUint16 cipherSuite, PK11SymKe return SECFailure; /* Code already set. */ } return tls13_HkdfExpandLabel(prk, hash, hsHash, hsHashLen, label, labelLen, - mech, keySize, keyp); + mech, keySize, variant, keyp); } SECStatus ssl_CreateMaskingContextInner(PRUint16 version, PRUint16 cipherSuite, + SSLProtocolVariant variant, PK11SymKey *secret, const char *label, unsigned int labelLen, @@ -283,7 +313,8 @@ ssl_CreateMaskingContextInner(PRUint16 version, PRUint16 cipherSuite, NULL, 0, // Handshake hash. label, labelLen, out->mech, - cipher->key_size, &out->secret); + cipher->key_size, variant, + &out->secret); if (rv != SECSuccess) { goto loser; } @@ -356,7 +387,7 @@ ssl_CreateMaskInner(SSLMaskingContext *ctx, const PRUint8 *sample, unsigned char zeros[128] = { 0 }; if (maskLen > sizeof(zeros)) { - PORT_SetError(SEC_ERROR_INVALID_ARGS); + PORT_SetError(SEC_ERROR_OUTPUT_LEN); return SECFailure; } @@ -413,7 +444,20 @@ SSLExp_CreateMaskingContext(PRUint16 version, PRUint16 cipherSuite, unsigned int labelLen, SSLMaskingContext **ctx) { - return ssl_CreateMaskingContextInner(version, cipherSuite, secret, label, labelLen, ctx); + return ssl_CreateMaskingContextInner(version, cipherSuite, ssl_variant_stream, secret, + label, labelLen, ctx); +} + +SECStatus +SSLExp_CreateVariantMaskingContext(PRUint16 version, PRUint16 cipherSuite, + SSLProtocolVariant variant, + PK11SymKey *secret, + const char *label, + unsigned int labelLen, + SSLMaskingContext **ctx) +{ + return ssl_CreateMaskingContextInner(version, cipherSuite, variant, secret, + label, labelLen, ctx); } SECStatus diff --git a/lib/ssl/sslsock.c b/lib/ssl/sslsock.c index 581f0c467..cf77c187b 100644 --- a/lib/ssl/sslsock.c +++ b/lib/ssl/sslsock.c @@ -4222,6 +4222,7 @@ struct { EXP(CreateAntiReplayContext), EXP(CreateMask), EXP(CreateMaskingContext), + EXP(CreateVariantMaskingContext), EXP(DelegateCredential), EXP(DestroyAead), EXP(DestroyMaskingContext), @@ -4236,8 +4237,11 @@ struct { EXP(HkdfExtract), EXP(HkdfExpandLabel), EXP(HkdfExpandLabelWithMech), + EXP(HkdfVariantExpandLabel), + EXP(HkdfVariantExpandLabelWithMech), EXP(KeyUpdate), EXP(MakeAead), + EXP(MakeVariantAead), EXP(RecordLayerData), EXP(RecordLayerWriteCallback), EXP(ReleaseAntiReplayContext), diff --git a/lib/ssl/tls13con.c b/lib/ssl/tls13con.c index 1d43d8c73..261f9b301 100644 --- a/lib/ssl/tls13con.c +++ b/lib/ssl/tls13con.c @@ -670,6 +670,7 @@ tls13_UpdateTrafficKeys(sslSocket *ss, SSLSecretDirection direction) strlen(kHkdfLabelTrafficUpdate), tls13_GetHmacMechanism(ss), tls13_GetHashSize(ss), + ss->protocolVariant, &updatedSecret); if (rv != SECSuccess) { return SECFailure; @@ -3347,7 +3348,8 @@ tls13_DeriveSecret(sslSocket *ss, PK11SymKey *key, hashes->u.raw, hashes->len, label, labelLen, tls13_GetHkdfMechanism(ss), - tls13_GetHashSize(ss), dest); + tls13_GetHashSize(ss), + ss->protocolVariant, dest); if (rv != SECSuccess) { LOG_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE); return SECFailure; @@ -3496,6 +3498,7 @@ tls13_DeriveTrafficKeys(sslSocket *ss, ssl3CipherSpec *spec, NULL, 0, kHkdfPurposeKey, strlen(kHkdfPurposeKey), bulkAlgorithm, keySize, + ss->protocolVariant, &spec->keyMaterial.key); if (rv != SECSuccess) { LOG_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE); @@ -3504,8 +3507,8 @@ tls13_DeriveTrafficKeys(sslSocket *ss, ssl3CipherSpec *spec, } if (IS_DTLS(ss) && spec->epoch > 0) { - rv = ssl_CreateMaskingContextInner(spec->version, - ss->ssl3.hs.cipher_suite, prk, kHkdfPurposeSn, + rv = ssl_CreateMaskingContextInner(spec->version, ss->ssl3.hs.cipher_suite, + ss->protocolVariant, prk, kHkdfPurposeSn, strlen(kHkdfPurposeSn), &spec->maskContext); if (rv != SECSuccess) { LOG_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE); @@ -3517,6 +3520,7 @@ tls13_DeriveTrafficKeys(sslSocket *ss, ssl3CipherSpec *spec, rv = tls13_HkdfExpandLabelRaw(prk, tls13_GetHash(ss), NULL, 0, kHkdfPurposeIv, strlen(kHkdfPurposeIv), + ss->protocolVariant, spec->keyMaterial.iv, ivSize); if (rv != SECSuccess) { LOG_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE); @@ -4432,7 +4436,8 @@ tls13_ComputeFinished(sslSocket *ss, PK11SymKey *baseKey, NULL, 0, label, strlen(label), tls13_GetHmacMechanism(ss), - tls13_GetHashSize(ss), &secret); + tls13_GetHashSize(ss), + ss->protocolVariant, &secret); if (rv != SECSuccess) { goto abort; } @@ -4970,7 +4975,8 @@ tls13_SendNewSessionTicket(sslSocket *ss, const PRUint8 *appToken, kHkdfLabelResumption, strlen(kHkdfLabelResumption), tls13_GetHkdfMechanism(ss), - tls13_GetHashSize(ss), &secret); + tls13_GetHashSize(ss), + ss->protocolVariant, &secret); if (rv != SECSuccess) { goto loser; } @@ -5204,7 +5210,8 @@ tls13_HandleNewSessionTicket(sslSocket *ss, PRUint8 *b, PRUint32 length) kHkdfLabelResumption, strlen(kHkdfLabelResumption), tls13_GetHkdfMechanism(ss), - tls13_GetHashSize(ss), &secret); + tls13_GetHashSize(ss), + ss->protocolVariant, &secret); if (rv != SECSuccess) { return SECFailure; } @@ -5393,6 +5400,11 @@ tls13_ProtectRecord(sslSocket *ss, PORT_Assert(cipher_def->type == type_aead); + /* If the following condition holds, we can skip the padding logic for + * DTLS 1.3 (4.2.3). This will be the case until we support a cipher + * with tag length < 15B. */ + PORT_Assert(tagLen + 1 /* cType */ >= 16); + /* Add the content type at the end. */ *(SSL_BUFFER_NEXT(wrBuf) + contentLen) = type; @@ -5403,9 +5415,7 @@ tls13_ProtectRecord(sslSocket *ss, return SECFailure; } if (needsLength) { - rv = sslBuffer_AppendNumber(&buf, contentLen + 1 + - cwSpec->cipherDef->tag_size, - 2); + rv = sslBuffer_AppendNumber(&buf, contentLen + 1 + tagLen, 2); if (rv != SECSuccess) { return SECFailure; } diff --git a/lib/ssl/tls13esni.c b/lib/ssl/tls13esni.c index f2f8d0a9c..4562b86a1 100644 --- a/lib/ssl/tls13esni.c +++ b/lib/ssl/tls13esni.c @@ -550,7 +550,7 @@ tls13_ComputeESNIKeys(const sslSocket *ss, hash, hashSize, kHkdfPurposeEsniKey, strlen(kHkdfPurposeEsniKey), ssl3_Alg2Mech(cipherDef->calg), - keySize, + keySize, ss->protocolVariant, &keyMat->key); if (rv != SECSuccess) { goto loser; @@ -558,7 +558,7 @@ tls13_ComputeESNIKeys(const sslSocket *ss, rv = tls13_HkdfExpandLabelRaw(Zx, suite->prf_hash, hash, hashSize, kHkdfPurposeEsniIv, strlen(kHkdfPurposeEsniIv), - keyMat->iv, ivSize); + ss->protocolVariant, keyMat->iv, ivSize); if (rv != SECSuccess) { goto loser; } diff --git a/lib/ssl/tls13hkdf.c b/lib/ssl/tls13hkdf.c index ab546e06f..afb322db4 100644 --- a/lib/ssl/tls13hkdf.c +++ b/lib/ssl/tls13hkdf.c @@ -126,7 +126,7 @@ tls13_HkdfExpandLabel(PK11SymKey *prk, SSLHashType baseHash, const PRUint8 *handshakeHash, unsigned int handshakeHashLen, const char *label, unsigned int labelLen, CK_MECHANISM_TYPE algorithm, unsigned int keySize, - PK11SymKey **keyp) + SSLProtocolVariant variant, PK11SymKey **keyp) { CK_NSS_HKDFParams params; SECItem paramsi = { siBuffer, NULL, 0 }; @@ -137,8 +137,12 @@ tls13_HkdfExpandLabel(PK11SymKey *prk, SSLHashType baseHash, sslBuffer infoBuf = SSL_BUFFER(info); PK11SymKey *derived; SECStatus rv; - const char *kLabelPrefix = "tls13 "; - const unsigned int kLabelPrefixLen = strlen(kLabelPrefix); + const char *kLabelPrefixTls = "tls13 "; + const char *kLabelPrefixDtls = "dtls13"; + const unsigned int kLabelPrefixLen = + (variant == ssl_variant_stream) ? strlen(kLabelPrefixTls) : strlen(kLabelPrefixDtls); + const char *kLabelPrefix = + (variant == ssl_variant_stream) ? kLabelPrefixTls : kLabelPrefixDtls; PORT_Assert(prk); PORT_Assert(keyp); @@ -229,7 +233,8 @@ SECStatus tls13_HkdfExpandLabelRaw(PK11SymKey *prk, SSLHashType baseHash, const PRUint8 *handshakeHash, unsigned int handshakeHashLen, const char *label, unsigned int labelLen, - unsigned char *output, unsigned int outputLen) + SSLProtocolVariant variant, unsigned char *output, + unsigned int outputLen) { PK11SymKey *derived = NULL; SECItem *rawkey; @@ -238,7 +243,7 @@ tls13_HkdfExpandLabelRaw(PK11SymKey *prk, SSLHashType baseHash, rv = tls13_HkdfExpandLabel(prk, baseHash, handshakeHash, handshakeHashLen, label, labelLen, kTlsHkdfInfo[baseHash].pkcs11Mech, outputLen, - &derived); + variant, &derived); if (rv != SECSuccess || !derived) { goto abort; } diff --git a/lib/ssl/tls13hkdf.h b/lib/ssl/tls13hkdf.h index 78347a11d..00e5ff1dd 100644 --- a/lib/ssl/tls13hkdf.h +++ b/lib/ssl/tls13hkdf.h @@ -23,13 +23,14 @@ SECStatus tls13_HkdfExpandLabelRaw( PK11SymKey *prk, SSLHashType baseHash, const PRUint8 *handshakeHash, unsigned int handshakeHashLen, const char *label, unsigned int labelLen, - unsigned char *output, unsigned int outputLen); + SSLProtocolVariant variant, unsigned char *output, + unsigned int outputLen); SECStatus tls13_HkdfExpandLabel( PK11SymKey *prk, SSLHashType baseHash, const PRUint8 *handshakeHash, unsigned int handshakeHashLen, const char *label, unsigned int labelLen, CK_MECHANISM_TYPE algorithm, unsigned int keySize, - PK11SymKey **keyp); + SSLProtocolVariant variant, PK11SymKey **keyp); #ifdef __cplusplus } diff --git a/lib/ssl/tls13replay.c b/lib/ssl/tls13replay.c index 628011144..224d103dd 100644 --- a/lib/ssl/tls13replay.c +++ b/lib/ssl/tls13replay.c @@ -269,7 +269,7 @@ tls13_IsReplay(const sslSocket *ss, const sslSessionID *sid) ss->xtnData.pskBinder.data, ss->xtnData.pskBinder.len, label, strlen(label), - buf, size); + ss->protocolVariant, buf, size); if (rv != SECSuccess) { return PR_TRUE; } -- cgit v1.2.1