From 15bf8f14132712c419a7de16c12396abf40c51df Mon Sep 17 00:00:00 2001 From: Davis Haupt Date: Tue, 5 Apr 2022 14:24:44 +0000 Subject: SERVER-64360 Server-side rewrite for count command --- buildscripts/resmokeconfig/suites/fle2.yml | 4 +- src/mongo/db/SConscript | 1 + src/mongo/db/commands/count_cmd.cpp | 8 ++++ src/mongo/db/commands/find_cmd.cpp | 3 +- src/mongo/db/fle_crud.cpp | 12 +++++- src/mongo/db/fle_crud.h | 24 +++++++++++- src/mongo/db/fle_crud_mongod.cpp | 12 +++++- src/mongo/db/query/fle/server_rewrite.cpp | 60 +++++++++++++++++++++++++++--- src/mongo/db/query/fle/server_rewrite.h | 16 +++++++- src/mongo/s/commands/cluster_count_cmd.cpp | 33 ++++++++-------- src/mongo/s/commands/cluster_find_cmd.h | 4 +- 11 files changed, 146 insertions(+), 31 deletions(-) diff --git a/buildscripts/resmokeconfig/suites/fle2.yml b/buildscripts/resmokeconfig/suites/fle2.yml index 6847e0121d2..33f2537f3eb 100644 --- a/buildscripts/resmokeconfig/suites/fle2.yml +++ b/buildscripts/resmokeconfig/suites/fle2.yml @@ -2,7 +2,9 @@ test_kind: js_test selector: roots: - jstests/fle2/**/*.js - - src/mongo/db/modules/*/jstests/fle2/**/*.js + - src/mongo/db/modules/*/jstests/fle2/*.js + - src/mongo/db/modules/*/jstests/fle2/query/*.js + executor: archive: hooks: diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript index 3353b74fd91..ef1e749de69 100644 --- a/src/mongo/db/SConscript +++ b/src/mongo/db/SConscript @@ -870,6 +870,7 @@ env.Library( '$BUILD_DIR/mongo/crypto/encrypted_field_config', '$BUILD_DIR/mongo/db/ops/write_ops_parsers', '$BUILD_DIR/mongo/db/pipeline/pipeline', + '$BUILD_DIR/mongo/db/query/command_request_response', 'transaction_api', ], LIBDEPS_PRIVATE=[ diff --git a/src/mongo/db/commands/count_cmd.cpp b/src/mongo/db/commands/count_cmd.cpp index 44d00d4dcb2..c548bf3ef1c 100644 --- a/src/mongo/db/commands/count_cmd.cpp +++ b/src/mongo/db/commands/count_cmd.cpp @@ -40,6 +40,7 @@ #include "mongo/db/curop_failpoint_helpers.h" #include "mongo/db/db_raii.h" #include "mongo/db/exec/count.h" +#include "mongo/db/fle_crud.h" #include "mongo/db/pipeline/aggregation_request_helper.h" #include "mongo/db/query/collection_query_info.h" #include "mongo/db/query/count_command_as_aggregation_command.h" @@ -158,6 +159,10 @@ public: return exceptionToStatus(); } + if (shouldDoFLERewrite(request)) { + processFLECountD(opCtx, nss, &request); + } + if (ctx->getView()) { // Relinquish locks. The aggregation command will re-acquire them. ctx.reset(); @@ -231,6 +236,9 @@ public: &hangBeforeCollectionCount, opCtx, "hangBeforeCollectionCount", []() {}, nss); auto request = CountCommandRequest::parse(IDLParserErrorContext("count"), cmdObj); + if (shouldDoFLERewrite(request)) { + processFLECountD(opCtx, nss, &request); + } if (ctx->getView()) { auto viewAggregation = countCommandAsAggregationCommand(request, nss); diff --git a/src/mongo/db/commands/find_cmd.cpp b/src/mongo/db/commands/find_cmd.cpp index 9f596d34a4f..e58562a6b75 100644 --- a/src/mongo/db/commands/find_cmd.cpp +++ b/src/mongo/db/commands/find_cmd.cpp @@ -122,7 +122,8 @@ std::unique_ptr parseCmdObjectToFindCommandRequest(Operation // Rewrite any FLE find payloads that exist in the query if this is a FLE 2 query. if (shouldDoFLERewrite(findCommand)) { - processFLEFindD(opCtx, findCommand.get()); + invariant(findCommand->getNamespaceOrUUID().nss()); + processFLEFindD(opCtx, findCommand->getNamespaceOrUUID().nss().get(), findCommand.get()); } return translateNtoReturnToLimitOrBatchSize(std::move(findCommand)); diff --git a/src/mongo/db/fle_crud.cpp b/src/mongo/db/fle_crud.cpp index 296e9c76cf9..6d606d93d2a 100644 --- a/src/mongo/db/fle_crud.cpp +++ b/src/mongo/db/fle_crud.cpp @@ -1256,8 +1256,16 @@ std::vector FLEQueryInterfaceImpl::findDocuments(const NamespaceString& return _txnClient.exhaustiveFind(find).get(); } -void processFLEFindS(OperationContext* opCtx, FindCommandRequest* findCommand) { - fle::processFindCommand(opCtx, findCommand, &getTransactionWithRetriesForMongoS); +void processFLEFindS(OperationContext* opCtx, + const NamespaceString& nss, + FindCommandRequest* findCommand) { + fle::processFindCommand(opCtx, nss, findCommand, &getTransactionWithRetriesForMongoS); +} + +void processFLECountS(OperationContext* opCtx, + const NamespaceString& nss, + CountCommandRequest* countCommand) { + fle::processCountCommand(opCtx, nss, countCommand, &getTransactionWithRetriesForMongoS); } std::unique_ptr processFLEPipelineS( diff --git a/src/mongo/db/fle_crud.h b/src/mongo/db/fle_crud.h index 303c8704015..1b36850c213 100644 --- a/src/mongo/db/fle_crud.h +++ b/src/mongo/db/fle_crud.h @@ -40,6 +40,7 @@ #include "mongo/db/operation_context.h" #include "mongo/db/ops/write_ops_gen.h" #include "mongo/db/pipeline/pipeline.h" +#include "mongo/db/query/count_command_gen.h" #include "mongo/db/transaction_api.h" #include "mongo/s/write_ops/batch_write_exec.h" #include "mongo/s/write_ops/batched_command_response.h" @@ -118,12 +119,31 @@ write_ops::FindAndModifyCommandReply processFLEFindAndModify( /** * Process a find command from mongos. */ -void processFLEFindS(OperationContext* opCtx, FindCommandRequest* findCommand); +void processFLEFindS(OperationContext* opCtx, + const NamespaceString& nss, + FindCommandRequest* findCommand); /** * Process a find command from a replica set. */ -void processFLEFindD(OperationContext* opCtx, FindCommandRequest* findCommand); +void processFLEFindD(OperationContext* opCtx, + const NamespaceString& nss, + FindCommandRequest* findCommand); + + +/** + * Process a find command from mongos. + */ +void processFLECountS(OperationContext* opCtx, + const NamespaceString& nss, + CountCommandRequest* countCommand); + +/** + * Process a find command from a replica set. + */ +void processFLECountD(OperationContext* opCtx, + const NamespaceString& nss, + CountCommandRequest* countCommand); /** * Process a pipeline from mongos. diff --git a/src/mongo/db/fle_crud_mongod.cpp b/src/mongo/db/fle_crud_mongod.cpp index 1308c8cb136..cefa6a9e0bf 100644 --- a/src/mongo/db/fle_crud_mongod.cpp +++ b/src/mongo/db/fle_crud_mongod.cpp @@ -233,8 +233,16 @@ write_ops::UpdateCommandReply processFLEUpdate( return updateReply; } -void processFLEFindD(OperationContext* opCtx, FindCommandRequest* findCommand) { - fle::processFindCommand(opCtx, findCommand, &getTransactionWithRetriesForMongoD); +void processFLEFindD(OperationContext* opCtx, + const NamespaceString& nss, + FindCommandRequest* findCommand) { + fle::processFindCommand(opCtx, nss, findCommand, &getTransactionWithRetriesForMongoD); +} + +void processFLECountD(OperationContext* opCtx, + const NamespaceString& nss, + CountCommandRequest* countCommand) { + fle::processCountCommand(opCtx, nss, countCommand, &getTransactionWithRetriesForMongoD); } std::unique_ptr processFLEPipelineD( diff --git a/src/mongo/db/query/fle/server_rewrite.cpp b/src/mongo/db/query/fle/server_rewrite.cpp index 2c6480b44b5..d08b9173884 100644 --- a/src/mongo/db/query/fle/server_rewrite.cpp +++ b/src/mongo/db/query/fle/server_rewrite.cpp @@ -59,8 +59,8 @@ BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader, * Make an expression context from a find command. */ boost::intrusive_ptr makeExpCtx(OperationContext* opCtx, + const NamespaceString& nss, FindCommandRequest* findCommand) { - invariant(findCommand->getNamespaceOrUUID().nss()); std::unique_ptr collator; if (!findCommand->getCollation().isEmpty()) { @@ -72,13 +72,37 @@ boost::intrusive_ptr makeExpCtx(OperationContext* opCtx, } auto expCtx = make_intrusive(opCtx, std::move(collator), - findCommand->getNamespaceOrUUID().nss().get(), + nss, findCommand->getLegacyRuntimeConstants(), findCommand->getLet()); expCtx->stopExpressionCounters(); return expCtx; } +boost::intrusive_ptr makeExpCtx(OperationContext* opCtx, + const NamespaceString& nss, + CountCommandRequest* countCommand) { + + std::unique_ptr collator; + if (countCommand->getCollation()) { + auto statusWithCollator = CollatorFactoryInterface::get(opCtx->getServiceContext()) + ->makeFromBSON(countCommand->getCollation().get()); + + uassertStatusOK(statusWithCollator.getStatus()); + collator = std::move(statusWithCollator.getValue()); + } + auto expCtx = make_intrusive( + opCtx, + std::move(collator), + nss, + // Count command does not have legacy runtime constants, and does not support user variables + // defined in a let expression. + boost::none, + boost::none); + expCtx->stopExpressionCounters(); + return expCtx; +} + class RewriteBase { public: @@ -203,23 +227,49 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl, void processFindCommand(OperationContext* opCtx, + const NamespaceString& nss, FindCommandRequest* findCommand, GetTxnCallback getTransaction) { - invariant(findCommand->getNamespaceOrUUID().nss()); invariant(findCommand->getEncryptionInformation()); auto sharedBlock = - std::make_shared(makeExpCtx(opCtx, findCommand), - findCommand->getNamespaceOrUUID().nss().get(), + std::make_shared(makeExpCtx(opCtx, nss, findCommand), + nss, findCommand->getEncryptionInformation().get(), findCommand->getFilter().getOwned()); doFLERewriteInTxn(opCtx, sharedBlock, getTransaction); auto rewrittenFilter = sharedBlock->rewrittenFilter.getOwned(); findCommand->setFilter(std::move(rewrittenFilter)); + // The presence of encryptionInformation is a signal that this is a FLE request that requires + // special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but + // a normal query that can be executed by the query system like any other, so remove + // encryptionInformation. findCommand->setEncryptionInformation(boost::none); } +void processCountCommand(OperationContext* opCtx, + const NamespaceString& nss, + CountCommandRequest* countCommand, + GetTxnCallback getTxn) { + invariant(countCommand->getEncryptionInformation()); + + auto sharedBlock = + std::make_shared(makeExpCtx(opCtx, nss, countCommand), + nss, + countCommand->getEncryptionInformation().get(), + countCommand->getQuery().getOwned()); + doFLERewriteInTxn(opCtx, sharedBlock, getTxn); + + auto rewrittenFilter = sharedBlock->rewrittenFilter.getOwned(); + countCommand->setQuery(std::move(rewrittenFilter)); + // The presence of encryptionInformation is a signal that this is a FLE request that requires + // special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but + // a normal query that can be executed by the query system like any other, so remove + // encryptionInformation. + countCommand->setEncryptionInformation(boost::none); +} + std::unique_ptr processPipeline( OperationContext* opCtx, NamespaceString nss, diff --git a/src/mongo/db/query/fle/server_rewrite.h b/src/mongo/db/query/fle/server_rewrite.h index 9a94e0a0aa8..e85c6911ea4 100644 --- a/src/mongo/db/query/fle/server_rewrite.h +++ b/src/mongo/db/query/fle/server_rewrite.h @@ -39,6 +39,7 @@ #include "mongo/db/matcher/expression_parser.h" #include "mongo/db/namespace_string.h" #include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/query/count_command_gen.h" #include "mongo/db/transaction_api.h" namespace mongo { @@ -50,13 +51,24 @@ namespace fle { * that any query on an encrypted field will properly query the underlying tags array. */ void processFindCommand(OperationContext* opCtx, + const NamespaceString& nss, FindCommandRequest* findCommand, GetTxnCallback txn); +/** + * Process a count command with encryptionInformation in-place, rewriting the filter condition so + * that any query on an encrypted field will properly query the underlying tags array. + */ +void processCountCommand(OperationContext* opCtx, + const NamespaceString& nss, + CountCommandRequest* countCommand, + GetTxnCallback getTxn); + /** * Process a pipeline with encryptionInformation by rewriting the pipeline to query against the - * underlying tags array, where appropriate. After this rewriting is complete, there is no more FLE - * work to be done. The encryption info does not need to be kept around (e.g. on a command object). + * underlying tags array, where appropriate. After this rewriting is complete, there is no more + * FLE work to be done. The encryption info does not need to be kept around (e.g. on a command + * object). */ std::unique_ptr processPipeline( OperationContext* opCtx, diff --git a/src/mongo/s/commands/cluster_count_cmd.cpp b/src/mongo/s/commands/cluster_count_cmd.cpp index 588a5e41814..48545c7c0d7 100644 --- a/src/mongo/s/commands/cluster_count_cmd.cpp +++ b/src/mongo/s/commands/cluster_count_cmd.cpp @@ -33,6 +33,7 @@ #include "mongo/bson/util/bson_extract.h" #include "mongo/db/commands.h" +#include "mongo/db/fle_crud.h" #include "mongo/db/query/count_command_as_aggregation_command.h" #include "mongo/db/query/count_command_gen.h" #include "mongo/db/query/view_response_formatter.h" @@ -100,6 +101,9 @@ public: std::vector shardResponses; try { auto countRequest = CountCommandRequest::parse(IDLParserErrorContext("count"), cmdObj); + if (shouldDoFLERewrite(countRequest)) { + processFLECountS(opCtx, nss, &countRequest); + } // We only need to factor in the skip value when sending to the shards if we // have a value for limit, otherwise, we apply it only once we have collected all @@ -192,29 +196,28 @@ public: rpc::ReplyBuilderInterface* result) const override { std::string dbname = request.getDatabase().toString(); const BSONObj& cmdObj = request.body; + + CountCommandRequest countRequest(NamespaceStringOrUUID(NamespaceString{})); + try { + countRequest = CountCommandRequest::parse(IDLParserErrorContext("count"), request); + } catch (...) { + return exceptionToStatus(); + } + const NamespaceString nss(parseNs(dbname, cmdObj)); uassert(ErrorCodes::InvalidNamespace, str::stream() << "Invalid namespace specified '" << nss.ns() << "'", nss.isValid()); - // Extract the targeting query. - BSONObj targetingQuery; - if (Object == cmdObj["query"].type()) { - targetingQuery = cmdObj["query"].Obj(); + // If the command has encryptionInformation, rewrite the query as necessary. + if (shouldDoFLERewrite(countRequest)) { + processFLECountS(opCtx, nss, &countRequest); } - // Extract the targeting collation. - BSONObj targetingCollation; - BSONElement targetingCollationElement; - auto status = bsonExtractTypedField( - cmdObj, "collation", BSONType::Object, &targetingCollationElement); - if (status.isOK()) { - targetingCollation = targetingCollationElement.Obj(); - } else if (status != ErrorCodes::NoSuchKey) { - return status; - } + BSONObj targetingQuery = countRequest.getQuery(); + BSONObj targetingCollation = countRequest.getCollation().value_or(BSONObj()); - const auto explainCmd = ClusterExplain::wrapAsExplain(cmdObj, verbosity); + const auto explainCmd = ClusterExplain::wrapAsExplain(countRequest.toBSON({}), verbosity); // We will time how long it takes to run the commands on the shards Timer timer; diff --git a/src/mongo/s/commands/cluster_find_cmd.h b/src/mongo/s/commands/cluster_find_cmd.h index 888da770530..efbc0bd4d72 100644 --- a/src/mongo/s/commands/cluster_find_cmd.h +++ b/src/mongo/s/commands/cluster_find_cmd.h @@ -289,7 +289,9 @@ public: findCommand->getNtoreturn() == boost::none); if (shouldDoFLERewrite(findCommand)) { - processFLEFindS(opCtx, findCommand.get()); + invariant(findCommand->getNamespaceOrUUID().nss()); + processFLEFindS( + opCtx, findCommand->getNamespaceOrUUID().nss().get(), findCommand.get()); } return findCommand; -- cgit v1.2.1