From 7ce60b59b24a893bb6494089ea4db8e19901f48d Mon Sep 17 00:00:00 2001 From: kauboy26 Date: Wed, 17 May 2023 20:55:33 +0000 Subject: SERVER-72989 Attach stmtIds to bulkWrite request sent by mongos --- src/mongo/db/commands/bulk_write.cpp | 29 +--- src/mongo/db/commands/bulk_write_common.cpp | 12 ++ src/mongo/db/commands/bulk_write_common.h | 6 + src/mongo/s/SConscript | 1 + src/mongo/s/write_ops/bulk_write_exec.cpp | 20 ++- src/mongo/s/write_ops/bulk_write_exec_test.cpp | 212 ++++++++++++++++++++++++- 6 files changed, 250 insertions(+), 30 deletions(-) diff --git a/src/mongo/db/commands/bulk_write.cpp b/src/mongo/db/commands/bulk_write.cpp index c2d989ef9a7..2f283ec5ed0 100644 --- a/src/mongo/db/commands/bulk_write.cpp +++ b/src/mongo/db/commands/bulk_write.cpp @@ -375,24 +375,6 @@ void finishCurOp(OperationContext* opCtx, CurOp* curOp) { } } -int32_t getStatementId(OperationContext* opCtx, - const BulkWriteCommandRequest& req, - const size_t currentOpIdx) { - if (opCtx->isRetryableWrite()) { - auto stmtId = req.getStmtId(); - auto stmtIds = req.getStmtIds(); - - if (stmtIds) { - return stmtIds->at(currentOpIdx); - } - - const int32_t firstStmtId = stmtId ? *stmtId : 0; - return firstStmtId + currentOpIdx; - } - - return kUninitializedStmtId; -} - std::tuple> getRetryResultForDelete( OperationContext* opCtx, const NamespaceString& nsString, @@ -501,7 +483,8 @@ bool handleInsertOp(OperationContext* opCtx, const auto& nsInfo = req.getNsInfo(); auto idx = op->getInsert(); - auto stmtId = getStatementId(opCtx, req, currentOpIdx); + auto stmtId = opCtx->isRetryableWrite() ? bulk_write_common::getStatementId(req, currentOpIdx) + : kUninitializedStmtId; auto txnParticipant = TransactionParticipant::get(opCtx); @@ -582,7 +565,9 @@ bool handleUpdateOp(OperationContext* opCtx, doTransactionValidationForWrites(opCtx, nsString); - auto stmtId = getStatementId(opCtx, req, currentOpIdx); + auto stmtId = opCtx->isRetryableWrite() + ? bulk_write_common::getStatementId(req, currentOpIdx) + : kUninitializedStmtId; if (opCtx->isRetryableWrite()) { const auto txnParticipant = TransactionParticipant::get(opCtx); if (auto entry = txnParticipant.checkStatementExecuted(opCtx, stmtId)) { @@ -728,7 +713,9 @@ bool handleDeleteOp(OperationContext* opCtx, doTransactionValidationForWrites(opCtx, nsString); - auto stmtId = getStatementId(opCtx, req, currentOpIdx); + auto stmtId = opCtx->isRetryableWrite() + ? bulk_write_common::getStatementId(req, currentOpIdx) + : kUninitializedStmtId; if (opCtx->isRetryableWrite()) { const auto txnParticipant = TransactionParticipant::get(opCtx); // If 'return' is not specified then we do not need to parse the statement. Since diff --git a/src/mongo/db/commands/bulk_write_common.cpp b/src/mongo/db/commands/bulk_write_common.cpp index 691932dfc17..2b39688dc1b 100644 --- a/src/mongo/db/commands/bulk_write_common.cpp +++ b/src/mongo/db/commands/bulk_write_common.cpp @@ -129,5 +129,17 @@ std::vector getPrivileges(const BulkWriteCommandRequest& req) { return privileges; } +int32_t getStatementId(const BulkWriteCommandRequest& req, size_t currentOpIdx) { + auto stmtId = req.getStmtId(); + auto stmtIds = req.getStmtIds(); + + if (stmtIds) { + return stmtIds->at(currentOpIdx); + } + + int32_t firstStmtId = stmtId ? *stmtId : 0; + return firstStmtId + currentOpIdx; +} + } // namespace bulk_write_common } // namespace mongo diff --git a/src/mongo/db/commands/bulk_write_common.h b/src/mongo/db/commands/bulk_write_common.h index 754a644083f..ec79ea13464 100644 --- a/src/mongo/db/commands/bulk_write_common.h +++ b/src/mongo/db/commands/bulk_write_common.h @@ -49,5 +49,11 @@ void validateRequest(const BulkWriteCommandRequest& req, bool isRetryableWrite); */ std::vector getPrivileges(const BulkWriteCommandRequest& req); +/** + * Get the statement ID for an operation within a bulkWrite command, taking into consideration + * whether the stmtId / stmtIds fields are present on the request. + */ +int32_t getStatementId(const BulkWriteCommandRequest& req, size_t currentOpIdx); + } // namespace bulk_write_common } // namespace mongo diff --git a/src/mongo/s/SConscript b/src/mongo/s/SConscript index 3ca8219b298..eb66d737ec7 100644 --- a/src/mongo/s/SConscript +++ b/src/mongo/s/SConscript @@ -39,6 +39,7 @@ env.Library( 'write_ops/write_without_shard_key_util.cpp', ], LIBDEPS=[ + '$BUILD_DIR/mongo/db/commands/bulk_write_common', '$BUILD_DIR/mongo/db/commands/server_status_core', '$BUILD_DIR/mongo/db/fle_crud', '$BUILD_DIR/mongo/db/not_primary_error_tracker', diff --git a/src/mongo/s/write_ops/bulk_write_exec.cpp b/src/mongo/s/write_ops/bulk_write_exec.cpp index 724b46b34ad..9fc4207c130 100644 --- a/src/mongo/s/write_ops/bulk_write_exec.cpp +++ b/src/mongo/s/write_ops/bulk_write_exec.cpp @@ -32,6 +32,7 @@ #include "mongo/base/error_codes.h" #include "mongo/client/read_preference.h" #include "mongo/client/remote_command_targeter.h" +#include "mongo/db/commands/bulk_write_common.h" #include "mongo/db/commands/bulk_write_gen.h" #include "mongo/db/commands/bulk_write_parser.h" #include "mongo/db/database_name.h" @@ -84,8 +85,6 @@ void executeChildBatches(OperationContext* opCtx, requests.emplace_back(childBatch.first, request); } - bool isRetryableWrite = opCtx->getTxnNumber() && !TransactionRouter::get(opCtx); - // Use MultiStatementTransactionRequestsSender to send any ready sub-batches to targeted // shard endpoints. Requests are sent on construction. MultiStatementTransactionRequestsSender ars( @@ -94,7 +93,7 @@ void executeChildBatches(OperationContext* opCtx, DatabaseName::kAdmin, requests, ReadPreferenceSetting(ReadPreference::PrimaryOnly), - isRetryableWrite ? Shard::RetryPolicy::kIdempotent : Shard::RetryPolicy::kNoRetry); + opCtx->isRetryableWrite() ? Shard::RetryPolicy::kIdempotent : Shard::RetryPolicy::kNoRetry); while (!ars.done()) { // Block until a response is available. @@ -303,7 +302,11 @@ BulkWriteCommandRequest BulkWriteOp::buildBulkCommandRequest( ops; std::vector nsInfo = _clientRequest.getNsInfo(); - for (auto&& targetedWrite : targetedBatch.getWrites()) { + std::vector stmtIds; + if (_isRetryableWrite) + stmtIds.reserve(targetedBatch.getNumOps()); + + for (const auto& targetedWrite : targetedBatch.getWrites()) { const WriteOpRef& writeOpRef = targetedWrite->writeOpRef; ops.push_back(_clientRequest.getOps().at(writeOpRef.first)); @@ -325,6 +328,10 @@ BulkWriteCommandRequest BulkWriteOp::buildBulkCommandRequest( nsInfoEntry.setShardVersion(targetedWrite->endpoint.shardVersion); nsInfoEntry.setDatabaseVersion(targetedWrite->endpoint.databaseVersion); + + if (_isRetryableWrite) { + stmtIds.push_back(bulk_write_common::getStatementId(_clientRequest, writeOpRef.first)); + } } request.setOps(ops); @@ -336,8 +343,9 @@ BulkWriteCommandRequest BulkWriteOp::buildBulkCommandRequest( request.setOrdered(_clientRequest.getOrdered()); request.setBypassDocumentValidation(_clientRequest.getBypassDocumentValidation()); - // TODO (SERVER-72989): Attach stmtIds etc. when building support for retryable - // writes on mongos + if (_isRetryableWrite) { + request.setStmtIds(stmtIds); + } request.setDbName(DatabaseName::kAdmin); diff --git a/src/mongo/s/write_ops/bulk_write_exec_test.cpp b/src/mongo/s/write_ops/bulk_write_exec_test.cpp index 5dcd2ff15e7..c1416d7976c 100644 --- a/src/mongo/s/write_ops/bulk_write_exec_test.cpp +++ b/src/mongo/s/write_ops/bulk_write_exec_test.cpp @@ -813,9 +813,8 @@ TEST_F(BulkWriteOpTest, BuildChildRequestFromTargetedWriteBatch) { NamespaceString nss1("sonate.pacifique"); // Two different endpoints targeting the same shard for the two namespaces. - ShardEndpoint endpoint0(ShardId("shard"), - ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none), - boost::none); + ShardEndpoint endpoint0( + shardId, ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none), boost::none); ShardEndpoint endpoint1( shardId, ShardVersionFactory::make(ChunkVersion({OID::gen(), Timestamp(2)}, {10, 11}), @@ -873,6 +872,213 @@ TEST_F(BulkWriteOpTest, BuildChildRequestFromTargetedWriteBatch) { ASSERT_EQUALS(childRequest.getNsInfo()[1].getNs(), request.getNsInfo()[1].getNs()); } +// Tests that stmtIds are correctly attached to bulkWrite requests when the operations +// are ordered. +TEST_F(BulkWriteOpTest, TestOrderedOpsNoExistingStmtIds) { + NamespaceString nss("mgmt.kids"); + + ShardEndpoint endpointA(ShardId("shardA"), + ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none), + boost::none); + ShardEndpoint endpointB(ShardId("shardB"), + ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none), + boost::none); + + std::vector> targeters; + targeters.push_back(initTargeterSplitRange(nss, endpointA, endpointB)); + + // Because the operations are ordered, the bulkWrite operations is broken up by shard + // endpoint. In other words, targeting this request will result in two batches: + // 1) to shard A, and then 2) another to shard B after the first batch is complete. + BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 0, shard A + BulkWriteInsertOp(0, BSON("x" << -2)), // stmtId 1, shard A + BulkWriteInsertOp(0, BSON("x" << 1)), // stmtId 2, shard B + BulkWriteInsertOp(0, BSON("x" << 2))}, // stmtId 3, shard B + {NamespaceInfoEntry(nss)}); + request.setOrdered(true); + + // Setting the txnNumber makes it a retryable write. + _opCtx->setLogicalSessionId(LogicalSessionId()); + _opCtx->setTxnNumber(TxnNumber(0)); + BulkWriteOp bulkWriteOp(_opCtx, request); + + std::map> targeted; + ASSERT_OK(bulkWriteOp.target(targeters, false, targeted)); + + auto* batch = targeted.begin()->second.get(); + auto childRequest = bulkWriteOp.buildBulkCommandRequest(*batch); + auto childStmtIds = childRequest.getStmtIds(); + ASSERT_EQUALS(childStmtIds->size(), 2u); + ASSERT_EQUALS(childStmtIds->at(0), 0); + ASSERT_EQUALS(childStmtIds->at(1), 1); + + // Target again to get a batch for the operations to shard B. + targeted.clear(); + ASSERT_OK(bulkWriteOp.target(targeters, false, targeted)); + + batch = targeted.begin()->second.get(); + childRequest = bulkWriteOp.buildBulkCommandRequest(*batch); + childStmtIds = childRequest.getStmtIds(); + ASSERT_EQUALS(childStmtIds->size(), 2u); + ASSERT_EQUALS(childStmtIds->at(0), 2); + ASSERT_EQUALS(childStmtIds->at(1), 3); +} + +// Tests that stmtIds are correctly attached to bulkWrite requests when the operations +// are unordered. +TEST_F(BulkWriteOpTest, TestUnorderedOpsNoExistingStmtIds) { + NamespaceString nss("zero7.spinning"); + + ShardEndpoint endpointA(ShardId("shardA"), + ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none), + boost::none); + ShardEndpoint endpointB(ShardId("shardB"), + ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none), + boost::none); + + std::vector> targeters; + targeters.push_back(initTargeterSplitRange(nss, endpointA, endpointB)); + + // Since the ops aren't ordered, two batches are produced on a single targeting call: + // 1) the ops to shard A (op 0 and op 2) are a batch and 2) the ops to shard B (op 1 + // and op 3) are a batch. Therefore the stmtIds in the bulkWrite request sent to the shards + // will be interleaving. + BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 0, shard A + BulkWriteInsertOp(0, BSON("x" << 1)), // stmtId 1, shard B + BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 2, shard A + BulkWriteInsertOp(0, BSON("x" << 2))}, // stmtId 3, shard B + {NamespaceInfoEntry(nss)}); + request.setOrdered(false); + + // Setting the txnNumber makes it a retryable write. + _opCtx->setLogicalSessionId(LogicalSessionId()); + _opCtx->setTxnNumber(TxnNumber(0)); + BulkWriteOp bulkWriteOp(_opCtx, request); + + std::map> targeted; + ASSERT_OK(bulkWriteOp.target(targeters, false, targeted)); + + // The batch to shard A contains op 0 and op 2. + auto* batch = targeted[ShardId("shardA")].get(); + auto childRequest = bulkWriteOp.buildBulkCommandRequest(*batch); + auto childStmtIds = childRequest.getStmtIds(); + ASSERT_EQUALS(childStmtIds->size(), 2u); + ASSERT_EQUALS(childStmtIds->at(0), 0); + ASSERT_EQUALS(childStmtIds->at(1), 2); + + // The batch to shard B contains op 1 and op 3. + batch = targeted[ShardId("shardB")].get(); + childRequest = bulkWriteOp.buildBulkCommandRequest(*batch); + childStmtIds = childRequest.getStmtIds(); + ASSERT_EQUALS(childStmtIds->size(), 2u); + ASSERT_EQUALS(childStmtIds->at(0), 1); + ASSERT_EQUALS(childStmtIds->at(1), 3); +} + +// Tests that stmtIds are correctly attached to bulkWrite requests when the operations +// are unordered and stmtIds are attached to the request already. +TEST_F(BulkWriteOpTest, TestUnorderedOpsStmtIdsExist) { + NamespaceString nss("zero7.spinning"); + + ShardEndpoint endpointA(ShardId("shardA"), + ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none), + boost::none); + ShardEndpoint endpointB(ShardId("shardB"), + ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none), + boost::none); + + std::vector> targeters; + targeters.push_back(initTargeterSplitRange(nss, endpointA, endpointB)); + + // Since the ops aren't ordered, two batches are produced on a single targeting call: + // 1) the ops to shard A (op 0 and op 2) are a batch and 2) the ops to shard B (op 1 + // and op 3) are a batch. Therefore the stmtIds in the bulkWrite request sent to the shards + // will be interleaving. + BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 6, shard A + BulkWriteInsertOp(0, BSON("x" << 1)), // stmtId 7, shard B + BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 8, shard A + BulkWriteInsertOp(0, BSON("x" << 2))}, // stmtId 9, shard B + {NamespaceInfoEntry(nss)}); + request.setOrdered(false); + request.setStmtIds(std::vector{6, 7, 8, 9}); + + // Setting the txnNumber makes it a retryable write. + _opCtx->setLogicalSessionId(LogicalSessionId()); + _opCtx->setTxnNumber(TxnNumber(0)); + BulkWriteOp bulkWriteOp(_opCtx, request); + + std::map> targeted; + ASSERT_OK(bulkWriteOp.target(targeters, false, targeted)); + + // The batch to shard A contains op 0 and op 2. + auto* batch = targeted[ShardId("shardA")].get(); + auto childRequest = bulkWriteOp.buildBulkCommandRequest(*batch); + auto childStmtIds = childRequest.getStmtIds(); + ASSERT_EQUALS(childStmtIds->size(), 2u); + ASSERT_EQUALS(childStmtIds->at(0), 6); + ASSERT_EQUALS(childStmtIds->at(1), 8); + + // The batch to shard B contains op 1 and op 3. + batch = targeted[ShardId("shardB")].get(); + childRequest = bulkWriteOp.buildBulkCommandRequest(*batch); + childStmtIds = childRequest.getStmtIds(); + ASSERT_EQUALS(childStmtIds->size(), 2u); + ASSERT_EQUALS(childStmtIds->at(0), 7); + ASSERT_EQUALS(childStmtIds->at(1), 9); +} + +// Tests that stmtIds are correctly attached to bulkWrite requests when the operations +// are unordered and the stmtId field exists. +TEST_F(BulkWriteOpTest, TestUnorderedOpsStmtIdFieldExists) { + NamespaceString nss("zero7.spinning"); + + ShardEndpoint endpointA(ShardId("shardA"), + ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none), + boost::none); + ShardEndpoint endpointB(ShardId("shardB"), + ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none), + boost::none); + + std::vector> targeters; + targeters.push_back(initTargeterSplitRange(nss, endpointA, endpointB)); + + // Since the ops aren't ordered, two batches are produced on a single targeting call: + // 1) the ops to shard A (op 0 and op 2) are a batch and 2) the ops to shard B (op 1 + // and op 3) are a batch. Therefore the stmtIds in the bulkWrite request sent to the shards + // will be interleaving. + BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 6, shard A + BulkWriteInsertOp(0, BSON("x" << 1)), // stmtId 7, shard B + BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 8, shard A + BulkWriteInsertOp(0, BSON("x" << 2))}, // stmtId 9, shard B + {NamespaceInfoEntry(nss)}); + request.setOrdered(false); + request.setStmtId(6); // Produces stmtIds 6, 7, 8, 9 + + // Setting the txnNumber makes it a retryable write. + _opCtx->setLogicalSessionId(LogicalSessionId()); + _opCtx->setTxnNumber(TxnNumber(0)); + BulkWriteOp bulkWriteOp(_opCtx, request); + + std::map> targeted; + ASSERT_OK(bulkWriteOp.target(targeters, false, targeted)); + + // The batch to shard A contains op 0 and op 2. + auto* batch = targeted[ShardId("shardA")].get(); + auto childRequest = bulkWriteOp.buildBulkCommandRequest(*batch); + auto childStmtIds = childRequest.getStmtIds(); + ASSERT_EQUALS(childStmtIds->size(), 2u); + ASSERT_EQUALS(childStmtIds->at(0), 6); + ASSERT_EQUALS(childStmtIds->at(1), 8); + + // The batch to shard B contains op 1 and op 3. + batch = targeted[ShardId("shardB")].get(); + childRequest = bulkWriteOp.buildBulkCommandRequest(*batch); + childStmtIds = childRequest.getStmtIds(); + ASSERT_EQUALS(childStmtIds->size(), 2u); + ASSERT_EQUALS(childStmtIds->at(0), 7); + ASSERT_EQUALS(childStmtIds->at(1), 9); +} + // Test BatchItemRef.getLet(). TEST_F(BulkWriteOpTest, BatchItemRefGetLet) { NamespaceString nss("foo.bar"); -- cgit v1.2.1