diff options
author | Hana Pearlman <hana.pearlman@mongodb.com> | 2022-03-28 21:32:48 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-03-28 22:52:34 +0000 |
commit | 6eb0a1b696278d4473365a169ae08f51c8a12388 (patch) | |
tree | 1fdb2b9705f131f9ae513ded5eecebcf46f8fc58 | |
parent | 51918a385ab8b3fc1f922bede326ff84e846a47b (diff) | |
download | mongo-6eb0a1b696278d4473365a169ae08f51c8a12388.tar.gz |
SERVER-64359: Implement FLE server-side rewrite for agg command on mongos
-rw-r--r-- | src/mongo/db/commands/current_op.cpp | 2 | ||||
-rw-r--r-- | src/mongo/db/commands/current_op_common.cpp | 4 | ||||
-rw-r--r-- | src/mongo/db/commands/current_op_common.h | 4 | ||||
-rw-r--r-- | src/mongo/db/query/fle/server_rewrite.cpp | 158 | ||||
-rw-r--r-- | src/mongo/db/query/fle/server_rewrite.h | 11 | ||||
-rw-r--r-- | src/mongo/s/commands/cluster_current_op.cpp | 2 | ||||
-rw-r--r-- | src/mongo/s/commands/cluster_pipeline_cmd.cpp | 4 | ||||
-rw-r--r-- | src/mongo/s/query/SConscript | 1 | ||||
-rw-r--r-- | src/mongo/s/query/cluster_aggregate.cpp | 26 | ||||
-rw-r--r-- | src/mongo/s/query/cluster_aggregate.h | 6 | ||||
-rw-r--r-- | src/mongo/s/transaction_router.cpp | 8 |
11 files changed, 159 insertions, 67 deletions
diff --git a/src/mongo/db/commands/current_op.cpp b/src/mongo/db/commands/current_op.cpp index 1fd9200f9da..78fec805202 100644 --- a/src/mongo/db/commands/current_op.cpp +++ b/src/mongo/db/commands/current_op.cpp @@ -67,7 +67,7 @@ public: } virtual StatusWith<CursorResponse> runAggregation( - OperationContext* opCtx, const AggregateCommandRequest& request) const final { + OperationContext* opCtx, AggregateCommandRequest& request) const final { auto aggCmdObj = aggregation_request_helper::serializeToCommandObj(request); rpc::OpMsgReplyBuilder replyBuilder; diff --git a/src/mongo/db/commands/current_op_common.cpp b/src/mongo/db/commands/current_op_common.cpp index d44d5c066d2..6b816a79057 100644 --- a/src/mongo/db/commands/current_op_common.cpp +++ b/src/mongo/db/commands/current_op_common.cpp @@ -110,8 +110,8 @@ bool CurrentOpCommandBase::run(OperationContext* opCtx, pipeline.push_back(groupBuilder.obj()); // Pipeline is complete; create an AggregateCommandRequest for $currentOp. - const AggregateCommandRequest request(NamespaceString::makeCollectionlessAggregateNSS("admin"), - std::move(pipeline)); + AggregateCommandRequest request(NamespaceString::makeCollectionlessAggregateNSS("admin"), + std::move(pipeline)); // Run the pipeline and obtain a CursorResponse. auto aggResults = uassertStatusOK(runAggregation(opCtx, request)); diff --git a/src/mongo/db/commands/current_op_common.h b/src/mongo/db/commands/current_op_common.h index 31ff95ac764..789ece31e0a 100644 --- a/src/mongo/db/commands/current_op_common.h +++ b/src/mongo/db/commands/current_op_common.h @@ -73,8 +73,8 @@ private: * Runs the aggregation specified by the supplied AggregateCommandRequest, returning a * CursorResponse if successful or a Status containing the error otherwise. */ - virtual StatusWith<CursorResponse> runAggregation( - OperationContext* opCtx, const AggregateCommandRequest& request) const = 0; + virtual StatusWith<CursorResponse> runAggregation(OperationContext* opCtx, + AggregateCommandRequest& request) const = 0; /** * Allows overriders to optionally write additional data to the response object before the final diff --git a/src/mongo/db/query/fle/server_rewrite.cpp b/src/mongo/db/query/fle/server_rewrite.cpp index b83b087a04a..792aed343c8 100644 --- a/src/mongo/db/query/fle/server_rewrite.cpp +++ b/src/mongo/db/query/fle/server_rewrite.cpp @@ -34,6 +34,9 @@ #include "mongo/crypto/fle_crypto.h" #include "mongo/crypto/fle_tags.h" #include "mongo/db/fle_crud.h" +#include "mongo/db/pipeline/document_source_geo_near.h" +#include "mongo/db/pipeline/document_source_graph_lookup.h" +#include "mongo/db/pipeline/document_source_match.h" #include "mongo/db/query/collation/collator_factory_interface.h" #include "mongo/db/transaction_api.h" #include "mongo/s/grid.h" @@ -60,51 +63,83 @@ boost::intrusive_ptr<ExpressionContext> makeExpCtx(OperationContext* opCtx, findCommand->getLegacyRuntimeConstants(), findCommand->getLet()); } -} // namespace -BSONObj rewriteEncryptedFilter(boost::intrusive_ptr<ExpressionContext> expCtx, - const FLEStateCollectionReader& escReader, - const FLEStateCollectionReader& eccReader, - BSONObj filter) { - return MatchExpressionRewrite(expCtx, escReader, eccReader, filter).get(); -} +class RewriteBase { +public: + RewriteBase(boost::intrusive_ptr<ExpressionContext> expCtx, + const NamespaceString& nss, + const EncryptionInformation& encryptInfo) + : expCtx(expCtx), db(nss.db()) { + auto efc = EncryptionInformationHelpers::getAndValidateSchema(nss, encryptInfo); + esc = efc.getEscCollection()->toString(); + ecc = efc.getEccCollection()->toString(); + } + virtual ~RewriteBase(){}; + virtual void doRewrite(FLEStateCollectionReader& escReader, + FLEStateCollectionReader& eccReader){}; + + boost::intrusive_ptr<ExpressionContext> expCtx; + std::string esc; + std::string ecc; + std::string db; +}; + +// This class handles rewriting of an entire pipeline. +class PipelineRewrite : public RewriteBase { +public: + PipelineRewrite(const NamespaceString& nss, + const EncryptionInformation& encryptInfo, + std::unique_ptr<Pipeline, PipelineDeleter> toRewrite) + : RewriteBase(toRewrite->getContext(), nss, encryptInfo), pipeline(std::move(toRewrite)) {} + + ~PipelineRewrite(){}; + void doRewrite(FLEStateCollectionReader& escReader, FLEStateCollectionReader& eccReader) final { + for (auto&& source : pipeline->getSources()) { + if (auto match = dynamic_cast<DocumentSourceMatch*>(source.get())) { + match->rebuild( + rewriteEncryptedFilter(expCtx, escReader, eccReader, match->getQuery())); + } else if (auto geoNear = dynamic_cast<DocumentSourceGeoNear*>(source.get())) { + geoNear->setQuery( + rewriteEncryptedFilter(expCtx, escReader, eccReader, geoNear->getQuery())); + } else if (auto graphLookup = dynamic_cast<DocumentSourceGraphLookUp*>(source.get()); + graphLookup && graphLookup->getAdditionalFilter()) { + graphLookup->setAdditionalFilter(rewriteEncryptedFilter( + expCtx, escReader, eccReader, graphLookup->getAdditionalFilter().get())); + } + } + } -void processFindCommand(OperationContext* opCtx, - NamespaceString nss, - FindCommandRequest* findCommand) { - invariant(findCommand->getEncryptionInformation()); + std::unique_ptr<Pipeline, PipelineDeleter> getPipeline() { + return std::move(pipeline); + } + +private: + std::unique_ptr<Pipeline, PipelineDeleter> pipeline; +}; + +// This class handles rewriting of a single match expression, represented as a BSONObj. +class FilterRewrite : public RewriteBase { +public: + FilterRewrite(boost::intrusive_ptr<ExpressionContext> expCtx, + const NamespaceString& nss, + const EncryptionInformation& encryptInfo, + const BSONObj toRewrite) + : RewriteBase(expCtx, nss, encryptInfo), userFilter(toRewrite) {} + + ~FilterRewrite(){}; + void doRewrite(FLEStateCollectionReader& escReader, FLEStateCollectionReader& eccReader) final { + rewrittenFilter = rewriteEncryptedFilter(expCtx, escReader, eccReader, userFilter); + } - auto efc = EncryptionInformationHelpers::getAndValidateSchema( - nss, findCommand->getEncryptionInformation().get()); - - // The transaction runs in a separate executor, and so we can't pass data by - // reference into the lambda. This struct holds all the data we need inside the - // lambda, and is passed in a more threadsafe shared_ptr. - struct SharedBlock { - SharedBlock(NamespaceString nss, - std::string esc, - std::string ecc, - const BSONObj userFilter, - boost::intrusive_ptr<ExpressionContext> expCtx) - : esc(std::move(esc)), - ecc(std::move(ecc)), - userFilter(userFilter), - db(nss.db()), - expCtx(expCtx) {} - std::string esc; - std::string ecc; - const BSONObj userFilter; - BSONObj rewrittenFilter; - std::string db; - boost::intrusive_ptr<ExpressionContext> expCtx; - }; - - auto sharedBlock = std::make_shared<SharedBlock>(nss, - efc.getEscCollection().get().toString(), - efc.getEccCollection().get().toString(), - findCommand->getFilter().getOwned(), - makeExpCtx(opCtx, findCommand)); + const BSONObj userFilter; + BSONObj rewrittenFilter; +}; +// This helper executes the rewrite(s) inside a transaction. The transaction runs in a separate +// executor, and so we can't pass data by reference into the lambda. The provided rewriter should +// hold all the data we need to do the rewriting inside the lambda, and is passed in a more +// threadsafe shared_ptr. The result of applying the rewrites can be accessed in the RewriteBase. +void doFLERewriteInTxn(OperationContext* opCtx, std::shared_ptr<RewriteBase> sharedBlock) { auto txn = std::make_shared<txn_api::TransactionWithRetries>( opCtx, Grid::get(opCtx)->getExecutorPool()->getFixedExecutor(), @@ -125,8 +160,7 @@ void processFindCommand(OperationContext* opCtx, auto eccReader = makeCollectionReader(&queryInterface, sharedBlock->ecc); // Rewrite the MatchExpression. - sharedBlock->rewrittenFilter = rewriteEncryptedFilter( - sharedBlock->expCtx, escReader, eccReader, sharedBlock->userFilter); + sharedBlock->doRewrite(escReader, eccReader); return SemiFuture<void>::makeReady(); }); @@ -134,17 +168,43 @@ void processFindCommand(OperationContext* opCtx, uassertStatusOK(swCommitResult); uassertStatusOK(swCommitResult.getValue().cmdStatus); uassertStatusOK(swCommitResult.getValue().getEffectiveStatus()); +} +} // namespace + +BSONObj rewriteEncryptedFilter(boost::intrusive_ptr<ExpressionContext> expCtx, + const FLEStateCollectionReader& escReader, + const FLEStateCollectionReader& eccReader, + BSONObj filter) { + return MatchExpressionRewrite(expCtx, escReader, eccReader, filter).get(); +} + +void processFindCommand(OperationContext* opCtx, + NamespaceString nss, + FindCommandRequest* findCommand) { + invariant(findCommand->getEncryptionInformation()); + + auto sharedBlock = + std::make_shared<FilterRewrite>(makeExpCtx(opCtx, findCommand), + nss, + findCommand->getEncryptionInformation().get(), + findCommand->getFilter().getOwned()); + doFLERewriteInTxn(opCtx, sharedBlock); auto rewrittenFilter = sharedBlock->rewrittenFilter.getOwned(); findCommand->setFilter(std::move(rewrittenFilter)); findCommand->setEncryptionInformation(boost::none); +} - // If we are in a multi-document transaction, then the transaction API has taken - // care of setting the readConcern on the transaction, and the find command - // shouldn't provide its own readConcern. - if (opCtx->inMultiDocumentTransaction()) { - findCommand->setReadConcern(boost::none); - } +std::unique_ptr<Pipeline, PipelineDeleter> processPipeline( + OperationContext* opCtx, + NamespaceString nss, + const EncryptionInformation& encryptInfo, + std::unique_ptr<Pipeline, PipelineDeleter> toRewrite) { + + auto sharedBlock = std::make_shared<PipelineRewrite>(nss, encryptInfo, std::move(toRewrite)); + doFLERewriteInTxn(opCtx, sharedBlock); + + return sharedBlock->getPipeline(); } std::unique_ptr<MatchExpression> MatchExpressionRewrite::_rewriteMatchExpression( diff --git a/src/mongo/db/query/fle/server_rewrite.h b/src/mongo/db/query/fle/server_rewrite.h index d914d60e01d..3ea8fb5592b 100644 --- a/src/mongo/db/query/fle/server_rewrite.h +++ b/src/mongo/db/query/fle/server_rewrite.h @@ -58,6 +58,17 @@ void processFindCommand(OperationContext* opCtx, FindCommandRequest* findCommand); /** + * Process a pipeline with encryptionInformation by rewriting the pipeline to query against the + * underlying tags array, where appropriate. After this rewriting is complete, there is no more FLE + * work to be done. The encryption info does not need to be kept around (e.g. on a command object). + */ +std::unique_ptr<Pipeline, PipelineDeleter> processPipeline( + OperationContext* opCtx, + NamespaceString nss, + const EncryptionInformation& encryptInfo, + std::unique_ptr<Pipeline, PipelineDeleter> toRewrite); + +/** * Class which handles rewriting filter MatchExpressions for FLE2. The functionality is encapsulated * as a class rather than just a namespace so that the collection readers don't have to be passed * around as extra arguments to every function. diff --git a/src/mongo/s/commands/cluster_current_op.cpp b/src/mongo/s/commands/cluster_current_op.cpp index 3c0e73f30d3..5f4f2187b49 100644 --- a/src/mongo/s/commands/cluster_current_op.cpp +++ b/src/mongo/s/commands/cluster_current_op.cpp @@ -71,7 +71,7 @@ private: } virtual StatusWith<CursorResponse> runAggregation( - OperationContext* opCtx, const AggregateCommandRequest& request) const final { + OperationContext* opCtx, AggregateCommandRequest& request) const final { auto nss = request.getNamespace(); BSONObjBuilder responseBuilder; diff --git a/src/mongo/s/commands/cluster_pipeline_cmd.cpp b/src/mongo/s/commands/cluster_pipeline_cmd.cpp index 87dbb0e05c1..aec0641e905 100644 --- a/src/mongo/s/commands/cluster_pipeline_cmd.cpp +++ b/src/mongo/s/commands/cluster_pipeline_cmd.cpp @@ -94,7 +94,7 @@ public: public: Invocation(Command* cmd, const OpMsgRequest& request, - const AggregateCommandRequest aggregationRequest, + AggregateCommandRequest aggregationRequest, PrivilegeVector privileges) : CommandInvocation(cmd), _request(request), @@ -171,7 +171,7 @@ public: const OpMsgRequest& _request; const std::string _dbName; - const AggregateCommandRequest _aggregationRequest; + AggregateCommandRequest _aggregationRequest; const LiteParsedPipeline _liteParsedPipeline; const PrivilegeVector _privileges; }; diff --git a/src/mongo/s/query/SConscript b/src/mongo/s/query/SConscript index 1ebf88e5164..f6f1ac53b05 100644 --- a/src/mongo/s/query/SConscript +++ b/src/mongo/s/query/SConscript @@ -33,6 +33,7 @@ env.Library( 'cluster_aggregation_planner.cpp', ], LIBDEPS=[ + '$BUILD_DIR/mongo/db/fle_crud', '$BUILD_DIR/mongo/db/pipeline/pipeline', '$BUILD_DIR/mongo/db/pipeline/process_interface/mongos_process_interface', '$BUILD_DIR/mongo/db/pipeline/sharded_agg_helpers', diff --git a/src/mongo/s/query/cluster_aggregate.cpp b/src/mongo/s/query/cluster_aggregate.cpp index 966ca7c970c..3c754850628 100644 --- a/src/mongo/s/query/cluster_aggregate.cpp +++ b/src/mongo/s/query/cluster_aggregate.cpp @@ -53,6 +53,7 @@ #include "mongo/db/query/cursor_response.h" #include "mongo/db/query/explain_common.h" #include "mongo/db/query/find_common.h" +#include "mongo/db/query/fle/server_rewrite.h" #include "mongo/db/timeseries/timeseries_options.h" #include "mongo/db/views/resolved_view.h" #include "mongo/db/views/view.h" @@ -251,7 +252,7 @@ std::vector<BSONObj> rebuildPipelineWithTimeSeriesGranularity(const std::vector< Status ClusterAggregate::runAggregate(OperationContext* opCtx, const Namespaces& namespaces, - const AggregateCommandRequest& request, + AggregateCommandRequest& request, const PrivilegeVector& privileges, BSONObjBuilder* result) { return runAggregate(opCtx, namespaces, request, {request}, privileges, result); @@ -259,7 +260,7 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx, Status ClusterAggregate::runAggregate(OperationContext* opCtx, const Namespaces& namespaces, - const AggregateCommandRequest& request, + AggregateCommandRequest& request, const LiteParsedPipeline& liteParsedPipeline, const PrivilegeVector& privileges, BSONObjBuilder* result) { @@ -269,7 +270,7 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx, Status ClusterAggregate::runAggregate(OperationContext* opCtx, const Namespaces& namespaces, - const AggregateCommandRequest& request, + AggregateCommandRequest& request, const LiteParsedPipeline& liteParsedPipeline, const PrivilegeVector& privileges, boost::optional<ChunkManager> cm, @@ -298,6 +299,7 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx, opCtx, isSharded, request.getExplain(), serverGlobalParams.enableMajorityReadConcern); auto hasChangeStream = liteParsedPipeline.hasChangeStream(); auto involvedNamespaces = liteParsedPipeline.getInvolvedNamespaces(); + auto shouldDoFLERewrite = fle::shouldRewrite(&request); uassert(6256300, str::stream() << "On mongos, " << AggregateCommandRequest::kCollectionUUIDFieldName @@ -363,10 +365,26 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx, // Parse and optimize the full pipeline. auto pipeline = Pipeline::parse(request.getPipeline(), expCtx); + + // If the aggregate command supports encrypted collections, do rewrites of the pipeline to + // support querying against encrypted fields. + if (shouldDoFLERewrite) { + // After this rewriting, the encryption info does not need to be kept around. + pipeline = fle::processPipeline(opCtx, + namespaces.executionNss, + request.getEncryptionInformation().get(), + std::move(pipeline)); + request.setEncryptionInformation(boost::none); + } + pipeline->optimizePipeline(); return pipeline; }; + // The pipeline is not allowed to passthrough if any stage is not allowed to passthrough or if + // the pipeline needs to undergo FLE rewriting first. + auto allowedToPassthrough = + liteParsedPipeline.allowedToPassthroughFromMongos() && !shouldDoFLERewrite; auto targeter = cluster_aggregation_planner::AggregationTargeter::make( opCtx, namespaces.executionNss, @@ -374,7 +392,7 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx, cm, involvedNamespaces, hasChangeStream, - liteParsedPipeline.allowedToPassthroughFromMongos(), + allowedToPassthrough, request.getPassthroughToShard().has_value()); if (!expCtx) { diff --git a/src/mongo/s/query/cluster_aggregate.h b/src/mongo/s/query/cluster_aggregate.h index f16beabb648..5b3530d0860 100644 --- a/src/mongo/s/query/cluster_aggregate.h +++ b/src/mongo/s/query/cluster_aggregate.h @@ -83,7 +83,7 @@ public: */ static Status runAggregate(OperationContext* opCtx, const Namespaces& namespaces, - const AggregateCommandRequest& request, + AggregateCommandRequest& request, const LiteParsedPipeline& liteParsedPipeline, const PrivilegeVector& privileges, boost::optional<ChunkManager> cm, @@ -94,7 +94,7 @@ public: */ static Status runAggregate(OperationContext* opCtx, const Namespaces& namespaces, - const AggregateCommandRequest& request, + AggregateCommandRequest& request, const LiteParsedPipeline& liteParsedPipeline, const PrivilegeVector& privileges, BSONObjBuilder* result); @@ -105,7 +105,7 @@ public: */ static Status runAggregate(OperationContext* opCtx, const Namespaces& namespaces, - const AggregateCommandRequest& request, + AggregateCommandRequest& request, const PrivilegeVector& privileges, BSONObjBuilder* result); diff --git a/src/mongo/s/transaction_router.cpp b/src/mongo/s/transaction_router.cpp index d2e8a55b3a8..e2739a36778 100644 --- a/src/mongo/s/transaction_router.cpp +++ b/src/mongo/s/transaction_router.cpp @@ -442,11 +442,13 @@ BSONObj TransactionRouter::Participant::attachTxnFieldsIfNeeded( auto cmdName = cmd.firstElement().fieldNameStringData(); bool mustStartTransaction = isFirstStatementInThisParticipant && !isTransactionCommand(cmdName); + // Strip the command of its read concern if it should not have one. if (!mustStartTransaction) { auto readConcernFieldName = repl::ReadConcernArgs::kReadConcernFieldName; - dassert(!cmd.hasField(readConcernFieldName) || - cmd.getObjectField(readConcernFieldName).isEmpty() || - sharedOptions.isInternalTransactionForRetryableWrite); + if (cmd.hasField(readConcernFieldName) && + !sharedOptions.isInternalTransactionForRetryableWrite) { + cmd = cmd.removeField(readConcernFieldName); + } } BSONObjBuilder newCmd = mustStartTransaction |