diff options
author | Ben Caimano <ben.caimano@10gen.com> | 2019-05-23 13:18:38 -0400 |
---|---|---|
committer | Ben Caimano <ben.caimano@10gen.com> | 2019-05-29 16:23:44 -0400 |
commit | 60008cd952996d5558b9fb4c5f66d1e2d1af0d4d (patch) | |
tree | 12bdd4ab9b3a2b31ea17dd619640a0ff668729c2 /src/mongo/executor | |
parent | ffd64883d70c9139d7b56d076e249f3fef77e54e (diff) | |
download | mongo-60008cd952996d5558b9fb4c5f66d1e2d1af0d4d.tar.gz |
SERVER-41318 Return SemiFutures for ConnectionPool
Diffstat (limited to 'src/mongo/executor')
-rw-r--r-- | src/mongo/executor/connection_pool.cpp | 299 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool.h | 16 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool_test.cpp | 55 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool_test_fixture.cpp | 87 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool_test_fixture.h | 9 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool_tl.cpp | 23 | ||||
-rw-r--r-- | src/mongo/executor/network_interface.h | 3 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_tl.cpp | 172 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_tl.h | 30 |
9 files changed, 337 insertions, 357 deletions
diff --git a/src/mongo/executor/connection_pool.cpp b/src/mongo/executor/connection_pool.cpp index ac43bbaed6f..d8a06ad48f3 100644 --- a/src/mongo/executor/connection_pool.cpp +++ b/src/mongo/executor/connection_pool.cpp @@ -93,23 +93,24 @@ class ConnectionPool::SpecificPool final : public std::enable_shared_from_this<ConnectionPool::SpecificPool> { public: /** - * Whenever a function enters a specific pool, - * the function needs to be guarded by the pool lock. + * Whenever a function enters a specific pool, the function needs to be guarded by the lock. * * This callback also (perhaps overly aggressively) binds a shared pointer to the guard. * It is *always* safe to reference the original specific pool in the guarded function object. * * For a function object of signature: - * R riskyBusiness(stdx::unique_lock<stdx::mutex>, ArgTypes...); + * void riskyBusiness(ArgTypes...); * * It returns a function object of signature: - * R safeCallback(ArgTypes...); + * void safeCallback(ArgTypes...); */ template <typename Callback> auto guardCallback(Callback&& cb) { - return [ cb = std::forward<Callback>(cb), anchor = shared_from_this() ](auto&&... args) { - return cb(stdx::unique_lock(anchor->_parent->_mutex), - std::forward<decltype(args)>(args)...); + return + [ this, cb = std::forward<Callback>(cb), anchor = shared_from_this() ](auto&&... args) { + stdx::lock_guard lk(_parent->_mutex); + cb(std::forward<decltype(args)>(args)...); + updateState(); }; } @@ -122,7 +123,7 @@ public: * Gets a connection from the specific pool. Sinks a unique_lock from the * parent to preserve the lock on _mutex */ - Future<ConnectionHandle> getConnection(Milliseconds timeout, stdx::unique_lock<stdx::mutex> lk); + Future<ConnectionHandle> getConnection(Milliseconds timeout); /** * Triggers the shutdown procedure. This function marks the state as kInShutdown @@ -130,61 +131,59 @@ public: * delist or destruct this pool. However, both will happen eventually as ConnectionHandles * are deleted. */ - void triggerShutdown(const Status& status, stdx::unique_lock<stdx::mutex> lk); + void triggerShutdown(const Status& status); /** * Cascades a failure across existing connections and requests. Invoking * this function drops all current connections and fails all current * requests with the passed status. */ - void processFailure(const Status& status, stdx::unique_lock<stdx::mutex> lk); + void processFailure(const Status& status); /** * Returns a connection to a specific pool. Sinks a unique_lock from the * parent to preserve the lock on _mutex */ - void returnConnection(ConnectionInterface* connection, stdx::unique_lock<stdx::mutex> lk); + void returnConnection(ConnectionInterface* connection); /** * Returns the number of connections currently checked out of the pool. */ - size_t inUseConnections(const stdx::unique_lock<stdx::mutex>& lk); + size_t inUseConnections(); /** * Returns the number of available connections in the pool. */ - size_t availableConnections(const stdx::unique_lock<stdx::mutex>& lk); + size_t availableConnections(); /** * Returns the number of in progress connections in the pool. */ - size_t refreshingConnections(const stdx::unique_lock<stdx::mutex>& lk); + size_t refreshingConnections(); /** * Returns the total number of connections ever created in this pool. */ - size_t createdConnections(const stdx::unique_lock<stdx::mutex>& lk); + size_t createdConnections(); /** * Returns the total number of connections currently open that belong to * this pool. This is the sum of refreshingConnections, availableConnections, * and inUseConnections. */ - size_t openConnections(const stdx::unique_lock<stdx::mutex>& lk); + size_t openConnections(); /** * Return true if the tags on the specific pool match the passed in tags */ - bool matchesTags(const stdx::unique_lock<stdx::mutex>& lk, - transport::Session::TagMask tags) const { + bool matchesTags(transport::Session::TagMask tags) const { return !!(_tags & tags); } /** * Atomically manipulate the tags in the pool */ - void mutateTags(const stdx::unique_lock<stdx::mutex>& lk, - const stdx::function<transport::Session::TagMask(transport::Session::TagMask)>& + void mutateTags(const stdx::function<transport::Session::TagMask(transport::Session::TagMask)>& mutateFunc) { _tags = mutateFunc(_tags); } @@ -196,6 +195,20 @@ public: } } + void spawnConnections(); + + template <typename CallableT> + void runOnExecutor(CallableT&& cb) { + ExecutorFuture(ExecutorPtr(_parent->_factory->getExecutor())) // + .getAsync([ anchor = shared_from_this(), + cb = std::forward<CallableT>(cb) ](Status && status) mutable { + invariant(status); + cb(); + }); + } + + void updateState(); + private: using OwnedConnection = std::shared_ptr<ConnectionInterface>; using OwnershipPool = stdx::unordered_map<ConnectionInterface*, OwnedConnection>; @@ -209,19 +222,15 @@ private: ConnectionHandle makeHandle(ConnectionInterface* connection); - void finishRefresh(stdx::unique_lock<stdx::mutex> lk, - ConnectionInterface* connPtr, - Status status); + void finishRefresh(ConnectionInterface* connPtr, Status status); - void addToReady(stdx::unique_lock<stdx::mutex>& lk, OwnedConnection conn); + void addToReady(OwnedConnection conn); - void fulfillRequests(stdx::unique_lock<stdx::mutex>& lk); - - void spawnConnections(stdx::unique_lock<stdx::mutex>& lk); + void fulfillRequests(); // This internal helper is used both by get and by fulfillRequests and differs in that it // skips some bookkeeping that the other callers do on their own - ConnectionHandle tryGetConnection(const stdx::unique_lock<stdx::mutex>& lk); + ConnectionHandle tryGetConnection(); template <typename OwnershipPoolType> typename OwnershipPoolType::mapped_type takeFromPool( @@ -229,8 +238,6 @@ private: OwnedConnection takeFromProcessingPool(ConnectionInterface* connection); - void updateStateInLock(); - private: const std::shared_ptr<ConnectionPool> _parent; @@ -280,13 +287,6 @@ private: State _state; }; -constexpr Milliseconds ConnectionPool::kDefaultHostTimeout; -size_t const ConnectionPool::kDefaultMaxConns = std::numeric_limits<size_t>::max(); -size_t const ConnectionPool::kDefaultMinConns = 1; -size_t const ConnectionPool::kDefaultMaxConnecting = std::numeric_limits<size_t>::max(); -constexpr Milliseconds ConnectionPool::kDefaultRefreshRequirement; -constexpr Milliseconds ConnectionPool::kDefaultRefreshTimeout; - const Status ConnectionPool::kConnectionStateUnknown = Status(ErrorCodes::InternalError, "Connection is in an unknown state"); @@ -317,20 +317,20 @@ void ConnectionPool::shutdown() { // Grab all current pools (under the lock) auto pools = [&] { - stdx::unique_lock<stdx::mutex> lk(_mutex); + stdx::lock_guard lk(_mutex); return _pools; }(); for (const auto& pair : pools) { - stdx::unique_lock<stdx::mutex> lk(_mutex); + stdx::lock_guard lk(_mutex); pair.second->triggerShutdown( - Status(ErrorCodes::ShutdownInProgress, "Shutting down the connection pool"), - std::move(lk)); + Status(ErrorCodes::ShutdownInProgress, "Shutting down the connection pool")); + pair.second->updateState(); } } void ConnectionPool::dropConnections(const HostAndPort& hostAndPort) { - stdx::unique_lock<stdx::mutex> lk(_mutex); + stdx::lock_guard lk(_mutex); auto iter = _pools.find(hostAndPort); @@ -338,34 +338,35 @@ void ConnectionPool::dropConnections(const HostAndPort& hostAndPort) { return; auto pool = iter->second; - pool->processFailure(Status(ErrorCodes::PooledConnectionsDropped, "Pooled connections dropped"), - std::move(lk)); + pool->processFailure( + Status(ErrorCodes::PooledConnectionsDropped, "Pooled connections dropped")); + pool->updateState(); } void ConnectionPool::dropConnections(transport::Session::TagMask tags) { // Grab all current pools (under the lock) auto pools = [&] { - stdx::unique_lock<stdx::mutex> lk(_mutex); + stdx::lock_guard lk(_mutex); return _pools; }(); for (const auto& pair : pools) { auto& pool = pair.second; - stdx::unique_lock<stdx::mutex> lk(_mutex); - if (pool->matchesTags(lk, tags)) + stdx::lock_guard lk(_mutex); + if (pool->matchesTags(tags)) continue; pool->processFailure( - Status(ErrorCodes::PooledConnectionsDropped, "Pooled connections dropped"), - std::move(lk)); + Status(ErrorCodes::PooledConnectionsDropped, "Pooled connections dropped")); + pool->updateState(); } } void ConnectionPool::mutateTags( const HostAndPort& hostAndPort, const stdx::function<transport::Session::TagMask(transport::Session::TagMask)>& mutateFunc) { - stdx::unique_lock<stdx::mutex> lk(_mutex); + stdx::lock_guard lk(_mutex); auto iter = _pools.find(hostAndPort); @@ -373,57 +374,60 @@ void ConnectionPool::mutateTags( return; auto pool = iter->second; - pool->mutateTags(lk, mutateFunc); + pool->mutateTags(mutateFunc); } void ConnectionPool::get_forTest(const HostAndPort& hostAndPort, Milliseconds timeout, GetConnectionCallback cb) { - return get(hostAndPort, transport::kGlobalSSLMode, timeout).getAsync(std::move(cb)); + // We kick ourselves onto the executor queue to prevent us from deadlocking with our own thread + auto getConnectionFunc = [ this, hostAndPort, timeout, cb = std::move(cb) ](Status &&) mutable { + get(hostAndPort, transport::kGlobalSSLMode, timeout) + .thenRunOn(_factory->getExecutor()) + .getAsync(std::move(cb)); + }; + _factory->getExecutor()->schedule(std::move(getConnectionFunc)); } -Future<ConnectionPool::ConnectionHandle> ConnectionPool::get(const HostAndPort& hostAndPort, - transport::ConnectSSLMode sslMode, - Milliseconds timeout) { - std::shared_ptr<SpecificPool> pool; - - stdx::unique_lock<stdx::mutex> lk(_mutex); - - auto iter = _pools.find(hostAndPort); - - if (iter == _pools.end()) { +SemiFuture<ConnectionPool::ConnectionHandle> ConnectionPool::get(const HostAndPort& hostAndPort, + transport::ConnectSSLMode sslMode, + Milliseconds timeout) { + stdx::lock_guard lk(_mutex); + auto& pool = _pools[hostAndPort]; + if (!pool) { pool = std::make_shared<SpecificPool>(shared_from_this(), hostAndPort, sslMode); - _pools[hostAndPort] = pool; } else { - pool = iter->second; pool->fassertSSLModeIs(sslMode); } invariant(pool); - return pool->getConnection(timeout, std::move(lk)); + auto connFuture = pool->getConnection(timeout); + pool->updateState(); + + return std::move(connFuture).semi(); } void ConnectionPool::appendConnectionStats(ConnectionPoolStats* stats) const { - stdx::unique_lock<stdx::mutex> lk(_mutex); + stdx::lock_guard lk(_mutex); for (const auto& kv : _pools) { HostAndPort host = kv.first; auto& pool = kv.second; - ConnectionStatsPer hostStats{pool->inUseConnections(lk), - pool->availableConnections(lk), - pool->createdConnections(lk), - pool->refreshingConnections(lk)}; + ConnectionStatsPer hostStats{pool->inUseConnections(), + pool->availableConnections(), + pool->createdConnections(), + pool->refreshingConnections()}; stats->updateStatsForHost(_name, host, hostStats); } } size_t ConnectionPool::getNumConnectionsPerHost(const HostAndPort& hostAndPort) const { - stdx::unique_lock<stdx::mutex> lk(_mutex); + stdx::lock_guard lk(_mutex); auto iter = _pools.find(hostAndPort); if (iter != _pools.end()) { - return iter->second->openConnections(lk); + return iter->second->openConnections(); } return 0; @@ -452,35 +456,31 @@ ConnectionPool::SpecificPool::~SpecificPool() { invariant(_checkedOutPool.empty()); } -size_t ConnectionPool::SpecificPool::inUseConnections(const stdx::unique_lock<stdx::mutex>& lk) { +size_t ConnectionPool::SpecificPool::inUseConnections() { return _checkedOutPool.size(); } -size_t ConnectionPool::SpecificPool::availableConnections( - const stdx::unique_lock<stdx::mutex>& lk) { +size_t ConnectionPool::SpecificPool::availableConnections() { return _readyPool.size(); } -size_t ConnectionPool::SpecificPool::refreshingConnections( - const stdx::unique_lock<stdx::mutex>& lk) { +size_t ConnectionPool::SpecificPool::refreshingConnections() { return _processingPool.size(); } -size_t ConnectionPool::SpecificPool::createdConnections(const stdx::unique_lock<stdx::mutex>& lk) { +size_t ConnectionPool::SpecificPool::createdConnections() { return _created; } -size_t ConnectionPool::SpecificPool::openConnections(const stdx::unique_lock<stdx::mutex>& lk) { +size_t ConnectionPool::SpecificPool::openConnections() { return _checkedOutPool.size() + _readyPool.size() + _processingPool.size(); } Future<ConnectionPool::ConnectionHandle> ConnectionPool::SpecificPool::getConnection( - Milliseconds timeout, stdx::unique_lock<stdx::mutex> lk) { + Milliseconds timeout) { invariant(_state != State::kInShutdown); - auto conn = tryGetConnection(lk); - - updateStateInLock(); + auto conn = tryGetConnection(); if (conn) { return Future<ConnectionPool::ConnectionHandle>::makeReady(std::move(conn)); @@ -496,29 +496,23 @@ Future<ConnectionPool::ConnectionHandle> ConnectionPool::SpecificPool::getConnec _requests.push_back(make_pair(expiration, std::move(pf.promise))); std::push_heap(begin(_requests), end(_requests), RequestComparator{}); - updateStateInLock(); - - lk.unlock(); - _parent->_factory->getExecutor()->schedule(guardCallback([this](auto lk, auto schedStatus) { - fassert(51169, schedStatus); - - spawnConnections(lk); - })); + spawnConnections(); return std::move(pf.future); } auto ConnectionPool::SpecificPool::makeHandle(ConnectionInterface* connection) -> ConnectionHandle { - auto fun = guardCallback( - [this](auto lk, auto connection) { returnConnection(connection, std::move(lk)); }); - - auto handle = ConnectionHandle(connection, fun); - return handle; + auto deleter = [ this, anchor = shared_from_this() ](ConnectionInterface * connection) { + runOnExecutor([this, connection]() { + stdx::lock_guard lk(_parent->_mutex); + returnConnection(connection); + updateState(); + }); + }; + return ConnectionHandle(connection, std::move(deleter)); } -ConnectionPool::ConnectionHandle ConnectionPool::SpecificPool::tryGetConnection( - const stdx::unique_lock<stdx::mutex>&) { - +ConnectionPool::ConnectionHandle ConnectionPool::SpecificPool::tryGetConnection() { while (_readyPool.size()) { // _readyPool is an LRUCache, so its begin() object is the MRU item. auto iter = _readyPool.begin(); @@ -549,9 +543,7 @@ ConnectionPool::ConnectionHandle ConnectionPool::SpecificPool::tryGetConnection( return {}; } -void ConnectionPool::SpecificPool::finishRefresh(stdx::unique_lock<stdx::mutex> lk, - ConnectionInterface* connPtr, - Status status) { +void ConnectionPool::SpecificPool::finishRefresh(ConnectionInterface* connPtr, Status status) { auto conn = takeFromProcessingPool(connPtr); // If we're in shutdown, we don't need refreshed connections @@ -565,43 +557,36 @@ void ConnectionPool::SpecificPool::finishRefresh(stdx::unique_lock<stdx::mutex> if (status.code() == ErrorCodes::NetworkInterfaceExceededTimeLimit) { LOG(0) << "Pending connection to host " << _hostAndPort << " did not complete within the connection timeout," - << " retrying with a new connection;" << openConnections(lk) + << " retrying with a new connection;" << openConnections() << " connections to that host remain open"; - spawnConnections(lk); + spawnConnections(); return; } // Pass a failure on through if (!status.isOK()) { - processFailure(status, std::move(lk)); + processFailure(status); return; } // If the host and port were dropped, let this lapse and spawn new connections if (conn->getGeneration() != _generation) { - spawnConnections(lk); + spawnConnections(); return; } // If the connection refreshed successfully, throw it back in the ready pool - addToReady(lk, std::move(conn)); + addToReady(std::move(conn)); - lk.unlock(); - _parent->_factory->getExecutor()->schedule(guardCallback([this](auto lk, auto schedStatus) { - fassert(51170, schedStatus); - fulfillRequests(lk); - })); + fulfillRequests(); } -void ConnectionPool::SpecificPool::returnConnection(ConnectionInterface* connPtr, - stdx::unique_lock<stdx::mutex> lk) { +void ConnectionPool::SpecificPool::returnConnection(ConnectionInterface* connPtr) { auto needsRefreshTP = connPtr->getLastUsed() + _parent->_options.refreshRequirement; auto conn = takeFromPool(_checkedOutPool, connPtr); invariant(conn); - updateStateInLock(); - if (conn->getGeneration() != _generation) { // If the connection is from an older generation, just return. return; @@ -610,7 +595,7 @@ void ConnectionPool::SpecificPool::returnConnection(ConnectionInterface* connPtr if (!conn->getStatus().isOK()) { // TODO: alert via some callback if the host is bad log() << "Ending connection to host " << _hostAndPort << " due to bad connection status; " - << openConnections(lk) << " connections to that host remain open"; + << openConnections() << " connections to that host remain open"; return; } @@ -622,37 +607,29 @@ void ConnectionPool::SpecificPool::returnConnection(ConnectionInterface* connPtr _parent->_options.minConnections) { // If we already have minConnections, just let the connection lapse log() << "Ending idle connection to host " << _hostAndPort - << " because the pool meets constraints; " << openConnections(lk) + << " because the pool meets constraints; " << openConnections() << " connections to that host remain open"; return; } _processingPool[connPtr] = std::move(conn); - // Unlock in case refresh can occur immediately - lk.unlock(); connPtr->refresh(_parent->_options.refreshTimeout, - guardCallback([this](auto lk, auto conn, auto status) { - finishRefresh(std::move(lk), conn, status); + guardCallback([this](auto conn, auto status) { + finishRefresh(std::move(conn), std::move(status)); })); - lk.lock(); - } else { - // If it's fine as it is, just put it in the ready queue - addToReady(lk, std::move(conn)); - - lk.unlock(); - _parent->_factory->getExecutor()->schedule(guardCallback([this](auto lk, auto schedStatus) { - fassert(51171, schedStatus); - fulfillRequests(lk); - })); + + return; } - updateStateInLock(); + // If it's fine as it is, just put it in the ready queue + addToReady(std::move(conn)); + + fulfillRequests(); } // Adds a live connection to the ready pool -void ConnectionPool::SpecificPool::addToReady(stdx::unique_lock<stdx::mutex>& lk, - OwnedConnection conn) { +void ConnectionPool::SpecificPool::addToReady(OwnedConnection conn) { auto connPtr = conn.get(); // This makes the connection the new most-recently-used connection. @@ -661,8 +638,7 @@ void ConnectionPool::SpecificPool::addToReady(stdx::unique_lock<stdx::mutex>& lk // Our strategy for refreshing connections is to check them out and // immediately check them back in (which kicks off the refresh logic in // returnConnection - connPtr->setTimeout(_parent->_options.refreshRequirement, - guardCallback([this, connPtr](stdx::unique_lock<stdx::mutex> lk) { + connPtr->setTimeout(_parent->_options.refreshRequirement, guardCallback([this, connPtr]() { auto conn = takeFromPool(_readyPool, connPtr); // We've already been checked out. We don't need to refresh @@ -678,21 +654,19 @@ void ConnectionPool::SpecificPool::addToReady(stdx::unique_lock<stdx::mutex>& lk connPtr->indicateSuccess(); - returnConnection(connPtr, std::move(lk)); + returnConnection(connPtr); })); } // Sets state to shutdown and kicks off the failure protocol to tank existing connections -void ConnectionPool::SpecificPool::triggerShutdown(const Status& status, - stdx::unique_lock<stdx::mutex> lk) { +void ConnectionPool::SpecificPool::triggerShutdown(const Status& status) { _state = State::kInShutdown; _droppedProcessingPool.clear(); - processFailure(status, std::move(lk)); + processFailure(status); } // Drop connections and fail all requests -void ConnectionPool::SpecificPool::processFailure(const Status& status, - stdx::unique_lock<stdx::mutex> lk) { +void ConnectionPool::SpecificPool::processFailure(const Status& status) { // Bump the generation so we don't reuse any pending or checked out // connections _generation++; @@ -727,20 +701,13 @@ void ConnectionPool::SpecificPool::processFailure(const Status& status, swap(requestsToFail, _requests); } - // Update state to reflect the lack of requests - updateStateInLock(); - - // Drop the lock and process all of the requests - // with the same failed status - lk.unlock(); - for (auto& request : requestsToFail) { request.second.setError(status); } } // fulfills as many outstanding requests as possible -void ConnectionPool::SpecificPool::fulfillRequests(stdx::unique_lock<stdx::mutex>& lk) { +void ConnectionPool::SpecificPool::fulfillRequests() { // If some other thread (possibly this thread) is fulfilling requests, // don't keep padding the callstack. if (_inFulfillRequests) @@ -755,7 +722,7 @@ void ConnectionPool::SpecificPool::fulfillRequests(stdx::unique_lock<stdx::mutex // deadlock). // // None of the heap manipulation code throws, but it's something to keep in mind. - auto conn = tryGetConnection(lk); + auto conn = tryGetConnection(); if (!conn) { break; @@ -766,19 +733,15 @@ void ConnectionPool::SpecificPool::fulfillRequests(stdx::unique_lock<stdx::mutex std::pop_heap(begin(_requests), end(_requests), RequestComparator{}); _requests.pop_back(); - lk.unlock(); promise.emplaceValue(std::move(conn)); - lk.lock(); - - updateStateInLock(); } - spawnConnections(lk); + spawnConnections(); } // spawn enough connections to satisfy open requests and minpool, while // honoring maxpool -void ConnectionPool::SpecificPool::spawnConnections(stdx::unique_lock<stdx::mutex>& lk) { +void ConnectionPool::SpecificPool::spawnConnections() { // If some other thread (possibly this thread) is spawning connections, // don't keep padding the callstack. if (_inSpawnConnections) @@ -817,16 +780,13 @@ void ConnectionPool::SpecificPool::spawnConnections(stdx::unique_lock<stdx::mute ++_created; // Run the setup callback - lk.unlock(); handle->setup(_parent->_options.refreshTimeout, - guardCallback([this](auto lk, auto conn, auto status) { - finishRefresh(std::move(lk), conn, status); + guardCallback([this](auto conn, auto status) { + finishRefresh(std::move(conn), std::move(status)); })); // Note that this assumes that the refreshTimeout is sound for the // setupTimeout - - lk.lock(); } } @@ -855,7 +815,7 @@ ConnectionPool::SpecificPool::OwnedConnection ConnectionPool::SpecificPool::take // Updates our state and manages the request timer -void ConnectionPool::SpecificPool::updateStateInLock() { +void ConnectionPool::SpecificPool::updateState() { if (_state == State::kInShutdown) { // If we're in shutdown, there is nothing to update. Our clients are all gone. if (_processingPool.empty()) { @@ -885,7 +845,7 @@ void ConnectionPool::SpecificPool::updateStateInLock() { // We set a timer for the most recent request, then invoke each timed // out request we couldn't service _requestTimer->setTimeout( - timeout, guardCallback([this](stdx::unique_lock<stdx::mutex> lk) { + timeout, guardCallback([this]() { auto now = _parent->_factory->now(); while (_requests.size()) { @@ -896,16 +856,12 @@ void ConnectionPool::SpecificPool::updateStateInLock() { std::pop_heap(begin(_requests), end(_requests), RequestComparator{}); _requests.pop_back(); - lk.unlock(); promise.setError(Status(ErrorCodes::NetworkInterfaceExceededTimeLimit, "Couldn't get a connection within the time limit")); - lk.lock(); } else { break; } } - - updateStateInLock(); })); } else if (_checkedOutPool.size()) { // If we have no requests, but someone's using a connection, we just @@ -931,14 +887,13 @@ void ConnectionPool::SpecificPool::updateStateInLock() { // Set the shutdown timer, this gets reset on any request _requestTimer->setTimeout( - timeout, guardCallback([this](auto lk) { + timeout, guardCallback([this]() { if (_state != State::kIdle) return; triggerShutdown( Status(ErrorCodes::NetworkInterfaceExceededTimeLimit, - "Connection pool has been idle for longer than the host timeout"), - std::move(lk)); + "Connection pool has been idle for longer than the host timeout")); })); } } diff --git a/src/mongo/executor/connection_pool.h b/src/mongo/executor/connection_pool.h index 5944e7a527e..4c7da28ef27 100644 --- a/src/mongo/executor/connection_pool.h +++ b/src/mongo/executor/connection_pool.h @@ -76,9 +76,9 @@ public: using GetConnectionCallback = unique_function<void(StatusWith<ConnectionHandle>)>; static constexpr Milliseconds kDefaultHostTimeout = Milliseconds(300000); // 5mins - static const size_t kDefaultMaxConns; - static const size_t kDefaultMinConns; - static const size_t kDefaultMaxConnecting; + static constexpr size_t kDefaultMaxConns = std::numeric_limits<size_t>::max(); + static constexpr size_t kDefaultMinConns = 1; + static constexpr size_t kDefaultMaxConnecting = 2; static constexpr Milliseconds kDefaultRefreshRequirement = Milliseconds(60000); // 1min static constexpr Milliseconds kDefaultRefreshTimeout = Milliseconds(20000); // 20secs @@ -155,9 +155,9 @@ public: const stdx::function<transport::Session::TagMask(transport::Session::TagMask)>& mutateFunc) override; - Future<ConnectionHandle> get(const HostAndPort& hostAndPort, - transport::ConnectSSLMode sslMode, - Milliseconds timeout); + SemiFuture<ConnectionHandle> get(const HostAndPort& hostAndPort, + transport::ConnectSSLMode sslMode, + Milliseconds timeout); void get_forTest(const HostAndPort& hostAndPort, Milliseconds timeout, GetConnectionCallback cb); @@ -294,8 +294,8 @@ protected: * Making these protected makes the definitions available to override in * children. */ - using SetupCallback = stdx::function<void(ConnectionInterface*, Status)>; - using RefreshCallback = stdx::function<void(ConnectionInterface*, Status)>; + using SetupCallback = unique_function<void(ConnectionInterface*, Status)>; + using RefreshCallback = unique_function<void(ConnectionInterface*, Status)>; /** * Sets up the connection. This should include connection + auth + any diff --git a/src/mongo/executor/connection_pool_test.cpp b/src/mongo/executor/connection_pool_test.cpp index 0bb27501fcf..c2220cd95f6 100644 --- a/src/mongo/executor/connection_pool_test.cpp +++ b/src/mongo/executor/connection_pool_test.cpp @@ -76,10 +76,13 @@ void doneWith(const ConnectionPool::ConnectionHandle& conn) { using StatusWithConn = StatusWith<ConnectionPool::ConnectionHandle>; +auto getId(const ConnectionPool::ConnectionHandle& conn) { + return dynamic_cast<ConnectionImpl*>(conn.get())->id(); +} auto verifyAndGetId(StatusWithConn& swConn) { ASSERT(swConn.isOK()); auto& conn = swConn.getValue(); - return dynamic_cast<ConnectionImpl*>(conn.get())->id(); + return getId(conn); } /** @@ -1271,7 +1274,7 @@ TEST_F(ConnectionPoolTest, SetupTimeoutsDontTimeoutUnrelatedRequests) { ASSERT(conn1); ASSERT(!conn1->isOK()); - ASSERT(conn1->getStatus().code() == ErrorCodes::NetworkInterfaceExceededTimeLimit); + ASSERT_EQ(conn1->getStatus(), ErrorCodes::NetworkInterfaceExceededTimeLimit); } /** @@ -1324,7 +1327,7 @@ TEST_F(ConnectionPoolTest, RefreshTimeoutsDontTimeoutRequests) { ASSERT(conn1); ASSERT(!conn1->isOK()); - ASSERT(conn1->getStatus().code() == ErrorCodes::NetworkInterfaceExceededTimeLimit); + ASSERT_EQ(conn1->getStatus(), ErrorCodes::NetworkInterfaceExceededTimeLimit); } template <typename Ptr> @@ -1438,10 +1441,9 @@ TEST_F(ConnectionPoolTest, AsyncGet) { // Future should be ready now ASSERT_TRUE(connFuture.isReady()); - std::move(connFuture).getAsync([&](StatusWithConn swConn) mutable { - connId = verifyAndGetId(swConn); - doneWith(swConn.getValue()); - }); + auto conn = std::move(connFuture).get(); + connId = getId(conn); + doneWith(conn); ASSERT(connId); } @@ -1457,27 +1459,21 @@ TEST_F(ConnectionPoolTest, AsyncGet) { auto connFuture1 = pool->get(HostAndPort(), transport::kGlobalSSLMode, Seconds{1}); auto connFuture2 = pool->get(HostAndPort(), transport::kGlobalSSLMode, Seconds{10}); - // Queue up the second future to resolve as soon as it is ready - std::move(connFuture2).getAsync([&](StatusWithConn swConn) mutable { - connId2 = verifyAndGetId(swConn); - doneWith(swConn.getValue()); - }); - // The first future should be immediately ready. The second should be in the queue. ASSERT_TRUE(connFuture1.isReady()); ASSERT_FALSE(connFuture2.isReady()); // Resolve the first future to return the connection and continue on to the second. decltype(connFuture1) connFuture3; - std::move(connFuture1).getAsync([&](StatusWithConn swConn) mutable { - // Grab our third future while our first one is being fulfilled - connFuture3 = pool->get(HostAndPort(), transport::kGlobalSSLMode, Seconds{1}); + auto conn1 = std::move(connFuture1).get(); - connId1 = verifyAndGetId(swConn); - doneWith(swConn.getValue()); - }); + // Grab our third future while our first one is being fulfilled + connFuture3 = pool->get(HostAndPort(), transport::kGlobalSSLMode, Seconds{1}); + + connId1 = getId(conn1); + doneWith(conn1); + conn1.reset(); ASSERT(connId1); - ASSERT_FALSE(connId2); // Since the third future has a smaller timeout than the second, // it should take priority over the second @@ -1485,13 +1481,20 @@ TEST_F(ConnectionPoolTest, AsyncGet) { ASSERT_FALSE(connFuture2.isReady()); // Resolve the third future. This should trigger the second future - std::move(connFuture3).getAsync([&](StatusWithConn swConn) mutable { - // We run before the second future - ASSERT_FALSE(connId2); + auto conn3 = std::move(connFuture3).get(); - connId3 = verifyAndGetId(swConn); - doneWith(swConn.getValue()); - }); + // We've run before the second future + ASSERT_FALSE(connFuture2.isReady()); + + connId3 = getId(conn3); + doneWith(conn3); + conn3.reset(); + + // The second future is now finally ready + ASSERT_TRUE(connFuture2.isReady()); + auto conn2 = std::move(connFuture2).get(); + connId2 = getId(conn2); + doneWith(conn2); ASSERT_EQ(connId1, connId2); ASSERT_EQ(connId2, connId3); diff --git a/src/mongo/executor/connection_pool_test_fixture.cpp b/src/mongo/executor/connection_pool_test_fixture.cpp index 6c0321fa619..1f6217c6ef3 100644 --- a/src/mongo/executor/connection_pool_test_fixture.cpp +++ b/src/mongo/executor/connection_pool_test_fixture.cpp @@ -44,8 +44,6 @@ TimerImpl::~TimerImpl() { } void TimerImpl::setTimeout(Milliseconds timeout, TimeoutCallback cb) { - _timers.erase(this); - _cb = std::move(cb); _expiration = _global->now() + timeout; @@ -53,9 +51,10 @@ void TimerImpl::setTimeout(Milliseconds timeout, TimeoutCallback cb) { } void TimerImpl::cancelTimeout() { - _timers.erase(this); TimeoutCallback cb; _cb.swap(cb); + + _timers.erase(this); } void TimerImpl::clear() { @@ -72,7 +71,12 @@ void TimerImpl::fireIfNecessary() { for (auto&& x : timers) { if (_timers.count(x) && (x->_expiration <= now)) { - x->_cb(); + auto execCB = [cb = std::move(x->_cb)](auto&&) mutable { + std::move(cb)(); + }; + auto global = x->_global; + _timers.erase(x); + global->_executor->schedule(std::move(execCB)); } } } @@ -113,18 +117,24 @@ void ConnectionImpl::clear() { _pushRefreshQueue.clear(); } -void ConnectionImpl::pushSetup(PushSetupCallback status) { - _pushSetupQueue.push_back(status); - - if (_setupQueue.size()) { - auto connPtr = _setupQueue.front(); - auto callback = _pushSetupQueue.front(); - _setupQueue.pop_front(); - _pushSetupQueue.pop_front(); +void ConnectionImpl::processSetup() { + auto connPtr = _setupQueue.front(); + auto callback = std::move(_pushSetupQueue.front()); + _setupQueue.pop_front(); + _pushSetupQueue.pop_front(); - auto cb = connPtr->_setupCallback; + connPtr->_global->_executor->schedule([ connPtr, callback = std::move(callback) ](auto&&) { + auto cb = std::move(connPtr->_setupCallback); connPtr->indicateUsed(); cb(connPtr, callback()); + }); +} + +void ConnectionImpl::pushSetup(PushSetupCallback status) { + _pushSetupQueue.push_back(std::move(status)); + + if (_setupQueue.size()) { + processSetup(); } } @@ -136,19 +146,25 @@ size_t ConnectionImpl::setupQueueDepth() { return _setupQueue.size(); } -void ConnectionImpl::pushRefresh(PushRefreshCallback status) { - _pushRefreshQueue.push_back(status); - - if (_refreshQueue.size()) { - auto connPtr = _refreshQueue.front(); - auto callback = _pushRefreshQueue.front(); +void ConnectionImpl::processRefresh() { + auto connPtr = _refreshQueue.front(); + auto callback = std::move(_pushRefreshQueue.front()); - _refreshQueue.pop_front(); - _pushRefreshQueue.pop_front(); + _refreshQueue.pop_front(); + _pushRefreshQueue.pop_front(); - auto cb = connPtr->_refreshCallback; + connPtr->_global->_executor->schedule([ connPtr, callback = std::move(callback) ](auto&&) { + auto cb = std::move(connPtr->_refreshCallback); connPtr->indicateUsed(); cb(connPtr, callback()); + }); +} + +void ConnectionImpl::pushRefresh(PushRefreshCallback status) { + _pushRefreshQueue.push_back(std::move(status)); + + if (_refreshQueue.size()) { + processRefresh(); } } @@ -161,7 +177,7 @@ size_t ConnectionImpl::refreshQueueDepth() { } void ConnectionImpl::setTimeout(Milliseconds timeout, TimeoutCallback cb) { - _timer.setTimeout(timeout, cb); + _timer.setTimeout(timeout, std::move(cb)); } void ConnectionImpl::cancelTimeout() { @@ -172,20 +188,14 @@ void ConnectionImpl::setup(Milliseconds timeout, SetupCallback cb) { _setupCallback = std::move(cb); _timer.setTimeout(timeout, [this] { - _setupCallback(this, Status(ErrorCodes::NetworkInterfaceExceededTimeLimit, "timeout")); + auto setupCb = std::move(_setupCallback); + setupCb(this, Status(ErrorCodes::NetworkInterfaceExceededTimeLimit, "timeout")); }); _setupQueue.push_back(this); if (_pushSetupQueue.size()) { - auto connPtr = _setupQueue.front(); - auto callback = _pushSetupQueue.front(); - _setupQueue.pop_front(); - _pushSetupQueue.pop_front(); - - auto refreshCb = connPtr->_setupCallback; - connPtr->indicateUsed(); - refreshCb(connPtr, callback()); + processSetup(); } } @@ -193,21 +203,14 @@ void ConnectionImpl::refresh(Milliseconds timeout, RefreshCallback cb) { _refreshCallback = std::move(cb); _timer.setTimeout(timeout, [this] { - _refreshCallback(this, Status(ErrorCodes::NetworkInterfaceExceededTimeLimit, "timeout")); + auto refreshCb = std::move(_refreshCallback); + refreshCb(this, Status(ErrorCodes::NetworkInterfaceExceededTimeLimit, "timeout")); }); _refreshQueue.push_back(this); if (_pushRefreshQueue.size()) { - auto connPtr = _refreshQueue.front(); - auto callback = _pushRefreshQueue.front(); - - _refreshQueue.pop_front(); - _pushRefreshQueue.pop_front(); - - auto refreshCb = connPtr->_refreshCallback; - connPtr->indicateUsed(); - refreshCb(connPtr, callback()); + processRefresh(); } } diff --git a/src/mongo/executor/connection_pool_test_fixture.h b/src/mongo/executor/connection_pool_test_fixture.h index cbd6657a81a..89cd65e48ec 100644 --- a/src/mongo/executor/connection_pool_test_fixture.h +++ b/src/mongo/executor/connection_pool_test_fixture.h @@ -32,6 +32,7 @@ #include <set> #include "mongo/executor/connection_pool.h" +#include "mongo/util/functional.h" namespace mongo { namespace executor { @@ -77,8 +78,8 @@ private: */ class ConnectionImpl final : public ConnectionPool::ConnectionInterface { public: - using PushSetupCallback = stdx::function<Status()>; - using PushRefreshCallback = stdx::function<Status()>; + using PushSetupCallback = unique_function<Status()>; + using PushRefreshCallback = unique_function<Status()>; ConnectionImpl(const HostAndPort& hostAndPort, size_t generation, PoolImpl* global); @@ -115,6 +116,9 @@ private: void refresh(Milliseconds timeout, RefreshCallback cb) override; + static void processSetup(); + static void processRefresh(); + HostAndPort _hostAndPort; SetupCallback _setupCallback; RefreshCallback _refreshCallback; @@ -167,6 +171,7 @@ public: */ class PoolImpl final : public ConnectionPool::DependentTypeFactoryInterface { friend class ConnectionImpl; + friend class TimerImpl; public: PoolImpl() = default; diff --git a/src/mongo/executor/connection_pool_tl.cpp b/src/mongo/executor/connection_pool_tl.cpp index 44ca5e60f79..55ea5a161c1 100644 --- a/src/mongo/executor/connection_pool_tl.cpp +++ b/src/mongo/executor/connection_pool_tl.cpp @@ -98,16 +98,19 @@ void TLTimer::setTimeout(Milliseconds timeoutVal, TimeoutCallback cb) { return; } - _timer->waitUntil(_reactor->now() + timeoutVal).getAsync([cb = std::move(cb)](Status status) { - // If we get canceled, then we don't worry about the timeout anymore - if (status == ErrorCodes::CallbackCanceled) { - return; - } + // Wait until our timeoutVal then run on the reactor + _timer->waitUntil(_reactor->now() + timeoutVal) + .thenRunOn(_reactor) + .getAsync([cb = std::move(cb)](Status status) { + // If we get canceled, then we don't worry about the timeout anymore + if (status == ErrorCodes::CallbackCanceled) { + return; + } - fassert(50475, status); + fassert(50475, status); - cb(); - }); + cb(); + }); } void TLTimer::cancelTimeout() { @@ -209,7 +212,7 @@ void TLConnection::setup(Milliseconds timeout, SetupCallback cb) { auto pf = makePromiseFuture<void>(); auto handler = std::make_shared<TimeoutHandler>(std::move(pf.promise)); - std::move(pf.future).getAsync( + std::move(pf.future).thenRunOn(_reactor).getAsync( [ this, cb = std::move(cb), anchor ](Status status) { cb(this, std::move(status)); }); setTimeout(timeout, [this, handler, timeout] { @@ -281,7 +284,7 @@ void TLConnection::refresh(Milliseconds timeout, RefreshCallback cb) { auto pf = makePromiseFuture<void>(); auto handler = std::make_shared<TimeoutHandler>(std::move(pf.promise)); - std::move(pf.future).getAsync( + std::move(pf.future).thenRunOn(_reactor).getAsync( [ this, cb = std::move(cb), anchor ](Status status) { cb(this, status); }); setTimeout(timeout, [this, handler] { diff --git a/src/mongo/executor/network_interface.h b/src/mongo/executor/network_interface.h index 0cdd2175118..033dfb00d30 100644 --- a/src/mongo/executor/network_interface.h +++ b/src/mongo/executor/network_interface.h @@ -139,6 +139,9 @@ public: * Returns ErrorCodes::ShutdownInProgress if NetworkInterface::shutdown has already started * and Status::OK() otherwise. If it returns Status::OK(), then the onFinish argument will be * executed by NetworkInterface eventually; otherwise, it will not. + * + * Note that if you pass a baton to startCommand and that baton refuses work, then your onFinish + * function will not run. */ virtual Status startCommand(const TaskExecutor::CallbackHandle& cbHandle, RemoteCommandRequest& request, diff --git a/src/mongo/executor/network_interface_tl.cpp b/src/mongo/executor/network_interface_tl.cpp index 1bfb9a1321e..80b91244fea 100644 --- a/src/mongo/executor/network_interface_tl.cpp +++ b/src/mongo/executor/network_interface_tl.cpp @@ -172,6 +172,39 @@ Date_t NetworkInterfaceTL::now() { return _reactor->now(); } +NetworkInterfaceTL::CommandState::CommandState(NetworkInterfaceTL* interface_, + RemoteCommandRequest request_, + const TaskExecutor::CallbackHandle& cbHandle_, + Promise<RemoteCommandResponse> promise_) + : interface(interface_), + request(std::move(request_)), + cbHandle(cbHandle_), + promise(std::move(promise_)) {} + + +auto NetworkInterfaceTL::CommandState::make(NetworkInterfaceTL* interface, + RemoteCommandRequest request, + const TaskExecutor::CallbackHandle& cbHandle, + Promise<RemoteCommandResponse> promise) { + auto state = + std::make_shared<CommandState>(interface, std::move(request), cbHandle, std::move(promise)); + + { + stdx::lock_guard lk(interface->_inProgressMutex); + interface->_inProgress.insert({cbHandle, state}); + } + + return state; +} + +NetworkInterfaceTL::CommandState::~CommandState() { + invariant(interface); + + { + stdx::lock_guard lk(interface->_inProgressMutex); + interface->_inProgress.erase(cbHandle); + } +} Status NetworkInterfaceTL::startCommand(const TaskExecutor::CallbackHandle& cbHandle, RemoteCommandRequest& request, @@ -195,99 +228,71 @@ Status NetworkInterfaceTL::startCommand(const TaskExecutor::CallbackHandle& cbHa } auto pf = makePromiseFuture<RemoteCommandResponse>(); - auto state = std::make_shared<CommandState>(request, cbHandle, std::move(pf.promise)); - { - stdx::lock_guard<stdx::mutex> lk(_inProgressMutex); - _inProgress.insert({state->cbHandle, state}); - } + auto state = CommandState::make(this, request, cbHandle, std::move(pf.promise)); state->start = now(); if (state->request.timeout != state->request.kNoTimeout) { state->deadline = state->start + state->request.timeout; } + auto executor = baton ? ExecutorPtr(baton) : ExecutorPtr(_reactor); + std::move(pf.future) + .thenRunOn(executor) + .onError([requestId = state->request.id](auto error)->StatusWith<RemoteCommandResponse> { + LOG(2) << "Failed to get connection from pool for request " << requestId << ": " + << redact(error); + + // The TransportLayer has, for historical reasons returned SocketException for + // network errors, but sharding assumes HostUnreachable on network errors. + if (error == ErrorCodes::SocketException) { + error = Status(ErrorCodes::HostUnreachable, error.reason()); + } + return error; + }) + .getAsync([ this, state, onFinish = std::move(onFinish) ]( + StatusWith<RemoteCommandResponse> response) { + auto duration = now() - state->start; + if (!response.isOK()) { + onFinish(RemoteCommandResponse(response.getStatus(), duration)); + } else { + const auto& rs = response.getValue(); + LOG(2) << "Request " << state->request.id << " finished with response: " + << redact(rs.isOK() ? rs.data.toString() : rs.status.toString()); + onFinish(rs); + } + }); + if (MONGO_FAIL_POINT(networkInterfaceDiscardCommandsBeforeAcquireConn)) { log() << "Discarding command due to failpoint before acquireConn"; - std::move(pf.future).getAsync([onFinish = std::move(onFinish)]( - StatusWith<RemoteCommandResponse> response) mutable { - onFinish(RemoteCommandResponse(response.getStatus(), Milliseconds{0})); - }); return Status::OK(); } - auto connFuture = _pool->get(request.target, request.sslMode, request.timeout) - .tapError([state](Status error) { - LOG(2) << "Failed to get connection from pool for request " - << state->request.id << ": " << error; - }); - - auto remainingWork = - [ this, state, future = std::move(pf.future), baton, onFinish = std::move(onFinish) ]( - StatusWith<ConnectionPool::ConnectionHandle> swConn) mutable { - makeReadyFutureWith([&] { - auto conn = uassertStatusOK(std::move(swConn)); - return _onAcquireConn(state, std::move(future), std::move(conn), baton); + _pool->get(request.target, request.sslMode, request.timeout) + .thenRunOn(executor) + .then([this, state, baton](auto conn) { + if (MONGO_FAIL_POINT(networkInterfaceDiscardCommandsAfterAcquireConn)) { + log() << "Discarding command due to failpoint after acquireConn"; + return; + } + + _onAcquireConn(state, std::move(conn), baton); }) - .onError([](Status error) -> StatusWith<RemoteCommandResponse> { - // The TransportLayer has, for historical reasons returned SocketException for - // network errors, but sharding assumes HostUnreachable on network errors. - if (error == ErrorCodes::SocketException) { - error = Status(ErrorCodes::HostUnreachable, error.reason()); - } - return error; - }) - .getAsync([ this, state, onFinish = std::move(onFinish) ]( - StatusWith<RemoteCommandResponse> response) { - auto duration = now() - state->start; - if (!response.isOK()) { - onFinish(RemoteCommandResponse(response.getStatus(), duration)); - } else { - const auto& rs = response.getValue(); - LOG(2) << "Request " << state->request.id << " finished with response: " - << redact(rs.isOK() ? rs.data.toString() : rs.status.toString()); - onFinish(rs); - } - }); - }; - - if (baton) { - // If we have a baton, we want to get back to the baton thread immediately after we get a - // connection - std::move(connFuture) - .getAsync([ baton, reactor = _reactor.get(), rw = std::move(remainingWork) ]( - StatusWith<ConnectionPool::ConnectionHandle> swConn) mutable { - baton->schedule( - [ rw = std::move(rw), swConn = std::move(swConn) ](Status status) mutable { - if (status.isOK()) { - std::move(rw)(std::move(swConn)); - } else { - std::move(rw)(std::move(status)); - } - }); - }); - } else { - // otherwise we're happy to run inline - std::move(connFuture) - .getAsync([rw = std::move(remainingWork)]( - StatusWith<ConnectionPool::ConnectionHandle> swConn) mutable { - std::move(rw)(std::move(swConn)); - }); - } + .getAsync([this, state](auto status) { + // If we couldn't get a connection or _onAcquireConn threw, then we should clean up + // here. + if (!status.isOK() && !state->done.swap(true)) { + state->promise.setError(status); + } + }); return Status::OK(); } // This is only called from within a then() callback on a future, so throwing is equivalent to // returning a ready Future with a not-OK status. -Future<RemoteCommandResponse> NetworkInterfaceTL::_onAcquireConn( - std::shared_ptr<CommandState> state, - Future<RemoteCommandResponse> future, - ConnectionPool::ConnectionHandle conn, - const BatonHandle& baton) { - if (MONGO_FAIL_POINT(networkInterfaceDiscardCommandsAfterAcquireConn)) { - conn->indicateSuccess(); - return future; - } +void NetworkInterfaceTL::_onAcquireConn(std::shared_ptr<CommandState> state, + ConnectionPool::ConnectionHandle conn, + const BatonHandle& baton) { if (state->done.load()) { conn->indicateSuccess(); @@ -355,7 +360,6 @@ Future<RemoteCommandResponse> NetworkInterfaceTL::_onAcquireConn( return RemoteCommandResponse(std::move(response)); }) .getAsync([this, state, baton](StatusWith<RemoteCommandResponse> swr) { - _eraseInUseConn(state->cbHandle); if (!swr.isOK()) { state->conn->indicateFailure(swr.getStatus()); } else if (!swr.getValue().isOK()) { @@ -365,8 +369,9 @@ Future<RemoteCommandResponse> NetworkInterfaceTL::_onAcquireConn( state->conn->indicateSuccess(); } - if (state->done.swap(true)) + if (state->done.swap(true)) { return; + } if (getTestCommandsEnabled()) { stdx::lock_guard<stdx::mutex> lk(_mutex); @@ -383,13 +388,6 @@ Future<RemoteCommandResponse> NetworkInterfaceTL::_onAcquireConn( state->promise.setFromStatusWith(std::move(swr)); }); - - return future; -} - -void NetworkInterfaceTL::_eraseInUseConn(const TaskExecutor::CallbackHandle& cbHandle) { - stdx::lock_guard<stdx::mutex> lk(_inProgressMutex); - _inProgress.erase(cbHandle); } void NetworkInterfaceTL::cancelCommand(const TaskExecutor::CallbackHandle& cbHandle, @@ -399,7 +397,11 @@ void NetworkInterfaceTL::cancelCommand(const TaskExecutor::CallbackHandle& cbHan if (it == _inProgress.end()) { return; } - auto state = it->second; + auto state = it->second.lock(); + if (!state) { + return; + } + _inProgress.erase(it); lk.unlock(); diff --git a/src/mongo/executor/network_interface_tl.h b/src/mongo/executor/network_interface_tl.h index 1c7ab495411..7e7b24b8c71 100644 --- a/src/mongo/executor/network_interface_tl.h +++ b/src/mongo/executor/network_interface_tl.h @@ -85,12 +85,20 @@ public: private: struct CommandState { - CommandState(RemoteCommandRequest request_, - TaskExecutor::CallbackHandle cbHandle_, - Promise<RemoteCommandResponse> promise_) - : request(std::move(request_)), - cbHandle(std::move(cbHandle_)), - promise(std::move(promise_)) {} + CommandState(NetworkInterfaceTL* interface_, + RemoteCommandRequest request_, + const TaskExecutor::CallbackHandle& cbHandle_, + Promise<RemoteCommandResponse> promise_); + ~CommandState(); + + // Create a new CommandState in a shared_ptr + // Prefer this over raw construction + static auto make(NetworkInterfaceTL* interface, + RemoteCommandRequest request, + const TaskExecutor::CallbackHandle& cbHandle, + Promise<RemoteCommandResponse> promise); + + NetworkInterfaceTL* interface; RemoteCommandRequest request; TaskExecutor::CallbackHandle cbHandle; @@ -125,11 +133,9 @@ private: void _answerAlarm(Status status, std::shared_ptr<AlarmState> state); void _run(); - void _eraseInUseConn(const TaskExecutor::CallbackHandle& handle); - Future<RemoteCommandResponse> _onAcquireConn(std::shared_ptr<CommandState> state, - Future<RemoteCommandResponse> future, - ConnectionPool::ConnectionHandle conn, - const BatonHandle& baton); + void _onAcquireConn(std::shared_ptr<CommandState> state, + ConnectionPool::ConnectionHandle conn, + const BatonHandle& baton); std::string _instanceName; ServiceContext* _svcCtx; @@ -157,7 +163,7 @@ private: stdx::thread _ioThread; stdx::mutex _inProgressMutex; - stdx::unordered_map<TaskExecutor::CallbackHandle, std::shared_ptr<CommandState>> _inProgress; + stdx::unordered_map<TaskExecutor::CallbackHandle, std::weak_ptr<CommandState>> _inProgress; stdx::unordered_map<TaskExecutor::CallbackHandle, std::shared_ptr<AlarmState>> _inProgressAlarms; |