diff options
Diffstat (limited to 'src')
25 files changed, 581 insertions, 147 deletions
diff --git a/src/mongo/db/commands/write_commands/write_commands.cpp b/src/mongo/db/commands/write_commands/write_commands.cpp index 610f5bb7b2c..c2892e08fd9 100644 --- a/src/mongo/db/commands/write_commands/write_commands.cpp +++ b/src/mongo/db/commands/write_commands/write_commands.cpp @@ -363,6 +363,7 @@ private: UpdateRequest updateRequest(_batch.getNamespace()); updateRequest.setQuery(_batch.getUpdates()[0].getQ()); updateRequest.setUpdateModification(_batch.getUpdates()[0].getU()); + updateRequest.setUpdateConstants(_batch.getUpdates()[0].getC()); updateRequest.setRuntimeConstants( _batch.getRuntimeConstants().value_or(Variables::generateRuntimeConstants(opCtx))); updateRequest.setCollation(write_ops::collationOf(_batch.getUpdates()[0])); diff --git a/src/mongo/db/ops/parsed_update.cpp b/src/mongo/db/ops/parsed_update.cpp index d47d973c94b..79e1d9ed692 100644 --- a/src/mongo/db/ops/parsed_update.cpp +++ b/src/mongo/db/ops/parsed_update.cpp @@ -146,7 +146,10 @@ void ParsedUpdate::parseUpdate() { _driver.setLogOp(true); _driver.setFromOplogApplication(_request->isFromOplogApplication()); - _driver.parse(_request->getUpdateModification(), _arrayFilters, _request->isMulti()); + _driver.parse(_request->getUpdateModification(), + _arrayFilters, + _request->getUpdateConstants(), + _request->isMulti()); } StatusWith<std::map<StringData, std::unique_ptr<ExpressionWithPlaceholder>>> diff --git a/src/mongo/db/ops/update_request.h b/src/mongo/db/ops/update_request.h index 763307a9260..e0f99abf1d4 100644 --- a/src/mongo/db/ops/update_request.h +++ b/src/mongo/db/ops/update_request.h @@ -111,6 +111,14 @@ public: return _updateMod; } + inline void setUpdateConstants(const boost::optional<BSONObj>& updateConstants) { + _updateConstants = updateConstants; + } + + inline const boost::optional<BSONObj>& getUpdateConstants() const { + return _updateConstants; + } + inline void setRuntimeConstants(const RuntimeConstants& runtimeConstants) { _runtimeConstants = runtimeConstants; } @@ -230,6 +238,10 @@ public: } builder << "]"; + if (_updateConstants) { + builder << " updateConstants: " << *_updateConstants; + } + if (_runtimeConstants) { builder << " runtimeConstants: " << _runtimeConstants->toBSON().toString(); } @@ -261,7 +273,14 @@ private: // Contains the modifiers to apply to matched objects, or a replacement document. write_ops::UpdateModification _updateMod; - // Contains any constant values which may be required by the query or update operation. + // User-defined constant values to be used with a pipeline-style update. Those are different + // from the '_runtimeConstants' as they can be specified by the user for each individual + // element of the 'updates' array in the 'update' command. The '_runtimeConstants' contains + // runtime system constant values which remain unchanged for all update statements in the + // 'update' command. + boost::optional<BSONObj> _updateConstants; + + // System-defined constant values which may be required by the query or update operation. boost::optional<RuntimeConstants> _runtimeConstants; // Filters to specify which array elements should be updated. diff --git a/src/mongo/db/ops/write_ops.idl b/src/mongo/db/ops/write_ops.idl index 95ef7f5e71d..a4ddaadd966 100644 --- a/src/mongo/db/ops/write_ops.idl +++ b/src/mongo/db/ops/write_ops.idl @@ -101,6 +101,11 @@ structs: u: description: "Set of modifications to apply." type: update_modification + c: + description: "Specifies constant values that can be referred to in the pipeline + performing a custom update." + type: object + optional: true arrayFilters: description: "Specifies which array elements an update modifier should apply to." type: array<object> diff --git a/src/mongo/db/ops/write_ops_exec.cpp b/src/mongo/db/ops/write_ops_exec.cpp index acd90695b4f..79b519779c8 100644 --- a/src/mongo/db/ops/write_ops_exec.cpp +++ b/src/mongo/db/ops/write_ops_exec.cpp @@ -708,6 +708,7 @@ static SingleWriteResult performSingleUpdateOpWithDupKeyRetry(OperationContext* UpdateRequest request(ns); request.setQuery(op.getQ()); request.setUpdateModification(op.getU()); + request.setUpdateConstants(op.getC()); request.setRuntimeConstants(std::move(runtimeConstants)); request.setCollation(write_ops::collationOf(op)); request.setStmtId(stmtId); diff --git a/src/mongo/db/pipeline/document_source_merge.cpp b/src/mongo/db/pipeline/document_source_merge.cpp index d871fb2d7b5..196a1cdb531 100644 --- a/src/mongo/db/pipeline/document_source_merge.cpp +++ b/src/mongo/db/pipeline/document_source_merge.cpp @@ -58,6 +58,7 @@ using MergeStrategyDescriptorsMap = std::map<const MergeMode, const MergeStrateg using WhenMatched = MergeStrategyDescriptor::WhenMatched; using WhenNotMatched = MergeStrategyDescriptor::WhenNotMatched; using BatchTransform = std::function<void(DocumentSourceMerge::BatchedObjects&)>; +using UpdateModification = write_ops::UpdateModification; constexpr auto kStageName = DocumentSourceMerge::kStageName; constexpr auto kDefaultWhenMatched = WhenMatched::kMerge; @@ -91,14 +92,8 @@ MergeStrategy makeUpdateStrategy(bool upsert, BatchTransform transform) { } constexpr auto multi = false; - expCtx->mongoProcessInterface->update(expCtx, - ns, - std::move(batch.uniqueKeys), - std::move(batch.modifications), - wc, - upsert, - multi, - epoch); + expCtx->mongoProcessInterface->update( + expCtx, ns, std::move(batch), wc, upsert, multi, epoch); }; } @@ -119,15 +114,8 @@ MergeStrategy makeStrictUpdateStrategy(bool upsert, BatchTransform transform) { const auto batchSize = batch.size(); constexpr auto multi = false; - auto writeResult = - expCtx->mongoProcessInterface->updateWithResult(expCtx, - ns, - std::move(batch.uniqueKeys), - std::move(batch.modifications), - wc, - upsert, - multi, - epoch); + auto writeResult = expCtx->mongoProcessInterface->updateWithResult( + expCtx, ns, std::move(batch), wc, upsert, multi, epoch); constexpr auto initValue = 0ULL; auto nMatched = std::accumulate(writeResult.results.begin(), @@ -151,10 +139,9 @@ MergeStrategy makeInsertStrategy() { std::vector<BSONObj> objectsToInsert(batch.size()); // The batch stores replacement style updates, but for this "insert" style of $merge we'd // like to just insert the new document without attempting any sort of replacement. - std::transform(batch.modifications.begin(), - batch.modifications.end(), - objectsToInsert.begin(), - [](const auto& mod) { return mod.getUpdateClassic(); }); + std::transform(batch.begin(), batch.end(), objectsToInsert.begin(), [](const auto& obj) { + return std::get<UpdateModification>(obj).getUpdateClassic(); + }); expCtx->mongoProcessInterface->insert(expCtx, ns, std::move(objectsToInsert), wc, epoch); }; } @@ -165,11 +152,10 @@ MergeStrategy makeInsertStrategy() { */ BatchTransform makeUpdateTransform(const std::string& updateOp) { return [updateOp](auto& batch) { - std::transform( - batch.modifications.begin(), - batch.modifications.end(), - batch.modifications.begin(), - [updateOp](const auto& mod) { return BSON(updateOp << mod.getUpdateClassic()); }); + for (auto&& obj : batch) { + std::get<UpdateModification>(obj) = + BSON(updateOp << std::get<UpdateModification>(obj).getUpdateClassic()); + } }; } @@ -446,7 +432,8 @@ std::unique_ptr<DocumentSourceMerge::LiteParsed> DocumentSourceMerge::LiteParsed DocumentSourceMerge::DocumentSourceMerge(NamespaceString outputNs, const boost::intrusive_ptr<ExpressionContext>& expCtx, const MergeStrategyDescriptor& descriptor, - boost::optional<std::vector<BSONObj>>&& pipeline, + boost::optional<BSONObj> letVariables, + boost::optional<std::vector<BSONObj>> pipeline, std::set<FieldPath> mergeOnFields, boost::optional<ChunkVersion> targetCollectionVersion, bool serializeAsOutStage) @@ -459,14 +446,28 @@ DocumentSourceMerge::DocumentSourceMerge(NamespaceString outputNs, _pipeline(std::move(pipeline)), _mergeOnFields(std::move(mergeOnFields)), _mergeOnFieldsIncludesId(_mergeOnFields.count("_id") == 1), - _serializeAsOutStage(serializeAsOutStage) {} + _serializeAsOutStage(serializeAsOutStage) { + if (letVariables) { + _letVariables.emplace(); + + for (auto&& varElem : *letVariables) { + const auto varName = varElem.fieldNameStringData(); + Variables::uassertValidNameForUserWrite(varName); + + _letVariables->emplace( + varName.toString(), + Expression::parseOperand(expCtx, varElem, expCtx->variablesParseState)); + } + } +} boost::intrusive_ptr<DocumentSource> DocumentSourceMerge::create( NamespaceString outputNs, const boost::intrusive_ptr<ExpressionContext>& expCtx, WhenMatched whenMatched, WhenNotMatched whenNotMatched, - boost::optional<std::vector<BSONObj>>&& pipeline, + boost::optional<BSONObj> letVariables, + boost::optional<std::vector<BSONObj>> pipeline, std::set<FieldPath> mergeOnFields, boost::optional<ChunkVersion> targetCollectionVersion, bool serializeAsOutStage) { @@ -496,11 +497,27 @@ boost::intrusive_ptr<DocumentSource> DocumentSourceMerge::create( "Cannot {} into special collection: '{}'"_format(kStageName, outputNs.coll()), !outputNs.isSpecial()); + if (whenMatched == WhenMatched::kPipeline) { + if (!letVariables) { + // For custom pipeline-style updates, default the 'let' variables to {new: "$$ROOT"}, + // if the user has omitted the 'let' argument. + letVariables = BSON("new" + << "$$ROOT"); + } + } else { + // Ensure the 'let' argument cannot be used with any other merge modes. + uassert(51199, + "Cannot use 'let' variables with 'whenMatched: {}' mode"_format( + MergeWhenMatchedMode_serializer(whenMatched)), + !letVariables); + } + return new DocumentSourceMerge(outputNs, expCtx, getDescriptors().at({whenMatched, whenNotMatched}), + std::move(letVariables), std::move(pipeline), - mergeOnFields, + std::move(mergeOnFields), targetCollectionVersion, serializeAsOutStage); } @@ -526,6 +543,7 @@ boost::intrusive_ptr<DocumentSource> DocumentSourceMerge::createFromBson( expCtx, whenMatched, whenNotMatched, + mergeSpec.getLet(), std::move(pipeline), std::move(mergeOnFields), targetCollectionVersion, @@ -574,19 +592,19 @@ DocumentSource::GetNextResult DocumentSourceMerge::getNext() { doc = mutableDoc.freeze(); } - // Extract the 'on' fields before converting the document to BSON. auto mergeOnFields = extractMergeOnFieldsFromDoc(doc, _mergeOnFields); - auto mod = _pipeline ? write_ops::UpdateModification(*_pipeline) - : write_ops::UpdateModification(doc.toBson()); + auto mod = makeBatchUpdateModification(doc); + auto vars = resolveLetVariablesIfNeeded(doc); + auto modSize = mod.objsize() + (vars ? vars->objsize() : 0); - bufferedBytes += mod.objsize(); + bufferedBytes += modSize; if (!batch.empty() && (bufferedBytes > BSONObjMaxUserSize || batch.size() >= write_ops::kMaxWriteBatchSize)) { spill(std::move(batch)); batch.clear(); - bufferedBytes = mod.objsize(); + bufferedBytes = modSize; } - batch.emplace(std::move(mod), std::move(mergeOnFields)); + batch.emplace_back(std::move(mergeOnFields), std::move(mod), std::move(vars)); } if (!batch.empty()) { spill(std::move(batch)); @@ -637,6 +655,17 @@ Value DocumentSourceMerge::serialize(boost::optional<ExplainOptions::Verbosity> } else { DocumentSourceMergeSpec spec; spec.setTargetNss(_outputNs); + spec.setLet([&]() -> boost::optional<BSONObj> { + if (!_letVariables) { + return boost::none; + } + + BSONObjBuilder bob; + for (auto && [ name, expr ] : *_letVariables) { + bob << name << expr->serialize(static_cast<bool>(explain)); + } + return bob.obj(); + }()); spec.setWhenMatched(MergeWhenMatchedPolicy{_descriptor.mode.first, _pipeline}); spec.setWhenNotMatched(_descriptor.mode.second); spec.setOn([&]() { diff --git a/src/mongo/db/pipeline/document_source_merge.h b/src/mongo/db/pipeline/document_source_merge.h index b440afac666..3b0d143ad9b 100644 --- a/src/mongo/db/pipeline/document_source_merge.h +++ b/src/mongo/db/pipeline/document_source_merge.h @@ -42,41 +42,9 @@ namespace mongo { */ class DocumentSourceMerge final : public DocumentSource { public: - static constexpr StringData kStageName = "$merge"_sd; - - /** - * Storage for a batch of BSON Objects to be inserted/updated to the write namespace. The - * extracted 'on' field values are also stored in a batch, used by 'MergeStrategy' as the query - * portion of the update or insert. - */ - struct BatchedObjects { - void emplace(write_ops::UpdateModification&& mod, BSONObj&& key) { - modifications.emplace_back(std::move(mod)); - uniqueKeys.emplace_back(std::move(key)); - } - - bool empty() const { - return modifications.empty(); - } + using BatchedObjects = MongoProcessInterface::BatchedObjects; - size_t size() const { - return modifications.size(); - } - - void clear() { - modifications.clear(); - uniqueKeys.clear(); - } - - // For each element in the batch we store an UpdateModification which is either the new - // document we want to upsert or insert into the collection (i.e. a 'classic' replacement - // update), or the pipeline to run to compute the new document. - std::vector<write_ops::UpdateModification> modifications; - - // Store the unique keys as BSON objects instead of Documents for compatibility with the - // batch update command. (e.g. {q: <array of uniqueKeys>, u: <array of objects>}) - std::vector<BSONObj> uniqueKeys; - }; + static constexpr StringData kStageName = "$merge"_sd; // A descriptor for a merge strategy. Holds a merge strategy function and a set of actions // the client should be authorized to perform in order to be able to execute a merge operation @@ -125,7 +93,8 @@ public: DocumentSourceMerge(NamespaceString outputNs, const boost::intrusive_ptr<ExpressionContext>& expCtx, const MergeStrategyDescriptor& descriptor, - boost::optional<std::vector<BSONObj>>&& pipeline, + boost::optional<BSONObj> letVariables, + boost::optional<std::vector<BSONObj>> pipeline, std::set<FieldPath> mergeOnFields, boost::optional<ChunkVersion> targetCollectionVersion, bool serializeAsOutStage); @@ -205,7 +174,8 @@ public: const boost::intrusive_ptr<ExpressionContext>& expCtx, MergeStrategyDescriptor::WhenMatched whenMatched, MergeStrategyDescriptor::WhenNotMatched whenNotMatched, - boost::optional<std::vector<BSONObj>>&& pipeline, + boost::optional<BSONObj> letVariables, + boost::optional<std::vector<BSONObj>> pipeline, std::set<FieldPath> mergeOnFields, boost::optional<ChunkVersion> targetCollectionVersion, bool serializeAsOutStage); @@ -224,18 +194,44 @@ private: OutStageWriteBlock writeBlock(pExpCtx->opCtx); try { - _descriptor.strategy( - pExpCtx, _outputNs, _writeConcern, _targetEpoch(), std::move(batch)); + auto targetEpoch = _targetCollectionVersion + ? boost::optional<OID>(_targetCollectionVersion->epoch()) + : boost::none; + + _descriptor.strategy(pExpCtx, _outputNs, _writeConcern, targetEpoch, std::move(batch)); } catch (const ExceptionFor<ErrorCodes::ImmutableField>& ex) { uassertStatusOKWithContext(ex.toStatus(), "$merge failed to update the matching document, did you " "attempt to modify the _id or the shard key?"); } - }; + } + + /** + * Creates an UpdateModification object from the given 'doc' to be used with the batched update. + */ + auto makeBatchUpdateModification(const Document& doc) { + return _pipeline ? write_ops::UpdateModification(*_pipeline) + : write_ops::UpdateModification(doc.toBson()); + } - boost::optional<OID> _targetEpoch() { - return _targetCollectionVersion ? boost::optional<OID>(_targetCollectionVersion->epoch()) - : boost::none; + /** + * Resolves 'let' defined variables against the 'doc' and stores the results in the returned + * BSON. + */ + boost::optional<BSONObj> resolveLetVariablesIfNeeded(const Document& doc) { + // When we resolve 'let' variables, an empty BSON object or boost::none won't make any + // difference at the end-point (in the PipelineExecutor), as in both cases we will end up + // with the update pipeline ExpressionContext not being populated with any variables, so we + // are not making a distinction between these two cases here. + if (!_letVariables || _letVariables->empty()) { + return boost::none; + } + + BSONObjBuilder bob; + for (auto && [ name, expr ] : *_letVariables) { + bob << name << expr->evaluate(doc); + } + return bob.obj(); } // Stash the writeConcern of the original command as the operation context may change by the @@ -257,6 +253,13 @@ private: // descriptor. const MergeStrategyDescriptor& _descriptor; + // Holds 'let' variables defined in this stage. These variables are propagated to the + // ExpressionContext of the pipeline update for use in the inner pipeline execution. The key + // of the map is a variable name as defined in the $merge spec 'let' argument, and the value is + // a parsed Expression, defining how the variable value must be evaluated. + boost::optional<stdx::unordered_map<std::string, boost::intrusive_ptr<Expression>>> + _letVariables; + // A custom pipeline to compute a new version of merging documents. boost::optional<std::vector<BSONObj>> _pipeline; diff --git a/src/mongo/db/pipeline/document_source_merge.idl b/src/mongo/db/pipeline/document_source_merge.idl index fabda43b017..1b79ad9e673 100644 --- a/src/mongo/db/pipeline/document_source_merge.idl +++ b/src/mongo/db/pipeline/document_source_merge.idl @@ -92,6 +92,13 @@ structs: optional: true description: A single field or array of fields that uniquely identify a document. + let: + type: object + optional: true + description: Specifies variables to use in the update pipeline defined in + MergeWhenMatchedPolicy when the 'whenMatched' mode is a custom + pipeline. + whenMatched: type: MergeWhenMatchedPolicy optional: true diff --git a/src/mongo/db/pipeline/document_source_merge_test.cpp b/src/mongo/db/pipeline/document_source_merge_test.cpp index 0fb4686db82..f55857c00e5 100644 --- a/src/mongo/db/pipeline/document_source_merge_test.cpp +++ b/src/mongo/db/pipeline/document_source_merge_test.cpp @@ -632,7 +632,7 @@ TEST_F(DocumentSourceMergeTest, CorrectlyHandlesWhenMatchedAndWhenNotMatchedMode << "fail" << "whenNotMatched" << "fail")); - ASSERT_THROWS_CODE(createMergeStage(spec), DBException, 51189); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, 51189); spec = BSON("$merge" << BSON("into" << "target_collection" @@ -680,7 +680,7 @@ TEST_F(DocumentSourceMergeTest, CorrectlyHandlesWhenMatchedAndWhenNotMatchedMode << "keepExisting" << "whenNotMatched" << "fail")); - ASSERT_THROWS_CODE(createMergeStage(spec), DBException, 51189); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, 51189); spec = BSON("$merge" << BSON("into" << "target_collection" @@ -720,7 +720,7 @@ TEST_F(DocumentSourceMergeTest, CorrectlyHandlesWhenMatchedAndWhenNotMatchedMode << "pipeline" << "whenNotMatched" << "insert")); - ASSERT_THROWS_CODE(createMergeStage(spec), DBException, ErrorCodes::BadValue); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, ErrorCodes::BadValue); spec = BSON("$merge" << BSON("into" << "target_collection" @@ -728,7 +728,268 @@ TEST_F(DocumentSourceMergeTest, CorrectlyHandlesWhenMatchedAndWhenNotMatchedMode << "[{$addFields: {x: 1}}]" << "whenNotMatched" << "insert")); - ASSERT_THROWS_CODE(createMergeStage(spec), DBException, ErrorCodes::BadValue); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, ErrorCodes::BadValue); +} + +TEST_F(DocumentSourceMergeTest, LetVariablesCanOnlyBeUsedWithPipelineMode) { + auto let = BSON("foo" + << "bar"); + auto spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << let + << "whenMatched" + << BSON_ARRAY(BSON("$project" << BSON("x" << 1))) + << "whenNotMatched" + << "insert")); + ASSERT(createMergeStage(spec)); + + spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << let + << "whenMatched" + << BSON_ARRAY(BSON("$project" << BSON("x" << 1))) + << "whenNotMatched" + << "fail")); + ASSERT(createMergeStage(spec)); + + spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << let + << "whenMatched" + << BSON_ARRAY(BSON("$project" << BSON("x" << 1))) + << "whenNotMatched" + << "discard")); + ASSERT(createMergeStage(spec)); + + spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << let + << "whenMatched" + << "replaceWithNew" + << "whenNotMatched" + << "insert")); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, 51199); + + spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << let + << "whenMatched" + << "replaceWithNew" + << "whenNotMatched" + << "fail")); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, 51199); + + spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << let + << "whenMatched" + << "replaceWithNew" + << "whenNotMatched" + << "discard")); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, 51199); + + spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << let + << "whenMatched" + << "merge" + << "whenNotMatched" + << "insert")); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, 51199); + + spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << let + << "whenMatched" + << "merge" + << "whenNotMatched" + << "fail")); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, 51199); + + spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << let + << "whenMatched" + << "merge" + << "whenNotMatched" + << "discard")); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, 51199); + + spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << let + << "whenMatched" + << "keepExisting" + << "whenNotMatched" + << "insert")); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, 51199); + + spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << let + << "whenMatched" + << "fail" + << "whenNotMatched" + << "insert")); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, 51199); +} + +TEST_F(DocumentSourceMergeTest, SerializeDefaultLetVariable) { + auto spec = BSON("$merge" << BSON("into" + << "target_collection" + << "whenMatched" + << BSON_ARRAY(BSON("$project" << BSON("x" << 1))) + << "whenNotMatched" + << "insert")); + auto mergeStage = createMergeStage(spec); + auto serialized = mergeStage->serialize().getDocument(); + ASSERT_VALUE_EQ(serialized["$merge"]["let"], + Value(BSON("new" + << "$$ROOT"))); +} + +TEST_F(DocumentSourceMergeTest, SerializeLetVariables) { + auto pipeline = BSON_ARRAY(BSON("$project" << BSON("x" + << "$$v1" + << "y" + << "$$v2" + << "z" + << "$$v3"))); + auto spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << BSON("v1" << 10 << "v2" + << "foo" + << "v3" + << BSON("x" << 1 << "y" << BSON("z" + << "bar"))) + << "whenMatched" + << pipeline + << "whenNotMatched" + << "insert")); + auto mergeStage = createMergeStage(spec); + ASSERT(mergeStage); + auto serialized = mergeStage->serialize().getDocument(); + ASSERT_VALUE_EQ(serialized["$merge"]["let"]["v1"], Value(BSON("$const" << 10))); + ASSERT_VALUE_EQ(serialized["$merge"]["let"]["v2"], + Value(BSON("$const" + << "foo"))); + ASSERT_VALUE_EQ(serialized["$merge"]["let"]["v3"], + Value(BSON("x" << BSON("$const" << 1) << "y" << BSON("z" << BSON("$const" + << "bar"))))); + ASSERT_VALUE_EQ(serialized["$merge"]["whenMatched"], Value(pipeline)); +} + +TEST_F(DocumentSourceMergeTest, SerializeLetArrayVariable) { + auto pipeline = BSON_ARRAY(BSON("$project" << BSON("x" + << "$$v1"))); + auto spec = + BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << BSON("v1" << BSON_ARRAY(1 << "2" << BSON("x" << 1 << "y" << 2))) + << "whenMatched" + << pipeline + << "whenNotMatched" + << "insert")); + auto mergeStage = createMergeStage(spec); + ASSERT(mergeStage); + auto serialized = mergeStage->serialize().getDocument(); + ASSERT_VALUE_EQ(serialized["$merge"]["let"]["v1"], + Value(BSON_ARRAY(BSON("$const" << 1) << BSON("$const" + << "2") + << BSON("x" << BSON("$const" << 1) << "y" + << BSON("$const" << 2))))); + ASSERT_VALUE_EQ(serialized["$merge"]["whenMatched"], Value(pipeline)); +} + +// This test verifies that when the 'let' argument is specified as 'null', the default 'new' +// variable is still available. This is not a desirable behaviour but rather a limitation in the +// IDL parser which cannot differentiate between an optional field specified explicitly as 'null', +// or not specified at all. In both cases it will treat the field like it wasn't specified. So, +// this test ensures that we're aware of this limitation. Once the limitation is addressed in +// SERVER-41272, this test should be updated to accordingly. +TEST_F(DocumentSourceMergeTest, SerializeNullLetVariablesAsDefault) { + auto pipeline = BSON_ARRAY(BSON("$project" << BSON("x" + << "1"))); + auto spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << BSONNULL + << "whenMatched" + << pipeline + << "whenNotMatched" + << "insert")); + auto mergeStage = createMergeStage(spec); + ASSERT(mergeStage); + auto serialized = mergeStage->serialize().getDocument(); + ASSERT_VALUE_EQ(serialized["$merge"]["let"], + Value(BSON("new" + << "$$ROOT"))); + ASSERT_VALUE_EQ(serialized["$merge"]["whenMatched"], Value(pipeline)); +} + +TEST_F(DocumentSourceMergeTest, SerializeEmptyLetVariables) { + auto pipeline = BSON_ARRAY(BSON("$project" << BSON("x" + << "1"))); + auto spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << BSONObj() + << "whenMatched" + << pipeline + << "whenNotMatched" + << "insert")); + auto mergeStage = createMergeStage(spec); + ASSERT(mergeStage); + auto serialized = mergeStage->serialize().getDocument(); + ASSERT_VALUE_EQ(serialized["$merge"]["let"], Value(BSONObj())); + ASSERT_VALUE_EQ(serialized["$merge"]["whenMatched"], Value(pipeline)); +} + +TEST_F(DocumentSourceMergeTest, OnlyObjectCanBeUsedAsLetVariables) { + auto pipeline = BSON_ARRAY(BSON("$project" << BSON("x" + << "1"))); + auto spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << 1 + << "whenMatched" + << pipeline + << "whenNotMatched" + << "insert")); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, ErrorCodes::TypeMismatch); + + spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << "foo" + << "whenMatched" + << pipeline + << "whenNotMatched" + << "insert")); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, ErrorCodes::TypeMismatch); + + spec = BSON("$merge" << BSON("into" + << "target_collection" + << "let" + << BSON_ARRAY(1 << "2") + << "whenMatched" + << pipeline + << "whenNotMatched" + << "insert")); + ASSERT_THROWS_CODE(createMergeStage(spec), AssertionException, ErrorCodes::TypeMismatch); } } // namespace diff --git a/src/mongo/db/pipeline/document_source_out.cpp b/src/mongo/db/pipeline/document_source_out.cpp index 6dd3fda1d5c..4369e1c6a9c 100644 --- a/src/mongo/db/pipeline/document_source_out.cpp +++ b/src/mongo/db/pipeline/document_source_out.cpp @@ -321,6 +321,7 @@ intrusive_ptr<DocumentSource> DocumentSourceOut::create( expCtx, MergeWhenMatchedModeEnum::kFail, MergeWhenNotMatchedModeEnum::kInsert, + boost::none, /* no variables */ boost::none, /* no custom pipeline */ std::move(uniqueKey), targetCollectionVersion, @@ -330,7 +331,8 @@ intrusive_ptr<DocumentSource> DocumentSourceOut::create( expCtx, MergeWhenMatchedModeEnum::kReplaceWithNew, MergeWhenNotMatchedModeEnum::kInsert, - boost::none /* no custom pipeline */, + boost::none, /* no variables */ + boost::none, /* no custom pipeline */ std::move(uniqueKey), targetCollectionVersion, true /* serialize as $out stage */); diff --git a/src/mongo/db/pipeline/mongo_process_interface.h b/src/mongo/db/pipeline/mongo_process_interface.h index f36adb1f662..e96298a9fe3 100644 --- a/src/mongo/db/pipeline/mongo_process_interface.h +++ b/src/mongo/db/pipeline/mongo_process_interface.h @@ -68,6 +68,20 @@ class PipelineDeleter; */ class MongoProcessInterface { public: + /** + * Storage for a batch of BSON Objects to be updated in the write namespace. For each element + * in the batch we store a tuple of the folliwng elements: + * 1. BSONObj - specifies the query that identifies a document in the to collection to be + * updated. + * 2. write_ops::UpdateModification - either the new document we want to upsert or insert into + * the collection (i.e. a 'classic' replacement update), or the pipeline to run to compute + * the new document. + * 3. boost::optional<BSONObj> - for pipeline-style updated, specifies variables that can be + * referred to in the pipeline performing the custom update. + */ + using BatchedObjects = + std::vector<std::tuple<BSONObj, write_ops::UpdateModification, boost::optional<BSONObj>>>; + enum class CurrentOpConnectionsMode { kIncludeIdle, kExcludeIdle }; enum class CurrentOpUserMode { kIncludeAll, kExcludeOthers }; enum class CurrentOpTruncateMode { kNoTruncation, kTruncateOps }; @@ -135,8 +149,7 @@ public: */ virtual void update(const boost::intrusive_ptr<ExpressionContext>& expCtx, const NamespaceString& ns, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, const WriteConcernOptions& wc, bool upsert, bool multi, @@ -150,8 +163,7 @@ public: */ virtual WriteResult updateWithResult(const boost::intrusive_ptr<ExpressionContext>& expCtx, const NamespaceString& ns, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, const WriteConcernOptions& wc, bool upsert, bool multi, diff --git a/src/mongo/db/pipeline/mongos_process_interface.h b/src/mongo/db/pipeline/mongos_process_interface.h index 0f6feca3356..ae78e9e7f89 100644 --- a/src/mongo/db/pipeline/mongos_process_interface.h +++ b/src/mongo/db/pipeline/mongos_process_interface.h @@ -110,8 +110,7 @@ public: void update(const boost::intrusive_ptr<ExpressionContext>& expCtx, const NamespaceString& ns, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, const WriteConcernOptions& wc, bool upsert, bool multi, @@ -121,8 +120,7 @@ public: WriteResult updateWithResult(const boost::intrusive_ptr<ExpressionContext>& expCtx, const NamespaceString& ns, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, const WriteConcernOptions& wc, bool upsert, bool multi, diff --git a/src/mongo/db/pipeline/process_interface_shardsvr.cpp b/src/mongo/db/pipeline/process_interface_shardsvr.cpp index 5540bc7e57e..5d217019b11 100644 --- a/src/mongo/db/pipeline/process_interface_shardsvr.cpp +++ b/src/mongo/db/pipeline/process_interface_shardsvr.cpp @@ -132,8 +132,7 @@ void MongoInterfaceShardServer::insert(const boost::intrusive_ptr<ExpressionCont void MongoInterfaceShardServer::update(const boost::intrusive_ptr<ExpressionContext>& expCtx, const NamespaceString& ns, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, const WriteConcernOptions& wc, bool upsert, bool multi, @@ -141,12 +140,8 @@ void MongoInterfaceShardServer::update(const boost::intrusive_ptr<ExpressionCont BatchedCommandResponse response; BatchWriteExecStats stats; - BatchedCommandRequest updateCommand(buildUpdateOp(ns, - std::move(queries), - std::move(updates), - upsert, - multi, - expCtx->bypassDocumentValidation)); + BatchedCommandRequest updateCommand( + buildUpdateOp(ns, std::move(batch), upsert, multi, expCtx->bypassDocumentValidation)); // If applicable, attach a write concern to the batched command request. attachWriteConcern(&updateCommand, wc); diff --git a/src/mongo/db/pipeline/process_interface_shardsvr.h b/src/mongo/db/pipeline/process_interface_shardsvr.h index c7044848ae5..099a85b1b12 100644 --- a/src/mongo/db/pipeline/process_interface_shardsvr.h +++ b/src/mongo/db/pipeline/process_interface_shardsvr.h @@ -76,8 +76,7 @@ public: */ void update(const boost::intrusive_ptr<ExpressionContext>& expCtx, const NamespaceString& ns, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, const WriteConcernOptions& wc, bool upsert, bool multi, diff --git a/src/mongo/db/pipeline/process_interface_standalone.cpp b/src/mongo/db/pipeline/process_interface_standalone.cpp index ca5fa8e537d..7805d0621b0 100644 --- a/src/mongo/db/pipeline/process_interface_standalone.cpp +++ b/src/mongo/db/pipeline/process_interface_standalone.cpp @@ -192,19 +192,20 @@ Insert MongoInterfaceStandalone::buildInsertOp(const NamespaceString& nss, } Update MongoInterfaceStandalone::buildUpdateOp(const NamespaceString& nss, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, bool upsert, bool multi, bool bypassDocValidation) { Update updateOp(nss); updateOp.setUpdates([&] { std::vector<UpdateOpEntry> updateEntries; - for (size_t index = 0; index < queries.size(); ++index) { + for (auto&& obj : batch) { updateEntries.push_back([&] { UpdateOpEntry entry; - entry.setQ(std::move(queries[index])); - entry.setU(std::move(updates[index])); + auto && [ q, u, c ] = obj; + entry.setQ(std::move(q)); + entry.setU(std::move(u)); + entry.setC(std::move(c)); entry.setUpsert(upsert); entry.setMulti(multi); return entry; @@ -245,19 +246,14 @@ void MongoInterfaceStandalone::insert(const boost::intrusive_ptr<ExpressionConte WriteResult MongoInterfaceStandalone::updateWithResult( const boost::intrusive_ptr<ExpressionContext>& expCtx, const NamespaceString& ns, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, const WriteConcernOptions& wc, bool upsert, bool multi, boost::optional<OID> targetEpoch) { - auto writeResults = performUpdates(expCtx->opCtx, - buildUpdateOp(ns, - std::move(queries), - std::move(updates), - upsert, - multi, - expCtx->bypassDocumentValidation)); + auto writeResults = performUpdates( + expCtx->opCtx, + buildUpdateOp(ns, std::move(batch), upsert, multi, expCtx->bypassDocumentValidation)); // Need to check each result in the batch since the writes are unordered. uassertStatusOKWithContext( [&writeResults]() { @@ -275,14 +271,13 @@ WriteResult MongoInterfaceStandalone::updateWithResult( void MongoInterfaceStandalone::update(const boost::intrusive_ptr<ExpressionContext>& expCtx, const NamespaceString& ns, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, const WriteConcernOptions& wc, bool upsert, bool multi, boost::optional<OID> targetEpoch) { - [[maybe_unused]] auto writeResult = updateWithResult( - expCtx, ns, std::move(queries), std::move(updates), wc, upsert, multi, targetEpoch); + [[maybe_unused]] auto writeResult = + updateWithResult(expCtx, ns, std::move(batch), wc, upsert, multi, targetEpoch); } CollectionIndexUsageMap MongoInterfaceStandalone::getIndexStats(OperationContext* opCtx, diff --git a/src/mongo/db/pipeline/process_interface_standalone.h b/src/mongo/db/pipeline/process_interface_standalone.h index ac32ad69d19..43cdec79d02 100644 --- a/src/mongo/db/pipeline/process_interface_standalone.h +++ b/src/mongo/db/pipeline/process_interface_standalone.h @@ -64,16 +64,14 @@ public: boost::optional<OID> targetEpoch) override; void update(const boost::intrusive_ptr<ExpressionContext>& expCtx, const NamespaceString& ns, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, const WriteConcernOptions& wc, bool upsert, bool multi, boost::optional<OID> targetEpoch) override; WriteResult updateWithResult(const boost::intrusive_ptr<ExpressionContext>& expCtx, const NamespaceString& ns, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, const WriteConcernOptions& wc, bool upsert, bool multi, @@ -162,8 +160,7 @@ protected: * Note that 'queries' and 'updates' must be the same length. */ Update buildUpdateOp(const NamespaceString& nss, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, bool upsert, bool multi, bool bypassDocValidation); diff --git a/src/mongo/db/pipeline/stub_mongo_process_interface.h b/src/mongo/db/pipeline/stub_mongo_process_interface.h index f09aab9c217..abbf1555cb0 100644 --- a/src/mongo/db/pipeline/stub_mongo_process_interface.h +++ b/src/mongo/db/pipeline/stub_mongo_process_interface.h @@ -73,8 +73,7 @@ public: void update(const boost::intrusive_ptr<ExpressionContext>& expCtx, const NamespaceString& ns, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, const WriteConcernOptions& wc, bool upsert, bool multi, @@ -84,8 +83,7 @@ public: WriteResult updateWithResult(const boost::intrusive_ptr<ExpressionContext>& expCtx, const NamespaceString& ns, - std::vector<BSONObj>&& queries, - std::vector<write_ops::UpdateModification>&& updates, + BatchedObjects&& batch, const WriteConcernOptions& wc, bool upsert, bool multi, diff --git a/src/mongo/db/update/pipeline_executor.cpp b/src/mongo/db/update/pipeline_executor.cpp index ff2b581220c..542734e8a39 100644 --- a/src/mongo/db/update/pipeline_executor.cpp +++ b/src/mongo/db/update/pipeline_executor.cpp @@ -44,7 +44,8 @@ constexpr StringData kIdFieldName = "_id"_sd; } // namespace PipelineExecutor::PipelineExecutor(const boost::intrusive_ptr<ExpressionContext>& expCtx, - const std::vector<BSONObj>& pipeline) + const std::vector<BSONObj>& pipeline, + boost::optional<BSONObj> constants) : _expCtx(expCtx) { // "Resolve" involved namespaces into a map. We have to populate this map so that any // $lookups, etc. will not fail instantiation. They will not be used for execution as these @@ -55,6 +56,17 @@ PipelineExecutor::PipelineExecutor(const boost::intrusive_ptr<ExpressionContext> for (auto&& nss : liteParsedPipeline.getInvolvedNamespaces()) { resolvedNamespaces.try_emplace(nss.coll(), nss, std::vector<BSONObj>{}); } + + if (constants) { + for (auto&& constElem : *constants) { + const auto constName = constElem.fieldNameStringData(); + Variables::uassertValidNameForUserRead(constName); + + auto varId = _expCtx->variablesParseState.defineVariable(constName); + _expCtx->variables.setConstantValue(varId, Value(constElem)); + } + } + _expCtx->setResolvedNamespaces(resolvedNamespaces); _pipeline = uassertStatusOK(Pipeline::parse(pipeline, _expCtx)); diff --git a/src/mongo/db/update/pipeline_executor.h b/src/mongo/db/update/pipeline_executor.h index 2f10539c7e2..be8d4b6e0f5 100644 --- a/src/mongo/db/update/pipeline_executor.h +++ b/src/mongo/db/update/pipeline_executor.h @@ -51,7 +51,8 @@ public: * Initializes the node with an aggregation pipeline definition. */ explicit PipelineExecutor(const boost::intrusive_ptr<ExpressionContext>& expCtx, - const std::vector<BSONObj>& pipeline); + const std::vector<BSONObj>& pipeline, + boost::optional<BSONObj> constants = boost::none); /** * Replaces the document that 'applyParams.element' belongs to with 'val'. If 'val' does not diff --git a/src/mongo/db/update/pipeline_executor_test.cpp b/src/mongo/db/update/pipeline_executor_test.cpp index d16f502a441..2a10c292532 100644 --- a/src/mongo/db/update/pipeline_executor_test.cpp +++ b/src/mongo/db/update/pipeline_executor_test.cpp @@ -285,5 +285,82 @@ TEST_F(PipelineExecutorTest, SerializeTest) { ASSERT_VALUE_EQ(serialized, Value(BSONArray(doc))); } +TEST_F(PipelineExecutorTest, RejectsInvalidConstantNames) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + const std::vector<BSONObj> pipeline; + + // Empty name. + auto constants = BSON("" << 10); + ASSERT_THROWS_CODE(PipelineExecutor(expCtx, pipeline, constants), AssertionException, 16869); + + // Invalid first character. + constants = BSON("^invalidFirstChar" << 10); + ASSERT_THROWS_CODE(PipelineExecutor(expCtx, pipeline, constants), AssertionException, 16870); + + // Contains invalid character. + constants = BSON("contains*InvalidChar" << 10); + ASSERT_THROWS_CODE(PipelineExecutor(expCtx, pipeline, constants), AssertionException, 16871); +} + +TEST_F(PipelineExecutorTest, CanUseConstants) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + + const std::vector<BSONObj> pipeline{fromjson("{$set: {b: '$$var1', c: '$$var2'}}")}; + const auto constants = BSON("var1" << 10 << "var2" << BSON("x" << 1 << "y" << 2)); + PipelineExecutor exec(expCtx, pipeline, constants); + + mutablebson::Document doc(fromjson("{a: 1}")); + const auto result = exec.applyUpdate(getApplyParams(doc.root())); + ASSERT_FALSE(result.noop); + ASSERT_TRUE(result.indexesAffected); + ASSERT_EQUALS(fromjson("{a: 1, b: 10, c : {x: 1, y: 2}}"), doc); + ASSERT_FALSE(doc.isInPlaceModeEnabled()); + ASSERT_EQUALS(fromjson("{a: 1, b: 10, c : {x: 1, y: 2}}"), getLogDoc()); +} + +TEST_F(PipelineExecutorTest, CanUseConstantsAcrossMultipleUpdates) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + + const std::vector<BSONObj> pipeline{fromjson("{$set: {b: '$$var1'}}")}; + const auto constants = BSON("var1" + << "foo"); + PipelineExecutor exec(expCtx, pipeline, constants); + + // Update first doc. + mutablebson::Document doc1(fromjson("{a: 1}")); + auto result = exec.applyUpdate(getApplyParams(doc1.root())); + ASSERT_FALSE(result.noop); + ASSERT_TRUE(result.indexesAffected); + ASSERT_EQUALS(fromjson("{a: 1, b: 'foo'}"), doc1); + ASSERT_FALSE(doc1.isInPlaceModeEnabled()); + ASSERT_EQUALS(fromjson("{a: 1, b: 'foo'}"), getLogDoc()); + + // Update second doc. + mutablebson::Document doc2(fromjson("{a: 2}")); + resetApplyParams(); + result = exec.applyUpdate(getApplyParams(doc2.root())); + ASSERT_FALSE(result.noop); + ASSERT_TRUE(result.indexesAffected); + ASSERT_EQUALS(fromjson("{a: 2, b: 'foo'}"), doc2); + ASSERT_FALSE(doc2.isInPlaceModeEnabled()); + ASSERT_EQUALS(fromjson("{a: 2, b: 'foo'}"), getLogDoc()); +} + +TEST_F(PipelineExecutorTest, NoopWithConstants) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + + const std::vector<BSONObj> pipeline{fromjson("{$set: {a: '$$var1', b: '$$var2'}}")}; + const auto constants = BSON("var1" << 1 << "var2" << 2); + PipelineExecutor exec(expCtx, pipeline, constants); + + mutablebson::Document doc(fromjson("{a: 1, b: 2}")); + const auto result = exec.applyUpdate(getApplyParams(doc.root())); + ASSERT_TRUE(result.noop); + ASSERT_FALSE(result.indexesAffected); + ASSERT_EQUALS(fromjson("{a: 1, b: 2}"), doc); + ASSERT_TRUE(doc.isInPlaceModeEnabled()); + ASSERT_EQUALS(fromjson("{}"), getLogDoc()); +} + } // namespace } // namespace mongo diff --git a/src/mongo/db/update/update_driver.cpp b/src/mongo/db/update/update_driver.cpp index d66dcc8dab9..3f2024dba1c 100644 --- a/src/mongo/db/update/update_driver.cpp +++ b/src/mongo/db/update/update_driver.cpp @@ -150,6 +150,7 @@ UpdateDriver::UpdateDriver(const boost::intrusive_ptr<ExpressionContext>& expCtx void UpdateDriver::parse( const write_ops::UpdateModification& updateMod, const std::map<StringData, std::unique_ptr<ExpressionWithPlaceholder>>& arrayFilters, + boost::optional<BSONObj> constants, const bool multi) { invariant(!_updateExecutor, "Multiple calls to parse() on same UpdateDriver"); @@ -158,11 +159,13 @@ void UpdateDriver::parse( "arrayFilters may not be specified for pipeline-syle updates", arrayFilters.empty()); _updateExecutor = - stdx::make_unique<PipelineExecutor>(_expCtx, updateMod.getUpdatePipeline()); + stdx::make_unique<PipelineExecutor>(_expCtx, updateMod.getUpdatePipeline(), constants); _updateType = UpdateType::kPipeline; return; } + uassert(51198, "Constant values may only be specified for pipeline updates", !constants); + // Check if the update expression is a full object replacement. if (isDocReplacement(updateMod)) { uassert(ErrorCodes::FailedToParse, diff --git a/src/mongo/db/update/update_driver.h b/src/mongo/db/update/update_driver.h index 24c0cd72edc..f4ef497840e 100644 --- a/src/mongo/db/update/update_driver.h +++ b/src/mongo/db/update/update_driver.h @@ -66,6 +66,7 @@ public: */ void parse(const write_ops::UpdateModification& updateExpr, const std::map<StringData, std::unique_ptr<ExpressionWithPlaceholder>>& arrayFilters, + boost::optional<BSONObj> constants = boost::none, const bool multi = false); /** diff --git a/src/mongo/db/update/update_driver_test.cpp b/src/mongo/db/update/update_driver_test.cpp index 5cb065b528f..dcfab0d74fb 100644 --- a/src/mongo/db/update/update_driver_test.cpp +++ b/src/mongo/db/update/update_driver_test.cpp @@ -95,7 +95,6 @@ TEST(Parse, ObjectReplacment) { } TEST(Parse, ParseUpdateWithPipeline) { - setTestCommandsEnabled(true); boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); UpdateDriver driver(expCtx); std::map<StringData, std::unique_ptr<ExpressionWithPlaceholder>> arrayFilters; @@ -104,6 +103,20 @@ TEST(Parse, ParseUpdateWithPipeline) { ASSERT_TRUE(driver.type() == UpdateDriver::UpdateType::kPipeline); } +TEST(Parse, ParseUpdateWithPipelineAndVariables) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + UpdateDriver driver(expCtx); + std::map<StringData, std::unique_ptr<ExpressionWithPlaceholder>> arrayFilters; + const auto variables = BSON("var1" << 1 << "var2" + << "foo"); + auto updateObj = BSON("u" << BSON_ARRAY(BSON("$set" << BSON("a" + << "$$var1" + << "b" + << "$$var2")))); + ASSERT_DOES_NOT_THROW(driver.parse(updateObj["u"], arrayFilters, variables)); + ASSERT_TRUE(driver.type() == UpdateDriver::UpdateType::kPipeline); +} + TEST(Parse, EmptyMod) { boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); UpdateDriver driver(expCtx); @@ -551,6 +564,7 @@ public: boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); UpdateDriver driver(expCtx); std::map<StringData, std::unique_ptr<ExpressionWithPlaceholder>> arrayFilters; + for (const auto& filter : arrayFilterSpec) { auto parsedFilter = assertGet(MatchExpressionParser::parse(filter, expCtx)); auto expr = assertGet(ExpressionWithPlaceholder::make(std::move(parsedFilter))); diff --git a/src/mongo/dbtests/query_stage_update.cpp b/src/mongo/dbtests/query_stage_update.cpp index 91b1669e050..715aa75d2ad 100644 --- a/src/mongo/dbtests/query_stage_update.cpp +++ b/src/mongo/dbtests/query_stage_update.cpp @@ -215,9 +215,10 @@ public: request.setUpdateModification(updates); const std::map<StringData, std::unique_ptr<ExpressionWithPlaceholder>> arrayFilters; + const auto constants = boost::none; - ASSERT_DOES_NOT_THROW( - driver.parse(request.getUpdateModification(), arrayFilters, request.isMulti())); + ASSERT_DOES_NOT_THROW(driver.parse( + request.getUpdateModification(), arrayFilters, constants, request.isMulti())); // Setup update params. UpdateStageParams params(&request, &driver, opDebug); @@ -288,9 +289,10 @@ public: request.setUpdateModification(updates); const std::map<StringData, std::unique_ptr<ExpressionWithPlaceholder>> arrayFilters; + const auto constants = boost::none; - ASSERT_DOES_NOT_THROW( - driver.parse(request.getUpdateModification(), arrayFilters, request.isMulti())); + ASSERT_DOES_NOT_THROW(driver.parse( + request.getUpdateModification(), arrayFilters, constants, request.isMulti())); // Configure the scan. CollectionScanParams collScanParams; @@ -399,9 +401,10 @@ public: request.setReturnDocs(UpdateRequest::RETURN_OLD); const std::map<StringData, std::unique_ptr<ExpressionWithPlaceholder>> arrayFilters; + const auto constants = boost::none; - ASSERT_DOES_NOT_THROW( - driver.parse(request.getUpdateModification(), arrayFilters, request.isMulti())); + ASSERT_DOES_NOT_THROW(driver.parse( + request.getUpdateModification(), arrayFilters, constants, request.isMulti())); // Configure a QueuedDataStage to pass the first object in the collection back in a // RID_AND_OBJ state. @@ -490,9 +493,10 @@ public: request.setReturnDocs(UpdateRequest::RETURN_NEW); const std::map<StringData, std::unique_ptr<ExpressionWithPlaceholder>> arrayFilters; + const auto constants = boost::none; - ASSERT_DOES_NOT_THROW( - driver.parse(request.getUpdateModification(), arrayFilters, request.isMulti())); + ASSERT_DOES_NOT_THROW(driver.parse( + request.getUpdateModification(), arrayFilters, constants, request.isMulti())); // Configure a QueuedDataStage to pass the first object in the collection back in a // RID_AND_OBJ state. diff --git a/src/mongo/embedded/stitch_support/stitch_support.cpp b/src/mongo/embedded/stitch_support/stitch_support.cpp index 06ba765eab8..c7f8f34d164 100644 --- a/src/mongo/embedded/stitch_support/stitch_support.cpp +++ b/src/mongo/embedded/stitch_support/stitch_support.cpp @@ -240,10 +240,7 @@ struct stitch_support_v1_update { this->parsedFilters = uassertStatusOK(mongo::ParsedUpdate::parseArrayFilters( arrayFilterVector, this->opCtx.get(), collator ? collator->collator.get() : nullptr)); - // Initializing the update as single-document allows document-replacement updates. - bool multi = false; - - updateDriver.parse(this->updateExpr, parsedFilters, multi); + updateDriver.parse(this->updateExpr, parsedFilters); uassert(51037, "Updates with a positional operator require a matcher object.", |