diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/mongo/db/commands/run_aggregate.cpp | 22 | ||||
-rw-r--r-- | src/mongo/db/pipeline/aggregate_command.idl | 12 | ||||
-rw-r--r-- | src/mongo/db/pipeline/aggregation_request_helper.cpp | 8 | ||||
-rw-r--r-- | src/mongo/db/pipeline/aggregation_request_helper.h | 5 | ||||
-rw-r--r-- | src/mongo/db/pipeline/aggregation_request_test.cpp | 7 | ||||
-rw-r--r-- | src/mongo/db/query/cursor_response.cpp | 28 | ||||
-rw-r--r-- | src/mongo/db/query/cursor_response.h | 13 | ||||
-rw-r--r-- | src/mongo/db/query/cursor_response.idl | 9 | ||||
-rw-r--r-- | src/mongo/s/commands/cluster_query_without_shard_key_cmd.cpp | 72 | ||||
-rw-r--r-- | src/mongo/s/query/cluster_aggregate.cpp | 5 |
10 files changed, 161 insertions, 20 deletions
diff --git a/src/mongo/db/commands/run_aggregate.cpp b/src/mongo/db/commands/run_aggregate.cpp index b2a81220828..8db9c787f7f 100644 --- a/src/mongo/db/commands/run_aggregate.cpp +++ b/src/mongo/db/commands/run_aggregate.cpp @@ -87,6 +87,7 @@ #include "mongo/db/service_context.h" #include "mongo/db/stats/resource_consumption_metrics.h" #include "mongo/db/storage/storage_options.h" +#include "mongo/db/transaction/transaction_participant.h" #include "mongo/db/views/view.h" #include "mongo/db/views/view_catalog_helpers.h" #include "mongo/logv2/log.h" @@ -571,17 +572,20 @@ std::vector<std::unique_ptr<Pipeline, PipelineDeleter>> createAdditionalPipeline } /** - * Performs validations related to API versioning and time-series stages. + * Performs validations related to API versioning, time-series stages, and general command + * validation. * Throws UserAssertion if any of the validations fails * - validation of API versioning on each stage on the pipeline * - validation of API versioning on 'AggregateCommandRequest' request * - validation of time-series related stages + * - validation of command parameters */ void performValidationChecks(const OperationContext* opCtx, const AggregateCommandRequest& request, const LiteParsedPipeline& liteParsedPipeline) { liteParsedPipeline.validate(opCtx); aggregation_request_helper::validateRequestForAPIVersion(opCtx, request); + aggregation_request_helper::validateRequestFromClusterQueryWithoutShardKey(request); } std::vector<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> createLegacyExecutor( @@ -671,6 +675,22 @@ Status runAggregate(OperationContext* opCtx, // aggregation command. performValidationChecks(opCtx, request, liteParsedPipeline); + // If we are running a retryable write without shard key, check if the write was applied on this + // shard, and if so, return early with an empty cursor with $_wasStatementExecuted + // set to true. + auto isClusterQueryWithoutShardKeyCmd = request.getIsClusterQueryWithoutShardKeyCmd(); + auto stmtId = request.getStmtId(); + if (isClusterQueryWithoutShardKeyCmd && stmtId) { + if (TransactionParticipant::get(opCtx).checkStatementExecuted(opCtx, *stmtId)) { + CursorResponseBuilder::Options options; + options.isInitialResponse = true; + CursorResponseBuilder responseBuilder(result, options); + responseBuilder.setWasStatementExecuted(true); + responseBuilder.done(0LL, origNss); + return Status::OK(); + } + } + // For operations on views, this will be the underlying namespace. NamespaceString nss = request.getNamespace(); diff --git a/src/mongo/db/pipeline/aggregate_command.idl b/src/mongo/db/pipeline/aggregate_command.idl index eccd8b9fb45..4c6b11eef44 100644 --- a/src/mongo/db/pipeline/aggregate_command.idl +++ b/src/mongo/db/pipeline/aggregate_command.idl @@ -296,3 +296,15 @@ commands: type: uuid optional: true stability: unstable + stmtId: + description: "The statement id of the write in the original write batch for a write + without shard key." + type: int + optional: true + stability: internal + $_isClusterQueryWithoutShardKeyCmd: + description: "True if a _clusterQueryWithoutShardKey command is running a broadcast + aggregate." + type: optionalBool + cpp_name: isClusterQueryWithoutShardKeyCmd + stability: internal diff --git a/src/mongo/db/pipeline/aggregation_request_helper.cpp b/src/mongo/db/pipeline/aggregation_request_helper.cpp index 60f19b9cf2e..29a5fe90cd8 100644 --- a/src/mongo/db/pipeline/aggregation_request_helper.cpp +++ b/src/mongo/db/pipeline/aggregation_request_helper.cpp @@ -222,6 +222,14 @@ void validateRequestForAPIVersion(const OperationContext* opCtx, } } +void validateRequestFromClusterQueryWithoutShardKey(const AggregateCommandRequest& request) { + if (request.getIsClusterQueryWithoutShardKeyCmd()) { + uassert(ErrorCodes::InvalidOptions, + "Only mongos can set the isClusterQueryWithoutShardKeyCmd field", + request.getFromMongos()); + } +} + PlanExecutorPipeline::ResumableScanType getResumableScanType(const AggregateCommandRequest& request, bool isChangeStream) { // $changeStream cannot be run on the oplog, and $_requestReshardingResumeToken can only be run diff --git a/src/mongo/db/pipeline/aggregation_request_helper.h b/src/mongo/db/pipeline/aggregation_request_helper.h index dff875ff7f5..63d77287db4 100644 --- a/src/mongo/db/pipeline/aggregation_request_helper.h +++ b/src/mongo/db/pipeline/aggregation_request_helper.h @@ -121,6 +121,11 @@ BSONObj serializeToCommandObj(const AggregateCommandRequest& request); */ void validateRequestForAPIVersion(const OperationContext* opCtx, const AggregateCommandRequest& request); +/** + * Validates if 'AggregateCommandRequest' sets the "isClusterQueryWithoutShardKeyCmd" field then the + * request must have been fromMongos. + */ +void validateRequestFromClusterQueryWithoutShardKey(const AggregateCommandRequest& request); /** * Returns the type of resumable scan required by this aggregation, if applicable. Otherwise returns diff --git a/src/mongo/db/pipeline/aggregation_request_test.cpp b/src/mongo/db/pipeline/aggregation_request_test.cpp index ba75b6f28ca..a300aa15128 100644 --- a/src/mongo/db/pipeline/aggregation_request_test.cpp +++ b/src/mongo/db/pipeline/aggregation_request_test.cpp @@ -66,7 +66,7 @@ TEST(AggregationRequestTest, ShouldParseAllKnownOptions) { "collation: {locale: 'en_US'}, cursor: {batchSize: 10}, hint: {a: 1}, maxTimeMS: 100, " "readConcern: {level: 'linearizable'}, $queryOptions: {$readPreference: 'nearest'}, " "exchange: {policy: 'roundrobin', consumers:NumberInt(2)}, isMapReduceCommand: true, $db: " - "'local'}"); + "'local', $_isClusterQueryWithoutShardKeyCmd: true}"); auto uuid = UUID::gen(); BSONObjBuilder uuidBob; uuid.appendToBuilder(&uuidBob, AggregateCommandRequest::kCollectionUUIDFieldName); @@ -80,6 +80,7 @@ TEST(AggregationRequestTest, ShouldParseAllKnownOptions) { ASSERT_TRUE(request.getNeedsMerge()); ASSERT_TRUE(request.getBypassDocumentValidation().value_or(false)); ASSERT_TRUE(request.getRequestReshardingResumeToken()); + ASSERT_TRUE(request.getIsClusterQueryWithoutShardKeyCmd()); ASSERT_EQ( request.getCursor().getBatchSize().value_or(aggregation_request_helper::kDefaultBatchSize), 10); @@ -215,6 +216,7 @@ TEST(AggregationRequestTest, ShouldSerializeOptionalValuesIfSet) { request.setLet(letParamsObj); auto uuid = UUID::gen(); request.setCollectionUUID(uuid); + request.setIsClusterQueryWithoutShardKeyCmd(true); auto expectedSerialization = Document{ {AggregateCommandRequest::kCommandName, nss.coll()}, @@ -232,7 +234,8 @@ TEST(AggregationRequestTest, ShouldSerializeOptionalValuesIfSet) { {query_request_helper::kUnwrappedReadPrefField, readPrefObj}, {AggregateCommandRequest::kRequestReshardingResumeTokenFieldName, true}, {AggregateCommandRequest::kIsMapReduceCommandFieldName, true}, - {AggregateCommandRequest::kCollectionUUIDFieldName, uuid}}; + {AggregateCommandRequest::kCollectionUUIDFieldName, uuid}, + {AggregateCommandRequest::kIsClusterQueryWithoutShardKeyCmdFieldName, true}}; ASSERT_DOCUMENT_EQ(aggregation_request_helper::serializeToCommandDoc(request), expectedSerialization); } diff --git a/src/mongo/db/query/cursor_response.cpp b/src/mongo/db/query/cursor_response.cpp index d53a24c7c5f..1940e3def4d 100644 --- a/src/mongo/db/query/cursor_response.cpp +++ b/src/mongo/db/query/cursor_response.cpp @@ -57,6 +57,7 @@ const char kBatchDocSequenceFieldInitial[] = "cursor.firstBatch"; const char kPostBatchResumeTokenField[] = "postBatchResumeToken"; const char kPartialResultsReturnedField[] = "partialResultsReturned"; const char kInvalidatedField[] = "invalidated"; +const char kWasStatementExecuted[] = "$_wasStatementExecuted"; } // namespace @@ -84,6 +85,10 @@ void CursorResponseBuilder::done(CursorId cursorId, const NamespaceString& curso _cursorObject->append(kInvalidatedField, _invalidated); } + if (_wasStatementExecuted) { + _cursorObject->append(kWasStatementExecuted, _wasStatementExecuted); + } + _cursorObject->append(kIdField, cursorId); _cursorObject->append(kNsField, NamespaceStringUtil::serialize(cursorNamespace)); if (_options.atClusterTime) { @@ -140,7 +145,8 @@ CursorResponse::CursorResponse(NamespaceString nss, boost::optional<BSONObj> varsField, boost::optional<std::string> cursorType, bool partialResultsReturned, - bool invalidated) + bool invalidated, + bool wasStatementExecuted) : _nss(std::move(nss)), _cursorId(cursorId), _batch(std::move(batch)), @@ -150,7 +156,8 @@ CursorResponse::CursorResponse(NamespaceString nss, _varsField(std::move(varsField)), _cursorType(std::move(cursorType)), _partialResultsReturned(partialResultsReturned), - _invalidated(invalidated) {} + _invalidated(invalidated), + _wasStatementExecuted(wasStatementExecuted) {} std::vector<StatusWith<CursorResponse>> CursorResponse::parseFromBSONMany( const BSONObj& cmdResponse) { @@ -299,6 +306,16 @@ StatusWith<CursorResponse> CursorResponse::parseFromBSON(const BSONObj& cmdRespo } } + auto wasStatementExecuted = cursorObj[kWasStatementExecuted]; + if (wasStatementExecuted) { + if (wasStatementExecuted.type() != BSONType::Bool) { + return {ErrorCodes::BadValue, + str::stream() << kWasStatementExecuted + << " format is invalid; expected Bool, but found: " + << wasStatementExecuted.type()}; + } + } + auto writeConcernError = cmdResponse["writeConcernError"]; if (writeConcernError && writeConcernError.type() != BSONType::Object) { @@ -317,7 +334,8 @@ StatusWith<CursorResponse> CursorResponse::parseFromBSON(const BSONObj& cmdRespo varsElt ? varsElt.Obj().getOwned() : boost::optional<BSONObj>{}, typeElt ? boost::make_optional<std::string>(typeElt.String()) : boost::none, partialResultsReturned.trueValue(), - invalidatedElem.trueValue()}}; + invalidatedElem.trueValue(), + wasStatementExecuted.trueValue()}}; } void CursorResponse::addToBSON(CursorResponse::ResponseType responseType, @@ -351,6 +369,10 @@ void CursorResponse::addToBSON(CursorResponse::ResponseType responseType, cursorBuilder.append(kInvalidatedField, _invalidated); } + if (_wasStatementExecuted) { + cursorBuilder.append(kWasStatementExecuted, _wasStatementExecuted); + } + cursorBuilder.doneFast(); builder->append("ok", 1.0); diff --git a/src/mongo/db/query/cursor_response.h b/src/mongo/db/query/cursor_response.h index 7b33b71e1d5..87f5f22c263 100644 --- a/src/mongo/db/query/cursor_response.h +++ b/src/mongo/db/query/cursor_response.h @@ -99,6 +99,10 @@ public: _invalidated = true; } + void setWasStatementExecuted(bool wasStatementExecuted) { + _wasStatementExecuted = true; + } + long long numDocs() const { return _numDocs; } @@ -135,6 +139,7 @@ private: BSONObj _postBatchResumeToken; bool _partialResultsReturned = false; bool _invalidated = false; + bool _wasStatementExecuted = false; }; /** @@ -216,7 +221,8 @@ public: boost::optional<BSONObj> varsField = boost::none, boost::optional<std::string> cursorType = boost::none, bool partialResultsReturned = false, - bool invalidated = false); + bool invalidated = false, + bool wasStatementExecuted = false); CursorResponse(CursorResponse&& other) = default; CursorResponse& operator=(CursorResponse&& other) = default; @@ -269,6 +275,10 @@ public: return _invalidated; } + bool getWasStatementExecuted() const { + return _wasStatementExecuted; + } + /** * Converts this response to its raw BSON representation. */ @@ -289,6 +299,7 @@ private: boost::optional<std::string> _cursorType; bool _partialResultsReturned = false; bool _invalidated = false; + bool _wasStatementExecuted = false; }; } // namespace mongo diff --git a/src/mongo/db/query/cursor_response.idl b/src/mongo/db/query/cursor_response.idl index af264710874..391beb38c35 100644 --- a/src/mongo/db/query/cursor_response.idl +++ b/src/mongo/db/query/cursor_response.idl @@ -76,6 +76,15 @@ structs: description: "Boolean represents if the cursor has been invalidated." type: optionalBool stability: stable + $_wasStatementExecuted: + description: "An optional field set to true if a write without shard key had already + applied the write. To provide some context, this internal field is + used by the two phase write without shard key protocol introduced in + PM-1632 to support retryable writes for updateOne without shard key, + deleteOne without shard key, and findAndModify without shard key." + type: optionalBool + cpp_name: wasStatementExecuted + stability: internal InitialResponseCursor: description: "A struct representing an initial response cursor." diff --git a/src/mongo/s/commands/cluster_query_without_shard_key_cmd.cpp b/src/mongo/s/commands/cluster_query_without_shard_key_cmd.cpp index 2511a722fe1..4d1f8ed17ec 100644 --- a/src/mongo/s/commands/cluster_query_without_shard_key_cmd.cpp +++ b/src/mongo/s/commands/cluster_query_without_shard_key_cmd.cpp @@ -47,8 +47,18 @@ namespace { struct ParsedCommandInfo { BSONObj query; BSONObj collation; + int stmtId; - ParsedCommandInfo(BSONObj query, BSONObj collation) : query(query), collation(collation) {} + ParsedCommandInfo(BSONObj query, BSONObj collation, int stmtId) + : query(query), collation(collation), stmtId(stmtId) {} +}; + +struct AsyncRequestSenderResponseData { + ShardId shardId; + CursorResponse cursorResponse; + + AsyncRequestSenderResponseData(ShardId shardId, CursorResponse cursorResponse) + : shardId(shardId), cursorResponse(std::move(cursorResponse)) {} }; std::set<ShardId> getShardsToTarget(OperationContext* opCtx, @@ -85,6 +95,12 @@ BSONObj createAggregateCmdObj(OperationContext* opCtx, BSON("$limit" << 1), BSON("$project" << BSON("_id" << 1))}); aggregate.setCollation(parsedInfo.collation); + aggregate.setIsClusterQueryWithoutShardKeyCmd(true); + aggregate.setFromMongos(true); + + if (parsedInfo.stmtId != kUninitializedStmtId) { + aggregate.setStmtId(parsedInfo.stmtId); + } return aggregate.toBSON({}); } @@ -110,17 +126,29 @@ public: const NamespaceString nss( CommandHelpers::parseNsCollectionRequired(ns().dbName(), request().getWriteCmd())); const auto cri = uassertStatusOK(getCollectionRoutingInfoForTxnCmd(opCtx, nss)); + auto parsedInfoFromRequest = [&] { const auto commandName = request().getWriteCmd().firstElementFieldNameStringData(); + BSONObjBuilder bob(request().getWriteCmd()); bob.appendElementsUnique(BSON("$db" << ns().dbName().toString())); auto writeCmdObj = bob.obj(); + BSONObj query; BSONObj collation; + int stmtId = kUninitializedStmtId; + if (commandName == "update") { auto updateRequest = write_ops::UpdateCommandRequest::parse( IDLParserContext("_clusterQueryWithoutShardKey"), writeCmdObj); query = updateRequest.getUpdates().front().getQ(); + + // In the batch write path, when the request is reconstructed to be passed to + // the two phase write protocol, only the stmtIds field is used. + if (auto stmtIds = updateRequest.getStmtIds()) { + stmtId = stmtIds->front(); + } + if (auto parsedCollation = updateRequest.getUpdates().front().getCollation()) { collation = *parsedCollation; } @@ -128,6 +156,13 @@ public: auto deleteRequest = write_ops::DeleteCommandRequest::parse( IDLParserContext("_clusterQueryWithoutShardKey"), writeCmdObj); query = deleteRequest.getDeletes().front().getQ(); + + // In the batch write path, when the request is reconstructed to be passed to + // the two phase write protocol, only the stmtIds field is used. + if (auto stmtIds = deleteRequest.getStmtIds()) { + stmtId = stmtIds->front(); + } + if (auto parsedCollation = deleteRequest.getDeletes().front().getCollation()) { collation = *parsedCollation; } @@ -135,13 +170,15 @@ public: auto findAndModifyRequest = write_ops::FindAndModifyCommandRequest::parse( IDLParserContext("_clusterQueryWithoutShardKey"), writeCmdObj); query = findAndModifyRequest.getQuery(); + stmtId = findAndModifyRequest.getStmtId().value_or(kUninitializedStmtId); + if (auto parsedCollation = findAndModifyRequest.getCollation()) { collation = *parsedCollation; } } else { uasserted(ErrorCodes::InvalidOptions, "Not a supported batch write command"); } - return ParsedCommandInfo(query.getOwned(), collation.getOwned()); + return ParsedCommandInfo(query.getOwned(), collation.getOwned(), stmtId); }(); auto allShardsContainingChunksForNs = @@ -162,24 +199,35 @@ public: ReadPreferenceSetting(ReadPreference::PrimaryOnly), Shard::RetryPolicy::kNoRetry); - std::vector<AsyncRequestsSender::Response> results; + BSONObj targetDoc; + Response res; + std::vector<AsyncRequestSenderResponseData> responses; while (!ars.done()) { auto response = ars.next(); uassertStatusOK(response.swResponse); - results.push_back(response); + responses.emplace_back( + AsyncRequestSenderResponseData(response.shardId, + uassertStatusOK(CursorResponse::parseFromBSON( + response.swResponse.getValue().data)))); } - BSONObj targetDoc; - Response res; - for (const auto& arsRes : results) { - auto cursorResponse = uassertStatusOK( - CursorResponse::parseFromBSON(arsRes.swResponse.getValue().data)); + for (auto& responseData : responses) { + auto shardId = responseData.shardId; + auto cursorResponse = std::move(responseData.cursorResponse); + + // Return the first target doc/shard id pair that has already applied the write + // for a retryable write. + if (cursorResponse.getWasStatementExecuted()) { + // Since the retryable write history check happens before a write is executed, + // we can just use an empty BSONObj for the target doc. + res.setTargetDoc(BSONObj::kEmptyObject); + res.setShardId(boost::optional<mongo::StringData>(shardId)); + break; + } - // Return the first response that contains a matching document. if (cursorResponse.getBatch().size() > 0) { res.setTargetDoc(cursorResponse.releaseBatch().front().getOwned()); - res.setShardId(boost::optional<mongo::StringData>(arsRes.shardId)); - break; + res.setShardId(boost::optional<mongo::StringData>(shardId)); } } return res; diff --git a/src/mongo/s/query/cluster_aggregate.cpp b/src/mongo/s/query/cluster_aggregate.cpp index 3bc19bdef59..5550b177258 100644 --- a/src/mongo/s/query/cluster_aggregate.cpp +++ b/src/mongo/s/query/cluster_aggregate.cpp @@ -218,17 +218,20 @@ void updateHostsTargetedMetrics(OperationContext* opCtx, } /** - * Performs validations related to API versioning and time-series stages. + * Performs validations related to API versioning, time-series stages, and general command + * validation. * Throws UserAssertion if any of the validations fails * - validation of API versioning on each stage on the pipeline * - validation of API versioning on 'AggregateCommandRequest' request * - validation of time-series related stages + * - validation of command parameters */ void performValidationChecks(const OperationContext* opCtx, const AggregateCommandRequest& request, const LiteParsedPipeline& liteParsedPipeline) { liteParsedPipeline.validate(opCtx); aggregation_request_helper::validateRequestForAPIVersion(opCtx, request); + aggregation_request_helper::validateRequestFromClusterQueryWithoutShardKey(request); } /** |