summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Caimano <ben.caimano@10gen.com>2020-09-23 04:03:43 +0000
committerBen Caimano <ben.caimano@10gen.com>2020-10-19 20:25:05 +0000
commitfa29e47f37da2353b49ee71c907026b769fdc607 (patch)
tree1a9137ca238a39cc1926e3876f3bca38e12fa39d
parentb03c93d55baefc8a70a6ee790f1d497fd1ff8b70 (diff)
downloadmongo-fa29e47f37da2353b49ee71c907026b769fdc607.tar.gz
SERVER-51278 Introduced ClientStrand
-rw-r--r--src/mongo/db/SConscript2
-rw-r--r--src/mongo/db/client.cpp2
-rw-r--r--src/mongo/db/client_strand.cpp95
-rw-r--r--src/mongo/db/client_strand.h214
-rw-r--r--src/mongo/db/client_strand_test.cpp381
-rw-r--r--src/mongo/transport/service_state_machine.cpp295
6 files changed, 733 insertions, 256 deletions
diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript
index 14706fa55ce..e62270e4f62 100644
--- a/src/mongo/db/SConscript
+++ b/src/mongo/db/SConscript
@@ -450,6 +450,7 @@ env.Library(
source=[
'baton.cpp',
'client.cpp',
+ 'client_strand.cpp',
'default_baton.cpp',
'operation_context.cpp',
'operation_context_group.cpp',
@@ -2212,6 +2213,7 @@ envWithAsio.CppUnitTest(
target='db_unittests',
source=[
'catalog_raii_test.cpp',
+ 'client_strand_test.cpp',
'collection_index_usage_tracker_test.cpp',
'commands_test.cpp',
'curop_test.cpp',
diff --git a/src/mongo/db/client.cpp b/src/mongo/db/client.cpp
index a9678bee962..99175dd618f 100644
--- a/src/mongo/db/client.cpp
+++ b/src/mongo/db/client.cpp
@@ -155,7 +155,7 @@ bool haveClient() {
}
ServiceContext::UniqueClient Client::releaseCurrent() {
- invariant(haveClient());
+ invariant(haveClient(), "No client to release");
return std::move(currentClient);
}
diff --git a/src/mongo/db/client_strand.cpp b/src/mongo/db/client_strand.cpp
new file mode 100644
index 00000000000..470ae6d3133
--- /dev/null
+++ b/src/mongo/db/client_strand.cpp
@@ -0,0 +1,95 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kDefault
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/client_strand.h"
+
+#include "mongo/logv2/log.h"
+#include "mongo/util/concurrency/thread_name.h"
+
+namespace mongo {
+namespace {
+struct ClientStrandData {
+ ClientStrand* strand = nullptr;
+};
+
+auto getClientStrandData = Client::declareDecoration<ClientStrandData>();
+} // namespace
+
+boost::intrusive_ptr<ClientStrand> ClientStrand::make(ServiceContext::UniqueClient client) {
+ auto strand = make_intrusive<ClientStrand>(std::move(client));
+ getClientStrandData(strand->getClientPointer()).strand = strand.get();
+ return strand;
+}
+
+boost::intrusive_ptr<ClientStrand> ClientStrand::get(Client* client) {
+ return getClientStrandData(client).strand;
+}
+
+void ClientStrand::_setCurrent() noexcept {
+ invariant(_isBound.load());
+ invariant(_client);
+
+ LOGV2_DEBUG(
+ 4910701, kDiagnosticLogLevel, "Setting the Client", "client"_attr = _client->desc());
+
+ // Set the Client for this thread so calls to Client::getCurrent() works as expected.
+ Client::setCurrent(std::move(_client));
+
+ // Set up the thread name.
+ auto oldThreadName = getThreadName();
+ StringData threadName = _clientPtr->desc();
+ if (oldThreadName != threadName) {
+ _oldThreadName = oldThreadName.toString();
+ setThreadName(threadName);
+ LOGV2_DEBUG(4910701, kDiagnosticLogLevel, "Set thread name", "name"_attr = threadName);
+ }
+}
+
+void ClientStrand::_releaseCurrent() noexcept {
+ invariant(_isBound.load());
+ invariant(!_client);
+
+ // Reclaim the client.
+ _client = Client::releaseCurrent();
+ invariant(_client.get() == _clientPtr, kUnableToRecoverClient);
+
+ if (!_oldThreadName.empty()) {
+ // Reset the old thread name.
+ setThreadName(_oldThreadName);
+ }
+
+ LOGV2_DEBUG(
+ 4910702, kDiagnosticLogLevel, "Released the Client", "client"_attr = _client->desc());
+}
+
+} // namespace mongo
diff --git a/src/mongo/db/client_strand.h b/src/mongo/db/client_strand.h
new file mode 100644
index 00000000000..20b9d940d27
--- /dev/null
+++ b/src/mongo/db/client_strand.h
@@ -0,0 +1,214 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <string>
+
+#include "mongo/db/client.h"
+#include "mongo/db/service_context.h"
+#include "mongo/platform/atomic_word.h"
+#include "mongo/stdx/mutex.h"
+#include "mongo/util/intrusive_counter.h"
+#include "mongo/util/out_of_line_executor.h"
+
+namespace mongo {
+
+/**
+ * ClientStrand is a reference counted type for loaning Clients to threads.
+ *
+ * ClientStrand maintains the lifetime of its wrapped Client object and provides functionality to
+ * "bind" that Client to one and only one thread at a time. Its functions are synchronized.
+ */
+class ClientStrand final : public RefCountable {
+ static constexpr auto kDiagnosticLogLevel = 3;
+
+public:
+ static constexpr auto kUnableToRecoverClient = "Unable to recover Client for ClientStrand";
+
+ /**
+ * A simple RAII guard to set and release Clients.
+ */
+ class Guard {
+ public:
+ Guard() = default;
+ Guard(Guard&&) = default;
+ Guard& operator=(Guard&&) = default;
+
+ Guard(const Guard&) = delete;
+ Guard& operator=(const Guard&) = delete;
+
+ Guard(ClientStrand* strand) : _strand(strand) {
+ // Hold the lock for as long as the Guard is around. This forces other consumers to
+ // queue behind the Guard.
+ _strand->_mutex.lock();
+ _strand->_isBound.store(true);
+
+ _strand->_setCurrent();
+ }
+
+ ~Guard() {
+ dismiss();
+ }
+
+ void dismiss() noexcept {
+ auto strand = std::exchange(_strand, {});
+ if (!strand) {
+ return;
+ }
+
+ strand->_releaseCurrent();
+ strand->_isBound.store(false);
+ strand->_mutex.unlock();
+ }
+
+ Client* get() noexcept {
+ return _strand->getClientPointer();
+ }
+
+ Client* operator->() noexcept {
+ return get();
+ }
+
+ Client& operator*() noexcept {
+ return *get();
+ }
+
+ private:
+ boost::intrusive_ptr<ClientStrand> _strand;
+ };
+
+ /**
+ * A simple wrapping executor to run tasks while a Client is bound.
+ */
+ class Executor final : public OutOfLineExecutor {
+ public:
+ Executor(ClientStrand* strand, ExecutorPtr exec)
+ : _strand(strand), _exec(std::move(exec)) {}
+ void schedule(Task task) override;
+
+ private:
+ boost::intrusive_ptr<ClientStrand> _strand;
+ ExecutorPtr _exec;
+ };
+
+ /**
+ * Make a new ClientStrand from a UniqueClient.
+ */
+ static boost::intrusive_ptr<ClientStrand> make(ServiceContext::UniqueClient client);
+
+ /**
+ * Acquire an owning ClientStrand given a client.
+ *
+ * This will return nullptr if the Client does not belong to a ClientStrand.
+ */
+ static boost::intrusive_ptr<ClientStrand> get(Client* client);
+
+ ClientStrand(ServiceContext::UniqueClient client)
+ : _clientPtr(client.get()), _client(std::move(client)) {}
+
+ /**
+ * Get a pointer to the underlying Client.
+ */
+ Client* getClientPointer() noexcept {
+ return _clientPtr;
+ }
+
+ /**
+ * Set the current Client for this thread and return a RAII guard to release it eventually.
+ *
+ * If the Client is currently bound, this function will block until the Client is available.
+ */
+ auto bind() {
+ return Guard(this);
+ }
+
+ /**
+ * Run a Task with the Client bound to the current thread.
+ *
+ * This function runs the task inline and assumes that the Client is not already bound to the
+ * current thread. If the Client is currently bound, this function will block until it is
+ * released.
+ */
+ template <typename Task, typename... Args>
+ void run(Task task, Args&&... args) {
+ auto guard = bind();
+
+ return task(std::forward<Args>(args)...);
+ }
+
+ /**
+ * Make a wrapped executor around another.
+ */
+ ExecutorPtr makeExecutor(ExecutorPtr exec) {
+ return std::make_shared<Executor>(this, std::move(exec));
+ }
+
+ /**
+ * Return if the strand is currently bound to a Client.
+ */
+ bool isBound() const noexcept {
+ return _isBound.load();
+ }
+
+private:
+ /**
+ * Bind the Client to the current thread.
+ *
+ * This is only valid to call if no other thread has the Client bound.
+ */
+ void _setCurrent() noexcept;
+
+ /**
+ * Release the Client from the current thread.
+ *
+ * This is valid to call multiple times on the same thread. It is not valid to mix this with
+ * Client::releaseCurrent().
+ */
+ void _releaseCurrent() noexcept;
+
+ Client* const _clientPtr;
+
+ stdx::mutex _mutex; // NOLINT
+
+ // Once we have stdx::atomic::wait(), we can get rid of the mutex in favor of this variable.
+ AtomicWord<bool> _isBound{false};
+
+ ServiceContext::UniqueClient _client;
+
+ std::string _oldThreadName;
+};
+
+inline void ClientStrand::Executor::schedule(Task task) {
+ _exec->schedule([task = std::forward<Task>(task), strand = _strand](Status status) mutable {
+ strand->run(std::move(task), std::move(status));
+ });
+}
+
+using ClientStrandPtr = boost::intrusive_ptr<ClientStrand>;
+
+} // namespace mongo
diff --git a/src/mongo/db/client_strand_test.cpp b/src/mongo/db/client_strand_test.cpp
new file mode 100644
index 00000000000..68c055f53c0
--- /dev/null
+++ b/src/mongo/db/client_strand_test.cpp
@@ -0,0 +1,381 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include <memory>
+
+#include "mongo/db/client_strand.h"
+#include "mongo/db/service_context_test_fixture.h"
+#include "mongo/unittest/barrier.h"
+#include "mongo/unittest/death_test.h"
+#include "mongo/unittest/unittest.h"
+#include "mongo/util/assert_util.h"
+#include "mongo/util/concurrency/thread_name.h"
+#include "mongo/util/executor_test_util.h"
+
+namespace mongo {
+namespace {
+
+class ClientStrandTest : public unittest::Test, public ScopedGlobalServiceContextForTest {
+public:
+ constexpr static auto kClientName1 = "foo";
+ constexpr static auto kClientName2 = "bar";
+
+ void assertStrandNotBound(const ClientStrandPtr& strand) {
+ ASSERT_FALSE(haveClient());
+ ASSERT_FALSE(strand->isBound());
+ }
+
+ void assertStrandBound(const ClientStrandPtr& strand) {
+ // We have a Client.
+ ASSERT_TRUE(haveClient());
+ ASSERT_TRUE(strand->isBound());
+
+ // The current Client and Thread have the correct name.
+ auto client = strand->getClientPointer();
+ ASSERT_EQ(client, Client::getCurrent());
+ ASSERT_EQ(client->desc(), getThreadName());
+ }
+};
+
+TEST_F(ClientStrandTest, CreateOnly) {
+ auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1));
+
+ // We have no bound Client.
+ assertStrandNotBound(strand);
+
+ // The Client should exist.
+ ASSERT_TRUE(strand->getClientPointer());
+
+ // The Client should reference its ClientStrand.
+ ASSERT_EQ(ClientStrand::get(strand->getClientPointer()), strand);
+}
+
+TEST_F(ClientStrandTest, BindOnce) {
+ auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1));
+
+ // We have no bound Client.
+ assertStrandNotBound(strand);
+
+ {
+ // Bind a single client
+ auto guard = strand->bind();
+ assertStrandBound(strand);
+
+ // The guard allows us to get the Client.
+ ASSERT_EQ(guard.get(), strand->getClientPointer());
+ }
+
+ // We have no bound Client again.
+ assertStrandNotBound(strand);
+}
+
+TEST_F(ClientStrandTest, BindMultipleTimes) {
+ auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1));
+
+ // We have no bound Client.
+ assertStrandNotBound(strand);
+
+ for (auto i = 0; i < 100; ++i) {
+ // Bind a bunch of times.
+
+ {
+ auto guard = strand->bind();
+ assertStrandBound(strand);
+ }
+
+ // We have no bound Client again.
+ assertStrandNotBound(strand);
+ }
+}
+
+TEST_F(ClientStrandTest, BindMultipleTimesAndDismiss) {
+ auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1));
+
+ // We have no bound Client.
+ assertStrandNotBound(strand);
+
+ auto guard = strand->bind();
+ for (auto i = 0; i < 100; ++i) {
+ assertStrandBound(strand);
+
+ // Dismiss the current guard.
+ guard.dismiss();
+ assertStrandNotBound(strand);
+
+ // Assign a new guard.
+ guard = strand->bind();
+ }
+
+ // At the end we have a strand bound.
+ assertStrandBound(strand);
+}
+
+TEST_F(ClientStrandTest, BindLocalBeforeWorkerThread) {
+ auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1));
+ auto barrier = std::make_shared<unittest::Barrier>(2);
+
+ // Set our state to an initial value. It is unsynchronized, but ClientStrand does synchronize,
+ // thus it should pass TSAN.
+ enum State {
+ kStarted,
+ kLocalThread,
+ kWorkerThread,
+ };
+ State state = kStarted;
+
+ assertStrandNotBound(strand);
+
+ auto thread = stdx::thread([&, barrier] {
+ // Wait for local thread to bind the strand.
+ barrier->countDownAndWait();
+
+ auto guard = strand->bind();
+ assertStrandBound(strand);
+
+ // We've acquired the strand after the local thread.
+ ASSERT_EQ(state, kLocalThread);
+ state = kWorkerThread;
+ });
+
+ {
+ auto guard = strand->bind();
+ assertStrandBound(strand);
+
+ // Wait for the worker thread.
+ barrier->countDownAndWait();
+
+ // We've acquired the strand first.
+ ASSERT_EQ(state, kStarted);
+ state = kLocalThread;
+ }
+
+ thread.join();
+
+ assertStrandNotBound(strand);
+
+ // Bind one last time to synchronize the state.
+ auto guard = strand->bind();
+
+ // The worker thread acquired the strand last.
+ ASSERT_EQ(state, kWorkerThread);
+}
+
+TEST_F(ClientStrandTest, BindLocalAfterWorkerThread) {
+ auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1));
+ auto barrier = std::make_shared<unittest::Barrier>(2);
+
+ // Set our state to an initial value. It is unsynchronized, but ClientStrand does synchronize,
+ // thus it should pass TSAN.
+ enum State {
+ kStarted,
+ kLocalThread,
+ kWorkerThread,
+ };
+ State state = kStarted;
+
+ assertStrandNotBound(strand);
+
+ auto thread = stdx::thread([&, barrier] {
+ auto guard = strand->bind();
+ assertStrandBound(strand);
+
+ // Wait for local thread.
+ barrier->countDownAndWait();
+
+ // We've acquired the strand after the local thread.
+ ASSERT_EQ(state, kStarted);
+ state = kWorkerThread;
+ });
+
+ {
+ // Wait for the worker thread to bind the strand.
+ barrier->countDownAndWait();
+
+ auto guard = strand->bind();
+ assertStrandBound(strand);
+
+ // We've acquired the strand first.
+ ASSERT_EQ(state, kWorkerThread);
+ state = kLocalThread;
+ }
+
+ thread.join();
+
+ assertStrandNotBound(strand);
+
+ // Bind one last time to synchronize the state.
+ auto guard = strand->bind();
+ assertStrandBound(strand);
+
+ // The local thread acquired the strand last.
+ ASSERT_EQ(state, kLocalThread);
+}
+
+TEST_F(ClientStrandTest, BindManyWorkerThreads) {
+ auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1));
+
+ constexpr size_t kCount = 100;
+ auto barrier = std::make_shared<unittest::Barrier>(kCount);
+
+ size_t threadsBound = 0;
+
+ assertStrandNotBound(strand);
+
+ std::vector<stdx::thread> threads;
+ for (size_t i = 0; i < kCount; ++i) {
+ threads.emplace_back([&, barrier] {
+ // Wait for the herd.
+ barrier->countDownAndWait();
+
+ auto guard = strand->bind();
+ assertStrandBound(strand);
+
+ // This is technically atomic on x86 but TSAN should complain if it isn't synchronized.
+ ++threadsBound;
+ });
+ }
+
+ for (auto& thread : threads) {
+ thread.join();
+ }
+
+ assertStrandNotBound(strand);
+
+ // Bind one last time to access the count.
+ auto guard = strand->bind();
+ assertStrandBound(strand);
+
+ // We've been bound to the amount of threads we expected.
+ ASSERT_EQ(threadsBound, kCount);
+}
+
+TEST_F(ClientStrandTest, SwapStrands) {
+ auto strand1 = ClientStrand::make(getServiceContext()->makeClient(kClientName1));
+ auto strand2 = ClientStrand::make(getServiceContext()->makeClient(kClientName2));
+
+ assertStrandNotBound(strand1);
+ assertStrandNotBound(strand2);
+
+ for (size_t i = 0; i < 100; ++i) {
+ // Alternate between binding strand1 and strand2.
+ auto& strand = (i % 2 == 0) ? strand1 : strand2;
+ auto guard = strand->bind();
+
+ assertStrandBound(strand);
+ }
+
+ assertStrandNotBound(strand1);
+ assertStrandNotBound(strand2);
+}
+
+TEST_F(ClientStrandTest, Executor) {
+ constexpr size_t kCount = 100;
+
+ auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1));
+
+ assertStrandNotBound(strand);
+
+ auto exec = strand->makeExecutor(InlineQueuedCountingExecutor::make());
+
+ // Schedule a series of tasks onto the wrapped executor. Note that while this is running on the
+ // local thread, this is not true recursive execution which would deadlock.
+ size_t i = 0;
+ unique_function<void(void)> reschedule;
+ reschedule = [&] {
+ exec->schedule([&](Status status) {
+ invariant(status);
+ assertStrandBound(strand);
+
+ if (++i >= kCount) {
+ // We've rescheduled enough.
+ return;
+ }
+
+ reschedule();
+ });
+ };
+
+ reschedule();
+ assertStrandNotBound(strand);
+
+ // Confirm we scheduled as many times as we expected.
+ ASSERT_EQ(i, kCount);
+}
+
+DEATH_TEST_F(ClientStrandTest, ReplaceCurrentAfterBind, ClientStrand::kUnableToRecoverClient) {
+ auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1));
+
+ assertStrandNotBound(strand);
+
+ auto guard = strand->bind();
+ assertStrandBound(strand);
+
+ // We need to capture the UniqueClient to avoid ABA pointer comparison issues with tcmalloc. In
+ // practice, this failure mode is most likely if someone is using an AlternativeClientRegion,
+ // which has its own issues.
+ auto stolenClient = Client::releaseCurrent();
+ Client::setCurrent(getServiceContext()->makeClient(kClientName2));
+
+ // Dismiss the guard for an explicit failure point.
+ guard.dismiss();
+}
+
+DEATH_TEST_F(ClientStrandTest, ReleaseCurrentAfterBind, "No client to release") {
+ auto strand = ClientStrand::make(getServiceContext()->makeClient(kClientName1));
+
+ assertStrandNotBound(strand);
+
+ auto guard = strand->bind();
+ assertStrandBound(strand);
+
+ Client::releaseCurrent();
+
+ // Dismiss the guard for an explicit failure point.
+ guard.dismiss();
+}
+
+DEATH_TEST_F(ClientStrandTest, BindAfterBind, "Already have client on this thread") {
+ auto strand1 = ClientStrand::make(getServiceContext()->makeClient(kClientName1));
+ auto strand2 = ClientStrand::make(getServiceContext()->makeClient(kClientName2));
+
+ assertStrandNotBound(strand1);
+ assertStrandNotBound(strand2);
+
+ // Bind our first strand.
+ auto guard1 = strand1->bind();
+ assertStrandBound(strand1);
+
+ // Bind our second strand...and fail hard.
+ auto guard2 = strand2->bind();
+}
+
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/transport/service_state_machine.cpp b/src/mongo/transport/service_state_machine.cpp
index 0cb587426e9..acc60bc3fcb 100644
--- a/src/mongo/transport/service_state_machine.cpp
+++ b/src/mongo/transport/service_state_machine.cpp
@@ -38,6 +38,7 @@
#include "mongo/base/status.h"
#include "mongo/config.h"
#include "mongo/db/client.h"
+#include "mongo/db/client_strand.h"
#include "mongo/db/dbmessage.h"
#include "mongo/db/stats/counters.h"
#include "mongo/db/traffic_recorder.h"
@@ -56,7 +57,6 @@
#include "mongo/transport/transport_layer.h"
#include "mongo/util/assert_util.h"
#include "mongo/util/concurrency/idle_thread_block.h"
-#include "mongo/util/concurrency/thread_name.h"
#include "mongo/util/debug_util.h"
#include "mongo/util/exit.h"
#include "mongo/util/fail_point.h"
@@ -204,19 +204,11 @@ public:
*/
enum class Ownership { kUnowned, kOwned, kStatic };
- /*
- * A class that wraps up lifetime management of the _client and _threadName for each step in
- * runOnce();
- */
- class ThreadGuard;
- class ThreadGuardedExecutor;
-
Impl(ServiceContext::UniqueClient client)
: _state{State::Created},
_serviceContext{client->getServiceContext()},
_sep{_serviceContext->getServiceEntryPoint()},
- _client{std::move(client)},
- _clientPtr{_client.get()} {}
+ _clientStrand{ClientStrand::make(std::move(client))} {}
void start(ServiceExecutorContext seCtx);
@@ -257,8 +249,7 @@ public:
void sinkCallback(Status status);
/*
- * Source/Sink message from the TransportLayer. These will invalidate the ThreadGuard just
- * before waiting on the TL.
+ * Source/Sink message from the TransportLayer.
*/
Future<void> sourceMessage();
Future<void> sinkMessage();
@@ -290,14 +281,14 @@ public:
* Gets the transport::Session associated with this connection
*/
const transport::SessionHandle& session() {
- return _clientPtr->session();
+ return _clientStrand->getClientPointer()->session();
}
/*
* Gets the transport::ServiceExecutor associated with this connection.
*/
ServiceExecutor* executor() {
- return ServiceExecutorContext::get(_clientPtr)->getServiceExecutor();
+ return ServiceExecutorContext::get(_clientStrand->getClientPointer())->getServiceExecutor();
}
private:
@@ -307,8 +298,7 @@ private:
ServiceEntryPoint* const _sep;
transport::SessionHandle _sessionHandle;
- ServiceContext::UniqueClient _client;
- Client* _clientPtr;
+ ClientStrandPtr _clientStrand;
std::function<void()> _cleanupHook;
bool _inExhaust = false;
@@ -323,214 +313,6 @@ private:
ServiceContext::UniqueOperationContext _opCtx;
};
-/*
- * This class wraps up the logic for swapping/unswapping the Client when transitioning between
- * states.
- *
- * In debug builds this also ensures that only one thread is working on the SSM at once.
- */
-class ServiceStateMachine::Impl::ThreadGuard {
- ThreadGuard(ThreadGuard&) = delete;
- ThreadGuard& operator=(ThreadGuard&) = delete;
-
-public:
- explicit ThreadGuard(ServiceStateMachine::Impl* ssm) : _ssm{ssm} {
- invariant(_ssm);
-
- if (_ssm->_clientPtr == Client::getCurrent()) {
- // We're not the first on this thread, nothing more to do.
- return;
- }
-
- auto& client = _ssm->_client;
- invariant(client);
-
- // Set up the thread name
- auto oldThreadName = getThreadName();
- const auto& threadName = client->desc();
- if (oldThreadName != threadName) {
- _oldThreadName = oldThreadName.toString();
- setThreadName(threadName);
- }
-
- // Swap the current Client so calls to cc() work as expected
- Client::setCurrent(std::move(client));
- _haveTakenOwnership = true;
- }
-
- // Constructing from a moved ThreadGuard invalidates the other thread guard.
- ThreadGuard(ThreadGuard&& other)
- : _ssm{std::exchange(other._ssm, nullptr)},
- _haveTakenOwnership{std::exchange(_haveTakenOwnership, false)} {}
-
- ThreadGuard& operator=(ThreadGuard&& other) {
- _ssm = std::exchange(other._ssm, nullptr);
- _haveTakenOwnership = std::exchange(other._haveTakenOwnership, false);
- return *this;
- };
-
- ThreadGuard() = delete;
-
- ~ThreadGuard() {
- release();
- }
-
- explicit operator bool() const {
- return _ssm;
- }
-
- void release() {
- if (!_ssm) {
- // We've been released or moved from.
- return;
- }
-
- // If we have a ServiceStateMachine pointer, then it should control the current Client.
- invariant(_ssm->_clientPtr == Client::getCurrent());
-
- if (auto haveTakenOwnership = std::exchange(_haveTakenOwnership, false);
- !haveTakenOwnership) {
- // Reset our pointer so that we cannot release again.
- _ssm = nullptr;
-
- // We are not the original owner, nothing more to do.
- return;
- }
-
- // Reclaim the client.
- _ssm->_client = Client::releaseCurrent();
-
- // Reset our pointer so that we cannot release again.
- _ssm = nullptr;
-
- if (!_oldThreadName.empty()) {
- // Reset the old thread name.
- setThreadName(_oldThreadName);
- }
- }
-
-private:
- ServiceStateMachine::Impl* _ssm = nullptr;
-
- bool _haveTakenOwnership = false;
- std::string _oldThreadName;
-};
-
-auto getThreadGuardedExecutor =
- Client::declareDecoration<std::shared_ptr<ServiceStateMachine::Impl::ThreadGuardedExecutor>>();
-
-/*
- * A ThreadGuardedExecutor is a client decoration that can wrap any OutOfLineExecutor to allow
- * processing tasks while having the client object (i.e., `Client`) attached to the executor thread.
- * In particular, scheduling tasks through a ThreadGuardedExecutor ensures that the corresponding
- * client object is reachable through `Client::getCurrent()` and prevents concurrent accesses to the
- * client object. A ThreadGuardedExecutor is only valid in the context of ServiceStateMachine.
- * Also, any task scheduled through ThreadGuardedExecutor captures a reference (i.e., shared
- * pointer) to both ServiceStateMachine and ThreadGuardedExecutor, thus accessing these objects
- * through raw pointers (e.g., `this`) is considered safe.
- */
-class ServiceStateMachine::Impl::ThreadGuardedExecutor
- : public std::enable_shared_from_this<ServiceStateMachine::Impl::ThreadGuardedExecutor> {
-public:
- // Wraps an instance of OutOfLineExecutor and delegates scheduling to ThreadGuardedExecutor
- class WrappedExecutor : public OutOfLineExecutor,
- public std::enable_shared_from_this<WrappedExecutor> {
- public:
- WrappedExecutor(const WrappedExecutor&) = delete;
- WrappedExecutor(WrappedExecutor&&) = delete;
-
- WrappedExecutor(std::shared_ptr<ThreadGuardedExecutor> parent, OutOfLineExecutor* executor)
- : _parent(std::move(parent)), _executor(executor) {}
-
- void schedule(OutOfLineExecutor::Task task) override {
- _parent->schedule(_executor, std::move(task));
- }
-
- private:
- std::shared_ptr<ThreadGuardedExecutor> const _parent;
- OutOfLineExecutor* const _executor;
- };
-
- ThreadGuardedExecutor() = delete;
- ThreadGuardedExecutor(const ThreadGuardedExecutor&) = delete;
-
- explicit ThreadGuardedExecutor(std::weak_ptr<ServiceStateMachine::Impl> ssm)
- : _ssm(std::move(ssm)) {}
-
- ThreadGuardedExecutor(ThreadGuardedExecutor&& other)
- : _isBusy{other._isBusy.load()}, _ssm(other._ssm), _guard(std::move(other._guard)) {}
-
- ~ThreadGuardedExecutor() {
- invariant(!_isBusy.load());
- invariant(!_guard.has_value());
- }
-
- static void set(Client* client, ThreadGuardedExecutor instance) {
- auto& clientThreadGuardedExecutor = getThreadGuardedExecutor(client);
- invariant(!clientThreadGuardedExecutor);
- clientThreadGuardedExecutor = std::make_shared<decltype(instance)>(std::move(instance));
- }
-
- static std::shared_ptr<ThreadGuardedExecutor> get(const Client* client) {
- return getThreadGuardedExecutor(client);
- }
-
- auto wrapExecutor(OutOfLineExecutor* executor) {
- return std::make_shared<WrappedExecutor>(shared_from_this(), executor);
- }
-
- void schedule(OutOfLineExecutor* executor, OutOfLineExecutor::Task task) {
- // Since `ThreadGuardedExecutor` is a client decoration, and SSM owns the client object,
- // ServiceStateMachine must be present when tasks are scheduled here.
- auto ssm = _ssm.lock();
- invariant(ssm);
-
- executor->schedule([this,
- executor, // Valid as the executor must be present to run the task
- task = std::move(task),
- ssm = std::move(ssm),
- anchor = shared_from_this()](Status status) mutable {
- if (auto wasBusy = _isBusy.swap(true); wasBusy) {
- // Reschedule if another executor thread is running a task for the client.
- LOGV2_DEBUG(4910704, kDiagnosticLogLevel, "Rescheduling thread-guarded task");
- schedule(executor, std::move(task));
- return;
- }
-
- invariant(!_guard.has_value());
- _guard.emplace(ssm.get());
- LOGV2_DEBUG(4910701, kDiagnosticLogLevel, "Acquired ThreadGuard in scheduled task");
-
- ON_BLOCK_EXIT([&] {
- if (_guard.has_value()) {
- releaseThreadGuard();
- }
- });
-
- LOGV2_DEBUG(
- 4910703, kDiagnosticLogLevel, "Started running task in a thread-guarded context");
- task(status);
- });
- }
-
- // Must only be called on the thread that owns the guard
- void releaseThreadGuard() {
- invariant(_guard.has_value() && _isBusy.load());
- LOGV2_DEBUG(4910702, kDiagnosticLogLevel, "Releasing the ThreadGuard");
- _guard = boost::none;
- _isBusy.store(false);
- }
-
-private:
- static constexpr auto kDiagnosticLogLevel = 3;
-
- // Set to `true` if the executor is busy running a task on behalf of the corresponding client.
- AtomicWord<bool> _isBusy{false};
-
- std::weak_ptr<ServiceStateMachine::Impl> _ssm;
- boost::optional<ThreadGuard> _guard;
-};
-
Future<void> ServiceStateMachine::Impl::sourceMessage() {
invariant(_inMessage.empty());
invariant(_state.load() == State::Source);
@@ -583,8 +365,6 @@ Future<void> ServiceStateMachine::Impl::sinkMessage() {
}
void ServiceStateMachine::Impl::sourceCallback(Status status) {
- auto guard = ThreadGuard(this);
-
invariant(state() == State::SourceWait);
auto remote = session()->remote();
@@ -626,8 +406,6 @@ void ServiceStateMachine::Impl::sourceCallback(Status status) {
}
void ServiceStateMachine::Impl::sinkCallback(Status status) {
- auto guard = ThreadGuard(this);
-
invariant(state() == State::SinkWait);
// If there was an error sinking the message to the client, then we should print an error and
@@ -736,25 +514,25 @@ Future<void> ServiceStateMachine::Impl::processMessage() {
void ServiceStateMachine::Impl::start(ServiceExecutorContext seCtx) {
{
- stdx::lock_guard lk(*_clientPtr);
-
- ServiceExecutorContext::set(_clientPtr, std::move(seCtx));
+ auto client = _clientStrand->getClientPointer();
+ stdx::lock_guard lk(*client);
+ ServiceExecutorContext::set(client, std::move(seCtx));
}
- // Set the executor decoration here to ensure `shared_from_this()` returns a valid pointer
- ThreadGuardedExecutor::set(_clientPtr, ThreadGuardedExecutor(shared_from_this()));
+ invariant(_state.swap(State::Source) == State::Created);
- ThreadGuardedExecutor::get(_clientPtr)
- ->schedule(
- executor(),
- GuaranteedExecutor::enforceRunOnce([this, anchor = shared_from_this()](Status status) {
- // If this is the first run of the SSM, then update its state to Source
- if (state() == State::Created) {
- _state.store(State::Source);
- }
+ auto cb = [this, anchor = shared_from_this()](Status status) {
+ _clientStrand->run([&] {
+ if (ErrorCodes::isCancelationError(status)) {
+ cleanupSession();
+ return;
+ }
+ invariant(status);
- runOnce();
- }));
+ runOnce();
+ });
+ };
+ executor()->schedule(std::move(cb));
}
void ServiceStateMachine::Impl::runOnce() {
@@ -791,19 +569,25 @@ void ServiceStateMachine::Impl::runOnce() {
"error"_attr = status);
terminate();
- ThreadGuardedExecutor::get(_clientPtr)
- ->schedule(executor(),
- GuaranteedExecutor::enforceRunOnce(
- [this, anchor = shared_from_this()](Status status) {
- cleanupSession();
- }));
+ auto cb = [this, anchor = shared_from_this()](Status status) {
+ _clientStrand->run([&] { cleanupSession(); });
+ };
+ executor()->schedule(std::move(cb));
return;
}
- ThreadGuardedExecutor::get(_clientPtr)
- ->schedule(executor(), GuaranteedExecutor::enforceRunOnce([this](Status status) {
- runOnce();
- }));
+ auto cb = [this, anchor = shared_from_this()](Status status) {
+ _clientStrand->run([&] {
+ if (ErrorCodes::isCancelationError(status)) {
+ cleanupSession();
+ return;
+ }
+ invariant(status);
+
+ runOnce();
+ });
+ };
+ executor()->schedule(std::move(cb));
});
}
@@ -878,8 +662,9 @@ void ServiceStateMachine::Impl::cleanupSession() {
cleanupExhaustResources();
{
- stdx::lock_guard lk(*_clientPtr);
- transport::ServiceExecutorContext::reset(_clientPtr);
+ auto client = _clientStrand->getClientPointer();
+ stdx::lock_guard lk(*client);
+ transport::ServiceExecutorContext::reset(client);
}
if (auto cleanupHook = std::exchange(_cleanupHook, {})) {