summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkauboy26 <vishnu.kaushik@mongodb.com>2023-05-17 20:55:33 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2023-05-17 23:17:02 +0000
commit7ce60b59b24a893bb6494089ea4db8e19901f48d (patch)
treea528735ad4207c8f409bb361f4b120492d5732ca
parentde55cd2ac227dcc8cae2fd021abc291e86b2abb2 (diff)
downloadmongo-7ce60b59b24a893bb6494089ea4db8e19901f48d.tar.gz
SERVER-72989 Attach stmtIds to bulkWrite request sent by mongos
-rw-r--r--src/mongo/db/commands/bulk_write.cpp29
-rw-r--r--src/mongo/db/commands/bulk_write_common.cpp12
-rw-r--r--src/mongo/db/commands/bulk_write_common.h6
-rw-r--r--src/mongo/s/SConscript1
-rw-r--r--src/mongo/s/write_ops/bulk_write_exec.cpp20
-rw-r--r--src/mongo/s/write_ops/bulk_write_exec_test.cpp212
6 files changed, 250 insertions, 30 deletions
diff --git a/src/mongo/db/commands/bulk_write.cpp b/src/mongo/db/commands/bulk_write.cpp
index c2d989ef9a7..2f283ec5ed0 100644
--- a/src/mongo/db/commands/bulk_write.cpp
+++ b/src/mongo/db/commands/bulk_write.cpp
@@ -375,24 +375,6 @@ void finishCurOp(OperationContext* opCtx, CurOp* curOp) {
}
}
-int32_t getStatementId(OperationContext* opCtx,
- const BulkWriteCommandRequest& req,
- const size_t currentOpIdx) {
- if (opCtx->isRetryableWrite()) {
- auto stmtId = req.getStmtId();
- auto stmtIds = req.getStmtIds();
-
- if (stmtIds) {
- return stmtIds->at(currentOpIdx);
- }
-
- const int32_t firstStmtId = stmtId ? *stmtId : 0;
- return firstStmtId + currentOpIdx;
- }
-
- return kUninitializedStmtId;
-}
-
std::tuple<long long, boost::optional<BSONObj>> getRetryResultForDelete(
OperationContext* opCtx,
const NamespaceString& nsString,
@@ -501,7 +483,8 @@ bool handleInsertOp(OperationContext* opCtx,
const auto& nsInfo = req.getNsInfo();
auto idx = op->getInsert();
- auto stmtId = getStatementId(opCtx, req, currentOpIdx);
+ auto stmtId = opCtx->isRetryableWrite() ? bulk_write_common::getStatementId(req, currentOpIdx)
+ : kUninitializedStmtId;
auto txnParticipant = TransactionParticipant::get(opCtx);
@@ -582,7 +565,9 @@ bool handleUpdateOp(OperationContext* opCtx,
doTransactionValidationForWrites(opCtx, nsString);
- auto stmtId = getStatementId(opCtx, req, currentOpIdx);
+ auto stmtId = opCtx->isRetryableWrite()
+ ? bulk_write_common::getStatementId(req, currentOpIdx)
+ : kUninitializedStmtId;
if (opCtx->isRetryableWrite()) {
const auto txnParticipant = TransactionParticipant::get(opCtx);
if (auto entry = txnParticipant.checkStatementExecuted(opCtx, stmtId)) {
@@ -728,7 +713,9 @@ bool handleDeleteOp(OperationContext* opCtx,
doTransactionValidationForWrites(opCtx, nsString);
- auto stmtId = getStatementId(opCtx, req, currentOpIdx);
+ auto stmtId = opCtx->isRetryableWrite()
+ ? bulk_write_common::getStatementId(req, currentOpIdx)
+ : kUninitializedStmtId;
if (opCtx->isRetryableWrite()) {
const auto txnParticipant = TransactionParticipant::get(opCtx);
// If 'return' is not specified then we do not need to parse the statement. Since
diff --git a/src/mongo/db/commands/bulk_write_common.cpp b/src/mongo/db/commands/bulk_write_common.cpp
index 691932dfc17..2b39688dc1b 100644
--- a/src/mongo/db/commands/bulk_write_common.cpp
+++ b/src/mongo/db/commands/bulk_write_common.cpp
@@ -129,5 +129,17 @@ std::vector<Privilege> getPrivileges(const BulkWriteCommandRequest& req) {
return privileges;
}
+int32_t getStatementId(const BulkWriteCommandRequest& req, size_t currentOpIdx) {
+ auto stmtId = req.getStmtId();
+ auto stmtIds = req.getStmtIds();
+
+ if (stmtIds) {
+ return stmtIds->at(currentOpIdx);
+ }
+
+ int32_t firstStmtId = stmtId ? *stmtId : 0;
+ return firstStmtId + currentOpIdx;
+}
+
} // namespace bulk_write_common
} // namespace mongo
diff --git a/src/mongo/db/commands/bulk_write_common.h b/src/mongo/db/commands/bulk_write_common.h
index 754a644083f..ec79ea13464 100644
--- a/src/mongo/db/commands/bulk_write_common.h
+++ b/src/mongo/db/commands/bulk_write_common.h
@@ -49,5 +49,11 @@ void validateRequest(const BulkWriteCommandRequest& req, bool isRetryableWrite);
*/
std::vector<Privilege> getPrivileges(const BulkWriteCommandRequest& req);
+/**
+ * Get the statement ID for an operation within a bulkWrite command, taking into consideration
+ * whether the stmtId / stmtIds fields are present on the request.
+ */
+int32_t getStatementId(const BulkWriteCommandRequest& req, size_t currentOpIdx);
+
} // namespace bulk_write_common
} // namespace mongo
diff --git a/src/mongo/s/SConscript b/src/mongo/s/SConscript
index 3ca8219b298..eb66d737ec7 100644
--- a/src/mongo/s/SConscript
+++ b/src/mongo/s/SConscript
@@ -39,6 +39,7 @@ env.Library(
'write_ops/write_without_shard_key_util.cpp',
],
LIBDEPS=[
+ '$BUILD_DIR/mongo/db/commands/bulk_write_common',
'$BUILD_DIR/mongo/db/commands/server_status_core',
'$BUILD_DIR/mongo/db/fle_crud',
'$BUILD_DIR/mongo/db/not_primary_error_tracker',
diff --git a/src/mongo/s/write_ops/bulk_write_exec.cpp b/src/mongo/s/write_ops/bulk_write_exec.cpp
index 724b46b34ad..9fc4207c130 100644
--- a/src/mongo/s/write_ops/bulk_write_exec.cpp
+++ b/src/mongo/s/write_ops/bulk_write_exec.cpp
@@ -32,6 +32,7 @@
#include "mongo/base/error_codes.h"
#include "mongo/client/read_preference.h"
#include "mongo/client/remote_command_targeter.h"
+#include "mongo/db/commands/bulk_write_common.h"
#include "mongo/db/commands/bulk_write_gen.h"
#include "mongo/db/commands/bulk_write_parser.h"
#include "mongo/db/database_name.h"
@@ -84,8 +85,6 @@ void executeChildBatches(OperationContext* opCtx,
requests.emplace_back(childBatch.first, request);
}
- bool isRetryableWrite = opCtx->getTxnNumber() && !TransactionRouter::get(opCtx);
-
// Use MultiStatementTransactionRequestsSender to send any ready sub-batches to targeted
// shard endpoints. Requests are sent on construction.
MultiStatementTransactionRequestsSender ars(
@@ -94,7 +93,7 @@ void executeChildBatches(OperationContext* opCtx,
DatabaseName::kAdmin,
requests,
ReadPreferenceSetting(ReadPreference::PrimaryOnly),
- isRetryableWrite ? Shard::RetryPolicy::kIdempotent : Shard::RetryPolicy::kNoRetry);
+ opCtx->isRetryableWrite() ? Shard::RetryPolicy::kIdempotent : Shard::RetryPolicy::kNoRetry);
while (!ars.done()) {
// Block until a response is available.
@@ -303,7 +302,11 @@ BulkWriteCommandRequest BulkWriteOp::buildBulkCommandRequest(
ops;
std::vector<NamespaceInfoEntry> nsInfo = _clientRequest.getNsInfo();
- for (auto&& targetedWrite : targetedBatch.getWrites()) {
+ std::vector<int> stmtIds;
+ if (_isRetryableWrite)
+ stmtIds.reserve(targetedBatch.getNumOps());
+
+ for (const auto& targetedWrite : targetedBatch.getWrites()) {
const WriteOpRef& writeOpRef = targetedWrite->writeOpRef;
ops.push_back(_clientRequest.getOps().at(writeOpRef.first));
@@ -325,6 +328,10 @@ BulkWriteCommandRequest BulkWriteOp::buildBulkCommandRequest(
nsInfoEntry.setShardVersion(targetedWrite->endpoint.shardVersion);
nsInfoEntry.setDatabaseVersion(targetedWrite->endpoint.databaseVersion);
+
+ if (_isRetryableWrite) {
+ stmtIds.push_back(bulk_write_common::getStatementId(_clientRequest, writeOpRef.first));
+ }
}
request.setOps(ops);
@@ -336,8 +343,9 @@ BulkWriteCommandRequest BulkWriteOp::buildBulkCommandRequest(
request.setOrdered(_clientRequest.getOrdered());
request.setBypassDocumentValidation(_clientRequest.getBypassDocumentValidation());
- // TODO (SERVER-72989): Attach stmtIds etc. when building support for retryable
- // writes on mongos
+ if (_isRetryableWrite) {
+ request.setStmtIds(stmtIds);
+ }
request.setDbName(DatabaseName::kAdmin);
diff --git a/src/mongo/s/write_ops/bulk_write_exec_test.cpp b/src/mongo/s/write_ops/bulk_write_exec_test.cpp
index 5dcd2ff15e7..c1416d7976c 100644
--- a/src/mongo/s/write_ops/bulk_write_exec_test.cpp
+++ b/src/mongo/s/write_ops/bulk_write_exec_test.cpp
@@ -813,9 +813,8 @@ TEST_F(BulkWriteOpTest, BuildChildRequestFromTargetedWriteBatch) {
NamespaceString nss1("sonate.pacifique");
// Two different endpoints targeting the same shard for the two namespaces.
- ShardEndpoint endpoint0(ShardId("shard"),
- ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none),
- boost::none);
+ ShardEndpoint endpoint0(
+ shardId, ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none), boost::none);
ShardEndpoint endpoint1(
shardId,
ShardVersionFactory::make(ChunkVersion({OID::gen(), Timestamp(2)}, {10, 11}),
@@ -873,6 +872,213 @@ TEST_F(BulkWriteOpTest, BuildChildRequestFromTargetedWriteBatch) {
ASSERT_EQUALS(childRequest.getNsInfo()[1].getNs(), request.getNsInfo()[1].getNs());
}
+// Tests that stmtIds are correctly attached to bulkWrite requests when the operations
+// are ordered.
+TEST_F(BulkWriteOpTest, TestOrderedOpsNoExistingStmtIds) {
+ NamespaceString nss("mgmt.kids");
+
+ ShardEndpoint endpointA(ShardId("shardA"),
+ ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none),
+ boost::none);
+ ShardEndpoint endpointB(ShardId("shardB"),
+ ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none),
+ boost::none);
+
+ std::vector<std::unique_ptr<NSTargeter>> targeters;
+ targeters.push_back(initTargeterSplitRange(nss, endpointA, endpointB));
+
+ // Because the operations are ordered, the bulkWrite operations is broken up by shard
+ // endpoint. In other words, targeting this request will result in two batches:
+ // 1) to shard A, and then 2) another to shard B after the first batch is complete.
+ BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 0, shard A
+ BulkWriteInsertOp(0, BSON("x" << -2)), // stmtId 1, shard A
+ BulkWriteInsertOp(0, BSON("x" << 1)), // stmtId 2, shard B
+ BulkWriteInsertOp(0, BSON("x" << 2))}, // stmtId 3, shard B
+ {NamespaceInfoEntry(nss)});
+ request.setOrdered(true);
+
+ // Setting the txnNumber makes it a retryable write.
+ _opCtx->setLogicalSessionId(LogicalSessionId());
+ _opCtx->setTxnNumber(TxnNumber(0));
+ BulkWriteOp bulkWriteOp(_opCtx, request);
+
+ std::map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted;
+ ASSERT_OK(bulkWriteOp.target(targeters, false, targeted));
+
+ auto* batch = targeted.begin()->second.get();
+ auto childRequest = bulkWriteOp.buildBulkCommandRequest(*batch);
+ auto childStmtIds = childRequest.getStmtIds();
+ ASSERT_EQUALS(childStmtIds->size(), 2u);
+ ASSERT_EQUALS(childStmtIds->at(0), 0);
+ ASSERT_EQUALS(childStmtIds->at(1), 1);
+
+ // Target again to get a batch for the operations to shard B.
+ targeted.clear();
+ ASSERT_OK(bulkWriteOp.target(targeters, false, targeted));
+
+ batch = targeted.begin()->second.get();
+ childRequest = bulkWriteOp.buildBulkCommandRequest(*batch);
+ childStmtIds = childRequest.getStmtIds();
+ ASSERT_EQUALS(childStmtIds->size(), 2u);
+ ASSERT_EQUALS(childStmtIds->at(0), 2);
+ ASSERT_EQUALS(childStmtIds->at(1), 3);
+}
+
+// Tests that stmtIds are correctly attached to bulkWrite requests when the operations
+// are unordered.
+TEST_F(BulkWriteOpTest, TestUnorderedOpsNoExistingStmtIds) {
+ NamespaceString nss("zero7.spinning");
+
+ ShardEndpoint endpointA(ShardId("shardA"),
+ ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none),
+ boost::none);
+ ShardEndpoint endpointB(ShardId("shardB"),
+ ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none),
+ boost::none);
+
+ std::vector<std::unique_ptr<NSTargeter>> targeters;
+ targeters.push_back(initTargeterSplitRange(nss, endpointA, endpointB));
+
+ // Since the ops aren't ordered, two batches are produced on a single targeting call:
+ // 1) the ops to shard A (op 0 and op 2) are a batch and 2) the ops to shard B (op 1
+ // and op 3) are a batch. Therefore the stmtIds in the bulkWrite request sent to the shards
+ // will be interleaving.
+ BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 0, shard A
+ BulkWriteInsertOp(0, BSON("x" << 1)), // stmtId 1, shard B
+ BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 2, shard A
+ BulkWriteInsertOp(0, BSON("x" << 2))}, // stmtId 3, shard B
+ {NamespaceInfoEntry(nss)});
+ request.setOrdered(false);
+
+ // Setting the txnNumber makes it a retryable write.
+ _opCtx->setLogicalSessionId(LogicalSessionId());
+ _opCtx->setTxnNumber(TxnNumber(0));
+ BulkWriteOp bulkWriteOp(_opCtx, request);
+
+ std::map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted;
+ ASSERT_OK(bulkWriteOp.target(targeters, false, targeted));
+
+ // The batch to shard A contains op 0 and op 2.
+ auto* batch = targeted[ShardId("shardA")].get();
+ auto childRequest = bulkWriteOp.buildBulkCommandRequest(*batch);
+ auto childStmtIds = childRequest.getStmtIds();
+ ASSERT_EQUALS(childStmtIds->size(), 2u);
+ ASSERT_EQUALS(childStmtIds->at(0), 0);
+ ASSERT_EQUALS(childStmtIds->at(1), 2);
+
+ // The batch to shard B contains op 1 and op 3.
+ batch = targeted[ShardId("shardB")].get();
+ childRequest = bulkWriteOp.buildBulkCommandRequest(*batch);
+ childStmtIds = childRequest.getStmtIds();
+ ASSERT_EQUALS(childStmtIds->size(), 2u);
+ ASSERT_EQUALS(childStmtIds->at(0), 1);
+ ASSERT_EQUALS(childStmtIds->at(1), 3);
+}
+
+// Tests that stmtIds are correctly attached to bulkWrite requests when the operations
+// are unordered and stmtIds are attached to the request already.
+TEST_F(BulkWriteOpTest, TestUnorderedOpsStmtIdsExist) {
+ NamespaceString nss("zero7.spinning");
+
+ ShardEndpoint endpointA(ShardId("shardA"),
+ ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none),
+ boost::none);
+ ShardEndpoint endpointB(ShardId("shardB"),
+ ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none),
+ boost::none);
+
+ std::vector<std::unique_ptr<NSTargeter>> targeters;
+ targeters.push_back(initTargeterSplitRange(nss, endpointA, endpointB));
+
+ // Since the ops aren't ordered, two batches are produced on a single targeting call:
+ // 1) the ops to shard A (op 0 and op 2) are a batch and 2) the ops to shard B (op 1
+ // and op 3) are a batch. Therefore the stmtIds in the bulkWrite request sent to the shards
+ // will be interleaving.
+ BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 6, shard A
+ BulkWriteInsertOp(0, BSON("x" << 1)), // stmtId 7, shard B
+ BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 8, shard A
+ BulkWriteInsertOp(0, BSON("x" << 2))}, // stmtId 9, shard B
+ {NamespaceInfoEntry(nss)});
+ request.setOrdered(false);
+ request.setStmtIds(std::vector<int>{6, 7, 8, 9});
+
+ // Setting the txnNumber makes it a retryable write.
+ _opCtx->setLogicalSessionId(LogicalSessionId());
+ _opCtx->setTxnNumber(TxnNumber(0));
+ BulkWriteOp bulkWriteOp(_opCtx, request);
+
+ std::map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted;
+ ASSERT_OK(bulkWriteOp.target(targeters, false, targeted));
+
+ // The batch to shard A contains op 0 and op 2.
+ auto* batch = targeted[ShardId("shardA")].get();
+ auto childRequest = bulkWriteOp.buildBulkCommandRequest(*batch);
+ auto childStmtIds = childRequest.getStmtIds();
+ ASSERT_EQUALS(childStmtIds->size(), 2u);
+ ASSERT_EQUALS(childStmtIds->at(0), 6);
+ ASSERT_EQUALS(childStmtIds->at(1), 8);
+
+ // The batch to shard B contains op 1 and op 3.
+ batch = targeted[ShardId("shardB")].get();
+ childRequest = bulkWriteOp.buildBulkCommandRequest(*batch);
+ childStmtIds = childRequest.getStmtIds();
+ ASSERT_EQUALS(childStmtIds->size(), 2u);
+ ASSERT_EQUALS(childStmtIds->at(0), 7);
+ ASSERT_EQUALS(childStmtIds->at(1), 9);
+}
+
+// Tests that stmtIds are correctly attached to bulkWrite requests when the operations
+// are unordered and the stmtId field exists.
+TEST_F(BulkWriteOpTest, TestUnorderedOpsStmtIdFieldExists) {
+ NamespaceString nss("zero7.spinning");
+
+ ShardEndpoint endpointA(ShardId("shardA"),
+ ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none),
+ boost::none);
+ ShardEndpoint endpointB(ShardId("shardB"),
+ ShardVersionFactory::make(ChunkVersion::IGNORED(), boost::none),
+ boost::none);
+
+ std::vector<std::unique_ptr<NSTargeter>> targeters;
+ targeters.push_back(initTargeterSplitRange(nss, endpointA, endpointB));
+
+ // Since the ops aren't ordered, two batches are produced on a single targeting call:
+ // 1) the ops to shard A (op 0 and op 2) are a batch and 2) the ops to shard B (op 1
+ // and op 3) are a batch. Therefore the stmtIds in the bulkWrite request sent to the shards
+ // will be interleaving.
+ BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 6, shard A
+ BulkWriteInsertOp(0, BSON("x" << 1)), // stmtId 7, shard B
+ BulkWriteInsertOp(0, BSON("x" << -1)), // stmtId 8, shard A
+ BulkWriteInsertOp(0, BSON("x" << 2))}, // stmtId 9, shard B
+ {NamespaceInfoEntry(nss)});
+ request.setOrdered(false);
+ request.setStmtId(6); // Produces stmtIds 6, 7, 8, 9
+
+ // Setting the txnNumber makes it a retryable write.
+ _opCtx->setLogicalSessionId(LogicalSessionId());
+ _opCtx->setTxnNumber(TxnNumber(0));
+ BulkWriteOp bulkWriteOp(_opCtx, request);
+
+ std::map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted;
+ ASSERT_OK(bulkWriteOp.target(targeters, false, targeted));
+
+ // The batch to shard A contains op 0 and op 2.
+ auto* batch = targeted[ShardId("shardA")].get();
+ auto childRequest = bulkWriteOp.buildBulkCommandRequest(*batch);
+ auto childStmtIds = childRequest.getStmtIds();
+ ASSERT_EQUALS(childStmtIds->size(), 2u);
+ ASSERT_EQUALS(childStmtIds->at(0), 6);
+ ASSERT_EQUALS(childStmtIds->at(1), 8);
+
+ // The batch to shard B contains op 1 and op 3.
+ batch = targeted[ShardId("shardB")].get();
+ childRequest = bulkWriteOp.buildBulkCommandRequest(*batch);
+ childStmtIds = childRequest.getStmtIds();
+ ASSERT_EQUALS(childStmtIds->size(), 2u);
+ ASSERT_EQUALS(childStmtIds->at(0), 7);
+ ASSERT_EQUALS(childStmtIds->at(1), 9);
+}
+
// Test BatchItemRef.getLet().
TEST_F(BulkWriteOpTest, BatchItemRefGetLet) {
NamespaceString nss("foo.bar");