summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavis Haupt <davis.haupt@mongodb.com>2022-09-22 13:20:24 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-09-22 14:37:20 +0000
commit7868a6387999a37bff0690fcbd7e764428d61880 (patch)
tree2ea72a073c3325cce63d6fb82eba054a15d14730
parent7b52219b95967bce103e85f7a26bfaf70f5146c6 (diff)
downloadmongo-7868a6387999a37bff0690fcbd7e764428d61880.tar.gz
SERVER-69113 Refactor server-side rewrite to enable easy addition of new encrypted index types
-rw-r--r--src/mongo/crypto/fle_crypto.cpp10
-rw-r--r--src/mongo/crypto/fle_crypto.h8
-rw-r--r--src/mongo/db/SConscript8
-rw-r--r--src/mongo/db/query/canonical_query.cpp1
-rw-r--r--src/mongo/db/query/fle/encrypted_predicate.cpp65
-rw-r--r--src/mongo/db/query/fle/encrypted_predicate.h261
-rw-r--r--src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h106
-rw-r--r--src/mongo/db/query/fle/equality_predicate.cpp402
-rw-r--r--src/mongo/db/query/fle/equality_predicate.h65
-rw-r--r--src/mongo/db/query/fle/equality_predicate_test.cpp465
-rw-r--r--src/mongo/db/query/fle/query_rewriter.cpp124
-rw-r--r--src/mongo/db/query/fle/query_rewriter.h156
-rw-r--r--src/mongo/db/query/fle/query_rewriter_interface.h71
-rw-r--r--src/mongo/db/query/fle/query_rewriter_test.cpp209
-rw-r--r--src/mongo/db/query/fle/range_predicate.cpp84
-rw-r--r--src/mongo/db/query/fle/range_predicate.h58
-rw-r--r--src/mongo/db/query/fle/range_predicate_test.cpp131
-rw-r--r--src/mongo/db/query/fle/server_rewrite.cpp527
-rw-r--r--src/mongo/db/query/fle/server_rewrite.h158
-rw-r--r--src/mongo/db/query/fle/server_rewrite_test.cpp1085
20 files changed, 2225 insertions, 1769 deletions
diff --git a/src/mongo/crypto/fle_crypto.cpp b/src/mongo/crypto/fle_crypto.cpp
index 9bc6dfa7022..a288e77329c 100644
--- a/src/mongo/crypto/fle_crypto.cpp
+++ b/src/mongo/crypto/fle_crypto.cpp
@@ -3136,13 +3136,13 @@ StringMap<FLEDeleteToken> EncryptionInformationHelpers::getDeleteTokens(
return map;
}
-ParsedFindPayload::ParsedFindPayload(BSONElement fleFindPayload)
- : ParsedFindPayload(binDataToCDR(fleFindPayload)){};
+ParsedFindEqualityPayload::ParsedFindEqualityPayload(BSONElement fleFindPayload)
+ : ParsedFindEqualityPayload(binDataToCDR(fleFindPayload)){};
-ParsedFindPayload::ParsedFindPayload(const Value& fleFindPayload)
- : ParsedFindPayload(binDataToCDR(fleFindPayload)){};
+ParsedFindEqualityPayload::ParsedFindEqualityPayload(const Value& fleFindPayload)
+ : ParsedFindEqualityPayload(binDataToCDR(fleFindPayload)){};
-ParsedFindPayload::ParsedFindPayload(ConstDataRange cdr) {
+ParsedFindEqualityPayload::ParsedFindEqualityPayload(ConstDataRange cdr) {
auto [encryptedTypeBinding, subCdr] = fromEncryptedConstDataRange(cdr);
auto encryptedType = encryptedTypeBinding;
diff --git a/src/mongo/crypto/fle_crypto.h b/src/mongo/crypto/fle_crypto.h
index 566dd00246c..c65a1c8a0ca 100644
--- a/src/mongo/crypto/fle_crypto.h
+++ b/src/mongo/crypto/fle_crypto.h
@@ -1253,16 +1253,16 @@ public:
*/
std::pair<EncryptedBinDataType, ConstDataRange> fromEncryptedConstDataRange(ConstDataRange cdr);
-struct ParsedFindPayload {
+struct ParsedFindEqualityPayload {
ESCDerivedFromDataToken escToken;
ECCDerivedFromDataToken eccToken;
EDCDerivedFromDataToken edcToken;
boost::optional<ServerDataEncryptionLevel1Token> serverToken;
boost::optional<std::int64_t> maxCounter;
- explicit ParsedFindPayload(BSONElement fleFindPayload);
- explicit ParsedFindPayload(const Value& fleFindPayload);
- explicit ParsedFindPayload(ConstDataRange cdr);
+ explicit ParsedFindEqualityPayload(BSONElement fleFindPayload);
+ explicit ParsedFindEqualityPayload(const Value& fleFindPayload);
+ explicit ParsedFindEqualityPayload(ConstDataRange cdr);
};
diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript
index f109b4b282f..cc88766bd55 100644
--- a/src/mongo/db/SConscript
+++ b/src/mongo/db/SConscript
@@ -880,6 +880,10 @@ env.Library(
target='fle_crud',
source=[
'fle_crud.cpp',
+ 'query/fle/encrypted_predicate.cpp',
+ 'query/fle/equality_predicate.cpp',
+ 'query/fle/query_rewriter.cpp',
+ 'query/fle/range_predicate.cpp',
'query/fle/server_rewrite.cpp',
],
LIBDEPS=[
@@ -2515,7 +2519,9 @@ if wiredtiger:
'operation_id_test.cpp',
'operation_time_tracker_test.cpp',
'persistent_task_store_test.cpp',
- 'query/fle/server_rewrite_test.cpp',
+ 'query/fle/equality_predicate_test.cpp',
+ 'query/fle/range_predicate_test.cpp',
+ 'query/fle/query_rewriter_test.cpp',
'range_arithmetic_test.cpp',
'read_write_concern_defaults_test.cpp',
'read_write_concern_provenance_test.cpp',
diff --git a/src/mongo/db/query/canonical_query.cpp b/src/mongo/db/query/canonical_query.cpp
index bbd7c68a21b..c5bf0a5fbd4 100644
--- a/src/mongo/db/query/canonical_query.cpp
+++ b/src/mongo/db/query/canonical_query.cpp
@@ -42,7 +42,6 @@
#include "mongo/db/operation_context.h"
#include "mongo/db/query/canonical_query_encoder.h"
#include "mongo/db/query/collation/collator_factory_interface.h"
-#include "mongo/db/query/fle/server_rewrite.h"
#include "mongo/db/query/indexability.h"
#include "mongo/db/query/projection_parser.h"
#include "mongo/db/query/query_planner_common.h"
diff --git a/src/mongo/db/query/fle/encrypted_predicate.cpp b/src/mongo/db/query/fle/encrypted_predicate.cpp
new file mode 100644
index 00000000000..c23e1a2a5ae
--- /dev/null
+++ b/src/mongo/db/query/fle/encrypted_predicate.cpp
@@ -0,0 +1,65 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "encrypted_predicate.h"
+
+#include "mongo/crypto/fle_crypto.h"
+#include "mongo/db/query/fle/query_rewriter_interface.h"
+#include "mongo/logv2/log.h"
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kQuery
+
+namespace mongo {
+namespace fle {
+
+ExpressionToRewriteMap aggPredicateRewriteMap{};
+MatchTypeToRewriteMap matchPredicateRewriteMap{};
+
+void logTagsExceeded(const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) {
+ LOGV2_DEBUG(
+ 6672410, 2, "FLE Max tag limit hit during query rewrite", "__error__"_attr = ex.what());
+}
+
+BSONArray toBSONArray(std::vector<PrfBlock>&& vec) {
+ auto bab = BSONArrayBuilder();
+ for (auto& elt : vec) {
+ bab.appendBinData(elt.size(), BinDataType::BinDataGeneral, elt.data());
+ }
+ return bab.arr();
+}
+
+std::vector<Value> toValues(std::vector<PrfBlock>&& vec) {
+ std::vector<Value> output;
+ for (auto& elt : vec) {
+ output.push_back(Value(BSONBinData(elt.data(), elt.size(), BinDataType::BinDataGeneral)));
+ }
+ return output;
+}
+} // namespace fle
+} // namespace mongo
diff --git a/src/mongo/db/query/fle/encrypted_predicate.h b/src/mongo/db/query/fle/encrypted_predicate.h
new file mode 100644
index 00000000000..b2797fd1eb9
--- /dev/null
+++ b/src/mongo/db/query/fle/encrypted_predicate.h
@@ -0,0 +1,261 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <functional>
+
+#include "mongo/base/init.h"
+#include "mongo/bson/bsonobjbuilder.h"
+#include "mongo/crypto/fle_crypto.h"
+#include "mongo/crypto/fle_field_schema_gen.h"
+#include "mongo/db/exec/document_value/document.h"
+#include "mongo/db/matcher/expression.h"
+#include "mongo/db/matcher/expression_leaf.h"
+#include "mongo/db/pipeline/expression.h"
+#include "mongo/db/query/fle/query_rewriter_interface.h"
+
+/**
+ * This file contains an abstract class that describes rewrites on agg Expressions and
+ * MatchExpressions for individual encrypted index types. Subclasses of this class represent
+ * concrete encrypted index types, like Equality and Range.
+ *
+ * This class is not responsible for traversing expression trees, but instead takes leaf
+ * expressions that it may replace. Tree traversal is handled by the QueryRewriter.
+ */
+
+namespace mongo {
+namespace fle {
+
+// Virtual functions can't be templated, so in order to write a function which can take in either a
+// BSONElement or a Value&, we need to create a variant type to use in function signatures.
+// std::reference_wrapper is necessary to avoid copying the Value because references alone cannot be
+// included in a variant. BSONElement can be passed by value because it is just a pointer into an
+// owning BSONObj.
+using BSONValue = stdx::variant<BSONElement, std::reference_wrapper<Value>>;
+
+/**
+ * Parse a find payload from either a BSONElement or a Value. All ParsedFindPayload types should
+ * have constructors for both BSONElements and Values, which will enable this function to work on
+ * both types.
+ */
+template <typename T>
+T parseFindPayload(BSONValue payload) {
+ return stdx::visit(OverloadedVisitor{[&](BSONElement payload) { return T(payload); },
+ [&](Value payload) { return T(payload); }},
+ payload);
+}
+
+/**
+ * Convert a vector of PrfBlocks to a BSONArray for use in MatchExpression tag generation.
+ */
+BSONArray toBSONArray(std::vector<PrfBlock>&& vec);
+
+/**
+ * Convert a vector of PrfBlocks to a vector of Values for use in Agg tag generation.
+ */
+std::vector<Value> toValues(std::vector<PrfBlock>&& vec);
+
+void logTagsExceeded(const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex);
+/**
+ * Interface for implementing a server rewrite for an encrypted index. Each type of predicate
+ * should have its own subclass that implements the virtual methods in this class.
+ */
+class EncryptedPredicate {
+public:
+ EncryptedPredicate(const QueryRewriterInterface* rewriter) : _rewriter(rewriter) {}
+
+ /**
+ * Rewrite a terminal expression for this encrypted predicate. If this function returns
+ * nullptr, then no rewrite needs to be done. Rewrites generally transform predicates from one
+ * kind of expression to another, either a $in or an $_internalFle* runtime expression, and so
+ * this function will allocate a new expression and return a unique_ptr to it.
+ */
+ template <typename T>
+ std::unique_ptr<T> rewrite(T* expr) const {
+ auto mode = _rewriter->getEncryptedCollScanMode();
+ if (mode != EncryptedCollScanMode::kForceAlways) {
+ try {
+ return rewriteToTagDisjunction(expr);
+ } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) {
+ // LOGV2 can't be called from a header file, so this call is factored out to a
+ // function defined in the cpp file.
+ logTagsExceeded(ex);
+ if (mode != EncryptedCollScanMode::kUseIfNeeded) {
+ throw;
+ }
+ }
+ }
+ return rewriteToRuntimeComparison(expr);
+ }
+
+protected:
+ /**
+ * Check if the passed-in payload is a FLE2 find payload for the right encrypted index type.
+ */
+ virtual bool isPayload(const BSONElement& elt) const {
+ if (!elt.isBinData(BinDataType::Encrypt)) {
+ return false;
+ }
+ int dataLen;
+ auto data = elt.binData(dataLen);
+
+ // Check that the BinData's subtype is 6, and its sub-subtype is equal to this predicate's
+ // encryptedBinDataType.
+ return dataLen >= 1 && data[0] == static_cast<uint8_t>(encryptedBinDataType());
+ }
+
+ /**
+ * Check if the passed-in payload is a FLE2 find payload for the right encrypted index type.
+ */
+ virtual bool isPayload(const Value& v) const {
+ if (v.getType() != BSONType::BinData) {
+ return false;
+ }
+
+ auto binData = v.getBinData();
+ // Check that the BinData's subtype is 6, and its sub-subtype is equal to this predicate's
+ // encryptedBinDataType.
+ return binData.type == BinDataType::Encrypt && binData.length >= 1 &&
+ static_cast<uint8_t>(encryptedBinDataType()) ==
+ static_cast<const uint8_t*>(binData.data)[0];
+ }
+ /**
+ * Generate tags from a FLE2 Find Payload. This function takes in a variant of BSONElement and
+ * Value so that it can be used in both the MatchExpression and Aggregation contexts. Virtual
+ * functions can't also be templated, which is why we need the runtime dispatch on the variant.
+ */
+ virtual std::vector<PrfBlock> generateTags(BSONValue payload) const = 0;
+
+ /**
+ * Rewrite to a tag disjunction on the __safeContent__ field.
+ */
+ virtual std::unique_ptr<MatchExpression> rewriteToTagDisjunction(
+ MatchExpression* expr) const = 0;
+ /**
+ * Rewrite to a tag disjunction on the __safeContent__ field.
+ */
+ virtual std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const = 0;
+
+ /**
+ * Rewrite to an expression which can generate tags at runtime during an encrypted collscan.
+ */
+ virtual std::unique_ptr<MatchExpression> rewriteToRuntimeComparison(
+ MatchExpression* expr) const = 0;
+ /**
+ * Rewrite to an expression which can generate tags at runtime during an encrypted collscan.
+ */
+ virtual std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const = 0;
+
+ const QueryRewriterInterface* _rewriter;
+
+private:
+ /**
+ * Sub-subtype associated with the find payload for this encrypted predicate.
+ */
+ virtual EncryptedBinDataType encryptedBinDataType() const = 0;
+};
+
+/**
+ * Encrypted predicate rewrites are registered at startup time using MONGO_INITIALIZER blocks.
+ * MatchExpression rewrites are keyed on the MatchExpressionType enum, and Agg Expression rewrites
+ * are keyed on the dynamic type for the Expression subclass.
+ */
+
+using ExpressionToRewriteMap = stdx::unordered_map<
+ std::type_index,
+ std::function<std::unique_ptr<Expression>(QueryRewriterInterface*, Expression*)>>;
+
+extern ExpressionToRewriteMap aggPredicateRewriteMap;
+
+using MatchTypeToRewriteMap = stdx::unordered_map<
+ MatchExpression::MatchType,
+ std::function<std::unique_ptr<MatchExpression>(QueryRewriterInterface*, MatchExpression*)>>;
+
+extern MatchTypeToRewriteMap matchPredicateRewriteMap;
+
+/**
+ * Register an agg rewrite if a condition is true at startup time.
+ */
+#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(className, rewriteClass, isEnabledExpr) \
+ MONGO_INITIALIZER(encryptedAggPredicateRewriteFor_##className)(InitializerContext*) { \
+ \
+ invariant(aggPredicateRewriteMap.find(typeid(className)) == aggPredicateRewriteMap.end()); \
+ aggPredicateRewriteMap[typeid(className)] = [&](auto* rewriter, auto* expr) { \
+ if (isEnabledExpr) { \
+ return rewriteClass{rewriter}.rewrite(expr); \
+ } else { \
+ return std::unique_ptr<Expression>(nullptr); \
+ } \
+ }; \
+ }
+
+/**
+ * Register an agg rewrite unconditionally.
+ */
+#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE(matchType, rewriteClass) \
+ REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(matchType, rewriteClass, true)
+
+/**
+ * Register an agg rewrite behind a feature flag.
+ */
+#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(matchType, rewriteClass, featureFlag) \
+ REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED( \
+ matchType, rewriteClass, featureFlag.isEnabled(serverGlobalParams.featureCompatibility))
+
+/**
+ * Register a MatchExpression rewrite if a condition is true at startup time.
+ */
+#define REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_GUARDED(matchType, rewriteClass, isEnabledExpr) \
+ MONGO_INITIALIZER(encryptedMatchPredicateRewriteFor_##matchType)(InitializerContext*) { \
+ \
+ invariant(matchPredicateRewriteMap.find(MatchExpression::matchType) == \
+ matchPredicateRewriteMap.end()); \
+ matchPredicateRewriteMap[MatchExpression::matchType] = [&](auto* rewriter, auto* expr) { \
+ if (isEnabledExpr) { \
+ return rewriteClass{rewriter}.rewrite(expr); \
+ } else { \
+ return std::unique_ptr<MatchExpression>(nullptr); \
+ } \
+ }; \
+ };
+/**
+ * Register a MatchExpression rewrite unconditionally.
+ */
+#define REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE(matchType, rewriteClass) \
+ REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_GUARDED(matchType, rewriteClass, true)
+
+/**
+ * Register a MatchExpression rewrite behind a feature flag.
+ */
+#define REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(matchType, rewriteClass, featureFlag) \
+ REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_GUARDED( \
+ matchType, rewriteClass, featureFlag.isEnabled(serverGlobalParams.featureCompatibility))
+} // namespace fle
+} // namespace mongo
diff --git a/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h b/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h
new file mode 100644
index 00000000000..88a6c85983a
--- /dev/null
+++ b/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h
@@ -0,0 +1,106 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/pipeline/expression_context_for_test.h"
+#include "mongo/db/query/fle/encrypted_predicate.h"
+#include "mongo/db/query/fle/equality_predicate.h"
+#include "mongo/db/query/fle/query_rewriter_interface.h"
+#include "mongo/db/query/fle/range_predicate.h"
+#include "mongo/unittest/unittest.h"
+#include "mongo/util/overloaded_visitor.h"
+
+namespace mongo::fle {
+using TagMap = std::map<std::pair<StringData, int>, std::vector<PrfBlock>>;
+
+/*
+ * The MockServerRewrite allows unit testing individual predicate rewrites without going through the
+ * real server rewrite that traverses full expression trees.
+ */
+class MockServerRewrite : public QueryRewriterInterface {
+public:
+ MockServerRewrite() : _expCtx((new ExpressionContextForTest())) {}
+ const FLEStateCollectionReader* getEscReader() const override {
+ return nullptr;
+ }
+ const FLEStateCollectionReader* getEccReader() const override {
+ return nullptr;
+ }
+ EncryptedCollScanMode getEncryptedCollScanMode() const override {
+ return _mode;
+ };
+ ExpressionContext* getExpressionContext() const {
+ return _expCtx.get();
+ }
+
+ void setForceEncryptedCollScanForTest() {
+ _mode = EncryptedCollScanMode::kForceAlways;
+ }
+
+private:
+ boost::intrusive_ptr<ExpressionContextForTest> _expCtx;
+ EncryptedCollScanMode _mode{EncryptedCollScanMode::kUseIfNeeded};
+};
+
+class EncryptedPredicateRewriteTest : public unittest::Test {
+public:
+ EncryptedPredicateRewriteTest() {}
+
+ void setUp() override {}
+
+ void tearDown() override {}
+
+ static std::unique_ptr<MatchExpression> makeInExpr(StringData fieldname,
+ BSONArray disjunctions) {
+ auto inExpr = std::make_unique<InMatchExpression>(fieldname);
+ std::vector<BSONElement> elems;
+ disjunctions.elems(elems);
+ uassertStatusOK(inExpr->setEqualities(elems));
+ inExpr->setBackingBSON(std::move(disjunctions));
+ return inExpr;
+ }
+
+ /*
+ * Assertion helper for tag disjunction rewrite.
+ */
+ void assertRewriteToTags(const EncryptedPredicate& pred,
+ MatchExpression* input,
+ BSONArray expectedTags) {
+
+ auto actual = pred.rewrite(input);
+ auto expected = makeInExpr(kSafeContent, expectedTags);
+ ASSERT_BSONOBJ_EQ(actual->serialize(),
+ static_cast<MatchExpression*>(expected.get())->serialize());
+ }
+
+protected:
+ MockServerRewrite _mock{};
+};
+} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/equality_predicate.cpp b/src/mongo/db/query/fle/equality_predicate.cpp
new file mode 100644
index 00000000000..fc46bcbfe3a
--- /dev/null
+++ b/src/mongo/db/query/fle/equality_predicate.cpp
@@ -0,0 +1,402 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "equality_predicate.h"
+
+#include "mongo/crypto/fle_crypto.h"
+#include "mongo/crypto/fle_tags.h"
+#include "mongo/db/matcher/expression_expr.h"
+#include "mongo/db/matcher/expression_leaf.h"
+#include "mongo/db/pipeline/expression.h"
+#include "mongo/db/query/fle/encrypted_predicate.h"
+#include "mongo/util/overloaded_visitor.h"
+
+namespace mongo::fle {
+
+REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE(EQ, EqualityPredicate);
+REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE(MATCH_IN, EqualityPredicate);
+REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE(ExpressionCompare, EqualityPredicate);
+REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE(ExpressionIn, EqualityPredicate);
+
+std::vector<PrfBlock> EqualityPredicate::generateTags(BSONValue payload) const {
+ ParsedFindEqualityPayload tokens = parseFindPayload<ParsedFindEqualityPayload>(payload);
+ return readTags(*_rewriter->getEscReader(),
+ *_rewriter->getEccReader(),
+ tokens.escToken,
+ tokens.eccToken,
+ tokens.edcToken,
+ tokens.maxCounter);
+}
+
+std::unique_ptr<MatchExpression> EqualityPredicate::rewriteToTagDisjunction(
+ MatchExpression* expr) const {
+ switch (expr->matchType()) {
+ case MatchExpression::EQ: {
+ auto eqExpr = static_cast<EqualityMatchExpression*>(expr);
+ auto payload = eqExpr->getData();
+ if (!isPayload(payload)) {
+ return nullptr;
+ }
+ auto obj = toBSONArray(generateTags(payload));
+
+ auto tags = std::vector<BSONElement>();
+ obj.elems(tags);
+
+ auto inExpr = std::make_unique<InMatchExpression>(kSafeContent);
+ inExpr->setBackingBSON(std::move(obj));
+ auto status = inExpr->setEqualities(std::move(tags));
+ uassertStatusOK(status);
+ return inExpr;
+ }
+ case MatchExpression::MATCH_IN: {
+ auto inExpr = static_cast<InMatchExpression*>(expr);
+ size_t numFFPs = 0;
+ for (auto& eq : inExpr->getEqualities()) {
+ if (isPayload(eq)) {
+ ++numFFPs;
+ }
+ }
+ if (numFFPs == 0) {
+ return nullptr;
+ }
+ // All elements in an encrypted $in expression should be FFPs.
+ uassert(6329400,
+ "If any elements in a $in expression are encrypted, then all elements should "
+ "be encrypted.",
+ numFFPs == inExpr->getEqualities().size());
+ auto backingBSONBuilder = BSONArrayBuilder();
+
+ for (auto& eq : inExpr->getEqualities()) {
+ auto obj = toBSONArray(generateTags(eq));
+ for (auto&& elt : obj) {
+ backingBSONBuilder.append(elt);
+ }
+ }
+
+ auto backingBSON = backingBSONBuilder.arr();
+ auto allTags = std::vector<BSONElement>();
+ backingBSON.elems(allTags);
+
+ auto newExpr = std::make_unique<InMatchExpression>(kSafeContent);
+ newExpr->setBackingBSON(std::move(backingBSON));
+ auto status = newExpr->setEqualities(std::move(allTags));
+ uassertStatusOK(status);
+
+ return newExpr;
+ }
+ default:
+ MONGO_UNREACHABLE_TASSERT(6911300);
+ }
+ MONGO_UNREACHABLE_TASSERT(6911301);
+};
+
+namespace {
+template <typename PayloadT>
+boost::intrusive_ptr<ExpressionInternalFLEEqual> generateFleEqualMatch(StringData path,
+ const PayloadT& ffp,
+ ExpressionContext* expCtx) {
+ // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] }
+ // }
+ auto tokens = ParsedFindEqualityPayload(ffp);
+
+ uassert(6672401,
+ "Missing required field server encryption token in find payload",
+ tokens.serverToken.has_value());
+
+ return make_intrusive<ExpressionInternalFLEEqual>(
+ expCtx,
+ ExpressionFieldPath::createPathFromString(
+ expCtx, path.toString(), expCtx->variablesParseState),
+ tokens.serverToken.value().data,
+ tokens.maxCounter.value_or(0LL),
+ tokens.edcToken.data);
+}
+
+
+template <typename PayloadT>
+std::unique_ptr<ExpressionInternalFLEEqual> generateFleEqualMatchUnique(StringData path,
+ const PayloadT& ffp,
+ ExpressionContext* expCtx) {
+ // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] } }
+ auto tokens = ParsedFindEqualityPayload(ffp);
+
+ uassert(6672419,
+ "Missing required field server encryption token in find payload",
+ tokens.serverToken.has_value());
+
+ return std::make_unique<ExpressionInternalFLEEqual>(
+ expCtx,
+ ExpressionFieldPath::createPathFromString(
+ expCtx, path.toString(), expCtx->variablesParseState),
+ tokens.serverToken.value().data,
+ tokens.maxCounter.value_or(0LL),
+ tokens.edcToken.data);
+}
+
+std::unique_ptr<MatchExpression> generateFleEqualMatchAndExpr(StringData path,
+ const BSONElement ffp,
+ ExpressionContext* expCtx) {
+ auto fleEqualMatch = generateFleEqualMatch(path, ffp, expCtx);
+
+ return std::make_unique<ExprMatchExpression>(fleEqualMatch, expCtx);
+}
+} // namespace
+
+std::unique_ptr<MatchExpression> EqualityPredicate::rewriteToRuntimeComparison(
+ MatchExpression* expr) const {
+ switch (expr->matchType()) {
+ case MatchExpression::EQ: {
+ auto eqExpr = static_cast<EqualityMatchExpression*>(expr);
+ auto payload = eqExpr->getData();
+ if (!isPayload(payload)) {
+ return nullptr;
+ }
+ return generateFleEqualMatchAndExpr(
+ eqExpr->path(), payload, _rewriter->getExpressionContext());
+ }
+ case MatchExpression::MATCH_IN: {
+ auto inExpr = static_cast<InMatchExpression*>(expr);
+ size_t numFFPs = 0;
+ for (auto& eq : inExpr->getEqualities()) {
+ if (isPayload(eq)) {
+ ++numFFPs;
+ }
+ }
+ if (numFFPs == 0) {
+ return nullptr;
+ }
+ uassert(6911300,
+ "If any elements in a $in expression are encrypted, then all elements should "
+ "be encrypted.",
+ numFFPs == inExpr->getEqualities().size());
+ std::vector<std::unique_ptr<MatchExpression>> matches;
+ matches.reserve(numFFPs);
+
+ for (auto& eq : inExpr->getEqualities()) {
+ auto exprMatch = generateFleEqualMatchAndExpr(
+ expr->path(), eq, _rewriter->getExpressionContext());
+ matches.push_back(std::move(exprMatch));
+ }
+
+ auto orExpr = std::make_unique<OrMatchExpression>(std::move(matches));
+ return orExpr;
+ }
+ default:
+ MONGO_UNREACHABLE;
+ }
+ MONGO_UNREACHABLE;
+}
+
+/*
+ * Helper function for code shared between tag disjunction and runtime evaluation for the equality
+ * case.
+ */
+boost::optional<std::pair<ExpressionFieldPath*, ExpressionConstant*>>
+EqualityPredicate::extractDetailsFromComparison(ExpressionCompare* expr) const {
+ auto equalitiesList = expr->getChildren();
+
+ auto leftConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[0].get());
+ auto rightConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[1].get());
+
+ bool isLeftFFP = leftConstant && isPayload(leftConstant->getValue());
+ bool isRightFFP = rightConstant && isPayload(rightConstant->getValue());
+
+ uassert(6334100,
+ "Cannot compare two encrypted constants to each other",
+ !(isLeftFFP && isRightFFP));
+
+ // No FLE Find Payload
+ if (!isLeftFFP && !isRightFFP) {
+ return boost::none;
+ }
+
+ auto leftFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[0].get());
+ auto rightFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[1].get());
+
+ uassert(6672413,
+ "Queryable Encryption only supports comparisons between a field path and a constant",
+ leftFieldPath || rightFieldPath);
+ auto fieldPath = leftFieldPath ? leftFieldPath : rightFieldPath;
+ auto constChild = isLeftFFP ? leftConstant : rightConstant;
+ return {{fieldPath, constChild}};
+}
+
+/**
+ * Perform validation on $in add expressions, and return a pre-processed fieldpath expression for
+ * use by the rewrite. This factors out common validation code for the runtime and tag rewrite
+ * cases.
+ */
+boost::optional<const ExpressionFieldPath*> EqualityPredicate::validateIn(
+ ExpressionIn* inExpr, ExpressionArray* inList) const {
+ auto leftExpr = inExpr->getOperandList()[0].get();
+ auto& equalitiesList = inList->getChildren();
+ size_t numFFPs = 0;
+
+ for (auto& equality : equalitiesList) {
+ // For each expression representing a FleFindPayload...
+ if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) {
+ if (!isPayload(constChild->getValue())) {
+ continue;
+ }
+
+ numFFPs++;
+ }
+ }
+
+ // Finally, construct an $or of all of the $ins.
+ if (numFFPs == 0) {
+ return boost::none;
+ }
+
+ uassert(6334102,
+ "If any elements in an comparison expression are encrypted, then all elements "
+ "should "
+ "be encrypted.",
+ numFFPs == equalitiesList.size());
+
+ auto leftFieldPath = dynamic_cast<const ExpressionFieldPath*>(leftExpr);
+ uassert(6672417,
+ "$in is only supported with Queryable Encryption when the first argument is a "
+ "field path",
+ leftFieldPath != nullptr);
+ return leftFieldPath;
+}
+
+std::unique_ptr<Expression> EqualityPredicate::rewriteToTagDisjunction(Expression* expr) const {
+ if (auto eqExpr = dynamic_cast<ExpressionCompare*>(expr); eqExpr) {
+ if (eqExpr->getOp() != ExpressionCompare::EQ && eqExpr->getOp() != ExpressionCompare::NE) {
+ return nullptr;
+ }
+ auto details = extractDetailsFromComparison(eqExpr);
+ if (!details) {
+ return nullptr;
+ }
+ auto [_, constChild] = details.value();
+
+ std::vector<boost::intrusive_ptr<Expression>> orListElems;
+ auto payload = constChild->getValue();
+ auto tags = toValues(generateTags(std::ref(payload)));
+ for (auto&& tagElt : tags) {
+ // ... and for each tag, construct expression {$in: [tag,
+ // "$__safeContent__"]}.
+ std::vector<boost::intrusive_ptr<Expression>> inVec{
+ ExpressionConstant::create(_rewriter->getExpressionContext(), tagElt),
+ ExpressionFieldPath::createPathFromString(
+ _rewriter->getExpressionContext(),
+ kSafeContent,
+ _rewriter->getExpressionContext()->variablesParseState)};
+ orListElems.push_back(
+ make_intrusive<ExpressionIn>(_rewriter->getExpressionContext(), std::move(inVec)));
+ }
+ auto disjunction = std::make_unique<ExpressionOr>(_rewriter->getExpressionContext(),
+ std::move(orListElems));
+ if (eqExpr->getOp() == ExpressionCompare::NE) {
+ std::vector<boost::intrusive_ptr<Expression>> notChild{disjunction.release()};
+ return std::make_unique<ExpressionNot>(_rewriter->getExpressionContext(),
+ std::move(notChild));
+ }
+ return disjunction;
+ } else if (auto inExpr = dynamic_cast<ExpressionIn*>(expr)) {
+ if (auto inList = dynamic_cast<ExpressionArray*>(inExpr->getOperandList()[1].get())) {
+ if (!validateIn(inExpr, inList)) {
+ return nullptr;
+ }
+ auto& equalitiesList = inList->getChildren();
+ std::vector<boost::intrusive_ptr<Expression>> orListElems;
+ auto expCtx = _rewriter->getExpressionContext();
+ for (auto& equality : equalitiesList) {
+ // For each expression representing a FleFindPayload...
+ if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) {
+ // ... rewrite the payload to a list of tags...
+ auto payload = constChild->getValue();
+ auto tags = toValues(generateTags(std::ref(payload)));
+ for (auto&& tagElt : tags) {
+ // ... and for each tag, construct expression {$in: [tag,
+ // "$__safeContent__"]}.
+ std::vector<boost::intrusive_ptr<Expression>> inVec{
+ ExpressionConstant::create(expCtx, tagElt),
+ ExpressionFieldPath::createPathFromString(
+ expCtx, kSafeContent, expCtx->variablesParseState)};
+ orListElems.push_back(
+ make_intrusive<ExpressionIn>(expCtx, std::move(inVec)));
+ }
+ }
+ }
+ return std::make_unique<ExpressionOr>(expCtx, std::move(orListElems));
+ }
+ return nullptr;
+ }
+ MONGO_UNREACHABLE_TASSERT(6911303);
+}
+
+std::unique_ptr<Expression> EqualityPredicate::rewriteToRuntimeComparison(Expression* expr) const {
+ if (auto eqExpr = dynamic_cast<ExpressionCompare*>(expr); eqExpr) {
+ if (eqExpr->getOp() != ExpressionCompare::EQ && eqExpr->getOp() != ExpressionCompare::NE) {
+ return nullptr;
+ }
+ auto details = extractDetailsFromComparison(eqExpr);
+ if (!details) {
+ return nullptr;
+ }
+ auto [fieldPath, constChild] = details.value();
+ auto fleEqualExpr =
+ generateFleEqualMatchUnique(fieldPath->getFieldPathWithoutCurrentPrefix().fullPath(),
+ constChild->getValue(),
+ _rewriter->getExpressionContext());
+ if (eqExpr->getOp() == ExpressionCompare::NE) {
+ std::vector<boost::intrusive_ptr<Expression>> notChild{fleEqualExpr.release()};
+ return std::make_unique<ExpressionNot>(_rewriter->getExpressionContext(),
+ std::move(notChild));
+ }
+ return fleEqualExpr;
+ } else if (auto inExpr = dynamic_cast<ExpressionIn*>(expr)) {
+ if (auto inList = dynamic_cast<ExpressionArray*>(inExpr->getOperandList()[1].get())) {
+ auto leftFieldPath = validateIn(inExpr, inList);
+ if (!leftFieldPath) {
+ return nullptr;
+ }
+ auto& equalitiesList = inList->getChildren();
+ std::vector<boost::intrusive_ptr<Expression>> orListElems;
+ for (auto& equality : equalitiesList) {
+ if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) {
+ auto fleEqExpr = generateFleEqualMatch(
+ leftFieldPath.value()->getFieldPathWithoutCurrentPrefix().fullPath(),
+ constChild->getValue(),
+ _rewriter->getExpressionContext());
+ orListElems.push_back(fleEqExpr);
+ }
+ }
+ return std::make_unique<ExpressionOr>(_rewriter->getExpressionContext(),
+ std::move(orListElems));
+ }
+ return nullptr;
+ }
+ MONGO_UNREACHABLE_TASSERT(6911304);
+}
+} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/equality_predicate.h b/src/mongo/db/query/fle/equality_predicate.h
new file mode 100644
index 00000000000..f7630c1d5aa
--- /dev/null
+++ b/src/mongo/db/query/fle/equality_predicate.h
@@ -0,0 +1,65 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/crypto/fle_crypto.h"
+#include "mongo/db/query/fle/encrypted_predicate.h"
+
+namespace mongo::fle {
+/**
+ * Server-side rewrite for the encrypted equality index. This rewrite expects either a $eq or $in
+ * expression.
+ */
+class EqualityPredicate : public EncryptedPredicate {
+public:
+ EqualityPredicate(const QueryRewriterInterface* rewriter) : EncryptedPredicate(rewriter) {}
+
+protected:
+ std::vector<PrfBlock> generateTags(BSONValue payload) const override;
+
+ std::unique_ptr<MatchExpression> rewriteToTagDisjunction(MatchExpression* expr) const override;
+ std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override;
+
+ std::unique_ptr<MatchExpression> rewriteToRuntimeComparison(
+ MatchExpression* expr) const override;
+ std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const override;
+
+private:
+ EncryptedBinDataType encryptedBinDataType() const override {
+ return EncryptedBinDataType::kFLE2FindEqualityPayload;
+ }
+
+ boost::optional<std::pair<ExpressionFieldPath*, ExpressionConstant*>>
+ extractDetailsFromComparison(ExpressionCompare* eqExpr) const;
+
+ boost::optional<const ExpressionFieldPath*> validateIn(ExpressionIn* inExpr,
+ ExpressionArray* inList) const;
+};
+} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/equality_predicate_test.cpp b/src/mongo/db/query/fle/equality_predicate_test.cpp
new file mode 100644
index 00000000000..5a8ee9f5e70
--- /dev/null
+++ b/src/mongo/db/query/fle/equality_predicate_test.cpp
@@ -0,0 +1,465 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/bson/bsonmisc.h"
+#include "mongo/bson/json.h"
+#include "mongo/crypto/fle_crypto.h"
+#include "mongo/db/exec/document_value/value.h"
+#include "mongo/db/matcher/expression_expr.h"
+#include "mongo/db/matcher/expression_leaf.h"
+#include "mongo/db/matcher/expression_tree.h"
+#include "mongo/db/matcher/expression_visitor.h"
+#include "mongo/db/pipeline/expression.h"
+#include "mongo/db/query/fle/encrypted_predicate_test_fixtures.h"
+#include "mongo/db/query/fle/equality_predicate.h"
+#include "mongo/idl/server_parameter_test_util.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo::fle {
+namespace {
+class MockEqualityPredicate : public EqualityPredicate {
+public:
+ MockEqualityPredicate(const QueryRewriterInterface* rewriter) : EqualityPredicate(rewriter) {}
+ MockEqualityPredicate(const QueryRewriterInterface* rewriter,
+ TagMap tags,
+ std::set<StringData> encryptedFields)
+ : EqualityPredicate(rewriter), _tags(tags), _encryptedFields(encryptedFields) {}
+
+ void setEncryptedTags(std::pair<StringData, int> fieldvalue, std::vector<PrfBlock> tags) {
+ _encryptedFields.insert(fieldvalue.first);
+ _tags[fieldvalue] = tags;
+ }
+
+ void addEncryptedField(StringData field) {
+ _encryptedFields.insert(field);
+ }
+
+
+protected:
+ bool isPayload(const BSONElement& elt) const override {
+ return _encryptedFields.find(elt.fieldNameStringData()) != _encryptedFields.end();
+ }
+
+ bool isPayload(const Value& v) const override {
+ return true;
+ }
+
+ std::vector<PrfBlock> generateTags(BSONValue payload) const {
+ return stdx::visit(
+ OverloadedVisitor{
+ [&](BSONElement p) {
+ ASSERT(p.isNumber()); // Only accept numbers as mock FFPs.
+ ASSERT(_tags.find({p.fieldNameStringData(), p.Int()}) != _tags.end());
+ return _tags.find({p.fieldNameStringData(), p.Int()})->second;
+ },
+ [&](std::reference_wrapper<Value> v) { return std::vector<PrfBlock>{}; }},
+ payload);
+ }
+
+private:
+ TagMap _tags;
+ std::set<StringData> _encryptedFields;
+};
+
+class EqualityPredicateRewriteTest : public EncryptedPredicateRewriteTest {
+public:
+ EqualityPredicateRewriteTest() : _predicate(&_mock) {}
+
+protected:
+ MockEqualityPredicate _predicate;
+};
+
+TEST_F(EqualityPredicateRewriteTest, Equality_NoFFP) {
+ std::unique_ptr<MatchExpression> input =
+ std::make_unique<EqualityMatchExpression>("ssn", Value(5));
+ auto expected = EqualityMatchExpression("ssn", Value(5));
+
+ auto result = _predicate.rewrite(input.get());
+ ASSERT(result == nullptr);
+ ASSERT(input->equivalent(&expected));
+}
+
+TEST_F(EqualityPredicateRewriteTest, In_NoFFP) {
+ auto input = makeInExpr("name",
+ BSON_ARRAY("harry"
+ << "ron"
+ << "hermione"));
+ auto expected = makeInExpr("name",
+ BSON_ARRAY("harry"
+ << "ron"
+ << "hermione"));
+
+ auto result = _predicate.rewrite(input.get());
+ ASSERT(result == nullptr);
+ ASSERT(input->equivalent(expected.get()));
+}
+
+TEST_F(EqualityPredicateRewriteTest, Equality_Basic) {
+ auto input = EqualityMatchExpression("ssn", Value(5));
+ std::vector<PrfBlock> tags = {{1}, {2}, {3}};
+
+ _predicate.setEncryptedTags({"ssn", 5}, tags);
+
+ assertRewriteToTags(_predicate, &input, toBSONArray(std::move(tags)));
+}
+
+TEST_F(EqualityPredicateRewriteTest, In_Basic) {
+ auto input = makeInExpr("ssn", BSON_ARRAY(2 << 4 << 6));
+
+ _predicate.setEncryptedTags({"0", 2}, {{1}, {2}});
+ _predicate.setEncryptedTags({"1", 4}, {{5}, {3}});
+ _predicate.setEncryptedTags({"2", 6}, {{99}, {100}});
+
+ assertRewriteToTags(_predicate, input.get(), toBSONArray({{1}, {2}, {3}, {5}, {99}, {100}}));
+}
+
+TEST_F(EqualityPredicateRewriteTest, In_NotAllFFPs) {
+ auto input = makeInExpr("ssn", BSON_ARRAY(2 << 4 << 6));
+
+ _predicate.setEncryptedTags({"0", 2}, {{1}, {2}});
+ _predicate.setEncryptedTags({"1", 4}, {{5}, {3}});
+
+ ASSERT_THROWS_CODE(
+ assertRewriteToTags(_predicate, input.get(), toBSONArray({{1}, {2}, {3}, {5}})),
+ AssertionException,
+ 6329400);
+}
+
+template <typename T>
+std::vector<uint8_t> toEncryptedVector(EncryptedBinDataType dt, T t) {
+ BSONObj obj = t.toBSON();
+
+ std::vector<uint8_t> buf(obj.objsize() + 1);
+ buf[0] = static_cast<uint8_t>(dt);
+
+ std::copy(obj.objdata(), obj.objdata() + obj.objsize(), buf.data() + 1);
+
+ return buf;
+}
+
+template <typename T>
+void toEncryptedBinData(StringData field, EncryptedBinDataType dt, T t, BSONObjBuilder* builder) {
+ auto buf = toEncryptedVector(dt, t);
+
+ builder->appendBinData(field, buf.size(), BinDataType::Encrypt, buf.data());
+}
+
+constexpr auto kIndexKeyId = "12345678-1234-9876-1234-123456789012"_sd;
+constexpr auto kUserKeyId = "ABCDEFAB-1234-9876-1234-123456789012"_sd;
+static UUID indexKeyId = uassertStatusOK(UUID::parse(kIndexKeyId.toString()));
+static UUID userKeyId = uassertStatusOK(UUID::parse(kUserKeyId.toString()));
+
+std::vector<char> testValue = {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19};
+std::vector<char> testValue2 = {0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29};
+
+const FLEIndexKey& getIndexKey() {
+ static std::string indexVec = hexblob::decode(
+ "7dbfebc619aa68a659f64b8e23ccd21644ac326cb74a26840c3d2420176c40ae088294d00ad6cae9684237b21b754cf503f085c25cd320bf035c3417416e1e6fe3d9219f79586582112740b2add88e1030d91926ae8afc13ee575cfb8bb965b7"_sd);
+ static FLEIndexKey indexKey(KeyMaterial(indexVec.begin(), indexVec.end()));
+ return indexKey;
+}
+
+const FLEUserKey& getUserKey() {
+ static std::string userVec = hexblob::decode(
+ "a7ddbc4c8be00d51f68d9d8e485f351c8edc8d2206b24d8e0e1816d005fbe520e489125047d647b0d8684bfbdbf09c304085ed086aba6c2b2b1677ccc91ced8847a733bf5e5682c84b3ee7969e4a5fe0e0c21e5e3ee190595a55f83147d8de2a"_sd);
+ static FLEUserKey userKey(KeyMaterial(userVec.begin(), userVec.end()));
+ return userKey;
+}
+
+
+BSONObj generateFFP(StringData path, int value) {
+ auto indexKey = getIndexKey();
+ FLEIndexKeyAndId indexKeyAndId(indexKey.data, indexKeyId);
+ auto userKey = getUserKey();
+ FLEUserKeyAndId userKeyAndId(userKey.data, indexKeyId);
+
+ BSONObj doc = BSON("value" << value);
+ auto element = doc.firstElement();
+ auto fpp = FLEClientCrypto::serializeFindPayload(indexKeyAndId, userKeyAndId, element, 0);
+
+ BSONObjBuilder builder;
+ toEncryptedBinData(path, EncryptedBinDataType::kFLE2FindEqualityPayload, fpp, &builder);
+ return builder.obj();
+}
+
+std::unique_ptr<MatchExpression> generateEqualityWithFFP(StringData path, int value) {
+ auto ffp = generateFFP(path, value);
+ return std::make_unique<EqualityMatchExpression>(path, ffp.firstElement());
+}
+
+std::unique_ptr<Expression> generateEqualityWithFFP(ExpressionContext* const expCtx,
+ StringData path,
+ int value) {
+ auto ffp = Value(generateFFP(path, value).firstElement());
+ auto ffpExpr = make_intrusive<ExpressionConstant>(expCtx, ffp);
+ auto fieldpath = ExpressionFieldPath::createPathFromString(
+ expCtx, path.toString(), expCtx->variablesParseState);
+ std::vector<boost::intrusive_ptr<Expression>> children = {std::move(fieldpath),
+ std::move(ffpExpr)};
+ return std::make_unique<ExpressionCompare>(expCtx, ExpressionCompare::EQ, std::move(children));
+}
+
+std::unique_ptr<MatchExpression> generateDisjunctionWithFFP(StringData path,
+ std::initializer_list<int> vals) {
+ BSONArrayBuilder bab;
+ for (auto& value : vals) {
+ bab.append(generateFFP(path, value).firstElement());
+ }
+ auto arr = bab.arr();
+ return EncryptedPredicateRewriteTest::makeInExpr(path, arr);
+}
+
+std::unique_ptr<Expression> generateDisjunctionWithFFP(ExpressionContext* const expCtx,
+ StringData path,
+ std::initializer_list<int> values) {
+ std::vector<boost::intrusive_ptr<Expression>> ffps;
+ for (auto& value : values) {
+ auto ffp = make_intrusive<ExpressionConstant>(
+ expCtx, Value(generateFFP(path, value).firstElement()));
+ ffps.emplace_back(std::move(ffp));
+ }
+ auto ffpArray = make_intrusive<ExpressionArray>(expCtx, std::move(ffps));
+ auto fieldpath = ExpressionFieldPath::createPathFromString(
+ expCtx, path.toString(), expCtx->variablesParseState);
+ std::vector<boost::intrusive_ptr<Expression>> children{std::move(fieldpath),
+ std::move(ffpArray)};
+ return std::make_unique<ExpressionIn>(expCtx, std::move(children));
+}
+
+class EqualityPredicateCollScanRewriteTest : public EncryptedPredicateRewriteTest {
+public:
+ EqualityPredicateCollScanRewriteTest() : _predicate(&_mock) {
+ _mock.setForceEncryptedCollScanForTest();
+ }
+
+protected:
+ MockEqualityPredicate _predicate;
+};
+
+TEST_F(EqualityPredicateCollScanRewriteTest, Eq_Match) {
+ auto input = generateEqualityWithFFP("ssn", 1);
+ _predicate.addEncryptedField("ssn");
+
+ auto result = _predicate.rewrite(input.get());
+
+ ASSERT(result);
+ ASSERT_EQ(result->matchType(), MatchExpression::EXPRESSION);
+ auto* expr = static_cast<ExprMatchExpression*>(result.get());
+ auto aggExpr = expr->getExpression();
+
+ auto expected = fromjson(R"({
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }
+ })");
+ ASSERT_BSONOBJ_EQ(aggExpr->serialize(false).getDocument().toBson(), expected);
+}
+
+TEST_F(EqualityPredicateCollScanRewriteTest, Eq_Expr) {
+ auto expCtx = _mock.getExpressionContext();
+ auto input = generateEqualityWithFFP(expCtx, "ssn", 1);
+ _predicate.addEncryptedField("ssn");
+
+ auto result = _predicate.rewrite(input.get());
+
+ auto expected = fromjson(R"({
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }
+ })");
+ ASSERT(result);
+ ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected);
+}
+
+TEST_F(EqualityPredicateCollScanRewriteTest, In_Match) {
+ auto input = generateDisjunctionWithFFP("ssn", {1, 2, 3});
+ _predicate.addEncryptedField("0");
+ _predicate.addEncryptedField("1");
+ _predicate.addEncryptedField("2");
+
+ auto result = _predicate.rewrite(input.get());
+
+ ASSERT(result);
+ ASSERT_EQ(result->matchType(), MatchExpression::OR);
+ auto expected = fromjson(R"({
+ "$or": [
+ {
+ "$expr": {
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }
+ }
+ },
+ {
+ "$expr": {
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CLpCo6rNuYMVT+6n1HCX15MNrVYDNqf6udO46ayo43Sw",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }
+ }
+ },
+ {
+ "$expr": {
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CPi44oCQHnNDeRqHsNLzbdCeHt2DK/wCly0g2dxU5fqN",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }
+ }
+ }
+ ]
+})");
+ ASSERT_BSONOBJ_EQ(result->serialize(), expected);
+}
+
+TEST_F(EqualityPredicateCollScanRewriteTest, In_Expr) {
+ auto input = generateDisjunctionWithFFP(_mock.getExpressionContext(), "ssn", {1, 1});
+ _predicate.addEncryptedField("0");
+ _predicate.addEncryptedField("1");
+ _predicate.addEncryptedField("2");
+
+ auto result = _predicate.rewrite(input.get());
+
+ ASSERT(result);
+ auto expected = fromjson(R"({ "$or" : [ {
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }},
+ {
+ "$_internalFleEq": {
+ "field": "$ssn",
+ "edc": {
+ "$binary": {
+ "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
+ "subType": "6"
+ }
+ },
+ "counter": {
+ "$numberLong": "0"
+ },
+ "server": {
+ "$binary": {
+ "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
+ "subType": "6"
+ }
+ }
+ }}
+ ]})");
+ ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected);
+}
+
+} // namespace
+} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/query_rewriter.cpp b/src/mongo/db/query/fle/query_rewriter.cpp
new file mode 100644
index 00000000000..441f436ec00
--- /dev/null
+++ b/src/mongo/db/query/fle/query_rewriter.cpp
@@ -0,0 +1,124 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "query_rewriter.h"
+
+#include "mongo/db/matcher/expression_expr.h"
+#include "mongo/db/matcher/expression_parser.h"
+
+namespace mongo::fle {
+
+class ExpressionRewriter {
+public:
+ ExpressionRewriter(QueryRewriter* queryRewriter, const ExpressionToRewriteMap& exprRewrites)
+ : queryRewriter(queryRewriter), exprRewrites(exprRewrites){};
+
+ std::unique_ptr<Expression> postVisit(Expression* exp) {
+ if (auto rewrite = exprRewrites.find(typeid(*exp)); rewrite != exprRewrites.end()) {
+ auto expr = rewrite->second(queryRewriter, exp);
+ if (expr != nullptr) {
+ didRewrite = true;
+ }
+ return expr;
+ }
+ return nullptr;
+ }
+
+ QueryRewriter* queryRewriter;
+ const ExpressionToRewriteMap& exprRewrites;
+ bool didRewrite = false;
+};
+
+std::unique_ptr<Expression> QueryRewriter::rewriteExpression(Expression* expression) {
+ tassert(6334104, "Expected an expression to rewrite but found none", expression);
+
+ ExpressionRewriter expressionRewriter{this, this->_exprRewrites};
+ auto res = expression_walker::walk<Expression>(expression, &expressionRewriter);
+ _rewroteLastExpression = expressionRewriter.didRewrite;
+ return res;
+}
+
+boost::optional<BSONObj> QueryRewriter::rewriteMatchExpression(const BSONObj& filter) {
+ auto expr = uassertStatusOK(MatchExpressionParser::parse(filter, _expCtx));
+
+ _rewroteLastExpression = false;
+ if (auto res = _rewrite(expr.get())) {
+ // The rewrite resulted in top-level changes. Serialize the new expression.
+ return res->serialize().getOwned();
+ } else if (_rewroteLastExpression) {
+ // The rewrite had no top-level changes, but nested expressions were rewritten. Serialize
+ // the parsed expression, which has in-place changes.
+ return expr->serialize().getOwned();
+ }
+
+ // No rewrites were done.
+ return boost::none;
+}
+
+std::unique_ptr<MatchExpression> QueryRewriter::_rewrite(MatchExpression* expr) {
+ switch (expr->matchType()) {
+ case MatchExpression::AND:
+ case MatchExpression::OR:
+ case MatchExpression::NOT:
+ case MatchExpression::NOR: {
+ for (size_t i = 0; i < expr->numChildren(); i++) {
+ auto child = expr->getChild(i);
+ if (auto newChild = _rewrite(child)) {
+ expr->resetChild(i, newChild.release());
+ }
+ }
+ return nullptr;
+ }
+ case MatchExpression::EXPRESSION: {
+ // Save the current value of _rewroteLastExpression, since rewriteExpression() may
+ // reset it to false and we may have already done a match expression rewrite.
+ auto didRewrite = _rewroteLastExpression;
+ auto rewritten =
+ rewriteExpression(static_cast<ExprMatchExpression*>(expr)->getExpression().get());
+ _rewroteLastExpression |= didRewrite;
+ if (rewritten) {
+ return std::make_unique<ExprMatchExpression>(rewritten.release(),
+ getExpressionContext());
+ }
+ return nullptr;
+ }
+ default: {
+ if (auto rewrite = _matchRewrites.find(expr->matchType());
+ rewrite != _matchRewrites.end()) {
+ auto rewritten = rewrite->second(this, expr);
+ if (rewritten != nullptr) {
+ _rewroteLastExpression = true;
+ }
+ return rewritten;
+ }
+ return nullptr;
+ }
+ }
+}
+} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/query_rewriter.h b/src/mongo/db/query/fle/query_rewriter.h
new file mode 100644
index 00000000000..cf8470a84ac
--- /dev/null
+++ b/src/mongo/db/query/fle/query_rewriter.h
@@ -0,0 +1,156 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/bson/bsonobj.h"
+#include "mongo/crypto/fle_crypto.h"
+#include "mongo/db/pipeline/expression_context.h"
+#include "mongo/db/query/fle/encrypted_predicate.h"
+#include "mongo/db/query/fle/query_rewriter_interface.h"
+
+namespace mongo::fle {
+/**
+ * Class which handles traversing expressions and rewriting predicates for FLE2.
+ *
+ * The QueryRewriter is responsible for traversing Agg Expressions and MatchExpression trees and
+ * calling individual rewrites (subclasses of EncryptedPredicate) that have been registered for each
+ * encrypted index type.
+ *
+ * The actual rewrites performed are stored in references to maps in the class. In non-test
+ * environments, these are global maps that register encrypted predicate rewrites that live in their
+ * own files.
+ */
+class QueryRewriter : public QueryRewriterInterface {
+public:
+ /**
+ * Takes in references to collection readers for the ESC and ECC that are used during tag
+ * computation.
+ */
+ QueryRewriter(boost::intrusive_ptr<ExpressionContext> expCtx,
+ const FLEStateCollectionReader& escReader,
+ const FLEStateCollectionReader& eccReader,
+ EncryptedCollScanModeAllowed mode = EncryptedCollScanModeAllowed::kAllow)
+ : _expCtx(expCtx),
+ _escReader(&escReader),
+ _eccReader(&eccReader),
+ _exprRewrites(aggPredicateRewriteMap),
+ _matchRewrites(matchPredicateRewriteMap) {
+
+ if (internalQueryFLEAlwaysUseEncryptedCollScanMode.load()) {
+ _mode = EncryptedCollScanMode::kForceAlways;
+ }
+
+ if (mode == EncryptedCollScanModeAllowed::kDisallow) {
+ _mode = EncryptedCollScanMode::kDisallow;
+ }
+
+ // This isn't the "real" query so we don't want to increment Expression
+ // counters here.
+ _expCtx->stopExpressionCounters();
+ }
+
+ /**
+ * Accepts a BSONObj holding a MatchExpression, and returns BSON representing the rewritten
+ * expression. Returns boost::none if no rewriting was done.
+ *
+ * Rewrites the match expression with FLE find payloads into a disjunction on the
+ * __safeContent__ array of tags.
+ *
+ * Will rewrite top-level $eq and $in expressions, as well as recursing through $and, $or, $not
+ * and $nor. Also handles similarly limited rewriting under $expr. All other MatchExpressions,
+ * notably $elemMatch, are ignored.
+ */
+ boost::optional<BSONObj> rewriteMatchExpression(const BSONObj& filter);
+
+ /**
+ * Accepts an expression to be re-written. Will rewrite top-level expressions including $eq and
+ * $in, as well as recursing through other expressions. Returns a new pointer if the top-level
+ * expression must be changed. A nullptr indicates that the modifications happened in-place.
+ */
+ std::unique_ptr<Expression> rewriteExpression(Expression* expression);
+
+ bool isForceEncryptedCollScan() const {
+ return _mode == EncryptedCollScanMode::kForceAlways;
+ }
+
+ void setForceEncryptedCollScanForTest() {
+ _mode = EncryptedCollScanMode::kForceAlways;
+ }
+
+ EncryptedCollScanMode getEncryptedCollScanMode() const override {
+ return _mode;
+ }
+
+ const FLEStateCollectionReader* getEscReader() const override {
+ return _escReader;
+ }
+
+ const FLEStateCollectionReader* getEccReader() const override {
+ return _eccReader;
+ }
+
+ ExpressionContext* getExpressionContext() const override {
+ return _expCtx.get();
+ }
+
+protected:
+ // This constructor should only be used for mocks in testing.
+ QueryRewriter(boost::intrusive_ptr<ExpressionContext> expCtx,
+ const ExpressionToRewriteMap& exprRewrites,
+ const MatchTypeToRewriteMap& matchRewrites)
+ : _expCtx(expCtx),
+ _escReader(nullptr),
+ _eccReader(nullptr),
+ _exprRewrites(exprRewrites),
+ _matchRewrites(matchRewrites) {}
+
+private:
+ /**
+ * A single rewrite step, called recursively on child expressions.
+ */
+ std::unique_ptr<MatchExpression> _rewrite(MatchExpression* me);
+
+ boost::intrusive_ptr<ExpressionContext> _expCtx;
+
+ // Holds a pointer so that these can be null for tests, even though the public constructor
+ // takes a const reference.
+ const FLEStateCollectionReader* _escReader;
+ const FLEStateCollectionReader* _eccReader;
+
+ // True if the last Expression or MatchExpression processed by this rewriter was rewritten.
+ bool _rewroteLastExpression = false;
+
+ // Controls how query rewriter rewrites the query
+ EncryptedCollScanMode _mode{EncryptedCollScanMode::kUseIfNeeded};
+
+ const ExpressionToRewriteMap& _exprRewrites;
+ const MatchTypeToRewriteMap& _matchRewrites;
+};
+} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/query_rewriter_interface.h b/src/mongo/db/query/fle/query_rewriter_interface.h
new file mode 100644
index 00000000000..60f54defeac
--- /dev/null
+++ b/src/mongo/db/query/fle/query_rewriter_interface.h
@@ -0,0 +1,71 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/crypto/fle_crypto.h"
+#include "mongo/db/pipeline/expression_context.h"
+
+namespace mongo {
+namespace fle {
+enum class EncryptedCollScanMode {
+ // Always use high cardinality filters, used by tests
+ kForceAlways,
+
+ // Use high cardinality mode if $in rewrites do not fit in the
+ // internalQueryFLERewriteMemoryLimit memory limit
+ kUseIfNeeded,
+
+ // Do not rewrite into high cardinality filter, throw exceptions instead
+ // Some contexts like upsert do not support $expr
+ kDisallow,
+};
+
+/**
+ * Low Selectivity rewrites use $expr which is not supported in all commands such as upserts.
+ */
+enum class EncryptedCollScanModeAllowed {
+ kAllow,
+ kDisallow,
+};
+
+/**
+ * Pure virtual class that allows encrypted predicate rewrites to be unit tested independently from
+ * the actual server rewrite.
+ */
+class QueryRewriterInterface {
+public:
+ virtual ~QueryRewriterInterface() {}
+ virtual const FLEStateCollectionReader* getEscReader() const = 0;
+ virtual const FLEStateCollectionReader* getEccReader() const = 0;
+ virtual EncryptedCollScanMode getEncryptedCollScanMode() const = 0;
+ virtual ExpressionContext* getExpressionContext() const = 0;
+};
+} // namespace fle
+} // namespace mongo
diff --git a/src/mongo/db/query/fle/query_rewriter_test.cpp b/src/mongo/db/query/fle/query_rewriter_test.cpp
new file mode 100644
index 00000000000..d779f6be1cc
--- /dev/null
+++ b/src/mongo/db/query/fle/query_rewriter_test.cpp
@@ -0,0 +1,209 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+
+#include <memory>
+
+#include "query_rewriter.h"
+
+#include "mongo/bson/bsonelement.h"
+#include "mongo/bson/bsonmisc.h"
+#include "mongo/bson/bsonobjbuilder.h"
+#include "mongo/bson/bsontypes.h"
+#include "mongo/db/matcher/expression_leaf.h"
+#include "mongo/db/pipeline/expression.h"
+#include "mongo/db/pipeline/expression_context_for_test.h"
+#include "mongo/db/query/fle/encrypted_predicate.h"
+#include "mongo/db/query/fle/encrypted_predicate_test_fixtures.h"
+#include "mongo/db/query/fle/query_rewriter_interface.h"
+#include "mongo/idl/server_parameter_test_util.h"
+#include "mongo/unittest/unittest.h"
+#include "mongo/util/overloaded_visitor.h"
+
+
+namespace mongo {
+namespace {
+
+/*
+ * The server rewrite itself is only responsible for traversing agg and MatchExpressions and
+ * executing whatever rewrites are registered. For unit testing, we will only verify that this
+ * traversal and rewrite is happening properly using a mock predicate rewriter that rewrites any
+ * equality with an object with the key `encrypt` to a $gt operator. Unit tests for the actual
+ * rewrites while mocking out tag generation are located in the test file for each encrypted
+ * predicate type. Full end-to-end testing happens in jstests. This organization ensures that we
+ * don't write redundant tests that each index type is properly rewritten under different
+ * circumstances, when the same exact code is called for each index type.
+ */
+
+class MockPredicateRewriter : public fle::EncryptedPredicate {
+public:
+ MockPredicateRewriter(const fle::QueryRewriterInterface* rewriter)
+ : EncryptedPredicate(rewriter) {}
+
+protected:
+ bool isPayload(const BSONElement& elt) const override {
+ if (!elt.isABSONObj()) {
+ return false;
+ }
+ return elt.Obj().firstElementFieldNameStringData() == "encrypt"_sd;
+ }
+ bool isPayload(const Value& v) const override {
+ if (!v.isObject()) {
+ return false;
+ }
+ return !v.getDocument().getField("encrypt").missing();
+ }
+
+ std::vector<PrfBlock> generateTags(fle::BSONValue payload) const override {
+ return {};
+ };
+
+ // Encrypted values will be rewritten from $eq to $gt. This is an arbitrary decision just to
+ // make sure that the rewrite works properly.
+ std::unique_ptr<MatchExpression> rewriteToTagDisjunction(MatchExpression* expr) const override {
+ invariant(expr->matchType() == MatchExpression::EQ);
+ auto eqMatch = static_cast<EqualityMatchExpression*>(expr);
+ if (!isPayload(eqMatch->getData())) {
+ return nullptr;
+ }
+ return std::make_unique<GTMatchExpression>(eqMatch->path(),
+ eqMatch->getData().Obj().firstElement());
+ };
+
+ std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override {
+ return nullptr;
+ }
+
+ std::unique_ptr<MatchExpression> rewriteToRuntimeComparison(
+ MatchExpression* expr) const override {
+ return nullptr;
+ }
+
+ std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const override {
+ return nullptr;
+ }
+
+private:
+ EncryptedBinDataType encryptedBinDataType() const override {
+ return EncryptedBinDataType::kPlaceholder; // return the 0 type. this isn't used anywhere.
+ }
+};
+
+void setMockRewriteMaps(fle::MatchTypeToRewriteMap& match,
+ fle::ExpressionToRewriteMap& agg,
+ fle::TagMap& tags,
+ std::set<StringData>& encryptedFields) {
+ match[MatchExpression::EQ] = [&](auto* rewriter, auto* expr) {
+ return MockPredicateRewriter{rewriter}.rewrite(expr);
+ };
+}
+
+class MockQueryRewriter : public fle::QueryRewriter {
+public:
+ MockQueryRewriter(fle::ExpressionToRewriteMap* exprRewrites,
+ fle::MatchTypeToRewriteMap* matchRewrites)
+ : fle::QueryRewriter(new ExpressionContextForTest(), *exprRewrites, *matchRewrites) {
+ setMockRewriteMaps(*matchRewrites, *exprRewrites, _tags, _encryptedFields);
+ }
+
+ BSONObj rewriteMatchExpressionForTest(const BSONObj& obj) {
+ auto res = rewriteMatchExpression(obj);
+ return res ? res.value() : obj;
+ }
+
+private:
+ fle::TagMap _tags;
+ std::set<StringData> _encryptedFields;
+};
+
+class FLEServerRewriteTest : public unittest::Test {
+public:
+ FLEServerRewriteTest() : _mock(nullptr) {}
+
+ void setUp() override {
+ _mock = std::make_unique<MockQueryRewriter>(&_agg, &_match);
+ }
+
+ void tearDown() override {}
+
+protected:
+ std::unique_ptr<MockQueryRewriter> _mock;
+ fle::ExpressionToRewriteMap _agg;
+ fle::MatchTypeToRewriteMap _match;
+};
+
+#define ASSERT_MATCH_EXPRESSION_REWRITE(input, expected) \
+ auto actual = _mock->rewriteMatchExpressionForTest(fromjson(input)); \
+ ASSERT_BSONOBJ_EQ(actual, fromjson(expected));
+
+#define TEST_FLE_REWRITE_MATCH(name, input, expected) \
+ TEST_F(FLEServerRewriteTest, name##_MatchExpression) { \
+ ASSERT_MATCH_EXPRESSION_REWRITE(input, expected); \
+ }
+
+TEST_FLE_REWRITE_MATCH(TopLevel_DottedPath,
+ "{'user.ssn': {$eq: {encrypt: 2}}}",
+ "{'user.ssn': {$gt: 2}}");
+
+TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_BothEncrypted,
+ "{$and: [{ssn: {encrypt: 2}}, {age: {encrypt: 4}}]}",
+ "{$and: [{ssn: {$gt: 2}}, {age: {$gt: 4}}]}");
+
+TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_PartlyEncrypted,
+ "{$and: [{ssn: {encrypt: 2}}, {age: {plain: 4}}]}",
+ "{$and: [{ssn: {$gt: 2}}, {age: {$eq: {plain: 4}}}]}");
+
+TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_PartlyEncryptedWithUnregisteredOperator,
+ "{$and: [{ssn: {encrypt: 2}}, {age: {$lt: {encrypt: 4}}}]}",
+ "{$and: [{ssn: {$gt: 2}}, {age: {$lt: {encrypt: 4}}}]}");
+
+TEST_FLE_REWRITE_MATCH(TopLevel_Encrypted_Nested_Unecrypted,
+ "{$and: [{ssn: {encrypt: 2}}, {user: {region: 'US'}}]}",
+ "{$and: [{ssn: {$gt: 2}}, {user: {$eq: {region: 'US'}}}]}");
+
+TEST_FLE_REWRITE_MATCH(TopLevel_Not,
+ "{ssn: {$not: {$eq: {encrypt: 5}}}}",
+ "{ssn: {$not: {$gt: 5}}}");
+
+TEST_FLE_REWRITE_MATCH(TopLevel_Neq, "{ssn: {$ne: {encrypt: 5}}}", "{ssn: {$not: {$gt: 5}}}}");
+
+TEST_FLE_REWRITE_MATCH(
+ NestedConjunction,
+ "{$and: [{$and: [{ssn: {encrypt: 2}}, {other: 'field'}]}, {otherSsn: {encrypt: 3}}]}",
+ "{$and: [{$and: [{ssn: {$gt: 2}}, {other: {$eq: 'field'}}]}, {otherSsn: {$gt: 3}}]}");
+
+TEST_FLE_REWRITE_MATCH(TopLevel_Nor,
+ "{$nor: [{ssn: {encrypt: 5}}, {other: {$eq: 'field'}}]}",
+ "{$nor: [{ssn: {$gt: 5}}, {other: {$eq: 'field'}}]}");
+
+TEST_FLE_REWRITE_MATCH(TopLevel_Or,
+ "{$or: [{ssn: {encrypt: 5}}, {other: {$eq: 'field'}}]}",
+ "{$or: [{ssn: {$gt: 5}}, {other: {$eq: 'field'}}]}");
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp
new file mode 100644
index 00000000000..322d3025dc4
--- /dev/null
+++ b/src/mongo/db/query/fle/range_predicate.cpp
@@ -0,0 +1,84 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "range_predicate.h"
+
+#include "mongo/crypto/encryption_fields_gen.h"
+#include "mongo/db/matcher/expression_leaf.h"
+#include "mongo/db/query/fle/encrypted_predicate.h"
+
+namespace mongo::fle {
+
+REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(ENCRYPTED_BETWEEN,
+ RangePredicate,
+ gFeatureFlagFLE2Range);
+
+// TODO: SERVER-67206 Generate tags for range payload.
+std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const {
+ return {};
+}
+
+std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction(
+ MatchExpression* expr) const {
+ invariant(expr->matchType() == MatchExpression::ENCRYPTED_BETWEEN);
+ auto betExpr = static_cast<EncryptedBetweenMatchExpression*>(expr);
+ auto ffp = betExpr->rhs();
+
+ if (!isPayload(ffp)) {
+ return nullptr;
+ }
+
+ auto obj = toBSONArray(generateTags(ffp));
+ auto tags = std::vector<BSONElement>();
+ obj.elems(tags);
+ auto inExpr = std::make_unique<InMatchExpression>(kSafeContent);
+ inExpr->setBackingBSON(std::move(obj));
+ auto status = inExpr->setEqualities(std::move(tags));
+ uassertStatusOK(status);
+ return inExpr;
+}
+
+// TODO: SERVER-67209 Server-side rewrite for agg expressions with $encryptedBetween.
+std::unique_ptr<Expression> RangePredicate::rewriteToTagDisjunction(Expression* expr) const {
+ return nullptr;
+}
+
+// TODO: SERVER-67267 Rewrite $encryptedBetween to $_internalFleBetween when number of tags exceeds
+// limit.
+std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison(
+ MatchExpression* expr) const {
+ return nullptr;
+}
+
+// TODO: SERVER-67267 Rewrite $encryptedBetween to $_internalFleBetween when number of tags exceeds
+// limit.
+std::unique_ptr<Expression> RangePredicate::rewriteToRuntimeComparison(Expression* expr) const {
+ return nullptr;
+}
+} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/range_predicate.h b/src/mongo/db/query/fle/range_predicate.h
new file mode 100644
index 00000000000..a310cde16e2
--- /dev/null
+++ b/src/mongo/db/query/fle/range_predicate.h
@@ -0,0 +1,58 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/crypto/encryption_fields_gen.h"
+#include "mongo/db/query/fle/encrypted_predicate.h"
+
+namespace mongo::fle {
+/**
+ * Rewrite for the encrypted range index, which expects a $encryptedBetween expression.
+ */
+class RangePredicate : public EncryptedPredicate {
+public:
+ RangePredicate(const QueryRewriterInterface* rewriter) : EncryptedPredicate(rewriter) {}
+
+protected:
+ std::vector<PrfBlock> generateTags(BSONValue payload) const override;
+
+ std::unique_ptr<MatchExpression> rewriteToTagDisjunction(MatchExpression* expr) const override;
+ std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override;
+
+ std::unique_ptr<MatchExpression> rewriteToRuntimeComparison(
+ MatchExpression* expr) const override;
+ std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const override;
+
+private:
+ EncryptedBinDataType encryptedBinDataType() const override {
+ return EncryptedBinDataType::kFLE2FindRangePayload;
+ }
+};
+} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/range_predicate_test.cpp b/src/mongo/db/query/fle/range_predicate_test.cpp
new file mode 100644
index 00000000000..c240abde6f8
--- /dev/null
+++ b/src/mongo/db/query/fle/range_predicate_test.cpp
@@ -0,0 +1,131 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/crypto/fle_crypto.h"
+#include "mongo/db/matcher/expression_leaf.h"
+#include "mongo/db/query/fle/encrypted_predicate_test_fixtures.h"
+#include "mongo/db/query/fle/range_predicate.h"
+#include "mongo/idl/server_parameter_test_util.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo::fle {
+namespace {
+class MockRangePredicate : public RangePredicate {
+public:
+ MockRangePredicate(const QueryRewriterInterface* rewriter) : RangePredicate(rewriter) {}
+
+ MockRangePredicate(const QueryRewriterInterface* rewriter,
+ TagMap tags,
+ std::set<StringData> encryptedFields)
+ : RangePredicate(rewriter), _tags(tags), _encryptedFields(encryptedFields) {}
+
+ void setEncryptedTags(std::pair<StringData, int> fieldvalue, std::vector<PrfBlock> tags) {
+ _encryptedFields.insert(fieldvalue.first);
+ _tags[fieldvalue] = tags;
+ }
+
+
+protected:
+ bool isPayload(const BSONElement& elt) const override {
+ return true;
+ }
+
+ bool isPayload(const Value& v) const override {
+ return true;
+ }
+
+ std::vector<PrfBlock> generateTags(BSONValue payload) const {
+ return stdx::visit(
+ OverloadedVisitor{
+ [&](BSONElement p) {
+ auto parsedPayload = p.Obj().firstElement();
+ auto fieldName = parsedPayload.fieldNameStringData();
+
+ std::vector<BSONElement> range;
+ auto payloadAsArray = parsedPayload.Array();
+ for (auto&& elt : payloadAsArray) {
+ range.push_back(elt);
+ }
+
+ std::vector<PrfBlock> allTags;
+ for (auto i = range[0].Number(); i <= range[1].Number(); i++) {
+ ASSERT(_tags.find({fieldName, i}) != _tags.end());
+ auto temp = _tags.find({fieldName, i})->second;
+ for (auto tag : temp) {
+ allTags.push_back(tag);
+ }
+ }
+ return allTags;
+ },
+ [&](std::reference_wrapper<Value> v) { return std::vector<PrfBlock>{}; }},
+ payload);
+ }
+
+private:
+ TagMap _tags;
+ std::set<StringData> _encryptedFields;
+};
+class RangePredicateRewriteTest : public EncryptedPredicateRewriteTest {
+public:
+ RangePredicateRewriteTest() : _predicate(&_mock) {}
+
+protected:
+ MockRangePredicate _predicate;
+};
+
+TEST_F(RangePredicateRewriteTest, BasicRangeRewrite) {
+ RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
+
+ int start = 1;
+ int end = 3;
+ StringData encField = "ssn";
+
+ std::vector<PrfBlock> tags1 = {{1}, {2}, {3}};
+ std::vector<PrfBlock> tags2 = {{4}, {5}, {6}};
+ std::vector<PrfBlock> tags3 = {{7}, {8}, {9}};
+
+ _predicate.setEncryptedTags({encField, 1}, tags1);
+ _predicate.setEncryptedTags({encField, 2}, tags2);
+ _predicate.setEncryptedTags({encField, 3}, tags3);
+
+ std::vector<PrfBlock> allTags = {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}};
+
+ // The field redundancy is so that we can pull out the field
+ // name in the mock version of rewriteRangePayloadAsTags.
+ BSONObj query =
+ BSON(encField << BSON("$encryptedBetween" << BSON(encField << BSON_ARRAY(start << end))));
+
+ auto inputExpr =
+ EncryptedBetweenMatchExpression(encField, query[encField]["$encryptedBetween"], nullptr);
+
+ assertRewriteToTags(_predicate, &inputExpr, toBSONArray(std::move(allTags)));
+}
+
+}; // namespace
+} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/server_rewrite.cpp b/src/mongo/db/query/fle/server_rewrite.cpp
index be98058a246..6575a2483d2 100644
--- a/src/mongo/db/query/fle/server_rewrite.cpp
+++ b/src/mongo/db/query/fle/server_rewrite.cpp
@@ -48,6 +48,8 @@
#include "mongo/db/pipeline/document_source_match.h"
#include "mongo/db/pipeline/expression.h"
#include "mongo/db/query/collation/collator_factory_interface.h"
+#include "mongo/db/query/fle/encrypted_predicate.h"
+#include "mongo/db/query/fle/query_rewriter.h"
#include "mongo/db/service_context.h"
#include "mongo/logv2/log.h"
#include "mongo/s/grid.h"
@@ -73,64 +75,13 @@ std::unique_ptr<CollatorInterface> collatorFromBSON(OperationContext* opCtx,
return collator;
}
namespace {
-
-template <typename PayloadT>
-boost::intrusive_ptr<ExpressionInternalFLEEqual> generateFleEqualMatch(StringData path,
- const PayloadT& ffp,
- ExpressionContext* expCtx) {
- // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] }
- auto tokens = ParsedFindPayload(ffp);
-
- uassert(6672401,
- "Missing required field server encryption token in find payload",
- tokens.serverToken.has_value());
-
- return make_intrusive<ExpressionInternalFLEEqual>(
- expCtx,
- ExpressionFieldPath::createPathFromString(
- expCtx, path.toString(), expCtx->variablesParseState),
- tokens.serverToken.value().data,
- tokens.maxCounter.value_or(0LL),
- tokens.edcToken.data);
-}
-
-
-template <typename PayloadT>
-std::unique_ptr<ExpressionInternalFLEEqual> generateFleEqualMatchUnique(StringData path,
- const PayloadT& ffp,
- ExpressionContext* expCtx) {
- // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] }
- auto tokens = ParsedFindPayload(ffp);
-
- uassert(6672419,
- "Missing required field server encryption token in find payload",
- tokens.serverToken.has_value());
-
- return std::make_unique<ExpressionInternalFLEEqual>(
- expCtx,
- ExpressionFieldPath::createPathFromString(
- expCtx, path.toString(), expCtx->variablesParseState),
- tokens.serverToken.value().data,
- tokens.maxCounter.value_or(0LL),
- tokens.edcToken.data);
-}
-
-std::unique_ptr<MatchExpression> generateFleEqualMatchAndExpr(StringData path,
- const BSONElement ffp,
- ExpressionContext* expCtx) {
- auto fleEqualMatch = generateFleEqualMatch(path, ffp, expCtx);
-
- return std::make_unique<ExprMatchExpression>(fleEqualMatch, expCtx);
-}
-
-
/**
* This section defines a mapping from DocumentSources to the dispatch function to appropriately
* handle FLE rewriting for that stage. This should be kept in line with code on the client-side
* that marks constants for encryption: we should handle all places where an implicitly-encrypted
* value may be for each stage, otherwise we may return non-sensical results.
*/
-static stdx::unordered_map<std::type_index, std::function<void(FLEQueryRewriter*, DocumentSource*)>>
+static stdx::unordered_map<std::type_index, std::function<void(QueryRewriter*, DocumentSource*)>>
stageRewriterMap;
#define REGISTER_DOCUMENT_SOURCE_FLE_REWRITER(className, rewriterFunc) \
@@ -142,19 +93,19 @@ static stdx::unordered_map<std::type_index, std::function<void(FLEQueryRewriter*
}; \
}
-void rewriteMatch(FLEQueryRewriter* rewriter, DocumentSourceMatch* source) {
+void rewriteMatch(QueryRewriter* rewriter, DocumentSourceMatch* source) {
if (auto rewritten = rewriter->rewriteMatchExpression(source->getQuery())) {
source->rebuild(rewritten.value());
}
}
-void rewriteGeoNear(FLEQueryRewriter* rewriter, DocumentSourceGeoNear* source) {
+void rewriteGeoNear(QueryRewriter* rewriter, DocumentSourceGeoNear* source) {
if (auto rewritten = rewriter->rewriteMatchExpression(source->getQuery())) {
source->setQuery(rewritten.value());
}
}
-void rewriteGraphLookUp(FLEQueryRewriter* rewriter, DocumentSourceGraphLookUp* source) {
+void rewriteGraphLookUp(QueryRewriter* rewriter, DocumentSourceGraphLookUp* source) {
if (auto filter = source->getAdditionalFilter()) {
if (auto rewritten = rewriter->rewriteMatchExpression(filter.value())) {
source->setAdditionalFilter(rewritten.value());
@@ -170,229 +121,6 @@ REGISTER_DOCUMENT_SOURCE_FLE_REWRITER(DocumentSourceMatch, rewriteMatch);
REGISTER_DOCUMENT_SOURCE_FLE_REWRITER(DocumentSourceGeoNear, rewriteGeoNear);
REGISTER_DOCUMENT_SOURCE_FLE_REWRITER(DocumentSourceGraphLookUp, rewriteGraphLookUp);
-class FLEExpressionRewriter {
-public:
- FLEExpressionRewriter(FLEQueryRewriter* queryRewriter) : queryRewriter(queryRewriter){};
-
- /**
- * Accepts a vector of expressions to be compared for equality to an encrypted field. For any
- * expression representing a constant encrypted value, computes the tags for the expression and
- * rewrites the comparison to a disjunction over __safeContent__. Returns an OR expression of
- * these disjunctions. If no rewrites were done, returns nullptr. Either all of the expressions
- * be constant FFPs or none of them should be.
- *
- * The final output will look like
- * {$or: [{$in: [tag0, "$__safeContent__"]}, {$in: [tag1, "$__safeContent__"]}, ...]}.
- */
- std::unique_ptr<Expression> rewriteInToEncryptedField(
- const Expression* leftExpr,
- const std::vector<boost::intrusive_ptr<Expression>>& equalitiesList) {
- size_t numFFPs = 0;
- std::vector<boost::intrusive_ptr<Expression>> orListElems;
-
- for (auto& equality : equalitiesList) {
- // For each expression representing a FleFindPayload...
- if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) {
- if (!queryRewriter->isFleFindPayload(
- constChild->getValue(), EncryptedBinDataType::kFLE2FindEqualityPayload)) {
- continue;
- }
-
- numFFPs++;
- }
- }
-
- // Finally, construct an $or of all of the $ins.
- if (numFFPs == 0) {
- return nullptr;
- }
-
- uassert(
- 6334102,
- "If any elements in an comparison expression are encrypted, then all elements should "
- "be encrypted.",
- numFFPs == equalitiesList.size());
-
- auto leftFieldPath = dynamic_cast<const ExpressionFieldPath*>(leftExpr);
- uassert(6672417,
- "$in is only supported with Queryable Encryption when the first argument is a "
- "field path",
- leftFieldPath != nullptr);
-
- if (!queryRewriter->isForceEncryptedCollScan()) {
- try {
- for (auto& equality : equalitiesList) {
- // For each expression representing a FleFindPayload...
- if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) {
- // ... rewrite the payload to a list of tags...
- auto tags =
- queryRewriter->rewriteEqualityPayloadAsTags(constChild->getValue());
- for (auto&& tagElt : tags) {
- // ... and for each tag, construct expression {$in: [tag,
- // "$__safeContent__"]}.
- std::vector<boost::intrusive_ptr<Expression>> inVec{
- ExpressionConstant::create(queryRewriter->expCtx(), tagElt),
- ExpressionFieldPath::createPathFromString(
- queryRewriter->expCtx(),
- kSafeContent,
- queryRewriter->expCtx()->variablesParseState)};
- orListElems.push_back(make_intrusive<ExpressionIn>(
- queryRewriter->expCtx(), std::move(inVec)));
- }
- }
- }
-
- didRewrite = true;
-
- return std::make_unique<ExpressionOr>(queryRewriter->expCtx(),
- std::move(orListElems));
- } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) {
- LOGV2_DEBUG(6672403,
- 2,
- "FLE Max tag limit hit during aggregation $in rewrite",
- "__error__"_attr = ex.what());
-
- if (queryRewriter->getEncryptedCollScanMode() !=
- FLEQueryRewriter::EncryptedCollScanMode::kUseIfNeeded) {
- throw;
- }
-
- // fall through
- }
- }
-
- for (auto& equality : equalitiesList) {
- if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) {
- auto fleEqExpr = generateFleEqualMatch(
- leftFieldPath->getFieldPathWithoutCurrentPrefix().fullPath(),
- constChild->getValue(),
- queryRewriter->expCtx());
- orListElems.push_back(fleEqExpr);
- }
- }
-
- didRewrite = true;
- return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), std::move(orListElems));
- }
-
- // Rewrite a [$eq : [$fieldpath, constant]] or [$eq: [constant, $fieldpath]]
- // to _internalFleEq: {field: $fieldpath, edc: edcToken, counter: N, server: serverToken}
- std::unique_ptr<Expression> rewriteComparisonsToEncryptedField(
- const std::vector<boost::intrusive_ptr<Expression>>& equalitiesList) {
-
- auto leftConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[0].get());
- auto rightConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[1].get());
-
- bool isLeftFFP = leftConstant &&
- queryRewriter->isFleFindPayload(leftConstant->getValue(),
- EncryptedBinDataType::kFLE2FindEqualityPayload);
- bool isRightFFP = rightConstant &&
- queryRewriter->isFleFindPayload(rightConstant->getValue(),
- EncryptedBinDataType::kFLE2FindEqualityPayload);
-
- uassert(6334100,
- "Cannot compare two encrypted constants to each other",
- !(isLeftFFP && isRightFFP));
-
- // No FLE Find Payload
- if (!isLeftFFP && !isRightFFP) {
- return nullptr;
- }
-
- auto leftFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[0].get());
- auto rightFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[1].get());
-
- uassert(
- 6672413,
- "Queryable Encryption only supports comparisons between a field path and a constant",
- leftFieldPath || rightFieldPath);
-
- auto fieldPath = leftFieldPath ? leftFieldPath : rightFieldPath;
- auto constChild = isLeftFFP ? leftConstant : rightConstant;
-
- if (!queryRewriter->isForceEncryptedCollScan()) {
- try {
- std::vector<boost::intrusive_ptr<Expression>> orListElems;
-
- auto tags = queryRewriter->rewriteEqualityPayloadAsTags(constChild->getValue());
- for (auto&& tagElt : tags) {
- // ... and for each tag, construct expression {$in: [tag,
- // "$__safeContent__"]}.
- std::vector<boost::intrusive_ptr<Expression>> inVec{
- ExpressionConstant::create(queryRewriter->expCtx(), tagElt),
- ExpressionFieldPath::createPathFromString(
- queryRewriter->expCtx(),
- kSafeContent,
- queryRewriter->expCtx()->variablesParseState)};
- orListElems.push_back(
- make_intrusive<ExpressionIn>(queryRewriter->expCtx(), std::move(inVec)));
- }
-
- didRewrite = true;
- return std::make_unique<ExpressionOr>(queryRewriter->expCtx(),
- std::move(orListElems));
-
- } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) {
- LOGV2_DEBUG(6672409,
- 2,
- "FLE Max tag limit hit during query $in rewrite",
- "__error__"_attr = ex.what());
-
- if (queryRewriter->getEncryptedCollScanMode() !=
- FLEQueryRewriter::EncryptedCollScanMode::kUseIfNeeded) {
- throw;
- }
-
- // fall through
- }
- }
-
- auto fleEqExpr =
- generateFleEqualMatchUnique(fieldPath->getFieldPathWithoutCurrentPrefix().fullPath(),
- constChild->getValue(),
- queryRewriter->expCtx());
-
- didRewrite = true;
- return fleEqExpr;
- }
-
- std::unique_ptr<Expression> postVisit(Expression* exp) {
- if (auto inExpr = dynamic_cast<ExpressionIn*>(exp)) {
- // Rewrite an $in over an encrypted field to an $or. The first child of the $in can be
- // ignored when rewrites are done; there is no extra information in that child that
- // doesn't exist in the FFPs in the $in list.
- if (auto inList = dynamic_cast<ExpressionArray*>(inExpr->getOperandList()[1].get())) {
- return rewriteInToEncryptedField(inExpr->getOperandList()[0].get(),
- inList->getChildren());
- }
- } else if (auto eqExpr = dynamic_cast<ExpressionCompare*>(exp); eqExpr &&
- (eqExpr->getOp() == ExpressionCompare::EQ ||
- eqExpr->getOp() == ExpressionCompare::NE)) {
- // Rewrite an $eq comparing an encrypted field and an encrypted constant to an $or.
- auto newExpr = rewriteComparisonsToEncryptedField(eqExpr->getChildren());
-
- // Neither child is an encrypted constant, and no rewriting needs to be done.
- if (!newExpr) {
- return nullptr;
- }
-
- // Exactly one child was an encrypted constant. The other child can be ignored; there is
- // no extra information in that child that doesn't exist in the FFP.
- if (eqExpr->getOp() == ExpressionCompare::NE) {
- std::vector<boost::intrusive_ptr<Expression>> notChild{newExpr.release()};
- return std::make_unique<ExpressionNot>(queryRewriter->expCtx(),
- std::move(notChild));
- }
- return newExpr;
- }
-
- return nullptr;
- }
-
- FLEQueryRewriter* queryRewriter;
- bool didRewrite = false;
-};
-
BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader,
const FLEStateCollectionReader& eccReader,
boost::intrusive_ptr<ExpressionContext> expCtx,
@@ -400,7 +128,7 @@ BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader,
EncryptedCollScanModeAllowed mode) {
if (auto rewritten =
- FLEQueryRewriter(expCtx, escReader, eccReader, mode).rewriteMatchExpression(filter)) {
+ QueryRewriter(expCtx, escReader, eccReader, mode).rewriteMatchExpression(filter)) {
return rewritten.value();
}
@@ -437,7 +165,7 @@ public:
~PipelineRewrite(){};
void doRewrite(FLEStateCollectionReader& escReader, FLEStateCollectionReader& eccReader) final {
- auto rewriter = FLEQueryRewriter(expCtx, escReader, eccReader);
+ auto rewriter = QueryRewriter(expCtx, escReader, eccReader);
for (auto&& source : pipeline->getSources()) {
if (stageRewriterMap.find(typeid(*source)) != stageRewriterMap.end()) {
stageRewriterMap[typeid(*source)](&rewriter, source.get());
@@ -481,7 +209,6 @@ void doFLERewriteInTxn(OperationContext* opCtx,
std::shared_ptr<RewriteBase> sharedBlock,
GetTxnCallback getTxn) {
auto txn = getTxn(opCtx);
-
auto swCommitResult = txn->runNoThrow(
opCtx, [sharedBlock](const txn_api::TransactionClient& txnClient, auto txnExec) {
auto makeCollectionReader = [sharedBlock](FLEQueryInterface* queryImpl,
@@ -602,242 +329,4 @@ std::unique_ptr<Pipeline, PipelineDeleter> processPipeline(
return sharedBlock->getPipeline();
}
-
-std::unique_ptr<Expression> FLEQueryRewriter::rewriteExpression(Expression* expression) {
- tassert(6334104, "Expected an expression to rewrite but found none", expression);
-
- FLEExpressionRewriter expressionRewriter{this};
- auto res = expression_walker::walk<Expression>(expression, &expressionRewriter);
- _rewroteLastExpression = expressionRewriter.didRewrite;
- return res;
-}
-
-boost::optional<BSONObj> FLEQueryRewriter::rewriteMatchExpression(const BSONObj& filter) {
- auto expr = uassertStatusOK(MatchExpressionParser::parse(filter, _expCtx));
-
- _rewroteLastExpression = false;
- if (auto res = _rewrite(expr.get())) {
- // The rewrite resulted in top-level changes. Serialize the new expression.
- return res->serialize().getOwned();
- } else if (_rewroteLastExpression) {
- // The rewrite had no top-level changes, but nested expressions were rewritten. Serialize
- // the parsed expression, which has in-place changes.
- return expr->serialize().getOwned();
- }
-
- // No rewrites were done.
- return boost::none;
-}
-
-std::unique_ptr<MatchExpression> FLEQueryRewriter::_rewrite(MatchExpression* expr) {
- switch (expr->matchType()) {
- case MatchExpression::EQ:
- return rewriteEq(std::move(static_cast<const EqualityMatchExpression*>(expr)));
- case MatchExpression::MATCH_IN:
- return rewriteIn(std::move(static_cast<const InMatchExpression*>(expr)));
- case MatchExpression::AND:
- case MatchExpression::OR:
- case MatchExpression::NOT:
- case MatchExpression::NOR: {
- for (size_t i = 0; i < expr->numChildren(); i++) {
- auto child = expr->getChild(i);
- if (auto newChild = _rewrite(child)) {
- expr->resetChild(i, newChild.release());
- }
- }
- return nullptr;
- }
- case MatchExpression::ENCRYPTED_BETWEEN: {
- if (gFeatureFlagFLE2Range.isEnabled(serverGlobalParams.featureCompatibility)) {
- return rewriteRange(
- std::move(static_cast<const EncryptedBetweenMatchExpression*>(expr)));
- }
- return nullptr;
- }
- case MatchExpression::EXPRESSION: {
- // Save the current value of _rewroteLastExpression, since rewriteExpression() may
- // reset it to false and we may have already done a match expression rewrite.
- auto didRewrite = _rewroteLastExpression;
- auto rewritten =
- rewriteExpression(static_cast<ExprMatchExpression*>(expr)->getExpression().get());
- _rewroteLastExpression |= didRewrite;
- if (rewritten) {
- return std::make_unique<ExprMatchExpression>(rewritten.release(), expCtx());
- }
- [[fallthrough]];
- }
- default:
- return nullptr;
- }
-}
-
-BSONObj FLEQueryRewriter::rewriteEqualityPayloadAsTags(BSONElement fleFindPayload) const {
- auto tokens = ParsedFindPayload(fleFindPayload);
- auto tags = readTags(*_escReader,
- *_eccReader,
- tokens.escToken,
- tokens.eccToken,
- tokens.edcToken,
- tokens.maxCounter);
-
- auto bab = BSONArrayBuilder();
- for (auto tag : tags) {
- bab.appendBinData(tag.size(), BinDataType::BinDataGeneral, tag.data());
- }
-
- return bab.obj().getOwned();
-}
-
-std::vector<Value> FLEQueryRewriter::rewriteEqualityPayloadAsTags(Value fleFindPayload) const {
- auto tokens = ParsedFindPayload(fleFindPayload);
- auto tags = readTags(*_escReader,
- *_eccReader,
- tokens.escToken,
- tokens.eccToken,
- tokens.edcToken,
- tokens.maxCounter);
-
- std::vector<Value> tagVec;
- for (auto tag : tags) {
- tagVec.push_back(Value(BSONBinData(tag.data(), tag.size(), BinDataType::BinDataGeneral)));
- }
- return tagVec;
-}
-
-BSONObj FLEQueryRewriter::rewriteRangePayloadAsTags(BSONElement fleFindPayload) const {
- // TODO: SERVER-67206
- return BSONObj();
-}
-
-std::vector<Value> FLEQueryRewriter::rewriteRangePayloadAsTags(Value fleFindPayload) const {
- // TODO: SERVER-67206
- return std::vector({Value(0)});
-}
-
-std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteEq(const EqualityMatchExpression* expr) {
- auto ffp = expr->getData();
- if (!isFleFindPayload(ffp, EncryptedBinDataType::kFLE2FindEqualityPayload)) {
- return nullptr;
- }
-
- if (_mode != EncryptedCollScanMode::kForceAlways) {
- try {
- auto obj = rewriteEqualityPayloadAsTags(ffp);
-
- auto tags = std::vector<BSONElement>();
- obj.elems(tags);
-
- auto inExpr = std::make_unique<InMatchExpression>(kSafeContent);
- inExpr->setBackingBSON(std::move(obj));
- auto status = inExpr->setEqualities(std::move(tags));
- uassertStatusOK(status);
- _rewroteLastExpression = true;
- return inExpr;
- } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) {
- LOGV2_DEBUG(6672410,
- 2,
- "FLE Max tag limit hit during query $eq rewrite",
- "__error__"_attr = ex.what());
-
- if (_mode != EncryptedCollScanMode::kUseIfNeeded) {
- throw;
- }
-
- // fall through
- }
- }
-
- auto exprMatch = generateFleEqualMatchAndExpr(expr->path(), ffp, _expCtx.get());
- _rewroteLastExpression = true;
- return exprMatch;
-}
-
-std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteIn(const InMatchExpression* expr) {
- size_t numFFPs = 0;
- for (auto& eq : expr->getEqualities()) {
- if (isFleFindPayload(eq, EncryptedBinDataType::kFLE2FindEqualityPayload)) {
- ++numFFPs;
- }
- }
-
- if (numFFPs == 0) {
- return nullptr;
- }
-
- // All elements in an encrypted $in expression should be FFPs.
- uassert(
- 6329400,
- "If any elements in a $in expression are encrypted, then all elements should be encrypted.",
- numFFPs == expr->getEqualities().size());
-
- if (_mode != EncryptedCollScanMode::kForceAlways) {
-
- try {
- auto backingBSONBuilder = BSONArrayBuilder();
-
- for (auto& eq : expr->getEqualities()) {
- auto obj = rewriteEqualityPayloadAsTags(eq);
- for (auto&& elt : obj) {
- backingBSONBuilder.append(elt);
- }
- }
-
- auto backingBSON = backingBSONBuilder.arr();
- auto allTags = std::vector<BSONElement>();
- backingBSON.elems(allTags);
-
- auto inExpr = std::make_unique<InMatchExpression>(kSafeContent);
- inExpr->setBackingBSON(std::move(backingBSON));
- auto status = inExpr->setEqualities(std::move(allTags));
- uassertStatusOK(status);
-
- _rewroteLastExpression = true;
- return inExpr;
-
- } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) {
- LOGV2_DEBUG(6672411,
- 2,
- "FLE Max tag limit hit during query $in rewrite",
- "__error__"_attr = ex.what());
-
- if (_mode != EncryptedCollScanMode::kUseIfNeeded) {
- throw;
- }
-
- // fall through
- }
- }
-
- std::vector<std::unique_ptr<MatchExpression>> matches;
- matches.reserve(numFFPs);
-
- for (auto& eq : expr->getEqualities()) {
- auto exprMatch = generateFleEqualMatchAndExpr(expr->path(), eq, _expCtx.get());
- matches.push_back(std::move(exprMatch));
- }
-
- auto orExpr = std::make_unique<OrMatchExpression>(std::move(matches));
- _rewroteLastExpression = true;
- return orExpr;
-}
-
-std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteRange(
- const EncryptedBetweenMatchExpression* expr) {
- auto ffp = expr->rhs();
-
- if (!isFleFindPayload(ffp, EncryptedBinDataType::kFLE2FindRangePayload)) {
- return nullptr;
- }
-
- auto obj = rewriteRangePayloadAsTags(ffp);
- auto tags = std::vector<BSONElement>();
- obj.elems(tags);
- auto inExpr = std::make_unique<InMatchExpression>(kSafeContent);
- inExpr->setBackingBSON(std::move(obj));
- auto status = inExpr->setEqualities(std::move(tags));
- uassertStatusOK(status);
- _rewroteLastExpression = true;
- return inExpr;
-}
-
} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/server_rewrite.h b/src/mongo/db/query/fle/server_rewrite.h
index d3d52b607f2..c086548b23d 100644
--- a/src/mongo/db/query/fle/server_rewrite.h
+++ b/src/mongo/db/query/fle/server_rewrite.h
@@ -40,19 +40,16 @@
#include "mongo/db/namespace_string.h"
#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/query/count_command_gen.h"
+#include "mongo/db/query/fle/query_rewriter_interface.h"
#include "mongo/db/transaction/transaction_api.h"
+/**
+ * This file contains the interface for rewriting filters within CRUD commands for FLE2.
+ */
namespace mongo {
class FLEQueryInterface;
namespace fle {
-/**
- * Low Selectivity rewrites use $expr which is not supported in all commands such as upserts.
- */
-enum class EncryptedCollScanModeAllowed {
- kAllow,
- kDisallow,
-};
/**
* Make a collator object from its BSON representation. Useful when creating ExpressionContext
@@ -116,152 +113,5 @@ BSONObj rewriteEncryptedFilterInsideTxn(
boost::intrusive_ptr<ExpressionContext> expCtx,
BSONObj filter,
EncryptedCollScanModeAllowed mode = EncryptedCollScanModeAllowed::kDisallow);
-
-/**
- * Class which handles rewriting filter MatchExpressions for FLE2. The functionality is encapsulated
- * as a class rather than just a namespace so that the collection readers don't have to be passed
- * around as extra arguments to every function.
- *
- * Exposed in the header file for unit testing purposes. External callers should use the
- * rewriteEncryptedFilterInsideTxn() helper function defined above.
- */
-class FLEQueryRewriter {
-public:
- enum class EncryptedCollScanMode {
- // Always use high cardinality filters, used by tests
- kForceAlways,
-
- // Use high cardinality mode if $in rewrites do not fit in the
- // internalQueryFLERewriteMemoryLimit memory limit
- kUseIfNeeded,
-
- // Do not rewrite into high cardinality filter, throw exceptions instead
- // Some contexts like upsert do not support $expr
- kDisallow,
- };
-
- /**
- * Takes in references to collection readers for the ESC and ECC that are used during tag
- * computation.
- */
- FLEQueryRewriter(boost::intrusive_ptr<ExpressionContext> expCtx,
- const FLEStateCollectionReader& escReader,
- const FLEStateCollectionReader& eccReader,
- EncryptedCollScanModeAllowed mode = EncryptedCollScanModeAllowed::kAllow)
- : _expCtx(expCtx), _escReader(&escReader), _eccReader(&eccReader) {
-
- if (internalQueryFLEAlwaysUseEncryptedCollScanMode.load()) {
- _mode = EncryptedCollScanMode::kForceAlways;
- }
-
- if (mode == EncryptedCollScanModeAllowed::kDisallow) {
- _mode = EncryptedCollScanMode::kDisallow;
- }
-
- // This isn't the "real" query so we don't want to increment Expression
- // counters here.
- _expCtx->stopExpressionCounters();
- }
-
- /**
- * Accepts a BSONObj holding a MatchExpression, and returns BSON representing the rewritten
- * expression. Returns boost::none if no rewriting was done.
- *
- * Rewrites the match expression with FLE find payloads into a disjunction on the
- * __safeContent__ array of tags.
- *
- * Will rewrite top-level $eq and $in expressions, as well as recursing through $and, $or, $not
- * and $nor. Also handles similarly limited rewriting under $expr. All other MatchExpressions,
- * notably $elemMatch, are ignored.
- */
- boost::optional<BSONObj> rewriteMatchExpression(const BSONObj& filter);
-
- /**
- * Accepts an expression to be re-written. Will rewrite top-level expressions including $eq and
- * $in, as well as recursing through other expressions. Returns a new pointer if the top-level
- * expression must be changed. A nullptr indicates that the modifications happened in-place.
- */
- std::unique_ptr<Expression> rewriteExpression(Expression* expression);
-
- /**
- * Determine whether a given BSONElement is in fact a FLE find payload by checking that it is
- * the same type as the given EncryptedBinDataType. Sub-type 6, sub-sub-type determined by
- * "type."
- */
- virtual bool isFleFindPayload(const BSONElement& elt, EncryptedBinDataType type) const {
- if (!elt.isBinData(BinDataType::Encrypt)) {
- return false;
- }
- int dataLen;
- auto data = elt.binData(dataLen);
- return dataLen >= 1 && data[0] == static_cast<uint8_t>(type);
- }
-
- /**
- * Determine whether a given Value is in fact a FLE find payload by checking that it is the same
- * type as the given EncryptedBinDataType. Sub-type 6, sub-sub-type determined by "type."
- */
- bool isFleFindPayload(const Value& v, EncryptedBinDataType type) const {
- if (v.getType() != BSONType::BinData) {
- return false;
- }
-
- auto binData = v.getBinData();
- return binData.type == BinDataType::Encrypt && binData.length >= 1 &&
- static_cast<uint8_t>(type) == static_cast<const uint8_t*>(binData.data)[0];
- }
-
- std::vector<Value> rewriteEqualityPayloadAsTags(Value fleFindPayload) const;
- std::vector<Value> rewriteRangePayloadAsTags(Value fleFindPayload) const;
-
- ExpressionContext* expCtx() {
- return _expCtx.get();
- }
-
- bool isForceEncryptedCollScan() const {
- return _mode == EncryptedCollScanMode::kForceAlways;
- }
-
- void setForceEncryptedCollScanForTest() {
- _mode = EncryptedCollScanMode::kForceAlways;
- }
-
- EncryptedCollScanMode getEncryptedCollScanMode() const {
- return _mode;
- }
-
-protected:
- // This constructor should only be used for mocks in testing.
- FLEQueryRewriter(boost::intrusive_ptr<ExpressionContext> expCtx)
- : _expCtx(expCtx), _escReader(nullptr), _eccReader(nullptr) {}
-
-private:
- /**
- * A single rewrite step, called recursively on child expressions.
- */
- std::unique_ptr<MatchExpression> _rewrite(MatchExpression* me);
-
- virtual BSONObj rewriteEqualityPayloadAsTags(BSONElement fleFindPayload) const;
-
- virtual BSONObj rewriteRangePayloadAsTags(BSONElement fleFindPayload) const;
- std::unique_ptr<MatchExpression> rewriteEq(const EqualityMatchExpression* expr);
- std::unique_ptr<MatchExpression> rewriteIn(const InMatchExpression* expr);
- std::unique_ptr<MatchExpression> rewriteRange(const EncryptedBetweenMatchExpression* expr);
-
- boost::intrusive_ptr<ExpressionContext> _expCtx;
-
- // Holds a pointer so that these can be null for tests, even though the public constructor
- // takes a const reference.
- const FLEStateCollectionReader* _escReader;
- const FLEStateCollectionReader* _eccReader;
-
- // True if the last Expression or MatchExpression processed by this rewriter was rewritten.
- bool _rewroteLastExpression = false;
-
- // Controls how query rewriter rewrites the query
- EncryptedCollScanMode _mode{EncryptedCollScanMode::kUseIfNeeded};
-};
-
-
} // namespace fle
} // namespace mongo
diff --git a/src/mongo/db/query/fle/server_rewrite_test.cpp b/src/mongo/db/query/fle/server_rewrite_test.cpp
deleted file mode 100644
index fd910390eb9..00000000000
--- a/src/mongo/db/query/fle/server_rewrite_test.cpp
+++ /dev/null
@@ -1,1085 +0,0 @@
-/**
- * Copyright (C) 2022-present MongoDB, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the Server Side Public License, version 1,
- * as published by MongoDB, Inc.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * Server Side Public License for more details.
- *
- * You should have received a copy of the Server Side Public License
- * along with this program. If not, see
- * <http://www.mongodb.com/licensing/server-side-public-license>.
- *
- * As a special exception, the copyright holders give permission to link the
- * code of portions of this program with the OpenSSL library under certain
- * conditions as described in each individual source file and distribute
- * linked combinations including the program with the OpenSSL library. You
- * must comply with the Server Side Public License in all respects for
- * all of the code used other than as permitted herein. If you modify file(s)
- * with this exception, you may extend this exception to your version of the
- * file(s), but you are not obligated to do so. If you do not wish to do so,
- * delete this exception statement from your version. If you delete this
- * exception statement from all source files in the program, then also delete
- * it in the license file.
- */
-
-
-#include <memory>
-
-#include "mongo/bson/bsonelement.h"
-#include "mongo/bson/bsonmisc.h"
-#include "mongo/bson/bsonobjbuilder.h"
-#include "mongo/bson/bsontypes.h"
-#include "mongo/db/matcher/expression_leaf.h"
-#include "mongo/db/pipeline/expression_context_for_test.h"
-#include "mongo/db/query/fle/server_rewrite.h"
-#include "mongo/idl/server_parameter_test_util.h"
-#include "mongo/unittest/unittest.h"
-#include "mongo/util/assert_util.h"
-
-
-namespace mongo {
-namespace {
-
-class BasicMockFLEQueryRewriter : public fle::FLEQueryRewriter {
-public:
- BasicMockFLEQueryRewriter() : fle::FLEQueryRewriter(new ExpressionContextForTest()) {}
-
- BSONObj rewriteMatchExpressionForTest(const BSONObj& obj) {
- auto res = rewriteMatchExpression(obj);
- return res ? res.value() : obj;
- }
-
- /* Given a vector of BSONArrays, concatenate them into one BSONArray.
- *
- * E.g., given vec = [{1, 2, 3}, {4, 5, 6}, {21, 34}] this will return
- * {1, 2, 3, 4, 5, 6, 21, 34} */
- BSONArray concatBSONArrays(std::vector<BSONArray> vec) const {
- auto backingBSONBuilder = BSONArrayBuilder();
-
- for (auto& arr : vec) {
- for (auto&& elt : arr) {
- backingBSONBuilder.append(elt);
- }
- }
- return backingBSONBuilder.arr();
- }
-};
-
-class MockFLEQueryRewriter : public BasicMockFLEQueryRewriter {
-public:
- MockFLEQueryRewriter() : _tags() {}
-
- bool isFleFindPayload(const BSONElement& fleFindPayload,
- EncryptedBinDataType type) const override {
- switch (type) {
- case EncryptedBinDataType::kFLE2FindEqualityPayload: {
- return _encryptedFields.find(fleFindPayload.fieldNameStringData()) !=
- _encryptedFields.end();
- }
- case EncryptedBinDataType::kFLE2FindRangePayload: {
- // By definition, $encryptedBetween only ever has an encrypted payload.
- return true;
- }
- default:
- return false;
- }
- }
-
- void setEncryptedTags(std::pair<StringData, int> fieldvalue, BSONObj tags) {
- _encryptedFields.insert(fieldvalue.first);
- _tags[fieldvalue] = tags;
- }
-
-private:
- BSONObj rewriteEqualityPayloadAsTags(BSONElement fleFindPayload) const override {
- ASSERT(fleFindPayload.isNumber()); // Only accept numbers as mock FFPs.
- ASSERT(_tags.find({fleFindPayload.fieldNameStringData(), fleFindPayload.Int()}) !=
- _tags.end());
- return _tags.find({fleFindPayload.fieldNameStringData(), fleFindPayload.Int()})->second;
- };
-
- BSONObj rewriteRangePayloadAsTags(BSONElement fleFindPayload) const override {
- auto parsedPayload = fleFindPayload.Obj().firstElement();
- auto fieldName = parsedPayload.fieldNameStringData();
-
- std::vector<BSONElement> range;
- auto payloadAsArray = parsedPayload.Array();
- for (auto&& elt : payloadAsArray) {
- range.push_back(elt);
- }
-
- std::vector<BSONArray> allTags;
- for (auto i = range[0].Number(); i <= range[1].Number(); i++) {
- ASSERT(_tags.find({fieldName, i}) != _tags.end());
- auto temp = _tags.find({fieldName, i})->second;
- allTags.push_back(BSONArray(temp));
- }
- return concatBSONArrays(allTags);
- };
-
- std::map<std::pair<StringData, int>, BSONObj> _tags;
- std::set<StringData> _encryptedFields;
-};
-
-class FLEServerRewriteTest : public unittest::Test {
-public:
- FLEServerRewriteTest() {}
-
- void setUp() override {}
-
- void tearDown() override {}
-
-protected:
- MockFLEQueryRewriter _mock;
-};
-
-TEST_F(FLEServerRewriteTest, NoFFP_Equality) {
- auto match = fromjson("{ssn: '5'}");
- auto expected = fromjson("{ssn: '5'}}");
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, NoFFP_In) {
- auto match = fromjson("{ssn: {$in: ['5', '6', '7']}}");
- auto expected = fromjson("{ssn: {$in: ['5', '6', '7']}}");
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_Equality) {
- auto match = fromjson("{ssn: 5}");
- auto tags = BSON_ARRAY(1 << 2 << 3);
-
- _mock.setEncryptedTags({"ssn", 5}, tags);
- auto expected = BSON(kSafeContent << BSON("$in" << tags));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_Equality_DottedPath) {
- auto match = fromjson("{'user.ssn': {$eq: 5}}");
- auto tags = BSON_ARRAY(1 << 2 << 3);
-
- _mock.setEncryptedTags({"user.ssn", 5}, tags);
- auto expected = BSON(kSafeContent << BSON("$in" << tags));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_In) {
- auto match = fromjson("{ssn: {$in: [2, 4, 6]}}");
-
- // The key/value pairs that the mock functions use to determine the fake FFPs are inside an
- // array, and so the keys are the index values and the values are the actual array elements.
- _mock.setEncryptedTags({"0", 2}, BSON_ARRAY(1 << 2));
- _mock.setEncryptedTags({"1", 4}, BSON_ARRAY(5 << 3));
- _mock.setEncryptedTags({"2", 6}, BSON_ARRAY(99 << 100));
-
- // Order doesn't matter in a disjunction.
- auto expected = BSON(kSafeContent << BSON("$in" << BSON_ARRAY(1 << 2 << 3 << 5 << 99 << 100)));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_In_DottedPath) {
- auto match = fromjson("{'user.ssn': {$in: [2, 4, 6]}}");
-
- // The key/value pairs that the mock functions use to determine the fake FFPs are inside an
- // array, and so the keys are the index values and the values are the actual array elements.
- _mock.setEncryptedTags({"0", 2}, BSON_ARRAY(1 << 2));
- _mock.setEncryptedTags({"1", 4}, BSON_ARRAY(5 << 3));
- _mock.setEncryptedTags({"2", 6}, BSON_ARRAY(99 << 100));
-
- // Order doesn't matter in a disjunction.
- auto expected = BSON(kSafeContent << BSON("$in" << BSON_ARRAY(1 << 2 << 3 << 5 << 99 << 100)));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_Conjunction_BothEncrypted) {
- auto match = fromjson("{$and: [{ssn: 5}, {age: 36}]}");
- auto ssnTags = BSON_ARRAY(1 << 2 << 3);
- auto ageTags = BSON_ARRAY(22 << 44 << 66);
-
- _mock.setEncryptedTags({"ssn", 5}, ssnTags);
- _mock.setEncryptedTags({"age", 36}, ageTags);
- auto expected = BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << ssnTags))
- << BSON(kSafeContent << BSON("$in" << ageTags))));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_Conjunction_PartlyEncrypted) {
- auto match = fromjson("{$and: [{ssn: 5}, {notSsn: 6}]}");
- auto tags = BSON_ARRAY(1 << 2 << 3);
-
- _mock.setEncryptedTags({"ssn", 5}, tags);
- auto expected = BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tags))
- << BSON("notSsn" << BSON("$eq" << 6))));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_CompoundEquality_PartlyEncrypted) {
- auto match = fromjson("{ssn: 5, notSsn: 6}");
- auto tags = BSON_ARRAY(1 << 2 << 3);
-
- _mock.setEncryptedTags({"ssn", 5}, tags);
- auto expected = BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tags))
- << BSON("notSsn" << BSON("$eq" << 6))));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_Encrypted_Nested_Unencrypted) {
- auto match = fromjson("{ssn: 5, user: {region: 'US'}}");
- auto tags = BSON_ARRAY(1 << 2 << 3);
-
- _mock.setEncryptedTags({"ssn", 5}, tags);
- auto expected = BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tags))
- << BSON("user" << BSON("$eq" << BSON("region"
- << "US")))));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_Not_Equality) {
- auto match = fromjson("{ssn: {$not: {$eq: 5}}}");
- auto tags = BSON_ARRAY(1 << 2 << 3);
-
- _mock.setEncryptedTags({"ssn", 5}, tags);
- auto expected = BSON(kSafeContent << BSON("$not" << BSON("$in" << tags)));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_Neq) {
- auto match = fromjson("{ssn: {$ne: 5}}");
- auto tags = BSON_ARRAY(1 << 2 << 3);
-
- _mock.setEncryptedTags({"ssn", 5}, tags);
- auto expected = BSON(kSafeContent << BSON("$not" << BSON("$in" << tags)));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-
-TEST_F(FLEServerRewriteTest, TopLevel_And_In) {
- auto match = fromjson("{$and: [{ssn: {$in: [2, 4, 6]}}, {region: 'US'}]}");
-
- _mock.setEncryptedTags({"0", 2}, BSON_ARRAY(1 << 2));
- _mock.setEncryptedTags({"1", 4}, BSON_ARRAY(5 << 3));
- _mock.setEncryptedTags({"2", 6}, BSON_ARRAY(99 << 100));
-
- auto expected =
- BSON("$and" << BSON_ARRAY(
- BSON(kSafeContent << BSON("$in" << BSON_ARRAY(1 << 2 << 3 << 5 << 99 << 100)))
- << BSON("region" << BSON("$eq"
- << "US"))));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, NestedConjunction) {
- auto match = fromjson("{$and: [{$and: [{ssn: 2}, {other: 3}]}, {otherSsn: 5}]}");
-
- _mock.setEncryptedTags({"ssn", 2}, BSON_ARRAY(1 << 2));
- _mock.setEncryptedTags({"otherSsn", 5}, BSON_ARRAY(3 << 4));
-
- auto expected = fromjson(R"(
- { $and: [
- { $and: [
- { __safeContent__: { $in: [ 1, 2 ] } },
- { other: { $eq: 3 } }
- ] },
- { __safeContent__: { $in: [ 3, 4 ] } }
- ] })");
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_Nor_Equality) {
- auto match = fromjson("{$nor: [{ssn: 5}]}");
- auto tags = BSON_ARRAY(1 << 2 << 3);
-
- _mock.setEncryptedTags({"ssn", 5}, tags);
- auto expected = BSON("$nor" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tags))));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_Nor_Equality_WithUnencrypted) {
- auto match = fromjson("{$nor: [{ssn: 5}, {region: 'US'}]}");
- auto tags = BSON_ARRAY(1 << 2 << 3);
-
- _mock.setEncryptedTags({"ssn", 5}, tags);
- auto expected = BSON("$nor" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tags))
- << BSON("region" << BSON("$eq"
- << "US"))));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_Or_Equality_WithUnencrypted) {
- auto match = fromjson("{$or: [{ssn: 5}, {region: 'US'}]}");
- auto tags = BSON_ARRAY(1 << 2 << 3);
-
- _mock.setEncryptedTags({"ssn", 5}, tags);
- auto expected = BSON("$or" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tags))
- << BSON("region" << BSON("$eq"
- << "US"))));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_Not_In) {
- auto match = fromjson("{ssn: {$not: {$in: [2, 4, 6]}}}");
-
- _mock.setEncryptedTags({"0", 2}, BSON_ARRAY(1 << 2));
- _mock.setEncryptedTags({"1", 4}, BSON_ARRAY(5 << 3));
- _mock.setEncryptedTags({"2", 6}, BSON_ARRAY(99 << 100));
-
- auto expected = BSON(
- kSafeContent << BSON("$not" << BSON("$in" << BSON_ARRAY(1 << 2 << 3 << 5 << 99 << 100))));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, TopLevel_Nin) {
- auto match = fromjson("{ssn: {$nin: [2, 4, 6]}}");
-
- _mock.setEncryptedTags({"0", 2}, BSON_ARRAY(1 << 2));
- _mock.setEncryptedTags({"1", 4}, BSON_ARRAY(5 << 3));
- _mock.setEncryptedTags({"2", 6}, BSON_ARRAY(99 << 100));
-
- // Order doesn't matter in a disjunction.
- auto expected = BSON(
- kSafeContent << BSON("$not" << BSON("$in" << BSON_ARRAY(1 << 2 << 3 << 5 << 99 << 100))));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, InMixOfEncryptedElementsIsDisallowed) {
- auto match = fromjson("{ssn: {$in: [2, 4, 6]}}");
-
- _mock.setEncryptedTags({"0", 2}, BSON_ARRAY(1 << 2));
- _mock.setEncryptedTags({"1", 4}, BSON_ARRAY(5 << 3));
-
- ASSERT_THROWS_CODE(_mock.rewriteMatchExpressionForTest(match), AssertionException, 6329400);
-}
-
-TEST_F(FLEServerRewriteTest, ComparisonToObjectIgnored) {
- // Although such a query should fail in query analysis, it's not realistic for us to catch all
- // the ways a FLEFindPayload could be improperly included in an explicitly encrypted query, so
- // this test demonstrates the server side behavior.
- {
- auto match = fromjson("{user: {$eq: {ssn: 5}}}");
-
- _mock.setEncryptedTags({"user.ssn", 5}, BSON_ARRAY(1 << 2));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, match);
- }
- {
- auto match = fromjson("{user: {$in: [{ssn: 5}]}}");
-
- _mock.setEncryptedTags({"user.ssn", 5}, BSON_ARRAY(1 << 2));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, match);
- }
-}
-
-TEST_F(FLEServerRewriteTest, EncryptedBetweenBasic) {
- RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
-
- int start = 1;
- int end = 3;
- StringData encField = "ssn";
-
- // The field redundancy is so that we can pull out the field
- // name in the mock version of rewriteRangePayloadAsTags.
- BSONObj query =
- BSON(encField << BSON("$encryptedBetween" << BSON(encField << BSON_ARRAY(start << end))));
-
- auto tags1 = BSON_ARRAY(1 << 2 << 3);
- auto tags2 = BSON_ARRAY("A"
- << "F"
- << "Q");
- auto tags3 = BSON_ARRAY("aHb"
- << "jkl"
- << "q76");
-
- std::vector<BSONArray> allTags = {tags1, tags2, tags3};
- BSONArray tagsConcat = _mock.concatBSONArrays(allTags);
-
- for (int i = 0; i <= (end - start); i++) {
- _mock.setEncryptedTags({encField, (start + i)}, allTags[i]);
- }
-
- auto expected = BSON(kSafeContent << BSON("$in" << tagsConcat));
- auto actual = _mock.rewriteMatchExpressionForTest(query);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, EncryptedBetweenFeatureFlagFalse) {
- RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", false);
-
- int start = 1;
- int end = 3;
- StringData encField = "ssn";
-
- BSONObj query =
- BSON(encField << BSON("$encryptedBetween" << BSON(encField << BSON_ARRAY(start << end))));
-
- auto tags1 = BSON_ARRAY(1 << 2 << 3);
- auto tags2 = BSON_ARRAY(4 << 5 << 6);
- auto tags3 = BSON_ARRAY(7 << 8 << 9);
-
- std::vector<BSONArray> allTags = {tags1, tags2, tags3};
- BSONArray tagsConcat = _mock.concatBSONArrays(allTags);
-
- for (int i = 0; i <= (end - start); i++) {
- _mock.setEncryptedTags({encField, (start + i)}, allTags[i]);
- }
-
- // No rewrite should occur since the feature flag has been set to false.
- auto actual = _mock.rewriteMatchExpressionForTest(query);
- ASSERT_BSONOBJ_EQ(actual, query);
-}
-
-TEST_F(FLEServerRewriteTest, EncryptedBetweenVariableNumberOfTags) {
- RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
-
- int start = 1;
- int end = 3;
- StringData encField = "ssn";
-
- BSONObj query =
- BSON(encField << BSON("$encryptedBetween" << BSON(encField << BSON_ARRAY(start << end))));
-
- auto tags1 = BSON_ARRAY(1);
- auto tags2 = BSON_ARRAY("A"
- << "F");
- auto tags3 = BSON_ARRAY("aHb"
- << "bcdefdfg12243"
- << "c"
- << "d"
- << "e"
- << "f"
- << "g"
- << "hij"
- << "kl78h");
-
- std::vector<BSONArray> allTags = {tags1, tags2, tags3};
- BSONArray tagsConcat = _mock.concatBSONArrays(allTags);
-
- for (int i = 0; i <= (end - start); i++) {
- _mock.setEncryptedTags({encField, (start + i)}, allTags[i]);
- }
-
- auto expected = BSON(kSafeContent << BSON("$in" << tagsConcat));
- auto actual = _mock.rewriteMatchExpressionForTest(query);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, EncryptedBetweenInsideNot) {
- RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
-
- int start = 1;
- int end = 3;
- StringData encField = "ssn";
- BSONObj query =
- BSON(encField << BSON("$not" << BSON("$encryptedBetween"
- << BSON(encField << BSON_ARRAY(start << end)))));
-
- auto tags1 = BSON_ARRAY(1 << 2 << 3);
- auto tags2 = BSON_ARRAY("A"
- << "F"
- << "Q");
- auto tags3 = BSON_ARRAY("aHb"
- << "jkl"
- << "q76");
-
- std::vector<BSONArray> allTags = {tags1, tags2, tags3};
- BSONArray tagsConcat = _mock.concatBSONArrays(allTags);
-
- for (int i = 0; i <= (end - start); i++) {
- _mock.setEncryptedTags({encField, (start + i)}, allTags[i]);
- }
-
- auto expected = BSON(kSafeContent << BSON("$not" << BSON("$in" << tagsConcat)));
- auto actual = _mock.rewriteMatchExpressionForTest(query);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, EncryptedBetweenDottedPath) {
- RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
-
- int start = 1;
- int end = 3;
- StringData encField = "hello.world";
- BSONObj query =
- BSON(encField << BSON("$encryptedBetween" << BSON(encField << BSON_ARRAY(start << end))));
-
- auto tags1 = BSON_ARRAY(1 << 2 << 3);
- auto tags2 = BSON_ARRAY("A"
- << "F"
- << "Q");
- auto tags3 = BSON_ARRAY("aHb"
- << "jkl"
- << "q76");
-
- std::vector<BSONArray> allTags = {tags1, tags2, tags3};
- BSONArray tagsConcat = _mock.concatBSONArrays(allTags);
-
- for (int i = 0; i <= (end - start); i++) {
- _mock.setEncryptedTags({encField, (start + i)}, allTags[i]);
- }
-
- auto expected = BSON(kSafeContent << BSON("$in" << tagsConcat));
- auto actual = _mock.rewriteMatchExpressionForTest(query);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, EncryptedBetweenInsideAnd) {
- RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
-
- int start = 1;
- int end = 3;
- StringData encField = "ssn";
- BSONObj query = BSON(
- "$and" << BSON_ARRAY(BSON("x" << 5)
- << BSON(encField << BSON("$encryptedBetween" << BSON(
- encField << BSON_ARRAY(start << end))))));
-
- auto tags1 = BSON_ARRAY(1 << 2 << 3);
- auto tags2 = BSON_ARRAY("A"
- << "F"
- << "Q");
- auto tags3 = BSON_ARRAY("aHb"
- << "jkl"
- << "q76");
-
- std::vector<BSONArray> allTags = {tags1, tags2, tags3};
- BSONArray tagsConcat = _mock.concatBSONArrays(allTags);
-
- for (int i = 0; i <= (end - start); i++) {
- _mock.setEncryptedTags({encField, (start + i)}, allTags[i]);
- }
-
- auto expected = BSON("$and" << BSON_ARRAY(BSON("x" << BSON("$eq" << 5))
- << BSON(kSafeContent << BSON("$in" << tagsConcat))));
- auto actual = _mock.rewriteMatchExpressionForTest(query);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, EncryptedBetweenInsideOr) {
- RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
-
- int start = 1;
- int end = 3;
- StringData encField = "ssn";
- BSONObj query = BSON(
- "$or" << BSON_ARRAY(BSON("x" << 5)
- << BSON(encField << BSON("$encryptedBetween" << BSON(
- encField << BSON_ARRAY(start << end))))));
-
- auto tags1 = BSON_ARRAY(1 << 2 << 3);
- auto tags2 = BSON_ARRAY("A"
- << "F"
- << "Q");
- auto tags3 = BSON_ARRAY("aHb"
- << "jkl"
- << "q76");
-
- std::vector<BSONArray> allTags = {tags1, tags2, tags3};
- BSONArray tagsConcat = _mock.concatBSONArrays(allTags);
-
- for (int i = 0; i <= (end - start); i++) {
- _mock.setEncryptedTags({encField, (start + i)}, allTags[i]);
- }
-
- auto expected = BSON("$or" << BSON_ARRAY(BSON("x" << BSON("$eq" << 5))
- << BSON(kSafeContent << BSON("$in" << tagsConcat))));
- auto actual = _mock.rewriteMatchExpressionForTest(query);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, EncryptedBetweeenAndEncryptedEquality) {
- RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
-
- int start = 1;
- int end = 3;
- StringData encField = "ssn";
- BSONObj query = BSON(
- "$and" << BSON_ARRAY(BSON("x" << 21)
- << BSON(encField << BSON("$encryptedBetween" << BSON(
- encField << BSON_ARRAY(start << end))))));
-
- auto tags1 = BSON_ARRAY(1 << 2 << 3);
- auto tags2 = BSON_ARRAY("A"
- << "F"
- << "Q");
- auto tags3 = BSON_ARRAY("aHb"
- << "jkl"
- << "q76");
-
- std::vector<BSONArray> allTags = {tags1, tags2, tags3};
- BSONArray tagsConcat = _mock.concatBSONArrays(allTags);
-
- for (int i = 0; i <= (end - start); i++) {
- _mock.setEncryptedTags({encField, (start + i)}, allTags[i]);
- }
-
- BSONArray equalityTags = BSON_ARRAY(312 << 567 << 897);
- _mock.setEncryptedTags({"x", 21}, equalityTags);
-
- auto expected = BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << equalityTags))
- << BSON(kSafeContent << BSON("$in" << tagsConcat))));
- auto actual = _mock.rewriteMatchExpressionForTest(query);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, EncryptedBetweenAndEncryptedIn) {
- RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
-
- int start = 1;
- int end = 3;
- StringData encBetweenField = "ssn";
- StringData encInField = "age";
- BSONObj query =
- BSON("$and" << BSON_ARRAY(
- BSON(encBetweenField << BSON("$encryptedBetween"
- << BSON(encBetweenField << BSON_ARRAY(start << end))))
- << BSON(encInField << BSON("$in" << BSON_ARRAY(10 << 22 << 34)))));
-
- auto tags1 = BSON_ARRAY(1 << 2 << 3);
- auto tags2 = BSON_ARRAY("A"
- << "F"
- << "Q");
- auto tags3 = BSON_ARRAY("aHb"
- << "jkl"
- << "q76");
-
- std::vector<BSONArray> allTags = {tags1, tags2, tags3};
- BSONArray tagsConcat = _mock.concatBSONArrays(allTags);
-
- for (int i = 0; i <= (end - start); i++) {
- _mock.setEncryptedTags({encBetweenField, (start + i)}, allTags[i]);
- }
-
- auto inTags1 = BSON_ARRAY(1 << 2);
- auto inTags2 = BSON_ARRAY(3 << 5);
- auto inTags3 = BSON_ARRAY(97 << 98 << 99 << 100);
- std::vector<BSONArray> allInTags = {inTags1, inTags2, inTags3};
- _mock.setEncryptedTags({"0", 10}, inTags1);
- _mock.setEncryptedTags({"1", 22}, inTags2);
- _mock.setEncryptedTags({"2", 34}, inTags3);
- BSONArray inTagsConcat = _mock.concatBSONArrays(allInTags);
-
-
- auto expected =
- BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tagsConcat))
- << BSON(kSafeContent << BSON("$in" << inTagsConcat))));
- auto actual = _mock.rewriteMatchExpressionForTest(query);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, EncryptedBetweenAndUnencryptedRange) {
- RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
-
- int start = 1;
- int end = 3;
- StringData encField = "ssn";
- BSONObj query = BSON(
- "$and" << BSON_ARRAY(BSON(encField << BSON("$encryptedBetween"
- << BSON(encField << BSON_ARRAY(start << end))))
- << BSON("$and" << BSON_ARRAY(BSON("x" << BSON("$gt" << 10))
- << BSON("x" << BSON("$lt" << 25))))));
-
- auto tags1 = BSON_ARRAY(1 << 2 << 3);
- auto tags2 = BSON_ARRAY("A"
- << "F"
- << "Q");
- auto tags3 = BSON_ARRAY("aHb"
- << "jkl"
- << "q76");
-
- std::vector<BSONArray> allTags = {tags1, tags2, tags3};
- BSONArray tagsConcat = _mock.concatBSONArrays(allTags);
-
- for (int i = 0; i <= (end - start); i++) {
- _mock.setEncryptedTags({encField, (start + i)}, allTags[i]);
- }
-
- auto expected = BSON(
- "$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tagsConcat))
- << BSON("$and" << BSON_ARRAY(BSON("x" << BSON("$gt" << 10))
- << BSON("x" << BSON("$lt" << 25))))));
- auto actual = _mock.rewriteMatchExpressionForTest(query);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerRewriteTest, EncryptedBetweenOnTwoDiffFields) {
- RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
-
- int start1 = 1;
- int end1 = 3;
- StringData encField1 = "ssn";
- int start2 = 102;
- int end2 = 106;
- StringData encField2 = "age";
- BSONObj query =
- BSON("$and" << BSON_ARRAY(
- BSON(encField1 << BSON("$encryptedBetween"
- << BSON(encField1 << BSON_ARRAY(start1 << end1))))
- << BSON(encField2 << BSON("$encryptedBetween"
- << BSON(encField2 << BSON_ARRAY(start2 << end2))))));
-
- auto tags1 = BSON_ARRAY(1 << 2 << 3);
- auto tags2 = BSON_ARRAY("A"
- << "F"
- << "Q");
- auto tags3 = BSON_ARRAY("aHb"
- << "jkl"
- << "q76");
-
- std::vector<BSONArray> allTags1 = {tags1, tags2, tags3};
- BSONArray tagsConcat1 = _mock.concatBSONArrays(allTags1);
-
- for (int i = 0; i <= (end1 - start1); i++) {
- _mock.setEncryptedTags({encField1, (start1 + i)}, allTags1[i]);
- }
-
- auto tags4 = BSON_ARRAY(1 << 2 << 3);
- auto tags5 = BSON_ARRAY(4 << 5 << 6);
- auto tags6 = BSON_ARRAY(21 << 25 << 45);
- auto tags7 = BSON_ARRAY(112 << 212 << 456);
- auto tags8 = BSON_ARRAY(908 << 1234 << 23467813);
-
- std::vector<BSONArray> allTags2 = {tags4, tags5, tags6, tags7, tags8};
- BSONArray tagsConcat2 = _mock.concatBSONArrays(allTags2);
-
- for (int i = 0; i <= (end2 - start2); i++) {
- _mock.setEncryptedTags({encField2, (start2 + i)}, allTags2[i]);
- }
-
- auto expected = BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tagsConcat1))
- << BSON(kSafeContent << BSON("$in" << tagsConcat2))));
- auto actual = _mock.rewriteMatchExpressionForTest(query);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-template <typename T>
-std::vector<uint8_t> toEncryptedVector(EncryptedBinDataType dt, T t) {
- BSONObj obj = t.toBSON();
-
- std::vector<uint8_t> buf(obj.objsize() + 1);
- buf[0] = static_cast<uint8_t>(dt);
-
- std::copy(obj.objdata(), obj.objdata() + obj.objsize(), buf.data() + 1);
-
- return buf;
-}
-
-template <typename T>
-void toEncryptedBinData(StringData field, EncryptedBinDataType dt, T t, BSONObjBuilder* builder) {
- auto buf = toEncryptedVector(dt, t);
-
- builder->appendBinData(field, buf.size(), BinDataType::Encrypt, buf.data());
-}
-
-constexpr auto kIndexKeyId = "12345678-1234-9876-1234-123456789012"_sd;
-constexpr auto kUserKeyId = "ABCDEFAB-1234-9876-1234-123456789012"_sd;
-static UUID indexKeyId = uassertStatusOK(UUID::parse(kIndexKeyId.toString()));
-static UUID userKeyId = uassertStatusOK(UUID::parse(kUserKeyId.toString()));
-
-std::vector<char> testValue = {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19};
-std::vector<char> testValue2 = {0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29};
-
-const FLEIndexKey& getIndexKey() {
- static std::string indexVec = hexblob::decode(
- "7dbfebc619aa68a659f64b8e23ccd21644ac326cb74a26840c3d2420176c40ae088294d00ad6cae9684237b21b754cf503f085c25cd320bf035c3417416e1e6fe3d9219f79586582112740b2add88e1030d91926ae8afc13ee575cfb8bb965b7"_sd);
- static FLEIndexKey indexKey(KeyMaterial(indexVec.begin(), indexVec.end()));
- return indexKey;
-}
-
-const FLEUserKey& getUserKey() {
- static std::string userVec = hexblob::decode(
- "a7ddbc4c8be00d51f68d9d8e485f351c8edc8d2206b24d8e0e1816d005fbe520e489125047d647b0d8684bfbdbf09c304085ed086aba6c2b2b1677ccc91ced8847a733bf5e5682c84b3ee7969e4a5fe0e0c21e5e3ee190595a55f83147d8de2a"_sd);
- static FLEUserKey userKey(KeyMaterial(userVec.begin(), userVec.end()));
- return userKey;
-}
-
-
-BSONObj generateFFP(StringData path, int value) {
- auto indexKey = getIndexKey();
- FLEIndexKeyAndId indexKeyAndId(indexKey.data, indexKeyId);
- auto userKey = getUserKey();
- FLEUserKeyAndId userKeyAndId(userKey.data, indexKeyId);
-
- BSONObj doc = BSON("value" << value);
- auto element = doc.firstElement();
- auto fpp = FLEClientCrypto::serializeFindPayload(indexKeyAndId, userKeyAndId, element, 0);
-
- BSONObjBuilder builder;
- toEncryptedBinData(path, EncryptedBinDataType::kFLE2FindEqualityPayload, fpp, &builder);
- return builder.obj();
-}
-
-class FLEServerHighCardRewriteTest : public unittest::Test {
-public:
- FLEServerHighCardRewriteTest() {}
-
- void setUp() override {}
-
- void tearDown() override {}
-
-protected:
- BasicMockFLEQueryRewriter _mock;
-};
-
-
-TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Equality) {
- _mock.setForceEncryptedCollScanForTest();
-
- auto match = generateFFP("ssn", 1);
- auto expected = fromjson(R"({
- "$expr": {
- "$_internalFleEq": {
- "field": "$ssn",
- "edc": {
- "$binary": {
- "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
- "subType": "6"
- }
- },
- "counter": {
- "$numberLong": "0"
- },
- "server": {
- "$binary": {
- "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
- "subType": "6"
- }
- }
- }
- }
-})");
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-
-TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_In) {
- _mock.setForceEncryptedCollScanForTest();
-
- auto ffp1 = generateFFP("ssn", 1);
- auto ffp2 = generateFFP("ssn", 2);
- auto ffp3 = generateFFP("ssn", 3);
- auto expected = fromjson(R"({
- "$or": [
- {
- "$expr": {
- "$_internalFleEq": {
- "field": "$ssn",
- "edc": {
- "$binary": {
- "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
- "subType": "6"
- }
- },
- "counter": {
- "$numberLong": "0"
- },
- "server": {
- "$binary": {
- "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
- "subType": "6"
- }
- }
- }
- }
- },
- {
- "$expr": {
- "$_internalFleEq": {
- "field": "$ssn",
- "edc": {
- "$binary": {
- "base64": "CLpCo6rNuYMVT+6n1HCX15MNrVYDNqf6udO46ayo43Sw",
- "subType": "6"
- }
- },
- "counter": {
- "$numberLong": "0"
- },
- "server": {
- "$binary": {
- "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
- "subType": "6"
- }
- }
- }
- }
- },
- {
- "$expr": {
- "$_internalFleEq": {
- "field": "$ssn",
- "edc": {
- "$binary": {
- "base64": "CPi44oCQHnNDeRqHsNLzbdCeHt2DK/wCly0g2dxU5fqN",
- "subType": "6"
- }
- },
- "counter": {
- "$numberLong": "0"
- },
- "server": {
- "$binary": {
- "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
- "subType": "6"
- }
- }
- }
- }
- }
- ]
-})");
-
- auto match =
- BSON("ssn" << BSON("$in" << BSON_ARRAY(ffp1.firstElement()
- << ffp2.firstElement() << ffp3.firstElement())));
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-
-TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Expr) {
-
- _mock.setForceEncryptedCollScanForTest();
-
- auto ffp = generateFFP("$ssn", 1);
- int len;
- auto v = ffp.firstElement().binDataClean(len);
- auto match = BSON("$expr" << BSON("$eq" << BSON_ARRAY(ffp.firstElement().fieldName()
- << BSONBinData(v, len, Encrypt))));
-
- auto expected = fromjson(R"({ "$expr": {
- "$_internalFleEq": {
- "field": "$ssn",
- "edc": {
- "$binary": {
- "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
- "subType": "6"
- }
- },
- "counter": {
- "$numberLong": "0"
- },
- "server": {
- "$binary": {
- "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
- "subType": "6"
- }
- }
- }
- }
- })");
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Expr_In) {
-
- _mock.setForceEncryptedCollScanForTest();
-
- auto ffp = generateFFP("$ssn", 1);
- int len;
- auto v = ffp.firstElement().binDataClean(len);
-
- auto ffp2 = generateFFP("$ssn", 1);
- int len2;
- auto v2 = ffp2.firstElement().binDataClean(len2);
-
- auto match = BSON(
- "$expr" << BSON("$in" << BSON_ARRAY(ffp.firstElement().fieldName()
- << BSON_ARRAY(BSONBinData(v, len, Encrypt)
- << BSONBinData(v2, len2, Encrypt)))));
-
- auto expected = fromjson(R"({ "$expr": { "$or" : [ {
- "$_internalFleEq": {
- "field": "$ssn",
- "edc": {
- "$binary": {
- "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
- "subType": "6"
- }
- },
- "counter": {
- "$numberLong": "0"
- },
- "server": {
- "$binary": {
- "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
- "subType": "6"
- }
- }
- }},
- {
- "$_internalFleEq": {
- "field": "$ssn",
- "edc": {
- "$binary": {
- "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg",
- "subType": "6"
- }
- },
- "counter": {
- "$numberLong": "0"
- },
- "server": {
- "$binary": {
- "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps",
- "subType": "6"
- }
- }
- }}
- ]}})");
-
- auto actual = _mock.rewriteMatchExpressionForTest(match);
- ASSERT_BSONOBJ_EQ(actual, expected);
-}
-
-} // namespace
-} // namespace mongo