diff options
Diffstat (limited to 'src/mongo/transport/baton_asio_linux.h')
-rw-r--r-- | src/mongo/transport/baton_asio_linux.h | 219 |
1 files changed, 137 insertions, 82 deletions
diff --git a/src/mongo/transport/baton_asio_linux.h b/src/mongo/transport/baton_asio_linux.h index 63bfe67da3f..55570768ba7 100644 --- a/src/mongo/transport/baton_asio_linux.h +++ b/src/mongo/transport/baton_asio_linux.h @@ -30,8 +30,8 @@ #pragma once +#include <map> #include <memory> -#include <set> #include <vector> #include <poll.h> @@ -55,7 +55,7 @@ namespace transport { * * We implement our networking reactor on top of poll + eventfd for wakeups */ -class TransportLayerASIO::BatonASIO : public Baton { +class TransportLayerASIO::BatonASIO : public NetworkingBaton { /** * We use this internal reactor timer to exit run_until calls (by forcing an early timeout for * ::poll). @@ -117,6 +117,8 @@ class TransportLayerASIO::BatonASIO : public Baton { } const int fd; + + static const Client::Decoration<EventFDHolder> getForClient; }; public: @@ -129,75 +131,69 @@ public: invariant(_timers.empty()); } - void detach() override { - { - stdx::lock_guard<stdx::mutex> lk(_mutex); - invariant(_sessions.empty()); - invariant(_scheduled.empty()); - invariant(_timers.empty()); - } + void markKillOnClientDisconnect() noexcept override { + if (_opCtx->getClient() && _opCtx->getClient()->session()) { + addSessionImpl(*(_opCtx->getClient()->session()), POLLRDHUP).getAsync([this](Status s) { + if (!s.isOK()) { + return; + } - { - stdx::lock_guard<Client> lk(*_opCtx->getClient()); - invariant(_opCtx->getBaton().get() == this); - _opCtx->setBaton(nullptr); + _opCtx->markKilled(ErrorCodes::ClientDisconnect); + }); } + } - _opCtx = nullptr; + Future<void> addSession(Session& session, Type type) noexcept override { + return addSessionImpl(session, type == Type::In ? POLLIN : POLLOUT); } - Future<void> addSession(Session& session, Type type) override { - auto fd = checked_cast<ASIOSession&>(session).getSocket().native_handle(); + Future<void> waitUntil(const ReactorTimer& timer, Date_t expiration) noexcept override { auto pf = makePromiseFuture<void>(); + auto id = timer.id(); - _safeExecute([ fd, type, promise = std::move(pf.promise), this ]() mutable { - _sessions[fd] = TransportSession{type, std::move(promise)}; - }); + stdx::unique_lock<stdx::mutex> lk(_mutex); - return std::move(pf.future); - } + if (!_opCtx) { + return Status(ErrorCodes::ShutdownInProgress, + "baton is detached, cannot waitUntil on timer"); + } - Future<void> waitUntil(const ReactorTimer& timer, Date_t expiration) override { - auto pf = makePromiseFuture<void>(); - _safeExecute( - [ timerPtr = &timer, expiration, promise = std::move(pf.promise), this ]() mutable { - auto pair = _timers.insert({ - timerPtr, expiration, std::move(promise), - }); - invariant(pair.second); - _timersById[pair.first->id] = pair.first; - }); + _safeExecute(std::move(lk), + [ id, expiration, promise = std::move(pf.promise), this ]() mutable { + auto iter = _timers.emplace(std::piecewise_construct, + std::forward_as_tuple(expiration), + std::forward_as_tuple(id, std::move(promise))); + _timersById[id] = iter; + }); return std::move(pf.future); } - bool cancelSession(Session& session) override { - const auto fd = checked_cast<ASIOSession&>(session).getSocket().native_handle(); + bool cancelSession(Session& session) noexcept override { + const auto id = session.id(); stdx::unique_lock<stdx::mutex> lk(_mutex); - if (_sessions.find(fd) == _sessions.end()) { + if (_sessions.find(id) == _sessions.end()) { return false; } - // TODO: There's an ABA issue here with fds where between previously and before we could - // have removed the fd, then opened and added a new socket with the same fd. We need to - // solve it via using session id's for handles. - _safeExecute(std::move(lk), [fd, this] { _sessions.erase(fd); }); + _safeExecute(std::move(lk), [id, this] { _sessions.erase(id); }); return true; } - bool cancelTimer(const ReactorTimer& timer) override { + bool cancelTimer(const ReactorTimer& timer) noexcept override { + const auto id = timer.id(); + stdx::unique_lock<stdx::mutex> lk(_mutex); - if (_timersById.find(&timer) == _timersById.end()) { + if (_timersById.find(id) == _timersById.end()) { return false; } - // TODO: Same ABA issue as above, but for pointers. - _safeExecute(std::move(lk), [ timerPtr = &timer, this ] { - auto iter = _timersById.find(timerPtr); + _safeExecute(std::move(lk), [id, this] { + auto iter = _timersById.find(id); if (iter != _timersById.end()) { _timers.erase(iter->second); @@ -208,18 +204,24 @@ public: return true; } - void schedule(unique_function<void()> func) override { + void schedule(unique_function<void(OperationContext*)> func) noexcept override { stdx::lock_guard<stdx::mutex> lk(_mutex); + if (!_opCtx) { + func(nullptr); + + return; + } + _scheduled.push_back(std::move(func)); if (_inPoll) { - _efd.notify(); + efd().notify(); } } void notify() noexcept override { - _efd.notify(); + efd().notify(); } /** @@ -263,7 +265,7 @@ public: lk.unlock(); for (auto& job : toRun) { - job(); + job(_opCtx); } lk.lock(); } @@ -280,21 +282,19 @@ public: // If we have a timer, poll no longer than that if (_timers.size()) { - deadline = _timers.begin()->expiration; + deadline = _timers.begin()->first; } std::vector<decltype(_sessions)::iterator> sessions; sessions.reserve(_sessions.size()); std::vector<pollfd> pollSet; + pollSet.reserve(_sessions.size() + 1); - pollSet.push_back(pollfd{_efd.fd, POLLIN, 0}); + pollSet.push_back(pollfd{efd().fd, POLLIN, 0}); for (auto iter = _sessions.begin(); iter != _sessions.end(); ++iter) { - pollSet.push_back( - pollfd{iter->first, - static_cast<short>(iter->second.type == Type::In ? POLLIN : POLLOUT), - 0}); + pollSet.push_back(pollfd{iter->second.fd, iter->second.type, 0}); sessions.push_back(iter); } @@ -330,9 +330,9 @@ public: now = clkSource->now(); // Fire expired timers - for (auto iter = _timers.begin(); iter != _timers.end() && iter->expiration < now;) { - toFulfill.push_back(std::move(iter->promise)); - _timersById.erase(iter->id); + for (auto iter = _timers.begin(); iter != _timers.end() && iter->first < now;) { + toFulfill.push_back(std::move(iter->second.promise)); + _timersById.erase(iter->second.id); iter = _timers.erase(iter); } @@ -343,12 +343,13 @@ public: auto pollIter = pollSet.begin(); if (pollIter->revents) { - _efd.wait(); + efd().wait(); remaining--; } ++pollIter; + for (auto sessionIter = sessions.begin(); sessionIter != sessions.end() && remaining; ++sessionIter, ++pollIter) { if (pollIter->revents) { @@ -366,28 +367,75 @@ public: } private: - struct Timer { - const ReactorTimer* id; - Date_t expiration; - mutable Promise<void> promise; // Needs to be mutable to move from it while in std::set. + Future<void> addSessionImpl(Session& session, short type) noexcept { + auto fd = checked_cast<ASIOSession&>(session).getSocket().native_handle(); + auto id = session.id(); + auto pf = makePromiseFuture<void>(); + + stdx::unique_lock<stdx::mutex> lk(_mutex); + + if (!_opCtx) { + return Status(ErrorCodes::ShutdownInProgress, "baton is detached, cannot addSession"); + } - struct LessThan { - bool operator()(const Timer& lhs, const Timer& rhs) const { - return std::tie(lhs.expiration, lhs.id) < std::tie(rhs.expiration, rhs.id); + _safeExecute(std::move(lk), + [ id, fd, type, promise = std::move(pf.promise), this ]() mutable { + _sessions[id] = TransportSession{fd, type, std::move(promise)}; + }); + + return std::move(pf.future); + } + + void detachImpl() noexcept override { + decltype(_sessions) sessions; + decltype(_scheduled) scheduled; + decltype(_timers) timers; + + { + stdx::lock_guard<stdx::mutex> lk(_mutex); + + { + stdx::lock_guard<Client> lk(*_opCtx->getClient()); + invariant(_opCtx->getBaton().get() == this); + _opCtx->setBaton(nullptr); } - }; + + _opCtx = nullptr; + + using std::swap; + swap(_sessions, sessions); + swap(_scheduled, scheduled); + swap(_timers, timers); + } + + for (auto& job : scheduled) { + job(nullptr); + } + + for (auto& session : sessions) { + session.second.promise.setError(Status(ErrorCodes::ShutdownInProgress, + "baton is detached, cannot wait for socket")); + } + + for (auto& pair : timers) { + pair.second.promise.setError(Status(ErrorCodes::ShutdownInProgress, + "baton is detached, completing timer early")); + } + } + + struct Timer { + Timer(size_t id, Promise<void> promise) : id(id), promise(std::move(promise)) {} + + size_t id; + Promise<void> promise; // Needs to be mutable to move from it while in std::set. }; struct TransportSession { - Type type; + int fd; + short type; Promise<void> promise; }; - template <typename Callback> - void _safeExecute(Callback&& cb) { - return _safeExecute(stdx::unique_lock<stdx::mutex>(_mutex), std::forward<Callback>(cb)); - } - /** * Safely executes method on the reactor. If we're in poll, we schedule a task, then write to * the eventfd. If not, we run inline. @@ -395,37 +443,44 @@ private: template <typename Callback> void _safeExecute(stdx::unique_lock<stdx::mutex> lk, Callback&& cb) { if (_inPoll) { - _scheduled.push_back([ cb = std::forward<Callback>(cb), this ]() mutable { - stdx::lock_guard<stdx::mutex> lk(_mutex); - cb(); - }); + _scheduled.push_back( + [ cb = std::forward<Callback>(cb), this ](OperationContext*) mutable { + stdx::lock_guard<stdx::mutex> lk(_mutex); + cb(); + }); - _efd.notify(); + efd().notify(); } else { cb(); } } + EventFDHolder& efd() { + return EventFDHolder::getForClient(_opCtx->getClient()); + } + stdx::mutex _mutex; OperationContext* _opCtx; bool _inPoll = false; - EventFDHolder _efd; - // This map stores the sessions we need to poll on. We unwind it into a pollset for every // blocking call to run - stdx::unordered_map<int, TransportSession> _sessions; + stdx::unordered_map<SessionId, TransportSession> _sessions; // The set is used to find the next timer which will fire. The unordered_map looks up the // timers so we can remove them in O(1) - std::set<Timer, Timer::LessThan> _timers; - stdx::unordered_map<const ReactorTimer*, decltype(_timers)::const_iterator> _timersById; + std::multimap<Date_t, Timer> _timers; + stdx::unordered_map<size_t, decltype(_timers)::const_iterator> _timersById; // For tasks that come in via schedule. Or that were deferred because we were in poll - std::vector<unique_function<void()>> _scheduled; + std::vector<unique_function<void(OperationContext*)>> _scheduled; }; +const Client::Decoration<TransportLayerASIO::BatonASIO::EventFDHolder> + TransportLayerASIO::BatonASIO::EventFDHolder::getForClient = + Client::declareDecoration<TransportLayerASIO::BatonASIO::EventFDHolder>(); + } // namespace transport } // namespace mongo |