diff options
-rw-r--r-- | cpputil/tls_parser.h | 26 | ||||
-rw-r--r-- | gtests/ssl_gtest/libssl_internals.c | 12 | ||||
-rw-r--r-- | gtests/ssl_gtest/libssl_internals.h | 1 | ||||
-rw-r--r-- | gtests/ssl_gtest/ssl_recordsep_unittest.cc | 362 | ||||
-rw-r--r-- | gtests/ssl_gtest/test_io.cc | 9 | ||||
-rw-r--r-- | gtests/ssl_gtest/test_io.h | 2 | ||||
-rw-r--r-- | gtests/ssl_gtest/tls_agent.cc | 9 | ||||
-rw-r--r-- | gtests/ssl_gtest/tls_agent.h | 1 | ||||
-rw-r--r-- | gtests/ssl_gtest/tls_connect.cc | 14 | ||||
-rw-r--r-- | lib/ssl/dtlscon.c | 1 | ||||
-rw-r--r-- | lib/ssl/ssl3con.c | 115 | ||||
-rw-r--r-- | lib/ssl/ssl3gthr.c | 228 | ||||
-rw-r--r-- | lib/ssl/sslcon.c | 16 | ||||
-rw-r--r-- | lib/ssl/ssldef.c | 2 | ||||
-rw-r--r-- | lib/ssl/sslexp.h | 57 | ||||
-rw-r--r-- | lib/ssl/sslimpl.h | 15 | ||||
-rw-r--r-- | lib/ssl/sslsecur.c | 82 | ||||
-rw-r--r-- | lib/ssl/sslsock.c | 2 | ||||
-rw-r--r-- | lib/ssl/tls13con.c | 3 |
19 files changed, 733 insertions, 224 deletions
diff --git a/cpputil/tls_parser.h b/cpputil/tls_parser.h index cd9e28fc3..881c5268e 100644 --- a/cpputil/tls_parser.h +++ b/cpputil/tls_parser.h @@ -80,6 +80,32 @@ inline std::ostream& operator<<(std::ostream& os, SSLProtocolVariant v) { return os << ((v == ssl_variant_stream) ? "TLS" : "DTLS"); } +inline std::ostream& operator<<(std::ostream& os, SSLContentType v) { + switch (v) { + case ssl_ct_change_cipher_spec: + return os << "CCS"; + case ssl_ct_alert: + return os << "alert"; + case ssl_ct_handshake: + return os << "handshake"; + case ssl_ct_application_data: + return os << "application data"; + case ssl_ct_ack: + return os << "ack"; + } + return os << "UNKNOWN content type " << static_cast<int>(v); +} + +inline std::ostream& operator<<(std::ostream& os, SSLSecretDirection v) { + switch (v) { + case ssl_secret_read: + return os << "read"; + case ssl_secret_write: + return os << "write"; + } + return os << "UNKNOWN secret direction " << static_cast<int>(v); +} + inline bool IsDtls(uint16_t version) { return (version & 0x8000) == 0x8000; } inline uint16_t NormalizeTlsVersion(uint16_t version) { diff --git a/gtests/ssl_gtest/libssl_internals.c b/gtests/ssl_gtest/libssl_internals.c index e43113de4..81ba7349e 100644 --- a/gtests/ssl_gtest/libssl_internals.c +++ b/gtests/ssl_gtest/libssl_internals.c @@ -373,3 +373,15 @@ SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch, ssl_ReleaseSpecReadLock(ss); return SECSuccess; } + +SECStatus SSLInt_HasPendingHandshakeData(PRFileDesc *fd, PRBool *pending) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + ssl_GetSSL3HandshakeLock(ss); + *pending = ss->ssl3.hs.msg_body.len > 0; + ssl_ReleaseSSL3HandshakeLock(ss); + return SECSuccess; +} diff --git a/gtests/ssl_gtest/libssl_internals.h b/gtests/ssl_gtest/libssl_internals.h index 3efb362c2..e61d7c942 100644 --- a/gtests/ssl_gtest/libssl_internals.h +++ b/gtests/ssl_gtest/libssl_internals.h @@ -41,6 +41,7 @@ SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra); SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group); SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch, PRUint16 *writeEpoch); +SECStatus SSLInt_HasPendingHandshakeData(PRFileDesc *fd, PRBool *pending); SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd, sslCipherSpecChangedFunc func, diff --git a/gtests/ssl_gtest/ssl_recordsep_unittest.cc b/gtests/ssl_gtest/ssl_recordsep_unittest.cc index 6a1565f07..cb3cb02a4 100644 --- a/gtests/ssl_gtest/ssl_recordsep_unittest.cc +++ b/gtests/ssl_gtest/ssl_recordsep_unittest.cc @@ -14,6 +14,7 @@ extern "C" { #include "libssl_internals.h" } +#include <queue> #include "gtest_utils.h" #include "nss_scoped_ptrs.h" #include "tls_connect.h" @@ -50,8 +51,7 @@ class HandshakeSecretTracker { PK11SymKey* secret) { if (g_ssl_gtest_verbose) { std::cerr << agent_->role_str() << ": secret callback for " - << (dir == ssl_secret_read ? "read" : "write") << " epoch " - << epoch << std::endl; + << dir << " epoch " << epoch << std::endl; } EXPECT_TRUE(secret); @@ -162,4 +162,362 @@ TEST_F(TlsConnectTest, KeyUpdateSecrets) { s.CheckCalled(); } +// BadPrSocket is an instance of a PR IO layer that crashes the test if it is +// ever used for reading or writing. It does that by failing to overwrite any +// of the DummyIOLayerMethods, which all crash when invoked. +class BadPrSocket : public DummyIOLayerMethods { + public: + BadPrSocket(std::shared_ptr<TlsAgent>& agent) : DummyIOLayerMethods() { + static PRDescIdentity bad_identity = PR_GetUniqueIdentity("bad NSPR id"); + fd_ = DummyIOLayerMethods::CreateFD(bad_identity, this); + + // This is terrible, but NSPR doesn't provide an easy way to replace the + // bottom layer of an IO stack. Take the DummyPrSocket and replace its + // NSPR method vtable with the ones from this object. + dummy_layer_ = + PR_GetIdentitiesLayer(agent->ssl_fd(), DummyPrSocket::LayerId()); + original_methods_ = dummy_layer_->methods; + original_secret_ = dummy_layer_->secret; + dummy_layer_->methods = fd_->methods; + dummy_layer_->secret = reinterpret_cast<PRFilePrivate*>(this); + } + + // This will be destroyed before the agent, so we need to restore the state + // before we tampered with it. + virtual ~BadPrSocket() { + dummy_layer_->methods = original_methods_; + dummy_layer_->secret = original_secret_; + } + + private: + ScopedPRFileDesc fd_; + PRFileDesc* dummy_layer_; + const PRIOMethods* original_methods_; + PRFilePrivate* original_secret_; +}; + +class StagedRecords { + public: + StagedRecords(std::shared_ptr<TlsAgent>& agent) : agent_(agent), records_() { + EXPECT_EQ(SECSuccess, + SSL_RecordLayerWriteCallback( + agent_->ssl_fd(), StagedRecords::StageRecordData, this)); + } + + virtual ~StagedRecords() { + // Uninstall so that the callback doesn't fire during cleanup. + EXPECT_EQ(SECSuccess, + SSL_RecordLayerWriteCallback(agent_->ssl_fd(), nullptr, nullptr)); + } + + bool empty() const { return records_.empty(); } + + void ForwardAll(std::shared_ptr<TlsAgent>& peer) { + EXPECT_NE(agent_, peer) << "can't forward to self"; + for (auto r : records_) { + r.Forward(peer); + } + records_.clear(); + } + + // This forwards all saved data and checks the resulting state. + void ForwardAll(std::shared_ptr<TlsAgent>& peer, + TlsAgent::State expected_state) { + ForwardAll(peer); + peer->Handshake(); + EXPECT_EQ(expected_state, peer->state()); + } + + void ForwardPartial(std::shared_ptr<TlsAgent>& peer) { + if (records_.empty()) { + ADD_FAILURE() << "No records to slice"; + return; + } + auto& last = records_.back(); + auto tail = last.SliceTail(); + ForwardAll(peer, TlsAgent::STATE_CONNECTING); + records_.push_back(tail); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, peer->state()); + } + + private: + // A single record. + class StagedRecord { + public: + StagedRecord(const std::string role, uint16_t epoch, SSLContentType ct, + const uint8_t* data, size_t len) + : role_(role), epoch_(epoch), content_type_(ct), data_(data, len) { + if (g_ssl_gtest_verbose) { + std::cerr << role_ << ": staged epoch " << epoch_ << " " + << content_type_ << ": " << data_ << std::endl; + } + } + + // This forwards staged data to the identified agent. + void Forward(std::shared_ptr<TlsAgent>& peer) { + // Now there should be staged data. + EXPECT_FALSE(data_.empty()); + if (g_ssl_gtest_verbose) { + std::cerr << role_ << ": forward " << data_ << std::endl; + } + SECStatus rv = SSL_RecordLayerData( + peer->ssl_fd(), epoch_, content_type_, data_.data(), + static_cast<unsigned int>(data_.len())); + if (rv != SECSuccess) { + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + } + } + + // Slices the tail off this record and returns it. + StagedRecord SliceTail() { + size_t slice = 1; + if (data_.len() <= slice) { + ADD_FAILURE() << "record too small to slice in two"; + slice = 0; + } + size_t keep = data_.len() - slice; + StagedRecord tail(role_, epoch_, content_type_, data_.data() + keep, + slice); + data_.Truncate(keep); + return tail; + } + + private: + std::string role_; + uint16_t epoch_; + SSLContentType content_type_; + DataBuffer data_; + }; + + // This is an SSLRecordWriteCallback that stages data. + static SECStatus StageRecordData(PRFileDesc* fd, PRUint16 epoch, + SSLContentType content_type, + const PRUint8* data, unsigned int len, + void* arg) { + auto stage = reinterpret_cast<StagedRecords*>(arg); + stage->records_.push_back(StagedRecord(stage->agent_->role_str(), epoch, + content_type, data, + static_cast<size_t>(len))); + return SECSuccess; + } + + std::shared_ptr<TlsAgent>& agent_; + std::deque<StagedRecord> records_; +}; + +// Attempting to feed application data in before the handshake is complete +// should be caught. +static void RefuseApplicationData(std::shared_ptr<TlsAgent>& peer, + uint16_t epoch) { + static const uint8_t d[] = {1, 2, 3}; + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(peer->ssl_fd(), epoch, ssl_ct_application_data, + d, static_cast<unsigned int>(sizeof(d)))); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +static void SendForwardReceive(std::shared_ptr<TlsAgent>& sender, + StagedRecords& sender_stage, + std::shared_ptr<TlsAgent>& receiver) { + const size_t count = 10; + sender->SendData(count, count); + sender_stage.ForwardAll(receiver); + receiver->ReadBytes(count); +} + +TEST_P(TlsConnectStream, ReplaceRecordLayer) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + // BadPrSocket installs an IO layer that crashes when the SSL layer attempts + // to read or write. + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + + // StagedRecords installs a handler for unprotected data from the socket, and + // captures that data. + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + // Both peers should refuse application data from epoch 0. + RefuseApplicationData(client_, 0); + RefuseApplicationData(server_, 0); + + // This first call forwards nothing, but it causes the client to handshake, + // which starts things off. This stages the ClientHello as a result. + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + // This processes the ClientHello and stages the first server flight. + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + RefuseApplicationData(server_, 1); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // Process the server flight and the client is done. + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + } else { + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + RefuseApplicationData(client_, 1); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + } + CheckKeys(); + + // Reading and writing application data should work. + SendForwardReceive(client_, client_stage, server_); + SendForwardReceive(server_, server_stage, client_); +} + +static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) { + return SECWouldBlock; +} + +TEST_P(TlsConnectStream, ReplaceRecordLayerAsyncLateAuth) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + + // Prior to TLS 1.3, the client sends its second flight immediately. But in + // TLS 1.3, a client won't send a Finished until it is happy with the server + // certificate. So blocking certificate validation causes the client to send + // nothing. + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_TRUE(client_stage.empty()); + + // Client should have stopped reading when it saw the Certificate message, + // so it will be reading handshake epoch, and writing cleartext. + client_->CheckEpochs(2, 0); + // Server should be reading handshake, and writing application data. + server_->CheckEpochs(2, 3); + + // Handshake again and the client will read the remainder of the server's + // flight, but it will remain blocked. + client_->Handshake(); + ASSERT_TRUE(client_stage.empty()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + } else { + // In prior versions, the client's second flight is always sent. + ASSERT_FALSE(client_stage.empty()); + } + + // Now declare the certificate good. + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + client_->Handshake(); + ASSERT_FALSE(client_stage.empty()); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + } else { + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + } + CheckKeys(); + + // Reading and writing application data should work. + SendForwardReceive(client_, client_stage, server_); +} + +// This test ensures that data is correctly forwarded when the handshake is +// resumed after asynchronous server certificate authentication, when +// SSL_AuthCertificateComplete() is called. The logic for resuming the +// handshake involves a different code path than the usual one, so this test +// exercises that code fully. +TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerAsyncEarlyAuth) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + + // Send a partial flight on to the client. + // This includes enough to trigger the certificate callback. + server_stage.ForwardPartial(client_); + EXPECT_TRUE(client_stage.empty()); + + // Declare the certificate good. + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + client_->Handshake(); + EXPECT_TRUE(client_stage.empty()); + + // Send the remainder of the server flight. + PRBool pending = PR_FALSE; + EXPECT_EQ(SECSuccess, + SSLInt_HasPendingHandshakeData(client_->ssl_fd(), &pending)); + EXPECT_EQ(PR_TRUE, pending); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + CheckKeys(); + + SendForwardReceive(server_, server_stage, client_); +} + +TEST_P(TlsConnectStream, ForwardDataFromWrongEpoch) { + const uint8_t data[] = {1}; + Connect(); + uint16_t next_epoch; + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), 2, ssl_ct_application_data, + data, sizeof(data))); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()) + << "Passing data from an old epoch is rejected"; + next_epoch = 4; + } else { + // Prior to TLS 1.3, the epoch is only updated once during the handshake. + next_epoch = 2; + } + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), next_epoch, + ssl_ct_application_data, data, sizeof(data))); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()) + << "Passing data from a future epoch blocks"; +} + +TEST_F(TlsConnectStreamTls13, ForwardInvalidData) { + const uint8_t data[1] = {0}; + + // NULL fd. + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(nullptr, 0, + ssl_ct_application_data, nullptr, 1)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Zero-length data. + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), 0, + ssl_ct_application_data, data, 0)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // NULL data. + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), 0, + ssl_ct_application_data, nullptr, 1)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectDatagram13, ForwardDataDtls) { + const uint8_t data[1] = {0}; + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), 0, + ssl_ct_application_data, data, sizeof(data))); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + } // namespace nss_test diff --git a/gtests/ssl_gtest/test_io.cc b/gtests/ssl_gtest/test_io.cc index 6d792c520..1ac752bfd 100644 --- a/gtests/ssl_gtest/test_io.cc +++ b/gtests/ssl_gtest/test_io.cc @@ -25,10 +25,13 @@ namespace nss_test { if (g_ssl_gtest_verbose) LOG(a); \ } while (false) +PRDescIdentity DummyPrSocket::LayerId() { + static PRDescIdentity id = PR_GetUniqueIdentity("dummysocket"); + return id; +} + ScopedPRFileDesc DummyPrSocket::CreateFD() { - static PRDescIdentity test_fd_identity = - PR_GetUniqueIdentity("testtransportadapter"); - return DummyIOLayerMethods::CreateFD(test_fd_identity, this); + return DummyIOLayerMethods::CreateFD(DummyPrSocket::LayerId(), this); } void DummyPrSocket::Reset() { diff --git a/gtests/ssl_gtest/test_io.h b/gtests/ssl_gtest/test_io.h index 062ae86c8..8efd2305b 100644 --- a/gtests/ssl_gtest/test_io.h +++ b/gtests/ssl_gtest/test_io.h @@ -68,6 +68,8 @@ class DummyPrSocket : public DummyIOLayerMethods { write_error_(0) {} virtual ~DummyPrSocket() {} + static PRDescIdentity LayerId(); + // Create a file descriptor that will reference this object. The fd must not // live longer than this adapter; call PR_Close() before. ScopedPRFileDesc CreateFD(); diff --git a/gtests/ssl_gtest/tls_agent.cc b/gtests/ssl_gtest/tls_agent.cc index fb66196b5..8b6b36fa7 100644 --- a/gtests/ssl_gtest/tls_agent.cc +++ b/gtests/ssl_gtest/tls_agent.cc @@ -640,6 +640,15 @@ void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, } } +void TlsAgent::CheckEpochs(uint16_t expected_read, + uint16_t expected_write) const { + uint16_t read_epoch = 0; + uint16_t write_epoch = 0; + EXPECT_EQ(SECSuccess, SSLInt_GetEpochs(ssl_fd(), &read_epoch, &write_epoch)); + EXPECT_EQ(expected_read, read_epoch) << role_str() << " read epoch"; + EXPECT_EQ(expected_write, write_epoch) << role_str() << " write epoch"; +} + void TlsAgent::EnableSrtp() { EXPECT_TRUE(EnsureTlsSetup()); const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80, diff --git a/gtests/ssl_gtest/tls_agent.h b/gtests/ssl_gtest/tls_agent.h index 020221868..a2ffe7711 100644 --- a/gtests/ssl_gtest/tls_agent.h +++ b/gtests/ssl_gtest/tls_agent.h @@ -139,6 +139,7 @@ class TlsAgent : public PollTarget { const std::string& expected = "") const; void EnableSrtp(); void CheckSrtp() const; + void CheckEpochs(uint16_t expected_read, uint16_t expected_write) const; void CheckErrorCode(int32_t expected) const; void WaitForErrorCode(int32_t expected, uint32_t delay) const; // Send data on the socket, encrypting it. diff --git a/gtests/ssl_gtest/tls_connect.cc b/gtests/ssl_gtest/tls_connect.cc index c48ae38ec..d992041df 100644 --- a/gtests/ssl_gtest/tls_connect.cc +++ b/gtests/ssl_gtest/tls_connect.cc @@ -167,18 +167,8 @@ void TlsConnectTestBase::CheckShares( void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const { - uint16_t read_epoch = 0; - uint16_t write_epoch = 0; - - EXPECT_EQ(SECSuccess, - SSLInt_GetEpochs(client_->ssl_fd(), &read_epoch, &write_epoch)); - EXPECT_EQ(server_epoch, read_epoch) << "client read epoch"; - EXPECT_EQ(client_epoch, write_epoch) << "client write epoch"; - - EXPECT_EQ(SECSuccess, - SSLInt_GetEpochs(server_->ssl_fd(), &read_epoch, &write_epoch)); - EXPECT_EQ(client_epoch, read_epoch) << "server read epoch"; - EXPECT_EQ(server_epoch, write_epoch) << "server write epoch"; + client_->CheckEpochs(server_epoch, client_epoch); + server_->CheckEpochs(client_epoch, server_epoch); } void TlsConnectTestBase::ClearStats() { diff --git a/lib/ssl/dtlscon.c b/lib/ssl/dtlscon.c index a5c604bca..bbd2f6d79 100644 --- a/lib/ssl/dtlscon.c +++ b/lib/ssl/dtlscon.c @@ -542,7 +542,6 @@ dtls_QueueMessage(sslSocket *ss, SSLContentType ct, /* Add DTLS handshake message to the pending queue * Empty the sendBuf buffer. - * This function returns SECSuccess or SECFailure, never SECWouldBlock. * Always set sendBuf.len to 0, even when returning SECFailure. * * Called from: diff --git a/lib/ssl/ssl3con.c b/lib/ssl/ssl3con.c index 7bb830525..5ce729876 100644 --- a/lib/ssl/ssl3con.c +++ b/lib/ssl/ssl3con.c @@ -2314,8 +2314,8 @@ ssl_ProtectNextRecord(sslSocket *ss, ssl3CipherSpec *spec, SSLContentType ct, * Returns the number of bytes of plaintext that were successfully sent * plus the number of bytes of plaintext that were copied into the * output (write) buffer. - * Returns SECFailure on a hard IO error, memory error, or crypto error. - * Does NOT return SECWouldBlock. + * Returns -1 on an error. PR_WOULD_BLOCK_ERROR is set if the error is blocking + * and not terminal. * * Notes on the use of the private ssl flags: * (no private SSL flags) @@ -2360,13 +2360,26 @@ ssl3_SendRecord(sslSocket *ss, * error, so don't overwrite. */ PORT_SetError(SSL_ERROR_HANDSHAKE_FAILED); } - return SECFailure; + return -1; } /* check for Token Presence */ if (!ssl3_ClientAuthTokenPresent(ss->sec.ci.sid)) { PORT_SetError(SSL_ERROR_TOKEN_INSERTION_REMOVAL); - return SECFailure; + return -1; + } + + if (ss->recordWriteCallback) { + PRUint16 epoch; + ssl_GetSpecReadLock(ss); + epoch = ss->ssl3.cwSpec->epoch; + ssl_ReleaseSpecReadLock(ss); + rv = ss->recordWriteCallback(ss->fd, epoch, ct, pIn, nIn, + ss->recordWriteCallbackArg); + if (rv != SECSuccess) { + return -1; + } + return nIn; } if (cwSpec) { @@ -2470,7 +2483,7 @@ loser: #define SSL3_PENDING_HIGH_WATER 1024 /* Attempt to send the content of "in" in an SSL application_data record. - * Returns "len" or SECFailure, never SECWouldBlock, nor SECSuccess. + * Returns "len" or -1 on failure. */ int ssl3_SendApplicationData(sslSocket *ss, const unsigned char *in, @@ -2485,21 +2498,21 @@ ssl3_SendApplicationData(sslSocket *ss, const unsigned char *in, PORT_Assert(!(flags & ssl_SEND_FLAG_NO_RETRANSMIT)); if (len < 0 || !in) { PORT_SetError(PR_INVALID_ARGUMENT_ERROR); - return SECFailure; + return -1; } if (ss->pendingBuf.len > SSL3_PENDING_HIGH_WATER && !ssl_SocketIsBlocking(ss)) { PORT_Assert(!ssl_SocketIsBlocking(ss)); PORT_SetError(PR_WOULD_BLOCK_ERROR); - return SECFailure; + return -1; } if (ss->appDataBuffered && len) { PORT_Assert(in[0] == (unsigned char)(ss->appDataBuffered)); if (in[0] != (unsigned char)(ss->appDataBuffered)) { PORT_SetError(PR_INVALID_ARGUMENT_ERROR); - return SECFailure; + return -1; } in++; len--; @@ -2548,7 +2561,7 @@ ssl3_SendApplicationData(sslSocket *ss, const unsigned char *in, PORT_Assert(ss->lastWriteBlocked); break; } - return SECFailure; /* error code set by ssl3_SendRecord */ + return -1; /* error code set by ssl3_SendRecord */ } totalSent += sent; if (ss->pendingBuf.len) { @@ -2577,7 +2590,6 @@ ssl3_SendApplicationData(sslSocket *ss, const unsigned char *in, } /* Attempt to send buffered handshake messages. - * This function returns SECSuccess or SECFailure, never SECWouldBlock. * Always set sendBuf.len to 0, even when returning SECFailure. * * Depending on whether we are doing DTLS or not, this either calls @@ -2600,7 +2612,6 @@ ssl3_FlushHandshake(sslSocket *ss, PRInt32 flags) } /* Attempt to send the content of sendBuf buffer in an SSL handshake record. - * This function returns SECSuccess or SECFailure, never SECWouldBlock. * Always set sendBuf.len to 0, even when returning SECFailure. * * Called from ssl3_FlushHandshake @@ -7603,7 +7614,8 @@ ssl3_SendClientSecondRound(sslSocket *ss) " certificate authentication is still pending.", SSL_GETPID(), ss->fd)); ss->ssl3.hs.restartTarget = ssl3_SendClientSecondRound; - return SECWouldBlock; + PORT_SetError(PR_WOULD_BLOCK_ERROR); + return SECFailure; } ssl_GetXmitBufLock(ss); /*******************************/ @@ -10898,13 +10910,6 @@ ssl3_AuthCertificateComplete(sslSocket *ss, PRErrorCode error) } rv = target(ss); - /* Even if we blocked here, we have accomplished enough to claim - * success. Any remaining work will be taken care of by subsequent - * calls to SSL_ForceHandshake/PR_Send/PR_Read/etc. - */ - if (rv == SECWouldBlock) { - rv = SECSuccess; - } } else { SSL_TRC(3, ("%d: SSL3[%p]: certificate authentication won the race with" " peer's finished message", @@ -11445,7 +11450,8 @@ xmit_loser: } ss->ssl3.hs.restartTarget = ssl3_FinishHandshake; - return SECWouldBlock; + PORT_SetError(PR_WOULD_BLOCK_ERROR); + return SECFailure; } rv = ssl3_FinishHandshake(ss); @@ -11649,9 +11655,10 @@ ssl3_HandleHandshakeMessage(sslSocket *ss, PRUint8 *b, PRUint32 length, * authenticate the certificate in ssl3_HandleCertificateStatus. */ rv = ssl3_AuthCertificate(ss); /* sets ss->ssl3.hs.ws */ - PORT_Assert(rv != SECWouldBlock); if (rv != SECSuccess) { - return rv; + /* This can't block. */ + PORT_Assert(PORT_GetError() != PR_WOULD_BLOCK_ERROR); + return SECFailure; } } @@ -11809,28 +11816,17 @@ ssl3_HandlePostHelloHandshakeMessage(sslSocket *ss, PRUint8 *b, static SECStatus ssl3_HandleHandshake(sslSocket *ss, sslBuffer *origBuf) { - /* - * There may be a partial handshake message already in the handshake - * state. The incoming buffer may contain another portion, or a - * complete message or several messages followed by another portion. - * - * Each message is made contiguous before being passed to the actual - * message parser. - */ - sslBuffer *buf = &ss->ssl3.hs.msgState; /* do not lose the original buffer pointer */ + sslBuffer buf = *origBuf; /* Work from a copy. */ SECStatus rv; PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss)); PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); - if (buf->buf == NULL) { - *buf = *origBuf; - } - while (buf->len > 0) { + while (buf.len > 0) { if (ss->ssl3.hs.header_bytes < 4) { PRUint8 t; - t = *(buf->buf++); - buf->len--; + t = *(buf.buf++); + buf.len--; if (ss->ssl3.hs.header_bytes++ == 0) ss->ssl3.hs.msg_type = (SSLHandshakeType)t; else @@ -11847,7 +11843,7 @@ ssl3_HandleHandshake(sslSocket *ss, sslBuffer *origBuf) #undef MAX_HANDSHAKE_MSG_LEN /* If msg_len is zero, be sure we fall through, - ** even if buf->len is zero. + ** even if buf.len is zero. */ if (ss->ssl3.hs.msg_len > 0) continue; @@ -11858,22 +11854,15 @@ ssl3_HandleHandshake(sslSocket *ss, sslBuffer *origBuf) * data available for this message. If it can be done right out * of the original buffer, then use it from there. */ - if (ss->ssl3.hs.msg_body.len == 0 && buf->len >= ss->ssl3.hs.msg_len) { + if (ss->ssl3.hs.msg_body.len == 0 && buf.len >= ss->ssl3.hs.msg_len) { /* handle it from input buffer */ - rv = ssl3_HandleHandshakeMessage(ss, buf->buf, ss->ssl3.hs.msg_len, - buf->len == ss->ssl3.hs.msg_len); - if (rv == SECFailure) { - /* This test wants to fall through on either - * SECSuccess or SECWouldBlock. - * ssl3_HandleHandshakeMessage MUST set the error code. - */ - return rv; - } - buf->buf += ss->ssl3.hs.msg_len; - buf->len -= ss->ssl3.hs.msg_len; + rv = ssl3_HandleHandshakeMessage(ss, buf.buf, ss->ssl3.hs.msg_len, + buf.len == ss->ssl3.hs.msg_len); + buf.buf += ss->ssl3.hs.msg_len; + buf.len -= ss->ssl3.hs.msg_len; ss->ssl3.hs.msg_len = 0; ss->ssl3.hs.header_bytes = 0; - if (rv != SECSuccess) { /* return if SECWouldBlock. */ + if (rv != SECSuccess) { return rv; } } else { @@ -11881,7 +11870,7 @@ ssl3_HandleHandshake(sslSocket *ss, sslBuffer *origBuf) unsigned int bytes; PORT_Assert(ss->ssl3.hs.msg_body.len < ss->ssl3.hs.msg_len); - bytes = PR_MIN(buf->len, ss->ssl3.hs.msg_len - ss->ssl3.hs.msg_body.len); + bytes = PR_MIN(buf.len, ss->ssl3.hs.msg_len - ss->ssl3.hs.msg_body.len); /* Grow the buffer if needed */ rv = sslBuffer_Grow(&ss->ssl3.hs.msg_body, ss->ssl3.hs.msg_len); @@ -11891,10 +11880,10 @@ ssl3_HandleHandshake(sslSocket *ss, sslBuffer *origBuf) } PORT_Memcpy(ss->ssl3.hs.msg_body.buf + ss->ssl3.hs.msg_body.len, - buf->buf, bytes); + buf.buf, bytes); ss->ssl3.hs.msg_body.len += bytes; - buf->buf += bytes; - buf->len -= bytes; + buf.buf += bytes; + buf.len -= bytes; PORT_Assert(ss->ssl3.hs.msg_body.len <= ss->ssl3.hs.msg_len); @@ -11902,29 +11891,21 @@ ssl3_HandleHandshake(sslSocket *ss, sslBuffer *origBuf) if (ss->ssl3.hs.msg_body.len == ss->ssl3.hs.msg_len) { rv = ssl3_HandleHandshakeMessage( ss, ss->ssl3.hs.msg_body.buf, ss->ssl3.hs.msg_len, - buf->len == 0); - if (rv == SECFailure) { - /* This test wants to fall through on either - * SECSuccess or SECWouldBlock. - * ssl3_HandleHandshakeMessage MUST set error code. - */ - return rv; - } + buf.len == 0); ss->ssl3.hs.msg_body.len = 0; ss->ssl3.hs.msg_len = 0; ss->ssl3.hs.header_bytes = 0; - if (rv != SECSuccess) { /* return if SECWouldBlock. */ + if (rv != SECSuccess) { return rv; } } else { - PORT_Assert(buf->len == 0); + PORT_Assert(buf.len == 0); break; } } } /* end loop */ origBuf->len = 0; /* So ssl3_GatherAppDataRecord will keep looping. */ - buf->buf = NULL; /* not a leak. */ return SECSuccess; } @@ -12372,7 +12353,7 @@ ssl3_HandleNonApplicationData(sslSocket *ss, SSLContentType rType, ssl_GetSSL3HandshakeLock(ss); /* All the functions called in this switch MUST set error code if - ** they return SECFailure or SECWouldBlock. + ** they return SECFailure. */ switch (rType) { case ssl_ct_change_cipher_spec: diff --git a/lib/ssl/ssl3gthr.c b/lib/ssl/ssl3gthr.c index 64a1878f7..e7ca65a36 100644 --- a/lib/ssl/ssl3gthr.c +++ b/lib/ssl/ssl3gthr.c @@ -389,7 +389,6 @@ dtls_GatherData(sslSocket *ss, sslGather *gs, int flags) * application data is available. * Returns 0 if ssl3_GatherData hits EOF. * Returns -1 on read error, or PR_WOULD_BLOCK_ERROR, or handleRecord error. - * Returns -2 on SECWouldBlock return from ssl3_HandleRecord. * * Called from ssl_GatherRecord1stHandshake in sslcon.c, * and from SSL_ForceHandshake in sslsecur.c @@ -422,13 +421,19 @@ ssl3_GatherCompleteHandshake(sslSocket *ss, int flags) PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss)); do { - PRBool handleRecordNow = PR_FALSE; PRBool processingEarlyData; ssl_GetSSL3HandshakeLock(ss); processingEarlyData = ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted; + /* If we have a detached record layer, don't ever gather. */ + if (ss->recordWriteCallback) { + ssl_ReleaseSSL3HandshakeLock(ss); + PORT_SetError(PR_WOULD_BLOCK_ERROR); + return (int)SECFailure; + } + /* Without this, we may end up wrongly reporting * SSL_ERROR_RX_UNEXPECTED_* errors if we receive any records from the * peer while we are waiting to be restarted. @@ -439,93 +444,68 @@ ssl3_GatherCompleteHandshake(sslSocket *ss, int flags) return (int)SECFailure; } - /* Treat an empty msgState like a NULL msgState. (Most of the time - * when ssl3_HandleHandshake returns SECWouldBlock, it leaves - * behind a non-NULL but zero-length msgState). - * Test: async_cert_restart_server_sends_hello_request_first_in_separate_record - */ - if (ss->ssl3.hs.msgState.buf) { - if (ss->ssl3.hs.msgState.len == 0) { - ss->ssl3.hs.msgState.buf = NULL; - } else { - handleRecordNow = PR_TRUE; - } - } - ssl_ReleaseSSL3HandshakeLock(ss); - if (handleRecordNow) { - /* ssl3_HandleHandshake previously returned SECWouldBlock and the - * as-yet-unprocessed plaintext of that previous handshake record. - * We need to process it now before we overwrite it with the next - * handshake record. - */ - SSL_DBG(("%d: SSL3[%d]: resuming handshake", - SSL_GETPID(), ss->fd)); - PORT_Assert(!IS_DTLS(ss)); - rv = ssl3_HandleNonApplicationData(ss, ssl_ct_handshake, - 0, 0, &ss->gs.buf); - } else { - /* State for SSLv2 client hello support. */ - ssl2Gather ssl2gs = { PR_FALSE, 0 }; - ssl2Gather *ssl2gs_ptr = NULL; + /* State for SSLv2 client hello support. */ + ssl2Gather ssl2gs = { PR_FALSE, 0 }; + ssl2Gather *ssl2gs_ptr = NULL; - if (ss->sec.isServer && ss->opt.enableV2CompatibleHello && - ss->ssl3.hs.ws == wait_client_hello) { - ssl2gs_ptr = &ssl2gs; - } + /* If we're a server and waiting for a client hello, accept v2. */ + if (ss->sec.isServer && ss->opt.enableV2CompatibleHello && + ss->ssl3.hs.ws == wait_client_hello) { + ssl2gs_ptr = &ssl2gs; + } - /* bring in the next sslv3 record. */ - if (ss->recvdCloseNotify) { - /* RFC 5246 Section 7.2.1: - * Any data received after a closure alert is ignored. - */ - return 0; - } + /* bring in the next sslv3 record. */ + if (ss->recvdCloseNotify) { + /* RFC 5246 Section 7.2.1: + * Any data received after a closure alert is ignored. + */ + return 0; + } - if (!IS_DTLS(ss)) { - /* Passing a non-NULL ssl2gs here enables detection of - * SSLv2-compatible ClientHello messages. */ - rv = ssl3_GatherData(ss, &ss->gs, flags, ssl2gs_ptr); - } else { - rv = dtls_GatherData(ss, &ss->gs, flags); - - /* If we got a would block error, that means that no data was - * available, so we check the timer to see if it's time to - * retransmit */ - if (rv == SECFailure && - (PORT_GetError() == PR_WOULD_BLOCK_ERROR)) { - dtls_CheckTimer(ss); - /* Restore the error in case something succeeded */ - PORT_SetError(PR_WOULD_BLOCK_ERROR); - } + if (!IS_DTLS(ss)) { + /* If we're a server waiting for a ClientHello then pass + * ssl2gs to support SSLv2 ClientHello messages. */ + rv = ssl3_GatherData(ss, &ss->gs, flags, ssl2gs_ptr); + } else { + rv = dtls_GatherData(ss, &ss->gs, flags); + + /* If we got a would block error, that means that no data was + * available, so we check the timer to see if it's time to + * retransmit */ + if (rv == SECFailure && + (PORT_GetError() == PR_WOULD_BLOCK_ERROR)) { + dtls_CheckTimer(ss); + /* Restore the error in case something succeeded */ + PORT_SetError(PR_WOULD_BLOCK_ERROR); } + } - if (rv <= 0) { - return rv; - } + if (rv <= 0) { + return rv; + } - if (ssl2gs.isV2) { - rv = ssl3_HandleV2ClientHello(ss, ss->gs.inbuf.buf, - ss->gs.inbuf.len, - ssl2gs.padding); - if (rv < 0) { - return rv; - } - } else { - /* decipher it, and handle it if it's a handshake. - * If it's application data, ss->gs.buf will not be empty upon return. - * If it's a change cipher spec, alert, or handshake message, - * ss->gs.buf.len will be 0 when ssl3_HandleRecord returns SECSuccess. - * - * cText only needs to be valid for this next function call, so - * it can borrow gs.hdr. - */ - cText.hdr = ss->gs.hdr; - cText.hdrLen = ss->gs.hdrLen; - cText.buf = &ss->gs.inbuf; - rv = ssl3_HandleRecord(ss, &cText); + if (ssl2gs.isV2) { + rv = ssl3_HandleV2ClientHello(ss, ss->gs.inbuf.buf, + ss->gs.inbuf.len, + ssl2gs.padding); + if (rv < 0) { + return rv; } + } else { + /* decipher it, and handle it if it's a handshake. + * If it's application data, ss->gs.buf will not be empty upon return. + * If it's a change cipher spec, alert, or handshake message, + * ss->gs.buf.len will be 0 when ssl3_HandleRecord returns SECSuccess. + * + * cText only needs to be valid for this next function call, so + * it can borrow gs.hdr. + */ + cText.hdr = ss->gs.hdr; + cText.hdrLen = ss->gs.hdrLen; + cText.buf = &ss->gs.inbuf; + rv = ssl3_HandleRecord(ss, &cText); } if (rv < 0) { return ss->recvdCloseNotify ? 0 : rv; @@ -575,7 +555,7 @@ ssl3_GatherCompleteHandshake(sslSocket *ss, int flags) * delivered to the application before the handshake completes. */ ssl_ReleaseSSL3HandshakeLock(ss); PORT_SetError(PR_WOULD_BLOCK_ERROR); - return SECWouldBlock; + return -1; } ssl_ReleaseSSL3HandshakeLock(ss); } while (keepGoing); @@ -596,7 +576,6 @@ ssl3_GatherCompleteHandshake(sslSocket *ss, int flags) * Returns 1 when application data is available. * Returns 0 if ssl3_GatherData hits EOF. * Returns -1 on read error, or PR_WOULD_BLOCK_ERROR, or handleRecord error. - * Returns -2 on SECWouldBlock return from ssl3_HandleRecord. * * Called from DoRecv in sslsecur.c * Caller must hold the recv buf lock. @@ -616,3 +595,88 @@ ssl3_GatherAppDataRecord(sslSocket *ss, int flags) return rv; } + +SECStatus +SSLExp_RecordLayerData(PRFileDesc *fd, PRUint16 epoch, + SSLContentType contentType, + const PRUint8 *data, unsigned int len) +{ + SECStatus rv; + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + if (IS_DTLS(ss) || data == NULL || len == 0) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + /* Run any handshake function. If SSL_RecordLayerData is the only way that + * the handshake is driven, then this is necessary to ensure that + * ssl_BeginClientHandshake or ssl_BeginServerHandshake is called. Note that + * the other function that might be set to ss->handshake, + * ssl3_GatherCompleteHandshake, does nothing when this function is used. */ + ssl_Get1stHandshakeLock(ss); + rv = ssl_Do1stHandshake(ss); + if (rv != SECSuccess && PORT_GetError() != PR_WOULD_BLOCK_ERROR) { + goto early_loser; /* Rely on the existing code. */ + } + + /* Don't allow application data before handshake completion. */ + if (contentType == ssl_ct_application_data && !ss->firstHsDone) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + goto early_loser; + } + + /* Then we can validate the epoch. */ + PRErrorCode epochError; + ssl_GetSpecReadLock(ss); + if (epoch < ss->ssl3.crSpec->epoch) { + epochError = SEC_ERROR_INVALID_ARGS; /* Too c/old. */ + } else if (epoch > ss->ssl3.crSpec->epoch) { + epochError = PR_WOULD_BLOCK_ERROR; /* Too warm/new. */ + } else { + epochError = 0; /* Just right. */ + } + ssl_ReleaseSpecReadLock(ss); + if (epochError) { + PORT_SetError(epochError); + goto early_loser; + } + + /* If the handshake is still running, we need to run that. */ + ssl_Get1stHandshakeLock(ss); + rv = ssl_Do1stHandshake(ss); + if (rv != SECSuccess && PORT_GetError() != PR_WOULD_BLOCK_ERROR) { + ssl_Release1stHandshakeLock(ss); + return SECFailure; + } + + /* Finally, save the data... */ + ssl_GetRecvBufLock(ss); + rv = sslBuffer_Append(&ss->gs.buf, data, len); + if (rv != SECSuccess) { + goto loser; + } + + /* ...and process it. Just saving application data is enough for it to be + * available to PR_Read(). */ + if (contentType != ssl_ct_application_data) { + rv = ssl3_HandleNonApplicationData(ss, contentType, 0, 0, &ss->gs.buf); + if (rv != SECSuccess) { + goto loser; + } + } + + ssl_ReleaseRecvBufLock(ss); + ssl_Release1stHandshakeLock(ss); + return SECSuccess; + +loser: + /* Make sure that any data is not used again. */ + ss->gs.buf.len = 0; + ssl_ReleaseRecvBufLock(ss); +early_loser: + ssl_Release1stHandshakeLock(ss); + return SECFailure; +} diff --git a/lib/ssl/sslcon.c b/lib/ssl/sslcon.c index bc63e1537..603da3e15 100644 --- a/lib/ssl/sslcon.c +++ b/lib/ssl/sslcon.c @@ -44,17 +44,13 @@ const char *ssl_version = "SECURITY_VERSION:" * This function acquires and releases the RecvBufLock. * * returns SECSuccess for success. - * returns SECWouldBlock when that value is returned by - * ssl3_GatherCompleteHandshake(). - * returns SECFailure on all other errors. + * returns SECFailure on error, setting PR_WOULD_BLOCK_ERROR if only blocked. * * The gather functions called by ssl_GatherRecord1stHandshake are expected * to return values interpreted as follows: * 1 : the function completed without error. * 0 : the function read EOF. * -1 : read error, or PR_WOULD_BLOCK_ERROR, or handleRecord error. - * -2 : the function wants ssl_GatherRecord1stHandshake to be called again - * immediately, by ssl_Do1stHandshake. * * This code is similar to, and easily confused with, DoRecv() in sslsecur.c * @@ -82,16 +78,14 @@ ssl_GatherRecord1stHandshake(sslSocket *ss) ssl_ReleaseRecvBufLock(ss); if (rv <= 0) { - if (rv == SECWouldBlock) { - /* Progress is blocked waiting for callback completion. */ - SSL_TRC(10, ("%d: SSL[%d]: handshake blocked (need %d)", - SSL_GETPID(), ss->fd, ss->gs.remainder)); - return SECWouldBlock; - } if (rv == 0) { /* EOF. Loser */ PORT_SetError(PR_END_OF_FILE_ERROR); } + if (PORT_GetError() == PR_WOULD_BLOCK_ERROR) { + SSL_TRC(10, ("%d: SSL[%d]: handshake blocked (need %d)", + SSL_GETPID(), ss->fd, ss->gs.remainder)); + } return SECFailure; /* rv is < 0 here. */ } diff --git a/lib/ssl/ssldef.c b/lib/ssl/ssldef.c index be5bcb269..3ed197950 100644 --- a/lib/ssl/ssldef.c +++ b/lib/ssl/ssldef.c @@ -84,7 +84,7 @@ ssl_DefRecv(sslSocket *ss, unsigned char *buf, int len, int flags) * For blocking sockets, always returns len or SECFailure, no short writes. * For non-blocking sockets: * Returns positive count if any data was written, else returns SECFailure. - * Short writes may occur. Does not return SECWouldBlock. + * Short writes may occur. */ int ssl_DefSend(sslSocket *ss, const unsigned char *buf, int len, int flags) diff --git a/lib/ssl/sslexp.h b/lib/ssl/sslexp.h index d3b2e854d..8965f5606 100644 --- a/lib/ssl/sslexp.h +++ b/lib/ssl/sslexp.h @@ -536,6 +536,63 @@ typedef void(PR_CALLBACK *SSLSecretCallback)( (PRFileDesc * _fd, SSLSecretCallback _cb, void *_arg), \ (fd, cb, arg)) +/* SSL_RecordLayerWriteCallback() is used to replace the TLS record layer. This + * function installs a callback that TLS calls when it would otherwise encrypt + * and write a record to the underlying NSPR IO layer. The application is + * responsible for ensuring that these records are encrypted and written. + * + * Calling this API also disables reads from the underlying NSPR layer. The + * application is expected to push data when it is available using + * SSL_RecordLayerData(). + * + * When data would be written, the provided SSLRecordWriteCallback with the + * epoch, TLS content type, and the data. The data provided to the callback is + * not split into record-sized writes. If the callback returns SECFailure, the + * write will be considered to have failed; in particular, PR_WOULD_BLOCK_ERROR + * is not handled specially. + * + * If TLS 1.3 is in use, the epoch indicates the expected level of protection + * that the record would receive, this matches that used in DTLS 1.3: + * + * - epoch 0 corresponds to no record protection + * - epoch 1 corresponds to 0-RTT + * - epoch 2 corresponds to TLS handshake + * - epoch 3 and higher are application data + * + * Prior versions of TLS use epoch 1 and higher for application data. + * + * This API is not supported for DTLS sockets. + */ +typedef SECStatus(PR_CALLBACK *SSLRecordWriteCallback)( + PRFileDesc *fd, PRUint16 epoch, SSLContentType contentType, + const PRUint8 *data, unsigned int len, void *arg); + +#define SSL_RecordLayerWriteCallback(fd, writeCb, arg) \ + SSL_EXPERIMENTAL_API("SSL_RecordLayerWriteCallback", \ + (PRFileDesc * _fd, SSLRecordWriteCallback _wCb, \ + void *_arg), \ + (fd, writeCb, arg)) + +/* SSL_RecordLayerData() is used to provide new data to TLS. The application + * indicates the epoch (see the description of SSL_RecordLayerWriteCallback()), + * content type, and the data that was received. The application is responsible + * for removing any encryption or other protection before passing data to this + * function. + * + * This returns SECSuccess if the data was successfully processed. If this + * function is used to drive the handshake and the caller needs to know when the + * handshake is complete, a call to SSL_ForceHandshake will return SECSuccess + * when the handshake is complete. + * + * This API is not supported for DTLS sockets. + */ +#define SSL_RecordLayerData(fd, epoch, ct, data, len) \ + SSL_EXPERIMENTAL_API("SSL_RecordLayerData", \ + (PRFileDesc * _fd, PRUint16 _epoch, \ + SSLContentType _contentType, \ + const PRUint8 *_data, unsigned int _len), \ + (fd, epoch, ct, data, len)) + /* Deprecated experimental APIs */ #define SSL_UseAltServerHelloType(fd, enable) SSL_DEPRECATED_EXPERIMENTAL_API diff --git a/lib/ssl/sslimpl.h b/lib/ssl/sslimpl.h index 5f7743c75..f89b2fe83 100644 --- a/lib/ssl/sslimpl.h +++ b/lib/ssl/sslimpl.h @@ -622,8 +622,6 @@ typedef struct SSL3HandshakeStateStr { unsigned long msg_len; PRBool isResuming; /* we are resuming (not used in TLS 1.3) */ PRBool sendingSCSV; /* instead of empty RI */ - sslBuffer msgState; /* current state for handshake messages*/ - /* protected by recvBufLock */ /* The session ticket received in a NewSessionTicket message is temporarily * stored in newSessionTicket until the handshake is finished; then it is @@ -996,6 +994,8 @@ struct sslSocketStr { void *resumptionTokenContext; SSLSecretCallback secretCallback; void *secretCallbackArg; + SSLRecordWriteCallback recordWriteCallback; + void *recordWriteCallbackArg; PRIntervalTime rTimeout; /* timeout for NSPR I/O */ PRIntervalTime wTimeout; /* timeout for NSPR I/O */ @@ -1176,7 +1176,7 @@ extern SECStatus ssl_SaveWriteData(sslSocket *ss, const void *p, unsigned int l); extern SECStatus ssl_BeginClientHandshake(sslSocket *ss); extern SECStatus ssl_BeginServerHandshake(sslSocket *ss); -extern int ssl_Do1stHandshake(sslSocket *ss); +extern SECStatus ssl_Do1stHandshake(sslSocket *ss); extern SECStatus ssl3_InitPendingCipherSpecs(sslSocket *ss, PK11SymKey *secret, PRBool derive); @@ -1744,7 +1744,14 @@ SECStatus SSLExp_GetResumptionTokenInfo(const PRUint8 *tokenData, unsigned int t SECStatus SSLExp_DestroyResumptionTokenInfo(SSLResumptionTokenInfo *token); -SECStatus SSLExp_SecretCallback(PRFileDesc *fd, SSLSecretCallback cb, void *arg); +SECStatus SSLExp_SecretCallback(PRFileDesc *fd, SSLSecretCallback cb, + void *arg); +SECStatus SSLExp_RecordLayerWriteCallback(PRFileDesc *fd, + SSLRecordWriteCallback write, + void *arg); +SECStatus SSLExp_RecordLayerData(PRFileDesc *fd, PRUint16 epoch, + SSLContentType contentType, + const PRUint8 *data, unsigned int len); #define SSLResumptionTokenVersion 2 diff --git a/lib/ssl/sslsecur.c b/lib/ssl/sslsecur.c index b6c2c3523..79634e403 100644 --- a/lib/ssl/sslsecur.c +++ b/lib/ssl/sslsecur.c @@ -16,32 +16,7 @@ #include "nss.h" /* for NSS_RegisterShutdown */ #include "prinit.h" /* for PR_CallOnceWithArg */ -/* Returns a SECStatus: SECSuccess or SECFailure, NOT SECWouldBlock. - * - * Currently, the list of functions called through ss->handshake is: - * - * In sslsocks.c: - * SocksGatherRecord - * SocksHandleReply - * SocksStartGather - * - * In sslcon.c: - * ssl_GatherRecord1stHandshake - * ssl_BeginClientHandshake - * ssl_BeginServerHandshake - * - * The ss->handshake function returns SECWouldBlock if it was returned by - * one of the callback functions, via one of these paths: - * - * - ssl_GatherRecord1stHandshake() -> ssl3_GatherCompleteHandshake() -> - * ssl3_HandleRecord() -> ssl3_HandleHandshake() -> - * ssl3_HandleHandshakeMessage() -> ssl3_HandleCertificate() -> - * ss->handleBadCert() - * - * - ssl_GatherRecord1stHandshake() -> ssl3_GatherCompleteHandshake() -> - * ssl3_HandleRecord() -> ssl3_HandleHandshake() -> - * ssl3_HandleHandshakeMessage() -> ssl3_HandleCertificateRequest() -> - * ss->getClientAuthData() +/* Step through the handshake functions. * * Called from: SSL_ForceHandshake (below), * ssl_SecureRecv (below) and @@ -52,10 +27,10 @@ * * Caller must hold the (write) handshakeLock. */ -int +SECStatus ssl_Do1stHandshake(sslSocket *ss) { - int rv = SECSuccess; + SECStatus rv = SECSuccess; while (ss->handshake && rv == SECSuccess) { PORT_Assert(ss->opt.noLocks || ssl_Have1stHandshakeLock(ss)); @@ -70,10 +45,6 @@ ssl_Do1stHandshake(sslSocket *ss) PORT_Assert(ss->opt.noLocks || !ssl_HaveXmitBufLock(ss)); PORT_Assert(ss->opt.noLocks || !ssl_HaveSSL3HandshakeLock(ss)); - if (rv == SECWouldBlock) { - PORT_SetError(PR_WOULD_BLOCK_ERROR); - rv = SECFailure; - } return rv; } @@ -106,8 +77,8 @@ ssl_FinishHandshake(sslSocket *ss) static SECStatus ssl3_AlwaysBlock(sslSocket *ss) { - PORT_SetError(PR_WOULD_BLOCK_ERROR); /* perhaps redundant. */ - return SECWouldBlock; + PORT_SetError(PR_WOULD_BLOCK_ERROR); + return SECFailure; } /* @@ -400,10 +371,13 @@ SSL_ForceHandshake(PRFileDesc *fd) ssl_ReleaseRecvBufLock(ss); if (gatherResult > 0) { rv = SECSuccess; - } else if (gatherResult == 0) { - PORT_SetError(PR_END_OF_FILE_ERROR); - } else if (gatherResult == SECWouldBlock) { - PORT_SetError(PR_WOULD_BLOCK_ERROR); + } else { + if (gatherResult == 0) { + PORT_SetError(PR_END_OF_FILE_ERROR); + } + /* We can rely on ssl3_GatherCompleteHandshake to set + * PR_WOULD_BLOCK_ERROR as needed here. */ + rv = SECFailure; } } else { PORT_Assert(!ss->firstHsDone); @@ -515,8 +489,7 @@ DoRecv(sslSocket *ss, unsigned char *out, int len, int flags) SSL_GETPID(), ss->fd)); goto done; } - if ((rv != SECWouldBlock) && - (PR_GetError() != PR_WOULD_BLOCK_ERROR)) { + if (PR_GetError() != PR_WOULD_BLOCK_ERROR) { /* Some random error */ goto done; } @@ -1011,6 +984,35 @@ ssl_SecureWrite(sslSocket *ss, const unsigned char *buf, int len) } SECStatus +SSLExp_RecordLayerWriteCallback(PRFileDesc *fd, SSLRecordWriteCallback cb, + void *arg) +{ + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + SSL_DBG(("%d: SSL[%d]: invalid socket for SSL_RecordLayerWriteCallback", + SSL_GETPID(), fd)); + return SECFailure; + } + if (IS_DTLS(ss)) { + SSL_DBG(("%d: SSL[%d]: DTLS socket for SSL_RecordLayerWriteCallback", + SSL_GETPID(), fd)); + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + /* This needs both HS and Xmit locks because this value is checked under + * both locks. HS to disable reading from the underlying IO layer; Xmit to + * prevent writing. */ + ssl_GetSSL3HandshakeLock(ss); + ssl_GetXmitBufLock(ss); + ss->recordWriteCallback = cb; + ss->recordWriteCallbackArg = arg; + ssl_ReleaseXmitBufLock(ss); + ssl_ReleaseSSL3HandshakeLock(ss); + return SECSuccess; +} + +SECStatus SSL_AlertReceivedCallback(PRFileDesc *fd, SSLAlertCallback cb, void *arg) { sslSocket *ss; diff --git a/lib/ssl/sslsock.c b/lib/ssl/sslsock.c index ad9b62fa9..4d4b7e1e6 100644 --- a/lib/ssl/sslsock.c +++ b/lib/ssl/sslsock.c @@ -4034,6 +4034,8 @@ struct { EXP(HelloRetryRequestCallback), EXP(InstallExtensionHooks), EXP(KeyUpdate), + EXP(RecordLayerData), + EXP(RecordLayerWriteCallback), EXP(SecretCallback), EXP(SendSessionTicket), EXP(SetESNIKeyPair), diff --git a/lib/ssl/tls13con.c b/lib/ssl/tls13con.c index 5737e6948..cbf149372 100644 --- a/lib/ssl/tls13con.c +++ b/lib/ssl/tls13con.c @@ -4459,7 +4459,8 @@ tls13_SendClientSecondRound(sslSocket *ss) " certificate authentication is still pending.", SSL_GETPID(), ss->fd)); ss->ssl3.hs.restartTarget = tls13_SendClientSecondRound; - return SECWouldBlock; + PORT_SetError(PR_WOULD_BLOCK_ERROR); + return SECFailure; } rv = tls13_ComputeApplicationSecrets(ss); |