diff options
author | Andras Becsi <andras.becsi@digia.com> | 2013-12-11 21:33:03 +0100 |
---|---|---|
committer | Andras Becsi <andras.becsi@digia.com> | 2013-12-13 12:34:07 +0100 |
commit | f2a33ff9cbc6d19943f1c7fbddd1f23d23975577 (patch) | |
tree | 0586a32aa390ade8557dfd6b4897f43a07449578 /chromium/net/socket | |
parent | 5362912cdb5eea702b68ebe23702468d17c3017a (diff) | |
download | qtwebengine-chromium-f2a33ff9cbc6d19943f1c7fbddd1f23d23975577.tar.gz |
Update Chromium to branch 1650 (31.0.1650.63)
Change-Id: I57d8c832eaec1eb2364e0a8e7352a6dd354db99f
Reviewed-by: Jocelyn Turcotte <jocelyn.turcotte@digia.com>
Diffstat (limited to 'chromium/net/socket')
71 files changed, 4358 insertions, 3668 deletions
diff --git a/chromium/net/socket/buffered_write_stream_socket.cc b/chromium/net/socket/buffered_write_stream_socket.cc index 36b9df715fd..cf13c5e439a 100644 --- a/chromium/net/socket/buffered_write_stream_socket.cc +++ b/chromium/net/socket/buffered_write_stream_socket.cc @@ -23,8 +23,8 @@ void AppendBuffer(GrowableIOBuffer* dst, IOBuffer* src, int src_len) { } // anonymous namespace BufferedWriteStreamSocket::BufferedWriteStreamSocket( - StreamSocket* socket_to_wrap) - : wrapped_socket_(socket_to_wrap), + scoped_ptr<StreamSocket> socket_to_wrap) + : wrapped_socket_(socket_to_wrap.Pass()), io_buffer_(new GrowableIOBuffer()), backup_buffer_(new GrowableIOBuffer()), weak_factory_(this), diff --git a/chromium/net/socket/buffered_write_stream_socket.h b/chromium/net/socket/buffered_write_stream_socket.h index fcb33a81910..aad5736d0b0 100644 --- a/chromium/net/socket/buffered_write_stream_socket.h +++ b/chromium/net/socket/buffered_write_stream_socket.h @@ -5,6 +5,8 @@ #ifndef NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_ #define NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_ +#include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" #include "base/memory/weak_ptr.h" #include "net/base/net_log.h" #include "net/socket/stream_socket.h" @@ -33,7 +35,7 @@ class IPEndPoint; // There are no bounds on the local buffer size. Use carefully. class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket { public: - BufferedWriteStreamSocket(StreamSocket* socket_to_wrap); + explicit BufferedWriteStreamSocket(scoped_ptr<StreamSocket> socket_to_wrap); virtual ~BufferedWriteStreamSocket(); // Socket interface @@ -71,6 +73,8 @@ class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket { bool callback_pending_; bool wrapped_write_in_progress_; int error_; + + DISALLOW_COPY_AND_ASSIGN(BufferedWriteStreamSocket); }; } // namespace net diff --git a/chromium/net/socket/buffered_write_stream_socket_unittest.cc b/chromium/net/socket/buffered_write_stream_socket_unittest.cc index e579a7f51d2..485295f33f6 100644 --- a/chromium/net/socket/buffered_write_stream_socket_unittest.cc +++ b/chromium/net/socket/buffered_write_stream_socket_unittest.cc @@ -30,10 +30,11 @@ class BufferedWriteStreamSocketTest : public testing::Test { if (writes_count) { data_->StopAfter(writes_count); } - DeterministicMockTCPClientSocket* wrapped_socket = - new DeterministicMockTCPClientSocket(net_log_.net_log(), data_.get()); + scoped_ptr<DeterministicMockTCPClientSocket> wrapped_socket( + new DeterministicMockTCPClientSocket(net_log_.net_log(), data_.get())); data_->set_delegate(wrapped_socket->AsWeakPtr()); - socket_.reset(new BufferedWriteStreamSocket(wrapped_socket)); + socket_.reset(new BufferedWriteStreamSocket( + wrapped_socket.PassAs<StreamSocket>())); socket_->Connect(callback_.callback()); } diff --git a/chromium/net/socket/client_socket_factory.cc b/chromium/net/socket/client_socket_factory.cc index 022988aa6a9..a86688e3333 100644 --- a/chromium/net/socket/client_socket_factory.cc +++ b/chromium/net/socket/client_socket_factory.cc @@ -67,23 +67,25 @@ class DefaultClientSocketFactory : public ClientSocketFactory, ClearSSLSessionCache(); } - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE { - return new UDPClientSocket(bind_type, rand_int_cb, net_log, source); + return scoped_ptr<DatagramClientSocket>( + new UDPClientSocket(bind_type, rand_int_cb, net_log, source)); } - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, const NetLog::Source& source) OVERRIDE { - return new TCPClientSocket(addresses, net_log, source); + return scoped_ptr<StreamSocket>( + new TCPClientSocket(addresses, net_log, source)); } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE { @@ -102,17 +104,19 @@ class DefaultClientSocketFactory : public ClientSocketFactory, nss_task_runner = base::ThreadTaskRunnerHandle::Get(); #if defined(USE_OPENSSL) - return new SSLClientSocketOpenSSL(transport_socket, host_and_port, - ssl_config, context); + return scoped_ptr<SSLClientSocket>( + new SSLClientSocketOpenSSL(transport_socket.Pass(), host_and_port, + ssl_config, context)); #elif defined(USE_NSS) || defined(OS_MACOSX) || defined(OS_WIN) - return new SSLClientSocketNSS(nss_task_runner.get(), - transport_socket, - host_and_port, - ssl_config, - context); + return scoped_ptr<SSLClientSocket>( + new SSLClientSocketNSS(nss_task_runner.get(), + transport_socket.Pass(), + host_and_port, + ssl_config, + context)); #else NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); #endif } @@ -130,18 +134,6 @@ static base::LazyInstance<DefaultClientSocketFactory>::Leaky } // namespace -// Deprecated function (http://crbug.com/37810) that takes a StreamSocket. -SSLClientSocket* ClientSocketFactory::CreateSSLClientSocket( - StreamSocket* transport_socket, - const HostPortPair& host_and_port, - const SSLConfig& ssl_config, - const SSLClientSocketContext& context) { - ClientSocketHandle* socket_handle = new ClientSocketHandle(); - socket_handle->set_socket(transport_socket); - return CreateSSLClientSocket(socket_handle, host_and_port, ssl_config, - context); -} - // static ClientSocketFactory* ClientSocketFactory::GetDefaultFactory() { return g_default_client_socket_factory.Pointer(); diff --git a/chromium/net/socket/client_socket_factory.h b/chromium/net/socket/client_socket_factory.h index 65022f29234..6cb5949f0b3 100644 --- a/chromium/net/socket/client_socket_factory.h +++ b/chromium/net/socket/client_socket_factory.h @@ -8,6 +8,7 @@ #include <string> #include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" #include "net/base/net_export.h" #include "net/base/net_log.h" #include "net/base/rand_callback.h" @@ -32,13 +33,13 @@ class NET_EXPORT ClientSocketFactory { // |source| is the NetLog::Source for the entity trying to create the socket, // if it has one. - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) = 0; - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, const NetLog::Source& source) = 0; @@ -46,19 +47,12 @@ class NET_EXPORT ClientSocketFactory { // It is allowed to pass in a |transport_socket| that is not obtained from a // socket pool. The caller could create a ClientSocketHandle directly and call // set_socket() on it to set a valid StreamSocket instance. - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) = 0; - // Deprecated function (http://crbug.com/37810) that takes a StreamSocket. - virtual SSLClientSocket* CreateSSLClientSocket( - StreamSocket* transport_socket, - const HostPortPair& host_and_port, - const SSLConfig& ssl_config, - const SSLClientSocketContext& context); - // Clears cache used for SSL session resumption. virtual void ClearSSLSessionCache() = 0; diff --git a/chromium/net/socket/client_socket_handle.cc b/chromium/net/socket/client_socket_handle.cc index 3894fa7aa0e..e42e9fcada3 100644 --- a/chromium/net/socket/client_socket_handle.cc +++ b/chromium/net/socket/client_socket_handle.cc @@ -18,7 +18,7 @@ namespace net { ClientSocketHandle::ClientSocketHandle() : is_initialized_(false), pool_(NULL), - layered_pool_(NULL), + higher_pool_(NULL), is_reused_(false), callback_(base::Bind(&ClientSocketHandle::OnIOComplete, base::Unretained(this))), @@ -34,29 +34,34 @@ void ClientSocketHandle::Reset() { } void ClientSocketHandle::ResetInternal(bool cancel) { - if (group_name_.empty()) // Was Init called? - return; - if (is_initialized()) { - // Because of http://crbug.com/37810 we may not have a pool, but have - // just a raw socket. - socket_->NetLog().EndEvent(NetLog::TYPE_SOCKET_IN_USE); - if (pool_) - // If we've still got a socket, release it back to the ClientSocketPool so - // it can be deleted or reused. - pool_->ReleaseSocket(group_name_, release_socket(), pool_id_); - } else if (cancel) { - // If we did not get initialized yet, we've got a socket request pending. - // Cancel it. - pool_->CancelRequest(group_name_, this); + // Was Init called? + if (!group_name_.empty()) { + // If so, we must have a pool. + CHECK(pool_); + if (is_initialized()) { + if (socket_) { + socket_->NetLog().EndEvent(NetLog::TYPE_SOCKET_IN_USE); + // Release the socket back to the ClientSocketPool so it can be + // deleted or reused. + pool_->ReleaseSocket(group_name_, socket_.Pass(), pool_id_); + } else { + // If the handle has been initialized, we should still have a + // socket. + NOTREACHED(); + } + } else if (cancel) { + // If we did not get initialized yet and we have a socket + // request pending, cancel it. + pool_->CancelRequest(group_name_, this); + } } is_initialized_ = false; + socket_.reset(); group_name_.clear(); is_reused_ = false; user_callback_.Reset(); - if (layered_pool_) { - pool_->RemoveLayeredPool(layered_pool_); - layered_pool_ = NULL; - } + if (higher_pool_) + RemoveHigherLayeredPool(higher_pool_); pool_ = NULL; idle_time_ = base::TimeDelta(); init_time_ = base::TimeTicks(); @@ -82,24 +87,30 @@ LoadState ClientSocketHandle::GetLoadState() const { } bool ClientSocketHandle::IsPoolStalled() const { + if (!pool_) + return false; return pool_->IsStalled(); } -void ClientSocketHandle::AddLayeredPool(LayeredPool* layered_pool) { - CHECK(layered_pool); - CHECK(!layered_pool_); +void ClientSocketHandle::AddHigherLayeredPool(HigherLayeredPool* higher_pool) { + CHECK(higher_pool); + CHECK(!higher_pool_); + // TODO(mmenke): |pool_| should only be NULL in tests. Maybe stop doing that + // so this be be made into a DCHECK, and the same can be done in + // RemoveHigherLayeredPool? if (pool_) { - pool_->AddLayeredPool(layered_pool); - layered_pool_ = layered_pool; + pool_->AddHigherLayeredPool(higher_pool); + higher_pool_ = higher_pool; } } -void ClientSocketHandle::RemoveLayeredPool(LayeredPool* layered_pool) { - CHECK(layered_pool); - CHECK(layered_pool_); +void ClientSocketHandle::RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) { + CHECK(higher_pool_); + CHECK_EQ(higher_pool_, higher_pool); if (pool_) { - pool_->RemoveLayeredPool(layered_pool); - layered_pool_ = NULL; + pool_->RemoveHigherLayeredPool(higher_pool); + higher_pool_ = NULL; } } @@ -121,6 +132,10 @@ bool ClientSocketHandle::GetLoadTimingInfo( return true; } +void ClientSocketHandle::SetSocket(scoped_ptr<StreamSocket> s) { + socket_ = s.Pass(); +} + void ClientSocketHandle::OnIOComplete(int result) { CompletionCallback callback = user_callback_; user_callback_.Reset(); @@ -128,6 +143,10 @@ void ClientSocketHandle::OnIOComplete(int result) { callback.Run(result); } +scoped_ptr<StreamSocket> ClientSocketHandle::PassSocket() { + return socket_.Pass(); +} + void ClientSocketHandle::HandleInitCompletion(int result) { CHECK_NE(ERR_IO_PENDING, result); ClientSocketPoolHistograms* histograms = pool_->histograms(); diff --git a/chromium/net/socket/client_socket_handle.h b/chromium/net/socket/client_socket_handle.h index 7d5588a145b..30b7c03e9dc 100644 --- a/chromium/net/socket/client_socket_handle.h +++ b/chromium/net/socket/client_socket_handle.h @@ -70,9 +70,9 @@ class NET_EXPORT ClientSocketHandle { // // Profiling information for the request is saved to |net_log| if non-NULL. // - template <typename SocketParams, typename PoolType> + template <typename PoolType> int Init(const std::string& group_name, - const scoped_refptr<SocketParams>& socket_params, + const scoped_refptr<typename PoolType::SocketParams>& socket_params, RequestPriority priority, const CompletionCallback& callback, PoolType* pool, @@ -94,9 +94,15 @@ class NET_EXPORT ClientSocketHandle { bool IsPoolStalled() const; - void AddLayeredPool(LayeredPool* layered_pool); + // Adds a higher layered pool on top of the socket pool that |socket_| belongs + // to. At most one higher layered pool can be added to a + // ClientSocketHandle at a time. On destruction or reset, automatically + // removes the higher pool if RemoveHigherLayeredPool has not been called. + void AddHigherLayeredPool(HigherLayeredPool* higher_pool); - void RemoveLayeredPool(LayeredPool* layered_pool); + // Removes a higher layered pool from the socket pool that |socket_| belongs + // to. |higher_pool| must have been added by the above function. + void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool); // Returns true when Init() has completed successfully. bool is_initialized() const { return is_initialized_; } @@ -116,8 +122,11 @@ class NET_EXPORT ClientSocketHandle { LoadTimingInfo* load_timing_info) const; // Used by ClientSocketPool to initialize the ClientSocketHandle. + // + // SetSocket() may also be used if this handle is used as simply for + // socket storage (e.g., http://crbug.com/37810). + void SetSocket(scoped_ptr<StreamSocket> s); void set_is_reused(bool is_reused) { is_reused_ = is_reused; } - void set_socket(StreamSocket* s) { socket_.reset(s); } void set_idle_time(base::TimeDelta idle_time) { idle_time_ = idle_time; } void set_pool_id(int id) { pool_id_ = id; } void set_is_ssl_error(bool is_ssl_error) { is_ssl_error_ = is_ssl_error; } @@ -143,11 +152,15 @@ class NET_EXPORT ClientSocketHandle { return pending_http_proxy_connection_.release(); } + StreamSocket* socket() { return socket_.get(); } + + // SetSocket() must be called with a new socket before this handle + // is destroyed if is_initialized() is true. + scoped_ptr<StreamSocket> PassSocket(); + // These may only be used if is_initialized() is true. const std::string& group_name() const { return group_name_; } int id() const { return pool_id_; } - StreamSocket* socket() { return socket_.get(); } - StreamSocket* release_socket() { return socket_.release(); } bool is_reused() const { return is_reused_; } base::TimeDelta idle_time() const { return idle_time_; } SocketReuseType reuse_type() const { @@ -184,7 +197,7 @@ class NET_EXPORT ClientSocketHandle { bool is_initialized_; ClientSocketPool* pool_; - LayeredPool* layered_pool_; + HigherLayeredPool* higher_pool_; scoped_ptr<StreamSocket> socket_; std::string group_name_; bool is_reused_; @@ -207,20 +220,17 @@ class NET_EXPORT ClientSocketHandle { }; // Template function implementation: -template <typename SocketParams, typename PoolType> -int ClientSocketHandle::Init(const std::string& group_name, - const scoped_refptr<SocketParams>& socket_params, - RequestPriority priority, - const CompletionCallback& callback, - PoolType* pool, - const BoundNetLog& net_log) { +template <typename PoolType> +int ClientSocketHandle::Init( + const std::string& group_name, + const scoped_refptr<typename PoolType::SocketParams>& socket_params, + RequestPriority priority, + const CompletionCallback& callback, + PoolType* pool, + const BoundNetLog& net_log) { requesting_source_ = net_log.source(); CHECK(!group_name.empty()); - // Note that this will result in a compile error if the SocketParams has not - // been registered for the PoolType via REGISTER_SOCKET_PARAMS_FOR_POOL - // (defined in client_socket_pool.h). - CheckIsValidSocketParamsForPool<PoolType, SocketParams>(); ResetInternal(true); ResetErrorState(); pool_ = pool; diff --git a/chromium/net/socket/client_socket_pool.h b/chromium/net/socket/client_socket_pool.h index 7cb9a7ebc2e..715cddb94d4 100644 --- a/chromium/net/socket/client_socket_pool.h +++ b/chromium/net/socket/client_socket_pool.h @@ -10,6 +10,7 @@ #include "base/basictypes.h" #include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" #include "base/template_util.h" #include "base/time/time.h" #include "net/base/completion_callback.h" @@ -30,20 +31,43 @@ class StreamSocket; // ClientSocketPools are layered. This defines an interface for lower level // socket pools to communicate with higher layer pools. -class NET_EXPORT LayeredPool { +class NET_EXPORT HigherLayeredPool { public: - virtual ~LayeredPool() {}; + virtual ~HigherLayeredPool() {} - // Instructs the LayeredPool to close an idle connection. Return true if one - // was closed. + // Instructs the HigherLayeredPool to close an idle connection. Return true if + // one was closed. Closing an idle connection will call into the lower layer + // pool it came from, so must be careful of re-entrancy when using this. virtual bool CloseOneIdleConnection() = 0; }; +// ClientSocketPools are layered. This defines an interface for higher level +// socket pools to communicate with lower layer pools. +class NET_EXPORT LowerLayeredPool { + public: + virtual ~LowerLayeredPool() {} + + // Returns true if a there is currently a request blocked on the per-pool + // (not per-host) max socket limit, either in this pool, or one that it is + // layered on top of. + virtual bool IsStalled() const = 0; + + // Called to add or remove a higher layer pool on top of |this|. A higher + // layer pool may be added at most once to |this|, and must be removed prior + // to destruction of |this|. + virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) = 0; + virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) = 0; +}; + // A ClientSocketPool is used to restrict the number of sockets open at a time. // It also maintains a list of idle persistent sockets. // -class NET_EXPORT ClientSocketPool { +class NET_EXPORT ClientSocketPool : public LowerLayeredPool { public: + // Subclasses must also have an inner class SocketParams which is + // the type for the |params| argument in RequestSocket() and + // RequestSockets() below. + // Requests a connected socket for a group_name. // // There are five possible results from calling this function: @@ -111,7 +135,7 @@ class NET_EXPORT ClientSocketPool { // change when it flushes, so it can use this |id| to discard sockets with // mismatched ids. virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) = 0; // This flushes all state from the ClientSocketPool. This means that all @@ -121,10 +145,6 @@ class NET_EXPORT ClientSocketPool { // Does not flush any pools wrapped by |this|. virtual void FlushWithError(int error) = 0; - // Returns true if a there is currently a request blocked on the - // per-pool (not per-host) max socket limit. - virtual bool IsStalled() const = 0; - // Called to close any idle connections held by the connection manager. virtual void CloseIdleSockets() = 0; @@ -138,12 +158,6 @@ class NET_EXPORT ClientSocketPool { virtual LoadState GetLoadState(const std::string& group_name, const ClientSocketHandle* handle) const = 0; - // Adds a LayeredPool on top of |this|. - virtual void AddLayeredPool(LayeredPool* layered_pool) = 0; - - // Removes a LayeredPool from |this|. - virtual void RemoveLayeredPool(LayeredPool* layered_pool) = 0; - // Retrieves information on the current state of the pool as a // DictionaryValue. Caller takes possession of the returned value. // If |include_nested_pools| is true, the states of any nested @@ -177,41 +191,13 @@ class NET_EXPORT ClientSocketPool { DISALLOW_COPY_AND_ASSIGN(ClientSocketPool); }; -// ClientSocketPool subclasses should indicate valid SocketParams via the -// REGISTER_SOCKET_PARAMS_FOR_POOL macro below. By default, any given -// <PoolType,SocketParams> pair will have its SocketParamsTrait inherit from -// base::false_type, but REGISTER_SOCKET_PARAMS_FOR_POOL will specialize that -// pairing to inherit from base::true_type. This provides compile time -// verification that the correct SocketParams type is used with the appropriate -// PoolType. -template <typename PoolType, typename SocketParams> -struct SocketParamTraits : public base::false_type { -}; - -template <typename PoolType, typename SocketParams> -void CheckIsValidSocketParamsForPool() { - COMPILE_ASSERT(!base::is_pointer<scoped_refptr<SocketParams> >::value, - socket_params_cannot_be_pointer); - COMPILE_ASSERT((SocketParamTraits<PoolType, - scoped_refptr<SocketParams> >::value), - invalid_socket_params_for_pool); -} - -// Provides an empty definition for CheckIsValidSocketParamsForPool() which -// should be optimized out by the compiler. -#define REGISTER_SOCKET_PARAMS_FOR_POOL(pool_type, socket_params) \ -template<> \ -struct SocketParamTraits<pool_type, scoped_refptr<socket_params> > \ - : public base::true_type { \ -} - -template <typename PoolType, typename SocketParams> -void RequestSocketsForPool(PoolType* pool, - const std::string& group_name, - const scoped_refptr<SocketParams>& params, - int num_sockets, - const BoundNetLog& net_log) { - CheckIsValidSocketParamsForPool<PoolType, SocketParams>(); +template <typename PoolType> +void RequestSocketsForPool( + PoolType* pool, + const std::string& group_name, + const scoped_refptr<typename PoolType::SocketParams>& params, + int num_sockets, + const BoundNetLog& net_log) { pool->RequestSockets(group_name, ¶ms, num_sockets, net_log); } diff --git a/chromium/net/socket/client_socket_pool_base.cc b/chromium/net/socket/client_socket_pool_base.cc index 3332e04a171..cec7956a0ee 100644 --- a/chromium/net/socket/client_socket_pool_base.cc +++ b/chromium/net/socket/client_socket_pool_base.cc @@ -65,10 +65,12 @@ int CompareEffectiveRequestPriority( ConnectJob::ConnectJob(const std::string& group_name, base::TimeDelta timeout_duration, + RequestPriority priority, Delegate* delegate, const BoundNetLog& net_log) : group_name_(group_name), timeout_duration_(timeout_duration), + priority_(priority), delegate_(delegate), net_log_(net_log), idle_(true) { @@ -82,6 +84,10 @@ ConnectJob::~ConnectJob() { net_log().EndEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB); } +scoped_ptr<StreamSocket> ConnectJob::PassSocket() { + return socket_.Pass(); +} + int ConnectJob::Connect() { if (timeout_duration_ != base::TimeDelta()) timer_.Start(FROM_HERE, timeout_duration_, this, &ConnectJob::OnTimeout); @@ -100,16 +106,16 @@ int ConnectJob::Connect() { return rv; } -void ConnectJob::set_socket(StreamSocket* socket) { +void ConnectJob::SetSocket(scoped_ptr<StreamSocket> socket) { if (socket) { net_log().AddEvent(NetLog::TYPE_CONNECT_JOB_SET_SOCKET, socket->NetLog().source().ToEventParametersCallback()); } - socket_.reset(socket); + socket_ = socket.Pass(); } void ConnectJob::NotifyDelegateOfCompletion(int rv) { - // The delegate will delete |this|. + // The delegate will own |this|. Delegate* delegate = delegate_; delegate_ = NULL; @@ -135,7 +141,7 @@ void ConnectJob::LogConnectCompletion(int net_error) { void ConnectJob::OnTimeout() { // Make sure the socket is NULL before calling into |delegate|. - set_socket(NULL); + SetSocket(scoped_ptr<StreamSocket>()); net_log_.AddEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_TIMED_OUT); @@ -161,6 +167,7 @@ ClientSocketPoolBaseHelper::Request::Request( ClientSocketPoolBaseHelper::Request::~Request() {} ClientSocketPoolBaseHelper::ClientSocketPoolBaseHelper( + HigherLayeredPool* pool, int max_sockets, int max_sockets_per_group, base::TimeDelta unused_idle_socket_timeout, @@ -177,6 +184,7 @@ ClientSocketPoolBaseHelper::ClientSocketPoolBaseHelper( connect_job_factory_(connect_job_factory), connect_backup_jobs_enabled_(false), pool_generation_number_(0), + pool_(pool), weak_factory_(this) { DCHECK_LE(0, max_sockets_per_group); DCHECK_LE(max_sockets_per_group, max_sockets); @@ -192,9 +200,16 @@ ClientSocketPoolBaseHelper::~ClientSocketPoolBaseHelper() { DCHECK(group_map_.empty()); DCHECK(pending_callback_map_.empty()); DCHECK_EQ(0, connecting_socket_count_); - CHECK(higher_layer_pools_.empty()); + CHECK(higher_pools_.empty()); NetworkChangeNotifier::RemoveIPAddressObserver(this); + + // Remove from lower layer pools. + for (std::set<LowerLayeredPool*>::iterator it = lower_pools_.begin(); + it != lower_pools_.end(); + ++it) { + (*it)->RemoveHigherLayeredPool(pool_); + } } ClientSocketPoolBaseHelper::CallbackResultPair::CallbackResultPair() @@ -209,46 +224,59 @@ ClientSocketPoolBaseHelper::CallbackResultPair::CallbackResultPair( ClientSocketPoolBaseHelper::CallbackResultPair::~CallbackResultPair() {} -// static -void ClientSocketPoolBaseHelper::InsertRequestIntoQueue( - const Request* r, RequestQueue* pending_requests) { - RequestQueue::iterator it = pending_requests->begin(); - // TODO(mmenke): Should the network stack require requests with - // |ignore_limits| have the highest priority? - while (it != pending_requests->end() && - CompareEffectiveRequestPriority(*r, *(*it)) <= 0) { - ++it; +bool ClientSocketPoolBaseHelper::IsStalled() const { + // If a lower layer pool is stalled, consider |this| stalled as well. + for (std::set<LowerLayeredPool*>::const_iterator it = lower_pools_.begin(); + it != lower_pools_.end(); + ++it) { + if ((*it)->IsStalled()) + return true; + } + + // If fewer than |max_sockets_| are in use, then clearly |this| is not + // stalled. + if ((handed_out_socket_count_ + connecting_socket_count_) < max_sockets_) + return false; + // So in order to be stalled, |this| must be using at least |max_sockets_| AND + // |this| must have a request that is actually stalled on the global socket + // limit. To find such a request, look for a group that has more requests + // than jobs AND where the number of sockets is less than + // |max_sockets_per_group_|. (If the number of sockets is equal to + // |max_sockets_per_group_|, then the request is stalled on the group limit, + // which does not count.) + for (GroupMap::const_iterator it = group_map_.begin(); + it != group_map_.end(); ++it) { + if (it->second->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) + return true; } - pending_requests->insert(it, r); + return false; } -// static -const ClientSocketPoolBaseHelper::Request* -ClientSocketPoolBaseHelper::RemoveRequestFromQueue( - const RequestQueue::iterator& it, Group* group) { - const Request* req = *it; - group->mutable_pending_requests()->erase(it); - // If there are no more requests, we kill the backup timer. - if (group->pending_requests().empty()) - group->CleanupBackupJob(); - return req; +void ClientSocketPoolBaseHelper::AddLowerLayeredPool( + LowerLayeredPool* lower_pool) { + DCHECK(pool_); + CHECK(!ContainsKey(lower_pools_, lower_pool)); + lower_pools_.insert(lower_pool); + lower_pool->AddHigherLayeredPool(pool_); } -void ClientSocketPoolBaseHelper::AddLayeredPool(LayeredPool* pool) { - CHECK(pool); - CHECK(!ContainsKey(higher_layer_pools_, pool)); - higher_layer_pools_.insert(pool); +void ClientSocketPoolBaseHelper::AddHigherLayeredPool( + HigherLayeredPool* higher_pool) { + CHECK(higher_pool); + CHECK(!ContainsKey(higher_pools_, higher_pool)); + higher_pools_.insert(higher_pool); } -void ClientSocketPoolBaseHelper::RemoveLayeredPool(LayeredPool* pool) { - CHECK(pool); - CHECK(ContainsKey(higher_layer_pools_, pool)); - higher_layer_pools_.erase(pool); +void ClientSocketPoolBaseHelper::RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) { + CHECK(higher_pool); + CHECK(ContainsKey(higher_pools_, higher_pool)); + higher_pools_.erase(higher_pool); } int ClientSocketPoolBaseHelper::RequestSocket( const std::string& group_name, - const Request* request) { + scoped_ptr<const Request> request) { CHECK(!request->callback().is_null()); CHECK(request->handle()); @@ -259,13 +287,13 @@ int ClientSocketPoolBaseHelper::RequestSocket( request->net_log().BeginEvent(NetLog::TYPE_SOCKET_POOL); Group* group = GetOrCreateGroup(group_name); - int rv = RequestSocketInternal(group_name, request); + int rv = RequestSocketInternal(group_name, *request); if (rv != ERR_IO_PENDING) { request->net_log().EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, rv); CHECK(!request->handle()->is_initialized()); - delete request; + request.reset(); } else { - InsertRequestIntoQueue(request, group->mutable_pending_requests()); + group->InsertPendingRequest(request.Pass()); // Have to do this asynchronously, as closing sockets in higher level pools // call back in to |this|, which will cause all sorts of fun and exciting // re-entrancy issues if the socket pool is doing something else at the @@ -309,7 +337,7 @@ void ClientSocketPoolBaseHelper::RequestSockets( for (int num_iterations_left = num_sockets; group->NumActiveSocketSlots() < num_sockets && num_iterations_left > 0 ; num_iterations_left--) { - rv = RequestSocketInternal(group_name, &request); + rv = RequestSocketInternal(group_name, request); if (rv < 0 && rv != ERR_IO_PENDING) { // We're encountering a synchronous error. Give up. if (!ContainsKey(group_map_, group_name)) @@ -336,12 +364,12 @@ void ClientSocketPoolBaseHelper::RequestSockets( int ClientSocketPoolBaseHelper::RequestSocketInternal( const std::string& group_name, - const Request* request) { - ClientSocketHandle* const handle = request->handle(); + const Request& request) { + ClientSocketHandle* const handle = request.handle(); const bool preconnecting = !handle; Group* group = GetOrCreateGroup(group_name); - if (!(request->flags() & NO_IDLE_SOCKETS)) { + if (!(request.flags() & NO_IDLE_SOCKETS)) { // Try to reuse a socket. if (AssignIdleSocketToRequest(request, group)) return OK; @@ -355,17 +383,17 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( // Can we make another active socket now? if (!group->HasAvailableSocketSlot(max_sockets_per_group_) && - !request->ignore_limits()) { + !request.ignore_limits()) { // TODO(willchan): Consider whether or not we need to close a socket in a // higher layered group. I don't think this makes sense since we would just // reuse that socket then if we needed one and wouldn't make it down to this // layer. - request->net_log().AddEvent( + request.net_log().AddEvent( NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS_PER_GROUP); return ERR_IO_PENDING; } - if (ReachedMaxSocketsLimit() && !request->ignore_limits()) { + if (ReachedMaxSocketsLimit() && !request.ignore_limits()) { // NOTE(mmenke): Wonder if we really need different code for each case // here. Only reason for them now seems to be preconnects. if (idle_socket_count() > 0) { @@ -378,7 +406,7 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( } else { // We could check if we really have a stalled group here, but it requires // a scan of all groups, so just flip a flag here, and do the check later. - request->net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS); + request.net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS); return ERR_IO_PENDING; } } @@ -386,17 +414,17 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( // We couldn't find a socket to reuse, and there's space to allocate one, // so allocate and connect a new one. scoped_ptr<ConnectJob> connect_job( - connect_job_factory_->NewConnectJob(group_name, *request, this)); + connect_job_factory_->NewConnectJob(group_name, request, this)); int rv = connect_job->Connect(); if (rv == OK) { LogBoundConnectJobToRequest(connect_job->net_log().source(), request); if (!preconnecting) { - HandOutSocket(connect_job->ReleaseSocket(), false /* not reused */, + HandOutSocket(connect_job->PassSocket(), false /* not reused */, connect_job->connect_timing(), handle, base::TimeDelta(), - group, request->net_log()); + group, request.net_log()); } else { - AddIdleSocket(connect_job->ReleaseSocket(), group); + AddIdleSocket(connect_job->PassSocket(), group); } } else if (rv == ERR_IO_PENDING) { // If we don't have any sockets in this group, set a timer for potentially @@ -409,19 +437,19 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( connecting_socket_count_++; - group->AddJob(connect_job.release(), preconnecting); + group->AddJob(connect_job.Pass(), preconnecting); } else { LogBoundConnectJobToRequest(connect_job->net_log().source(), request); - StreamSocket* error_socket = NULL; + scoped_ptr<StreamSocket> error_socket; if (!preconnecting) { DCHECK(handle); connect_job->GetAdditionalErrorState(handle); - error_socket = connect_job->ReleaseSocket(); + error_socket = connect_job->PassSocket(); } if (error_socket) { - HandOutSocket(error_socket, false /* not reused */, + HandOutSocket(error_socket.Pass(), false /* not reused */, connect_job->connect_timing(), handle, base::TimeDelta(), - group, request->net_log()); + group, request.net_log()); } else if (group->IsEmpty()) { RemoveGroup(group_name); } @@ -431,7 +459,7 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( } bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest( - const Request* request, Group* group) { + const Request& request, Group* group) { std::list<IdleSocket>* idle_sockets = group->mutable_idle_sockets(); std::list<IdleSocket>::iterator idle_socket_it = idle_sockets->end(); @@ -469,13 +497,13 @@ bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest( IdleSocket idle_socket = *idle_socket_it; idle_sockets->erase(idle_socket_it); HandOutSocket( - idle_socket.socket, + scoped_ptr<StreamSocket>(idle_socket.socket), idle_socket.socket->WasEverUsed(), LoadTimingInfo::ConnectTiming(), - request->handle(), + request.handle(), idle_time, group, - request->net_log()); + request.net_log()); return true; } @@ -484,9 +512,9 @@ bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest( // static void ClientSocketPoolBaseHelper::LogBoundConnectJobToRequest( - const NetLog::Source& connect_job_source, const Request* request) { - request->net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB, - connect_job_source.ToEventParametersCallback()); + const NetLog::Source& connect_job_source, const Request& request) { + request.net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB, + connect_job_source.ToEventParametersCallback()); } void ClientSocketPoolBaseHelper::CancelRequest( @@ -495,11 +523,11 @@ void ClientSocketPoolBaseHelper::CancelRequest( if (callback_it != pending_callback_map_.end()) { int result = callback_it->second.result; pending_callback_map_.erase(callback_it); - StreamSocket* socket = handle->release_socket(); + scoped_ptr<StreamSocket> socket = handle->PassSocket(); if (socket) { if (result != OK) socket->Disconnect(); - ReleaseSocket(handle->group_name(), socket, handle->id()); + ReleaseSocket(handle->group_name(), socket.Pass(), handle->id()); } return; } @@ -509,21 +537,18 @@ void ClientSocketPoolBaseHelper::CancelRequest( Group* group = GetOrCreateGroup(group_name); // Search pending_requests for matching handle. - RequestQueue::iterator it = group->mutable_pending_requests()->begin(); - for (; it != group->pending_requests().end(); ++it) { - if ((*it)->handle() == handle) { - scoped_ptr<const Request> req(RemoveRequestFromQueue(it, group)); - req->net_log().AddEvent(NetLog::TYPE_CANCELLED); - req->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL); - - // We let the job run, unless we're at the socket limit and there is - // not another request waiting on the job. - if (group->jobs().size() > group->pending_requests().size() && - ReachedMaxSocketsLimit()) { - RemoveConnectJob(*group->jobs().begin(), group); - CheckForStalledSocketGroups(); - } - break; + scoped_ptr<const Request> request = + group->FindAndRemovePendingRequest(handle); + if (request) { + request->net_log().AddEvent(NetLog::TYPE_CANCELLED); + request->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL); + + // We let the job run, unless we're at the socket limit and there is + // not another request waiting on the job. + if (group->jobs().size() > group->pending_request_count() && + ReachedMaxSocketsLimit()) { + RemoveConnectJob(*group->jobs().begin(), group); + CheckForStalledSocketGroups(); } } } @@ -560,16 +585,7 @@ LoadState ClientSocketPoolBaseHelper::GetLoadState( // Can't use operator[] since it is non-const. const Group& group = *group_map_.find(group_name)->second; - // Search the first group.jobs().size() |pending_requests| for |handle|. - // If it's farther back in the deque than that, it doesn't have a - // corresponding ConnectJob. - size_t connect_jobs = group.jobs().size(); - RequestQueue::const_iterator it = group.pending_requests().begin(); - for (size_t i = 0; it != group.pending_requests().end() && i < connect_jobs; - ++it, ++i) { - if ((*it)->handle() != handle) - continue; - + if (group.HasConnectJobForHandle(handle)) { // Just return the state of the farthest along ConnectJob for the first // group.jobs().size() pending requests. LoadState max_state = LOAD_STATE_IDLE; @@ -607,8 +623,8 @@ base::DictionaryValue* ClientSocketPoolBaseHelper::GetInfoAsValue( base::DictionaryValue* group_dict = new base::DictionaryValue(); group_dict->SetInteger("pending_request_count", - group->pending_requests().size()); - if (!group->pending_requests().empty()) { + group->pending_request_count()); + if (group->has_pending_requests()) { group_dict->SetInteger("top_pending_priority", group->TopPendingPriority()); } @@ -756,7 +772,7 @@ void ClientSocketPoolBaseHelper::StartIdleSocketTimer() { } void ClientSocketPoolBaseHelper::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) { GroupMap::iterator i = group_map_.find(group_name); CHECK(i != group_map_.end()); @@ -773,10 +789,10 @@ void ClientSocketPoolBaseHelper::ReleaseSocket(const std::string& group_name, id == pool_generation_number_; if (can_reuse) { // Add it to the idle list. - AddIdleSocket(socket, group); + AddIdleSocket(socket.Pass(), group); OnAvailableSocketSlot(group_name, group); } else { - delete socket; + socket.reset(); } CheckForStalledSocketGroups(); @@ -786,8 +802,18 @@ void ClientSocketPoolBaseHelper::CheckForStalledSocketGroups() { // If we have idle sockets, see if we can give one to the top-stalled group. std::string top_group_name; Group* top_group = NULL; - if (!FindTopStalledGroup(&top_group, &top_group_name)) + if (!FindTopStalledGroup(&top_group, &top_group_name)) { + // There may still be a stalled group in a lower level pool. + for (std::set<LowerLayeredPool*>::iterator it = lower_pools_.begin(); + it != lower_pools_.end(); + ++it) { + if ((*it)->IsStalled()) { + CloseOneIdleSocket(); + break; + } + } return; + } if (ReachedMaxSocketsLimit()) { if (idle_socket_count() > 0) { @@ -820,8 +846,7 @@ bool ClientSocketPoolBaseHelper::FindTopStalledGroup( for (GroupMap::const_iterator i = group_map_.begin(); i != group_map_.end(); ++i) { Group* curr_group = i->second; - const RequestQueue& queue = curr_group->pending_requests(); - if (queue.empty()) + if (!curr_group->has_pending_requests()) continue; if (curr_group->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) { if (!group) @@ -854,27 +879,29 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete( CHECK(group_it != group_map_.end()); Group* group = group_it->second; - scoped_ptr<StreamSocket> socket(job->ReleaseSocket()); + scoped_ptr<StreamSocket> socket = job->PassSocket(); // Copies of these are needed because |job| may be deleted before they are // accessed. BoundNetLog job_log = job->net_log(); LoadTimingInfo::ConnectTiming connect_timing = job->connect_timing(); + // RemoveConnectJob(job, _) must be called by all branches below; + // otherwise, |job| will be leaked. + if (result == OK) { DCHECK(socket.get()); RemoveConnectJob(job, group); - if (!group->pending_requests().empty()) { - scoped_ptr<const Request> r(RemoveRequestFromQueue( - group->mutable_pending_requests()->begin(), group)); - LogBoundConnectJobToRequest(job_log.source(), r.get()); + scoped_ptr<const Request> request = group->PopNextPendingRequest(); + if (request) { + LogBoundConnectJobToRequest(job_log.source(), *request); HandOutSocket( - socket.release(), false /* unused socket */, connect_timing, - r->handle(), base::TimeDelta(), group, r->net_log()); - r->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL); - InvokeUserCallbackLater(r->handle(), r->callback(), result); + socket.Pass(), false /* unused socket */, connect_timing, + request->handle(), base::TimeDelta(), group, request->net_log()); + request->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL); + InvokeUserCallbackLater(request->handle(), request->callback(), result); } else { - AddIdleSocket(socket.release(), group); + AddIdleSocket(socket.Pass(), group); OnAvailableSocketSlot(group_name, group); CheckForStalledSocketGroups(); } @@ -882,20 +909,20 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete( // If we got a socket, it must contain error information so pass that // up so that the caller can retrieve it. bool handed_out_socket = false; - if (!group->pending_requests().empty()) { - scoped_ptr<const Request> r(RemoveRequestFromQueue( - group->mutable_pending_requests()->begin(), group)); - LogBoundConnectJobToRequest(job_log.source(), r.get()); - job->GetAdditionalErrorState(r->handle()); + scoped_ptr<const Request> request = group->PopNextPendingRequest(); + if (request) { + LogBoundConnectJobToRequest(job_log.source(), *request); + job->GetAdditionalErrorState(request->handle()); RemoveConnectJob(job, group); if (socket.get()) { handed_out_socket = true; - HandOutSocket(socket.release(), false /* unused socket */, - connect_timing, r->handle(), base::TimeDelta(), group, - r->net_log()); + HandOutSocket(socket.Pass(), false /* unused socket */, + connect_timing, request->handle(), base::TimeDelta(), + group, request->net_log()); } - r->net_log().EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, result); - InvokeUserCallbackLater(r->handle(), r->callback(), result); + request->net_log().EndEventWithNetErrorCode( + NetLog::TYPE_SOCKET_POOL, result); + InvokeUserCallbackLater(request->handle(), request->callback(), result); } else { RemoveConnectJob(job, group); } @@ -917,59 +944,38 @@ void ClientSocketPoolBaseHelper::FlushWithError(int error) { CancelAllRequestsWithError(error); } -bool ClientSocketPoolBaseHelper::IsStalled() const { - // If we are not using |max_sockets_|, then clearly we are not stalled - if ((handed_out_socket_count_ + connecting_socket_count_) < max_sockets_) - return false; - // So in order to be stalled we need to be using |max_sockets_| AND - // we need to have a request that is actually stalled on the global - // socket limit. To find such a request, we look for a group that - // a has more requests that jobs AND where the number of jobs is less - // than |max_sockets_per_group_|. (If the number of jobs is equal to - // |max_sockets_per_group_|, then the request is stalled on the group, - // which does not count.) - for (GroupMap::const_iterator it = group_map_.begin(); - it != group_map_.end(); ++it) { - if (it->second->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) - return true; - } - return false; -} - void ClientSocketPoolBaseHelper::RemoveConnectJob(ConnectJob* job, Group* group) { CHECK_GT(connecting_socket_count_, 0); connecting_socket_count_--; DCHECK(group); - DCHECK(ContainsKey(group->jobs(), job)); group->RemoveJob(job); // If we've got no more jobs for this group, then we no longer need a // backup job either. if (group->jobs().empty()) group->CleanupBackupJob(); - - DCHECK(job); - delete job; } void ClientSocketPoolBaseHelper::OnAvailableSocketSlot( const std::string& group_name, Group* group) { DCHECK(ContainsKey(group_map_, group_name)); - if (group->IsEmpty()) + if (group->IsEmpty()) { RemoveGroup(group_name); - else if (!group->pending_requests().empty()) + } else if (group->has_pending_requests()) { ProcessPendingRequest(group_name, group); + } } void ClientSocketPoolBaseHelper::ProcessPendingRequest( const std::string& group_name, Group* group) { - int rv = RequestSocketInternal(group_name, - *group->pending_requests().begin()); + const Request* next_request = group->GetNextPendingRequest(); + DCHECK(next_request); + int rv = RequestSocketInternal(group_name, *next_request); if (rv != ERR_IO_PENDING) { - scoped_ptr<const Request> request(RemoveRequestFromQueue( - group->mutable_pending_requests()->begin(), group)); + scoped_ptr<const Request> request = group->PopNextPendingRequest(); + DCHECK(request); if (group->IsEmpty()) RemoveGroup(group_name); @@ -979,7 +985,7 @@ void ClientSocketPoolBaseHelper::ProcessPendingRequest( } void ClientSocketPoolBaseHelper::HandOutSocket( - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, bool reused, const LoadTimingInfo::ConnectTiming& connect_timing, ClientSocketHandle* handle, @@ -987,7 +993,7 @@ void ClientSocketPoolBaseHelper::HandOutSocket( Group* group, const BoundNetLog& net_log) { DCHECK(socket); - handle->set_socket(socket); + handle->SetSocket(socket.Pass()); handle->set_is_reused(reused); handle->set_idle_time(idle_time); handle->set_pool_id(pool_generation_number_); @@ -1000,18 +1006,20 @@ void ClientSocketPoolBaseHelper::HandOutSocket( "idle_ms", static_cast<int>(idle_time.InMilliseconds()))); } - net_log.AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET, - socket->NetLog().source().ToEventParametersCallback()); + net_log.AddEvent( + NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET, + handle->socket()->NetLog().source().ToEventParametersCallback()); handed_out_socket_count_++; group->IncrementActiveSocketCount(); } void ClientSocketPoolBaseHelper::AddIdleSocket( - StreamSocket* socket, Group* group) { + scoped_ptr<StreamSocket> socket, + Group* group) { DCHECK(socket); IdleSocket idle_socket; - idle_socket.socket = socket; + idle_socket.socket = socket.release(); idle_socket.start_time = base::TimeTicks::Now(); group->mutable_idle_sockets()->push_back(idle_socket); @@ -1041,13 +1049,11 @@ void ClientSocketPoolBaseHelper::CancelAllRequestsWithError(int error) { for (GroupMap::iterator i = group_map_.begin(); i != group_map_.end();) { Group* group = i->second; - RequestQueue pending_requests; - pending_requests.swap(*group->mutable_pending_requests()); - for (RequestQueue::iterator it2 = pending_requests.begin(); - it2 != pending_requests.end(); ++it2) { - scoped_ptr<const Request> request(*it2); - InvokeUserCallbackLater( - request->handle(), request->callback(), error); + while (true) { + scoped_ptr<const Request> request = group->PopNextPendingRequest(); + if (!request) + break; + InvokeUserCallbackLater(request->handle(), request->callback(), error); } // Delete group if no longer needed. @@ -1103,12 +1109,12 @@ bool ClientSocketPoolBaseHelper::CloseOneIdleSocketExceptInGroup( return false; } -bool ClientSocketPoolBaseHelper::CloseOneIdleConnectionInLayeredPool() { +bool ClientSocketPoolBaseHelper::CloseOneIdleConnectionInHigherLayeredPool() { // This pool doesn't have any idle sockets. It's possible that a pool at a // higher layer is holding one of this sockets active, but it's actually idle. // Query the higher layers. - for (std::set<LayeredPool*>::const_iterator it = higher_layer_pools_.begin(); - it != higher_layer_pools_.end(); ++it) { + for (std::set<HigherLayeredPool*>::const_iterator it = higher_pools_.begin(); + it != higher_pools_.end(); ++it) { if ((*it)->CloseOneIdleConnection()) return true; } @@ -1144,7 +1150,7 @@ void ClientSocketPoolBaseHelper::TryToCloseSocketsInLayeredPools() { while (IsStalled()) { // Closing a socket will result in calling back into |this| to use the freed // socket slot, so nothing else is needed. - if (!CloseOneIdleConnectionInLayeredPool()) + if (!CloseOneIdleConnectionInHigherLayeredPool()) return; } } @@ -1182,19 +1188,25 @@ bool ClientSocketPoolBaseHelper::Group::TryToUseUnassignedConnectJob() { return true; } -void ClientSocketPoolBaseHelper::Group::AddJob(ConnectJob* job, +void ClientSocketPoolBaseHelper::Group::AddJob(scoped_ptr<ConnectJob> job, bool is_preconnect) { SanityCheck(); if (is_preconnect) ++unassigned_job_count_; - jobs_.insert(job); + jobs_.insert(job.release()); } void ClientSocketPoolBaseHelper::Group::RemoveJob(ConnectJob* job) { + scoped_ptr<ConnectJob> owned_job(job); SanityCheck(); - jobs_.erase(job); + std::set<ConnectJob*>::iterator it = jobs_.find(job); + if (it != jobs_.end()) { + jobs_.erase(it); + } else { + NOTREACHED(); + } size_t job_count = jobs_.size(); if (job_count < unassigned_job_count_) unassigned_job_count_ = job_count; @@ -1222,15 +1234,17 @@ void ClientSocketPoolBaseHelper::Group::OnBackupSocketTimerFired( if (pending_requests_.empty()) return; - ConnectJob* backup_job = pool->connect_job_factory_->NewConnectJob( - group_name, **pending_requests_.begin(), pool); + scoped_ptr<ConnectJob> backup_job = + pool->connect_job_factory_->NewConnectJob( + group_name, **pending_requests_.begin(), pool); backup_job->net_log().AddEvent(NetLog::TYPE_SOCKET_BACKUP_CREATED); SIMPLE_STATS_COUNTER("socket.backup_created"); int rv = backup_job->Connect(); pool->connecting_socket_count_++; - AddJob(backup_job, false); + ConnectJob* raw_backup_job = backup_job.get(); + AddJob(backup_job.Pass(), false); if (rv != ERR_IO_PENDING) - pool->OnConnectJobComplete(rv, backup_job); + pool->OnConnectJobComplete(rv, raw_backup_job); } void ClientSocketPoolBaseHelper::Group::SanityCheck() { @@ -1248,6 +1262,68 @@ void ClientSocketPoolBaseHelper::Group::RemoveAllJobs() { weak_factory_.InvalidateWeakPtrs(); } +const ClientSocketPoolBaseHelper::Request* +ClientSocketPoolBaseHelper::Group::GetNextPendingRequest() const { + return pending_requests_.empty() ? NULL : *pending_requests_.begin(); +} + +bool ClientSocketPoolBaseHelper::Group::HasConnectJobForHandle( + const ClientSocketHandle* handle) const { + // Search the first |jobs_.size()| pending requests for |handle|. + // If it's farther back in the deque than that, it doesn't have a + // corresponding ConnectJob. + size_t i = 0; + for (RequestQueue::const_iterator it = pending_requests_.begin(); + it != pending_requests_.end() && i < jobs_.size(); ++it, ++i) { + if ((*it)->handle() == handle) + return true; + } + return false; +} + +void ClientSocketPoolBaseHelper::Group::InsertPendingRequest( + scoped_ptr<const Request> r) { + RequestQueue::iterator it = pending_requests_.begin(); + // TODO(mmenke): Should the network stack require requests with + // |ignore_limits| have the highest priority? + while (it != pending_requests_.end() && + CompareEffectiveRequestPriority(*r, *(*it)) <= 0) { + ++it; + } + pending_requests_.insert(it, r.release()); +} + +scoped_ptr<const ClientSocketPoolBaseHelper::Request> +ClientSocketPoolBaseHelper::Group::PopNextPendingRequest() { + if (pending_requests_.empty()) + return scoped_ptr<const ClientSocketPoolBaseHelper::Request>(); + return RemovePendingRequest(pending_requests_.begin()); +} + +scoped_ptr<const ClientSocketPoolBaseHelper::Request> +ClientSocketPoolBaseHelper::Group::FindAndRemovePendingRequest( + ClientSocketHandle* handle) { + for (RequestQueue::iterator it = pending_requests_.begin(); + it != pending_requests_.end(); ++it) { + if ((*it)->handle() == handle) { + scoped_ptr<const Request> request = RemovePendingRequest(it); + return request.Pass(); + } + } + return scoped_ptr<const ClientSocketPoolBaseHelper::Request>(); +} + +scoped_ptr<const ClientSocketPoolBaseHelper::Request> +ClientSocketPoolBaseHelper::Group::RemovePendingRequest( + const RequestQueue::iterator& it) { + scoped_ptr<const Request> request(*it); + pending_requests_.erase(it); + // If there are no more requests, kill the backup timer. + if (pending_requests_.empty()) + CleanupBackupJob(); + return request.Pass(); +} + } // namespace internal } // namespace net diff --git a/chromium/net/socket/client_socket_pool_base.h b/chromium/net/socket/client_socket_pool_base.h index 4bf95d7b04a..31ec9bf7b13 100644 --- a/chromium/net/socket/client_socket_pool_base.h +++ b/chromium/net/socket/client_socket_pool_base.h @@ -61,8 +61,11 @@ class NET_EXPORT_PRIVATE ConnectJob { Delegate() {} virtual ~Delegate() {} - // Alerts the delegate that the connection completed. - virtual void OnConnectJobComplete(int result, ConnectJob* job) = 0; + // Alerts the delegate that the connection completed. |job| must + // be destroyed by the delegate. A scoped_ptr<> isn't used because + // the caller of this function doesn't own |job|. + virtual void OnConnectJobComplete(int result, + ConnectJob* job) = 0; private: DISALLOW_COPY_AND_ASSIGN(Delegate); @@ -71,6 +74,7 @@ class NET_EXPORT_PRIVATE ConnectJob { // A |timeout_duration| of 0 corresponds to no timeout. ConnectJob(const std::string& group_name, base::TimeDelta timeout_duration, + RequestPriority priority, Delegate* delegate, const BoundNetLog& net_log); virtual ~ConnectJob(); @@ -79,9 +83,10 @@ class NET_EXPORT_PRIVATE ConnectJob { const std::string& group_name() const { return group_name_; } const BoundNetLog& net_log() { return net_log_; } - // Releases |socket_| to the client. On connection error, this should return - // NULL. - StreamSocket* ReleaseSocket() { return socket_.release(); } + // Releases ownership of the underlying socket to the caller. + // Returns the released socket, or NULL if there was a connection + // error. + scoped_ptr<StreamSocket> PassSocket(); // Begins connecting the socket. Returns OK on success, ERR_IO_PENDING if it // cannot complete synchronously without blocking, or another net error code @@ -105,7 +110,8 @@ class NET_EXPORT_PRIVATE ConnectJob { const BoundNetLog& net_log() const { return net_log_; } protected: - void set_socket(StreamSocket* socket); + RequestPriority priority() const { return priority_; } + void SetSocket(scoped_ptr<StreamSocket> socket); StreamSocket* socket() { return socket_.get(); } void NotifyDelegateOfCompletion(int rv); void ResetTimer(base::TimeDelta remainingTime); @@ -124,6 +130,8 @@ class NET_EXPORT_PRIVATE ConnectJob { const std::string group_name_; const base::TimeDelta timeout_duration_; + // TODO(akalin): Support reprioritization. + const RequestPriority priority_; // Timer to abort jobs that take too long. base::OneShotTimer<ConnectJob> timer_; Delegate* delegate_; @@ -175,6 +183,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper private: ClientSocketHandle* const handle_; CompletionCallback callback_; + // TODO(akalin): Support reprioritization. const RequestPriority priority_; bool ignore_limits_; const Flags flags_; @@ -188,7 +197,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper ConnectJobFactory() {} virtual ~ConnectJobFactory() {} - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const Request& request, ConnectJob::Delegate* delegate) const = 0; @@ -200,6 +209,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper }; ClientSocketPoolBaseHelper( + HigherLayeredPool* pool, int max_sockets, int max_sockets_per_group, base::TimeDelta unused_idle_socket_timeout, @@ -208,15 +218,21 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper virtual ~ClientSocketPoolBaseHelper(); - // Adds/Removes layered pools. It is expected in the destructor that no - // layered pools remain. - void AddLayeredPool(LayeredPool* pool); - void RemoveLayeredPool(LayeredPool* pool); + // Adds a lower layered pool to |this|, and adds |this| as a higher layered + // pool on top of |lower_pool|. + void AddLowerLayeredPool(LowerLayeredPool* lower_pool); + + // See LowerLayeredPool::IsStalled for documentation on this function. + bool IsStalled() const; + + // See LowerLayeredPool for documentation on these functions. It is expected + // in the destructor that no higher layer pools remain. + void AddHigherLayeredPool(HigherLayeredPool* higher_pool); + void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool); // See ClientSocketPool::RequestSocket for documentation on this function. - // ClientSocketPoolBaseHelper takes ownership of |request|, which must be - // heap allocated. - int RequestSocket(const std::string& group_name, const Request* request); + int RequestSocket(const std::string& group_name, + scoped_ptr<const Request> request); // See ClientSocketPool::RequestSocket for documentation on this function. void RequestSockets(const std::string& group_name, @@ -229,15 +245,12 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // See ClientSocketPool::ReleaseSocket for documentation on this function. void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id); // See ClientSocketPool::FlushWithError for documentation on this function. void FlushWithError(int error); - // See ClientSocketPool::IsStalled for documentation on this function. - bool IsStalled() const; - // See ClientSocketPool::CloseIdleSockets for documentation on this function. void CloseIdleSockets(); @@ -294,8 +307,8 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // I'm not sure if we hit this situation often. bool CloseOneIdleSocket(); - // Checks layered pools to see if they can close an idle connection. - bool CloseOneIdleConnectionInLayeredPool(); + // Checks higher layered pools to see if they can close an idle connection. + bool CloseOneIdleConnectionInHigherLayeredPool(); // See ClientSocketPool::GetInfoAsValue for documentation on this function. base::DictionaryValue* GetInfoAsValue(const std::string& name, @@ -386,22 +399,55 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // Otherwise, returns false. bool TryToUseUnassignedConnectJob(); - void AddJob(ConnectJob* job, bool is_preconnect); + void AddJob(scoped_ptr<ConnectJob> job, bool is_preconnect); + // Remove |job| from this group, which must already own |job|. void RemoveJob(ConnectJob* job); void RemoveAllJobs(); + bool has_pending_requests() const { + return !pending_requests_.empty(); + } + + size_t pending_request_count() const { + return pending_requests_.size(); + } + + // Gets (but does not remove) the next pending request. Returns + // NULL if there are no pending requests. + const Request* GetNextPendingRequest() const; + + // Returns true if there is a connect job for |handle|. + bool HasConnectJobForHandle(const ClientSocketHandle* handle) const; + + // Inserts the request into the queue based on priority + // order. Older requests are prioritized over requests of equal + // priority. + void InsertPendingRequest(scoped_ptr<const Request> r); + + // Gets and removes the next pending request. Returns NULL if + // there are no pending requests. + scoped_ptr<const Request> PopNextPendingRequest(); + + // Finds the pending request for |handle| and removes it. Returns + // the removed pending request, or NULL if there was none. + scoped_ptr<const Request> FindAndRemovePendingRequest( + ClientSocketHandle* handle); + void IncrementActiveSocketCount() { active_socket_count_++; } void DecrementActiveSocketCount() { active_socket_count_--; } int unassigned_job_count() const { return unassigned_job_count_; } const std::set<ConnectJob*>& jobs() const { return jobs_; } const std::list<IdleSocket>& idle_sockets() const { return idle_sockets_; } - const RequestQueue& pending_requests() const { return pending_requests_; } int active_socket_count() const { return active_socket_count_; } - RequestQueue* mutable_pending_requests() { return &pending_requests_; } std::list<IdleSocket>* mutable_idle_sockets() { return &idle_sockets_; } private: + // Returns the iterator's pending request after removing it from + // the queue. + scoped_ptr<const Request> RemovePendingRequest( + const RequestQueue::iterator& it); + // Called when the backup socket timer fires. void OnBackupSocketTimerFired( std::string group_name, @@ -443,15 +489,6 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper typedef std::map<const ClientSocketHandle*, CallbackResultPair> PendingCallbackMap; - // Inserts the request into the queue based on order they will receive - // sockets. Sockets which ignore the socket pool limits are first. Then - // requests are sorted by priority, with higher priorities closer to the - // front. Older requests are prioritized over requests of equal priority. - static void InsertRequestIntoQueue(const Request* r, - RequestQueue* pending_requests); - static const Request* RemoveRequestFromQueue(const RequestQueue::iterator& it, - Group* group); - Group* GetOrCreateGroup(const std::string& group_name); void RemoveGroup(const std::string& group_name); void RemoveGroup(GroupMap::iterator it); @@ -475,7 +512,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper CleanupIdleSockets(false); } - // Removes |job| from |connect_job_set_|. Also updates |group| if non-NULL. + // Removes |job| from |group|, which must already own |job|. void RemoveConnectJob(ConnectJob* job, Group* group); // Tries to see if we can handle any more requests for |group|. @@ -485,7 +522,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper void ProcessPendingRequest(const std::string& group_name, Group* group); // Assigns |socket| to |handle| and updates |group|'s counters appropriately. - void HandOutSocket(StreamSocket* socket, + void HandOutSocket(scoped_ptr<StreamSocket> socket, bool reused, const LoadTimingInfo::ConnectTiming& connect_timing, ClientSocketHandle* handle, @@ -494,7 +531,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper const BoundNetLog& net_log); // Adds |socket| to the list of idle sockets for |group|. - void AddIdleSocket(StreamSocket* socket, Group* group); + void AddIdleSocket(scoped_ptr<StreamSocket> socket, Group* group); // Iterates through |group_map_|, canceling all ConnectJobs and deleting // groups if they are no longer needed. @@ -511,14 +548,14 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // it does not handle logging into NetLog of the queueing status of // |request|. int RequestSocketInternal(const std::string& group_name, - const Request* request); + const Request& request); // Assigns an idle socket for the group to the request. // Returns |true| if an idle socket is available, false otherwise. - bool AssignIdleSocketToRequest(const Request* request, Group* group); + bool AssignIdleSocketToRequest(const Request& request, Group* group); static void LogBoundConnectJobToRequest( - const NetLog::Source& connect_job_source, const Request* request); + const NetLog::Source& connect_job_source, const Request& request); // Same as CloseOneIdleSocket() except it won't close an idle socket in // |group|. If |group| is NULL, it is ignored. Returns true if it closed a @@ -588,7 +625,18 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // to the pool, we can make sure that they are discarded rather than reused. int pool_generation_number_; - std::set<LayeredPool*> higher_layer_pools_; + // Used to add |this| as a higher layer pool on top of lower layer pools. May + // be NULL if no lower layer pools will be added. + HigherLayeredPool* pool_; + + // Pools that create connections through |this|. |this| will try to close + // their idle sockets when it stalls. Must be empty on destruction. + std::set<HigherLayeredPool*> higher_pools_; + + // Pools that this goes through. Typically there's only one, but not always. + // |this| will check if they're stalled when it has a new idle socket. |this| + // will remove itself from all lower layered pools on destruction. + std::set<LowerLayeredPool*> lower_pools_; base::WeakPtrFactory<ClientSocketPoolBaseHelper> weak_factory_; @@ -624,7 +672,7 @@ class ClientSocketPoolBase { ConnectJobFactory() {} virtual ~ConnectJobFactory() {} - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const Request& request, ConnectJob::Delegate* delegate) const = 0; @@ -642,6 +690,7 @@ class ClientSocketPoolBase { // |used_idle_socket_timeout| specifies how long to leave a previously used // idle socket open before closing it. ClientSocketPoolBase( + HigherLayeredPool* self, int max_sockets, int max_sockets_per_group, ClientSocketPoolHistograms* histograms, @@ -649,19 +698,23 @@ class ClientSocketPoolBase { base::TimeDelta used_idle_socket_timeout, ConnectJobFactory* connect_job_factory) : histograms_(histograms), - helper_(max_sockets, max_sockets_per_group, + helper_(self, max_sockets, max_sockets_per_group, unused_idle_socket_timeout, used_idle_socket_timeout, new ConnectJobFactoryAdaptor(connect_job_factory)) {} virtual ~ClientSocketPoolBase() {} // These member functions simply forward to ClientSocketPoolBaseHelper. - void AddLayeredPool(LayeredPool* pool) { - helper_.AddLayeredPool(pool); + void AddLowerLayeredPool(LowerLayeredPool* lower_pool) { + helper_.AddLowerLayeredPool(lower_pool); } - void RemoveLayeredPool(LayeredPool* pool) { - helper_.RemoveLayeredPool(pool); + void AddHigherLayeredPool(HigherLayeredPool* higher_pool) { + helper_.AddHigherLayeredPool(higher_pool); + } + + void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) { + helper_.RemoveHigherLayeredPool(higher_pool); } // RequestSocket bundles up the parameters into a Request and then forwards to @@ -672,12 +725,15 @@ class ClientSocketPoolBase { ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) { - Request* request = + scoped_ptr<const Request> request( new Request(handle, callback, priority, internal::ClientSocketPoolBaseHelper::NORMAL, params->ignore_limits(), - params, net_log); - return helper_.RequestSocket(group_name, request); + params, net_log)); + return helper_.RequestSocket( + group_name, + request.template PassAs< + const internal::ClientSocketPoolBaseHelper::Request>()); } // RequestSockets bundles up the parameters into a Request and then forwards @@ -702,9 +758,10 @@ class ClientSocketPoolBase { return helper_.CancelRequest(group_name, handle); } - void ReleaseSocket(const std::string& group_name, StreamSocket* socket, + void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, int id) { - return helper_.ReleaseSocket(group_name, socket, id); + return helper_.ReleaseSocket(group_name, socket.Pass(), id); } void FlushWithError(int error) { helper_.FlushWithError(error); } @@ -765,8 +822,8 @@ class ClientSocketPoolBase { bool CloseOneIdleSocket() { return helper_.CloseOneIdleSocket(); } - bool CloseOneIdleConnectionInLayeredPool() { - return helper_.CloseOneIdleConnectionInLayeredPool(); + bool CloseOneIdleConnectionInHigherLayeredPool() { + return helper_.CloseOneIdleConnectionInHigherLayeredPool(); } private: @@ -785,13 +842,13 @@ class ClientSocketPoolBase { : connect_job_factory_(connect_job_factory) {} virtual ~ConnectJobFactoryAdaptor() {} - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const internal::ClientSocketPoolBaseHelper::Request& request, - ConnectJob::Delegate* delegate) const { - const Request* casted_request = static_cast<const Request*>(&request); + ConnectJob::Delegate* delegate) const OVERRIDE { + const Request& casted_request = static_cast<const Request&>(request); return connect_job_factory_->NewConnectJob( - group_name, *casted_request, delegate); + group_name, casted_request, delegate); } virtual base::TimeDelta ConnectionTimeout() const { diff --git a/chromium/net/socket/client_socket_pool_base_unittest.cc b/chromium/net/socket/client_socket_pool_base_unittest.cc index 5eeda972cff..bbeca2f3e11 100644 --- a/chromium/net/socket/client_socket_pool_base_unittest.cc +++ b/chromium/net/socket/client_socket_pool_base_unittest.cc @@ -30,7 +30,9 @@ #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" +#include "net/udp/datagram_client_socket.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" @@ -189,30 +191,30 @@ class MockClientSocketFactory : public ClientSocketFactory { public: MockClientSocketFactory() : allocation_count_(0) {} - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE { NOTREACHED(); - return NULL; + return scoped_ptr<DatagramClientSocket>(); } - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* /* net_log */, const NetLog::Source& /*source*/) OVERRIDE { allocation_count_++; - return NULL; + return scoped_ptr<StreamSocket>(); } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); } virtual void ClearSSLSessionCache() OVERRIDE { @@ -259,7 +261,7 @@ class TestConnectJob : public ConnectJob { ConnectJob::Delegate* delegate, MockClientSocketFactory* client_socket_factory, NetLog* net_log) - : ConnectJob(group_name, timeout_duration, delegate, + : ConnectJob(group_name, timeout_duration, request.priority(), delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), job_type_(job_type), client_socket_factory_(client_socket_factory), @@ -294,7 +296,8 @@ class TestConnectJob : public ConnectJob { AddressList ignored; client_socket_factory_->CreateTransportClientSocket( ignored, NULL, net::NetLog::Source()); - set_socket(new MockClientSocket(net_log().net_log())); + SetSocket( + scoped_ptr<StreamSocket>(new MockClientSocket(net_log().net_log()))); switch (job_type_) { case kMockJob: return DoConnect(true /* successful */, false /* sync */, @@ -373,7 +376,7 @@ class TestConnectJob : public ConnectJob { return ERR_IO_PENDING; default: NOTREACHED(); - set_socket(NULL); + SetSocket(scoped_ptr<StreamSocket>()); return ERR_FAILED; } } @@ -386,7 +389,7 @@ class TestConnectJob : public ConnectJob { result = ERR_PROXY_AUTH_REQUESTED; } else { result = ERR_CONNECTION_FAILED; - set_socket(NULL); + SetSocket(scoped_ptr<StreamSocket>()); } if (was_async) @@ -430,7 +433,7 @@ class TestConnectJobFactory // ConnectJobFactory implementation. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const TestClientSocketPoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE { @@ -440,13 +443,13 @@ class TestConnectJobFactory job_type = job_types_->front(); job_types_->pop_front(); } - return new TestConnectJob(job_type, - group_name, - request, - timeout_duration_, - delegate, - client_socket_factory_, - net_log_); + return scoped_ptr<ConnectJob>(new TestConnectJob(job_type, + group_name, + request, + timeout_duration_, + delegate, + client_socket_factory_, + net_log_)); } virtual base::TimeDelta ConnectionTimeout() const OVERRIDE { @@ -465,6 +468,8 @@ class TestConnectJobFactory class TestClientSocketPool : public ClientSocketPool { public: + typedef TestSocketParams SocketParams; + TestClientSocketPool( int max_sockets, int max_sockets_per_group, @@ -472,7 +477,7 @@ class TestClientSocketPool : public ClientSocketPool { base::TimeDelta unused_idle_socket_timeout, base::TimeDelta used_idle_socket_timeout, TestClientSocketPoolBase::ConnectJobFactory* connect_job_factory) - : base_(max_sockets, max_sockets_per_group, histograms, + : base_(NULL, max_sockets, max_sockets_per_group, histograms, unused_idle_socket_timeout, used_idle_socket_timeout, connect_job_factory) {} @@ -509,9 +514,9 @@ class TestClientSocketPool : public ClientSocketPool { virtual void ReleaseSocket( const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE { - base_.ReleaseSocket(group_name, socket, id); + base_.ReleaseSocket(group_name, socket.Pass(), id); } virtual void FlushWithError(int error) OVERRIDE { @@ -541,12 +546,13 @@ class TestClientSocketPool : public ClientSocketPool { return base_.GetLoadState(group_name, handle); } - virtual void AddLayeredPool(LayeredPool* pool) OVERRIDE { - base_.AddLayeredPool(pool); + virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE { + base_.AddHigherLayeredPool(higher_pool); } - virtual void RemoveLayeredPool(LayeredPool* pool) OVERRIDE { - base_.RemoveLayeredPool(pool); + virtual void RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) OVERRIDE { + base_.RemoveHigherLayeredPool(higher_pool); } virtual base::DictionaryValue* GetInfoAsValue( @@ -586,8 +592,8 @@ class TestClientSocketPool : public ClientSocketPool { void EnableConnectBackupJobs() { base_.EnableConnectBackupJobs(); } - bool CloseOneIdleConnectionInLayeredPool() { - return base_.CloseOneIdleConnectionInLayeredPool(); + bool CloseOneIdleConnectionInHigherLayeredPool() { + return base_.CloseOneIdleConnectionInHigherLayeredPool(); } private: @@ -598,8 +604,6 @@ class TestClientSocketPool : public ClientSocketPool { } // namespace -REGISTER_SOCKET_PARAMS_FOR_POOL(TestClientSocketPool, TestSocketParams); - namespace { void MockClientSocketFactory::SignalJobs() { @@ -630,10 +634,10 @@ class TestConnectJobDelegate : public ConnectJob::Delegate { virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE { result_ = result; - scoped_ptr<StreamSocket> socket(job->ReleaseSocket()); + scoped_ptr<ConnectJob> owned_job(job); + scoped_ptr<StreamSocket> socket = owned_job->PassSocket(); // socket.get() should be NULL iff result != OK - EXPECT_EQ(socket.get() == NULL, result != OK); - delete job; + EXPECT_EQ(socket == NULL, result != OK); have_result_ = true; if (waiting_for_result_) base::MessageLoop::current()->Quit(); @@ -702,9 +706,8 @@ class ClientSocketPoolBaseTest : public testing::Test { const std::string& group_name, RequestPriority priority, const scoped_refptr<TestSocketParams>& params) { - return test_base_.StartRequestUsingPool< - TestClientSocketPool, TestSocketParams>( - pool_.get(), group_name, priority, params); + return test_base_.StartRequestUsingPool( + pool_.get(), group_name, priority, params); } int StartRequest(const std::string& group_name, RequestPriority priority) { @@ -3716,7 +3719,7 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectWithBackupJob) { EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a")); } -class MockLayeredPool : public LayeredPool { +class MockLayeredPool : public HigherLayeredPool { public: MockLayeredPool(TestClientSocketPool* pool, const std::string& group_name) @@ -3724,11 +3727,11 @@ class MockLayeredPool : public LayeredPool { params_(new TestSocketParams), group_name_(group_name), can_release_connection_(true) { - pool_->AddLayeredPool(this); + pool_->AddHigherLayeredPool(this); } ~MockLayeredPool() { - pool_->RemoveLayeredPool(this); + pool_->RemoveHigherLayeredPool(this); } int RequestSocket(TestClientSocketPool* pool) { @@ -3774,7 +3777,7 @@ TEST_F(ClientSocketPoolBaseTest, FailToCloseIdleSocketsNotHeldByLayeredPool) { EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) .WillOnce(Return(false)); - EXPECT_FALSE(pool_->CloseOneIdleConnectionInLayeredPool()); + EXPECT_FALSE(pool_->CloseOneIdleConnectionInHigherLayeredPool()); } TEST_F(ClientSocketPoolBaseTest, ForciblyCloseIdleSocketsHeldByLayeredPool) { @@ -3786,7 +3789,7 @@ TEST_F(ClientSocketPoolBaseTest, ForciblyCloseIdleSocketsHeldByLayeredPool) { EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) .WillOnce(Invoke(&mock_layered_pool, &MockLayeredPool::ReleaseOneConnection)); - EXPECT_TRUE(pool_->CloseOneIdleConnectionInLayeredPool()); + EXPECT_TRUE(pool_->CloseOneIdleConnectionInHigherLayeredPool()); } // Tests the basic case of closing an idle socket in a higher layered pool when diff --git a/chromium/net/socket/client_socket_pool_manager.cc b/chromium/net/socket/client_socket_pool_manager.cc index 71496d28646..b37d2d1949c 100644 --- a/chromium/net/socket/client_socket_pool_manager.cc +++ b/chromium/net/socket/client_socket_pool_manager.cc @@ -158,7 +158,6 @@ int InitSocketPoolHelper(const GURL& request_url, bool ignore_limits = (request_load_flags & LOAD_IGNORE_LIMITS) != 0; if (proxy_info.is_direct()) { tcp_params = new TransportSocketParams(origin_host_port, - request_priority, disable_resolver_cache, ignore_limits, resolution_callback); @@ -167,7 +166,6 @@ int InitSocketPoolHelper(const GURL& request_url, proxy_host_port.reset(new HostPortPair(proxy_server.host_port_pair())); scoped_refptr<TransportSocketParams> proxy_tcp_params( new TransportSocketParams(*proxy_host_port, - request_priority, disable_resolver_cache, ignore_limits, resolution_callback)); @@ -182,7 +180,6 @@ int InitSocketPoolHelper(const GURL& request_url, ssl_params = new SSLSocketParams(proxy_tcp_params, NULL, NULL, - ProxyServer::SCHEME_DIRECT, *proxy_host_port.get(), ssl_config_for_proxy, kPrivacyModeDisabled, @@ -214,8 +211,7 @@ int InitSocketPoolHelper(const GURL& request_url, socks_params = new SOCKSSocketParams(proxy_tcp_params, socks_version == '5', - origin_host_port, - request_priority); + origin_host_port); } } @@ -229,7 +225,6 @@ int InitSocketPoolHelper(const GURL& request_url, new SSLSocketParams(tcp_params, socks_params, http_proxy_params, - proxy_info.proxy_server().scheme(), origin_host_port, ssl_config_for_origin, privacy_mode, diff --git a/chromium/net/socket/deterministic_socket_data_unittest.cc b/chromium/net/socket/deterministic_socket_data_unittest.cc index eba01b5e9cc..c51427e25a7 100644 --- a/chromium/net/socket/deterministic_socket_data_unittest.cc +++ b/chromium/net/socket/deterministic_socket_data_unittest.cc @@ -72,7 +72,6 @@ DeterministicSocketDataTest::DeterministicSocketDataTest() connect_data_(SYNCHRONOUS, OK), endpoint_("www.google.com", 443), tcp_params_(new TransportSocketParams(endpoint_, - LOWEST, false, false, OnHostResolutionCallback())), diff --git a/chromium/net/socket/nss_ssl_util.cc b/chromium/net/socket/nss_ssl_util.cc index be33ac5add0..7e3aee430c4 100644 --- a/chromium/net/socket/nss_ssl_util.cc +++ b/chromium/net/socket/nss_ssl_util.cc @@ -58,12 +58,13 @@ class NSSSSLInitSingleton { enabled = false; // Trim the list of cipher suites in order to keep the size of the - // ClientHello down. DSS, ECDH, CAMELLIA, SEED and ECC+3DES cipher - // suites are disabled. + // ClientHello down. DSS, ECDH, CAMELLIA, SEED, ECC+3DES, and + // HMAC-SHA256 cipher suites are disabled. if (info.symCipher == ssl_calg_camellia || info.symCipher == ssl_calg_seed || (info.symCipher == ssl_calg_3des && info.keaType != ssl_kea_rsa) || info.authAlgorithm == ssl_auth_dsa || + info.macAlgorithm == ssl_hmac_sha256 || info.nonStandard || strcmp(info.keaTypeName, "ECDH") == 0) { enabled = false; @@ -232,6 +233,10 @@ int MapNSSError(PRErrorCode err) { case SEC_ERROR_BAD_DER: case SEC_ERROR_EXTRA_INPUT: return ERR_SSL_BAD_PEER_PUBLIC_KEY; + // During renegotiation, the server presented a different certificate than + // was used earlier. + case SSL_ERROR_WRONG_CERTIFICATE: + return ERR_SSL_SERVER_CERT_CHANGED; default: { if (IS_SSL_ERROR(err)) { diff --git a/chromium/net/socket/socket_descriptor.cc b/chromium/net/socket/socket_descriptor.cc new file mode 100644 index 00000000000..5a2e53cab4d --- /dev/null +++ b/chromium/net/socket/socket_descriptor.cc @@ -0,0 +1,48 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socket_descriptor.h" + +#if defined(OS_POSIX) +#include <sys/types.h> +#include <sys/socket.h> +#endif + +#include "base/basictypes.h" + +#if defined(OS_WIN) +#include "net/base/winsock_init.h" +#endif + +namespace net { + +PlatformSocketFactory* g_socket_factory = NULL; + +PlatformSocketFactory::PlatformSocketFactory() { +} + +PlatformSocketFactory::~PlatformSocketFactory() { +} + +void PlatformSocketFactory::SetInstance(PlatformSocketFactory* factory) { + g_socket_factory = factory; +} + +SocketDescriptor CreateSocketDefault(int family, int type, int protocol) { +#if defined(OS_WIN) + EnsureWinsockInit(); + return ::WSASocket(family, type, protocol, NULL, 0, WSA_FLAG_OVERLAPPED); +#else // OS_WIN + return ::socket(family, type, protocol); +#endif // OS_WIN +} + +SocketDescriptor CreatePlatformSocket(int family, int type, int protocol) { + if (g_socket_factory) + return g_socket_factory->CreateSocket(family, type, protocol); + else + return CreateSocketDefault(family, type, protocol); +} + +} // namespace net diff --git a/chromium/net/socket/socket_descriptor.h b/chromium/net/socket/socket_descriptor.h new file mode 100644 index 00000000000..b2a22234b80 --- /dev/null +++ b/chromium/net/socket/socket_descriptor.h @@ -0,0 +1,49 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SOCKET_DESCRIPTOR_H_ +#define NET_SOCKET_SOCKET_DESCRIPTOR_H_ + +#include "build/build_config.h" +#include "net/base/net_export.h" + +#if defined(OS_WIN) +#include <winsock2.h> +#endif // OS_WIN + +namespace net { + +#if defined(OS_POSIX) +typedef int SocketDescriptor; +const SocketDescriptor kInvalidSocket = -1; +#elif defined(OS_WIN) +typedef SOCKET SocketDescriptor; +const SocketDescriptor kInvalidSocket = INVALID_SOCKET; +#endif + +// Interface to create native socket. +// Usually such factories are used for testing purposes, which is not true in +// this case. This interface is used to substitute WSASocket/socket to make +// possible execution of some network code in sandbox. +class NET_EXPORT PlatformSocketFactory { + public: + PlatformSocketFactory(); + virtual ~PlatformSocketFactory(); + + // Replace WSASocket/socket with given factory. The factory will be used by + // CreatePlatformSocket. + static void SetInstance(PlatformSocketFactory* factory); + + // Creates socket. See WSASocket/socket documentation of parameters. + virtual SocketDescriptor CreateSocket(int family, int type, int protocol) = 0; +}; + +// Creates socket. See WSASocket/socket documentation of parameters. +SocketDescriptor NET_EXPORT CreatePlatformSocket(int family, + int type, + int protocol); + +} // namespace net + +#endif // NET_SOCKET_SOCKET_DESCRIPTOR_H_ diff --git a/chromium/net/socket/socket_test_util.cc b/chromium/net/socket/socket_test_util.cc index 8b2bdfccba3..78e9e7ce9c4 100644 --- a/chromium/net/socket/socket_test_util.cc +++ b/chromium/net/socket/socket_test_util.cc @@ -657,37 +657,39 @@ void MockClientSocketFactory::ResetNextMockIndexes() { mock_ssl_data_.ResetNextIndex(); } -DatagramClientSocket* MockClientSocketFactory::CreateDatagramClientSocket( +scoped_ptr<DatagramClientSocket> +MockClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, net::NetLog* net_log, const net::NetLog::Source& source) { SocketDataProvider* data_provider = mock_data_.GetNext(); - MockUDPClientSocket* socket = new MockUDPClientSocket(data_provider, net_log); - data_provider->set_socket(socket); - return socket; + scoped_ptr<MockUDPClientSocket> socket( + new MockUDPClientSocket(data_provider, net_log)); + data_provider->set_socket(socket.get()); + return socket.PassAs<DatagramClientSocket>(); } -StreamSocket* MockClientSocketFactory::CreateTransportClientSocket( +scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket( const AddressList& addresses, net::NetLog* net_log, const net::NetLog::Source& source) { SocketDataProvider* data_provider = mock_data_.GetNext(); - MockTCPClientSocket* socket = - new MockTCPClientSocket(addresses, net_log, data_provider); - data_provider->set_socket(socket); - return socket; + scoped_ptr<MockTCPClientSocket> socket( + new MockTCPClientSocket(addresses, net_log, data_provider)); + data_provider->set_socket(socket.get()); + return socket.PassAs<StreamSocket>(); } -SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( - ClientSocketHandle* transport_socket, +scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) { - MockSSLClientSocket* socket = - new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, - mock_ssl_data_.GetNext()); - return socket; + return scoped_ptr<SSLClientSocket>( + new MockSSLClientSocket(transport_socket.Pass(), + host_and_port, ssl_config, + mock_ssl_data_.GetNext())); } void MockClientSocketFactory::ClearSSLSessionCache() { @@ -1278,7 +1280,7 @@ void DeterministicMockTCPClientSocket::OnConnectComplete( // static void MockSSLClientSocket::ConnectCallback( - MockSSLClientSocket *ssl_client_socket, + MockSSLClientSocket* ssl_client_socket, const CompletionCallback& callback, int rv) { if (rv == OK) @@ -1287,7 +1289,7 @@ void MockSSLClientSocket::ConnectCallback( } MockSSLClientSocket::MockSSLClientSocket( - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_port_pair, const SSLConfig& ssl_config, SSLSocketDataProvider* data) @@ -1295,7 +1297,7 @@ MockSSLClientSocket::MockSSLClientSocket( // Have to use the right BoundNetLog for LoadTimingInfo regression // tests. transport_socket->socket()->NetLog()), - transport_(transport_socket), + transport_(transport_socket.Pass()), data_(data), is_npn_state_set_(false), new_npn_value_(false), @@ -1664,10 +1666,10 @@ void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) { } MockTransportClientSocketPool::MockConnectJob::MockConnectJob( - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, ClientSocketHandle* handle, const CompletionCallback& callback) - : socket_(socket), + : socket_(socket.Pass()), handle_(handle), user_callback_(callback) { } @@ -1698,7 +1700,7 @@ void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) { if (!socket_.get()) return; if (rv == OK) { - handle_->set_socket(socket_.release()); + handle_->SetSocket(socket_.Pass()); // Needed for socket pool tests that layer other sockets on top of mock // sockets. @@ -1730,6 +1732,7 @@ MockTransportClientSocketPool::MockTransportClientSocketPool( : TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms, NULL, NULL, NULL), client_socket_factory_(socket_factory), + last_request_priority_(DEFAULT_PRIORITY), release_count_(0), cancel_count_(0) { } @@ -1740,9 +1743,11 @@ int MockTransportClientSocketPool::RequestSocket( const std::string& group_name, const void* socket_params, RequestPriority priority, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) { - StreamSocket* socket = client_socket_factory_->CreateTransportClientSocket( - AddressList(), net_log.net_log(), net::NetLog::Source()); - MockConnectJob* job = new MockConnectJob(socket, handle, callback); + last_request_priority_ = priority; + scoped_ptr<StreamSocket> socket = + client_socket_factory_->CreateTransportClientSocket( + AddressList(), net_log.net_log(), net::NetLog::Source()); + MockConnectJob* job = new MockConnectJob(socket.Pass(), handle, callback); job_list_.push_back(job); handle->set_pool_id(1); return job->Connect(); @@ -1759,11 +1764,12 @@ void MockTransportClientSocketPool::CancelRequest(const std::string& group_name, } } -void MockTransportClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { +void MockTransportClientSocketPool::ReleaseSocket( + const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) { EXPECT_EQ(1, id); release_count_++; - delete socket; } DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {} @@ -1791,42 +1797,45 @@ MockSSLClientSocket* DeterministicMockClientSocketFactory:: return ssl_client_sockets_[index]; } -DatagramClientSocket* +scoped_ptr<DatagramClientSocket> DeterministicMockClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, net::NetLog* net_log, const NetLog::Source& source) { DeterministicSocketData* data_provider = mock_data().GetNext(); - DeterministicMockUDPClientSocket* socket = - new DeterministicMockUDPClientSocket(net_log, data_provider); + scoped_ptr<DeterministicMockUDPClientSocket> socket( + new DeterministicMockUDPClientSocket(net_log, data_provider)); data_provider->set_delegate(socket->AsWeakPtr()); - udp_client_sockets().push_back(socket); - return socket; + udp_client_sockets().push_back(socket.get()); + return socket.PassAs<DatagramClientSocket>(); } -StreamSocket* DeterministicMockClientSocketFactory::CreateTransportClientSocket( +scoped_ptr<StreamSocket> +DeterministicMockClientSocketFactory::CreateTransportClientSocket( const AddressList& addresses, net::NetLog* net_log, const net::NetLog::Source& source) { DeterministicSocketData* data_provider = mock_data().GetNext(); - DeterministicMockTCPClientSocket* socket = - new DeterministicMockTCPClientSocket(net_log, data_provider); + scoped_ptr<DeterministicMockTCPClientSocket> socket( + new DeterministicMockTCPClientSocket(net_log, data_provider)); data_provider->set_delegate(socket->AsWeakPtr()); - tcp_client_sockets().push_back(socket); - return socket; + tcp_client_sockets().push_back(socket.get()); + return socket.PassAs<StreamSocket>(); } -SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket( - ClientSocketHandle* transport_socket, +scoped_ptr<SSLClientSocket> +DeterministicMockClientSocketFactory::CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) { - MockSSLClientSocket* socket = - new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, - mock_ssl_data_.GetNext()); - ssl_client_sockets_.push_back(socket); - return socket; + scoped_ptr<MockSSLClientSocket> socket( + new MockSSLClientSocket(transport_socket.Pass(), + host_and_port, ssl_config, + mock_ssl_data_.GetNext())); + ssl_client_sockets_.push_back(socket.get()); + return socket.PassAs<SSLClientSocket>(); } void DeterministicMockClientSocketFactory::ClearSSLSessionCache() { @@ -1859,8 +1868,9 @@ void MockSOCKSClientSocketPool::CancelRequest( } void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { - return transport_pool_->ReleaseSocket(group_name, socket, id); + scoped_ptr<StreamSocket> socket, + int id) { + return transport_pool_->ReleaseSocket(group_name, socket.Pass(), id); } const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 }; diff --git a/chromium/net/socket/socket_test_util.h b/chromium/net/socket/socket_test_util.h index 6afe170299e..e4e56522c92 100644 --- a/chromium/net/socket/socket_test_util.h +++ b/chromium/net/socket/socket_test_util.h @@ -13,6 +13,7 @@ #include "base/basictypes.h" #include "base/callback.h" #include "base/logging.h" +#include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/memory/scoped_vector.h" #include "base/memory/weak_ptr.h" @@ -592,17 +593,17 @@ class MockClientSocketFactory : public ClientSocketFactory { } // ClientSocketFactory - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE; - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, const NetLog::Source& source) OVERRIDE; - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE; @@ -857,7 +858,7 @@ class DeterministicMockTCPClientSocket class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { public: MockSSLClientSocket( - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLSocketDataProvider* socket); @@ -1001,11 +1002,12 @@ class ClientSocketPoolTest { ClientSocketPoolTest(); ~ClientSocketPoolTest(); - template <typename PoolType, typename SocketParams> - int StartRequestUsingPool(PoolType* socket_pool, - const std::string& group_name, - RequestPriority priority, - const scoped_refptr<SocketParams>& socket_params) { + template <typename PoolType> + int StartRequestUsingPool( + PoolType* socket_pool, + const std::string& group_name, + RequestPriority priority, + const scoped_refptr<typename PoolType::SocketParams>& socket_params) { DCHECK(socket_pool); TestSocketRequest* request = new TestSocketRequest(&request_order_, &completion_count_); @@ -1045,11 +1047,20 @@ class ClientSocketPoolTest { size_t completion_count_; }; +class MockTransportSocketParams + : public base::RefCounted<MockTransportSocketParams> { + private: + friend class base::RefCounted<MockTransportSocketParams>; + ~MockTransportSocketParams() {} +}; + class MockTransportClientSocketPool : public TransportClientSocketPool { public: + typedef MockTransportSocketParams SocketParams; + class MockConnectJob { public: - MockConnectJob(StreamSocket* socket, ClientSocketHandle* handle, + MockConnectJob(scoped_ptr<StreamSocket> socket, ClientSocketHandle* handle, const CompletionCallback& callback); ~MockConnectJob(); @@ -1074,6 +1085,9 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { virtual ~MockTransportClientSocketPool(); + RequestPriority last_request_priority() const { + return last_request_priority_; + } int release_count() const { return release_count_; } int cancel_count() const { return cancel_count_; } @@ -1088,11 +1102,13 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { virtual void CancelRequest(const std::string& group_name, ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) OVERRIDE; + scoped_ptr<StreamSocket> socket, + int id) OVERRIDE; private: ClientSocketFactory* client_socket_factory_; ScopedVector<MockConnectJob> job_list_; + RequestPriority last_request_priority_; int release_count_; int cancel_count_; @@ -1123,17 +1139,17 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory { } // ClientSocketFactory - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE; - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, const NetLog::Source& source) OVERRIDE; - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE; @@ -1170,7 +1186,8 @@ class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { virtual void CancelRequest(const std::string& group_name, ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) OVERRIDE; + scoped_ptr<StreamSocket> socket, + int id) OVERRIDE; private: TransportClientSocketPool* const transport_pool_; diff --git a/chromium/net/socket/socks5_client_socket.cc b/chromium/net/socket/socks5_client_socket.cc index c9d25bc3dcb..537b584a932 100644 --- a/chromium/net/socket/socks5_client_socket.cc +++ b/chromium/net/socket/socks5_client_socket.cc @@ -28,34 +28,18 @@ COMPILE_ASSERT(sizeof(struct in_addr) == 4, incorrect_system_size_of_IPv4); COMPILE_ASSERT(sizeof(struct in6_addr) == 16, incorrect_system_size_of_IPv6); SOCKS5ClientSocket::SOCKS5ClientSocket( - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostResolver::RequestInfo& req_info) : io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete, base::Unretained(this))), - transport_(transport_socket), + transport_(transport_socket.Pass()), next_state_(STATE_NONE), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), read_header_size(kReadHeaderSize), host_request_info_(req_info), - net_log_(transport_socket->socket()->NetLog()) { -} - -SOCKS5ClientSocket::SOCKS5ClientSocket( - StreamSocket* transport_socket, - const HostResolver::RequestInfo& req_info) - : io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete, - base::Unretained(this))), - transport_(new ClientSocketHandle()), - next_state_(STATE_NONE), - completed_handshake_(false), - bytes_sent_(0), - bytes_received_(0), - read_header_size(kReadHeaderSize), - host_request_info_(req_info), - net_log_(transport_socket->NetLog()) { - transport_->set_socket(transport_socket); + net_log_(transport_->socket()->NetLog()) { } SOCKS5ClientSocket::~SOCKS5ClientSocket() { diff --git a/chromium/net/socket/socks5_client_socket.h b/chromium/net/socket/socks5_client_socket.h index b955e8f42de..45216244f10 100644 --- a/chromium/net/socket/socks5_client_socket.h +++ b/chromium/net/socket/socks5_client_socket.h @@ -28,20 +28,13 @@ class BoundNetLog; // Currently no SOCKSv5 authentication is supported. class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { public: - // Takes ownership of the |transport_socket|, which should already be - // connected by the time Connect() is called. - // // |req_info| contains the hostname and port to which the socket above will // communicate to via the SOCKS layer. // // Although SOCKS 5 supports 3 different modes of addressing, we will // always pass it a hostname. This means the DNS resolving is done // proxy side. - SOCKS5ClientSocket(ClientSocketHandle* transport_socket, - const HostResolver::RequestInfo& req_info); - - // Deprecated constructor (http://crbug.com/37810) that takes a StreamSocket. - SOCKS5ClientSocket(StreamSocket* transport_socket, + SOCKS5ClientSocket(scoped_ptr<ClientSocketHandle> transport_socket, const HostResolver::RequestInfo& req_info); // On destruction Disconnect() is called. diff --git a/chromium/net/socket/socks5_client_socket_unittest.cc b/chromium/net/socket/socks5_client_socket_unittest.cc index 717d858eef8..78f2ac433c3 100644 --- a/chromium/net/socket/socks5_client_socket_unittest.cc +++ b/chromium/net/socket/socks5_client_socket_unittest.cc @@ -32,13 +32,13 @@ class SOCKS5ClientSocketTest : public PlatformTest { public: SOCKS5ClientSocketTest(); // Create a SOCKSClientSocket on top of a MockSocket. - SOCKS5ClientSocket* BuildMockSocket(MockRead reads[], - size_t reads_count, - MockWrite writes[], - size_t writes_count, - const std::string& hostname, - int port, - NetLog* net_log); + scoped_ptr<SOCKS5ClientSocket> BuildMockSocket(MockRead reads[], + size_t reads_count, + MockWrite writes[], + size_t writes_count, + const std::string& hostname, + int port, + NetLog* net_log); virtual void SetUp(); @@ -47,6 +47,8 @@ class SOCKS5ClientSocketTest : public PlatformTest { CapturingNetLog net_log_; scoped_ptr<SOCKS5ClientSocket> user_sock_; AddressList address_list_; + // Filled in by BuildMockSocket() and owned by its return value + // (which |user_sock| is set to). StreamSocket* tcp_sock_; TestCompletionCallback callback_; scoped_ptr<MockHostResolver> host_resolver_; @@ -68,14 +70,18 @@ void SOCKS5ClientSocketTest::SetUp() { // Resolve the "localhost" AddressList used by the TCP connection to connect. HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080)); TestCompletionCallback callback; - int rv = host_resolver_->Resolve(info, &address_list_, callback.callback(), - NULL, BoundNetLog()); + int rv = host_resolver_->Resolve(info, + DEFAULT_PRIORITY, + &address_list_, + callback.callback(), + NULL, + BoundNetLog()); ASSERT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); ASSERT_EQ(OK, rv); } -SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket( +scoped_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket( MockRead reads[], size_t reads_count, MockWrite writes[], @@ -94,8 +100,13 @@ SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket( EXPECT_EQ(OK, rv); EXPECT_TRUE(tcp_sock_->IsConnected()); - return new SOCKS5ClientSocket(tcp_sock_, - HostResolver::RequestInfo(HostPortPair(hostname, port))); + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + // |connection| takes ownership of |tcp_sock_|, but keep a + // non-owning pointer to it. + connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_)); + return scoped_ptr<SOCKS5ClientSocket>(new SOCKS5ClientSocket( + connection.Pass(), + HostResolver::RequestInfo(HostPortPair(hostname, port)))); } // Tests a complete SOCKS5 handshake and the disconnection. @@ -123,9 +134,9 @@ TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) { MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength), MockRead(ASYNC, payload_read.data(), payload_read.size()) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - "localhost", 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + "localhost", 80, &net_log_); // At this state the TCP connection is completed but not the SOCKS handshake. EXPECT_TRUE(tcp_sock_->IsConnected()); @@ -195,9 +206,9 @@ TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) { MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, NULL)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, NULL); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(OK, rv); @@ -217,9 +228,9 @@ TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) { // Create a SOCKS socket, with mock transport socket. MockWrite data_writes[] = {MockWrite()}; MockRead data_reads[] = {MockRead()}; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - large_host_name, 80, NULL)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + large_host_name, 80, NULL); // Try to connect -- should fail (without having read/written anything to // the transport socket first) because the hostname is too long. @@ -253,9 +264,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { MockRead data_reads[] = { MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -284,9 +295,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { MockRead(ASYNC, partial1, arraysize(partial1)), MockRead(ASYNC, partial2, arraysize(partial2)), MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -314,9 +325,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { MockRead data_reads[] = { MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); CapturingNetLog::CapturedEntryList net_log_entries; @@ -345,9 +356,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { kSOCKS5OkResponseLength - kSplitPoint) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); CapturingNetLog::CapturedEntryList net_log_entries; diff --git a/chromium/net/socket/socks_client_socket.cc b/chromium/net/socket/socks_client_socket.cc index c4bbd28c619..67089589cc5 100644 --- a/chromium/net/socket/socks_client_socket.cc +++ b/chromium/net/socket/socks_client_socket.cc @@ -55,32 +55,20 @@ struct SOCKS4ServerResponse { COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, socks4_server_response_struct_wrong_size); -SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket, - const HostResolver::RequestInfo& req_info, - HostResolver* host_resolver) - : transport_(transport_socket), +SOCKSClientSocket::SOCKSClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostResolver::RequestInfo& req_info, + RequestPriority priority, + HostResolver* host_resolver) + : transport_(transport_socket.Pass()), next_state_(STATE_NONE), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), host_resolver_(host_resolver), host_request_info_(req_info), - net_log_(transport_socket->socket()->NetLog()) { -} - -SOCKSClientSocket::SOCKSClientSocket(StreamSocket* transport_socket, - const HostResolver::RequestInfo& req_info, - HostResolver* host_resolver) - : transport_(new ClientSocketHandle()), - next_state_(STATE_NONE), - completed_handshake_(false), - bytes_sent_(0), - bytes_received_(0), - host_resolver_(host_resolver), - host_request_info_(req_info), - net_log_(transport_socket->NetLog()) { - transport_->set_socket(transport_socket); -} + priority_(priority), + net_log_(transport_->socket()->NetLog()) {} SOCKSClientSocket::~SOCKSClientSocket() { Disconnect(); @@ -283,7 +271,9 @@ int SOCKSClientSocket::DoResolveHost() { // addresses for the target host. host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4); return host_resolver_.Resolve( - host_request_info_, &addresses_, + host_request_info_, + priority_, + &addresses_, base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)), net_log_); } diff --git a/chromium/net/socket/socks_client_socket.h b/chromium/net/socket/socks_client_socket.h index 3d4f9fc2771..d4f058a62b1 100644 --- a/chromium/net/socket/socks_client_socket.h +++ b/chromium/net/socket/socks_client_socket.h @@ -27,18 +27,11 @@ class BoundNetLog; // The SOCKS client socket implementation class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { public: - // Takes ownership of the |transport_socket|, which should already be - // connected by the time Connect() is called. - // // |req_info| contains the hostname and port to which the socket above will // communicate to via the socks layer. For testing the referrer is optional. - SOCKSClientSocket(ClientSocketHandle* transport_socket, - const HostResolver::RequestInfo& req_info, - HostResolver* host_resolver); - - // Deprecated constructor (http://crbug.com/37810) that takes a StreamSocket. - SOCKSClientSocket(StreamSocket* transport_socket, + SOCKSClientSocket(scoped_ptr<ClientSocketHandle> transport_socket, const HostResolver::RequestInfo& req_info, + RequestPriority priority, HostResolver* host_resolver); // On destruction Disconnect() is called. @@ -131,6 +124,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { SingleRequestHostResolver host_resolver_; AddressList addresses_; HostResolver::RequestInfo host_request_info_; + RequestPriority priority_; BoundNetLog net_log_; diff --git a/chromium/net/socket/socks_client_socket_pool.cc b/chromium/net/socket/socks_client_socket_pool.cc index d740e5b9a0e..e11b7a48db5 100644 --- a/chromium/net/socket/socks_client_socket_pool.cc +++ b/chromium/net/socket/socks_client_socket_pool.cc @@ -21,8 +21,7 @@ namespace net { SOCKSSocketParams::SOCKSSocketParams( const scoped_refptr<TransportSocketParams>& proxy_server, bool socks_v5, - const HostPortPair& host_port_pair, - RequestPriority priority) + const HostPortPair& host_port_pair) : transport_params_(proxy_server), destination_(host_port_pair), socks_v5_(socks_v5) { @@ -30,7 +29,6 @@ SOCKSSocketParams::SOCKSSocketParams( ignore_limits_ = transport_params_->ignore_limits(); else ignore_limits_ = false; - destination_.set_priority(priority); } SOCKSSocketParams::~SOCKSSocketParams() {} @@ -41,13 +39,14 @@ static const int kSOCKSConnectJobTimeoutInSeconds = 30; SOCKSConnectJob::SOCKSConnectJob( const std::string& group_name, + RequestPriority priority, const scoped_refptr<SOCKSSocketParams>& socks_params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, HostResolver* host_resolver, Delegate* delegate, NetLog* net_log) - : ConnectJob(group_name, timeout_duration, delegate, + : ConnectJob(group_name, timeout_duration, priority, delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), socks_params_(socks_params), transport_pool_(transport_pool), @@ -117,10 +116,12 @@ int SOCKSConnectJob::DoLoop(int result) { int SOCKSConnectJob::DoTransportConnect() { next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; transport_socket_handle_.reset(new ClientSocketHandle()); - return transport_socket_handle_->Init( - group_name(), socks_params_->transport_params(), - socks_params_->destination().priority(), callback_, transport_pool_, - net_log()); + return transport_socket_handle_->Init(group_name(), + socks_params_->transport_params(), + priority(), + callback_, + transport_pool_, + net_log()); } int SOCKSConnectJob::DoTransportConnectComplete(int result) { @@ -140,11 +141,12 @@ int SOCKSConnectJob::DoSOCKSConnect() { // Add a SOCKS connection on top of the tcp socket. if (socks_params_->is_socks_v5()) { - socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.release(), + socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.Pass(), socks_params_->destination())); } else { - socket_.reset(new SOCKSClientSocket(transport_socket_handle_.release(), + socket_.reset(new SOCKSClientSocket(transport_socket_handle_.Pass(), socks_params_->destination(), + priority(), resolver_)); } return socket_->Connect( @@ -157,7 +159,7 @@ int SOCKSConnectJob::DoSOCKSConnectComplete(int result) { return result; } - set_socket(socket_.release()); + SetSocket(socket_.Pass()); return result; } @@ -166,17 +168,19 @@ int SOCKSConnectJob::ConnectInternal() { return DoLoop(OK); } -ConnectJob* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( +scoped_ptr<ConnectJob> +SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return new SOCKSConnectJob(group_name, - request.params(), - ConnectionTimeout(), - transport_pool_, - host_resolver_, - delegate, - net_log_); + return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name, + request.priority(), + request.params(), + ConnectionTimeout(), + transport_pool_, + host_resolver_, + delegate, + net_log_)); } base::TimeDelta @@ -193,7 +197,7 @@ SOCKSClientSocketPool::SOCKSClientSocketPool( TransportClientSocketPool* transport_pool, NetLog* net_log) : transport_pool_(transport_pool), - base_(max_sockets, max_sockets_per_group, histograms, + base_(this, max_sockets, max_sockets_per_group, histograms, ClientSocketPool::unused_idle_socket_timeout(), ClientSocketPool::used_idle_socket_timeout(), new SOCKSConnectJobFactory(transport_pool, @@ -201,13 +205,10 @@ SOCKSClientSocketPool::SOCKSClientSocketPool( net_log)) { // We should always have a |transport_pool_| except in unit tests. if (transport_pool_) - transport_pool_->AddLayeredPool(this); + base_.AddLowerLayeredPool(transport_pool_); } SOCKSClientSocketPool::~SOCKSClientSocketPool() { - // We should always have a |transport_pool_| except in unit tests. - if (transport_pool_) - transport_pool_->RemoveLayeredPool(this); } int SOCKSClientSocketPool::RequestSocket( @@ -238,18 +239,15 @@ void SOCKSClientSocketPool::CancelRequest(const std::string& group_name, } void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { - base_.ReleaseSocket(group_name, socket, id); + scoped_ptr<StreamSocket> socket, + int id) { + base_.ReleaseSocket(group_name, socket.Pass(), id); } void SOCKSClientSocketPool::FlushWithError(int error) { base_.FlushWithError(error); } -bool SOCKSClientSocketPool::IsStalled() const { - return base_.IsStalled() || transport_pool_->IsStalled(); -} - void SOCKSClientSocketPool::CloseIdleSockets() { base_.CloseIdleSockets(); } @@ -268,14 +266,6 @@ LoadState SOCKSClientSocketPool::GetLoadState( return base_.GetLoadState(group_name, handle); } -void SOCKSClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { - base_.AddLayeredPool(layered_pool); -} - -void SOCKSClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { - base_.RemoveLayeredPool(layered_pool); -} - base::DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, @@ -299,10 +289,24 @@ ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const { return base_.histograms(); }; +bool SOCKSClientSocketPool::IsStalled() const { + return base_.IsStalled(); +} + +void SOCKSClientSocketPool::AddHigherLayeredPool( + HigherLayeredPool* higher_pool) { + base_.AddHigherLayeredPool(higher_pool); +} + +void SOCKSClientSocketPool::RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) { + base_.RemoveHigherLayeredPool(higher_pool); +} + bool SOCKSClientSocketPool::CloseOneIdleConnection() { if (base_.CloseOneIdleSocket()) return true; - return base_.CloseOneIdleConnectionInLayeredPool(); + return base_.CloseOneIdleConnectionInHigherLayeredPool(); } } // namespace net diff --git a/chromium/net/socket/socks_client_socket_pool.h b/chromium/net/socket/socks_client_socket_pool.h index 86609a1a5a0..c6d5c8d0883 100644 --- a/chromium/net/socket/socks_client_socket_pool.h +++ b/chromium/net/socket/socks_client_socket_pool.h @@ -28,8 +28,7 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams : public base::RefCounted<SOCKSSocketParams> { public: SOCKSSocketParams(const scoped_refptr<TransportSocketParams>& proxy_server, - bool socks_v5, const HostPortPair& host_port_pair, - RequestPriority priority); + bool socks_v5, const HostPortPair& host_port_pair); const scoped_refptr<TransportSocketParams>& transport_params() const { return transport_params_; @@ -57,6 +56,7 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams class SOCKSConnectJob : public ConnectJob { public: SOCKSConnectJob(const std::string& group_name, + RequestPriority priority, const scoped_refptr<SOCKSSocketParams>& params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, @@ -105,8 +105,10 @@ class SOCKSConnectJob : public ConnectJob { }; class NET_EXPORT_PRIVATE SOCKSClientSocketPool - : public ClientSocketPool, public LayeredPool { + : public ClientSocketPool, public HigherLayeredPool { public: + typedef SOCKSSocketParams SocketParams; + SOCKSClientSocketPool( int max_sockets, int max_sockets_per_group, @@ -134,13 +136,11 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE; virtual void FlushWithError(int error) OVERRIDE; - virtual bool IsStalled() const OVERRIDE; - virtual void CloseIdleSockets() OVERRIDE; virtual int IdleSocketCount() const OVERRIDE; @@ -152,10 +152,6 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool const std::string& group_name, const ClientSocketHandle* handle) const OVERRIDE; - virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; - - virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; - virtual base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, @@ -165,7 +161,14 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; - // LayeredPool implementation. + // LowerLayeredPool implementation. + virtual bool IsStalled() const OVERRIDE; + + virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + + virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + + // HigherLayeredPool implementation. virtual bool CloseOneIdleConnection() OVERRIDE; private: @@ -183,7 +186,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool virtual ~SOCKSConnectJobFactory() {} // ClientSocketPoolBase::ConnectJobFactory methods. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE; @@ -204,8 +207,6 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool DISALLOW_COPY_AND_ASSIGN(SOCKSClientSocketPool); }; -REGISTER_SOCKET_PARAMS_FOR_POOL(SOCKSClientSocketPool, SOCKSSocketParams); - } // namespace net #endif // NET_SOCKET_SOCKS_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/socket/socks_client_socket_pool_unittest.cc b/chromium/net/socket/socks_client_socket_pool_unittest.cc index 77440d36a19..4463e171f84 100644 --- a/chromium/net/socket/socks_client_socket_pool_unittest.cc +++ b/chromium/net/socket/socks_client_socket_pool_unittest.cc @@ -41,6 +41,25 @@ void TestLoadTimingInfo(const ClientSocketHandle& handle) { ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); } + +scoped_refptr<TransportSocketParams> CreateProxyHostParams() { + return new TransportSocketParams( + HostPortPair("proxy", 80), false, false, + OnHostResolutionCallback()); +} + +scoped_refptr<SOCKSSocketParams> CreateSOCKSv4Params() { + return new SOCKSSocketParams( + CreateProxyHostParams(), false /* socks_v5 */, + HostPortPair("host", 80)); +} + +scoped_refptr<SOCKSSocketParams> CreateSOCKSv5Params() { + return new SOCKSSocketParams( + CreateProxyHostParams(), true /* socks_v5 */, + HostPortPair("host", 80)); +} + class SOCKSClientSocketPoolTest : public testing::Test { protected: class SOCKS5MockData { @@ -71,30 +90,24 @@ class SOCKSClientSocketPoolTest : public testing::Test { }; SOCKSClientSocketPoolTest() - : ignored_transport_socket_params_(new TransportSocketParams( - HostPortPair("proxy", 80), MEDIUM, false, false, - OnHostResolutionCallback())), - transport_histograms_("MockTCP"), + : transport_histograms_("MockTCP"), transport_socket_pool_( kMaxSockets, kMaxSocketsPerGroup, &transport_histograms_, &transport_client_socket_factory_), - ignored_socket_params_(new SOCKSSocketParams( - ignored_transport_socket_params_, true, HostPortPair("host", 80), - MEDIUM)), socks_histograms_("SOCKSUnitTest"), pool_(kMaxSockets, kMaxSocketsPerGroup, &socks_histograms_, - NULL, + &host_resolver_, &transport_socket_pool_, NULL) { } virtual ~SOCKSClientSocketPoolTest() {} - int StartRequest(const std::string& group_name, RequestPriority priority) { + int StartRequestV5(const std::string& group_name, RequestPriority priority) { return test_base_.StartRequestUsingPool( - &pool_, group_name, priority, ignored_socket_params_); + &pool_, group_name, priority, CreateSOCKSv5Params()); } int GetOrderOfRequest(size_t index) const { @@ -103,13 +116,12 @@ class SOCKSClientSocketPoolTest : public testing::Test { ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); } - scoped_refptr<TransportSocketParams> ignored_transport_socket_params_; ClientSocketPoolHistograms transport_histograms_; MockClientSocketFactory transport_client_socket_factory_; MockTransportClientSocketPool transport_socket_pool_; - scoped_refptr<SOCKSSocketParams> ignored_socket_params_; ClientSocketPoolHistograms socks_histograms_; + MockHostResolver host_resolver_; SOCKSClientSocketPool pool_; ClientSocketPoolTest test_base_; }; @@ -120,7 +132,7 @@ TEST_F(SOCKSClientSocketPoolTest, Simple) { transport_client_socket_factory_.AddSocketDataProvider(data.data_provider()); ClientSocketHandle handle; - int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(), + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(), &pool_, BoundNetLog()); EXPECT_EQ(OK, rv); EXPECT_TRUE(handle.is_initialized()); @@ -128,13 +140,52 @@ TEST_F(SOCKSClientSocketPoolTest, Simple) { TestLoadTimingInfo(handle); } +// Make sure that SOCKSConnectJob passes on its priority to its +// socket request on Init. +TEST_F(SOCKSClientSocketPoolTest, SetSocketRequestPriorityOnInit) { + for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) { + RequestPriority priority = static_cast<RequestPriority>(i); + SOCKS5MockData data(SYNCHRONOUS); + data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK)); + transport_client_socket_factory_.AddSocketDataProvider( + data.data_provider()); + + ClientSocketHandle handle; + EXPECT_EQ(OK, + handle.Init("a", CreateSOCKSv5Params(), priority, + CompletionCallback(), &pool_, BoundNetLog())); + EXPECT_EQ(priority, transport_socket_pool_.last_request_priority()); + handle.socket()->Disconnect(); + } +} + +// Make sure that SOCKSConnectJob passes on its priority to its +// HostResolver request (for non-SOCKS5) on Init. +TEST_F(SOCKSClientSocketPoolTest, SetResolvePriorityOnInit) { + for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) { + RequestPriority priority = static_cast<RequestPriority>(i); + SOCKS5MockData data(SYNCHRONOUS); + data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK)); + transport_client_socket_factory_.AddSocketDataProvider( + data.data_provider()); + + ClientSocketHandle handle; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", CreateSOCKSv4Params(), priority, + CompletionCallback(), &pool_, BoundNetLog())); + EXPECT_EQ(priority, transport_socket_pool_.last_request_priority()); + EXPECT_EQ(priority, host_resolver_.last_request_priority()); + EXPECT_TRUE(handle.socket() == NULL); + } +} + TEST_F(SOCKSClientSocketPoolTest, Async) { SOCKS5MockData data(ASYNC); transport_client_socket_factory_.AddSocketDataProvider(data.data_provider()); TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(), + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -153,7 +204,7 @@ TEST_F(SOCKSClientSocketPoolTest, TransportConnectError) { transport_client_socket_factory_.AddSocketDataProvider(&socket_data); ClientSocketHandle handle; - int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(), + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, rv); EXPECT_FALSE(handle.is_initialized()); @@ -167,7 +218,7 @@ TEST_F(SOCKSClientSocketPoolTest, AsyncTransportConnectError) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(), + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -189,7 +240,7 @@ TEST_F(SOCKSClientSocketPoolTest, SOCKSConnectError) { ClientSocketHandle handle; EXPECT_EQ(0, transport_socket_pool_.release_count()); - int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(), + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv); EXPECT_FALSE(handle.is_initialized()); @@ -209,7 +260,7 @@ TEST_F(SOCKSClientSocketPoolTest, AsyncSOCKSConnectError) { TestCompletionCallback callback; ClientSocketHandle handle; EXPECT_EQ(0, transport_socket_pool_.release_count()); - int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(), + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -230,10 +281,10 @@ TEST_F(SOCKSClientSocketPoolTest, CancelDuringTransportConnect) { transport_client_socket_factory_.AddSocketDataProvider(data2.data_provider()); EXPECT_EQ(0, transport_socket_pool_.cancel_count()); - int rv = StartRequest("a", LOW); + int rv = StartRequestV5("a", LOW); EXPECT_EQ(ERR_IO_PENDING, rv); - rv = StartRequest("a", LOW); + rv = StartRequestV5("a", LOW); EXPECT_EQ(ERR_IO_PENDING, rv); pool_.CancelRequest("a", (*requests())[0]->handle()); @@ -265,10 +316,10 @@ TEST_F(SOCKSClientSocketPoolTest, CancelDuringSOCKSConnect) { EXPECT_EQ(0, transport_socket_pool_.cancel_count()); EXPECT_EQ(0, transport_socket_pool_.release_count()); - int rv = StartRequest("a", LOW); + int rv = StartRequestV5("a", LOW); EXPECT_EQ(ERR_IO_PENDING, rv); - rv = StartRequest("a", LOW); + rv = StartRequestV5("a", LOW); EXPECT_EQ(ERR_IO_PENDING, rv); pool_.CancelRequest("a", (*requests())[0]->handle()); diff --git a/chromium/net/socket/socks_client_socket_unittest.cc b/chromium/net/socket/socks_client_socket_unittest.cc index 7a8faf69856..f361244feff 100644 --- a/chromium/net/socket/socks_client_socket_unittest.cc +++ b/chromium/net/socket/socks_client_socket_unittest.cc @@ -4,6 +4,7 @@ #include "net/socket/socks_client_socket.h" +#include "base/memory/scoped_ptr.h" #include "net/base/address_list.h" #include "net/base/net_log.h" #include "net/base/net_log_unittest.h" @@ -27,16 +28,19 @@ class SOCKSClientSocketTest : public PlatformTest { public: SOCKSClientSocketTest(); // Create a SOCKSClientSocket on top of a MockSocket. - SOCKSClientSocket* BuildMockSocket(MockRead reads[], size_t reads_count, - MockWrite writes[], size_t writes_count, - HostResolver* host_resolver, - const std::string& hostname, int port, - NetLog* net_log); + scoped_ptr<SOCKSClientSocket> BuildMockSocket( + MockRead reads[], size_t reads_count, + MockWrite writes[], size_t writes_count, + HostResolver* host_resolver, + const std::string& hostname, int port, + NetLog* net_log); virtual void SetUp(); protected: scoped_ptr<SOCKSClientSocket> user_sock_; AddressList address_list_; + // Filled in by BuildMockSocket() and owned by its return value + // (which |user_sock| is set to). StreamSocket* tcp_sock_; TestCompletionCallback callback_; scoped_ptr<MockHostResolver> host_resolver_; @@ -52,7 +56,7 @@ void SOCKSClientSocketTest::SetUp() { PlatformTest::SetUp(); } -SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( +scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket( MockRead reads[], size_t reads_count, MockWrite writes[], @@ -73,9 +77,15 @@ SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( EXPECT_EQ(OK, rv); EXPECT_TRUE(tcp_sock_->IsConnected()); - return new SOCKSClientSocket(tcp_sock_, + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + // |connection| takes ownership of |tcp_sock_|, but keep a + // non-owning pointer to it. + connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_)); + return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket( + connection.Pass(), HostResolver::RequestInfo(HostPortPair(hostname, port)), - host_resolver); + DEFAULT_PRIORITY, + host_resolver)); } // Implementation of HostResolver that never completes its resolve request. @@ -86,6 +96,7 @@ class HangingHostResolverWithCancel : public HostResolver { HangingHostResolverWithCancel() : outstanding_request_(NULL) {} virtual int Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, @@ -134,11 +145,11 @@ TEST_F(SOCKSClientSocketTest, CompleteHandshake) { MockRead(ASYNC, payload_read.data(), payload_read.size()) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); // At this state the TCP connection is completed but not the SOCKS handshake. EXPECT_TRUE(tcp_sock_->IsConnected()); @@ -210,11 +221,11 @@ TEST_F(SOCKSClientSocketTest, HandshakeFailures) { arraysize(tests[i].fail_reply)) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -247,11 +258,11 @@ TEST_F(SOCKSClientSocketTest, PartialServerReads) { MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -285,11 +296,11 @@ TEST_F(SOCKSClientSocketTest, PartialClientWrites) { MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -317,11 +328,11 @@ TEST_F(SOCKSClientSocketTest, FailedSocketRead) { MockRead(SYNCHRONOUS, 0) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -347,11 +358,11 @@ TEST_F(SOCKSClientSocketTest, FailedDNS) { CapturingNetLog log; - user_sock_.reset(BuildMockSocket(NULL, 0, - NULL, 0, - host_resolver_.get(), - hostname, 80, - &log)); + user_sock_ = BuildMockSocket(NULL, 0, + NULL, 0, + host_resolver_.get(), + hostname, 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -378,11 +389,11 @@ TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) { MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) }; MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hanging_resolver.get(), - "foo", 80, - NULL)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hanging_resolver.get(), + "foo", 80, + NULL); // Start connecting (will get stuck waiting for the host to resolve). int rv = user_sock_->Connect(callback_.callback()); diff --git a/chromium/net/socket/ssl_client_socket_nss.cc b/chromium/net/socket/ssl_client_socket_nss.cc index f374dedcf80..0de7cfb9060 100644 --- a/chromium/net/socket/ssl_client_socket_nss.cc +++ b/chromium/net/socket/ssl_client_socket_nss.cc @@ -380,20 +380,16 @@ void PeerCertificateChain::Reset(PRFileDesc* nss_fd) { if (nss_fd == NULL) return; - unsigned int num_certs = 0; - SECStatus rv = SSL_PeerCertificateChain(nss_fd, NULL, &num_certs, 0); - DCHECK_EQ(SECSuccess, rv); - + CERTCertList* list = SSL_PeerCertificateChain(nss_fd); // The handshake on |nss_fd| may not have completed. - if (num_certs == 0) + if (list == NULL) return; - certs_.resize(num_certs); - const unsigned int expected_num_certs = num_certs; - rv = SSL_PeerCertificateChain(nss_fd, vector_as_array(&certs_), - &num_certs, expected_num_certs); - DCHECK_EQ(SECSuccess, rv); - DCHECK_EQ(expected_num_certs, num_certs); + for (CERTCertListNode* node = CERT_LIST_HEAD(list); + !CERT_LIST_END(node, list); node = CERT_LIST_NEXT(node)) { + certs_.push_back(CERT_DupCertificate(node->cert)); + } + CERT_DestroyCertList(list); } std::vector<base::StringPiece> @@ -1291,6 +1287,19 @@ SECStatus SSLClientSocketNSS::Core::OwnAuthCertHandler( // Start with it. SSL_OptionSet(socket, SSL_ENABLE_FALSE_START, PR_FALSE); } + } else { + // Disallow the server certificate to change in a renegotiation. + CERTCertificate* old_cert = core->nss_handshake_state_.server_cert_chain[0]; + ScopedCERTCertificate new_cert(SSL_PeerCertificate(socket)); + if (new_cert->derCert.len != old_cert->derCert.len || + memcmp(new_cert->derCert.data, old_cert->derCert.data, + new_cert->derCert.len) != 0) { + // NSS doesn't have an error code that indicates the server certificate + // changed. Borrow SSL_ERROR_WRONG_CERTIFICATE (which NSS isn't using) + // for this purpose. + PORT_SetError(SSL_ERROR_WRONG_CERTIFICATE); + return SECFailure; + } } // Tell NSS to not verify the certificate. @@ -2598,7 +2607,7 @@ int SSLClientSocketNSS::Core::DoGetDomainBoundCert(const std::string& host) { weak_net_log_->BeginEvent(NetLog::TYPE_SSL_GET_DOMAIN_BOUND_CERT); - int rv = server_bound_cert_service_->GetDomainBoundCert( + int rv = server_bound_cert_service_->GetOrCreateDomainBoundCert( host, &domain_bound_private_key_, &domain_bound_cert_, @@ -2751,12 +2760,12 @@ void SSLClientSocketNSS::Core::SetChannelIDProvided() { SSLClientSocketNSS::SSLClientSocketNSS( base::SequencedTaskRunner* nss_task_runner, - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) : nss_task_runner_(nss_task_runner), - transport_(transport_socket), + transport_(transport_socket.Pass()), host_and_port_(host_and_port), ssl_config_(ssl_config), cert_verifier_(context.cert_verifier), @@ -2765,7 +2774,7 @@ SSLClientSocketNSS::SSLClientSocketNSS( completed_handshake_(false), next_handshake_state_(STATE_NONE), nss_fd_(NULL), - net_log_(transport_socket->socket()->NetLog()), + net_log_(transport_->socket()->NetLog()), transport_security_state_(context.transport_security_state), valid_thread_id_(base::kInvalidThreadId) { EnterFunction(""); @@ -3141,7 +3150,8 @@ int SSLClientSocketNSS::InitializeSSLOptions() { net_log_, "SSL_OptionSet", "SSL_ENABLE_SESSION_TICKETS"); } - rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_FALSE_START, PR_FALSE); + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_FALSE_START, + ssl_config_.false_start_enabled); if (rv != SECSuccess) LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_FALSE_START"); diff --git a/chromium/net/socket/ssl_client_socket_nss.h b/chromium/net/socket/ssl_client_socket_nss.h index fed8ef706b5..b41d28d74a8 100644 --- a/chromium/net/socket/ssl_client_socket_nss.h +++ b/chromium/net/socket/ssl_client_socket_nss.h @@ -59,7 +59,7 @@ class SSLClientSocketNSS : public SSLClientSocket { // behaviour is desired, for performance or compatibility, the current task // runner should be supplied instead. SSLClientSocketNSS(base::SequencedTaskRunner* nss_task_runner, - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context); diff --git a/chromium/net/socket/ssl_client_socket_openssl.cc b/chromium/net/socket/ssl_client_socket_openssl.cc index 1431bc61486..416ab87bc4b 100644 --- a/chromium/net/socket/ssl_client_socket_openssl.cc +++ b/chromium/net/socket/ssl_client_socket_openssl.cc @@ -425,7 +425,7 @@ void SSLClientSocket::ClearSessionCache() { } SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) @@ -439,14 +439,14 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( cert_verifier_(context.cert_verifier), ssl_(NULL), transport_bio_(NULL), - transport_(transport_socket), + transport_(transport_socket.Pass()), host_and_port_(host_and_port), ssl_config_(ssl_config), ssl_session_cache_shard_(context.ssl_session_cache_shard), trying_cached_session_(false), next_handshake_state_(STATE_NONE), npn_status_(kNextProtoUnsupported), - net_log_(transport_socket->socket()->NetLog()) { + net_log_(transport_->socket()->NetLog()) { } SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() { @@ -532,9 +532,11 @@ bool SSLClientSocketOpenSSL::Init() { STACK_OF(SSL_CIPHER)* ciphers = SSL_get_ciphers(ssl_); DCHECK(ciphers); // See SSLConfig::disabled_cipher_suites for description of the suites - // disabled by default. Note that !SHA384 only removes HMAC-SHA384 cipher - // suites, not GCM cipher suites with SHA384 as the handshake hash. - std::string command("DEFAULT:!NULL:!aNULL:!IDEA:!FZA:!SRP:!SHA384:!aECDH"); + // disabled by default. Note that !SHA256 and !SHA384 only remove HMAC-SHA256 + // and HMAC-SHA384 cipher suites, not GCM cipher suites with SHA256 or SHA384 + // as the handshake hash. + std::string command("DEFAULT:!NULL:!aNULL:!IDEA:!FZA:!SRP:!SHA256:!SHA384:" + "!aECDH:!AESGCM+AES256"); // Walk through all the installed ciphers, seeing if any need to be // appended to the cipher removal |command|. for (int i = 0; i < sk_SSL_CIPHER_num(ciphers); ++i) { diff --git a/chromium/net/socket/ssl_client_socket_openssl.h b/chromium/net/socket/ssl_client_socket_openssl.h index 520f432b8bc..f66d95cc69d 100644 --- a/chromium/net/socket/ssl_client_socket_openssl.h +++ b/chromium/net/socket/ssl_client_socket_openssl.h @@ -41,7 +41,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // The given hostname will be compared with the name(s) in the server's // certificate during the SSL handshake. ssl_config specifies the SSL // settings. - SSLClientSocketOpenSSL(ClientSocketHandle* transport_socket, + SSLClientSocketOpenSSL(scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context); diff --git a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc index 7a37cdc1187..24c06059be5 100644 --- a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc @@ -67,7 +67,7 @@ bool LoadPrivateKeyOpenSSL( const base::FilePath& filepath, OpenSSLClientKeyStore::ScopedEVP_PKEY* pkey) { std::string data; - if (!file_util::ReadFileToString(filepath, &data)) { + if (!base::ReadFileToString(filepath, &data)) { LOG(ERROR) << "Could not read private key file: " << filepath.value() << ": " << strerror(errno); return false; @@ -107,11 +107,13 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { } protected: - SSLClientSocket* CreateSSLClientSocket( - StreamSocket* transport_socket, + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<StreamSocket> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config) { - return socket_factory_->CreateSSLClientSocket(transport_socket, + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + connection->SetSocket(transport_socket.Pass()); + return socket_factory_->CreateSSLClientSocket(connection.Pass(), host_and_port, ssl_config, context_); @@ -164,9 +166,9 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { // itself was a success. bool CreateAndConnectSSLClientSocket(SSLConfig& ssl_config, int* result) { - sock_.reset(CreateSSLClientSocket(transport_.release(), - test_server_->host_port_pair(), - ssl_config)); + sock_ = CreateSSLClientSocket(transport_.Pass(), + test_server_->host_port_pair(), + ssl_config); if (sock_->IsConnected()) { LOG(ERROR) << "SSL Socket prematurely connected"; diff --git a/chromium/net/socket/ssl_client_socket_pool.cc b/chromium/net/socket/ssl_client_socket_pool.cc index fed268d4ee4..5d574b7edda 100644 --- a/chromium/net/socket/ssl_client_socket_pool.cc +++ b/chromium/net/socket/ssl_client_socket_pool.cc @@ -26,20 +26,18 @@ namespace net { SSLSocketParams::SSLSocketParams( - const scoped_refptr<TransportSocketParams>& transport_params, - const scoped_refptr<SOCKSSocketParams>& socks_params, + const scoped_refptr<TransportSocketParams>& direct_params, + const scoped_refptr<SOCKSSocketParams>& socks_proxy_params, const scoped_refptr<HttpProxySocketParams>& http_proxy_params, - ProxyServer::Scheme proxy, const HostPortPair& host_and_port, const SSLConfig& ssl_config, PrivacyMode privacy_mode, int load_flags, bool force_spdy_over_ssl, bool want_spdy_over_npn) - : transport_params_(transport_params), + : direct_params_(direct_params), + socks_proxy_params_(socks_proxy_params), http_proxy_params_(http_proxy_params), - socks_params_(socks_params), - proxy_(proxy), host_and_port_(host_and_port), ssl_config_(ssl_config), privacy_mode_(privacy_mode), @@ -47,39 +45,60 @@ SSLSocketParams::SSLSocketParams( force_spdy_over_ssl_(force_spdy_over_ssl), want_spdy_over_npn_(want_spdy_over_npn), ignore_limits_(false) { - switch (proxy_) { - case ProxyServer::SCHEME_DIRECT: - DCHECK(transport_params_.get() != NULL); - DCHECK(http_proxy_params_.get() == NULL); - DCHECK(socks_params_.get() == NULL); - ignore_limits_ = transport_params_->ignore_limits(); - break; - case ProxyServer::SCHEME_HTTP: - case ProxyServer::SCHEME_HTTPS: - DCHECK(transport_params_.get() == NULL); - DCHECK(http_proxy_params_.get() != NULL); - DCHECK(socks_params_.get() == NULL); - ignore_limits_ = http_proxy_params_->ignore_limits(); - break; - case ProxyServer::SCHEME_SOCKS4: - case ProxyServer::SCHEME_SOCKS5: - DCHECK(transport_params_.get() == NULL); - DCHECK(http_proxy_params_.get() == NULL); - DCHECK(socks_params_.get() != NULL); - ignore_limits_ = socks_params_->ignore_limits(); - break; - default: - LOG(DFATAL) << "unknown proxy type"; - break; + if (direct_params_) { + DCHECK(!socks_proxy_params_); + DCHECK(!http_proxy_params_); + ignore_limits_ = direct_params_->ignore_limits(); + } else if (socks_proxy_params_) { + DCHECK(!http_proxy_params_); + ignore_limits_ = socks_proxy_params_->ignore_limits(); + } else { + DCHECK(http_proxy_params_); + ignore_limits_ = http_proxy_params_->ignore_limits(); } } SSLSocketParams::~SSLSocketParams() {} +SSLSocketParams::ConnectionType SSLSocketParams::GetConnectionType() const { + if (direct_params_) { + DCHECK(!socks_proxy_params_); + DCHECK(!http_proxy_params_); + return DIRECT; + } + + if (socks_proxy_params_) { + DCHECK(!http_proxy_params_); + return SOCKS_PROXY; + } + + DCHECK(http_proxy_params_); + return HTTP_PROXY; +} + +const scoped_refptr<TransportSocketParams>& +SSLSocketParams::GetDirectConnectionParams() const { + DCHECK_EQ(GetConnectionType(), DIRECT); + return direct_params_; +} + +const scoped_refptr<SOCKSSocketParams>& +SSLSocketParams::GetSocksProxyConnectionParams() const { + DCHECK_EQ(GetConnectionType(), SOCKS_PROXY); + return socks_proxy_params_; +} + +const scoped_refptr<HttpProxySocketParams>& +SSLSocketParams::GetHttpProxyConnectionParams() const { + DCHECK_EQ(GetConnectionType(), HTTP_PROXY); + return http_proxy_params_; +} + // Timeout for the SSL handshake portion of the connect. static const int kSSLHandshakeTimeoutInSeconds = 30; SSLConnectJob::SSLConnectJob(const std::string& group_name, + RequestPriority priority, const scoped_refptr<SSLSocketParams>& params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, @@ -92,6 +111,7 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name, NetLog* net_log) : ConnectJob(group_name, timeout_duration, + priority, delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), params_(params), @@ -201,12 +221,14 @@ int SSLConnectJob::DoTransportConnect() { next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; transport_socket_handle_.reset(new ClientSocketHandle()); - scoped_refptr<TransportSocketParams> transport_params = - params_->transport_params(); - return transport_socket_handle_->Init( - group_name(), transport_params, - transport_params->destination().priority(), callback_, transport_pool_, - net_log()); + scoped_refptr<TransportSocketParams> direct_params = + params_->GetDirectConnectionParams(); + return transport_socket_handle_->Init(group_name(), + direct_params, + priority(), + callback_, + transport_pool_, + net_log()); } int SSLConnectJob::DoTransportConnectComplete(int result) { @@ -220,10 +242,14 @@ int SSLConnectJob::DoSOCKSConnect() { DCHECK(socks_pool_); next_state_ = STATE_SOCKS_CONNECT_COMPLETE; transport_socket_handle_.reset(new ClientSocketHandle()); - scoped_refptr<SOCKSSocketParams> socks_params = params_->socks_params(); - return transport_socket_handle_->Init( - group_name(), socks_params, socks_params->destination().priority(), - callback_, socks_pool_, net_log()); + scoped_refptr<SOCKSSocketParams> socks_proxy_params = + params_->GetSocksProxyConnectionParams(); + return transport_socket_handle_->Init(group_name(), + socks_proxy_params, + priority(), + callback_, + socks_pool_, + net_log()); } int SSLConnectJob::DoSOCKSConnectComplete(int result) { @@ -239,11 +265,13 @@ int SSLConnectJob::DoTunnelConnect() { transport_socket_handle_.reset(new ClientSocketHandle()); scoped_refptr<HttpProxySocketParams> http_proxy_params = - params_->http_proxy_params(); - return transport_socket_handle_->Init( - group_name(), http_proxy_params, - http_proxy_params->destination().priority(), callback_, http_proxy_pool_, - net_log()); + params_->GetHttpProxyConnectionParams(); + return transport_socket_handle_->Init(group_name(), + http_proxy_params, + priority(), + callback_, + http_proxy_pool_, + net_log()); } int SSLConnectJob::DoTunnelConnectComplete(int result) { @@ -287,11 +315,11 @@ int SSLConnectJob::DoSSLConnect() { connect_timing_.ssl_start = base::TimeTicks::Now(); - ssl_socket_.reset(client_socket_factory_->CreateSSLClientSocket( - transport_socket_handle_.release(), + ssl_socket_ = client_socket_factory_->CreateSSLClientSocket( + transport_socket_handle_.Pass(), params_->host_and_port(), params_->ssl_config(), - context_)); + context_); return ssl_socket_->Connect(callback_); } @@ -410,7 +438,7 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { } if (result == OK || IsCertificateError(result)) { - set_socket(ssl_socket_.release()); + SetSocket(ssl_socket_.PassAs<StreamSocket>()); } else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) { error_response_info_.cert_request_info = new SSLCertRequestInfo; ssl_socket_->GetSSLCertRequestInfo( @@ -420,23 +448,22 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { return result; } -int SSLConnectJob::ConnectInternal() { - switch (params_->proxy()) { - case ProxyServer::SCHEME_DIRECT: - next_state_ = STATE_TRANSPORT_CONNECT; - break; - case ProxyServer::SCHEME_HTTP: - case ProxyServer::SCHEME_HTTPS: - next_state_ = STATE_TUNNEL_CONNECT; - break; - case ProxyServer::SCHEME_SOCKS4: - case ProxyServer::SCHEME_SOCKS5: - next_state_ = STATE_SOCKS_CONNECT; - break; - default: - NOTREACHED() << "unknown proxy type"; - break; +SSLConnectJob::State SSLConnectJob::GetInitialState( + SSLSocketParams::ConnectionType connection_type) { + switch (connection_type) { + case SSLSocketParams::DIRECT: + return STATE_TRANSPORT_CONNECT; + case SSLSocketParams::HTTP_PROXY: + return STATE_TUNNEL_CONNECT; + case SSLSocketParams::SOCKS_PROXY: + return STATE_SOCKS_CONNECT; } + NOTREACHED(); + return STATE_NONE; +} + +int SSLConnectJob::ConnectInternal() { + next_state_ = GetInitialState(params_->GetConnectionType()); return DoLoop(OK); } @@ -491,7 +518,7 @@ SSLClientSocketPool::SSLClientSocketPool( : transport_pool_(transport_pool), socks_pool_(socks_pool), http_proxy_pool_(http_proxy_pool), - base_(max_sockets, max_sockets_per_group, histograms, + base_(this, max_sockets, max_sockets_per_group, histograms, ClientSocketPool::unused_idle_socket_timeout(), ClientSocketPool::used_idle_socket_timeout(), new SSLConnectJobFactory(transport_pool, @@ -509,32 +536,28 @@ SSLClientSocketPool::SSLClientSocketPool( if (ssl_config_service_.get()) ssl_config_service_->AddObserver(this); if (transport_pool_) - transport_pool_->AddLayeredPool(this); + base_.AddLowerLayeredPool(transport_pool_); if (socks_pool_) - socks_pool_->AddLayeredPool(this); + base_.AddLowerLayeredPool(socks_pool_); if (http_proxy_pool_) - http_proxy_pool_->AddLayeredPool(this); + base_.AddLowerLayeredPool(http_proxy_pool_); } SSLClientSocketPool::~SSLClientSocketPool() { - if (http_proxy_pool_) - http_proxy_pool_->RemoveLayeredPool(this); - if (socks_pool_) - socks_pool_->RemoveLayeredPool(this); - if (transport_pool_) - transport_pool_->RemoveLayeredPool(this); if (ssl_config_service_.get()) ssl_config_service_->RemoveObserver(this); } -ConnectJob* SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( +scoped_ptr<ConnectJob> +SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return new SSLConnectJob(group_name, request.params(), ConnectionTimeout(), - transport_pool_, socks_pool_, http_proxy_pool_, - client_socket_factory_, host_resolver_, - context_, delegate, net_log_); + return scoped_ptr<ConnectJob>( + new SSLConnectJob(group_name, request.priority(), request.params(), + ConnectionTimeout(), transport_pool_, socks_pool_, + http_proxy_pool_, client_socket_factory_, + host_resolver_, context_, delegate, net_log_)); } base::TimeDelta @@ -572,21 +595,15 @@ void SSLClientSocketPool::CancelRequest(const std::string& group_name, } void SSLClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { - base_.ReleaseSocket(group_name, socket, id); + scoped_ptr<StreamSocket> socket, + int id) { + base_.ReleaseSocket(group_name, socket.Pass(), id); } void SSLClientSocketPool::FlushWithError(int error) { base_.FlushWithError(error); } -bool SSLClientSocketPool::IsStalled() const { - return base_.IsStalled() || - (transport_pool_ && transport_pool_->IsStalled()) || - (socks_pool_ && socks_pool_->IsStalled()) || - (http_proxy_pool_ && http_proxy_pool_->IsStalled()); -} - void SSLClientSocketPool::CloseIdleSockets() { base_.CloseIdleSockets(); } @@ -605,14 +622,6 @@ LoadState SSLClientSocketPool::GetLoadState( return base_.GetLoadState(group_name, handle); } -void SSLClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { - base_.AddLayeredPool(layered_pool); -} - -void SSLClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { - base_.RemoveLayeredPool(layered_pool); -} - base::DictionaryValue* SSLClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, @@ -648,14 +657,27 @@ ClientSocketPoolHistograms* SSLClientSocketPool::histograms() const { return base_.histograms(); } -void SSLClientSocketPool::OnSSLConfigChanged() { - FlushWithError(ERR_NETWORK_CHANGED); +bool SSLClientSocketPool::IsStalled() const { + return base_.IsStalled(); +} + +void SSLClientSocketPool::AddHigherLayeredPool(HigherLayeredPool* higher_pool) { + base_.AddHigherLayeredPool(higher_pool); +} + +void SSLClientSocketPool::RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) { + base_.RemoveHigherLayeredPool(higher_pool); } bool SSLClientSocketPool::CloseOneIdleConnection() { if (base_.CloseOneIdleSocket()) return true; - return base_.CloseOneIdleConnectionInLayeredPool(); + return base_.CloseOneIdleConnectionInHigherLayeredPool(); +} + +void SSLClientSocketPool::OnSSLConfigChanged() { + FlushWithError(ERR_NETWORK_CHANGED); } } // namespace net diff --git a/chromium/net/socket/ssl_client_socket_pool.h b/chromium/net/socket/ssl_client_socket_pool.h index bc54bc92f9a..ec62eb01f46 100644 --- a/chromium/net/socket/ssl_client_socket_pool.h +++ b/chromium/net/socket/ssl_client_socket_pool.h @@ -13,7 +13,6 @@ #include "net/base/privacy_mode.h" #include "net/dns/host_resolver.h" #include "net/http/http_response_info.h" -#include "net/proxy/proxy_server.h" #include "net/socket/client_socket_pool.h" #include "net/socket/client_socket_pool_base.h" #include "net/socket/client_socket_pool_histograms.h" @@ -35,32 +34,39 @@ class TransportClientSocketPool; class TransportSecurityState; class TransportSocketParams; -// SSLSocketParams only needs the socket params for the transport socket -// that will be used (denoted by |proxy|). class NET_EXPORT_PRIVATE SSLSocketParams : public base::RefCounted<SSLSocketParams> { public: - SSLSocketParams(const scoped_refptr<TransportSocketParams>& transport_params, - const scoped_refptr<SOCKSSocketParams>& socks_params, - const scoped_refptr<HttpProxySocketParams>& http_proxy_params, - ProxyServer::Scheme proxy, - const HostPortPair& host_and_port, - const SSLConfig& ssl_config, - PrivacyMode privacy_mode, - int load_flags, - bool force_spdy_over_ssl, - bool want_spdy_over_npn); - - const scoped_refptr<TransportSocketParams>& transport_params() { - return transport_params_; - } - const scoped_refptr<HttpProxySocketParams>& http_proxy_params() { - return http_proxy_params_; - } - const scoped_refptr<SOCKSSocketParams>& socks_params() { - return socks_params_; - } - ProxyServer::Scheme proxy() const { return proxy_; } + enum ConnectionType { DIRECT, SOCKS_PROXY, HTTP_PROXY }; + + // Exactly one of |direct_params|, |socks_proxy_params|, and + // |http_proxy_params| must be non-NULL. + SSLSocketParams( + const scoped_refptr<TransportSocketParams>& direct_params, + const scoped_refptr<SOCKSSocketParams>& socks_proxy_params, + const scoped_refptr<HttpProxySocketParams>& http_proxy_params, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + PrivacyMode privacy_mode, + int load_flags, + bool force_spdy_over_ssl, + bool want_spdy_over_npn); + + // Returns the type of the underlying connection. + ConnectionType GetConnectionType() const; + + // Must be called only when GetConnectionType() returns DIRECT. + const scoped_refptr<TransportSocketParams>& + GetDirectConnectionParams() const; + + // Must be called only when GetConnectionType() returns SOCKS_PROXY. + const scoped_refptr<SOCKSSocketParams>& + GetSocksProxyConnectionParams() const; + + // Must be called only when GetConnectionType() returns HTTP_PROXY. + const scoped_refptr<HttpProxySocketParams>& + GetHttpProxyConnectionParams() const; + const HostPortPair& host_and_port() const { return host_and_port_; } const SSLConfig& ssl_config() const { return ssl_config_; } PrivacyMode privacy_mode() const { return privacy_mode_; } @@ -73,10 +79,9 @@ class NET_EXPORT_PRIVATE SSLSocketParams friend class base::RefCounted<SSLSocketParams>; ~SSLSocketParams(); - const scoped_refptr<TransportSocketParams> transport_params_; + const scoped_refptr<TransportSocketParams> direct_params_; + const scoped_refptr<SOCKSSocketParams> socks_proxy_params_; const scoped_refptr<HttpProxySocketParams> http_proxy_params_; - const scoped_refptr<SOCKSSocketParams> socks_params_; - const ProxyServer::Scheme proxy_; const HostPortPair host_and_port_; const SSLConfig ssl_config_; const PrivacyMode privacy_mode_; @@ -94,6 +99,7 @@ class SSLConnectJob : public ConnectJob { public: SSLConnectJob( const std::string& group_name, + RequestPriority priority, const scoped_refptr<SSLSocketParams>& params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, @@ -138,6 +144,10 @@ class SSLConnectJob : public ConnectJob { int DoSSLConnect(); int DoSSLConnectComplete(int result); + // Returns the initial state for the state machine based on the + // |connection_type|. + static State GetInitialState(SSLSocketParams::ConnectionType connection_type); + // Starts the SSL connection process. Returns OK on success and // ERR_IO_PENDING if it cannot immediately service the request. // Otherwise, it returns a net error code. @@ -164,9 +174,11 @@ class SSLConnectJob : public ConnectJob { class NET_EXPORT_PRIVATE SSLClientSocketPool : public ClientSocketPool, - public LayeredPool, + public HigherLayeredPool, public SSLConfigService::Observer { public: + typedef SSLSocketParams SocketParams; + // Only the pools that will be used are required. i.e. if you never // try to create an SSL over SOCKS socket, |socks_pool| may be NULL. SSLClientSocketPool( @@ -204,13 +216,11 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE; virtual void FlushWithError(int error) OVERRIDE; - virtual bool IsStalled() const OVERRIDE; - virtual void CloseIdleSockets() OVERRIDE; virtual int IdleSocketCount() const OVERRIDE; @@ -222,10 +232,6 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool const std::string& group_name, const ClientSocketHandle* handle) const OVERRIDE; - virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; - - virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; - virtual base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, @@ -235,7 +241,14 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; - // LayeredPool implementation. + // LowerLayeredPool implementation. + virtual bool IsStalled() const OVERRIDE; + + virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + + virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + + // HigherLayeredPool implementation. virtual bool CloseOneIdleConnection() OVERRIDE; private: @@ -261,7 +274,7 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool virtual ~SSLConnectJobFactory() {} // ClientSocketPoolBase::ConnectJobFactory methods. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE; @@ -290,8 +303,6 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool); }; -REGISTER_SOCKET_PARAMS_FOR_POOL(SSLClientSocketPool, SSLSocketParams); - } // namespace net #endif // NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/socket/ssl_client_socket_pool_unittest.cc b/chromium/net/socket/ssl_client_socket_pool_unittest.cc index 280f6e7af1a..8aecb98dd85 100644 --- a/chromium/net/socket/ssl_client_socket_pool_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_pool_unittest.cc @@ -85,7 +85,6 @@ class SSLClientSocketPoolTest session_(CreateNetworkSession()), direct_transport_socket_params_( new TransportSocketParams(HostPortPair("host", 443), - MEDIUM, false, false, OnHostResolutionCallback())), @@ -96,15 +95,13 @@ class SSLClientSocketPoolTest &socket_factory_), proxy_transport_socket_params_( new TransportSocketParams(HostPortPair("proxy", 443), - MEDIUM, false, false, OnHostResolutionCallback())), socks_socket_params_( new SOCKSSocketParams(proxy_transport_socket_params_, true, - HostPortPair("sockshost", 443), - MEDIUM)), + HostPortPair("sockshost", 443))), socks_histograms_("MockSOCKS"), socks_socket_pool_(kMaxSockets, kMaxSocketsPerGroup, @@ -159,7 +156,6 @@ class SSLClientSocketPoolTest : NULL, proxy == ProxyServer::SCHEME_SOCKS5 ? socks_socket_params_ : NULL, proxy == ProxyServer::SCHEME_HTTP ? http_proxy_socket_params_ : NULL, - proxy, HostPortPair("host", 443), ssl_config_, kPrivacyModeDisabled, @@ -294,6 +290,30 @@ TEST_P(SSLClientSocketPoolTest, BasicDirect) { TestLoadTimingInfo(handle); } +// Make sure that SSLConnectJob passes on its priority to its +// socket request on Init (for the DIRECT case). +TEST_P(SSLClientSocketPoolTest, SetSocketRequestPriorityOnInitDirect) { + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + + for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) { + RequestPriority priority = static_cast<RequestPriority>(i); + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(SYNCHRONOUS, OK)); + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(SYNCHRONOUS, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(OK, handle.Init("a", params, priority, callback.callback(), + pool_.get(), BoundNetLog())); + EXPECT_EQ(priority, transport_socket_pool_.last_request_priority()); + handle.socket()->Disconnect(); + } +} + TEST_P(SSLClientSocketPoolTest, BasicDirectAsync) { StaticSocketDataProvider data; socket_factory_.AddSocketDataProvider(&data); @@ -547,6 +567,26 @@ TEST_P(SSLClientSocketPoolTest, SOCKSBasic) { TestLoadTimingInfo(handle); } +// Make sure that SSLConnectJob passes on its priority to its +// transport socket on Init (for the SOCKS_PROXY case). +TEST_P(SSLClientSocketPoolTest, SetTransportPriorityOnInitSOCKS) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(SYNCHRONOUS, OK)); + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(SYNCHRONOUS, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = + SSLParams(ProxyServer::SCHEME_SOCKS5, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(OK, handle.Init("a", params, HIGHEST, callback.callback(), + pool_.get(), BoundNetLog())); + EXPECT_EQ(HIGHEST, transport_socket_pool_.last_request_priority()); +} + TEST_P(SSLClientSocketPoolTest, SOCKSBasicAsync) { StaticSocketDataProvider data; socket_factory_.AddSocketDataProvider(&data); @@ -648,6 +688,38 @@ TEST_P(SSLClientSocketPoolTest, HttpProxyBasic) { TestLoadTimingInfoNoDns(handle); } +// Make sure that SSLConnectJob passes on its priority to its +// transport socket on Init (for the HTTP_PROXY case). +TEST_P(SSLClientSocketPoolTest, SetTransportPriorityOnInitHTTP) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, + "CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), + }; + MockRead reads[] = { + MockRead(SYNCHRONOUS, "HTTP/1.1 200 Connection Established\r\n\r\n"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + data.set_connect_data(MockConnect(SYNCHRONOUS, OK)); + socket_factory_.AddSocketDataProvider(&data); + AddAuthToCache(); + SSLSocketDataProvider ssl(SYNCHRONOUS, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = + SSLParams(ProxyServer::SCHEME_HTTP, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(OK, handle.Init("a", params, HIGHEST, callback.callback(), + pool_.get(), BoundNetLog())); + EXPECT_EQ(HIGHEST, transport_socket_pool_.last_request_priority()); +} + TEST_P(SSLClientSocketPoolTest, HttpProxyBasicAsync) { MockWrite writes[] = { MockWrite("CONNECT host:80 HTTP/1.1\r\n" @@ -746,8 +818,12 @@ TEST_P(SSLClientSocketPoolTest, IPPooling) { // This test requires that the HostResolver cache be populated. Normal // code would have done this already, but we do it manually. HostResolver::RequestInfo info(HostPortPair(test_hosts[i].name, kTestPort)); - host_resolver_.Resolve(info, &test_hosts[i].addresses, CompletionCallback(), - NULL, BoundNetLog()); + host_resolver_.Resolve(info, + DEFAULT_PRIORITY, + &test_hosts[i].addresses, + CompletionCallback(), + NULL, + BoundNetLog()); // Setup a SpdySessionKey test_hosts[i].key = SpdySessionKey( @@ -802,8 +878,12 @@ void SSLClientSocketPoolTest::TestIPPoolingDisabled( // This test requires that the HostResolver cache be populated. Normal // code would have done this already, but we do it manually. HostResolver::RequestInfo info(HostPortPair(test_hosts[i].name, kTestPort)); - rv = host_resolver_.Resolve(info, &test_hosts[i].addresses, - callback.callback(), NULL, BoundNetLog()); + rv = host_resolver_.Resolve(info, + DEFAULT_PRIORITY, + &test_hosts[i].addresses, + callback.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(OK, callback.GetResult(rv)); // Setup a SpdySessionKey diff --git a/chromium/net/socket/ssl_client_socket_unittest.cc b/chromium/net/socket/ssl_client_socket_unittest.cc index f0e7120a135..f791928580f 100644 --- a/chromium/net/socket/ssl_client_socket_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_unittest.cc @@ -30,9 +30,11 @@ //----------------------------------------------------------------------------- +namespace net { + namespace { -const net::SSLConfig kDefaultSSLConfig; +const SSLConfig kDefaultSSLConfig; // WrappedStreamSocket is a base class that wraps an existing StreamSocket, // forwarding the Socket and StreamSocket interfaces to the underlying @@ -40,33 +42,30 @@ const net::SSLConfig kDefaultSSLConfig; // This is to provide a common base class for subclasses to override specific // StreamSocket methods for testing, while still communicating with a 'real' // StreamSocket. -class WrappedStreamSocket : public net::StreamSocket { +class WrappedStreamSocket : public StreamSocket { public: - explicit WrappedStreamSocket(scoped_ptr<net::StreamSocket> transport) - : transport_(transport.Pass()) { - } + explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport) + : transport_(transport.Pass()) {} virtual ~WrappedStreamSocket() {} // StreamSocket implementation: - virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { + virtual int Connect(const CompletionCallback& callback) OVERRIDE { return transport_->Connect(callback); } - virtual void Disconnect() OVERRIDE { - transport_->Disconnect(); - } + virtual void Disconnect() OVERRIDE { transport_->Disconnect(); } virtual bool IsConnected() const OVERRIDE { return transport_->IsConnected(); } virtual bool IsConnectedAndIdle() const OVERRIDE { return transport_->IsConnectedAndIdle(); } - virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { return transport_->GetPeerAddress(address); } - virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { return transport_->GetLocalAddress(address); } - virtual const net::BoundNetLog& NetLog() const OVERRIDE { + virtual const BoundNetLog& NetLog() const OVERRIDE { return transport_->NetLog(); } virtual void SetSubresourceSpeculation() OVERRIDE { @@ -84,20 +83,22 @@ class WrappedStreamSocket : public net::StreamSocket { virtual bool WasNpnNegotiated() const OVERRIDE { return transport_->WasNpnNegotiated(); } - virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { return transport_->GetNegotiatedProtocol(); } - virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return transport_->GetSSLInfo(ssl_info); } // Socket implementation: - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE { + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { return transport_->Read(buf, buf_len, callback); } - virtual int Write(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE { + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { return transport_->Write(buf, buf_len, callback); } virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { @@ -108,7 +109,7 @@ class WrappedStreamSocket : public net::StreamSocket { } protected: - scoped_ptr<net::StreamSocket> transport_; + scoped_ptr<StreamSocket> transport_; }; // ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that @@ -119,12 +120,13 @@ class WrappedStreamSocket : public net::StreamSocket { // them from the TestServer. class ReadBufferingStreamSocket : public WrappedStreamSocket { public: - explicit ReadBufferingStreamSocket(scoped_ptr<net::StreamSocket> transport); + explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport); virtual ~ReadBufferingStreamSocket() {} // Socket implementation: - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE; + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; // Sets the internal buffer to |size|. This must not be greater than // the largest value supplied to Read() - that is, it does not handle @@ -148,19 +150,18 @@ class ReadBufferingStreamSocket : public WrappedStreamSocket { void OnReadCompleted(int result); State state_; - scoped_refptr<net::GrowableIOBuffer> read_buffer_; + scoped_refptr<GrowableIOBuffer> read_buffer_; int buffer_size_; - scoped_refptr<net::IOBuffer> user_read_buf_; - net::CompletionCallback user_read_callback_; + scoped_refptr<IOBuffer> user_read_buf_; + CompletionCallback user_read_callback_; }; ReadBufferingStreamSocket::ReadBufferingStreamSocket( - scoped_ptr<net::StreamSocket> transport) + scoped_ptr<StreamSocket> transport) : WrappedStreamSocket(transport.Pass()), - read_buffer_(new net::GrowableIOBuffer()), - buffer_size_(0) { -} + read_buffer_(new GrowableIOBuffer()), + buffer_size_(0) {} void ReadBufferingStreamSocket::SetBufferSize(int size) { DCHECK(!user_read_buf_.get()); @@ -168,19 +169,19 @@ void ReadBufferingStreamSocket::SetBufferSize(int size) { read_buffer_->SetCapacity(size); } -int ReadBufferingStreamSocket::Read(net::IOBuffer* buf, +int ReadBufferingStreamSocket::Read(IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) { + const CompletionCallback& callback) { if (buffer_size_ == 0) return transport_->Read(buf, buf_len, callback); if (buf_len < buffer_size_) - return net::ERR_UNEXPECTED; + return ERR_UNEXPECTED; state_ = STATE_READ; user_read_buf_ = buf; - int result = DoLoop(net::OK); - if (result == net::ERR_IO_PENDING) + int result = DoLoop(OK); + if (result == ERR_IO_PENDING) user_read_callback_ = callback; else user_read_buf_ = NULL; @@ -202,10 +203,10 @@ int ReadBufferingStreamSocket::DoLoop(int result) { case STATE_NONE: default: NOTREACHED() << "Unexpected state: " << current_state; - rv = net::ERR_UNEXPECTED; + rv = ERR_UNEXPECTED; break; } - } while (rv != net::ERR_IO_PENDING && state_ != STATE_NONE); + } while (rv != ERR_IO_PENDING && state_ != STATE_NONE); return rv; } @@ -227,10 +228,11 @@ int ReadBufferingStreamSocket::DoReadComplete(int result) { read_buffer_->set_offset(read_buffer_->offset() + result); if (read_buffer_->RemainingCapacity() > 0) { state_ = STATE_READ; - return net::OK; + return OK; } - memcpy(user_read_buf_->data(), read_buffer_->StartOfBuffer(), + memcpy(user_read_buf_->data(), + read_buffer_->StartOfBuffer(), read_buffer_->capacity()); read_buffer_->set_offset(0); return read_buffer_->capacity(); @@ -238,7 +240,7 @@ int ReadBufferingStreamSocket::DoReadComplete(int result) { void ReadBufferingStreamSocket::OnReadCompleted(int result) { result = DoLoop(result); - if (result == net::ERR_IO_PENDING) + if (result == ERR_IO_PENDING) return; user_read_buf_ = NULL; @@ -252,16 +254,18 @@ class SynchronousErrorStreamSocket : public WrappedStreamSocket { virtual ~SynchronousErrorStreamSocket() {} // Socket implementation: - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE; - virtual int Write(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE; + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; // Sets the next Read() call and all future calls to return |error|. // If there is already a pending asynchronous read, the configured error // will not be returned until that asynchronous read has completed and Read() // is called again. - void SetNextReadError(net::Error error) { + void SetNextReadError(Error error) { DCHECK_GE(0, error); have_read_error_ = true; pending_read_error_ = error; @@ -271,7 +275,7 @@ class SynchronousErrorStreamSocket : public WrappedStreamSocket { // If there is already a pending asynchronous write, the configured error // will not be returned until that asynchronous write has completed and // Write() is called again. - void SetNextWriteError(net::Error error) { + void SetNextWriteError(Error error) { DCHECK_GE(0, error); have_write_error_ = true; pending_write_error_ = error; @@ -291,24 +295,21 @@ SynchronousErrorStreamSocket::SynchronousErrorStreamSocket( scoped_ptr<StreamSocket> transport) : WrappedStreamSocket(transport.Pass()), have_read_error_(false), - pending_read_error_(net::OK), + pending_read_error_(OK), have_write_error_(false), - pending_write_error_(net::OK) { -} + pending_write_error_(OK) {} -int SynchronousErrorStreamSocket::Read( - net::IOBuffer* buf, - int buf_len, - const net::CompletionCallback& callback) { +int SynchronousErrorStreamSocket::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { if (have_read_error_) return pending_read_error_; return transport_->Read(buf, buf_len, callback); } -int SynchronousErrorStreamSocket::Write( - net::IOBuffer* buf, - int buf_len, - const net::CompletionCallback& callback) { +int SynchronousErrorStreamSocket::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { if (have_write_error_) return pending_write_error_; return transport_->Write(buf, buf_len, callback); @@ -324,12 +325,14 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { virtual ~FakeBlockingStreamSocket() {} // Socket implementation: - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE { + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { return read_state_.RunWrappedFunction(buf, buf_len, callback); } - virtual int Write(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE { + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { return write_state_.RunWrappedFunction(buf, buf_len, callback); } @@ -350,9 +353,8 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { class BlockingState { public: // Wrapper for the underlying Socket function to call (ie: Read/Write). - typedef base::Callback< - int(net::IOBuffer*, int, - const net::CompletionCallback&)> WrappedSocketFunction; + typedef base::Callback<int(IOBuffer*, int, const CompletionCallback&)> + WrappedSocketFunction; explicit BlockingState(const WrappedSocketFunction& function); ~BlockingState() {} @@ -371,8 +373,9 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { // Performs the wrapped socket function on the underlying transport. If // configured to block via SetShouldBlock(), then |user_callback| will not // be invoked until Unblock() has been called. - int RunWrappedFunction(net::IOBuffer* buf, int len, - const net::CompletionCallback& user_callback); + int RunWrappedFunction(IOBuffer* buf, + int len, + const CompletionCallback& user_callback); private: // Handles completion from the underlying wrapped socket function. @@ -382,7 +385,7 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { bool should_block_; bool have_result_; int pending_result_; - net::CompletionCallback user_callback_; + CompletionCallback user_callback_; }; BlockingState read_state_; @@ -397,16 +400,14 @@ FakeBlockingStreamSocket::FakeBlockingStreamSocket( read_state_(base::Bind(&Socket::Read, base::Unretained(transport_.get()))), write_state_(base::Bind(&Socket::Write, - base::Unretained(transport_.get()))) { -} + base::Unretained(transport_.get()))) {} FakeBlockingStreamSocket::BlockingState::BlockingState( const WrappedSocketFunction& function) : wrapped_function_(function), should_block_(false), have_result_(false), - pending_result_(net::OK) { -} + pending_result_(OK) {} void FakeBlockingStreamSocket::BlockingState::SetShouldBlock() { DCHECK(!should_block_); @@ -429,24 +430,24 @@ void FakeBlockingStreamSocket::BlockingState::Unblock() { } int FakeBlockingStreamSocket::BlockingState::RunWrappedFunction( - net::IOBuffer* buf, + IOBuffer* buf, int len, - const net::CompletionCallback& callback) { + const CompletionCallback& callback) { // The callback to be called by the underlying transport. Either forward // directly to the user's callback if not set to block, or intercept it with // OnCompleted so that the user's callback is not invoked until Unblock() is // called. - net::CompletionCallback transport_callback = + CompletionCallback transport_callback = !should_block_ ? callback : base::Bind(&BlockingState::OnCompleted, base::Unretained(this)); int rv = wrapped_function_.Run(buf, len, transport_callback); if (should_block_) { user_callback_ = callback; // May have completed synchronously. - have_result_ = (rv != net::ERR_IO_PENDING); + have_result_ = (rv != ERR_IO_PENDING); pending_result_ = rv; - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return rv; @@ -466,64 +467,61 @@ void FakeBlockingStreamSocket::BlockingState::OnCompleted(int result) { base::ResetAndReturn(&user_callback_).Run(result); } -// CompletionCallback that will delete the associated net::StreamSocket when +// CompletionCallback that will delete the associated StreamSocket when // the callback is invoked. -class DeleteSocketCallback : public net::TestCompletionCallbackBase { +class DeleteSocketCallback : public TestCompletionCallbackBase { public: - explicit DeleteSocketCallback(net::StreamSocket* socket) + explicit DeleteSocketCallback(StreamSocket* socket) : socket_(socket), callback_(base::Bind(&DeleteSocketCallback::OnComplete, - base::Unretained(this))) { - } + base::Unretained(this))) {} virtual ~DeleteSocketCallback() {} - const net::CompletionCallback& callback() const { return callback_; } + const CompletionCallback& callback() const { return callback_; } private: void OnComplete(int result) { - if (socket_) { - delete socket_; - socket_ = NULL; - } else { - ADD_FAILURE() << "Deleting socket twice"; - } - SetResult(result); + if (socket_) { + delete socket_; + socket_ = NULL; + } else { + ADD_FAILURE() << "Deleting socket twice"; + } + SetResult(result); } - net::StreamSocket* socket_; - net::CompletionCallback callback_; + StreamSocket* socket_; + CompletionCallback callback_; DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback); }; -} // namespace - class SSLClientSocketTest : public PlatformTest { public: SSLClientSocketTest() - : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), - cert_verifier_(new net::MockCertVerifier), - transport_security_state_(new net::TransportSecurityState) { - cert_verifier_->set_default_result(net::OK); + : socket_factory_(ClientSocketFactory::GetDefaultFactory()), + cert_verifier_(new MockCertVerifier), + transport_security_state_(new TransportSecurityState) { + cert_verifier_->set_default_result(OK); context_.cert_verifier = cert_verifier_.get(); context_.transport_security_state = transport_security_state_.get(); } protected: - net::SSLClientSocket* CreateSSLClientSocket( - net::StreamSocket* transport_socket, - const net::HostPortPair& host_and_port, - const net::SSLConfig& ssl_config) { - return socket_factory_->CreateSSLClientSocket(transport_socket, - host_and_port, - ssl_config, - context_); + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<StreamSocket> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config) { + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + connection->SetSocket(transport_socket.Pass()); + return socket_factory_->CreateSSLClientSocket( + connection.Pass(), host_and_port, ssl_config, context_); } - net::ClientSocketFactory* socket_factory_; - scoped_ptr<net::MockCertVerifier> cert_verifier_; - scoped_ptr<net::TransportSecurityState> transport_security_state_; - net::SSLClientSocketContext context_; + ClientSocketFactory* socket_factory_; + scoped_ptr<MockCertVerifier> cert_verifier_; + scoped_ptr<TransportSecurityState> transport_security_state_; + SSLClientSocketContext context_; }; //----------------------------------------------------------------------------- @@ -536,45 +534,45 @@ class SSLClientSocketTest : public PlatformTest { // timeout. This means that an SSL connect end event may appear as a socket // write. static bool LogContainsSSLConnectEndEvent( - const net::CapturingNetLog::CapturedEntryList& log, int i) { - return net::LogContainsEndEvent(log, i, net::NetLog::TYPE_SSL_CONNECT) || - net::LogContainsEvent(log, i, net::NetLog::TYPE_SOCKET_BYTES_SENT, - net::NetLog::PHASE_NONE); -}; + const CapturingNetLog::CapturedEntryList& log, + int i) { + return LogContainsEndEvent(log, i, NetLog::TYPE_SSL_CONNECT) || + LogContainsEvent( + log, i, NetLog::TYPE_SOCKET_BYTES_SENT, NetLog::PHASE_NONE); +} +; TEST_F(SSLClientSocketTest, Connect) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); - if (rv == net::ERR_IO_PENDING) + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); log.GetEntries(&entries); EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); @@ -584,43 +582,40 @@ TEST_F(SSLClientSocketTest, Connect) { } TEST_F(SSLClientSocketTest, ConnectExpired) { - net::SpawnedTestServer::SSLOptions ssl_options( - net::SpawnedTestServer::SSLOptions::CERT_EXPIRED); - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath()); + SpawnedTestServer::SSLOptions ssl_options( + SpawnedTestServer::SSLOptions::CERT_EXPIRED); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); ASSERT_TRUE(test_server.Start()); - cert_verifier_->set_default_result(net::ERR_CERT_DATE_INVALID); + cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); - if (rv == net::ERR_IO_PENDING) + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::ERR_CERT_DATE_INVALID, rv); + EXPECT_EQ(ERR_CERT_DATE_INVALID, rv); // Rather than testing whether or not the underlying socket is connected, // test that the handshake has finished. This is because it may be @@ -631,43 +626,40 @@ TEST_F(SSLClientSocketTest, ConnectExpired) { } TEST_F(SSLClientSocketTest, ConnectMismatched) { - net::SpawnedTestServer::SSLOptions ssl_options( - net::SpawnedTestServer::SSLOptions::CERT_MISMATCHED_NAME); - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath()); + SpawnedTestServer::SSLOptions ssl_options( + SpawnedTestServer::SSLOptions::CERT_MISMATCHED_NAME); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); ASSERT_TRUE(test_server.Start()); - cert_verifier_->set_default_result(net::ERR_CERT_COMMON_NAME_INVALID); + cert_verifier_->set_default_result(ERR_CERT_COMMON_NAME_INVALID); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); - if (rv == net::ERR_IO_PENDING) + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::ERR_CERT_COMMON_NAME_INVALID, rv); + EXPECT_EQ(ERR_CERT_COMMON_NAME_INVALID, rv); // Rather than testing whether or not the underlying socket is connected, // test that the handshake has finished. This is because it may be @@ -680,38 +672,35 @@ TEST_F(SSLClientSocketTest, ConnectMismatched) { // Attempt to connect to a page which requests a client certificate. It should // return an error code on connect. TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) { - net::SpawnedTestServer::SSLOptions ssl_options; + SpawnedTestServer::SSLOptions ssl_options; ssl_options.request_client_certificate = true; - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath()); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); - if (rv == net::ERR_IO_PENDING) + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); log.GetEntries(&entries); @@ -731,9 +720,9 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) { // certificate. This test may still be useful as we'll want to close // the socket on a timeout if the user takes a long time to pick a // cert. Related bug: https://bugzilla.mozilla.org/show_bug.cgi?id=542832 - net::ExpectLogContainsSomewhere( - entries, 0, net::NetLog::TYPE_SSL_CONNECT, net::NetLog::PHASE_END); - EXPECT_EQ(net::ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv); + ExpectLogContainsSomewhere( + entries, 0, NetLog::TYPE_SSL_CONNECT, NetLog::PHASE_END); + EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv); EXPECT_FALSE(sock->IsConnected()); } @@ -742,32 +731,30 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) { // // TODO(davidben): Also test providing an actual certificate. TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { - net::SpawnedTestServer::SSLOptions ssl_options; + SpawnedTestServer::SSLOptions ssl_options; ssl_options.request_client_certificate = true; - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath()); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - net::SSLConfig ssl_config = kDefaultSSLConfig; + SSLConfig ssl_config = kDefaultSSLConfig; ssl_config.send_client_cert = true; ssl_config.client_cert = NULL; - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - ssl_config)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); EXPECT_FALSE(sock->IsConnected()); @@ -775,14 +762,13 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { // TODO(davidben): Add a test which requires them and verify the error. rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); - if (rv == net::ERR_IO_PENDING) + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); log.GetEntries(&entries); EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); @@ -790,7 +776,7 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { // We responded to the server's certificate request with a Certificate // message with no client certificate in it. ssl_info.client_cert_sent // should be false in this case. - net::SSLInfo ssl_info; + SSLInfo ssl_info; sock->GetSSLInfo(&ssl_info); EXPECT_FALSE(ssl_info.client_cert_sent); @@ -804,51 +790,50 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { // - Server sends data unexpectedly. TEST_F(SSLClientSocketTest, Read) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, NULL, net::NetLog::Source()); + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(arraysize(request_text) - 1)); + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); rv = sock->Write( request_buffer.get(), arraysize(request_text) - 1, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); for (;;) { rv = sock->Read(buf.get(), 4096, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_GE(rv, 0); @@ -862,39 +847,40 @@ TEST_F(SSLClientSocketTest, Read) { // the socket connection uncleanly. // This is a regression test for http://crbug.com/238536 TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket( - addr, NULL, net::NetLog::Source())); - SynchronousErrorStreamSocket* transport = new SynchronousErrorStreamSocket( - real_transport.Pass()); + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + scoped_ptr<SynchronousErrorStreamSocket> transport( + new SynchronousErrorStreamSocket(real_transport.Pass())); int rv = callback.GetResult(transport->Connect(callback.callback())); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); // Disable TLS False Start to avoid handshake non-determinism. - net::SSLConfig ssl_config; + SSLConfig ssl_config; ssl_config.false_start_enabled = false; - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), + SynchronousErrorStreamSocket* raw_transport = transport.get(); + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), ssl_config)); rv = callback.GetResult(sock->Connect(callback.callback())); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; static const int kRequestTextSize = static_cast<int>(arraysize(request_text) - 1); - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(kRequestTextSize)); + scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize)); memcpy(request_buffer->data(), request_text, kRequestTextSize); rv = callback.GetResult( @@ -902,9 +888,9 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { EXPECT_EQ(kRequestTextSize, rv); // Simulate an unclean/forcible shutdown. - transport->SetNextReadError(net::ERR_CONNECTION_RESET); + raw_transport->SetNextReadError(ERR_CONNECTION_RESET); - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); // Note: This test will hang if this bug has regressed. Simply checking that // rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING is a legitimate @@ -913,7 +899,7 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { #if !defined(USE_OPENSSL) // SSLClientSocketNSS records the error exactly - EXPECT_EQ(net::ERR_CONNECTION_RESET, rv); + EXPECT_EQ(ERR_CONNECTION_RESET, rv); #else // SSLClientSocketOpenSSL treats any errors as a simple EOF. EXPECT_EQ(0, rv); @@ -925,49 +911,51 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { // intermediary terminates the socket connection uncleanly. // This is a regression test for http://crbug.com/249848 TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket( - addr, NULL, net::NetLog::Source())); - // Note: |error_socket|'s ownership is handed to |transport|, but the pointer + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + // Note: |error_socket|'s ownership is handed to |transport|, but a pointer // is retained in order to configure additional errors. - SynchronousErrorStreamSocket* error_socket = new SynchronousErrorStreamSocket( - real_transport.Pass()); - FakeBlockingStreamSocket* transport = new FakeBlockingStreamSocket( - scoped_ptr<net::StreamSocket>(error_socket)); + scoped_ptr<SynchronousErrorStreamSocket> error_socket( + new SynchronousErrorStreamSocket(real_transport.Pass())); + SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); + scoped_ptr<FakeBlockingStreamSocket> transport( + new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>())); + FakeBlockingStreamSocket* raw_transport = transport.get(); int rv = callback.GetResult(transport->Connect(callback.callback())); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); // Disable TLS False Start to avoid handshake non-determinism. - net::SSLConfig ssl_config; + SSLConfig ssl_config; ssl_config.false_start_enabled = false; - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), ssl_config)); rv = callback.GetResult(sock->Connect(callback.callback())); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; static const int kRequestTextSize = static_cast<int>(arraysize(request_text) - 1); - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(kRequestTextSize)); + scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize)); memcpy(request_buffer->data(), request_text, kRequestTextSize); // Simulate an unclean/forcible shutdown on the underlying socket. // However, simulate this error asynchronously. - error_socket->SetNextWriteError(net::ERR_CONNECTION_RESET); - transport->SetNextWriteShouldBlock(); + raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET); + raw_transport->SetNextWriteShouldBlock(); // This write should complete synchronously, because the TLS ciphertext // can be created and placed into the outgoing buffers independent of the @@ -976,14 +964,14 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { sock->Write(request_buffer.get(), kRequestTextSize, callback.callback())); EXPECT_EQ(kRequestTextSize, rv); - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); rv = sock->Read(buf.get(), 4096, callback.callback()); - EXPECT_EQ(net::ERR_IO_PENDING, rv); + EXPECT_EQ(ERR_IO_PENDING, rv); // Now unblock the outgoing request, having it fail with the connection // being reset. - transport->UnblockWrite(); + raw_transport->UnblockWrite(); // Note: This will cause an inifite loop if this bug has regressed. Simply // checking that rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING @@ -992,7 +980,7 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { #if !defined(USE_OPENSSL) // SSLClientSocketNSS records the error exactly - EXPECT_EQ(net::ERR_CONNECTION_RESET, rv); + EXPECT_EQ(ERR_CONNECTION_RESET, rv); #else // SSLClientSocketOpenSSL treats any errors as a simple EOF. EXPECT_EQ(0, rv); @@ -1002,38 +990,37 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { // Test the full duplex mode, with Read and Write pending at the same time. // This test also serves as a regression test for http://crbug.com/29815. TEST_F(SSLClientSocketTest, Read_FullDuplex) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; // Used for everything except Write. + TestCompletionCallback callback; // Used for everything except Write. - net::StreamSocket* transport = new net::TCPClientSocket( - addr, NULL, net::NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); // Issue a "hanging" Read first. - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); rv = sock->Read(buf.get(), 4096, callback.callback()); // We haven't written the request, so there should be no response yet. - ASSERT_EQ(net::ERR_IO_PENDING, rv); + ASSERT_EQ(ERR_IO_PENDING, rv); // Write the request. // The request is padded with a User-Agent header to a size that causes the @@ -1043,15 +1030,14 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) { for (int i = 0; i < 3770; ++i) request_text.push_back('*'); request_text.append("\r\n\r\n"); - scoped_refptr<net::IOBuffer> request_buffer( - new net::StringIOBuffer(request_text)); + scoped_refptr<IOBuffer> request_buffer(new StringIOBuffer(request_text)); - net::TestCompletionCallback callback2; // Used for Write only. + TestCompletionCallback callback2; // Used for Write only. rv = sock->Write( request_buffer.get(), request_text.size(), callback2.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback2.WaitForResult(); EXPECT_EQ(static_cast<int>(request_text.size()), rv); @@ -1067,62 +1053,65 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) { // callback, the Write() callback should not be invoked. // Regression test for http://crbug.com/232633 TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket( - addr, NULL, net::NetLog::Source())); - // Note: |error_socket|'s ownership is handed to |transport|, but the pointer + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + // Note: |error_socket|'s ownership is handed to |transport|, but a pointer // is retained in order to configure additional errors. - SynchronousErrorStreamSocket* error_socket = new SynchronousErrorStreamSocket( - real_transport.Pass()); - FakeBlockingStreamSocket* transport = new FakeBlockingStreamSocket( - scoped_ptr<net::StreamSocket>(error_socket)); + scoped_ptr<SynchronousErrorStreamSocket> error_socket( + new SynchronousErrorStreamSocket(real_transport.Pass())); + SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); + scoped_ptr<FakeBlockingStreamSocket> transport( + new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>())); + FakeBlockingStreamSocket* raw_transport = transport.get(); int rv = callback.GetResult(transport->Connect(callback.callback())); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); // Disable TLS False Start to avoid handshake non-determinism. - net::SSLConfig ssl_config; + SSLConfig ssl_config; ssl_config.false_start_enabled = false; - net::SSLClientSocket* sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - ssl_config)); + scoped_ptr<SSLClientSocket> sock = + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), + ssl_config); rv = callback.GetResult(sock->Connect(callback.callback())); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); std::string request_text = "GET / HTTP/1.1\r\nUser-Agent: long browser name "; request_text.append(20 * 1024, '*'); request_text.append("\r\n\r\n"); - scoped_refptr<net::DrainableIOBuffer> request_buffer( - new net::DrainableIOBuffer(new net::StringIOBuffer(request_text), - request_text.size())); + scoped_refptr<DrainableIOBuffer> request_buffer(new DrainableIOBuffer( + new StringIOBuffer(request_text), request_text.size())); // Simulate errors being returned from the underlying Read() and Write() ... - error_socket->SetNextReadError(net::ERR_CONNECTION_RESET); - error_socket->SetNextWriteError(net::ERR_CONNECTION_RESET); + raw_error_socket->SetNextReadError(ERR_CONNECTION_RESET); + raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET); // ... but have those errors returned asynchronously. Because the Write() will // return first, this will trigger the error. - transport->SetNextReadShouldBlock(); - transport->SetNextWriteShouldBlock(); + raw_transport->SetNextReadShouldBlock(); + raw_transport->SetNextWriteShouldBlock(); // Enqueue a Read() before calling Write(), which should "hang" due to // the ERR_IO_PENDING caused by SetReadShouldBlock() and thus return. - DeleteSocketCallback read_callback(sock); - scoped_refptr<net::IOBuffer> read_buf(new net::IOBuffer(4096)); - rv = sock->Read(read_buf.get(), 4096, read_callback.callback()); + SSLClientSocket* raw_sock = sock.get(); + DeleteSocketCallback read_callback(sock.release()); + scoped_refptr<IOBuffer> read_buf(new IOBuffer(4096)); + rv = raw_sock->Read(read_buf.get(), 4096, read_callback.callback()); // Ensure things didn't complete synchronously, otherwise |sock| is invalid. - ASSERT_EQ(net::ERR_IO_PENDING, rv); + ASSERT_EQ(ERR_IO_PENDING, rv); ASSERT_FALSE(read_callback.have_result()); #if !defined(USE_OPENSSL) @@ -1142,9 +1131,9 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { // SSLClientSocketOpenSSL::Write() will not return until all of // |request_buffer| has been written to the underlying BIO (although not // necessarily the underlying transport). - rv = callback.GetResult(sock->Write(request_buffer.get(), - request_buffer->BytesRemaining(), - callback.callback())); + rv = callback.GetResult(raw_sock->Write(request_buffer.get(), + request_buffer->BytesRemaining(), + callback.callback())); ASSERT_LT(0, rv); request_buffer->DidConsume(rv); @@ -1157,22 +1146,22 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { // Attempt to write the remaining data. NSS will not be able to consume the // application data because the internal buffers are full, while OpenSSL will // return that its blocked because the underlying transport is blocked. - rv = sock->Write(request_buffer.get(), - request_buffer->BytesRemaining(), - callback.callback()); - ASSERT_EQ(net::ERR_IO_PENDING, rv); + rv = raw_sock->Write(request_buffer.get(), + request_buffer->BytesRemaining(), + callback.callback()); + ASSERT_EQ(ERR_IO_PENDING, rv); ASSERT_FALSE(callback.have_result()); // Now unblock Write(), which will invoke OnSendComplete and (eventually) // call the Read() callback, deleting the socket and thus aborting calling // the Write() callback. - transport->UnblockWrite(); + raw_transport->UnblockWrite(); rv = read_callback.WaitForResult(); #if !defined(USE_OPENSSL) // NSS records the error exactly. - EXPECT_EQ(net::ERR_CONNECTION_RESET, rv); + EXPECT_EQ(ERR_CONNECTION_RESET, rv); #else // OpenSSL treats any errors as a simple EOF. EXPECT_EQ(0, rv); @@ -1183,50 +1172,49 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { } TEST_F(SSLClientSocketTest, Read_SmallChunks) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, NULL, net::NetLog::Source()); + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(arraysize(request_text) - 1)); + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); rv = sock->Write( request_buffer.get(), arraysize(request_text) - 1, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(1)); + scoped_refptr<IOBuffer> buf(new IOBuffer(1)); for (;;) { rv = sock->Read(buf.get(), 1, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_GE(rv, 0); @@ -1236,34 +1224,36 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) { } TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; + TestCompletionCallback callback; - scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket( - addr, NULL, net::NetLog::Source())); - ReadBufferingStreamSocket* transport = new ReadBufferingStreamSocket( - real_transport.Pass()); + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + scoped_ptr<ReadBufferingStreamSocket> transport( + new ReadBufferingStreamSocket(real_transport.Pass())); + ReadBufferingStreamSocket* raw_transport = transport.get(); int rv = callback.GetResult(transport->Connect(callback.callback())); - ASSERT_EQ(net::OK, rv); + ASSERT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), kDefaultSSLConfig)); rv = callback.GetResult(sock->Connect(callback.callback())); - ASSERT_EQ(net::OK, rv); + ASSERT_EQ(OK, rv); ASSERT_TRUE(sock->IsConnected()); const char request_text[] = "GET /ssl-many-small-records HTTP/1.0\r\n\r\n"; - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(arraysize(request_text) - 1)); + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); rv = callback.GetResult(sock->Write( @@ -1280,117 +1270,114 @@ TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { // 15K was chosen because 15K is smaller than the 17K (max) read issued by // the SSLClientSocket implementation, and larger than the minimum amount // of ciphertext necessary to contain the 8K of plaintext requested below. - transport->SetBufferSize(15000); + raw_transport->SetBufferSize(15000); - scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(8192)); + scoped_refptr<IOBuffer> buffer(new IOBuffer(8192)); rv = callback.GetResult(sock->Read(buffer.get(), 8192, callback.callback())); ASSERT_EQ(rv, 8192); } TEST_F(SSLClientSocketTest, Read_Interrupted) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, NULL, net::NetLog::Source()); + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(arraysize(request_text) - 1)); + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); rv = sock->Write( request_buffer.get(), arraysize(request_text) - 1, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); // Do a partial read and then exit. This test should not crash! - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(512)); + scoped_refptr<IOBuffer> buf(new IOBuffer(512)); rv = sock->Read(buf.get(), 512, callback.callback()); - EXPECT_TRUE(rv > 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv > 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_GT(rv, 0); } TEST_F(SSLClientSocketTest, Read_FullLogging) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - log.SetLogLevel(net::NetLog::LOG_ALL); - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + log.SetLogLevel(NetLog::LOG_ALL); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(arraysize(request_text) - 1)); + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); rv = sock->Write( request_buffer.get(), arraysize(request_text) - 1, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - size_t last_index = net::ExpectLogContainsSomewhereAfter( - entries, 5, net::NetLog::TYPE_SSL_SOCKET_BYTES_SENT, - net::NetLog::PHASE_NONE); + size_t last_index = ExpectLogContainsSomewhereAfter( + entries, 5, NetLog::TYPE_SSL_SOCKET_BYTES_SENT, NetLog::PHASE_NONE); - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); for (;;) { rv = sock->Read(buf.get(), 4096, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_GE(rv, 0); @@ -1398,61 +1385,59 @@ TEST_F(SSLClientSocketTest, Read_FullLogging) { break; log.GetEntries(&entries); - last_index = net::ExpectLogContainsSomewhereAfter( - entries, last_index + 1, net::NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, - net::NetLog::PHASE_NONE); + last_index = + ExpectLogContainsSomewhereAfter(entries, + last_index + 1, + NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, + NetLog::PHASE_NONE); } } // Regression test for http://crbug.com/42538 TEST_F(SSLClientSocketTest, PrematureApplicationData) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; - net::TestCompletionCallback callback; + AddressList addr; + TestCompletionCallback callback; static const unsigned char application_data[] = { - 0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b, - 0xc2, 0xf8, 0xb2, 0xc1, 0x56, 0x42, 0xb9, 0x57, 0x7f, 0xde, 0x87, 0x46, - 0xf7, 0xa3, 0x52, 0x42, 0x21, 0xf0, 0x13, 0x1c, 0x9c, 0x83, 0x88, 0xd6, - 0x93, 0x0c, 0xf6, 0x36, 0x30, 0x05, 0x7e, 0x20, 0xb5, 0xb5, 0x73, 0x36, - 0x53, 0x83, 0x0a, 0xfc, 0x17, 0x63, 0xbf, 0xa0, 0xe4, 0x42, 0x90, 0x0d, - 0x2f, 0x18, 0x6d, 0x20, 0xd8, 0x36, 0x3f, 0xfc, 0xe6, 0x01, 0xfa, 0x0f, - 0xa5, 0x75, 0x7f, 0x09, 0x00, 0x04, 0x00, 0x16, 0x03, 0x01, 0x11, 0x57, - 0x0b, 0x00, 0x11, 0x53, 0x00, 0x11, 0x50, 0x00, 0x06, 0x22, 0x30, 0x82, - 0x06, 0x1e, 0x30, 0x82, 0x05, 0x06, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, - 0x0a - }; + 0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b, + 0xc2, 0xf8, 0xb2, 0xc1, 0x56, 0x42, 0xb9, 0x57, 0x7f, 0xde, 0x87, 0x46, + 0xf7, 0xa3, 0x52, 0x42, 0x21, 0xf0, 0x13, 0x1c, 0x9c, 0x83, 0x88, 0xd6, + 0x93, 0x0c, 0xf6, 0x36, 0x30, 0x05, 0x7e, 0x20, 0xb5, 0xb5, 0x73, 0x36, + 0x53, 0x83, 0x0a, 0xfc, 0x17, 0x63, 0xbf, 0xa0, 0xe4, 0x42, 0x90, 0x0d, + 0x2f, 0x18, 0x6d, 0x20, 0xd8, 0x36, 0x3f, 0xfc, 0xe6, 0x01, 0xfa, 0x0f, + 0xa5, 0x75, 0x7f, 0x09, 0x00, 0x04, 0x00, 0x16, 0x03, 0x01, 0x11, 0x57, + 0x0b, 0x00, 0x11, 0x53, 0x00, 0x11, 0x50, 0x00, 0x06, 0x22, 0x30, 0x82, + 0x06, 0x1e, 0x30, 0x82, 0x05, 0x06, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, + 0x0a}; // All reads and writes complete synchronously (async=false). - net::MockRead data_reads[] = { - net::MockRead(net::SYNCHRONOUS, - reinterpret_cast<const char*>(application_data), - arraysize(application_data)), - net::MockRead(net::SYNCHRONOUS, net::OK), - }; + MockRead data_reads[] = { + MockRead(SYNCHRONOUS, + reinterpret_cast<const char*>(application_data), + arraysize(application_data)), + MockRead(SYNCHRONOUS, OK), }; - net::StaticSocketDataProvider data(data_reads, arraysize(data_reads), - NULL, 0); + StaticSocketDataProvider data(data_reads, arraysize(data_reads), NULL, 0); - net::StreamSocket* transport = - new net::MockTCPClientSocket(addr, NULL, &data); + scoped_ptr<StreamSocket> transport( + new MockTCPClientSocket(addr, NULL, &data)); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv); + EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv); } TEST_F(SSLClientSocketTest, CipherSuiteDisables) { @@ -1460,46 +1445,41 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml, // only disabling those cipher suites that the test server actually // implements. - const uint16 kCiphersToDisable[] = { - 0x0005, // TLS_RSA_WITH_RC4_128_SHA + const uint16 kCiphersToDisable[] = {0x0005, // TLS_RSA_WITH_RC4_128_SHA }; - net::SpawnedTestServer::SSLOptions ssl_options; + SpawnedTestServer::SSLOptions ssl_options; // Enable only RC4 on the test server. - ssl_options.bulk_ciphers = - net::SpawnedTestServer::SSLOptions::BULK_CIPHER_RC4; - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath()); + ssl_options.bulk_ciphers = SpawnedTestServer::SSLOptions::BULK_CIPHER_RC4; + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - net::SSLConfig ssl_config; + SSLConfig ssl_config; for (size_t i = 0; i < arraysize(kCiphersToDisable); ++i) ssl_config.disabled_cipher_suites.push_back(kCiphersToDisable[i]); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - ssl_config)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); // NSS has special handling that maps a handshake_failure alert received // immediately after a client_hello to be a mismatched cipher suite error, @@ -1507,17 +1487,16 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { // Secure Transport (OS X), the handshake_failure is bubbled up without any // interpretation, leading to ERR_SSL_PROTOCOL_ERROR. Either way, a failure // indicates that no cipher suite was negotiated with the test server. - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_TRUE(rv == net::ERR_SSL_VERSION_OR_CIPHER_MISMATCH || - rv == net::ERR_SSL_PROTOCOL_ERROR); + EXPECT_TRUE(rv == ERR_SSL_VERSION_OR_CIPHER_MISMATCH || + rv == ERR_SSL_PROTOCOL_ERROR); // The exact ordering differs between SSLClientSocketNSS (which issues an // extra read) and SSLClientSocketMac (which does not). Just make sure the // error appears somewhere in the log. log.GetEntries(&entries); - net::ExpectLogContainsSomewhere(entries, 0, - net::NetLog::TYPE_SSL_HANDSHAKE_ERROR, - net::NetLog::PHASE_NONE); + ExpectLogContainsSomewhere( + entries, 0, NetLog::TYPE_SSL_HANDSHAKE_ERROR, NetLog::PHASE_NONE); // We cannot test sock->IsConnected(), as the NSS implementation disconnects // the socket when it encounters an error, whereas other implementations @@ -1539,65 +1518,65 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { // Here we verify that such a simple ClientSocketHandle, not associated with any // client socket pool, can be destroyed safely. TEST_F(SSLClientSocketTest, ClientSocketHandleNotFromPool) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, NULL, net::NetLog::Source()); + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - net::ClientSocketHandle* socket_handle = new net::ClientSocketHandle(); - socket_handle->set_socket(transport); + scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle()); + socket_handle->SetSocket(transport.Pass()); - scoped_ptr<net::SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket( - socket_handle, test_server.host_port_pair(), kDefaultSSLConfig, - context_)); + scoped_ptr<SSLClientSocket> sock( + socket_factory_->CreateSSLClientSocket(socket_handle.Pass(), + test_server.host_port_pair(), + kDefaultSSLConfig, + context_)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); } // Verifies that SSLClientSocket::ExportKeyingMaterial return a success // code and different keying label results in different keying material. TEST_F(SSLClientSocketTest, ExportKeyingMaterial) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; + TestCompletionCallback callback; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, NULL, net::NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); const int kKeyingMaterialSize = 32; @@ -1605,23 +1584,23 @@ TEST_F(SSLClientSocketTest, ExportKeyingMaterial) { const char* kKeyingContext = ""; unsigned char client_out1[kKeyingMaterialSize]; memset(client_out1, 0, sizeof(client_out1)); - rv = sock->ExportKeyingMaterial(kKeyingLabel1, false, kKeyingContext, - client_out1, sizeof(client_out1)); - EXPECT_EQ(rv, net::OK); + rv = sock->ExportKeyingMaterial( + kKeyingLabel1, false, kKeyingContext, client_out1, sizeof(client_out1)); + EXPECT_EQ(rv, OK); const char* kKeyingLabel2 = "client-socket-test-2"; unsigned char client_out2[kKeyingMaterialSize]; memset(client_out2, 0, sizeof(client_out2)); - rv = sock->ExportKeyingMaterial(kKeyingLabel2, false, kKeyingContext, - client_out2, sizeof(client_out2)); - EXPECT_EQ(rv, net::OK); + rv = sock->ExportKeyingMaterial( + kKeyingLabel2, false, kKeyingContext, client_out2, sizeof(client_out2)); + EXPECT_EQ(rv, OK); EXPECT_NE(memcmp(client_out1, client_out2, kKeyingMaterialSize), 0); } // Verifies that SSLClientSocket::ClearSessionCache can be called without // explicit NSS initialization. TEST(SSLClientSocket, ClearSessionCache) { - net::SSLClientSocket::ClearSessionCache(); + SSLClientSocket::ClearSessionCache(); } // This tests that SSLInfo contains a properly re-constructed certificate @@ -1639,86 +1618,84 @@ TEST(SSLClientSocket, ClearSessionCache) { TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) { // By default, cause the CertVerifier to treat all certificates as // expired. - cert_verifier_->set_default_result(net::ERR_CERT_DATE_INVALID); + cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID); // We will expect SSLInfo to ultimately contain this chain. - net::CertificateList certs = CreateCertificateListFromFile( - net::GetTestCertsDirectory(), "redundant-validated-chain.pem", - net::X509Certificate::FORMAT_AUTO); + CertificateList certs = + CreateCertificateListFromFile(GetTestCertsDirectory(), + "redundant-validated-chain.pem", + X509Certificate::FORMAT_AUTO); ASSERT_EQ(3U, certs.size()); - net::X509Certificate::OSCertHandles temp_intermediates; + X509Certificate::OSCertHandles temp_intermediates; temp_intermediates.push_back(certs[1]->os_cert_handle()); temp_intermediates.push_back(certs[2]->os_cert_handle()); - net::CertVerifyResult verify_result; - verify_result.verified_cert = - net::X509Certificate::CreateFromHandle(certs[0]->os_cert_handle(), - temp_intermediates); + CertVerifyResult verify_result; + verify_result.verified_cert = X509Certificate::CreateFromHandle( + certs[0]->os_cert_handle(), temp_intermediates); // Add a rule that maps the server cert (A) to the chain of A->B->C2 // rather than A->B->C. - cert_verifier_->AddResultForCert(certs[0].get(), verify_result, net::OK); + cert_verifier_->AddResultForCert(certs[0].get(), verify_result, OK); // Load and install the root for the validated chain. - scoped_refptr<net::X509Certificate> root_cert = - net::ImportCertFromFile(net::GetTestCertsDirectory(), - "redundant-validated-chain-root.pem"); - ASSERT_NE(static_cast<net::X509Certificate*>(NULL), root_cert); - net::ScopedTestRoot scoped_root(root_cert.get()); + scoped_refptr<X509Certificate> root_cert = ImportCertFromFile( + GetTestCertsDirectory(), "redundant-validated-chain-root.pem"); + ASSERT_NE(static_cast<X509Certificate*>(NULL), root_cert); + ScopedTestRoot scoped_root(root_cert.get()); // Set up a test server with CERT_CHAIN_WRONG_ROOT. - net::SpawnedTestServer::SSLOptions ssl_options( - net::SpawnedTestServer::SSLOptions::CERT_CHAIN_WRONG_ROOT); - net::SpawnedTestServer test_server( - net::SpawnedTestServer::TYPE_HTTPS, ssl_options, + SpawnedTestServer::SSLOptions ssl_options( + SpawnedTestServer::SSLOptions::CERT_CHAIN_WRONG_ROOT); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, + ssl_options, base::FilePath(FILE_PATH_LITERAL("net/data/ssl"))); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); - if (rv == net::ERR_IO_PENDING) + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); log.GetEntries(&entries); EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); - net::SSLInfo ssl_info; + SSLInfo ssl_info; sock->GetSSLInfo(&ssl_info); // Verify that SSLInfo contains the corrected re-constructed chain A -> B // -> C2. - const net::X509Certificate::OSCertHandles& intermediates = + const X509Certificate::OSCertHandles& intermediates = ssl_info.cert->GetIntermediateCertificates(); ASSERT_EQ(2U, intermediates.size()); - EXPECT_TRUE(net::X509Certificate::IsSameOSCert( - ssl_info.cert->os_cert_handle(), certs[0]->os_cert_handle())); - EXPECT_TRUE(net::X509Certificate::IsSameOSCert( - intermediates[0], certs[1]->os_cert_handle())); - EXPECT_TRUE(net::X509Certificate::IsSameOSCert( - intermediates[1], certs[2]->os_cert_handle())); + EXPECT_TRUE(X509Certificate::IsSameOSCert(ssl_info.cert->os_cert_handle(), + certs[0]->os_cert_handle())); + EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[0], + certs[1]->os_cert_handle())); + EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[1], + certs[2]->os_cert_handle())); sock->Disconnect(); EXPECT_FALSE(sock->IsConnected()); @@ -1729,37 +1706,34 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { protected: // Creates a test server with the given SSLOptions, connects to it and returns // the SSLCertRequestInfo reported by the socket. - scoped_refptr<net::SSLCertRequestInfo> GetCertRequest( - net::SpawnedTestServer::SSLOptions ssl_options) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath()); + scoped_refptr<SSLCertRequestInfo> GetCertRequest( + SpawnedTestServer::SSLOptions ssl_options) { + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); if (!test_server.Start()) return NULL; - net::AddressList addr; + AddressList addr; if (!test_server.GetAddressList(&addr)) return NULL; - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - scoped_refptr<net::SSLCertRequestInfo> request_info = - new net::SSLCertRequestInfo(); + scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo(); sock->GetSSLCertRequestInfo(request_info.get()); sock->Disconnect(); EXPECT_FALSE(sock->IsConnected()); @@ -1769,10 +1743,9 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { }; TEST_F(SSLClientSocketCertRequestInfoTest, NoAuthorities) { - net::SpawnedTestServer::SSLOptions ssl_options; + SpawnedTestServer::SSLOptions ssl_options; ssl_options.request_client_certificate = true; - scoped_refptr<net::SSLCertRequestInfo> request_info = - GetCertRequest(ssl_options); + scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options); ASSERT_TRUE(request_info.get()); EXPECT_EQ(0u, request_info->cert_authorities.size()); } @@ -1781,39 +1754,36 @@ TEST_F(SSLClientSocketCertRequestInfoTest, TwoAuthorities) { const base::FilePath::CharType kThawteFile[] = FILE_PATH_LITERAL("thawte.single.pem"); const unsigned char kThawteDN[] = { - 0x30, 0x4c, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, - 0x02, 0x5a, 0x41, 0x31, 0x25, 0x30, 0x23, 0x06, 0x03, 0x55, 0x04, 0x0a, - 0x13, 0x1c, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, 0x43, 0x6f, 0x6e, - 0x73, 0x75, 0x6c, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x28, 0x50, 0x74, 0x79, - 0x29, 0x20, 0x4c, 0x74, 0x64, 0x2e, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, - 0x55, 0x04, 0x03, 0x13, 0x0d, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, - 0x53, 0x47, 0x43, 0x20, 0x43, 0x41 - }; + 0x30, 0x4c, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, + 0x02, 0x5a, 0x41, 0x31, 0x25, 0x30, 0x23, 0x06, 0x03, 0x55, 0x04, 0x0a, + 0x13, 0x1c, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, 0x43, 0x6f, 0x6e, + 0x73, 0x75, 0x6c, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x28, 0x50, 0x74, 0x79, + 0x29, 0x20, 0x4c, 0x74, 0x64, 0x2e, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, + 0x55, 0x04, 0x03, 0x13, 0x0d, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, + 0x53, 0x47, 0x43, 0x20, 0x43, 0x41}; const size_t kThawteLen = sizeof(kThawteDN); const base::FilePath::CharType kDiginotarFile[] = FILE_PATH_LITERAL("diginotar_root_ca.pem"); const unsigned char kDiginotarDN[] = { - 0x30, 0x5f, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, - 0x02, 0x4e, 0x4c, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x0a, - 0x13, 0x09, 0x44, 0x69, 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x31, - 0x1a, 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x11, 0x44, 0x69, - 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x20, 0x52, 0x6f, 0x6f, 0x74, - 0x20, 0x43, 0x41, 0x31, 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48, - 0x86, 0xf7, 0x0d, 0x01, 0x09, 0x01, 0x16, 0x11, 0x69, 0x6e, 0x66, 0x6f, - 0x40, 0x64, 0x69, 0x67, 0x69, 0x6e, 0x6f, 0x74, 0x61, 0x72, 0x2e, 0x6e, - 0x6c - }; + 0x30, 0x5f, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, + 0x02, 0x4e, 0x4c, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x0a, + 0x13, 0x09, 0x44, 0x69, 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x31, + 0x1a, 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x11, 0x44, 0x69, + 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x20, 0x52, 0x6f, 0x6f, 0x74, + 0x20, 0x43, 0x41, 0x31, 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48, + 0x86, 0xf7, 0x0d, 0x01, 0x09, 0x01, 0x16, 0x11, 0x69, 0x6e, 0x66, 0x6f, + 0x40, 0x64, 0x69, 0x67, 0x69, 0x6e, 0x6f, 0x74, 0x61, 0x72, 0x2e, 0x6e, + 0x6c}; const size_t kDiginotarLen = sizeof(kDiginotarDN); - net::SpawnedTestServer::SSLOptions ssl_options; + SpawnedTestServer::SSLOptions ssl_options; ssl_options.request_client_certificate = true; ssl_options.client_authorities.push_back( - net::GetTestClientCertsDirectory().Append(kThawteFile)); + GetTestClientCertsDirectory().Append(kThawteFile)); ssl_options.client_authorities.push_back( - net::GetTestClientCertsDirectory().Append(kDiginotarFile)); - scoped_refptr<net::SSLCertRequestInfo> request_info = - GetCertRequest(ssl_options); + GetTestClientCertsDirectory().Append(kDiginotarFile)); + scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options); ASSERT_TRUE(request_info.get()); ASSERT_EQ(2u, request_info->cert_authorities.size()); EXPECT_EQ(std::string(reinterpret_cast<const char*>(kThawteDN), kThawteLen), @@ -1822,3 +1792,7 @@ TEST_F(SSLClientSocketCertRequestInfoTest, TwoAuthorities) { std::string(reinterpret_cast<const char*>(kDiginotarDN), kDiginotarLen), request_info->cert_authorities[1]); } + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/ssl_server_socket.h b/chromium/net/socket/ssl_server_socket.h index 52d53cb19a2..8b607bf80cf 100644 --- a/chromium/net/socket/ssl_server_socket.h +++ b/chromium/net/socket/ssl_server_socket.h @@ -6,6 +6,7 @@ #define NET_SOCKET_SSL_SERVER_SOCKET_H_ #include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" #include "net/socket/ssl_socket.h" @@ -52,8 +53,8 @@ NET_EXPORT void EnableSSLServerSockets(); // // The caller starts the SSL server handshake by calling Handshake on the // returned socket. -NET_EXPORT SSLServerSocket* CreateSSLServerSocket( - StreamSocket* socket, +NET_EXPORT scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket, X509Certificate* certificate, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config); diff --git a/chromium/net/socket/ssl_server_socket_nss.cc b/chromium/net/socket/ssl_server_socket_nss.cc index c2681d3ee14..7e5d70118ac 100644 --- a/chromium/net/socket/ssl_server_socket_nss.cc +++ b/chromium/net/socket/ssl_server_socket_nss.cc @@ -78,19 +78,20 @@ void EnableSSLServerSockets() { g_nss_ssl_server_init_singleton.Get(); } -SSLServerSocket* CreateSSLServerSocket( - StreamSocket* socket, +scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket, X509Certificate* cert, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config) { DCHECK(g_nss_server_sockets_init) << "EnableSSLServerSockets() has not been" << "called yet!"; - return new SSLServerSocketNSS(socket, cert, key, ssl_config); + return scoped_ptr<SSLServerSocket>( + new SSLServerSocketNSS(socket.Pass(), cert, key, ssl_config)); } SSLServerSocketNSS::SSLServerSocketNSS( - StreamSocket* transport_socket, + scoped_ptr<StreamSocket> transport_socket, scoped_refptr<X509Certificate> cert, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config) @@ -100,7 +101,7 @@ SSLServerSocketNSS::SSLServerSocketNSS( user_write_buf_len_(0), nss_fd_(NULL), nss_bufs_(NULL), - transport_socket_(transport_socket), + transport_socket_(transport_socket.Pass()), ssl_config_(ssl_config), cert_(cert), next_handshake_state_(STATE_NONE), diff --git a/chromium/net/socket/ssl_server_socket_nss.h b/chromium/net/socket/ssl_server_socket_nss.h index 17a1fc38750..8bbb0e338ac 100644 --- a/chromium/net/socket/ssl_server_socket_nss.h +++ b/chromium/net/socket/ssl_server_socket_nss.h @@ -24,7 +24,7 @@ class SSLServerSocketNSS : public SSLServerSocket { public: // See comments on CreateSSLServerSocket for details of how these // parameters are used. - SSLServerSocketNSS(StreamSocket* socket, + SSLServerSocketNSS(scoped_ptr<StreamSocket> socket, scoped_refptr<X509Certificate> certificate, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config); diff --git a/chromium/net/socket/ssl_server_socket_openssl.cc b/chromium/net/socket/ssl_server_socket_openssl.cc index e0cf8bc0b21..c327f2caf10 100644 --- a/chromium/net/socket/ssl_server_socket_openssl.cc +++ b/chromium/net/socket/ssl_server_socket_openssl.cc @@ -16,13 +16,13 @@ void EnableSSLServerSockets() { NOTIMPLEMENTED(); } -SSLServerSocket* CreateSSLServerSocket(StreamSocket* socket, - X509Certificate* certificate, - crypto::RSAPrivateKey* key, - const SSLConfig& ssl_config) { +scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket, + X509Certificate* certificate, + crypto::RSAPrivateKey* key, + const SSLConfig& ssl_config) { NOTIMPLEMENTED(); - delete socket; - return NULL; + return scoped_ptr<SSLServerSocket>(); } } // namespace net diff --git a/chromium/net/socket/ssl_server_socket_unittest.cc b/chromium/net/socket/ssl_server_socket_unittest.cc index f931e2c957e..e1f7f496131 100644 --- a/chromium/net/socket/ssl_server_socket_unittest.cc +++ b/chromium/net/socket/ssl_server_socket_unittest.cc @@ -304,21 +304,24 @@ class SSLServerSocketTest : public PlatformTest { protected: void Initialize() { - FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_); - FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_); + scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle); + client_connection->SetSocket( + scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_))); + scoped_ptr<StreamSocket> server_socket( + new FakeSocket(&channel_2_, &channel_1_)); base::FilePath certs_dir(GetTestCertsDirectory()); base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); std::string cert_der; - ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der)); + ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der)); scoped_refptr<net::X509Certificate> cert = X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); std::string key_string; - ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string)); + ASSERT_TRUE(base::ReadFileToString(key_path, &key_string)); std::vector<uint8> key_vector( reinterpret_cast<const uint8*>(key_string.data()), reinterpret_cast<const uint8*>(key_string.data() + @@ -344,11 +347,12 @@ class SSLServerSocketTest : public PlatformTest { net::SSLClientSocketContext context; context.cert_verifier = cert_verifier_.get(); context.transport_security_state = transport_security_state_.get(); - client_socket_.reset( + client_socket_ = socket_factory_->CreateSSLClientSocket( - fake_client_socket, host_and_pair, ssl_config, context)); - server_socket_.reset(net::CreateSSLServerSocket( - fake_server_socket, cert.get(), private_key.get(), net::SSLConfig())); + client_connection.Pass(), host_and_pair, ssl_config, context); + server_socket_ = net::CreateSSLServerSocket( + server_socket.Pass(), + cert.get(), private_key.get(), net::SSLConfig()); } FakeDataChannel channel_1_; diff --git a/chromium/net/socket/stream_listen_socket.cc b/chromium/net/socket/stream_listen_socket.cc index c85c671800d..1109e7527c3 100644 --- a/chromium/net/socket/stream_listen_socket.cc +++ b/chromium/net/socket/stream_listen_socket.cc @@ -27,6 +27,7 @@ #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" +#include "net/socket/socket_descriptor.h" using std::string; @@ -43,10 +44,8 @@ const int kReadBufSize = 4096; } // namespace #if defined(OS_WIN) -const SocketDescriptor StreamListenSocket::kInvalidSocket = INVALID_SOCKET; const int StreamListenSocket::kSocketError = SOCKET_ERROR; #elif defined(OS_POSIX) -const SocketDescriptor StreamListenSocket::kInvalidSocket = -1; const int StreamListenSocket::kSocketError = -1; #endif diff --git a/chromium/net/socket/stream_listen_socket.h b/chromium/net/socket/stream_listen_socket.h index 6f03eefaca2..9825a4ef126 100644 --- a/chromium/net/socket/stream_listen_socket.h +++ b/chromium/net/socket/stream_listen_socket.h @@ -30,21 +30,16 @@ #include "base/basictypes.h" #include "base/compiler_specific.h" +#include "base/memory/scoped_ptr.h" #include "net/base/net_export.h" -#include "net/socket/stream_listen_socket.h" - -#if defined(OS_POSIX) -typedef int SocketDescriptor; -#else -typedef SOCKET SocketDescriptor; -#endif +#include "net/socket/socket_descriptor.h" namespace net { class IPEndPoint; class NET_EXPORT StreamListenSocket - : public base::RefCountedThreadSafe<StreamListenSocket>, + : #if defined(OS_WIN) public base::win::ObjectWatcher::Delegate { #elif defined(OS_POSIX) @@ -52,16 +47,17 @@ class NET_EXPORT StreamListenSocket #endif public: + virtual ~StreamListenSocket(); + // TODO(erikkay): this delegate should really be split into two parts // to split up the listener from the connected socket. Perhaps this class // should be split up similarly. class Delegate { public: // |server| is the original listening Socket, connection is the new - // Socket that was created. Ownership of |connection| is transferred - // to the delegate with this call. + // Socket that was created. virtual void DidAccept(StreamListenSocket* server, - StreamListenSocket* connection) = 0; + scoped_ptr<StreamListenSocket> connection) = 0; virtual void DidRead(StreamListenSocket* connection, const char* data, int len) = 0; @@ -78,7 +74,6 @@ class NET_EXPORT StreamListenSocket // Copies the local address to |address|. Returns a network error code. int GetLocalAddress(IPEndPoint* address); - static const SocketDescriptor kInvalidSocket; static const int kSocketError; protected: @@ -89,7 +84,6 @@ class NET_EXPORT StreamListenSocket }; StreamListenSocket(SocketDescriptor s, Delegate* del); - virtual ~StreamListenSocket(); SocketDescriptor AcceptSocket(); virtual void Accept() = 0; @@ -107,7 +101,6 @@ class NET_EXPORT StreamListenSocket Delegate* const socket_delegate_; private: - friend class base::RefCountedThreadSafe<StreamListenSocket>; friend class TransportClientSocketTest; void SendInternal(const char* bytes, int len); @@ -146,7 +139,7 @@ class NET_EXPORT StreamListenSocketFactory { virtual ~StreamListenSocketFactory() {} // Returns a new instance of StreamListenSocket or NULL if an error occurred. - virtual scoped_refptr<StreamListenSocket> CreateAndListen( + virtual scoped_ptr<StreamListenSocket> CreateAndListen( StreamListenSocket::Delegate* delegate) const = 0; }; diff --git a/chromium/net/socket/tcp_client_socket.cc b/chromium/net/socket/tcp_client_socket.cc index dbd21056f39..22aea47778b 100644 --- a/chromium/net/socket/tcp_client_socket.cc +++ b/chromium/net/socket/tcp_client_socket.cc @@ -4,56 +4,317 @@ #include "net/socket/tcp_client_socket.h" -#include "base/file_util.h" -#include "base/files/file_path.h" +#include "base/callback_helpers.h" +#include "base/logging.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" namespace net { -namespace { +TCPClientSocket::TCPClientSocket(const AddressList& addresses, + net::NetLog* net_log, + const net::NetLog::Source& source) + : socket_(new TCPSocket(net_log, source)), + addresses_(addresses), + current_address_index_(-1), + next_connect_state_(CONNECT_STATE_NONE), + previously_disconnected_(false) { +} + +TCPClientSocket::TCPClientSocket(scoped_ptr<TCPSocket> connected_socket, + const IPEndPoint& peer_address) + : socket_(connected_socket.Pass()), + addresses_(AddressList(peer_address)), + current_address_index_(0), + next_connect_state_(CONNECT_STATE_NONE), + previously_disconnected_(false) { + DCHECK(socket_); + + socket_->SetDefaultOptionsForClient(); + use_history_.set_was_ever_connected(); +} + +TCPClientSocket::~TCPClientSocket() { +} + +int TCPClientSocket::Bind(const IPEndPoint& address) { + if (current_address_index_ >= 0 || bind_address_) { + // Cannot bind the socket if we are already connected or connecting. + NOTREACHED(); + return ERR_UNEXPECTED; + } + + int result = OK; + if (!socket_->IsValid()) { + result = OpenSocket(address.GetFamily()); + if (result != OK) + return result; + } + + result = socket_->Bind(address); + if (result != OK) + return result; + + bind_address_.reset(new IPEndPoint(address)); + return OK; +} + +int TCPClientSocket::Connect(const CompletionCallback& callback) { + DCHECK(!callback.is_null()); + + // If connecting or already connected, then just return OK. + if (socket_->IsValid() && current_address_index_ >= 0) + return OK; + + socket_->StartLoggingMultipleConnectAttempts(addresses_); + + // We will try to connect to each address in addresses_. Start with the + // first one in the list. + next_connect_state_ = CONNECT_STATE_CONNECT; + current_address_index_ = 0; + + int rv = DoConnectLoop(OK); + if (rv == ERR_IO_PENDING) { + connect_callback_ = callback; + } else { + socket_->EndLoggingMultipleConnectAttempts(rv); + } + + return rv; +} -#if defined(OS_LINUX) +int TCPClientSocket::DoConnectLoop(int result) { + DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE); -// Checks to see if the system supports TCP FastOpen. Notably, it requires -// kernel support. Additionally, this checks system configuration to ensure that -// it's enabled. -bool SystemSupportsTCPFastOpen() { - static const base::FilePath::CharType kTCPFastOpenProcFilePath[] = - "/proc/sys/net/ipv4/tcp_fastopen"; - std::string system_enabled_tcp_fastopen; - if (!file_util::ReadFileToString( - base::FilePath(kTCPFastOpenProcFilePath), - &system_enabled_tcp_fastopen)) { - return false; + int rv = result; + do { + ConnectState state = next_connect_state_; + next_connect_state_ = CONNECT_STATE_NONE; + switch (state) { + case CONNECT_STATE_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoConnect(); + break; + case CONNECT_STATE_CONNECT_COMPLETE: + rv = DoConnectComplete(rv); + break; + default: + NOTREACHED() << "bad state " << state; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE); + + return rv; +} + +int TCPClientSocket::DoConnect() { + DCHECK_GE(current_address_index_, 0); + DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size())); + + const IPEndPoint& endpoint = addresses_[current_address_index_]; + + if (previously_disconnected_) { + use_history_.Reset(); + previously_disconnected_ = false; } - // As per http://lxr.linux.no/linux+v3.7.7/include/net/tcp.h#L225 - // TFO_CLIENT_ENABLE is the LSB - if (system_enabled_tcp_fastopen.empty() || - (system_enabled_tcp_fastopen[0] & 0x1) == 0) { - return false; + next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE; + + if (socket_->IsValid()) { + DCHECK(bind_address_); + } else { + int result = OpenSocket(endpoint.GetFamily()); + if (result != OK) + return result; + + if (bind_address_) { + result = socket_->Bind(*bind_address_); + if (result != OK) { + socket_->Close(); + return result; + } + } + } + + // |socket_| is owned by this class and the callback won't be run once + // |socket_| is gone. Therefore, it is safe to use base::Unretained() here. + return socket_->Connect(endpoint, + base::Bind(&TCPClientSocket::DidCompleteConnect, + base::Unretained(this))); +} + +int TCPClientSocket::DoConnectComplete(int result) { + if (result == OK) { + use_history_.set_was_ever_connected(); + return OK; // Done! + } + + // Close whatever partially connected socket we currently have. + DoDisconnect(); + + // Try to fall back to the next address in the list. + if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) { + next_connect_state_ = CONNECT_STATE_CONNECT; + ++current_address_index_; + return OK; + } + + // Otherwise there is nothing to fall back to, so give up. + return result; +} + +void TCPClientSocket::Disconnect() { + DoDisconnect(); + current_address_index_ = -1; + bind_address_.reset(); +} + +void TCPClientSocket::DoDisconnect() { + // If connecting or already connected, record that the socket has been + // disconnected. + previously_disconnected_ = socket_->IsValid() && current_address_index_ >= 0; + socket_->Close(); +} + +bool TCPClientSocket::IsConnected() const { + return socket_->IsConnected(); +} + +bool TCPClientSocket::IsConnectedAndIdle() const { + return socket_->IsConnectedAndIdle(); +} + +int TCPClientSocket::GetPeerAddress(IPEndPoint* address) const { + return socket_->GetPeerAddress(address); +} + +int TCPClientSocket::GetLocalAddress(IPEndPoint* address) const { + DCHECK(address); + + if (!socket_->IsValid()) { + if (bind_address_) { + *address = *bind_address_; + return OK; + } + return ERR_SOCKET_NOT_CONNECTED; } - return true; + return socket_->GetLocalAddress(address); +} + +const BoundNetLog& TCPClientSocket::NetLog() const { + return socket_->net_log(); +} + +void TCPClientSocket::SetSubresourceSpeculation() { + use_history_.set_subresource_speculation(); +} + +void TCPClientSocket::SetOmniboxSpeculation() { + use_history_.set_omnibox_speculation(); +} + +bool TCPClientSocket::WasEverUsed() const { + return use_history_.was_used_to_convey_data(); +} + +bool TCPClientSocket::UsingTCPFastOpen() const { + return socket_->UsingTCPFastOpen(); +} + +bool TCPClientSocket::WasNpnNegotiated() const { + return false; } -#else +NextProto TCPClientSocket::GetNegotiatedProtocol() const { + return kProtoUnknown; +} -bool SystemSupportsTCPFastOpen() { +bool TCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { return false; } -#endif +int TCPClientSocket::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(!callback.is_null()); + + // |socket_| is owned by this class and the callback won't be run once + // |socket_| is gone. Therefore, it is safe to use base::Unretained() here. + CompletionCallback read_callback = base::Bind( + &TCPClientSocket::DidCompleteReadWrite, base::Unretained(this), callback); + int result = socket_->Read(buf, buf_len, read_callback); + if (result > 0) + use_history_.set_was_used_to_convey_data(); + + return result; +} + +int TCPClientSocket::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(!callback.is_null()); + + // |socket_| is owned by this class and the callback won't be run once + // |socket_| is gone. Therefore, it is safe to use base::Unretained() here. + CompletionCallback write_callback = base::Bind( + &TCPClientSocket::DidCompleteReadWrite, base::Unretained(this), callback); + int result = socket_->Write(buf, buf_len, write_callback); + if (result > 0) + use_history_.set_was_used_to_convey_data(); + return result; } -static bool g_tcp_fastopen_enabled = false; +bool TCPClientSocket::SetReceiveBufferSize(int32 size) { + return socket_->SetReceiveBufferSize(size); +} + +bool TCPClientSocket::SetSendBufferSize(int32 size) { + return socket_->SetSendBufferSize(size); +} + +bool TCPClientSocket::SetKeepAlive(bool enable, int delay) { + return socket_->SetKeepAlive(enable, delay); +} -void SetTCPFastOpenEnabled(bool value) { - g_tcp_fastopen_enabled = value && SystemSupportsTCPFastOpen(); +bool TCPClientSocket::SetNoDelay(bool no_delay) { + return socket_->SetNoDelay(no_delay); } -bool IsTCPFastOpenEnabled() { - return g_tcp_fastopen_enabled; +void TCPClientSocket::DidCompleteConnect(int result) { + DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE); + DCHECK_NE(result, ERR_IO_PENDING); + DCHECK(!connect_callback_.is_null()); + + result = DoConnectLoop(result); + if (result != ERR_IO_PENDING) { + socket_->EndLoggingMultipleConnectAttempts(result); + base::ResetAndReturn(&connect_callback_).Run(result); + } +} + +void TCPClientSocket::DidCompleteReadWrite(const CompletionCallback& callback, + int result) { + if (result > 0) + use_history_.set_was_used_to_convey_data(); + + callback.Run(result); +} + +int TCPClientSocket::OpenSocket(AddressFamily family) { + DCHECK(!socket_->IsValid()); + + int result = socket_->Open(family); + if (result != OK) + return result; + + socket_->SetDefaultOptionsForClient(); + + return OK; } } // namespace net diff --git a/chromium/net/socket/tcp_client_socket.h b/chromium/net/socket/tcp_client_socket.h index 8a2c0cd73f0..fabcbc1b39d 100644 --- a/chromium/net/socket/tcp_client_socket.h +++ b/chromium/net/socket/tcp_client_socket.h @@ -5,30 +5,116 @@ #ifndef NET_SOCKET_TCP_CLIENT_SOCKET_H_ #define NET_SOCKET_TCP_CLIENT_SOCKET_H_ -#include "build/build_config.h" +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/address_list.h" +#include "net/base/completion_callback.h" #include "net/base/net_export.h" - -#if defined(OS_WIN) -#include "net/socket/tcp_client_socket_win.h" -#elif defined(OS_POSIX) -#include "net/socket/tcp_client_socket_libevent.h" -#endif +#include "net/base/net_log.h" +#include "net/socket/stream_socket.h" +#include "net/socket/tcp_socket.h" namespace net { // A client socket that uses TCP as the transport layer. -#if defined(OS_WIN) -typedef TCPClientSocketWin TCPClientSocket; -#elif defined(OS_POSIX) -typedef TCPClientSocketLibevent TCPClientSocket; -#endif - -// Enable/disable experimental TCP FastOpen option. -// Not thread safe. Must be called during initialization/startup only. -NET_EXPORT void SetTCPFastOpenEnabled(bool value); - -// Check if the TCP FastOpen option is enabled. -bool IsTCPFastOpenEnabled(); +class NET_EXPORT TCPClientSocket : public StreamSocket { + public: + // The IP address(es) and port number to connect to. The TCP socket will try + // each IP address in the list until it succeeds in establishing a + // connection. + TCPClientSocket(const AddressList& addresses, + net::NetLog* net_log, + const net::NetLog::Source& source); + + // Adopts the given, connected socket and then acts as if Connect() had been + // called. This function is used by TCPServerSocket and for testing. + TCPClientSocket(scoped_ptr<TCPSocket> connected_socket, + const IPEndPoint& peer_address); + + virtual ~TCPClientSocket(); + + // Binds the socket to a local IP address and port. + int Bind(const IPEndPoint& address); + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE; + virtual void SetSubresourceSpeculation() OVERRIDE; + virtual void SetOmniboxSpeculation() OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + // Socket implementation. + // Multiple outstanding requests are not supported. + // Full duplex mode (reading and writing at the same time) is supported. + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + + virtual bool SetKeepAlive(bool enable, int delay); + virtual bool SetNoDelay(bool no_delay); + + private: + // State machine for connecting the socket. + enum ConnectState { + CONNECT_STATE_CONNECT, + CONNECT_STATE_CONNECT_COMPLETE, + CONNECT_STATE_NONE, + }; + + // State machine used by Connect(). + int DoConnectLoop(int result); + int DoConnect(); + int DoConnectComplete(int result); + + // Helper used by Disconnect(), which disconnects minus resetting + // current_address_index_ and bind_address_. + void DoDisconnect(); + + void DidCompleteConnect(int result); + void DidCompleteReadWrite(const CompletionCallback& callback, int result); + + int OpenSocket(AddressFamily family); + + scoped_ptr<TCPSocket> socket_; + + // Local IP address and port we are bound to. Set to NULL if Bind() + // wasn't called (in that case OS chooses address/port). + scoped_ptr<IPEndPoint> bind_address_; + + // The list of addresses we should try in order to establish a connection. + AddressList addresses_; + + // Where we are in above list. Set to -1 if uninitialized. + int current_address_index_; + + // External callback; called when connect is complete. + CompletionCallback connect_callback_; + + // The next state for the Connect() state machine. + ConnectState next_connect_state_; + + // This socket was previously disconnected and has not been re-connected. + bool previously_disconnected_; + + // Record of connectivity and transmissions, for use in speculative connection + // histograms. + UseHistory use_history_; + + DISALLOW_COPY_AND_ASSIGN(TCPClientSocket); +}; } // namespace net diff --git a/chromium/net/socket/tcp_client_socket_libevent.h b/chromium/net/socket/tcp_client_socket_libevent.h deleted file mode 100644 index e5a0d8deab4..00000000000 --- a/chromium/net/socket/tcp_client_socket_libevent.h +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_ -#define NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_ - -#include "base/memory/ref_counted.h" -#include "base/memory/scoped_ptr.h" -#include "base/message_loop/message_loop.h" -#include "base/threading/non_thread_safe.h" -#include "net/base/address_list.h" -#include "net/base/completion_callback.h" -#include "net/base/net_log.h" -#include "net/socket/stream_socket.h" - -namespace net { - -class BoundNetLog; - -// A client socket that uses TCP as the transport layer. -class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, - public base::NonThreadSafe { - public: - // The IP address(es) and port number to connect to. The TCP socket will try - // each IP address in the list until it succeeds in establishing a - // connection. - TCPClientSocketLibevent(const AddressList& addresses, - net::NetLog* net_log, - const net::NetLog::Source& source); - - virtual ~TCPClientSocketLibevent(); - - // AdoptSocket causes the given, connected socket to be adopted as a TCP - // socket. This object must not be connected. This object takes ownership of - // the given socket and then acts as if Connect() had been called. This - // function is used by TCPServerSocket() to adopt accepted connections - // and for testing. - int AdoptSocket(int socket); - - // Binds the socket to a local IP address and port. - int Bind(const IPEndPoint& address); - - // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool IsConnectedAndIdle() const OVERRIDE; - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual const BoundNetLog& NetLog() const OVERRIDE; - virtual void SetSubresourceSpeculation() OVERRIDE; - virtual void SetOmniboxSpeculation() OVERRIDE; - virtual bool WasEverUsed() const OVERRIDE; - virtual bool UsingTCPFastOpen() const OVERRIDE; - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual NextProto GetNegotiatedProtocol() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; - - // Socket implementation. - // Multiple outstanding requests are not supported. - // Full duplex mode (reading and writing at the same time) is supported - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; - virtual bool SetSendBufferSize(int32 size) OVERRIDE; - - virtual bool SetKeepAlive(bool enable, int delay); - virtual bool SetNoDelay(bool no_delay); - - private: - // State machine for connecting the socket. - enum ConnectState { - CONNECT_STATE_CONNECT, - CONNECT_STATE_CONNECT_COMPLETE, - CONNECT_STATE_NONE, - }; - - // States that a fast open socket attempt can result in. - enum FastOpenStatus { - FAST_OPEN_STATUS_UNKNOWN, - - // The initial fast open connect attempted returned synchronously, - // indicating that we had and sent a cookie along with the initial data. - FAST_OPEN_FAST_CONNECT_RETURN, - - // The initial fast open connect attempted returned asynchronously, - // indicating that we did not have a cookie for the server. - FAST_OPEN_SLOW_CONNECT_RETURN, - - // Some other error occurred on connection, so we couldn't tell if - // fast open would have worked. - FAST_OPEN_ERROR, - - // An attempt to do a fast open succeeded immediately - // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server - // had acked the data we sent. - FAST_OPEN_SYN_DATA_ACK, - - // An attempt to do a fast open succeeded immediately - // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server - // had nacked the data we sent. - FAST_OPEN_SYN_DATA_NACK, - - // An attempt to do a fast open succeeded immediately - // (FAST_OPEN_FAST_CONNECT_RETURN) and our probe to determine if the - // socket was using fast open failed. - FAST_OPEN_SYN_DATA_FAILED, - - // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) - // and we later confirmed that the server had acked initial data. This - // should never happen (we didn't send data, so it shouldn't have - // been acked). - FAST_OPEN_NO_SYN_DATA_ACK, - - // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) - // and we later discovered that the server had nacked initial data. This - // is the expected case results for FAST_OPEN_SLOW_CONNECT_RETURN. - FAST_OPEN_NO_SYN_DATA_NACK, - - // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) - // and our later probe for ack/nack state failed. - FAST_OPEN_NO_SYN_DATA_FAILED, - - FAST_OPEN_MAX_VALUE - }; - - class ReadWatcher : public base::MessageLoopForIO::Watcher { - public: - explicit ReadWatcher(TCPClientSocketLibevent* socket) : socket_(socket) {} - - // MessageLoopForIO::Watcher methods - - virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE; - - virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE {} - - private: - TCPClientSocketLibevent* const socket_; - - DISALLOW_COPY_AND_ASSIGN(ReadWatcher); - }; - - class WriteWatcher : public base::MessageLoopForIO::Watcher { - public: - explicit WriteWatcher(TCPClientSocketLibevent* socket) : socket_(socket) {} - - // MessageLoopForIO::Watcher implementation. - virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE {} - virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE; - - private: - TCPClientSocketLibevent* const socket_; - - DISALLOW_COPY_AND_ASSIGN(WriteWatcher); - }; - - // State machine used by Connect(). - int DoConnectLoop(int result); - int DoConnect(); - int DoConnectComplete(int result); - - // Helper used by Disconnect(), which disconnects minus the logging and - // resetting of current_address_index_. - void DoDisconnect(); - - void DoReadCallback(int rv); - void DoWriteCallback(int rv); - void DidCompleteRead(); - void DidCompleteWrite(); - void DidCompleteConnect(); - - // Returns true if a Connect() is in progress. - bool waiting_connect() const { - return next_connect_state_ != CONNECT_STATE_NONE; - } - - // Helper to add a TCP_CONNECT (end) event to the NetLog. - void LogConnectCompletion(int net_error); - - // Internal function to write to a socket. - int InternalWrite(IOBuffer* buf, int buf_len); - - // Called when the socket is known to be in a connected state. - void RecordFastOpenStatus(); - - int socket_; - - // Local IP address and port we are bound to. Set to NULL if Bind() - // was't called (in that cases OS chooses address/port). - scoped_ptr<IPEndPoint> bind_address_; - - // Stores bound socket between Bind() and Connect() calls. - int bound_socket_; - - // The list of addresses we should try in order to establish a connection. - AddressList addresses_; - - // Where we are in above list. Set to -1 if uninitialized. - int current_address_index_; - - // The socket's libevent wrappers - base::MessageLoopForIO::FileDescriptorWatcher read_socket_watcher_; - base::MessageLoopForIO::FileDescriptorWatcher write_socket_watcher_; - - // The corresponding watchers for reads and writes. - ReadWatcher read_watcher_; - WriteWatcher write_watcher_; - - // The buffer used by OnSocketReady to retry Read requests - scoped_refptr<IOBuffer> read_buf_; - int read_buf_len_; - - // The buffer used by OnSocketReady to retry Write requests - scoped_refptr<IOBuffer> write_buf_; - int write_buf_len_; - - // External callback; called when read is complete. - CompletionCallback read_callback_; - - // External callback; called when write is complete. - CompletionCallback write_callback_; - - // The next state for the Connect() state machine. - ConnectState next_connect_state_; - - // The OS error that CONNECT_STATE_CONNECT last completed with. - int connect_os_error_; - - BoundNetLog net_log_; - - // This socket was previously disconnected and has not been re-connected. - bool previously_disconnected_; - - // Record of connectivity and transmissions, for use in speculative connection - // histograms. - UseHistory use_history_; - - // Enables experimental TCP FastOpen option. - const bool use_tcp_fastopen_; - - // True when TCP FastOpen is in use and we have done the connect. - bool tcp_fastopen_connected_; - - enum FastOpenStatus fast_open_status_; - - DISALLOW_COPY_AND_ASSIGN(TCPClientSocketLibevent); -}; - -} // namespace net - -#endif // NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_ diff --git a/chromium/net/socket/tcp_client_socket_win.h b/chromium/net/socket/tcp_client_socket_win.h deleted file mode 100644 index 26c8b9feff2..00000000000 --- a/chromium/net/socket/tcp_client_socket_win.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_ -#define NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_ - -#include <winsock2.h> - -#include "base/memory/scoped_ptr.h" -#include "base/threading/non_thread_safe.h" -#include "net/base/address_list.h" -#include "net/base/completion_callback.h" -#include "net/base/net_log.h" -#include "net/socket/stream_socket.h" - -namespace net { - -class BoundNetLog; - -class NET_EXPORT TCPClientSocketWin : public StreamSocket, - NON_EXPORTED_BASE(base::NonThreadSafe) { - public: - // The IP address(es) and port number to connect to. The TCP socket will try - // each IP address in the list until it succeeds in establishing a - // connection. - TCPClientSocketWin(const AddressList& addresses, - net::NetLog* net_log, - const net::NetLog::Source& source); - - virtual ~TCPClientSocketWin(); - - // AdoptSocket causes the given, connected socket to be adopted as a TCP - // socket. This object must not be connected. This object takes ownership of - // the given socket and then acts as if Connect() had been called. This - // function is used by TCPServerSocket() to adopt accepted connections - // and for testing. - int AdoptSocket(SOCKET socket); - - // Binds the socket to a local IP address and port. - int Bind(const IPEndPoint& address); - - // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback); - virtual void Disconnect(); - virtual bool IsConnected() const; - virtual bool IsConnectedAndIdle() const; - virtual int GetPeerAddress(IPEndPoint* address) const; - virtual int GetLocalAddress(IPEndPoint* address) const; - virtual const BoundNetLog& NetLog() const { return net_log_; } - virtual void SetSubresourceSpeculation(); - virtual void SetOmniboxSpeculation(); - virtual bool WasEverUsed() const; - virtual bool UsingTCPFastOpen() const; - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual NextProto GetNegotiatedProtocol() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; - - // Socket implementation. - // Multiple outstanding requests are not supported. - // Full duplex mode (reading and writing at the same time) is supported - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback); - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback); - - virtual bool SetReceiveBufferSize(int32 size); - virtual bool SetSendBufferSize(int32 size); - - virtual bool SetKeepAlive(bool enable, int delay); - virtual bool SetNoDelay(bool no_delay); - - // Perform reads in non-blocking mode instead of overlapped mode. - // Used for experiments. - static void DisableOverlappedReads(); - - private: - // State machine for connecting the socket. - enum ConnectState { - CONNECT_STATE_CONNECT, - CONNECT_STATE_CONNECT_COMPLETE, - CONNECT_STATE_NONE, - }; - - class Core; - - // State machine used by Connect(). - int DoConnectLoop(int result); - int DoConnect(); - int DoConnectComplete(int result); - - // Helper used by Disconnect(), which disconnects minus the logging and - // resetting of current_address_index_. - void DoDisconnect(); - - // Returns true if a Connect() is in progress. - bool waiting_connect() const { - return next_connect_state_ != CONNECT_STATE_NONE; - } - - // Called after Connect() has completed with |net_error|. - void LogConnectCompletion(int net_error); - - int DoRead(IOBuffer* buf, int buf_len, const CompletionCallback& callback); - void DoReadCallback(int rv); - void DoWriteCallback(int rv); - void DidCompleteConnect(); - void DidCompleteRead(); - void DidCompleteWrite(); - void DidSignalRead(); - - SOCKET socket_; - - // Local IP address and port we are bound to. Set to NULL if Bind() - // was't called (in that cases OS chooses address/port). - scoped_ptr<IPEndPoint> bind_address_; - - // Stores bound socket between Bind() and Connect() calls. - SOCKET bound_socket_; - - // The list of addresses we should try in order to establish a connection. - AddressList addresses_; - - // Where we are in above list. Set to -1 if uninitialized. - int current_address_index_; - - // The various states that the socket could be in. - bool waiting_read_; - bool waiting_write_; - - // The core of the socket that can live longer than the socket itself. We pass - // resources to the Windows async IO functions and we have to make sure that - // they are not destroyed while the OS still references them. - scoped_refptr<Core> core_; - - // External callback; called when connect or read is complete. - CompletionCallback read_callback_; - - // External callback; called when write is complete. - CompletionCallback write_callback_; - - // The next state for the Connect() state machine. - ConnectState next_connect_state_; - - // The OS error that CONNECT_STATE_CONNECT last completed with. - int connect_os_error_; - - BoundNetLog net_log_; - - // This socket was previously disconnected and has not been re-connected. - bool previously_disconnected_; - - // Record of connectivity and transmissions, for use in speculative connection - // histograms. - UseHistory use_history_; - - DISALLOW_COPY_AND_ASSIGN(TCPClientSocketWin); -}; - -} // namespace net - -#endif // NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_ diff --git a/chromium/net/socket/tcp_listen_socket.cc b/chromium/net/socket/tcp_listen_socket.cc index aab2e45d0e9..223abee2cba 100644 --- a/chromium/net/socket/tcp_listen_socket.cc +++ b/chromium/net/socket/tcp_listen_socket.cc @@ -23,20 +23,21 @@ #include "build/build_config.h" #include "net/base/net_util.h" #include "net/base/winsock_init.h" +#include "net/socket/socket_descriptor.h" using std::string; namespace net { // static -scoped_refptr<TCPListenSocket> TCPListenSocket::CreateAndListen( +scoped_ptr<TCPListenSocket> TCPListenSocket::CreateAndListen( const string& ip, int port, StreamListenSocket::Delegate* del) { SocketDescriptor s = CreateAndBind(ip, port); if (s == kInvalidSocket) - return NULL; - scoped_refptr<TCPListenSocket> sock(new TCPListenSocket(s, del)); + return scoped_ptr<TCPListenSocket>(); + scoped_ptr<TCPListenSocket> sock(new TCPListenSocket(s, del)); sock->Listen(); - return sock; + return sock.Pass(); } TCPListenSocket::TCPListenSocket(SocketDescriptor s, @@ -47,11 +48,7 @@ TCPListenSocket::TCPListenSocket(SocketDescriptor s, TCPListenSocket::~TCPListenSocket() {} SocketDescriptor TCPListenSocket::CreateAndBind(const string& ip, int port) { -#if defined(OS_WIN) - EnsureWinsockInit(); -#endif - - SocketDescriptor s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + SocketDescriptor s = CreatePlatformSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (s != kInvalidSocket) { #if defined(OS_POSIX) // Allow rapid reuse. @@ -104,13 +101,13 @@ void TCPListenSocket::Accept() { SocketDescriptor conn = AcceptSocket(); if (conn == kInvalidSocket) return; - scoped_refptr<TCPListenSocket> sock( + scoped_ptr<TCPListenSocket> sock( new TCPListenSocket(conn, socket_delegate_)); // It's up to the delegate to AddRef if it wants to keep it around. #if defined(OS_POSIX) sock->WatchSocket(WAITING_READ); #endif - socket_delegate_->DidAccept(this, sock.get()); + socket_delegate_->DidAccept(this, sock.PassAs<StreamListenSocket>()); } TCPListenSocketFactory::TCPListenSocketFactory(const string& ip, int port) @@ -120,9 +117,10 @@ TCPListenSocketFactory::TCPListenSocketFactory(const string& ip, int port) TCPListenSocketFactory::~TCPListenSocketFactory() {} -scoped_refptr<StreamListenSocket> TCPListenSocketFactory::CreateAndListen( +scoped_ptr<StreamListenSocket> TCPListenSocketFactory::CreateAndListen( StreamListenSocket::Delegate* delegate) const { - return TCPListenSocket::CreateAndListen(ip_, port_, delegate); + return TCPListenSocket::CreateAndListen(ip_, port_, delegate) + .PassAs<StreamListenSocket>(); } } // namespace net diff --git a/chromium/net/socket/tcp_listen_socket.h b/chromium/net/socket/tcp_listen_socket.h index dbc5347e945..54a91de59bb 100644 --- a/chromium/net/socket/tcp_listen_socket.h +++ b/chromium/net/socket/tcp_listen_socket.h @@ -8,18 +8,19 @@ #include <string> #include "base/basictypes.h" -#include "base/memory/ref_counted.h" #include "net/base/net_export.h" +#include "net/socket/socket_descriptor.h" #include "net/socket/stream_listen_socket.h" namespace net { -// Implements a TCP socket. Note that this is ref counted. +// Implements a TCP socket. class NET_EXPORT TCPListenSocket : public StreamListenSocket { public: + virtual ~TCPListenSocket(); // Listen on port for the specified IP address. Use 127.0.0.1 to only // accept local connections. - static scoped_refptr<TCPListenSocket> CreateAndListen( + static scoped_ptr<TCPListenSocket> CreateAndListen( const std::string& ip, int port, StreamListenSocket::Delegate* del); // Get raw TCP socket descriptor bound to ip:port. @@ -30,10 +31,7 @@ class NET_EXPORT TCPListenSocket : public StreamListenSocket { int* port); protected: - friend class scoped_refptr<TCPListenSocket>; - TCPListenSocket(SocketDescriptor s, StreamListenSocket::Delegate* del); - virtual ~TCPListenSocket(); // Implements StreamListenSocket::Accept. virtual void Accept() OVERRIDE; @@ -49,7 +47,7 @@ class NET_EXPORT TCPListenSocketFactory : public StreamListenSocketFactory { virtual ~TCPListenSocketFactory(); // StreamListenSocketFactory overrides. - virtual scoped_refptr<StreamListenSocket> CreateAndListen( + virtual scoped_ptr<StreamListenSocket> CreateAndListen( StreamListenSocket::Delegate* delegate) const OVERRIDE; private: diff --git a/chromium/net/socket/tcp_listen_socket_unittest.cc b/chromium/net/socket/tcp_listen_socket_unittest.cc index d13b784cbdc..b122c6143d8 100644 --- a/chromium/net/socket/tcp_listen_socket_unittest.cc +++ b/chromium/net/socket/tcp_listen_socket_unittest.cc @@ -10,13 +10,14 @@ #include "base/bind.h" #include "base/posix/eintr_wrapper.h" #include "base/sys_byteorder.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" #include "net/base/net_util.h" +#include "net/socket/socket_descriptor.h" #include "testing/platform_test.h" namespace net { -const int TCPListenSocketTester::kTestPort = 9999; - static const int kReadBufSize = 1024; static const char kHelloWorld[] = "HELLO, WORLD"; static const int kMaxQueueSize = 20; @@ -24,7 +25,9 @@ static const char kLoopback[] = "127.0.0.1"; static const int kDefaultTimeoutMs = 5000; TCPListenSocketTester::TCPListenSocketTester() - : loop_(NULL), server_(NULL), connection_(NULL), cv_(&lock_) {} + : loop_(NULL), + cv_(&lock_), + server_port_(0) {} void TCPListenSocketTester::SetUp() { base::Thread::Options options; @@ -41,13 +44,16 @@ void TCPListenSocketTester::SetUp() { ASSERT_FALSE(server_.get() == NULL); ASSERT_EQ(ACTION_LISTEN, last_action_.type()); + int server_port = GetServerPort(); + ASSERT_GT(server_port, 0); + // verify the connect/accept and setup test_socket_ - test_socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - ASSERT_NE(StreamListenSocket::kInvalidSocket, test_socket_); + test_socket_ = CreatePlatformSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ASSERT_NE(kInvalidSocket, test_socket_); struct sockaddr_in client; client.sin_family = AF_INET; client.sin_addr.s_addr = inet_addr(kLoopback); - client.sin_port = base::HostToNet16(kTestPort); + client.sin_port = base::HostToNet16(server_port); int ret = HANDLE_EINTR( connect(test_socket_, reinterpret_cast<sockaddr*>(&client), sizeof(client))); @@ -113,17 +119,20 @@ int TCPListenSocketTester::ClearTestSocket() { } void TCPListenSocketTester::Shutdown() { - connection_->Release(); - connection_ = NULL; - server_->Release(); - server_ = NULL; + connection_.reset(); + server_.reset(); ReportAction(TCPListenSocketTestAction(ACTION_SHUTDOWN)); } void TCPListenSocketTester::Listen() { server_ = DoListen(); ASSERT_TRUE(server_.get()); - server_->AddRef(); + + // The server's port will be needed to open the client socket. + IPEndPoint local_address; + ASSERT_EQ(OK, server_->GetLocalAddress(&local_address)); + SetServerPort(local_address.port()); + ReportAction(TCPListenSocketTestAction(ACTION_LISTEN)); } @@ -227,10 +236,10 @@ bool TCPListenSocketTester::Send(SocketDescriptor sock, return true; } -void TCPListenSocketTester::DidAccept(StreamListenSocket* server, - StreamListenSocket* connection) { - connection_ = connection; - connection_->AddRef(); +void TCPListenSocketTester::DidAccept( + StreamListenSocket* server, + scoped_ptr<StreamListenSocket> connection) { + connection_ = connection.Pass(); ReportAction(TCPListenSocketTestAction(ACTION_ACCEPT)); } @@ -247,11 +256,22 @@ void TCPListenSocketTester::DidClose(StreamListenSocket* sock) { TCPListenSocketTester::~TCPListenSocketTester() {} -scoped_refptr<TCPListenSocket> TCPListenSocketTester::DoListen() { - return TCPListenSocket::CreateAndListen(kLoopback, kTestPort, this); +scoped_ptr<TCPListenSocket> TCPListenSocketTester::DoListen() { + // Let the OS pick a free port. + return TCPListenSocket::CreateAndListen(kLoopback, 0, this); +} + +int TCPListenSocketTester::GetServerPort() { + base::AutoLock locked(lock_); + return server_port_; +} + +void TCPListenSocketTester::SetServerPort(int server_port) { + base::AutoLock locked(lock_); + server_port_ = server_port; } -class TCPListenSocketTest: public PlatformTest { +class TCPListenSocketTest : public PlatformTest { public: TCPListenSocketTest() { tester_ = NULL; diff --git a/chromium/net/socket/tcp_listen_socket_unittest.h b/chromium/net/socket/tcp_listen_socket_unittest.h index 048a0186705..1bc31a8d1ce 100644 --- a/chromium/net/socket/tcp_listen_socket_unittest.h +++ b/chromium/net/socket/tcp_listen_socket_unittest.h @@ -91,30 +91,37 @@ class TCPListenSocketTester : // StreamListenSocket::Delegate: virtual void DidAccept(StreamListenSocket* server, - StreamListenSocket* connection) OVERRIDE; + scoped_ptr<StreamListenSocket> connection) OVERRIDE; virtual void DidRead(StreamListenSocket* connection, const char* data, int len) OVERRIDE; virtual void DidClose(StreamListenSocket* sock) OVERRIDE; scoped_ptr<base::Thread> thread_; base::MessageLoopForIO* loop_; - scoped_refptr<TCPListenSocket> server_; - StreamListenSocket* connection_; + scoped_ptr<TCPListenSocket> server_; + scoped_ptr<StreamListenSocket> connection_; TCPListenSocketTestAction last_action_; SocketDescriptor test_socket_; - static const int kTestPort; - base::Lock lock_; // protects |queue_| and wraps |cv_| + base::Lock lock_; // Protects |queue_| and |server_port_|. Wraps |cv_|. base::ConditionVariable cv_; std::deque<TCPListenSocketTestAction> queue_; - protected: + private: friend class base::RefCountedThreadSafe<TCPListenSocketTester>; virtual ~TCPListenSocketTester(); - virtual scoped_refptr<TCPListenSocket> DoListen(); + virtual scoped_ptr<TCPListenSocket> DoListen(); + + // Getters/setters for |server_port_|. They use |lock_| for thread safety. + int GetServerPort(); + void SetServerPort(int server_port); + + // Port the server is using. Must have |lock_| to access. Set by Listen() on + // the server's thread. + int server_port_; }; } // namespace net diff --git a/chromium/net/socket/tcp_server_socket.cc b/chromium/net/socket/tcp_server_socket.cc new file mode 100644 index 00000000000..a25f73f6c6f --- /dev/null +++ b/chromium/net/socket/tcp_server_socket.cc @@ -0,0 +1,105 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_server_socket.h" + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/logging.h" +#include "net/base/net_errors.h" +#include "net/socket/tcp_client_socket.h" + +namespace net { + +TCPServerSocket::TCPServerSocket(NetLog* net_log, const NetLog::Source& source) + : socket_(net_log, source), + pending_accept_(false) { +} + +TCPServerSocket::~TCPServerSocket() { +} + +int TCPServerSocket::Listen(const IPEndPoint& address, int backlog) { + int result = socket_.Open(address.GetFamily()); + if (result != OK) + return result; + + result = socket_.SetDefaultOptionsForServer(); + if (result != OK) { + socket_.Close(); + return result; + } + + result = socket_.Bind(address); + if (result != OK) { + socket_.Close(); + return result; + } + + result = socket_.Listen(backlog); + if (result != OK) { + socket_.Close(); + return result; + } + + return OK; +} + +int TCPServerSocket::GetLocalAddress(IPEndPoint* address) const { + return socket_.GetLocalAddress(address); +} + +int TCPServerSocket::Accept(scoped_ptr<StreamSocket>* socket, + const CompletionCallback& callback) { + DCHECK(socket); + DCHECK(!callback.is_null()); + + if (pending_accept_) { + NOTREACHED(); + return ERR_UNEXPECTED; + } + + // It is safe to use base::Unretained(this). |socket_| is owned by this class, + // and the callback won't be run after |socket_| is destroyed. + CompletionCallback accept_callback = + base::Bind(&TCPServerSocket::OnAcceptCompleted, base::Unretained(this), + socket, callback); + int result = socket_.Accept(&accepted_socket_, &accepted_address_, + accept_callback); + if (result != ERR_IO_PENDING) { + // |accept_callback| won't be called so we need to run + // ConvertAcceptedSocket() ourselves in order to do the conversion from + // |accepted_socket_| to |socket|. + result = ConvertAcceptedSocket(result, socket); + } else { + pending_accept_ = true; + } + + return result; +} + +int TCPServerSocket::ConvertAcceptedSocket( + int result, + scoped_ptr<StreamSocket>* output_accepted_socket) { + // Make sure the TCPSocket object is destroyed in any case. + scoped_ptr<TCPSocket> temp_accepted_socket(accepted_socket_.Pass()); + if (result != OK) + return result; + + output_accepted_socket->reset(new TCPClientSocket( + temp_accepted_socket.Pass(), accepted_address_)); + + return OK; +} + +void TCPServerSocket::OnAcceptCompleted( + scoped_ptr<StreamSocket>* output_accepted_socket, + const CompletionCallback& forward_callback, + int result) { + result = ConvertAcceptedSocket(result, output_accepted_socket); + pending_accept_ = false; + forward_callback.Run(result); +} + +} // namespace net diff --git a/chromium/net/socket/tcp_server_socket.h b/chromium/net/socket/tcp_server_socket.h index 4970a150e8d..faff9ad826a 100644 --- a/chromium/net/socket/tcp_server_socket.h +++ b/chromium/net/socket/tcp_server_socket.h @@ -5,21 +5,48 @@ #ifndef NET_SOCKET_TCP_SERVER_SOCKET_H_ #define NET_SOCKET_TCP_SERVER_SOCKET_H_ -#include "build/build_config.h" - -#if defined(OS_WIN) -#include "net/socket/tcp_server_socket_win.h" -#elif defined(OS_POSIX) -#include "net/socket/tcp_server_socket_libevent.h" -#endif +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" +#include "net/socket/server_socket.h" +#include "net/socket/tcp_socket.h" namespace net { -#if defined(OS_WIN) -typedef TCPServerSocketWin TCPServerSocket; -#elif defined(OS_POSIX) -typedef TCPServerSocketLibevent TCPServerSocket; -#endif +class NET_EXPORT_PRIVATE TCPServerSocket : public ServerSocket { + public: + TCPServerSocket(NetLog* net_log, const NetLog::Source& source); + virtual ~TCPServerSocket(); + + // net::ServerSocket implementation. + virtual int Listen(const IPEndPoint& address, int backlog) OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual int Accept(scoped_ptr<StreamSocket>* socket, + const CompletionCallback& callback) OVERRIDE; + + private: + // Converts |accepted_socket_| and stores the result in + // |output_accepted_socket|. + // |output_accepted_socket| is untouched on failure. But |accepted_socket_| is + // set to NULL in any case. + int ConvertAcceptedSocket(int result, + scoped_ptr<StreamSocket>* output_accepted_socket); + // Completion callback for calling TCPSocket::Accept(). + void OnAcceptCompleted(scoped_ptr<StreamSocket>* output_accepted_socket, + const CompletionCallback& forward_callback, + int result); + + TCPSocket socket_; + + scoped_ptr<TCPSocket> accepted_socket_; + IPEndPoint accepted_address_; + bool pending_accept_; + + DISALLOW_COPY_AND_ASSIGN(TCPServerSocket); +}; } // namespace net diff --git a/chromium/net/socket/tcp_server_socket_libevent.cc b/chromium/net/socket/tcp_server_socket_libevent.cc deleted file mode 100644 index 38dda962f46..00000000000 --- a/chromium/net/socket/tcp_server_socket_libevent.cc +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/tcp_server_socket_libevent.h" - -#include <errno.h> -#include <fcntl.h> -#include <netdb.h> -#include <sys/socket.h> - -#include "build/build_config.h" - -#if defined(OS_POSIX) -#include <netinet/in.h> -#endif - -#include "base/posix/eintr_wrapper.h" -#include "net/base/ip_endpoint.h" -#include "net/base/net_errors.h" -#include "net/base/net_util.h" -#include "net/socket/socket_net_log_params.h" -#include "net/socket/tcp_client_socket.h" - -namespace net { - -namespace { - -const int kInvalidSocket = -1; - -} // namespace - -TCPServerSocketLibevent::TCPServerSocketLibevent( - net::NetLog* net_log, - const net::NetLog::Source& source) - : socket_(kInvalidSocket), - accept_socket_(NULL), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { - net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, - source.ToEventParametersCallback()); -} - -TCPServerSocketLibevent::~TCPServerSocketLibevent() { - if (socket_ != kInvalidSocket) - Close(); - net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); -} - -int TCPServerSocketLibevent::Listen(const IPEndPoint& address, int backlog) { - DCHECK(CalledOnValidThread()); - DCHECK_GT(backlog, 0); - DCHECK_EQ(socket_, kInvalidSocket); - - socket_ = socket(address.GetSockAddrFamily(), SOCK_STREAM, IPPROTO_TCP); - if (socket_ < 0) { - PLOG(ERROR) << "socket() returned an error"; - return MapSystemError(errno); - } - - if (SetNonBlocking(socket_)) { - int result = MapSystemError(errno); - Close(); - return result; - } - - int result = SetSocketOptions(); - if (result != OK) { - Close(); - return result; - } - - SockaddrStorage storage; - if (!address.ToSockAddr(storage.addr, &storage.addr_len)) { - Close(); - return ERR_ADDRESS_INVALID; - } - - result = bind(socket_, storage.addr, storage.addr_len); - if (result < 0) { - PLOG(ERROR) << "bind() returned an error"; - result = MapSystemError(errno); - Close(); - return result; - } - - result = listen(socket_, backlog); - if (result < 0) { - PLOG(ERROR) << "listen() returned an error"; - result = MapSystemError(errno); - Close(); - return result; - } - - return OK; -} - -int TCPServerSocketLibevent::GetLocalAddress(IPEndPoint* address) const { - DCHECK(CalledOnValidThread()); - DCHECK(address); - - SockaddrStorage storage; - if (getsockname(socket_, storage.addr, &storage.addr_len) < 0) - return MapSystemError(errno); - if (!address->FromSockAddr(storage.addr, storage.addr_len)) - return ERR_FAILED; - - return OK; -} - -int TCPServerSocketLibevent::Accept( - scoped_ptr<StreamSocket>* socket, const CompletionCallback& callback) { - DCHECK(CalledOnValidThread()); - DCHECK(socket); - DCHECK(!callback.is_null()); - DCHECK(accept_callback_.is_null()); - - net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT); - - int result = AcceptInternal(socket); - - if (result == ERR_IO_PENDING) { - if (!base::MessageLoopForIO::current()->WatchFileDescriptor( - socket_, true, base::MessageLoopForIO::WATCH_READ, - &accept_socket_watcher_, this)) { - PLOG(ERROR) << "WatchFileDescriptor failed on read"; - return MapSystemError(errno); - } - - accept_socket_ = socket; - accept_callback_ = callback; - } - - return result; -} - -int TCPServerSocketLibevent::SetSocketOptions() { - // SO_REUSEADDR is useful for server sockets to bind to a recently unbound - // port. When a socket is closed, the end point changes its state to TIME_WAIT - // and wait for 2 MSL (maximum segment lifetime) to ensure the remote peer - // acknowledges its closure. For server sockets, it is usually safe to - // bind to a TIME_WAIT end point immediately, which is a widely adopted - // behavior. - // - // Note that on *nix, SO_REUSEADDR does not enable the TCP socket to bind to - // an end point that is already bound by another socket. To do that one must - // set SO_REUSEPORT instead. This option is not provided on Linux prior - // to 3.9. - // - // SO_REUSEPORT is provided in MacOS X and iOS. - int true_value = 1; - int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &true_value, - sizeof(true_value)); - if (rv < 0) - return MapSystemError(errno); - return OK; -} - -int TCPServerSocketLibevent::AcceptInternal( - scoped_ptr<StreamSocket>* socket) { - SockaddrStorage storage; - int new_socket = HANDLE_EINTR(accept(socket_, - storage.addr, - &storage.addr_len)); - if (new_socket < 0) { - int net_error = MapSystemError(errno); - if (net_error != ERR_IO_PENDING) - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); - return net_error; - } - - IPEndPoint address; - if (!address.FromSockAddr(storage.addr, storage.addr_len)) { - NOTREACHED(); - if (HANDLE_EINTR(close(new_socket)) < 0) - PLOG(ERROR) << "close"; - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED); - return ERR_FAILED; - } - scoped_ptr<TCPClientSocket> tcp_socket(new TCPClientSocket( - AddressList(address), - net_log_.net_log(), net_log_.source())); - int adopt_result = tcp_socket->AdoptSocket(new_socket); - if (adopt_result != OK) { - if (HANDLE_EINTR(close(new_socket)) < 0) - PLOG(ERROR) << "close"; - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result); - return adopt_result; - } - socket->reset(tcp_socket.release()); - net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, - CreateNetLogIPEndPointCallback(&address)); - return OK; -} - -void TCPServerSocketLibevent::Close() { - if (socket_ != kInvalidSocket) { - bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - if (HANDLE_EINTR(close(socket_)) < 0) - PLOG(ERROR) << "close"; - socket_ = kInvalidSocket; - } -} - -void TCPServerSocketLibevent::OnFileCanReadWithoutBlocking(int fd) { - DCHECK(CalledOnValidThread()); - - int result = AcceptInternal(accept_socket_); - if (result != ERR_IO_PENDING) { - accept_socket_ = NULL; - bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - CompletionCallback callback = accept_callback_; - accept_callback_.Reset(); - callback.Run(result); - } -} - -void TCPServerSocketLibevent::OnFileCanWriteWithoutBlocking(int fd) { - NOTREACHED(); -} - -} // namespace net diff --git a/chromium/net/socket/tcp_server_socket_libevent.h b/chromium/net/socket/tcp_server_socket_libevent.h deleted file mode 100644 index fe69472a653..00000000000 --- a/chromium/net/socket/tcp_server_socket_libevent.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) 2011 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_ -#define NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_ - -#include "base/memory/scoped_ptr.h" -#include "base/message_loop/message_loop.h" -#include "base/threading/non_thread_safe.h" -#include "net/base/completion_callback.h" -#include "net/base/net_log.h" -#include "net/socket/server_socket.h" - -namespace net { - -class IPEndPoint; - -class NET_EXPORT_PRIVATE TCPServerSocketLibevent : - public ServerSocket, - public base::NonThreadSafe, - public base::MessageLoopForIO::Watcher { - public: - TCPServerSocketLibevent(net::NetLog* net_log, - const net::NetLog::Source& source); - virtual ~TCPServerSocketLibevent(); - - // net::ServerSocket implementation. - virtual int Listen(const net::IPEndPoint& address, int backlog) OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual int Accept(scoped_ptr<StreamSocket>* socket, - const CompletionCallback& callback) OVERRIDE; - - // MessageLoopForIO::Watcher implementation. - virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE; - virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE; - - private: - int SetSocketOptions(); - int AcceptInternal(scoped_ptr<StreamSocket>* socket); - void Close(); - - int socket_; - - base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_; - - scoped_ptr<StreamSocket>* accept_socket_; - CompletionCallback accept_callback_; - - BoundNetLog net_log_; -}; - -} // namespace net - -#endif // NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_ diff --git a/chromium/net/socket/tcp_server_socket_win.cc b/chromium/net/socket/tcp_server_socket_win.cc deleted file mode 100644 index 0ac77be5e81..00000000000 --- a/chromium/net/socket/tcp_server_socket_win.cc +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/tcp_server_socket_win.h" - -#include <mstcpip.h> - -#include "net/base/ip_endpoint.h" -#include "net/base/net_errors.h" -#include "net/base/net_util.h" -#include "net/base/winsock_init.h" -#include "net/base/winsock_util.h" -#include "net/socket/socket_net_log_params.h" -#include "net/socket/tcp_client_socket.h" - -namespace net { - -TCPServerSocketWin::TCPServerSocketWin(net::NetLog* net_log, - const net::NetLog::Source& source) - : socket_(INVALID_SOCKET), - socket_event_(WSA_INVALID_EVENT), - accept_socket_(NULL), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { - net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, - source.ToEventParametersCallback()); - EnsureWinsockInit(); -} - -TCPServerSocketWin::~TCPServerSocketWin() { - Close(); - net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); -} - -int TCPServerSocketWin::Listen(const IPEndPoint& address, int backlog) { - DCHECK(CalledOnValidThread()); - DCHECK_GT(backlog, 0); - DCHECK_EQ(socket_, INVALID_SOCKET); - DCHECK_EQ(socket_event_, WSA_INVALID_EVENT); - - socket_event_ = WSACreateEvent(); - if (socket_event_ == WSA_INVALID_EVENT) { - PLOG(ERROR) << "WSACreateEvent()"; - return ERR_FAILED; - } - - socket_ = socket(address.GetSockAddrFamily(), SOCK_STREAM, IPPROTO_TCP); - if (socket_ == INVALID_SOCKET) { - PLOG(ERROR) << "socket() returned an error"; - return MapSystemError(WSAGetLastError()); - } - - if (SetNonBlocking(socket_)) { - int result = MapSystemError(WSAGetLastError()); - Close(); - return result; - } - - int result = SetSocketOptions(); - if (result != OK) { - Close(); - return result; - } - - SockaddrStorage storage; - if (!address.ToSockAddr(storage.addr, &storage.addr_len)) { - Close(); - return ERR_ADDRESS_INVALID; - } - - result = bind(socket_, storage.addr, storage.addr_len); - if (result < 0) { - PLOG(ERROR) << "bind() returned an error"; - result = MapSystemError(WSAGetLastError()); - Close(); - return result; - } - - result = listen(socket_, backlog); - if (result < 0) { - PLOG(ERROR) << "listen() returned an error"; - result = MapSystemError(WSAGetLastError()); - Close(); - return result; - } - - return OK; -} - -int TCPServerSocketWin::GetLocalAddress(IPEndPoint* address) const { - DCHECK(CalledOnValidThread()); - DCHECK(address); - - SockaddrStorage storage; - if (getsockname(socket_, storage.addr, &storage.addr_len)) - return MapSystemError(WSAGetLastError()); - if (!address->FromSockAddr(storage.addr, storage.addr_len)) - return ERR_FAILED; - - return OK; -} - -int TCPServerSocketWin::Accept( - scoped_ptr<StreamSocket>* socket, const CompletionCallback& callback) { - DCHECK(CalledOnValidThread()); - DCHECK(socket); - DCHECK(!callback.is_null()); - DCHECK(accept_callback_.is_null()); - - net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT); - - int result = AcceptInternal(socket); - - if (result == ERR_IO_PENDING) { - // Start watching - WSAEventSelect(socket_, socket_event_, FD_ACCEPT); - accept_watcher_.StartWatching(socket_event_, this); - - accept_socket_ = socket; - accept_callback_ = callback; - } - - return result; -} - -int TCPServerSocketWin::SetSocketOptions() { - // On Windows, a bound end point can be hijacked by another process by - // setting SO_REUSEADDR. Therefore a Windows-only option SO_EXCLUSIVEADDRUSE - // was introduced in Windows NT 4.0 SP4. If the socket that is bound to the - // end point has SO_EXCLUSIVEADDRUSE enabled, it is not possible for another - // socket to forcibly bind to the end point until the end point is unbound. - // It is recommend that all server applications must use SO_EXCLUSIVEADDRUSE. - // MSDN: http://goo.gl/M6fjQ. - // - // Unlike on *nix, on Windows a TCP server socket can always bind to an end - // point in TIME_WAIT state without setting SO_REUSEADDR, therefore it is not - // needed here. - // - // SO_EXCLUSIVEADDRUSE will prevent a TCP client socket from binding to an end - // point in TIME_WAIT status. It does not have this effect for a TCP server - // socket. - - BOOL true_value = 1; - int rv = setsockopt(socket_, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, - reinterpret_cast<const char*>(&true_value), - sizeof(true_value)); - if (rv < 0) - return MapSystemError(errno); - return OK; -} - -int TCPServerSocketWin::AcceptInternal(scoped_ptr<StreamSocket>* socket) { - SockaddrStorage storage; - int new_socket = accept(socket_, storage.addr, &storage.addr_len); - if (new_socket < 0) { - int net_error = MapSystemError(WSAGetLastError()); - if (net_error != ERR_IO_PENDING) - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); - return net_error; - } - - IPEndPoint address; - if (!address.FromSockAddr(storage.addr, storage.addr_len)) { - NOTREACHED(); - if (closesocket(new_socket) < 0) - PLOG(ERROR) << "closesocket"; - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED); - return ERR_FAILED; - } - scoped_ptr<TCPClientSocket> tcp_socket(new TCPClientSocket( - AddressList(address), - net_log_.net_log(), net_log_.source())); - int adopt_result = tcp_socket->AdoptSocket(new_socket); - if (adopt_result != OK) { - if (closesocket(new_socket) < 0) - PLOG(ERROR) << "closesocket"; - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result); - return adopt_result; - } - socket->reset(tcp_socket.release()); - net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, - CreateNetLogIPEndPointCallback(&address)); - return OK; -} - -void TCPServerSocketWin::Close() { - if (socket_ != INVALID_SOCKET) { - if (closesocket(socket_) < 0) - PLOG(ERROR) << "closesocket"; - socket_ = INVALID_SOCKET; - } - - if (socket_event_) { - WSACloseEvent(socket_event_); - socket_event_ = WSA_INVALID_EVENT; - } -} - -void TCPServerSocketWin::OnObjectSignaled(HANDLE object) { - WSANETWORKEVENTS ev; - if (WSAEnumNetworkEvents(socket_, socket_event_, &ev) == SOCKET_ERROR) { - PLOG(ERROR) << "WSAEnumNetworkEvents()"; - return; - } - - if (ev.lNetworkEvents & FD_ACCEPT) { - int result = AcceptInternal(accept_socket_); - if (result != ERR_IO_PENDING) { - accept_socket_ = NULL; - CompletionCallback callback = accept_callback_; - accept_callback_.Reset(); - callback.Run(result); - } - } -} - -} // namespace net diff --git a/chromium/net/socket/tcp_server_socket_win.h b/chromium/net/socket/tcp_server_socket_win.h deleted file mode 100644 index 5a1d378ad9b..00000000000 --- a/chromium/net/socket/tcp_server_socket_win.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2011 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_ -#define NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_ - -#include <winsock2.h> - -#include "base/memory/scoped_ptr.h" -#include "base/message_loop/message_loop.h" -#include "base/threading/non_thread_safe.h" -#include "base/win/object_watcher.h" -#include "net/base/completion_callback.h" -#include "net/base/net_log.h" -#include "net/socket/server_socket.h" - -namespace net { - -class IPEndPoint; - -class NET_EXPORT_PRIVATE TCPServerSocketWin - : public ServerSocket, - NON_EXPORTED_BASE(public base::NonThreadSafe), - public base::win::ObjectWatcher::Delegate { - public: - TCPServerSocketWin(net::NetLog* net_log, - const net::NetLog::Source& source); - ~TCPServerSocketWin(); - - // net::ServerSocket implementation. - virtual int Listen(const net::IPEndPoint& address, int backlog) OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual int Accept(scoped_ptr<StreamSocket>* socket, - const CompletionCallback& callback) OVERRIDE; - - // base::ObjectWatcher::Delegate implementation. - virtual void OnObjectSignaled(HANDLE object); - - private: - int SetSocketOptions(); - int AcceptInternal(scoped_ptr<StreamSocket>* socket); - void Close(); - - SOCKET socket_; - HANDLE socket_event_; - - base::win::ObjectWatcher accept_watcher_; - - scoped_ptr<StreamSocket>* accept_socket_; - CompletionCallback accept_callback_; - - BoundNetLog net_log_; -}; - -} // namespace net - -#endif // NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_ diff --git a/chromium/net/socket/tcp_socket.cc b/chromium/net/socket/tcp_socket.cc new file mode 100644 index 00000000000..fd72f6b4640 --- /dev/null +++ b/chromium/net/socket/tcp_socket.cc @@ -0,0 +1,59 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_socket.h" + +#include "base/file_util.h" +#include "base/files/file_path.h" + +namespace net { + +namespace { + +#if defined(OS_LINUX) + +// Checks to see if the system supports TCP FastOpen. Notably, it requires +// kernel support. Additionally, this checks system configuration to ensure that +// it's enabled. +bool SystemSupportsTCPFastOpen() { + static const base::FilePath::CharType kTCPFastOpenProcFilePath[] = + "/proc/sys/net/ipv4/tcp_fastopen"; + std::string system_enabled_tcp_fastopen; + if (!base::ReadFileToString( + base::FilePath(kTCPFastOpenProcFilePath), + &system_enabled_tcp_fastopen)) { + return false; + } + + // As per http://lxr.linux.no/linux+v3.7.7/include/net/tcp.h#L225 + // TFO_CLIENT_ENABLE is the LSB + if (system_enabled_tcp_fastopen.empty() || + (system_enabled_tcp_fastopen[0] & 0x1) == 0) { + return false; + } + + return true; +} + +#else + +bool SystemSupportsTCPFastOpen() { + return false; +} + +#endif + +bool g_tcp_fastopen_enabled = false; + +} // namespace + +void SetTCPFastOpenEnabled(bool value) { + g_tcp_fastopen_enabled = value && SystemSupportsTCPFastOpen(); +} + +bool IsTCPFastOpenEnabled() { + return g_tcp_fastopen_enabled; +} + +} // namespace net diff --git a/chromium/net/socket/tcp_socket.h b/chromium/net/socket/tcp_socket.h new file mode 100644 index 00000000000..8b36fade758 --- /dev/null +++ b/chromium/net/socket/tcp_socket.h @@ -0,0 +1,40 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_SOCKET_H_ +#define NET_SOCKET_TCP_SOCKET_H_ + +#include "build/build_config.h" +#include "net/base/net_export.h" + +#if defined(OS_WIN) +#include "net/socket/tcp_socket_win.h" +#elif defined(OS_POSIX) +#include "net/socket/tcp_socket_libevent.h" +#endif + +namespace net { + +// Enable/disable experimental TCP FastOpen option. +// Not thread safe. Must be called during initialization/startup only. +NET_EXPORT void SetTCPFastOpenEnabled(bool value); + +// Check if the TCP FastOpen option is enabled. +bool IsTCPFastOpenEnabled(); + +// TCPSocket provides a platform-independent interface for TCP sockets. +// +// It is recommended to use TCPClientSocket/TCPServerSocket instead of this +// class, unless a clear separation of client and server socket functionality is +// not suitable for your use case (e.g., a socket needs to be created and bound +// before you know whether it is a client or server socket). +#if defined(OS_WIN) +typedef TCPSocketWin TCPSocket; +#elif defined(OS_POSIX) +typedef TCPSocketLibevent TCPSocket; +#endif + +} // namespace net + +#endif // NET_SOCKET_TCP_SOCKET_H_ diff --git a/chromium/net/socket/tcp_client_socket_libevent.cc b/chromium/net/socket/tcp_socket_libevent.cc index 2f7e4b4b255..66416f70207 100644 --- a/chromium/net/socket/tcp_client_socket_libevent.cc +++ b/chromium/net/socket/tcp_socket_libevent.cc @@ -1,29 +1,27 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/socket/tcp_client_socket.h" +#include "net/socket/tcp_socket.h" #include <errno.h> #include <fcntl.h> #include <netdb.h> -#include <sys/socket.h> -#include <netinet/tcp.h> -#if defined(OS_POSIX) #include <netinet/in.h> -#endif +#include <netinet/tcp.h> +#include <sys/socket.h> +#include "base/callback_helpers.h" #include "base/logging.h" -#include "base/message_loop/message_loop.h" #include "base/metrics/histogram.h" #include "base/metrics/stats_counters.h" #include "base/posix/eintr_wrapper.h" -#include "base/strings/string_util.h" +#include "build/build_config.h" +#include "net/base/address_list.h" #include "net/base/connection_type_histograms.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" #include "net/base/net_util.h" #include "net/base/network_change_notifier.h" #include "net/socket/socket_net_log_params.h" @@ -37,7 +35,6 @@ namespace net { namespace { -const int kInvalidSocket = -1; const int kTCPKeepAliveSeconds = 45; // SetTCPNoDelay turns on/off buffering in the kernel. By default, TCP sockets @@ -46,13 +43,12 @@ const int kTCPKeepAliveSeconds = 45; // `man 7 tcp`. bool SetTCPNoDelay(int fd, bool no_delay) { int on = no_delay ? 1 : 0; - int error = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &on, - sizeof(on)); + int error = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)); return error == 0; } // SetTCPKeepAlive sets SO_KEEPALIVE. -bool SetTCPKeepAlive(int fd, bool enable, int delay) { +bool SetTCPKeepAlive(int fd, bool enable, int delay) { int on = enable ? 1 : 0; if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &on, sizeof(on))) { PLOG(ERROR) << "Failed to set SO_KEEPALIVE on fd: " << fd; @@ -73,36 +69,6 @@ bool SetTCPKeepAlive(int fd, bool enable, int delay) { return true; } -// Sets socket parameters. Returns the OS error code (or 0 on -// success). -int SetupSocket(int socket) { - if (SetNonBlocking(socket)) - return errno; - - // This mirrors the behaviour on Windows. See the comment in - // tcp_client_socket_win.cc after searching for "NODELAY". - SetTCPNoDelay(socket, true); // If SetTCPNoDelay fails, we don't care. - SetTCPKeepAlive(socket, true, kTCPKeepAliveSeconds); - - return 0; -} - -// Creates a new socket and sets default parameters for it. Returns -// the OS error code (or 0 on success). -int CreateSocket(int family, int* socket) { - *socket = ::socket(family, SOCK_STREAM, IPPROTO_TCP); - if (*socket == kInvalidSocket) - return errno; - int error = SetupSocket(*socket); - if (error) { - if (HANDLE_EINTR(close(*socket)) < 0) - PLOG(ERROR) << "close"; - *socket = kInvalidSocket; - return error; - } - return 0; -} - int MapConnectError(int os_error) { switch (os_error) { case EACCES: @@ -128,275 +94,206 @@ int MapConnectError(int os_error) { //----------------------------------------------------------------------------- -TCPClientSocketLibevent::TCPClientSocketLibevent( - const AddressList& addresses, - net::NetLog* net_log, - const net::NetLog::Source& source) +TCPSocketLibevent::Watcher::Watcher( + const base::Closure& read_ready_callback, + const base::Closure& write_ready_callback) + : read_ready_callback_(read_ready_callback), + write_ready_callback_(write_ready_callback) { +} + +TCPSocketLibevent::Watcher::~Watcher() { +} + +void TCPSocketLibevent::Watcher::OnFileCanReadWithoutBlocking(int /* fd */) { + if (!read_ready_callback_.is_null()) + read_ready_callback_.Run(); + else + NOTREACHED(); +} + +void TCPSocketLibevent::Watcher::OnFileCanWriteWithoutBlocking(int /* fd */) { + if (!write_ready_callback_.is_null()) + write_ready_callback_.Run(); + else + NOTREACHED(); +} + +TCPSocketLibevent::TCPSocketLibevent(NetLog* net_log, + const NetLog::Source& source) : socket_(kInvalidSocket), - bound_socket_(kInvalidSocket), - addresses_(addresses), - current_address_index_(-1), - read_watcher_(this), - write_watcher_(this), - next_connect_state_(CONNECT_STATE_NONE), - connect_os_error_(0), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), - previously_disconnected_(false), + accept_watcher_(base::Bind(&TCPSocketLibevent::DidCompleteAccept, + base::Unretained(this)), + base::Closure()), + accept_socket_(NULL), + accept_address_(NULL), + read_watcher_(base::Bind(&TCPSocketLibevent::DidCompleteRead, + base::Unretained(this)), + base::Closure()), + write_watcher_(base::Closure(), + base::Bind(&TCPSocketLibevent::DidCompleteConnectOrWrite, + base::Unretained(this))), + read_buf_len_(0), + write_buf_len_(0), use_tcp_fastopen_(IsTCPFastOpenEnabled()), tcp_fastopen_connected_(false), - fast_open_status_(FAST_OPEN_STATUS_UNKNOWN) { + fast_open_status_(FAST_OPEN_STATUS_UNKNOWN), + waiting_connect_(false), + connect_os_error_(0), + logging_multiple_connect_attempts_(false), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, source.ToEventParametersCallback()); } -TCPClientSocketLibevent::~TCPClientSocketLibevent() { - Disconnect(); +TCPSocketLibevent::~TCPSocketLibevent() { net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); if (tcp_fastopen_connected_) { UMA_HISTOGRAM_ENUMERATION("Net.TcpFastOpenSocketConnection", fast_open_status_, FAST_OPEN_MAX_VALUE); } + Close(); } -int TCPClientSocketLibevent::AdoptSocket(int socket) { +int TCPSocketLibevent::Open(AddressFamily family) { + DCHECK(CalledOnValidThread()); DCHECK_EQ(socket_, kInvalidSocket); - int error = SetupSocket(socket); - if (error) - return MapSystemError(error); - - socket_ = socket; + socket_ = CreatePlatformSocket(ConvertAddressFamily(family), SOCK_STREAM, + IPPROTO_TCP); + if (socket_ < 0) { + PLOG(ERROR) << "CreatePlatformSocket() returned an error"; + return MapSystemError(errno); + } - // This is to make GetPeerAddress() work. It's up to the caller ensure - // that |address_| contains a reasonable address for this - // socket. (i.e. at least match IPv4 vs IPv6!). - current_address_index_ = 0; - use_history_.set_was_ever_connected(); + if (SetNonBlocking(socket_)) { + int result = MapSystemError(errno); + Close(); + return result; + } return OK; } -int TCPClientSocketLibevent::Bind(const IPEndPoint& address) { - if (current_address_index_ >= 0 || bind_address_.get()) { - // Cannot bind the socket if we are already bound connected or - // connecting. - return ERR_UNEXPECTED; - } - - SockaddrStorage storage; - if (!address.ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_INVALID_ARGUMENT; +int TCPSocketLibevent::AdoptConnectedSocket(int socket, + const IPEndPoint& peer_address) { + DCHECK(CalledOnValidThread()); + DCHECK_EQ(socket_, kInvalidSocket); - // Create |bound_socket_| and try to bind it to |address|. - int error = CreateSocket(address.GetSockAddrFamily(), &bound_socket_); - if (error) - return MapSystemError(error); + socket_ = socket; - if (HANDLE_EINTR(bind(bound_socket_, storage.addr, storage.addr_len))) { - error = errno; - if (HANDLE_EINTR(close(bound_socket_)) < 0) - PLOG(ERROR) << "close"; - bound_socket_ = kInvalidSocket; - return MapSystemError(error); + if (SetNonBlocking(socket_)) { + int result = MapSystemError(errno); + Close(); + return result; } - bind_address_.reset(new IPEndPoint(address)); + peer_address_.reset(new IPEndPoint(peer_address)); - return 0; + return OK; } -int TCPClientSocketLibevent::Connect(const CompletionCallback& callback) { +int TCPSocketLibevent::Bind(const IPEndPoint& address) { DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, kInvalidSocket); - // If already connected, then just return OK. - if (socket_ != kInvalidSocket) - return OK; - - base::StatsCounter connects("tcp.connect"); - connects.Increment(); - - DCHECK(!waiting_connect()); - - net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, - addresses_.CreateNetLogCallback()); - - // We will try to connect to each address in addresses_. Start with the - // first one in the list. - next_connect_state_ = CONNECT_STATE_CONNECT; - current_address_index_ = 0; + SockaddrStorage storage; + if (!address.ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_ADDRESS_INVALID; - int rv = DoConnectLoop(OK); - if (rv == ERR_IO_PENDING) { - // Synchronous operation not supported. - DCHECK(!callback.is_null()); - write_callback_ = callback; - } else { - LogConnectCompletion(rv); + int result = bind(socket_, storage.addr, storage.addr_len); + if (result < 0) { + PLOG(ERROR) << "bind() returned an error"; + return MapSystemError(errno); } - return rv; -} - -int TCPClientSocketLibevent::DoConnectLoop(int result) { - DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE); - - int rv = result; - do { - ConnectState state = next_connect_state_; - next_connect_state_ = CONNECT_STATE_NONE; - switch (state) { - case CONNECT_STATE_CONNECT: - DCHECK_EQ(OK, rv); - rv = DoConnect(); - break; - case CONNECT_STATE_CONNECT_COMPLETE: - rv = DoConnectComplete(rv); - break; - default: - LOG(DFATAL) << "bad state"; - rv = ERR_UNEXPECTED; - break; - } - } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE); - - return rv; + return OK; } -int TCPClientSocketLibevent::DoConnect() { - DCHECK_GE(current_address_index_, 0); - DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size())); - DCHECK_EQ(0, connect_os_error_); - - const IPEndPoint& endpoint = addresses_[current_address_index_]; +int TCPSocketLibevent::Listen(int backlog) { + DCHECK(CalledOnValidThread()); + DCHECK_GT(backlog, 0); + DCHECK_NE(socket_, kInvalidSocket); - if (previously_disconnected_) { - use_history_.Reset(); - previously_disconnected_ = false; + int result = listen(socket_, backlog); + if (result < 0) { + PLOG(ERROR) << "listen() returned an error"; + return MapSystemError(errno); } - net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, - CreateNetLogIPEndPointCallback(&endpoint)); + return OK; +} - next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE; +int TCPSocketLibevent::Accept(scoped_ptr<TCPSocketLibevent>* socket, + IPEndPoint* address, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK(socket); + DCHECK(address); + DCHECK(!callback.is_null()); + DCHECK(accept_callback_.is_null()); - if (bound_socket_ != kInvalidSocket) { - DCHECK(bind_address_.get()); - socket_ = bound_socket_; - bound_socket_ = kInvalidSocket; - } else { - // Create a non-blocking socket. - connect_os_error_ = CreateSocket(endpoint.GetSockAddrFamily(), &socket_); - if (connect_os_error_) - return MapSystemError(connect_os_error_); - - if (bind_address_.get()) { - SockaddrStorage storage; - if (!bind_address_->ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_INVALID_ARGUMENT; - if (HANDLE_EINTR(bind(socket_, storage.addr, storage.addr_len))) - return MapSystemError(errno); - } - } + net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT); - // Connect the socket. - if (!use_tcp_fastopen_) { - SockaddrStorage storage; - if (!endpoint.ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_INVALID_ARGUMENT; + int result = AcceptInternal(socket, address); - if (!HANDLE_EINTR(connect(socket_, storage.addr, storage.addr_len))) { - // Connected without waiting! - return OK; + if (result == ERR_IO_PENDING) { + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, base::MessageLoopForIO::WATCH_READ, + &accept_socket_watcher_, &accept_watcher_)) { + PLOG(ERROR) << "WatchFileDescriptor failed on read"; + return MapSystemError(errno); } - } else { - // With TCP FastOpen, we pretend that the socket is connected. - DCHECK(!tcp_fastopen_connected_); - return OK; - } - - // Check if the connect() failed synchronously. - connect_os_error_ = errno; - if (connect_os_error_ != EINPROGRESS) - return MapConnectError(connect_os_error_); - // Otherwise the connect() is going to complete asynchronously, so watch - // for its completion. - if (!base::MessageLoopForIO::current()->WatchFileDescriptor( - socket_, true, base::MessageLoopForIO::WATCH_WRITE, - &write_socket_watcher_, &write_watcher_)) { - connect_os_error_ = errno; - DVLOG(1) << "WatchFileDescriptor failed: " << connect_os_error_; - return MapSystemError(connect_os_error_); + accept_socket_ = socket; + accept_address_ = address; + accept_callback_ = callback; } - return ERR_IO_PENDING; -} - -int TCPClientSocketLibevent::DoConnectComplete(int result) { - // Log the end of this attempt (and any OS error it threw). - int os_error = connect_os_error_; - connect_os_error_ = 0; - if (result != OK) { - net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, - NetLog::IntegerCallback("os_error", os_error)); - } else { - net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT); - } - - if (result == OK) { - write_socket_watcher_.StopWatchingFileDescriptor(); - use_history_.set_was_ever_connected(); - return OK; // Done! - } - - // Close whatever partially connected socket we currently have. - DoDisconnect(); - - // Try to fall back to the next address in the list. - if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) { - next_connect_state_ = CONNECT_STATE_CONNECT; - ++current_address_index_; - return OK; - } - - // Otherwise there is nothing to fall back to, so give up. return result; } -void TCPClientSocketLibevent::Disconnect() { +int TCPSocketLibevent::Connect(const IPEndPoint& address, + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, kInvalidSocket); + DCHECK(!waiting_connect_); - DoDisconnect(); - current_address_index_ = -1; - bind_address_.reset(); -} + // |peer_address_| will be non-NULL if Connect() has been called. Unless + // Close() is called to reset the internal state, a second call to Connect() + // is not allowed. + // Please note that we don't allow a second Connect() even if the previous + // Connect() has failed. Connecting the same |socket_| again after a + // connection attempt failed results in unspecified behavior according to + // POSIX. + DCHECK(!peer_address_); -void TCPClientSocketLibevent::DoDisconnect() { - if (socket_ == kInvalidSocket) - return; + if (!logging_multiple_connect_attempts_) + LogConnectBegin(AddressList(address)); - bool ok = read_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - ok = write_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - if (HANDLE_EINTR(close(socket_)) < 0) - PLOG(ERROR) << "close"; - socket_ = kInvalidSocket; - previously_disconnected_ = true; + peer_address_.reset(new IPEndPoint(address)); + + int rv = DoConnect(); + if (rv == ERR_IO_PENDING) { + // Synchronous operation not supported. + DCHECK(!callback.is_null()); + write_callback_ = callback; + waiting_connect_ = true; + } else { + DoConnectComplete(rv); + } + + return rv; } -bool TCPClientSocketLibevent::IsConnected() const { +bool TCPSocketLibevent::IsConnected() const { DCHECK(CalledOnValidThread()); - if (socket_ == kInvalidSocket || waiting_connect()) + if (socket_ == kInvalidSocket || waiting_connect_) return false; - if (use_tcp_fastopen_ && !tcp_fastopen_connected_) { + if (use_tcp_fastopen_ && !tcp_fastopen_connected_ && peer_address_) { // With TCP FastOpen, we pretend that the socket is connected. - // This allows GetPeerAddress() to return current_ai_ as the peer - // address. Since we don't fail over to the next address if - // sendto() fails, current_ai_ is the only possible peer address. - CHECK_LT(current_address_index_, static_cast<int>(addresses_.size())); + // This allows GetPeerAddress() to return peer_address_. return true; } @@ -411,10 +308,10 @@ bool TCPClientSocketLibevent::IsConnected() const { return true; } -bool TCPClientSocketLibevent::IsConnectedAndIdle() const { +bool TCPSocketLibevent::IsConnectedAndIdle() const { DCHECK(CalledOnValidThread()); - if (socket_ == kInvalidSocket || waiting_connect()) + if (socket_ == kInvalidSocket || waiting_connect_) return false; // TODO(wtc): should we also handle the TCP FastOpen case here, @@ -432,12 +329,12 @@ bool TCPClientSocketLibevent::IsConnectedAndIdle() const { return true; } -int TCPClientSocketLibevent::Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) { +int TCPSocketLibevent::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(kInvalidSocket, socket_); - DCHECK(!waiting_connect()); + DCHECK(!waiting_connect_); DCHECK(read_callback_.is_null()); // Synchronous operation not supported DCHECK(!callback.is_null()); @@ -447,8 +344,6 @@ int TCPClientSocketLibevent::Read(IOBuffer* buf, if (nread >= 0) { base::StatsCounter read_bytes("tcp.read_bytes"); read_bytes.Add(nread); - if (nread > 0) - use_history_.set_was_used_to_convey_data(); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, nread, buf->data()); RecordFastOpenStatus(); @@ -474,12 +369,12 @@ int TCPClientSocketLibevent::Read(IOBuffer* buf, return ERR_IO_PENDING; } -int TCPClientSocketLibevent::Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) { +int TCPSocketLibevent::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(kInvalidSocket, socket_); - DCHECK(!waiting_connect()); + DCHECK(!waiting_connect_); DCHECK(write_callback_.is_null()); // Synchronous operation not supported DCHECK(!callback.is_null()); @@ -489,8 +384,6 @@ int TCPClientSocketLibevent::Write(IOBuffer* buf, if (nwrite >= 0) { base::StatsCounter write_bytes("tcp.write_bytes"); write_bytes.Add(nwrite); - if (nwrite > 0) - use_history_.set_was_used_to_convey_data(); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, nwrite, buf->data()); return nwrite; @@ -515,56 +408,67 @@ int TCPClientSocketLibevent::Write(IOBuffer* buf, return ERR_IO_PENDING; } -int TCPClientSocketLibevent::InternalWrite(IOBuffer* buf, int buf_len) { - int nwrite; - if (use_tcp_fastopen_ && !tcp_fastopen_connected_) { - SockaddrStorage storage; - if (!addresses_[current_address_index_].ToSockAddr(storage.addr, - &storage.addr_len)) { - errno = EINVAL; - return -1; - } +int TCPSocketLibevent::GetLocalAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); - int flags = 0x20000000; // Magic flag to enable TCP_FASTOPEN. -#if defined(OS_LINUX) - // sendto() will fail with EPIPE when the system doesn't support TCP Fast - // Open. Theoretically that shouldn't happen since the caller should check - // for system support on startup, but users may dynamically disable TCP Fast - // Open via sysctl. - flags |= MSG_NOSIGNAL; -#endif // defined(OS_LINUX) - nwrite = HANDLE_EINTR(sendto(socket_, - buf->data(), - buf_len, - flags, - storage.addr, - storage.addr_len)); - tcp_fastopen_connected_ = true; + SockaddrStorage storage; + if (getsockname(socket_, storage.addr, &storage.addr_len) < 0) + return MapSystemError(errno); + if (!address->FromSockAddr(storage.addr, storage.addr_len)) + return ERR_ADDRESS_INVALID; - if (nwrite < 0) { - DCHECK_NE(EPIPE, errno); + return OK; +} - // If errno == EINPROGRESS, that means the kernel didn't have a cookie - // and would block. The kernel is internally doing a connect() though. - // Remap EINPROGRESS to EAGAIN so we treat this the same as our other - // asynchronous cases. Note that the user buffer has not been copied to - // kernel space. - if (errno == EINPROGRESS) { - errno = EAGAIN; - fast_open_status_ = FAST_OPEN_SLOW_CONNECT_RETURN; - } else { - fast_open_status_ = FAST_OPEN_ERROR; - } - } else { - fast_open_status_ = FAST_OPEN_FAST_CONNECT_RETURN; - } - } else { - nwrite = HANDLE_EINTR(write(socket_, buf->data(), buf_len)); - } - return nwrite; +int TCPSocketLibevent::GetPeerAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + *address = *peer_address_; + return OK; } -bool TCPClientSocketLibevent::SetReceiveBufferSize(int32 size) { +int TCPSocketLibevent::SetDefaultOptionsForServer() { + DCHECK(CalledOnValidThread()); + return SetAddressReuse(true); +} + +void TCPSocketLibevent::SetDefaultOptionsForClient() { + DCHECK(CalledOnValidThread()); + + // This mirrors the behaviour on Windows. See the comment in + // tcp_socket_win.cc after searching for "NODELAY". + SetTCPNoDelay(socket_, true); // If SetTCPNoDelay fails, we don't care. + SetTCPKeepAlive(socket_, true, kTCPKeepAliveSeconds); +} + +int TCPSocketLibevent::SetAddressReuse(bool allow) { + DCHECK(CalledOnValidThread()); + + // SO_REUSEADDR is useful for server sockets to bind to a recently unbound + // port. When a socket is closed, the end point changes its state to TIME_WAIT + // and wait for 2 MSL (maximum segment lifetime) to ensure the remote peer + // acknowledges its closure. For server sockets, it is usually safe to + // bind to a TIME_WAIT end point immediately, which is a widely adopted + // behavior. + // + // Note that on *nix, SO_REUSEADDR does not enable the TCP socket to bind to + // an end point that is already bound by another socket. To do that one must + // set SO_REUSEPORT instead. This option is not provided on Linux prior + // to 3.9. + // + // SO_REUSEPORT is provided in MacOS X and iOS. + int boolean_value = allow ? 1 : 0; + int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &boolean_value, + sizeof(boolean_value)); + if (rv < 0) + return MapSystemError(errno); + return OK; +} + +bool TCPSocketLibevent::SetReceiveBufferSize(int32 size) { DCHECK(CalledOnValidThread()); int rv = setsockopt(socket_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<const char*>(&size), @@ -573,7 +477,7 @@ bool TCPClientSocketLibevent::SetReceiveBufferSize(int32 size) { return rv == 0; } -bool TCPClientSocketLibevent::SetSendBufferSize(int32 size) { +bool TCPSocketLibevent::SetSendBufferSize(int32 size) { DCHECK(CalledOnValidThread()); int rv = setsockopt(socket_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<const char*>(&size), @@ -582,31 +486,180 @@ bool TCPClientSocketLibevent::SetSendBufferSize(int32 size) { return rv == 0; } -bool TCPClientSocketLibevent::SetKeepAlive(bool enable, int delay) { - int socket = socket_ != kInvalidSocket ? socket_ : bound_socket_; - return SetTCPKeepAlive(socket, enable, delay); +bool TCPSocketLibevent::SetKeepAlive(bool enable, int delay) { + DCHECK(CalledOnValidThread()); + return SetTCPKeepAlive(socket_, enable, delay); } -bool TCPClientSocketLibevent::SetNoDelay(bool no_delay) { - int socket = socket_ != kInvalidSocket ? socket_ : bound_socket_; - return SetTCPNoDelay(socket, no_delay); +bool TCPSocketLibevent::SetNoDelay(bool no_delay) { + DCHECK(CalledOnValidThread()); + return SetTCPNoDelay(socket_, no_delay); } -void TCPClientSocketLibevent::ReadWatcher::OnFileCanReadWithoutBlocking(int) { - socket_->RecordFastOpenStatus(); - if (!socket_->read_callback_.is_null()) - socket_->DidCompleteRead(); +void TCPSocketLibevent::Close() { + DCHECK(CalledOnValidThread()); + + bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + ok = read_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + ok = write_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + + if (socket_ != kInvalidSocket) { + if (HANDLE_EINTR(close(socket_)) < 0) + PLOG(ERROR) << "close"; + socket_ = kInvalidSocket; + } + + if (!accept_callback_.is_null()) { + accept_socket_ = NULL; + accept_address_ = NULL; + accept_callback_.Reset(); + } + + if (!read_callback_.is_null()) { + read_buf_ = NULL; + read_buf_len_ = 0; + read_callback_.Reset(); + } + + if (!write_callback_.is_null()) { + write_buf_ = NULL; + write_buf_len_ = 0; + write_callback_.Reset(); + } + + tcp_fastopen_connected_ = false; + fast_open_status_ = FAST_OPEN_STATUS_UNKNOWN; + waiting_connect_ = false; + peer_address_.reset(); + connect_os_error_ = 0; +} + +bool TCPSocketLibevent::UsingTCPFastOpen() const { + return use_tcp_fastopen_; +} + +void TCPSocketLibevent::StartLoggingMultipleConnectAttempts( + const AddressList& addresses) { + if (!logging_multiple_connect_attempts_) { + logging_multiple_connect_attempts_ = true; + LogConnectBegin(addresses); + } else { + NOTREACHED(); + } +} + +void TCPSocketLibevent::EndLoggingMultipleConnectAttempts(int net_error) { + if (logging_multiple_connect_attempts_) { + LogConnectEnd(net_error); + logging_multiple_connect_attempts_ = false; + } else { + NOTREACHED(); + } +} + +int TCPSocketLibevent::AcceptInternal(scoped_ptr<TCPSocketLibevent>* socket, + IPEndPoint* address) { + SockaddrStorage storage; + int new_socket = HANDLE_EINTR(accept(socket_, + storage.addr, + &storage.addr_len)); + if (new_socket < 0) { + int net_error = MapSystemError(errno); + if (net_error != ERR_IO_PENDING) + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); + return net_error; + } + + IPEndPoint ip_end_point; + if (!ip_end_point.FromSockAddr(storage.addr, storage.addr_len)) { + NOTREACHED(); + if (HANDLE_EINTR(close(new_socket)) < 0) + PLOG(ERROR) << "close"; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, + ERR_ADDRESS_INVALID); + return ERR_ADDRESS_INVALID; + } + scoped_ptr<TCPSocketLibevent> tcp_socket(new TCPSocketLibevent( + net_log_.net_log(), net_log_.source())); + int adopt_result = tcp_socket->AdoptConnectedSocket(new_socket, ip_end_point); + if (adopt_result != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result); + return adopt_result; + } + *socket = tcp_socket.Pass(); + *address = ip_end_point; + net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, + CreateNetLogIPEndPointCallback(&ip_end_point)); + return OK; +} + +int TCPSocketLibevent::DoConnect() { + DCHECK_EQ(0, connect_os_error_); + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + CreateNetLogIPEndPointCallback(peer_address_.get())); + + // Connect the socket. + if (!use_tcp_fastopen_) { + SockaddrStorage storage; + if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_INVALID_ARGUMENT; + + if (!HANDLE_EINTR(connect(socket_, storage.addr, storage.addr_len))) { + // Connected without waiting! + return OK; + } + } else { + // With TCP FastOpen, we pretend that the socket is connected. + DCHECK(!tcp_fastopen_connected_); + return OK; + } + + // Check if the connect() failed synchronously. + connect_os_error_ = errno; + if (connect_os_error_ != EINPROGRESS) + return MapConnectError(connect_os_error_); + + // Otherwise the connect() is going to complete asynchronously, so watch + // for its completion. + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, base::MessageLoopForIO::WATCH_WRITE, + &write_socket_watcher_, &write_watcher_)) { + connect_os_error_ = errno; + DVLOG(1) << "WatchFileDescriptor failed: " << connect_os_error_; + return MapSystemError(connect_os_error_); + } + + return ERR_IO_PENDING; } -void TCPClientSocketLibevent::WriteWatcher::OnFileCanWriteWithoutBlocking(int) { - if (socket_->waiting_connect()) { - socket_->DidCompleteConnect(); - } else if (!socket_->write_callback_.is_null()) { - socket_->DidCompleteWrite(); +void TCPSocketLibevent::DoConnectComplete(int result) { + // Log the end of this attempt (and any OS error it threw). + int os_error = connect_os_error_; + connect_os_error_ = 0; + if (result != OK) { + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + NetLog::IntegerCallback("os_error", os_error)); + } else { + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT); } + + if (!logging_multiple_connect_attempts_) + LogConnectEnd(result); } -void TCPClientSocketLibevent::LogConnectCompletion(int net_error) { +void TCPSocketLibevent::LogConnectBegin(const AddressList& addresses) { + base::StatsCounter connects("tcp.connect"); + connects.Increment(); + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, + addresses.CreateNetLogCallback()); +} + +void TCPSocketLibevent::LogConnectEnd(int net_error) { if (net_error == OK) UpdateConnectionTypeHistograms(CONNECTION_ANY); @@ -629,50 +682,11 @@ void TCPClientSocketLibevent::LogConnectCompletion(int net_error) { storage.addr_len)); } -void TCPClientSocketLibevent::DoReadCallback(int rv) { - DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(!read_callback_.is_null()); - - // since Run may result in Read being called, clear read_callback_ up front. - CompletionCallback c = read_callback_; - read_callback_.Reset(); - c.Run(rv); -} - -void TCPClientSocketLibevent::DoWriteCallback(int rv) { - DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(!write_callback_.is_null()); - - // since Run may result in Write being called, clear write_callback_ up front. - CompletionCallback c = write_callback_; - write_callback_.Reset(); - c.Run(rv); -} - -void TCPClientSocketLibevent::DidCompleteConnect() { - DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE); - - // Get the error that connect() completed with. - int os_error = 0; - socklen_t len = sizeof(os_error); - if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &os_error, &len) < 0) - os_error = errno; - - // TODO(eroman): Is this check really necessary? - if (os_error == EINPROGRESS || os_error == EALREADY) { - NOTREACHED(); // This indicates a bug in libevent or our code. +void TCPSocketLibevent::DidCompleteRead() { + RecordFastOpenStatus(); + if (read_callback_.is_null()) return; - } - connect_os_error_ = os_error; - int rv = DoConnectLoop(MapConnectError(os_error)); - if (rv != ERR_IO_PENDING) { - LogConnectCompletion(rv); - DoWriteCallback(rv); - } -} - -void TCPClientSocketLibevent::DidCompleteRead() { int bytes_transferred; bytes_transferred = HANDLE_EINTR(read(socket_, read_buf_->data(), read_buf_len_)); @@ -682,8 +696,6 @@ void TCPClientSocketLibevent::DidCompleteRead() { result = bytes_transferred; base::StatsCounter read_bytes("tcp.read_bytes"); read_bytes.Add(bytes_transferred); - if (bytes_transferred > 0) - use_history_.set_was_used_to_convey_data(); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, result, read_buf_->data()); } else { @@ -699,11 +711,14 @@ void TCPClientSocketLibevent::DidCompleteRead() { read_buf_len_ = 0; bool ok = read_socket_watcher_.StopWatchingFileDescriptor(); DCHECK(ok); - DoReadCallback(result); + base::ResetAndReturn(&read_callback_).Run(result); } } -void TCPClientSocketLibevent::DidCompleteWrite() { +void TCPSocketLibevent::DidCompleteWrite() { + if (write_callback_.is_null()) + return; + int bytes_transferred; bytes_transferred = HANDLE_EINTR(write(socket_, write_buf_->data(), write_buf_len_)); @@ -713,8 +728,6 @@ void TCPClientSocketLibevent::DidCompleteWrite() { result = bytes_transferred; base::StatsCounter write_bytes("tcp.write_bytes"); write_bytes.Add(bytes_transferred); - if (bytes_transferred > 0) - use_history_.set_was_used_to_convey_data(); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, result, write_buf_->data()); } else { @@ -729,40 +742,100 @@ void TCPClientSocketLibevent::DidCompleteWrite() { write_buf_ = NULL; write_buf_len_ = 0; write_socket_watcher_.StopWatchingFileDescriptor(); - DoWriteCallback(result); + base::ResetAndReturn(&write_callback_).Run(result); } } -int TCPClientSocketLibevent::GetPeerAddress(IPEndPoint* address) const { - DCHECK(CalledOnValidThread()); - DCHECK(address); - if (!IsConnected()) - return ERR_SOCKET_NOT_CONNECTED; - *address = addresses_[current_address_index_]; - return OK; +void TCPSocketLibevent::DidCompleteConnect() { + DCHECK(waiting_connect_); + + // Get the error that connect() completed with. + int os_error = 0; + socklen_t len = sizeof(os_error); + if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &os_error, &len) < 0) + os_error = errno; + + int result = MapConnectError(os_error); + connect_os_error_ = os_error; + if (result != ERR_IO_PENDING) { + DoConnectComplete(result); + waiting_connect_ = false; + write_socket_watcher_.StopWatchingFileDescriptor(); + base::ResetAndReturn(&write_callback_).Run(result); + } +} + +void TCPSocketLibevent::DidCompleteConnectOrWrite() { + if (waiting_connect_) + DidCompleteConnect(); + else + DidCompleteWrite(); } -int TCPClientSocketLibevent::GetLocalAddress(IPEndPoint* address) const { +void TCPSocketLibevent::DidCompleteAccept() { DCHECK(CalledOnValidThread()); - DCHECK(address); - if (socket_ == kInvalidSocket) { - if (bind_address_.get()) { - *address = *bind_address_; - return OK; - } - return ERR_SOCKET_NOT_CONNECTED; + + int result = AcceptInternal(accept_socket_, accept_address_); + if (result != ERR_IO_PENDING) { + accept_socket_ = NULL; + accept_address_ = NULL; + bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + CompletionCallback callback = accept_callback_; + accept_callback_.Reset(); + callback.Run(result); } +} - SockaddrStorage storage; - if (getsockname(socket_, storage.addr, &storage.addr_len)) - return MapSystemError(errno); - if (!address->FromSockAddr(storage.addr, storage.addr_len)) - return ERR_FAILED; +int TCPSocketLibevent::InternalWrite(IOBuffer* buf, int buf_len) { + int nwrite; + if (use_tcp_fastopen_ && !tcp_fastopen_connected_) { + SockaddrStorage storage; + if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len)) { + errno = EINVAL; + return -1; + } - return OK; + int flags = 0x20000000; // Magic flag to enable TCP_FASTOPEN. +#if defined(OS_LINUX) + // sendto() will fail with EPIPE when the system doesn't support TCP Fast + // Open. Theoretically that shouldn't happen since the caller should check + // for system support on startup, but users may dynamically disable TCP Fast + // Open via sysctl. + flags |= MSG_NOSIGNAL; +#endif // defined(OS_LINUX) + nwrite = HANDLE_EINTR(sendto(socket_, + buf->data(), + buf_len, + flags, + storage.addr, + storage.addr_len)); + tcp_fastopen_connected_ = true; + + if (nwrite < 0) { + DCHECK_NE(EPIPE, errno); + + // If errno == EINPROGRESS, that means the kernel didn't have a cookie + // and would block. The kernel is internally doing a connect() though. + // Remap EINPROGRESS to EAGAIN so we treat this the same as our other + // asynchronous cases. Note that the user buffer has not been copied to + // kernel space. + if (errno == EINPROGRESS) { + errno = EAGAIN; + fast_open_status_ = FAST_OPEN_SLOW_CONNECT_RETURN; + } else { + fast_open_status_ = FAST_OPEN_ERROR; + } + } else { + fast_open_status_ = FAST_OPEN_FAST_CONNECT_RETURN; + } + } else { + nwrite = HANDLE_EINTR(write(socket_, buf->data(), buf_len)); + } + return nwrite; } -void TCPClientSocketLibevent::RecordFastOpenStatus() { +void TCPSocketLibevent::RecordFastOpenStatus() { if (use_tcp_fastopen_ && (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN || fast_open_status_ == FAST_OPEN_SLOW_CONNECT_RETURN)) { @@ -795,36 +868,4 @@ void TCPClientSocketLibevent::RecordFastOpenStatus() { } } -const BoundNetLog& TCPClientSocketLibevent::NetLog() const { - return net_log_; -} - -void TCPClientSocketLibevent::SetSubresourceSpeculation() { - use_history_.set_subresource_speculation(); -} - -void TCPClientSocketLibevent::SetOmniboxSpeculation() { - use_history_.set_omnibox_speculation(); -} - -bool TCPClientSocketLibevent::WasEverUsed() const { - return use_history_.was_used_to_convey_data(); -} - -bool TCPClientSocketLibevent::UsingTCPFastOpen() const { - return use_tcp_fastopen_; -} - -bool TCPClientSocketLibevent::WasNpnNegotiated() const { - return false; -} - -NextProto TCPClientSocketLibevent::GetNegotiatedProtocol() const { - return kProtoUnknown; -} - -bool TCPClientSocketLibevent::GetSSLInfo(SSLInfo* ssl_info) { - return false; -} - } // namespace net diff --git a/chromium/net/socket/tcp_socket_libevent.h b/chromium/net/socket/tcp_socket_libevent.h new file mode 100644 index 00000000000..a50caf0ad59 --- /dev/null +++ b/chromium/net/socket/tcp_socket_libevent.h @@ -0,0 +1,235 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_SOCKET_LIBEVENT_H_ +#define NET_SOCKET_TCP_SOCKET_LIBEVENT_H_ + +#include "base/basictypes.h" +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/threading/non_thread_safe.h" +#include "net/base/address_family.h" +#include "net/base/completion_callback.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" +#include "net/socket/socket_descriptor.h" + +namespace net { + +class AddressList; +class IOBuffer; +class IPEndPoint; + +class NET_EXPORT TCPSocketLibevent : public base::NonThreadSafe { + public: + TCPSocketLibevent(NetLog* net_log, const NetLog::Source& source); + virtual ~TCPSocketLibevent(); + + int Open(AddressFamily family); + // Takes ownership of |socket|. + int AdoptConnectedSocket(int socket, const IPEndPoint& peer_address); + + int Bind(const IPEndPoint& address); + + int Listen(int backlog); + int Accept(scoped_ptr<TCPSocketLibevent>* socket, + IPEndPoint* address, + const CompletionCallback& callback); + + int Connect(const IPEndPoint& address, const CompletionCallback& callback); + bool IsConnected() const; + bool IsConnectedAndIdle() const; + + // Multiple outstanding requests are not supported. + // Full duplex mode (reading and writing at the same time) is supported. + int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + + int GetLocalAddress(IPEndPoint* address) const; + int GetPeerAddress(IPEndPoint* address) const; + + // Sets various socket options. + // The commonly used options for server listening sockets: + // - SetAddressReuse(true). + int SetDefaultOptionsForServer(); + // The commonly used options for client sockets and accepted sockets: + // - SetNoDelay(true); + // - SetKeepAlive(true, 45). + void SetDefaultOptionsForClient(); + int SetAddressReuse(bool allow); + bool SetReceiveBufferSize(int32 size); + bool SetSendBufferSize(int32 size); + bool SetKeepAlive(bool enable, int delay); + bool SetNoDelay(bool no_delay); + + void Close(); + + bool UsingTCPFastOpen() const; + bool IsValid() const { return socket_ != kInvalidSocket; } + + // Marks the start/end of a series of connect attempts for logging purpose. + // + // TCPClientSocket may attempt to connect to multiple addresses until it + // succeeds in establishing a connection. The corresponding log will have + // multiple NetLog::TYPE_TCP_CONNECT_ATTEMPT entries nested within a + // NetLog::TYPE_TCP_CONNECT. These methods set the start/end of + // NetLog::TYPE_TCP_CONNECT. + // + // TODO(yzshen): Change logging format and let TCPClientSocket log the + // start/end of a series of connect attempts itself. + void StartLoggingMultipleConnectAttempts(const AddressList& addresses); + void EndLoggingMultipleConnectAttempts(int net_error); + + const BoundNetLog& net_log() const { return net_log_; } + + private: + // States that a fast open socket attempt can result in. + enum FastOpenStatus { + FAST_OPEN_STATUS_UNKNOWN, + + // The initial fast open connect attempted returned synchronously, + // indicating that we had and sent a cookie along with the initial data. + FAST_OPEN_FAST_CONNECT_RETURN, + + // The initial fast open connect attempted returned asynchronously, + // indicating that we did not have a cookie for the server. + FAST_OPEN_SLOW_CONNECT_RETURN, + + // Some other error occurred on connection, so we couldn't tell if + // fast open would have worked. + FAST_OPEN_ERROR, + + // An attempt to do a fast open succeeded immediately + // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server + // had acked the data we sent. + FAST_OPEN_SYN_DATA_ACK, + + // An attempt to do a fast open succeeded immediately + // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server + // had nacked the data we sent. + FAST_OPEN_SYN_DATA_NACK, + + // An attempt to do a fast open succeeded immediately + // (FAST_OPEN_FAST_CONNECT_RETURN) and our probe to determine if the + // socket was using fast open failed. + FAST_OPEN_SYN_DATA_FAILED, + + // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) + // and we later confirmed that the server had acked initial data. This + // should never happen (we didn't send data, so it shouldn't have + // been acked). + FAST_OPEN_NO_SYN_DATA_ACK, + + // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) + // and we later discovered that the server had nacked initial data. This + // is the expected case results for FAST_OPEN_SLOW_CONNECT_RETURN. + FAST_OPEN_NO_SYN_DATA_NACK, + + // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) + // and our later probe for ack/nack state failed. + FAST_OPEN_NO_SYN_DATA_FAILED, + + FAST_OPEN_MAX_VALUE + }; + + // Watcher simply forwards notifications to Closure objects set via the + // constructor. + class Watcher: public base::MessageLoopForIO::Watcher { + public: + Watcher(const base::Closure& read_ready_callback, + const base::Closure& write_ready_callback); + virtual ~Watcher(); + + // base::MessageLoopForIO::Watcher methods. + virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE; + virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE; + + private: + base::Closure read_ready_callback_; + base::Closure write_ready_callback_; + + DISALLOW_COPY_AND_ASSIGN(Watcher); + }; + + int AcceptInternal(scoped_ptr<TCPSocketLibevent>* socket, + IPEndPoint* address); + + int DoConnect(); + void DoConnectComplete(int result); + + void LogConnectBegin(const AddressList& addresses); + void LogConnectEnd(int net_error); + + void DidCompleteRead(); + void DidCompleteWrite(); + void DidCompleteConnect(); + void DidCompleteConnectOrWrite(); + void DidCompleteAccept(); + + // Internal function to write to a socket. Returns an OS error. + int InternalWrite(IOBuffer* buf, int buf_len); + + // Called when the socket is known to be in a connected state. + void RecordFastOpenStatus(); + + int socket_; + + base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_; + Watcher accept_watcher_; + + scoped_ptr<TCPSocketLibevent>* accept_socket_; + IPEndPoint* accept_address_; + CompletionCallback accept_callback_; + + // The socket's libevent wrappers for reads and writes. + base::MessageLoopForIO::FileDescriptorWatcher read_socket_watcher_; + base::MessageLoopForIO::FileDescriptorWatcher write_socket_watcher_; + + // The corresponding watchers for reads and writes. + Watcher read_watcher_; + Watcher write_watcher_; + + // The buffer used for reads. + scoped_refptr<IOBuffer> read_buf_; + int read_buf_len_; + + // The buffer used for writes. + scoped_refptr<IOBuffer> write_buf_; + int write_buf_len_; + + // External callback; called when read is complete. + CompletionCallback read_callback_; + + // External callback; called when write or connect is complete. + CompletionCallback write_callback_; + + // Enables experimental TCP FastOpen option. + const bool use_tcp_fastopen_; + + // True when TCP FastOpen is in use and we have done the connect. + bool tcp_fastopen_connected_; + + FastOpenStatus fast_open_status_; + + // A connect operation is pending. In this case, |write_callback_| needs to be + // called when connect is complete. + bool waiting_connect_; + + scoped_ptr<IPEndPoint> peer_address_; + // The OS error that a connect attempt last completed with. + int connect_os_error_; + + bool logging_multiple_connect_attempts_; + + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(TCPSocketLibevent); +}; + +} // namespace net + +#endif // NET_SOCKET_TCP_SOCKET_LIBEVENT_H_ diff --git a/chromium/net/socket/tcp_socket_unittest.cc b/chromium/net/socket/tcp_socket_unittest.cc new file mode 100644 index 00000000000..a45fcba016b --- /dev/null +++ b/chromium/net/socket/tcp_socket_unittest.cc @@ -0,0 +1,263 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_socket.h" + +#include <string.h> + +#include <string> +#include <vector> + +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/address_list.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/tcp_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +namespace net { + +namespace { +const int kListenBacklog = 5; + +class TCPSocketTest : public PlatformTest { + protected: + TCPSocketTest() : socket_(NULL, NetLog::Source()) { + } + + void SetUpListenIPv4() { + IPEndPoint address; + ParseAddress("127.0.0.1", 0, &address); + + ASSERT_EQ(OK, socket_.Open(ADDRESS_FAMILY_IPV4)); + ASSERT_EQ(OK, socket_.Bind(address)); + ASSERT_EQ(OK, socket_.Listen(kListenBacklog)); + ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); + } + + void SetUpListenIPv6(bool* success) { + *success = false; + IPEndPoint address; + ParseAddress("::1", 0, &address); + + if (socket_.Open(ADDRESS_FAMILY_IPV6) != OK || + socket_.Bind(address) != OK || + socket_.Listen(kListenBacklog) != OK) { + LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is " + "disabled. Skipping the test"; + return; + } + ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); + *success = true; + } + + void ParseAddress(const std::string& ip_str, int port, IPEndPoint* address) { + IPAddressNumber ip_number; + bool rv = ParseIPLiteralToNumber(ip_str, &ip_number); + if (!rv) + return; + *address = IPEndPoint(ip_number, port); + } + + AddressList local_address_list() const { + return AddressList(local_address_); + } + + TCPSocket socket_; + IPEndPoint local_address_; +}; + +// Test listening and accepting with a socket bound to an IPv4 address. +TEST_F(TCPSocketTest, Accept) { + ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4()); + + TestCompletionCallback connect_callback; + // TODO(yzshen): Switch to use TCPSocket when it supports client socket + // operations. + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + TestCompletionCallback accept_callback; + scoped_ptr<TCPSocket> accepted_socket; + IPEndPoint accepted_address; + int result = socket_.Accept(&accepted_socket, &accepted_address, + accept_callback.callback()); + if (result == ERR_IO_PENDING) + result = accept_callback.WaitForResult(); + ASSERT_EQ(OK, result); + + EXPECT_TRUE(accepted_socket.get()); + + // Both sockets should be on the loopback network interface. + EXPECT_EQ(accepted_address.address(), local_address_.address()); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); +} + +// Test Accept() callback. +TEST_F(TCPSocketTest, AcceptAsync) { + ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4()); + + TestCompletionCallback accept_callback; + scoped_ptr<TCPSocket> accepted_socket; + IPEndPoint accepted_address; + ASSERT_EQ(ERR_IO_PENDING, + socket_.Accept(&accepted_socket, &accepted_address, + accept_callback.callback())); + + TestCompletionCallback connect_callback; + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); + EXPECT_EQ(OK, accept_callback.WaitForResult()); + + EXPECT_TRUE(accepted_socket.get()); + + // Both sockets should be on the loopback network interface. + EXPECT_EQ(accepted_address.address(), local_address_.address()); +} + +// Accept two connections simultaneously. +TEST_F(TCPSocketTest, Accept2Connections) { + ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4()); + + TestCompletionCallback accept_callback; + scoped_ptr<TCPSocket> accepted_socket; + IPEndPoint accepted_address; + + ASSERT_EQ(ERR_IO_PENDING, + socket_.Accept(&accepted_socket, &accepted_address, + accept_callback.callback())); + + TestCompletionCallback connect_callback; + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + TestCompletionCallback connect_callback2; + TCPClientSocket connecting_socket2(local_address_list(), + NULL, NetLog::Source()); + connecting_socket2.Connect(connect_callback2.callback()); + + EXPECT_EQ(OK, accept_callback.WaitForResult()); + + TestCompletionCallback accept_callback2; + scoped_ptr<TCPSocket> accepted_socket2; + IPEndPoint accepted_address2; + + int result = socket_.Accept(&accepted_socket2, &accepted_address2, + accept_callback2.callback()); + if (result == ERR_IO_PENDING) + result = accept_callback2.WaitForResult(); + ASSERT_EQ(OK, result); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); + EXPECT_EQ(OK, connect_callback2.WaitForResult()); + + EXPECT_TRUE(accepted_socket.get()); + EXPECT_TRUE(accepted_socket2.get()); + EXPECT_NE(accepted_socket.get(), accepted_socket2.get()); + + EXPECT_EQ(accepted_address.address(), local_address_.address()); + EXPECT_EQ(accepted_address2.address(), local_address_.address()); +} + +// Test listening and accepting with a socket bound to an IPv6 address. +TEST_F(TCPSocketTest, AcceptIPv6) { + bool initialized = false; + ASSERT_NO_FATAL_FAILURE(SetUpListenIPv6(&initialized)); + if (!initialized) + return; + + TestCompletionCallback connect_callback; + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + TestCompletionCallback accept_callback; + scoped_ptr<TCPSocket> accepted_socket; + IPEndPoint accepted_address; + int result = socket_.Accept(&accepted_socket, &accepted_address, + accept_callback.callback()); + if (result == ERR_IO_PENDING) + result = accept_callback.WaitForResult(); + ASSERT_EQ(OK, result); + + EXPECT_TRUE(accepted_socket.get()); + + // Both sockets should be on the loopback network interface. + EXPECT_EQ(accepted_address.address(), local_address_.address()); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); +} + +TEST_F(TCPSocketTest, ReadWrite) { + ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4()); + + TestCompletionCallback connect_callback; + TCPSocket connecting_socket(NULL, NetLog::Source()); + int result = connecting_socket.Open(ADDRESS_FAMILY_IPV4); + ASSERT_EQ(OK, result); + connecting_socket.Connect(local_address_, connect_callback.callback()); + + TestCompletionCallback accept_callback; + scoped_ptr<TCPSocket> accepted_socket; + IPEndPoint accepted_address; + result = socket_.Accept(&accepted_socket, &accepted_address, + accept_callback.callback()); + ASSERT_EQ(OK, accept_callback.GetResult(result)); + + ASSERT_TRUE(accepted_socket.get()); + + // Both sockets should be on the loopback network interface. + EXPECT_EQ(accepted_address.address(), local_address_.address()); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); + + const std::string message("test message"); + std::vector<char> buffer(message.size()); + + size_t bytes_written = 0; + while (bytes_written < message.size()) { + scoped_refptr<IOBufferWithSize> write_buffer( + new IOBufferWithSize(message.size() - bytes_written)); + memmove(write_buffer->data(), message.data() + bytes_written, + message.size() - bytes_written); + + TestCompletionCallback write_callback; + int write_result = accepted_socket->Write( + write_buffer.get(), write_buffer->size(), write_callback.callback()); + write_result = write_callback.GetResult(write_result); + ASSERT_TRUE(write_result >= 0); + bytes_written += write_result; + ASSERT_TRUE(bytes_written <= message.size()); + } + + size_t bytes_read = 0; + while (bytes_read < message.size()) { + scoped_refptr<IOBufferWithSize> read_buffer( + new IOBufferWithSize(message.size() - bytes_read)); + TestCompletionCallback read_callback; + int read_result = connecting_socket.Read( + read_buffer.get(), read_buffer->size(), read_callback.callback()); + read_result = read_callback.GetResult(read_result); + ASSERT_TRUE(read_result >= 0); + ASSERT_TRUE(bytes_read + read_result <= message.size()); + memmove(&buffer[bytes_read], read_buffer->data(), read_result); + bytes_read += read_result; + } + + std::string received_message(buffer.begin(), buffer.end()); + ASSERT_EQ(message, received_message); +} + +} // namespace +} // namespace net diff --git a/chromium/net/socket/tcp_client_socket_win.cc b/chromium/net/socket/tcp_socket_win.cc index 9b0a5b50bf1..7d76232f962 100644 --- a/chromium/net/socket/tcp_client_socket_win.cc +++ b/chromium/net/socket/tcp_socket_win.cc @@ -1,26 +1,25 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/socket/tcp_client_socket_win.h" +#include "net/socket/tcp_socket_win.h" #include <mstcpip.h> -#include "base/basictypes.h" -#include "base/compiler_specific.h" +#include "base/callback_helpers.h" +#include "base/logging.h" #include "base/metrics/stats_counters.h" -#include "base/strings/string_util.h" -#include "base/win/object_watcher.h" #include "base/win/windows_version.h" +#include "net/base/address_list.h" #include "net/base/connection_type_histograms.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" #include "net/base/net_util.h" #include "net/base/network_change_notifier.h" #include "net/base/winsock_init.h" #include "net/base/winsock_util.h" +#include "net/socket/socket_descriptor.h" #include "net/socket/socket_net_log_params.h" namespace net { @@ -28,7 +27,6 @@ namespace net { namespace { const int kTCPKeepAliveSeconds = 45; -bool g_disable_overlapped_reads = false; bool SetSocketReceiveBufferSize(SOCKET socket, int32 size) { int rv = setsockopt(socket, SOL_SOCKET, SO_RCVBUF, @@ -86,8 +84,8 @@ bool SetTCPKeepAlive(SOCKET socket, BOOL enable, int delay_secs) { }; DWORD bytes_returned = 0xABAB; int rv = WSAIoctl(socket, SIO_KEEPALIVE_VALS, &keepalive_vals, - sizeof(keepalive_vals), NULL, 0, - &bytes_returned, NULL, NULL); + sizeof(keepalive_vals), NULL, 0, + &bytes_returned, NULL, NULL); DCHECK(!rv) << "Could not enable TCP Keep-Alive for socket: " << socket << " [error: " << WSAGetLastError() << "]."; @@ -95,49 +93,6 @@ bool SetTCPKeepAlive(SOCKET socket, BOOL enable, int delay_secs) { return rv == 0; } -// Sets socket parameters. Returns the OS error code (or 0 on -// success). -int SetupSocket(SOCKET socket) { - // Increase the socket buffer sizes from the default sizes for WinXP. In - // performance testing, there is substantial benefit by increasing from 8KB - // to 64KB. - // See also: - // http://support.microsoft.com/kb/823764/EN-US - // On Vista, if we manually set these sizes, Vista turns off its receive - // window auto-tuning feature. - // http://blogs.msdn.com/wndp/archive/2006/05/05/Winhec-blog-tcpip-2.aspx - // Since Vista's auto-tune is better than any static value we can could set, - // only change these on pre-vista machines. - if (base::win::GetVersion() < base::win::VERSION_VISTA) { - const int32 kSocketBufferSize = 64 * 1024; - SetSocketReceiveBufferSize(socket, kSocketBufferSize); - SetSocketSendBufferSize(socket, kSocketBufferSize); - } - - DisableNagle(socket, true); - SetTCPKeepAlive(socket, true, kTCPKeepAliveSeconds); - return 0; -} - -// Creates a new socket and sets default parameters for it. Returns -// the OS error code (or 0 on success). -int CreateSocket(int family, SOCKET* socket) { - *socket = CreatePlatformSocket(family, SOCK_STREAM, IPPROTO_TCP); - if (*socket == INVALID_SOCKET) { - int os_error = WSAGetLastError(); - LOG(ERROR) << "CreatePlatformSocket failed: " << os_error; - return os_error; - } - int error = SetupSocket(*socket); - if (error) { - if (closesocket(*socket) < 0) - PLOG(ERROR) << "closesocket"; - *socket = INVALID_SOCKET; - return error; - } - return 0; -} - int MapConnectError(int os_error) { switch (os_error) { // connect fails with WSAEACCES when Windows Firewall blocks the @@ -167,31 +122,21 @@ int MapConnectError(int os_error) { //----------------------------------------------------------------------------- // This class encapsulates all the state that has to be preserved as long as -// there is a network IO operation in progress. If the owner TCPClientSocketWin -// is destroyed while an operation is in progress, the Core is detached and it +// there is a network IO operation in progress. If the owner TCPSocketWin is +// destroyed while an operation is in progress, the Core is detached and it // lives until the operation completes and the OS doesn't reference any resource // declared on this class anymore. -class TCPClientSocketWin::Core : public base::RefCounted<Core> { +class TCPSocketWin::Core : public base::RefCounted<Core> { public: - explicit Core(TCPClientSocketWin* socket); + explicit Core(TCPSocketWin* socket); // Start watching for the end of a read or write operation. void WatchForRead(); void WatchForWrite(); - // The TCPClientSocketWin is going away. + // The TCPSocketWin is going away. void Detach() { socket_ = NULL; } - // Throttle the read size based on our current slow start state. - // Returns the throttled read size. - int ThrottleReadSize(int size) { - if (slow_start_throttle_ < kMaxSlowStartThrottle) { - size = std::min(size, slow_start_throttle_); - slow_start_throttle_ *= 2; - } - return size; - } - // The separate OVERLAPPED variables for asynchronous operation. // |read_overlapped_| is used for both Connect() and Read(). // |write_overlapped_| is only used for Write(); @@ -204,9 +149,6 @@ class TCPClientSocketWin::Core : public base::RefCounted<Core> { int read_buffer_length_; int write_buffer_length_; - // Remember the state of g_disable_overlapped_reads for the duration of the - // socket based on what it was when the socket was created. - bool disable_overlapped_reads_; bool non_blocking_reads_initialized_; private: @@ -239,7 +181,7 @@ class TCPClientSocketWin::Core : public base::RefCounted<Core> { ~Core(); // The socket that created this object. - TCPClientSocketWin* socket_; + TCPSocketWin* socket_; // |reader_| handles the signals from |read_watcher_|. ReadDelegate reader_; @@ -251,26 +193,16 @@ class TCPClientSocketWin::Core : public base::RefCounted<Core> { // |write_watcher_| watches for events from Write(); base::win::ObjectWatcher write_watcher_; - // When doing reads from the socket, we try to mirror TCP's slow start. - // We do this because otherwise the async IO subsystem artifically delays - // returning data to the application. - static const int kInitialSlowStartThrottle = 1 * 1024; - static const int kMaxSlowStartThrottle = 32 * kInitialSlowStartThrottle; - int slow_start_throttle_; - DISALLOW_COPY_AND_ASSIGN(Core); }; -TCPClientSocketWin::Core::Core( - TCPClientSocketWin* socket) +TCPSocketWin::Core::Core(TCPSocketWin* socket) : read_buffer_length_(0), write_buffer_length_(0), - disable_overlapped_reads_(g_disable_overlapped_reads), non_blocking_reads_initialized_(false), socket_(socket), reader_(this), - writer_(this), - slow_start_throttle_(kInitialSlowStartThrottle) { + writer_(this) { memset(&read_overlapped_, 0, sizeof(read_overlapped_)); memset(&write_overlapped_, 0, sizeof(write_overlapped_)); @@ -278,7 +210,7 @@ TCPClientSocketWin::Core::Core( write_overlapped_.hEvent = WSACreateEvent(); } -TCPClientSocketWin::Core::~Core() { +TCPSocketWin::Core::~Core() { // Make sure the message loop is not watching this object anymore. read_watcher_.StopWatching(); write_watcher_.StopWatching(); @@ -289,37 +221,33 @@ TCPClientSocketWin::Core::~Core() { memset(&write_overlapped_, 0xaf, sizeof(write_overlapped_)); } -void TCPClientSocketWin::Core::WatchForRead() { +void TCPSocketWin::Core::WatchForRead() { // We grab an extra reference because there is an IO operation in progress. // Balanced in ReadDelegate::OnObjectSignaled(). AddRef(); read_watcher_.StartWatching(read_overlapped_.hEvent, &reader_); } -void TCPClientSocketWin::Core::WatchForWrite() { +void TCPSocketWin::Core::WatchForWrite() { // We grab an extra reference because there is an IO operation in progress. // Balanced in WriteDelegate::OnObjectSignaled(). AddRef(); write_watcher_.StartWatching(write_overlapped_.hEvent, &writer_); } -void TCPClientSocketWin::Core::ReadDelegate::OnObjectSignaled( - HANDLE object) { +void TCPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) { DCHECK_EQ(object, core_->read_overlapped_.hEvent); if (core_->socket_) { - if (core_->socket_->waiting_connect()) { + if (core_->socket_->waiting_connect_) core_->socket_->DidCompleteConnect(); - } else if (core_->disable_overlapped_reads_) { + else core_->socket_->DidSignalRead(); - } else { - core_->socket_->DidCompleteRead(); - } } core_->Release(); } -void TCPClientSocketWin::Core::WriteDelegate::OnObjectSignaled( +void TCPSocketWin::Core::WriteDelegate::OnObjectSignaled( HANDLE object) { DCHECK_EQ(object, core_->write_overlapped_.hEvent); if (core_->socket_) @@ -330,281 +258,170 @@ void TCPClientSocketWin::Core::WriteDelegate::OnObjectSignaled( //----------------------------------------------------------------------------- -TCPClientSocketWin::TCPClientSocketWin(const AddressList& addresses, - net::NetLog* net_log, - const net::NetLog::Source& source) +TCPSocketWin::TCPSocketWin(net::NetLog* net_log, + const net::NetLog::Source& source) : socket_(INVALID_SOCKET), - bound_socket_(INVALID_SOCKET), - addresses_(addresses), - current_address_index_(-1), + accept_event_(WSA_INVALID_EVENT), + accept_socket_(NULL), + accept_address_(NULL), + waiting_connect_(false), waiting_read_(false), waiting_write_(false), - next_connect_state_(CONNECT_STATE_NONE), connect_os_error_(0), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), - previously_disconnected_(false) { + logging_multiple_connect_attempts_(false), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, source.ToEventParametersCallback()); EnsureWinsockInit(); } -TCPClientSocketWin::~TCPClientSocketWin() { - Disconnect(); +TCPSocketWin::~TCPSocketWin() { + Close(); net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); } -int TCPClientSocketWin::AdoptSocket(SOCKET socket) { +int TCPSocketWin::Open(AddressFamily family) { + DCHECK(CalledOnValidThread()); DCHECK_EQ(socket_, INVALID_SOCKET); - int error = SetupSocket(socket); - if (error) - return MapSystemError(error); - - socket_ = socket; - SetNonBlocking(socket_); - - core_ = new Core(this); - current_address_index_ = 0; - use_history_.set_was_ever_connected(); - - return OK; -} - -int TCPClientSocketWin::Bind(const IPEndPoint& address) { - if (current_address_index_ >= 0 || bind_address_.get()) { - // Cannot bind the socket if we are already connected or connecting. - return ERR_UNEXPECTED; + socket_ = CreatePlatformSocket(ConvertAddressFamily(family), SOCK_STREAM, + IPPROTO_TCP); + if (socket_ == INVALID_SOCKET) { + PLOG(ERROR) << "CreatePlatformSocket() returned an error"; + return MapSystemError(WSAGetLastError()); } - SockaddrStorage storage; - if (!address.ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_INVALID_ARGUMENT; - - // Create |bound_socket_| and try to bind it to |address|. - int error = CreateSocket(address.GetSockAddrFamily(), &bound_socket_); - if (error) - return MapSystemError(error); - - if (bind(bound_socket_, storage.addr, storage.addr_len)) { - error = errno; - if (closesocket(bound_socket_) < 0) - PLOG(ERROR) << "closesocket"; - bound_socket_ = INVALID_SOCKET; - return MapSystemError(error); + if (SetNonBlocking(socket_)) { + int result = MapSystemError(WSAGetLastError()); + Close(); + return result; } - bind_address_.reset(new IPEndPoint(address)); - - return 0; + return OK; } - -int TCPClientSocketWin::Connect(const CompletionCallback& callback) { +int TCPSocketWin::AdoptConnectedSocket(SOCKET socket, + const IPEndPoint& peer_address) { DCHECK(CalledOnValidThread()); + DCHECK_EQ(socket_, INVALID_SOCKET); + DCHECK(!core_); - // If already connected, then just return OK. - if (socket_ != INVALID_SOCKET) - return OK; - - base::StatsCounter connects("tcp.connect"); - connects.Increment(); - - net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, - addresses_.CreateNetLogCallback()); - - // We will try to connect to each address in addresses_. Start with the - // first one in the list. - next_connect_state_ = CONNECT_STATE_CONNECT; - current_address_index_ = 0; + socket_ = socket; - int rv = DoConnectLoop(OK); - if (rv == ERR_IO_PENDING) { - // Synchronous operation not supported. - DCHECK(!callback.is_null()); - // TODO(ajwong): Is setting read_callback_ the right thing to do here?? - read_callback_ = callback; - } else { - LogConnectCompletion(rv); + if (SetNonBlocking(socket_)) { + int result = MapSystemError(WSAGetLastError()); + Close(); + return result; } - return rv; -} - -int TCPClientSocketWin::DoConnectLoop(int result) { - DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE); - - int rv = result; - do { - ConnectState state = next_connect_state_; - next_connect_state_ = CONNECT_STATE_NONE; - switch (state) { - case CONNECT_STATE_CONNECT: - DCHECK_EQ(OK, rv); - rv = DoConnect(); - break; - case CONNECT_STATE_CONNECT_COMPLETE: - rv = DoConnectComplete(rv); - break; - default: - LOG(DFATAL) << "bad state " << state; - rv = ERR_UNEXPECTED; - break; - } - } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE); + core_ = new Core(this); + peer_address_.reset(new IPEndPoint(peer_address)); - return rv; + return OK; } -int TCPClientSocketWin::DoConnect() { - DCHECK_GE(current_address_index_, 0); - DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size())); - DCHECK_EQ(0, connect_os_error_); +int TCPSocketWin::Bind(const IPEndPoint& address) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, INVALID_SOCKET); - const IPEndPoint& endpoint = addresses_[current_address_index_]; + SockaddrStorage storage; + if (!address.ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_ADDRESS_INVALID; - if (previously_disconnected_) { - use_history_.Reset(); - previously_disconnected_ = false; + int result = bind(socket_, storage.addr, storage.addr_len); + if (result < 0) { + PLOG(ERROR) << "bind() returned an error"; + return MapSystemError(WSAGetLastError()); } - net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, - CreateNetLogIPEndPointCallback(&endpoint)); + return OK; +} - next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE; +int TCPSocketWin::Listen(int backlog) { + DCHECK(CalledOnValidThread()); + DCHECK_GT(backlog, 0); + DCHECK_NE(socket_, INVALID_SOCKET); + DCHECK_EQ(accept_event_, WSA_INVALID_EVENT); - if (bound_socket_ != INVALID_SOCKET) { - DCHECK(bind_address_.get()); - socket_ = bound_socket_; - bound_socket_ = INVALID_SOCKET; - } else { - connect_os_error_ = CreateSocket(endpoint.GetSockAddrFamily(), &socket_); - if (connect_os_error_ != 0) - return MapSystemError(connect_os_error_); - - if (bind_address_.get()) { - SockaddrStorage storage; - if (!bind_address_->ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_INVALID_ARGUMENT; - if (bind(socket_, storage.addr, storage.addr_len)) - return MapSystemError(errno); - } + accept_event_ = WSACreateEvent(); + if (accept_event_ == WSA_INVALID_EVENT) { + PLOG(ERROR) << "WSACreateEvent()"; + return MapSystemError(WSAGetLastError()); } - DCHECK(!core_); - core_ = new Core(this); - // WSAEventSelect sets the socket to non-blocking mode as a side effect. - // Our connect() and recv() calls require that the socket be non-blocking. - WSAEventSelect(socket_, core_->read_overlapped_.hEvent, FD_CONNECT); - - SockaddrStorage storage; - if (!endpoint.ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_INVALID_ARGUMENT; - if (!connect(socket_, storage.addr, storage.addr_len)) { - // Connected without waiting! - // - // The MSDN page for connect says: - // With a nonblocking socket, the connection attempt cannot be completed - // immediately. In this case, connect will return SOCKET_ERROR, and - // WSAGetLastError will return WSAEWOULDBLOCK. - // which implies that for a nonblocking socket, connect never returns 0. - // It's not documented whether the event object will be signaled or not - // if connect does return 0. So the code below is essentially dead code - // and we don't know if it's correct. - NOTREACHED(); - - if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) - return OK; - } else { - int os_error = WSAGetLastError(); - if (os_error != WSAEWOULDBLOCK) { - LOG(ERROR) << "connect failed: " << os_error; - connect_os_error_ = os_error; - return MapConnectError(os_error); - } + int result = listen(socket_, backlog); + if (result < 0) { + PLOG(ERROR) << "listen() returned an error"; + return MapSystemError(WSAGetLastError()); } - core_->WatchForRead(); - return ERR_IO_PENDING; + return OK; } -int TCPClientSocketWin::DoConnectComplete(int result) { - // Log the end of this attempt (and any OS error it threw). - int os_error = connect_os_error_; - connect_os_error_ = 0; - if (result != OK) { - net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, - NetLog::IntegerCallback("os_error", os_error)); - } else { - net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT); - } +int TCPSocketWin::Accept(scoped_ptr<TCPSocketWin>* socket, + IPEndPoint* address, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK(socket); + DCHECK(address); + DCHECK(!callback.is_null()); + DCHECK(accept_callback_.is_null()); - if (result == OK) { - use_history_.set_was_ever_connected(); - return OK; // Done! - } + net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT); - // Close whatever partially connected socket we currently have. - DoDisconnect(); + int result = AcceptInternal(socket, address); - // Try to fall back to the next address in the list. - if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) { - next_connect_state_ = CONNECT_STATE_CONNECT; - ++current_address_index_; - return OK; + if (result == ERR_IO_PENDING) { + // Start watching. + WSAEventSelect(socket_, accept_event_, FD_ACCEPT); + accept_watcher_.StartWatching(accept_event_, this); + + accept_socket_ = socket; + accept_address_ = address; + accept_callback_ = callback; } - // Otherwise there is nothing to fall back to, so give up. return result; } -void TCPClientSocketWin::Disconnect() { +int TCPSocketWin::Connect(const IPEndPoint& address, + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, INVALID_SOCKET); + DCHECK(!waiting_connect_); - DoDisconnect(); - current_address_index_ = -1; - bind_address_.reset(); -} + // |peer_address_| and |core_| will be non-NULL if Connect() has been called. + // Unless Close() is called to reset the internal state, a second call to + // Connect() is not allowed. + // Please note that we enforce this even if the previous Connect() has + // completed and failed. Although it is allowed to connect the same |socket_| + // again after a connection attempt failed on Windows, it results in + // unspecified behavior according to POSIX. Therefore, we make it behave in + // the same way as TCPSocketLibevent. + DCHECK(!peer_address_ && !core_); -void TCPClientSocketWin::DoDisconnect() { - DCHECK(CalledOnValidThread()); + if (!logging_multiple_connect_attempts_) + LogConnectBegin(AddressList(address)); - if (socket_ == INVALID_SOCKET) - return; + peer_address_.reset(new IPEndPoint(address)); - // Note: don't use CancelIo to cancel pending IO because it doesn't work - // when there is a Winsock layered service provider. - - // In most socket implementations, closing a socket results in a graceful - // connection shutdown, but in Winsock we have to call shutdown explicitly. - // See the MSDN page "Graceful Shutdown, Linger Options, and Socket Closure" - // at http://msdn.microsoft.com/en-us/library/ms738547.aspx - shutdown(socket_, SD_SEND); - - // This cancels any pending IO. - closesocket(socket_); - socket_ = INVALID_SOCKET; - - if (waiting_connect()) { - // We closed the socket, so this notification will never come. - // From MSDN' WSAEventSelect documentation: - // "Closing a socket with closesocket also cancels the association and - // selection of network events specified in WSAEventSelect for the socket". - core_->Release(); + int rv = DoConnect(); + if (rv == ERR_IO_PENDING) { + // Synchronous operation not supported. + DCHECK(!callback.is_null()); + read_callback_ = callback; + waiting_connect_ = true; + } else { + DoConnectComplete(rv); } - waiting_read_ = false; - waiting_write_ = false; - - core_->Detach(); - core_ = NULL; - - previously_disconnected_ = true; + return rv; } -bool TCPClientSocketWin::IsConnected() const { +bool TCPSocketWin::IsConnected() const { DCHECK(CalledOnValidThread()); - if (socket_ == INVALID_SOCKET || waiting_connect()) + if (socket_ == INVALID_SOCKET || waiting_connect_) return false; if (waiting_read_) @@ -621,10 +438,10 @@ bool TCPClientSocketWin::IsConnected() const { return true; } -bool TCPClientSocketWin::IsConnectedAndIdle() const { +bool TCPSocketWin::IsConnectedAndIdle() const { DCHECK(CalledOnValidThread()); - if (socket_ == INVALID_SOCKET || waiting_connect()) + if (socket_ == INVALID_SOCKET || waiting_connect_) return false; if (waiting_read_) @@ -642,68 +459,9 @@ bool TCPClientSocketWin::IsConnectedAndIdle() const { return true; } -int TCPClientSocketWin::GetPeerAddress(IPEndPoint* address) const { - DCHECK(CalledOnValidThread()); - DCHECK(address); - if (!IsConnected()) - return ERR_SOCKET_NOT_CONNECTED; - *address = addresses_[current_address_index_]; - return OK; -} - -int TCPClientSocketWin::GetLocalAddress(IPEndPoint* address) const { - DCHECK(CalledOnValidThread()); - DCHECK(address); - if (socket_ == INVALID_SOCKET) { - if (bind_address_.get()) { - *address = *bind_address_; - return OK; - } - return ERR_SOCKET_NOT_CONNECTED; - } - - struct sockaddr_storage addr_storage; - socklen_t addr_len = sizeof(addr_storage); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - if (getsockname(socket_, addr, &addr_len)) - return MapSystemError(WSAGetLastError()); - if (!address->FromSockAddr(addr, addr_len)) - return ERR_FAILED; - return OK; -} - -void TCPClientSocketWin::SetSubresourceSpeculation() { - use_history_.set_subresource_speculation(); -} - -void TCPClientSocketWin::SetOmniboxSpeculation() { - use_history_.set_omnibox_speculation(); -} - -bool TCPClientSocketWin::WasEverUsed() const { - return use_history_.was_used_to_convey_data(); -} - -bool TCPClientSocketWin::UsingTCPFastOpen() const { - // Not supported on windows. - return false; -} - -bool TCPClientSocketWin::WasNpnNegotiated() const { - return false; -} - -NextProto TCPClientSocketWin::GetNegotiatedProtocol() const { - return kProtoUnknown; -} - -bool TCPClientSocketWin::GetSSLInfo(SSLInfo* ssl_info) { - return false; -} - -int TCPClientSocketWin::Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) { +int TCPSocketWin::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(socket_, INVALID_SOCKET); DCHECK(!waiting_read_); @@ -713,9 +471,9 @@ int TCPClientSocketWin::Read(IOBuffer* buf, return DoRead(buf, buf_len, callback); } -int TCPClientSocketWin::Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) { +int TCPSocketWin::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(socket_, INVALID_SOCKET); DCHECK(!waiting_write_); @@ -747,8 +505,6 @@ int TCPClientSocketWin::Write(IOBuffer* buf, } base::StatsCounter write_bytes("tcp.write_bytes"); write_bytes.Add(rv); - if (rv > 0) - use_history_.set_was_used_to_convey_data(); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, rv, buf->data()); return rv; @@ -770,29 +526,294 @@ int TCPClientSocketWin::Write(IOBuffer* buf, return ERR_IO_PENDING; } -bool TCPClientSocketWin::SetReceiveBufferSize(int32 size) { +int TCPSocketWin::GetLocalAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); + + SockaddrStorage storage; + if (getsockname(socket_, storage.addr, &storage.addr_len)) + return MapSystemError(WSAGetLastError()); + if (!address->FromSockAddr(storage.addr, storage.addr_len)) + return ERR_ADDRESS_INVALID; + + return OK; +} + +int TCPSocketWin::GetPeerAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + *address = *peer_address_; + return OK; +} + +int TCPSocketWin::SetDefaultOptionsForServer() { + return SetExclusiveAddrUse(); +} + +void TCPSocketWin::SetDefaultOptionsForClient() { + // Increase the socket buffer sizes from the default sizes for WinXP. In + // performance testing, there is substantial benefit by increasing from 8KB + // to 64KB. + // See also: + // http://support.microsoft.com/kb/823764/EN-US + // On Vista, if we manually set these sizes, Vista turns off its receive + // window auto-tuning feature. + // http://blogs.msdn.com/wndp/archive/2006/05/05/Winhec-blog-tcpip-2.aspx + // Since Vista's auto-tune is better than any static value we can could set, + // only change these on pre-vista machines. + if (base::win::GetVersion() < base::win::VERSION_VISTA) { + const int32 kSocketBufferSize = 64 * 1024; + SetSocketReceiveBufferSize(socket_, kSocketBufferSize); + SetSocketSendBufferSize(socket_, kSocketBufferSize); + } + + DisableNagle(socket_, true); + SetTCPKeepAlive(socket_, true, kTCPKeepAliveSeconds); +} + +int TCPSocketWin::SetExclusiveAddrUse() { + // On Windows, a bound end point can be hijacked by another process by + // setting SO_REUSEADDR. Therefore a Windows-only option SO_EXCLUSIVEADDRUSE + // was introduced in Windows NT 4.0 SP4. If the socket that is bound to the + // end point has SO_EXCLUSIVEADDRUSE enabled, it is not possible for another + // socket to forcibly bind to the end point until the end point is unbound. + // It is recommend that all server applications must use SO_EXCLUSIVEADDRUSE. + // MSDN: http://goo.gl/M6fjQ. + // + // Unlike on *nix, on Windows a TCP server socket can always bind to an end + // point in TIME_WAIT state without setting SO_REUSEADDR, therefore it is not + // needed here. + // + // SO_EXCLUSIVEADDRUSE will prevent a TCP client socket from binding to an end + // point in TIME_WAIT status. It does not have this effect for a TCP server + // socket. + + BOOL true_value = 1; + int rv = setsockopt(socket_, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, + reinterpret_cast<const char*>(&true_value), + sizeof(true_value)); + if (rv < 0) + return MapSystemError(errno); + return OK; +} + +bool TCPSocketWin::SetReceiveBufferSize(int32 size) { DCHECK(CalledOnValidThread()); return SetSocketReceiveBufferSize(socket_, size); } -bool TCPClientSocketWin::SetSendBufferSize(int32 size) { +bool TCPSocketWin::SetSendBufferSize(int32 size) { DCHECK(CalledOnValidThread()); return SetSocketSendBufferSize(socket_, size); } -bool TCPClientSocketWin::SetKeepAlive(bool enable, int delay) { +bool TCPSocketWin::SetKeepAlive(bool enable, int delay) { return SetTCPKeepAlive(socket_, enable, delay); } -bool TCPClientSocketWin::SetNoDelay(bool no_delay) { +bool TCPSocketWin::SetNoDelay(bool no_delay) { return DisableNagle(socket_, no_delay); } -void TCPClientSocketWin::DisableOverlappedReads() { - g_disable_overlapped_reads = true; +void TCPSocketWin::Close() { + DCHECK(CalledOnValidThread()); + + if (socket_ != INVALID_SOCKET) { + // Note: don't use CancelIo to cancel pending IO because it doesn't work + // when there is a Winsock layered service provider. + + // In most socket implementations, closing a socket results in a graceful + // connection shutdown, but in Winsock we have to call shutdown explicitly. + // See the MSDN page "Graceful Shutdown, Linger Options, and Socket Closure" + // at http://msdn.microsoft.com/en-us/library/ms738547.aspx + shutdown(socket_, SD_SEND); + + // This cancels any pending IO. + if (closesocket(socket_) < 0) + PLOG(ERROR) << "closesocket"; + socket_ = INVALID_SOCKET; + } + + if (accept_event_) { + WSACloseEvent(accept_event_); + accept_event_ = WSA_INVALID_EVENT; + } + + if (!accept_callback_.is_null()) { + accept_watcher_.StopWatching(); + accept_socket_ = NULL; + accept_address_ = NULL; + accept_callback_.Reset(); + } + + if (core_) { + if (waiting_connect_) { + // We closed the socket, so this notification will never come. + // From MSDN' WSAEventSelect documentation: + // "Closing a socket with closesocket also cancels the association and + // selection of network events specified in WSAEventSelect for the + // socket". + core_->Release(); + } + core_->Detach(); + core_ = NULL; + } + + waiting_connect_ = false; + waiting_read_ = false; + waiting_write_ = false; + + read_callback_.Reset(); + write_callback_.Reset(); + peer_address_.reset(); + connect_os_error_ = 0; } -void TCPClientSocketWin::LogConnectCompletion(int net_error) { +bool TCPSocketWin::UsingTCPFastOpen() const { + // Not supported on windows. + return false; +} + +void TCPSocketWin::StartLoggingMultipleConnectAttempts( + const AddressList& addresses) { + if (!logging_multiple_connect_attempts_) { + logging_multiple_connect_attempts_ = true; + LogConnectBegin(addresses); + } else { + NOTREACHED(); + } +} + +void TCPSocketWin::EndLoggingMultipleConnectAttempts(int net_error) { + if (logging_multiple_connect_attempts_) { + LogConnectEnd(net_error); + logging_multiple_connect_attempts_ = false; + } else { + NOTREACHED(); + } +} + +int TCPSocketWin::AcceptInternal(scoped_ptr<TCPSocketWin>* socket, + IPEndPoint* address) { + SockaddrStorage storage; + int new_socket = accept(socket_, storage.addr, &storage.addr_len); + if (new_socket < 0) { + int net_error = MapSystemError(WSAGetLastError()); + if (net_error != ERR_IO_PENDING) + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); + return net_error; + } + + IPEndPoint ip_end_point; + if (!ip_end_point.FromSockAddr(storage.addr, storage.addr_len)) { + NOTREACHED(); + if (closesocket(new_socket) < 0) + PLOG(ERROR) << "closesocket"; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED); + return ERR_FAILED; + } + scoped_ptr<TCPSocketWin> tcp_socket(new TCPSocketWin( + net_log_.net_log(), net_log_.source())); + int adopt_result = tcp_socket->AdoptConnectedSocket(new_socket, ip_end_point); + if (adopt_result != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result); + return adopt_result; + } + *socket = tcp_socket.Pass(); + *address = ip_end_point; + net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, + CreateNetLogIPEndPointCallback(&ip_end_point)); + return OK; +} + +void TCPSocketWin::OnObjectSignaled(HANDLE object) { + WSANETWORKEVENTS ev; + if (WSAEnumNetworkEvents(socket_, accept_event_, &ev) == SOCKET_ERROR) { + PLOG(ERROR) << "WSAEnumNetworkEvents()"; + return; + } + + if (ev.lNetworkEvents & FD_ACCEPT) { + int result = AcceptInternal(accept_socket_, accept_address_); + if (result != ERR_IO_PENDING) { + accept_socket_ = NULL; + accept_address_ = NULL; + base::ResetAndReturn(&accept_callback_).Run(result); + } + } +} + +int TCPSocketWin::DoConnect() { + DCHECK_EQ(connect_os_error_, 0); + DCHECK(!core_); + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + CreateNetLogIPEndPointCallback(peer_address_.get())); + + core_ = new Core(this); + // WSAEventSelect sets the socket to non-blocking mode as a side effect. + // Our connect() and recv() calls require that the socket be non-blocking. + WSAEventSelect(socket_, core_->read_overlapped_.hEvent, FD_CONNECT); + + SockaddrStorage storage; + if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_INVALID_ARGUMENT; + if (!connect(socket_, storage.addr, storage.addr_len)) { + // Connected without waiting! + // + // The MSDN page for connect says: + // With a nonblocking socket, the connection attempt cannot be completed + // immediately. In this case, connect will return SOCKET_ERROR, and + // WSAGetLastError will return WSAEWOULDBLOCK. + // which implies that for a nonblocking socket, connect never returns 0. + // It's not documented whether the event object will be signaled or not + // if connect does return 0. So the code below is essentially dead code + // and we don't know if it's correct. + NOTREACHED(); + + if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) + return OK; + } else { + int os_error = WSAGetLastError(); + if (os_error != WSAEWOULDBLOCK) { + LOG(ERROR) << "connect failed: " << os_error; + connect_os_error_ = os_error; + int rv = MapConnectError(os_error); + CHECK_NE(ERR_IO_PENDING, rv); + return rv; + } + } + + core_->WatchForRead(); + return ERR_IO_PENDING; +} + +void TCPSocketWin::DoConnectComplete(int result) { + // Log the end of this attempt (and any OS error it threw). + int os_error = connect_os_error_; + connect_os_error_ = 0; + if (result != OK) { + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + NetLog::IntegerCallback("os_error", os_error)); + } else { + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT); + } + + if (!logging_multiple_connect_attempts_) + LogConnectEnd(result); +} + +void TCPSocketWin::LogConnectBegin(const AddressList& addresses) { + base::StatsCounter connects("tcp.connect"); + connects.Increment(); + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, + addresses.CreateNetLogCallback()); +} + +void TCPSocketWin::LogConnectEnd(int net_error) { if (net_error == OK) UpdateConnectionTypeHistograms(CONNECTION_ANY); @@ -820,66 +841,30 @@ void TCPClientSocketWin::LogConnectCompletion(int net_error) { sizeof(source_address))); } -int TCPClientSocketWin::DoRead(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) { - if (core_->disable_overlapped_reads_) { - if (!core_->non_blocking_reads_initialized_) { - WSAEventSelect(socket_, core_->read_overlapped_.hEvent, - FD_READ | FD_CLOSE); - core_->non_blocking_reads_initialized_ = true; - } - int rv = recv(socket_, buf->data(), buf_len, 0); - if (rv == SOCKET_ERROR) { - int os_error = WSAGetLastError(); - if (os_error != WSAEWOULDBLOCK) { - int net_error = MapSystemError(os_error); - net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, - CreateNetLogSocketErrorCallback(net_error, os_error)); - return net_error; - } - } else { - base::StatsCounter read_bytes("tcp.read_bytes"); - if (rv > 0) { - use_history_.set_was_used_to_convey_data(); - read_bytes.Add(rv); - } - net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, rv, - buf->data()); - return rv; +int TCPSocketWin::DoRead(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + if (!core_->non_blocking_reads_initialized_) { + WSAEventSelect(socket_, core_->read_overlapped_.hEvent, + FD_READ | FD_CLOSE); + core_->non_blocking_reads_initialized_ = true; + } + int rv = recv(socket_, buf->data(), buf_len, 0); + if (rv == SOCKET_ERROR) { + int os_error = WSAGetLastError(); + if (os_error != WSAEWOULDBLOCK) { + int net_error = MapSystemError(os_error); + net_log_.AddEvent( + NetLog::TYPE_SOCKET_READ_ERROR, + CreateNetLogSocketErrorCallback(net_error, os_error)); + return net_error; } } else { - buf_len = core_->ThrottleReadSize(buf_len); - - WSABUF read_buffer; - read_buffer.len = buf_len; - read_buffer.buf = buf->data(); - - // TODO(wtc): Remove the assertion after enough testing. - AssertEventNotSignaled(core_->read_overlapped_.hEvent); - DWORD num; - DWORD flags = 0; - int rv = WSARecv(socket_, &read_buffer, 1, &num, &flags, - &core_->read_overlapped_, NULL); - if (rv == 0) { - if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) { - base::StatsCounter read_bytes("tcp.read_bytes"); - if (num > 0) { - use_history_.set_was_used_to_convey_data(); - read_bytes.Add(num); - } - net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, num, - buf->data()); - return static_cast<int>(num); - } - } else { - int os_error = WSAGetLastError(); - if (os_error != WSA_IO_PENDING) { - int net_error = MapSystemError(os_error); - net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, - CreateNetLogSocketErrorCallback(net_error, os_error)); - return net_error; - } - } + base::StatsCounter read_bytes("tcp.read_bytes"); + if (rv > 0) + read_bytes.Add(rv); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, rv, + buf->data()); + return rv; } waiting_read_ = true; @@ -890,28 +875,9 @@ int TCPClientSocketWin::DoRead(IOBuffer* buf, int buf_len, return ERR_IO_PENDING; } -void TCPClientSocketWin::DoReadCallback(int rv) { - DCHECK_NE(rv, ERR_IO_PENDING); +void TCPSocketWin::DidCompleteConnect() { + DCHECK(waiting_connect_); DCHECK(!read_callback_.is_null()); - - // Since Run may result in Read being called, clear read_callback_ up front. - CompletionCallback c = read_callback_; - read_callback_.Reset(); - c.Run(rv); -} - -void TCPClientSocketWin::DoWriteCallback(int rv) { - DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(!write_callback_.is_null()); - - // since Run may result in Write being called, clear write_callback_ up front. - CompletionCallback c = write_callback_; - write_callback_.Reset(); - c.Run(rv); -} - -void TCPClientSocketWin::DidCompleteConnect() { - DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE); int result; WSANETWORKEVENTS events; @@ -931,42 +897,16 @@ void TCPClientSocketWin::DidCompleteConnect() { } connect_os_error_ = os_error; - rv = DoConnectLoop(result); - if (rv != ERR_IO_PENDING) { - LogConnectCompletion(rv); - DoReadCallback(rv); - } -} + DoConnectComplete(result); + waiting_connect_ = false; -void TCPClientSocketWin::DidCompleteRead() { - DCHECK(waiting_read_); - DWORD num_bytes, flags; - BOOL ok = WSAGetOverlappedResult(socket_, &core_->read_overlapped_, - &num_bytes, FALSE, &flags); - waiting_read_ = false; - int rv; - if (ok) { - base::StatsCounter read_bytes("tcp.read_bytes"); - read_bytes.Add(num_bytes); - if (num_bytes > 0) - use_history_.set_was_used_to_convey_data(); - net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, - num_bytes, core_->read_iobuffer_->data()); - rv = static_cast<int>(num_bytes); - } else { - int os_error = WSAGetLastError(); - rv = MapSystemError(os_error); - net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, - CreateNetLogSocketErrorCallback(rv, os_error)); - } - WSAResetEvent(core_->read_overlapped_.hEvent); - core_->read_iobuffer_ = NULL; - core_->read_buffer_length_ = 0; - DoReadCallback(rv); + DCHECK_NE(result, ERR_IO_PENDING); + base::ResetAndReturn(&read_callback_).Run(result); } -void TCPClientSocketWin::DidCompleteWrite() { +void TCPSocketWin::DidCompleteWrite() { DCHECK(waiting_write_); + DCHECK(!write_callback_.is_null()); DWORD num_bytes, flags; BOOL ok = WSAGetOverlappedResult(socket_, &core_->write_overlapped_, @@ -991,18 +931,21 @@ void TCPClientSocketWin::DidCompleteWrite() { } else { base::StatsCounter write_bytes("tcp.write_bytes"); write_bytes.Add(num_bytes); - if (num_bytes > 0) - use_history_.set_was_used_to_convey_data(); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, num_bytes, core_->write_iobuffer_->data()); } } + core_->write_iobuffer_ = NULL; - DoWriteCallback(rv); + + DCHECK_NE(rv, ERR_IO_PENDING); + base::ResetAndReturn(&write_callback_).Run(rv); } -void TCPClientSocketWin::DidSignalRead() { +void TCPSocketWin::DidSignalRead() { DCHECK(waiting_read_); + DCHECK(!read_callback_.is_null()); + int os_error = 0; WSANETWORKEVENTS network_events; int rv = WSAEnumNetworkEvents(socket_, core_->read_overlapped_.hEvent, @@ -1036,10 +979,14 @@ void TCPClientSocketWin::DidSignalRead() { core_->WatchForRead(); return; } + waiting_read_ = false; core_->read_iobuffer_ = NULL; core_->read_buffer_length_ = 0; - DoReadCallback(rv); + + DCHECK_NE(rv, ERR_IO_PENDING); + base::ResetAndReturn(&read_callback_).Run(rv); } } // namespace net + diff --git a/chromium/net/socket/tcp_socket_win.h b/chromium/net/socket/tcp_socket_win.h new file mode 100644 index 00000000000..df5fbf09aec --- /dev/null +++ b/chromium/net/socket/tcp_socket_win.h @@ -0,0 +1,150 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_SOCKET_WIN_H_ +#define NET_SOCKET_TCP_SOCKET_WIN_H_ + +#include <winsock2.h> + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/threading/non_thread_safe.h" +#include "base/win/object_watcher.h" +#include "net/base/address_family.h" +#include "net/base/completion_callback.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" + +namespace net { + +class AddressList; +class IOBuffer; +class IPEndPoint; + +class NET_EXPORT TCPSocketWin : NON_EXPORTED_BASE(public base::NonThreadSafe), + public base::win::ObjectWatcher::Delegate { + public: + TCPSocketWin(NetLog* net_log, const NetLog::Source& source); + virtual ~TCPSocketWin(); + + int Open(AddressFamily family); + // Takes ownership of |socket|. + int AdoptConnectedSocket(SOCKET socket, const IPEndPoint& peer_address); + + int Bind(const IPEndPoint& address); + + int Listen(int backlog); + int Accept(scoped_ptr<TCPSocketWin>* socket, + IPEndPoint* address, + const CompletionCallback& callback); + + int Connect(const IPEndPoint& address, const CompletionCallback& callback); + bool IsConnected() const; + bool IsConnectedAndIdle() const; + + // Multiple outstanding requests are not supported. + // Full duplex mode (reading and writing at the same time) is supported. + int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + + int GetLocalAddress(IPEndPoint* address) const; + int GetPeerAddress(IPEndPoint* address) const; + + // Sets various socket options. + // The commonly used options for server listening sockets: + // - SetExclusiveAddrUse(). + int SetDefaultOptionsForServer(); + // The commonly used options for client sockets and accepted sockets: + // - Increase the socket buffer sizes for WinXP; + // - SetNoDelay(true); + // - SetKeepAlive(true, 45). + void SetDefaultOptionsForClient(); + int SetExclusiveAddrUse(); + bool SetReceiveBufferSize(int32 size); + bool SetSendBufferSize(int32 size); + bool SetKeepAlive(bool enable, int delay); + bool SetNoDelay(bool no_delay); + + void Close(); + + bool UsingTCPFastOpen() const; + bool IsValid() const { return socket_ != INVALID_SOCKET; } + + // Marks the start/end of a series of connect attempts for logging purpose. + // + // TCPClientSocket may attempt to connect to multiple addresses until it + // succeeds in establishing a connection. The corresponding log will have + // multiple NetLog::TYPE_TCP_CONNECT_ATTEMPT entries nested within a + // NetLog::TYPE_TCP_CONNECT. These methods set the start/end of + // NetLog::TYPE_TCP_CONNECT. + // + // TODO(yzshen): Change logging format and let TCPClientSocket log the + // start/end of a series of connect attempts itself. + void StartLoggingMultipleConnectAttempts(const AddressList& addresses); + void EndLoggingMultipleConnectAttempts(int net_error); + + const BoundNetLog& net_log() const { return net_log_; } + + private: + class Core; + + // base::ObjectWatcher::Delegate implementation. + virtual void OnObjectSignaled(HANDLE object) OVERRIDE; + + int AcceptInternal(scoped_ptr<TCPSocketWin>* socket, + IPEndPoint* address); + + int DoConnect(); + void DoConnectComplete(int result); + + void LogConnectBegin(const AddressList& addresses); + void LogConnectEnd(int net_error); + + int DoRead(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + void DidCompleteConnect(); + void DidCompleteWrite(); + void DidSignalRead(); + + SOCKET socket_; + + HANDLE accept_event_; + base::win::ObjectWatcher accept_watcher_; + + scoped_ptr<TCPSocketWin>* accept_socket_; + IPEndPoint* accept_address_; + CompletionCallback accept_callback_; + + // The various states that the socket could be in. + bool waiting_connect_; + bool waiting_read_; + bool waiting_write_; + + // The core of the socket that can live longer than the socket itself. We pass + // resources to the Windows async IO functions and we have to make sure that + // they are not destroyed while the OS still references them. + scoped_refptr<Core> core_; + + // External callback; called when connect or read is complete. + CompletionCallback read_callback_; + + // External callback; called when write is complete. + CompletionCallback write_callback_; + + scoped_ptr<IPEndPoint> peer_address_; + // The OS error that a connect attempt last completed with. + int connect_os_error_; + + bool logging_multiple_connect_attempts_; + + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(TCPSocketWin); +}; + +} // namespace net + +#endif // NET_SOCKET_TCP_SOCKET_WIN_H_ + diff --git a/chromium/net/socket/transport_client_socket_pool.cc b/chromium/net/socket/transport_client_socket_pool.cc index 8255e988fa4..d03e3e651ac 100644 --- a/chromium/net/socket/transport_client_socket_pool.cc +++ b/chromium/net/socket/transport_client_socket_pool.cc @@ -48,25 +48,18 @@ bool AddressListOnlyContainsIPv6(const AddressList& list) { TransportSocketParams::TransportSocketParams( const HostPortPair& host_port_pair, - RequestPriority priority, bool disable_resolver_cache, bool ignore_limits, const OnHostResolutionCallback& host_resolution_callback) : destination_(host_port_pair), ignore_limits_(ignore_limits), host_resolution_callback_(host_resolution_callback) { - Initialize(priority, disable_resolver_cache); -} - -TransportSocketParams::~TransportSocketParams() {} - -void TransportSocketParams::Initialize(RequestPriority priority, - bool disable_resolver_cache) { - destination_.set_priority(priority); if (disable_resolver_cache) destination_.set_allow_cached_response(false); } +TransportSocketParams::~TransportSocketParams() {} + // TransportConnectJobs will time out after this many seconds. Note this is // the total time, including both host resolution and TCP connect() times. // @@ -80,13 +73,14 @@ static const int kTransportConnectJobTimeoutInSeconds = 240; // 4 minutes. TransportConnectJob::TransportConnectJob( const std::string& group_name, + RequestPriority priority, const scoped_refptr<TransportSocketParams>& params, base::TimeDelta timeout_duration, ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, Delegate* delegate, NetLog* net_log) - : ConnectJob(group_name, timeout_duration, delegate, + : ConnectJob(group_name, timeout_duration, priority, delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), params_(params), client_socket_factory_(client_socket_factory), @@ -107,10 +101,11 @@ LoadState TransportConnectJob::GetLoadState() const { case STATE_TRANSPORT_CONNECT: case STATE_TRANSPORT_CONNECT_COMPLETE: return LOAD_STATE_CONNECTING; - default: - NOTREACHED(); + case STATE_NONE: return LOAD_STATE_IDLE; } + NOTREACHED(); + return LOAD_STATE_IDLE; } // static @@ -166,7 +161,9 @@ int TransportConnectJob::DoResolveHost() { connect_timing_.dns_start = base::TimeTicks::Now(); return resolver_.Resolve( - params_->destination(), &addresses_, + params_->destination(), + priority(), + &addresses_, base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this)), net_log()); } @@ -190,8 +187,8 @@ int TransportConnectJob::DoResolveHostComplete(int result) { int TransportConnectJob::DoTransportConnect() { next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; - transport_socket_.reset(client_socket_factory_->CreateTransportClientSocket( - addresses_, net_log().net_log(), net_log().source())); + transport_socket_ = client_socket_factory_->CreateTransportClientSocket( + addresses_, net_log().net_log(), net_log().source()); int rv = transport_socket_->Connect( base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this))); if (rv == ERR_IO_PENDING && @@ -246,7 +243,7 @@ int TransportConnectJob::DoTransportConnectComplete(int result) { 100); } } - set_socket(transport_socket_.release()); + SetSocket(transport_socket_.Pass()); fallback_timer_.Stop(); } else { // Be a bit paranoid and kill off the fallback members to prevent reuse. @@ -270,9 +267,9 @@ void TransportConnectJob::DoIPv6FallbackTransportConnect() { fallback_addresses_.reset(new AddressList(addresses_)); MakeAddressListStartWithIPv4(fallback_addresses_.get()); - fallback_transport_socket_.reset( + fallback_transport_socket_ = client_socket_factory_->CreateTransportClientSocket( - *fallback_addresses_, net_log().net_log(), net_log().source())); + *fallback_addresses_, net_log().net_log(), net_log().source()); fallback_connect_start_time_ = base::TimeTicks::Now(); int rv = fallback_transport_socket_->Connect( base::Bind( @@ -317,7 +314,7 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) { base::TimeDelta::FromMilliseconds(1), base::TimeDelta::FromMinutes(10), 100); - set_socket(fallback_transport_socket_.release()); + SetSocket(fallback_transport_socket_.Pass()); next_state_ = STATE_NONE; transport_socket_.reset(); } else { @@ -333,18 +330,20 @@ int TransportConnectJob::ConnectInternal() { return DoLoop(OK); } -ConnectJob* +scoped_ptr<ConnectJob> TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return new TransportConnectJob(group_name, - request.params(), - ConnectionTimeout(), - client_socket_factory_, - host_resolver_, - delegate, - net_log_); + return scoped_ptr<ConnectJob>( + new TransportConnectJob(group_name, + request.priority(), + request.params(), + ConnectionTimeout(), + client_socket_factory_, + host_resolver_, + delegate, + net_log_)); } base::TimeDelta @@ -360,11 +359,11 @@ TransportClientSocketPool::TransportClientSocketPool( HostResolver* host_resolver, ClientSocketFactory* client_socket_factory, NetLog* net_log) - : base_(max_sockets, max_sockets_per_group, histograms, + : base_(NULL, max_sockets, max_sockets_per_group, histograms, ClientSocketPool::unused_idle_socket_timeout(), ClientSocketPool::used_idle_socket_timeout(), new TransportConnectJobFactory(client_socket_factory, - host_resolver, net_log)) { + host_resolver, net_log)) { base_.EnableConnectBackupJobs(); } @@ -419,19 +418,15 @@ void TransportClientSocketPool::CancelRequest( void TransportClientSocketPool::ReleaseSocket( const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) { - base_.ReleaseSocket(group_name, socket, id); + base_.ReleaseSocket(group_name, socket.Pass(), id); } void TransportClientSocketPool::FlushWithError(int error) { base_.FlushWithError(error); } -bool TransportClientSocketPool::IsStalled() const { - return base_.IsStalled(); -} - void TransportClientSocketPool::CloseIdleSockets() { base_.CloseIdleSockets(); } @@ -450,14 +445,6 @@ LoadState TransportClientSocketPool::GetLoadState( return base_.GetLoadState(group_name, handle); } -void TransportClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { - base_.AddLayeredPool(layered_pool); -} - -void TransportClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { - base_.RemoveLayeredPool(layered_pool); -} - base::DictionaryValue* TransportClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, @@ -473,4 +460,18 @@ ClientSocketPoolHistograms* TransportClientSocketPool::histograms() const { return base_.histograms(); } +bool TransportClientSocketPool::IsStalled() const { + return base_.IsStalled(); +} + +void TransportClientSocketPool::AddHigherLayeredPool( + HigherLayeredPool* higher_pool) { + base_.AddHigherLayeredPool(higher_pool); +} + +void TransportClientSocketPool::RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) { + base_.RemoveHigherLayeredPool(higher_pool); +} + } // namespace net diff --git a/chromium/net/socket/transport_client_socket_pool.h b/chromium/net/socket/transport_client_socket_pool.h index bb53b3da301..16e421a4550 100644 --- a/chromium/net/socket/transport_client_socket_pool.h +++ b/chromium/net/socket/transport_client_socket_pool.h @@ -34,7 +34,6 @@ class NET_EXPORT_PRIVATE TransportSocketParams // connection will be aborted with that value. TransportSocketParams( const HostPortPair& host_port_pair, - RequestPriority priority, bool disable_resolver_cache, bool ignore_limits, const OnHostResolutionCallback& host_resolution_callback); @@ -49,8 +48,6 @@ class NET_EXPORT_PRIVATE TransportSocketParams friend class base::RefCounted<TransportSocketParams>; ~TransportSocketParams(); - void Initialize(RequestPriority priority, bool disable_resolver_cache); - HostResolver::RequestInfo destination_; bool ignore_limits_; const OnHostResolutionCallback host_resolution_callback_; @@ -69,6 +66,7 @@ class NET_EXPORT_PRIVATE TransportSocketParams class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { public: TransportConnectJob(const std::string& group_name, + RequestPriority priority, const scoped_refptr<TransportSocketParams>& params, base::TimeDelta timeout_duration, ClientSocketFactory* client_socket_factory, @@ -132,6 +130,8 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { public: + typedef TransportSocketParams SocketParams; + TransportClientSocketPool( int max_sockets, int max_sockets_per_group, @@ -156,10 +156,9 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { virtual void CancelRequest(const std::string& group_name, ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE; virtual void FlushWithError(int error) OVERRIDE; - virtual bool IsStalled() const OVERRIDE; virtual void CloseIdleSockets() OVERRIDE; virtual int IdleSocketCount() const OVERRIDE; virtual int IdleSocketCountInGroup( @@ -167,8 +166,6 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { virtual LoadState GetLoadState( const std::string& group_name, const ClientSocketHandle* handle) const OVERRIDE; - virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; - virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; virtual base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, @@ -176,6 +173,11 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; + // HigherLayeredPool implementation. + virtual bool IsStalled() const OVERRIDE; + virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + private: typedef ClientSocketPoolBase<TransportSocketParams> PoolBase; @@ -193,7 +195,7 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { // ClientSocketPoolBase::ConnectJobFactory methods. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE; @@ -213,9 +215,6 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPool); }; -REGISTER_SOCKET_PARAMS_FOR_POOL(TransportClientSocketPool, - TransportSocketParams); - } // namespace net #endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/socket/transport_client_socket_pool_unittest.cc b/chromium/net/socket/transport_client_socket_pool_unittest.cc index dfa1151b291..a984ea3b740 100644 --- a/chromium/net/socket/transport_client_socket_pool_unittest.cc +++ b/chromium/net/socket/transport_client_socket_pool_unittest.cc @@ -23,6 +23,7 @@ #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" #include "testing/gtest/include/gtest/gtest.h" @@ -340,16 +341,16 @@ class MockClientSocketFactory : public ClientSocketFactory { delay_(base::TimeDelta::FromMilliseconds( ClientSocketPool::kMaxConnectRetryIntervalMs)) {} - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE { NOTREACHED(); - return NULL; + return scoped_ptr<DatagramClientSocket>(); } - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* /* net_log */, const NetLog::Source& /* source */) OVERRIDE { @@ -363,34 +364,41 @@ class MockClientSocketFactory : public ClientSocketFactory { switch (type) { case MOCK_CLIENT_SOCKET: - return new MockClientSocket(addresses, net_log_); + return scoped_ptr<StreamSocket>( + new MockClientSocket(addresses, net_log_)); case MOCK_FAILING_CLIENT_SOCKET: - return new MockFailingClientSocket(addresses, net_log_); + return scoped_ptr<StreamSocket>( + new MockFailingClientSocket(addresses, net_log_)); case MOCK_PENDING_CLIENT_SOCKET: - return new MockPendingClientSocket( - addresses, true, false, base::TimeDelta(), net_log_); + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, true, false, base::TimeDelta(), net_log_)); case MOCK_PENDING_FAILING_CLIENT_SOCKET: - return new MockPendingClientSocket( - addresses, false, false, base::TimeDelta(), net_log_); + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, false, false, base::TimeDelta(), net_log_)); case MOCK_DELAYED_CLIENT_SOCKET: - return new MockPendingClientSocket( - addresses, true, false, delay_, net_log_); + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, true, false, delay_, net_log_)); case MOCK_STALLED_CLIENT_SOCKET: - return new MockPendingClientSocket( - addresses, true, true, base::TimeDelta(), net_log_); + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, true, true, base::TimeDelta(), net_log_)); default: NOTREACHED(); - return new MockClientSocket(addresses, net_log_); + return scoped_ptr<StreamSocket>( + new MockClientSocket(addresses, net_log_)); } } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); } virtual void ClearSSLSessionCache() OVERRIDE { @@ -431,11 +439,7 @@ class TransportClientSocketPoolTest : public testing::Test { ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(true)), params_( new TransportSocketParams(HostPortPair("www.google.com", 80), - kDefaultPriority, false, false, - OnHostResolutionCallback())), - low_params_( - new TransportSocketParams(HostPortPair("www.google.com", 80), - LOW, false, false, + false, false, OnHostResolutionCallback())), histograms_(new ClientSocketPoolHistograms("TCPUnitTest")), host_resolver_(new MockHostResolver), @@ -455,7 +459,7 @@ class TransportClientSocketPoolTest : public testing::Test { int StartRequest(const std::string& group_name, RequestPriority priority) { scoped_refptr<TransportSocketParams> params(new TransportSocketParams( - HostPortPair("www.google.com", 80), MEDIUM, false, false, + HostPortPair("www.google.com", 80), false, false, OnHostResolutionCallback())); return test_base_.StartRequestUsingPool( &pool_, group_name, priority, params); @@ -479,7 +483,6 @@ class TransportClientSocketPoolTest : public testing::Test { bool connect_backup_jobs_enabled_; CapturingNetLog net_log_; scoped_refptr<TransportSocketParams> params_; - scoped_refptr<TransportSocketParams> low_params_; scoped_ptr<ClientSocketPoolHistograms> histograms_; scoped_ptr<MockHostResolver> host_resolver_; MockClientSocketFactory client_socket_factory_; @@ -561,7 +564,7 @@ TEST(TransportConnectJobTest, MakeAddrListStartWithIPv4) { TEST_F(TransportClientSocketPoolTest, Basic) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -573,13 +576,27 @@ TEST_F(TransportClientSocketPoolTest, Basic) { TestLoadTimingInfoConnectedNotReused(handle); } +// Make sure that TransportConnectJob passes on its priority to its +// HostResolver request on Init. +TEST_F(TransportClientSocketPoolTest, SetResolvePriorityOnInit) { + for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) { + RequestPriority priority = static_cast<RequestPriority>(i); + TestCompletionCallback callback; + ClientSocketHandle handle; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, priority, callback.callback(), &pool_, + BoundNetLog())); + EXPECT_EQ(priority, host_resolver_->last_request_priority()); + } +} + TEST_F(TransportClientSocketPoolTest, InitHostResolutionFailure) { host_resolver_->rules()->AddSimulatedFailure("unresolvable.host.name"); TestCompletionCallback callback; ClientSocketHandle handle; HostPortPair host_port_pair("unresolvable.host.name", 80); scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( - host_port_pair, kDefaultPriority, false, false, + host_port_pair, false, false, OnHostResolutionCallback())); EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", dest, kDefaultPriority, callback.callback(), @@ -854,7 +871,7 @@ class RequestSocketCallback : public TestCompletionCallbackBase { } within_callback_ = true; scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( - HostPortPair("www.google.com", 80), LOWEST, false, false, + HostPortPair("www.google.com", 80), false, false, OnHostResolutionCallback())); int rv = handle_->Init("a", dest, LOWEST, callback(), pool_, BoundNetLog()); @@ -874,7 +891,7 @@ TEST_F(TransportClientSocketPoolTest, RequestTwice) { ClientSocketHandle handle; RequestSocketCallback callback(&handle, &pool_); scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( - HostPortPair("www.google.com", 80), LOWEST, false, false, + HostPortPair("www.google.com", 80), false, false, OnHostResolutionCallback())); int rv = handle.Init("a", dest, LOWEST, callback.callback(), &pool_, BoundNetLog()); @@ -939,7 +956,7 @@ TEST_F(TransportClientSocketPoolTest, FailingActiveRequestWithPendingRequests) { TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -957,7 +974,7 @@ TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) { // Now we should have 1 idle socket. EXPECT_EQ(1, pool_.IdleSocketCount()); - rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_, + rv = handle.Init("a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(OK, rv); EXPECT_EQ(0, pool_.IdleSocketCount()); @@ -967,7 +984,7 @@ TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) { TEST_F(TransportClientSocketPoolTest, ResetIdleSocketsOnIPAddressChange) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1023,7 +1040,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketConnect) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("b", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1065,7 +1082,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketCancel) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("c", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("c", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1111,7 +1128,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("b", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1159,7 +1176,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterDelay) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("b", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1215,7 +1232,7 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv4FinishesFirst) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1260,7 +1277,7 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv6FinishesFirst) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1294,7 +1311,7 @@ TEST_F(TransportClientSocketPoolTest, IPv6NoIPv4AddressesToFallbackTo) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1327,7 +1344,7 @@ TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); diff --git a/chromium/net/socket/transport_client_socket_unittest.cc b/chromium/net/socket/transport_client_socket_unittest.cc index 2f75e740067..5548b27b995 100644 --- a/chromium/net/socket/transport_client_socket_unittest.cc +++ b/chromium/net/socket/transport_client_socket_unittest.cc @@ -48,8 +48,9 @@ class TransportClientSocketTest // Implement StreamListenSocket::Delegate methods virtual void DidAccept(StreamListenSocket* server, - StreamListenSocket* connection) OVERRIDE { - connected_sock_ = reinterpret_cast<TCPListenSocket*>(connection); + scoped_ptr<StreamListenSocket> connection) OVERRIDE { + connected_sock_.reset( + static_cast<TCPListenSocket*>(connection.release())); } virtual void DidRead(StreamListenSocket*, const char* str, int len) OVERRIDE { // TODO(dkegel): this might not be long enough to tickle some bugs. @@ -65,7 +66,7 @@ class TransportClientSocketTest void CloseServerSocket() { // delete the connected_sock_, which will close it. - connected_sock_ = NULL; + connected_sock_.reset(); } void PauseServerReads() { @@ -94,8 +95,8 @@ class TransportClientSocketTest scoped_ptr<StreamSocket> sock_; private: - scoped_refptr<TCPListenSocket> listen_sock_; - scoped_refptr<TCPListenSocket> connected_sock_; + scoped_ptr<TCPListenSocket> listen_sock_; + scoped_ptr<TCPListenSocket> connected_sock_; bool close_server_socket_on_next_send_; }; @@ -103,7 +104,7 @@ void TransportClientSocketTest::SetUp() { ::testing::TestWithParam<ClientSocketTestTypes>::SetUp(); // Find a free port to listen on - scoped_refptr<TCPListenSocket> sock; + scoped_ptr<TCPListenSocket> sock; int port; // Range of ports to listen on. Shouldn't need to try many. const int kMinPort = 10100; @@ -117,7 +118,7 @@ void TransportClientSocketTest::SetUp() { break; } ASSERT_TRUE(sock.get() != NULL); - listen_sock_ = sock; + listen_sock_ = sock.Pass(); listen_port_ = port; AddressList addr; @@ -125,15 +126,15 @@ void TransportClientSocketTest::SetUp() { scoped_ptr<HostResolver> resolver(new MockHostResolver()); HostResolver::RequestInfo info(HostPortPair("localhost", listen_port_)); TestCompletionCallback callback; - int rv = resolver->Resolve(info, &addr, callback.callback(), NULL, - BoundNetLog()); + int rv = resolver->Resolve( + info, DEFAULT_PRIORITY, &addr, callback.callback(), NULL, BoundNetLog()); CHECK_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); CHECK_EQ(rv, OK); - sock_.reset( + sock_ = socket_factory_->CreateTransportClientSocket(addr, &net_log_, - NetLog::Source())); + NetLog::Source()); } int TransportClientSocketTest::DrainClientSocket( diff --git a/chromium/net/socket/unix_domain_socket_posix.cc b/chromium/net/socket/unix_domain_socket_posix.cc index 5b6b2498245..2b781d58b35 100644 --- a/chromium/net/socket/unix_domain_socket_posix.cc +++ b/chromium/net/socket/unix_domain_socket_posix.cc @@ -21,6 +21,7 @@ #include "build/build_config.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" +#include "net/socket/socket_descriptor.h" namespace net { @@ -48,12 +49,12 @@ bool GetPeerIds(int socket, uid_t* user_id, gid_t* group_id) { } // namespace // static -UnixDomainSocket::AuthCallback NoAuthentication() { +UnixDomainSocket::AuthCallback UnixDomainSocket::NoAuthentication() { return base::Bind(NoAuthenticationCallback); } // static -UnixDomainSocket* UnixDomainSocket::CreateAndListenInternal( +scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListenInternal( const std::string& path, const std::string& fallback_path, StreamListenSocket::Delegate* del, @@ -63,14 +64,15 @@ UnixDomainSocket* UnixDomainSocket::CreateAndListenInternal( if (s == kInvalidSocket && !fallback_path.empty()) s = CreateAndBind(fallback_path, use_abstract_namespace); if (s == kInvalidSocket) - return NULL; - UnixDomainSocket* sock = new UnixDomainSocket(s, del, auth_callback); + return scoped_ptr<UnixDomainSocket>(); + scoped_ptr<UnixDomainSocket> sock( + new UnixDomainSocket(s, del, auth_callback)); sock->Listen(); - return sock; + return sock.Pass(); } // static -scoped_refptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen( +scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen( const std::string& path, StreamListenSocket::Delegate* del, const AuthCallback& auth_callback) { @@ -79,14 +81,14 @@ scoped_refptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen( #if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) // static -scoped_refptr<UnixDomainSocket> +scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListenWithAbstractNamespace( const std::string& path, const std::string& fallback_path, StreamListenSocket::Delegate* del, const AuthCallback& auth_callback) { - return make_scoped_refptr( - CreateAndListenInternal(path, fallback_path, del, auth_callback, true)); + return + CreateAndListenInternal(path, fallback_path, del, auth_callback, true); } #endif @@ -106,7 +108,7 @@ SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path, static const size_t kPathMax = sizeof(addr.sun_path); if (use_abstract_namespace + path.size() + 1 /* '\0' */ > kPathMax) return kInvalidSocket; - const SocketDescriptor s = socket(PF_UNIX, SOCK_STREAM, 0); + const SocketDescriptor s = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0); if (s == kInvalidSocket) return kInvalidSocket; memset(&addr, 0, sizeof(addr)); @@ -147,11 +149,11 @@ void UnixDomainSocket::Accept() { LOG(ERROR) << "close() error"; return; } - scoped_refptr<UnixDomainSocket> sock( + scoped_ptr<UnixDomainSocket> sock( new UnixDomainSocket(conn, socket_delegate_, auth_callback_)); // It's up to the delegate to AddRef if it wants to keep it around. sock->WatchSocket(WAITING_READ); - socket_delegate_->DidAccept(this, sock.get()); + socket_delegate_->DidAccept(this, sock.PassAs<StreamListenSocket>()); } UnixDomainSocketFactory::UnixDomainSocketFactory( @@ -162,10 +164,10 @@ UnixDomainSocketFactory::UnixDomainSocketFactory( UnixDomainSocketFactory::~UnixDomainSocketFactory() {} -scoped_refptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen( +scoped_ptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen( StreamListenSocket::Delegate* delegate) const { return UnixDomainSocket::CreateAndListen( - path_, delegate, auth_callback_); + path_, delegate, auth_callback_).PassAs<StreamListenSocket>(); } #if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) @@ -181,11 +183,12 @@ UnixDomainSocketWithAbstractNamespaceFactory( UnixDomainSocketWithAbstractNamespaceFactory:: ~UnixDomainSocketWithAbstractNamespaceFactory() {} -scoped_refptr<StreamListenSocket> +scoped_ptr<StreamListenSocket> UnixDomainSocketWithAbstractNamespaceFactory::CreateAndListen( StreamListenSocket::Delegate* delegate) const { return UnixDomainSocket::CreateAndListenWithAbstractNamespace( - path_, fallback_path_, delegate, auth_callback_); + path_, fallback_path_, delegate, auth_callback_) + .PassAs<StreamListenSocket>(); } #endif diff --git a/chromium/net/socket/unix_domain_socket_posix.h b/chromium/net/socket/unix_domain_socket_posix.h index 2ef06803d24..98d0c11a648 100644 --- a/chromium/net/socket/unix_domain_socket_posix.h +++ b/chromium/net/socket/unix_domain_socket_posix.h @@ -10,7 +10,6 @@ #include "base/basictypes.h" #include "base/callback_forward.h" #include "base/compiler_specific.h" -#include "base/memory/ref_counted.h" #include "build/build_config.h" #include "net/base/net_export.h" #include "net/socket/stream_listen_socket.h" @@ -26,6 +25,8 @@ namespace net { // Unix Domain Socket Implementation. Supports abstract namespaces on Linux. class NET_EXPORT UnixDomainSocket : public StreamListenSocket { public: + virtual ~UnixDomainSocket(); + // Callback that returns whether the already connected client, identified by // its process |user_id| and |group_id|, is allowed to keep the connection // open. Note that the socket is closed immediately in case the callback @@ -38,7 +39,7 @@ class NET_EXPORT UnixDomainSocket : public StreamListenSocket { // Note that the returned UnixDomainSocket instance does not take ownership of // |del|. - static scoped_refptr<UnixDomainSocket> CreateAndListen( + static scoped_ptr<UnixDomainSocket> CreateAndListen( const std::string& path, StreamListenSocket::Delegate* del, const AuthCallback& auth_callback); @@ -47,7 +48,7 @@ class NET_EXPORT UnixDomainSocket : public StreamListenSocket { // Same as above except that the created socket uses the abstract namespace // which is a Linux-only feature. If |fallback_path| is not empty, // make the second attempt with the provided fallback name. - static scoped_refptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace( + static scoped_ptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace( const std::string& path, const std::string& fallback_path, StreamListenSocket::Delegate* del, @@ -58,9 +59,8 @@ class NET_EXPORT UnixDomainSocket : public StreamListenSocket { UnixDomainSocket(SocketDescriptor s, StreamListenSocket::Delegate* del, const AuthCallback& auth_callback); - virtual ~UnixDomainSocket(); - static UnixDomainSocket* CreateAndListenInternal( + static scoped_ptr<UnixDomainSocket> CreateAndListenInternal( const std::string& path, const std::string& fallback_path, StreamListenSocket::Delegate* del, @@ -87,7 +87,7 @@ class NET_EXPORT UnixDomainSocketFactory : public StreamListenSocketFactory { virtual ~UnixDomainSocketFactory(); // StreamListenSocketFactory: - virtual scoped_refptr<StreamListenSocket> CreateAndListen( + virtual scoped_ptr<StreamListenSocket> CreateAndListen( StreamListenSocket::Delegate* delegate) const OVERRIDE; protected: @@ -111,7 +111,7 @@ class NET_EXPORT UnixDomainSocketWithAbstractNamespaceFactory virtual ~UnixDomainSocketWithAbstractNamespaceFactory(); // UnixDomainSocketFactory: - virtual scoped_refptr<StreamListenSocket> CreateAndListen( + virtual scoped_ptr<StreamListenSocket> CreateAndListen( StreamListenSocket::Delegate* delegate) const OVERRIDE; private: diff --git a/chromium/net/socket/unix_domain_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_socket_posix_unittest.cc index 5abe03b4ae3..f062d274205 100644 --- a/chromium/net/socket/unix_domain_socket_posix_unittest.cc +++ b/chromium/net/socket/unix_domain_socket_posix_unittest.cc @@ -29,6 +29,7 @@ #include "base/synchronization/lock.h" #include "base/threading/platform_thread.h" #include "base/threading/thread.h" +#include "net/socket/socket_descriptor.h" #include "net/socket/unix_domain_socket_posix.h" #include "testing/gtest/include/gtest/gtest.h" @@ -102,9 +103,9 @@ class TestListenSocketDelegate : public StreamListenSocket::Delegate { : event_manager_(event_manager) {} virtual void DidAccept(StreamListenSocket* server, - StreamListenSocket* connection) OVERRIDE { + scoped_ptr<StreamListenSocket> connection) OVERRIDE { LOG(ERROR) << __PRETTY_FUNCTION__; - connection_ = connection; + connection_ = connection.Pass(); Notify(EVENT_ACCEPT); } @@ -138,7 +139,7 @@ class TestListenSocketDelegate : public StreamListenSocket::Delegate { } const scoped_refptr<EventManager> event_manager_; - scoped_refptr<StreamListenSocket> connection_; + scoped_ptr<StreamListenSocket> connection_; base::Lock mutex_; string data_; }; @@ -172,7 +173,7 @@ class UnixDomainSocketTestHelper : public testing::Test { virtual void TearDown() OVERRIDE { DeleteSocketFile(); - socket_ = NULL; + socket_.reset(); socket_delegate_.reset(); event_manager_ = NULL; } @@ -187,10 +188,10 @@ class UnixDomainSocketTestHelper : public testing::Test { } SocketDescriptor CreateClientSocket() { - const SocketDescriptor sock = socket(PF_UNIX, SOCK_STREAM, 0); + const SocketDescriptor sock = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0); if (sock < 0) { LOG(ERROR) << "socket() error"; - return StreamListenSocket::kInvalidSocket; + return kInvalidSocket; } sockaddr_un addr; memset(&addr, 0, sizeof(addr)); @@ -200,7 +201,7 @@ class UnixDomainSocketTestHelper : public testing::Test { addr_len = sizeof(sockaddr_un); if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) { LOG(ERROR) << "connect() error"; - return StreamListenSocket::kInvalidSocket; + return kInvalidSocket; } return sock; } @@ -221,7 +222,7 @@ class UnixDomainSocketTestHelper : public testing::Test { const bool allow_user_; scoped_refptr<EventManager> event_manager_; scoped_ptr<TestListenSocketDelegate> socket_delegate_; - scoped_refptr<UnixDomainSocket> socket_; + scoped_ptr<UnixDomainSocket> socket_; }; class UnixDomainSocketTest : public UnixDomainSocketTestHelper { @@ -264,7 +265,7 @@ TEST_F(UnixDomainSocketTestWithInvalidPath, } TEST_F(UnixDomainSocketTest, TestFallbackName) { - scoped_refptr<UnixDomainSocket> existing_socket = + scoped_ptr<UnixDomainSocket> existing_socket = UnixDomainSocket::CreateAndListenWithAbstractNamespace( file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); EXPECT_FALSE(existing_socket.get() == NULL); @@ -280,7 +281,6 @@ TEST_F(UnixDomainSocketTest, TestFallbackName) { socket_delegate_.get(), MakeAuthCallback()); EXPECT_FALSE(socket_.get() == NULL); - existing_socket = NULL; } #endif @@ -291,7 +291,7 @@ TEST_F(UnixDomainSocketTest, TestWithClient) { // Create the client socket. const SocketDescriptor sock = CreateClientSocket(); - ASSERT_NE(StreamListenSocket::kInvalidSocket, sock); + ASSERT_NE(kInvalidSocket, sock); event = event_manager_->WaitForEvent(); ASSERT_EQ(EVENT_AUTH_GRANTED, event); event = event_manager_->WaitForEvent(); @@ -316,7 +316,7 @@ TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) { EventType event = event_manager_->WaitForEvent(); ASSERT_EQ(EVENT_LISTEN, event); const SocketDescriptor sock = CreateClientSocket(); - ASSERT_NE(StreamListenSocket::kInvalidSocket, sock); + ASSERT_NE(kInvalidSocket, sock); event = event_manager_->WaitForEvent(); ASSERT_EQ(EVENT_AUTH_DENIED, event); |