diff options
author | Mark Benvenuto <mark.benvenuto@mongodb.com> | 2022-06-15 21:20:11 -0400 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-06-16 02:21:54 +0000 |
commit | 33d893ff83c2873202628e6002a7e35b5b296db8 (patch) | |
tree | a14be19387a333d89b587c9a0c406d7ea9b71f8c | |
parent | 86c763b7e96f5c8e990387e29dd63a6c702076b4 (diff) | |
download | mongo-33d893ff83c2873202628e6002a7e35b5b296db8.tar.gz |
SERVER-66724 Create FLE 2 Equality Match Expression
27 files changed, 1452 insertions, 94 deletions
diff --git a/buildscripts/idl/idl/bson.py b/buildscripts/idl/idl/bson.py index 3f8e5190f9d..8216b5d743d 100644 --- a/buildscripts/idl/idl/bson.py +++ b/buildscripts/idl/idl/bson.py @@ -72,6 +72,7 @@ _BINDATA_SUBTYPE = { # }, "uuid": {'scalar': True, 'bindata_enum': 'newUUID'}, "md5": {'scalar': True, 'bindata_enum': 'MD5Type'}, + "encrypt": {'scalar': True, 'bindata_enum': 'Encrypt'}, } diff --git a/buildscripts/resmokeconfig/suites/fle2_high_cardinality.yml b/buildscripts/resmokeconfig/suites/fle2_high_cardinality.yml new file mode 100644 index 00000000000..4e6002ae7d2 --- /dev/null +++ b/buildscripts/resmokeconfig/suites/fle2_high_cardinality.yml @@ -0,0 +1,32 @@ +test_kind: js_test +selector: + roots: + - jstests/fle2/**/*.js + - src/mongo/db/modules/*/jstests/fle2/*.js + - src/mongo/db/modules/*/jstests/fle2/query/*.js + exclude_with_any_tags: + # Not compatible with tests the expect fle to always using $in in queries, + # i.e. verify explain output + - requires_fle_in_always + +executor: + archive: + hooks: + - ValidateCollections + config: + shell_options: + eval: "testingReplication = true; testingFLESharding = false;" + hooks: + # We don't execute dbHash or oplog consistency checks since there is only a single replica set + # node. + - class: ValidateCollections + - class: CleanEveryN + n: 20 + fixture: + class: ReplicaSetFixture + mongod_options: + set_parameters: + enableTestCommands: 1 + internalQueryFLEAlwaysUseHighCardinalityMode: 1 + # Use a 2-node replica set. + num_nodes: 2 diff --git a/buildscripts/resmokeconfig/suites/fle2_sharding_high_cardinality.yml b/buildscripts/resmokeconfig/suites/fle2_sharding_high_cardinality.yml new file mode 100644 index 00000000000..33a3d4e5c1a --- /dev/null +++ b/buildscripts/resmokeconfig/suites/fle2_sharding_high_cardinality.yml @@ -0,0 +1,37 @@ +test_kind: js_test +selector: + roots: + - jstests/fle2/*.js + - src/mongo/db/modules/*/jstests/fle2/*.js + - src/mongo/db/modules/*/jstests/fle2/query/*.js + exclude_with_any_tags: + # Not compatible with tests the expect fle to always using $in in queries, + # i.e. verify explain output + - requires_fle_in_always + +executor: + archive: + hooks: + - CheckReplDBHash + - ValidateCollections + config: + shell_options: + eval: "testingReplication = false; testingFLESharding = true;" + hooks: + - class: CheckReplDBHash + - class: ValidateCollections + - class: CleanEveryN + n: 20 + fixture: + class: ShardedClusterFixture + mongos_options: + set_parameters: + enableTestCommands: 1 + internalQueryFLEAlwaysUseHighCardinalityMode: 1 + mongod_options: + set_parameters: + enableTestCommands: 1 + internalQueryFLEAlwaysUseHighCardinalityMode: 1 + num_rs_nodes_per_shard: 2 + enable_sharding: + - test diff --git a/etc/evergreen_yml_components/definitions.yml b/etc/evergreen_yml_components/definitions.yml index 5b55bbd48ed..2f0ebf2b2ce 100644 --- a/etc/evergreen_yml_components/definitions.yml +++ b/etc/evergreen_yml_components/definitions.yml @@ -6094,6 +6094,20 @@ tasks: - func: "run tests" - <<: *task_template + name: fle2_high_cardinality + tags: ["encrypt"] + commands: + - func: "do setup" + - func: "run tests" + +- <<: *task_template + name: fle2_sharding_high_cardinality + tags: ["encrypt"] + commands: + - func: "do setup" + - func: "run tests" + +- <<: *task_template name: ocsp tags: ["ssl", "encrypt", "ocsp", "patch_build"] commands: diff --git a/src/mongo/base/error_codes.yml b/src/mongo/base/error_codes.yml index 9d02993c5bf..1a9b676eae9 100644 --- a/src/mongo/base/error_codes.yml +++ b/src/mongo/base/error_codes.yml @@ -491,6 +491,7 @@ error_codes: - {code: 374, name: TransactionAPIMustRetryTransaction, categories: [InternalOnly]} - {code: 375, name: TransactionAPIMustRetryCommit, categories: [InternalOnly]} - {code: 376, name: ChangeStreamNotEnabled} + - {code: 377, name: FLEMaxTagLimitExceeded } # Error codes 4000-8999 are reserved. diff --git a/src/mongo/crypto/fle_crypto.cpp b/src/mongo/crypto/fle_crypto.cpp index 38800792351..b7f5eebd3e9 100644 --- a/src/mongo/crypto/fle_crypto.cpp +++ b/src/mongo/crypto/fle_crypto.cpp @@ -53,6 +53,7 @@ #include "mongo/base/error_codes.h" #include "mongo/base/status.h" #include "mongo/bson/bson_depth.h" +#include "mongo/bson/bsonmisc.h" #include "mongo/bson/bsonobj.h" #include "mongo/bson/bsonobjbuilder.h" #include "mongo/bson/bsontypes.h" @@ -153,6 +154,8 @@ PrfBlock blockToArray(const SHA256Block& block) { return data; } +} // namespace + PrfBlock PrfBlockfromCDR(ConstDataRange block) { uassert(6373501, "Invalid prf length", block.length() == sizeof(PrfBlock)); @@ -161,6 +164,7 @@ PrfBlock PrfBlockfromCDR(ConstDataRange block) { return ret; } +namespace { ConstDataRange hmacKey(const KeyMaterial& keyMaterial) { static_assert(kHmacKeyOffset + crypto::sym256KeySize <= crypto::kFieldLevelEncryptionKeySize); invariant(crypto::kFieldLevelEncryptionKeySize == keyMaterial->size()); @@ -212,15 +216,18 @@ ConstDataRange binDataToCDR(const BSONElement element) { return ConstDataRange(data, data + len); } -ConstDataRange binDataToCDR(const Value& value) { - uassert(6334103, "Expected binData Value type", value.getType() == BinData); - - auto binData = value.getBinData(); +ConstDataRange binDataToCDR(const BSONBinData binData) { int len = binData.length; const char* data = static_cast<const char*>(binData.data); return ConstDataRange(data, data + len); } +ConstDataRange binDataToCDR(const Value& value) { + uassert(6334103, "Expected binData Value type", value.getType() == BinData); + + return binDataToCDR(value.getBinData()); +} + template <typename T> void toBinData(StringData field, T t, BSONObjBuilder* builder) { BSONObj obj = t.toBSON(); @@ -292,7 +299,7 @@ void toEncryptedBinData(StringData field, std::pair<EncryptedBinDataType, ConstDataRange> fromEncryptedBinData(BSONElement element) { uassert( - 6373502, "Expected binData with subtype Encrypt", element.isBinData(BinDataType::Encrypt)); + 6672414, "Expected binData with subtype Encrypt", element.isBinData(BinDataType::Encrypt)); return fromEncryptedConstDataRange(binDataToCDR(element)); } @@ -1163,6 +1170,28 @@ uint64_t generateRandomContention(uint64_t cm) { } // namespace +std::pair<EncryptedBinDataType, ConstDataRange> fromEncryptedBinData(const Value& value) { + uassert(6672416, "Expected binData with subtype Encrypt", value.getType() == BinData); + + auto binData = value.getBinData(); + + uassert(6672415, "Expected binData with subtype Encrypt", binData.type == BinDataType::Encrypt); + + return fromEncryptedConstDataRange(binDataToCDR(binData)); +} + +BSONBinData toBSONBinData(const std::vector<uint8_t>& buf) { + return BSONBinData(buf.data(), buf.size(), Encrypt); +} + +std::vector<uint8_t> toEncryptedVector(EncryptedBinDataType dt, const PrfBlock& block) { + std::vector<uint8_t> buf(block.size() + 1); + buf[0] = static_cast<uint8_t>(dt); + + std::copy(block.data(), block.data() + block.size(), buf.data() + 1); + + return buf; +} CollectionsLevel1Token FLELevel1TokenGenerator::generateCollectionsLevel1Token( FLEIndexKey indexKey) { @@ -1364,6 +1393,8 @@ std::pair<BSONType, std::vector<uint8_t>> FLEClientCrypto::decrypt(ConstDataRang return {EOO, vectorFromCDR(pair.second)}; } else if (pair.first == EncryptedBinDataType::kFLE2InsertUpdatePayload) { return {EOO, vectorFromCDR(pair.second)}; + } else if (pair.first == EncryptedBinDataType::kFLE2TransientRaw) { + return {EOO, vectorFromCDR(pair.second)}; } else { uasserted(6373507, "Not supported"); } @@ -1720,6 +1751,8 @@ FLE2FindEqualityPayload FLEClientCrypto::serializeFindPayload(FLEIndexKeyAndId i auto value = ConstDataRange(element.value(), element.value() + element.valuesize()); auto collectionToken = FLELevel1TokenGenerator::generateCollectionsLevel1Token(indexKey.key); + auto serverToken = + FLELevel1TokenGenerator::generateServerDataEncryptionLevel1Token(indexKey.key); auto edcToken = FLECollectionTokenGenerator::generateEDCToken(collectionToken); auto escToken = FLECollectionTokenGenerator::generateESCToken(collectionToken); @@ -1738,6 +1771,7 @@ FLE2FindEqualityPayload FLEClientCrypto::serializeFindPayload(FLEIndexKeyAndId i payload.setEscDerivedToken(escDatakey.toCDR()); payload.setEccDerivedToken(eccDatakey.toCDR()); payload.setMaxCounter(maxContentionFactor); + payload.setServerEncryptionToken(serverToken.toCDR()); return payload; } @@ -2076,6 +2110,44 @@ PrfBlock EDCServerCollection::generateTag(const FLE2IndexedEqualityEncryptedValu return generateTag(edcTwiceDerived, indexedValue.count); } + +StatusWith<FLE2IndexedEqualityEncryptedValue> EDCServerCollection::decryptAndParse( + ServerDataEncryptionLevel1Token token, ConstDataRange serializedServerValue) { + auto pair = fromEncryptedConstDataRange(serializedServerValue); + uassert(6672412, + "Wrong encrypted field type", + pair.first == EncryptedBinDataType::kFLE2EqualityIndexedValue); + + return FLE2IndexedEqualityEncryptedValue::decryptAndParse(token, pair.second); +} + +StatusWith<FLE2IndexedEqualityEncryptedValue> EDCServerCollection::decryptAndParse( + ConstDataRange token, ConstDataRange serializedServerValue) { + auto serverToken = FLETokenFromCDR<FLETokenType::ServerDataEncryptionLevel1Token>(token); + + return FLE2IndexedEqualityEncryptedValue::decryptAndParse(serverToken, serializedServerValue); +} + +std::vector<EDCDerivedFromDataTokenAndContentionFactorToken> EDCServerCollection::generateEDCTokens( + EDCDerivedFromDataToken token, uint64_t maxContentionFactor) { + std::vector<EDCDerivedFromDataTokenAndContentionFactorToken> tokens; + tokens.reserve(maxContentionFactor); + + for (uint64_t i = 0; i <= maxContentionFactor; ++i) { + tokens.push_back(FLEDerivedFromDataTokenAndContentionFactorTokenGenerator:: + generateEDCDerivedFromDataTokenAndContentionFactorToken(token, i)); + } + + return tokens; +} + +std::vector<EDCDerivedFromDataTokenAndContentionFactorToken> EDCServerCollection::generateEDCTokens( + ConstDataRange rawToken, uint64_t maxContentionFactor) { + auto token = FLETokenFromCDR<FLETokenType::EDCDerivedFromDataToken>(rawToken); + + return generateEDCTokens(token, maxContentionFactor); +} + BSONObj EDCServerCollection::finalizeForInsert( const BSONObj& doc, const std::vector<EDCServerPayloadInfo>& serverPayload) { std::vector<TagInfo> tags; @@ -2305,6 +2377,7 @@ EncryptedFieldConfig EncryptionInformationHelpers::getAndValidateSchema( return efc; } + std::pair<EncryptedBinDataType, ConstDataRange> fromEncryptedConstDataRange(ConstDataRange cdr) { ConstDataRangeCursor cdrc(cdr); @@ -2377,6 +2450,12 @@ ParsedFindPayload::ParsedFindPayload(ConstDataRange cdr) { escToken = FLETokenFromCDR<FLETokenType::ESCDerivedFromDataToken>(payload.getEscDerivedToken()); eccToken = FLETokenFromCDR<FLETokenType::ECCDerivedFromDataToken>(payload.getEccDerivedToken()); edcToken = FLETokenFromCDR<FLETokenType::EDCDerivedFromDataToken>(payload.getEdcDerivedToken()); + + if (payload.getServerEncryptionToken().has_value()) { + serverToken = FLETokenFromCDR<FLETokenType::ServerDataEncryptionLevel1Token>( + payload.getServerEncryptionToken().value()); + } + maxCounter = payload.getMaxCounter(); } diff --git a/src/mongo/crypto/fle_crypto.h b/src/mongo/crypto/fle_crypto.h index 5d8285f5790..2520b7a02b3 100644 --- a/src/mongo/crypto/fle_crypto.h +++ b/src/mongo/crypto/fle_crypto.h @@ -41,6 +41,7 @@ #include "mongo/base/status_with.h" #include "mongo/base/string_data.h" #include "mongo/bson/bsonelement.h" +#include "mongo/bson/bsonmisc.h" #include "mongo/bson/bsonobj.h" #include "mongo/bson/bsontypes.h" #include "mongo/crypto/aead_encryption.h" @@ -1016,6 +1017,12 @@ public: */ static std::vector<EDCServerPayloadInfo> getEncryptedFieldInfo(BSONObj& obj); + static StatusWith<FLE2IndexedEqualityEncryptedValue> decryptAndParse( + ServerDataEncryptionLevel1Token token, ConstDataRange serializedServerValue); + + static StatusWith<FLE2IndexedEqualityEncryptedValue> decryptAndParse( + ConstDataRange token, ConstDataRange serializedServerValue); + /** * Generate a search tag * @@ -1026,6 +1033,14 @@ public: static PrfBlock generateTag(const FLE2IndexedEqualityEncryptedValue& indexedValue); /** + * Generate all the EDC tokens + */ + static std::vector<EDCDerivedFromDataTokenAndContentionFactorToken> generateEDCTokens( + EDCDerivedFromDataToken token, uint64_t maxContentionFactor); + static std::vector<EDCDerivedFromDataTokenAndContentionFactorToken> generateEDCTokens( + ConstDataRange rawToken, uint64_t maxContentionFactor); + + /** * Consumes a payload from a MongoDB client for insert. * * Converts FLE2InsertUpdatePayload to a final insert payload and updates __safeContent__ with @@ -1163,6 +1178,7 @@ struct ParsedFindPayload { ESCDerivedFromDataToken escToken; ECCDerivedFromDataToken eccToken; EDCDerivedFromDataToken edcToken; + boost::optional<ServerDataEncryptionLevel1Token> serverToken; boost::optional<std::int64_t> maxCounter; explicit ParsedFindPayload(BSONElement fleFindPayload); @@ -1170,4 +1186,15 @@ struct ParsedFindPayload { explicit ParsedFindPayload(ConstDataRange cdr); }; +/** + * Utility functions manipulating buffers + */ +PrfBlock PrfBlockfromCDR(ConstDataRange block); + +std::vector<uint8_t> toEncryptedVector(EncryptedBinDataType dt, const PrfBlock& block); + +BSONBinData toBSONBinData(const std::vector<uint8_t>& buf); + +std::pair<EncryptedBinDataType, ConstDataRange> fromEncryptedBinData(const Value& value); + } // namespace mongo diff --git a/src/mongo/crypto/fle_crypto_test.cpp b/src/mongo/crypto/fle_crypto_test.cpp index 75c1976c097..99b470497cb 100644 --- a/src/mongo/crypto/fle_crypto_test.cpp +++ b/src/mongo/crypto/fle_crypto_test.cpp @@ -33,6 +33,7 @@ #include "mongo/crypto/fle_crypto.h" #include <algorithm> +#include <cstdint> #include <iostream> #include <limits> #include <stack> @@ -696,7 +697,8 @@ std::vector<char> generatePlaceholder( BSONElement value, Operation operation, mongo::Fle2AlgorithmInt algorithm = mongo::Fle2AlgorithmInt::kEquality, - boost::optional<UUID> key = boost::none) { + boost::optional<UUID> key = boost::none, + uint64_t contention = 0) { FLE2EncryptionPlaceholder ep; if (operation == Operation::kFind) { @@ -709,7 +711,7 @@ std::vector<char> generatePlaceholder( ep.setUserKeyId(userKeyId); ep.setIndexKeyId(key.value_or(indexKeyId)); ep.setValue(value); - ep.setMaxContentionCounter(0); + ep.setMaxContentionCounter(contention); BSONObj obj = ep.toBSON(); @@ -832,6 +834,41 @@ void roundTripMultiencrypted(BSONObj doc1, assertPayload(finalDoc["encrypted2"], operation2); } +// Used to generate the test data for the ExpressionFLETest in expression_test.cpp +TEST(FLE_EDC, PrintTest) { + auto doc = BSON("value" << 1); + auto element = doc.firstElement(); + + TestKeyVault keyVault; + + auto inputDoc = BSON("plainText" + << "sample" + << "encrypted" << element); + + { + auto buf = generatePlaceholder(element, Operation::kInsert, Fle2AlgorithmInt::kEquality); + BSONObjBuilder builder; + builder.append("plainText", "sample"); + builder.appendBinData("encrypted", buf.size(), BinDataType::Encrypt, buf.data()); + + auto finalDoc = encryptDocument(builder.obj(), &keyVault); + + std::cout << finalDoc.jsonString() << std::endl; + } + + { + auto buf = generatePlaceholder( + element, Operation::kInsert, Fle2AlgorithmInt::kEquality, boost::none, 50); + BSONObjBuilder builder; + builder.append("plainText", "sample"); + builder.appendBinData("encrypted", buf.size(), BinDataType::Encrypt, buf.data()); + + auto finalDoc = encryptDocument(builder.obj(), &keyVault); + + std::cout << finalDoc.jsonString() << std::endl; + } +} + TEST(FLE_EDC, Allowed_Types) { const std::vector<std::pair<BSONObj, BSONType>> universallyAllowedObjects{ {BSON("sample" @@ -1928,4 +1965,25 @@ TEST(CompactionHelpersTest, countDeletedTest) { ASSERT_EQ(CompactionHelpers::countDeleted(input), 20); } +TEST(EDCServerCollectionTest, GenerateEDCTokens) { + + auto doc = BSON("sample" << 123456); + auto element = doc.firstElement(); + + auto value = ConstDataRange(element.value(), element.value() + element.valuesize()); + + auto collectionToken = FLELevel1TokenGenerator::generateCollectionsLevel1Token(getIndexKey()); + auto edcToken = FLECollectionTokenGenerator::generateEDCToken(collectionToken); + + EDCDerivedFromDataToken edcDatakey = + FLEDerivedFromDataTokenGenerator::generateEDCDerivedFromDataToken(edcToken, value); + + + ASSERT_EQ(EDCServerCollection::generateEDCTokens(edcDatakey, 0).size(), 1); + ASSERT_EQ(EDCServerCollection::generateEDCTokens(edcDatakey, 1).size(), 2); + ASSERT_EQ(EDCServerCollection::generateEDCTokens(edcDatakey, 2).size(), 3); + ASSERT_EQ(EDCServerCollection::generateEDCTokens(edcDatakey, 3).size(), 4); +} + + } // namespace mongo diff --git a/src/mongo/crypto/fle_field_schema.idl b/src/mongo/crypto/fle_field_schema.idl index 030fa36ef3f..d9e2b54b890 100644 --- a/src/mongo/crypto/fle_field_schema.idl +++ b/src/mongo/crypto/fle_field_schema.idl @@ -51,6 +51,10 @@ enums: kFLE2UnindexedEncryptedValue : 6 # see FLE2IndexedEqualityEncryptedValue kFLE2EqualityIndexedValue : 7 + # Transient encrypted data in query rewrites, not persisted + # same as BinDataGeneral but redacted + kFLE2TransientRaw : 8 + FleVersion: description: "The version / type of field-level encryption in use." type: int diff --git a/src/mongo/crypto/fle_tags.cpp b/src/mongo/crypto/fle_tags.cpp index a0de37b2f42..4737ff13144 100644 --- a/src/mongo/crypto/fle_tags.cpp +++ b/src/mongo/crypto/fle_tags.cpp @@ -56,8 +56,11 @@ void verifyTagsWillFit(size_t tagCount, size_t memoryLimit) { constexpr size_t largestElementSize = arrayElementSize(std::numeric_limits<size_t>::digits10); constexpr size_t ridiculousNumberOfTags = std::numeric_limits<size_t>::max() / largestElementSize; - uassert(6653300, "Encrypted rewrite too many tags", tagCount < ridiculousNumberOfTags); - uassert(6401800, + + uassert(ErrorCodes::FLEMaxTagLimitExceeded, + "Encrypted rewrite too many tags", + tagCount < ridiculousNumberOfTags); + uassert(ErrorCodes::FLEMaxTagLimitExceeded, "Encrypted rewrite memory limit exceeded", sizeArrayElementsMemory(tagCount) <= memoryLimit); } diff --git a/src/mongo/db/dbmessage.h b/src/mongo/db/dbmessage.h index 1f5472b2272..0b8e8ce84c7 100644 --- a/src/mongo/db/dbmessage.h +++ b/src/mongo/db/dbmessage.h @@ -227,7 +227,7 @@ public: * Indicates whether this message is expected to have a ns. */ bool messageShouldHaveNs() const { - return (_msg.operation() >= dbUpdate) & (_msg.operation() <= dbDelete); + return static_cast<int>(_msg.operation() >= dbUpdate) & (_msg.operation() <= dbDelete); } /** diff --git a/src/mongo/db/fle_crud.cpp b/src/mongo/db/fle_crud.cpp index 069cd34b1ac..41484d1a3af 100644 --- a/src/mongo/db/fle_crud.cpp +++ b/src/mongo/db/fle_crud.cpp @@ -801,8 +801,17 @@ write_ops::UpdateCommandReply processUpdate(FLEQueryInterface* queryImpl, // Step 1 ---- std::vector<EDCServerPayloadInfo> serverPayload; auto newUpdateOpEntry = updateRequest.getUpdates()[0]; - newUpdateOpEntry.setQ(fle::rewriteEncryptedFilterInsideTxn( - queryImpl, updateRequest.getDbName(), efc, expCtx, newUpdateOpEntry.getQ())); + + auto highCardinalityModeAllowed = newUpdateOpEntry.getUpsert() + ? fle::HighCardinalityModeAllowed::kDisallow + : fle::HighCardinalityModeAllowed::kAllow; + + newUpdateOpEntry.setQ(fle::rewriteEncryptedFilterInsideTxn(queryImpl, + updateRequest.getDbName(), + efc, + expCtx, + newUpdateOpEntry.getQ(), + highCardinalityModeAllowed)); if (updateModification.type() == write_ops::UpdateModification::Type::kModifier) { auto updateModifier = updateModification.getUpdateModifier(); @@ -972,19 +981,25 @@ std::unique_ptr<BatchedCommandRequest> processFLEBatchExplain( request.getNS(), deleteRequest.getEncryptionInformation().get(), newDeleteOp.getQ(), - &getTransactionWithRetriesForMongoS)); + &getTransactionWithRetriesForMongoS, + fle::HighCardinalityModeAllowed::kAllow)); deleteRequest.setDeletes({newDeleteOp}); deleteRequest.getWriteCommandRequestBase().setEncryptionInformation(boost::none); return std::make_unique<BatchedCommandRequest>(deleteRequest); } else if (request.getBatchType() == BatchedCommandRequest::BatchType_Update) { auto updateRequest = request.getUpdateRequest(); auto newUpdateOp = updateRequest.getUpdates()[0]; + auto highCardinalityModeAllowed = newUpdateOp.getUpsert() + ? fle::HighCardinalityModeAllowed::kDisallow + : fle::HighCardinalityModeAllowed::kAllow; + newUpdateOp.setQ(fle::rewriteQuery(opCtx, getExpCtx(newUpdateOp), request.getNS(), updateRequest.getEncryptionInformation().get(), newUpdateOp.getQ(), - &getTransactionWithRetriesForMongoS)); + &getTransactionWithRetriesForMongoS, + highCardinalityModeAllowed)); updateRequest.setUpdates({newUpdateOp}); updateRequest.getWriteCommandRequestBase().setEncryptionInformation(boost::none); return std::make_unique<BatchedCommandRequest>(updateRequest); @@ -1009,8 +1024,17 @@ write_ops::FindAndModifyCommandReply processFindAndModify( // Step 0 ---- // Rewrite filter - newFindAndModifyRequest.setQuery(fle::rewriteEncryptedFilterInsideTxn( - queryImpl, edcNss.db(), efc, expCtx, findAndModifyRequest.getQuery())); + auto highCardinalityModeAllowed = findAndModifyRequest.getUpsert().value_or(false) + ? fle::HighCardinalityModeAllowed::kDisallow + : fle::HighCardinalityModeAllowed::kAllow; + + newFindAndModifyRequest.setQuery( + fle::rewriteEncryptedFilterInsideTxn(queryImpl, + edcNss.db(), + efc, + expCtx, + findAndModifyRequest.getQuery(), + highCardinalityModeAllowed)); // Make sure not to inherit the command's writeConcern, this should be set at the transaction // level. @@ -1133,8 +1157,17 @@ write_ops::FindAndModifyCommandRequest processFindAndModifyExplain( auto efc = EncryptionInformationHelpers::getAndValidateSchema(edcNss, ei); auto newFindAndModifyRequest = findAndModifyRequest; - newFindAndModifyRequest.setQuery(fle::rewriteEncryptedFilterInsideTxn( - queryImpl, edcNss.db(), efc, expCtx, findAndModifyRequest.getQuery())); + auto highCardinalityModeAllowed = findAndModifyRequest.getUpsert().value_or(false) + ? fle::HighCardinalityModeAllowed::kDisallow + : fle::HighCardinalityModeAllowed::kAllow; + + newFindAndModifyRequest.setQuery( + fle::rewriteEncryptedFilterInsideTxn(queryImpl, + edcNss.db(), + efc, + expCtx, + findAndModifyRequest.getQuery(), + highCardinalityModeAllowed)); newFindAndModifyRequest.setEncryptionInformation(boost::none); return newFindAndModifyRequest; diff --git a/src/mongo/db/fle_crud_mongod.cpp b/src/mongo/db/fle_crud_mongod.cpp index 68327133c88..1e488f1f65a 100644 --- a/src/mongo/db/fle_crud_mongod.cpp +++ b/src/mongo/db/fle_crud_mongod.cpp @@ -284,7 +284,13 @@ BSONObj processFLEWriteExplainD(OperationContext* opCtx, const BSONObj& query) { auto expCtx = make_intrusive<ExpressionContext>( opCtx, fle::collatorFromBSON(opCtx, collation), nss, runtimeConstants, letParameters); - return fle::rewriteQuery(opCtx, expCtx, nss, info, query, &getTransactionWithRetriesForMongoD); + return fle::rewriteQuery(opCtx, + expCtx, + nss, + info, + query, + &getTransactionWithRetriesForMongoD, + fle::HighCardinalityModeAllowed::kAllow); } std::pair<write_ops::FindAndModifyCommandRequest, OpMsgRequest> diff --git a/src/mongo/db/fle_crud_test.cpp b/src/mongo/db/fle_crud_test.cpp index 527dd5bca11..d97f4621987 100644 --- a/src/mongo/db/fle_crud_test.cpp +++ b/src/mongo/db/fle_crud_test.cpp @@ -27,6 +27,7 @@ * it in the license file. */ +#include "mongo/base/error_codes.h" #include "mongo/platform/basic.h" #include <algorithm> @@ -1199,7 +1200,7 @@ TEST_F(FleTagsTest, MemoryLimit) { doSingleInsert(10, doc); // readTags returns 11 tags which does exceed memory limit. - ASSERT_THROWS_CODE(readTags(doc), DBException, 6401800); + ASSERT_THROWS_CODE(readTags(doc), DBException, ErrorCodes::FLEMaxTagLimitExceeded); doSingleDelete(5); diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index ff2c639db8b..96c7d59a025 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -99,6 +99,7 @@ env.Library( 'expression_context.cpp', 'expression_function.cpp', 'expression_js_emit.cpp', + 'expression_parser.idl', 'expression_test_api_version.cpp', 'expression_trigonometric.cpp', 'javascript_execution.cpp', @@ -106,6 +107,7 @@ env.Library( 'variables.cpp', ], LIBDEPS=[ + '$BUILD_DIR/mongo/crypto/fle_crypto', '$BUILD_DIR/mongo/db/bson/dotted_path_support', '$BUILD_DIR/mongo/db/commands/test_commands_enabled', '$BUILD_DIR/mongo/db/exec/document_value/document_value', @@ -128,6 +130,7 @@ env.Library( LIBDEPS_PRIVATE=[ '$BUILD_DIR/mongo/db/mongohasher', '$BUILD_DIR/mongo/db/vector_clock', + '$BUILD_DIR/mongo/idl/idl_parser', ], ) diff --git a/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp b/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp index 20552e85599..05b9fff8932 100644 --- a/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp +++ b/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp @@ -322,6 +322,10 @@ public: unsupportedExpression(expr->getOpName()); } + void visit(const ExpressionInternalFLEEqual* expr) override final { + unsupportedExpression(expr->getOpName()); + } + void visit(const ExpressionMap* expr) override final { unsupportedExpression("$map"); } diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index 12e1ddf2ec3..464d2ad6953 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -39,6 +39,9 @@ #include <utility> #include <vector> +#include "mongo/bson/bsonmisc.h" +#include "mongo/bson/bsontypes.h" +#include "mongo/crypto/fle_crypto.h" #include "mongo/db/bson/dotted_path_support.h" #include "mongo/db/commands/feature_compatibility_version_documentation.h" #include "mongo/db/exec/document_value/document.h" @@ -46,6 +49,7 @@ #include "mongo/db/hasher.h" #include "mongo/db/jsobj.h" #include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_parser_gen.h" #include "mongo/db/pipeline/variable_validation.h" #include "mongo/db/query/datetime/date_time_support.h" #include "mongo/db/query/sort_pattern.h" @@ -3804,6 +3808,123 @@ const char* ExpressionLog10::getOpName() const { return "$log10"; } +/* ----------------------- ExpressionInternalFLEEqual ---------------------------- */ +constexpr auto kInternalFleEq = "$_internalFleEq"_sd; + +ExpressionInternalFLEEqual::ExpressionInternalFLEEqual(ExpressionContext* const expCtx, + boost::intrusive_ptr<Expression> field, + ConstDataRange serverToken, + int64_t contentionFactor, + ConstDataRange edcToken) + : Expression(expCtx, {std::move(field)}), + _serverToken(PrfBlockfromCDR(serverToken)), + _edcToken(PrfBlockfromCDR(edcToken)), + _contentionFactor(contentionFactor) { + expCtx->sbeCompatible = false; + + auto tokens = + EDCServerCollection::generateEDCTokens(ConstDataRange(_edcToken), _contentionFactor); + + for (auto& token : tokens) { + _cachedEDCTokens.insert(std::move(token.data)); + } +} + +void ExpressionInternalFLEEqual::_doAddDependencies(DepsTracker* deps) const { + for (auto&& operand : _children) { + operand->addDependencies(deps); + } +} + +REGISTER_EXPRESSION_WITH_MIN_VERSION(_internalFleEq, + ExpressionInternalFLEEqual::parse, + AllowedWithApiStrict::kAlways, + AllowedWithClientType::kAny, + multiversion::FeatureCompatibilityVersion::kVersion_6_0); + +intrusive_ptr<Expression> ExpressionInternalFLEEqual::parse(ExpressionContext* const expCtx, + BSONElement expr, + const VariablesParseState& vps) { + + IDLParserErrorContext ctx(kInternalFleEq); + auto fleEq = InternalFleEqStruct::parse(ctx, expr.Obj()); + + auto fieldExpr = Expression::parseOperand(expCtx, fleEq.getField().getElement(), vps); + + auto serverTokenPair = fromEncryptedConstDataRange(fleEq.getServerEncryptionToken()); + + uassert(6672405, + "Invalid server token", + serverTokenPair.first == EncryptedBinDataType::kFLE2TransientRaw && + serverTokenPair.second.length() == sizeof(PrfBlock)); + + auto edcTokenPair = fromEncryptedConstDataRange(fleEq.getEdcDerivedToken()); + + uassert(6672406, + "Invalid edc token", + edcTokenPair.first == EncryptedBinDataType::kFLE2TransientRaw && + edcTokenPair.second.length() == sizeof(PrfBlock)); + + + auto cf = fleEq.getMaxCounter(); + uassert(6672408, "Contention factor must be between 0 and 10000", cf >= 0 && cf < 10000); + + return new ExpressionInternalFLEEqual(expCtx, + std::move(fieldExpr), + serverTokenPair.second, + fleEq.getMaxCounter(), + edcTokenPair.second); +} + +Value toValue(const std::array<std::uint8_t, 32>& buf) { + auto vec = toEncryptedVector(EncryptedBinDataType::kFLE2TransientRaw, buf); + return Value(BSONBinData(vec.data(), vec.size(), BinDataType::Encrypt)); +} + +Value ExpressionInternalFLEEqual::serialize(bool explain) const { + return Value(Document{{kInternalFleEq, + Document{{"field", _children[0]->serialize(explain)}, + {"edc", toValue(_edcToken)}, + {"counter", Value(static_cast<long long>(_contentionFactor))}, + {"server", toValue(_serverToken)}}}}); +} + +Value ExpressionInternalFLEEqual::evaluate(const Document& root, Variables* variables) const { + // Inputs + // 1. Value for FLE2IndexedEqualityEncryptedValue field + + Value fieldValue = _children[0]->evaluate(root, variables); + + if (fieldValue.nullish()) { + return Value(BSONNULL); + } + + if (fieldValue.getType() != BinData) { + return Value(false); + } + + auto fieldValuePair = fromEncryptedBinData(fieldValue); + + uassert(6672407, + "Invalid encrypted indexed field", + fieldValuePair.first == EncryptedBinDataType::kFLE2EqualityIndexedValue); + + // Value matches if + // 1. Decrypt field is successful + // 2. EDC_u Token is in GenTokens(EDC Token, ContentionFactor) + // + auto swIndexed = + EDCServerCollection::decryptAndParse(ConstDataRange(_serverToken), fieldValuePair.second); + uassertStatusOK(swIndexed); + auto indexed = swIndexed.getValue(); + + return Value(_cachedEDCTokens.count(indexed.edc.data) == 1); +} + +const char* ExpressionInternalFLEEqual::getOpName() const { + return kInternalFleEq.rawData(); +} + /* ------------------------ ExpressionNary ----------------------------- */ /** diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index ff53eaedf3e..4b5745bb2b6 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -29,6 +29,7 @@ #pragma once +#include "mongo/base/data_range.h" #include "mongo/platform/basic.h" #include <algorithm> @@ -2197,6 +2198,38 @@ public: } }; +class ExpressionInternalFLEEqual final : public Expression { +public: + ExpressionInternalFLEEqual(ExpressionContext* expCtx, + boost::intrusive_ptr<Expression> field, + ConstDataRange serverToken, + int64_t contentionFactor, + ConstDataRange edcToken); + Value serialize(bool explain) const final; + + Value evaluate(const Document& root, Variables* variables) const final; + const char* getOpName() const; + + static boost::intrusive_ptr<Expression> parse(ExpressionContext* expCtx, + BSONElement expr, + const VariablesParseState& vps); + void _doAddDependencies(DepsTracker* deps) const final; + + void acceptVisitor(ExpressionMutableVisitor* visitor) final { + return visitor->visit(this); + } + + void acceptVisitor(ExpressionConstVisitor* visitor) const final { + return visitor->visit(this); + } + +private: + std::array<std::uint8_t, 32> _serverToken; + std::array<std::uint8_t, 32> _edcToken; + int64_t _contentionFactor; + stdx::unordered_set<std::array<std::uint8_t, 32>> _cachedEDCTokens; +}; + class ExpressionMap final : public Expression { public: ExpressionMap( diff --git a/src/mongo/db/pipeline/expression_parser.idl b/src/mongo/db/pipeline/expression_parser.idl new file mode 100644 index 00000000000..9f1cde70856 --- /dev/null +++ b/src/mongo/db/pipeline/expression_parser.idl @@ -0,0 +1,57 @@ +# Copyright (C) 2022-present MongoDB, Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the Server Side Public License, version 1, +# as published by MongoDB, Inc. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Server Side Public License for more details. +# +# You should have received a copy of the Server Side Public License +# along with this program. If not, see +# <http://www.mongodb.com/licensing/server-side-public-license>. +# +# As a special exception, the copyright holders give permission to link the +# code of portions of this program with the OpenSSL library under certain +# conditions as described in each individual source file and distribute +# linked combinations including the program with the OpenSSL library. You +# must comply with the Server Side Public License in all respects for +# all of the code used other than as permitted herein. If you modify file(s) +# with this exception, you may extend this exception to your version of the +# file(s), but you are not obligated to do so. If you do not wish to do so, +# delete this exception statement from your version. If you delete this +# exception statement from all source files in the program, then also delete +# it in the license file. + +global: + cpp_namespace: "mongo" + +imports: + - "mongo/idl/basic_types.idl" + +structs: + + InternalFleEqStruct: + description: "Struct for $_internalFleEq" + strict: true + fields: + field: + description: "Expression" + type: IDLAnyType + cpp_name: field + edc: + description: "EDCDerivedFromDataToken" + type: bindata_encrypt + cpp_name: edcDerivedToken + server: + description: "ServerDataEncryptionLevel1Token" + type: bindata_encrypt + cpp_name: serverEncryptionToken + counter: + description: "Queryable Encryption max counter" + type: long + cpp_name: maxCounter + + diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index a33de77322c..314062c3f03 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -175,6 +175,7 @@ void parseAndVerifyResults( ASSERT_VALUE_EQ(expr->evaluate({}, &expCtx.variables), expected); } + /* ------------------------- ExpressionArrayToObject -------------------------- */ TEST(ExpressionArrayToObjectTest, KVFormatSimple) { @@ -3930,4 +3931,240 @@ TEST(ExpressionAddTest, VerifyNoDoubleDoubleSummation) { ASSERT_VALUE_EQ(result, Value(straightSum)); ASSERT_VALUE_NE(result, Value(compensatedSum.getDouble())); } +TEST(ExpressionFLETest, BadInputs) { + + auto expCtx = ExpressionContextForTest(); + auto vps = expCtx.variablesParseState; + { + auto expr = fromjson("{$_internalFleEq: 12}"); + ASSERT_THROWS_CODE(ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps), + DBException, + 10065); + } +} + +// Test we return true if it matches +TEST(ExpressionFLETest, TestBinData) { + auto expCtx = ExpressionContextForTest(); + auto vps = expCtx.variablesParseState; + + { + auto expr = fromjson(R"({$_internalFleEq: { + field: { + "$binary": { + "base64": + "BxI0VngSNJh2EjQSNFZ4kBIQ0JE8aMUFkPk5sSTVqfdNNfjqUfQQ1Uoj0BBcthrWoe9wyU3cN6zmWaQBPJ97t0ZPbecnMsU736yXre6cBO4Zdt/wThtY+v5+7vFgNnWpgRP0e+vam6QPmLvbBrO0LdsvAPTGW4yqwnzCIXCoEg7QPGfbfAXKPDTNenBfRlawiblmTOhO/6ljKotWsMp22q/rpHrn9IEIeJmecwuuPIJ7EA+XYQ3hOKVccYf2ogoK73+8xD/Vul83Qvr84Q8afc4QUMVs8A==", + "subType": "6" + } + }, + server: { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + }, + counter: { + "$numberLong": "3" + }, + edc: { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + } } })"); + auto exprFle = ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps); + + ASSERT_VALUE_EQ(exprFle->evaluate({}, &expCtx.variables), Value(true)); + } + + // Negative: Use wrong server token + { + auto expr = fromjson(R"({$_internalFleEq: { + field: { + "$binary": { + "base64": + "BxI0VngSNJh2EjQSNFZ4kBIQ0JE8aMUFkPk5sSTVqfdNNfjqUfQQ1Uoj0BBcthrWoe9wyU3cN6zmWaQBPJ97t0ZPbecnMsU736yXre6cBO4Zdt/wThtY+v5+7vFgNnWpgRP0e+vam6QPmLvbBrO0LdsvAPTGW4yqwnzCIXCoEg7QPGfbfAXKPDTNenBfRlawiblmTOhO/6ljKotWsMp22q/rpHrn9IEIeJmecwuuPIJ7EA+XYQ3hOKVccYf2ogoK73+8xD/Vul83Qvr84Q8afc4QUMVs8A==", + "subType": "6" + } + }, + server: { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + }, + counter: { + "$numberLong": "3" + }, + edc: { + "$binary": { + "base64": "CEWSMQID7SFWYAUI3ZKSFKATKRYDQFNXXEOGAD5D4RSG", + "subType": "6" + } + } } })"); + auto exprFle = ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps); + + ASSERT_VALUE_EQ(exprFle->evaluate({}, &expCtx.variables), Value(false)); + } + + // Negative: Use wrong edc token + { + auto expr = fromjson(R"({$_internalFleEq: { + field: { + "$binary": { + "base64": + "BxI0VngSNJh2EjQSNFZ4kBIQ0JE8aMUFkPk5sSTVqfdNNfjqUfQQ1Uoj0BBcthrWoe9wyU3cN6zmWaQBPJ97t0ZPbecnMsU736yXre6cBO4Zdt/wThtY+v5+7vFgNnWpgRP0e+vam6QPmLvbBrO0LdsvAPTGW4yqwnzCIXCoEg7QPGfbfAXKPDTNenBfRlawiblmTOhO/6ljKotWsMp22q/rpHrn9IEIeJmecwuuPIJ7EA+XYQ3hOKVccYf2ogoK73+8xD/Vul83Qvr84Q8afc4QUMVs8A==", + "subType": "6" + } + }, + server: { + "$binary": { + "base64": "COUAC/ERLYAKKX6B0VZ1R3QODOQFFJQJD+XLGIPU4/PS", + "subType": "6" + } + }, + counter: { + "$numberLong": "3" + }, + edc: { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + } } })"); + auto exprFle = ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps); + + ASSERT_THROWS_CODE( + exprFle->evaluate({}, &expCtx.variables), DBException, ErrorCodes::Overflow); + } +} + +TEST(ExpressionFLETest, TestBinData_ContentionFactor) { + auto expCtx = ExpressionContextForTest(); + auto vps = expCtx.variablesParseState; + + // Use the wrong contention factor - 0 + { + auto expr = fromjson(R"({$_internalFleEq: { + field: { + "$binary": { + "base64": + "BxI0VngSNJh2EjQSNFZ4kBIQ5+Wa5+SZafJeRUDGdLNx+i2ADDkyV2qA90Xcve7FqltoDm1PllSSgUS4fYtw3XDjzoNZrFFg8LfG2wH0HYbLMswv681KJpmEw7+RXy4CcPVFgoRFt24N13p7jT+pqu2oQAHAoxYTy/TsiAyY4RnAMiXYGg3hWz4AO/WxHNSyq6B6kX5d7x/hrXvppsZDc2Pmhd+c5xmovlv5RPj7wnNld13kYcMluztjNswiCH05hM/kp2/P7kw30iVnbz0SZxn1FjjCug==", + "subType": "6" + } + }, + server: { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + }, + counter: { + "$numberLong": "0" + }, + edc: { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + } } })"); + auto exprFle = ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps); + + ASSERT_VALUE_EQ(exprFle->evaluate({}, &expCtx.variables), Value(false)); + } + + // Use the right contention factor - 50 + { + auto expr = fromjson(R"({$_internalFleEq: { + field: { + "$binary": { + "base64": +"BxI0VngSNJh2EjQSNFZ4kBIQ5+Wa5+SZafJeRUDGdLNx+i2ADDkyV2qA90Xcve7FqltoDm1PllSSgUS4fYtw3XDjzoNZrFFg8LfG2wH0HYbLMswv681KJpmEw7+RXy4CcPVFgoRFt24N13p7jT+pqu2oQAHAoxYTy/TsiAyY4RnAMiXYGg3hWz4AO/WxHNSyq6B6kX5d7x/hrXvppsZDc2Pmhd+c5xmovlv5RPj7wnNld13kYcMluztjNswiCH05hM/kp2/P7kw30iVnbz0SZxn1FjjCug==", + "subType": "6" + } + }, + server: { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + }, + counter: { + "$numberLong": "50" + }, + edc: { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + } } })"); + auto exprFle = ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps); + + ASSERT_VALUE_EQ(exprFle->evaluate({}, &expCtx.variables), Value(true)); + } +} + +TEST(ExpressionFLETest, TestBinData_RoundTrip) { + auto expCtx = ExpressionContextForTest(); + auto vps = expCtx.variablesParseState; + + auto expr = fromjson(R"({$_internalFleEq: { + field: { + "$binary": { + "base64": + "BxI0VngSNJh2EjQSNFZ4kBIQ0JE8aMUFkPk5sSTVqfdNNfjqUfQQ1Uoj0BBcthrWoe9wyU3cN6zmWaQBPJ97t0ZPbecnMsU736yXre6cBO4Zdt/wThtY+v5+7vFgNnWpgRP0e+vam6QPmLvbBrO0LdsvAPTGW4yqwnzCIXCoEg7QPGfbfAXKPDTNenBfRlawiblmTOhO/6ljKotWsMp22q/rpHrn9IEIeJmecwuuPIJ7EA+XYQ3hOKVccYf2ogoK73+8xD/Vul83Qvr84Q8afc4QUMVs8A==", + "subType": "6" + } + }, + server: { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + }, + counter: { + "$numberLong": "3" + }, + edc: { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + } } })"); + auto exprFle = ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps); + + ASSERT_VALUE_EQ(exprFle->evaluate({}, &expCtx.variables), Value(true)); + + // Verify it round trips + auto value = exprFle->serialize(false); + + auto roundTripExpr = fromjson(R"({$_internalFleEq: { + field: { + "$const" : { "$binary": { + "base64": + "BxI0VngSNJh2EjQSNFZ4kBIQ0JE8aMUFkPk5sSTVqfdNNfjqUfQQ1Uoj0BBcthrWoe9wyU3cN6zmWaQBPJ97t0ZPbecnMsU736yXre6cBO4Zdt/wThtY+v5+7vFgNnWpgRP0e+vam6QPmLvbBrO0LdsvAPTGW4yqwnzCIXCoEg7QPGfbfAXKPDTNenBfRlawiblmTOhO/6ljKotWsMp22q/rpHrn9IEIeJmecwuuPIJ7EA+XYQ3hOKVccYf2ogoK73+8xD/Vul83Qvr84Q8afc4QUMVs8A==", + "subType": "6" + }} + }, + edc: { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + counter: { + "$numberLong": "3" + }, + server: { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } })"); + + + ASSERT_BSONOBJ_EQ(value.getDocument().toBson(), roundTripExpr); +} + } // namespace ExpressionTests diff --git a/src/mongo/db/pipeline/expression_visitor.h b/src/mongo/db/pipeline/expression_visitor.h index 46ad3ee6295..6b7c4fc4cdd 100644 --- a/src/mongo/db/pipeline/expression_visitor.h +++ b/src/mongo/db/pipeline/expression_visitor.h @@ -153,6 +153,7 @@ class ExpressionHyperbolicSine; class ExpressionInternalFindSlice; class ExpressionInternalFindPositional; class ExpressionInternalFindElemMatch; +class ExpressionInternalFLEEqual; class ExpressionInternalJsEmit; class ExpressionFunction; class ExpressionDegreesToRadians; @@ -245,6 +246,7 @@ public: virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionLn>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionLog>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionLog10>) = 0; + virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionInternalFLEEqual>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionMap>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionMeta>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionMod>) = 0; @@ -424,6 +426,7 @@ struct SelectiveConstExpressionVisitorBase : public ExpressionConstVisitor { void visit(const ExpressionLn*) override {} void visit(const ExpressionLog*) override {} void visit(const ExpressionLog10*) override {} + void visit(const ExpressionInternalFLEEqual*) override {} void visit(const ExpressionMap*) override {} void visit(const ExpressionMeta*) override {} void visit(const ExpressionMod*) override {} diff --git a/src/mongo/db/query/fle/server_rewrite.cpp b/src/mongo/db/query/fle/server_rewrite.cpp index f4f02bcb383..2aeb99a4061 100644 --- a/src/mongo/db/query/fle/server_rewrite.cpp +++ b/src/mongo/db/query/fle/server_rewrite.cpp @@ -32,6 +32,7 @@ #include <memory> +#include "mongo/bson/bsonmisc.h" #include "mongo/bson/bsonobj.h" #include "mongo/bson/bsonobjbuilder.h" #include "mongo/bson/bsontypes.h" @@ -48,9 +49,14 @@ #include "mongo/db/pipeline/expression.h" #include "mongo/db/query/collation/collator_factory_interface.h" #include "mongo/db/service_context.h" +#include "mongo/logv2/log.h" #include "mongo/s/grid.h" #include "mongo/s/transaction_router_resource_yielder.h" #include "mongo/util/assert_util.h" +#include "mongo/util/intrusive_counter.h" + + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kQuery namespace mongo::fle { @@ -68,6 +74,56 @@ std::unique_ptr<CollatorInterface> collatorFromBSON(OperationContext* opCtx, } namespace { +template <typename PayloadT> +boost::intrusive_ptr<ExpressionInternalFLEEqual> generateFleEqualMatch(StringData path, + const PayloadT& ffp, + ExpressionContext* expCtx) { + // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] } + auto tokens = ParsedFindPayload(ffp); + + uassert(6672401, + "Missing required field server encryption token in find payload", + tokens.serverToken.has_value()); + + return make_intrusive<ExpressionInternalFLEEqual>( + expCtx, + ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState), + tokens.serverToken.get().data, + tokens.maxCounter.value_or(0LL), + tokens.edcToken.data); +} + + +template <typename PayloadT> +std::unique_ptr<ExpressionInternalFLEEqual> generateFleEqualMatchUnique(StringData path, + const PayloadT& ffp, + ExpressionContext* expCtx) { + // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] } + auto tokens = ParsedFindPayload(ffp); + + uassert(6672419, + "Missing required field server encryption token in find payload", + tokens.serverToken.has_value()); + + return std::make_unique<ExpressionInternalFLEEqual>( + expCtx, + ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState), + tokens.serverToken.get().data, + tokens.maxCounter.value_or(0LL), + tokens.edcToken.data); +} + +std::unique_ptr<MatchExpression> generateFleEqualMatchAndExpr(StringData path, + const BSONElement ffp, + ExpressionContext* expCtx) { + auto fleEqualMatch = generateFleEqualMatch(path, ffp, expCtx); + + return std::make_unique<ExprMatchExpression>(fleEqualMatch, expCtx); +} + + /** * This section defines a mapping from DocumentSources to the dispatch function to appropriately * handle FLE rewriting for that stage. This should be kept in line with code on the client-side @@ -128,7 +184,8 @@ public: * The final output will look like * {$or: [{$in: [tag0, "$__safeContent__"]}, {$in: [tag1, "$__safeContent__"]}, ...]}. */ - std::unique_ptr<Expression> rewriteComparisonsToEncryptedField( + std::unique_ptr<Expression> rewriteInToEncryptedField( + const Expression* leftExpr, const std::vector<boost::intrusive_ptr<Expression>>& equalitiesList) { size_t numFFPs = 0; std::vector<boost::intrusive_ptr<Expression>> orListElems; @@ -140,11 +197,122 @@ public: continue; } - // ... rewrite the payload to a list of tags... numFFPs++; + } + } + + // Finally, construct an $or of all of the $ins. + if (numFFPs == 0) { + return nullptr; + } + + uassert( + 6334102, + "If any elements in an comparison expression are encrypted, then all elements should " + "be encrypted.", + numFFPs == equalitiesList.size()); + + auto leftFieldPath = dynamic_cast<const ExpressionFieldPath*>(leftExpr); + uassert(6672417, + "$in is only supported with Queryable Encryption when the first argument is a " + "field path", + leftFieldPath != nullptr); + + if (!queryRewriter->isForceHighCardinality()) { + try { + for (auto& equality : equalitiesList) { + // For each expression representing a FleFindPayload... + if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { + // ... rewrite the payload to a list of tags... + auto tags = queryRewriter->rewritePayloadAsTags(constChild->getValue()); + for (auto&& tagElt : tags) { + // ... and for each tag, construct expression {$in: [tag, + // "$__safeContent__"]}. + std::vector<boost::intrusive_ptr<Expression>> inVec{ + ExpressionConstant::create(queryRewriter->expCtx(), tagElt), + ExpressionFieldPath::createPathFromString( + queryRewriter->expCtx(), + kSafeContent, + queryRewriter->expCtx()->variablesParseState)}; + orListElems.push_back(make_intrusive<ExpressionIn>( + queryRewriter->expCtx(), std::move(inVec))); + } + } + } + + didRewrite = true; + + return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), + std::move(orListElems)); + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG(6672403, + 2, + "FLE Max tag limit hit during aggregation $in rewrite", + "__error__"_attr = ex.what()); + + if (queryRewriter->getHighCardinalityMode() != + FLEQueryRewriter::HighCardinalityMode::kUseIfNeeded) { + throw; + } + + // fall through + } + } + + for (auto& equality : equalitiesList) { + if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { + auto fleEqExpr = generateFleEqualMatch( + leftFieldPath->getFieldPathWithoutCurrentPrefix().fullPath(), + constChild->getValue(), + queryRewriter->expCtx()); + orListElems.push_back(fleEqExpr); + } + } + + didRewrite = true; + return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), std::move(orListElems)); + } + + // Rewrite a [$eq : [$fieldpath, constant]] or [$eq: [constant, $fieldpath]] + // to _internalFleEq: {field: $fieldpath, edc: edcToken, counter: N, server: serverToken} + std::unique_ptr<Expression> rewriteComparisonsToEncryptedField( + const std::vector<boost::intrusive_ptr<Expression>>& equalitiesList) { + + auto leftConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[0].get()); + auto rightConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[1].get()); + + bool isLeftFFP = leftConstant && queryRewriter->isFleFindPayload(leftConstant->getValue()); + bool isRightFFP = + rightConstant && queryRewriter->isFleFindPayload(rightConstant->getValue()); + + uassert(6334100, + "Cannot compare two encrypted constants to each other", + !(isLeftFFP && isRightFFP)); + + // No FLE Find Payload + if (!isLeftFFP && !isRightFFP) { + return nullptr; + } + + auto leftFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[0].get()); + auto rightFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[1].get()); + + uassert( + 6672413, + "Queryable Encryption only supports comparisons between a field path and a constant", + leftFieldPath || rightFieldPath); + + auto fieldPath = leftFieldPath ? leftFieldPath : rightFieldPath; + auto constChild = isLeftFFP ? leftConstant : rightConstant; + + if (!queryRewriter->isForceHighCardinality()) { + try { + std::vector<boost::intrusive_ptr<Expression>> orListElems; + auto tags = queryRewriter->rewritePayloadAsTags(constChild->getValue()); for (auto&& tagElt : tags) { - // ... and for each tag, construct expression {$in: [tag, "$__safeContent__"]}. + // ... and for each tag, construct expression {$in: [tag, + // "$__safeContent__"]}. std::vector<boost::intrusive_ptr<Expression>> inVec{ ExpressionConstant::create(queryRewriter->expCtx(), tagElt), ExpressionFieldPath::createPathFromString( @@ -154,21 +322,33 @@ public: orListElems.push_back( make_intrusive<ExpressionIn>(queryRewriter->expCtx(), std::move(inVec))); } + + didRewrite = true; + return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), + std::move(orListElems)); + + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG(6672409, + 2, + "FLE Max tag limit hit during query $in rewrite", + "__error__"_attr = ex.what()); + + if (queryRewriter->getHighCardinalityMode() != + FLEQueryRewriter::HighCardinalityMode::kUseIfNeeded) { + throw; + } + + // fall through } } - // Finally, construct an $or of all of the $ins. - if (numFFPs == 0) { - return nullptr; - } - uassert( - 6334102, - "If any elements in an comparison expression are encrypted, then all elements should " - "be encrypted.", - numFFPs == equalitiesList.size()); + auto fleEqExpr = + generateFleEqualMatchUnique(fieldPath->getFieldPathWithoutCurrentPrefix().fullPath(), + constChild->getValue(), + queryRewriter->expCtx()); didRewrite = true; - return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), std::move(orListElems)); + return fleEqExpr; } std::unique_ptr<Expression> postVisit(Expression* exp) { @@ -177,30 +357,28 @@ public: // ignored when rewrites are done; there is no extra information in that child that // doesn't exist in the FFPs in the $in list. if (auto inList = dynamic_cast<ExpressionArray*>(inExpr->getOperandList()[1].get())) { - return rewriteComparisonsToEncryptedField(inList->getChildren()); + return rewriteInToEncryptedField(inExpr->getOperandList()[0].get(), + inList->getChildren()); } } else if (auto eqExpr = dynamic_cast<ExpressionCompare*>(exp); eqExpr && (eqExpr->getOp() == ExpressionCompare::EQ || eqExpr->getOp() == ExpressionCompare::NE)) { // Rewrite an $eq comparing an encrypted field and an encrypted constant to an $or. - // Either child may be the constant, so try rewriting both. - auto or0 = rewriteComparisonsToEncryptedField({eqExpr->getChildren()[0]}); - auto or1 = rewriteComparisonsToEncryptedField({eqExpr->getChildren()[1]}); - uassert(6334100, "Cannot compare two encrypted constants to each other", !or0 || !or1); + auto newExpr = rewriteComparisonsToEncryptedField(eqExpr->getChildren()); // Neither child is an encrypted constant, and no rewriting needs to be done. - if (!or0 && !or1) { + if (!newExpr) { return nullptr; } // Exactly one child was an encrypted constant. The other child can be ignored; there is // no extra information in that child that doesn't exist in the FFP. if (eqExpr->getOp() == ExpressionCompare::NE) { - std::vector<boost::intrusive_ptr<Expression>> notChild{(or0 ? or0 : or1).release()}; + std::vector<boost::intrusive_ptr<Expression>> notChild{newExpr.release()}; return std::make_unique<ExpressionNot>(queryRewriter->expCtx(), std::move(notChild)); } - return std::move(or0 ? or0 : or1); + return newExpr; } return nullptr; @@ -213,11 +391,14 @@ public: BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader, const FLEStateCollectionReader& eccReader, boost::intrusive_ptr<ExpressionContext> expCtx, - BSONObj filter) { + BSONObj filter, + HighCardinalityModeAllowed mode) { + if (auto rewritten = - FLEQueryRewriter(expCtx, escReader, eccReader).rewriteMatchExpression(filter)) { + FLEQueryRewriter(expCtx, escReader, eccReader, mode).rewriteMatchExpression(filter)) { return rewritten.get(); } + return filter; } @@ -273,16 +454,18 @@ public: FilterRewrite(boost::intrusive_ptr<ExpressionContext> expCtx, const NamespaceString& nss, const EncryptionInformation& encryptInfo, - const BSONObj toRewrite) - : RewriteBase(expCtx, nss, encryptInfo), userFilter(toRewrite) {} + const BSONObj toRewrite, + HighCardinalityModeAllowed mode) + : RewriteBase(expCtx, nss, encryptInfo), userFilter(toRewrite), _mode(mode) {} ~FilterRewrite(){}; void doRewrite(FLEStateCollectionReader& escReader, FLEStateCollectionReader& eccReader) final { - rewrittenFilter = rewriteEncryptedFilter(escReader, eccReader, expCtx, userFilter); + rewrittenFilter = rewriteEncryptedFilter(escReader, eccReader, expCtx, userFilter, _mode); } const BSONObj userFilter; BSONObj rewrittenFilter; + HighCardinalityModeAllowed _mode; }; // This helper executes the rewrite(s) inside a transaction. The transaction runs in a separate @@ -324,7 +507,8 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl, StringData db, const EncryptedFieldConfig& efc, boost::intrusive_ptr<ExpressionContext> expCtx, - BSONObj filter) { + BSONObj filter, + HighCardinalityModeAllowed mode) { auto makeCollectionReader = [&](FLEQueryInterface* queryImpl, const StringData& coll) { NamespaceString nss(db, coll); auto docCount = queryImpl->countDocuments(nss); @@ -332,7 +516,8 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl, }; auto escReader = makeCollectionReader(queryImpl, efc.getEscCollection().get()); auto eccReader = makeCollectionReader(queryImpl, efc.getEccCollection().get()); - return rewriteEncryptedFilter(escReader, eccReader, expCtx, filter); + + return rewriteEncryptedFilter(escReader, eccReader, expCtx, filter, mode); } BSONObj rewriteQuery(OperationContext* opCtx, @@ -340,8 +525,9 @@ BSONObj rewriteQuery(OperationContext* opCtx, const NamespaceString& nss, const EncryptionInformation& info, BSONObj filter, - GetTxnCallback getTransaction) { - auto sharedBlock = std::make_shared<FilterRewrite>(expCtx, nss, info, filter); + GetTxnCallback getTransaction, + HighCardinalityModeAllowed mode) { + auto sharedBlock = std::make_shared<FilterRewrite>(expCtx, nss, info, filter, mode); doFLERewriteInTxn(opCtx, sharedBlock, getTransaction); return sharedBlock->rewrittenFilter.getOwned(); } @@ -365,7 +551,8 @@ void processFindCommand(OperationContext* opCtx, nss, findCommand->getEncryptionInformation().get(), findCommand->getFilter().getOwned(), - getTransaction)); + getTransaction, + HighCardinalityModeAllowed::kAllow)); // The presence of encryptionInformation is a signal that this is a FLE request that requires // special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but // a normal query that can be executed by the query system like any other, so remove @@ -389,7 +576,8 @@ void processCountCommand(OperationContext* opCtx, nss, countCommand->getEncryptionInformation().get(), countCommand->getQuery().getOwned(), - getTxn)); + getTxn, + HighCardinalityModeAllowed::kAllow)); // The presence of encryptionInformation is a signal that this is a FLE request that requires // special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but // a normal query that can be executed by the query system like any other, so remove @@ -504,59 +692,112 @@ std::vector<Value> FLEQueryRewriter::rewritePayloadAsTags(Value fleFindPayload) return tagVec; } -std::unique_ptr<InMatchExpression> FLEQueryRewriter::rewriteEq( - const EqualityMatchExpression* expr) { + +std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteEq(const EqualityMatchExpression* expr) { auto ffp = expr->getData(); if (!isFleFindPayload(ffp)) { return nullptr; } - auto obj = rewritePayloadAsTags(ffp); - - auto tags = std::vector<BSONElement>(); - obj.elems(tags); + if (_mode != HighCardinalityMode::kForceAlways) { + try { + auto obj = rewritePayloadAsTags(ffp); + + auto tags = std::vector<BSONElement>(); + obj.elems(tags); + + auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); + inExpr->setBackingBSON(std::move(obj)); + auto status = inExpr->setEqualities(std::move(tags)); + uassertStatusOK(status); + _rewroteLastExpression = true; + return inExpr; + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG(6672410, + 2, + "FLE Max tag limit hit during query $eq rewrite", + "__error__"_attr = ex.what()); + + if (_mode != HighCardinalityMode::kUseIfNeeded) { + throw; + } - auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); - inExpr->setBackingBSON(std::move(obj)); - auto status = inExpr->setEqualities(std::move(tags)); - uassertStatusOK(status); + // fall through + } + } + auto exprMatch = generateFleEqualMatchAndExpr(expr->path(), ffp, _expCtx.get()); _rewroteLastExpression = true; - return inExpr; + return exprMatch; } -std::unique_ptr<InMatchExpression> FLEQueryRewriter::rewriteIn(const InMatchExpression* expr) { - auto backingBSONBuilder = BSONArrayBuilder(); +std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteIn(const InMatchExpression* expr) { size_t numFFPs = 0; for (auto& eq : expr->getEqualities()) { if (isFleFindPayload(eq)) { - auto obj = rewritePayloadAsTags(eq); ++numFFPs; - for (auto&& elt : obj) { - backingBSONBuilder.append(elt); - } } } + if (numFFPs == 0) { return nullptr; } + // All elements in an encrypted $in expression should be FFPs. uassert( 6329400, "If any elements in a $in expression are encrypted, then all elements should be encrypted.", numFFPs == expr->getEqualities().size()); - auto backingBSON = backingBSONBuilder.arr(); - auto allTags = std::vector<BSONElement>(); - backingBSON.elems(allTags); + if (_mode != HighCardinalityMode::kForceAlways) { + + try { + auto backingBSONBuilder = BSONArrayBuilder(); + + for (auto& eq : expr->getEqualities()) { + auto obj = rewritePayloadAsTags(eq); + for (auto&& elt : obj) { + backingBSONBuilder.append(elt); + } + } - auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); - inExpr->setBackingBSON(std::move(backingBSON)); - auto status = inExpr->setEqualities(std::move(allTags)); - uassertStatusOK(status); + auto backingBSON = backingBSONBuilder.arr(); + auto allTags = std::vector<BSONElement>(); + backingBSON.elems(allTags); + + auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); + inExpr->setBackingBSON(std::move(backingBSON)); + auto status = inExpr->setEqualities(std::move(allTags)); + uassertStatusOK(status); + + _rewroteLastExpression = true; + return inExpr; + + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG(6672411, + 2, + "FLE Max tag limit hit during query $in rewrite", + "__error__"_attr = ex.what()); + + if (_mode != HighCardinalityMode::kUseIfNeeded) { + throw; + } + + // fall through + } + } + + std::vector<std::unique_ptr<MatchExpression>> matches; + matches.reserve(numFFPs); + + for (auto& eq : expr->getEqualities()) { + auto exprMatch = generateFleEqualMatchAndExpr(expr->path(), eq, _expCtx.get()); + matches.push_back(std::move(exprMatch)); + } + auto orExpr = std::make_unique<OrMatchExpression>(std::move(matches)); _rewroteLastExpression = true; - return inExpr; + return orExpr; } } // namespace mongo::fle diff --git a/src/mongo/db/query/fle/server_rewrite.h b/src/mongo/db/query/fle/server_rewrite.h index ed84ea283c5..bf02eeebd4e 100644 --- a/src/mongo/db/query/fle/server_rewrite.h +++ b/src/mongo/db/query/fle/server_rewrite.h @@ -31,7 +31,7 @@ #include <memory> -#include "boost/smart_ptr/intrusive_ptr.hpp" +#include <boost/smart_ptr/intrusive_ptr.hpp> #include "mongo/bson/bsonobj.h" #include "mongo/crypto/fle_crypto.h" @@ -47,6 +47,14 @@ class FLEQueryInterface; namespace fle { /** + * Low Selectivity rewrites use $expr which is not supported in all commands such as upserts. + */ +enum class HighCardinalityModeAllowed { + kAllow, + kDisallow, +}; + +/** * Make a collator object from its BSON representation. Useful when creating ExpressionContext * objects for parsing MatchExpressions as part of the server-side rewrite. */ @@ -62,7 +70,8 @@ BSONObj rewriteQuery(OperationContext* opCtx, const NamespaceString& nss, const EncryptionInformation& info, BSONObj filter, - GetTxnCallback getTransaction); + GetTxnCallback getTransaction, + HighCardinalityModeAllowed mode); /** * Process a find command with encryptionInformation in-place, rewriting the filter condition so @@ -100,11 +109,13 @@ std::unique_ptr<Pipeline, PipelineDeleter> processPipeline( * from inside an existing transaction using a FLEQueryInterface constructed from a * transaction client. */ -BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl, - StringData db, - const EncryptedFieldConfig& efc, - boost::intrusive_ptr<ExpressionContext> expCtx, - BSONObj filter); +BSONObj rewriteEncryptedFilterInsideTxn( + FLEQueryInterface* queryImpl, + StringData db, + const EncryptedFieldConfig& efc, + boost::intrusive_ptr<ExpressionContext> expCtx, + BSONObj filter, + HighCardinalityModeAllowed mode = HighCardinalityModeAllowed::kDisallow); /** * Class which handles rewriting filter MatchExpressions for FLE2. The functionality is encapsulated @@ -116,14 +127,37 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl, */ class FLEQueryRewriter { public: + enum class HighCardinalityMode { + // Always use high cardinality filters, used by tests + kForceAlways, + + // Use high cardinality mode if $in rewrites do not fit in the + // internalQueryFLERewriteMemoryLimit memory limit + kUseIfNeeded, + + // Do not rewrite into high cardinality filter, throw exceptions instead + // Some contexts like upsert do not support $expr + kDisallow, + }; + /** * Takes in references to collection readers for the ESC and ECC that are used during tag * computation. */ FLEQueryRewriter(boost::intrusive_ptr<ExpressionContext> expCtx, const FLEStateCollectionReader& escReader, - const FLEStateCollectionReader& eccReader) + const FLEStateCollectionReader& eccReader, + HighCardinalityModeAllowed mode = HighCardinalityModeAllowed::kAllow) : _expCtx(expCtx), _escReader(&escReader), _eccReader(&eccReader) { + + if (internalQueryFLEAlwaysUseHighCardinalityMode.load()) { + _mode = HighCardinalityMode::kForceAlways; + } + + if (mode == HighCardinalityModeAllowed::kDisallow) { + _mode = HighCardinalityMode::kDisallow; + } + // This isn't the "real" query so we don't want to increment Expression // counters here. _expCtx->stopExpressionCounters(); @@ -184,6 +218,18 @@ public: return _expCtx.get(); } + bool isForceHighCardinality() const { + return _mode == HighCardinalityMode::kForceAlways; + } + + void setForceHighCardinalityForTest() { + _mode = HighCardinalityMode::kForceAlways; + } + + HighCardinalityMode getHighCardinalityMode() const { + return _mode; + } + protected: // This constructor should only be used for mocks in testing. FLEQueryRewriter(boost::intrusive_ptr<ExpressionContext> expCtx) @@ -196,8 +242,8 @@ private: std::unique_ptr<MatchExpression> _rewrite(MatchExpression* me); virtual BSONObj rewritePayloadAsTags(BSONElement fleFindPayload) const; - std::unique_ptr<InMatchExpression> rewriteEq(const EqualityMatchExpression* expr); - std::unique_ptr<InMatchExpression> rewriteIn(const InMatchExpression* expr); + std::unique_ptr<MatchExpression> rewriteEq(const EqualityMatchExpression* expr); + std::unique_ptr<MatchExpression> rewriteIn(const InMatchExpression* expr); boost::intrusive_ptr<ExpressionContext> _expCtx; @@ -208,6 +254,9 @@ private: // True if the last Expression or MatchExpression processed by this rewriter was rewritten. bool _rewroteLastExpression = false; + + // Controls how query rewriter rewrites the query + HighCardinalityMode _mode{HighCardinalityMode::kUseIfNeeded}; }; diff --git a/src/mongo/db/query/fle/server_rewrite_test.cpp b/src/mongo/db/query/fle/server_rewrite_test.cpp index cb81656dcb6..034de8f0aa9 100644 --- a/src/mongo/db/query/fle/server_rewrite_test.cpp +++ b/src/mongo/db/query/fle/server_rewrite_test.cpp @@ -31,7 +31,9 @@ #include <memory> #include "mongo/bson/bsonelement.h" +#include "mongo/bson/bsonmisc.h" #include "mongo/bson/bsonobjbuilder.h" +#include "mongo/bson/bsontypes.h" #include "mongo/db/matcher/expression_leaf.h" #include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/query/fle/server_rewrite.h" @@ -42,9 +44,19 @@ namespace mongo { namespace { -class MockFLEQueryRewriter : public fle::FLEQueryRewriter { +class BasicMockFLEQueryRewriter : public fle::FLEQueryRewriter { public: - MockFLEQueryRewriter() : fle::FLEQueryRewriter(new ExpressionContextForTest()), _tags() {} + BasicMockFLEQueryRewriter() : fle::FLEQueryRewriter(new ExpressionContextForTest()) {} + + BSONObj rewriteMatchExpressionForTest(const BSONObj& obj) { + auto res = rewriteMatchExpression(obj); + return res ? res.get() : obj; + } +}; + +class MockFLEQueryRewriter : public BasicMockFLEQueryRewriter { +public: + MockFLEQueryRewriter() : _tags() {} bool isFleFindPayload(const BSONElement& fleFindPayload) const override { return _encryptedFields.find(fleFindPayload.fieldNameStringData()) != @@ -56,11 +68,6 @@ public: _tags[fieldvalue] = tags; } - BSONObj rewriteMatchExpressionForTest(const BSONObj& obj) { - auto res = rewriteMatchExpression(obj); - return res ? res.get() : obj; - } - private: BSONObj rewritePayloadAsTags(BSONElement fleFindPayload) const override { ASSERT(fleFindPayload.isNumber()); // Only accept numbers as mock FFPs. @@ -72,6 +79,7 @@ private: std::map<std::pair<StringData, int>, BSONObj> _tags; std::set<StringData> _encryptedFields; }; + class FLEServerRewriteTest : public unittest::Test { public: FLEServerRewriteTest() {} @@ -361,5 +369,290 @@ TEST_F(FLEServerRewriteTest, ComparisonToObjectIgnored) { } } +template <typename T> +std::vector<uint8_t> toEncryptedVector(EncryptedBinDataType dt, T t) { + BSONObj obj = t.toBSON(); + + std::vector<uint8_t> buf(obj.objsize() + 1); + buf[0] = static_cast<uint8_t>(dt); + + std::copy(obj.objdata(), obj.objdata() + obj.objsize(), buf.data() + 1); + + return buf; +} + +template <typename T> +void toEncryptedBinData(StringData field, EncryptedBinDataType dt, T t, BSONObjBuilder* builder) { + auto buf = toEncryptedVector(dt, t); + + builder->appendBinData(field, buf.size(), BinDataType::Encrypt, buf.data()); +} + +constexpr auto kIndexKeyId = "12345678-1234-9876-1234-123456789012"_sd; +constexpr auto kUserKeyId = "ABCDEFAB-1234-9876-1234-123456789012"_sd; +static UUID indexKeyId = uassertStatusOK(UUID::parse(kIndexKeyId.toString())); +static UUID userKeyId = uassertStatusOK(UUID::parse(kUserKeyId.toString())); + +std::vector<char> testValue = {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19}; +std::vector<char> testValue2 = {0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29}; + +const FLEIndexKey& getIndexKey() { + static std::string indexVec = hexblob::decode( + "7dbfebc619aa68a659f64b8e23ccd21644ac326cb74a26840c3d2420176c40ae088294d00ad6cae9684237b21b754cf503f085c25cd320bf035c3417416e1e6fe3d9219f79586582112740b2add88e1030d91926ae8afc13ee575cfb8bb965b7"_sd); + static FLEIndexKey indexKey(KeyMaterial(indexVec.begin(), indexVec.end())); + return indexKey; +} + +const FLEUserKey& getUserKey() { + static std::string userVec = hexblob::decode( + "a7ddbc4c8be00d51f68d9d8e485f351c8edc8d2206b24d8e0e1816d005fbe520e489125047d647b0d8684bfbdbf09c304085ed086aba6c2b2b1677ccc91ced8847a733bf5e5682c84b3ee7969e4a5fe0e0c21e5e3ee190595a55f83147d8de2a"_sd); + static FLEUserKey userKey(KeyMaterial(userVec.begin(), userVec.end())); + return userKey; +} + + +BSONObj generateFFP(StringData path, int value) { + auto indexKey = getIndexKey(); + FLEIndexKeyAndId indexKeyAndId(indexKey.data, indexKeyId); + auto userKey = getUserKey(); + FLEUserKeyAndId userKeyAndId(userKey.data, indexKeyId); + + BSONObj doc = BSON("value" << value); + auto element = doc.firstElement(); + auto fpp = FLEClientCrypto::serializeFindPayload(indexKeyAndId, userKeyAndId, element, 0); + + BSONObjBuilder builder; + toEncryptedBinData(path, EncryptedBinDataType::kFLE2FindEqualityPayload, fpp, &builder); + return builder.obj(); +} + +class FLEServerHighCardRewriteTest : public unittest::Test { +public: + FLEServerHighCardRewriteTest() {} + + void setUp() override {} + + void tearDown() override {} + +protected: + BasicMockFLEQueryRewriter _mock; +}; + + +TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Equality) { + _mock.setForceHighCardinalityForTest(); + + auto match = generateFFP("ssn", 1); + auto expected = fromjson(R"({ + "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } +})"); + + auto actual = _mock.rewriteMatchExpressionForTest(match); + ASSERT_BSONOBJ_EQ(actual, expected); +} + + +TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_In) { + _mock.setForceHighCardinalityForTest(); + + auto ffp1 = generateFFP("ssn", 1); + auto ffp2 = generateFFP("ssn", 2); + auto ffp3 = generateFFP("ssn", 3); + auto expected = fromjson(R"({ + "$or": [ + { + "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } + }, + { + "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CLpCo6rNuYMVT+6n1HCX15MNrVYDNqf6udO46ayo43Sw", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } + }, + { + "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CPi44oCQHnNDeRqHsNLzbdCeHt2DK/wCly0g2dxU5fqN", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } + } + ] +})"); + + auto match = + BSON("ssn" << BSON("$in" << BSON_ARRAY(ffp1.firstElement() + << ffp2.firstElement() << ffp3.firstElement()))); + + auto actual = _mock.rewriteMatchExpressionForTest(match); + ASSERT_BSONOBJ_EQ(actual, expected); +} + + +TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Expr) { + + _mock.setForceHighCardinalityForTest(); + + auto ffp = generateFFP("$ssn", 1); + int len; + auto v = ffp.firstElement().binDataClean(len); + auto match = BSON("$expr" << BSON("$eq" << BSON_ARRAY(ffp.firstElement().fieldName() + << BSONBinData(v, len, Encrypt)))); + + auto expected = fromjson(R"({ "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } + })"); + + auto actual = _mock.rewriteMatchExpressionForTest(match); + ASSERT_BSONOBJ_EQ(actual, expected); +} + +TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Expr_In) { + + _mock.setForceHighCardinalityForTest(); + + auto ffp = generateFFP("$ssn", 1); + int len; + auto v = ffp.firstElement().binDataClean(len); + + auto ffp2 = generateFFP("$ssn", 1); + int len2; + auto v2 = ffp2.firstElement().binDataClean(len2); + + auto match = BSON( + "$expr" << BSON("$in" << BSON_ARRAY(ffp.firstElement().fieldName() + << BSON_ARRAY(BSONBinData(v, len, Encrypt) + << BSONBinData(v2, len2, Encrypt))))); + + auto expected = fromjson(R"({ "$expr": { "$or" : [ { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + }}, + { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + }} + ]}})"); + + auto actual = _mock.rewriteMatchExpressionForTest(match); + ASSERT_BSONOBJ_EQ(actual, expected); +} + } // namespace } // namespace mongo diff --git a/src/mongo/db/query/query_knobs.idl b/src/mongo/db/query/query_knobs.idl index 18851f0ddb9..f894629037f 100644 --- a/src/mongo/db/query/query_knobs.idl +++ b/src/mongo/db/query/query_knobs.idl @@ -863,6 +863,14 @@ server_parameters: gt: 0 lt: 16777216 + internalQueryFLEAlwaysUseHighCardinalityMode: + description: "Boolean flag to force FLE to always use low selectivity mode" + set_at: [ startup, runtime ] + cpp_varname: "internalQueryFLEAlwaysUseHighCardinalityMode" + cpp_vartype: AtomicWord<bool> + default: + expr: false + # Note for adding additional query knobs: # # When adding a new query knob, you should consider whether or not you need to add an 'on_update' diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 89541a241bc..cbeac015678 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -377,6 +377,7 @@ public: void visit(const ExpressionLn* expr) final {} void visit(const ExpressionLog* expr) final {} void visit(const ExpressionLog10* expr) final {} + void visit(const ExpressionInternalFLEEqual* expr) final {} void visit(const ExpressionMap* expr) final {} void visit(const ExpressionMeta* expr) final {} void visit(const ExpressionMod* expr) final {} @@ -609,6 +610,7 @@ public: void visit(const ExpressionLn* expr) final {} void visit(const ExpressionLog* expr) final {} void visit(const ExpressionLog10* expr) final {} + void visit(const ExpressionInternalFLEEqual* expr) final {} void visit(const ExpressionMap* expr) final {} void visit(const ExpressionMeta* expr) final {} void visit(const ExpressionMod* expr) final {} @@ -2317,6 +2319,9 @@ public: _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(log10Expr))); } + void visit(const ExpressionInternalFLEEqual* expr) final { + unsupportedExpression("$_internalFleEq"); + } void visit(const ExpressionMap* expr) final { unsupportedExpression("$map"); } diff --git a/src/mongo/idl/basic_types.idl b/src/mongo/idl/basic_types.idl index 07a8e5fbf00..883c7a61e34 100644 --- a/src/mongo/idl/basic_types.idl +++ b/src/mongo/idl/basic_types.idl @@ -157,6 +157,13 @@ types: cpp_type: "std::array<std::uint8_t, 16>" deserializer: "mongo::BSONElement::uuid" + bindata_encrypt: + bson_serialization_type: bindata + bindata_subtype: encrypt + description: "A BSON bindata of encrypt sub type" + cpp_type: "std::vector<std::uint8_t>" + deserializer: "mongo::BSONElement::_binDataVector" + uuid: bson_serialization_type: bindata bindata_subtype: uuid @@ -256,6 +263,7 @@ types: cpp_type: "mongo::IDLAnyTypeOwned" serializer: mongo::IDLAnyTypeOwned::serializeToBSON deserializer: mongo::IDLAnyTypeOwned::parseFromBSON + tenant_id: bson_serialization_type: any description: "A struct representing a tenant id" |