diff options
-rw-r--r-- | src/mongo/executor/connection_pool.cpp | 57 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool.h | 17 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool_test_fixture.cpp | 10 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool_test_fixture.h | 6 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool_tl.cpp | 81 | ||||
-rw-r--r-- | src/mongo/executor/connection_pool_tl.h | 77 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_integration_fixture.cpp | 5 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_integration_test.cpp | 9 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_tl.cpp | 30 | ||||
-rw-r--r-- | src/mongo/executor/network_interface_tl.h | 1 | ||||
-rw-r--r-- | src/mongo/transport/service_executor_test.cpp | 8 | ||||
-rw-r--r-- | src/mongo/transport/transport_layer.h | 1 | ||||
-rw-r--r-- | src/mongo/transport/transport_layer_asio.cpp | 107 |
13 files changed, 259 insertions, 150 deletions
diff --git a/src/mongo/executor/connection_pool.cpp b/src/mongo/executor/connection_pool.cpp index 9bfbdd459a3..337e0b52c49 100644 --- a/src/mongo/executor/connection_pool.cpp +++ b/src/mongo/executor/connection_pool.cpp @@ -59,7 +59,8 @@ namespace executor { * go out of existence after hostTimeout passes without any of their * connections being used. */ -class ConnectionPool::SpecificPool { +class ConnectionPool::SpecificPool final + : public std::enable_shared_from_this<ConnectionPool::SpecificPool> { public: /** * These active client methods must be used whenever entering a specific pool outside of the @@ -219,7 +220,7 @@ private: std::vector<Request> _requests; - std::unique_ptr<TimerInterface> _requestTimer; + std::shared_ptr<TimerInterface> _requestTimer; Date_t _requestTimerExpiration; size_t _activeClients; size_t _generation; @@ -266,7 +267,7 @@ constexpr Milliseconds ConnectionPool::kDefaultRefreshTimeout; const Status ConnectionPool::kConnectionStateUnknown = Status(ErrorCodes::InternalError, "Connection is in an unknown state"); -ConnectionPool::ConnectionPool(std::unique_ptr<DependentTypeFactoryInterface> impl, +ConnectionPool::ConnectionPool(std::shared_ptr<DependentTypeFactoryInterface> impl, std::string name, Options options) : _name(std::move(name)), @@ -289,6 +290,8 @@ ConnectionPool::~ConnectionPool() { } void ConnectionPool::shutdown() { + _factory->shutdown(); + std::vector<SpecificPool*> pools; // Ensure we decrement active clients for all pools that we inc on (because we intend to process @@ -515,6 +518,7 @@ void ConnectionPool::SpecificPool::returnConnection(ConnectionInterface* connPtr auto needsRefreshTP = connPtr->getLastUsed() + _parent->_options.refreshRequirement; auto conn = takeFromPool(_checkedOutPool, connPtr); + invariant(conn); updateStateInLock(); @@ -611,29 +615,27 @@ void ConnectionPool::SpecificPool::addToReady(stdx::unique_lock<stdx::mutex>& lk // Our strategy for refreshing connections is to check them out and // immediately check them back in (which kicks off the refresh logic in // returnConnection - connPtr->setTimeout(_parent->_options.refreshRequirement, [this, connPtr]() { - OwnedConnection conn; - - runWithActiveClient([&](stdx::unique_lock<stdx::mutex> lk) { - if (!_readyPool.count(connPtr)) { - // We've already been checked out. We don't need to refresh - // ourselves. - return; - } + connPtr->setTimeout(_parent->_options.refreshRequirement, + [ this, connPtr, anchor = shared_from_this() ]() { + runWithActiveClient([&](stdx::unique_lock<stdx::mutex> lk) { + auto conn = takeFromPool(_readyPool, connPtr); - conn = takeFromPool(_readyPool, connPtr); + // We've already been checked out. We don't need to refresh + // ourselves. + if (!conn) + return; - // If we're in shutdown, we don't need to refresh connections - if (_state == State::kInShutdown) - return; + // If we're in shutdown, we don't need to refresh connections + if (_state == State::kInShutdown) + return; - _checkedOutPool[connPtr] = std::move(conn); + _checkedOutPool[connPtr] = std::move(conn); - connPtr->indicateSuccess(); + connPtr->indicateSuccess(); - returnConnection(connPtr, std::move(lk)); - }); - }); + returnConnection(connPtr, std::move(lk)); + }); + }); fulfillRequests(lk); } @@ -764,14 +766,13 @@ void ConnectionPool::SpecificPool::spawnConnections(stdx::unique_lock<stdx::mute fassertFailed(40336); } - auto connPtr = handle.get(); - _processingPool[connPtr] = std::move(handle); + _processingPool[handle.get()] = handle; ++_created; // Run the setup callback lk.unlock(); - connPtr->setup( + handle->setup( _parent->_options.refreshTimeout, [this](ConnectionInterface* connPtr, Status status) { runWithActiveClient([&](stdx::unique_lock<stdx::mutex> lk) { auto conn = takeFromProcessingPool(connPtr); @@ -844,7 +845,8 @@ template <typename OwnershipPoolType> typename OwnershipPoolType::mapped_type ConnectionPool::SpecificPool::takeFromPool( OwnershipPoolType& pool, typename OwnershipPoolType::key_type connPtr) { auto iter = pool.find(connPtr); - invariant(iter != pool.end()); + if (iter == pool.end()) + return typename OwnershipPoolType::mapped_type(); auto conn = std::move(iter->second); pool.erase(iter); @@ -853,8 +855,9 @@ typename OwnershipPoolType::mapped_type ConnectionPool::SpecificPool::takeFromPo ConnectionPool::SpecificPool::OwnedConnection ConnectionPool::SpecificPool::takeFromProcessingPool( ConnectionInterface* connPtr) { - if (_processingPool.count(connPtr)) - return takeFromPool(_processingPool, connPtr); + auto conn = takeFromPool(_processingPool, connPtr); + if (conn) + return conn; return takeFromPool(_droppedProcessingPool, connPtr); } diff --git a/src/mongo/executor/connection_pool.h b/src/mongo/executor/connection_pool.h index 118f734920d..a6539cf8beb 100644 --- a/src/mongo/executor/connection_pool.h +++ b/src/mongo/executor/connection_pool.h @@ -131,7 +131,7 @@ public: EgressTagCloserManager* egressTagCloserManager = nullptr; }; - explicit ConnectionPool(std::unique_ptr<DependentTypeFactoryInterface> impl, + explicit ConnectionPool(std::shared_ptr<DependentTypeFactoryInterface> impl, std::string name, Options options = Options{}); @@ -163,11 +163,11 @@ private: // accessed outside the lock const Options _options; - const std::unique_ptr<DependentTypeFactoryInterface> _factory; + const std::shared_ptr<DependentTypeFactoryInterface> _factory; // The global mutex for specific pool access and the generation counter mutable stdx::mutex _mutex; - stdx::unordered_map<HostAndPort, std::unique_ptr<SpecificPool>> _pools; + stdx::unordered_map<HostAndPort, std::shared_ptr<SpecificPool>> _pools; EgressTagCloserManager* _manager; }; @@ -220,9 +220,7 @@ public: * specifically callbacks to set them up (connect + auth + whatever else), * refresh them (issue some kind of ping) and manage a timer. */ -class ConnectionPool::ConnectionInterface - : public TimerInterface, - public std::enable_shared_from_this<ConnectionPool::ConnectionInterface> { +class ConnectionPool::ConnectionInterface : public TimerInterface { MONGO_DISALLOW_COPYING(ConnectionInterface); friend class ConnectionPool; @@ -336,12 +334,17 @@ public: /** * Makes a new timer */ - virtual std::unique_ptr<TimerInterface> makeTimer() = 0; + virtual std::shared_ptr<TimerInterface> makeTimer() = 0; /** * Returns the current time point */ virtual Date_t now() = 0; + + /** + * shutdown + */ + virtual void shutdown() = 0; }; } // namespace executor diff --git a/src/mongo/executor/connection_pool_test_fixture.cpp b/src/mongo/executor/connection_pool_test_fixture.cpp index 560f184a025..77629b0fa94 100644 --- a/src/mongo/executor/connection_pool_test_fixture.cpp +++ b/src/mongo/executor/connection_pool_test_fixture.cpp @@ -42,6 +42,8 @@ TimerImpl::~TimerImpl() { } void TimerImpl::setTimeout(Milliseconds timeout, TimeoutCallback cb) { + _timers.erase(this); + _cb = std::move(cb); _expiration = _global->now() + timeout; @@ -50,10 +52,14 @@ void TimerImpl::setTimeout(Milliseconds timeout, TimeoutCallback cb) { void TimerImpl::cancelTimeout() { _timers.erase(this); + _cb = TimeoutCallback{}; } void TimerImpl::clear() { - _timers.clear(); + while (!_timers.empty()) { + auto* timer = *_timers.begin(); + timer->cancelTimeout(); + } } void TimerImpl::fireIfNecessary() { @@ -233,7 +239,7 @@ std::shared_ptr<ConnectionPool::ConnectionInterface> PoolImpl::makeConnection( return std::make_shared<ConnectionImpl>(hostAndPort, generation, this); } -std::unique_ptr<ConnectionPool::TimerInterface> PoolImpl::makeTimer() { +std::shared_ptr<ConnectionPool::TimerInterface> PoolImpl::makeTimer() { return stdx::make_unique<TimerImpl>(this); } diff --git a/src/mongo/executor/connection_pool_test_fixture.h b/src/mongo/executor/connection_pool_test_fixture.h index f12eccf6c34..a66492d555b 100644 --- a/src/mongo/executor/connection_pool_test_fixture.h +++ b/src/mongo/executor/connection_pool_test_fixture.h @@ -150,10 +150,14 @@ public: std::shared_ptr<ConnectionPool::ConnectionInterface> makeConnection( const HostAndPort& hostAndPort, size_t generation) override; - std::unique_ptr<ConnectionPool::TimerInterface> makeTimer() override; + std::shared_ptr<ConnectionPool::TimerInterface> makeTimer() override; Date_t now() override; + void shutdown() override { + TimerImpl::clear(); + }; + /** * setNow() can be used to fire all timers that have passed a point in time */ diff --git a/src/mongo/executor/connection_pool_tl.cpp b/src/mongo/executor/connection_pool_tl.cpp index 0c9cf2e0f76..5e79798c6aa 100644 --- a/src/mongo/executor/connection_pool_tl.cpp +++ b/src/mongo/executor/connection_pool_tl.cpp @@ -50,11 +50,55 @@ struct TimeoutHandler { } // namespace +void TLTypeFactory::shutdown() { + // Stop any attempt to schedule timers in the future + _inShutdown.store(true); + + stdx::lock_guard<stdx::mutex> lk(_mutex); + + log() << "Killing all outstanding egress activity."; + for (auto collar : _collars) { + collar->kill(); + } +} + +void TLTypeFactory::fasten(Type* type) { + stdx::lock_guard<stdx::mutex> lk(_mutex); + _collars.insert(type); +} + +void TLTypeFactory::release(Type* type) { + stdx::lock_guard<stdx::mutex> lk(_mutex); + _collars.erase(type); + + type->_wasReleased = true; +} + +TLTypeFactory::Type::Type(const std::shared_ptr<TLTypeFactory>& factory) : _factory{factory} {} + +TLTypeFactory::Type::~Type() { + invariant(_wasReleased); +} + +void TLTypeFactory::Type::release() { + _factory->release(this); +} + +bool TLTypeFactory::inShutdown() const { + return _inShutdown.load(); +} + void TLTimer::setTimeout(Milliseconds timeoutVal, TimeoutCallback cb) { + // We will not wait on a timeout if we are in shutdown. + // The clients will be canceled as an inevitable consequence of pools shutting down. + if (inShutdown()) { + LOG(2) << "Skipping timeout due to impending shutdown."; + return; + } + _timer->waitFor(timeoutVal).getAsync([cb = std::move(cb)](Status status) { - // TODO: verify why we still get broken promises when expliciting call stop and shutting - // down NITL's quickly. - if (status == ErrorCodes::CallbackCanceled || status == ErrorCodes::BrokenPromise) { + // If we get canceled, then we don't worry about the timeout anymore + if (status == ErrorCodes::CallbackCanceled) { return; } @@ -103,11 +147,12 @@ const Status& TLConnection::getStatus() const { } void TLConnection::setTimeout(Milliseconds timeout, TimeoutCallback cb) { - _timer.setTimeout(timeout, std::move(cb)); + auto anchor = shared_from_this(); + _timer->setTimeout(timeout, [ cb = std::move(cb), anchor = std::move(anchor) ] { cb(); }); } void TLConnection::cancelTimeout() { - _timer.cancelTimeout(); + _timer->cancelTimeout(); } void TLConnection::setup(Milliseconds timeout, SetupCallback cb) { @@ -116,7 +161,7 @@ void TLConnection::setup(Milliseconds timeout, SetupCallback cb) { auto pf = makePromiseFuture<void>(); auto handler = std::make_shared<TimeoutHandler>(std::move(pf.promise)); std::move(pf.future).getAsync( - [ this, cb = std::move(cb) ](Status status) { cb(this, std::move(status)); }); + [ this, cb = std::move(cb), anchor ](Status status) { cb(this, std::move(status)); }); log() << "Connecting to " << _peer; setTimeout(timeout, [this, handler, timeout] { @@ -169,6 +214,7 @@ void TLConnection::setup(Milliseconds timeout, SetupCallback cb) { handler->promise.setError(status); } }); + LOG(2) << "Finished connection setup."; } void TLConnection::resetToUnknown() { @@ -181,7 +227,7 @@ void TLConnection::refresh(Milliseconds timeout, RefreshCallback cb) { auto pf = makePromiseFuture<void>(); auto handler = std::make_shared<TimeoutHandler>(std::move(pf.promise)); std::move(pf.future).getAsync( - [ this, cb = std::move(cb) ](Status status) { cb(this, status); }); + [ this, cb = std::move(cb), anchor ](Status status) { cb(this, status); }); setTimeout(timeout, [this, handler] { if (handler->done.swap(true)) { @@ -220,14 +266,27 @@ size_t TLConnection::getGeneration() const { return _generation; } +void TLConnection::cancelAsync() { + if (_client) + _client->cancel(); +} + std::shared_ptr<ConnectionPool::ConnectionInterface> TLTypeFactory::makeConnection( const HostAndPort& hostAndPort, size_t generation) { - return std::make_shared<TLConnection>( - _reactor, getGlobalServiceContext(), hostAndPort, generation, _onConnectHook.get()); + auto conn = std::make_shared<TLConnection>(shared_from_this(), + _reactor, + getGlobalServiceContext(), + hostAndPort, + generation, + _onConnectHook.get()); + fasten(conn.get()); + return conn; } -std::unique_ptr<ConnectionPool::TimerInterface> TLTypeFactory::makeTimer() { - return std::make_unique<TLTimer>(_reactor); +std::shared_ptr<ConnectionPool::TimerInterface> TLTypeFactory::makeTimer() { + auto timer = std::make_shared<TLTimer>(shared_from_this(), _reactor); + fasten(timer.get()); + return timer; } Date_t TLTypeFactory::now() { diff --git a/src/mongo/executor/connection_pool_tl.h b/src/mongo/executor/connection_pool_tl.h index 1e9e1c98604..5aff6c80c57 100644 --- a/src/mongo/executor/connection_pool_tl.h +++ b/src/mongo/executor/connection_pool_tl.h @@ -34,13 +34,17 @@ #include "mongo/executor/connection_pool.h" #include "mongo/executor/network_connection_hook.h" #include "mongo/executor/network_interface.h" +#include "mongo/util/future.h" namespace mongo { namespace executor { namespace connection_pool_tl { -class TLTypeFactory final : public ConnectionPool::DependentTypeFactoryInterface { +class TLTypeFactory final : public ConnectionPool::DependentTypeFactoryInterface, + public std::enable_shared_from_this<TLTypeFactory> { public: + class Type; + TLTypeFactory(transport::ReactorHandle reactor, transport::TransportLayer* tl, std::unique_ptr<NetworkConnectionHook> onConnectHook) @@ -48,42 +52,91 @@ public: std::shared_ptr<ConnectionPool::ConnectionInterface> makeConnection( const HostAndPort& hostAndPort, size_t generation) override; - std::unique_ptr<ConnectionPool::TimerInterface> makeTimer() override; + std::shared_ptr<ConnectionPool::TimerInterface> makeTimer() override; Date_t now() override; + void shutdown() override; + bool inShutdown() const; + void fasten(Type* type); + void release(Type* type); + private: transport::ReactorHandle _reactor; transport::TransportLayer* _tl; std::unique_ptr<NetworkConnectionHook> _onConnectHook; + + mutable stdx::mutex _mutex; + AtomicBool _inShutdown{false}; + stdx::unordered_set<Type*> _collars; +}; + +class TLTypeFactory::Type : public std::enable_shared_from_this<TLTypeFactory::Type> { + friend class TLTypeFactory; + + MONGO_DISALLOW_COPYING(Type); + +public: + explicit Type(const std::shared_ptr<TLTypeFactory>& factory); + ~Type(); + + void release(); + bool inShutdown() const { + return _factory->inShutdown(); + } + + virtual void kill() = 0; + +private: + std::shared_ptr<TLTypeFactory> _factory; + bool _wasReleased = false; }; -class TLTimer final : public ConnectionPool::TimerInterface { +class TLTimer final : public ConnectionPool::TimerInterface, public TLTypeFactory::Type { public: - explicit TLTimer(const transport::ReactorHandle& reactor) - : _reactor(reactor), _timer(_reactor->makeTimer()) {} + explicit TLTimer(const std::shared_ptr<TLTypeFactory>& factory, + const transport::ReactorHandle& reactor) + : TLTypeFactory::Type(factory), _reactor(reactor), _timer(_reactor->makeTimer()) {} + ~TLTimer() { + // Release must be the first expression of this dtor + release(); + } + + void kill() override { + cancelTimeout(); + } void setTimeout(Milliseconds timeout, TimeoutCallback cb) override; void cancelTimeout() override; private: transport::ReactorHandle _reactor; - std::unique_ptr<transport::ReactorTimer> _timer; + std::shared_ptr<transport::ReactorTimer> _timer; }; -class TLConnection final : public ConnectionPool::ConnectionInterface { +class TLConnection final : public ConnectionPool::ConnectionInterface, public TLTypeFactory::Type { public: - TLConnection(transport::ReactorHandle reactor, + TLConnection(const std::shared_ptr<TLTypeFactory>& factory, + transport::ReactorHandle reactor, ServiceContext* serviceContext, HostAndPort peer, size_t generation, NetworkConnectionHook* onConnectHook) - : _reactor(reactor), + : TLTypeFactory::Type(factory), + _reactor(reactor), _serviceContext(serviceContext), - _timer(_reactor), + _timer(factory->makeTimer()), _peer(std::move(peer)), _generation(generation), _onConnectHook(onConnectHook) {} + ~TLConnection() { + // Release must be the first expression of this dtor + release(); + } + + void kill() override { + cancelAsync(); + } void indicateSuccess() override; void indicateFailure(Status status) override; @@ -101,13 +154,15 @@ private: void setup(Milliseconds timeout, SetupCallback cb) override; void resetToUnknown() override; void refresh(Milliseconds timeout, RefreshCallback cb) override; + void cancelAsync(); size_t getGeneration() const override; private: transport::ReactorHandle _reactor; ServiceContext* const _serviceContext; - TLTimer _timer; + std::shared_ptr<ConnectionPool::TimerInterface> _timer; + HostAndPort _peer; size_t _generation; NetworkConnectionHook* const _onConnectHook; diff --git a/src/mongo/executor/network_interface_integration_fixture.cpp b/src/mongo/executor/network_interface_integration_fixture.cpp index 6a593374fb6..4a694eec7d2 100644 --- a/src/mongo/executor/network_interface_integration_fixture.cpp +++ b/src/mongo/executor/network_interface_integration_fixture.cpp @@ -62,9 +62,8 @@ void NetworkInterfaceIntegrationFixture::startNet( } void NetworkInterfaceIntegrationFixture::tearDown() { - if (!_net->inShutdown()) { - _net->shutdown(); - } + // Network interface will only shutdown once because of an internal shutdown guard + _net->shutdown(); } NetworkInterface& NetworkInterfaceIntegrationFixture::net() { diff --git a/src/mongo/executor/network_interface_integration_test.cpp b/src/mongo/executor/network_interface_integration_test.cpp index 738c7606676..75f5811baca 100644 --- a/src/mongo/executor/network_interface_integration_test.cpp +++ b/src/mongo/executor/network_interface_integration_test.cpp @@ -93,12 +93,13 @@ class HangingHook : public executor::NetworkConnectionHook { } Status handleReply(const HostAndPort& remoteHost, RemoteCommandResponse&& response) final { - if (pingCommandMissing(response)) { - return {ErrorCodes::NetworkInterfaceExceededTimeLimit, - "No ping command. Simulating timeout"}; + if (!pingCommandMissing(response)) { + ASSERT_EQ(ErrorCodes::CallbackCanceled, response.status); + return response.status; } - MONGO_UNREACHABLE; + return {ErrorCodes::NetworkInterfaceExceededTimeLimit, + "No ping command. Simulating timeout"}; } }; diff --git a/src/mongo/executor/network_interface_tl.cpp b/src/mongo/executor/network_interface_tl.cpp index cffe9a74596..923d25f565e 100644 --- a/src/mongo/executor/network_interface_tl.cpp +++ b/src/mongo/executor/network_interface_tl.cpp @@ -101,17 +101,37 @@ void NetworkInterfaceTL::startup() { std::move(typeFactory), std::string("NetworkInterfaceTL-") + _instanceName, _connPoolOpts); _ioThread = stdx::thread([this] { setThreadName(_instanceName); - LOG(2) << "The NetworkInterfaceTL reactor thread is spinning up"; - _reactor->run(); + _run(); }); } +void NetworkInterfaceTL::_run() { + LOG(2) << "The NetworkInterfaceTL reactor thread is spinning up"; + + // This returns when the reactor is stopped in shutdown() + _reactor->run(); + + // Note that the pool will shutdown again when the ConnectionPool dtor runs + // This prevents new timers from being set, calls all cancels via the factory registry, and + // destructs all connections for all existing pools. + _pool->shutdown(); + + // Close out all remaining tasks in the reactor now that they've all been canceled. + _reactor->drain(); + + LOG(2) << "NetworkInterfaceTL shutdown successfully"; +} + void NetworkInterfaceTL::shutdown() { - _inShutdown.store(true); + if (_inShutdown.swap(true)) + return; + + LOG(2) << "Shutting down network interface."; + + // Stop the reactor/thread first so that nothing runs on a partially dtor'd pool. _reactor->stop(); + _ioThread.join(); - _pool->shutdown(); - LOG(2) << "NetworkInterfaceTL shutdown successfully"; } bool NetworkInterfaceTL::inShutdown() const { diff --git a/src/mongo/executor/network_interface_tl.h b/src/mongo/executor/network_interface_tl.h index 603eb59ef8d..621336cd6e7 100644 --- a/src/mongo/executor/network_interface_tl.h +++ b/src/mongo/executor/network_interface_tl.h @@ -110,6 +110,7 @@ private: Promise<RemoteCommandResponse> promise; }; + void _run(); void _eraseInUseConn(const TaskExecutor::CallbackHandle& handle); Future<RemoteCommandResponse> _onAcquireConn(std::shared_ptr<CommandState> state, Future<RemoteCommandResponse> future, diff --git a/src/mongo/transport/service_executor_test.cpp b/src/mongo/transport/service_executor_test.cpp index 86e7187a37f..a3fafd88ef4 100644 --- a/src/mongo/transport/service_executor_test.cpp +++ b/src/mongo/transport/service_executor_test.cpp @@ -104,6 +104,14 @@ public: _ioContext.stop(); } + void drain() override final { + _ioContext.restart(); + while (_ioContext.poll()) { + LOG(1) << "Draining remaining work in reactor."; + } + _ioContext.stop(); + } + std::unique_ptr<ReactorTimer> makeTimer() final { MONGO_UNREACHABLE; } diff --git a/src/mongo/transport/transport_layer.h b/src/mongo/transport/transport_layer.h index 3080eceabac..b3ee5aa92a7 100644 --- a/src/mongo/transport/transport_layer.h +++ b/src/mongo/transport/transport_layer.h @@ -155,6 +155,7 @@ public: virtual void run() noexcept = 0; virtual void runFor(Milliseconds time) noexcept = 0; virtual void stop() = 0; + virtual void drain() = 0; using Task = stdx::function<void()>; diff --git a/src/mongo/transport/transport_layer_asio.cpp b/src/mongo/transport/transport_layer_asio.cpp index cbcde155864..617dced8cac 100644 --- a/src/mongo/transport/transport_layer_asio.cpp +++ b/src/mongo/transport/transport_layer_asio.cpp @@ -67,40 +67,30 @@ MONGO_FAIL_POINT_DEFINE(transportLayerASIOasyncConnectTimesOut); class ASIOReactorTimer final : public ReactorTimer { public: explicit ASIOReactorTimer(asio::io_context& ctx) - : _timerState(std::make_shared<TimerState>(ctx)) {} + : _timer(std::make_shared<asio::system_timer>(ctx)) {} ~ASIOReactorTimer() { // The underlying timer won't get destroyed until the last promise from _asyncWait - // has been filled, so cancel the timer so call callbacks get run + // has been filled, so cancel the timer so our promises get fulfilled cancel(); } void cancel(const BatonHandle& baton = nullptr) override { - auto promise = [&] { - stdx::lock_guard<stdx::mutex> lk(_timerState->mutex); - _timerState->generation++; - return std::move(_timerState->finalPromise); - }(); - - if (promise) { - // We're worried that setting the error on the promise without unwinding the stack - // can lead to a deadlock, so this gets scheduled on the io_context of the timer. - _timerState->timer.get_io_context().post([promise = promise->share()]() mutable { - promise.setError({ErrorCodes::CallbackCanceled, "Timer was canceled"}); - }); + // If we have a baton try to cancel that. + if (baton && baton->cancelTimer(*this)) { + LOG(2) << "Canceled via baton, skipping asio cancel."; + return; } - if (!(baton && baton->cancelTimer(*this))) { - _timerState->timer.cancel(); - } + // Otherwise there could be a previous timer that was scheduled normally. + _timer->cancel(); } Future<void> waitFor(Milliseconds timeout, const BatonHandle& baton = nullptr) override { if (baton) { return _asyncWait([&] { return baton->waitFor(*this, timeout); }, baton); } else { - return _asyncWait( - [&] { _timerState->timer.expires_after(timeout.toSystemDuration()); }); + return _asyncWait([&] { _timer->expires_after(timeout.toSystemDuration()); }); } } @@ -108,48 +98,21 @@ public: if (baton) { return _asyncWait([&] { return baton->waitUntil(*this, expiration); }, baton); } else { - return _asyncWait( - [&] { _timerState->timer.expires_at(expiration.toSystemTimePoint()); }); + return _asyncWait([&] { _timer->expires_at(expiration.toSystemTimePoint()); }); } } private: - std::pair<Future<void>, uint64_t> _getFuture() { - stdx::lock_guard<stdx::mutex> lk(_timerState->mutex); - auto id = ++_timerState->generation; - invariant(!_timerState->finalPromise); - auto pf = makePromiseFuture<void>(); - _timerState->finalPromise = std::make_unique<Promise<void>>(std::move(pf.promise)); - return std::make_pair(std::move(pf.future), id); - } - template <typename ArmTimerCb> Future<void> _asyncWait(ArmTimerCb&& armTimer) { try { cancel(); - Future<void> ret; - uint64_t id; - std::tie(ret, id) = _getFuture(); - armTimer(); - _timerState->timer.async_wait( - [ id, state = _timerState ](const std::error_code& ec) mutable { - stdx::unique_lock<stdx::mutex> lk(state->mutex); - if (id != state->generation) { - return; - } - auto promise = std::move(state->finalPromise); - lk.unlock(); - - if (ec) { - promise->setError(errorCodeToStatus(ec)); - } else { - promise->emplaceValue(); - } - }); - - return ret; + return _timer->async_wait(UseFuture{}).tapError([timer = _timer](const Status& status) { + LOG(2) << "Timer received error: " << status; + }); + } catch (asio::system_error& ex) { return Future<void>::makeReady(errorCodeToStatus(ex.code())); } @@ -159,40 +122,19 @@ private: Future<void> _asyncWait(ArmTimerCb&& armTimer, const BatonHandle& baton) { cancel(baton); - Future<void> ret; - uint64_t id; - std::tie(ret, id) = _getFuture(); - - armTimer().getAsync([ id, state = _timerState ](Status status) mutable { - stdx::unique_lock<stdx::mutex> lk(state->mutex); - if (id != state->generation) { - return; - } - auto promise = std::move(state->finalPromise); - lk.unlock(); - + auto pf = makePromiseFuture<void>(); + armTimer().getAsync([sp = pf.promise.share()](Status status) mutable { if (status.isOK()) { - promise->emplaceValue(); + sp.emplaceValue(); } else { - promise->setError(status); + sp.setError(status); } }); - return ret; + return std::move(pf.future); } - // The timer itself and its state are stored in this struct managed by a shared_ptr so we can - // extend the lifetime of the timer until all callbacks to timer.async_wait have run. - struct TimerState { - explicit TimerState(asio::io_context& ctx) : timer(ctx) {} - - asio::system_timer timer; - stdx::mutex mutex; - uint64_t generation = 0; - std::unique_ptr<Promise<void>> finalPromise; - }; - - std::shared_ptr<TimerState> _timerState; + std::shared_ptr<asio::system_timer> _timer; }; class TransportLayerASIO::ASIOReactor final : public Reactor { @@ -213,7 +155,6 @@ public: void runFor(Milliseconds time) noexcept override { ThreadIdGuard threadIdGuard(this); asio::io_context::work work(_ioContext); - try { _ioContext.run_for(time.toSystemDuration()); } catch (...) { @@ -226,6 +167,14 @@ public: _ioContext.stop(); } + void drain() override { + _ioContext.restart(); + while (_ioContext.poll()) { + LOG(2) << "Draining remaining work in reactor."; + } + _ioContext.stop(); + } + std::unique_ptr<ReactorTimer> makeTimer() override { return std::make_unique<ASIOReactorTimer>(_ioContext); } |