summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Jacobs <kjacobs@mozilla.com>2020-06-02 15:12:15 +0000
committerKevin Jacobs <kjacobs@mozilla.com>2020-06-02 15:12:15 +0000
commitfc41a7aec9ed1d3a3b52cb4ec408f23f5c8c334c (patch)
treeda5d8e9d11411ab88e3f5f48e1426636818afb91
parent5dbc90bfa2b55123d918ead803b1ac5690ab67a3 (diff)
downloadnss-hg-fc41a7aec9ed1d3a3b52cb4ec408f23f5c8c334c.tar.gz
Bug 1603042 - TLS 1.3 out-of-band PSK support r=mt
This patch adds support for External (out-of-band) PSKs in TLS 1.3. An External PSK (EPSK) can be set by calling `SSL_AddExternalPsk`, and removed with `SSL_RemoveExternalPsk`. `SSL_AddExternalPsk0Rtt` can be used to add a PSK while also specifying a suite and max_early_data_size for use with 0-RTT. As part of handling PSKs more generically, the patch also changes how resumption PSKs are handled internally, so as to rely on the same mechanisms where possible. A socket is currently limited to only one External PSK at a time. If the server doesn't find the same identity for the configured EPSK, it will fall back to certificate authentication. Differential Revision: https://phabricator.services.mozilla.com/D56687
-rw-r--r--automation/abi-check/expected-report-libssl3.so.txt10
-rw-r--r--gtests/ssl_gtest/libssl_internals.c18
-rw-r--r--gtests/ssl_gtest/libssl_internals.h1
-rw-r--r--gtests/ssl_gtest/manifest.mn1
-rw-r--r--gtests/ssl_gtest/ssl_0rtt_unittest.cc55
-rw-r--r--gtests/ssl_gtest/ssl_extension_unittest.cc20
-rw-r--r--gtests/ssl_gtest/ssl_gtest.gyp1
-rw-r--r--gtests/ssl_gtest/tls_agent.cc45
-rw-r--r--gtests/ssl_gtest/tls_agent.h9
-rw-r--r--gtests/ssl_gtest/tls_connect.cc17
-rw-r--r--gtests/ssl_gtest/tls_connect.h5
-rw-r--r--gtests/ssl_gtest/tls_psk_unittest.cc514
-rw-r--r--lib/ssl/manifest.mn1
-rw-r--r--lib/ssl/ssl.gyp1
-rw-r--r--lib/ssl/ssl3con.c79
-rw-r--r--lib/ssl/ssl3ext.c2
-rw-r--r--lib/ssl/ssl3ext.h4
-rw-r--r--lib/ssl/sslerr.h1
-rw-r--r--lib/ssl/sslexp.h52
-rw-r--r--lib/ssl/sslimpl.h8
-rw-r--r--lib/ssl/sslinfo.c25
-rw-r--r--lib/ssl/sslsecur.c2
-rw-r--r--lib/ssl/sslsock.c20
-rw-r--r--lib/ssl/sslt.h10
-rw-r--r--lib/ssl/tls13con.c556
-rw-r--r--lib/ssl/tls13con.h6
-rw-r--r--lib/ssl/tls13exthandle.c181
-rw-r--r--lib/ssl/tls13psk.c219
-rw-r--r--lib/ssl/tls13psk.h58
-rw-r--r--lib/ssl/tls13replay.c5
30 files changed, 1608 insertions, 318 deletions
diff --git a/automation/abi-check/expected-report-libssl3.so.txt b/automation/abi-check/expected-report-libssl3.so.txt
index e69de29bb..6e410a597 100644
--- a/automation/abi-check/expected-report-libssl3.so.txt
+++ b/automation/abi-check/expected-report-libssl3.so.txt
@@ -0,0 +1,10 @@
+
+1 function with some indirect sub-type change:
+
+ [C] 'function SECStatus SSL_GetChannelInfo(PRFileDesc*, SSLChannelInfo*, PRUintn)' at sslinfo.c:14:1 has some indirect sub-type changes:
+ parameter 2 of type 'SSLChannelInfo*' has sub-type changes:
+ in pointed to type 'typedef SSLChannelInfo' at sslt.h:373:1:
+ underlying type 'struct SSLChannelInfoStr' at sslt.h:293:1 changed:
+ type size changed from 960 to 1024 (in bits)
+ 1 data member insertion:
+ 'SSLPskType SSLChannelInfoStr::pskType', at offset 960 (in bits) at sslt.h:369:1
diff --git a/gtests/ssl_gtest/libssl_internals.c b/gtests/ssl_gtest/libssl_internals.c
index fff310b97..9018f4df8 100644
--- a/gtests/ssl_gtest/libssl_internals.c
+++ b/gtests/ssl_gtest/libssl_internals.c
@@ -15,6 +15,24 @@
#include "secmodti.h"
#include "sslproto.h"
+SECStatus SSLInt_RemoveServerCertificates(PRFileDesc *fd) {
+ if (!fd) {
+ return SECFailure;
+ }
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ PRCList *cursor;
+ while (!PR_CLIST_IS_EMPTY(&ss->serverCerts)) {
+ cursor = PR_LIST_TAIL(&ss->serverCerts);
+ PR_REMOVE_LINK(cursor);
+ ssl_FreeServerCert((sslServerCert *)cursor);
+ }
+ return SECSuccess;
+}
+
SECStatus SSLInt_SetDCAdvertisedSigSchemes(PRFileDesc *fd,
const SSLSignatureScheme *schemes,
uint32_t num_sig_schemes) {
diff --git a/gtests/ssl_gtest/libssl_internals.h b/gtests/ssl_gtest/libssl_internals.h
index 2f26a4d3f..ff31f89ee 100644
--- a/gtests/ssl_gtest/libssl_internals.h
+++ b/gtests/ssl_gtest/libssl_internals.h
@@ -47,5 +47,6 @@ SECStatus SSLInt_TweakChannelInfoForDC(PRFileDesc *fd, PRBool changeAuthKeyBits,
SECStatus SSLInt_SetDCAdvertisedSigSchemes(PRFileDesc *fd,
const SSLSignatureScheme *schemes,
uint32_t num_sig_schemes);
+SECStatus SSLInt_RemoveServerCertificates(PRFileDesc *fd);
#endif // ndef libssl_internals_h_
diff --git a/gtests/ssl_gtest/manifest.mn b/gtests/ssl_gtest/manifest.mn
index d5e96a490..2cfa7cdd2 100644
--- a/gtests/ssl_gtest/manifest.mn
+++ b/gtests/ssl_gtest/manifest.mn
@@ -56,6 +56,7 @@ CPPSRCS = \
tls_hkdf_unittest.cc \
tls_filter.cc \
tls_protect.cc \
+ tls_psk_unittest.cc \
tls_subcerts_unittest.cc \
tls_esni_unittest.cc \
$(SSLKEYLOGFILE_FILES) \
diff --git a/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/gtests/ssl_gtest/ssl_0rtt_unittest.cc
index 88d0fdc51..42966c19c 100644
--- a/gtests/ssl_gtest/ssl_0rtt_unittest.cc
+++ b/gtests/ssl_gtest/ssl_0rtt_unittest.cc
@@ -15,6 +15,7 @@ extern "C" {
#include "libssl_internals.h"
}
+#include "cpputil.h"
#include "gtest_utils.h"
#include "nss_scoped_ptrs.h"
#include "tls_connect.h"
@@ -117,16 +118,12 @@ class TlsZeroRttReplayTest : public TlsConnectTls13 {
};
protected:
- void RunTest(bool rollover) {
- // Run the initial handshake
- SetupForZeroRtt();
-
+ void RunTest(bool rollover, const ScopedPK11SymKey& epsk) {
// Now run a true 0-RTT handshake, but capture the first packet.
auto first_packet = std::make_shared<SaveFirstPacket>();
client_->SetFilter(first_packet);
client_->Set0RttEnabled(true);
server_->Set0RttEnabled(true);
- ExpectResumption(RESUME_TICKET);
ZeroRttSendReceive(true, true);
Handshake();
EXPECT_LT(0U, first_packet->packet().len());
@@ -142,6 +139,11 @@ class TlsZeroRttReplayTest : public TlsConnectTls13 {
Reset();
server_->StartConnect();
server_->Set0RttEnabled(true);
+ server_->SetAntiReplayContext(anti_replay_);
+ if (epsk) {
+ AddPsk(epsk, std::string("foo"), ssl_hash_sha256,
+ TLS_CHACHA20_POLY1305_SHA256);
+ }
// Capture the early_data extension, which should not appear.
auto early_data_ext =
@@ -154,11 +156,41 @@ class TlsZeroRttReplayTest : public TlsConnectTls13 {
server_->Handshake();
EXPECT_FALSE(early_data_ext->captured());
}
+
+ void RunResPskTest(bool rollover) {
+ // Run the initial handshake
+ SetupForZeroRtt();
+ ExpectResumption(RESUME_TICKET);
+ RunTest(rollover, ScopedPK11SymKey(nullptr));
+ }
+
+ void RunExtPskTest(bool rollover) {
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_NE(nullptr, slot);
+
+ const std::vector<uint8_t> kPskDummyVal(16, 0xFF);
+ SECItem psk_item = {siBuffer, toUcharPtr(kPskDummyVal.data()),
+ static_cast<unsigned int>(kPskDummyVal.size())};
+ PK11SymKey* key =
+ PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, PK11_OriginUnwrap,
+ CKA_DERIVE, &psk_item, NULL);
+ ASSERT_NE(nullptr, key);
+ ScopedPK11SymKey scoped_psk(key);
+ RolloverAntiReplay();
+ AddPsk(scoped_psk, std::string("foo"), ssl_hash_sha256,
+ TLS_CHACHA20_POLY1305_SHA256);
+ StartConnect();
+ RunTest(rollover, scoped_psk);
+ }
};
-TEST_P(TlsZeroRttReplayTest, ZeroRttReplay) { RunTest(false); }
+TEST_P(TlsZeroRttReplayTest, ResPskZeroRttReplay) { RunResPskTest(false); }
-TEST_P(TlsZeroRttReplayTest, ZeroRttReplayAfterRollover) { RunTest(true); }
+TEST_P(TlsZeroRttReplayTest, ExtPskZeroRttReplay) { RunExtPskTest(false); }
+
+TEST_P(TlsZeroRttReplayTest, ZeroRttReplayAfterRollover) {
+ RunResPskTest(true);
+}
// Test that we don't try to send 0-RTT data when the server sent
// us a ticket without the 0-RTT flags.
@@ -477,15 +509,6 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) {
client_->CheckErrorCode(SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA);
}
-static void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent,
- size_t expected_size) {
- SSLPreliminaryChannelInfo preinfo;
- SECStatus rv =
- SSL_GetPreliminaryChannelInfo(agent->ssl_fd(), &preinfo, sizeof(preinfo));
- EXPECT_EQ(SECSuccess, rv);
- EXPECT_EQ(expected_size, static_cast<size_t>(preinfo.maxEarlyDataSize));
-}
-
TEST_P(TlsConnectTls13, SendTooMuchEarlyData) {
EnsureTlsSetup();
const char* big_message = "0123456789abcdef";
diff --git a/gtests/ssl_gtest/ssl_extension_unittest.cc b/gtests/ssl_gtest/ssl_extension_unittest.cc
index b85568f43..fb995953f 100644
--- a/gtests/ssl_gtest/ssl_extension_unittest.cc
+++ b/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -946,6 +946,26 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
}
+// Do the same with an External PSK.
+TEST_P(TlsConnectTls13, TestTls13PskInvalidBinderValue) {
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(!!slot);
+ ScopedPK11SymKey key(
+ PK11_KeyGen(slot.get(), CKM_HKDF_KEY_GEN, nullptr, 16, nullptr));
+ ASSERT_TRUE(!!key);
+ AddPsk(key, std::string("foo"), ssl_hash_sha256);
+ StartConnect();
+ ASSERT_TRUE(client_->MaybeSetResumptionToken());
+
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1);
+ });
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
// Extend the binder by one.
TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
SetupForResume();
diff --git a/gtests/ssl_gtest/ssl_gtest.gyp b/gtests/ssl_gtest/ssl_gtest.gyp
index c44af7ed1..5491a0725 100644
--- a/gtests/ssl_gtest/ssl_gtest.gyp
+++ b/gtests/ssl_gtest/ssl_gtest.gyp
@@ -57,6 +57,7 @@
'tls_hkdf_unittest.cc',
'tls_esni_unittest.cc',
'tls_protect.cc',
+ 'tls_psk_unittest.cc',
'tls_subcerts_unittest.cc'
],
'dependencies': [
diff --git a/gtests/ssl_gtest/tls_agent.cc b/gtests/ssl_gtest/tls_agent.cc
index 06312f70b..fa02b0b88 100644
--- a/gtests/ssl_gtest/tls_agent.cc
+++ b/gtests/ssl_gtest/tls_agent.cc
@@ -73,8 +73,8 @@ TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var)
falsestart_enabled_(false),
expected_version_(0),
expected_cipher_suite_(0),
- expect_resumption_(false),
expect_client_auth_(false),
+ expect_psk_(ssl_psk_none),
can_falsestart_hook_called_(false),
sni_hook_called_(false),
auth_certificate_hook_called_(false),
@@ -301,7 +301,7 @@ bool TlsAgent::MaybeSetResumptionToken() {
// rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR
// if the resumption token was bad (expired/malformed/etc.).
- if (expect_resumption_) {
+ if (expect_psk_ == ssl_psk_resume) {
// Only in case we expect resumption this has to be successful. We might
// not expect resumption due to some reason but the token is totally fine.
EXPECT_EQ(SECSuccess, rv);
@@ -309,8 +309,8 @@ bool TlsAgent::MaybeSetResumptionToken() {
if (rv != SECSuccess) {
EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError());
resumption_token_.clear();
- EXPECT_FALSE(expect_resumption_);
- if (expect_resumption_) return false;
+ EXPECT_FALSE(expect_psk_ == ssl_psk_resume);
+ if (expect_psk_ == ssl_psk_resume) return false;
}
}
@@ -634,7 +634,9 @@ void TlsAgent::CheckAuthType(SSLAuthType auth,
SSLSignatureScheme sig_scheme) const {
EXPECT_EQ(STATE_CONNECTED, state_);
EXPECT_EQ(auth, info_.authType);
- EXPECT_EQ(server_key_bits_, info_.authKeyBits);
+ if (auth != ssl_auth_psk) {
+ EXPECT_EQ(server_key_bits_, info_.authKeyBits);
+ }
if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
switch (auth) {
case ssl_auth_rsa_sign:
@@ -685,13 +687,31 @@ void TlsAgent::EnableFalseStart() {
SetOption(SSL_ENABLE_FALSE_START, PR_TRUE);
}
-void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
+void TlsAgent::ExpectPsk() { expect_psk_ = ssl_psk_external; }
+
+void TlsAgent::ExpectResumption() { expect_psk_ = ssl_psk_resume; }
void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
EXPECT_TRUE(EnsureTlsSetup());
EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len));
}
+void TlsAgent::AddPsk(const ScopedPK11SymKey& psk, std::string label,
+ SSLHashType hash, uint16_t zeroRttSuite) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ EXPECT_EQ(SECSuccess, SSL_AddExternalPsk0Rtt(
+ ssl_fd(), psk.get(),
+ reinterpret_cast<const uint8_t*>(label.data()),
+ label.length(), hash, zeroRttSuite, 1000));
+}
+
+void TlsAgent::RemovePsk(std::string label) {
+ EXPECT_EQ(SECSuccess,
+ SSL_RemoveExternalPsk(
+ ssl_fd(), reinterpret_cast<const uint8_t*>(label.data()),
+ label.length()));
+}
+
void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
const std::string& expected) const {
SSLNextProtoState alpn_state;
@@ -821,22 +841,22 @@ void TlsAgent::CheckPreliminaryInfo() {
void TlsAgent::CheckCallbacks() const {
// If false start happens, the handshake is reported as being complete at the
// point that false start happens.
- if (expect_resumption_ || !falsestart_enabled_) {
+ if (expect_psk_ == ssl_psk_resume || !falsestart_enabled_) {
EXPECT_TRUE(handshake_callback_called_);
}
// These callbacks shouldn't fire if we are resuming, except on TLS 1.3.
if (role_ == SERVER) {
PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn);
- EXPECT_EQ(((!expect_resumption_ && have_sni) ||
+ EXPECT_EQ(((expect_psk_ != ssl_psk_resume && have_sni) ||
expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3),
sni_hook_called_);
} else {
- EXPECT_EQ(!expect_resumption_, auth_certificate_hook_called_);
+ EXPECT_EQ(expect_psk_ == ssl_psk_none, auth_certificate_hook_called_);
// Note that this isn't unconditionally called, even with false start on.
// But the callback is only skipped if a cipher that is ridiculously weak
// (80 bits) is chosen. Don't test that: plan to remove bad ciphers.
- EXPECT_EQ(falsestart_enabled_ && !expect_resumption_,
+ EXPECT_EQ(falsestart_enabled_ && expect_psk_ != ssl_psk_resume,
can_falsestart_hook_called_);
}
}
@@ -872,7 +892,7 @@ void TlsAgent::ValidateCipherSpecs() {
} else {
// For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec
// until the holddown timer runs down.
- if (expect_resumption_) {
+ if (expect_psk_ == ssl_psk_resume) {
if (role_ == CLIENT) {
expected = 3;
}
@@ -910,7 +930,8 @@ void TlsAgent::Connected() {
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ(sizeof(info_), info_.length);
- EXPECT_EQ(expect_resumption_, info_.resumed == PR_TRUE);
+ EXPECT_EQ(expect_psk_ == ssl_psk_resume, info_.resumed == PR_TRUE);
+ EXPECT_EQ(expect_psk_, info_.pskType);
// Preliminary values are exposed through callbacks during the handshake.
// If either expected values were set or the callbacks were called, check
diff --git a/gtests/ssl_gtest/tls_agent.h b/gtests/ssl_gtest/tls_agent.h
index 19c84a82e..e43b68b2b 100644
--- a/gtests/ssl_gtest/tls_agent.h
+++ b/gtests/ssl_gtest/tls_agent.h
@@ -9,6 +9,7 @@
#include "prio.h"
#include "ssl.h"
+#include "sslproto.h"
#include <functional>
#include <iostream>
@@ -157,6 +158,7 @@ class TlsAgent : public PollTarget {
void SetServerKeyBits(uint16_t bits);
void ExpectReadWriteError();
void EnableFalseStart();
+ void ExpectPsk();
void ExpectResumption();
void SkipVersionChecks();
void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
@@ -176,6 +178,9 @@ class TlsAgent : public PollTarget {
// Send data directly to the underlying socket, skipping the TLS layer.
void SendDirect(const DataBuffer& buf);
void SendRecordDirect(const TlsRecord& record);
+ void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash,
+ uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL);
+ void RemovePsk(std::string label);
void ReadBytes(size_t max = 16384U);
void ResetSentBytes(); // Hack to test drops.
void EnableExtendedMasterSecret();
@@ -248,6 +253,8 @@ class TlsAgent : public PollTarget {
return true;
}
+ void expected_cipher_suite(uint16_t suite) { expected_cipher_suite_ = suite; }
+
std::string cipher_suite_name() const {
if (state_ != STATE_CONNECTED) return "UNKNOWN";
@@ -418,8 +425,8 @@ class TlsAgent : public PollTarget {
bool falsestart_enabled_;
uint16_t expected_version_;
uint16_t expected_cipher_suite_;
- bool expect_resumption_;
bool expect_client_auth_;
+ SSLPskType expect_psk_;
bool can_falsestart_hook_called_;
bool sni_hook_called_;
bool auth_certificate_hook_called_;
diff --git a/gtests/ssl_gtest/tls_connect.cc b/gtests/ssl_gtest/tls_connect.cc
index 6d0cf89cf..915f3705a 100644
--- a/gtests/ssl_gtest/tls_connect.cc
+++ b/gtests/ssl_gtest/tls_connect.cc
@@ -401,6 +401,15 @@ void TlsConnectTestBase::CheckConnected() {
server_->CheckSecretsDestroyed();
}
+void TlsConnectTestBase::CheckEarlyDataLimit(
+ const std::shared_ptr<TlsAgent>& agent, size_t expected_size) {
+ SSLPreliminaryChannelInfo preinfo;
+ SECStatus rv =
+ SSL_GetPreliminaryChannelInfo(agent->ssl_fd(), &preinfo, sizeof(preinfo));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(expected_size, static_cast<size_t>(preinfo.maxEarlyDataSize));
+}
+
void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
SSLAuthType auth_type,
SSLSignatureScheme sig_scheme) const {
@@ -520,6 +529,14 @@ void TlsConnectTestBase::SetExpectedVersion(uint16_t version) {
server_->SetExpectedVersion(version);
}
+void TlsConnectTestBase::AddPsk(const ScopedPK11SymKey& psk, std::string label,
+ SSLHashType hash, uint16_t zeroRttSuite) {
+ client_->AddPsk(psk, label, hash, zeroRttSuite);
+ server_->AddPsk(psk, label, hash, zeroRttSuite);
+ client_->ExpectPsk();
+ server_->ExpectPsk();
+}
+
void TlsConnectTestBase::DisableAllCiphers() {
EnsureTlsSetup();
client_->DisableAllCiphers();
diff --git a/gtests/ssl_gtest/tls_connect.h b/gtests/ssl_gtest/tls_connect.h
index 23c60bf4f..3a43d6bca 100644
--- a/gtests/ssl_gtest/tls_connect.h
+++ b/gtests/ssl_gtest/tls_connect.h
@@ -80,6 +80,8 @@ class TlsConnectTestBase : public ::testing::Test {
void ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
void ConnectExpectFailOneSide(TlsAgent::Role failingSide);
void ConnectWithCipherSuite(uint16_t cipher_suite);
+ void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent,
+ size_t expected_size);
// Check that the keys used in the handshake match expectations.
void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const;
@@ -120,6 +122,9 @@ class TlsConnectTestBase : public ::testing::Test {
void EnableSrtp();
void CheckSrtp() const;
void SendReceive(size_t total = 50);
+ void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash,
+ uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL);
+ void RemovePsk(std::string label);
void SetupForZeroRtt();
void SetupForResume();
void ZeroRttSendReceive(
diff --git a/gtests/ssl_gtest/tls_psk_unittest.cc b/gtests/ssl_gtest/tls_psk_unittest.cc
new file mode 100644
index 000000000..c75297bc8
--- /dev/null
+++ b/gtests/ssl_gtest/tls_psk_unittest.cc
@@ -0,0 +1,514 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+class Tls13PskTest : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ Tls13PskTest()
+ : TlsConnectTestBase(std::get<0>(GetParam()),
+ SSL_LIBRARY_VERSION_TLS_1_3),
+ suite_(std::get<1>(GetParam())) {}
+
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
+ scoped_psk_.reset(GetPsk());
+ ASSERT_TRUE(!!scoped_psk_);
+ }
+
+ private:
+ PK11SymKey* GetPsk() {
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ if (!slot) {
+ ADD_FAILURE();
+ return nullptr;
+ }
+
+ SECItem psk_item;
+ psk_item.type = siBuffer;
+ psk_item.len = sizeof(kPskDummyVal_);
+ psk_item.data = const_cast<uint8_t*>(kPskDummyVal_);
+
+ PK11SymKey* key =
+ PK11_ImportSymKey(slot.get(), CKM_HKDF_KEY_GEN, PK11_OriginUnwrap,
+ CKA_DERIVE, &psk_item, NULL);
+ if (!key) {
+ ADD_FAILURE();
+ }
+ return key;
+ }
+
+ protected:
+ ScopedPK11SymKey scoped_psk_;
+ const uint16_t suite_;
+ const uint8_t kPskDummyVal_[16] = {0x01, 0x02, 0x03, 0x04, 0x05,
+ 0x06, 0x07, 0x08, 0x09, 0x0a,
+ 0x0b, 0x0c, 0x0d, 0x0e, 0x0f};
+ const std::string kPskDummyLabel_ = "NSS PSK GTEST label";
+ const SSLHashType kPskHash_ = ssl_hash_sha384;
+};
+
+// TLS 1.3 PSK connection test.
+TEST_P(Tls13PskTest, NormalExternal) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+ client_->RemovePsk(kPskDummyLabel_);
+ server_->RemovePsk(kPskDummyLabel_);
+
+ // Removing it again should fail.
+ EXPECT_EQ(SECFailure, SSL_RemoveExternalPsk(client_->ssl_fd(),
+ reinterpret_cast<const uint8_t*>(
+ kPskDummyLabel_.data()),
+ kPskDummyLabel_.length()));
+ EXPECT_EQ(SECFailure, SSL_RemoveExternalPsk(server_->ssl_fd(),
+ reinterpret_cast<const uint8_t*>(
+ kPskDummyLabel_.data()),
+ kPskDummyLabel_.length()));
+}
+
+TEST_P(Tls13PskTest, KeyTooLarge) {
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(!!slot);
+ ScopedPK11SymKey scoped_psk(PK11_KeyGen(
+ slot.get(), CKM_GENERIC_SECRET_KEY_GEN, nullptr, 128, nullptr));
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+}
+
+// Attempt to use a PSK with the wrong PRF hash.
+// "Clients MUST verify that...the server selected a cipher suite
+// indicating a Hash associated with the PSK"
+TEST_P(Tls13PskTest, ClientVerifyHashType) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ MakeTlsFilter<SelectedCipherSuiteReplacer>(server_,
+ TLS_CHACHA20_POLY1305_SHA256);
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE, server_->error_code());
+ } else {
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ }
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
+}
+
+// Different EPSKs (by label) on each endpoint. Expect cert auth.
+TEST_P(Tls13PskTest, LabelMismatch) {
+ client_->AddPsk(scoped_psk_, std::string("foo"), kPskHash_);
+ server_->AddPsk(scoped_psk_, std::string("bar"), kPskHash_);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+}
+
+SSLHelloRetryRequestAction RetryFirstHello(
+ PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen,
+ PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<size_t*>(arg);
+ ++*called;
+ EXPECT_EQ(0U, clientTokenLen);
+ EXPECT_EQ(*called, firstHello ? 1U : 2U);
+ return firstHello ? ssl_hello_retry_request : ssl_hello_retry_accept;
+}
+
+// Test resumption PSK with HRR.
+TEST_P(Tls13PskTest, ResPskRetryStateless) {
+ ConfigureSelfEncrypt();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ StartConnect();
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryFirstHello, &cb_called));
+ ExpectResumption(RESUME_TICKET);
+ Handshake();
+ CheckConnected();
+ EXPECT_EQ(2U, cb_called);
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+ SendReceive();
+}
+
+// Test external PSK with HRR.
+TEST_P(Tls13PskTest, ExtPskRetryStateless) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), RetryFirstHello, &cb_called));
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ EXPECT_EQ(1U, cb_called);
+ auto replacement = std::make_shared<TlsAgent>(
+ server_->name(), TlsAgent::SERVER, server_->variant());
+ server_ = replacement;
+ server_->SetVersionRange(version_, version_);
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+ server_->AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ server_->ExpectPsk();
+ server_->StartConnect();
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+}
+
+// Server not configured with PSK and sends a certificate instead of
+// a selected_identity. Client should attempt certificate authentication.
+TEST_P(Tls13PskTest, ClientOnly) {
+ client_->AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+}
+
+// Set a PSK, remove psk_key_exchange_modes.
+TEST_P(Tls13PskTest, DropKexModes) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ StartConnect();
+ MakeTlsFilter<TlsExtensionDropper>(client_,
+ ssl_tls13_psk_key_exchange_modes_xtn);
+ ConnectExpectAlert(server_, kTlsAlertMissingExtension);
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES);
+}
+
+// "Clients MUST verify that...a server "key_share" extension is present
+// if required by the ClientHello "psk_key_exchange_modes" extension."
+// As we don't support PSK without DH, it is always required.
+TEST_P(Tls13PskTest, DropRequiredKeyShare) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ StartConnect();
+ MakeTlsFilter<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn);
+ client_->ExpectSendAlert(kTlsAlertMissingExtension);
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ } else {
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ }
+ client_->CheckErrorCode(SSL_ERROR_MISSING_KEY_SHARE);
+}
+
+// "Clients MUST verify that...the server's selected_identity is
+// within the range supplied by the client". We send one OfferedPsk.
+TEST_P(Tls13PskTest, InvalidSelectedIdentity) {
+ static const uint8_t selected_identity[] = {0x00, 0x01};
+ DataBuffer buf(selected_identity, sizeof(selected_identity));
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ StartConnect();
+ MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_pre_shared_key_xtn,
+ buf);
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ } else {
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ }
+ client_->CheckErrorCode(SSL_ERROR_MALFORMED_PRE_SHARED_KEY);
+}
+
+// Resume-eligible reconnect with an EPSK configured.
+// Expect the EPSK to be used.
+TEST_P(Tls13PskTest, PreferEpsk) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ ExpectResumption(RESUME_NONE);
+ StartConnect();
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+}
+
+// Enable resumption, but connect (initially) with an EPSK.
+// Expect no session ticket.
+TEST_P(Tls13PskTest, SuppressNewSessionTicket) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ auto nst_capture =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket);
+ nst_capture->EnableDecryption();
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), nullptr, 0));
+ EXPECT_EQ(0U, nst_capture->buffer().len());
+ if (variant_ == ssl_variant_stream) {
+ EXPECT_EQ(SSL_ERROR_FEATURE_DISABLED, PORT_GetError());
+ } else {
+ EXPECT_EQ(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_VERSION, PORT_GetError());
+ }
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+}
+
+TEST_P(Tls13PskTest, BadConfigValues) {
+ EXPECT_TRUE(client_->EnsureTlsSetup());
+ std::vector<uint8_t> label{'L', 'A', 'B', 'E', 'L'};
+ EXPECT_EQ(SECFailure,
+ SSL_AddExternalPsk(client_->ssl_fd(), nullptr, label.data(),
+ label.size(), kPskHash_));
+ EXPECT_EQ(SECFailure, SSL_AddExternalPsk(client_->ssl_fd(), scoped_psk_.get(),
+ nullptr, label.size(), kPskHash_));
+
+ EXPECT_EQ(SECFailure, SSL_AddExternalPsk(client_->ssl_fd(), scoped_psk_.get(),
+ label.data(), 0, kPskHash_));
+ EXPECT_EQ(SECSuccess,
+ SSL_AddExternalPsk(client_->ssl_fd(), scoped_psk_.get(),
+ label.data(), label.size(), ssl_hash_sha256));
+
+ EXPECT_EQ(SECFailure,
+ SSL_RemoveExternalPsk(client_->ssl_fd(), nullptr, label.size()));
+
+ EXPECT_EQ(SECFailure,
+ SSL_RemoveExternalPsk(client_->ssl_fd(), label.data(), 0));
+
+ EXPECT_EQ(SECSuccess, SSL_RemoveExternalPsk(client_->ssl_fd(), label.data(),
+ label.size()));
+}
+
+// If the server has an EPSK configured with a ciphersuite not supported
+// by the client, it should use certificate authentication.
+TEST_P(Tls13PskTest, FallbackUnsupportedCiphersuite) {
+ client_->AddPsk(scoped_psk_, kPskDummyLabel_, ssl_hash_sha256,
+ TLS_AES_128_GCM_SHA256);
+ server_->AddPsk(scoped_psk_, kPskDummyLabel_, ssl_hash_sha256,
+ TLS_CHACHA20_POLY1305_SHA256);
+
+ client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+}
+
+// That fallback should not occur if there is no cipher overlap.
+TEST_P(Tls13PskTest, ExplicitSuiteNoOverlap) {
+ client_->AddPsk(scoped_psk_, kPskDummyLabel_, ssl_hash_sha256,
+ TLS_AES_128_GCM_SHA256);
+ server_->AddPsk(scoped_psk_, kPskDummyLabel_, ssl_hash_sha256,
+ TLS_CHACHA20_POLY1305_SHA256);
+
+ client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ server_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256);
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+TEST_P(Tls13PskTest, SuppressHandshakeCertReq) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ server_->SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE);
+ server_->SetOption(SSL_REQUIRE_CERTIFICATE, PR_TRUE);
+ const std::set<uint8_t> hs_types = {ssl_hs_certificate,
+ ssl_hs_certificate_request};
+ auto cr_cert_capture = MakeTlsFilter<TlsHandshakeRecorder>(server_, hs_types);
+ cr_cert_capture->EnableDecryption();
+
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+ EXPECT_EQ(0U, cr_cert_capture->buffer().len());
+}
+
+TEST_P(Tls13PskTest, DisallowClientConfigWithoutServerCert) {
+ AddPsk(scoped_psk_, kPskDummyLabel_, kPskHash_);
+ server_->SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE);
+ server_->SetOption(SSL_REQUIRE_CERTIFICATE, PR_TRUE);
+ const std::set<uint8_t> hs_types = {ssl_hs_certificate,
+ ssl_hs_certificate_request};
+ auto cr_cert_capture = MakeTlsFilter<TlsHandshakeRecorder>(server_, hs_types);
+ cr_cert_capture->EnableDecryption();
+
+ EXPECT_EQ(SECSuccess, SSLInt_RemoveServerCertificates(server_->ssl_fd()));
+
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ EXPECT_EQ(0U, cr_cert_capture->buffer().len());
+}
+
+TEST_F(TlsConnectStreamTls13, ClientRejectHandshakeCertReq) {
+ // Stream only, as the filter doesn't support DTLS 1.3 yet.
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(!!slot);
+ ScopedPK11SymKey scoped_psk(PK11_KeyGen(
+ slot.get(), CKM_GENERIC_SECRET_KEY_GEN, nullptr, 32, nullptr));
+ AddPsk(scoped_psk, std::string("foo"), ssl_hash_sha256);
+ // Inject a CR after EE. This would be legal if not for ssl_auth_psk.
+ auto filter = MakeTlsFilter<TlsEncryptedHandshakeMessageReplacer>(
+ server_, kTlsHandshakeFinished, kTlsHandshakeCertificateRequest);
+ filter->EnableDecryption();
+
+ ExpectAlert(client_, kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CERT_REQUEST);
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+TEST_F(TlsConnectStreamTls13, RejectPha) {
+ // Stream only, as the filter doesn't support DTLS 1.3 yet.
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ ASSERT_TRUE(!!slot);
+ ScopedPK11SymKey scoped_psk(PK11_KeyGen(
+ slot.get(), CKM_GENERIC_SECRET_KEY_GEN, nullptr, 32, nullptr));
+ AddPsk(scoped_psk, std::string("foo"), ssl_hash_sha256);
+ server_->SetOption(SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE);
+ auto kuToCr = MakeTlsFilter<TlsEncryptedHandshakeMessageReplacer>(
+ server_, kTlsHandshakeKeyUpdate, kTlsHandshakeCertificateRequest);
+ kuToCr->EnableDecryption();
+ Connect();
+
+ // Make sure the direct path is blocked.
+ EXPECT_EQ(SECFailure, SSL_SendCertificateRequest(server_->ssl_fd()));
+ EXPECT_EQ(SSL_ERROR_FEATURE_DISABLED, PORT_GetError());
+
+ // Inject a PHA CR. Since this is not allowed, send KeyUpdate
+ // and change the message type.
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ ExpectAlert(client_, kTlsAlertUnexpectedMessage);
+ client_->Handshake(); // Eat the CR.
+ server_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CERT_REQUEST);
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+class Tls13PskTestWithCiphers : public Tls13PskTest {};
+
+TEST_P(Tls13PskTestWithCiphers, 0RttCiphers) {
+ RolloverAntiReplay();
+ AddPsk(scoped_psk_, kPskDummyLabel_, tls13_GetHashForCipherSuite(suite_),
+ suite_);
+ StartConnect();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_psk, ssl_sig_none);
+}
+
+TEST_P(Tls13PskTestWithCiphers, 0RttMaxEarlyData) {
+ EnsureTlsSetup();
+ RolloverAntiReplay();
+ const char* big_message = "0123456789abcdef";
+ const size_t short_size = strlen(big_message) - 1;
+ const PRInt32 short_length = static_cast<PRInt32>(short_size);
+
+ // Set up the PSK
+ EXPECT_EQ(SECSuccess,
+ SSL_AddExternalPsk0Rtt(
+ client_->ssl_fd(), scoped_psk_.get(),
+ reinterpret_cast<const uint8_t*>(kPskDummyLabel_.data()),
+ kPskDummyLabel_.length(), tls13_GetHashForCipherSuite(suite_),
+ suite_, short_length));
+ EXPECT_EQ(SECSuccess,
+ SSL_AddExternalPsk0Rtt(
+ server_->ssl_fd(), scoped_psk_.get(),
+ reinterpret_cast<const uint8_t*>(kPskDummyLabel_.data()),
+ kPskDummyLabel_.length(), tls13_GetHashForCipherSuite(suite_),
+ suite_, short_length));
+ client_->ExpectPsk();
+ server_->ExpectPsk();
+ client_->expected_cipher_suite(suite_);
+ server_->expected_cipher_suite(suite_);
+ StartConnect();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ client_->Handshake();
+ CheckEarlyDataLimit(client_, short_size);
+
+ PRInt32 sent;
+ // Writing more than the limit will succeed in TLS, but fail in DTLS.
+ if (variant_ == ssl_variant_stream) {
+ sent = PR_Write(client_->ssl_fd(), big_message,
+ static_cast<PRInt32>(strlen(big_message)));
+ } else {
+ sent = PR_Write(client_->ssl_fd(), big_message,
+ static_cast<PRInt32>(strlen(big_message)));
+ EXPECT_GE(0, sent);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Try an exact-sized write now.
+ sent = PR_Write(client_->ssl_fd(), big_message, short_length);
+ }
+ EXPECT_EQ(short_length, sent);
+
+ // Even a single octet write should now fail.
+ sent = PR_Write(client_->ssl_fd(), big_message, 1);
+ EXPECT_GE(0, sent);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Process the ClientHello and read 0-RTT.
+ server_->Handshake();
+ CheckEarlyDataLimit(server_, short_size);
+
+ std::vector<uint8_t> buf(short_size + 1);
+ PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity());
+ EXPECT_EQ(short_length, read);
+ EXPECT_EQ(0, memcmp(big_message, buf.data(), short_size));
+
+ // Second read fails.
+ read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity());
+ EXPECT_EQ(SECFailure, read);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+static const uint16_t k0RttCipherDefs[] = {TLS_CHACHA20_POLY1305_SHA256,
+ TLS_AES_128_GCM_SHA256,
+ TLS_AES_256_GCM_SHA384};
+
+static const uint16_t kDefaultSuite[] = {TLS_CHACHA20_POLY1305_SHA256};
+
+INSTANTIATE_TEST_CASE_P(Tls13PskTest, Tls13PskTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ ::testing::ValuesIn(kDefaultSuite)));
+
+INSTANTIATE_TEST_CASE_P(
+ Tls13PskTestWithCiphers, Tls13PskTestWithCiphers,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ ::testing::ValuesIn(k0RttCipherDefs)));
+
+} // namespace nss_test
diff --git a/lib/ssl/manifest.mn b/lib/ssl/manifest.mn
index 39e013641..5b3584ba9 100644
--- a/lib/ssl/manifest.mn
+++ b/lib/ssl/manifest.mn
@@ -55,6 +55,7 @@ CSRCS = \
tls13exthandle.c \
tls13hashstate.c \
tls13hkdf.c \
+ tls13psk.c \
tls13replay.c \
sslcert.c \
sslgrp.c \
diff --git a/lib/ssl/ssl.gyp b/lib/ssl/ssl.gyp
index 3e1b5531a..5c84a1f03 100644
--- a/lib/ssl/ssl.gyp
+++ b/lib/ssl/ssl.gyp
@@ -48,6 +48,7 @@
'tls13exthandle.c',
'tls13hashstate.c',
'tls13hkdf.c',
+ 'tls13psk.c',
'tls13replay.c',
'tls13subcerts.c',
],
diff --git a/lib/ssl/ssl3con.c b/lib/ssl/ssl3con.c
index 7f581e792..930635850 100644
--- a/lib/ssl/ssl3con.c
+++ b/lib/ssl/ssl3con.c
@@ -21,6 +21,7 @@
#include "sslerr.h"
#include "ssl3ext.h"
#include "ssl3exthandle.h"
+#include "tls13psk.h"
#include "tls13subcerts.h"
#include "prtime.h"
#include "prinrval.h"
@@ -912,6 +913,13 @@ ssl3_config_match_init(sslSocket *ss)
if (SSL_ALL_VERSIONS_DISABLED(&ss->vrange)) {
return 0;
}
+ if (ss->sec.isServer && ss->psk &&
+ PR_CLIST_IS_EMPTY(&ss->serverCerts) &&
+ (ss->opt.requestCertificate || ss->opt.requireCertificate)) {
+ /* PSK and certificate auth cannot be combined. */
+ PORT_SetError(SSL_ERROR_NO_CERTIFICATE);
+ return 0;
+ }
if (ssl_CheckSignatureSchemes(ss) != SECSuccess) {
return 0; /* Code already set. */
}
@@ -1009,6 +1017,16 @@ ssl3_config_match(const ssl3CipherSuiteCfg *suite, PRUint8 policy,
return PR_FALSE;
}
+ /* If a PSK is selected, disable suites that use a different hash than
+ * the PSK. We advertise non-PSK-compatible suites in the CH, as we could
+ * fallback to certificate auth. The client handler will check hash
+ * compatibility before committing to use the PSK. */
+ if (ss->xtnData.selectedPsk) {
+ if (ss->xtnData.selectedPsk->hash != cipher_def->prf_hash) {
+ return PR_FALSE;
+ }
+ }
+
return ssl3_CipherSuiteAllowedForVersionRange(suite->cipher_suite, vrange);
}
@@ -5333,10 +5351,11 @@ ssl3_SendClientHello(sslSocket *ss, sslClientHelloType type)
}
if (extensionBuf.len) {
- /* If we are sending a PSK binder, replace the dummy value. Note that
- * we only set statelessResume on the client in TLS 1.3. */
- if (ss->statelessResume &&
- ss->xtnData.sentSessionTicketInClientHello) {
+ /* If we are sending a PSK binder, replace the dummy value. */
+ if (ssl3_ExtensionAdvertised(ss, ssl_tls13_pre_shared_key_xtn)) {
+ PORT_Assert(ss->psk ||
+ (ss->statelessResume && ss->xtnData.sentSessionTicketInClientHello));
+ PORT_Assert(!PR_CLIST_IS_EMPTY(&ss->ssl3.hs.psks));
rv = tls13_WriteExtensionsWithBinder(ss, &extensionBuf);
} else {
rv = ssl3_AppendBufferToHandshakeVariable(ss, &extensionBuf, 2);
@@ -8105,26 +8124,53 @@ ssl3_KEASupportsTickets(const ssl3KEADef *kea_def)
return PR_TRUE;
}
+static PRBool
+ssl3_PeerSupportsCipherSuite(const SECItem *peerSuites, uint16_t suite)
+{
+ for (unsigned int i = 0; i + 1 < peerSuites->len; i += 2) {
+ PRUint16 suite_i = (peerSuites->data[i] << 8) | peerSuites->data[i + 1];
+ if (suite_i == suite) {
+ return PR_TRUE;
+ }
+ }
+ return PR_FALSE;
+}
+
SECStatus
ssl3_NegotiateCipherSuiteInner(sslSocket *ss, const SECItem *suites,
PRUint16 version, PRUint16 *suitep)
{
- unsigned int j;
unsigned int i;
+ SSLVersionRange vrange = { version, version };
- for (j = 0; j < ssl_V3_SUITES_IMPLEMENTED; j++) {
- ssl3CipherSuiteCfg *suite = &ss->cipherSuites[j];
- SSLVersionRange vrange = { version, version };
+ /* If we negotiated an External PSK and that PSK has a ciphersuite
+ * configured, we need to constrain our choice. If the client does
+ * not support it, negotiate a certificate auth suite and fall back.
+ */
+ if (ss->version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ ss->xtnData.selectedPsk &&
+ ss->xtnData.selectedPsk->type == ssl_psk_external &&
+ ss->xtnData.selectedPsk->zeroRttSuite != TLS_NULL_WITH_NULL_NULL) {
+ PRUint16 pskSuite = ss->xtnData.selectedPsk->zeroRttSuite;
+ ssl3CipherSuiteCfg *pskSuiteCfg = ssl_LookupCipherSuiteCfgMutable(pskSuite,
+ ss->cipherSuites);
+ if (ssl3_config_match(pskSuiteCfg, ss->ssl3.policy, &vrange, ss) &&
+ ssl3_PeerSupportsCipherSuite(suites, pskSuite)) {
+ *suitep = pskSuite;
+ return SECSuccess;
+ }
+ }
+
+ for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) {
+ ssl3CipherSuiteCfg *suite = &ss->cipherSuites[i];
if (!ssl3_config_match(suite, ss->ssl3.policy, &vrange, ss)) {
continue;
}
- for (i = 0; i + 1 < suites->len; i += 2) {
- PRUint16 suite_i = (suites->data[i] << 8) | suites->data[i + 1];
- if (suite_i == suite->cipher_suite) {
- *suitep = suite_i;
- return SECSuccess;
- }
+ if (!ssl3_PeerSupportsCipherSuite(suites, suite->cipher_suite)) {
+ continue;
}
+ *suitep = suite->cipher_suite;
+ return SECSuccess;
}
PORT_SetError(SSL_ERROR_NO_CYPHER_OVERLAP);
return SECFailure;
@@ -13102,7 +13148,6 @@ ssl3_InitState(sslSocket *ss)
ss->ssl3.hs.currentSecret = NULL;
ss->ssl3.hs.resumptionMasterSecret = NULL;
ss->ssl3.hs.dheSecret = NULL;
- ss->ssl3.hs.pskBinderKey = NULL;
ss->ssl3.hs.clientEarlyTrafficSecret = NULL;
ss->ssl3.hs.clientHsTrafficSecret = NULL;
ss->ssl3.hs.serverHsTrafficSecret = NULL;
@@ -13476,8 +13521,6 @@ ssl3_DestroySSL3Info(sslSocket *ss)
PK11_FreeSymKey(ss->ssl3.hs.resumptionMasterSecret);
if (ss->ssl3.hs.dheSecret)
PK11_FreeSymKey(ss->ssl3.hs.dheSecret);
- if (ss->ssl3.hs.pskBinderKey)
- PK11_FreeSymKey(ss->ssl3.hs.pskBinderKey);
if (ss->ssl3.hs.clientEarlyTrafficSecret)
PK11_FreeSymKey(ss->ssl3.hs.clientEarlyTrafficSecret);
if (ss->ssl3.hs.clientHsTrafficSecret)
@@ -13496,6 +13539,8 @@ ssl3_DestroySSL3Info(sslSocket *ss)
ss->ssl3.hs.zeroRttState = ssl_0rtt_none;
/* Destroy TLS 1.3 buffered early data. */
tls13_DestroyEarlyData(&ss->ssl3.hs.bufferedEarlyData);
+ /* Destroy TLS 1.3 PSKs */
+ tls13_DestroyPskList(&ss->ssl3.hs.psks);
}
#define MAP_NULL(x) (((x) != 0) ? (x) : SEC_OID_NULL_CIPHER)
diff --git a/lib/ssl/ssl3ext.c b/lib/ssl/ssl3ext.c
index 1cad98a7f..65a69450d 100644
--- a/lib/ssl/ssl3ext.c
+++ b/lib/ssl/ssl3ext.c
@@ -10,6 +10,7 @@
#include "nssrenam.h"
#include "nss.h"
+#include "pk11pub.h"
#include "ssl.h"
#include "sslimpl.h"
#include "sslproto.h"
@@ -962,6 +963,7 @@ ssl3_InitExtensionData(TLSExtensionData *xtnData, const sslSocket *ss)
xtnData->peerDelegCred = NULL;
xtnData->peerRequestedDelegCred = PR_FALSE;
xtnData->sendingDelegCredToPeer = PR_FALSE;
+ xtnData->selectedPsk = NULL;
}
void
diff --git a/lib/ssl/ssl3ext.h b/lib/ssl/ssl3ext.h
index 7f09e5fd7..ff2f7c211 100644
--- a/lib/ssl/ssl3ext.h
+++ b/lib/ssl/ssl3ext.h
@@ -134,6 +134,10 @@ struct TLSExtensionDataStr {
* |tls13_MaybeSetDelegatedCredential|.
*/
PRBool sendingDelegCredToPeer;
+
+ /* A non-owning reference to the selected PSKs. MUST NOT be freed directly,
+ * rather through tls13_DestoryPskList(). */
+ sslPsk *selectedPsk;
};
typedef struct TLSExtensionStr {
diff --git a/lib/ssl/sslerr.h b/lib/ssl/sslerr.h
index bc2785f9a..eb8f7c2da 100644
--- a/lib/ssl/sslerr.h
+++ b/lib/ssl/sslerr.h
@@ -275,6 +275,7 @@ typedef enum {
SSL_ERROR_DC_INVALID_KEY_USAGE = (SSL_ERROR_BASE + 184),
SSL_ERROR_DC_EXPIRED = (SSL_ERROR_BASE + 185),
SSL_ERROR_DC_INAPPROPRIATE_VALIDITY_PERIOD = (SSL_ERROR_BASE + 186),
+ SSL_ERROR_FEATURE_DISABLED = (SSL_ERROR_BASE + 187),
SSL_ERROR_END_OF_LIST /* let the c compiler determine the value of this. */
} SSLErrorCodes;
#endif /* NO_SECURITY_ERROR_ENUM */
diff --git a/lib/ssl/sslexp.h b/lib/ssl/sslexp.h
index fb3d612c1..8a92a39ad 100644
--- a/lib/ssl/sslexp.h
+++ b/lib/ssl/sslexp.h
@@ -254,7 +254,8 @@ typedef struct SSLAntiReplayContextStr SSLAntiReplayContext;
*
* This function will fail unless the socket has an active TLS 1.3 session.
* Earlier versions of TLS do not support the spontaneous sending of the
- * NewSessionTicket message.
+ * NewSessionTicket message. It will also fail when external PSK
+ * authentication has been negotiated.
*/
#define SSL_SendSessionTicket(fd, appToken, appTokenLen) \
SSL_EXPERIMENTAL_API("SSL_SendSessionTicket", \
@@ -380,6 +381,10 @@ typedef SSLHelloRetryRequestAction(PR_CALLBACK *SSLHelloRetryRequestCallback)(
* a server. This can be called once at a time, and is not allowed
* until an answer is received.
*
+ * This function is not allowed for use with DTLS or when external
+ * PSK authentication has been negotiated. SECFailure is returned
+ * in both cases.
+ *
* The AuthCertificateCallback is called when the answer is received.
* If the answer is accepted by the server, the value returned by
* SSL_PeerCertificate() is replaced. If you need to remember all the
@@ -947,6 +952,51 @@ typedef struct SSLMaskingContextStr {
SSL_EXPERIMENTAL_API("SSL_SetDtls13VersionWorkaround", \
(PRFileDesc * _fd, PRBool _enabled), (fd, enabled))
+/* SSL_AddExternalPsk() and SSL_AddExternalPsk0Rtt() can be used to
+ * set an external PSK on a socket. If successful, this PSK will
+ * be used in all subsequent connection attempts for this socket.
+ * This has no effect if the maximum TLS version is < 1.3.
+ *
+ * This API currently only accepts a single PSK, so multiple calls to
+ * either function will fail. An EPSK can be replaced by calling
+ * SSL_RemoveExternalPsk followed by SSL_AddExternalPsk.
+ * For both functions, the label is expected to be a unique identifier
+ * for the external PSK. Should en external PSK have the same label
+ * as a configured resumption PSK identity, the external PSK will
+ * take precedence.
+ *
+ * If you want to enable early data, you need to also provide a
+ * cipher suite for 0-RTT and a limit for the early data using
+ * SSL_AddExternalPsk0Rtt(). If you want to explicitly disallow
+ * certificate authentication, use SSL_AuthCertificateHook to set
+ * a callback that rejects all certificate chains.
+ */
+#define SSL_AddExternalPsk(fd, psk, identity, identityLen, hash) \
+ SSL_EXPERIMENTAL_API("SSL_AddExternalPsk", \
+ (PRFileDesc * _fd, PK11SymKey * _psk, \
+ const PRUint8 *_identity, unsigned int _identityLen, \
+ SSLHashType _hash), \
+ (fd, psk, identity, identityLen, hash))
+
+#define SSL_AddExternalPsk0Rtt(fd, psk, identity, identityLen, hash, \
+ zeroRttSuite, maxEarlyData) \
+ SSL_EXPERIMENTAL_API("SSL_AddExternalPsk0Rtt", \
+ (PRFileDesc * _fd, PK11SymKey * _psk, \
+ const PRUint8 *_identity, unsigned int _identityLen, \
+ SSLHashType _hash, PRUint16 _zeroRttSuite, \
+ PRUint32 _maxEarlyData), \
+ (fd, psk, identity, identityLen, hash, \
+ zeroRttSuite, maxEarlyData))
+
+/* SSLExp_RemoveExternalPsk() removes an external PSK from socket
+ * configuration. Returns SECSuccess if the PSK was removed
+ * successfully, and SECFailure otherwise. */
+#define SSL_RemoveExternalPsk(fd, identity, identityLen) \
+ SSL_EXPERIMENTAL_API("SSL_RemoveExternalPsk", \
+ (PRFileDesc * _fd, const PRUint8 *_identity, \
+ unsigned int _identityLen), \
+ (fd, identity, identityLen))
+
/* Deprecated experimental APIs */
#define SSL_UseAltServerHelloType(fd, enable) SSL_DEPRECATED_EXPERIMENTAL_API
#define SSL_SetupAntiReplay(a, b, c) SSL_DEPRECATED_EXPERIMENTAL_API
diff --git a/lib/ssl/sslimpl.h b/lib/ssl/sslimpl.h
index a3f3fe2e7..4a1d61739 100644
--- a/lib/ssl/sslimpl.h
+++ b/lib/ssl/sslimpl.h
@@ -37,6 +37,7 @@
typedef struct sslSocketStr sslSocket;
typedef struct sslNamedGroupDefStr sslNamedGroupDef;
typedef struct sslEsniKeysStr sslEsniKeys;
+typedef struct sslPskStr sslPsk;
typedef struct sslDelegatedCredentialStr sslDelegatedCredential;
typedef struct sslEphemeralKeyPairStr sslEphemeralKeyPair;
typedef struct TLS13KeyShareEntryStr TLS13KeyShareEntry;
@@ -691,9 +692,8 @@ typedef struct SSL3HandshakeStateStr {
/* This group of values is used for TLS 1.3 and above */
PK11SymKey *currentSecret; /* The secret down the "left hand side"
* of the TLS 1.3 key schedule. */
- PK11SymKey *resumptionMasterSecret; /* The resumption PSK. */
+ PK11SymKey *resumptionMasterSecret; /* The resumption_master_secret. */
PK11SymKey *dheSecret; /* The (EC)DHE shared secret. */
- PK11SymKey *pskBinderKey; /* Used to compute the PSK binder. */
PK11SymKey *clientEarlyTrafficSecret; /* The secret we use for 0-RTT. */
PK11SymKey *clientHsTrafficSecret; /* The source keys for handshake */
PK11SymKey *serverHsTrafficSecret; /* traffic keys. */
@@ -724,6 +724,7 @@ typedef struct SSL3HandshakeStateStr {
PRCList dtlsSentHandshake; /* Used to map records to handshake fragments. */
PRCList dtlsRcvdHandshake; /* Handshake records we have received
* used to generate ACKs. */
+ PRCList psks; /* A list of PSKs, resumption and/or external. */
} SSL3HandshakeState;
#define SSL_ASSERT_HASHES_EMPTY(ss) \
@@ -1101,6 +1102,9 @@ struct sslSocketStr {
/* Anti-replay for TLS 1.3 0-RTT. */
SSLAntiReplayContext *antiReplay;
+
+ /* An out-of-band PSK. */
+ sslPsk *psk;
};
struct sslSelfEncryptKeysStr {
diff --git a/lib/ssl/sslinfo.c b/lib/ssl/sslinfo.c
index 115c38dc1..18a03949b 100644
--- a/lib/ssl/sslinfo.c
+++ b/lib/ssl/sslinfo.c
@@ -7,6 +7,7 @@
#include "sslimpl.h"
#include "sslproto.h"
#include "tls13hkdf.h"
+#include "tls13psk.h"
#include "tls13subcerts.h"
SECStatus
@@ -80,6 +81,13 @@ SSL_GetChannelInfo(PRFileDesc *fd, SSLChannelInfo *info, PRUintn len)
inf.signatureScheme = sid->sigScheme;
}
inf.resumed = ss->statelessResume || ss->ssl3.hs.isResuming;
+ if (inf.resumed) {
+ inf.pskType = ssl_psk_resume;
+ } else if (inf.authType == ssl_auth_psk) {
+ inf.pskType = ssl_psk_external;
+ } else {
+ inf.pskType = ssl_psk_none;
+ }
inf.peerDelegCred = tls13_IsVerifyingWithDelegatedCredential(ss);
if (sid) {
@@ -147,8 +155,14 @@ SSL_GetPreliminaryChannelInfo(PRFileDesc *fd,
if (ss->sec.ci.sid &&
(ss->ssl3.hs.zeroRttState == ssl_0rtt_sent ||
ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted)) {
- inf.maxEarlyDataSize =
- ss->sec.ci.sid->u.ssl3.locked.sessionTicket.max_early_data_size;
+ if (ss->statelessResume) {
+ inf.maxEarlyDataSize =
+ ss->sec.ci.sid->u.ssl3.locked.sessionTicket.max_early_data_size;
+ } else if (ss->psk) {
+ /* We may have cleared the handshake list, so check the socket.
+ * This is permissable since we only support one EPSK at a time. */
+ inf.maxEarlyDataSize = ss->psk->maxEarlyData;
+ }
} else {
inf.maxEarlyDataSize = 0;
}
@@ -416,19 +430,20 @@ tls13_Exporter(sslSocket *ss, PK11SymKey *secret,
}
/* Pre-hash the context. */
- rv = tls13_ComputeHash(ss, &contextHash, context, contextLen);
+ SSLHashType hashAlg = tls13_GetHash(ss);
+ rv = tls13_ComputeHash(ss, &contextHash, context, contextLen, hashAlg);
if (rv != SECSuccess) {
return rv;
}
rv = tls13_DeriveSecretNullHash(ss, secret, label, labelLen,
- &innerSecret);
+ &innerSecret, hashAlg);
if (rv != SECSuccess) {
return rv;
}
rv = tls13_HkdfExpandLabelRaw(innerSecret,
- tls13_GetHash(ss),
+ hashAlg,
contextHash.u.raw, contextHash.len,
kExporterInnerLabel,
strlen(kExporterInnerLabel),
diff --git a/lib/ssl/sslsecur.c b/lib/ssl/sslsecur.c
index 14320fa19..ef978c90a 100644
--- a/lib/ssl/sslsecur.c
+++ b/lib/ssl/sslsecur.c
@@ -15,6 +15,7 @@
#include "pk11func.h" /* for PK11_GenerateRandom */
#include "nss.h" /* for NSS_RegisterShutdown */
#include "prinit.h" /* for PR_CallOnceWithArg */
+#include "tls13psk.h"
/* Step through the handshake functions.
*
@@ -173,6 +174,7 @@ SSL_ResetHandshake(PRFileDesc *s, PRBool asServer)
ssl3_DestroyRemoteExtensions(&ss->ssl3.hs.remoteExtensions);
ssl3_ResetExtensionData(&ss->xtnData, ss);
+ tls13_ResetHandshakePsks(ss, &ss->ssl3.hs.psks);
if (!ss->TCPconnected)
ss->TCPconnected = (PR_SUCCESS == ssl_DefGetpeername(ss, &addr));
diff --git a/lib/ssl/sslsock.c b/lib/ssl/sslsock.c
index 9602919e3..83372104e 100644
--- a/lib/ssl/sslsock.c
+++ b/lib/ssl/sslsock.c
@@ -20,6 +20,7 @@
#include "pk11pqg.h"
#include "pk11pub.h"
#include "tls13esni.h"
+#include "tls13psk.h"
#include "tls13subcerts.h"
static const sslSocketOps ssl_default_ops = { /* No SSL. */
@@ -383,6 +384,12 @@ ssl_DupSocket(sslSocket *os)
goto loser;
}
}
+ if (os->psk) {
+ ss->psk = tls13_CopyPsk(os->psk);
+ if (!ss->psk) {
+ goto loser;
+ }
+ }
/* Create security data */
rv = ssl_CopySecurityInfo(ss, os);
@@ -469,9 +476,15 @@ ssl_DestroySocketContents(sslSocket *ss)
ssl_ClearPRCList(&ss->ssl3.hs.dtlsSentHandshake, NULL);
ssl_ClearPRCList(&ss->ssl3.hs.dtlsRcvdHandshake, NULL);
+ tls13_DestroyPskList(&ss->ssl3.hs.psks);
tls13_DestroyESNIKeys(ss->esniKeys);
tls13_ReleaseAntiReplayContext(ss->antiReplay);
+
+ if (ss->psk) {
+ tls13_DestroyPsk(ss->psk);
+ ss->psk = NULL;
+ }
}
/*
@@ -2468,6 +2481,8 @@ SSL_ReconfigFD(PRFileDesc *model, PRFileDesc *fd)
}
}
+ tls13_ResetHandshakePsks(sm, &ss->ssl3.hs.psks);
+
if (sm->authCertificate)
ss->authCertificate = sm->authCertificate;
if (sm->authCertificateArg)
@@ -4161,10 +4176,12 @@ ssl_NewSocket(PRBool makeLocks, SSLProtocolVariant protocolVariant)
ssl3_InitExtensionData(&ss->xtnData, ss);
PR_INIT_CLIST(&ss->ssl3.hs.dtlsSentHandshake);
PR_INIT_CLIST(&ss->ssl3.hs.dtlsRcvdHandshake);
+ PR_INIT_CLIST(&ss->ssl3.hs.psks);
dtls_InitTimers(ss);
ss->esniKeys = NULL;
ss->antiReplay = NULL;
+ ss->psk = NULL;
if (makeLocks) {
rv = ssl_MakeLocks(ss);
@@ -4231,6 +4248,8 @@ struct {
void *function;
} ssl_experimental_functions[] = {
#ifndef SSL_DISABLE_EXPERIMENTAL_API
+ EXP(AddExternalPsk),
+ EXP(AddExternalPsk0Rtt),
EXP(AeadDecrypt),
EXP(AeadEncrypt),
EXP(CipherSuiteOrderGet),
@@ -4261,6 +4280,7 @@ struct {
EXP(RecordLayerData),
EXP(RecordLayerWriteCallback),
EXP(ReleaseAntiReplayContext),
+ EXP(RemoveExternalPsk),
EXP(SecretCallback),
EXP(SendCertificateRequest),
EXP(SendSessionTicket),
diff --git a/lib/ssl/sslt.h b/lib/ssl/sslt.h
index 63dc06c8a..fdee19f89 100644
--- a/lib/ssl/sslt.h
+++ b/lib/ssl/sslt.h
@@ -184,6 +184,12 @@ typedef enum {
ssl_auth_size /* number of authentication types */
} SSLAuthType;
+typedef enum {
+ ssl_psk_none = 0,
+ ssl_psk_resume = 1,
+ ssl_psk_external = 2,
+} SSLPskType;
+
/* This is defined for backward compatibility reasons */
#define ssl_auth_rsa ssl_auth_rsa_decrypt
@@ -358,6 +364,10 @@ typedef struct SSLChannelInfoStr {
*/
PRBool peerDelegCred;
+ /* The following fields were added in NSS 3.54. */
+ /* Indicates what type of PSK, if any, was used in a handshake. */
+ SSLPskType pskType;
+
/* When adding new fields to this structure, please document the
* NSS version in which they were added. */
} SSLChannelInfo;
diff --git a/lib/ssl/tls13con.c b/lib/ssl/tls13con.c
index 6abbc0500..4ea30b639 100644
--- a/lib/ssl/tls13con.c
+++ b/lib/ssl/tls13con.c
@@ -25,6 +25,7 @@
#include "tls13exthandle.h"
#include "tls13hashstate.h"
#include "tls13subcerts.h"
+#include "tls13psk.h"
static SECStatus tls13_SetCipherSpec(sslSocket *ss, PRUint16 epoch,
SSLSecretDirection install,
@@ -65,14 +66,15 @@ tls13_DeriveSecret(sslSocket *ss, PK11SymKey *key,
const char *label,
unsigned int labelLen,
const SSL3Hashes *hashes,
- PK11SymKey **dest);
+ PK11SymKey **dest,
+ SSLHashType hash);
static SECStatus tls13_SendEndOfEarlyData(sslSocket *ss);
static SECStatus tls13_HandleEndOfEarlyData(sslSocket *ss, const PRUint8 *b,
PRUint32 length);
static SECStatus tls13_MaybeHandleSuppressedEndOfEarlyData(sslSocket *ss);
static SECStatus tls13_SendFinished(sslSocket *ss, PK11SymKey *baseKey);
static SECStatus tls13_ComputePskBinderHash(sslSocket *ss, unsigned int prefix,
- SSL3Hashes *hashes);
+ SSL3Hashes *hashes, SSLHashType type);
static SECStatus tls13_VerifyFinished(sslSocket *ss, SSLHandshakeType message,
PK11SymKey *secret,
PRUint8 *b, PRUint32 length,
@@ -86,14 +88,14 @@ static SECStatus tls13_SendNewSessionTicket(sslSocket *ss,
unsigned int appTokenLen);
static SECStatus tls13_HandleNewSessionTicket(sslSocket *ss, PRUint8 *b,
PRUint32 length);
-static SECStatus tls13_ComputeEarlySecrets(sslSocket *ss);
+static SECStatus tls13_ComputeEarlySecretsWithPsk(sslSocket *ss);
static SECStatus tls13_ComputeHandshakeSecrets(sslSocket *ss);
static SECStatus tls13_ComputeApplicationSecrets(sslSocket *ss);
static SECStatus tls13_ComputeFinalSecrets(sslSocket *ss);
static SECStatus tls13_ComputeFinished(
- sslSocket *ss, PK11SymKey *baseKey, const SSL3Hashes *hashes,
- PRBool sending, PRUint8 *output, unsigned int *outputLen,
- unsigned int maxOutputLen);
+ sslSocket *ss, PK11SymKey *baseKey, SSLHashType hashType,
+ const SSL3Hashes *hashes, PRBool sending, PRUint8 *output,
+ unsigned int *outputLen, unsigned int maxOutputLen);
static SECStatus tls13_SendClientSecondRound(sslSocket *ss);
static SECStatus tls13_SendClientSecondFlight(sslSocket *ss,
PRBool sendClientCert,
@@ -103,7 +105,8 @@ static SECStatus tls13_FinishHandshake(sslSocket *ss);
const char kHkdfLabelClient[] = "c";
const char kHkdfLabelServer[] = "s";
const char kHkdfLabelDerivedSecret[] = "derived";
-const char kHkdfLabelPskBinderKey[] = "res binder";
+const char kHkdfLabelResPskBinderKey[] = "res binder";
+const char kHkdfLabelExtPskBinderKey[] = "ext binder";
const char kHkdfLabelEarlyTrafficSecret[] = "e traffic";
const char kHkdfLabelEarlyExporterSecret[] = "e exp master";
const char kHkdfLabelHandshakeTrafficSecret[] = "hs traffic";
@@ -265,6 +268,10 @@ tls13_GetHashForCipherSuite(ssl3CipherSuite suite)
SSLHashType
tls13_GetHash(const sslSocket *ss)
{
+ /* suite_def may not be set yet when doing EPSK 0-Rtt. */
+ if (!ss->ssl3.hs.suite_def && ss->xtnData.selectedPsk) {
+ return ss->xtnData.selectedPsk->hash;
+ }
/* All TLS 1.3 cipher suites must have an explict PRF hash. */
PORT_Assert(ss->ssl3.hs.suite_def->prf_hash != ssl_hash_none);
return ss->ssl3.hs.suite_def->prf_hash;
@@ -319,9 +326,9 @@ tls13_GetHashSize(const sslSocket *ss)
}
static CK_MECHANISM_TYPE
-tls13_GetHmacMechanism(sslSocket *ss)
+tls13_GetHmacMechanismFromHash(SSLHashType hashType)
{
- switch (tls13_GetHash(ss)) {
+ switch (hashType) {
case ssl_hash_sha256:
return CKM_SHA256_HMAC;
case ssl_hash_sha384:
@@ -332,19 +339,25 @@ tls13_GetHmacMechanism(sslSocket *ss)
return CKM_SHA256_HMAC;
}
+static CK_MECHANISM_TYPE
+tls13_GetHmacMechanism(const sslSocket *ss)
+{
+ return tls13_GetHmacMechanismFromHash(tls13_GetHash(ss));
+}
+
SECStatus
tls13_ComputeHash(sslSocket *ss, SSL3Hashes *hashes,
- const PRUint8 *buf, unsigned int len)
+ const PRUint8 *buf, unsigned int len,
+ SSLHashType hash)
{
SECStatus rv;
- rv = PK11_HashBuf(ssl3_HashTypeToOID(tls13_GetHash(ss)),
- hashes->u.raw, buf, len);
+ rv = PK11_HashBuf(ssl3_HashTypeToOID(hash), hashes->u.raw, buf, len);
if (rv != SECSuccess) {
FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
return SECFailure;
}
- hashes->len = tls13_GetHashSize(ss);
+ hashes->len = tls13_GetHashSizeForHash(hash);
return SECSuccess;
}
@@ -462,40 +475,50 @@ tls13_SetupClientHello(sslSocket *ss, sslClientHelloType chType)
return SECFailure;
}
- /* Below here checks if we can do stateless resumption. */
- if (sid->cached == never_cached ||
- sid->version < SSL_LIBRARY_VERSION_TLS_1_3) {
- return SECSuccess;
- }
+ /* Try to do stateless resumption, if we can. */
+ if (sid->cached != never_cached &&
+ sid->version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ /* The caller must be holding sid->u.ssl3.lock for reading. */
+ session_ticket = &sid->u.ssl3.locked.sessionTicket;
+ PORT_Assert(session_ticket && session_ticket->ticket.data);
- /* The caller must be holding sid->u.ssl3.lock for reading. */
- session_ticket = &sid->u.ssl3.locked.sessionTicket;
- PORT_Assert(session_ticket && session_ticket->ticket.data);
+ if (ssl_TicketTimeValid(ss, session_ticket)) {
+ ss->statelessResume = PR_TRUE;
+ }
- if (ssl_TicketTimeValid(ss, session_ticket)) {
- ss->statelessResume = PR_TRUE;
- }
+ if (ss->statelessResume) {
+ PORT_Assert(ss->sec.ci.sid);
+ rv = tls13_RecoverWrappedSharedSecret(ss, ss->sec.ci.sid);
+ if (rv != SECSuccess) {
+ FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
+ SSL_AtomicIncrementLong(&ssl3stats->sch_sid_cache_not_ok);
+ ssl_UncacheSessionID(ss);
+ ssl_FreeSID(ss->sec.ci.sid);
+ ss->sec.ci.sid = NULL;
+ return SECFailure;
+ }
- if (ss->statelessResume) {
- PORT_Assert(ss->sec.ci.sid);
- rv = tls13_RecoverWrappedSharedSecret(ss, ss->sec.ci.sid);
- if (rv != SECSuccess) {
- FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
- SSL_AtomicIncrementLong(&ssl3stats->sch_sid_cache_not_ok);
- ssl_UncacheSessionID(ss);
- ssl_FreeSID(ss->sec.ci.sid);
- ss->sec.ci.sid = NULL;
- return SECFailure;
+ ss->ssl3.hs.cipher_suite = ss->sec.ci.sid->u.ssl3.cipherSuite;
+ rv = ssl3_SetupCipherSuite(ss, PR_FALSE);
+ if (rv != SECSuccess) {
+ FATAL_ERROR(ss, PORT_GetError(), internal_error);
+ return SECFailure;
+ }
+ PORT_Assert(!PR_CLIST_IS_EMPTY(&ss->ssl3.hs.psks));
}
+ }
- ss->ssl3.hs.cipher_suite = ss->sec.ci.sid->u.ssl3.cipherSuite;
- rv = ssl3_SetupCipherSuite(ss, PR_FALSE);
- if (rv != SECSuccess) {
- FATAL_ERROR(ss, PORT_GetError(), internal_error);
- return SECFailure;
+ /* Derive the binder keys if any PSKs. */
+ if (!PR_CLIST_IS_EMPTY(&ss->ssl3.hs.psks)) {
+ /* If an External PSK specified a suite, use that. */
+ sslPsk *psk = (sslPsk *)PR_LIST_HEAD(&ss->ssl3.hs.psks);
+ if (!ss->statelessResume &&
+ psk->type == ssl_psk_external &&
+ psk->zeroRttSuite != TLS_NULL_WITH_NULL_NULL) {
+ ss->ssl3.hs.cipher_suite = psk->zeroRttSuite;
}
- rv = tls13_ComputeEarlySecrets(ss);
+ rv = tls13_ComputeEarlySecretsWithPsk(ss);
if (rv != SECSuccess) {
FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
return SECFailure;
@@ -877,6 +900,11 @@ SSLExp_SendCertificateRequest(PRFileDesc *fd)
return SECFailure;
}
+ if (ss->ssl3.hs.kea_def->authKeyType == ssl_auth_psk) {
+ PORT_SetError(SSL_ERROR_FEATURE_DISABLED);
+ return SECFailure;
+ }
+
rv = TLS13_CHECK_HS_STATE(ss, SEC_ERROR_INVALID_ARGS,
idle_handshake);
if (rv != SECSuccess) {
@@ -990,20 +1018,34 @@ tls13_RecoverWrappedSharedSecret(sslSocket *ss, sslSessionID *sid)
wrappedMS.data = sid->u.ssl3.keys.wrapped_master_secret;
wrappedMS.len = sid->u.ssl3.keys.wrapped_master_secret_len;
- /* unwrap the "master secret" which is actually RMS. */
- ss->ssl3.hs.resumptionMasterSecret = ssl_unwrapSymKey(
- wrapKey, sid->u.ssl3.masterWrapMech,
- NULL, &wrappedMS,
- CKM_SSL3_MASTER_KEY_DERIVE,
- CKA_DERIVE,
- tls13_GetHashSizeForHash(hashType),
- CKF_SIGN | CKF_VERIFY, ss->pkcs11PinArg);
+ PK11SymKey *unwrappedPsk = ssl_unwrapSymKey(wrapKey, sid->u.ssl3.masterWrapMech,
+ NULL, &wrappedMS, CKM_SSL3_MASTER_KEY_DERIVE,
+ CKA_DERIVE, tls13_GetHashSizeForHash(hashType),
+ CKF_SIGN | CKF_VERIFY, ss->pkcs11PinArg);
PK11_FreeSymKey(wrapKey);
- if (!ss->ssl3.hs.resumptionMasterSecret) {
+ if (!unwrappedPsk) {
return SECFailure;
}
+ sslPsk *rpsk = tls13_MakePsk(unwrappedPsk, ssl_psk_resume, hashType, NULL);
+ if (!rpsk) {
+ PK11_FreeSymKey(unwrappedPsk);
+ return SECFailure;
+ }
+ if (sid->u.ssl3.locked.sessionTicket.flags & ticket_allow_early_data) {
+ rpsk->maxEarlyData = sid->u.ssl3.locked.sessionTicket.max_early_data_size;
+ rpsk->zeroRttSuite = sid->u.ssl3.cipherSuite;
+ }
+ PRINT_KEY(50, (ss, "Recovered RMS", rpsk->key));
+ PORT_Assert(PR_CLIST_IS_EMPTY(&ss->ssl3.hs.psks) ||
+ ((sslPsk *)PR_LIST_HEAD(&ss->ssl3.hs.psks))->type != ssl_psk_resume);
- PRINT_KEY(50, (ss, "Recovered RMS", ss->ssl3.hs.resumptionMasterSecret));
+ if (ss->sec.isServer) {
+ /* In server, we couldn't select the RPSK in the extension handler
+ * since it was not unwrapped yet. We're committed now, so select
+ * it and add it to the list (to ensure it is freed). */
+ ss->xtnData.selectedPsk = rpsk;
+ }
+ PR_APPEND_LINK(&rpsk->link, &ss->ssl3.hs.psks);
return SECSuccess;
}
@@ -1061,38 +1103,45 @@ tls13_RecoverWrappedSharedSecret(sslSocket *ss, sslSessionID *sid)
* = resumption_master_secret
*
*/
-
static SECStatus
-tls13_ComputeEarlySecrets(sslSocket *ss)
+tls13_ComputeEarlySecretsWithPsk(sslSocket *ss)
{
- SECStatus rv = SECSuccess;
+ SECStatus rv;
SSL_TRC(5, ("%d: TLS13[%d]: compute early secrets (%s)",
SSL_GETPID(), ss->fd, SSL_ROLE(ss)));
- /* Extract off the resumptionMasterSecret (if present), else pass the NULL
- * resumptionMasterSecret which will be internally translated to zeroes. */
PORT_Assert(!ss->ssl3.hs.currentSecret);
- rv = tls13_HkdfExtract(NULL, ss->ssl3.hs.resumptionMasterSecret,
- tls13_GetHash(ss), &ss->ssl3.hs.currentSecret);
+ sslPsk *psk = NULL;
+
+ if (ss->sec.isServer) {
+ psk = ss->xtnData.selectedPsk;
+ } else {
+ /* Client to use the first PSK for early secrets. */
+ PORT_Assert(!PR_CLIST_IS_EMPTY(&ss->ssl3.hs.psks));
+ psk = (sslPsk *)PR_LIST_HEAD(&ss->ssl3.hs.psks);
+ }
+ PORT_Assert(psk && psk->key);
+ PORT_Assert(psk->hash != ssl_hash_none);
+
+ PK11SymKey *earlySecret = NULL;
+ rv = tls13_HkdfExtract(NULL, psk->key, psk->hash, &earlySecret);
if (rv != SECSuccess) {
return SECFailure;
}
- PORT_Assert(ss->statelessResume == (ss->ssl3.hs.resumptionMasterSecret != NULL));
- if (ss->statelessResume) {
- PK11_FreeSymKey(ss->ssl3.hs.resumptionMasterSecret);
- ss->ssl3.hs.resumptionMasterSecret = NULL;
-
- rv = tls13_DeriveSecretNullHash(ss, ss->ssl3.hs.currentSecret,
- kHkdfLabelPskBinderKey,
- strlen(kHkdfLabelPskBinderKey),
- &ss->ssl3.hs.pskBinderKey);
- if (rv != SECSuccess) {
- return SECFailure;
- }
+ /* No longer need the raw input key */
+ PK11_FreeSymKey(psk->key);
+ psk->key = NULL;
+ const char *label = (psk->type == ssl_psk_resume) ? kHkdfLabelResPskBinderKey : kHkdfLabelExtPskBinderKey;
+ rv = tls13_DeriveSecretNullHash(ss, earlySecret,
+ label, strlen(label),
+ &psk->binderKey, psk->hash);
+ if (rv != SECSuccess) {
+ PK11_FreeSymKey(earlySecret);
+ return SECFailure;
}
- PORT_Assert(!ss->ssl3.hs.resumptionMasterSecret);
+ ss->ssl3.hs.currentSecret = earlySecret;
return SECSuccess;
}
@@ -1102,7 +1151,7 @@ static SECStatus
tls13_DeriveEarlySecrets(sslSocket *ss)
{
SECStatus rv;
-
+ PORT_Assert(ss->ssl3.hs.currentSecret);
rv = tls13_DeriveSecretWrap(ss, ss->ssl3.hs.currentSecret,
kHkdfLabelClient,
kHkdfLabelEarlyTrafficSecret,
@@ -1140,7 +1189,15 @@ tls13_ComputeHandshakeSecrets(sslSocket *ss)
SSL_TRC(5, ("%d: TLS13[%d]: compute handshake secrets (%s)",
SSL_GETPID(), ss->fd, SSL_ROLE(ss)));
- /* First update |currentSecret| to add |dheSecret|, if any. */
+ /* If no PSK, generate the default early secret. */
+ if (!ss->ssl3.hs.currentSecret) {
+ PORT_Assert(!ss->xtnData.selectedPsk);
+ rv = tls13_HkdfExtract(NULL, NULL,
+ tls13_GetHash(ss), &ss->ssl3.hs.currentSecret);
+ if (rv != SECSuccess) {
+ return SECFailure;
+ }
+ }
PORT_Assert(ss->ssl3.hs.currentSecret);
PORT_Assert(ss->ssl3.hs.dheSecret);
@@ -1148,7 +1205,7 @@ tls13_ComputeHandshakeSecrets(sslSocket *ss)
rv = tls13_DeriveSecretNullHash(ss, ss->ssl3.hs.currentSecret,
kHkdfLabelDerivedSecret,
strlen(kHkdfLabelDerivedSecret),
- &derivedSecret);
+ &derivedSecret, tls13_GetHash(ss));
if (rv != SECSuccess) {
LOG_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE);
return rv;
@@ -1207,7 +1264,7 @@ tls13_ComputeHandshakeSecrets(sslSocket *ss)
rv = tls13_DeriveSecretNullHash(ss, ss->ssl3.hs.currentSecret,
kHkdfLabelDerivedSecret,
strlen(kHkdfLabelDerivedSecret),
- &derivedSecret);
+ &derivedSecret, tls13_GetHash(ss));
if (rv != SECSuccess) {
LOG_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE);
return rv;
@@ -1279,7 +1336,7 @@ tls13_ComputeFinalSecrets(sslSocket *ss)
PORT_Assert(!ss->ssl3.crSpec->masterSecret);
PORT_Assert(!ss->ssl3.cwSpec->masterSecret);
-
+ PORT_Assert(ss->ssl3.hs.currentSecret);
rv = tls13_DeriveSecretWrap(ss, ss->ssl3.hs.currentSecret,
NULL, kHkdfLabelResumptionMasterSecret,
NULL,
@@ -1340,21 +1397,40 @@ static PRBool
tls13_CanNegotiateZeroRtt(sslSocket *ss, const sslSessionID *sid)
{
PORT_Assert(ss->ssl3.hs.zeroRttState == ssl_0rtt_sent);
+ sslPsk *psk = ss->xtnData.selectedPsk;
- if (!sid)
+ if (!ss->opt.enable0RttData) {
return PR_FALSE;
- PORT_Assert(ss->statelessResume);
- if (!ss->statelessResume)
+ }
+ if (!psk) {
return PR_FALSE;
- if (ss->ssl3.hs.cipher_suite != sid->u.ssl3.cipherSuite)
+ }
+ if (psk->zeroRttSuite == TLS_NULL_WITH_NULL_NULL) {
return PR_FALSE;
- if (!ss->opt.enable0RttData)
+ }
+ if (!psk->maxEarlyData) {
return PR_FALSE;
- if (!(sid->u.ssl3.locked.sessionTicket.flags & ticket_allow_early_data))
+ }
+ if (ss->ssl3.hs.cipher_suite != psk->zeroRttSuite) {
return PR_FALSE;
- if (SECITEM_CompareItem(&ss->xtnData.nextProto,
- &sid->u.ssl3.alpnSelection) != 0)
+ }
+ if (psk->type == ssl_psk_resume) {
+ if (!sid) {
+ return PR_FALSE;
+ }
+ PORT_Assert(sid->u.ssl3.locked.sessionTicket.flags & ticket_allow_early_data);
+ PORT_Assert(ss->statelessResume);
+ if (!ss->statelessResume) {
+ return PR_FALSE;
+ }
+ if (SECITEM_CompareItem(&ss->xtnData.nextProto,
+ &sid->u.ssl3.alpnSelection) != 0) {
+ return PR_FALSE;
+ }
+ } else if (psk->type != ssl_psk_external) {
+ PORT_Assert(0);
return PR_FALSE;
+ }
if (tls13_IsReplay(ss, sid)) {
return PR_FALSE;
@@ -1407,7 +1483,7 @@ tls13_NegotiateZeroRtt(sslSocket *ss, const sslSessionID *sid)
}
SSL_TRC(3, ("%d: TLS13[%d]: enable 0-RTT", SSL_GETPID(), ss->fd));
- PORT_Assert(ss->statelessResume);
+ PORT_Assert(ss->xtnData.selectedPsk);
ss->ssl3.hs.zeroRttState = ssl_0rtt_accepted;
ss->ssl3.hs.zeroRttIgnore = ssl_0rtt_ignore_none;
ss->ssl3.hs.zeroRttSuite = ss->ssl3.hs.cipher_suite;
@@ -1459,7 +1535,7 @@ tls13_NegotiateKeyExchange(sslSocket *ss,
const sslNamedGroupDef *preferredGroup = NULL;
/* We insist on DHE. */
- if (ss->statelessResume) {
+ if (ssl3_ExtensionNegotiated(ss, ssl_tls13_pre_shared_key_xtn)) {
if (!ssl3_ExtensionNegotiated(ss, ssl_tls13_psk_key_exchange_modes_xtn)) {
FATAL_ERROR(ss, SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES,
missing_extension);
@@ -1483,8 +1559,8 @@ tls13_NegotiateKeyExchange(sslSocket *ss,
return SECFailure;
}
- SSL_TRC(3, ("%d: TLS13[%d]: selected KE = %s",
- SSL_GETPID(), ss->fd, ss->statelessResume ? "PSK + (EC)DHE" : "(EC)DHE"));
+ SSL_TRC(3, ("%d: TLS13[%d]: selected KE = %s", SSL_GETPID(),
+ ss->fd, ss->statelessResume || ss->xtnData.selectedPsk ? "PSK + (EC)DHE" : "(EC)DHE"));
/* Find the preferred group and an according client key share available. */
for (index = 0; index < SSL_NAMED_GROUP_COUNT; ++index) {
@@ -1672,26 +1748,42 @@ tls13_MaybeSendHelloRetry(sslSocket *ss, const sslNamedGroupDef *requestedGroup,
static SECStatus
tls13_NegotiateAuthentication(sslSocket *ss)
{
- SECStatus rv;
-
if (ss->statelessResume) {
- SSL_TRC(3, ("%d: TLS13[%d]: selected PSK authentication",
+ SSL_TRC(3, ("%d: TLS13[%d]: selected resumption PSK authentication",
SSL_GETPID(), ss->fd));
ss->ssl3.hs.signatureScheme = ssl_sig_none;
ss->ssl3.hs.kea_def_mutable.authKeyType = ssl_auth_psk;
+ /* Overwritten by tls13_RestoreCipherInfo. */
+ ss->sec.authType = ssl_auth_psk;
return SECSuccess;
+ } else if (ss->xtnData.selectedPsk) {
+ /* If the EPSK doesn't specify a suite, use what was negotiated.
+ * Else, only use the EPSK if we negotiated that suite. */
+ if (ss->xtnData.selectedPsk->zeroRttSuite == TLS_NULL_WITH_NULL_NULL ||
+ ss->ssl3.hs.cipher_suite == ss->xtnData.selectedPsk->zeroRttSuite) {
+ SSL_TRC(3, ("%d: TLS13[%d]: selected external PSK authentication",
+ SSL_GETPID(), ss->fd));
+ ss->ssl3.hs.signatureScheme = ssl_sig_none;
+ ss->ssl3.hs.kea_def_mutable.authKeyType = ssl_auth_psk;
+ ss->sec.authType = ssl_auth_psk;
+ return SECSuccess;
+ }
+ }
+
+ /* If there were PSKs, they are no longer needed. */
+ if (ss->xtnData.selectedPsk) {
+ tls13_DestroyPskList(&ss->ssl3.hs.psks);
+ ss->xtnData.selectedPsk = NULL;
}
SSL_TRC(3, ("%d: TLS13[%d]: selected certificate authentication",
SSL_GETPID(), ss->fd));
- /* We've now established that we need to sign.... */
- rv = tls13_SelectServerCert(ss);
+ SECStatus rv = tls13_SelectServerCert(ss);
if (rv != SECSuccess) {
return SECFailure;
}
return SECSuccess;
}
-
/* Called from ssl3_HandleClientHello after we have parsed the
* ClientHello and are sure that we are going to do TLS 1.3
* or fail. */
@@ -1855,40 +1947,51 @@ tls13_HandleClientHelloPart2(sslSocket *ss,
goto loser;
}
- if (ss->statelessResume) {
- /* We are now committed to trying to resume. */
- PORT_Assert(sid);
-
- /* Check that the negotiated SNI and the cached SNI match. */
- if (SECITEM_CompareItem(&sid->u.ssl3.srvName,
- &ss->ssl3.hs.srvVirtName) != SECEqual) {
- FATAL_ERROR(ss, SSL_ERROR_RX_MALFORMED_CLIENT_HELLO,
- handshake_failure);
- goto loser;
- }
+ if (ss->sec.authType == ssl_auth_psk) {
+ if (ss->statelessResume) {
+ /* We are now committed to trying to resume. */
+ PORT_Assert(sid);
+ /* Check that the negotiated SNI and the cached SNI match. */
+ if (SECITEM_CompareItem(&sid->u.ssl3.srvName,
+ &ss->ssl3.hs.srvVirtName) != SECEqual) {
+ FATAL_ERROR(ss, SSL_ERROR_RX_MALFORMED_CLIENT_HELLO,
+ handshake_failure);
+ goto loser;
+ }
- ss->sec.serverCert = ssl_FindServerCert(ss, sid->authType,
- sid->namedCurve);
- PORT_Assert(ss->sec.serverCert);
+ ss->sec.serverCert = ssl_FindServerCert(ss, sid->authType,
+ sid->namedCurve);
+ PORT_Assert(ss->sec.serverCert);
- rv = tls13_RecoverWrappedSharedSecret(ss, sid);
- if (rv != SECSuccess) {
- SSL_AtomicIncrementLong(&ssl3stats->hch_sid_cache_not_ok);
- FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
- goto loser;
- }
- tls13_RestoreCipherInfo(ss, sid);
+ rv = tls13_RecoverWrappedSharedSecret(ss, sid);
+ if (rv != SECSuccess) {
+ SSL_AtomicIncrementLong(&ssl3stats->hch_sid_cache_not_ok);
+ FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
+ goto loser;
+ }
+ tls13_RestoreCipherInfo(ss, sid);
- ss->sec.localCert = CERT_DupCertificate(ss->sec.serverCert->serverCert);
- if (sid->peerCert != NULL) {
- ss->sec.peerCert = CERT_DupCertificate(sid->peerCert);
+ ss->sec.localCert = CERT_DupCertificate(ss->sec.serverCert->serverCert);
+ if (sid->peerCert != NULL) {
+ ss->sec.peerCert = CERT_DupCertificate(sid->peerCert);
+ }
+ } else if (sid) {
+ /* We should never have a SID in the non-resumption case. */
+ PORT_Assert(0);
+ ssl_UncacheSessionID(ss);
+ ssl_FreeSID(sid);
+ sid = NULL;
}
-
ssl3_RegisterExtensionSender(
ss, &ss->xtnData,
ssl_tls13_pre_shared_key_xtn, tls13_ServerSendPreSharedKeyXtn);
-
tls13_NegotiateZeroRtt(ss, sid);
+
+ rv = tls13_ComputeEarlySecretsWithPsk(ss);
+ if (rv != SECSuccess) {
+ FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
+ return SECFailure;
+ }
} else {
if (sid) { /* we had a sid, but it's no longer valid, free it */
SSL_AtomicIncrementLong(&ssl3stats->hch_sid_cache_not_ok);
@@ -1899,35 +2002,34 @@ tls13_HandleClientHelloPart2(sslSocket *ss,
tls13_NegotiateZeroRtt(ss, NULL);
}
- /* Need to compute early secrets. */
- rv = tls13_ComputeEarlySecrets(ss);
- if (rv != SECSuccess) {
- FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
- return SECFailure;
+ if (ss->statelessResume) {
+ PORT_Assert(ss->xtnData.selectedPsk);
+ PORT_Assert(ss->ssl3.hs.kea_def_mutable.authKeyType = ssl_auth_psk);
}
- /* Now that we have the binder key check the binder. */
- if (ss->statelessResume) {
+ /* Now that we have the binder key, check the binder. */
+ if (ss->xtnData.selectedPsk) {
SSL3Hashes hashes;
-
PORT_Assert(ss->ssl3.hs.messages.len > ss->xtnData.pskBindersLen);
rv = tls13_ComputePskBinderHash(
ss,
ss->ssl3.hs.messages.len - ss->xtnData.pskBindersLen,
- &hashes);
+ &hashes, tls13_GetHash(ss));
if (rv != SECSuccess) {
FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
goto loser;
}
+ PORT_Assert(ss->xtnData.selectedPsk->hash == tls13_GetHash(ss));
+ PORT_Assert(ss->ssl3.hs.suite_def);
rv = tls13_VerifyFinished(ss, ssl_hs_client_hello,
- ss->ssl3.hs.pskBinderKey,
+ ss->xtnData.selectedPsk->binderKey,
ss->xtnData.pskBinder.data,
ss->xtnData.pskBinder.len,
&hashes);
- if (rv != SECSuccess) {
- goto loser;
- }
+ }
+ if (rv != SECSuccess) {
+ goto loser;
}
/* This needs to go after we verify the psk binder. */
@@ -1954,7 +2056,7 @@ tls13_HandleClientHelloPart2(sslSocket *ss,
SSL_AtomicIncrementLong(&ssl3stats->hch_sid_cache_not_ok);
ssl_UncacheSessionID(ss);
ssl_FreeSID(sid);
- } else {
+ } else if (!ss->xtnData.selectedPsk) {
SSL_AtomicIncrementLong(&ssl3stats->hch_sid_cache_misses);
}
@@ -1984,6 +2086,10 @@ tls13_HandleClientHelloPart2(sslSocket *ss,
return SECFailure;
}
+ /* We're done with PSKs */
+ tls13_DestroyPskList(&ss->ssl3.hs.psks);
+ ss->xtnData.selectedPsk = NULL;
+
return SECSuccess;
loser:
@@ -2315,7 +2421,8 @@ tls13_ReinjectHandshakeTranscript(sslSocket *ss)
// First compute the hash.
rv = tls13_ComputeHash(ss, &hashes,
ss->ssl3.hs.messages.buf,
- ss->ssl3.hs.messages.len);
+ ss->ssl3.hs.messages.len,
+ tls13_GetHash(ss));
if (rv != SECSuccess) {
return SECFailure;
}
@@ -2332,7 +2439,6 @@ tls13_ReinjectHandshakeTranscript(sslSocket *ss)
return SECSuccess;
}
-
static unsigned int
ssl_ListCount(PRCList *list)
{
@@ -2455,6 +2561,12 @@ tls13_HandleCertificateRequest(sslSocket *ss, PRUint8 *b, PRUint32 length)
return SECFailure;
}
+ /* MUST NOT combine external PSKs with certificate authentication. */
+ if (ss->ssl3.hs.kea_def->authKeyType == ssl_auth_psk) {
+ FATAL_ERROR(ss, SSL_ERROR_RX_UNEXPECTED_CERT_REQUEST, unexpected_message);
+ return SECFailure;
+ }
+
if (tls13_IsPostHandshake(ss)) {
PORT_Assert(ss->ssl3.hs.shaPostHandshake == NULL);
ss->ssl3.hs.shaPostHandshake = PK11_CloneContext(ss->ssl3.hs.sha);
@@ -2750,14 +2862,22 @@ tls13_HandleServerHelloPart2(sslSocket *ss)
SSL3Statistics *ssl3stats = SSL_GetStatistics();
if (ssl3_ExtensionNegotiated(ss, ssl_tls13_pre_shared_key_xtn)) {
- PORT_Assert(ss->statelessResume);
+ PORT_Assert(!PR_CLIST_IS_EMPTY(&ss->ssl3.hs.psks));
+ PORT_Assert(ss->xtnData.selectedPsk);
+
+ if (ss->xtnData.selectedPsk->type != ssl_psk_resume) {
+ ss->statelessResume = PR_FALSE;
+ }
} else {
+ /* We may have offered a PSK. If the server didn't negotiate
+ * it, clear this state to re-extract the Early Secret. */
if (ss->ssl3.hs.currentSecret) {
- PORT_Assert(ss->statelessResume);
+ PORT_Assert(ssl3_ExtensionAdvertised(ss, ssl_tls13_pre_shared_key_xtn));
PK11_FreeSymKey(ss->ssl3.hs.currentSecret);
ss->ssl3.hs.currentSecret = NULL;
}
ss->statelessResume = PR_FALSE;
+ ss->xtnData.selectedPsk = NULL;
}
if (ss->statelessResume) {
@@ -2774,19 +2894,22 @@ tls13_HandleServerHelloPart2(sslSocket *ss)
ss->ssl3.hs.kea_def_mutable = *ss->ssl3.hs.kea_def;
ss->ssl3.hs.kea_def = &ss->ssl3.hs.kea_def_mutable;
- if (ss->statelessResume) {
- /* PSK */
+ if (ss->xtnData.selectedPsk) {
ss->ssl3.hs.kea_def_mutable.authKeyType = ssl_auth_psk;
- tls13_RestoreCipherInfo(ss, sid);
- if (sid->peerCert) {
- ss->sec.peerCert = CERT_DupCertificate(sid->peerCert);
- }
+ if (ss->statelessResume) {
+ tls13_RestoreCipherInfo(ss, sid);
+ if (sid->peerCert) {
+ ss->sec.peerCert = CERT_DupCertificate(sid->peerCert);
+ }
- SSL_AtomicIncrementLong(&ssl3stats->hsh_sid_cache_hits);
- SSL_AtomicIncrementLong(&ssl3stats->hsh_sid_stateless_resumes);
+ SSL_AtomicIncrementLong(&ssl3stats->hsh_sid_cache_hits);
+ SSL_AtomicIncrementLong(&ssl3stats->hsh_sid_stateless_resumes);
+ } else {
+ ss->sec.authType = ssl_auth_psk;
+ }
} else {
- /* !PSK */
- if (ssl3_ExtensionAdvertised(ss, ssl_tls13_pre_shared_key_xtn)) {
+ if (ss->statelessResume &&
+ ssl3_ExtensionAdvertised(ss, ssl_tls13_pre_shared_key_xtn)) {
SSL_AtomicIncrementLong(&ssl3stats->hsh_sid_cache_misses);
}
if (sid->cached == in_client_cache) {
@@ -2795,18 +2918,6 @@ tls13_HandleServerHelloPart2(sslSocket *ss)
}
}
- if (!ss->ssl3.hs.currentSecret) {
- PORT_Assert(!ss->statelessResume);
-
- /* If we don't already have the Early Secret we need to make it
- * now. */
- rv = tls13_ComputeEarlySecrets(ss);
- if (rv != SECSuccess) {
- FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
- return SECFailure;
- }
- }
-
/* Discard current SID and make a new one, though it may eventually
* end up looking a lot like the old one.
*/
@@ -3310,15 +3421,14 @@ tls13_DeriveSecret(sslSocket *ss, PK11SymKey *key,
const char *label,
unsigned int labelLen,
const SSL3Hashes *hashes,
- PK11SymKey **dest)
+ PK11SymKey **dest,
+ SSLHashType hash)
{
SECStatus rv;
- rv = tls13_HkdfExpandLabel(key, tls13_GetHash(ss),
- hashes->u.raw, hashes->len,
- label, labelLen,
- CKM_HKDF_DERIVE,
- tls13_GetHashSize(ss),
+ rv = tls13_HkdfExpandLabel(key, hash, hashes->u.raw, hashes->len,
+ label, labelLen, CKM_HKDF_DERIVE,
+ tls13_GetHashSizeForHash(hash),
ss->protocolVariant, dest);
if (rv != SECSuccess) {
LOG_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE);
@@ -3332,18 +3442,19 @@ SECStatus
tls13_DeriveSecretNullHash(sslSocket *ss, PK11SymKey *key,
const char *label,
unsigned int labelLen,
- PK11SymKey **dest)
+ PK11SymKey **dest,
+ SSLHashType hash)
{
SSL3Hashes hashes;
SECStatus rv;
PRUint8 buf[] = { 0 };
- rv = tls13_ComputeHash(ss, &hashes, buf, 0);
+ rv = tls13_ComputeHash(ss, &hashes, buf, 0, hash);
if (rv != SECSuccess) {
return SECFailure;
}
- return tls13_DeriveSecret(ss, key, label, labelLen, &hashes, dest);
+ return tls13_DeriveSecret(ss, key, label, labelLen, &hashes, dest, hash);
}
/* Convenience wrapper that lets us supply a separate prefix and suffix. */
@@ -3382,7 +3493,7 @@ tls13_DeriveSecretWrap(sslSocket *ss, PK11SymKey *key,
}
rv = tls13_DeriveSecret(ss, key, label, strlen(label),
- &hashes, dest);
+ &hashes, dest, tls13_GetHash(ss));
if (rv != SECSuccess) {
return SECFailure;
}
@@ -3546,8 +3657,10 @@ tls13_SetupPendingCipherSpec(sslSocket *ss, ssl3CipherSpec *spec)
spec->cipherDef = ssl_GetBulkCipherDef(ssl_LookupCipherSuiteDef(suite));
if (spec->epoch == TrafficKeyEarlyApplicationData) {
- spec->earlyDataRemaining =
- ss->sec.ci.sid->u.ssl3.locked.sessionTicket.max_early_data_size;
+ if (ss->xtnData.selectedPsk &&
+ ss->xtnData.selectedPsk->zeroRttSuite != TLS_NULL_WITH_NULL_NULL) {
+ spec->earlyDataRemaining = ss->xtnData.selectedPsk->maxEarlyData;
+ }
}
tls13_SetSpecRecordVersion(ss, spec);
@@ -4013,7 +4126,7 @@ tls13_HandleEncryptedExtensions(sslSocket *ss, PRUint8 *b, PRUint32 length)
/* We can only get here if we offered 0-RTT. */
if (ssl3_ExtensionNegotiated(ss, ssl_tls13_early_data_xtn)) {
PORT_Assert(ss->ssl3.hs.zeroRttState == ssl_0rtt_sent);
- if (!ss->statelessResume) {
+ if (!ss->xtnData.selectedPsk) {
/* Illegal to accept 0-RTT without also accepting PSK. */
FATAL_ERROR(ss, SSL_ERROR_RX_MALFORMED_ENCRYPTED_EXTENSIONS,
illegal_parameter);
@@ -4051,6 +4164,10 @@ tls13_HandleEncryptedExtensions(sslSocket *ss, PRUint8 *b, PRUint32 length)
TLS13_SET_HS_STATE(ss, wait_cert_request);
}
+ /* Client is done with any PSKs */
+ tls13_DestroyPskList(&ss->ssl3.hs.psks);
+ ss->xtnData.selectedPsk = NULL;
+
return SECSuccess;
}
@@ -4330,7 +4447,7 @@ loser:
static SECStatus
tls13_ComputePskBinderHash(sslSocket *ss, unsigned int prefixLength,
- SSL3Hashes *hashes)
+ SSL3Hashes *hashes, SSLHashType hashType)
{
SECStatus rv;
@@ -4340,14 +4457,14 @@ tls13_ComputePskBinderHash(sslSocket *ss, unsigned int prefixLength,
PRINT_BUF(10, (NULL, "Handshake hash computed over ClientHello prefix",
ss->ssl3.hs.messages.buf, prefixLength));
- rv = PK11_HashBuf(ssl3_HashTypeToOID(tls13_GetHash(ss)),
+ rv = PK11_HashBuf(ssl3_HashTypeToOID(hashType),
hashes->u.raw, ss->ssl3.hs.messages.buf, prefixLength);
if (rv != SECSuccess) {
ssl_MapLowLevelError(SSL_ERROR_SHA_DIGEST_FAILURE);
return SECFailure;
}
- hashes->len = tls13_GetHashSize(ss);
+ hashes->len = tls13_GetHashSizeForHash(hashType);
PRINT_BUF(10, (NULL, "PSK Binder hash", hashes->u.raw, hashes->len));
return SECSuccess;
@@ -4365,7 +4482,10 @@ tls13_WriteExtensionsWithBinder(sslSocket *ss, sslBuffer *extensions)
{
SSL3Hashes hashes;
SECStatus rv;
- unsigned int size = tls13_GetHashSize(ss);
+
+ PORT_Assert(!PR_CLIST_IS_EMPTY(&ss->ssl3.hs.psks));
+ sslPsk *psk = (sslPsk *)PR_LIST_HEAD(&ss->ssl3.hs.psks);
+ unsigned int size = tls13_GetHashSizeForHash(psk->hash);
unsigned int prefixLen = extensions->len - size - 3;
unsigned int finishedLen;
@@ -4386,15 +4506,18 @@ tls13_WriteExtensionsWithBinder(sslSocket *ss, sslBuffer *extensions)
}
/* Calculate the binder based on what has been written out. */
- rv = tls13_ComputePskBinderHash(ss, ss->ssl3.hs.messages.len, &hashes);
+ rv = tls13_ComputePskBinderHash(ss, ss->ssl3.hs.messages.len,
+ &hashes, psk->hash);
if (rv != SECSuccess) {
return SECFailure;
}
/* Write the binder into the extensions buffer, over the zeros we reserved
- * previously. This avoids an allocation and means that we don't need a
+ * previously. This avoids an allocation and means that we don't need a
* separate write for the extra bits that precede the binder. */
- rv = tls13_ComputeFinished(ss, ss->ssl3.hs.pskBinderKey, &hashes, PR_TRUE,
+ PORT_Assert(psk->binderKey);
+ rv = tls13_ComputeFinished(ss, psk->binderKey,
+ psk->hash, &hashes, PR_TRUE,
extensions->buf + extensions->len - size,
&finishedLen, size);
if (rv != SECSuccess) {
@@ -4414,13 +4537,13 @@ tls13_WriteExtensionsWithBinder(sslSocket *ss, sslBuffer *extensions)
static SECStatus
tls13_ComputeFinished(sslSocket *ss, PK11SymKey *baseKey,
- const SSL3Hashes *hashes,
+ SSLHashType hashType, const SSL3Hashes *hashes,
PRBool sending, PRUint8 *output, unsigned int *outputLen,
unsigned int maxOutputLen)
{
SECStatus rv;
PK11Context *hmacCtx = NULL;
- CK_MECHANISM_TYPE macAlg = tls13_GetHmacMechanism(ss);
+ CK_MECHANISM_TYPE macAlg = tls13_GetHmacMechanismFromHash(hashType);
SECItem param = { siBuffer, NULL, 0 };
unsigned int outputLenUint;
const char *label = kHkdfLabelFinishedSecret;
@@ -4432,18 +4555,16 @@ tls13_ComputeFinished(sslSocket *ss, PK11SymKey *baseKey,
PRINT_BUF(50, (ss, "Handshake hash", hashes->u.raw, hashes->len));
/* Now derive the appropriate finished secret from the base secret. */
- rv = tls13_HkdfExpandLabel(baseKey,
- tls13_GetHash(ss),
- NULL, 0,
- label, strlen(label),
- tls13_GetHmacMechanism(ss),
- tls13_GetHashSize(ss),
+ rv = tls13_HkdfExpandLabel(baseKey, hashType,
+ NULL, 0, label, strlen(label),
+ tls13_GetHmacMechanismFromHash(hashType),
+ tls13_GetHashSizeForHash(hashType),
ss->protocolVariant, &secret);
if (rv != SECSuccess) {
goto abort;
}
- PORT_Assert(hashes->len == tls13_GetHashSize(ss));
+ PORT_Assert(hashes->len == tls13_GetHashSizeForHash(hashType));
hmacCtx = PK11_CreateContextBySymKey(macAlg, CKA_SIGN,
secret, &param);
if (!hmacCtx) {
@@ -4458,7 +4579,7 @@ tls13_ComputeFinished(sslSocket *ss, PK11SymKey *baseKey,
if (rv != SECSuccess)
goto abort;
- PORT_Assert(maxOutputLen >= tls13_GetHashSize(ss));
+ PORT_Assert(maxOutputLen >= tls13_GetHashSizeForHash(hashType));
rv = PK11_DigestFinal(hmacCtx, output, &outputLenUint, maxOutputLen);
if (rv != SECSuccess)
goto abort;
@@ -4502,7 +4623,7 @@ tls13_SendFinished(sslSocket *ss, PK11SymKey *baseKey)
}
ssl_GetSpecReadLock(ss);
- rv = tls13_ComputeFinished(ss, baseKey, &hashes, PR_TRUE,
+ rv = tls13_ComputeFinished(ss, baseKey, tls13_GetHash(ss), &hashes, PR_TRUE,
finishedBuf, &finishedLen, sizeof(finishedBuf));
ssl_ReleaseSpecReadLock(ss);
if (rv != SECSuccess) {
@@ -4539,7 +4660,7 @@ tls13_VerifyFinished(sslSocket *ss, SSLHandshakeType message,
return SECFailure;
}
- rv = tls13_ComputeFinished(ss, secret, hashes, PR_FALSE,
+ rv = tls13_ComputeFinished(ss, secret, tls13_GetHash(ss), hashes, PR_FALSE,
finishedBuf, &finishedLen, sizeof(finishedBuf));
if (rv != SECSuccess) {
FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
@@ -4701,7 +4822,8 @@ tls13_ServerHandleFinished(sslSocket *ss, PRUint8 *b, PRUint32 length)
}
ssl_GetXmitBufLock(ss);
- if (ss->opt.enableSessionTickets) {
+ /* If resumption, authType is the original value and not ssl_auth_psk. */
+ if (ss->opt.enableSessionTickets && ss->sec.authType != ssl_auth_psk) {
rv = tls13_SendNewSessionTicket(ss, NULL, 0);
if (rv != SECSuccess) {
goto loser;
@@ -5084,6 +5206,11 @@ SSLExp_SendSessionTicket(PRFileDesc *fd, const PRUint8 *token,
return SECFailure;
}
+ if (ss->ssl3.hs.kea_def->authKeyType == ssl_auth_psk) {
+ PORT_SetError(SSL_ERROR_FEATURE_DISABLED);
+ return SECFailure;
+ }
+
ssl_GetSSL3HandshakeLock(ss);
ssl_GetXmitBufLock(ss);
rv = tls13_SendNewSessionTicket(ss, token, tokenLen);
@@ -5618,9 +5745,10 @@ tls13_UnprotectRecord(sslSocket *ss,
* 1. We are doing TLS 1.3
* 2. This isn't a second ClientHello (in response to HelloRetryRequest)
* 3. The 0-RTT option is set.
- * 4. We have a valid ticket.
- * 5. The server is willing to accept 0-RTT.
- * 6. We have not changed our ALPN settings to disallow the ALPN tag
+ * 4. We have a valid ticket or an External PSK.
+ * 5. If resuming:
+ * 5a. The server is willing to accept 0-RTT.
+ * 5b. We have not changed our ALPN settings to disallow the ALPN tag
* in the ticket.
*
* Called from tls13_ClientSendEarlyDataXtn().
@@ -5630,17 +5758,39 @@ tls13_ClientAllow0Rtt(const sslSocket *ss, const sslSessionID *sid)
{
/* We checked that the cipher suite was still allowed back in
* ssl3_SendClientHello. */
- if (sid->version < SSL_LIBRARY_VERSION_TLS_1_3)
+ if (sid->version < SSL_LIBRARY_VERSION_TLS_1_3) {
+ return PR_FALSE;
+ }
+ if (ss->ssl3.hs.helloRetry) {
return PR_FALSE;
- if (ss->ssl3.hs.helloRetry)
+ }
+ if (!ss->opt.enable0RttData) {
return PR_FALSE;
- if (!ss->opt.enable0RttData)
+ }
+ if (PR_CLIST_IS_EMPTY(&ss->ssl3.hs.psks)) {
return PR_FALSE;
- if (!ss->statelessResume)
+ }
+ sslPsk *psk = (sslPsk *)PR_LIST_HEAD(&ss->ssl3.hs.psks);
+
+ if (psk->zeroRttSuite == TLS_NULL_WITH_NULL_NULL) {
return PR_FALSE;
- if ((sid->u.ssl3.locked.sessionTicket.flags & ticket_allow_early_data) == 0)
+ }
+ if (!psk->maxEarlyData) {
return PR_FALSE;
- return ssl_AlpnTagAllowed(ss, &sid->u.ssl3.alpnSelection);
+ }
+
+ if (psk->type == ssl_psk_external) {
+ return psk->hash == tls13_GetHashForCipherSuite(psk->zeroRttSuite);
+ }
+ if (psk->type == ssl_psk_resume) {
+ if (!ss->statelessResume)
+ return PR_FALSE;
+ if ((sid->u.ssl3.locked.sessionTicket.flags & ticket_allow_early_data) == 0)
+ return PR_FALSE;
+ return ssl_AlpnTagAllowed(ss, &sid->u.ssl3.alpnSelection);
+ }
+ PORT_Assert(0);
+ return PR_FALSE;
}
SECStatus
@@ -5687,6 +5837,9 @@ tls13_MaybeDo0RTTHandshake(sslSocket *ss)
}
}
+ /* If we're trying 0-RTT, derive from the first PSK */
+ PORT_Assert(!PR_CLIST_IS_EMPTY(&ss->ssl3.hs.psks) && !ss->xtnData.selectedPsk);
+ ss->xtnData.selectedPsk = (sslPsk *)PR_LIST_HEAD(&ss->ssl3.hs.psks);
rv = tls13_DeriveEarlySecrets(ss);
if (rv != SECSuccess) {
return SECFailure;
@@ -5698,6 +5851,7 @@ tls13_MaybeDo0RTTHandshake(sslSocket *ss)
rv = tls13_SetCipherSpec(ss, TrafficKeyEarlyApplicationData,
ssl_secret_write, PR_TRUE);
+ ss->xtnData.selectedPsk = NULL;
if (rv != SECSuccess) {
return SECFailure;
}
diff --git a/lib/ssl/tls13con.h b/lib/ssl/tls13con.h
index dd693b377..9a3cd14c1 100644
--- a/lib/ssl/tls13con.h
+++ b/lib/ssl/tls13con.h
@@ -51,13 +51,15 @@ SSLHashType tls13_GetHashForCipherSuite(ssl3CipherSuite suite);
unsigned int tls13_GetHashSize(const sslSocket *ss);
unsigned int tls13_GetHashSizeForHash(SSLHashType hash);
SECStatus tls13_ComputeHash(sslSocket *ss, SSL3Hashes *hashes,
- const PRUint8 *buf, unsigned int len);
+ const PRUint8 *buf, unsigned int len,
+ SSLHashType hash);
SECStatus tls13_ComputeHandshakeHashes(sslSocket *ss,
SSL3Hashes *hashes);
SECStatus tls13_DeriveSecretNullHash(sslSocket *ss, PK11SymKey *key,
const char *label,
unsigned int labelLen,
- PK11SymKey **dest);
+ PK11SymKey **dest,
+ SSLHashType hash);
void tls13_FatalError(sslSocket *ss, PRErrorCode prError,
SSL3AlertDescription desc);
SECStatus tls13_SetupClientHello(sslSocket *ss, sslClientHelloType chType);
diff --git a/lib/ssl/tls13exthandle.c b/lib/ssl/tls13exthandle.c
index 5768fbce5..ee3d309ec 100644
--- a/lib/ssl/tls13exthandle.c
+++ b/lib/ssl/tls13exthandle.c
@@ -14,6 +14,7 @@
#include "ssl3exthandle.h"
#include "tls13esni.h"
#include "tls13exthandle.h"
+#include "tls13psk.h"
#include "tls13subcerts.h"
SECStatus
@@ -408,69 +409,92 @@ tls13_ServerSendKeyShareXtn(const sslSocket *ss, TLSExtensionData *xtnData,
* };
*
* } PreSharedKeyExtension;
-
- * Presently the only way to get a PSK is by resumption, so this is
- * really a ticket label and there will be at most one.
*/
SECStatus
tls13_ClientSendPreSharedKeyXtn(const sslSocket *ss, TLSExtensionData *xtnData,
sslBuffer *buf, PRBool *added)
{
- NewSessionTicket *session_ticket;
- PRTime age;
const static PRUint8 binder[TLS13_MAX_FINISHED_SIZE] = { 0 };
unsigned int binderLen;
+ unsigned int identityLen = 0;
+ const PRUint8 *identity = NULL;
+ PRTime age;
SECStatus rv;
- /* We only set statelessResume on the client in TLS 1.3 code. */
- if (!ss->statelessResume) {
+ /* Exit early if no PSKs or max version < 1.3. */
+ if (PR_CLIST_IS_EMPTY(&ss->ssl3.hs.psks) ||
+ ss->vrange.max < SSL_LIBRARY_VERSION_TLS_1_3) {
+ return SECSuccess;
+ }
+
+ /* ...or if PSK type is resumption, but we're not resuming. */
+ sslPsk *psk = (sslPsk *)PR_LIST_HEAD(&ss->ssl3.hs.psks);
+ if (psk->type == ssl_psk_resume && !ss->statelessResume) {
return SECSuccess;
}
/* Save where this extension starts so that if we have to add padding, it
- * can be inserted before this extension. */
+ * can be inserted before this extension. */
PORT_Assert(buf->len >= 4);
xtnData->lastXtnOffset = buf->len - 4;
+ PORT_Assert(psk->type == ssl_psk_resume || psk->type == ssl_psk_external);
+ binderLen = tls13_GetHashSizeForHash(psk->hash);
+ if (psk->type == ssl_psk_resume) {
+ /* Send a single ticket identity. */
+ NewSessionTicket *session_ticket = &ss->sec.ci.sid->u.ssl3.locked.sessionTicket;
+ identityLen = session_ticket->ticket.len;
+ identity = session_ticket->ticket.data;
+
+ /* Obfuscated age. */
+ age = ssl_Time(ss) - session_ticket->received_timestamp;
+ age /= PR_USEC_PER_MSEC;
+ age += session_ticket->ticket_age_add;
+ PRINT_BUF(50, (ss, "Sending Resumption PSK with identity", identity, identityLen));
+ } else if (psk->type == ssl_psk_external) {
+ identityLen = psk->label.len;
+ identity = psk->label.data;
+ age = 0;
+ PRINT_BUF(50, (ss, "Sending External PSK with label", identity, identityLen));
+ } else {
+ PORT_Assert(0);
+ PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
+ return SECFailure;
+ }
- PORT_Assert(ss->vrange.max >= SSL_LIBRARY_VERSION_TLS_1_3);
- PORT_Assert(ss->sec.ci.sid->version >= SSL_LIBRARY_VERSION_TLS_1_3);
-
- /* Send a single ticket identity. */
- session_ticket = &ss->sec.ci.sid->u.ssl3.locked.sessionTicket;
- rv = sslBuffer_AppendNumber(buf, 2 + /* identity length */
- session_ticket->ticket.len + /* ticket */
- 4 /* obfuscated_ticket_age */,
- 2);
- if (rv != SECSuccess)
+ /* Length is len(identityLen) + identityLen + len(age) */
+ rv = sslBuffer_AppendNumber(buf, 2 + identityLen + 4, 2);
+ if (rv != SECSuccess) {
goto loser;
- rv = sslBuffer_AppendVariable(buf, session_ticket->ticket.data,
- session_ticket->ticket.len, 2);
- if (rv != SECSuccess)
+ }
+
+ rv = sslBuffer_AppendVariable(buf, identity,
+ identityLen, 2);
+ if (rv != SECSuccess) {
goto loser;
+ }
- /* Obfuscated age. */
- age = ssl_Time(ss) - session_ticket->received_timestamp;
- age /= PR_USEC_PER_MSEC;
- age += session_ticket->ticket_age_add;
rv = sslBuffer_AppendNumber(buf, age, 4);
- if (rv != SECSuccess)
+ if (rv != SECSuccess) {
goto loser;
+ }
/* Write out the binder list length. */
- binderLen = tls13_GetHashSize(ss);
rv = sslBuffer_AppendNumber(buf, binderLen + 1, 2);
- if (rv != SECSuccess)
+ if (rv != SECSuccess) {
goto loser;
- /* Write zeroes for the binder for the moment. */
+ }
+
+ /* Write zeroes for the binder for the moment. These
+ * are overwritten in tls13_WriteExtensionsWithBinder. */
rv = sslBuffer_AppendVariable(buf, binder, binderLen, 1);
- if (rv != SECSuccess)
+ if (rv != SECSuccess) {
goto loser;
+ }
- PRINT_BUF(50, (ss, "Sending PreSharedKey value",
- session_ticket->ticket.data,
- session_ticket->ticket.len));
+ if (psk->type == ssl_psk_resume) {
+ xtnData->sentSessionTicketInClientHello = PR_TRUE;
+ }
- xtnData->sentSessionTicketInClientHello = PR_TRUE;
*added = PR_TRUE;
return SECSuccess;
@@ -479,8 +503,7 @@ loser:
return SECFailure;
}
-/* Handle a TLS 1.3 PreSharedKey Extension. We only accept PSKs
- * that contain session tickets. */
+/* Handle a TLS 1.3 PreSharedKey Extension. */
SECStatus
tls13_ServerHandlePreSharedKeyXtn(const sslSocket *ss, TLSExtensionData *xtnData,
SECItem *data)
@@ -534,28 +557,52 @@ tls13_ServerHandlePreSharedKeyXtn(const sslSocket *ss, TLSExtensionData *xtnData
return rv;
if (!numIdentities) {
- PRINT_BUF(50, (ss, "Handling PreSharedKey value",
- label.data, label.len));
- rv = ssl3_ProcessSessionTicketCommon(
- CONST_CAST(sslSocket, ss), &label, appToken);
- /* This only happens if we have an internal error, not
- * a malformed ticket. Bogus tickets just don't resume
- * and return SECSuccess. */
- if (rv != SECSuccess)
- return SECFailure;
+ /* Check any configured external PSK for a matching label.
+ * If none exists, try to parse it as a ticket. */
+ PORT_Assert(!xtnData->selectedPsk);
+ for (PRCList *cur_p = PR_LIST_HEAD(&ss->ssl3.hs.psks);
+ cur_p != &ss->ssl3.hs.psks;
+ cur_p = PR_NEXT_LINK(cur_p)) {
+ sslPsk *psk = (sslPsk *)cur_p;
+ if (psk->type != ssl_psk_external ||
+ SECITEM_CompareItem(&psk->label, &label) != SECEqual) {
+ continue;
+ }
+ PRINT_BUF(50, (ss, "Using External PSK with label",
+ psk->label.data, psk->label.len));
+ xtnData->selectedPsk = psk;
+ }
- if (ss->sec.ci.sid) {
- /* xtnData->ticketAge contains the baseline we use for
- * calculating the ticket age (i.e., our RTT estimate less the
- * value of ticket_age_add).
- *
- * Add that to the obfuscated ticket age to recover the client's
- * view of the ticket age plus the estimated RTT.
- *
- * See ssl3_EncodeSessionTicket() for details. */
- xtnData->ticketAge += obfuscatedAge;
+ if (!xtnData->selectedPsk) {
+ PRINT_BUF(50, (ss, "Handling PreSharedKey value",
+ label.data, label.len));
+ rv = ssl3_ProcessSessionTicketCommon(
+ CONST_CAST(sslSocket, ss), &label, appToken);
+ /* This only happens if we have an internal error, not
+ * a malformed ticket. Bogus tickets just don't resume
+ * and return SECSuccess. */
+ if (rv != SECSuccess) {
+ return SECFailure;
+ }
+
+ if (ss->sec.ci.sid) {
+ /* xtnData->ticketAge contains the baseline we use for
+ * calculating the ticket age (i.e., our RTT estimate less the
+ * value of ticket_age_add).
+ *
+ * Add that to the obfuscated ticket age to recover the client's
+ * view of the ticket age plus the estimated RTT.
+ *
+ * See ssl3_EncodeSessionTicket() for details. */
+ xtnData->ticketAge += obfuscatedAge;
+
+ /* We are not committed to resumption until after unwrapping the
+ * RMS in tls13_HandleClientHelloPart2. The RPSK will be stored
+ * in ss->xtnData.selectedPsk at that point, so continue. */
+ }
}
}
+
++numIdentities;
}
@@ -589,10 +636,14 @@ tls13_ServerHandlePreSharedKeyXtn(const sslSocket *ss, TLSExtensionData *xtnData
if (numBinders != numIdentities)
goto alert_loser;
- /* Keep track of negotiated extensions. Note that this does not
- * mean we are resuming. */
- xtnData->negotiated[xtnData->numNegotiated++] = ssl_tls13_pre_shared_key_xtn;
+ if (ss->statelessResume) {
+ PORT_Assert(!ss->xtnData.selectedPsk);
+ } else if (!xtnData->selectedPsk) {
+ /* No matching EPSK. */
+ return SECSuccess;
+ }
+ xtnData->negotiated[xtnData->numNegotiated++] = ssl_tls13_pre_shared_key_xtn;
return SECSuccess;
alert_loser:
@@ -618,8 +669,7 @@ tls13_ServerSendPreSharedKeyXtn(const sslSocket *ss, TLSExtensionData *xtnData,
return SECSuccess;
}
-/* Handle a TLS 1.3 PreSharedKey Extension. We only accept PSKs
- * that contain session tickets. */
+/* Handle a TLS 1.3 PreSharedKey Extension. */
SECStatus
tls13_ClientHandlePreSharedKeyXtn(const sslSocket *ss, TLSExtensionData *xtnData,
SECItem *data)
@@ -648,12 +698,23 @@ tls13_ClientHandlePreSharedKeyXtn(const sslSocket *ss, TLSExtensionData *xtnData
/* We only sent one PSK label so index must be equal to 0 */
if (index) {
+ ssl3_ExtSendAlert(ss, alert_fatal, illegal_parameter);
PORT_SetError(SSL_ERROR_MALFORMED_PRE_SHARED_KEY);
return SECFailure;
}
+ PORT_Assert(!PR_CLIST_IS_EMPTY(&ss->ssl3.hs.psks));
+ sslPsk *candidate = (sslPsk *)PR_LIST_HEAD(&ss->ssl3.hs.psks);
+
+ /* Check that the server-selected ciphersuite hash and PSK hash match. */
+ if (candidate->hash != tls13_GetHashForCipherSuite(ss->ssl3.hs.cipher_suite)) {
+ ssl3_ExtSendAlert(ss, alert_fatal, illegal_parameter);
+ return SECFailure;
+ }
+
/* Keep track of negotiated extensions. */
xtnData->negotiated[xtnData->numNegotiated++] = ssl_tls13_pre_shared_key_xtn;
+ xtnData->selectedPsk = candidate;
return SECSuccess;
}
diff --git a/lib/ssl/tls13psk.c b/lib/ssl/tls13psk.c
new file mode 100644
index 000000000..cc1d14106
--- /dev/null
+++ b/lib/ssl/tls13psk.c
@@ -0,0 +1,219 @@
+/* -*- Mode: C; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 4 -*- */
+/*
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "nss.h"
+#include "pk11func.h"
+#include "ssl.h"
+#include "sslproto.h"
+#include "sslimpl.h"
+#include "ssl3exthandle.h"
+#include "tls13exthandle.h"
+#include "tls13hkdf.h"
+#include "tls13psk.h"
+
+SECStatus
+SSLExp_AddExternalPsk0Rtt(PRFileDesc *fd, PK11SymKey *key, const PRUint8 *identity,
+ unsigned int identityLen, SSLHashType hash,
+ PRUint16 zeroRttSuite, PRUint32 maxEarlyData)
+{
+
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ SSL_DBG(("%d: SSL[%d]: bad socket in SSLExp_SetExternalPsk",
+ SSL_GETPID(), fd));
+ return SECFailure;
+ }
+
+ if (!key || !identity || !identityLen || identityLen > 0xFFFF ||
+ (hash != ssl_hash_sha256 && hash != ssl_hash_sha384)) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+
+ SECItem label = { siBuffer, CONST_CAST(unsigned char, identity), identityLen };
+ sslPsk *psk = tls13_MakePsk(PK11_ReferenceSymKey(key), ssl_psk_external,
+ hash, &label);
+ if (!psk) {
+ PORT_SetError(SEC_ERROR_NO_MEMORY);
+ return SECFailure;
+ }
+ psk->zeroRttSuite = zeroRttSuite;
+ psk->maxEarlyData = maxEarlyData;
+ SECStatus rv = SECFailure;
+
+ ssl_Get1stHandshakeLock(ss);
+ ssl_GetSSL3HandshakeLock(ss);
+
+ if (ss->psk) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ tls13_DestroyPsk(psk);
+ } else {
+ ss->psk = psk;
+ rv = SECSuccess;
+ tls13_ResetHandshakePsks(ss, &ss->ssl3.hs.psks);
+ }
+
+ ssl_ReleaseSSL3HandshakeLock(ss);
+ ssl_Release1stHandshakeLock(ss);
+
+ return rv;
+}
+
+SECStatus
+SSLExp_AddExternalPsk(PRFileDesc *fd, PK11SymKey *key, const PRUint8 *identity,
+ unsigned int identityLen, SSLHashType hash)
+{
+ return SSLExp_AddExternalPsk0Rtt(fd, key, identity, identityLen,
+ hash, TLS_NULL_WITH_NULL_NULL, 0);
+}
+
+SECStatus
+SSLExp_RemoveExternalPsk(PRFileDesc *fd, const PRUint8 *identity, unsigned int identityLen)
+{
+ if (!identity || !identityLen) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
+ return SECFailure;
+ }
+
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ SSL_DBG(("%d: SSL[%d]: bad socket in SSL_SetPSK",
+ SSL_GETPID(), fd));
+ return SECFailure;
+ }
+
+ SECItem removeIdentity = { siBuffer,
+ (unsigned char *)identity,
+ identityLen };
+
+ SECStatus rv;
+ ssl_Get1stHandshakeLock(ss);
+ ssl_GetSSL3HandshakeLock(ss);
+
+ if (!ss->psk || SECITEM_CompareItem(&ss->psk->label, &removeIdentity) != SECEqual) {
+ PORT_SetError(SEC_ERROR_NO_KEY);
+ rv = SECFailure;
+ } else {
+ tls13_DestroyPsk(ss->psk);
+ ss->psk = NULL;
+ tls13_ResetHandshakePsks(ss, &ss->ssl3.hs.psks);
+ rv = SECSuccess;
+ }
+
+ ssl_ReleaseSSL3HandshakeLock(ss);
+ ssl_Release1stHandshakeLock(ss);
+
+ return rv;
+}
+
+sslPsk *
+tls13_CopyPsk(sslPsk *opsk)
+{
+ if (!opsk || !opsk->key) {
+ return NULL;
+ }
+
+ sslPsk *psk = PORT_ZNew(sslPsk);
+ if (!psk) {
+ return NULL;
+ }
+
+ SECStatus rv = SECITEM_CopyItem(NULL, &psk->label, &opsk->label);
+ if (rv != SECSuccess) {
+ PORT_Free(psk);
+ return NULL;
+ }
+ /* We should only have the initial key. Binder keys
+ * are derived during the handshake. */
+ PORT_Assert(opsk->type == ssl_psk_external);
+ PORT_Assert(opsk->key);
+ PORT_Assert(opsk->binderKey);
+ psk->hash = opsk->hash;
+ psk->type = opsk->type;
+ psk->key = opsk->key ? PK11_ReferenceSymKey(opsk->key) : NULL;
+ psk->binderKey = opsk->binderKey ? PK11_ReferenceSymKey(opsk->binderKey) : NULL;
+ return psk;
+}
+
+void
+tls13_DestroyPsk(sslPsk *psk)
+{
+ if (!psk) {
+ return;
+ }
+ if (psk->key) {
+ PK11_FreeSymKey(psk->key);
+ psk->key = NULL;
+ }
+ if (psk->binderKey) {
+ PK11_FreeSymKey(psk->binderKey);
+ psk->binderKey = NULL;
+ }
+ SECITEM_ZfreeItem(&psk->label, PR_FALSE);
+ PORT_ZFree(psk, sizeof(*psk));
+}
+
+void
+tls13_DestroyPskList(PRCList *list)
+{
+ PRCList *cur_p;
+ while (!PR_CLIST_IS_EMPTY(list)) {
+ cur_p = PR_LIST_TAIL(list);
+ PR_REMOVE_LINK(cur_p);
+ tls13_DestroyPsk((sslPsk *)cur_p);
+ }
+}
+
+sslPsk *
+tls13_MakePsk(PK11SymKey *key, SSLPskType pskType, SSLHashType hashType, const SECItem *label)
+{
+ sslPsk *psk = PORT_ZNew(sslPsk);
+ if (!psk) {
+ PORT_SetError(SEC_ERROR_NO_MEMORY);
+ return NULL;
+ }
+ psk->type = pskType;
+ psk->hash = hashType;
+ psk->key = key;
+
+ /* Label is NULL in the resumption case. */
+ if (label) {
+ PORT_Assert(psk->type != ssl_psk_resume);
+ SECStatus rv = SECITEM_CopyItem(NULL, &psk->label, label);
+ if (rv != SECSuccess) {
+ PORT_SetError(SEC_ERROR_NO_MEMORY);
+ tls13_DestroyPsk(psk);
+ return NULL;
+ }
+ }
+
+ return psk;
+}
+
+/* Destroy any existing PSKs in |list| then copy
+ * in the configured |ss->psk|, if any.*/
+SECStatus
+tls13_ResetHandshakePsks(sslSocket *ss, PRCList *list)
+{
+ tls13_DestroyPskList(list);
+ PORT_Assert(!ss->xtnData.selectedPsk);
+ ss->xtnData.selectedPsk = NULL;
+ if (ss->psk) {
+ PORT_Assert(ss->psk->type == ssl_psk_external);
+ PORT_Assert(ss->psk->key);
+ PORT_Assert(!ss->psk->binderKey);
+
+ sslPsk *epsk = tls13_MakePsk(PK11_ReferenceSymKey(ss->psk->key),
+ ss->psk->type, ss->psk->hash, &ss->psk->label);
+ if (!epsk) {
+ return SECFailure;
+ }
+ epsk->zeroRttSuite = ss->psk->zeroRttSuite;
+ epsk->maxEarlyData = ss->psk->maxEarlyData;
+ PR_APPEND_LINK(&epsk->link, list);
+ }
+ return SECSuccess;
+} \ No newline at end of file
diff --git a/lib/ssl/tls13psk.h b/lib/ssl/tls13psk.h
new file mode 100644
index 000000000..73013fb9b
--- /dev/null
+++ b/lib/ssl/tls13psk.h
@@ -0,0 +1,58 @@
+/* -*- Mode: C; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 4 -*- */
+/*
+ * This file is PRIVATE to SSL.
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef __tls13psk_h_
+#define __tls13psk_h_
+
+/*
+ * Internally, we have track sslPsk pointers in three locations:
+ * 1) An external PSK can be configured to the socket, in which case ss->psk will hold an owned reference.
+ * For now, this only holds one external PSK. The value will persist across handshake restarts.
+ * 2) When a handshake begins, the ss->psk value is deep-copied into ss->ssl3.hs.psks, which may also hold
+ * a resumption PSK. This is essentially a priority-sorted list (where a resumption PSK has higher
+ * priority than external), and we currently only send one PskIdentity and binder.
+ * 3) During negotiation, ss->xtnData.selectedPsk will either be NULL or it will hold a non-owning refernce
+ * to the PSK that has been (or is being) negotiated.
+ */
+
+/* Note: When holding a resumption PSK:
+ * 1. |hash| comes from the original connection.
+ * 2. |label| is ignored: The identity sent in the pre_shared_key_xtn
+ * comes from ss->sec.ci.sid->u.ssl3.locked.sessionTicket.
+ */
+struct sslPskStr {
+ PRCList link;
+ PK11SymKey *key; /* A raw PSK. */
+ PK11SymKey *binderKey; /* The binder key derived from |key|. |key| is NULL after derivation. */
+ SSLPskType type; /* none, resumption, or external. */
+ SECItem label; /* Label (identity) for an external PSK. */
+ SSLHashType hash; /* A hash algorithm associated with a PSK. */
+ ssl3CipherSuite zeroRttSuite; /* For EPSKs, an explicitly-configured ciphersuite for 0-Rtt. */
+ PRUint32 maxEarlyData; /* For EPSKs, a limit on early data. Must be > 0 for 0-Rtt. */
+};
+
+SECStatus SSLExp_AddExternalPsk(PRFileDesc *fd, PK11SymKey *psk, const PRUint8 *identity,
+ unsigned int identitylen, SSLHashType hash);
+
+SECStatus SSLExp_AddExternalPsk0Rtt(PRFileDesc *fd, PK11SymKey *psk, const PRUint8 *identity,
+ unsigned int identitylen, SSLHashType hash,
+ PRUint16 zeroRttSuite, PRUint32 maxEarlyData);
+
+SECStatus SSLExp_RemoveExternalPsk(PRFileDesc *fd, const PRUint8 *identity, unsigned int identitylen);
+
+sslPsk *tls13_CopyPsk(sslPsk *opsk);
+
+void tls13_DestroyPsk(sslPsk *psk);
+
+void tls13_DestroyPskList(PRCList *list);
+
+sslPsk *tls13_MakePsk(PK11SymKey *key, SSLPskType pskType, SSLHashType hashType, const SECItem *label);
+
+SECStatus tls13_ResetHandshakePsks(sslSocket *ss, PRCList *list);
+
+#endif
diff --git a/lib/ssl/tls13replay.c b/lib/ssl/tls13replay.c
index b6d1416f3..8deb596d2 100644
--- a/lib/ssl/tls13replay.c
+++ b/lib/ssl/tls13replay.c
@@ -16,6 +16,7 @@
#include "sslbloom.h"
#include "sslimpl.h"
#include "tls13hkdf.h"
+#include "tls13psk.h"
struct SSLAntiReplayContextStr {
/* The number of outstanding references to this context. */
@@ -250,7 +251,9 @@ tls13_IsReplay(const sslSocket *ss, const sslSessionID *sid)
return PR_TRUE;
}
- if (!tls13_InWindow(ss, sid)) {
+ if (!sid) {
+ PORT_Assert(ss->xtnData.selectedPsk->type == ssl_psk_external);
+ } else if (!tls13_InWindow(ss, sid)) {
return PR_TRUE;
}