summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLingzhi Deng <lingzhi.deng@mongodb.com>2023-02-10 04:14:01 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2023-02-10 08:00:31 +0000
commit188549516c293b65c2c3cba5ae05a573b2eab460 (patch)
tree25dd3aa0996ce90573cf6f10335a5391936ed5bc
parent164d779fe330e254ddd9dbf15c51fe38d5369fb1 (diff)
downloadmongo-188549516c293b65c2c3cba5ae05a573b2eab460.tar.gz
SERVER-72787: Implement sub-batching logic for bulkWrite
-rw-r--r--src/mongo/s/SConscript1
-rw-r--r--src/mongo/s/chunk_manager.cpp40
-rw-r--r--src/mongo/s/chunk_manager.h7
-rw-r--r--src/mongo/s/write_ops/batch_write_op.cpp222
-rw-r--r--src/mongo/s/write_ops/batch_write_op.h19
-rw-r--r--src/mongo/s/write_ops/bulk_write_exec.cpp79
-rw-r--r--src/mongo/s/write_ops/bulk_write_exec.h27
-rw-r--r--src/mongo/s/write_ops/bulk_write_exec_test.cpp328
8 files changed, 596 insertions, 127 deletions
diff --git a/src/mongo/s/SConscript b/src/mongo/s/SConscript
index 50dc50966c2..4cfd659e533 100644
--- a/src/mongo/s/SConscript
+++ b/src/mongo/s/SConscript
@@ -686,6 +686,7 @@ env.CppUnitTest(
'write_ops/batch_write_op_test.cpp',
'write_ops/batched_command_request_test.cpp',
'write_ops/batched_command_response_test.cpp',
+ 'write_ops/bulk_write_exec_test.cpp',
'write_ops/write_op_test.cpp',
'write_ops/write_without_shard_key_util_test.cpp',
],
diff --git a/src/mongo/s/chunk_manager.cpp b/src/mongo/s/chunk_manager.cpp
index 0c213a7c5d3..dee2b74ed49 100644
--- a/src/mongo/s/chunk_manager.cpp
+++ b/src/mongo/s/chunk_manager.cpp
@@ -726,4 +726,44 @@ ShardEndpoint::ShardEndpoint(const ShardId& shardName,
invariant(shardName == ShardId::kConfigServerId);
}
+bool EndpointComp::operator()(const ShardEndpoint* endpointA,
+ const ShardEndpoint* endpointB) const {
+ const int shardNameDiff = endpointA->shardName.compare(endpointB->shardName);
+ if (shardNameDiff)
+ return shardNameDiff < 0;
+
+ if (endpointA->shardVersion && endpointB->shardVersion) {
+ const int epochDiff = endpointA->shardVersion->placementVersion().epoch().compare(
+ endpointB->shardVersion->placementVersion().epoch());
+ if (epochDiff)
+ return epochDiff < 0;
+
+ const int shardVersionDiff = endpointA->shardVersion->placementVersion().toLong() -
+ endpointB->shardVersion->placementVersion().toLong();
+ if (shardVersionDiff)
+ return shardVersionDiff < 0;
+ } else if (!endpointA->shardVersion && !endpointB->shardVersion) {
+ // TODO (SERVER-51070): Can only happen if the destination is the config server
+ return false;
+ } else {
+ // TODO (SERVER-51070): Can only happen if the destination is the config server
+ return !endpointA->shardVersion && endpointB->shardVersion;
+ }
+
+ if (endpointA->databaseVersion && endpointB->databaseVersion) {
+ const int uuidDiff =
+ endpointA->databaseVersion->getUuid().compare(endpointB->databaseVersion->getUuid());
+ if (uuidDiff)
+ return uuidDiff < 0;
+
+ return endpointA->databaseVersion->getLastMod() < endpointB->databaseVersion->getLastMod();
+ } else if (!endpointA->databaseVersion && !endpointB->databaseVersion) {
+ return false;
+ } else {
+ return !endpointA->databaseVersion && endpointB->databaseVersion;
+ }
+
+ MONGO_UNREACHABLE;
+}
+
} // namespace mongo
diff --git a/src/mongo/s/chunk_manager.h b/src/mongo/s/chunk_manager.h
index e6d1d3fe8a7..002e557a410 100644
--- a/src/mongo/s/chunk_manager.h
+++ b/src/mongo/s/chunk_manager.h
@@ -487,6 +487,13 @@ struct ShardEndpoint {
};
/**
+ * Compares shard endpoints in a map.
+ */
+struct EndpointComp {
+ bool operator()(const ShardEndpoint* endpointA, const ShardEndpoint* endpointB) const;
+};
+
+/**
* Wrapper around a RoutingTableHistory, which pins it to a particular point in time.
*/
class ChunkManager {
diff --git a/src/mongo/s/write_ops/batch_write_op.cpp b/src/mongo/s/write_ops/batch_write_op.cpp
index 5b2df548795..ddb0f7f7b55 100644
--- a/src/mongo/s/write_ops/batch_write_op.cpp
+++ b/src/mongo/s/write_ops/batch_write_op.cpp
@@ -240,23 +240,13 @@ int getEncryptionInformationSize(const BatchedCommandRequest& req) {
} // namespace
-BatchWriteOp::BatchWriteOp(OperationContext* opCtx, const BatchedCommandRequest& clientRequest)
- : _opCtx(opCtx),
- _clientRequest(clientRequest),
- _batchTxnNum(_opCtx->getTxnNumber()),
- _inTransaction(bool(TransactionRouter::get(opCtx))),
- _isRetryableWrite(opCtx->isRetryableWrite()) {
- _writeOps.reserve(_clientRequest.sizeWriteOps());
-
- for (size_t i = 0; i < _clientRequest.sizeWriteOps(); ++i) {
- _writeOps.emplace_back(BatchItemRef(&_clientRequest, i), _inTransaction);
- }
-}
-
-StatusWith<bool> BatchWriteOp::targetBatch(
- const NSTargeter& targeter,
- bool recordTargetErrors,
- std::map<ShardId, std::unique_ptr<TargetedWriteBatch>>* targetedBatches) {
+StatusWith<bool> targetWriteOps(OperationContext* opCtx,
+ std::vector<WriteOp>& writeOps,
+ bool ordered,
+ bool recordTargetErrors,
+ GetTargeterFn getTargeterFn,
+ GetWriteSizeFn getWriteSizeFn,
+ TargetedBatchMap& batchMap) {
//
// Targeting of unordered batches is fairly simple - each remaining write op is targeted,
// and each of those targeted writes are grouped into a batch for a particular shard
@@ -290,19 +280,13 @@ StatusWith<bool> BatchWriteOp::targetBatch(
// [{ skey : y }, { skey : z }]
//
- const bool ordered = _clientRequest.getWriteCommandRequestBase().getOrdered();
-
bool isWriteWithoutShardKeyOrId = false;
- TargetedBatchMap batchMap;
+ // Used to track the set of shardIds (w/o shardVersion) we targeted.
std::set<ShardId> targetedShards;
- const size_t numWriteOps = _clientRequest.sizeWriteOps();
-
- for (size_t i = 0; i < numWriteOps; ++i) {
- WriteOp& writeOp = _writeOps[i];
-
- // Only target _Ready ops
+ for (auto& writeOp : writeOps) {
+ // Only target Ready op.
if (writeOp.getWriteState() != WriteOpState_Ready)
continue;
@@ -312,34 +296,45 @@ StatusWith<bool> BatchWriteOp::targetBatch(
break;
}
- //
- // Get TargetedWrites from the targeter for the write operation
- //
- // TargetedWrites need to be owned once returned
+ const auto& targeter = getTargeterFn(writeOp);
std::vector<std::unique_ptr<TargetedWrite>> writes;
-
- Status targetStatus = Status::OK();
-
- try {
- writeOp.targetWrites(_opCtx, targeter, &writes);
- } catch (const DBException& ex) {
- targetStatus = ex.toStatus();
- }
+ auto targetStatus = [&] {
+ try {
+ writeOp.targetWrites(opCtx, targeter, &writes);
+ return Status::OK();
+ } catch (const DBException& ex) {
+ return ex.toStatus();
+ }
+ }();
if (!targetStatus.isOK()) {
write_ops::WriteError targetError(0, targetStatus);
- if (TransactionRouter::get(_opCtx)) {
+ auto cancelBatches = [&](const write_ops::WriteError& why) {
+ for (TargetedBatchMap::iterator it = batchMap.begin(); it != batchMap.end();) {
+ for (auto&& write : it->second->getWrites()) {
+ // NOTE: We may repeatedly cancel a write op here, but that's fast and we
+ // want to cancel before erasing the TargetedWrite* (which owns the
+ // cancelled targeting info) for reporting reasons.
+ writeOps[write->writeOpRef.first].cancelWrites(&why);
+ }
+
+ it = batchMap.erase(it);
+ }
+ dassert(batchMap.empty());
+ };
+
+ if (TransactionRouter::get(opCtx)) {
writeOp.setOpError(targetError);
// Cleanup all the writes we have targetted in this call so far since we are going
// to abort the entire transaction.
- _cancelBatches(targetError, std::move(batchMap));
+ cancelBatches(targetError);
return targetStatus;
} else if (!recordTargetErrors) {
// Cancel current batch state with an error
- _cancelBatches(targetError, std::move(batchMap));
+ cancelBatches(targetError);
return targetStatus;
} else if (!ordered || batchMap.empty()) {
// Record an error for this batch
@@ -347,7 +342,7 @@ StatusWith<bool> BatchWriteOp::targetBatch(
writeOp.setOpError(targetError);
if (ordered)
- return StatusWith<bool>(isWriteWithoutShardKeyOrId);
+ return isWriteWithoutShardKeyOrId;
continue;
} else {
@@ -360,11 +355,8 @@ StatusWith<bool> BatchWriteOp::targetBatch(
}
}
- //
- // If ordered and we have a previous endpoint, make sure we don't need to send these
- // targeted writes to any other endpoints.
- //
-
+ // If writes are ordered and we have a targeted endpoint, make sure we don't need to send
+ // these targeted writes to any other endpoints.
if (ordered && !batchMap.empty()) {
dassert(batchMap.size() == 1u);
if (isNewBatchRequiredOrdered(writes, batchMap)) {
@@ -373,33 +365,16 @@ StatusWith<bool> BatchWriteOp::targetBatch(
}
}
- // If retryable writes are used, MongoS needs to send an additional array of stmtId(s)
- // corresponding to the statements that got routed to each individual shard, so they need to
- // be accounted in the potential request size so it does not exceed the max BSON size.
- //
- // The constant 4 is chosen as the size of the BSON representation of the stmtId.
- const int writeSizeBytes = getWriteSizeBytes(writeOp) +
- getEncryptionInformationSize(_clientRequest) +
- write_ops::kWriteCommandBSONArrayPerElementOverheadBytes +
- (_batchTxnNum ? write_ops::kWriteCommandBSONArrayPerElementOverheadBytes + 4 : 0);
-
- // For unordered writes, the router must return an entry for each failed write. This
- // constant is a pessimistic attempt to ensure that if a request to a shard hits
- // "retargeting needed" error and has to return number of errors equivalent to the number of
- // writes in the batch, the response size will not exceed the max BSON size.
- //
- // The constant of 272 is chosen as an approximation of the size of the BSON representation
- // of the StaleConfigInfo (which contains the shard id) and the adjacent error message.
- const int errorResponsePotentialSizeBytes =
- ordered ? 0 : write_ops::kWriteCommandBSONArrayPerElementOverheadBytes + 272;
+ const auto estWriteSizeBytes = getWriteSizeFn(writeOp);
- if (wouldMakeBatchesTooBig(
- writes, std::max(writeSizeBytes, errorResponsePotentialSizeBytes), batchMap)) {
+ if (wouldMakeBatchesTooBig(writes, estWriteSizeBytes, batchMap)) {
invariant(!batchMap.empty());
writeOp.cancelWrites(nullptr);
break;
}
+ // If writes are unordered and we already have targeted endpoints, make sure we don't target
+ // the same shard with a different shardVersion.
if (!ordered && !batchMap.empty() &&
isNewBatchRequiredUnordered(writes, batchMap, targetedShards)) {
writeOp.cancelWrites(nullptr);
@@ -430,7 +405,7 @@ StatusWith<bool> BatchWriteOp::targetBatch(
if (!isMultiWrite &&
write_without_shard_key::useTwoPhaseProtocol(
- _opCtx, targeter.getNS(), true /* isUpdateOrDelete */, query, collation)) {
+ opCtx, targeter.getNS(), true /* isUpdateOrDelete */, query, collation)) {
// Writes without shard key should be in their own batch.
if (!batchMap.empty()) {
@@ -455,8 +430,7 @@ StatusWith<bool> BatchWriteOp::targetBatch(
targetedShards.insert(endpoint->shardName);
}
- batchIt->second->addWrite(std::move(write),
- std::max(writeSizeBytes, errorResponsePotentialSizeBytes));
+ batchIt->second->addWrite(std::move(write), estWriteSizeBytes);
}
// Relinquish ownership of TargetedWrites, now the TargetedBatches own them
@@ -471,6 +445,70 @@ StatusWith<bool> BatchWriteOp::targetBatch(
break;
}
+ return isWriteWithoutShardKeyOrId;
+}
+
+BatchWriteOp::BatchWriteOp(OperationContext* opCtx, const BatchedCommandRequest& clientRequest)
+ : _opCtx(opCtx),
+ _clientRequest(clientRequest),
+ _batchTxnNum(_opCtx->getTxnNumber()),
+ _inTransaction(bool(TransactionRouter::get(opCtx))),
+ _isRetryableWrite(opCtx->isRetryableWrite()) {
+ _writeOps.reserve(_clientRequest.sizeWriteOps());
+
+ for (size_t i = 0; i < _clientRequest.sizeWriteOps(); ++i) {
+ _writeOps.emplace_back(BatchItemRef(&_clientRequest, i), _inTransaction);
+ }
+}
+
+StatusWith<bool> BatchWriteOp::targetBatch(
+ const NSTargeter& targeter,
+ bool recordTargetErrors,
+ std::map<ShardId, std::unique_ptr<TargetedWriteBatch>>* targetedBatches) {
+ const bool ordered = _clientRequest.getWriteCommandRequestBase().getOrdered();
+
+ // Used to track the shard endpoints (w/ shardVersion) we targeted and the batches to each of
+ // these shard endpoints.
+ TargetedBatchMap batchMap;
+
+ auto targetStatus = targetWriteOps(
+ _opCtx,
+ _writeOps,
+ ordered,
+ recordTargetErrors,
+ // getTargeterFn:
+ [&](const WriteOp& writeOp) -> const NSTargeter& { return targeter; },
+ // getWriteSizeFn:
+ [&](const WriteOp& writeOp) {
+ // If retryable writes are used, MongoS needs to send an additional array of stmtId(s)
+ // corresponding to the statements that got routed to each individual shard, so they
+ // need to be accounted in the potential request size so it does not exceed the max BSON
+ // size.
+ //
+ // The constant 4 is chosen as the size of the BSON representation of the stmtId.
+ const int writeSizeBytes = getWriteSizeBytes(writeOp) +
+ getEncryptionInformationSize(_clientRequest) +
+ write_ops::kWriteCommandBSONArrayPerElementOverheadBytes +
+ (_batchTxnNum ? write_ops::kWriteCommandBSONArrayPerElementOverheadBytes + 4 : 0);
+
+ // For unordered writes, the router must return an entry for each failed write. This
+ // constant is a pessimistic attempt to ensure that if a request to a shard hits
+ // "retargeting needed" error and has to return number of errors equivalent to the
+ // number of writes in the batch, the response size will not exceed the max BSON size.
+ //
+ // The constant of 272 is chosen as an approximation of the size of the BSON
+ // representation of the StaleConfigInfo (which contains the shard id) and the adjacent
+ // error message.
+ const int errorResponsePotentialSizeBytes =
+ ordered ? 0 : write_ops::kWriteCommandBSONArrayPerElementOverheadBytes + 272;
+ return std::max(writeSizeBytes, errorResponsePotentialSizeBytes);
+ },
+ batchMap);
+
+ if (!targetStatus.isOK()) {
+ return targetStatus;
+ }
+
//
// Send back our targeted batches
//
@@ -490,7 +528,7 @@ StatusWith<bool> BatchWriteOp::targetBatch(
_nShardsOwningChunks = targeter.getNShardsOwningChunks();
- return StatusWith<bool>(isWriteWithoutShardKeyOrId);
+ return targetStatus;
}
BatchedCommandRequest BatchWriteOp::buildBatchRequest(const TargetedWriteBatch& targetedBatch,
@@ -957,46 +995,6 @@ void BatchWriteOp::_cancelBatches(const write_ops::WriteError& why,
}
}
-bool EndpointComp::operator()(const ShardEndpoint* endpointA,
- const ShardEndpoint* endpointB) const {
- const int shardNameDiff = endpointA->shardName.compare(endpointB->shardName);
- if (shardNameDiff)
- return shardNameDiff < 0;
-
- if (endpointA->shardVersion && endpointB->shardVersion) {
- const int epochDiff = endpointA->shardVersion->placementVersion().epoch().compare(
- endpointB->shardVersion->placementVersion().epoch());
- if (epochDiff)
- return epochDiff < 0;
-
- const int shardVersionDiff = endpointA->shardVersion->placementVersion().toLong() -
- endpointB->shardVersion->placementVersion().toLong();
- if (shardVersionDiff)
- return shardVersionDiff < 0;
- } else if (!endpointA->shardVersion && !endpointB->shardVersion) {
- // TODO (SERVER-51070): Can only happen if the destination is the config server
- return false;
- } else {
- // TODO (SERVER-51070): Can only happen if the destination is the config server
- return !endpointA->shardVersion && endpointB->shardVersion;
- }
-
- if (endpointA->databaseVersion && endpointB->databaseVersion) {
- const int uuidDiff =
- endpointA->databaseVersion->getUuid().compare(endpointB->databaseVersion->getUuid());
- if (uuidDiff)
- return uuidDiff < 0;
-
- return endpointA->databaseVersion->getLastMod() < endpointB->databaseVersion->getLastMod();
- } else if (!endpointA->databaseVersion && !endpointB->databaseVersion) {
- return false;
- } else {
- return !endpointA->databaseVersion && endpointB->databaseVersion;
- }
-
- MONGO_UNREACHABLE;
-}
-
void TrackedErrors::startTracking(int errCode) {
dassert(!isTracking(errCode));
_errorMap.emplace(errCode, std::vector<ShardError>());
diff --git a/src/mongo/s/write_ops/batch_write_op.h b/src/mongo/s/write_ops/batch_write_op.h
index b828031bacf..2fde01033f9 100644
--- a/src/mongo/s/write_ops/batch_write_op.h
+++ b/src/mongo/s/write_ops/batch_write_op.h
@@ -80,13 +80,6 @@ struct ShardWCError {
WriteConcernErrorDetail error;
};
-/**
- * Compares endpoints in a map.
- */
-struct EndpointComp {
- bool operator()(const ShardEndpoint* endpointA, const ShardEndpoint* endpointB) const;
-};
-
using TargetedBatchMap =
std::map<const ShardEndpoint*, std::unique_ptr<TargetedWriteBatch>, EndpointComp>;
@@ -275,4 +268,16 @@ private:
TrackedErrorMap _errorMap;
};
+typedef std::function<const NSTargeter&(const WriteOp& writeOp)> GetTargeterFn;
+typedef std::function<int(const WriteOp& writeOp)> GetWriteSizeFn;
+
+// Helper function to target ready writeOps. See BatchWriteOp::targetBatch for details.
+StatusWith<bool> targetWriteOps(OperationContext* opCtx,
+ std::vector<WriteOp>& writeOps,
+ bool ordered,
+ bool recordTargetErrors,
+ GetTargeterFn getTargeterFn,
+ GetWriteSizeFn getWriteSizeFn,
+ TargetedBatchMap& batchMap);
+
} // namespace mongo
diff --git a/src/mongo/s/write_ops/bulk_write_exec.cpp b/src/mongo/s/write_ops/bulk_write_exec.cpp
index 51016edfae2..47f62bb38f1 100644
--- a/src/mongo/s/write_ops/bulk_write_exec.cpp
+++ b/src/mongo/s/write_ops/bulk_write_exec.cpp
@@ -36,6 +36,7 @@
#include "mongo/s/client/shard_registry.h"
#include "mongo/s/grid.h"
#include "mongo/s/transaction_router.h"
+#include "mongo/s/write_ops/batch_write_op.h"
#include "mongo/s/write_ops/write_without_shard_key_util.h"
#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kSharding
@@ -55,11 +56,23 @@ void execute(OperationContext* opCtx,
BulkWriteOp bulkWriteOp(opCtx, clientRequest);
+ bool refreshedTargeter = false;
+
while (!bulkWriteOp.isFinished()) {
// 1: Target remaining ops with the appropriate targeter based on the namespace index and
// re-batch ops based on their targeted shard id.
stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> childBatches;
- auto targetStatus = bulkWriteOp.target(targeters, &childBatches);
+
+ bool recordTargetErrors = refreshedTargeter;
+ auto targetStatus = bulkWriteOp.target(targeters, recordTargetErrors, childBatches);
+ if (!targetStatus.isOK()) {
+ dassert(childBatches.size() == 0u);
+ // TODO(SERVER-72982): Handle targeting errors.
+ for (auto& targeter : targeters) {
+ targeter->noteCouldNotTarget();
+ }
+ refreshedTargeter = true;
+ }
// 2: Use MultiStatementTransactionRequestsSender to send any ready sub-batches to targeted
// shard endpoints.
@@ -69,6 +82,7 @@ void execute(OperationContext* opCtx,
// errors for ordered writes or transactions.
// 4: Refresh the targeter(s) if we receive a stale config/db error.
+ // TODO(SERVER-72982): Handle targeting errors.
}
// Reassemble the final response based on responses from sub-batches.
@@ -94,14 +108,71 @@ BulkWriteOp::BulkWriteOp(OperationContext* opCtx, const BulkWriteCommandRequest&
StatusWith<bool> BulkWriteOp::target(
const std::vector<std::unique_ptr<NSTargeter>>& targeters,
- stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>>* targetedBatches) {
- return false;
+ bool recordTargetErrors,
+ stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>>& targetedBatches) {
+ const auto ordered = _clientRequest.getOrdered();
+
+ // Used to track the shard endpoints (w/ shardVersion) we targeted and the batches to each of
+ // these shard endpoints.
+ TargetedBatchMap batchMap;
+
+ // Used to track the set of shardIds (w/o shardVersion) we targeted.
+ std::set<ShardId> targetedShards;
+
+ auto targetStatus = targetWriteOps(_opCtx,
+ _writeOps,
+ ordered,
+ recordTargetErrors,
+ // getTargeterFn:
+ [&](const WriteOp& writeOp) -> const NSTargeter& {
+ const auto opIdx = writeOp.getWriteItem().getItemIndex();
+ // TODO(SERVER-73281): Support bulkWrite update and
+ // delete.
+ const auto nsIdx =
+ _clientRequest.getOps()[opIdx].getInsert();
+ return *targeters[nsIdx];
+ },
+ // getWriteSizeFn:
+ [&](const WriteOp& writeOp) {
+ // TODO(SERVER-73536): Account for the size of the
+ // outgoing request.
+ return 1;
+ },
+ batchMap);
+
+ if (!targetStatus.isOK()) {
+ return targetStatus;
+ }
+
+ // Send back our targeted batches.
+ for (TargetedBatchMap::iterator it = batchMap.begin(); it != batchMap.end(); ++it) {
+ auto batch = std::move(it->second);
+ if (batch->getWrites().empty())
+ continue;
+
+ invariant(targetedBatches.find(batch->getEndpoint().shardName) == targetedBatches.end());
+ targetedBatches.emplace(batch->getEndpoint().shardName, std::move(batch));
+ }
+
+ return targetStatus;
}
-bool BulkWriteOp::isFinished() {
+bool BulkWriteOp::isFinished() const {
// TODO: Track ops lifetime.
+ const bool ordered = _clientRequest.getOrdered();
+ for (auto& writeOp : _writeOps) {
+ if (writeOp.getWriteState() < WriteOpState_Completed) {
+ return false;
+ } else if (ordered && writeOp.getWriteState() == WriteOpState_Error) {
+ return true;
+ }
+ }
return true;
}
+
+const WriteOp& BulkWriteOp::getWriteOp_forTest(int i) const {
+ return _writeOps[i];
+}
} // namespace bulkWriteExec
} // namespace mongo
diff --git a/src/mongo/s/write_ops/bulk_write_exec.h b/src/mongo/s/write_ops/bulk_write_exec.h
index 1d01745284f..5333153b55d 100644
--- a/src/mongo/s/write_ops/bulk_write_exec.h
+++ b/src/mongo/s/write_ops/bulk_write_exec.h
@@ -37,7 +37,6 @@
#include "mongo/s/write_ops/write_op.h"
namespace mongo {
-
namespace bulkWriteExec {
/**
* Executes a client bulkWrite request by sending child batches to several shard endpoints, and
@@ -84,15 +83,35 @@ public:
BulkWriteOp(OperationContext* opCtx, const BulkWriteCommandRequest& clientRequest);
~BulkWriteOp() = default;
- // TODO(SERVER-72787): Finish this.
+ /**
+ * Targets one or more of the next write ops in this bulkWrite request using the given
+ * NSTargeters (targeters[i] corresponds to the targeter of the collection in nsInfo[i]). The
+ * resulting TargetedWrites are aggregated together in the returned TargetedWriteBatches.
+ *
+ * If 'recordTargetErrors' is false, any targeting error will abort all current batches and
+ * the method will return the targeting error. No targetedBatches will be returned on error.
+ *
+ * Otherwise, if 'recordTargetErrors' is true, targeting errors will be recorded for each
+ * write op that fails to target, and the method will return OK.
+ *
+ * (The idea here is that if we are sure our NSTargeters are up-to-date we should record
+ * targeting errors, but if not we should refresh once first.)
+ *
+ * Returned TargetedWriteBatches are owned by the caller.
+ * If a write without a shard key is detected, return an OK StatusWith that has 'true' as the
+ * value.
+ */
StatusWith<bool> target(
const std::vector<std::unique_ptr<NSTargeter>>& targeters,
- stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>>* targetedBatches);
+ bool recordTargetErrors,
+ stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>>& targetedBatches);
/**
* Returns false if the bulk write op needs more processing.
*/
- bool isFinished();
+ bool isFinished() const;
+
+ const WriteOp& getWriteOp_forTest(int i) const;
private:
// The OperationContext the client bulkWrite request is run on.
diff --git a/src/mongo/s/write_ops/bulk_write_exec_test.cpp b/src/mongo/s/write_ops/bulk_write_exec_test.cpp
new file mode 100644
index 00000000000..ad95d1e7cfb
--- /dev/null
+++ b/src/mongo/s/write_ops/bulk_write_exec_test.cpp
@@ -0,0 +1,328 @@
+/**
+ * Copyright (C) 2023-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/idl/server_parameter_test_util.h"
+#include "mongo/s/catalog_cache_test_fixture.h"
+#include "mongo/s/concurrency/locker_mongos_client_observer.h"
+#include "mongo/s/mock_ns_targeter.h"
+#include "mongo/s/session_catalog_router.h"
+#include "mongo/s/sharding_router_test_fixture.h"
+#include "mongo/s/transaction_router.h"
+#include "mongo/s/write_ops/batch_write_op.h"
+#include "mongo/s/write_ops/batched_command_request.h"
+#include "mongo/s/write_ops/bulk_write_exec.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo {
+namespace {
+
+auto initTargeterFullRange(const NamespaceString& nss, const ShardEndpoint& endpoint) {
+ std::vector<MockRange> range{MockRange(endpoint, BSON("x" << MINKEY), BSON("x" << MAXKEY))};
+ return std::make_unique<MockNSTargeter>(nss, std::move(range));
+}
+
+auto initTargeterSplitRange(const NamespaceString& nss,
+ const ShardEndpoint& endpointA,
+ const ShardEndpoint& endpointB) {
+ std::vector<MockRange> range{MockRange(endpointA, BSON("x" << MINKEY), BSON("x" << 0)),
+ MockRange(endpointB, BSON("x" << 0), BSON("x" << MAXKEY))};
+ return std::make_unique<MockNSTargeter>(nss, std::move(range));
+}
+
+auto initTargeterHalfRange(const NamespaceString& nss, const ShardEndpoint& endpoint) {
+ // x >= 0 values are untargetable
+ std::vector<MockRange> range{MockRange(endpoint, BSON("x" << MINKEY), BSON("x" << 0))};
+ return std::make_unique<MockNSTargeter>(nss, std::move(range));
+}
+
+using namespace bulkWriteExec;
+
+class BulkWriteOpTest : public ServiceContextTest {
+protected:
+ BulkWriteOpTest() {
+ auto service = getServiceContext();
+ service->registerClientObserver(std::make_unique<LockerMongosClientObserver>());
+ _opCtxHolder = makeOperationContext();
+ _opCtx = _opCtxHolder.get();
+ }
+
+ ServiceContext::UniqueOperationContext _opCtxHolder;
+ OperationContext* _opCtx;
+};
+
+// Test targeting a single op in a bulkWrite request.
+TEST_F(BulkWriteOpTest, TargetSingleOp) {
+ NamespaceString nss("foo.bar");
+ ShardEndpoint endpoint(ShardId("shard"), ShardVersion::IGNORED(), boost::none);
+
+ std::vector<std::unique_ptr<NSTargeter>> targeters;
+ targeters.push_back(initTargeterFullRange(nss, endpoint));
+
+ BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << 1))},
+ {NamespaceInfoEntry(nss)});
+
+ BulkWriteOp bulkWriteOp(_opCtx, request);
+
+ stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted;
+ ASSERT_OK(bulkWriteOp.target(targeters, false, targeted));
+ ASSERT_EQUALS(targeted.size(), 1u);
+ assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpoint);
+ ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 1u);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending);
+}
+
+// Test targeting a single op with target error.
+TEST_F(BulkWriteOpTest, TargetSingleOpError) {
+ NamespaceString nss("foo.bar");
+ ShardEndpoint endpoint(ShardId("shard"), ShardVersion::IGNORED(), boost::none);
+
+ std::vector<std::unique_ptr<NSTargeter>> targeters;
+ // Initialize the targeter so that x >= 0 values are untargetable so target call will encounter
+ // an error.
+ targeters.push_back(initTargeterHalfRange(nss, endpoint));
+
+ BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << 1))},
+ {NamespaceInfoEntry(nss)});
+
+ BulkWriteOp bulkWriteOp(_opCtx, request);
+
+ stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted;
+ // target should return target error when recordTargetErrors = false.
+ ASSERT_NOT_OK(bulkWriteOp.target(targeters, false, targeted));
+ ASSERT_EQUALS(targeted.size(), 0u);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Ready);
+
+ // target should transition the writeOp to an error state upon target errors when
+ // recordTargetErrors = true.
+ ASSERT_OK(bulkWriteOp.target(targeters, true, targeted));
+ ASSERT_EQUALS(targeted.size(), 0u);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Error);
+}
+
+// Test multiple ordered ops that target the same shard.
+TEST_F(BulkWriteOpTest, TargetMultiOpsOrdered_SameShard) {
+ NamespaceString nss0("foo.bar");
+ NamespaceString nss1("bar.foo");
+ ShardEndpoint endpoint(ShardId("shard"), ShardVersion::IGNORED(), boost::none);
+
+ std::vector<std::unique_ptr<NSTargeter>> targeters;
+ targeters.push_back(initTargeterFullRange(nss0, endpoint));
+ targeters.push_back(initTargeterFullRange(nss1, endpoint));
+
+ BulkWriteCommandRequest request(
+ {BulkWriteInsertOp(1, BSON("x" << 1)), BulkWriteInsertOp(0, BSON("x" << 2))},
+ {NamespaceInfoEntry(nss0), NamespaceInfoEntry(nss1)});
+
+ BulkWriteOp bulkWriteOp(_opCtx, request);
+
+ stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted;
+ ASSERT_OK(bulkWriteOp.target(targeters, false, targeted));
+ ASSERT_EQUALS(targeted.size(), 1u);
+ assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpoint);
+ ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 2u);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Pending);
+}
+
+// Test multiple ordered ops where one of them result in a target error.
+TEST_F(BulkWriteOpTest, TargetMultiOpsOrdered_RecordTargetErrors) {
+ NamespaceString nss0("foo.bar");
+ NamespaceString nss1("bar.foo");
+ ShardEndpoint endpoint(ShardId("shard"), ShardVersion::IGNORED(), boost::none);
+
+ std::vector<std::unique_ptr<NSTargeter>> targeters;
+ // Initialize the targeter so that x >= 0 values are untargetable so target call will encounter
+ // an error.
+ targeters.push_back(initTargeterHalfRange(nss0, endpoint));
+ targeters.push_back(initTargeterFullRange(nss1, endpoint));
+
+ // Only the second op would get a target error.
+ BulkWriteCommandRequest request({BulkWriteInsertOp(1, BSON("x" << 1)),
+ BulkWriteInsertOp(0, BSON("x" << 2)),
+ BulkWriteInsertOp(0, BSON("x" << -1))},
+ {NamespaceInfoEntry(nss0), NamespaceInfoEntry(nss1)});
+
+ BulkWriteOp bulkWriteOp(_opCtx, request);
+
+ stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted;
+ ASSERT_OK(bulkWriteOp.target(targeters, true, targeted));
+
+ // Only the first op should be targeted as the second op encounters a target error. But this
+ // won't record the target error since there could be an error in the first op before executing
+ // the second op.
+ ASSERT_EQUALS(targeted.size(), 1u);
+ assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpoint);
+ ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 1u);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Ready);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Ready);
+
+ targeted.clear();
+
+ // Pretending the first op was done successfully, the target error should be recorded in the
+ // second op.
+ ASSERT_OK(bulkWriteOp.target(targeters, true, targeted));
+ ASSERT_EQUALS(targeted.size(), 0u);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Error);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Ready);
+}
+
+// Test multiple ordered ops that target two different shards.
+TEST_F(BulkWriteOpTest, TargetMultiOpsOrdered_DifferentShard) {
+ NamespaceString nss0("foo.bar");
+ NamespaceString nss1("bar.foo");
+ ShardEndpoint endpointA(ShardId("shardA"), ShardVersion::IGNORED(), boost::none);
+ ShardEndpoint endpointB(ShardId("shardB"), ShardVersion::IGNORED(), boost::none);
+
+ std::vector<std::unique_ptr<NSTargeter>> targeters;
+ targeters.push_back(initTargeterSplitRange(nss0, endpointA, endpointB));
+ targeters.push_back(initTargeterFullRange(nss1, endpointA));
+
+ // ops[0] -> shardA
+ // ops[1] -> shardB
+ // ops[2] -> shardA
+ BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << -1)),
+ BulkWriteInsertOp(0, BSON("x" << 1)),
+ BulkWriteInsertOp(1, BSON("x" << 1))},
+ {NamespaceInfoEntry(nss0), NamespaceInfoEntry(nss1)});
+
+ BulkWriteOp bulkWriteOp(_opCtx, request);
+
+ stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted;
+
+ // The resulting batch should be {shardA: [ops[0]]}.
+ ASSERT_OK(bulkWriteOp.target(targeters, false, targeted));
+ ASSERT_EQUALS(targeted.size(), 1u);
+ assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpointA);
+ ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 1u);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Ready);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Ready);
+
+ targeted.clear();
+
+ // The resulting batch should be {shardB: [ops[1]]}.
+ ASSERT_OK(bulkWriteOp.target(targeters, false, targeted));
+ ASSERT_EQUALS(targeted.size(), 1u);
+ assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpointB);
+ ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 1u);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Pending);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Ready);
+
+ targeted.clear();
+
+ // The resulting batch should be {shardA: [ops[2]]}.
+ ASSERT_OK(bulkWriteOp.target(targeters, false, targeted));
+ ASSERT_EQUALS(targeted.size(), 1u);
+ assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpointA);
+ ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 1u);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Pending);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Pending);
+}
+
+// Test multiple unordered ops that target two different shards.
+TEST_F(BulkWriteOpTest, TargetMultiOpsUnordered) {
+ NamespaceString nss0("foo.bar");
+ NamespaceString nss1("bar.foo");
+ ShardEndpoint endpointA(ShardId("shardA"), ShardVersion::IGNORED(), boost::none);
+ ShardEndpoint endpointB(ShardId("shardB"), ShardVersion::IGNORED(), boost::none);
+
+ std::vector<std::unique_ptr<NSTargeter>> targeters;
+ targeters.push_back(initTargeterSplitRange(nss0, endpointA, endpointB));
+ targeters.push_back(initTargeterFullRange(nss1, endpointA));
+
+ // ops[0] -> shardA
+ // ops[1] -> shardB
+ // ops[2] -> shardA
+ BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << -1)),
+ BulkWriteInsertOp(0, BSON("x" << 1)),
+ BulkWriteInsertOp(1, BSON("x" << 1))},
+ {NamespaceInfoEntry(nss0), NamespaceInfoEntry(nss1)});
+ request.setOrdered(false);
+
+ BulkWriteOp bulkWriteOp(_opCtx, request);
+
+ // The two resulting batches should be:
+ // {shardA: [ops[0], ops[2]]}
+ // {shardB: [ops[1]]}
+ stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted;
+ ASSERT_OK(bulkWriteOp.target(targeters, false, targeted));
+ ASSERT_EQUALS(targeted.size(), 2u);
+
+ ASSERT_EQUALS(targeted[ShardId("shardA")]->getWrites().size(), 2u);
+ ASSERT_EQUALS(targeted[ShardId("shardA")]->getWrites()[0]->writeOpRef.first, 0);
+ ASSERT_EQUALS(targeted[ShardId("shardA")]->getWrites()[1]->writeOpRef.first, 2);
+
+ ASSERT_EQUALS(targeted[ShardId("shardB")]->getWrites().size(), 1u);
+ ASSERT_EQUALS(targeted[ShardId("shardB")]->getWrites()[0]->writeOpRef.first, 1);
+
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Pending);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Pending);
+}
+
+// Test multiple unordered ops where one of them result in a target error.
+TEST_F(BulkWriteOpTest, TargetMultiOpsUnordered_RecordTargetErrors) {
+ NamespaceString nss0("foo.bar");
+ NamespaceString nss1("bar.foo");
+ ShardEndpoint endpoint(ShardId("shard"), ShardVersion::IGNORED(), boost::none);
+
+ std::vector<std::unique_ptr<NSTargeter>> targeters;
+ // Initialize the targeter so that x >= 0 values are untargetable so target call will encounter
+ // an error.
+ targeters.push_back(initTargeterHalfRange(nss0, endpoint));
+ targeters.push_back(initTargeterFullRange(nss1, endpoint));
+
+ // Only the second op would get a target error.
+ BulkWriteCommandRequest request({BulkWriteInsertOp(1, BSON("x" << 1)),
+ BulkWriteInsertOp(0, BSON("x" << 2)),
+ BulkWriteInsertOp(0, BSON("x" << -1))},
+ {NamespaceInfoEntry(nss0), NamespaceInfoEntry(nss1)});
+ request.setOrdered(false);
+
+ BulkWriteOp bulkWriteOp(_opCtx, request);
+
+ stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted;
+ ASSERT_OK(bulkWriteOp.target(targeters, true, targeted));
+
+ // In the unordered case, both the first and the third ops should be targeted successfully
+ // despite targeting error on the second op.
+ ASSERT_EQUALS(targeted.size(), 1u);
+ assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpoint);
+ ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 2u);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Error);
+ ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Pending);
+}
+
+} // namespace
+
+} // namespace mongo