diff options
20 files changed, 2225 insertions, 1769 deletions
diff --git a/src/mongo/crypto/fle_crypto.cpp b/src/mongo/crypto/fle_crypto.cpp index 9bc6dfa7022..a288e77329c 100644 --- a/src/mongo/crypto/fle_crypto.cpp +++ b/src/mongo/crypto/fle_crypto.cpp @@ -3136,13 +3136,13 @@ StringMap<FLEDeleteToken> EncryptionInformationHelpers::getDeleteTokens( return map; } -ParsedFindPayload::ParsedFindPayload(BSONElement fleFindPayload) - : ParsedFindPayload(binDataToCDR(fleFindPayload)){}; +ParsedFindEqualityPayload::ParsedFindEqualityPayload(BSONElement fleFindPayload) + : ParsedFindEqualityPayload(binDataToCDR(fleFindPayload)){}; -ParsedFindPayload::ParsedFindPayload(const Value& fleFindPayload) - : ParsedFindPayload(binDataToCDR(fleFindPayload)){}; +ParsedFindEqualityPayload::ParsedFindEqualityPayload(const Value& fleFindPayload) + : ParsedFindEqualityPayload(binDataToCDR(fleFindPayload)){}; -ParsedFindPayload::ParsedFindPayload(ConstDataRange cdr) { +ParsedFindEqualityPayload::ParsedFindEqualityPayload(ConstDataRange cdr) { auto [encryptedTypeBinding, subCdr] = fromEncryptedConstDataRange(cdr); auto encryptedType = encryptedTypeBinding; diff --git a/src/mongo/crypto/fle_crypto.h b/src/mongo/crypto/fle_crypto.h index 566dd00246c..c65a1c8a0ca 100644 --- a/src/mongo/crypto/fle_crypto.h +++ b/src/mongo/crypto/fle_crypto.h @@ -1253,16 +1253,16 @@ public: */ std::pair<EncryptedBinDataType, ConstDataRange> fromEncryptedConstDataRange(ConstDataRange cdr); -struct ParsedFindPayload { +struct ParsedFindEqualityPayload { ESCDerivedFromDataToken escToken; ECCDerivedFromDataToken eccToken; EDCDerivedFromDataToken edcToken; boost::optional<ServerDataEncryptionLevel1Token> serverToken; boost::optional<std::int64_t> maxCounter; - explicit ParsedFindPayload(BSONElement fleFindPayload); - explicit ParsedFindPayload(const Value& fleFindPayload); - explicit ParsedFindPayload(ConstDataRange cdr); + explicit ParsedFindEqualityPayload(BSONElement fleFindPayload); + explicit ParsedFindEqualityPayload(const Value& fleFindPayload); + explicit ParsedFindEqualityPayload(ConstDataRange cdr); }; diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript index f109b4b282f..cc88766bd55 100644 --- a/src/mongo/db/SConscript +++ b/src/mongo/db/SConscript @@ -880,6 +880,10 @@ env.Library( target='fle_crud', source=[ 'fle_crud.cpp', + 'query/fle/encrypted_predicate.cpp', + 'query/fle/equality_predicate.cpp', + 'query/fle/query_rewriter.cpp', + 'query/fle/range_predicate.cpp', 'query/fle/server_rewrite.cpp', ], LIBDEPS=[ @@ -2515,7 +2519,9 @@ if wiredtiger: 'operation_id_test.cpp', 'operation_time_tracker_test.cpp', 'persistent_task_store_test.cpp', - 'query/fle/server_rewrite_test.cpp', + 'query/fle/equality_predicate_test.cpp', + 'query/fle/range_predicate_test.cpp', + 'query/fle/query_rewriter_test.cpp', 'range_arithmetic_test.cpp', 'read_write_concern_defaults_test.cpp', 'read_write_concern_provenance_test.cpp', diff --git a/src/mongo/db/query/canonical_query.cpp b/src/mongo/db/query/canonical_query.cpp index bbd7c68a21b..c5bf0a5fbd4 100644 --- a/src/mongo/db/query/canonical_query.cpp +++ b/src/mongo/db/query/canonical_query.cpp @@ -42,7 +42,6 @@ #include "mongo/db/operation_context.h" #include "mongo/db/query/canonical_query_encoder.h" #include "mongo/db/query/collation/collator_factory_interface.h" -#include "mongo/db/query/fle/server_rewrite.h" #include "mongo/db/query/indexability.h" #include "mongo/db/query/projection_parser.h" #include "mongo/db/query/query_planner_common.h" diff --git a/src/mongo/db/query/fle/encrypted_predicate.cpp b/src/mongo/db/query/fle/encrypted_predicate.cpp new file mode 100644 index 00000000000..c23e1a2a5ae --- /dev/null +++ b/src/mongo/db/query/fle/encrypted_predicate.cpp @@ -0,0 +1,65 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "encrypted_predicate.h" + +#include "mongo/crypto/fle_crypto.h" +#include "mongo/db/query/fle/query_rewriter_interface.h" +#include "mongo/logv2/log.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kQuery + +namespace mongo { +namespace fle { + +ExpressionToRewriteMap aggPredicateRewriteMap{}; +MatchTypeToRewriteMap matchPredicateRewriteMap{}; + +void logTagsExceeded(const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG( + 6672410, 2, "FLE Max tag limit hit during query rewrite", "__error__"_attr = ex.what()); +} + +BSONArray toBSONArray(std::vector<PrfBlock>&& vec) { + auto bab = BSONArrayBuilder(); + for (auto& elt : vec) { + bab.appendBinData(elt.size(), BinDataType::BinDataGeneral, elt.data()); + } + return bab.arr(); +} + +std::vector<Value> toValues(std::vector<PrfBlock>&& vec) { + std::vector<Value> output; + for (auto& elt : vec) { + output.push_back(Value(BSONBinData(elt.data(), elt.size(), BinDataType::BinDataGeneral))); + } + return output; +} +} // namespace fle +} // namespace mongo diff --git a/src/mongo/db/query/fle/encrypted_predicate.h b/src/mongo/db/query/fle/encrypted_predicate.h new file mode 100644 index 00000000000..b2797fd1eb9 --- /dev/null +++ b/src/mongo/db/query/fle/encrypted_predicate.h @@ -0,0 +1,261 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <functional> + +#include "mongo/base/init.h" +#include "mongo/bson/bsonobjbuilder.h" +#include "mongo/crypto/fle_crypto.h" +#include "mongo/crypto/fle_field_schema_gen.h" +#include "mongo/db/exec/document_value/document.h" +#include "mongo/db/matcher/expression.h" +#include "mongo/db/matcher/expression_leaf.h" +#include "mongo/db/pipeline/expression.h" +#include "mongo/db/query/fle/query_rewriter_interface.h" + +/** + * This file contains an abstract class that describes rewrites on agg Expressions and + * MatchExpressions for individual encrypted index types. Subclasses of this class represent + * concrete encrypted index types, like Equality and Range. + * + * This class is not responsible for traversing expression trees, but instead takes leaf + * expressions that it may replace. Tree traversal is handled by the QueryRewriter. + */ + +namespace mongo { +namespace fle { + +// Virtual functions can't be templated, so in order to write a function which can take in either a +// BSONElement or a Value&, we need to create a variant type to use in function signatures. +// std::reference_wrapper is necessary to avoid copying the Value because references alone cannot be +// included in a variant. BSONElement can be passed by value because it is just a pointer into an +// owning BSONObj. +using BSONValue = stdx::variant<BSONElement, std::reference_wrapper<Value>>; + +/** + * Parse a find payload from either a BSONElement or a Value. All ParsedFindPayload types should + * have constructors for both BSONElements and Values, which will enable this function to work on + * both types. + */ +template <typename T> +T parseFindPayload(BSONValue payload) { + return stdx::visit(OverloadedVisitor{[&](BSONElement payload) { return T(payload); }, + [&](Value payload) { return T(payload); }}, + payload); +} + +/** + * Convert a vector of PrfBlocks to a BSONArray for use in MatchExpression tag generation. + */ +BSONArray toBSONArray(std::vector<PrfBlock>&& vec); + +/** + * Convert a vector of PrfBlocks to a vector of Values for use in Agg tag generation. + */ +std::vector<Value> toValues(std::vector<PrfBlock>&& vec); + +void logTagsExceeded(const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex); +/** + * Interface for implementing a server rewrite for an encrypted index. Each type of predicate + * should have its own subclass that implements the virtual methods in this class. + */ +class EncryptedPredicate { +public: + EncryptedPredicate(const QueryRewriterInterface* rewriter) : _rewriter(rewriter) {} + + /** + * Rewrite a terminal expression for this encrypted predicate. If this function returns + * nullptr, then no rewrite needs to be done. Rewrites generally transform predicates from one + * kind of expression to another, either a $in or an $_internalFle* runtime expression, and so + * this function will allocate a new expression and return a unique_ptr to it. + */ + template <typename T> + std::unique_ptr<T> rewrite(T* expr) const { + auto mode = _rewriter->getEncryptedCollScanMode(); + if (mode != EncryptedCollScanMode::kForceAlways) { + try { + return rewriteToTagDisjunction(expr); + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + // LOGV2 can't be called from a header file, so this call is factored out to a + // function defined in the cpp file. + logTagsExceeded(ex); + if (mode != EncryptedCollScanMode::kUseIfNeeded) { + throw; + } + } + } + return rewriteToRuntimeComparison(expr); + } + +protected: + /** + * Check if the passed-in payload is a FLE2 find payload for the right encrypted index type. + */ + virtual bool isPayload(const BSONElement& elt) const { + if (!elt.isBinData(BinDataType::Encrypt)) { + return false; + } + int dataLen; + auto data = elt.binData(dataLen); + + // Check that the BinData's subtype is 6, and its sub-subtype is equal to this predicate's + // encryptedBinDataType. + return dataLen >= 1 && data[0] == static_cast<uint8_t>(encryptedBinDataType()); + } + + /** + * Check if the passed-in payload is a FLE2 find payload for the right encrypted index type. + */ + virtual bool isPayload(const Value& v) const { + if (v.getType() != BSONType::BinData) { + return false; + } + + auto binData = v.getBinData(); + // Check that the BinData's subtype is 6, and its sub-subtype is equal to this predicate's + // encryptedBinDataType. + return binData.type == BinDataType::Encrypt && binData.length >= 1 && + static_cast<uint8_t>(encryptedBinDataType()) == + static_cast<const uint8_t*>(binData.data)[0]; + } + /** + * Generate tags from a FLE2 Find Payload. This function takes in a variant of BSONElement and + * Value so that it can be used in both the MatchExpression and Aggregation contexts. Virtual + * functions can't also be templated, which is why we need the runtime dispatch on the variant. + */ + virtual std::vector<PrfBlock> generateTags(BSONValue payload) const = 0; + + /** + * Rewrite to a tag disjunction on the __safeContent__ field. + */ + virtual std::unique_ptr<MatchExpression> rewriteToTagDisjunction( + MatchExpression* expr) const = 0; + /** + * Rewrite to a tag disjunction on the __safeContent__ field. + */ + virtual std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const = 0; + + /** + * Rewrite to an expression which can generate tags at runtime during an encrypted collscan. + */ + virtual std::unique_ptr<MatchExpression> rewriteToRuntimeComparison( + MatchExpression* expr) const = 0; + /** + * Rewrite to an expression which can generate tags at runtime during an encrypted collscan. + */ + virtual std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const = 0; + + const QueryRewriterInterface* _rewriter; + +private: + /** + * Sub-subtype associated with the find payload for this encrypted predicate. + */ + virtual EncryptedBinDataType encryptedBinDataType() const = 0; +}; + +/** + * Encrypted predicate rewrites are registered at startup time using MONGO_INITIALIZER blocks. + * MatchExpression rewrites are keyed on the MatchExpressionType enum, and Agg Expression rewrites + * are keyed on the dynamic type for the Expression subclass. + */ + +using ExpressionToRewriteMap = stdx::unordered_map< + std::type_index, + std::function<std::unique_ptr<Expression>(QueryRewriterInterface*, Expression*)>>; + +extern ExpressionToRewriteMap aggPredicateRewriteMap; + +using MatchTypeToRewriteMap = stdx::unordered_map< + MatchExpression::MatchType, + std::function<std::unique_ptr<MatchExpression>(QueryRewriterInterface*, MatchExpression*)>>; + +extern MatchTypeToRewriteMap matchPredicateRewriteMap; + +/** + * Register an agg rewrite if a condition is true at startup time. + */ +#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(className, rewriteClass, isEnabledExpr) \ + MONGO_INITIALIZER(encryptedAggPredicateRewriteFor_##className)(InitializerContext*) { \ + \ + invariant(aggPredicateRewriteMap.find(typeid(className)) == aggPredicateRewriteMap.end()); \ + aggPredicateRewriteMap[typeid(className)] = [&](auto* rewriter, auto* expr) { \ + if (isEnabledExpr) { \ + return rewriteClass{rewriter}.rewrite(expr); \ + } else { \ + return std::unique_ptr<Expression>(nullptr); \ + } \ + }; \ + } + +/** + * Register an agg rewrite unconditionally. + */ +#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE(matchType, rewriteClass) \ + REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(matchType, rewriteClass, true) + +/** + * Register an agg rewrite behind a feature flag. + */ +#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(matchType, rewriteClass, featureFlag) \ + REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED( \ + matchType, rewriteClass, featureFlag.isEnabled(serverGlobalParams.featureCompatibility)) + +/** + * Register a MatchExpression rewrite if a condition is true at startup time. + */ +#define REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_GUARDED(matchType, rewriteClass, isEnabledExpr) \ + MONGO_INITIALIZER(encryptedMatchPredicateRewriteFor_##matchType)(InitializerContext*) { \ + \ + invariant(matchPredicateRewriteMap.find(MatchExpression::matchType) == \ + matchPredicateRewriteMap.end()); \ + matchPredicateRewriteMap[MatchExpression::matchType] = [&](auto* rewriter, auto* expr) { \ + if (isEnabledExpr) { \ + return rewriteClass{rewriter}.rewrite(expr); \ + } else { \ + return std::unique_ptr<MatchExpression>(nullptr); \ + } \ + }; \ + }; +/** + * Register a MatchExpression rewrite unconditionally. + */ +#define REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE(matchType, rewriteClass) \ + REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_GUARDED(matchType, rewriteClass, true) + +/** + * Register a MatchExpression rewrite behind a feature flag. + */ +#define REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(matchType, rewriteClass, featureFlag) \ + REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_GUARDED( \ + matchType, rewriteClass, featureFlag.isEnabled(serverGlobalParams.featureCompatibility)) +} // namespace fle +} // namespace mongo diff --git a/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h b/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h new file mode 100644 index 00000000000..88a6c85983a --- /dev/null +++ b/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h @@ -0,0 +1,106 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/pipeline/expression_context_for_test.h" +#include "mongo/db/query/fle/encrypted_predicate.h" +#include "mongo/db/query/fle/equality_predicate.h" +#include "mongo/db/query/fle/query_rewriter_interface.h" +#include "mongo/db/query/fle/range_predicate.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/overloaded_visitor.h" + +namespace mongo::fle { +using TagMap = std::map<std::pair<StringData, int>, std::vector<PrfBlock>>; + +/* + * The MockServerRewrite allows unit testing individual predicate rewrites without going through the + * real server rewrite that traverses full expression trees. + */ +class MockServerRewrite : public QueryRewriterInterface { +public: + MockServerRewrite() : _expCtx((new ExpressionContextForTest())) {} + const FLEStateCollectionReader* getEscReader() const override { + return nullptr; + } + const FLEStateCollectionReader* getEccReader() const override { + return nullptr; + } + EncryptedCollScanMode getEncryptedCollScanMode() const override { + return _mode; + }; + ExpressionContext* getExpressionContext() const { + return _expCtx.get(); + } + + void setForceEncryptedCollScanForTest() { + _mode = EncryptedCollScanMode::kForceAlways; + } + +private: + boost::intrusive_ptr<ExpressionContextForTest> _expCtx; + EncryptedCollScanMode _mode{EncryptedCollScanMode::kUseIfNeeded}; +}; + +class EncryptedPredicateRewriteTest : public unittest::Test { +public: + EncryptedPredicateRewriteTest() {} + + void setUp() override {} + + void tearDown() override {} + + static std::unique_ptr<MatchExpression> makeInExpr(StringData fieldname, + BSONArray disjunctions) { + auto inExpr = std::make_unique<InMatchExpression>(fieldname); + std::vector<BSONElement> elems; + disjunctions.elems(elems); + uassertStatusOK(inExpr->setEqualities(elems)); + inExpr->setBackingBSON(std::move(disjunctions)); + return inExpr; + } + + /* + * Assertion helper for tag disjunction rewrite. + */ + void assertRewriteToTags(const EncryptedPredicate& pred, + MatchExpression* input, + BSONArray expectedTags) { + + auto actual = pred.rewrite(input); + auto expected = makeInExpr(kSafeContent, expectedTags); + ASSERT_BSONOBJ_EQ(actual->serialize(), + static_cast<MatchExpression*>(expected.get())->serialize()); + } + +protected: + MockServerRewrite _mock{}; +}; +} // namespace mongo::fle diff --git a/src/mongo/db/query/fle/equality_predicate.cpp b/src/mongo/db/query/fle/equality_predicate.cpp new file mode 100644 index 00000000000..fc46bcbfe3a --- /dev/null +++ b/src/mongo/db/query/fle/equality_predicate.cpp @@ -0,0 +1,402 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "equality_predicate.h" + +#include "mongo/crypto/fle_crypto.h" +#include "mongo/crypto/fle_tags.h" +#include "mongo/db/matcher/expression_expr.h" +#include "mongo/db/matcher/expression_leaf.h" +#include "mongo/db/pipeline/expression.h" +#include "mongo/db/query/fle/encrypted_predicate.h" +#include "mongo/util/overloaded_visitor.h" + +namespace mongo::fle { + +REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE(EQ, EqualityPredicate); +REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE(MATCH_IN, EqualityPredicate); +REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE(ExpressionCompare, EqualityPredicate); +REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE(ExpressionIn, EqualityPredicate); + +std::vector<PrfBlock> EqualityPredicate::generateTags(BSONValue payload) const { + ParsedFindEqualityPayload tokens = parseFindPayload<ParsedFindEqualityPayload>(payload); + return readTags(*_rewriter->getEscReader(), + *_rewriter->getEccReader(), + tokens.escToken, + tokens.eccToken, + tokens.edcToken, + tokens.maxCounter); +} + +std::unique_ptr<MatchExpression> EqualityPredicate::rewriteToTagDisjunction( + MatchExpression* expr) const { + switch (expr->matchType()) { + case MatchExpression::EQ: { + auto eqExpr = static_cast<EqualityMatchExpression*>(expr); + auto payload = eqExpr->getData(); + if (!isPayload(payload)) { + return nullptr; + } + auto obj = toBSONArray(generateTags(payload)); + + auto tags = std::vector<BSONElement>(); + obj.elems(tags); + + auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); + inExpr->setBackingBSON(std::move(obj)); + auto status = inExpr->setEqualities(std::move(tags)); + uassertStatusOK(status); + return inExpr; + } + case MatchExpression::MATCH_IN: { + auto inExpr = static_cast<InMatchExpression*>(expr); + size_t numFFPs = 0; + for (auto& eq : inExpr->getEqualities()) { + if (isPayload(eq)) { + ++numFFPs; + } + } + if (numFFPs == 0) { + return nullptr; + } + // All elements in an encrypted $in expression should be FFPs. + uassert(6329400, + "If any elements in a $in expression are encrypted, then all elements should " + "be encrypted.", + numFFPs == inExpr->getEqualities().size()); + auto backingBSONBuilder = BSONArrayBuilder(); + + for (auto& eq : inExpr->getEqualities()) { + auto obj = toBSONArray(generateTags(eq)); + for (auto&& elt : obj) { + backingBSONBuilder.append(elt); + } + } + + auto backingBSON = backingBSONBuilder.arr(); + auto allTags = std::vector<BSONElement>(); + backingBSON.elems(allTags); + + auto newExpr = std::make_unique<InMatchExpression>(kSafeContent); + newExpr->setBackingBSON(std::move(backingBSON)); + auto status = newExpr->setEqualities(std::move(allTags)); + uassertStatusOK(status); + + return newExpr; + } + default: + MONGO_UNREACHABLE_TASSERT(6911300); + } + MONGO_UNREACHABLE_TASSERT(6911301); +}; + +namespace { +template <typename PayloadT> +boost::intrusive_ptr<ExpressionInternalFLEEqual> generateFleEqualMatch(StringData path, + const PayloadT& ffp, + ExpressionContext* expCtx) { + // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] } + // } + auto tokens = ParsedFindEqualityPayload(ffp); + + uassert(6672401, + "Missing required field server encryption token in find payload", + tokens.serverToken.has_value()); + + return make_intrusive<ExpressionInternalFLEEqual>( + expCtx, + ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState), + tokens.serverToken.value().data, + tokens.maxCounter.value_or(0LL), + tokens.edcToken.data); +} + + +template <typename PayloadT> +std::unique_ptr<ExpressionInternalFLEEqual> generateFleEqualMatchUnique(StringData path, + const PayloadT& ffp, + ExpressionContext* expCtx) { + // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] } } + auto tokens = ParsedFindEqualityPayload(ffp); + + uassert(6672419, + "Missing required field server encryption token in find payload", + tokens.serverToken.has_value()); + + return std::make_unique<ExpressionInternalFLEEqual>( + expCtx, + ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState), + tokens.serverToken.value().data, + tokens.maxCounter.value_or(0LL), + tokens.edcToken.data); +} + +std::unique_ptr<MatchExpression> generateFleEqualMatchAndExpr(StringData path, + const BSONElement ffp, + ExpressionContext* expCtx) { + auto fleEqualMatch = generateFleEqualMatch(path, ffp, expCtx); + + return std::make_unique<ExprMatchExpression>(fleEqualMatch, expCtx); +} +} // namespace + +std::unique_ptr<MatchExpression> EqualityPredicate::rewriteToRuntimeComparison( + MatchExpression* expr) const { + switch (expr->matchType()) { + case MatchExpression::EQ: { + auto eqExpr = static_cast<EqualityMatchExpression*>(expr); + auto payload = eqExpr->getData(); + if (!isPayload(payload)) { + return nullptr; + } + return generateFleEqualMatchAndExpr( + eqExpr->path(), payload, _rewriter->getExpressionContext()); + } + case MatchExpression::MATCH_IN: { + auto inExpr = static_cast<InMatchExpression*>(expr); + size_t numFFPs = 0; + for (auto& eq : inExpr->getEqualities()) { + if (isPayload(eq)) { + ++numFFPs; + } + } + if (numFFPs == 0) { + return nullptr; + } + uassert(6911300, + "If any elements in a $in expression are encrypted, then all elements should " + "be encrypted.", + numFFPs == inExpr->getEqualities().size()); + std::vector<std::unique_ptr<MatchExpression>> matches; + matches.reserve(numFFPs); + + for (auto& eq : inExpr->getEqualities()) { + auto exprMatch = generateFleEqualMatchAndExpr( + expr->path(), eq, _rewriter->getExpressionContext()); + matches.push_back(std::move(exprMatch)); + } + + auto orExpr = std::make_unique<OrMatchExpression>(std::move(matches)); + return orExpr; + } + default: + MONGO_UNREACHABLE; + } + MONGO_UNREACHABLE; +} + +/* + * Helper function for code shared between tag disjunction and runtime evaluation for the equality + * case. + */ +boost::optional<std::pair<ExpressionFieldPath*, ExpressionConstant*>> +EqualityPredicate::extractDetailsFromComparison(ExpressionCompare* expr) const { + auto equalitiesList = expr->getChildren(); + + auto leftConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[0].get()); + auto rightConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[1].get()); + + bool isLeftFFP = leftConstant && isPayload(leftConstant->getValue()); + bool isRightFFP = rightConstant && isPayload(rightConstant->getValue()); + + uassert(6334100, + "Cannot compare two encrypted constants to each other", + !(isLeftFFP && isRightFFP)); + + // No FLE Find Payload + if (!isLeftFFP && !isRightFFP) { + return boost::none; + } + + auto leftFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[0].get()); + auto rightFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[1].get()); + + uassert(6672413, + "Queryable Encryption only supports comparisons between a field path and a constant", + leftFieldPath || rightFieldPath); + auto fieldPath = leftFieldPath ? leftFieldPath : rightFieldPath; + auto constChild = isLeftFFP ? leftConstant : rightConstant; + return {{fieldPath, constChild}}; +} + +/** + * Perform validation on $in add expressions, and return a pre-processed fieldpath expression for + * use by the rewrite. This factors out common validation code for the runtime and tag rewrite + * cases. + */ +boost::optional<const ExpressionFieldPath*> EqualityPredicate::validateIn( + ExpressionIn* inExpr, ExpressionArray* inList) const { + auto leftExpr = inExpr->getOperandList()[0].get(); + auto& equalitiesList = inList->getChildren(); + size_t numFFPs = 0; + + for (auto& equality : equalitiesList) { + // For each expression representing a FleFindPayload... + if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { + if (!isPayload(constChild->getValue())) { + continue; + } + + numFFPs++; + } + } + + // Finally, construct an $or of all of the $ins. + if (numFFPs == 0) { + return boost::none; + } + + uassert(6334102, + "If any elements in an comparison expression are encrypted, then all elements " + "should " + "be encrypted.", + numFFPs == equalitiesList.size()); + + auto leftFieldPath = dynamic_cast<const ExpressionFieldPath*>(leftExpr); + uassert(6672417, + "$in is only supported with Queryable Encryption when the first argument is a " + "field path", + leftFieldPath != nullptr); + return leftFieldPath; +} + +std::unique_ptr<Expression> EqualityPredicate::rewriteToTagDisjunction(Expression* expr) const { + if (auto eqExpr = dynamic_cast<ExpressionCompare*>(expr); eqExpr) { + if (eqExpr->getOp() != ExpressionCompare::EQ && eqExpr->getOp() != ExpressionCompare::NE) { + return nullptr; + } + auto details = extractDetailsFromComparison(eqExpr); + if (!details) { + return nullptr; + } + auto [_, constChild] = details.value(); + + std::vector<boost::intrusive_ptr<Expression>> orListElems; + auto payload = constChild->getValue(); + auto tags = toValues(generateTags(std::ref(payload))); + for (auto&& tagElt : tags) { + // ... and for each tag, construct expression {$in: [tag, + // "$__safeContent__"]}. + std::vector<boost::intrusive_ptr<Expression>> inVec{ + ExpressionConstant::create(_rewriter->getExpressionContext(), tagElt), + ExpressionFieldPath::createPathFromString( + _rewriter->getExpressionContext(), + kSafeContent, + _rewriter->getExpressionContext()->variablesParseState)}; + orListElems.push_back( + make_intrusive<ExpressionIn>(_rewriter->getExpressionContext(), std::move(inVec))); + } + auto disjunction = std::make_unique<ExpressionOr>(_rewriter->getExpressionContext(), + std::move(orListElems)); + if (eqExpr->getOp() == ExpressionCompare::NE) { + std::vector<boost::intrusive_ptr<Expression>> notChild{disjunction.release()}; + return std::make_unique<ExpressionNot>(_rewriter->getExpressionContext(), + std::move(notChild)); + } + return disjunction; + } else if (auto inExpr = dynamic_cast<ExpressionIn*>(expr)) { + if (auto inList = dynamic_cast<ExpressionArray*>(inExpr->getOperandList()[1].get())) { + if (!validateIn(inExpr, inList)) { + return nullptr; + } + auto& equalitiesList = inList->getChildren(); + std::vector<boost::intrusive_ptr<Expression>> orListElems; + auto expCtx = _rewriter->getExpressionContext(); + for (auto& equality : equalitiesList) { + // For each expression representing a FleFindPayload... + if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { + // ... rewrite the payload to a list of tags... + auto payload = constChild->getValue(); + auto tags = toValues(generateTags(std::ref(payload))); + for (auto&& tagElt : tags) { + // ... and for each tag, construct expression {$in: [tag, + // "$__safeContent__"]}. + std::vector<boost::intrusive_ptr<Expression>> inVec{ + ExpressionConstant::create(expCtx, tagElt), + ExpressionFieldPath::createPathFromString( + expCtx, kSafeContent, expCtx->variablesParseState)}; + orListElems.push_back( + make_intrusive<ExpressionIn>(expCtx, std::move(inVec))); + } + } + } + return std::make_unique<ExpressionOr>(expCtx, std::move(orListElems)); + } + return nullptr; + } + MONGO_UNREACHABLE_TASSERT(6911303); +} + +std::unique_ptr<Expression> EqualityPredicate::rewriteToRuntimeComparison(Expression* expr) const { + if (auto eqExpr = dynamic_cast<ExpressionCompare*>(expr); eqExpr) { + if (eqExpr->getOp() != ExpressionCompare::EQ && eqExpr->getOp() != ExpressionCompare::NE) { + return nullptr; + } + auto details = extractDetailsFromComparison(eqExpr); + if (!details) { + return nullptr; + } + auto [fieldPath, constChild] = details.value(); + auto fleEqualExpr = + generateFleEqualMatchUnique(fieldPath->getFieldPathWithoutCurrentPrefix().fullPath(), + constChild->getValue(), + _rewriter->getExpressionContext()); + if (eqExpr->getOp() == ExpressionCompare::NE) { + std::vector<boost::intrusive_ptr<Expression>> notChild{fleEqualExpr.release()}; + return std::make_unique<ExpressionNot>(_rewriter->getExpressionContext(), + std::move(notChild)); + } + return fleEqualExpr; + } else if (auto inExpr = dynamic_cast<ExpressionIn*>(expr)) { + if (auto inList = dynamic_cast<ExpressionArray*>(inExpr->getOperandList()[1].get())) { + auto leftFieldPath = validateIn(inExpr, inList); + if (!leftFieldPath) { + return nullptr; + } + auto& equalitiesList = inList->getChildren(); + std::vector<boost::intrusive_ptr<Expression>> orListElems; + for (auto& equality : equalitiesList) { + if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { + auto fleEqExpr = generateFleEqualMatch( + leftFieldPath.value()->getFieldPathWithoutCurrentPrefix().fullPath(), + constChild->getValue(), + _rewriter->getExpressionContext()); + orListElems.push_back(fleEqExpr); + } + } + return std::make_unique<ExpressionOr>(_rewriter->getExpressionContext(), + std::move(orListElems)); + } + return nullptr; + } + MONGO_UNREACHABLE_TASSERT(6911304); +} +} // namespace mongo::fle diff --git a/src/mongo/db/query/fle/equality_predicate.h b/src/mongo/db/query/fle/equality_predicate.h new file mode 100644 index 00000000000..f7630c1d5aa --- /dev/null +++ b/src/mongo/db/query/fle/equality_predicate.h @@ -0,0 +1,65 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/crypto/fle_crypto.h" +#include "mongo/db/query/fle/encrypted_predicate.h" + +namespace mongo::fle { +/** + * Server-side rewrite for the encrypted equality index. This rewrite expects either a $eq or $in + * expression. + */ +class EqualityPredicate : public EncryptedPredicate { +public: + EqualityPredicate(const QueryRewriterInterface* rewriter) : EncryptedPredicate(rewriter) {} + +protected: + std::vector<PrfBlock> generateTags(BSONValue payload) const override; + + std::unique_ptr<MatchExpression> rewriteToTagDisjunction(MatchExpression* expr) const override; + std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override; + + std::unique_ptr<MatchExpression> rewriteToRuntimeComparison( + MatchExpression* expr) const override; + std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const override; + +private: + EncryptedBinDataType encryptedBinDataType() const override { + return EncryptedBinDataType::kFLE2FindEqualityPayload; + } + + boost::optional<std::pair<ExpressionFieldPath*, ExpressionConstant*>> + extractDetailsFromComparison(ExpressionCompare* eqExpr) const; + + boost::optional<const ExpressionFieldPath*> validateIn(ExpressionIn* inExpr, + ExpressionArray* inList) const; +}; +} // namespace mongo::fle diff --git a/src/mongo/db/query/fle/equality_predicate_test.cpp b/src/mongo/db/query/fle/equality_predicate_test.cpp new file mode 100644 index 00000000000..5a8ee9f5e70 --- /dev/null +++ b/src/mongo/db/query/fle/equality_predicate_test.cpp @@ -0,0 +1,465 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/bson/bsonmisc.h" +#include "mongo/bson/json.h" +#include "mongo/crypto/fle_crypto.h" +#include "mongo/db/exec/document_value/value.h" +#include "mongo/db/matcher/expression_expr.h" +#include "mongo/db/matcher/expression_leaf.h" +#include "mongo/db/matcher/expression_tree.h" +#include "mongo/db/matcher/expression_visitor.h" +#include "mongo/db/pipeline/expression.h" +#include "mongo/db/query/fle/encrypted_predicate_test_fixtures.h" +#include "mongo/db/query/fle/equality_predicate.h" +#include "mongo/idl/server_parameter_test_util.h" +#include "mongo/unittest/unittest.h" + +namespace mongo::fle { +namespace { +class MockEqualityPredicate : public EqualityPredicate { +public: + MockEqualityPredicate(const QueryRewriterInterface* rewriter) : EqualityPredicate(rewriter) {} + MockEqualityPredicate(const QueryRewriterInterface* rewriter, + TagMap tags, + std::set<StringData> encryptedFields) + : EqualityPredicate(rewriter), _tags(tags), _encryptedFields(encryptedFields) {} + + void setEncryptedTags(std::pair<StringData, int> fieldvalue, std::vector<PrfBlock> tags) { + _encryptedFields.insert(fieldvalue.first); + _tags[fieldvalue] = tags; + } + + void addEncryptedField(StringData field) { + _encryptedFields.insert(field); + } + + +protected: + bool isPayload(const BSONElement& elt) const override { + return _encryptedFields.find(elt.fieldNameStringData()) != _encryptedFields.end(); + } + + bool isPayload(const Value& v) const override { + return true; + } + + std::vector<PrfBlock> generateTags(BSONValue payload) const { + return stdx::visit( + OverloadedVisitor{ + [&](BSONElement p) { + ASSERT(p.isNumber()); // Only accept numbers as mock FFPs. + ASSERT(_tags.find({p.fieldNameStringData(), p.Int()}) != _tags.end()); + return _tags.find({p.fieldNameStringData(), p.Int()})->second; + }, + [&](std::reference_wrapper<Value> v) { return std::vector<PrfBlock>{}; }}, + payload); + } + +private: + TagMap _tags; + std::set<StringData> _encryptedFields; +}; + +class EqualityPredicateRewriteTest : public EncryptedPredicateRewriteTest { +public: + EqualityPredicateRewriteTest() : _predicate(&_mock) {} + +protected: + MockEqualityPredicate _predicate; +}; + +TEST_F(EqualityPredicateRewriteTest, Equality_NoFFP) { + std::unique_ptr<MatchExpression> input = + std::make_unique<EqualityMatchExpression>("ssn", Value(5)); + auto expected = EqualityMatchExpression("ssn", Value(5)); + + auto result = _predicate.rewrite(input.get()); + ASSERT(result == nullptr); + ASSERT(input->equivalent(&expected)); +} + +TEST_F(EqualityPredicateRewriteTest, In_NoFFP) { + auto input = makeInExpr("name", + BSON_ARRAY("harry" + << "ron" + << "hermione")); + auto expected = makeInExpr("name", + BSON_ARRAY("harry" + << "ron" + << "hermione")); + + auto result = _predicate.rewrite(input.get()); + ASSERT(result == nullptr); + ASSERT(input->equivalent(expected.get())); +} + +TEST_F(EqualityPredicateRewriteTest, Equality_Basic) { + auto input = EqualityMatchExpression("ssn", Value(5)); + std::vector<PrfBlock> tags = {{1}, {2}, {3}}; + + _predicate.setEncryptedTags({"ssn", 5}, tags); + + assertRewriteToTags(_predicate, &input, toBSONArray(std::move(tags))); +} + +TEST_F(EqualityPredicateRewriteTest, In_Basic) { + auto input = makeInExpr("ssn", BSON_ARRAY(2 << 4 << 6)); + + _predicate.setEncryptedTags({"0", 2}, {{1}, {2}}); + _predicate.setEncryptedTags({"1", 4}, {{5}, {3}}); + _predicate.setEncryptedTags({"2", 6}, {{99}, {100}}); + + assertRewriteToTags(_predicate, input.get(), toBSONArray({{1}, {2}, {3}, {5}, {99}, {100}})); +} + +TEST_F(EqualityPredicateRewriteTest, In_NotAllFFPs) { + auto input = makeInExpr("ssn", BSON_ARRAY(2 << 4 << 6)); + + _predicate.setEncryptedTags({"0", 2}, {{1}, {2}}); + _predicate.setEncryptedTags({"1", 4}, {{5}, {3}}); + + ASSERT_THROWS_CODE( + assertRewriteToTags(_predicate, input.get(), toBSONArray({{1}, {2}, {3}, {5}})), + AssertionException, + 6329400); +} + +template <typename T> +std::vector<uint8_t> toEncryptedVector(EncryptedBinDataType dt, T t) { + BSONObj obj = t.toBSON(); + + std::vector<uint8_t> buf(obj.objsize() + 1); + buf[0] = static_cast<uint8_t>(dt); + + std::copy(obj.objdata(), obj.objdata() + obj.objsize(), buf.data() + 1); + + return buf; +} + +template <typename T> +void toEncryptedBinData(StringData field, EncryptedBinDataType dt, T t, BSONObjBuilder* builder) { + auto buf = toEncryptedVector(dt, t); + + builder->appendBinData(field, buf.size(), BinDataType::Encrypt, buf.data()); +} + +constexpr auto kIndexKeyId = "12345678-1234-9876-1234-123456789012"_sd; +constexpr auto kUserKeyId = "ABCDEFAB-1234-9876-1234-123456789012"_sd; +static UUID indexKeyId = uassertStatusOK(UUID::parse(kIndexKeyId.toString())); +static UUID userKeyId = uassertStatusOK(UUID::parse(kUserKeyId.toString())); + +std::vector<char> testValue = {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19}; +std::vector<char> testValue2 = {0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29}; + +const FLEIndexKey& getIndexKey() { + static std::string indexVec = hexblob::decode( + "7dbfebc619aa68a659f64b8e23ccd21644ac326cb74a26840c3d2420176c40ae088294d00ad6cae9684237b21b754cf503f085c25cd320bf035c3417416e1e6fe3d9219f79586582112740b2add88e1030d91926ae8afc13ee575cfb8bb965b7"_sd); + static FLEIndexKey indexKey(KeyMaterial(indexVec.begin(), indexVec.end())); + return indexKey; +} + +const FLEUserKey& getUserKey() { + static std::string userVec = hexblob::decode( + "a7ddbc4c8be00d51f68d9d8e485f351c8edc8d2206b24d8e0e1816d005fbe520e489125047d647b0d8684bfbdbf09c304085ed086aba6c2b2b1677ccc91ced8847a733bf5e5682c84b3ee7969e4a5fe0e0c21e5e3ee190595a55f83147d8de2a"_sd); + static FLEUserKey userKey(KeyMaterial(userVec.begin(), userVec.end())); + return userKey; +} + + +BSONObj generateFFP(StringData path, int value) { + auto indexKey = getIndexKey(); + FLEIndexKeyAndId indexKeyAndId(indexKey.data, indexKeyId); + auto userKey = getUserKey(); + FLEUserKeyAndId userKeyAndId(userKey.data, indexKeyId); + + BSONObj doc = BSON("value" << value); + auto element = doc.firstElement(); + auto fpp = FLEClientCrypto::serializeFindPayload(indexKeyAndId, userKeyAndId, element, 0); + + BSONObjBuilder builder; + toEncryptedBinData(path, EncryptedBinDataType::kFLE2FindEqualityPayload, fpp, &builder); + return builder.obj(); +} + +std::unique_ptr<MatchExpression> generateEqualityWithFFP(StringData path, int value) { + auto ffp = generateFFP(path, value); + return std::make_unique<EqualityMatchExpression>(path, ffp.firstElement()); +} + +std::unique_ptr<Expression> generateEqualityWithFFP(ExpressionContext* const expCtx, + StringData path, + int value) { + auto ffp = Value(generateFFP(path, value).firstElement()); + auto ffpExpr = make_intrusive<ExpressionConstant>(expCtx, ffp); + auto fieldpath = ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState); + std::vector<boost::intrusive_ptr<Expression>> children = {std::move(fieldpath), + std::move(ffpExpr)}; + return std::make_unique<ExpressionCompare>(expCtx, ExpressionCompare::EQ, std::move(children)); +} + +std::unique_ptr<MatchExpression> generateDisjunctionWithFFP(StringData path, + std::initializer_list<int> vals) { + BSONArrayBuilder bab; + for (auto& value : vals) { + bab.append(generateFFP(path, value).firstElement()); + } + auto arr = bab.arr(); + return EncryptedPredicateRewriteTest::makeInExpr(path, arr); +} + +std::unique_ptr<Expression> generateDisjunctionWithFFP(ExpressionContext* const expCtx, + StringData path, + std::initializer_list<int> values) { + std::vector<boost::intrusive_ptr<Expression>> ffps; + for (auto& value : values) { + auto ffp = make_intrusive<ExpressionConstant>( + expCtx, Value(generateFFP(path, value).firstElement())); + ffps.emplace_back(std::move(ffp)); + } + auto ffpArray = make_intrusive<ExpressionArray>(expCtx, std::move(ffps)); + auto fieldpath = ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState); + std::vector<boost::intrusive_ptr<Expression>> children{std::move(fieldpath), + std::move(ffpArray)}; + return std::make_unique<ExpressionIn>(expCtx, std::move(children)); +} + +class EqualityPredicateCollScanRewriteTest : public EncryptedPredicateRewriteTest { +public: + EqualityPredicateCollScanRewriteTest() : _predicate(&_mock) { + _mock.setForceEncryptedCollScanForTest(); + } + +protected: + MockEqualityPredicate _predicate; +}; + +TEST_F(EqualityPredicateCollScanRewriteTest, Eq_Match) { + auto input = generateEqualityWithFFP("ssn", 1); + _predicate.addEncryptedField("ssn"); + + auto result = _predicate.rewrite(input.get()); + + ASSERT(result); + ASSERT_EQ(result->matchType(), MatchExpression::EXPRESSION); + auto* expr = static_cast<ExprMatchExpression*>(result.get()); + auto aggExpr = expr->getExpression(); + + auto expected = fromjson(R"({ + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + })"); + ASSERT_BSONOBJ_EQ(aggExpr->serialize(false).getDocument().toBson(), expected); +} + +TEST_F(EqualityPredicateCollScanRewriteTest, Eq_Expr) { + auto expCtx = _mock.getExpressionContext(); + auto input = generateEqualityWithFFP(expCtx, "ssn", 1); + _predicate.addEncryptedField("ssn"); + + auto result = _predicate.rewrite(input.get()); + + auto expected = fromjson(R"({ + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + })"); + ASSERT(result); + ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected); +} + +TEST_F(EqualityPredicateCollScanRewriteTest, In_Match) { + auto input = generateDisjunctionWithFFP("ssn", {1, 2, 3}); + _predicate.addEncryptedField("0"); + _predicate.addEncryptedField("1"); + _predicate.addEncryptedField("2"); + + auto result = _predicate.rewrite(input.get()); + + ASSERT(result); + ASSERT_EQ(result->matchType(), MatchExpression::OR); + auto expected = fromjson(R"({ + "$or": [ + { + "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } + }, + { + "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CLpCo6rNuYMVT+6n1HCX15MNrVYDNqf6udO46ayo43Sw", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } + }, + { + "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CPi44oCQHnNDeRqHsNLzbdCeHt2DK/wCly0g2dxU5fqN", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } + } + ] +})"); + ASSERT_BSONOBJ_EQ(result->serialize(), expected); +} + +TEST_F(EqualityPredicateCollScanRewriteTest, In_Expr) { + auto input = generateDisjunctionWithFFP(_mock.getExpressionContext(), "ssn", {1, 1}); + _predicate.addEncryptedField("0"); + _predicate.addEncryptedField("1"); + _predicate.addEncryptedField("2"); + + auto result = _predicate.rewrite(input.get()); + + ASSERT(result); + auto expected = fromjson(R"({ "$or" : [ { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + }}, + { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + }} + ]})"); + ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected); +} + +} // namespace +} // namespace mongo::fle diff --git a/src/mongo/db/query/fle/query_rewriter.cpp b/src/mongo/db/query/fle/query_rewriter.cpp new file mode 100644 index 00000000000..441f436ec00 --- /dev/null +++ b/src/mongo/db/query/fle/query_rewriter.cpp @@ -0,0 +1,124 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "query_rewriter.h" + +#include "mongo/db/matcher/expression_expr.h" +#include "mongo/db/matcher/expression_parser.h" + +namespace mongo::fle { + +class ExpressionRewriter { +public: + ExpressionRewriter(QueryRewriter* queryRewriter, const ExpressionToRewriteMap& exprRewrites) + : queryRewriter(queryRewriter), exprRewrites(exprRewrites){}; + + std::unique_ptr<Expression> postVisit(Expression* exp) { + if (auto rewrite = exprRewrites.find(typeid(*exp)); rewrite != exprRewrites.end()) { + auto expr = rewrite->second(queryRewriter, exp); + if (expr != nullptr) { + didRewrite = true; + } + return expr; + } + return nullptr; + } + + QueryRewriter* queryRewriter; + const ExpressionToRewriteMap& exprRewrites; + bool didRewrite = false; +}; + +std::unique_ptr<Expression> QueryRewriter::rewriteExpression(Expression* expression) { + tassert(6334104, "Expected an expression to rewrite but found none", expression); + + ExpressionRewriter expressionRewriter{this, this->_exprRewrites}; + auto res = expression_walker::walk<Expression>(expression, &expressionRewriter); + _rewroteLastExpression = expressionRewriter.didRewrite; + return res; +} + +boost::optional<BSONObj> QueryRewriter::rewriteMatchExpression(const BSONObj& filter) { + auto expr = uassertStatusOK(MatchExpressionParser::parse(filter, _expCtx)); + + _rewroteLastExpression = false; + if (auto res = _rewrite(expr.get())) { + // The rewrite resulted in top-level changes. Serialize the new expression. + return res->serialize().getOwned(); + } else if (_rewroteLastExpression) { + // The rewrite had no top-level changes, but nested expressions were rewritten. Serialize + // the parsed expression, which has in-place changes. + return expr->serialize().getOwned(); + } + + // No rewrites were done. + return boost::none; +} + +std::unique_ptr<MatchExpression> QueryRewriter::_rewrite(MatchExpression* expr) { + switch (expr->matchType()) { + case MatchExpression::AND: + case MatchExpression::OR: + case MatchExpression::NOT: + case MatchExpression::NOR: { + for (size_t i = 0; i < expr->numChildren(); i++) { + auto child = expr->getChild(i); + if (auto newChild = _rewrite(child)) { + expr->resetChild(i, newChild.release()); + } + } + return nullptr; + } + case MatchExpression::EXPRESSION: { + // Save the current value of _rewroteLastExpression, since rewriteExpression() may + // reset it to false and we may have already done a match expression rewrite. + auto didRewrite = _rewroteLastExpression; + auto rewritten = + rewriteExpression(static_cast<ExprMatchExpression*>(expr)->getExpression().get()); + _rewroteLastExpression |= didRewrite; + if (rewritten) { + return std::make_unique<ExprMatchExpression>(rewritten.release(), + getExpressionContext()); + } + return nullptr; + } + default: { + if (auto rewrite = _matchRewrites.find(expr->matchType()); + rewrite != _matchRewrites.end()) { + auto rewritten = rewrite->second(this, expr); + if (rewritten != nullptr) { + _rewroteLastExpression = true; + } + return rewritten; + } + return nullptr; + } + } +} +} // namespace mongo::fle diff --git a/src/mongo/db/query/fle/query_rewriter.h b/src/mongo/db/query/fle/query_rewriter.h new file mode 100644 index 00000000000..cf8470a84ac --- /dev/null +++ b/src/mongo/db/query/fle/query_rewriter.h @@ -0,0 +1,156 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/bson/bsonobj.h" +#include "mongo/crypto/fle_crypto.h" +#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/query/fle/encrypted_predicate.h" +#include "mongo/db/query/fle/query_rewriter_interface.h" + +namespace mongo::fle { +/** + * Class which handles traversing expressions and rewriting predicates for FLE2. + * + * The QueryRewriter is responsible for traversing Agg Expressions and MatchExpression trees and + * calling individual rewrites (subclasses of EncryptedPredicate) that have been registered for each + * encrypted index type. + * + * The actual rewrites performed are stored in references to maps in the class. In non-test + * environments, these are global maps that register encrypted predicate rewrites that live in their + * own files. + */ +class QueryRewriter : public QueryRewriterInterface { +public: + /** + * Takes in references to collection readers for the ESC and ECC that are used during tag + * computation. + */ + QueryRewriter(boost::intrusive_ptr<ExpressionContext> expCtx, + const FLEStateCollectionReader& escReader, + const FLEStateCollectionReader& eccReader, + EncryptedCollScanModeAllowed mode = EncryptedCollScanModeAllowed::kAllow) + : _expCtx(expCtx), + _escReader(&escReader), + _eccReader(&eccReader), + _exprRewrites(aggPredicateRewriteMap), + _matchRewrites(matchPredicateRewriteMap) { + + if (internalQueryFLEAlwaysUseEncryptedCollScanMode.load()) { + _mode = EncryptedCollScanMode::kForceAlways; + } + + if (mode == EncryptedCollScanModeAllowed::kDisallow) { + _mode = EncryptedCollScanMode::kDisallow; + } + + // This isn't the "real" query so we don't want to increment Expression + // counters here. + _expCtx->stopExpressionCounters(); + } + + /** + * Accepts a BSONObj holding a MatchExpression, and returns BSON representing the rewritten + * expression. Returns boost::none if no rewriting was done. + * + * Rewrites the match expression with FLE find payloads into a disjunction on the + * __safeContent__ array of tags. + * + * Will rewrite top-level $eq and $in expressions, as well as recursing through $and, $or, $not + * and $nor. Also handles similarly limited rewriting under $expr. All other MatchExpressions, + * notably $elemMatch, are ignored. + */ + boost::optional<BSONObj> rewriteMatchExpression(const BSONObj& filter); + + /** + * Accepts an expression to be re-written. Will rewrite top-level expressions including $eq and + * $in, as well as recursing through other expressions. Returns a new pointer if the top-level + * expression must be changed. A nullptr indicates that the modifications happened in-place. + */ + std::unique_ptr<Expression> rewriteExpression(Expression* expression); + + bool isForceEncryptedCollScan() const { + return _mode == EncryptedCollScanMode::kForceAlways; + } + + void setForceEncryptedCollScanForTest() { + _mode = EncryptedCollScanMode::kForceAlways; + } + + EncryptedCollScanMode getEncryptedCollScanMode() const override { + return _mode; + } + + const FLEStateCollectionReader* getEscReader() const override { + return _escReader; + } + + const FLEStateCollectionReader* getEccReader() const override { + return _eccReader; + } + + ExpressionContext* getExpressionContext() const override { + return _expCtx.get(); + } + +protected: + // This constructor should only be used for mocks in testing. + QueryRewriter(boost::intrusive_ptr<ExpressionContext> expCtx, + const ExpressionToRewriteMap& exprRewrites, + const MatchTypeToRewriteMap& matchRewrites) + : _expCtx(expCtx), + _escReader(nullptr), + _eccReader(nullptr), + _exprRewrites(exprRewrites), + _matchRewrites(matchRewrites) {} + +private: + /** + * A single rewrite step, called recursively on child expressions. + */ + std::unique_ptr<MatchExpression> _rewrite(MatchExpression* me); + + boost::intrusive_ptr<ExpressionContext> _expCtx; + + // Holds a pointer so that these can be null for tests, even though the public constructor + // takes a const reference. + const FLEStateCollectionReader* _escReader; + const FLEStateCollectionReader* _eccReader; + + // True if the last Expression or MatchExpression processed by this rewriter was rewritten. + bool _rewroteLastExpression = false; + + // Controls how query rewriter rewrites the query + EncryptedCollScanMode _mode{EncryptedCollScanMode::kUseIfNeeded}; + + const ExpressionToRewriteMap& _exprRewrites; + const MatchTypeToRewriteMap& _matchRewrites; +}; +} // namespace mongo::fle diff --git a/src/mongo/db/query/fle/query_rewriter_interface.h b/src/mongo/db/query/fle/query_rewriter_interface.h new file mode 100644 index 00000000000..60f54defeac --- /dev/null +++ b/src/mongo/db/query/fle/query_rewriter_interface.h @@ -0,0 +1,71 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/crypto/fle_crypto.h" +#include "mongo/db/pipeline/expression_context.h" + +namespace mongo { +namespace fle { +enum class EncryptedCollScanMode { + // Always use high cardinality filters, used by tests + kForceAlways, + + // Use high cardinality mode if $in rewrites do not fit in the + // internalQueryFLERewriteMemoryLimit memory limit + kUseIfNeeded, + + // Do not rewrite into high cardinality filter, throw exceptions instead + // Some contexts like upsert do not support $expr + kDisallow, +}; + +/** + * Low Selectivity rewrites use $expr which is not supported in all commands such as upserts. + */ +enum class EncryptedCollScanModeAllowed { + kAllow, + kDisallow, +}; + +/** + * Pure virtual class that allows encrypted predicate rewrites to be unit tested independently from + * the actual server rewrite. + */ +class QueryRewriterInterface { +public: + virtual ~QueryRewriterInterface() {} + virtual const FLEStateCollectionReader* getEscReader() const = 0; + virtual const FLEStateCollectionReader* getEccReader() const = 0; + virtual EncryptedCollScanMode getEncryptedCollScanMode() const = 0; + virtual ExpressionContext* getExpressionContext() const = 0; +}; +} // namespace fle +} // namespace mongo diff --git a/src/mongo/db/query/fle/query_rewriter_test.cpp b/src/mongo/db/query/fle/query_rewriter_test.cpp new file mode 100644 index 00000000000..d779f6be1cc --- /dev/null +++ b/src/mongo/db/query/fle/query_rewriter_test.cpp @@ -0,0 +1,209 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + + +#include <memory> + +#include "query_rewriter.h" + +#include "mongo/bson/bsonelement.h" +#include "mongo/bson/bsonmisc.h" +#include "mongo/bson/bsonobjbuilder.h" +#include "mongo/bson/bsontypes.h" +#include "mongo/db/matcher/expression_leaf.h" +#include "mongo/db/pipeline/expression.h" +#include "mongo/db/pipeline/expression_context_for_test.h" +#include "mongo/db/query/fle/encrypted_predicate.h" +#include "mongo/db/query/fle/encrypted_predicate_test_fixtures.h" +#include "mongo/db/query/fle/query_rewriter_interface.h" +#include "mongo/idl/server_parameter_test_util.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/overloaded_visitor.h" + + +namespace mongo { +namespace { + +/* + * The server rewrite itself is only responsible for traversing agg and MatchExpressions and + * executing whatever rewrites are registered. For unit testing, we will only verify that this + * traversal and rewrite is happening properly using a mock predicate rewriter that rewrites any + * equality with an object with the key `encrypt` to a $gt operator. Unit tests for the actual + * rewrites while mocking out tag generation are located in the test file for each encrypted + * predicate type. Full end-to-end testing happens in jstests. This organization ensures that we + * don't write redundant tests that each index type is properly rewritten under different + * circumstances, when the same exact code is called for each index type. + */ + +class MockPredicateRewriter : public fle::EncryptedPredicate { +public: + MockPredicateRewriter(const fle::QueryRewriterInterface* rewriter) + : EncryptedPredicate(rewriter) {} + +protected: + bool isPayload(const BSONElement& elt) const override { + if (!elt.isABSONObj()) { + return false; + } + return elt.Obj().firstElementFieldNameStringData() == "encrypt"_sd; + } + bool isPayload(const Value& v) const override { + if (!v.isObject()) { + return false; + } + return !v.getDocument().getField("encrypt").missing(); + } + + std::vector<PrfBlock> generateTags(fle::BSONValue payload) const override { + return {}; + }; + + // Encrypted values will be rewritten from $eq to $gt. This is an arbitrary decision just to + // make sure that the rewrite works properly. + std::unique_ptr<MatchExpression> rewriteToTagDisjunction(MatchExpression* expr) const override { + invariant(expr->matchType() == MatchExpression::EQ); + auto eqMatch = static_cast<EqualityMatchExpression*>(expr); + if (!isPayload(eqMatch->getData())) { + return nullptr; + } + return std::make_unique<GTMatchExpression>(eqMatch->path(), + eqMatch->getData().Obj().firstElement()); + }; + + std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override { + return nullptr; + } + + std::unique_ptr<MatchExpression> rewriteToRuntimeComparison( + MatchExpression* expr) const override { + return nullptr; + } + + std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const override { + return nullptr; + } + +private: + EncryptedBinDataType encryptedBinDataType() const override { + return EncryptedBinDataType::kPlaceholder; // return the 0 type. this isn't used anywhere. + } +}; + +void setMockRewriteMaps(fle::MatchTypeToRewriteMap& match, + fle::ExpressionToRewriteMap& agg, + fle::TagMap& tags, + std::set<StringData>& encryptedFields) { + match[MatchExpression::EQ] = [&](auto* rewriter, auto* expr) { + return MockPredicateRewriter{rewriter}.rewrite(expr); + }; +} + +class MockQueryRewriter : public fle::QueryRewriter { +public: + MockQueryRewriter(fle::ExpressionToRewriteMap* exprRewrites, + fle::MatchTypeToRewriteMap* matchRewrites) + : fle::QueryRewriter(new ExpressionContextForTest(), *exprRewrites, *matchRewrites) { + setMockRewriteMaps(*matchRewrites, *exprRewrites, _tags, _encryptedFields); + } + + BSONObj rewriteMatchExpressionForTest(const BSONObj& obj) { + auto res = rewriteMatchExpression(obj); + return res ? res.value() : obj; + } + +private: + fle::TagMap _tags; + std::set<StringData> _encryptedFields; +}; + +class FLEServerRewriteTest : public unittest::Test { +public: + FLEServerRewriteTest() : _mock(nullptr) {} + + void setUp() override { + _mock = std::make_unique<MockQueryRewriter>(&_agg, &_match); + } + + void tearDown() override {} + +protected: + std::unique_ptr<MockQueryRewriter> _mock; + fle::ExpressionToRewriteMap _agg; + fle::MatchTypeToRewriteMap _match; +}; + +#define ASSERT_MATCH_EXPRESSION_REWRITE(input, expected) \ + auto actual = _mock->rewriteMatchExpressionForTest(fromjson(input)); \ + ASSERT_BSONOBJ_EQ(actual, fromjson(expected)); + +#define TEST_FLE_REWRITE_MATCH(name, input, expected) \ + TEST_F(FLEServerRewriteTest, name##_MatchExpression) { \ + ASSERT_MATCH_EXPRESSION_REWRITE(input, expected); \ + } + +TEST_FLE_REWRITE_MATCH(TopLevel_DottedPath, + "{'user.ssn': {$eq: {encrypt: 2}}}", + "{'user.ssn': {$gt: 2}}"); + +TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_BothEncrypted, + "{$and: [{ssn: {encrypt: 2}}, {age: {encrypt: 4}}]}", + "{$and: [{ssn: {$gt: 2}}, {age: {$gt: 4}}]}"); + +TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_PartlyEncrypted, + "{$and: [{ssn: {encrypt: 2}}, {age: {plain: 4}}]}", + "{$and: [{ssn: {$gt: 2}}, {age: {$eq: {plain: 4}}}]}"); + +TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_PartlyEncryptedWithUnregisteredOperator, + "{$and: [{ssn: {encrypt: 2}}, {age: {$lt: {encrypt: 4}}}]}", + "{$and: [{ssn: {$gt: 2}}, {age: {$lt: {encrypt: 4}}}]}"); + +TEST_FLE_REWRITE_MATCH(TopLevel_Encrypted_Nested_Unecrypted, + "{$and: [{ssn: {encrypt: 2}}, {user: {region: 'US'}}]}", + "{$and: [{ssn: {$gt: 2}}, {user: {$eq: {region: 'US'}}}]}"); + +TEST_FLE_REWRITE_MATCH(TopLevel_Not, + "{ssn: {$not: {$eq: {encrypt: 5}}}}", + "{ssn: {$not: {$gt: 5}}}"); + +TEST_FLE_REWRITE_MATCH(TopLevel_Neq, "{ssn: {$ne: {encrypt: 5}}}", "{ssn: {$not: {$gt: 5}}}}"); + +TEST_FLE_REWRITE_MATCH( + NestedConjunction, + "{$and: [{$and: [{ssn: {encrypt: 2}}, {other: 'field'}]}, {otherSsn: {encrypt: 3}}]}", + "{$and: [{$and: [{ssn: {$gt: 2}}, {other: {$eq: 'field'}}]}, {otherSsn: {$gt: 3}}]}"); + +TEST_FLE_REWRITE_MATCH(TopLevel_Nor, + "{$nor: [{ssn: {encrypt: 5}}, {other: {$eq: 'field'}}]}", + "{$nor: [{ssn: {$gt: 5}}, {other: {$eq: 'field'}}]}"); + +TEST_FLE_REWRITE_MATCH(TopLevel_Or, + "{$or: [{ssn: {encrypt: 5}}, {other: {$eq: 'field'}}]}", + "{$or: [{ssn: {$gt: 5}}, {other: {$eq: 'field'}}]}"); +} // namespace +} // namespace mongo diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp new file mode 100644 index 00000000000..322d3025dc4 --- /dev/null +++ b/src/mongo/db/query/fle/range_predicate.cpp @@ -0,0 +1,84 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "range_predicate.h" + +#include "mongo/crypto/encryption_fields_gen.h" +#include "mongo/db/matcher/expression_leaf.h" +#include "mongo/db/query/fle/encrypted_predicate.h" + +namespace mongo::fle { + +REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(ENCRYPTED_BETWEEN, + RangePredicate, + gFeatureFlagFLE2Range); + +// TODO: SERVER-67206 Generate tags for range payload. +std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const { + return {}; +} + +std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction( + MatchExpression* expr) const { + invariant(expr->matchType() == MatchExpression::ENCRYPTED_BETWEEN); + auto betExpr = static_cast<EncryptedBetweenMatchExpression*>(expr); + auto ffp = betExpr->rhs(); + + if (!isPayload(ffp)) { + return nullptr; + } + + auto obj = toBSONArray(generateTags(ffp)); + auto tags = std::vector<BSONElement>(); + obj.elems(tags); + auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); + inExpr->setBackingBSON(std::move(obj)); + auto status = inExpr->setEqualities(std::move(tags)); + uassertStatusOK(status); + return inExpr; +} + +// TODO: SERVER-67209 Server-side rewrite for agg expressions with $encryptedBetween. +std::unique_ptr<Expression> RangePredicate::rewriteToTagDisjunction(Expression* expr) const { + return nullptr; +} + +// TODO: SERVER-67267 Rewrite $encryptedBetween to $_internalFleBetween when number of tags exceeds +// limit. +std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison( + MatchExpression* expr) const { + return nullptr; +} + +// TODO: SERVER-67267 Rewrite $encryptedBetween to $_internalFleBetween when number of tags exceeds +// limit. +std::unique_ptr<Expression> RangePredicate::rewriteToRuntimeComparison(Expression* expr) const { + return nullptr; +} +} // namespace mongo::fle diff --git a/src/mongo/db/query/fle/range_predicate.h b/src/mongo/db/query/fle/range_predicate.h new file mode 100644 index 00000000000..a310cde16e2 --- /dev/null +++ b/src/mongo/db/query/fle/range_predicate.h @@ -0,0 +1,58 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/crypto/encryption_fields_gen.h" +#include "mongo/db/query/fle/encrypted_predicate.h" + +namespace mongo::fle { +/** + * Rewrite for the encrypted range index, which expects a $encryptedBetween expression. + */ +class RangePredicate : public EncryptedPredicate { +public: + RangePredicate(const QueryRewriterInterface* rewriter) : EncryptedPredicate(rewriter) {} + +protected: + std::vector<PrfBlock> generateTags(BSONValue payload) const override; + + std::unique_ptr<MatchExpression> rewriteToTagDisjunction(MatchExpression* expr) const override; + std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override; + + std::unique_ptr<MatchExpression> rewriteToRuntimeComparison( + MatchExpression* expr) const override; + std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const override; + +private: + EncryptedBinDataType encryptedBinDataType() const override { + return EncryptedBinDataType::kFLE2FindRangePayload; + } +}; +} // namespace mongo::fle diff --git a/src/mongo/db/query/fle/range_predicate_test.cpp b/src/mongo/db/query/fle/range_predicate_test.cpp new file mode 100644 index 00000000000..c240abde6f8 --- /dev/null +++ b/src/mongo/db/query/fle/range_predicate_test.cpp @@ -0,0 +1,131 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/crypto/fle_crypto.h" +#include "mongo/db/matcher/expression_leaf.h" +#include "mongo/db/query/fle/encrypted_predicate_test_fixtures.h" +#include "mongo/db/query/fle/range_predicate.h" +#include "mongo/idl/server_parameter_test_util.h" +#include "mongo/unittest/unittest.h" + +namespace mongo::fle { +namespace { +class MockRangePredicate : public RangePredicate { +public: + MockRangePredicate(const QueryRewriterInterface* rewriter) : RangePredicate(rewriter) {} + + MockRangePredicate(const QueryRewriterInterface* rewriter, + TagMap tags, + std::set<StringData> encryptedFields) + : RangePredicate(rewriter), _tags(tags), _encryptedFields(encryptedFields) {} + + void setEncryptedTags(std::pair<StringData, int> fieldvalue, std::vector<PrfBlock> tags) { + _encryptedFields.insert(fieldvalue.first); + _tags[fieldvalue] = tags; + } + + +protected: + bool isPayload(const BSONElement& elt) const override { + return true; + } + + bool isPayload(const Value& v) const override { + return true; + } + + std::vector<PrfBlock> generateTags(BSONValue payload) const { + return stdx::visit( + OverloadedVisitor{ + [&](BSONElement p) { + auto parsedPayload = p.Obj().firstElement(); + auto fieldName = parsedPayload.fieldNameStringData(); + + std::vector<BSONElement> range; + auto payloadAsArray = parsedPayload.Array(); + for (auto&& elt : payloadAsArray) { + range.push_back(elt); + } + + std::vector<PrfBlock> allTags; + for (auto i = range[0].Number(); i <= range[1].Number(); i++) { + ASSERT(_tags.find({fieldName, i}) != _tags.end()); + auto temp = _tags.find({fieldName, i})->second; + for (auto tag : temp) { + allTags.push_back(tag); + } + } + return allTags; + }, + [&](std::reference_wrapper<Value> v) { return std::vector<PrfBlock>{}; }}, + payload); + } + +private: + TagMap _tags; + std::set<StringData> _encryptedFields; +}; +class RangePredicateRewriteTest : public EncryptedPredicateRewriteTest { +public: + RangePredicateRewriteTest() : _predicate(&_mock) {} + +protected: + MockRangePredicate _predicate; +}; + +TEST_F(RangePredicateRewriteTest, BasicRangeRewrite) { + RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); + + int start = 1; + int end = 3; + StringData encField = "ssn"; + + std::vector<PrfBlock> tags1 = {{1}, {2}, {3}}; + std::vector<PrfBlock> tags2 = {{4}, {5}, {6}}; + std::vector<PrfBlock> tags3 = {{7}, {8}, {9}}; + + _predicate.setEncryptedTags({encField, 1}, tags1); + _predicate.setEncryptedTags({encField, 2}, tags2); + _predicate.setEncryptedTags({encField, 3}, tags3); + + std::vector<PrfBlock> allTags = {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}}; + + // The field redundancy is so that we can pull out the field + // name in the mock version of rewriteRangePayloadAsTags. + BSONObj query = + BSON(encField << BSON("$encryptedBetween" << BSON(encField << BSON_ARRAY(start << end)))); + + auto inputExpr = + EncryptedBetweenMatchExpression(encField, query[encField]["$encryptedBetween"], nullptr); + + assertRewriteToTags(_predicate, &inputExpr, toBSONArray(std::move(allTags))); +} + +}; // namespace +} // namespace mongo::fle diff --git a/src/mongo/db/query/fle/server_rewrite.cpp b/src/mongo/db/query/fle/server_rewrite.cpp index be98058a246..6575a2483d2 100644 --- a/src/mongo/db/query/fle/server_rewrite.cpp +++ b/src/mongo/db/query/fle/server_rewrite.cpp @@ -48,6 +48,8 @@ #include "mongo/db/pipeline/document_source_match.h" #include "mongo/db/pipeline/expression.h" #include "mongo/db/query/collation/collator_factory_interface.h" +#include "mongo/db/query/fle/encrypted_predicate.h" +#include "mongo/db/query/fle/query_rewriter.h" #include "mongo/db/service_context.h" #include "mongo/logv2/log.h" #include "mongo/s/grid.h" @@ -73,64 +75,13 @@ std::unique_ptr<CollatorInterface> collatorFromBSON(OperationContext* opCtx, return collator; } namespace { - -template <typename PayloadT> -boost::intrusive_ptr<ExpressionInternalFLEEqual> generateFleEqualMatch(StringData path, - const PayloadT& ffp, - ExpressionContext* expCtx) { - // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] } - auto tokens = ParsedFindPayload(ffp); - - uassert(6672401, - "Missing required field server encryption token in find payload", - tokens.serverToken.has_value()); - - return make_intrusive<ExpressionInternalFLEEqual>( - expCtx, - ExpressionFieldPath::createPathFromString( - expCtx, path.toString(), expCtx->variablesParseState), - tokens.serverToken.value().data, - tokens.maxCounter.value_or(0LL), - tokens.edcToken.data); -} - - -template <typename PayloadT> -std::unique_ptr<ExpressionInternalFLEEqual> generateFleEqualMatchUnique(StringData path, - const PayloadT& ffp, - ExpressionContext* expCtx) { - // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] } - auto tokens = ParsedFindPayload(ffp); - - uassert(6672419, - "Missing required field server encryption token in find payload", - tokens.serverToken.has_value()); - - return std::make_unique<ExpressionInternalFLEEqual>( - expCtx, - ExpressionFieldPath::createPathFromString( - expCtx, path.toString(), expCtx->variablesParseState), - tokens.serverToken.value().data, - tokens.maxCounter.value_or(0LL), - tokens.edcToken.data); -} - -std::unique_ptr<MatchExpression> generateFleEqualMatchAndExpr(StringData path, - const BSONElement ffp, - ExpressionContext* expCtx) { - auto fleEqualMatch = generateFleEqualMatch(path, ffp, expCtx); - - return std::make_unique<ExprMatchExpression>(fleEqualMatch, expCtx); -} - - /** * This section defines a mapping from DocumentSources to the dispatch function to appropriately * handle FLE rewriting for that stage. This should be kept in line with code on the client-side * that marks constants for encryption: we should handle all places where an implicitly-encrypted * value may be for each stage, otherwise we may return non-sensical results. */ -static stdx::unordered_map<std::type_index, std::function<void(FLEQueryRewriter*, DocumentSource*)>> +static stdx::unordered_map<std::type_index, std::function<void(QueryRewriter*, DocumentSource*)>> stageRewriterMap; #define REGISTER_DOCUMENT_SOURCE_FLE_REWRITER(className, rewriterFunc) \ @@ -142,19 +93,19 @@ static stdx::unordered_map<std::type_index, std::function<void(FLEQueryRewriter* }; \ } -void rewriteMatch(FLEQueryRewriter* rewriter, DocumentSourceMatch* source) { +void rewriteMatch(QueryRewriter* rewriter, DocumentSourceMatch* source) { if (auto rewritten = rewriter->rewriteMatchExpression(source->getQuery())) { source->rebuild(rewritten.value()); } } -void rewriteGeoNear(FLEQueryRewriter* rewriter, DocumentSourceGeoNear* source) { +void rewriteGeoNear(QueryRewriter* rewriter, DocumentSourceGeoNear* source) { if (auto rewritten = rewriter->rewriteMatchExpression(source->getQuery())) { source->setQuery(rewritten.value()); } } -void rewriteGraphLookUp(FLEQueryRewriter* rewriter, DocumentSourceGraphLookUp* source) { +void rewriteGraphLookUp(QueryRewriter* rewriter, DocumentSourceGraphLookUp* source) { if (auto filter = source->getAdditionalFilter()) { if (auto rewritten = rewriter->rewriteMatchExpression(filter.value())) { source->setAdditionalFilter(rewritten.value()); @@ -170,229 +121,6 @@ REGISTER_DOCUMENT_SOURCE_FLE_REWRITER(DocumentSourceMatch, rewriteMatch); REGISTER_DOCUMENT_SOURCE_FLE_REWRITER(DocumentSourceGeoNear, rewriteGeoNear); REGISTER_DOCUMENT_SOURCE_FLE_REWRITER(DocumentSourceGraphLookUp, rewriteGraphLookUp); -class FLEExpressionRewriter { -public: - FLEExpressionRewriter(FLEQueryRewriter* queryRewriter) : queryRewriter(queryRewriter){}; - - /** - * Accepts a vector of expressions to be compared for equality to an encrypted field. For any - * expression representing a constant encrypted value, computes the tags for the expression and - * rewrites the comparison to a disjunction over __safeContent__. Returns an OR expression of - * these disjunctions. If no rewrites were done, returns nullptr. Either all of the expressions - * be constant FFPs or none of them should be. - * - * The final output will look like - * {$or: [{$in: [tag0, "$__safeContent__"]}, {$in: [tag1, "$__safeContent__"]}, ...]}. - */ - std::unique_ptr<Expression> rewriteInToEncryptedField( - const Expression* leftExpr, - const std::vector<boost::intrusive_ptr<Expression>>& equalitiesList) { - size_t numFFPs = 0; - std::vector<boost::intrusive_ptr<Expression>> orListElems; - - for (auto& equality : equalitiesList) { - // For each expression representing a FleFindPayload... - if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { - if (!queryRewriter->isFleFindPayload( - constChild->getValue(), EncryptedBinDataType::kFLE2FindEqualityPayload)) { - continue; - } - - numFFPs++; - } - } - - // Finally, construct an $or of all of the $ins. - if (numFFPs == 0) { - return nullptr; - } - - uassert( - 6334102, - "If any elements in an comparison expression are encrypted, then all elements should " - "be encrypted.", - numFFPs == equalitiesList.size()); - - auto leftFieldPath = dynamic_cast<const ExpressionFieldPath*>(leftExpr); - uassert(6672417, - "$in is only supported with Queryable Encryption when the first argument is a " - "field path", - leftFieldPath != nullptr); - - if (!queryRewriter->isForceEncryptedCollScan()) { - try { - for (auto& equality : equalitiesList) { - // For each expression representing a FleFindPayload... - if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { - // ... rewrite the payload to a list of tags... - auto tags = - queryRewriter->rewriteEqualityPayloadAsTags(constChild->getValue()); - for (auto&& tagElt : tags) { - // ... and for each tag, construct expression {$in: [tag, - // "$__safeContent__"]}. - std::vector<boost::intrusive_ptr<Expression>> inVec{ - ExpressionConstant::create(queryRewriter->expCtx(), tagElt), - ExpressionFieldPath::createPathFromString( - queryRewriter->expCtx(), - kSafeContent, - queryRewriter->expCtx()->variablesParseState)}; - orListElems.push_back(make_intrusive<ExpressionIn>( - queryRewriter->expCtx(), std::move(inVec))); - } - } - } - - didRewrite = true; - - return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), - std::move(orListElems)); - } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { - LOGV2_DEBUG(6672403, - 2, - "FLE Max tag limit hit during aggregation $in rewrite", - "__error__"_attr = ex.what()); - - if (queryRewriter->getEncryptedCollScanMode() != - FLEQueryRewriter::EncryptedCollScanMode::kUseIfNeeded) { - throw; - } - - // fall through - } - } - - for (auto& equality : equalitiesList) { - if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { - auto fleEqExpr = generateFleEqualMatch( - leftFieldPath->getFieldPathWithoutCurrentPrefix().fullPath(), - constChild->getValue(), - queryRewriter->expCtx()); - orListElems.push_back(fleEqExpr); - } - } - - didRewrite = true; - return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), std::move(orListElems)); - } - - // Rewrite a [$eq : [$fieldpath, constant]] or [$eq: [constant, $fieldpath]] - // to _internalFleEq: {field: $fieldpath, edc: edcToken, counter: N, server: serverToken} - std::unique_ptr<Expression> rewriteComparisonsToEncryptedField( - const std::vector<boost::intrusive_ptr<Expression>>& equalitiesList) { - - auto leftConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[0].get()); - auto rightConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[1].get()); - - bool isLeftFFP = leftConstant && - queryRewriter->isFleFindPayload(leftConstant->getValue(), - EncryptedBinDataType::kFLE2FindEqualityPayload); - bool isRightFFP = rightConstant && - queryRewriter->isFleFindPayload(rightConstant->getValue(), - EncryptedBinDataType::kFLE2FindEqualityPayload); - - uassert(6334100, - "Cannot compare two encrypted constants to each other", - !(isLeftFFP && isRightFFP)); - - // No FLE Find Payload - if (!isLeftFFP && !isRightFFP) { - return nullptr; - } - - auto leftFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[0].get()); - auto rightFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[1].get()); - - uassert( - 6672413, - "Queryable Encryption only supports comparisons between a field path and a constant", - leftFieldPath || rightFieldPath); - - auto fieldPath = leftFieldPath ? leftFieldPath : rightFieldPath; - auto constChild = isLeftFFP ? leftConstant : rightConstant; - - if (!queryRewriter->isForceEncryptedCollScan()) { - try { - std::vector<boost::intrusive_ptr<Expression>> orListElems; - - auto tags = queryRewriter->rewriteEqualityPayloadAsTags(constChild->getValue()); - for (auto&& tagElt : tags) { - // ... and for each tag, construct expression {$in: [tag, - // "$__safeContent__"]}. - std::vector<boost::intrusive_ptr<Expression>> inVec{ - ExpressionConstant::create(queryRewriter->expCtx(), tagElt), - ExpressionFieldPath::createPathFromString( - queryRewriter->expCtx(), - kSafeContent, - queryRewriter->expCtx()->variablesParseState)}; - orListElems.push_back( - make_intrusive<ExpressionIn>(queryRewriter->expCtx(), std::move(inVec))); - } - - didRewrite = true; - return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), - std::move(orListElems)); - - } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { - LOGV2_DEBUG(6672409, - 2, - "FLE Max tag limit hit during query $in rewrite", - "__error__"_attr = ex.what()); - - if (queryRewriter->getEncryptedCollScanMode() != - FLEQueryRewriter::EncryptedCollScanMode::kUseIfNeeded) { - throw; - } - - // fall through - } - } - - auto fleEqExpr = - generateFleEqualMatchUnique(fieldPath->getFieldPathWithoutCurrentPrefix().fullPath(), - constChild->getValue(), - queryRewriter->expCtx()); - - didRewrite = true; - return fleEqExpr; - } - - std::unique_ptr<Expression> postVisit(Expression* exp) { - if (auto inExpr = dynamic_cast<ExpressionIn*>(exp)) { - // Rewrite an $in over an encrypted field to an $or. The first child of the $in can be - // ignored when rewrites are done; there is no extra information in that child that - // doesn't exist in the FFPs in the $in list. - if (auto inList = dynamic_cast<ExpressionArray*>(inExpr->getOperandList()[1].get())) { - return rewriteInToEncryptedField(inExpr->getOperandList()[0].get(), - inList->getChildren()); - } - } else if (auto eqExpr = dynamic_cast<ExpressionCompare*>(exp); eqExpr && - (eqExpr->getOp() == ExpressionCompare::EQ || - eqExpr->getOp() == ExpressionCompare::NE)) { - // Rewrite an $eq comparing an encrypted field and an encrypted constant to an $or. - auto newExpr = rewriteComparisonsToEncryptedField(eqExpr->getChildren()); - - // Neither child is an encrypted constant, and no rewriting needs to be done. - if (!newExpr) { - return nullptr; - } - - // Exactly one child was an encrypted constant. The other child can be ignored; there is - // no extra information in that child that doesn't exist in the FFP. - if (eqExpr->getOp() == ExpressionCompare::NE) { - std::vector<boost::intrusive_ptr<Expression>> notChild{newExpr.release()}; - return std::make_unique<ExpressionNot>(queryRewriter->expCtx(), - std::move(notChild)); - } - return newExpr; - } - - return nullptr; - } - - FLEQueryRewriter* queryRewriter; - bool didRewrite = false; -}; - BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader, const FLEStateCollectionReader& eccReader, boost::intrusive_ptr<ExpressionContext> expCtx, @@ -400,7 +128,7 @@ BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader, EncryptedCollScanModeAllowed mode) { if (auto rewritten = - FLEQueryRewriter(expCtx, escReader, eccReader, mode).rewriteMatchExpression(filter)) { + QueryRewriter(expCtx, escReader, eccReader, mode).rewriteMatchExpression(filter)) { return rewritten.value(); } @@ -437,7 +165,7 @@ public: ~PipelineRewrite(){}; void doRewrite(FLEStateCollectionReader& escReader, FLEStateCollectionReader& eccReader) final { - auto rewriter = FLEQueryRewriter(expCtx, escReader, eccReader); + auto rewriter = QueryRewriter(expCtx, escReader, eccReader); for (auto&& source : pipeline->getSources()) { if (stageRewriterMap.find(typeid(*source)) != stageRewriterMap.end()) { stageRewriterMap[typeid(*source)](&rewriter, source.get()); @@ -481,7 +209,6 @@ void doFLERewriteInTxn(OperationContext* opCtx, std::shared_ptr<RewriteBase> sharedBlock, GetTxnCallback getTxn) { auto txn = getTxn(opCtx); - auto swCommitResult = txn->runNoThrow( opCtx, [sharedBlock](const txn_api::TransactionClient& txnClient, auto txnExec) { auto makeCollectionReader = [sharedBlock](FLEQueryInterface* queryImpl, @@ -602,242 +329,4 @@ std::unique_ptr<Pipeline, PipelineDeleter> processPipeline( return sharedBlock->getPipeline(); } - -std::unique_ptr<Expression> FLEQueryRewriter::rewriteExpression(Expression* expression) { - tassert(6334104, "Expected an expression to rewrite but found none", expression); - - FLEExpressionRewriter expressionRewriter{this}; - auto res = expression_walker::walk<Expression>(expression, &expressionRewriter); - _rewroteLastExpression = expressionRewriter.didRewrite; - return res; -} - -boost::optional<BSONObj> FLEQueryRewriter::rewriteMatchExpression(const BSONObj& filter) { - auto expr = uassertStatusOK(MatchExpressionParser::parse(filter, _expCtx)); - - _rewroteLastExpression = false; - if (auto res = _rewrite(expr.get())) { - // The rewrite resulted in top-level changes. Serialize the new expression. - return res->serialize().getOwned(); - } else if (_rewroteLastExpression) { - // The rewrite had no top-level changes, but nested expressions were rewritten. Serialize - // the parsed expression, which has in-place changes. - return expr->serialize().getOwned(); - } - - // No rewrites were done. - return boost::none; -} - -std::unique_ptr<MatchExpression> FLEQueryRewriter::_rewrite(MatchExpression* expr) { - switch (expr->matchType()) { - case MatchExpression::EQ: - return rewriteEq(std::move(static_cast<const EqualityMatchExpression*>(expr))); - case MatchExpression::MATCH_IN: - return rewriteIn(std::move(static_cast<const InMatchExpression*>(expr))); - case MatchExpression::AND: - case MatchExpression::OR: - case MatchExpression::NOT: - case MatchExpression::NOR: { - for (size_t i = 0; i < expr->numChildren(); i++) { - auto child = expr->getChild(i); - if (auto newChild = _rewrite(child)) { - expr->resetChild(i, newChild.release()); - } - } - return nullptr; - } - case MatchExpression::ENCRYPTED_BETWEEN: { - if (gFeatureFlagFLE2Range.isEnabled(serverGlobalParams.featureCompatibility)) { - return rewriteRange( - std::move(static_cast<const EncryptedBetweenMatchExpression*>(expr))); - } - return nullptr; - } - case MatchExpression::EXPRESSION: { - // Save the current value of _rewroteLastExpression, since rewriteExpression() may - // reset it to false and we may have already done a match expression rewrite. - auto didRewrite = _rewroteLastExpression; - auto rewritten = - rewriteExpression(static_cast<ExprMatchExpression*>(expr)->getExpression().get()); - _rewroteLastExpression |= didRewrite; - if (rewritten) { - return std::make_unique<ExprMatchExpression>(rewritten.release(), expCtx()); - } - [[fallthrough]]; - } - default: - return nullptr; - } -} - -BSONObj FLEQueryRewriter::rewriteEqualityPayloadAsTags(BSONElement fleFindPayload) const { - auto tokens = ParsedFindPayload(fleFindPayload); - auto tags = readTags(*_escReader, - *_eccReader, - tokens.escToken, - tokens.eccToken, - tokens.edcToken, - tokens.maxCounter); - - auto bab = BSONArrayBuilder(); - for (auto tag : tags) { - bab.appendBinData(tag.size(), BinDataType::BinDataGeneral, tag.data()); - } - - return bab.obj().getOwned(); -} - -std::vector<Value> FLEQueryRewriter::rewriteEqualityPayloadAsTags(Value fleFindPayload) const { - auto tokens = ParsedFindPayload(fleFindPayload); - auto tags = readTags(*_escReader, - *_eccReader, - tokens.escToken, - tokens.eccToken, - tokens.edcToken, - tokens.maxCounter); - - std::vector<Value> tagVec; - for (auto tag : tags) { - tagVec.push_back(Value(BSONBinData(tag.data(), tag.size(), BinDataType::BinDataGeneral))); - } - return tagVec; -} - -BSONObj FLEQueryRewriter::rewriteRangePayloadAsTags(BSONElement fleFindPayload) const { - // TODO: SERVER-67206 - return BSONObj(); -} - -std::vector<Value> FLEQueryRewriter::rewriteRangePayloadAsTags(Value fleFindPayload) const { - // TODO: SERVER-67206 - return std::vector({Value(0)}); -} - -std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteEq(const EqualityMatchExpression* expr) { - auto ffp = expr->getData(); - if (!isFleFindPayload(ffp, EncryptedBinDataType::kFLE2FindEqualityPayload)) { - return nullptr; - } - - if (_mode != EncryptedCollScanMode::kForceAlways) { - try { - auto obj = rewriteEqualityPayloadAsTags(ffp); - - auto tags = std::vector<BSONElement>(); - obj.elems(tags); - - auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); - inExpr->setBackingBSON(std::move(obj)); - auto status = inExpr->setEqualities(std::move(tags)); - uassertStatusOK(status); - _rewroteLastExpression = true; - return inExpr; - } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { - LOGV2_DEBUG(6672410, - 2, - "FLE Max tag limit hit during query $eq rewrite", - "__error__"_attr = ex.what()); - - if (_mode != EncryptedCollScanMode::kUseIfNeeded) { - throw; - } - - // fall through - } - } - - auto exprMatch = generateFleEqualMatchAndExpr(expr->path(), ffp, _expCtx.get()); - _rewroteLastExpression = true; - return exprMatch; -} - -std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteIn(const InMatchExpression* expr) { - size_t numFFPs = 0; - for (auto& eq : expr->getEqualities()) { - if (isFleFindPayload(eq, EncryptedBinDataType::kFLE2FindEqualityPayload)) { - ++numFFPs; - } - } - - if (numFFPs == 0) { - return nullptr; - } - - // All elements in an encrypted $in expression should be FFPs. - uassert( - 6329400, - "If any elements in a $in expression are encrypted, then all elements should be encrypted.", - numFFPs == expr->getEqualities().size()); - - if (_mode != EncryptedCollScanMode::kForceAlways) { - - try { - auto backingBSONBuilder = BSONArrayBuilder(); - - for (auto& eq : expr->getEqualities()) { - auto obj = rewriteEqualityPayloadAsTags(eq); - for (auto&& elt : obj) { - backingBSONBuilder.append(elt); - } - } - - auto backingBSON = backingBSONBuilder.arr(); - auto allTags = std::vector<BSONElement>(); - backingBSON.elems(allTags); - - auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); - inExpr->setBackingBSON(std::move(backingBSON)); - auto status = inExpr->setEqualities(std::move(allTags)); - uassertStatusOK(status); - - _rewroteLastExpression = true; - return inExpr; - - } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { - LOGV2_DEBUG(6672411, - 2, - "FLE Max tag limit hit during query $in rewrite", - "__error__"_attr = ex.what()); - - if (_mode != EncryptedCollScanMode::kUseIfNeeded) { - throw; - } - - // fall through - } - } - - std::vector<std::unique_ptr<MatchExpression>> matches; - matches.reserve(numFFPs); - - for (auto& eq : expr->getEqualities()) { - auto exprMatch = generateFleEqualMatchAndExpr(expr->path(), eq, _expCtx.get()); - matches.push_back(std::move(exprMatch)); - } - - auto orExpr = std::make_unique<OrMatchExpression>(std::move(matches)); - _rewroteLastExpression = true; - return orExpr; -} - -std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteRange( - const EncryptedBetweenMatchExpression* expr) { - auto ffp = expr->rhs(); - - if (!isFleFindPayload(ffp, EncryptedBinDataType::kFLE2FindRangePayload)) { - return nullptr; - } - - auto obj = rewriteRangePayloadAsTags(ffp); - auto tags = std::vector<BSONElement>(); - obj.elems(tags); - auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); - inExpr->setBackingBSON(std::move(obj)); - auto status = inExpr->setEqualities(std::move(tags)); - uassertStatusOK(status); - _rewroteLastExpression = true; - return inExpr; -} - } // namespace mongo::fle diff --git a/src/mongo/db/query/fle/server_rewrite.h b/src/mongo/db/query/fle/server_rewrite.h index d3d52b607f2..c086548b23d 100644 --- a/src/mongo/db/query/fle/server_rewrite.h +++ b/src/mongo/db/query/fle/server_rewrite.h @@ -40,19 +40,16 @@ #include "mongo/db/namespace_string.h" #include "mongo/db/pipeline/expression_context.h" #include "mongo/db/query/count_command_gen.h" +#include "mongo/db/query/fle/query_rewriter_interface.h" #include "mongo/db/transaction/transaction_api.h" +/** + * This file contains the interface for rewriting filters within CRUD commands for FLE2. + */ namespace mongo { class FLEQueryInterface; namespace fle { -/** - * Low Selectivity rewrites use $expr which is not supported in all commands such as upserts. - */ -enum class EncryptedCollScanModeAllowed { - kAllow, - kDisallow, -}; /** * Make a collator object from its BSON representation. Useful when creating ExpressionContext @@ -116,152 +113,5 @@ BSONObj rewriteEncryptedFilterInsideTxn( boost::intrusive_ptr<ExpressionContext> expCtx, BSONObj filter, EncryptedCollScanModeAllowed mode = EncryptedCollScanModeAllowed::kDisallow); - -/** - * Class which handles rewriting filter MatchExpressions for FLE2. The functionality is encapsulated - * as a class rather than just a namespace so that the collection readers don't have to be passed - * around as extra arguments to every function. - * - * Exposed in the header file for unit testing purposes. External callers should use the - * rewriteEncryptedFilterInsideTxn() helper function defined above. - */ -class FLEQueryRewriter { -public: - enum class EncryptedCollScanMode { - // Always use high cardinality filters, used by tests - kForceAlways, - - // Use high cardinality mode if $in rewrites do not fit in the - // internalQueryFLERewriteMemoryLimit memory limit - kUseIfNeeded, - - // Do not rewrite into high cardinality filter, throw exceptions instead - // Some contexts like upsert do not support $expr - kDisallow, - }; - - /** - * Takes in references to collection readers for the ESC and ECC that are used during tag - * computation. - */ - FLEQueryRewriter(boost::intrusive_ptr<ExpressionContext> expCtx, - const FLEStateCollectionReader& escReader, - const FLEStateCollectionReader& eccReader, - EncryptedCollScanModeAllowed mode = EncryptedCollScanModeAllowed::kAllow) - : _expCtx(expCtx), _escReader(&escReader), _eccReader(&eccReader) { - - if (internalQueryFLEAlwaysUseEncryptedCollScanMode.load()) { - _mode = EncryptedCollScanMode::kForceAlways; - } - - if (mode == EncryptedCollScanModeAllowed::kDisallow) { - _mode = EncryptedCollScanMode::kDisallow; - } - - // This isn't the "real" query so we don't want to increment Expression - // counters here. - _expCtx->stopExpressionCounters(); - } - - /** - * Accepts a BSONObj holding a MatchExpression, and returns BSON representing the rewritten - * expression. Returns boost::none if no rewriting was done. - * - * Rewrites the match expression with FLE find payloads into a disjunction on the - * __safeContent__ array of tags. - * - * Will rewrite top-level $eq and $in expressions, as well as recursing through $and, $or, $not - * and $nor. Also handles similarly limited rewriting under $expr. All other MatchExpressions, - * notably $elemMatch, are ignored. - */ - boost::optional<BSONObj> rewriteMatchExpression(const BSONObj& filter); - - /** - * Accepts an expression to be re-written. Will rewrite top-level expressions including $eq and - * $in, as well as recursing through other expressions. Returns a new pointer if the top-level - * expression must be changed. A nullptr indicates that the modifications happened in-place. - */ - std::unique_ptr<Expression> rewriteExpression(Expression* expression); - - /** - * Determine whether a given BSONElement is in fact a FLE find payload by checking that it is - * the same type as the given EncryptedBinDataType. Sub-type 6, sub-sub-type determined by - * "type." - */ - virtual bool isFleFindPayload(const BSONElement& elt, EncryptedBinDataType type) const { - if (!elt.isBinData(BinDataType::Encrypt)) { - return false; - } - int dataLen; - auto data = elt.binData(dataLen); - return dataLen >= 1 && data[0] == static_cast<uint8_t>(type); - } - - /** - * Determine whether a given Value is in fact a FLE find payload by checking that it is the same - * type as the given EncryptedBinDataType. Sub-type 6, sub-sub-type determined by "type." - */ - bool isFleFindPayload(const Value& v, EncryptedBinDataType type) const { - if (v.getType() != BSONType::BinData) { - return false; - } - - auto binData = v.getBinData(); - return binData.type == BinDataType::Encrypt && binData.length >= 1 && - static_cast<uint8_t>(type) == static_cast<const uint8_t*>(binData.data)[0]; - } - - std::vector<Value> rewriteEqualityPayloadAsTags(Value fleFindPayload) const; - std::vector<Value> rewriteRangePayloadAsTags(Value fleFindPayload) const; - - ExpressionContext* expCtx() { - return _expCtx.get(); - } - - bool isForceEncryptedCollScan() const { - return _mode == EncryptedCollScanMode::kForceAlways; - } - - void setForceEncryptedCollScanForTest() { - _mode = EncryptedCollScanMode::kForceAlways; - } - - EncryptedCollScanMode getEncryptedCollScanMode() const { - return _mode; - } - -protected: - // This constructor should only be used for mocks in testing. - FLEQueryRewriter(boost::intrusive_ptr<ExpressionContext> expCtx) - : _expCtx(expCtx), _escReader(nullptr), _eccReader(nullptr) {} - -private: - /** - * A single rewrite step, called recursively on child expressions. - */ - std::unique_ptr<MatchExpression> _rewrite(MatchExpression* me); - - virtual BSONObj rewriteEqualityPayloadAsTags(BSONElement fleFindPayload) const; - - virtual BSONObj rewriteRangePayloadAsTags(BSONElement fleFindPayload) const; - std::unique_ptr<MatchExpression> rewriteEq(const EqualityMatchExpression* expr); - std::unique_ptr<MatchExpression> rewriteIn(const InMatchExpression* expr); - std::unique_ptr<MatchExpression> rewriteRange(const EncryptedBetweenMatchExpression* expr); - - boost::intrusive_ptr<ExpressionContext> _expCtx; - - // Holds a pointer so that these can be null for tests, even though the public constructor - // takes a const reference. - const FLEStateCollectionReader* _escReader; - const FLEStateCollectionReader* _eccReader; - - // True if the last Expression or MatchExpression processed by this rewriter was rewritten. - bool _rewroteLastExpression = false; - - // Controls how query rewriter rewrites the query - EncryptedCollScanMode _mode{EncryptedCollScanMode::kUseIfNeeded}; -}; - - } // namespace fle } // namespace mongo diff --git a/src/mongo/db/query/fle/server_rewrite_test.cpp b/src/mongo/db/query/fle/server_rewrite_test.cpp deleted file mode 100644 index fd910390eb9..00000000000 --- a/src/mongo/db/query/fle/server_rewrite_test.cpp +++ /dev/null @@ -1,1085 +0,0 @@ -/** - * Copyright (C) 2022-present MongoDB, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the Server Side Public License, version 1, - * as published by MongoDB, Inc. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Server Side Public License for more details. - * - * You should have received a copy of the Server Side Public License - * along with this program. If not, see - * <http://www.mongodb.com/licensing/server-side-public-license>. - * - * As a special exception, the copyright holders give permission to link the - * code of portions of this program with the OpenSSL library under certain - * conditions as described in each individual source file and distribute - * linked combinations including the program with the OpenSSL library. You - * must comply with the Server Side Public License in all respects for - * all of the code used other than as permitted herein. If you modify file(s) - * with this exception, you may extend this exception to your version of the - * file(s), but you are not obligated to do so. If you do not wish to do so, - * delete this exception statement from your version. If you delete this - * exception statement from all source files in the program, then also delete - * it in the license file. - */ - - -#include <memory> - -#include "mongo/bson/bsonelement.h" -#include "mongo/bson/bsonmisc.h" -#include "mongo/bson/bsonobjbuilder.h" -#include "mongo/bson/bsontypes.h" -#include "mongo/db/matcher/expression_leaf.h" -#include "mongo/db/pipeline/expression_context_for_test.h" -#include "mongo/db/query/fle/server_rewrite.h" -#include "mongo/idl/server_parameter_test_util.h" -#include "mongo/unittest/unittest.h" -#include "mongo/util/assert_util.h" - - -namespace mongo { -namespace { - -class BasicMockFLEQueryRewriter : public fle::FLEQueryRewriter { -public: - BasicMockFLEQueryRewriter() : fle::FLEQueryRewriter(new ExpressionContextForTest()) {} - - BSONObj rewriteMatchExpressionForTest(const BSONObj& obj) { - auto res = rewriteMatchExpression(obj); - return res ? res.value() : obj; - } - - /* Given a vector of BSONArrays, concatenate them into one BSONArray. - * - * E.g., given vec = [{1, 2, 3}, {4, 5, 6}, {21, 34}] this will return - * {1, 2, 3, 4, 5, 6, 21, 34} */ - BSONArray concatBSONArrays(std::vector<BSONArray> vec) const { - auto backingBSONBuilder = BSONArrayBuilder(); - - for (auto& arr : vec) { - for (auto&& elt : arr) { - backingBSONBuilder.append(elt); - } - } - return backingBSONBuilder.arr(); - } -}; - -class MockFLEQueryRewriter : public BasicMockFLEQueryRewriter { -public: - MockFLEQueryRewriter() : _tags() {} - - bool isFleFindPayload(const BSONElement& fleFindPayload, - EncryptedBinDataType type) const override { - switch (type) { - case EncryptedBinDataType::kFLE2FindEqualityPayload: { - return _encryptedFields.find(fleFindPayload.fieldNameStringData()) != - _encryptedFields.end(); - } - case EncryptedBinDataType::kFLE2FindRangePayload: { - // By definition, $encryptedBetween only ever has an encrypted payload. - return true; - } - default: - return false; - } - } - - void setEncryptedTags(std::pair<StringData, int> fieldvalue, BSONObj tags) { - _encryptedFields.insert(fieldvalue.first); - _tags[fieldvalue] = tags; - } - -private: - BSONObj rewriteEqualityPayloadAsTags(BSONElement fleFindPayload) const override { - ASSERT(fleFindPayload.isNumber()); // Only accept numbers as mock FFPs. - ASSERT(_tags.find({fleFindPayload.fieldNameStringData(), fleFindPayload.Int()}) != - _tags.end()); - return _tags.find({fleFindPayload.fieldNameStringData(), fleFindPayload.Int()})->second; - }; - - BSONObj rewriteRangePayloadAsTags(BSONElement fleFindPayload) const override { - auto parsedPayload = fleFindPayload.Obj().firstElement(); - auto fieldName = parsedPayload.fieldNameStringData(); - - std::vector<BSONElement> range; - auto payloadAsArray = parsedPayload.Array(); - for (auto&& elt : payloadAsArray) { - range.push_back(elt); - } - - std::vector<BSONArray> allTags; - for (auto i = range[0].Number(); i <= range[1].Number(); i++) { - ASSERT(_tags.find({fieldName, i}) != _tags.end()); - auto temp = _tags.find({fieldName, i})->second; - allTags.push_back(BSONArray(temp)); - } - return concatBSONArrays(allTags); - }; - - std::map<std::pair<StringData, int>, BSONObj> _tags; - std::set<StringData> _encryptedFields; -}; - -class FLEServerRewriteTest : public unittest::Test { -public: - FLEServerRewriteTest() {} - - void setUp() override {} - - void tearDown() override {} - -protected: - MockFLEQueryRewriter _mock; -}; - -TEST_F(FLEServerRewriteTest, NoFFP_Equality) { - auto match = fromjson("{ssn: '5'}"); - auto expected = fromjson("{ssn: '5'}}"); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, NoFFP_In) { - auto match = fromjson("{ssn: {$in: ['5', '6', '7']}}"); - auto expected = fromjson("{ssn: {$in: ['5', '6', '7']}}"); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_Equality) { - auto match = fromjson("{ssn: 5}"); - auto tags = BSON_ARRAY(1 << 2 << 3); - - _mock.setEncryptedTags({"ssn", 5}, tags); - auto expected = BSON(kSafeContent << BSON("$in" << tags)); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_Equality_DottedPath) { - auto match = fromjson("{'user.ssn': {$eq: 5}}"); - auto tags = BSON_ARRAY(1 << 2 << 3); - - _mock.setEncryptedTags({"user.ssn", 5}, tags); - auto expected = BSON(kSafeContent << BSON("$in" << tags)); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_In) { - auto match = fromjson("{ssn: {$in: [2, 4, 6]}}"); - - // The key/value pairs that the mock functions use to determine the fake FFPs are inside an - // array, and so the keys are the index values and the values are the actual array elements. - _mock.setEncryptedTags({"0", 2}, BSON_ARRAY(1 << 2)); - _mock.setEncryptedTags({"1", 4}, BSON_ARRAY(5 << 3)); - _mock.setEncryptedTags({"2", 6}, BSON_ARRAY(99 << 100)); - - // Order doesn't matter in a disjunction. - auto expected = BSON(kSafeContent << BSON("$in" << BSON_ARRAY(1 << 2 << 3 << 5 << 99 << 100))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_In_DottedPath) { - auto match = fromjson("{'user.ssn': {$in: [2, 4, 6]}}"); - - // The key/value pairs that the mock functions use to determine the fake FFPs are inside an - // array, and so the keys are the index values and the values are the actual array elements. - _mock.setEncryptedTags({"0", 2}, BSON_ARRAY(1 << 2)); - _mock.setEncryptedTags({"1", 4}, BSON_ARRAY(5 << 3)); - _mock.setEncryptedTags({"2", 6}, BSON_ARRAY(99 << 100)); - - // Order doesn't matter in a disjunction. - auto expected = BSON(kSafeContent << BSON("$in" << BSON_ARRAY(1 << 2 << 3 << 5 << 99 << 100))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_Conjunction_BothEncrypted) { - auto match = fromjson("{$and: [{ssn: 5}, {age: 36}]}"); - auto ssnTags = BSON_ARRAY(1 << 2 << 3); - auto ageTags = BSON_ARRAY(22 << 44 << 66); - - _mock.setEncryptedTags({"ssn", 5}, ssnTags); - _mock.setEncryptedTags({"age", 36}, ageTags); - auto expected = BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << ssnTags)) - << BSON(kSafeContent << BSON("$in" << ageTags)))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_Conjunction_PartlyEncrypted) { - auto match = fromjson("{$and: [{ssn: 5}, {notSsn: 6}]}"); - auto tags = BSON_ARRAY(1 << 2 << 3); - - _mock.setEncryptedTags({"ssn", 5}, tags); - auto expected = BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tags)) - << BSON("notSsn" << BSON("$eq" << 6)))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_CompoundEquality_PartlyEncrypted) { - auto match = fromjson("{ssn: 5, notSsn: 6}"); - auto tags = BSON_ARRAY(1 << 2 << 3); - - _mock.setEncryptedTags({"ssn", 5}, tags); - auto expected = BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tags)) - << BSON("notSsn" << BSON("$eq" << 6)))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_Encrypted_Nested_Unencrypted) { - auto match = fromjson("{ssn: 5, user: {region: 'US'}}"); - auto tags = BSON_ARRAY(1 << 2 << 3); - - _mock.setEncryptedTags({"ssn", 5}, tags); - auto expected = BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tags)) - << BSON("user" << BSON("$eq" << BSON("region" - << "US"))))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_Not_Equality) { - auto match = fromjson("{ssn: {$not: {$eq: 5}}}"); - auto tags = BSON_ARRAY(1 << 2 << 3); - - _mock.setEncryptedTags({"ssn", 5}, tags); - auto expected = BSON(kSafeContent << BSON("$not" << BSON("$in" << tags))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_Neq) { - auto match = fromjson("{ssn: {$ne: 5}}"); - auto tags = BSON_ARRAY(1 << 2 << 3); - - _mock.setEncryptedTags({"ssn", 5}, tags); - auto expected = BSON(kSafeContent << BSON("$not" << BSON("$in" << tags))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - - -TEST_F(FLEServerRewriteTest, TopLevel_And_In) { - auto match = fromjson("{$and: [{ssn: {$in: [2, 4, 6]}}, {region: 'US'}]}"); - - _mock.setEncryptedTags({"0", 2}, BSON_ARRAY(1 << 2)); - _mock.setEncryptedTags({"1", 4}, BSON_ARRAY(5 << 3)); - _mock.setEncryptedTags({"2", 6}, BSON_ARRAY(99 << 100)); - - auto expected = - BSON("$and" << BSON_ARRAY( - BSON(kSafeContent << BSON("$in" << BSON_ARRAY(1 << 2 << 3 << 5 << 99 << 100))) - << BSON("region" << BSON("$eq" - << "US")))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, NestedConjunction) { - auto match = fromjson("{$and: [{$and: [{ssn: 2}, {other: 3}]}, {otherSsn: 5}]}"); - - _mock.setEncryptedTags({"ssn", 2}, BSON_ARRAY(1 << 2)); - _mock.setEncryptedTags({"otherSsn", 5}, BSON_ARRAY(3 << 4)); - - auto expected = fromjson(R"( - { $and: [ - { $and: [ - { __safeContent__: { $in: [ 1, 2 ] } }, - { other: { $eq: 3 } } - ] }, - { __safeContent__: { $in: [ 3, 4 ] } } - ] })"); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_Nor_Equality) { - auto match = fromjson("{$nor: [{ssn: 5}]}"); - auto tags = BSON_ARRAY(1 << 2 << 3); - - _mock.setEncryptedTags({"ssn", 5}, tags); - auto expected = BSON("$nor" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tags)))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_Nor_Equality_WithUnencrypted) { - auto match = fromjson("{$nor: [{ssn: 5}, {region: 'US'}]}"); - auto tags = BSON_ARRAY(1 << 2 << 3); - - _mock.setEncryptedTags({"ssn", 5}, tags); - auto expected = BSON("$nor" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tags)) - << BSON("region" << BSON("$eq" - << "US")))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_Or_Equality_WithUnencrypted) { - auto match = fromjson("{$or: [{ssn: 5}, {region: 'US'}]}"); - auto tags = BSON_ARRAY(1 << 2 << 3); - - _mock.setEncryptedTags({"ssn", 5}, tags); - auto expected = BSON("$or" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tags)) - << BSON("region" << BSON("$eq" - << "US")))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_Not_In) { - auto match = fromjson("{ssn: {$not: {$in: [2, 4, 6]}}}"); - - _mock.setEncryptedTags({"0", 2}, BSON_ARRAY(1 << 2)); - _mock.setEncryptedTags({"1", 4}, BSON_ARRAY(5 << 3)); - _mock.setEncryptedTags({"2", 6}, BSON_ARRAY(99 << 100)); - - auto expected = BSON( - kSafeContent << BSON("$not" << BSON("$in" << BSON_ARRAY(1 << 2 << 3 << 5 << 99 << 100)))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, TopLevel_Nin) { - auto match = fromjson("{ssn: {$nin: [2, 4, 6]}}"); - - _mock.setEncryptedTags({"0", 2}, BSON_ARRAY(1 << 2)); - _mock.setEncryptedTags({"1", 4}, BSON_ARRAY(5 << 3)); - _mock.setEncryptedTags({"2", 6}, BSON_ARRAY(99 << 100)); - - // Order doesn't matter in a disjunction. - auto expected = BSON( - kSafeContent << BSON("$not" << BSON("$in" << BSON_ARRAY(1 << 2 << 3 << 5 << 99 << 100)))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, InMixOfEncryptedElementsIsDisallowed) { - auto match = fromjson("{ssn: {$in: [2, 4, 6]}}"); - - _mock.setEncryptedTags({"0", 2}, BSON_ARRAY(1 << 2)); - _mock.setEncryptedTags({"1", 4}, BSON_ARRAY(5 << 3)); - - ASSERT_THROWS_CODE(_mock.rewriteMatchExpressionForTest(match), AssertionException, 6329400); -} - -TEST_F(FLEServerRewriteTest, ComparisonToObjectIgnored) { - // Although such a query should fail in query analysis, it's not realistic for us to catch all - // the ways a FLEFindPayload could be improperly included in an explicitly encrypted query, so - // this test demonstrates the server side behavior. - { - auto match = fromjson("{user: {$eq: {ssn: 5}}}"); - - _mock.setEncryptedTags({"user.ssn", 5}, BSON_ARRAY(1 << 2)); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, match); - } - { - auto match = fromjson("{user: {$in: [{ssn: 5}]}}"); - - _mock.setEncryptedTags({"user.ssn", 5}, BSON_ARRAY(1 << 2)); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, match); - } -} - -TEST_F(FLEServerRewriteTest, EncryptedBetweenBasic) { - RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - - int start = 1; - int end = 3; - StringData encField = "ssn"; - - // The field redundancy is so that we can pull out the field - // name in the mock version of rewriteRangePayloadAsTags. - BSONObj query = - BSON(encField << BSON("$encryptedBetween" << BSON(encField << BSON_ARRAY(start << end)))); - - auto tags1 = BSON_ARRAY(1 << 2 << 3); - auto tags2 = BSON_ARRAY("A" - << "F" - << "Q"); - auto tags3 = BSON_ARRAY("aHb" - << "jkl" - << "q76"); - - std::vector<BSONArray> allTags = {tags1, tags2, tags3}; - BSONArray tagsConcat = _mock.concatBSONArrays(allTags); - - for (int i = 0; i <= (end - start); i++) { - _mock.setEncryptedTags({encField, (start + i)}, allTags[i]); - } - - auto expected = BSON(kSafeContent << BSON("$in" << tagsConcat)); - auto actual = _mock.rewriteMatchExpressionForTest(query); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, EncryptedBetweenFeatureFlagFalse) { - RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", false); - - int start = 1; - int end = 3; - StringData encField = "ssn"; - - BSONObj query = - BSON(encField << BSON("$encryptedBetween" << BSON(encField << BSON_ARRAY(start << end)))); - - auto tags1 = BSON_ARRAY(1 << 2 << 3); - auto tags2 = BSON_ARRAY(4 << 5 << 6); - auto tags3 = BSON_ARRAY(7 << 8 << 9); - - std::vector<BSONArray> allTags = {tags1, tags2, tags3}; - BSONArray tagsConcat = _mock.concatBSONArrays(allTags); - - for (int i = 0; i <= (end - start); i++) { - _mock.setEncryptedTags({encField, (start + i)}, allTags[i]); - } - - // No rewrite should occur since the feature flag has been set to false. - auto actual = _mock.rewriteMatchExpressionForTest(query); - ASSERT_BSONOBJ_EQ(actual, query); -} - -TEST_F(FLEServerRewriteTest, EncryptedBetweenVariableNumberOfTags) { - RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - - int start = 1; - int end = 3; - StringData encField = "ssn"; - - BSONObj query = - BSON(encField << BSON("$encryptedBetween" << BSON(encField << BSON_ARRAY(start << end)))); - - auto tags1 = BSON_ARRAY(1); - auto tags2 = BSON_ARRAY("A" - << "F"); - auto tags3 = BSON_ARRAY("aHb" - << "bcdefdfg12243" - << "c" - << "d" - << "e" - << "f" - << "g" - << "hij" - << "kl78h"); - - std::vector<BSONArray> allTags = {tags1, tags2, tags3}; - BSONArray tagsConcat = _mock.concatBSONArrays(allTags); - - for (int i = 0; i <= (end - start); i++) { - _mock.setEncryptedTags({encField, (start + i)}, allTags[i]); - } - - auto expected = BSON(kSafeContent << BSON("$in" << tagsConcat)); - auto actual = _mock.rewriteMatchExpressionForTest(query); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, EncryptedBetweenInsideNot) { - RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - - int start = 1; - int end = 3; - StringData encField = "ssn"; - BSONObj query = - BSON(encField << BSON("$not" << BSON("$encryptedBetween" - << BSON(encField << BSON_ARRAY(start << end))))); - - auto tags1 = BSON_ARRAY(1 << 2 << 3); - auto tags2 = BSON_ARRAY("A" - << "F" - << "Q"); - auto tags3 = BSON_ARRAY("aHb" - << "jkl" - << "q76"); - - std::vector<BSONArray> allTags = {tags1, tags2, tags3}; - BSONArray tagsConcat = _mock.concatBSONArrays(allTags); - - for (int i = 0; i <= (end - start); i++) { - _mock.setEncryptedTags({encField, (start + i)}, allTags[i]); - } - - auto expected = BSON(kSafeContent << BSON("$not" << BSON("$in" << tagsConcat))); - auto actual = _mock.rewriteMatchExpressionForTest(query); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, EncryptedBetweenDottedPath) { - RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - - int start = 1; - int end = 3; - StringData encField = "hello.world"; - BSONObj query = - BSON(encField << BSON("$encryptedBetween" << BSON(encField << BSON_ARRAY(start << end)))); - - auto tags1 = BSON_ARRAY(1 << 2 << 3); - auto tags2 = BSON_ARRAY("A" - << "F" - << "Q"); - auto tags3 = BSON_ARRAY("aHb" - << "jkl" - << "q76"); - - std::vector<BSONArray> allTags = {tags1, tags2, tags3}; - BSONArray tagsConcat = _mock.concatBSONArrays(allTags); - - for (int i = 0; i <= (end - start); i++) { - _mock.setEncryptedTags({encField, (start + i)}, allTags[i]); - } - - auto expected = BSON(kSafeContent << BSON("$in" << tagsConcat)); - auto actual = _mock.rewriteMatchExpressionForTest(query); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, EncryptedBetweenInsideAnd) { - RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - - int start = 1; - int end = 3; - StringData encField = "ssn"; - BSONObj query = BSON( - "$and" << BSON_ARRAY(BSON("x" << 5) - << BSON(encField << BSON("$encryptedBetween" << BSON( - encField << BSON_ARRAY(start << end)))))); - - auto tags1 = BSON_ARRAY(1 << 2 << 3); - auto tags2 = BSON_ARRAY("A" - << "F" - << "Q"); - auto tags3 = BSON_ARRAY("aHb" - << "jkl" - << "q76"); - - std::vector<BSONArray> allTags = {tags1, tags2, tags3}; - BSONArray tagsConcat = _mock.concatBSONArrays(allTags); - - for (int i = 0; i <= (end - start); i++) { - _mock.setEncryptedTags({encField, (start + i)}, allTags[i]); - } - - auto expected = BSON("$and" << BSON_ARRAY(BSON("x" << BSON("$eq" << 5)) - << BSON(kSafeContent << BSON("$in" << tagsConcat)))); - auto actual = _mock.rewriteMatchExpressionForTest(query); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, EncryptedBetweenInsideOr) { - RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - - int start = 1; - int end = 3; - StringData encField = "ssn"; - BSONObj query = BSON( - "$or" << BSON_ARRAY(BSON("x" << 5) - << BSON(encField << BSON("$encryptedBetween" << BSON( - encField << BSON_ARRAY(start << end)))))); - - auto tags1 = BSON_ARRAY(1 << 2 << 3); - auto tags2 = BSON_ARRAY("A" - << "F" - << "Q"); - auto tags3 = BSON_ARRAY("aHb" - << "jkl" - << "q76"); - - std::vector<BSONArray> allTags = {tags1, tags2, tags3}; - BSONArray tagsConcat = _mock.concatBSONArrays(allTags); - - for (int i = 0; i <= (end - start); i++) { - _mock.setEncryptedTags({encField, (start + i)}, allTags[i]); - } - - auto expected = BSON("$or" << BSON_ARRAY(BSON("x" << BSON("$eq" << 5)) - << BSON(kSafeContent << BSON("$in" << tagsConcat)))); - auto actual = _mock.rewriteMatchExpressionForTest(query); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, EncryptedBetweeenAndEncryptedEquality) { - RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - - int start = 1; - int end = 3; - StringData encField = "ssn"; - BSONObj query = BSON( - "$and" << BSON_ARRAY(BSON("x" << 21) - << BSON(encField << BSON("$encryptedBetween" << BSON( - encField << BSON_ARRAY(start << end)))))); - - auto tags1 = BSON_ARRAY(1 << 2 << 3); - auto tags2 = BSON_ARRAY("A" - << "F" - << "Q"); - auto tags3 = BSON_ARRAY("aHb" - << "jkl" - << "q76"); - - std::vector<BSONArray> allTags = {tags1, tags2, tags3}; - BSONArray tagsConcat = _mock.concatBSONArrays(allTags); - - for (int i = 0; i <= (end - start); i++) { - _mock.setEncryptedTags({encField, (start + i)}, allTags[i]); - } - - BSONArray equalityTags = BSON_ARRAY(312 << 567 << 897); - _mock.setEncryptedTags({"x", 21}, equalityTags); - - auto expected = BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << equalityTags)) - << BSON(kSafeContent << BSON("$in" << tagsConcat)))); - auto actual = _mock.rewriteMatchExpressionForTest(query); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, EncryptedBetweenAndEncryptedIn) { - RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - - int start = 1; - int end = 3; - StringData encBetweenField = "ssn"; - StringData encInField = "age"; - BSONObj query = - BSON("$and" << BSON_ARRAY( - BSON(encBetweenField << BSON("$encryptedBetween" - << BSON(encBetweenField << BSON_ARRAY(start << end)))) - << BSON(encInField << BSON("$in" << BSON_ARRAY(10 << 22 << 34))))); - - auto tags1 = BSON_ARRAY(1 << 2 << 3); - auto tags2 = BSON_ARRAY("A" - << "F" - << "Q"); - auto tags3 = BSON_ARRAY("aHb" - << "jkl" - << "q76"); - - std::vector<BSONArray> allTags = {tags1, tags2, tags3}; - BSONArray tagsConcat = _mock.concatBSONArrays(allTags); - - for (int i = 0; i <= (end - start); i++) { - _mock.setEncryptedTags({encBetweenField, (start + i)}, allTags[i]); - } - - auto inTags1 = BSON_ARRAY(1 << 2); - auto inTags2 = BSON_ARRAY(3 << 5); - auto inTags3 = BSON_ARRAY(97 << 98 << 99 << 100); - std::vector<BSONArray> allInTags = {inTags1, inTags2, inTags3}; - _mock.setEncryptedTags({"0", 10}, inTags1); - _mock.setEncryptedTags({"1", 22}, inTags2); - _mock.setEncryptedTags({"2", 34}, inTags3); - BSONArray inTagsConcat = _mock.concatBSONArrays(allInTags); - - - auto expected = - BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tagsConcat)) - << BSON(kSafeContent << BSON("$in" << inTagsConcat)))); - auto actual = _mock.rewriteMatchExpressionForTest(query); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, EncryptedBetweenAndUnencryptedRange) { - RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - - int start = 1; - int end = 3; - StringData encField = "ssn"; - BSONObj query = BSON( - "$and" << BSON_ARRAY(BSON(encField << BSON("$encryptedBetween" - << BSON(encField << BSON_ARRAY(start << end)))) - << BSON("$and" << BSON_ARRAY(BSON("x" << BSON("$gt" << 10)) - << BSON("x" << BSON("$lt" << 25)))))); - - auto tags1 = BSON_ARRAY(1 << 2 << 3); - auto tags2 = BSON_ARRAY("A" - << "F" - << "Q"); - auto tags3 = BSON_ARRAY("aHb" - << "jkl" - << "q76"); - - std::vector<BSONArray> allTags = {tags1, tags2, tags3}; - BSONArray tagsConcat = _mock.concatBSONArrays(allTags); - - for (int i = 0; i <= (end - start); i++) { - _mock.setEncryptedTags({encField, (start + i)}, allTags[i]); - } - - auto expected = BSON( - "$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tagsConcat)) - << BSON("$and" << BSON_ARRAY(BSON("x" << BSON("$gt" << 10)) - << BSON("x" << BSON("$lt" << 25)))))); - auto actual = _mock.rewriteMatchExpressionForTest(query); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerRewriteTest, EncryptedBetweenOnTwoDiffFields) { - RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - - int start1 = 1; - int end1 = 3; - StringData encField1 = "ssn"; - int start2 = 102; - int end2 = 106; - StringData encField2 = "age"; - BSONObj query = - BSON("$and" << BSON_ARRAY( - BSON(encField1 << BSON("$encryptedBetween" - << BSON(encField1 << BSON_ARRAY(start1 << end1)))) - << BSON(encField2 << BSON("$encryptedBetween" - << BSON(encField2 << BSON_ARRAY(start2 << end2)))))); - - auto tags1 = BSON_ARRAY(1 << 2 << 3); - auto tags2 = BSON_ARRAY("A" - << "F" - << "Q"); - auto tags3 = BSON_ARRAY("aHb" - << "jkl" - << "q76"); - - std::vector<BSONArray> allTags1 = {tags1, tags2, tags3}; - BSONArray tagsConcat1 = _mock.concatBSONArrays(allTags1); - - for (int i = 0; i <= (end1 - start1); i++) { - _mock.setEncryptedTags({encField1, (start1 + i)}, allTags1[i]); - } - - auto tags4 = BSON_ARRAY(1 << 2 << 3); - auto tags5 = BSON_ARRAY(4 << 5 << 6); - auto tags6 = BSON_ARRAY(21 << 25 << 45); - auto tags7 = BSON_ARRAY(112 << 212 << 456); - auto tags8 = BSON_ARRAY(908 << 1234 << 23467813); - - std::vector<BSONArray> allTags2 = {tags4, tags5, tags6, tags7, tags8}; - BSONArray tagsConcat2 = _mock.concatBSONArrays(allTags2); - - for (int i = 0; i <= (end2 - start2); i++) { - _mock.setEncryptedTags({encField2, (start2 + i)}, allTags2[i]); - } - - auto expected = BSON("$and" << BSON_ARRAY(BSON(kSafeContent << BSON("$in" << tagsConcat1)) - << BSON(kSafeContent << BSON("$in" << tagsConcat2)))); - auto actual = _mock.rewriteMatchExpressionForTest(query); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -template <typename T> -std::vector<uint8_t> toEncryptedVector(EncryptedBinDataType dt, T t) { - BSONObj obj = t.toBSON(); - - std::vector<uint8_t> buf(obj.objsize() + 1); - buf[0] = static_cast<uint8_t>(dt); - - std::copy(obj.objdata(), obj.objdata() + obj.objsize(), buf.data() + 1); - - return buf; -} - -template <typename T> -void toEncryptedBinData(StringData field, EncryptedBinDataType dt, T t, BSONObjBuilder* builder) { - auto buf = toEncryptedVector(dt, t); - - builder->appendBinData(field, buf.size(), BinDataType::Encrypt, buf.data()); -} - -constexpr auto kIndexKeyId = "12345678-1234-9876-1234-123456789012"_sd; -constexpr auto kUserKeyId = "ABCDEFAB-1234-9876-1234-123456789012"_sd; -static UUID indexKeyId = uassertStatusOK(UUID::parse(kIndexKeyId.toString())); -static UUID userKeyId = uassertStatusOK(UUID::parse(kUserKeyId.toString())); - -std::vector<char> testValue = {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19}; -std::vector<char> testValue2 = {0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29}; - -const FLEIndexKey& getIndexKey() { - static std::string indexVec = hexblob::decode( - "7dbfebc619aa68a659f64b8e23ccd21644ac326cb74a26840c3d2420176c40ae088294d00ad6cae9684237b21b754cf503f085c25cd320bf035c3417416e1e6fe3d9219f79586582112740b2add88e1030d91926ae8afc13ee575cfb8bb965b7"_sd); - static FLEIndexKey indexKey(KeyMaterial(indexVec.begin(), indexVec.end())); - return indexKey; -} - -const FLEUserKey& getUserKey() { - static std::string userVec = hexblob::decode( - "a7ddbc4c8be00d51f68d9d8e485f351c8edc8d2206b24d8e0e1816d005fbe520e489125047d647b0d8684bfbdbf09c304085ed086aba6c2b2b1677ccc91ced8847a733bf5e5682c84b3ee7969e4a5fe0e0c21e5e3ee190595a55f83147d8de2a"_sd); - static FLEUserKey userKey(KeyMaterial(userVec.begin(), userVec.end())); - return userKey; -} - - -BSONObj generateFFP(StringData path, int value) { - auto indexKey = getIndexKey(); - FLEIndexKeyAndId indexKeyAndId(indexKey.data, indexKeyId); - auto userKey = getUserKey(); - FLEUserKeyAndId userKeyAndId(userKey.data, indexKeyId); - - BSONObj doc = BSON("value" << value); - auto element = doc.firstElement(); - auto fpp = FLEClientCrypto::serializeFindPayload(indexKeyAndId, userKeyAndId, element, 0); - - BSONObjBuilder builder; - toEncryptedBinData(path, EncryptedBinDataType::kFLE2FindEqualityPayload, fpp, &builder); - return builder.obj(); -} - -class FLEServerHighCardRewriteTest : public unittest::Test { -public: - FLEServerHighCardRewriteTest() {} - - void setUp() override {} - - void tearDown() override {} - -protected: - BasicMockFLEQueryRewriter _mock; -}; - - -TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Equality) { - _mock.setForceEncryptedCollScanForTest(); - - auto match = generateFFP("ssn", 1); - auto expected = fromjson(R"({ - "$expr": { - "$_internalFleEq": { - "field": "$ssn", - "edc": { - "$binary": { - "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", - "subType": "6" - } - }, - "counter": { - "$numberLong": "0" - }, - "server": { - "$binary": { - "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", - "subType": "6" - } - } - } - } -})"); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - - -TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_In) { - _mock.setForceEncryptedCollScanForTest(); - - auto ffp1 = generateFFP("ssn", 1); - auto ffp2 = generateFFP("ssn", 2); - auto ffp3 = generateFFP("ssn", 3); - auto expected = fromjson(R"({ - "$or": [ - { - "$expr": { - "$_internalFleEq": { - "field": "$ssn", - "edc": { - "$binary": { - "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", - "subType": "6" - } - }, - "counter": { - "$numberLong": "0" - }, - "server": { - "$binary": { - "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", - "subType": "6" - } - } - } - } - }, - { - "$expr": { - "$_internalFleEq": { - "field": "$ssn", - "edc": { - "$binary": { - "base64": "CLpCo6rNuYMVT+6n1HCX15MNrVYDNqf6udO46ayo43Sw", - "subType": "6" - } - }, - "counter": { - "$numberLong": "0" - }, - "server": { - "$binary": { - "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", - "subType": "6" - } - } - } - } - }, - { - "$expr": { - "$_internalFleEq": { - "field": "$ssn", - "edc": { - "$binary": { - "base64": "CPi44oCQHnNDeRqHsNLzbdCeHt2DK/wCly0g2dxU5fqN", - "subType": "6" - } - }, - "counter": { - "$numberLong": "0" - }, - "server": { - "$binary": { - "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", - "subType": "6" - } - } - } - } - } - ] -})"); - - auto match = - BSON("ssn" << BSON("$in" << BSON_ARRAY(ffp1.firstElement() - << ffp2.firstElement() << ffp3.firstElement()))); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - - -TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Expr) { - - _mock.setForceEncryptedCollScanForTest(); - - auto ffp = generateFFP("$ssn", 1); - int len; - auto v = ffp.firstElement().binDataClean(len); - auto match = BSON("$expr" << BSON("$eq" << BSON_ARRAY(ffp.firstElement().fieldName() - << BSONBinData(v, len, Encrypt)))); - - auto expected = fromjson(R"({ "$expr": { - "$_internalFleEq": { - "field": "$ssn", - "edc": { - "$binary": { - "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", - "subType": "6" - } - }, - "counter": { - "$numberLong": "0" - }, - "server": { - "$binary": { - "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", - "subType": "6" - } - } - } - } - })"); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Expr_In) { - - _mock.setForceEncryptedCollScanForTest(); - - auto ffp = generateFFP("$ssn", 1); - int len; - auto v = ffp.firstElement().binDataClean(len); - - auto ffp2 = generateFFP("$ssn", 1); - int len2; - auto v2 = ffp2.firstElement().binDataClean(len2); - - auto match = BSON( - "$expr" << BSON("$in" << BSON_ARRAY(ffp.firstElement().fieldName() - << BSON_ARRAY(BSONBinData(v, len, Encrypt) - << BSONBinData(v2, len2, Encrypt))))); - - auto expected = fromjson(R"({ "$expr": { "$or" : [ { - "$_internalFleEq": { - "field": "$ssn", - "edc": { - "$binary": { - "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", - "subType": "6" - } - }, - "counter": { - "$numberLong": "0" - }, - "server": { - "$binary": { - "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", - "subType": "6" - } - } - }}, - { - "$_internalFleEq": { - "field": "$ssn", - "edc": { - "$binary": { - "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", - "subType": "6" - } - }, - "counter": { - "$numberLong": "0" - }, - "server": { - "$binary": { - "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", - "subType": "6" - } - } - }} - ]}})"); - - auto actual = _mock.rewriteMatchExpressionForTest(match); - ASSERT_BSONOBJ_EQ(actual, expected); -} - -} // namespace -} // namespace mongo |