diff options
22 files changed, 448 insertions, 31 deletions
diff --git a/jstests/libs/global_snapshot_reads_util.js b/jstests/libs/global_snapshot_reads_util.js new file mode 100644 index 00000000000..a83ddfbb340 --- /dev/null +++ b/jstests/libs/global_snapshot_reads_util.js @@ -0,0 +1,30 @@ +/** + * Tests invalid getMore attempts against an established global snapshot cursor on mongos. The + * cursor should still be valid and usable after each failed attempt. + */ +function verifyInvalidGetMoreAttempts(mainDb, sessionDb, collName, cursorId, txnNumber) { + // Reject getMores without a session. + assert.commandFailedWithCode( + mainDb.runCommand({getMore: cursorId, collection: collName, batchSize: 1}), 50800); + + // Subsequent getMore requests without the same session id are rejected. The cursor should + // still be valid and usable after this failed attempt. + assert.commandFailedWithCode(mainDb.runCommand({ + getMore: cursorId, + collection: collName, + batchSize: 1, + txnNumber: NumberLong(txnNumber), + lsid: {id: UUID()} + }), + 50801); + + // Reject getMores without without txnNumber. + assert.commandFailedWithCode( + sessionDb.runCommand({getMore: cursorId, collection: collName, batchSize: 1}), 50803); + + // Reject getMores without without same txnNumber. + assert.commandFailedWithCode( + sessionDb.runCommand( + {getMore: cursorId, collection: collName, batchSize: 1, txnNumber: NumberLong(50)}), + 50804); +} diff --git a/jstests/sharding/kill_pinned_cursor.js b/jstests/sharding/kill_pinned_cursor.js index f636f981082..417ed66817e 100644 --- a/jstests/sharding/kill_pinned_cursor.js +++ b/jstests/sharding/kill_pinned_cursor.js @@ -32,28 +32,29 @@ // string. This means that we can't pass it functions which capture variables. Instead we use // the trick below, by putting the values for the variables we'd like to capture inside the // string. Kudos to Dave Storch for coming up with this idea. - function makeParallelShellFunctionString(cursorId, getMoreErrCode, useSession) { + function makeParallelShellFunctionString(cursorId, getMoreErrCodes, useSession, sessionId) { let code = `const cursorId = ${cursorId.toString()};`; code += `const kDBName = "${kDBName}";`; code += `let collName = "${coll.getName()}";`; - code += `let getMoreErrCode = ${getMoreErrCode};`; code += `const useSession = ${useSession};`; + TestData.getMoreErrCodes = getMoreErrCodes; + if (useSession) { + TestData.sessionId = sessionId; + } + const runGetMore = function() { - let dbToUse = db; - let session = null; + let getMoreCmd = {getMore: cursorId, collection: collName, batchSize: 4}; + if (useSession) { - session = db.getMongo().startSession(); - dbToUse = session.getDatabase(kDBName); + getMoreCmd.lsid = TestData.sessionId; } - let response = - dbToUse.runCommand({getMore: cursorId, collection: collName, batchSize: 4}); // We expect that the operation will get interrupted and fail. - assert.commandFailedWithCode(response, getMoreErrCode); + assert.commandFailedWithCode(db.runCommand(getMoreCmd), TestData.getMoreErrCodes); - if (session) { - session.endSession(); + if (useSession) { + assert.commandWorked(db.adminCommand({endSessions: [TestData.sessionId]})); } }; @@ -67,13 +68,14 @@ // cursor to hang due to getMore commands hanging on each of the shards. Then invokes // 'killFunc', and verifies the cursors on the shards and the mongos cursor get cleaned up. // - // 'getMoreErrCode' is the error code with which we expect the getMore to fail (e.g. a + // 'getMoreErrCodes' are the error codes with which we expect the getMore to fail (e.g. a // killCursors command should cause getMore to fail with "CursorKilled", but killOp should cause // a getMore to fail with "Interrupted"). function testShardedKillPinned( - {killFunc: killFunc, getMoreErrCode: getMoreErrCode, useSession: useSession}) { + {killFunc: killFunc, getMoreErrCodes: getMoreErrCodes, useSession: useSession}) { let getMoreJoiner = null; let cursorId; + let sessionId; try { // Set up the mongods to hang on a getMore request. ONLY set the failpoint on the @@ -85,13 +87,21 @@ {configureFailPoint: kFailPointName, mode: "alwaysOn", data: kFailpointOptions})); // Run a find against mongos. This should open cursors on both of the shards. - let cmdRes = mongosDB.runCommand({find: coll.getName(), batchSize: 2}); + let findCmd = {find: coll.getName(), batchSize: 2}; + + if (useSession) { + // Manually start a session so it can be continued from inside a parallel shell. + sessionId = assert.commandWorked(mongosDB.adminCommand({startSession: 1})).id; + findCmd.lsid = sessionId; + } + + let cmdRes = mongosDB.runCommand(findCmd); assert.commandWorked(cmdRes); cursorId = cmdRes.cursor.id; assert.neq(cursorId, NumberLong(0)); const parallelShellFn = - makeParallelShellFunctionString(cursorId, getMoreErrCode, useSession); + makeParallelShellFunctionString(cursorId, getMoreErrCodes, useSession, sessionId); getMoreJoiner = startParallelShell(parallelShellFn, st.s.port); // Sleep until we know the mongod cursors are pinned. @@ -147,7 +157,7 @@ assert.eq(cmdRes.cursorsNotFound, []); assert.eq(cmdRes.cursorsUnknown, []); }, - getMoreErrCode: ErrorCodes.CursorKilled, + getMoreErrCodes: ErrorCodes.CursorKilled, useSession: useSession }); @@ -167,7 +177,7 @@ let killOpResult = shard0DB.killOp(currentGetMore.opid); assert.commandWorked(killOpResult); }, - getMoreErrCode: ErrorCodes.Interrupted, + getMoreErrCodes: ErrorCodes.Interrupted, useSession: useSession }); @@ -193,7 +203,7 @@ assert.eq(cmdRes.cursorsNotFound, []); assert.eq(cmdRes.cursorsUnknown, []); }, - getMoreErrCode: ErrorCodes.CursorKilled, + getMoreErrCodes: ErrorCodes.CursorKilled, useSession: useSession }); } @@ -216,7 +226,10 @@ const sessionUUID = localSessions[0]._id.id; assert.commandWorked(mongosDB.runCommand({killSessions: [{id: sessionUUID}]})); }, - getMoreErrCode: ErrorCodes.Interrupted, + // Killing a session on mongos kills all matching remote cursors (through KillCursors) then + // all matching local operations (through KillOp), so the getMore can fail with either + // CursorKilled or Interrupted depending on which response is returned first. + getMoreErrCodes: [ErrorCodes.CursorKilled, ErrorCodes.Interrupted], useSession: true, }); diff --git a/jstests/sharding/snapshot_aggregate_mongos.js b/jstests/sharding/snapshot_aggregate_mongos.js index 36a2a529c7b..b0480fc534c 100644 --- a/jstests/sharding/snapshot_aggregate_mongos.js +++ b/jstests/sharding/snapshot_aggregate_mongos.js @@ -3,6 +3,8 @@ (function() { "use strict"; + load("jstests/libs/global_snapshot_reads_util.js"); + const dbName = "test"; const shardedCollName = "shardedColl"; const unshardedCollName = "unshardedColl"; @@ -56,6 +58,8 @@ // performed outside of the session. assert.writeOK(mainDb[collName].insert({_id: 10}, {writeConcern: {w: "majority"}})); + verifyInvalidGetMoreAttempts(mainDb, sessionDb, collName, cursorId, txnNumber); + // Fetch the 6th document. This confirms that the transaction stash is preserved across // multiple getMore invocations. res = assert.commandWorked(sessionDb.runCommand({ diff --git a/jstests/sharding/snapshot_find_mongos.js b/jstests/sharding/snapshot_find_mongos.js index 8687654ac39..92b5b679730 100644 --- a/jstests/sharding/snapshot_find_mongos.js +++ b/jstests/sharding/snapshot_find_mongos.js @@ -4,6 +4,8 @@ (function() { "use strict"; + load("jstests/libs/global_snapshot_reads_util.js"); + const dbName = "test"; const shardedCollName = "shardedColl"; const unshardedCollName = "unshardedColl"; @@ -57,6 +59,8 @@ // performed outside of the session. assert.writeOK(mainDb[collName].insert({_id: 10}, {writeConcern: {w: "majority"}})); + verifyInvalidGetMoreAttempts(mainDb, sessionDb, collName, cursorId, txnNumber); + // Fetch the 6th document. This confirms that the transaction stash is preserved across // multiple getMore invocations. res = assert.commandWorked(sessionDb.runCommand({ diff --git a/src/mongo/db/operation_context_test.cpp b/src/mongo/db/operation_context_test.cpp index ac14497fa25..751aee986cc 100644 --- a/src/mongo/db/operation_context_test.cpp +++ b/src/mongo/db/operation_context_test.cpp @@ -41,6 +41,7 @@ #include "mongo/stdx/memory.h" #include "mongo/stdx/thread.h" #include "mongo/unittest/barrier.h" +#include "mongo/unittest/death_test.h" #include "mongo/unittest/unittest.h" #include "mongo/util/clock_source_mock.h" #include "mongo/util/tick_source_mock.h" @@ -111,6 +112,34 @@ TEST(OperationContextTest, SessionIdAndTransactionNumber) { ASSERT_EQUALS(5, *opCtx->getTxnNumber()); } +DEATH_TEST(OperationContextTest, SettingSessionIdMoreThanOnceShouldCrash, "invariant") { + auto serviceCtx = stdx::make_unique<ServiceContextNoop>(); + auto client = serviceCtx->makeClient("OperationContextTest"); + auto opCtx = client->makeOperationContext(); + + opCtx->setLogicalSessionId(makeLogicalSessionIdForTest()); + opCtx->setLogicalSessionId(makeLogicalSessionIdForTest()); +} + +DEATH_TEST(OperationContextTest, SettingTransactionNumberMoreThanOnceShouldCrash, "invariant") { + auto serviceCtx = stdx::make_unique<ServiceContextNoop>(); + auto client = serviceCtx->makeClient("OperationContextTest"); + auto opCtx = client->makeOperationContext(); + + opCtx->setLogicalSessionId(makeLogicalSessionIdForTest()); + + opCtx->setTxnNumber(5); + opCtx->setTxnNumber(5); +} + +DEATH_TEST(OperationContextTest, SettingTransactionNumberWithoutSessionIdShouldCrash, "invariant") { + auto serviceCtx = stdx::make_unique<ServiceContextNoop>(); + auto client = serviceCtx->makeClient("OperationContextTest"); + auto opCtx = client->makeOperationContext(); + + opCtx->setTxnNumber(5); +} + TEST(OperationContextTest, OpCtxGroup) { OperationContextGroup group1; ASSERT_TRUE(group1.isEmpty()); diff --git a/src/mongo/s/commands/cluster_aggregate.cpp b/src/mongo/s/commands/cluster_aggregate.cpp index c0643205001..46cb784032a 100644 --- a/src/mongo/s/commands/cluster_aggregate.cpp +++ b/src/mongo/s/commands/cluster_aggregate.cpp @@ -551,6 +551,8 @@ BSONObj establishMergingMongosCursor(OperationContext* opCtx, params.batchSize = request.getBatchSize() == 0 ? boost::none : boost::optional<long long>(request.getBatchSize()); + params.lsid = opCtx->getLogicalSessionId(); + params.txnNumber = opCtx->getTxnNumber(); if (liteParsedPipeline.hasChangeStream()) { // For change streams, we need to set up a custom stage to establish cursors on new shards diff --git a/src/mongo/s/query/async_results_merger.cpp b/src/mongo/s/query/async_results_merger.cpp index 45403399e8c..3e41a36c089 100644 --- a/src/mongo/s/query/async_results_merger.cpp +++ b/src/mongo/s/query/async_results_merger.cpp @@ -94,6 +94,10 @@ AsyncResultsMerger::AsyncResultsMerger(OperationContext* opCtx, _mergeQueue(MergingComparator(_remotes, _params.getSort() ? *_params.getSort() : BSONObj(), _params.getCompareWholeSortKey())) { + if (params.getTxnNumber()) { + invariant(params.getSessionId()); + } + size_t remoteIndex = 0; for (const auto& remote : _params.getRemotes()) { _remotes.emplace_back(remote.getHostAndPort(), @@ -352,6 +356,20 @@ Status AsyncResultsMerger::_askForNextBatch(WithLock, size_t remoteIndex) { boost::none) .toBSON(); + if (_params.getSessionId()) { + BSONObjBuilder newCmdBob(std::move(cmdObj)); + + BSONObjBuilder lsidBob(newCmdBob.subobjStart(OperationSessionInfo::kSessionIdFieldName)); + _params.getSessionId()->serialize(&lsidBob); + lsidBob.doneFast(); + + if (_params.getTxnNumber()) { + newCmdBob.append(OperationSessionInfo::kTxnNumberFieldName, *_params.getTxnNumber()); + } + + cmdObj = newCmdBob.obj(); + } + executor::RemoteCommandRequest request( remote.getTargetHost(), _params.getNss().db().toString(), cmdObj, _opCtx); diff --git a/src/mongo/s/query/async_results_merger_params.idl b/src/mongo/s/query/async_results_merger_params.idl index dafc9b53c1c..2d85faf730f 100644 --- a/src/mongo/s/query/async_results_merger_params.idl +++ b/src/mongo/s/query/async_results_merger_params.idl @@ -32,6 +32,7 @@ global: - "mongo/db/query/cursor_response.h" imports: + - "mongo/db/logical_session_id.idl" - "mongo/db/query/tailable_mode.idl" - "mongo/idl/basic_types.idl" - "mongo/util/net/hostandport.idl" @@ -61,6 +62,8 @@ structs: AsyncResultsMergerParams: description: The parameters needed to establish an AsyncResultsMerger. + chained_structs: + OperationSessionInfo : OperationSessionInfo fields: sort: type: object diff --git a/src/mongo/s/query/async_results_merger_test.cpp b/src/mongo/s/query/async_results_merger_test.cpp index 6fd81715e90..586b387ad96 100644 --- a/src/mongo/s/query/async_results_merger_test.cpp +++ b/src/mongo/s/query/async_results_merger_test.cpp @@ -43,6 +43,7 @@ #include "mongo/s/client/shard_registry.h" #include "mongo/s/sharding_router_test_fixture.h" #include "mongo/stdx/memory.h" +#include "mongo/unittest/death_test.h" #include "mongo/unittest/unittest.h" namespace mongo { @@ -64,6 +65,10 @@ const std::vector<HostAndPort> kTestShardHosts = {HostAndPort("FakeShard1Host", const NamespaceString kTestNss("testdb.testcoll"); +LogicalSessionId parseSessionIdFromCmd(BSONObj cmdObj) { + return LogicalSessionId::parse(IDLParserErrorContext("lsid"), cmdObj["lsid"].Obj()); +} + class AsyncResultsMergerTest : public ShardingTestFixture { public: AsyncResultsMergerTest() {} @@ -137,6 +142,11 @@ protected: params.setAllowPartialResults(qr->isAllowPartialResults()); } + OperationSessionInfo sessionInfo; + sessionInfo.setSessionId(operationContext()->getLogicalSessionId()); + sessionInfo.setTxnNumber(operationContext()->getTxnNumber()); + params.setOperationSessionInfo(sessionInfo); + return stdx::make_unique<AsyncResultsMerger>( operationContext(), executor(), std::move(params)); } @@ -2004,5 +2014,117 @@ TEST_F(AsyncResultsMergerTest, ShouldBeAbleToBlockUntilKilled) { arm->blockingKill(operationContext()); } +TEST_F(AsyncResultsMergerTest, GetMoresShouldNotIncludeLSIDOrTxnNumberIfNoneSpecified) { + std::vector<RemoteCursor> cursors; + cursors.emplace_back( + makeRemoteCursor(kTestShardIds[0], kTestShardHosts[0], CursorResponse(kTestNss, 1, {}))); + auto arm = makeARMFromExistingCursors(std::move(cursors)); + + // There should be no lsid txnNumber in the scheduled getMore. + ASSERT_OK(arm->nextEvent().getStatus()); + onCommand([&](const auto& request) { + ASSERT(request.cmdObj["getMore"]); + + ASSERT(request.cmdObj["lsid"].eoo()); + ASSERT(request.cmdObj["txnNumber"].eoo()); + + return CursorResponse(kTestNss, 0LL, {BSON("x" << 1)}) + .toBSON(CursorResponse::ResponseType::SubsequentResponse); + }); +} + +TEST_F(AsyncResultsMergerTest, GetMoresShouldIncludeLSIDIfSpecified) { + auto lsid = makeLogicalSessionIdForTest(); + operationContext()->setLogicalSessionId(lsid); + + std::vector<RemoteCursor> cursors; + cursors.emplace_back( + makeRemoteCursor(kTestShardIds[0], kTestShardHosts[0], CursorResponse(kTestNss, 1, {}))); + auto arm = makeARMFromExistingCursors(std::move(cursors)); + + // There should be an lsid and no txnNumber in the scheduled getMore. + ASSERT_OK(arm->nextEvent().getStatus()); + onCommand([&](const auto& request) { + ASSERT(request.cmdObj["getMore"]); + + ASSERT_EQ(parseSessionIdFromCmd(request.cmdObj), lsid); + ASSERT(request.cmdObj["txnNumber"].eoo()); + + return CursorResponse(kTestNss, 1LL, {BSON("x" << 1)}) + .toBSON(CursorResponse::ResponseType::SubsequentResponse); + }); + + // Subsequent requests still pass the lsid. + ASSERT(arm->ready()); + ASSERT_OK(arm->nextReady().getStatus()); + ASSERT_FALSE(arm->ready()); + + ASSERT_OK(arm->nextEvent().getStatus()); + onCommand([&](const auto& request) { + ASSERT(request.cmdObj["getMore"]); + + ASSERT_EQ(parseSessionIdFromCmd(request.cmdObj), lsid); + ASSERT(request.cmdObj["txnNumber"].eoo()); + + return CursorResponse(kTestNss, 0LL, {BSON("x" << 1)}) + .toBSON(CursorResponse::ResponseType::SubsequentResponse); + }); +} + +TEST_F(AsyncResultsMergerTest, GetMoresShouldIncludeLSIDAndTxnNumIfSpecified) { + auto lsid = makeLogicalSessionIdForTest(); + operationContext()->setLogicalSessionId(lsid); + + const TxnNumber txnNumber = 5; + operationContext()->setTxnNumber(txnNumber); + + std::vector<RemoteCursor> cursors; + cursors.emplace_back( + makeRemoteCursor(kTestShardIds[0], kTestShardHosts[0], CursorResponse(kTestNss, 1, {}))); + auto arm = makeARMFromExistingCursors(std::move(cursors)); + + // The first scheduled getMore should pass the txnNumber the ARM was constructed with. + ASSERT_OK(arm->nextEvent().getStatus()); + onCommand([&](const auto& request) { + ASSERT(request.cmdObj["getMore"]); + + ASSERT_EQ(parseSessionIdFromCmd(request.cmdObj), lsid); + ASSERT_EQ(request.cmdObj["txnNumber"].numberLong(), txnNumber); + + return CursorResponse(kTestNss, 1LL, {BSON("x" << 1)}) + .toBSON(CursorResponse::ResponseType::SubsequentResponse); + }); + + // Subsequent requests still pass the txnNumber. + ASSERT(arm->ready()); + ASSERT_OK(arm->nextReady().getStatus()); + ASSERT_FALSE(arm->ready()); + + // Subsequent getMore requests should include txnNumber. + ASSERT_OK(arm->nextEvent().getStatus()); + onCommand([&](const auto& request) { + ASSERT(request.cmdObj["getMore"]); + + ASSERT_EQ(parseSessionIdFromCmd(request.cmdObj), lsid); + ASSERT_EQ(request.cmdObj["txnNumber"].numberLong(), txnNumber); + + return CursorResponse(kTestNss, 0LL, {BSON("x" << 1)}) + .toBSON(CursorResponse::ResponseType::SubsequentResponse); + }); +} + +DEATH_TEST_F(AsyncResultsMergerTest, + ConstructingARMWithTxnNumAndNoLSIDShouldCrash, + "Invariant failure params.getSessionId()") { + AsyncResultsMergerParams params; + + OperationSessionInfo sessionInfo; + sessionInfo.setTxnNumber(5); + params.setOperationSessionInfo(sessionInfo); + + // This should trigger an invariant. + stdx::make_unique<AsyncResultsMerger>(operationContext(), executor(), std::move(params)); +} + } // namespace } // namespace mongo diff --git a/src/mongo/s/query/cluster_client_cursor.h b/src/mongo/s/query/cluster_client_cursor.h index 1afb5ae1b38..5ad4ec4298f 100644 --- a/src/mongo/s/query/cluster_client_cursor.h +++ b/src/mongo/s/query/cluster_client_cursor.h @@ -151,6 +151,11 @@ public: virtual boost::optional<LogicalSessionId> getLsid() const = 0; /** + * Returns the transaction number for this cursor. + */ + virtual boost::optional<TxnNumber> getTxnNumber() const = 0; + + /** * Returns the readPreference for this cursor. */ virtual boost::optional<ReadPreferenceSetting> getReadPreference() const = 0; diff --git a/src/mongo/s/query/cluster_client_cursor_impl.cpp b/src/mongo/s/query/cluster_client_cursor_impl.cpp index e5348f3d86f..1b3a665df5e 100644 --- a/src/mongo/s/query/cluster_client_cursor_impl.cpp +++ b/src/mongo/s/query/cluster_client_cursor_impl.cpp @@ -175,6 +175,10 @@ boost::optional<LogicalSessionId> ClusterClientCursorImpl::getLsid() const { return _lsid; } +boost::optional<TxnNumber> ClusterClientCursorImpl::getTxnNumber() const { + return _params.txnNumber; +} + boost::optional<ReadPreferenceSetting> ClusterClientCursorImpl::getReadPreference() const { return _params.readPreference; } diff --git a/src/mongo/s/query/cluster_client_cursor_impl.h b/src/mongo/s/query/cluster_client_cursor_impl.h index 34fa7d16c61..36f9d3995c8 100644 --- a/src/mongo/s/query/cluster_client_cursor_impl.h +++ b/src/mongo/s/query/cluster_client_cursor_impl.h @@ -117,6 +117,8 @@ public: boost::optional<LogicalSessionId> getLsid() const final; + boost::optional<TxnNumber> getTxnNumber() const final; + boost::optional<ReadPreferenceSetting> getReadPreference() const final; public: diff --git a/src/mongo/s/query/cluster_client_cursor_impl_test.cpp b/src/mongo/s/query/cluster_client_cursor_impl_test.cpp index 2db563fdd73..c73b33a68e4 100644 --- a/src/mongo/s/query/cluster_client_cursor_impl_test.cpp +++ b/src/mongo/s/query/cluster_client_cursor_impl_test.cpp @@ -222,6 +222,47 @@ TEST_F(ClusterClientCursorImplTest, LogicalSessionIdsOnCursors) { ASSERT(*(cursor2.getLsid()) == lsid); } +TEST_F(ClusterClientCursorImplTest, ShouldStoreLSIDIfSetOnOpCtx) { + { + // Make a cursor with no lsid or txnNumber. + ClusterClientCursorParams params(NamespaceString("test"), {}); + params.lsid = _opCtx->getLogicalSessionId(); + params.txnNumber = _opCtx->getTxnNumber(); + + auto cursor = ClusterClientCursorImpl::make(_opCtx.get(), nullptr, std::move(params)); + ASSERT_FALSE(cursor->getLsid()); + ASSERT_FALSE(cursor->getTxnNumber()); + } + + const auto lsid = makeLogicalSessionIdForTest(); + _opCtx->setLogicalSessionId(lsid); + + { + // Make a cursor with an lsid and no txnNumber. + ClusterClientCursorParams params(NamespaceString("test"), {}); + params.lsid = _opCtx->getLogicalSessionId(); + params.txnNumber = _opCtx->getTxnNumber(); + + auto cursor = ClusterClientCursorImpl::make(_opCtx.get(), nullptr, std::move(params)); + ASSERT_EQ(*cursor->getLsid(), lsid); + ASSERT_FALSE(cursor->getTxnNumber()); + } + + const TxnNumber txnNumber = 5; + _opCtx->setTxnNumber(txnNumber); + + { + // Make a cursor with an lsid and txnNumber. + ClusterClientCursorParams params(NamespaceString("test"), {}); + params.lsid = _opCtx->getLogicalSessionId(); + params.txnNumber = _opCtx->getTxnNumber(); + + auto cursor = ClusterClientCursorImpl::make(_opCtx.get(), nullptr, std::move(params)); + ASSERT_EQ(*cursor->getLsid(), lsid); + ASSERT_EQ(*cursor->getTxnNumber(), txnNumber); + } +} + } // namespace } // namespace mongo diff --git a/src/mongo/s/query/cluster_client_cursor_mock.cpp b/src/mongo/s/query/cluster_client_cursor_mock.cpp index b248b35f77e..616d48d275c 100644 --- a/src/mongo/s/query/cluster_client_cursor_mock.cpp +++ b/src/mongo/s/query/cluster_client_cursor_mock.cpp @@ -37,8 +37,9 @@ namespace mongo { ClusterClientCursorMock::ClusterClientCursorMock(boost::optional<LogicalSessionId> lsid, + boost::optional<TxnNumber> txnNumber, stdx::function<void(void)> killCallback) - : _killCallback(std::move(killCallback)), _lsid(lsid) {} + : _killCallback(std::move(killCallback)), _lsid(lsid), _txnNumber(txnNumber) {} ClusterClientCursorMock::~ClusterClientCursorMock() { invariant((_exhausted && _remotesExhausted) || _killed); @@ -115,6 +116,10 @@ boost::optional<LogicalSessionId> ClusterClientCursorMock::getLsid() const { return _lsid; } +boost::optional<TxnNumber> ClusterClientCursorMock::getTxnNumber() const { + return _txnNumber; +} + boost::optional<ReadPreferenceSetting> ClusterClientCursorMock::getReadPreference() const { return boost::none; } diff --git a/src/mongo/s/query/cluster_client_cursor_mock.h b/src/mongo/s/query/cluster_client_cursor_mock.h index beb49735a8d..f5b1464b94b 100644 --- a/src/mongo/s/query/cluster_client_cursor_mock.h +++ b/src/mongo/s/query/cluster_client_cursor_mock.h @@ -43,6 +43,7 @@ class ClusterClientCursorMock final : public ClusterClientCursor { public: ClusterClientCursorMock(boost::optional<LogicalSessionId> lsid, + boost::optional<TxnNumber> txnNumber, stdx::function<void(void)> killCallback = stdx::function<void(void)>()); ~ClusterClientCursorMock(); @@ -79,6 +80,8 @@ public: boost::optional<LogicalSessionId> getLsid() const final; + boost::optional<TxnNumber> getTxnNumber() const final; + boost::optional<ReadPreferenceSetting> getReadPreference() const final; /** @@ -110,6 +113,8 @@ private: boost::optional<LogicalSessionId> _lsid; + boost::optional<TxnNumber> _txnNumber; + OperationContext* _opCtx = nullptr; }; diff --git a/src/mongo/s/query/cluster_client_cursor_params.h b/src/mongo/s/query/cluster_client_cursor_params.h index 71a7f17c282..c2d300ee19e 100644 --- a/src/mongo/s/query/cluster_client_cursor_params.h +++ b/src/mongo/s/query/cluster_client_cursor_params.h @@ -85,6 +85,12 @@ struct ClusterClientCursorParams { armParams.setBatchSize(batchSize); armParams.setNss(nsString); armParams.setAllowPartialResults(isAllowPartialResults); + + OperationSessionInfo sessionInfo; + sessionInfo.setSessionId(lsid); + sessionInfo.setTxnNumber(txnNumber); + armParams.setOperationSessionInfo(sessionInfo); + return armParams; } @@ -136,6 +142,12 @@ struct ClusterClientCursorParams { // Whether the client indicated that it is willing to receive partial results in the case of an // unreachable host. bool isAllowPartialResults = false; + + // The logical session id of the command that created the cursor. + boost::optional<LogicalSessionId> lsid; + + // The transaction number of the command that created the cursor. + boost::optional<TxnNumber> txnNumber; }; } // mongo diff --git a/src/mongo/s/query/cluster_cursor_manager.cpp b/src/mongo/s/query/cluster_cursor_manager.cpp index 44cd6333f19..99e6d9761c0 100644 --- a/src/mongo/s/query/cluster_cursor_manager.cpp +++ b/src/mongo/s/query/cluster_cursor_manager.cpp @@ -183,6 +183,16 @@ void ClusterCursorManager::PinnedCursor::returnAndKillCursor() { returnCursor(CursorState::Exhausted); } +boost::optional<LogicalSessionId> ClusterCursorManager::PinnedCursor::getLsid() const { + invariant(_cursor); + return _cursor->getLsid(); +} + +boost::optional<TxnNumber> ClusterCursorManager::PinnedCursor::getTxnNumber() const { + invariant(_cursor); + return _cursor->getTxnNumber(); +} + ClusterCursorManager::ClusterCursorManager(ClockSource* clockSource) : _clockSource(clockSource), _pseudoRandom(std::unique_ptr<SecureRandom>(SecureRandom::create())->nextInt64()) { diff --git a/src/mongo/s/query/cluster_cursor_manager.h b/src/mongo/s/query/cluster_cursor_manager.h index e8ab24cfa36..c255a2177e5 100644 --- a/src/mongo/s/query/cluster_cursor_manager.h +++ b/src/mongo/s/query/cluster_cursor_manager.h @@ -236,6 +236,16 @@ public: */ Status setAwaitDataTimeout(Milliseconds awaitDataTimeout); + /** + * Returns the logical session id of the command that created the underlying cursor. + */ + boost::optional<LogicalSessionId> getLsid() const; + + /** + * Returns the transaction number of the command that created the underlying cursor. + */ + boost::optional<TxnNumber> getTxnNumber() const; + Microseconds getLeftoverMaxTimeMicros() const { invariant(_cursor); return _cursor->getLeftoverMaxTimeMicros(); diff --git a/src/mongo/s/query/cluster_cursor_manager_test.cpp b/src/mongo/s/query/cluster_cursor_manager_test.cpp index b9ea435cc99..b5a2832cb1e 100644 --- a/src/mongo/s/query/cluster_cursor_manager_test.cpp +++ b/src/mongo/s/query/cluster_cursor_manager_test.cpp @@ -82,7 +82,8 @@ protected: * Allocates a mock cursor, which can be used with the 'isMockCursorKilled' method below. */ std::unique_ptr<ClusterClientCursorMock> allocateMockCursor( - boost::optional<LogicalSessionId> lsid = boost::none) { + boost::optional<LogicalSessionId> lsid = boost::none, + boost::optional<TxnNumber> txnNumber = boost::none) { // Allocate a new boolean to our list to track when this cursor is killed. _cursorKilledFlags.push_back(false); @@ -91,8 +92,8 @@ protected: // (std::list<>::push_back() does not invalidate references, and our list outlives the // manager). bool& killedFlag = _cursorKilledFlags.back(); - return stdx::make_unique<ClusterClientCursorMock>(std::move(lsid), - [&killedFlag]() { killedFlag = true; }); + return stdx::make_unique<ClusterClientCursorMock>( + std::move(lsid), std::move(txnNumber), [&killedFlag]() { killedFlag = true; }); } /** @@ -1268,6 +1269,25 @@ TEST_F(ClusterCursorManagerTest, CheckAuthForKillCursors) { getManager()->checkAuthForKillCursors(_opCtx.get(), nss, cursorId, successAuthChecker)); } +TEST_F(ClusterCursorManagerTest, PinnedCursorReturnsUnderlyingCursorTxnNumber) { + const TxnNumber txnNumber = 5; + auto cursorId = assertGet( + getManager()->registerCursor(_opCtx.get(), + allocateMockCursor(makeLogicalSessionIdForTest(), txnNumber), + nss, + ClusterCursorManager::CursorType::SingleTarget, + ClusterCursorManager::CursorLifetime::Mortal, + UserNameIterator())); + + auto pinnedCursor = + getManager()->checkOutCursor(nss, cursorId, _opCtx.get(), successAuthChecker); + ASSERT_OK(pinnedCursor.getStatus()); + + // The underlying cursor's txnNumber should be returned. + ASSERT(pinnedCursor.getValue().getTxnNumber()); + ASSERT_EQ(txnNumber, *pinnedCursor.getValue().getTxnNumber()); +} + } // namespace } // namespace mongo diff --git a/src/mongo/s/query/cluster_find.cpp b/src/mongo/s/query/cluster_find.cpp index 8c9e654f5bb..b47cd757402 100644 --- a/src/mongo/s/query/cluster_find.cpp +++ b/src/mongo/s/query/cluster_find.cpp @@ -246,6 +246,8 @@ CursorId runQueryWithoutRetrying(OperationContext* opCtx, params.skip = query.getQueryRequest().getSkip(); params.tailableMode = query.getQueryRequest().getTailableMode(); params.isAllowPartialResults = query.getQueryRequest().isAllowPartialResults(); + params.lsid = opCtx->getLogicalSessionId(); + params.txnNumber = opCtx->getTxnNumber(); // This is the batchSize passed to each subsequent getMore command issued by the cursor. We // usually use the batchSize associated with the initial find, but as it is illegal to send a @@ -470,6 +472,86 @@ CursorId ClusterFind::runQuery(OperationContext* opCtx, MONGO_UNREACHABLE } +/** + * Validates that the lsid on the OperationContext matches that on the cursor, returning it to the + * ClusterClusterCursor manager if it does not. + */ +void validateLSID(OperationContext* opCtx, + const GetMoreRequest& request, + ClusterCursorManager::PinnedCursor* cursor) { + if (opCtx->getLogicalSessionId() && !cursor->getLsid()) { + uasserted(50799, + str::stream() << "Cannot run getMore on cursor " << request.cursorid + << ", which was not created in a session, in session " + << *opCtx->getLogicalSessionId()); + } + + if (!opCtx->getLogicalSessionId() && cursor->getLsid()) { + uasserted(50800, + str::stream() << "Cannot run getMore on cursor " << request.cursorid + << ", which was created in session " + << *cursor->getLsid() + << ", without an lsid"); + } + + if (opCtx->getLogicalSessionId() && cursor->getLsid() && + (*opCtx->getLogicalSessionId() != *cursor->getLsid())) { + uasserted(50801, + str::stream() << "Cannot run getMore on cursor " << request.cursorid + << ", which was created in session " + << *cursor->getLsid() + << ", in session " + << *opCtx->getLogicalSessionId()); + } +} + +/** + * Validates that the txnNumber on the OperationContext matches that on the cursor, returning it to + * the ClusterClusterCursor manager if it does not. + */ +void validateTxnNumber(OperationContext* opCtx, + const GetMoreRequest& request, + ClusterCursorManager::PinnedCursor* cursor) { + if (opCtx->getTxnNumber() && !cursor->getTxnNumber()) { + uasserted(50802, + str::stream() << "Cannot run getMore on cursor " << request.cursorid + << ", which was not created in a transaction, in transaction " + << *opCtx->getTxnNumber()); + } + + if (!opCtx->getTxnNumber() && cursor->getTxnNumber()) { + uasserted(50803, + str::stream() << "Cannot run getMore on cursor " << request.cursorid + << ", which was created in transaction " + << *cursor->getTxnNumber() + << ", without a txnNumber"); + } + + if (opCtx->getTxnNumber() && cursor->getTxnNumber() && + (*opCtx->getTxnNumber() != *cursor->getTxnNumber())) { + uasserted(50804, + str::stream() << "Cannot run getMore on cursor " << request.cursorid + << ", which was created in transaction " + << *cursor->getTxnNumber() + << ", in transaction " + << *opCtx->getTxnNumber()); + } +} + +/** + * Validates that the OperationSessionInfo (i.e. txnNumber and lsid) on the OperationContext match + * that stored on the cursor. The cursor is returned to the ClusterCursorManager if it does not. + */ +void validateOperationSessionInfo(OperationContext* opCtx, + const GetMoreRequest& request, + ClusterCursorManager::PinnedCursor* cursor) { + ScopeGuard returnCursorGuard = MakeGuard( + [cursor] { cursor->returnCursor(ClusterCursorManager::CursorState::NotExhausted); }); + validateLSID(opCtx, request, cursor); + validateTxnNumber(opCtx, request, cursor); + returnCursorGuard.Dismiss(); +} + StatusWith<CursorResponse> ClusterFind::runGetMore(OperationContext* opCtx, const GetMoreRequest& request) { auto cursorManager = Grid::get(opCtx)->getCursorManager(); @@ -488,6 +570,8 @@ StatusWith<CursorResponse> ClusterFind::runGetMore(OperationContext* opCtx, } invariant(request.cursorid == pinnedCursor.getValue().getCursorId()); + validateOperationSessionInfo(opCtx, request, &pinnedCursor.getValue()); + // Set the originatingCommand object and the cursorID in CurOp. { CurOp::get(opCtx)->debug().nShards = pinnedCursor.getValue().getNumRemotes(); diff --git a/src/mongo/s/query/store_possible_cursor.cpp b/src/mongo/s/query/store_possible_cursor.cpp index 9a754364412..fa67241d1cd 100644 --- a/src/mongo/s/query/store_possible_cursor.cpp +++ b/src/mongo/s/query/store_possible_cursor.cpp @@ -81,6 +81,8 @@ StatusWith<BSONObj> storePossibleCursor(OperationContext* opCtx, {})); params.originatingCommandObj = CurOp::get(opCtx)->opDescription().getOwned(); params.tailableMode = tailableMode; + params.lsid = opCtx->getLogicalSessionId(); + params.txnNumber = opCtx->getTxnNumber(); auto ccc = ClusterClientCursorImpl::make(opCtx, executor, std::move(params)); diff --git a/src/mongo/s/sharding_task_executor.cpp b/src/mongo/s/sharding_task_executor.cpp index aef915f4782..ee15fdf3a6b 100644 --- a/src/mongo/s/sharding_task_executor.cpp +++ b/src/mongo/s/sharding_task_executor.cpp @@ -35,7 +35,6 @@ #include "mongo/base/disallow_copying.h" #include "mongo/base/status_with.h" #include "mongo/bson/timestamp.h" -#include "mongo/db/commands/test_commands_enabled.h" #include "mongo/db/logical_time.h" #include "mongo/db/operation_time_tracker.h" #include "mongo/executor/thread_pool_task_executor.h" @@ -131,13 +130,6 @@ StatusWith<TaskExecutor::CallbackHandle> ShardingTaskExecutor::scheduleRemoteCom request.opCtx->getLogicalSessionId()->serialize(&subbob); } - // TODO SERVER-33991. - if (getTestCommandsEnabled() && request.opCtx->getTxnNumber() && - request.cmdObj.hasField("getMore") && - !request.cmdObj.hasField(OperationSessionInfo::kTxnNumberFieldName)) { - bob.append(OperationSessionInfo::kTxnNumberFieldName, *(request.opCtx->getTxnNumber())); - } - newRequest->cmdObj = bob.obj(); } std::shared_ptr<OperationTimeTracker> timeTracker = OperationTimeTracker::get(request.opCtx); |