diff options
author | Lingzhi Deng <lingzhi.deng@mongodb.com> | 2023-02-10 04:14:01 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2023-02-10 08:00:31 +0000 |
commit | 188549516c293b65c2c3cba5ae05a573b2eab460 (patch) | |
tree | 25dd3aa0996ce90573cf6f10335a5391936ed5bc | |
parent | 164d779fe330e254ddd9dbf15c51fe38d5369fb1 (diff) | |
download | mongo-188549516c293b65c2c3cba5ae05a573b2eab460.tar.gz |
SERVER-72787: Implement sub-batching logic for bulkWrite
-rw-r--r-- | src/mongo/s/SConscript | 1 | ||||
-rw-r--r-- | src/mongo/s/chunk_manager.cpp | 40 | ||||
-rw-r--r-- | src/mongo/s/chunk_manager.h | 7 | ||||
-rw-r--r-- | src/mongo/s/write_ops/batch_write_op.cpp | 222 | ||||
-rw-r--r-- | src/mongo/s/write_ops/batch_write_op.h | 19 | ||||
-rw-r--r-- | src/mongo/s/write_ops/bulk_write_exec.cpp | 79 | ||||
-rw-r--r-- | src/mongo/s/write_ops/bulk_write_exec.h | 27 | ||||
-rw-r--r-- | src/mongo/s/write_ops/bulk_write_exec_test.cpp | 328 |
8 files changed, 596 insertions, 127 deletions
diff --git a/src/mongo/s/SConscript b/src/mongo/s/SConscript index 50dc50966c2..4cfd659e533 100644 --- a/src/mongo/s/SConscript +++ b/src/mongo/s/SConscript @@ -686,6 +686,7 @@ env.CppUnitTest( 'write_ops/batch_write_op_test.cpp', 'write_ops/batched_command_request_test.cpp', 'write_ops/batched_command_response_test.cpp', + 'write_ops/bulk_write_exec_test.cpp', 'write_ops/write_op_test.cpp', 'write_ops/write_without_shard_key_util_test.cpp', ], diff --git a/src/mongo/s/chunk_manager.cpp b/src/mongo/s/chunk_manager.cpp index 0c213a7c5d3..dee2b74ed49 100644 --- a/src/mongo/s/chunk_manager.cpp +++ b/src/mongo/s/chunk_manager.cpp @@ -726,4 +726,44 @@ ShardEndpoint::ShardEndpoint(const ShardId& shardName, invariant(shardName == ShardId::kConfigServerId); } +bool EndpointComp::operator()(const ShardEndpoint* endpointA, + const ShardEndpoint* endpointB) const { + const int shardNameDiff = endpointA->shardName.compare(endpointB->shardName); + if (shardNameDiff) + return shardNameDiff < 0; + + if (endpointA->shardVersion && endpointB->shardVersion) { + const int epochDiff = endpointA->shardVersion->placementVersion().epoch().compare( + endpointB->shardVersion->placementVersion().epoch()); + if (epochDiff) + return epochDiff < 0; + + const int shardVersionDiff = endpointA->shardVersion->placementVersion().toLong() - + endpointB->shardVersion->placementVersion().toLong(); + if (shardVersionDiff) + return shardVersionDiff < 0; + } else if (!endpointA->shardVersion && !endpointB->shardVersion) { + // TODO (SERVER-51070): Can only happen if the destination is the config server + return false; + } else { + // TODO (SERVER-51070): Can only happen if the destination is the config server + return !endpointA->shardVersion && endpointB->shardVersion; + } + + if (endpointA->databaseVersion && endpointB->databaseVersion) { + const int uuidDiff = + endpointA->databaseVersion->getUuid().compare(endpointB->databaseVersion->getUuid()); + if (uuidDiff) + return uuidDiff < 0; + + return endpointA->databaseVersion->getLastMod() < endpointB->databaseVersion->getLastMod(); + } else if (!endpointA->databaseVersion && !endpointB->databaseVersion) { + return false; + } else { + return !endpointA->databaseVersion && endpointB->databaseVersion; + } + + MONGO_UNREACHABLE; +} + } // namespace mongo diff --git a/src/mongo/s/chunk_manager.h b/src/mongo/s/chunk_manager.h index e6d1d3fe8a7..002e557a410 100644 --- a/src/mongo/s/chunk_manager.h +++ b/src/mongo/s/chunk_manager.h @@ -487,6 +487,13 @@ struct ShardEndpoint { }; /** + * Compares shard endpoints in a map. + */ +struct EndpointComp { + bool operator()(const ShardEndpoint* endpointA, const ShardEndpoint* endpointB) const; +}; + +/** * Wrapper around a RoutingTableHistory, which pins it to a particular point in time. */ class ChunkManager { diff --git a/src/mongo/s/write_ops/batch_write_op.cpp b/src/mongo/s/write_ops/batch_write_op.cpp index 5b2df548795..ddb0f7f7b55 100644 --- a/src/mongo/s/write_ops/batch_write_op.cpp +++ b/src/mongo/s/write_ops/batch_write_op.cpp @@ -240,23 +240,13 @@ int getEncryptionInformationSize(const BatchedCommandRequest& req) { } // namespace -BatchWriteOp::BatchWriteOp(OperationContext* opCtx, const BatchedCommandRequest& clientRequest) - : _opCtx(opCtx), - _clientRequest(clientRequest), - _batchTxnNum(_opCtx->getTxnNumber()), - _inTransaction(bool(TransactionRouter::get(opCtx))), - _isRetryableWrite(opCtx->isRetryableWrite()) { - _writeOps.reserve(_clientRequest.sizeWriteOps()); - - for (size_t i = 0; i < _clientRequest.sizeWriteOps(); ++i) { - _writeOps.emplace_back(BatchItemRef(&_clientRequest, i), _inTransaction); - } -} - -StatusWith<bool> BatchWriteOp::targetBatch( - const NSTargeter& targeter, - bool recordTargetErrors, - std::map<ShardId, std::unique_ptr<TargetedWriteBatch>>* targetedBatches) { +StatusWith<bool> targetWriteOps(OperationContext* opCtx, + std::vector<WriteOp>& writeOps, + bool ordered, + bool recordTargetErrors, + GetTargeterFn getTargeterFn, + GetWriteSizeFn getWriteSizeFn, + TargetedBatchMap& batchMap) { // // Targeting of unordered batches is fairly simple - each remaining write op is targeted, // and each of those targeted writes are grouped into a batch for a particular shard @@ -290,19 +280,13 @@ StatusWith<bool> BatchWriteOp::targetBatch( // [{ skey : y }, { skey : z }] // - const bool ordered = _clientRequest.getWriteCommandRequestBase().getOrdered(); - bool isWriteWithoutShardKeyOrId = false; - TargetedBatchMap batchMap; + // Used to track the set of shardIds (w/o shardVersion) we targeted. std::set<ShardId> targetedShards; - const size_t numWriteOps = _clientRequest.sizeWriteOps(); - - for (size_t i = 0; i < numWriteOps; ++i) { - WriteOp& writeOp = _writeOps[i]; - - // Only target _Ready ops + for (auto& writeOp : writeOps) { + // Only target Ready op. if (writeOp.getWriteState() != WriteOpState_Ready) continue; @@ -312,34 +296,45 @@ StatusWith<bool> BatchWriteOp::targetBatch( break; } - // - // Get TargetedWrites from the targeter for the write operation - // - // TargetedWrites need to be owned once returned + const auto& targeter = getTargeterFn(writeOp); std::vector<std::unique_ptr<TargetedWrite>> writes; - - Status targetStatus = Status::OK(); - - try { - writeOp.targetWrites(_opCtx, targeter, &writes); - } catch (const DBException& ex) { - targetStatus = ex.toStatus(); - } + auto targetStatus = [&] { + try { + writeOp.targetWrites(opCtx, targeter, &writes); + return Status::OK(); + } catch (const DBException& ex) { + return ex.toStatus(); + } + }(); if (!targetStatus.isOK()) { write_ops::WriteError targetError(0, targetStatus); - if (TransactionRouter::get(_opCtx)) { + auto cancelBatches = [&](const write_ops::WriteError& why) { + for (TargetedBatchMap::iterator it = batchMap.begin(); it != batchMap.end();) { + for (auto&& write : it->second->getWrites()) { + // NOTE: We may repeatedly cancel a write op here, but that's fast and we + // want to cancel before erasing the TargetedWrite* (which owns the + // cancelled targeting info) for reporting reasons. + writeOps[write->writeOpRef.first].cancelWrites(&why); + } + + it = batchMap.erase(it); + } + dassert(batchMap.empty()); + }; + + if (TransactionRouter::get(opCtx)) { writeOp.setOpError(targetError); // Cleanup all the writes we have targetted in this call so far since we are going // to abort the entire transaction. - _cancelBatches(targetError, std::move(batchMap)); + cancelBatches(targetError); return targetStatus; } else if (!recordTargetErrors) { // Cancel current batch state with an error - _cancelBatches(targetError, std::move(batchMap)); + cancelBatches(targetError); return targetStatus; } else if (!ordered || batchMap.empty()) { // Record an error for this batch @@ -347,7 +342,7 @@ StatusWith<bool> BatchWriteOp::targetBatch( writeOp.setOpError(targetError); if (ordered) - return StatusWith<bool>(isWriteWithoutShardKeyOrId); + return isWriteWithoutShardKeyOrId; continue; } else { @@ -360,11 +355,8 @@ StatusWith<bool> BatchWriteOp::targetBatch( } } - // - // If ordered and we have a previous endpoint, make sure we don't need to send these - // targeted writes to any other endpoints. - // - + // If writes are ordered and we have a targeted endpoint, make sure we don't need to send + // these targeted writes to any other endpoints. if (ordered && !batchMap.empty()) { dassert(batchMap.size() == 1u); if (isNewBatchRequiredOrdered(writes, batchMap)) { @@ -373,33 +365,16 @@ StatusWith<bool> BatchWriteOp::targetBatch( } } - // If retryable writes are used, MongoS needs to send an additional array of stmtId(s) - // corresponding to the statements that got routed to each individual shard, so they need to - // be accounted in the potential request size so it does not exceed the max BSON size. - // - // The constant 4 is chosen as the size of the BSON representation of the stmtId. - const int writeSizeBytes = getWriteSizeBytes(writeOp) + - getEncryptionInformationSize(_clientRequest) + - write_ops::kWriteCommandBSONArrayPerElementOverheadBytes + - (_batchTxnNum ? write_ops::kWriteCommandBSONArrayPerElementOverheadBytes + 4 : 0); - - // For unordered writes, the router must return an entry for each failed write. This - // constant is a pessimistic attempt to ensure that if a request to a shard hits - // "retargeting needed" error and has to return number of errors equivalent to the number of - // writes in the batch, the response size will not exceed the max BSON size. - // - // The constant of 272 is chosen as an approximation of the size of the BSON representation - // of the StaleConfigInfo (which contains the shard id) and the adjacent error message. - const int errorResponsePotentialSizeBytes = - ordered ? 0 : write_ops::kWriteCommandBSONArrayPerElementOverheadBytes + 272; + const auto estWriteSizeBytes = getWriteSizeFn(writeOp); - if (wouldMakeBatchesTooBig( - writes, std::max(writeSizeBytes, errorResponsePotentialSizeBytes), batchMap)) { + if (wouldMakeBatchesTooBig(writes, estWriteSizeBytes, batchMap)) { invariant(!batchMap.empty()); writeOp.cancelWrites(nullptr); break; } + // If writes are unordered and we already have targeted endpoints, make sure we don't target + // the same shard with a different shardVersion. if (!ordered && !batchMap.empty() && isNewBatchRequiredUnordered(writes, batchMap, targetedShards)) { writeOp.cancelWrites(nullptr); @@ -430,7 +405,7 @@ StatusWith<bool> BatchWriteOp::targetBatch( if (!isMultiWrite && write_without_shard_key::useTwoPhaseProtocol( - _opCtx, targeter.getNS(), true /* isUpdateOrDelete */, query, collation)) { + opCtx, targeter.getNS(), true /* isUpdateOrDelete */, query, collation)) { // Writes without shard key should be in their own batch. if (!batchMap.empty()) { @@ -455,8 +430,7 @@ StatusWith<bool> BatchWriteOp::targetBatch( targetedShards.insert(endpoint->shardName); } - batchIt->second->addWrite(std::move(write), - std::max(writeSizeBytes, errorResponsePotentialSizeBytes)); + batchIt->second->addWrite(std::move(write), estWriteSizeBytes); } // Relinquish ownership of TargetedWrites, now the TargetedBatches own them @@ -471,6 +445,70 @@ StatusWith<bool> BatchWriteOp::targetBatch( break; } + return isWriteWithoutShardKeyOrId; +} + +BatchWriteOp::BatchWriteOp(OperationContext* opCtx, const BatchedCommandRequest& clientRequest) + : _opCtx(opCtx), + _clientRequest(clientRequest), + _batchTxnNum(_opCtx->getTxnNumber()), + _inTransaction(bool(TransactionRouter::get(opCtx))), + _isRetryableWrite(opCtx->isRetryableWrite()) { + _writeOps.reserve(_clientRequest.sizeWriteOps()); + + for (size_t i = 0; i < _clientRequest.sizeWriteOps(); ++i) { + _writeOps.emplace_back(BatchItemRef(&_clientRequest, i), _inTransaction); + } +} + +StatusWith<bool> BatchWriteOp::targetBatch( + const NSTargeter& targeter, + bool recordTargetErrors, + std::map<ShardId, std::unique_ptr<TargetedWriteBatch>>* targetedBatches) { + const bool ordered = _clientRequest.getWriteCommandRequestBase().getOrdered(); + + // Used to track the shard endpoints (w/ shardVersion) we targeted and the batches to each of + // these shard endpoints. + TargetedBatchMap batchMap; + + auto targetStatus = targetWriteOps( + _opCtx, + _writeOps, + ordered, + recordTargetErrors, + // getTargeterFn: + [&](const WriteOp& writeOp) -> const NSTargeter& { return targeter; }, + // getWriteSizeFn: + [&](const WriteOp& writeOp) { + // If retryable writes are used, MongoS needs to send an additional array of stmtId(s) + // corresponding to the statements that got routed to each individual shard, so they + // need to be accounted in the potential request size so it does not exceed the max BSON + // size. + // + // The constant 4 is chosen as the size of the BSON representation of the stmtId. + const int writeSizeBytes = getWriteSizeBytes(writeOp) + + getEncryptionInformationSize(_clientRequest) + + write_ops::kWriteCommandBSONArrayPerElementOverheadBytes + + (_batchTxnNum ? write_ops::kWriteCommandBSONArrayPerElementOverheadBytes + 4 : 0); + + // For unordered writes, the router must return an entry for each failed write. This + // constant is a pessimistic attempt to ensure that if a request to a shard hits + // "retargeting needed" error and has to return number of errors equivalent to the + // number of writes in the batch, the response size will not exceed the max BSON size. + // + // The constant of 272 is chosen as an approximation of the size of the BSON + // representation of the StaleConfigInfo (which contains the shard id) and the adjacent + // error message. + const int errorResponsePotentialSizeBytes = + ordered ? 0 : write_ops::kWriteCommandBSONArrayPerElementOverheadBytes + 272; + return std::max(writeSizeBytes, errorResponsePotentialSizeBytes); + }, + batchMap); + + if (!targetStatus.isOK()) { + return targetStatus; + } + // // Send back our targeted batches // @@ -490,7 +528,7 @@ StatusWith<bool> BatchWriteOp::targetBatch( _nShardsOwningChunks = targeter.getNShardsOwningChunks(); - return StatusWith<bool>(isWriteWithoutShardKeyOrId); + return targetStatus; } BatchedCommandRequest BatchWriteOp::buildBatchRequest(const TargetedWriteBatch& targetedBatch, @@ -957,46 +995,6 @@ void BatchWriteOp::_cancelBatches(const write_ops::WriteError& why, } } -bool EndpointComp::operator()(const ShardEndpoint* endpointA, - const ShardEndpoint* endpointB) const { - const int shardNameDiff = endpointA->shardName.compare(endpointB->shardName); - if (shardNameDiff) - return shardNameDiff < 0; - - if (endpointA->shardVersion && endpointB->shardVersion) { - const int epochDiff = endpointA->shardVersion->placementVersion().epoch().compare( - endpointB->shardVersion->placementVersion().epoch()); - if (epochDiff) - return epochDiff < 0; - - const int shardVersionDiff = endpointA->shardVersion->placementVersion().toLong() - - endpointB->shardVersion->placementVersion().toLong(); - if (shardVersionDiff) - return shardVersionDiff < 0; - } else if (!endpointA->shardVersion && !endpointB->shardVersion) { - // TODO (SERVER-51070): Can only happen if the destination is the config server - return false; - } else { - // TODO (SERVER-51070): Can only happen if the destination is the config server - return !endpointA->shardVersion && endpointB->shardVersion; - } - - if (endpointA->databaseVersion && endpointB->databaseVersion) { - const int uuidDiff = - endpointA->databaseVersion->getUuid().compare(endpointB->databaseVersion->getUuid()); - if (uuidDiff) - return uuidDiff < 0; - - return endpointA->databaseVersion->getLastMod() < endpointB->databaseVersion->getLastMod(); - } else if (!endpointA->databaseVersion && !endpointB->databaseVersion) { - return false; - } else { - return !endpointA->databaseVersion && endpointB->databaseVersion; - } - - MONGO_UNREACHABLE; -} - void TrackedErrors::startTracking(int errCode) { dassert(!isTracking(errCode)); _errorMap.emplace(errCode, std::vector<ShardError>()); diff --git a/src/mongo/s/write_ops/batch_write_op.h b/src/mongo/s/write_ops/batch_write_op.h index b828031bacf..2fde01033f9 100644 --- a/src/mongo/s/write_ops/batch_write_op.h +++ b/src/mongo/s/write_ops/batch_write_op.h @@ -80,13 +80,6 @@ struct ShardWCError { WriteConcernErrorDetail error; }; -/** - * Compares endpoints in a map. - */ -struct EndpointComp { - bool operator()(const ShardEndpoint* endpointA, const ShardEndpoint* endpointB) const; -}; - using TargetedBatchMap = std::map<const ShardEndpoint*, std::unique_ptr<TargetedWriteBatch>, EndpointComp>; @@ -275,4 +268,16 @@ private: TrackedErrorMap _errorMap; }; +typedef std::function<const NSTargeter&(const WriteOp& writeOp)> GetTargeterFn; +typedef std::function<int(const WriteOp& writeOp)> GetWriteSizeFn; + +// Helper function to target ready writeOps. See BatchWriteOp::targetBatch for details. +StatusWith<bool> targetWriteOps(OperationContext* opCtx, + std::vector<WriteOp>& writeOps, + bool ordered, + bool recordTargetErrors, + GetTargeterFn getTargeterFn, + GetWriteSizeFn getWriteSizeFn, + TargetedBatchMap& batchMap); + } // namespace mongo diff --git a/src/mongo/s/write_ops/bulk_write_exec.cpp b/src/mongo/s/write_ops/bulk_write_exec.cpp index 51016edfae2..47f62bb38f1 100644 --- a/src/mongo/s/write_ops/bulk_write_exec.cpp +++ b/src/mongo/s/write_ops/bulk_write_exec.cpp @@ -36,6 +36,7 @@ #include "mongo/s/client/shard_registry.h" #include "mongo/s/grid.h" #include "mongo/s/transaction_router.h" +#include "mongo/s/write_ops/batch_write_op.h" #include "mongo/s/write_ops/write_without_shard_key_util.h" #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kSharding @@ -55,11 +56,23 @@ void execute(OperationContext* opCtx, BulkWriteOp bulkWriteOp(opCtx, clientRequest); + bool refreshedTargeter = false; + while (!bulkWriteOp.isFinished()) { // 1: Target remaining ops with the appropriate targeter based on the namespace index and // re-batch ops based on their targeted shard id. stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> childBatches; - auto targetStatus = bulkWriteOp.target(targeters, &childBatches); + + bool recordTargetErrors = refreshedTargeter; + auto targetStatus = bulkWriteOp.target(targeters, recordTargetErrors, childBatches); + if (!targetStatus.isOK()) { + dassert(childBatches.size() == 0u); + // TODO(SERVER-72982): Handle targeting errors. + for (auto& targeter : targeters) { + targeter->noteCouldNotTarget(); + } + refreshedTargeter = true; + } // 2: Use MultiStatementTransactionRequestsSender to send any ready sub-batches to targeted // shard endpoints. @@ -69,6 +82,7 @@ void execute(OperationContext* opCtx, // errors for ordered writes or transactions. // 4: Refresh the targeter(s) if we receive a stale config/db error. + // TODO(SERVER-72982): Handle targeting errors. } // Reassemble the final response based on responses from sub-batches. @@ -94,14 +108,71 @@ BulkWriteOp::BulkWriteOp(OperationContext* opCtx, const BulkWriteCommandRequest& StatusWith<bool> BulkWriteOp::target( const std::vector<std::unique_ptr<NSTargeter>>& targeters, - stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>>* targetedBatches) { - return false; + bool recordTargetErrors, + stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>>& targetedBatches) { + const auto ordered = _clientRequest.getOrdered(); + + // Used to track the shard endpoints (w/ shardVersion) we targeted and the batches to each of + // these shard endpoints. + TargetedBatchMap batchMap; + + // Used to track the set of shardIds (w/o shardVersion) we targeted. + std::set<ShardId> targetedShards; + + auto targetStatus = targetWriteOps(_opCtx, + _writeOps, + ordered, + recordTargetErrors, + // getTargeterFn: + [&](const WriteOp& writeOp) -> const NSTargeter& { + const auto opIdx = writeOp.getWriteItem().getItemIndex(); + // TODO(SERVER-73281): Support bulkWrite update and + // delete. + const auto nsIdx = + _clientRequest.getOps()[opIdx].getInsert(); + return *targeters[nsIdx]; + }, + // getWriteSizeFn: + [&](const WriteOp& writeOp) { + // TODO(SERVER-73536): Account for the size of the + // outgoing request. + return 1; + }, + batchMap); + + if (!targetStatus.isOK()) { + return targetStatus; + } + + // Send back our targeted batches. + for (TargetedBatchMap::iterator it = batchMap.begin(); it != batchMap.end(); ++it) { + auto batch = std::move(it->second); + if (batch->getWrites().empty()) + continue; + + invariant(targetedBatches.find(batch->getEndpoint().shardName) == targetedBatches.end()); + targetedBatches.emplace(batch->getEndpoint().shardName, std::move(batch)); + } + + return targetStatus; } -bool BulkWriteOp::isFinished() { +bool BulkWriteOp::isFinished() const { // TODO: Track ops lifetime. + const bool ordered = _clientRequest.getOrdered(); + for (auto& writeOp : _writeOps) { + if (writeOp.getWriteState() < WriteOpState_Completed) { + return false; + } else if (ordered && writeOp.getWriteState() == WriteOpState_Error) { + return true; + } + } return true; } + +const WriteOp& BulkWriteOp::getWriteOp_forTest(int i) const { + return _writeOps[i]; +} } // namespace bulkWriteExec } // namespace mongo diff --git a/src/mongo/s/write_ops/bulk_write_exec.h b/src/mongo/s/write_ops/bulk_write_exec.h index 1d01745284f..5333153b55d 100644 --- a/src/mongo/s/write_ops/bulk_write_exec.h +++ b/src/mongo/s/write_ops/bulk_write_exec.h @@ -37,7 +37,6 @@ #include "mongo/s/write_ops/write_op.h" namespace mongo { - namespace bulkWriteExec { /** * Executes a client bulkWrite request by sending child batches to several shard endpoints, and @@ -84,15 +83,35 @@ public: BulkWriteOp(OperationContext* opCtx, const BulkWriteCommandRequest& clientRequest); ~BulkWriteOp() = default; - // TODO(SERVER-72787): Finish this. + /** + * Targets one or more of the next write ops in this bulkWrite request using the given + * NSTargeters (targeters[i] corresponds to the targeter of the collection in nsInfo[i]). The + * resulting TargetedWrites are aggregated together in the returned TargetedWriteBatches. + * + * If 'recordTargetErrors' is false, any targeting error will abort all current batches and + * the method will return the targeting error. No targetedBatches will be returned on error. + * + * Otherwise, if 'recordTargetErrors' is true, targeting errors will be recorded for each + * write op that fails to target, and the method will return OK. + * + * (The idea here is that if we are sure our NSTargeters are up-to-date we should record + * targeting errors, but if not we should refresh once first.) + * + * Returned TargetedWriteBatches are owned by the caller. + * If a write without a shard key is detected, return an OK StatusWith that has 'true' as the + * value. + */ StatusWith<bool> target( const std::vector<std::unique_ptr<NSTargeter>>& targeters, - stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>>* targetedBatches); + bool recordTargetErrors, + stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>>& targetedBatches); /** * Returns false if the bulk write op needs more processing. */ - bool isFinished(); + bool isFinished() const; + + const WriteOp& getWriteOp_forTest(int i) const; private: // The OperationContext the client bulkWrite request is run on. diff --git a/src/mongo/s/write_ops/bulk_write_exec_test.cpp b/src/mongo/s/write_ops/bulk_write_exec_test.cpp new file mode 100644 index 00000000000..ad95d1e7cfb --- /dev/null +++ b/src/mongo/s/write_ops/bulk_write_exec_test.cpp @@ -0,0 +1,328 @@ +/** + * Copyright (C) 2023-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/idl/server_parameter_test_util.h" +#include "mongo/s/catalog_cache_test_fixture.h" +#include "mongo/s/concurrency/locker_mongos_client_observer.h" +#include "mongo/s/mock_ns_targeter.h" +#include "mongo/s/session_catalog_router.h" +#include "mongo/s/sharding_router_test_fixture.h" +#include "mongo/s/transaction_router.h" +#include "mongo/s/write_ops/batch_write_op.h" +#include "mongo/s/write_ops/batched_command_request.h" +#include "mongo/s/write_ops/bulk_write_exec.h" +#include "mongo/unittest/unittest.h" + +namespace mongo { +namespace { + +auto initTargeterFullRange(const NamespaceString& nss, const ShardEndpoint& endpoint) { + std::vector<MockRange> range{MockRange(endpoint, BSON("x" << MINKEY), BSON("x" << MAXKEY))}; + return std::make_unique<MockNSTargeter>(nss, std::move(range)); +} + +auto initTargeterSplitRange(const NamespaceString& nss, + const ShardEndpoint& endpointA, + const ShardEndpoint& endpointB) { + std::vector<MockRange> range{MockRange(endpointA, BSON("x" << MINKEY), BSON("x" << 0)), + MockRange(endpointB, BSON("x" << 0), BSON("x" << MAXKEY))}; + return std::make_unique<MockNSTargeter>(nss, std::move(range)); +} + +auto initTargeterHalfRange(const NamespaceString& nss, const ShardEndpoint& endpoint) { + // x >= 0 values are untargetable + std::vector<MockRange> range{MockRange(endpoint, BSON("x" << MINKEY), BSON("x" << 0))}; + return std::make_unique<MockNSTargeter>(nss, std::move(range)); +} + +using namespace bulkWriteExec; + +class BulkWriteOpTest : public ServiceContextTest { +protected: + BulkWriteOpTest() { + auto service = getServiceContext(); + service->registerClientObserver(std::make_unique<LockerMongosClientObserver>()); + _opCtxHolder = makeOperationContext(); + _opCtx = _opCtxHolder.get(); + } + + ServiceContext::UniqueOperationContext _opCtxHolder; + OperationContext* _opCtx; +}; + +// Test targeting a single op in a bulkWrite request. +TEST_F(BulkWriteOpTest, TargetSingleOp) { + NamespaceString nss("foo.bar"); + ShardEndpoint endpoint(ShardId("shard"), ShardVersion::IGNORED(), boost::none); + + std::vector<std::unique_ptr<NSTargeter>> targeters; + targeters.push_back(initTargeterFullRange(nss, endpoint)); + + BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << 1))}, + {NamespaceInfoEntry(nss)}); + + BulkWriteOp bulkWriteOp(_opCtx, request); + + stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted; + ASSERT_OK(bulkWriteOp.target(targeters, false, targeted)); + ASSERT_EQUALS(targeted.size(), 1u); + assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpoint); + ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 1u); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending); +} + +// Test targeting a single op with target error. +TEST_F(BulkWriteOpTest, TargetSingleOpError) { + NamespaceString nss("foo.bar"); + ShardEndpoint endpoint(ShardId("shard"), ShardVersion::IGNORED(), boost::none); + + std::vector<std::unique_ptr<NSTargeter>> targeters; + // Initialize the targeter so that x >= 0 values are untargetable so target call will encounter + // an error. + targeters.push_back(initTargeterHalfRange(nss, endpoint)); + + BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << 1))}, + {NamespaceInfoEntry(nss)}); + + BulkWriteOp bulkWriteOp(_opCtx, request); + + stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted; + // target should return target error when recordTargetErrors = false. + ASSERT_NOT_OK(bulkWriteOp.target(targeters, false, targeted)); + ASSERT_EQUALS(targeted.size(), 0u); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Ready); + + // target should transition the writeOp to an error state upon target errors when + // recordTargetErrors = true. + ASSERT_OK(bulkWriteOp.target(targeters, true, targeted)); + ASSERT_EQUALS(targeted.size(), 0u); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Error); +} + +// Test multiple ordered ops that target the same shard. +TEST_F(BulkWriteOpTest, TargetMultiOpsOrdered_SameShard) { + NamespaceString nss0("foo.bar"); + NamespaceString nss1("bar.foo"); + ShardEndpoint endpoint(ShardId("shard"), ShardVersion::IGNORED(), boost::none); + + std::vector<std::unique_ptr<NSTargeter>> targeters; + targeters.push_back(initTargeterFullRange(nss0, endpoint)); + targeters.push_back(initTargeterFullRange(nss1, endpoint)); + + BulkWriteCommandRequest request( + {BulkWriteInsertOp(1, BSON("x" << 1)), BulkWriteInsertOp(0, BSON("x" << 2))}, + {NamespaceInfoEntry(nss0), NamespaceInfoEntry(nss1)}); + + BulkWriteOp bulkWriteOp(_opCtx, request); + + stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted; + ASSERT_OK(bulkWriteOp.target(targeters, false, targeted)); + ASSERT_EQUALS(targeted.size(), 1u); + assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpoint); + ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 2u); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Pending); +} + +// Test multiple ordered ops where one of them result in a target error. +TEST_F(BulkWriteOpTest, TargetMultiOpsOrdered_RecordTargetErrors) { + NamespaceString nss0("foo.bar"); + NamespaceString nss1("bar.foo"); + ShardEndpoint endpoint(ShardId("shard"), ShardVersion::IGNORED(), boost::none); + + std::vector<std::unique_ptr<NSTargeter>> targeters; + // Initialize the targeter so that x >= 0 values are untargetable so target call will encounter + // an error. + targeters.push_back(initTargeterHalfRange(nss0, endpoint)); + targeters.push_back(initTargeterFullRange(nss1, endpoint)); + + // Only the second op would get a target error. + BulkWriteCommandRequest request({BulkWriteInsertOp(1, BSON("x" << 1)), + BulkWriteInsertOp(0, BSON("x" << 2)), + BulkWriteInsertOp(0, BSON("x" << -1))}, + {NamespaceInfoEntry(nss0), NamespaceInfoEntry(nss1)}); + + BulkWriteOp bulkWriteOp(_opCtx, request); + + stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted; + ASSERT_OK(bulkWriteOp.target(targeters, true, targeted)); + + // Only the first op should be targeted as the second op encounters a target error. But this + // won't record the target error since there could be an error in the first op before executing + // the second op. + ASSERT_EQUALS(targeted.size(), 1u); + assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpoint); + ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 1u); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Ready); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Ready); + + targeted.clear(); + + // Pretending the first op was done successfully, the target error should be recorded in the + // second op. + ASSERT_OK(bulkWriteOp.target(targeters, true, targeted)); + ASSERT_EQUALS(targeted.size(), 0u); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Error); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Ready); +} + +// Test multiple ordered ops that target two different shards. +TEST_F(BulkWriteOpTest, TargetMultiOpsOrdered_DifferentShard) { + NamespaceString nss0("foo.bar"); + NamespaceString nss1("bar.foo"); + ShardEndpoint endpointA(ShardId("shardA"), ShardVersion::IGNORED(), boost::none); + ShardEndpoint endpointB(ShardId("shardB"), ShardVersion::IGNORED(), boost::none); + + std::vector<std::unique_ptr<NSTargeter>> targeters; + targeters.push_back(initTargeterSplitRange(nss0, endpointA, endpointB)); + targeters.push_back(initTargeterFullRange(nss1, endpointA)); + + // ops[0] -> shardA + // ops[1] -> shardB + // ops[2] -> shardA + BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << -1)), + BulkWriteInsertOp(0, BSON("x" << 1)), + BulkWriteInsertOp(1, BSON("x" << 1))}, + {NamespaceInfoEntry(nss0), NamespaceInfoEntry(nss1)}); + + BulkWriteOp bulkWriteOp(_opCtx, request); + + stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted; + + // The resulting batch should be {shardA: [ops[0]]}. + ASSERT_OK(bulkWriteOp.target(targeters, false, targeted)); + ASSERT_EQUALS(targeted.size(), 1u); + assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpointA); + ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 1u); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Ready); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Ready); + + targeted.clear(); + + // The resulting batch should be {shardB: [ops[1]]}. + ASSERT_OK(bulkWriteOp.target(targeters, false, targeted)); + ASSERT_EQUALS(targeted.size(), 1u); + assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpointB); + ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 1u); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Pending); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Ready); + + targeted.clear(); + + // The resulting batch should be {shardA: [ops[2]]}. + ASSERT_OK(bulkWriteOp.target(targeters, false, targeted)); + ASSERT_EQUALS(targeted.size(), 1u); + assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpointA); + ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 1u); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Pending); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Pending); +} + +// Test multiple unordered ops that target two different shards. +TEST_F(BulkWriteOpTest, TargetMultiOpsUnordered) { + NamespaceString nss0("foo.bar"); + NamespaceString nss1("bar.foo"); + ShardEndpoint endpointA(ShardId("shardA"), ShardVersion::IGNORED(), boost::none); + ShardEndpoint endpointB(ShardId("shardB"), ShardVersion::IGNORED(), boost::none); + + std::vector<std::unique_ptr<NSTargeter>> targeters; + targeters.push_back(initTargeterSplitRange(nss0, endpointA, endpointB)); + targeters.push_back(initTargeterFullRange(nss1, endpointA)); + + // ops[0] -> shardA + // ops[1] -> shardB + // ops[2] -> shardA + BulkWriteCommandRequest request({BulkWriteInsertOp(0, BSON("x" << -1)), + BulkWriteInsertOp(0, BSON("x" << 1)), + BulkWriteInsertOp(1, BSON("x" << 1))}, + {NamespaceInfoEntry(nss0), NamespaceInfoEntry(nss1)}); + request.setOrdered(false); + + BulkWriteOp bulkWriteOp(_opCtx, request); + + // The two resulting batches should be: + // {shardA: [ops[0], ops[2]]} + // {shardB: [ops[1]]} + stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted; + ASSERT_OK(bulkWriteOp.target(targeters, false, targeted)); + ASSERT_EQUALS(targeted.size(), 2u); + + ASSERT_EQUALS(targeted[ShardId("shardA")]->getWrites().size(), 2u); + ASSERT_EQUALS(targeted[ShardId("shardA")]->getWrites()[0]->writeOpRef.first, 0); + ASSERT_EQUALS(targeted[ShardId("shardA")]->getWrites()[1]->writeOpRef.first, 2); + + ASSERT_EQUALS(targeted[ShardId("shardB")]->getWrites().size(), 1u); + ASSERT_EQUALS(targeted[ShardId("shardB")]->getWrites()[0]->writeOpRef.first, 1); + + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Pending); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Pending); +} + +// Test multiple unordered ops where one of them result in a target error. +TEST_F(BulkWriteOpTest, TargetMultiOpsUnordered_RecordTargetErrors) { + NamespaceString nss0("foo.bar"); + NamespaceString nss1("bar.foo"); + ShardEndpoint endpoint(ShardId("shard"), ShardVersion::IGNORED(), boost::none); + + std::vector<std::unique_ptr<NSTargeter>> targeters; + // Initialize the targeter so that x >= 0 values are untargetable so target call will encounter + // an error. + targeters.push_back(initTargeterHalfRange(nss0, endpoint)); + targeters.push_back(initTargeterFullRange(nss1, endpoint)); + + // Only the second op would get a target error. + BulkWriteCommandRequest request({BulkWriteInsertOp(1, BSON("x" << 1)), + BulkWriteInsertOp(0, BSON("x" << 2)), + BulkWriteInsertOp(0, BSON("x" << -1))}, + {NamespaceInfoEntry(nss0), NamespaceInfoEntry(nss1)}); + request.setOrdered(false); + + BulkWriteOp bulkWriteOp(_opCtx, request); + + stdx::unordered_map<ShardId, std::unique_ptr<TargetedWriteBatch>> targeted; + ASSERT_OK(bulkWriteOp.target(targeters, true, targeted)); + + // In the unordered case, both the first and the third ops should be targeted successfully + // despite targeting error on the second op. + ASSERT_EQUALS(targeted.size(), 1u); + assertEndpointsEqual(targeted.begin()->second->getEndpoint(), endpoint); + ASSERT_EQUALS(targeted.begin()->second->getWrites().size(), 2u); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(0).getWriteState(), WriteOpState_Pending); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(1).getWriteState(), WriteOpState_Error); + ASSERT_EQUALS(bulkWriteOp.getWriteOp_forTest(2).getWriteState(), WriteOpState_Pending); +} + +} // namespace + +} // namespace mongo |