summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavis Haupt <davis.haupt@mongodb.com>2022-04-05 14:24:44 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-04-05 15:07:53 +0000
commit15bf8f14132712c419a7de16c12396abf40c51df (patch)
tree7d455f5ad5201ec16ea3a476957e52a2b5f656ee
parentfbee74f86aff6d7c3110a757b951d575b34cc921 (diff)
downloadmongo-15bf8f14132712c419a7de16c12396abf40c51df.tar.gz
SERVER-64360 Server-side rewrite for count command
-rw-r--r--buildscripts/resmokeconfig/suites/fle2.yml4
-rw-r--r--src/mongo/db/SConscript1
-rw-r--r--src/mongo/db/commands/count_cmd.cpp8
-rw-r--r--src/mongo/db/commands/find_cmd.cpp3
-rw-r--r--src/mongo/db/fle_crud.cpp12
-rw-r--r--src/mongo/db/fle_crud.h24
-rw-r--r--src/mongo/db/fle_crud_mongod.cpp12
-rw-r--r--src/mongo/db/query/fle/server_rewrite.cpp60
-rw-r--r--src/mongo/db/query/fle/server_rewrite.h16
-rw-r--r--src/mongo/s/commands/cluster_count_cmd.cpp33
-rw-r--r--src/mongo/s/commands/cluster_find_cmd.h4
11 files changed, 146 insertions, 31 deletions
diff --git a/buildscripts/resmokeconfig/suites/fle2.yml b/buildscripts/resmokeconfig/suites/fle2.yml
index 6847e0121d2..33f2537f3eb 100644
--- a/buildscripts/resmokeconfig/suites/fle2.yml
+++ b/buildscripts/resmokeconfig/suites/fle2.yml
@@ -2,7 +2,9 @@ test_kind: js_test
selector:
roots:
- jstests/fle2/**/*.js
- - src/mongo/db/modules/*/jstests/fle2/**/*.js
+ - src/mongo/db/modules/*/jstests/fle2/*.js
+ - src/mongo/db/modules/*/jstests/fle2/query/*.js
+
executor:
archive:
hooks:
diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript
index 3353b74fd91..ef1e749de69 100644
--- a/src/mongo/db/SConscript
+++ b/src/mongo/db/SConscript
@@ -870,6 +870,7 @@ env.Library(
'$BUILD_DIR/mongo/crypto/encrypted_field_config',
'$BUILD_DIR/mongo/db/ops/write_ops_parsers',
'$BUILD_DIR/mongo/db/pipeline/pipeline',
+ '$BUILD_DIR/mongo/db/query/command_request_response',
'transaction_api',
],
LIBDEPS_PRIVATE=[
diff --git a/src/mongo/db/commands/count_cmd.cpp b/src/mongo/db/commands/count_cmd.cpp
index 44d00d4dcb2..c548bf3ef1c 100644
--- a/src/mongo/db/commands/count_cmd.cpp
+++ b/src/mongo/db/commands/count_cmd.cpp
@@ -40,6 +40,7 @@
#include "mongo/db/curop_failpoint_helpers.h"
#include "mongo/db/db_raii.h"
#include "mongo/db/exec/count.h"
+#include "mongo/db/fle_crud.h"
#include "mongo/db/pipeline/aggregation_request_helper.h"
#include "mongo/db/query/collection_query_info.h"
#include "mongo/db/query/count_command_as_aggregation_command.h"
@@ -158,6 +159,10 @@ public:
return exceptionToStatus();
}
+ if (shouldDoFLERewrite(request)) {
+ processFLECountD(opCtx, nss, &request);
+ }
+
if (ctx->getView()) {
// Relinquish locks. The aggregation command will re-acquire them.
ctx.reset();
@@ -231,6 +236,9 @@ public:
&hangBeforeCollectionCount, opCtx, "hangBeforeCollectionCount", []() {}, nss);
auto request = CountCommandRequest::parse(IDLParserErrorContext("count"), cmdObj);
+ if (shouldDoFLERewrite(request)) {
+ processFLECountD(opCtx, nss, &request);
+ }
if (ctx->getView()) {
auto viewAggregation = countCommandAsAggregationCommand(request, nss);
diff --git a/src/mongo/db/commands/find_cmd.cpp b/src/mongo/db/commands/find_cmd.cpp
index 9f596d34a4f..e58562a6b75 100644
--- a/src/mongo/db/commands/find_cmd.cpp
+++ b/src/mongo/db/commands/find_cmd.cpp
@@ -122,7 +122,8 @@ std::unique_ptr<FindCommandRequest> parseCmdObjectToFindCommandRequest(Operation
// Rewrite any FLE find payloads that exist in the query if this is a FLE 2 query.
if (shouldDoFLERewrite(findCommand)) {
- processFLEFindD(opCtx, findCommand.get());
+ invariant(findCommand->getNamespaceOrUUID().nss());
+ processFLEFindD(opCtx, findCommand->getNamespaceOrUUID().nss().get(), findCommand.get());
}
return translateNtoReturnToLimitOrBatchSize(std::move(findCommand));
diff --git a/src/mongo/db/fle_crud.cpp b/src/mongo/db/fle_crud.cpp
index 296e9c76cf9..6d606d93d2a 100644
--- a/src/mongo/db/fle_crud.cpp
+++ b/src/mongo/db/fle_crud.cpp
@@ -1256,8 +1256,16 @@ std::vector<BSONObj> FLEQueryInterfaceImpl::findDocuments(const NamespaceString&
return _txnClient.exhaustiveFind(find).get();
}
-void processFLEFindS(OperationContext* opCtx, FindCommandRequest* findCommand) {
- fle::processFindCommand(opCtx, findCommand, &getTransactionWithRetriesForMongoS);
+void processFLEFindS(OperationContext* opCtx,
+ const NamespaceString& nss,
+ FindCommandRequest* findCommand) {
+ fle::processFindCommand(opCtx, nss, findCommand, &getTransactionWithRetriesForMongoS);
+}
+
+void processFLECountS(OperationContext* opCtx,
+ const NamespaceString& nss,
+ CountCommandRequest* countCommand) {
+ fle::processCountCommand(opCtx, nss, countCommand, &getTransactionWithRetriesForMongoS);
}
std::unique_ptr<Pipeline, PipelineDeleter> processFLEPipelineS(
diff --git a/src/mongo/db/fle_crud.h b/src/mongo/db/fle_crud.h
index 303c8704015..1b36850c213 100644
--- a/src/mongo/db/fle_crud.h
+++ b/src/mongo/db/fle_crud.h
@@ -40,6 +40,7 @@
#include "mongo/db/operation_context.h"
#include "mongo/db/ops/write_ops_gen.h"
#include "mongo/db/pipeline/pipeline.h"
+#include "mongo/db/query/count_command_gen.h"
#include "mongo/db/transaction_api.h"
#include "mongo/s/write_ops/batch_write_exec.h"
#include "mongo/s/write_ops/batched_command_response.h"
@@ -118,12 +119,31 @@ write_ops::FindAndModifyCommandReply processFLEFindAndModify(
/**
* Process a find command from mongos.
*/
-void processFLEFindS(OperationContext* opCtx, FindCommandRequest* findCommand);
+void processFLEFindS(OperationContext* opCtx,
+ const NamespaceString& nss,
+ FindCommandRequest* findCommand);
/**
* Process a find command from a replica set.
*/
-void processFLEFindD(OperationContext* opCtx, FindCommandRequest* findCommand);
+void processFLEFindD(OperationContext* opCtx,
+ const NamespaceString& nss,
+ FindCommandRequest* findCommand);
+
+
+/**
+ * Process a find command from mongos.
+ */
+void processFLECountS(OperationContext* opCtx,
+ const NamespaceString& nss,
+ CountCommandRequest* countCommand);
+
+/**
+ * Process a find command from a replica set.
+ */
+void processFLECountD(OperationContext* opCtx,
+ const NamespaceString& nss,
+ CountCommandRequest* countCommand);
/**
* Process a pipeline from mongos.
diff --git a/src/mongo/db/fle_crud_mongod.cpp b/src/mongo/db/fle_crud_mongod.cpp
index 1308c8cb136..cefa6a9e0bf 100644
--- a/src/mongo/db/fle_crud_mongod.cpp
+++ b/src/mongo/db/fle_crud_mongod.cpp
@@ -233,8 +233,16 @@ write_ops::UpdateCommandReply processFLEUpdate(
return updateReply;
}
-void processFLEFindD(OperationContext* opCtx, FindCommandRequest* findCommand) {
- fle::processFindCommand(opCtx, findCommand, &getTransactionWithRetriesForMongoD);
+void processFLEFindD(OperationContext* opCtx,
+ const NamespaceString& nss,
+ FindCommandRequest* findCommand) {
+ fle::processFindCommand(opCtx, nss, findCommand, &getTransactionWithRetriesForMongoD);
+}
+
+void processFLECountD(OperationContext* opCtx,
+ const NamespaceString& nss,
+ CountCommandRequest* countCommand) {
+ fle::processCountCommand(opCtx, nss, countCommand, &getTransactionWithRetriesForMongoD);
}
std::unique_ptr<Pipeline, PipelineDeleter> processFLEPipelineD(
diff --git a/src/mongo/db/query/fle/server_rewrite.cpp b/src/mongo/db/query/fle/server_rewrite.cpp
index 2c6480b44b5..d08b9173884 100644
--- a/src/mongo/db/query/fle/server_rewrite.cpp
+++ b/src/mongo/db/query/fle/server_rewrite.cpp
@@ -59,8 +59,8 @@ BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader,
* Make an expression context from a find command.
*/
boost::intrusive_ptr<ExpressionContext> makeExpCtx(OperationContext* opCtx,
+ const NamespaceString& nss,
FindCommandRequest* findCommand) {
- invariant(findCommand->getNamespaceOrUUID().nss());
std::unique_ptr<CollatorInterface> collator;
if (!findCommand->getCollation().isEmpty()) {
@@ -72,13 +72,37 @@ boost::intrusive_ptr<ExpressionContext> makeExpCtx(OperationContext* opCtx,
}
auto expCtx = make_intrusive<ExpressionContext>(opCtx,
std::move(collator),
- findCommand->getNamespaceOrUUID().nss().get(),
+ nss,
findCommand->getLegacyRuntimeConstants(),
findCommand->getLet());
expCtx->stopExpressionCounters();
return expCtx;
}
+boost::intrusive_ptr<ExpressionContext> makeExpCtx(OperationContext* opCtx,
+ const NamespaceString& nss,
+ CountCommandRequest* countCommand) {
+
+ std::unique_ptr<CollatorInterface> collator;
+ if (countCommand->getCollation()) {
+ auto statusWithCollator = CollatorFactoryInterface::get(opCtx->getServiceContext())
+ ->makeFromBSON(countCommand->getCollation().get());
+
+ uassertStatusOK(statusWithCollator.getStatus());
+ collator = std::move(statusWithCollator.getValue());
+ }
+ auto expCtx = make_intrusive<ExpressionContext>(
+ opCtx,
+ std::move(collator),
+ nss,
+ // Count command does not have legacy runtime constants, and does not support user variables
+ // defined in a let expression.
+ boost::none,
+ boost::none);
+ expCtx->stopExpressionCounters();
+ return expCtx;
+}
+
class RewriteBase {
public:
@@ -203,23 +227,49 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl,
void processFindCommand(OperationContext* opCtx,
+ const NamespaceString& nss,
FindCommandRequest* findCommand,
GetTxnCallback getTransaction) {
- invariant(findCommand->getNamespaceOrUUID().nss());
invariant(findCommand->getEncryptionInformation());
auto sharedBlock =
- std::make_shared<FilterRewrite>(makeExpCtx(opCtx, findCommand),
- findCommand->getNamespaceOrUUID().nss().get(),
+ std::make_shared<FilterRewrite>(makeExpCtx(opCtx, nss, findCommand),
+ nss,
findCommand->getEncryptionInformation().get(),
findCommand->getFilter().getOwned());
doFLERewriteInTxn(opCtx, sharedBlock, getTransaction);
auto rewrittenFilter = sharedBlock->rewrittenFilter.getOwned();
findCommand->setFilter(std::move(rewrittenFilter));
+ // The presence of encryptionInformation is a signal that this is a FLE request that requires
+ // special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but
+ // a normal query that can be executed by the query system like any other, so remove
+ // encryptionInformation.
findCommand->setEncryptionInformation(boost::none);
}
+void processCountCommand(OperationContext* opCtx,
+ const NamespaceString& nss,
+ CountCommandRequest* countCommand,
+ GetTxnCallback getTxn) {
+ invariant(countCommand->getEncryptionInformation());
+
+ auto sharedBlock =
+ std::make_shared<FilterRewrite>(makeExpCtx(opCtx, nss, countCommand),
+ nss,
+ countCommand->getEncryptionInformation().get(),
+ countCommand->getQuery().getOwned());
+ doFLERewriteInTxn(opCtx, sharedBlock, getTxn);
+
+ auto rewrittenFilter = sharedBlock->rewrittenFilter.getOwned();
+ countCommand->setQuery(std::move(rewrittenFilter));
+ // The presence of encryptionInformation is a signal that this is a FLE request that requires
+ // special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but
+ // a normal query that can be executed by the query system like any other, so remove
+ // encryptionInformation.
+ countCommand->setEncryptionInformation(boost::none);
+}
+
std::unique_ptr<Pipeline, PipelineDeleter> processPipeline(
OperationContext* opCtx,
NamespaceString nss,
diff --git a/src/mongo/db/query/fle/server_rewrite.h b/src/mongo/db/query/fle/server_rewrite.h
index 9a94e0a0aa8..e85c6911ea4 100644
--- a/src/mongo/db/query/fle/server_rewrite.h
+++ b/src/mongo/db/query/fle/server_rewrite.h
@@ -39,6 +39,7 @@
#include "mongo/db/matcher/expression_parser.h"
#include "mongo/db/namespace_string.h"
#include "mongo/db/pipeline/expression_context.h"
+#include "mongo/db/query/count_command_gen.h"
#include "mongo/db/transaction_api.h"
namespace mongo {
@@ -50,13 +51,24 @@ namespace fle {
* that any query on an encrypted field will properly query the underlying tags array.
*/
void processFindCommand(OperationContext* opCtx,
+ const NamespaceString& nss,
FindCommandRequest* findCommand,
GetTxnCallback txn);
/**
+ * Process a count command with encryptionInformation in-place, rewriting the filter condition so
+ * that any query on an encrypted field will properly query the underlying tags array.
+ */
+void processCountCommand(OperationContext* opCtx,
+ const NamespaceString& nss,
+ CountCommandRequest* countCommand,
+ GetTxnCallback getTxn);
+
+/**
* Process a pipeline with encryptionInformation by rewriting the pipeline to query against the
- * underlying tags array, where appropriate. After this rewriting is complete, there is no more FLE
- * work to be done. The encryption info does not need to be kept around (e.g. on a command object).
+ * underlying tags array, where appropriate. After this rewriting is complete, there is no more
+ * FLE work to be done. The encryption info does not need to be kept around (e.g. on a command
+ * object).
*/
std::unique_ptr<Pipeline, PipelineDeleter> processPipeline(
OperationContext* opCtx,
diff --git a/src/mongo/s/commands/cluster_count_cmd.cpp b/src/mongo/s/commands/cluster_count_cmd.cpp
index 588a5e41814..48545c7c0d7 100644
--- a/src/mongo/s/commands/cluster_count_cmd.cpp
+++ b/src/mongo/s/commands/cluster_count_cmd.cpp
@@ -33,6 +33,7 @@
#include "mongo/bson/util/bson_extract.h"
#include "mongo/db/commands.h"
+#include "mongo/db/fle_crud.h"
#include "mongo/db/query/count_command_as_aggregation_command.h"
#include "mongo/db/query/count_command_gen.h"
#include "mongo/db/query/view_response_formatter.h"
@@ -100,6 +101,9 @@ public:
std::vector<AsyncRequestsSender::Response> shardResponses;
try {
auto countRequest = CountCommandRequest::parse(IDLParserErrorContext("count"), cmdObj);
+ if (shouldDoFLERewrite(countRequest)) {
+ processFLECountS(opCtx, nss, &countRequest);
+ }
// We only need to factor in the skip value when sending to the shards if we
// have a value for limit, otherwise, we apply it only once we have collected all
@@ -192,29 +196,28 @@ public:
rpc::ReplyBuilderInterface* result) const override {
std::string dbname = request.getDatabase().toString();
const BSONObj& cmdObj = request.body;
+
+ CountCommandRequest countRequest(NamespaceStringOrUUID(NamespaceString{}));
+ try {
+ countRequest = CountCommandRequest::parse(IDLParserErrorContext("count"), request);
+ } catch (...) {
+ return exceptionToStatus();
+ }
+
const NamespaceString nss(parseNs(dbname, cmdObj));
uassert(ErrorCodes::InvalidNamespace,
str::stream() << "Invalid namespace specified '" << nss.ns() << "'",
nss.isValid());
- // Extract the targeting query.
- BSONObj targetingQuery;
- if (Object == cmdObj["query"].type()) {
- targetingQuery = cmdObj["query"].Obj();
+ // If the command has encryptionInformation, rewrite the query as necessary.
+ if (shouldDoFLERewrite(countRequest)) {
+ processFLECountS(opCtx, nss, &countRequest);
}
- // Extract the targeting collation.
- BSONObj targetingCollation;
- BSONElement targetingCollationElement;
- auto status = bsonExtractTypedField(
- cmdObj, "collation", BSONType::Object, &targetingCollationElement);
- if (status.isOK()) {
- targetingCollation = targetingCollationElement.Obj();
- } else if (status != ErrorCodes::NoSuchKey) {
- return status;
- }
+ BSONObj targetingQuery = countRequest.getQuery();
+ BSONObj targetingCollation = countRequest.getCollation().value_or(BSONObj());
- const auto explainCmd = ClusterExplain::wrapAsExplain(cmdObj, verbosity);
+ const auto explainCmd = ClusterExplain::wrapAsExplain(countRequest.toBSON({}), verbosity);
// We will time how long it takes to run the commands on the shards
Timer timer;
diff --git a/src/mongo/s/commands/cluster_find_cmd.h b/src/mongo/s/commands/cluster_find_cmd.h
index 888da770530..efbc0bd4d72 100644
--- a/src/mongo/s/commands/cluster_find_cmd.h
+++ b/src/mongo/s/commands/cluster_find_cmd.h
@@ -289,7 +289,9 @@ public:
findCommand->getNtoreturn() == boost::none);
if (shouldDoFLERewrite(findCommand)) {
- processFLEFindS(opCtx, findCommand.get());
+ invariant(findCommand->getNamespaceOrUUID().nss());
+ processFLEFindS(
+ opCtx, findCommand->getNamespaceOrUUID().nss().get(), findCommand.get());
}
return findCommand;