diff options
author | Ben Caimano <ben.caimano@10gen.com> | 2020-09-23 04:03:43 +0000 |
---|---|---|
committer | Ben Caimano <ben.caimano@10gen.com> | 2020-10-19 20:25:05 +0000 |
commit | fa29e47f37da2353b49ee71c907026b769fdc607 (patch) | |
tree | 1a9137ca238a39cc1926e3876f3bca38e12fa39d | |
parent | b03c93d55baefc8a70a6ee790f1d497fd1ff8b70 (diff) | |
download | mongo-fa29e47f37da2353b49ee71c907026b769fdc607.tar.gz |
SERVER-51278 Introduced ClientStrand
-rw-r--r-- | src/mongo/db/SConscript | 2 | ||||
-rw-r--r-- | src/mongo/db/client.cpp | 2 | ||||
-rw-r--r-- | src/mongo/db/client_strand.cpp | 95 | ||||
-rw-r--r-- | src/mongo/db/client_strand.h | 214 | ||||
-rw-r--r-- | src/mongo/db/client_strand_test.cpp | 381 | ||||
-rw-r--r-- | src/mongo/transport/service_state_machine.cpp | 295 |
6 files changed, 733 insertions, 256 deletions
diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript index 14706fa55ce..e62270e4f62 100644 --- a/src/mongo/db/SConscript +++ b/src/mongo/db/SConscript @@ -450,6 +450,7 @@ env.Library( source=[ 'baton.cpp', 'client.cpp', + 'client_strand.cpp', 'default_baton.cpp', 'operation_context.cpp', 'operation_context_group.cpp', @@ -2212,6 +2213,7 @@ envWithAsio.CppUnitTest( target='db_unittests', source=[ 'catalog_raii_test.cpp', + 'client_strand_test.cpp', 'collection_index_usage_tracker_test.cpp', 'commands_test.cpp', 'curop_test.cpp', diff --git a/src/mongo/db/client.cpp b/src/mongo/db/client.cpp index a9678bee962..99175dd618f 100644 --- a/src/mongo/db/client.cpp +++ b/src/mongo/db/client.cpp @@ -155,7 +155,7 @@ bool haveClient() { } ServiceContext::UniqueClient Client::releaseCurrent() { - invariant(haveClient()); + invariant(haveClient(), "No client to release"); return std::move(currentClient); } diff --git a/src/mongo/db/client_strand.cpp b/src/mongo/db/client_strand.cpp new file mode 100644 index 00000000000..470ae6d3133 --- /dev/null +++ b/src/mongo/db/client_strand.cpp @@ -0,0 +1,95 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kDefault + +#include "mongo/platform/basic.h" + +#include "mongo/db/client_strand.h" + +#include "mongo/logv2/log.h" +#include "mongo/util/concurrency/thread_name.h" + +namespace mongo { +namespace { +struct ClientStrandData { + ClientStrand* strand = nullptr; +}; + +auto getClientStrandData = Client::declareDecoration<ClientStrandData>(); +} // namespace + +boost::intrusive_ptr<ClientStrand> ClientStrand::make(ServiceContext::UniqueClient client) { + auto strand = make_intrusive<ClientStrand>(std::move(client)); + getClientStrandData(strand->getClientPointer()).strand = strand.get(); + return strand; +} + +boost::intrusive_ptr<ClientStrand> ClientStrand::get(Client* client) { + return getClientStrandData(client).strand; +} + +void ClientStrand::_setCurrent() noexcept { + invariant(_isBound.load()); + invariant(_client); + + LOGV2_DEBUG( + 4910701, kDiagnosticLogLevel, "Setting the Client", "client"_attr = _client->desc()); + + // Set the Client for this thread so calls to Client::getCurrent() works as expected. + Client::setCurrent(std::move(_client)); + + // Set up the thread name. + auto oldThreadName = getThreadName(); + StringData threadName = _clientPtr->desc(); + if (oldThreadName != threadName) { + _oldThreadName = oldThreadName.toString(); + setThreadName(threadName); + LOGV2_DEBUG(4910701, kDiagnosticLogLevel, "Set thread name", "name"_attr = threadName); + } +} + +void ClientStrand::_releaseCurrent() noexcept { + invariant(_isBound.load()); + invariant(!_client); + + // Reclaim the client. + _client = Client::releaseCurrent(); + invariant(_client.get() == _clientPtr, kUnableToRecoverClient); + + if (!_oldThreadName.empty()) { + // Reset the old thread name. + setThreadName(_oldThreadName); + } + + LOGV2_DEBUG( + 4910702, kDiagnosticLogLevel, "Released the Client", "client"_attr = _client->desc()); +} + +} // namespace mongo diff --git a/src/mongo/db/client_strand.h b/src/mongo/db/client_strand.h new file mode 100644 index 00000000000..20b9d940d27 --- /dev/null +++ b/src/mongo/db/client_strand.h @@ -0,0 +1,214 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <string> + +#include "mongo/db/client.h" +#include "mongo/db/service_context.h" +#include "mongo/platform/atomic_word.h" +#include "mongo/stdx/mutex.h" +#include "mongo/util/intrusive_counter.h" +#include "mongo/util/out_of_line_executor.h" + +namespace mongo { + +/** + * ClientStrand is a reference counted type for loaning Clients to threads. + * + * ClientStrand maintains the lifetime of its wrapped Client object and provides functionality to + * "bind" that Client to one and only one thread at a time. Its functions are synchronized. + */ +class ClientStrand final : public RefCountable { + static constexpr auto kDiagnosticLogLevel = 3; + +public: + static constexpr auto kUnableToRecoverClient = "Unable to recover Client for ClientStrand"; + + /** + * A simple RAII guard to set and release Clients. + */ + class Guard { + public: + Guard() = default; + Guard(Guard&&) = default; + Guard& operator=(Guard&&) = default; + + Guard(const Guard&) = delete; + Guard& operator=(const Guard&) = delete; + + Guard(ClientStrand* strand) : _strand(strand) { + // Hold the lock for as long as the Guard is around. This forces other consumers to + // queue behind the Guard. + _strand->_mutex.lock(); + _strand->_isBound.store(true); + + _strand->_setCurrent(); + } + + ~Guard() { + dismiss(); + } + + void dismiss() noexcept { + auto strand = std::exchange(_strand, {}); + if (!strand) { + return; + } + + strand->_releaseCurrent(); + strand->_isBound.store(false); + strand->_mutex.unlock(); + } + + Client* get() noexcept { + return _strand->getClientPointer(); + } + + Client* operator->() noexcept { + return get(); + } + + Client& operator*() noexcept { + return *get(); + } + + private: + boost::intrusive_ptr<ClientStrand> _strand; + }; + + /** + * A simple wrapping executor to run tasks while a Client is bound. + */ + class Executor final : public OutOfLineExecutor { + public: + Executor(ClientStrand* strand, ExecutorPtr exec) + : _strand(strand), _exec(std::move(exec)) {} + void schedule(Task task) override; + + private: + boost::intrusive_ptr<ClientStrand> _strand; + ExecutorPtr _exec; + }; + + /** + * Make a new ClientStrand from a UniqueClient. + */ + static boost::intrusive_ptr<ClientStrand> make(ServiceContext::UniqueClient client); + + /** + * Acquire an owning ClientStrand given a client. + * + * This will return nullptr if the Client does not belong to a ClientStrand. + */ + static boost::intrusive_ptr<ClientStrand> get(Client* client); + + ClientStrand(ServiceContext::UniqueClient client) + : _clientPtr(client.get()), _client(std::move(client)) {} + + /** + * Get a pointer to the underlying Client. + */ + Client* getClientPointer() noexcept { + return _clientPtr; + } + + /** + * Set the current Client for this thread and return a RAII guard to release it eventually. + * + * If the Client is currently bound, this function will block until the Client is available. + */ + auto bind() { + return Guard(this); + } + + /** + * Run a Task with the Client bound to the current thread. + * + * This function runs the task inline and assumes that the Client is not already bound to the + * current thread. If the Client is currently bound, this function will block until it is + * released. + */ + template <typename Task, typename... Args> + void run(Task task, Args&&... args) { + auto guard = bind(); + + return task(std::forward<Args>(args)...); + } + + /** + * Make a wrapped executor around another. + */ + ExecutorPtr makeExecutor(ExecutorPtr exec) { + return std::make_shared<Executor>(this, std::move(exec)); + } + + /** + * Return if the strand is currently bound to a Client. + */ + bool isBound() const noexcept { + return _isBound.load(); + } + +private: + /** + * Bind the Client to the current thread. + * + * This is only valid to call if no other thread has the Client bound. + */ + void _setCurrent() noexcept; + + /** + * Release the Client from the current thread. + * + * This is valid to call multiple times on the same thread. It is not valid to mix this with + * Client::releaseCurrent(). + */ + void _releaseCurrent() noexcept; + + Client* const _clientPtr; + + stdx::mutex _mutex; // NOLINT + + // Once we have stdx::atomic::wait(), we can get rid of the mutex in favor of this variable. + AtomicWord<bool> _isBound{false}; + + ServiceContext::UniqueClient _client; + + std::string _oldThreadName; +}; + +inline void ClientStrand::Executor::schedule(Task task) { + _exec->schedule([task = std::forward<Task>(task), strand = _strand](Status status) mutable { + strand->run(std::move(task), std::move(status)); + }); +} + +using ClientStrandPtr = boost::intrusive_ptr<ClientStrand>; + +} // namespace mongo diff --git a/src/mongo/db/client_strand_test.cpp b/src/mongo/db/client_strand_test.cpp new file mode 100644 index 00000000000..68c055f53c0 --- /dev/null +++ b/src/mongo/db/client_strand_test.cpp @@ -0,0 +1,381 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include <memory> + +#include "mongo/db/client_strand.h" +#include "mongo/db/service_context_test_fixture.h" +#include "mongo/unittest/barrier.h" +#include "mongo/unittest/death_test.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/assert_util.h" +#include "mongo/util/concurrency/thread_name.h" +#include "mongo/util/executor_test_util.h" + +namespace mongo { +namespace { + +class ClientStrandTest : public unittest::Test, public ScopedGlobalServiceContextForTest { +public: + constexpr static auto kClientName1 = "foo"; + constexpr static auto kClientName2 = "bar"; + + void assertStrandNotBound(const ClientStrandPtr& strand) { + ASSERT_FALSE(haveClient()); + ASSERT_FALSE(strand->isBound()); + } + + void assertStrandBound(const ClientStrandPtr& strand) { + // We have a Client. + ASSERT_TRUE(haveClient()); + ASSERT_TRUE(strand->isBound()); + + // The current Client and Thread have the correct name. + auto client = strand->getClientPointer(); + ASSERT_EQ(client, Client::getCurrent()); + ASSERT_EQ(client->desc(), getThreadName()); + } +}; + +TEST_F(ClientStrandTest, CreateOnly) { + auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1)); + + // We have no bound Client. + assertStrandNotBound(strand); + + // The Client should exist. + ASSERT_TRUE(strand->getClientPointer()); + + // The Client should reference its ClientStrand. + ASSERT_EQ(ClientStrand::get(strand->getClientPointer()), strand); +} + +TEST_F(ClientStrandTest, BindOnce) { + auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1)); + + // We have no bound Client. + assertStrandNotBound(strand); + + { + // Bind a single client + auto guard = strand->bind(); + assertStrandBound(strand); + + // The guard allows us to get the Client. + ASSERT_EQ(guard.get(), strand->getClientPointer()); + } + + // We have no bound Client again. + assertStrandNotBound(strand); +} + +TEST_F(ClientStrandTest, BindMultipleTimes) { + auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1)); + + // We have no bound Client. + assertStrandNotBound(strand); + + for (auto i = 0; i < 100; ++i) { + // Bind a bunch of times. + + { + auto guard = strand->bind(); + assertStrandBound(strand); + } + + // We have no bound Client again. + assertStrandNotBound(strand); + } +} + +TEST_F(ClientStrandTest, BindMultipleTimesAndDismiss) { + auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1)); + + // We have no bound Client. + assertStrandNotBound(strand); + + auto guard = strand->bind(); + for (auto i = 0; i < 100; ++i) { + assertStrandBound(strand); + + // Dismiss the current guard. + guard.dismiss(); + assertStrandNotBound(strand); + + // Assign a new guard. + guard = strand->bind(); + } + + // At the end we have a strand bound. + assertStrandBound(strand); +} + +TEST_F(ClientStrandTest, BindLocalBeforeWorkerThread) { + auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1)); + auto barrier = std::make_shared<unittest::Barrier>(2); + + // Set our state to an initial value. It is unsynchronized, but ClientStrand does synchronize, + // thus it should pass TSAN. + enum State { + kStarted, + kLocalThread, + kWorkerThread, + }; + State state = kStarted; + + assertStrandNotBound(strand); + + auto thread = stdx::thread([&, barrier] { + // Wait for local thread to bind the strand. + barrier->countDownAndWait(); + + auto guard = strand->bind(); + assertStrandBound(strand); + + // We've acquired the strand after the local thread. + ASSERT_EQ(state, kLocalThread); + state = kWorkerThread; + }); + + { + auto guard = strand->bind(); + assertStrandBound(strand); + + // Wait for the worker thread. + barrier->countDownAndWait(); + + // We've acquired the strand first. + ASSERT_EQ(state, kStarted); + state = kLocalThread; + } + + thread.join(); + + assertStrandNotBound(strand); + + // Bind one last time to synchronize the state. + auto guard = strand->bind(); + + // The worker thread acquired the strand last. + ASSERT_EQ(state, kWorkerThread); +} + +TEST_F(ClientStrandTest, BindLocalAfterWorkerThread) { + auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1)); + auto barrier = std::make_shared<unittest::Barrier>(2); + + // Set our state to an initial value. It is unsynchronized, but ClientStrand does synchronize, + // thus it should pass TSAN. + enum State { + kStarted, + kLocalThread, + kWorkerThread, + }; + State state = kStarted; + + assertStrandNotBound(strand); + + auto thread = stdx::thread([&, barrier] { + auto guard = strand->bind(); + assertStrandBound(strand); + + // Wait for local thread. + barrier->countDownAndWait(); + + // We've acquired the strand after the local thread. + ASSERT_EQ(state, kStarted); + state = kWorkerThread; + }); + + { + // Wait for the worker thread to bind the strand. + barrier->countDownAndWait(); + + auto guard = strand->bind(); + assertStrandBound(strand); + + // We've acquired the strand first. + ASSERT_EQ(state, kWorkerThread); + state = kLocalThread; + } + + thread.join(); + + assertStrandNotBound(strand); + + // Bind one last time to synchronize the state. + auto guard = strand->bind(); + assertStrandBound(strand); + + // The local thread acquired the strand last. + ASSERT_EQ(state, kLocalThread); +} + +TEST_F(ClientStrandTest, BindManyWorkerThreads) { + auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1)); + + constexpr size_t kCount = 100; + auto barrier = std::make_shared<unittest::Barrier>(kCount); + + size_t threadsBound = 0; + + assertStrandNotBound(strand); + + std::vector<stdx::thread> threads; + for (size_t i = 0; i < kCount; ++i) { + threads.emplace_back([&, barrier] { + // Wait for the herd. + barrier->countDownAndWait(); + + auto guard = strand->bind(); + assertStrandBound(strand); + + // This is technically atomic on x86 but TSAN should complain if it isn't synchronized. + ++threadsBound; + }); + } + + for (auto& thread : threads) { + thread.join(); + } + + assertStrandNotBound(strand); + + // Bind one last time to access the count. + auto guard = strand->bind(); + assertStrandBound(strand); + + // We've been bound to the amount of threads we expected. + ASSERT_EQ(threadsBound, kCount); +} + +TEST_F(ClientStrandTest, SwapStrands) { + auto strand1 = ClientStrand::make(getServiceContext()->makeClient(kClientName1)); + auto strand2 = ClientStrand::make(getServiceContext()->makeClient(kClientName2)); + + assertStrandNotBound(strand1); + assertStrandNotBound(strand2); + + for (size_t i = 0; i < 100; ++i) { + // Alternate between binding strand1 and strand2. + auto& strand = (i % 2 == 0) ? strand1 : strand2; + auto guard = strand->bind(); + + assertStrandBound(strand); + } + + assertStrandNotBound(strand1); + assertStrandNotBound(strand2); +} + +TEST_F(ClientStrandTest, Executor) { + constexpr size_t kCount = 100; + + auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1)); + + assertStrandNotBound(strand); + + auto exec = strand->makeExecutor(InlineQueuedCountingExecutor::make()); + + // Schedule a series of tasks onto the wrapped executor. Note that while this is running on the + // local thread, this is not true recursive execution which would deadlock. + size_t i = 0; + unique_function<void(void)> reschedule; + reschedule = [&] { + exec->schedule([&](Status status) { + invariant(status); + assertStrandBound(strand); + + if (++i >= kCount) { + // We've rescheduled enough. + return; + } + + reschedule(); + }); + }; + + reschedule(); + assertStrandNotBound(strand); + + // Confirm we scheduled as many times as we expected. + ASSERT_EQ(i, kCount); +} + +DEATH_TEST_F(ClientStrandTest, ReplaceCurrentAfterBind, ClientStrand::kUnableToRecoverClient) { + auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1)); + + assertStrandNotBound(strand); + + auto guard = strand->bind(); + assertStrandBound(strand); + + // We need to capture the UniqueClient to avoid ABA pointer comparison issues with tcmalloc. In + // practice, this failure mode is most likely if someone is using an AlternativeClientRegion, + // which has its own issues. + auto stolenClient = Client::releaseCurrent(); + Client::setCurrent(getServiceContext()->makeClient(kClientName2)); + + // Dismiss the guard for an explicit failure point. + guard.dismiss(); +} + +DEATH_TEST_F(ClientStrandTest, ReleaseCurrentAfterBind, "No client to release") { + auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1)); + + assertStrandNotBound(strand); + + auto guard = strand->bind(); + assertStrandBound(strand); + + Client::releaseCurrent(); + + // Dismiss the guard for an explicit failure point. + guard.dismiss(); +} + +DEATH_TEST_F(ClientStrandTest, BindAfterBind, "Already have client on this thread") { + auto strand1 = ClientStrand::make(getServiceContext()->makeClient(kClientName1)); + auto strand2 = ClientStrand::make(getServiceContext()->makeClient(kClientName2)); + + assertStrandNotBound(strand1); + assertStrandNotBound(strand2); + + // Bind our first strand. + auto guard1 = strand1->bind(); + assertStrandBound(strand1); + + // Bind our second strand...and fail hard. + auto guard2 = strand2->bind(); +} + +} // namespace +} // namespace mongo diff --git a/src/mongo/transport/service_state_machine.cpp b/src/mongo/transport/service_state_machine.cpp index 0cb587426e9..acc60bc3fcb 100644 --- a/src/mongo/transport/service_state_machine.cpp +++ b/src/mongo/transport/service_state_machine.cpp @@ -38,6 +38,7 @@ #include "mongo/base/status.h" #include "mongo/config.h" #include "mongo/db/client.h" +#include "mongo/db/client_strand.h" #include "mongo/db/dbmessage.h" #include "mongo/db/stats/counters.h" #include "mongo/db/traffic_recorder.h" @@ -56,7 +57,6 @@ #include "mongo/transport/transport_layer.h" #include "mongo/util/assert_util.h" #include "mongo/util/concurrency/idle_thread_block.h" -#include "mongo/util/concurrency/thread_name.h" #include "mongo/util/debug_util.h" #include "mongo/util/exit.h" #include "mongo/util/fail_point.h" @@ -204,19 +204,11 @@ public: */ enum class Ownership { kUnowned, kOwned, kStatic }; - /* - * A class that wraps up lifetime management of the _client and _threadName for each step in - * runOnce(); - */ - class ThreadGuard; - class ThreadGuardedExecutor; - Impl(ServiceContext::UniqueClient client) : _state{State::Created}, _serviceContext{client->getServiceContext()}, _sep{_serviceContext->getServiceEntryPoint()}, - _client{std::move(client)}, - _clientPtr{_client.get()} {} + _clientStrand{ClientStrand::make(std::move(client))} {} void start(ServiceExecutorContext seCtx); @@ -257,8 +249,7 @@ public: void sinkCallback(Status status); /* - * Source/Sink message from the TransportLayer. These will invalidate the ThreadGuard just - * before waiting on the TL. + * Source/Sink message from the TransportLayer. */ Future<void> sourceMessage(); Future<void> sinkMessage(); @@ -290,14 +281,14 @@ public: * Gets the transport::Session associated with this connection */ const transport::SessionHandle& session() { - return _clientPtr->session(); + return _clientStrand->getClientPointer()->session(); } /* * Gets the transport::ServiceExecutor associated with this connection. */ ServiceExecutor* executor() { - return ServiceExecutorContext::get(_clientPtr)->getServiceExecutor(); + return ServiceExecutorContext::get(_clientStrand->getClientPointer())->getServiceExecutor(); } private: @@ -307,8 +298,7 @@ private: ServiceEntryPoint* const _sep; transport::SessionHandle _sessionHandle; - ServiceContext::UniqueClient _client; - Client* _clientPtr; + ClientStrandPtr _clientStrand; std::function<void()> _cleanupHook; bool _inExhaust = false; @@ -323,214 +313,6 @@ private: ServiceContext::UniqueOperationContext _opCtx; }; -/* - * This class wraps up the logic for swapping/unswapping the Client when transitioning between - * states. - * - * In debug builds this also ensures that only one thread is working on the SSM at once. - */ -class ServiceStateMachine::Impl::ThreadGuard { - ThreadGuard(ThreadGuard&) = delete; - ThreadGuard& operator=(ThreadGuard&) = delete; - -public: - explicit ThreadGuard(ServiceStateMachine::Impl* ssm) : _ssm{ssm} { - invariant(_ssm); - - if (_ssm->_clientPtr == Client::getCurrent()) { - // We're not the first on this thread, nothing more to do. - return; - } - - auto& client = _ssm->_client; - invariant(client); - - // Set up the thread name - auto oldThreadName = getThreadName(); - const auto& threadName = client->desc(); - if (oldThreadName != threadName) { - _oldThreadName = oldThreadName.toString(); - setThreadName(threadName); - } - - // Swap the current Client so calls to cc() work as expected - Client::setCurrent(std::move(client)); - _haveTakenOwnership = true; - } - - // Constructing from a moved ThreadGuard invalidates the other thread guard. - ThreadGuard(ThreadGuard&& other) - : _ssm{std::exchange(other._ssm, nullptr)}, - _haveTakenOwnership{std::exchange(_haveTakenOwnership, false)} {} - - ThreadGuard& operator=(ThreadGuard&& other) { - _ssm = std::exchange(other._ssm, nullptr); - _haveTakenOwnership = std::exchange(other._haveTakenOwnership, false); - return *this; - }; - - ThreadGuard() = delete; - - ~ThreadGuard() { - release(); - } - - explicit operator bool() const { - return _ssm; - } - - void release() { - if (!_ssm) { - // We've been released or moved from. - return; - } - - // If we have a ServiceStateMachine pointer, then it should control the current Client. - invariant(_ssm->_clientPtr == Client::getCurrent()); - - if (auto haveTakenOwnership = std::exchange(_haveTakenOwnership, false); - !haveTakenOwnership) { - // Reset our pointer so that we cannot release again. - _ssm = nullptr; - - // We are not the original owner, nothing more to do. - return; - } - - // Reclaim the client. - _ssm->_client = Client::releaseCurrent(); - - // Reset our pointer so that we cannot release again. - _ssm = nullptr; - - if (!_oldThreadName.empty()) { - // Reset the old thread name. - setThreadName(_oldThreadName); - } - } - -private: - ServiceStateMachine::Impl* _ssm = nullptr; - - bool _haveTakenOwnership = false; - std::string _oldThreadName; -}; - -auto getThreadGuardedExecutor = - Client::declareDecoration<std::shared_ptr<ServiceStateMachine::Impl::ThreadGuardedExecutor>>(); - -/* - * A ThreadGuardedExecutor is a client decoration that can wrap any OutOfLineExecutor to allow - * processing tasks while having the client object (i.e., `Client`) attached to the executor thread. - * In particular, scheduling tasks through a ThreadGuardedExecutor ensures that the corresponding - * client object is reachable through `Client::getCurrent()` and prevents concurrent accesses to the - * client object. A ThreadGuardedExecutor is only valid in the context of ServiceStateMachine. - * Also, any task scheduled through ThreadGuardedExecutor captures a reference (i.e., shared - * pointer) to both ServiceStateMachine and ThreadGuardedExecutor, thus accessing these objects - * through raw pointers (e.g., `this`) is considered safe. - */ -class ServiceStateMachine::Impl::ThreadGuardedExecutor - : public std::enable_shared_from_this<ServiceStateMachine::Impl::ThreadGuardedExecutor> { -public: - // Wraps an instance of OutOfLineExecutor and delegates scheduling to ThreadGuardedExecutor - class WrappedExecutor : public OutOfLineExecutor, - public std::enable_shared_from_this<WrappedExecutor> { - public: - WrappedExecutor(const WrappedExecutor&) = delete; - WrappedExecutor(WrappedExecutor&&) = delete; - - WrappedExecutor(std::shared_ptr<ThreadGuardedExecutor> parent, OutOfLineExecutor* executor) - : _parent(std::move(parent)), _executor(executor) {} - - void schedule(OutOfLineExecutor::Task task) override { - _parent->schedule(_executor, std::move(task)); - } - - private: - std::shared_ptr<ThreadGuardedExecutor> const _parent; - OutOfLineExecutor* const _executor; - }; - - ThreadGuardedExecutor() = delete; - ThreadGuardedExecutor(const ThreadGuardedExecutor&) = delete; - - explicit ThreadGuardedExecutor(std::weak_ptr<ServiceStateMachine::Impl> ssm) - : _ssm(std::move(ssm)) {} - - ThreadGuardedExecutor(ThreadGuardedExecutor&& other) - : _isBusy{other._isBusy.load()}, _ssm(other._ssm), _guard(std::move(other._guard)) {} - - ~ThreadGuardedExecutor() { - invariant(!_isBusy.load()); - invariant(!_guard.has_value()); - } - - static void set(Client* client, ThreadGuardedExecutor instance) { - auto& clientThreadGuardedExecutor = getThreadGuardedExecutor(client); - invariant(!clientThreadGuardedExecutor); - clientThreadGuardedExecutor = std::make_shared<decltype(instance)>(std::move(instance)); - } - - static std::shared_ptr<ThreadGuardedExecutor> get(const Client* client) { - return getThreadGuardedExecutor(client); - } - - auto wrapExecutor(OutOfLineExecutor* executor) { - return std::make_shared<WrappedExecutor>(shared_from_this(), executor); - } - - void schedule(OutOfLineExecutor* executor, OutOfLineExecutor::Task task) { - // Since `ThreadGuardedExecutor` is a client decoration, and SSM owns the client object, - // ServiceStateMachine must be present when tasks are scheduled here. - auto ssm = _ssm.lock(); - invariant(ssm); - - executor->schedule([this, - executor, // Valid as the executor must be present to run the task - task = std::move(task), - ssm = std::move(ssm), - anchor = shared_from_this()](Status status) mutable { - if (auto wasBusy = _isBusy.swap(true); wasBusy) { - // Reschedule if another executor thread is running a task for the client. - LOGV2_DEBUG(4910704, kDiagnosticLogLevel, "Rescheduling thread-guarded task"); - schedule(executor, std::move(task)); - return; - } - - invariant(!_guard.has_value()); - _guard.emplace(ssm.get()); - LOGV2_DEBUG(4910701, kDiagnosticLogLevel, "Acquired ThreadGuard in scheduled task"); - - ON_BLOCK_EXIT([&] { - if (_guard.has_value()) { - releaseThreadGuard(); - } - }); - - LOGV2_DEBUG( - 4910703, kDiagnosticLogLevel, "Started running task in a thread-guarded context"); - task(status); - }); - } - - // Must only be called on the thread that owns the guard - void releaseThreadGuard() { - invariant(_guard.has_value() && _isBusy.load()); - LOGV2_DEBUG(4910702, kDiagnosticLogLevel, "Releasing the ThreadGuard"); - _guard = boost::none; - _isBusy.store(false); - } - -private: - static constexpr auto kDiagnosticLogLevel = 3; - - // Set to `true` if the executor is busy running a task on behalf of the corresponding client. - AtomicWord<bool> _isBusy{false}; - - std::weak_ptr<ServiceStateMachine::Impl> _ssm; - boost::optional<ThreadGuard> _guard; -}; - Future<void> ServiceStateMachine::Impl::sourceMessage() { invariant(_inMessage.empty()); invariant(_state.load() == State::Source); @@ -583,8 +365,6 @@ Future<void> ServiceStateMachine::Impl::sinkMessage() { } void ServiceStateMachine::Impl::sourceCallback(Status status) { - auto guard = ThreadGuard(this); - invariant(state() == State::SourceWait); auto remote = session()->remote(); @@ -626,8 +406,6 @@ void ServiceStateMachine::Impl::sourceCallback(Status status) { } void ServiceStateMachine::Impl::sinkCallback(Status status) { - auto guard = ThreadGuard(this); - invariant(state() == State::SinkWait); // If there was an error sinking the message to the client, then we should print an error and @@ -736,25 +514,25 @@ Future<void> ServiceStateMachine::Impl::processMessage() { void ServiceStateMachine::Impl::start(ServiceExecutorContext seCtx) { { - stdx::lock_guard lk(*_clientPtr); - - ServiceExecutorContext::set(_clientPtr, std::move(seCtx)); + auto client = _clientStrand->getClientPointer(); + stdx::lock_guard lk(*client); + ServiceExecutorContext::set(client, std::move(seCtx)); } - // Set the executor decoration here to ensure `shared_from_this()` returns a valid pointer - ThreadGuardedExecutor::set(_clientPtr, ThreadGuardedExecutor(shared_from_this())); + invariant(_state.swap(State::Source) == State::Created); - ThreadGuardedExecutor::get(_clientPtr) - ->schedule( - executor(), - GuaranteedExecutor::enforceRunOnce([this, anchor = shared_from_this()](Status status) { - // If this is the first run of the SSM, then update its state to Source - if (state() == State::Created) { - _state.store(State::Source); - } + auto cb = [this, anchor = shared_from_this()](Status status) { + _clientStrand->run([&] { + if (ErrorCodes::isCancelationError(status)) { + cleanupSession(); + return; + } + invariant(status); - runOnce(); - })); + runOnce(); + }); + }; + executor()->schedule(std::move(cb)); } void ServiceStateMachine::Impl::runOnce() { @@ -791,19 +569,25 @@ void ServiceStateMachine::Impl::runOnce() { "error"_attr = status); terminate(); - ThreadGuardedExecutor::get(_clientPtr) - ->schedule(executor(), - GuaranteedExecutor::enforceRunOnce( - [this, anchor = shared_from_this()](Status status) { - cleanupSession(); - })); + auto cb = [this, anchor = shared_from_this()](Status status) { + _clientStrand->run([&] { cleanupSession(); }); + }; + executor()->schedule(std::move(cb)); return; } - ThreadGuardedExecutor::get(_clientPtr) - ->schedule(executor(), GuaranteedExecutor::enforceRunOnce([this](Status status) { - runOnce(); - })); + auto cb = [this, anchor = shared_from_this()](Status status) { + _clientStrand->run([&] { + if (ErrorCodes::isCancelationError(status)) { + cleanupSession(); + return; + } + invariant(status); + + runOnce(); + }); + }; + executor()->schedule(std::move(cb)); }); } @@ -878,8 +662,9 @@ void ServiceStateMachine::Impl::cleanupSession() { cleanupExhaustResources(); { - stdx::lock_guard lk(*_clientPtr); - transport::ServiceExecutorContext::reset(_clientPtr); + auto client = _clientStrand->getClientPointer(); + stdx::lock_guard lk(*client); + transport::ServiceExecutorContext::reset(client); } if (auto cleanupHook = std::exchange(_cleanupHook, {})) { |