summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHana Pearlman <hana.pearlman@mongodb.com>2022-03-28 21:32:48 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-03-28 22:52:34 +0000
commit6eb0a1b696278d4473365a169ae08f51c8a12388 (patch)
tree1fdb2b9705f131f9ae513ded5eecebcf46f8fc58
parent51918a385ab8b3fc1f922bede326ff84e846a47b (diff)
downloadmongo-6eb0a1b696278d4473365a169ae08f51c8a12388.tar.gz
SERVER-64359: Implement FLE server-side rewrite for agg command on mongos
-rw-r--r--src/mongo/db/commands/current_op.cpp2
-rw-r--r--src/mongo/db/commands/current_op_common.cpp4
-rw-r--r--src/mongo/db/commands/current_op_common.h4
-rw-r--r--src/mongo/db/query/fle/server_rewrite.cpp158
-rw-r--r--src/mongo/db/query/fle/server_rewrite.h11
-rw-r--r--src/mongo/s/commands/cluster_current_op.cpp2
-rw-r--r--src/mongo/s/commands/cluster_pipeline_cmd.cpp4
-rw-r--r--src/mongo/s/query/SConscript1
-rw-r--r--src/mongo/s/query/cluster_aggregate.cpp26
-rw-r--r--src/mongo/s/query/cluster_aggregate.h6
-rw-r--r--src/mongo/s/transaction_router.cpp8
11 files changed, 159 insertions, 67 deletions
diff --git a/src/mongo/db/commands/current_op.cpp b/src/mongo/db/commands/current_op.cpp
index 1fd9200f9da..78fec805202 100644
--- a/src/mongo/db/commands/current_op.cpp
+++ b/src/mongo/db/commands/current_op.cpp
@@ -67,7 +67,7 @@ public:
}
virtual StatusWith<CursorResponse> runAggregation(
- OperationContext* opCtx, const AggregateCommandRequest& request) const final {
+ OperationContext* opCtx, AggregateCommandRequest& request) const final {
auto aggCmdObj = aggregation_request_helper::serializeToCommandObj(request);
rpc::OpMsgReplyBuilder replyBuilder;
diff --git a/src/mongo/db/commands/current_op_common.cpp b/src/mongo/db/commands/current_op_common.cpp
index d44d5c066d2..6b816a79057 100644
--- a/src/mongo/db/commands/current_op_common.cpp
+++ b/src/mongo/db/commands/current_op_common.cpp
@@ -110,8 +110,8 @@ bool CurrentOpCommandBase::run(OperationContext* opCtx,
pipeline.push_back(groupBuilder.obj());
// Pipeline is complete; create an AggregateCommandRequest for $currentOp.
- const AggregateCommandRequest request(NamespaceString::makeCollectionlessAggregateNSS("admin"),
- std::move(pipeline));
+ AggregateCommandRequest request(NamespaceString::makeCollectionlessAggregateNSS("admin"),
+ std::move(pipeline));
// Run the pipeline and obtain a CursorResponse.
auto aggResults = uassertStatusOK(runAggregation(opCtx, request));
diff --git a/src/mongo/db/commands/current_op_common.h b/src/mongo/db/commands/current_op_common.h
index 31ff95ac764..789ece31e0a 100644
--- a/src/mongo/db/commands/current_op_common.h
+++ b/src/mongo/db/commands/current_op_common.h
@@ -73,8 +73,8 @@ private:
* Runs the aggregation specified by the supplied AggregateCommandRequest, returning a
* CursorResponse if successful or a Status containing the error otherwise.
*/
- virtual StatusWith<CursorResponse> runAggregation(
- OperationContext* opCtx, const AggregateCommandRequest& request) const = 0;
+ virtual StatusWith<CursorResponse> runAggregation(OperationContext* opCtx,
+ AggregateCommandRequest& request) const = 0;
/**
* Allows overriders to optionally write additional data to the response object before the final
diff --git a/src/mongo/db/query/fle/server_rewrite.cpp b/src/mongo/db/query/fle/server_rewrite.cpp
index b83b087a04a..792aed343c8 100644
--- a/src/mongo/db/query/fle/server_rewrite.cpp
+++ b/src/mongo/db/query/fle/server_rewrite.cpp
@@ -34,6 +34,9 @@
#include "mongo/crypto/fle_crypto.h"
#include "mongo/crypto/fle_tags.h"
#include "mongo/db/fle_crud.h"
+#include "mongo/db/pipeline/document_source_geo_near.h"
+#include "mongo/db/pipeline/document_source_graph_lookup.h"
+#include "mongo/db/pipeline/document_source_match.h"
#include "mongo/db/query/collation/collator_factory_interface.h"
#include "mongo/db/transaction_api.h"
#include "mongo/s/grid.h"
@@ -60,51 +63,83 @@ boost::intrusive_ptr<ExpressionContext> makeExpCtx(OperationContext* opCtx,
findCommand->getLegacyRuntimeConstants(),
findCommand->getLet());
}
-} // namespace
-BSONObj rewriteEncryptedFilter(boost::intrusive_ptr<ExpressionContext> expCtx,
- const FLEStateCollectionReader& escReader,
- const FLEStateCollectionReader& eccReader,
- BSONObj filter) {
- return MatchExpressionRewrite(expCtx, escReader, eccReader, filter).get();
-}
+class RewriteBase {
+public:
+ RewriteBase(boost::intrusive_ptr<ExpressionContext> expCtx,
+ const NamespaceString& nss,
+ const EncryptionInformation& encryptInfo)
+ : expCtx(expCtx), db(nss.db()) {
+ auto efc = EncryptionInformationHelpers::getAndValidateSchema(nss, encryptInfo);
+ esc = efc.getEscCollection()->toString();
+ ecc = efc.getEccCollection()->toString();
+ }
+ virtual ~RewriteBase(){};
+ virtual void doRewrite(FLEStateCollectionReader& escReader,
+ FLEStateCollectionReader& eccReader){};
+
+ boost::intrusive_ptr<ExpressionContext> expCtx;
+ std::string esc;
+ std::string ecc;
+ std::string db;
+};
+
+// This class handles rewriting of an entire pipeline.
+class PipelineRewrite : public RewriteBase {
+public:
+ PipelineRewrite(const NamespaceString& nss,
+ const EncryptionInformation& encryptInfo,
+ std::unique_ptr<Pipeline, PipelineDeleter> toRewrite)
+ : RewriteBase(toRewrite->getContext(), nss, encryptInfo), pipeline(std::move(toRewrite)) {}
+
+ ~PipelineRewrite(){};
+ void doRewrite(FLEStateCollectionReader& escReader, FLEStateCollectionReader& eccReader) final {
+ for (auto&& source : pipeline->getSources()) {
+ if (auto match = dynamic_cast<DocumentSourceMatch*>(source.get())) {
+ match->rebuild(
+ rewriteEncryptedFilter(expCtx, escReader, eccReader, match->getQuery()));
+ } else if (auto geoNear = dynamic_cast<DocumentSourceGeoNear*>(source.get())) {
+ geoNear->setQuery(
+ rewriteEncryptedFilter(expCtx, escReader, eccReader, geoNear->getQuery()));
+ } else if (auto graphLookup = dynamic_cast<DocumentSourceGraphLookUp*>(source.get());
+ graphLookup && graphLookup->getAdditionalFilter()) {
+ graphLookup->setAdditionalFilter(rewriteEncryptedFilter(
+ expCtx, escReader, eccReader, graphLookup->getAdditionalFilter().get()));
+ }
+ }
+ }
-void processFindCommand(OperationContext* opCtx,
- NamespaceString nss,
- FindCommandRequest* findCommand) {
- invariant(findCommand->getEncryptionInformation());
+ std::unique_ptr<Pipeline, PipelineDeleter> getPipeline() {
+ return std::move(pipeline);
+ }
+
+private:
+ std::unique_ptr<Pipeline, PipelineDeleter> pipeline;
+};
+
+// This class handles rewriting of a single match expression, represented as a BSONObj.
+class FilterRewrite : public RewriteBase {
+public:
+ FilterRewrite(boost::intrusive_ptr<ExpressionContext> expCtx,
+ const NamespaceString& nss,
+ const EncryptionInformation& encryptInfo,
+ const BSONObj toRewrite)
+ : RewriteBase(expCtx, nss, encryptInfo), userFilter(toRewrite) {}
+
+ ~FilterRewrite(){};
+ void doRewrite(FLEStateCollectionReader& escReader, FLEStateCollectionReader& eccReader) final {
+ rewrittenFilter = rewriteEncryptedFilter(expCtx, escReader, eccReader, userFilter);
+ }
- auto efc = EncryptionInformationHelpers::getAndValidateSchema(
- nss, findCommand->getEncryptionInformation().get());
-
- // The transaction runs in a separate executor, and so we can't pass data by
- // reference into the lambda. This struct holds all the data we need inside the
- // lambda, and is passed in a more threadsafe shared_ptr.
- struct SharedBlock {
- SharedBlock(NamespaceString nss,
- std::string esc,
- std::string ecc,
- const BSONObj userFilter,
- boost::intrusive_ptr<ExpressionContext> expCtx)
- : esc(std::move(esc)),
- ecc(std::move(ecc)),
- userFilter(userFilter),
- db(nss.db()),
- expCtx(expCtx) {}
- std::string esc;
- std::string ecc;
- const BSONObj userFilter;
- BSONObj rewrittenFilter;
- std::string db;
- boost::intrusive_ptr<ExpressionContext> expCtx;
- };
-
- auto sharedBlock = std::make_shared<SharedBlock>(nss,
- efc.getEscCollection().get().toString(),
- efc.getEccCollection().get().toString(),
- findCommand->getFilter().getOwned(),
- makeExpCtx(opCtx, findCommand));
+ const BSONObj userFilter;
+ BSONObj rewrittenFilter;
+};
+// This helper executes the rewrite(s) inside a transaction. The transaction runs in a separate
+// executor, and so we can't pass data by reference into the lambda. The provided rewriter should
+// hold all the data we need to do the rewriting inside the lambda, and is passed in a more
+// threadsafe shared_ptr. The result of applying the rewrites can be accessed in the RewriteBase.
+void doFLERewriteInTxn(OperationContext* opCtx, std::shared_ptr<RewriteBase> sharedBlock) {
auto txn = std::make_shared<txn_api::TransactionWithRetries>(
opCtx,
Grid::get(opCtx)->getExecutorPool()->getFixedExecutor(),
@@ -125,8 +160,7 @@ void processFindCommand(OperationContext* opCtx,
auto eccReader = makeCollectionReader(&queryInterface, sharedBlock->ecc);
// Rewrite the MatchExpression.
- sharedBlock->rewrittenFilter = rewriteEncryptedFilter(
- sharedBlock->expCtx, escReader, eccReader, sharedBlock->userFilter);
+ sharedBlock->doRewrite(escReader, eccReader);
return SemiFuture<void>::makeReady();
});
@@ -134,17 +168,43 @@ void processFindCommand(OperationContext* opCtx,
uassertStatusOK(swCommitResult);
uassertStatusOK(swCommitResult.getValue().cmdStatus);
uassertStatusOK(swCommitResult.getValue().getEffectiveStatus());
+}
+} // namespace
+
+BSONObj rewriteEncryptedFilter(boost::intrusive_ptr<ExpressionContext> expCtx,
+ const FLEStateCollectionReader& escReader,
+ const FLEStateCollectionReader& eccReader,
+ BSONObj filter) {
+ return MatchExpressionRewrite(expCtx, escReader, eccReader, filter).get();
+}
+
+void processFindCommand(OperationContext* opCtx,
+ NamespaceString nss,
+ FindCommandRequest* findCommand) {
+ invariant(findCommand->getEncryptionInformation());
+
+ auto sharedBlock =
+ std::make_shared<FilterRewrite>(makeExpCtx(opCtx, findCommand),
+ nss,
+ findCommand->getEncryptionInformation().get(),
+ findCommand->getFilter().getOwned());
+ doFLERewriteInTxn(opCtx, sharedBlock);
auto rewrittenFilter = sharedBlock->rewrittenFilter.getOwned();
findCommand->setFilter(std::move(rewrittenFilter));
findCommand->setEncryptionInformation(boost::none);
+}
- // If we are in a multi-document transaction, then the transaction API has taken
- // care of setting the readConcern on the transaction, and the find command
- // shouldn't provide its own readConcern.
- if (opCtx->inMultiDocumentTransaction()) {
- findCommand->setReadConcern(boost::none);
- }
+std::unique_ptr<Pipeline, PipelineDeleter> processPipeline(
+ OperationContext* opCtx,
+ NamespaceString nss,
+ const EncryptionInformation& encryptInfo,
+ std::unique_ptr<Pipeline, PipelineDeleter> toRewrite) {
+
+ auto sharedBlock = std::make_shared<PipelineRewrite>(nss, encryptInfo, std::move(toRewrite));
+ doFLERewriteInTxn(opCtx, sharedBlock);
+
+ return sharedBlock->getPipeline();
}
std::unique_ptr<MatchExpression> MatchExpressionRewrite::_rewriteMatchExpression(
diff --git a/src/mongo/db/query/fle/server_rewrite.h b/src/mongo/db/query/fle/server_rewrite.h
index d914d60e01d..3ea8fb5592b 100644
--- a/src/mongo/db/query/fle/server_rewrite.h
+++ b/src/mongo/db/query/fle/server_rewrite.h
@@ -58,6 +58,17 @@ void processFindCommand(OperationContext* opCtx,
FindCommandRequest* findCommand);
/**
+ * Process a pipeline with encryptionInformation by rewriting the pipeline to query against the
+ * underlying tags array, where appropriate. After this rewriting is complete, there is no more FLE
+ * work to be done. The encryption info does not need to be kept around (e.g. on a command object).
+ */
+std::unique_ptr<Pipeline, PipelineDeleter> processPipeline(
+ OperationContext* opCtx,
+ NamespaceString nss,
+ const EncryptionInformation& encryptInfo,
+ std::unique_ptr<Pipeline, PipelineDeleter> toRewrite);
+
+/**
* Class which handles rewriting filter MatchExpressions for FLE2. The functionality is encapsulated
* as a class rather than just a namespace so that the collection readers don't have to be passed
* around as extra arguments to every function.
diff --git a/src/mongo/s/commands/cluster_current_op.cpp b/src/mongo/s/commands/cluster_current_op.cpp
index 3c0e73f30d3..5f4f2187b49 100644
--- a/src/mongo/s/commands/cluster_current_op.cpp
+++ b/src/mongo/s/commands/cluster_current_op.cpp
@@ -71,7 +71,7 @@ private:
}
virtual StatusWith<CursorResponse> runAggregation(
- OperationContext* opCtx, const AggregateCommandRequest& request) const final {
+ OperationContext* opCtx, AggregateCommandRequest& request) const final {
auto nss = request.getNamespace();
BSONObjBuilder responseBuilder;
diff --git a/src/mongo/s/commands/cluster_pipeline_cmd.cpp b/src/mongo/s/commands/cluster_pipeline_cmd.cpp
index 87dbb0e05c1..aec0641e905 100644
--- a/src/mongo/s/commands/cluster_pipeline_cmd.cpp
+++ b/src/mongo/s/commands/cluster_pipeline_cmd.cpp
@@ -94,7 +94,7 @@ public:
public:
Invocation(Command* cmd,
const OpMsgRequest& request,
- const AggregateCommandRequest aggregationRequest,
+ AggregateCommandRequest aggregationRequest,
PrivilegeVector privileges)
: CommandInvocation(cmd),
_request(request),
@@ -171,7 +171,7 @@ public:
const OpMsgRequest& _request;
const std::string _dbName;
- const AggregateCommandRequest _aggregationRequest;
+ AggregateCommandRequest _aggregationRequest;
const LiteParsedPipeline _liteParsedPipeline;
const PrivilegeVector _privileges;
};
diff --git a/src/mongo/s/query/SConscript b/src/mongo/s/query/SConscript
index 1ebf88e5164..f6f1ac53b05 100644
--- a/src/mongo/s/query/SConscript
+++ b/src/mongo/s/query/SConscript
@@ -33,6 +33,7 @@ env.Library(
'cluster_aggregation_planner.cpp',
],
LIBDEPS=[
+ '$BUILD_DIR/mongo/db/fle_crud',
'$BUILD_DIR/mongo/db/pipeline/pipeline',
'$BUILD_DIR/mongo/db/pipeline/process_interface/mongos_process_interface',
'$BUILD_DIR/mongo/db/pipeline/sharded_agg_helpers',
diff --git a/src/mongo/s/query/cluster_aggregate.cpp b/src/mongo/s/query/cluster_aggregate.cpp
index 966ca7c970c..3c754850628 100644
--- a/src/mongo/s/query/cluster_aggregate.cpp
+++ b/src/mongo/s/query/cluster_aggregate.cpp
@@ -53,6 +53,7 @@
#include "mongo/db/query/cursor_response.h"
#include "mongo/db/query/explain_common.h"
#include "mongo/db/query/find_common.h"
+#include "mongo/db/query/fle/server_rewrite.h"
#include "mongo/db/timeseries/timeseries_options.h"
#include "mongo/db/views/resolved_view.h"
#include "mongo/db/views/view.h"
@@ -251,7 +252,7 @@ std::vector<BSONObj> rebuildPipelineWithTimeSeriesGranularity(const std::vector<
Status ClusterAggregate::runAggregate(OperationContext* opCtx,
const Namespaces& namespaces,
- const AggregateCommandRequest& request,
+ AggregateCommandRequest& request,
const PrivilegeVector& privileges,
BSONObjBuilder* result) {
return runAggregate(opCtx, namespaces, request, {request}, privileges, result);
@@ -259,7 +260,7 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx,
Status ClusterAggregate::runAggregate(OperationContext* opCtx,
const Namespaces& namespaces,
- const AggregateCommandRequest& request,
+ AggregateCommandRequest& request,
const LiteParsedPipeline& liteParsedPipeline,
const PrivilegeVector& privileges,
BSONObjBuilder* result) {
@@ -269,7 +270,7 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx,
Status ClusterAggregate::runAggregate(OperationContext* opCtx,
const Namespaces& namespaces,
- const AggregateCommandRequest& request,
+ AggregateCommandRequest& request,
const LiteParsedPipeline& liteParsedPipeline,
const PrivilegeVector& privileges,
boost::optional<ChunkManager> cm,
@@ -298,6 +299,7 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx,
opCtx, isSharded, request.getExplain(), serverGlobalParams.enableMajorityReadConcern);
auto hasChangeStream = liteParsedPipeline.hasChangeStream();
auto involvedNamespaces = liteParsedPipeline.getInvolvedNamespaces();
+ auto shouldDoFLERewrite = fle::shouldRewrite(&request);
uassert(6256300,
str::stream() << "On mongos, " << AggregateCommandRequest::kCollectionUUIDFieldName
@@ -363,10 +365,26 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx,
// Parse and optimize the full pipeline.
auto pipeline = Pipeline::parse(request.getPipeline(), expCtx);
+
+ // If the aggregate command supports encrypted collections, do rewrites of the pipeline to
+ // support querying against encrypted fields.
+ if (shouldDoFLERewrite) {
+ // After this rewriting, the encryption info does not need to be kept around.
+ pipeline = fle::processPipeline(opCtx,
+ namespaces.executionNss,
+ request.getEncryptionInformation().get(),
+ std::move(pipeline));
+ request.setEncryptionInformation(boost::none);
+ }
+
pipeline->optimizePipeline();
return pipeline;
};
+ // The pipeline is not allowed to passthrough if any stage is not allowed to passthrough or if
+ // the pipeline needs to undergo FLE rewriting first.
+ auto allowedToPassthrough =
+ liteParsedPipeline.allowedToPassthroughFromMongos() && !shouldDoFLERewrite;
auto targeter = cluster_aggregation_planner::AggregationTargeter::make(
opCtx,
namespaces.executionNss,
@@ -374,7 +392,7 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx,
cm,
involvedNamespaces,
hasChangeStream,
- liteParsedPipeline.allowedToPassthroughFromMongos(),
+ allowedToPassthrough,
request.getPassthroughToShard().has_value());
if (!expCtx) {
diff --git a/src/mongo/s/query/cluster_aggregate.h b/src/mongo/s/query/cluster_aggregate.h
index f16beabb648..5b3530d0860 100644
--- a/src/mongo/s/query/cluster_aggregate.h
+++ b/src/mongo/s/query/cluster_aggregate.h
@@ -83,7 +83,7 @@ public:
*/
static Status runAggregate(OperationContext* opCtx,
const Namespaces& namespaces,
- const AggregateCommandRequest& request,
+ AggregateCommandRequest& request,
const LiteParsedPipeline& liteParsedPipeline,
const PrivilegeVector& privileges,
boost::optional<ChunkManager> cm,
@@ -94,7 +94,7 @@ public:
*/
static Status runAggregate(OperationContext* opCtx,
const Namespaces& namespaces,
- const AggregateCommandRequest& request,
+ AggregateCommandRequest& request,
const LiteParsedPipeline& liteParsedPipeline,
const PrivilegeVector& privileges,
BSONObjBuilder* result);
@@ -105,7 +105,7 @@ public:
*/
static Status runAggregate(OperationContext* opCtx,
const Namespaces& namespaces,
- const AggregateCommandRequest& request,
+ AggregateCommandRequest& request,
const PrivilegeVector& privileges,
BSONObjBuilder* result);
diff --git a/src/mongo/s/transaction_router.cpp b/src/mongo/s/transaction_router.cpp
index d2e8a55b3a8..e2739a36778 100644
--- a/src/mongo/s/transaction_router.cpp
+++ b/src/mongo/s/transaction_router.cpp
@@ -442,11 +442,13 @@ BSONObj TransactionRouter::Participant::attachTxnFieldsIfNeeded(
auto cmdName = cmd.firstElement().fieldNameStringData();
bool mustStartTransaction = isFirstStatementInThisParticipant && !isTransactionCommand(cmdName);
+ // Strip the command of its read concern if it should not have one.
if (!mustStartTransaction) {
auto readConcernFieldName = repl::ReadConcernArgs::kReadConcernFieldName;
- dassert(!cmd.hasField(readConcernFieldName) ||
- cmd.getObjectField(readConcernFieldName).isEmpty() ||
- sharedOptions.isInternalTransactionForRetryableWrite);
+ if (cmd.hasField(readConcernFieldName) &&
+ !sharedOptions.isInternalTransactionForRetryableWrite) {
+ cmd = cmd.removeField(readConcernFieldName);
+ }
}
BSONObjBuilder newCmd = mustStartTransaction