diff options
-rw-r--r-- | src/mongo/transport/session_workflow_test.cpp | 1231 |
1 files changed, 516 insertions, 715 deletions
diff --git a/src/mongo/transport/session_workflow_test.cpp b/src/mongo/transport/session_workflow_test.cpp index f85b3e4d120..9cbcc5fb2db 100644 --- a/src/mongo/transport/session_workflow_test.cpp +++ b/src/mongo/transport/session_workflow_test.cpp @@ -30,6 +30,9 @@ #include "mongo/platform/basic.h" +#include <array> +#include <deque> +#include <initializer_list> #include <memory> #include <type_traits> #include <vector> @@ -39,7 +42,6 @@ #include "mongo/bson/bsonobj.h" #include "mongo/bson/bsonobjbuilder.h" #include "mongo/db/client.h" -#include "mongo/db/client_strand.h" #include "mongo/db/concurrency/locker_noop_service_context_test_fixture.h" #include "mongo/db/dbmessage.h" #include "mongo/db/service_context.h" @@ -51,22 +53,16 @@ #include "mongo/transport/mock_session.h" #include "mongo/transport/service_entry_point.h" #include "mongo/transport/service_entry_point_impl.h" -#include "mongo/transport/service_executor.h" -#include "mongo/transport/service_executor_utils.h" #include "mongo/transport/session_workflow.h" #include "mongo/transport/transport_layer_mock.h" #include "mongo/unittest/unittest.h" #include "mongo/util/assert_util.h" -#include "mongo/util/clock_source_mock.h" #include "mongo/util/concurrency/thread_pool.h" -#include "mongo/util/producer_consumer_queue.h" -#include "mongo/util/tick_source_mock.h" #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest -namespace mongo { -namespace transport { +namespace mongo::transport { namespace { /** Scope guard to set and restore an object value. */ @@ -84,181 +80,29 @@ private: T _saved; }; -const Status kClosedSessionError{ErrorCodes::SocketException, "Session is closed"}; -const Status kNetworkError{ErrorCodes::HostUnreachable, "Someone is unreachable"}; -const Status kShutdownError{ErrorCodes::ShutdownInProgress, "Something is shutting down"}; -const Status kArbitraryError{ErrorCodes::InternalError, "Something happened"}; - -/** - * FailureCondition represents a set of the ways any state in the SessionWorkflow can fail. - */ -enum class FailureCondition { - kNone, - kTerminate, // External termination via the ServiceEntryPoint. - kDisconnect, // Socket disconnection by peer. - kNetwork, // Unspecified network failure (ala host unreachable). - kShutdown, // System shutdown. - kArbitrary, // An arbitrary error that does not fall under the other conditions. -}; - -constexpr StringData toString(FailureCondition fail) { - switch (fail) { - case FailureCondition::kNone: - return "None"_sd; - case FailureCondition::kTerminate: - return "Terminate"_sd; - case FailureCondition::kDisconnect: - return "Disconnect"_sd; - case FailureCondition::kNetwork: - return "Network"_sd; - case FailureCondition::kShutdown: - return "Shutdown"_sd; - case FailureCondition::kArbitrary: - return "Arbitrary"_sd; - }; - - return "Unknown"_sd; -} - -std::ostream& operator<<(std::ostream& os, FailureCondition fail) { - return os << toString(fail); +template <typename T> +StringData findEnumName(const std::vector<std::pair<T, StringData>>& arr, T k) { + return std::find_if(arr.begin(), arr.end(), [&](auto&& e) { return e.first == k; })->second; } -/** - * SessionState represents the externally observable state of the SessionWorkflow. These - * states map relatively closely to the internals of the SessionWorkflow::Impl. That said, - * this enum represents the SessionWorkflowTest's external understanding of the internal - * state. - */ -enum class SessionState { - kStart, - kPoll, - kSource, - kProcess, - kSink, - kEnd, -}; - -constexpr StringData toString(SessionState state) { - switch (state) { - case SessionState::kStart: - return "Start"_sd; - case SessionState::kPoll: - return "Poll"_sd; - case SessionState::kSource: - return "Source"_sd; - case SessionState::kProcess: - return "Process"_sd; - case SessionState::kSink: - return "Sink"_sd; - case SessionState::kEnd: - return "End"_sd; - }; - - return "Unknown"_sd; +template <typename T> +static std::string typeName() { + return demangleName(typeid(T)); } - -std::ostream& operator<<(std::ostream& os, SessionState state) { - return os << toString(state); +template <typename T> +static std::ostream& stream(std::ostream& os, const T&) { + return os << "[{}]"_format(typeName<T>()); } - -/** - * RequestKind represents the type of operation of the SessionWorkflow. Depending on various - * message flags and conditions, the SessionWorkflow will transition between states - * differently. - */ -enum class RequestKind { - kDefault, - kExhaust, - kMoreToCome, -}; - -constexpr StringData toString(RequestKind kind) { - switch (kind) { - case RequestKind::kDefault: - return "Default"_sd; - case RequestKind::kExhaust: - return "Exhaust"_sd; - case RequestKind::kMoreToCome: - return "MoreToCome"_sd; - }; - - return "Unknown"_sd; +static std::ostream& stream(std::ostream& os, const Status& v) { + return os << v; } - -std::ostream& operator<<(std::ostream& os, RequestKind kind) { - return os << toString(kind); +template <typename T> +static std::ostream& stream(std::ostream& os, const StatusWith<T>& v) { + if (!v.isOK()) + return stream(os, v.getStatus()); + return stream(os, v.getValue()); } -class ResultValue { -public: - ResultValue() = default; - explicit ResultValue(Status status) : _value(std::move(status)) {} - explicit ResultValue(StatusWith<Message> message) : _value(std::move(message)) {} - explicit ResultValue(StatusWith<DbResponse> response) : _value(std::move(response)) {} - - void setResponse(StatusWith<DbResponse> response) { - _value = response; - } - - StatusWith<DbResponse> getResponse() const { - return _convertTo<StatusWith<DbResponse>>(); - } - - void setMessage(StatusWith<Message> message) { - _value = message; - } - - StatusWith<Message> getMessage() const { - return _convertTo<StatusWith<Message>>(); - } - - void setStatus(Status status) { - _value = status; - } - - Status getStatus() const { - return _convertTo<Status>(); - } - - bool empty() const { - return _value.index() == 0; - } - - explicit operator bool() const { - return !empty(); - } - -private: - template <typename Target> - Target _convertTo() const { - return stdx::visit( - [](auto alt) -> Target { - if constexpr (std::is_convertible<decltype(alt), Target>()) - return alt; - invariant(false, "ResultValue not convertible to target type"); - MONGO_COMPILER_UNREACHABLE; - }, - _value); - } - - stdx::variant<stdx::monostate, Status, StatusWith<Message>, StatusWith<DbResponse>> _value; -}; - -/** - * This class stores and synchronizes the shared state result between the test - * fixture and its various wrappers. - */ -struct StateResult { - Mutex mutex = MONGO_MAKE_LATCH("StateResult::_mutex"); - stdx::condition_variable cv; - - AtomicWord<bool> isConnected{true}; - - ResultValue result; - SessionState state; -}; - class CallbackMockSession : public MockSessionBase { public: TransportLayer* getTransportLayer() const override { @@ -297,13 +141,13 @@ public: return asyncSinkMessageCb(message, handle); } - std::function<TransportLayer*(void)> getTransportLayerCb; - std::function<void(void)> endCb; - std::function<bool(void)> isConnectedCb; - std::function<Status(void)> waitForDataCb; - std::function<StatusWith<Message>(void)> sourceMessageCb; + std::function<TransportLayer*()> getTransportLayerCb; + std::function<void()> endCb; + std::function<bool()> isConnectedCb; + std::function<Status()> waitForDataCb; + std::function<StatusWith<Message>()> sourceMessageCb; std::function<Status(Message)> sinkMessageCb; - std::function<Future<void>(void)> asyncWaitForDataCb; + std::function<Future<void>()> asyncWaitForDataCb; std::function<Future<Message>(const BatonHandle&)> asyncSourceMessageCb; std::function<Future<void>(Message, const BatonHandle&)> asyncSinkMessageCb; }; @@ -317,7 +161,7 @@ public: return handleRequestCb(opCtx, request); } - void onEndSession(const transport::SessionHandle& handle) override { + void onEndSession(const SessionHandle& handle) override { onEndSessionCb(handle); } @@ -326,390 +170,370 @@ public: } std::function<Future<DbResponse>(OperationContext*, const Message&)> handleRequestCb; - std::function<void(const transport::SessionHandle)> onEndSessionCb; + std::function<void(const SessionHandle)> onEndSessionCb; std::function<void(Client*)> derivedOnClientDisconnectCb; }; /** - * The SessionWorkflowTest is a fixture that mocks the external inputs into the - * SessionWorkflow so as to provide a deterministic way to evaluate the session workflow - * implemenation. + * Events generated by SessionWorkflow, mostly by virtual function calls to mock + * objects connected to SessionWorkflow. */ -class SessionWorkflowTest : public LockerNoopServiceContextTest { -public: - /** - * Make a generic thread pool to deliver external inputs out of line (mocking the network or - * database workers). - */ - static std::shared_ptr<ThreadPool> makeThreadPool() { - auto options = ThreadPool::Options{}; - options.poolName = "SessionWorkflowTest"; +enum class Event { + kStart, + kWaitForData, + kSource, + kProcess, + kSink, + kEnd, +}; - return std::make_shared<ThreadPool>(std::move(options)); - } +StringData toString(Event e) { + return findEnumName( + { + {Event::kStart, "Start"_sd}, + {Event::kWaitForData, "WaitForData"_sd}, + {Event::kSource, "Source"_sd}, + {Event::kProcess, "Process"_sd}, + {Event::kSink, "Sink"_sd}, + {Event::kEnd, "End"_sd}, + }, + e); +} - void setUp() override; - void tearDown() override; +std::ostream& operator<<(std::ostream& os, Event e) { + return os << toString(e); +} - /** - * This function blocks until the SessionWorkflowTest observes a state change. - */ - SessionState popSessionState() { - return _stateQueue.pop(); +class Result { + using Variant = + stdx::variant<stdx::monostate, Status, StatusWith<Message>, StatusWith<DbResponse>>; + +public: + Result() = default; + + template <typename T, std::enable_if_t<std::is_constructible_v<Variant, T&&>, int> = 0> + explicit Result(T&& v) : _value{std::forward<T>(v)} {} + + template <typename T> + T consumeAs() && { + return stdx::visit( + [](auto&& alt) -> T { + using A = decltype(alt); + if constexpr (std::is_convertible<A, T>()) + return std::forward<A>(alt); + invariant(0, "{} => {}"_format(typeName<A>(), typeName<T>())); + MONGO_UNREACHABLE; + }, + std::exchange(_value, {})); } - /** - * This function asserts that the SessionWorkflowTest has not yet observed a state change. - * - * Note that this function does not guarantee that it will not observe a state change in the - * future. - */ - void assertNoSessionState() { - if (auto maybeState = _stateQueue.tryPop()) { - FAIL("The queue is not empty, state: ") << *maybeState; - } + explicit operator bool() const { + return _value.index() != 0; } - /** - * This function stores an external response to be delivered out of line to the - * SessionWorkflow. - */ - void setResult(SessionState state, ResultValue result) { - auto lk = stdx::lock_guard(_stateResult->mutex); - invariant(state == SessionState::kPoll || state == SessionState::kSource || - state == SessionState::kProcess || state == SessionState::kSink); - _stateResult->result = std::move(result); - _stateResult->state = state; - _stateResult->cv.notify_one(); +private: + friend std::string toString(const Result& r) { + return stdx::visit( + [](auto&& alt) -> std::string { + using A = std::decay_t<decltype(alt)>; + std::ostringstream os; + stream(os << "[{}]"_format(typeName<A>()), alt); + return os.str(); + }, + r._value); } + Variant _value; +}; + +class SessionEventTracker { +public: /** - * This function makes a generic result appropriate for a successful state change given - * SessionState and RequestKind. + * Prepare a response for the `event`. Called by the test to inject + * a behavior for the next mock object that calls `consumeExpectation` with + * the same `event`. Only one response can be prepared at a time. */ - ResultValue makeGenericResult(SessionState state, RequestKind kind) { - ResultValue result; - switch (state) { - case SessionState::kPoll: - case SessionState::kSink: - result.setStatus(Status::OK()); - break; - case SessionState::kSource: { - Message message = _makeIndexedBson(); - switch (kind) { - case RequestKind::kExhaust: - OpMsg::setFlag(&message, OpMsg::kExhaustSupported); - break; - case RequestKind::kDefault: - case RequestKind::kMoreToCome: - break; - } - result.setMessage(StatusWith<Message>(message)); - } break; - case SessionState::kProcess: { - DbResponse response; - switch (kind) { - case RequestKind::kDefault: - response.response = _makeIndexedBson(); - break; - case RequestKind::kExhaust: - response.response = _makeIndexedBson(); - response.shouldRunAgainForExhaust = true; - break; - case RequestKind::kMoreToCome: - break; - } - result.setResponse(response); - } break; + void prepareResponse(Event event, Result v) { + switch (event) { + case Event::kWaitForData: + case Event::kSource: + case Event::kProcess: + case Event::kSink: { + stdx::lock_guard lk{_mutex}; + invariant(!_expect); + _expect = std::unique_ptr<EventAndResult>{new EventAndResult{event, std::move(v)}}; + LOGV2(6742612, + "SessionEventTracker::set", + "event"_attr = toString(_expect->event), + "value"_attr = toString(_expect->result)); + _cv.notify_one(); + return; + } default: - invariant( - false, - "Unable to make generic result for this state: {}"_format(toString(state))); + invariant(0, "SessionEventTracker::set for bad event={}"_format(toString(event))); + MONGO_UNREACHABLE; } - return result; } - /** - * Initialize a new Session. - */ - void initNewSession(); - - /** - * Launch a SessionWorkflow for the current session. - */ - void startSession(); - - /** - * Wait for the current Session and SessionWorkflow to end. - */ - void joinSession(); - - /** - * Mark the session as no longer connected. - */ void endSession() { - auto lk = stdx::lock_guard(_stateResult->mutex); - if (_stateResult->isConnected.swap(false)) { + if (setConnected(false)) LOGV2(5014101, "Ending session"); - _stateResult->cv.notify_one(); - } - } - - /** - * Start a brand new session, run the given function, and then join the session. - */ - template <typename F> - void runWithNewSession(F&& func) { - initNewSession(); - startSession(); - - auto firstState = popSessionState(); - ASSERT(firstState == SessionState::kSource || firstState == SessionState::kPoll) - << "State was instead: " << toString(firstState); - - std::forward<F>(func)(); - - joinSession(); } - void terminateViaServiceEntryPoint(); - bool isConnected() const { - return _stateResult->isConnected.load(); + stdx::unique_lock lk{_mutex}; + return _isConnected; } - int onClientDisconnectCalledTimes() const { - return _onClientDisconnectCalled; + bool setConnected(bool b) { + stdx::unique_lock lk{_mutex}; + bool old = _isConnected; + if (old != b) { + _isConnected = b; + _cv.notify_all(); + } + return old; } -private: /** - * Generate a resonably generic BSON with an id for use in debugging. + * Called by mock objects to inject behavior into them. The mock function + * call is identified by an `event`. Waits if necessary until a response + * has been prepared for that event. */ - static Message _makeIndexedBson() { - auto bob = BSONObjBuilder(); - static auto nextId = AtomicWord<int>{0}; - bob.append("id", nextId.fetchAndAdd(1)); - - auto omb = OpMsgBuilder{}; - omb.setBody(bob.obj()); - return omb.finish(); + Result consumeExpectation(Event event) { + stdx::unique_lock lk{_mutex}; + _cv.wait(lk, [&] { return (_expect && _expect->event == event) || !_isConnected; }); + if (!(_expect && _expect->event == event)) + return Result(Status{ErrorCodes::SocketException, "Session is closed"}); + invariant(_expect); + invariant(_expect->event == event); + invariant(_expect->result); + return std::exchange(_expect, {})->result; } - /** - * Use an external result to mock handling a request. - */ - StatusWith<DbResponse> _processRequest(OperationContext* opCtx, const Message&) noexcept { - _stateQueue.push(SessionState::kProcess); - - auto result = [&]() -> StatusWith<DbResponse> { - auto lk = stdx::unique_lock(_stateResult->mutex); - _stateResult->cv.wait(lk, [this] { - return (_stateResult->result && _stateResult->state == SessionState::kProcess) || - !isConnected(); - }); - - if (!isConnected()) { - return kClosedSessionError; - } - - invariant(_stateResult->result); - return std::exchange(_stateResult->result, {}).getResponse(); - }(); +private: + struct EventAndResult { + Event event; + Result result; + }; - LOGV2(5014100, "Handled request", "error"_attr = result.getStatus()); + mutable Mutex _mutex; + stdx::condition_variable _cv; + bool _isConnected = true; + std::unique_ptr<EventAndResult> _expect; +}; - return result; +/** Fixture to mock interactions with the SessionWorkflow. */ +class SessionWorkflowTest : public LockerNoopServiceContextTest { +public: + void setUp() override { + ServiceContextTest::setUp(); + auto sc = getServiceContext(); + sc->setServiceEntryPoint(_makeServiceEntryPoint(sc)); + _makeSession(); + invariant(sep()->start()); + _threadPool->startup(); } - Future<DbResponse> _handleRequest(OperationContext* opCtx, const Message& request) noexcept { - auto [p, f] = makePromiseFuture<DbResponse>(); - ExecutorFuture<void>(_threadPool) - .then([this, opCtx, &request, p = std::move(p)]() mutable { - auto strand = ClientStrand::get(opCtx->getClient()); - strand->run([&] { p.setWith([&] { return _processRequest(opCtx, request); }); }); - }) - .getAsync([](auto&&) {}); - return std::move(f); + void tearDown() override { + ScopeGuard guard = [&] { ServiceContextTest::tearDown(); }; + endSession(); + // Normal shutdown is a noop outside of ASAN. + invariant(sep()->shutdownAndWait(Seconds{10})); + _threadPool->shutdown(); + _threadPool->join(); } - /** - * Use an external result to mock polling for data and observe the state. - */ - Status _waitForData() { - _stateQueue.push(SessionState::kPoll); - - auto result = [&]() -> Status { - auto lk = stdx::unique_lock(_stateResult->mutex); - _stateResult->cv.wait(lk, [this] { - return (_stateResult->result && _stateResult->state == SessionState::kPoll) || - !isConnected(); - }); - - if (!isConnected()) { - return kClosedSessionError; - } - - invariant(_stateResult->result); - return std::exchange(_stateResult->result, {}).getStatus(); - }(); - - LOGV2(5014102, "Finished waiting for data", "error"_attr = result); - return result; + /** Blocks until the SessionWorkflowTest observes an event. */ + Event awaitEvent() { + return _eventSlot.wait(); } - /** - * Use an external result to mock reading data and observe the state. - */ - StatusWith<Message> _sourceMessage() { - _stateQueue.push(SessionState::kSource); - - auto result = [&]() -> StatusWith<Message> { - auto lk = stdx::unique_lock(_stateResult->mutex); - _stateResult->cv.wait(lk, [this] { - return (_stateResult->result && _stateResult->state == SessionState::kSource) || - !isConnected(); - }); - - if (!isConnected()) { - return kClosedSessionError; - } + /** Stores an event response to be consumed by a mock. */ + void prepareResponse(Event event, Result result) { + _sessionEventTracker.prepareResponse(event, std::move(result)); + } - invariant(_stateResult->result); - return std::exchange(_stateResult->result, {}).getMessage(); - }(); + /** Waits for the current Session and SessionWorkflow to end. */ + void joinSession() { + ASSERT(sep()->waitForNoSessions(Seconds{1})); + ASSERT_FALSE(_eventSlot) << "An unconsumed expectation is an error in the test"; + } - LOGV2(5014103, "Sourced message", "error"_attr = result.getStatus()); + /** Launches a SessionWorkflow for the current session. */ + void startSession() { + LOGV2(6742613, "Starting session"); + _sessionEventTracker.setConnected(true); + sep()->startSession(_session); + } - return result; + void endSession() { + _sessionEventTracker.endSession(); } - /** - * Use an external result to mock writing data and observe the state. - */ - Status _sinkMessage(Message message) { - _stateQueue.push(SessionState::kSink); - - auto result = [&]() -> Status { - auto lk = stdx::unique_lock(_stateResult->mutex); - _stateResult->cv.wait(lk, [this] { - return (_stateResult->result && _stateResult->state == SessionState::kSink) || - !isConnected(); - }); - - if (!isConnected()) { - return kClosedSessionError; - } + void terminateViaServiceEntryPoint() { + sep()->endAllSessionsNoTagMask(); + } - invariant(_stateResult->result); - return std::exchange(_stateResult->result, {}).getStatus(); - }(); + MockServiceEntryPoint* sep() { + return checked_cast<MockServiceEntryPoint*>(getServiceContext()->getServiceEntryPoint()); + } - LOGV2(5014104, "Sunk message", "error"_attr = result); - return result; +private: + std::shared_ptr<ThreadPool> _makeThreadPool() { + ThreadPool::Options options{}; + options.poolName = "SessionWorkflowTest"; + return std::make_shared<ThreadPool>(std::move(options)); } - Future<void> _asyncWaitForData() noexcept { - return ExecutorFuture<void>(_threadPool) - .then([this] { return _waitForData(); }) - .unsafeToInlineFuture(); + std::unique_ptr<MockServiceEntryPoint> _makeServiceEntryPoint(ServiceContext* sc) { + auto sep = std::make_unique<MockServiceEntryPoint>(sc); + sep->handleRequestCb = [=](OperationContext*, const Message&) { + return _onMockEvent<StatusWith<DbResponse>>(Event::kProcess); + }; + sep->onEndSessionCb = [=](const SessionHandle&) { _onMockEvent<void>(Event::kEnd); }; + sep->derivedOnClientDisconnectCb = [&](Client*) {}; + return sep; + } + + /** Create and initialize a mock Session. */ + void _makeSession() { + auto s = std::make_shared<CallbackMockSession>(); + s->endCb = [=] { endSession(); }; + s->isConnectedCb = [=] { return _isConnected(); }; + s->waitForDataCb = [=] { return _onMockEvent<Status>(Event::kWaitForData); }; + s->sourceMessageCb = [=] { return _onMockEvent<StatusWith<Message>>(Event::kSource); }; + s->sinkMessageCb = [=](Message) { return _onMockEvent<Status>(Event::kSink); }; + // The async variants will just run the same callback on `_threadPool`. + auto async = [this](auto cb) { + return ExecutorFuture<void>(_threadPool).then(cb).unsafeToInlineFuture(); + }; + s->asyncWaitForDataCb = [=, cb = s->waitForDataCb] { return async([cb] { return cb(); }); }; + s->asyncSourceMessageCb = [=, cb = s->sourceMessageCb](const BatonHandle&) { + return async([cb] { return cb(); }); + }; + s->asyncSinkMessageCb = [=, cb = s->sinkMessageCb](Message m, const BatonHandle&) { + return async([cb, m = std::move(m)]() mutable { return cb(std::move(m)); }); + }; + _session = std::move(s); } - Future<Message> _asyncSourceMessage() noexcept { - return ExecutorFuture<void>(_threadPool) - .then([this] { return _sourceMessage(); }) - .unsafeToInlineFuture(); + bool _isConnected() const { + return _sessionEventTracker.isConnected(); } - Future<void> _asyncSinkMessage(Message message) noexcept { - return ExecutorFuture<void>(_threadPool) - .then([this, message = std::move(message)]() mutable { - return _sinkMessage(std::move(message)); - }) - .unsafeToInlineFuture(); + /** Called by all mock functions to notify test thread and get a value with which to respond. */ + template <typename Target> + Target _onMockEvent(Event event) { + LOGV2(6742616, "Arrived", "event"_attr = toString(event)); + _eventSlot.signal(std::move(event)); + if constexpr (std::is_same_v<Target, void>) { + return; + } else { + auto r = _sessionEventTracker.consumeExpectation(event); + LOGV2(6742618, "Waited for Event", "event"_attr = event, "result"_attr = toString(r)); + return std::move(r).consumeAs<Target>(); + } } - MockServiceEntryPoint* _sep; + /** An awaitable event slot. */ + class SyncMockEventSlot { + public: + void signal(Event e) { + stdx::unique_lock lk(_mu); + invariant(!_event, "All events must be consumed in order"); + _event = e; + _arrival.notify_one(); + } - const std::shared_ptr<ThreadPool> _threadPool = makeThreadPool(); + Event wait() { + stdx::unique_lock lk(_mu); + _arrival.wait(lk, [&] { return _event; }); + return *std::exchange(_event, {}); + } - std::unique_ptr<StateResult> _stateResult; + explicit operator bool() const { + stdx::unique_lock lk(_mu); + return !!_event; + } - std::shared_ptr<CallbackMockSession> _session; - SingleProducerSingleConsumerQueue<SessionState> _stateQueue; + private: + mutable Mutex _mu; + stdx::condition_variable _arrival; + boost::optional<Event> _event; + }; - int _onClientDisconnectCalled{0}; + std::shared_ptr<Session> _session; + std::shared_ptr<ThreadPool> _threadPool = _makeThreadPool(); + SessionEventTracker _sessionEventTracker; + SyncMockEventSlot _eventSlot; }; -/** - * This class iterates over the potential methods of failure for a set of steps. - */ -class StepRunner { - /** - * This is a simple data structure describing the external response for one state in the - * session workflow. - */ - struct Step { - SessionState state; - RequestKind kind; - }; - using StepList = std::vector<Step>; - -public: - StepRunner(SessionWorkflowTest* fixture) : _fixture{fixture} {} - ~StepRunner() { - invariant(_runCount > 0, "StepRunner expects to be run at least once"); - } +TEST_F(SessionWorkflowTest, StartThenEndSession) { + startSession(); + ASSERT_EQ(awaitEvent(), Event::kSource); + endSession(); + joinSession(); +} - /** - * Given a FailureCondition, cause an external result to be delivered that is appropriate for - * the given state and request kind. - */ - SessionState doGenericStep(const Step& step, FailureCondition fail) { - switch (fail) { - case FailureCondition::kNone: { - _fixture->setResult(step.state, _fixture->makeGenericResult(step.state, step.kind)); - } break; - case FailureCondition::kTerminate: { - _fixture->terminateViaServiceEntryPoint(); - // We expect that the session will be disconnected via the SEP, no need to set any - // result. - } break; - case FailureCondition::kDisconnect: { - _fixture->endSession(); - // We expect that the session will be disconnected directly, no need to set any - // result. - } break; - case FailureCondition::kNetwork: { - _fixture->setResult(step.state, ResultValue(kNetworkError)); - } break; - case FailureCondition::kShutdown: { - _fixture->setResult(step.state, ResultValue(kShutdownError)); - } break; - case FailureCondition::kArbitrary: { - _fixture->setResult(step.state, ResultValue(kArbitraryError)); - } break; - }; +TEST_F(SessionWorkflowTest, EndBeforeStartSession) { + endSession(); + startSession(); + ASSERT_EQ(awaitEvent(), Event::kSource); + endSession(); + ASSERT_EQ(awaitEvent(), Event::kEnd); + joinSession(); +} - return _fixture->popSessionState(); - } +TEST_F(SessionWorkflowTest, OnClientDisconnectCalledOnCleanup) { + int disconnects = 0; + sep()->derivedOnClientDisconnectCb = [&](Client*) { ++disconnects; }; + startSession(); + ASSERT_EQ(awaitEvent(), Event::kSource); + ASSERT_EQ(disconnects, 0); + endSession(); + ASSERT_EQ(awaitEvent(), Event::kEnd); + joinSession(); + ASSERT_EQ(disconnects, 1); +} - /** - * Mark an additional expected state in the session workflow. - */ - void expectNextState(SessionState state, RequestKind kind) { - auto step = Step{}; - step.state = state; - step.kind = kind; - _steps.emplace_back(std::move(step)); - } +class StepRunner { +public: + enum class Action { + kDefault, + kExhaust, + kMoreToCome, + + kErrTerminate, // External termination via the ServiceEntryPoint. + kErrDisconnect, // Socket disconnection by peer. + kErrNetwork, // Unspecified network failure (ala host unreachable). + kErrShutdown, // System shutdown. + kErrArbitrary, // An arbitrary error that does not fall under the other conditions. + }; + friend StringData toString(Action k) { + return findEnumName({{Action::kDefault, "Default"_sd}, + {Action::kExhaust, "Exhaust"_sd}, + {Action::kMoreToCome, "MoreToCome"_sd}, + {Action::kErrTerminate, "ErrTerminate"_sd}, + {Action::kErrDisconnect, "ErrDisconnect"_sd}, + {Action::kErrNetwork, "ErrNetwork"_sd}, + {Action::kErrShutdown, "ErrShutdown"_sd}, + {Action::kErrArbitrary, "ErrArbitrary"_sd}}, + k); + } + + /** Encodes a response to `event` by taking `action`. */ + struct Step { + Event event; + Action action = Action::kDefault; + }; - /** - * Mark the final expected state in the session workflow. - */ - void expectFinalState(SessionState state) { - _finalState = state; - } + // The final step is assumed to have `kErrDisconnect` as an action, + // yielding an implied `kEnd` step. + StepRunner(SessionWorkflowTest* fixture, std::deque<Step> steps) + : _fixture{fixture}, _steps{[&, at = steps.size() - 1] { + return _appendTermination(std::move(steps), at, Action::kErrDisconnect); + }()} {} /** * Run a set of variations on the steps, failing further and further along the way. @@ -725,251 +549,228 @@ public: * [(Source, None), (Sink, Terminate), (End)] */ void run() { - invariant(_finalState); + const auto baseline = std::deque<Step>(_steps.begin(), _steps.end()); + LOGV2(5014106, "Running one entirely clean run"); + _runSteps(baseline); + + static constexpr std::array fails{Action::kErrTerminate, + Action::kErrDisconnect, + Action::kErrNetwork, + Action::kErrShutdown, + Action::kErrArbitrary}; + + // Incrementally push forward the step where we fail. + for (size_t failAt = 0; failAt < baseline.size(); ++failAt) { + LOGV2(6742614, "Injecting failures", "failAt"_attr = failAt); + for (auto fail : fails) + _runSteps(_appendTermination(baseline, failAt, fail)); + } + } - auto getExpectedPostState = [&](auto iter) { - auto nextIter = ++iter; - if (nextIter == _steps.end()) { - return *_finalState; - } - return nextIter->state; - }; +private: + /** + * Returns a new steps sequence, formed by copying the specified `q`, and + * modifying the copy to be terminated with a `fail` at the `failAt` index. + */ + std::deque<Step> _appendTermination(std::deque<Step> q, size_t failAt, Action fail) const { + LOGV2(6742617, "appendTermination", "fail"_attr = toString(fail), "failAt"_attr = failAt); + invariant(failAt < q.size()); + q.erase(q.begin() + failAt + 1, q.end()); + q.back().action = fail; + q.push_back({Event::kEnd}); + return q; + } + + template <typename T> + void _dumpTransitions(const T& q) { + BSONArrayBuilder bab; + for (auto&& t : q) { + BSONObjBuilder{bab.subobjStart()} + .append("event", toString(t.event)) + .append("action", toString(t.action)); + } + LOGV2(6742615, "Run transitions", "transitions"_attr = bab.arr()); + } - // Do one entirely clean run. - LOGV2(5014106, "Running success case"); - _fixture->runWithNewSession([&] { - for (auto iter = _steps.begin(); iter != _steps.end(); ++iter) { - ASSERT_EQ(doGenericStep(*iter, FailureCondition::kNone), - getExpectedPostState(iter)); - } + /** Generates an OpMsg containing a BSON with a unique 'id' field. */ + Message _makeOpMsg() { + static auto nextId = AtomicWord<int>{0}; + auto omb = OpMsgBuilder{}; + omb.setBody(BSONObjBuilder{}.append("id", nextId.fetchAndAdd(1)).obj()); + return omb.finish(); + } - _fixture->endSession(); - ASSERT_EQ(_fixture->popSessionState(), SessionState::kEnd); - }); - - const auto failList = std::vector<FailureCondition>{FailureCondition::kTerminate, - FailureCondition::kDisconnect, - FailureCondition::kNetwork, - FailureCondition::kShutdown, - FailureCondition::kArbitrary}; - - for (auto failIter = _steps.begin(); failIter != _steps.end(); ++failIter) { - // Incrementally push forward the step where we fail. - for (auto fail : failList) { - LOGV2(5014105, - "Running failure case", - "failureCase"_attr = fail, - "sessionState"_attr = failIter->state, - "requestKind"_attr = failIter->kind); - - _fixture->runWithNewSession([&] { - auto iter = _steps.begin(); - for (; iter != failIter; ++iter) { - // Run through each step until our point of failure with - // FailureCondition::kNone. - ASSERT_EQ(doGenericStep(*iter, FailureCondition::kNone), - getExpectedPostState(iter)) - << "Current state: (" << iter->state << ", " << iter->kind << ")"; - } - - // Finally fail on a given step. - ASSERT_EQ(doGenericStep(*iter, fail), SessionState::kEnd); - }); + /** Makes a result for a successful event. */ + Result _successResult(Event event, Action action) { + switch (event) { + case Event::kWaitForData: + case Event::kSink: + return Result{Status::OK()}; + case Event::kSource: { + Message m = _makeOpMsg(); + switch (action) { + case Action::kExhaust: + OpMsg::setFlag(&m, OpMsg::kExhaustSupported); + break; + default: + break; + } + return Result{StatusWith{std::move(m)}}; + } + case Event::kProcess: { + DbResponse response{}; + switch (action) { + case Action::kDefault: + response.response = _makeOpMsg(); + break; + case Action::kExhaust: + response.response = _makeOpMsg(); + response.shouldRunAgainForExhaust = true; + break; + default: + break; + } + return Result{StatusWith{std::move(response)}}; } + default: + invariant(0, "Bad event: {}"_format(toString(event))); } - - _runCount += 1; + MONGO_UNREACHABLE; } -private: - SessionWorkflowTest* const _fixture; + void _injectStep(const Step& t) { + switch (t.action) { + case Action::kErrTerminate: + _fixture->terminateViaServiceEntryPoint(); + break; + case Action::kErrDisconnect: + _fixture->endSession(); + break; + case Action::kErrNetwork: + _fixture->prepareResponse(t.event, Result(Status{ErrorCodes::HostUnreachable, ""})); + break; + case Action::kErrShutdown: + _fixture->prepareResponse(t.event, + Result(Status{ErrorCodes::ShutdownInProgress, ""})); + break; + case Action::kErrArbitrary: + _fixture->prepareResponse(t.event, Result(Status{ErrorCodes::InternalError, ""})); + break; + default: + _fixture->prepareResponse(t.event, _successResult(t.event, t.action)); + break; + } + } - boost::optional<SessionState> _finalState; - StepList _steps; + /** Start a new session, run the `steps` sequence, and join the session. */ + void _runSteps(std::deque<Step> q) { + _dumpTransitions(q); + _fixture->startSession(); + for (; !q.empty(); q.pop_front()) { + auto&& t = q.front(); + auto event = _fixture->awaitEvent(); + ASSERT_EQ(event, t.event); + LOGV2(6742610, + "Test main thread received an event and taking action", + "event"_attr = toString(t.event), + "action"_attr = toString(t.action)); + if (t.event == Event::kEnd) + break; + _injectStep(t); + } + _fixture->joinSession(); + } - // This variable is currently used as a post-condition to make sure that the StepRunner has been - // run. In the current form, it could be a boolean. That said, if you need to stress test the - // SessionWorkflow, you will want to check this variable to make sure you have run as many - // times as you expect. - size_t _runCount = 0; + SessionWorkflowTest* _fixture; + std::deque<Step> _steps; }; -void SessionWorkflowTest::initNewSession() { - assertNoSessionState(); - - _session = std::make_shared<CallbackMockSession>(); - _session->endCb = [&] { endSession(); }; - _session->isConnectedCb = [&] { return isConnected(); }; - _session->waitForDataCb = [&] { return _waitForData(); }; - _session->sourceMessageCb = [&] { return _sourceMessage(); }; - _session->sinkMessageCb = [&](Message message) { return _sinkMessage(std::move(message)); }; - _session->asyncWaitForDataCb = [&] { return _asyncWaitForData(); }; - _session->asyncSourceMessageCb = [&](const BatonHandle&) { return _asyncSourceMessage(); }; - _session->asyncSinkMessageCb = [&](Message message, const BatonHandle&) { - return _asyncSinkMessage(std::move(message)); - }; - _stateResult->isConnected.store(true); -} - -void SessionWorkflowTest::joinSession() { - ASSERT(_sep->waitForNoSessions(Seconds{1})); - - assertNoSessionState(); -} - -void SessionWorkflowTest::startSession() { - _sep->startSession(_session); -} - -void SessionWorkflowTest::terminateViaServiceEntryPoint() { - _sep->endAllSessionsNoTagMask(); -} - -void SessionWorkflowTest::setUp() { - ServiceContextTest::setUp(); - - auto sep = std::make_unique<MockServiceEntryPoint>(getServiceContext()); - sep->handleRequestCb = [&](OperationContext* opCtx, const Message& request) { - return _handleRequest(opCtx, request); - }; - sep->onEndSessionCb = [&](const transport::SessionHandle& session) { - invariant(session == _session); - _stateQueue.push(SessionState::kEnd); - }; - sep->derivedOnClientDisconnectCb = [&](Client*) { ++_onClientDisconnectCalled; }; - _sep = sep.get(); - getServiceContext()->setServiceEntryPoint(std::move(sep)); - invariant(_sep->start()); - - _threadPool->startup(); - - _stateResult = std::make_unique<StateResult>(); -} +template <bool useDedicatedThread> +class StepRunnerSessionWorkflowTest : public SessionWorkflowTest { +public: + using Action = StepRunner::Action; -void SessionWorkflowTest::tearDown() { - ON_BLOCK_EXIT([&] { ServiceContextTest::tearDown(); }); + void runSteps(std::deque<StepRunner::Step> steps) { + StepRunner{this, steps}.run(); + } - endSession(); + std::deque<StepRunner::Step> defaultLoop() const { + return { + {Event::kSource}, + {Event::kProcess}, + {Event::kSink}, + {Event::kSource}, + }; + } - // Normal shutdown is a noop outside of ASAN. - invariant(_sep->shutdownAndWait(Seconds{10})); + std::deque<StepRunner::Step> exhaustLoop() const { + return { + {Event::kSource, Action::kExhaust}, + {Event::kProcess, Action::kExhaust}, + {Event::kSink}, + {Event::kProcess}, + {Event::kSink}, + {Event::kSource}, + }; + } - _threadPool->shutdown(); - _threadPool->join(); -} + std::deque<StepRunner::Step> moreToComeLoop() const { + return { + {Event::kSource, Action::kMoreToCome}, + {Event::kProcess, Action::kMoreToCome}, + {Event::kSource}, + {Event::kProcess}, + {Event::kSink}, + {Event::kSource}, + }; + } -template <bool useDedicatedThread> -class DedicatedThreadOverrideTest : public SessionWorkflowTest { +private: ScopedValueOverride<bool> _svo{gInitialUseDedicatedThread, useDedicatedThread}; }; -using SessionWorkflowWithBorrowedThreadsTest = DedicatedThreadOverrideTest<false>; -using SessionWorkflowWithDedicatedThreadsTest = DedicatedThreadOverrideTest<true>; - -TEST_F(SessionWorkflowTest, StartThenEndSession) { - initNewSession(); - startSession(); - - ASSERT_EQ(popSessionState(), SessionState::kSource); - - endSession(); -} - -TEST_F(SessionWorkflowTest, EndBeforeStartSession) { - initNewSession(); - endSession(); - startSession(); -} - -TEST_F(SessionWorkflowTest, OnClientDisconnectCalledOnCleanup) { - initNewSession(); - startSession(); - ASSERT_EQ(popSessionState(), SessionState::kSource); - ASSERT_EQ(onClientDisconnectCalledTimes(), 0); - endSession(); - ASSERT_EQ(popSessionState(), SessionState::kEnd); - joinSession(); - ASSERT_EQ(onClientDisconnectCalledTimes(), 1); -} - -TEST_F(SessionWorkflowWithDedicatedThreadsTest, DefaultLoop) { - auto runner = StepRunner(this); +class DedicatedThreadSessionWorkflowTest : public StepRunnerSessionWorkflowTest<true> {}; - runner.expectNextState(SessionState::kSource, RequestKind::kDefault); - runner.expectNextState(SessionState::kProcess, RequestKind::kDefault); - runner.expectNextState(SessionState::kSink, RequestKind::kDefault); - runner.expectFinalState(SessionState::kSource); - - runner.run(); +TEST_F(DedicatedThreadSessionWorkflowTest, DefaultLoop) { + runSteps(defaultLoop()); } -TEST_F(SessionWorkflowWithDedicatedThreadsTest, ExhaustLoop) { - auto runner = StepRunner(this); - - runner.expectNextState(SessionState::kSource, RequestKind::kExhaust); - runner.expectNextState(SessionState::kProcess, RequestKind::kExhaust); - runner.expectNextState(SessionState::kSink, RequestKind::kExhaust); - runner.expectNextState(SessionState::kProcess, RequestKind::kDefault); - runner.expectNextState(SessionState::kSink, RequestKind::kDefault); - runner.expectFinalState(SessionState::kSource); - - runner.run(); +TEST_F(DedicatedThreadSessionWorkflowTest, ExhaustLoop) { + runSteps(exhaustLoop()); } -TEST_F(SessionWorkflowWithDedicatedThreadsTest, MoreToComeLoop) { - auto runner = StepRunner(this); - - runner.expectNextState(SessionState::kSource, RequestKind::kMoreToCome); - runner.expectNextState(SessionState::kProcess, RequestKind::kMoreToCome); - runner.expectNextState(SessionState::kSource, RequestKind::kDefault); - runner.expectNextState(SessionState::kProcess, RequestKind::kDefault); - runner.expectNextState(SessionState::kSink, RequestKind::kDefault); - runner.expectFinalState(SessionState::kSource); - - runner.run(); +TEST_F(DedicatedThreadSessionWorkflowTest, MoreToComeLoop) { + runSteps(moreToComeLoop()); } -TEST_F(SessionWorkflowWithBorrowedThreadsTest, DefaultLoop) { - auto runner = StepRunner(this); - - runner.expectNextState(SessionState::kPoll, RequestKind::kDefault); - runner.expectNextState(SessionState::kSource, RequestKind::kDefault); - runner.expectNextState(SessionState::kProcess, RequestKind::kDefault); - runner.expectNextState(SessionState::kSink, RequestKind::kDefault); - runner.expectFinalState(SessionState::kPoll); +class BorrowedThreadSessionWorkflowTest : public StepRunnerSessionWorkflowTest<false> { +public: + /** + * Under the borrowed thread model, the steps are the same as for dedicated thread model, + * except that kSource events are preceded by kWaitForData events. + */ + std::deque<StepRunner::Step> borrowedSteps(std::deque<StepRunner::Step> q) { + for (auto iter = q.begin(); iter != q.end(); ++iter) { + if (iter->event == Event::kSource) { + iter = q.insert(iter, {Event::kWaitForData}); + ++iter; + } + } + return q; + } +}; - runner.run(); +TEST_F(BorrowedThreadSessionWorkflowTest, DefaultLoop) { + runSteps(borrowedSteps(defaultLoop())); } -TEST_F(SessionWorkflowWithBorrowedThreadsTest, ExhaustLoop) { - auto runner = StepRunner(this); - - runner.expectNextState(SessionState::kPoll, RequestKind::kExhaust); - runner.expectNextState(SessionState::kSource, RequestKind::kExhaust); - runner.expectNextState(SessionState::kProcess, RequestKind::kExhaust); - runner.expectNextState(SessionState::kSink, RequestKind::kExhaust); - runner.expectNextState(SessionState::kProcess, RequestKind::kDefault); - runner.expectNextState(SessionState::kSink, RequestKind::kDefault); - runner.expectFinalState(SessionState::kPoll); - - runner.run(); +TEST_F(BorrowedThreadSessionWorkflowTest, ExhaustLoop) { + runSteps(borrowedSteps(exhaustLoop())); } -TEST_F(SessionWorkflowWithBorrowedThreadsTest, MoreToComeLoop) { - auto runner = StepRunner(this); - - runner.expectNextState(SessionState::kPoll, RequestKind::kMoreToCome); - runner.expectNextState(SessionState::kSource, RequestKind::kMoreToCome); - runner.expectNextState(SessionState::kProcess, RequestKind::kMoreToCome); - runner.expectNextState(SessionState::kPoll, RequestKind::kDefault); - runner.expectNextState(SessionState::kSource, RequestKind::kDefault); - runner.expectNextState(SessionState::kProcess, RequestKind::kDefault); - runner.expectNextState(SessionState::kSink, RequestKind::kDefault); - runner.expectFinalState(SessionState::kPoll); - - runner.run(); +TEST_F(BorrowedThreadSessionWorkflowTest, MoreToComeLoop) { + runSteps(borrowedSteps(moreToComeLoop())); } } // namespace -} // namespace transport -} // namespace mongo +} // namespace mongo::transport |