summaryrefslogtreecommitdiff
path: root/chromium/net/socket
diff options
context:
space:
mode:
authorAndras Becsi <andras.becsi@digia.com>2013-12-11 21:33:03 +0100
committerAndras Becsi <andras.becsi@digia.com>2013-12-13 12:34:07 +0100
commitf2a33ff9cbc6d19943f1c7fbddd1f23d23975577 (patch)
tree0586a32aa390ade8557dfd6b4897f43a07449578 /chromium/net/socket
parent5362912cdb5eea702b68ebe23702468d17c3017a (diff)
downloadqtwebengine-chromium-f2a33ff9cbc6d19943f1c7fbddd1f23d23975577.tar.gz
Update Chromium to branch 1650 (31.0.1650.63)
Change-Id: I57d8c832eaec1eb2364e0a8e7352a6dd354db99f Reviewed-by: Jocelyn Turcotte <jocelyn.turcotte@digia.com>
Diffstat (limited to 'chromium/net/socket')
-rw-r--r--chromium/net/socket/buffered_write_stream_socket.cc4
-rw-r--r--chromium/net/socket/buffered_write_stream_socket.h6
-rw-r--r--chromium/net/socket/buffered_write_stream_socket_unittest.cc7
-rw-r--r--chromium/net/socket/client_socket_factory.cc44
-rw-r--r--chromium/net/socket/client_socket_factory.h16
-rw-r--r--chromium/net/socket/client_socket_handle.cc77
-rw-r--r--chromium/net/socket/client_socket_handle.h48
-rw-r--r--chromium/net/socket/client_socket_pool.h88
-rw-r--r--chromium/net/socket/client_socket_pool_base.cc414
-rw-r--r--chromium/net/socket/client_socket_pool_base.h169
-rw-r--r--chromium/net/socket/client_socket_pool_base_unittest.cc85
-rw-r--r--chromium/net/socket/client_socket_pool_manager.cc7
-rw-r--r--chromium/net/socket/deterministic_socket_data_unittest.cc1
-rw-r--r--chromium/net/socket/nss_ssl_util.cc9
-rw-r--r--chromium/net/socket/socket_descriptor.cc48
-rw-r--r--chromium/net/socket/socket_descriptor.h49
-rw-r--r--chromium/net/socket/socket_test_util.cc102
-rw-r--r--chromium/net/socket/socket_test_util.h51
-rw-r--r--chromium/net/socket/socks5_client_socket.cc22
-rw-r--r--chromium/net/socket/socks5_client_socket.h9
-rw-r--r--chromium/net/socket/socks5_client_socket_unittest.cc77
-rw-r--r--chromium/net/socket/socks_client_socket.cc32
-rw-r--r--chromium/net/socket/socks_client_socket.h12
-rw-r--r--chromium/net/socket/socks_client_socket_pool.cc82
-rw-r--r--chromium/net/socket/socks_client_socket_pool.h29
-rw-r--r--chromium/net/socket/socks_client_socket_pool_unittest.cc95
-rw-r--r--chromium/net/socket/socks_client_socket_unittest.cc97
-rw-r--r--chromium/net/socket/ssl_client_socket_nss.cc42
-rw-r--r--chromium/net/socket/ssl_client_socket_nss.h2
-rw-r--r--chromium/net/socket/ssl_client_socket_openssl.cc14
-rw-r--r--chromium/net/socket/ssl_client_socket_openssl.h2
-rw-r--r--chromium/net/socket/ssl_client_socket_openssl_unittest.cc16
-rw-r--r--chromium/net/socket/ssl_client_socket_pool.cc222
-rw-r--r--chromium/net/socket/ssl_client_socket_pool.h89
-rw-r--r--chromium/net/socket/ssl_client_socket_pool_unittest.cc98
-rw-r--r--chromium/net/socket/ssl_client_socket_unittest.cc1200
-rw-r--r--chromium/net/socket/ssl_server_socket.h5
-rw-r--r--chromium/net/socket/ssl_server_socket_nss.cc11
-rw-r--r--chromium/net/socket/ssl_server_socket_nss.h2
-rw-r--r--chromium/net/socket/ssl_server_socket_openssl.cc12
-rw-r--r--chromium/net/socket/ssl_server_socket_unittest.cc20
-rw-r--r--chromium/net/socket/stream_listen_socket.cc3
-rw-r--r--chromium/net/socket/stream_listen_socket.h23
-rw-r--r--chromium/net/socket/tcp_client_socket.cc319
-rw-r--r--chromium/net/socket/tcp_client_socket.h124
-rw-r--r--chromium/net/socket/tcp_client_socket_libevent.h256
-rw-r--r--chromium/net/socket/tcp_client_socket_win.h162
-rw-r--r--chromium/net/socket/tcp_listen_socket.cc24
-rw-r--r--chromium/net/socket/tcp_listen_socket.h12
-rw-r--r--chromium/net/socket/tcp_listen_socket_unittest.cc56
-rw-r--r--chromium/net/socket/tcp_listen_socket_unittest.h21
-rw-r--r--chromium/net/socket/tcp_server_socket.cc105
-rw-r--r--chromium/net/socket/tcp_server_socket.h51
-rw-r--r--chromium/net/socket/tcp_server_socket_libevent.cc223
-rw-r--r--chromium/net/socket/tcp_server_socket_libevent.h55
-rw-r--r--chromium/net/socket/tcp_server_socket_win.cc217
-rw-r--r--chromium/net/socket/tcp_server_socket_win.h58
-rw-r--r--chromium/net/socket/tcp_socket.cc59
-rw-r--r--chromium/net/socket/tcp_socket.h40
-rw-r--r--chromium/net/socket/tcp_socket_libevent.cc (renamed from chromium/net/socket/tcp_client_socket_libevent.cc)903
-rw-r--r--chromium/net/socket/tcp_socket_libevent.h235
-rw-r--r--chromium/net/socket/tcp_socket_unittest.cc263
-rw-r--r--chromium/net/socket/tcp_socket_win.cc (renamed from chromium/net/socket/tcp_client_socket_win.cc)953
-rw-r--r--chromium/net/socket/tcp_socket_win.h150
-rw-r--r--chromium/net/socket/transport_client_socket_pool.cc87
-rw-r--r--chromium/net/socket/transport_client_socket_pool.h21
-rw-r--r--chromium/net/socket/transport_client_socket_pool_unittest.cc95
-rw-r--r--chromium/net/socket/transport_client_socket_unittest.cc23
-rw-r--r--chromium/net/socket/unix_domain_socket_posix.cc35
-rw-r--r--chromium/net/socket/unix_domain_socket_posix.h14
-rw-r--r--chromium/net/socket/unix_domain_socket_posix_unittest.cc24
71 files changed, 4358 insertions, 3668 deletions
diff --git a/chromium/net/socket/buffered_write_stream_socket.cc b/chromium/net/socket/buffered_write_stream_socket.cc
index 36b9df715fd..cf13c5e439a 100644
--- a/chromium/net/socket/buffered_write_stream_socket.cc
+++ b/chromium/net/socket/buffered_write_stream_socket.cc
@@ -23,8 +23,8 @@ void AppendBuffer(GrowableIOBuffer* dst, IOBuffer* src, int src_len) {
} // anonymous namespace
BufferedWriteStreamSocket::BufferedWriteStreamSocket(
- StreamSocket* socket_to_wrap)
- : wrapped_socket_(socket_to_wrap),
+ scoped_ptr<StreamSocket> socket_to_wrap)
+ : wrapped_socket_(socket_to_wrap.Pass()),
io_buffer_(new GrowableIOBuffer()),
backup_buffer_(new GrowableIOBuffer()),
weak_factory_(this),
diff --git a/chromium/net/socket/buffered_write_stream_socket.h b/chromium/net/socket/buffered_write_stream_socket.h
index fcb33a81910..aad5736d0b0 100644
--- a/chromium/net/socket/buffered_write_stream_socket.h
+++ b/chromium/net/socket/buffered_write_stream_socket.h
@@ -5,6 +5,8 @@
#ifndef NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_
#define NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_
+#include "base/basictypes.h"
+#include "base/memory/scoped_ptr.h"
#include "base/memory/weak_ptr.h"
#include "net/base/net_log.h"
#include "net/socket/stream_socket.h"
@@ -33,7 +35,7 @@ class IPEndPoint;
// There are no bounds on the local buffer size. Use carefully.
class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket {
public:
- BufferedWriteStreamSocket(StreamSocket* socket_to_wrap);
+ explicit BufferedWriteStreamSocket(scoped_ptr<StreamSocket> socket_to_wrap);
virtual ~BufferedWriteStreamSocket();
// Socket interface
@@ -71,6 +73,8 @@ class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket {
bool callback_pending_;
bool wrapped_write_in_progress_;
int error_;
+
+ DISALLOW_COPY_AND_ASSIGN(BufferedWriteStreamSocket);
};
} // namespace net
diff --git a/chromium/net/socket/buffered_write_stream_socket_unittest.cc b/chromium/net/socket/buffered_write_stream_socket_unittest.cc
index e579a7f51d2..485295f33f6 100644
--- a/chromium/net/socket/buffered_write_stream_socket_unittest.cc
+++ b/chromium/net/socket/buffered_write_stream_socket_unittest.cc
@@ -30,10 +30,11 @@ class BufferedWriteStreamSocketTest : public testing::Test {
if (writes_count) {
data_->StopAfter(writes_count);
}
- DeterministicMockTCPClientSocket* wrapped_socket =
- new DeterministicMockTCPClientSocket(net_log_.net_log(), data_.get());
+ scoped_ptr<DeterministicMockTCPClientSocket> wrapped_socket(
+ new DeterministicMockTCPClientSocket(net_log_.net_log(), data_.get()));
data_->set_delegate(wrapped_socket->AsWeakPtr());
- socket_.reset(new BufferedWriteStreamSocket(wrapped_socket));
+ socket_.reset(new BufferedWriteStreamSocket(
+ wrapped_socket.PassAs<StreamSocket>()));
socket_->Connect(callback_.callback());
}
diff --git a/chromium/net/socket/client_socket_factory.cc b/chromium/net/socket/client_socket_factory.cc
index 022988aa6a9..a86688e3333 100644
--- a/chromium/net/socket/client_socket_factory.cc
+++ b/chromium/net/socket/client_socket_factory.cc
@@ -67,23 +67,25 @@ class DefaultClientSocketFactory : public ClientSocketFactory,
ClearSSLSessionCache();
}
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE {
- return new UDPClientSocket(bind_type, rand_int_cb, net_log, source);
+ return scoped_ptr<DatagramClientSocket>(
+ new UDPClientSocket(bind_type, rand_int_cb, net_log, source));
}
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE {
- return new TCPClientSocket(addresses, net_log, source);
+ return scoped_ptr<StreamSocket>(
+ new TCPClientSocket(addresses, net_log, source));
}
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) OVERRIDE {
@@ -102,17 +104,19 @@ class DefaultClientSocketFactory : public ClientSocketFactory,
nss_task_runner = base::ThreadTaskRunnerHandle::Get();
#if defined(USE_OPENSSL)
- return new SSLClientSocketOpenSSL(transport_socket, host_and_port,
- ssl_config, context);
+ return scoped_ptr<SSLClientSocket>(
+ new SSLClientSocketOpenSSL(transport_socket.Pass(), host_and_port,
+ ssl_config, context));
#elif defined(USE_NSS) || defined(OS_MACOSX) || defined(OS_WIN)
- return new SSLClientSocketNSS(nss_task_runner.get(),
- transport_socket,
- host_and_port,
- ssl_config,
- context);
+ return scoped_ptr<SSLClientSocket>(
+ new SSLClientSocketNSS(nss_task_runner.get(),
+ transport_socket.Pass(),
+ host_and_port,
+ ssl_config,
+ context));
#else
NOTIMPLEMENTED();
- return NULL;
+ return scoped_ptr<SSLClientSocket>();
#endif
}
@@ -130,18 +134,6 @@ static base::LazyInstance<DefaultClientSocketFactory>::Leaky
} // namespace
-// Deprecated function (http://crbug.com/37810) that takes a StreamSocket.
-SSLClientSocket* ClientSocketFactory::CreateSSLClientSocket(
- StreamSocket* transport_socket,
- const HostPortPair& host_and_port,
- const SSLConfig& ssl_config,
- const SSLClientSocketContext& context) {
- ClientSocketHandle* socket_handle = new ClientSocketHandle();
- socket_handle->set_socket(transport_socket);
- return CreateSSLClientSocket(socket_handle, host_and_port, ssl_config,
- context);
-}
-
// static
ClientSocketFactory* ClientSocketFactory::GetDefaultFactory() {
return g_default_client_socket_factory.Pointer();
diff --git a/chromium/net/socket/client_socket_factory.h b/chromium/net/socket/client_socket_factory.h
index 65022f29234..6cb5949f0b3 100644
--- a/chromium/net/socket/client_socket_factory.h
+++ b/chromium/net/socket/client_socket_factory.h
@@ -8,6 +8,7 @@
#include <string>
#include "base/basictypes.h"
+#include "base/memory/scoped_ptr.h"
#include "net/base/net_export.h"
#include "net/base/net_log.h"
#include "net/base/rand_callback.h"
@@ -32,13 +33,13 @@ class NET_EXPORT ClientSocketFactory {
// |source| is the NetLog::Source for the entity trying to create the socket,
// if it has one.
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) = 0;
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* net_log,
const NetLog::Source& source) = 0;
@@ -46,19 +47,12 @@ class NET_EXPORT ClientSocketFactory {
// It is allowed to pass in a |transport_socket| that is not obtained from a
// socket pool. The caller could create a ClientSocketHandle directly and call
// set_socket() on it to set a valid StreamSocket instance.
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) = 0;
- // Deprecated function (http://crbug.com/37810) that takes a StreamSocket.
- virtual SSLClientSocket* CreateSSLClientSocket(
- StreamSocket* transport_socket,
- const HostPortPair& host_and_port,
- const SSLConfig& ssl_config,
- const SSLClientSocketContext& context);
-
// Clears cache used for SSL session resumption.
virtual void ClearSSLSessionCache() = 0;
diff --git a/chromium/net/socket/client_socket_handle.cc b/chromium/net/socket/client_socket_handle.cc
index 3894fa7aa0e..e42e9fcada3 100644
--- a/chromium/net/socket/client_socket_handle.cc
+++ b/chromium/net/socket/client_socket_handle.cc
@@ -18,7 +18,7 @@ namespace net {
ClientSocketHandle::ClientSocketHandle()
: is_initialized_(false),
pool_(NULL),
- layered_pool_(NULL),
+ higher_pool_(NULL),
is_reused_(false),
callback_(base::Bind(&ClientSocketHandle::OnIOComplete,
base::Unretained(this))),
@@ -34,29 +34,34 @@ void ClientSocketHandle::Reset() {
}
void ClientSocketHandle::ResetInternal(bool cancel) {
- if (group_name_.empty()) // Was Init called?
- return;
- if (is_initialized()) {
- // Because of http://crbug.com/37810 we may not have a pool, but have
- // just a raw socket.
- socket_->NetLog().EndEvent(NetLog::TYPE_SOCKET_IN_USE);
- if (pool_)
- // If we've still got a socket, release it back to the ClientSocketPool so
- // it can be deleted or reused.
- pool_->ReleaseSocket(group_name_, release_socket(), pool_id_);
- } else if (cancel) {
- // If we did not get initialized yet, we've got a socket request pending.
- // Cancel it.
- pool_->CancelRequest(group_name_, this);
+ // Was Init called?
+ if (!group_name_.empty()) {
+ // If so, we must have a pool.
+ CHECK(pool_);
+ if (is_initialized()) {
+ if (socket_) {
+ socket_->NetLog().EndEvent(NetLog::TYPE_SOCKET_IN_USE);
+ // Release the socket back to the ClientSocketPool so it can be
+ // deleted or reused.
+ pool_->ReleaseSocket(group_name_, socket_.Pass(), pool_id_);
+ } else {
+ // If the handle has been initialized, we should still have a
+ // socket.
+ NOTREACHED();
+ }
+ } else if (cancel) {
+ // If we did not get initialized yet and we have a socket
+ // request pending, cancel it.
+ pool_->CancelRequest(group_name_, this);
+ }
}
is_initialized_ = false;
+ socket_.reset();
group_name_.clear();
is_reused_ = false;
user_callback_.Reset();
- if (layered_pool_) {
- pool_->RemoveLayeredPool(layered_pool_);
- layered_pool_ = NULL;
- }
+ if (higher_pool_)
+ RemoveHigherLayeredPool(higher_pool_);
pool_ = NULL;
idle_time_ = base::TimeDelta();
init_time_ = base::TimeTicks();
@@ -82,24 +87,30 @@ LoadState ClientSocketHandle::GetLoadState() const {
}
bool ClientSocketHandle::IsPoolStalled() const {
+ if (!pool_)
+ return false;
return pool_->IsStalled();
}
-void ClientSocketHandle::AddLayeredPool(LayeredPool* layered_pool) {
- CHECK(layered_pool);
- CHECK(!layered_pool_);
+void ClientSocketHandle::AddHigherLayeredPool(HigherLayeredPool* higher_pool) {
+ CHECK(higher_pool);
+ CHECK(!higher_pool_);
+ // TODO(mmenke): |pool_| should only be NULL in tests. Maybe stop doing that
+ // so this be be made into a DCHECK, and the same can be done in
+ // RemoveHigherLayeredPool?
if (pool_) {
- pool_->AddLayeredPool(layered_pool);
- layered_pool_ = layered_pool;
+ pool_->AddHigherLayeredPool(higher_pool);
+ higher_pool_ = higher_pool;
}
}
-void ClientSocketHandle::RemoveLayeredPool(LayeredPool* layered_pool) {
- CHECK(layered_pool);
- CHECK(layered_pool_);
+void ClientSocketHandle::RemoveHigherLayeredPool(
+ HigherLayeredPool* higher_pool) {
+ CHECK(higher_pool_);
+ CHECK_EQ(higher_pool_, higher_pool);
if (pool_) {
- pool_->RemoveLayeredPool(layered_pool);
- layered_pool_ = NULL;
+ pool_->RemoveHigherLayeredPool(higher_pool);
+ higher_pool_ = NULL;
}
}
@@ -121,6 +132,10 @@ bool ClientSocketHandle::GetLoadTimingInfo(
return true;
}
+void ClientSocketHandle::SetSocket(scoped_ptr<StreamSocket> s) {
+ socket_ = s.Pass();
+}
+
void ClientSocketHandle::OnIOComplete(int result) {
CompletionCallback callback = user_callback_;
user_callback_.Reset();
@@ -128,6 +143,10 @@ void ClientSocketHandle::OnIOComplete(int result) {
callback.Run(result);
}
+scoped_ptr<StreamSocket> ClientSocketHandle::PassSocket() {
+ return socket_.Pass();
+}
+
void ClientSocketHandle::HandleInitCompletion(int result) {
CHECK_NE(ERR_IO_PENDING, result);
ClientSocketPoolHistograms* histograms = pool_->histograms();
diff --git a/chromium/net/socket/client_socket_handle.h b/chromium/net/socket/client_socket_handle.h
index 7d5588a145b..30b7c03e9dc 100644
--- a/chromium/net/socket/client_socket_handle.h
+++ b/chromium/net/socket/client_socket_handle.h
@@ -70,9 +70,9 @@ class NET_EXPORT ClientSocketHandle {
//
// Profiling information for the request is saved to |net_log| if non-NULL.
//
- template <typename SocketParams, typename PoolType>
+ template <typename PoolType>
int Init(const std::string& group_name,
- const scoped_refptr<SocketParams>& socket_params,
+ const scoped_refptr<typename PoolType::SocketParams>& socket_params,
RequestPriority priority,
const CompletionCallback& callback,
PoolType* pool,
@@ -94,9 +94,15 @@ class NET_EXPORT ClientSocketHandle {
bool IsPoolStalled() const;
- void AddLayeredPool(LayeredPool* layered_pool);
+ // Adds a higher layered pool on top of the socket pool that |socket_| belongs
+ // to. At most one higher layered pool can be added to a
+ // ClientSocketHandle at a time. On destruction or reset, automatically
+ // removes the higher pool if RemoveHigherLayeredPool has not been called.
+ void AddHigherLayeredPool(HigherLayeredPool* higher_pool);
- void RemoveLayeredPool(LayeredPool* layered_pool);
+ // Removes a higher layered pool from the socket pool that |socket_| belongs
+ // to. |higher_pool| must have been added by the above function.
+ void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool);
// Returns true when Init() has completed successfully.
bool is_initialized() const { return is_initialized_; }
@@ -116,8 +122,11 @@ class NET_EXPORT ClientSocketHandle {
LoadTimingInfo* load_timing_info) const;
// Used by ClientSocketPool to initialize the ClientSocketHandle.
+ //
+ // SetSocket() may also be used if this handle is used as simply for
+ // socket storage (e.g., http://crbug.com/37810).
+ void SetSocket(scoped_ptr<StreamSocket> s);
void set_is_reused(bool is_reused) { is_reused_ = is_reused; }
- void set_socket(StreamSocket* s) { socket_.reset(s); }
void set_idle_time(base::TimeDelta idle_time) { idle_time_ = idle_time; }
void set_pool_id(int id) { pool_id_ = id; }
void set_is_ssl_error(bool is_ssl_error) { is_ssl_error_ = is_ssl_error; }
@@ -143,11 +152,15 @@ class NET_EXPORT ClientSocketHandle {
return pending_http_proxy_connection_.release();
}
+ StreamSocket* socket() { return socket_.get(); }
+
+ // SetSocket() must be called with a new socket before this handle
+ // is destroyed if is_initialized() is true.
+ scoped_ptr<StreamSocket> PassSocket();
+
// These may only be used if is_initialized() is true.
const std::string& group_name() const { return group_name_; }
int id() const { return pool_id_; }
- StreamSocket* socket() { return socket_.get(); }
- StreamSocket* release_socket() { return socket_.release(); }
bool is_reused() const { return is_reused_; }
base::TimeDelta idle_time() const { return idle_time_; }
SocketReuseType reuse_type() const {
@@ -184,7 +197,7 @@ class NET_EXPORT ClientSocketHandle {
bool is_initialized_;
ClientSocketPool* pool_;
- LayeredPool* layered_pool_;
+ HigherLayeredPool* higher_pool_;
scoped_ptr<StreamSocket> socket_;
std::string group_name_;
bool is_reused_;
@@ -207,20 +220,17 @@ class NET_EXPORT ClientSocketHandle {
};
// Template function implementation:
-template <typename SocketParams, typename PoolType>
-int ClientSocketHandle::Init(const std::string& group_name,
- const scoped_refptr<SocketParams>& socket_params,
- RequestPriority priority,
- const CompletionCallback& callback,
- PoolType* pool,
- const BoundNetLog& net_log) {
+template <typename PoolType>
+int ClientSocketHandle::Init(
+ const std::string& group_name,
+ const scoped_refptr<typename PoolType::SocketParams>& socket_params,
+ RequestPriority priority,
+ const CompletionCallback& callback,
+ PoolType* pool,
+ const BoundNetLog& net_log) {
requesting_source_ = net_log.source();
CHECK(!group_name.empty());
- // Note that this will result in a compile error if the SocketParams has not
- // been registered for the PoolType via REGISTER_SOCKET_PARAMS_FOR_POOL
- // (defined in client_socket_pool.h).
- CheckIsValidSocketParamsForPool<PoolType, SocketParams>();
ResetInternal(true);
ResetErrorState();
pool_ = pool;
diff --git a/chromium/net/socket/client_socket_pool.h b/chromium/net/socket/client_socket_pool.h
index 7cb9a7ebc2e..715cddb94d4 100644
--- a/chromium/net/socket/client_socket_pool.h
+++ b/chromium/net/socket/client_socket_pool.h
@@ -10,6 +10,7 @@
#include "base/basictypes.h"
#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
#include "base/template_util.h"
#include "base/time/time.h"
#include "net/base/completion_callback.h"
@@ -30,20 +31,43 @@ class StreamSocket;
// ClientSocketPools are layered. This defines an interface for lower level
// socket pools to communicate with higher layer pools.
-class NET_EXPORT LayeredPool {
+class NET_EXPORT HigherLayeredPool {
public:
- virtual ~LayeredPool() {};
+ virtual ~HigherLayeredPool() {}
- // Instructs the LayeredPool to close an idle connection. Return true if one
- // was closed.
+ // Instructs the HigherLayeredPool to close an idle connection. Return true if
+ // one was closed. Closing an idle connection will call into the lower layer
+ // pool it came from, so must be careful of re-entrancy when using this.
virtual bool CloseOneIdleConnection() = 0;
};
+// ClientSocketPools are layered. This defines an interface for higher level
+// socket pools to communicate with lower layer pools.
+class NET_EXPORT LowerLayeredPool {
+ public:
+ virtual ~LowerLayeredPool() {}
+
+ // Returns true if a there is currently a request blocked on the per-pool
+ // (not per-host) max socket limit, either in this pool, or one that it is
+ // layered on top of.
+ virtual bool IsStalled() const = 0;
+
+ // Called to add or remove a higher layer pool on top of |this|. A higher
+ // layer pool may be added at most once to |this|, and must be removed prior
+ // to destruction of |this|.
+ virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) = 0;
+ virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) = 0;
+};
+
// A ClientSocketPool is used to restrict the number of sockets open at a time.
// It also maintains a list of idle persistent sockets.
//
-class NET_EXPORT ClientSocketPool {
+class NET_EXPORT ClientSocketPool : public LowerLayeredPool {
public:
+ // Subclasses must also have an inner class SocketParams which is
+ // the type for the |params| argument in RequestSocket() and
+ // RequestSockets() below.
+
// Requests a connected socket for a group_name.
//
// There are five possible results from calling this function:
@@ -111,7 +135,7 @@ class NET_EXPORT ClientSocketPool {
// change when it flushes, so it can use this |id| to discard sockets with
// mismatched ids.
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) = 0;
// This flushes all state from the ClientSocketPool. This means that all
@@ -121,10 +145,6 @@ class NET_EXPORT ClientSocketPool {
// Does not flush any pools wrapped by |this|.
virtual void FlushWithError(int error) = 0;
- // Returns true if a there is currently a request blocked on the
- // per-pool (not per-host) max socket limit.
- virtual bool IsStalled() const = 0;
-
// Called to close any idle connections held by the connection manager.
virtual void CloseIdleSockets() = 0;
@@ -138,12 +158,6 @@ class NET_EXPORT ClientSocketPool {
virtual LoadState GetLoadState(const std::string& group_name,
const ClientSocketHandle* handle) const = 0;
- // Adds a LayeredPool on top of |this|.
- virtual void AddLayeredPool(LayeredPool* layered_pool) = 0;
-
- // Removes a LayeredPool from |this|.
- virtual void RemoveLayeredPool(LayeredPool* layered_pool) = 0;
-
// Retrieves information on the current state of the pool as a
// DictionaryValue. Caller takes possession of the returned value.
// If |include_nested_pools| is true, the states of any nested
@@ -177,41 +191,13 @@ class NET_EXPORT ClientSocketPool {
DISALLOW_COPY_AND_ASSIGN(ClientSocketPool);
};
-// ClientSocketPool subclasses should indicate valid SocketParams via the
-// REGISTER_SOCKET_PARAMS_FOR_POOL macro below. By default, any given
-// <PoolType,SocketParams> pair will have its SocketParamsTrait inherit from
-// base::false_type, but REGISTER_SOCKET_PARAMS_FOR_POOL will specialize that
-// pairing to inherit from base::true_type. This provides compile time
-// verification that the correct SocketParams type is used with the appropriate
-// PoolType.
-template <typename PoolType, typename SocketParams>
-struct SocketParamTraits : public base::false_type {
-};
-
-template <typename PoolType, typename SocketParams>
-void CheckIsValidSocketParamsForPool() {
- COMPILE_ASSERT(!base::is_pointer<scoped_refptr<SocketParams> >::value,
- socket_params_cannot_be_pointer);
- COMPILE_ASSERT((SocketParamTraits<PoolType,
- scoped_refptr<SocketParams> >::value),
- invalid_socket_params_for_pool);
-}
-
-// Provides an empty definition for CheckIsValidSocketParamsForPool() which
-// should be optimized out by the compiler.
-#define REGISTER_SOCKET_PARAMS_FOR_POOL(pool_type, socket_params) \
-template<> \
-struct SocketParamTraits<pool_type, scoped_refptr<socket_params> > \
- : public base::true_type { \
-}
-
-template <typename PoolType, typename SocketParams>
-void RequestSocketsForPool(PoolType* pool,
- const std::string& group_name,
- const scoped_refptr<SocketParams>& params,
- int num_sockets,
- const BoundNetLog& net_log) {
- CheckIsValidSocketParamsForPool<PoolType, SocketParams>();
+template <typename PoolType>
+void RequestSocketsForPool(
+ PoolType* pool,
+ const std::string& group_name,
+ const scoped_refptr<typename PoolType::SocketParams>& params,
+ int num_sockets,
+ const BoundNetLog& net_log) {
pool->RequestSockets(group_name, &params, num_sockets, net_log);
}
diff --git a/chromium/net/socket/client_socket_pool_base.cc b/chromium/net/socket/client_socket_pool_base.cc
index 3332e04a171..cec7956a0ee 100644
--- a/chromium/net/socket/client_socket_pool_base.cc
+++ b/chromium/net/socket/client_socket_pool_base.cc
@@ -65,10 +65,12 @@ int CompareEffectiveRequestPriority(
ConnectJob::ConnectJob(const std::string& group_name,
base::TimeDelta timeout_duration,
+ RequestPriority priority,
Delegate* delegate,
const BoundNetLog& net_log)
: group_name_(group_name),
timeout_duration_(timeout_duration),
+ priority_(priority),
delegate_(delegate),
net_log_(net_log),
idle_(true) {
@@ -82,6 +84,10 @@ ConnectJob::~ConnectJob() {
net_log().EndEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB);
}
+scoped_ptr<StreamSocket> ConnectJob::PassSocket() {
+ return socket_.Pass();
+}
+
int ConnectJob::Connect() {
if (timeout_duration_ != base::TimeDelta())
timer_.Start(FROM_HERE, timeout_duration_, this, &ConnectJob::OnTimeout);
@@ -100,16 +106,16 @@ int ConnectJob::Connect() {
return rv;
}
-void ConnectJob::set_socket(StreamSocket* socket) {
+void ConnectJob::SetSocket(scoped_ptr<StreamSocket> socket) {
if (socket) {
net_log().AddEvent(NetLog::TYPE_CONNECT_JOB_SET_SOCKET,
socket->NetLog().source().ToEventParametersCallback());
}
- socket_.reset(socket);
+ socket_ = socket.Pass();
}
void ConnectJob::NotifyDelegateOfCompletion(int rv) {
- // The delegate will delete |this|.
+ // The delegate will own |this|.
Delegate* delegate = delegate_;
delegate_ = NULL;
@@ -135,7 +141,7 @@ void ConnectJob::LogConnectCompletion(int net_error) {
void ConnectJob::OnTimeout() {
// Make sure the socket is NULL before calling into |delegate|.
- set_socket(NULL);
+ SetSocket(scoped_ptr<StreamSocket>());
net_log_.AddEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_TIMED_OUT);
@@ -161,6 +167,7 @@ ClientSocketPoolBaseHelper::Request::Request(
ClientSocketPoolBaseHelper::Request::~Request() {}
ClientSocketPoolBaseHelper::ClientSocketPoolBaseHelper(
+ HigherLayeredPool* pool,
int max_sockets,
int max_sockets_per_group,
base::TimeDelta unused_idle_socket_timeout,
@@ -177,6 +184,7 @@ ClientSocketPoolBaseHelper::ClientSocketPoolBaseHelper(
connect_job_factory_(connect_job_factory),
connect_backup_jobs_enabled_(false),
pool_generation_number_(0),
+ pool_(pool),
weak_factory_(this) {
DCHECK_LE(0, max_sockets_per_group);
DCHECK_LE(max_sockets_per_group, max_sockets);
@@ -192,9 +200,16 @@ ClientSocketPoolBaseHelper::~ClientSocketPoolBaseHelper() {
DCHECK(group_map_.empty());
DCHECK(pending_callback_map_.empty());
DCHECK_EQ(0, connecting_socket_count_);
- CHECK(higher_layer_pools_.empty());
+ CHECK(higher_pools_.empty());
NetworkChangeNotifier::RemoveIPAddressObserver(this);
+
+ // Remove from lower layer pools.
+ for (std::set<LowerLayeredPool*>::iterator it = lower_pools_.begin();
+ it != lower_pools_.end();
+ ++it) {
+ (*it)->RemoveHigherLayeredPool(pool_);
+ }
}
ClientSocketPoolBaseHelper::CallbackResultPair::CallbackResultPair()
@@ -209,46 +224,59 @@ ClientSocketPoolBaseHelper::CallbackResultPair::CallbackResultPair(
ClientSocketPoolBaseHelper::CallbackResultPair::~CallbackResultPair() {}
-// static
-void ClientSocketPoolBaseHelper::InsertRequestIntoQueue(
- const Request* r, RequestQueue* pending_requests) {
- RequestQueue::iterator it = pending_requests->begin();
- // TODO(mmenke): Should the network stack require requests with
- // |ignore_limits| have the highest priority?
- while (it != pending_requests->end() &&
- CompareEffectiveRequestPriority(*r, *(*it)) <= 0) {
- ++it;
+bool ClientSocketPoolBaseHelper::IsStalled() const {
+ // If a lower layer pool is stalled, consider |this| stalled as well.
+ for (std::set<LowerLayeredPool*>::const_iterator it = lower_pools_.begin();
+ it != lower_pools_.end();
+ ++it) {
+ if ((*it)->IsStalled())
+ return true;
+ }
+
+ // If fewer than |max_sockets_| are in use, then clearly |this| is not
+ // stalled.
+ if ((handed_out_socket_count_ + connecting_socket_count_) < max_sockets_)
+ return false;
+ // So in order to be stalled, |this| must be using at least |max_sockets_| AND
+ // |this| must have a request that is actually stalled on the global socket
+ // limit. To find such a request, look for a group that has more requests
+ // than jobs AND where the number of sockets is less than
+ // |max_sockets_per_group_|. (If the number of sockets is equal to
+ // |max_sockets_per_group_|, then the request is stalled on the group limit,
+ // which does not count.)
+ for (GroupMap::const_iterator it = group_map_.begin();
+ it != group_map_.end(); ++it) {
+ if (it->second->IsStalledOnPoolMaxSockets(max_sockets_per_group_))
+ return true;
}
- pending_requests->insert(it, r);
+ return false;
}
-// static
-const ClientSocketPoolBaseHelper::Request*
-ClientSocketPoolBaseHelper::RemoveRequestFromQueue(
- const RequestQueue::iterator& it, Group* group) {
- const Request* req = *it;
- group->mutable_pending_requests()->erase(it);
- // If there are no more requests, we kill the backup timer.
- if (group->pending_requests().empty())
- group->CleanupBackupJob();
- return req;
+void ClientSocketPoolBaseHelper::AddLowerLayeredPool(
+ LowerLayeredPool* lower_pool) {
+ DCHECK(pool_);
+ CHECK(!ContainsKey(lower_pools_, lower_pool));
+ lower_pools_.insert(lower_pool);
+ lower_pool->AddHigherLayeredPool(pool_);
}
-void ClientSocketPoolBaseHelper::AddLayeredPool(LayeredPool* pool) {
- CHECK(pool);
- CHECK(!ContainsKey(higher_layer_pools_, pool));
- higher_layer_pools_.insert(pool);
+void ClientSocketPoolBaseHelper::AddHigherLayeredPool(
+ HigherLayeredPool* higher_pool) {
+ CHECK(higher_pool);
+ CHECK(!ContainsKey(higher_pools_, higher_pool));
+ higher_pools_.insert(higher_pool);
}
-void ClientSocketPoolBaseHelper::RemoveLayeredPool(LayeredPool* pool) {
- CHECK(pool);
- CHECK(ContainsKey(higher_layer_pools_, pool));
- higher_layer_pools_.erase(pool);
+void ClientSocketPoolBaseHelper::RemoveHigherLayeredPool(
+ HigherLayeredPool* higher_pool) {
+ CHECK(higher_pool);
+ CHECK(ContainsKey(higher_pools_, higher_pool));
+ higher_pools_.erase(higher_pool);
}
int ClientSocketPoolBaseHelper::RequestSocket(
const std::string& group_name,
- const Request* request) {
+ scoped_ptr<const Request> request) {
CHECK(!request->callback().is_null());
CHECK(request->handle());
@@ -259,13 +287,13 @@ int ClientSocketPoolBaseHelper::RequestSocket(
request->net_log().BeginEvent(NetLog::TYPE_SOCKET_POOL);
Group* group = GetOrCreateGroup(group_name);
- int rv = RequestSocketInternal(group_name, request);
+ int rv = RequestSocketInternal(group_name, *request);
if (rv != ERR_IO_PENDING) {
request->net_log().EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, rv);
CHECK(!request->handle()->is_initialized());
- delete request;
+ request.reset();
} else {
- InsertRequestIntoQueue(request, group->mutable_pending_requests());
+ group->InsertPendingRequest(request.Pass());
// Have to do this asynchronously, as closing sockets in higher level pools
// call back in to |this|, which will cause all sorts of fun and exciting
// re-entrancy issues if the socket pool is doing something else at the
@@ -309,7 +337,7 @@ void ClientSocketPoolBaseHelper::RequestSockets(
for (int num_iterations_left = num_sockets;
group->NumActiveSocketSlots() < num_sockets &&
num_iterations_left > 0 ; num_iterations_left--) {
- rv = RequestSocketInternal(group_name, &request);
+ rv = RequestSocketInternal(group_name, request);
if (rv < 0 && rv != ERR_IO_PENDING) {
// We're encountering a synchronous error. Give up.
if (!ContainsKey(group_map_, group_name))
@@ -336,12 +364,12 @@ void ClientSocketPoolBaseHelper::RequestSockets(
int ClientSocketPoolBaseHelper::RequestSocketInternal(
const std::string& group_name,
- const Request* request) {
- ClientSocketHandle* const handle = request->handle();
+ const Request& request) {
+ ClientSocketHandle* const handle = request.handle();
const bool preconnecting = !handle;
Group* group = GetOrCreateGroup(group_name);
- if (!(request->flags() & NO_IDLE_SOCKETS)) {
+ if (!(request.flags() & NO_IDLE_SOCKETS)) {
// Try to reuse a socket.
if (AssignIdleSocketToRequest(request, group))
return OK;
@@ -355,17 +383,17 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal(
// Can we make another active socket now?
if (!group->HasAvailableSocketSlot(max_sockets_per_group_) &&
- !request->ignore_limits()) {
+ !request.ignore_limits()) {
// TODO(willchan): Consider whether or not we need to close a socket in a
// higher layered group. I don't think this makes sense since we would just
// reuse that socket then if we needed one and wouldn't make it down to this
// layer.
- request->net_log().AddEvent(
+ request.net_log().AddEvent(
NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS_PER_GROUP);
return ERR_IO_PENDING;
}
- if (ReachedMaxSocketsLimit() && !request->ignore_limits()) {
+ if (ReachedMaxSocketsLimit() && !request.ignore_limits()) {
// NOTE(mmenke): Wonder if we really need different code for each case
// here. Only reason for them now seems to be preconnects.
if (idle_socket_count() > 0) {
@@ -378,7 +406,7 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal(
} else {
// We could check if we really have a stalled group here, but it requires
// a scan of all groups, so just flip a flag here, and do the check later.
- request->net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS);
+ request.net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS);
return ERR_IO_PENDING;
}
}
@@ -386,17 +414,17 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal(
// We couldn't find a socket to reuse, and there's space to allocate one,
// so allocate and connect a new one.
scoped_ptr<ConnectJob> connect_job(
- connect_job_factory_->NewConnectJob(group_name, *request, this));
+ connect_job_factory_->NewConnectJob(group_name, request, this));
int rv = connect_job->Connect();
if (rv == OK) {
LogBoundConnectJobToRequest(connect_job->net_log().source(), request);
if (!preconnecting) {
- HandOutSocket(connect_job->ReleaseSocket(), false /* not reused */,
+ HandOutSocket(connect_job->PassSocket(), false /* not reused */,
connect_job->connect_timing(), handle, base::TimeDelta(),
- group, request->net_log());
+ group, request.net_log());
} else {
- AddIdleSocket(connect_job->ReleaseSocket(), group);
+ AddIdleSocket(connect_job->PassSocket(), group);
}
} else if (rv == ERR_IO_PENDING) {
// If we don't have any sockets in this group, set a timer for potentially
@@ -409,19 +437,19 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal(
connecting_socket_count_++;
- group->AddJob(connect_job.release(), preconnecting);
+ group->AddJob(connect_job.Pass(), preconnecting);
} else {
LogBoundConnectJobToRequest(connect_job->net_log().source(), request);
- StreamSocket* error_socket = NULL;
+ scoped_ptr<StreamSocket> error_socket;
if (!preconnecting) {
DCHECK(handle);
connect_job->GetAdditionalErrorState(handle);
- error_socket = connect_job->ReleaseSocket();
+ error_socket = connect_job->PassSocket();
}
if (error_socket) {
- HandOutSocket(error_socket, false /* not reused */,
+ HandOutSocket(error_socket.Pass(), false /* not reused */,
connect_job->connect_timing(), handle, base::TimeDelta(),
- group, request->net_log());
+ group, request.net_log());
} else if (group->IsEmpty()) {
RemoveGroup(group_name);
}
@@ -431,7 +459,7 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal(
}
bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest(
- const Request* request, Group* group) {
+ const Request& request, Group* group) {
std::list<IdleSocket>* idle_sockets = group->mutable_idle_sockets();
std::list<IdleSocket>::iterator idle_socket_it = idle_sockets->end();
@@ -469,13 +497,13 @@ bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest(
IdleSocket idle_socket = *idle_socket_it;
idle_sockets->erase(idle_socket_it);
HandOutSocket(
- idle_socket.socket,
+ scoped_ptr<StreamSocket>(idle_socket.socket),
idle_socket.socket->WasEverUsed(),
LoadTimingInfo::ConnectTiming(),
- request->handle(),
+ request.handle(),
idle_time,
group,
- request->net_log());
+ request.net_log());
return true;
}
@@ -484,9 +512,9 @@ bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest(
// static
void ClientSocketPoolBaseHelper::LogBoundConnectJobToRequest(
- const NetLog::Source& connect_job_source, const Request* request) {
- request->net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB,
- connect_job_source.ToEventParametersCallback());
+ const NetLog::Source& connect_job_source, const Request& request) {
+ request.net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB,
+ connect_job_source.ToEventParametersCallback());
}
void ClientSocketPoolBaseHelper::CancelRequest(
@@ -495,11 +523,11 @@ void ClientSocketPoolBaseHelper::CancelRequest(
if (callback_it != pending_callback_map_.end()) {
int result = callback_it->second.result;
pending_callback_map_.erase(callback_it);
- StreamSocket* socket = handle->release_socket();
+ scoped_ptr<StreamSocket> socket = handle->PassSocket();
if (socket) {
if (result != OK)
socket->Disconnect();
- ReleaseSocket(handle->group_name(), socket, handle->id());
+ ReleaseSocket(handle->group_name(), socket.Pass(), handle->id());
}
return;
}
@@ -509,21 +537,18 @@ void ClientSocketPoolBaseHelper::CancelRequest(
Group* group = GetOrCreateGroup(group_name);
// Search pending_requests for matching handle.
- RequestQueue::iterator it = group->mutable_pending_requests()->begin();
- for (; it != group->pending_requests().end(); ++it) {
- if ((*it)->handle() == handle) {
- scoped_ptr<const Request> req(RemoveRequestFromQueue(it, group));
- req->net_log().AddEvent(NetLog::TYPE_CANCELLED);
- req->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL);
-
- // We let the job run, unless we're at the socket limit and there is
- // not another request waiting on the job.
- if (group->jobs().size() > group->pending_requests().size() &&
- ReachedMaxSocketsLimit()) {
- RemoveConnectJob(*group->jobs().begin(), group);
- CheckForStalledSocketGroups();
- }
- break;
+ scoped_ptr<const Request> request =
+ group->FindAndRemovePendingRequest(handle);
+ if (request) {
+ request->net_log().AddEvent(NetLog::TYPE_CANCELLED);
+ request->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL);
+
+ // We let the job run, unless we're at the socket limit and there is
+ // not another request waiting on the job.
+ if (group->jobs().size() > group->pending_request_count() &&
+ ReachedMaxSocketsLimit()) {
+ RemoveConnectJob(*group->jobs().begin(), group);
+ CheckForStalledSocketGroups();
}
}
}
@@ -560,16 +585,7 @@ LoadState ClientSocketPoolBaseHelper::GetLoadState(
// Can't use operator[] since it is non-const.
const Group& group = *group_map_.find(group_name)->second;
- // Search the first group.jobs().size() |pending_requests| for |handle|.
- // If it's farther back in the deque than that, it doesn't have a
- // corresponding ConnectJob.
- size_t connect_jobs = group.jobs().size();
- RequestQueue::const_iterator it = group.pending_requests().begin();
- for (size_t i = 0; it != group.pending_requests().end() && i < connect_jobs;
- ++it, ++i) {
- if ((*it)->handle() != handle)
- continue;
-
+ if (group.HasConnectJobForHandle(handle)) {
// Just return the state of the farthest along ConnectJob for the first
// group.jobs().size() pending requests.
LoadState max_state = LOAD_STATE_IDLE;
@@ -607,8 +623,8 @@ base::DictionaryValue* ClientSocketPoolBaseHelper::GetInfoAsValue(
base::DictionaryValue* group_dict = new base::DictionaryValue();
group_dict->SetInteger("pending_request_count",
- group->pending_requests().size());
- if (!group->pending_requests().empty()) {
+ group->pending_request_count());
+ if (group->has_pending_requests()) {
group_dict->SetInteger("top_pending_priority",
group->TopPendingPriority());
}
@@ -756,7 +772,7 @@ void ClientSocketPoolBaseHelper::StartIdleSocketTimer() {
}
void ClientSocketPoolBaseHelper::ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) {
GroupMap::iterator i = group_map_.find(group_name);
CHECK(i != group_map_.end());
@@ -773,10 +789,10 @@ void ClientSocketPoolBaseHelper::ReleaseSocket(const std::string& group_name,
id == pool_generation_number_;
if (can_reuse) {
// Add it to the idle list.
- AddIdleSocket(socket, group);
+ AddIdleSocket(socket.Pass(), group);
OnAvailableSocketSlot(group_name, group);
} else {
- delete socket;
+ socket.reset();
}
CheckForStalledSocketGroups();
@@ -786,8 +802,18 @@ void ClientSocketPoolBaseHelper::CheckForStalledSocketGroups() {
// If we have idle sockets, see if we can give one to the top-stalled group.
std::string top_group_name;
Group* top_group = NULL;
- if (!FindTopStalledGroup(&top_group, &top_group_name))
+ if (!FindTopStalledGroup(&top_group, &top_group_name)) {
+ // There may still be a stalled group in a lower level pool.
+ for (std::set<LowerLayeredPool*>::iterator it = lower_pools_.begin();
+ it != lower_pools_.end();
+ ++it) {
+ if ((*it)->IsStalled()) {
+ CloseOneIdleSocket();
+ break;
+ }
+ }
return;
+ }
if (ReachedMaxSocketsLimit()) {
if (idle_socket_count() > 0) {
@@ -820,8 +846,7 @@ bool ClientSocketPoolBaseHelper::FindTopStalledGroup(
for (GroupMap::const_iterator i = group_map_.begin();
i != group_map_.end(); ++i) {
Group* curr_group = i->second;
- const RequestQueue& queue = curr_group->pending_requests();
- if (queue.empty())
+ if (!curr_group->has_pending_requests())
continue;
if (curr_group->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) {
if (!group)
@@ -854,27 +879,29 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete(
CHECK(group_it != group_map_.end());
Group* group = group_it->second;
- scoped_ptr<StreamSocket> socket(job->ReleaseSocket());
+ scoped_ptr<StreamSocket> socket = job->PassSocket();
// Copies of these are needed because |job| may be deleted before they are
// accessed.
BoundNetLog job_log = job->net_log();
LoadTimingInfo::ConnectTiming connect_timing = job->connect_timing();
+ // RemoveConnectJob(job, _) must be called by all branches below;
+ // otherwise, |job| will be leaked.
+
if (result == OK) {
DCHECK(socket.get());
RemoveConnectJob(job, group);
- if (!group->pending_requests().empty()) {
- scoped_ptr<const Request> r(RemoveRequestFromQueue(
- group->mutable_pending_requests()->begin(), group));
- LogBoundConnectJobToRequest(job_log.source(), r.get());
+ scoped_ptr<const Request> request = group->PopNextPendingRequest();
+ if (request) {
+ LogBoundConnectJobToRequest(job_log.source(), *request);
HandOutSocket(
- socket.release(), false /* unused socket */, connect_timing,
- r->handle(), base::TimeDelta(), group, r->net_log());
- r->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL);
- InvokeUserCallbackLater(r->handle(), r->callback(), result);
+ socket.Pass(), false /* unused socket */, connect_timing,
+ request->handle(), base::TimeDelta(), group, request->net_log());
+ request->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL);
+ InvokeUserCallbackLater(request->handle(), request->callback(), result);
} else {
- AddIdleSocket(socket.release(), group);
+ AddIdleSocket(socket.Pass(), group);
OnAvailableSocketSlot(group_name, group);
CheckForStalledSocketGroups();
}
@@ -882,20 +909,20 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete(
// If we got a socket, it must contain error information so pass that
// up so that the caller can retrieve it.
bool handed_out_socket = false;
- if (!group->pending_requests().empty()) {
- scoped_ptr<const Request> r(RemoveRequestFromQueue(
- group->mutable_pending_requests()->begin(), group));
- LogBoundConnectJobToRequest(job_log.source(), r.get());
- job->GetAdditionalErrorState(r->handle());
+ scoped_ptr<const Request> request = group->PopNextPendingRequest();
+ if (request) {
+ LogBoundConnectJobToRequest(job_log.source(), *request);
+ job->GetAdditionalErrorState(request->handle());
RemoveConnectJob(job, group);
if (socket.get()) {
handed_out_socket = true;
- HandOutSocket(socket.release(), false /* unused socket */,
- connect_timing, r->handle(), base::TimeDelta(), group,
- r->net_log());
+ HandOutSocket(socket.Pass(), false /* unused socket */,
+ connect_timing, request->handle(), base::TimeDelta(),
+ group, request->net_log());
}
- r->net_log().EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, result);
- InvokeUserCallbackLater(r->handle(), r->callback(), result);
+ request->net_log().EndEventWithNetErrorCode(
+ NetLog::TYPE_SOCKET_POOL, result);
+ InvokeUserCallbackLater(request->handle(), request->callback(), result);
} else {
RemoveConnectJob(job, group);
}
@@ -917,59 +944,38 @@ void ClientSocketPoolBaseHelper::FlushWithError(int error) {
CancelAllRequestsWithError(error);
}
-bool ClientSocketPoolBaseHelper::IsStalled() const {
- // If we are not using |max_sockets_|, then clearly we are not stalled
- if ((handed_out_socket_count_ + connecting_socket_count_) < max_sockets_)
- return false;
- // So in order to be stalled we need to be using |max_sockets_| AND
- // we need to have a request that is actually stalled on the global
- // socket limit. To find such a request, we look for a group that
- // a has more requests that jobs AND where the number of jobs is less
- // than |max_sockets_per_group_|. (If the number of jobs is equal to
- // |max_sockets_per_group_|, then the request is stalled on the group,
- // which does not count.)
- for (GroupMap::const_iterator it = group_map_.begin();
- it != group_map_.end(); ++it) {
- if (it->second->IsStalledOnPoolMaxSockets(max_sockets_per_group_))
- return true;
- }
- return false;
-}
-
void ClientSocketPoolBaseHelper::RemoveConnectJob(ConnectJob* job,
Group* group) {
CHECK_GT(connecting_socket_count_, 0);
connecting_socket_count_--;
DCHECK(group);
- DCHECK(ContainsKey(group->jobs(), job));
group->RemoveJob(job);
// If we've got no more jobs for this group, then we no longer need a
// backup job either.
if (group->jobs().empty())
group->CleanupBackupJob();
-
- DCHECK(job);
- delete job;
}
void ClientSocketPoolBaseHelper::OnAvailableSocketSlot(
const std::string& group_name, Group* group) {
DCHECK(ContainsKey(group_map_, group_name));
- if (group->IsEmpty())
+ if (group->IsEmpty()) {
RemoveGroup(group_name);
- else if (!group->pending_requests().empty())
+ } else if (group->has_pending_requests()) {
ProcessPendingRequest(group_name, group);
+ }
}
void ClientSocketPoolBaseHelper::ProcessPendingRequest(
const std::string& group_name, Group* group) {
- int rv = RequestSocketInternal(group_name,
- *group->pending_requests().begin());
+ const Request* next_request = group->GetNextPendingRequest();
+ DCHECK(next_request);
+ int rv = RequestSocketInternal(group_name, *next_request);
if (rv != ERR_IO_PENDING) {
- scoped_ptr<const Request> request(RemoveRequestFromQueue(
- group->mutable_pending_requests()->begin(), group));
+ scoped_ptr<const Request> request = group->PopNextPendingRequest();
+ DCHECK(request);
if (group->IsEmpty())
RemoveGroup(group_name);
@@ -979,7 +985,7 @@ void ClientSocketPoolBaseHelper::ProcessPendingRequest(
}
void ClientSocketPoolBaseHelper::HandOutSocket(
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
bool reused,
const LoadTimingInfo::ConnectTiming& connect_timing,
ClientSocketHandle* handle,
@@ -987,7 +993,7 @@ void ClientSocketPoolBaseHelper::HandOutSocket(
Group* group,
const BoundNetLog& net_log) {
DCHECK(socket);
- handle->set_socket(socket);
+ handle->SetSocket(socket.Pass());
handle->set_is_reused(reused);
handle->set_idle_time(idle_time);
handle->set_pool_id(pool_generation_number_);
@@ -1000,18 +1006,20 @@ void ClientSocketPoolBaseHelper::HandOutSocket(
"idle_ms", static_cast<int>(idle_time.InMilliseconds())));
}
- net_log.AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET,
- socket->NetLog().source().ToEventParametersCallback());
+ net_log.AddEvent(
+ NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET,
+ handle->socket()->NetLog().source().ToEventParametersCallback());
handed_out_socket_count_++;
group->IncrementActiveSocketCount();
}
void ClientSocketPoolBaseHelper::AddIdleSocket(
- StreamSocket* socket, Group* group) {
+ scoped_ptr<StreamSocket> socket,
+ Group* group) {
DCHECK(socket);
IdleSocket idle_socket;
- idle_socket.socket = socket;
+ idle_socket.socket = socket.release();
idle_socket.start_time = base::TimeTicks::Now();
group->mutable_idle_sockets()->push_back(idle_socket);
@@ -1041,13 +1049,11 @@ void ClientSocketPoolBaseHelper::CancelAllRequestsWithError(int error) {
for (GroupMap::iterator i = group_map_.begin(); i != group_map_.end();) {
Group* group = i->second;
- RequestQueue pending_requests;
- pending_requests.swap(*group->mutable_pending_requests());
- for (RequestQueue::iterator it2 = pending_requests.begin();
- it2 != pending_requests.end(); ++it2) {
- scoped_ptr<const Request> request(*it2);
- InvokeUserCallbackLater(
- request->handle(), request->callback(), error);
+ while (true) {
+ scoped_ptr<const Request> request = group->PopNextPendingRequest();
+ if (!request)
+ break;
+ InvokeUserCallbackLater(request->handle(), request->callback(), error);
}
// Delete group if no longer needed.
@@ -1103,12 +1109,12 @@ bool ClientSocketPoolBaseHelper::CloseOneIdleSocketExceptInGroup(
return false;
}
-bool ClientSocketPoolBaseHelper::CloseOneIdleConnectionInLayeredPool() {
+bool ClientSocketPoolBaseHelper::CloseOneIdleConnectionInHigherLayeredPool() {
// This pool doesn't have any idle sockets. It's possible that a pool at a
// higher layer is holding one of this sockets active, but it's actually idle.
// Query the higher layers.
- for (std::set<LayeredPool*>::const_iterator it = higher_layer_pools_.begin();
- it != higher_layer_pools_.end(); ++it) {
+ for (std::set<HigherLayeredPool*>::const_iterator it = higher_pools_.begin();
+ it != higher_pools_.end(); ++it) {
if ((*it)->CloseOneIdleConnection())
return true;
}
@@ -1144,7 +1150,7 @@ void ClientSocketPoolBaseHelper::TryToCloseSocketsInLayeredPools() {
while (IsStalled()) {
// Closing a socket will result in calling back into |this| to use the freed
// socket slot, so nothing else is needed.
- if (!CloseOneIdleConnectionInLayeredPool())
+ if (!CloseOneIdleConnectionInHigherLayeredPool())
return;
}
}
@@ -1182,19 +1188,25 @@ bool ClientSocketPoolBaseHelper::Group::TryToUseUnassignedConnectJob() {
return true;
}
-void ClientSocketPoolBaseHelper::Group::AddJob(ConnectJob* job,
+void ClientSocketPoolBaseHelper::Group::AddJob(scoped_ptr<ConnectJob> job,
bool is_preconnect) {
SanityCheck();
if (is_preconnect)
++unassigned_job_count_;
- jobs_.insert(job);
+ jobs_.insert(job.release());
}
void ClientSocketPoolBaseHelper::Group::RemoveJob(ConnectJob* job) {
+ scoped_ptr<ConnectJob> owned_job(job);
SanityCheck();
- jobs_.erase(job);
+ std::set<ConnectJob*>::iterator it = jobs_.find(job);
+ if (it != jobs_.end()) {
+ jobs_.erase(it);
+ } else {
+ NOTREACHED();
+ }
size_t job_count = jobs_.size();
if (job_count < unassigned_job_count_)
unassigned_job_count_ = job_count;
@@ -1222,15 +1234,17 @@ void ClientSocketPoolBaseHelper::Group::OnBackupSocketTimerFired(
if (pending_requests_.empty())
return;
- ConnectJob* backup_job = pool->connect_job_factory_->NewConnectJob(
- group_name, **pending_requests_.begin(), pool);
+ scoped_ptr<ConnectJob> backup_job =
+ pool->connect_job_factory_->NewConnectJob(
+ group_name, **pending_requests_.begin(), pool);
backup_job->net_log().AddEvent(NetLog::TYPE_SOCKET_BACKUP_CREATED);
SIMPLE_STATS_COUNTER("socket.backup_created");
int rv = backup_job->Connect();
pool->connecting_socket_count_++;
- AddJob(backup_job, false);
+ ConnectJob* raw_backup_job = backup_job.get();
+ AddJob(backup_job.Pass(), false);
if (rv != ERR_IO_PENDING)
- pool->OnConnectJobComplete(rv, backup_job);
+ pool->OnConnectJobComplete(rv, raw_backup_job);
}
void ClientSocketPoolBaseHelper::Group::SanityCheck() {
@@ -1248,6 +1262,68 @@ void ClientSocketPoolBaseHelper::Group::RemoveAllJobs() {
weak_factory_.InvalidateWeakPtrs();
}
+const ClientSocketPoolBaseHelper::Request*
+ClientSocketPoolBaseHelper::Group::GetNextPendingRequest() const {
+ return pending_requests_.empty() ? NULL : *pending_requests_.begin();
+}
+
+bool ClientSocketPoolBaseHelper::Group::HasConnectJobForHandle(
+ const ClientSocketHandle* handle) const {
+ // Search the first |jobs_.size()| pending requests for |handle|.
+ // If it's farther back in the deque than that, it doesn't have a
+ // corresponding ConnectJob.
+ size_t i = 0;
+ for (RequestQueue::const_iterator it = pending_requests_.begin();
+ it != pending_requests_.end() && i < jobs_.size(); ++it, ++i) {
+ if ((*it)->handle() == handle)
+ return true;
+ }
+ return false;
+}
+
+void ClientSocketPoolBaseHelper::Group::InsertPendingRequest(
+ scoped_ptr<const Request> r) {
+ RequestQueue::iterator it = pending_requests_.begin();
+ // TODO(mmenke): Should the network stack require requests with
+ // |ignore_limits| have the highest priority?
+ while (it != pending_requests_.end() &&
+ CompareEffectiveRequestPriority(*r, *(*it)) <= 0) {
+ ++it;
+ }
+ pending_requests_.insert(it, r.release());
+}
+
+scoped_ptr<const ClientSocketPoolBaseHelper::Request>
+ClientSocketPoolBaseHelper::Group::PopNextPendingRequest() {
+ if (pending_requests_.empty())
+ return scoped_ptr<const ClientSocketPoolBaseHelper::Request>();
+ return RemovePendingRequest(pending_requests_.begin());
+}
+
+scoped_ptr<const ClientSocketPoolBaseHelper::Request>
+ClientSocketPoolBaseHelper::Group::FindAndRemovePendingRequest(
+ ClientSocketHandle* handle) {
+ for (RequestQueue::iterator it = pending_requests_.begin();
+ it != pending_requests_.end(); ++it) {
+ if ((*it)->handle() == handle) {
+ scoped_ptr<const Request> request = RemovePendingRequest(it);
+ return request.Pass();
+ }
+ }
+ return scoped_ptr<const ClientSocketPoolBaseHelper::Request>();
+}
+
+scoped_ptr<const ClientSocketPoolBaseHelper::Request>
+ClientSocketPoolBaseHelper::Group::RemovePendingRequest(
+ const RequestQueue::iterator& it) {
+ scoped_ptr<const Request> request(*it);
+ pending_requests_.erase(it);
+ // If there are no more requests, kill the backup timer.
+ if (pending_requests_.empty())
+ CleanupBackupJob();
+ return request.Pass();
+}
+
} // namespace internal
} // namespace net
diff --git a/chromium/net/socket/client_socket_pool_base.h b/chromium/net/socket/client_socket_pool_base.h
index 4bf95d7b04a..31ec9bf7b13 100644
--- a/chromium/net/socket/client_socket_pool_base.h
+++ b/chromium/net/socket/client_socket_pool_base.h
@@ -61,8 +61,11 @@ class NET_EXPORT_PRIVATE ConnectJob {
Delegate() {}
virtual ~Delegate() {}
- // Alerts the delegate that the connection completed.
- virtual void OnConnectJobComplete(int result, ConnectJob* job) = 0;
+ // Alerts the delegate that the connection completed. |job| must
+ // be destroyed by the delegate. A scoped_ptr<> isn't used because
+ // the caller of this function doesn't own |job|.
+ virtual void OnConnectJobComplete(int result,
+ ConnectJob* job) = 0;
private:
DISALLOW_COPY_AND_ASSIGN(Delegate);
@@ -71,6 +74,7 @@ class NET_EXPORT_PRIVATE ConnectJob {
// A |timeout_duration| of 0 corresponds to no timeout.
ConnectJob(const std::string& group_name,
base::TimeDelta timeout_duration,
+ RequestPriority priority,
Delegate* delegate,
const BoundNetLog& net_log);
virtual ~ConnectJob();
@@ -79,9 +83,10 @@ class NET_EXPORT_PRIVATE ConnectJob {
const std::string& group_name() const { return group_name_; }
const BoundNetLog& net_log() { return net_log_; }
- // Releases |socket_| to the client. On connection error, this should return
- // NULL.
- StreamSocket* ReleaseSocket() { return socket_.release(); }
+ // Releases ownership of the underlying socket to the caller.
+ // Returns the released socket, or NULL if there was a connection
+ // error.
+ scoped_ptr<StreamSocket> PassSocket();
// Begins connecting the socket. Returns OK on success, ERR_IO_PENDING if it
// cannot complete synchronously without blocking, or another net error code
@@ -105,7 +110,8 @@ class NET_EXPORT_PRIVATE ConnectJob {
const BoundNetLog& net_log() const { return net_log_; }
protected:
- void set_socket(StreamSocket* socket);
+ RequestPriority priority() const { return priority_; }
+ void SetSocket(scoped_ptr<StreamSocket> socket);
StreamSocket* socket() { return socket_.get(); }
void NotifyDelegateOfCompletion(int rv);
void ResetTimer(base::TimeDelta remainingTime);
@@ -124,6 +130,8 @@ class NET_EXPORT_PRIVATE ConnectJob {
const std::string group_name_;
const base::TimeDelta timeout_duration_;
+ // TODO(akalin): Support reprioritization.
+ const RequestPriority priority_;
// Timer to abort jobs that take too long.
base::OneShotTimer<ConnectJob> timer_;
Delegate* delegate_;
@@ -175,6 +183,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
private:
ClientSocketHandle* const handle_;
CompletionCallback callback_;
+ // TODO(akalin): Support reprioritization.
const RequestPriority priority_;
bool ignore_limits_;
const Flags flags_;
@@ -188,7 +197,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
ConnectJobFactory() {}
virtual ~ConnectJobFactory() {}
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const Request& request,
ConnectJob::Delegate* delegate) const = 0;
@@ -200,6 +209,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
};
ClientSocketPoolBaseHelper(
+ HigherLayeredPool* pool,
int max_sockets,
int max_sockets_per_group,
base::TimeDelta unused_idle_socket_timeout,
@@ -208,15 +218,21 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
virtual ~ClientSocketPoolBaseHelper();
- // Adds/Removes layered pools. It is expected in the destructor that no
- // layered pools remain.
- void AddLayeredPool(LayeredPool* pool);
- void RemoveLayeredPool(LayeredPool* pool);
+ // Adds a lower layered pool to |this|, and adds |this| as a higher layered
+ // pool on top of |lower_pool|.
+ void AddLowerLayeredPool(LowerLayeredPool* lower_pool);
+
+ // See LowerLayeredPool::IsStalled for documentation on this function.
+ bool IsStalled() const;
+
+ // See LowerLayeredPool for documentation on these functions. It is expected
+ // in the destructor that no higher layer pools remain.
+ void AddHigherLayeredPool(HigherLayeredPool* higher_pool);
+ void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool);
// See ClientSocketPool::RequestSocket for documentation on this function.
- // ClientSocketPoolBaseHelper takes ownership of |request|, which must be
- // heap allocated.
- int RequestSocket(const std::string& group_name, const Request* request);
+ int RequestSocket(const std::string& group_name,
+ scoped_ptr<const Request> request);
// See ClientSocketPool::RequestSocket for documentation on this function.
void RequestSockets(const std::string& group_name,
@@ -229,15 +245,12 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
// See ClientSocketPool::ReleaseSocket for documentation on this function.
void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id);
// See ClientSocketPool::FlushWithError for documentation on this function.
void FlushWithError(int error);
- // See ClientSocketPool::IsStalled for documentation on this function.
- bool IsStalled() const;
-
// See ClientSocketPool::CloseIdleSockets for documentation on this function.
void CloseIdleSockets();
@@ -294,8 +307,8 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
// I'm not sure if we hit this situation often.
bool CloseOneIdleSocket();
- // Checks layered pools to see if they can close an idle connection.
- bool CloseOneIdleConnectionInLayeredPool();
+ // Checks higher layered pools to see if they can close an idle connection.
+ bool CloseOneIdleConnectionInHigherLayeredPool();
// See ClientSocketPool::GetInfoAsValue for documentation on this function.
base::DictionaryValue* GetInfoAsValue(const std::string& name,
@@ -386,22 +399,55 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
// Otherwise, returns false.
bool TryToUseUnassignedConnectJob();
- void AddJob(ConnectJob* job, bool is_preconnect);
+ void AddJob(scoped_ptr<ConnectJob> job, bool is_preconnect);
+ // Remove |job| from this group, which must already own |job|.
void RemoveJob(ConnectJob* job);
void RemoveAllJobs();
+ bool has_pending_requests() const {
+ return !pending_requests_.empty();
+ }
+
+ size_t pending_request_count() const {
+ return pending_requests_.size();
+ }
+
+ // Gets (but does not remove) the next pending request. Returns
+ // NULL if there are no pending requests.
+ const Request* GetNextPendingRequest() const;
+
+ // Returns true if there is a connect job for |handle|.
+ bool HasConnectJobForHandle(const ClientSocketHandle* handle) const;
+
+ // Inserts the request into the queue based on priority
+ // order. Older requests are prioritized over requests of equal
+ // priority.
+ void InsertPendingRequest(scoped_ptr<const Request> r);
+
+ // Gets and removes the next pending request. Returns NULL if
+ // there are no pending requests.
+ scoped_ptr<const Request> PopNextPendingRequest();
+
+ // Finds the pending request for |handle| and removes it. Returns
+ // the removed pending request, or NULL if there was none.
+ scoped_ptr<const Request> FindAndRemovePendingRequest(
+ ClientSocketHandle* handle);
+
void IncrementActiveSocketCount() { active_socket_count_++; }
void DecrementActiveSocketCount() { active_socket_count_--; }
int unassigned_job_count() const { return unassigned_job_count_; }
const std::set<ConnectJob*>& jobs() const { return jobs_; }
const std::list<IdleSocket>& idle_sockets() const { return idle_sockets_; }
- const RequestQueue& pending_requests() const { return pending_requests_; }
int active_socket_count() const { return active_socket_count_; }
- RequestQueue* mutable_pending_requests() { return &pending_requests_; }
std::list<IdleSocket>* mutable_idle_sockets() { return &idle_sockets_; }
private:
+ // Returns the iterator's pending request after removing it from
+ // the queue.
+ scoped_ptr<const Request> RemovePendingRequest(
+ const RequestQueue::iterator& it);
+
// Called when the backup socket timer fires.
void OnBackupSocketTimerFired(
std::string group_name,
@@ -443,15 +489,6 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
typedef std::map<const ClientSocketHandle*, CallbackResultPair>
PendingCallbackMap;
- // Inserts the request into the queue based on order they will receive
- // sockets. Sockets which ignore the socket pool limits are first. Then
- // requests are sorted by priority, with higher priorities closer to the
- // front. Older requests are prioritized over requests of equal priority.
- static void InsertRequestIntoQueue(const Request* r,
- RequestQueue* pending_requests);
- static const Request* RemoveRequestFromQueue(const RequestQueue::iterator& it,
- Group* group);
-
Group* GetOrCreateGroup(const std::string& group_name);
void RemoveGroup(const std::string& group_name);
void RemoveGroup(GroupMap::iterator it);
@@ -475,7 +512,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
CleanupIdleSockets(false);
}
- // Removes |job| from |connect_job_set_|. Also updates |group| if non-NULL.
+ // Removes |job| from |group|, which must already own |job|.
void RemoveConnectJob(ConnectJob* job, Group* group);
// Tries to see if we can handle any more requests for |group|.
@@ -485,7 +522,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
void ProcessPendingRequest(const std::string& group_name, Group* group);
// Assigns |socket| to |handle| and updates |group|'s counters appropriately.
- void HandOutSocket(StreamSocket* socket,
+ void HandOutSocket(scoped_ptr<StreamSocket> socket,
bool reused,
const LoadTimingInfo::ConnectTiming& connect_timing,
ClientSocketHandle* handle,
@@ -494,7 +531,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
const BoundNetLog& net_log);
// Adds |socket| to the list of idle sockets for |group|.
- void AddIdleSocket(StreamSocket* socket, Group* group);
+ void AddIdleSocket(scoped_ptr<StreamSocket> socket, Group* group);
// Iterates through |group_map_|, canceling all ConnectJobs and deleting
// groups if they are no longer needed.
@@ -511,14 +548,14 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
// it does not handle logging into NetLog of the queueing status of
// |request|.
int RequestSocketInternal(const std::string& group_name,
- const Request* request);
+ const Request& request);
// Assigns an idle socket for the group to the request.
// Returns |true| if an idle socket is available, false otherwise.
- bool AssignIdleSocketToRequest(const Request* request, Group* group);
+ bool AssignIdleSocketToRequest(const Request& request, Group* group);
static void LogBoundConnectJobToRequest(
- const NetLog::Source& connect_job_source, const Request* request);
+ const NetLog::Source& connect_job_source, const Request& request);
// Same as CloseOneIdleSocket() except it won't close an idle socket in
// |group|. If |group| is NULL, it is ignored. Returns true if it closed a
@@ -588,7 +625,18 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
// to the pool, we can make sure that they are discarded rather than reused.
int pool_generation_number_;
- std::set<LayeredPool*> higher_layer_pools_;
+ // Used to add |this| as a higher layer pool on top of lower layer pools. May
+ // be NULL if no lower layer pools will be added.
+ HigherLayeredPool* pool_;
+
+ // Pools that create connections through |this|. |this| will try to close
+ // their idle sockets when it stalls. Must be empty on destruction.
+ std::set<HigherLayeredPool*> higher_pools_;
+
+ // Pools that this goes through. Typically there's only one, but not always.
+ // |this| will check if they're stalled when it has a new idle socket. |this|
+ // will remove itself from all lower layered pools on destruction.
+ std::set<LowerLayeredPool*> lower_pools_;
base::WeakPtrFactory<ClientSocketPoolBaseHelper> weak_factory_;
@@ -624,7 +672,7 @@ class ClientSocketPoolBase {
ConnectJobFactory() {}
virtual ~ConnectJobFactory() {}
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const Request& request,
ConnectJob::Delegate* delegate) const = 0;
@@ -642,6 +690,7 @@ class ClientSocketPoolBase {
// |used_idle_socket_timeout| specifies how long to leave a previously used
// idle socket open before closing it.
ClientSocketPoolBase(
+ HigherLayeredPool* self,
int max_sockets,
int max_sockets_per_group,
ClientSocketPoolHistograms* histograms,
@@ -649,19 +698,23 @@ class ClientSocketPoolBase {
base::TimeDelta used_idle_socket_timeout,
ConnectJobFactory* connect_job_factory)
: histograms_(histograms),
- helper_(max_sockets, max_sockets_per_group,
+ helper_(self, max_sockets, max_sockets_per_group,
unused_idle_socket_timeout, used_idle_socket_timeout,
new ConnectJobFactoryAdaptor(connect_job_factory)) {}
virtual ~ClientSocketPoolBase() {}
// These member functions simply forward to ClientSocketPoolBaseHelper.
- void AddLayeredPool(LayeredPool* pool) {
- helper_.AddLayeredPool(pool);
+ void AddLowerLayeredPool(LowerLayeredPool* lower_pool) {
+ helper_.AddLowerLayeredPool(lower_pool);
}
- void RemoveLayeredPool(LayeredPool* pool) {
- helper_.RemoveLayeredPool(pool);
+ void AddHigherLayeredPool(HigherLayeredPool* higher_pool) {
+ helper_.AddHigherLayeredPool(higher_pool);
+ }
+
+ void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) {
+ helper_.RemoveHigherLayeredPool(higher_pool);
}
// RequestSocket bundles up the parameters into a Request and then forwards to
@@ -672,12 +725,15 @@ class ClientSocketPoolBase {
ClientSocketHandle* handle,
const CompletionCallback& callback,
const BoundNetLog& net_log) {
- Request* request =
+ scoped_ptr<const Request> request(
new Request(handle, callback, priority,
internal::ClientSocketPoolBaseHelper::NORMAL,
params->ignore_limits(),
- params, net_log);
- return helper_.RequestSocket(group_name, request);
+ params, net_log));
+ return helper_.RequestSocket(
+ group_name,
+ request.template PassAs<
+ const internal::ClientSocketPoolBaseHelper::Request>());
}
// RequestSockets bundles up the parameters into a Request and then forwards
@@ -702,9 +758,10 @@ class ClientSocketPoolBase {
return helper_.CancelRequest(group_name, handle);
}
- void ReleaseSocket(const std::string& group_name, StreamSocket* socket,
+ void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
int id) {
- return helper_.ReleaseSocket(group_name, socket, id);
+ return helper_.ReleaseSocket(group_name, socket.Pass(), id);
}
void FlushWithError(int error) { helper_.FlushWithError(error); }
@@ -765,8 +822,8 @@ class ClientSocketPoolBase {
bool CloseOneIdleSocket() { return helper_.CloseOneIdleSocket(); }
- bool CloseOneIdleConnectionInLayeredPool() {
- return helper_.CloseOneIdleConnectionInLayeredPool();
+ bool CloseOneIdleConnectionInHigherLayeredPool() {
+ return helper_.CloseOneIdleConnectionInHigherLayeredPool();
}
private:
@@ -785,13 +842,13 @@ class ClientSocketPoolBase {
: connect_job_factory_(connect_job_factory) {}
virtual ~ConnectJobFactoryAdaptor() {}
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const internal::ClientSocketPoolBaseHelper::Request& request,
- ConnectJob::Delegate* delegate) const {
- const Request* casted_request = static_cast<const Request*>(&request);
+ ConnectJob::Delegate* delegate) const OVERRIDE {
+ const Request& casted_request = static_cast<const Request&>(request);
return connect_job_factory_->NewConnectJob(
- group_name, *casted_request, delegate);
+ group_name, casted_request, delegate);
}
virtual base::TimeDelta ConnectionTimeout() const {
diff --git a/chromium/net/socket/client_socket_pool_base_unittest.cc b/chromium/net/socket/client_socket_pool_base_unittest.cc
index 5eeda972cff..bbeca2f3e11 100644
--- a/chromium/net/socket/client_socket_pool_base_unittest.cc
+++ b/chromium/net/socket/client_socket_pool_base_unittest.cc
@@ -30,7 +30,9 @@
#include "net/socket/client_socket_handle.h"
#include "net/socket/client_socket_pool_histograms.h"
#include "net/socket/socket_test_util.h"
+#include "net/socket/ssl_client_socket.h"
#include "net/socket/stream_socket.h"
+#include "net/udp/datagram_client_socket.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -189,30 +191,30 @@ class MockClientSocketFactory : public ClientSocketFactory {
public:
MockClientSocketFactory() : allocation_count_(0) {}
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE {
NOTREACHED();
- return NULL;
+ return scoped_ptr<DatagramClientSocket>();
}
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* /* net_log */,
const NetLog::Source& /*source*/) OVERRIDE {
allocation_count_++;
- return NULL;
+ return scoped_ptr<StreamSocket>();
}
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) OVERRIDE {
NOTIMPLEMENTED();
- return NULL;
+ return scoped_ptr<SSLClientSocket>();
}
virtual void ClearSSLSessionCache() OVERRIDE {
@@ -259,7 +261,7 @@ class TestConnectJob : public ConnectJob {
ConnectJob::Delegate* delegate,
MockClientSocketFactory* client_socket_factory,
NetLog* net_log)
- : ConnectJob(group_name, timeout_duration, delegate,
+ : ConnectJob(group_name, timeout_duration, request.priority(), delegate,
BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
job_type_(job_type),
client_socket_factory_(client_socket_factory),
@@ -294,7 +296,8 @@ class TestConnectJob : public ConnectJob {
AddressList ignored;
client_socket_factory_->CreateTransportClientSocket(
ignored, NULL, net::NetLog::Source());
- set_socket(new MockClientSocket(net_log().net_log()));
+ SetSocket(
+ scoped_ptr<StreamSocket>(new MockClientSocket(net_log().net_log())));
switch (job_type_) {
case kMockJob:
return DoConnect(true /* successful */, false /* sync */,
@@ -373,7 +376,7 @@ class TestConnectJob : public ConnectJob {
return ERR_IO_PENDING;
default:
NOTREACHED();
- set_socket(NULL);
+ SetSocket(scoped_ptr<StreamSocket>());
return ERR_FAILED;
}
}
@@ -386,7 +389,7 @@ class TestConnectJob : public ConnectJob {
result = ERR_PROXY_AUTH_REQUESTED;
} else {
result = ERR_CONNECTION_FAILED;
- set_socket(NULL);
+ SetSocket(scoped_ptr<StreamSocket>());
}
if (was_async)
@@ -430,7 +433,7 @@ class TestConnectJobFactory
// ConnectJobFactory implementation.
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const TestClientSocketPoolBase::Request& request,
ConnectJob::Delegate* delegate) const OVERRIDE {
@@ -440,13 +443,13 @@ class TestConnectJobFactory
job_type = job_types_->front();
job_types_->pop_front();
}
- return new TestConnectJob(job_type,
- group_name,
- request,
- timeout_duration_,
- delegate,
- client_socket_factory_,
- net_log_);
+ return scoped_ptr<ConnectJob>(new TestConnectJob(job_type,
+ group_name,
+ request,
+ timeout_duration_,
+ delegate,
+ client_socket_factory_,
+ net_log_));
}
virtual base::TimeDelta ConnectionTimeout() const OVERRIDE {
@@ -465,6 +468,8 @@ class TestConnectJobFactory
class TestClientSocketPool : public ClientSocketPool {
public:
+ typedef TestSocketParams SocketParams;
+
TestClientSocketPool(
int max_sockets,
int max_sockets_per_group,
@@ -472,7 +477,7 @@ class TestClientSocketPool : public ClientSocketPool {
base::TimeDelta unused_idle_socket_timeout,
base::TimeDelta used_idle_socket_timeout,
TestClientSocketPoolBase::ConnectJobFactory* connect_job_factory)
- : base_(max_sockets, max_sockets_per_group, histograms,
+ : base_(NULL, max_sockets, max_sockets_per_group, histograms,
unused_idle_socket_timeout, used_idle_socket_timeout,
connect_job_factory) {}
@@ -509,9 +514,9 @@ class TestClientSocketPool : public ClientSocketPool {
virtual void ReleaseSocket(
const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) OVERRIDE {
- base_.ReleaseSocket(group_name, socket, id);
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
}
virtual void FlushWithError(int error) OVERRIDE {
@@ -541,12 +546,13 @@ class TestClientSocketPool : public ClientSocketPool {
return base_.GetLoadState(group_name, handle);
}
- virtual void AddLayeredPool(LayeredPool* pool) OVERRIDE {
- base_.AddLayeredPool(pool);
+ virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE {
+ base_.AddHigherLayeredPool(higher_pool);
}
- virtual void RemoveLayeredPool(LayeredPool* pool) OVERRIDE {
- base_.RemoveLayeredPool(pool);
+ virtual void RemoveHigherLayeredPool(
+ HigherLayeredPool* higher_pool) OVERRIDE {
+ base_.RemoveHigherLayeredPool(higher_pool);
}
virtual base::DictionaryValue* GetInfoAsValue(
@@ -586,8 +592,8 @@ class TestClientSocketPool : public ClientSocketPool {
void EnableConnectBackupJobs() { base_.EnableConnectBackupJobs(); }
- bool CloseOneIdleConnectionInLayeredPool() {
- return base_.CloseOneIdleConnectionInLayeredPool();
+ bool CloseOneIdleConnectionInHigherLayeredPool() {
+ return base_.CloseOneIdleConnectionInHigherLayeredPool();
}
private:
@@ -598,8 +604,6 @@ class TestClientSocketPool : public ClientSocketPool {
} // namespace
-REGISTER_SOCKET_PARAMS_FOR_POOL(TestClientSocketPool, TestSocketParams);
-
namespace {
void MockClientSocketFactory::SignalJobs() {
@@ -630,10 +634,10 @@ class TestConnectJobDelegate : public ConnectJob::Delegate {
virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE {
result_ = result;
- scoped_ptr<StreamSocket> socket(job->ReleaseSocket());
+ scoped_ptr<ConnectJob> owned_job(job);
+ scoped_ptr<StreamSocket> socket = owned_job->PassSocket();
// socket.get() should be NULL iff result != OK
- EXPECT_EQ(socket.get() == NULL, result != OK);
- delete job;
+ EXPECT_EQ(socket == NULL, result != OK);
have_result_ = true;
if (waiting_for_result_)
base::MessageLoop::current()->Quit();
@@ -702,9 +706,8 @@ class ClientSocketPoolBaseTest : public testing::Test {
const std::string& group_name,
RequestPriority priority,
const scoped_refptr<TestSocketParams>& params) {
- return test_base_.StartRequestUsingPool<
- TestClientSocketPool, TestSocketParams>(
- pool_.get(), group_name, priority, params);
+ return test_base_.StartRequestUsingPool(
+ pool_.get(), group_name, priority, params);
}
int StartRequest(const std::string& group_name, RequestPriority priority) {
@@ -3716,7 +3719,7 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectWithBackupJob) {
EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a"));
}
-class MockLayeredPool : public LayeredPool {
+class MockLayeredPool : public HigherLayeredPool {
public:
MockLayeredPool(TestClientSocketPool* pool,
const std::string& group_name)
@@ -3724,11 +3727,11 @@ class MockLayeredPool : public LayeredPool {
params_(new TestSocketParams),
group_name_(group_name),
can_release_connection_(true) {
- pool_->AddLayeredPool(this);
+ pool_->AddHigherLayeredPool(this);
}
~MockLayeredPool() {
- pool_->RemoveLayeredPool(this);
+ pool_->RemoveHigherLayeredPool(this);
}
int RequestSocket(TestClientSocketPool* pool) {
@@ -3774,7 +3777,7 @@ TEST_F(ClientSocketPoolBaseTest, FailToCloseIdleSocketsNotHeldByLayeredPool) {
EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get()));
EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection())
.WillOnce(Return(false));
- EXPECT_FALSE(pool_->CloseOneIdleConnectionInLayeredPool());
+ EXPECT_FALSE(pool_->CloseOneIdleConnectionInHigherLayeredPool());
}
TEST_F(ClientSocketPoolBaseTest, ForciblyCloseIdleSocketsHeldByLayeredPool) {
@@ -3786,7 +3789,7 @@ TEST_F(ClientSocketPoolBaseTest, ForciblyCloseIdleSocketsHeldByLayeredPool) {
EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection())
.WillOnce(Invoke(&mock_layered_pool,
&MockLayeredPool::ReleaseOneConnection));
- EXPECT_TRUE(pool_->CloseOneIdleConnectionInLayeredPool());
+ EXPECT_TRUE(pool_->CloseOneIdleConnectionInHigherLayeredPool());
}
// Tests the basic case of closing an idle socket in a higher layered pool when
diff --git a/chromium/net/socket/client_socket_pool_manager.cc b/chromium/net/socket/client_socket_pool_manager.cc
index 71496d28646..b37d2d1949c 100644
--- a/chromium/net/socket/client_socket_pool_manager.cc
+++ b/chromium/net/socket/client_socket_pool_manager.cc
@@ -158,7 +158,6 @@ int InitSocketPoolHelper(const GURL& request_url,
bool ignore_limits = (request_load_flags & LOAD_IGNORE_LIMITS) != 0;
if (proxy_info.is_direct()) {
tcp_params = new TransportSocketParams(origin_host_port,
- request_priority,
disable_resolver_cache,
ignore_limits,
resolution_callback);
@@ -167,7 +166,6 @@ int InitSocketPoolHelper(const GURL& request_url,
proxy_host_port.reset(new HostPortPair(proxy_server.host_port_pair()));
scoped_refptr<TransportSocketParams> proxy_tcp_params(
new TransportSocketParams(*proxy_host_port,
- request_priority,
disable_resolver_cache,
ignore_limits,
resolution_callback));
@@ -182,7 +180,6 @@ int InitSocketPoolHelper(const GURL& request_url,
ssl_params = new SSLSocketParams(proxy_tcp_params,
NULL,
NULL,
- ProxyServer::SCHEME_DIRECT,
*proxy_host_port.get(),
ssl_config_for_proxy,
kPrivacyModeDisabled,
@@ -214,8 +211,7 @@ int InitSocketPoolHelper(const GURL& request_url,
socks_params = new SOCKSSocketParams(proxy_tcp_params,
socks_version == '5',
- origin_host_port,
- request_priority);
+ origin_host_port);
}
}
@@ -229,7 +225,6 @@ int InitSocketPoolHelper(const GURL& request_url,
new SSLSocketParams(tcp_params,
socks_params,
http_proxy_params,
- proxy_info.proxy_server().scheme(),
origin_host_port,
ssl_config_for_origin,
privacy_mode,
diff --git a/chromium/net/socket/deterministic_socket_data_unittest.cc b/chromium/net/socket/deterministic_socket_data_unittest.cc
index eba01b5e9cc..c51427e25a7 100644
--- a/chromium/net/socket/deterministic_socket_data_unittest.cc
+++ b/chromium/net/socket/deterministic_socket_data_unittest.cc
@@ -72,7 +72,6 @@ DeterministicSocketDataTest::DeterministicSocketDataTest()
connect_data_(SYNCHRONOUS, OK),
endpoint_("www.google.com", 443),
tcp_params_(new TransportSocketParams(endpoint_,
- LOWEST,
false,
false,
OnHostResolutionCallback())),
diff --git a/chromium/net/socket/nss_ssl_util.cc b/chromium/net/socket/nss_ssl_util.cc
index be33ac5add0..7e3aee430c4 100644
--- a/chromium/net/socket/nss_ssl_util.cc
+++ b/chromium/net/socket/nss_ssl_util.cc
@@ -58,12 +58,13 @@ class NSSSSLInitSingleton {
enabled = false;
// Trim the list of cipher suites in order to keep the size of the
- // ClientHello down. DSS, ECDH, CAMELLIA, SEED and ECC+3DES cipher
- // suites are disabled.
+ // ClientHello down. DSS, ECDH, CAMELLIA, SEED, ECC+3DES, and
+ // HMAC-SHA256 cipher suites are disabled.
if (info.symCipher == ssl_calg_camellia ||
info.symCipher == ssl_calg_seed ||
(info.symCipher == ssl_calg_3des && info.keaType != ssl_kea_rsa) ||
info.authAlgorithm == ssl_auth_dsa ||
+ info.macAlgorithm == ssl_hmac_sha256 ||
info.nonStandard ||
strcmp(info.keaTypeName, "ECDH") == 0) {
enabled = false;
@@ -232,6 +233,10 @@ int MapNSSError(PRErrorCode err) {
case SEC_ERROR_BAD_DER:
case SEC_ERROR_EXTRA_INPUT:
return ERR_SSL_BAD_PEER_PUBLIC_KEY;
+ // During renegotiation, the server presented a different certificate than
+ // was used earlier.
+ case SSL_ERROR_WRONG_CERTIFICATE:
+ return ERR_SSL_SERVER_CERT_CHANGED;
default: {
if (IS_SSL_ERROR(err)) {
diff --git a/chromium/net/socket/socket_descriptor.cc b/chromium/net/socket/socket_descriptor.cc
new file mode 100644
index 00000000000..5a2e53cab4d
--- /dev/null
+++ b/chromium/net/socket/socket_descriptor.cc
@@ -0,0 +1,48 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/socket_descriptor.h"
+
+#if defined(OS_POSIX)
+#include <sys/types.h>
+#include <sys/socket.h>
+#endif
+
+#include "base/basictypes.h"
+
+#if defined(OS_WIN)
+#include "net/base/winsock_init.h"
+#endif
+
+namespace net {
+
+PlatformSocketFactory* g_socket_factory = NULL;
+
+PlatformSocketFactory::PlatformSocketFactory() {
+}
+
+PlatformSocketFactory::~PlatformSocketFactory() {
+}
+
+void PlatformSocketFactory::SetInstance(PlatformSocketFactory* factory) {
+ g_socket_factory = factory;
+}
+
+SocketDescriptor CreateSocketDefault(int family, int type, int protocol) {
+#if defined(OS_WIN)
+ EnsureWinsockInit();
+ return ::WSASocket(family, type, protocol, NULL, 0, WSA_FLAG_OVERLAPPED);
+#else // OS_WIN
+ return ::socket(family, type, protocol);
+#endif // OS_WIN
+}
+
+SocketDescriptor CreatePlatformSocket(int family, int type, int protocol) {
+ if (g_socket_factory)
+ return g_socket_factory->CreateSocket(family, type, protocol);
+ else
+ return CreateSocketDefault(family, type, protocol);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/socket_descriptor.h b/chromium/net/socket/socket_descriptor.h
new file mode 100644
index 00000000000..b2a22234b80
--- /dev/null
+++ b/chromium/net/socket/socket_descriptor.h
@@ -0,0 +1,49 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SOCKET_DESCRIPTOR_H_
+#define NET_SOCKET_SOCKET_DESCRIPTOR_H_
+
+#include "build/build_config.h"
+#include "net/base/net_export.h"
+
+#if defined(OS_WIN)
+#include <winsock2.h>
+#endif // OS_WIN
+
+namespace net {
+
+#if defined(OS_POSIX)
+typedef int SocketDescriptor;
+const SocketDescriptor kInvalidSocket = -1;
+#elif defined(OS_WIN)
+typedef SOCKET SocketDescriptor;
+const SocketDescriptor kInvalidSocket = INVALID_SOCKET;
+#endif
+
+// Interface to create native socket.
+// Usually such factories are used for testing purposes, which is not true in
+// this case. This interface is used to substitute WSASocket/socket to make
+// possible execution of some network code in sandbox.
+class NET_EXPORT PlatformSocketFactory {
+ public:
+ PlatformSocketFactory();
+ virtual ~PlatformSocketFactory();
+
+ // Replace WSASocket/socket with given factory. The factory will be used by
+ // CreatePlatformSocket.
+ static void SetInstance(PlatformSocketFactory* factory);
+
+ // Creates socket. See WSASocket/socket documentation of parameters.
+ virtual SocketDescriptor CreateSocket(int family, int type, int protocol) = 0;
+};
+
+// Creates socket. See WSASocket/socket documentation of parameters.
+SocketDescriptor NET_EXPORT CreatePlatformSocket(int family,
+ int type,
+ int protocol);
+
+} // namespace net
+
+#endif // NET_SOCKET_SOCKET_DESCRIPTOR_H_
diff --git a/chromium/net/socket/socket_test_util.cc b/chromium/net/socket/socket_test_util.cc
index 8b2bdfccba3..78e9e7ce9c4 100644
--- a/chromium/net/socket/socket_test_util.cc
+++ b/chromium/net/socket/socket_test_util.cc
@@ -657,37 +657,39 @@ void MockClientSocketFactory::ResetNextMockIndexes() {
mock_ssl_data_.ResetNextIndex();
}
-DatagramClientSocket* MockClientSocketFactory::CreateDatagramClientSocket(
+scoped_ptr<DatagramClientSocket>
+MockClientSocketFactory::CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
net::NetLog* net_log,
const net::NetLog::Source& source) {
SocketDataProvider* data_provider = mock_data_.GetNext();
- MockUDPClientSocket* socket = new MockUDPClientSocket(data_provider, net_log);
- data_provider->set_socket(socket);
- return socket;
+ scoped_ptr<MockUDPClientSocket> socket(
+ new MockUDPClientSocket(data_provider, net_log));
+ data_provider->set_socket(socket.get());
+ return socket.PassAs<DatagramClientSocket>();
}
-StreamSocket* MockClientSocketFactory::CreateTransportClientSocket(
+scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket(
const AddressList& addresses,
net::NetLog* net_log,
const net::NetLog::Source& source) {
SocketDataProvider* data_provider = mock_data_.GetNext();
- MockTCPClientSocket* socket =
- new MockTCPClientSocket(addresses, net_log, data_provider);
- data_provider->set_socket(socket);
- return socket;
+ scoped_ptr<MockTCPClientSocket> socket(
+ new MockTCPClientSocket(addresses, net_log, data_provider));
+ data_provider->set_socket(socket.get());
+ return socket.PassAs<StreamSocket>();
}
-SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) {
- MockSSLClientSocket* socket =
- new MockSSLClientSocket(transport_socket, host_and_port, ssl_config,
- mock_ssl_data_.GetNext());
- return socket;
+ return scoped_ptr<SSLClientSocket>(
+ new MockSSLClientSocket(transport_socket.Pass(),
+ host_and_port, ssl_config,
+ mock_ssl_data_.GetNext()));
}
void MockClientSocketFactory::ClearSSLSessionCache() {
@@ -1278,7 +1280,7 @@ void DeterministicMockTCPClientSocket::OnConnectComplete(
// static
void MockSSLClientSocket::ConnectCallback(
- MockSSLClientSocket *ssl_client_socket,
+ MockSSLClientSocket* ssl_client_socket,
const CompletionCallback& callback,
int rv) {
if (rv == OK)
@@ -1287,7 +1289,7 @@ void MockSSLClientSocket::ConnectCallback(
}
MockSSLClientSocket::MockSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_port_pair,
const SSLConfig& ssl_config,
SSLSocketDataProvider* data)
@@ -1295,7 +1297,7 @@ MockSSLClientSocket::MockSSLClientSocket(
// Have to use the right BoundNetLog for LoadTimingInfo regression
// tests.
transport_socket->socket()->NetLog()),
- transport_(transport_socket),
+ transport_(transport_socket.Pass()),
data_(data),
is_npn_state_set_(false),
new_npn_value_(false),
@@ -1664,10 +1666,10 @@ void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
}
MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
ClientSocketHandle* handle,
const CompletionCallback& callback)
- : socket_(socket),
+ : socket_(socket.Pass()),
handle_(handle),
user_callback_(callback) {
}
@@ -1698,7 +1700,7 @@ void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
if (!socket_.get())
return;
if (rv == OK) {
- handle_->set_socket(socket_.release());
+ handle_->SetSocket(socket_.Pass());
// Needed for socket pool tests that layer other sockets on top of mock
// sockets.
@@ -1730,6 +1732,7 @@ MockTransportClientSocketPool::MockTransportClientSocketPool(
: TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms,
NULL, NULL, NULL),
client_socket_factory_(socket_factory),
+ last_request_priority_(DEFAULT_PRIORITY),
release_count_(0),
cancel_count_(0) {
}
@@ -1740,9 +1743,11 @@ int MockTransportClientSocketPool::RequestSocket(
const std::string& group_name, const void* socket_params,
RequestPriority priority, ClientSocketHandle* handle,
const CompletionCallback& callback, const BoundNetLog& net_log) {
- StreamSocket* socket = client_socket_factory_->CreateTransportClientSocket(
- AddressList(), net_log.net_log(), net::NetLog::Source());
- MockConnectJob* job = new MockConnectJob(socket, handle, callback);
+ last_request_priority_ = priority;
+ scoped_ptr<StreamSocket> socket =
+ client_socket_factory_->CreateTransportClientSocket(
+ AddressList(), net_log.net_log(), net::NetLog::Source());
+ MockConnectJob* job = new MockConnectJob(socket.Pass(), handle, callback);
job_list_.push_back(job);
handle->set_pool_id(1);
return job->Connect();
@@ -1759,11 +1764,12 @@ void MockTransportClientSocketPool::CancelRequest(const std::string& group_name,
}
}
-void MockTransportClientSocketPool::ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) {
+void MockTransportClientSocketPool::ReleaseSocket(
+ const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) {
EXPECT_EQ(1, id);
release_count_++;
- delete socket;
}
DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {}
@@ -1791,42 +1797,45 @@ MockSSLClientSocket* DeterministicMockClientSocketFactory::
return ssl_client_sockets_[index];
}
-DatagramClientSocket*
+scoped_ptr<DatagramClientSocket>
DeterministicMockClientSocketFactory::CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
net::NetLog* net_log,
const NetLog::Source& source) {
DeterministicSocketData* data_provider = mock_data().GetNext();
- DeterministicMockUDPClientSocket* socket =
- new DeterministicMockUDPClientSocket(net_log, data_provider);
+ scoped_ptr<DeterministicMockUDPClientSocket> socket(
+ new DeterministicMockUDPClientSocket(net_log, data_provider));
data_provider->set_delegate(socket->AsWeakPtr());
- udp_client_sockets().push_back(socket);
- return socket;
+ udp_client_sockets().push_back(socket.get());
+ return socket.PassAs<DatagramClientSocket>();
}
-StreamSocket* DeterministicMockClientSocketFactory::CreateTransportClientSocket(
+scoped_ptr<StreamSocket>
+DeterministicMockClientSocketFactory::CreateTransportClientSocket(
const AddressList& addresses,
net::NetLog* net_log,
const net::NetLog::Source& source) {
DeterministicSocketData* data_provider = mock_data().GetNext();
- DeterministicMockTCPClientSocket* socket =
- new DeterministicMockTCPClientSocket(net_log, data_provider);
+ scoped_ptr<DeterministicMockTCPClientSocket> socket(
+ new DeterministicMockTCPClientSocket(net_log, data_provider));
data_provider->set_delegate(socket->AsWeakPtr());
- tcp_client_sockets().push_back(socket);
- return socket;
+ tcp_client_sockets().push_back(socket.get());
+ return socket.PassAs<StreamSocket>();
}
-SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+scoped_ptr<SSLClientSocket>
+DeterministicMockClientSocketFactory::CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) {
- MockSSLClientSocket* socket =
- new MockSSLClientSocket(transport_socket, host_and_port, ssl_config,
- mock_ssl_data_.GetNext());
- ssl_client_sockets_.push_back(socket);
- return socket;
+ scoped_ptr<MockSSLClientSocket> socket(
+ new MockSSLClientSocket(transport_socket.Pass(),
+ host_and_port, ssl_config,
+ mock_ssl_data_.GetNext()));
+ ssl_client_sockets_.push_back(socket.get());
+ return socket.PassAs<SSLClientSocket>();
}
void DeterministicMockClientSocketFactory::ClearSSLSessionCache() {
@@ -1859,8 +1868,9 @@ void MockSOCKSClientSocketPool::CancelRequest(
}
void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) {
- return transport_pool_->ReleaseSocket(group_name, socket, id);
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ return transport_pool_->ReleaseSocket(group_name, socket.Pass(), id);
}
const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 };
diff --git a/chromium/net/socket/socket_test_util.h b/chromium/net/socket/socket_test_util.h
index 6afe170299e..e4e56522c92 100644
--- a/chromium/net/socket/socket_test_util.h
+++ b/chromium/net/socket/socket_test_util.h
@@ -13,6 +13,7 @@
#include "base/basictypes.h"
#include "base/callback.h"
#include "base/logging.h"
+#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
#include "base/memory/scoped_vector.h"
#include "base/memory/weak_ptr.h"
@@ -592,17 +593,17 @@ class MockClientSocketFactory : public ClientSocketFactory {
}
// ClientSocketFactory
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE;
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE;
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) OVERRIDE;
@@ -857,7 +858,7 @@ class DeterministicMockTCPClientSocket
class MockSSLClientSocket : public MockClientSocket, public AsyncSocket {
public:
MockSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
SSLSocketDataProvider* socket);
@@ -1001,11 +1002,12 @@ class ClientSocketPoolTest {
ClientSocketPoolTest();
~ClientSocketPoolTest();
- template <typename PoolType, typename SocketParams>
- int StartRequestUsingPool(PoolType* socket_pool,
- const std::string& group_name,
- RequestPriority priority,
- const scoped_refptr<SocketParams>& socket_params) {
+ template <typename PoolType>
+ int StartRequestUsingPool(
+ PoolType* socket_pool,
+ const std::string& group_name,
+ RequestPriority priority,
+ const scoped_refptr<typename PoolType::SocketParams>& socket_params) {
DCHECK(socket_pool);
TestSocketRequest* request = new TestSocketRequest(&request_order_,
&completion_count_);
@@ -1045,11 +1047,20 @@ class ClientSocketPoolTest {
size_t completion_count_;
};
+class MockTransportSocketParams
+ : public base::RefCounted<MockTransportSocketParams> {
+ private:
+ friend class base::RefCounted<MockTransportSocketParams>;
+ ~MockTransportSocketParams() {}
+};
+
class MockTransportClientSocketPool : public TransportClientSocketPool {
public:
+ typedef MockTransportSocketParams SocketParams;
+
class MockConnectJob {
public:
- MockConnectJob(StreamSocket* socket, ClientSocketHandle* handle,
+ MockConnectJob(scoped_ptr<StreamSocket> socket, ClientSocketHandle* handle,
const CompletionCallback& callback);
~MockConnectJob();
@@ -1074,6 +1085,9 @@ class MockTransportClientSocketPool : public TransportClientSocketPool {
virtual ~MockTransportClientSocketPool();
+ RequestPriority last_request_priority() const {
+ return last_request_priority_;
+ }
int release_count() const { return release_count_; }
int cancel_count() const { return cancel_count_; }
@@ -1088,11 +1102,13 @@ class MockTransportClientSocketPool : public TransportClientSocketPool {
virtual void CancelRequest(const std::string& group_name,
ClientSocketHandle* handle) OVERRIDE;
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) OVERRIDE;
+ scoped_ptr<StreamSocket> socket,
+ int id) OVERRIDE;
private:
ClientSocketFactory* client_socket_factory_;
ScopedVector<MockConnectJob> job_list_;
+ RequestPriority last_request_priority_;
int release_count_;
int cancel_count_;
@@ -1123,17 +1139,17 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory {
}
// ClientSocketFactory
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE;
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE;
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) OVERRIDE;
@@ -1170,7 +1186,8 @@ class MockSOCKSClientSocketPool : public SOCKSClientSocketPool {
virtual void CancelRequest(const std::string& group_name,
ClientSocketHandle* handle) OVERRIDE;
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) OVERRIDE;
+ scoped_ptr<StreamSocket> socket,
+ int id) OVERRIDE;
private:
TransportClientSocketPool* const transport_pool_;
diff --git a/chromium/net/socket/socks5_client_socket.cc b/chromium/net/socket/socks5_client_socket.cc
index c9d25bc3dcb..537b584a932 100644
--- a/chromium/net/socket/socks5_client_socket.cc
+++ b/chromium/net/socket/socks5_client_socket.cc
@@ -28,34 +28,18 @@ COMPILE_ASSERT(sizeof(struct in_addr) == 4, incorrect_system_size_of_IPv4);
COMPILE_ASSERT(sizeof(struct in6_addr) == 16, incorrect_system_size_of_IPv6);
SOCKS5ClientSocket::SOCKS5ClientSocket(
- ClientSocketHandle* transport_socket,
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostResolver::RequestInfo& req_info)
: io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete,
base::Unretained(this))),
- transport_(transport_socket),
+ transport_(transport_socket.Pass()),
next_state_(STATE_NONE),
completed_handshake_(false),
bytes_sent_(0),
bytes_received_(0),
read_header_size(kReadHeaderSize),
host_request_info_(req_info),
- net_log_(transport_socket->socket()->NetLog()) {
-}
-
-SOCKS5ClientSocket::SOCKS5ClientSocket(
- StreamSocket* transport_socket,
- const HostResolver::RequestInfo& req_info)
- : io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete,
- base::Unretained(this))),
- transport_(new ClientSocketHandle()),
- next_state_(STATE_NONE),
- completed_handshake_(false),
- bytes_sent_(0),
- bytes_received_(0),
- read_header_size(kReadHeaderSize),
- host_request_info_(req_info),
- net_log_(transport_socket->NetLog()) {
- transport_->set_socket(transport_socket);
+ net_log_(transport_->socket()->NetLog()) {
}
SOCKS5ClientSocket::~SOCKS5ClientSocket() {
diff --git a/chromium/net/socket/socks5_client_socket.h b/chromium/net/socket/socks5_client_socket.h
index b955e8f42de..45216244f10 100644
--- a/chromium/net/socket/socks5_client_socket.h
+++ b/chromium/net/socket/socks5_client_socket.h
@@ -28,20 +28,13 @@ class BoundNetLog;
// Currently no SOCKSv5 authentication is supported.
class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket {
public:
- // Takes ownership of the |transport_socket|, which should already be
- // connected by the time Connect() is called.
- //
// |req_info| contains the hostname and port to which the socket above will
// communicate to via the SOCKS layer.
//
// Although SOCKS 5 supports 3 different modes of addressing, we will
// always pass it a hostname. This means the DNS resolving is done
// proxy side.
- SOCKS5ClientSocket(ClientSocketHandle* transport_socket,
- const HostResolver::RequestInfo& req_info);
-
- // Deprecated constructor (http://crbug.com/37810) that takes a StreamSocket.
- SOCKS5ClientSocket(StreamSocket* transport_socket,
+ SOCKS5ClientSocket(scoped_ptr<ClientSocketHandle> transport_socket,
const HostResolver::RequestInfo& req_info);
// On destruction Disconnect() is called.
diff --git a/chromium/net/socket/socks5_client_socket_unittest.cc b/chromium/net/socket/socks5_client_socket_unittest.cc
index 717d858eef8..78f2ac433c3 100644
--- a/chromium/net/socket/socks5_client_socket_unittest.cc
+++ b/chromium/net/socket/socks5_client_socket_unittest.cc
@@ -32,13 +32,13 @@ class SOCKS5ClientSocketTest : public PlatformTest {
public:
SOCKS5ClientSocketTest();
// Create a SOCKSClientSocket on top of a MockSocket.
- SOCKS5ClientSocket* BuildMockSocket(MockRead reads[],
- size_t reads_count,
- MockWrite writes[],
- size_t writes_count,
- const std::string& hostname,
- int port,
- NetLog* net_log);
+ scoped_ptr<SOCKS5ClientSocket> BuildMockSocket(MockRead reads[],
+ size_t reads_count,
+ MockWrite writes[],
+ size_t writes_count,
+ const std::string& hostname,
+ int port,
+ NetLog* net_log);
virtual void SetUp();
@@ -47,6 +47,8 @@ class SOCKS5ClientSocketTest : public PlatformTest {
CapturingNetLog net_log_;
scoped_ptr<SOCKS5ClientSocket> user_sock_;
AddressList address_list_;
+ // Filled in by BuildMockSocket() and owned by its return value
+ // (which |user_sock| is set to).
StreamSocket* tcp_sock_;
TestCompletionCallback callback_;
scoped_ptr<MockHostResolver> host_resolver_;
@@ -68,14 +70,18 @@ void SOCKS5ClientSocketTest::SetUp() {
// Resolve the "localhost" AddressList used by the TCP connection to connect.
HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080));
TestCompletionCallback callback;
- int rv = host_resolver_->Resolve(info, &address_list_, callback.callback(),
- NULL, BoundNetLog());
+ int rv = host_resolver_->Resolve(info,
+ DEFAULT_PRIORITY,
+ &address_list_,
+ callback.callback(),
+ NULL,
+ BoundNetLog());
ASSERT_EQ(ERR_IO_PENDING, rv);
rv = callback.WaitForResult();
ASSERT_EQ(OK, rv);
}
-SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket(
+scoped_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket(
MockRead reads[],
size_t reads_count,
MockWrite writes[],
@@ -94,8 +100,13 @@ SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket(
EXPECT_EQ(OK, rv);
EXPECT_TRUE(tcp_sock_->IsConnected());
- return new SOCKS5ClientSocket(tcp_sock_,
- HostResolver::RequestInfo(HostPortPair(hostname, port)));
+ scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
+ // |connection| takes ownership of |tcp_sock_|, but keep a
+ // non-owning pointer to it.
+ connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
+ return scoped_ptr<SOCKS5ClientSocket>(new SOCKS5ClientSocket(
+ connection.Pass(),
+ HostResolver::RequestInfo(HostPortPair(hostname, port))));
}
// Tests a complete SOCKS5 handshake and the disconnection.
@@ -123,9 +134,9 @@ TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
MockRead(ASYNC, payload_read.data(), payload_read.size()) };
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- "localhost", 80, &net_log_));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ "localhost", 80, &net_log_);
// At this state the TCP connection is completed but not the SOCKS handshake.
EXPECT_TRUE(tcp_sock_->IsConnected());
@@ -195,9 +206,9 @@ TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
};
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- hostname, 80, NULL));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, NULL);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(OK, rv);
@@ -217,9 +228,9 @@ TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
// Create a SOCKS socket, with mock transport socket.
MockWrite data_writes[] = {MockWrite()};
MockRead data_reads[] = {MockRead()};
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- large_host_name, 80, NULL));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ large_host_name, 80, NULL);
// Try to connect -- should fail (without having read/written anything to
// the transport socket first) because the hostname is too long.
@@ -253,9 +264,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
MockRead data_reads[] = {
MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- hostname, 80, &net_log_));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, &net_log_);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -284,9 +295,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
MockRead(ASYNC, partial1, arraysize(partial1)),
MockRead(ASYNC, partial2, arraysize(partial2)),
MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- hostname, 80, &net_log_));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, &net_log_);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -314,9 +325,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
MockRead data_reads[] = {
MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- hostname, 80, &net_log_));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, &net_log_);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
CapturingNetLog::CapturedEntryList net_log_entries;
@@ -345,9 +356,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
kSOCKS5OkResponseLength - kSplitPoint)
};
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- hostname, 80, &net_log_));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, &net_log_);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
CapturingNetLog::CapturedEntryList net_log_entries;
diff --git a/chromium/net/socket/socks_client_socket.cc b/chromium/net/socket/socks_client_socket.cc
index c4bbd28c619..67089589cc5 100644
--- a/chromium/net/socket/socks_client_socket.cc
+++ b/chromium/net/socket/socks_client_socket.cc
@@ -55,32 +55,20 @@ struct SOCKS4ServerResponse {
COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
socks4_server_response_struct_wrong_size);
-SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket,
- const HostResolver::RequestInfo& req_info,
- HostResolver* host_resolver)
- : transport_(transport_socket),
+SOCKSClientSocket::SOCKSClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostResolver::RequestInfo& req_info,
+ RequestPriority priority,
+ HostResolver* host_resolver)
+ : transport_(transport_socket.Pass()),
next_state_(STATE_NONE),
completed_handshake_(false),
bytes_sent_(0),
bytes_received_(0),
host_resolver_(host_resolver),
host_request_info_(req_info),
- net_log_(transport_socket->socket()->NetLog()) {
-}
-
-SOCKSClientSocket::SOCKSClientSocket(StreamSocket* transport_socket,
- const HostResolver::RequestInfo& req_info,
- HostResolver* host_resolver)
- : transport_(new ClientSocketHandle()),
- next_state_(STATE_NONE),
- completed_handshake_(false),
- bytes_sent_(0),
- bytes_received_(0),
- host_resolver_(host_resolver),
- host_request_info_(req_info),
- net_log_(transport_socket->NetLog()) {
- transport_->set_socket(transport_socket);
-}
+ priority_(priority),
+ net_log_(transport_->socket()->NetLog()) {}
SOCKSClientSocket::~SOCKSClientSocket() {
Disconnect();
@@ -283,7 +271,9 @@ int SOCKSClientSocket::DoResolveHost() {
// addresses for the target host.
host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4);
return host_resolver_.Resolve(
- host_request_info_, &addresses_,
+ host_request_info_,
+ priority_,
+ &addresses_,
base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)),
net_log_);
}
diff --git a/chromium/net/socket/socks_client_socket.h b/chromium/net/socket/socks_client_socket.h
index 3d4f9fc2771..d4f058a62b1 100644
--- a/chromium/net/socket/socks_client_socket.h
+++ b/chromium/net/socket/socks_client_socket.h
@@ -27,18 +27,11 @@ class BoundNetLog;
// The SOCKS client socket implementation
class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket {
public:
- // Takes ownership of the |transport_socket|, which should already be
- // connected by the time Connect() is called.
- //
// |req_info| contains the hostname and port to which the socket above will
// communicate to via the socks layer. For testing the referrer is optional.
- SOCKSClientSocket(ClientSocketHandle* transport_socket,
- const HostResolver::RequestInfo& req_info,
- HostResolver* host_resolver);
-
- // Deprecated constructor (http://crbug.com/37810) that takes a StreamSocket.
- SOCKSClientSocket(StreamSocket* transport_socket,
+ SOCKSClientSocket(scoped_ptr<ClientSocketHandle> transport_socket,
const HostResolver::RequestInfo& req_info,
+ RequestPriority priority,
HostResolver* host_resolver);
// On destruction Disconnect() is called.
@@ -131,6 +124,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket {
SingleRequestHostResolver host_resolver_;
AddressList addresses_;
HostResolver::RequestInfo host_request_info_;
+ RequestPriority priority_;
BoundNetLog net_log_;
diff --git a/chromium/net/socket/socks_client_socket_pool.cc b/chromium/net/socket/socks_client_socket_pool.cc
index d740e5b9a0e..e11b7a48db5 100644
--- a/chromium/net/socket/socks_client_socket_pool.cc
+++ b/chromium/net/socket/socks_client_socket_pool.cc
@@ -21,8 +21,7 @@ namespace net {
SOCKSSocketParams::SOCKSSocketParams(
const scoped_refptr<TransportSocketParams>& proxy_server,
bool socks_v5,
- const HostPortPair& host_port_pair,
- RequestPriority priority)
+ const HostPortPair& host_port_pair)
: transport_params_(proxy_server),
destination_(host_port_pair),
socks_v5_(socks_v5) {
@@ -30,7 +29,6 @@ SOCKSSocketParams::SOCKSSocketParams(
ignore_limits_ = transport_params_->ignore_limits();
else
ignore_limits_ = false;
- destination_.set_priority(priority);
}
SOCKSSocketParams::~SOCKSSocketParams() {}
@@ -41,13 +39,14 @@ static const int kSOCKSConnectJobTimeoutInSeconds = 30;
SOCKSConnectJob::SOCKSConnectJob(
const std::string& group_name,
+ RequestPriority priority,
const scoped_refptr<SOCKSSocketParams>& socks_params,
const base::TimeDelta& timeout_duration,
TransportClientSocketPool* transport_pool,
HostResolver* host_resolver,
Delegate* delegate,
NetLog* net_log)
- : ConnectJob(group_name, timeout_duration, delegate,
+ : ConnectJob(group_name, timeout_duration, priority, delegate,
BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
socks_params_(socks_params),
transport_pool_(transport_pool),
@@ -117,10 +116,12 @@ int SOCKSConnectJob::DoLoop(int result) {
int SOCKSConnectJob::DoTransportConnect() {
next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
transport_socket_handle_.reset(new ClientSocketHandle());
- return transport_socket_handle_->Init(
- group_name(), socks_params_->transport_params(),
- socks_params_->destination().priority(), callback_, transport_pool_,
- net_log());
+ return transport_socket_handle_->Init(group_name(),
+ socks_params_->transport_params(),
+ priority(),
+ callback_,
+ transport_pool_,
+ net_log());
}
int SOCKSConnectJob::DoTransportConnectComplete(int result) {
@@ -140,11 +141,12 @@ int SOCKSConnectJob::DoSOCKSConnect() {
// Add a SOCKS connection on top of the tcp socket.
if (socks_params_->is_socks_v5()) {
- socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.release(),
+ socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.Pass(),
socks_params_->destination()));
} else {
- socket_.reset(new SOCKSClientSocket(transport_socket_handle_.release(),
+ socket_.reset(new SOCKSClientSocket(transport_socket_handle_.Pass(),
socks_params_->destination(),
+ priority(),
resolver_));
}
return socket_->Connect(
@@ -157,7 +159,7 @@ int SOCKSConnectJob::DoSOCKSConnectComplete(int result) {
return result;
}
- set_socket(socket_.release());
+ SetSocket(socket_.Pass());
return result;
}
@@ -166,17 +168,19 @@ int SOCKSConnectJob::ConnectInternal() {
return DoLoop(OK);
}
-ConnectJob* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob(
+scoped_ptr<ConnectJob>
+SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const {
- return new SOCKSConnectJob(group_name,
- request.params(),
- ConnectionTimeout(),
- transport_pool_,
- host_resolver_,
- delegate,
- net_log_);
+ return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name,
+ request.priority(),
+ request.params(),
+ ConnectionTimeout(),
+ transport_pool_,
+ host_resolver_,
+ delegate,
+ net_log_));
}
base::TimeDelta
@@ -193,7 +197,7 @@ SOCKSClientSocketPool::SOCKSClientSocketPool(
TransportClientSocketPool* transport_pool,
NetLog* net_log)
: transport_pool_(transport_pool),
- base_(max_sockets, max_sockets_per_group, histograms,
+ base_(this, max_sockets, max_sockets_per_group, histograms,
ClientSocketPool::unused_idle_socket_timeout(),
ClientSocketPool::used_idle_socket_timeout(),
new SOCKSConnectJobFactory(transport_pool,
@@ -201,13 +205,10 @@ SOCKSClientSocketPool::SOCKSClientSocketPool(
net_log)) {
// We should always have a |transport_pool_| except in unit tests.
if (transport_pool_)
- transport_pool_->AddLayeredPool(this);
+ base_.AddLowerLayeredPool(transport_pool_);
}
SOCKSClientSocketPool::~SOCKSClientSocketPool() {
- // We should always have a |transport_pool_| except in unit tests.
- if (transport_pool_)
- transport_pool_->RemoveLayeredPool(this);
}
int SOCKSClientSocketPool::RequestSocket(
@@ -238,18 +239,15 @@ void SOCKSClientSocketPool::CancelRequest(const std::string& group_name,
}
void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) {
- base_.ReleaseSocket(group_name, socket, id);
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
}
void SOCKSClientSocketPool::FlushWithError(int error) {
base_.FlushWithError(error);
}
-bool SOCKSClientSocketPool::IsStalled() const {
- return base_.IsStalled() || transport_pool_->IsStalled();
-}
-
void SOCKSClientSocketPool::CloseIdleSockets() {
base_.CloseIdleSockets();
}
@@ -268,14 +266,6 @@ LoadState SOCKSClientSocketPool::GetLoadState(
return base_.GetLoadState(group_name, handle);
}
-void SOCKSClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) {
- base_.AddLayeredPool(layered_pool);
-}
-
-void SOCKSClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) {
- base_.RemoveLayeredPool(layered_pool);
-}
-
base::DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue(
const std::string& name,
const std::string& type,
@@ -299,10 +289,24 @@ ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const {
return base_.histograms();
};
+bool SOCKSClientSocketPool::IsStalled() const {
+ return base_.IsStalled();
+}
+
+void SOCKSClientSocketPool::AddHigherLayeredPool(
+ HigherLayeredPool* higher_pool) {
+ base_.AddHigherLayeredPool(higher_pool);
+}
+
+void SOCKSClientSocketPool::RemoveHigherLayeredPool(
+ HigherLayeredPool* higher_pool) {
+ base_.RemoveHigherLayeredPool(higher_pool);
+}
+
bool SOCKSClientSocketPool::CloseOneIdleConnection() {
if (base_.CloseOneIdleSocket())
return true;
- return base_.CloseOneIdleConnectionInLayeredPool();
+ return base_.CloseOneIdleConnectionInHigherLayeredPool();
}
} // namespace net
diff --git a/chromium/net/socket/socks_client_socket_pool.h b/chromium/net/socket/socks_client_socket_pool.h
index 86609a1a5a0..c6d5c8d0883 100644
--- a/chromium/net/socket/socks_client_socket_pool.h
+++ b/chromium/net/socket/socks_client_socket_pool.h
@@ -28,8 +28,7 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams
: public base::RefCounted<SOCKSSocketParams> {
public:
SOCKSSocketParams(const scoped_refptr<TransportSocketParams>& proxy_server,
- bool socks_v5, const HostPortPair& host_port_pair,
- RequestPriority priority);
+ bool socks_v5, const HostPortPair& host_port_pair);
const scoped_refptr<TransportSocketParams>& transport_params() const {
return transport_params_;
@@ -57,6 +56,7 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams
class SOCKSConnectJob : public ConnectJob {
public:
SOCKSConnectJob(const std::string& group_name,
+ RequestPriority priority,
const scoped_refptr<SOCKSSocketParams>& params,
const base::TimeDelta& timeout_duration,
TransportClientSocketPool* transport_pool,
@@ -105,8 +105,10 @@ class SOCKSConnectJob : public ConnectJob {
};
class NET_EXPORT_PRIVATE SOCKSClientSocketPool
- : public ClientSocketPool, public LayeredPool {
+ : public ClientSocketPool, public HigherLayeredPool {
public:
+ typedef SOCKSSocketParams SocketParams;
+
SOCKSClientSocketPool(
int max_sockets,
int max_sockets_per_group,
@@ -134,13 +136,11 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool
ClientSocketHandle* handle) OVERRIDE;
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) OVERRIDE;
virtual void FlushWithError(int error) OVERRIDE;
- virtual bool IsStalled() const OVERRIDE;
-
virtual void CloseIdleSockets() OVERRIDE;
virtual int IdleSocketCount() const OVERRIDE;
@@ -152,10 +152,6 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool
const std::string& group_name,
const ClientSocketHandle* handle) const OVERRIDE;
- virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE;
-
- virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE;
-
virtual base::DictionaryValue* GetInfoAsValue(
const std::string& name,
const std::string& type,
@@ -165,7 +161,14 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool
virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
- // LayeredPool implementation.
+ // LowerLayeredPool implementation.
+ virtual bool IsStalled() const OVERRIDE;
+
+ virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
+
+ virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
+
+ // HigherLayeredPool implementation.
virtual bool CloseOneIdleConnection() OVERRIDE;
private:
@@ -183,7 +186,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool
virtual ~SOCKSConnectJobFactory() {}
// ClientSocketPoolBase::ConnectJobFactory methods.
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const OVERRIDE;
@@ -204,8 +207,6 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool
DISALLOW_COPY_AND_ASSIGN(SOCKSClientSocketPool);
};
-REGISTER_SOCKET_PARAMS_FOR_POOL(SOCKSClientSocketPool, SOCKSSocketParams);
-
} // namespace net
#endif // NET_SOCKET_SOCKS_CLIENT_SOCKET_POOL_H_
diff --git a/chromium/net/socket/socks_client_socket_pool_unittest.cc b/chromium/net/socket/socks_client_socket_pool_unittest.cc
index 77440d36a19..4463e171f84 100644
--- a/chromium/net/socket/socks_client_socket_pool_unittest.cc
+++ b/chromium/net/socket/socks_client_socket_pool_unittest.cc
@@ -41,6 +41,25 @@ void TestLoadTimingInfo(const ClientSocketHandle& handle) {
ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
}
+
+scoped_refptr<TransportSocketParams> CreateProxyHostParams() {
+ return new TransportSocketParams(
+ HostPortPair("proxy", 80), false, false,
+ OnHostResolutionCallback());
+}
+
+scoped_refptr<SOCKSSocketParams> CreateSOCKSv4Params() {
+ return new SOCKSSocketParams(
+ CreateProxyHostParams(), false /* socks_v5 */,
+ HostPortPair("host", 80));
+}
+
+scoped_refptr<SOCKSSocketParams> CreateSOCKSv5Params() {
+ return new SOCKSSocketParams(
+ CreateProxyHostParams(), true /* socks_v5 */,
+ HostPortPair("host", 80));
+}
+
class SOCKSClientSocketPoolTest : public testing::Test {
protected:
class SOCKS5MockData {
@@ -71,30 +90,24 @@ class SOCKSClientSocketPoolTest : public testing::Test {
};
SOCKSClientSocketPoolTest()
- : ignored_transport_socket_params_(new TransportSocketParams(
- HostPortPair("proxy", 80), MEDIUM, false, false,
- OnHostResolutionCallback())),
- transport_histograms_("MockTCP"),
+ : transport_histograms_("MockTCP"),
transport_socket_pool_(
kMaxSockets, kMaxSocketsPerGroup,
&transport_histograms_,
&transport_client_socket_factory_),
- ignored_socket_params_(new SOCKSSocketParams(
- ignored_transport_socket_params_, true, HostPortPair("host", 80),
- MEDIUM)),
socks_histograms_("SOCKSUnitTest"),
pool_(kMaxSockets, kMaxSocketsPerGroup,
&socks_histograms_,
- NULL,
+ &host_resolver_,
&transport_socket_pool_,
NULL) {
}
virtual ~SOCKSClientSocketPoolTest() {}
- int StartRequest(const std::string& group_name, RequestPriority priority) {
+ int StartRequestV5(const std::string& group_name, RequestPriority priority) {
return test_base_.StartRequestUsingPool(
- &pool_, group_name, priority, ignored_socket_params_);
+ &pool_, group_name, priority, CreateSOCKSv5Params());
}
int GetOrderOfRequest(size_t index) const {
@@ -103,13 +116,12 @@ class SOCKSClientSocketPoolTest : public testing::Test {
ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); }
- scoped_refptr<TransportSocketParams> ignored_transport_socket_params_;
ClientSocketPoolHistograms transport_histograms_;
MockClientSocketFactory transport_client_socket_factory_;
MockTransportClientSocketPool transport_socket_pool_;
- scoped_refptr<SOCKSSocketParams> ignored_socket_params_;
ClientSocketPoolHistograms socks_histograms_;
+ MockHostResolver host_resolver_;
SOCKSClientSocketPool pool_;
ClientSocketPoolTest test_base_;
};
@@ -120,7 +132,7 @@ TEST_F(SOCKSClientSocketPoolTest, Simple) {
transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
ClientSocketHandle handle;
- int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(),
+ int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(),
&pool_, BoundNetLog());
EXPECT_EQ(OK, rv);
EXPECT_TRUE(handle.is_initialized());
@@ -128,13 +140,52 @@ TEST_F(SOCKSClientSocketPoolTest, Simple) {
TestLoadTimingInfo(handle);
}
+// Make sure that SOCKSConnectJob passes on its priority to its
+// socket request on Init.
+TEST_F(SOCKSClientSocketPoolTest, SetSocketRequestPriorityOnInit) {
+ for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) {
+ RequestPriority priority = static_cast<RequestPriority>(i);
+ SOCKS5MockData data(SYNCHRONOUS);
+ data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ transport_client_socket_factory_.AddSocketDataProvider(
+ data.data_provider());
+
+ ClientSocketHandle handle;
+ EXPECT_EQ(OK,
+ handle.Init("a", CreateSOCKSv5Params(), priority,
+ CompletionCallback(), &pool_, BoundNetLog()));
+ EXPECT_EQ(priority, transport_socket_pool_.last_request_priority());
+ handle.socket()->Disconnect();
+ }
+}
+
+// Make sure that SOCKSConnectJob passes on its priority to its
+// HostResolver request (for non-SOCKS5) on Init.
+TEST_F(SOCKSClientSocketPoolTest, SetResolvePriorityOnInit) {
+ for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) {
+ RequestPriority priority = static_cast<RequestPriority>(i);
+ SOCKS5MockData data(SYNCHRONOUS);
+ data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ transport_client_socket_factory_.AddSocketDataProvider(
+ data.data_provider());
+
+ ClientSocketHandle handle;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a", CreateSOCKSv4Params(), priority,
+ CompletionCallback(), &pool_, BoundNetLog()));
+ EXPECT_EQ(priority, transport_socket_pool_.last_request_priority());
+ EXPECT_EQ(priority, host_resolver_.last_request_priority());
+ EXPECT_TRUE(handle.socket() == NULL);
+ }
+}
+
TEST_F(SOCKSClientSocketPoolTest, Async) {
SOCKS5MockData data(ASYNC);
transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(),
+ int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(),
&pool_, BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -153,7 +204,7 @@ TEST_F(SOCKSClientSocketPoolTest, TransportConnectError) {
transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
ClientSocketHandle handle;
- int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(),
+ int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(),
&pool_, BoundNetLog());
EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -167,7 +218,7 @@ TEST_F(SOCKSClientSocketPoolTest, AsyncTransportConnectError) {
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(),
+ int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(),
&pool_, BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -189,7 +240,7 @@ TEST_F(SOCKSClientSocketPoolTest, SOCKSConnectError) {
ClientSocketHandle handle;
EXPECT_EQ(0, transport_socket_pool_.release_count());
- int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(),
+ int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(),
&pool_, BoundNetLog());
EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -209,7 +260,7 @@ TEST_F(SOCKSClientSocketPoolTest, AsyncSOCKSConnectError) {
TestCompletionCallback callback;
ClientSocketHandle handle;
EXPECT_EQ(0, transport_socket_pool_.release_count());
- int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(),
+ int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(),
&pool_, BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -230,10 +281,10 @@ TEST_F(SOCKSClientSocketPoolTest, CancelDuringTransportConnect) {
transport_client_socket_factory_.AddSocketDataProvider(data2.data_provider());
EXPECT_EQ(0, transport_socket_pool_.cancel_count());
- int rv = StartRequest("a", LOW);
+ int rv = StartRequestV5("a", LOW);
EXPECT_EQ(ERR_IO_PENDING, rv);
- rv = StartRequest("a", LOW);
+ rv = StartRequestV5("a", LOW);
EXPECT_EQ(ERR_IO_PENDING, rv);
pool_.CancelRequest("a", (*requests())[0]->handle());
@@ -265,10 +316,10 @@ TEST_F(SOCKSClientSocketPoolTest, CancelDuringSOCKSConnect) {
EXPECT_EQ(0, transport_socket_pool_.cancel_count());
EXPECT_EQ(0, transport_socket_pool_.release_count());
- int rv = StartRequest("a", LOW);
+ int rv = StartRequestV5("a", LOW);
EXPECT_EQ(ERR_IO_PENDING, rv);
- rv = StartRequest("a", LOW);
+ rv = StartRequestV5("a", LOW);
EXPECT_EQ(ERR_IO_PENDING, rv);
pool_.CancelRequest("a", (*requests())[0]->handle());
diff --git a/chromium/net/socket/socks_client_socket_unittest.cc b/chromium/net/socket/socks_client_socket_unittest.cc
index 7a8faf69856..f361244feff 100644
--- a/chromium/net/socket/socks_client_socket_unittest.cc
+++ b/chromium/net/socket/socks_client_socket_unittest.cc
@@ -4,6 +4,7 @@
#include "net/socket/socks_client_socket.h"
+#include "base/memory/scoped_ptr.h"
#include "net/base/address_list.h"
#include "net/base/net_log.h"
#include "net/base/net_log_unittest.h"
@@ -27,16 +28,19 @@ class SOCKSClientSocketTest : public PlatformTest {
public:
SOCKSClientSocketTest();
// Create a SOCKSClientSocket on top of a MockSocket.
- SOCKSClientSocket* BuildMockSocket(MockRead reads[], size_t reads_count,
- MockWrite writes[], size_t writes_count,
- HostResolver* host_resolver,
- const std::string& hostname, int port,
- NetLog* net_log);
+ scoped_ptr<SOCKSClientSocket> BuildMockSocket(
+ MockRead reads[], size_t reads_count,
+ MockWrite writes[], size_t writes_count,
+ HostResolver* host_resolver,
+ const std::string& hostname, int port,
+ NetLog* net_log);
virtual void SetUp();
protected:
scoped_ptr<SOCKSClientSocket> user_sock_;
AddressList address_list_;
+ // Filled in by BuildMockSocket() and owned by its return value
+ // (which |user_sock| is set to).
StreamSocket* tcp_sock_;
TestCompletionCallback callback_;
scoped_ptr<MockHostResolver> host_resolver_;
@@ -52,7 +56,7 @@ void SOCKSClientSocketTest::SetUp() {
PlatformTest::SetUp();
}
-SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket(
+scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket(
MockRead reads[],
size_t reads_count,
MockWrite writes[],
@@ -73,9 +77,15 @@ SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket(
EXPECT_EQ(OK, rv);
EXPECT_TRUE(tcp_sock_->IsConnected());
- return new SOCKSClientSocket(tcp_sock_,
+ scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
+ // |connection| takes ownership of |tcp_sock_|, but keep a
+ // non-owning pointer to it.
+ connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
+ return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket(
+ connection.Pass(),
HostResolver::RequestInfo(HostPortPair(hostname, port)),
- host_resolver);
+ DEFAULT_PRIORITY,
+ host_resolver));
}
// Implementation of HostResolver that never completes its resolve request.
@@ -86,6 +96,7 @@ class HangingHostResolverWithCancel : public HostResolver {
HangingHostResolverWithCancel() : outstanding_request_(NULL) {}
virtual int Resolve(const RequestInfo& info,
+ RequestPriority priority,
AddressList* addresses,
const CompletionCallback& callback,
RequestHandle* out_req,
@@ -134,11 +145,11 @@ TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
MockRead(ASYNC, payload_read.data(), payload_read.size()) };
CapturingNetLog log;
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- host_resolver_.get(),
- "localhost", 80,
- &log));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
// At this state the TCP connection is completed but not the SOCKS handshake.
EXPECT_TRUE(tcp_sock_->IsConnected());
@@ -210,11 +221,11 @@ TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
arraysize(tests[i].fail_reply)) };
CapturingNetLog log;
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- host_resolver_.get(),
- "localhost", 80,
- &log));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -247,11 +258,11 @@ TEST_F(SOCKSClientSocketTest, PartialServerReads) {
MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) };
CapturingNetLog log;
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- host_resolver_.get(),
- "localhost", 80,
- &log));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -285,11 +296,11 @@ TEST_F(SOCKSClientSocketTest, PartialClientWrites) {
MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
CapturingNetLog log;
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- host_resolver_.get(),
- "localhost", 80,
- &log));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -317,11 +328,11 @@ TEST_F(SOCKSClientSocketTest, FailedSocketRead) {
MockRead(SYNCHRONOUS, 0) };
CapturingNetLog log;
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- host_resolver_.get(),
- "localhost", 80,
- &log));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -347,11 +358,11 @@ TEST_F(SOCKSClientSocketTest, FailedDNS) {
CapturingNetLog log;
- user_sock_.reset(BuildMockSocket(NULL, 0,
- NULL, 0,
- host_resolver_.get(),
- hostname, 80,
- &log));
+ user_sock_ = BuildMockSocket(NULL, 0,
+ NULL, 0,
+ host_resolver_.get(),
+ hostname, 80,
+ &log);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -378,11 +389,11 @@ TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) };
MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) };
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- hanging_resolver.get(),
- "foo", 80,
- NULL));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hanging_resolver.get(),
+ "foo", 80,
+ NULL);
// Start connecting (will get stuck waiting for the host to resolve).
int rv = user_sock_->Connect(callback_.callback());
diff --git a/chromium/net/socket/ssl_client_socket_nss.cc b/chromium/net/socket/ssl_client_socket_nss.cc
index f374dedcf80..0de7cfb9060 100644
--- a/chromium/net/socket/ssl_client_socket_nss.cc
+++ b/chromium/net/socket/ssl_client_socket_nss.cc
@@ -380,20 +380,16 @@ void PeerCertificateChain::Reset(PRFileDesc* nss_fd) {
if (nss_fd == NULL)
return;
- unsigned int num_certs = 0;
- SECStatus rv = SSL_PeerCertificateChain(nss_fd, NULL, &num_certs, 0);
- DCHECK_EQ(SECSuccess, rv);
-
+ CERTCertList* list = SSL_PeerCertificateChain(nss_fd);
// The handshake on |nss_fd| may not have completed.
- if (num_certs == 0)
+ if (list == NULL)
return;
- certs_.resize(num_certs);
- const unsigned int expected_num_certs = num_certs;
- rv = SSL_PeerCertificateChain(nss_fd, vector_as_array(&certs_),
- &num_certs, expected_num_certs);
- DCHECK_EQ(SECSuccess, rv);
- DCHECK_EQ(expected_num_certs, num_certs);
+ for (CERTCertListNode* node = CERT_LIST_HEAD(list);
+ !CERT_LIST_END(node, list); node = CERT_LIST_NEXT(node)) {
+ certs_.push_back(CERT_DupCertificate(node->cert));
+ }
+ CERT_DestroyCertList(list);
}
std::vector<base::StringPiece>
@@ -1291,6 +1287,19 @@ SECStatus SSLClientSocketNSS::Core::OwnAuthCertHandler(
// Start with it.
SSL_OptionSet(socket, SSL_ENABLE_FALSE_START, PR_FALSE);
}
+ } else {
+ // Disallow the server certificate to change in a renegotiation.
+ CERTCertificate* old_cert = core->nss_handshake_state_.server_cert_chain[0];
+ ScopedCERTCertificate new_cert(SSL_PeerCertificate(socket));
+ if (new_cert->derCert.len != old_cert->derCert.len ||
+ memcmp(new_cert->derCert.data, old_cert->derCert.data,
+ new_cert->derCert.len) != 0) {
+ // NSS doesn't have an error code that indicates the server certificate
+ // changed. Borrow SSL_ERROR_WRONG_CERTIFICATE (which NSS isn't using)
+ // for this purpose.
+ PORT_SetError(SSL_ERROR_WRONG_CERTIFICATE);
+ return SECFailure;
+ }
}
// Tell NSS to not verify the certificate.
@@ -2598,7 +2607,7 @@ int SSLClientSocketNSS::Core::DoGetDomainBoundCert(const std::string& host) {
weak_net_log_->BeginEvent(NetLog::TYPE_SSL_GET_DOMAIN_BOUND_CERT);
- int rv = server_bound_cert_service_->GetDomainBoundCert(
+ int rv = server_bound_cert_service_->GetOrCreateDomainBoundCert(
host,
&domain_bound_private_key_,
&domain_bound_cert_,
@@ -2751,12 +2760,12 @@ void SSLClientSocketNSS::Core::SetChannelIDProvided() {
SSLClientSocketNSS::SSLClientSocketNSS(
base::SequencedTaskRunner* nss_task_runner,
- ClientSocketHandle* transport_socket,
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context)
: nss_task_runner_(nss_task_runner),
- transport_(transport_socket),
+ transport_(transport_socket.Pass()),
host_and_port_(host_and_port),
ssl_config_(ssl_config),
cert_verifier_(context.cert_verifier),
@@ -2765,7 +2774,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(
completed_handshake_(false),
next_handshake_state_(STATE_NONE),
nss_fd_(NULL),
- net_log_(transport_socket->socket()->NetLog()),
+ net_log_(transport_->socket()->NetLog()),
transport_security_state_(context.transport_security_state),
valid_thread_id_(base::kInvalidThreadId) {
EnterFunction("");
@@ -3141,7 +3150,8 @@ int SSLClientSocketNSS::InitializeSSLOptions() {
net_log_, "SSL_OptionSet", "SSL_ENABLE_SESSION_TICKETS");
}
- rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_FALSE_START, PR_FALSE);
+ rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_FALSE_START,
+ ssl_config_.false_start_enabled);
if (rv != SECSuccess)
LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_FALSE_START");
diff --git a/chromium/net/socket/ssl_client_socket_nss.h b/chromium/net/socket/ssl_client_socket_nss.h
index fed8ef706b5..b41d28d74a8 100644
--- a/chromium/net/socket/ssl_client_socket_nss.h
+++ b/chromium/net/socket/ssl_client_socket_nss.h
@@ -59,7 +59,7 @@ class SSLClientSocketNSS : public SSLClientSocket {
// behaviour is desired, for performance or compatibility, the current task
// runner should be supplied instead.
SSLClientSocketNSS(base::SequencedTaskRunner* nss_task_runner,
- ClientSocketHandle* transport_socket,
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context);
diff --git a/chromium/net/socket/ssl_client_socket_openssl.cc b/chromium/net/socket/ssl_client_socket_openssl.cc
index 1431bc61486..416ab87bc4b 100644
--- a/chromium/net/socket/ssl_client_socket_openssl.cc
+++ b/chromium/net/socket/ssl_client_socket_openssl.cc
@@ -425,7 +425,7 @@ void SSLClientSocket::ClearSessionCache() {
}
SSLClientSocketOpenSSL::SSLClientSocketOpenSSL(
- ClientSocketHandle* transport_socket,
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context)
@@ -439,14 +439,14 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL(
cert_verifier_(context.cert_verifier),
ssl_(NULL),
transport_bio_(NULL),
- transport_(transport_socket),
+ transport_(transport_socket.Pass()),
host_and_port_(host_and_port),
ssl_config_(ssl_config),
ssl_session_cache_shard_(context.ssl_session_cache_shard),
trying_cached_session_(false),
next_handshake_state_(STATE_NONE),
npn_status_(kNextProtoUnsupported),
- net_log_(transport_socket->socket()->NetLog()) {
+ net_log_(transport_->socket()->NetLog()) {
}
SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() {
@@ -532,9 +532,11 @@ bool SSLClientSocketOpenSSL::Init() {
STACK_OF(SSL_CIPHER)* ciphers = SSL_get_ciphers(ssl_);
DCHECK(ciphers);
// See SSLConfig::disabled_cipher_suites for description of the suites
- // disabled by default. Note that !SHA384 only removes HMAC-SHA384 cipher
- // suites, not GCM cipher suites with SHA384 as the handshake hash.
- std::string command("DEFAULT:!NULL:!aNULL:!IDEA:!FZA:!SRP:!SHA384:!aECDH");
+ // disabled by default. Note that !SHA256 and !SHA384 only remove HMAC-SHA256
+ // and HMAC-SHA384 cipher suites, not GCM cipher suites with SHA256 or SHA384
+ // as the handshake hash.
+ std::string command("DEFAULT:!NULL:!aNULL:!IDEA:!FZA:!SRP:!SHA256:!SHA384:"
+ "!aECDH:!AESGCM+AES256");
// Walk through all the installed ciphers, seeing if any need to be
// appended to the cipher removal |command|.
for (int i = 0; i < sk_SSL_CIPHER_num(ciphers); ++i) {
diff --git a/chromium/net/socket/ssl_client_socket_openssl.h b/chromium/net/socket/ssl_client_socket_openssl.h
index 520f432b8bc..f66d95cc69d 100644
--- a/chromium/net/socket/ssl_client_socket_openssl.h
+++ b/chromium/net/socket/ssl_client_socket_openssl.h
@@ -41,7 +41,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
// The given hostname will be compared with the name(s) in the server's
// certificate during the SSL handshake. ssl_config specifies the SSL
// settings.
- SSLClientSocketOpenSSL(ClientSocketHandle* transport_socket,
+ SSLClientSocketOpenSSL(scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context);
diff --git a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc
index 7a37cdc1187..24c06059be5 100644
--- a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc
+++ b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc
@@ -67,7 +67,7 @@ bool LoadPrivateKeyOpenSSL(
const base::FilePath& filepath,
OpenSSLClientKeyStore::ScopedEVP_PKEY* pkey) {
std::string data;
- if (!file_util::ReadFileToString(filepath, &data)) {
+ if (!base::ReadFileToString(filepath, &data)) {
LOG(ERROR) << "Could not read private key file: "
<< filepath.value() << ": " << strerror(errno);
return false;
@@ -107,11 +107,13 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest {
}
protected:
- SSLClientSocket* CreateSSLClientSocket(
- StreamSocket* transport_socket,
+ scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<StreamSocket> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config) {
- return socket_factory_->CreateSSLClientSocket(transport_socket,
+ scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
+ connection->SetSocket(transport_socket.Pass());
+ return socket_factory_->CreateSSLClientSocket(connection.Pass(),
host_and_port,
ssl_config,
context_);
@@ -164,9 +166,9 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest {
// itself was a success.
bool CreateAndConnectSSLClientSocket(SSLConfig& ssl_config,
int* result) {
- sock_.reset(CreateSSLClientSocket(transport_.release(),
- test_server_->host_port_pair(),
- ssl_config));
+ sock_ = CreateSSLClientSocket(transport_.Pass(),
+ test_server_->host_port_pair(),
+ ssl_config);
if (sock_->IsConnected()) {
LOG(ERROR) << "SSL Socket prematurely connected";
diff --git a/chromium/net/socket/ssl_client_socket_pool.cc b/chromium/net/socket/ssl_client_socket_pool.cc
index fed268d4ee4..5d574b7edda 100644
--- a/chromium/net/socket/ssl_client_socket_pool.cc
+++ b/chromium/net/socket/ssl_client_socket_pool.cc
@@ -26,20 +26,18 @@
namespace net {
SSLSocketParams::SSLSocketParams(
- const scoped_refptr<TransportSocketParams>& transport_params,
- const scoped_refptr<SOCKSSocketParams>& socks_params,
+ const scoped_refptr<TransportSocketParams>& direct_params,
+ const scoped_refptr<SOCKSSocketParams>& socks_proxy_params,
const scoped_refptr<HttpProxySocketParams>& http_proxy_params,
- ProxyServer::Scheme proxy,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
PrivacyMode privacy_mode,
int load_flags,
bool force_spdy_over_ssl,
bool want_spdy_over_npn)
- : transport_params_(transport_params),
+ : direct_params_(direct_params),
+ socks_proxy_params_(socks_proxy_params),
http_proxy_params_(http_proxy_params),
- socks_params_(socks_params),
- proxy_(proxy),
host_and_port_(host_and_port),
ssl_config_(ssl_config),
privacy_mode_(privacy_mode),
@@ -47,39 +45,60 @@ SSLSocketParams::SSLSocketParams(
force_spdy_over_ssl_(force_spdy_over_ssl),
want_spdy_over_npn_(want_spdy_over_npn),
ignore_limits_(false) {
- switch (proxy_) {
- case ProxyServer::SCHEME_DIRECT:
- DCHECK(transport_params_.get() != NULL);
- DCHECK(http_proxy_params_.get() == NULL);
- DCHECK(socks_params_.get() == NULL);
- ignore_limits_ = transport_params_->ignore_limits();
- break;
- case ProxyServer::SCHEME_HTTP:
- case ProxyServer::SCHEME_HTTPS:
- DCHECK(transport_params_.get() == NULL);
- DCHECK(http_proxy_params_.get() != NULL);
- DCHECK(socks_params_.get() == NULL);
- ignore_limits_ = http_proxy_params_->ignore_limits();
- break;
- case ProxyServer::SCHEME_SOCKS4:
- case ProxyServer::SCHEME_SOCKS5:
- DCHECK(transport_params_.get() == NULL);
- DCHECK(http_proxy_params_.get() == NULL);
- DCHECK(socks_params_.get() != NULL);
- ignore_limits_ = socks_params_->ignore_limits();
- break;
- default:
- LOG(DFATAL) << "unknown proxy type";
- break;
+ if (direct_params_) {
+ DCHECK(!socks_proxy_params_);
+ DCHECK(!http_proxy_params_);
+ ignore_limits_ = direct_params_->ignore_limits();
+ } else if (socks_proxy_params_) {
+ DCHECK(!http_proxy_params_);
+ ignore_limits_ = socks_proxy_params_->ignore_limits();
+ } else {
+ DCHECK(http_proxy_params_);
+ ignore_limits_ = http_proxy_params_->ignore_limits();
}
}
SSLSocketParams::~SSLSocketParams() {}
+SSLSocketParams::ConnectionType SSLSocketParams::GetConnectionType() const {
+ if (direct_params_) {
+ DCHECK(!socks_proxy_params_);
+ DCHECK(!http_proxy_params_);
+ return DIRECT;
+ }
+
+ if (socks_proxy_params_) {
+ DCHECK(!http_proxy_params_);
+ return SOCKS_PROXY;
+ }
+
+ DCHECK(http_proxy_params_);
+ return HTTP_PROXY;
+}
+
+const scoped_refptr<TransportSocketParams>&
+SSLSocketParams::GetDirectConnectionParams() const {
+ DCHECK_EQ(GetConnectionType(), DIRECT);
+ return direct_params_;
+}
+
+const scoped_refptr<SOCKSSocketParams>&
+SSLSocketParams::GetSocksProxyConnectionParams() const {
+ DCHECK_EQ(GetConnectionType(), SOCKS_PROXY);
+ return socks_proxy_params_;
+}
+
+const scoped_refptr<HttpProxySocketParams>&
+SSLSocketParams::GetHttpProxyConnectionParams() const {
+ DCHECK_EQ(GetConnectionType(), HTTP_PROXY);
+ return http_proxy_params_;
+}
+
// Timeout for the SSL handshake portion of the connect.
static const int kSSLHandshakeTimeoutInSeconds = 30;
SSLConnectJob::SSLConnectJob(const std::string& group_name,
+ RequestPriority priority,
const scoped_refptr<SSLSocketParams>& params,
const base::TimeDelta& timeout_duration,
TransportClientSocketPool* transport_pool,
@@ -92,6 +111,7 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name,
NetLog* net_log)
: ConnectJob(group_name,
timeout_duration,
+ priority,
delegate,
BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
params_(params),
@@ -201,12 +221,14 @@ int SSLConnectJob::DoTransportConnect() {
next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
transport_socket_handle_.reset(new ClientSocketHandle());
- scoped_refptr<TransportSocketParams> transport_params =
- params_->transport_params();
- return transport_socket_handle_->Init(
- group_name(), transport_params,
- transport_params->destination().priority(), callback_, transport_pool_,
- net_log());
+ scoped_refptr<TransportSocketParams> direct_params =
+ params_->GetDirectConnectionParams();
+ return transport_socket_handle_->Init(group_name(),
+ direct_params,
+ priority(),
+ callback_,
+ transport_pool_,
+ net_log());
}
int SSLConnectJob::DoTransportConnectComplete(int result) {
@@ -220,10 +242,14 @@ int SSLConnectJob::DoSOCKSConnect() {
DCHECK(socks_pool_);
next_state_ = STATE_SOCKS_CONNECT_COMPLETE;
transport_socket_handle_.reset(new ClientSocketHandle());
- scoped_refptr<SOCKSSocketParams> socks_params = params_->socks_params();
- return transport_socket_handle_->Init(
- group_name(), socks_params, socks_params->destination().priority(),
- callback_, socks_pool_, net_log());
+ scoped_refptr<SOCKSSocketParams> socks_proxy_params =
+ params_->GetSocksProxyConnectionParams();
+ return transport_socket_handle_->Init(group_name(),
+ socks_proxy_params,
+ priority(),
+ callback_,
+ socks_pool_,
+ net_log());
}
int SSLConnectJob::DoSOCKSConnectComplete(int result) {
@@ -239,11 +265,13 @@ int SSLConnectJob::DoTunnelConnect() {
transport_socket_handle_.reset(new ClientSocketHandle());
scoped_refptr<HttpProxySocketParams> http_proxy_params =
- params_->http_proxy_params();
- return transport_socket_handle_->Init(
- group_name(), http_proxy_params,
- http_proxy_params->destination().priority(), callback_, http_proxy_pool_,
- net_log());
+ params_->GetHttpProxyConnectionParams();
+ return transport_socket_handle_->Init(group_name(),
+ http_proxy_params,
+ priority(),
+ callback_,
+ http_proxy_pool_,
+ net_log());
}
int SSLConnectJob::DoTunnelConnectComplete(int result) {
@@ -287,11 +315,11 @@ int SSLConnectJob::DoSSLConnect() {
connect_timing_.ssl_start = base::TimeTicks::Now();
- ssl_socket_.reset(client_socket_factory_->CreateSSLClientSocket(
- transport_socket_handle_.release(),
+ ssl_socket_ = client_socket_factory_->CreateSSLClientSocket(
+ transport_socket_handle_.Pass(),
params_->host_and_port(),
params_->ssl_config(),
- context_));
+ context_);
return ssl_socket_->Connect(callback_);
}
@@ -410,7 +438,7 @@ int SSLConnectJob::DoSSLConnectComplete(int result) {
}
if (result == OK || IsCertificateError(result)) {
- set_socket(ssl_socket_.release());
+ SetSocket(ssl_socket_.PassAs<StreamSocket>());
} else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) {
error_response_info_.cert_request_info = new SSLCertRequestInfo;
ssl_socket_->GetSSLCertRequestInfo(
@@ -420,23 +448,22 @@ int SSLConnectJob::DoSSLConnectComplete(int result) {
return result;
}
-int SSLConnectJob::ConnectInternal() {
- switch (params_->proxy()) {
- case ProxyServer::SCHEME_DIRECT:
- next_state_ = STATE_TRANSPORT_CONNECT;
- break;
- case ProxyServer::SCHEME_HTTP:
- case ProxyServer::SCHEME_HTTPS:
- next_state_ = STATE_TUNNEL_CONNECT;
- break;
- case ProxyServer::SCHEME_SOCKS4:
- case ProxyServer::SCHEME_SOCKS5:
- next_state_ = STATE_SOCKS_CONNECT;
- break;
- default:
- NOTREACHED() << "unknown proxy type";
- break;
+SSLConnectJob::State SSLConnectJob::GetInitialState(
+ SSLSocketParams::ConnectionType connection_type) {
+ switch (connection_type) {
+ case SSLSocketParams::DIRECT:
+ return STATE_TRANSPORT_CONNECT;
+ case SSLSocketParams::HTTP_PROXY:
+ return STATE_TUNNEL_CONNECT;
+ case SSLSocketParams::SOCKS_PROXY:
+ return STATE_SOCKS_CONNECT;
}
+ NOTREACHED();
+ return STATE_NONE;
+}
+
+int SSLConnectJob::ConnectInternal() {
+ next_state_ = GetInitialState(params_->GetConnectionType());
return DoLoop(OK);
}
@@ -491,7 +518,7 @@ SSLClientSocketPool::SSLClientSocketPool(
: transport_pool_(transport_pool),
socks_pool_(socks_pool),
http_proxy_pool_(http_proxy_pool),
- base_(max_sockets, max_sockets_per_group, histograms,
+ base_(this, max_sockets, max_sockets_per_group, histograms,
ClientSocketPool::unused_idle_socket_timeout(),
ClientSocketPool::used_idle_socket_timeout(),
new SSLConnectJobFactory(transport_pool,
@@ -509,32 +536,28 @@ SSLClientSocketPool::SSLClientSocketPool(
if (ssl_config_service_.get())
ssl_config_service_->AddObserver(this);
if (transport_pool_)
- transport_pool_->AddLayeredPool(this);
+ base_.AddLowerLayeredPool(transport_pool_);
if (socks_pool_)
- socks_pool_->AddLayeredPool(this);
+ base_.AddLowerLayeredPool(socks_pool_);
if (http_proxy_pool_)
- http_proxy_pool_->AddLayeredPool(this);
+ base_.AddLowerLayeredPool(http_proxy_pool_);
}
SSLClientSocketPool::~SSLClientSocketPool() {
- if (http_proxy_pool_)
- http_proxy_pool_->RemoveLayeredPool(this);
- if (socks_pool_)
- socks_pool_->RemoveLayeredPool(this);
- if (transport_pool_)
- transport_pool_->RemoveLayeredPool(this);
if (ssl_config_service_.get())
ssl_config_service_->RemoveObserver(this);
}
-ConnectJob* SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob(
+scoped_ptr<ConnectJob>
+SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const {
- return new SSLConnectJob(group_name, request.params(), ConnectionTimeout(),
- transport_pool_, socks_pool_, http_proxy_pool_,
- client_socket_factory_, host_resolver_,
- context_, delegate, net_log_);
+ return scoped_ptr<ConnectJob>(
+ new SSLConnectJob(group_name, request.priority(), request.params(),
+ ConnectionTimeout(), transport_pool_, socks_pool_,
+ http_proxy_pool_, client_socket_factory_,
+ host_resolver_, context_, delegate, net_log_));
}
base::TimeDelta
@@ -572,21 +595,15 @@ void SSLClientSocketPool::CancelRequest(const std::string& group_name,
}
void SSLClientSocketPool::ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) {
- base_.ReleaseSocket(group_name, socket, id);
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
}
void SSLClientSocketPool::FlushWithError(int error) {
base_.FlushWithError(error);
}
-bool SSLClientSocketPool::IsStalled() const {
- return base_.IsStalled() ||
- (transport_pool_ && transport_pool_->IsStalled()) ||
- (socks_pool_ && socks_pool_->IsStalled()) ||
- (http_proxy_pool_ && http_proxy_pool_->IsStalled());
-}
-
void SSLClientSocketPool::CloseIdleSockets() {
base_.CloseIdleSockets();
}
@@ -605,14 +622,6 @@ LoadState SSLClientSocketPool::GetLoadState(
return base_.GetLoadState(group_name, handle);
}
-void SSLClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) {
- base_.AddLayeredPool(layered_pool);
-}
-
-void SSLClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) {
- base_.RemoveLayeredPool(layered_pool);
-}
-
base::DictionaryValue* SSLClientSocketPool::GetInfoAsValue(
const std::string& name,
const std::string& type,
@@ -648,14 +657,27 @@ ClientSocketPoolHistograms* SSLClientSocketPool::histograms() const {
return base_.histograms();
}
-void SSLClientSocketPool::OnSSLConfigChanged() {
- FlushWithError(ERR_NETWORK_CHANGED);
+bool SSLClientSocketPool::IsStalled() const {
+ return base_.IsStalled();
+}
+
+void SSLClientSocketPool::AddHigherLayeredPool(HigherLayeredPool* higher_pool) {
+ base_.AddHigherLayeredPool(higher_pool);
+}
+
+void SSLClientSocketPool::RemoveHigherLayeredPool(
+ HigherLayeredPool* higher_pool) {
+ base_.RemoveHigherLayeredPool(higher_pool);
}
bool SSLClientSocketPool::CloseOneIdleConnection() {
if (base_.CloseOneIdleSocket())
return true;
- return base_.CloseOneIdleConnectionInLayeredPool();
+ return base_.CloseOneIdleConnectionInHigherLayeredPool();
+}
+
+void SSLClientSocketPool::OnSSLConfigChanged() {
+ FlushWithError(ERR_NETWORK_CHANGED);
}
} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket_pool.h b/chromium/net/socket/ssl_client_socket_pool.h
index bc54bc92f9a..ec62eb01f46 100644
--- a/chromium/net/socket/ssl_client_socket_pool.h
+++ b/chromium/net/socket/ssl_client_socket_pool.h
@@ -13,7 +13,6 @@
#include "net/base/privacy_mode.h"
#include "net/dns/host_resolver.h"
#include "net/http/http_response_info.h"
-#include "net/proxy/proxy_server.h"
#include "net/socket/client_socket_pool.h"
#include "net/socket/client_socket_pool_base.h"
#include "net/socket/client_socket_pool_histograms.h"
@@ -35,32 +34,39 @@ class TransportClientSocketPool;
class TransportSecurityState;
class TransportSocketParams;
-// SSLSocketParams only needs the socket params for the transport socket
-// that will be used (denoted by |proxy|).
class NET_EXPORT_PRIVATE SSLSocketParams
: public base::RefCounted<SSLSocketParams> {
public:
- SSLSocketParams(const scoped_refptr<TransportSocketParams>& transport_params,
- const scoped_refptr<SOCKSSocketParams>& socks_params,
- const scoped_refptr<HttpProxySocketParams>& http_proxy_params,
- ProxyServer::Scheme proxy,
- const HostPortPair& host_and_port,
- const SSLConfig& ssl_config,
- PrivacyMode privacy_mode,
- int load_flags,
- bool force_spdy_over_ssl,
- bool want_spdy_over_npn);
-
- const scoped_refptr<TransportSocketParams>& transport_params() {
- return transport_params_;
- }
- const scoped_refptr<HttpProxySocketParams>& http_proxy_params() {
- return http_proxy_params_;
- }
- const scoped_refptr<SOCKSSocketParams>& socks_params() {
- return socks_params_;
- }
- ProxyServer::Scheme proxy() const { return proxy_; }
+ enum ConnectionType { DIRECT, SOCKS_PROXY, HTTP_PROXY };
+
+ // Exactly one of |direct_params|, |socks_proxy_params|, and
+ // |http_proxy_params| must be non-NULL.
+ SSLSocketParams(
+ const scoped_refptr<TransportSocketParams>& direct_params,
+ const scoped_refptr<SOCKSSocketParams>& socks_proxy_params,
+ const scoped_refptr<HttpProxySocketParams>& http_proxy_params,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ PrivacyMode privacy_mode,
+ int load_flags,
+ bool force_spdy_over_ssl,
+ bool want_spdy_over_npn);
+
+ // Returns the type of the underlying connection.
+ ConnectionType GetConnectionType() const;
+
+ // Must be called only when GetConnectionType() returns DIRECT.
+ const scoped_refptr<TransportSocketParams>&
+ GetDirectConnectionParams() const;
+
+ // Must be called only when GetConnectionType() returns SOCKS_PROXY.
+ const scoped_refptr<SOCKSSocketParams>&
+ GetSocksProxyConnectionParams() const;
+
+ // Must be called only when GetConnectionType() returns HTTP_PROXY.
+ const scoped_refptr<HttpProxySocketParams>&
+ GetHttpProxyConnectionParams() const;
+
const HostPortPair& host_and_port() const { return host_and_port_; }
const SSLConfig& ssl_config() const { return ssl_config_; }
PrivacyMode privacy_mode() const { return privacy_mode_; }
@@ -73,10 +79,9 @@ class NET_EXPORT_PRIVATE SSLSocketParams
friend class base::RefCounted<SSLSocketParams>;
~SSLSocketParams();
- const scoped_refptr<TransportSocketParams> transport_params_;
+ const scoped_refptr<TransportSocketParams> direct_params_;
+ const scoped_refptr<SOCKSSocketParams> socks_proxy_params_;
const scoped_refptr<HttpProxySocketParams> http_proxy_params_;
- const scoped_refptr<SOCKSSocketParams> socks_params_;
- const ProxyServer::Scheme proxy_;
const HostPortPair host_and_port_;
const SSLConfig ssl_config_;
const PrivacyMode privacy_mode_;
@@ -94,6 +99,7 @@ class SSLConnectJob : public ConnectJob {
public:
SSLConnectJob(
const std::string& group_name,
+ RequestPriority priority,
const scoped_refptr<SSLSocketParams>& params,
const base::TimeDelta& timeout_duration,
TransportClientSocketPool* transport_pool,
@@ -138,6 +144,10 @@ class SSLConnectJob : public ConnectJob {
int DoSSLConnect();
int DoSSLConnectComplete(int result);
+ // Returns the initial state for the state machine based on the
+ // |connection_type|.
+ static State GetInitialState(SSLSocketParams::ConnectionType connection_type);
+
// Starts the SSL connection process. Returns OK on success and
// ERR_IO_PENDING if it cannot immediately service the request.
// Otherwise, it returns a net error code.
@@ -164,9 +174,11 @@ class SSLConnectJob : public ConnectJob {
class NET_EXPORT_PRIVATE SSLClientSocketPool
: public ClientSocketPool,
- public LayeredPool,
+ public HigherLayeredPool,
public SSLConfigService::Observer {
public:
+ typedef SSLSocketParams SocketParams;
+
// Only the pools that will be used are required. i.e. if you never
// try to create an SSL over SOCKS socket, |socks_pool| may be NULL.
SSLClientSocketPool(
@@ -204,13 +216,11 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
ClientSocketHandle* handle) OVERRIDE;
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) OVERRIDE;
virtual void FlushWithError(int error) OVERRIDE;
- virtual bool IsStalled() const OVERRIDE;
-
virtual void CloseIdleSockets() OVERRIDE;
virtual int IdleSocketCount() const OVERRIDE;
@@ -222,10 +232,6 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
const std::string& group_name,
const ClientSocketHandle* handle) const OVERRIDE;
- virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE;
-
- virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE;
-
virtual base::DictionaryValue* GetInfoAsValue(
const std::string& name,
const std::string& type,
@@ -235,7 +241,14 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
- // LayeredPool implementation.
+ // LowerLayeredPool implementation.
+ virtual bool IsStalled() const OVERRIDE;
+
+ virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
+
+ virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
+
+ // HigherLayeredPool implementation.
virtual bool CloseOneIdleConnection() OVERRIDE;
private:
@@ -261,7 +274,7 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
virtual ~SSLConnectJobFactory() {}
// ClientSocketPoolBase::ConnectJobFactory methods.
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const OVERRIDE;
@@ -290,8 +303,6 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool);
};
-REGISTER_SOCKET_PARAMS_FOR_POOL(SSLClientSocketPool, SSLSocketParams);
-
} // namespace net
#endif // NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
diff --git a/chromium/net/socket/ssl_client_socket_pool_unittest.cc b/chromium/net/socket/ssl_client_socket_pool_unittest.cc
index 280f6e7af1a..8aecb98dd85 100644
--- a/chromium/net/socket/ssl_client_socket_pool_unittest.cc
+++ b/chromium/net/socket/ssl_client_socket_pool_unittest.cc
@@ -85,7 +85,6 @@ class SSLClientSocketPoolTest
session_(CreateNetworkSession()),
direct_transport_socket_params_(
new TransportSocketParams(HostPortPair("host", 443),
- MEDIUM,
false,
false,
OnHostResolutionCallback())),
@@ -96,15 +95,13 @@ class SSLClientSocketPoolTest
&socket_factory_),
proxy_transport_socket_params_(
new TransportSocketParams(HostPortPair("proxy", 443),
- MEDIUM,
false,
false,
OnHostResolutionCallback())),
socks_socket_params_(
new SOCKSSocketParams(proxy_transport_socket_params_,
true,
- HostPortPair("sockshost", 443),
- MEDIUM)),
+ HostPortPair("sockshost", 443))),
socks_histograms_("MockSOCKS"),
socks_socket_pool_(kMaxSockets,
kMaxSocketsPerGroup,
@@ -159,7 +156,6 @@ class SSLClientSocketPoolTest
: NULL,
proxy == ProxyServer::SCHEME_SOCKS5 ? socks_socket_params_ : NULL,
proxy == ProxyServer::SCHEME_HTTP ? http_proxy_socket_params_ : NULL,
- proxy,
HostPortPair("host", 443),
ssl_config_,
kPrivacyModeDisabled,
@@ -294,6 +290,30 @@ TEST_P(SSLClientSocketPoolTest, BasicDirect) {
TestLoadTimingInfo(handle);
}
+// Make sure that SSLConnectJob passes on its priority to its
+// socket request on Init (for the DIRECT case).
+TEST_P(SSLClientSocketPoolTest, SetSocketRequestPriorityOnInitDirect) {
+ CreatePool(true /* tcp pool */, false, false);
+ scoped_refptr<SSLSocketParams> params =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+
+ for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) {
+ RequestPriority priority = static_cast<RequestPriority>(i);
+ StaticSocketDataProvider data;
+ data.set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(SYNCHRONOUS, OK);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(OK, handle.Init("a", params, priority, callback.callback(),
+ pool_.get(), BoundNetLog()));
+ EXPECT_EQ(priority, transport_socket_pool_.last_request_priority());
+ handle.socket()->Disconnect();
+ }
+}
+
TEST_P(SSLClientSocketPoolTest, BasicDirectAsync) {
StaticSocketDataProvider data;
socket_factory_.AddSocketDataProvider(&data);
@@ -547,6 +567,26 @@ TEST_P(SSLClientSocketPoolTest, SOCKSBasic) {
TestLoadTimingInfo(handle);
}
+// Make sure that SSLConnectJob passes on its priority to its
+// transport socket on Init (for the SOCKS_PROXY case).
+TEST_P(SSLClientSocketPoolTest, SetTransportPriorityOnInitSOCKS) {
+ StaticSocketDataProvider data;
+ data.set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(SYNCHRONOUS, OK);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(false, true /* http proxy pool */, true /* socks pool */);
+ scoped_refptr<SSLSocketParams> params =
+ SSLParams(ProxyServer::SCHEME_SOCKS5, false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(OK, handle.Init("a", params, HIGHEST, callback.callback(),
+ pool_.get(), BoundNetLog()));
+ EXPECT_EQ(HIGHEST, transport_socket_pool_.last_request_priority());
+}
+
TEST_P(SSLClientSocketPoolTest, SOCKSBasicAsync) {
StaticSocketDataProvider data;
socket_factory_.AddSocketDataProvider(&data);
@@ -648,6 +688,38 @@ TEST_P(SSLClientSocketPoolTest, HttpProxyBasic) {
TestLoadTimingInfoNoDns(handle);
}
+// Make sure that SSLConnectJob passes on its priority to its
+// transport socket on Init (for the HTTP_PROXY case).
+TEST_P(SSLClientSocketPoolTest, SetTransportPriorityOnInitHTTP) {
+ MockWrite writes[] = {
+ MockWrite(SYNCHRONOUS,
+ "CONNECT host:80 HTTP/1.1\r\n"
+ "Host: host\r\n"
+ "Proxy-Connection: keep-alive\r\n"
+ "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"),
+ };
+ MockRead reads[] = {
+ MockRead(SYNCHRONOUS, "HTTP/1.1 200 Connection Established\r\n\r\n"),
+ };
+ StaticSocketDataProvider data(reads, arraysize(reads), writes,
+ arraysize(writes));
+ data.set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ socket_factory_.AddSocketDataProvider(&data);
+ AddAuthToCache();
+ SSLSocketDataProvider ssl(SYNCHRONOUS, OK);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(false, true /* http proxy pool */, true /* socks pool */);
+ scoped_refptr<SSLSocketParams> params =
+ SSLParams(ProxyServer::SCHEME_HTTP, false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(OK, handle.Init("a", params, HIGHEST, callback.callback(),
+ pool_.get(), BoundNetLog()));
+ EXPECT_EQ(HIGHEST, transport_socket_pool_.last_request_priority());
+}
+
TEST_P(SSLClientSocketPoolTest, HttpProxyBasicAsync) {
MockWrite writes[] = {
MockWrite("CONNECT host:80 HTTP/1.1\r\n"
@@ -746,8 +818,12 @@ TEST_P(SSLClientSocketPoolTest, IPPooling) {
// This test requires that the HostResolver cache be populated. Normal
// code would have done this already, but we do it manually.
HostResolver::RequestInfo info(HostPortPair(test_hosts[i].name, kTestPort));
- host_resolver_.Resolve(info, &test_hosts[i].addresses, CompletionCallback(),
- NULL, BoundNetLog());
+ host_resolver_.Resolve(info,
+ DEFAULT_PRIORITY,
+ &test_hosts[i].addresses,
+ CompletionCallback(),
+ NULL,
+ BoundNetLog());
// Setup a SpdySessionKey
test_hosts[i].key = SpdySessionKey(
@@ -802,8 +878,12 @@ void SSLClientSocketPoolTest::TestIPPoolingDisabled(
// This test requires that the HostResolver cache be populated. Normal
// code would have done this already, but we do it manually.
HostResolver::RequestInfo info(HostPortPair(test_hosts[i].name, kTestPort));
- rv = host_resolver_.Resolve(info, &test_hosts[i].addresses,
- callback.callback(), NULL, BoundNetLog());
+ rv = host_resolver_.Resolve(info,
+ DEFAULT_PRIORITY,
+ &test_hosts[i].addresses,
+ callback.callback(),
+ NULL,
+ BoundNetLog());
EXPECT_EQ(OK, callback.GetResult(rv));
// Setup a SpdySessionKey
diff --git a/chromium/net/socket/ssl_client_socket_unittest.cc b/chromium/net/socket/ssl_client_socket_unittest.cc
index f0e7120a135..f791928580f 100644
--- a/chromium/net/socket/ssl_client_socket_unittest.cc
+++ b/chromium/net/socket/ssl_client_socket_unittest.cc
@@ -30,9 +30,11 @@
//-----------------------------------------------------------------------------
+namespace net {
+
namespace {
-const net::SSLConfig kDefaultSSLConfig;
+const SSLConfig kDefaultSSLConfig;
// WrappedStreamSocket is a base class that wraps an existing StreamSocket,
// forwarding the Socket and StreamSocket interfaces to the underlying
@@ -40,33 +42,30 @@ const net::SSLConfig kDefaultSSLConfig;
// This is to provide a common base class for subclasses to override specific
// StreamSocket methods for testing, while still communicating with a 'real'
// StreamSocket.
-class WrappedStreamSocket : public net::StreamSocket {
+class WrappedStreamSocket : public StreamSocket {
public:
- explicit WrappedStreamSocket(scoped_ptr<net::StreamSocket> transport)
- : transport_(transport.Pass()) {
- }
+ explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport)
+ : transport_(transport.Pass()) {}
virtual ~WrappedStreamSocket() {}
// StreamSocket implementation:
- virtual int Connect(const net::CompletionCallback& callback) OVERRIDE {
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE {
return transport_->Connect(callback);
}
- virtual void Disconnect() OVERRIDE {
- transport_->Disconnect();
- }
+ virtual void Disconnect() OVERRIDE { transport_->Disconnect(); }
virtual bool IsConnected() const OVERRIDE {
return transport_->IsConnected();
}
virtual bool IsConnectedAndIdle() const OVERRIDE {
return transport_->IsConnectedAndIdle();
}
- virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE {
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
return transport_->GetPeerAddress(address);
}
- virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE {
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
return transport_->GetLocalAddress(address);
}
- virtual const net::BoundNetLog& NetLog() const OVERRIDE {
+ virtual const BoundNetLog& NetLog() const OVERRIDE {
return transport_->NetLog();
}
virtual void SetSubresourceSpeculation() OVERRIDE {
@@ -84,20 +83,22 @@ class WrappedStreamSocket : public net::StreamSocket {
virtual bool WasNpnNegotiated() const OVERRIDE {
return transport_->WasNpnNegotiated();
}
- virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE {
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
return transport_->GetNegotiatedProtocol();
}
- virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE {
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
return transport_->GetSSLInfo(ssl_info);
}
// Socket implementation:
- virtual int Read(net::IOBuffer* buf, int buf_len,
- const net::CompletionCallback& callback) OVERRIDE {
+ virtual int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
return transport_->Read(buf, buf_len, callback);
}
- virtual int Write(net::IOBuffer* buf, int buf_len,
- const net::CompletionCallback& callback) OVERRIDE {
+ virtual int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
return transport_->Write(buf, buf_len, callback);
}
virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
@@ -108,7 +109,7 @@ class WrappedStreamSocket : public net::StreamSocket {
}
protected:
- scoped_ptr<net::StreamSocket> transport_;
+ scoped_ptr<StreamSocket> transport_;
};
// ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that
@@ -119,12 +120,13 @@ class WrappedStreamSocket : public net::StreamSocket {
// them from the TestServer.
class ReadBufferingStreamSocket : public WrappedStreamSocket {
public:
- explicit ReadBufferingStreamSocket(scoped_ptr<net::StreamSocket> transport);
+ explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport);
virtual ~ReadBufferingStreamSocket() {}
// Socket implementation:
- virtual int Read(net::IOBuffer* buf, int buf_len,
- const net::CompletionCallback& callback) OVERRIDE;
+ virtual int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
// Sets the internal buffer to |size|. This must not be greater than
// the largest value supplied to Read() - that is, it does not handle
@@ -148,19 +150,18 @@ class ReadBufferingStreamSocket : public WrappedStreamSocket {
void OnReadCompleted(int result);
State state_;
- scoped_refptr<net::GrowableIOBuffer> read_buffer_;
+ scoped_refptr<GrowableIOBuffer> read_buffer_;
int buffer_size_;
- scoped_refptr<net::IOBuffer> user_read_buf_;
- net::CompletionCallback user_read_callback_;
+ scoped_refptr<IOBuffer> user_read_buf_;
+ CompletionCallback user_read_callback_;
};
ReadBufferingStreamSocket::ReadBufferingStreamSocket(
- scoped_ptr<net::StreamSocket> transport)
+ scoped_ptr<StreamSocket> transport)
: WrappedStreamSocket(transport.Pass()),
- read_buffer_(new net::GrowableIOBuffer()),
- buffer_size_(0) {
-}
+ read_buffer_(new GrowableIOBuffer()),
+ buffer_size_(0) {}
void ReadBufferingStreamSocket::SetBufferSize(int size) {
DCHECK(!user_read_buf_.get());
@@ -168,19 +169,19 @@ void ReadBufferingStreamSocket::SetBufferSize(int size) {
read_buffer_->SetCapacity(size);
}
-int ReadBufferingStreamSocket::Read(net::IOBuffer* buf,
+int ReadBufferingStreamSocket::Read(IOBuffer* buf,
int buf_len,
- const net::CompletionCallback& callback) {
+ const CompletionCallback& callback) {
if (buffer_size_ == 0)
return transport_->Read(buf, buf_len, callback);
if (buf_len < buffer_size_)
- return net::ERR_UNEXPECTED;
+ return ERR_UNEXPECTED;
state_ = STATE_READ;
user_read_buf_ = buf;
- int result = DoLoop(net::OK);
- if (result == net::ERR_IO_PENDING)
+ int result = DoLoop(OK);
+ if (result == ERR_IO_PENDING)
user_read_callback_ = callback;
else
user_read_buf_ = NULL;
@@ -202,10 +203,10 @@ int ReadBufferingStreamSocket::DoLoop(int result) {
case STATE_NONE:
default:
NOTREACHED() << "Unexpected state: " << current_state;
- rv = net::ERR_UNEXPECTED;
+ rv = ERR_UNEXPECTED;
break;
}
- } while (rv != net::ERR_IO_PENDING && state_ != STATE_NONE);
+ } while (rv != ERR_IO_PENDING && state_ != STATE_NONE);
return rv;
}
@@ -227,10 +228,11 @@ int ReadBufferingStreamSocket::DoReadComplete(int result) {
read_buffer_->set_offset(read_buffer_->offset() + result);
if (read_buffer_->RemainingCapacity() > 0) {
state_ = STATE_READ;
- return net::OK;
+ return OK;
}
- memcpy(user_read_buf_->data(), read_buffer_->StartOfBuffer(),
+ memcpy(user_read_buf_->data(),
+ read_buffer_->StartOfBuffer(),
read_buffer_->capacity());
read_buffer_->set_offset(0);
return read_buffer_->capacity();
@@ -238,7 +240,7 @@ int ReadBufferingStreamSocket::DoReadComplete(int result) {
void ReadBufferingStreamSocket::OnReadCompleted(int result) {
result = DoLoop(result);
- if (result == net::ERR_IO_PENDING)
+ if (result == ERR_IO_PENDING)
return;
user_read_buf_ = NULL;
@@ -252,16 +254,18 @@ class SynchronousErrorStreamSocket : public WrappedStreamSocket {
virtual ~SynchronousErrorStreamSocket() {}
// Socket implementation:
- virtual int Read(net::IOBuffer* buf, int buf_len,
- const net::CompletionCallback& callback) OVERRIDE;
- virtual int Write(net::IOBuffer* buf, int buf_len,
- const net::CompletionCallback& callback) OVERRIDE;
+ virtual int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
// Sets the next Read() call and all future calls to return |error|.
// If there is already a pending asynchronous read, the configured error
// will not be returned until that asynchronous read has completed and Read()
// is called again.
- void SetNextReadError(net::Error error) {
+ void SetNextReadError(Error error) {
DCHECK_GE(0, error);
have_read_error_ = true;
pending_read_error_ = error;
@@ -271,7 +275,7 @@ class SynchronousErrorStreamSocket : public WrappedStreamSocket {
// If there is already a pending asynchronous write, the configured error
// will not be returned until that asynchronous write has completed and
// Write() is called again.
- void SetNextWriteError(net::Error error) {
+ void SetNextWriteError(Error error) {
DCHECK_GE(0, error);
have_write_error_ = true;
pending_write_error_ = error;
@@ -291,24 +295,21 @@ SynchronousErrorStreamSocket::SynchronousErrorStreamSocket(
scoped_ptr<StreamSocket> transport)
: WrappedStreamSocket(transport.Pass()),
have_read_error_(false),
- pending_read_error_(net::OK),
+ pending_read_error_(OK),
have_write_error_(false),
- pending_write_error_(net::OK) {
-}
+ pending_write_error_(OK) {}
-int SynchronousErrorStreamSocket::Read(
- net::IOBuffer* buf,
- int buf_len,
- const net::CompletionCallback& callback) {
+int SynchronousErrorStreamSocket::Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
if (have_read_error_)
return pending_read_error_;
return transport_->Read(buf, buf_len, callback);
}
-int SynchronousErrorStreamSocket::Write(
- net::IOBuffer* buf,
- int buf_len,
- const net::CompletionCallback& callback) {
+int SynchronousErrorStreamSocket::Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
if (have_write_error_)
return pending_write_error_;
return transport_->Write(buf, buf_len, callback);
@@ -324,12 +325,14 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket {
virtual ~FakeBlockingStreamSocket() {}
// Socket implementation:
- virtual int Read(net::IOBuffer* buf, int buf_len,
- const net::CompletionCallback& callback) OVERRIDE {
+ virtual int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
return read_state_.RunWrappedFunction(buf, buf_len, callback);
}
- virtual int Write(net::IOBuffer* buf, int buf_len,
- const net::CompletionCallback& callback) OVERRIDE {
+ virtual int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
return write_state_.RunWrappedFunction(buf, buf_len, callback);
}
@@ -350,9 +353,8 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket {
class BlockingState {
public:
// Wrapper for the underlying Socket function to call (ie: Read/Write).
- typedef base::Callback<
- int(net::IOBuffer*, int,
- const net::CompletionCallback&)> WrappedSocketFunction;
+ typedef base::Callback<int(IOBuffer*, int, const CompletionCallback&)>
+ WrappedSocketFunction;
explicit BlockingState(const WrappedSocketFunction& function);
~BlockingState() {}
@@ -371,8 +373,9 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket {
// Performs the wrapped socket function on the underlying transport. If
// configured to block via SetShouldBlock(), then |user_callback| will not
// be invoked until Unblock() has been called.
- int RunWrappedFunction(net::IOBuffer* buf, int len,
- const net::CompletionCallback& user_callback);
+ int RunWrappedFunction(IOBuffer* buf,
+ int len,
+ const CompletionCallback& user_callback);
private:
// Handles completion from the underlying wrapped socket function.
@@ -382,7 +385,7 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket {
bool should_block_;
bool have_result_;
int pending_result_;
- net::CompletionCallback user_callback_;
+ CompletionCallback user_callback_;
};
BlockingState read_state_;
@@ -397,16 +400,14 @@ FakeBlockingStreamSocket::FakeBlockingStreamSocket(
read_state_(base::Bind(&Socket::Read,
base::Unretained(transport_.get()))),
write_state_(base::Bind(&Socket::Write,
- base::Unretained(transport_.get()))) {
-}
+ base::Unretained(transport_.get()))) {}
FakeBlockingStreamSocket::BlockingState::BlockingState(
const WrappedSocketFunction& function)
: wrapped_function_(function),
should_block_(false),
have_result_(false),
- pending_result_(net::OK) {
-}
+ pending_result_(OK) {}
void FakeBlockingStreamSocket::BlockingState::SetShouldBlock() {
DCHECK(!should_block_);
@@ -429,24 +430,24 @@ void FakeBlockingStreamSocket::BlockingState::Unblock() {
}
int FakeBlockingStreamSocket::BlockingState::RunWrappedFunction(
- net::IOBuffer* buf,
+ IOBuffer* buf,
int len,
- const net::CompletionCallback& callback) {
+ const CompletionCallback& callback) {
// The callback to be called by the underlying transport. Either forward
// directly to the user's callback if not set to block, or intercept it with
// OnCompleted so that the user's callback is not invoked until Unblock() is
// called.
- net::CompletionCallback transport_callback =
+ CompletionCallback transport_callback =
!should_block_ ? callback : base::Bind(&BlockingState::OnCompleted,
base::Unretained(this));
int rv = wrapped_function_.Run(buf, len, transport_callback);
if (should_block_) {
user_callback_ = callback;
// May have completed synchronously.
- have_result_ = (rv != net::ERR_IO_PENDING);
+ have_result_ = (rv != ERR_IO_PENDING);
pending_result_ = rv;
- return net::ERR_IO_PENDING;
+ return ERR_IO_PENDING;
}
return rv;
@@ -466,64 +467,61 @@ void FakeBlockingStreamSocket::BlockingState::OnCompleted(int result) {
base::ResetAndReturn(&user_callback_).Run(result);
}
-// CompletionCallback that will delete the associated net::StreamSocket when
+// CompletionCallback that will delete the associated StreamSocket when
// the callback is invoked.
-class DeleteSocketCallback : public net::TestCompletionCallbackBase {
+class DeleteSocketCallback : public TestCompletionCallbackBase {
public:
- explicit DeleteSocketCallback(net::StreamSocket* socket)
+ explicit DeleteSocketCallback(StreamSocket* socket)
: socket_(socket),
callback_(base::Bind(&DeleteSocketCallback::OnComplete,
- base::Unretained(this))) {
- }
+ base::Unretained(this))) {}
virtual ~DeleteSocketCallback() {}
- const net::CompletionCallback& callback() const { return callback_; }
+ const CompletionCallback& callback() const { return callback_; }
private:
void OnComplete(int result) {
- if (socket_) {
- delete socket_;
- socket_ = NULL;
- } else {
- ADD_FAILURE() << "Deleting socket twice";
- }
- SetResult(result);
+ if (socket_) {
+ delete socket_;
+ socket_ = NULL;
+ } else {
+ ADD_FAILURE() << "Deleting socket twice";
+ }
+ SetResult(result);
}
- net::StreamSocket* socket_;
- net::CompletionCallback callback_;
+ StreamSocket* socket_;
+ CompletionCallback callback_;
DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback);
};
-} // namespace
-
class SSLClientSocketTest : public PlatformTest {
public:
SSLClientSocketTest()
- : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()),
- cert_verifier_(new net::MockCertVerifier),
- transport_security_state_(new net::TransportSecurityState) {
- cert_verifier_->set_default_result(net::OK);
+ : socket_factory_(ClientSocketFactory::GetDefaultFactory()),
+ cert_verifier_(new MockCertVerifier),
+ transport_security_state_(new TransportSecurityState) {
+ cert_verifier_->set_default_result(OK);
context_.cert_verifier = cert_verifier_.get();
context_.transport_security_state = transport_security_state_.get();
}
protected:
- net::SSLClientSocket* CreateSSLClientSocket(
- net::StreamSocket* transport_socket,
- const net::HostPortPair& host_and_port,
- const net::SSLConfig& ssl_config) {
- return socket_factory_->CreateSSLClientSocket(transport_socket,
- host_and_port,
- ssl_config,
- context_);
+ scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<StreamSocket> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config) {
+ scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
+ connection->SetSocket(transport_socket.Pass());
+ return socket_factory_->CreateSSLClientSocket(
+ connection.Pass(), host_and_port, ssl_config, context_);
}
- net::ClientSocketFactory* socket_factory_;
- scoped_ptr<net::MockCertVerifier> cert_verifier_;
- scoped_ptr<net::TransportSecurityState> transport_security_state_;
- net::SSLClientSocketContext context_;
+ ClientSocketFactory* socket_factory_;
+ scoped_ptr<MockCertVerifier> cert_verifier_;
+ scoped_ptr<TransportSecurityState> transport_security_state_;
+ SSLClientSocketContext context_;
};
//-----------------------------------------------------------------------------
@@ -536,45 +534,45 @@ class SSLClientSocketTest : public PlatformTest {
// timeout. This means that an SSL connect end event may appear as a socket
// write.
static bool LogContainsSSLConnectEndEvent(
- const net::CapturingNetLog::CapturedEntryList& log, int i) {
- return net::LogContainsEndEvent(log, i, net::NetLog::TYPE_SSL_CONNECT) ||
- net::LogContainsEvent(log, i, net::NetLog::TYPE_SOCKET_BYTES_SENT,
- net::NetLog::PHASE_NONE);
-};
+ const CapturingNetLog::CapturedEntryList& log,
+ int i) {
+ return LogContainsEndEvent(log, i, NetLog::TYPE_SSL_CONNECT) ||
+ LogContainsEvent(
+ log, i, NetLog::TYPE_SOCKET_BYTES_SENT, NetLog::PHASE_NONE);
+}
+;
TEST_F(SSLClientSocketTest, Connect) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- net::CapturingNetLog log;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, &log, net::NetLog::Source());
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
- net::CapturingNetLog::CapturedEntryList entries;
+ CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
- EXPECT_TRUE(net::LogContainsBeginEvent(
- entries, 5, net::NetLog::TYPE_SSL_CONNECT));
- if (rv == net::ERR_IO_PENDING)
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
@@ -584,43 +582,40 @@ TEST_F(SSLClientSocketTest, Connect) {
}
TEST_F(SSLClientSocketTest, ConnectExpired) {
- net::SpawnedTestServer::SSLOptions ssl_options(
- net::SpawnedTestServer::SSLOptions::CERT_EXPIRED);
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- ssl_options,
- base::FilePath());
+ SpawnedTestServer::SSLOptions ssl_options(
+ SpawnedTestServer::SSLOptions::CERT_EXPIRED);
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
ASSERT_TRUE(test_server.Start());
- cert_verifier_->set_default_result(net::ERR_CERT_DATE_INVALID);
+ cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID);
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- net::CapturingNetLog log;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, &log, net::NetLog::Source());
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
- net::CapturingNetLog::CapturedEntryList entries;
+ CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
- EXPECT_TRUE(net::LogContainsBeginEvent(
- entries, 5, net::NetLog::TYPE_SSL_CONNECT));
- if (rv == net::ERR_IO_PENDING)
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::ERR_CERT_DATE_INVALID, rv);
+ EXPECT_EQ(ERR_CERT_DATE_INVALID, rv);
// Rather than testing whether or not the underlying socket is connected,
// test that the handshake has finished. This is because it may be
@@ -631,43 +626,40 @@ TEST_F(SSLClientSocketTest, ConnectExpired) {
}
TEST_F(SSLClientSocketTest, ConnectMismatched) {
- net::SpawnedTestServer::SSLOptions ssl_options(
- net::SpawnedTestServer::SSLOptions::CERT_MISMATCHED_NAME);
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- ssl_options,
- base::FilePath());
+ SpawnedTestServer::SSLOptions ssl_options(
+ SpawnedTestServer::SSLOptions::CERT_MISMATCHED_NAME);
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
ASSERT_TRUE(test_server.Start());
- cert_verifier_->set_default_result(net::ERR_CERT_COMMON_NAME_INVALID);
+ cert_verifier_->set_default_result(ERR_CERT_COMMON_NAME_INVALID);
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- net::CapturingNetLog log;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, &log, net::NetLog::Source());
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
- net::CapturingNetLog::CapturedEntryList entries;
+ CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
- EXPECT_TRUE(net::LogContainsBeginEvent(
- entries, 5, net::NetLog::TYPE_SSL_CONNECT));
- if (rv == net::ERR_IO_PENDING)
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::ERR_CERT_COMMON_NAME_INVALID, rv);
+ EXPECT_EQ(ERR_CERT_COMMON_NAME_INVALID, rv);
// Rather than testing whether or not the underlying socket is connected,
// test that the handshake has finished. This is because it may be
@@ -680,38 +672,35 @@ TEST_F(SSLClientSocketTest, ConnectMismatched) {
// Attempt to connect to a page which requests a client certificate. It should
// return an error code on connect.
TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) {
- net::SpawnedTestServer::SSLOptions ssl_options;
+ SpawnedTestServer::SSLOptions ssl_options;
ssl_options.request_client_certificate = true;
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- ssl_options,
- base::FilePath());
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- net::CapturingNetLog log;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, &log, net::NetLog::Source());
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
- net::CapturingNetLog::CapturedEntryList entries;
+ CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
- EXPECT_TRUE(net::LogContainsBeginEvent(
- entries, 5, net::NetLog::TYPE_SSL_CONNECT));
- if (rv == net::ERR_IO_PENDING)
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
log.GetEntries(&entries);
@@ -731,9 +720,9 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) {
// certificate. This test may still be useful as we'll want to close
// the socket on a timeout if the user takes a long time to pick a
// cert. Related bug: https://bugzilla.mozilla.org/show_bug.cgi?id=542832
- net::ExpectLogContainsSomewhere(
- entries, 0, net::NetLog::TYPE_SSL_CONNECT, net::NetLog::PHASE_END);
- EXPECT_EQ(net::ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv);
+ ExpectLogContainsSomewhere(
+ entries, 0, NetLog::TYPE_SSL_CONNECT, NetLog::PHASE_END);
+ EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv);
EXPECT_FALSE(sock->IsConnected());
}
@@ -742,32 +731,30 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) {
//
// TODO(davidben): Also test providing an actual certificate.
TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) {
- net::SpawnedTestServer::SSLOptions ssl_options;
+ SpawnedTestServer::SSLOptions ssl_options;
ssl_options.request_client_certificate = true;
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- ssl_options,
- base::FilePath());
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- net::CapturingNetLog log;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, &log, net::NetLog::Source());
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- net::SSLConfig ssl_config = kDefaultSSLConfig;
+ SSLConfig ssl_config = kDefaultSSLConfig;
ssl_config.send_client_cert = true;
ssl_config.client_cert = NULL;
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- ssl_config));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
EXPECT_FALSE(sock->IsConnected());
@@ -775,14 +762,13 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) {
// TODO(davidben): Add a test which requires them and verify the error.
rv = sock->Connect(callback.callback());
- net::CapturingNetLog::CapturedEntryList entries;
+ CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
- EXPECT_TRUE(net::LogContainsBeginEvent(
- entries, 5, net::NetLog::TYPE_SSL_CONNECT));
- if (rv == net::ERR_IO_PENDING)
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
@@ -790,7 +776,7 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) {
// We responded to the server's certificate request with a Certificate
// message with no client certificate in it. ssl_info.client_cert_sent
// should be false in this case.
- net::SSLInfo ssl_info;
+ SSLInfo ssl_info;
sock->GetSSLInfo(&ssl_info);
EXPECT_FALSE(ssl_info.client_cert_sent);
@@ -804,51 +790,50 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) {
// - Server sends data unexpectedly.
TEST_F(SSLClientSocketTest, Read) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, NULL, net::NetLog::Source());
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
- scoped_refptr<net::IOBuffer> request_buffer(
- new net::IOBuffer(arraysize(request_text) - 1));
+ scoped_refptr<IOBuffer> request_buffer(
+ new IOBuffer(arraysize(request_text) - 1));
memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
rv = sock->Write(
request_buffer.get(), arraysize(request_text) - 1, callback.callback());
- EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
- scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096));
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
for (;;) {
rv = sock->Read(buf.get(), 4096, callback.callback());
- EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_GE(rv, 0);
@@ -862,39 +847,40 @@ TEST_F(SSLClientSocketTest, Read) {
// the socket connection uncleanly.
// This is a regression test for http://crbug.com/238536
TEST_F(SSLClientSocketTest, Read_WithSynchronousError) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket(
- addr, NULL, net::NetLog::Source()));
- SynchronousErrorStreamSocket* transport = new SynchronousErrorStreamSocket(
- real_transport.Pass());
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ scoped_ptr<SynchronousErrorStreamSocket> transport(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
int rv = callback.GetResult(transport->Connect(callback.callback()));
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
// Disable TLS False Start to avoid handshake non-determinism.
- net::SSLConfig ssl_config;
+ SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
+ SynchronousErrorStreamSocket* raw_transport = transport.get();
+ scoped_ptr<SSLClientSocket> sock(
+ CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
+ test_server.host_port_pair(),
ssl_config));
rv = callback.GetResult(sock->Connect(callback.callback()));
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
static const int kRequestTextSize =
static_cast<int>(arraysize(request_text) - 1);
- scoped_refptr<net::IOBuffer> request_buffer(
- new net::IOBuffer(kRequestTextSize));
+ scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize));
memcpy(request_buffer->data(), request_text, kRequestTextSize);
rv = callback.GetResult(
@@ -902,9 +888,9 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) {
EXPECT_EQ(kRequestTextSize, rv);
// Simulate an unclean/forcible shutdown.
- transport->SetNextReadError(net::ERR_CONNECTION_RESET);
+ raw_transport->SetNextReadError(ERR_CONNECTION_RESET);
- scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096));
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
// Note: This test will hang if this bug has regressed. Simply checking that
// rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING is a legitimate
@@ -913,7 +899,7 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) {
#if !defined(USE_OPENSSL)
// SSLClientSocketNSS records the error exactly
- EXPECT_EQ(net::ERR_CONNECTION_RESET, rv);
+ EXPECT_EQ(ERR_CONNECTION_RESET, rv);
#else
// SSLClientSocketOpenSSL treats any errors as a simple EOF.
EXPECT_EQ(0, rv);
@@ -925,49 +911,51 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) {
// intermediary terminates the socket connection uncleanly.
// This is a regression test for http://crbug.com/249848
TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket(
- addr, NULL, net::NetLog::Source()));
- // Note: |error_socket|'s ownership is handed to |transport|, but the pointer
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ // Note: |error_socket|'s ownership is handed to |transport|, but a pointer
// is retained in order to configure additional errors.
- SynchronousErrorStreamSocket* error_socket = new SynchronousErrorStreamSocket(
- real_transport.Pass());
- FakeBlockingStreamSocket* transport = new FakeBlockingStreamSocket(
- scoped_ptr<net::StreamSocket>(error_socket));
+ scoped_ptr<SynchronousErrorStreamSocket> error_socket(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+ SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
+ scoped_ptr<FakeBlockingStreamSocket> transport(
+ new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
+ FakeBlockingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
// Disable TLS False Start to avoid handshake non-determinism.
- net::SSLConfig ssl_config;
+ SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
+ scoped_ptr<SSLClientSocket> sock(
+ CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
+ test_server.host_port_pair(),
ssl_config));
rv = callback.GetResult(sock->Connect(callback.callback()));
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
static const int kRequestTextSize =
static_cast<int>(arraysize(request_text) - 1);
- scoped_refptr<net::IOBuffer> request_buffer(
- new net::IOBuffer(kRequestTextSize));
+ scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize));
memcpy(request_buffer->data(), request_text, kRequestTextSize);
// Simulate an unclean/forcible shutdown on the underlying socket.
// However, simulate this error asynchronously.
- error_socket->SetNextWriteError(net::ERR_CONNECTION_RESET);
- transport->SetNextWriteShouldBlock();
+ raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
+ raw_transport->SetNextWriteShouldBlock();
// This write should complete synchronously, because the TLS ciphertext
// can be created and placed into the outgoing buffers independent of the
@@ -976,14 +964,14 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
sock->Write(request_buffer.get(), kRequestTextSize, callback.callback()));
EXPECT_EQ(kRequestTextSize, rv);
- scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096));
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
rv = sock->Read(buf.get(), 4096, callback.callback());
- EXPECT_EQ(net::ERR_IO_PENDING, rv);
+ EXPECT_EQ(ERR_IO_PENDING, rv);
// Now unblock the outgoing request, having it fail with the connection
// being reset.
- transport->UnblockWrite();
+ raw_transport->UnblockWrite();
// Note: This will cause an inifite loop if this bug has regressed. Simply
// checking that rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING
@@ -992,7 +980,7 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
#if !defined(USE_OPENSSL)
// SSLClientSocketNSS records the error exactly
- EXPECT_EQ(net::ERR_CONNECTION_RESET, rv);
+ EXPECT_EQ(ERR_CONNECTION_RESET, rv);
#else
// SSLClientSocketOpenSSL treats any errors as a simple EOF.
EXPECT_EQ(0, rv);
@@ -1002,38 +990,37 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
// Test the full duplex mode, with Read and Write pending at the same time.
// This test also serves as a regression test for http://crbug.com/29815.
TEST_F(SSLClientSocketTest, Read_FullDuplex) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback; // Used for everything except Write.
+ TestCompletionCallback callback; // Used for everything except Write.
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, NULL, net::NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
// Issue a "hanging" Read first.
- scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096));
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
rv = sock->Read(buf.get(), 4096, callback.callback());
// We haven't written the request, so there should be no response yet.
- ASSERT_EQ(net::ERR_IO_PENDING, rv);
+ ASSERT_EQ(ERR_IO_PENDING, rv);
// Write the request.
// The request is padded with a User-Agent header to a size that causes the
@@ -1043,15 +1030,14 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) {
for (int i = 0; i < 3770; ++i)
request_text.push_back('*');
request_text.append("\r\n\r\n");
- scoped_refptr<net::IOBuffer> request_buffer(
- new net::StringIOBuffer(request_text));
+ scoped_refptr<IOBuffer> request_buffer(new StringIOBuffer(request_text));
- net::TestCompletionCallback callback2; // Used for Write only.
+ TestCompletionCallback callback2; // Used for Write only.
rv = sock->Write(
request_buffer.get(), request_text.size(), callback2.callback());
- EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback2.WaitForResult();
EXPECT_EQ(static_cast<int>(request_text.size()), rv);
@@ -1067,62 +1053,65 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) {
// callback, the Write() callback should not be invoked.
// Regression test for http://crbug.com/232633
TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket(
- addr, NULL, net::NetLog::Source()));
- // Note: |error_socket|'s ownership is handed to |transport|, but the pointer
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ // Note: |error_socket|'s ownership is handed to |transport|, but a pointer
// is retained in order to configure additional errors.
- SynchronousErrorStreamSocket* error_socket = new SynchronousErrorStreamSocket(
- real_transport.Pass());
- FakeBlockingStreamSocket* transport = new FakeBlockingStreamSocket(
- scoped_ptr<net::StreamSocket>(error_socket));
+ scoped_ptr<SynchronousErrorStreamSocket> error_socket(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+ SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
+ scoped_ptr<FakeBlockingStreamSocket> transport(
+ new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
+ FakeBlockingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
// Disable TLS False Start to avoid handshake non-determinism.
- net::SSLConfig ssl_config;
+ SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
- net::SSLClientSocket* sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- ssl_config));
+ scoped_ptr<SSLClientSocket> sock =
+ CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
+ test_server.host_port_pair(),
+ ssl_config);
rv = callback.GetResult(sock->Connect(callback.callback()));
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
std::string request_text = "GET / HTTP/1.1\r\nUser-Agent: long browser name ";
request_text.append(20 * 1024, '*');
request_text.append("\r\n\r\n");
- scoped_refptr<net::DrainableIOBuffer> request_buffer(
- new net::DrainableIOBuffer(new net::StringIOBuffer(request_text),
- request_text.size()));
+ scoped_refptr<DrainableIOBuffer> request_buffer(new DrainableIOBuffer(
+ new StringIOBuffer(request_text), request_text.size()));
// Simulate errors being returned from the underlying Read() and Write() ...
- error_socket->SetNextReadError(net::ERR_CONNECTION_RESET);
- error_socket->SetNextWriteError(net::ERR_CONNECTION_RESET);
+ raw_error_socket->SetNextReadError(ERR_CONNECTION_RESET);
+ raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
// ... but have those errors returned asynchronously. Because the Write() will
// return first, this will trigger the error.
- transport->SetNextReadShouldBlock();
- transport->SetNextWriteShouldBlock();
+ raw_transport->SetNextReadShouldBlock();
+ raw_transport->SetNextWriteShouldBlock();
// Enqueue a Read() before calling Write(), which should "hang" due to
// the ERR_IO_PENDING caused by SetReadShouldBlock() and thus return.
- DeleteSocketCallback read_callback(sock);
- scoped_refptr<net::IOBuffer> read_buf(new net::IOBuffer(4096));
- rv = sock->Read(read_buf.get(), 4096, read_callback.callback());
+ SSLClientSocket* raw_sock = sock.get();
+ DeleteSocketCallback read_callback(sock.release());
+ scoped_refptr<IOBuffer> read_buf(new IOBuffer(4096));
+ rv = raw_sock->Read(read_buf.get(), 4096, read_callback.callback());
// Ensure things didn't complete synchronously, otherwise |sock| is invalid.
- ASSERT_EQ(net::ERR_IO_PENDING, rv);
+ ASSERT_EQ(ERR_IO_PENDING, rv);
ASSERT_FALSE(read_callback.have_result());
#if !defined(USE_OPENSSL)
@@ -1142,9 +1131,9 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
// SSLClientSocketOpenSSL::Write() will not return until all of
// |request_buffer| has been written to the underlying BIO (although not
// necessarily the underlying transport).
- rv = callback.GetResult(sock->Write(request_buffer.get(),
- request_buffer->BytesRemaining(),
- callback.callback()));
+ rv = callback.GetResult(raw_sock->Write(request_buffer.get(),
+ request_buffer->BytesRemaining(),
+ callback.callback()));
ASSERT_LT(0, rv);
request_buffer->DidConsume(rv);
@@ -1157,22 +1146,22 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
// Attempt to write the remaining data. NSS will not be able to consume the
// application data because the internal buffers are full, while OpenSSL will
// return that its blocked because the underlying transport is blocked.
- rv = sock->Write(request_buffer.get(),
- request_buffer->BytesRemaining(),
- callback.callback());
- ASSERT_EQ(net::ERR_IO_PENDING, rv);
+ rv = raw_sock->Write(request_buffer.get(),
+ request_buffer->BytesRemaining(),
+ callback.callback());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
ASSERT_FALSE(callback.have_result());
// Now unblock Write(), which will invoke OnSendComplete and (eventually)
// call the Read() callback, deleting the socket and thus aborting calling
// the Write() callback.
- transport->UnblockWrite();
+ raw_transport->UnblockWrite();
rv = read_callback.WaitForResult();
#if !defined(USE_OPENSSL)
// NSS records the error exactly.
- EXPECT_EQ(net::ERR_CONNECTION_RESET, rv);
+ EXPECT_EQ(ERR_CONNECTION_RESET, rv);
#else
// OpenSSL treats any errors as a simple EOF.
EXPECT_EQ(0, rv);
@@ -1183,50 +1172,49 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
}
TEST_F(SSLClientSocketTest, Read_SmallChunks) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, NULL, net::NetLog::Source());
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
- scoped_refptr<net::IOBuffer> request_buffer(
- new net::IOBuffer(arraysize(request_text) - 1));
+ scoped_refptr<IOBuffer> request_buffer(
+ new IOBuffer(arraysize(request_text) - 1));
memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
rv = sock->Write(
request_buffer.get(), arraysize(request_text) - 1, callback.callback());
- EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
- scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(1));
+ scoped_refptr<IOBuffer> buf(new IOBuffer(1));
for (;;) {
rv = sock->Read(buf.get(), 1, callback.callback());
- EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_GE(rv, 0);
@@ -1236,34 +1224,36 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) {
}
TEST_F(SSLClientSocketTest, Read_ManySmallRecords) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
+ TestCompletionCallback callback;
- scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket(
- addr, NULL, net::NetLog::Source()));
- ReadBufferingStreamSocket* transport = new ReadBufferingStreamSocket(
- real_transport.Pass());
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ scoped_ptr<ReadBufferingStreamSocket> transport(
+ new ReadBufferingStreamSocket(real_transport.Pass()));
+ ReadBufferingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
- ASSERT_EQ(net::OK, rv);
+ ASSERT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
+ scoped_ptr<SSLClientSocket> sock(
+ CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
+ test_server.host_port_pair(),
kDefaultSSLConfig));
rv = callback.GetResult(sock->Connect(callback.callback()));
- ASSERT_EQ(net::OK, rv);
+ ASSERT_EQ(OK, rv);
ASSERT_TRUE(sock->IsConnected());
const char request_text[] = "GET /ssl-many-small-records HTTP/1.0\r\n\r\n";
- scoped_refptr<net::IOBuffer> request_buffer(
- new net::IOBuffer(arraysize(request_text) - 1));
+ scoped_refptr<IOBuffer> request_buffer(
+ new IOBuffer(arraysize(request_text) - 1));
memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
rv = callback.GetResult(sock->Write(
@@ -1280,117 +1270,114 @@ TEST_F(SSLClientSocketTest, Read_ManySmallRecords) {
// 15K was chosen because 15K is smaller than the 17K (max) read issued by
// the SSLClientSocket implementation, and larger than the minimum amount
// of ciphertext necessary to contain the 8K of plaintext requested below.
- transport->SetBufferSize(15000);
+ raw_transport->SetBufferSize(15000);
- scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(8192));
+ scoped_refptr<IOBuffer> buffer(new IOBuffer(8192));
rv = callback.GetResult(sock->Read(buffer.get(), 8192, callback.callback()));
ASSERT_EQ(rv, 8192);
}
TEST_F(SSLClientSocketTest, Read_Interrupted) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, NULL, net::NetLog::Source());
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
- scoped_refptr<net::IOBuffer> request_buffer(
- new net::IOBuffer(arraysize(request_text) - 1));
+ scoped_refptr<IOBuffer> request_buffer(
+ new IOBuffer(arraysize(request_text) - 1));
memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
rv = sock->Write(
request_buffer.get(), arraysize(request_text) - 1, callback.callback());
- EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
// Do a partial read and then exit. This test should not crash!
- scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(512));
+ scoped_refptr<IOBuffer> buf(new IOBuffer(512));
rv = sock->Read(buf.get(), 512, callback.callback());
- EXPECT_TRUE(rv > 0 || rv == net::ERR_IO_PENDING);
+ EXPECT_TRUE(rv > 0 || rv == ERR_IO_PENDING);
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_GT(rv, 0);
}
TEST_F(SSLClientSocketTest, Read_FullLogging) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- net::CapturingNetLog log;
- log.SetLogLevel(net::NetLog::LOG_ALL);
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, &log, net::NetLog::Source());
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ log.SetLogLevel(NetLog::LOG_ALL);
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
- scoped_refptr<net::IOBuffer> request_buffer(
- new net::IOBuffer(arraysize(request_text) - 1));
+ scoped_refptr<IOBuffer> request_buffer(
+ new IOBuffer(arraysize(request_text) - 1));
memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
rv = sock->Write(
request_buffer.get(), arraysize(request_text) - 1, callback.callback());
- EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
- net::CapturingNetLog::CapturedEntryList entries;
+ CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
- size_t last_index = net::ExpectLogContainsSomewhereAfter(
- entries, 5, net::NetLog::TYPE_SSL_SOCKET_BYTES_SENT,
- net::NetLog::PHASE_NONE);
+ size_t last_index = ExpectLogContainsSomewhereAfter(
+ entries, 5, NetLog::TYPE_SSL_SOCKET_BYTES_SENT, NetLog::PHASE_NONE);
- scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096));
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
for (;;) {
rv = sock->Read(buf.get(), 4096, callback.callback());
- EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_GE(rv, 0);
@@ -1398,61 +1385,59 @@ TEST_F(SSLClientSocketTest, Read_FullLogging) {
break;
log.GetEntries(&entries);
- last_index = net::ExpectLogContainsSomewhereAfter(
- entries, last_index + 1, net::NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED,
- net::NetLog::PHASE_NONE);
+ last_index =
+ ExpectLogContainsSomewhereAfter(entries,
+ last_index + 1,
+ NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED,
+ NetLog::PHASE_NONE);
}
}
// Regression test for http://crbug.com/42538
TEST_F(SSLClientSocketTest, PrematureApplicationData) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
- net::TestCompletionCallback callback;
+ AddressList addr;
+ TestCompletionCallback callback;
static const unsigned char application_data[] = {
- 0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b,
- 0xc2, 0xf8, 0xb2, 0xc1, 0x56, 0x42, 0xb9, 0x57, 0x7f, 0xde, 0x87, 0x46,
- 0xf7, 0xa3, 0x52, 0x42, 0x21, 0xf0, 0x13, 0x1c, 0x9c, 0x83, 0x88, 0xd6,
- 0x93, 0x0c, 0xf6, 0x36, 0x30, 0x05, 0x7e, 0x20, 0xb5, 0xb5, 0x73, 0x36,
- 0x53, 0x83, 0x0a, 0xfc, 0x17, 0x63, 0xbf, 0xa0, 0xe4, 0x42, 0x90, 0x0d,
- 0x2f, 0x18, 0x6d, 0x20, 0xd8, 0x36, 0x3f, 0xfc, 0xe6, 0x01, 0xfa, 0x0f,
- 0xa5, 0x75, 0x7f, 0x09, 0x00, 0x04, 0x00, 0x16, 0x03, 0x01, 0x11, 0x57,
- 0x0b, 0x00, 0x11, 0x53, 0x00, 0x11, 0x50, 0x00, 0x06, 0x22, 0x30, 0x82,
- 0x06, 0x1e, 0x30, 0x82, 0x05, 0x06, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02,
- 0x0a
- };
+ 0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b,
+ 0xc2, 0xf8, 0xb2, 0xc1, 0x56, 0x42, 0xb9, 0x57, 0x7f, 0xde, 0x87, 0x46,
+ 0xf7, 0xa3, 0x52, 0x42, 0x21, 0xf0, 0x13, 0x1c, 0x9c, 0x83, 0x88, 0xd6,
+ 0x93, 0x0c, 0xf6, 0x36, 0x30, 0x05, 0x7e, 0x20, 0xb5, 0xb5, 0x73, 0x36,
+ 0x53, 0x83, 0x0a, 0xfc, 0x17, 0x63, 0xbf, 0xa0, 0xe4, 0x42, 0x90, 0x0d,
+ 0x2f, 0x18, 0x6d, 0x20, 0xd8, 0x36, 0x3f, 0xfc, 0xe6, 0x01, 0xfa, 0x0f,
+ 0xa5, 0x75, 0x7f, 0x09, 0x00, 0x04, 0x00, 0x16, 0x03, 0x01, 0x11, 0x57,
+ 0x0b, 0x00, 0x11, 0x53, 0x00, 0x11, 0x50, 0x00, 0x06, 0x22, 0x30, 0x82,
+ 0x06, 0x1e, 0x30, 0x82, 0x05, 0x06, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02,
+ 0x0a};
// All reads and writes complete synchronously (async=false).
- net::MockRead data_reads[] = {
- net::MockRead(net::SYNCHRONOUS,
- reinterpret_cast<const char*>(application_data),
- arraysize(application_data)),
- net::MockRead(net::SYNCHRONOUS, net::OK),
- };
+ MockRead data_reads[] = {
+ MockRead(SYNCHRONOUS,
+ reinterpret_cast<const char*>(application_data),
+ arraysize(application_data)),
+ MockRead(SYNCHRONOUS, OK), };
- net::StaticSocketDataProvider data(data_reads, arraysize(data_reads),
- NULL, 0);
+ StaticSocketDataProvider data(data_reads, arraysize(data_reads), NULL, 0);
- net::StreamSocket* transport =
- new net::MockTCPClientSocket(addr, NULL, &data);
+ scoped_ptr<StreamSocket> transport(
+ new MockTCPClientSocket(addr, NULL, &data));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv);
+ EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv);
}
TEST_F(SSLClientSocketTest, CipherSuiteDisables) {
@@ -1460,46 +1445,41 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) {
// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml,
// only disabling those cipher suites that the test server actually
// implements.
- const uint16 kCiphersToDisable[] = {
- 0x0005, // TLS_RSA_WITH_RC4_128_SHA
+ const uint16 kCiphersToDisable[] = {0x0005, // TLS_RSA_WITH_RC4_128_SHA
};
- net::SpawnedTestServer::SSLOptions ssl_options;
+ SpawnedTestServer::SSLOptions ssl_options;
// Enable only RC4 on the test server.
- ssl_options.bulk_ciphers =
- net::SpawnedTestServer::SSLOptions::BULK_CIPHER_RC4;
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- ssl_options,
- base::FilePath());
+ ssl_options.bulk_ciphers = SpawnedTestServer::SSLOptions::BULK_CIPHER_RC4;
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- net::CapturingNetLog log;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, &log, net::NetLog::Source());
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- net::SSLConfig ssl_config;
+ SSLConfig ssl_config;
for (size_t i = 0; i < arraysize(kCiphersToDisable); ++i)
ssl_config.disabled_cipher_suites.push_back(kCiphersToDisable[i]);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- ssl_config));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
- net::CapturingNetLog::CapturedEntryList entries;
+ CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
- EXPECT_TRUE(net::LogContainsBeginEvent(
- entries, 5, net::NetLog::TYPE_SSL_CONNECT));
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
// NSS has special handling that maps a handshake_failure alert received
// immediately after a client_hello to be a mismatched cipher suite error,
@@ -1507,17 +1487,16 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) {
// Secure Transport (OS X), the handshake_failure is bubbled up without any
// interpretation, leading to ERR_SSL_PROTOCOL_ERROR. Either way, a failure
// indicates that no cipher suite was negotiated with the test server.
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_TRUE(rv == net::ERR_SSL_VERSION_OR_CIPHER_MISMATCH ||
- rv == net::ERR_SSL_PROTOCOL_ERROR);
+ EXPECT_TRUE(rv == ERR_SSL_VERSION_OR_CIPHER_MISMATCH ||
+ rv == ERR_SSL_PROTOCOL_ERROR);
// The exact ordering differs between SSLClientSocketNSS (which issues an
// extra read) and SSLClientSocketMac (which does not). Just make sure the
// error appears somewhere in the log.
log.GetEntries(&entries);
- net::ExpectLogContainsSomewhere(entries, 0,
- net::NetLog::TYPE_SSL_HANDSHAKE_ERROR,
- net::NetLog::PHASE_NONE);
+ ExpectLogContainsSomewhere(
+ entries, 0, NetLog::TYPE_SSL_HANDSHAKE_ERROR, NetLog::PHASE_NONE);
// We cannot test sock->IsConnected(), as the NSS implementation disconnects
// the socket when it encounters an error, whereas other implementations
@@ -1539,65 +1518,65 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) {
// Here we verify that such a simple ClientSocketHandle, not associated with any
// client socket pool, can be destroyed safely.
TEST_F(SSLClientSocketTest, ClientSocketHandleNotFromPool) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, NULL, net::NetLog::Source());
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- net::ClientSocketHandle* socket_handle = new net::ClientSocketHandle();
- socket_handle->set_socket(transport);
+ scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle());
+ socket_handle->SetSocket(transport.Pass());
- scoped_ptr<net::SSLClientSocket> sock(
- socket_factory_->CreateSSLClientSocket(
- socket_handle, test_server.host_port_pair(), kDefaultSSLConfig,
- context_));
+ scoped_ptr<SSLClientSocket> sock(
+ socket_factory_->CreateSSLClientSocket(socket_handle.Pass(),
+ test_server.host_port_pair(),
+ kDefaultSSLConfig,
+ context_));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
}
// Verifies that SSLClientSocket::ExportKeyingMaterial return a success
// code and different keying label results in different keying material.
TEST_F(SSLClientSocketTest, ExportKeyingMaterial) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- net::SpawnedTestServer::kLocalhost,
- base::FilePath());
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
+ TestCompletionCallback callback;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, NULL, net::NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
const int kKeyingMaterialSize = 32;
@@ -1605,23 +1584,23 @@ TEST_F(SSLClientSocketTest, ExportKeyingMaterial) {
const char* kKeyingContext = "";
unsigned char client_out1[kKeyingMaterialSize];
memset(client_out1, 0, sizeof(client_out1));
- rv = sock->ExportKeyingMaterial(kKeyingLabel1, false, kKeyingContext,
- client_out1, sizeof(client_out1));
- EXPECT_EQ(rv, net::OK);
+ rv = sock->ExportKeyingMaterial(
+ kKeyingLabel1, false, kKeyingContext, client_out1, sizeof(client_out1));
+ EXPECT_EQ(rv, OK);
const char* kKeyingLabel2 = "client-socket-test-2";
unsigned char client_out2[kKeyingMaterialSize];
memset(client_out2, 0, sizeof(client_out2));
- rv = sock->ExportKeyingMaterial(kKeyingLabel2, false, kKeyingContext,
- client_out2, sizeof(client_out2));
- EXPECT_EQ(rv, net::OK);
+ rv = sock->ExportKeyingMaterial(
+ kKeyingLabel2, false, kKeyingContext, client_out2, sizeof(client_out2));
+ EXPECT_EQ(rv, OK);
EXPECT_NE(memcmp(client_out1, client_out2, kKeyingMaterialSize), 0);
}
// Verifies that SSLClientSocket::ClearSessionCache can be called without
// explicit NSS initialization.
TEST(SSLClientSocket, ClearSessionCache) {
- net::SSLClientSocket::ClearSessionCache();
+ SSLClientSocket::ClearSessionCache();
}
// This tests that SSLInfo contains a properly re-constructed certificate
@@ -1639,86 +1618,84 @@ TEST(SSLClientSocket, ClearSessionCache) {
TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) {
// By default, cause the CertVerifier to treat all certificates as
// expired.
- cert_verifier_->set_default_result(net::ERR_CERT_DATE_INVALID);
+ cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID);
// We will expect SSLInfo to ultimately contain this chain.
- net::CertificateList certs = CreateCertificateListFromFile(
- net::GetTestCertsDirectory(), "redundant-validated-chain.pem",
- net::X509Certificate::FORMAT_AUTO);
+ CertificateList certs =
+ CreateCertificateListFromFile(GetTestCertsDirectory(),
+ "redundant-validated-chain.pem",
+ X509Certificate::FORMAT_AUTO);
ASSERT_EQ(3U, certs.size());
- net::X509Certificate::OSCertHandles temp_intermediates;
+ X509Certificate::OSCertHandles temp_intermediates;
temp_intermediates.push_back(certs[1]->os_cert_handle());
temp_intermediates.push_back(certs[2]->os_cert_handle());
- net::CertVerifyResult verify_result;
- verify_result.verified_cert =
- net::X509Certificate::CreateFromHandle(certs[0]->os_cert_handle(),
- temp_intermediates);
+ CertVerifyResult verify_result;
+ verify_result.verified_cert = X509Certificate::CreateFromHandle(
+ certs[0]->os_cert_handle(), temp_intermediates);
// Add a rule that maps the server cert (A) to the chain of A->B->C2
// rather than A->B->C.
- cert_verifier_->AddResultForCert(certs[0].get(), verify_result, net::OK);
+ cert_verifier_->AddResultForCert(certs[0].get(), verify_result, OK);
// Load and install the root for the validated chain.
- scoped_refptr<net::X509Certificate> root_cert =
- net::ImportCertFromFile(net::GetTestCertsDirectory(),
- "redundant-validated-chain-root.pem");
- ASSERT_NE(static_cast<net::X509Certificate*>(NULL), root_cert);
- net::ScopedTestRoot scoped_root(root_cert.get());
+ scoped_refptr<X509Certificate> root_cert = ImportCertFromFile(
+ GetTestCertsDirectory(), "redundant-validated-chain-root.pem");
+ ASSERT_NE(static_cast<X509Certificate*>(NULL), root_cert);
+ ScopedTestRoot scoped_root(root_cert.get());
// Set up a test server with CERT_CHAIN_WRONG_ROOT.
- net::SpawnedTestServer::SSLOptions ssl_options(
- net::SpawnedTestServer::SSLOptions::CERT_CHAIN_WRONG_ROOT);
- net::SpawnedTestServer test_server(
- net::SpawnedTestServer::TYPE_HTTPS, ssl_options,
+ SpawnedTestServer::SSLOptions ssl_options(
+ SpawnedTestServer::SSLOptions::CERT_CHAIN_WRONG_ROOT);
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS,
+ ssl_options,
base::FilePath(FILE_PATH_LITERAL("net/data/ssl")));
ASSERT_TRUE(test_server.Start());
- net::AddressList addr;
+ AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
- net::TestCompletionCallback callback;
- net::CapturingNetLog log;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, &log, net::NetLog::Source());
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
- net::CapturingNetLog::CapturedEntryList entries;
+ CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
- EXPECT_TRUE(net::LogContainsBeginEvent(
- entries, 5, net::NetLog::TYPE_SSL_CONNECT));
- if (rv == net::ERR_IO_PENDING)
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
- net::SSLInfo ssl_info;
+ SSLInfo ssl_info;
sock->GetSSLInfo(&ssl_info);
// Verify that SSLInfo contains the corrected re-constructed chain A -> B
// -> C2.
- const net::X509Certificate::OSCertHandles& intermediates =
+ const X509Certificate::OSCertHandles& intermediates =
ssl_info.cert->GetIntermediateCertificates();
ASSERT_EQ(2U, intermediates.size());
- EXPECT_TRUE(net::X509Certificate::IsSameOSCert(
- ssl_info.cert->os_cert_handle(), certs[0]->os_cert_handle()));
- EXPECT_TRUE(net::X509Certificate::IsSameOSCert(
- intermediates[0], certs[1]->os_cert_handle()));
- EXPECT_TRUE(net::X509Certificate::IsSameOSCert(
- intermediates[1], certs[2]->os_cert_handle()));
+ EXPECT_TRUE(X509Certificate::IsSameOSCert(ssl_info.cert->os_cert_handle(),
+ certs[0]->os_cert_handle()));
+ EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[0],
+ certs[1]->os_cert_handle()));
+ EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[1],
+ certs[2]->os_cert_handle()));
sock->Disconnect();
EXPECT_FALSE(sock->IsConnected());
@@ -1729,37 +1706,34 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest {
protected:
// Creates a test server with the given SSLOptions, connects to it and returns
// the SSLCertRequestInfo reported by the socket.
- scoped_refptr<net::SSLCertRequestInfo> GetCertRequest(
- net::SpawnedTestServer::SSLOptions ssl_options) {
- net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS,
- ssl_options,
- base::FilePath());
+ scoped_refptr<SSLCertRequestInfo> GetCertRequest(
+ SpawnedTestServer::SSLOptions ssl_options) {
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
if (!test_server.Start())
return NULL;
- net::AddressList addr;
+ AddressList addr;
if (!test_server.GetAddressList(&addr))
return NULL;
- net::TestCompletionCallback callback;
- net::CapturingNetLog log;
- net::StreamSocket* transport = new net::TCPClientSocket(
- addr, &log, net::NetLog::Source());
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- EXPECT_EQ(net::OK, rv);
+ EXPECT_EQ(OK, rv);
- scoped_ptr<net::SSLClientSocket> sock(
- CreateSSLClientSocket(transport, test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
- if (rv == net::ERR_IO_PENDING)
+ if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
- scoped_refptr<net::SSLCertRequestInfo> request_info =
- new net::SSLCertRequestInfo();
+ scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo();
sock->GetSSLCertRequestInfo(request_info.get());
sock->Disconnect();
EXPECT_FALSE(sock->IsConnected());
@@ -1769,10 +1743,9 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest {
};
TEST_F(SSLClientSocketCertRequestInfoTest, NoAuthorities) {
- net::SpawnedTestServer::SSLOptions ssl_options;
+ SpawnedTestServer::SSLOptions ssl_options;
ssl_options.request_client_certificate = true;
- scoped_refptr<net::SSLCertRequestInfo> request_info =
- GetCertRequest(ssl_options);
+ scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options);
ASSERT_TRUE(request_info.get());
EXPECT_EQ(0u, request_info->cert_authorities.size());
}
@@ -1781,39 +1754,36 @@ TEST_F(SSLClientSocketCertRequestInfoTest, TwoAuthorities) {
const base::FilePath::CharType kThawteFile[] =
FILE_PATH_LITERAL("thawte.single.pem");
const unsigned char kThawteDN[] = {
- 0x30, 0x4c, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
- 0x02, 0x5a, 0x41, 0x31, 0x25, 0x30, 0x23, 0x06, 0x03, 0x55, 0x04, 0x0a,
- 0x13, 0x1c, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, 0x43, 0x6f, 0x6e,
- 0x73, 0x75, 0x6c, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x28, 0x50, 0x74, 0x79,
- 0x29, 0x20, 0x4c, 0x74, 0x64, 0x2e, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03,
- 0x55, 0x04, 0x03, 0x13, 0x0d, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20,
- 0x53, 0x47, 0x43, 0x20, 0x43, 0x41
- };
+ 0x30, 0x4c, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
+ 0x02, 0x5a, 0x41, 0x31, 0x25, 0x30, 0x23, 0x06, 0x03, 0x55, 0x04, 0x0a,
+ 0x13, 0x1c, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, 0x43, 0x6f, 0x6e,
+ 0x73, 0x75, 0x6c, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x28, 0x50, 0x74, 0x79,
+ 0x29, 0x20, 0x4c, 0x74, 0x64, 0x2e, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03,
+ 0x55, 0x04, 0x03, 0x13, 0x0d, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20,
+ 0x53, 0x47, 0x43, 0x20, 0x43, 0x41};
const size_t kThawteLen = sizeof(kThawteDN);
const base::FilePath::CharType kDiginotarFile[] =
FILE_PATH_LITERAL("diginotar_root_ca.pem");
const unsigned char kDiginotarDN[] = {
- 0x30, 0x5f, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
- 0x02, 0x4e, 0x4c, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x0a,
- 0x13, 0x09, 0x44, 0x69, 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x31,
- 0x1a, 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x11, 0x44, 0x69,
- 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x20, 0x52, 0x6f, 0x6f, 0x74,
- 0x20, 0x43, 0x41, 0x31, 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48,
- 0x86, 0xf7, 0x0d, 0x01, 0x09, 0x01, 0x16, 0x11, 0x69, 0x6e, 0x66, 0x6f,
- 0x40, 0x64, 0x69, 0x67, 0x69, 0x6e, 0x6f, 0x74, 0x61, 0x72, 0x2e, 0x6e,
- 0x6c
- };
+ 0x30, 0x5f, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
+ 0x02, 0x4e, 0x4c, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x0a,
+ 0x13, 0x09, 0x44, 0x69, 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x31,
+ 0x1a, 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x11, 0x44, 0x69,
+ 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x20, 0x52, 0x6f, 0x6f, 0x74,
+ 0x20, 0x43, 0x41, 0x31, 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48,
+ 0x86, 0xf7, 0x0d, 0x01, 0x09, 0x01, 0x16, 0x11, 0x69, 0x6e, 0x66, 0x6f,
+ 0x40, 0x64, 0x69, 0x67, 0x69, 0x6e, 0x6f, 0x74, 0x61, 0x72, 0x2e, 0x6e,
+ 0x6c};
const size_t kDiginotarLen = sizeof(kDiginotarDN);
- net::SpawnedTestServer::SSLOptions ssl_options;
+ SpawnedTestServer::SSLOptions ssl_options;
ssl_options.request_client_certificate = true;
ssl_options.client_authorities.push_back(
- net::GetTestClientCertsDirectory().Append(kThawteFile));
+ GetTestClientCertsDirectory().Append(kThawteFile));
ssl_options.client_authorities.push_back(
- net::GetTestClientCertsDirectory().Append(kDiginotarFile));
- scoped_refptr<net::SSLCertRequestInfo> request_info =
- GetCertRequest(ssl_options);
+ GetTestClientCertsDirectory().Append(kDiginotarFile));
+ scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options);
ASSERT_TRUE(request_info.get());
ASSERT_EQ(2u, request_info->cert_authorities.size());
EXPECT_EQ(std::string(reinterpret_cast<const char*>(kThawteDN), kThawteLen),
@@ -1822,3 +1792,7 @@ TEST_F(SSLClientSocketCertRequestInfoTest, TwoAuthorities) {
std::string(reinterpret_cast<const char*>(kDiginotarDN), kDiginotarLen),
request_info->cert_authorities[1]);
}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/socket/ssl_server_socket.h b/chromium/net/socket/ssl_server_socket.h
index 52d53cb19a2..8b607bf80cf 100644
--- a/chromium/net/socket/ssl_server_socket.h
+++ b/chromium/net/socket/ssl_server_socket.h
@@ -6,6 +6,7 @@
#define NET_SOCKET_SSL_SERVER_SOCKET_H_
#include "base/basictypes.h"
+#include "base/memory/scoped_ptr.h"
#include "net/base/completion_callback.h"
#include "net/base/net_export.h"
#include "net/socket/ssl_socket.h"
@@ -52,8 +53,8 @@ NET_EXPORT void EnableSSLServerSockets();
//
// The caller starts the SSL server handshake by calling Handshake on the
// returned socket.
-NET_EXPORT SSLServerSocket* CreateSSLServerSocket(
- StreamSocket* socket,
+NET_EXPORT scoped_ptr<SSLServerSocket> CreateSSLServerSocket(
+ scoped_ptr<StreamSocket> socket,
X509Certificate* certificate,
crypto::RSAPrivateKey* key,
const SSLConfig& ssl_config);
diff --git a/chromium/net/socket/ssl_server_socket_nss.cc b/chromium/net/socket/ssl_server_socket_nss.cc
index c2681d3ee14..7e5d70118ac 100644
--- a/chromium/net/socket/ssl_server_socket_nss.cc
+++ b/chromium/net/socket/ssl_server_socket_nss.cc
@@ -78,19 +78,20 @@ void EnableSSLServerSockets() {
g_nss_ssl_server_init_singleton.Get();
}
-SSLServerSocket* CreateSSLServerSocket(
- StreamSocket* socket,
+scoped_ptr<SSLServerSocket> CreateSSLServerSocket(
+ scoped_ptr<StreamSocket> socket,
X509Certificate* cert,
crypto::RSAPrivateKey* key,
const SSLConfig& ssl_config) {
DCHECK(g_nss_server_sockets_init) << "EnableSSLServerSockets() has not been"
<< "called yet!";
- return new SSLServerSocketNSS(socket, cert, key, ssl_config);
+ return scoped_ptr<SSLServerSocket>(
+ new SSLServerSocketNSS(socket.Pass(), cert, key, ssl_config));
}
SSLServerSocketNSS::SSLServerSocketNSS(
- StreamSocket* transport_socket,
+ scoped_ptr<StreamSocket> transport_socket,
scoped_refptr<X509Certificate> cert,
crypto::RSAPrivateKey* key,
const SSLConfig& ssl_config)
@@ -100,7 +101,7 @@ SSLServerSocketNSS::SSLServerSocketNSS(
user_write_buf_len_(0),
nss_fd_(NULL),
nss_bufs_(NULL),
- transport_socket_(transport_socket),
+ transport_socket_(transport_socket.Pass()),
ssl_config_(ssl_config),
cert_(cert),
next_handshake_state_(STATE_NONE),
diff --git a/chromium/net/socket/ssl_server_socket_nss.h b/chromium/net/socket/ssl_server_socket_nss.h
index 17a1fc38750..8bbb0e338ac 100644
--- a/chromium/net/socket/ssl_server_socket_nss.h
+++ b/chromium/net/socket/ssl_server_socket_nss.h
@@ -24,7 +24,7 @@ class SSLServerSocketNSS : public SSLServerSocket {
public:
// See comments on CreateSSLServerSocket for details of how these
// parameters are used.
- SSLServerSocketNSS(StreamSocket* socket,
+ SSLServerSocketNSS(scoped_ptr<StreamSocket> socket,
scoped_refptr<X509Certificate> certificate,
crypto::RSAPrivateKey* key,
const SSLConfig& ssl_config);
diff --git a/chromium/net/socket/ssl_server_socket_openssl.cc b/chromium/net/socket/ssl_server_socket_openssl.cc
index e0cf8bc0b21..c327f2caf10 100644
--- a/chromium/net/socket/ssl_server_socket_openssl.cc
+++ b/chromium/net/socket/ssl_server_socket_openssl.cc
@@ -16,13 +16,13 @@ void EnableSSLServerSockets() {
NOTIMPLEMENTED();
}
-SSLServerSocket* CreateSSLServerSocket(StreamSocket* socket,
- X509Certificate* certificate,
- crypto::RSAPrivateKey* key,
- const SSLConfig& ssl_config) {
+scoped_ptr<SSLServerSocket> CreateSSLServerSocket(
+ scoped_ptr<StreamSocket> socket,
+ X509Certificate* certificate,
+ crypto::RSAPrivateKey* key,
+ const SSLConfig& ssl_config) {
NOTIMPLEMENTED();
- delete socket;
- return NULL;
+ return scoped_ptr<SSLServerSocket>();
}
} // namespace net
diff --git a/chromium/net/socket/ssl_server_socket_unittest.cc b/chromium/net/socket/ssl_server_socket_unittest.cc
index f931e2c957e..e1f7f496131 100644
--- a/chromium/net/socket/ssl_server_socket_unittest.cc
+++ b/chromium/net/socket/ssl_server_socket_unittest.cc
@@ -304,21 +304,24 @@ class SSLServerSocketTest : public PlatformTest {
protected:
void Initialize() {
- FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_);
- FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_);
+ scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle);
+ client_connection->SetSocket(
+ scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_)));
+ scoped_ptr<StreamSocket> server_socket(
+ new FakeSocket(&channel_2_, &channel_1_));
base::FilePath certs_dir(GetTestCertsDirectory());
base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
std::string cert_der;
- ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der));
+ ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der));
scoped_refptr<net::X509Certificate> cert =
X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size());
base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
std::string key_string;
- ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string));
+ ASSERT_TRUE(base::ReadFileToString(key_path, &key_string));
std::vector<uint8> key_vector(
reinterpret_cast<const uint8*>(key_string.data()),
reinterpret_cast<const uint8*>(key_string.data() +
@@ -344,11 +347,12 @@ class SSLServerSocketTest : public PlatformTest {
net::SSLClientSocketContext context;
context.cert_verifier = cert_verifier_.get();
context.transport_security_state = transport_security_state_.get();
- client_socket_.reset(
+ client_socket_ =
socket_factory_->CreateSSLClientSocket(
- fake_client_socket, host_and_pair, ssl_config, context));
- server_socket_.reset(net::CreateSSLServerSocket(
- fake_server_socket, cert.get(), private_key.get(), net::SSLConfig()));
+ client_connection.Pass(), host_and_pair, ssl_config, context);
+ server_socket_ = net::CreateSSLServerSocket(
+ server_socket.Pass(),
+ cert.get(), private_key.get(), net::SSLConfig());
}
FakeDataChannel channel_1_;
diff --git a/chromium/net/socket/stream_listen_socket.cc b/chromium/net/socket/stream_listen_socket.cc
index c85c671800d..1109e7527c3 100644
--- a/chromium/net/socket/stream_listen_socket.cc
+++ b/chromium/net/socket/stream_listen_socket.cc
@@ -27,6 +27,7 @@
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/net_util.h"
+#include "net/socket/socket_descriptor.h"
using std::string;
@@ -43,10 +44,8 @@ const int kReadBufSize = 4096;
} // namespace
#if defined(OS_WIN)
-const SocketDescriptor StreamListenSocket::kInvalidSocket = INVALID_SOCKET;
const int StreamListenSocket::kSocketError = SOCKET_ERROR;
#elif defined(OS_POSIX)
-const SocketDescriptor StreamListenSocket::kInvalidSocket = -1;
const int StreamListenSocket::kSocketError = -1;
#endif
diff --git a/chromium/net/socket/stream_listen_socket.h b/chromium/net/socket/stream_listen_socket.h
index 6f03eefaca2..9825a4ef126 100644
--- a/chromium/net/socket/stream_listen_socket.h
+++ b/chromium/net/socket/stream_listen_socket.h
@@ -30,21 +30,16 @@
#include "base/basictypes.h"
#include "base/compiler_specific.h"
+#include "base/memory/scoped_ptr.h"
#include "net/base/net_export.h"
-#include "net/socket/stream_listen_socket.h"
-
-#if defined(OS_POSIX)
-typedef int SocketDescriptor;
-#else
-typedef SOCKET SocketDescriptor;
-#endif
+#include "net/socket/socket_descriptor.h"
namespace net {
class IPEndPoint;
class NET_EXPORT StreamListenSocket
- : public base::RefCountedThreadSafe<StreamListenSocket>,
+ :
#if defined(OS_WIN)
public base::win::ObjectWatcher::Delegate {
#elif defined(OS_POSIX)
@@ -52,16 +47,17 @@ class NET_EXPORT StreamListenSocket
#endif
public:
+ virtual ~StreamListenSocket();
+
// TODO(erikkay): this delegate should really be split into two parts
// to split up the listener from the connected socket. Perhaps this class
// should be split up similarly.
class Delegate {
public:
// |server| is the original listening Socket, connection is the new
- // Socket that was created. Ownership of |connection| is transferred
- // to the delegate with this call.
+ // Socket that was created.
virtual void DidAccept(StreamListenSocket* server,
- StreamListenSocket* connection) = 0;
+ scoped_ptr<StreamListenSocket> connection) = 0;
virtual void DidRead(StreamListenSocket* connection,
const char* data,
int len) = 0;
@@ -78,7 +74,6 @@ class NET_EXPORT StreamListenSocket
// Copies the local address to |address|. Returns a network error code.
int GetLocalAddress(IPEndPoint* address);
- static const SocketDescriptor kInvalidSocket;
static const int kSocketError;
protected:
@@ -89,7 +84,6 @@ class NET_EXPORT StreamListenSocket
};
StreamListenSocket(SocketDescriptor s, Delegate* del);
- virtual ~StreamListenSocket();
SocketDescriptor AcceptSocket();
virtual void Accept() = 0;
@@ -107,7 +101,6 @@ class NET_EXPORT StreamListenSocket
Delegate* const socket_delegate_;
private:
- friend class base::RefCountedThreadSafe<StreamListenSocket>;
friend class TransportClientSocketTest;
void SendInternal(const char* bytes, int len);
@@ -146,7 +139,7 @@ class NET_EXPORT StreamListenSocketFactory {
virtual ~StreamListenSocketFactory() {}
// Returns a new instance of StreamListenSocket or NULL if an error occurred.
- virtual scoped_refptr<StreamListenSocket> CreateAndListen(
+ virtual scoped_ptr<StreamListenSocket> CreateAndListen(
StreamListenSocket::Delegate* delegate) const = 0;
};
diff --git a/chromium/net/socket/tcp_client_socket.cc b/chromium/net/socket/tcp_client_socket.cc
index dbd21056f39..22aea47778b 100644
--- a/chromium/net/socket/tcp_client_socket.cc
+++ b/chromium/net/socket/tcp_client_socket.cc
@@ -4,56 +4,317 @@
#include "net/socket/tcp_client_socket.h"
-#include "base/file_util.h"
-#include "base/files/file_path.h"
+#include "base/callback_helpers.h"
+#include "base/logging.h"
+#include "net/base/io_buffer.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
namespace net {
-namespace {
+TCPClientSocket::TCPClientSocket(const AddressList& addresses,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source)
+ : socket_(new TCPSocket(net_log, source)),
+ addresses_(addresses),
+ current_address_index_(-1),
+ next_connect_state_(CONNECT_STATE_NONE),
+ previously_disconnected_(false) {
+}
+
+TCPClientSocket::TCPClientSocket(scoped_ptr<TCPSocket> connected_socket,
+ const IPEndPoint& peer_address)
+ : socket_(connected_socket.Pass()),
+ addresses_(AddressList(peer_address)),
+ current_address_index_(0),
+ next_connect_state_(CONNECT_STATE_NONE),
+ previously_disconnected_(false) {
+ DCHECK(socket_);
+
+ socket_->SetDefaultOptionsForClient();
+ use_history_.set_was_ever_connected();
+}
+
+TCPClientSocket::~TCPClientSocket() {
+}
+
+int TCPClientSocket::Bind(const IPEndPoint& address) {
+ if (current_address_index_ >= 0 || bind_address_) {
+ // Cannot bind the socket if we are already connected or connecting.
+ NOTREACHED();
+ return ERR_UNEXPECTED;
+ }
+
+ int result = OK;
+ if (!socket_->IsValid()) {
+ result = OpenSocket(address.GetFamily());
+ if (result != OK)
+ return result;
+ }
+
+ result = socket_->Bind(address);
+ if (result != OK)
+ return result;
+
+ bind_address_.reset(new IPEndPoint(address));
+ return OK;
+}
+
+int TCPClientSocket::Connect(const CompletionCallback& callback) {
+ DCHECK(!callback.is_null());
+
+ // If connecting or already connected, then just return OK.
+ if (socket_->IsValid() && current_address_index_ >= 0)
+ return OK;
+
+ socket_->StartLoggingMultipleConnectAttempts(addresses_);
+
+ // We will try to connect to each address in addresses_. Start with the
+ // first one in the list.
+ next_connect_state_ = CONNECT_STATE_CONNECT;
+ current_address_index_ = 0;
+
+ int rv = DoConnectLoop(OK);
+ if (rv == ERR_IO_PENDING) {
+ connect_callback_ = callback;
+ } else {
+ socket_->EndLoggingMultipleConnectAttempts(rv);
+ }
+
+ return rv;
+}
-#if defined(OS_LINUX)
+int TCPClientSocket::DoConnectLoop(int result) {
+ DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE);
-// Checks to see if the system supports TCP FastOpen. Notably, it requires
-// kernel support. Additionally, this checks system configuration to ensure that
-// it's enabled.
-bool SystemSupportsTCPFastOpen() {
- static const base::FilePath::CharType kTCPFastOpenProcFilePath[] =
- "/proc/sys/net/ipv4/tcp_fastopen";
- std::string system_enabled_tcp_fastopen;
- if (!file_util::ReadFileToString(
- base::FilePath(kTCPFastOpenProcFilePath),
- &system_enabled_tcp_fastopen)) {
- return false;
+ int rv = result;
+ do {
+ ConnectState state = next_connect_state_;
+ next_connect_state_ = CONNECT_STATE_NONE;
+ switch (state) {
+ case CONNECT_STATE_CONNECT:
+ DCHECK_EQ(OK, rv);
+ rv = DoConnect();
+ break;
+ case CONNECT_STATE_CONNECT_COMPLETE:
+ rv = DoConnectComplete(rv);
+ break;
+ default:
+ NOTREACHED() << "bad state " << state;
+ rv = ERR_UNEXPECTED;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE);
+
+ return rv;
+}
+
+int TCPClientSocket::DoConnect() {
+ DCHECK_GE(current_address_index_, 0);
+ DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size()));
+
+ const IPEndPoint& endpoint = addresses_[current_address_index_];
+
+ if (previously_disconnected_) {
+ use_history_.Reset();
+ previously_disconnected_ = false;
}
- // As per http://lxr.linux.no/linux+v3.7.7/include/net/tcp.h#L225
- // TFO_CLIENT_ENABLE is the LSB
- if (system_enabled_tcp_fastopen.empty() ||
- (system_enabled_tcp_fastopen[0] & 0x1) == 0) {
- return false;
+ next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE;
+
+ if (socket_->IsValid()) {
+ DCHECK(bind_address_);
+ } else {
+ int result = OpenSocket(endpoint.GetFamily());
+ if (result != OK)
+ return result;
+
+ if (bind_address_) {
+ result = socket_->Bind(*bind_address_);
+ if (result != OK) {
+ socket_->Close();
+ return result;
+ }
+ }
+ }
+
+ // |socket_| is owned by this class and the callback won't be run once
+ // |socket_| is gone. Therefore, it is safe to use base::Unretained() here.
+ return socket_->Connect(endpoint,
+ base::Bind(&TCPClientSocket::DidCompleteConnect,
+ base::Unretained(this)));
+}
+
+int TCPClientSocket::DoConnectComplete(int result) {
+ if (result == OK) {
+ use_history_.set_was_ever_connected();
+ return OK; // Done!
+ }
+
+ // Close whatever partially connected socket we currently have.
+ DoDisconnect();
+
+ // Try to fall back to the next address in the list.
+ if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) {
+ next_connect_state_ = CONNECT_STATE_CONNECT;
+ ++current_address_index_;
+ return OK;
+ }
+
+ // Otherwise there is nothing to fall back to, so give up.
+ return result;
+}
+
+void TCPClientSocket::Disconnect() {
+ DoDisconnect();
+ current_address_index_ = -1;
+ bind_address_.reset();
+}
+
+void TCPClientSocket::DoDisconnect() {
+ // If connecting or already connected, record that the socket has been
+ // disconnected.
+ previously_disconnected_ = socket_->IsValid() && current_address_index_ >= 0;
+ socket_->Close();
+}
+
+bool TCPClientSocket::IsConnected() const {
+ return socket_->IsConnected();
+}
+
+bool TCPClientSocket::IsConnectedAndIdle() const {
+ return socket_->IsConnectedAndIdle();
+}
+
+int TCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
+ return socket_->GetPeerAddress(address);
+}
+
+int TCPClientSocket::GetLocalAddress(IPEndPoint* address) const {
+ DCHECK(address);
+
+ if (!socket_->IsValid()) {
+ if (bind_address_) {
+ *address = *bind_address_;
+ return OK;
+ }
+ return ERR_SOCKET_NOT_CONNECTED;
}
- return true;
+ return socket_->GetLocalAddress(address);
+}
+
+const BoundNetLog& TCPClientSocket::NetLog() const {
+ return socket_->net_log();
+}
+
+void TCPClientSocket::SetSubresourceSpeculation() {
+ use_history_.set_subresource_speculation();
+}
+
+void TCPClientSocket::SetOmniboxSpeculation() {
+ use_history_.set_omnibox_speculation();
+}
+
+bool TCPClientSocket::WasEverUsed() const {
+ return use_history_.was_used_to_convey_data();
+}
+
+bool TCPClientSocket::UsingTCPFastOpen() const {
+ return socket_->UsingTCPFastOpen();
+}
+
+bool TCPClientSocket::WasNpnNegotiated() const {
+ return false;
}
-#else
+NextProto TCPClientSocket::GetNegotiatedProtocol() const {
+ return kProtoUnknown;
+}
-bool SystemSupportsTCPFastOpen() {
+bool TCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
return false;
}
-#endif
+int TCPClientSocket::Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(!callback.is_null());
+
+ // |socket_| is owned by this class and the callback won't be run once
+ // |socket_| is gone. Therefore, it is safe to use base::Unretained() here.
+ CompletionCallback read_callback = base::Bind(
+ &TCPClientSocket::DidCompleteReadWrite, base::Unretained(this), callback);
+ int result = socket_->Read(buf, buf_len, read_callback);
+ if (result > 0)
+ use_history_.set_was_used_to_convey_data();
+
+ return result;
+}
+
+int TCPClientSocket::Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(!callback.is_null());
+
+ // |socket_| is owned by this class and the callback won't be run once
+ // |socket_| is gone. Therefore, it is safe to use base::Unretained() here.
+ CompletionCallback write_callback = base::Bind(
+ &TCPClientSocket::DidCompleteReadWrite, base::Unretained(this), callback);
+ int result = socket_->Write(buf, buf_len, write_callback);
+ if (result > 0)
+ use_history_.set_was_used_to_convey_data();
+ return result;
}
-static bool g_tcp_fastopen_enabled = false;
+bool TCPClientSocket::SetReceiveBufferSize(int32 size) {
+ return socket_->SetReceiveBufferSize(size);
+}
+
+bool TCPClientSocket::SetSendBufferSize(int32 size) {
+ return socket_->SetSendBufferSize(size);
+}
+
+bool TCPClientSocket::SetKeepAlive(bool enable, int delay) {
+ return socket_->SetKeepAlive(enable, delay);
+}
-void SetTCPFastOpenEnabled(bool value) {
- g_tcp_fastopen_enabled = value && SystemSupportsTCPFastOpen();
+bool TCPClientSocket::SetNoDelay(bool no_delay) {
+ return socket_->SetNoDelay(no_delay);
}
-bool IsTCPFastOpenEnabled() {
- return g_tcp_fastopen_enabled;
+void TCPClientSocket::DidCompleteConnect(int result) {
+ DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE);
+ DCHECK_NE(result, ERR_IO_PENDING);
+ DCHECK(!connect_callback_.is_null());
+
+ result = DoConnectLoop(result);
+ if (result != ERR_IO_PENDING) {
+ socket_->EndLoggingMultipleConnectAttempts(result);
+ base::ResetAndReturn(&connect_callback_).Run(result);
+ }
+}
+
+void TCPClientSocket::DidCompleteReadWrite(const CompletionCallback& callback,
+ int result) {
+ if (result > 0)
+ use_history_.set_was_used_to_convey_data();
+
+ callback.Run(result);
+}
+
+int TCPClientSocket::OpenSocket(AddressFamily family) {
+ DCHECK(!socket_->IsValid());
+
+ int result = socket_->Open(family);
+ if (result != OK)
+ return result;
+
+ socket_->SetDefaultOptionsForClient();
+
+ return OK;
}
} // namespace net
diff --git a/chromium/net/socket/tcp_client_socket.h b/chromium/net/socket/tcp_client_socket.h
index 8a2c0cd73f0..fabcbc1b39d 100644
--- a/chromium/net/socket/tcp_client_socket.h
+++ b/chromium/net/socket/tcp_client_socket.h
@@ -5,30 +5,116 @@
#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_H_
#define NET_SOCKET_TCP_CLIENT_SOCKET_H_
-#include "build/build_config.h"
+#include "base/basictypes.h"
+#include "base/compiler_specific.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/address_list.h"
+#include "net/base/completion_callback.h"
#include "net/base/net_export.h"
-
-#if defined(OS_WIN)
-#include "net/socket/tcp_client_socket_win.h"
-#elif defined(OS_POSIX)
-#include "net/socket/tcp_client_socket_libevent.h"
-#endif
+#include "net/base/net_log.h"
+#include "net/socket/stream_socket.h"
+#include "net/socket/tcp_socket.h"
namespace net {
// A client socket that uses TCP as the transport layer.
-#if defined(OS_WIN)
-typedef TCPClientSocketWin TCPClientSocket;
-#elif defined(OS_POSIX)
-typedef TCPClientSocketLibevent TCPClientSocket;
-#endif
-
-// Enable/disable experimental TCP FastOpen option.
-// Not thread safe. Must be called during initialization/startup only.
-NET_EXPORT void SetTCPFastOpenEnabled(bool value);
-
-// Check if the TCP FastOpen option is enabled.
-bool IsTCPFastOpenEnabled();
+class NET_EXPORT TCPClientSocket : public StreamSocket {
+ public:
+ // The IP address(es) and port number to connect to. The TCP socket will try
+ // each IP address in the list until it succeeds in establishing a
+ // connection.
+ TCPClientSocket(const AddressList& addresses,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source);
+
+ // Adopts the given, connected socket and then acts as if Connect() had been
+ // called. This function is used by TCPServerSocket and for testing.
+ TCPClientSocket(scoped_ptr<TCPSocket> connected_socket,
+ const IPEndPoint& peer_address);
+
+ virtual ~TCPClientSocket();
+
+ // Binds the socket to a local IP address and port.
+ int Bind(const IPEndPoint& address);
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE;
+ virtual void Disconnect() OVERRIDE;
+ virtual bool IsConnected() const OVERRIDE;
+ virtual bool IsConnectedAndIdle() const OVERRIDE;
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ virtual const BoundNetLog& NetLog() const OVERRIDE;
+ virtual void SetSubresourceSpeculation() OVERRIDE;
+ virtual void SetOmniboxSpeculation() OVERRIDE;
+ virtual bool WasEverUsed() const OVERRIDE;
+ virtual bool UsingTCPFastOpen() const OVERRIDE;
+ virtual bool WasNpnNegotiated() const OVERRIDE;
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+
+ // Socket implementation.
+ // Multiple outstanding requests are not supported.
+ // Full duplex mode (reading and writing at the same time) is supported.
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE;
+
+ virtual bool SetKeepAlive(bool enable, int delay);
+ virtual bool SetNoDelay(bool no_delay);
+
+ private:
+ // State machine for connecting the socket.
+ enum ConnectState {
+ CONNECT_STATE_CONNECT,
+ CONNECT_STATE_CONNECT_COMPLETE,
+ CONNECT_STATE_NONE,
+ };
+
+ // State machine used by Connect().
+ int DoConnectLoop(int result);
+ int DoConnect();
+ int DoConnectComplete(int result);
+
+ // Helper used by Disconnect(), which disconnects minus resetting
+ // current_address_index_ and bind_address_.
+ void DoDisconnect();
+
+ void DidCompleteConnect(int result);
+ void DidCompleteReadWrite(const CompletionCallback& callback, int result);
+
+ int OpenSocket(AddressFamily family);
+
+ scoped_ptr<TCPSocket> socket_;
+
+ // Local IP address and port we are bound to. Set to NULL if Bind()
+ // wasn't called (in that case OS chooses address/port).
+ scoped_ptr<IPEndPoint> bind_address_;
+
+ // The list of addresses we should try in order to establish a connection.
+ AddressList addresses_;
+
+ // Where we are in above list. Set to -1 if uninitialized.
+ int current_address_index_;
+
+ // External callback; called when connect is complete.
+ CompletionCallback connect_callback_;
+
+ // The next state for the Connect() state machine.
+ ConnectState next_connect_state_;
+
+ // This socket was previously disconnected and has not been re-connected.
+ bool previously_disconnected_;
+
+ // Record of connectivity and transmissions, for use in speculative connection
+ // histograms.
+ UseHistory use_history_;
+
+ DISALLOW_COPY_AND_ASSIGN(TCPClientSocket);
+};
} // namespace net
diff --git a/chromium/net/socket/tcp_client_socket_libevent.h b/chromium/net/socket/tcp_client_socket_libevent.h
deleted file mode 100644
index e5a0d8deab4..00000000000
--- a/chromium/net/socket/tcp_client_socket_libevent.h
+++ /dev/null
@@ -1,256 +0,0 @@
-// Copyright (c) 2012 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_
-#define NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_
-
-#include "base/memory/ref_counted.h"
-#include "base/memory/scoped_ptr.h"
-#include "base/message_loop/message_loop.h"
-#include "base/threading/non_thread_safe.h"
-#include "net/base/address_list.h"
-#include "net/base/completion_callback.h"
-#include "net/base/net_log.h"
-#include "net/socket/stream_socket.h"
-
-namespace net {
-
-class BoundNetLog;
-
-// A client socket that uses TCP as the transport layer.
-class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket,
- public base::NonThreadSafe {
- public:
- // The IP address(es) and port number to connect to. The TCP socket will try
- // each IP address in the list until it succeeds in establishing a
- // connection.
- TCPClientSocketLibevent(const AddressList& addresses,
- net::NetLog* net_log,
- const net::NetLog::Source& source);
-
- virtual ~TCPClientSocketLibevent();
-
- // AdoptSocket causes the given, connected socket to be adopted as a TCP
- // socket. This object must not be connected. This object takes ownership of
- // the given socket and then acts as if Connect() had been called. This
- // function is used by TCPServerSocket() to adopt accepted connections
- // and for testing.
- int AdoptSocket(int socket);
-
- // Binds the socket to a local IP address and port.
- int Bind(const IPEndPoint& address);
-
- // StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE;
- virtual void Disconnect() OVERRIDE;
- virtual bool IsConnected() const OVERRIDE;
- virtual bool IsConnectedAndIdle() const OVERRIDE;
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
- virtual const BoundNetLog& NetLog() const OVERRIDE;
- virtual void SetSubresourceSpeculation() OVERRIDE;
- virtual void SetOmniboxSpeculation() OVERRIDE;
- virtual bool WasEverUsed() const OVERRIDE;
- virtual bool UsingTCPFastOpen() const OVERRIDE;
- virtual bool WasNpnNegotiated() const OVERRIDE;
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
-
- // Socket implementation.
- // Multiple outstanding requests are not supported.
- // Full duplex mode (reading and writing at the same time) is supported
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
- virtual bool SetSendBufferSize(int32 size) OVERRIDE;
-
- virtual bool SetKeepAlive(bool enable, int delay);
- virtual bool SetNoDelay(bool no_delay);
-
- private:
- // State machine for connecting the socket.
- enum ConnectState {
- CONNECT_STATE_CONNECT,
- CONNECT_STATE_CONNECT_COMPLETE,
- CONNECT_STATE_NONE,
- };
-
- // States that a fast open socket attempt can result in.
- enum FastOpenStatus {
- FAST_OPEN_STATUS_UNKNOWN,
-
- // The initial fast open connect attempted returned synchronously,
- // indicating that we had and sent a cookie along with the initial data.
- FAST_OPEN_FAST_CONNECT_RETURN,
-
- // The initial fast open connect attempted returned asynchronously,
- // indicating that we did not have a cookie for the server.
- FAST_OPEN_SLOW_CONNECT_RETURN,
-
- // Some other error occurred on connection, so we couldn't tell if
- // fast open would have worked.
- FAST_OPEN_ERROR,
-
- // An attempt to do a fast open succeeded immediately
- // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server
- // had acked the data we sent.
- FAST_OPEN_SYN_DATA_ACK,
-
- // An attempt to do a fast open succeeded immediately
- // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server
- // had nacked the data we sent.
- FAST_OPEN_SYN_DATA_NACK,
-
- // An attempt to do a fast open succeeded immediately
- // (FAST_OPEN_FAST_CONNECT_RETURN) and our probe to determine if the
- // socket was using fast open failed.
- FAST_OPEN_SYN_DATA_FAILED,
-
- // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN)
- // and we later confirmed that the server had acked initial data. This
- // should never happen (we didn't send data, so it shouldn't have
- // been acked).
- FAST_OPEN_NO_SYN_DATA_ACK,
-
- // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN)
- // and we later discovered that the server had nacked initial data. This
- // is the expected case results for FAST_OPEN_SLOW_CONNECT_RETURN.
- FAST_OPEN_NO_SYN_DATA_NACK,
-
- // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN)
- // and our later probe for ack/nack state failed.
- FAST_OPEN_NO_SYN_DATA_FAILED,
-
- FAST_OPEN_MAX_VALUE
- };
-
- class ReadWatcher : public base::MessageLoopForIO::Watcher {
- public:
- explicit ReadWatcher(TCPClientSocketLibevent* socket) : socket_(socket) {}
-
- // MessageLoopForIO::Watcher methods
-
- virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE;
-
- virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE {}
-
- private:
- TCPClientSocketLibevent* const socket_;
-
- DISALLOW_COPY_AND_ASSIGN(ReadWatcher);
- };
-
- class WriteWatcher : public base::MessageLoopForIO::Watcher {
- public:
- explicit WriteWatcher(TCPClientSocketLibevent* socket) : socket_(socket) {}
-
- // MessageLoopForIO::Watcher implementation.
- virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE {}
- virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE;
-
- private:
- TCPClientSocketLibevent* const socket_;
-
- DISALLOW_COPY_AND_ASSIGN(WriteWatcher);
- };
-
- // State machine used by Connect().
- int DoConnectLoop(int result);
- int DoConnect();
- int DoConnectComplete(int result);
-
- // Helper used by Disconnect(), which disconnects minus the logging and
- // resetting of current_address_index_.
- void DoDisconnect();
-
- void DoReadCallback(int rv);
- void DoWriteCallback(int rv);
- void DidCompleteRead();
- void DidCompleteWrite();
- void DidCompleteConnect();
-
- // Returns true if a Connect() is in progress.
- bool waiting_connect() const {
- return next_connect_state_ != CONNECT_STATE_NONE;
- }
-
- // Helper to add a TCP_CONNECT (end) event to the NetLog.
- void LogConnectCompletion(int net_error);
-
- // Internal function to write to a socket.
- int InternalWrite(IOBuffer* buf, int buf_len);
-
- // Called when the socket is known to be in a connected state.
- void RecordFastOpenStatus();
-
- int socket_;
-
- // Local IP address and port we are bound to. Set to NULL if Bind()
- // was't called (in that cases OS chooses address/port).
- scoped_ptr<IPEndPoint> bind_address_;
-
- // Stores bound socket between Bind() and Connect() calls.
- int bound_socket_;
-
- // The list of addresses we should try in order to establish a connection.
- AddressList addresses_;
-
- // Where we are in above list. Set to -1 if uninitialized.
- int current_address_index_;
-
- // The socket's libevent wrappers
- base::MessageLoopForIO::FileDescriptorWatcher read_socket_watcher_;
- base::MessageLoopForIO::FileDescriptorWatcher write_socket_watcher_;
-
- // The corresponding watchers for reads and writes.
- ReadWatcher read_watcher_;
- WriteWatcher write_watcher_;
-
- // The buffer used by OnSocketReady to retry Read requests
- scoped_refptr<IOBuffer> read_buf_;
- int read_buf_len_;
-
- // The buffer used by OnSocketReady to retry Write requests
- scoped_refptr<IOBuffer> write_buf_;
- int write_buf_len_;
-
- // External callback; called when read is complete.
- CompletionCallback read_callback_;
-
- // External callback; called when write is complete.
- CompletionCallback write_callback_;
-
- // The next state for the Connect() state machine.
- ConnectState next_connect_state_;
-
- // The OS error that CONNECT_STATE_CONNECT last completed with.
- int connect_os_error_;
-
- BoundNetLog net_log_;
-
- // This socket was previously disconnected and has not been re-connected.
- bool previously_disconnected_;
-
- // Record of connectivity and transmissions, for use in speculative connection
- // histograms.
- UseHistory use_history_;
-
- // Enables experimental TCP FastOpen option.
- const bool use_tcp_fastopen_;
-
- // True when TCP FastOpen is in use and we have done the connect.
- bool tcp_fastopen_connected_;
-
- enum FastOpenStatus fast_open_status_;
-
- DISALLOW_COPY_AND_ASSIGN(TCPClientSocketLibevent);
-};
-
-} // namespace net
-
-#endif // NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_
diff --git a/chromium/net/socket/tcp_client_socket_win.h b/chromium/net/socket/tcp_client_socket_win.h
deleted file mode 100644
index 26c8b9feff2..00000000000
--- a/chromium/net/socket/tcp_client_socket_win.h
+++ /dev/null
@@ -1,162 +0,0 @@
-// Copyright (c) 2012 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_
-#define NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_
-
-#include <winsock2.h>
-
-#include "base/memory/scoped_ptr.h"
-#include "base/threading/non_thread_safe.h"
-#include "net/base/address_list.h"
-#include "net/base/completion_callback.h"
-#include "net/base/net_log.h"
-#include "net/socket/stream_socket.h"
-
-namespace net {
-
-class BoundNetLog;
-
-class NET_EXPORT TCPClientSocketWin : public StreamSocket,
- NON_EXPORTED_BASE(base::NonThreadSafe) {
- public:
- // The IP address(es) and port number to connect to. The TCP socket will try
- // each IP address in the list until it succeeds in establishing a
- // connection.
- TCPClientSocketWin(const AddressList& addresses,
- net::NetLog* net_log,
- const net::NetLog::Source& source);
-
- virtual ~TCPClientSocketWin();
-
- // AdoptSocket causes the given, connected socket to be adopted as a TCP
- // socket. This object must not be connected. This object takes ownership of
- // the given socket and then acts as if Connect() had been called. This
- // function is used by TCPServerSocket() to adopt accepted connections
- // and for testing.
- int AdoptSocket(SOCKET socket);
-
- // Binds the socket to a local IP address and port.
- int Bind(const IPEndPoint& address);
-
- // StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback);
- virtual void Disconnect();
- virtual bool IsConnected() const;
- virtual bool IsConnectedAndIdle() const;
- virtual int GetPeerAddress(IPEndPoint* address) const;
- virtual int GetLocalAddress(IPEndPoint* address) const;
- virtual const BoundNetLog& NetLog() const { return net_log_; }
- virtual void SetSubresourceSpeculation();
- virtual void SetOmniboxSpeculation();
- virtual bool WasEverUsed() const;
- virtual bool UsingTCPFastOpen() const;
- virtual bool WasNpnNegotiated() const OVERRIDE;
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
-
- // Socket implementation.
- // Multiple outstanding requests are not supported.
- // Full duplex mode (reading and writing at the same time) is supported
- virtual int Read(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback);
- virtual int Write(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback);
-
- virtual bool SetReceiveBufferSize(int32 size);
- virtual bool SetSendBufferSize(int32 size);
-
- virtual bool SetKeepAlive(bool enable, int delay);
- virtual bool SetNoDelay(bool no_delay);
-
- // Perform reads in non-blocking mode instead of overlapped mode.
- // Used for experiments.
- static void DisableOverlappedReads();
-
- private:
- // State machine for connecting the socket.
- enum ConnectState {
- CONNECT_STATE_CONNECT,
- CONNECT_STATE_CONNECT_COMPLETE,
- CONNECT_STATE_NONE,
- };
-
- class Core;
-
- // State machine used by Connect().
- int DoConnectLoop(int result);
- int DoConnect();
- int DoConnectComplete(int result);
-
- // Helper used by Disconnect(), which disconnects minus the logging and
- // resetting of current_address_index_.
- void DoDisconnect();
-
- // Returns true if a Connect() is in progress.
- bool waiting_connect() const {
- return next_connect_state_ != CONNECT_STATE_NONE;
- }
-
- // Called after Connect() has completed with |net_error|.
- void LogConnectCompletion(int net_error);
-
- int DoRead(IOBuffer* buf, int buf_len, const CompletionCallback& callback);
- void DoReadCallback(int rv);
- void DoWriteCallback(int rv);
- void DidCompleteConnect();
- void DidCompleteRead();
- void DidCompleteWrite();
- void DidSignalRead();
-
- SOCKET socket_;
-
- // Local IP address and port we are bound to. Set to NULL if Bind()
- // was't called (in that cases OS chooses address/port).
- scoped_ptr<IPEndPoint> bind_address_;
-
- // Stores bound socket between Bind() and Connect() calls.
- SOCKET bound_socket_;
-
- // The list of addresses we should try in order to establish a connection.
- AddressList addresses_;
-
- // Where we are in above list. Set to -1 if uninitialized.
- int current_address_index_;
-
- // The various states that the socket could be in.
- bool waiting_read_;
- bool waiting_write_;
-
- // The core of the socket that can live longer than the socket itself. We pass
- // resources to the Windows async IO functions and we have to make sure that
- // they are not destroyed while the OS still references them.
- scoped_refptr<Core> core_;
-
- // External callback; called when connect or read is complete.
- CompletionCallback read_callback_;
-
- // External callback; called when write is complete.
- CompletionCallback write_callback_;
-
- // The next state for the Connect() state machine.
- ConnectState next_connect_state_;
-
- // The OS error that CONNECT_STATE_CONNECT last completed with.
- int connect_os_error_;
-
- BoundNetLog net_log_;
-
- // This socket was previously disconnected and has not been re-connected.
- bool previously_disconnected_;
-
- // Record of connectivity and transmissions, for use in speculative connection
- // histograms.
- UseHistory use_history_;
-
- DISALLOW_COPY_AND_ASSIGN(TCPClientSocketWin);
-};
-
-} // namespace net
-
-#endif // NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_
diff --git a/chromium/net/socket/tcp_listen_socket.cc b/chromium/net/socket/tcp_listen_socket.cc
index aab2e45d0e9..223abee2cba 100644
--- a/chromium/net/socket/tcp_listen_socket.cc
+++ b/chromium/net/socket/tcp_listen_socket.cc
@@ -23,20 +23,21 @@
#include "build/build_config.h"
#include "net/base/net_util.h"
#include "net/base/winsock_init.h"
+#include "net/socket/socket_descriptor.h"
using std::string;
namespace net {
// static
-scoped_refptr<TCPListenSocket> TCPListenSocket::CreateAndListen(
+scoped_ptr<TCPListenSocket> TCPListenSocket::CreateAndListen(
const string& ip, int port, StreamListenSocket::Delegate* del) {
SocketDescriptor s = CreateAndBind(ip, port);
if (s == kInvalidSocket)
- return NULL;
- scoped_refptr<TCPListenSocket> sock(new TCPListenSocket(s, del));
+ return scoped_ptr<TCPListenSocket>();
+ scoped_ptr<TCPListenSocket> sock(new TCPListenSocket(s, del));
sock->Listen();
- return sock;
+ return sock.Pass();
}
TCPListenSocket::TCPListenSocket(SocketDescriptor s,
@@ -47,11 +48,7 @@ TCPListenSocket::TCPListenSocket(SocketDescriptor s,
TCPListenSocket::~TCPListenSocket() {}
SocketDescriptor TCPListenSocket::CreateAndBind(const string& ip, int port) {
-#if defined(OS_WIN)
- EnsureWinsockInit();
-#endif
-
- SocketDescriptor s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ SocketDescriptor s = CreatePlatformSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (s != kInvalidSocket) {
#if defined(OS_POSIX)
// Allow rapid reuse.
@@ -104,13 +101,13 @@ void TCPListenSocket::Accept() {
SocketDescriptor conn = AcceptSocket();
if (conn == kInvalidSocket)
return;
- scoped_refptr<TCPListenSocket> sock(
+ scoped_ptr<TCPListenSocket> sock(
new TCPListenSocket(conn, socket_delegate_));
// It's up to the delegate to AddRef if it wants to keep it around.
#if defined(OS_POSIX)
sock->WatchSocket(WAITING_READ);
#endif
- socket_delegate_->DidAccept(this, sock.get());
+ socket_delegate_->DidAccept(this, sock.PassAs<StreamListenSocket>());
}
TCPListenSocketFactory::TCPListenSocketFactory(const string& ip, int port)
@@ -120,9 +117,10 @@ TCPListenSocketFactory::TCPListenSocketFactory(const string& ip, int port)
TCPListenSocketFactory::~TCPListenSocketFactory() {}
-scoped_refptr<StreamListenSocket> TCPListenSocketFactory::CreateAndListen(
+scoped_ptr<StreamListenSocket> TCPListenSocketFactory::CreateAndListen(
StreamListenSocket::Delegate* delegate) const {
- return TCPListenSocket::CreateAndListen(ip_, port_, delegate);
+ return TCPListenSocket::CreateAndListen(ip_, port_, delegate)
+ .PassAs<StreamListenSocket>();
}
} // namespace net
diff --git a/chromium/net/socket/tcp_listen_socket.h b/chromium/net/socket/tcp_listen_socket.h
index dbc5347e945..54a91de59bb 100644
--- a/chromium/net/socket/tcp_listen_socket.h
+++ b/chromium/net/socket/tcp_listen_socket.h
@@ -8,18 +8,19 @@
#include <string>
#include "base/basictypes.h"
-#include "base/memory/ref_counted.h"
#include "net/base/net_export.h"
+#include "net/socket/socket_descriptor.h"
#include "net/socket/stream_listen_socket.h"
namespace net {
-// Implements a TCP socket. Note that this is ref counted.
+// Implements a TCP socket.
class NET_EXPORT TCPListenSocket : public StreamListenSocket {
public:
+ virtual ~TCPListenSocket();
// Listen on port for the specified IP address. Use 127.0.0.1 to only
// accept local connections.
- static scoped_refptr<TCPListenSocket> CreateAndListen(
+ static scoped_ptr<TCPListenSocket> CreateAndListen(
const std::string& ip, int port, StreamListenSocket::Delegate* del);
// Get raw TCP socket descriptor bound to ip:port.
@@ -30,10 +31,7 @@ class NET_EXPORT TCPListenSocket : public StreamListenSocket {
int* port);
protected:
- friend class scoped_refptr<TCPListenSocket>;
-
TCPListenSocket(SocketDescriptor s, StreamListenSocket::Delegate* del);
- virtual ~TCPListenSocket();
// Implements StreamListenSocket::Accept.
virtual void Accept() OVERRIDE;
@@ -49,7 +47,7 @@ class NET_EXPORT TCPListenSocketFactory : public StreamListenSocketFactory {
virtual ~TCPListenSocketFactory();
// StreamListenSocketFactory overrides.
- virtual scoped_refptr<StreamListenSocket> CreateAndListen(
+ virtual scoped_ptr<StreamListenSocket> CreateAndListen(
StreamListenSocket::Delegate* delegate) const OVERRIDE;
private:
diff --git a/chromium/net/socket/tcp_listen_socket_unittest.cc b/chromium/net/socket/tcp_listen_socket_unittest.cc
index d13b784cbdc..b122c6143d8 100644
--- a/chromium/net/socket/tcp_listen_socket_unittest.cc
+++ b/chromium/net/socket/tcp_listen_socket_unittest.cc
@@ -10,13 +10,14 @@
#include "base/bind.h"
#include "base/posix/eintr_wrapper.h"
#include "base/sys_byteorder.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
#include "net/base/net_util.h"
+#include "net/socket/socket_descriptor.h"
#include "testing/platform_test.h"
namespace net {
-const int TCPListenSocketTester::kTestPort = 9999;
-
static const int kReadBufSize = 1024;
static const char kHelloWorld[] = "HELLO, WORLD";
static const int kMaxQueueSize = 20;
@@ -24,7 +25,9 @@ static const char kLoopback[] = "127.0.0.1";
static const int kDefaultTimeoutMs = 5000;
TCPListenSocketTester::TCPListenSocketTester()
- : loop_(NULL), server_(NULL), connection_(NULL), cv_(&lock_) {}
+ : loop_(NULL),
+ cv_(&lock_),
+ server_port_(0) {}
void TCPListenSocketTester::SetUp() {
base::Thread::Options options;
@@ -41,13 +44,16 @@ void TCPListenSocketTester::SetUp() {
ASSERT_FALSE(server_.get() == NULL);
ASSERT_EQ(ACTION_LISTEN, last_action_.type());
+ int server_port = GetServerPort();
+ ASSERT_GT(server_port, 0);
+
// verify the connect/accept and setup test_socket_
- test_socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
- ASSERT_NE(StreamListenSocket::kInvalidSocket, test_socket_);
+ test_socket_ = CreatePlatformSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ ASSERT_NE(kInvalidSocket, test_socket_);
struct sockaddr_in client;
client.sin_family = AF_INET;
client.sin_addr.s_addr = inet_addr(kLoopback);
- client.sin_port = base::HostToNet16(kTestPort);
+ client.sin_port = base::HostToNet16(server_port);
int ret = HANDLE_EINTR(
connect(test_socket_, reinterpret_cast<sockaddr*>(&client),
sizeof(client)));
@@ -113,17 +119,20 @@ int TCPListenSocketTester::ClearTestSocket() {
}
void TCPListenSocketTester::Shutdown() {
- connection_->Release();
- connection_ = NULL;
- server_->Release();
- server_ = NULL;
+ connection_.reset();
+ server_.reset();
ReportAction(TCPListenSocketTestAction(ACTION_SHUTDOWN));
}
void TCPListenSocketTester::Listen() {
server_ = DoListen();
ASSERT_TRUE(server_.get());
- server_->AddRef();
+
+ // The server's port will be needed to open the client socket.
+ IPEndPoint local_address;
+ ASSERT_EQ(OK, server_->GetLocalAddress(&local_address));
+ SetServerPort(local_address.port());
+
ReportAction(TCPListenSocketTestAction(ACTION_LISTEN));
}
@@ -227,10 +236,10 @@ bool TCPListenSocketTester::Send(SocketDescriptor sock,
return true;
}
-void TCPListenSocketTester::DidAccept(StreamListenSocket* server,
- StreamListenSocket* connection) {
- connection_ = connection;
- connection_->AddRef();
+void TCPListenSocketTester::DidAccept(
+ StreamListenSocket* server,
+ scoped_ptr<StreamListenSocket> connection) {
+ connection_ = connection.Pass();
ReportAction(TCPListenSocketTestAction(ACTION_ACCEPT));
}
@@ -247,11 +256,22 @@ void TCPListenSocketTester::DidClose(StreamListenSocket* sock) {
TCPListenSocketTester::~TCPListenSocketTester() {}
-scoped_refptr<TCPListenSocket> TCPListenSocketTester::DoListen() {
- return TCPListenSocket::CreateAndListen(kLoopback, kTestPort, this);
+scoped_ptr<TCPListenSocket> TCPListenSocketTester::DoListen() {
+ // Let the OS pick a free port.
+ return TCPListenSocket::CreateAndListen(kLoopback, 0, this);
+}
+
+int TCPListenSocketTester::GetServerPort() {
+ base::AutoLock locked(lock_);
+ return server_port_;
+}
+
+void TCPListenSocketTester::SetServerPort(int server_port) {
+ base::AutoLock locked(lock_);
+ server_port_ = server_port;
}
-class TCPListenSocketTest: public PlatformTest {
+class TCPListenSocketTest : public PlatformTest {
public:
TCPListenSocketTest() {
tester_ = NULL;
diff --git a/chromium/net/socket/tcp_listen_socket_unittest.h b/chromium/net/socket/tcp_listen_socket_unittest.h
index 048a0186705..1bc31a8d1ce 100644
--- a/chromium/net/socket/tcp_listen_socket_unittest.h
+++ b/chromium/net/socket/tcp_listen_socket_unittest.h
@@ -91,30 +91,37 @@ class TCPListenSocketTester :
// StreamListenSocket::Delegate:
virtual void DidAccept(StreamListenSocket* server,
- StreamListenSocket* connection) OVERRIDE;
+ scoped_ptr<StreamListenSocket> connection) OVERRIDE;
virtual void DidRead(StreamListenSocket* connection, const char* data,
int len) OVERRIDE;
virtual void DidClose(StreamListenSocket* sock) OVERRIDE;
scoped_ptr<base::Thread> thread_;
base::MessageLoopForIO* loop_;
- scoped_refptr<TCPListenSocket> server_;
- StreamListenSocket* connection_;
+ scoped_ptr<TCPListenSocket> server_;
+ scoped_ptr<StreamListenSocket> connection_;
TCPListenSocketTestAction last_action_;
SocketDescriptor test_socket_;
- static const int kTestPort;
- base::Lock lock_; // protects |queue_| and wraps |cv_|
+ base::Lock lock_; // Protects |queue_| and |server_port_|. Wraps |cv_|.
base::ConditionVariable cv_;
std::deque<TCPListenSocketTestAction> queue_;
- protected:
+ private:
friend class base::RefCountedThreadSafe<TCPListenSocketTester>;
virtual ~TCPListenSocketTester();
- virtual scoped_refptr<TCPListenSocket> DoListen();
+ virtual scoped_ptr<TCPListenSocket> DoListen();
+
+ // Getters/setters for |server_port_|. They use |lock_| for thread safety.
+ int GetServerPort();
+ void SetServerPort(int server_port);
+
+ // Port the server is using. Must have |lock_| to access. Set by Listen() on
+ // the server's thread.
+ int server_port_;
};
} // namespace net
diff --git a/chromium/net/socket/tcp_server_socket.cc b/chromium/net/socket/tcp_server_socket.cc
new file mode 100644
index 00000000000..a25f73f6c6f
--- /dev/null
+++ b/chromium/net/socket/tcp_server_socket.cc
@@ -0,0 +1,105 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/tcp_server_socket.h"
+
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/logging.h"
+#include "net/base/net_errors.h"
+#include "net/socket/tcp_client_socket.h"
+
+namespace net {
+
+TCPServerSocket::TCPServerSocket(NetLog* net_log, const NetLog::Source& source)
+ : socket_(net_log, source),
+ pending_accept_(false) {
+}
+
+TCPServerSocket::~TCPServerSocket() {
+}
+
+int TCPServerSocket::Listen(const IPEndPoint& address, int backlog) {
+ int result = socket_.Open(address.GetFamily());
+ if (result != OK)
+ return result;
+
+ result = socket_.SetDefaultOptionsForServer();
+ if (result != OK) {
+ socket_.Close();
+ return result;
+ }
+
+ result = socket_.Bind(address);
+ if (result != OK) {
+ socket_.Close();
+ return result;
+ }
+
+ result = socket_.Listen(backlog);
+ if (result != OK) {
+ socket_.Close();
+ return result;
+ }
+
+ return OK;
+}
+
+int TCPServerSocket::GetLocalAddress(IPEndPoint* address) const {
+ return socket_.GetLocalAddress(address);
+}
+
+int TCPServerSocket::Accept(scoped_ptr<StreamSocket>* socket,
+ const CompletionCallback& callback) {
+ DCHECK(socket);
+ DCHECK(!callback.is_null());
+
+ if (pending_accept_) {
+ NOTREACHED();
+ return ERR_UNEXPECTED;
+ }
+
+ // It is safe to use base::Unretained(this). |socket_| is owned by this class,
+ // and the callback won't be run after |socket_| is destroyed.
+ CompletionCallback accept_callback =
+ base::Bind(&TCPServerSocket::OnAcceptCompleted, base::Unretained(this),
+ socket, callback);
+ int result = socket_.Accept(&accepted_socket_, &accepted_address_,
+ accept_callback);
+ if (result != ERR_IO_PENDING) {
+ // |accept_callback| won't be called so we need to run
+ // ConvertAcceptedSocket() ourselves in order to do the conversion from
+ // |accepted_socket_| to |socket|.
+ result = ConvertAcceptedSocket(result, socket);
+ } else {
+ pending_accept_ = true;
+ }
+
+ return result;
+}
+
+int TCPServerSocket::ConvertAcceptedSocket(
+ int result,
+ scoped_ptr<StreamSocket>* output_accepted_socket) {
+ // Make sure the TCPSocket object is destroyed in any case.
+ scoped_ptr<TCPSocket> temp_accepted_socket(accepted_socket_.Pass());
+ if (result != OK)
+ return result;
+
+ output_accepted_socket->reset(new TCPClientSocket(
+ temp_accepted_socket.Pass(), accepted_address_));
+
+ return OK;
+}
+
+void TCPServerSocket::OnAcceptCompleted(
+ scoped_ptr<StreamSocket>* output_accepted_socket,
+ const CompletionCallback& forward_callback,
+ int result) {
+ result = ConvertAcceptedSocket(result, output_accepted_socket);
+ pending_accept_ = false;
+ forward_callback.Run(result);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/tcp_server_socket.h b/chromium/net/socket/tcp_server_socket.h
index 4970a150e8d..faff9ad826a 100644
--- a/chromium/net/socket/tcp_server_socket.h
+++ b/chromium/net/socket/tcp_server_socket.h
@@ -5,21 +5,48 @@
#ifndef NET_SOCKET_TCP_SERVER_SOCKET_H_
#define NET_SOCKET_TCP_SERVER_SOCKET_H_
-#include "build/build_config.h"
-
-#if defined(OS_WIN)
-#include "net/socket/tcp_server_socket_win.h"
-#elif defined(OS_POSIX)
-#include "net/socket/tcp_server_socket_libevent.h"
-#endif
+#include "base/basictypes.h"
+#include "base/compiler_specific.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_export.h"
+#include "net/base/net_log.h"
+#include "net/socket/server_socket.h"
+#include "net/socket/tcp_socket.h"
namespace net {
-#if defined(OS_WIN)
-typedef TCPServerSocketWin TCPServerSocket;
-#elif defined(OS_POSIX)
-typedef TCPServerSocketLibevent TCPServerSocket;
-#endif
+class NET_EXPORT_PRIVATE TCPServerSocket : public ServerSocket {
+ public:
+ TCPServerSocket(NetLog* net_log, const NetLog::Source& source);
+ virtual ~TCPServerSocket();
+
+ // net::ServerSocket implementation.
+ virtual int Listen(const IPEndPoint& address, int backlog) OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int Accept(scoped_ptr<StreamSocket>* socket,
+ const CompletionCallback& callback) OVERRIDE;
+
+ private:
+ // Converts |accepted_socket_| and stores the result in
+ // |output_accepted_socket|.
+ // |output_accepted_socket| is untouched on failure. But |accepted_socket_| is
+ // set to NULL in any case.
+ int ConvertAcceptedSocket(int result,
+ scoped_ptr<StreamSocket>* output_accepted_socket);
+ // Completion callback for calling TCPSocket::Accept().
+ void OnAcceptCompleted(scoped_ptr<StreamSocket>* output_accepted_socket,
+ const CompletionCallback& forward_callback,
+ int result);
+
+ TCPSocket socket_;
+
+ scoped_ptr<TCPSocket> accepted_socket_;
+ IPEndPoint accepted_address_;
+ bool pending_accept_;
+
+ DISALLOW_COPY_AND_ASSIGN(TCPServerSocket);
+};
} // namespace net
diff --git a/chromium/net/socket/tcp_server_socket_libevent.cc b/chromium/net/socket/tcp_server_socket_libevent.cc
deleted file mode 100644
index 38dda962f46..00000000000
--- a/chromium/net/socket/tcp_server_socket_libevent.cc
+++ /dev/null
@@ -1,223 +0,0 @@
-// Copyright (c) 2012 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "net/socket/tcp_server_socket_libevent.h"
-
-#include <errno.h>
-#include <fcntl.h>
-#include <netdb.h>
-#include <sys/socket.h>
-
-#include "build/build_config.h"
-
-#if defined(OS_POSIX)
-#include <netinet/in.h>
-#endif
-
-#include "base/posix/eintr_wrapper.h"
-#include "net/base/ip_endpoint.h"
-#include "net/base/net_errors.h"
-#include "net/base/net_util.h"
-#include "net/socket/socket_net_log_params.h"
-#include "net/socket/tcp_client_socket.h"
-
-namespace net {
-
-namespace {
-
-const int kInvalidSocket = -1;
-
-} // namespace
-
-TCPServerSocketLibevent::TCPServerSocketLibevent(
- net::NetLog* net_log,
- const net::NetLog::Source& source)
- : socket_(kInvalidSocket),
- accept_socket_(NULL),
- net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
- net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE,
- source.ToEventParametersCallback());
-}
-
-TCPServerSocketLibevent::~TCPServerSocketLibevent() {
- if (socket_ != kInvalidSocket)
- Close();
- net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE);
-}
-
-int TCPServerSocketLibevent::Listen(const IPEndPoint& address, int backlog) {
- DCHECK(CalledOnValidThread());
- DCHECK_GT(backlog, 0);
- DCHECK_EQ(socket_, kInvalidSocket);
-
- socket_ = socket(address.GetSockAddrFamily(), SOCK_STREAM, IPPROTO_TCP);
- if (socket_ < 0) {
- PLOG(ERROR) << "socket() returned an error";
- return MapSystemError(errno);
- }
-
- if (SetNonBlocking(socket_)) {
- int result = MapSystemError(errno);
- Close();
- return result;
- }
-
- int result = SetSocketOptions();
- if (result != OK) {
- Close();
- return result;
- }
-
- SockaddrStorage storage;
- if (!address.ToSockAddr(storage.addr, &storage.addr_len)) {
- Close();
- return ERR_ADDRESS_INVALID;
- }
-
- result = bind(socket_, storage.addr, storage.addr_len);
- if (result < 0) {
- PLOG(ERROR) << "bind() returned an error";
- result = MapSystemError(errno);
- Close();
- return result;
- }
-
- result = listen(socket_, backlog);
- if (result < 0) {
- PLOG(ERROR) << "listen() returned an error";
- result = MapSystemError(errno);
- Close();
- return result;
- }
-
- return OK;
-}
-
-int TCPServerSocketLibevent::GetLocalAddress(IPEndPoint* address) const {
- DCHECK(CalledOnValidThread());
- DCHECK(address);
-
- SockaddrStorage storage;
- if (getsockname(socket_, storage.addr, &storage.addr_len) < 0)
- return MapSystemError(errno);
- if (!address->FromSockAddr(storage.addr, storage.addr_len))
- return ERR_FAILED;
-
- return OK;
-}
-
-int TCPServerSocketLibevent::Accept(
- scoped_ptr<StreamSocket>* socket, const CompletionCallback& callback) {
- DCHECK(CalledOnValidThread());
- DCHECK(socket);
- DCHECK(!callback.is_null());
- DCHECK(accept_callback_.is_null());
-
- net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT);
-
- int result = AcceptInternal(socket);
-
- if (result == ERR_IO_PENDING) {
- if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
- socket_, true, base::MessageLoopForIO::WATCH_READ,
- &accept_socket_watcher_, this)) {
- PLOG(ERROR) << "WatchFileDescriptor failed on read";
- return MapSystemError(errno);
- }
-
- accept_socket_ = socket;
- accept_callback_ = callback;
- }
-
- return result;
-}
-
-int TCPServerSocketLibevent::SetSocketOptions() {
- // SO_REUSEADDR is useful for server sockets to bind to a recently unbound
- // port. When a socket is closed, the end point changes its state to TIME_WAIT
- // and wait for 2 MSL (maximum segment lifetime) to ensure the remote peer
- // acknowledges its closure. For server sockets, it is usually safe to
- // bind to a TIME_WAIT end point immediately, which is a widely adopted
- // behavior.
- //
- // Note that on *nix, SO_REUSEADDR does not enable the TCP socket to bind to
- // an end point that is already bound by another socket. To do that one must
- // set SO_REUSEPORT instead. This option is not provided on Linux prior
- // to 3.9.
- //
- // SO_REUSEPORT is provided in MacOS X and iOS.
- int true_value = 1;
- int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &true_value,
- sizeof(true_value));
- if (rv < 0)
- return MapSystemError(errno);
- return OK;
-}
-
-int TCPServerSocketLibevent::AcceptInternal(
- scoped_ptr<StreamSocket>* socket) {
- SockaddrStorage storage;
- int new_socket = HANDLE_EINTR(accept(socket_,
- storage.addr,
- &storage.addr_len));
- if (new_socket < 0) {
- int net_error = MapSystemError(errno);
- if (net_error != ERR_IO_PENDING)
- net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error);
- return net_error;
- }
-
- IPEndPoint address;
- if (!address.FromSockAddr(storage.addr, storage.addr_len)) {
- NOTREACHED();
- if (HANDLE_EINTR(close(new_socket)) < 0)
- PLOG(ERROR) << "close";
- net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED);
- return ERR_FAILED;
- }
- scoped_ptr<TCPClientSocket> tcp_socket(new TCPClientSocket(
- AddressList(address),
- net_log_.net_log(), net_log_.source()));
- int adopt_result = tcp_socket->AdoptSocket(new_socket);
- if (adopt_result != OK) {
- if (HANDLE_EINTR(close(new_socket)) < 0)
- PLOG(ERROR) << "close";
- net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result);
- return adopt_result;
- }
- socket->reset(tcp_socket.release());
- net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT,
- CreateNetLogIPEndPointCallback(&address));
- return OK;
-}
-
-void TCPServerSocketLibevent::Close() {
- if (socket_ != kInvalidSocket) {
- bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
- DCHECK(ok);
- if (HANDLE_EINTR(close(socket_)) < 0)
- PLOG(ERROR) << "close";
- socket_ = kInvalidSocket;
- }
-}
-
-void TCPServerSocketLibevent::OnFileCanReadWithoutBlocking(int fd) {
- DCHECK(CalledOnValidThread());
-
- int result = AcceptInternal(accept_socket_);
- if (result != ERR_IO_PENDING) {
- accept_socket_ = NULL;
- bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
- DCHECK(ok);
- CompletionCallback callback = accept_callback_;
- accept_callback_.Reset();
- callback.Run(result);
- }
-}
-
-void TCPServerSocketLibevent::OnFileCanWriteWithoutBlocking(int fd) {
- NOTREACHED();
-}
-
-} // namespace net
diff --git a/chromium/net/socket/tcp_server_socket_libevent.h b/chromium/net/socket/tcp_server_socket_libevent.h
deleted file mode 100644
index fe69472a653..00000000000
--- a/chromium/net/socket/tcp_server_socket_libevent.h
+++ /dev/null
@@ -1,55 +0,0 @@
-// Copyright (c) 2011 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_
-#define NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_
-
-#include "base/memory/scoped_ptr.h"
-#include "base/message_loop/message_loop.h"
-#include "base/threading/non_thread_safe.h"
-#include "net/base/completion_callback.h"
-#include "net/base/net_log.h"
-#include "net/socket/server_socket.h"
-
-namespace net {
-
-class IPEndPoint;
-
-class NET_EXPORT_PRIVATE TCPServerSocketLibevent :
- public ServerSocket,
- public base::NonThreadSafe,
- public base::MessageLoopForIO::Watcher {
- public:
- TCPServerSocketLibevent(net::NetLog* net_log,
- const net::NetLog::Source& source);
- virtual ~TCPServerSocketLibevent();
-
- // net::ServerSocket implementation.
- virtual int Listen(const net::IPEndPoint& address, int backlog) OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
- virtual int Accept(scoped_ptr<StreamSocket>* socket,
- const CompletionCallback& callback) OVERRIDE;
-
- // MessageLoopForIO::Watcher implementation.
- virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE;
- virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE;
-
- private:
- int SetSocketOptions();
- int AcceptInternal(scoped_ptr<StreamSocket>* socket);
- void Close();
-
- int socket_;
-
- base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_;
-
- scoped_ptr<StreamSocket>* accept_socket_;
- CompletionCallback accept_callback_;
-
- BoundNetLog net_log_;
-};
-
-} // namespace net
-
-#endif // NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_
diff --git a/chromium/net/socket/tcp_server_socket_win.cc b/chromium/net/socket/tcp_server_socket_win.cc
deleted file mode 100644
index 0ac77be5e81..00000000000
--- a/chromium/net/socket/tcp_server_socket_win.cc
+++ /dev/null
@@ -1,217 +0,0 @@
-// Copyright (c) 2012 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "net/socket/tcp_server_socket_win.h"
-
-#include <mstcpip.h>
-
-#include "net/base/ip_endpoint.h"
-#include "net/base/net_errors.h"
-#include "net/base/net_util.h"
-#include "net/base/winsock_init.h"
-#include "net/base/winsock_util.h"
-#include "net/socket/socket_net_log_params.h"
-#include "net/socket/tcp_client_socket.h"
-
-namespace net {
-
-TCPServerSocketWin::TCPServerSocketWin(net::NetLog* net_log,
- const net::NetLog::Source& source)
- : socket_(INVALID_SOCKET),
- socket_event_(WSA_INVALID_EVENT),
- accept_socket_(NULL),
- net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
- net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE,
- source.ToEventParametersCallback());
- EnsureWinsockInit();
-}
-
-TCPServerSocketWin::~TCPServerSocketWin() {
- Close();
- net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE);
-}
-
-int TCPServerSocketWin::Listen(const IPEndPoint& address, int backlog) {
- DCHECK(CalledOnValidThread());
- DCHECK_GT(backlog, 0);
- DCHECK_EQ(socket_, INVALID_SOCKET);
- DCHECK_EQ(socket_event_, WSA_INVALID_EVENT);
-
- socket_event_ = WSACreateEvent();
- if (socket_event_ == WSA_INVALID_EVENT) {
- PLOG(ERROR) << "WSACreateEvent()";
- return ERR_FAILED;
- }
-
- socket_ = socket(address.GetSockAddrFamily(), SOCK_STREAM, IPPROTO_TCP);
- if (socket_ == INVALID_SOCKET) {
- PLOG(ERROR) << "socket() returned an error";
- return MapSystemError(WSAGetLastError());
- }
-
- if (SetNonBlocking(socket_)) {
- int result = MapSystemError(WSAGetLastError());
- Close();
- return result;
- }
-
- int result = SetSocketOptions();
- if (result != OK) {
- Close();
- return result;
- }
-
- SockaddrStorage storage;
- if (!address.ToSockAddr(storage.addr, &storage.addr_len)) {
- Close();
- return ERR_ADDRESS_INVALID;
- }
-
- result = bind(socket_, storage.addr, storage.addr_len);
- if (result < 0) {
- PLOG(ERROR) << "bind() returned an error";
- result = MapSystemError(WSAGetLastError());
- Close();
- return result;
- }
-
- result = listen(socket_, backlog);
- if (result < 0) {
- PLOG(ERROR) << "listen() returned an error";
- result = MapSystemError(WSAGetLastError());
- Close();
- return result;
- }
-
- return OK;
-}
-
-int TCPServerSocketWin::GetLocalAddress(IPEndPoint* address) const {
- DCHECK(CalledOnValidThread());
- DCHECK(address);
-
- SockaddrStorage storage;
- if (getsockname(socket_, storage.addr, &storage.addr_len))
- return MapSystemError(WSAGetLastError());
- if (!address->FromSockAddr(storage.addr, storage.addr_len))
- return ERR_FAILED;
-
- return OK;
-}
-
-int TCPServerSocketWin::Accept(
- scoped_ptr<StreamSocket>* socket, const CompletionCallback& callback) {
- DCHECK(CalledOnValidThread());
- DCHECK(socket);
- DCHECK(!callback.is_null());
- DCHECK(accept_callback_.is_null());
-
- net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT);
-
- int result = AcceptInternal(socket);
-
- if (result == ERR_IO_PENDING) {
- // Start watching
- WSAEventSelect(socket_, socket_event_, FD_ACCEPT);
- accept_watcher_.StartWatching(socket_event_, this);
-
- accept_socket_ = socket;
- accept_callback_ = callback;
- }
-
- return result;
-}
-
-int TCPServerSocketWin::SetSocketOptions() {
- // On Windows, a bound end point can be hijacked by another process by
- // setting SO_REUSEADDR. Therefore a Windows-only option SO_EXCLUSIVEADDRUSE
- // was introduced in Windows NT 4.0 SP4. If the socket that is bound to the
- // end point has SO_EXCLUSIVEADDRUSE enabled, it is not possible for another
- // socket to forcibly bind to the end point until the end point is unbound.
- // It is recommend that all server applications must use SO_EXCLUSIVEADDRUSE.
- // MSDN: http://goo.gl/M6fjQ.
- //
- // Unlike on *nix, on Windows a TCP server socket can always bind to an end
- // point in TIME_WAIT state without setting SO_REUSEADDR, therefore it is not
- // needed here.
- //
- // SO_EXCLUSIVEADDRUSE will prevent a TCP client socket from binding to an end
- // point in TIME_WAIT status. It does not have this effect for a TCP server
- // socket.
-
- BOOL true_value = 1;
- int rv = setsockopt(socket_, SOL_SOCKET, SO_EXCLUSIVEADDRUSE,
- reinterpret_cast<const char*>(&true_value),
- sizeof(true_value));
- if (rv < 0)
- return MapSystemError(errno);
- return OK;
-}
-
-int TCPServerSocketWin::AcceptInternal(scoped_ptr<StreamSocket>* socket) {
- SockaddrStorage storage;
- int new_socket = accept(socket_, storage.addr, &storage.addr_len);
- if (new_socket < 0) {
- int net_error = MapSystemError(WSAGetLastError());
- if (net_error != ERR_IO_PENDING)
- net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error);
- return net_error;
- }
-
- IPEndPoint address;
- if (!address.FromSockAddr(storage.addr, storage.addr_len)) {
- NOTREACHED();
- if (closesocket(new_socket) < 0)
- PLOG(ERROR) << "closesocket";
- net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED);
- return ERR_FAILED;
- }
- scoped_ptr<TCPClientSocket> tcp_socket(new TCPClientSocket(
- AddressList(address),
- net_log_.net_log(), net_log_.source()));
- int adopt_result = tcp_socket->AdoptSocket(new_socket);
- if (adopt_result != OK) {
- if (closesocket(new_socket) < 0)
- PLOG(ERROR) << "closesocket";
- net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result);
- return adopt_result;
- }
- socket->reset(tcp_socket.release());
- net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT,
- CreateNetLogIPEndPointCallback(&address));
- return OK;
-}
-
-void TCPServerSocketWin::Close() {
- if (socket_ != INVALID_SOCKET) {
- if (closesocket(socket_) < 0)
- PLOG(ERROR) << "closesocket";
- socket_ = INVALID_SOCKET;
- }
-
- if (socket_event_) {
- WSACloseEvent(socket_event_);
- socket_event_ = WSA_INVALID_EVENT;
- }
-}
-
-void TCPServerSocketWin::OnObjectSignaled(HANDLE object) {
- WSANETWORKEVENTS ev;
- if (WSAEnumNetworkEvents(socket_, socket_event_, &ev) == SOCKET_ERROR) {
- PLOG(ERROR) << "WSAEnumNetworkEvents()";
- return;
- }
-
- if (ev.lNetworkEvents & FD_ACCEPT) {
- int result = AcceptInternal(accept_socket_);
- if (result != ERR_IO_PENDING) {
- accept_socket_ = NULL;
- CompletionCallback callback = accept_callback_;
- accept_callback_.Reset();
- callback.Run(result);
- }
- }
-}
-
-} // namespace net
diff --git a/chromium/net/socket/tcp_server_socket_win.h b/chromium/net/socket/tcp_server_socket_win.h
deleted file mode 100644
index 5a1d378ad9b..00000000000
--- a/chromium/net/socket/tcp_server_socket_win.h
+++ /dev/null
@@ -1,58 +0,0 @@
-// Copyright (c) 2011 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_
-#define NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_
-
-#include <winsock2.h>
-
-#include "base/memory/scoped_ptr.h"
-#include "base/message_loop/message_loop.h"
-#include "base/threading/non_thread_safe.h"
-#include "base/win/object_watcher.h"
-#include "net/base/completion_callback.h"
-#include "net/base/net_log.h"
-#include "net/socket/server_socket.h"
-
-namespace net {
-
-class IPEndPoint;
-
-class NET_EXPORT_PRIVATE TCPServerSocketWin
- : public ServerSocket,
- NON_EXPORTED_BASE(public base::NonThreadSafe),
- public base::win::ObjectWatcher::Delegate {
- public:
- TCPServerSocketWin(net::NetLog* net_log,
- const net::NetLog::Source& source);
- ~TCPServerSocketWin();
-
- // net::ServerSocket implementation.
- virtual int Listen(const net::IPEndPoint& address, int backlog) OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
- virtual int Accept(scoped_ptr<StreamSocket>* socket,
- const CompletionCallback& callback) OVERRIDE;
-
- // base::ObjectWatcher::Delegate implementation.
- virtual void OnObjectSignaled(HANDLE object);
-
- private:
- int SetSocketOptions();
- int AcceptInternal(scoped_ptr<StreamSocket>* socket);
- void Close();
-
- SOCKET socket_;
- HANDLE socket_event_;
-
- base::win::ObjectWatcher accept_watcher_;
-
- scoped_ptr<StreamSocket>* accept_socket_;
- CompletionCallback accept_callback_;
-
- BoundNetLog net_log_;
-};
-
-} // namespace net
-
-#endif // NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_
diff --git a/chromium/net/socket/tcp_socket.cc b/chromium/net/socket/tcp_socket.cc
new file mode 100644
index 00000000000..fd72f6b4640
--- /dev/null
+++ b/chromium/net/socket/tcp_socket.cc
@@ -0,0 +1,59 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/tcp_socket.h"
+
+#include "base/file_util.h"
+#include "base/files/file_path.h"
+
+namespace net {
+
+namespace {
+
+#if defined(OS_LINUX)
+
+// Checks to see if the system supports TCP FastOpen. Notably, it requires
+// kernel support. Additionally, this checks system configuration to ensure that
+// it's enabled.
+bool SystemSupportsTCPFastOpen() {
+ static const base::FilePath::CharType kTCPFastOpenProcFilePath[] =
+ "/proc/sys/net/ipv4/tcp_fastopen";
+ std::string system_enabled_tcp_fastopen;
+ if (!base::ReadFileToString(
+ base::FilePath(kTCPFastOpenProcFilePath),
+ &system_enabled_tcp_fastopen)) {
+ return false;
+ }
+
+ // As per http://lxr.linux.no/linux+v3.7.7/include/net/tcp.h#L225
+ // TFO_CLIENT_ENABLE is the LSB
+ if (system_enabled_tcp_fastopen.empty() ||
+ (system_enabled_tcp_fastopen[0] & 0x1) == 0) {
+ return false;
+ }
+
+ return true;
+}
+
+#else
+
+bool SystemSupportsTCPFastOpen() {
+ return false;
+}
+
+#endif
+
+bool g_tcp_fastopen_enabled = false;
+
+} // namespace
+
+void SetTCPFastOpenEnabled(bool value) {
+ g_tcp_fastopen_enabled = value && SystemSupportsTCPFastOpen();
+}
+
+bool IsTCPFastOpenEnabled() {
+ return g_tcp_fastopen_enabled;
+}
+
+} // namespace net
diff --git a/chromium/net/socket/tcp_socket.h b/chromium/net/socket/tcp_socket.h
new file mode 100644
index 00000000000..8b36fade758
--- /dev/null
+++ b/chromium/net/socket/tcp_socket.h
@@ -0,0 +1,40 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_TCP_SOCKET_H_
+#define NET_SOCKET_TCP_SOCKET_H_
+
+#include "build/build_config.h"
+#include "net/base/net_export.h"
+
+#if defined(OS_WIN)
+#include "net/socket/tcp_socket_win.h"
+#elif defined(OS_POSIX)
+#include "net/socket/tcp_socket_libevent.h"
+#endif
+
+namespace net {
+
+// Enable/disable experimental TCP FastOpen option.
+// Not thread safe. Must be called during initialization/startup only.
+NET_EXPORT void SetTCPFastOpenEnabled(bool value);
+
+// Check if the TCP FastOpen option is enabled.
+bool IsTCPFastOpenEnabled();
+
+// TCPSocket provides a platform-independent interface for TCP sockets.
+//
+// It is recommended to use TCPClientSocket/TCPServerSocket instead of this
+// class, unless a clear separation of client and server socket functionality is
+// not suitable for your use case (e.g., a socket needs to be created and bound
+// before you know whether it is a client or server socket).
+#if defined(OS_WIN)
+typedef TCPSocketWin TCPSocket;
+#elif defined(OS_POSIX)
+typedef TCPSocketLibevent TCPSocket;
+#endif
+
+} // namespace net
+
+#endif // NET_SOCKET_TCP_SOCKET_H_
diff --git a/chromium/net/socket/tcp_client_socket_libevent.cc b/chromium/net/socket/tcp_socket_libevent.cc
index 2f7e4b4b255..66416f70207 100644
--- a/chromium/net/socket/tcp_client_socket_libevent.cc
+++ b/chromium/net/socket/tcp_socket_libevent.cc
@@ -1,29 +1,27 @@
-// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#include "net/socket/tcp_client_socket.h"
+#include "net/socket/tcp_socket.h"
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
-#include <sys/socket.h>
-#include <netinet/tcp.h>
-#if defined(OS_POSIX)
#include <netinet/in.h>
-#endif
+#include <netinet/tcp.h>
+#include <sys/socket.h>
+#include "base/callback_helpers.h"
#include "base/logging.h"
-#include "base/message_loop/message_loop.h"
#include "base/metrics/histogram.h"
#include "base/metrics/stats_counters.h"
#include "base/posix/eintr_wrapper.h"
-#include "base/strings/string_util.h"
+#include "build/build_config.h"
+#include "net/base/address_list.h"
#include "net/base/connection_type_histograms.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
-#include "net/base/net_log.h"
#include "net/base/net_util.h"
#include "net/base/network_change_notifier.h"
#include "net/socket/socket_net_log_params.h"
@@ -37,7 +35,6 @@ namespace net {
namespace {
-const int kInvalidSocket = -1;
const int kTCPKeepAliveSeconds = 45;
// SetTCPNoDelay turns on/off buffering in the kernel. By default, TCP sockets
@@ -46,13 +43,12 @@ const int kTCPKeepAliveSeconds = 45;
// `man 7 tcp`.
bool SetTCPNoDelay(int fd, bool no_delay) {
int on = no_delay ? 1 : 0;
- int error = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &on,
- sizeof(on));
+ int error = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on));
return error == 0;
}
// SetTCPKeepAlive sets SO_KEEPALIVE.
-bool SetTCPKeepAlive(int fd, bool enable, int delay) {
+bool SetTCPKeepAlive(int fd, bool enable, int delay) {
int on = enable ? 1 : 0;
if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &on, sizeof(on))) {
PLOG(ERROR) << "Failed to set SO_KEEPALIVE on fd: " << fd;
@@ -73,36 +69,6 @@ bool SetTCPKeepAlive(int fd, bool enable, int delay) {
return true;
}
-// Sets socket parameters. Returns the OS error code (or 0 on
-// success).
-int SetupSocket(int socket) {
- if (SetNonBlocking(socket))
- return errno;
-
- // This mirrors the behaviour on Windows. See the comment in
- // tcp_client_socket_win.cc after searching for "NODELAY".
- SetTCPNoDelay(socket, true); // If SetTCPNoDelay fails, we don't care.
- SetTCPKeepAlive(socket, true, kTCPKeepAliveSeconds);
-
- return 0;
-}
-
-// Creates a new socket and sets default parameters for it. Returns
-// the OS error code (or 0 on success).
-int CreateSocket(int family, int* socket) {
- *socket = ::socket(family, SOCK_STREAM, IPPROTO_TCP);
- if (*socket == kInvalidSocket)
- return errno;
- int error = SetupSocket(*socket);
- if (error) {
- if (HANDLE_EINTR(close(*socket)) < 0)
- PLOG(ERROR) << "close";
- *socket = kInvalidSocket;
- return error;
- }
- return 0;
-}
-
int MapConnectError(int os_error) {
switch (os_error) {
case EACCES:
@@ -128,275 +94,206 @@ int MapConnectError(int os_error) {
//-----------------------------------------------------------------------------
-TCPClientSocketLibevent::TCPClientSocketLibevent(
- const AddressList& addresses,
- net::NetLog* net_log,
- const net::NetLog::Source& source)
+TCPSocketLibevent::Watcher::Watcher(
+ const base::Closure& read_ready_callback,
+ const base::Closure& write_ready_callback)
+ : read_ready_callback_(read_ready_callback),
+ write_ready_callback_(write_ready_callback) {
+}
+
+TCPSocketLibevent::Watcher::~Watcher() {
+}
+
+void TCPSocketLibevent::Watcher::OnFileCanReadWithoutBlocking(int /* fd */) {
+ if (!read_ready_callback_.is_null())
+ read_ready_callback_.Run();
+ else
+ NOTREACHED();
+}
+
+void TCPSocketLibevent::Watcher::OnFileCanWriteWithoutBlocking(int /* fd */) {
+ if (!write_ready_callback_.is_null())
+ write_ready_callback_.Run();
+ else
+ NOTREACHED();
+}
+
+TCPSocketLibevent::TCPSocketLibevent(NetLog* net_log,
+ const NetLog::Source& source)
: socket_(kInvalidSocket),
- bound_socket_(kInvalidSocket),
- addresses_(addresses),
- current_address_index_(-1),
- read_watcher_(this),
- write_watcher_(this),
- next_connect_state_(CONNECT_STATE_NONE),
- connect_os_error_(0),
- net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)),
- previously_disconnected_(false),
+ accept_watcher_(base::Bind(&TCPSocketLibevent::DidCompleteAccept,
+ base::Unretained(this)),
+ base::Closure()),
+ accept_socket_(NULL),
+ accept_address_(NULL),
+ read_watcher_(base::Bind(&TCPSocketLibevent::DidCompleteRead,
+ base::Unretained(this)),
+ base::Closure()),
+ write_watcher_(base::Closure(),
+ base::Bind(&TCPSocketLibevent::DidCompleteConnectOrWrite,
+ base::Unretained(this))),
+ read_buf_len_(0),
+ write_buf_len_(0),
use_tcp_fastopen_(IsTCPFastOpenEnabled()),
tcp_fastopen_connected_(false),
- fast_open_status_(FAST_OPEN_STATUS_UNKNOWN) {
+ fast_open_status_(FAST_OPEN_STATUS_UNKNOWN),
+ waiting_connect_(false),
+ connect_os_error_(0),
+ logging_multiple_connect_attempts_(false),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE,
source.ToEventParametersCallback());
}
-TCPClientSocketLibevent::~TCPClientSocketLibevent() {
- Disconnect();
+TCPSocketLibevent::~TCPSocketLibevent() {
net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE);
if (tcp_fastopen_connected_) {
UMA_HISTOGRAM_ENUMERATION("Net.TcpFastOpenSocketConnection",
fast_open_status_, FAST_OPEN_MAX_VALUE);
}
+ Close();
}
-int TCPClientSocketLibevent::AdoptSocket(int socket) {
+int TCPSocketLibevent::Open(AddressFamily family) {
+ DCHECK(CalledOnValidThread());
DCHECK_EQ(socket_, kInvalidSocket);
- int error = SetupSocket(socket);
- if (error)
- return MapSystemError(error);
-
- socket_ = socket;
+ socket_ = CreatePlatformSocket(ConvertAddressFamily(family), SOCK_STREAM,
+ IPPROTO_TCP);
+ if (socket_ < 0) {
+ PLOG(ERROR) << "CreatePlatformSocket() returned an error";
+ return MapSystemError(errno);
+ }
- // This is to make GetPeerAddress() work. It's up to the caller ensure
- // that |address_| contains a reasonable address for this
- // socket. (i.e. at least match IPv4 vs IPv6!).
- current_address_index_ = 0;
- use_history_.set_was_ever_connected();
+ if (SetNonBlocking(socket_)) {
+ int result = MapSystemError(errno);
+ Close();
+ return result;
+ }
return OK;
}
-int TCPClientSocketLibevent::Bind(const IPEndPoint& address) {
- if (current_address_index_ >= 0 || bind_address_.get()) {
- // Cannot bind the socket if we are already bound connected or
- // connecting.
- return ERR_UNEXPECTED;
- }
-
- SockaddrStorage storage;
- if (!address.ToSockAddr(storage.addr, &storage.addr_len))
- return ERR_INVALID_ARGUMENT;
+int TCPSocketLibevent::AdoptConnectedSocket(int socket,
+ const IPEndPoint& peer_address) {
+ DCHECK(CalledOnValidThread());
+ DCHECK_EQ(socket_, kInvalidSocket);
- // Create |bound_socket_| and try to bind it to |address|.
- int error = CreateSocket(address.GetSockAddrFamily(), &bound_socket_);
- if (error)
- return MapSystemError(error);
+ socket_ = socket;
- if (HANDLE_EINTR(bind(bound_socket_, storage.addr, storage.addr_len))) {
- error = errno;
- if (HANDLE_EINTR(close(bound_socket_)) < 0)
- PLOG(ERROR) << "close";
- bound_socket_ = kInvalidSocket;
- return MapSystemError(error);
+ if (SetNonBlocking(socket_)) {
+ int result = MapSystemError(errno);
+ Close();
+ return result;
}
- bind_address_.reset(new IPEndPoint(address));
+ peer_address_.reset(new IPEndPoint(peer_address));
- return 0;
+ return OK;
}
-int TCPClientSocketLibevent::Connect(const CompletionCallback& callback) {
+int TCPSocketLibevent::Bind(const IPEndPoint& address) {
DCHECK(CalledOnValidThread());
+ DCHECK_NE(socket_, kInvalidSocket);
- // If already connected, then just return OK.
- if (socket_ != kInvalidSocket)
- return OK;
-
- base::StatsCounter connects("tcp.connect");
- connects.Increment();
-
- DCHECK(!waiting_connect());
-
- net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT,
- addresses_.CreateNetLogCallback());
-
- // We will try to connect to each address in addresses_. Start with the
- // first one in the list.
- next_connect_state_ = CONNECT_STATE_CONNECT;
- current_address_index_ = 0;
+ SockaddrStorage storage;
+ if (!address.ToSockAddr(storage.addr, &storage.addr_len))
+ return ERR_ADDRESS_INVALID;
- int rv = DoConnectLoop(OK);
- if (rv == ERR_IO_PENDING) {
- // Synchronous operation not supported.
- DCHECK(!callback.is_null());
- write_callback_ = callback;
- } else {
- LogConnectCompletion(rv);
+ int result = bind(socket_, storage.addr, storage.addr_len);
+ if (result < 0) {
+ PLOG(ERROR) << "bind() returned an error";
+ return MapSystemError(errno);
}
- return rv;
-}
-
-int TCPClientSocketLibevent::DoConnectLoop(int result) {
- DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE);
-
- int rv = result;
- do {
- ConnectState state = next_connect_state_;
- next_connect_state_ = CONNECT_STATE_NONE;
- switch (state) {
- case CONNECT_STATE_CONNECT:
- DCHECK_EQ(OK, rv);
- rv = DoConnect();
- break;
- case CONNECT_STATE_CONNECT_COMPLETE:
- rv = DoConnectComplete(rv);
- break;
- default:
- LOG(DFATAL) << "bad state";
- rv = ERR_UNEXPECTED;
- break;
- }
- } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE);
-
- return rv;
+ return OK;
}
-int TCPClientSocketLibevent::DoConnect() {
- DCHECK_GE(current_address_index_, 0);
- DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size()));
- DCHECK_EQ(0, connect_os_error_);
-
- const IPEndPoint& endpoint = addresses_[current_address_index_];
+int TCPSocketLibevent::Listen(int backlog) {
+ DCHECK(CalledOnValidThread());
+ DCHECK_GT(backlog, 0);
+ DCHECK_NE(socket_, kInvalidSocket);
- if (previously_disconnected_) {
- use_history_.Reset();
- previously_disconnected_ = false;
+ int result = listen(socket_, backlog);
+ if (result < 0) {
+ PLOG(ERROR) << "listen() returned an error";
+ return MapSystemError(errno);
}
- net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
- CreateNetLogIPEndPointCallback(&endpoint));
+ return OK;
+}
- next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE;
+int TCPSocketLibevent::Accept(scoped_ptr<TCPSocketLibevent>* socket,
+ IPEndPoint* address,
+ const CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(socket);
+ DCHECK(address);
+ DCHECK(!callback.is_null());
+ DCHECK(accept_callback_.is_null());
- if (bound_socket_ != kInvalidSocket) {
- DCHECK(bind_address_.get());
- socket_ = bound_socket_;
- bound_socket_ = kInvalidSocket;
- } else {
- // Create a non-blocking socket.
- connect_os_error_ = CreateSocket(endpoint.GetSockAddrFamily(), &socket_);
- if (connect_os_error_)
- return MapSystemError(connect_os_error_);
-
- if (bind_address_.get()) {
- SockaddrStorage storage;
- if (!bind_address_->ToSockAddr(storage.addr, &storage.addr_len))
- return ERR_INVALID_ARGUMENT;
- if (HANDLE_EINTR(bind(socket_, storage.addr, storage.addr_len)))
- return MapSystemError(errno);
- }
- }
+ net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT);
- // Connect the socket.
- if (!use_tcp_fastopen_) {
- SockaddrStorage storage;
- if (!endpoint.ToSockAddr(storage.addr, &storage.addr_len))
- return ERR_INVALID_ARGUMENT;
+ int result = AcceptInternal(socket, address);
- if (!HANDLE_EINTR(connect(socket_, storage.addr, storage.addr_len))) {
- // Connected without waiting!
- return OK;
+ if (result == ERR_IO_PENDING) {
+ if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
+ socket_, true, base::MessageLoopForIO::WATCH_READ,
+ &accept_socket_watcher_, &accept_watcher_)) {
+ PLOG(ERROR) << "WatchFileDescriptor failed on read";
+ return MapSystemError(errno);
}
- } else {
- // With TCP FastOpen, we pretend that the socket is connected.
- DCHECK(!tcp_fastopen_connected_);
- return OK;
- }
-
- // Check if the connect() failed synchronously.
- connect_os_error_ = errno;
- if (connect_os_error_ != EINPROGRESS)
- return MapConnectError(connect_os_error_);
- // Otherwise the connect() is going to complete asynchronously, so watch
- // for its completion.
- if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
- socket_, true, base::MessageLoopForIO::WATCH_WRITE,
- &write_socket_watcher_, &write_watcher_)) {
- connect_os_error_ = errno;
- DVLOG(1) << "WatchFileDescriptor failed: " << connect_os_error_;
- return MapSystemError(connect_os_error_);
+ accept_socket_ = socket;
+ accept_address_ = address;
+ accept_callback_ = callback;
}
- return ERR_IO_PENDING;
-}
-
-int TCPClientSocketLibevent::DoConnectComplete(int result) {
- // Log the end of this attempt (and any OS error it threw).
- int os_error = connect_os_error_;
- connect_os_error_ = 0;
- if (result != OK) {
- net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
- NetLog::IntegerCallback("os_error", os_error));
- } else {
- net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT);
- }
-
- if (result == OK) {
- write_socket_watcher_.StopWatchingFileDescriptor();
- use_history_.set_was_ever_connected();
- return OK; // Done!
- }
-
- // Close whatever partially connected socket we currently have.
- DoDisconnect();
-
- // Try to fall back to the next address in the list.
- if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) {
- next_connect_state_ = CONNECT_STATE_CONNECT;
- ++current_address_index_;
- return OK;
- }
-
- // Otherwise there is nothing to fall back to, so give up.
return result;
}
-void TCPClientSocketLibevent::Disconnect() {
+int TCPSocketLibevent::Connect(const IPEndPoint& address,
+ const CompletionCallback& callback) {
DCHECK(CalledOnValidThread());
+ DCHECK_NE(socket_, kInvalidSocket);
+ DCHECK(!waiting_connect_);
- DoDisconnect();
- current_address_index_ = -1;
- bind_address_.reset();
-}
+ // |peer_address_| will be non-NULL if Connect() has been called. Unless
+ // Close() is called to reset the internal state, a second call to Connect()
+ // is not allowed.
+ // Please note that we don't allow a second Connect() even if the previous
+ // Connect() has failed. Connecting the same |socket_| again after a
+ // connection attempt failed results in unspecified behavior according to
+ // POSIX.
+ DCHECK(!peer_address_);
-void TCPClientSocketLibevent::DoDisconnect() {
- if (socket_ == kInvalidSocket)
- return;
+ if (!logging_multiple_connect_attempts_)
+ LogConnectBegin(AddressList(address));
- bool ok = read_socket_watcher_.StopWatchingFileDescriptor();
- DCHECK(ok);
- ok = write_socket_watcher_.StopWatchingFileDescriptor();
- DCHECK(ok);
- if (HANDLE_EINTR(close(socket_)) < 0)
- PLOG(ERROR) << "close";
- socket_ = kInvalidSocket;
- previously_disconnected_ = true;
+ peer_address_.reset(new IPEndPoint(address));
+
+ int rv = DoConnect();
+ if (rv == ERR_IO_PENDING) {
+ // Synchronous operation not supported.
+ DCHECK(!callback.is_null());
+ write_callback_ = callback;
+ waiting_connect_ = true;
+ } else {
+ DoConnectComplete(rv);
+ }
+
+ return rv;
}
-bool TCPClientSocketLibevent::IsConnected() const {
+bool TCPSocketLibevent::IsConnected() const {
DCHECK(CalledOnValidThread());
- if (socket_ == kInvalidSocket || waiting_connect())
+ if (socket_ == kInvalidSocket || waiting_connect_)
return false;
- if (use_tcp_fastopen_ && !tcp_fastopen_connected_) {
+ if (use_tcp_fastopen_ && !tcp_fastopen_connected_ && peer_address_) {
// With TCP FastOpen, we pretend that the socket is connected.
- // This allows GetPeerAddress() to return current_ai_ as the peer
- // address. Since we don't fail over to the next address if
- // sendto() fails, current_ai_ is the only possible peer address.
- CHECK_LT(current_address_index_, static_cast<int>(addresses_.size()));
+ // This allows GetPeerAddress() to return peer_address_.
return true;
}
@@ -411,10 +308,10 @@ bool TCPClientSocketLibevent::IsConnected() const {
return true;
}
-bool TCPClientSocketLibevent::IsConnectedAndIdle() const {
+bool TCPSocketLibevent::IsConnectedAndIdle() const {
DCHECK(CalledOnValidThread());
- if (socket_ == kInvalidSocket || waiting_connect())
+ if (socket_ == kInvalidSocket || waiting_connect_)
return false;
// TODO(wtc): should we also handle the TCP FastOpen case here,
@@ -432,12 +329,12 @@ bool TCPClientSocketLibevent::IsConnectedAndIdle() const {
return true;
}
-int TCPClientSocketLibevent::Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) {
+int TCPSocketLibevent::Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
DCHECK(CalledOnValidThread());
DCHECK_NE(kInvalidSocket, socket_);
- DCHECK(!waiting_connect());
+ DCHECK(!waiting_connect_);
DCHECK(read_callback_.is_null());
// Synchronous operation not supported
DCHECK(!callback.is_null());
@@ -447,8 +344,6 @@ int TCPClientSocketLibevent::Read(IOBuffer* buf,
if (nread >= 0) {
base::StatsCounter read_bytes("tcp.read_bytes");
read_bytes.Add(nread);
- if (nread > 0)
- use_history_.set_was_used_to_convey_data();
net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, nread,
buf->data());
RecordFastOpenStatus();
@@ -474,12 +369,12 @@ int TCPClientSocketLibevent::Read(IOBuffer* buf,
return ERR_IO_PENDING;
}
-int TCPClientSocketLibevent::Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) {
+int TCPSocketLibevent::Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
DCHECK(CalledOnValidThread());
DCHECK_NE(kInvalidSocket, socket_);
- DCHECK(!waiting_connect());
+ DCHECK(!waiting_connect_);
DCHECK(write_callback_.is_null());
// Synchronous operation not supported
DCHECK(!callback.is_null());
@@ -489,8 +384,6 @@ int TCPClientSocketLibevent::Write(IOBuffer* buf,
if (nwrite >= 0) {
base::StatsCounter write_bytes("tcp.write_bytes");
write_bytes.Add(nwrite);
- if (nwrite > 0)
- use_history_.set_was_used_to_convey_data();
net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, nwrite,
buf->data());
return nwrite;
@@ -515,56 +408,67 @@ int TCPClientSocketLibevent::Write(IOBuffer* buf,
return ERR_IO_PENDING;
}
-int TCPClientSocketLibevent::InternalWrite(IOBuffer* buf, int buf_len) {
- int nwrite;
- if (use_tcp_fastopen_ && !tcp_fastopen_connected_) {
- SockaddrStorage storage;
- if (!addresses_[current_address_index_].ToSockAddr(storage.addr,
- &storage.addr_len)) {
- errno = EINVAL;
- return -1;
- }
+int TCPSocketLibevent::GetLocalAddress(IPEndPoint* address) const {
+ DCHECK(CalledOnValidThread());
+ DCHECK(address);
- int flags = 0x20000000; // Magic flag to enable TCP_FASTOPEN.
-#if defined(OS_LINUX)
- // sendto() will fail with EPIPE when the system doesn't support TCP Fast
- // Open. Theoretically that shouldn't happen since the caller should check
- // for system support on startup, but users may dynamically disable TCP Fast
- // Open via sysctl.
- flags |= MSG_NOSIGNAL;
-#endif // defined(OS_LINUX)
- nwrite = HANDLE_EINTR(sendto(socket_,
- buf->data(),
- buf_len,
- flags,
- storage.addr,
- storage.addr_len));
- tcp_fastopen_connected_ = true;
+ SockaddrStorage storage;
+ if (getsockname(socket_, storage.addr, &storage.addr_len) < 0)
+ return MapSystemError(errno);
+ if (!address->FromSockAddr(storage.addr, storage.addr_len))
+ return ERR_ADDRESS_INVALID;
- if (nwrite < 0) {
- DCHECK_NE(EPIPE, errno);
+ return OK;
+}
- // If errno == EINPROGRESS, that means the kernel didn't have a cookie
- // and would block. The kernel is internally doing a connect() though.
- // Remap EINPROGRESS to EAGAIN so we treat this the same as our other
- // asynchronous cases. Note that the user buffer has not been copied to
- // kernel space.
- if (errno == EINPROGRESS) {
- errno = EAGAIN;
- fast_open_status_ = FAST_OPEN_SLOW_CONNECT_RETURN;
- } else {
- fast_open_status_ = FAST_OPEN_ERROR;
- }
- } else {
- fast_open_status_ = FAST_OPEN_FAST_CONNECT_RETURN;
- }
- } else {
- nwrite = HANDLE_EINTR(write(socket_, buf->data(), buf_len));
- }
- return nwrite;
+int TCPSocketLibevent::GetPeerAddress(IPEndPoint* address) const {
+ DCHECK(CalledOnValidThread());
+ DCHECK(address);
+ if (!IsConnected())
+ return ERR_SOCKET_NOT_CONNECTED;
+ *address = *peer_address_;
+ return OK;
}
-bool TCPClientSocketLibevent::SetReceiveBufferSize(int32 size) {
+int TCPSocketLibevent::SetDefaultOptionsForServer() {
+ DCHECK(CalledOnValidThread());
+ return SetAddressReuse(true);
+}
+
+void TCPSocketLibevent::SetDefaultOptionsForClient() {
+ DCHECK(CalledOnValidThread());
+
+ // This mirrors the behaviour on Windows. See the comment in
+ // tcp_socket_win.cc after searching for "NODELAY".
+ SetTCPNoDelay(socket_, true); // If SetTCPNoDelay fails, we don't care.
+ SetTCPKeepAlive(socket_, true, kTCPKeepAliveSeconds);
+}
+
+int TCPSocketLibevent::SetAddressReuse(bool allow) {
+ DCHECK(CalledOnValidThread());
+
+ // SO_REUSEADDR is useful for server sockets to bind to a recently unbound
+ // port. When a socket is closed, the end point changes its state to TIME_WAIT
+ // and wait for 2 MSL (maximum segment lifetime) to ensure the remote peer
+ // acknowledges its closure. For server sockets, it is usually safe to
+ // bind to a TIME_WAIT end point immediately, which is a widely adopted
+ // behavior.
+ //
+ // Note that on *nix, SO_REUSEADDR does not enable the TCP socket to bind to
+ // an end point that is already bound by another socket. To do that one must
+ // set SO_REUSEPORT instead. This option is not provided on Linux prior
+ // to 3.9.
+ //
+ // SO_REUSEPORT is provided in MacOS X and iOS.
+ int boolean_value = allow ? 1 : 0;
+ int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &boolean_value,
+ sizeof(boolean_value));
+ if (rv < 0)
+ return MapSystemError(errno);
+ return OK;
+}
+
+bool TCPSocketLibevent::SetReceiveBufferSize(int32 size) {
DCHECK(CalledOnValidThread());
int rv = setsockopt(socket_, SOL_SOCKET, SO_RCVBUF,
reinterpret_cast<const char*>(&size),
@@ -573,7 +477,7 @@ bool TCPClientSocketLibevent::SetReceiveBufferSize(int32 size) {
return rv == 0;
}
-bool TCPClientSocketLibevent::SetSendBufferSize(int32 size) {
+bool TCPSocketLibevent::SetSendBufferSize(int32 size) {
DCHECK(CalledOnValidThread());
int rv = setsockopt(socket_, SOL_SOCKET, SO_SNDBUF,
reinterpret_cast<const char*>(&size),
@@ -582,31 +486,180 @@ bool TCPClientSocketLibevent::SetSendBufferSize(int32 size) {
return rv == 0;
}
-bool TCPClientSocketLibevent::SetKeepAlive(bool enable, int delay) {
- int socket = socket_ != kInvalidSocket ? socket_ : bound_socket_;
- return SetTCPKeepAlive(socket, enable, delay);
+bool TCPSocketLibevent::SetKeepAlive(bool enable, int delay) {
+ DCHECK(CalledOnValidThread());
+ return SetTCPKeepAlive(socket_, enable, delay);
}
-bool TCPClientSocketLibevent::SetNoDelay(bool no_delay) {
- int socket = socket_ != kInvalidSocket ? socket_ : bound_socket_;
- return SetTCPNoDelay(socket, no_delay);
+bool TCPSocketLibevent::SetNoDelay(bool no_delay) {
+ DCHECK(CalledOnValidThread());
+ return SetTCPNoDelay(socket_, no_delay);
}
-void TCPClientSocketLibevent::ReadWatcher::OnFileCanReadWithoutBlocking(int) {
- socket_->RecordFastOpenStatus();
- if (!socket_->read_callback_.is_null())
- socket_->DidCompleteRead();
+void TCPSocketLibevent::Close() {
+ DCHECK(CalledOnValidThread());
+
+ bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ ok = read_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ ok = write_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+
+ if (socket_ != kInvalidSocket) {
+ if (HANDLE_EINTR(close(socket_)) < 0)
+ PLOG(ERROR) << "close";
+ socket_ = kInvalidSocket;
+ }
+
+ if (!accept_callback_.is_null()) {
+ accept_socket_ = NULL;
+ accept_address_ = NULL;
+ accept_callback_.Reset();
+ }
+
+ if (!read_callback_.is_null()) {
+ read_buf_ = NULL;
+ read_buf_len_ = 0;
+ read_callback_.Reset();
+ }
+
+ if (!write_callback_.is_null()) {
+ write_buf_ = NULL;
+ write_buf_len_ = 0;
+ write_callback_.Reset();
+ }
+
+ tcp_fastopen_connected_ = false;
+ fast_open_status_ = FAST_OPEN_STATUS_UNKNOWN;
+ waiting_connect_ = false;
+ peer_address_.reset();
+ connect_os_error_ = 0;
+}
+
+bool TCPSocketLibevent::UsingTCPFastOpen() const {
+ return use_tcp_fastopen_;
+}
+
+void TCPSocketLibevent::StartLoggingMultipleConnectAttempts(
+ const AddressList& addresses) {
+ if (!logging_multiple_connect_attempts_) {
+ logging_multiple_connect_attempts_ = true;
+ LogConnectBegin(addresses);
+ } else {
+ NOTREACHED();
+ }
+}
+
+void TCPSocketLibevent::EndLoggingMultipleConnectAttempts(int net_error) {
+ if (logging_multiple_connect_attempts_) {
+ LogConnectEnd(net_error);
+ logging_multiple_connect_attempts_ = false;
+ } else {
+ NOTREACHED();
+ }
+}
+
+int TCPSocketLibevent::AcceptInternal(scoped_ptr<TCPSocketLibevent>* socket,
+ IPEndPoint* address) {
+ SockaddrStorage storage;
+ int new_socket = HANDLE_EINTR(accept(socket_,
+ storage.addr,
+ &storage.addr_len));
+ if (new_socket < 0) {
+ int net_error = MapSystemError(errno);
+ if (net_error != ERR_IO_PENDING)
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error);
+ return net_error;
+ }
+
+ IPEndPoint ip_end_point;
+ if (!ip_end_point.FromSockAddr(storage.addr, storage.addr_len)) {
+ NOTREACHED();
+ if (HANDLE_EINTR(close(new_socket)) < 0)
+ PLOG(ERROR) << "close";
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT,
+ ERR_ADDRESS_INVALID);
+ return ERR_ADDRESS_INVALID;
+ }
+ scoped_ptr<TCPSocketLibevent> tcp_socket(new TCPSocketLibevent(
+ net_log_.net_log(), net_log_.source()));
+ int adopt_result = tcp_socket->AdoptConnectedSocket(new_socket, ip_end_point);
+ if (adopt_result != OK) {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result);
+ return adopt_result;
+ }
+ *socket = tcp_socket.Pass();
+ *address = ip_end_point;
+ net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT,
+ CreateNetLogIPEndPointCallback(&ip_end_point));
+ return OK;
+}
+
+int TCPSocketLibevent::DoConnect() {
+ DCHECK_EQ(0, connect_os_error_);
+
+ net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
+ CreateNetLogIPEndPointCallback(peer_address_.get()));
+
+ // Connect the socket.
+ if (!use_tcp_fastopen_) {
+ SockaddrStorage storage;
+ if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len))
+ return ERR_INVALID_ARGUMENT;
+
+ if (!HANDLE_EINTR(connect(socket_, storage.addr, storage.addr_len))) {
+ // Connected without waiting!
+ return OK;
+ }
+ } else {
+ // With TCP FastOpen, we pretend that the socket is connected.
+ DCHECK(!tcp_fastopen_connected_);
+ return OK;
+ }
+
+ // Check if the connect() failed synchronously.
+ connect_os_error_ = errno;
+ if (connect_os_error_ != EINPROGRESS)
+ return MapConnectError(connect_os_error_);
+
+ // Otherwise the connect() is going to complete asynchronously, so watch
+ // for its completion.
+ if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
+ socket_, true, base::MessageLoopForIO::WATCH_WRITE,
+ &write_socket_watcher_, &write_watcher_)) {
+ connect_os_error_ = errno;
+ DVLOG(1) << "WatchFileDescriptor failed: " << connect_os_error_;
+ return MapSystemError(connect_os_error_);
+ }
+
+ return ERR_IO_PENDING;
}
-void TCPClientSocketLibevent::WriteWatcher::OnFileCanWriteWithoutBlocking(int) {
- if (socket_->waiting_connect()) {
- socket_->DidCompleteConnect();
- } else if (!socket_->write_callback_.is_null()) {
- socket_->DidCompleteWrite();
+void TCPSocketLibevent::DoConnectComplete(int result) {
+ // Log the end of this attempt (and any OS error it threw).
+ int os_error = connect_os_error_;
+ connect_os_error_ = 0;
+ if (result != OK) {
+ net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
+ NetLog::IntegerCallback("os_error", os_error));
+ } else {
+ net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT);
}
+
+ if (!logging_multiple_connect_attempts_)
+ LogConnectEnd(result);
}
-void TCPClientSocketLibevent::LogConnectCompletion(int net_error) {
+void TCPSocketLibevent::LogConnectBegin(const AddressList& addresses) {
+ base::StatsCounter connects("tcp.connect");
+ connects.Increment();
+
+ net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT,
+ addresses.CreateNetLogCallback());
+}
+
+void TCPSocketLibevent::LogConnectEnd(int net_error) {
if (net_error == OK)
UpdateConnectionTypeHistograms(CONNECTION_ANY);
@@ -629,50 +682,11 @@ void TCPClientSocketLibevent::LogConnectCompletion(int net_error) {
storage.addr_len));
}
-void TCPClientSocketLibevent::DoReadCallback(int rv) {
- DCHECK_NE(rv, ERR_IO_PENDING);
- DCHECK(!read_callback_.is_null());
-
- // since Run may result in Read being called, clear read_callback_ up front.
- CompletionCallback c = read_callback_;
- read_callback_.Reset();
- c.Run(rv);
-}
-
-void TCPClientSocketLibevent::DoWriteCallback(int rv) {
- DCHECK_NE(rv, ERR_IO_PENDING);
- DCHECK(!write_callback_.is_null());
-
- // since Run may result in Write being called, clear write_callback_ up front.
- CompletionCallback c = write_callback_;
- write_callback_.Reset();
- c.Run(rv);
-}
-
-void TCPClientSocketLibevent::DidCompleteConnect() {
- DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE);
-
- // Get the error that connect() completed with.
- int os_error = 0;
- socklen_t len = sizeof(os_error);
- if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &os_error, &len) < 0)
- os_error = errno;
-
- // TODO(eroman): Is this check really necessary?
- if (os_error == EINPROGRESS || os_error == EALREADY) {
- NOTREACHED(); // This indicates a bug in libevent or our code.
+void TCPSocketLibevent::DidCompleteRead() {
+ RecordFastOpenStatus();
+ if (read_callback_.is_null())
return;
- }
- connect_os_error_ = os_error;
- int rv = DoConnectLoop(MapConnectError(os_error));
- if (rv != ERR_IO_PENDING) {
- LogConnectCompletion(rv);
- DoWriteCallback(rv);
- }
-}
-
-void TCPClientSocketLibevent::DidCompleteRead() {
int bytes_transferred;
bytes_transferred = HANDLE_EINTR(read(socket_, read_buf_->data(),
read_buf_len_));
@@ -682,8 +696,6 @@ void TCPClientSocketLibevent::DidCompleteRead() {
result = bytes_transferred;
base::StatsCounter read_bytes("tcp.read_bytes");
read_bytes.Add(bytes_transferred);
- if (bytes_transferred > 0)
- use_history_.set_was_used_to_convey_data();
net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, result,
read_buf_->data());
} else {
@@ -699,11 +711,14 @@ void TCPClientSocketLibevent::DidCompleteRead() {
read_buf_len_ = 0;
bool ok = read_socket_watcher_.StopWatchingFileDescriptor();
DCHECK(ok);
- DoReadCallback(result);
+ base::ResetAndReturn(&read_callback_).Run(result);
}
}
-void TCPClientSocketLibevent::DidCompleteWrite() {
+void TCPSocketLibevent::DidCompleteWrite() {
+ if (write_callback_.is_null())
+ return;
+
int bytes_transferred;
bytes_transferred = HANDLE_EINTR(write(socket_, write_buf_->data(),
write_buf_len_));
@@ -713,8 +728,6 @@ void TCPClientSocketLibevent::DidCompleteWrite() {
result = bytes_transferred;
base::StatsCounter write_bytes("tcp.write_bytes");
write_bytes.Add(bytes_transferred);
- if (bytes_transferred > 0)
- use_history_.set_was_used_to_convey_data();
net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, result,
write_buf_->data());
} else {
@@ -729,40 +742,100 @@ void TCPClientSocketLibevent::DidCompleteWrite() {
write_buf_ = NULL;
write_buf_len_ = 0;
write_socket_watcher_.StopWatchingFileDescriptor();
- DoWriteCallback(result);
+ base::ResetAndReturn(&write_callback_).Run(result);
}
}
-int TCPClientSocketLibevent::GetPeerAddress(IPEndPoint* address) const {
- DCHECK(CalledOnValidThread());
- DCHECK(address);
- if (!IsConnected())
- return ERR_SOCKET_NOT_CONNECTED;
- *address = addresses_[current_address_index_];
- return OK;
+void TCPSocketLibevent::DidCompleteConnect() {
+ DCHECK(waiting_connect_);
+
+ // Get the error that connect() completed with.
+ int os_error = 0;
+ socklen_t len = sizeof(os_error);
+ if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &os_error, &len) < 0)
+ os_error = errno;
+
+ int result = MapConnectError(os_error);
+ connect_os_error_ = os_error;
+ if (result != ERR_IO_PENDING) {
+ DoConnectComplete(result);
+ waiting_connect_ = false;
+ write_socket_watcher_.StopWatchingFileDescriptor();
+ base::ResetAndReturn(&write_callback_).Run(result);
+ }
+}
+
+void TCPSocketLibevent::DidCompleteConnectOrWrite() {
+ if (waiting_connect_)
+ DidCompleteConnect();
+ else
+ DidCompleteWrite();
}
-int TCPClientSocketLibevent::GetLocalAddress(IPEndPoint* address) const {
+void TCPSocketLibevent::DidCompleteAccept() {
DCHECK(CalledOnValidThread());
- DCHECK(address);
- if (socket_ == kInvalidSocket) {
- if (bind_address_.get()) {
- *address = *bind_address_;
- return OK;
- }
- return ERR_SOCKET_NOT_CONNECTED;
+
+ int result = AcceptInternal(accept_socket_, accept_address_);
+ if (result != ERR_IO_PENDING) {
+ accept_socket_ = NULL;
+ accept_address_ = NULL;
+ bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ CompletionCallback callback = accept_callback_;
+ accept_callback_.Reset();
+ callback.Run(result);
}
+}
- SockaddrStorage storage;
- if (getsockname(socket_, storage.addr, &storage.addr_len))
- return MapSystemError(errno);
- if (!address->FromSockAddr(storage.addr, storage.addr_len))
- return ERR_FAILED;
+int TCPSocketLibevent::InternalWrite(IOBuffer* buf, int buf_len) {
+ int nwrite;
+ if (use_tcp_fastopen_ && !tcp_fastopen_connected_) {
+ SockaddrStorage storage;
+ if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len)) {
+ errno = EINVAL;
+ return -1;
+ }
- return OK;
+ int flags = 0x20000000; // Magic flag to enable TCP_FASTOPEN.
+#if defined(OS_LINUX)
+ // sendto() will fail with EPIPE when the system doesn't support TCP Fast
+ // Open. Theoretically that shouldn't happen since the caller should check
+ // for system support on startup, but users may dynamically disable TCP Fast
+ // Open via sysctl.
+ flags |= MSG_NOSIGNAL;
+#endif // defined(OS_LINUX)
+ nwrite = HANDLE_EINTR(sendto(socket_,
+ buf->data(),
+ buf_len,
+ flags,
+ storage.addr,
+ storage.addr_len));
+ tcp_fastopen_connected_ = true;
+
+ if (nwrite < 0) {
+ DCHECK_NE(EPIPE, errno);
+
+ // If errno == EINPROGRESS, that means the kernel didn't have a cookie
+ // and would block. The kernel is internally doing a connect() though.
+ // Remap EINPROGRESS to EAGAIN so we treat this the same as our other
+ // asynchronous cases. Note that the user buffer has not been copied to
+ // kernel space.
+ if (errno == EINPROGRESS) {
+ errno = EAGAIN;
+ fast_open_status_ = FAST_OPEN_SLOW_CONNECT_RETURN;
+ } else {
+ fast_open_status_ = FAST_OPEN_ERROR;
+ }
+ } else {
+ fast_open_status_ = FAST_OPEN_FAST_CONNECT_RETURN;
+ }
+ } else {
+ nwrite = HANDLE_EINTR(write(socket_, buf->data(), buf_len));
+ }
+ return nwrite;
}
-void TCPClientSocketLibevent::RecordFastOpenStatus() {
+void TCPSocketLibevent::RecordFastOpenStatus() {
if (use_tcp_fastopen_ &&
(fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN ||
fast_open_status_ == FAST_OPEN_SLOW_CONNECT_RETURN)) {
@@ -795,36 +868,4 @@ void TCPClientSocketLibevent::RecordFastOpenStatus() {
}
}
-const BoundNetLog& TCPClientSocketLibevent::NetLog() const {
- return net_log_;
-}
-
-void TCPClientSocketLibevent::SetSubresourceSpeculation() {
- use_history_.set_subresource_speculation();
-}
-
-void TCPClientSocketLibevent::SetOmniboxSpeculation() {
- use_history_.set_omnibox_speculation();
-}
-
-bool TCPClientSocketLibevent::WasEverUsed() const {
- return use_history_.was_used_to_convey_data();
-}
-
-bool TCPClientSocketLibevent::UsingTCPFastOpen() const {
- return use_tcp_fastopen_;
-}
-
-bool TCPClientSocketLibevent::WasNpnNegotiated() const {
- return false;
-}
-
-NextProto TCPClientSocketLibevent::GetNegotiatedProtocol() const {
- return kProtoUnknown;
-}
-
-bool TCPClientSocketLibevent::GetSSLInfo(SSLInfo* ssl_info) {
- return false;
-}
-
} // namespace net
diff --git a/chromium/net/socket/tcp_socket_libevent.h b/chromium/net/socket/tcp_socket_libevent.h
new file mode 100644
index 00000000000..a50caf0ad59
--- /dev/null
+++ b/chromium/net/socket/tcp_socket_libevent.h
@@ -0,0 +1,235 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_TCP_SOCKET_LIBEVENT_H_
+#define NET_SOCKET_TCP_SOCKET_LIBEVENT_H_
+
+#include "base/basictypes.h"
+#include "base/callback.h"
+#include "base/compiler_specific.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "base/threading/non_thread_safe.h"
+#include "net/base/address_family.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_export.h"
+#include "net/base/net_log.h"
+#include "net/socket/socket_descriptor.h"
+
+namespace net {
+
+class AddressList;
+class IOBuffer;
+class IPEndPoint;
+
+class NET_EXPORT TCPSocketLibevent : public base::NonThreadSafe {
+ public:
+ TCPSocketLibevent(NetLog* net_log, const NetLog::Source& source);
+ virtual ~TCPSocketLibevent();
+
+ int Open(AddressFamily family);
+ // Takes ownership of |socket|.
+ int AdoptConnectedSocket(int socket, const IPEndPoint& peer_address);
+
+ int Bind(const IPEndPoint& address);
+
+ int Listen(int backlog);
+ int Accept(scoped_ptr<TCPSocketLibevent>* socket,
+ IPEndPoint* address,
+ const CompletionCallback& callback);
+
+ int Connect(const IPEndPoint& address, const CompletionCallback& callback);
+ bool IsConnected() const;
+ bool IsConnectedAndIdle() const;
+
+ // Multiple outstanding requests are not supported.
+ // Full duplex mode (reading and writing at the same time) is supported.
+ int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback);
+ int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback);
+
+ int GetLocalAddress(IPEndPoint* address) const;
+ int GetPeerAddress(IPEndPoint* address) const;
+
+ // Sets various socket options.
+ // The commonly used options for server listening sockets:
+ // - SetAddressReuse(true).
+ int SetDefaultOptionsForServer();
+ // The commonly used options for client sockets and accepted sockets:
+ // - SetNoDelay(true);
+ // - SetKeepAlive(true, 45).
+ void SetDefaultOptionsForClient();
+ int SetAddressReuse(bool allow);
+ bool SetReceiveBufferSize(int32 size);
+ bool SetSendBufferSize(int32 size);
+ bool SetKeepAlive(bool enable, int delay);
+ bool SetNoDelay(bool no_delay);
+
+ void Close();
+
+ bool UsingTCPFastOpen() const;
+ bool IsValid() const { return socket_ != kInvalidSocket; }
+
+ // Marks the start/end of a series of connect attempts for logging purpose.
+ //
+ // TCPClientSocket may attempt to connect to multiple addresses until it
+ // succeeds in establishing a connection. The corresponding log will have
+ // multiple NetLog::TYPE_TCP_CONNECT_ATTEMPT entries nested within a
+ // NetLog::TYPE_TCP_CONNECT. These methods set the start/end of
+ // NetLog::TYPE_TCP_CONNECT.
+ //
+ // TODO(yzshen): Change logging format and let TCPClientSocket log the
+ // start/end of a series of connect attempts itself.
+ void StartLoggingMultipleConnectAttempts(const AddressList& addresses);
+ void EndLoggingMultipleConnectAttempts(int net_error);
+
+ const BoundNetLog& net_log() const { return net_log_; }
+
+ private:
+ // States that a fast open socket attempt can result in.
+ enum FastOpenStatus {
+ FAST_OPEN_STATUS_UNKNOWN,
+
+ // The initial fast open connect attempted returned synchronously,
+ // indicating that we had and sent a cookie along with the initial data.
+ FAST_OPEN_FAST_CONNECT_RETURN,
+
+ // The initial fast open connect attempted returned asynchronously,
+ // indicating that we did not have a cookie for the server.
+ FAST_OPEN_SLOW_CONNECT_RETURN,
+
+ // Some other error occurred on connection, so we couldn't tell if
+ // fast open would have worked.
+ FAST_OPEN_ERROR,
+
+ // An attempt to do a fast open succeeded immediately
+ // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server
+ // had acked the data we sent.
+ FAST_OPEN_SYN_DATA_ACK,
+
+ // An attempt to do a fast open succeeded immediately
+ // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server
+ // had nacked the data we sent.
+ FAST_OPEN_SYN_DATA_NACK,
+
+ // An attempt to do a fast open succeeded immediately
+ // (FAST_OPEN_FAST_CONNECT_RETURN) and our probe to determine if the
+ // socket was using fast open failed.
+ FAST_OPEN_SYN_DATA_FAILED,
+
+ // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN)
+ // and we later confirmed that the server had acked initial data. This
+ // should never happen (we didn't send data, so it shouldn't have
+ // been acked).
+ FAST_OPEN_NO_SYN_DATA_ACK,
+
+ // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN)
+ // and we later discovered that the server had nacked initial data. This
+ // is the expected case results for FAST_OPEN_SLOW_CONNECT_RETURN.
+ FAST_OPEN_NO_SYN_DATA_NACK,
+
+ // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN)
+ // and our later probe for ack/nack state failed.
+ FAST_OPEN_NO_SYN_DATA_FAILED,
+
+ FAST_OPEN_MAX_VALUE
+ };
+
+ // Watcher simply forwards notifications to Closure objects set via the
+ // constructor.
+ class Watcher: public base::MessageLoopForIO::Watcher {
+ public:
+ Watcher(const base::Closure& read_ready_callback,
+ const base::Closure& write_ready_callback);
+ virtual ~Watcher();
+
+ // base::MessageLoopForIO::Watcher methods.
+ virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE;
+ virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE;
+
+ private:
+ base::Closure read_ready_callback_;
+ base::Closure write_ready_callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(Watcher);
+ };
+
+ int AcceptInternal(scoped_ptr<TCPSocketLibevent>* socket,
+ IPEndPoint* address);
+
+ int DoConnect();
+ void DoConnectComplete(int result);
+
+ void LogConnectBegin(const AddressList& addresses);
+ void LogConnectEnd(int net_error);
+
+ void DidCompleteRead();
+ void DidCompleteWrite();
+ void DidCompleteConnect();
+ void DidCompleteConnectOrWrite();
+ void DidCompleteAccept();
+
+ // Internal function to write to a socket. Returns an OS error.
+ int InternalWrite(IOBuffer* buf, int buf_len);
+
+ // Called when the socket is known to be in a connected state.
+ void RecordFastOpenStatus();
+
+ int socket_;
+
+ base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_;
+ Watcher accept_watcher_;
+
+ scoped_ptr<TCPSocketLibevent>* accept_socket_;
+ IPEndPoint* accept_address_;
+ CompletionCallback accept_callback_;
+
+ // The socket's libevent wrappers for reads and writes.
+ base::MessageLoopForIO::FileDescriptorWatcher read_socket_watcher_;
+ base::MessageLoopForIO::FileDescriptorWatcher write_socket_watcher_;
+
+ // The corresponding watchers for reads and writes.
+ Watcher read_watcher_;
+ Watcher write_watcher_;
+
+ // The buffer used for reads.
+ scoped_refptr<IOBuffer> read_buf_;
+ int read_buf_len_;
+
+ // The buffer used for writes.
+ scoped_refptr<IOBuffer> write_buf_;
+ int write_buf_len_;
+
+ // External callback; called when read is complete.
+ CompletionCallback read_callback_;
+
+ // External callback; called when write or connect is complete.
+ CompletionCallback write_callback_;
+
+ // Enables experimental TCP FastOpen option.
+ const bool use_tcp_fastopen_;
+
+ // True when TCP FastOpen is in use and we have done the connect.
+ bool tcp_fastopen_connected_;
+
+ FastOpenStatus fast_open_status_;
+
+ // A connect operation is pending. In this case, |write_callback_| needs to be
+ // called when connect is complete.
+ bool waiting_connect_;
+
+ scoped_ptr<IPEndPoint> peer_address_;
+ // The OS error that a connect attempt last completed with.
+ int connect_os_error_;
+
+ bool logging_multiple_connect_attempts_;
+
+ BoundNetLog net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(TCPSocketLibevent);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_TCP_SOCKET_LIBEVENT_H_
diff --git a/chromium/net/socket/tcp_socket_unittest.cc b/chromium/net/socket/tcp_socket_unittest.cc
new file mode 100644
index 00000000000..a45fcba016b
--- /dev/null
+++ b/chromium/net/socket/tcp_socket_unittest.cc
@@ -0,0 +1,263 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/tcp_socket.h"
+
+#include <string.h>
+
+#include <string>
+#include <vector>
+
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/address_list.h"
+#include "net/base/io_buffer.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/test_completion_callback.h"
+#include "net/socket/tcp_client_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/platform_test.h"
+
+namespace net {
+
+namespace {
+const int kListenBacklog = 5;
+
+class TCPSocketTest : public PlatformTest {
+ protected:
+ TCPSocketTest() : socket_(NULL, NetLog::Source()) {
+ }
+
+ void SetUpListenIPv4() {
+ IPEndPoint address;
+ ParseAddress("127.0.0.1", 0, &address);
+
+ ASSERT_EQ(OK, socket_.Open(ADDRESS_FAMILY_IPV4));
+ ASSERT_EQ(OK, socket_.Bind(address));
+ ASSERT_EQ(OK, socket_.Listen(kListenBacklog));
+ ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
+ }
+
+ void SetUpListenIPv6(bool* success) {
+ *success = false;
+ IPEndPoint address;
+ ParseAddress("::1", 0, &address);
+
+ if (socket_.Open(ADDRESS_FAMILY_IPV6) != OK ||
+ socket_.Bind(address) != OK ||
+ socket_.Listen(kListenBacklog) != OK) {
+ LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is "
+ "disabled. Skipping the test";
+ return;
+ }
+ ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
+ *success = true;
+ }
+
+ void ParseAddress(const std::string& ip_str, int port, IPEndPoint* address) {
+ IPAddressNumber ip_number;
+ bool rv = ParseIPLiteralToNumber(ip_str, &ip_number);
+ if (!rv)
+ return;
+ *address = IPEndPoint(ip_number, port);
+ }
+
+ AddressList local_address_list() const {
+ return AddressList(local_address_);
+ }
+
+ TCPSocket socket_;
+ IPEndPoint local_address_;
+};
+
+// Test listening and accepting with a socket bound to an IPv4 address.
+TEST_F(TCPSocketTest, Accept) {
+ ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
+
+ TestCompletionCallback connect_callback;
+ // TODO(yzshen): Switch to use TCPSocket when it supports client socket
+ // operations.
+ TCPClientSocket connecting_socket(local_address_list(),
+ NULL, NetLog::Source());
+ connecting_socket.Connect(connect_callback.callback());
+
+ TestCompletionCallback accept_callback;
+ scoped_ptr<TCPSocket> accepted_socket;
+ IPEndPoint accepted_address;
+ int result = socket_.Accept(&accepted_socket, &accepted_address,
+ accept_callback.callback());
+ if (result == ERR_IO_PENDING)
+ result = accept_callback.WaitForResult();
+ ASSERT_EQ(OK, result);
+
+ EXPECT_TRUE(accepted_socket.get());
+
+ // Both sockets should be on the loopback network interface.
+ EXPECT_EQ(accepted_address.address(), local_address_.address());
+
+ EXPECT_EQ(OK, connect_callback.WaitForResult());
+}
+
+// Test Accept() callback.
+TEST_F(TCPSocketTest, AcceptAsync) {
+ ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
+
+ TestCompletionCallback accept_callback;
+ scoped_ptr<TCPSocket> accepted_socket;
+ IPEndPoint accepted_address;
+ ASSERT_EQ(ERR_IO_PENDING,
+ socket_.Accept(&accepted_socket, &accepted_address,
+ accept_callback.callback()));
+
+ TestCompletionCallback connect_callback;
+ TCPClientSocket connecting_socket(local_address_list(),
+ NULL, NetLog::Source());
+ connecting_socket.Connect(connect_callback.callback());
+
+ EXPECT_EQ(OK, connect_callback.WaitForResult());
+ EXPECT_EQ(OK, accept_callback.WaitForResult());
+
+ EXPECT_TRUE(accepted_socket.get());
+
+ // Both sockets should be on the loopback network interface.
+ EXPECT_EQ(accepted_address.address(), local_address_.address());
+}
+
+// Accept two connections simultaneously.
+TEST_F(TCPSocketTest, Accept2Connections) {
+ ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
+
+ TestCompletionCallback accept_callback;
+ scoped_ptr<TCPSocket> accepted_socket;
+ IPEndPoint accepted_address;
+
+ ASSERT_EQ(ERR_IO_PENDING,
+ socket_.Accept(&accepted_socket, &accepted_address,
+ accept_callback.callback()));
+
+ TestCompletionCallback connect_callback;
+ TCPClientSocket connecting_socket(local_address_list(),
+ NULL, NetLog::Source());
+ connecting_socket.Connect(connect_callback.callback());
+
+ TestCompletionCallback connect_callback2;
+ TCPClientSocket connecting_socket2(local_address_list(),
+ NULL, NetLog::Source());
+ connecting_socket2.Connect(connect_callback2.callback());
+
+ EXPECT_EQ(OK, accept_callback.WaitForResult());
+
+ TestCompletionCallback accept_callback2;
+ scoped_ptr<TCPSocket> accepted_socket2;
+ IPEndPoint accepted_address2;
+
+ int result = socket_.Accept(&accepted_socket2, &accepted_address2,
+ accept_callback2.callback());
+ if (result == ERR_IO_PENDING)
+ result = accept_callback2.WaitForResult();
+ ASSERT_EQ(OK, result);
+
+ EXPECT_EQ(OK, connect_callback.WaitForResult());
+ EXPECT_EQ(OK, connect_callback2.WaitForResult());
+
+ EXPECT_TRUE(accepted_socket.get());
+ EXPECT_TRUE(accepted_socket2.get());
+ EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
+
+ EXPECT_EQ(accepted_address.address(), local_address_.address());
+ EXPECT_EQ(accepted_address2.address(), local_address_.address());
+}
+
+// Test listening and accepting with a socket bound to an IPv6 address.
+TEST_F(TCPSocketTest, AcceptIPv6) {
+ bool initialized = false;
+ ASSERT_NO_FATAL_FAILURE(SetUpListenIPv6(&initialized));
+ if (!initialized)
+ return;
+
+ TestCompletionCallback connect_callback;
+ TCPClientSocket connecting_socket(local_address_list(),
+ NULL, NetLog::Source());
+ connecting_socket.Connect(connect_callback.callback());
+
+ TestCompletionCallback accept_callback;
+ scoped_ptr<TCPSocket> accepted_socket;
+ IPEndPoint accepted_address;
+ int result = socket_.Accept(&accepted_socket, &accepted_address,
+ accept_callback.callback());
+ if (result == ERR_IO_PENDING)
+ result = accept_callback.WaitForResult();
+ ASSERT_EQ(OK, result);
+
+ EXPECT_TRUE(accepted_socket.get());
+
+ // Both sockets should be on the loopback network interface.
+ EXPECT_EQ(accepted_address.address(), local_address_.address());
+
+ EXPECT_EQ(OK, connect_callback.WaitForResult());
+}
+
+TEST_F(TCPSocketTest, ReadWrite) {
+ ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
+
+ TestCompletionCallback connect_callback;
+ TCPSocket connecting_socket(NULL, NetLog::Source());
+ int result = connecting_socket.Open(ADDRESS_FAMILY_IPV4);
+ ASSERT_EQ(OK, result);
+ connecting_socket.Connect(local_address_, connect_callback.callback());
+
+ TestCompletionCallback accept_callback;
+ scoped_ptr<TCPSocket> accepted_socket;
+ IPEndPoint accepted_address;
+ result = socket_.Accept(&accepted_socket, &accepted_address,
+ accept_callback.callback());
+ ASSERT_EQ(OK, accept_callback.GetResult(result));
+
+ ASSERT_TRUE(accepted_socket.get());
+
+ // Both sockets should be on the loopback network interface.
+ EXPECT_EQ(accepted_address.address(), local_address_.address());
+
+ EXPECT_EQ(OK, connect_callback.WaitForResult());
+
+ const std::string message("test message");
+ std::vector<char> buffer(message.size());
+
+ size_t bytes_written = 0;
+ while (bytes_written < message.size()) {
+ scoped_refptr<IOBufferWithSize> write_buffer(
+ new IOBufferWithSize(message.size() - bytes_written));
+ memmove(write_buffer->data(), message.data() + bytes_written,
+ message.size() - bytes_written);
+
+ TestCompletionCallback write_callback;
+ int write_result = accepted_socket->Write(
+ write_buffer.get(), write_buffer->size(), write_callback.callback());
+ write_result = write_callback.GetResult(write_result);
+ ASSERT_TRUE(write_result >= 0);
+ bytes_written += write_result;
+ ASSERT_TRUE(bytes_written <= message.size());
+ }
+
+ size_t bytes_read = 0;
+ while (bytes_read < message.size()) {
+ scoped_refptr<IOBufferWithSize> read_buffer(
+ new IOBufferWithSize(message.size() - bytes_read));
+ TestCompletionCallback read_callback;
+ int read_result = connecting_socket.Read(
+ read_buffer.get(), read_buffer->size(), read_callback.callback());
+ read_result = read_callback.GetResult(read_result);
+ ASSERT_TRUE(read_result >= 0);
+ ASSERT_TRUE(bytes_read + read_result <= message.size());
+ memmove(&buffer[bytes_read], read_buffer->data(), read_result);
+ bytes_read += read_result;
+ }
+
+ std::string received_message(buffer.begin(), buffer.end());
+ ASSERT_EQ(message, received_message);
+}
+
+} // namespace
+} // namespace net
diff --git a/chromium/net/socket/tcp_client_socket_win.cc b/chromium/net/socket/tcp_socket_win.cc
index 9b0a5b50bf1..7d76232f962 100644
--- a/chromium/net/socket/tcp_client_socket_win.cc
+++ b/chromium/net/socket/tcp_socket_win.cc
@@ -1,26 +1,25 @@
-// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#include "net/socket/tcp_client_socket_win.h"
+#include "net/socket/tcp_socket_win.h"
#include <mstcpip.h>
-#include "base/basictypes.h"
-#include "base/compiler_specific.h"
+#include "base/callback_helpers.h"
+#include "base/logging.h"
#include "base/metrics/stats_counters.h"
-#include "base/strings/string_util.h"
-#include "base/win/object_watcher.h"
#include "base/win/windows_version.h"
+#include "net/base/address_list.h"
#include "net/base/connection_type_histograms.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
-#include "net/base/net_log.h"
#include "net/base/net_util.h"
#include "net/base/network_change_notifier.h"
#include "net/base/winsock_init.h"
#include "net/base/winsock_util.h"
+#include "net/socket/socket_descriptor.h"
#include "net/socket/socket_net_log_params.h"
namespace net {
@@ -28,7 +27,6 @@ namespace net {
namespace {
const int kTCPKeepAliveSeconds = 45;
-bool g_disable_overlapped_reads = false;
bool SetSocketReceiveBufferSize(SOCKET socket, int32 size) {
int rv = setsockopt(socket, SOL_SOCKET, SO_RCVBUF,
@@ -86,8 +84,8 @@ bool SetTCPKeepAlive(SOCKET socket, BOOL enable, int delay_secs) {
};
DWORD bytes_returned = 0xABAB;
int rv = WSAIoctl(socket, SIO_KEEPALIVE_VALS, &keepalive_vals,
- sizeof(keepalive_vals), NULL, 0,
- &bytes_returned, NULL, NULL);
+ sizeof(keepalive_vals), NULL, 0,
+ &bytes_returned, NULL, NULL);
DCHECK(!rv) << "Could not enable TCP Keep-Alive for socket: " << socket
<< " [error: " << WSAGetLastError() << "].";
@@ -95,49 +93,6 @@ bool SetTCPKeepAlive(SOCKET socket, BOOL enable, int delay_secs) {
return rv == 0;
}
-// Sets socket parameters. Returns the OS error code (or 0 on
-// success).
-int SetupSocket(SOCKET socket) {
- // Increase the socket buffer sizes from the default sizes for WinXP. In
- // performance testing, there is substantial benefit by increasing from 8KB
- // to 64KB.
- // See also:
- // http://support.microsoft.com/kb/823764/EN-US
- // On Vista, if we manually set these sizes, Vista turns off its receive
- // window auto-tuning feature.
- // http://blogs.msdn.com/wndp/archive/2006/05/05/Winhec-blog-tcpip-2.aspx
- // Since Vista's auto-tune is better than any static value we can could set,
- // only change these on pre-vista machines.
- if (base::win::GetVersion() < base::win::VERSION_VISTA) {
- const int32 kSocketBufferSize = 64 * 1024;
- SetSocketReceiveBufferSize(socket, kSocketBufferSize);
- SetSocketSendBufferSize(socket, kSocketBufferSize);
- }
-
- DisableNagle(socket, true);
- SetTCPKeepAlive(socket, true, kTCPKeepAliveSeconds);
- return 0;
-}
-
-// Creates a new socket and sets default parameters for it. Returns
-// the OS error code (or 0 on success).
-int CreateSocket(int family, SOCKET* socket) {
- *socket = CreatePlatformSocket(family, SOCK_STREAM, IPPROTO_TCP);
- if (*socket == INVALID_SOCKET) {
- int os_error = WSAGetLastError();
- LOG(ERROR) << "CreatePlatformSocket failed: " << os_error;
- return os_error;
- }
- int error = SetupSocket(*socket);
- if (error) {
- if (closesocket(*socket) < 0)
- PLOG(ERROR) << "closesocket";
- *socket = INVALID_SOCKET;
- return error;
- }
- return 0;
-}
-
int MapConnectError(int os_error) {
switch (os_error) {
// connect fails with WSAEACCES when Windows Firewall blocks the
@@ -167,31 +122,21 @@ int MapConnectError(int os_error) {
//-----------------------------------------------------------------------------
// This class encapsulates all the state that has to be preserved as long as
-// there is a network IO operation in progress. If the owner TCPClientSocketWin
-// is destroyed while an operation is in progress, the Core is detached and it
+// there is a network IO operation in progress. If the owner TCPSocketWin is
+// destroyed while an operation is in progress, the Core is detached and it
// lives until the operation completes and the OS doesn't reference any resource
// declared on this class anymore.
-class TCPClientSocketWin::Core : public base::RefCounted<Core> {
+class TCPSocketWin::Core : public base::RefCounted<Core> {
public:
- explicit Core(TCPClientSocketWin* socket);
+ explicit Core(TCPSocketWin* socket);
// Start watching for the end of a read or write operation.
void WatchForRead();
void WatchForWrite();
- // The TCPClientSocketWin is going away.
+ // The TCPSocketWin is going away.
void Detach() { socket_ = NULL; }
- // Throttle the read size based on our current slow start state.
- // Returns the throttled read size.
- int ThrottleReadSize(int size) {
- if (slow_start_throttle_ < kMaxSlowStartThrottle) {
- size = std::min(size, slow_start_throttle_);
- slow_start_throttle_ *= 2;
- }
- return size;
- }
-
// The separate OVERLAPPED variables for asynchronous operation.
// |read_overlapped_| is used for both Connect() and Read().
// |write_overlapped_| is only used for Write();
@@ -204,9 +149,6 @@ class TCPClientSocketWin::Core : public base::RefCounted<Core> {
int read_buffer_length_;
int write_buffer_length_;
- // Remember the state of g_disable_overlapped_reads for the duration of the
- // socket based on what it was when the socket was created.
- bool disable_overlapped_reads_;
bool non_blocking_reads_initialized_;
private:
@@ -239,7 +181,7 @@ class TCPClientSocketWin::Core : public base::RefCounted<Core> {
~Core();
// The socket that created this object.
- TCPClientSocketWin* socket_;
+ TCPSocketWin* socket_;
// |reader_| handles the signals from |read_watcher_|.
ReadDelegate reader_;
@@ -251,26 +193,16 @@ class TCPClientSocketWin::Core : public base::RefCounted<Core> {
// |write_watcher_| watches for events from Write();
base::win::ObjectWatcher write_watcher_;
- // When doing reads from the socket, we try to mirror TCP's slow start.
- // We do this because otherwise the async IO subsystem artifically delays
- // returning data to the application.
- static const int kInitialSlowStartThrottle = 1 * 1024;
- static const int kMaxSlowStartThrottle = 32 * kInitialSlowStartThrottle;
- int slow_start_throttle_;
-
DISALLOW_COPY_AND_ASSIGN(Core);
};
-TCPClientSocketWin::Core::Core(
- TCPClientSocketWin* socket)
+TCPSocketWin::Core::Core(TCPSocketWin* socket)
: read_buffer_length_(0),
write_buffer_length_(0),
- disable_overlapped_reads_(g_disable_overlapped_reads),
non_blocking_reads_initialized_(false),
socket_(socket),
reader_(this),
- writer_(this),
- slow_start_throttle_(kInitialSlowStartThrottle) {
+ writer_(this) {
memset(&read_overlapped_, 0, sizeof(read_overlapped_));
memset(&write_overlapped_, 0, sizeof(write_overlapped_));
@@ -278,7 +210,7 @@ TCPClientSocketWin::Core::Core(
write_overlapped_.hEvent = WSACreateEvent();
}
-TCPClientSocketWin::Core::~Core() {
+TCPSocketWin::Core::~Core() {
// Make sure the message loop is not watching this object anymore.
read_watcher_.StopWatching();
write_watcher_.StopWatching();
@@ -289,37 +221,33 @@ TCPClientSocketWin::Core::~Core() {
memset(&write_overlapped_, 0xaf, sizeof(write_overlapped_));
}
-void TCPClientSocketWin::Core::WatchForRead() {
+void TCPSocketWin::Core::WatchForRead() {
// We grab an extra reference because there is an IO operation in progress.
// Balanced in ReadDelegate::OnObjectSignaled().
AddRef();
read_watcher_.StartWatching(read_overlapped_.hEvent, &reader_);
}
-void TCPClientSocketWin::Core::WatchForWrite() {
+void TCPSocketWin::Core::WatchForWrite() {
// We grab an extra reference because there is an IO operation in progress.
// Balanced in WriteDelegate::OnObjectSignaled().
AddRef();
write_watcher_.StartWatching(write_overlapped_.hEvent, &writer_);
}
-void TCPClientSocketWin::Core::ReadDelegate::OnObjectSignaled(
- HANDLE object) {
+void TCPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) {
DCHECK_EQ(object, core_->read_overlapped_.hEvent);
if (core_->socket_) {
- if (core_->socket_->waiting_connect()) {
+ if (core_->socket_->waiting_connect_)
core_->socket_->DidCompleteConnect();
- } else if (core_->disable_overlapped_reads_) {
+ else
core_->socket_->DidSignalRead();
- } else {
- core_->socket_->DidCompleteRead();
- }
}
core_->Release();
}
-void TCPClientSocketWin::Core::WriteDelegate::OnObjectSignaled(
+void TCPSocketWin::Core::WriteDelegate::OnObjectSignaled(
HANDLE object) {
DCHECK_EQ(object, core_->write_overlapped_.hEvent);
if (core_->socket_)
@@ -330,281 +258,170 @@ void TCPClientSocketWin::Core::WriteDelegate::OnObjectSignaled(
//-----------------------------------------------------------------------------
-TCPClientSocketWin::TCPClientSocketWin(const AddressList& addresses,
- net::NetLog* net_log,
- const net::NetLog::Source& source)
+TCPSocketWin::TCPSocketWin(net::NetLog* net_log,
+ const net::NetLog::Source& source)
: socket_(INVALID_SOCKET),
- bound_socket_(INVALID_SOCKET),
- addresses_(addresses),
- current_address_index_(-1),
+ accept_event_(WSA_INVALID_EVENT),
+ accept_socket_(NULL),
+ accept_address_(NULL),
+ waiting_connect_(false),
waiting_read_(false),
waiting_write_(false),
- next_connect_state_(CONNECT_STATE_NONE),
connect_os_error_(0),
- net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)),
- previously_disconnected_(false) {
+ logging_multiple_connect_attempts_(false),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE,
source.ToEventParametersCallback());
EnsureWinsockInit();
}
-TCPClientSocketWin::~TCPClientSocketWin() {
- Disconnect();
+TCPSocketWin::~TCPSocketWin() {
+ Close();
net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE);
}
-int TCPClientSocketWin::AdoptSocket(SOCKET socket) {
+int TCPSocketWin::Open(AddressFamily family) {
+ DCHECK(CalledOnValidThread());
DCHECK_EQ(socket_, INVALID_SOCKET);
- int error = SetupSocket(socket);
- if (error)
- return MapSystemError(error);
-
- socket_ = socket;
- SetNonBlocking(socket_);
-
- core_ = new Core(this);
- current_address_index_ = 0;
- use_history_.set_was_ever_connected();
-
- return OK;
-}
-
-int TCPClientSocketWin::Bind(const IPEndPoint& address) {
- if (current_address_index_ >= 0 || bind_address_.get()) {
- // Cannot bind the socket if we are already connected or connecting.
- return ERR_UNEXPECTED;
+ socket_ = CreatePlatformSocket(ConvertAddressFamily(family), SOCK_STREAM,
+ IPPROTO_TCP);
+ if (socket_ == INVALID_SOCKET) {
+ PLOG(ERROR) << "CreatePlatformSocket() returned an error";
+ return MapSystemError(WSAGetLastError());
}
- SockaddrStorage storage;
- if (!address.ToSockAddr(storage.addr, &storage.addr_len))
- return ERR_INVALID_ARGUMENT;
-
- // Create |bound_socket_| and try to bind it to |address|.
- int error = CreateSocket(address.GetSockAddrFamily(), &bound_socket_);
- if (error)
- return MapSystemError(error);
-
- if (bind(bound_socket_, storage.addr, storage.addr_len)) {
- error = errno;
- if (closesocket(bound_socket_) < 0)
- PLOG(ERROR) << "closesocket";
- bound_socket_ = INVALID_SOCKET;
- return MapSystemError(error);
+ if (SetNonBlocking(socket_)) {
+ int result = MapSystemError(WSAGetLastError());
+ Close();
+ return result;
}
- bind_address_.reset(new IPEndPoint(address));
-
- return 0;
+ return OK;
}
-
-int TCPClientSocketWin::Connect(const CompletionCallback& callback) {
+int TCPSocketWin::AdoptConnectedSocket(SOCKET socket,
+ const IPEndPoint& peer_address) {
DCHECK(CalledOnValidThread());
+ DCHECK_EQ(socket_, INVALID_SOCKET);
+ DCHECK(!core_);
- // If already connected, then just return OK.
- if (socket_ != INVALID_SOCKET)
- return OK;
-
- base::StatsCounter connects("tcp.connect");
- connects.Increment();
-
- net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT,
- addresses_.CreateNetLogCallback());
-
- // We will try to connect to each address in addresses_. Start with the
- // first one in the list.
- next_connect_state_ = CONNECT_STATE_CONNECT;
- current_address_index_ = 0;
+ socket_ = socket;
- int rv = DoConnectLoop(OK);
- if (rv == ERR_IO_PENDING) {
- // Synchronous operation not supported.
- DCHECK(!callback.is_null());
- // TODO(ajwong): Is setting read_callback_ the right thing to do here??
- read_callback_ = callback;
- } else {
- LogConnectCompletion(rv);
+ if (SetNonBlocking(socket_)) {
+ int result = MapSystemError(WSAGetLastError());
+ Close();
+ return result;
}
- return rv;
-}
-
-int TCPClientSocketWin::DoConnectLoop(int result) {
- DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE);
-
- int rv = result;
- do {
- ConnectState state = next_connect_state_;
- next_connect_state_ = CONNECT_STATE_NONE;
- switch (state) {
- case CONNECT_STATE_CONNECT:
- DCHECK_EQ(OK, rv);
- rv = DoConnect();
- break;
- case CONNECT_STATE_CONNECT_COMPLETE:
- rv = DoConnectComplete(rv);
- break;
- default:
- LOG(DFATAL) << "bad state " << state;
- rv = ERR_UNEXPECTED;
- break;
- }
- } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE);
+ core_ = new Core(this);
+ peer_address_.reset(new IPEndPoint(peer_address));
- return rv;
+ return OK;
}
-int TCPClientSocketWin::DoConnect() {
- DCHECK_GE(current_address_index_, 0);
- DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size()));
- DCHECK_EQ(0, connect_os_error_);
+int TCPSocketWin::Bind(const IPEndPoint& address) {
+ DCHECK(CalledOnValidThread());
+ DCHECK_NE(socket_, INVALID_SOCKET);
- const IPEndPoint& endpoint = addresses_[current_address_index_];
+ SockaddrStorage storage;
+ if (!address.ToSockAddr(storage.addr, &storage.addr_len))
+ return ERR_ADDRESS_INVALID;
- if (previously_disconnected_) {
- use_history_.Reset();
- previously_disconnected_ = false;
+ int result = bind(socket_, storage.addr, storage.addr_len);
+ if (result < 0) {
+ PLOG(ERROR) << "bind() returned an error";
+ return MapSystemError(WSAGetLastError());
}
- net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
- CreateNetLogIPEndPointCallback(&endpoint));
+ return OK;
+}
- next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE;
+int TCPSocketWin::Listen(int backlog) {
+ DCHECK(CalledOnValidThread());
+ DCHECK_GT(backlog, 0);
+ DCHECK_NE(socket_, INVALID_SOCKET);
+ DCHECK_EQ(accept_event_, WSA_INVALID_EVENT);
- if (bound_socket_ != INVALID_SOCKET) {
- DCHECK(bind_address_.get());
- socket_ = bound_socket_;
- bound_socket_ = INVALID_SOCKET;
- } else {
- connect_os_error_ = CreateSocket(endpoint.GetSockAddrFamily(), &socket_);
- if (connect_os_error_ != 0)
- return MapSystemError(connect_os_error_);
-
- if (bind_address_.get()) {
- SockaddrStorage storage;
- if (!bind_address_->ToSockAddr(storage.addr, &storage.addr_len))
- return ERR_INVALID_ARGUMENT;
- if (bind(socket_, storage.addr, storage.addr_len))
- return MapSystemError(errno);
- }
+ accept_event_ = WSACreateEvent();
+ if (accept_event_ == WSA_INVALID_EVENT) {
+ PLOG(ERROR) << "WSACreateEvent()";
+ return MapSystemError(WSAGetLastError());
}
- DCHECK(!core_);
- core_ = new Core(this);
- // WSAEventSelect sets the socket to non-blocking mode as a side effect.
- // Our connect() and recv() calls require that the socket be non-blocking.
- WSAEventSelect(socket_, core_->read_overlapped_.hEvent, FD_CONNECT);
-
- SockaddrStorage storage;
- if (!endpoint.ToSockAddr(storage.addr, &storage.addr_len))
- return ERR_INVALID_ARGUMENT;
- if (!connect(socket_, storage.addr, storage.addr_len)) {
- // Connected without waiting!
- //
- // The MSDN page for connect says:
- // With a nonblocking socket, the connection attempt cannot be completed
- // immediately. In this case, connect will return SOCKET_ERROR, and
- // WSAGetLastError will return WSAEWOULDBLOCK.
- // which implies that for a nonblocking socket, connect never returns 0.
- // It's not documented whether the event object will be signaled or not
- // if connect does return 0. So the code below is essentially dead code
- // and we don't know if it's correct.
- NOTREACHED();
-
- if (ResetEventIfSignaled(core_->read_overlapped_.hEvent))
- return OK;
- } else {
- int os_error = WSAGetLastError();
- if (os_error != WSAEWOULDBLOCK) {
- LOG(ERROR) << "connect failed: " << os_error;
- connect_os_error_ = os_error;
- return MapConnectError(os_error);
- }
+ int result = listen(socket_, backlog);
+ if (result < 0) {
+ PLOG(ERROR) << "listen() returned an error";
+ return MapSystemError(WSAGetLastError());
}
- core_->WatchForRead();
- return ERR_IO_PENDING;
+ return OK;
}
-int TCPClientSocketWin::DoConnectComplete(int result) {
- // Log the end of this attempt (and any OS error it threw).
- int os_error = connect_os_error_;
- connect_os_error_ = 0;
- if (result != OK) {
- net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
- NetLog::IntegerCallback("os_error", os_error));
- } else {
- net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT);
- }
+int TCPSocketWin::Accept(scoped_ptr<TCPSocketWin>* socket,
+ IPEndPoint* address,
+ const CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(socket);
+ DCHECK(address);
+ DCHECK(!callback.is_null());
+ DCHECK(accept_callback_.is_null());
- if (result == OK) {
- use_history_.set_was_ever_connected();
- return OK; // Done!
- }
+ net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT);
- // Close whatever partially connected socket we currently have.
- DoDisconnect();
+ int result = AcceptInternal(socket, address);
- // Try to fall back to the next address in the list.
- if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) {
- next_connect_state_ = CONNECT_STATE_CONNECT;
- ++current_address_index_;
- return OK;
+ if (result == ERR_IO_PENDING) {
+ // Start watching.
+ WSAEventSelect(socket_, accept_event_, FD_ACCEPT);
+ accept_watcher_.StartWatching(accept_event_, this);
+
+ accept_socket_ = socket;
+ accept_address_ = address;
+ accept_callback_ = callback;
}
- // Otherwise there is nothing to fall back to, so give up.
return result;
}
-void TCPClientSocketWin::Disconnect() {
+int TCPSocketWin::Connect(const IPEndPoint& address,
+ const CompletionCallback& callback) {
DCHECK(CalledOnValidThread());
+ DCHECK_NE(socket_, INVALID_SOCKET);
+ DCHECK(!waiting_connect_);
- DoDisconnect();
- current_address_index_ = -1;
- bind_address_.reset();
-}
+ // |peer_address_| and |core_| will be non-NULL if Connect() has been called.
+ // Unless Close() is called to reset the internal state, a second call to
+ // Connect() is not allowed.
+ // Please note that we enforce this even if the previous Connect() has
+ // completed and failed. Although it is allowed to connect the same |socket_|
+ // again after a connection attempt failed on Windows, it results in
+ // unspecified behavior according to POSIX. Therefore, we make it behave in
+ // the same way as TCPSocketLibevent.
+ DCHECK(!peer_address_ && !core_);
-void TCPClientSocketWin::DoDisconnect() {
- DCHECK(CalledOnValidThread());
+ if (!logging_multiple_connect_attempts_)
+ LogConnectBegin(AddressList(address));
- if (socket_ == INVALID_SOCKET)
- return;
+ peer_address_.reset(new IPEndPoint(address));
- // Note: don't use CancelIo to cancel pending IO because it doesn't work
- // when there is a Winsock layered service provider.
-
- // In most socket implementations, closing a socket results in a graceful
- // connection shutdown, but in Winsock we have to call shutdown explicitly.
- // See the MSDN page "Graceful Shutdown, Linger Options, and Socket Closure"
- // at http://msdn.microsoft.com/en-us/library/ms738547.aspx
- shutdown(socket_, SD_SEND);
-
- // This cancels any pending IO.
- closesocket(socket_);
- socket_ = INVALID_SOCKET;
-
- if (waiting_connect()) {
- // We closed the socket, so this notification will never come.
- // From MSDN' WSAEventSelect documentation:
- // "Closing a socket with closesocket also cancels the association and
- // selection of network events specified in WSAEventSelect for the socket".
- core_->Release();
+ int rv = DoConnect();
+ if (rv == ERR_IO_PENDING) {
+ // Synchronous operation not supported.
+ DCHECK(!callback.is_null());
+ read_callback_ = callback;
+ waiting_connect_ = true;
+ } else {
+ DoConnectComplete(rv);
}
- waiting_read_ = false;
- waiting_write_ = false;
-
- core_->Detach();
- core_ = NULL;
-
- previously_disconnected_ = true;
+ return rv;
}
-bool TCPClientSocketWin::IsConnected() const {
+bool TCPSocketWin::IsConnected() const {
DCHECK(CalledOnValidThread());
- if (socket_ == INVALID_SOCKET || waiting_connect())
+ if (socket_ == INVALID_SOCKET || waiting_connect_)
return false;
if (waiting_read_)
@@ -621,10 +438,10 @@ bool TCPClientSocketWin::IsConnected() const {
return true;
}
-bool TCPClientSocketWin::IsConnectedAndIdle() const {
+bool TCPSocketWin::IsConnectedAndIdle() const {
DCHECK(CalledOnValidThread());
- if (socket_ == INVALID_SOCKET || waiting_connect())
+ if (socket_ == INVALID_SOCKET || waiting_connect_)
return false;
if (waiting_read_)
@@ -642,68 +459,9 @@ bool TCPClientSocketWin::IsConnectedAndIdle() const {
return true;
}
-int TCPClientSocketWin::GetPeerAddress(IPEndPoint* address) const {
- DCHECK(CalledOnValidThread());
- DCHECK(address);
- if (!IsConnected())
- return ERR_SOCKET_NOT_CONNECTED;
- *address = addresses_[current_address_index_];
- return OK;
-}
-
-int TCPClientSocketWin::GetLocalAddress(IPEndPoint* address) const {
- DCHECK(CalledOnValidThread());
- DCHECK(address);
- if (socket_ == INVALID_SOCKET) {
- if (bind_address_.get()) {
- *address = *bind_address_;
- return OK;
- }
- return ERR_SOCKET_NOT_CONNECTED;
- }
-
- struct sockaddr_storage addr_storage;
- socklen_t addr_len = sizeof(addr_storage);
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
- if (getsockname(socket_, addr, &addr_len))
- return MapSystemError(WSAGetLastError());
- if (!address->FromSockAddr(addr, addr_len))
- return ERR_FAILED;
- return OK;
-}
-
-void TCPClientSocketWin::SetSubresourceSpeculation() {
- use_history_.set_subresource_speculation();
-}
-
-void TCPClientSocketWin::SetOmniboxSpeculation() {
- use_history_.set_omnibox_speculation();
-}
-
-bool TCPClientSocketWin::WasEverUsed() const {
- return use_history_.was_used_to_convey_data();
-}
-
-bool TCPClientSocketWin::UsingTCPFastOpen() const {
- // Not supported on windows.
- return false;
-}
-
-bool TCPClientSocketWin::WasNpnNegotiated() const {
- return false;
-}
-
-NextProto TCPClientSocketWin::GetNegotiatedProtocol() const {
- return kProtoUnknown;
-}
-
-bool TCPClientSocketWin::GetSSLInfo(SSLInfo* ssl_info) {
- return false;
-}
-
-int TCPClientSocketWin::Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) {
+int TCPSocketWin::Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
DCHECK(CalledOnValidThread());
DCHECK_NE(socket_, INVALID_SOCKET);
DCHECK(!waiting_read_);
@@ -713,9 +471,9 @@ int TCPClientSocketWin::Read(IOBuffer* buf,
return DoRead(buf, buf_len, callback);
}
-int TCPClientSocketWin::Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) {
+int TCPSocketWin::Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
DCHECK(CalledOnValidThread());
DCHECK_NE(socket_, INVALID_SOCKET);
DCHECK(!waiting_write_);
@@ -747,8 +505,6 @@ int TCPClientSocketWin::Write(IOBuffer* buf,
}
base::StatsCounter write_bytes("tcp.write_bytes");
write_bytes.Add(rv);
- if (rv > 0)
- use_history_.set_was_used_to_convey_data();
net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, rv,
buf->data());
return rv;
@@ -770,29 +526,294 @@ int TCPClientSocketWin::Write(IOBuffer* buf,
return ERR_IO_PENDING;
}
-bool TCPClientSocketWin::SetReceiveBufferSize(int32 size) {
+int TCPSocketWin::GetLocalAddress(IPEndPoint* address) const {
+ DCHECK(CalledOnValidThread());
+ DCHECK(address);
+
+ SockaddrStorage storage;
+ if (getsockname(socket_, storage.addr, &storage.addr_len))
+ return MapSystemError(WSAGetLastError());
+ if (!address->FromSockAddr(storage.addr, storage.addr_len))
+ return ERR_ADDRESS_INVALID;
+
+ return OK;
+}
+
+int TCPSocketWin::GetPeerAddress(IPEndPoint* address) const {
+ DCHECK(CalledOnValidThread());
+ DCHECK(address);
+ if (!IsConnected())
+ return ERR_SOCKET_NOT_CONNECTED;
+ *address = *peer_address_;
+ return OK;
+}
+
+int TCPSocketWin::SetDefaultOptionsForServer() {
+ return SetExclusiveAddrUse();
+}
+
+void TCPSocketWin::SetDefaultOptionsForClient() {
+ // Increase the socket buffer sizes from the default sizes for WinXP. In
+ // performance testing, there is substantial benefit by increasing from 8KB
+ // to 64KB.
+ // See also:
+ // http://support.microsoft.com/kb/823764/EN-US
+ // On Vista, if we manually set these sizes, Vista turns off its receive
+ // window auto-tuning feature.
+ // http://blogs.msdn.com/wndp/archive/2006/05/05/Winhec-blog-tcpip-2.aspx
+ // Since Vista's auto-tune is better than any static value we can could set,
+ // only change these on pre-vista machines.
+ if (base::win::GetVersion() < base::win::VERSION_VISTA) {
+ const int32 kSocketBufferSize = 64 * 1024;
+ SetSocketReceiveBufferSize(socket_, kSocketBufferSize);
+ SetSocketSendBufferSize(socket_, kSocketBufferSize);
+ }
+
+ DisableNagle(socket_, true);
+ SetTCPKeepAlive(socket_, true, kTCPKeepAliveSeconds);
+}
+
+int TCPSocketWin::SetExclusiveAddrUse() {
+ // On Windows, a bound end point can be hijacked by another process by
+ // setting SO_REUSEADDR. Therefore a Windows-only option SO_EXCLUSIVEADDRUSE
+ // was introduced in Windows NT 4.0 SP4. If the socket that is bound to the
+ // end point has SO_EXCLUSIVEADDRUSE enabled, it is not possible for another
+ // socket to forcibly bind to the end point until the end point is unbound.
+ // It is recommend that all server applications must use SO_EXCLUSIVEADDRUSE.
+ // MSDN: http://goo.gl/M6fjQ.
+ //
+ // Unlike on *nix, on Windows a TCP server socket can always bind to an end
+ // point in TIME_WAIT state without setting SO_REUSEADDR, therefore it is not
+ // needed here.
+ //
+ // SO_EXCLUSIVEADDRUSE will prevent a TCP client socket from binding to an end
+ // point in TIME_WAIT status. It does not have this effect for a TCP server
+ // socket.
+
+ BOOL true_value = 1;
+ int rv = setsockopt(socket_, SOL_SOCKET, SO_EXCLUSIVEADDRUSE,
+ reinterpret_cast<const char*>(&true_value),
+ sizeof(true_value));
+ if (rv < 0)
+ return MapSystemError(errno);
+ return OK;
+}
+
+bool TCPSocketWin::SetReceiveBufferSize(int32 size) {
DCHECK(CalledOnValidThread());
return SetSocketReceiveBufferSize(socket_, size);
}
-bool TCPClientSocketWin::SetSendBufferSize(int32 size) {
+bool TCPSocketWin::SetSendBufferSize(int32 size) {
DCHECK(CalledOnValidThread());
return SetSocketSendBufferSize(socket_, size);
}
-bool TCPClientSocketWin::SetKeepAlive(bool enable, int delay) {
+bool TCPSocketWin::SetKeepAlive(bool enable, int delay) {
return SetTCPKeepAlive(socket_, enable, delay);
}
-bool TCPClientSocketWin::SetNoDelay(bool no_delay) {
+bool TCPSocketWin::SetNoDelay(bool no_delay) {
return DisableNagle(socket_, no_delay);
}
-void TCPClientSocketWin::DisableOverlappedReads() {
- g_disable_overlapped_reads = true;
+void TCPSocketWin::Close() {
+ DCHECK(CalledOnValidThread());
+
+ if (socket_ != INVALID_SOCKET) {
+ // Note: don't use CancelIo to cancel pending IO because it doesn't work
+ // when there is a Winsock layered service provider.
+
+ // In most socket implementations, closing a socket results in a graceful
+ // connection shutdown, but in Winsock we have to call shutdown explicitly.
+ // See the MSDN page "Graceful Shutdown, Linger Options, and Socket Closure"
+ // at http://msdn.microsoft.com/en-us/library/ms738547.aspx
+ shutdown(socket_, SD_SEND);
+
+ // This cancels any pending IO.
+ if (closesocket(socket_) < 0)
+ PLOG(ERROR) << "closesocket";
+ socket_ = INVALID_SOCKET;
+ }
+
+ if (accept_event_) {
+ WSACloseEvent(accept_event_);
+ accept_event_ = WSA_INVALID_EVENT;
+ }
+
+ if (!accept_callback_.is_null()) {
+ accept_watcher_.StopWatching();
+ accept_socket_ = NULL;
+ accept_address_ = NULL;
+ accept_callback_.Reset();
+ }
+
+ if (core_) {
+ if (waiting_connect_) {
+ // We closed the socket, so this notification will never come.
+ // From MSDN' WSAEventSelect documentation:
+ // "Closing a socket with closesocket also cancels the association and
+ // selection of network events specified in WSAEventSelect for the
+ // socket".
+ core_->Release();
+ }
+ core_->Detach();
+ core_ = NULL;
+ }
+
+ waiting_connect_ = false;
+ waiting_read_ = false;
+ waiting_write_ = false;
+
+ read_callback_.Reset();
+ write_callback_.Reset();
+ peer_address_.reset();
+ connect_os_error_ = 0;
}
-void TCPClientSocketWin::LogConnectCompletion(int net_error) {
+bool TCPSocketWin::UsingTCPFastOpen() const {
+ // Not supported on windows.
+ return false;
+}
+
+void TCPSocketWin::StartLoggingMultipleConnectAttempts(
+ const AddressList& addresses) {
+ if (!logging_multiple_connect_attempts_) {
+ logging_multiple_connect_attempts_ = true;
+ LogConnectBegin(addresses);
+ } else {
+ NOTREACHED();
+ }
+}
+
+void TCPSocketWin::EndLoggingMultipleConnectAttempts(int net_error) {
+ if (logging_multiple_connect_attempts_) {
+ LogConnectEnd(net_error);
+ logging_multiple_connect_attempts_ = false;
+ } else {
+ NOTREACHED();
+ }
+}
+
+int TCPSocketWin::AcceptInternal(scoped_ptr<TCPSocketWin>* socket,
+ IPEndPoint* address) {
+ SockaddrStorage storage;
+ int new_socket = accept(socket_, storage.addr, &storage.addr_len);
+ if (new_socket < 0) {
+ int net_error = MapSystemError(WSAGetLastError());
+ if (net_error != ERR_IO_PENDING)
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error);
+ return net_error;
+ }
+
+ IPEndPoint ip_end_point;
+ if (!ip_end_point.FromSockAddr(storage.addr, storage.addr_len)) {
+ NOTREACHED();
+ if (closesocket(new_socket) < 0)
+ PLOG(ERROR) << "closesocket";
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED);
+ return ERR_FAILED;
+ }
+ scoped_ptr<TCPSocketWin> tcp_socket(new TCPSocketWin(
+ net_log_.net_log(), net_log_.source()));
+ int adopt_result = tcp_socket->AdoptConnectedSocket(new_socket, ip_end_point);
+ if (adopt_result != OK) {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result);
+ return adopt_result;
+ }
+ *socket = tcp_socket.Pass();
+ *address = ip_end_point;
+ net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT,
+ CreateNetLogIPEndPointCallback(&ip_end_point));
+ return OK;
+}
+
+void TCPSocketWin::OnObjectSignaled(HANDLE object) {
+ WSANETWORKEVENTS ev;
+ if (WSAEnumNetworkEvents(socket_, accept_event_, &ev) == SOCKET_ERROR) {
+ PLOG(ERROR) << "WSAEnumNetworkEvents()";
+ return;
+ }
+
+ if (ev.lNetworkEvents & FD_ACCEPT) {
+ int result = AcceptInternal(accept_socket_, accept_address_);
+ if (result != ERR_IO_PENDING) {
+ accept_socket_ = NULL;
+ accept_address_ = NULL;
+ base::ResetAndReturn(&accept_callback_).Run(result);
+ }
+ }
+}
+
+int TCPSocketWin::DoConnect() {
+ DCHECK_EQ(connect_os_error_, 0);
+ DCHECK(!core_);
+
+ net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
+ CreateNetLogIPEndPointCallback(peer_address_.get()));
+
+ core_ = new Core(this);
+ // WSAEventSelect sets the socket to non-blocking mode as a side effect.
+ // Our connect() and recv() calls require that the socket be non-blocking.
+ WSAEventSelect(socket_, core_->read_overlapped_.hEvent, FD_CONNECT);
+
+ SockaddrStorage storage;
+ if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len))
+ return ERR_INVALID_ARGUMENT;
+ if (!connect(socket_, storage.addr, storage.addr_len)) {
+ // Connected without waiting!
+ //
+ // The MSDN page for connect says:
+ // With a nonblocking socket, the connection attempt cannot be completed
+ // immediately. In this case, connect will return SOCKET_ERROR, and
+ // WSAGetLastError will return WSAEWOULDBLOCK.
+ // which implies that for a nonblocking socket, connect never returns 0.
+ // It's not documented whether the event object will be signaled or not
+ // if connect does return 0. So the code below is essentially dead code
+ // and we don't know if it's correct.
+ NOTREACHED();
+
+ if (ResetEventIfSignaled(core_->read_overlapped_.hEvent))
+ return OK;
+ } else {
+ int os_error = WSAGetLastError();
+ if (os_error != WSAEWOULDBLOCK) {
+ LOG(ERROR) << "connect failed: " << os_error;
+ connect_os_error_ = os_error;
+ int rv = MapConnectError(os_error);
+ CHECK_NE(ERR_IO_PENDING, rv);
+ return rv;
+ }
+ }
+
+ core_->WatchForRead();
+ return ERR_IO_PENDING;
+}
+
+void TCPSocketWin::DoConnectComplete(int result) {
+ // Log the end of this attempt (and any OS error it threw).
+ int os_error = connect_os_error_;
+ connect_os_error_ = 0;
+ if (result != OK) {
+ net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
+ NetLog::IntegerCallback("os_error", os_error));
+ } else {
+ net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT);
+ }
+
+ if (!logging_multiple_connect_attempts_)
+ LogConnectEnd(result);
+}
+
+void TCPSocketWin::LogConnectBegin(const AddressList& addresses) {
+ base::StatsCounter connects("tcp.connect");
+ connects.Increment();
+
+ net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT,
+ addresses.CreateNetLogCallback());
+}
+
+void TCPSocketWin::LogConnectEnd(int net_error) {
if (net_error == OK)
UpdateConnectionTypeHistograms(CONNECTION_ANY);
@@ -820,66 +841,30 @@ void TCPClientSocketWin::LogConnectCompletion(int net_error) {
sizeof(source_address)));
}
-int TCPClientSocketWin::DoRead(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) {
- if (core_->disable_overlapped_reads_) {
- if (!core_->non_blocking_reads_initialized_) {
- WSAEventSelect(socket_, core_->read_overlapped_.hEvent,
- FD_READ | FD_CLOSE);
- core_->non_blocking_reads_initialized_ = true;
- }
- int rv = recv(socket_, buf->data(), buf_len, 0);
- if (rv == SOCKET_ERROR) {
- int os_error = WSAGetLastError();
- if (os_error != WSAEWOULDBLOCK) {
- int net_error = MapSystemError(os_error);
- net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR,
- CreateNetLogSocketErrorCallback(net_error, os_error));
- return net_error;
- }
- } else {
- base::StatsCounter read_bytes("tcp.read_bytes");
- if (rv > 0) {
- use_history_.set_was_used_to_convey_data();
- read_bytes.Add(rv);
- }
- net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, rv,
- buf->data());
- return rv;
+int TCPSocketWin::DoRead(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ if (!core_->non_blocking_reads_initialized_) {
+ WSAEventSelect(socket_, core_->read_overlapped_.hEvent,
+ FD_READ | FD_CLOSE);
+ core_->non_blocking_reads_initialized_ = true;
+ }
+ int rv = recv(socket_, buf->data(), buf_len, 0);
+ if (rv == SOCKET_ERROR) {
+ int os_error = WSAGetLastError();
+ if (os_error != WSAEWOULDBLOCK) {
+ int net_error = MapSystemError(os_error);
+ net_log_.AddEvent(
+ NetLog::TYPE_SOCKET_READ_ERROR,
+ CreateNetLogSocketErrorCallback(net_error, os_error));
+ return net_error;
}
} else {
- buf_len = core_->ThrottleReadSize(buf_len);
-
- WSABUF read_buffer;
- read_buffer.len = buf_len;
- read_buffer.buf = buf->data();
-
- // TODO(wtc): Remove the assertion after enough testing.
- AssertEventNotSignaled(core_->read_overlapped_.hEvent);
- DWORD num;
- DWORD flags = 0;
- int rv = WSARecv(socket_, &read_buffer, 1, &num, &flags,
- &core_->read_overlapped_, NULL);
- if (rv == 0) {
- if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) {
- base::StatsCounter read_bytes("tcp.read_bytes");
- if (num > 0) {
- use_history_.set_was_used_to_convey_data();
- read_bytes.Add(num);
- }
- net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, num,
- buf->data());
- return static_cast<int>(num);
- }
- } else {
- int os_error = WSAGetLastError();
- if (os_error != WSA_IO_PENDING) {
- int net_error = MapSystemError(os_error);
- net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR,
- CreateNetLogSocketErrorCallback(net_error, os_error));
- return net_error;
- }
- }
+ base::StatsCounter read_bytes("tcp.read_bytes");
+ if (rv > 0)
+ read_bytes.Add(rv);
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, rv,
+ buf->data());
+ return rv;
}
waiting_read_ = true;
@@ -890,28 +875,9 @@ int TCPClientSocketWin::DoRead(IOBuffer* buf, int buf_len,
return ERR_IO_PENDING;
}
-void TCPClientSocketWin::DoReadCallback(int rv) {
- DCHECK_NE(rv, ERR_IO_PENDING);
+void TCPSocketWin::DidCompleteConnect() {
+ DCHECK(waiting_connect_);
DCHECK(!read_callback_.is_null());
-
- // Since Run may result in Read being called, clear read_callback_ up front.
- CompletionCallback c = read_callback_;
- read_callback_.Reset();
- c.Run(rv);
-}
-
-void TCPClientSocketWin::DoWriteCallback(int rv) {
- DCHECK_NE(rv, ERR_IO_PENDING);
- DCHECK(!write_callback_.is_null());
-
- // since Run may result in Write being called, clear write_callback_ up front.
- CompletionCallback c = write_callback_;
- write_callback_.Reset();
- c.Run(rv);
-}
-
-void TCPClientSocketWin::DidCompleteConnect() {
- DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE);
int result;
WSANETWORKEVENTS events;
@@ -931,42 +897,16 @@ void TCPClientSocketWin::DidCompleteConnect() {
}
connect_os_error_ = os_error;
- rv = DoConnectLoop(result);
- if (rv != ERR_IO_PENDING) {
- LogConnectCompletion(rv);
- DoReadCallback(rv);
- }
-}
+ DoConnectComplete(result);
+ waiting_connect_ = false;
-void TCPClientSocketWin::DidCompleteRead() {
- DCHECK(waiting_read_);
- DWORD num_bytes, flags;
- BOOL ok = WSAGetOverlappedResult(socket_, &core_->read_overlapped_,
- &num_bytes, FALSE, &flags);
- waiting_read_ = false;
- int rv;
- if (ok) {
- base::StatsCounter read_bytes("tcp.read_bytes");
- read_bytes.Add(num_bytes);
- if (num_bytes > 0)
- use_history_.set_was_used_to_convey_data();
- net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED,
- num_bytes, core_->read_iobuffer_->data());
- rv = static_cast<int>(num_bytes);
- } else {
- int os_error = WSAGetLastError();
- rv = MapSystemError(os_error);
- net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR,
- CreateNetLogSocketErrorCallback(rv, os_error));
- }
- WSAResetEvent(core_->read_overlapped_.hEvent);
- core_->read_iobuffer_ = NULL;
- core_->read_buffer_length_ = 0;
- DoReadCallback(rv);
+ DCHECK_NE(result, ERR_IO_PENDING);
+ base::ResetAndReturn(&read_callback_).Run(result);
}
-void TCPClientSocketWin::DidCompleteWrite() {
+void TCPSocketWin::DidCompleteWrite() {
DCHECK(waiting_write_);
+ DCHECK(!write_callback_.is_null());
DWORD num_bytes, flags;
BOOL ok = WSAGetOverlappedResult(socket_, &core_->write_overlapped_,
@@ -991,18 +931,21 @@ void TCPClientSocketWin::DidCompleteWrite() {
} else {
base::StatsCounter write_bytes("tcp.write_bytes");
write_bytes.Add(num_bytes);
- if (num_bytes > 0)
- use_history_.set_was_used_to_convey_data();
net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, num_bytes,
core_->write_iobuffer_->data());
}
}
+
core_->write_iobuffer_ = NULL;
- DoWriteCallback(rv);
+
+ DCHECK_NE(rv, ERR_IO_PENDING);
+ base::ResetAndReturn(&write_callback_).Run(rv);
}
-void TCPClientSocketWin::DidSignalRead() {
+void TCPSocketWin::DidSignalRead() {
DCHECK(waiting_read_);
+ DCHECK(!read_callback_.is_null());
+
int os_error = 0;
WSANETWORKEVENTS network_events;
int rv = WSAEnumNetworkEvents(socket_, core_->read_overlapped_.hEvent,
@@ -1036,10 +979,14 @@ void TCPClientSocketWin::DidSignalRead() {
core_->WatchForRead();
return;
}
+
waiting_read_ = false;
core_->read_iobuffer_ = NULL;
core_->read_buffer_length_ = 0;
- DoReadCallback(rv);
+
+ DCHECK_NE(rv, ERR_IO_PENDING);
+ base::ResetAndReturn(&read_callback_).Run(rv);
}
} // namespace net
+
diff --git a/chromium/net/socket/tcp_socket_win.h b/chromium/net/socket/tcp_socket_win.h
new file mode 100644
index 00000000000..df5fbf09aec
--- /dev/null
+++ b/chromium/net/socket/tcp_socket_win.h
@@ -0,0 +1,150 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_TCP_SOCKET_WIN_H_
+#define NET_SOCKET_TCP_SOCKET_WIN_H_
+
+#include <winsock2.h>
+
+#include "base/basictypes.h"
+#include "base/compiler_specific.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/threading/non_thread_safe.h"
+#include "base/win/object_watcher.h"
+#include "net/base/address_family.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_export.h"
+#include "net/base/net_log.h"
+
+namespace net {
+
+class AddressList;
+class IOBuffer;
+class IPEndPoint;
+
+class NET_EXPORT TCPSocketWin : NON_EXPORTED_BASE(public base::NonThreadSafe),
+ public base::win::ObjectWatcher::Delegate {
+ public:
+ TCPSocketWin(NetLog* net_log, const NetLog::Source& source);
+ virtual ~TCPSocketWin();
+
+ int Open(AddressFamily family);
+ // Takes ownership of |socket|.
+ int AdoptConnectedSocket(SOCKET socket, const IPEndPoint& peer_address);
+
+ int Bind(const IPEndPoint& address);
+
+ int Listen(int backlog);
+ int Accept(scoped_ptr<TCPSocketWin>* socket,
+ IPEndPoint* address,
+ const CompletionCallback& callback);
+
+ int Connect(const IPEndPoint& address, const CompletionCallback& callback);
+ bool IsConnected() const;
+ bool IsConnectedAndIdle() const;
+
+ // Multiple outstanding requests are not supported.
+ // Full duplex mode (reading and writing at the same time) is supported.
+ int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback);
+ int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback);
+
+ int GetLocalAddress(IPEndPoint* address) const;
+ int GetPeerAddress(IPEndPoint* address) const;
+
+ // Sets various socket options.
+ // The commonly used options for server listening sockets:
+ // - SetExclusiveAddrUse().
+ int SetDefaultOptionsForServer();
+ // The commonly used options for client sockets and accepted sockets:
+ // - Increase the socket buffer sizes for WinXP;
+ // - SetNoDelay(true);
+ // - SetKeepAlive(true, 45).
+ void SetDefaultOptionsForClient();
+ int SetExclusiveAddrUse();
+ bool SetReceiveBufferSize(int32 size);
+ bool SetSendBufferSize(int32 size);
+ bool SetKeepAlive(bool enable, int delay);
+ bool SetNoDelay(bool no_delay);
+
+ void Close();
+
+ bool UsingTCPFastOpen() const;
+ bool IsValid() const { return socket_ != INVALID_SOCKET; }
+
+ // Marks the start/end of a series of connect attempts for logging purpose.
+ //
+ // TCPClientSocket may attempt to connect to multiple addresses until it
+ // succeeds in establishing a connection. The corresponding log will have
+ // multiple NetLog::TYPE_TCP_CONNECT_ATTEMPT entries nested within a
+ // NetLog::TYPE_TCP_CONNECT. These methods set the start/end of
+ // NetLog::TYPE_TCP_CONNECT.
+ //
+ // TODO(yzshen): Change logging format and let TCPClientSocket log the
+ // start/end of a series of connect attempts itself.
+ void StartLoggingMultipleConnectAttempts(const AddressList& addresses);
+ void EndLoggingMultipleConnectAttempts(int net_error);
+
+ const BoundNetLog& net_log() const { return net_log_; }
+
+ private:
+ class Core;
+
+ // base::ObjectWatcher::Delegate implementation.
+ virtual void OnObjectSignaled(HANDLE object) OVERRIDE;
+
+ int AcceptInternal(scoped_ptr<TCPSocketWin>* socket,
+ IPEndPoint* address);
+
+ int DoConnect();
+ void DoConnectComplete(int result);
+
+ void LogConnectBegin(const AddressList& addresses);
+ void LogConnectEnd(int net_error);
+
+ int DoRead(IOBuffer* buf, int buf_len, const CompletionCallback& callback);
+ void DidCompleteConnect();
+ void DidCompleteWrite();
+ void DidSignalRead();
+
+ SOCKET socket_;
+
+ HANDLE accept_event_;
+ base::win::ObjectWatcher accept_watcher_;
+
+ scoped_ptr<TCPSocketWin>* accept_socket_;
+ IPEndPoint* accept_address_;
+ CompletionCallback accept_callback_;
+
+ // The various states that the socket could be in.
+ bool waiting_connect_;
+ bool waiting_read_;
+ bool waiting_write_;
+
+ // The core of the socket that can live longer than the socket itself. We pass
+ // resources to the Windows async IO functions and we have to make sure that
+ // they are not destroyed while the OS still references them.
+ scoped_refptr<Core> core_;
+
+ // External callback; called when connect or read is complete.
+ CompletionCallback read_callback_;
+
+ // External callback; called when write is complete.
+ CompletionCallback write_callback_;
+
+ scoped_ptr<IPEndPoint> peer_address_;
+ // The OS error that a connect attempt last completed with.
+ int connect_os_error_;
+
+ bool logging_multiple_connect_attempts_;
+
+ BoundNetLog net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(TCPSocketWin);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_TCP_SOCKET_WIN_H_
+
diff --git a/chromium/net/socket/transport_client_socket_pool.cc b/chromium/net/socket/transport_client_socket_pool.cc
index 8255e988fa4..d03e3e651ac 100644
--- a/chromium/net/socket/transport_client_socket_pool.cc
+++ b/chromium/net/socket/transport_client_socket_pool.cc
@@ -48,25 +48,18 @@ bool AddressListOnlyContainsIPv6(const AddressList& list) {
TransportSocketParams::TransportSocketParams(
const HostPortPair& host_port_pair,
- RequestPriority priority,
bool disable_resolver_cache,
bool ignore_limits,
const OnHostResolutionCallback& host_resolution_callback)
: destination_(host_port_pair),
ignore_limits_(ignore_limits),
host_resolution_callback_(host_resolution_callback) {
- Initialize(priority, disable_resolver_cache);
-}
-
-TransportSocketParams::~TransportSocketParams() {}
-
-void TransportSocketParams::Initialize(RequestPriority priority,
- bool disable_resolver_cache) {
- destination_.set_priority(priority);
if (disable_resolver_cache)
destination_.set_allow_cached_response(false);
}
+TransportSocketParams::~TransportSocketParams() {}
+
// TransportConnectJobs will time out after this many seconds. Note this is
// the total time, including both host resolution and TCP connect() times.
//
@@ -80,13 +73,14 @@ static const int kTransportConnectJobTimeoutInSeconds = 240; // 4 minutes.
TransportConnectJob::TransportConnectJob(
const std::string& group_name,
+ RequestPriority priority,
const scoped_refptr<TransportSocketParams>& params,
base::TimeDelta timeout_duration,
ClientSocketFactory* client_socket_factory,
HostResolver* host_resolver,
Delegate* delegate,
NetLog* net_log)
- : ConnectJob(group_name, timeout_duration, delegate,
+ : ConnectJob(group_name, timeout_duration, priority, delegate,
BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
params_(params),
client_socket_factory_(client_socket_factory),
@@ -107,10 +101,11 @@ LoadState TransportConnectJob::GetLoadState() const {
case STATE_TRANSPORT_CONNECT:
case STATE_TRANSPORT_CONNECT_COMPLETE:
return LOAD_STATE_CONNECTING;
- default:
- NOTREACHED();
+ case STATE_NONE:
return LOAD_STATE_IDLE;
}
+ NOTREACHED();
+ return LOAD_STATE_IDLE;
}
// static
@@ -166,7 +161,9 @@ int TransportConnectJob::DoResolveHost() {
connect_timing_.dns_start = base::TimeTicks::Now();
return resolver_.Resolve(
- params_->destination(), &addresses_,
+ params_->destination(),
+ priority(),
+ &addresses_,
base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this)),
net_log());
}
@@ -190,8 +187,8 @@ int TransportConnectJob::DoResolveHostComplete(int result) {
int TransportConnectJob::DoTransportConnect() {
next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
- transport_socket_.reset(client_socket_factory_->CreateTransportClientSocket(
- addresses_, net_log().net_log(), net_log().source()));
+ transport_socket_ = client_socket_factory_->CreateTransportClientSocket(
+ addresses_, net_log().net_log(), net_log().source());
int rv = transport_socket_->Connect(
base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this)));
if (rv == ERR_IO_PENDING &&
@@ -246,7 +243,7 @@ int TransportConnectJob::DoTransportConnectComplete(int result) {
100);
}
}
- set_socket(transport_socket_.release());
+ SetSocket(transport_socket_.Pass());
fallback_timer_.Stop();
} else {
// Be a bit paranoid and kill off the fallback members to prevent reuse.
@@ -270,9 +267,9 @@ void TransportConnectJob::DoIPv6FallbackTransportConnect() {
fallback_addresses_.reset(new AddressList(addresses_));
MakeAddressListStartWithIPv4(fallback_addresses_.get());
- fallback_transport_socket_.reset(
+ fallback_transport_socket_ =
client_socket_factory_->CreateTransportClientSocket(
- *fallback_addresses_, net_log().net_log(), net_log().source()));
+ *fallback_addresses_, net_log().net_log(), net_log().source());
fallback_connect_start_time_ = base::TimeTicks::Now();
int rv = fallback_transport_socket_->Connect(
base::Bind(
@@ -317,7 +314,7 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) {
base::TimeDelta::FromMilliseconds(1),
base::TimeDelta::FromMinutes(10),
100);
- set_socket(fallback_transport_socket_.release());
+ SetSocket(fallback_transport_socket_.Pass());
next_state_ = STATE_NONE;
transport_socket_.reset();
} else {
@@ -333,18 +330,20 @@ int TransportConnectJob::ConnectInternal() {
return DoLoop(OK);
}
-ConnectJob*
+scoped_ptr<ConnectJob>
TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const {
- return new TransportConnectJob(group_name,
- request.params(),
- ConnectionTimeout(),
- client_socket_factory_,
- host_resolver_,
- delegate,
- net_log_);
+ return scoped_ptr<ConnectJob>(
+ new TransportConnectJob(group_name,
+ request.priority(),
+ request.params(),
+ ConnectionTimeout(),
+ client_socket_factory_,
+ host_resolver_,
+ delegate,
+ net_log_));
}
base::TimeDelta
@@ -360,11 +359,11 @@ TransportClientSocketPool::TransportClientSocketPool(
HostResolver* host_resolver,
ClientSocketFactory* client_socket_factory,
NetLog* net_log)
- : base_(max_sockets, max_sockets_per_group, histograms,
+ : base_(NULL, max_sockets, max_sockets_per_group, histograms,
ClientSocketPool::unused_idle_socket_timeout(),
ClientSocketPool::used_idle_socket_timeout(),
new TransportConnectJobFactory(client_socket_factory,
- host_resolver, net_log)) {
+ host_resolver, net_log)) {
base_.EnableConnectBackupJobs();
}
@@ -419,19 +418,15 @@ void TransportClientSocketPool::CancelRequest(
void TransportClientSocketPool::ReleaseSocket(
const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) {
- base_.ReleaseSocket(group_name, socket, id);
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
}
void TransportClientSocketPool::FlushWithError(int error) {
base_.FlushWithError(error);
}
-bool TransportClientSocketPool::IsStalled() const {
- return base_.IsStalled();
-}
-
void TransportClientSocketPool::CloseIdleSockets() {
base_.CloseIdleSockets();
}
@@ -450,14 +445,6 @@ LoadState TransportClientSocketPool::GetLoadState(
return base_.GetLoadState(group_name, handle);
}
-void TransportClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) {
- base_.AddLayeredPool(layered_pool);
-}
-
-void TransportClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) {
- base_.RemoveLayeredPool(layered_pool);
-}
-
base::DictionaryValue* TransportClientSocketPool::GetInfoAsValue(
const std::string& name,
const std::string& type,
@@ -473,4 +460,18 @@ ClientSocketPoolHistograms* TransportClientSocketPool::histograms() const {
return base_.histograms();
}
+bool TransportClientSocketPool::IsStalled() const {
+ return base_.IsStalled();
+}
+
+void TransportClientSocketPool::AddHigherLayeredPool(
+ HigherLayeredPool* higher_pool) {
+ base_.AddHigherLayeredPool(higher_pool);
+}
+
+void TransportClientSocketPool::RemoveHigherLayeredPool(
+ HigherLayeredPool* higher_pool) {
+ base_.RemoveHigherLayeredPool(higher_pool);
+}
+
} // namespace net
diff --git a/chromium/net/socket/transport_client_socket_pool.h b/chromium/net/socket/transport_client_socket_pool.h
index bb53b3da301..16e421a4550 100644
--- a/chromium/net/socket/transport_client_socket_pool.h
+++ b/chromium/net/socket/transport_client_socket_pool.h
@@ -34,7 +34,6 @@ class NET_EXPORT_PRIVATE TransportSocketParams
// connection will be aborted with that value.
TransportSocketParams(
const HostPortPair& host_port_pair,
- RequestPriority priority,
bool disable_resolver_cache,
bool ignore_limits,
const OnHostResolutionCallback& host_resolution_callback);
@@ -49,8 +48,6 @@ class NET_EXPORT_PRIVATE TransportSocketParams
friend class base::RefCounted<TransportSocketParams>;
~TransportSocketParams();
- void Initialize(RequestPriority priority, bool disable_resolver_cache);
-
HostResolver::RequestInfo destination_;
bool ignore_limits_;
const OnHostResolutionCallback host_resolution_callback_;
@@ -69,6 +66,7 @@ class NET_EXPORT_PRIVATE TransportSocketParams
class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob {
public:
TransportConnectJob(const std::string& group_name,
+ RequestPriority priority,
const scoped_refptr<TransportSocketParams>& params,
base::TimeDelta timeout_duration,
ClientSocketFactory* client_socket_factory,
@@ -132,6 +130,8 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob {
class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
public:
+ typedef TransportSocketParams SocketParams;
+
TransportClientSocketPool(
int max_sockets,
int max_sockets_per_group,
@@ -156,10 +156,9 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
virtual void CancelRequest(const std::string& group_name,
ClientSocketHandle* handle) OVERRIDE;
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) OVERRIDE;
virtual void FlushWithError(int error) OVERRIDE;
- virtual bool IsStalled() const OVERRIDE;
virtual void CloseIdleSockets() OVERRIDE;
virtual int IdleSocketCount() const OVERRIDE;
virtual int IdleSocketCountInGroup(
@@ -167,8 +166,6 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
virtual LoadState GetLoadState(
const std::string& group_name,
const ClientSocketHandle* handle) const OVERRIDE;
- virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE;
- virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE;
virtual base::DictionaryValue* GetInfoAsValue(
const std::string& name,
const std::string& type,
@@ -176,6 +173,11 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
+ // HigherLayeredPool implementation.
+ virtual bool IsStalled() const OVERRIDE;
+ virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
+ virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
+
private:
typedef ClientSocketPoolBase<TransportSocketParams> PoolBase;
@@ -193,7 +195,7 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
// ClientSocketPoolBase::ConnectJobFactory methods.
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const OVERRIDE;
@@ -213,9 +215,6 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPool);
};
-REGISTER_SOCKET_PARAMS_FOR_POOL(TransportClientSocketPool,
- TransportSocketParams);
-
} // namespace net
#endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
diff --git a/chromium/net/socket/transport_client_socket_pool_unittest.cc b/chromium/net/socket/transport_client_socket_pool_unittest.cc
index dfa1151b291..a984ea3b740 100644
--- a/chromium/net/socket/transport_client_socket_pool_unittest.cc
+++ b/chromium/net/socket/transport_client_socket_pool_unittest.cc
@@ -23,6 +23,7 @@
#include "net/socket/client_socket_handle.h"
#include "net/socket/client_socket_pool_histograms.h"
#include "net/socket/socket_test_util.h"
+#include "net/socket/ssl_client_socket.h"
#include "net/socket/stream_socket.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -340,16 +341,16 @@ class MockClientSocketFactory : public ClientSocketFactory {
delay_(base::TimeDelta::FromMilliseconds(
ClientSocketPool::kMaxConnectRetryIntervalMs)) {}
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE {
NOTREACHED();
- return NULL;
+ return scoped_ptr<DatagramClientSocket>();
}
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* /* net_log */,
const NetLog::Source& /* source */) OVERRIDE {
@@ -363,34 +364,41 @@ class MockClientSocketFactory : public ClientSocketFactory {
switch (type) {
case MOCK_CLIENT_SOCKET:
- return new MockClientSocket(addresses, net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockClientSocket(addresses, net_log_));
case MOCK_FAILING_CLIENT_SOCKET:
- return new MockFailingClientSocket(addresses, net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockFailingClientSocket(addresses, net_log_));
case MOCK_PENDING_CLIENT_SOCKET:
- return new MockPendingClientSocket(
- addresses, true, false, base::TimeDelta(), net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockPendingClientSocket(
+ addresses, true, false, base::TimeDelta(), net_log_));
case MOCK_PENDING_FAILING_CLIENT_SOCKET:
- return new MockPendingClientSocket(
- addresses, false, false, base::TimeDelta(), net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockPendingClientSocket(
+ addresses, false, false, base::TimeDelta(), net_log_));
case MOCK_DELAYED_CLIENT_SOCKET:
- return new MockPendingClientSocket(
- addresses, true, false, delay_, net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockPendingClientSocket(
+ addresses, true, false, delay_, net_log_));
case MOCK_STALLED_CLIENT_SOCKET:
- return new MockPendingClientSocket(
- addresses, true, true, base::TimeDelta(), net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockPendingClientSocket(
+ addresses, true, true, base::TimeDelta(), net_log_));
default:
NOTREACHED();
- return new MockClientSocket(addresses, net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockClientSocket(addresses, net_log_));
}
}
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) OVERRIDE {
NOTIMPLEMENTED();
- return NULL;
+ return scoped_ptr<SSLClientSocket>();
}
virtual void ClearSSLSessionCache() OVERRIDE {
@@ -431,11 +439,7 @@ class TransportClientSocketPoolTest : public testing::Test {
ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(true)),
params_(
new TransportSocketParams(HostPortPair("www.google.com", 80),
- kDefaultPriority, false, false,
- OnHostResolutionCallback())),
- low_params_(
- new TransportSocketParams(HostPortPair("www.google.com", 80),
- LOW, false, false,
+ false, false,
OnHostResolutionCallback())),
histograms_(new ClientSocketPoolHistograms("TCPUnitTest")),
host_resolver_(new MockHostResolver),
@@ -455,7 +459,7 @@ class TransportClientSocketPoolTest : public testing::Test {
int StartRequest(const std::string& group_name, RequestPriority priority) {
scoped_refptr<TransportSocketParams> params(new TransportSocketParams(
- HostPortPair("www.google.com", 80), MEDIUM, false, false,
+ HostPortPair("www.google.com", 80), false, false,
OnHostResolutionCallback()));
return test_base_.StartRequestUsingPool(
&pool_, group_name, priority, params);
@@ -479,7 +483,6 @@ class TransportClientSocketPoolTest : public testing::Test {
bool connect_backup_jobs_enabled_;
CapturingNetLog net_log_;
scoped_refptr<TransportSocketParams> params_;
- scoped_refptr<TransportSocketParams> low_params_;
scoped_ptr<ClientSocketPoolHistograms> histograms_;
scoped_ptr<MockHostResolver> host_resolver_;
MockClientSocketFactory client_socket_factory_;
@@ -561,7 +564,7 @@ TEST(TransportConnectJobTest, MakeAddrListStartWithIPv4) {
TEST_F(TransportClientSocketPoolTest, Basic) {
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_,
+ int rv = handle.Init("a", params_, LOW, callback.callback(), &pool_,
BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -573,13 +576,27 @@ TEST_F(TransportClientSocketPoolTest, Basic) {
TestLoadTimingInfoConnectedNotReused(handle);
}
+// Make sure that TransportConnectJob passes on its priority to its
+// HostResolver request on Init.
+TEST_F(TransportClientSocketPoolTest, SetResolvePriorityOnInit) {
+ for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) {
+ RequestPriority priority = static_cast<RequestPriority>(i);
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a", params_, priority, callback.callback(), &pool_,
+ BoundNetLog()));
+ EXPECT_EQ(priority, host_resolver_->last_request_priority());
+ }
+}
+
TEST_F(TransportClientSocketPoolTest, InitHostResolutionFailure) {
host_resolver_->rules()->AddSimulatedFailure("unresolvable.host.name");
TestCompletionCallback callback;
ClientSocketHandle handle;
HostPortPair host_port_pair("unresolvable.host.name", 80);
scoped_refptr<TransportSocketParams> dest(new TransportSocketParams(
- host_port_pair, kDefaultPriority, false, false,
+ host_port_pair, false, false,
OnHostResolutionCallback()));
EXPECT_EQ(ERR_IO_PENDING,
handle.Init("a", dest, kDefaultPriority, callback.callback(),
@@ -854,7 +871,7 @@ class RequestSocketCallback : public TestCompletionCallbackBase {
}
within_callback_ = true;
scoped_refptr<TransportSocketParams> dest(new TransportSocketParams(
- HostPortPair("www.google.com", 80), LOWEST, false, false,
+ HostPortPair("www.google.com", 80), false, false,
OnHostResolutionCallback()));
int rv = handle_->Init("a", dest, LOWEST, callback(), pool_,
BoundNetLog());
@@ -874,7 +891,7 @@ TEST_F(TransportClientSocketPoolTest, RequestTwice) {
ClientSocketHandle handle;
RequestSocketCallback callback(&handle, &pool_);
scoped_refptr<TransportSocketParams> dest(new TransportSocketParams(
- HostPortPair("www.google.com", 80), LOWEST, false, false,
+ HostPortPair("www.google.com", 80), false, false,
OnHostResolutionCallback()));
int rv = handle.Init("a", dest, LOWEST, callback.callback(), &pool_,
BoundNetLog());
@@ -939,7 +956,7 @@ TEST_F(TransportClientSocketPoolTest, FailingActiveRequestWithPendingRequests) {
TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) {
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_,
+ int rv = handle.Init("a", params_, LOW, callback.callback(), &pool_,
BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -957,7 +974,7 @@ TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) {
// Now we should have 1 idle socket.
EXPECT_EQ(1, pool_.IdleSocketCount());
- rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_,
+ rv = handle.Init("a", params_, LOW, callback.callback(), &pool_,
BoundNetLog());
EXPECT_EQ(OK, rv);
EXPECT_EQ(0, pool_.IdleSocketCount());
@@ -967,7 +984,7 @@ TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) {
TEST_F(TransportClientSocketPoolTest, ResetIdleSocketsOnIPAddressChange) {
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_,
+ int rv = handle.Init("a", params_, LOW, callback.callback(), &pool_,
BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -1023,7 +1040,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketConnect) {
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_,
+ int rv = handle.Init("b", params_, LOW, callback.callback(), &pool_,
BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -1065,7 +1082,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketCancel) {
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("c", low_params_, LOW, callback.callback(), &pool_,
+ int rv = handle.Init("c", params_, LOW, callback.callback(), &pool_,
BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -1111,7 +1128,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) {
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_,
+ int rv = handle.Init("b", params_, LOW, callback.callback(), &pool_,
BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -1159,7 +1176,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterDelay) {
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_,
+ int rv = handle.Init("b", params_, LOW, callback.callback(), &pool_,
BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -1215,7 +1232,7 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv4FinishesFirst) {
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool,
+ int rv = handle.Init("a", params_, LOW, callback.callback(), &pool,
BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -1260,7 +1277,7 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv6FinishesFirst) {
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool,
+ int rv = handle.Init("a", params_, LOW, callback.callback(), &pool,
BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -1294,7 +1311,7 @@ TEST_F(TransportClientSocketPoolTest, IPv6NoIPv4AddressesToFallbackTo) {
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool,
+ int rv = handle.Init("a", params_, LOW, callback.callback(), &pool,
BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
@@ -1327,7 +1344,7 @@ TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) {
TestCompletionCallback callback;
ClientSocketHandle handle;
- int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool,
+ int rv = handle.Init("a", params_, LOW, callback.callback(), &pool,
BoundNetLog());
EXPECT_EQ(ERR_IO_PENDING, rv);
EXPECT_FALSE(handle.is_initialized());
diff --git a/chromium/net/socket/transport_client_socket_unittest.cc b/chromium/net/socket/transport_client_socket_unittest.cc
index 2f75e740067..5548b27b995 100644
--- a/chromium/net/socket/transport_client_socket_unittest.cc
+++ b/chromium/net/socket/transport_client_socket_unittest.cc
@@ -48,8 +48,9 @@ class TransportClientSocketTest
// Implement StreamListenSocket::Delegate methods
virtual void DidAccept(StreamListenSocket* server,
- StreamListenSocket* connection) OVERRIDE {
- connected_sock_ = reinterpret_cast<TCPListenSocket*>(connection);
+ scoped_ptr<StreamListenSocket> connection) OVERRIDE {
+ connected_sock_.reset(
+ static_cast<TCPListenSocket*>(connection.release()));
}
virtual void DidRead(StreamListenSocket*, const char* str, int len) OVERRIDE {
// TODO(dkegel): this might not be long enough to tickle some bugs.
@@ -65,7 +66,7 @@ class TransportClientSocketTest
void CloseServerSocket() {
// delete the connected_sock_, which will close it.
- connected_sock_ = NULL;
+ connected_sock_.reset();
}
void PauseServerReads() {
@@ -94,8 +95,8 @@ class TransportClientSocketTest
scoped_ptr<StreamSocket> sock_;
private:
- scoped_refptr<TCPListenSocket> listen_sock_;
- scoped_refptr<TCPListenSocket> connected_sock_;
+ scoped_ptr<TCPListenSocket> listen_sock_;
+ scoped_ptr<TCPListenSocket> connected_sock_;
bool close_server_socket_on_next_send_;
};
@@ -103,7 +104,7 @@ void TransportClientSocketTest::SetUp() {
::testing::TestWithParam<ClientSocketTestTypes>::SetUp();
// Find a free port to listen on
- scoped_refptr<TCPListenSocket> sock;
+ scoped_ptr<TCPListenSocket> sock;
int port;
// Range of ports to listen on. Shouldn't need to try many.
const int kMinPort = 10100;
@@ -117,7 +118,7 @@ void TransportClientSocketTest::SetUp() {
break;
}
ASSERT_TRUE(sock.get() != NULL);
- listen_sock_ = sock;
+ listen_sock_ = sock.Pass();
listen_port_ = port;
AddressList addr;
@@ -125,15 +126,15 @@ void TransportClientSocketTest::SetUp() {
scoped_ptr<HostResolver> resolver(new MockHostResolver());
HostResolver::RequestInfo info(HostPortPair("localhost", listen_port_));
TestCompletionCallback callback;
- int rv = resolver->Resolve(info, &addr, callback.callback(), NULL,
- BoundNetLog());
+ int rv = resolver->Resolve(
+ info, DEFAULT_PRIORITY, &addr, callback.callback(), NULL, BoundNetLog());
CHECK_EQ(ERR_IO_PENDING, rv);
rv = callback.WaitForResult();
CHECK_EQ(rv, OK);
- sock_.reset(
+ sock_ =
socket_factory_->CreateTransportClientSocket(addr,
&net_log_,
- NetLog::Source()));
+ NetLog::Source());
}
int TransportClientSocketTest::DrainClientSocket(
diff --git a/chromium/net/socket/unix_domain_socket_posix.cc b/chromium/net/socket/unix_domain_socket_posix.cc
index 5b6b2498245..2b781d58b35 100644
--- a/chromium/net/socket/unix_domain_socket_posix.cc
+++ b/chromium/net/socket/unix_domain_socket_posix.cc
@@ -21,6 +21,7 @@
#include "build/build_config.h"
#include "net/base/net_errors.h"
#include "net/base/net_util.h"
+#include "net/socket/socket_descriptor.h"
namespace net {
@@ -48,12 +49,12 @@ bool GetPeerIds(int socket, uid_t* user_id, gid_t* group_id) {
} // namespace
// static
-UnixDomainSocket::AuthCallback NoAuthentication() {
+UnixDomainSocket::AuthCallback UnixDomainSocket::NoAuthentication() {
return base::Bind(NoAuthenticationCallback);
}
// static
-UnixDomainSocket* UnixDomainSocket::CreateAndListenInternal(
+scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListenInternal(
const std::string& path,
const std::string& fallback_path,
StreamListenSocket::Delegate* del,
@@ -63,14 +64,15 @@ UnixDomainSocket* UnixDomainSocket::CreateAndListenInternal(
if (s == kInvalidSocket && !fallback_path.empty())
s = CreateAndBind(fallback_path, use_abstract_namespace);
if (s == kInvalidSocket)
- return NULL;
- UnixDomainSocket* sock = new UnixDomainSocket(s, del, auth_callback);
+ return scoped_ptr<UnixDomainSocket>();
+ scoped_ptr<UnixDomainSocket> sock(
+ new UnixDomainSocket(s, del, auth_callback));
sock->Listen();
- return sock;
+ return sock.Pass();
}
// static
-scoped_refptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen(
+scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen(
const std::string& path,
StreamListenSocket::Delegate* del,
const AuthCallback& auth_callback) {
@@ -79,14 +81,14 @@ scoped_refptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen(
#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
// static
-scoped_refptr<UnixDomainSocket>
+scoped_ptr<UnixDomainSocket>
UnixDomainSocket::CreateAndListenWithAbstractNamespace(
const std::string& path,
const std::string& fallback_path,
StreamListenSocket::Delegate* del,
const AuthCallback& auth_callback) {
- return make_scoped_refptr(
- CreateAndListenInternal(path, fallback_path, del, auth_callback, true));
+ return
+ CreateAndListenInternal(path, fallback_path, del, auth_callback, true);
}
#endif
@@ -106,7 +108,7 @@ SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path,
static const size_t kPathMax = sizeof(addr.sun_path);
if (use_abstract_namespace + path.size() + 1 /* '\0' */ > kPathMax)
return kInvalidSocket;
- const SocketDescriptor s = socket(PF_UNIX, SOCK_STREAM, 0);
+ const SocketDescriptor s = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
if (s == kInvalidSocket)
return kInvalidSocket;
memset(&addr, 0, sizeof(addr));
@@ -147,11 +149,11 @@ void UnixDomainSocket::Accept() {
LOG(ERROR) << "close() error";
return;
}
- scoped_refptr<UnixDomainSocket> sock(
+ scoped_ptr<UnixDomainSocket> sock(
new UnixDomainSocket(conn, socket_delegate_, auth_callback_));
// It's up to the delegate to AddRef if it wants to keep it around.
sock->WatchSocket(WAITING_READ);
- socket_delegate_->DidAccept(this, sock.get());
+ socket_delegate_->DidAccept(this, sock.PassAs<StreamListenSocket>());
}
UnixDomainSocketFactory::UnixDomainSocketFactory(
@@ -162,10 +164,10 @@ UnixDomainSocketFactory::UnixDomainSocketFactory(
UnixDomainSocketFactory::~UnixDomainSocketFactory() {}
-scoped_refptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen(
+scoped_ptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen(
StreamListenSocket::Delegate* delegate) const {
return UnixDomainSocket::CreateAndListen(
- path_, delegate, auth_callback_);
+ path_, delegate, auth_callback_).PassAs<StreamListenSocket>();
}
#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
@@ -181,11 +183,12 @@ UnixDomainSocketWithAbstractNamespaceFactory(
UnixDomainSocketWithAbstractNamespaceFactory::
~UnixDomainSocketWithAbstractNamespaceFactory() {}
-scoped_refptr<StreamListenSocket>
+scoped_ptr<StreamListenSocket>
UnixDomainSocketWithAbstractNamespaceFactory::CreateAndListen(
StreamListenSocket::Delegate* delegate) const {
return UnixDomainSocket::CreateAndListenWithAbstractNamespace(
- path_, fallback_path_, delegate, auth_callback_);
+ path_, fallback_path_, delegate, auth_callback_)
+ .PassAs<StreamListenSocket>();
}
#endif
diff --git a/chromium/net/socket/unix_domain_socket_posix.h b/chromium/net/socket/unix_domain_socket_posix.h
index 2ef06803d24..98d0c11a648 100644
--- a/chromium/net/socket/unix_domain_socket_posix.h
+++ b/chromium/net/socket/unix_domain_socket_posix.h
@@ -10,7 +10,6 @@
#include "base/basictypes.h"
#include "base/callback_forward.h"
#include "base/compiler_specific.h"
-#include "base/memory/ref_counted.h"
#include "build/build_config.h"
#include "net/base/net_export.h"
#include "net/socket/stream_listen_socket.h"
@@ -26,6 +25,8 @@ namespace net {
// Unix Domain Socket Implementation. Supports abstract namespaces on Linux.
class NET_EXPORT UnixDomainSocket : public StreamListenSocket {
public:
+ virtual ~UnixDomainSocket();
+
// Callback that returns whether the already connected client, identified by
// its process |user_id| and |group_id|, is allowed to keep the connection
// open. Note that the socket is closed immediately in case the callback
@@ -38,7 +39,7 @@ class NET_EXPORT UnixDomainSocket : public StreamListenSocket {
// Note that the returned UnixDomainSocket instance does not take ownership of
// |del|.
- static scoped_refptr<UnixDomainSocket> CreateAndListen(
+ static scoped_ptr<UnixDomainSocket> CreateAndListen(
const std::string& path,
StreamListenSocket::Delegate* del,
const AuthCallback& auth_callback);
@@ -47,7 +48,7 @@ class NET_EXPORT UnixDomainSocket : public StreamListenSocket {
// Same as above except that the created socket uses the abstract namespace
// which is a Linux-only feature. If |fallback_path| is not empty,
// make the second attempt with the provided fallback name.
- static scoped_refptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace(
+ static scoped_ptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace(
const std::string& path,
const std::string& fallback_path,
StreamListenSocket::Delegate* del,
@@ -58,9 +59,8 @@ class NET_EXPORT UnixDomainSocket : public StreamListenSocket {
UnixDomainSocket(SocketDescriptor s,
StreamListenSocket::Delegate* del,
const AuthCallback& auth_callback);
- virtual ~UnixDomainSocket();
- static UnixDomainSocket* CreateAndListenInternal(
+ static scoped_ptr<UnixDomainSocket> CreateAndListenInternal(
const std::string& path,
const std::string& fallback_path,
StreamListenSocket::Delegate* del,
@@ -87,7 +87,7 @@ class NET_EXPORT UnixDomainSocketFactory : public StreamListenSocketFactory {
virtual ~UnixDomainSocketFactory();
// StreamListenSocketFactory:
- virtual scoped_refptr<StreamListenSocket> CreateAndListen(
+ virtual scoped_ptr<StreamListenSocket> CreateAndListen(
StreamListenSocket::Delegate* delegate) const OVERRIDE;
protected:
@@ -111,7 +111,7 @@ class NET_EXPORT UnixDomainSocketWithAbstractNamespaceFactory
virtual ~UnixDomainSocketWithAbstractNamespaceFactory();
// UnixDomainSocketFactory:
- virtual scoped_refptr<StreamListenSocket> CreateAndListen(
+ virtual scoped_ptr<StreamListenSocket> CreateAndListen(
StreamListenSocket::Delegate* delegate) const OVERRIDE;
private:
diff --git a/chromium/net/socket/unix_domain_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_socket_posix_unittest.cc
index 5abe03b4ae3..f062d274205 100644
--- a/chromium/net/socket/unix_domain_socket_posix_unittest.cc
+++ b/chromium/net/socket/unix_domain_socket_posix_unittest.cc
@@ -29,6 +29,7 @@
#include "base/synchronization/lock.h"
#include "base/threading/platform_thread.h"
#include "base/threading/thread.h"
+#include "net/socket/socket_descriptor.h"
#include "net/socket/unix_domain_socket_posix.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -102,9 +103,9 @@ class TestListenSocketDelegate : public StreamListenSocket::Delegate {
: event_manager_(event_manager) {}
virtual void DidAccept(StreamListenSocket* server,
- StreamListenSocket* connection) OVERRIDE {
+ scoped_ptr<StreamListenSocket> connection) OVERRIDE {
LOG(ERROR) << __PRETTY_FUNCTION__;
- connection_ = connection;
+ connection_ = connection.Pass();
Notify(EVENT_ACCEPT);
}
@@ -138,7 +139,7 @@ class TestListenSocketDelegate : public StreamListenSocket::Delegate {
}
const scoped_refptr<EventManager> event_manager_;
- scoped_refptr<StreamListenSocket> connection_;
+ scoped_ptr<StreamListenSocket> connection_;
base::Lock mutex_;
string data_;
};
@@ -172,7 +173,7 @@ class UnixDomainSocketTestHelper : public testing::Test {
virtual void TearDown() OVERRIDE {
DeleteSocketFile();
- socket_ = NULL;
+ socket_.reset();
socket_delegate_.reset();
event_manager_ = NULL;
}
@@ -187,10 +188,10 @@ class UnixDomainSocketTestHelper : public testing::Test {
}
SocketDescriptor CreateClientSocket() {
- const SocketDescriptor sock = socket(PF_UNIX, SOCK_STREAM, 0);
+ const SocketDescriptor sock = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
if (sock < 0) {
LOG(ERROR) << "socket() error";
- return StreamListenSocket::kInvalidSocket;
+ return kInvalidSocket;
}
sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
@@ -200,7 +201,7 @@ class UnixDomainSocketTestHelper : public testing::Test {
addr_len = sizeof(sockaddr_un);
if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) {
LOG(ERROR) << "connect() error";
- return StreamListenSocket::kInvalidSocket;
+ return kInvalidSocket;
}
return sock;
}
@@ -221,7 +222,7 @@ class UnixDomainSocketTestHelper : public testing::Test {
const bool allow_user_;
scoped_refptr<EventManager> event_manager_;
scoped_ptr<TestListenSocketDelegate> socket_delegate_;
- scoped_refptr<UnixDomainSocket> socket_;
+ scoped_ptr<UnixDomainSocket> socket_;
};
class UnixDomainSocketTest : public UnixDomainSocketTestHelper {
@@ -264,7 +265,7 @@ TEST_F(UnixDomainSocketTestWithInvalidPath,
}
TEST_F(UnixDomainSocketTest, TestFallbackName) {
- scoped_refptr<UnixDomainSocket> existing_socket =
+ scoped_ptr<UnixDomainSocket> existing_socket =
UnixDomainSocket::CreateAndListenWithAbstractNamespace(
file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
EXPECT_FALSE(existing_socket.get() == NULL);
@@ -280,7 +281,6 @@ TEST_F(UnixDomainSocketTest, TestFallbackName) {
socket_delegate_.get(),
MakeAuthCallback());
EXPECT_FALSE(socket_.get() == NULL);
- existing_socket = NULL;
}
#endif
@@ -291,7 +291,7 @@ TEST_F(UnixDomainSocketTest, TestWithClient) {
// Create the client socket.
const SocketDescriptor sock = CreateClientSocket();
- ASSERT_NE(StreamListenSocket::kInvalidSocket, sock);
+ ASSERT_NE(kInvalidSocket, sock);
event = event_manager_->WaitForEvent();
ASSERT_EQ(EVENT_AUTH_GRANTED, event);
event = event_manager_->WaitForEvent();
@@ -316,7 +316,7 @@ TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) {
EventType event = event_manager_->WaitForEvent();
ASSERT_EQ(EVENT_LISTEN, event);
const SocketDescriptor sock = CreateClientSocket();
- ASSERT_NE(StreamListenSocket::kInvalidSocket, sock);
+ ASSERT_NE(kInvalidSocket, sock);
event = event_manager_->WaitForEvent();
ASSERT_EQ(EVENT_AUTH_DENIED, event);