diff options
Diffstat (limited to 'src/mongo/s/query/cluster_aggregate.cpp')
-rw-r--r-- | src/mongo/s/query/cluster_aggregate.cpp | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/src/mongo/s/query/cluster_aggregate.cpp b/src/mongo/s/query/cluster_aggregate.cpp index 966ca7c970c..3c754850628 100644 --- a/src/mongo/s/query/cluster_aggregate.cpp +++ b/src/mongo/s/query/cluster_aggregate.cpp @@ -53,6 +53,7 @@ #include "mongo/db/query/cursor_response.h" #include "mongo/db/query/explain_common.h" #include "mongo/db/query/find_common.h" +#include "mongo/db/query/fle/server_rewrite.h" #include "mongo/db/timeseries/timeseries_options.h" #include "mongo/db/views/resolved_view.h" #include "mongo/db/views/view.h" @@ -251,7 +252,7 @@ std::vector<BSONObj> rebuildPipelineWithTimeSeriesGranularity(const std::vector< Status ClusterAggregate::runAggregate(OperationContext* opCtx, const Namespaces& namespaces, - const AggregateCommandRequest& request, + AggregateCommandRequest& request, const PrivilegeVector& privileges, BSONObjBuilder* result) { return runAggregate(opCtx, namespaces, request, {request}, privileges, result); @@ -259,7 +260,7 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx, Status ClusterAggregate::runAggregate(OperationContext* opCtx, const Namespaces& namespaces, - const AggregateCommandRequest& request, + AggregateCommandRequest& request, const LiteParsedPipeline& liteParsedPipeline, const PrivilegeVector& privileges, BSONObjBuilder* result) { @@ -269,7 +270,7 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx, Status ClusterAggregate::runAggregate(OperationContext* opCtx, const Namespaces& namespaces, - const AggregateCommandRequest& request, + AggregateCommandRequest& request, const LiteParsedPipeline& liteParsedPipeline, const PrivilegeVector& privileges, boost::optional<ChunkManager> cm, @@ -298,6 +299,7 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx, opCtx, isSharded, request.getExplain(), serverGlobalParams.enableMajorityReadConcern); auto hasChangeStream = liteParsedPipeline.hasChangeStream(); auto involvedNamespaces = liteParsedPipeline.getInvolvedNamespaces(); + auto shouldDoFLERewrite = fle::shouldRewrite(&request); uassert(6256300, str::stream() << "On mongos, " << AggregateCommandRequest::kCollectionUUIDFieldName @@ -363,10 +365,26 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx, // Parse and optimize the full pipeline. auto pipeline = Pipeline::parse(request.getPipeline(), expCtx); + + // If the aggregate command supports encrypted collections, do rewrites of the pipeline to + // support querying against encrypted fields. + if (shouldDoFLERewrite) { + // After this rewriting, the encryption info does not need to be kept around. + pipeline = fle::processPipeline(opCtx, + namespaces.executionNss, + request.getEncryptionInformation().get(), + std::move(pipeline)); + request.setEncryptionInformation(boost::none); + } + pipeline->optimizePipeline(); return pipeline; }; + // The pipeline is not allowed to passthrough if any stage is not allowed to passthrough or if + // the pipeline needs to undergo FLE rewriting first. + auto allowedToPassthrough = + liteParsedPipeline.allowedToPassthroughFromMongos() && !shouldDoFLERewrite; auto targeter = cluster_aggregation_planner::AggregationTargeter::make( opCtx, namespaces.executionNss, @@ -374,7 +392,7 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx, cm, involvedNamespaces, hasChangeStream, - liteParsedPipeline.allowedToPassthroughFromMongos(), + allowedToPassthrough, request.getPassthroughToShard().has_value()); if (!expCtx) { |