// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "content/browser/websockets/websocket_impl.h" #include #include #include "base/bind.h" #include "base/bind_helpers.h" #include "base/location.h" #include "base/logging.h" #include "base/macros.h" #include "base/memory/ptr_util.h" #include "base/single_thread_task_runner.h" #include "base/strings/string_util.h" #include "base/strings/stringprintf.h" #include "base/threading/thread_task_runner_handle.h" #include "content/browser/bad_message.h" #include "content/browser/child_process_security_policy_impl.h" #include "content/browser/ssl/ssl_error_handler.h" #include "content/browser/ssl/ssl_manager.h" #include "content/browser/websockets/websocket_handshake_request_info_impl.h" #include "content/public/browser/storage_partition.h" #include "ipc/ipc_message.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/http/http_request_headers.h" #include "net/http/http_response_headers.h" #include "net/http/http_util.h" #include "net/ssl/ssl_info.h" #include "net/url_request/url_request_context_getter.h" #include "net/websockets/websocket_channel.h" #include "net/websockets/websocket_errors.h" #include "net/websockets/websocket_event_interface.h" #include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode #include "net/websockets/websocket_handshake_request_info.h" #include "net/websockets/websocket_handshake_response_info.h" #include "url/origin.h" namespace content { namespace { typedef net::WebSocketEventInterface::ChannelState ChannelState; // Convert a blink::mojom::WebSocketMessageType to a // net::WebSocketFrameHeader::OpCode net::WebSocketFrameHeader::OpCode MessageTypeToOpCode( blink::mojom::WebSocketMessageType type) { DCHECK(type == blink::mojom::WebSocketMessageType::CONTINUATION || type == blink::mojom::WebSocketMessageType::TEXT || type == blink::mojom::WebSocketMessageType::BINARY); typedef net::WebSocketFrameHeader::OpCode OpCode; // These compile asserts verify that the same underlying values are used for // both types, so we can simply cast between them. static_assert( static_cast(blink::mojom::WebSocketMessageType::CONTINUATION) == net::WebSocketFrameHeader::kOpCodeContinuation, "enum values must match for opcode continuation"); static_assert( static_cast(blink::mojom::WebSocketMessageType::TEXT) == net::WebSocketFrameHeader::kOpCodeText, "enum values must match for opcode text"); static_assert( static_cast(blink::mojom::WebSocketMessageType::BINARY) == net::WebSocketFrameHeader::kOpCodeBinary, "enum values must match for opcode binary"); return static_cast(type); } blink::mojom::WebSocketMessageType OpCodeToMessageType( net::WebSocketFrameHeader::OpCode opCode) { DCHECK(opCode == net::WebSocketFrameHeader::kOpCodeContinuation || opCode == net::WebSocketFrameHeader::kOpCodeText || opCode == net::WebSocketFrameHeader::kOpCodeBinary); // This cast is guaranteed valid by the static_assert() statements above. return static_cast(opCode); } } // namespace // Implementation of net::WebSocketEventInterface. Receives events from our // WebSocketChannel object. class WebSocketImpl::WebSocketEventHandler final : public net::WebSocketEventInterface { public: explicit WebSocketEventHandler(WebSocketImpl* impl); ~WebSocketEventHandler() override; // net::WebSocketEventInterface implementation void OnCreateURLRequest(net::URLRequest* url_request) override; ChannelState OnAddChannelResponse(const std::string& selected_subprotocol, const std::string& extensions) override; ChannelState OnDataFrame(bool fin, WebSocketMessageType type, scoped_refptr buffer, size_t buffer_size) override; ChannelState OnClosingHandshake() override; ChannelState OnFlowControl(int64_t quota) override; ChannelState OnDropChannel(bool was_clean, uint16_t code, const std::string& reason) override; ChannelState OnFailChannel(const std::string& message) override; ChannelState OnStartOpeningHandshake( std::unique_ptr request) override; ChannelState OnFinishOpeningHandshake( std::unique_ptr response) override; ChannelState OnSSLCertificateError( std::unique_ptr callbacks, const GURL& url, const net::SSLInfo& ssl_info, bool fatal) override; private: class SSLErrorHandlerDelegate final : public SSLErrorHandler::Delegate { public: SSLErrorHandlerDelegate( std::unique_ptr callbacks); ~SSLErrorHandlerDelegate() override; base::WeakPtr GetWeakPtr(); // SSLErrorHandler::Delegate methods void CancelSSLRequest(int error, const net::SSLInfo* ssl_info) override; void ContinueSSLRequest() override; private: std::unique_ptr callbacks_; base::WeakPtrFactory weak_ptr_factory_; DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate); }; WebSocketImpl* const impl_; std::unique_ptr ssl_error_handler_delegate_; DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler); }; WebSocketImpl::WebSocketEventHandler::WebSocketEventHandler(WebSocketImpl* impl) : impl_(impl) { DVLOG(1) << "WebSocketEventHandler created @" << reinterpret_cast(this); } WebSocketImpl::WebSocketEventHandler::~WebSocketEventHandler() { DVLOG(1) << "WebSocketEventHandler destroyed @" << reinterpret_cast(this); } void WebSocketImpl::WebSocketEventHandler::OnCreateURLRequest( net::URLRequest* url_request) { WebSocketHandshakeRequestInfoImpl::CreateInfoAndAssociateWithRequest( impl_->child_id_, impl_->frame_id_, url_request); } ChannelState WebSocketImpl::WebSocketEventHandler::OnAddChannelResponse( const std::string& selected_protocol, const std::string& extensions) { DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse @" << reinterpret_cast(this) << " selected_protocol=\"" << selected_protocol << "\"" << " extensions=\"" << extensions << "\""; impl_->delegate_->OnReceivedResponseFromServer(impl_); impl_->client_->OnAddChannelResponse(selected_protocol, extensions); return net::WebSocketEventInterface::CHANNEL_ALIVE; } ChannelState WebSocketImpl::WebSocketEventHandler::OnDataFrame( bool fin, net::WebSocketFrameHeader::OpCode type, scoped_refptr buffer, size_t buffer_size) { DVLOG(3) << "WebSocketEventHandler::OnDataFrame @" << reinterpret_cast(this) << " fin=" << fin << " type=" << type << " data is " << buffer_size << " bytes"; // TODO(darin): Avoid this copy. std::vector data_to_pass(buffer_size); if (buffer_size > 0) { std::copy(buffer->data(), buffer->data() + buffer_size, data_to_pass.begin()); } impl_->client_->OnDataFrame(fin, OpCodeToMessageType(type), data_to_pass); return net::WebSocketEventInterface::CHANNEL_ALIVE; } ChannelState WebSocketImpl::WebSocketEventHandler::OnClosingHandshake() { DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake @" << reinterpret_cast(this); impl_->client_->OnClosingHandshake(); return net::WebSocketEventInterface::CHANNEL_ALIVE; } ChannelState WebSocketImpl::WebSocketEventHandler::OnFlowControl( int64_t quota) { DVLOG(3) << "WebSocketEventHandler::OnFlowControl @" << reinterpret_cast(this) << " quota=" << quota; impl_->client_->OnFlowControl(quota); return net::WebSocketEventInterface::CHANNEL_ALIVE; } ChannelState WebSocketImpl::WebSocketEventHandler::OnDropChannel( bool was_clean, uint16_t code, const std::string& reason) { DVLOG(3) << "WebSocketEventHandler::OnDropChannel @" << reinterpret_cast(this) << " was_clean=" << was_clean << " code=" << code << " reason=\"" << reason << "\""; impl_->client_->OnDropChannel(was_clean, code, reason); // net::WebSocketChannel requires that we delete it at this point. impl_->channel_.reset(); return net::WebSocketEventInterface::CHANNEL_DELETED; } ChannelState WebSocketImpl::WebSocketEventHandler::OnFailChannel( const std::string& message) { DVLOG(3) << "WebSocketEventHandler::OnFailChannel @" << reinterpret_cast(this) << " message=\"" << message << "\""; impl_->client_->OnFailChannel(message); // net::WebSocketChannel requires that we delete it at this point. impl_->channel_.reset(); return net::WebSocketEventInterface::CHANNEL_DELETED; } ChannelState WebSocketImpl::WebSocketEventHandler::OnStartOpeningHandshake( std::unique_ptr request) { bool should_send = ChildProcessSecurityPolicyImpl::GetInstance()->CanReadRawCookies( impl_->delegate_->GetClientProcessId()); DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake @" << reinterpret_cast(this) << " should_send=" << should_send; if (!should_send) return WebSocketEventInterface::CHANNEL_ALIVE; blink::mojom::WebSocketHandshakeRequestPtr request_to_pass( blink::mojom::WebSocketHandshakeRequest::New()); request_to_pass->url.Swap(&request->url); net::HttpRequestHeaders::Iterator it(request->headers); while (it.GetNext()) { blink::mojom::HttpHeaderPtr header(blink::mojom::HttpHeader::New()); header->name = it.name(); header->value = it.value(); request_to_pass->headers.push_back(std::move(header)); } request_to_pass->headers_text = base::StringPrintf("GET %s HTTP/1.1\r\n", request_to_pass->url.spec().c_str()) + request->headers.ToString(); impl_->client_->OnStartOpeningHandshake(std::move(request_to_pass)); return WebSocketEventInterface::CHANNEL_ALIVE; } ChannelState WebSocketImpl::WebSocketEventHandler::OnFinishOpeningHandshake( std::unique_ptr response) { bool should_send = ChildProcessSecurityPolicyImpl::GetInstance()->CanReadRawCookies( impl_->delegate_->GetClientProcessId()); DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake " << reinterpret_cast(this) << " should_send=" << should_send; if (!should_send) return WebSocketEventInterface::CHANNEL_ALIVE; blink::mojom::WebSocketHandshakeResponsePtr response_to_pass( blink::mojom::WebSocketHandshakeResponse::New()); response_to_pass->url.Swap(&response->url); response_to_pass->status_code = response->status_code; response_to_pass->status_text = response->status_text; size_t iter = 0; std::string name, value; while (response->headers->EnumerateHeaderLines(&iter, &name, &value)) { blink::mojom::HttpHeaderPtr header(blink::mojom::HttpHeader::New()); header->name = name; header->value = value; response_to_pass->headers.push_back(std::move(header)); } response_to_pass->headers_text = net::HttpUtil::ConvertHeadersBackToHTTPResponse( response->headers->raw_headers()); impl_->client_->OnFinishOpeningHandshake(std::move(response_to_pass)); return WebSocketEventInterface::CHANNEL_ALIVE; } ChannelState WebSocketImpl::WebSocketEventHandler::OnSSLCertificateError( std::unique_ptr callbacks, const GURL& url, const net::SSLInfo& ssl_info, bool fatal) { DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError" << reinterpret_cast(this) << " url=" << url.spec() << " cert_status=" << ssl_info.cert_status << " fatal=" << fatal; ssl_error_handler_delegate_.reset( new SSLErrorHandlerDelegate(std::move(callbacks))); SSLManager::OnSSLCertificateSubresourceError( ssl_error_handler_delegate_->GetWeakPtr(), url, impl_->delegate_->GetClientProcessId(), impl_->frame_id_, ssl_info, fatal); // The above method is always asynchronous. return WebSocketEventInterface::CHANNEL_ALIVE; } WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate:: SSLErrorHandlerDelegate( std::unique_ptr callbacks) : callbacks_(std::move(callbacks)), weak_ptr_factory_(this) {} WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate:: ~SSLErrorHandlerDelegate() {} base::WeakPtr WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() { return weak_ptr_factory_.GetWeakPtr(); } void WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate:: CancelSSLRequest(int error, const net::SSLInfo* ssl_info) { DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest" << " error=" << error << " cert_status=" << (ssl_info ? ssl_info->cert_status : static_cast(-1)); callbacks_->CancelSSLRequest(error, ssl_info); } void WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate:: ContinueSSLRequest() { DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest"; callbacks_->ContinueSSLRequest(); } WebSocketImpl::WebSocketImpl(Delegate* delegate, blink::mojom::WebSocketRequest request, int child_id, int frame_id, base::TimeDelta delay) : delegate_(delegate), binding_(this, std::move(request)), delay_(delay), pending_flow_control_quota_(0), child_id_(child_id), frame_id_(frame_id), handshake_succeeded_(false), weak_ptr_factory_(this) { binding_.set_connection_error_handler( base::Bind(&WebSocketImpl::OnConnectionError, base::Unretained(this))); } WebSocketImpl::~WebSocketImpl() {} void WebSocketImpl::GoAway() { StartClosingHandshake(static_cast(net::kWebSocketErrorGoingAway), ""); } void WebSocketImpl::AddChannelRequest( const GURL& socket_url, const std::vector& requested_protocols, const url::Origin& origin, const GURL& first_party_for_cookies, const std::string& user_agent_override, blink::mojom::WebSocketClientPtr client) { DVLOG(3) << "WebSocketImpl::AddChannelRequest @" << reinterpret_cast(this) << " socket_url=\"" << socket_url << "\" requested_protocols=\"" << base::JoinString(requested_protocols, ", ") << "\" origin=\"" << origin << "\" first_party_for_cookies=\"" << first_party_for_cookies << "\" user_agent_override=\"" << user_agent_override << "\""; if (client_ || !client) { bad_message::ReceivedBadMessage( delegate_->GetClientProcessId(), bad_message::WSI_UNEXPECTED_ADD_CHANNEL_REQUEST); return; } client_ = std::move(client); DCHECK(!channel_); if (delay_ > base::TimeDelta()) { base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( FROM_HERE, base::Bind(&WebSocketImpl::AddChannel, weak_ptr_factory_.GetWeakPtr(), socket_url, requested_protocols, origin, first_party_for_cookies, user_agent_override), delay_); } else { AddChannel(socket_url, requested_protocols, origin, first_party_for_cookies, user_agent_override); } } void WebSocketImpl::SendFrame(bool fin, blink::mojom::WebSocketMessageType type, const std::vector& data) { DVLOG(3) << "WebSocketImpl::SendFrame @" << reinterpret_cast(this) << " fin=" << fin << " type=" << type << " data is " << data.size() << " bytes"; if (!channel_) { // The client should not be sending us frames until after we've informed // it that the channel has been opened (OnAddChannelResponse). if (handshake_succeeded_) { DVLOG(1) << "Dropping frame sent to closed websocket"; } else { bad_message::ReceivedBadMessage( delegate_->GetClientProcessId(), bad_message::WSI_UNEXPECTED_SEND_FRAME); } return; } // TODO(darin): Avoid this copy. scoped_refptr data_to_pass(new net::IOBuffer(data.size())); std::copy(data.begin(), data.end(), data_to_pass->data()); channel_->SendFrame(fin, MessageTypeToOpCode(type), std::move(data_to_pass), data.size()); } void WebSocketImpl::SendFlowControl(int64_t quota) { DVLOG(3) << "WebSocketImpl::OnFlowControl @" << reinterpret_cast(this) << " quota=" << quota; if (!channel_) { // WebSocketChannel is not yet created due to the delay introduced by // per-renderer WebSocket throttling. // SendFlowControl() is called after WebSocketChannel is created. pending_flow_control_quota_ += quota; return; } ignore_result(channel_->SendFlowControl(quota)); } void WebSocketImpl::StartClosingHandshake(uint16_t code, const std::string& reason) { DVLOG(3) << "WebSocketImpl::StartClosingHandshake @" << reinterpret_cast(this) << " code=" << code << " reason=\"" << reason << "\""; if (!channel_) { // WebSocketChannel is not yet created due to the delay introduced by // per-renderer WebSocket throttling. if (client_) client_->OnDropChannel(false, net::kWebSocketErrorAbnormalClosure, ""); return; } ignore_result(channel_->StartClosingHandshake(code, reason)); } void WebSocketImpl::OnConnectionError() { DVLOG(3) << "WebSocketImpl::OnConnectionError @" << reinterpret_cast(this); delegate_->OnLostConnectionToClient(this); } void WebSocketImpl::AddChannel( const GURL& socket_url, const std::vector& requested_protocols, const url::Origin& origin, const GURL& first_party_for_cookies, const std::string& user_agent_override) { DVLOG(3) << "WebSocketImpl::AddChannel @" << reinterpret_cast(this) << " socket_url=\"" << socket_url << "\" requested_protocols=\"" << base::JoinString(requested_protocols, ", ") << "\" origin=\"" << origin << "\" first_party_for_cookies=\"" << first_party_for_cookies << "\" user_agent_override=\"" << user_agent_override << "\""; DCHECK(!channel_); StoragePartition* partition = delegate_->GetStoragePartition(); std::unique_ptr event_interface( new WebSocketEventHandler(this)); channel_.reset( new net::WebSocketChannel( std::move(event_interface), partition->GetURLRequestContext()->GetURLRequestContext())); int64_t quota = pending_flow_control_quota_; pending_flow_control_quota_ = 0; std::string additional_headers; if (!user_agent_override.empty()) { if (!net::HttpUtil::IsValidHeaderValue(user_agent_override)) { bad_message::ReceivedBadMessage( delegate_->GetClientProcessId(), bad_message::WSI_INVALID_HEADER_VALUE); return; } additional_headers = base::StringPrintf("%s:%s", net::HttpRequestHeaders::kUserAgent, user_agent_override.c_str()); } channel_->SendAddChannelRequest(socket_url, requested_protocols, origin, first_party_for_cookies, additional_headers); if (quota > 0) SendFlowControl(quota); } } // namespace content