summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/mongo/db/commands/write_commands.cpp16
-rw-r--r--src/mongo/db/fle_crud.cpp40
-rw-r--r--src/mongo/db/fle_crud.h41
-rw-r--r--src/mongo/db/fle_crud_mongod.cpp12
-rw-r--r--src/mongo/db/query/fle/server_rewrite.cpp110
-rw-r--r--src/mongo/db/query/fle/server_rewrite.h18
-rw-r--r--src/mongo/s/commands/SConscript1
-rw-r--r--src/mongo/s/commands/cluster_write_cmd.cpp22
8 files changed, 190 insertions, 70 deletions
diff --git a/src/mongo/db/commands/write_commands.cpp b/src/mongo/db/commands/write_commands.cpp
index 2a6279e6388..93e9cea9c64 100644
--- a/src/mongo/db/commands/write_commands.cpp
+++ b/src/mongo/db/commands/write_commands.cpp
@@ -1537,6 +1537,13 @@ public:
UpdateRequest updateRequest(request().getUpdates()[0]);
updateRequest.setNamespaceString(request().getNamespace());
+ if (shouldDoFLERewrite(request())) {
+ updateRequest.setQuery(
+ processFLEWriteExplainD(opCtx,
+ write_ops::collationOf(request().getUpdates()[0]),
+ request(),
+ updateRequest.getQuery()));
+ }
updateRequest.setLegacyRuntimeConstants(request().getLegacyRuntimeConstants().value_or(
Variables::generateRuntimeConstants(opCtx)));
updateRequest.setLetParameters(request().getLet());
@@ -1677,7 +1684,14 @@ public:
deleteRequest.setLegacyRuntimeConstants(request().getLegacyRuntimeConstants().value_or(
Variables::generateRuntimeConstants(opCtx)));
deleteRequest.setLet(request().getLet());
- deleteRequest.setQuery(request().getDeletes()[0].getQ());
+
+ BSONObj query = request().getDeletes()[0].getQ();
+ if (shouldDoFLERewrite(request())) {
+ query = processFLEWriteExplainD(
+ opCtx, write_ops::collationOf(request().getDeletes()[0]), request(), query);
+ }
+ deleteRequest.setQuery(std::move(query));
+
deleteRequest.setCollation(write_ops::collationOf(request().getDeletes()[0]));
deleteRequest.setMulti(request().getDeletes()[0].getMulti());
deleteRequest.setYieldPolicy(PlanYieldPolicy::YieldPolicy::YIELD_AUTO);
diff --git a/src/mongo/db/fle_crud.cpp b/src/mongo/db/fle_crud.cpp
index 6d606d93d2a..990a1c553a2 100644
--- a/src/mongo/db/fle_crud.cpp
+++ b/src/mongo/db/fle_crud.cpp
@@ -855,6 +855,46 @@ FLEBatchResult processFLEBatch(OperationContext* opCtx,
MONGO_UNREACHABLE;
}
+std::unique_ptr<BatchedCommandRequest> processFLEBatchExplain(
+ OperationContext* opCtx, const BatchedCommandRequest& request) {
+ invariant(request.hasEncryptionInformation());
+ auto getExpCtx = [&](const auto& op) {
+ auto expCtx = make_intrusive<ExpressionContext>(
+ opCtx,
+ fle::collatorFromBSON(opCtx, op.getCollation().value_or(BSONObj())),
+ request.getNS(),
+ request.getLegacyRuntimeConstants(),
+ request.getLet());
+ expCtx->stopExpressionCounters();
+ return expCtx;
+ };
+
+ if (request.getBatchType() == BatchedCommandRequest::BatchType_Delete) {
+ auto deleteRequest = request.getDeleteRequest();
+ auto newDeleteOp = deleteRequest.getDeletes()[0];
+ newDeleteOp.setQ(fle::rewriteQuery(opCtx,
+ getExpCtx(newDeleteOp),
+ request.getNS(),
+ deleteRequest.getEncryptionInformation().get(),
+ newDeleteOp.getQ(),
+ &getTransactionWithRetriesForMongoS));
+ deleteRequest.setDeletes({newDeleteOp});
+ return std::make_unique<BatchedCommandRequest>(deleteRequest);
+ } else if (request.getBatchType() == BatchedCommandRequest::BatchType_Update) {
+ auto updateRequest = request.getUpdateRequest();
+ auto newUpdateOp = updateRequest.getUpdates()[0];
+ newUpdateOp.setQ(fle::rewriteQuery(opCtx,
+ getExpCtx(newUpdateOp),
+ request.getNS(),
+ updateRequest.getEncryptionInformation().get(),
+ newUpdateOp.getQ(),
+ &getTransactionWithRetriesForMongoS));
+ updateRequest.setUpdates({newUpdateOp});
+ return std::make_unique<BatchedCommandRequest>(updateRequest);
+ }
+ MONGO_UNREACHABLE;
+}
+
// See processUpdate for algorithm overview
write_ops::FindAndModifyCommandReply processFindAndModify(
boost::intrusive_ptr<ExpressionContext> expCtx,
diff --git a/src/mongo/db/fle_crud.h b/src/mongo/db/fle_crud.h
index 1b36850c213..13e53f3406d 100644
--- a/src/mongo/db/fle_crud.h
+++ b/src/mongo/db/fle_crud.h
@@ -29,9 +29,10 @@
#pragma once
-#include "boost/smart_ptr/intrusive_ptr.hpp"
#include <cstdint>
+#include "boost/smart_ptr/intrusive_ptr.hpp"
+
#include "mongo/bson/bsonobj.h"
#include "mongo/bson/oid.h"
#include "mongo/crypto/fle_crypto.h"
@@ -72,6 +73,12 @@ FLEBatchResult processFLEBatch(OperationContext* opCtx,
BatchedCommandResponse* response,
boost::optional<OID> targetEpoch);
+/**
+ * Rewrite a BatchedCommandRequest for explain commands.
+ */
+std::unique_ptr<BatchedCommandRequest> processFLEBatchExplain(OperationContext* opCtx,
+ const BatchedCommandRequest& request);
+
/**
* Initialize the FLE CRUD subsystem on Mongod.
@@ -98,6 +105,38 @@ write_ops::DeleteCommandReply processFLEDelete(
OperationContext* opCtx, const write_ops::DeleteCommandRequest& deleteRequest);
/**
+ * Rewrite the query within a replica set explain command for delete and update.
+ * This concrete function is passed all the parameters directly.
+ */
+BSONObj processFLEWriteExplainD(OperationContext* opCtx,
+ const BSONObj& collation,
+ const NamespaceString& nss,
+ const EncryptionInformation& info,
+ const boost::optional<LegacyRuntimeConstants>& runtimeConstants,
+ const boost::optional<BSONObj>& letParameters,
+ const BSONObj& query);
+
+/**
+ * Rewrite the query within a replica set explain command for delete and update.
+ * This template is passed the request object from the command and delegates
+ * to the function above.
+ */
+template <typename T>
+BSONObj processFLEWriteExplainD(OperationContext* opCtx,
+ const BSONObj& collation,
+ const T& request,
+ const BSONObj& query) {
+
+ return processFLEWriteExplainD(opCtx,
+ collation,
+ request.getNamespace(),
+ request.getEncryptionInformation().get(),
+ request.getLegacyRuntimeConstants(),
+ request.getLet(),
+ query);
+}
+
+/**
* Process a replica set update.
*/
write_ops::UpdateCommandReply processFLEUpdate(
diff --git a/src/mongo/db/fle_crud_mongod.cpp b/src/mongo/db/fle_crud_mongod.cpp
index cefa6a9e0bf..7e8eea293fa 100644
--- a/src/mongo/db/fle_crud_mongod.cpp
+++ b/src/mongo/db/fle_crud_mongod.cpp
@@ -253,4 +253,16 @@ std::unique_ptr<Pipeline, PipelineDeleter> processFLEPipelineD(
return fle::processPipeline(
opCtx, nss, encryptInfo, std::move(toRewrite), &getTransactionWithRetriesForMongoD);
}
+
+BSONObj processFLEWriteExplainD(OperationContext* opCtx,
+ const BSONObj& collation,
+ const NamespaceString& nss,
+ const EncryptionInformation& info,
+ const boost::optional<LegacyRuntimeConstants>& runtimeConstants,
+ const boost::optional<BSONObj>& letParameters,
+ const BSONObj& query) {
+ auto expCtx = make_intrusive<ExpressionContext>(
+ opCtx, fle::collatorFromBSON(opCtx, collation), nss, runtimeConstants, letParameters);
+ return fle::rewriteQuery(opCtx, expCtx, nss, info, query, &getTransactionWithRetriesForMongoD);
+}
} // namespace mongo
diff --git a/src/mongo/db/query/fle/server_rewrite.cpp b/src/mongo/db/query/fle/server_rewrite.cpp
index d08b9173884..876b0c6823d 100644
--- a/src/mongo/db/query/fle/server_rewrite.cpp
+++ b/src/mongo/db/query/fle/server_rewrite.cpp
@@ -32,12 +32,15 @@
#include <memory>
+#include "mongo/bson/bsonobj.h"
#include "mongo/bson/bsonobjbuilder.h"
#include "mongo/bson/bsontypes.h"
#include "mongo/crypto/encryption_fields_gen.h"
#include "mongo/crypto/fle_crypto.h"
+#include "mongo/crypto/fle_field_schema_gen.h"
#include "mongo/crypto/fle_tags.h"
#include "mongo/db/fle_crud.h"
+#include "mongo/db/operation_context.h"
#include "mongo/db/pipeline/document_source_geo_near.h"
#include "mongo/db/pipeline/document_source_graph_lookup.h"
#include "mongo/db/pipeline/document_source_match.h"
@@ -47,63 +50,27 @@
#include "mongo/util/assert_util.h"
namespace mongo::fle {
-namespace {
-BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader,
- const FLEStateCollectionReader& eccReader,
- boost::intrusive_ptr<ExpressionContext> expCtx,
- BSONObj filter) {
- return MatchExpressionRewrite(expCtx, escReader, eccReader, filter).get();
-}
-
-/**
- * Make an expression context from a find command.
- */
-boost::intrusive_ptr<ExpressionContext> makeExpCtx(OperationContext* opCtx,
- const NamespaceString& nss,
- FindCommandRequest* findCommand) {
+// TODO: This is a generally useful helper function that should probably go in some other namespace.
+std::unique_ptr<CollatorInterface> collatorFromBSON(OperationContext* opCtx,
+ const BSONObj& collation) {
std::unique_ptr<CollatorInterface> collator;
- if (!findCommand->getCollation().isEmpty()) {
- auto statusWithCollator = CollatorFactoryInterface::get(opCtx->getServiceContext())
- ->makeFromBSON(findCommand->getCollation());
-
+ if (!collation.isEmpty()) {
+ auto statusWithCollator =
+ CollatorFactoryInterface::get(opCtx->getServiceContext())->makeFromBSON(collation);
uassertStatusOK(statusWithCollator.getStatus());
collator = std::move(statusWithCollator.getValue());
}
- auto expCtx = make_intrusive<ExpressionContext>(opCtx,
- std::move(collator),
- nss,
- findCommand->getLegacyRuntimeConstants(),
- findCommand->getLet());
- expCtx->stopExpressionCounters();
- return expCtx;
+ return collator;
}
-
-boost::intrusive_ptr<ExpressionContext> makeExpCtx(OperationContext* opCtx,
- const NamespaceString& nss,
- CountCommandRequest* countCommand) {
-
- std::unique_ptr<CollatorInterface> collator;
- if (countCommand->getCollation()) {
- auto statusWithCollator = CollatorFactoryInterface::get(opCtx->getServiceContext())
- ->makeFromBSON(countCommand->getCollation().get());
-
- uassertStatusOK(statusWithCollator.getStatus());
- collator = std::move(statusWithCollator.getValue());
- }
- auto expCtx = make_intrusive<ExpressionContext>(
- opCtx,
- std::move(collator),
- nss,
- // Count command does not have legacy runtime constants, and does not support user variables
- // defined in a let expression.
- boost::none,
- boost::none);
- expCtx->stopExpressionCounters();
- return expCtx;
+namespace {
+BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader,
+ const FLEStateCollectionReader& eccReader,
+ boost::intrusive_ptr<ExpressionContext> expCtx,
+ BSONObj filter) {
+ return MatchExpressionRewrite(expCtx, escReader, eccReader, filter).get();
}
-
class RewriteBase {
public:
RewriteBase(boost::intrusive_ptr<ExpressionContext> expCtx,
@@ -225,6 +192,17 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl,
return rewriteEncryptedFilter(escReader, eccReader, expCtx, filter);
}
+BSONObj rewriteQuery(OperationContext* opCtx,
+ boost::intrusive_ptr<ExpressionContext> expCtx,
+ const NamespaceString& nss,
+ const EncryptionInformation& info,
+ BSONObj filter,
+ GetTxnCallback getTransaction) {
+ auto sharedBlock = std::make_shared<FilterRewrite>(expCtx, nss, info, filter);
+ doFLERewriteInTxn(opCtx, sharedBlock, getTransaction);
+ return sharedBlock->rewrittenFilter.getOwned();
+}
+
void processFindCommand(OperationContext* opCtx,
const NamespaceString& nss,
@@ -232,15 +210,19 @@ void processFindCommand(OperationContext* opCtx,
GetTxnCallback getTransaction) {
invariant(findCommand->getEncryptionInformation());
- auto sharedBlock =
- std::make_shared<FilterRewrite>(makeExpCtx(opCtx, nss, findCommand),
+ auto expCtx =
+ make_intrusive<ExpressionContext>(opCtx,
+ collatorFromBSON(opCtx, findCommand->getCollation()),
+ nss,
+ findCommand->getLegacyRuntimeConstants(),
+ findCommand->getLet());
+ expCtx->stopExpressionCounters();
+ findCommand->setFilter(rewriteQuery(opCtx,
+ expCtx,
nss,
findCommand->getEncryptionInformation().get(),
- findCommand->getFilter().getOwned());
- doFLERewriteInTxn(opCtx, sharedBlock, getTransaction);
-
- auto rewrittenFilter = sharedBlock->rewrittenFilter.getOwned();
- findCommand->setFilter(std::move(rewrittenFilter));
+ findCommand->getFilter().getOwned(),
+ getTransaction));
// The presence of encryptionInformation is a signal that this is a FLE request that requires
// special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but
// a normal query that can be executed by the query system like any other, so remove
@@ -253,16 +235,18 @@ void processCountCommand(OperationContext* opCtx,
CountCommandRequest* countCommand,
GetTxnCallback getTxn) {
invariant(countCommand->getEncryptionInformation());
+ // Count command does not have legacy runtime constants, and does not support user variables
+ // defined in a let expression.
+ auto expCtx = make_intrusive<ExpressionContext>(
+ opCtx, collatorFromBSON(opCtx, countCommand->getCollation().value_or(BSONObj())), nss);
+ expCtx->stopExpressionCounters();
- auto sharedBlock =
- std::make_shared<FilterRewrite>(makeExpCtx(opCtx, nss, countCommand),
+ countCommand->setQuery(rewriteQuery(opCtx,
+ expCtx,
nss,
countCommand->getEncryptionInformation().get(),
- countCommand->getQuery().getOwned());
- doFLERewriteInTxn(opCtx, sharedBlock, getTxn);
-
- auto rewrittenFilter = sharedBlock->rewrittenFilter.getOwned();
- countCommand->setQuery(std::move(rewrittenFilter));
+ countCommand->getQuery().getOwned(),
+ getTxn));
// The presence of encryptionInformation is a signal that this is a FLE request that requires
// special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but
// a normal query that can be executed by the query system like any other, so remove
diff --git a/src/mongo/db/query/fle/server_rewrite.h b/src/mongo/db/query/fle/server_rewrite.h
index e85c6911ea4..c653932fe5c 100644
--- a/src/mongo/db/query/fle/server_rewrite.h
+++ b/src/mongo/db/query/fle/server_rewrite.h
@@ -47,6 +47,24 @@ class FLEQueryInterface;
namespace fle {
/**
+ * Make a collator object from its BSON representation. Useful when creating ExpressionContext
+ * objects for parsing MatchExpressions as part of the server-side rewrite.
+ */
+std::unique_ptr<CollatorInterface> collatorFromBSON(OperationContext* opCtx,
+ const BSONObj& collation);
+
+/**
+ * Return a rewritten version of the passed-in filter as a BSONObj. Generally used by other
+ * functions to process MatchExpressions in each command.
+ */
+BSONObj rewriteQuery(OperationContext* opCtx,
+ boost::intrusive_ptr<ExpressionContext> expCtx,
+ const NamespaceString& nss,
+ const EncryptionInformation& info,
+ BSONObj filter,
+ GetTxnCallback getTransaction);
+
+/**
* Process a find command with encryptionInformation in-place, rewriting the filter condition so
* that any query on an encrypted field will properly query the underlying tags array.
*/
diff --git a/src/mongo/s/commands/SConscript b/src/mongo/s/commands/SConscript
index 0a1796119aa..5cdade782f9 100644
--- a/src/mongo/s/commands/SConscript
+++ b/src/mongo/s/commands/SConscript
@@ -172,6 +172,7 @@ env.Library(
],
LIBDEPS_PRIVATE=[
'$BUILD_DIR/mongo/db/commands/core',
+ '$BUILD_DIR/mongo/db/fle_crud',
'$BUILD_DIR/mongo/db/initialize_api_parameters',
'$BUILD_DIR/mongo/db/internal_transactions_feature_flag',
'$BUILD_DIR/mongo/db/read_write_concern_defaults',
diff --git a/src/mongo/s/commands/cluster_write_cmd.cpp b/src/mongo/s/commands/cluster_write_cmd.cpp
index 47eb18b4b0e..3810f0a4b3d 100644
--- a/src/mongo/s/commands/cluster_write_cmd.cpp
+++ b/src/mongo/s/commands/cluster_write_cmd.cpp
@@ -35,6 +35,7 @@
#include "mongo/client/remote_command_targeter.h"
#include "mongo/db/catalog/document_validation.h"
#include "mongo/db/curop.h"
+#include "mongo/db/fle_crud.h"
#include "mongo/db/internal_transactions_feature_flag_gen.h"
#include "mongo/db/pipeline/lite_parsed_pipeline.h"
#include "mongo/db/stats/counters.h"
@@ -644,22 +645,33 @@ void ClusterWriteCmd::InvocationBase::explain(OperationContext* opCtx,
"explained write batches must be of size 1",
_batchedRequest.sizeWriteOps() == 1U);
- const auto explainCmd = ClusterExplain::wrapAsExplain(_request->body, verbosity);
+
+ std::unique_ptr<BatchedCommandRequest> req;
+ if (_batchedRequest.hasEncryptionInformation() &&
+ (_batchedRequest.getBatchType() == BatchedCommandRequest::BatchType_Delete ||
+ _batchedRequest.getBatchType() == BatchedCommandRequest::BatchType_Update)) {
+ req = processFLEBatchExplain(opCtx, _batchedRequest);
+ }
+
+ auto nss = req ? req->getNS() : _batchedRequest.getNS();
+ auto requestBSON = req ? req->toBSON() : _request->body;
+ auto requestPtr = req ? req.get() : &_batchedRequest;
+
+ const auto explainCmd = ClusterExplain::wrapAsExplain(requestBSON, verbosity);
// We will time how long it takes to run the commands on the shards.
Timer timer;
// Target the command to the shards based on the singleton batch item.
- BatchItemRef targetingBatchItem(&_batchedRequest, 0);
+ BatchItemRef targetingBatchItem(requestPtr, 0);
std::vector<AsyncRequestsSender::Response> shardResponses;
- _commandOpWrite(
- opCtx, _batchedRequest.getNS(), explainCmd, targetingBatchItem, &shardResponses);
+ _commandOpWrite(opCtx, nss, explainCmd, targetingBatchItem, &shardResponses);
auto bodyBuilder = result->getBodyBuilder();
uassertStatusOK(ClusterExplain::buildExplainResult(opCtx,
shardResponses,
ClusterExplain::kWriteOnShards,
timer.millis(),
- _request->body,
+ requestBSON,
&bodyBuilder));
}