diff options
author | Andras Becsi <andras.becsi@digia.com> | 2014-03-18 13:16:26 +0100 |
---|---|---|
committer | Frederik Gladhorn <frederik.gladhorn@digia.com> | 2014-03-20 15:55:39 +0100 |
commit | 3f0f86b0caed75241fa71c95a5d73bc0164348c5 (patch) | |
tree | 92b9fb00f2e9e90b0be2262093876d4f43b6cd13 /chromium/net/socket | |
parent | e90d7c4b152c56919d963987e2503f9909a666d2 (diff) | |
download | qtwebengine-chromium-3f0f86b0caed75241fa71c95a5d73bc0164348c5.tar.gz |
Update to new stable branch 1750
This also includes an updated ninja and chromium dependencies
needed on Windows.
Change-Id: Icd597d80ed3fa4425933c9f1334c3c2e31291c42
Reviewed-by: Zoltan Arvai <zarvai@inf.u-szeged.hu>
Reviewed-by: Zeno Albisser <zeno.albisser@digia.com>
Diffstat (limited to 'chromium/net/socket')
44 files changed, 3311 insertions, 1393 deletions
diff --git a/chromium/net/socket/buffered_write_stream_socket.cc b/chromium/net/socket/buffered_write_stream_socket.cc index cf13c5e439a..87d3c337e24 100644 --- a/chromium/net/socket/buffered_write_stream_socket.cc +++ b/chromium/net/socket/buffered_write_stream_socket.cc @@ -27,10 +27,10 @@ BufferedWriteStreamSocket::BufferedWriteStreamSocket( : wrapped_socket_(socket_to_wrap.Pass()), io_buffer_(new GrowableIOBuffer()), backup_buffer_(new GrowableIOBuffer()), - weak_factory_(this), callback_pending_(false), wrapped_write_in_progress_(false), - error_(0) { + error_(0), + weak_factory_(this) { } BufferedWriteStreamSocket::~BufferedWriteStreamSocket() { diff --git a/chromium/net/socket/buffered_write_stream_socket.h b/chromium/net/socket/buffered_write_stream_socket.h index aad5736d0b0..803018eefe3 100644 --- a/chromium/net/socket/buffered_write_stream_socket.h +++ b/chromium/net/socket/buffered_write_stream_socket.h @@ -69,11 +69,12 @@ class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket { scoped_ptr<StreamSocket> wrapped_socket_; scoped_refptr<GrowableIOBuffer> io_buffer_; scoped_refptr<GrowableIOBuffer> backup_buffer_; - base::WeakPtrFactory<BufferedWriteStreamSocket> weak_factory_; bool callback_pending_; bool wrapped_write_in_progress_; int error_; + base::WeakPtrFactory<BufferedWriteStreamSocket> weak_factory_; + DISALLOW_COPY_AND_ASSIGN(BufferedWriteStreamSocket); }; diff --git a/chromium/net/socket/client_socket_factory.cc b/chromium/net/socket/client_socket_factory.cc index a86688e3333..953914581fc 100644 --- a/chromium/net/socket/client_socket_factory.cc +++ b/chromium/net/socket/client_socket_factory.cc @@ -60,10 +60,10 @@ class DefaultClientSocketFactory : public ClientSocketFactory, ClearSSLSessionCache(); } - virtual void OnCertTrustChanged(const X509Certificate* cert) OVERRIDE { + virtual void OnCACertChanged(const X509Certificate* cert) OVERRIDE { // Per wtc, we actually only need to flush when trust is reduced. - // Always flush now because OnCertTrustChanged does not tell us this. - // See comments in ClientSocketPoolManager::OnCertTrustChanged. + // Always flush now because OnCACertChanged does not tell us this. + // See comments in ClientSocketPoolManager::OnCACertChanged. ClearSSLSessionCache(); } diff --git a/chromium/net/socket/client_socket_pool_base.cc b/chromium/net/socket/client_socket_pool_base.cc index cec7956a0ee..1c79923a400 100644 --- a/chromium/net/socket/client_socket_pool_base.cc +++ b/chromium/net/socket/client_socket_pool_base.cc @@ -39,28 +39,6 @@ const int kCleanupInterval = 10; // DO NOT INCREASE THIS TIMEOUT. // after a certain timeout has passed without receiving an ACK. bool g_connect_backup_jobs_enabled = true; -// Compares the effective priority of two results, and returns 1 if |request1| -// has greater effective priority than |request2|, 0 if they have the same -// effective priority, and -1 if |request2| has the greater effective priority. -// Requests with |ignore_limits| set have higher effective priority than those -// without. If both requests have |ignore_limits| set/unset, then the request -// with the highest Pririoty has the highest effective priority. Does not take -// into account the fact that Requests are serviced in FIFO order if they would -// otherwise have the same priority. -int CompareEffectiveRequestPriority( - const internal::ClientSocketPoolBaseHelper::Request& request1, - const internal::ClientSocketPoolBaseHelper::Request& request2) { - if (request1.ignore_limits() && !request2.ignore_limits()) - return 1; - if (!request1.ignore_limits() && request2.ignore_limits()) - return -1; - if (request1.priority() > request2.priority()) - return 1; - if (request1.priority() < request2.priority()) - return -1; - return 0; -} - } // namespace ConnectJob::ConnectJob(const std::string& group_name, @@ -162,7 +140,10 @@ ClientSocketPoolBaseHelper::Request::Request( priority_(priority), ignore_limits_(ignore_limits), flags_(flags), - net_log_(net_log) {} + net_log_(net_log) { + if (ignore_limits_) + DCHECK_EQ(priority_, MAXIMUM_PRIORITY); +} ClientSocketPoolBaseHelper::Request::~Request() {} @@ -430,9 +411,8 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( // If we don't have any sockets in this group, set a timer for potentially // creating a new one. If the SYN is lost, this backup socket may complete // before the slow socket, improving end user latency. - if (connect_backup_jobs_enabled_ && - group->IsEmpty() && !group->HasBackupJob()) { - group->StartBackupSocketTimer(group_name, this); + if (connect_backup_jobs_enabled_ && group->IsEmpty()) { + group->StartBackupJobTimer(group_name, this); } connecting_socket_count_++; @@ -625,8 +605,9 @@ base::DictionaryValue* ClientSocketPoolBaseHelper::GetInfoAsValue( group_dict->SetInteger("pending_request_count", group->pending_request_count()); if (group->has_pending_requests()) { - group_dict->SetInteger("top_pending_priority", - group->TopPendingPriority()); + group_dict->SetString( + "top_pending_priority", + RequestPriorityToString(group->TopPendingPriority())); } group_dict->SetInteger("active_socket_count", group->active_socket_count()); @@ -652,7 +633,8 @@ base::DictionaryValue* ClientSocketPoolBaseHelper::GetInfoAsValue( group_dict->SetBoolean("is_stalled", group->IsStalledOnPoolMaxSockets( max_sockets_per_group_)); - group_dict->SetBoolean("has_backup_job", group->HasBackupJob()); + group_dict->SetBoolean("backup_job_timer_is_running", + group->BackupJobTimerIsRunning()); all_groups_dict->SetWithoutPathExpansion(it->first, group_dict); } @@ -951,11 +933,6 @@ void ClientSocketPoolBaseHelper::RemoveConnectJob(ConnectJob* job, DCHECK(group); group->RemoveJob(job); - - // If we've got no more jobs for this group, then we no longer need a - // backup job either. - if (group->jobs().empty()) - group->CleanupBackupJob(); } void ClientSocketPoolBaseHelper::OnAvailableSocketSlot( @@ -1157,26 +1134,30 @@ void ClientSocketPoolBaseHelper::TryToCloseSocketsInLayeredPools() { ClientSocketPoolBaseHelper::Group::Group() : unassigned_job_count_(0), - active_socket_count_(0), - weak_factory_(this) {} + pending_requests_(NUM_PRIORITIES), + active_socket_count_(0) {} ClientSocketPoolBaseHelper::Group::~Group() { - CleanupBackupJob(); DCHECK_EQ(0u, unassigned_job_count_); } -void ClientSocketPoolBaseHelper::Group::StartBackupSocketTimer( +void ClientSocketPoolBaseHelper::Group::StartBackupJobTimer( const std::string& group_name, ClientSocketPoolBaseHelper* pool) { - // Only allow one timer pending to create a backup socket. - if (weak_factory_.HasWeakPtrs()) + // Only allow one timer to run at a time. + if (BackupJobTimerIsRunning()) return; - base::MessageLoop::current()->PostDelayedTask( - FROM_HERE, - base::Bind(&Group::OnBackupSocketTimerFired, weak_factory_.GetWeakPtr(), - group_name, pool), - pool->ConnectRetryInterval()); + // Unretained here is okay because |backup_job_timer_| is + // automatically cancelled when it's destroyed. + backup_job_timer_.Start( + FROM_HERE, pool->ConnectRetryInterval(), + base::Bind(&Group::OnBackupJobTimerFired, base::Unretained(this), + group_name, pool)); +} + +bool ClientSocketPoolBaseHelper::Group::BackupJobTimerIsRunning() const { + return backup_job_timer_.IsRunning(); } bool ClientSocketPoolBaseHelper::Group::TryToUseUnassignedConnectJob() { @@ -1210,9 +1191,14 @@ void ClientSocketPoolBaseHelper::Group::RemoveJob(ConnectJob* job) { size_t job_count = jobs_.size(); if (job_count < unassigned_job_count_) unassigned_job_count_ = job_count; + + // If we've got no more jobs for this group, then we no longer need a + // backup job either. + if (jobs_.empty()) + backup_job_timer_.Stop(); } -void ClientSocketPoolBaseHelper::Group::OnBackupSocketTimerFired( +void ClientSocketPoolBaseHelper::Group::OnBackupJobTimerFired( std::string group_name, ClientSocketPoolBaseHelper* pool) { // If there are no more jobs pending, there is no work to do. @@ -1227,7 +1213,7 @@ void ClientSocketPoolBaseHelper::Group::OnBackupSocketTimerFired( if (pool->ReachedMaxSocketsLimit() || !HasAvailableSocketSlot(pool->max_sockets_per_group_) || (*jobs_.begin())->GetLoadState() == LOAD_STATE_RESOLVING_HOST) { - StartBackupSocketTimer(group_name, pool); + StartBackupJobTimer(group_name, pool); return; } @@ -1236,8 +1222,8 @@ void ClientSocketPoolBaseHelper::Group::OnBackupSocketTimerFired( scoped_ptr<ConnectJob> backup_job = pool->connect_job_factory_->NewConnectJob( - group_name, **pending_requests_.begin(), pool); - backup_job->net_log().AddEvent(NetLog::TYPE_SOCKET_BACKUP_CREATED); + group_name, *pending_requests_.FirstMax().value(), pool); + backup_job->net_log().AddEvent(NetLog::TYPE_BACKUP_CONNECT_JOB_CREATED); SIMPLE_STATS_COUNTER("socket.backup_created"); int rv = backup_job->Connect(); pool->connecting_socket_count_++; @@ -1258,13 +1244,14 @@ void ClientSocketPoolBaseHelper::Group::RemoveAllJobs() { STLDeleteElements(&jobs_); unassigned_job_count_ = 0; - // Cancel pending backup job. - weak_factory_.InvalidateWeakPtrs(); + // Stop backup job timer. + backup_job_timer_.Stop(); } const ClientSocketPoolBaseHelper::Request* ClientSocketPoolBaseHelper::Group::GetNextPendingRequest() const { - return pending_requests_.empty() ? NULL : *pending_requests_.begin(); + return + pending_requests_.empty() ? NULL : pending_requests_.FirstMax().value(); } bool ClientSocketPoolBaseHelper::Group::HasConnectJobForHandle( @@ -1273,40 +1260,45 @@ bool ClientSocketPoolBaseHelper::Group::HasConnectJobForHandle( // If it's farther back in the deque than that, it doesn't have a // corresponding ConnectJob. size_t i = 0; - for (RequestQueue::const_iterator it = pending_requests_.begin(); - it != pending_requests_.end() && i < jobs_.size(); ++it, ++i) { - if ((*it)->handle() == handle) + for (RequestQueue::Pointer pointer = pending_requests_.FirstMax(); + !pointer.is_null() && i < jobs_.size(); + pointer = pending_requests_.GetNextTowardsLastMin(pointer), ++i) { + if (pointer.value()->handle() == handle) return true; } return false; } void ClientSocketPoolBaseHelper::Group::InsertPendingRequest( - scoped_ptr<const Request> r) { - RequestQueue::iterator it = pending_requests_.begin(); - // TODO(mmenke): Should the network stack require requests with - // |ignore_limits| have the highest priority? - while (it != pending_requests_.end() && - CompareEffectiveRequestPriority(*r, *(*it)) <= 0) { - ++it; + scoped_ptr<const Request> request) { + // This value must be cached before we release |request|. + RequestPriority priority = request->priority(); + if (request->ignore_limits()) { + // Put requests with ignore_limits == true (which should have + // priority == MAXIMUM_PRIORITY) ahead of other requests with + // MAXIMUM_PRIORITY. + DCHECK_EQ(priority, MAXIMUM_PRIORITY); + pending_requests_.InsertAtFront(request.release(), priority); + } else { + pending_requests_.Insert(request.release(), priority); } - pending_requests_.insert(it, r.release()); } scoped_ptr<const ClientSocketPoolBaseHelper::Request> ClientSocketPoolBaseHelper::Group::PopNextPendingRequest() { if (pending_requests_.empty()) return scoped_ptr<const ClientSocketPoolBaseHelper::Request>(); - return RemovePendingRequest(pending_requests_.begin()); + return RemovePendingRequest(pending_requests_.FirstMax()); } scoped_ptr<const ClientSocketPoolBaseHelper::Request> ClientSocketPoolBaseHelper::Group::FindAndRemovePendingRequest( ClientSocketHandle* handle) { - for (RequestQueue::iterator it = pending_requests_.begin(); - it != pending_requests_.end(); ++it) { - if ((*it)->handle() == handle) { - scoped_ptr<const Request> request = RemovePendingRequest(it); + for (RequestQueue::Pointer pointer = pending_requests_.FirstMax(); + !pointer.is_null(); + pointer = pending_requests_.GetNextTowardsLastMin(pointer)) { + if (pointer.value()->handle() == handle) { + scoped_ptr<const Request> request = RemovePendingRequest(pointer); return request.Pass(); } } @@ -1315,12 +1307,12 @@ ClientSocketPoolBaseHelper::Group::FindAndRemovePendingRequest( scoped_ptr<const ClientSocketPoolBaseHelper::Request> ClientSocketPoolBaseHelper::Group::RemovePendingRequest( - const RequestQueue::iterator& it) { - scoped_ptr<const Request> request(*it); - pending_requests_.erase(it); + const RequestQueue::Pointer& pointer) { + scoped_ptr<const Request> request(pointer.value()); + pending_requests_.Erase(pointer); // If there are no more requests, kill the backup timer. if (pending_requests_.empty()) - CleanupBackupJob(); + backup_job_timer_.Stop(); return request.Pass(); } diff --git a/chromium/net/socket/client_socket_pool_base.h b/chromium/net/socket/client_socket_pool_base.h index 31ec9bf7b13..2c2ddb57abc 100644 --- a/chromium/net/socket/client_socket_pool_base.h +++ b/chromium/net/socket/client_socket_pool_base.h @@ -22,6 +22,7 @@ #ifndef NET_SOCKET_CLIENT_SOCKET_POOL_BASE_H_ #define NET_SOCKET_CLIENT_SOCKET_POOL_BASE_H_ +#include <cstddef> #include <deque> #include <list> #include <map> @@ -43,6 +44,7 @@ #include "net/base/net_export.h" #include "net/base/net_log.h" #include "net/base/network_change_notifier.h" +#include "net/base/priority_queue.h" #include "net/base/request_priority.h" #include "net/socket/client_socket_pool.h" #include "net/socket/stream_socket.h" @@ -182,12 +184,12 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper private: ClientSocketHandle* const handle_; - CompletionCallback callback_; + const CompletionCallback callback_; // TODO(akalin): Support reprioritization. const RequestPriority priority_; - bool ignore_limits_; + const bool ignore_limits_; const Flags flags_; - BoundNetLog net_log_; + const BoundNetLog net_log_; DISALLOW_COPY_AND_ASSIGN(Request); }; @@ -350,7 +352,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper base::TimeTicks start_time; }; - typedef std::deque<const Request* > RequestQueue; + typedef PriorityQueue<const Request*> RequestQueue; typedef std::map<const ClientSocketHandle*, const Request*> RequestMap; // A Group is allocated per group_name when there are idle sockets or pending @@ -380,19 +382,22 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper pending_requests_.size() > jobs_.size(); } + // Returns the priority of the top of the pending request queue + // (which may be less than the maximum priority over the entire + // queue, due to how we prioritize requests with |ignore_limits| + // set over others). RequestPriority TopPendingPriority() const { - return pending_requests_.front()->priority(); + // NOTE: FirstMax().value()->priority() is not the same as + // FirstMax().priority()! + return pending_requests_.FirstMax().value()->priority(); } - bool HasBackupJob() const { return weak_factory_.HasWeakPtrs(); } + // Set a timer to create a backup job if it takes too long to + // create one and if a timer isn't already running. + void StartBackupJobTimer(const std::string& group_name, + ClientSocketPoolBaseHelper* pool); - void CleanupBackupJob() { - weak_factory_.InvalidateWeakPtrs(); - } - - // Set a timer to create a backup socket if it takes too long to create one. - void StartBackupSocketTimer(const std::string& group_name, - ClientSocketPoolBaseHelper* pool); + bool BackupJobTimerIsRunning() const; // If there's a ConnectJob that's never been assigned to Request, // decrements |unassigned_job_count_| and returns true. @@ -422,7 +427,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // Inserts the request into the queue based on priority // order. Older requests are prioritized over requests of equal // priority. - void InsertPendingRequest(scoped_ptr<const Request> r); + void InsertPendingRequest(scoped_ptr<const Request> request); // Gets and removes the next pending request. Returns NULL if // there are no pending requests. @@ -446,10 +451,10 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // Returns the iterator's pending request after removing it from // the queue. scoped_ptr<const Request> RemovePendingRequest( - const RequestQueue::iterator& it); + const RequestQueue::Pointer& pointer); // Called when the backup socket timer fires. - void OnBackupSocketTimerFired( + void OnBackupJobTimerFired( std::string group_name, ClientSocketPoolBaseHelper* pool); @@ -469,8 +474,8 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper std::set<ConnectJob*> jobs_; RequestQueue pending_requests_; int active_socket_count_; // number of active sockets used by clients - // A factory to pin the backup_job tasks. - base::WeakPtrFactory<Group> weak_factory_; + // A timer for when to start the backup job. + base::OneShotTimer<Group> backup_job_timer_; }; typedef std::map<std::string, Group*> GroupMap; diff --git a/chromium/net/socket/client_socket_pool_base_unittest.cc b/chromium/net/socket/client_socket_pool_base_unittest.cc index bbeca2f3e11..46f4e40c0d4 100644 --- a/chromium/net/socket/client_socket_pool_base_unittest.cc +++ b/chromium/net/socket/client_socket_pool_base_unittest.cc @@ -45,7 +45,6 @@ namespace { const int kDefaultMaxSockets = 4; const int kDefaultMaxSocketsPerGroup = 2; -const net::RequestPriority kDefaultPriority = MEDIUM; // Make sure |handle| sets load times correctly when it has been assigned a // reused socket. @@ -100,18 +99,16 @@ void TestLoadTimingInfoNotConnected(const ClientSocketHandle& handle) { class TestSocketParams : public base::RefCounted<TestSocketParams> { public: - TestSocketParams() : ignore_limits_(false) {} + explicit TestSocketParams(bool ignore_limits) + : ignore_limits_(ignore_limits) {} - void set_ignore_limits(bool ignore_limits) { - ignore_limits_ = ignore_limits; - } bool ignore_limits() { return ignore_limits_; } private: friend class base::RefCounted<TestSocketParams>; ~TestSocketParams() {} - bool ignore_limits_; + const bool ignore_limits_; }; typedef ClientSocketPoolBase<TestSocketParams> TestClientSocketPoolBase; @@ -265,9 +262,10 @@ class TestConnectJob : public ConnectJob { BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), job_type_(job_type), client_socket_factory_(client_socket_factory), - weak_factory_(this), load_state_(LOAD_STATE_IDLE), - store_additional_error_state_(false) {} + store_additional_error_state_(false), + weak_factory_(this) { + } void Signal() { DoConnect(waiting_success_, true /* async */, false /* recoverable */); @@ -400,10 +398,11 @@ class TestConnectJob : public ConnectJob { bool waiting_success_; const JobType job_type_; MockClientSocketFactory* const client_socket_factory_; - base::WeakPtrFactory<TestConnectJob> weak_factory_; LoadState load_state_; bool store_additional_error_state_; + base::WeakPtrFactory<TestConnectJob> weak_factory_; + DISALLOW_COPY_AND_ASSIGN(TestConnectJob); }; @@ -663,7 +662,7 @@ class TestConnectJobDelegate : public ConnectJob::Delegate { class ClientSocketPoolBaseTest : public testing::Test { protected: ClientSocketPoolBaseTest() - : params_(new TestSocketParams()), + : params_(new TestSocketParams(false /* ignore_limits */)), histograms_("ClientSocketPoolTest") { connect_backup_jobs_enabled_ = internal::ClientSocketPoolBaseHelper::connect_backup_jobs_enabled(); @@ -748,7 +747,7 @@ TEST_F(ClientSocketPoolBaseTest, ConnectJob_NoTimeoutOnSynchronousCompletion) { TestConnectJobDelegate delegate; ClientSocketHandle ignored; TestClientSocketPoolBase::Request request( - &ignored, CompletionCallback(), kDefaultPriority, + &ignored, CompletionCallback(), DEFAULT_PRIORITY, internal::ClientSocketPoolBaseHelper::NORMAL, false, params_, BoundNetLog()); scoped_ptr<TestConnectJob> job( @@ -768,7 +767,7 @@ TEST_F(ClientSocketPoolBaseTest, ConnectJob_TimedOut) { CapturingNetLog log; TestClientSocketPoolBase::Request request( - &ignored, CompletionCallback(), kDefaultPriority, + &ignored, CompletionCallback(), DEFAULT_PRIORITY, internal::ClientSocketPoolBaseHelper::NORMAL, false, params_, BoundNetLog()); // Deleted by TestConnectJobDelegate. @@ -815,7 +814,7 @@ TEST_F(ClientSocketPoolBaseTest, BasicSynchronous) { EXPECT_EQ(OK, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), log.bound())); @@ -858,7 +857,7 @@ TEST_F(ClientSocketPoolBaseTest, InitConnectionFailure) { EXPECT_EQ(ERR_CONNECTION_FAILED, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), log.bound())); @@ -885,18 +884,18 @@ TEST_F(ClientSocketPoolBaseTest, TotalLimit) { // TODO(eroman): Check that the NetLog contains this event. - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("c", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("d", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("b", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("c", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("d", DEFAULT_PRIORITY)); EXPECT_EQ(static_cast<int>(requests_size()), client_socket_factory_.allocation_count()); EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count()); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("e", kDefaultPriority)); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("f", kDefaultPriority)); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("g", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("e", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("f", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("g", DEFAULT_PRIORITY)); ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); @@ -922,17 +921,17 @@ TEST_F(ClientSocketPoolBaseTest, TotalLimitReachedNewGroup) { // TODO(eroman): Check that the NetLog contains this event. // Reach all limits: max total sockets, and max sockets per group. - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("b", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("b", DEFAULT_PRIORITY)); EXPECT_EQ(static_cast<int>(requests_size()), client_socket_factory_.allocation_count()); EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count()); // Now create a new group and verify that we don't starve it. - EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", DEFAULT_PRIORITY)); ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); @@ -1028,13 +1027,13 @@ TEST_F(ClientSocketPoolBaseTest, TotalLimitRespectsGroupLimit) { TEST_F(ClientSocketPoolBaseTest, TotalLimitCountsConnectingSockets) { CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("c", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("b", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("c", DEFAULT_PRIORITY)); // Create one asynchronous request. connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("d", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("d", DEFAULT_PRIORITY)); // We post all of our delayed tasks with a 2ms delay. I.e. they don't // actually become pending until 2ms after they have been created. In order @@ -1045,7 +1044,7 @@ TEST_F(ClientSocketPoolBaseTest, TotalLimitCountsConnectingSockets) { // The next synchronous request should wait for its turn. connect_job_factory_->set_job_type(TestConnectJob::kMockJob); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("e", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("e", DEFAULT_PRIORITY)); ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); @@ -1066,17 +1065,17 @@ TEST_F(ClientSocketPoolBaseTest, CorrectlyCountStalledGroups) { CreatePool(kDefaultMaxSockets, kDefaultMaxSockets); connect_job_factory_->set_job_type(TestConnectJob::kMockJob); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count()); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("b", kDefaultPriority)); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("b", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", DEFAULT_PRIORITY)); EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count()); @@ -1098,7 +1097,7 @@ TEST_F(ClientSocketPoolBaseTest, StallAndThenCancelAndTriggerAvailableSocket) { EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -1109,7 +1108,7 @@ TEST_F(ClientSocketPoolBaseTest, StallAndThenCancelAndTriggerAvailableSocket) { EXPECT_EQ(ERR_IO_PENDING, handles[i].Init("b", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -1132,7 +1131,7 @@ TEST_F(ClientSocketPoolBaseTest, CancelStalledSocketAtSocketLimit) { for (int i = 0; i < kDefaultMaxSockets; ++i) { EXPECT_EQ(OK, handles[i].Init(base::IntToString(i), params_, - kDefaultPriority, + DEFAULT_PRIORITY, callbacks[i].callback(), pool_.get(), BoundNetLog())); @@ -1143,7 +1142,7 @@ TEST_F(ClientSocketPoolBaseTest, CancelStalledSocketAtSocketLimit) { TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -1171,7 +1170,7 @@ TEST_F(ClientSocketPoolBaseTest, CancelPendingSocketAtSocketLimit) { TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, handles[i].Init(base::IntToString(i), params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -1183,7 +1182,7 @@ TEST_F(ClientSocketPoolBaseTest, CancelPendingSocketAtSocketLimit) { TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -1228,7 +1227,7 @@ TEST_F(ClientSocketPoolBaseTest, WaitForStalledSocketAtSocketLimit) { EXPECT_EQ(OK, handles[i].Init(base::StringPrintf( "Take 2: %d", i), params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -1241,7 +1240,7 @@ TEST_F(ClientSocketPoolBaseTest, WaitForStalledSocketAtSocketLimit) { // Now we will hit the socket limit. EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -1269,7 +1268,7 @@ TEST_F(ClientSocketPoolBaseTest, CloseIdleSocketAtSocketLimitDeleteGroup) { TestCompletionCallback callback; EXPECT_EQ(OK, handle.Init(base::IntToString(i), params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -1289,7 +1288,7 @@ TEST_F(ClientSocketPoolBaseTest, CloseIdleSocketAtSocketLimitDeleteGroup) { // close an idle socket though, since we should reuse the idle socket. EXPECT_EQ(OK, handle.Init("0", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -1301,8 +1300,8 @@ TEST_F(ClientSocketPoolBaseTest, CloseIdleSocketAtSocketLimitDeleteGroup) { TEST_F(ClientSocketPoolBaseTest, PendingRequests) { CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", IDLE)); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); @@ -1333,8 +1332,8 @@ TEST_F(ClientSocketPoolBaseTest, PendingRequests) { TEST_F(ClientSocketPoolBaseTest, PendingRequests_NoKeepAlive) { CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST)); @@ -1363,7 +1362,7 @@ TEST_F(ClientSocketPoolBaseTest, CancelRequestClearGroup) { TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -1379,7 +1378,7 @@ TEST_F(ClientSocketPoolBaseTest, ConnectCancelConnect) { EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -1390,7 +1389,7 @@ TEST_F(ClientSocketPoolBaseTest, ConnectCancelConnect) { EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog())); @@ -1404,8 +1403,8 @@ TEST_F(ClientSocketPoolBaseTest, ConnectCancelConnect) { TEST_F(ClientSocketPoolBaseTest, CancelRequest) { CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST)); @@ -1478,10 +1477,11 @@ class RequestSocketCallback : public TestCompletionCallbackBase { } within_callback_ = true; TestCompletionCallback next_job_callback; - scoped_refptr<TestSocketParams> params(new TestSocketParams()); + scoped_refptr<TestSocketParams> params( + new TestSocketParams(false /* ignore_limits */)); int rv = handle_->Init("a", params, - kDefaultPriority, + DEFAULT_PRIORITY, next_job_callback.callback(), pool_, BoundNetLog()); @@ -1530,7 +1530,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestPendingJobTwice) { TestConnectJob::kMockPendingJob); int rv = handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog()); @@ -1548,7 +1548,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestPendingJobThenSynchronous) { &handle, pool_.get(), connect_job_factory_, TestConnectJob::kMockJob); int rv = handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog()); @@ -1564,13 +1564,13 @@ TEST_F(ClientSocketPoolBaseTest, CancelActiveRequestWithPendingRequests) { connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); // Now, kDefaultMaxSocketsPerGroup requests should be active. // Let's cancel them. @@ -1601,7 +1601,7 @@ TEST_F(ClientSocketPoolBaseTest, FailingActiveRequestWithPendingRequests) { // Queue up all the requests for (size_t i = 0; i < kNumberOfRequests; ++i) - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); for (size_t i = 0; i < kNumberOfRequests; ++i) EXPECT_EQ(ERR_CONNECTION_FAILED, request(i)->WaitForResult()); @@ -1616,7 +1616,7 @@ TEST_F(ClientSocketPoolBaseTest, CancelActiveRequestThenRequestSocket) { TestCompletionCallback callback; int rv = handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog()); @@ -1627,7 +1627,7 @@ TEST_F(ClientSocketPoolBaseTest, CancelActiveRequestThenRequestSocket) { rv = handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog()); @@ -1647,14 +1647,14 @@ TEST_F(ClientSocketPoolBaseTest, GroupWithPendingRequestsIsNotEmpty) { const RequestPriority kHighPriority = HIGHEST; - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); // This is going to be a pending request in an otherwise empty group. - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); // Reach the maximum socket limit. - EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("b", DEFAULT_PRIORITY)); // Create a stalled group with high priorities. EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", kHighPriority)); @@ -1732,7 +1732,7 @@ TEST_F(ClientSocketPoolBaseTest, handle.set_ssl_error_response_info(info); EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), log.bound())); @@ -1768,7 +1768,7 @@ TEST_F(ClientSocketPoolBaseTest, TwoRequestsCancelOne) { EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -1776,7 +1776,7 @@ TEST_F(ClientSocketPoolBaseTest, TwoRequestsCancelOne) { EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog())); @@ -1828,7 +1828,7 @@ TEST_F(ClientSocketPoolBaseTest, ReleaseSockets) { TestSocketRequest req1(&request_order, &completion_count); int rv = req1.handle()->Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, req1.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -1841,7 +1841,7 @@ TEST_F(ClientSocketPoolBaseTest, ReleaseSockets) { TestSocketRequest req2(&request_order, &completion_count); rv = req2.handle()->Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, req2.callback(), pool_.get(), BoundNetLog()); @@ -1849,7 +1849,7 @@ TEST_F(ClientSocketPoolBaseTest, ReleaseSockets) { TestSocketRequest req3(&request_order, &completion_count); rv = req3.handle()->Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, req3.callback(), pool_.get(), BoundNetLog()); @@ -1888,7 +1888,7 @@ TEST_F(ClientSocketPoolBaseTest, PendingJobCompletionOrder) { TestSocketRequest req1(&request_order, &completion_count); int rv = req1.handle()->Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, req1.callback(), pool_.get(), BoundNetLog()); @@ -1897,7 +1897,7 @@ TEST_F(ClientSocketPoolBaseTest, PendingJobCompletionOrder) { TestSocketRequest req2(&request_order, &completion_count); rv = req2.handle()->Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, req2.callback(), pool_.get(), BoundNetLog()); @@ -1909,7 +1909,7 @@ TEST_F(ClientSocketPoolBaseTest, PendingJobCompletionOrder) { TestSocketRequest req3(&request_order, &completion_count); rv = req3.handle()->Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, req3.callback(), pool_.get(), BoundNetLog()); @@ -1934,7 +1934,7 @@ TEST_F(ClientSocketPoolBaseTest, LoadStateOneRequest) { TestCompletionCallback callback; int rv = handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog()); @@ -1957,7 +1957,7 @@ TEST_F(ClientSocketPoolBaseTest, LoadStateTwoRequests) { TestCompletionCallback callback; int rv = handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog()); @@ -1967,7 +1967,7 @@ TEST_F(ClientSocketPoolBaseTest, LoadStateTwoRequests) { TestCompletionCallback callback2; rv = handle2.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog()); @@ -2050,7 +2050,7 @@ TEST_F(ClientSocketPoolBaseTest, LoadStatePoolLimit) { TestCompletionCallback callback; int rv = handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog()); @@ -2061,7 +2061,7 @@ TEST_F(ClientSocketPoolBaseTest, LoadStatePoolLimit) { TestCompletionCallback callback2; rv = handle2.Init("b", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog()); @@ -2073,7 +2073,7 @@ TEST_F(ClientSocketPoolBaseTest, LoadStatePoolLimit) { TestCompletionCallback callback3; rv = handle3.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog()); @@ -2107,7 +2107,7 @@ TEST_F(ClientSocketPoolBaseTest, Recoverable) { ClientSocketHandle handle; TestCompletionCallback callback; EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, - handle.Init("a", params_, kDefaultPriority, callback.callback(), + handle.Init("a", params_, DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); @@ -2123,7 +2123,7 @@ TEST_F(ClientSocketPoolBaseTest, AsyncRecoverable) { EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2143,7 +2143,7 @@ TEST_F(ClientSocketPoolBaseTest, AdditionalErrorStateSynchronous) { EXPECT_EQ(ERR_CONNECTION_FAILED, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2163,7 +2163,7 @@ TEST_F(ClientSocketPoolBaseTest, AdditionalErrorStateAsynchronous) { EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2531,10 +2531,10 @@ TEST_F(ClientSocketPoolBaseTest, connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); EXPECT_EQ(OK, (*requests())[0]->WaitForResult()); EXPECT_EQ(OK, (*requests())[1]->WaitForResult()); @@ -2581,9 +2581,10 @@ class TestReleasingSocketRequest : public TestCompletionCallbackBase { if (reset_releasing_handle_) handle_.Reset(); - scoped_refptr<TestSocketParams> con_params(new TestSocketParams()); + scoped_refptr<TestSocketParams> con_params( + new TestSocketParams(false /* ignore_limits */)); EXPECT_EQ(expected_result_, - handle2_.Init("a", con_params, kDefaultPriority, + handle2_.Init("a", con_params, DEFAULT_PRIORITY, callback2_.callback(), pool_, BoundNetLog())); } @@ -2600,9 +2601,9 @@ class TestReleasingSocketRequest : public TestCompletionCallbackBase { TEST_F(ClientSocketPoolBaseTest, AdditionalErrorSocketsDontUseSlot) { CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); - EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); - EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("b", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(OK, StartRequest("b", DEFAULT_PRIORITY)); EXPECT_EQ(static_cast<int>(requests_size()), client_socket_factory_.allocation_count()); @@ -2611,7 +2612,7 @@ TEST_F(ClientSocketPoolBaseTest, AdditionalErrorSocketsDontUseSlot) { TestConnectJob::kMockPendingAdditionalErrorStateJob); TestReleasingSocketRequest req(pool_.get(), OK, false); EXPECT_EQ(ERR_IO_PENDING, - req.handle()->Init("a", params_, kDefaultPriority, req.callback(), + req.handle()->Init("a", params_, DEFAULT_PRIORITY, req.callback(), pool_.get(), BoundNetLog())); // The next job should complete synchronously connect_job_factory_->set_job_type(TestConnectJob::kMockJob); @@ -2639,7 +2640,7 @@ TEST_F(ClientSocketPoolBaseTest, CallbackThatReleasesPool) { TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2658,7 +2659,7 @@ TEST_F(ClientSocketPoolBaseTest, DoNotReuseSocketAfterFlush) { TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2672,7 +2673,7 @@ TEST_F(ClientSocketPoolBaseTest, DoNotReuseSocketAfterFlush) { EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2707,7 +2708,7 @@ class ConnectWithinCallback : public TestCompletionCallbackBase { EXPECT_EQ(ERR_IO_PENDING, handle_.Init(group_name_, params_, - kDefaultPriority, + DEFAULT_PRIORITY, nested_callback_.callback(), pool_, BoundNetLog())); @@ -2733,7 +2734,7 @@ TEST_F(ClientSocketPoolBaseTest, AbortAllRequestsOnFlush) { ConnectWithinCallback callback("a", params_, pool_.get()); EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2760,7 +2761,7 @@ TEST_F(ClientSocketPoolBaseTest, BackupSocketCancelAtMaxSockets) { TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, handle.Init("bar", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2772,7 +2773,7 @@ TEST_F(ClientSocketPoolBaseTest, BackupSocketCancelAtMaxSockets) { TestCompletionCallback callback; EXPECT_EQ(OK, handles[i].Init("bar", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2802,7 +2803,7 @@ TEST_F(ClientSocketPoolBaseTest, CancelBackupSocketAfterCancelingAllRequests) { TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, handle.Init("bar", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2832,7 +2833,7 @@ TEST_F(ClientSocketPoolBaseTest, CancelBackupSocketAfterFinishingAllRequests) { TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, handle.Init("bar", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2841,7 +2842,7 @@ TEST_F(ClientSocketPoolBaseTest, CancelBackupSocketAfterFinishingAllRequests) { TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, handle2.Init("bar", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog())); @@ -2871,7 +2872,7 @@ TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingWaitingForConnect) { EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2887,7 +2888,7 @@ TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingWaitingForConnect) { EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2929,7 +2930,7 @@ TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingAtGroupCapacity) { EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2945,7 +2946,7 @@ TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingAtGroupCapacity) { EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -2989,7 +2990,7 @@ TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingAtStall) { EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -3005,7 +3006,7 @@ TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingAtStall) { EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -3052,7 +3053,7 @@ TEST_F(ClientSocketPoolBaseTest, SynchronouslyProcessOnePendingRequest) { EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3068,7 +3069,7 @@ TEST_F(ClientSocketPoolBaseTest, SynchronouslyProcessOnePendingRequest) { EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog())); @@ -3089,7 +3090,7 @@ TEST_F(ClientSocketPoolBaseTest, PreferUsedSocketToUnusedSocket) { TestCompletionCallback callback1; EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3098,7 +3099,7 @@ TEST_F(ClientSocketPoolBaseTest, PreferUsedSocketToUnusedSocket) { TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog())); @@ -3106,7 +3107,7 @@ TEST_F(ClientSocketPoolBaseTest, PreferUsedSocketToUnusedSocket) { TestCompletionCallback callback3; EXPECT_EQ(ERR_IO_PENDING, handle3.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback3.callback(), pool_.get(), BoundNetLog())); @@ -3125,19 +3126,19 @@ TEST_F(ClientSocketPoolBaseTest, PreferUsedSocketToUnusedSocket) { EXPECT_EQ(OK, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, handle2.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, handle3.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback3.callback(), pool_.get(), BoundNetLog())); @@ -3162,7 +3163,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestSockets) { TestCompletionCallback callback1; EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3171,7 +3172,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestSockets) { TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog())); @@ -3198,7 +3199,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsWhenAlreadyHaveAConnectJob) { TestCompletionCallback callback1; EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3218,7 +3219,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsWhenAlreadyHaveAConnectJob) { TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog())); @@ -3246,7 +3247,7 @@ TEST_F(ClientSocketPoolBaseTest, TestCompletionCallback callback1; EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3255,7 +3256,7 @@ TEST_F(ClientSocketPoolBaseTest, TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog())); @@ -3264,7 +3265,7 @@ TEST_F(ClientSocketPoolBaseTest, TestCompletionCallback callback3; EXPECT_EQ(ERR_IO_PENDING, handle3.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback3.callback(), pool_.get(), BoundNetLog())); @@ -3346,7 +3347,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsCountIdleSockets) { TestCompletionCallback callback1; EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3373,7 +3374,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsCountActiveSockets) { TestCompletionCallback callback1; EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3450,7 +3451,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsMultipleTimesDoesNothing) { TestCompletionCallback callback1; EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3460,7 +3461,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsMultipleTimesDoesNothing) { TestCompletionCallback callback2; int rv = handle2.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog()); @@ -3529,7 +3530,7 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectJobsTakenByNormalRequests) { TestCompletionCallback callback1; EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3564,7 +3565,7 @@ TEST_F(ClientSocketPoolBaseTest, ConnectedPreconnectJobsHaveNoConnectTimes) { TestCompletionCallback callback; EXPECT_EQ(OK, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -3592,7 +3593,7 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectClosesIdleSocketRemovesGroup) { TestCompletionCallback callback1; EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3606,13 +3607,13 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectClosesIdleSocketRemovesGroup) { TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, handle1.Init("b", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(ERR_IO_PENDING, handle2.Init("b", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog())); @@ -3700,7 +3701,7 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectWithBackupJob) { TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -3724,7 +3725,6 @@ class MockLayeredPool : public HigherLayeredPool { MockLayeredPool(TestClientSocketPool* pool, const std::string& group_name) : pool_(pool), - params_(new TestSocketParams), group_name_(group_name), can_release_connection_(true) { pool_->AddHigherLayeredPool(this); @@ -3735,13 +3735,16 @@ class MockLayeredPool : public HigherLayeredPool { } int RequestSocket(TestClientSocketPool* pool) { - return handle_.Init(group_name_, params_, kDefaultPriority, + scoped_refptr<TestSocketParams> params( + new TestSocketParams(false /* ignore_limits */)); + return handle_.Init(group_name_, params, DEFAULT_PRIORITY, callback_.callback(), pool, BoundNetLog()); } int RequestSocketWithoutLimits(TestClientSocketPool* pool) { - params_->set_ignore_limits(true); - return handle_.Init(group_name_, params_, kDefaultPriority, + scoped_refptr<TestSocketParams> params( + new TestSocketParams(true /* ignore_limits */)); + return handle_.Init(group_name_, params, MAXIMUM_PRIORITY, callback_.callback(), pool, BoundNetLog()); } @@ -3762,7 +3765,6 @@ class MockLayeredPool : public HigherLayeredPool { private: TestClientSocketPool* const pool_; - scoped_refptr<TestSocketParams> params_; ClientSocketHandle handle_; TestCompletionCallback callback_; const std::string group_name_; @@ -3807,7 +3809,7 @@ TEST_F(ClientSocketPoolBaseTest, CloseIdleSocketsHeldByLayeredPoolWhenNeeded) { TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -3830,7 +3832,7 @@ TEST_F(ClientSocketPoolBaseTest, TestCompletionCallback callback1; EXPECT_EQ(OK, handle1.Init("group1", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3844,7 +3846,7 @@ TEST_F(ClientSocketPoolBaseTest, TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, handle.Init("group2", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback2.callback(), pool_.get(), BoundNetLog())); @@ -3867,7 +3869,7 @@ TEST_F(ClientSocketPoolBaseTest, TestCompletionCallback callback1; EXPECT_EQ(OK, handle1.Init("group1", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3884,7 +3886,7 @@ TEST_F(ClientSocketPoolBaseTest, TestCompletionCallback callback3; EXPECT_EQ(ERR_IO_PENDING, handle3.Init("group3", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback3.callback(), pool_.get(), BoundNetLog())); @@ -3900,7 +3902,7 @@ TEST_F(ClientSocketPoolBaseTest, TestCompletionCallback callback4; EXPECT_EQ(ERR_IO_PENDING, handle4.Init("group3", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback4.callback(), pool_.get(), BoundNetLog())); @@ -3932,7 +3934,7 @@ TEST_F(ClientSocketPoolBaseTest, TestCompletionCallback callback1; EXPECT_EQ(OK, handle1.Init("group1", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback1.callback(), pool_.get(), BoundNetLog())); @@ -3995,7 +3997,7 @@ TEST_F(ClientSocketPoolBaseTest, TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", params_, - kDefaultPriority, + DEFAULT_PRIORITY, callback.callback(), pool_.get(), BoundNetLog())); @@ -4007,21 +4009,21 @@ TEST_F(ClientSocketPoolBaseTest, // instead of a request with the same priority that was issued earlier, but // that does not have |ignore_limits| set. TEST_F(ClientSocketPoolBaseTest, IgnoreLimits) { - scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams()); - params_ignore_limits->set_ignore_limits(true); + scoped_refptr<TestSocketParams> params_ignore_limits( + new TestSocketParams(true /* ignore_limits */)); CreatePool(1, 1); // Issue a request to reach the socket pool limit. - EXPECT_EQ(OK, StartRequestWithParams("a", kDefaultPriority, params_)); + EXPECT_EQ(OK, StartRequestWithParams("a", MAXIMUM_PRIORITY, params_)); EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", kDefaultPriority, + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", MAXIMUM_PRIORITY, params_)); EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", kDefaultPriority, + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", MAXIMUM_PRIORITY, params_ignore_limits)); ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a")); @@ -4029,76 +4031,25 @@ TEST_F(ClientSocketPoolBaseTest, IgnoreLimits) { EXPECT_FALSE(request(1)->have_result()); } -// Test that when a socket pool and group are at their limits, a request with -// |ignore_limits| set triggers creation of a new socket, and gets the socket -// instead of a request with a higher priority that was issued earlier, but -// that does not have |ignore_limits| set. -TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsLowPriority) { - scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams()); - params_ignore_limits->set_ignore_limits(true); - CreatePool(1, 1); - - // Issue a request to reach the socket pool limit. - EXPECT_EQ(OK, StartRequestWithParams("a", HIGHEST, params_)); - EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); - - connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); - - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_)); - EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); - - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", LOW, - params_ignore_limits)); - ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a")); - - EXPECT_EQ(OK, request(2)->WaitForResult()); - EXPECT_FALSE(request(1)->have_result()); -} - -// Test that when a socket pool and group are at their limits, a request with -// |ignore_limits| set triggers creation of a new socket, and gets the socket -// instead of a request with a higher priority that was issued later and -// does not have |ignore_limits| set. -TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsLowPriority2) { - scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams()); - params_ignore_limits->set_ignore_limits(true); - CreatePool(1, 1); - - // Issue a request to reach the socket pool limit. - EXPECT_EQ(OK, StartRequestWithParams("a", HIGHEST, params_)); - EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); - - connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); - - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", LOW, - params_ignore_limits)); - ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a")); - - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_)); - EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); - - EXPECT_EQ(OK, request(1)->WaitForResult()); - EXPECT_FALSE(request(2)->have_result()); -} - // Test that when a socket pool and group are at their limits, a ConnectJob // issued for a request with |ignore_limits| set is not cancelled when a request // without |ignore_limits| issued to the same group is cancelled. TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsCancelOtherJob) { - scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams()); - params_ignore_limits->set_ignore_limits(true); + scoped_refptr<TestSocketParams> params_ignore_limits( + new TestSocketParams(true /* ignore_limits */)); CreatePool(1, 1); // Issue a request to reach the socket pool limit. - EXPECT_EQ(OK, StartRequestWithParams("a", HIGHEST, params_)); + EXPECT_EQ(OK, StartRequestWithParams("a", MAXIMUM_PRIORITY, params_)); EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_)); + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", MAXIMUM_PRIORITY, + params_)); EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", MAXIMUM_PRIORITY, params_ignore_limits)); ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a")); @@ -4111,58 +4062,6 @@ TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsCancelOtherJob) { EXPECT_FALSE(request(1)->have_result()); } -// More involved test of ignore limits. Issues a bunch of requests and later -// checks the order in which they receive sockets. -TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsOrder) { - scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams()); - params_ignore_limits->set_ignore_limits(true); - CreatePool(1, 1); - - connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); - - // Requests 0 and 1 do not have ignore_limits set, so they finish last. Since - // the maximum number of sockets per pool is 1, the second requests does not - // trigger a ConnectJob. - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_)); - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_)); - - // Requests 2 and 3 have ignore_limits set, but have a low priority, so they - // finish just before the first two. - EXPECT_EQ(ERR_IO_PENDING, - StartRequestWithParams("a", LOW, params_ignore_limits)); - EXPECT_EQ(ERR_IO_PENDING, - StartRequestWithParams("a", LOW, params_ignore_limits)); - - // Request 4 finishes first, since it is high priority and ignores limits. - EXPECT_EQ(ERR_IO_PENDING, - StartRequestWithParams("a", HIGHEST, params_ignore_limits)); - - // Request 5 and 6 are cancelled right after starting. This should result in - // creating two ConnectJobs. Since only one request (Request 1) did not - // result in creating a ConnectJob, only one of the ConnectJobs should be - // cancelled when the requests are. - EXPECT_EQ(ERR_IO_PENDING, - StartRequestWithParams("a", HIGHEST, params_ignore_limits)); - EXPECT_EQ(ERR_IO_PENDING, - StartRequestWithParams("a", HIGHEST, params_ignore_limits)); - EXPECT_EQ(6, pool_->NumConnectJobsInGroup("a")); - request(5)->handle()->Reset(); - EXPECT_EQ(6, pool_->NumConnectJobsInGroup("a")); - request(6)->handle()->Reset(); - ASSERT_EQ(5, pool_->NumConnectJobsInGroup("a")); - - // Wait for the last request to get a socket. - EXPECT_EQ(OK, request(1)->WaitForResult()); - - // Check order in which requests received sockets. - // These are 1-based indices, while request(x) uses 0-based indices. - EXPECT_EQ(1, GetOrderOfRequest(5)); - EXPECT_EQ(2, GetOrderOfRequest(3)); - EXPECT_EQ(3, GetOrderOfRequest(4)); - EXPECT_EQ(4, GetOrderOfRequest(1)); - EXPECT_EQ(5, GetOrderOfRequest(2)); -} - } // namespace } // namespace net diff --git a/chromium/net/socket/client_socket_pool_manager.cc b/chromium/net/socket/client_socket_pool_manager.cc index b37d2d1949c..24d6b70ced5 100644 --- a/chromium/net/socket/client_socket_pool_manager.cc +++ b/chromium/net/socket/client_socket_pool_manager.cc @@ -437,6 +437,31 @@ int InitSocketHandleForRawConnect( callback); } +int InitSocketHandleForTlsConnect( + const HostPortPair& host_port_pair, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + ClientSocketHandle* socket_handle, + const CompletionCallback& callback) { + DCHECK(socket_handle); + // Synthesize an HttpRequestInfo. + GURL request_url = GURL("https://" + host_port_pair.ToString()); + HttpRequestHeaders request_extra_headers; + int request_load_flags = 0; + RequestPriority request_priority = MEDIUM; + + return InitSocketPoolHelper( + request_url, request_extra_headers, request_load_flags, request_priority, + session, proxy_info, false, false, ssl_config_for_origin, + ssl_config_for_proxy, true, privacy_mode, net_log, 0, socket_handle, + HttpNetworkSession::NORMAL_SOCKET_POOL, OnHostResolutionCallback(), + callback); +} + int PreconnectSocketsForHttpRequest( const GURL& request_url, const HttpRequestHeaders& request_extra_headers, diff --git a/chromium/net/socket/client_socket_pool_manager.h b/chromium/net/socket/client_socket_pool_manager.h index 1b78324f233..12154809870 100644 --- a/chromium/net/socket/client_socket_pool_manager.h +++ b/chromium/net/socket/client_socket_pool_manager.h @@ -147,6 +147,21 @@ NET_EXPORT int InitSocketHandleForRawConnect( ClientSocketHandle* socket_handle, const CompletionCallback& callback); +// A helper method that uses the passed in proxy information to initialize a +// ClientSocketHandle with the relevant socket pool. Use this method for +// a raw socket connection with TLS negotiation to a host-port pair (that needs +// to tunnel through the proxies). +NET_EXPORT int InitSocketHandleForTlsConnect( + const HostPortPair& host_port_pair, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + ClientSocketHandle* socket_handle, + const CompletionCallback& callback); + // Similar to InitSocketHandleForHttpRequest except that it initiates the // desired number of preconnect streams from the relevant socket pool. int PreconnectSocketsForHttpRequest( diff --git a/chromium/net/socket/client_socket_pool_manager_impl.cc b/chromium/net/socket/client_socket_pool_manager_impl.cc index b557874d011..991278d7341 100644 --- a/chromium/net/socket/client_socket_pool_manager_impl.cc +++ b/chromium/net/socket/client_socket_pool_manager_impl.cc @@ -40,6 +40,7 @@ ClientSocketPoolManagerImpl::ClientSocketPoolManagerImpl( CertVerifier* cert_verifier, ServerBoundCertService* server_bound_cert_service, TransportSecurityState* transport_security_state, + CTVerifier* cert_transparency_verifier, const std::string& ssl_session_cache_shard, ProxyService* proxy_service, SSLConfigService* ssl_config_service, @@ -50,6 +51,7 @@ ClientSocketPoolManagerImpl::ClientSocketPoolManagerImpl( cert_verifier_(cert_verifier), server_bound_cert_service_(server_bound_cert_service), transport_security_state_(transport_security_state), + cert_transparency_verifier_(cert_transparency_verifier), ssl_session_cache_shard_(ssl_session_cache_shard), proxy_service_(proxy_service), ssl_config_service_(ssl_config_service), @@ -69,6 +71,7 @@ ClientSocketPoolManagerImpl::ClientSocketPoolManagerImpl( cert_verifier, server_bound_cert_service, transport_security_state, + cert_transparency_verifier, ssl_session_cache_shard, socket_factory, transport_socket_pool_.get(), @@ -286,6 +289,7 @@ ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy( cert_verifier_, server_bound_cert_service_, transport_security_state_, + cert_transparency_verifier_, ssl_session_cache_shard_, socket_factory_, tcp_https_ret.first->second /* https proxy */, @@ -326,6 +330,7 @@ SSLClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPoolForSSLWithProxy( cert_verifier_, server_bound_cert_service_, transport_security_state_, + cert_transparency_verifier_, ssl_session_cache_shard_, socket_factory_, NULL, /* no tcp pool, we always go through a proxy */ @@ -374,7 +379,7 @@ void ClientSocketPoolManagerImpl::OnCertAdded(const X509Certificate* cert) { FlushSocketPoolsWithError(ERR_NETWORK_CHANGED); } -void ClientSocketPoolManagerImpl::OnCertTrustChanged( +void ClientSocketPoolManagerImpl::OnCACertChanged( const X509Certificate* cert) { // We should flush the socket pools if we removed trust from a // cert, because a previously trusted server may have become @@ -383,8 +388,8 @@ void ClientSocketPoolManagerImpl::OnCertTrustChanged( // We should not flush the socket pools if we added trust to a // cert. // - // Since the OnCertTrustChanged method doesn't tell us what - // kind of trust change it is, we have to flush the socket + // Since the OnCACertChanged method doesn't tell us what + // kind of change it is, we have to flush the socket // pools to be safe. FlushSocketPoolsWithError(ERR_NETWORK_CHANGED); } diff --git a/chromium/net/socket/client_socket_pool_manager_impl.h b/chromium/net/socket/client_socket_pool_manager_impl.h index 8f6e618d2e1..06d4d244a36 100644 --- a/chromium/net/socket/client_socket_pool_manager_impl.h +++ b/chromium/net/socket/client_socket_pool_manager_impl.h @@ -23,6 +23,7 @@ namespace net { class CertVerifier; class ClientSocketFactory; class ClientSocketPoolHistograms; +class CTVerifier; class HttpProxyClientSocketPool; class HostResolver; class NetLog; @@ -62,6 +63,7 @@ class ClientSocketPoolManagerImpl : public base::NonThreadSafe, CertVerifier* cert_verifier, ServerBoundCertService* server_bound_cert_service, TransportSecurityState* transport_security_state, + CTVerifier* cert_transparency_verifier, const std::string& ssl_session_cache_shard, ProxyService* proxy_service, SSLConfigService* ssl_config_service, @@ -90,7 +92,7 @@ class ClientSocketPoolManagerImpl : public base::NonThreadSafe, // CertDatabase::Observer methods: virtual void OnCertAdded(const X509Certificate* cert) OVERRIDE; - virtual void OnCertTrustChanged(const X509Certificate* cert) OVERRIDE; + virtual void OnCACertChanged(const X509Certificate* cert) OVERRIDE; private: typedef internal::OwnedPoolMap<HostPortPair, TransportClientSocketPool*> @@ -108,6 +110,7 @@ class ClientSocketPoolManagerImpl : public base::NonThreadSafe, CertVerifier* const cert_verifier_; ServerBoundCertService* const server_bound_cert_service_; TransportSecurityState* const transport_security_state_; + CTVerifier* const cert_transparency_verifier_; const std::string ssl_session_cache_shard_; ProxyService* const proxy_service_; const scoped_refptr<SSLConfigService> ssl_config_service_; diff --git a/chromium/net/socket/next_proto.h b/chromium/net/socket/next_proto.h index 0bd307a3778..a2e5ab6902d 100644 --- a/chromium/net/socket/next_proto.h +++ b/chromium/net/socket/next_proto.h @@ -13,23 +13,19 @@ namespace net { // protocols that we recognise. enum NextProto { kProtoUnknown = 0, - kProtoHTTP11 = 1, + kProtoHTTP11, kProtoMinimumVersion = kProtoHTTP11, - // TODO(akalin): Stop advertising SPDY/1 and remove this. - kProtoSPDY1 = 2, - kProtoSPDYMinimumVersion = kProtoSPDY1, - kProtoSPDY2 = 3, - // TODO(akalin): Stop adverising SPDY/2.1, too. - kProtoSPDY21 = 4, - kProtoSPDY3 = 5, - kProtoSPDY31 = 6, - kProtoSPDY4a2 = 7, + kProtoDeprecatedSPDY2, + kProtoSPDYMinimumVersion = kProtoDeprecatedSPDY2, + kProtoSPDY3, + kProtoSPDY31, + kProtoSPDY4a2, // We lump in HTTP/2 with the SPDY protocols for now. - kProtoHTTP2Draft04 = 8, + kProtoHTTP2Draft04, kProtoSPDYMaximumVersion = kProtoHTTP2Draft04, - kProtoQUIC1SPDY3 = 9, + kProtoQUIC1SPDY3, kProtoMaximumVersion = kProtoQUIC1SPDY3, }; diff --git a/chromium/net/socket/nss_ssl_util.cc b/chromium/net/socket/nss_ssl_util.cc index 7e3aee430c4..33e7e6b89ac 100644 --- a/chromium/net/socket/nss_ssl_util.cc +++ b/chromium/net/socket/nss_ssl_util.cc @@ -13,6 +13,7 @@ #include <string> #include "base/bind.h" +#include "base/cpu.h" #include "base/lazy_instance.h" #include "base/logging.h" #include "base/memory/singleton.h" @@ -22,16 +23,67 @@ #include "crypto/nss_util.h" #include "net/base/net_errors.h" #include "net/base/net_log.h" +#include "net/base/nss_memio.h" #if defined(OS_WIN) #include "base/win/windows_version.h" #endif +namespace { + +// CiphersRemove takes a zero-terminated array of cipher suite ids in +// |to_remove| and sets every instance of them in |ciphers| to zero. It returns +// true if it found and removed every element of |to_remove|. It assumes that +// there are no duplicates in |ciphers| nor in |to_remove|. +bool CiphersRemove(const uint16* to_remove, uint16* ciphers, size_t num) { + size_t i, found = 0; + + for (i = 0; ; i++) { + if (to_remove[i] == 0) + break; + + for (size_t j = 0; j < num; j++) { + if (to_remove[i] == ciphers[j]) { + ciphers[j] = 0; + found++; + break; + } + } + } + + return found == i; +} + +// CiphersCompact takes an array of cipher suite ids in |ciphers|, where some +// entries are zero, and moves the entries so that all the non-zero elements +// are compacted at the end of the array. +void CiphersCompact(uint16* ciphers, size_t num) { + size_t j = num - 1; + + for (size_t i = num - 1; i < num; i--) { + if (ciphers[i] == 0) + continue; + ciphers[j--] = ciphers[i]; + } +} + +// CiphersCopy copies the zero-terminated array |in| to |out|. It returns the +// number of cipher suite ids copied. +size_t CiphersCopy(const uint16* in, uint16* out) { + for (size_t i = 0; ; i++) { + if (in[i] == 0) + return i; + out[i] = in[i]; + } +} + +} // anonymous namespace + namespace net { class NSSSSLInitSingleton { public: - NSSSSLInitSingleton() { + NSSSSLInitSingleton() : model_fd_(NULL) { crypto::EnsureNSSInit(); NSS_SetDomesticPolicy(); @@ -81,14 +133,72 @@ class NSSSSLInitSingleton { // Enable SSL. SSL_OptionSetDefault(SSL_SECURITY, PR_TRUE); + // Calculate the order of ciphers that we'll use for NSS sockets. (Note + // that, even if a cipher is specified in the ordering, it must still be + // enabled in order to be included in a ClientHello.) + // + // Our top preference cipher suites are either forward-secret AES-GCM or + // forward-secret ChaCha20-Poly1305. If the local machine has AES-NI then + // we prefer AES-GCM, otherwise ChaCha20. The remainder of the cipher suite + // preference is inheriented from NSS. */ + static const uint16 chacha_ciphers[] = { + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + 0, + }; + static const uint16 aes_gcm_ciphers[] = { + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, + 0, + }; + scoped_ptr<uint16[]> ciphers(new uint16[num_ciphers]); + memcpy(ciphers.get(), ssl_ciphers, sizeof(uint16)*num_ciphers); + + if (CiphersRemove(chacha_ciphers, ciphers.get(), num_ciphers) && + CiphersRemove(aes_gcm_ciphers, ciphers.get(), num_ciphers)) { + CiphersCompact(ciphers.get(), num_ciphers); + + const uint16* preference_ciphers = chacha_ciphers; + const uint16* other_ciphers = aes_gcm_ciphers; + base::CPU cpu; + + if (cpu.has_aesni() && cpu.has_avx()) { + preference_ciphers = aes_gcm_ciphers; + other_ciphers = chacha_ciphers; + } + unsigned i = CiphersCopy(preference_ciphers, ciphers.get()); + CiphersCopy(other_ciphers, &ciphers[i]); + + if ((model_fd_ = memio_CreateIOLayer(1, 1)) == NULL || + SSL_ImportFD(NULL, model_fd_) == NULL || + SECSuccess != + SSL_CipherOrderSet(model_fd_, ciphers.get(), num_ciphers)) { + NOTREACHED(); + if (model_fd_) { + PR_Close(model_fd_); + model_fd_ = NULL; + } + } + } + // All other SSL options are set per-session by SSLClientSocket and // SSLServerSocket. } + PRFileDesc* GetModelSocket() { + return model_fd_; + } + ~NSSSSLInitSingleton() { // Have to clear the cache, or NSS_Shutdown fails with SEC_ERROR_BUSY. SSL_ClearSessionCache(); + if (model_fd_) + PR_Close(model_fd_); } + + private: + PRFileDesc* model_fd_; }; static base::LazyInstance<NSSSSLInitSingleton> g_nss_ssl_init_singleton = @@ -107,6 +217,10 @@ void EnsureNSSSSLInit() { g_nss_ssl_init_singleton.Get(); } +PRFileDesc* GetNSSModelSocket() { + return g_nss_ssl_init_singleton.Get().GetModelSocket(); +} + // Map a Chromium net error code to an NSS error code. // See _MD_unix_map_default_error in the NSS source // tree for inspiration. @@ -237,6 +351,8 @@ int MapNSSError(PRErrorCode err) { // was used earlier. case SSL_ERROR_WRONG_CERTIFICATE: return ERR_SSL_SERVER_CERT_CHANGED; + case SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT: + return ERR_SSL_INAPPROPRIATE_FALLBACK; default: { if (IS_SSL_ERROR(err)) { diff --git a/chromium/net/socket/nss_ssl_util.h b/chromium/net/socket/nss_ssl_util.h index 09ae3562cd7..3aed7bf6b4a 100644 --- a/chromium/net/socket/nss_ssl_util.h +++ b/chromium/net/socket/nss_ssl_util.h @@ -9,6 +9,7 @@ #define NET_SOCKET_NSS_SSL_UTIL_H_ #include <prerror.h> +#include <prio.h> #include "net/base/net_export.h" @@ -27,6 +28,10 @@ void LogFailedNSSFunction(const BoundNetLog& net_log, // Map network error code to NSS error code. PRErrorCode MapErrorToNSS(int result); +// GetNSSModelSocket returns either NULL, or an NSS socket that can be passed +// to |SSL_ImportFD| in order to inherit some default options. +PRFileDesc* GetNSSModelSocket(); + // Map NSS error code to network error code. int MapNSSError(PRErrorCode err); diff --git a/chromium/net/socket/socket_descriptor.cc b/chromium/net/socket/socket_descriptor.cc index 5a2e53cab4d..81787a29d25 100644 --- a/chromium/net/socket/socket_descriptor.cc +++ b/chromium/net/socket/socket_descriptor.cc @@ -12,6 +12,8 @@ #include "base/basictypes.h" #if defined(OS_WIN) +#include <ws2tcpip.h> +#include "base/win/windows_version.h" #include "net/base/winsock_init.h" #endif @@ -32,7 +34,18 @@ void PlatformSocketFactory::SetInstance(PlatformSocketFactory* factory) { SocketDescriptor CreateSocketDefault(int family, int type, int protocol) { #if defined(OS_WIN) EnsureWinsockInit(); - return ::WSASocket(family, type, protocol, NULL, 0, WSA_FLAG_OVERLAPPED); + SocketDescriptor result = ::WSASocket(family, type, protocol, NULL, 0, + WSA_FLAG_OVERLAPPED); + if (result != kInvalidSocket && family == AF_INET6 && + base::win::OSInfo::GetInstance()->version() >= base::win::VERSION_VISTA) { + DWORD value = 0; + if (setsockopt(result, IPPROTO_IPV6, IPV6_V6ONLY, + reinterpret_cast<const char*>(&value), sizeof(value))) { + closesocket(result); + return kInvalidSocket; + } + } + return result; #else // OS_WIN return ::socket(family, type, protocol); #endif // OS_WIN diff --git a/chromium/net/socket/socket_test_util.cc b/chromium/net/socket/socket_test_util.cc index 78e9e7ce9c4..148b11c5698 100644 --- a/chromium/net/socket/socket_test_util.cc +++ b/chromium/net/socket/socket_test_util.cc @@ -698,9 +698,9 @@ void MockClientSocketFactory::ClearSSLSessionCache() { const char MockClientSocket::kTlsUnique[] = "MOCK_TLSUNIQ"; MockClientSocket::MockClientSocket(const BoundNetLog& net_log) - : weak_factory_(this), - connected_(false), - net_log_(net_log) { + : connected_(false), + net_log_(net_log), + weak_factory_(this) { IPAddressNumber ip; CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); peer_addr_ = IPEndPoint(ip, 0); @@ -957,7 +957,7 @@ int MockTCPClientSocket::CompleteRead() { was_used_to_convey_data_ = true; // Save the pending async IO data and reset our |pending_| state. - IOBuffer* buf = pending_buf_; + scoped_refptr<IOBuffer> buf = pending_buf_; int buf_len = pending_buf_len_; CompletionCallback callback = pending_callback_; pending_buf_ = NULL; @@ -1556,7 +1556,7 @@ int MockUDPClientSocket::CompleteRead() { DCHECK(pending_buf_len_ > 0); // Save the pending async IO data and reset our |pending_| state. - IOBuffer* buf = pending_buf_; + scoped_refptr<IOBuffer> buf = pending_buf_; int buf_len = pending_buf_len_; CompletionCallback callback = pending_callback_; pending_buf_ = NULL; diff --git a/chromium/net/socket/socket_test_util.h b/chromium/net/socket/socket_test_util.h index e4e56522c92..8df1d69538a 100644 --- a/chromium/net/socket/socket_test_util.h +++ b/chromium/net/socket/socket_test_util.h @@ -97,43 +97,76 @@ struct MockReadWrite { }; // Default - MockReadWrite() : mode(SYNCHRONOUS), result(0), data(NULL), data_len(0), - sequence_number(0), time_stamp(base::Time::Now()) {} + MockReadWrite() + : mode(SYNCHRONOUS), + result(0), + data(NULL), + data_len(0), + sequence_number(0), + time_stamp(base::Time::Now()) {} // Read/write failure (no data). - MockReadWrite(IoMode io_mode, int result) : mode(io_mode), result(result), - data(NULL), data_len(0), sequence_number(0), - time_stamp(base::Time::Now()) { } + MockReadWrite(IoMode io_mode, int result) + : mode(io_mode), + result(result), + data(NULL), + data_len(0), + sequence_number(0), + time_stamp(base::Time::Now()) {} // Read/write failure (no data), with sequence information. - MockReadWrite(IoMode io_mode, int result, int seq) : mode(io_mode), - result(result), data(NULL), data_len(0), sequence_number(seq), - time_stamp(base::Time::Now()) { } + MockReadWrite(IoMode io_mode, int result, int seq) + : mode(io_mode), + result(result), + data(NULL), + data_len(0), + sequence_number(seq), + time_stamp(base::Time::Now()) {} // Asynchronous read/write success (inferred data length). - explicit MockReadWrite(const char* data) : mode(ASYNC), result(0), - data(data), data_len(strlen(data)), sequence_number(0), - time_stamp(base::Time::Now()) { } + explicit MockReadWrite(const char* data) + : mode(ASYNC), + result(0), + data(data), + data_len(strlen(data)), + sequence_number(0), + time_stamp(base::Time::Now()) {} // Read/write success (inferred data length). - MockReadWrite(IoMode io_mode, const char* data) : mode(io_mode), result(0), - data(data), data_len(strlen(data)), sequence_number(0), - time_stamp(base::Time::Now()) { } + MockReadWrite(IoMode io_mode, const char* data) + : mode(io_mode), + result(0), + data(data), + data_len(strlen(data)), + sequence_number(0), + time_stamp(base::Time::Now()) {} // Read/write success. - MockReadWrite(IoMode io_mode, const char* data, int data_len) : mode(io_mode), - result(0), data(data), data_len(data_len), sequence_number(0), - time_stamp(base::Time::Now()) { } + MockReadWrite(IoMode io_mode, const char* data, int data_len) + : mode(io_mode), + result(0), + data(data), + data_len(data_len), + sequence_number(0), + time_stamp(base::Time::Now()) {} // Read/write success (inferred data length) with sequence information. - MockReadWrite(IoMode io_mode, int seq, const char* data) : mode(io_mode), - result(0), data(data), data_len(strlen(data)), sequence_number(seq), - time_stamp(base::Time::Now()) { } + MockReadWrite(IoMode io_mode, int seq, const char* data) + : mode(io_mode), + result(0), + data(data), + data_len(strlen(data)), + sequence_number(seq), + time_stamp(base::Time::Now()) {} // Read/write success with sequence information. - MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq) : - mode(io_mode), result(0), data(data), data_len(data_len), - sequence_number(seq), time_stamp(base::Time::Now()) { } + MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq) + : mode(io_mode), + result(0), + data(data), + data_len(data_len), + sequence_number(seq), + time_stamp(base::Time::Now()) {} IoMode mode; int result; @@ -143,18 +176,16 @@ struct MockReadWrite { // For OrderedSocketData, which only allows reads to occur in a particular // sequence. If a read occurs before the given |sequence_number| is reached, // an ERR_IO_PENDING is returned. - int sequence_number; // The sequence number at which a read is allowed - // to occur. - base::Time time_stamp; // The time stamp at which the operation occurred. + int sequence_number; // The sequence number at which a read is allowed + // to occur. + base::Time time_stamp; // The time stamp at which the operation occurred. }; typedef MockReadWrite<MOCK_READ> MockRead; typedef MockReadWrite<MOCK_WRITE> MockWrite; struct MockWriteResult { - MockWriteResult(IoMode io_mode, int result) - : mode(io_mode), - result(result) {} + MockWriteResult(IoMode io_mode, int result) : mode(io_mode), result(result) {} IoMode mode; int result; @@ -208,8 +239,10 @@ class AsyncSocket { class StaticSocketDataProvider : public SocketDataProvider { public: StaticSocketDataProvider(); - StaticSocketDataProvider(MockRead* reads, size_t reads_count, - MockWrite* writes, size_t writes_count); + StaticSocketDataProvider(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count); virtual ~StaticSocketDataProvider(); // These functions get access to the next available read and write data. @@ -231,7 +264,7 @@ class StaticSocketDataProvider : public SocketDataProvider { // SocketDataProvider implementation. virtual MockRead GetNextRead() OVERRIDE; virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE; - ; virtual void Reset() OVERRIDE; + virtual void Reset() OVERRIDE; private: MockRead* reads_; @@ -267,9 +300,7 @@ class DynamicSocketDataProvider : public SocketDataProvider { // The next time there is a read from this socket, it will return |data|. // Before calling SimulateRead next time, the previous data must be consumed. void SimulateRead(const char* data, size_t length); - void SimulateRead(const char* data) { - SimulateRead(data, std::strlen(data)); - } + void SimulateRead(const char* data) { SimulateRead(data, std::strlen(data)); } private: std::deque<MockRead> reads_; @@ -316,8 +347,10 @@ class DelayedSocketData : public StaticSocketDataProvider { // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a // MockRead(true, 0, 0); DelayedSocketData(int write_delay, - MockRead* reads, size_t reads_count, - MockWrite* writes, size_t writes_count); + MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count); // |connect| the result for the connect phase. // |reads| the list of MockRead completions. @@ -326,9 +359,12 @@ class DelayedSocketData : public StaticSocketDataProvider { // |writes| the list of MockWrite completions. // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a // MockRead(true, 0, 0); - DelayedSocketData(const MockConnect& connect, int write_delay, - MockRead* reads, size_t reads_count, - MockWrite* writes, size_t writes_count); + DelayedSocketData(const MockConnect& connect, + int write_delay, + MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count); virtual ~DelayedSocketData(); void ForceNextRead(); @@ -342,7 +378,10 @@ class DelayedSocketData : public StaticSocketDataProvider { private: int write_delay_; bool read_in_progress_; + base::WeakPtrFactory<DelayedSocketData> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(DelayedSocketData); }; // A DataProvider where the reads are ordered. @@ -363,8 +402,10 @@ class OrderedSocketData : public StaticSocketDataProvider { // Note: All MockReads and MockWrites must be async. // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a // MockRead(true, 0, 0); - OrderedSocketData(MockRead* reads, size_t reads_count, - MockWrite* writes, size_t writes_count); + OrderedSocketData(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count); virtual ~OrderedSocketData(); // |connect| the result for the connect phase. @@ -374,8 +415,10 @@ class OrderedSocketData : public StaticSocketDataProvider { // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a // MockRead(true, 0, 0); OrderedSocketData(const MockConnect& connect, - MockRead* reads, size_t reads_count, - MockWrite* writes, size_t writes_count); + MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count); // Posts a quit message to the current message loop, if one is running. void EndLoop(); @@ -390,7 +433,10 @@ class OrderedSocketData : public StaticSocketDataProvider { int sequence_number_; int loop_stop_stage_; bool blocked_; + base::WeakPtrFactory<OrderedSocketData> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(OrderedSocketData); }; class DeterministicMockTCPClientSocket; @@ -452,8 +498,7 @@ class DeterministicMockTCPClientSocket; // // For examples of how to use this class, see: // deterministic_socket_data_unittests.cc -class DeterministicSocketData - : public StaticSocketDataProvider { +class DeterministicSocketData : public StaticSocketDataProvider { public: // The Delegate is an abstract interface which handles the communication from // the DeterministicSocketData to the Deterministic MockSocket. The @@ -481,8 +526,10 @@ class DeterministicSocketData // |reads| the list of MockRead completions. // |writes| the list of MockWrite completions. - DeterministicSocketData(MockRead* reads, size_t reads_count, - MockWrite* writes, size_t writes_count); + DeterministicSocketData(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count); virtual ~DeterministicSocketData(); // Consume all the data up to the give stop point (via SetStop()). @@ -501,9 +548,7 @@ class DeterministicSocketData MockRead& current_read() { return current_read_; } MockWrite& current_write() { return current_write_; } int sequence_number() const { return sequence_number_; } - void set_delegate(base::WeakPtr<Delegate> delegate) { - delegate_ = delegate; - } + void set_delegate(base::WeakPtr<Delegate> delegate) { delegate_ = delegate; } // StaticSocketDataProvider: @@ -524,8 +569,10 @@ class DeterministicSocketData void NextStep(); - void VerifyCorrectSequenceNumbers(MockRead* reads, size_t reads_count, - MockWrite* writes, size_t writes_count); + void VerifyCorrectSequenceNumbers(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count); int sequence_number_; MockRead current_read_; @@ -540,7 +587,7 @@ class DeterministicSocketData // Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}StreamSocket // objects get instantiated, they take their data from the i'th element of this // array. -template<typename T> +template <typename T> class SocketDataProviderArray { public: SocketDataProviderArray() : next_index_(0) {} @@ -557,9 +604,7 @@ class SocketDataProviderArray { size_t next_index() { return next_index_; } - void ResetNextIndex() { - next_index_ = 0; - } + void ResetNextIndex() { next_index_ = 0; } private: // Index of the next |data_providers_| element to use. Not an iterator @@ -624,9 +669,11 @@ class MockClientSocket : public SSLClientSocket { explicit MockClientSocket(const BoundNetLog& net_log); // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, + virtual int Read(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) = 0; - virtual int Write(IOBuffer* buf, int buf_len, + virtual int Write(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) = 0; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -643,8 +690,8 @@ class MockClientSocket : public SSLClientSocket { virtual void SetOmniboxSpeculation() OVERRIDE {} // SSLClientSocket implementation. - virtual void GetSSLCertRequestInfo( - SSLCertRequestInfo* cert_request_info) OVERRIDE; + virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) + OVERRIDE; virtual int ExportKeyingMaterial(const base::StringPiece& label, bool has_context, const base::StringPiece& context, @@ -660,8 +707,6 @@ class MockClientSocket : public SSLClientSocket { void RunCallbackAsync(const CompletionCallback& callback, int result); void RunCallback(const CompletionCallback& callback, int result); - base::WeakPtrFactory<MockClientSocket> weak_factory_; - // True if Connect completed successfully and Disconnect hasn't been called. bool connected_; @@ -669,20 +714,27 @@ class MockClientSocket : public SSLClientSocket { IPEndPoint peer_addr_; BoundNetLog net_log_; + + base::WeakPtrFactory<MockClientSocket> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(MockClientSocket); }; class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { public: - MockTCPClientSocket(const AddressList& addresses, net::NetLog* net_log, + MockTCPClientSocket(const AddressList& addresses, + net::NetLog* net_log, SocketDataProvider* socket); virtual ~MockTCPClientSocket(); const AddressList& addresses() const { return addresses_; } // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, + virtual int Read(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, int buf_len, + virtual int Write(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) OVERRIDE; // StreamSocket implementation. @@ -716,10 +768,12 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { bool peer_closed_connection_; // While an asynchronous IO is pending, we save our user-buffer state. - IOBuffer* pending_buf_; + scoped_refptr<IOBuffer> pending_buf_; int pending_buf_len_; CompletionCallback pending_callback_; bool was_used_to_convey_data_; + + DISALLOW_COPY_AND_ASSIGN(MockTCPClientSocket); }; // DeterministicSocketHelper is a helper class that can be used @@ -740,10 +794,8 @@ class DeterministicSocketHelper { void CompleteWrite(); int CompleteRead(); - int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback); - int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback); + int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); const BoundNetLog& net_log() const { return net_log_; } @@ -788,9 +840,11 @@ class DeterministicMockUDPClientSocket virtual int CompleteRead() OVERRIDE; // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, + virtual int Read(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, int buf_len, + virtual int Write(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -812,6 +866,8 @@ class DeterministicMockUDPClientSocket bool connected_; IPEndPoint peer_address_; DeterministicSocketHelper helper_; + + DISALLOW_COPY_AND_ASSIGN(DeterministicMockUDPClientSocket); }; // Mock TCP socket to be used in conjunction with DeterministicSocketData. @@ -832,9 +888,11 @@ class DeterministicMockTCPClientSocket virtual int CompleteRead() OVERRIDE; // Socket: - virtual int Write(IOBuffer* buf, int buf_len, + virtual int Write(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) OVERRIDE; - virtual int Read(IOBuffer* buf, int buf_len, + virtual int Read(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) OVERRIDE; // StreamSocket: @@ -853,21 +911,24 @@ class DeterministicMockTCPClientSocket private: DeterministicSocketHelper helper_; + + DISALLOW_COPY_AND_ASSIGN(DeterministicMockTCPClientSocket); }; class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { public: - MockSSLClientSocket( - scoped_ptr<ClientSocketHandle> transport_socket, - const HostPortPair& host_and_port, - const SSLConfig& ssl_config, - SSLSocketDataProvider* socket); + MockSSLClientSocket(scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + SSLSocketDataProvider* socket); virtual ~MockSSLClientSocket(); // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, + virtual int Read(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, int buf_len, + virtual int Write(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) OVERRIDE; // StreamSocket implementation. @@ -881,13 +942,12 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // SSLClientSocket implementation. - virtual void GetSSLCertRequestInfo( - SSLCertRequestInfo* cert_request_info) OVERRIDE; + virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) + OVERRIDE; virtual NextProtoStatus GetNextProto(std::string* proto, std::string* server_protos) OVERRIDE; virtual bool set_was_npn_negotiated(bool negotiated) OVERRIDE; - virtual void set_protocol_negotiated( - NextProto protocol_negotiated) OVERRIDE; + virtual void set_protocol_negotiated(NextProto protocol_negotiated) OVERRIDE; virtual NextProto GetNegotiatedProtocol() const OVERRIDE; // This MockSocket does not implement the manual async IO feature. @@ -899,7 +959,7 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE; private: - static void ConnectCallback(MockSSLClientSocket *ssl_client_socket, + static void ConnectCallback(MockSSLClientSocket* ssl_client_socket, const CompletionCallback& callback, int rv); @@ -909,19 +969,21 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { bool new_npn_value_; bool is_protocol_negotiated_set_; NextProto protocol_negotiated_; + + DISALLOW_COPY_AND_ASSIGN(MockSSLClientSocket); }; -class MockUDPClientSocket - : public DatagramClientSocket, - public AsyncSocket { +class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket { public: MockUDPClientSocket(SocketDataProvider* data, net::NetLog* net_log); virtual ~MockUDPClientSocket(); // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, + virtual int Read(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, int buf_len, + virtual int Write(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -955,7 +1017,7 @@ class MockUDPClientSocket IPEndPoint peer_addr_; // While an asynchronous IO is pending, we save our user-buffer state. - IOBuffer* pending_buf_; + scoped_refptr<IOBuffer> pending_buf_; int pending_buf_len_; CompletionCallback pending_callback_; @@ -1009,12 +1071,15 @@ class ClientSocketPoolTest { RequestPriority priority, const scoped_refptr<typename PoolType::SocketParams>& socket_params) { DCHECK(socket_pool); - TestSocketRequest* request = new TestSocketRequest(&request_order_, - &completion_count_); + TestSocketRequest* request = + new TestSocketRequest(&request_order_, &completion_count_); requests_.push_back(request); - int rv = request->handle()->Init( - group_name, socket_params, priority, request->callback(), - socket_pool, BoundNetLog()); + int rv = request->handle()->Init(group_name, + socket_params, + priority, + request->callback(), + socket_pool, + BoundNetLog()); if (rv != ERR_IO_PENDING) request_order_.push_back(request); return rv; @@ -1045,6 +1110,8 @@ class ClientSocketPoolTest { ScopedVector<TestSocketRequest> requests_; std::vector<TestSocketRequest*> request_order_; size_t completion_count_; + + DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolTest); }; class MockTransportSocketParams @@ -1052,6 +1119,8 @@ class MockTransportSocketParams private: friend class base::RefCounted<MockTransportSocketParams>; ~MockTransportSocketParams() {} + + DISALLOW_COPY_AND_ASSIGN(MockTransportSocketParams); }; class MockTransportClientSocketPool : public TransportClientSocketPool { @@ -1060,7 +1129,8 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { class MockConnectJob { public: - MockConnectJob(scoped_ptr<StreamSocket> socket, ClientSocketHandle* handle, + MockConnectJob(scoped_ptr<StreamSocket> socket, + ClientSocketHandle* handle, const CompletionCallback& callback); ~MockConnectJob(); @@ -1077,11 +1147,10 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { DISALLOW_COPY_AND_ASSIGN(MockConnectJob); }; - MockTransportClientSocketPool( - int max_sockets, - int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, - ClientSocketFactory* socket_factory); + MockTransportClientSocketPool(int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + ClientSocketFactory* socket_factory); virtual ~MockTransportClientSocketPool(); @@ -1163,15 +1232,16 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory { std::vector<DeterministicMockTCPClientSocket*> tcp_client_sockets_; std::vector<DeterministicMockUDPClientSocket*> udp_client_sockets_; std::vector<MockSSLClientSocket*> ssl_client_sockets_; + + DISALLOW_COPY_AND_ASSIGN(DeterministicMockClientSocketFactory); }; class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { public: - MockSOCKSClientSocketPool( - int max_sockets, - int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, - TransportClientSocketPool* transport_pool); + MockSOCKSClientSocketPool(int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + TransportClientSocketPool* transport_pool); virtual ~MockSOCKSClientSocketPool(); diff --git a/chromium/net/socket/socks5_client_socket.cc b/chromium/net/socket/socks5_client_socket.cc index 537b584a932..004b67c669f 100644 --- a/chromium/net/socket/socks5_client_socket.cc +++ b/chromium/net/socket/socks5_client_socket.cc @@ -254,7 +254,6 @@ int SOCKS5ClientSocket::DoLoop(int last_io_result) { } const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 }; // no authentication -const char kSOCKS5GreetReadData[] = { 0x05, 0x00 }; int SOCKS5ClientSocket::DoGreetWrite() { // Since we only have 1 byte to send the hostname length in, if the diff --git a/chromium/net/socket/socks_client_socket_pool_unittest.cc b/chromium/net/socket/socks_client_socket_pool_unittest.cc index 4463e171f84..b2b8655ee22 100644 --- a/chromium/net/socket/socks_client_socket_pool_unittest.cc +++ b/chromium/net/socket/socks_client_socket_pool_unittest.cc @@ -143,7 +143,7 @@ TEST_F(SOCKSClientSocketPoolTest, Simple) { // Make sure that SOCKSConnectJob passes on its priority to its // socket request on Init. TEST_F(SOCKSClientSocketPoolTest, SetSocketRequestPriorityOnInit) { - for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) { + for (int i = MINIMUM_PRIORITY; i <= MAXIMUM_PRIORITY; ++i) { RequestPriority priority = static_cast<RequestPriority>(i); SOCKS5MockData data(SYNCHRONOUS); data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK)); @@ -162,7 +162,7 @@ TEST_F(SOCKSClientSocketPoolTest, SetSocketRequestPriorityOnInit) { // Make sure that SOCKSConnectJob passes on its priority to its // HostResolver request (for non-SOCKS5) on Init. TEST_F(SOCKSClientSocketPoolTest, SetResolvePriorityOnInit) { - for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) { + for (int i = MINIMUM_PRIORITY; i <= MAXIMUM_PRIORITY; ++i) { RequestPriority priority = static_cast<RequestPriority>(i); SOCKS5MockData data(SYNCHRONOUS); data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK)); diff --git a/chromium/net/socket/ssl_client_socket.cc b/chromium/net/socket/ssl_client_socket.cc index 54f66a1f681..d96a720e428 100644 --- a/chromium/net/socket/ssl_client_socket.cc +++ b/chromium/net/socket/ssl_client_socket.cc @@ -4,7 +4,11 @@ #include "net/socket/ssl_client_socket.h" +#include "base/metrics/histogram.h" #include "base/strings/string_util.h" +#include "crypto/ec_private_key.h" +#include "net/ssl/server_bound_cert_service.h" +#include "net/ssl/ssl_config_service.h" namespace net { @@ -12,7 +16,9 @@ SSLClientSocket::SSLClientSocket() : was_npn_negotiated_(false), was_spdy_negotiated_(false), protocol_negotiated_(kProtoUnknown), - channel_id_sent_(false) { + channel_id_sent_(false), + signed_cert_timestamps_received_(false), + stapled_ocsp_response_received_(false) { } // static @@ -20,10 +26,8 @@ NextProto SSLClientSocket::NextProtoFromString( const std::string& proto_string) { if (proto_string == "http1.1" || proto_string == "http/1.1") { return kProtoHTTP11; - } else if (proto_string == "spdy/1") { - return kProtoSPDY1; } else if (proto_string == "spdy/2") { - return kProtoSPDY2; + return kProtoDeprecatedSPDY2; } else if (proto_string == "spdy/3") { return kProtoSPDY3; } else if (proto_string == "spdy/3.1") { @@ -44,9 +48,7 @@ const char* SSLClientSocket::NextProtoToString(NextProto next_proto) { switch (next_proto) { case kProtoHTTP11: return "http/1.1"; - case kProtoSPDY1: - return "spdy/1"; - case kProtoSPDY2: + case kProtoDeprecatedSPDY2: return "spdy/2"; case kProtoSPDY3: return "spdy/3"; @@ -58,7 +60,6 @@ const char* SSLClientSocket::NextProtoToString(NextProto next_proto) { return "HTTP-draft-04/2.0"; case kProtoQUIC1SPDY3: return "quic/1+spdy/3"; - case kProtoSPDY21: case kProtoUnknown: break; } @@ -145,4 +146,68 @@ void SSLClientSocket::set_channel_id_sent(bool channel_id_sent) { channel_id_sent_ = channel_id_sent; } +void SSLClientSocket::set_signed_cert_timestamps_received( + bool signed_cert_timestamps_received) { + signed_cert_timestamps_received_ = signed_cert_timestamps_received; +} + +void SSLClientSocket::set_stapled_ocsp_response_received( + bool stapled_ocsp_response_received) { + stapled_ocsp_response_received_ = stapled_ocsp_response_received; +} + +// static +void SSLClientSocket::RecordChannelIDSupport( + ServerBoundCertService* server_bound_cert_service, + bool negotiated_channel_id, + bool channel_id_enabled, + bool supports_ecc) { + // Since this enum is used for a histogram, do not change or re-use values. + enum { + DISABLED = 0, + CLIENT_ONLY = 1, + CLIENT_AND_SERVER = 2, + CLIENT_NO_ECC = 3, + CLIENT_BAD_SYSTEM_TIME = 4, + CLIENT_NO_SERVER_BOUND_CERT_SERVICE = 5, + DOMAIN_BOUND_CERT_USAGE_MAX + } supported = DISABLED; + if (negotiated_channel_id) { + supported = CLIENT_AND_SERVER; + } else if (channel_id_enabled) { + if (!server_bound_cert_service) + supported = CLIENT_NO_SERVER_BOUND_CERT_SERVICE; + else if (!supports_ecc) + supported = CLIENT_NO_ECC; + else if (!server_bound_cert_service->IsSystemTimeValid()) + supported = CLIENT_BAD_SYSTEM_TIME; + else + supported = CLIENT_ONLY; + } + UMA_HISTOGRAM_ENUMERATION("DomainBoundCerts.Support", supported, + DOMAIN_BOUND_CERT_USAGE_MAX); +} + +// static +bool SSLClientSocket::IsChannelIDEnabled( + const SSLConfig& ssl_config, + ServerBoundCertService* server_bound_cert_service) { + if (!ssl_config.channel_id_enabled) + return false; + if (!server_bound_cert_service) { + DVLOG(1) << "NULL server_bound_cert_service_, not enabling channel ID."; + return false; + } + if (!crypto::ECPrivateKey::IsSupported()) { + DVLOG(1) << "Elliptic Curve not supported, not enabling channel ID."; + return false; + } + if (!server_bound_cert_service->IsSystemTimeValid()) { + DVLOG(1) << "System time is not within the supported range for certificate " + "generation, not enabling channel ID."; + return false; + } + return true; +} + } // namespace net diff --git a/chromium/net/socket/ssl_client_socket.h b/chromium/net/socket/ssl_client_socket.h index 41ee0873347..410062dc5a9 100644 --- a/chromium/net/socket/ssl_client_socket.h +++ b/chromium/net/socket/ssl_client_socket.h @@ -7,6 +7,7 @@ #include <string> +#include "base/gtest_prod_util.h" #include "net/base/completion_callback.h" #include "net/base/load_flags.h" #include "net/base/net_errors.h" @@ -16,8 +17,10 @@ namespace net { class CertVerifier; +class CTVerifier; class ServerBoundCertService; class SSLCertRequestInfo; +struct SSLConfig; class SSLInfo; class TransportSecurityState; @@ -27,20 +30,24 @@ struct SSLClientSocketContext { SSLClientSocketContext() : cert_verifier(NULL), server_bound_cert_service(NULL), - transport_security_state(NULL) {} + transport_security_state(NULL), + cert_transparency_verifier(NULL) {} SSLClientSocketContext(CertVerifier* cert_verifier_arg, ServerBoundCertService* server_bound_cert_service_arg, TransportSecurityState* transport_security_state_arg, + CTVerifier* cert_transparency_verifier_arg, const std::string& ssl_session_cache_shard_arg) : cert_verifier(cert_verifier_arg), server_bound_cert_service(server_bound_cert_service_arg), transport_security_state(transport_security_state_arg), + cert_transparency_verifier(cert_transparency_verifier_arg), ssl_session_cache_shard(ssl_session_cache_shard_arg) {} CertVerifier* cert_verifier; ServerBoundCertService* server_bound_cert_service; TransportSecurityState* transport_security_state; + CTVerifier* cert_transparency_verifier; // ssl_session_cache_shard is an opaque string that identifies a shard of the // SSL session cache. SSL sockets with the same ssl_session_cache_shard may // resume each other's SSL sessions but we'll never sessions between shards. @@ -121,11 +128,41 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { // This may be useful for protocols, like SPDY, which allow the same // connection to be shared between multiple domains, each of which need // a channel ID. + // + // Public for ssl_client_socket_openssl_unittest.cc. virtual bool WasChannelIDSent() const; + protected: virtual void set_channel_id_sent(bool channel_id_sent); + virtual void set_signed_cert_timestamps_received( + bool signed_cert_timestamps_received); + + virtual void set_stapled_ocsp_response_received( + bool stapled_ocsp_response_received); + + // Records histograms for channel id support during full handshakes - resumed + // handshakes are ignored. + static void RecordChannelIDSupport( + ServerBoundCertService* server_bound_cert_service, + bool negotiated_channel_id, + bool channel_id_enabled, + bool supports_ecc); + + // Returns whether TLS channel ID is enabled. + static bool IsChannelIDEnabled( + const SSLConfig& ssl_config, + ServerBoundCertService* server_bound_cert_service); + private: + // For signed_cert_timestamps_received_ and stapled_ocsp_response_received_. + FRIEND_TEST_ALL_PREFIXES(SSLClientSocketTest, + ConnectSignedCertTimestampsEnabledTLSExtension); + FRIEND_TEST_ALL_PREFIXES(SSLClientSocketTest, + ConnectSignedCertTimestampsEnabledOCSP); + FRIEND_TEST_ALL_PREFIXES(SSLClientSocketTest, + ConnectSignedCertTimestampsDisabled); + // True if NPN was responded to, independent of selecting SPDY or HTTP. bool was_npn_negotiated_; // True if NPN successfully negotiated SPDY. @@ -134,6 +171,10 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { NextProto protocol_negotiated_; // True if a channel ID was sent. bool channel_id_sent_; + // True if SCTs were received via a TLS extension. + bool signed_cert_timestamps_received_; + // True if a stapled OCSP response was received. + bool stapled_ocsp_response_received_; }; } // namespace net diff --git a/chromium/net/socket/ssl_client_socket_nss.cc b/chromium/net/socket/ssl_client_socket_nss.cc index 0de7cfb9060..0c73ec6c582 100644 --- a/chromium/net/socket/ssl_client_socket_nss.cc +++ b/chromium/net/socket/ssl_client_socket_nss.cc @@ -93,6 +93,11 @@ #include "net/cert/asn1_util.h" #include "net/cert/cert_status_flags.h" #include "net/cert/cert_verifier.h" +#include "net/cert/ct_objects_extractor.h" +#include "net/cert/ct_verifier.h" +#include "net/cert/ct_verify_result.h" +#include "net/cert/scoped_nss_types.h" +#include "net/cert/sct_status_flags.h" #include "net/cert/single_request_cert_verifier.h" #include "net/cert/x509_certificate_net_log_param.h" #include "net/cert/x509_util.h" @@ -221,15 +226,6 @@ bool IsOCSPStaplingSupported() { } #endif -class FreeCERTCertificate { - public: - inline void operator()(CERTCertificate* x) const { - CERT_DestroyCertificate(x); - } -}; -typedef scoped_ptr_malloc<CERTCertificate, FreeCERTCertificate> - ScopedCERTCertificate; - #if defined(OS_WIN) // This callback is intended to be used with CertFindChainInStore. In addition @@ -422,6 +418,8 @@ struct HandshakeState { channel_id_sent = false; server_cert_chain.Reset(NULL); server_cert = NULL; + sct_list_from_tls_extension.clear(); + stapled_ocsp_response.clear(); resumed_handshake = false; ssl_connection_status = 0; } @@ -451,6 +449,10 @@ struct HandshakeState { // always be non-NULL. PeerCertificateChain server_cert_chain; scoped_refptr<X509Certificate> server_cert; + // SignedCertificateTimestampList received via TLS extension (RFC 6962). + std::string sct_list_from_tls_extension; + // Stapled OCSP response received. + std::string stapled_ocsp_response; // True if the current handshake was the result of TLS session resumption. bool resumed_handshake; @@ -650,6 +652,14 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { bool HasPendingAsyncOperation(); bool HasUnhandledReceivedData(); + // Called on the network task runner. + // Causes the associated SSL/TLS session ID to be added to NSS's session + // cache, but only if the connection has not been False Started. + // + // This should only be called after the server's certificate has been + // verified, and may not be called within an NSS callback. + void CacheSessionIfNecessary(); + private: friend class base::RefCountedThreadSafe<Core>; ~Core(); @@ -707,10 +717,19 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { SECKEYPrivateKey** result_private_key); #endif + // Called by NSS to determine if we can False Start. + // |arg| contains a pointer to the current SSLClientSocketNSS::Core. + static SECStatus CanFalseStartCallback(PRFileDesc* socket, + void* arg, + PRBool* can_false_start); + // Called by NSS once the handshake has completed. // |arg| contains a pointer to the current SSLClientSocketNSS::Core. static void HandshakeCallback(PRFileDesc* socket, void* arg); + // Called once the handshake has succeeded. + void HandshakeSucceeded(); + // Handles an NSS error generated while handshaking or performing IO. // Returns a network error code mapped from the original NSS error. int HandleNSSError(PRErrorCode error, bool handshake_error); @@ -753,11 +772,18 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { // Updates the NSS and platform specific certificates. void UpdateServerCert(); + // Update the nss_handshake_state_ with the SignedCertificateTimestampList + // received in the handshake via a TLS extension. + void UpdateSignedCertTimestamps(); + // Update the OCSP response cache with the stapled response received in the + // handshake, and update nss_handshake_state_ with + // the SignedCertificateTimestampList received in the stapled OCSP response. + void UpdateStapledOCSPResponse(); // Updates the nss_handshake_state_ with the negotiated security parameters. void UpdateConnectionStatus(); // Record histograms for channel id support during full handshakes - resumed // handshakes are ignored. - void RecordChannelIDSupport(); + void RecordChannelIDSupportOnNSSTaskRunner(); // UpdateNextProto gets any application-layer protocol that may have been // negotiated by the TLS connection. void UpdateNextProto(); @@ -862,6 +888,8 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { bool channel_id_needed_; // True if the handshake state machine was interrupted for client auth. bool client_auth_cert_needed_; + // True if NSS has False Started. + bool false_started_; // True if NSS has called HandshakeCallback. bool handshake_callback_called_; @@ -930,6 +958,7 @@ SSLClientSocketNSS::Core::Core( channel_id_xtn_negotiated_(false), channel_id_needed_(false), client_auth_cert_needed_(false), + false_started_(false), handshake_callback_called_(false), transport_recv_busy_(false), transport_recv_eof_(false), @@ -1009,22 +1038,22 @@ bool SSLClientSocketNSS::Core::Init(PRFileDesc* socket, return false; } - if (ssl_config_.channel_id_enabled) { - if (!server_bound_cert_service_) { - DVLOG(1) << "NULL server_bound_cert_service_, not enabling channel ID."; - } else if (!crypto::ECPrivateKey::IsSupported()) { - DVLOG(1) << "Elliptic Curve not supported, not enabling channel ID."; - } else if (!server_bound_cert_service_->IsSystemTimeValid()) { - DVLOG(1) << "System time is weird, not enabling channel ID."; - } else { - rv = SSL_SetClientChannelIDCallback( - nss_fd_, SSLClientSocketNSS::Core::ClientChannelIDHandler, this); - if (rv != SECSuccess) - LogFailedNSSFunction(*weak_net_log_, "SSL_SetClientChannelIDCallback", - ""); + if (IsChannelIDEnabled(ssl_config_, server_bound_cert_service_)) { + rv = SSL_SetClientChannelIDCallback( + nss_fd_, SSLClientSocketNSS::Core::ClientChannelIDHandler, this); + if (rv != SECSuccess) { + LogFailedNSSFunction( + *weak_net_log_, "SSL_SetClientChannelIDCallback", ""); } } + rv = SSL_SetCanFalseStartCallback( + nss_fd_, SSLClientSocketNSS::Core::CanFalseStartCallback, this); + if (rv != SECSuccess) { + LogFailedNSSFunction(*weak_net_log_, "SSL_SetCanFalseStartCallback", ""); + return false; + } + rv = SSL_HandshakeCallback( nss_fd_, SSLClientSocketNSS::Core::HandshakeCallback, this); if (rv != SECSuccess) { @@ -1151,7 +1180,7 @@ int SSLClientSocketNSS::Core::Read(IOBuffer* buf, int buf_len, } DCHECK(OnNSSTaskRunner()); - DCHECK(handshake_callback_called_); + DCHECK(false_started_ || handshake_callback_called_); DCHECK_EQ(STATE_NONE, next_handshake_state_); DCHECK(user_read_callback_.is_null()); DCHECK(user_connect_callback_.is_null()); @@ -1205,7 +1234,7 @@ int SSLClientSocketNSS::Core::Write(IOBuffer* buf, int buf_len, } DCHECK(OnNSSTaskRunner()); - DCHECK(handshake_callback_called_); + DCHECK(false_started_ || handshake_callback_called_); DCHECK_EQ(STATE_NONE, next_handshake_state_); DCHECK(user_write_callback_.is_null()); DCHECK(user_connect_callback_.is_null()); @@ -1253,6 +1282,30 @@ bool SSLClientSocketNSS::Core::HasUnhandledReceivedData() { return unhandled_buffer_size_ != 0; } +void SSLClientSocketNSS::Core::CacheSessionIfNecessary() { + // TODO(rsleevi): This should occur on the NSS task runner, due to the use of + // nss_fd_. However, it happens on the network task runner in order to match + // the buggy behavior of ExportKeyingMaterial. + // + // Once http://crbug.com/330360 is fixed, this should be moved to an + // implementation that exclusively does this work on the NSS TaskRunner. This + // is "safe" because it is only called during the certificate verification + // state machine of the main socket, which is safe because no underlying + // transport IO will be occuring in that state, and NSS will not be blocking + // on any PKCS#11 related locks that might block the Network TaskRunner. + DCHECK(OnNetworkTaskRunner()); + + // Only cache the session if the connection was not False Started, because + // sessions should only be cached *after* the peer's Finished message is + // processed. + // In the case of False Start, the session will be cached once the + // HandshakeCallback is called, which signals the receipt and processing of + // the Finished message, and which will happen during a call to + // PR_Read/PR_Write. + if (!false_started_) + SSL_CacheSession(nss_fd_); +} + bool SSLClientSocketNSS::Core::OnNSSTaskRunner() const { return nss_task_runner_->RunsTasksOnCurrentThread(); } @@ -1268,26 +1321,7 @@ SECStatus SSLClientSocketNSS::Core::OwnAuthCertHandler( PRBool checksig, PRBool is_server) { Core* core = reinterpret_cast<Core*>(arg); - if (!core->handshake_callback_called_) { - // Only need to turn off False Start in the initial handshake. Also, it is - // unsafe to call SSL_OptionSet in a renegotiation because the "first - // handshake" lock isn't already held, which will result in an assertion - // failure in the ssl_Get1stHandshakeLock call in SSL_OptionSet. - PRBool negotiated_extension; - SECStatus rv = SSL_HandshakeNegotiatedExtension(socket, - ssl_app_layer_protocol_xtn, - &negotiated_extension); - if (rv != SECSuccess || !negotiated_extension) { - rv = SSL_HandshakeNegotiatedExtension(socket, - ssl_next_proto_nego_xtn, - &negotiated_extension); - } - if (rv != SECSuccess || !negotiated_extension) { - // If the server doesn't support NPN or ALPN, then we don't do False - // Start with it. - SSL_OptionSet(socket, SSL_ENABLE_FALSE_START, PR_FALSE); - } - } else { + if (core->handshake_callback_called_) { // Disallow the server certificate to change in a renegotiation. CERTCertificate* old_cert = core->nss_handshake_state_.server_cert_chain[0]; ScopedCERTCertificate new_cert(SSL_PeerCertificate(socket)); @@ -1615,6 +1649,30 @@ SECStatus SSLClientSocketNSS::Core::ClientAuthHandler( #endif // NSS_PLATFORM_CLIENT_AUTH // static +SECStatus SSLClientSocketNSS::Core::CanFalseStartCallback( + PRFileDesc* socket, + void* arg, + PRBool* can_false_start) { + // If the server doesn't support NPN or ALPN, then we don't do False + // Start with it. + PRBool negotiated_extension; + SECStatus rv = SSL_HandshakeNegotiatedExtension(socket, + ssl_app_layer_protocol_xtn, + &negotiated_extension); + if (rv != SECSuccess || !negotiated_extension) { + rv = SSL_HandshakeNegotiatedExtension(socket, + ssl_next_proto_nego_xtn, + &negotiated_extension); + } + if (rv != SECSuccess || !negotiated_extension) { + *can_false_start = PR_FALSE; + return SECSuccess; + } + + return SSL_RecommendedCanFalseStart(socket, can_false_start); +} + +// static void SSLClientSocketNSS::Core::HandshakeCallback( PRFileDesc* socket, void* arg) { @@ -1622,27 +1680,47 @@ void SSLClientSocketNSS::Core::HandshakeCallback( DCHECK(core->OnNSSTaskRunner()); core->handshake_callback_called_ = true; + if (core->false_started_) { + core->false_started_ = false; + // If the connection was False Started, then at the time of this callback, + // the peer's certificate will have been verified or the caller will have + // accepted the error. + // This is guaranteed when using False Start because this callback will + // not be invoked until processing the peer's Finished message, which + // will only happen in a PR_Read/PR_Write call, which can only happen + // after the peer's certificate is verified. + SSL_CacheSessionUnlocked(socket); + + // Additionally, when False Starting, DoHandshake() will have already + // called HandshakeSucceeded(), so return now. + return; + } + core->HandshakeSucceeded(); +} - HandshakeState* nss_state = &core->nss_handshake_state_; +void SSLClientSocketNSS::Core::HandshakeSucceeded() { + DCHECK(OnNSSTaskRunner()); PRBool last_handshake_resumed; - SECStatus rv = SSL_HandshakeResumedSession(socket, &last_handshake_resumed); + SECStatus rv = SSL_HandshakeResumedSession(nss_fd_, &last_handshake_resumed); if (rv == SECSuccess && last_handshake_resumed) { - nss_state->resumed_handshake = true; + nss_handshake_state_.resumed_handshake = true; } else { - nss_state->resumed_handshake = false; + nss_handshake_state_.resumed_handshake = false; } - core->RecordChannelIDSupport(); - core->UpdateServerCert(); - core->UpdateConnectionStatus(); - core->UpdateNextProto(); + RecordChannelIDSupportOnNSSTaskRunner(); + UpdateServerCert(); + UpdateSignedCertTimestamps(); + UpdateStapledOCSPResponse(); + UpdateConnectionStatus(); + UpdateNextProto(); // Update the network task runners view of the handshake state whenever // a handshake has completed. - core->PostOrRunCallback( - FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, core, - *nss_state)); + PostOrRunCallback( + FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, this, + nss_handshake_state_)); } int SSLClientSocketNSS::Core::HandleNSSError(PRErrorCode nss_error, @@ -1717,7 +1795,7 @@ int SSLClientSocketNSS::Core::DoHandshakeLoop(int last_io_result) { int SSLClientSocketNSS::Core::DoReadLoop(int result) { DCHECK(OnNSSTaskRunner()); - DCHECK(handshake_callback_called_); + DCHECK(false_started_ || handshake_callback_called_); DCHECK_EQ(STATE_NONE, next_handshake_state_); if (result < 0) @@ -1746,7 +1824,7 @@ int SSLClientSocketNSS::Core::DoReadLoop(int result) { int SSLClientSocketNSS::Core::DoWriteLoop(int result) { DCHECK(OnNSSTaskRunner()); - DCHECK(handshake_callback_called_); + DCHECK(false_started_ || handshake_callback_called_); DCHECK_EQ(STATE_NONE, next_handshake_state_); if (result < 0) @@ -1803,52 +1881,9 @@ int SSLClientSocketNSS::Core::DoHandshake() { LOG(WARNING) << "Couldn't invalidate SSL session: " << PR_GetError(); } else if (rv == SECSuccess) { if (!handshake_callback_called_) { - // Workaround for https://bugzilla.mozilla.org/show_bug.cgi?id=562434 - - // SSL_ForceHandshake returned SECSuccess prematurely. - rv = SECFailure; - net_error = ERR_SSL_PROTOCOL_ERROR; - PostOrRunCallback( - FROM_HERE, - base::Bind(&AddLogEventWithCallback, weak_net_log_, - NetLog::TYPE_SSL_HANDSHAKE_ERROR, - CreateNetLogSSLErrorCallback(net_error, 0))); - } else { - #if defined(SSL_ENABLE_OCSP_STAPLING) - // TODO(agl): figure out how to plumb an OCSP response into the Mac - // system library and update IsOCSPStaplingSupported for Mac. - if (IsOCSPStaplingSupported()) { - const SECItemArray* ocsp_responses = - SSL_PeerStapledOCSPResponses(nss_fd_); - if (ocsp_responses->len) { - #if defined(OS_WIN) - if (nss_handshake_state_.server_cert) { - CRYPT_DATA_BLOB ocsp_response_blob; - ocsp_response_blob.cbData = ocsp_responses->items[0].len; - ocsp_response_blob.pbData = ocsp_responses->items[0].data; - BOOL ok = CertSetCertificateContextProperty( - nss_handshake_state_.server_cert->os_cert_handle(), - CERT_OCSP_RESPONSE_PROP_ID, - CERT_SET_PROPERTY_IGNORE_PERSIST_ERROR_FLAG, - &ocsp_response_blob); - if (!ok) { - VLOG(1) << "Failed to set OCSP response property: " - << GetLastError(); - } - } - #elif defined(USE_NSS) - CacheOCSPResponseFromSideChannelFunction cache_ocsp_response = - GetCacheOCSPResponseFromSideChannelFunction(); - - cache_ocsp_response( - CERT_GetDefaultCertDB(), - nss_handshake_state_.server_cert_chain[0], PR_Now(), - &ocsp_responses->items[0], NULL); - #endif - } - } - #endif + false_started_ = true; + HandshakeSucceeded(); } - // Done! } else { PRErrorCode prerr = PR_GetError(); net_error = HandleNSSError(prerr, true); @@ -2364,8 +2399,10 @@ int SSLClientSocketNSS::Core::ImportChannelIDKeys(SECKEYPublicKey** public_key, if (cert == NULL) return MapNSSError(PORT_GetError()); + crypto::ScopedPK11Slot slot(PK11_GetInternalSlot()); // Set the private key. if (!crypto::ECPrivateKey::ImportFromEncryptedPrivateKeyInfo( + slot.get(), ServerBoundCertService::kEPKIPassword, reinterpret_cast<const unsigned char*>( domain_bound_private_key_.data()), @@ -2400,6 +2437,58 @@ void SSLClientSocketNSS::Core::UpdateServerCert() { } } +void SSLClientSocketNSS::Core::UpdateSignedCertTimestamps() { + const SECItem* signed_cert_timestamps = + SSL_PeerSignedCertTimestamps(nss_fd_); + + if (!signed_cert_timestamps || !signed_cert_timestamps->len) + return; + + nss_handshake_state_.sct_list_from_tls_extension = std::string( + reinterpret_cast<char*>(signed_cert_timestamps->data), + signed_cert_timestamps->len); +} + +void SSLClientSocketNSS::Core::UpdateStapledOCSPResponse() { + const SECItemArray* ocsp_responses = + SSL_PeerStapledOCSPResponses(nss_fd_); + if (!ocsp_responses || !ocsp_responses->len) + return; + + nss_handshake_state_.stapled_ocsp_response = std::string( + reinterpret_cast<char*>(ocsp_responses->items[0].data), + ocsp_responses->items[0].len); + + // TODO(agl): figure out how to plumb an OCSP response into the Mac + // system library and update IsOCSPStaplingSupported for Mac. + if (IsOCSPStaplingSupported()) { + #if defined(OS_WIN) + if (nss_handshake_state_.server_cert) { + CRYPT_DATA_BLOB ocsp_response_blob; + ocsp_response_blob.cbData = ocsp_responses->items[0].len; + ocsp_response_blob.pbData = ocsp_responses->items[0].data; + BOOL ok = CertSetCertificateContextProperty( + nss_handshake_state_.server_cert->os_cert_handle(), + CERT_OCSP_RESPONSE_PROP_ID, + CERT_SET_PROPERTY_IGNORE_PERSIST_ERROR_FLAG, + &ocsp_response_blob); + if (!ok) { + VLOG(1) << "Failed to set OCSP response property: " + << GetLastError(); + } + } + #elif defined(USE_NSS) + CacheOCSPResponseFromSideChannelFunction cache_ocsp_response = + GetCacheOCSPResponseFromSideChannelFunction(); + + cache_ocsp_response( + CERT_GetDefaultCertDB(), + nss_handshake_state_.server_cert_chain[0], PR_Now(), + &ocsp_responses->items[0], NULL); + #endif + } // IsOCSPStaplingSupported() +} + void SSLClientSocketNSS::Core::UpdateConnectionStatus() { SSLChannelInfo channel_info; SECStatus ok = SSL_GetChannelInfo(nss_fd_, @@ -2506,7 +2595,7 @@ void SSLClientSocketNSS::Core::UpdateNextProto() { } } -void SSLClientSocketNSS::Core::RecordChannelIDSupport() { +void SSLClientSocketNSS::Core::RecordChannelIDSupportOnNSSTaskRunner() { DCHECK(OnNSSTaskRunner()); if (nss_handshake_state_.resumed_handshake) return; @@ -2529,30 +2618,10 @@ void SSLClientSocketNSS::Core::RecordChannelIDSupportOnNetworkTaskRunner( bool supports_ecc) const { DCHECK(OnNetworkTaskRunner()); - // Since this enum is used for a histogram, do not change or re-use values. - enum { - DISABLED = 0, - CLIENT_ONLY = 1, - CLIENT_AND_SERVER = 2, - CLIENT_NO_ECC = 3, - CLIENT_BAD_SYSTEM_TIME = 4, - CLIENT_NO_SERVER_BOUND_CERT_SERVICE = 5, - DOMAIN_BOUND_CERT_USAGE_MAX - } supported = DISABLED; - if (negotiated_channel_id) { - supported = CLIENT_AND_SERVER; - } else if (channel_id_enabled) { - if (!server_bound_cert_service_) - supported = CLIENT_NO_SERVER_BOUND_CERT_SERVICE; - else if (!supports_ecc) - supported = CLIENT_NO_ECC; - else if (!server_bound_cert_service_->IsSystemTimeValid()) - supported = CLIENT_BAD_SYSTEM_TIME; - else - supported = CLIENT_ONLY; - } - UMA_HISTOGRAM_ENUMERATION("DomainBoundCerts.Support", supported, - DOMAIN_BOUND_CERT_USAGE_MAX); + RecordChannelIDSupport(server_bound_cert_service_, + negotiated_channel_id, + channel_id_enabled, + supports_ecc); } int SSLClientSocketNSS::Core::DoBufferRecv(IOBuffer* read_buffer, int len) { @@ -2769,6 +2838,7 @@ SSLClientSocketNSS::SSLClientSocketNSS( host_and_port_(host_and_port), ssl_config_(ssl_config), cert_verifier_(context.cert_verifier), + cert_transparency_verifier_(context.cert_transparency_verifier), server_bound_cert_service_(context.server_bound_cert_service), ssl_session_cache_shard_(context.ssl_session_cache_shard), completed_handshake_(false), @@ -2808,6 +2878,9 @@ bool SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->cert_status = server_cert_verify_result_.cert_status; ssl_info->cert = server_cert_verify_result_.verified_cert; + + AddSCTInfoToSSLInfo(ssl_info); + ssl_info->connection_status = core_->state().ssl_connection_status; ssl_info->public_key_hashes = server_cert_verify_result_.public_key_hashes; @@ -3098,9 +3171,10 @@ int SSLClientSocketNSS::InitializeSSLOptions() { /* Create SSL state machine */ /* Push SSL onto our fake I/O socket */ - nss_fd_ = SSL_ImportFD(NULL, nss_fd_); - if (nss_fd_ == NULL) { + if (SSL_ImportFD(GetNSSModelSocket(), nss_fd_) == NULL) { LogFailedNSSFunction(net_log_, "SSL_ImportFD", ""); + PR_Close(nss_fd_); + nss_fd_ = NULL; return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR/NSS error code. } // TODO(port): set more ssl options! Check errors! @@ -3135,6 +3209,14 @@ int SSLClientSocketNSS::InitializeSSLOptions() { return ERR_NO_SSL_VERSIONS_ENABLED; } + if (ssl_config_.version_fallback) { + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_FALLBACK_SCSV, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction( + net_log_, "SSL_OptionSet", "SSL_ENABLE_FALLBACK_SCSV"); + } + } + for (std::vector<uint16>::const_iterator it = ssl_config_.disabled_cipher_suites.begin(); it != ssl_config_.disabled_cipher_suites.end(); ++it) { @@ -3173,15 +3255,24 @@ int SSLClientSocketNSS::InitializeSSLOptions() { // Added in NSS 3.15 #ifdef SSL_ENABLE_OCSP_STAPLING - if (IsOCSPStaplingSupported()) { - rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_OCSP_STAPLING, PR_TRUE); - if (rv != SECSuccess) { - LogFailedNSSFunction(net_log_, "SSL_OptionSet", - "SSL_ENABLE_OCSP_STAPLING"); - } + // Request OCSP stapling even on platforms that don't support it, in + // order to extract Certificate Transparency information. + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_OCSP_STAPLING, + (IsOCSPStaplingSupported() || + ssl_config_.signed_cert_timestamps_enabled)); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", + "SSL_ENABLE_OCSP_STAPLING"); } #endif + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, + ssl_config_.signed_cert_timestamps_enabled); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", + "SSL_ENABLE_SIGNED_CERT_TIMESTAMPS"); + } + // Chromium patch to libssl #ifdef SSL_ENABLE_CACHED_INFO rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_CACHED_INFO, @@ -3327,12 +3418,15 @@ int SSLClientSocketNSS::DoHandshakeComplete(int result) { // Done! } set_channel_id_sent(core_->state().channel_id_sent); + set_signed_cert_timestamps_received( + !core_->state().sct_list_from_tls_extension.empty()); + set_stapled_ocsp_response_received( + !core_->state().stapled_ocsp_response.empty()); LeaveFunction(result); return result; } - int SSLClientSocketNSS::DoVerifyCert(int result) { DCHECK(!core_->state().server_cert_chain.empty()); DCHECK(core_->state().server_cert_chain[0]); @@ -3416,8 +3510,6 @@ int SSLClientSocketNSS::DoVerifyCertComplete(int result) { if (result == OK) LogConnectionTypeMetrics(); - completed_handshake_ = true; - #if defined(OFFICIAL_BUILD) && !defined(OS_ANDROID) && !defined(OS_IOS) // Take care of any mandates for public key pinning. // @@ -3465,11 +3557,45 @@ int SSLClientSocketNSS::DoVerifyCertComplete(int result) { } #endif + if (result == OK) { + // Only check Certificate Transparency if there were no other errors with + // the connection. + VerifyCT(); + + // Only cache the session if the certificate verified successfully. + core_->CacheSessionIfNecessary(); + } + + completed_handshake_ = true; + // Exit DoHandshakeLoop and return the result to the caller to Connect. DCHECK_EQ(STATE_NONE, next_handshake_state_); return result; } +void SSLClientSocketNSS::VerifyCT() { + if (!cert_transparency_verifier_) + return; + + // Note that this is a completely synchronous operation: The CT Log Verifier + // gets all the data it needs for SCT verification and does not do any + // external communication. + int result = cert_transparency_verifier_->Verify( + server_cert_verify_result_.verified_cert, + core_->state().stapled_ocsp_response, + core_->state().sct_list_from_tls_extension, + &ct_verify_result_, + net_log_); + // TODO(ekasper): wipe stapled_ocsp_response and sct_list_from_tls_extension + // from the state after verification is complete, to conserve memory. + + VLOG(1) << "CT Verification complete: result " << result + << " Invalid scts: " << ct_verify_result_.invalid_scts.size() + << " Verified scts: " << ct_verify_result_.verified_scts.size() + << " scts from unknown logs: " + << ct_verify_result_.unknown_logs_scts.size(); +} + void SSLClientSocketNSS::LogConnectionTypeMetrics() const { UpdateConnectionTypeHistograms(CONNECTION_SSL); int ssl_version = SSLConnectionStatusToVersion( @@ -3506,6 +3632,28 @@ bool SSLClientSocketNSS::CalledOnValidThread() const { return valid_thread_id_ == base::PlatformThread::CurrentId(); } +void SSLClientSocketNSS::AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const { + for (ct::SCTList::const_iterator iter = + ct_verify_result_.verified_scts.begin(); + iter != ct_verify_result_.verified_scts.end(); ++iter) { + ssl_info->signed_certificate_timestamps.push_back( + SignedCertificateTimestampAndStatus(*iter, ct::SCT_STATUS_OK)); + } + for (ct::SCTList::const_iterator iter = + ct_verify_result_.invalid_scts.begin(); + iter != ct_verify_result_.invalid_scts.end(); ++iter) { + ssl_info->signed_certificate_timestamps.push_back( + SignedCertificateTimestampAndStatus(*iter, ct::SCT_STATUS_INVALID)); + } + for (ct::SCTList::const_iterator iter = + ct_verify_result_.unknown_logs_scts.begin(); + iter != ct_verify_result_.unknown_logs_scts.end(); ++iter) { + ssl_info->signed_certificate_timestamps.push_back( + SignedCertificateTimestampAndStatus(*iter, + ct::SCT_STATUS_LOG_UNKNOWN)); + } +} + ServerBoundCertService* SSLClientSocketNSS::GetServerBoundCertService() const { return server_bound_cert_service_; } diff --git a/chromium/net/socket/ssl_client_socket_nss.h b/chromium/net/socket/ssl_client_socket_nss.h index b41d28d74a8..cc1412fa80b 100644 --- a/chromium/net/socket/ssl_client_socket_nss.h +++ b/chromium/net/socket/ssl_client_socket_nss.h @@ -24,6 +24,7 @@ #include "net/base/net_log.h" #include "net/base/nss_memio.h" #include "net/cert/cert_verify_result.h" +#include "net/cert/ct_verify_result.h" #include "net/cert/x509_certificate.h" #include "net/socket/ssl_client_socket.h" #include "net/ssl/server_bound_cert_service.h" @@ -37,6 +38,7 @@ namespace net { class BoundNetLog; class CertVerifier; +class CTVerifier; class ClientSocketHandle; class ServerBoundCertService; class SingleRequestCertVerifier; @@ -135,6 +137,8 @@ class SSLClientSocketNSS : public SSLClientSocket { int DoVerifyCert(int result); int DoVerifyCertComplete(int result); + void VerifyCT(); + void LogConnectionTypeMetrics() const; // The following methods are for debugging bug 65948. Will remove this code @@ -142,6 +146,13 @@ class SSLClientSocketNSS : public SSLClientSocket { void EnsureThreadIdAssigned() const; bool CalledOnValidThread() const; + // Adds the SignedCertificateTimestamps from ct_verify_result_ to |ssl_info|. + // SCTs are held in three separate vectors in ct_verify_result, each + // vetor representing a particular verification state, this method associates + // each of the SCTs with the corresponding SCTVerifyStatus as it adds it to + // the |ssl_info|.signed_certificate_timestamps list. + void AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const; + // The task runner used to perform NSS operations. scoped_refptr<base::SequencedTaskRunner> nss_task_runner_; scoped_ptr<ClientSocketHandle> transport_; @@ -158,6 +169,10 @@ class SSLClientSocketNSS : public SSLClientSocket { CertVerifier* const cert_verifier_; scoped_ptr<SingleRequestCertVerifier> verifier_; + // Certificate Transparency: Verifier and result holder. + ct::CTVerifyResult ct_verify_result_; + CTVerifier* cert_transparency_verifier_; + // The service for retrieving Channel ID keys. May be NULL. ServerBoundCertService* server_bound_cert_service_; diff --git a/chromium/net/socket/ssl_client_socket_openssl.cc b/chromium/net/socket/ssl_client_socket_openssl.cc index 416ab87bc4b..49bdc8eb68a 100644 --- a/chromium/net/socket/ssl_client_socket_openssl.cc +++ b/chromium/net/socket/ssl_client_socket_openssl.cc @@ -13,15 +13,18 @@ #include "base/bind.h" #include "base/callback_helpers.h" +#include "base/debug/alias.h" #include "base/memory/singleton.h" #include "base/metrics/histogram.h" #include "base/synchronization/lock.h" +#include "crypto/ec_private_key.h" #include "crypto/openssl_util.h" #include "net/base/net_errors.h" #include "net/cert/cert_verifier.h" #include "net/cert/single_request_cert_verifier.h" #include "net/cert/x509_certificate_net_log_param.h" #include "net/socket/ssl_error_params.h" +#include "net/socket/ssl_session_cache_openssl.h" #include "net/ssl/openssl_client_key_store.h" #include "net/ssl/ssl_cert_request_info.h" #include "net/ssl/ssl_connection_status_flags.h" @@ -40,9 +43,6 @@ namespace { #define GotoState(s) next_handshake_state_ = s #endif -const int kSessionCacheTimeoutSeconds = 60 * 60; -const size_t kSessionCacheMaxEntires = 1024; - // This constant can be any non-negative/non-zero value (eg: it does not // overlap with any value of the net::Error range, including net::OK). const int kNoPendingReadResult = 1; @@ -162,6 +162,9 @@ int MapOpenSSLErrorSSL() { case SSL_R_INVALID_TICKET_KEYS_LENGTH: case SSL_R_KEY_ARG_TOO_LONG: case SSL_R_READ_WRONG_PACKET_TYPE: + // SSL_do_handshake reports this error when the server responds to a + // ClientHello with a fatal close_notify alert. + case SSL_AD_REASON_OFFSET + SSL_AD_CLOSE_NOTIFY: case SSL_R_SSLV3_ALERT_UNEXPECTED_MESSAGE: // TODO(joth): SSL_R_SSLV3_ALERT_HANDSHAKE_FAILURE may be returned from the // server after receiving ClientHello if there's no common supported cipher. @@ -214,114 +217,37 @@ int NoOpVerifyCallback(X509_STORE_CTX*, void *) { return 1; } -// OpenSSL manages a cache of SSL_SESSION, this class provides the application -// side policy for that cache about session re-use: we retain one session per -// unique HostPortPair, per shard. -class SSLSessionCache { - public: - SSLSessionCache() {} - - void OnSessionAdded(const HostPortPair& host_and_port, - const std::string& shard, - SSL_SESSION* session) { - // Declare the session cleaner-upper before the lock, so any call into - // OpenSSL to free the session will happen after the lock is released. - crypto::ScopedOpenSSL<SSL_SESSION, SSL_SESSION_free> session_to_free; - base::AutoLock lock(lock_); - - DCHECK_EQ(0U, session_map_.count(session)); - const std::string cache_key = GetCacheKey(host_and_port, shard); - - std::pair<HostPortMap::iterator, bool> res = - host_port_map_.insert(std::make_pair(cache_key, session)); - if (!res.second) { // Already exists: replace old entry. - session_to_free.reset(res.first->second); - session_map_.erase(session_to_free.get()); - res.first->second = session; - } - DVLOG(2) << "Adding session " << session << " => " - << cache_key << ", new entry = " << res.second; - DCHECK(host_port_map_[cache_key] == session); - session_map_[session] = res.first; - DCHECK_EQ(host_port_map_.size(), session_map_.size()); - DCHECK_LE(host_port_map_.size(), kSessionCacheMaxEntires); - } - - void OnSessionRemoved(SSL_SESSION* session) { - // Declare the session cleaner-upper before the lock, so any call into - // OpenSSL to free the session will happen after the lock is released. - crypto::ScopedOpenSSL<SSL_SESSION, SSL_SESSION_free> session_to_free; - base::AutoLock lock(lock_); - - SessionMap::iterator it = session_map_.find(session); - if (it == session_map_.end()) - return; - DVLOG(2) << "Remove session " << session << " => " << it->second->first; - DCHECK(it->second->second == session); - host_port_map_.erase(it->second); - session_map_.erase(it); - session_to_free.reset(session); - DCHECK_EQ(host_port_map_.size(), session_map_.size()); - } - - // Looks up the host:port in the cache, and if a session is found it is added - // to |ssl|, returning true on success. - bool SetSSLSession(SSL* ssl, const HostPortPair& host_and_port, - const std::string& shard) { - base::AutoLock lock(lock_); - const std::string cache_key = GetCacheKey(host_and_port, shard); - HostPortMap::iterator it = host_port_map_.find(cache_key); - if (it == host_port_map_.end()) - return false; - DVLOG(2) << "Lookup session: " << it->second << " => " << cache_key; - SSL_SESSION* session = it->second; - DCHECK(session); - DCHECK(session_map_[session] == it); - // Ideally we'd release |lock_| before calling into OpenSSL here, however - // that opens a small risk |session| will go out of scope before it is used. - // Alternatively we would take a temporary local refcount on |session|, - // except OpenSSL does not provide a public API for adding a ref (c.f. - // SSL_SESSION_free which decrements the ref). - return SSL_set_session(ssl, session) == 1; - } - - // Flush removes all entries from the cache. This is called when a client - // certificate is added. - void Flush() { - for (HostPortMap::iterator i = host_port_map_.begin(); - i != host_port_map_.end(); i++) { - SSL_SESSION_free(i->second); - } - host_port_map_.clear(); - session_map_.clear(); - } - - private: - static std::string GetCacheKey(const HostPortPair& host_and_port, - const std::string& shard) { - return host_and_port.ToString() + "/" + shard; +// Utility to construct the appropriate set & clear masks for use the OpenSSL +// options and mode configuration functions. (SSL_set_options etc) +struct SslSetClearMask { + SslSetClearMask() : set_mask(0), clear_mask(0) {} + void ConfigureFlag(long flag, bool state) { + (state ? set_mask : clear_mask) |= flag; + // Make sure we haven't got any intersection in the set & clear options. + DCHECK_EQ(0, set_mask & clear_mask) << flag << ":" << state; } + long set_mask; + long clear_mask; +}; - // A pair of maps to allow bi-directional lookups between host:port and an - // associated session. - typedef std::map<std::string, SSL_SESSION*> HostPortMap; - typedef std::map<SSL_SESSION*, HostPortMap::iterator> SessionMap; - HostPortMap host_port_map_; - SessionMap session_map_; - - // Protects access to both the above maps. - base::Lock lock_; +// Compute a unique key string for the SSL session cache. |socket| is an +// input socket object. Return a string. +std::string GetSocketSessionCacheKey(const SSLClientSocketOpenSSL& socket) { + std::string result = socket.host_and_port().ToString(); + result.append("/"); + result.append(socket.ssl_session_cache_shard()); + return result; +} - DISALLOW_COPY_AND_ASSIGN(SSLSessionCache); -}; +} // namespace -class SSLContext { +class SSLClientSocketOpenSSL::SSLContext { public: static SSLContext* GetInstance() { return Singleton<SSLContext>::get(); } SSL_CTX* ssl_ctx() { return ssl_ctx_.get(); } - SSLSessionCache* session_cache() { return &session_cache_; } + SSLSessionCacheOpenSSL* session_cache() { return &session_cache_; } - SSLClientSocketOpenSSL* GetClientSocketFromSSL(SSL* ssl) { + SSLClientSocketOpenSSL* GetClientSocketFromSSL(const SSL* ssl) { DCHECK(ssl); SSLClientSocketOpenSSL* socket = static_cast<SSLClientSocketOpenSSL*>( SSL_get_ex_data(ssl, ssl_socket_data_index_)); @@ -341,13 +267,10 @@ class SSLContext { ssl_socket_data_index_ = SSL_get_ex_new_index(0, 0, 0, 0, 0); DCHECK_NE(ssl_socket_data_index_, -1); ssl_ctx_.reset(SSL_CTX_new(SSLv23_client_method())); + session_cache_.Reset(ssl_ctx_.get(), kDefaultSessionCacheConfig); SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), NoOpVerifyCallback, NULL); - SSL_CTX_set_session_cache_mode(ssl_ctx_.get(), SSL_SESS_CACHE_CLIENT); - SSL_CTX_sess_set_new_cb(ssl_ctx_.get(), NewSessionCallbackStatic); - SSL_CTX_sess_set_remove_cb(ssl_ctx_.get(), RemoveSessionCallbackStatic); - SSL_CTX_set_timeout(ssl_ctx_.get(), kSessionCacheTimeoutSeconds); - SSL_CTX_sess_set_cache_size(ssl_ctx_.get(), kSessionCacheMaxEntires); SSL_CTX_set_client_cert_cb(ssl_ctx_.get(), ClientCertCallback); + SSL_CTX_set_channel_id_cb(ssl_ctx_.get(), ChannelIDCallback); #if defined(OPENSSL_NPN_NEGOTIATED) // TODO(kristianm): Only select this if ssl_config_.next_proto is not empty. // It would be better if the callback were not a global setting, @@ -357,26 +280,13 @@ class SSLContext { #endif } - static int NewSessionCallbackStatic(SSL* ssl, SSL_SESSION* session) { - return GetInstance()->NewSessionCallback(ssl, session); - } - - int NewSessionCallback(SSL* ssl, SSL_SESSION* session) { - SSLClientSocketOpenSSL* socket = GetClientSocketFromSSL(ssl); - session_cache_.OnSessionAdded(socket->host_and_port(), - socket->ssl_session_cache_shard(), - session); - return 1; // 1 => We took ownership of |session|. - } - - static void RemoveSessionCallbackStatic(SSL_CTX* ctx, SSL_SESSION* session) { - return GetInstance()->RemoveSessionCallback(ctx, session); + static std::string GetSessionCacheKey(const SSL* ssl) { + SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); + DCHECK(socket); + return GetSocketSessionCacheKey(*socket); } - void RemoveSessionCallback(SSL_CTX* ctx, SSL_SESSION* session) { - DCHECK(ctx == ssl_ctx()); - session_cache_.OnSessionRemoved(session); - } + static SSLSessionCacheOpenSSL::Config kDefaultSessionCacheConfig; static int ClientCertCallback(SSL* ssl, X509** x509, EVP_PKEY** pkey) { SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); @@ -384,6 +294,12 @@ class SSLContext { return socket->ClientCertRequestCallback(ssl, x509, pkey); } + static void ChannelIDCallback(SSL* ssl, EVP_PKEY** pkey) { + SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); + CHECK(socket); + socket->ChannelIDRequestCallback(ssl, pkey); + } + static int SelectNextProtoCallback(SSL* ssl, unsigned char** out, unsigned char* outlen, const unsigned char* in, @@ -396,31 +312,24 @@ class SSLContext { // SSLClientSocketOpenSSL object from an SSL instance. int ssl_socket_data_index_; - // session_cache_ must appear before |ssl_ctx_| because the destruction of - // |ssl_ctx_| may trigger callbacks into |session_cache_|. Therefore, - // |session_cache_| must be destructed after |ssl_ctx_|. - SSLSessionCache session_cache_; crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ssl_ctx_; + // |session_cache_| must be destroyed before |ssl_ctx_|. + SSLSessionCacheOpenSSL session_cache_; }; -// Utility to construct the appropriate set & clear masks for use the OpenSSL -// options and mode configuration functions. (SSL_set_options etc) -struct SslSetClearMask { - SslSetClearMask() : set_mask(0), clear_mask(0) {} - void ConfigureFlag(long flag, bool state) { - (state ? set_mask : clear_mask) |= flag; - // Make sure we haven't got any intersection in the set & clear options. - DCHECK_EQ(0, set_mask & clear_mask) << flag << ":" << state; - } - long set_mask; - long clear_mask; +// static +SSLSessionCacheOpenSSL::Config + SSLClientSocketOpenSSL::SSLContext::kDefaultSessionCacheConfig = { + &GetSessionCacheKey, // key_func + 1024, // max_entries + 256, // expiration_check_count + 60 * 60, // timeout_seconds }; -} // namespace - // static void SSLClientSocket::ClearSessionCache() { - SSLContext* context = SSLContext::GetInstance(); + SSLClientSocketOpenSSL::SSLContext* context = + SSLClientSocketOpenSSL::SSLContext::GetInstance(); context->session_cache()->Flush(); } @@ -434,9 +343,11 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( transport_recv_eof_(false), weak_factory_(this), pending_read_error_(kNoPendingReadResult), + transport_write_error_(OK), completed_handshake_(false), client_auth_cert_needed_(false), cert_verifier_(context.cert_verifier), + server_bound_cert_service_(context.server_bound_cert_service), ssl_(NULL), transport_bio_(NULL), transport_(transport_socket.Pass()), @@ -446,6 +357,8 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( trying_cached_session_(false), next_handshake_state_(STATE_NONE), npn_status_(kNextProtoUnsupported), + channel_id_request_return_value_(ERR_UNEXPECTED), + channel_id_xtn_negotiated_(false), net_log_(transport_->socket()->NetLog()) { } @@ -453,6 +366,282 @@ SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() { Disconnect(); } +void SSLClientSocketOpenSSL::GetSSLCertRequestInfo( + SSLCertRequestInfo* cert_request_info) { + cert_request_info->host_and_port = host_and_port_.ToString(); + cert_request_info->cert_authorities = cert_authorities_; +} + +SSLClientSocket::NextProtoStatus SSLClientSocketOpenSSL::GetNextProto( + std::string* proto, std::string* server_protos) { + *proto = npn_proto_; + *server_protos = server_protos_; + return npn_status_; +} + +ServerBoundCertService* +SSLClientSocketOpenSSL::GetServerBoundCertService() const { + return server_bound_cert_service_; +} + +int SSLClientSocketOpenSSL::ExportKeyingMaterial( + const base::StringPiece& label, + bool has_context, const base::StringPiece& context, + unsigned char* out, unsigned int outlen) { + crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); + + int rv = SSL_export_keying_material( + ssl_, out, outlen, const_cast<char*>(label.data()), + label.size(), + reinterpret_cast<unsigned char*>(const_cast<char*>(context.data())), + context.length(), + context.length() > 0); + + if (rv != 1) { + int ssl_error = SSL_get_error(ssl_, rv); + LOG(ERROR) << "Failed to export keying material;" + << " returned " << rv + << ", SSL error code " << ssl_error; + return MapOpenSSLError(ssl_error, err_tracer); + } + return OK; +} + +int SSLClientSocketOpenSSL::GetTLSUniqueChannelBinding(std::string* out) { + return ERR_NOT_IMPLEMENTED; +} + +int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { + net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT); + + // Set up new ssl object. + if (!Init()) { + int result = ERR_UNEXPECTED; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, result); + return result; + } + + // Set SSL to client mode. Handshake happens in the loop below. + SSL_set_connect_state(ssl_); + + GotoState(STATE_HANDSHAKE); + int rv = DoHandshakeLoop(net::OK); + if (rv == ERR_IO_PENDING) { + user_connect_callback_ = callback; + } else { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + } + + return rv > OK ? OK : rv; +} + +void SSLClientSocketOpenSSL::Disconnect() { + if (ssl_) { + // Calling SSL_shutdown prevents the session from being marked as + // unresumable. + SSL_shutdown(ssl_); + SSL_free(ssl_); + ssl_ = NULL; + } + if (transport_bio_) { + BIO_free_all(transport_bio_); + transport_bio_ = NULL; + } + + // Shut down anything that may call us back. + verifier_.reset(); + transport_->socket()->Disconnect(); + + // Null all callbacks, delete all buffers. + transport_send_busy_ = false; + send_buffer_ = NULL; + transport_recv_busy_ = false; + transport_recv_eof_ = false; + recv_buffer_ = NULL; + + user_connect_callback_.Reset(); + user_read_callback_.Reset(); + user_write_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + + pending_read_error_ = kNoPendingReadResult; + transport_write_error_ = OK; + + server_cert_verify_result_.Reset(); + completed_handshake_ = false; + + cert_authorities_.clear(); + client_auth_cert_needed_ = false; +} + +bool SSLClientSocketOpenSSL::IsConnected() const { + // If the handshake has not yet completed. + if (!completed_handshake_) + return false; + // If an asynchronous operation is still pending. + if (user_read_buf_.get() || user_write_buf_.get()) + return true; + + return transport_->socket()->IsConnected(); +} + +bool SSLClientSocketOpenSSL::IsConnectedAndIdle() const { + // If the handshake has not yet completed. + if (!completed_handshake_) + return false; + // If an asynchronous operation is still pending. + if (user_read_buf_.get() || user_write_buf_.get()) + return false; + // If there is data waiting to be sent, or data read from the network that + // has not yet been consumed. + if (BIO_ctrl_pending(transport_bio_) > 0 || + BIO_ctrl_wpending(transport_bio_) > 0) { + return false; + } + + return transport_->socket()->IsConnectedAndIdle(); +} + +int SSLClientSocketOpenSSL::GetPeerAddress(IPEndPoint* addressList) const { + return transport_->socket()->GetPeerAddress(addressList); +} + +int SSLClientSocketOpenSSL::GetLocalAddress(IPEndPoint* addressList) const { + return transport_->socket()->GetLocalAddress(addressList); +} + +const BoundNetLog& SSLClientSocketOpenSSL::NetLog() const { + return net_log_; +} + +void SSLClientSocketOpenSSL::SetSubresourceSpeculation() { + if (transport_.get() && transport_->socket()) { + transport_->socket()->SetSubresourceSpeculation(); + } else { + NOTREACHED(); + } +} + +void SSLClientSocketOpenSSL::SetOmniboxSpeculation() { + if (transport_.get() && transport_->socket()) { + transport_->socket()->SetOmniboxSpeculation(); + } else { + NOTREACHED(); + } +} + +bool SSLClientSocketOpenSSL::WasEverUsed() const { + if (transport_.get() && transport_->socket()) + return transport_->socket()->WasEverUsed(); + + NOTREACHED(); + return false; +} + +bool SSLClientSocketOpenSSL::UsingTCPFastOpen() const { + if (transport_.get() && transport_->socket()) + return transport_->socket()->UsingTCPFastOpen(); + + NOTREACHED(); + return false; +} + +bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { + ssl_info->Reset(); + if (!server_cert_.get()) + return false; + + ssl_info->cert = server_cert_verify_result_.verified_cert; + ssl_info->cert_status = server_cert_verify_result_.cert_status; + ssl_info->is_issued_by_known_root = + server_cert_verify_result_.is_issued_by_known_root; + ssl_info->public_key_hashes = + server_cert_verify_result_.public_key_hashes; + ssl_info->client_cert_sent = + ssl_config_.send_client_cert && ssl_config_.client_cert.get(); + ssl_info->channel_id_sent = WasChannelIDSent(); + + RecordChannelIDSupport(server_bound_cert_service_, + channel_id_xtn_negotiated_, + ssl_config_.channel_id_enabled, + crypto::ECPrivateKey::IsSupported()); + + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_); + CHECK(cipher); + ssl_info->security_bits = SSL_CIPHER_get_bits(cipher, NULL); + const COMP_METHOD* compression = SSL_get_current_compression(ssl_); + + ssl_info->connection_status = EncodeSSLConnectionStatus( + SSL_CIPHER_get_id(cipher), + compression ? compression->type : 0, + GetNetSSLVersion(ssl_)); + + bool peer_supports_renego_ext = !!SSL_get_secure_renegotiation_support(ssl_); + if (!peer_supports_renego_ext) + ssl_info->connection_status |= SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION; + UMA_HISTOGRAM_ENUMERATION("Net.RenegotiationExtensionSupported", + implicit_cast<int>(peer_supports_renego_ext), 2); + + if (ssl_config_.version_fallback) + ssl_info->connection_status |= SSL_CONNECTION_VERSION_FALLBACK; + + ssl_info->handshake_type = SSL_session_reused(ssl_) ? + SSLInfo::HANDSHAKE_RESUME : SSLInfo::HANDSHAKE_FULL; + + DVLOG(3) << "Encoded connection status: cipher suite = " + << SSLConnectionStatusToCipherSuite(ssl_info->connection_status) + << " version = " + << SSLConnectionStatusToVersion(ssl_info->connection_status); + return true; +} + +int SSLClientSocketOpenSSL::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + user_read_buf_ = buf; + user_read_buf_len_ = buf_len; + + int rv = DoReadLoop(OK); + + if (rv == ERR_IO_PENDING) { + user_read_callback_ = callback; + } else { + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + } + + return rv; +} + +int SSLClientSocketOpenSSL::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoWriteLoop(OK); + + if (rv == ERR_IO_PENDING) { + user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + + return rv; +} + +bool SSLClientSocketOpenSSL::SetReceiveBufferSize(int32 size) { + return transport_->socket()->SetReceiveBufferSize(size); +} + +bool SSLClientSocketOpenSSL::SetSendBufferSize(int32 size) { + return transport_->socket()->SetSendBufferSize(size); +} + bool SSLClientSocketOpenSSL::Init() { DCHECK(!ssl_); DCHECK(!transport_bio_); @@ -467,9 +656,8 @@ bool SSLClientSocketOpenSSL::Init() { if (!SSL_set_tlsext_host_name(ssl_, host_and_port_.host().c_str())) return false; - trying_cached_session_ = - context->session_cache()->SetSSLSession(ssl_, host_and_port_, - ssl_session_cache_shard_); + trying_cached_session_ = context->session_cache()->SetSSLSessionWithKey( + ssl_, GetSocketSessionCacheKey(*this)); BIO* ssl_bio = NULL; // 0 => use default buffer sizes. @@ -566,149 +754,15 @@ bool SSLClientSocketOpenSSL::Init() { // handshake at which point the appropriate error is bubbled up to the client. LOG_IF(WARNING, rv != 1) << "SSL_set_cipher_list('" << command << "') " "returned " << rv; - return true; -} - -int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl, - X509** x509, - EVP_PKEY** pkey) { - DVLOG(3) << "OpenSSL ClientCertRequestCallback called"; - DCHECK(ssl == ssl_); - DCHECK(*x509 == NULL); - DCHECK(*pkey == NULL); - - if (!ssl_config_.send_client_cert) { - // First pass: we know that a client certificate is needed, but we do not - // have one at hand. - client_auth_cert_needed_ = true; - STACK_OF(X509_NAME) *authorities = SSL_get_client_CA_list(ssl); - for (int i = 0; i < sk_X509_NAME_num(authorities); i++) { - X509_NAME *ca_name = (X509_NAME *)sk_X509_NAME_value(authorities, i); - unsigned char* str = NULL; - int length = i2d_X509_NAME(ca_name, &str); - cert_authorities_.push_back(std::string( - reinterpret_cast<const char*>(str), - static_cast<size_t>(length))); - OPENSSL_free(str); - } - return -1; // Suspends handshake. + // TLS channel ids. + if (IsChannelIDEnabled(ssl_config_, server_bound_cert_service_)) { + SSL_enable_tls_channel_id(ssl_); } - // Second pass: a client certificate should have been selected. - if (ssl_config_.client_cert.get()) { - // A note about ownership: FetchClientCertPrivateKey() increments - // the reference count of the EVP_PKEY. Ownership of this reference - // is passed directly to OpenSSL, which will release the reference - // using EVP_PKEY_free() when the SSL object is destroyed. - OpenSSLClientKeyStore::ScopedEVP_PKEY privkey; - if (OpenSSLClientKeyStore::GetInstance()->FetchClientCertPrivateKey( - ssl_config_.client_cert.get(), &privkey)) { - // TODO(joth): (copied from NSS) We should wait for server certificate - // verification before sending our credentials. See http://crbug.com/13934 - *x509 = X509Certificate::DupOSCertHandle( - ssl_config_.client_cert->os_cert_handle()); - *pkey = privkey.release(); - return 1; - } - LOG(WARNING) << "Client cert found without private key"; - } - - // Send no client certificate. - return 0; -} - -// SSLClientSocket methods - -bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { - ssl_info->Reset(); - if (!server_cert_.get()) - return false; - - ssl_info->cert = server_cert_verify_result_.verified_cert; - ssl_info->cert_status = server_cert_verify_result_.cert_status; - ssl_info->is_issued_by_known_root = - server_cert_verify_result_.is_issued_by_known_root; - ssl_info->public_key_hashes = - server_cert_verify_result_.public_key_hashes; - ssl_info->client_cert_sent = - ssl_config_.send_client_cert && ssl_config_.client_cert.get(); - ssl_info->channel_id_sent = WasChannelIDSent(); - - const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_); - CHECK(cipher); - ssl_info->security_bits = SSL_CIPHER_get_bits(cipher, NULL); - const COMP_METHOD* compression = SSL_get_current_compression(ssl_); - - ssl_info->connection_status = EncodeSSLConnectionStatus( - SSL_CIPHER_get_id(cipher), - compression ? compression->type : 0, - GetNetSSLVersion(ssl_)); - - bool peer_supports_renego_ext = !!SSL_get_secure_renegotiation_support(ssl_); - if (!peer_supports_renego_ext) - ssl_info->connection_status |= SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION; - UMA_HISTOGRAM_ENUMERATION("Net.RenegotiationExtensionSupported", - implicit_cast<int>(peer_supports_renego_ext), 2); - - if (ssl_config_.version_fallback) - ssl_info->connection_status |= SSL_CONNECTION_VERSION_FALLBACK; - - ssl_info->handshake_type = SSL_session_reused(ssl_) ? - SSLInfo::HANDSHAKE_RESUME : SSLInfo::HANDSHAKE_FULL; - - DVLOG(3) << "Encoded connection status: cipher suite = " - << SSLConnectionStatusToCipherSuite(ssl_info->connection_status) - << " version = " - << SSLConnectionStatusToVersion(ssl_info->connection_status); return true; } -void SSLClientSocketOpenSSL::GetSSLCertRequestInfo( - SSLCertRequestInfo* cert_request_info) { - cert_request_info->host_and_port = host_and_port_.ToString(); - cert_request_info->cert_authorities = cert_authorities_; -} - -int SSLClientSocketOpenSSL::ExportKeyingMaterial( - const base::StringPiece& label, - bool has_context, const base::StringPiece& context, - unsigned char* out, unsigned int outlen) { - crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); - - int rv = SSL_export_keying_material( - ssl_, out, outlen, const_cast<char*>(label.data()), - label.size(), - reinterpret_cast<unsigned char*>(const_cast<char*>(context.data())), - context.length(), - context.length() > 0); - - if (rv != 1) { - int ssl_error = SSL_get_error(ssl_, rv); - LOG(ERROR) << "Failed to export keying material;" - << " returned " << rv - << ", SSL error code " << ssl_error; - return MapOpenSSLError(ssl_error, err_tracer); - } - return OK; -} - -int SSLClientSocketOpenSSL::GetTLSUniqueChannelBinding(std::string* out) { - return ERR_NOT_IMPLEMENTED; -} - -SSLClientSocket::NextProtoStatus SSLClientSocketOpenSSL::GetNextProto( - std::string* proto, std::string* server_protos) { - *proto = npn_proto_; - *server_protos = server_protos_; - return npn_status_; -} - -ServerBoundCertService* -SSLClientSocketOpenSSL::GetServerBoundCertService() const { - return NULL; -} - void SSLClientSocketOpenSSL::DoReadCallback(int rv) { // Since Run may result in Read being called, clear |user_read_callback_| // up front. @@ -725,107 +779,19 @@ void SSLClientSocketOpenSSL::DoWriteCallback(int rv) { base::ResetAndReturn(&user_write_callback_).Run(rv); } -// StreamSocket implementation. -int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { - net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT); - - // Set up new ssl object. - if (!Init()) { - int result = ERR_UNEXPECTED; - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, result); - return result; - } - - // Set SSL to client mode. Handshake happens in the loop below. - SSL_set_connect_state(ssl_); - - GotoState(STATE_HANDSHAKE); - int rv = DoHandshakeLoop(net::OK); - if (rv == ERR_IO_PENDING) { - user_connect_callback_ = callback; - } else { - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); - } - - return rv > OK ? OK : rv; -} - -void SSLClientSocketOpenSSL::Disconnect() { - if (ssl_) { - // Calling SSL_shutdown prevents the session from being marked as - // unresumable. - SSL_shutdown(ssl_); - SSL_free(ssl_); - ssl_ = NULL; - } - if (transport_bio_) { - BIO_free_all(transport_bio_); - transport_bio_ = NULL; - } - - // Shut down anything that may call us back. - verifier_.reset(); - transport_->socket()->Disconnect(); - - // Null all callbacks, delete all buffers. - transport_send_busy_ = false; - send_buffer_ = NULL; - transport_recv_busy_ = false; - transport_recv_eof_ = false; - recv_buffer_ = NULL; - - user_connect_callback_.Reset(); - user_read_callback_.Reset(); - user_write_callback_.Reset(); - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - - server_cert_verify_result_.Reset(); - completed_handshake_ = false; - - cert_authorities_.clear(); - client_auth_cert_needed_ = false; -} - -int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) { - int rv = last_io_result; +bool SSLClientSocketOpenSSL::DoTransportIO() { + bool network_moved = false; + int rv; + // Read and write as much data as possible. The loop is necessary because + // Write() may return synchronously. do { - // Default to STATE_NONE for next state. - // (This is a quirk carried over from the windows - // implementation. It makes reading the logs a bit harder.) - // State handlers can and often do call GotoState just - // to stay in the current state. - State state = next_handshake_state_; - GotoState(STATE_NONE); - switch (state) { - case STATE_HANDSHAKE: - rv = DoHandshake(); - break; - case STATE_VERIFY_CERT: - DCHECK(rv == OK); - rv = DoVerifyCert(rv); - break; - case STATE_VERIFY_CERT_COMPLETE: - rv = DoVerifyCertComplete(rv); - break; - case STATE_NONE: - default: - rv = ERR_UNEXPECTED; - NOTREACHED() << "unexpected state" << state; - break; - } - - bool network_moved = DoTransportIO(); - if (network_moved && next_handshake_state_ == STATE_HANDSHAKE) { - // In general we exit the loop if rv is ERR_IO_PENDING. In this - // special case we keep looping even if rv is ERR_IO_PENDING because - // the transport IO may allow DoHandshake to make progress. - rv = OK; // This causes us to stay in the loop. - } - } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE); - return rv; + rv = BufferSend(); + if (rv != ERR_IO_PENDING && rv != 0) + network_moved = true; + } while (rv > 0); + if (!transport_recv_eof_ && BufferRecv() != ERR_IO_PENDING) + network_moved = true; + return network_moved; } int SSLClientSocketOpenSSL::DoHandshake() { @@ -863,7 +829,14 @@ int SSLClientSocketOpenSSL::DoHandshake() { GotoState(STATE_VERIFY_CERT); } else { int ssl_error = SSL_get_error(ssl_, rv); - net_error = MapOpenSSLError(ssl_error, err_tracer); + + if (ssl_error == SSL_ERROR_WANT_CHANNEL_ID_LOOKUP) { + // The server supports TLS channel id and the lookup is asynchronous. + // Retrieve the error from the call to |server_bound_cert_service_|. + net_error = channel_id_request_return_value_; + } else { + net_error = MapOpenSSLError(ssl_error, err_tracer); + } // If not done, stay in this state if (net_error == ERR_IO_PENDING) { @@ -880,58 +853,6 @@ int SSLClientSocketOpenSSL::DoHandshake() { return net_error; } -// SelectNextProtoCallback is called by OpenSSL during the handshake. If the -// server supports NPN, selects a protocol from the list that the server -// provides. According to third_party/openssl/openssl/ssl/ssl_lib.c, the -// callback can assume that |in| is syntactically valid. -int SSLClientSocketOpenSSL::SelectNextProtoCallback(unsigned char** out, - unsigned char* outlen, - const unsigned char* in, - unsigned int inlen) { -#if defined(OPENSSL_NPN_NEGOTIATED) - if (ssl_config_.next_protos.empty()) { - *out = reinterpret_cast<uint8*>( - const_cast<char*>(kDefaultSupportedNPNProtocol)); - *outlen = arraysize(kDefaultSupportedNPNProtocol) - 1; - npn_status_ = kNextProtoUnsupported; - return SSL_TLSEXT_ERR_OK; - } - - // Assume there's no overlap between our protocols and the server's list. - npn_status_ = kNextProtoNoOverlap; - - // For each protocol in server preference order, see if we support it. - for (unsigned int i = 0; i < inlen; i += in[i] + 1) { - for (std::vector<std::string>::const_iterator - j = ssl_config_.next_protos.begin(); - j != ssl_config_.next_protos.end(); ++j) { - if (in[i] == j->size() && - memcmp(&in[i + 1], j->data(), in[i]) == 0) { - // We found a match. - *out = const_cast<unsigned char*>(in) + i + 1; - *outlen = in[i]; - npn_status_ = kNextProtoNegotiated; - break; - } - } - if (npn_status_ == kNextProtoNegotiated) - break; - } - - // If we didn't find a protocol, we select the first one from our list. - if (npn_status_ == kNextProtoNoOverlap) { - *out = reinterpret_cast<uint8*>(const_cast<char*>( - ssl_config_.next_protos[0].data())); - *outlen = ssl_config_.next_protos[0].size(); - } - - npn_proto_.assign(reinterpret_cast<const char*>(*out), *outlen); - server_protos_.assign(reinterpret_cast<const char*>(in), inlen); - DVLOG(2) << "next protocol: '" << npn_proto_ << "' status: " << npn_status_; -#endif - return SSL_TLSEXT_ERR_OK; -} - int SSLClientSocketOpenSSL::DoVerifyCert(int result) { DCHECK(server_cert_.get()); GotoState(STATE_VERIFY_CERT_COMPLETE); @@ -972,6 +893,7 @@ int SSLClientSocketOpenSSL::DoVerifyCertComplete(int result) { if (result == OK) { // TODO(joth): Work out if we need to remember the intermediate CA certs // when the server sends them to us, and do so here. + SSLContext::GetInstance()->session_cache()->MarkSSLSessionAsGood(ssl_); } else { DVLOG(1) << "DoVerifyCertComplete error " << ErrorToString(result) << " (" << result << ")"; @@ -983,6 +905,14 @@ int SSLClientSocketOpenSSL::DoVerifyCertComplete(int result) { return result; } +void SSLClientSocketOpenSSL::DoConnectCallback(int rv) { + if (!user_connect_callback_.is_null()) { + CompletionCallback c = user_connect_callback_; + user_connect_callback_.Reset(); + c.Run(rv > OK ? OK : rv); + } +} + X509Certificate* SSLClientSocketOpenSSL::UpdateServerCert() { if (server_cert_.get()) return server_cert_.get(); @@ -1007,146 +937,6 @@ X509Certificate* SSLClientSocketOpenSSL::UpdateServerCert() { return server_cert_.get(); } -bool SSLClientSocketOpenSSL::DoTransportIO() { - bool network_moved = false; - int rv; - // Read and write as much data as possible. The loop is necessary because - // Write() may return synchronously. - do { - rv = BufferSend(); - if (rv != ERR_IO_PENDING && rv != 0) - network_moved = true; - } while (rv > 0); - if (!transport_recv_eof_ && BufferRecv() != ERR_IO_PENDING) - network_moved = true; - return network_moved; -} - -int SSLClientSocketOpenSSL::BufferSend(void) { - if (transport_send_busy_) - return ERR_IO_PENDING; - - if (!send_buffer_.get()) { - // Get a fresh send buffer out of the send BIO. - size_t max_read = BIO_ctrl_pending(transport_bio_); - if (!max_read) - return 0; // Nothing pending in the OpenSSL write BIO. - send_buffer_ = new DrainableIOBuffer(new IOBuffer(max_read), max_read); - int read_bytes = BIO_read(transport_bio_, send_buffer_->data(), max_read); - DCHECK_GT(read_bytes, 0); - CHECK_EQ(static_cast<int>(max_read), read_bytes); - } - - int rv = transport_->socket()->Write( - send_buffer_.get(), - send_buffer_->BytesRemaining(), - base::Bind(&SSLClientSocketOpenSSL::BufferSendComplete, - base::Unretained(this))); - if (rv == ERR_IO_PENDING) { - transport_send_busy_ = true; - } else { - TransportWriteComplete(rv); - } - return rv; -} - -void SSLClientSocketOpenSSL::BufferSendComplete(int result) { - transport_send_busy_ = false; - TransportWriteComplete(result); - OnSendComplete(result); -} - -void SSLClientSocketOpenSSL::TransportWriteComplete(int result) { - DCHECK(ERR_IO_PENDING != result); - if (result < 0) { - // Got a socket write error; close the BIO to indicate this upward. - DVLOG(1) << "TransportWriteComplete error " << result; - (void)BIO_shutdown_wr(transport_bio_); - BIO_set_mem_eof_return(transport_bio_, 0); - send_buffer_ = NULL; - } else { - DCHECK(send_buffer_.get()); - send_buffer_->DidConsume(result); - DCHECK_GE(send_buffer_->BytesRemaining(), 0); - if (send_buffer_->BytesRemaining() <= 0) - send_buffer_ = NULL; - } -} - -int SSLClientSocketOpenSSL::BufferRecv(void) { - if (transport_recv_busy_) - return ERR_IO_PENDING; - - // Determine how much was requested from |transport_bio_| that was not - // actually available. - size_t requested = BIO_ctrl_get_read_request(transport_bio_); - if (requested == 0) { - // This is not a perfect match of error codes, as no operation is - // actually pending. However, returning 0 would be interpreted as - // a possible sign of EOF, which is also an inappropriate match. - return ERR_IO_PENDING; - } - - // Known Issue: While only reading |requested| data is the more correct - // implementation, it has the downside of resulting in frequent reads: - // One read for the SSL record header (~5 bytes) and one read for the SSL - // record body. Rather than issuing these reads to the underlying socket - // (and constantly allocating new IOBuffers), a single Read() request to - // fill |transport_bio_| is issued. As long as an SSL client socket cannot - // be gracefully shutdown (via SSL close alerts) and re-used for non-SSL - // traffic, this over-subscribed Read()ing will not cause issues. - size_t max_write = BIO_ctrl_get_write_guarantee(transport_bio_); - if (!max_write) - return ERR_IO_PENDING; - - recv_buffer_ = new IOBuffer(max_write); - int rv = transport_->socket()->Read( - recv_buffer_.get(), - max_write, - base::Bind(&SSLClientSocketOpenSSL::BufferRecvComplete, - base::Unretained(this))); - if (rv == ERR_IO_PENDING) { - transport_recv_busy_ = true; - } else { - TransportReadComplete(rv); - } - return rv; -} - -void SSLClientSocketOpenSSL::BufferRecvComplete(int result) { - TransportReadComplete(result); - OnRecvComplete(result); -} - -void SSLClientSocketOpenSSL::TransportReadComplete(int result) { - DCHECK(ERR_IO_PENDING != result); - if (result <= 0) { - DVLOG(1) << "TransportReadComplete result " << result; - // Received 0 (end of file) or an error. Either way, bubble it up to the - // SSL layer via the BIO. TODO(joth): consider stashing the error code, to - // relay up to the SSL socket client (i.e. via DoReadCallback). - if (result == 0) - transport_recv_eof_ = true; - BIO_set_mem_eof_return(transport_bio_, 0); - (void)BIO_shutdown_wr(transport_bio_); - } else { - DCHECK(recv_buffer_.get()); - int ret = BIO_write(transport_bio_, recv_buffer_->data(), result); - // A write into a memory BIO should always succeed. - CHECK_EQ(result, ret); - } - recv_buffer_ = NULL; - transport_recv_busy_ = false; -} - -void SSLClientSocketOpenSSL::DoConnectCallback(int rv) { - if (!user_connect_callback_.is_null()) { - CompletionCallback c = user_connect_callback_; - user_connect_callback_.Reset(); - c.Run(rv > OK ? OK : rv); - } -} - void SSLClientSocketOpenSSL::OnHandshakeIOComplete(int result) { int rv = DoHandshakeLoop(result); if (rv != ERR_IO_PENDING) { @@ -1207,95 +997,42 @@ void SSLClientSocketOpenSSL::OnRecvComplete(int result) { DoReadCallback(rv); } -bool SSLClientSocketOpenSSL::IsConnected() const { - // If the handshake has not yet completed. - if (!completed_handshake_) - return false; - // If an asynchronous operation is still pending. - if (user_read_buf_.get() || user_write_buf_.get()) - return true; - - return transport_->socket()->IsConnected(); -} - -bool SSLClientSocketOpenSSL::IsConnectedAndIdle() const { - // If the handshake has not yet completed. - if (!completed_handshake_) - return false; - // If an asynchronous operation is still pending. - if (user_read_buf_.get() || user_write_buf_.get()) - return false; - // If there is data waiting to be sent, or data read from the network that - // has not yet been consumed. - if (BIO_ctrl_pending(transport_bio_) > 0 || - BIO_ctrl_wpending(transport_bio_) > 0) { - return false; - } - - return transport_->socket()->IsConnectedAndIdle(); -} - -int SSLClientSocketOpenSSL::GetPeerAddress(IPEndPoint* addressList) const { - return transport_->socket()->GetPeerAddress(addressList); -} - -int SSLClientSocketOpenSSL::GetLocalAddress(IPEndPoint* addressList) const { - return transport_->socket()->GetLocalAddress(addressList); -} - -const BoundNetLog& SSLClientSocketOpenSSL::NetLog() const { - return net_log_; -} - -void SSLClientSocketOpenSSL::SetSubresourceSpeculation() { - if (transport_.get() && transport_->socket()) { - transport_->socket()->SetSubresourceSpeculation(); - } else { - NOTREACHED(); - } -} - -void SSLClientSocketOpenSSL::SetOmniboxSpeculation() { - if (transport_.get() && transport_->socket()) { - transport_->socket()->SetOmniboxSpeculation(); - } else { - NOTREACHED(); - } -} - -bool SSLClientSocketOpenSSL::WasEverUsed() const { - if (transport_.get() && transport_->socket()) - return transport_->socket()->WasEverUsed(); - - NOTREACHED(); - return false; -} - -bool SSLClientSocketOpenSSL::UsingTCPFastOpen() const { - if (transport_.get() && transport_->socket()) - return transport_->socket()->UsingTCPFastOpen(); - - NOTREACHED(); - return false; -} - -// Socket methods - -int SSLClientSocketOpenSSL::Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) { - user_read_buf_ = buf; - user_read_buf_len_ = buf_len; - - int rv = DoReadLoop(OK); - - if (rv == ERR_IO_PENDING) { - user_read_callback_ = callback; - } else { - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - } +int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) { + int rv = last_io_result; + do { + // Default to STATE_NONE for next state. + // (This is a quirk carried over from the windows + // implementation. It makes reading the logs a bit harder.) + // State handlers can and often do call GotoState just + // to stay in the current state. + State state = next_handshake_state_; + GotoState(STATE_NONE); + switch (state) { + case STATE_HANDSHAKE: + rv = DoHandshake(); + break; + case STATE_VERIFY_CERT: + DCHECK(rv == OK); + rv = DoVerifyCert(rv); + break; + case STATE_VERIFY_CERT_COMPLETE: + rv = DoVerifyCertComplete(rv); + break; + case STATE_NONE: + default: + rv = ERR_UNEXPECTED; + NOTREACHED() << "unexpected state" << state; + break; + } + bool network_moved = DoTransportIO(); + if (network_moved && next_handshake_state_ == STATE_HANDSHAKE) { + // In general we exit the loop if rv is ERR_IO_PENDING. In this + // special case we keep looping even if rv is ERR_IO_PENDING because + // the transport IO may allow DoHandshake to make progress. + rv = OK; // This causes us to stay in the loop. + } + } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE); return rv; } @@ -1313,24 +1050,6 @@ int SSLClientSocketOpenSSL::DoReadLoop(int result) { return rv; } -int SSLClientSocketOpenSSL::Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) { - user_write_buf_ = buf; - user_write_buf_len_ = buf_len; - - int rv = DoWriteLoop(OK); - - if (rv == ERR_IO_PENDING) { - user_write_callback_ = callback; - } else { - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - } - - return rv; -} - int SSLClientSocketOpenSSL::DoWriteLoop(int result) { if (result < 0) return result; @@ -1345,14 +1064,6 @@ int SSLClientSocketOpenSSL::DoWriteLoop(int result) { return rv; } -bool SSLClientSocketOpenSSL::SetReceiveBufferSize(int32 size) { - return transport_->socket()->SetReceiveBufferSize(size); -} - -bool SSLClientSocketOpenSSL::SetSendBufferSize(int32 size) { - return transport_->socket()->SetSendBufferSize(size); -} - int SSLClientSocketOpenSSL::DoPayloadRead() { crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); @@ -1434,4 +1145,278 @@ int SSLClientSocketOpenSSL::DoPayloadWrite() { return MapOpenSSLError(err, err_tracer); } +int SSLClientSocketOpenSSL::BufferSend(void) { + if (transport_send_busy_) + return ERR_IO_PENDING; + + if (!send_buffer_.get()) { + // Get a fresh send buffer out of the send BIO. + size_t max_read = BIO_ctrl_pending(transport_bio_); + if (!max_read) + return 0; // Nothing pending in the OpenSSL write BIO. + send_buffer_ = new DrainableIOBuffer(new IOBuffer(max_read), max_read); + int read_bytes = BIO_read(transport_bio_, send_buffer_->data(), max_read); + DCHECK_GT(read_bytes, 0); + CHECK_EQ(static_cast<int>(max_read), read_bytes); + } + + int rv = transport_->socket()->Write( + send_buffer_.get(), + send_buffer_->BytesRemaining(), + base::Bind(&SSLClientSocketOpenSSL::BufferSendComplete, + base::Unretained(this))); + if (rv == ERR_IO_PENDING) { + transport_send_busy_ = true; + } else { + TransportWriteComplete(rv); + } + return rv; +} + +int SSLClientSocketOpenSSL::BufferRecv(void) { + if (transport_recv_busy_) + return ERR_IO_PENDING; + + // Determine how much was requested from |transport_bio_| that was not + // actually available. + size_t requested = BIO_ctrl_get_read_request(transport_bio_); + if (requested == 0) { + // This is not a perfect match of error codes, as no operation is + // actually pending. However, returning 0 would be interpreted as + // a possible sign of EOF, which is also an inappropriate match. + return ERR_IO_PENDING; + } + + // Known Issue: While only reading |requested| data is the more correct + // implementation, it has the downside of resulting in frequent reads: + // One read for the SSL record header (~5 bytes) and one read for the SSL + // record body. Rather than issuing these reads to the underlying socket + // (and constantly allocating new IOBuffers), a single Read() request to + // fill |transport_bio_| is issued. As long as an SSL client socket cannot + // be gracefully shutdown (via SSL close alerts) and re-used for non-SSL + // traffic, this over-subscribed Read()ing will not cause issues. + size_t max_write = BIO_ctrl_get_write_guarantee(transport_bio_); + if (!max_write) + return ERR_IO_PENDING; + + recv_buffer_ = new IOBuffer(max_write); + int rv = transport_->socket()->Read( + recv_buffer_.get(), + max_write, + base::Bind(&SSLClientSocketOpenSSL::BufferRecvComplete, + base::Unretained(this))); + if (rv == ERR_IO_PENDING) { + transport_recv_busy_ = true; + } else { + rv = TransportReadComplete(rv); + } + return rv; +} + +void SSLClientSocketOpenSSL::BufferSendComplete(int result) { + transport_send_busy_ = false; + TransportWriteComplete(result); + OnSendComplete(result); +} + +void SSLClientSocketOpenSSL::BufferRecvComplete(int result) { + result = TransportReadComplete(result); + OnRecvComplete(result); +} + +void SSLClientSocketOpenSSL::TransportWriteComplete(int result) { + DCHECK(ERR_IO_PENDING != result); + if (result < 0) { + // Got a socket write error; close the BIO to indicate this upward. + // + // TODO(davidben): The value of |result| gets lost. Feed the error back into + // the BIO so it gets (re-)detected in OnSendComplete. Perhaps with + // BIO_set_callback. + DVLOG(1) << "TransportWriteComplete error " << result; + (void)BIO_shutdown_wr(SSL_get_wbio(ssl_)); + + // Match the fix for http://crbug.com/249848 in NSS by erroring future reads + // from the socket after a write error. + // + // TODO(davidben): Avoid having read and write ends interact this way. + transport_write_error_ = result; + (void)BIO_shutdown_wr(transport_bio_); + send_buffer_ = NULL; + } else { + DCHECK(send_buffer_.get()); + send_buffer_->DidConsume(result); + DCHECK_GE(send_buffer_->BytesRemaining(), 0); + if (send_buffer_->BytesRemaining() <= 0) + send_buffer_ = NULL; + } +} + +int SSLClientSocketOpenSSL::TransportReadComplete(int result) { + DCHECK(ERR_IO_PENDING != result); + if (result <= 0) { + DVLOG(1) << "TransportReadComplete result " << result; + // Received 0 (end of file) or an error. Either way, bubble it up to the + // SSL layer via the BIO. TODO(joth): consider stashing the error code, to + // relay up to the SSL socket client (i.e. via DoReadCallback). + if (result == 0) + transport_recv_eof_ = true; + (void)BIO_shutdown_wr(transport_bio_); + } else if (transport_write_error_ < 0) { + // Mirror transport write errors as read failures; transport_bio_ has been + // shut down by TransportWriteComplete, so the BIO_write will fail, failing + // the CHECK. http://crbug.com/335557. + result = transport_write_error_; + } else { + DCHECK(recv_buffer_.get()); + int ret = BIO_write(transport_bio_, recv_buffer_->data(), result); + // A write into a memory BIO should always succeed. + // Force values on the stack for http://crbug.com/335557 + base::debug::Alias(&result); + base::debug::Alias(&ret); + CHECK_EQ(result, ret); + } + recv_buffer_ = NULL; + transport_recv_busy_ = false; + return result; +} + +int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl, + X509** x509, + EVP_PKEY** pkey) { + DVLOG(3) << "OpenSSL ClientCertRequestCallback called"; + DCHECK(ssl == ssl_); + DCHECK(*x509 == NULL); + DCHECK(*pkey == NULL); + + if (!ssl_config_.send_client_cert) { + // First pass: we know that a client certificate is needed, but we do not + // have one at hand. + client_auth_cert_needed_ = true; + STACK_OF(X509_NAME) *authorities = SSL_get_client_CA_list(ssl); + for (int i = 0; i < sk_X509_NAME_num(authorities); i++) { + X509_NAME *ca_name = (X509_NAME *)sk_X509_NAME_value(authorities, i); + unsigned char* str = NULL; + int length = i2d_X509_NAME(ca_name, &str); + cert_authorities_.push_back(std::string( + reinterpret_cast<const char*>(str), + static_cast<size_t>(length))); + OPENSSL_free(str); + } + + return -1; // Suspends handshake. + } + + // Second pass: a client certificate should have been selected. + if (ssl_config_.client_cert.get()) { + // A note about ownership: FetchClientCertPrivateKey() increments + // the reference count of the EVP_PKEY. Ownership of this reference + // is passed directly to OpenSSL, which will release the reference + // using EVP_PKEY_free() when the SSL object is destroyed. + OpenSSLClientKeyStore::ScopedEVP_PKEY privkey; + if (OpenSSLClientKeyStore::GetInstance()->FetchClientCertPrivateKey( + ssl_config_.client_cert.get(), &privkey)) { + // TODO(joth): (copied from NSS) We should wait for server certificate + // verification before sending our credentials. See http://crbug.com/13934 + *x509 = X509Certificate::DupOSCertHandle( + ssl_config_.client_cert->os_cert_handle()); + *pkey = privkey.release(); + return 1; + } + LOG(WARNING) << "Client cert found without private key"; + } + + // Send no client certificate. + return 0; +} + +void SSLClientSocketOpenSSL::ChannelIDRequestCallback(SSL* ssl, + EVP_PKEY** pkey) { + DVLOG(3) << "OpenSSL ChannelIDRequestCallback called"; + DCHECK_EQ(ssl, ssl_); + DCHECK(!*pkey); + + channel_id_xtn_negotiated_ = true; + if (!channel_id_private_key_.size()) { + channel_id_request_return_value_ = + server_bound_cert_service_->GetOrCreateDomainBoundCert( + host_and_port_.host(), + &channel_id_private_key_, + &channel_id_cert_, + base::Bind(&SSLClientSocketOpenSSL::OnHandshakeIOComplete, + base::Unretained(this)), + &channel_id_request_handle_); + if (channel_id_request_return_value_ != OK) + return; + } + + // Decode key. + std::vector<uint8> encrypted_private_key_info; + std::vector<uint8> subject_public_key_info; + encrypted_private_key_info.assign( + channel_id_private_key_.data(), + channel_id_private_key_.data() + channel_id_private_key_.size()); + subject_public_key_info.assign( + channel_id_cert_.data(), + channel_id_cert_.data() + channel_id_cert_.size()); + scoped_ptr<crypto::ECPrivateKey> ec_private_key( + crypto::ECPrivateKey::CreateFromEncryptedPrivateKeyInfo( + ServerBoundCertService::kEPKIPassword, + encrypted_private_key_info, + subject_public_key_info)); + set_channel_id_sent(true); + *pkey = EVP_PKEY_dup(ec_private_key->key()); +} + +// SelectNextProtoCallback is called by OpenSSL during the handshake. If the +// server supports NPN, selects a protocol from the list that the server +// provides. According to third_party/openssl/openssl/ssl/ssl_lib.c, the +// callback can assume that |in| is syntactically valid. +int SSLClientSocketOpenSSL::SelectNextProtoCallback(unsigned char** out, + unsigned char* outlen, + const unsigned char* in, + unsigned int inlen) { +#if defined(OPENSSL_NPN_NEGOTIATED) + if (ssl_config_.next_protos.empty()) { + *out = reinterpret_cast<uint8*>( + const_cast<char*>(kDefaultSupportedNPNProtocol)); + *outlen = arraysize(kDefaultSupportedNPNProtocol) - 1; + npn_status_ = kNextProtoUnsupported; + return SSL_TLSEXT_ERR_OK; + } + + // Assume there's no overlap between our protocols and the server's list. + npn_status_ = kNextProtoNoOverlap; + + // For each protocol in server preference order, see if we support it. + for (unsigned int i = 0; i < inlen; i += in[i] + 1) { + for (std::vector<std::string>::const_iterator + j = ssl_config_.next_protos.begin(); + j != ssl_config_.next_protos.end(); ++j) { + if (in[i] == j->size() && + memcmp(&in[i + 1], j->data(), in[i]) == 0) { + // We found a match. + *out = const_cast<unsigned char*>(in) + i + 1; + *outlen = in[i]; + npn_status_ = kNextProtoNegotiated; + break; + } + } + if (npn_status_ == kNextProtoNegotiated) + break; + } + + // If we didn't find a protocol, we select the first one from our list. + if (npn_status_ == kNextProtoNoOverlap) { + *out = reinterpret_cast<uint8*>(const_cast<char*>( + ssl_config_.next_protos[0].data())); + *outlen = ssl_config_.next_protos[0].size(); + } + + npn_proto_.assign(reinterpret_cast<const char*>(*out), *outlen); + server_protos_.assign(reinterpret_cast<const char*>(in), inlen); + DVLOG(2) << "next protocol: '" << npn_proto_ << "' status: " << npn_status_; +#endif + return SSL_TLSEXT_ERR_OK; +} + } // namespace net diff --git a/chromium/net/socket/ssl_client_socket_openssl.h b/chromium/net/socket/ssl_client_socket_openssl.h index f66d95cc69d..5f4800a08de 100644 --- a/chromium/net/socket/ssl_client_socket_openssl.h +++ b/chromium/net/socket/ssl_client_socket_openssl.h @@ -15,6 +15,7 @@ #include "net/cert/cert_verify_result.h" #include "net/socket/client_socket_handle.h" #include "net/socket/ssl_client_socket.h" +#include "net/ssl/server_bound_cert_service.h" #include "net/ssl/ssl_config_service.h" // Avoid including misc OpenSSL headers, i.e.: @@ -52,14 +53,6 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { return ssl_session_cache_shard_; } - // Callback from the SSL layer that indicates the remote server is requesting - // a certificate for this client. - int ClientCertRequestCallback(SSL* ssl, X509** x509, EVP_PKEY** pkey); - - // Callback from the SSL layer to check which NPN protocol we are supporting - int SelectNextProtoCallback(unsigned char** out, unsigned char* outlen, - const unsigned char* in, unsigned int inlen); - // SSLClientSocket implementation. virtual void GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) OVERRIDE; @@ -98,6 +91,10 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { virtual bool SetSendBufferSize(int32 size) OVERRIDE; private: + class SSLContext; + friend class SSLClientSocket; + friend class SSLContext; + bool Init(); void DoReadCallback(int result); void DoWriteCallback(int result); @@ -124,7 +121,19 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { void BufferSendComplete(int result); void BufferRecvComplete(int result); void TransportWriteComplete(int result); - void TransportReadComplete(int result); + int TransportReadComplete(int result); + + // Callback from the SSL layer that indicates the remote server is requesting + // a certificate for this client. + int ClientCertRequestCallback(SSL* ssl, X509** x509, EVP_PKEY** pkey); + + // Callback from the SSL layer that indicates the remote server supports TLS + // Channel IDs. + void ChannelIDRequestCallback(SSL* ssl, EVP_PKEY** pkey); + + // Callback from the SSL layer to check which NPN protocol we are supporting + int SelectNextProtoCallback(unsigned char** out, unsigned char* outlen, + const unsigned char* in, unsigned int inlen); bool transport_send_busy_; bool transport_recv_busy_; @@ -155,6 +164,10 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // indicates an error. int pending_read_error_; + // Used by TransportWriteComplete() and TransportReadComplete() to signify an + // error writing to the transport socket. A value of OK indicates no error. + int transport_write_error_; + // Set when handshake finishes. scoped_refptr<X509Certificate> server_cert_; CertVerifyResult server_cert_verify_result_; @@ -170,6 +183,9 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { CertVerifier* const cert_verifier_; scoped_ptr<SingleRequestCertVerifier> verifier_; + // The service for retrieving Channel ID keys. May be NULL. + ServerBoundCertService* server_bound_cert_service_; + // OpenSSL stuff SSL* ssl_; BIO* transport_bio_; @@ -195,6 +211,15 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { NextProtoStatus npn_status_; std::string npn_proto_; std::string server_protos_; + // Written by the |server_bound_cert_service_|. + std::string channel_id_private_key_; + std::string channel_id_cert_; + // The return value of the last call to |server_bound_cert_service_|. + int channel_id_request_return_value_; + // True if channel ID extension was negotiated. + bool channel_id_xtn_negotiated_; + // The request handle for |server_bound_cert_service_|. + ServerBoundCertService::RequestHandle channel_id_request_handle_; BoundNetLog net_log_; }; diff --git a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc index 24c06059be5..48a4813f159 100644 --- a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc @@ -17,6 +17,7 @@ #include "base/files/file_path.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_handle.h" +#include "base/message_loop/message_loop_proxy.h" #include "base/values.h" #include "crypto/openssl_util.h" #include "net/base/address_list.h" @@ -34,7 +35,9 @@ #include "net/socket/client_socket_handle.h" #include "net/socket/socket_test_util.h" #include "net/socket/tcp_client_socket.h" +#include "net/ssl/default_server_bound_cert_store.h" #include "net/ssl/openssl_client_key_store.h" +#include "net/ssl/server_bound_cert_service.h" #include "net/ssl/ssl_cert_request_info.h" #include "net/ssl/ssl_config_service.h" #include "net/test/cert_test_util.h" @@ -59,6 +62,35 @@ typedef crypto::ScopedOpenSSL<BIGNUM, BN_free> ScopedBIGNUM; const SSLConfig kDefaultSSLConfig; +// A ServerBoundCertStore that always returns an error when asked for a +// certificate. +class FailingServerBoundCertStore : public ServerBoundCertStore { + virtual int GetServerBoundCert(const std::string& server_identifier, + base::Time* expiration_time, + std::string* private_key_result, + std::string* cert_result, + const GetCertCallback& callback) OVERRIDE { + return ERR_UNEXPECTED; + } + virtual void SetServerBoundCert(const std::string& server_identifier, + base::Time creation_time, + base::Time expiration_time, + const std::string& private_key, + const std::string& cert) OVERRIDE {} + virtual void DeleteServerBoundCert(const std::string& server_identifier, + const base::Closure& completion_callback) + OVERRIDE {} + virtual void DeleteAllCreatedBetween(base::Time delete_begin, + base::Time delete_end, + const base::Closure& completion_callback) + OVERRIDE {} + virtual void DeleteAll(const base::Closure& completion_callback) OVERRIDE {} + virtual void GetAllServerBoundCerts(const GetCertListCallback& callback) + OVERRIDE {} + virtual int GetCertCount() OVERRIDE { return 0; } + virtual void SetForceKeepSessionState() OVERRIDE {} +}; + // Loads a PEM-encoded private key file into a scoped EVP_PKEY object. // |filepath| is the private key file path. // |*pkey| is reset to the new EVP_PKEY on success, untouched otherwise. @@ -107,6 +139,20 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { } protected: + void EnabledChannelID() { + cert_service_.reset( + new ServerBoundCertService(new DefaultServerBoundCertStore(NULL), + base::MessageLoopProxy::current())); + context_.server_bound_cert_service = cert_service_.get(); + } + + void EnabledFailingChannelID() { + cert_service_.reset( + new ServerBoundCertService(new FailingServerBoundCertStore(), + base::MessageLoopProxy::current())); + context_.server_bound_cert_service = cert_service_.get(); + } + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( scoped_ptr<StreamSocket> transport_socket, const HostPortPair& host_and_port, @@ -188,6 +234,7 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { return ssl_info.client_cert_sent; } + scoped_ptr<ServerBoundCertService> cert_service_; ClientSocketFactory* socket_factory_; scoped_ptr<MockCertVerifier> cert_verifier_; scoped_ptr<TransportSecurityState> transport_security_state_; @@ -275,5 +322,44 @@ TEST_F(SSLClientSocketOpenSSLClientAuthTest, SendGoodCert) { EXPECT_FALSE(sock_->IsConnected()); } +// Connect to a server using channel id. It should allow the connection. +TEST_F(SSLClientSocketOpenSSLClientAuthTest, SendChannelID) { + SpawnedTestServer::SSLOptions ssl_options; + + ASSERT_TRUE(ConnectToTestServer(ssl_options)); + + EnabledChannelID(); + SSLConfig ssl_config = kDefaultSSLConfig; + ssl_config.channel_id_enabled = true; + + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); + + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock_->IsConnected()); + EXPECT_TRUE(sock_->WasChannelIDSent()); + + sock_->Disconnect(); + EXPECT_FALSE(sock_->IsConnected()); +} + +// Connect to a server using channel id but without sending a key. It should +// fail. +TEST_F(SSLClientSocketOpenSSLClientAuthTest, FailingChannelID) { + SpawnedTestServer::SSLOptions ssl_options; + + ASSERT_TRUE(ConnectToTestServer(ssl_options)); + + EnabledFailingChannelID(); + SSLConfig ssl_config = kDefaultSSLConfig; + ssl_config.channel_id_enabled = true; + + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); + + EXPECT_EQ(ERR_UNEXPECTED, rv); + EXPECT_FALSE(sock_->IsConnected()); +} + } // namespace } // namespace net diff --git a/chromium/net/socket/ssl_client_socket_pool.cc b/chromium/net/socket/ssl_client_socket_pool.cc index 5d574b7edda..de315fdc01a 100644 --- a/chromium/net/socket/ssl_client_socket_pool.cc +++ b/chromium/net/socket/ssl_client_socket_pool.cc @@ -123,6 +123,7 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name, context_(context.cert_verifier, context.server_bound_cert_service, context.transport_security_state, + context.cert_transparency_verifier, (params->privacy_mode() == kPrivacyModeEnabled ? "pm/" + context.ssl_session_cache_shard : context.ssl_session_cache_shard)), @@ -508,6 +509,7 @@ SSLClientSocketPool::SSLClientSocketPool( CertVerifier* cert_verifier, ServerBoundCertService* server_bound_cert_service, TransportSecurityState* transport_security_state, + CTVerifier* cert_transparency_verifier, const std::string& ssl_session_cache_shard, ClientSocketFactory* client_socket_factory, TransportClientSocketPool* transport_pool, @@ -530,6 +532,7 @@ SSLClientSocketPool::SSLClientSocketPool( cert_verifier, server_bound_cert_service, transport_security_state, + cert_transparency_verifier, ssl_session_cache_shard), net_log)), ssl_config_service_(ssl_config_service) { diff --git a/chromium/net/socket/ssl_client_socket_pool.h b/chromium/net/socket/ssl_client_socket_pool.h index ec62eb01f46..e03b76ade6a 100644 --- a/chromium/net/socket/ssl_client_socket_pool.h +++ b/chromium/net/socket/ssl_client_socket_pool.h @@ -24,6 +24,7 @@ namespace net { class CertVerifier; class ClientSocketFactory; class ConnectJobFactory; +class CTVerifier; class HostPortPair; class HttpProxyClientSocketPool; class HttpProxySocketParams; @@ -189,6 +190,7 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool CertVerifier* cert_verifier, ServerBoundCertService* server_bound_cert_service, TransportSecurityState* transport_security_state, + CTVerifier* cert_transparency_verifier, const std::string& ssl_session_cache_shard, ClientSocketFactory* client_socket_factory, TransportClientSocketPool* transport_pool, diff --git a/chromium/net/socket/ssl_client_socket_pool_unittest.cc b/chromium/net/socket/ssl_client_socket_pool_unittest.cc index 8aecb98dd85..92ad51a4351 100644 --- a/chromium/net/socket/ssl_client_socket_pool_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_pool_unittest.cc @@ -140,6 +140,7 @@ class SSLClientSocketPoolTest NULL /* cert_verifier */, NULL /* server_bound_cert_service */, NULL /* transport_security_state */, + NULL /* cert_transparency_verifier */, std::string() /* ssl_session_cache_shard */, &socket_factory_, transport_pool ? &transport_socket_pool_ : NULL, @@ -225,7 +226,8 @@ class SSLClientSocketPoolTest INSTANTIATE_TEST_CASE_P( NextProto, SSLClientSocketPoolTest, - testing::Values(kProtoSPDY2, kProtoSPDY3, kProtoSPDY31, kProtoSPDY4a2, + testing::Values(kProtoDeprecatedSPDY2, + kProtoSPDY3, kProtoSPDY31, kProtoSPDY4a2, kProtoHTTP2Draft04)); TEST_P(SSLClientSocketPoolTest, TCPFail) { @@ -297,7 +299,7 @@ TEST_P(SSLClientSocketPoolTest, SetSocketRequestPriorityOnInitDirect) { scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, false); - for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) { + for (int i = MINIMUM_PRIORITY; i <= MAXIMUM_PRIORITY; ++i) { RequestPriority priority = static_cast<RequestPriority>(i); StaticSocketDataProvider data; data.set_connect_data(MockConnect(SYNCHRONOUS, OK)); diff --git a/chromium/net/socket/ssl_client_socket_unittest.cc b/chromium/net/socket/ssl_client_socket_unittest.cc index f791928580f..14633a958ba 100644 --- a/chromium/net/socket/ssl_client_socket_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_unittest.cc @@ -1171,6 +1171,112 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { EXPECT_FALSE(callback.have_result()); } +// Tests that the SSLClientSocket does not crash if data is received on the +// transport socket after a failing write. This can occur if we have a Write +// error in a SPDY socket. +// Regression test for http://crbug.com/335557 +TEST_F(SSLClientSocketTest, Read_WithWriteError) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + // Note: |error_socket|'s ownership is handed to |transport|, but a pointer + // is retained in order to configure additional errors. + scoped_ptr<SynchronousErrorStreamSocket> error_socket( + new SynchronousErrorStreamSocket(real_transport.Pass())); + SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); + scoped_ptr<FakeBlockingStreamSocket> transport( + new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>())); + FakeBlockingStreamSocket* raw_transport = transport.get(); + + int rv = callback.GetResult(transport->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + + // Disable TLS False Start to avoid handshake non-determinism. + SSLConfig ssl_config; + ssl_config.false_start_enabled = false; + + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), + ssl_config)); + + rv = callback.GetResult(sock->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + + // Send a request so there is something to read from the socket. + const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; + static const int kRequestTextSize = + static_cast<int>(arraysize(request_text) - 1); + scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize)); + memcpy(request_buffer->data(), request_text, kRequestTextSize); + + rv = callback.GetResult( + sock->Write(request_buffer.get(), kRequestTextSize, callback.callback())); + EXPECT_EQ(kRequestTextSize, rv); + + // Start a hanging read. + TestCompletionCallback read_callback; + raw_transport->SetNextReadShouldBlock(); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + rv = sock->Read(buf.get(), 4096, read_callback.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + // Perform another write, but have it fail. Write a request larger than the + // internal socket buffers so that the request hits the underlying transport + // socket and detects the error. + std::string long_request_text = + "GET / HTTP/1.1\r\nUser-Agent: long browser name "; + long_request_text.append(20 * 1024, '*'); + long_request_text.append("\r\n\r\n"); + scoped_refptr<DrainableIOBuffer> long_request_buffer(new DrainableIOBuffer( + new StringIOBuffer(long_request_text), long_request_text.size())); + + raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET); + + // Write as much data as possible until hitting an error. This is necessary + // for NSS. PR_Write will only consume as much data as it can encode into + // application data records before the internal memio buffer is full, which + // should only fill if writing a large amount of data and the underlying + // transport is blocked. Once this happens, NSS will return (total size of all + // application data records it wrote) - 1, with the caller expected to resume + // with the remaining unsent data. + do { + rv = callback.GetResult(sock->Write(long_request_buffer.get(), + long_request_buffer->BytesRemaining(), + callback.callback())); + if (rv > 0) { + long_request_buffer->DidConsume(rv); + // Abort if the entire buffer is ever consumed. + ASSERT_LT(0, long_request_buffer->BytesRemaining()); + } + } while (rv > 0); + +#if !defined(USE_OPENSSL) + // NSS records the error exactly. + EXPECT_EQ(ERR_CONNECTION_RESET, rv); +#else + // OpenSSL treats the reset as a generic protocol error. + EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv); +#endif + + // Release the read. Some bytes should go through. + raw_transport->UnblockRead(); + rv = read_callback.WaitForResult(); + + // Per the fix for http://crbug.com/249848, write failures currently break + // reads. Change this assertion if they're changed to not collide. + EXPECT_EQ(ERR_CONNECTION_RESET, rv); +} + TEST_F(SSLClientSocketTest, Read_SmallChunks) { SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, SpawnedTestServer::kLocalhost, @@ -1795,4 +1901,162 @@ TEST_F(SSLClientSocketCertRequestInfoTest, TwoAuthorities) { } // namespace +TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledTLSExtension) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.signed_cert_timestamps_tls_ext = "test"; + + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + ssl_options, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + SSLConfig ssl_config; + ssl_config.signed_cert_timestamps_enabled = true; + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); + + EXPECT_FALSE(sock->IsConnected()); + + rv = sock->Connect(callback.callback()); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); + +#if !defined(USE_OPENSSL) + EXPECT_TRUE(sock->signed_cert_timestamps_received_); +#else + // Enabling CT for OpenSSL is currently a noop. + EXPECT_FALSE(sock->signed_cert_timestamps_received_); +#endif + + sock->Disconnect(); + EXPECT_FALSE(sock->IsConnected()); +} + +// Test that enabling Signed Certificate Timestamps enables OCSP stapling. +TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledOCSP) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.staple_ocsp_response = true; + // The test server currently only knows how to generate OCSP responses + // for a freshly minted certificate. + ssl_options.server_certificate = SpawnedTestServer::SSLOptions::CERT_AUTO; + + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + ssl_options, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + SSLConfig ssl_config; + // Enabling Signed Cert Timestamps ensures we request OCSP stapling for + // Certificate Transparency verification regardless of whether the platform + // is able to process the OCSP status itself. + ssl_config.signed_cert_timestamps_enabled = true; + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); + + EXPECT_FALSE(sock->IsConnected()); + + rv = sock->Connect(callback.callback()); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); + +#if !defined(USE_OPENSSL) + EXPECT_TRUE(sock->stapled_ocsp_response_received_); +#else + // OCSP stapling isn't currently supported in the OpenSSL socket. + EXPECT_FALSE(sock->stapled_ocsp_response_received_); +#endif + + sock->Disconnect(); + EXPECT_FALSE(sock->IsConnected()); +} + +TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsDisabled) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.signed_cert_timestamps_tls_ext = "test"; + + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + ssl_options, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + SSLConfig ssl_config; + ssl_config.signed_cert_timestamps_enabled = false; + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); + + EXPECT_FALSE(sock->IsConnected()); + + rv = sock->Connect(callback.callback()); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); + + EXPECT_FALSE(sock->signed_cert_timestamps_received_); + + sock->Disconnect(); + EXPECT_FALSE(sock->IsConnected()); +} + } // namespace net diff --git a/chromium/net/socket/ssl_server_socket_nss.cc b/chromium/net/socket/ssl_server_socket_nss.cc index 7e5d70118ac..b95983a5159 100644 --- a/chromium/net/socket/ssl_server_socket_nss.cc +++ b/chromium/net/socket/ssl_server_socket_nss.cc @@ -106,10 +106,6 @@ SSLServerSocketNSS::SSLServerSocketNSS( cert_(cert), next_handshake_state_(STATE_NONE), completed_handshake_(false) { - ssl_config_.false_start_enabled = false; - ssl_config_.version_min = SSL_PROTOCOL_VERSION_SSL3; - ssl_config_.version_max = SSL_PROTOCOL_VERSION_TLS1_1; - // TODO(hclam): Need a better way to clone a key. std::vector<uint8> key_bytes; CHECK(key->ExportPrivateKey(&key_bytes)); @@ -350,6 +346,23 @@ int SSLServerSocketNSS::InitializeSSLOptions() { return ERR_NO_SSL_VERSIONS_ENABLED; } + if (ssl_config_.require_forward_secrecy) { + const PRUint16* const ssl_ciphers = SSL_GetImplementedCiphers(); + const PRUint16 num_ciphers = SSL_GetNumImplementedCiphers(); + + // Require forward security by iterating over the cipher suites and + // disabling all those that don't use ECDHE. + for (unsigned i = 0; i < num_ciphers; i++) { + SSLCipherSuiteInfo info; + if (SSL_GetCipherSuiteInfo(ssl_ciphers[i], &info, sizeof(info)) == + SECSuccess) { + if (strcmp(info.keaTypeName, "ECDHE") != 0) { + SSL_CipherPrefSet(nss_fd_, ssl_ciphers[i], PR_FALSE); + } + } + } + } + for (std::vector<uint16>::const_iterator it = ssl_config_.disabled_cipher_suites.begin(); it != ssl_config_.disabled_cipher_suites.end(); ++it) { diff --git a/chromium/net/socket/ssl_server_socket_unittest.cc b/chromium/net/socket/ssl_server_socket_unittest.cc index e1f7f496131..d5d04b20a90 100644 --- a/chromium/net/socket/ssl_server_socket_unittest.cc +++ b/chromium/net/socket/ssl_server_socket_unittest.cc @@ -56,9 +56,9 @@ class FakeDataChannel { public: FakeDataChannel() : read_buf_len_(0), - weak_factory_(this), closed_(false), - write_called_after_close_(false) { + write_called_after_close_(false), + weak_factory_(this) { } int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { @@ -140,8 +140,6 @@ class FakeDataChannel { std::queue<scoped_refptr<net::DrainableIOBuffer> > data_; - base::WeakPtrFactory<FakeDataChannel> weak_factory_; - // True if Close() has been called. bool closed_; @@ -150,6 +148,8 @@ class FakeDataChannel { // asynchronously. bool write_called_after_close_; + base::WeakPtrFactory<FakeDataChannel> weak_factory_; + DISALLOW_COPY_AND_ASSIGN(FakeDataChannel); }; @@ -334,8 +334,6 @@ class SSLServerSocketTest : public PlatformTest { ssl_config.cached_info_enabled = false; ssl_config.false_start_enabled = false; ssl_config.channel_id_enabled = false; - ssl_config.version_min = SSL_PROTOCOL_VERSION_SSL3; - ssl_config.version_max = SSL_PROTOCOL_VERSION_TLS1_1; // Certificate provided by the host doesn't need authority. net::SSLConfig::CertAndStatus cert_and_status; diff --git a/chromium/net/socket/ssl_session_cache_openssl.cc b/chromium/net/socket/ssl_session_cache_openssl.cc new file mode 100644 index 00000000000..d16bb8d6325 --- /dev/null +++ b/chromium/net/socket/ssl_session_cache_openssl.cc @@ -0,0 +1,508 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/ssl_session_cache_openssl.h" + +#include <list> +#include <map> + +#include <openssl/rand.h> +#include <openssl/ssl.h> + +#include "base/containers/hash_tables.h" +#include "base/lazy_instance.h" +#include "base/logging.h" +#include "base/synchronization/lock.h" + +namespace net { + +namespace { + +// A helper class to lazily create a new EX_DATA index to map SSL_CTX handles +// to their corresponding SSLSessionCacheOpenSSLImpl object. +class SSLContextExIndex { +public: + SSLContextExIndex() { + context_index_ = SSL_CTX_get_ex_new_index(0, NULL, NULL, NULL, NULL); + DCHECK_NE(-1, context_index_); + session_index_ = SSL_SESSION_get_ex_new_index(0, NULL, NULL, NULL, NULL); + DCHECK_NE(-1, session_index_); + } + + int context_index() const { return context_index_; } + int session_index() const { return session_index_; } + + private: + int context_index_; + int session_index_; +}; + +// static +base::LazyInstance<SSLContextExIndex>::Leaky s_ssl_context_ex_instance = + LAZY_INSTANCE_INITIALIZER; + +// Retrieve the global EX_DATA index, created lazily on first call, to +// be used with SSL_CTX_set_ex_data() and SSL_CTX_get_ex_data(). +static int GetSSLContextExIndex() { + return s_ssl_context_ex_instance.Get().context_index(); +} + +// Retrieve the global EX_DATA index, created lazily on first call, to +// be used with SSL_SESSION_set_ex_data() and SSL_SESSION_get_ex_data(). +static int GetSSLSessionExIndex() { + return s_ssl_context_ex_instance.Get().session_index(); +} + +// Helper struct used to store session IDs in a SessionIdIndex container +// (see definition below). To save memory each entry only holds a pointer +// to the session ID buffer, which must outlive the entry itself. On the +// other hand, a hash is included to minimize the number of hashing +// computations during cache operations. +struct SessionId { + SessionId(const unsigned char* a_id, unsigned a_id_len) + : id(a_id), id_len(a_id_len), hash(ComputeHash(a_id, a_id_len)) {} + + explicit SessionId(const SessionId& other) + : id(other.id), id_len(other.id_len), hash(other.hash) {} + + explicit SessionId(SSL_SESSION* session) + : id(session->session_id), + id_len(session->session_id_length), + hash(ComputeHash(session->session_id, session->session_id_length)) {} + + bool operator==(const SessionId& other) const { + return hash == other.hash && id_len == other.id_len && + !memcmp(id, other.id, id_len); + } + + const unsigned char* id; + unsigned id_len; + size_t hash; + + private: + // Session ID are random strings of bytes. This happens to compute the same + // value as std::hash<std::string> without the extra string copy. See + // base/containers/hash_tables.h. Other hashing computations are possible, + // this one is just simple enough to do the job. + size_t ComputeHash(const unsigned char* id, unsigned id_len) { + size_t result = 0; + for (unsigned n = 0; n < id_len; ++n) + result += 131 * id[n]; + return result; + } +}; + +} // namespace + +} // namespace net + +namespace BASE_HASH_NAMESPACE { + +template <> +struct hash<net::SessionId> { + std::size_t operator()(const net::SessionId& entry) const { + return entry.hash; + } +}; + +} // namespace BASE_HASH_NAMESPACE + +namespace net { + +// Implementation of the real SSLSessionCache. +// +// The implementation is inspired by base::MRUCache, except that the deletor +// also needs to remove the entry from other containers. In a nutshell, this +// uses several basic containers: +// +// |ordering_| is a doubly-linked list of SSL_SESSION handles, ordered in +// MRU order. +// +// |key_index_| is a hash table mapping unique cache keys (e.g. host/port +// values) to a single iterator of |ordering_|. It is used to efficiently +// find the cached session associated with a given key. +// +// |id_index_| is a hash table mapping SessionId values to iterators +// of |key_index_|. If is used to efficiently remove sessions from the cache, +// as well as check for the existence of a session ID value in the cache. +// +// SSL_SESSION objects are reference-counted, and owned by the cache. This +// means that their reference count is incremented when they are added, and +// decremented when they are removed. +// +// Assuming an average key size of 100 characters, each node requires the +// following memory usage on 32-bit Android, when linked against STLport: +// +// 12 (ordering_ node, including SSL_SESSION handle) +// 100 (key characters) +// + 24 (std::string header/minimum size) +// + 8 (key_index_ node, excluding the 2 lines above for the key). +// + 20 (id_index_ node) +// -------- +// 164 bytes/node +// +// Hence, 41 KiB for a full cache with a maximum of 1024 entries, excluding +// the size of SSL_SESSION objects and heap fragmentation. +// + +class SSLSessionCacheOpenSSLImpl { + public: + // Construct new instance. This registers various hooks into the SSL_CTX + // context |ctx|. OpenSSL will call back during SSL connection + // operations. |key_func| is used to map a SSL handle to a unique cache + // string, according to the client's preferences. + SSLSessionCacheOpenSSLImpl(SSL_CTX* ctx, + const SSLSessionCacheOpenSSL::Config& config) + : ctx_(ctx), config_(config), expiration_check_(0) { + DCHECK(ctx); + + // NO_INTERNAL_STORE disables OpenSSL's builtin cache, and + // NO_AUTO_CLEAR disables the call to SSL_CTX_flush_sessions + // every 256 connections (this number is hard-coded in the library + // and can't be changed). + SSL_CTX_set_session_cache_mode(ctx_, + SSL_SESS_CACHE_CLIENT | + SSL_SESS_CACHE_NO_INTERNAL_STORE | + SSL_SESS_CACHE_NO_AUTO_CLEAR); + + SSL_CTX_sess_set_new_cb(ctx_, NewSessionCallbackStatic); + SSL_CTX_sess_set_remove_cb(ctx_, RemoveSessionCallbackStatic); + SSL_CTX_set_generate_session_id(ctx_, GenerateSessionIdStatic); + SSL_CTX_set_timeout(ctx_, config_.timeout_seconds); + + SSL_CTX_set_ex_data(ctx_, GetSSLContextExIndex(), this); + } + + // Destroy this instance. Must happen before |ctx_| is destroyed. + ~SSLSessionCacheOpenSSLImpl() { + Flush(); + SSL_CTX_set_ex_data(ctx_, GetSSLContextExIndex(), NULL); + SSL_CTX_sess_set_new_cb(ctx_, NULL); + SSL_CTX_sess_set_remove_cb(ctx_, NULL); + SSL_CTX_set_generate_session_id(ctx_, NULL); + } + + // Return the number of items in this cache. + size_t size() const { return key_index_.size(); } + + // Retrieve the cache key from |ssl| and look for a corresponding + // cached session ID. If one is found, call SSL_set_session() to associate + // it with the |ssl| connection. + // + // Will also check for expired sessions every |expiration_check_count| + // calls. + // + // Return true if a cached session ID was found, false otherwise. + bool SetSSLSession(SSL* ssl) { + std::string cache_key = config_.key_func(ssl); + if (cache_key.empty()) + return false; + + return SetSSLSessionWithKey(ssl, cache_key); + } + + // Variant of SetSSLSession to be used when the client already has computed + // the cache key. Avoid a call to the configuration's |key_func| function. + bool SetSSLSessionWithKey(SSL* ssl, const std::string& cache_key) { + base::AutoLock locked(lock_); + + DCHECK_EQ(config_.key_func(ssl), cache_key); + + if (++expiration_check_ >= config_.expiration_check_count) { + expiration_check_ = 0; + FlushExpiredSessionsLocked(); + } + + KeyIndex::iterator it = key_index_.find(cache_key); + if (it == key_index_.end()) + return false; + + SSL_SESSION* session = *it->second; + DCHECK(session); + + DVLOG(2) << "Lookup session: " << session << " for " << cache_key; + + void* session_is_good = + SSL_SESSION_get_ex_data(session, GetSSLSessionExIndex()); + if (!session_is_good) + return false; // Session has not yet been marked good. Treat as a miss. + + // Move to front of MRU list. + ordering_.push_front(session); + ordering_.erase(it->second); + it->second = ordering_.begin(); + + return SSL_set_session(ssl, session) == 1; + } + + void MarkSSLSessionAsGood(SSL* ssl) { + SSL_SESSION* session = SSL_get_session(ssl); + if (!session) + return; + + // Mark the session as good, allowing it to be used for future connections. + SSL_SESSION_set_ex_data( + session, GetSSLSessionExIndex(), reinterpret_cast<void*>(1)); + } + + // Flush all entries from the cache. + void Flush() { + base::AutoLock lock(lock_); + id_index_.clear(); + key_index_.clear(); + while (!ordering_.empty()) { + SSL_SESSION* session = ordering_.front(); + ordering_.pop_front(); + SSL_SESSION_free(session); + } + } + + private: + // Type for list of SSL_SESSION handles, ordered in MRU order. + typedef std::list<SSL_SESSION*> MRUSessionList; + // Type for a dictionary from unique cache keys to session list nodes. + typedef base::hash_map<std::string, MRUSessionList::iterator> KeyIndex; + // Type for a dictionary from SessionId values to key index nodes. + typedef base::hash_map<SessionId, KeyIndex::iterator> SessionIdIndex; + + // Return the key associated with a given session, or the empty string if + // none exist. This shall only be used for debugging. + std::string SessionKey(SSL_SESSION* session) { + if (!session) + return std::string("<null-session>"); + + if (session->session_id_length == 0) + return std::string("<empty-session-id>"); + + SessionIdIndex::iterator it = id_index_.find(SessionId(session)); + if (it == id_index_.end()) + return std::string("<unknown-session>"); + + return it->second->first; + } + + // Remove a given |session| from the cache. Lock must be held. + void RemoveSessionLocked(SSL_SESSION* session) { + lock_.AssertAcquired(); + DCHECK(session); + DCHECK_GT(session->session_id_length, 0U); + SessionId session_id(session); + SessionIdIndex::iterator id_it = id_index_.find(session_id); + if (id_it == id_index_.end()) { + LOG(ERROR) << "Trying to remove unknown session from cache: " << session; + return; + } + KeyIndex::iterator key_it = id_it->second; + DCHECK(key_it != key_index_.end()); + DCHECK_EQ(session, *key_it->second); + + id_index_.erase(session_id); + ordering_.erase(key_it->second); + key_index_.erase(key_it); + + SSL_SESSION_free(session); + + DCHECK_EQ(key_index_.size(), id_index_.size()); + } + + // Used internally to flush expired sessions. Lock must be held. + void FlushExpiredSessionsLocked() { + lock_.AssertAcquired(); + + // Unfortunately, OpenSSL initializes |session->time| with a time() + // timestamps, which makes mocking / unit testing difficult. + long timeout_secs = static_cast<long>(::time(NULL)); + MRUSessionList::iterator it = ordering_.begin(); + while (it != ordering_.end()) { + SSL_SESSION* session = *it++; + + // Important, use <= instead of < here to allow unit testing to + // work properly. That's because unit tests that check the expiration + // behaviour will use a session timeout of 0 seconds. + if (session->time + session->timeout <= timeout_secs) { + DVLOG(2) << "Expiring session " << session << " for " + << SessionKey(session); + RemoveSessionLocked(session); + } + } + } + + // Retrieve the cache associated with a given SSL context |ctx|. + static SSLSessionCacheOpenSSLImpl* GetCache(SSL_CTX* ctx) { + DCHECK(ctx); + void* result = SSL_CTX_get_ex_data(ctx, GetSSLContextExIndex()); + DCHECK(result); + return reinterpret_cast<SSLSessionCacheOpenSSLImpl*>(result); + } + + // Called by OpenSSL when a new |session| was created and added to a given + // |ssl| connection. Note that the session's reference count was already + // incremented before the function is entered. The function must return 1 + // to indicate that it took ownership of the session, i.e. that the caller + // should not decrement its reference count after completion. + static int NewSessionCallbackStatic(SSL* ssl, SSL_SESSION* session) { + GetCache(ssl->ctx)->OnSessionAdded(ssl, session); + return 1; + } + + // Called by OpenSSL to indicate that a session must be removed from the + // cache. This happens when SSL_CTX is destroyed. + static void RemoveSessionCallbackStatic(SSL_CTX* ctx, SSL_SESSION* session) { + GetCache(ctx)->OnSessionRemoved(session); + } + + // Called by OpenSSL to generate a new session ID. This happens during a + // SSL connection operation, when the SSL object doesn't have a session yet. + // + // A session ID is a random string of bytes used to uniquely identify the + // session between a client and a server. + // + // |ssl| is a SSL connection handle. Ignored here. + // |id| is the target buffer where the ID must be generated. + // |*id_len| is, on input, the size of the desired ID. It will be 16 for + // SSLv2, and 32 for anything else. OpenSSL allows an implementation + // to change it on output, but this will not happen here. + // + // The function must ensure the generated ID is really unique, i.e. that + // another session in the cache doesn't already use the same value. It must + // return 1 to indicate success, or 0 for failure. + static int GenerateSessionIdStatic(const SSL* ssl, + unsigned char* id, + unsigned* id_len) { + if (!GetCache(ssl->ctx)->OnGenerateSessionId(id, *id_len)) + return 0; + + return 1; + } + + // Add |session| to the cache in association with |cache_key|. If a session + // already exists, it is replaced with the new one. This assumes that the + // caller already incremented the session's reference count. + void OnSessionAdded(SSL* ssl, SSL_SESSION* session) { + base::AutoLock locked(lock_); + DCHECK(ssl); + DCHECK_GT(session->session_id_length, 0U); + std::string cache_key = config_.key_func(ssl); + KeyIndex::iterator it = key_index_.find(cache_key); + if (it == key_index_.end()) { + DVLOG(2) << "Add session " << session << " for " << cache_key; + // This is a new session. Add it to the cache. + ordering_.push_front(session); + std::pair<KeyIndex::iterator, bool> ret = + key_index_.insert(std::make_pair(cache_key, ordering_.begin())); + DCHECK(ret.second); + it = ret.first; + DCHECK(it != key_index_.end()); + } else { + // An existing session exists for this key, so replace it if needed. + DVLOG(2) << "Replace session " << *it->second << " with " << session + << " for " << cache_key; + SSL_SESSION* old_session = *it->second; + if (old_session != session) { + id_index_.erase(SessionId(old_session)); + SSL_SESSION_free(old_session); + } + ordering_.erase(it->second); + ordering_.push_front(session); + it->second = ordering_.begin(); + } + + id_index_[SessionId(session)] = it; + + if (key_index_.size() > config_.max_entries) + ShrinkCacheLocked(); + + DCHECK_EQ(key_index_.size(), id_index_.size()); + DCHECK_LE(key_index_.size(), config_.max_entries); + } + + // Shrink the cache to ensure no more than config_.max_entries entries, + // starting with older entries first. Lock must be acquired. + void ShrinkCacheLocked() { + lock_.AssertAcquired(); + DCHECK_EQ(key_index_.size(), ordering_.size()); + DCHECK_EQ(key_index_.size(), id_index_.size()); + + while (key_index_.size() > config_.max_entries) { + MRUSessionList::reverse_iterator it = ordering_.rbegin(); + DCHECK(it != ordering_.rend()); + + SSL_SESSION* session = *it; + DCHECK(session); + DVLOG(2) << "Evicting session " << session << " for " + << SessionKey(session); + RemoveSessionLocked(session); + } + } + + // Remove |session| from the cache. + void OnSessionRemoved(SSL_SESSION* session) { + base::AutoLock locked(lock_); + DVLOG(2) << "Remove session " << session << " for " << SessionKey(session); + RemoveSessionLocked(session); + } + + // See GenerateSessionIdStatic for a description of what this function does. + bool OnGenerateSessionId(unsigned char* id, unsigned id_len) { + base::AutoLock locked(lock_); + // This mimics def_generate_session_id() in openssl/ssl/ssl_sess.cc, + // I.e. try to generate a pseudo-random bit string, and check that no + // other entry in the cache has the same value. + const size_t kMaxTries = 10; + for (size_t tries = 0; tries < kMaxTries; ++tries) { + if (RAND_pseudo_bytes(id, id_len) <= 0) { + DLOG(ERROR) << "Couldn't generate " << id_len + << " pseudo random bytes?"; + return false; + } + if (id_index_.find(SessionId(id, id_len)) == id_index_.end()) + return true; + } + DLOG(ERROR) << "Couldn't generate unique session ID of " << id_len + << "bytes after " << kMaxTries << " tries."; + return false; + } + + SSL_CTX* ctx_; + SSLSessionCacheOpenSSL::Config config_; + + // method to get the index which can later be used with SSL_CTX_get_ex_data() + // or SSL_CTX_set_ex_data(). + base::Lock lock_; // Protects access to containers below. + + MRUSessionList ordering_; + KeyIndex key_index_; + SessionIdIndex id_index_; + + size_t expiration_check_; +}; + +SSLSessionCacheOpenSSL::~SSLSessionCacheOpenSSL() { delete impl_; } + +size_t SSLSessionCacheOpenSSL::size() const { return impl_->size(); } + +void SSLSessionCacheOpenSSL::Reset(SSL_CTX* ctx, const Config& config) { + if (impl_) + delete impl_; + + impl_ = new SSLSessionCacheOpenSSLImpl(ctx, config); +} + +bool SSLSessionCacheOpenSSL::SetSSLSession(SSL* ssl) { + return impl_->SetSSLSession(ssl); +} + +bool SSLSessionCacheOpenSSL::SetSSLSessionWithKey( + SSL* ssl, + const std::string& cache_key) { + return impl_->SetSSLSessionWithKey(ssl, cache_key); +} + +void SSLSessionCacheOpenSSL::MarkSSLSessionAsGood(SSL* ssl) { + return impl_->MarkSSLSessionAsGood(ssl); +} + +void SSLSessionCacheOpenSSL::Flush() { impl_->Flush(); } + +} // namespace net diff --git a/chromium/net/socket/ssl_session_cache_openssl.h b/chromium/net/socket/ssl_session_cache_openssl.h new file mode 100644 index 00000000000..bbd9659641d --- /dev/null +++ b/chromium/net/socket/ssl_session_cache_openssl.h @@ -0,0 +1,141 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_SESSION_CACHE_OPENSSL_H +#define NET_SOCKET_SSL_SESSION_CACHE_OPENSSL_H + +#include <string> + +#include "base/basictypes.h" +#include "net/base/net_export.h" + +// Avoid including OpenSSL headers here. +typedef struct ssl_ctx_st SSL_CTX; +typedef struct ssl_st SSL; + +namespace net { + +class SSLSessionCacheOpenSSLImpl; + +// A class used to implement a custom cache of SSL_SESSION objects. +// Usage is as follows: +// +// - Client creates a new cache instance with appropriate configuration, +// associating it with a given SSL_CTX object. +// +// The configuration must include a pointer to a client-provided function +// that can retrieve a unique cache key from an existing SSL handle. +// +// - When creating a new SSL connection, call SetSSLSession() with the newly +// created SSL handle, and a cache key for the current host/port. If a +// session is already in the cache, it will be added to the connection +// through SSL_set_session(). +// +// - Otherwise, OpenSSL will create a new SSL_SESSION object during the +// connection, and will pass it to the cache's internal functions, +// transparently to the client. +// +// - Each session has a timeout in seconds, which are checked every N-th call +// to SetSSLSession(), where N is the current configuration's +// |check_expiration_count|. Expired sessions are removed automatically +// from the cache. +// +// - Clients can call Flush() to remove all sessions from the cache, this is +// useful when the system's certificate store has changed. +// +// This class is thread-safe. There shouldn't be any issue with multiple +// SSL connections being performed in parallel in multiple threads. +class NET_EXPORT SSLSessionCacheOpenSSL { + public: + // Type of a function that takes a SSL handle and returns a unique cache + // key string to identify it. + typedef std::string GetSessionKeyFunction(const SSL* ssl); + + // A small structure used to configure a cache on creation. + // |key_func| is a function used at runtime to retrieve the unique cache key + // from a given SSL connection handle. + // |max_entries| is the maximum number of entries in the cache. + // |expiration_check_count| is the number of calls to SetSSLSession() that + // will trigger a check for expired sessions. + // |timeout_seconds| is the timeout of new cached sessions in seconds. + struct Config { + GetSessionKeyFunction* key_func; + size_t max_entries; + size_t expiration_check_count; + int timeout_seconds; + }; + + SSLSessionCacheOpenSSL() : impl_(NULL) {} + + // Construct a new cache instance. + // |ctx| is a SSL_CTX context handle that will be associated with this cache. + // |key_func| is a function that will be used at runtime to retrieve the + // unique cache key from a SSL connection handle. + // |max_entries| is the maximum number of entries in the cache. + // |timeout_seconds| is the timeout of new cached sessions in seconds. + // |expiration_check_count| is the number of calls to SetSSLSession() that + // will trigger a check for expired sessions. + SSLSessionCacheOpenSSL(SSL_CTX* ctx, const Config& config) : impl_(NULL) { + Reset(ctx, config); + } + + // Destroy this instance. This must be called before the SSL_CTX handle + // is destroyed. + ~SSLSessionCacheOpenSSL(); + + // Reset the cache configuration. This flushes any existing entries. + void Reset(SSL_CTX* ctx, const Config& config); + + size_t size() const; + + // Lookup the unique cache key associated with |ssl| connection handle, + // and find a cached session for it in the cache. If one is found, associate + // it with the |ssl| connection through SSL_set_session(). Consider using + // SetSSLSessionWithKey() if you already have the key. + // + // Every |check_expiration_count| call to either SetSSLSession() or + // SetSSLSessionWithKey() triggers a check for, and removal of, expired + // sessions. + // + // Return true iff a cached session was associated with the |ssl| connection. + bool SetSSLSession(SSL* ssl); + + // A more efficient variant of SetSSLSession() that can be used if the caller + // already has the cache key for the session of interest. The caller must + // ensure that the value of |cache_key| matches the result of calling the + // configuration's |key_func| function with the |ssl| as parameter. + // + // Every |check_expiration_count| call to either SetSSLSession() or + // SetSSLSessionWithKey() triggers a check for, and removal of, expired + // sessions. + // + // Return true iff a cached session was associated with the |ssl| connection. + bool SetSSLSessionWithKey(SSL* ssl, const std::string& cache_key); + + // Indicates that the SSL session associated with |ssl| is "good" - that is, + // that all associated cryptographic parameters that were negotiated, + // including the peer's certificate, were successfully validated. Because + // OpenSSL does not provide an asynchronous certificate verification + // callback, it's necessary to manually manage the sessions to ensure that + // only validated sessions are resumed. + void MarkSSLSessionAsGood(SSL* ssl); + + // Flush removes all entries from the cache. This is typically called when + // the system's certificate store has changed. + void Flush(); + + // TODO(digit): Move to client code. + static const int kDefaultTimeoutSeconds = 60 * 60; + static const size_t kMaxEntries = 1024; + static const size_t kMaxExpirationChecks = 256; + + private: + DISALLOW_COPY_AND_ASSIGN(SSLSessionCacheOpenSSL); + + SSLSessionCacheOpenSSLImpl* impl_; +}; + +} // namespace net + +#endif // NET_SOCKET_SSL_SESSION_CACHE_OPENSSL_H diff --git a/chromium/net/socket/ssl_session_cache_openssl_unittest.cc b/chromium/net/socket/ssl_session_cache_openssl_unittest.cc new file mode 100644 index 00000000000..22c4fbaeb9c --- /dev/null +++ b/chromium/net/socket/ssl_session_cache_openssl_unittest.cc @@ -0,0 +1,378 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/ssl_session_cache_openssl.h" + +#include <openssl/ssl.h> + +#include "base/lazy_instance.h" +#include "base/logging.h" +#include "base/strings/stringprintf.h" +#include "crypto/openssl_util.h" + +#include "testing/gtest/include/gtest/gtest.h" + +// This is an internal OpenSSL function that can be used to create a new +// session for an existing SSL object. This shall force a call to the +// 'generate_session_id' callback from the SSL's session context. +// |s| is the target SSL connection handle. +// |session| is non-0 to ask for the creation of a new session. If 0, +// this will set an empty session with no ID instead. +extern "C" int ssl_get_new_session(SSL* s, int session); + +// This is an internal OpenSSL function which is used internally to add +// a new session to the cache. It is normally triggered by a succesful +// connection. However, this unit test does not use the network at all. +extern "C" void ssl_update_cache(SSL* s, int mode); + +namespace net { + +namespace { + +typedef crypto::ScopedOpenSSL<SSL, SSL_free> ScopedSSL; + +// Helper class used to associate arbitrary std::string keys with SSL objects. +class SSLKeyHelper { + public: + // Return the string associated with a given SSL handle |ssl|, or the + // empty string if none exists. + static std::string Get(const SSL* ssl) { + return GetInstance()->GetValue(ssl); + } + + // Associate a string with a given SSL handle |ssl|. + static void Set(SSL* ssl, const std::string& value) { + GetInstance()->SetValue(ssl, value); + } + + static SSLKeyHelper* GetInstance() { + static base::LazyInstance<SSLKeyHelper>::Leaky s_instance = + LAZY_INSTANCE_INITIALIZER; + return s_instance.Pointer(); + } + + SSLKeyHelper() { + ex_index_ = SSL_get_ex_new_index(0, NULL, NULL, KeyDup, KeyFree); + CHECK_NE(-1, ex_index_); + } + + std::string GetValue(const SSL* ssl) { + std::string* value = + reinterpret_cast<std::string*>(SSL_get_ex_data(ssl, ex_index_)); + if (!value) + return std::string(); + return *value; + } + + void SetValue(SSL* ssl, const std::string& value) { + int ret = SSL_set_ex_data(ssl, ex_index_, new std::string(value)); + CHECK_EQ(1, ret); + } + + // Called when an SSL object is copied through SSL_dup(). This needs to copy + // the value as well. + static int KeyDup(CRYPTO_EX_DATA* to, + CRYPTO_EX_DATA* from, + void* from_fd, + int idx, + long argl, + void* argp) { + // |from_fd| is really the address of a temporary pointer. On input, it + // points to the value from the original SSL object. The function must + // update it to the address of a copy. + std::string** ptr = reinterpret_cast<std::string**>(from_fd); + std::string* old_string = *ptr; + std::string* new_string = new std::string(*old_string); + *ptr = new_string; + return 0; // Ignored by the implementation. + } + + // Called to destroy the value associated with an SSL object. + static void KeyFree(void* parent, + void* ptr, + CRYPTO_EX_DATA* ad, + int index, + long argl, + void* argp) { + std::string* value = reinterpret_cast<std::string*>(ptr); + delete value; + } + + int ex_index_; +}; + +} // namespace + +class SSLSessionCacheOpenSSLTest : public testing::Test { + public: + SSLSessionCacheOpenSSLTest() { + crypto::EnsureOpenSSLInit(); + ctx_.reset(SSL_CTX_new(SSLv23_client_method())); + cache_.Reset(ctx_.get(), kDefaultConfig); + } + + // Reset cache configuration. + void ResetConfig(const SSLSessionCacheOpenSSL::Config& config) { + cache_.Reset(ctx_.get(), config); + } + + // Helper function to create a new SSL connection object associated with + // a given unique |cache_key|. This does _not_ add the session to the cache. + // Caller must free the object with SSL_free(). + SSL* NewSSL(const std::string& cache_key) { + SSL* ssl = SSL_new(ctx_.get()); + if (!ssl) + return NULL; + + SSLKeyHelper::Set(ssl, cache_key); // associate cache key. + ResetSessionID(ssl); // create new unique session ID. + return ssl; + } + + // Reset the session ID of a given SSL object. This creates a new session + // with a new unique random ID. Does not add it to the cache. + static void ResetSessionID(SSL* ssl) { ssl_get_new_session(ssl, 1); } + + // Add a given SSL object and its session to the cache. + void AddToCache(SSL* ssl) { + ssl_update_cache(ssl, ctx_.get()->session_cache_mode); + } + + static const SSLSessionCacheOpenSSL::Config kDefaultConfig; + + protected: + crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ctx_; + // |cache_| must be destroyed before |ctx_| and thus appears after it. + SSLSessionCacheOpenSSL cache_; +}; + +// static +const SSLSessionCacheOpenSSL::Config + SSLSessionCacheOpenSSLTest::kDefaultConfig = { + &SSLKeyHelper::Get, // key_func + 1024, // max_entries + 256, // expiration_check_count + 60 * 60, // timeout_seconds +}; + +TEST_F(SSLSessionCacheOpenSSLTest, EmptyCacheCreation) { + EXPECT_EQ(0U, cache_.size()); +} + +TEST_F(SSLSessionCacheOpenSSLTest, CacheOneSession) { + ScopedSSL ssl(NewSSL("hello")); + + EXPECT_EQ(0U, cache_.size()); + AddToCache(ssl.get()); + EXPECT_EQ(1U, cache_.size()); + ssl.reset(NULL); + EXPECT_EQ(1U, cache_.size()); +} + +TEST_F(SSLSessionCacheOpenSSLTest, CacheMultipleSessions) { + const size_t kNumItems = 100; + int local_id = 1; + + // Add kNumItems to the cache. + for (size_t n = 0; n < kNumItems; ++n) { + std::string local_id_string = base::StringPrintf("%d", local_id++); + ScopedSSL ssl(NewSSL(local_id_string)); + AddToCache(ssl.get()); + EXPECT_EQ(n + 1, cache_.size()); + } +} + +TEST_F(SSLSessionCacheOpenSSLTest, Flush) { + const size_t kNumItems = 100; + int local_id = 1; + + // Add kNumItems to the cache. + for (size_t n = 0; n < kNumItems; ++n) { + std::string local_id_string = base::StringPrintf("%d", local_id++); + ScopedSSL ssl(NewSSL(local_id_string)); + AddToCache(ssl.get()); + } + EXPECT_EQ(kNumItems, cache_.size()); + + cache_.Flush(); + EXPECT_EQ(0U, cache_.size()); +} + +TEST_F(SSLSessionCacheOpenSSLTest, SetSSLSession) { + const std::string key("hello"); + ScopedSSL ssl(NewSSL(key)); + + // First call should fail because the session is not in the cache. + EXPECT_FALSE(cache_.SetSSLSession(ssl.get())); + SSL_SESSION* session = ssl.get()->session; + EXPECT_TRUE(session); + EXPECT_EQ(1, session->references); + + AddToCache(ssl.get()); + EXPECT_EQ(2, session->references); + + // Mark the session as good, so that it is re-used for the second connection. + cache_.MarkSSLSessionAsGood(ssl.get()); + + ssl.reset(NULL); + EXPECT_EQ(1, session->references); + + // Second call should find the session ID and associate it with |ssl2|. + ScopedSSL ssl2(NewSSL(key)); + EXPECT_TRUE(cache_.SetSSLSession(ssl2.get())); + + EXPECT_EQ(session, ssl2.get()->session); + EXPECT_EQ(2, session->references); +} + +TEST_F(SSLSessionCacheOpenSSLTest, SetSSLSessionWithKey) { + const std::string key("hello"); + ScopedSSL ssl(NewSSL(key)); + AddToCache(ssl.get()); + cache_.MarkSSLSessionAsGood(ssl.get()); + ssl.reset(NULL); + + ScopedSSL ssl2(NewSSL(key)); + EXPECT_TRUE(cache_.SetSSLSessionWithKey(ssl2.get(), key)); +} + +TEST_F(SSLSessionCacheOpenSSLTest, CheckSessionReplacement) { + // Check that if two SSL connections have the same key, only one + // corresponding session can be stored in the cache. + const std::string common_key("common-key"); + ScopedSSL ssl1(NewSSL(common_key)); + ScopedSSL ssl2(NewSSL(common_key)); + + AddToCache(ssl1.get()); + EXPECT_EQ(1U, cache_.size()); + EXPECT_EQ(2, ssl1.get()->session->references); + + // This ends up calling OnSessionAdded which will discover that there is + // already one session ID associated with the key, and will replace it. + AddToCache(ssl2.get()); + EXPECT_EQ(1U, cache_.size()); + EXPECT_EQ(1, ssl1.get()->session->references); + EXPECT_EQ(2, ssl2.get()->session->references); +} + +// Check that when two connections have the same key, a new session is created +// if the existing session has not yet been marked "good". Further, after the +// first session completes, if the second session has replaced it in the cache, +// new sessions should continue to fail until the currently cached session +// succeeds. +TEST_F(SSLSessionCacheOpenSSLTest, CheckSessionReplacementWhenNotGood) { + const std::string key("hello"); + ScopedSSL ssl(NewSSL(key)); + + // First call should fail because the session is not in the cache. + EXPECT_FALSE(cache_.SetSSLSession(ssl.get())); + SSL_SESSION* session = ssl.get()->session; + ASSERT_TRUE(session); + EXPECT_EQ(1, session->references); + + AddToCache(ssl.get()); + EXPECT_EQ(2, session->references); + + // Second call should find the session ID, but because it is not yet good, + // fail to associate it with |ssl2|. + ScopedSSL ssl2(NewSSL(key)); + EXPECT_FALSE(cache_.SetSSLSession(ssl2.get())); + SSL_SESSION* session2 = ssl2.get()->session; + ASSERT_TRUE(session2); + EXPECT_EQ(1, session2->references); + + EXPECT_NE(session, session2); + + // Add the second connection to the cache. It should replace the first + // session, and the cache should hold on to the second session. + AddToCache(ssl2.get()); + EXPECT_EQ(1, session->references); + EXPECT_EQ(2, session2->references); + + // Mark the first session as good, simulating it completing. + cache_.MarkSSLSessionAsGood(ssl.get()); + + // Third call should find the session ID, but because the second session (the + // current cache entry) is not yet good, fail to associate it with |ssl3|. + ScopedSSL ssl3(NewSSL(key)); + EXPECT_FALSE(cache_.SetSSLSession(ssl3.get())); + EXPECT_NE(session, ssl3.get()->session); + EXPECT_NE(session2, ssl3.get()->session); + EXPECT_EQ(1, ssl3.get()->session->references); +} + +TEST_F(SSLSessionCacheOpenSSLTest, CheckEviction) { + const size_t kMaxItems = 20; + int local_id = 1; + + SSLSessionCacheOpenSSL::Config config = kDefaultConfig; + config.max_entries = kMaxItems; + ResetConfig(config); + + // Add kMaxItems to the cache. + for (size_t n = 0; n < kMaxItems; ++n) { + std::string local_id_string = base::StringPrintf("%d", local_id++); + ScopedSSL ssl(NewSSL(local_id_string)); + + AddToCache(ssl.get()); + EXPECT_EQ(n + 1, cache_.size()); + } + + // Continue adding new items to the cache, check that old ones are + // evicted. + for (size_t n = 0; n < kMaxItems; ++n) { + std::string local_id_string = base::StringPrintf("%d", local_id++); + ScopedSSL ssl(NewSSL(local_id_string)); + + AddToCache(ssl.get()); + EXPECT_EQ(kMaxItems, cache_.size()); + } +} + +// Check that session expiration works properly. +TEST_F(SSLSessionCacheOpenSSLTest, CheckExpiration) { + const size_t kMaxCheckCount = 10; + const size_t kNumEntries = 20; + + SSLSessionCacheOpenSSL::Config config = kDefaultConfig; + config.expiration_check_count = kMaxCheckCount; + config.timeout_seconds = 1000; + ResetConfig(config); + + // Add |kNumItems - 1| session entries with crafted time values. + for (size_t n = 0; n < kNumEntries - 1U; ++n) { + std::string key = base::StringPrintf("%d", static_cast<int>(n)); + ScopedSSL ssl(NewSSL(key)); + // Cheat a little: Force the session |time| value, this guarantees that they + // are expired, given that ::time() will always return a value that is + // past the first 100 seconds after the Unix epoch. + ssl.get()->session->time = static_cast<long>(n); + AddToCache(ssl.get()); + } + EXPECT_EQ(kNumEntries - 1U, cache_.size()); + + // Add nother session which will get the current time, and thus not be + // expirable until 1000 seconds have passed. + ScopedSSL good_ssl(NewSSL("good-key")); + AddToCache(good_ssl.get()); + good_ssl.reset(NULL); + EXPECT_EQ(kNumEntries, cache_.size()); + + // Call SetSSLSession() |kMaxCheckCount - 1| times, this shall not expire + // any session + for (size_t n = 0; n < kMaxCheckCount - 1U; ++n) { + ScopedSSL ssl(NewSSL("unknown-key")); + cache_.SetSSLSession(ssl.get()); + EXPECT_EQ(kNumEntries, cache_.size()); + } + + // Call SetSSLSession another time, this shall expire all sessions except + // the last one. + ScopedSSL bad_ssl(NewSSL("unknown-key")); + cache_.SetSSLSession(bad_ssl.get()); + bad_ssl.reset(NULL); + EXPECT_EQ(1U, cache_.size()); +} + +} // namespace net diff --git a/chromium/net/socket/stream_listen_socket.cc b/chromium/net/socket/stream_listen_socket.cc index 1109e7527c3..960991b7c6d 100644 --- a/chromium/net/socket/stream_listen_socket.cc +++ b/chromium/net/socket/stream_listen_socket.cc @@ -65,13 +65,13 @@ StreamListenSocket::StreamListenSocket(SocketDescriptor s, } StreamListenSocket::~StreamListenSocket() { + CloseSocket(); #if defined(OS_WIN) if (socket_event_) { WSACloseEvent(socket_event_); socket_event_ = WSA_INVALID_EVENT; } #endif - CloseSocket(socket_); } void StreamListenSocket::Send(const char* bytes, int len, @@ -194,13 +194,13 @@ void StreamListenSocket::Close() { socket_delegate_->DidClose(this); } -void StreamListenSocket::CloseSocket(SocketDescriptor s) { - if (s && s != kInvalidSocket) { +void StreamListenSocket::CloseSocket() { + if (socket_ != kInvalidSocket) { UnwatchSocket(); #if defined(OS_WIN) - closesocket(s); + closesocket(socket_); #elif defined(OS_POSIX) - close(s); + close(socket_); #endif } } diff --git a/chromium/net/socket/stream_listen_socket.h b/chromium/net/socket/stream_listen_socket.h index 9825a4ef126..3c9b984ed76 100644 --- a/chromium/net/socket/stream_listen_socket.h +++ b/chromium/net/socket/stream_listen_socket.h @@ -91,7 +91,7 @@ class NET_EXPORT StreamListenSocket void Listen(); void Read(); void Close(); - void CloseSocket(SocketDescriptor s); + void CloseSocket(); // Pass any value in case of Windows, because in Windows // we are not using state. diff --git a/chromium/net/socket/tcp_listen_socket_unittest.cc b/chromium/net/socket/tcp_listen_socket_unittest.cc index b122c6143d8..41c41f81fe7 100644 --- a/chromium/net/socket/tcp_listen_socket_unittest.cc +++ b/chromium/net/socket/tcp_listen_socket_unittest.cc @@ -18,11 +18,9 @@ namespace net { -static const int kReadBufSize = 1024; -static const char kHelloWorld[] = "HELLO, WORLD"; -static const int kMaxQueueSize = 20; -static const char kLoopback[] = "127.0.0.1"; -static const int kDefaultTimeoutMs = 5000; +const int kReadBufSize = 1024; +const char kHelloWorld[] = "HELLO, WORLD"; +const char kLoopback[] = "127.0.0.1"; TCPListenSocketTester::TCPListenSocketTester() : loop_(NULL), @@ -75,7 +73,7 @@ void TCPListenSocketTester::TearDown() { #if defined(OS_WIN) ASSERT_EQ(0, closesocket(test_socket_)); #elif defined(OS_POSIX) - ASSERT_EQ(0, HANDLE_EINTR(close(test_socket_))); + ASSERT_EQ(0, IGNORE_EINTR(close(test_socket_))); #endif NextAction(); ASSERT_EQ(ACTION_CLOSE, last_action_.type()); diff --git a/chromium/net/socket/tcp_socket_libevent.cc b/chromium/net/socket/tcp_socket_libevent.cc index 66416f70207..f4e4fe861af 100644 --- a/chromium/net/socket/tcp_socket_libevent.cc +++ b/chromium/net/socket/tcp_socket_libevent.cc @@ -69,6 +69,20 @@ bool SetTCPKeepAlive(int fd, bool enable, int delay) { return true; } +int MapAcceptError(int os_error) { + switch (os_error) { + // If the client aborts the connection before the server calls accept, + // POSIX specifies accept should fail with ECONNABORTED. The server can + // ignore the error and just call accept again, so we map the error to + // ERR_IO_PENDING. See UNIX Network Programming, Vol. 1, 3rd Ed., Sec. + // 5.11, "Connection Abort before accept Returns". + case ECONNABORTED: + return ERR_IO_PENDING; + default: + return MapSystemError(os_error); + } +} + int MapConnectError(int os_error) { switch (os_error) { case EACCES: @@ -507,7 +521,7 @@ void TCPSocketLibevent::Close() { DCHECK(ok); if (socket_ != kInvalidSocket) { - if (HANDLE_EINTR(close(socket_)) < 0) + if (IGNORE_EINTR(close(socket_)) < 0) PLOG(ERROR) << "close"; socket_ = kInvalidSocket; } @@ -567,7 +581,7 @@ int TCPSocketLibevent::AcceptInternal(scoped_ptr<TCPSocketLibevent>* socket, storage.addr, &storage.addr_len)); if (new_socket < 0) { - int net_error = MapSystemError(errno); + int net_error = MapAcceptError(errno); if (net_error != ERR_IO_PENDING) net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); return net_error; @@ -576,7 +590,7 @@ int TCPSocketLibevent::AcceptInternal(scoped_ptr<TCPSocketLibevent>* socket, IPEndPoint ip_end_point; if (!ip_end_point.FromSockAddr(storage.addr, storage.addr_len)) { NOTREACHED(); - if (HANDLE_EINTR(close(new_socket)) < 0) + if (IGNORE_EINTR(close(new_socket)) < 0) PLOG(ERROR) << "close"; net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_ADDRESS_INVALID); diff --git a/chromium/net/socket/tcp_socket_win.cc b/chromium/net/socket/tcp_socket_win.cc index 7d76232f962..5f8cd5c3e86 100644 --- a/chromium/net/socket/tcp_socket_win.cc +++ b/chromium/net/socket/tcp_socket_win.cc @@ -636,11 +636,6 @@ void TCPSocketWin::Close() { socket_ = INVALID_SOCKET; } - if (accept_event_) { - WSACloseEvent(accept_event_); - accept_event_ = WSA_INVALID_EVENT; - } - if (!accept_callback_.is_null()) { accept_watcher_.StopWatching(); accept_socket_ = NULL; @@ -648,6 +643,11 @@ void TCPSocketWin::Close() { accept_callback_.Reset(); } + if (accept_event_) { + WSACloseEvent(accept_event_); + accept_event_ = WSA_INVALID_EVENT; + } + if (core_) { if (waiting_connect_) { // We closed the socket, so this notification will never come. @@ -742,6 +742,14 @@ void TCPSocketWin::OnObjectSignaled(HANDLE object) { accept_address_ = NULL; base::ResetAndReturn(&accept_callback_).Run(result); } + } else { + // This happens when a client opens a connection and closes it before we + // have a chance to accept it. + DCHECK(ev.lNetworkEvents == 0); + + // Start watching the next FD_ACCEPT event. + WSAEventSelect(socket_, accept_event_, FD_ACCEPT); + accept_watcher_.StartWatching(accept_event_, this); } } diff --git a/chromium/net/socket/transport_client_socket_pool.cc b/chromium/net/socket/transport_client_socket_pool.cc index d03e3e651ac..ec56dcba65c 100644 --- a/chromium/net/socket/transport_client_socket_pool.cc +++ b/chromium/net/socket/transport_client_socket_pool.cc @@ -7,10 +7,12 @@ #include <algorithm> #include "base/compiler_specific.h" +#include "base/lazy_instance.h" #include "base/logging.h" #include "base/message_loop/message_loop.h" #include "base/metrics/histogram.h" #include "base/strings/string_util.h" +#include "base/synchronization/lock.h" #include "base/time/time.h" #include "base/values.h" #include "net/base/ip_endpoint.h" @@ -46,6 +48,14 @@ bool AddressListOnlyContainsIPv6(const AddressList& list) { } // namespace +// This lock protects |g_last_connect_time|. +static base::LazyInstance<base::Lock>::Leaky + g_last_connect_time_lock = LAZY_INSTANCE_INITIALIZER; + +// |g_last_connect_time| has the last time a connect() call is made. +static base::LazyInstance<base::TimeTicks>::Leaky + g_last_connect_time = LAZY_INSTANCE_INITIALIZER; + TransportSocketParams::TransportSocketParams( const HostPortPair& host_port_pair, bool disable_resolver_cache, @@ -85,7 +95,8 @@ TransportConnectJob::TransportConnectJob( params_(params), client_socket_factory_(client_socket_factory), resolver_(host_resolver), - next_state_(STATE_NONE) { + next_state_(STATE_NONE), + interval_between_connects_(CONNECT_INTERVAL_GT_20MS) { } TransportConnectJob::~TransportConnectJob() { @@ -186,6 +197,25 @@ int TransportConnectJob::DoResolveHostComplete(int result) { } int TransportConnectJob::DoTransportConnect() { + base::TimeTicks now = base::TimeTicks::Now(); + base::TimeTicks last_connect_time; + { + base::AutoLock lock(g_last_connect_time_lock.Get()); + last_connect_time = g_last_connect_time.Get(); + *g_last_connect_time.Pointer() = now; + } + if (last_connect_time.is_null()) { + interval_between_connects_ = CONNECT_INTERVAL_GT_20MS; + } else { + int64 interval = (now - last_connect_time).InMilliseconds(); + if (interval <= 10) + interval_between_connects_ = CONNECT_INTERVAL_LE_10MS; + else if (interval <= 20) + interval_between_connects_ = CONNECT_INTERVAL_LE_20MS; + else + interval_between_connects_ = CONNECT_INTERVAL_GT_20MS; + } + next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; transport_socket_ = client_socket_factory_->CreateTransportClientSocket( addresses_, net_log().net_log(), net_log().source()); @@ -222,6 +252,36 @@ int TransportConnectJob::DoTransportConnectComplete(int result) { base::TimeDelta::FromMinutes(10), 100); + switch (interval_between_connects_) { + case CONNECT_INTERVAL_LE_10MS: + UMA_HISTOGRAM_CUSTOM_TIMES( + "Net.TCP_Connection_Latency_Interval_LessThanOrEqual_10ms", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + break; + case CONNECT_INTERVAL_LE_20MS: + UMA_HISTOGRAM_CUSTOM_TIMES( + "Net.TCP_Connection_Latency_Interval_LessThanOrEqual_20ms", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + break; + case CONNECT_INTERVAL_GT_20MS: + UMA_HISTOGRAM_CUSTOM_TIMES( + "Net.TCP_Connection_Latency_Interval_GreaterThan_20ms", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + break; + default: + NOTREACHED(); + break; + } + if (is_ipv4) { UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_No_Race", connect_duration, diff --git a/chromium/net/socket/transport_client_socket_pool.h b/chromium/net/socket/transport_client_socket_pool.h index 16e421a4550..1c22bf29ec3 100644 --- a/chromium/net/socket/transport_client_socket_pool.h +++ b/chromium/net/socket/transport_client_socket_pool.h @@ -93,6 +93,12 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { STATE_NONE, }; + enum ConnectInterval { + CONNECT_INTERVAL_LE_10MS, + CONNECT_INTERVAL_LE_20MS, + CONNECT_INTERVAL_GT_20MS, + }; + void OnIOComplete(int result); // Runs the state transition loop. @@ -125,6 +131,9 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { base::TimeTicks fallback_connect_start_time_; base::OneShotTimer<TransportConnectJob> fallback_timer_; + // Track the interval between this connect and previous connect. + ConnectInterval interval_between_connects_; + DISALLOW_COPY_AND_ASSIGN(TransportConnectJob); }; diff --git a/chromium/net/socket/transport_client_socket_pool_unittest.cc b/chromium/net/socket/transport_client_socket_pool_unittest.cc index a984ea3b740..ff85847979b 100644 --- a/chromium/net/socket/transport_client_socket_pool_unittest.cc +++ b/chromium/net/socket/transport_client_socket_pool_unittest.cc @@ -152,6 +152,8 @@ class MockClientSocket : public StreamSocket { bool connected_; const AddressList addrlist_; BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(MockClientSocket); }; class MockFailingClientSocket : public StreamSocket { @@ -214,6 +216,8 @@ class MockFailingClientSocket : public StreamSocket { private: const AddressList addrlist_; BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket); }; class MockPendingClientSocket : public StreamSocket { @@ -228,13 +232,13 @@ class MockPendingClientSocket : public StreamSocket { bool should_stall, base::TimeDelta delay, net::NetLog* net_log) - : weak_factory_(this), - should_connect_(should_connect), + : should_connect_(should_connect), should_stall_(should_stall), delay_(delay), is_connected_(false), addrlist_(addrlist), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), + weak_factory_(this) { } // StreamSocket implementation. @@ -312,13 +316,16 @@ class MockPendingClientSocket : public StreamSocket { } } - base::WeakPtrFactory<MockPendingClientSocket> weak_factory_; bool should_connect_; bool should_stall_; base::TimeDelta delay_; bool is_connected_; const AddressList addrlist_; BoundNetLog net_log_; + + base::WeakPtrFactory<MockPendingClientSocket> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(MockPendingClientSocket); }; class MockClientSocketFactory : public ClientSocketFactory { @@ -430,6 +437,8 @@ class MockClientSocketFactory : public ClientSocketFactory { int client_socket_index_; int client_socket_index_max_; base::TimeDelta delay_; + + DISALLOW_COPY_AND_ASSIGN(MockClientSocketFactory); }; class TransportClientSocketPoolTest : public testing::Test { @@ -488,6 +497,8 @@ class TransportClientSocketPoolTest : public testing::Test { MockClientSocketFactory client_socket_factory_; TransportClientSocketPool pool_; ClientSocketPoolTest test_base_; + + DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPoolTest); }; TEST(TransportConnectJobTest, MakeAddrListStartWithIPv4) { @@ -579,7 +590,7 @@ TEST_F(TransportClientSocketPoolTest, Basic) { // Make sure that TransportConnectJob passes on its priority to its // HostResolver request on Init. TEST_F(TransportClientSocketPoolTest, SetResolvePriorityOnInit) { - for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) { + for (int i = MINIMUM_PRIORITY; i <= MAXIMUM_PRIORITY; ++i) { RequestPriority priority = static_cast<RequestPriority>(i); TestCompletionCallback callback; ClientSocketHandle handle; diff --git a/chromium/net/socket/unix_domain_socket_posix.cc b/chromium/net/socket/unix_domain_socket_posix.cc index 2b781d58b35..3141f7166b2 100644 --- a/chromium/net/socket/unix_domain_socket_posix.cc +++ b/chromium/net/socket/unix_domain_socket_posix.cc @@ -130,7 +130,7 @@ SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path, LOG(ERROR) << "Could not bind unix domain socket to " << path; if (use_abstract_namespace) LOG(ERROR) << " (with abstract namespace enabled)"; - if (HANDLE_EINTR(close(s)) < 0) + if (IGNORE_EINTR(close(s)) < 0) LOG(ERROR) << "close() error"; return kInvalidSocket; } @@ -145,7 +145,7 @@ void UnixDomainSocket::Accept() { gid_t group_id; if (!GetPeerIds(conn, &user_id, &group_id) || !auth_callback_.Run(user_id, group_id)) { - if (HANDLE_EINTR(close(conn)) < 0) + if (IGNORE_EINTR(close(conn)) < 0) LOG(ERROR) << "close() error"; return; } diff --git a/chromium/net/socket/unix_domain_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_socket_posix_unittest.cc index f062d274205..b1857e62e0e 100644 --- a/chromium/net/socket/unix_domain_socket_posix_unittest.cc +++ b/chromium/net/socket/unix_domain_socket_posix_unittest.cc @@ -40,7 +40,6 @@ namespace net { namespace { const char kSocketFilename[] = "unix_domain_socket_for_testing"; -const char kFallbackSocketName[] = "unix_domain_socket_for_testing_2"; const char kInvalidSocketPath[] = "/invalid/path"; const char kMsg[] = "hello"; @@ -55,7 +54,7 @@ enum EventType { string MakeSocketPath(const string& socket_file_name) { base::FilePath temp_dir; - file_util::GetTempDir(&temp_dir); + base::GetTempDir(&temp_dir); return temp_dir.Append(socket_file_name).value(); } @@ -275,6 +274,7 @@ TEST_F(UnixDomainSocketTest, TestFallbackName) { file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); EXPECT_TRUE(socket_.get() == NULL); // Now with a fallback name. + const char kFallbackSocketName[] = "unix_domain_socket_for_testing_2"; socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace( file_path_.value(), MakeSocketPath(kFallbackSocketName), @@ -306,7 +306,7 @@ TEST_F(UnixDomainSocketTest, TestWithClient) { ASSERT_EQ(kMsg, socket_delegate_->ReceivedData()); // Close the client socket. - ret = HANDLE_EINTR(close(sock)); + ret = IGNORE_EINTR(close(sock)); event = event_manager_->WaitForEvent(); ASSERT_EQ(EVENT_CLOSE, event); } |