diff options
Diffstat (limited to 'src/mongo/transport')
-rw-r--r-- | src/mongo/transport/service_executor_test.cpp | 8 | ||||
-rw-r--r-- | src/mongo/transport/transport_layer.h | 1 | ||||
-rw-r--r-- | src/mongo/transport/transport_layer_asio.cpp | 107 |
3 files changed, 37 insertions, 79 deletions
diff --git a/src/mongo/transport/service_executor_test.cpp b/src/mongo/transport/service_executor_test.cpp index 86e7187a37f..a3fafd88ef4 100644 --- a/src/mongo/transport/service_executor_test.cpp +++ b/src/mongo/transport/service_executor_test.cpp @@ -104,6 +104,14 @@ public: _ioContext.stop(); } + void drain() override final { + _ioContext.restart(); + while (_ioContext.poll()) { + LOG(1) << "Draining remaining work in reactor."; + } + _ioContext.stop(); + } + std::unique_ptr<ReactorTimer> makeTimer() final { MONGO_UNREACHABLE; } diff --git a/src/mongo/transport/transport_layer.h b/src/mongo/transport/transport_layer.h index 3080eceabac..b3ee5aa92a7 100644 --- a/src/mongo/transport/transport_layer.h +++ b/src/mongo/transport/transport_layer.h @@ -155,6 +155,7 @@ public: virtual void run() noexcept = 0; virtual void runFor(Milliseconds time) noexcept = 0; virtual void stop() = 0; + virtual void drain() = 0; using Task = stdx::function<void()>; diff --git a/src/mongo/transport/transport_layer_asio.cpp b/src/mongo/transport/transport_layer_asio.cpp index cbcde155864..617dced8cac 100644 --- a/src/mongo/transport/transport_layer_asio.cpp +++ b/src/mongo/transport/transport_layer_asio.cpp @@ -67,40 +67,30 @@ MONGO_FAIL_POINT_DEFINE(transportLayerASIOasyncConnectTimesOut); class ASIOReactorTimer final : public ReactorTimer { public: explicit ASIOReactorTimer(asio::io_context& ctx) - : _timerState(std::make_shared<TimerState>(ctx)) {} + : _timer(std::make_shared<asio::system_timer>(ctx)) {} ~ASIOReactorTimer() { // The underlying timer won't get destroyed until the last promise from _asyncWait - // has been filled, so cancel the timer so call callbacks get run + // has been filled, so cancel the timer so our promises get fulfilled cancel(); } void cancel(const BatonHandle& baton = nullptr) override { - auto promise = [&] { - stdx::lock_guard<stdx::mutex> lk(_timerState->mutex); - _timerState->generation++; - return std::move(_timerState->finalPromise); - }(); - - if (promise) { - // We're worried that setting the error on the promise without unwinding the stack - // can lead to a deadlock, so this gets scheduled on the io_context of the timer. - _timerState->timer.get_io_context().post([promise = promise->share()]() mutable { - promise.setError({ErrorCodes::CallbackCanceled, "Timer was canceled"}); - }); + // If we have a baton try to cancel that. + if (baton && baton->cancelTimer(*this)) { + LOG(2) << "Canceled via baton, skipping asio cancel."; + return; } - if (!(baton && baton->cancelTimer(*this))) { - _timerState->timer.cancel(); - } + // Otherwise there could be a previous timer that was scheduled normally. + _timer->cancel(); } Future<void> waitFor(Milliseconds timeout, const BatonHandle& baton = nullptr) override { if (baton) { return _asyncWait([&] { return baton->waitFor(*this, timeout); }, baton); } else { - return _asyncWait( - [&] { _timerState->timer.expires_after(timeout.toSystemDuration()); }); + return _asyncWait([&] { _timer->expires_after(timeout.toSystemDuration()); }); } } @@ -108,48 +98,21 @@ public: if (baton) { return _asyncWait([&] { return baton->waitUntil(*this, expiration); }, baton); } else { - return _asyncWait( - [&] { _timerState->timer.expires_at(expiration.toSystemTimePoint()); }); + return _asyncWait([&] { _timer->expires_at(expiration.toSystemTimePoint()); }); } } private: - std::pair<Future<void>, uint64_t> _getFuture() { - stdx::lock_guard<stdx::mutex> lk(_timerState->mutex); - auto id = ++_timerState->generation; - invariant(!_timerState->finalPromise); - auto pf = makePromiseFuture<void>(); - _timerState->finalPromise = std::make_unique<Promise<void>>(std::move(pf.promise)); - return std::make_pair(std::move(pf.future), id); - } - template <typename ArmTimerCb> Future<void> _asyncWait(ArmTimerCb&& armTimer) { try { cancel(); - Future<void> ret; - uint64_t id; - std::tie(ret, id) = _getFuture(); - armTimer(); - _timerState->timer.async_wait( - [ id, state = _timerState ](const std::error_code& ec) mutable { - stdx::unique_lock<stdx::mutex> lk(state->mutex); - if (id != state->generation) { - return; - } - auto promise = std::move(state->finalPromise); - lk.unlock(); - - if (ec) { - promise->setError(errorCodeToStatus(ec)); - } else { - promise->emplaceValue(); - } - }); - - return ret; + return _timer->async_wait(UseFuture{}).tapError([timer = _timer](const Status& status) { + LOG(2) << "Timer received error: " << status; + }); + } catch (asio::system_error& ex) { return Future<void>::makeReady(errorCodeToStatus(ex.code())); } @@ -159,40 +122,19 @@ private: Future<void> _asyncWait(ArmTimerCb&& armTimer, const BatonHandle& baton) { cancel(baton); - Future<void> ret; - uint64_t id; - std::tie(ret, id) = _getFuture(); - - armTimer().getAsync([ id, state = _timerState ](Status status) mutable { - stdx::unique_lock<stdx::mutex> lk(state->mutex); - if (id != state->generation) { - return; - } - auto promise = std::move(state->finalPromise); - lk.unlock(); - + auto pf = makePromiseFuture<void>(); + armTimer().getAsync([sp = pf.promise.share()](Status status) mutable { if (status.isOK()) { - promise->emplaceValue(); + sp.emplaceValue(); } else { - promise->setError(status); + sp.setError(status); } }); - return ret; + return std::move(pf.future); } - // The timer itself and its state are stored in this struct managed by a shared_ptr so we can - // extend the lifetime of the timer until all callbacks to timer.async_wait have run. - struct TimerState { - explicit TimerState(asio::io_context& ctx) : timer(ctx) {} - - asio::system_timer timer; - stdx::mutex mutex; - uint64_t generation = 0; - std::unique_ptr<Promise<void>> finalPromise; - }; - - std::shared_ptr<TimerState> _timerState; + std::shared_ptr<asio::system_timer> _timer; }; class TransportLayerASIO::ASIOReactor final : public Reactor { @@ -213,7 +155,6 @@ public: void runFor(Milliseconds time) noexcept override { ThreadIdGuard threadIdGuard(this); asio::io_context::work work(_ioContext); - try { _ioContext.run_for(time.toSystemDuration()); } catch (...) { @@ -226,6 +167,14 @@ public: _ioContext.stop(); } + void drain() override { + _ioContext.restart(); + while (_ioContext.poll()) { + LOG(2) << "Draining remaining work in reactor."; + } + _ioContext.stop(); + } + std::unique_ptr<ReactorTimer> makeTimer() override { return std::make_unique<ASIOReactorTimer>(_ioContext); } |