diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/mongo/db/commands/write_commands.cpp | 16 | ||||
-rw-r--r-- | src/mongo/db/fle_crud.cpp | 40 | ||||
-rw-r--r-- | src/mongo/db/fle_crud.h | 41 | ||||
-rw-r--r-- | src/mongo/db/fle_crud_mongod.cpp | 12 | ||||
-rw-r--r-- | src/mongo/db/query/fle/server_rewrite.cpp | 110 | ||||
-rw-r--r-- | src/mongo/db/query/fle/server_rewrite.h | 18 | ||||
-rw-r--r-- | src/mongo/s/commands/SConscript | 1 | ||||
-rw-r--r-- | src/mongo/s/commands/cluster_write_cmd.cpp | 22 |
8 files changed, 190 insertions, 70 deletions
diff --git a/src/mongo/db/commands/write_commands.cpp b/src/mongo/db/commands/write_commands.cpp index 2a6279e6388..93e9cea9c64 100644 --- a/src/mongo/db/commands/write_commands.cpp +++ b/src/mongo/db/commands/write_commands.cpp @@ -1537,6 +1537,13 @@ public: UpdateRequest updateRequest(request().getUpdates()[0]); updateRequest.setNamespaceString(request().getNamespace()); + if (shouldDoFLERewrite(request())) { + updateRequest.setQuery( + processFLEWriteExplainD(opCtx, + write_ops::collationOf(request().getUpdates()[0]), + request(), + updateRequest.getQuery())); + } updateRequest.setLegacyRuntimeConstants(request().getLegacyRuntimeConstants().value_or( Variables::generateRuntimeConstants(opCtx))); updateRequest.setLetParameters(request().getLet()); @@ -1677,7 +1684,14 @@ public: deleteRequest.setLegacyRuntimeConstants(request().getLegacyRuntimeConstants().value_or( Variables::generateRuntimeConstants(opCtx))); deleteRequest.setLet(request().getLet()); - deleteRequest.setQuery(request().getDeletes()[0].getQ()); + + BSONObj query = request().getDeletes()[0].getQ(); + if (shouldDoFLERewrite(request())) { + query = processFLEWriteExplainD( + opCtx, write_ops::collationOf(request().getDeletes()[0]), request(), query); + } + deleteRequest.setQuery(std::move(query)); + deleteRequest.setCollation(write_ops::collationOf(request().getDeletes()[0])); deleteRequest.setMulti(request().getDeletes()[0].getMulti()); deleteRequest.setYieldPolicy(PlanYieldPolicy::YieldPolicy::YIELD_AUTO); diff --git a/src/mongo/db/fle_crud.cpp b/src/mongo/db/fle_crud.cpp index 6d606d93d2a..990a1c553a2 100644 --- a/src/mongo/db/fle_crud.cpp +++ b/src/mongo/db/fle_crud.cpp @@ -855,6 +855,46 @@ FLEBatchResult processFLEBatch(OperationContext* opCtx, MONGO_UNREACHABLE; } +std::unique_ptr<BatchedCommandRequest> processFLEBatchExplain( + OperationContext* opCtx, const BatchedCommandRequest& request) { + invariant(request.hasEncryptionInformation()); + auto getExpCtx = [&](const auto& op) { + auto expCtx = make_intrusive<ExpressionContext>( + opCtx, + fle::collatorFromBSON(opCtx, op.getCollation().value_or(BSONObj())), + request.getNS(), + request.getLegacyRuntimeConstants(), + request.getLet()); + expCtx->stopExpressionCounters(); + return expCtx; + }; + + if (request.getBatchType() == BatchedCommandRequest::BatchType_Delete) { + auto deleteRequest = request.getDeleteRequest(); + auto newDeleteOp = deleteRequest.getDeletes()[0]; + newDeleteOp.setQ(fle::rewriteQuery(opCtx, + getExpCtx(newDeleteOp), + request.getNS(), + deleteRequest.getEncryptionInformation().get(), + newDeleteOp.getQ(), + &getTransactionWithRetriesForMongoS)); + deleteRequest.setDeletes({newDeleteOp}); + return std::make_unique<BatchedCommandRequest>(deleteRequest); + } else if (request.getBatchType() == BatchedCommandRequest::BatchType_Update) { + auto updateRequest = request.getUpdateRequest(); + auto newUpdateOp = updateRequest.getUpdates()[0]; + newUpdateOp.setQ(fle::rewriteQuery(opCtx, + getExpCtx(newUpdateOp), + request.getNS(), + updateRequest.getEncryptionInformation().get(), + newUpdateOp.getQ(), + &getTransactionWithRetriesForMongoS)); + updateRequest.setUpdates({newUpdateOp}); + return std::make_unique<BatchedCommandRequest>(updateRequest); + } + MONGO_UNREACHABLE; +} + // See processUpdate for algorithm overview write_ops::FindAndModifyCommandReply processFindAndModify( boost::intrusive_ptr<ExpressionContext> expCtx, diff --git a/src/mongo/db/fle_crud.h b/src/mongo/db/fle_crud.h index 1b36850c213..13e53f3406d 100644 --- a/src/mongo/db/fle_crud.h +++ b/src/mongo/db/fle_crud.h @@ -29,9 +29,10 @@ #pragma once -#include "boost/smart_ptr/intrusive_ptr.hpp" #include <cstdint> +#include "boost/smart_ptr/intrusive_ptr.hpp" + #include "mongo/bson/bsonobj.h" #include "mongo/bson/oid.h" #include "mongo/crypto/fle_crypto.h" @@ -72,6 +73,12 @@ FLEBatchResult processFLEBatch(OperationContext* opCtx, BatchedCommandResponse* response, boost::optional<OID> targetEpoch); +/** + * Rewrite a BatchedCommandRequest for explain commands. + */ +std::unique_ptr<BatchedCommandRequest> processFLEBatchExplain(OperationContext* opCtx, + const BatchedCommandRequest& request); + /** * Initialize the FLE CRUD subsystem on Mongod. @@ -98,6 +105,38 @@ write_ops::DeleteCommandReply processFLEDelete( OperationContext* opCtx, const write_ops::DeleteCommandRequest& deleteRequest); /** + * Rewrite the query within a replica set explain command for delete and update. + * This concrete function is passed all the parameters directly. + */ +BSONObj processFLEWriteExplainD(OperationContext* opCtx, + const BSONObj& collation, + const NamespaceString& nss, + const EncryptionInformation& info, + const boost::optional<LegacyRuntimeConstants>& runtimeConstants, + const boost::optional<BSONObj>& letParameters, + const BSONObj& query); + +/** + * Rewrite the query within a replica set explain command for delete and update. + * This template is passed the request object from the command and delegates + * to the function above. + */ +template <typename T> +BSONObj processFLEWriteExplainD(OperationContext* opCtx, + const BSONObj& collation, + const T& request, + const BSONObj& query) { + + return processFLEWriteExplainD(opCtx, + collation, + request.getNamespace(), + request.getEncryptionInformation().get(), + request.getLegacyRuntimeConstants(), + request.getLet(), + query); +} + +/** * Process a replica set update. */ write_ops::UpdateCommandReply processFLEUpdate( diff --git a/src/mongo/db/fle_crud_mongod.cpp b/src/mongo/db/fle_crud_mongod.cpp index cefa6a9e0bf..7e8eea293fa 100644 --- a/src/mongo/db/fle_crud_mongod.cpp +++ b/src/mongo/db/fle_crud_mongod.cpp @@ -253,4 +253,16 @@ std::unique_ptr<Pipeline, PipelineDeleter> processFLEPipelineD( return fle::processPipeline( opCtx, nss, encryptInfo, std::move(toRewrite), &getTransactionWithRetriesForMongoD); } + +BSONObj processFLEWriteExplainD(OperationContext* opCtx, + const BSONObj& collation, + const NamespaceString& nss, + const EncryptionInformation& info, + const boost::optional<LegacyRuntimeConstants>& runtimeConstants, + const boost::optional<BSONObj>& letParameters, + const BSONObj& query) { + auto expCtx = make_intrusive<ExpressionContext>( + opCtx, fle::collatorFromBSON(opCtx, collation), nss, runtimeConstants, letParameters); + return fle::rewriteQuery(opCtx, expCtx, nss, info, query, &getTransactionWithRetriesForMongoD); +} } // namespace mongo diff --git a/src/mongo/db/query/fle/server_rewrite.cpp b/src/mongo/db/query/fle/server_rewrite.cpp index d08b9173884..876b0c6823d 100644 --- a/src/mongo/db/query/fle/server_rewrite.cpp +++ b/src/mongo/db/query/fle/server_rewrite.cpp @@ -32,12 +32,15 @@ #include <memory> +#include "mongo/bson/bsonobj.h" #include "mongo/bson/bsonobjbuilder.h" #include "mongo/bson/bsontypes.h" #include "mongo/crypto/encryption_fields_gen.h" #include "mongo/crypto/fle_crypto.h" +#include "mongo/crypto/fle_field_schema_gen.h" #include "mongo/crypto/fle_tags.h" #include "mongo/db/fle_crud.h" +#include "mongo/db/operation_context.h" #include "mongo/db/pipeline/document_source_geo_near.h" #include "mongo/db/pipeline/document_source_graph_lookup.h" #include "mongo/db/pipeline/document_source_match.h" @@ -47,63 +50,27 @@ #include "mongo/util/assert_util.h" namespace mongo::fle { -namespace { -BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader, - const FLEStateCollectionReader& eccReader, - boost::intrusive_ptr<ExpressionContext> expCtx, - BSONObj filter) { - return MatchExpressionRewrite(expCtx, escReader, eccReader, filter).get(); -} - -/** - * Make an expression context from a find command. - */ -boost::intrusive_ptr<ExpressionContext> makeExpCtx(OperationContext* opCtx, - const NamespaceString& nss, - FindCommandRequest* findCommand) { +// TODO: This is a generally useful helper function that should probably go in some other namespace. +std::unique_ptr<CollatorInterface> collatorFromBSON(OperationContext* opCtx, + const BSONObj& collation) { std::unique_ptr<CollatorInterface> collator; - if (!findCommand->getCollation().isEmpty()) { - auto statusWithCollator = CollatorFactoryInterface::get(opCtx->getServiceContext()) - ->makeFromBSON(findCommand->getCollation()); - + if (!collation.isEmpty()) { + auto statusWithCollator = + CollatorFactoryInterface::get(opCtx->getServiceContext())->makeFromBSON(collation); uassertStatusOK(statusWithCollator.getStatus()); collator = std::move(statusWithCollator.getValue()); } - auto expCtx = make_intrusive<ExpressionContext>(opCtx, - std::move(collator), - nss, - findCommand->getLegacyRuntimeConstants(), - findCommand->getLet()); - expCtx->stopExpressionCounters(); - return expCtx; + return collator; } - -boost::intrusive_ptr<ExpressionContext> makeExpCtx(OperationContext* opCtx, - const NamespaceString& nss, - CountCommandRequest* countCommand) { - - std::unique_ptr<CollatorInterface> collator; - if (countCommand->getCollation()) { - auto statusWithCollator = CollatorFactoryInterface::get(opCtx->getServiceContext()) - ->makeFromBSON(countCommand->getCollation().get()); - - uassertStatusOK(statusWithCollator.getStatus()); - collator = std::move(statusWithCollator.getValue()); - } - auto expCtx = make_intrusive<ExpressionContext>( - opCtx, - std::move(collator), - nss, - // Count command does not have legacy runtime constants, and does not support user variables - // defined in a let expression. - boost::none, - boost::none); - expCtx->stopExpressionCounters(); - return expCtx; +namespace { +BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader, + const FLEStateCollectionReader& eccReader, + boost::intrusive_ptr<ExpressionContext> expCtx, + BSONObj filter) { + return MatchExpressionRewrite(expCtx, escReader, eccReader, filter).get(); } - class RewriteBase { public: RewriteBase(boost::intrusive_ptr<ExpressionContext> expCtx, @@ -225,6 +192,17 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl, return rewriteEncryptedFilter(escReader, eccReader, expCtx, filter); } +BSONObj rewriteQuery(OperationContext* opCtx, + boost::intrusive_ptr<ExpressionContext> expCtx, + const NamespaceString& nss, + const EncryptionInformation& info, + BSONObj filter, + GetTxnCallback getTransaction) { + auto sharedBlock = std::make_shared<FilterRewrite>(expCtx, nss, info, filter); + doFLERewriteInTxn(opCtx, sharedBlock, getTransaction); + return sharedBlock->rewrittenFilter.getOwned(); +} + void processFindCommand(OperationContext* opCtx, const NamespaceString& nss, @@ -232,15 +210,19 @@ void processFindCommand(OperationContext* opCtx, GetTxnCallback getTransaction) { invariant(findCommand->getEncryptionInformation()); - auto sharedBlock = - std::make_shared<FilterRewrite>(makeExpCtx(opCtx, nss, findCommand), + auto expCtx = + make_intrusive<ExpressionContext>(opCtx, + collatorFromBSON(opCtx, findCommand->getCollation()), + nss, + findCommand->getLegacyRuntimeConstants(), + findCommand->getLet()); + expCtx->stopExpressionCounters(); + findCommand->setFilter(rewriteQuery(opCtx, + expCtx, nss, findCommand->getEncryptionInformation().get(), - findCommand->getFilter().getOwned()); - doFLERewriteInTxn(opCtx, sharedBlock, getTransaction); - - auto rewrittenFilter = sharedBlock->rewrittenFilter.getOwned(); - findCommand->setFilter(std::move(rewrittenFilter)); + findCommand->getFilter().getOwned(), + getTransaction)); // The presence of encryptionInformation is a signal that this is a FLE request that requires // special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but // a normal query that can be executed by the query system like any other, so remove @@ -253,16 +235,18 @@ void processCountCommand(OperationContext* opCtx, CountCommandRequest* countCommand, GetTxnCallback getTxn) { invariant(countCommand->getEncryptionInformation()); + // Count command does not have legacy runtime constants, and does not support user variables + // defined in a let expression. + auto expCtx = make_intrusive<ExpressionContext>( + opCtx, collatorFromBSON(opCtx, countCommand->getCollation().value_or(BSONObj())), nss); + expCtx->stopExpressionCounters(); - auto sharedBlock = - std::make_shared<FilterRewrite>(makeExpCtx(opCtx, nss, countCommand), + countCommand->setQuery(rewriteQuery(opCtx, + expCtx, nss, countCommand->getEncryptionInformation().get(), - countCommand->getQuery().getOwned()); - doFLERewriteInTxn(opCtx, sharedBlock, getTxn); - - auto rewrittenFilter = sharedBlock->rewrittenFilter.getOwned(); - countCommand->setQuery(std::move(rewrittenFilter)); + countCommand->getQuery().getOwned(), + getTxn)); // The presence of encryptionInformation is a signal that this is a FLE request that requires // special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but // a normal query that can be executed by the query system like any other, so remove diff --git a/src/mongo/db/query/fle/server_rewrite.h b/src/mongo/db/query/fle/server_rewrite.h index e85c6911ea4..c653932fe5c 100644 --- a/src/mongo/db/query/fle/server_rewrite.h +++ b/src/mongo/db/query/fle/server_rewrite.h @@ -47,6 +47,24 @@ class FLEQueryInterface; namespace fle { /** + * Make a collator object from its BSON representation. Useful when creating ExpressionContext + * objects for parsing MatchExpressions as part of the server-side rewrite. + */ +std::unique_ptr<CollatorInterface> collatorFromBSON(OperationContext* opCtx, + const BSONObj& collation); + +/** + * Return a rewritten version of the passed-in filter as a BSONObj. Generally used by other + * functions to process MatchExpressions in each command. + */ +BSONObj rewriteQuery(OperationContext* opCtx, + boost::intrusive_ptr<ExpressionContext> expCtx, + const NamespaceString& nss, + const EncryptionInformation& info, + BSONObj filter, + GetTxnCallback getTransaction); + +/** * Process a find command with encryptionInformation in-place, rewriting the filter condition so * that any query on an encrypted field will properly query the underlying tags array. */ diff --git a/src/mongo/s/commands/SConscript b/src/mongo/s/commands/SConscript index 0a1796119aa..5cdade782f9 100644 --- a/src/mongo/s/commands/SConscript +++ b/src/mongo/s/commands/SConscript @@ -172,6 +172,7 @@ env.Library( ], LIBDEPS_PRIVATE=[ '$BUILD_DIR/mongo/db/commands/core', + '$BUILD_DIR/mongo/db/fle_crud', '$BUILD_DIR/mongo/db/initialize_api_parameters', '$BUILD_DIR/mongo/db/internal_transactions_feature_flag', '$BUILD_DIR/mongo/db/read_write_concern_defaults', diff --git a/src/mongo/s/commands/cluster_write_cmd.cpp b/src/mongo/s/commands/cluster_write_cmd.cpp index 47eb18b4b0e..3810f0a4b3d 100644 --- a/src/mongo/s/commands/cluster_write_cmd.cpp +++ b/src/mongo/s/commands/cluster_write_cmd.cpp @@ -35,6 +35,7 @@ #include "mongo/client/remote_command_targeter.h" #include "mongo/db/catalog/document_validation.h" #include "mongo/db/curop.h" +#include "mongo/db/fle_crud.h" #include "mongo/db/internal_transactions_feature_flag_gen.h" #include "mongo/db/pipeline/lite_parsed_pipeline.h" #include "mongo/db/stats/counters.h" @@ -644,22 +645,33 @@ void ClusterWriteCmd::InvocationBase::explain(OperationContext* opCtx, "explained write batches must be of size 1", _batchedRequest.sizeWriteOps() == 1U); - const auto explainCmd = ClusterExplain::wrapAsExplain(_request->body, verbosity); + + std::unique_ptr<BatchedCommandRequest> req; + if (_batchedRequest.hasEncryptionInformation() && + (_batchedRequest.getBatchType() == BatchedCommandRequest::BatchType_Delete || + _batchedRequest.getBatchType() == BatchedCommandRequest::BatchType_Update)) { + req = processFLEBatchExplain(opCtx, _batchedRequest); + } + + auto nss = req ? req->getNS() : _batchedRequest.getNS(); + auto requestBSON = req ? req->toBSON() : _request->body; + auto requestPtr = req ? req.get() : &_batchedRequest; + + const auto explainCmd = ClusterExplain::wrapAsExplain(requestBSON, verbosity); // We will time how long it takes to run the commands on the shards. Timer timer; // Target the command to the shards based on the singleton batch item. - BatchItemRef targetingBatchItem(&_batchedRequest, 0); + BatchItemRef targetingBatchItem(requestPtr, 0); std::vector<AsyncRequestsSender::Response> shardResponses; - _commandOpWrite( - opCtx, _batchedRequest.getNS(), explainCmd, targetingBatchItem, &shardResponses); + _commandOpWrite(opCtx, nss, explainCmd, targetingBatchItem, &shardResponses); auto bodyBuilder = result->getBodyBuilder(); uassertStatusOK(ClusterExplain::buildExplainResult(opCtx, shardResponses, ClusterExplain::kWriteOnShards, timer.millis(), - _request->body, + requestBSON, &bodyBuilder)); } |