diff options
author | Ben Caimano <ben.caimano@10gen.com> | 2018-08-17 17:11:42 -0400 |
---|---|---|
committer | Ben Caimano <ben.caimano@10gen.com> | 2018-08-17 17:11:42 -0400 |
commit | 6c33879eaa545136f8e1c747780ffd3ccb493dce (patch) | |
tree | 65ac9eab37da6703298405e877dbc73f47f43526 | |
parent | 9fc1c55c9648c9c9a49ddbcd7914cf2d56a9bcd4 (diff) | |
download | mongo-6c33879eaa545136f8e1c747780ffd3ccb493dce.tar.gz |
SERVER-35056 Flush ready callbacks on NetworkInterfaceTL shutdown
Squashed from 3 commits:
SERVER-35684 Remove `promise.getFuture()`
This API invites subtle race conditions. So just remove it, and
force everyone to use a unified API which creates a promise and
a future at the same time.
(cherry picked from commit 2338f365430d7f395faf73bff6c64def505da1b3)
SERVER-35056 Flush ready callbacks on NetworkInterfaceTL shutdown
(cherry picked from commit b49a27b359b17cd1b1560134b89527b78db565cc)
SERVER-36466 Secure shutdown conditions for SpecificPool
(cherry picked from commit 5e0545d3625dc85d16f5f021896f61d3a21e2333)
-rw-r--r-- | src/mongo/executor/connection_pool.cpp | 435 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool.h | 33 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool_test_fixture.cpp | 10 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool_test_fixture.h | 6 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool_tl.cpp | 93 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool_tl.h | 77 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_integration_fixture.cpp | 5 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_integration_test.cpp | 23 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_tl.cpp | 61 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_tl.h | 11 | ||||
-rw-r--r-- | src/mongo/transport/service_executor_test.cpp | 8 | ||||
-rw-r--r-- | src/mongo/transport/transport_layer.h | 1 | ||||
-rw-r--r-- | src/mongo/transport/transport_layer_asio.cpp | 123 | ||||
-rw-r--r-- | src/mongo/util/future.h | 49 | ||||
-rw-r--r-- | src/mongo/util/future_bm.cpp | 132 | ||||
-rw-r--r-- | src/mongo/util/future_test.cpp | 201 | ||||
-rw-r--r-- | src/mongo/util/keyed_executor.h | 7 |
17 files changed, 662 insertions, 613 deletions
diff --git a/src/mongo/executor/connection_pool.cpp b/src/mongo/executor/connection_pool.cpp index 9bfbdd459a3..581ec8e3194 100644 --- a/src/mongo/executor/connection_pool.cpp +++ b/src/mongo/executor/connection_pool.cpp @@ -59,12 +59,13 @@ namespace executor { * go out of existence after hostTimeout passes without any of their * connections being used. */ -class ConnectionPool::SpecificPool { +class ConnectionPool::SpecificPool final + : public std::enable_shared_from_this<ConnectionPool::SpecificPool> { public: /** - * These active client methods must be used whenever entering a specific pool outside of the - * shutdown background task. The presence of an active client will bump a counter on the - * specific pool which will prevent the shutdown thread from deleting it. + * Whenever a function enters a specific pool, the function needs to be guarded. + * The presence of one of these guards will bump a counter on the specific pool + * which will prevent the pool from removing itself from the map of pools. * * The complexity comes from the need to hold a lock when writing to the * _activeClients param on the specific pool. Because the code beneath the client needs to lock @@ -72,32 +73,28 @@ public: * lock acquired, move it into the client, then re-acquire to decrement the counter on the way * out. * - * It's used like: + * This callback also (perhaps overly aggressively) binds a shared pointer to the guard. + * It is *always* safe to reference the original specific pool in the guarded function object. * - * pool.runWithActiveClient([](stdx::unique_lock<stdx::mutex> lk){ codeToBeProtected(); }); + * For a function object of signature: + * R riskyBusiness(stdx::unique_lock<stdx::mutex>, ArgTypes...); + * + * It returns a function object of signature: + * R safeCallback(ArgTypes...); */ template <typename Callback> - auto runWithActiveClient(Callback&& cb) { - return runWithActiveClient(stdx::unique_lock<stdx::mutex>(_parent->_mutex), - std::forward<Callback>(cb)); - } - - template <typename Callback> - auto runWithActiveClient(stdx::unique_lock<stdx::mutex> lk, Callback&& cb) { - invariant(lk.owns_lock()); - - _activeClients++; - - const auto guard = MakeGuard([&] { - invariant(!lk.owns_lock()); - stdx::lock_guard<stdx::mutex> lk(_parent->_mutex); - _activeClients--; - }); + auto guardCallback(Callback&& cb) { + return [ cb = std::forward<Callback>(cb), anchor = shared_from_this() ](auto&&... args) { + stdx::unique_lock<stdx::mutex> lk(anchor->_parent->_mutex); + ++(anchor->_activeClients); + + ON_BLOCK_EXIT([anchor]() { + stdx::unique_lock<stdx::mutex> lk(anchor->_parent->_mutex); + --(anchor->_activeClients); + }); - { - decltype(lk) localLk(std::move(lk)); - return cb(std::move(localLk)); - } + return cb(std::move(lk), std::forward<decltype(args)>(args)...); + }; } SpecificPool(ConnectionPool* parent, const HostAndPort& hostAndPort); @@ -112,6 +109,14 @@ public: stdx::unique_lock<stdx::mutex> lk); /** + * Triggers the shutdown procedure. This function marks the state as kInShutdown + * and calls processFailure below with the status provided. This may not immediately + * delist or destruct this pool. However, both will happen eventually as ConnectionHandles + * are deleted. + */ + void triggerShutdown(const Status& status, stdx::unique_lock<stdx::mutex> lk); + + /** * Cascades a failure across existing connections and requests. Invoking * this function drops all current connections and fails all current * requests with the passed status. @@ -168,18 +173,6 @@ public: _tags = mutateFunc(_tags); } - /** - * See runWithActiveClient for what this controls, and be very very careful to manage the - * refcount correctly. - */ - void incActiveClients(const stdx::unique_lock<stdx::mutex>& lk) { - _activeClients++; - } - - void decActiveClients(const stdx::unique_lock<stdx::mutex>& lk) { - _activeClients--; - } - private: using OwnedConnection = std::shared_ptr<ConnectionInterface>; using OwnershipPool = stdx::unordered_map<ConnectionInterface*, OwnedConnection>; @@ -197,8 +190,6 @@ private: void spawnConnections(stdx::unique_lock<stdx::mutex>& lk); - void shutdown(); - template <typename OwnershipPoolType> typename OwnershipPoolType::mapped_type takeFromPool( OwnershipPoolType& pool, typename OwnershipPoolType::key_type connPtr); @@ -219,7 +210,7 @@ private: std::vector<Request> _requests; - std::unique_ptr<TimerInterface> _requestTimer; + std::shared_ptr<TimerInterface> _requestTimer; Date_t _requestTimerExpiration; size_t _activeClients; size_t _generation; @@ -266,7 +257,7 @@ constexpr Milliseconds ConnectionPool::kDefaultRefreshTimeout; const Status ConnectionPool::kConnectionStateUnknown = Status(ErrorCodes::InternalError, "Connection is in an unknown state"); -ConnectionPool::ConnectionPool(std::unique_ptr<DependentTypeFactoryInterface> impl, +ConnectionPool::ConnectionPool(std::shared_ptr<DependentTypeFactoryInterface> impl, std::string name, Options options) : _name(std::move(name)), @@ -289,33 +280,17 @@ ConnectionPool::~ConnectionPool() { } void ConnectionPool::shutdown() { - std::vector<SpecificPool*> pools; - - // Ensure we decrement active clients for all pools that we inc on (because we intend to process - // failures) - const auto guard = MakeGuard([&] { - stdx::unique_lock<stdx::mutex> lk(_mutex); - - for (const auto& pool : pools) { - pool->decActiveClients(lk); - } - }); + _factory->shutdown(); // Grab all current pools (under the lock) - { + auto pools = [&] { stdx::unique_lock<stdx::mutex> lk(_mutex); + return _pools; + }(); - for (auto& pair : _pools) { - pools.push_back(pair.second.get()); - pair.second->incActiveClients(lk); - } - } - - // Reacquire the lock per pool and process failures. We'll dec active clients when we're all - // through in the guard - for (const auto& pool : pools) { + for (const auto& pair : pools) { stdx::unique_lock<stdx::mutex> lk(_mutex); - pool->processFailure( + pair.second->triggerShutdown( Status(ErrorCodes::ShutdownInProgress, "Shutting down the connection pool"), std::move(lk)); } @@ -329,42 +304,25 @@ void ConnectionPool::dropConnections(const HostAndPort& hostAndPort) { if (iter == _pools.end()) return; - iter->second->runWithActiveClient(std::move(lk), [&](decltype(lk) lk) { - iter->second->processFailure( - Status(ErrorCodes::PooledConnectionsDropped, "Pooled connections dropped"), - std::move(lk)); - }); + auto pool = iter->second; + pool->processFailure(Status(ErrorCodes::PooledConnectionsDropped, "Pooled connections dropped"), + std::move(lk)); } void ConnectionPool::dropConnections(transport::Session::TagMask tags) { - std::vector<SpecificPool*> pools; - - // Ensure we decrement active clients for all pools that we inc on (because we intend to process - // failures) - const auto guard = MakeGuard([&] { + // Grab all current pools (under the lock) + auto pools = [&] { stdx::unique_lock<stdx::mutex> lk(_mutex); + return _pools; + }(); - for (const auto& pool : pools) { - pool->decActiveClients(lk); - } - }); + for (const auto& pair : pools) { + auto& pool = pair.second; - // Grab all current pools that don't match tags (under the lock) - { stdx::unique_lock<stdx::mutex> lk(_mutex); + if (pool->matchesTags(lk, tags)) + continue; - for (auto& pair : _pools) { - if (!pair.second->matchesTags(lk, tags)) { - pools.push_back(pair.second.get()); - pair.second->incActiveClients(lk); - } - } - } - - // Reacquire the lock per pool and process failures. We'll dec active clients when we're all - // through in the guard - for (const auto& pool : pools) { - stdx::unique_lock<stdx::mutex> lk(_mutex); pool->processFailure( Status(ErrorCodes::PooledConnectionsDropped, "Pooled connections dropped"), std::move(lk)); @@ -381,7 +339,8 @@ void ConnectionPool::mutateTags( if (iter == _pools.end()) return; - iter->second->mutateTags(lk, mutateFunc); + auto pool = iter->second; + pool->mutateTags(lk, mutateFunc); } void ConnectionPool::get(const HostAndPort& hostAndPort, @@ -392,25 +351,22 @@ void ConnectionPool::get(const HostAndPort& hostAndPort, Future<ConnectionPool::ConnectionHandle> ConnectionPool::get(const HostAndPort& hostAndPort, Milliseconds timeout) { - SpecificPool* pool; + std::shared_ptr<SpecificPool> pool; stdx::unique_lock<stdx::mutex> lk(_mutex); auto iter = _pools.find(hostAndPort); if (iter == _pools.end()) { - auto handle = stdx::make_unique<SpecificPool>(this, hostAndPort); - pool = handle.get(); - _pools[hostAndPort] = std::move(handle); + pool = stdx::make_unique<SpecificPool>(this, hostAndPort); + _pools[hostAndPort] = pool; } else { - pool = iter->second.get(); + pool = iter->second; } invariant(pool); - return pool->runWithActiveClient(std::move(lk), [&](decltype(lk) lk) { - return pool->getConnection(hostAndPort, timeout, std::move(lk)); - }); + return pool->getConnection(hostAndPort, timeout, std::move(lk)); } void ConnectionPool::appendConnectionStats(ConnectionPoolStats* stats) const { @@ -447,9 +403,8 @@ void ConnectionPool::returnConnection(ConnectionInterface* conn) { str::stream() << "Tried to return connection but no pool found for " << conn->getHostAndPort()); - iter->second->runWithActiveClient(std::move(lk), [&](decltype(lk) lk) { - iter->second->returnConnection(conn, std::move(lk)); - }); + auto pool = iter->second; + pool->returnConnection(conn, std::move(lk)); } ConnectionPool::SpecificPool::SpecificPool(ConnectionPool* parent, const HostAndPort& hostAndPort) @@ -466,6 +421,9 @@ ConnectionPool::SpecificPool::SpecificPool(ConnectionPool* parent, const HostAnd ConnectionPool::SpecificPool::~SpecificPool() { DESTRUCTOR_GUARD(_requestTimer->cancelTimeout();) + + invariant(_requests.empty()); + invariant(_checkedOutPool.empty()); } size_t ConnectionPool::SpecificPool::inUseConnections(const stdx::unique_lock<stdx::mutex>& lk) { @@ -492,6 +450,8 @@ size_t ConnectionPool::SpecificPool::openConnections(const stdx::unique_lock<std Future<ConnectionPool::ConnectionHandle> ConnectionPool::SpecificPool::getConnection( const HostAndPort& hostAndPort, Milliseconds timeout, stdx::unique_lock<stdx::mutex> lk) { + invariant(_state != State::kInShutdown); + if (timeout < Milliseconds(0) || timeout > _parent->_options.refreshTimeout) { timeout = _parent->_options.refreshTimeout; } @@ -515,6 +475,7 @@ void ConnectionPool::SpecificPool::returnConnection(ConnectionInterface* connPtr auto needsRefreshTP = connPtr->getLastUsed() + _parent->_options.refreshRequirement; auto conn = takeFromPool(_checkedOutPool, connPtr); + invariant(conn); updateStateInLock(); @@ -551,46 +512,43 @@ void ConnectionPool::SpecificPool::returnConnection(ConnectionInterface* connPtr // Unlock in case refresh can occur immediately lk.unlock(); - connPtr->refresh( - _parent->_options.refreshTimeout, [this](ConnectionInterface* connPtr, Status status) { - runWithActiveClient([&](stdx::unique_lock<stdx::mutex> lk) { - auto conn = takeFromProcessingPool(connPtr); - - // If the host and port were dropped, let this lapse - if (conn->getGeneration() != _generation) { - spawnConnections(lk); - return; - } - - // If we're in shutdown, we don't need refreshed connections - if (_state == State::kInShutdown) - return; - - // If the connection refreshed successfully, throw it back in - // the ready pool - if (status.isOK()) { - addToReady(lk, std::move(conn)); - spawnConnections(lk); - return; - } - - // If we've exceeded the time limit, start a new connect, - // rather than failing all operations. We do this because the - // various callers have their own time limit which is unrelated - // to our internal one. - if (status.code() == ErrorCodes::NetworkInterfaceExceededTimeLimit) { - log() << "Pending connection to host " << _hostAndPort - << " did not complete within the connection timeout," - << " retrying with a new connection;" << openConnections(lk) - << " connections to that host remain open"; - spawnConnections(lk); - return; - } - - // Otherwise pass the failure on through - processFailure(status, std::move(lk)); - }); - }); + connPtr->refresh(_parent->_options.refreshTimeout, + guardCallback([this](stdx::unique_lock<stdx::mutex> lk, + ConnectionInterface* connPtr, + Status status) { + auto conn = takeFromProcessingPool(connPtr); + + // If we're in shutdown, we don't need refreshed connections + if (_state == State::kInShutdown) + return; + + // If the connection refreshed successfully, throw it back in + // the ready pool + if (status.isOK()) { + // If the host and port were dropped, let this lapse + if (conn->getGeneration() == _generation) { + addToReady(lk, std::move(conn)); + } + spawnConnections(lk); + return; + } + + // If we've exceeded the time limit, start a new connect, + // rather than failing all operations. We do this because the + // various callers have their own time limit which is unrelated + // to our internal one. + if (status.code() == ErrorCodes::NetworkInterfaceExceededTimeLimit) { + log() << "Pending connection to host " << _hostAndPort + << " did not complete within the connection timeout," + << " retrying with a new connection;" << openConnections(lk) + << " connections to that host remain open"; + spawnConnections(lk); + return; + } + + // Otherwise pass the failure on through + processFailure(status, std::move(lk)); + })); lk.lock(); } else { // If it's fine as it is, just put it in the ready queue @@ -611,33 +569,37 @@ void ConnectionPool::SpecificPool::addToReady(stdx::unique_lock<stdx::mutex>& lk // Our strategy for refreshing connections is to check them out and // immediately check them back in (which kicks off the refresh logic in // returnConnection - connPtr->setTimeout(_parent->_options.refreshRequirement, [this, connPtr]() { - OwnedConnection conn; + connPtr->setTimeout(_parent->_options.refreshRequirement, + guardCallback([this, connPtr](stdx::unique_lock<stdx::mutex> lk) { + auto conn = takeFromPool(_readyPool, connPtr); - runWithActiveClient([&](stdx::unique_lock<stdx::mutex> lk) { - if (!_readyPool.count(connPtr)) { - // We've already been checked out. We don't need to refresh - // ourselves. - return; - } + // We've already been checked out. We don't need to refresh + // ourselves. + if (!conn) + return; - conn = takeFromPool(_readyPool, connPtr); + // If we're in shutdown, we don't need to refresh connections + if (_state == State::kInShutdown) + return; - // If we're in shutdown, we don't need to refresh connections - if (_state == State::kInShutdown) - return; + _checkedOutPool[connPtr] = std::move(conn); - _checkedOutPool[connPtr] = std::move(conn); + connPtr->indicateSuccess(); - connPtr->indicateSuccess(); - - returnConnection(connPtr, std::move(lk)); - }); - }); + returnConnection(connPtr, std::move(lk)); + })); fulfillRequests(lk); } +// Sets state to shutdown and kicks off the failure protocol to tank existing connections +void ConnectionPool::SpecificPool::triggerShutdown(const Status& status, + stdx::unique_lock<stdx::mutex> lk) { + _state = State::kInShutdown; + _droppedProcessingPool.clear(); + processFailure(status, std::move(lk)); +} + // Drop connections and fail all requests void ConnectionPool::SpecificPool::processFailure(const Status& status, stdx::unique_lock<stdx::mutex> lk) { @@ -645,7 +607,11 @@ void ConnectionPool::SpecificPool::processFailure(const Status& status, // connections _generation++; - // Drop ready connections + // When a connection enters the ready pool, its timer is set to eventually refresh the + // connection. This requires a lifetime extension of the specific pool because the connection + // timer is tied to the lifetime of the connection, not the pool. That said, we can destruct + // all of the connections and thus timers of which we have ownership. + // In short, clearing the ready pool helps the SpecificPool drain. _readyPool.clear(); // Log something helpful @@ -653,7 +619,10 @@ void ConnectionPool::SpecificPool::processFailure(const Status& status, // Migrate processing connections to the dropped pool for (auto&& x : _processingPool) { - _droppedProcessingPool[x.first] = std::move(x.second); + if (_state != State::kInShutdown) { + // If we're just dropping the pool, we can reuse them later + _droppedProcessingPool[x.first] = std::move(x.second); + } } _processingPool.clear(); @@ -729,7 +698,12 @@ void ConnectionPool::SpecificPool::fulfillRequests(stdx::unique_lock<stdx::mutex // pass it to the user connPtr->resetToUnknown(); lk.unlock(); - promise.emplaceValue(ConnectionHandle(connPtr, ConnectionHandleDeleter(_parent))); + ConnectionHandle handle(connPtr, + guardCallback([this](stdx::unique_lock<stdx::mutex> localLk, + ConnectionPool::ConnectionInterface* conn) { + returnConnection(conn, std::move(localLk)); + })); + promise.emplaceValue(std::move(handle)); lk.lock(); } } @@ -753,8 +727,10 @@ void ConnectionPool::SpecificPool::spawnConnections(stdx::unique_lock<stdx::mute }; // While all of our inflight connections are less than our target - while ((_readyPool.size() + _processingPool.size() + _checkedOutPool.size() < target()) && + while ((_state != State::kInShutdown) && + (_readyPool.size() + _processingPool.size() + _checkedOutPool.size() < target()) && (_processingPool.size() < _parent->_options.maxConnecting)) { + OwnedConnection handle; try { // make a new connection and put it in processing @@ -764,36 +740,38 @@ void ConnectionPool::SpecificPool::spawnConnections(stdx::unique_lock<stdx::mute fassertFailed(40336); } - auto connPtr = handle.get(); - _processingPool[connPtr] = std::move(handle); + _processingPool[handle.get()] = handle; ++_created; // Run the setup callback lk.unlock(); - connPtr->setup( - _parent->_options.refreshTimeout, [this](ConnectionInterface* connPtr, Status status) { - runWithActiveClient([&](stdx::unique_lock<stdx::mutex> lk) { - auto conn = takeFromProcessingPool(connPtr); - - if (conn->getGeneration() != _generation) { - // If the host and port was dropped, let the - // connection lapse - spawnConnections(lk); - } else if (status.isOK()) { + handle->setup( + _parent->_options.refreshTimeout, + guardCallback([this]( + stdx::unique_lock<stdx::mutex> lk, ConnectionInterface* connPtr, Status status) { + auto conn = takeFromProcessingPool(connPtr); + + // If we're in shutdown, we don't need this conn + if (_state == State::kInShutdown) + return; + + if (status.isOK()) { + // If the host and port was dropped, let the connection lapse + if (conn->getGeneration() == _generation) { addToReady(lk, std::move(conn)); - spawnConnections(lk); - } else if (status.code() == ErrorCodes::NetworkInterfaceExceededTimeLimit) { - // If we've exceeded the time limit, restart the connect, rather than - // failing all operations. We do this because the various callers - // have their own time limit which is unrelated to our internal one. - spawnConnections(lk); - } else { - // If the setup failed, cascade the failure edge - processFailure(status, std::move(lk)); } - }); - }); + spawnConnections(lk); + } else if (status.code() == ErrorCodes::NetworkInterfaceExceededTimeLimit) { + // If we've exceeded the time limit, restart the connect, rather than + // failing all operations. We do this because the various callers + // have their own time limit which is unrelated to our internal one. + spawnConnections(lk); + } else { + // If the setup failed, cascade the failure edge + processFailure(status, std::move(lk)); + } + })); // Note that this assumes that the refreshTimeout is sound for the // setupTimeout @@ -801,50 +779,12 @@ void ConnectionPool::SpecificPool::spawnConnections(stdx::unique_lock<stdx::mute } } -// Called every second after hostTimeout until all processing connections reap -void ConnectionPool::SpecificPool::shutdown() { - stdx::unique_lock<stdx::mutex> lk(_parent->_mutex); - - // We're racing: - // - // Thread A (this thread) - // * Fired the shutdown timer - // * Came into shutdown() and blocked - // - // Thread B (some new consumer) - // * Requested a new connection - // * Beat thread A to the mutex - // * Cancelled timer (but thread A already made it in) - // * Set state to running - // * released the mutex - // - // So we end up in shutdown, but with kRunning. If we're here we raced and - // we should just bail. - if (_state == State::kRunning) { - return; - } - - _state = State::kInShutdown; - - // If we have processing connections, wait for them to finish or timeout - // before shutdown - if (_processingPool.size() || _droppedProcessingPool.size() || _activeClients) { - _requestTimer->setTimeout(Seconds(1), [this]() { shutdown(); }); - - return; - } - - invariant(_requests.empty()); - invariant(_checkedOutPool.empty()); - - _parent->_pools.erase(_hostAndPort); -} - template <typename OwnershipPoolType> typename OwnershipPoolType::mapped_type ConnectionPool::SpecificPool::takeFromPool( OwnershipPoolType& pool, typename OwnershipPoolType::key_type connPtr) { auto iter = pool.find(connPtr); - invariant(iter != pool.end()); + if (iter == pool.end()) + return typename OwnershipPoolType::mapped_type(); auto conn = std::move(iter->second); pool.erase(iter); @@ -853,8 +793,11 @@ typename OwnershipPoolType::mapped_type ConnectionPool::SpecificPool::takeFromPo ConnectionPool::SpecificPool::OwnedConnection ConnectionPool::SpecificPool::takeFromProcessingPool( ConnectionInterface* connPtr) { - if (_processingPool.count(connPtr)) - return takeFromPool(_processingPool, connPtr); + auto conn = takeFromPool(_processingPool, connPtr); + if (conn) { + invariant(_state != State::kInShutdown); + return conn; + } return takeFromPool(_droppedProcessingPool, connPtr); } @@ -862,6 +805,16 @@ ConnectionPool::SpecificPool::OwnedConnection ConnectionPool::SpecificPool::take // Updates our state and manages the request timer void ConnectionPool::SpecificPool::updateStateInLock() { + if (_state == State::kInShutdown) { + // If we're in shutdown, there is nothing to update. Our clients are all gone. + if (_processingPool.empty() && !_activeClients) { + // If we have no more clients that require access to us, delist from the parent pool + LOG(2) << "Delisting connection pool for " << _hostAndPort; + _parent->_pools.erase(_hostAndPort); + } + return; + } + if (_requests.size()) { // We have some outstanding requests, we're live @@ -880,8 +833,8 @@ void ConnectionPool::SpecificPool::updateStateInLock() { // We set a timer for the most recent request, then invoke each timed // out request we couldn't service - _requestTimer->setTimeout(timeout, [this]() { - runWithActiveClient([&](stdx::unique_lock<stdx::mutex> lk) { + _requestTimer->setTimeout( + timeout, guardCallback([this](stdx::unique_lock<stdx::mutex> lk) { auto now = _parent->_factory->now(); while (_requests.size()) { @@ -902,8 +855,7 @@ void ConnectionPool::SpecificPool::updateStateInLock() { } updateStateInLock(); - }); - }); + })); } else if (_checkedOutPool.size()) { // If we have no requests, but someone's using a connection, we just // hang around until the next request or a return @@ -926,8 +878,17 @@ void ConnectionPool::SpecificPool::updateStateInLock() { auto timeout = _parent->_options.hostTimeout; - // Set the shutdown timer - _requestTimer->setTimeout(timeout, [this]() { shutdown(); }); + // Set the shutdown timer, this gets reset on any request + _requestTimer->setTimeout(timeout, [ this, anchor = shared_from_this() ]() { + stdx::unique_lock<stdx::mutex> lk(anchor->_parent->_mutex); + if (_state != State::kIdle) + return; + + triggerShutdown( + Status(ErrorCodes::NetworkInterfaceExceededTimeLimit, + "Connection pool has been idle for longer than the host timeout"), + std::move(lk)); + }); } } diff --git a/src/mongo/executor/connection_pool.h b/src/mongo/executor/connection_pool.h index 118f734920d..25f98fa63cd 100644 --- a/src/mongo/executor/connection_pool.h +++ b/src/mongo/executor/connection_pool.h @@ -63,11 +63,11 @@ class ConnectionPool : public EgressTagCloser { class SpecificPool; public: - class ConnectionHandleDeleter; class ConnectionInterface; class DependentTypeFactoryInterface; class TimerInterface; + using ConnectionHandleDeleter = stdx::function<void(ConnectionInterface* connection)>; using ConnectionHandle = std::unique_ptr<ConnectionInterface, ConnectionHandleDeleter>; using GetConnectionCallback = stdx::function<void(StatusWith<ConnectionHandle>)>; @@ -131,7 +131,7 @@ public: EgressTagCloserManager* egressTagCloserManager = nullptr; }; - explicit ConnectionPool(std::unique_ptr<DependentTypeFactoryInterface> impl, + explicit ConnectionPool(std::shared_ptr<DependentTypeFactoryInterface> impl, std::string name, Options options = Options{}); @@ -163,29 +163,15 @@ private: // accessed outside the lock const Options _options; - const std::unique_ptr<DependentTypeFactoryInterface> _factory; + const std::shared_ptr<DependentTypeFactoryInterface> _factory; // The global mutex for specific pool access and the generation counter mutable stdx::mutex _mutex; - stdx::unordered_map<HostAndPort, std::unique_ptr<SpecificPool>> _pools; + stdx::unordered_map<HostAndPort, std::shared_ptr<SpecificPool>> _pools; EgressTagCloserManager* _manager; }; -class ConnectionPool::ConnectionHandleDeleter { -public: - ConnectionHandleDeleter() = default; - ConnectionHandleDeleter(ConnectionPool* pool) : _pool(pool) {} - - void operator()(ConnectionInterface* connection) const { - if (_pool && connection) - _pool->returnConnection(connection); - } - -private: - ConnectionPool* _pool = nullptr; -}; - /** * Interface for a basic timer * @@ -220,9 +206,7 @@ public: * specifically callbacks to set them up (connect + auth + whatever else), * refresh them (issue some kind of ping) and manage a timer. */ -class ConnectionPool::ConnectionInterface - : public TimerInterface, - public std::enable_shared_from_this<ConnectionPool::ConnectionInterface> { +class ConnectionPool::ConnectionInterface : public TimerInterface { MONGO_DISALLOW_COPYING(ConnectionInterface); friend class ConnectionPool; @@ -336,12 +320,17 @@ public: /** * Makes a new timer */ - virtual std::unique_ptr<TimerInterface> makeTimer() = 0; + virtual std::shared_ptr<TimerInterface> makeTimer() = 0; /** * Returns the current time point */ virtual Date_t now() = 0; + + /** + * shutdown + */ + virtual void shutdown() = 0; }; } // namespace executor diff --git a/src/mongo/executor/connection_pool_test_fixture.cpp b/src/mongo/executor/connection_pool_test_fixture.cpp index 560f184a025..77629b0fa94 100644 --- a/src/mongo/executor/connection_pool_test_fixture.cpp +++ b/src/mongo/executor/connection_pool_test_fixture.cpp @@ -42,6 +42,8 @@ TimerImpl::~TimerImpl() { } void TimerImpl::setTimeout(Milliseconds timeout, TimeoutCallback cb) { + _timers.erase(this); + _cb = std::move(cb); _expiration = _global->now() + timeout; @@ -50,10 +52,14 @@ void TimerImpl::setTimeout(Milliseconds timeout, TimeoutCallback cb) { void TimerImpl::cancelTimeout() { _timers.erase(this); + _cb = TimeoutCallback{}; } void TimerImpl::clear() { - _timers.clear(); + while (!_timers.empty()) { + auto* timer = *_timers.begin(); + timer->cancelTimeout(); + } } void TimerImpl::fireIfNecessary() { @@ -233,7 +239,7 @@ std::shared_ptr<ConnectionPool::ConnectionInterface> PoolImpl::makeConnection( return std::make_shared<ConnectionImpl>(hostAndPort, generation, this); } -std::unique_ptr<ConnectionPool::TimerInterface> PoolImpl::makeTimer() { +std::shared_ptr<ConnectionPool::TimerInterface> PoolImpl::makeTimer() { return stdx::make_unique<TimerImpl>(this); } diff --git a/src/mongo/executor/connection_pool_test_fixture.h b/src/mongo/executor/connection_pool_test_fixture.h index f12eccf6c34..a66492d555b 100644 --- a/src/mongo/executor/connection_pool_test_fixture.h +++ b/src/mongo/executor/connection_pool_test_fixture.h @@ -150,10 +150,14 @@ public: std::shared_ptr<ConnectionPool::ConnectionInterface> makeConnection( const HostAndPort& hostAndPort, size_t generation) override; - std::unique_ptr<ConnectionPool::TimerInterface> makeTimer() override; + std::shared_ptr<ConnectionPool::TimerInterface> makeTimer() override; Date_t now() override; + void shutdown() override { + TimerImpl::clear(); + }; + /** * setNow() can be used to fire all timers that have passed a point in time */ diff --git a/src/mongo/executor/connection_pool_tl.cpp b/src/mongo/executor/connection_pool_tl.cpp index bb41a687a6e..5e79798c6aa 100644 --- a/src/mongo/executor/connection_pool_tl.cpp +++ b/src/mongo/executor/connection_pool_tl.cpp @@ -44,15 +44,61 @@ const auto kMaxTimerDuration = Milliseconds::max(); struct TimeoutHandler { AtomicBool done; Promise<void> promise; + + explicit TimeoutHandler(Promise<void> p) : promise(std::move(p)) {} }; } // namespace +void TLTypeFactory::shutdown() { + // Stop any attempt to schedule timers in the future + _inShutdown.store(true); + + stdx::lock_guard<stdx::mutex> lk(_mutex); + + log() << "Killing all outstanding egress activity."; + for (auto collar : _collars) { + collar->kill(); + } +} + +void TLTypeFactory::fasten(Type* type) { + stdx::lock_guard<stdx::mutex> lk(_mutex); + _collars.insert(type); +} + +void TLTypeFactory::release(Type* type) { + stdx::lock_guard<stdx::mutex> lk(_mutex); + _collars.erase(type); + + type->_wasReleased = true; +} + +TLTypeFactory::Type::Type(const std::shared_ptr<TLTypeFactory>& factory) : _factory{factory} {} + +TLTypeFactory::Type::~Type() { + invariant(_wasReleased); +} + +void TLTypeFactory::Type::release() { + _factory->release(this); +} + +bool TLTypeFactory::inShutdown() const { + return _inShutdown.load(); +} + void TLTimer::setTimeout(Milliseconds timeoutVal, TimeoutCallback cb) { + // We will not wait on a timeout if we are in shutdown. + // The clients will be canceled as an inevitable consequence of pools shutting down. + if (inShutdown()) { + LOG(2) << "Skipping timeout due to impending shutdown."; + return; + } + _timer->waitFor(timeoutVal).getAsync([cb = std::move(cb)](Status status) { - // TODO: verify why we still get broken promises when expliciting call stop and shutting - // down NITL's quickly. - if (status == ErrorCodes::CallbackCanceled || status == ErrorCodes::BrokenPromise) { + // If we get canceled, then we don't worry about the timeout anymore + if (status == ErrorCodes::CallbackCanceled) { return; } @@ -101,19 +147,21 @@ const Status& TLConnection::getStatus() const { } void TLConnection::setTimeout(Milliseconds timeout, TimeoutCallback cb) { - _timer.setTimeout(timeout, std::move(cb)); + auto anchor = shared_from_this(); + _timer->setTimeout(timeout, [ cb = std::move(cb), anchor = std::move(anchor) ] { cb(); }); } void TLConnection::cancelTimeout() { - _timer.cancelTimeout(); + _timer->cancelTimeout(); } void TLConnection::setup(Milliseconds timeout, SetupCallback cb) { auto anchor = shared_from_this(); - auto handler = std::make_shared<TimeoutHandler>(); - handler->promise.getFuture().getAsync( - [ this, cb = std::move(cb) ](Status status) { cb(this, std::move(status)); }); + auto pf = makePromiseFuture<void>(); + auto handler = std::make_shared<TimeoutHandler>(std::move(pf.promise)); + std::move(pf.future).getAsync( + [ this, cb = std::move(cb), anchor ](Status status) { cb(this, std::move(status)); }); log() << "Connecting to " << _peer; setTimeout(timeout, [this, handler, timeout] { @@ -166,6 +214,7 @@ void TLConnection::setup(Milliseconds timeout, SetupCallback cb) { handler->promise.setError(status); } }); + LOG(2) << "Finished connection setup."; } void TLConnection::resetToUnknown() { @@ -175,9 +224,10 @@ void TLConnection::resetToUnknown() { void TLConnection::refresh(Milliseconds timeout, RefreshCallback cb) { auto anchor = shared_from_this(); - auto handler = std::make_shared<TimeoutHandler>(); - handler->promise.getFuture().getAsync( - [ this, cb = std::move(cb) ](Status status) { cb(this, status); }); + auto pf = makePromiseFuture<void>(); + auto handler = std::make_shared<TimeoutHandler>(std::move(pf.promise)); + std::move(pf.future).getAsync( + [ this, cb = std::move(cb), anchor ](Status status) { cb(this, status); }); setTimeout(timeout, [this, handler] { if (handler->done.swap(true)) { @@ -216,14 +266,27 @@ size_t TLConnection::getGeneration() const { return _generation; } +void TLConnection::cancelAsync() { + if (_client) + _client->cancel(); +} + std::shared_ptr<ConnectionPool::ConnectionInterface> TLTypeFactory::makeConnection( const HostAndPort& hostAndPort, size_t generation) { - return std::make_shared<TLConnection>( - _reactor, getGlobalServiceContext(), hostAndPort, generation, _onConnectHook.get()); + auto conn = std::make_shared<TLConnection>(shared_from_this(), + _reactor, + getGlobalServiceContext(), + hostAndPort, + generation, + _onConnectHook.get()); + fasten(conn.get()); + return conn; } -std::unique_ptr<ConnectionPool::TimerInterface> TLTypeFactory::makeTimer() { - return std::make_unique<TLTimer>(_reactor); +std::shared_ptr<ConnectionPool::TimerInterface> TLTypeFactory::makeTimer() { + auto timer = std::make_shared<TLTimer>(shared_from_this(), _reactor); + fasten(timer.get()); + return timer; } Date_t TLTypeFactory::now() { diff --git a/src/mongo/executor/connection_pool_tl.h b/src/mongo/executor/connection_pool_tl.h index 1e9e1c98604..5aff6c80c57 100644 --- a/src/mongo/executor/connection_pool_tl.h +++ b/src/mongo/executor/connection_pool_tl.h @@ -34,13 +34,17 @@ #include "mongo/executor/connection_pool.h" #include "mongo/executor/network_connection_hook.h" #include "mongo/executor/network_interface.h" +#include "mongo/util/future.h" namespace mongo { namespace executor { namespace connection_pool_tl { -class TLTypeFactory final : public ConnectionPool::DependentTypeFactoryInterface { +class TLTypeFactory final : public ConnectionPool::DependentTypeFactoryInterface, + public std::enable_shared_from_this<TLTypeFactory> { public: + class Type; + TLTypeFactory(transport::ReactorHandle reactor, transport::TransportLayer* tl, std::unique_ptr<NetworkConnectionHook> onConnectHook) @@ -48,42 +52,91 @@ public: std::shared_ptr<ConnectionPool::ConnectionInterface> makeConnection( const HostAndPort& hostAndPort, size_t generation) override; - std::unique_ptr<ConnectionPool::TimerInterface> makeTimer() override; + std::shared_ptr<ConnectionPool::TimerInterface> makeTimer() override; Date_t now() override; + void shutdown() override; + bool inShutdown() const; + void fasten(Type* type); + void release(Type* type); + private: transport::ReactorHandle _reactor; transport::TransportLayer* _tl; std::unique_ptr<NetworkConnectionHook> _onConnectHook; + + mutable stdx::mutex _mutex; + AtomicBool _inShutdown{false}; + stdx::unordered_set<Type*> _collars; +}; + +class TLTypeFactory::Type : public std::enable_shared_from_this<TLTypeFactory::Type> { + friend class TLTypeFactory; + + MONGO_DISALLOW_COPYING(Type); + +public: + explicit Type(const std::shared_ptr<TLTypeFactory>& factory); + ~Type(); + + void release(); + bool inShutdown() const { + return _factory->inShutdown(); + } + + virtual void kill() = 0; + +private: + std::shared_ptr<TLTypeFactory> _factory; + bool _wasReleased = false; }; -class TLTimer final : public ConnectionPool::TimerInterface { +class TLTimer final : public ConnectionPool::TimerInterface, public TLTypeFactory::Type { public: - explicit TLTimer(const transport::ReactorHandle& reactor) - : _reactor(reactor), _timer(_reactor->makeTimer()) {} + explicit TLTimer(const std::shared_ptr<TLTypeFactory>& factory, + const transport::ReactorHandle& reactor) + : TLTypeFactory::Type(factory), _reactor(reactor), _timer(_reactor->makeTimer()) {} + ~TLTimer() { + // Release must be the first expression of this dtor + release(); + } + + void kill() override { + cancelTimeout(); + } void setTimeout(Milliseconds timeout, TimeoutCallback cb) override; void cancelTimeout() override; private: transport::ReactorHandle _reactor; - std::unique_ptr<transport::ReactorTimer> _timer; + std::shared_ptr<transport::ReactorTimer> _timer; }; -class TLConnection final : public ConnectionPool::ConnectionInterface { +class TLConnection final : public ConnectionPool::ConnectionInterface, public TLTypeFactory::Type { public: - TLConnection(transport::ReactorHandle reactor, + TLConnection(const std::shared_ptr<TLTypeFactory>& factory, + transport::ReactorHandle reactor, ServiceContext* serviceContext, HostAndPort peer, size_t generation, NetworkConnectionHook* onConnectHook) - : _reactor(reactor), + : TLTypeFactory::Type(factory), + _reactor(reactor), _serviceContext(serviceContext), - _timer(_reactor), + _timer(factory->makeTimer()), _peer(std::move(peer)), _generation(generation), _onConnectHook(onConnectHook) {} + ~TLConnection() { + // Release must be the first expression of this dtor + release(); + } + + void kill() override { + cancelAsync(); + } void indicateSuccess() override; void indicateFailure(Status status) override; @@ -101,13 +154,15 @@ private: void setup(Milliseconds timeout, SetupCallback cb) override; void resetToUnknown() override; void refresh(Milliseconds timeout, RefreshCallback cb) override; + void cancelAsync(); size_t getGeneration() const override; private: transport::ReactorHandle _reactor; ServiceContext* const _serviceContext; - TLTimer _timer; + std::shared_ptr<ConnectionPool::TimerInterface> _timer; + HostAndPort _peer; size_t _generation; NetworkConnectionHook* const _onConnectHook; diff --git a/src/mongo/executor/network_interface_integration_fixture.cpp b/src/mongo/executor/network_interface_integration_fixture.cpp index 6a593374fb6..4a694eec7d2 100644 --- a/src/mongo/executor/network_interface_integration_fixture.cpp +++ b/src/mongo/executor/network_interface_integration_fixture.cpp @@ -62,9 +62,8 @@ void NetworkInterfaceIntegrationFixture::startNet( } void NetworkInterfaceIntegrationFixture::tearDown() { - if (!_net->inShutdown()) { - _net->shutdown(); - } + // Network interface will only shutdown once because of an internal shutdown guard + _net->shutdown(); } NetworkInterface& NetworkInterfaceIntegrationFixture::net() { diff --git a/src/mongo/executor/network_interface_integration_test.cpp b/src/mongo/executor/network_interface_integration_test.cpp index 738c7606676..c18f229ba28 100644 --- a/src/mongo/executor/network_interface_integration_test.cpp +++ b/src/mongo/executor/network_interface_integration_test.cpp @@ -93,12 +93,12 @@ class HangingHook : public executor::NetworkConnectionHook { } Status handleReply(const HostAndPort& remoteHost, RemoteCommandResponse&& response) final { - if (pingCommandMissing(response)) { - return {ErrorCodes::NetworkInterfaceExceededTimeLimit, - "No ping command. Simulating timeout"}; + if (!pingCommandMissing(response)) { + ASSERT_EQ(ErrorCodes::CallbackCanceled, response.status); + return response.status; } - MONGO_UNREACHABLE; + return {ErrorCodes::ExceededTimeLimit, "No ping command. Returning pseudo-timeout."}; } }; @@ -107,8 +107,19 @@ class HangingHook : public executor::NetworkConnectionHook { TEST_F(NetworkInterfaceIntegrationFixture, HookHangs) { startNet(stdx::make_unique<HangingHook>()); - assertCommandFailsOnClient( - "admin", BSON("ping" << 1), ErrorCodes::NetworkInterfaceExceededTimeLimit, Seconds(1)); + /** + * Since mongos's have no ping command, we effectively skip this test by returning + * ExceededTimeLimit above. (That ErrorCode is used heavily in repl and sharding code.) + * If we return NetworkInterfaceExceededTimeLimit, it will make the ConnectionPool + * attempt to reform the connection, which can lead to an accepted but unfortunate + * race between TLConnection::setup and TLTypeFactory::shutdown. + * We assert here that the error code we get is in the error class of timeouts, + * which covers both NetworkInterfaceExceededTimeLimit and ExceededTimeLimit. + */ + RemoteCommandRequest request{ + fixture().getServers()[0], "admin", BSON("ping" << 1), BSONObj(), nullptr, Seconds(1)}; + auto res = runCommandSync(request); + ASSERT(ErrorCodes::isExceededTimeLimitError(res.status.code())); } using ResponseStatus = TaskExecutor::ResponseStatus; diff --git a/src/mongo/executor/network_interface_tl.cpp b/src/mongo/executor/network_interface_tl.cpp index 2bb57e38678..923d25f565e 100644 --- a/src/mongo/executor/network_interface_tl.cpp +++ b/src/mongo/executor/network_interface_tl.cpp @@ -101,17 +101,37 @@ void NetworkInterfaceTL::startup() { std::move(typeFactory), std::string("NetworkInterfaceTL-") + _instanceName, _connPoolOpts); _ioThread = stdx::thread([this] { setThreadName(_instanceName); - LOG(2) << "The NetworkInterfaceTL reactor thread is spinning up"; - _reactor->run(); + _run(); }); } +void NetworkInterfaceTL::_run() { + LOG(2) << "The NetworkInterfaceTL reactor thread is spinning up"; + + // This returns when the reactor is stopped in shutdown() + _reactor->run(); + + // Note that the pool will shutdown again when the ConnectionPool dtor runs + // This prevents new timers from being set, calls all cancels via the factory registry, and + // destructs all connections for all existing pools. + _pool->shutdown(); + + // Close out all remaining tasks in the reactor now that they've all been canceled. + _reactor->drain(); + + LOG(2) << "NetworkInterfaceTL shutdown successfully"; +} + void NetworkInterfaceTL::shutdown() { - _inShutdown.store(true); + if (_inShutdown.swap(true)) + return; + + LOG(2) << "Shutting down network interface."; + + // Stop the reactor/thread first so that nothing runs on a partially dtor'd pool. _reactor->stop(); + _ioThread.join(); - _pool->shutdown(); - LOG(2) << "NetworkInterfaceTL shutdown successfully"; } bool NetworkInterfaceTL::inShutdown() const { @@ -169,7 +189,8 @@ Status NetworkInterfaceTL::startCommand(const TaskExecutor::CallbackHandle& cbHa request.metadata = newMetadata.obj(); } - auto state = std::make_shared<CommandState>(request, cbHandle); + auto pf = makePromiseFuture<RemoteCommandResponse>(); + auto state = std::make_shared<CommandState>(request, cbHandle, std::move(pf.promise)); { stdx::lock_guard<stdx::mutex> lk(_inProgressMutex); _inProgress.insert({state->cbHandle, state}); @@ -182,10 +203,9 @@ Status NetworkInterfaceTL::startCommand(const TaskExecutor::CallbackHandle& cbHa if (MONGO_FAIL_POINT(networkInterfaceDiscardCommandsBeforeAcquireConn)) { log() << "Discarding command due to failpoint before acquireConn"; - std::move(state->mergedFuture) - .getAsync([onFinish](StatusWith<RemoteCommandResponse> response) { - onFinish(RemoteCommandResponse(response.getStatus(), Milliseconds{0})); - }); + std::move(pf.future).getAsync([onFinish](StatusWith<RemoteCommandResponse> response) { + onFinish(RemoteCommandResponse(response.getStatus(), Milliseconds{0})); + }); return Status::OK(); } @@ -215,10 +235,18 @@ Status NetworkInterfaceTL::startCommand(const TaskExecutor::CallbackHandle& cbHa }); }); - auto remainingWork = [this, state, baton, onFinish]( - StatusWith<std::shared_ptr<CommandState::ConnHandle>> swConn) { - makeReadyFutureWith( - [&] { return _onAcquireConn(state, std::move(*uassertStatusOK(swConn)), baton); }) + auto remainingWork = [ + this, + state, + // TODO: once SERVER-35685 is done, stop using a `std::shared_ptr<Future>` here. + future = std::make_shared<decltype(pf.future)>(std::move(pf.future)), + baton, + onFinish + ](StatusWith<std::shared_ptr<CommandState::ConnHandle>> swConn) mutable { + makeReadyFutureWith([&] { + return _onAcquireConn( + state, std::move(*future), std::move(*uassertStatusOK(swConn)), baton); + }) .onError([](Status error) -> StatusWith<RemoteCommandResponse> { // The TransportLayer has, for historical reasons returned SocketException for // network errors, but sharding assumes HostUnreachable on network errors. @@ -267,11 +295,12 @@ Status NetworkInterfaceTL::startCommand(const TaskExecutor::CallbackHandle& cbHa // returning a ready Future with a not-OK status. Future<RemoteCommandResponse> NetworkInterfaceTL::_onAcquireConn( std::shared_ptr<CommandState> state, + Future<RemoteCommandResponse> future, CommandState::ConnHandle conn, const transport::BatonHandle& baton) { if (MONGO_FAIL_POINT(networkInterfaceDiscardCommandsAfterAcquireConn)) { conn->indicateSuccess(); - return std::move(state->mergedFuture); + return future; } if (state->done.load()) { @@ -366,7 +395,7 @@ Future<RemoteCommandResponse> NetworkInterfaceTL::_onAcquireConn( state->promise.setFromStatusWith(std::move(swr)); }); - return std::move(state->mergedFuture); + return future; } void NetworkInterfaceTL::_eraseInUseConn(const TaskExecutor::CallbackHandle& cbHandle) { diff --git a/src/mongo/executor/network_interface_tl.h b/src/mongo/executor/network_interface_tl.h index 068e5d310bb..621336cd6e7 100644 --- a/src/mongo/executor/network_interface_tl.h +++ b/src/mongo/executor/network_interface_tl.h @@ -80,8 +80,12 @@ public: private: struct CommandState { - CommandState(RemoteCommandRequest request_, TaskExecutor::CallbackHandle cbHandle_) - : request(std::move(request_)), cbHandle(std::move(cbHandle_)) {} + CommandState(RemoteCommandRequest request_, + TaskExecutor::CallbackHandle cbHandle_, + Promise<RemoteCommandResponse> promise_) + : request(std::move(request_)), + cbHandle(std::move(cbHandle_)), + promise(std::move(promise_)) {} RemoteCommandRequest request; TaskExecutor::CallbackHandle cbHandle; @@ -104,11 +108,12 @@ private: AtomicBool done; Promise<RemoteCommandResponse> promise; - Future<RemoteCommandResponse> mergedFuture = promise.getFuture(); }; + void _run(); void _eraseInUseConn(const TaskExecutor::CallbackHandle& handle); Future<RemoteCommandResponse> _onAcquireConn(std::shared_ptr<CommandState> state, + Future<RemoteCommandResponse> future, CommandState::ConnHandle conn, const transport::BatonHandle& baton); diff --git a/src/mongo/transport/service_executor_test.cpp b/src/mongo/transport/service_executor_test.cpp index 8e749228631..1db186ec0d2 100644 --- a/src/mongo/transport/service_executor_test.cpp +++ b/src/mongo/transport/service_executor_test.cpp @@ -104,6 +104,14 @@ public: _ioContext.stop(); } + void drain() override final { + _ioContext.restart(); + while (_ioContext.poll()) { + LOG(1) << "Draining remaining work in reactor."; + } + _ioContext.stop(); + } + std::unique_ptr<ReactorTimer> makeTimer() final { MONGO_UNREACHABLE; } diff --git a/src/mongo/transport/transport_layer.h b/src/mongo/transport/transport_layer.h index 3080eceabac..b3ee5aa92a7 100644 --- a/src/mongo/transport/transport_layer.h +++ b/src/mongo/transport/transport_layer.h @@ -155,6 +155,7 @@ public: virtual void run() noexcept = 0; virtual void runFor(Milliseconds time) noexcept = 0; virtual void stop() = 0; + virtual void drain() = 0; using Task = stdx::function<void()>; diff --git a/src/mongo/transport/transport_layer_asio.cpp b/src/mongo/transport/transport_layer_asio.cpp index 4cc72dc99fe..617dced8cac 100644 --- a/src/mongo/transport/transport_layer_asio.cpp +++ b/src/mongo/transport/transport_layer_asio.cpp @@ -67,40 +67,30 @@ MONGO_FAIL_POINT_DEFINE(transportLayerASIOasyncConnectTimesOut); class ASIOReactorTimer final : public ReactorTimer { public: explicit ASIOReactorTimer(asio::io_context& ctx) - : _timerState(std::make_shared<TimerState>(ctx)) {} + : _timer(std::make_shared<asio::system_timer>(ctx)) {} ~ASIOReactorTimer() { // The underlying timer won't get destroyed until the last promise from _asyncWait - // has been filled, so cancel the timer so call callbacks get run + // has been filled, so cancel the timer so our promises get fulfilled cancel(); } void cancel(const BatonHandle& baton = nullptr) override { - auto promise = [&] { - stdx::lock_guard<stdx::mutex> lk(_timerState->mutex); - _timerState->generation++; - return std::move(_timerState->finalPromise); - }(); - - if (promise) { - // We're worried that setting the error on the promise without unwinding the stack - // can lead to a deadlock, so this gets scheduled on the io_context of the timer. - _timerState->timer.get_io_context().post([promise = promise->share()]() mutable { - promise.setError({ErrorCodes::CallbackCanceled, "Timer was canceled"}); - }); + // If we have a baton try to cancel that. + if (baton && baton->cancelTimer(*this)) { + LOG(2) << "Canceled via baton, skipping asio cancel."; + return; } - if (!(baton && baton->cancelTimer(*this))) { - _timerState->timer.cancel(); - } + // Otherwise there could be a previous timer that was scheduled normally. + _timer->cancel(); } Future<void> waitFor(Milliseconds timeout, const BatonHandle& baton = nullptr) override { if (baton) { return _asyncWait([&] { return baton->waitFor(*this, timeout); }, baton); } else { - return _asyncWait( - [&] { _timerState->timer.expires_after(timeout.toSystemDuration()); }); + return _asyncWait([&] { _timer->expires_after(timeout.toSystemDuration()); }); } } @@ -108,48 +98,21 @@ public: if (baton) { return _asyncWait([&] { return baton->waitUntil(*this, expiration); }, baton); } else { - return _asyncWait( - [&] { _timerState->timer.expires_at(expiration.toSystemTimePoint()); }); + return _asyncWait([&] { _timer->expires_at(expiration.toSystemTimePoint()); }); } } private: - std::pair<Future<void>, uint64_t> _getFuture() { - stdx::lock_guard<stdx::mutex> lk(_timerState->mutex); - auto id = ++_timerState->generation; - invariant(!_timerState->finalPromise); - auto pf = makePromiseFuture<void>(); - _timerState->finalPromise = std::make_unique<Promise<void>>(std::move(pf.promise)); - return std::make_pair(std::move(pf.future), id); - } - template <typename ArmTimerCb> Future<void> _asyncWait(ArmTimerCb&& armTimer) { try { cancel(); - Future<void> ret; - uint64_t id; - std::tie(ret, id) = _getFuture(); - armTimer(); - _timerState->timer.async_wait( - [ id, state = _timerState ](const std::error_code& ec) mutable { - stdx::unique_lock<stdx::mutex> lk(state->mutex); - if (id != state->generation) { - return; - } - auto promise = std::move(state->finalPromise); - lk.unlock(); - - if (ec) { - promise->setError(errorCodeToStatus(ec)); - } else { - promise->emplaceValue(); - } - }); - - return ret; + return _timer->async_wait(UseFuture{}).tapError([timer = _timer](const Status& status) { + LOG(2) << "Timer received error: " << status; + }); + } catch (asio::system_error& ex) { return Future<void>::makeReady(errorCodeToStatus(ex.code())); } @@ -159,40 +122,19 @@ private: Future<void> _asyncWait(ArmTimerCb&& armTimer, const BatonHandle& baton) { cancel(baton); - Future<void> ret; - uint64_t id; - std::tie(ret, id) = _getFuture(); - - armTimer().getAsync([ id, state = _timerState ](Status status) mutable { - stdx::unique_lock<stdx::mutex> lk(state->mutex); - if (id != state->generation) { - return; - } - auto promise = std::move(state->finalPromise); - lk.unlock(); - + auto pf = makePromiseFuture<void>(); + armTimer().getAsync([sp = pf.promise.share()](Status status) mutable { if (status.isOK()) { - promise->emplaceValue(); + sp.emplaceValue(); } else { - promise->setError(status); + sp.setError(status); } }); - return ret; + return std::move(pf.future); } - // The timer itself and its state are stored in this struct managed by a shared_ptr so we can - // extend the lifetime of the timer until all callbacks to timer.async_wait have run. - struct TimerState { - explicit TimerState(asio::io_context& ctx) : timer(ctx) {} - - asio::system_timer timer; - stdx::mutex mutex; - uint64_t generation = 0; - std::unique_ptr<Promise<void>> finalPromise; - }; - - std::shared_ptr<TimerState> _timerState; + std::shared_ptr<asio::system_timer> _timer; }; class TransportLayerASIO::ASIOReactor final : public Reactor { @@ -213,7 +155,6 @@ public: void runFor(Milliseconds time) noexcept override { ThreadIdGuard threadIdGuard(this); asio::io_context::work work(_ioContext); - try { _ioContext.run_for(time.toSystemDuration()); } catch (...) { @@ -226,6 +167,14 @@ public: _ioContext.stop(); } + void drain() override { + _ioContext.restart(); + while (_ioContext.poll()) { + LOG(2) << "Draining remaining work in reactor."; + } + _ioContext.stop(); + } + std::unique_ptr<ReactorTimer> makeTimer() override { return std::make_unique<ASIOReactorTimer>(_ioContext); } @@ -557,8 +506,14 @@ Future<SessionHandle> TransportLayerASIO::asyncConnect(HostAndPort peer, Milliseconds timeout) { struct AsyncConnectState { - AsyncConnectState(HostAndPort peer, asio::io_context& context) - : socket(context), timeoutTimer(context), resolver(context), peer(std::move(peer)) {} + AsyncConnectState(HostAndPort peer, + asio::io_context& context, + Promise<SessionHandle> promise_) + : promise(std::move(promise_)), + socket(context), + timeoutTimer(context), + resolver(context), + peer(std::move(peer)) {} AtomicBool done{false}; Promise<SessionHandle> promise; @@ -573,8 +528,10 @@ Future<SessionHandle> TransportLayerASIO::asyncConnect(HostAndPort peer, }; auto reactorImpl = checked_cast<ASIOReactor*>(reactor.get()); - auto connector = std::make_shared<AsyncConnectState>(std::move(peer), *reactorImpl); - Future<SessionHandle> mergedFuture = connector->promise.getFuture(); + auto pf = makePromiseFuture<SessionHandle>(); + auto connector = + std::make_shared<AsyncConnectState>(std::move(peer), *reactorImpl, std::move(pf.promise)); + Future<SessionHandle> mergedFuture = std::move(pf.future); if (connector->peer.host().empty()) { return Status{ErrorCodes::HostNotFound, "Hostname or IP address to connect to is empty"}; diff --git a/src/mongo/util/future.h b/src/mongo/util/future.h index 61003d3221e..e230d522302 100644 --- a/src/mongo/util/future.h +++ b/src/mongo/util/future.h @@ -508,9 +508,7 @@ public: ~Promise() { if (MONGO_unlikely(sharedState)) { - if (haveExtractedFuture) { - sharedState->setError({ErrorCodes::BrokenPromise, "broken promise"}); - } + sharedState->setError({ErrorCodes::BrokenPromise, "broken promise"}); } } @@ -576,25 +574,28 @@ public: */ SharedPromise<T> share() noexcept; - /** - * Prefer using makePromiseFuture<T>() over constructing a promise and calling this method. - */ - Future<T> getFuture() noexcept; + static auto makePromiseFutureImpl() { + struct PromiseAndFuture { + Promise<T> promise; + Future<T> future = promise.getFuture(); + }; + return PromiseAndFuture(); + } private: + // This is not public because we found it frequently was involved in races. The + // `makePromiseFuture<T>` API avoids those races entirely. + Future<T> getFuture() noexcept; + friend class Future<void>; template <typename Func> void setImpl(Func&& doSet) noexcept { - invariant(!haveSetValue); - haveSetValue = true; + invariant(sharedState); doSet(); - if (haveExtractedFuture) - sharedState.reset(); + sharedState.reset(); } - bool haveSetValue = false; - bool haveExtractedFuture = false; boost::intrusive_ptr<SharedState<T>> sharedState = make_intrusive<SharedState<T>>(); }; @@ -1312,11 +1313,7 @@ auto makeReadyFutureWith(Func&& func) { */ template <typename T> inline auto makePromiseFuture() { - struct PromiseAndFuture { - Promise<T> promise; - Future<T> future = promise.getFuture(); - }; - return PromiseAndFuture(); + return Promise<T>::makePromiseFutureImpl(); } /** @@ -1351,23 +1348,13 @@ using FutureContinuationResult = template <typename T> inline Future<T> Promise<T>::getFuture() noexcept { - invariant(!haveExtractedFuture); - haveExtractedFuture = true; - - if (!haveSetValue) { - sharedState->threadUnsafeIncRefCountTo(2); - return Future<T>( - boost::intrusive_ptr<SharedState<T>>(sharedState.get(), /*add ref*/ false)); - } - - // Let the Future steal our ref-count since we don't need it anymore. - return Future<T>(std::move(sharedState)); + sharedState->threadUnsafeIncRefCountTo(2); + return Future<T>(boost::intrusive_ptr<SharedState<T>>(sharedState.get(), /*add ref*/ false)); } template <typename T> inline SharedPromise<T> Promise<T>::share() noexcept { - invariant(haveExtractedFuture); - invariant(!haveSetValue); + invariant(sharedState); return SharedPromise<T>(std::make_shared<Promise<T>>(std::move(*this))); } diff --git a/src/mongo/util/future_bm.cpp b/src/mongo/util/future_bm.cpp index 28d57bfe010..5aad4b6bb48 100644 --- a/src/mongo/util/future_bm.cpp +++ b/src/mongo/util/future_bm.cpp @@ -65,9 +65,9 @@ void BM_futureIntReadyThen(benchmark::State& state) { NOINLINE_DECL Future<int> makeReadyFutWithPromise() { benchmark::ClobberMemory(); - Promise<int> p; - p.emplaceValue(1); // before getFuture(). - return p.getFuture(); + auto pf = makePromiseFuture<int>(); + pf.promise.emplaceValue(1); + return std::move(pf.future); } void BM_futureIntReadyWithPromise(benchmark::State& state) { @@ -83,27 +83,12 @@ void BM_futureIntReadyWithPromiseThen(benchmark::State& state) { } } -NOINLINE_DECL Future<int> makeReadyFutWithPromise2() { - // This is the same as makeReadyFutWithPromise() except that this gets the Future first. - benchmark::ClobberMemory(); - Promise<int> p; - auto fut = p.getFuture(); - p.emplaceValue(1); // after getFuture(). - return fut; -} - -void BM_futureIntReadyWithPromise2(benchmark::State& state) { - for (auto _ : state) { - benchmark::DoNotOptimize(makeReadyFutWithPromise().then([](int i) { return i + 1; }).get()); - } -} - void BM_futureIntDeferredThen(benchmark::State& state) { for (auto _ : state) { benchmark::ClobberMemory(); - Promise<int> p; - auto fut = p.getFuture().then([](int i) { return i + 1; }); - p.emplaceValue(1); + auto pf = makePromiseFuture<int>(); + auto fut = std::move(pf.future).then([](int i) { return i + 1; }); + pf.promise.emplaceValue(1); benchmark::DoNotOptimize(std::move(fut).get()); } } @@ -111,9 +96,9 @@ void BM_futureIntDeferredThen(benchmark::State& state) { void BM_futureIntDeferredThenImmediate(benchmark::State& state) { for (auto _ : state) { benchmark::ClobberMemory(); - Promise<int> p; - auto fut = p.getFuture().then([](int i) { return Future<int>::makeReady(i + 1); }); - p.emplaceValue(1); + auto pf = makePromiseFuture<int>(); + auto fut = std::move(pf.future).then([](int i) { return Future<int>::makeReady(i + 1); }); + pf.promise.emplaceValue(1); benchmark::DoNotOptimize(std::move(fut).get()); } } @@ -122,9 +107,9 @@ void BM_futureIntDeferredThenImmediate(benchmark::State& state) { void BM_futureIntDeferredThenReady(benchmark::State& state) { for (auto _ : state) { benchmark::ClobberMemory(); - Promise<int> p1; - auto fut = p1.getFuture().then([&](int i) { return makeReadyFutWithPromise(); }); - p1.emplaceValue(1); + auto pf = makePromiseFuture<int>(); + auto fut = std::move(pf.future).then([](int i) { return makeReadyFutWithPromise(); }); + pf.promise.emplaceValue(1); benchmark::DoNotOptimize(std::move(fut).get()); } } @@ -132,11 +117,11 @@ void BM_futureIntDeferredThenReady(benchmark::State& state) { void BM_futureIntDoubleDeferredThen(benchmark::State& state) { for (auto _ : state) { benchmark::ClobberMemory(); - Promise<int> p1; - Promise<int> p2; - auto fut = p1.getFuture().then([&](int i) { return p2.getFuture(); }); - p1.emplaceValue(1); - p2.emplaceValue(1); + auto pf1 = makePromiseFuture<int>(); + auto pf2 = makePromiseFuture<int>(); + auto fut = std::move(pf1.future).then([&](int i) { return std::move(pf2.future); }); + pf1.promise.emplaceValue(1); + pf2.promise.emplaceValue(1); benchmark::DoNotOptimize(std::move(fut).get()); } } @@ -144,14 +129,15 @@ void BM_futureIntDoubleDeferredThen(benchmark::State& state) { void BM_futureInt3xDeferredThenNested(benchmark::State& state) { for (auto _ : state) { benchmark::ClobberMemory(); - Promise<int> p1; - Promise<int> p2; - Promise<int> p3; - auto fut = p1.getFuture().then( - [&](int i) { return p2.getFuture().then([&](int) { return p3.getFuture(); }); }); - p1.emplaceValue(1); - p2.emplaceValue(1); - p3.emplaceValue(1); + auto pf1 = makePromiseFuture<int>(); + auto pf2 = makePromiseFuture<int>(); + auto pf3 = makePromiseFuture<int>(); + auto fut = std::move(pf1.future).then([&](int i) { + return std::move(pf2.future).then([&](int) { return std::move(pf3.future); }); + }); + pf1.promise.emplaceValue(1); + pf2.promise.emplaceValue(1); + pf3.promise.emplaceValue(1); benchmark::DoNotOptimize(std::move(fut).get()); } } @@ -159,15 +145,15 @@ void BM_futureInt3xDeferredThenNested(benchmark::State& state) { void BM_futureInt3xDeferredThenChained(benchmark::State& state) { for (auto _ : state) { benchmark::ClobberMemory(); - Promise<int> p1; - Promise<int> p2; - Promise<int> p3; - auto fut = p1.getFuture().then([&](int i) { return p2.getFuture(); }).then([&](int i) { - return p3.getFuture(); - }); - p1.emplaceValue(1); - p2.emplaceValue(1); - p3.emplaceValue(1); + auto pf1 = makePromiseFuture<int>(); + auto pf2 = makePromiseFuture<int>(); + auto pf3 = makePromiseFuture<int>(); + auto fut = std::move(pf1.future) + .then([&](int i) { return std::move(pf2.future); }) + .then([&](int i) { return std::move(pf3.future); }); + pf1.promise.emplaceValue(1); + pf2.promise.emplaceValue(1); + pf3.promise.emplaceValue(1); benchmark::DoNotOptimize(std::move(fut).get()); } } @@ -176,18 +162,19 @@ void BM_futureInt3xDeferredThenChained(benchmark::State& state) { void BM_futureInt4xDeferredThenNested(benchmark::State& state) { for (auto _ : state) { benchmark::ClobberMemory(); - Promise<int> p1; - Promise<int> p2; - Promise<int> p3; - Promise<int> p4; - auto fut = p1.getFuture().then([&](int i) { - return p2.getFuture().then( - [&](int) { return p3.getFuture().then([&](int) { return p4.getFuture(); }); }); + auto pf1 = makePromiseFuture<int>(); + auto pf2 = makePromiseFuture<int>(); + auto pf3 = makePromiseFuture<int>(); + auto pf4 = makePromiseFuture<int>(); + auto fut = std::move(pf1.future).then([&](int i) { + return std::move(pf2.future).then([&](int) { + return std::move(pf3.future).then([&](int) { return std::move(pf4.future); }); + }); }); - p1.emplaceValue(1); - p2.emplaceValue(1); - p3.emplaceValue(1); - p4.emplaceValue(1); + pf1.promise.emplaceValue(1); + pf2.promise.emplaceValue(1); + pf3.promise.emplaceValue(1); + pf4.promise.emplaceValue(1); benchmark::DoNotOptimize(std::move(fut).get()); } } @@ -195,18 +182,18 @@ void BM_futureInt4xDeferredThenNested(benchmark::State& state) { void BM_futureInt4xDeferredThenChained(benchmark::State& state) { for (auto _ : state) { benchmark::ClobberMemory(); - Promise<int> p1; - Promise<int> p2; - Promise<int> p3; - Promise<int> p4; - auto fut = p1.getFuture() // - .then([&](int i) { return p2.getFuture(); }) - .then([&](int i) { return p3.getFuture(); }) - .then([&](int i) { return p4.getFuture(); }); - p1.emplaceValue(1); - p2.emplaceValue(1); - p3.emplaceValue(1); - p4.emplaceValue(1); + auto pf1 = makePromiseFuture<int>(); + auto pf2 = makePromiseFuture<int>(); + auto pf3 = makePromiseFuture<int>(); + auto pf4 = makePromiseFuture<int>(); + auto fut = std::move(pf1.future) // + .then([&](int i) { return std::move(pf2.future); }) + .then([&](int i) { return std::move(pf3.future); }) + .then([&](int i) { return std::move(pf4.future); }); + pf1.promise.emplaceValue(1); + pf2.promise.emplaceValue(1); + pf3.promise.emplaceValue(1); + pf4.promise.emplaceValue(1); benchmark::DoNotOptimize(std::move(fut).get()); } } @@ -217,7 +204,6 @@ BENCHMARK(BM_futureIntReady); BENCHMARK(BM_futureIntReadyThen); BENCHMARK(BM_futureIntReadyWithPromise); BENCHMARK(BM_futureIntReadyWithPromiseThen); -BENCHMARK(BM_futureIntReadyWithPromise2); BENCHMARK(BM_futureIntDeferredThen); BENCHMARK(BM_futureIntDeferredThenImmediate); BENCHMARK(BM_futureIntDeferredThenReady); diff --git a/src/mongo/util/future_test.cpp b/src/mongo/util/future_test.cpp index 2a755460ce3..ffeb101cbad 100644 --- a/src/mongo/util/future_test.cpp +++ b/src/mongo/util/future_test.cpp @@ -110,10 +110,9 @@ void FUTURE_SUCCESS_TEST(const CompletionFunc& completion, const TestFunc& test) test(Future<CompletionType>::makeReady(completion())); } { // ready future from promise - Promise<CompletionType> promise; - auto fut = promise.getFuture(); // before setting value to bypass opt to immediate - promise.emplaceValue(completion()); - test(std::move(fut)); + auto pf = makePromiseFuture<CompletionType>(); + pf.promise.emplaceValue(completion()); + test(std::move(pf.future)); } { // async future @@ -132,11 +131,10 @@ void FUTURE_SUCCESS_TEST(const CompletionFunc& completion, const TestFunc& test) test(Future<CompletionType>::makeReady()); } { // ready future from promise - Promise<CompletionType> promise; - auto fut = promise.getFuture(); // before setting value to bypass opt to immediate + auto pf = makePromiseFuture<CompletionType>(); completion(); - promise.emplaceValue(); - test(std::move(fut)); + pf.promise.emplaceValue(); + test(std::move(pf.future)); } { // async future @@ -150,10 +148,9 @@ void FUTURE_FAIL_TEST(const TestFunc& test) { test(Future<CompletionType>::makeReady(failStatus)); } { // ready future from promise - Promise<CompletionType> promise; - auto fut = promise.getFuture(); // before setting value to bypass opt to immediate - promise.setError(failStatus); - test(std::move(fut)); + auto pf = makePromiseFuture<CompletionType>(); + pf.promise.setError(failStatus); + test(std::move(pf.future)); } { // async future @@ -196,13 +193,12 @@ TEST(Future, Success_getAsync) { FUTURE_SUCCESS_TEST( [] { return 1; }, [](Future<int>&& fut) { - auto outside = Promise<int>(); - auto outsideFut = outside.getFuture(); - std::move(fut).getAsync([outside = outside.share()](StatusWith<int> sw) mutable { + auto pf = makePromiseFuture<int>(); + std::move(fut).getAsync([outside = pf.promise.share()](StatusWith<int> sw) mutable { ASSERT_OK(sw); outside.emplaceValue(sw.getValue()); }); - ASSERT_EQ(std::move(outsideFut).get(), 1); + ASSERT_EQ(std::move(pf.future).get(), 1); }); } @@ -234,13 +230,12 @@ TEST(Future, Fail_getNothrowRvalue) { TEST(Future, Fail_getAsync) { FUTURE_FAIL_TEST<int>([](Future<int>&& fut) { - auto outside = Promise<int>(); - auto outsideFut = outside.getFuture(); - std::move(fut).getAsync([outside = outside.share()](StatusWith<int> sw) mutable { + auto pf = makePromiseFuture<int>(); + std::move(fut).getAsync([outside = pf.promise.share()](StatusWith<int> sw) mutable { ASSERT(!sw.isOK()); outside.setError(sw.getStatus()); }); - ASSERT_EQ(std::move(outsideFut).getNoThrow(), failStatus); + ASSERT_EQ(std::move(pf.future).getNoThrow(), failStatus); }); } @@ -357,10 +352,9 @@ TEST(Future, Success_thenFutureReady) { [](Future<int>&& fut) { ASSERT_EQ(std::move(fut) .then([](int i) { - Promise<int> promise; - auto fut = promise.getFuture(); - promise.emplaceValue(i + 2); - return fut; + auto pf = makePromiseFuture<int>(); + pf.promise.emplaceValue(i + 2); + return std::move(pf.future); }) .get(), 3); @@ -490,10 +484,9 @@ TEST(Future, Fail_onErrorFutureReady) { ASSERT_EQ(std::move(fut) .onError([](Status s) { ASSERT_EQ(s, failStatus); - Promise<int> promise; - auto fut = promise.getFuture(); - promise.emplaceValue(3); - return fut; + auto pf = makePromiseFuture<int>(); + pf.promise.emplaceValue(3); + return std::move(pf.future); }) .get(), 3); @@ -744,13 +737,12 @@ TEST(Future_Void, Success_getAsync) { FUTURE_SUCCESS_TEST( [] {}, [](Future<void>&& fut) { - auto outside = Promise<void>(); - auto outsideFut = outside.getFuture(); - std::move(fut).getAsync([outside = outside.share()](Status status) mutable { + auto pf = makePromiseFuture<void>(); + std::move(fut).getAsync([outside = pf.promise.share()](Status status) mutable { ASSERT_OK(status); outside.emplaceValue(); }); - ASSERT_EQ(std::move(outsideFut).getNoThrow(), Status::OK()); + ASSERT_EQ(std::move(pf.future).getNoThrow(), Status::OK()); }); } @@ -783,13 +775,12 @@ TEST(Future_Void, Fail_getNothrowRvalue) { TEST(Future_Void, Fail_getAsync) { FUTURE_FAIL_TEST<void>([](Future<void>&& fut) { - auto outside = Promise<void>(); - auto outsideFut = outside.getFuture(); - std::move(fut).getAsync([outside = outside.share()](Status status) mutable { + auto pf = makePromiseFuture<void>(); + std::move(fut).getAsync([outside = pf.promise.share()](Status status) mutable { ASSERT(!status.isOK()); outside.setError(status); }); - ASSERT_EQ(std::move(outsideFut).getNoThrow(), failStatus); + ASSERT_EQ(std::move(pf.future).getNoThrow(), failStatus); }); } @@ -884,10 +875,9 @@ TEST(Future_Void, Success_thenFutureReady) { [](Future<void>&& fut) { ASSERT_EQ(std::move(fut) .then([]() { - Promise<int> promise; - auto fut = promise.getFuture(); - promise.emplaceValue(3); - return fut; + auto pf = makePromiseFuture<int>(); + pf.promise.emplaceValue(3); + return std::move(pf.future); }) .get(), 3); @@ -1000,10 +990,9 @@ TEST(Future_Void, Fail_onErrorFutureReady) { ASSERT_EQ(std::move(fut) .onError([](Status s) { ASSERT_EQ(s, failStatus); - Promise<void> promise; - auto fut = promise.getFuture(); - promise.emplaceValue(); - return fut; + auto pf = makePromiseFuture<void>(); + pf.promise.emplaceValue(); + return std::move(pf.future); }) .then([] { return 3; }) .get(), @@ -1284,13 +1273,12 @@ TEST(Future_MoveOnly, Success_getAsync) { FUTURE_SUCCESS_TEST( [] { return Widget(1); }, [](Future<Widget>&& fut) { - auto outside = Promise<Widget>(); - auto outsideFut = outside.getFuture(); - std::move(fut).getAsync([outside = outside.share()](StatusWith<Widget> sw) mutable { + auto pf = makePromiseFuture<Widget>(); + std::move(fut).getAsync([outside = pf.promise.share()](StatusWith<Widget> sw) mutable { ASSERT_OK(sw); outside.emplaceValue(std::move(sw.getValue())); }); - ASSERT_EQ(std::move(outsideFut).get(), 1); + ASSERT_EQ(std::move(pf.future).get(), 1); }); } @@ -1327,13 +1315,12 @@ TEST(Future_MoveOnly, Fail_getNothrowRvalue) { TEST(Future_MoveOnly, Fail_getAsync) { FUTURE_FAIL_TEST<Widget>([](Future<Widget>&& fut) { - auto outside = Promise<Widget>(); - auto outsideFut = outside.getFuture(); - std::move(fut).getAsync([outside = outside.share()](StatusWith<Widget> sw) mutable { + auto pf = makePromiseFuture<Widget>(); + std::move(fut).getAsync([outside = pf.promise.share()](StatusWith<Widget> sw) mutable { ASSERT(!sw.isOK()); outside.setError(sw.getStatus()); }); - ASSERT_EQ(std::move(outsideFut).getNoThrow(), failStatus); + ASSERT_EQ(std::move(pf.future).getNoThrow(), failStatus); }); } @@ -1413,10 +1400,9 @@ TEST(Future_MoveOnly, Success_thenFutureReady) { [](Future<Widget>&& fut) { ASSERT_EQ(std::move(fut) .then([](Widget i) { - Promise<Widget> promise; - auto fut = promise.getFuture(); - promise.emplaceValue(i + 2); - return fut; + auto pf = makePromiseFuture<Widget>(); + pf.promise.emplaceValue(i + 2); + return std::move(pf.future); }) .get(), 3); @@ -1548,10 +1534,9 @@ TEST(Future_MoveOnly, Fail_onErrorFutureReady) { ASSERT_EQ(std::move(fut) .onError([](Status s) { ASSERT_EQ(s, failStatus); - Promise<Widget> promise; - auto fut = promise.getFuture(); - promise.emplaceValue(3); - return fut; + auto pf = makePromiseFuture<Widget>(); + pf.promise.emplaceValue(3); + return std::move(pf.future); }) .get(), 3); @@ -1760,119 +1745,119 @@ DEATH_TEST(Future_EdgeCases, Success_getAsync_throw, "terminate() called") { TEST(Promise, Success_setFrom) { FUTURE_SUCCESS_TEST([] { return 1; }, [](Future<int>&& fut) { - Promise<int> p; - p.setFrom(std::move(fut)); - ASSERT_EQ(p.getFuture().get(), 1); + auto pf = makePromiseFuture<int>(); + pf.promise.setFrom(std::move(fut)); + ASSERT_EQ(std::move(pf.future).get(), 1); }); } TEST(Promise, Fail_setFrom) { FUTURE_FAIL_TEST<int>([](Future<int>&& fut) { - Promise<int> p; - p.setFrom(std::move(fut)); - ASSERT_THROWS_failStatus(p.getFuture().get()); + auto pf = makePromiseFuture<int>(); + pf.promise.setFrom(std::move(fut)); + ASSERT_THROWS_failStatus(std::move(pf.future).get()); }); } TEST(Promise, Success_setWith_value) { - Promise<int> p; - p.setWith([&] { return 1; }); - ASSERT_EQ(p.getFuture().get(), 1); + auto pf = makePromiseFuture<int>(); + pf.promise.setWith([&] { return 1; }); + ASSERT_EQ(std::move(pf.future).get(), 1); } TEST(Promise, Fail_setWith_throw) { - Promise<int> p; - p.setWith([&] { + auto pf = makePromiseFuture<int>(); + pf.promise.setWith([&] { uassertStatusOK(failStatus); return 1; }); - ASSERT_THROWS_failStatus(p.getFuture().get()); + ASSERT_THROWS_failStatus(std::move(pf.future).get()); } TEST(Promise, Success_setWith_StatusWith) { - Promise<int> p; - p.setWith([&] { return StatusWith<int>(1); }); - ASSERT_EQ(p.getFuture().get(), 1); + auto pf = makePromiseFuture<int>(); + pf.promise.setWith([&] { return StatusWith<int>(1); }); + ASSERT_EQ(std::move(pf.future).get(), 1); } TEST(Promise, Fail_setWith_StatusWith) { - Promise<int> p; - p.setWith([&] { return StatusWith<int>(failStatus); }); - ASSERT_THROWS_failStatus(p.getFuture().get()); + auto pf = makePromiseFuture<int>(); + pf.promise.setWith([&] { return StatusWith<int>(failStatus); }); + ASSERT_THROWS_failStatus(std::move(pf.future).get()); } TEST(Promise, Success_setWith_Future) { FUTURE_SUCCESS_TEST([] { return 1; }, [](Future<int>&& fut) { - Promise<int> p; - p.setWith([&] { return std::move(fut); }); - ASSERT_EQ(p.getFuture().get(), 1); + auto pf = makePromiseFuture<int>(); + pf.promise.setWith([&] { return std::move(fut); }); + ASSERT_EQ(std::move(pf.future).get(), 1); }); } TEST(Promise, Fail_setWith_Future) { FUTURE_FAIL_TEST<int>([](Future<int>&& fut) { - Promise<int> p; - p.setWith([&] { return std::move(fut); }); - ASSERT_THROWS_failStatus(p.getFuture().get()); + auto pf = makePromiseFuture<int>(); + pf.promise.setWith([&] { return std::move(fut); }); + ASSERT_THROWS_failStatus(std::move(pf.future).get()); }); } TEST(Promise_void, Success_setFrom) { FUTURE_SUCCESS_TEST([] {}, [](Future<void>&& fut) { - Promise<void> p; - p.setFrom(std::move(fut)); - ASSERT_OK(p.getFuture().getNoThrow()); + auto pf = makePromiseFuture<void>(); + pf.promise.setFrom(std::move(fut)); + ASSERT_OK(std::move(pf.future).getNoThrow()); }); } TEST(Promise_void, Fail_setFrom) { FUTURE_FAIL_TEST<void>([](Future<void>&& fut) { - Promise<void> p; - p.setFrom(std::move(fut)); - ASSERT_THROWS_failStatus(p.getFuture().get()); + auto pf = makePromiseFuture<void>(); + pf.promise.setFrom(std::move(fut)); + ASSERT_THROWS_failStatus(std::move(pf.future).get()); }); } TEST(Promise_void, Success_setWith_value) { - Promise<void> p; - p.setWith([&] {}); - ASSERT_OK(p.getFuture().getNoThrow()); + auto pf = makePromiseFuture<void>(); + pf.promise.setWith([&] {}); + ASSERT_OK(std::move(pf.future).getNoThrow()); } TEST(Promise_void, Fail_setWith_throw) { - Promise<void> p; - p.setWith([&] { uassertStatusOK(failStatus); }); - ASSERT_THROWS_failStatus(p.getFuture().get()); + auto pf = makePromiseFuture<void>(); + pf.promise.setWith([&] { uassertStatusOK(failStatus); }); + ASSERT_THROWS_failStatus(std::move(pf.future).get()); } TEST(Promise_void, Success_setWith_Status) { - Promise<void> p; - p.setWith([&] { return Status::OK(); }); - ASSERT_OK(p.getFuture().getNoThrow()); + auto pf = makePromiseFuture<void>(); + pf.promise.setWith([&] { return Status::OK(); }); + ASSERT_OK(std::move(pf.future).getNoThrow()); } TEST(Promise_void, Fail_setWith_Status) { - Promise<void> p; - p.setWith([&] { return failStatus; }); - ASSERT_THROWS_failStatus(p.getFuture().get()); + auto pf = makePromiseFuture<void>(); + pf.promise.setWith([&] { return failStatus; }); + ASSERT_THROWS_failStatus(std::move(pf.future).get()); } TEST(Promise_void, Success_setWith_Future) { FUTURE_SUCCESS_TEST([] {}, [](Future<void>&& fut) { - Promise<void> p; - p.setWith([&] { return std::move(fut); }); - ASSERT_OK(p.getFuture().getNoThrow()); + auto pf = makePromiseFuture<void>(); + pf.promise.setWith([&] { return std::move(fut); }); + ASSERT_OK(std::move(pf.future).getNoThrow()); }); } TEST(Promise_void, Fail_setWith_Future) { FUTURE_FAIL_TEST<void>([](Future<void>&& fut) { - Promise<void> p; - p.setWith([&] { return std::move(fut); }); - ASSERT_THROWS_failStatus(p.getFuture().get()); + auto pf = makePromiseFuture<void>(); + pf.promise.setWith([&] { return std::move(fut); }); + ASSERT_THROWS_failStatus(std::move(pf.future).get()); }); } diff --git a/src/mongo/util/keyed_executor.h b/src/mongo/util/keyed_executor.h index c0d25110d44..ba0fc28ee87 100644 --- a/src/mongo/util/keyed_executor.h +++ b/src/mongo/util/keyed_executor.h @@ -157,6 +157,8 @@ public: promise.emplaceValue(); } + explicit Latch(Promise<void> p) : promise(std::move(p)) {} + Promise<void> promise; }; @@ -167,9 +169,10 @@ public: return Future<void>::makeReady(); } + auto pf = makePromiseFuture<void>(); // We rely on shard_ptr to handle the atomic refcounting before emplacing for us. - auto latch = std::make_shared<Latch>(); - auto future = latch->promise.getFuture(); + auto latch = std::make_shared<Latch>(std::move(pf.promise)); + auto future = std::move(pf.future); for (auto& pair : _map) { _onCleared(lk, pair.second).getAsync([latch](const Status& status) mutable { |