summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/mongo/transport/session_workflow_test.cpp1231
1 files changed, 516 insertions, 715 deletions
diff --git a/src/mongo/transport/session_workflow_test.cpp b/src/mongo/transport/session_workflow_test.cpp
index f85b3e4d120..9cbcc5fb2db 100644
--- a/src/mongo/transport/session_workflow_test.cpp
+++ b/src/mongo/transport/session_workflow_test.cpp
@@ -30,6 +30,9 @@
#include "mongo/platform/basic.h"
+#include <array>
+#include <deque>
+#include <initializer_list>
#include <memory>
#include <type_traits>
#include <vector>
@@ -39,7 +42,6 @@
#include "mongo/bson/bsonobj.h"
#include "mongo/bson/bsonobjbuilder.h"
#include "mongo/db/client.h"
-#include "mongo/db/client_strand.h"
#include "mongo/db/concurrency/locker_noop_service_context_test_fixture.h"
#include "mongo/db/dbmessage.h"
#include "mongo/db/service_context.h"
@@ -51,22 +53,16 @@
#include "mongo/transport/mock_session.h"
#include "mongo/transport/service_entry_point.h"
#include "mongo/transport/service_entry_point_impl.h"
-#include "mongo/transport/service_executor.h"
-#include "mongo/transport/service_executor_utils.h"
#include "mongo/transport/session_workflow.h"
#include "mongo/transport/transport_layer_mock.h"
#include "mongo/unittest/unittest.h"
#include "mongo/util/assert_util.h"
-#include "mongo/util/clock_source_mock.h"
#include "mongo/util/concurrency/thread_pool.h"
-#include "mongo/util/producer_consumer_queue.h"
-#include "mongo/util/tick_source_mock.h"
#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest
-namespace mongo {
-namespace transport {
+namespace mongo::transport {
namespace {
/** Scope guard to set and restore an object value. */
@@ -84,181 +80,29 @@ private:
T _saved;
};
-const Status kClosedSessionError{ErrorCodes::SocketException, "Session is closed"};
-const Status kNetworkError{ErrorCodes::HostUnreachable, "Someone is unreachable"};
-const Status kShutdownError{ErrorCodes::ShutdownInProgress, "Something is shutting down"};
-const Status kArbitraryError{ErrorCodes::InternalError, "Something happened"};
-
-/**
- * FailureCondition represents a set of the ways any state in the SessionWorkflow can fail.
- */
-enum class FailureCondition {
- kNone,
- kTerminate, // External termination via the ServiceEntryPoint.
- kDisconnect, // Socket disconnection by peer.
- kNetwork, // Unspecified network failure (ala host unreachable).
- kShutdown, // System shutdown.
- kArbitrary, // An arbitrary error that does not fall under the other conditions.
-};
-
-constexpr StringData toString(FailureCondition fail) {
- switch (fail) {
- case FailureCondition::kNone:
- return "None"_sd;
- case FailureCondition::kTerminate:
- return "Terminate"_sd;
- case FailureCondition::kDisconnect:
- return "Disconnect"_sd;
- case FailureCondition::kNetwork:
- return "Network"_sd;
- case FailureCondition::kShutdown:
- return "Shutdown"_sd;
- case FailureCondition::kArbitrary:
- return "Arbitrary"_sd;
- };
-
- return "Unknown"_sd;
-}
-
-std::ostream& operator<<(std::ostream& os, FailureCondition fail) {
- return os << toString(fail);
+template <typename T>
+StringData findEnumName(const std::vector<std::pair<T, StringData>>& arr, T k) {
+ return std::find_if(arr.begin(), arr.end(), [&](auto&& e) { return e.first == k; })->second;
}
-/**
- * SessionState represents the externally observable state of the SessionWorkflow. These
- * states map relatively closely to the internals of the SessionWorkflow::Impl. That said,
- * this enum represents the SessionWorkflowTest's external understanding of the internal
- * state.
- */
-enum class SessionState {
- kStart,
- kPoll,
- kSource,
- kProcess,
- kSink,
- kEnd,
-};
-
-constexpr StringData toString(SessionState state) {
- switch (state) {
- case SessionState::kStart:
- return "Start"_sd;
- case SessionState::kPoll:
- return "Poll"_sd;
- case SessionState::kSource:
- return "Source"_sd;
- case SessionState::kProcess:
- return "Process"_sd;
- case SessionState::kSink:
- return "Sink"_sd;
- case SessionState::kEnd:
- return "End"_sd;
- };
-
- return "Unknown"_sd;
+template <typename T>
+static std::string typeName() {
+ return demangleName(typeid(T));
}
-
-std::ostream& operator<<(std::ostream& os, SessionState state) {
- return os << toString(state);
+template <typename T>
+static std::ostream& stream(std::ostream& os, const T&) {
+ return os << "[{}]"_format(typeName<T>());
}
-
-/**
- * RequestKind represents the type of operation of the SessionWorkflow. Depending on various
- * message flags and conditions, the SessionWorkflow will transition between states
- * differently.
- */
-enum class RequestKind {
- kDefault,
- kExhaust,
- kMoreToCome,
-};
-
-constexpr StringData toString(RequestKind kind) {
- switch (kind) {
- case RequestKind::kDefault:
- return "Default"_sd;
- case RequestKind::kExhaust:
- return "Exhaust"_sd;
- case RequestKind::kMoreToCome:
- return "MoreToCome"_sd;
- };
-
- return "Unknown"_sd;
+static std::ostream& stream(std::ostream& os, const Status& v) {
+ return os << v;
}
-
-std::ostream& operator<<(std::ostream& os, RequestKind kind) {
- return os << toString(kind);
+template <typename T>
+static std::ostream& stream(std::ostream& os, const StatusWith<T>& v) {
+ if (!v.isOK())
+ return stream(os, v.getStatus());
+ return stream(os, v.getValue());
}
-class ResultValue {
-public:
- ResultValue() = default;
- explicit ResultValue(Status status) : _value(std::move(status)) {}
- explicit ResultValue(StatusWith<Message> message) : _value(std::move(message)) {}
- explicit ResultValue(StatusWith<DbResponse> response) : _value(std::move(response)) {}
-
- void setResponse(StatusWith<DbResponse> response) {
- _value = response;
- }
-
- StatusWith<DbResponse> getResponse() const {
- return _convertTo<StatusWith<DbResponse>>();
- }
-
- void setMessage(StatusWith<Message> message) {
- _value = message;
- }
-
- StatusWith<Message> getMessage() const {
- return _convertTo<StatusWith<Message>>();
- }
-
- void setStatus(Status status) {
- _value = status;
- }
-
- Status getStatus() const {
- return _convertTo<Status>();
- }
-
- bool empty() const {
- return _value.index() == 0;
- }
-
- explicit operator bool() const {
- return !empty();
- }
-
-private:
- template <typename Target>
- Target _convertTo() const {
- return stdx::visit(
- [](auto alt) -> Target {
- if constexpr (std::is_convertible<decltype(alt), Target>())
- return alt;
- invariant(false, "ResultValue not convertible to target type");
- MONGO_COMPILER_UNREACHABLE;
- },
- _value);
- }
-
- stdx::variant<stdx::monostate, Status, StatusWith<Message>, StatusWith<DbResponse>> _value;
-};
-
-/**
- * This class stores and synchronizes the shared state result between the test
- * fixture and its various wrappers.
- */
-struct StateResult {
- Mutex mutex = MONGO_MAKE_LATCH("StateResult::_mutex");
- stdx::condition_variable cv;
-
- AtomicWord<bool> isConnected{true};
-
- ResultValue result;
- SessionState state;
-};
-
class CallbackMockSession : public MockSessionBase {
public:
TransportLayer* getTransportLayer() const override {
@@ -297,13 +141,13 @@ public:
return asyncSinkMessageCb(message, handle);
}
- std::function<TransportLayer*(void)> getTransportLayerCb;
- std::function<void(void)> endCb;
- std::function<bool(void)> isConnectedCb;
- std::function<Status(void)> waitForDataCb;
- std::function<StatusWith<Message>(void)> sourceMessageCb;
+ std::function<TransportLayer*()> getTransportLayerCb;
+ std::function<void()> endCb;
+ std::function<bool()> isConnectedCb;
+ std::function<Status()> waitForDataCb;
+ std::function<StatusWith<Message>()> sourceMessageCb;
std::function<Status(Message)> sinkMessageCb;
- std::function<Future<void>(void)> asyncWaitForDataCb;
+ std::function<Future<void>()> asyncWaitForDataCb;
std::function<Future<Message>(const BatonHandle&)> asyncSourceMessageCb;
std::function<Future<void>(Message, const BatonHandle&)> asyncSinkMessageCb;
};
@@ -317,7 +161,7 @@ public:
return handleRequestCb(opCtx, request);
}
- void onEndSession(const transport::SessionHandle& handle) override {
+ void onEndSession(const SessionHandle& handle) override {
onEndSessionCb(handle);
}
@@ -326,390 +170,370 @@ public:
}
std::function<Future<DbResponse>(OperationContext*, const Message&)> handleRequestCb;
- std::function<void(const transport::SessionHandle)> onEndSessionCb;
+ std::function<void(const SessionHandle)> onEndSessionCb;
std::function<void(Client*)> derivedOnClientDisconnectCb;
};
/**
- * The SessionWorkflowTest is a fixture that mocks the external inputs into the
- * SessionWorkflow so as to provide a deterministic way to evaluate the session workflow
- * implemenation.
+ * Events generated by SessionWorkflow, mostly by virtual function calls to mock
+ * objects connected to SessionWorkflow.
*/
-class SessionWorkflowTest : public LockerNoopServiceContextTest {
-public:
- /**
- * Make a generic thread pool to deliver external inputs out of line (mocking the network or
- * database workers).
- */
- static std::shared_ptr<ThreadPool> makeThreadPool() {
- auto options = ThreadPool::Options{};
- options.poolName = "SessionWorkflowTest";
+enum class Event {
+ kStart,
+ kWaitForData,
+ kSource,
+ kProcess,
+ kSink,
+ kEnd,
+};
- return std::make_shared<ThreadPool>(std::move(options));
- }
+StringData toString(Event e) {
+ return findEnumName(
+ {
+ {Event::kStart, "Start"_sd},
+ {Event::kWaitForData, "WaitForData"_sd},
+ {Event::kSource, "Source"_sd},
+ {Event::kProcess, "Process"_sd},
+ {Event::kSink, "Sink"_sd},
+ {Event::kEnd, "End"_sd},
+ },
+ e);
+}
- void setUp() override;
- void tearDown() override;
+std::ostream& operator<<(std::ostream& os, Event e) {
+ return os << toString(e);
+}
- /**
- * This function blocks until the SessionWorkflowTest observes a state change.
- */
- SessionState popSessionState() {
- return _stateQueue.pop();
+class Result {
+ using Variant =
+ stdx::variant<stdx::monostate, Status, StatusWith<Message>, StatusWith<DbResponse>>;
+
+public:
+ Result() = default;
+
+ template <typename T, std::enable_if_t<std::is_constructible_v<Variant, T&&>, int> = 0>
+ explicit Result(T&& v) : _value{std::forward<T>(v)} {}
+
+ template <typename T>
+ T consumeAs() && {
+ return stdx::visit(
+ [](auto&& alt) -> T {
+ using A = decltype(alt);
+ if constexpr (std::is_convertible<A, T>())
+ return std::forward<A>(alt);
+ invariant(0, "{} => {}"_format(typeName<A>(), typeName<T>()));
+ MONGO_UNREACHABLE;
+ },
+ std::exchange(_value, {}));
}
- /**
- * This function asserts that the SessionWorkflowTest has not yet observed a state change.
- *
- * Note that this function does not guarantee that it will not observe a state change in the
- * future.
- */
- void assertNoSessionState() {
- if (auto maybeState = _stateQueue.tryPop()) {
- FAIL("The queue is not empty, state: ") << *maybeState;
- }
+ explicit operator bool() const {
+ return _value.index() != 0;
}
- /**
- * This function stores an external response to be delivered out of line to the
- * SessionWorkflow.
- */
- void setResult(SessionState state, ResultValue result) {
- auto lk = stdx::lock_guard(_stateResult->mutex);
- invariant(state == SessionState::kPoll || state == SessionState::kSource ||
- state == SessionState::kProcess || state == SessionState::kSink);
- _stateResult->result = std::move(result);
- _stateResult->state = state;
- _stateResult->cv.notify_one();
+private:
+ friend std::string toString(const Result& r) {
+ return stdx::visit(
+ [](auto&& alt) -> std::string {
+ using A = std::decay_t<decltype(alt)>;
+ std::ostringstream os;
+ stream(os << "[{}]"_format(typeName<A>()), alt);
+ return os.str();
+ },
+ r._value);
}
+ Variant _value;
+};
+
+class SessionEventTracker {
+public:
/**
- * This function makes a generic result appropriate for a successful state change given
- * SessionState and RequestKind.
+ * Prepare a response for the `event`. Called by the test to inject
+ * a behavior for the next mock object that calls `consumeExpectation` with
+ * the same `event`. Only one response can be prepared at a time.
*/
- ResultValue makeGenericResult(SessionState state, RequestKind kind) {
- ResultValue result;
- switch (state) {
- case SessionState::kPoll:
- case SessionState::kSink:
- result.setStatus(Status::OK());
- break;
- case SessionState::kSource: {
- Message message = _makeIndexedBson();
- switch (kind) {
- case RequestKind::kExhaust:
- OpMsg::setFlag(&message, OpMsg::kExhaustSupported);
- break;
- case RequestKind::kDefault:
- case RequestKind::kMoreToCome:
- break;
- }
- result.setMessage(StatusWith<Message>(message));
- } break;
- case SessionState::kProcess: {
- DbResponse response;
- switch (kind) {
- case RequestKind::kDefault:
- response.response = _makeIndexedBson();
- break;
- case RequestKind::kExhaust:
- response.response = _makeIndexedBson();
- response.shouldRunAgainForExhaust = true;
- break;
- case RequestKind::kMoreToCome:
- break;
- }
- result.setResponse(response);
- } break;
+ void prepareResponse(Event event, Result v) {
+ switch (event) {
+ case Event::kWaitForData:
+ case Event::kSource:
+ case Event::kProcess:
+ case Event::kSink: {
+ stdx::lock_guard lk{_mutex};
+ invariant(!_expect);
+ _expect = std::unique_ptr<EventAndResult>{new EventAndResult{event, std::move(v)}};
+ LOGV2(6742612,
+ "SessionEventTracker::set",
+ "event"_attr = toString(_expect->event),
+ "value"_attr = toString(_expect->result));
+ _cv.notify_one();
+ return;
+ }
default:
- invariant(
- false,
- "Unable to make generic result for this state: {}"_format(toString(state)));
+ invariant(0, "SessionEventTracker::set for bad event={}"_format(toString(event)));
+ MONGO_UNREACHABLE;
}
- return result;
}
- /**
- * Initialize a new Session.
- */
- void initNewSession();
-
- /**
- * Launch a SessionWorkflow for the current session.
- */
- void startSession();
-
- /**
- * Wait for the current Session and SessionWorkflow to end.
- */
- void joinSession();
-
- /**
- * Mark the session as no longer connected.
- */
void endSession() {
- auto lk = stdx::lock_guard(_stateResult->mutex);
- if (_stateResult->isConnected.swap(false)) {
+ if (setConnected(false))
LOGV2(5014101, "Ending session");
- _stateResult->cv.notify_one();
- }
- }
-
- /**
- * Start a brand new session, run the given function, and then join the session.
- */
- template <typename F>
- void runWithNewSession(F&& func) {
- initNewSession();
- startSession();
-
- auto firstState = popSessionState();
- ASSERT(firstState == SessionState::kSource || firstState == SessionState::kPoll)
- << "State was instead: " << toString(firstState);
-
- std::forward<F>(func)();
-
- joinSession();
}
- void terminateViaServiceEntryPoint();
-
bool isConnected() const {
- return _stateResult->isConnected.load();
+ stdx::unique_lock lk{_mutex};
+ return _isConnected;
}
- int onClientDisconnectCalledTimes() const {
- return _onClientDisconnectCalled;
+ bool setConnected(bool b) {
+ stdx::unique_lock lk{_mutex};
+ bool old = _isConnected;
+ if (old != b) {
+ _isConnected = b;
+ _cv.notify_all();
+ }
+ return old;
}
-private:
/**
- * Generate a resonably generic BSON with an id for use in debugging.
+ * Called by mock objects to inject behavior into them. The mock function
+ * call is identified by an `event`. Waits if necessary until a response
+ * has been prepared for that event.
*/
- static Message _makeIndexedBson() {
- auto bob = BSONObjBuilder();
- static auto nextId = AtomicWord<int>{0};
- bob.append("id", nextId.fetchAndAdd(1));
-
- auto omb = OpMsgBuilder{};
- omb.setBody(bob.obj());
- return omb.finish();
+ Result consumeExpectation(Event event) {
+ stdx::unique_lock lk{_mutex};
+ _cv.wait(lk, [&] { return (_expect && _expect->event == event) || !_isConnected; });
+ if (!(_expect && _expect->event == event))
+ return Result(Status{ErrorCodes::SocketException, "Session is closed"});
+ invariant(_expect);
+ invariant(_expect->event == event);
+ invariant(_expect->result);
+ return std::exchange(_expect, {})->result;
}
- /**
- * Use an external result to mock handling a request.
- */
- StatusWith<DbResponse> _processRequest(OperationContext* opCtx, const Message&) noexcept {
- _stateQueue.push(SessionState::kProcess);
-
- auto result = [&]() -> StatusWith<DbResponse> {
- auto lk = stdx::unique_lock(_stateResult->mutex);
- _stateResult->cv.wait(lk, [this] {
- return (_stateResult->result && _stateResult->state == SessionState::kProcess) ||
- !isConnected();
- });
-
- if (!isConnected()) {
- return kClosedSessionError;
- }
-
- invariant(_stateResult->result);
- return std::exchange(_stateResult->result, {}).getResponse();
- }();
+private:
+ struct EventAndResult {
+ Event event;
+ Result result;
+ };
- LOGV2(5014100, "Handled request", "error"_attr = result.getStatus());
+ mutable Mutex _mutex;
+ stdx::condition_variable _cv;
+ bool _isConnected = true;
+ std::unique_ptr<EventAndResult> _expect;
+};
- return result;
+/** Fixture to mock interactions with the SessionWorkflow. */
+class SessionWorkflowTest : public LockerNoopServiceContextTest {
+public:
+ void setUp() override {
+ ServiceContextTest::setUp();
+ auto sc = getServiceContext();
+ sc->setServiceEntryPoint(_makeServiceEntryPoint(sc));
+ _makeSession();
+ invariant(sep()->start());
+ _threadPool->startup();
}
- Future<DbResponse> _handleRequest(OperationContext* opCtx, const Message& request) noexcept {
- auto [p, f] = makePromiseFuture<DbResponse>();
- ExecutorFuture<void>(_threadPool)
- .then([this, opCtx, &request, p = std::move(p)]() mutable {
- auto strand = ClientStrand::get(opCtx->getClient());
- strand->run([&] { p.setWith([&] { return _processRequest(opCtx, request); }); });
- })
- .getAsync([](auto&&) {});
- return std::move(f);
+ void tearDown() override {
+ ScopeGuard guard = [&] { ServiceContextTest::tearDown(); };
+ endSession();
+ // Normal shutdown is a noop outside of ASAN.
+ invariant(sep()->shutdownAndWait(Seconds{10}));
+ _threadPool->shutdown();
+ _threadPool->join();
}
- /**
- * Use an external result to mock polling for data and observe the state.
- */
- Status _waitForData() {
- _stateQueue.push(SessionState::kPoll);
-
- auto result = [&]() -> Status {
- auto lk = stdx::unique_lock(_stateResult->mutex);
- _stateResult->cv.wait(lk, [this] {
- return (_stateResult->result && _stateResult->state == SessionState::kPoll) ||
- !isConnected();
- });
-
- if (!isConnected()) {
- return kClosedSessionError;
- }
-
- invariant(_stateResult->result);
- return std::exchange(_stateResult->result, {}).getStatus();
- }();
-
- LOGV2(5014102, "Finished waiting for data", "error"_attr = result);
- return result;
+ /** Blocks until the SessionWorkflowTest observes an event. */
+ Event awaitEvent() {
+ return _eventSlot.wait();
}
- /**
- * Use an external result to mock reading data and observe the state.
- */
- StatusWith<Message> _sourceMessage() {
- _stateQueue.push(SessionState::kSource);
-
- auto result = [&]() -> StatusWith<Message> {
- auto lk = stdx::unique_lock(_stateResult->mutex);
- _stateResult->cv.wait(lk, [this] {
- return (_stateResult->result && _stateResult->state == SessionState::kSource) ||
- !isConnected();
- });
-
- if (!isConnected()) {
- return kClosedSessionError;
- }
+ /** Stores an event response to be consumed by a mock. */
+ void prepareResponse(Event event, Result result) {
+ _sessionEventTracker.prepareResponse(event, std::move(result));
+ }
- invariant(_stateResult->result);
- return std::exchange(_stateResult->result, {}).getMessage();
- }();
+ /** Waits for the current Session and SessionWorkflow to end. */
+ void joinSession() {
+ ASSERT(sep()->waitForNoSessions(Seconds{1}));
+ ASSERT_FALSE(_eventSlot) << "An unconsumed expectation is an error in the test";
+ }
- LOGV2(5014103, "Sourced message", "error"_attr = result.getStatus());
+ /** Launches a SessionWorkflow for the current session. */
+ void startSession() {
+ LOGV2(6742613, "Starting session");
+ _sessionEventTracker.setConnected(true);
+ sep()->startSession(_session);
+ }
- return result;
+ void endSession() {
+ _sessionEventTracker.endSession();
}
- /**
- * Use an external result to mock writing data and observe the state.
- */
- Status _sinkMessage(Message message) {
- _stateQueue.push(SessionState::kSink);
-
- auto result = [&]() -> Status {
- auto lk = stdx::unique_lock(_stateResult->mutex);
- _stateResult->cv.wait(lk, [this] {
- return (_stateResult->result && _stateResult->state == SessionState::kSink) ||
- !isConnected();
- });
-
- if (!isConnected()) {
- return kClosedSessionError;
- }
+ void terminateViaServiceEntryPoint() {
+ sep()->endAllSessionsNoTagMask();
+ }
- invariant(_stateResult->result);
- return std::exchange(_stateResult->result, {}).getStatus();
- }();
+ MockServiceEntryPoint* sep() {
+ return checked_cast<MockServiceEntryPoint*>(getServiceContext()->getServiceEntryPoint());
+ }
- LOGV2(5014104, "Sunk message", "error"_attr = result);
- return result;
+private:
+ std::shared_ptr<ThreadPool> _makeThreadPool() {
+ ThreadPool::Options options{};
+ options.poolName = "SessionWorkflowTest";
+ return std::make_shared<ThreadPool>(std::move(options));
}
- Future<void> _asyncWaitForData() noexcept {
- return ExecutorFuture<void>(_threadPool)
- .then([this] { return _waitForData(); })
- .unsafeToInlineFuture();
+ std::unique_ptr<MockServiceEntryPoint> _makeServiceEntryPoint(ServiceContext* sc) {
+ auto sep = std::make_unique<MockServiceEntryPoint>(sc);
+ sep->handleRequestCb = [=](OperationContext*, const Message&) {
+ return _onMockEvent<StatusWith<DbResponse>>(Event::kProcess);
+ };
+ sep->onEndSessionCb = [=](const SessionHandle&) { _onMockEvent<void>(Event::kEnd); };
+ sep->derivedOnClientDisconnectCb = [&](Client*) {};
+ return sep;
+ }
+
+ /** Create and initialize a mock Session. */
+ void _makeSession() {
+ auto s = std::make_shared<CallbackMockSession>();
+ s->endCb = [=] { endSession(); };
+ s->isConnectedCb = [=] { return _isConnected(); };
+ s->waitForDataCb = [=] { return _onMockEvent<Status>(Event::kWaitForData); };
+ s->sourceMessageCb = [=] { return _onMockEvent<StatusWith<Message>>(Event::kSource); };
+ s->sinkMessageCb = [=](Message) { return _onMockEvent<Status>(Event::kSink); };
+ // The async variants will just run the same callback on `_threadPool`.
+ auto async = [this](auto cb) {
+ return ExecutorFuture<void>(_threadPool).then(cb).unsafeToInlineFuture();
+ };
+ s->asyncWaitForDataCb = [=, cb = s->waitForDataCb] { return async([cb] { return cb(); }); };
+ s->asyncSourceMessageCb = [=, cb = s->sourceMessageCb](const BatonHandle&) {
+ return async([cb] { return cb(); });
+ };
+ s->asyncSinkMessageCb = [=, cb = s->sinkMessageCb](Message m, const BatonHandle&) {
+ return async([cb, m = std::move(m)]() mutable { return cb(std::move(m)); });
+ };
+ _session = std::move(s);
}
- Future<Message> _asyncSourceMessage() noexcept {
- return ExecutorFuture<void>(_threadPool)
- .then([this] { return _sourceMessage(); })
- .unsafeToInlineFuture();
+ bool _isConnected() const {
+ return _sessionEventTracker.isConnected();
}
- Future<void> _asyncSinkMessage(Message message) noexcept {
- return ExecutorFuture<void>(_threadPool)
- .then([this, message = std::move(message)]() mutable {
- return _sinkMessage(std::move(message));
- })
- .unsafeToInlineFuture();
+ /** Called by all mock functions to notify test thread and get a value with which to respond. */
+ template <typename Target>
+ Target _onMockEvent(Event event) {
+ LOGV2(6742616, "Arrived", "event"_attr = toString(event));
+ _eventSlot.signal(std::move(event));
+ if constexpr (std::is_same_v<Target, void>) {
+ return;
+ } else {
+ auto r = _sessionEventTracker.consumeExpectation(event);
+ LOGV2(6742618, "Waited for Event", "event"_attr = event, "result"_attr = toString(r));
+ return std::move(r).consumeAs<Target>();
+ }
}
- MockServiceEntryPoint* _sep;
+ /** An awaitable event slot. */
+ class SyncMockEventSlot {
+ public:
+ void signal(Event e) {
+ stdx::unique_lock lk(_mu);
+ invariant(!_event, "All events must be consumed in order");
+ _event = e;
+ _arrival.notify_one();
+ }
- const std::shared_ptr<ThreadPool> _threadPool = makeThreadPool();
+ Event wait() {
+ stdx::unique_lock lk(_mu);
+ _arrival.wait(lk, [&] { return _event; });
+ return *std::exchange(_event, {});
+ }
- std::unique_ptr<StateResult> _stateResult;
+ explicit operator bool() const {
+ stdx::unique_lock lk(_mu);
+ return !!_event;
+ }
- std::shared_ptr<CallbackMockSession> _session;
- SingleProducerSingleConsumerQueue<SessionState> _stateQueue;
+ private:
+ mutable Mutex _mu;
+ stdx::condition_variable _arrival;
+ boost::optional<Event> _event;
+ };
- int _onClientDisconnectCalled{0};
+ std::shared_ptr<Session> _session;
+ std::shared_ptr<ThreadPool> _threadPool = _makeThreadPool();
+ SessionEventTracker _sessionEventTracker;
+ SyncMockEventSlot _eventSlot;
};
-/**
- * This class iterates over the potential methods of failure for a set of steps.
- */
-class StepRunner {
- /**
- * This is a simple data structure describing the external response for one state in the
- * session workflow.
- */
- struct Step {
- SessionState state;
- RequestKind kind;
- };
- using StepList = std::vector<Step>;
-
-public:
- StepRunner(SessionWorkflowTest* fixture) : _fixture{fixture} {}
- ~StepRunner() {
- invariant(_runCount > 0, "StepRunner expects to be run at least once");
- }
+TEST_F(SessionWorkflowTest, StartThenEndSession) {
+ startSession();
+ ASSERT_EQ(awaitEvent(), Event::kSource);
+ endSession();
+ joinSession();
+}
- /**
- * Given a FailureCondition, cause an external result to be delivered that is appropriate for
- * the given state and request kind.
- */
- SessionState doGenericStep(const Step& step, FailureCondition fail) {
- switch (fail) {
- case FailureCondition::kNone: {
- _fixture->setResult(step.state, _fixture->makeGenericResult(step.state, step.kind));
- } break;
- case FailureCondition::kTerminate: {
- _fixture->terminateViaServiceEntryPoint();
- // We expect that the session will be disconnected via the SEP, no need to set any
- // result.
- } break;
- case FailureCondition::kDisconnect: {
- _fixture->endSession();
- // We expect that the session will be disconnected directly, no need to set any
- // result.
- } break;
- case FailureCondition::kNetwork: {
- _fixture->setResult(step.state, ResultValue(kNetworkError));
- } break;
- case FailureCondition::kShutdown: {
- _fixture->setResult(step.state, ResultValue(kShutdownError));
- } break;
- case FailureCondition::kArbitrary: {
- _fixture->setResult(step.state, ResultValue(kArbitraryError));
- } break;
- };
+TEST_F(SessionWorkflowTest, EndBeforeStartSession) {
+ endSession();
+ startSession();
+ ASSERT_EQ(awaitEvent(), Event::kSource);
+ endSession();
+ ASSERT_EQ(awaitEvent(), Event::kEnd);
+ joinSession();
+}
- return _fixture->popSessionState();
- }
+TEST_F(SessionWorkflowTest, OnClientDisconnectCalledOnCleanup) {
+ int disconnects = 0;
+ sep()->derivedOnClientDisconnectCb = [&](Client*) { ++disconnects; };
+ startSession();
+ ASSERT_EQ(awaitEvent(), Event::kSource);
+ ASSERT_EQ(disconnects, 0);
+ endSession();
+ ASSERT_EQ(awaitEvent(), Event::kEnd);
+ joinSession();
+ ASSERT_EQ(disconnects, 1);
+}
- /**
- * Mark an additional expected state in the session workflow.
- */
- void expectNextState(SessionState state, RequestKind kind) {
- auto step = Step{};
- step.state = state;
- step.kind = kind;
- _steps.emplace_back(std::move(step));
- }
+class StepRunner {
+public:
+ enum class Action {
+ kDefault,
+ kExhaust,
+ kMoreToCome,
+
+ kErrTerminate, // External termination via the ServiceEntryPoint.
+ kErrDisconnect, // Socket disconnection by peer.
+ kErrNetwork, // Unspecified network failure (ala host unreachable).
+ kErrShutdown, // System shutdown.
+ kErrArbitrary, // An arbitrary error that does not fall under the other conditions.
+ };
+ friend StringData toString(Action k) {
+ return findEnumName({{Action::kDefault, "Default"_sd},
+ {Action::kExhaust, "Exhaust"_sd},
+ {Action::kMoreToCome, "MoreToCome"_sd},
+ {Action::kErrTerminate, "ErrTerminate"_sd},
+ {Action::kErrDisconnect, "ErrDisconnect"_sd},
+ {Action::kErrNetwork, "ErrNetwork"_sd},
+ {Action::kErrShutdown, "ErrShutdown"_sd},
+ {Action::kErrArbitrary, "ErrArbitrary"_sd}},
+ k);
+ }
+
+ /** Encodes a response to `event` by taking `action`. */
+ struct Step {
+ Event event;
+ Action action = Action::kDefault;
+ };
- /**
- * Mark the final expected state in the session workflow.
- */
- void expectFinalState(SessionState state) {
- _finalState = state;
- }
+ // The final step is assumed to have `kErrDisconnect` as an action,
+ // yielding an implied `kEnd` step.
+ StepRunner(SessionWorkflowTest* fixture, std::deque<Step> steps)
+ : _fixture{fixture}, _steps{[&, at = steps.size() - 1] {
+ return _appendTermination(std::move(steps), at, Action::kErrDisconnect);
+ }()} {}
/**
* Run a set of variations on the steps, failing further and further along the way.
@@ -725,251 +549,228 @@ public:
* [(Source, None), (Sink, Terminate), (End)]
*/
void run() {
- invariant(_finalState);
+ const auto baseline = std::deque<Step>(_steps.begin(), _steps.end());
+ LOGV2(5014106, "Running one entirely clean run");
+ _runSteps(baseline);
+
+ static constexpr std::array fails{Action::kErrTerminate,
+ Action::kErrDisconnect,
+ Action::kErrNetwork,
+ Action::kErrShutdown,
+ Action::kErrArbitrary};
+
+ // Incrementally push forward the step where we fail.
+ for (size_t failAt = 0; failAt < baseline.size(); ++failAt) {
+ LOGV2(6742614, "Injecting failures", "failAt"_attr = failAt);
+ for (auto fail : fails)
+ _runSteps(_appendTermination(baseline, failAt, fail));
+ }
+ }
- auto getExpectedPostState = [&](auto iter) {
- auto nextIter = ++iter;
- if (nextIter == _steps.end()) {
- return *_finalState;
- }
- return nextIter->state;
- };
+private:
+ /**
+ * Returns a new steps sequence, formed by copying the specified `q`, and
+ * modifying the copy to be terminated with a `fail` at the `failAt` index.
+ */
+ std::deque<Step> _appendTermination(std::deque<Step> q, size_t failAt, Action fail) const {
+ LOGV2(6742617, "appendTermination", "fail"_attr = toString(fail), "failAt"_attr = failAt);
+ invariant(failAt < q.size());
+ q.erase(q.begin() + failAt + 1, q.end());
+ q.back().action = fail;
+ q.push_back({Event::kEnd});
+ return q;
+ }
+
+ template <typename T>
+ void _dumpTransitions(const T& q) {
+ BSONArrayBuilder bab;
+ for (auto&& t : q) {
+ BSONObjBuilder{bab.subobjStart()}
+ .append("event", toString(t.event))
+ .append("action", toString(t.action));
+ }
+ LOGV2(6742615, "Run transitions", "transitions"_attr = bab.arr());
+ }
- // Do one entirely clean run.
- LOGV2(5014106, "Running success case");
- _fixture->runWithNewSession([&] {
- for (auto iter = _steps.begin(); iter != _steps.end(); ++iter) {
- ASSERT_EQ(doGenericStep(*iter, FailureCondition::kNone),
- getExpectedPostState(iter));
- }
+ /** Generates an OpMsg containing a BSON with a unique 'id' field. */
+ Message _makeOpMsg() {
+ static auto nextId = AtomicWord<int>{0};
+ auto omb = OpMsgBuilder{};
+ omb.setBody(BSONObjBuilder{}.append("id", nextId.fetchAndAdd(1)).obj());
+ return omb.finish();
+ }
- _fixture->endSession();
- ASSERT_EQ(_fixture->popSessionState(), SessionState::kEnd);
- });
-
- const auto failList = std::vector<FailureCondition>{FailureCondition::kTerminate,
- FailureCondition::kDisconnect,
- FailureCondition::kNetwork,
- FailureCondition::kShutdown,
- FailureCondition::kArbitrary};
-
- for (auto failIter = _steps.begin(); failIter != _steps.end(); ++failIter) {
- // Incrementally push forward the step where we fail.
- for (auto fail : failList) {
- LOGV2(5014105,
- "Running failure case",
- "failureCase"_attr = fail,
- "sessionState"_attr = failIter->state,
- "requestKind"_attr = failIter->kind);
-
- _fixture->runWithNewSession([&] {
- auto iter = _steps.begin();
- for (; iter != failIter; ++iter) {
- // Run through each step until our point of failure with
- // FailureCondition::kNone.
- ASSERT_EQ(doGenericStep(*iter, FailureCondition::kNone),
- getExpectedPostState(iter))
- << "Current state: (" << iter->state << ", " << iter->kind << ")";
- }
-
- // Finally fail on a given step.
- ASSERT_EQ(doGenericStep(*iter, fail), SessionState::kEnd);
- });
+ /** Makes a result for a successful event. */
+ Result _successResult(Event event, Action action) {
+ switch (event) {
+ case Event::kWaitForData:
+ case Event::kSink:
+ return Result{Status::OK()};
+ case Event::kSource: {
+ Message m = _makeOpMsg();
+ switch (action) {
+ case Action::kExhaust:
+ OpMsg::setFlag(&m, OpMsg::kExhaustSupported);
+ break;
+ default:
+ break;
+ }
+ return Result{StatusWith{std::move(m)}};
+ }
+ case Event::kProcess: {
+ DbResponse response{};
+ switch (action) {
+ case Action::kDefault:
+ response.response = _makeOpMsg();
+ break;
+ case Action::kExhaust:
+ response.response = _makeOpMsg();
+ response.shouldRunAgainForExhaust = true;
+ break;
+ default:
+ break;
+ }
+ return Result{StatusWith{std::move(response)}};
}
+ default:
+ invariant(0, "Bad event: {}"_format(toString(event)));
}
-
- _runCount += 1;
+ MONGO_UNREACHABLE;
}
-private:
- SessionWorkflowTest* const _fixture;
+ void _injectStep(const Step& t) {
+ switch (t.action) {
+ case Action::kErrTerminate:
+ _fixture->terminateViaServiceEntryPoint();
+ break;
+ case Action::kErrDisconnect:
+ _fixture->endSession();
+ break;
+ case Action::kErrNetwork:
+ _fixture->prepareResponse(t.event, Result(Status{ErrorCodes::HostUnreachable, ""}));
+ break;
+ case Action::kErrShutdown:
+ _fixture->prepareResponse(t.event,
+ Result(Status{ErrorCodes::ShutdownInProgress, ""}));
+ break;
+ case Action::kErrArbitrary:
+ _fixture->prepareResponse(t.event, Result(Status{ErrorCodes::InternalError, ""}));
+ break;
+ default:
+ _fixture->prepareResponse(t.event, _successResult(t.event, t.action));
+ break;
+ }
+ }
- boost::optional<SessionState> _finalState;
- StepList _steps;
+ /** Start a new session, run the `steps` sequence, and join the session. */
+ void _runSteps(std::deque<Step> q) {
+ _dumpTransitions(q);
+ _fixture->startSession();
+ for (; !q.empty(); q.pop_front()) {
+ auto&& t = q.front();
+ auto event = _fixture->awaitEvent();
+ ASSERT_EQ(event, t.event);
+ LOGV2(6742610,
+ "Test main thread received an event and taking action",
+ "event"_attr = toString(t.event),
+ "action"_attr = toString(t.action));
+ if (t.event == Event::kEnd)
+ break;
+ _injectStep(t);
+ }
+ _fixture->joinSession();
+ }
- // This variable is currently used as a post-condition to make sure that the StepRunner has been
- // run. In the current form, it could be a boolean. That said, if you need to stress test the
- // SessionWorkflow, you will want to check this variable to make sure you have run as many
- // times as you expect.
- size_t _runCount = 0;
+ SessionWorkflowTest* _fixture;
+ std::deque<Step> _steps;
};
-void SessionWorkflowTest::initNewSession() {
- assertNoSessionState();
-
- _session = std::make_shared<CallbackMockSession>();
- _session->endCb = [&] { endSession(); };
- _session->isConnectedCb = [&] { return isConnected(); };
- _session->waitForDataCb = [&] { return _waitForData(); };
- _session->sourceMessageCb = [&] { return _sourceMessage(); };
- _session->sinkMessageCb = [&](Message message) { return _sinkMessage(std::move(message)); };
- _session->asyncWaitForDataCb = [&] { return _asyncWaitForData(); };
- _session->asyncSourceMessageCb = [&](const BatonHandle&) { return _asyncSourceMessage(); };
- _session->asyncSinkMessageCb = [&](Message message, const BatonHandle&) {
- return _asyncSinkMessage(std::move(message));
- };
- _stateResult->isConnected.store(true);
-}
-
-void SessionWorkflowTest::joinSession() {
- ASSERT(_sep->waitForNoSessions(Seconds{1}));
-
- assertNoSessionState();
-}
-
-void SessionWorkflowTest::startSession() {
- _sep->startSession(_session);
-}
-
-void SessionWorkflowTest::terminateViaServiceEntryPoint() {
- _sep->endAllSessionsNoTagMask();
-}
-
-void SessionWorkflowTest::setUp() {
- ServiceContextTest::setUp();
-
- auto sep = std::make_unique<MockServiceEntryPoint>(getServiceContext());
- sep->handleRequestCb = [&](OperationContext* opCtx, const Message& request) {
- return _handleRequest(opCtx, request);
- };
- sep->onEndSessionCb = [&](const transport::SessionHandle& session) {
- invariant(session == _session);
- _stateQueue.push(SessionState::kEnd);
- };
- sep->derivedOnClientDisconnectCb = [&](Client*) { ++_onClientDisconnectCalled; };
- _sep = sep.get();
- getServiceContext()->setServiceEntryPoint(std::move(sep));
- invariant(_sep->start());
-
- _threadPool->startup();
-
- _stateResult = std::make_unique<StateResult>();
-}
+template <bool useDedicatedThread>
+class StepRunnerSessionWorkflowTest : public SessionWorkflowTest {
+public:
+ using Action = StepRunner::Action;
-void SessionWorkflowTest::tearDown() {
- ON_BLOCK_EXIT([&] { ServiceContextTest::tearDown(); });
+ void runSteps(std::deque<StepRunner::Step> steps) {
+ StepRunner{this, steps}.run();
+ }
- endSession();
+ std::deque<StepRunner::Step> defaultLoop() const {
+ return {
+ {Event::kSource},
+ {Event::kProcess},
+ {Event::kSink},
+ {Event::kSource},
+ };
+ }
- // Normal shutdown is a noop outside of ASAN.
- invariant(_sep->shutdownAndWait(Seconds{10}));
+ std::deque<StepRunner::Step> exhaustLoop() const {
+ return {
+ {Event::kSource, Action::kExhaust},
+ {Event::kProcess, Action::kExhaust},
+ {Event::kSink},
+ {Event::kProcess},
+ {Event::kSink},
+ {Event::kSource},
+ };
+ }
- _threadPool->shutdown();
- _threadPool->join();
-}
+ std::deque<StepRunner::Step> moreToComeLoop() const {
+ return {
+ {Event::kSource, Action::kMoreToCome},
+ {Event::kProcess, Action::kMoreToCome},
+ {Event::kSource},
+ {Event::kProcess},
+ {Event::kSink},
+ {Event::kSource},
+ };
+ }
-template <bool useDedicatedThread>
-class DedicatedThreadOverrideTest : public SessionWorkflowTest {
+private:
ScopedValueOverride<bool> _svo{gInitialUseDedicatedThread, useDedicatedThread};
};
-using SessionWorkflowWithBorrowedThreadsTest = DedicatedThreadOverrideTest<false>;
-using SessionWorkflowWithDedicatedThreadsTest = DedicatedThreadOverrideTest<true>;
-
-TEST_F(SessionWorkflowTest, StartThenEndSession) {
- initNewSession();
- startSession();
-
- ASSERT_EQ(popSessionState(), SessionState::kSource);
-
- endSession();
-}
-
-TEST_F(SessionWorkflowTest, EndBeforeStartSession) {
- initNewSession();
- endSession();
- startSession();
-}
-
-TEST_F(SessionWorkflowTest, OnClientDisconnectCalledOnCleanup) {
- initNewSession();
- startSession();
- ASSERT_EQ(popSessionState(), SessionState::kSource);
- ASSERT_EQ(onClientDisconnectCalledTimes(), 0);
- endSession();
- ASSERT_EQ(popSessionState(), SessionState::kEnd);
- joinSession();
- ASSERT_EQ(onClientDisconnectCalledTimes(), 1);
-}
-
-TEST_F(SessionWorkflowWithDedicatedThreadsTest, DefaultLoop) {
- auto runner = StepRunner(this);
+class DedicatedThreadSessionWorkflowTest : public StepRunnerSessionWorkflowTest<true> {};
- runner.expectNextState(SessionState::kSource, RequestKind::kDefault);
- runner.expectNextState(SessionState::kProcess, RequestKind::kDefault);
- runner.expectNextState(SessionState::kSink, RequestKind::kDefault);
- runner.expectFinalState(SessionState::kSource);
-
- runner.run();
+TEST_F(DedicatedThreadSessionWorkflowTest, DefaultLoop) {
+ runSteps(defaultLoop());
}
-TEST_F(SessionWorkflowWithDedicatedThreadsTest, ExhaustLoop) {
- auto runner = StepRunner(this);
-
- runner.expectNextState(SessionState::kSource, RequestKind::kExhaust);
- runner.expectNextState(SessionState::kProcess, RequestKind::kExhaust);
- runner.expectNextState(SessionState::kSink, RequestKind::kExhaust);
- runner.expectNextState(SessionState::kProcess, RequestKind::kDefault);
- runner.expectNextState(SessionState::kSink, RequestKind::kDefault);
- runner.expectFinalState(SessionState::kSource);
-
- runner.run();
+TEST_F(DedicatedThreadSessionWorkflowTest, ExhaustLoop) {
+ runSteps(exhaustLoop());
}
-TEST_F(SessionWorkflowWithDedicatedThreadsTest, MoreToComeLoop) {
- auto runner = StepRunner(this);
-
- runner.expectNextState(SessionState::kSource, RequestKind::kMoreToCome);
- runner.expectNextState(SessionState::kProcess, RequestKind::kMoreToCome);
- runner.expectNextState(SessionState::kSource, RequestKind::kDefault);
- runner.expectNextState(SessionState::kProcess, RequestKind::kDefault);
- runner.expectNextState(SessionState::kSink, RequestKind::kDefault);
- runner.expectFinalState(SessionState::kSource);
-
- runner.run();
+TEST_F(DedicatedThreadSessionWorkflowTest, MoreToComeLoop) {
+ runSteps(moreToComeLoop());
}
-TEST_F(SessionWorkflowWithBorrowedThreadsTest, DefaultLoop) {
- auto runner = StepRunner(this);
-
- runner.expectNextState(SessionState::kPoll, RequestKind::kDefault);
- runner.expectNextState(SessionState::kSource, RequestKind::kDefault);
- runner.expectNextState(SessionState::kProcess, RequestKind::kDefault);
- runner.expectNextState(SessionState::kSink, RequestKind::kDefault);
- runner.expectFinalState(SessionState::kPoll);
+class BorrowedThreadSessionWorkflowTest : public StepRunnerSessionWorkflowTest<false> {
+public:
+ /**
+ * Under the borrowed thread model, the steps are the same as for dedicated thread model,
+ * except that kSource events are preceded by kWaitForData events.
+ */
+ std::deque<StepRunner::Step> borrowedSteps(std::deque<StepRunner::Step> q) {
+ for (auto iter = q.begin(); iter != q.end(); ++iter) {
+ if (iter->event == Event::kSource) {
+ iter = q.insert(iter, {Event::kWaitForData});
+ ++iter;
+ }
+ }
+ return q;
+ }
+};
- runner.run();
+TEST_F(BorrowedThreadSessionWorkflowTest, DefaultLoop) {
+ runSteps(borrowedSteps(defaultLoop()));
}
-TEST_F(SessionWorkflowWithBorrowedThreadsTest, ExhaustLoop) {
- auto runner = StepRunner(this);
-
- runner.expectNextState(SessionState::kPoll, RequestKind::kExhaust);
- runner.expectNextState(SessionState::kSource, RequestKind::kExhaust);
- runner.expectNextState(SessionState::kProcess, RequestKind::kExhaust);
- runner.expectNextState(SessionState::kSink, RequestKind::kExhaust);
- runner.expectNextState(SessionState::kProcess, RequestKind::kDefault);
- runner.expectNextState(SessionState::kSink, RequestKind::kDefault);
- runner.expectFinalState(SessionState::kPoll);
-
- runner.run();
+TEST_F(BorrowedThreadSessionWorkflowTest, ExhaustLoop) {
+ runSteps(borrowedSteps(exhaustLoop()));
}
-TEST_F(SessionWorkflowWithBorrowedThreadsTest, MoreToComeLoop) {
- auto runner = StepRunner(this);
-
- runner.expectNextState(SessionState::kPoll, RequestKind::kMoreToCome);
- runner.expectNextState(SessionState::kSource, RequestKind::kMoreToCome);
- runner.expectNextState(SessionState::kProcess, RequestKind::kMoreToCome);
- runner.expectNextState(SessionState::kPoll, RequestKind::kDefault);
- runner.expectNextState(SessionState::kSource, RequestKind::kDefault);
- runner.expectNextState(SessionState::kProcess, RequestKind::kDefault);
- runner.expectNextState(SessionState::kSink, RequestKind::kDefault);
- runner.expectFinalState(SessionState::kPoll);
-
- runner.run();
+TEST_F(BorrowedThreadSessionWorkflowTest, MoreToComeLoop) {
+ runSteps(borrowedSteps(moreToComeLoop()));
}
} // namespace
-} // namespace transport
-} // namespace mongo
+} // namespace mongo::transport