diff options
Diffstat (limited to 'chromium/services/network/websocket.cc')
-rw-r--r-- | chromium/services/network/websocket.cc | 152 |
1 files changed, 100 insertions, 52 deletions
diff --git a/chromium/services/network/websocket.cc b/chromium/services/network/websocket.cc index ee535a0aa2c..dcf9d7ccb77 100644 --- a/chromium/services/network/websocket.cc +++ b/chromium/services/network/websocket.cc @@ -11,9 +11,11 @@ #include "base/bind.h" #include "base/bind_helpers.h" +#include "base/feature_list.h" #include "base/location.h" #include "base/logging.h" #include "base/macros.h" +#include "base/numerics/safe_conversions.h" #include "base/single_thread_task_runner.h" #include "base/strings/strcat.h" #include "base/strings/string_number_conversions.h" @@ -36,11 +38,16 @@ #include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode #include "net/websockets/websocket_handshake_request_info.h" #include "net/websockets/websocket_handshake_response_info.h" +#include "services/network/public/cpp/features.h" #include "services/network/websocket_factory.h" namespace network { namespace { +// What is considered a "small message" for the purposes of small message +// reassembly. +constexpr uint64_t kSmallMessageThreshhold = 1 << 16; + // Convert a mojom::WebSocketMessageType to a // net::WebSocketFrameHeader::OpCode net::WebSocketFrameHeader::OpCode MessageTypeToOpCode( @@ -118,14 +125,13 @@ class WebSocket::WebSocketEventHandler final void OnAddChannelResponse( std::unique_ptr<net::WebSocketHandshakeResponseInfo> response, const std::string& selected_subprotocol, - const std::string& extensions, - int64_t send_flow_control_quota) override; + const std::string& extensions) override; void OnDataFrame(bool fin, WebSocketMessageType type, base::span<const char> payload) override; + void OnSendDataFrameDone() override; bool HasPendingDataFrames() override; void OnClosingHandshake() override; - void OnSendFlowControlQuotaAdded(int64_t quota) override; void OnDropChannel(bool was_clean, uint16_t code, const std::string& reason) override; @@ -172,8 +178,7 @@ void WebSocket::WebSocketEventHandler::OnCreateURLRequest( void WebSocket::WebSocketEventHandler::OnAddChannelResponse( std::unique_ptr<net::WebSocketHandshakeResponseInfo> response, const std::string& selected_protocol, - const std::string& extensions, - int64_t send_flow_control_quota) { + const std::string& extensions) { DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse @" << reinterpret_cast<void*>(this) << " selected_protocol=\"" << selected_protocol << "\"" @@ -206,6 +211,7 @@ void WebSocket::WebSocketEventHandler::OnAddChannelResponse( impl_->Reset(); return; } + impl_->data_pipe_use_tracker_.Activate(); const MojoResult mojo_result = impl_->writable_watcher_.Watch( impl_->writable_.get(), MOJO_HANDLE_SIGNAL_WRITABLE, MOJO_WATCH_CONDITION_SATISFIED, @@ -241,8 +247,6 @@ void WebSocket::WebSocketEventHandler::OnAddChannelResponse( impl_->header_client_.reset(); impl_->client_.set_disconnect_handler(base::BindOnce( &WebSocket::OnConnectionError, base::Unretained(impl_), FROM_HERE)); - - impl_->client_->AddSendFlowControlQuota(send_flow_control_quota); } void WebSocket::WebSocketEventHandler::OnDataFrame( @@ -259,6 +263,11 @@ void WebSocket::WebSocketEventHandler::OnDataFrame( impl_->SendPendingDataFrames(); } +void WebSocket::WebSocketEventHandler::OnSendDataFrameDone() { + impl_->ResumeDataPipeReading(); + return; +} + bool WebSocket::WebSocketEventHandler::HasPendingDataFrames() { return !impl_->pending_data_frames_.empty(); } @@ -270,14 +279,6 @@ void WebSocket::WebSocketEventHandler::OnClosingHandshake() { impl_->client_->OnClosingHandshake(); } -void WebSocket::WebSocketEventHandler::OnSendFlowControlQuotaAdded( - int64_t quota) { - DVLOG(3) << "WebSocketEventHandler::OnSendFlowControlQuotaAdded @" - << reinterpret_cast<void*>(this) << " quota=" << quota; - - impl_->client_->AddSendFlowControlQuota(quota); -} - void WebSocket::WebSocketEventHandler::OnDropChannel( bool was_clean, uint16_t code, @@ -370,6 +371,14 @@ int WebSocket::WebSocketEventHandler::OnAuthRequired( return net::ERR_IO_PENDING; } +struct WebSocket::CloseInfo { + CloseInfo(uint16_t code, const std::string& reason) + : code(code), reason(reason) {} + + const uint16_t code; + const std::string reason; +}; + WebSocket::WebSocket( WebSocketFactory* factory, const GURL& url, @@ -386,6 +395,7 @@ WebSocket::WebSocket( mojo::PendingRemote<mojom::AuthenticationHandler> auth_handler, mojo::PendingRemote<mojom::TrustedHeaderClient> header_client, WebSocketThrottler::PendingConnection pending_connection_tracker, + DataPipeUseTracker data_pipe_use_tracker, base::TimeDelta delay) : factory_(factory), handshake_client_(std::move(handshake_client)), @@ -404,7 +414,10 @@ WebSocket::WebSocket( base::ThreadTaskRunnerHandle::Get()), readable_watcher_(FROM_HERE, mojo::SimpleWatcher::ArmingPolicy::MANUAL, - base::ThreadTaskRunnerHandle::Get()) { + base::ThreadTaskRunnerHandle::Get()), + data_pipe_use_tracker_(std::move(data_pipe_use_tracker)), + reassemble_short_messages_(base::FeatureList::IsEnabled( + network::features::kWebSocketReassembleShortMessages)) { DCHECK(handshake_client_); // If |require_network_isolation_key| is set on the URLRequestContext, // |isolation_info| must not be empty. @@ -447,32 +460,6 @@ WebSocket::~WebSocket() { // static const void* const WebSocket::kUserDataKey = &WebSocket::kUserDataKey; -void WebSocket::SendFrame(bool fin, - mojom::WebSocketMessageType type, - base::span<const uint8_t> data) { - DVLOG(3) << "WebSocket::SendFrame @" << reinterpret_cast<void*>(this) - << " fin=" << fin << " type=" << type << " data is " << data.size() - << " bytes"; - - DCHECK(channel_) - << "WebSocket::SendFrame is called but there is no active channel."; - DCHECK(handshake_succeeded_); - // This is guaranteed by the maximum size enforced on mojo messages. - DCHECK_LE(data.size(), static_cast<size_t>(INT_MAX)); - - // This is guaranteed by mojo. - DCHECK(IsKnownEnumValue(type)); - - // TODO(darin): Avoid this copy. - auto data_to_pass = base::MakeRefCounted<net::IOBuffer>(data.size()); - memcpy(data_to_pass->data(), data.data(), data.size()); - - // It's okay to ignore the result here because we don't access |this| after - // this point. - ignore_result(channel_->SendFrame(fin, MessageTypeToOpCode(type), - std::move(data_to_pass), data.size())); -} - void WebSocket::SendMessage(mojom::WebSocketMessageType type, uint64_t data_length) { DVLOG(3) << "WebSocket::SendMessage @" << reinterpret_cast<void*>(this) @@ -489,11 +476,15 @@ void WebSocket::SendMessage(mojom::WebSocketMessageType type, } DCHECK(IsKnownEnumValue(type)); - pending_send_data_frames_.push(DataFrame(type, data_length)); + const bool do_not_fragment = + reassemble_short_messages_ && data_length <= kSmallMessageThreshhold; + + pending_send_data_frames_.push(DataFrame(type, data_length, do_not_fragment)); // Safe if ReadAndSendFromDataPipe() deletes |this| because this method is // only called from mojo. - ReadAndSendFromDataPipe(); + if (!blocked_on_websocket_channel_) + ReadAndSendFromDataPipe(); } void WebSocket::StartReceiving() { @@ -503,13 +494,19 @@ void WebSocket::StartReceiving() { void WebSocket::StartClosingHandshake(uint16_t code, const std::string& reason) { - DVLOG(3) << "WebSocket::StartClosingHandshake @" - << reinterpret_cast<void*>(this) << " code=" << code << " reason=\"" - << reason << "\""; + DVLOG(3) << "WebSocket::StartClosingHandshake @" << this << " code=" << code + << " reason=\"" << reason << "\""; - DCHECK(channel_) - << "WebSocket::SendFrame is called but there is no active channel."; + DCHECK(channel_) << "WebSocket::StartClosingHandshake is called but there is " + "no active channel."; DCHECK(handshake_succeeded_); + if (!pending_send_data_frames_.empty()) { + // This has only been observed happening on Windows 7, but the Mojo API + // doesn't guarantee that it won't happen on other platforms. + pending_start_closing_handshake_ = + std::make_unique<CloseInfo>(code, reason); + return; + } ignore_result(channel_->StartClosingHandshake(code, reason)); } @@ -628,6 +625,7 @@ void WebSocket::SendPendingDataFrames() { << reinterpret_cast<void*>(this) << ", pending_data_frames_.size=" << pending_data_frames_.size() << ", wait_for_writable_?" << wait_for_writable_; + if (wait_for_writable_) { return; } @@ -714,7 +712,9 @@ void WebSocket::ReadAndSendFromDataPipe() { &buffer, &readable_size, MOJO_READ_DATA_FLAG_NONE); if (begin_result == MOJO_RESULT_SHOULD_WAIT) { wait_for_readable_ = true; - readable_watcher_.ArmOrNotify(); + if (!blocked_on_websocket_channel_) { + readable_watcher_.ArmOrNotify(); + } return; } if (begin_result == MOJO_RESULT_FAILED_PRECONDITION) { @@ -722,18 +722,55 @@ void WebSocket::ReadAndSendFromDataPipe() { } DCHECK_EQ(begin_result, MOJO_RESULT_OK); + if (readable_size < data_frame.data_length && data_frame.do_not_fragment && + !message_under_reassembly_) { + // The cast is needed to unambiguously select a constructor on 32-bit + // platforms. + message_under_reassembly_ = base::MakeRefCounted<net::IOBuffer>( + base::checked_cast<size_t>(data_frame.data_length)); + DCHECK_EQ(bytes_reassembled_, 0u); + } + + if (message_under_reassembly_) { + const size_t bytes_to_copy = + std::min(static_cast<uint64_t>(readable_size), + data_frame.data_length - bytes_reassembled_); + memcpy(message_under_reassembly_->data() + bytes_reassembled_, buffer, + bytes_to_copy); + bytes_reassembled_ += bytes_to_copy; + + const MojoResult end_result = readable_->EndReadData(bytes_to_copy); + DCHECK_EQ(end_result, MOJO_RESULT_OK); + + DCHECK_LE(bytes_reassembled_, data_frame.data_length); + if (bytes_reassembled_ == data_frame.data_length) { + bytes_reassembled_ = 0; + blocked_on_websocket_channel_ = true; + if (channel_->SendFrame( + /*fin=*/true, MessageTypeToOpCode(data_frame.type), + std::move(message_under_reassembly_), data_frame.data_length) == + net::WebSocketChannel::CHANNEL_DELETED) { + // |this| has been deleted. + return; + } + pending_send_data_frames_.pop(); + } + + continue; + } + const size_t size_to_send = std::min(static_cast<uint64_t>(readable_size), data_frame.data_length); auto data_to_pass = base::MakeRefCounted<net::IOBuffer>(size_to_send); const bool is_final = (size_to_send == data_frame.data_length); memcpy(data_to_pass->data(), buffer, size_to_send); + blocked_on_websocket_channel_ = true; if (channel_->SendFrame(is_final, MessageTypeToOpCode(data_frame.type), std::move(data_to_pass), size_to_send) == net::WebSocketChannel::CHANNEL_DELETED) { // |this| has been deleted. return; } - const MojoResult end_result = readable_->EndReadData(size_to_send); DCHECK_EQ(end_result, MOJO_RESULT_OK); @@ -746,7 +783,17 @@ void WebSocket::ReadAndSendFromDataPipe() { data_frame.type = mojom::WebSocketMessageType::CONTINUATION; data_frame.data_length -= size_to_send; } - return; + if (pending_start_closing_handshake_) { + std::unique_ptr<CloseInfo> close_info = + std::move(pending_start_closing_handshake_); + ignore_result( + channel_->StartClosingHandshake(close_info->code, close_info->reason)); + } +} + +void WebSocket::ResumeDataPipeReading() { + blocked_on_websocket_channel_ = false; + readable_watcher_.ArmOrNotify(); } void WebSocket::OnSSLCertificateErrorResponse( @@ -811,6 +858,7 @@ void WebSocket::Reset() { auth_handler_.reset(); header_client_.reset(); receiver_.reset(); + data_pipe_use_tracker_.Reset(); // net::WebSocketChannel requires that we delete it at this point. channel_.reset(); |