summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Li <alex.li@mongodb.com>2022-09-12 14:23:07 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-09-12 15:36:02 +0000
commit3ec1e222a0a6144d94888578e09c256b14bddc00 (patch)
tree39b58e12c158dd88c1913cbf4f45a2ba7e7b974c
parentec4c0fe845891168ab5c09b2ffc9c03eaa230418 (diff)
downloadmongo-3ec1e222a0a6144d94888578e09c256b14bddc00.tar.gz
SERVER-67829 Benchmark for ServiceStateMachine
-rw-r--r--src/mongo/transport/SConscript13
-rw-r--r--src/mongo/transport/mock_service_executor.h74
-rw-r--r--src/mongo/transport/service_entry_point_impl.cpp18
-rw-r--r--src/mongo/transport/service_entry_point_impl.h5
-rw-r--r--src/mongo/transport/service_executor.cpp4
-rw-r--r--src/mongo/transport/service_executor.h6
-rw-r--r--src/mongo/transport/session_workflow_bm.cpp310
-rw-r--r--src/mongo/transport/session_workflow_test.cpp87
-rw-r--r--src/mongo/transport/session_workflow_test_util.h141
9 files changed, 564 insertions, 94 deletions
diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript
index 457ea23fabc..632e32347e0 100644
--- a/src/mongo/transport/SConscript
+++ b/src/mongo/transport/SConscript
@@ -246,3 +246,16 @@ tlEnvTest.CppIntegrationTest(
'transport_layer_egress_init',
],
)
+
+env.Benchmark(
+ target='session_workflow_bm',
+ source=[
+ 'session_workflow_bm.cpp',
+ ],
+ LIBDEPS=[
+ '$BUILD_DIR/mongo/db/service_context_test_fixture',
+ 'service_entry_point',
+ 'service_executor',
+ 'transport_layer_mock',
+ ],
+)
diff --git a/src/mongo/transport/mock_service_executor.h b/src/mongo/transport/mock_service_executor.h
new file mode 100644
index 00000000000..c6f514e3bb9
--- /dev/null
+++ b/src/mongo/transport/mock_service_executor.h
@@ -0,0 +1,74 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <functional>
+
+#include "mongo/transport/service_executor.h"
+#include "mongo/util/out_of_line_executor.h"
+
+namespace mongo::transport {
+
+class MockServiceExecutor : public ServiceExecutor {
+public:
+ Status start() override {
+ return startCb();
+ }
+
+ Status scheduleTask(Task task, ScheduleFlags flags) override {
+ return scheduleTaskCb(std::move(task), std::move(flags));
+ }
+
+ void runOnDataAvailable(const SessionHandle& session,
+ OutOfLineExecutor::Task onCompletionCallback) override {
+ runOnDataAvailableCb(session, std::move(onCompletionCallback));
+ }
+
+ Status shutdown(Milliseconds timeout) override {
+ return shutdownCb(std::move(timeout));
+ }
+
+ size_t getRunningThreads() const override {
+ return getRunningThreadsCb();
+ }
+
+ void appendStats(BSONObjBuilder* bob) const override {
+ appendStatsCb(bob);
+ }
+
+ std::function<Status()> startCb;
+ std::function<Status(Task, ScheduleFlags)> scheduleTaskCb;
+ std::function<void(const SessionHandle&, OutOfLineExecutor::Task)> runOnDataAvailableCb;
+ std::function<Status(Milliseconds)> shutdownCb;
+ std::function<size_t()> getRunningThreadsCb;
+ std::function<void(BSONObjBuilder*)> appendStatsCb;
+};
+
+} // namespace mongo::transport
diff --git a/src/mongo/transport/service_entry_point_impl.cpp b/src/mongo/transport/service_entry_point_impl.cpp
index be61885ee7a..35bdb872a12 100644
--- a/src/mongo/transport/service_entry_point_impl.cpp
+++ b/src/mongo/transport/service_entry_point_impl.cpp
@@ -281,6 +281,15 @@ Status ServiceEntryPointImpl::start() {
return Status::OK();
}
+void ServiceEntryPointImpl::configureServiceExecutorContext(ServiceContext::UniqueClient& client,
+ bool isPrivilegedSession) {
+ auto seCtx = std::make_unique<transport::ServiceExecutorContext>();
+ seCtx->setUseDedicatedThread(transport::gInitialUseDedicatedThread);
+ seCtx->setCanUseReserved(isPrivilegedSession);
+ stdx::lock_guard lk(*client);
+ transport::ServiceExecutorContext::set(&*client, std::move(seCtx));
+}
+
void ServiceEntryPointImpl::startSession(transport::SessionHandle session) {
invariant(session);
// Setup the restriction environment on the Session, if the Session has local/remote Sockaddrs
@@ -311,14 +320,7 @@ void ServiceEntryPointImpl::startSession(transport::SessionHandle session) {
return;
}
- // Imbue the new Client with a ServiceExecutorContext.
- {
- auto seCtx = std::make_unique<transport::ServiceExecutorContext>();
- seCtx->setUseDedicatedThread(transport::gInitialUseDedicatedThread);
- seCtx->setCanUseReserved(isPrivilegedSession);
- stdx::lock_guard lk(*client);
- transport::ServiceExecutorContext::set(&*client, std::move(seCtx));
- }
+ configureServiceExecutorContext(client, isPrivilegedSession);
workflow = transport::SessionWorkflow::make(std::move(client));
auto iter = sync.insert(workflow);
diff --git a/src/mongo/transport/service_entry_point_impl.h b/src/mongo/transport/service_entry_point_impl.h
index 91d5ea1e740..683ad266fa3 100644
--- a/src/mongo/transport/service_entry_point_impl.h
+++ b/src/mongo/transport/service_entry_point_impl.h
@@ -81,6 +81,11 @@ public:
/** `onClientDisconnect` calls this before doing anything else. */
virtual void derivedOnClientDisconnect(Client* client) {}
+protected:
+ /** Imbue the new Client with a ServiceExecutorContext. */
+ virtual void configureServiceExecutorContext(ServiceContext::UniqueClient& client,
+ bool isPrivilegedSession);
+
private:
class Sessions;
diff --git a/src/mongo/transport/service_executor.cpp b/src/mongo/transport/service_executor.cpp
index 747ea46e477..49a6ef4f67c 100644
--- a/src/mongo/transport/service_executor.cpp
+++ b/src/mongo/transport/service_executor.cpp
@@ -145,6 +145,10 @@ void ServiceExecutorContext::setCanUseReserved(bool canUseReserved) noexcept {
ServiceExecutor* ServiceExecutorContext::getServiceExecutor() noexcept {
invariant(_client);
+
+ if (_getServiceExecutorForTest)
+ return _getServiceExecutorForTest();
+
if (!_useDedicatedThread)
return ServiceExecutorFixed::get(_client->getServiceContext());
diff --git a/src/mongo/transport/service_executor.h b/src/mongo/transport/service_executor.h
index 458f152bd89..c8de15851eb 100644
--- a/src/mongo/transport/service_executor.h
+++ b/src/mongo/transport/service_executor.h
@@ -152,6 +152,9 @@ public:
static void reset(Client* client) noexcept;
ServiceExecutorContext() = default;
+ /** Test only */
+ explicit ServiceExecutorContext(std::function<ServiceExecutor*()> getServiceExecutorForTest)
+ : _getServiceExecutorForTest(getServiceExecutorForTest) {}
ServiceExecutorContext(const ServiceExecutorContext&) = delete;
ServiceExecutorContext& operator=(const ServiceExecutorContext&) = delete;
ServiceExecutorContext(ServiceExecutorContext&&) = delete;
@@ -195,6 +198,9 @@ private:
bool _useDedicatedThread = true;
bool _canUseReserved = false;
bool _hasUsedSynchronous = false;
+
+ /** For tests to override the behavior of `getServiceExecutor()`. */
+ std::function<ServiceExecutor*()> _getServiceExecutorForTest;
};
/**
diff --git a/src/mongo/transport/session_workflow_bm.cpp b/src/mongo/transport/session_workflow_bm.cpp
new file mode 100644
index 00000000000..9fe69da9c51
--- /dev/null
+++ b/src/mongo/transport/session_workflow_bm.cpp
@@ -0,0 +1,310 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <chrono>
+#include <memory>
+
+#include <benchmark/benchmark.h>
+
+#include "mongo/bson/bsonelement.h"
+#include "mongo/db/concurrency/locker_noop_client_observer.h"
+#include "mongo/db/dbmessage.h"
+#include "mongo/db/operation_context.h"
+#include "mongo/db/service_context.h"
+#include "mongo/rpc/op_msg.h"
+#include "mongo/transport/mock_service_executor.h"
+#include "mongo/transport/service_entry_point_impl.h"
+#include "mongo/transport/service_executor.h"
+#include "mongo/transport/session.h"
+#include "mongo/transport/session_workflow_test_util.h"
+#include "mongo/transport/transport_layer_mock.h"
+#include "mongo/util/assert_util_core.h"
+#include "mongo/util/out_of_line_executor.h"
+#include "mongo/util/processinfo.h"
+
+namespace mongo::transport {
+namespace {
+
+Status makeClosedSessionError() {
+ return Status{ErrorCodes::SocketException, "Session is closed"};
+}
+
+class NoopReactor : public Reactor {
+public:
+ void run() noexcept override {}
+ void stop() override {}
+
+ void runFor(Milliseconds time) noexcept override {
+ MONGO_UNREACHABLE;
+ }
+
+ void drain() override {
+ MONGO_UNREACHABLE;
+ }
+
+ void schedule(Task) override {
+ MONGO_UNREACHABLE;
+ }
+
+ void dispatch(Task) override {
+ MONGO_UNREACHABLE;
+ }
+
+ bool onReactorThread() const override {
+ MONGO_UNREACHABLE;
+ }
+
+ std::unique_ptr<ReactorTimer> makeTimer() override {
+ MONGO_UNREACHABLE;
+ }
+
+ Date_t now() override {
+ MONGO_UNREACHABLE;
+ }
+
+ void appendStats(BSONObjBuilder&) const {
+ MONGO_UNREACHABLE;
+ }
+};
+
+class TransportLayerMockWithReactor : public TransportLayerMock {
+public:
+ ReactorHandle getReactor(WhichReactor) override {
+ return _mockReactor;
+ }
+
+private:
+ ReactorHandle _mockReactor = std::make_unique<NoopReactor>();
+};
+
+Message makeMessageWithBenchmarkRunNumber(int runNumber) {
+ OpMsgBuilder builder;
+ builder.setBody(BSON("ping" << 1 << "benchmarkRunNumber" << runNumber));
+ Message request = builder.finish();
+ OpMsg::setFlag(&request, OpMsg::kExhaustSupported);
+ return request;
+}
+
+std::shared_ptr<CallbackMockSession> makeSession(Message message) {
+ auto session = std::make_shared<CallbackMockSession>();
+ session = std::make_shared<CallbackMockSession>();
+ session->endCb = [] {};
+ session->waitForDataCb = [&] { return Status::OK(); };
+ session->sourceMessageCb = [&, message] { return StatusWith<Message>(message); };
+ session->sinkMessageCb = [&](Message message) {
+ if (OpMsg::parse(message).body["benchmarkRunNumber"].numberInt() > 0)
+ return Status::OK();
+ return makeClosedSessionError();
+ };
+ session->asyncWaitForDataCb = [&] { return Future<void>::makeReady(); };
+ return session;
+}
+
+class SessionWorkflowFixture : public benchmark::Fixture {
+public:
+ Future<DbResponse> handleRequest(const Message& request) {
+ DbResponse response;
+ response.response = request;
+
+ BSONObj obj = OpMsg::parse(request).body;
+
+ // Check "benchmarkRunNumber" field for how many times to run in exhaust
+ if (obj["benchmarkRunNumber"].numberInt() > 0) {
+ BSONObjBuilder bsonBuilder;
+ bsonBuilder.append(obj.firstElement());
+ bsonBuilder.append("benchmarkRunNumber", obj["benchmarkRunNumber"].numberInt() - 1);
+ BSONObj newObj = bsonBuilder.obj();
+
+ response.response = request;
+ response.nextInvocation = newObj;
+ response.shouldRunAgainForExhaust = true;
+ }
+
+ return Future<DbResponse>::makeReady(StatusWith<DbResponse>(response));
+ }
+
+ void commonSetUp() {
+ serviceCtx = [] {
+ auto serviceContext = ServiceContext::make();
+ auto serviceContextPtr = serviceContext.get();
+ setGlobalServiceContext(std::move(serviceContext));
+ return serviceContextPtr;
+ }();
+ invariant(serviceCtx);
+ serviceCtx->registerClientObserver(std::make_unique<LockerNoopClientObserver>());
+
+ auto uniqueSep = std::make_unique<MockServiceEntryPoint>(serviceCtx);
+ uniqueSep->handleRequestCb = [&](OperationContext*, const Message& request) {
+ return handleRequest(request);
+ };
+ uniqueSep->onEndSessionCb = [&](const SessionHandle&) {};
+ uniqueSep->derivedOnClientDisconnectCb = [&](Client*) {};
+ sep = uniqueSep.get();
+
+ serviceCtx->setServiceEntryPoint(std::move(uniqueSep));
+ serviceCtx->setTransportLayer(std::make_unique<TransportLayerMockWithReactor>());
+ }
+
+ void SetUp(benchmark::State& state) override {
+ // Call SetUp on only one thread
+ if (state.thread_index != 0)
+ return;
+ commonSetUp();
+ makeSessionCb = makeSession;
+ invariant(sep->start());
+ }
+
+ void TearDown(benchmark::State& state) override {
+ if (state.thread_index != 0)
+ return;
+ invariant(sep->shutdownAndWait(Seconds{10}));
+ setGlobalServiceContext({});
+ }
+
+ void benchmarkScheduleNewLoop(benchmark::State& state) {
+ int64_t totalExhaustRounds = state.range(0);
+
+ for (auto _ : state) {
+ auto session = makeSessionCb(makeMessageWithBenchmarkRunNumber(totalExhaustRounds));
+ sep->startSession(std::move(session));
+
+ invariant(sep->waitForNoSessions(Seconds{1}));
+ }
+ }
+
+ ServiceContext* serviceCtx;
+ MockServiceEntryPoint* sep;
+ std::function<std::shared_ptr<CallbackMockSession>(Message)> makeSessionCb;
+};
+
+template <bool useDedicatedThread>
+class DedicatedThreadOverrideFixture : public SessionWorkflowFixture {
+ ScopedValueOverride<bool> _svo{gInitialUseDedicatedThread, useDedicatedThread};
+};
+
+using SessionWorkflowWithBorrowedThreads = DedicatedThreadOverrideFixture<false>;
+using SessionWorkflowWithDedicatedThreads = DedicatedThreadOverrideFixture<true>;
+
+enum class StageToStop {
+ kDefault,
+ kSource,
+ kProcess,
+ kSink,
+};
+
+class SingleThreadSessionWorkflow : public SessionWorkflowFixture {
+public:
+ void initializeMockExecutor(benchmark::State& state) {
+ serviceExecutor = std::make_unique<MockServiceExecutor>();
+ serviceExecutor->runOnDataAvailableCb = [&](const SessionHandle& session,
+ OutOfLineExecutor::Task callback) {
+ serviceExecutor->schedule(
+ [callback = std::move(callback)](Status status) { callback(status); });
+ };
+ serviceExecutor->getRunningThreadsCb = [&] { return 0; };
+ serviceExecutor->scheduleTaskCb = [&](ServiceExecutor::Task task,
+ MockServiceExecutor::ScheduleFlags) {
+ task();
+ return Status::OK();
+ };
+ }
+
+ void SetUp(benchmark::State& state) override {
+ invariant(state.threads == 1, "Environment must be single threaded");
+ auto stopAt = static_cast<StageToStop>(state.range(1));
+
+ commonSetUp();
+
+ // Configure SEP to use mock executor
+ sep->configureServiceExecutorContextCb = [&](ServiceContext::UniqueClient& client, bool) {
+ auto seCtx =
+ std::make_unique<ServiceExecutorContext>([&] { return serviceExecutor.get(); });
+ stdx::lock_guard lk(*client);
+ ServiceExecutorContext::set(&*client, std::move(seCtx));
+ };
+ initializeMockExecutor(state);
+
+ // Change callbacks so that the benchmark stops at the right stage
+ if (stopAt == StageToStop::kProcess)
+ sep->handleRequestCb = [&, stopAt](OperationContext* opCtx, const Message& request) {
+ if (OpMsg::parse(request).body["benchmarkRunNumber"].numberInt() > 0)
+ return handleRequest(request);
+ return Future<DbResponse>::makeReady(makeClosedSessionError());
+ };
+ makeSessionCb = [stopAt](Message message) {
+ auto session = makeSession(message);
+ if (stopAt == StageToStop::kSource)
+ session->sourceMessageCb = [&, message] {
+ return StatusWith<Message>(makeClosedSessionError());
+ };
+ return session;
+ };
+ }
+
+ std::unique_ptr<MockServiceExecutor> serviceExecutor;
+};
+
+const int64_t benchmarkThreadMax = ProcessInfo::getNumAvailableCores() * 2;
+
+BENCHMARK_DEFINE_F(SessionWorkflowWithDedicatedThreads, MultiThreadScheduleNewLoop)
+(benchmark::State& state) {
+ benchmarkScheduleNewLoop(state);
+}
+BENCHMARK_REGISTER_F(SessionWorkflowWithDedicatedThreads, MultiThreadScheduleNewLoop)
+ ->ArgNames({"Exhaust"})
+ ->Arg(0)
+ ->Arg(1)
+ ->ThreadRange(1, benchmarkThreadMax);
+
+BENCHMARK_DEFINE_F(SessionWorkflowWithBorrowedThreads, MultiThreadScheduleNewLoop)
+(benchmark::State& state) {
+ benchmarkScheduleNewLoop(state);
+}
+BENCHMARK_REGISTER_F(SessionWorkflowWithBorrowedThreads, MultiThreadScheduleNewLoop)
+ ->ArgNames({"Exhaust"})
+ ->Arg(0)
+ ->Arg(1)
+ ->ThreadRange(1, benchmarkThreadMax);
+
+template <typename... E>
+auto enumToArgs(E... e) {
+ return std::vector<int64_t>{static_cast<int64_t>(e)...};
+}
+
+BENCHMARK_DEFINE_F(SingleThreadSessionWorkflow, SingleThreadScheduleNewLoop)
+(benchmark::State& state) {
+ benchmarkScheduleNewLoop(state);
+}
+BENCHMARK_REGISTER_F(SingleThreadSessionWorkflow, SingleThreadScheduleNewLoop)
+ ->ArgNames({"Exhaust", "Stage to stop"})
+ ->ArgsProduct({{0},
+ enumToArgs(StageToStop::kSource, StageToStop::kProcess, StageToStop::kSink)});
+
+} // namespace
+} // namespace mongo::transport
diff --git a/src/mongo/transport/session_workflow_test.cpp b/src/mongo/transport/session_workflow_test.cpp
index f85b3e4d120..df264e372a2 100644
--- a/src/mongo/transport/session_workflow_test.cpp
+++ b/src/mongo/transport/session_workflow_test.cpp
@@ -54,6 +54,7 @@
#include "mongo/transport/service_executor.h"
#include "mongo/transport/service_executor_utils.h"
#include "mongo/transport/session_workflow.h"
+#include "mongo/transport/session_workflow_test_util.h"
#include "mongo/transport/transport_layer_mock.h"
#include "mongo/unittest/unittest.h"
#include "mongo/util/assert_util.h"
@@ -69,21 +70,6 @@ namespace mongo {
namespace transport {
namespace {
-/** Scope guard to set and restore an object value. */
-template <typename T>
-class ScopedValueOverride {
-public:
- ScopedValueOverride(T& target, T v)
- : _target{target}, _saved{std::exchange(_target, std::move(v))} {}
- ~ScopedValueOverride() {
- _target = std::move(_saved);
- }
-
-private:
- T& _target;
- T _saved;
-};
-
const Status kClosedSessionError{ErrorCodes::SocketException, "Session is closed"};
const Status kNetworkError{ErrorCodes::HostUnreachable, "Someone is unreachable"};
const Status kShutdownError{ErrorCodes::ShutdownInProgress, "Something is shutting down"};
@@ -259,77 +245,6 @@ struct StateResult {
SessionState state;
};
-class CallbackMockSession : public MockSessionBase {
-public:
- TransportLayer* getTransportLayer() const override {
- return getTransportLayerCb();
- }
-
- void end() override {
- endCb();
- }
-
- bool isConnected() override {
- return isConnectedCb();
- }
-
- Status waitForData() noexcept override {
- return waitForDataCb();
- }
-
- StatusWith<Message> sourceMessage() noexcept override {
- return sourceMessageCb();
- }
-
- Status sinkMessage(Message message) noexcept override {
- return sinkMessageCb(message);
- }
-
- Future<void> asyncWaitForData() noexcept override {
- return asyncWaitForDataCb();
- }
-
- Future<Message> asyncSourceMessage(const BatonHandle& handle) noexcept override {
- return asyncSourceMessageCb(handle);
- }
-
- Future<void> asyncSinkMessage(Message message, const BatonHandle& handle) noexcept override {
- return asyncSinkMessageCb(message, handle);
- }
-
- std::function<TransportLayer*(void)> getTransportLayerCb;
- std::function<void(void)> endCb;
- std::function<bool(void)> isConnectedCb;
- std::function<Status(void)> waitForDataCb;
- std::function<StatusWith<Message>(void)> sourceMessageCb;
- std::function<Status(Message)> sinkMessageCb;
- std::function<Future<void>(void)> asyncWaitForDataCb;
- std::function<Future<Message>(const BatonHandle&)> asyncSourceMessageCb;
- std::function<Future<void>(Message, const BatonHandle&)> asyncSinkMessageCb;
-};
-
-class MockServiceEntryPoint : public ServiceEntryPointImpl {
-public:
- explicit MockServiceEntryPoint(ServiceContext* svcCtx) : ServiceEntryPointImpl(svcCtx) {}
-
- Future<DbResponse> handleRequest(OperationContext* opCtx,
- const Message& request) noexcept override {
- return handleRequestCb(opCtx, request);
- }
-
- void onEndSession(const transport::SessionHandle& handle) override {
- onEndSessionCb(handle);
- }
-
- void derivedOnClientDisconnect(Client* client) override {
- derivedOnClientDisconnectCb(client);
- }
-
- std::function<Future<DbResponse>(OperationContext*, const Message&)> handleRequestCb;
- std::function<void(const transport::SessionHandle)> onEndSessionCb;
- std::function<void(Client*)> derivedOnClientDisconnectCb;
-};
-
/**
* The SessionWorkflowTest is a fixture that mocks the external inputs into the
* SessionWorkflow so as to provide a deterministic way to evaluate the session workflow
diff --git a/src/mongo/transport/session_workflow_test_util.h b/src/mongo/transport/session_workflow_test_util.h
new file mode 100644
index 00000000000..552ad9cecce
--- /dev/null
+++ b/src/mongo/transport/session_workflow_test_util.h
@@ -0,0 +1,141 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+
+#include <functional>
+
+#include "mongo/base/status.h"
+#include "mongo/db/dbmessage.h"
+#include "mongo/transport/mock_session.h"
+#include "mongo/transport/service_entry_point_impl.h"
+#include "mongo/transport/transport_layer.h"
+#include "mongo/transport/transport_layer_mock.h"
+
+namespace mongo {
+namespace transport {
+
+using ReactorHandle = std::shared_ptr<Reactor>;
+
+/** Scope guard to set and restore an object value. */
+template <typename T>
+class ScopedValueOverride {
+public:
+ ScopedValueOverride(T& target, T v)
+ : _target{target}, _saved{std::exchange(_target, std::move(v))} {}
+ ~ScopedValueOverride() {
+ _target = std::move(_saved);
+ }
+
+private:
+ T& _target;
+ T _saved;
+};
+
+class CallbackMockSession : public MockSessionBase {
+public:
+ TransportLayer* getTransportLayer() const override {
+ return getTransportLayerCb();
+ }
+
+ void end() override {
+ endCb();
+ }
+
+ bool isConnected() override {
+ return isConnectedCb();
+ }
+
+ Status waitForData() noexcept override {
+ return waitForDataCb();
+ }
+
+ StatusWith<Message> sourceMessage() noexcept override {
+ return sourceMessageCb();
+ }
+
+ Status sinkMessage(Message message) noexcept override {
+ return sinkMessageCb(std::move(message));
+ }
+
+ Future<void> asyncWaitForData() noexcept override {
+ return asyncWaitForDataCb();
+ }
+
+ Future<Message> asyncSourceMessage(const BatonHandle& handle) noexcept override {
+ return asyncSourceMessageCb(handle);
+ }
+
+ Future<void> asyncSinkMessage(Message message, const BatonHandle& handle) noexcept override {
+ return asyncSinkMessageCb(std::move(message), handle);
+ }
+
+ std::function<TransportLayer*()> getTransportLayerCb;
+ std::function<void()> endCb;
+ std::function<bool()> isConnectedCb;
+ std::function<Status()> waitForDataCb;
+ std::function<StatusWith<Message>()> sourceMessageCb;
+ std::function<Status(Message)> sinkMessageCb;
+ std::function<Future<void>()> asyncWaitForDataCb;
+ std::function<Future<Message>(const BatonHandle&)> asyncSourceMessageCb;
+ std::function<Future<void>(Message, const BatonHandle&)> asyncSinkMessageCb;
+};
+
+class MockServiceEntryPoint : public ServiceEntryPointImpl {
+public:
+ explicit MockServiceEntryPoint(ServiceContext* svcCtx) : ServiceEntryPointImpl(svcCtx) {}
+
+ Future<DbResponse> handleRequest(OperationContext* opCtx,
+ const Message& request) noexcept override {
+ return handleRequestCb(opCtx, request);
+ }
+
+ void onEndSession(const SessionHandle& handle) override {
+ onEndSessionCb(handle);
+ }
+
+ void derivedOnClientDisconnect(Client* client) override {
+ derivedOnClientDisconnectCb(client);
+ }
+
+ void configureServiceExecutorContext(ServiceContext::UniqueClient& client,
+ bool isPrivilegedSession) override {
+ if (configureServiceExecutorContextCb)
+ configureServiceExecutorContextCb(client, isPrivilegedSession);
+ else
+ ServiceEntryPointImpl::configureServiceExecutorContext(client, isPrivilegedSession);
+ }
+
+ std::function<Future<DbResponse>(OperationContext*, const Message&)> handleRequestCb;
+ std::function<void(const SessionHandle)> onEndSessionCb;
+ std::function<void(Client*)> derivedOnClientDisconnectCb;
+ std::function<void(ServiceContext::UniqueClient&, bool)> configureServiceExecutorContextCb;
+};
+
+} // namespace transport
+} // namespace mongo