summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/mongo/db/commands/run_aggregate.cpp22
-rw-r--r--src/mongo/db/pipeline/aggregate_command.idl12
-rw-r--r--src/mongo/db/pipeline/aggregation_request_helper.cpp8
-rw-r--r--src/mongo/db/pipeline/aggregation_request_helper.h5
-rw-r--r--src/mongo/db/pipeline/aggregation_request_test.cpp7
-rw-r--r--src/mongo/db/query/cursor_response.cpp28
-rw-r--r--src/mongo/db/query/cursor_response.h13
-rw-r--r--src/mongo/db/query/cursor_response.idl9
-rw-r--r--src/mongo/s/commands/cluster_query_without_shard_key_cmd.cpp72
-rw-r--r--src/mongo/s/query/cluster_aggregate.cpp5
10 files changed, 161 insertions, 20 deletions
diff --git a/src/mongo/db/commands/run_aggregate.cpp b/src/mongo/db/commands/run_aggregate.cpp
index b2a81220828..8db9c787f7f 100644
--- a/src/mongo/db/commands/run_aggregate.cpp
+++ b/src/mongo/db/commands/run_aggregate.cpp
@@ -87,6 +87,7 @@
#include "mongo/db/service_context.h"
#include "mongo/db/stats/resource_consumption_metrics.h"
#include "mongo/db/storage/storage_options.h"
+#include "mongo/db/transaction/transaction_participant.h"
#include "mongo/db/views/view.h"
#include "mongo/db/views/view_catalog_helpers.h"
#include "mongo/logv2/log.h"
@@ -571,17 +572,20 @@ std::vector<std::unique_ptr<Pipeline, PipelineDeleter>> createAdditionalPipeline
}
/**
- * Performs validations related to API versioning and time-series stages.
+ * Performs validations related to API versioning, time-series stages, and general command
+ * validation.
* Throws UserAssertion if any of the validations fails
* - validation of API versioning on each stage on the pipeline
* - validation of API versioning on 'AggregateCommandRequest' request
* - validation of time-series related stages
+ * - validation of command parameters
*/
void performValidationChecks(const OperationContext* opCtx,
const AggregateCommandRequest& request,
const LiteParsedPipeline& liteParsedPipeline) {
liteParsedPipeline.validate(opCtx);
aggregation_request_helper::validateRequestForAPIVersion(opCtx, request);
+ aggregation_request_helper::validateRequestFromClusterQueryWithoutShardKey(request);
}
std::vector<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> createLegacyExecutor(
@@ -671,6 +675,22 @@ Status runAggregate(OperationContext* opCtx,
// aggregation command.
performValidationChecks(opCtx, request, liteParsedPipeline);
+ // If we are running a retryable write without shard key, check if the write was applied on this
+ // shard, and if so, return early with an empty cursor with $_wasStatementExecuted
+ // set to true.
+ auto isClusterQueryWithoutShardKeyCmd = request.getIsClusterQueryWithoutShardKeyCmd();
+ auto stmtId = request.getStmtId();
+ if (isClusterQueryWithoutShardKeyCmd && stmtId) {
+ if (TransactionParticipant::get(opCtx).checkStatementExecuted(opCtx, *stmtId)) {
+ CursorResponseBuilder::Options options;
+ options.isInitialResponse = true;
+ CursorResponseBuilder responseBuilder(result, options);
+ responseBuilder.setWasStatementExecuted(true);
+ responseBuilder.done(0LL, origNss);
+ return Status::OK();
+ }
+ }
+
// For operations on views, this will be the underlying namespace.
NamespaceString nss = request.getNamespace();
diff --git a/src/mongo/db/pipeline/aggregate_command.idl b/src/mongo/db/pipeline/aggregate_command.idl
index eccd8b9fb45..4c6b11eef44 100644
--- a/src/mongo/db/pipeline/aggregate_command.idl
+++ b/src/mongo/db/pipeline/aggregate_command.idl
@@ -296,3 +296,15 @@ commands:
type: uuid
optional: true
stability: unstable
+ stmtId:
+ description: "The statement id of the write in the original write batch for a write
+ without shard key."
+ type: int
+ optional: true
+ stability: internal
+ $_isClusterQueryWithoutShardKeyCmd:
+ description: "True if a _clusterQueryWithoutShardKey command is running a broadcast
+ aggregate."
+ type: optionalBool
+ cpp_name: isClusterQueryWithoutShardKeyCmd
+ stability: internal
diff --git a/src/mongo/db/pipeline/aggregation_request_helper.cpp b/src/mongo/db/pipeline/aggregation_request_helper.cpp
index 60f19b9cf2e..29a5fe90cd8 100644
--- a/src/mongo/db/pipeline/aggregation_request_helper.cpp
+++ b/src/mongo/db/pipeline/aggregation_request_helper.cpp
@@ -222,6 +222,14 @@ void validateRequestForAPIVersion(const OperationContext* opCtx,
}
}
+void validateRequestFromClusterQueryWithoutShardKey(const AggregateCommandRequest& request) {
+ if (request.getIsClusterQueryWithoutShardKeyCmd()) {
+ uassert(ErrorCodes::InvalidOptions,
+ "Only mongos can set the isClusterQueryWithoutShardKeyCmd field",
+ request.getFromMongos());
+ }
+}
+
PlanExecutorPipeline::ResumableScanType getResumableScanType(const AggregateCommandRequest& request,
bool isChangeStream) {
// $changeStream cannot be run on the oplog, and $_requestReshardingResumeToken can only be run
diff --git a/src/mongo/db/pipeline/aggregation_request_helper.h b/src/mongo/db/pipeline/aggregation_request_helper.h
index dff875ff7f5..63d77287db4 100644
--- a/src/mongo/db/pipeline/aggregation_request_helper.h
+++ b/src/mongo/db/pipeline/aggregation_request_helper.h
@@ -121,6 +121,11 @@ BSONObj serializeToCommandObj(const AggregateCommandRequest& request);
*/
void validateRequestForAPIVersion(const OperationContext* opCtx,
const AggregateCommandRequest& request);
+/**
+ * Validates if 'AggregateCommandRequest' sets the "isClusterQueryWithoutShardKeyCmd" field then the
+ * request must have been fromMongos.
+ */
+void validateRequestFromClusterQueryWithoutShardKey(const AggregateCommandRequest& request);
/**
* Returns the type of resumable scan required by this aggregation, if applicable. Otherwise returns
diff --git a/src/mongo/db/pipeline/aggregation_request_test.cpp b/src/mongo/db/pipeline/aggregation_request_test.cpp
index ba75b6f28ca..a300aa15128 100644
--- a/src/mongo/db/pipeline/aggregation_request_test.cpp
+++ b/src/mongo/db/pipeline/aggregation_request_test.cpp
@@ -66,7 +66,7 @@ TEST(AggregationRequestTest, ShouldParseAllKnownOptions) {
"collation: {locale: 'en_US'}, cursor: {batchSize: 10}, hint: {a: 1}, maxTimeMS: 100, "
"readConcern: {level: 'linearizable'}, $queryOptions: {$readPreference: 'nearest'}, "
"exchange: {policy: 'roundrobin', consumers:NumberInt(2)}, isMapReduceCommand: true, $db: "
- "'local'}");
+ "'local', $_isClusterQueryWithoutShardKeyCmd: true}");
auto uuid = UUID::gen();
BSONObjBuilder uuidBob;
uuid.appendToBuilder(&uuidBob, AggregateCommandRequest::kCollectionUUIDFieldName);
@@ -80,6 +80,7 @@ TEST(AggregationRequestTest, ShouldParseAllKnownOptions) {
ASSERT_TRUE(request.getNeedsMerge());
ASSERT_TRUE(request.getBypassDocumentValidation().value_or(false));
ASSERT_TRUE(request.getRequestReshardingResumeToken());
+ ASSERT_TRUE(request.getIsClusterQueryWithoutShardKeyCmd());
ASSERT_EQ(
request.getCursor().getBatchSize().value_or(aggregation_request_helper::kDefaultBatchSize),
10);
@@ -215,6 +216,7 @@ TEST(AggregationRequestTest, ShouldSerializeOptionalValuesIfSet) {
request.setLet(letParamsObj);
auto uuid = UUID::gen();
request.setCollectionUUID(uuid);
+ request.setIsClusterQueryWithoutShardKeyCmd(true);
auto expectedSerialization = Document{
{AggregateCommandRequest::kCommandName, nss.coll()},
@@ -232,7 +234,8 @@ TEST(AggregationRequestTest, ShouldSerializeOptionalValuesIfSet) {
{query_request_helper::kUnwrappedReadPrefField, readPrefObj},
{AggregateCommandRequest::kRequestReshardingResumeTokenFieldName, true},
{AggregateCommandRequest::kIsMapReduceCommandFieldName, true},
- {AggregateCommandRequest::kCollectionUUIDFieldName, uuid}};
+ {AggregateCommandRequest::kCollectionUUIDFieldName, uuid},
+ {AggregateCommandRequest::kIsClusterQueryWithoutShardKeyCmdFieldName, true}};
ASSERT_DOCUMENT_EQ(aggregation_request_helper::serializeToCommandDoc(request),
expectedSerialization);
}
diff --git a/src/mongo/db/query/cursor_response.cpp b/src/mongo/db/query/cursor_response.cpp
index d53a24c7c5f..1940e3def4d 100644
--- a/src/mongo/db/query/cursor_response.cpp
+++ b/src/mongo/db/query/cursor_response.cpp
@@ -57,6 +57,7 @@ const char kBatchDocSequenceFieldInitial[] = "cursor.firstBatch";
const char kPostBatchResumeTokenField[] = "postBatchResumeToken";
const char kPartialResultsReturnedField[] = "partialResultsReturned";
const char kInvalidatedField[] = "invalidated";
+const char kWasStatementExecuted[] = "$_wasStatementExecuted";
} // namespace
@@ -84,6 +85,10 @@ void CursorResponseBuilder::done(CursorId cursorId, const NamespaceString& curso
_cursorObject->append(kInvalidatedField, _invalidated);
}
+ if (_wasStatementExecuted) {
+ _cursorObject->append(kWasStatementExecuted, _wasStatementExecuted);
+ }
+
_cursorObject->append(kIdField, cursorId);
_cursorObject->append(kNsField, NamespaceStringUtil::serialize(cursorNamespace));
if (_options.atClusterTime) {
@@ -140,7 +145,8 @@ CursorResponse::CursorResponse(NamespaceString nss,
boost::optional<BSONObj> varsField,
boost::optional<std::string> cursorType,
bool partialResultsReturned,
- bool invalidated)
+ bool invalidated,
+ bool wasStatementExecuted)
: _nss(std::move(nss)),
_cursorId(cursorId),
_batch(std::move(batch)),
@@ -150,7 +156,8 @@ CursorResponse::CursorResponse(NamespaceString nss,
_varsField(std::move(varsField)),
_cursorType(std::move(cursorType)),
_partialResultsReturned(partialResultsReturned),
- _invalidated(invalidated) {}
+ _invalidated(invalidated),
+ _wasStatementExecuted(wasStatementExecuted) {}
std::vector<StatusWith<CursorResponse>> CursorResponse::parseFromBSONMany(
const BSONObj& cmdResponse) {
@@ -299,6 +306,16 @@ StatusWith<CursorResponse> CursorResponse::parseFromBSON(const BSONObj& cmdRespo
}
}
+ auto wasStatementExecuted = cursorObj[kWasStatementExecuted];
+ if (wasStatementExecuted) {
+ if (wasStatementExecuted.type() != BSONType::Bool) {
+ return {ErrorCodes::BadValue,
+ str::stream() << kWasStatementExecuted
+ << " format is invalid; expected Bool, but found: "
+ << wasStatementExecuted.type()};
+ }
+ }
+
auto writeConcernError = cmdResponse["writeConcernError"];
if (writeConcernError && writeConcernError.type() != BSONType::Object) {
@@ -317,7 +334,8 @@ StatusWith<CursorResponse> CursorResponse::parseFromBSON(const BSONObj& cmdRespo
varsElt ? varsElt.Obj().getOwned() : boost::optional<BSONObj>{},
typeElt ? boost::make_optional<std::string>(typeElt.String()) : boost::none,
partialResultsReturned.trueValue(),
- invalidatedElem.trueValue()}};
+ invalidatedElem.trueValue(),
+ wasStatementExecuted.trueValue()}};
}
void CursorResponse::addToBSON(CursorResponse::ResponseType responseType,
@@ -351,6 +369,10 @@ void CursorResponse::addToBSON(CursorResponse::ResponseType responseType,
cursorBuilder.append(kInvalidatedField, _invalidated);
}
+ if (_wasStatementExecuted) {
+ cursorBuilder.append(kWasStatementExecuted, _wasStatementExecuted);
+ }
+
cursorBuilder.doneFast();
builder->append("ok", 1.0);
diff --git a/src/mongo/db/query/cursor_response.h b/src/mongo/db/query/cursor_response.h
index 7b33b71e1d5..87f5f22c263 100644
--- a/src/mongo/db/query/cursor_response.h
+++ b/src/mongo/db/query/cursor_response.h
@@ -99,6 +99,10 @@ public:
_invalidated = true;
}
+ void setWasStatementExecuted(bool wasStatementExecuted) {
+ _wasStatementExecuted = true;
+ }
+
long long numDocs() const {
return _numDocs;
}
@@ -135,6 +139,7 @@ private:
BSONObj _postBatchResumeToken;
bool _partialResultsReturned = false;
bool _invalidated = false;
+ bool _wasStatementExecuted = false;
};
/**
@@ -216,7 +221,8 @@ public:
boost::optional<BSONObj> varsField = boost::none,
boost::optional<std::string> cursorType = boost::none,
bool partialResultsReturned = false,
- bool invalidated = false);
+ bool invalidated = false,
+ bool wasStatementExecuted = false);
CursorResponse(CursorResponse&& other) = default;
CursorResponse& operator=(CursorResponse&& other) = default;
@@ -269,6 +275,10 @@ public:
return _invalidated;
}
+ bool getWasStatementExecuted() const {
+ return _wasStatementExecuted;
+ }
+
/**
* Converts this response to its raw BSON representation.
*/
@@ -289,6 +299,7 @@ private:
boost::optional<std::string> _cursorType;
bool _partialResultsReturned = false;
bool _invalidated = false;
+ bool _wasStatementExecuted = false;
};
} // namespace mongo
diff --git a/src/mongo/db/query/cursor_response.idl b/src/mongo/db/query/cursor_response.idl
index af264710874..391beb38c35 100644
--- a/src/mongo/db/query/cursor_response.idl
+++ b/src/mongo/db/query/cursor_response.idl
@@ -76,6 +76,15 @@ structs:
description: "Boolean represents if the cursor has been invalidated."
type: optionalBool
stability: stable
+ $_wasStatementExecuted:
+ description: "An optional field set to true if a write without shard key had already
+ applied the write. To provide some context, this internal field is
+ used by the two phase write without shard key protocol introduced in
+ PM-1632 to support retryable writes for updateOne without shard key,
+ deleteOne without shard key, and findAndModify without shard key."
+ type: optionalBool
+ cpp_name: wasStatementExecuted
+ stability: internal
InitialResponseCursor:
description: "A struct representing an initial response cursor."
diff --git a/src/mongo/s/commands/cluster_query_without_shard_key_cmd.cpp b/src/mongo/s/commands/cluster_query_without_shard_key_cmd.cpp
index 2511a722fe1..4d1f8ed17ec 100644
--- a/src/mongo/s/commands/cluster_query_without_shard_key_cmd.cpp
+++ b/src/mongo/s/commands/cluster_query_without_shard_key_cmd.cpp
@@ -47,8 +47,18 @@ namespace {
struct ParsedCommandInfo {
BSONObj query;
BSONObj collation;
+ int stmtId;
- ParsedCommandInfo(BSONObj query, BSONObj collation) : query(query), collation(collation) {}
+ ParsedCommandInfo(BSONObj query, BSONObj collation, int stmtId)
+ : query(query), collation(collation), stmtId(stmtId) {}
+};
+
+struct AsyncRequestSenderResponseData {
+ ShardId shardId;
+ CursorResponse cursorResponse;
+
+ AsyncRequestSenderResponseData(ShardId shardId, CursorResponse cursorResponse)
+ : shardId(shardId), cursorResponse(std::move(cursorResponse)) {}
};
std::set<ShardId> getShardsToTarget(OperationContext* opCtx,
@@ -85,6 +95,12 @@ BSONObj createAggregateCmdObj(OperationContext* opCtx,
BSON("$limit" << 1),
BSON("$project" << BSON("_id" << 1))});
aggregate.setCollation(parsedInfo.collation);
+ aggregate.setIsClusterQueryWithoutShardKeyCmd(true);
+ aggregate.setFromMongos(true);
+
+ if (parsedInfo.stmtId != kUninitializedStmtId) {
+ aggregate.setStmtId(parsedInfo.stmtId);
+ }
return aggregate.toBSON({});
}
@@ -110,17 +126,29 @@ public:
const NamespaceString nss(
CommandHelpers::parseNsCollectionRequired(ns().dbName(), request().getWriteCmd()));
const auto cri = uassertStatusOK(getCollectionRoutingInfoForTxnCmd(opCtx, nss));
+
auto parsedInfoFromRequest = [&] {
const auto commandName = request().getWriteCmd().firstElementFieldNameStringData();
+
BSONObjBuilder bob(request().getWriteCmd());
bob.appendElementsUnique(BSON("$db" << ns().dbName().toString()));
auto writeCmdObj = bob.obj();
+
BSONObj query;
BSONObj collation;
+ int stmtId = kUninitializedStmtId;
+
if (commandName == "update") {
auto updateRequest = write_ops::UpdateCommandRequest::parse(
IDLParserContext("_clusterQueryWithoutShardKey"), writeCmdObj);
query = updateRequest.getUpdates().front().getQ();
+
+ // In the batch write path, when the request is reconstructed to be passed to
+ // the two phase write protocol, only the stmtIds field is used.
+ if (auto stmtIds = updateRequest.getStmtIds()) {
+ stmtId = stmtIds->front();
+ }
+
if (auto parsedCollation = updateRequest.getUpdates().front().getCollation()) {
collation = *parsedCollation;
}
@@ -128,6 +156,13 @@ public:
auto deleteRequest = write_ops::DeleteCommandRequest::parse(
IDLParserContext("_clusterQueryWithoutShardKey"), writeCmdObj);
query = deleteRequest.getDeletes().front().getQ();
+
+ // In the batch write path, when the request is reconstructed to be passed to
+ // the two phase write protocol, only the stmtIds field is used.
+ if (auto stmtIds = deleteRequest.getStmtIds()) {
+ stmtId = stmtIds->front();
+ }
+
if (auto parsedCollation = deleteRequest.getDeletes().front().getCollation()) {
collation = *parsedCollation;
}
@@ -135,13 +170,15 @@ public:
auto findAndModifyRequest = write_ops::FindAndModifyCommandRequest::parse(
IDLParserContext("_clusterQueryWithoutShardKey"), writeCmdObj);
query = findAndModifyRequest.getQuery();
+ stmtId = findAndModifyRequest.getStmtId().value_or(kUninitializedStmtId);
+
if (auto parsedCollation = findAndModifyRequest.getCollation()) {
collation = *parsedCollation;
}
} else {
uasserted(ErrorCodes::InvalidOptions, "Not a supported batch write command");
}
- return ParsedCommandInfo(query.getOwned(), collation.getOwned());
+ return ParsedCommandInfo(query.getOwned(), collation.getOwned(), stmtId);
}();
auto allShardsContainingChunksForNs =
@@ -162,24 +199,35 @@ public:
ReadPreferenceSetting(ReadPreference::PrimaryOnly),
Shard::RetryPolicy::kNoRetry);
- std::vector<AsyncRequestsSender::Response> results;
+ BSONObj targetDoc;
+ Response res;
+ std::vector<AsyncRequestSenderResponseData> responses;
while (!ars.done()) {
auto response = ars.next();
uassertStatusOK(response.swResponse);
- results.push_back(response);
+ responses.emplace_back(
+ AsyncRequestSenderResponseData(response.shardId,
+ uassertStatusOK(CursorResponse::parseFromBSON(
+ response.swResponse.getValue().data))));
}
- BSONObj targetDoc;
- Response res;
- for (const auto& arsRes : results) {
- auto cursorResponse = uassertStatusOK(
- CursorResponse::parseFromBSON(arsRes.swResponse.getValue().data));
+ for (auto& responseData : responses) {
+ auto shardId = responseData.shardId;
+ auto cursorResponse = std::move(responseData.cursorResponse);
+
+ // Return the first target doc/shard id pair that has already applied the write
+ // for a retryable write.
+ if (cursorResponse.getWasStatementExecuted()) {
+ // Since the retryable write history check happens before a write is executed,
+ // we can just use an empty BSONObj for the target doc.
+ res.setTargetDoc(BSONObj::kEmptyObject);
+ res.setShardId(boost::optional<mongo::StringData>(shardId));
+ break;
+ }
- // Return the first response that contains a matching document.
if (cursorResponse.getBatch().size() > 0) {
res.setTargetDoc(cursorResponse.releaseBatch().front().getOwned());
- res.setShardId(boost::optional<mongo::StringData>(arsRes.shardId));
- break;
+ res.setShardId(boost::optional<mongo::StringData>(shardId));
}
}
return res;
diff --git a/src/mongo/s/query/cluster_aggregate.cpp b/src/mongo/s/query/cluster_aggregate.cpp
index 3bc19bdef59..5550b177258 100644
--- a/src/mongo/s/query/cluster_aggregate.cpp
+++ b/src/mongo/s/query/cluster_aggregate.cpp
@@ -218,17 +218,20 @@ void updateHostsTargetedMetrics(OperationContext* opCtx,
}
/**
- * Performs validations related to API versioning and time-series stages.
+ * Performs validations related to API versioning, time-series stages, and general command
+ * validation.
* Throws UserAssertion if any of the validations fails
* - validation of API versioning on each stage on the pipeline
* - validation of API versioning on 'AggregateCommandRequest' request
* - validation of time-series related stages
+ * - validation of command parameters
*/
void performValidationChecks(const OperationContext* opCtx,
const AggregateCommandRequest& request,
const LiteParsedPipeline& liteParsedPipeline) {
liteParsedPipeline.validate(opCtx);
aggregation_request_helper::validateRequestForAPIVersion(opCtx, request);
+ aggregation_request_helper::validateRequestFromClusterQueryWithoutShardKey(request);
}
/**