diff options
author | Alex Li <alex.li@mongodb.com> | 2022-09-12 14:23:07 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-09-12 15:36:02 +0000 |
commit | 3ec1e222a0a6144d94888578e09c256b14bddc00 (patch) | |
tree | 39b58e12c158dd88c1913cbf4f45a2ba7e7b974c | |
parent | ec4c0fe845891168ab5c09b2ffc9c03eaa230418 (diff) | |
download | mongo-3ec1e222a0a6144d94888578e09c256b14bddc00.tar.gz |
SERVER-67829 Benchmark for ServiceStateMachine
-rw-r--r-- | src/mongo/transport/SConscript | 13 | ||||
-rw-r--r-- | src/mongo/transport/mock_service_executor.h | 74 | ||||
-rw-r--r-- | src/mongo/transport/service_entry_point_impl.cpp | 18 | ||||
-rw-r--r-- | src/mongo/transport/service_entry_point_impl.h | 5 | ||||
-rw-r--r-- | src/mongo/transport/service_executor.cpp | 4 | ||||
-rw-r--r-- | src/mongo/transport/service_executor.h | 6 | ||||
-rw-r--r-- | src/mongo/transport/session_workflow_bm.cpp | 310 | ||||
-rw-r--r-- | src/mongo/transport/session_workflow_test.cpp | 87 | ||||
-rw-r--r-- | src/mongo/transport/session_workflow_test_util.h | 141 |
9 files changed, 564 insertions, 94 deletions
diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript index 457ea23fabc..632e32347e0 100644 --- a/src/mongo/transport/SConscript +++ b/src/mongo/transport/SConscript @@ -246,3 +246,16 @@ tlEnvTest.CppIntegrationTest( 'transport_layer_egress_init', ], ) + +env.Benchmark( + target='session_workflow_bm', + source=[ + 'session_workflow_bm.cpp', + ], + LIBDEPS=[ + '$BUILD_DIR/mongo/db/service_context_test_fixture', + 'service_entry_point', + 'service_executor', + 'transport_layer_mock', + ], +) diff --git a/src/mongo/transport/mock_service_executor.h b/src/mongo/transport/mock_service_executor.h new file mode 100644 index 00000000000..c6f514e3bb9 --- /dev/null +++ b/src/mongo/transport/mock_service_executor.h @@ -0,0 +1,74 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <functional> + +#include "mongo/transport/service_executor.h" +#include "mongo/util/out_of_line_executor.h" + +namespace mongo::transport { + +class MockServiceExecutor : public ServiceExecutor { +public: + Status start() override { + return startCb(); + } + + Status scheduleTask(Task task, ScheduleFlags flags) override { + return scheduleTaskCb(std::move(task), std::move(flags)); + } + + void runOnDataAvailable(const SessionHandle& session, + OutOfLineExecutor::Task onCompletionCallback) override { + runOnDataAvailableCb(session, std::move(onCompletionCallback)); + } + + Status shutdown(Milliseconds timeout) override { + return shutdownCb(std::move(timeout)); + } + + size_t getRunningThreads() const override { + return getRunningThreadsCb(); + } + + void appendStats(BSONObjBuilder* bob) const override { + appendStatsCb(bob); + } + + std::function<Status()> startCb; + std::function<Status(Task, ScheduleFlags)> scheduleTaskCb; + std::function<void(const SessionHandle&, OutOfLineExecutor::Task)> runOnDataAvailableCb; + std::function<Status(Milliseconds)> shutdownCb; + std::function<size_t()> getRunningThreadsCb; + std::function<void(BSONObjBuilder*)> appendStatsCb; +}; + +} // namespace mongo::transport diff --git a/src/mongo/transport/service_entry_point_impl.cpp b/src/mongo/transport/service_entry_point_impl.cpp index be61885ee7a..35bdb872a12 100644 --- a/src/mongo/transport/service_entry_point_impl.cpp +++ b/src/mongo/transport/service_entry_point_impl.cpp @@ -281,6 +281,15 @@ Status ServiceEntryPointImpl::start() { return Status::OK(); } +void ServiceEntryPointImpl::configureServiceExecutorContext(ServiceContext::UniqueClient& client, + bool isPrivilegedSession) { + auto seCtx = std::make_unique<transport::ServiceExecutorContext>(); + seCtx->setUseDedicatedThread(transport::gInitialUseDedicatedThread); + seCtx->setCanUseReserved(isPrivilegedSession); + stdx::lock_guard lk(*client); + transport::ServiceExecutorContext::set(&*client, std::move(seCtx)); +} + void ServiceEntryPointImpl::startSession(transport::SessionHandle session) { invariant(session); // Setup the restriction environment on the Session, if the Session has local/remote Sockaddrs @@ -311,14 +320,7 @@ void ServiceEntryPointImpl::startSession(transport::SessionHandle session) { return; } - // Imbue the new Client with a ServiceExecutorContext. - { - auto seCtx = std::make_unique<transport::ServiceExecutorContext>(); - seCtx->setUseDedicatedThread(transport::gInitialUseDedicatedThread); - seCtx->setCanUseReserved(isPrivilegedSession); - stdx::lock_guard lk(*client); - transport::ServiceExecutorContext::set(&*client, std::move(seCtx)); - } + configureServiceExecutorContext(client, isPrivilegedSession); workflow = transport::SessionWorkflow::make(std::move(client)); auto iter = sync.insert(workflow); diff --git a/src/mongo/transport/service_entry_point_impl.h b/src/mongo/transport/service_entry_point_impl.h index 91d5ea1e740..683ad266fa3 100644 --- a/src/mongo/transport/service_entry_point_impl.h +++ b/src/mongo/transport/service_entry_point_impl.h @@ -81,6 +81,11 @@ public: /** `onClientDisconnect` calls this before doing anything else. */ virtual void derivedOnClientDisconnect(Client* client) {} +protected: + /** Imbue the new Client with a ServiceExecutorContext. */ + virtual void configureServiceExecutorContext(ServiceContext::UniqueClient& client, + bool isPrivilegedSession); + private: class Sessions; diff --git a/src/mongo/transport/service_executor.cpp b/src/mongo/transport/service_executor.cpp index 747ea46e477..49a6ef4f67c 100644 --- a/src/mongo/transport/service_executor.cpp +++ b/src/mongo/transport/service_executor.cpp @@ -145,6 +145,10 @@ void ServiceExecutorContext::setCanUseReserved(bool canUseReserved) noexcept { ServiceExecutor* ServiceExecutorContext::getServiceExecutor() noexcept { invariant(_client); + + if (_getServiceExecutorForTest) + return _getServiceExecutorForTest(); + if (!_useDedicatedThread) return ServiceExecutorFixed::get(_client->getServiceContext()); diff --git a/src/mongo/transport/service_executor.h b/src/mongo/transport/service_executor.h index 458f152bd89..c8de15851eb 100644 --- a/src/mongo/transport/service_executor.h +++ b/src/mongo/transport/service_executor.h @@ -152,6 +152,9 @@ public: static void reset(Client* client) noexcept; ServiceExecutorContext() = default; + /** Test only */ + explicit ServiceExecutorContext(std::function<ServiceExecutor*()> getServiceExecutorForTest) + : _getServiceExecutorForTest(getServiceExecutorForTest) {} ServiceExecutorContext(const ServiceExecutorContext&) = delete; ServiceExecutorContext& operator=(const ServiceExecutorContext&) = delete; ServiceExecutorContext(ServiceExecutorContext&&) = delete; @@ -195,6 +198,9 @@ private: bool _useDedicatedThread = true; bool _canUseReserved = false; bool _hasUsedSynchronous = false; + + /** For tests to override the behavior of `getServiceExecutor()`. */ + std::function<ServiceExecutor*()> _getServiceExecutorForTest; }; /** diff --git a/src/mongo/transport/session_workflow_bm.cpp b/src/mongo/transport/session_workflow_bm.cpp new file mode 100644 index 00000000000..9fe69da9c51 --- /dev/null +++ b/src/mongo/transport/session_workflow_bm.cpp @@ -0,0 +1,310 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <chrono> +#include <memory> + +#include <benchmark/benchmark.h> + +#include "mongo/bson/bsonelement.h" +#include "mongo/db/concurrency/locker_noop_client_observer.h" +#include "mongo/db/dbmessage.h" +#include "mongo/db/operation_context.h" +#include "mongo/db/service_context.h" +#include "mongo/rpc/op_msg.h" +#include "mongo/transport/mock_service_executor.h" +#include "mongo/transport/service_entry_point_impl.h" +#include "mongo/transport/service_executor.h" +#include "mongo/transport/session.h" +#include "mongo/transport/session_workflow_test_util.h" +#include "mongo/transport/transport_layer_mock.h" +#include "mongo/util/assert_util_core.h" +#include "mongo/util/out_of_line_executor.h" +#include "mongo/util/processinfo.h" + +namespace mongo::transport { +namespace { + +Status makeClosedSessionError() { + return Status{ErrorCodes::SocketException, "Session is closed"}; +} + +class NoopReactor : public Reactor { +public: + void run() noexcept override {} + void stop() override {} + + void runFor(Milliseconds time) noexcept override { + MONGO_UNREACHABLE; + } + + void drain() override { + MONGO_UNREACHABLE; + } + + void schedule(Task) override { + MONGO_UNREACHABLE; + } + + void dispatch(Task) override { + MONGO_UNREACHABLE; + } + + bool onReactorThread() const override { + MONGO_UNREACHABLE; + } + + std::unique_ptr<ReactorTimer> makeTimer() override { + MONGO_UNREACHABLE; + } + + Date_t now() override { + MONGO_UNREACHABLE; + } + + void appendStats(BSONObjBuilder&) const { + MONGO_UNREACHABLE; + } +}; + +class TransportLayerMockWithReactor : public TransportLayerMock { +public: + ReactorHandle getReactor(WhichReactor) override { + return _mockReactor; + } + +private: + ReactorHandle _mockReactor = std::make_unique<NoopReactor>(); +}; + +Message makeMessageWithBenchmarkRunNumber(int runNumber) { + OpMsgBuilder builder; + builder.setBody(BSON("ping" << 1 << "benchmarkRunNumber" << runNumber)); + Message request = builder.finish(); + OpMsg::setFlag(&request, OpMsg::kExhaustSupported); + return request; +} + +std::shared_ptr<CallbackMockSession> makeSession(Message message) { + auto session = std::make_shared<CallbackMockSession>(); + session = std::make_shared<CallbackMockSession>(); + session->endCb = [] {}; + session->waitForDataCb = [&] { return Status::OK(); }; + session->sourceMessageCb = [&, message] { return StatusWith<Message>(message); }; + session->sinkMessageCb = [&](Message message) { + if (OpMsg::parse(message).body["benchmarkRunNumber"].numberInt() > 0) + return Status::OK(); + return makeClosedSessionError(); + }; + session->asyncWaitForDataCb = [&] { return Future<void>::makeReady(); }; + return session; +} + +class SessionWorkflowFixture : public benchmark::Fixture { +public: + Future<DbResponse> handleRequest(const Message& request) { + DbResponse response; + response.response = request; + + BSONObj obj = OpMsg::parse(request).body; + + // Check "benchmarkRunNumber" field for how many times to run in exhaust + if (obj["benchmarkRunNumber"].numberInt() > 0) { + BSONObjBuilder bsonBuilder; + bsonBuilder.append(obj.firstElement()); + bsonBuilder.append("benchmarkRunNumber", obj["benchmarkRunNumber"].numberInt() - 1); + BSONObj newObj = bsonBuilder.obj(); + + response.response = request; + response.nextInvocation = newObj; + response.shouldRunAgainForExhaust = true; + } + + return Future<DbResponse>::makeReady(StatusWith<DbResponse>(response)); + } + + void commonSetUp() { + serviceCtx = [] { + auto serviceContext = ServiceContext::make(); + auto serviceContextPtr = serviceContext.get(); + setGlobalServiceContext(std::move(serviceContext)); + return serviceContextPtr; + }(); + invariant(serviceCtx); + serviceCtx->registerClientObserver(std::make_unique<LockerNoopClientObserver>()); + + auto uniqueSep = std::make_unique<MockServiceEntryPoint>(serviceCtx); + uniqueSep->handleRequestCb = [&](OperationContext*, const Message& request) { + return handleRequest(request); + }; + uniqueSep->onEndSessionCb = [&](const SessionHandle&) {}; + uniqueSep->derivedOnClientDisconnectCb = [&](Client*) {}; + sep = uniqueSep.get(); + + serviceCtx->setServiceEntryPoint(std::move(uniqueSep)); + serviceCtx->setTransportLayer(std::make_unique<TransportLayerMockWithReactor>()); + } + + void SetUp(benchmark::State& state) override { + // Call SetUp on only one thread + if (state.thread_index != 0) + return; + commonSetUp(); + makeSessionCb = makeSession; + invariant(sep->start()); + } + + void TearDown(benchmark::State& state) override { + if (state.thread_index != 0) + return; + invariant(sep->shutdownAndWait(Seconds{10})); + setGlobalServiceContext({}); + } + + void benchmarkScheduleNewLoop(benchmark::State& state) { + int64_t totalExhaustRounds = state.range(0); + + for (auto _ : state) { + auto session = makeSessionCb(makeMessageWithBenchmarkRunNumber(totalExhaustRounds)); + sep->startSession(std::move(session)); + + invariant(sep->waitForNoSessions(Seconds{1})); + } + } + + ServiceContext* serviceCtx; + MockServiceEntryPoint* sep; + std::function<std::shared_ptr<CallbackMockSession>(Message)> makeSessionCb; +}; + +template <bool useDedicatedThread> +class DedicatedThreadOverrideFixture : public SessionWorkflowFixture { + ScopedValueOverride<bool> _svo{gInitialUseDedicatedThread, useDedicatedThread}; +}; + +using SessionWorkflowWithBorrowedThreads = DedicatedThreadOverrideFixture<false>; +using SessionWorkflowWithDedicatedThreads = DedicatedThreadOverrideFixture<true>; + +enum class StageToStop { + kDefault, + kSource, + kProcess, + kSink, +}; + +class SingleThreadSessionWorkflow : public SessionWorkflowFixture { +public: + void initializeMockExecutor(benchmark::State& state) { + serviceExecutor = std::make_unique<MockServiceExecutor>(); + serviceExecutor->runOnDataAvailableCb = [&](const SessionHandle& session, + OutOfLineExecutor::Task callback) { + serviceExecutor->schedule( + [callback = std::move(callback)](Status status) { callback(status); }); + }; + serviceExecutor->getRunningThreadsCb = [&] { return 0; }; + serviceExecutor->scheduleTaskCb = [&](ServiceExecutor::Task task, + MockServiceExecutor::ScheduleFlags) { + task(); + return Status::OK(); + }; + } + + void SetUp(benchmark::State& state) override { + invariant(state.threads == 1, "Environment must be single threaded"); + auto stopAt = static_cast<StageToStop>(state.range(1)); + + commonSetUp(); + + // Configure SEP to use mock executor + sep->configureServiceExecutorContextCb = [&](ServiceContext::UniqueClient& client, bool) { + auto seCtx = + std::make_unique<ServiceExecutorContext>([&] { return serviceExecutor.get(); }); + stdx::lock_guard lk(*client); + ServiceExecutorContext::set(&*client, std::move(seCtx)); + }; + initializeMockExecutor(state); + + // Change callbacks so that the benchmark stops at the right stage + if (stopAt == StageToStop::kProcess) + sep->handleRequestCb = [&, stopAt](OperationContext* opCtx, const Message& request) { + if (OpMsg::parse(request).body["benchmarkRunNumber"].numberInt() > 0) + return handleRequest(request); + return Future<DbResponse>::makeReady(makeClosedSessionError()); + }; + makeSessionCb = [stopAt](Message message) { + auto session = makeSession(message); + if (stopAt == StageToStop::kSource) + session->sourceMessageCb = [&, message] { + return StatusWith<Message>(makeClosedSessionError()); + }; + return session; + }; + } + + std::unique_ptr<MockServiceExecutor> serviceExecutor; +}; + +const int64_t benchmarkThreadMax = ProcessInfo::getNumAvailableCores() * 2; + +BENCHMARK_DEFINE_F(SessionWorkflowWithDedicatedThreads, MultiThreadScheduleNewLoop) +(benchmark::State& state) { + benchmarkScheduleNewLoop(state); +} +BENCHMARK_REGISTER_F(SessionWorkflowWithDedicatedThreads, MultiThreadScheduleNewLoop) + ->ArgNames({"Exhaust"}) + ->Arg(0) + ->Arg(1) + ->ThreadRange(1, benchmarkThreadMax); + +BENCHMARK_DEFINE_F(SessionWorkflowWithBorrowedThreads, MultiThreadScheduleNewLoop) +(benchmark::State& state) { + benchmarkScheduleNewLoop(state); +} +BENCHMARK_REGISTER_F(SessionWorkflowWithBorrowedThreads, MultiThreadScheduleNewLoop) + ->ArgNames({"Exhaust"}) + ->Arg(0) + ->Arg(1) + ->ThreadRange(1, benchmarkThreadMax); + +template <typename... E> +auto enumToArgs(E... e) { + return std::vector<int64_t>{static_cast<int64_t>(e)...}; +} + +BENCHMARK_DEFINE_F(SingleThreadSessionWorkflow, SingleThreadScheduleNewLoop) +(benchmark::State& state) { + benchmarkScheduleNewLoop(state); +} +BENCHMARK_REGISTER_F(SingleThreadSessionWorkflow, SingleThreadScheduleNewLoop) + ->ArgNames({"Exhaust", "Stage to stop"}) + ->ArgsProduct({{0}, + enumToArgs(StageToStop::kSource, StageToStop::kProcess, StageToStop::kSink)}); + +} // namespace +} // namespace mongo::transport diff --git a/src/mongo/transport/session_workflow_test.cpp b/src/mongo/transport/session_workflow_test.cpp index f85b3e4d120..df264e372a2 100644 --- a/src/mongo/transport/session_workflow_test.cpp +++ b/src/mongo/transport/session_workflow_test.cpp @@ -54,6 +54,7 @@ #include "mongo/transport/service_executor.h" #include "mongo/transport/service_executor_utils.h" #include "mongo/transport/session_workflow.h" +#include "mongo/transport/session_workflow_test_util.h" #include "mongo/transport/transport_layer_mock.h" #include "mongo/unittest/unittest.h" #include "mongo/util/assert_util.h" @@ -69,21 +70,6 @@ namespace mongo { namespace transport { namespace { -/** Scope guard to set and restore an object value. */ -template <typename T> -class ScopedValueOverride { -public: - ScopedValueOverride(T& target, T v) - : _target{target}, _saved{std::exchange(_target, std::move(v))} {} - ~ScopedValueOverride() { - _target = std::move(_saved); - } - -private: - T& _target; - T _saved; -}; - const Status kClosedSessionError{ErrorCodes::SocketException, "Session is closed"}; const Status kNetworkError{ErrorCodes::HostUnreachable, "Someone is unreachable"}; const Status kShutdownError{ErrorCodes::ShutdownInProgress, "Something is shutting down"}; @@ -259,77 +245,6 @@ struct StateResult { SessionState state; }; -class CallbackMockSession : public MockSessionBase { -public: - TransportLayer* getTransportLayer() const override { - return getTransportLayerCb(); - } - - void end() override { - endCb(); - } - - bool isConnected() override { - return isConnectedCb(); - } - - Status waitForData() noexcept override { - return waitForDataCb(); - } - - StatusWith<Message> sourceMessage() noexcept override { - return sourceMessageCb(); - } - - Status sinkMessage(Message message) noexcept override { - return sinkMessageCb(message); - } - - Future<void> asyncWaitForData() noexcept override { - return asyncWaitForDataCb(); - } - - Future<Message> asyncSourceMessage(const BatonHandle& handle) noexcept override { - return asyncSourceMessageCb(handle); - } - - Future<void> asyncSinkMessage(Message message, const BatonHandle& handle) noexcept override { - return asyncSinkMessageCb(message, handle); - } - - std::function<TransportLayer*(void)> getTransportLayerCb; - std::function<void(void)> endCb; - std::function<bool(void)> isConnectedCb; - std::function<Status(void)> waitForDataCb; - std::function<StatusWith<Message>(void)> sourceMessageCb; - std::function<Status(Message)> sinkMessageCb; - std::function<Future<void>(void)> asyncWaitForDataCb; - std::function<Future<Message>(const BatonHandle&)> asyncSourceMessageCb; - std::function<Future<void>(Message, const BatonHandle&)> asyncSinkMessageCb; -}; - -class MockServiceEntryPoint : public ServiceEntryPointImpl { -public: - explicit MockServiceEntryPoint(ServiceContext* svcCtx) : ServiceEntryPointImpl(svcCtx) {} - - Future<DbResponse> handleRequest(OperationContext* opCtx, - const Message& request) noexcept override { - return handleRequestCb(opCtx, request); - } - - void onEndSession(const transport::SessionHandle& handle) override { - onEndSessionCb(handle); - } - - void derivedOnClientDisconnect(Client* client) override { - derivedOnClientDisconnectCb(client); - } - - std::function<Future<DbResponse>(OperationContext*, const Message&)> handleRequestCb; - std::function<void(const transport::SessionHandle)> onEndSessionCb; - std::function<void(Client*)> derivedOnClientDisconnectCb; -}; - /** * The SessionWorkflowTest is a fixture that mocks the external inputs into the * SessionWorkflow so as to provide a deterministic way to evaluate the session workflow diff --git a/src/mongo/transport/session_workflow_test_util.h b/src/mongo/transport/session_workflow_test_util.h new file mode 100644 index 00000000000..552ad9cecce --- /dev/null +++ b/src/mongo/transport/session_workflow_test_util.h @@ -0,0 +1,141 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + + +#include <functional> + +#include "mongo/base/status.h" +#include "mongo/db/dbmessage.h" +#include "mongo/transport/mock_session.h" +#include "mongo/transport/service_entry_point_impl.h" +#include "mongo/transport/transport_layer.h" +#include "mongo/transport/transport_layer_mock.h" + +namespace mongo { +namespace transport { + +using ReactorHandle = std::shared_ptr<Reactor>; + +/** Scope guard to set and restore an object value. */ +template <typename T> +class ScopedValueOverride { +public: + ScopedValueOverride(T& target, T v) + : _target{target}, _saved{std::exchange(_target, std::move(v))} {} + ~ScopedValueOverride() { + _target = std::move(_saved); + } + +private: + T& _target; + T _saved; +}; + +class CallbackMockSession : public MockSessionBase { +public: + TransportLayer* getTransportLayer() const override { + return getTransportLayerCb(); + } + + void end() override { + endCb(); + } + + bool isConnected() override { + return isConnectedCb(); + } + + Status waitForData() noexcept override { + return waitForDataCb(); + } + + StatusWith<Message> sourceMessage() noexcept override { + return sourceMessageCb(); + } + + Status sinkMessage(Message message) noexcept override { + return sinkMessageCb(std::move(message)); + } + + Future<void> asyncWaitForData() noexcept override { + return asyncWaitForDataCb(); + } + + Future<Message> asyncSourceMessage(const BatonHandle& handle) noexcept override { + return asyncSourceMessageCb(handle); + } + + Future<void> asyncSinkMessage(Message message, const BatonHandle& handle) noexcept override { + return asyncSinkMessageCb(std::move(message), handle); + } + + std::function<TransportLayer*()> getTransportLayerCb; + std::function<void()> endCb; + std::function<bool()> isConnectedCb; + std::function<Status()> waitForDataCb; + std::function<StatusWith<Message>()> sourceMessageCb; + std::function<Status(Message)> sinkMessageCb; + std::function<Future<void>()> asyncWaitForDataCb; + std::function<Future<Message>(const BatonHandle&)> asyncSourceMessageCb; + std::function<Future<void>(Message, const BatonHandle&)> asyncSinkMessageCb; +}; + +class MockServiceEntryPoint : public ServiceEntryPointImpl { +public: + explicit MockServiceEntryPoint(ServiceContext* svcCtx) : ServiceEntryPointImpl(svcCtx) {} + + Future<DbResponse> handleRequest(OperationContext* opCtx, + const Message& request) noexcept override { + return handleRequestCb(opCtx, request); + } + + void onEndSession(const SessionHandle& handle) override { + onEndSessionCb(handle); + } + + void derivedOnClientDisconnect(Client* client) override { + derivedOnClientDisconnectCb(client); + } + + void configureServiceExecutorContext(ServiceContext::UniqueClient& client, + bool isPrivilegedSession) override { + if (configureServiceExecutorContextCb) + configureServiceExecutorContextCb(client, isPrivilegedSession); + else + ServiceEntryPointImpl::configureServiceExecutorContext(client, isPrivilegedSession); + } + + std::function<Future<DbResponse>(OperationContext*, const Message&)> handleRequestCb; + std::function<void(const SessionHandle)> onEndSessionCb; + std::function<void(Client*)> derivedOnClientDisconnectCb; + std::function<void(ServiceContext::UniqueClient&, bool)> configureServiceExecutorContextCb; +}; + +} // namespace transport +} // namespace mongo |